Source file src/crypto/internal/fips140/edwards25519/field/_asm/fe_amd64_asm.go

     1  // Copyright (c) 2021 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package main
     6  
     7  import (
     8  	"fmt"
     9  
    10  	. "github.com/mmcloughlin/avo/build"
    11  	. "github.com/mmcloughlin/avo/gotypes"
    12  	. "github.com/mmcloughlin/avo/operand"
    13  	. "github.com/mmcloughlin/avo/reg"
    14  )
    15  
    16  //go:generate go run . -out ../fe_amd64.s -stubs ../fe_amd64.go -pkg field
    17  
    18  func main() {
    19  	Package("crypto/internal/fips140/edwards25519/field")
    20  	ConstraintExpr("!purego")
    21  	feMul()
    22  	feSquare()
    23  	Generate()
    24  }
    25  
    26  type namedComponent struct {
    27  	Component
    28  	name string
    29  }
    30  
    31  func (c namedComponent) String() string { return c.name }
    32  
    33  type uint128 struct {
    34  	name   string
    35  	hi, lo GPVirtual
    36  }
    37  
    38  func (c uint128) String() string { return c.name }
    39  
    40  func feSquare() {
    41  	TEXT("feSquare", NOSPLIT, "func(out, a *Element)")
    42  	Doc("feSquare sets out = a * a. It works like feSquareGeneric.")
    43  	Pragma("noescape")
    44  
    45  	a := Dereference(Param("a"))
    46  	l0 := namedComponent{a.Field("l0"), "l0"}
    47  	l1 := namedComponent{a.Field("l1"), "l1"}
    48  	l2 := namedComponent{a.Field("l2"), "l2"}
    49  	l3 := namedComponent{a.Field("l3"), "l3"}
    50  	l4 := namedComponent{a.Field("l4"), "l4"}
    51  
    52  	// r0 = l0×l0 + 19×2×(l1×l4 + l2×l3)
    53  	r0 := uint128{"r0", GP64(), GP64()}
    54  	mul64(r0, 1, l0, l0)
    55  	addMul64(r0, 38, l1, l4)
    56  	addMul64(r0, 38, l2, l3)
    57  
    58  	// r1 = 2×l0×l1 + 19×2×l2×l4 + 19×l3×l3
    59  	r1 := uint128{"r1", GP64(), GP64()}
    60  	mul64(r1, 2, l0, l1)
    61  	addMul64(r1, 38, l2, l4)
    62  	addMul64(r1, 19, l3, l3)
    63  
    64  	// r2 = = 2×l0×l2 + l1×l1 + 19×2×l3×l4
    65  	r2 := uint128{"r2", GP64(), GP64()}
    66  	mul64(r2, 2, l0, l2)
    67  	addMul64(r2, 1, l1, l1)
    68  	addMul64(r2, 38, l3, l4)
    69  
    70  	// r3 = = 2×l0×l3 + 2×l1×l2 + 19×l4×l4
    71  	r3 := uint128{"r3", GP64(), GP64()}
    72  	mul64(r3, 2, l0, l3)
    73  	addMul64(r3, 2, l1, l2)
    74  	addMul64(r3, 19, l4, l4)
    75  
    76  	// r4 = = 2×l0×l4 + 2×l1×l3 + l2×l2
    77  	r4 := uint128{"r4", GP64(), GP64()}
    78  	mul64(r4, 2, l0, l4)
    79  	addMul64(r4, 2, l1, l3)
    80  	addMul64(r4, 1, l2, l2)
    81  
    82  	Comment("First reduction chain")
    83  	maskLow51Bits := GP64()
    84  	MOVQ(Imm((1<<51)-1), maskLow51Bits)
    85  	c0, r0lo := shiftRightBy51(&r0)
    86  	c1, r1lo := shiftRightBy51(&r1)
    87  	c2, r2lo := shiftRightBy51(&r2)
    88  	c3, r3lo := shiftRightBy51(&r3)
    89  	c4, r4lo := shiftRightBy51(&r4)
    90  	maskAndAdd(r0lo, maskLow51Bits, c4, 19)
    91  	maskAndAdd(r1lo, maskLow51Bits, c0, 1)
    92  	maskAndAdd(r2lo, maskLow51Bits, c1, 1)
    93  	maskAndAdd(r3lo, maskLow51Bits, c2, 1)
    94  	maskAndAdd(r4lo, maskLow51Bits, c3, 1)
    95  
    96  	Comment("Second reduction chain (carryPropagate)")
    97  	// c0 = r0 >> 51
    98  	MOVQ(r0lo, c0)
    99  	SHRQ(Imm(51), c0)
   100  	// c1 = r1 >> 51
   101  	MOVQ(r1lo, c1)
   102  	SHRQ(Imm(51), c1)
   103  	// c2 = r2 >> 51
   104  	MOVQ(r2lo, c2)
   105  	SHRQ(Imm(51), c2)
   106  	// c3 = r3 >> 51
   107  	MOVQ(r3lo, c3)
   108  	SHRQ(Imm(51), c3)
   109  	// c4 = r4 >> 51
   110  	MOVQ(r4lo, c4)
   111  	SHRQ(Imm(51), c4)
   112  	maskAndAdd(r0lo, maskLow51Bits, c4, 19)
   113  	maskAndAdd(r1lo, maskLow51Bits, c0, 1)
   114  	maskAndAdd(r2lo, maskLow51Bits, c1, 1)
   115  	maskAndAdd(r3lo, maskLow51Bits, c2, 1)
   116  	maskAndAdd(r4lo, maskLow51Bits, c3, 1)
   117  
   118  	Comment("Store output")
   119  	out := Dereference(Param("out"))
   120  	Store(r0lo, out.Field("l0"))
   121  	Store(r1lo, out.Field("l1"))
   122  	Store(r2lo, out.Field("l2"))
   123  	Store(r3lo, out.Field("l3"))
   124  	Store(r4lo, out.Field("l4"))
   125  
   126  	RET()
   127  }
   128  
   129  func feMul() {
   130  	TEXT("feMul", NOSPLIT, "func(out, a, b *Element)")
   131  	Doc("feMul sets out = a * b. It works like feMulGeneric.")
   132  	Pragma("noescape")
   133  
   134  	a := Dereference(Param("a"))
   135  	a0 := namedComponent{a.Field("l0"), "a0"}
   136  	a1 := namedComponent{a.Field("l1"), "a1"}
   137  	a2 := namedComponent{a.Field("l2"), "a2"}
   138  	a3 := namedComponent{a.Field("l3"), "a3"}
   139  	a4 := namedComponent{a.Field("l4"), "a4"}
   140  
   141  	b := Dereference(Param("b"))
   142  	b0 := namedComponent{b.Field("l0"), "b0"}
   143  	b1 := namedComponent{b.Field("l1"), "b1"}
   144  	b2 := namedComponent{b.Field("l2"), "b2"}
   145  	b3 := namedComponent{b.Field("l3"), "b3"}
   146  	b4 := namedComponent{b.Field("l4"), "b4"}
   147  
   148  	// r0 = a0×b0 + 19×(a1×b4 + a2×b3 + a3×b2 + a4×b1)
   149  	r0 := uint128{"r0", GP64(), GP64()}
   150  	mul64(r0, 1, a0, b0)
   151  	addMul64(r0, 19, a1, b4)
   152  	addMul64(r0, 19, a2, b3)
   153  	addMul64(r0, 19, a3, b2)
   154  	addMul64(r0, 19, a4, b1)
   155  
   156  	// r1 = a0×b1 + a1×b0 + 19×(a2×b4 + a3×b3 + a4×b2)
   157  	r1 := uint128{"r1", GP64(), GP64()}
   158  	mul64(r1, 1, a0, b1)
   159  	addMul64(r1, 1, a1, b0)
   160  	addMul64(r1, 19, a2, b4)
   161  	addMul64(r1, 19, a3, b3)
   162  	addMul64(r1, 19, a4, b2)
   163  
   164  	// r2 = a0×b2 + a1×b1 + a2×b0 + 19×(a3×b4 + a4×b3)
   165  	r2 := uint128{"r2", GP64(), GP64()}
   166  	mul64(r2, 1, a0, b2)
   167  	addMul64(r2, 1, a1, b1)
   168  	addMul64(r2, 1, a2, b0)
   169  	addMul64(r2, 19, a3, b4)
   170  	addMul64(r2, 19, a4, b3)
   171  
   172  	// r3 = a0×b3 + a1×b2 + a2×b1 + a3×b0 + 19×a4×b4
   173  	r3 := uint128{"r3", GP64(), GP64()}
   174  	mul64(r3, 1, a0, b3)
   175  	addMul64(r3, 1, a1, b2)
   176  	addMul64(r3, 1, a2, b1)
   177  	addMul64(r3, 1, a3, b0)
   178  	addMul64(r3, 19, a4, b4)
   179  
   180  	// r4 = a0×b4 + a1×b3 + a2×b2 + a3×b1 + a4×b0
   181  	r4 := uint128{"r4", GP64(), GP64()}
   182  	mul64(r4, 1, a0, b4)
   183  	addMul64(r4, 1, a1, b3)
   184  	addMul64(r4, 1, a2, b2)
   185  	addMul64(r4, 1, a3, b1)
   186  	addMul64(r4, 1, a4, b0)
   187  
   188  	Comment("First reduction chain")
   189  	maskLow51Bits := GP64()
   190  	MOVQ(Imm((1<<51)-1), maskLow51Bits)
   191  	c0, r0lo := shiftRightBy51(&r0)
   192  	c1, r1lo := shiftRightBy51(&r1)
   193  	c2, r2lo := shiftRightBy51(&r2)
   194  	c3, r3lo := shiftRightBy51(&r3)
   195  	c4, r4lo := shiftRightBy51(&r4)
   196  	maskAndAdd(r0lo, maskLow51Bits, c4, 19)
   197  	maskAndAdd(r1lo, maskLow51Bits, c0, 1)
   198  	maskAndAdd(r2lo, maskLow51Bits, c1, 1)
   199  	maskAndAdd(r3lo, maskLow51Bits, c2, 1)
   200  	maskAndAdd(r4lo, maskLow51Bits, c3, 1)
   201  
   202  	Comment("Second reduction chain (carryPropagate)")
   203  	// c0 = r0 >> 51
   204  	MOVQ(r0lo, c0)
   205  	SHRQ(Imm(51), c0)
   206  	// c1 = r1 >> 51
   207  	MOVQ(r1lo, c1)
   208  	SHRQ(Imm(51), c1)
   209  	// c2 = r2 >> 51
   210  	MOVQ(r2lo, c2)
   211  	SHRQ(Imm(51), c2)
   212  	// c3 = r3 >> 51
   213  	MOVQ(r3lo, c3)
   214  	SHRQ(Imm(51), c3)
   215  	// c4 = r4 >> 51
   216  	MOVQ(r4lo, c4)
   217  	SHRQ(Imm(51), c4)
   218  	maskAndAdd(r0lo, maskLow51Bits, c4, 19)
   219  	maskAndAdd(r1lo, maskLow51Bits, c0, 1)
   220  	maskAndAdd(r2lo, maskLow51Bits, c1, 1)
   221  	maskAndAdd(r3lo, maskLow51Bits, c2, 1)
   222  	maskAndAdd(r4lo, maskLow51Bits, c3, 1)
   223  
   224  	Comment("Store output")
   225  	out := Dereference(Param("out"))
   226  	Store(r0lo, out.Field("l0"))
   227  	Store(r1lo, out.Field("l1"))
   228  	Store(r2lo, out.Field("l2"))
   229  	Store(r3lo, out.Field("l3"))
   230  	Store(r4lo, out.Field("l4"))
   231  
   232  	RET()
   233  }
   234  
   235  // mul64 sets r to i * aX * bX.
   236  func mul64(r uint128, i int, aX, bX namedComponent) {
   237  	switch i {
   238  	case 1:
   239  		Comment(fmt.Sprintf("%s = %s×%s", r, aX, bX))
   240  		Load(aX, RAX)
   241  	case 2:
   242  		Comment(fmt.Sprintf("%s = 2×%s×%s", r, aX, bX))
   243  		Load(aX, RAX)
   244  		SHLQ(Imm(1), RAX)
   245  	default:
   246  		panic("unsupported i value")
   247  	}
   248  	MULQ(mustAddr(bX)) // RDX, RAX = RAX * bX
   249  	MOVQ(RAX, r.lo)
   250  	MOVQ(RDX, r.hi)
   251  }
   252  
   253  // addMul64 sets r to r + i * aX * bX.
   254  func addMul64(r uint128, i uint64, aX, bX namedComponent) {
   255  	switch i {
   256  	case 1:
   257  		Comment(fmt.Sprintf("%s += %s×%s", r, aX, bX))
   258  		Load(aX, RAX)
   259  	case 2:
   260  		Comment(fmt.Sprintf("%s += %d×%s×%s", r, i, aX, bX))
   261  		Load(aX, RAX)
   262  		SHLQ(U8(1), RAX)
   263  	case 19:
   264  		Comment(fmt.Sprintf("%s += %d×%s×%s", r, i, aX, bX))
   265  		// 19 * v ==> v + (v+v*8)*2
   266  		tmp := Load(aX, GP64())
   267  		LEAQ(Mem{Base: tmp, Index: tmp, Scale: 8}, RAX)
   268  		LEAQ(Mem{Base: tmp, Index: RAX, Scale: 2}, RAX)
   269  	case 38:
   270  		Comment(fmt.Sprintf("%s += %d×%s×%s", r, i, aX, bX))
   271  		// 38 * v ==> (v + (v+v*8)*2) * 2
   272  		tmp := Load(aX, GP64())
   273  		LEAQ(Mem{Base: tmp, Index: tmp, Scale: 8}, RAX)
   274  		LEAQ(Mem{Base: tmp, Index: RAX, Scale: 2}, RAX)
   275  		SHLQ(U8(1), RAX)
   276  	default:
   277  		Comment(fmt.Sprintf("%s += %d×%s×%s", r, i, aX, bX))
   278  		IMUL3Q(Imm(i), Load(aX, GP64()), RAX)
   279  	}
   280  	MULQ(mustAddr(bX)) // RDX, RAX = RAX * bX
   281  	ADDQ(RAX, r.lo)
   282  	ADCQ(RDX, r.hi)
   283  }
   284  
   285  // shiftRightBy51 returns r >> 51 and r.lo.
   286  //
   287  // After this function is called, the uint128 may not be used anymore.
   288  func shiftRightBy51(r *uint128) (out, lo GPVirtual) {
   289  	out = r.hi
   290  	lo = r.lo
   291  	SHLQ(Imm(64-51), r.lo, r.hi)
   292  	r.lo, r.hi = nil, nil // make sure the uint128 is unusable
   293  	return
   294  }
   295  
   296  // maskAndAdd sets r = r&mask + c*i.
   297  func maskAndAdd(r, mask, c GPVirtual, i uint64) {
   298  	ANDQ(mask, r)
   299  	if i != 1 {
   300  		IMUL3Q(Imm(i), c, c)
   301  	}
   302  	ADDQ(c, r)
   303  }
   304  
   305  func mustAddr(c Component) Op {
   306  	b, err := c.Resolve()
   307  	if err != nil {
   308  		panic(err)
   309  	}
   310  	return b.Addr
   311  }
   312  

View as plain text