Source file src/cmd/compile/internal/midway/rewrite.go

     1  // Copyright 2026 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package midway
     6  
     7  import (
     8  	"cmd/compile/internal/base"
     9  	"cmd/compile/internal/syntax"
    10  	"cmd/compile/internal/types2"
    11  	"fmt"
    12  	"internal/buildcfg"
    13  	"strings"
    14  )
    15  
    16  // "Midway" rewriting
    17  //
    18  // Go attempts to provide a package similar to the the "Highway" library
    19  // for C++ (https://google.github.io/highway).  The library package is "simd"
    20  // and defines vector types with unspecified widths that are bound to particular
    21  // machine dependent types as late as program execution.  This is accomplished
    22  // by rewriting code that depends on these types into code that references
    23  // architecture-specific types, perhaps more than once, and if necessary
    24  // dynamically choosing which version to execute based on hardware attributes.
    25  //
    26  // The rewriting takes place early in the compiler, after type checking but
    27  // before conversion to "unified" IR.  To ensure that types are correctly set
    28  // on the modified version of the code, type checking information is reset and
    29  // the type checking phase is re-run.  The places some limits on the shape of
    30  // the rewrites, but it also ensures that the rewritten code is well-formed.
    31  //
    32  // Rewritten code does not reference "archsimd" types directly, but instead
    33  // references types in a "bridge" package that filters the available methods
    34  // and adds a few more.  The package used relies on a builder/compiler hack;
    35  // the compiler's type checker enforces export naming conventions, but the
    36  // build system limits visibility to unrelated "internal" packages and can be
    37  // modified to allow access in special cases (like this one).  This allows the
    38  // rewritten code to reference types, functions, and methods that are not
    39  // accessible otherwise.
    40  //
    41  // The rewrite works in phases.  The first is "analysis", to discover functions,
    42  // types, methods, and variables that depend on "simd" types.  "Depend on" means
    43  // any mention of a simd type, and for types, also includes types that have a
    44  // simd-dependent method.  Dependent functions are split into two categories;
    45  // those whose dependence includes their signature, and those that do not.
    46  // The second category forms the boundary between code that depends on simd and
    47  // code that does not.  Notice that there cannot be a boundary method, because
    48  // (by design) the receiver type is simd-dependent and thus a dependent method
    49  // also has a dependent type in its signature.
    50  //
    51  // The second phase rewrites such "boundary" functions into a "dispatch" version
    52  // and (later, third phase) "specialized" versions.  The dispatch function
    53  // will choose which specialized version to call based on which simd implementation
    54  // has been chosen, and forward parameters and results to/from that specialized version
    55  // of the function.  The dispatch version shares the same name as the original function.
    56  // Note that this applies to functions only, and not methods.
    57  
    58  // The third phase specializes dependent functions (both kinds), methods,
    59  // global variables, and types into size/emulation/feature-specific variants.
    60  // Except for methods, this is done by adding a suffix beginning with "@" to
    61  // the name.  Because "@" cannot appear in legal Go identifiers this removes
    62  // the risk of a naming overlap.  Methods are specialized, but not renamed,
    63  // because their receiver type is renamed instead.  Not changing method names
    64  // preserves interface satisfaction, for example in the case of generic interfaces.
    65  //
    66  // Non-boundary dependent function and methods are not rewritten into dispatch
    67  // functions/methods, but remain in the generated code because they must be
    68  // present in the export data so that other packages that import them will still
    69  // compile before rewriting.  Their bodies are replaced with panic(...) to allow
    70  // compilation while preventing even worse chaos in the event of a bug either in
    71  // the compiler or through ambitious use of reflection or assembly language.
    72  //
    73  
    74  /* Example rewrites
    75  
    76  // Type alias, global variable, and init function:
    77  
    78  // before:
    79  type MyInt8s = simd.Int8s
    80  func Generic[T haslen](x int) int {
    81      var v T
    82      return x + v.Len()
    83  }
    84  var VL int
    85  func init() {
    86      VL = Generic[MyInt8s](1)
    87  }
    88  // dispatch:
    89  func init() {
    90      switch simd.VectorBitSize() {
    91      case
    92          128:
    93              init@simd128()
    94              return
    95      case 256:
    96              init@simd256()
    97              return
    98      case 512:
    99              init@simd512()
   100              return
   101      default:
   102          panic("unsupported vector size")
   103      }
   104  }
   105  // specialized (128)
   106  type MyInt8s@simd128 = archsimd.Int8x16
   107  func init@simd128() {
   108          VL = Generic[MyInt8s@simd128](1)
   109  }
   110  
   111  
   112  // structure containing simd fields, and with simd methods
   113  
   114  // before
   115  // A struct dependent on SIMD
   116  type VectorC struct {
   117      Field simd.Float32s
   118  }
   119  func (v *VectorC) MethodOfSimd() bool {
   120      return false
   121  }
   122  func (v VectorC) Data() simd.Float32s {
   123      return v.Field
   124  }
   125  func (v VectorC) Foo(x VectorC) VectorC {
   126      return VectorC{Field: v.Field.Add(x.Field)}
   127  }
   128  
   129  // dispatch
   130  // technically there is none, but functions with panicking bodies
   131  // remain because code must pass type checking before rewriting.
   132  type VectorC struct {
   133      Field simd.Float32s
   134  }
   135  func (v *VectorC) MethodOfSimd() bool {
   136      panic(...)
   137  }
   138  func (v VectorC) Data() simd.Float32s {
   139      panic(...)
   140  }
   141  func (v VectorC) Foo(x VectorC) VectorC {
   142      panic(...)
   143  }
   144  
   145  // specialized (128)
   146  
   147  // A struct dependent on SIMD
   148  type VectorC@simd128 struct {
   149      Field bridge.Float32x4
   150  }
   151  func (v *VectorC@simd128) MethodOfSimd() bool {
   152      return false
   153  }
   154  func (v VectorC@simd128) Data() bridge.Float32x4 {
   155      return v.Field
   156  }
   157  func (v VectorC@simd128) Foo(x VectorC@simd128) VectorC@simd128 {
   158      return VectorC@simd128{Field: v.Field.Add(x.Field)}
   159  }
   160  
   161  */
   162  
   163  type Rewriter struct {
   164  	pkg      *types2.Package
   165  	analyzer *Analyzer
   166  	info     *types2.Info
   167  	sizes    []int
   168  }
   169  
   170  func NewRewriter(pkg *types2.Package, info *types2.Info, analyzer *Analyzer, sizes []int) *Rewriter {
   171  	return &Rewriter{
   172  		pkg:      pkg,
   173  		info:     info,
   174  		analyzer: analyzer,
   175  		sizes:    sizes,
   176  	}
   177  }
   178  
   179  func (r *Rewriter) Rewrite(files []*syntax.File) {
   180  
   181  	// First duplicate and specialize all dependent functions and variables.
   182  	for _, fileAST := range files {
   183  
   184  		var newDecls []syntax.Decl
   185  		for _, k := range r.sizes {
   186  			newDecls = r.generateForSize(fileAST, k, newDecls)
   187  		}
   188  
   189  		// Then replace original functions with dispatchers.
   190  		// This also edits the DeclList of fileAST.
   191  		r.generateDispatchers(fileAST)
   192  
   193  		fileAST.DeclList = append(fileAST.DeclList, newDecls...)
   194  	}
   195  }
   196  
   197  func (r *Rewriter) generateDispatchers(fileAST *syntax.File) {
   198  	var newDecls []syntax.Decl
   199  
   200  	change := false
   201  
   202  	for _, decl := range fileAST.DeclList {
   203  		switch d := decl.(type) {
   204  		case *syntax.FuncDecl:
   205  			if d.Name == nil {
   206  				newDecls = append(newDecls, d)
   207  				continue
   208  			}
   209  			obj := r.info.Defs[d.Name]
   210  			if !r.analyzer.isDependentObj[obj] || r.analyzer.inSimd {
   211  				newDecls = append(newDecls, d)
   212  				continue
   213  			}
   214  
   215  			sig, ok := obj.Type().(*types2.Signature)
   216  			if !ok {
   217  				newDecls = append(newDecls, d)
   218  				continue
   219  			}
   220  
   221  			change = true
   222  			if r.analyzer.HasDependentSignature(sig) {
   223  				if base.Debug.Simd > 0 {
   224  					base.Warn("%s: removing body of dependent-sig original function %v", d.Pos().String(), d.Name.Value)
   225  				}
   226  				d.Body = r.blockOf(d.Pos(), r.panicStmt(d.Pos(),
   227  					"unexpected call of original function rewritten to specialized SIMD"))
   228  				newDecls = append(newDecls, d)
   229  				continue
   230  			}
   231  
   232  			// Clean signature -> Replace body with dispatcher
   233  			d.Body = r.createDispatcherBody(d, sig)
   234  			newDecls = append(newDecls, d)
   235  
   236  		case *syntax.VarDecl:
   237  			// Keep var decls even if rewritten, so that pre-rewrite code parses correctly.
   238  			// TODO figure out how to deal with side-effects in initializers.
   239  			newDecls = append(newDecls, d)
   240  
   241  		case *syntax.TypeDecl:
   242  			// Keep all types; we need the untranslated copy if a method referencing it
   243  			// needs to typecheck pre-translation.
   244  			newDecls = append(newDecls, d)
   245  		default:
   246  			newDecls = append(newDecls, decl)
   247  		}
   248  	}
   249  
   250  	if !change {
   251  		return
   252  	}
   253  
   254  	fileAST.DeclList = newDecls
   255  
   256  	if !r.analyzer.inSimd {
   257  		// Inject an import to the bridge package (if not exists)
   258  		hasArchSimd := false
   259  		var simdImport *syntax.ImportDecl
   260  		p := fileAST.Pos()
   261  		for _, decl := range fileAST.DeclList {
   262  			if imp, ok := decl.(*syntax.ImportDecl); ok {
   263  				if imp.Path.Value == `"`+archFullPkg+`"` {
   264  					hasArchSimd = true
   265  					if simdImport == nil {
   266  						p = imp.Pos()
   267  					}
   268  				}
   269  				if imp.Path.Value == `"`+simdPkg+`"` {
   270  					simdImport = imp
   271  					p = imp.Pos()
   272  				}
   273  			}
   274  		}
   275  
   276  		if !hasArchSimd {
   277  			r.injectImport(fileAST, archFullPkg, p)
   278  		}
   279  
   280  		// Ensure at least one use of "simd"
   281  		// var _ = simd.VectorBitLen()
   282  		fun := &syntax.SelectorExpr{
   283  			X:   syntax.NewName(p, simdPkg), // Assume this is resolvable
   284  			Sel: syntax.NewName(p, vectorSizeFn),
   285  		}
   286  		fun.SetPos(p)
   287  		call := &syntax.CallExpr{Fun: fun}
   288  		call.SetPos(p)
   289  
   290  		name := syntax.NewName(p, "_")
   291  
   292  		varDecl := &syntax.VarDecl{NameList: []*syntax.Name{name}, Values: call}
   293  		varDecl.SetPos(p)
   294  		fileAST.DeclList = append(fileAST.DeclList, varDecl)
   295  	}
   296  }
   297  
   298  func (r *Rewriter) injectImport(fileAST *syntax.File, toImport string, simdImportPos syntax.Pos) {
   299  	importDecl := &syntax.ImportDecl{
   300  		Path: &syntax.BasicLit{Value: `"` + toImport + `"`, Kind: syntax.StringLit},
   301  	}
   302  	importDecl.Path.SetPos(simdImportPos)
   303  	importDecl.SetPos(simdImportPos)
   304  	fileAST.DeclList = append([]syntax.Decl{importDecl}, fileAST.DeclList...)
   305  }
   306  
   307  func (r *Rewriter) createDispatcherBody(d *syntax.FuncDecl, sig *types2.Signature) *syntax.BlockStmt {
   308  
   309  	// Build call arguments from the function parameters
   310  	args := func() []syntax.Expr {
   311  		var args []syntax.Expr
   312  		if d.Type.ParamList != nil {
   313  			for _, field := range d.Type.ParamList {
   314  				if field.Name != nil {
   315  					paramName := syntax.NewName(field.Pos(), field.Name.Value)
   316  					args = append(args, paramName)
   317  				}
   318  			}
   319  		}
   320  		return args
   321  	}
   322  
   323  	// Slap a pos on an expression
   324  	pe := func(e syntax.Expr) syntax.Expr {
   325  		e.SetPos(d.Pos())
   326  		return e
   327  	}
   328  	// Slap a pos on a statement
   329  	ps := func(e syntax.Stmt) syntax.Stmt {
   330  		e.SetPos(d.Pos())
   331  		return e
   332  	}
   333  
   334  	// switch ast node.
   335  	// the goal is something like (for now, till there are finer-grained choices)
   336  	// switch simd.VectorSize() {
   337  	//   case 128: if simd.Emulated() { call the specialize-for-emulation-code(args) }
   338  	//             else { call the specialize-for-128-code(args) }
   339  	//   case 256: call the specialize-for-256-code(args)
   340  	//   etc
   341  	// }
   342  	//
   343  	// the cases above deal with the usual `return call(...)` vs `call(...); return`
   344  	switchStmt := &syntax.SwitchStmt{
   345  		Tag: pe(&syntax.CallExpr{
   346  			Fun: pe(&syntax.SelectorExpr{
   347  				X:   syntax.NewName(d.Pos(), simdPkg), // Assume this is resolvable
   348  				Sel: syntax.NewName(d.Pos(), vectorSizeFn),
   349  			}),
   350  		}),
   351  		Body: []*syntax.CaseClause{},
   352  	}
   353  
   354  	var emulation syntax.Stmt
   355  
   356  	for _, k := range r.sizes {
   357  		fnName := fmt.Sprintf("%s@simd%d", d.Name.Value, k)
   358  		fnIdent := syntax.NewName(d.Pos(), fnName)
   359  
   360  		callExpr := pe(&syntax.CallExpr{
   361  			Fun:     pe(fnIdent),
   362  			ArgList: args(),
   363  		})
   364  
   365  		// callReturnStmt is either `return call(...)` or `call(...); return`
   366  		var callReturnStmt syntax.Stmt
   367  		if d.Type.ResultList != nil && len(d.Type.ResultList) > 0 {
   368  			callReturnStmt = &syntax.ReturnStmt{Results: callExpr}
   369  		} else {
   370  			callReturnStmt = &syntax.BlockStmt{
   371  				List: []syntax.Stmt{
   372  					ps(&syntax.ExprStmt{X: callExpr}),
   373  					ps(&syntax.ReturnStmt{}),
   374  				},
   375  				Rbrace: d.Pos(),
   376  			}
   377  		}
   378  		callReturnStmt.SetPos(d.Pos())
   379  
   380  		if k == 0 {
   381  			// emulation == `if simd.Emulated() { callReturnStmt }`
   382  			// save it for the first part of the 128 case.
   383  			cond := pe(&syntax.CallExpr{
   384  				Fun: pe(&syntax.SelectorExpr{
   385  					X:   syntax.NewName(d.Pos(), simdPkg), // Assume this is resolvable
   386  					Sel: syntax.NewName(d.Pos(), emulatedFn),
   387  				})})
   388  
   389  			blockStmt, ok := callReturnStmt.(*syntax.BlockStmt)
   390  			if !ok {
   391  				blockStmt = &syntax.BlockStmt{
   392  					List:   []syntax.Stmt{callReturnStmt},
   393  					Rbrace: d.Pos(),
   394  				}
   395  				blockStmt.SetPos(d.Pos())
   396  			}
   397  
   398  			emulation = ps(&syntax.IfStmt{
   399  				Cond: cond,
   400  				Then: blockStmt,
   401  			})
   402  			continue
   403  		}
   404  
   405  		var caseBody []syntax.Stmt
   406  		// assume that 128 is a case; when we do scalable simd, this may change.
   407  		// For now, if there is emulation, it is 128-bit (only).
   408  		if emulation != nil && k == 128 {
   409  			caseBody = append(caseBody, emulation)
   410  			emulation = nil
   411  		}
   412  
   413  		caseClause := &syntax.CaseClause{
   414  			Cases: pe(&syntax.BasicLit{Kind: syntax.IntLit, Value: fmt.Sprintf("%d", k)}),
   415  			Body:  append(caseBody, callReturnStmt),
   416  		}
   417  		caseClause.SetPos(d.Pos())
   418  		switchStmt.Body = append(switchStmt.Body, caseClause)
   419  	}
   420  
   421  	panicStmt := r.panicStmt(d.Pos(), "unsupported vector size in simd-rewritten code")
   422  	return r.blockOf(d.Pos(), switchStmt, panicStmt)
   423  }
   424  
   425  func (r *Rewriter) blockOf(p syntax.Pos, stmts ...syntax.Stmt) *syntax.BlockStmt {
   426  	for _, s := range stmts {
   427  		s.SetPos(p)
   428  	}
   429  	blockStmt := &syntax.BlockStmt{List: stmts}
   430  	blockStmt.SetPos(p)
   431  	return blockStmt
   432  }
   433  
   434  func (r *Rewriter) panicStmt(p syntax.Pos, unquotedMessage string) *syntax.ExprStmt {
   435  	pe := func(e syntax.Expr) syntax.Expr {
   436  		e.SetPos(p)
   437  		return e
   438  	}
   439  	fnName := "panic"
   440  	fnIdent := pe(syntax.NewName(p, fnName))
   441  	callExpr := pe(&syntax.CallExpr{
   442  		Fun:     fnIdent,
   443  		ArgList: []syntax.Expr{pe(&syntax.BasicLit{Value: `"` + unquotedMessage + `"`, Kind: syntax.StringLit})},
   444  	})
   445  	panicStmt := &syntax.ExprStmt{X: callExpr}
   446  	panicStmt.SetPos(p)
   447  	return panicStmt
   448  }
   449  
   450  func (r *Rewriter) generateForSize(fileAST *syntax.File, k int, newDecls []syntax.Decl) []syntax.Decl {
   451  	copier := NewDeepCopier(r.pkg, r.info, k, r.analyzer, fmt.Sprintf("@simd%d", k))
   452  	for _, decl := range fileAST.DeclList {
   453  		if r.shouldIncludeDecl(decl) {
   454  			newDecl := copier.CopyDecl(decl)
   455  			newDecls = append(newDecls, newDecl)
   456  		}
   457  	}
   458  	return newDecls
   459  }
   460  
   461  func nameToElemBitWidth(name string) int {
   462  	var width int
   463  	switch name {
   464  	case "Int8s", "Uint8s", "Mask8s":
   465  		width = 8
   466  	case "Int16s", "Uint16s", "Mask16s":
   467  		width = 16
   468  	case "Int32s", "Uint32s", "Float32s", "Mask32s":
   469  		width = 32
   470  	case "Int64s", "Uint64s", "Float64s", "Mask64s":
   471  		width = 64
   472  	}
   473  	return width
   474  }
   475  
   476  func (r *Rewriter) shouldIncludeDecl(decl syntax.Decl) bool {
   477  	// Files (and declarations) in the simd package are excluded
   478  	// from processing, except for those that whose name begins
   479  	// with "tofrom_".
   480  	if r.analyzer.inSimd {
   481  		theFile := decl.Pos().Base().Filename()
   482  
   483  		lastSlash := strings.LastIndex(theFile, simdPkg+"/")
   484  		lastBackslash := strings.LastIndex(theFile, simdPkg+"\\")
   485  
   486  		// Windows paths can be chaos, all we care, is whether the very last part
   487  		// of the path is any-path-separator + "tofrom_" + anything-else, given that
   488  		// we already know that we are in the simd package.
   489  		maxSlash := max(lastSlash, lastBackslash)
   490  		if maxSlash == -1 {
   491  			return false
   492  		}
   493  		if !strings.HasPrefix(theFile[maxSlash:], simdPkg+"/tofrom_") &&
   494  			!strings.HasPrefix(theFile[maxSlash:], simdPkg+"\\tofrom_") {
   495  			return false
   496  		}
   497  	}
   498  
   499  	switch d := decl.(type) {
   500  	case *syntax.FuncDecl:
   501  		if d.Name != nil {
   502  			return r.analyzer.isDependentObj[r.info.Defs[d.Name]]
   503  		}
   504  	case *syntax.TypeDecl:
   505  		return r.analyzer.isDependentObj[r.info.Defs[d.Name]]
   506  	case *syntax.VarDecl:
   507  		for _, name := range d.NameList {
   508  			if r.analyzer.isDependentObj[r.info.Defs[name]] {
   509  				return true
   510  			}
   511  		}
   512  	}
   513  	return false
   514  }
   515  
   516  // Generate an API matching the standalone compilation call
   517  func RewriteWrapper(pkg *types2.Package, info *types2.Info, files []*syntax.File) bool {
   518  	if !buildcfg.Experiment.SIMD {
   519  		return false
   520  	}
   521  
   522  	switch buildcfg.GOARCH {
   523  	case "wasm", "amd64", "arm64":
   524  	default:
   525  		return false
   526  	}
   527  
   528  	sizes := rewriteSizes()
   529  	if len(sizes) == 0 {
   530  		return false
   531  	}
   532  	analyzer := NewAnalyzer(pkg, info)
   533  	if !analyzer.Analyze(files) {
   534  		return false
   535  	}
   536  
   537  	CheckPositions(files, "before midway")
   538  
   539  	rewriter := NewRewriter(pkg, info, analyzer, sizes)
   540  	rewriter.Rewrite(files)
   541  
   542  	CheckPositions(files, "after midway")
   543  
   544  	return true
   545  }
   546  

View as plain text