Source file src/crypto/mldsa/mldsa_wycheproof_test.go

     1  // Copyright 2026 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  //go:build !fips140v1.0
     6  
     7  package mldsa_test
     8  
     9  import (
    10  	"bytes"
    11  	"crypto"
    12  	"crypto/internal/cryptotest/wycheproof"
    13  	internalmldsa "crypto/internal/fips140/mldsa"
    14  	"crypto/mldsa"
    15  	"encoding/json"
    16  	"slices"
    17  	"testing"
    18  )
    19  
    20  // TestVerifyWycheproof test signature verification using the public
    21  // mldsa API.
    22  func TestVerifyWycheproof(t *testing.T) {
    23  	for _, file := range []string{
    24  		"mldsa_44_verify_test.json",
    25  		"mldsa_65_verify_test.json",
    26  		"mldsa_87_verify_test.json",
    27  	} {
    28  		var testdata wycheproof.MldsaVerifySchemaJson
    29  		wycheproof.LoadVectorFile(t, file, &testdata)
    30  
    31  		params := paramsForAlg(testdata.Algorithm)
    32  
    33  		for _, tg := range testdata.TestGroups {
    34  			publicKey := wycheproof.MustDecodeHex(tg.PublicKey)
    35  
    36  			for _, tv := range tg.Tests {
    37  				t.Run(wycheproof.TestName(file, tv), func(t *testing.T) {
    38  					t.Parallel()
    39  
    40  					shouldPass := wycheproof.ShouldPass(t, tv.Result, tv.Flags, nil)
    41  
    42  					pub, err := mldsa.NewPublicKey(params, publicKey)
    43  					if err != nil {
    44  						if shouldPass {
    45  							t.Fatalf("NewPublicKey: %v", err)
    46  						}
    47  						return
    48  					}
    49  
    50  					if !bytes.Equal(pub.Bytes(), publicKey) {
    51  						t.Errorf("public key roundtrip mismatch")
    52  					}
    53  
    54  					msg := wycheproof.MustDecodeHex(tv.Msg)
    55  					sig := wycheproof.MustDecodeHex(tv.Sig)
    56  					opts := new(mldsa.Options)
    57  					if tv.Ctx != nil {
    58  						opts.Context = string(wycheproof.MustDecodeHex(*tv.Ctx))
    59  					}
    60  
    61  					err = mldsa.Verify(pub, msg, sig, opts)
    62  					if shouldPass && err != nil {
    63  						t.Errorf("Verify: %v", err)
    64  					}
    65  					if !shouldPass && err == nil {
    66  						t.Errorf("Verify should have failed")
    67  					}
    68  				})
    69  			}
    70  		}
    71  	}
    72  }
    73  
    74  // TestSignSeedWycheproof tests key generation and signature creation using
    75  // the public mldsa API for seed private key inputs.
    76  //
    77  // It covers deterministic signature creation with and without pre-hashed mu.
    78  func TestSignSeedWycheproof(t *testing.T) {
    79  	// We don't include the mldsa_*_sign_noseed_test.json test vector files.
    80  	// Semi-expanded keys are not supported with the public API.
    81  	for _, file := range []string{
    82  		"mldsa_44_sign_seed_test.json",
    83  		"mldsa_65_sign_seed_test.json",
    84  		"mldsa_87_sign_seed_test.json",
    85  	} {
    86  		var testdata wycheproof.MldsaSignSeedSchemaJson
    87  		wycheproof.LoadVectorFile(t, file, &testdata)
    88  
    89  		params := paramsForAlg(testdata.Algorithm)
    90  
    91  		for _, tg := range testdata.TestGroups {
    92  			seed := wycheproof.MustDecodeHex(tg.PrivateSeed)
    93  			var expectedPublicKey []byte
    94  			if pk, ok := tg.PublicKey.(string); ok {
    95  				expectedPublicKey = wycheproof.MustDecodeHex(pk)
    96  			}
    97  
    98  			for _, raw := range tg.Tests {
    99  				tv := decodeMLDSASignTestVector(t, raw)
   100  
   101  				t.Run(wycheproof.TestName(file, tv), func(t *testing.T) {
   102  					t.Parallel()
   103  
   104  					shouldPass := wycheproof.ShouldPass(t, tv.Result, tv.Flags, nil)
   105  
   106  					priv, err := mldsa.NewPrivateKey(params, seed)
   107  					if err != nil {
   108  						if shouldPass {
   109  							t.Fatalf("NewPrivateKey: %v", err)
   110  						}
   111  						return
   112  					}
   113  					// By checking the derived public key is equal to the vector's
   114  					// provided public key the 'sign' vectors double as key
   115  					// generation vectors.
   116  					if expectedPublicKey != nil && !bytes.Equal(priv.PublicKey().Bytes(), expectedPublicKey) {
   117  						t.Fatalf("public key mismatch")
   118  					}
   119  
   120  					if slices.Contains(tv.Flags, "Randomized") {
   121  						t.Skipf("randomized signatures not supported with public API")
   122  					}
   123  
   124  					runSignTest(t, priv, tv, shouldPass)
   125  				})
   126  			}
   127  		}
   128  	}
   129  }
   130  
   131  func runSignTest(t *testing.T, priv *mldsa.PrivateKey, tv mldsaSignTestVector, shouldPass bool) {
   132  	t.Helper()
   133  
   134  	var msg, μ []byte
   135  	opts := new(mldsa.Options)
   136  	if tv.Msg != nil {
   137  		msg = wycheproof.MustDecodeHex(*tv.Msg)
   138  		if tv.Ctx != nil {
   139  			opts.Context = string(wycheproof.MustDecodeHex(*tv.Ctx))
   140  		}
   141  	}
   142  	if tv.Mu != nil && *tv.Mu != "" {
   143  		μ = wycheproof.MustDecodeHex(*tv.Mu)
   144  	}
   145  	if msg == nil && μ == nil {
   146  		t.Fatalf("test vector has neither msg nor mu")
   147  	}
   148  
   149  	var sigMsg, sigMu []byte
   150  	var errMsg, errMu error
   151  	if msg != nil {
   152  		sigMsg, errMsg = priv.SignDeterministic(msg, opts)
   153  	}
   154  	if μ != nil {
   155  		sigMu, errMu = priv.SignDeterministic(μ, crypto.MLDSAMu)
   156  	}
   157  
   158  	for _, e := range []error{errMsg, errMu} {
   159  		if e != nil {
   160  			if shouldPass {
   161  				t.Fatalf("Sign: %v", e)
   162  			}
   163  			return
   164  		}
   165  	}
   166  	if !shouldPass {
   167  		t.Errorf("Sign unexpectedly succeeded")
   168  		return
   169  	}
   170  
   171  	expectedSig := wycheproof.MustDecodeHex(tv.Sig)
   172  	sig := sigMsg
   173  	if sig == nil {
   174  		sig = sigMu
   175  	}
   176  	if sigMsg != nil && sigMu != nil && !bytes.Equal(sigMsg, sigMu) {
   177  		t.Errorf("Sign(msg, ctx) and SignExternalMu(mu) disagree")
   178  	}
   179  	if !bytes.Equal(sig, expectedSig) {
   180  		t.Errorf("signature mismatch")
   181  	}
   182  
   183  	pub := priv.PublicKey()
   184  	if msg != nil {
   185  		if err := mldsa.Verify(pub, msg, sig, opts); err != nil {
   186  			t.Errorf("Verify of own signature failed: %v", err)
   187  		}
   188  	}
   189  	// note: we can't round-trip verify external-mu signatures with the public API.
   190  	//  but if that capability were exposed in the future we could check here for
   191  	//  mu != nil.
   192  }
   193  
   194  // TestMLDSASignSeedRandomizedWycheproof tests randomized signing with the
   195  // internal testing-only API.
   196  //
   197  // It covers randomized signature creation with and without pre-hashed mu.
   198  func TestMLDSASignSeedRandomizedWycheproof(t *testing.T) {
   199  	for _, file := range []string{
   200  		"mldsa_44_sign_seed_test.json",
   201  		"mldsa_65_sign_seed_test.json",
   202  		"mldsa_87_sign_seed_test.json",
   203  	} {
   204  		var testdata wycheproof.MldsaSignSeedSchemaJson
   205  		wycheproof.LoadVectorFile(t, file, &testdata)
   206  
   207  		newPriv := newPrivateKeyFromSeedFn(t, testdata.Algorithm)
   208  
   209  		for _, tg := range testdata.TestGroups {
   210  			seed := wycheproof.MustDecodeHex(tg.PrivateSeed)
   211  			var expectedPublicKey []byte
   212  			if pk, ok := tg.PublicKey.(string); ok {
   213  				expectedPublicKey = wycheproof.MustDecodeHex(pk)
   214  			}
   215  
   216  			for _, raw := range tg.Tests {
   217  				tv := decodeMLDSASignTestVector(t, raw)
   218  
   219  				t.Run(wycheproof.TestName(file, tv), func(t *testing.T) {
   220  					t.Parallel()
   221  
   222  					shouldPass := wycheproof.ShouldPass(t, tv.Result, tv.Flags, nil)
   223  					priv, err := newPriv(seed)
   224  					if err != nil {
   225  						if shouldPass {
   226  							t.Fatalf("NewPrivateKey: %v", err)
   227  						}
   228  						return
   229  					}
   230  
   231  					// By checking the derived public key is equal to the vector's
   232  					// provided public key the 'sign' vectors double as key
   233  					// generation vectors.
   234  					if expectedPublicKey != nil && !bytes.Equal(priv.PublicKey().Bytes(), expectedPublicKey) {
   235  						t.Fatalf("public key mismatch")
   236  					}
   237  
   238  					runRandomizedSignTest(t, priv, tv, shouldPass)
   239  				})
   240  			}
   241  		}
   242  	}
   243  }
   244  
   245  func runRandomizedSignTest(t *testing.T, priv *internalmldsa.PrivateKey, tv mldsaSignTestVector, shouldPass bool) {
   246  	t.Helper()
   247  
   248  	var msg, μ []byte
   249  	var ctx string
   250  	rnd := make([]byte, 32)
   251  
   252  	if tv.Msg != nil {
   253  		msg = wycheproof.MustDecodeHex(*tv.Msg)
   254  		if tv.Ctx != nil {
   255  			ctx = string(wycheproof.MustDecodeHex(*tv.Ctx))
   256  		}
   257  	}
   258  	if tv.Mu != nil && *tv.Mu != "" {
   259  		μ = wycheproof.MustDecodeHex(*tv.Mu)
   260  	}
   261  	if tv.Rnd != nil && *tv.Rnd != "" {
   262  		rnd = wycheproof.MustDecodeHex(*tv.Rnd)
   263  	}
   264  
   265  	if msg == nil && μ == nil {
   266  		t.Fatalf("test vector has neither msg nor mu")
   267  	}
   268  
   269  	var sigMsg, sigMu []byte
   270  	var errMsg, errMu error
   271  	if msg != nil {
   272  		sigMsg, errMsg = internalmldsa.TestingOnlySignWithRandom(priv, msg, ctx, rnd)
   273  	}
   274  	if μ != nil {
   275  		sigMu, errMu = internalmldsa.TestingOnlySignExternalMuWithRandom(priv, μ, rnd)
   276  	}
   277  
   278  	for _, e := range []error{errMsg, errMu} {
   279  		if e != nil {
   280  			if shouldPass {
   281  				t.Fatalf("Sign: %v", e)
   282  			}
   283  			return
   284  		}
   285  	}
   286  	if !shouldPass {
   287  		t.Errorf("Sign unexpectedly succeeded")
   288  		return
   289  	}
   290  
   291  	expectedSig := wycheproof.MustDecodeHex(tv.Sig)
   292  	sig := sigMsg
   293  	if sig == nil {
   294  		sig = sigMu
   295  	}
   296  	if sigMsg != nil && sigMu != nil && !bytes.Equal(sigMsg, sigMu) {
   297  		t.Errorf("Sign(msg, ctx) and SignExternalMu(mu) disagree")
   298  	}
   299  	if !bytes.Equal(sig, expectedSig) {
   300  		t.Errorf("signature mismatch")
   301  	}
   302  
   303  	pub := priv.PublicKey()
   304  	if msg != nil {
   305  		if err := internalmldsa.Verify(pub, msg, sig, ctx); err != nil {
   306  			t.Errorf("Verify of own signature failed: %v", err)
   307  		}
   308  	}
   309  	if μ != nil {
   310  		if err := internalmldsa.VerifyExternalMu(pub, μ, sig); err != nil {
   311  			t.Errorf("VerifyExternalMu of own signature failed: %v", err)
   312  		}
   313  	}
   314  }
   315  
   316  // TestMLDSANoSeedWycheproof tests semi-expanded private key inputs
   317  // derive the correct public key using the internal testing-only API.
   318  //
   319  // We don't perform further signature signing operations as this is covered
   320  // by the seed-form TestSignSeedWycheproof.
   321  func TestMLDSANoSeedWycheproof(t *testing.T) {
   322  	for _, file := range []string{
   323  		"mldsa_44_sign_noseed_test.json",
   324  		"mldsa_65_sign_noseed_test.json",
   325  		"mldsa_87_sign_noseed_test.json",
   326  	} {
   327  		var testdata wycheproof.MldsaSignNoseedSchemaJson
   328  		wycheproof.LoadVectorFile(t, file, &testdata)
   329  
   330  		for _, tg := range testdata.TestGroups {
   331  			privateKey := wycheproof.MustDecodeHex(tg.PrivateKey)
   332  			var expectedPublicKey []byte
   333  			if pk, ok := tg.PublicKey.(string); ok {
   334  				expectedPublicKey = wycheproof.MustDecodeHex(pk)
   335  			}
   336  
   337  			for _, raw := range tg.Tests {
   338  				tv := decodeMLDSASignTestVector(t, raw)
   339  				t.Run(wycheproof.TestName(file, tv), func(t *testing.T) {
   340  					t.Parallel()
   341  
   342  					shouldPass := wycheproof.ShouldPass(t, tv.Result, tv.Flags, nil)
   343  					priv, err := internalmldsa.TestingOnlyNewPrivateKeyFromSemiExpanded(privateKey)
   344  					if err != nil {
   345  						if shouldPass {
   346  							t.Fatalf("TestingOnlyNewPrivateKeyFromSemiExpanded: %v", err)
   347  						}
   348  						return
   349  					}
   350  
   351  					if expectedPublicKey != nil && !bytes.Equal(priv.PublicKey().Bytes(), expectedPublicKey) {
   352  						t.Fatalf("public key mismatch")
   353  					}
   354  				})
   355  			}
   356  		}
   357  	}
   358  }
   359  
   360  func paramsForAlg(algorithm string) mldsa.Parameters {
   361  	switch algorithm {
   362  	case "ML-DSA-44":
   363  		return mldsa.MLDSA44()
   364  	case "ML-DSA-65":
   365  		return mldsa.MLDSA65()
   366  	case "ML-DSA-87":
   367  		return mldsa.MLDSA87()
   368  	}
   369  	panic("unknown algorithm: " + algorithm)
   370  }
   371  
   372  func newPrivateKeyFromSeedFn(t *testing.T, algorithm string) func([]byte) (*internalmldsa.PrivateKey, error) {
   373  	switch algorithm {
   374  	case "ML-DSA-44":
   375  		return internalmldsa.NewPrivateKey44
   376  	case "ML-DSA-65":
   377  		return internalmldsa.NewPrivateKey65
   378  	case "ML-DSA-87":
   379  		return internalmldsa.NewPrivateKey87
   380  	}
   381  	t.Fatalf("unknown algorithm: %s", algorithm)
   382  	return nil
   383  }
   384  
   385  // mldsaSignTestVector is a typed view of wycheproof.MlDsaSignTestVector,
   386  // which the schema generator emits as interface{} because of the schema's
   387  // conditional clauses.
   388  type mldsaSignTestVector struct {
   389  	TcId    int               `json:"tcId"`
   390  	Comment string            `json:"comment"`
   391  	Msg     *string           `json:"msg,omitempty"`
   392  	Ctx     *string           `json:"ctx,omitempty"`
   393  	Mu      *string           `json:"mu,omitempty"`
   394  	Rnd     *string           `json:"rnd,omitempty"`
   395  	Sig     string            `json:"sig"`
   396  	Result  wycheproof.Result `json:"result"`
   397  	Flags   []string          `json:"flags"`
   398  }
   399  
   400  // decodeMLDSASignTestVector roundtrips an interface{} typed raw
   401  // MlDsaSignTestVector to produce a typed MLDSASignTestVector.
   402  // This is a workaround for a limitation of the schema generator.
   403  func decodeMLDSASignTestVector(t *testing.T, raw wycheproof.MlDsaSignTestVector) mldsaSignTestVector {
   404  	t.Helper()
   405  	b, err := json.Marshal(raw)
   406  	if err != nil {
   407  		t.Fatalf("re-marshal sign test vector: %v", err)
   408  	}
   409  	var tv mldsaSignTestVector
   410  	if err := json.Unmarshal(b, &tv); err != nil {
   411  		t.Fatalf("decode sign test vector: %v", err)
   412  	}
   413  	return tv
   414  }
   415  

View as plain text