Source file src/math/big/calibrate_test.go

     1  // Copyright 2009 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // TestCalibrate determines appropriate thresholds for when to use
     6  // different calculation algorithms. To run it, use:
     7  //
     8  //	go test -run=Calibrate -calibrate >cal.log
     9  //
    10  // Calibration data is printed in CSV format, along with the normal test output.
    11  // See calibrate.md for more details about using the output.
    12  
    13  package big
    14  
    15  import (
    16  	"flag"
    17  	"fmt"
    18  	"internal/sysinfo"
    19  	"math"
    20  	"runtime"
    21  	"slices"
    22  	"strings"
    23  	"sync"
    24  	"testing"
    25  	"time"
    26  )
    27  
    28  var calibrate = flag.Bool("calibrate", false, "run calibration test")
    29  var calibrateOnce sync.Once
    30  
    31  func TestCalibrate(t *testing.T) {
    32  	if !*calibrate {
    33  		return
    34  	}
    35  
    36  	t.Run("KaratsubaMul", computeKaratsubaThreshold)
    37  	t.Run("BasicSqr", computeBasicSqrThreshold)
    38  	t.Run("KaratsubaSqr", computeKaratsubaSqrThreshold)
    39  	t.Run("DivRecursive", computeDivRecursiveThreshold)
    40  }
    41  
    42  func computeKaratsubaThreshold(t *testing.T) {
    43  	set := func(n int) { karatsubaThreshold = n }
    44  	computeThreshold(t, "karatsuba", set, 0, 4, 200, benchMul, 200, 8, 400)
    45  }
    46  
    47  func benchMul(size int) func() {
    48  	x := rndNat(size)
    49  	y := rndNat(size)
    50  	var z nat
    51  	return func() {
    52  		z.mul(nil, x, y)
    53  	}
    54  }
    55  
    56  func computeBasicSqrThreshold(t *testing.T) {
    57  	setDuringTest(t, &karatsubaSqrThreshold, 1e9)
    58  	set := func(n int) { basicSqrThreshold = n }
    59  	computeThreshold(t, "basicSqr", set, 2, 1, 40, benchBasicSqr, 1, 1, 40)
    60  }
    61  
    62  func benchBasicSqr(size int) func() {
    63  	x := rndNat(size)
    64  	var z nat
    65  	return func() {
    66  		// Run 100 squarings because 1 is too fast at the small sizes we consider.
    67  		// Some systems don't even have precise enough clocks to measure it accurately.
    68  		for range 100 {
    69  			z.sqr(nil, x)
    70  		}
    71  	}
    72  }
    73  
    74  func computeKaratsubaSqrThreshold(t *testing.T) {
    75  	set := func(n int) { karatsubaSqrThreshold = n }
    76  	computeThreshold(t, "karatsubaSqr", set, 0, 4, 200, benchSqr, 200, 8, 400)
    77  }
    78  
    79  func benchSqr(size int) func() {
    80  	x := rndNat(size)
    81  	var z nat
    82  	return func() {
    83  		z.sqr(nil, x)
    84  	}
    85  }
    86  
    87  func computeDivRecursiveThreshold(t *testing.T) {
    88  	set := func(n int) { divRecursiveThreshold = n }
    89  	computeThreshold(t, "divRecursive", set, 4, 4, 200, benchDiv, 200, 8, 400)
    90  }
    91  
    92  func benchDiv(size int) func() {
    93  	divx := rndNat(2 * size)
    94  	divy := rndNat(size)
    95  	var z, r nat
    96  	return func() {
    97  		z.div(nil, r, divx, divy)
    98  	}
    99  }
   100  
   101  func computeThreshold(t *testing.T, name string, set func(int), thresholdLo, thresholdStep, thresholdHi int, bench func(int) func(), sizeLo, sizeStep, sizeHi int) {
   102  	// Start CSV output; wrapped in txtar framing to separate CSV from other test ouptut.
   103  	fmt.Printf("-- calibrate-%s.csv --\n", name)
   104  	defer fmt.Printf("-- eof --\n")
   105  
   106  	fmt.Printf("goos,%s\n", runtime.GOOS)
   107  	fmt.Printf("goarch,%s\n", runtime.GOARCH)
   108  	fmt.Printf("cpu,%s\n", sysinfo.CPUName())
   109  	fmt.Printf("calibrate,%s\n", name)
   110  
   111  	// Expand lists of sizes and thresholds we will test.
   112  	var sizes, thresholds []int
   113  	for size := sizeLo; size <= sizeHi; size += sizeStep {
   114  		sizes = append(sizes, size)
   115  	}
   116  	for thresh := thresholdLo; thresh <= thresholdHi; thresh += thresholdStep {
   117  		thresholds = append(thresholds, thresh)
   118  	}
   119  
   120  	fmt.Printf("%s\n", csv("size \\ threshold", thresholds))
   121  
   122  	// Track minimum time observed for each size, threshold pair.
   123  	times := make([][]float64, len(sizes))
   124  	for i := range sizes {
   125  		times[i] = make([]float64, len(thresholds))
   126  		for j := range thresholds {
   127  			times[i][j] = math.Inf(+1)
   128  		}
   129  	}
   130  
   131  	// For each size, run at most MaxRounds of considering every threshold.
   132  	// If we run a threshold Stable times in a row without seeing more
   133  	// than a 1% improvement in the observed minimum, move on to the next one.
   134  	// After we run Converged rounds (not necessarily in a row)
   135  	// without seeing any threshold improve by more than 1%, stop.
   136  	const (
   137  		MaxRounds = 1600
   138  		Stable    = 20
   139  		Converged = 200
   140  	)
   141  
   142  	for i, size := range sizes {
   143  		b := bench(size)
   144  		same := 0
   145  		for range MaxRounds {
   146  			better := false
   147  			for j, threshold := range thresholds {
   148  				// No point if threshold is far beyond size
   149  				if false && threshold > size+2*sizeStep {
   150  					continue
   151  				}
   152  
   153  				// BasicSqr is different from the recursive thresholds: it either applies or not,
   154  				// without any question of recursive subproblems. Only try the thresholds
   155  				//	size-1, size, size+1, size+2
   156  				// to get two data points using basic multiplication and two using basic squaring.
   157  				// This avoids gathering many redundant data points.
   158  				// (The others have redundant data points as well, but for them the math is less trivial
   159  				// and best not duplicated in the calibration code.)
   160  				if false && name == "basicSqr" && (threshold < size-1 || threshold > size+3) {
   161  					continue
   162  				}
   163  
   164  				set(threshold)
   165  				b() // warm up
   166  				b()
   167  				tmin := times[i][j]
   168  				for k := 0; k < Stable; k++ {
   169  					start := time.Now()
   170  					b()
   171  					t := float64(time.Since(start))
   172  					if t < tmin {
   173  						if t < tmin*99/100 {
   174  							better = true
   175  							k = 0
   176  						}
   177  						tmin = t
   178  					}
   179  				}
   180  				times[i][j] = tmin
   181  			}
   182  			if !better {
   183  				if same++; same >= Converged {
   184  					break
   185  				}
   186  			}
   187  		}
   188  
   189  		fmt.Printf("%s\n", csv(fmt.Sprint(size), times[i]))
   190  	}
   191  
   192  	// For each size, normalize timings by the minimum achieved for that size.
   193  	fmt.Printf("%s\n", csv("size \\ threshold", thresholds))
   194  	norms := make([][]float64, len(sizes))
   195  	for i, times := range times {
   196  		m := min(1e100, slices.Min(times)) // make finite so divide preserves inf values
   197  		norms[i] = make([]float64, len(times))
   198  		for j, d := range times {
   199  			norms[i][j] = d / m
   200  		}
   201  		fmt.Printf("%s\n", csv(fmt.Sprint(sizes[i]), norms[i]))
   202  	}
   203  
   204  	// For each threshold, compute geomean of normalized timings across all sizes.
   205  	geomeans := make([]float64, len(thresholds))
   206  	for j := range thresholds {
   207  		p := 1.0
   208  		n := 0
   209  		for i := range sizes {
   210  			if v := norms[i][j]; !math.IsInf(v, +1) {
   211  				p *= v
   212  				n++
   213  			}
   214  		}
   215  		if n == 0 {
   216  			geomeans[j] = math.Inf(+1)
   217  		} else {
   218  			geomeans[j] = math.Pow(p, 1/float64(n))
   219  		}
   220  	}
   221  	fmt.Printf("%s\n", csv("geomean", geomeans))
   222  
   223  	// Add best threshold and smallest, largest within 10% and 5% of best.
   224  	var lo10, lo5, best, hi5, hi10 int
   225  	for i, g := range geomeans {
   226  		if g < geomeans[best] {
   227  			best = i
   228  		}
   229  	}
   230  	lo5 = best
   231  	for lo5 > 0 && geomeans[lo5-1] <= 1.05 {
   232  		lo5--
   233  	}
   234  	lo10 = lo5
   235  	for lo10 > 0 && geomeans[lo10-1] <= 1.10 {
   236  		lo10--
   237  	}
   238  	hi5 = best
   239  	for hi5+1 < len(geomeans) && geomeans[hi5+1] <= 1.05 {
   240  		hi5++
   241  	}
   242  	hi10 = hi5
   243  	for hi10+1 < len(geomeans) && geomeans[hi10+1] <= 1.10 {
   244  		hi10++
   245  	}
   246  	fmt.Printf("lo10%%,%d\n", thresholds[lo10])
   247  	fmt.Printf("lo5%%,%d\n", thresholds[lo5])
   248  	fmt.Printf("min,%d\n", thresholds[best])
   249  	fmt.Printf("hi5%%,%d\n", thresholds[hi5])
   250  	fmt.Printf("hi10%%,%d\n", thresholds[hi10])
   251  
   252  	set(thresholds[best])
   253  }
   254  
   255  // csv returns a single csv line starting with name and followed by the values.
   256  // Values that are float64 +infinity, denoting missing data, are replaced by an empty string.
   257  func csv[T int | float64](name string, values []T) string {
   258  	line := []string{name}
   259  	for _, v := range values {
   260  		if math.IsInf(float64(v), +1) {
   261  			line = append(line, "")
   262  		} else {
   263  			line = append(line, fmt.Sprint(v))
   264  		}
   265  	}
   266  	return strings.Join(line, ",")
   267  }
   268  

View as plain text