Source file src/crypto/rsa/rsa_wycheproof_test.go

     1  // Copyright 2026 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  package rsa_test
     5  
     6  import (
     7  	"bytes"
     8  	"crypto/internal/cryptotest/wycheproof"
     9  	"crypto/rsa"
    10  	"crypto/x509"
    11  	"fmt"
    12  	"slices"
    13  	"testing"
    14  )
    15  
    16  func TestRSAOAEPDecryptWycheproof(t *testing.T) {
    17  	flagsShouldPass := map[string]bool{
    18  		"Constructed":         true,
    19  		"EncryptionWithLabel": true,
    20  		// rsa.DecryptOAEP happily supports small key sizes
    21  		"SmallIntegerCiphertext": true,
    22  	}
    23  
    24  	for _, file := range []string{
    25  		"rsa_oaep_2048_sha1_mgf1sha1_test.json",
    26  		"rsa_oaep_2048_sha224_mgf1sha1_test.json",
    27  		"rsa_oaep_2048_sha224_mgf1sha224_test.json",
    28  		"rsa_oaep_2048_sha256_mgf1sha1_test.json",
    29  		"rsa_oaep_2048_sha256_mgf1sha256_test.json",
    30  		"rsa_oaep_2048_sha384_mgf1sha1_test.json",
    31  		"rsa_oaep_2048_sha384_mgf1sha384_test.json",
    32  		"rsa_oaep_2048_sha512_224_mgf1sha1_test.json",
    33  		"rsa_oaep_2048_sha512_224_mgf1sha512_224_test.json",
    34  		"rsa_oaep_2048_sha512_mgf1sha1_test.json",
    35  		"rsa_oaep_2048_sha512_mgf1sha512_test.json",
    36  		"rsa_oaep_3072_sha256_mgf1sha1_test.json",
    37  		"rsa_oaep_3072_sha256_mgf1sha256_test.json",
    38  		"rsa_oaep_3072_sha512_256_mgf1sha1_test.json",
    39  		"rsa_oaep_3072_sha512_256_mgf1sha512_256_test.json",
    40  		"rsa_oaep_3072_sha512_mgf1sha1_test.json",
    41  		"rsa_oaep_3072_sha512_mgf1sha512_test.json",
    42  		"rsa_oaep_4096_sha256_mgf1sha1_test.json",
    43  		"rsa_oaep_4096_sha256_mgf1sha256_test.json",
    44  		"rsa_oaep_4096_sha512_mgf1sha1_test.json",
    45  		"rsa_oaep_4096_sha512_mgf1sha512_test.json",
    46  		"rsa_oaep_misc_test.json",
    47  	} {
    48  		var testdata wycheproof.RsaesOaepDecryptSchemaV1Json
    49  		wycheproof.LoadVectorFile(t, file, &testdata)
    50  
    51  		for _, tg := range testdata.TestGroups {
    52  			rawPriv, err := x509.ParsePKCS8PrivateKey(wycheproof.MustDecodeHex(tg.PrivateKeyPkcs8))
    53  			if err != nil {
    54  				t.Fatalf("%s failed to parse PKCS #8 private key: %s", file, err)
    55  			}
    56  			priv := rawPriv.(*rsa.PrivateKey)
    57  			hash := wycheproof.ParseHash(tg.Sha)
    58  			mgfHash := wycheproof.ParseHash(tg.MgfSha)
    59  
    60  			for _, tv := range tg.Tests {
    61  				t.Run(wycheproof.TestName(file, tv), func(t *testing.T) {
    62  					t.Parallel()
    63  
    64  					ct := wycheproof.MustDecodeHex(tv.Ct)
    65  					label := wycheproof.MustDecodeHex(tv.Label)
    66  					wantPass := wycheproof.ShouldPass(t, tv.Result, tv.Flags, flagsShouldPass)
    67  					opts := &rsa.OAEPOptions{
    68  						Hash:    hash,
    69  						MGFHash: mgfHash,
    70  						Label:   label,
    71  					}
    72  					plaintext, err := priv.Decrypt(nil, ct, opts)
    73  					if wantPass {
    74  						if err != nil {
    75  							t.Fatalf("expected success: %s", err)
    76  						}
    77  						if !bytes.Equal(plaintext, wycheproof.MustDecodeHex(tv.Msg)) {
    78  							t.Errorf("unexpected plaintext: got %x, want %s", plaintext, tv.Msg)
    79  						}
    80  					} else if err == nil {
    81  						t.Errorf("expected failure")
    82  					}
    83  				})
    84  			}
    85  		}
    86  	}
    87  }
    88  
    89  func TestRSAPKCS1DecryptWycheproof(t *testing.T) {
    90  	for _, file := range []string{
    91  		"rsa_pkcs1_2048_test.json",
    92  		"rsa_pkcs1_3072_test.json",
    93  		"rsa_pkcs1_4096_test.json",
    94  	} {
    95  		var testdata wycheproof.RsaesPkcs1DecryptSchemaV1Json
    96  		wycheproof.LoadVectorFile(t, file, &testdata)
    97  
    98  		for _, tg := range testdata.TestGroups {
    99  			rawPriv, err := x509.ParsePKCS8PrivateKey(wycheproof.MustDecodeHex(tg.PrivateKeyPkcs8))
   100  			if err != nil {
   101  				t.Fatalf("%s: failed to parse PKCS #8 private key: %v", file, err)
   102  			}
   103  			priv := rawPriv.(*rsa.PrivateKey)
   104  
   105  			for _, tv := range tg.Tests {
   106  				t.Run(wycheproof.TestName(file, tv), func(t *testing.T) {
   107  					t.Parallel()
   108  
   109  					ct := wycheproof.MustDecodeHex(tv.Ct)
   110  					expectedMsg := wycheproof.MustDecodeHex(tv.Msg)
   111  					wantPass := wycheproof.ShouldPass(t, tv.Result, tv.Flags, nil)
   112  
   113  					plaintext, err := rsa.DecryptPKCS1v15(nil, priv, ct)
   114  					if err != nil {
   115  						if wantPass {
   116  							t.Fatalf("DecryptPKCS1v15: %v", err)
   117  						}
   118  						return
   119  					}
   120  					if !wantPass {
   121  						t.Errorf("DecryptPKCS1v15 unexpectedly succeeded")
   122  						return
   123  					}
   124  					if !bytes.Equal(plaintext, expectedMsg) {
   125  						t.Errorf("plaintext mismatch: got %x, want %x", plaintext, expectedMsg)
   126  					}
   127  				})
   128  			}
   129  		}
   130  	}
   131  }
   132  
   133  func TestRSAPKCS1SignaturesWycheproof(t *testing.T) {
   134  	// A map of supported modulus sizes to the list of hashes that Wycheproof has
   135  	// test vector coverage for.
   136  	modsAndHashes := map[int][]string{
   137  		2048: {
   138  			"sha224",
   139  			"sha256",
   140  			"sha384",
   141  			"sha512",
   142  			"sha512_224",
   143  			"sha512_256",
   144  			"sha3_224",
   145  			"sha3_256",
   146  			"sha3_384",
   147  			"sha3_512",
   148  		},
   149  		3072: {
   150  			"sha256",
   151  			"sha384",
   152  			"sha512",
   153  			"sha512_256",
   154  			"sha3_256",
   155  			"sha3_384",
   156  			"sha3_512",
   157  		},
   158  		4096: {
   159  			"sha256",
   160  			"sha384",
   161  			"sha512",
   162  			"sha512_256",
   163  		},
   164  		8192: {
   165  			"sha256",
   166  			"sha384",
   167  			"sha512",
   168  		},
   169  	}
   170  
   171  	var files []string
   172  	for m, hashes := range modsAndHashes {
   173  		for _, h := range hashes {
   174  			files = append(files, fmt.Sprintf("rsa_signature_%d_%s_test.json", m, h))
   175  		}
   176  	}
   177  
   178  	flagsShouldPass := map[string]bool{
   179  		// Omitting the parameter field in an ASN encoded integer is a legacy behavior.
   180  		"MissingNull": false,
   181  	}
   182  
   183  	for _, file := range files {
   184  		var testdata wycheproof.RsassaPkcs1VerifySchemaV1Json
   185  		wycheproof.LoadVectorFile(t, file, &testdata)
   186  
   187  		for _, tg := range testdata.TestGroups {
   188  			hash := wycheproof.ParseHash(tg.Sha)
   189  
   190  			pub, err := x509.ParsePKCS1PublicKey(wycheproof.MustDecodeHex(tg.PublicKeyAsn))
   191  			if err != nil {
   192  				t.Fatalf("failed to decode pubkey: %v", err)
   193  			}
   194  
   195  			for _, tv := range tg.Tests {
   196  				t.Run(wycheproof.TestName(file, tv), func(t *testing.T) {
   197  					t.Parallel()
   198  
   199  					sig := wycheproof.MustDecodeHex(tv.Sig)
   200  					h := hash.New()
   201  					h.Write(wycheproof.MustDecodeHex(tv.Msg))
   202  					err := rsa.VerifyPKCS1v15(pub, hash, h.Sum(nil), sig)
   203  					want := wycheproof.ShouldPass(t, tv.Result, tv.Flags, flagsShouldPass)
   204  					if (err == nil) != want {
   205  						t.Errorf("wanted success: %t err: %v", want, err)
   206  					}
   207  				})
   208  			}
   209  		}
   210  	}
   211  }
   212  
   213  func TestRSAPSSSignaturesWycheproof(t *testing.T) {
   214  	// filesOverrideToPassZeroSLen is a map of all test files
   215  	// and which TcIds that should be overridden to pass if the
   216  	// rsa.PSSOptions.SaltLength is zero.
   217  	// These tests expect a failure with a PSSOptions.SaltLength: 0
   218  	// and a signature that uses a different salt length. However,
   219  	// a salt length of 0 is defined as rsa.PSSSaltLengthAuto which
   220  	// works deterministically to auto-detect the length when
   221  	// verifying, so these tests actually pass as they should.
   222  	filesOverrideToPassZeroSLen := map[string][]int{
   223  		"rsa_pss_2048_sha1_mgf1_20_test.json":   {46, 47, 48, 49, 50, 51},
   224  		"rsa_pss_2048_sha256_mgf1_0_test.json":  {67, 68, 69, 70},
   225  		"rsa_pss_2048_sha256_mgf1_32_test.json": {67, 68, 69, 70, 71, 72},
   226  		"rsa_pss_3072_sha256_mgf1_32_test.json": {67, 68, 69, 70, 71, 72},
   227  		"rsa_pss_4096_sha256_mgf1_32_test.json": {67, 68, 69, 70, 71, 72},
   228  		"rsa_pss_4096_sha512_mgf1_32_test.json": {136, 137, 138, 139, 140, 141},
   229  		"rsa_pss_misc_test.json":                nil,
   230  	}
   231  
   232  	for file, overrideIDs := range filesOverrideToPassZeroSLen {
   233  		var testdata wycheproof.RsassaPssVerifySchemaV1Json
   234  		wycheproof.LoadVectorFile(t, file, &testdata)
   235  
   236  		for _, tg := range testdata.TestGroups {
   237  			// Go's PSS implementation doesn't support different hash
   238  			// algorithms for message digest and MGF1. See #46233.
   239  			// Skip the affected test groups in rsa_pss_misc_test.json.
   240  			if file == "rsa_pss_misc_test.json" && tg.Sha != tg.MgfSha {
   241  				continue
   242  			}
   243  
   244  			hash := wycheproof.ParseHash(tg.Sha)
   245  
   246  			pub, err := x509.ParsePKCS1PublicKey(wycheproof.MustDecodeHex(tg.PublicKeyAsn))
   247  			if err != nil {
   248  				t.Fatalf("failed to decode pubkey: %v", err)
   249  			}
   250  
   251  			// Run all the tests twice: the first time with the salt length
   252  			// as PSSSaltLengthAuto, and the second time with the salt length
   253  			// explicitly set to tg.SLen.
   254  			for i := 0; i < 2; i++ {
   255  				saltLabel := "autoSalt"
   256  				if i == 1 {
   257  					saltLabel = "vecSalt"
   258  				}
   259  				opts := &rsa.PSSOptions{
   260  					Hash:       hash,
   261  					SaltLength: rsa.PSSSaltLengthAuto,
   262  				}
   263  
   264  				for _, tv := range tg.Tests {
   265  					t.Run(wycheproof.TestName(file, tv)+" "+saltLabel, func(t *testing.T) {
   266  						h := hash.New()
   267  						h.Write(wycheproof.MustDecodeHex(tv.Msg))
   268  						sig := wycheproof.MustDecodeHex(tv.Sig)
   269  						err = rsa.VerifyPSS(pub, hash, h.Sum(nil), sig, opts)
   270  						want := wycheproof.ShouldPass(t, tv.Result, tv.Flags, nil)
   271  						if opts.SaltLength == 0 && slices.Contains(overrideIDs, tv.TcId) {
   272  							want = true
   273  						}
   274  						if (err == nil) != want {
   275  							t.Errorf("wanted success: %t err: %v", want, err)
   276  						}
   277  					})
   278  				}
   279  
   280  				// Update opts.SaltLength for the second run of the tests.
   281  				opts.SaltLength = tg.SLen
   282  			}
   283  		}
   284  	}
   285  }
   286  

View as plain text