Source file src/cmd/vendor/golang.org/x/tools/go/analysis/passes/modernize/bloop.go

     1  // Copyright 2024 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package modernize
     6  
     7  import (
     8  	"fmt"
     9  	"go/ast"
    10  	"go/token"
    11  	"go/types"
    12  	"strings"
    13  
    14  	"golang.org/x/tools/go/analysis"
    15  	"golang.org/x/tools/go/analysis/passes/inspect"
    16  	"golang.org/x/tools/go/ast/inspector"
    17  	"golang.org/x/tools/go/types/typeutil"
    18  	"golang.org/x/tools/internal/analysis/analyzerutil"
    19  	typeindexanalyzer "golang.org/x/tools/internal/analysis/typeindex"
    20  	"golang.org/x/tools/internal/astutil"
    21  	"golang.org/x/tools/internal/moreiters"
    22  	"golang.org/x/tools/internal/typesinternal"
    23  	"golang.org/x/tools/internal/typesinternal/typeindex"
    24  	"golang.org/x/tools/internal/versions"
    25  )
    26  
    27  var BLoopAnalyzer = &analysis.Analyzer{
    28  	Name: "bloop",
    29  	Doc:  analyzerutil.MustExtractDoc(doc, "bloop"),
    30  	Requires: []*analysis.Analyzer{
    31  		inspect.Analyzer,
    32  		typeindexanalyzer.Analyzer,
    33  	},
    34  	Run: bloop,
    35  	URL: "https://pkg.go.dev/golang.org/x/tools/go/analysis/passes/modernize#bloop",
    36  }
    37  
    38  // bloop updates benchmarks that use "for range b.N", replacing it
    39  // with go1.24's b.Loop() and eliminating any preceding
    40  // b.{Start,Stop,Reset}Timer calls.
    41  //
    42  // Variants:
    43  //
    44  //	for i := 0; i < b.N; i++ {}  =>   for b.Loop() {}
    45  //	for range b.N {}
    46  func bloop(pass *analysis.Pass) (any, error) {
    47  	if !typesinternal.Imports(pass.Pkg, "testing") {
    48  		return nil, nil
    49  	}
    50  
    51  	var (
    52  		index = pass.ResultOf[typeindexanalyzer.Analyzer].(*typeindex.Index)
    53  		info  = pass.TypesInfo
    54  	)
    55  
    56  	// edits computes the text edits for a matched for/range loop
    57  	// at the specified cursor. b is the *testing.B value, and
    58  	// (start, end) is the portion using b.N to delete.
    59  	edits := func(curLoop inspector.Cursor, b ast.Expr, start, end token.Pos) (edits []analysis.TextEdit) {
    60  		curFn, _ := enclosingFunc(curLoop)
    61  		// Within the same function, delete all calls to
    62  		// b.{Start,Stop,Timer} that precede the loop.
    63  		filter := []ast.Node{(*ast.ExprStmt)(nil), (*ast.FuncLit)(nil)}
    64  		curFn.Inspect(filter, func(cur inspector.Cursor) (descend bool) {
    65  			node := cur.Node()
    66  			if is[*ast.FuncLit](node) {
    67  				return false // don't descend into FuncLits (e.g. sub-benchmarks)
    68  			}
    69  			stmt := node.(*ast.ExprStmt)
    70  			if stmt.Pos() > start {
    71  				return false // not preceding: stop
    72  			}
    73  			if call, ok := stmt.X.(*ast.CallExpr); ok {
    74  				obj := typeutil.Callee(info, call)
    75  				if typesinternal.IsMethodNamed(obj, "testing", "B", "StopTimer", "StartTimer", "ResetTimer") {
    76  					// Delete call statement.
    77  					// TODO(adonovan): delete following newline, or
    78  					// up to start of next stmt? (May delete a comment.)
    79  					edits = append(edits, analysis.TextEdit{
    80  						Pos: stmt.Pos(),
    81  						End: stmt.End(),
    82  					})
    83  				}
    84  			}
    85  			return true
    86  		})
    87  
    88  		// Replace ...b.N... with b.Loop().
    89  		return append(edits, analysis.TextEdit{
    90  			Pos:     start,
    91  			End:     end,
    92  			NewText: fmt.Appendf(nil, "%s.Loop()", astutil.Format(pass.Fset, b)),
    93  		})
    94  	}
    95  
    96  	// Find all for/range statements.
    97  	loops := []ast.Node{
    98  		(*ast.ForStmt)(nil),
    99  		(*ast.RangeStmt)(nil),
   100  	}
   101  	for curFile := range filesUsingGoVersion(pass, versions.Go1_24) {
   102  		for curLoop := range curFile.Preorder(loops...) {
   103  			switch n := curLoop.Node().(type) {
   104  			case *ast.ForStmt:
   105  				// for _; i < b.N; _ {}
   106  				if cmp, ok := n.Cond.(*ast.BinaryExpr); ok && cmp.Op == token.LSS {
   107  					if sel, ok := cmp.Y.(*ast.SelectorExpr); ok &&
   108  						sel.Sel.Name == "N" &&
   109  						typesinternal.IsPointerToNamed(info.TypeOf(sel.X), "testing", "B") && usesBenchmarkNOnce(curLoop, info) {
   110  
   111  						delStart, delEnd := n.Cond.Pos(), n.Cond.End()
   112  
   113  						// Eliminate variable i if no longer needed:
   114  						//  for i := 0; i < b.N; i++ {
   115  						//    ...no references to i...
   116  						//  }
   117  						body, _ := curLoop.LastChild()
   118  						if v := isIncrementLoop(info, n); v != nil &&
   119  							!uses(index, body, v) {
   120  							delStart, delEnd = n.Init.Pos(), n.Post.End()
   121  						}
   122  
   123  						pass.Report(analysis.Diagnostic{
   124  							// Highlight "i < b.N".
   125  							Pos:     n.Cond.Pos(),
   126  							End:     n.Cond.End(),
   127  							Message: "b.N can be modernized using b.Loop()",
   128  							SuggestedFixes: []analysis.SuggestedFix{{
   129  								Message:   "Replace b.N with b.Loop()",
   130  								TextEdits: edits(curLoop, sel.X, delStart, delEnd),
   131  							}},
   132  						})
   133  					}
   134  				}
   135  
   136  			case *ast.RangeStmt:
   137  				// for range b.N {} -> for b.Loop() {}
   138  				//
   139  				// TODO(adonovan): handle "for i := range b.N".
   140  				if sel, ok := n.X.(*ast.SelectorExpr); ok &&
   141  					n.Key == nil &&
   142  					n.Value == nil &&
   143  					sel.Sel.Name == "N" &&
   144  					typesinternal.IsPointerToNamed(info.TypeOf(sel.X), "testing", "B") && usesBenchmarkNOnce(curLoop, info) {
   145  
   146  					pass.Report(analysis.Diagnostic{
   147  						// Highlight "range b.N".
   148  						Pos:     n.Range,
   149  						End:     n.X.End(),
   150  						Message: "b.N can be modernized using b.Loop()",
   151  						SuggestedFixes: []analysis.SuggestedFix{{
   152  							Message:   "Replace b.N with b.Loop()",
   153  							TextEdits: edits(curLoop, sel.X, n.Range, n.X.End()),
   154  						}},
   155  					})
   156  				}
   157  			}
   158  		}
   159  	}
   160  	return nil, nil
   161  }
   162  
   163  // uses reports whether the subtree cur contains a use of obj.
   164  func uses(index *typeindex.Index, cur inspector.Cursor, obj types.Object) bool {
   165  	for use := range index.Uses(obj) {
   166  		if cur.Contains(use) {
   167  			return true
   168  		}
   169  	}
   170  	return false
   171  }
   172  
   173  // enclosingFunc returns the cursor for the innermost Func{Decl,Lit}
   174  // that encloses c, if any.
   175  func enclosingFunc(c inspector.Cursor) (inspector.Cursor, bool) {
   176  	return moreiters.First(c.Enclosing((*ast.FuncDecl)(nil), (*ast.FuncLit)(nil)))
   177  }
   178  
   179  // usesBenchmarkNOnce reports whether a b.N loop should be modernized to b.Loop().
   180  // Only modernize loops that are:
   181  // 1. Directly in a benchmark function (not in nested functions)
   182  //   - b.Loop() must be called in the same goroutine as the benchmark function
   183  //   - Function literals are often used with goroutines (go func(){...})
   184  //
   185  // 2. The only b.N loop in that benchmark function
   186  //   - b.Loop() can only be called once per benchmark execution
   187  //   - Multiple calls result in "B.Loop called with timer stopped" error
   188  //   - Multiple loops may have complex interdependencies that are hard to analyze
   189  func usesBenchmarkNOnce(c inspector.Cursor, info *types.Info) bool {
   190  	// Find the enclosing benchmark function
   191  	curFunc, ok := enclosingFunc(c)
   192  	if !ok {
   193  		return false
   194  	}
   195  
   196  	// Check if this is actually a benchmark function
   197  	fdecl, ok := curFunc.Node().(*ast.FuncDecl)
   198  	if !ok {
   199  		return false // not in a function; or, inside a FuncLit
   200  	}
   201  	if !isBenchmarkFunc(fdecl) {
   202  		return false
   203  	}
   204  
   205  	// Count all b.N references in this benchmark function (including nested functions)
   206  	bnRefCount := 0
   207  	filter := []ast.Node{(*ast.SelectorExpr)(nil)}
   208  	curFunc.Inspect(filter, func(cur inspector.Cursor) bool {
   209  		sel := cur.Node().(*ast.SelectorExpr)
   210  		if sel.Sel.Name == "N" &&
   211  			typesinternal.IsPointerToNamed(info.TypeOf(sel.X), "testing", "B") {
   212  			bnRefCount++
   213  		}
   214  		return true
   215  	})
   216  
   217  	// Only modernize if there's exactly one b.N reference
   218  	return bnRefCount == 1
   219  }
   220  
   221  // isBenchmarkFunc reports whether f is a benchmark function.
   222  func isBenchmarkFunc(f *ast.FuncDecl) bool {
   223  	return f.Recv == nil &&
   224  		f.Name != nil &&
   225  		f.Name.IsExported() &&
   226  		strings.HasPrefix(f.Name.Name, "Benchmark") &&
   227  		f.Type.Params != nil &&
   228  		len(f.Type.Params.List) == 1
   229  }
   230  
   231  // isIncrementLoop reports whether loop has the form "for i := 0; ...; i++ { ... }",
   232  // and if so, it returns the symbol for the index variable.
   233  func isIncrementLoop(info *types.Info, loop *ast.ForStmt) *types.Var {
   234  	if assign, ok := loop.Init.(*ast.AssignStmt); ok &&
   235  		assign.Tok == token.DEFINE &&
   236  		len(assign.Rhs) == 1 &&
   237  		isZeroIntConst(info, assign.Rhs[0]) &&
   238  		is[*ast.IncDecStmt](loop.Post) &&
   239  		loop.Post.(*ast.IncDecStmt).Tok == token.INC &&
   240  		astutil.EqualSyntax(loop.Post.(*ast.IncDecStmt).X, assign.Lhs[0]) {
   241  		return info.Defs[assign.Lhs[0].(*ast.Ident)].(*types.Var)
   242  	}
   243  	return nil
   244  }
   245  

View as plain text