1
2
3
4
5
6
7 package analysisinternal
8
9 import (
10 "bytes"
11 "fmt"
12 "go/ast"
13 "go/scanner"
14 "go/token"
15 "go/types"
16 "os"
17 pathpkg "path"
18
19 "golang.org/x/tools/go/analysis"
20 )
21
22 func TypeErrorEndPos(fset *token.FileSet, src []byte, start token.Pos) token.Pos {
23
24 file := fset.File(start)
25 if file == nil {
26 return start
27 }
28 if offset := file.PositionFor(start, false).Offset; offset > len(src) {
29 return start
30 } else {
31 src = src[offset:]
32 }
33
34
35
36
37
38
39
40
41
42
43
44
45
46 end := start
47 {
48 var s scanner.Scanner
49 fset := token.NewFileSet()
50 f := fset.AddFile("", fset.Base(), len(src))
51 s.Init(f, src, nil , scanner.ScanComments)
52 pos, tok, lit := s.Scan()
53 if tok != token.SEMICOLON && token.Pos(f.Base()) <= pos && pos <= token.Pos(f.Base()+f.Size()) {
54 off := file.Offset(pos) + len(lit)
55 src = src[off:]
56 end += token.Pos(off)
57 }
58 }
59
60
61
62 if width := bytes.IndexAny(src, " \n,():;[]+-*/"); width > 0 {
63 end += token.Pos(width)
64 }
65 return end
66 }
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88 func StmtToInsertVarBefore(path []ast.Node) ast.Stmt {
89 enclosingIndex := -1
90 for i, p := range path {
91 if _, ok := p.(ast.Stmt); ok {
92 enclosingIndex = i
93 break
94 }
95 }
96 if enclosingIndex == -1 {
97 return nil
98 }
99 enclosingStmt := path[enclosingIndex]
100 switch enclosingStmt.(type) {
101 case *ast.IfStmt:
102
103
104
105
106
107
108 return baseIfStmt(path, enclosingIndex)
109 case *ast.CaseClause:
110
111
112 for i := enclosingIndex + 1; i < len(path); i++ {
113 if node, ok := path[i].(*ast.SwitchStmt); ok {
114 return node
115 } else if node, ok := path[i].(*ast.TypeSwitchStmt); ok {
116 return node
117 }
118 }
119 }
120 if len(path) <= enclosingIndex+1 {
121 return enclosingStmt.(ast.Stmt)
122 }
123
124 switch expr := path[enclosingIndex+1].(type) {
125 case *ast.IfStmt:
126
127 return baseIfStmt(path, enclosingIndex+1)
128 case *ast.ForStmt:
129 if expr.Init == enclosingStmt || expr.Post == enclosingStmt {
130 return expr
131 }
132 case *ast.SwitchStmt, *ast.TypeSwitchStmt:
133 return expr.(ast.Stmt)
134 }
135 return enclosingStmt.(ast.Stmt)
136 }
137
138
139
140 func baseIfStmt(path []ast.Node, index int) ast.Stmt {
141 stmt := path[index]
142 for i := index + 1; i < len(path); i++ {
143 if node, ok := path[i].(*ast.IfStmt); ok && node.Else == stmt {
144 stmt = node
145 continue
146 }
147 break
148 }
149 return stmt.(ast.Stmt)
150 }
151
152
153
154 func WalkASTWithParent(n ast.Node, f func(n ast.Node, parent ast.Node) bool) {
155 var ancestors []ast.Node
156 ast.Inspect(n, func(n ast.Node) (recurse bool) {
157 if n == nil {
158 ancestors = ancestors[:len(ancestors)-1]
159 return false
160 }
161
162 var parent ast.Node
163 if len(ancestors) > 0 {
164 parent = ancestors[len(ancestors)-1]
165 }
166 ancestors = append(ancestors, n)
167 return f(n, parent)
168 })
169 }
170
171
172
173
174
175 func MatchingIdents(typs []types.Type, node ast.Node, pos token.Pos, info *types.Info, pkg *types.Package) map[types.Type][]string {
176
177
178 matches := make(map[types.Type][]string)
179 for _, typ := range typs {
180 if typ == nil {
181 continue
182 }
183 matches[typ] = nil
184 }
185
186 seen := map[types.Object]struct{}{}
187 ast.Inspect(node, func(n ast.Node) bool {
188 if n == nil {
189 return false
190 }
191
192
193
194
195
196
197 if assign, ok := n.(*ast.AssignStmt); ok && pos > assign.Pos() && pos <= assign.End() {
198 return false
199 }
200 if n.End() > pos {
201 return n.Pos() <= pos
202 }
203 ident, ok := n.(*ast.Ident)
204 if !ok || ident.Name == "_" {
205 return true
206 }
207 obj := info.Defs[ident]
208 if obj == nil || obj.Type() == nil {
209 return true
210 }
211 if _, ok := obj.(*types.TypeName); ok {
212 return true
213 }
214
215 if _, ok = seen[obj]; ok {
216 return true
217 }
218 seen[obj] = struct{}{}
219
220
221 innerScope := pkg.Scope().Innermost(pos)
222 if innerScope == nil {
223 return true
224 }
225 _, foundObj := innerScope.LookupParent(ident.Name, pos)
226 if foundObj != obj {
227 return true
228 }
229
230
231 if names, ok := matches[obj.Type()]; ok {
232 matches[obj.Type()] = append(names, ident.Name)
233 } else {
234
235
236
237 for typ := range matches {
238 if equivalentTypes(obj.Type(), typ) {
239 matches[typ] = append(matches[typ], ident.Name)
240 }
241 }
242 }
243 return true
244 })
245 return matches
246 }
247
248 func equivalentTypes(want, got types.Type) bool {
249 if types.Identical(want, got) {
250 return true
251 }
252
253 if rhs, ok := want.(*types.Basic); ok && rhs.Info()&types.IsUntyped > 0 {
254 if lhs, ok := got.Underlying().(*types.Basic); ok {
255 return rhs.Info()&types.IsConstType == lhs.Info()&types.IsConstType
256 }
257 }
258 return types.AssignableTo(want, got)
259 }
260
261
262 func MakeReadFile(pass *analysis.Pass) func(filename string) ([]byte, error) {
263 return func(filename string) ([]byte, error) {
264 if err := CheckReadable(pass, filename); err != nil {
265 return nil, err
266 }
267 return os.ReadFile(filename)
268 }
269 }
270
271
272 func CheckReadable(pass *analysis.Pass, filename string) error {
273 if slicesContains(pass.OtherFiles, filename) ||
274 slicesContains(pass.IgnoredFiles, filename) {
275 return nil
276 }
277 for _, f := range pass.Files {
278 if pass.Fset.File(f.FileStart).Name() == filename {
279 return nil
280 }
281 }
282 return fmt.Errorf("Pass.ReadFile: %s is not among OtherFiles, IgnoredFiles, or names of Files", filename)
283 }
284
285
286 func slicesContains[S ~[]E, E comparable](slice S, x E) bool {
287 for _, elem := range slice {
288 if elem == x {
289 return true
290 }
291 }
292 return false
293 }
294
295
296
297
298
299
300
301
302 func AddImport(info *types.Info, file *ast.File, pos token.Pos, pkgpath, preferredName string) (name string, newImport []analysis.TextEdit) {
303
304 scope := info.Scopes[file].Innermost(pos)
305 if scope == nil {
306 panic("no enclosing lexical block")
307 }
308
309
310
311 for _, spec := range file.Imports {
312 pkgname, ok := importedPkgName(info, spec)
313 if ok && pkgname.Imported().Path() == pkgpath {
314 if _, obj := scope.LookupParent(pkgname.Name(), pos); obj == pkgname {
315 return pkgname.Name(), nil
316 }
317 }
318 }
319
320
321
322 newName := preferredName
323 for i := 0; ; i++ {
324 if _, obj := scope.LookupParent(newName, pos); obj == nil {
325 break
326 }
327 newName = fmt.Sprintf("%s%d", preferredName, i)
328 }
329
330
331
332
333
334
335
336
337 newText := fmt.Sprintf("import %q\n\n", pkgpath)
338 if newName != preferredName || newName != pathpkg.Base(pkgpath) {
339 newText = fmt.Sprintf("import %s %q\n\n", newName, pkgpath)
340 }
341 decl0 := file.Decls[0]
342 var before ast.Node = decl0
343 switch decl0 := decl0.(type) {
344 case *ast.GenDecl:
345 if decl0.Doc != nil {
346 before = decl0.Doc
347 }
348 case *ast.FuncDecl:
349 if decl0.Doc != nil {
350 before = decl0.Doc
351 }
352 }
353 return newName, []analysis.TextEdit{{
354 Pos: before.Pos(),
355 End: before.Pos(),
356 NewText: []byte(newText),
357 }}
358 }
359
360
361
362 func importedPkgName(info *types.Info, imp *ast.ImportSpec) (*types.PkgName, bool) {
363 var obj types.Object
364 if imp.Name != nil {
365 obj = info.Defs[imp.Name]
366 } else {
367 obj = info.Implicits[imp]
368 }
369 pkgname, ok := obj.(*types.PkgName)
370 return pkgname, ok
371 }
372
View as plain text