Source file src/crypto/internal/fips140/bigmod/nat.go
1 // Copyright 2021 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package bigmod 6 7 import ( 8 _ "crypto/internal/fips140/check" 9 "crypto/internal/fips140deps/byteorder" 10 "errors" 11 "math/bits" 12 ) 13 14 const ( 15 // _W is the size in bits of our limbs. 16 _W = bits.UintSize 17 // _S is the size in bytes of our limbs. 18 _S = _W / 8 19 ) 20 21 // Note: These functions make many loops over all the words in a Nat. 22 // These loops used to be in assembly, invisible to -race, -asan, and -msan, 23 // but now they are in Go and incur significant overhead in those modes. 24 // To bring the old performance back, we mark all functions that loop 25 // over Nat words with //go:norace. Because //go:norace does not 26 // propagate across inlining, we must also mark functions that inline 27 // //go:norace functions - specifically, those that inline add, addMulVVW, 28 // assign, cmpGeq, rshift1, and sub. 29 30 // choice represents a constant-time boolean. The value of choice is always 31 // either 1 or 0. We use an int instead of bool in order to make decisions in 32 // constant time by turning it into a mask. 33 type choice uint 34 35 func not(c choice) choice { return 1 ^ c } 36 37 const yes = choice(1) 38 const no = choice(0) 39 40 // ctMask is all 1s if on is yes, and all 0s otherwise. 41 func ctMask(on choice) uint { return -uint(on) } 42 43 // ctEq returns 1 if x == y, and 0 otherwise. The execution time of this 44 // function does not depend on its inputs. 45 func ctEq(x, y uint) choice { 46 // If x != y, then either x - y or y - x will generate a carry. 47 _, c1 := bits.Sub(x, y, 0) 48 _, c2 := bits.Sub(y, x, 0) 49 return not(choice(c1 | c2)) 50 } 51 52 // Nat represents an arbitrary natural number 53 // 54 // Each Nat has an announced length, which is the number of limbs it has stored. 55 // Operations on this number are allowed to leak this length, but will not leak 56 // any information about the values contained in those limbs. 57 type Nat struct { 58 // limbs is little-endian in base 2^W with W = bits.UintSize. 59 limbs []uint 60 } 61 62 // preallocTarget is the size in bits of the numbers used to implement the most 63 // common and most performant RSA key size. It's also enough to cover some of 64 // the operations of key sizes up to 4096. 65 const preallocTarget = 2048 66 const preallocLimbs = (preallocTarget + _W - 1) / _W 67 68 // NewNat returns a new nat with a size of zero, just like new(Nat), but with 69 // the preallocated capacity to hold a number of up to preallocTarget bits. 70 // NewNat inlines, so the allocation can live on the stack. 71 func NewNat() *Nat { 72 limbs := make([]uint, 0, preallocLimbs) 73 return &Nat{limbs} 74 } 75 76 // expand expands x to n limbs, leaving its value unchanged. 77 func (x *Nat) expand(n int) *Nat { 78 if len(x.limbs) > n { 79 panic("bigmod: internal error: shrinking nat") 80 } 81 if cap(x.limbs) < n { 82 newLimbs := make([]uint, n) 83 copy(newLimbs, x.limbs) 84 x.limbs = newLimbs 85 return x 86 } 87 extraLimbs := x.limbs[len(x.limbs):n] 88 clear(extraLimbs) 89 x.limbs = x.limbs[:n] 90 return x 91 } 92 93 // reset returns a zero nat of n limbs, reusing x's storage if n <= cap(x.limbs). 94 func (x *Nat) reset(n int) *Nat { 95 if cap(x.limbs) < n { 96 x.limbs = make([]uint, n) 97 return x 98 } 99 // Clear both the returned limbs and the previously used ones. 100 clear(x.limbs[:max(n, len(x.limbs))]) 101 x.limbs = x.limbs[:n] 102 return x 103 } 104 105 // resetToBytes assigns x = b, where b is a slice of big-endian bytes, resizing 106 // n to the appropriate size. 107 // 108 // The announced length of x is set based on the actual bit size of the input, 109 // ignoring leading zeroes. 110 func (x *Nat) resetToBytes(b []byte) *Nat { 111 x.reset((len(b) + _S - 1) / _S) 112 if err := x.setBytes(b); err != nil { 113 panic("bigmod: internal error: bad arithmetic") 114 } 115 return x.trim() 116 } 117 118 // trim reduces the size of x to match its value. 119 func (x *Nat) trim() *Nat { 120 // Trim most significant (trailing in little-endian) zero limbs. 121 // We assume comparison with zero (but not the branch) is constant time. 122 for i := len(x.limbs) - 1; i >= 0; i-- { 123 if x.limbs[i] != 0 { 124 break 125 } 126 x.limbs = x.limbs[:i] 127 } 128 return x 129 } 130 131 // set assigns x = y, optionally resizing x to the appropriate size. 132 func (x *Nat) set(y *Nat) *Nat { 133 x.reset(len(y.limbs)) 134 copy(x.limbs, y.limbs) 135 return x 136 } 137 138 // Bits returns x as a little-endian slice of uint. The length of the slice 139 // matches the announced length of x. The result and x share the same underlying 140 // array. 141 func (x *Nat) Bits() []uint { 142 return x.limbs 143 } 144 145 // Bytes returns x as a zero-extended big-endian byte slice. The size of the 146 // slice will match the size of m. 147 // 148 // x must have the same size as m and it must be less than or equal to m. 149 func (x *Nat) Bytes(m *Modulus) []byte { 150 i := m.Size() 151 bytes := make([]byte, i) 152 for _, limb := range x.limbs { 153 for j := 0; j < _S; j++ { 154 i-- 155 if i < 0 { 156 if limb == 0 { 157 break 158 } 159 panic("bigmod: modulus is smaller than nat") 160 } 161 bytes[i] = byte(limb) 162 limb >>= 8 163 } 164 } 165 return bytes 166 } 167 168 // SetBytes assigns x = b, where b is a slice of big-endian bytes. 169 // SetBytes returns an error if b >= m. 170 // 171 // The output will be resized to the size of m and overwritten. 172 // 173 //go:norace 174 func (x *Nat) SetBytes(b []byte, m *Modulus) (*Nat, error) { 175 x.resetFor(m) 176 if err := x.setBytes(b); err != nil { 177 return nil, err 178 } 179 if x.cmpGeq(m.nat) == yes { 180 return nil, errors.New("input overflows the modulus") 181 } 182 return x, nil 183 } 184 185 // SetOverflowingBytes assigns x = b, where b is a slice of big-endian bytes. 186 // SetOverflowingBytes returns an error if b has a longer bit length than m, but 187 // reduces overflowing values up to 2^⌈log2(m)⌉ - 1. 188 // 189 // The output will be resized to the size of m and overwritten. 190 func (x *Nat) SetOverflowingBytes(b []byte, m *Modulus) (*Nat, error) { 191 x.resetFor(m) 192 if err := x.setBytes(b); err != nil { 193 return nil, err 194 } 195 // setBytes would have returned an error if the input overflowed the limb 196 // size of the modulus, so now we only need to check if the most significant 197 // limb of x has more bits than the most significant limb of the modulus. 198 if bitLen(x.limbs[len(x.limbs)-1]) > bitLen(m.nat.limbs[len(m.nat.limbs)-1]) { 199 return nil, errors.New("input overflows the modulus size") 200 } 201 x.maybeSubtractModulus(no, m) 202 return x, nil 203 } 204 205 // bigEndianUint returns the contents of buf interpreted as a 206 // big-endian encoded uint value. 207 func bigEndianUint(buf []byte) uint { 208 if _W == 64 { 209 return uint(byteorder.BEUint64(buf)) 210 } 211 return uint(byteorder.BEUint32(buf)) 212 } 213 214 func (x *Nat) setBytes(b []byte) error { 215 i, k := len(b), 0 216 for k < len(x.limbs) && i >= _S { 217 x.limbs[k] = bigEndianUint(b[i-_S : i]) 218 i -= _S 219 k++ 220 } 221 for s := 0; s < _W && k < len(x.limbs) && i > 0; s += 8 { 222 x.limbs[k] |= uint(b[i-1]) << s 223 i-- 224 } 225 if i > 0 { 226 return errors.New("input overflows the modulus size") 227 } 228 return nil 229 } 230 231 // SetUint assigns x = y. 232 // 233 // The output will be resized to a single limb and overwritten. 234 func (x *Nat) SetUint(y uint) *Nat { 235 x.reset(1) 236 x.limbs[0] = y 237 return x 238 } 239 240 // Equal returns 1 if x == y, and 0 otherwise. 241 // 242 // Both operands must have the same announced length. 243 // 244 //go:norace 245 func (x *Nat) Equal(y *Nat) choice { 246 // Eliminate bounds checks in the loop. 247 size := len(x.limbs) 248 xLimbs := x.limbs[:size] 249 yLimbs := y.limbs[:size] 250 251 equal := yes 252 for i := 0; i < size; i++ { 253 equal &= ctEq(xLimbs[i], yLimbs[i]) 254 } 255 return equal 256 } 257 258 // IsZero returns 1 if x == 0, and 0 otherwise. 259 // 260 //go:norace 261 func (x *Nat) IsZero() choice { 262 // Eliminate bounds checks in the loop. 263 size := len(x.limbs) 264 xLimbs := x.limbs[:size] 265 266 zero := yes 267 for i := 0; i < size; i++ { 268 zero &= ctEq(xLimbs[i], 0) 269 } 270 return zero 271 } 272 273 // IsOne returns 1 if x == 1, and 0 otherwise. 274 // 275 //go:norace 276 func (x *Nat) IsOne() choice { 277 // Eliminate bounds checks in the loop. 278 size := len(x.limbs) 279 xLimbs := x.limbs[:size] 280 281 if len(xLimbs) == 0 { 282 return no 283 } 284 285 one := ctEq(xLimbs[0], 1) 286 for i := 1; i < size; i++ { 287 one &= ctEq(xLimbs[i], 0) 288 } 289 return one 290 } 291 292 // IsMinusOne returns 1 if x == -1 mod m, and 0 otherwise. 293 // 294 // The length of x must be the same as the modulus. x must already be reduced 295 // modulo m. 296 // 297 //go:norace 298 func (x *Nat) IsMinusOne(m *Modulus) choice { 299 minusOne := m.Nat() 300 minusOne.SubOne(m) 301 return x.Equal(minusOne) 302 } 303 304 // IsOdd returns 1 if x is odd, and 0 otherwise. 305 func (x *Nat) IsOdd() choice { 306 if len(x.limbs) == 0 { 307 return no 308 } 309 return choice(x.limbs[0] & 1) 310 } 311 312 // TrailingZeroBitsVarTime returns the number of trailing zero bits in x. 313 func (x *Nat) TrailingZeroBitsVarTime() uint { 314 var t uint 315 limbs := x.limbs 316 for _, l := range limbs { 317 if l == 0 { 318 t += _W 319 continue 320 } 321 t += uint(bits.TrailingZeros(l)) 322 break 323 } 324 return t 325 } 326 327 // cmpGeq returns 1 if x >= y, and 0 otherwise. 328 // 329 // Both operands must have the same announced length. 330 // 331 //go:norace 332 func (x *Nat) cmpGeq(y *Nat) choice { 333 // Eliminate bounds checks in the loop. 334 size := len(x.limbs) 335 xLimbs := x.limbs[:size] 336 yLimbs := y.limbs[:size] 337 338 var c uint 339 for i := 0; i < size; i++ { 340 _, c = bits.Sub(xLimbs[i], yLimbs[i], c) 341 } 342 // If there was a carry, then subtracting y underflowed, so 343 // x is not greater than or equal to y. 344 return not(choice(c)) 345 } 346 347 // assign sets x <- y if on == 1, and does nothing otherwise. 348 // 349 // Both operands must have the same announced length. 350 // 351 //go:norace 352 func (x *Nat) assign(on choice, y *Nat) *Nat { 353 // Eliminate bounds checks in the loop. 354 size := len(x.limbs) 355 xLimbs := x.limbs[:size] 356 yLimbs := y.limbs[:size] 357 358 mask := ctMask(on) 359 for i := 0; i < size; i++ { 360 xLimbs[i] ^= mask & (xLimbs[i] ^ yLimbs[i]) 361 } 362 return x 363 } 364 365 // add computes x += y and returns the carry. 366 // 367 // Both operands must have the same announced length. 368 // 369 //go:norace 370 func (x *Nat) add(y *Nat) (c uint) { 371 // Eliminate bounds checks in the loop. 372 size := len(x.limbs) 373 xLimbs := x.limbs[:size] 374 yLimbs := y.limbs[:size] 375 376 for i := 0; i < size; i++ { 377 xLimbs[i], c = bits.Add(xLimbs[i], yLimbs[i], c) 378 } 379 return 380 } 381 382 // sub computes x -= y. It returns the borrow of the subtraction. 383 // 384 // Both operands must have the same announced length. 385 // 386 //go:norace 387 func (x *Nat) sub(y *Nat) (c uint) { 388 // Eliminate bounds checks in the loop. 389 size := len(x.limbs) 390 xLimbs := x.limbs[:size] 391 yLimbs := y.limbs[:size] 392 393 for i := 0; i < size; i++ { 394 xLimbs[i], c = bits.Sub(xLimbs[i], yLimbs[i], c) 395 } 396 return 397 } 398 399 // ShiftRightVarTime sets x = x >> n. 400 // 401 // The announced length of x is unchanged. 402 // 403 //go:norace 404 func (x *Nat) ShiftRightVarTime(n uint) *Nat { 405 // Eliminate bounds checks in the loop. 406 size := len(x.limbs) 407 xLimbs := x.limbs[:size] 408 409 shift := int(n % _W) 410 shiftLimbs := int(n / _W) 411 412 var shiftedLimbs []uint 413 if shiftLimbs < size { 414 shiftedLimbs = xLimbs[shiftLimbs:] 415 } 416 417 for i := range xLimbs { 418 if i >= len(shiftedLimbs) { 419 xLimbs[i] = 0 420 continue 421 } 422 423 xLimbs[i] = shiftedLimbs[i] >> shift 424 if i+1 < len(shiftedLimbs) { 425 xLimbs[i] |= shiftedLimbs[i+1] << (_W - shift) 426 } 427 } 428 429 return x 430 } 431 432 // BitLenVarTime returns the actual size of x in bits. 433 // 434 // The actual size of x (but nothing more) leaks through timing side-channels. 435 // Note that this is ordinarily secret, as opposed to the announced size of x. 436 func (x *Nat) BitLenVarTime() int { 437 // Eliminate bounds checks in the loop. 438 size := len(x.limbs) 439 xLimbs := x.limbs[:size] 440 441 for i := size - 1; i >= 0; i-- { 442 if xLimbs[i] != 0 { 443 return i*_W + bitLen(xLimbs[i]) 444 } 445 } 446 return 0 447 } 448 449 // bitLen is a version of bits.Len that only leaks the bit length of n, but not 450 // its value. bits.Len and bits.LeadingZeros use a lookup table for the 451 // low-order bits on some architectures. 452 func bitLen(n uint) int { 453 len := 0 454 // We assume, here and elsewhere, that comparison to zero is constant time 455 // with respect to different non-zero values. 456 for n != 0 { 457 len++ 458 n >>= 1 459 } 460 return len 461 } 462 463 // Modulus is used for modular arithmetic, precomputing relevant constants. 464 // 465 // A Modulus can leak the exact number of bits needed to store its value 466 // and is stored without padding. Its actual value is still kept secret. 467 type Modulus struct { 468 // The underlying natural number for this modulus. 469 // 470 // This will be stored without any padding, and shouldn't alias with any 471 // other natural number being used. 472 nat *Nat 473 474 // If m is even, the following fields are not set. 475 odd bool 476 m0inv uint // -nat.limbs[0]⁻¹ mod _W 477 rr *Nat // R*R for montgomeryRepresentation 478 } 479 480 // rr returns R*R with R = 2^(_W * n) and n = len(m.nat.limbs). 481 func rr(m *Modulus) *Nat { 482 rr := NewNat().ExpandFor(m) 483 n := uint(len(rr.limbs)) 484 mLen := uint(m.BitLen()) 485 logR := _W * n 486 487 // We start by computing R = 2^(_W * n) mod m. We can get pretty close, to 488 // 2^⌊log₂m⌋, by setting the highest bit we can without having to reduce. 489 rr.limbs[n-1] = 1 << ((mLen - 1) % _W) 490 // Then we double until we reach 2^(_W * n). 491 for i := mLen - 1; i < logR; i++ { 492 rr.Add(rr, m) 493 } 494 495 // Next we need to get from R to 2^(_W * n) R mod m (aka from one to R in 496 // the Montgomery domain, meaning we can use Montgomery multiplication now). 497 // We could do that by doubling _W * n times, or with a square-and-double 498 // chain log2(_W * n) long. Turns out the fastest thing is to start out with 499 // doublings, and switch to square-and-double once the exponent is large 500 // enough to justify the cost of the multiplications. 501 502 // The threshold is selected experimentally as a linear function of n. 503 threshold := n / 4 504 505 // We calculate how many of the most-significant bits of the exponent we can 506 // compute before crossing the threshold, and we do it with doublings. 507 i := bits.UintSize 508 for logR>>i <= threshold { 509 i-- 510 } 511 for k := uint(0); k < logR>>i; k++ { 512 rr.Add(rr, m) 513 } 514 515 // Then we process the remaining bits of the exponent with a 516 // square-and-double chain. 517 for i > 0 { 518 rr.montgomeryMul(rr, rr, m) 519 i-- 520 if logR>>i&1 != 0 { 521 rr.Add(rr, m) 522 } 523 } 524 525 return rr 526 } 527 528 // minusInverseModW computes -x⁻¹ mod _W with x odd. 529 // 530 // This operation is used to precompute a constant involved in Montgomery 531 // multiplication. 532 func minusInverseModW(x uint) uint { 533 // Every iteration of this loop doubles the least-significant bits of 534 // correct inverse in y. The first three bits are already correct (1⁻¹ = 1, 535 // 3⁻¹ = 3, 5⁻¹ = 5, and 7⁻¹ = 7 mod 8), so doubling five times is enough 536 // for 64 bits (and wastes only one iteration for 32 bits). 537 // 538 // See https://crypto.stackexchange.com/a/47496. 539 y := x 540 for i := 0; i < 5; i++ { 541 y = y * (2 - x*y) 542 } 543 return -y 544 } 545 546 // NewModulus creates a new Modulus from a slice of big-endian bytes. The 547 // modulus must be greater than one. 548 // 549 // The number of significant bits and whether the modulus is even is leaked 550 // through timing side-channels. 551 func NewModulus(b []byte) (*Modulus, error) { 552 n := NewNat().resetToBytes(b) 553 return newModulus(n) 554 } 555 556 // NewModulusProduct creates a new Modulus from the product of two numbers 557 // represented as big-endian byte slices. The result must be greater than one. 558 // 559 //go:norace 560 func NewModulusProduct(a, b []byte) (*Modulus, error) { 561 x := NewNat().resetToBytes(a) 562 y := NewNat().resetToBytes(b) 563 n := NewNat().reset(len(x.limbs) + len(y.limbs)) 564 for i := range y.limbs { 565 n.limbs[i+len(x.limbs)] = addMulVVW(n.limbs[i:i+len(x.limbs)], x.limbs, y.limbs[i]) 566 } 567 return newModulus(n.trim()) 568 } 569 570 func newModulus(n *Nat) (*Modulus, error) { 571 m := &Modulus{nat: n} 572 if m.nat.IsZero() == yes || m.nat.IsOne() == yes { 573 return nil, errors.New("modulus must be > 1") 574 } 575 if m.nat.IsOdd() == 1 { 576 m.odd = true 577 m.m0inv = minusInverseModW(m.nat.limbs[0]) 578 m.rr = rr(m) 579 } 580 return m, nil 581 } 582 583 // Size returns the size of m in bytes. 584 func (m *Modulus) Size() int { 585 return (m.BitLen() + 7) / 8 586 } 587 588 // BitLen returns the size of m in bits. 589 func (m *Modulus) BitLen() int { 590 return m.nat.BitLenVarTime() 591 } 592 593 // Nat returns m as a Nat. 594 func (m *Modulus) Nat() *Nat { 595 // Make a copy so that the caller can't modify m.nat or alias it with 596 // another Nat in a modulus operation. 597 n := NewNat() 598 n.set(m.nat) 599 return n 600 } 601 602 // shiftIn calculates x = x << _W + y mod m. 603 // 604 // This assumes that x is already reduced mod m. 605 // 606 //go:norace 607 func (x *Nat) shiftIn(y uint, m *Modulus) *Nat { 608 d := NewNat().resetFor(m) 609 610 // Eliminate bounds checks in the loop. 611 size := len(m.nat.limbs) 612 xLimbs := x.limbs[:size] 613 dLimbs := d.limbs[:size] 614 mLimbs := m.nat.limbs[:size] 615 616 // Each iteration of this loop computes x = 2x + b mod m, where b is a bit 617 // from y. Effectively, it left-shifts x and adds y one bit at a time, 618 // reducing it every time. 619 // 620 // To do the reduction, each iteration computes both 2x + b and 2x + b - m. 621 // The next iteration (and finally the return line) will use either result 622 // based on whether 2x + b overflows m. 623 needSubtraction := no 624 for i := _W - 1; i >= 0; i-- { 625 carry := (y >> i) & 1 626 var borrow uint 627 mask := ctMask(needSubtraction) 628 for i := 0; i < size; i++ { 629 l := xLimbs[i] ^ (mask & (xLimbs[i] ^ dLimbs[i])) 630 xLimbs[i], carry = bits.Add(l, l, carry) 631 dLimbs[i], borrow = bits.Sub(xLimbs[i], mLimbs[i], borrow) 632 } 633 // Like in maybeSubtractModulus, we need the subtraction if either it 634 // didn't underflow (meaning 2x + b > m) or if computing 2x + b 635 // overflowed (meaning 2x + b > 2^_W*n > m). 636 needSubtraction = not(choice(borrow)) | choice(carry) 637 } 638 return x.assign(needSubtraction, d) 639 } 640 641 // Mod calculates out = x mod m. 642 // 643 // This works regardless how large the value of x is. 644 // 645 // The output will be resized to the size of m and overwritten. 646 // 647 //go:norace 648 func (out *Nat) Mod(x *Nat, m *Modulus) *Nat { 649 out.resetFor(m) 650 // Working our way from the most significant to the least significant limb, 651 // we can insert each limb at the least significant position, shifting all 652 // previous limbs left by _W. This way each limb will get shifted by the 653 // correct number of bits. We can insert at least N - 1 limbs without 654 // overflowing m. After that, we need to reduce every time we shift. 655 i := len(x.limbs) - 1 656 // For the first N - 1 limbs we can skip the actual shifting and position 657 // them at the shifted position, which starts at min(N - 2, i). 658 start := len(m.nat.limbs) - 2 659 if i < start { 660 start = i 661 } 662 for j := start; j >= 0; j-- { 663 out.limbs[j] = x.limbs[i] 664 i-- 665 } 666 // We shift in the remaining limbs, reducing modulo m each time. 667 for i >= 0 { 668 out.shiftIn(x.limbs[i], m) 669 i-- 670 } 671 return out 672 } 673 674 // ExpandFor ensures x has the right size to work with operations modulo m. 675 // 676 // The announced size of x must be smaller than or equal to that of m. 677 func (x *Nat) ExpandFor(m *Modulus) *Nat { 678 return x.expand(len(m.nat.limbs)) 679 } 680 681 // resetFor ensures out has the right size to work with operations modulo m. 682 // 683 // out is zeroed and may start at any size. 684 func (out *Nat) resetFor(m *Modulus) *Nat { 685 return out.reset(len(m.nat.limbs)) 686 } 687 688 // maybeSubtractModulus computes x -= m if and only if x >= m or if "always" is yes. 689 // 690 // It can be used to reduce modulo m a value up to 2m - 1, which is a common 691 // range for results computed by higher level operations. 692 // 693 // always is usually a carry that indicates that the operation that produced x 694 // overflowed its size, meaning abstractly x > 2^_W*n > m even if x < m. 695 // 696 // x and m operands must have the same announced length. 697 // 698 //go:norace 699 func (x *Nat) maybeSubtractModulus(always choice, m *Modulus) { 700 t := NewNat().set(x) 701 underflow := t.sub(m.nat) 702 // We keep the result if x - m didn't underflow (meaning x >= m) 703 // or if always was set. 704 keep := not(choice(underflow)) | choice(always) 705 x.assign(keep, t) 706 } 707 708 // Sub computes x = x - y mod m. 709 // 710 // The length of both operands must be the same as the modulus. Both operands 711 // must already be reduced modulo m. 712 // 713 //go:norace 714 func (x *Nat) Sub(y *Nat, m *Modulus) *Nat { 715 underflow := x.sub(y) 716 // If the subtraction underflowed, add m. 717 t := NewNat().set(x) 718 t.add(m.nat) 719 x.assign(choice(underflow), t) 720 return x 721 } 722 723 // SubOne computes x = x - 1 mod m. 724 // 725 // The length of x must be the same as the modulus. 726 func (x *Nat) SubOne(m *Modulus) *Nat { 727 one := NewNat().ExpandFor(m) 728 one.limbs[0] = 1 729 // Sub asks for x to be reduced modulo m, while SubOne doesn't, but when 730 // y = 1, it works, and this is an internal use. 731 return x.Sub(one, m) 732 } 733 734 // Add computes x = x + y mod m. 735 // 736 // The length of both operands must be the same as the modulus. Both operands 737 // must already be reduced modulo m. 738 // 739 //go:norace 740 func (x *Nat) Add(y *Nat, m *Modulus) *Nat { 741 overflow := x.add(y) 742 x.maybeSubtractModulus(choice(overflow), m) 743 return x 744 } 745 746 // montgomeryRepresentation calculates x = x * R mod m, with R = 2^(_W * n) and 747 // n = len(m.nat.limbs). 748 // 749 // Faster Montgomery multiplication replaces standard modular multiplication for 750 // numbers in this representation. 751 // 752 // This assumes that x is already reduced mod m. 753 func (x *Nat) montgomeryRepresentation(m *Modulus) *Nat { 754 // A Montgomery multiplication (which computes a * b / R) by R * R works out 755 // to a multiplication by R, which takes the value out of the Montgomery domain. 756 return x.montgomeryMul(x, m.rr, m) 757 } 758 759 // montgomeryReduction calculates x = x / R mod m, with R = 2^(_W * n) and 760 // n = len(m.nat.limbs). 761 // 762 // This assumes that x is already reduced mod m. 763 func (x *Nat) montgomeryReduction(m *Modulus) *Nat { 764 // By Montgomery multiplying with 1 not in Montgomery representation, we 765 // convert out back from Montgomery representation, because it works out to 766 // dividing by R. 767 one := NewNat().ExpandFor(m) 768 one.limbs[0] = 1 769 return x.montgomeryMul(x, one, m) 770 } 771 772 // montgomeryMul calculates x = a * b / R mod m, with R = 2^(_W * n) and 773 // n = len(m.nat.limbs), also known as a Montgomery multiplication. 774 // 775 // All inputs should be the same length and already reduced modulo m. 776 // x will be resized to the size of m and overwritten. 777 // 778 //go:norace 779 func (x *Nat) montgomeryMul(a *Nat, b *Nat, m *Modulus) *Nat { 780 n := len(m.nat.limbs) 781 mLimbs := m.nat.limbs[:n] 782 aLimbs := a.limbs[:n] 783 bLimbs := b.limbs[:n] 784 785 switch n { 786 default: 787 // Attempt to use a stack-allocated backing array. 788 T := make([]uint, 0, preallocLimbs*2) 789 if cap(T) < n*2 { 790 T = make([]uint, 0, n*2) 791 } 792 T = T[:n*2] 793 794 // This loop implements Word-by-Word Montgomery Multiplication, as 795 // described in Algorithm 4 (Fig. 3) of "Efficient Software 796 // Implementations of Modular Exponentiation" by Shay Gueron 797 // [https://eprint.iacr.org/2011/239.pdf]. 798 var c uint 799 for i := 0; i < n; i++ { 800 _ = T[n+i] // bounds check elimination hint 801 802 // Step 1 (T = a × b) is computed as a large pen-and-paper column 803 // multiplication of two numbers with n base-2^_W digits. If we just 804 // wanted to produce 2n-wide T, we would do 805 // 806 // for i := 0; i < n; i++ { 807 // d := bLimbs[i] 808 // T[n+i] = addMulVVW(T[i:n+i], aLimbs, d) 809 // } 810 // 811 // where d is a digit of the multiplier, T[i:n+i] is the shifted 812 // position of the product of that digit, and T[n+i] is the final carry. 813 // Note that T[i] isn't modified after processing the i-th digit. 814 // 815 // Instead of running two loops, one for Step 1 and one for Steps 2–6, 816 // the result of Step 1 is computed during the next loop. This is 817 // possible because each iteration only uses T[i] in Step 2 and then 818 // discards it in Step 6. 819 d := bLimbs[i] 820 c1 := addMulVVW(T[i:n+i], aLimbs, d) 821 822 // Step 6 is replaced by shifting the virtual window we operate 823 // over: T of the algorithm is T[i:] for us. That means that T1 in 824 // Step 2 (T mod 2^_W) is simply T[i]. k0 in Step 3 is our m0inv. 825 Y := T[i] * m.m0inv 826 827 // Step 4 and 5 add Y × m to T, which as mentioned above is stored 828 // at T[i:]. The two carries (from a × d and Y × m) are added up in 829 // the next word T[n+i], and the carry bit from that addition is 830 // brought forward to the next iteration. 831 c2 := addMulVVW(T[i:n+i], mLimbs, Y) 832 T[n+i], c = bits.Add(c1, c2, c) 833 } 834 835 // Finally for Step 7 we copy the final T window into x, and subtract m 836 // if necessary (which as explained in maybeSubtractModulus can be the 837 // case both if x >= m, or if x overflowed). 838 // 839 // The paper suggests in Section 4 that we can do an "Almost Montgomery 840 // Multiplication" by subtracting only in the overflow case, but the 841 // cost is very similar since the constant time subtraction tells us if 842 // x >= m as a side effect, and taking care of the broken invariant is 843 // highly undesirable (see https://go.dev/issue/13907). 844 copy(x.reset(n).limbs, T[n:]) 845 x.maybeSubtractModulus(choice(c), m) 846 847 // The following specialized cases follow the exact same algorithm, but 848 // optimized for the sizes most used in RSA. addMulVVW is implemented in 849 // assembly with loop unrolling depending on the architecture and bounds 850 // checks are removed by the compiler thanks to the constant size. 851 case 1024 / _W: 852 const n = 1024 / _W // compiler hint 853 T := make([]uint, n*2) 854 var c uint 855 for i := 0; i < n; i++ { 856 d := bLimbs[i] 857 c1 := addMulVVW1024(&T[i], &aLimbs[0], d) 858 Y := T[i] * m.m0inv 859 c2 := addMulVVW1024(&T[i], &mLimbs[0], Y) 860 T[n+i], c = bits.Add(c1, c2, c) 861 } 862 copy(x.reset(n).limbs, T[n:]) 863 x.maybeSubtractModulus(choice(c), m) 864 865 case 1536 / _W: 866 const n = 1536 / _W // compiler hint 867 T := make([]uint, n*2) 868 var c uint 869 for i := 0; i < n; i++ { 870 d := bLimbs[i] 871 c1 := addMulVVW1536(&T[i], &aLimbs[0], d) 872 Y := T[i] * m.m0inv 873 c2 := addMulVVW1536(&T[i], &mLimbs[0], Y) 874 T[n+i], c = bits.Add(c1, c2, c) 875 } 876 copy(x.reset(n).limbs, T[n:]) 877 x.maybeSubtractModulus(choice(c), m) 878 879 case 2048 / _W: 880 const n = 2048 / _W // compiler hint 881 T := make([]uint, n*2) 882 var c uint 883 for i := 0; i < n; i++ { 884 d := bLimbs[i] 885 c1 := addMulVVW2048(&T[i], &aLimbs[0], d) 886 Y := T[i] * m.m0inv 887 c2 := addMulVVW2048(&T[i], &mLimbs[0], Y) 888 T[n+i], c = bits.Add(c1, c2, c) 889 } 890 copy(x.reset(n).limbs, T[n:]) 891 x.maybeSubtractModulus(choice(c), m) 892 } 893 894 return x 895 } 896 897 // addMulVVW multiplies the multi-word value x by the single-word value y, 898 // adding the result to the multi-word value z and returning the final carry. 899 // It can be thought of as one row of a pen-and-paper column multiplication. 900 // 901 //go:norace 902 func addMulVVW(z, x []uint, y uint) (carry uint) { 903 _ = x[len(z)-1] // bounds check elimination hint 904 for i := range z { 905 hi, lo := bits.Mul(x[i], y) 906 lo, c := bits.Add(lo, z[i], 0) 907 // We use bits.Add with zero to get an add-with-carry instruction that 908 // absorbs the carry from the previous bits.Add. 909 hi, _ = bits.Add(hi, 0, c) 910 lo, c = bits.Add(lo, carry, 0) 911 hi, _ = bits.Add(hi, 0, c) 912 carry = hi 913 z[i] = lo 914 } 915 return carry 916 } 917 918 // Mul calculates x = x * y mod m. 919 // 920 // The length of both operands must be the same as the modulus. Both operands 921 // must already be reduced modulo m. 922 // 923 //go:norace 924 func (x *Nat) Mul(y *Nat, m *Modulus) *Nat { 925 if m.odd { 926 // A Montgomery multiplication by a value out of the Montgomery domain 927 // takes the result out of Montgomery representation. 928 xR := NewNat().set(x).montgomeryRepresentation(m) // xR = x * R mod m 929 return x.montgomeryMul(xR, y, m) // x = xR * y / R mod m 930 } 931 932 n := len(m.nat.limbs) 933 xLimbs := x.limbs[:n] 934 yLimbs := y.limbs[:n] 935 936 switch n { 937 default: 938 // Attempt to use a stack-allocated backing array. 939 T := make([]uint, 0, preallocLimbs*2) 940 if cap(T) < n*2 { 941 T = make([]uint, 0, n*2) 942 } 943 T = T[:n*2] 944 945 // T = x * y 946 for i := 0; i < n; i++ { 947 T[n+i] = addMulVVW(T[i:n+i], xLimbs, yLimbs[i]) 948 } 949 950 // x = T mod m 951 return x.Mod(&Nat{limbs: T}, m) 952 953 // The following specialized cases follow the exact same algorithm, but 954 // optimized for the sizes most used in RSA. See montgomeryMul for details. 955 case 1024 / _W: 956 const n = 1024 / _W // compiler hint 957 T := make([]uint, n*2) 958 for i := 0; i < n; i++ { 959 T[n+i] = addMulVVW1024(&T[i], &xLimbs[0], yLimbs[i]) 960 } 961 return x.Mod(&Nat{limbs: T}, m) 962 case 1536 / _W: 963 const n = 1536 / _W // compiler hint 964 T := make([]uint, n*2) 965 for i := 0; i < n; i++ { 966 T[n+i] = addMulVVW1536(&T[i], &xLimbs[0], yLimbs[i]) 967 } 968 return x.Mod(&Nat{limbs: T}, m) 969 case 2048 / _W: 970 const n = 2048 / _W // compiler hint 971 T := make([]uint, n*2) 972 for i := 0; i < n; i++ { 973 T[n+i] = addMulVVW2048(&T[i], &xLimbs[0], yLimbs[i]) 974 } 975 return x.Mod(&Nat{limbs: T}, m) 976 } 977 } 978 979 // Exp calculates out = x^e mod m. 980 // 981 // The exponent e is represented in big-endian order. The output will be resized 982 // to the size of m and overwritten. x must already be reduced modulo m. 983 // 984 // m must be odd, or Exp will panic. 985 // 986 //go:norace 987 func (out *Nat) Exp(x *Nat, e []byte, m *Modulus) *Nat { 988 if !m.odd { 989 panic("bigmod: modulus for Exp must be odd") 990 } 991 992 // We use a 4 bit window. For our RSA workload, 4 bit windows are faster 993 // than 2 bit windows, but use an extra 12 nats worth of scratch space. 994 // Using bit sizes that don't divide 8 are more complex to implement, but 995 // are likely to be more efficient if necessary. 996 997 table := [(1 << 4) - 1]*Nat{ // table[i] = x ^ (i+1) 998 // newNat calls are unrolled so they are allocated on the stack. 999 NewNat(), NewNat(), NewNat(), NewNat(), NewNat(), 1000 NewNat(), NewNat(), NewNat(), NewNat(), NewNat(), 1001 NewNat(), NewNat(), NewNat(), NewNat(), NewNat(), 1002 } 1003 table[0].set(x).montgomeryRepresentation(m) 1004 for i := 1; i < len(table); i++ { 1005 table[i].montgomeryMul(table[i-1], table[0], m) 1006 } 1007 1008 out.resetFor(m) 1009 out.limbs[0] = 1 1010 out.montgomeryRepresentation(m) 1011 tmp := NewNat().ExpandFor(m) 1012 for _, b := range e { 1013 for _, j := range []int{4, 0} { 1014 // Square four times. Optimization note: this can be implemented 1015 // more efficiently than with generic Montgomery multiplication. 1016 out.montgomeryMul(out, out, m) 1017 out.montgomeryMul(out, out, m) 1018 out.montgomeryMul(out, out, m) 1019 out.montgomeryMul(out, out, m) 1020 1021 // Select x^k in constant time from the table. 1022 k := uint((b >> j) & 0b1111) 1023 for i := range table { 1024 tmp.assign(ctEq(k, uint(i+1)), table[i]) 1025 } 1026 1027 // Multiply by x^k, discarding the result if k = 0. 1028 tmp.montgomeryMul(out, tmp, m) 1029 out.assign(not(ctEq(k, 0)), tmp) 1030 } 1031 } 1032 1033 return out.montgomeryReduction(m) 1034 } 1035 1036 // ExpShortVarTime calculates out = x^e mod m. 1037 // 1038 // The output will be resized to the size of m and overwritten. x must already 1039 // be reduced modulo m. This leaks the exponent through timing side-channels. 1040 // 1041 // m must be odd, or ExpShortVarTime will panic. 1042 func (out *Nat) ExpShortVarTime(x *Nat, e uint, m *Modulus) *Nat { 1043 if !m.odd { 1044 panic("bigmod: modulus for ExpShortVarTime must be odd") 1045 } 1046 // For short exponents, precomputing a table and using a window like in Exp 1047 // doesn't pay off. Instead, we do a simple conditional square-and-multiply 1048 // chain, skipping the initial run of zeroes. 1049 xR := NewNat().set(x).montgomeryRepresentation(m) 1050 out.set(xR) 1051 for i := bits.UintSize - bits.Len(e) + 1; i < bits.UintSize; i++ { 1052 out.montgomeryMul(out, out, m) 1053 if k := (e >> (bits.UintSize - i - 1)) & 1; k != 0 { 1054 out.montgomeryMul(out, xR, m) 1055 } 1056 } 1057 return out.montgomeryReduction(m) 1058 } 1059 1060 // InverseVarTime calculates x = a⁻¹ mod m and returns (x, true) if a is 1061 // invertible. Otherwise, InverseVarTime returns (x, false) and x is not 1062 // modified. 1063 // 1064 // a must be reduced modulo m, but doesn't need to have the same size. The 1065 // output will be resized to the size of m and overwritten. 1066 // 1067 //go:norace 1068 func (x *Nat) InverseVarTime(a *Nat, m *Modulus) (*Nat, bool) { 1069 u, A, err := extendedGCD(a, m.nat) 1070 if err != nil { 1071 return x, false 1072 } 1073 if u.IsOne() == no { 1074 return x, false 1075 } 1076 return x.set(A), true 1077 } 1078 1079 // GCDVarTime calculates x = GCD(a, b) where at least one of a or b is odd, and 1080 // both are non-zero. If GCDVarTime returns an error, x is not modified. 1081 // 1082 // The output will be resized to the size of the larger of a and b. 1083 func (x *Nat) GCDVarTime(a, b *Nat) (*Nat, error) { 1084 u, _, err := extendedGCD(a, b) 1085 if err != nil { 1086 return nil, err 1087 } 1088 return x.set(u), nil 1089 } 1090 1091 // extendedGCD computes u and A such that a = GCD(a, m) and u = A*a - B*m. 1092 // 1093 // u will have the size of the larger of a and m, and A will have the size of m. 1094 // 1095 // It is an error if either a or m is zero, or if they are both even. 1096 func extendedGCD(a, m *Nat) (u, A *Nat, err error) { 1097 // This is the extended binary GCD algorithm described in the Handbook of 1098 // Applied Cryptography, Algorithm 14.61, adapted by BoringSSL to bound 1099 // coefficients and avoid negative numbers. For more details and proof of 1100 // correctness, see https://github.com/mit-plv/fiat-crypto/pull/333/files. 1101 // 1102 // Following the proof linked in the PR above, the changes are: 1103 // 1104 // 1. Negate [B] and [C] so they are positive. The invariant now involves a 1105 // subtraction. 1106 // 2. If step 2 (both [x] and [y] are even) runs, abort immediately. This 1107 // case needs to be handled by the caller. 1108 // 3. Subtract copies of [x] and [y] as needed in step 6 (both [u] and [v] 1109 // are odd) so coefficients stay in bounds. 1110 // 4. Replace the [u >= v] check with [u > v]. This changes the end 1111 // condition to [v = 0] rather than [u = 0]. This saves an extra 1112 // subtraction due to which coefficients were negated. 1113 // 5. Rename x and y to a and n, to capture that one is a modulus. 1114 // 6. Rearrange steps 4 through 6 slightly. Merge the loops in steps 4 and 1115 // 5 into the main loop (step 7's goto), and move step 6 to the start of 1116 // the loop iteration, ensuring each loop iteration halves at least one 1117 // value. 1118 // 1119 // Note this algorithm does not handle either input being zero. 1120 1121 if a.IsZero() == yes || m.IsZero() == yes { 1122 return nil, nil, errors.New("extendedGCD: a or m is zero") 1123 } 1124 if a.IsOdd() == no && m.IsOdd() == no { 1125 return nil, nil, errors.New("extendedGCD: both a and m are even") 1126 } 1127 1128 size := max(len(a.limbs), len(m.limbs)) 1129 u = NewNat().set(a).expand(size) 1130 v := NewNat().set(m).expand(size) 1131 1132 A = NewNat().reset(len(m.limbs)) 1133 A.limbs[0] = 1 1134 B := NewNat().reset(len(a.limbs)) 1135 C := NewNat().reset(len(m.limbs)) 1136 D := NewNat().reset(len(a.limbs)) 1137 D.limbs[0] = 1 1138 1139 // Before and after each loop iteration, the following hold: 1140 // 1141 // u = A*a - B*m 1142 // v = D*m - C*a 1143 // 0 < u <= a 1144 // 0 <= v <= m 1145 // 0 <= A < m 1146 // 0 <= B <= a 1147 // 0 <= C < m 1148 // 0 <= D <= a 1149 // 1150 // After each loop iteration, u and v only get smaller, and at least one of 1151 // them shrinks by at least a factor of two. 1152 for { 1153 // If both u and v are odd, subtract the smaller from the larger. 1154 // If u = v, we need to subtract from v to hit the modified exit condition. 1155 if u.IsOdd() == yes && v.IsOdd() == yes { 1156 if v.cmpGeq(u) == no { 1157 u.sub(v) 1158 A.Add(C, &Modulus{nat: m}) 1159 B.Add(D, &Modulus{nat: a}) 1160 } else { 1161 v.sub(u) 1162 C.Add(A, &Modulus{nat: m}) 1163 D.Add(B, &Modulus{nat: a}) 1164 } 1165 } 1166 1167 // Exactly one of u and v is now even. 1168 if u.IsOdd() == v.IsOdd() { 1169 panic("bigmod: internal error: u and v are not in the expected state") 1170 } 1171 1172 // Halve the even one and adjust the corresponding coefficient. 1173 if u.IsOdd() == no { 1174 rshift1(u, 0) 1175 if A.IsOdd() == yes || B.IsOdd() == yes { 1176 rshift1(A, A.add(m)) 1177 rshift1(B, B.add(a)) 1178 } else { 1179 rshift1(A, 0) 1180 rshift1(B, 0) 1181 } 1182 } else { // v.IsOdd() == no 1183 rshift1(v, 0) 1184 if C.IsOdd() == yes || D.IsOdd() == yes { 1185 rshift1(C, C.add(m)) 1186 rshift1(D, D.add(a)) 1187 } else { 1188 rshift1(C, 0) 1189 rshift1(D, 0) 1190 } 1191 } 1192 1193 if v.IsZero() == yes { 1194 return u, A, nil 1195 } 1196 } 1197 } 1198 1199 //go:norace 1200 func rshift1(a *Nat, carry uint) { 1201 size := len(a.limbs) 1202 aLimbs := a.limbs[:size] 1203 1204 for i := range size { 1205 aLimbs[i] >>= 1 1206 if i+1 < size { 1207 aLimbs[i] |= aLimbs[i+1] << (_W - 1) 1208 } else { 1209 aLimbs[i] |= carry << (_W - 1) 1210 } 1211 } 1212 } 1213 1214 // DivShortVarTime calculates x = x / y and returns the remainder. 1215 // 1216 // It panics if y is zero. 1217 // 1218 //go:norace 1219 func (x *Nat) DivShortVarTime(y uint) uint { 1220 if y == 0 { 1221 panic("bigmod: division by zero") 1222 } 1223 1224 var r uint 1225 for i := len(x.limbs) - 1; i >= 0; i-- { 1226 x.limbs[i], r = bits.Div(r, x.limbs[i], y) 1227 } 1228 return r 1229 } 1230