// Copyright 2009 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. // Multiplication. package big // Operands that are shorter than karatsubaThreshold are multiplied using // "grade school" multiplication; for longer operands the Karatsuba algorithm // is used. var karatsubaThreshold = 40 // see calibrate_test.go // mul sets z = x*y, using stk for temporary storage. // The caller may pass stk == nil to request that mul obtain and release one itself. func (z nat) mul(stk *stack, x, y nat) nat { m := len(x) n := len(y) switch { case m < n: return z.mul(stk, y, x) case m == 0 || n == 0: return z[:0] case n == 1: return z.mulAddWW(x, y[0], 0) } // m >= n > 1 // determine if z can be reused if alias(z, x) || alias(z, y) { z = nil // z is an alias for x or y - cannot reuse } z = z.make(m + n) // use basic multiplication if the numbers are small if n < karatsubaThreshold { basicMul(z, x, y) return z.norm() } if stk == nil { stk = getStack() defer stk.free() } // Let x = x1:x0 where x0 is the same length as y. // Compute z = x0*y and then add in x1*y in sections // if needed. karatsuba(stk, z[:2*n], x[:n], y) if n < m { clear(z[2*n:]) defer stk.restore(stk.save()) t := stk.nat(2 * n) for i := n; i < m; i += n { t = t.mul(stk, x[i:min(i+n, len(x))], y) addTo(z[i:], t) } } return z.norm() } // Operands that are shorter than basicSqrThreshold are squared using // "grade school" multiplication; for operands longer than karatsubaSqrThreshold // we use the Karatsuba algorithm optimized for x == y. var basicSqrThreshold = 12 // see calibrate_test.go var karatsubaSqrThreshold = 80 // see calibrate_test.go // sqr sets z = x*x, using stk for temporary storage. // The caller may pass stk == nil to request that sqr obtain and release one itself. func (z nat) sqr(stk *stack, x nat) nat { n := len(x) switch { case n == 0: return z[:0] case n == 1: d := x[0] z = z.make(2) z[1], z[0] = mulWW(d, d) return z.norm() } if alias(z, x) { z = nil // z is an alias for x - cannot reuse } z = z.make(2 * n) if n < basicSqrThreshold && n < karatsubaSqrThreshold { basicMul(z, x, x) return z.norm() } if stk == nil { stk = getStack() defer stk.free() } if n < karatsubaSqrThreshold { basicSqr(stk, z, x) return z.norm() } karatsubaSqr(stk, z, x) return z.norm() } // basicSqr sets z = x*x and is asymptotically faster than basicMul // by about a factor of 2, but slower for small arguments due to overhead. // Requirements: len(x) > 0, len(z) == 2*len(x) // The (non-normalized) result is placed in z. func basicSqr(stk *stack, z, x nat) { n := len(x) if n < basicSqrThreshold { basicMul(z, x, x) return } defer stk.restore(stk.save()) t := stk.nat(2 * n) clear(t) z[1], z[0] = mulWW(x[0], x[0]) // the initial square for i := 1; i < n; i++ { d := x[i] // z collects the squares x[i] * x[i] z[2*i+1], z[2*i] = mulWW(d, d) // t collects the products x[i] * x[j] where j < i t[2*i] = addMulVVW(t[i:2*i], x[0:i], d) } t[2*n-1] = shlVU(t[1:2*n-1], t[1:2*n-1], 1) // double the j < i products addVV(z, z, t) // combine the result } // mulAddWW returns z = x*y + r. func (z nat) mulAddWW(x nat, y, r Word) nat { m := len(x) if m == 0 || y == 0 { return z.setWord(r) // result is r } // m > 0 z = z.make(m + 1) z[m] = mulAddVWW(z[0:m], x, y, r) return z.norm() } // basicMul multiplies x and y and leaves the result in z. // The (non-normalized) result is placed in z[0 : len(x) + len(y)]. func basicMul(z, x, y nat) { clear(z[0 : len(x)+len(y)]) // initialize z for i, d := range y { if d != 0 { z[len(x)+i] = addMulVVW(z[i:i+len(x)], x, d) } } } // karatsuba multiplies x and y, // writing the (non-normalized) result to z. // x and y must have the same length n, // and z must have length twice that. func karatsuba(stk *stack, z, x, y nat) { n := len(y) if len(x) != n || len(z) != 2*n { panic("bad karatsuba length") } // Fall back to basic algorithm if small enough. if n < karatsubaThreshold || n < 2 { basicMul(z, x, y) return } // Let the notation x1:x0 denote the nat (x1< D { s, t = s[:len(s)-D], s[len(s)-D:]+"_"+t } return neg + s + t } // trace prints a single debug value. func trace(name string, x *Int) { print(name, "=", ifmt(x), "\n") }