1
2
3
4
5 package modernize
6
7 import (
8 "bytes"
9 "fmt"
10 "go/ast"
11 "go/printer"
12 "slices"
13
14 "golang.org/x/tools/go/analysis"
15 "golang.org/x/tools/go/analysis/passes/inspect"
16 "golang.org/x/tools/go/types/typeutil"
17 "golang.org/x/tools/internal/analysis/analyzerutil"
18 typeindexanalyzer "golang.org/x/tools/internal/analysis/typeindex"
19 "golang.org/x/tools/internal/astutil"
20 "golang.org/x/tools/internal/refactor"
21 "golang.org/x/tools/internal/typesinternal/typeindex"
22 "golang.org/x/tools/internal/versions"
23 )
24
25 var WaitGroupAnalyzer = &analysis.Analyzer{
26 Name: "waitgroup",
27 Doc: analyzerutil.MustExtractDoc(doc, "waitgroup"),
28 Requires: []*analysis.Analyzer{
29 inspect.Analyzer,
30 typeindexanalyzer.Analyzer,
31 },
32 Run: waitgroup,
33 URL: "https://pkg.go.dev/golang.org/x/tools/go/analysis/passes/modernize#waitgroup",
34 }
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62 func waitgroup(pass *analysis.Pass) (any, error) {
63 var (
64 index = pass.ResultOf[typeindexanalyzer.Analyzer].(*typeindex.Index)
65 info = pass.TypesInfo
66 syncWaitGroupAdd = index.Selection("sync", "WaitGroup", "Add")
67 syncWaitGroupDone = index.Selection("sync", "WaitGroup", "Done")
68 )
69 if !index.Used(syncWaitGroupDone) {
70 return nil, nil
71 }
72
73 for curAddCall := range index.Calls(syncWaitGroupAdd) {
74
75 addCall := curAddCall.Node().(*ast.CallExpr)
76 if !isIntLiteral(info, addCall.Args[0], 1) {
77 continue
78 }
79
80
81 addCallRecv := ast.Unparen(addCall.Fun).(*ast.SelectorExpr).X
82
83
84 curAddStmt := curAddCall.Parent()
85 if !is[*ast.ExprStmt](curAddStmt.Node()) {
86 continue
87 }
88 curNext, ok := curAddCall.Parent().NextSibling()
89 if !ok {
90 continue
91 }
92 goStmt, ok := curNext.Node().(*ast.GoStmt)
93 if !ok {
94 continue
95 }
96 lit, ok := goStmt.Call.Fun.(*ast.FuncLit)
97 if !ok || len(goStmt.Call.Args) != 0 {
98 continue
99 }
100 list := lit.Body.List
101 if len(list) == 0 {
102 continue
103 }
104
105
106 var doneStmt ast.Stmt
107 if deferStmt, ok := list[0].(*ast.DeferStmt); ok &&
108 typeutil.Callee(info, deferStmt.Call) == syncWaitGroupDone &&
109 astutil.EqualSyntax(ast.Unparen(deferStmt.Call.Fun).(*ast.SelectorExpr).X, addCallRecv) {
110 doneStmt = deferStmt
111
112 } else if lastStmt, ok := list[len(list)-1].(*ast.ExprStmt); ok {
113 if doneCall, ok := lastStmt.X.(*ast.CallExpr); ok &&
114 typeutil.Callee(info, doneCall) == syncWaitGroupDone &&
115 astutil.EqualSyntax(ast.Unparen(doneCall.Fun).(*ast.SelectorExpr).X, addCallRecv) {
116 doneStmt = lastStmt
117 }
118 }
119 if doneStmt == nil {
120 continue
121 }
122 curDoneStmt, ok := curNext.FindNode(doneStmt)
123 if !ok {
124 panic("can't find Cursor for 'done' statement")
125 }
126
127 file := astutil.EnclosingFile(curAddCall)
128 if !analyzerutil.FileUsesGoVersion(pass, file, versions.Go1_25) {
129 continue
130 }
131 tokFile := pass.Fset.File(file.Pos())
132
133 var addCallRecvText bytes.Buffer
134 err := printer.Fprint(&addCallRecvText, pass.Fset, addCallRecv)
135 if err != nil {
136 continue
137 }
138
139 pass.Report(analysis.Diagnostic{
140
141
142 Pos: goStmt.Pos(),
143 End: lit.Type.End(),
144 Message: "Goroutine creation can be simplified using WaitGroup.Go",
145 SuggestedFixes: []analysis.SuggestedFix{{
146 Message: "Simplify by using WaitGroup.Go",
147 TextEdits: slices.Concat(
148
149 refactor.DeleteStmt(tokFile, curAddStmt),
150
151 refactor.DeleteStmt(tokFile, curDoneStmt),
152 []analysis.TextEdit{
153
154
155
156 {
157 Pos: goStmt.Pos(),
158 End: goStmt.Call.Pos(),
159 NewText: fmt.Appendf(nil, "%s.Go(", addCallRecvText.String()),
160 },
161
162
163
164 {
165 Pos: goStmt.Call.Lparen,
166 End: goStmt.Call.Rparen,
167 },
168 },
169 ),
170 }},
171 })
172 }
173 return nil, nil
174 }
175
View as plain text