1
2
3
4
5 package modernize
6
7 import (
8 "fmt"
9 "go/ast"
10 "go/token"
11 "go/types"
12 "strings"
13 "unicode"
14 "unicode/utf8"
15
16 "golang.org/x/tools/go/analysis"
17 "golang.org/x/tools/go/analysis/passes/inspect"
18 "golang.org/x/tools/go/ast/edge"
19 "golang.org/x/tools/go/types/typeutil"
20 "golang.org/x/tools/internal/analysis/analyzerutil"
21 typeindexanalyzer "golang.org/x/tools/internal/analysis/typeindex"
22 "golang.org/x/tools/internal/astutil"
23 "golang.org/x/tools/internal/typesinternal"
24 "golang.org/x/tools/internal/typesinternal/typeindex"
25 "golang.org/x/tools/internal/versions"
26 )
27
28 var TestingContextAnalyzer = &analysis.Analyzer{
29 Name: "testingcontext",
30 Doc: analyzerutil.MustExtractDoc(doc, "testingcontext"),
31 Requires: []*analysis.Analyzer{
32 inspect.Analyzer,
33 typeindexanalyzer.Analyzer,
34 },
35 Run: testingContext,
36 URL: "https://pkg.go.dev/golang.org/x/tools/go/analysis/passes/modernize#testingcontext",
37 }
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57 func testingContext(pass *analysis.Pass) (any, error) {
58 var (
59 index = pass.ResultOf[typeindexanalyzer.Analyzer].(*typeindex.Index)
60 info = pass.TypesInfo
61
62 contextWithCancel = index.Object("context", "WithCancel")
63 )
64
65 calls:
66 for cur := range index.Calls(contextWithCancel) {
67 call := cur.Node().(*ast.CallExpr)
68
69
70 arg, ok := call.Args[0].(*ast.CallExpr)
71 if !ok {
72 continue
73 }
74 if !typesinternal.IsFunctionNamed(typeutil.Callee(info, arg), "context", "Background", "TODO") {
75 continue
76 }
77
78
79 parent := cur.Parent()
80 assign, ok := parent.Node().(*ast.AssignStmt)
81 if !ok || assign.Tok != token.DEFINE {
82 continue
83 }
84
85
86
87 var lhs []types.Object
88 for _, expr := range assign.Lhs {
89 id, ok := expr.(*ast.Ident)
90 if !ok {
91 continue calls
92 }
93 obj, ok := info.Defs[id]
94 if !ok {
95 continue calls
96 }
97 lhs = append(lhs, obj)
98 }
99
100 next, ok := parent.NextSibling()
101 if !ok {
102 continue
103 }
104 defr, ok := next.Node().(*ast.DeferStmt)
105 if !ok {
106 continue
107 }
108 deferId, ok := defr.Call.Fun.(*ast.Ident)
109 if !ok || !soleUseIs(index, lhs[1], deferId) {
110 continue
111 }
112
113
114
115
116
117 var testObj types.Object
118 if curFunc, ok := enclosingFunc(cur); ok {
119 switch n := curFunc.Node().(type) {
120 case *ast.FuncLit:
121 if ek, idx := curFunc.ParentEdge(); ek == edge.CallExpr_Args && idx == 1 {
122
123 obj := typeutil.Callee(info, curFunc.Parent().Node().(*ast.CallExpr))
124 if (typesinternal.IsMethodNamed(obj, "testing", "T", "Run") ||
125 typesinternal.IsMethodNamed(obj, "testing", "B", "Run")) &&
126 len(n.Type.Params.List[0].Names) == 1 {
127
128
129 testObj = info.Defs[n.Type.Params.List[0].Names[0]]
130 }
131 }
132
133 case *ast.FuncDecl:
134 testObj = isTestFn(info, n)
135 }
136 }
137 if testObj != nil && analyzerutil.FileUsesGoVersion(pass, astutil.EnclosingFile(cur), versions.Go1_24) {
138
139
140 if _, obj := lhs[0].Parent().LookupParent(testObj.Name(), lhs[0].Pos()); obj == testObj {
141 pass.Report(analysis.Diagnostic{
142 Pos: call.Fun.Pos(),
143 End: call.Fun.End(),
144 Message: fmt.Sprintf("context.WithCancel can be modernized using %s.Context", testObj.Name()),
145 SuggestedFixes: []analysis.SuggestedFix{{
146 Message: fmt.Sprintf("Replace context.WithCancel with %s.Context", testObj.Name()),
147 TextEdits: []analysis.TextEdit{{
148 Pos: assign.Pos(),
149 End: defr.End(),
150 NewText: fmt.Appendf(nil, "%s := %s.Context()", lhs[0].Name(), testObj.Name()),
151 }},
152 }},
153 })
154 }
155 }
156 }
157 return nil, nil
158 }
159
160
161
162 func soleUseIs(index *typeindex.Index, obj types.Object, id *ast.Ident) bool {
163 empty := true
164 for use := range index.Uses(obj) {
165 empty = false
166 if use.Node() != id {
167 return false
168 }
169 }
170 return !empty
171 }
172
173
174
175
176
177
178
179
180
181
182
183 func isTestFn(info *types.Info, fn *ast.FuncDecl) types.Object {
184
185 if fn.Type.Results != nil && len(fn.Type.Results.List) > 0 ||
186 fn.Type.Params == nil ||
187 len(fn.Type.Params.List) != 1 ||
188 len(fn.Type.Params.List[0].Names) != 1 {
189
190 return nil
191 }
192
193 prefix := testKind(fn.Name.Name)
194 if prefix == "" {
195 return nil
196 }
197
198 if tparams := fn.Type.TypeParams; tparams != nil && len(tparams.List) > 0 {
199 return nil
200 }
201
202 obj := info.Defs[fn.Type.Params.List[0].Names[0]]
203 if obj == nil {
204 return nil
205 }
206
207 var name string
208 switch prefix {
209 case "Test":
210 name = "T"
211 case "Benchmark":
212 name = "B"
213 case "Fuzz":
214 name = "F"
215 }
216
217 if !typesinternal.IsPointerToNamed(obj.Type(), "testing", name) {
218 return nil
219 }
220 return obj
221 }
222
223
224
225
226
227 func testKind(name string) string {
228 var prefix string
229 switch {
230 case strings.HasPrefix(name, "Test"):
231 prefix = "Test"
232 case strings.HasPrefix(name, "Benchmark"):
233 prefix = "Benchmark"
234 case strings.HasPrefix(name, "Fuzz"):
235 prefix = "Fuzz"
236 }
237 if prefix == "" {
238 return ""
239 }
240 suffix := name[len(prefix):]
241 if len(suffix) == 0 {
242
243 return prefix
244 }
245 r, _ := utf8.DecodeRuneInString(suffix)
246 if unicode.IsLower(r) {
247 return ""
248 }
249 return prefix
250 }
251
View as plain text