Source file
src/math/big/calibrate_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13 package big
14
15 import (
16 "flag"
17 "fmt"
18 "internal/sysinfo"
19 "math"
20 "runtime"
21 "slices"
22 "strings"
23 "sync"
24 "testing"
25 "time"
26 )
27
28 var calibrate = flag.Bool("calibrate", false, "run calibration test")
29 var calibrateOnce sync.Once
30
31 func TestCalibrate(t *testing.T) {
32 if !*calibrate {
33 return
34 }
35
36 t.Run("KaratsubaMul", computeKaratsubaThreshold)
37 t.Run("BasicSqr", computeBasicSqrThreshold)
38 t.Run("KaratsubaSqr", computeKaratsubaSqrThreshold)
39 t.Run("DivRecursive", computeDivRecursiveThreshold)
40 }
41
42 func computeKaratsubaThreshold(t *testing.T) {
43 set := func(n int) { karatsubaThreshold = n }
44 computeThreshold(t, "karatsuba", set, 0, 4, 200, benchMul, 200, 8, 400)
45 }
46
47 func benchMul(size int) func() {
48 x := rndNat(size)
49 y := rndNat(size)
50 var z nat
51 return func() {
52 z.mul(nil, x, y)
53 }
54 }
55
56 func computeBasicSqrThreshold(t *testing.T) {
57 setDuringTest(t, &karatsubaSqrThreshold, 1e9)
58 set := func(n int) { basicSqrThreshold = n }
59 computeThreshold(t, "basicSqr", set, 2, 1, 40, benchBasicSqr, 1, 1, 40)
60 }
61
62 func benchBasicSqr(size int) func() {
63 x := rndNat(size)
64 var z nat
65 return func() {
66
67
68 for range 100 {
69 z.sqr(nil, x)
70 }
71 }
72 }
73
74 func computeKaratsubaSqrThreshold(t *testing.T) {
75 set := func(n int) { karatsubaSqrThreshold = n }
76 computeThreshold(t, "karatsubaSqr", set, 0, 4, 200, benchSqr, 200, 8, 400)
77 }
78
79 func benchSqr(size int) func() {
80 x := rndNat(size)
81 var z nat
82 return func() {
83 z.sqr(nil, x)
84 }
85 }
86
87 func computeDivRecursiveThreshold(t *testing.T) {
88 set := func(n int) { divRecursiveThreshold = n }
89 computeThreshold(t, "divRecursive", set, 4, 4, 200, benchDiv, 200, 8, 400)
90 }
91
92 func benchDiv(size int) func() {
93 divx := rndNat(2 * size)
94 divy := rndNat(size)
95 var z, r nat
96 return func() {
97 z.div(nil, r, divx, divy)
98 }
99 }
100
101 func computeThreshold(t *testing.T, name string, set func(int), thresholdLo, thresholdStep, thresholdHi int, bench func(int) func(), sizeLo, sizeStep, sizeHi int) {
102
103 fmt.Printf("-- calibrate-%s.csv --\n", name)
104 defer fmt.Printf("-- eof --\n")
105
106 fmt.Printf("goos,%s\n", runtime.GOOS)
107 fmt.Printf("goarch,%s\n", runtime.GOARCH)
108 fmt.Printf("cpu,%s\n", sysinfo.CPUName())
109 fmt.Printf("calibrate,%s\n", name)
110
111
112 var sizes, thresholds []int
113 for size := sizeLo; size <= sizeHi; size += sizeStep {
114 sizes = append(sizes, size)
115 }
116 for thresh := thresholdLo; thresh <= thresholdHi; thresh += thresholdStep {
117 thresholds = append(thresholds, thresh)
118 }
119
120 fmt.Printf("%s\n", csv("size \\ threshold", thresholds))
121
122
123 times := make([][]float64, len(sizes))
124 for i := range sizes {
125 times[i] = make([]float64, len(thresholds))
126 for j := range thresholds {
127 times[i][j] = math.Inf(+1)
128 }
129 }
130
131
132
133
134
135
136 const (
137 MaxRounds = 1600
138 Stable = 20
139 Converged = 200
140 )
141
142 for i, size := range sizes {
143 b := bench(size)
144 same := 0
145 for range MaxRounds {
146 better := false
147 for j, threshold := range thresholds {
148
149 if false && threshold > size+2*sizeStep {
150 continue
151 }
152
153
154
155
156
157
158
159
160 if false && name == "basicSqr" && (threshold < size-1 || threshold > size+3) {
161 continue
162 }
163
164 set(threshold)
165 b()
166 b()
167 tmin := times[i][j]
168 for k := 0; k < Stable; k++ {
169 start := time.Now()
170 b()
171 t := float64(time.Since(start))
172 if t < tmin {
173 if t < tmin*99/100 {
174 better = true
175 k = 0
176 }
177 tmin = t
178 }
179 }
180 times[i][j] = tmin
181 }
182 if !better {
183 if same++; same >= Converged {
184 break
185 }
186 }
187 }
188
189 fmt.Printf("%s\n", csv(fmt.Sprint(size), times[i]))
190 }
191
192
193 fmt.Printf("%s\n", csv("size \\ threshold", thresholds))
194 norms := make([][]float64, len(sizes))
195 for i, times := range times {
196 m := min(1e100, slices.Min(times))
197 norms[i] = make([]float64, len(times))
198 for j, d := range times {
199 norms[i][j] = d / m
200 }
201 fmt.Printf("%s\n", csv(fmt.Sprint(sizes[i]), norms[i]))
202 }
203
204
205 geomeans := make([]float64, len(thresholds))
206 for j := range thresholds {
207 p := 1.0
208 n := 0
209 for i := range sizes {
210 if v := norms[i][j]; !math.IsInf(v, +1) {
211 p *= v
212 n++
213 }
214 }
215 if n == 0 {
216 geomeans[j] = math.Inf(+1)
217 } else {
218 geomeans[j] = math.Pow(p, 1/float64(n))
219 }
220 }
221 fmt.Printf("%s\n", csv("geomean", geomeans))
222
223
224 var lo10, lo5, best, hi5, hi10 int
225 for i, g := range geomeans {
226 if g < geomeans[best] {
227 best = i
228 }
229 }
230 lo5 = best
231 for lo5 > 0 && geomeans[lo5-1] <= 1.05 {
232 lo5--
233 }
234 lo10 = lo5
235 for lo10 > 0 && geomeans[lo10-1] <= 1.10 {
236 lo10--
237 }
238 hi5 = best
239 for hi5+1 < len(geomeans) && geomeans[hi5+1] <= 1.05 {
240 hi5++
241 }
242 hi10 = hi5
243 for hi10+1 < len(geomeans) && geomeans[hi10+1] <= 1.10 {
244 hi10++
245 }
246 fmt.Printf("lo10%%,%d\n", thresholds[lo10])
247 fmt.Printf("lo5%%,%d\n", thresholds[lo5])
248 fmt.Printf("min,%d\n", thresholds[best])
249 fmt.Printf("hi5%%,%d\n", thresholds[hi5])
250 fmt.Printf("hi10%%,%d\n", thresholds[hi10])
251
252 set(thresholds[best])
253 }
254
255
256
257 func csv[T int | float64](name string, values []T) string {
258 line := []string{name}
259 for _, v := range values {
260 if math.IsInf(float64(v), +1) {
261 line = append(line, "")
262 } else {
263 line = append(line, fmt.Sprint(v))
264 }
265 }
266 return strings.Join(line, ",")
267 }
268
View as plain text