Source file src/cmd/vendor/golang.org/x/tools/go/analysis/passes/modernize/minmax.go

     1  // Copyright 2024 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package modernize
     6  
     7  import (
     8  	"fmt"
     9  	"go/ast"
    10  	"go/token"
    11  	"go/types"
    12  	"strings"
    13  
    14  	"golang.org/x/tools/go/analysis"
    15  	"golang.org/x/tools/go/analysis/passes/inspect"
    16  	"golang.org/x/tools/go/ast/edge"
    17  	"golang.org/x/tools/go/ast/inspector"
    18  	"golang.org/x/tools/internal/analysis/analyzerutil"
    19  	typeindexanalyzer "golang.org/x/tools/internal/analysis/typeindex"
    20  	"golang.org/x/tools/internal/astutil"
    21  	"golang.org/x/tools/internal/typeparams"
    22  	"golang.org/x/tools/internal/typesinternal/typeindex"
    23  	"golang.org/x/tools/internal/versions"
    24  )
    25  
    26  var MinMaxAnalyzer = &analysis.Analyzer{
    27  	Name: "minmax",
    28  	Doc:  analyzerutil.MustExtractDoc(doc, "minmax"),
    29  	Requires: []*analysis.Analyzer{
    30  		inspect.Analyzer,
    31  		typeindexanalyzer.Analyzer,
    32  	},
    33  	Run: minmax,
    34  	URL: "https://pkg.go.dev/golang.org/x/tools/go/analysis/passes/modernize#minmax",
    35  }
    36  
    37  // The minmax pass replaces if/else statements with calls to min or max,
    38  // and removes user-defined min/max functions that are equivalent to built-ins.
    39  //
    40  // If/else replacement patterns:
    41  //
    42  //  1. if a < b { x = a } else { x = b }        =>      x = min(a, b)
    43  //  2. x = a; if a < b { x = b }                =>      x = max(a, b)
    44  //
    45  // Pattern 1 requires that a is not NaN, and pattern 2 requires that b
    46  // is not Nan. Since this is hard to prove, we reject floating-point
    47  // numbers.
    48  //
    49  // Function removal:
    50  // User-defined min/max functions are suggested for removal if they may
    51  // be safely replaced by their built-in namesake.
    52  //
    53  // Variants:
    54  // - all four ordered comparisons
    55  // - "x := a" or "x = a" or "var x = a" in pattern 2
    56  // - "x < b" or "a < b" in pattern 2
    57  func minmax(pass *analysis.Pass) (any, error) {
    58  	// Check for user-defined min/max functions that can be removed
    59  	checkUserDefinedMinMax(pass)
    60  
    61  	// check is called for all statements of this form:
    62  	//   if a < b { lhs = rhs }
    63  	check := func(file *ast.File, curIfStmt inspector.Cursor, compare *ast.BinaryExpr) {
    64  		var (
    65  			ifStmt  = curIfStmt.Node().(*ast.IfStmt)
    66  			tassign = ifStmt.Body.List[0].(*ast.AssignStmt)
    67  			a       = compare.X
    68  			b       = compare.Y
    69  			lhs     = tassign.Lhs[0]
    70  			rhs     = tassign.Rhs[0]
    71  			sign    = isInequality(compare.Op)
    72  
    73  			// callArg formats a call argument, preserving comments from [start-end).
    74  			callArg = func(arg ast.Expr, start, end token.Pos) string {
    75  				comments := allComments(file, start, end)
    76  				return cond(arg == b, ", ", "") + // second argument needs a comma
    77  					cond(comments != "", "\n", "") + // comments need their own line
    78  					comments +
    79  					astutil.Format(pass.Fset, arg)
    80  			}
    81  		)
    82  
    83  		if fblock, ok := ifStmt.Else.(*ast.BlockStmt); ok && isAssignBlock(fblock) {
    84  			fassign := fblock.List[0].(*ast.AssignStmt)
    85  
    86  			// Have: if a < b { lhs = rhs } else { lhs2 = rhs2 }
    87  			lhs2 := fassign.Lhs[0]
    88  			rhs2 := fassign.Rhs[0]
    89  
    90  			// For pattern 1, check that:
    91  			// - lhs = lhs2
    92  			// - {rhs,rhs2} = {a,b}
    93  			if astutil.EqualSyntax(lhs, lhs2) {
    94  				if astutil.EqualSyntax(rhs, a) && astutil.EqualSyntax(rhs2, b) {
    95  					sign = +sign
    96  				} else if astutil.EqualSyntax(rhs2, a) && astutil.EqualSyntax(rhs, b) {
    97  					sign = -sign
    98  				} else {
    99  					return
   100  				}
   101  
   102  				sym := cond(sign < 0, "min", "max")
   103  
   104  				if !is[*types.Builtin](lookup(pass.TypesInfo, curIfStmt, sym)) {
   105  					return // min/max function is shadowed
   106  				}
   107  
   108  				// pattern 1
   109  				//
   110  				// TODO(adonovan): if lhs is declared "var lhs T" on preceding line,
   111  				// simplify the whole thing to "lhs := min(a, b)".
   112  				pass.Report(analysis.Diagnostic{
   113  					// Highlight the condition a < b.
   114  					Pos:     compare.Pos(),
   115  					End:     compare.End(),
   116  					Message: fmt.Sprintf("if/else statement can be modernized using %s", sym),
   117  					SuggestedFixes: []analysis.SuggestedFix{{
   118  						Message: fmt.Sprintf("Replace if statement with %s", sym),
   119  						TextEdits: []analysis.TextEdit{{
   120  							// Replace IfStmt with lhs = min(a, b).
   121  							Pos: ifStmt.Pos(),
   122  							End: ifStmt.End(),
   123  							NewText: fmt.Appendf(nil, "%s = %s(%s%s)",
   124  								astutil.Format(pass.Fset, lhs),
   125  								sym,
   126  								callArg(a, ifStmt.Pos(), ifStmt.Else.Pos()),
   127  								callArg(b, ifStmt.Else.Pos(), ifStmt.End()),
   128  							),
   129  						}},
   130  					}},
   131  				})
   132  			}
   133  
   134  		} else if prev, ok := curIfStmt.PrevSibling(); ok && isSimpleAssign(prev.Node()) && ifStmt.Else == nil {
   135  			fassign := prev.Node().(*ast.AssignStmt)
   136  
   137  			// Have: lhs0 = rhs0; if a < b { lhs = rhs }
   138  			//
   139  			// For pattern 2, check that
   140  			// - lhs = lhs0
   141  			// - {a,b} = {rhs,rhs0} or {rhs,lhs0}
   142  			//   The replacement must use rhs0 not lhs0 though.
   143  			//   For example, we accept this variant:
   144  			//     lhs = x; if lhs < y { lhs = y }   =>   lhs = min(x, y), not min(lhs, y)
   145  			//
   146  			// TODO(adonovan): accept "var lhs0 = rhs0" form too.
   147  			lhs0 := fassign.Lhs[0]
   148  			rhs0 := fassign.Rhs[0]
   149  
   150  			if astutil.EqualSyntax(lhs, lhs0) {
   151  				if astutil.EqualSyntax(rhs, a) && (astutil.EqualSyntax(rhs0, b) || astutil.EqualSyntax(lhs0, b)) {
   152  					sign = +sign
   153  				} else if (astutil.EqualSyntax(rhs0, a) || astutil.EqualSyntax(lhs0, a)) && astutil.EqualSyntax(rhs, b) {
   154  					sign = -sign
   155  				} else {
   156  					return
   157  				}
   158  				sym := cond(sign < 0, "min", "max")
   159  
   160  				if !is[*types.Builtin](lookup(pass.TypesInfo, curIfStmt, sym)) {
   161  					return // min/max function is shadowed
   162  				}
   163  
   164  				// Permit lhs0 to stand for rhs0 in the matching,
   165  				// but don't actually reduce to lhs0 = min(lhs0, rhs)
   166  				// since the "=" could be a ":=". Use min(rhs0, rhs).
   167  				if astutil.EqualSyntax(lhs0, a) {
   168  					a = rhs0
   169  				} else if astutil.EqualSyntax(lhs0, b) {
   170  					b = rhs0
   171  				}
   172  
   173  				// pattern 2
   174  				pass.Report(analysis.Diagnostic{
   175  					// Highlight the condition a < b.
   176  					Pos:     compare.Pos(),
   177  					End:     compare.End(),
   178  					Message: fmt.Sprintf("if statement can be modernized using %s", sym),
   179  					SuggestedFixes: []analysis.SuggestedFix{{
   180  						Message: fmt.Sprintf("Replace if/else with %s", sym),
   181  						TextEdits: []analysis.TextEdit{{
   182  							Pos: fassign.Pos(),
   183  							End: ifStmt.End(),
   184  							// Replace "x := a; if ... {}" with "x = min(...)", preserving comments.
   185  							NewText: fmt.Appendf(nil, "%s %s %s(%s%s)",
   186  								astutil.Format(pass.Fset, lhs),
   187  								fassign.Tok.String(),
   188  								sym,
   189  								callArg(a, fassign.Pos(), ifStmt.Pos()),
   190  								callArg(b, ifStmt.Pos(), ifStmt.End()),
   191  							),
   192  						}},
   193  					}},
   194  				})
   195  			}
   196  		}
   197  	}
   198  
   199  	// Find all "if a < b { lhs = rhs }" statements.
   200  	info := pass.TypesInfo
   201  	for curFile := range filesUsingGoVersion(pass, versions.Go1_21) {
   202  		astFile := curFile.Node().(*ast.File)
   203  		for curIfStmt := range curFile.Preorder((*ast.IfStmt)(nil)) {
   204  			ifStmt := curIfStmt.Node().(*ast.IfStmt)
   205  
   206  			// Don't bother handling "if a < b { lhs = rhs }" when it appears
   207  			// as the "else" branch of another if-statement.
   208  			//    if cond { ... } else if a < b { lhs = rhs }
   209  			// (This case would require introducing another block
   210  			//    if cond { ... } else { if a < b { lhs = rhs } }
   211  			// and checking that there is no following "else".)
   212  			if astutil.IsChildOf(curIfStmt, edge.IfStmt_Else) {
   213  				continue
   214  			}
   215  
   216  			if compare, ok := ifStmt.Cond.(*ast.BinaryExpr); ok &&
   217  				ifStmt.Init == nil &&
   218  				isInequality(compare.Op) != 0 &&
   219  				isAssignBlock(ifStmt.Body) {
   220  				// a blank var has no type.
   221  				if tLHS := info.TypeOf(ifStmt.Body.List[0].(*ast.AssignStmt).Lhs[0]); tLHS != nil && !maybeNaN(tLHS) {
   222  					// Have: if a < b { lhs = rhs }
   223  					check(astFile, curIfStmt, compare)
   224  				}
   225  			}
   226  		}
   227  	}
   228  	return nil, nil
   229  }
   230  
   231  // allComments collects all the comments from start to end.
   232  func allComments(file *ast.File, start, end token.Pos) string {
   233  	var buf strings.Builder
   234  	for co := range astutil.Comments(file, start, end) {
   235  		_, _ = fmt.Fprintf(&buf, "%s\n", co.Text)
   236  	}
   237  	return buf.String()
   238  }
   239  
   240  // isInequality reports non-zero if tok is one of < <= => >:
   241  // +1 for > and -1 for <.
   242  func isInequality(tok token.Token) int {
   243  	switch tok {
   244  	case token.LEQ, token.LSS:
   245  		return -1
   246  	case token.GEQ, token.GTR:
   247  		return +1
   248  	}
   249  	return 0
   250  }
   251  
   252  // isAssignBlock reports whether b is a block of the form { lhs = rhs }.
   253  func isAssignBlock(b *ast.BlockStmt) bool {
   254  	if len(b.List) != 1 {
   255  		return false
   256  	}
   257  	// Inv: the sole statement cannot be { lhs := rhs }.
   258  	return isSimpleAssign(b.List[0])
   259  }
   260  
   261  // isSimpleAssign reports whether n has the form "lhs = rhs" or "lhs := rhs".
   262  func isSimpleAssign(n ast.Node) bool {
   263  	assign, ok := n.(*ast.AssignStmt)
   264  	return ok &&
   265  		(assign.Tok == token.ASSIGN || assign.Tok == token.DEFINE) &&
   266  		len(assign.Lhs) == 1 &&
   267  		len(assign.Rhs) == 1
   268  }
   269  
   270  // maybeNaN reports whether t is (or may be) a floating-point type.
   271  func maybeNaN(t types.Type) bool {
   272  	// For now, we rely on core types.
   273  	// TODO(adonovan): In the post-core-types future,
   274  	// follow the approach of types.Checker.applyTypeFunc.
   275  	t = typeparams.CoreType(t)
   276  	if t == nil {
   277  		return true // fail safe
   278  	}
   279  	if basic, ok := t.(*types.Basic); ok && basic.Info()&types.IsFloat != 0 {
   280  		return true
   281  	}
   282  	return false
   283  }
   284  
   285  // checkUserDefinedMinMax looks for user-defined min/max functions that are
   286  // equivalent to the built-in functions and suggests removing them.
   287  func checkUserDefinedMinMax(pass *analysis.Pass) {
   288  	index := pass.ResultOf[typeindexanalyzer.Analyzer].(*typeindex.Index)
   289  
   290  	// Look up min and max functions by name in package scope
   291  	for _, funcName := range []string{"min", "max"} {
   292  		if fn, ok := pass.Pkg.Scope().Lookup(funcName).(*types.Func); ok {
   293  			// Use typeindex to get the FuncDecl directly
   294  			if def, ok := index.Def(fn); ok {
   295  				decl := def.Parent().Node().(*ast.FuncDecl)
   296  				// Check if this function matches the built-in min/max signature and behavior
   297  				if canUseBuiltinMinMax(fn, decl.Body) {
   298  					// Expand to include leading doc comment
   299  					pos := decl.Pos()
   300  					if docs := astutil.DocComment(decl); docs != nil {
   301  						pos = docs.Pos()
   302  					}
   303  
   304  					pass.Report(analysis.Diagnostic{
   305  						Pos:     decl.Pos(),
   306  						End:     decl.End(),
   307  						Message: fmt.Sprintf("user-defined %s function is equivalent to built-in %s and can be removed", funcName, funcName),
   308  						SuggestedFixes: []analysis.SuggestedFix{{
   309  							Message: fmt.Sprintf("Remove user-defined %s function", funcName),
   310  							TextEdits: []analysis.TextEdit{{
   311  								Pos: pos,
   312  								End: decl.End(),
   313  							}},
   314  						}},
   315  					})
   316  				}
   317  			}
   318  		}
   319  	}
   320  }
   321  
   322  // canUseBuiltinMinMax reports whether it is safe to replace a call
   323  // to this min or max function by its built-in namesake.
   324  func canUseBuiltinMinMax(fn *types.Func, body *ast.BlockStmt) bool {
   325  	sig := fn.Type().(*types.Signature)
   326  
   327  	// Only consider the most common case: exactly 2 parameters
   328  	if sig.Params().Len() != 2 {
   329  		return false
   330  	}
   331  
   332  	// Check if any parameter might be floating-point
   333  	for param := range sig.Params().Variables() {
   334  		if maybeNaN(param.Type()) {
   335  			return false // Don't suggest removal for float types due to NaN handling
   336  		}
   337  	}
   338  
   339  	// Must have exactly one return value
   340  	if sig.Results().Len() != 1 {
   341  		return false
   342  	}
   343  
   344  	// Check that the function body implements the expected min/max logic
   345  	if body == nil {
   346  		return false
   347  	}
   348  
   349  	return hasMinMaxLogic(body, fn.Name())
   350  }
   351  
   352  // hasMinMaxLogic checks if the function body implements simple min/max logic.
   353  func hasMinMaxLogic(body *ast.BlockStmt, funcName string) bool {
   354  	// Pattern 1: Single if/else statement
   355  	if len(body.List) == 1 {
   356  		if ifStmt, ok := body.List[0].(*ast.IfStmt); ok {
   357  			// Get the "false" result from the else block
   358  			if elseBlock, ok := ifStmt.Else.(*ast.BlockStmt); ok && len(elseBlock.List) == 1 {
   359  				if elseRet, ok := elseBlock.List[0].(*ast.ReturnStmt); ok && len(elseRet.Results) == 1 {
   360  					return checkMinMaxPattern(ifStmt, elseRet.Results[0], funcName)
   361  				}
   362  			}
   363  		}
   364  	}
   365  
   366  	// Pattern 2: if statement followed by return
   367  	if len(body.List) == 2 {
   368  		if ifStmt, ok := body.List[0].(*ast.IfStmt); ok && ifStmt.Else == nil {
   369  			if retStmt, ok := body.List[1].(*ast.ReturnStmt); ok && len(retStmt.Results) == 1 {
   370  				return checkMinMaxPattern(ifStmt, retStmt.Results[0], funcName)
   371  			}
   372  		}
   373  	}
   374  
   375  	return false
   376  }
   377  
   378  // checkMinMaxPattern checks if an if statement implements min/max logic.
   379  // ifStmt: the if statement to check
   380  // falseResult: the expression returned when the condition is false
   381  // funcName: "min" or "max"
   382  func checkMinMaxPattern(ifStmt *ast.IfStmt, falseResult ast.Expr, funcName string) bool {
   383  	// Must have condition with comparison
   384  	cmp, ok := ifStmt.Cond.(*ast.BinaryExpr)
   385  	if !ok {
   386  		return false
   387  	}
   388  
   389  	// Check if then branch returns one of the compared values
   390  	if len(ifStmt.Body.List) != 1 {
   391  		return false
   392  	}
   393  
   394  	thenRet, ok := ifStmt.Body.List[0].(*ast.ReturnStmt)
   395  	if !ok || len(thenRet.Results) != 1 {
   396  		return false
   397  	}
   398  
   399  	// Use the same logic as the existing minmax analyzer
   400  	sign := isInequality(cmp.Op)
   401  	if sign == 0 {
   402  		return false // Not a comparison operator
   403  	}
   404  
   405  	t := thenRet.Results[0] // "true" result
   406  	f := falseResult        // "false" result
   407  	x := cmp.X              // left operand
   408  	y := cmp.Y              // right operand
   409  
   410  	// Check operand order and adjust sign accordingly
   411  	if astutil.EqualSyntax(t, x) && astutil.EqualSyntax(f, y) {
   412  		sign = +sign
   413  	} else if astutil.EqualSyntax(t, y) && astutil.EqualSyntax(f, x) {
   414  		sign = -sign
   415  	} else {
   416  		return false
   417  	}
   418  
   419  	// Check if the sign matches the function name
   420  	return cond(sign < 0, "min", "max") == funcName
   421  }
   422  
   423  // -- utils --
   424  
   425  func is[T any](x any) bool {
   426  	_, ok := x.(T)
   427  	return ok
   428  }
   429  
   430  func cond[T any](cond bool, t, f T) T {
   431  	if cond {
   432  		return t
   433  	} else {
   434  		return f
   435  	}
   436  }
   437  

View as plain text