1
2
3
4
5 package modernize
6
7 import (
8 "fmt"
9 "go/ast"
10 "go/token"
11 "go/types"
12 "strings"
13
14 "golang.org/x/tools/go/analysis"
15 "golang.org/x/tools/go/analysis/passes/inspect"
16 "golang.org/x/tools/go/ast/edge"
17 "golang.org/x/tools/go/ast/inspector"
18 "golang.org/x/tools/internal/analysis/analyzerutil"
19 typeindexanalyzer "golang.org/x/tools/internal/analysis/typeindex"
20 "golang.org/x/tools/internal/astutil"
21 "golang.org/x/tools/internal/typeparams"
22 "golang.org/x/tools/internal/typesinternal/typeindex"
23 "golang.org/x/tools/internal/versions"
24 )
25
26 var MinMaxAnalyzer = &analysis.Analyzer{
27 Name: "minmax",
28 Doc: analyzerutil.MustExtractDoc(doc, "minmax"),
29 Requires: []*analysis.Analyzer{
30 inspect.Analyzer,
31 typeindexanalyzer.Analyzer,
32 },
33 Run: minmax,
34 URL: "https://pkg.go.dev/golang.org/x/tools/go/analysis/passes/modernize#minmax",
35 }
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57 func minmax(pass *analysis.Pass) (any, error) {
58
59 checkUserDefinedMinMax(pass)
60
61
62
63 check := func(file *ast.File, curIfStmt inspector.Cursor, compare *ast.BinaryExpr) {
64 var (
65 ifStmt = curIfStmt.Node().(*ast.IfStmt)
66 tassign = ifStmt.Body.List[0].(*ast.AssignStmt)
67 a = compare.X
68 b = compare.Y
69 lhs = tassign.Lhs[0]
70 rhs = tassign.Rhs[0]
71 sign = isInequality(compare.Op)
72
73
74 callArg = func(arg ast.Expr, start, end token.Pos) string {
75 comments := allComments(file, start, end)
76 return cond(arg == b, ", ", "") +
77 cond(comments != "", "\n", "") +
78 comments +
79 astutil.Format(pass.Fset, arg)
80 }
81 )
82
83 if fblock, ok := ifStmt.Else.(*ast.BlockStmt); ok && isAssignBlock(fblock) {
84 fassign := fblock.List[0].(*ast.AssignStmt)
85
86
87 lhs2 := fassign.Lhs[0]
88 rhs2 := fassign.Rhs[0]
89
90
91
92
93 if astutil.EqualSyntax(lhs, lhs2) {
94 if astutil.EqualSyntax(rhs, a) && astutil.EqualSyntax(rhs2, b) {
95 sign = +sign
96 } else if astutil.EqualSyntax(rhs2, a) && astutil.EqualSyntax(rhs, b) {
97 sign = -sign
98 } else {
99 return
100 }
101
102 sym := cond(sign < 0, "min", "max")
103
104 if !is[*types.Builtin](lookup(pass.TypesInfo, curIfStmt, sym)) {
105 return
106 }
107
108
109
110
111
112 pass.Report(analysis.Diagnostic{
113
114 Pos: compare.Pos(),
115 End: compare.End(),
116 Message: fmt.Sprintf("if/else statement can be modernized using %s", sym),
117 SuggestedFixes: []analysis.SuggestedFix{{
118 Message: fmt.Sprintf("Replace if statement with %s", sym),
119 TextEdits: []analysis.TextEdit{{
120
121 Pos: ifStmt.Pos(),
122 End: ifStmt.End(),
123 NewText: fmt.Appendf(nil, "%s = %s(%s%s)",
124 astutil.Format(pass.Fset, lhs),
125 sym,
126 callArg(a, ifStmt.Pos(), ifStmt.Else.Pos()),
127 callArg(b, ifStmt.Else.Pos(), ifStmt.End()),
128 ),
129 }},
130 }},
131 })
132 }
133
134 } else if prev, ok := curIfStmt.PrevSibling(); ok && isSimpleAssign(prev.Node()) && ifStmt.Else == nil {
135 fassign := prev.Node().(*ast.AssignStmt)
136
137
138
139
140
141
142
143
144
145
146
147 lhs0 := fassign.Lhs[0]
148 rhs0 := fassign.Rhs[0]
149
150 if astutil.EqualSyntax(lhs, lhs0) {
151 if astutil.EqualSyntax(rhs, a) && (astutil.EqualSyntax(rhs0, b) || astutil.EqualSyntax(lhs0, b)) {
152 sign = +sign
153 } else if (astutil.EqualSyntax(rhs0, a) || astutil.EqualSyntax(lhs0, a)) && astutil.EqualSyntax(rhs, b) {
154 sign = -sign
155 } else {
156 return
157 }
158 sym := cond(sign < 0, "min", "max")
159
160 if !is[*types.Builtin](lookup(pass.TypesInfo, curIfStmt, sym)) {
161 return
162 }
163
164
165
166
167 if astutil.EqualSyntax(lhs0, a) {
168 a = rhs0
169 } else if astutil.EqualSyntax(lhs0, b) {
170 b = rhs0
171 }
172
173
174 pass.Report(analysis.Diagnostic{
175
176 Pos: compare.Pos(),
177 End: compare.End(),
178 Message: fmt.Sprintf("if statement can be modernized using %s", sym),
179 SuggestedFixes: []analysis.SuggestedFix{{
180 Message: fmt.Sprintf("Replace if/else with %s", sym),
181 TextEdits: []analysis.TextEdit{{
182 Pos: fassign.Pos(),
183 End: ifStmt.End(),
184
185 NewText: fmt.Appendf(nil, "%s %s %s(%s%s)",
186 astutil.Format(pass.Fset, lhs),
187 fassign.Tok.String(),
188 sym,
189 callArg(a, fassign.Pos(), ifStmt.Pos()),
190 callArg(b, ifStmt.Pos(), ifStmt.End()),
191 ),
192 }},
193 }},
194 })
195 }
196 }
197 }
198
199
200 info := pass.TypesInfo
201 for curFile := range filesUsingGoVersion(pass, versions.Go1_21) {
202 astFile := curFile.Node().(*ast.File)
203 for curIfStmt := range curFile.Preorder((*ast.IfStmt)(nil)) {
204 ifStmt := curIfStmt.Node().(*ast.IfStmt)
205
206
207
208
209
210
211
212 if astutil.IsChildOf(curIfStmt, edge.IfStmt_Else) {
213 continue
214 }
215
216 if compare, ok := ifStmt.Cond.(*ast.BinaryExpr); ok &&
217 ifStmt.Init == nil &&
218 isInequality(compare.Op) != 0 &&
219 isAssignBlock(ifStmt.Body) {
220
221 if tLHS := info.TypeOf(ifStmt.Body.List[0].(*ast.AssignStmt).Lhs[0]); tLHS != nil && !maybeNaN(tLHS) {
222
223 check(astFile, curIfStmt, compare)
224 }
225 }
226 }
227 }
228 return nil, nil
229 }
230
231
232 func allComments(file *ast.File, start, end token.Pos) string {
233 var buf strings.Builder
234 for co := range astutil.Comments(file, start, end) {
235 _, _ = fmt.Fprintf(&buf, "%s\n", co.Text)
236 }
237 return buf.String()
238 }
239
240
241
242 func isInequality(tok token.Token) int {
243 switch tok {
244 case token.LEQ, token.LSS:
245 return -1
246 case token.GEQ, token.GTR:
247 return +1
248 }
249 return 0
250 }
251
252
253 func isAssignBlock(b *ast.BlockStmt) bool {
254 if len(b.List) != 1 {
255 return false
256 }
257
258 return isSimpleAssign(b.List[0])
259 }
260
261
262 func isSimpleAssign(n ast.Node) bool {
263 assign, ok := n.(*ast.AssignStmt)
264 return ok &&
265 (assign.Tok == token.ASSIGN || assign.Tok == token.DEFINE) &&
266 len(assign.Lhs) == 1 &&
267 len(assign.Rhs) == 1
268 }
269
270
271 func maybeNaN(t types.Type) bool {
272
273
274
275 t = typeparams.CoreType(t)
276 if t == nil {
277 return true
278 }
279 if basic, ok := t.(*types.Basic); ok && basic.Info()&types.IsFloat != 0 {
280 return true
281 }
282 return false
283 }
284
285
286
287 func checkUserDefinedMinMax(pass *analysis.Pass) {
288 index := pass.ResultOf[typeindexanalyzer.Analyzer].(*typeindex.Index)
289
290
291 for _, funcName := range []string{"min", "max"} {
292 if fn, ok := pass.Pkg.Scope().Lookup(funcName).(*types.Func); ok {
293
294 if def, ok := index.Def(fn); ok {
295 decl := def.Parent().Node().(*ast.FuncDecl)
296
297 if canUseBuiltinMinMax(fn, decl.Body) {
298
299 pos := decl.Pos()
300 if docs := astutil.DocComment(decl); docs != nil {
301 pos = docs.Pos()
302 }
303
304 pass.Report(analysis.Diagnostic{
305 Pos: decl.Pos(),
306 End: decl.End(),
307 Message: fmt.Sprintf("user-defined %s function is equivalent to built-in %s and can be removed", funcName, funcName),
308 SuggestedFixes: []analysis.SuggestedFix{{
309 Message: fmt.Sprintf("Remove user-defined %s function", funcName),
310 TextEdits: []analysis.TextEdit{{
311 Pos: pos,
312 End: decl.End(),
313 }},
314 }},
315 })
316 }
317 }
318 }
319 }
320 }
321
322
323
324 func canUseBuiltinMinMax(fn *types.Func, body *ast.BlockStmt) bool {
325 sig := fn.Type().(*types.Signature)
326
327
328 if sig.Params().Len() != 2 {
329 return false
330 }
331
332
333 for param := range sig.Params().Variables() {
334 if maybeNaN(param.Type()) {
335 return false
336 }
337 }
338
339
340 if sig.Results().Len() != 1 {
341 return false
342 }
343
344
345 if body == nil {
346 return false
347 }
348
349 return hasMinMaxLogic(body, fn.Name())
350 }
351
352
353 func hasMinMaxLogic(body *ast.BlockStmt, funcName string) bool {
354
355 if len(body.List) == 1 {
356 if ifStmt, ok := body.List[0].(*ast.IfStmt); ok {
357
358 if elseBlock, ok := ifStmt.Else.(*ast.BlockStmt); ok && len(elseBlock.List) == 1 {
359 if elseRet, ok := elseBlock.List[0].(*ast.ReturnStmt); ok && len(elseRet.Results) == 1 {
360 return checkMinMaxPattern(ifStmt, elseRet.Results[0], funcName)
361 }
362 }
363 }
364 }
365
366
367 if len(body.List) == 2 {
368 if ifStmt, ok := body.List[0].(*ast.IfStmt); ok && ifStmt.Else == nil {
369 if retStmt, ok := body.List[1].(*ast.ReturnStmt); ok && len(retStmt.Results) == 1 {
370 return checkMinMaxPattern(ifStmt, retStmt.Results[0], funcName)
371 }
372 }
373 }
374
375 return false
376 }
377
378
379
380
381
382 func checkMinMaxPattern(ifStmt *ast.IfStmt, falseResult ast.Expr, funcName string) bool {
383
384 cmp, ok := ifStmt.Cond.(*ast.BinaryExpr)
385 if !ok {
386 return false
387 }
388
389
390 if len(ifStmt.Body.List) != 1 {
391 return false
392 }
393
394 thenRet, ok := ifStmt.Body.List[0].(*ast.ReturnStmt)
395 if !ok || len(thenRet.Results) != 1 {
396 return false
397 }
398
399
400 sign := isInequality(cmp.Op)
401 if sign == 0 {
402 return false
403 }
404
405 t := thenRet.Results[0]
406 f := falseResult
407 x := cmp.X
408 y := cmp.Y
409
410
411 if astutil.EqualSyntax(t, x) && astutil.EqualSyntax(f, y) {
412 sign = +sign
413 } else if astutil.EqualSyntax(t, y) && astutil.EqualSyntax(f, x) {
414 sign = -sign
415 } else {
416 return false
417 }
418
419
420 return cond(sign < 0, "min", "max") == funcName
421 }
422
423
424
425 func is[T any](x any) bool {
426 _, ok := x.(T)
427 return ok
428 }
429
430 func cond[T any](cond bool, t, f T) T {
431 if cond {
432 return t
433 } else {
434 return f
435 }
436 }
437
View as plain text