1
2
3
4
5 package midway
6
7 import (
8 "cmd/compile/internal/base"
9 "cmd/compile/internal/syntax"
10 "cmd/compile/internal/types2"
11 "fmt"
12 "internal/buildcfg"
13 "strings"
14 )
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
162
163 type Rewriter struct {
164 pkg *types2.Package
165 analyzer *Analyzer
166 info *types2.Info
167 sizes []int
168 }
169
170 func NewRewriter(pkg *types2.Package, info *types2.Info, analyzer *Analyzer, sizes []int) *Rewriter {
171 return &Rewriter{
172 pkg: pkg,
173 info: info,
174 analyzer: analyzer,
175 sizes: sizes,
176 }
177 }
178
179 func (r *Rewriter) Rewrite(files []*syntax.File) {
180
181
182 for _, fileAST := range files {
183
184 var newDecls []syntax.Decl
185 for _, k := range r.sizes {
186 newDecls = r.generateForSize(fileAST, k, newDecls)
187 }
188
189
190
191 r.generateDispatchers(fileAST)
192
193 fileAST.DeclList = append(fileAST.DeclList, newDecls...)
194 }
195 }
196
197 func (r *Rewriter) generateDispatchers(fileAST *syntax.File) {
198 var newDecls []syntax.Decl
199
200 change := false
201
202 for _, decl := range fileAST.DeclList {
203 switch d := decl.(type) {
204 case *syntax.FuncDecl:
205 if d.Name == nil {
206 newDecls = append(newDecls, d)
207 continue
208 }
209 obj := r.info.Defs[d.Name]
210 if !r.analyzer.isDependentObj[obj] || r.analyzer.inSimd {
211 newDecls = append(newDecls, d)
212 continue
213 }
214
215 sig, ok := obj.Type().(*types2.Signature)
216 if !ok {
217 newDecls = append(newDecls, d)
218 continue
219 }
220
221 change = true
222 if r.analyzer.HasDependentSignature(sig) {
223 if base.Debug.Simd > 0 {
224 base.Warn("%s: removing body of dependent-sig original function %v", d.Pos().String(), d.Name.Value)
225 }
226 d.Body = r.blockOf(d.Pos(), r.panicStmt(d.Pos(),
227 "unexpected call of original function rewritten to specialized SIMD"))
228 newDecls = append(newDecls, d)
229 continue
230 }
231
232
233 d.Body = r.createDispatcherBody(d, sig)
234 newDecls = append(newDecls, d)
235
236 case *syntax.VarDecl:
237
238
239 newDecls = append(newDecls, d)
240
241 case *syntax.TypeDecl:
242
243
244 newDecls = append(newDecls, d)
245 default:
246 newDecls = append(newDecls, decl)
247 }
248 }
249
250 if !change {
251 return
252 }
253
254 fileAST.DeclList = newDecls
255
256 if !r.analyzer.inSimd {
257
258 hasArchSimd := false
259 var simdImport *syntax.ImportDecl
260 p := fileAST.Pos()
261 for _, decl := range fileAST.DeclList {
262 if imp, ok := decl.(*syntax.ImportDecl); ok {
263 if imp.Path.Value == `"`+archFullPkg+`"` {
264 hasArchSimd = true
265 if simdImport == nil {
266 p = imp.Pos()
267 }
268 }
269 if imp.Path.Value == `"`+simdPkg+`"` {
270 simdImport = imp
271 p = imp.Pos()
272 }
273 }
274 }
275
276 if !hasArchSimd {
277 r.injectImport(fileAST, archFullPkg, p)
278 }
279
280
281
282 fun := &syntax.SelectorExpr{
283 X: syntax.NewName(p, simdPkg),
284 Sel: syntax.NewName(p, vectorSizeFn),
285 }
286 fun.SetPos(p)
287 call := &syntax.CallExpr{Fun: fun}
288 call.SetPos(p)
289
290 name := syntax.NewName(p, "_")
291
292 varDecl := &syntax.VarDecl{NameList: []*syntax.Name{name}, Values: call}
293 varDecl.SetPos(p)
294 fileAST.DeclList = append(fileAST.DeclList, varDecl)
295 }
296 }
297
298 func (r *Rewriter) injectImport(fileAST *syntax.File, toImport string, simdImportPos syntax.Pos) {
299 importDecl := &syntax.ImportDecl{
300 Path: &syntax.BasicLit{Value: `"` + toImport + `"`, Kind: syntax.StringLit},
301 }
302 importDecl.Path.SetPos(simdImportPos)
303 importDecl.SetPos(simdImportPos)
304 fileAST.DeclList = append([]syntax.Decl{importDecl}, fileAST.DeclList...)
305 }
306
307 func (r *Rewriter) createDispatcherBody(d *syntax.FuncDecl, sig *types2.Signature) *syntax.BlockStmt {
308
309
310 args := func() []syntax.Expr {
311 var args []syntax.Expr
312 if d.Type.ParamList != nil {
313 for _, field := range d.Type.ParamList {
314 if field.Name != nil {
315 paramName := syntax.NewName(field.Pos(), field.Name.Value)
316 args = append(args, paramName)
317 }
318 }
319 }
320 return args
321 }
322
323
324 pe := func(e syntax.Expr) syntax.Expr {
325 e.SetPos(d.Pos())
326 return e
327 }
328
329 ps := func(e syntax.Stmt) syntax.Stmt {
330 e.SetPos(d.Pos())
331 return e
332 }
333
334
335
336
337
338
339
340
341
342
343
344 switchStmt := &syntax.SwitchStmt{
345 Tag: pe(&syntax.CallExpr{
346 Fun: pe(&syntax.SelectorExpr{
347 X: syntax.NewName(d.Pos(), simdPkg),
348 Sel: syntax.NewName(d.Pos(), vectorSizeFn),
349 }),
350 }),
351 Body: []*syntax.CaseClause{},
352 }
353
354 var emulation syntax.Stmt
355
356 for _, k := range r.sizes {
357 fnName := fmt.Sprintf("%s@simd%d", d.Name.Value, k)
358 fnIdent := syntax.NewName(d.Pos(), fnName)
359
360 callExpr := pe(&syntax.CallExpr{
361 Fun: pe(fnIdent),
362 ArgList: args(),
363 })
364
365
366 var callReturnStmt syntax.Stmt
367 if d.Type.ResultList != nil && len(d.Type.ResultList) > 0 {
368 callReturnStmt = &syntax.ReturnStmt{Results: callExpr}
369 } else {
370 callReturnStmt = &syntax.BlockStmt{
371 List: []syntax.Stmt{
372 ps(&syntax.ExprStmt{X: callExpr}),
373 ps(&syntax.ReturnStmt{}),
374 },
375 Rbrace: d.Pos(),
376 }
377 }
378 callReturnStmt.SetPos(d.Pos())
379
380 if k == 0 {
381
382
383 cond := pe(&syntax.CallExpr{
384 Fun: pe(&syntax.SelectorExpr{
385 X: syntax.NewName(d.Pos(), simdPkg),
386 Sel: syntax.NewName(d.Pos(), emulatedFn),
387 })})
388
389 blockStmt, ok := callReturnStmt.(*syntax.BlockStmt)
390 if !ok {
391 blockStmt = &syntax.BlockStmt{
392 List: []syntax.Stmt{callReturnStmt},
393 Rbrace: d.Pos(),
394 }
395 blockStmt.SetPos(d.Pos())
396 }
397
398 emulation = ps(&syntax.IfStmt{
399 Cond: cond,
400 Then: blockStmt,
401 })
402 continue
403 }
404
405 var caseBody []syntax.Stmt
406
407
408 if emulation != nil && k == 128 {
409 caseBody = append(caseBody, emulation)
410 emulation = nil
411 }
412
413 caseClause := &syntax.CaseClause{
414 Cases: pe(&syntax.BasicLit{Kind: syntax.IntLit, Value: fmt.Sprintf("%d", k)}),
415 Body: append(caseBody, callReturnStmt),
416 }
417 caseClause.SetPos(d.Pos())
418 switchStmt.Body = append(switchStmt.Body, caseClause)
419 }
420
421 panicStmt := r.panicStmt(d.Pos(), "unsupported vector size in simd-rewritten code")
422 return r.blockOf(d.Pos(), switchStmt, panicStmt)
423 }
424
425 func (r *Rewriter) blockOf(p syntax.Pos, stmts ...syntax.Stmt) *syntax.BlockStmt {
426 for _, s := range stmts {
427 s.SetPos(p)
428 }
429 blockStmt := &syntax.BlockStmt{List: stmts}
430 blockStmt.SetPos(p)
431 return blockStmt
432 }
433
434 func (r *Rewriter) panicStmt(p syntax.Pos, unquotedMessage string) *syntax.ExprStmt {
435 pe := func(e syntax.Expr) syntax.Expr {
436 e.SetPos(p)
437 return e
438 }
439 fnName := "panic"
440 fnIdent := pe(syntax.NewName(p, fnName))
441 callExpr := pe(&syntax.CallExpr{
442 Fun: fnIdent,
443 ArgList: []syntax.Expr{pe(&syntax.BasicLit{Value: `"` + unquotedMessage + `"`, Kind: syntax.StringLit})},
444 })
445 panicStmt := &syntax.ExprStmt{X: callExpr}
446 panicStmt.SetPos(p)
447 return panicStmt
448 }
449
450 func (r *Rewriter) generateForSize(fileAST *syntax.File, k int, newDecls []syntax.Decl) []syntax.Decl {
451 copier := NewDeepCopier(r.pkg, r.info, k, r.analyzer, fmt.Sprintf("@simd%d", k))
452 for _, decl := range fileAST.DeclList {
453 if r.shouldIncludeDecl(decl) {
454 newDecl := copier.CopyDecl(decl)
455 newDecls = append(newDecls, newDecl)
456 }
457 }
458 return newDecls
459 }
460
461 func nameToElemBitWidth(name string) int {
462 var width int
463 switch name {
464 case "Int8s", "Uint8s", "Mask8s":
465 width = 8
466 case "Int16s", "Uint16s", "Mask16s":
467 width = 16
468 case "Int32s", "Uint32s", "Float32s", "Mask32s":
469 width = 32
470 case "Int64s", "Uint64s", "Float64s", "Mask64s":
471 width = 64
472 }
473 return width
474 }
475
476 func (r *Rewriter) shouldIncludeDecl(decl syntax.Decl) bool {
477
478
479
480 if r.analyzer.inSimd {
481 theFile := decl.Pos().Base().Filename()
482
483 lastSlash := strings.LastIndex(theFile, simdPkg+"/")
484 lastBackslash := strings.LastIndex(theFile, simdPkg+"\\")
485
486
487
488
489 maxSlash := max(lastSlash, lastBackslash)
490 if maxSlash == -1 {
491 return false
492 }
493 if !strings.HasPrefix(theFile[maxSlash:], simdPkg+"/tofrom_") &&
494 !strings.HasPrefix(theFile[maxSlash:], simdPkg+"\\tofrom_") {
495 return false
496 }
497 }
498
499 switch d := decl.(type) {
500 case *syntax.FuncDecl:
501 if d.Name != nil {
502 return r.analyzer.isDependentObj[r.info.Defs[d.Name]]
503 }
504 case *syntax.TypeDecl:
505 return r.analyzer.isDependentObj[r.info.Defs[d.Name]]
506 case *syntax.VarDecl:
507 for _, name := range d.NameList {
508 if r.analyzer.isDependentObj[r.info.Defs[name]] {
509 return true
510 }
511 }
512 }
513 return false
514 }
515
516
517 func RewriteWrapper(pkg *types2.Package, info *types2.Info, files []*syntax.File) bool {
518 if !buildcfg.Experiment.SIMD {
519 return false
520 }
521
522 switch buildcfg.GOARCH {
523 case "wasm", "amd64", "arm64":
524 default:
525 return false
526 }
527
528 sizes := rewriteSizes()
529 if len(sizes) == 0 {
530 return false
531 }
532 analyzer := NewAnalyzer(pkg, info)
533 if !analyzer.Analyze(files) {
534 return false
535 }
536
537 CheckPositions(files, "before midway")
538
539 rewriter := NewRewriter(pkg, info, analyzer, sizes)
540 rewriter.Rewrite(files)
541
542 CheckPositions(files, "after midway")
543
544 return true
545 }
546
View as plain text