Source file src/math/big/nat_test.go

     1  // Copyright 2009 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package big
     6  
     7  import (
     8  	"fmt"
     9  	"math"
    10  	"math/bits"
    11  	"math/rand/v2"
    12  	"runtime"
    13  	"strings"
    14  	"testing"
    15  )
    16  
    17  var cmpTests = []struct {
    18  	x, y nat
    19  	r    int
    20  }{
    21  	{nil, nil, 0},
    22  	{nil, nat(nil), 0},
    23  	{nat(nil), nil, 0},
    24  	{nat(nil), nat(nil), 0},
    25  	{nat{0}, nat{0}, 0},
    26  	{nat{0}, nat{1}, -1},
    27  	{nat{1}, nat{0}, 1},
    28  	{nat{1}, nat{1}, 0},
    29  	{nat{0, _M}, nat{1}, 1},
    30  	{nat{1}, nat{0, _M}, -1},
    31  	{nat{1, _M}, nat{0, _M}, 1},
    32  	{nat{0, _M}, nat{1, _M}, -1},
    33  	{nat{16, 571956, 8794, 68}, nat{837, 9146, 1, 754489}, -1},
    34  	{nat{34986, 41, 105, 1957}, nat{56, 7458, 104, 1957}, 1},
    35  }
    36  
    37  func TestCmp(t *testing.T) {
    38  	for i, a := range cmpTests {
    39  		r := a.x.cmp(a.y)
    40  		if r != a.r {
    41  			t.Errorf("#%d got r = %v; want %v", i, r, a.r)
    42  		}
    43  	}
    44  }
    45  
    46  type funNN func(z, x, y nat) nat
    47  type funSNN func(z nat, stk *stack, x, y nat) nat
    48  type argNN struct {
    49  	z, x, y nat
    50  }
    51  
    52  var sumNN = []argNN{
    53  	{},
    54  	{nat{1}, nil, nat{1}},
    55  	{nat{1111111110}, nat{123456789}, nat{987654321}},
    56  	{nat{0, 0, 0, 1}, nil, nat{0, 0, 0, 1}},
    57  	{nat{0, 0, 0, 1111111110}, nat{0, 0, 0, 123456789}, nat{0, 0, 0, 987654321}},
    58  	{nat{0, 0, 0, 1}, nat{0, 0, _M}, nat{0, 0, 1}},
    59  }
    60  
    61  var prodNN = append(prodTests(), prodNNExtra...)
    62  
    63  func permute[E any](x []E) {
    64  	out := make([]E, len(x))
    65  	for i, j := range rand.Perm(len(x)) {
    66  		out[i] = x[j]
    67  	}
    68  	copy(x, out)
    69  }
    70  
    71  // testMul returns the product of x and y using the grade-school algorithm,
    72  // as a reference implementation.
    73  func testMul(x, y nat) nat {
    74  	z := make(nat, len(x)+len(y))
    75  	for i, xi := range x {
    76  		for j, yj := range y {
    77  			hi, lo := bits.Mul(uint(xi), uint(yj))
    78  			k := i + j
    79  			s, c := bits.Add(uint(z[k]), lo, 0)
    80  			z[k] = Word(s)
    81  			k++
    82  			for hi != 0 || c != 0 {
    83  				s, c = bits.Add(uint(z[k]), hi, c)
    84  				hi = 0
    85  				z[k] = Word(s)
    86  				k++
    87  			}
    88  		}
    89  	}
    90  	return z.norm()
    91  }
    92  
    93  func prodTests() []argNN {
    94  	var tests []argNN
    95  	for size := range 10 {
    96  		var x, y nat
    97  		for i := range size {
    98  			x = append(x, Word(i+1))
    99  			y = append(y, Word(i+1+size))
   100  		}
   101  		permute(x)
   102  		permute(y)
   103  		x = x.norm()
   104  		y = y.norm()
   105  		tests = append(tests, argNN{testMul(x, y), x, y})
   106  	}
   107  
   108  	words := []Word{0, 1, 2, 3, 4, ^Word(0), ^Word(1), ^Word(2), ^Word(3)}
   109  	for size := range 10 {
   110  		if size == 0 {
   111  			continue // already tested the only 0-length possibility above
   112  		}
   113  		for range 10 {
   114  			x := make(nat, size)
   115  			y := make(nat, size)
   116  			for i := range size {
   117  				x[i] = words[rand.N(len(words))]
   118  				y[i] = words[rand.N(len(words))]
   119  			}
   120  			x = x.norm()
   121  			y = y.norm()
   122  			tests = append(tests, argNN{testMul(x, y), x, y})
   123  		}
   124  	}
   125  	return tests
   126  }
   127  
   128  var prodNNExtra = []argNN{
   129  	{nil, nat{991}, nil},
   130  	{nat{991}, nat{991}, nat{1}},
   131  	{nat{991 * 991}, nat{991}, nat{991}},
   132  	{nat{8, 22, 15}, nat{2, 3}, nat{4, 5}},
   133  	{nat{10, 27, 52, 45, 28}, nat{2, 3, 4}, nat{5, 6, 7}},
   134  	{nat{12, 32, 61, 100, 94, 76, 45}, nat{2, 3, 4, 5}, nat{6, 7, 8, 9}},
   135  	{nat{12, 32, 61, 100, 94, 76, 45}, nat{2, 3, 4, 5}, nat{6, 7, 8, 9}},
   136  	{nat{14, 37, 70, 114, 170, 166, 148, 115, 66}, nat{2, 3, 4, 5, 6}, nat{7, 8, 9, 10, 11}},
   137  	{nat{991 * 991, 991 * 2, 1}, nat{991, 1}, nat{991, 1}},
   138  	{nat{991 * 991, 991 * 777 * 2, 777 * 777}, nat{991, 777}, nat{991, 777}},
   139  	{nat{0, 0, 991 * 991}, nat{0, 991}, nat{0, 991}},
   140  	{nat{1 * 991, 2 * 991, 3 * 991, 4 * 991}, nat{1, 2, 3, 4}, nat{991}},
   141  	{nat{4, 11, 20, 30, 20, 11, 4}, nat{1, 2, 3, 4}, nat{4, 3, 2, 1}},
   142  	// 3^100 * 3^28 = 3^128
   143  	{
   144  		natFromString("11790184577738583171520872861412518665678211592275841109096961"),
   145  		natFromString("515377520732011331036461129765621272702107522001"),
   146  		natFromString("22876792454961"),
   147  	},
   148  	// z = 111....1 (70000 digits)
   149  	// x = 10^(99*700) + ... + 10^1400 + 10^700 + 1
   150  	// y = 111....1 (700 digits, larger than Karatsuba threshold on 32-bit and 64-bit)
   151  	{
   152  		natFromString(strings.Repeat("1", 70000)),
   153  		natFromString("1" + strings.Repeat(strings.Repeat("0", 699)+"1", 99)),
   154  		natFromString(strings.Repeat("1", 700)),
   155  	},
   156  	// z = 111....1 (20000 digits)
   157  	// x = 10^10000 + 1
   158  	// y = 111....1 (10000 digits)
   159  	{
   160  		natFromString(strings.Repeat("1", 20000)),
   161  		natFromString("1" + strings.Repeat("0", 9999) + "1"),
   162  		natFromString(strings.Repeat("1", 10000)),
   163  	},
   164  }
   165  
   166  func natFromString(s string) nat {
   167  	x, _, _, err := nat(nil).scan(strings.NewReader(s), 0, false)
   168  	if err != nil {
   169  		panic(err)
   170  	}
   171  	return x
   172  }
   173  
   174  func TestSet(t *testing.T) {
   175  	for _, a := range sumNN {
   176  		z := nat(nil).set(a.z)
   177  		if z.cmp(a.z) != 0 {
   178  			t.Errorf("got z = %v; want %v", z, a.z)
   179  		}
   180  	}
   181  }
   182  
   183  func testFunNN(t *testing.T, msg string, f funNN, a argNN) {
   184  	z := f(nil, a.x, a.y)
   185  	if z.cmp(a.z) != 0 {
   186  		t.Errorf("%s%+v\n\tgot z = %v; want %v", msg, a, z, a.z)
   187  	}
   188  }
   189  
   190  func testFunSNN(t *testing.T, msg string, f funSNN, a argNN) {
   191  	t.Helper()
   192  	stk := getStack()
   193  	defer stk.free()
   194  	z := f(nil, stk, a.x, a.y)
   195  	if z.cmp(a.z) != 0 {
   196  		t.Fatalf("%s%+v\n\tgot z = %v; want %v", msg, a, z, a.z)
   197  	}
   198  }
   199  
   200  func setDuringTest[V any](t *testing.T, p *V, v V) {
   201  	old := *p
   202  	*p = v
   203  	t.Cleanup(func() { *p = old })
   204  }
   205  
   206  func TestAdd(t *testing.T) {
   207  	for _, a := range sumNN {
   208  		testFunNN(t, "add", nat.add, a)
   209  		a.x, a.y = a.y, a.x
   210  		testFunNN(t, "add", nat.add, a)
   211  	}
   212  }
   213  
   214  func TestSub(t *testing.T) {
   215  	for _, a := range sumNN {
   216  		a.x, a.z = a.z, a.x
   217  		testFunNN(t, "sub", nat.sub, a)
   218  
   219  		a.y, a.z = a.z, a.y
   220  		testFunNN(t, "sub", nat.sub, a)
   221  	}
   222  }
   223  
   224  func TestNatMul(t *testing.T) {
   225  	t.Run("Basic", func(t *testing.T) {
   226  		setDuringTest(t, &karatsubaThreshold, 1e9)
   227  		for _, a := range prodNN {
   228  			if len(a.z) >= 100 {
   229  				continue
   230  			}
   231  			testFunSNN(t, "mul", nat.mul, a)
   232  			a.x, a.y = a.y, a.x
   233  			testFunSNN(t, "mul", nat.mul, a)
   234  		}
   235  	})
   236  	t.Run("Karatsuba", func(t *testing.T) {
   237  		setDuringTest(t, &karatsubaThreshold, 2)
   238  		for _, a := range prodNN {
   239  			testFunSNN(t, "mul", nat.mul, a)
   240  			a.x, a.y = a.y, a.x
   241  			testFunSNN(t, "mul", nat.mul, a)
   242  		}
   243  	})
   244  
   245  	t.Run("Mul", func(t *testing.T) {
   246  		for _, a := range prodNN {
   247  			testFunSNN(t, "mul", nat.mul, a)
   248  			a.x, a.y = a.y, a.x
   249  			testFunSNN(t, "mul", nat.mul, a)
   250  		}
   251  	})
   252  }
   253  
   254  func testSqr(t *testing.T, x nat) {
   255  	stk := getStack()
   256  	defer stk.free()
   257  
   258  	got := make(nat, 2*len(x))
   259  	want := make(nat, 2*len(x))
   260  	got = got.sqr(stk, x)
   261  	want = want.mul(stk, x, x)
   262  	if got.cmp(want) != 0 {
   263  		t.Errorf("basicSqr(%v), got %v, want %v", x, got, want)
   264  	}
   265  }
   266  
   267  func TestNatSqr(t *testing.T) {
   268  	t.Run("Basic", func(t *testing.T) {
   269  		setDuringTest(t, &basicSqrThreshold, 0)
   270  		setDuringTest(t, &karatsubaSqrThreshold, 1e9)
   271  		for _, a := range prodNN {
   272  			if len(a.z) >= 100 {
   273  				continue
   274  			}
   275  			testSqr(t, a.x)
   276  			testSqr(t, a.y)
   277  			testSqr(t, a.z)
   278  		}
   279  	})
   280  	t.Run("Karatsuba", func(t *testing.T) {
   281  		setDuringTest(t, &basicSqrThreshold, 2)
   282  		setDuringTest(t, &karatsubaSqrThreshold, 2)
   283  		for _, a := range prodNN {
   284  			testSqr(t, a.x)
   285  			testSqr(t, a.y)
   286  			testSqr(t, a.z)
   287  		}
   288  	})
   289  	t.Run("Sqr", func(t *testing.T) {
   290  		for _, a := range prodNN {
   291  			testSqr(t, a.x)
   292  			testSqr(t, a.y)
   293  			testSqr(t, a.z)
   294  		}
   295  	})
   296  }
   297  
   298  var mulRangesN = []struct {
   299  	a, b uint64
   300  	prod string
   301  }{
   302  	{0, 0, "0"},
   303  	{1, 1, "1"},
   304  	{1, 2, "2"},
   305  	{1, 3, "6"},
   306  	{10, 10, "10"},
   307  	{0, 100, "0"},
   308  	{0, 1e9, "0"},
   309  	{1, 0, "1"},                    // empty range
   310  	{100, 1, "1"},                  // empty range
   311  	{1, 10, "3628800"},             // 10!
   312  	{1, 20, "2432902008176640000"}, // 20!
   313  	{1, 100,
   314  		"933262154439441526816992388562667004907159682643816214685929" +
   315  			"638952175999932299156089414639761565182862536979208272237582" +
   316  			"51185210916864000000000000000000000000", // 100!
   317  	},
   318  	{math.MaxUint64 - 0, math.MaxUint64, "18446744073709551615"},
   319  	{math.MaxUint64 - 1, math.MaxUint64, "340282366920938463408034375210639556610"},
   320  	{math.MaxUint64 - 2, math.MaxUint64, "6277101735386680761794095221682035635525021984684230311930"},
   321  	{math.MaxUint64 - 3, math.MaxUint64, "115792089237316195360799967654821100226821973275796746098729803619699194331160"},
   322  }
   323  
   324  func TestMulRangeN(t *testing.T) {
   325  	stk := getStack()
   326  	defer stk.free()
   327  
   328  	for i, r := range mulRangesN {
   329  		prod := string(nat(nil).mulRange(stk, r.a, r.b).utoa(10))
   330  		if prod != r.prod {
   331  			t.Errorf("#%d: got %s; want %s", i, prod, r.prod)
   332  		}
   333  	}
   334  }
   335  
   336  // allocBytes returns the number of bytes allocated by invoking f.
   337  func allocBytes(f func()) uint64 {
   338  	var stats runtime.MemStats
   339  	runtime.ReadMemStats(&stats)
   340  	t := stats.TotalAlloc
   341  	f()
   342  	runtime.ReadMemStats(&stats)
   343  	return stats.TotalAlloc - t
   344  }
   345  
   346  // TestMulUnbalanced tests that multiplying numbers of different lengths
   347  // does not cause deep recursion and in turn allocate too much memory.
   348  // Test case for issue 3807.
   349  func TestMulUnbalanced(t *testing.T) {
   350  	stk := getStack()
   351  	defer stk.free()
   352  
   353  	defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(1))
   354  	x := rndNat(50000)
   355  	y := rndNat(40)
   356  	allocSize := allocBytes(func() {
   357  		nat(nil).mul(stk, x, y)
   358  	})
   359  	inputSize := uint64(len(x)+len(y)) * _S
   360  	if ratio := allocSize / uint64(inputSize); ratio > 10 {
   361  		t.Errorf("multiplication uses too much memory (%d > %d times the size of inputs)", allocSize, ratio)
   362  	}
   363  }
   364  
   365  // rndNat returns a random nat value >= 0 of (usually) n words in length.
   366  // In extremely unlikely cases it may be smaller than n words if the top-
   367  // most words are 0.
   368  func rndNat(n int) nat {
   369  	return nat(rndV(n)).norm()
   370  }
   371  
   372  // rndNat1 is like rndNat but the result is guaranteed to be > 0.
   373  func rndNat1(n int) nat {
   374  	x := nat(rndV(n)).norm()
   375  	if len(x) == 0 {
   376  		x.setWord(1)
   377  	}
   378  	return x
   379  }
   380  
   381  func benchmarkNatMul(b *testing.B, nwords int) {
   382  	x := rndNat(nwords)
   383  	y := rndNat(nwords)
   384  	var z nat
   385  	b.ResetTimer()
   386  	b.ReportAllocs()
   387  	for i := 0; i < b.N; i++ {
   388  		z.mul(nil, x, y)
   389  	}
   390  }
   391  
   392  var mulBenchSizes = []int{10, 100, 1000, 10000, 100000}
   393  
   394  func BenchmarkNatMul(b *testing.B) {
   395  	for _, n := range mulBenchSizes {
   396  		if isRaceBuilder && n > 1e3 {
   397  			continue
   398  		}
   399  		b.Run(fmt.Sprintf("%d", n), func(b *testing.B) {
   400  			benchmarkNatMul(b, n)
   401  		})
   402  	}
   403  }
   404  
   405  func TestNLZ(t *testing.T) {
   406  	var x Word = _B >> 1
   407  	for i := 0; i <= _W; i++ {
   408  		if int(nlz(x)) != i {
   409  			t.Errorf("failed at %x: got %d want %d", x, nlz(x), i)
   410  		}
   411  		x >>= 1
   412  	}
   413  }
   414  
   415  type shiftTest struct {
   416  	in    nat
   417  	shift uint
   418  	out   nat
   419  }
   420  
   421  var leftShiftTests = []shiftTest{
   422  	{nil, 0, nil},
   423  	{nil, 1, nil},
   424  	{natOne, 0, natOne},
   425  	{natOne, 1, natTwo},
   426  	{nat{1 << (_W - 1)}, 1, nat{0}},
   427  	{nat{1 << (_W - 1), 0}, 1, nat{0, 1}},
   428  }
   429  
   430  func TestShiftLeft(t *testing.T) {
   431  	for i, test := range leftShiftTests {
   432  		var z nat
   433  		z = z.shl(test.in, test.shift)
   434  		for j, d := range test.out {
   435  			if j >= len(z) || z[j] != d {
   436  				t.Errorf("#%d: got: %v want: %v", i, z, test.out)
   437  				break
   438  			}
   439  		}
   440  	}
   441  }
   442  
   443  var rightShiftTests = []shiftTest{
   444  	{nil, 0, nil},
   445  	{nil, 1, nil},
   446  	{natOne, 0, natOne},
   447  	{natOne, 1, nil},
   448  	{natTwo, 1, natOne},
   449  	{nat{0, 1}, 1, nat{1 << (_W - 1)}},
   450  	{nat{2, 1, 1}, 1, nat{1<<(_W-1) + 1, 1 << (_W - 1)}},
   451  }
   452  
   453  func TestShiftRight(t *testing.T) {
   454  	for i, test := range rightShiftTests {
   455  		var z nat
   456  		z = z.shr(test.in, test.shift)
   457  		for j, d := range test.out {
   458  			if j >= len(z) || z[j] != d {
   459  				t.Errorf("#%d: got: %v want: %v", i, z, test.out)
   460  				break
   461  			}
   462  		}
   463  	}
   464  }
   465  
   466  func BenchmarkZeroShifts(b *testing.B) {
   467  	x := rndNat(800)
   468  
   469  	b.Run("Shl", func(b *testing.B) {
   470  		for i := 0; i < b.N; i++ {
   471  			var z nat
   472  			z.shl(x, 0)
   473  		}
   474  	})
   475  	b.Run("ShlSame", func(b *testing.B) {
   476  		for i := 0; i < b.N; i++ {
   477  			x.shl(x, 0)
   478  		}
   479  	})
   480  
   481  	b.Run("Shr", func(b *testing.B) {
   482  		for i := 0; i < b.N; i++ {
   483  			var z nat
   484  			z.shr(x, 0)
   485  		}
   486  	})
   487  	b.Run("ShrSame", func(b *testing.B) {
   488  		for i := 0; i < b.N; i++ {
   489  			x.shr(x, 0)
   490  		}
   491  	})
   492  }
   493  
   494  type modWTest struct {
   495  	in       string
   496  	dividend string
   497  	out      string
   498  }
   499  
   500  var modWTests32 = []modWTest{
   501  	{"23492635982634928349238759823742", "252341", "220170"},
   502  }
   503  
   504  var modWTests64 = []modWTest{
   505  	{"6527895462947293856291561095690465243862946", "524326975699234", "375066989628668"},
   506  }
   507  
   508  func runModWTests(t *testing.T, tests []modWTest) {
   509  	for i, test := range tests {
   510  		in, _ := new(Int).SetString(test.in, 10)
   511  		d, _ := new(Int).SetString(test.dividend, 10)
   512  		out, _ := new(Int).SetString(test.out, 10)
   513  
   514  		r := in.abs.modW(d.abs[0])
   515  		if r != out.abs[0] {
   516  			t.Errorf("#%d failed: got %d want %s", i, r, out)
   517  		}
   518  	}
   519  }
   520  
   521  func TestModW(t *testing.T) {
   522  	if _W >= 32 {
   523  		runModWTests(t, modWTests32)
   524  	}
   525  	if _W >= 64 {
   526  		runModWTests(t, modWTests64)
   527  	}
   528  }
   529  
   530  var montgomeryTests = []struct {
   531  	x, y, m      string
   532  	k0           uint64
   533  	out32, out64 string
   534  }{
   535  	{
   536  		"0xffffffffffffffffffffffffffffffffffffffffffffffffe",
   537  		"0xffffffffffffffffffffffffffffffffffffffffffffffffe",
   538  		"0xfffffffffffffffffffffffffffffffffffffffffffffffff",
   539  		1,
   540  		"0x1000000000000000000000000000000000000000000",
   541  		"0x10000000000000000000000000000000000",
   542  	},
   543  	{
   544  		"0x000000000ffffff5",
   545  		"0x000000000ffffff0",
   546  		"0x0000000010000001",
   547  		0xff0000000fffffff,
   548  		"0x000000000bfffff4",
   549  		"0x0000000003400001",
   550  	},
   551  	{
   552  		"0x0000000080000000",
   553  		"0x00000000ffffffff",
   554  		"0x1000000000000001",
   555  		0xfffffffffffffff,
   556  		"0x0800000008000001",
   557  		"0x0800000008000001",
   558  	},
   559  	{
   560  		"0x0000000080000000",
   561  		"0x0000000080000000",
   562  		"0xffffffff00000001",
   563  		0xfffffffeffffffff,
   564  		"0xbfffffff40000001",
   565  		"0xbfffffff40000001",
   566  	},
   567  	{
   568  		"0x0000000080000000",
   569  		"0x0000000080000000",
   570  		"0x00ffffff00000001",
   571  		0xfffffeffffffff,
   572  		"0xbfffff40000001",
   573  		"0xbfffff40000001",
   574  	},
   575  	{
   576  		"0x0000000080000000",
   577  		"0x0000000080000000",
   578  		"0x0000ffff00000001",
   579  		0xfffeffffffff,
   580  		"0xbfff40000001",
   581  		"0xbfff40000001",
   582  	},
   583  	{
   584  		"0x3321ffffffffffffffffffffffffffff00000000000022222623333333332bbbb888c0",
   585  		"0x3321ffffffffffffffffffffffffffff00000000000022222623333333332bbbb888c0",
   586  		"0x33377fffffffffffffffffffffffffffffffffffffffffffff0000000000022222eee1",
   587  		0xdecc8f1249812adf,
   588  		"0x04eb0e11d72329dc0915f86784820fc403275bf2f6620a20e0dd344c5cd0875e50deb5",
   589  		"0x0d7144739a7d8e11d72329dc0915f86784820fc403275bf2f61ed96f35dd34dbb3d6a0",
   590  	},
   591  	{
   592  		"0x10000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000ffffffffffffffffffffffffffffffff00000000000022222223333333333444444444",
   593  		"0x10000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000ffffffffffffffffffffffffffffffff999999999999999aaabbbbbbbbcccccccccccc",
   594  		"0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff33377fffffffffffffffffffffffffffffffffffffffffffff0000000000022222eee1",
   595  		0xdecc8f1249812adf,
   596  		"0x5c0d52f451aec609b15da8e5e5626c4eaa88723bdeac9d25ca9b961269400410ca208a16af9c2fb07d7a11c7772cba02c22f9711078d51a3797eb18e691295293284d988e349fa6deba46b25a4ecd9f715",
   597  		"0x92fcad4b5c0d52f451aec609b15da8e5e5626c4eaa88723bdeac9d25ca9b961269400410ca208a16af9c2fb07d799c32fe2f3cc5422f9711078d51a3797eb18e691295293284d8f5e69caf6decddfe1df6",
   598  	},
   599  }
   600  
   601  func TestMontgomery(t *testing.T) {
   602  	stk := getStack()
   603  	defer stk.free()
   604  
   605  	one := NewInt(1)
   606  	_B := new(Int).Lsh(one, _W)
   607  	for i, test := range montgomeryTests {
   608  		x := natFromString(test.x)
   609  		y := natFromString(test.y)
   610  		m := natFromString(test.m)
   611  		for len(x) < len(m) {
   612  			x = append(x, 0)
   613  		}
   614  		for len(y) < len(m) {
   615  			y = append(y, 0)
   616  		}
   617  
   618  		if x.cmp(m) > 0 {
   619  			_, r := nat(nil).div(stk, nil, x, m)
   620  			t.Errorf("#%d: x > m (0x%s > 0x%s; use 0x%s)", i, x.utoa(16), m.utoa(16), r.utoa(16))
   621  		}
   622  		if y.cmp(m) > 0 {
   623  			_, r := nat(nil).div(stk, nil, x, m)
   624  			t.Errorf("#%d: y > m (0x%s > 0x%s; use 0x%s)", i, y.utoa(16), m.utoa(16), r.utoa(16))
   625  		}
   626  
   627  		var out nat
   628  		if _W == 32 {
   629  			out = natFromString(test.out32)
   630  		} else {
   631  			out = natFromString(test.out64)
   632  		}
   633  
   634  		// t.Logf("#%d: len=%d\n", i, len(m))
   635  
   636  		// check output in table
   637  		xi := &Int{abs: x}
   638  		yi := &Int{abs: y}
   639  		mi := &Int{abs: m}
   640  		p := new(Int).Mod(new(Int).Mul(xi, new(Int).Mul(yi, new(Int).ModInverse(new(Int).Lsh(one, uint(len(m))*_W), mi))), mi)
   641  		if out.cmp(p.abs.norm()) != 0 {
   642  			t.Errorf("#%d: out in table=0x%s, computed=0x%s", i, out.utoa(16), p.abs.norm().utoa(16))
   643  		}
   644  
   645  		// check k0 in table
   646  		k := new(Int).Mod(&Int{abs: m}, _B)
   647  		k = new(Int).Sub(_B, k)
   648  		k = new(Int).Mod(k, _B)
   649  		k0 := Word(new(Int).ModInverse(k, _B).Uint64())
   650  		if k0 != Word(test.k0) {
   651  			t.Errorf("#%d: k0 in table=%#x, computed=%#x\n", i, test.k0, k0)
   652  		}
   653  
   654  		// check montgomery with correct k0 produces correct output
   655  		z := nat(nil).montgomery(x, y, m, k0, len(m))
   656  		z = z.norm()
   657  		if z.cmp(out) != 0 {
   658  			t.Errorf("#%d: got 0x%s want 0x%s", i, z.utoa(16), out.utoa(16))
   659  		}
   660  	}
   661  }
   662  
   663  var expNNTests = []struct {
   664  	x, y, m string
   665  	out     string
   666  }{
   667  	{"0", "0", "0", "1"},
   668  	{"0", "0", "1", "0"},
   669  	{"1", "1", "1", "0"},
   670  	{"2", "1", "1", "0"},
   671  	{"2", "2", "1", "0"},
   672  	{"10", "100000000000", "1", "0"},
   673  	{"0x8000000000000000", "2", "", "0x40000000000000000000000000000000"},
   674  	{"0x8000000000000000", "2", "6719", "4944"},
   675  	{"0x8000000000000000", "3", "6719", "5447"},
   676  	{"0x8000000000000000", "1000", "6719", "1603"},
   677  	{"0x8000000000000000", "1000000", "6719", "3199"},
   678  	{
   679  		"2938462938472983472983659726349017249287491026512746239764525612965293865296239471239874193284792387498274256129746192347",
   680  		"298472983472983471903246121093472394872319615612417471234712061",
   681  		"29834729834729834729347290846729561262544958723956495615629569234729836259263598127342374289365912465901365498236492183464",
   682  		"23537740700184054162508175125554701713153216681790245129157191391322321508055833908509185839069455749219131480588829346291",
   683  	},
   684  	{
   685  		"11521922904531591643048817447554701904414021819823889996244743037378330903763518501116638828335352811871131385129455853417360623007349090150042001944696604737499160174391019030572483602867266711107136838523916077674888297896995042968746762200926853379",
   686  		"426343618817810911523",
   687  		"444747819283133684179",
   688  		"42",
   689  	},
   690  	{"375", "249", "388", "175"},
   691  	{"375", "18446744073709551801", "388", "175"},
   692  	{"0", "0x40000000000000", "0x200", "0"},
   693  	{"0xeffffff900002f00", "0x40000000000000", "0x200", "0"},
   694  	{"5", "1435700818", "72", "49"},
   695  	{"0xffff", "0x300030003000300030003000300030003000302a3000300030003000300030003000300030003000300030003000300030003030623066307f3030783062303430383064303630343036", "0x300000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", "0xa3f94c08b0b90e87af637cacc9383f7ea032352b8961fc036a52b659b6c9b33491b335ffd74c927f64ddd62cfca0001"},
   696  }
   697  
   698  func TestExpNN(t *testing.T) {
   699  	stk := getStack()
   700  	defer stk.free()
   701  
   702  	for i, test := range expNNTests {
   703  		x := natFromString(test.x)
   704  		y := natFromString(test.y)
   705  		out := natFromString(test.out)
   706  
   707  		var m nat
   708  		if len(test.m) > 0 {
   709  			m = natFromString(test.m)
   710  		}
   711  
   712  		z := nat(nil).expNN(stk, x, y, m, false)
   713  		if z.cmp(out) != 0 {
   714  			t.Errorf("#%d got %s want %s", i, z.utoa(10), out.utoa(10))
   715  		}
   716  	}
   717  }
   718  
   719  func FuzzExpMont(f *testing.F) {
   720  	f.Fuzz(func(t *testing.T, x1, x2, x3, y1, y2, y3, m1, m2, m3 uint) {
   721  		if m1 == 0 && m2 == 0 && m3 == 0 {
   722  			return
   723  		}
   724  		x := new(Int).SetBits([]Word{Word(x1), Word(x2), Word(x3)})
   725  		y := new(Int).SetBits([]Word{Word(y1), Word(y2), Word(y3)})
   726  		m := new(Int).SetBits([]Word{Word(m1), Word(m2), Word(m3)})
   727  		out := new(Int).Exp(x, y, m)
   728  		want := new(Int).expSlow(x, y, m)
   729  		if out.Cmp(want) != 0 {
   730  			t.Errorf("x = %#x\ny=%#x\nz=%#x\nout=%#x\nwant=%#x\ndc: 16o 16i %X %X %X |p", x, y, m, out, want, x, y, m)
   731  		}
   732  	})
   733  }
   734  
   735  func BenchmarkExp3Power(b *testing.B) {
   736  	stk := getStack()
   737  	defer stk.free()
   738  
   739  	const x = 3
   740  	for _, y := range []Word{
   741  		0x10, 0x40, 0x100, 0x400, 0x1000, 0x4000, 0x10000, 0x40000, 0x100000, 0x400000,
   742  	} {
   743  		b.Run(fmt.Sprintf("%#x", y), func(b *testing.B) {
   744  			var z nat
   745  			for i := 0; i < b.N; i++ {
   746  				z.expWW(stk, x, y)
   747  			}
   748  		})
   749  	}
   750  }
   751  
   752  func fibo(n int) nat {
   753  	switch n {
   754  	case 0:
   755  		return nil
   756  	case 1:
   757  		return nat{1}
   758  	}
   759  	f0 := fibo(0)
   760  	f1 := fibo(1)
   761  	var f2 nat
   762  	for i := 1; i < n; i++ {
   763  		f2 = f2.add(f0, f1)
   764  		f0, f1, f2 = f1, f2, f0
   765  	}
   766  	return f1
   767  }
   768  
   769  var fiboNums = []string{
   770  	"0",
   771  	"55",
   772  	"6765",
   773  	"832040",
   774  	"102334155",
   775  	"12586269025",
   776  	"1548008755920",
   777  	"190392490709135",
   778  	"23416728348467685",
   779  	"2880067194370816120",
   780  	"354224848179261915075",
   781  }
   782  
   783  func TestFibo(t *testing.T) {
   784  	for i, want := range fiboNums {
   785  		n := i * 10
   786  		got := string(fibo(n).utoa(10))
   787  		if got != want {
   788  			t.Errorf("fibo(%d) failed: got %s want %s", n, got, want)
   789  		}
   790  	}
   791  }
   792  
   793  func BenchmarkFibo(b *testing.B) {
   794  	for i := 0; i < b.N; i++ {
   795  		fibo(1e0)
   796  		fibo(1e1)
   797  		fibo(1e2)
   798  		fibo(1e3)
   799  		fibo(1e4)
   800  		fibo(1e5)
   801  	}
   802  }
   803  
   804  var bitTests = []struct {
   805  	x    string
   806  	i    uint
   807  	want uint
   808  }{
   809  	{"0", 0, 0},
   810  	{"0", 1, 0},
   811  	{"0", 1000, 0},
   812  
   813  	{"0x1", 0, 1},
   814  	{"0x10", 0, 0},
   815  	{"0x10", 3, 0},
   816  	{"0x10", 4, 1},
   817  	{"0x10", 5, 0},
   818  
   819  	{"0x8000000000000000", 62, 0},
   820  	{"0x8000000000000000", 63, 1},
   821  	{"0x8000000000000000", 64, 0},
   822  
   823  	{"0x3" + strings.Repeat("0", 32), 127, 0},
   824  	{"0x3" + strings.Repeat("0", 32), 128, 1},
   825  	{"0x3" + strings.Repeat("0", 32), 129, 1},
   826  	{"0x3" + strings.Repeat("0", 32), 130, 0},
   827  }
   828  
   829  func TestBit(t *testing.T) {
   830  	for i, test := range bitTests {
   831  		x := natFromString(test.x)
   832  		if got := x.bit(test.i); got != test.want {
   833  			t.Errorf("#%d: %s.bit(%d) = %v; want %v", i, test.x, test.i, got, test.want)
   834  		}
   835  	}
   836  }
   837  
   838  var stickyTests = []struct {
   839  	x    string
   840  	i    uint
   841  	want uint
   842  }{
   843  	{"0", 0, 0},
   844  	{"0", 1, 0},
   845  	{"0", 1000, 0},
   846  
   847  	{"0x1", 0, 0},
   848  	{"0x1", 1, 1},
   849  
   850  	{"0x1350", 0, 0},
   851  	{"0x1350", 4, 0},
   852  	{"0x1350", 5, 1},
   853  
   854  	{"0x8000000000000000", 63, 0},
   855  	{"0x8000000000000000", 64, 1},
   856  
   857  	{"0x1" + strings.Repeat("0", 100), 400, 0},
   858  	{"0x1" + strings.Repeat("0", 100), 401, 1},
   859  }
   860  
   861  func TestSticky(t *testing.T) {
   862  	for i, test := range stickyTests {
   863  		x := natFromString(test.x)
   864  		if got := x.sticky(test.i); got != test.want {
   865  			t.Errorf("#%d: %s.sticky(%d) = %v; want %v", i, test.x, test.i, got, test.want)
   866  		}
   867  		if test.want == 1 {
   868  			// all subsequent i's should also return 1
   869  			for d := uint(1); d <= 3; d++ {
   870  				if got := x.sticky(test.i + d); got != 1 {
   871  					t.Errorf("#%d: %s.sticky(%d) = %v; want %v", i, test.x, test.i+d, got, 1)
   872  				}
   873  			}
   874  		}
   875  	}
   876  }
   877  
   878  func benchmarkNatSqr(b *testing.B, nwords int) {
   879  	x := rndNat(nwords)
   880  	var z nat
   881  	b.ResetTimer()
   882  	b.ReportAllocs()
   883  	for i := 0; i < b.N; i++ {
   884  		z.sqr(nil, x)
   885  	}
   886  }
   887  
   888  var sqrBenchSizes = []int{
   889  	1, 2, 3, 5, 8, 10, 20, 30, 50, 80,
   890  	100, 200, 300, 500, 800,
   891  	1000, 10000, 100000,
   892  }
   893  
   894  func BenchmarkNatSqr(b *testing.B) {
   895  	for _, n := range sqrBenchSizes {
   896  		if isRaceBuilder && n > 1e3 {
   897  			continue
   898  		}
   899  		b.Run(fmt.Sprintf("%d", n), func(b *testing.B) {
   900  			benchmarkNatSqr(b, n)
   901  		})
   902  	}
   903  }
   904  
   905  var subMod2NTests = []struct {
   906  	x string
   907  	y string
   908  	n uint
   909  	z string
   910  }{
   911  	{"1", "2", 0, "0"},
   912  	{"1", "0", 1, "1"},
   913  	{"0", "1", 1, "1"},
   914  	{"3", "5", 3, "6"},
   915  	{"5", "3", 3, "2"},
   916  	// 2^65, 2^66-1, 2^65 - (2^66-1) + 2^67
   917  	{"36893488147419103232", "73786976294838206463", 67, "110680464442257309697"},
   918  	// 2^66-1, 2^65, 2^65-1
   919  	{"73786976294838206463", "36893488147419103232", 67, "36893488147419103231"},
   920  }
   921  
   922  func TestNatSubMod2N(t *testing.T) {
   923  	for _, mode := range []string{"noalias", "aliasX", "aliasY"} {
   924  		t.Run(mode, func(t *testing.T) {
   925  			for _, tt := range subMod2NTests {
   926  				x0 := natFromString(tt.x)
   927  				y0 := natFromString(tt.y)
   928  				want := natFromString(tt.z)
   929  				x := nat(nil).set(x0)
   930  				y := nat(nil).set(y0)
   931  				var z nat
   932  				switch mode {
   933  				case "aliasX":
   934  					z = x
   935  				case "aliasY":
   936  					z = y
   937  				}
   938  				z = z.subMod2N(x, y, tt.n)
   939  				if z.cmp(want) != 0 {
   940  					t.Fatalf("subMod2N(%d, %d, %d) = %d, want %d", x0, y0, tt.n, z, want)
   941  				}
   942  				if mode != "aliasX" && x.cmp(x0) != 0 {
   943  					t.Fatalf("subMod2N(%d, %d, %d) modified x", x0, y0, tt.n)
   944  				}
   945  				if mode != "aliasY" && y.cmp(y0) != 0 {
   946  					t.Fatalf("subMod2N(%d, %d, %d) modified y", x0, y0, tt.n)
   947  				}
   948  			}
   949  		})
   950  	}
   951  }
   952  
   953  func BenchmarkNatSetBytes(b *testing.B) {
   954  	const maxLength = 128
   955  	lengths := []int{
   956  		// No remainder:
   957  		8, 24, maxLength,
   958  		// With remainder:
   959  		7, 23, maxLength - 1,
   960  	}
   961  	n := make(nat, maxLength/_W) // ensure n doesn't need to grow during the test
   962  	buf := make([]byte, maxLength)
   963  	for _, l := range lengths {
   964  		b.Run(fmt.Sprint(l), func(b *testing.B) {
   965  			for i := 0; i < b.N; i++ {
   966  				n.setBytes(buf[:l])
   967  			}
   968  		})
   969  	}
   970  }
   971  
   972  func TestNatDiv(t *testing.T) {
   973  	stk := getStack()
   974  	defer stk.free()
   975  
   976  	sizes := []int{
   977  		1, 2, 5, 8, 15, 25, 40, 65, 100,
   978  		200, 500, 800, 1500, 2500, 4000, 6500, 10000,
   979  	}
   980  	for _, i := range sizes {
   981  		for _, j := range sizes {
   982  			a := rndNat1(i)
   983  			b := rndNat1(j)
   984  			// the test requires b >= 2
   985  			if len(b) == 1 && b[0] == 1 {
   986  				b[0] = 2
   987  			}
   988  			// choose a remainder c < b
   989  			c := rndNat1(len(b))
   990  			if len(c) == len(b) && c[len(c)-1] >= b[len(b)-1] {
   991  				c[len(c)-1] = 0
   992  				c = c.norm()
   993  			}
   994  			// compute x = a*b+c
   995  			x := nat(nil).mul(stk, a, b)
   996  			x = x.add(x, c)
   997  
   998  			var q, r nat
   999  			q, r = q.div(stk, r, x, b)
  1000  			if q.cmp(a) != 0 {
  1001  				t.Fatalf("wrong quotient: got %s; want %s for %s/%s", q.utoa(10), a.utoa(10), x.utoa(10), b.utoa(10))
  1002  			}
  1003  			if r.cmp(c) != 0 {
  1004  				t.Fatalf("wrong remainder: got %s; want %s for %s/%s", r.utoa(10), c.utoa(10), x.utoa(10), b.utoa(10))
  1005  			}
  1006  		}
  1007  	}
  1008  }
  1009  
  1010  // TestIssue37499 triggers the edge case of divBasic where
  1011  // the inaccurate estimate of the first word's quotient
  1012  // happens at the very beginning of the loop.
  1013  func TestIssue37499(t *testing.T) {
  1014  	stk := getStack()
  1015  	defer stk.free()
  1016  
  1017  	// Choose u and v such that v is slightly larger than u >> N.
  1018  	// This tricks divBasic into choosing 1 as the first word
  1019  	// of the quotient. This works in both 32-bit and 64-bit settings.
  1020  	u := natFromString("0x2b6c385a05be027f5c22005b63c42a1165b79ff510e1706b39f8489c1d28e57bb5ba4ef9fd9387a3e344402c0a453381")
  1021  	v := natFromString("0x2b6c385a05be027f5c22005b63c42a1165b79ff510e1706c")
  1022  
  1023  	q := nat(nil).make(8)
  1024  	q.divBasic(stk, u, v)
  1025  	q = q.norm()
  1026  	if s := string(q.utoa(16)); s != "fffffffffffffffffffffffffffffffffffffffffffffffb" {
  1027  		t.Fatalf("incorrect quotient: %s", s)
  1028  	}
  1029  }
  1030  
  1031  // TestIssue42552 triggers an edge case of recursive division
  1032  // where the first division loop is never entered, and correcting
  1033  // the remainder takes exactly two iterations in the final loop.
  1034  func TestIssue42552(t *testing.T) {
  1035  	stk := getStack()
  1036  	defer stk.free()
  1037  
  1038  	u := natFromString("0xc23b166884c3869092a520eceedeced2b00847bd256c9cf3b2c5e2227c15bd5e6ee7ef8a2f49236ad0eedf2c8a3b453cf6e0706f64285c526b372c4b1321245519d430540804a50b7ca8b6f1b34a2ec05cdbc24de7599af112d3e3c8db347e8799fe70f16e43c6566ba3aeb169463a3ecc486172deb2d9b80a3699c776e44fef20036bd946f1b4d054dd88a2c1aeb986199b0b2b7e58c42288824b74934d112fe1fc06e06b4d99fe1c5e725946b23210521e209cd507cce90b5f39a523f27e861f9e232aee50c3f585208b4573dcc0b897b6177f2ba20254fd5c50a033e849dee1b3a93bd2dc44ba8ca836cab2c2ae50e50b126284524fa0187af28628ff0face68d87709200329db1392852c8b8963fbe3d05fb1efe19f0ed5ca9fadc2f96f82187c24bb2512b2e85a66333a7e176605695211e1c8e0b9b9e82813e50654964945b1e1e66a90840396c7d10e23e47f364d2d3f660fa54598e18d1ca2ea4fe4f35a40a11f69f201c80b48eaee3e2e9b0eda63decf92bec08a70f731587d4ed0f218d5929285c8b2ccbc497e20db42de73885191fa453350335990184d8df805072f958d5354debda38f5421effaaafd6cb9b721ace74be0892d77679f62a4a126697cd35797f6858193da4ba1770c06aea2e5c59ec04b8ea26749e61b72ecdde403f3bc7e5e546cd799578cc939fa676dfd5e648576d4a06cbadb028adc2c0b461f145b2321f42e5e0f3b4fb898ecd461df07a6f5154067787bf74b5cc5c03704a1ce47494961931f0263b0aac32505102595957531a2de69dd71aac51f8a49902f81f21283dbe8e21e01e5d82517868826f86acf338d935aa6b4d5a25c8d540389b277dd9d64569d68baf0f71bd03dba45b92a7fc052601d1bd011a2fc6790a23f97c6fa5caeea040ab86841f268d39ce4f7caf01069df78bba098e04366492f0c2ac24f1bf16828752765fa523c9a4d42b71109d123e6be8c7b1ab3ccf8ea03404075fe1a9596f1bba1d267f9a7879ceece514818316c9c0583469d2367831fc42b517ea028a28df7c18d783d16ea2436cee2b15d52db68b5dfdee6b4d26f0905f9b030c911a04d078923a4136afea96eed6874462a482917353264cc9bee298f167ac65a6db4e4eda88044b39cc0b33183843eaa946564a00c3a0ab661f2c915e70bf0bb65bfbb6fa2eea20aed16bf2c1a1d00ec55fb4ff2f76b8e462ea70c19efa579c9ee78194b86708fdae66a9ce6e2cf3d366037798cfb50277ba6d2fd4866361022fd788ab7735b40b8b61d55e32243e06719e53992e9ac16c9c4b6e6933635c3c47c8f7e73e17dd54d0dd8aeba5d76de46894e7b3f9d3ec25ad78ee82297ba69905ea0fa094b8667faa2b8885e2187b3da80268aa1164761d7b0d6de206b676777348152b8ae1d4afed753bc63c739a5ca8ce7afb2b241a226bd9e502baba391b5b13f5054f070b65a9cf3a67063bfaa803ba390732cd03888f664023f888741d04d564e0b5674b0a183ace81452001b3fbb4214c77d42ca75376742c471e58f67307726d56a1032bd236610cbcbcd03d0d7a452900136897dc55bb3ce959d10d4e6a10fb635006bd8c41cd9ded2d3dfdd8f2e229590324a7370cb2124210b2330f4c56155caa09a2564932ceded8d92c79664dcdeb87faad7d3da006cc2ea267ee3df41e9677789cc5a8cc3b83add6491561b3047919e0648b1b2e97d7ad6f6c2aa80cab8e9ae10e1f75b1fdd0246151af709d259a6a0ed0b26bd711024965ecad7c41387de45443defce53f66612948694a6032279131c257119ed876a8e805dfb49576ef5c563574115ee87050d92d191bc761ef51d966918e2ef925639400069e3959d8fe19f36136e947ff430bf74e71da0aa5923b00000000")
  1039  	v := natFromString("0x838332321d443a3d30373d47301d47073847473a383d3030f25b3d3d3e00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000002e00000000000000000041603038331c3d32f5303441e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e01c0a5459bfc7b9be9fcbb9d2383840464319434707303030f43a32f53034411c0a5459413820878787878787878787878787878787878787878787878787878787878787878787870630303a3a30334036605b923a6101f83638413943413960204337602043323801526040523241846038414143015238604060328452413841413638523c0240384141364036605b923a6101f83638413943413960204334602043323801526040523241846038414143015238604060328452413841413638523c02403841413638433030f25a8b83838383838383838383838383838383837d838383ffffffffffffffff838383838383838383000000000000000000030000007d26e27c7c8b83838383838383838383838383838383837d838383ffffffffffffffff83838383838383838383838383838383838383838383435960f535073030f3343200000000000000011881301938343030fa398383300000002300000000000000000000f11af4600c845252904141364138383c60406032414443095238010241414303364443434132305b595a15434160b042385341ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff47476043410536613603593a6005411c437405fcfcfcfcfcfcfc0000000000005a3b075815054359000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
  1040  	q := nat(nil).make(16)
  1041  	q.div(stk, q, u, v)
  1042  }
  1043  

View as plain text