1
2
3
4
5 package lostcancel
6
7 import (
8 _ "embed"
9 "fmt"
10 "go/ast"
11 "go/types"
12
13 "golang.org/x/tools/go/analysis"
14 "golang.org/x/tools/go/analysis/passes/ctrlflow"
15 "golang.org/x/tools/go/analysis/passes/inspect"
16 "golang.org/x/tools/go/analysis/passes/internal/analysisutil"
17 "golang.org/x/tools/go/ast/inspector"
18 "golang.org/x/tools/go/cfg"
19 )
20
21
22 var doc string
23
24 var Analyzer = &analysis.Analyzer{
25 Name: "lostcancel",
26 Doc: analysisutil.MustExtractDoc(doc, "lostcancel"),
27 URL: "https://pkg.go.dev/golang.org/x/tools/go/analysis/passes/lostcancel",
28 Run: run,
29 Requires: []*analysis.Analyzer{
30 inspect.Analyzer,
31 ctrlflow.Analyzer,
32 },
33 }
34
35 const debug = false
36
37 var contextPackage = "context"
38
39
40
41
42
43
44
45
46
47
48
49 func run(pass *analysis.Pass) (interface{}, error) {
50
51 if !analysisutil.Imports(pass.Pkg, contextPackage) {
52 return nil, nil
53 }
54
55
56 inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
57 nodeTypes := []ast.Node{
58 (*ast.FuncLit)(nil),
59 (*ast.FuncDecl)(nil),
60 }
61 inspect.Preorder(nodeTypes, func(n ast.Node) {
62 runFunc(pass, n)
63 })
64 return nil, nil
65 }
66
67 func runFunc(pass *analysis.Pass, node ast.Node) {
68
69 var funcScope *types.Scope
70 switch v := node.(type) {
71 case *ast.FuncLit:
72 funcScope = pass.TypesInfo.Scopes[v.Type]
73 case *ast.FuncDecl:
74 funcScope = pass.TypesInfo.Scopes[v.Type]
75 }
76
77
78 cancelvars := make(map[*types.Var]ast.Node)
79
80
81
82
83
84
85 stack := make([]ast.Node, 0, 32)
86 ast.Inspect(node, func(n ast.Node) bool {
87 switch n.(type) {
88 case *ast.FuncLit:
89 if len(stack) > 0 {
90 return false
91 }
92 case nil:
93 stack = stack[:len(stack)-1]
94 return true
95 }
96 stack = append(stack, n)
97
98
99
100
101
102
103
104 if !isContextWithCancel(pass.TypesInfo, n) || !isCall(stack[len(stack)-2]) {
105 return true
106 }
107 var id *ast.Ident
108 stmt := stack[len(stack)-3]
109 switch stmt := stmt.(type) {
110 case *ast.ValueSpec:
111 if len(stmt.Names) > 1 {
112 id = stmt.Names[1]
113 }
114 case *ast.AssignStmt:
115 if len(stmt.Lhs) > 1 {
116 id, _ = stmt.Lhs[1].(*ast.Ident)
117 }
118 }
119 if id != nil {
120 if id.Name == "_" {
121 pass.ReportRangef(id,
122 "the cancel function returned by context.%s should be called, not discarded, to avoid a context leak",
123 n.(*ast.SelectorExpr).Sel.Name)
124 } else if v, ok := pass.TypesInfo.Uses[id].(*types.Var); ok {
125
126
127 if funcScope.Contains(v.Pos()) {
128 cancelvars[v] = stmt
129 }
130 } else if v, ok := pass.TypesInfo.Defs[id].(*types.Var); ok {
131 cancelvars[v] = stmt
132 }
133 }
134 return true
135 })
136
137 if len(cancelvars) == 0 {
138 return
139 }
140
141
142 cfgs := pass.ResultOf[ctrlflow.Analyzer].(*ctrlflow.CFGs)
143 var g *cfg.CFG
144 var sig *types.Signature
145 switch node := node.(type) {
146 case *ast.FuncDecl:
147 sig, _ = pass.TypesInfo.Defs[node.Name].Type().(*types.Signature)
148 if node.Name.Name == "main" && sig.Recv() == nil && pass.Pkg.Name() == "main" {
149
150
151 return
152 }
153 g = cfgs.FuncDecl(node)
154
155 case *ast.FuncLit:
156 sig, _ = pass.TypesInfo.Types[node.Type].Type.(*types.Signature)
157 g = cfgs.FuncLit(node)
158 }
159 if sig == nil {
160 return
161 }
162
163
164 if debug {
165 fmt.Println(g.Format(pass.Fset))
166 }
167
168
169
170
171 for v, stmt := range cancelvars {
172 if ret := lostCancelPath(pass, g, v, stmt, sig); ret != nil {
173 lineno := pass.Fset.Position(stmt.Pos()).Line
174 pass.ReportRangef(stmt, "the %s function is not used on all paths (possible context leak)", v.Name())
175
176 pos, end := ret.Pos(), ret.End()
177
178
179 if pass.Fset.File(pos) != pass.Fset.File(end) {
180 end = pos
181 }
182 pass.Report(analysis.Diagnostic{
183 Pos: pos,
184 End: end,
185 Message: fmt.Sprintf("this return statement may be reached without using the %s var defined on line %d", v.Name(), lineno),
186 })
187 }
188 }
189 }
190
191 func isCall(n ast.Node) bool { _, ok := n.(*ast.CallExpr); return ok }
192
193
194
195 func isContextWithCancel(info *types.Info, n ast.Node) bool {
196 sel, ok := n.(*ast.SelectorExpr)
197 if !ok {
198 return false
199 }
200 switch sel.Sel.Name {
201 case "WithCancel", "WithCancelCause",
202 "WithTimeout", "WithTimeoutCause",
203 "WithDeadline", "WithDeadlineCause":
204 default:
205 return false
206 }
207 if x, ok := sel.X.(*ast.Ident); ok {
208 if pkgname, ok := info.Uses[x].(*types.PkgName); ok {
209 return pkgname.Imported().Path() == contextPackage
210 }
211
212
213 return x.Name == "context"
214 }
215 return false
216 }
217
218
219
220
221
222 func lostCancelPath(pass *analysis.Pass, g *cfg.CFG, v *types.Var, stmt ast.Node, sig *types.Signature) *ast.ReturnStmt {
223 vIsNamedResult := sig != nil && tupleContains(sig.Results(), v)
224
225
226 uses := func(pass *analysis.Pass, v *types.Var, stmts []ast.Node) bool {
227 found := false
228 for _, stmt := range stmts {
229 ast.Inspect(stmt, func(n ast.Node) bool {
230 switch n := n.(type) {
231 case *ast.Ident:
232 if pass.TypesInfo.Uses[n] == v {
233 found = true
234 }
235 case *ast.ReturnStmt:
236
237
238 if n.Results == nil && vIsNamedResult {
239 found = true
240 }
241 }
242 return !found
243 })
244 }
245 return found
246 }
247
248
249 memo := make(map[*cfg.Block]bool)
250 blockUses := func(pass *analysis.Pass, v *types.Var, b *cfg.Block) bool {
251 res, ok := memo[b]
252 if !ok {
253 res = uses(pass, v, b.Nodes)
254 memo[b] = res
255 }
256 return res
257 }
258
259
260
261 var defblock *cfg.Block
262 var rest []ast.Node
263 outer:
264 for _, b := range g.Blocks {
265 for i, n := range b.Nodes {
266 if n == stmt {
267 defblock = b
268 rest = b.Nodes[i+1:]
269 break outer
270 }
271 }
272 }
273 if defblock == nil {
274 panic("internal error: can't find defining block for cancel var")
275 }
276
277
278 if uses(pass, v, rest) {
279 return nil
280 }
281
282
283 if ret := defblock.Return(); ret != nil {
284 return ret
285 }
286
287
288
289 seen := make(map[*cfg.Block]bool)
290 var search func(blocks []*cfg.Block) *ast.ReturnStmt
291 search = func(blocks []*cfg.Block) *ast.ReturnStmt {
292 for _, b := range blocks {
293 if seen[b] {
294 continue
295 }
296 seen[b] = true
297
298
299 if blockUses(pass, v, b) {
300 continue
301 }
302
303
304 if ret := b.Return(); ret != nil {
305 if debug {
306 fmt.Printf("found path to return in block %s\n", b)
307 }
308 return ret
309 }
310
311
312 if ret := search(b.Succs); ret != nil {
313 if debug {
314 fmt.Printf(" from block %s\n", b)
315 }
316 return ret
317 }
318 }
319 return nil
320 }
321 return search(defblock.Succs)
322 }
323
324 func tupleContains(tuple *types.Tuple, v *types.Var) bool {
325 for i := 0; i < tuple.Len(); i++ {
326 if tuple.At(i) == v {
327 return true
328 }
329 }
330 return false
331 }
332
View as plain text