Source file src/cmd/vendor/golang.org/x/tools/go/analysis/passes/modernize/slicescontains.go

     1  // Copyright 2024 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package modernize
     6  
     7  import (
     8  	"fmt"
     9  	"go/ast"
    10  	"go/token"
    11  	"go/types"
    12  
    13  	"golang.org/x/tools/go/analysis"
    14  	"golang.org/x/tools/go/analysis/passes/inspect"
    15  	"golang.org/x/tools/go/ast/inspector"
    16  	"golang.org/x/tools/go/types/typeutil"
    17  	"golang.org/x/tools/internal/analysis/analyzerutil"
    18  	typeindexanalyzer "golang.org/x/tools/internal/analysis/typeindex"
    19  	"golang.org/x/tools/internal/astutil"
    20  	"golang.org/x/tools/internal/refactor"
    21  	"golang.org/x/tools/internal/typeparams"
    22  	"golang.org/x/tools/internal/typesinternal/typeindex"
    23  	"golang.org/x/tools/internal/versions"
    24  )
    25  
    26  var SlicesContainsAnalyzer = &analysis.Analyzer{
    27  	Name: "slicescontains",
    28  	Doc:  analyzerutil.MustExtractDoc(doc, "slicescontains"),
    29  	Requires: []*analysis.Analyzer{
    30  		inspect.Analyzer,
    31  		typeindexanalyzer.Analyzer,
    32  	},
    33  	Run: slicescontains,
    34  	URL: "https://pkg.go.dev/golang.org/x/tools/go/analysis/passes/modernize#slicescontains",
    35  }
    36  
    37  // The slicescontains pass identifies loops that can be replaced by a
    38  // call to slices.Contains{,Func}. For example:
    39  //
    40  //	for i, elem := range s {
    41  //		if elem == needle {
    42  //			...
    43  //			break
    44  //		}
    45  //	}
    46  //
    47  // =>
    48  //
    49  //	if slices.Contains(s, needle) { ... }
    50  //
    51  // Variants:
    52  //   - if the if-condition is f(elem), the replacement
    53  //     uses slices.ContainsFunc(s, f).
    54  //   - if the if-body is "return true" and the fallthrough
    55  //     statement is "return false" (or vice versa), the
    56  //     loop becomes "return [!]slices.Contains(...)".
    57  //   - if the if-body is "found = true" and the previous
    58  //     statement is "found = false" (or vice versa), the
    59  //     loop becomes "found = [!]slices.Contains(...)".
    60  //
    61  // It may change cardinality of effects of the "needle" expression.
    62  // (Mostly this appears to be a desirable optimization, avoiding
    63  // redundantly repeated evaluation.)
    64  //
    65  // TODO(adonovan): Add a check that needle/predicate expression from
    66  // if-statement has no effects. Now the program behavior may change.
    67  func slicescontains(pass *analysis.Pass) (any, error) {
    68  	// Skip the analyzer in packages where its
    69  	// fixes would create an import cycle.
    70  	if within(pass, "slices", "runtime") {
    71  		return nil, nil
    72  	}
    73  
    74  	var (
    75  		index = pass.ResultOf[typeindexanalyzer.Analyzer].(*typeindex.Index)
    76  		info  = pass.TypesInfo
    77  	)
    78  
    79  	// check is called for each RangeStmt of this form:
    80  	//   for i, elem := range s { if cond { ... } }
    81  	check := func(file *ast.File, curRange inspector.Cursor) {
    82  		rng := curRange.Node().(*ast.RangeStmt)
    83  		ifStmt := rng.Body.List[0].(*ast.IfStmt)
    84  
    85  		// isSliceElem reports whether e denotes the
    86  		// current slice element (elem or s[i]).
    87  		isSliceElem := func(e ast.Expr) bool {
    88  			if rng.Value != nil && astutil.EqualSyntax(e, rng.Value) {
    89  				return true // "elem"
    90  			}
    91  			if x, ok := e.(*ast.IndexExpr); ok &&
    92  				astutil.EqualSyntax(x.X, rng.X) &&
    93  				astutil.EqualSyntax(x.Index, rng.Key) {
    94  				return true // "s[i]"
    95  			}
    96  			return false
    97  		}
    98  
    99  		// Examine the condition for one of these forms:
   100  		//
   101  		// - if elem or s[i] == needle  { ... } => Contains
   102  		// - if predicate(s[i] or elem) { ... } => ContainsFunc
   103  		var (
   104  			funcName string   // "Contains" or "ContainsFunc"
   105  			arg2     ast.Expr // second argument to func (needle or predicate)
   106  		)
   107  		switch cond := ifStmt.Cond.(type) {
   108  		case *ast.BinaryExpr:
   109  			if cond.Op == token.EQL {
   110  				var elem ast.Expr
   111  				if isSliceElem(cond.X) {
   112  					funcName = "Contains"
   113  					elem = cond.X
   114  					arg2 = cond.Y // "if elem == needle"
   115  				} else if isSliceElem(cond.Y) {
   116  					funcName = "Contains"
   117  					elem = cond.Y
   118  					arg2 = cond.X // "if needle == elem"
   119  				}
   120  
   121  				// Reject if elem and needle have different types.
   122  				if elem != nil {
   123  					tElem := info.TypeOf(elem)
   124  					tNeedle := info.TypeOf(arg2)
   125  					if !types.Identical(tElem, tNeedle) {
   126  						// Avoid ill-typed slices.Contains([]error, any).
   127  						if !types.AssignableTo(tNeedle, tElem) {
   128  							return
   129  						}
   130  						// TODO(adonovan): relax this check to allow
   131  						//   slices.Contains([]error, error(any)),
   132  						// inserting an explicit widening conversion
   133  						// around the needle.
   134  						return
   135  					}
   136  				}
   137  			}
   138  
   139  		case *ast.CallExpr:
   140  			if len(cond.Args) == 1 &&
   141  				isSliceElem(cond.Args[0]) &&
   142  				typeutil.Callee(info, cond) != nil { // not a conversion
   143  
   144  				// Attempt to get signature
   145  				sig, isSignature := info.TypeOf(cond.Fun).(*types.Signature)
   146  				if isSignature {
   147  					// skip variadic functions
   148  					if sig.Variadic() {
   149  						return
   150  					}
   151  
   152  					// Slice element type must match function parameter type.
   153  					var (
   154  						tElem  = typeparams.CoreType(info.TypeOf(rng.X)).(*types.Slice).Elem()
   155  						tParam = sig.Params().At(0).Type()
   156  					)
   157  					if !types.Identical(tElem, tParam) {
   158  						return
   159  					}
   160  				}
   161  
   162  				funcName = "ContainsFunc"
   163  				arg2 = cond.Fun // "if predicate(elem)"
   164  			}
   165  		}
   166  		if funcName == "" {
   167  			return // not a candidate for Contains{,Func}
   168  		}
   169  
   170  		// body is the "true" body.
   171  		body := ifStmt.Body
   172  		if len(body.List) == 0 {
   173  			// (We could perhaps delete the loop entirely.)
   174  			return
   175  		}
   176  
   177  		// Reject if the body, needle or predicate references either range variable.
   178  		usesRangeVar := func(n ast.Node) bool {
   179  			cur, ok := curRange.FindNode(n)
   180  			if !ok {
   181  				panic(fmt.Sprintf("FindNode(%T) failed", n))
   182  			}
   183  			return uses(index, cur, info.Defs[rng.Key.(*ast.Ident)]) ||
   184  				rng.Value != nil && uses(index, cur, info.Defs[rng.Value.(*ast.Ident)])
   185  		}
   186  		if usesRangeVar(body) {
   187  			// Body uses range var "i" or "elem".
   188  			//
   189  			// (The check for "i" could be relaxed when we
   190  			// generalize this to support slices.Index;
   191  			// and the check for "elem" could be relaxed
   192  			// if "elem" can safely be replaced in the
   193  			// body by "needle".)
   194  			return
   195  		}
   196  		if usesRangeVar(arg2) {
   197  			return
   198  		}
   199  
   200  		// Prepare slices.Contains{,Func} call.
   201  		prefix, importEdits := refactor.AddImport(info, file, "slices", "slices", funcName, rng.Pos())
   202  		contains := fmt.Sprintf("%s%s(%s, %s)",
   203  			prefix,
   204  			funcName,
   205  			astutil.Format(pass.Fset, rng.X),
   206  			astutil.Format(pass.Fset, arg2))
   207  
   208  		report := func(edits []analysis.TextEdit) {
   209  			pass.Report(analysis.Diagnostic{
   210  				Pos:     rng.Pos(),
   211  				End:     rng.End(),
   212  				Message: fmt.Sprintf("Loop can be simplified using slices.%s", funcName),
   213  				SuggestedFixes: []analysis.SuggestedFix{{
   214  					Message:   "Replace loop by call to slices." + funcName,
   215  					TextEdits: append(edits, importEdits...),
   216  				}},
   217  			})
   218  		}
   219  
   220  		// Last statement of body must return/break out of the loop.
   221  		//
   222  		// TODO(adonovan): opt:consider avoiding FindNode with new API of form:
   223  		//    curRange.Get(edge.RangeStmt_Body, -1).
   224  		//             Get(edge.BodyStmt_List, 0).
   225  		//             Get(edge.IfStmt_Body)
   226  		curBody, _ := curRange.FindNode(body)
   227  		curLastStmt, _ := curBody.LastChild()
   228  
   229  		// Reject if any statement in the body except the
   230  		// last has a free continuation (continue or break)
   231  		// that might affected by melting down the loop.
   232  		//
   233  		// TODO(adonovan): relax check by analyzing branch target.
   234  		for curBodyStmt := range curBody.Children() {
   235  			if curBodyStmt != curLastStmt {
   236  				for range curBodyStmt.Preorder((*ast.BranchStmt)(nil), (*ast.ReturnStmt)(nil)) {
   237  					return
   238  				}
   239  			}
   240  		}
   241  
   242  		switch lastStmt := curLastStmt.Node().(type) {
   243  		case *ast.ReturnStmt:
   244  			// Have: for ... range seq { if ... { stmts; return x } }
   245  
   246  			// Special case:
   247  			// body={ return true } next="return false"   (or negation)
   248  			// => return [!]slices.Contains(...)
   249  			if curNext, ok := curRange.NextSibling(); ok {
   250  				nextStmt := curNext.Node().(ast.Stmt)
   251  				tval := isReturnTrueOrFalse(info, lastStmt)
   252  				fval := isReturnTrueOrFalse(info, nextStmt)
   253  				if len(body.List) == 1 && tval*fval < 0 {
   254  					//    for ... { if ... { return true/false } }
   255  					// => return [!]slices.Contains(...)
   256  					report([]analysis.TextEdit{
   257  						// Delete the range statement and following space.
   258  						{
   259  							Pos: rng.Pos(),
   260  							End: nextStmt.Pos(),
   261  						},
   262  						// Change return to [!]slices.Contains(...).
   263  						{
   264  							Pos: nextStmt.Pos(),
   265  							End: nextStmt.End(),
   266  							NewText: fmt.Appendf(nil, "return %s%s",
   267  								cond(tval > 0, "", "!"),
   268  								contains),
   269  						},
   270  					})
   271  					return
   272  				}
   273  			}
   274  
   275  			// General case:
   276  			// => if slices.Contains(...) { stmts; return x }
   277  			report([]analysis.TextEdit{
   278  				// Replace "for ... { if ... " with "if slices.Contains(...)".
   279  				{
   280  					Pos:     rng.Pos(),
   281  					End:     ifStmt.Body.Pos(),
   282  					NewText: fmt.Appendf(nil, "if %s ", contains),
   283  				},
   284  				// Delete '}' of range statement and preceding space.
   285  				{
   286  					Pos: ifStmt.Body.End(),
   287  					End: rng.End(),
   288  				},
   289  			})
   290  			return
   291  
   292  		case *ast.BranchStmt:
   293  			if lastStmt.Tok == token.BREAK && lastStmt.Label == nil { // unlabeled break
   294  				// Have: for ... { if ... { stmts; break } }
   295  
   296  				var prevStmt ast.Stmt // previous statement to range (if any)
   297  				if curPrev, ok := curRange.PrevSibling(); ok {
   298  					// If the RangeStmt's previous sibling is a Stmt,
   299  					// the RangeStmt must be among the Body list of
   300  					// a BlockStmt, CauseClause, or CommClause.
   301  					// In all cases, the prevStmt is the immediate
   302  					// predecessor of the RangeStmt during execution.
   303  					//
   304  					// (This is not true for Stmts in general;
   305  					// see [Cursor.Children] and #71074.)
   306  					prevStmt, _ = curPrev.Node().(ast.Stmt)
   307  				}
   308  
   309  				// Special case:
   310  				// prev="lhs = false" body={ lhs = true; break }
   311  				// => lhs = slices.Contains(...) (or its negation)
   312  				if assign, ok := body.List[0].(*ast.AssignStmt); ok &&
   313  					len(body.List) == 2 &&
   314  					assign.Tok == token.ASSIGN &&
   315  					len(assign.Lhs) == 1 &&
   316  					len(assign.Rhs) == 1 {
   317  
   318  					// Have: body={ lhs = rhs; break }
   319  					if prevAssign, ok := prevStmt.(*ast.AssignStmt); ok &&
   320  						len(prevAssign.Lhs) == 1 &&
   321  						len(prevAssign.Rhs) == 1 &&
   322  						astutil.EqualSyntax(prevAssign.Lhs[0], assign.Lhs[0]) &&
   323  						isTrueOrFalse(info, assign.Rhs[0]) ==
   324  							-isTrueOrFalse(info, prevAssign.Rhs[0]) {
   325  
   326  						// Have:
   327  						//    lhs = false
   328  						//    for ... { if ... { lhs = true; break } }
   329  						//  =>
   330  						//    lhs = slices.Contains(...)
   331  						//
   332  						// TODO(adonovan):
   333  						// - support "var lhs bool = false" and variants.
   334  						// - allow the break to be omitted.
   335  						neg := cond(isTrueOrFalse(info, assign.Rhs[0]) < 0, "!", "")
   336  						report([]analysis.TextEdit{
   337  							// Replace "rhs" of previous assignment by [!]slices.Contains(...)
   338  							{
   339  								Pos:     prevAssign.Rhs[0].Pos(),
   340  								End:     prevAssign.Rhs[0].End(),
   341  								NewText: []byte(neg + contains),
   342  							},
   343  							// Delete the loop and preceding space.
   344  							{
   345  								Pos: prevAssign.Rhs[0].End(),
   346  								End: rng.End(),
   347  							},
   348  						})
   349  						return
   350  					}
   351  				}
   352  
   353  				// General case:
   354  				//    for ... { if ...        { stmts; break } }
   355  				// => if slices.Contains(...) { stmts        }
   356  				report([]analysis.TextEdit{
   357  					// Replace "for ... { if ... " with "if slices.Contains(...)".
   358  					{
   359  						Pos:     rng.Pos(),
   360  						End:     ifStmt.Body.Pos(),
   361  						NewText: fmt.Appendf(nil, "if %s ", contains),
   362  					},
   363  					// Delete break statement and preceding space.
   364  					{
   365  						Pos: func() token.Pos {
   366  							if len(body.List) > 1 {
   367  								beforeBreak, _ := curLastStmt.PrevSibling()
   368  								return beforeBreak.Node().End()
   369  							}
   370  							return lastStmt.Pos()
   371  						}(),
   372  						End: lastStmt.End(),
   373  					},
   374  					// Delete '}' of range statement and preceding space.
   375  					{
   376  						Pos: ifStmt.Body.End(),
   377  						End: rng.End(),
   378  					},
   379  				})
   380  				return
   381  			}
   382  		}
   383  	}
   384  
   385  	for curFile := range filesUsingGoVersion(pass, versions.Go1_21) {
   386  		file := curFile.Node().(*ast.File)
   387  
   388  		for curRange := range curFile.Preorder((*ast.RangeStmt)(nil)) {
   389  			rng := curRange.Node().(*ast.RangeStmt)
   390  
   391  			if is[*ast.Ident](rng.Key) &&
   392  				rng.Tok == token.DEFINE &&
   393  				len(rng.Body.List) == 1 &&
   394  				is[*types.Slice](typeparams.CoreType(info.TypeOf(rng.X))) {
   395  
   396  				// Have:
   397  				// - for _, elem := range s { S }
   398  				// - for i       := range s { S }
   399  
   400  				if ifStmt, ok := rng.Body.List[0].(*ast.IfStmt); ok &&
   401  					ifStmt.Init == nil && ifStmt.Else == nil {
   402  
   403  					// Have: for i, elem := range s { if cond { ... } }
   404  					check(file, curRange)
   405  				}
   406  			}
   407  		}
   408  	}
   409  	return nil, nil
   410  }
   411  
   412  // -- helpers --
   413  
   414  // isReturnTrueOrFalse returns nonzero if stmt returns true (+1) or false (-1).
   415  func isReturnTrueOrFalse(info *types.Info, stmt ast.Stmt) int {
   416  	if ret, ok := stmt.(*ast.ReturnStmt); ok && len(ret.Results) == 1 {
   417  		return isTrueOrFalse(info, ret.Results[0])
   418  	}
   419  	return 0
   420  }
   421  
   422  // isTrueOrFalse returns nonzero if expr is literally true (+1) or false (-1).
   423  func isTrueOrFalse(info *types.Info, expr ast.Expr) int {
   424  	if id, ok := expr.(*ast.Ident); ok {
   425  		switch info.Uses[id] {
   426  		case builtinTrue:
   427  			return +1
   428  		case builtinFalse:
   429  			return -1
   430  		}
   431  	}
   432  	return 0
   433  }
   434  

View as plain text