Source file src/crypto/internal/hpke/hpke.go

     1  // Copyright 2024 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package hpke
     6  
     7  import (
     8  	"crypto"
     9  	"crypto/aes"
    10  	"crypto/cipher"
    11  	"crypto/ecdh"
    12  	"crypto/hkdf"
    13  	"crypto/rand"
    14  	"errors"
    15  	"internal/byteorder"
    16  	"math/bits"
    17  
    18  	"golang.org/x/crypto/chacha20poly1305"
    19  )
    20  
    21  // testingOnlyGenerateKey is only used during testing, to provide
    22  // a fixed test key to use when checking the RFC 9180 vectors.
    23  var testingOnlyGenerateKey func() (*ecdh.PrivateKey, error)
    24  
    25  type hkdfKDF struct {
    26  	hash crypto.Hash
    27  }
    28  
    29  func (kdf *hkdfKDF) LabeledExtract(sid []byte, salt []byte, label string, inputKey []byte) ([]byte, error) {
    30  	labeledIKM := make([]byte, 0, 7+len(sid)+len(label)+len(inputKey))
    31  	labeledIKM = append(labeledIKM, []byte("HPKE-v1")...)
    32  	labeledIKM = append(labeledIKM, sid...)
    33  	labeledIKM = append(labeledIKM, label...)
    34  	labeledIKM = append(labeledIKM, inputKey...)
    35  	return hkdf.Extract(kdf.hash.New, labeledIKM, salt)
    36  }
    37  
    38  func (kdf *hkdfKDF) LabeledExpand(suiteID []byte, randomKey []byte, label string, info []byte, length uint16) ([]byte, error) {
    39  	labeledInfo := make([]byte, 0, 2+7+len(suiteID)+len(label)+len(info))
    40  	labeledInfo = byteorder.BEAppendUint16(labeledInfo, length)
    41  	labeledInfo = append(labeledInfo, []byte("HPKE-v1")...)
    42  	labeledInfo = append(labeledInfo, suiteID...)
    43  	labeledInfo = append(labeledInfo, label...)
    44  	labeledInfo = append(labeledInfo, info...)
    45  	return hkdf.Expand(kdf.hash.New, randomKey, string(labeledInfo), int(length))
    46  }
    47  
    48  // dhKEM implements the KEM specified in RFC 9180, Section 4.1.
    49  type dhKEM struct {
    50  	dh  ecdh.Curve
    51  	kdf hkdfKDF
    52  
    53  	suiteID []byte
    54  	nSecret uint16
    55  }
    56  
    57  type KemID uint16
    58  
    59  const DHKEM_X25519_HKDF_SHA256 = 0x0020
    60  
    61  var SupportedKEMs = map[uint16]struct {
    62  	curve   ecdh.Curve
    63  	hash    crypto.Hash
    64  	nSecret uint16
    65  }{
    66  	// RFC 9180 Section 7.1
    67  	DHKEM_X25519_HKDF_SHA256: {ecdh.X25519(), crypto.SHA256, 32},
    68  }
    69  
    70  func newDHKem(kemID uint16) (*dhKEM, error) {
    71  	suite, ok := SupportedKEMs[kemID]
    72  	if !ok {
    73  		return nil, errors.New("unsupported suite ID")
    74  	}
    75  	return &dhKEM{
    76  		dh:      suite.curve,
    77  		kdf:     hkdfKDF{suite.hash},
    78  		suiteID: byteorder.BEAppendUint16([]byte("KEM"), kemID),
    79  		nSecret: suite.nSecret,
    80  	}, nil
    81  }
    82  
    83  func (dh *dhKEM) ExtractAndExpand(dhKey, kemContext []byte) ([]byte, error) {
    84  	eaePRK, err := dh.kdf.LabeledExtract(dh.suiteID[:], nil, "eae_prk", dhKey)
    85  	if err != nil {
    86  		return nil, err
    87  	}
    88  	return dh.kdf.LabeledExpand(dh.suiteID[:], eaePRK, "shared_secret", kemContext, dh.nSecret)
    89  }
    90  
    91  func (dh *dhKEM) Encap(pubRecipient *ecdh.PublicKey) (sharedSecret []byte, encapPub []byte, err error) {
    92  	var privEph *ecdh.PrivateKey
    93  	if testingOnlyGenerateKey != nil {
    94  		privEph, err = testingOnlyGenerateKey()
    95  	} else {
    96  		privEph, err = dh.dh.GenerateKey(rand.Reader)
    97  	}
    98  	if err != nil {
    99  		return nil, nil, err
   100  	}
   101  	dhVal, err := privEph.ECDH(pubRecipient)
   102  	if err != nil {
   103  		return nil, nil, err
   104  	}
   105  	encPubEph := privEph.PublicKey().Bytes()
   106  
   107  	encPubRecip := pubRecipient.Bytes()
   108  	kemContext := append(encPubEph, encPubRecip...)
   109  	sharedSecret, err = dh.ExtractAndExpand(dhVal, kemContext)
   110  	if err != nil {
   111  		return nil, nil, err
   112  	}
   113  	return sharedSecret, encPubEph, nil
   114  }
   115  
   116  func (dh *dhKEM) Decap(encPubEph []byte, secRecipient *ecdh.PrivateKey) ([]byte, error) {
   117  	pubEph, err := dh.dh.NewPublicKey(encPubEph)
   118  	if err != nil {
   119  		return nil, err
   120  	}
   121  	dhVal, err := secRecipient.ECDH(pubEph)
   122  	if err != nil {
   123  		return nil, err
   124  	}
   125  	kemContext := append(encPubEph, secRecipient.PublicKey().Bytes()...)
   126  	return dh.ExtractAndExpand(dhVal, kemContext)
   127  }
   128  
   129  type context struct {
   130  	aead cipher.AEAD
   131  
   132  	sharedSecret []byte
   133  
   134  	suiteID []byte
   135  
   136  	key            []byte
   137  	baseNonce      []byte
   138  	exporterSecret []byte
   139  
   140  	seqNum uint128
   141  }
   142  
   143  type Sender struct {
   144  	*context
   145  }
   146  
   147  type Receipient struct {
   148  	*context
   149  }
   150  
   151  var aesGCMNew = func(key []byte) (cipher.AEAD, error) {
   152  	block, err := aes.NewCipher(key)
   153  	if err != nil {
   154  		return nil, err
   155  	}
   156  	return cipher.NewGCM(block)
   157  }
   158  
   159  type AEADID uint16
   160  
   161  const (
   162  	AEAD_AES_128_GCM      = 0x0001
   163  	AEAD_AES_256_GCM      = 0x0002
   164  	AEAD_ChaCha20Poly1305 = 0x0003
   165  )
   166  
   167  var SupportedAEADs = map[uint16]struct {
   168  	keySize   int
   169  	nonceSize int
   170  	aead      func([]byte) (cipher.AEAD, error)
   171  }{
   172  	// RFC 9180, Section 7.3
   173  	AEAD_AES_128_GCM:      {keySize: 16, nonceSize: 12, aead: aesGCMNew},
   174  	AEAD_AES_256_GCM:      {keySize: 32, nonceSize: 12, aead: aesGCMNew},
   175  	AEAD_ChaCha20Poly1305: {keySize: chacha20poly1305.KeySize, nonceSize: chacha20poly1305.NonceSize, aead: chacha20poly1305.New},
   176  }
   177  
   178  type KDFID uint16
   179  
   180  const KDF_HKDF_SHA256 = 0x0001
   181  
   182  var SupportedKDFs = map[uint16]func() *hkdfKDF{
   183  	// RFC 9180, Section 7.2
   184  	KDF_HKDF_SHA256: func() *hkdfKDF { return &hkdfKDF{crypto.SHA256} },
   185  }
   186  
   187  func newContext(sharedSecret []byte, kemID, kdfID, aeadID uint16, info []byte) (*context, error) {
   188  	sid := suiteID(kemID, kdfID, aeadID)
   189  
   190  	kdfInit, ok := SupportedKDFs[kdfID]
   191  	if !ok {
   192  		return nil, errors.New("unsupported KDF id")
   193  	}
   194  	kdf := kdfInit()
   195  
   196  	aeadInfo, ok := SupportedAEADs[aeadID]
   197  	if !ok {
   198  		return nil, errors.New("unsupported AEAD id")
   199  	}
   200  
   201  	pskIDHash, err := kdf.LabeledExtract(sid, nil, "psk_id_hash", nil)
   202  	if err != nil {
   203  		return nil, err
   204  	}
   205  	infoHash, err := kdf.LabeledExtract(sid, nil, "info_hash", info)
   206  	if err != nil {
   207  		return nil, err
   208  	}
   209  	ksContext := append([]byte{0}, pskIDHash...)
   210  	ksContext = append(ksContext, infoHash...)
   211  
   212  	secret, err := kdf.LabeledExtract(sid, sharedSecret, "secret", nil)
   213  	if err != nil {
   214  		return nil, err
   215  	}
   216  	key, err := kdf.LabeledExpand(sid, secret, "key", ksContext, uint16(aeadInfo.keySize) /* Nk - key size for AEAD */)
   217  	if err != nil {
   218  		return nil, err
   219  	}
   220  	baseNonce, err := kdf.LabeledExpand(sid, secret, "base_nonce", ksContext, uint16(aeadInfo.nonceSize) /* Nn - nonce size for AEAD */)
   221  	if err != nil {
   222  		return nil, err
   223  	}
   224  	exporterSecret, err := kdf.LabeledExpand(sid, secret, "exp", ksContext, uint16(kdf.hash.Size()) /* Nh - hash output size of the kdf*/)
   225  	if err != nil {
   226  		return nil, err
   227  	}
   228  
   229  	aead, err := aeadInfo.aead(key)
   230  	if err != nil {
   231  		return nil, err
   232  	}
   233  
   234  	return &context{
   235  		aead:           aead,
   236  		sharedSecret:   sharedSecret,
   237  		suiteID:        sid,
   238  		key:            key,
   239  		baseNonce:      baseNonce,
   240  		exporterSecret: exporterSecret,
   241  	}, nil
   242  }
   243  
   244  func SetupSender(kemID, kdfID, aeadID uint16, pub *ecdh.PublicKey, info []byte) ([]byte, *Sender, error) {
   245  	kem, err := newDHKem(kemID)
   246  	if err != nil {
   247  		return nil, nil, err
   248  	}
   249  	sharedSecret, encapsulatedKey, err := kem.Encap(pub)
   250  	if err != nil {
   251  		return nil, nil, err
   252  	}
   253  
   254  	context, err := newContext(sharedSecret, kemID, kdfID, aeadID, info)
   255  	if err != nil {
   256  		return nil, nil, err
   257  	}
   258  
   259  	return encapsulatedKey, &Sender{context}, nil
   260  }
   261  
   262  func SetupReceipient(kemID, kdfID, aeadID uint16, priv *ecdh.PrivateKey, info, encPubEph []byte) (*Receipient, error) {
   263  	kem, err := newDHKem(kemID)
   264  	if err != nil {
   265  		return nil, err
   266  	}
   267  	sharedSecret, err := kem.Decap(encPubEph, priv)
   268  	if err != nil {
   269  		return nil, err
   270  	}
   271  
   272  	context, err := newContext(sharedSecret, kemID, kdfID, aeadID, info)
   273  	if err != nil {
   274  		return nil, err
   275  	}
   276  
   277  	return &Receipient{context}, nil
   278  }
   279  
   280  func (ctx *context) nextNonce() []byte {
   281  	nonce := ctx.seqNum.bytes()[16-ctx.aead.NonceSize():]
   282  	for i := range ctx.baseNonce {
   283  		nonce[i] ^= ctx.baseNonce[i]
   284  	}
   285  	return nonce
   286  }
   287  
   288  func (ctx *context) incrementNonce() {
   289  	// Message limit is, according to the RFC, 2^95+1, which
   290  	// is somewhat confusing, but we do as we're told.
   291  	if ctx.seqNum.bitLen() >= (ctx.aead.NonceSize()*8)-1 {
   292  		panic("message limit reached")
   293  	}
   294  	ctx.seqNum = ctx.seqNum.addOne()
   295  }
   296  
   297  func (s *Sender) Seal(aad, plaintext []byte) ([]byte, error) {
   298  	ciphertext := s.aead.Seal(nil, s.nextNonce(), plaintext, aad)
   299  	s.incrementNonce()
   300  	return ciphertext, nil
   301  }
   302  
   303  func (r *Receipient) Open(aad, ciphertext []byte) ([]byte, error) {
   304  	plaintext, err := r.aead.Open(nil, r.nextNonce(), ciphertext, aad)
   305  	if err != nil {
   306  		return nil, err
   307  	}
   308  	r.incrementNonce()
   309  	return plaintext, nil
   310  }
   311  
   312  func suiteID(kemID, kdfID, aeadID uint16) []byte {
   313  	suiteID := make([]byte, 0, 4+2+2+2)
   314  	suiteID = append(suiteID, []byte("HPKE")...)
   315  	suiteID = byteorder.BEAppendUint16(suiteID, kemID)
   316  	suiteID = byteorder.BEAppendUint16(suiteID, kdfID)
   317  	suiteID = byteorder.BEAppendUint16(suiteID, aeadID)
   318  	return suiteID
   319  }
   320  
   321  func ParseHPKEPublicKey(kemID uint16, bytes []byte) (*ecdh.PublicKey, error) {
   322  	kemInfo, ok := SupportedKEMs[kemID]
   323  	if !ok {
   324  		return nil, errors.New("unsupported KEM id")
   325  	}
   326  	return kemInfo.curve.NewPublicKey(bytes)
   327  }
   328  
   329  func ParseHPKEPrivateKey(kemID uint16, bytes []byte) (*ecdh.PrivateKey, error) {
   330  	kemInfo, ok := SupportedKEMs[kemID]
   331  	if !ok {
   332  		return nil, errors.New("unsupported KEM id")
   333  	}
   334  	return kemInfo.curve.NewPrivateKey(bytes)
   335  }
   336  
   337  type uint128 struct {
   338  	hi, lo uint64
   339  }
   340  
   341  func (u uint128) addOne() uint128 {
   342  	lo, carry := bits.Add64(u.lo, 1, 0)
   343  	return uint128{u.hi + carry, lo}
   344  }
   345  
   346  func (u uint128) bitLen() int {
   347  	return bits.Len64(u.hi) + bits.Len64(u.lo)
   348  }
   349  
   350  func (u uint128) bytes() []byte {
   351  	b := make([]byte, 16)
   352  	byteorder.BEPutUint64(b[0:], u.hi)
   353  	byteorder.BEPutUint64(b[8:], u.lo)
   354  	return b
   355  }
   356  

View as plain text