1
2
3
4
5 package hpke
6
7 import (
8 "crypto"
9 "crypto/aes"
10 "crypto/cipher"
11 "crypto/ecdh"
12 "crypto/hkdf"
13 "crypto/rand"
14 "errors"
15 "internal/byteorder"
16 "math/bits"
17
18 "golang.org/x/crypto/chacha20poly1305"
19 )
20
21
22
23 var testingOnlyGenerateKey func() (*ecdh.PrivateKey, error)
24
25 type hkdfKDF struct {
26 hash crypto.Hash
27 }
28
29 func (kdf *hkdfKDF) LabeledExtract(sid []byte, salt []byte, label string, inputKey []byte) ([]byte, error) {
30 labeledIKM := make([]byte, 0, 7+len(sid)+len(label)+len(inputKey))
31 labeledIKM = append(labeledIKM, []byte("HPKE-v1")...)
32 labeledIKM = append(labeledIKM, sid...)
33 labeledIKM = append(labeledIKM, label...)
34 labeledIKM = append(labeledIKM, inputKey...)
35 return hkdf.Extract(kdf.hash.New, labeledIKM, salt)
36 }
37
38 func (kdf *hkdfKDF) LabeledExpand(suiteID []byte, randomKey []byte, label string, info []byte, length uint16) ([]byte, error) {
39 labeledInfo := make([]byte, 0, 2+7+len(suiteID)+len(label)+len(info))
40 labeledInfo = byteorder.BEAppendUint16(labeledInfo, length)
41 labeledInfo = append(labeledInfo, []byte("HPKE-v1")...)
42 labeledInfo = append(labeledInfo, suiteID...)
43 labeledInfo = append(labeledInfo, label...)
44 labeledInfo = append(labeledInfo, info...)
45 return hkdf.Expand(kdf.hash.New, randomKey, string(labeledInfo), int(length))
46 }
47
48
49 type dhKEM struct {
50 dh ecdh.Curve
51 kdf hkdfKDF
52
53 suiteID []byte
54 nSecret uint16
55 }
56
57 type KemID uint16
58
59 const DHKEM_X25519_HKDF_SHA256 = 0x0020
60
61 var SupportedKEMs = map[uint16]struct {
62 curve ecdh.Curve
63 hash crypto.Hash
64 nSecret uint16
65 }{
66
67 DHKEM_X25519_HKDF_SHA256: {ecdh.X25519(), crypto.SHA256, 32},
68 }
69
70 func newDHKem(kemID uint16) (*dhKEM, error) {
71 suite, ok := SupportedKEMs[kemID]
72 if !ok {
73 return nil, errors.New("unsupported suite ID")
74 }
75 return &dhKEM{
76 dh: suite.curve,
77 kdf: hkdfKDF{suite.hash},
78 suiteID: byteorder.BEAppendUint16([]byte("KEM"), kemID),
79 nSecret: suite.nSecret,
80 }, nil
81 }
82
83 func (dh *dhKEM) ExtractAndExpand(dhKey, kemContext []byte) ([]byte, error) {
84 eaePRK, err := dh.kdf.LabeledExtract(dh.suiteID[:], nil, "eae_prk", dhKey)
85 if err != nil {
86 return nil, err
87 }
88 return dh.kdf.LabeledExpand(dh.suiteID[:], eaePRK, "shared_secret", kemContext, dh.nSecret)
89 }
90
91 func (dh *dhKEM) Encap(pubRecipient *ecdh.PublicKey) (sharedSecret []byte, encapPub []byte, err error) {
92 var privEph *ecdh.PrivateKey
93 if testingOnlyGenerateKey != nil {
94 privEph, err = testingOnlyGenerateKey()
95 } else {
96 privEph, err = dh.dh.GenerateKey(rand.Reader)
97 }
98 if err != nil {
99 return nil, nil, err
100 }
101 dhVal, err := privEph.ECDH(pubRecipient)
102 if err != nil {
103 return nil, nil, err
104 }
105 encPubEph := privEph.PublicKey().Bytes()
106
107 encPubRecip := pubRecipient.Bytes()
108 kemContext := append(encPubEph, encPubRecip...)
109 sharedSecret, err = dh.ExtractAndExpand(dhVal, kemContext)
110 if err != nil {
111 return nil, nil, err
112 }
113 return sharedSecret, encPubEph, nil
114 }
115
116 func (dh *dhKEM) Decap(encPubEph []byte, secRecipient *ecdh.PrivateKey) ([]byte, error) {
117 pubEph, err := dh.dh.NewPublicKey(encPubEph)
118 if err != nil {
119 return nil, err
120 }
121 dhVal, err := secRecipient.ECDH(pubEph)
122 if err != nil {
123 return nil, err
124 }
125 kemContext := append(encPubEph, secRecipient.PublicKey().Bytes()...)
126 return dh.ExtractAndExpand(dhVal, kemContext)
127 }
128
129 type context struct {
130 aead cipher.AEAD
131
132 sharedSecret []byte
133
134 suiteID []byte
135
136 key []byte
137 baseNonce []byte
138 exporterSecret []byte
139
140 seqNum uint128
141 }
142
143 type Sender struct {
144 *context
145 }
146
147 type Receipient struct {
148 *context
149 }
150
151 var aesGCMNew = func(key []byte) (cipher.AEAD, error) {
152 block, err := aes.NewCipher(key)
153 if err != nil {
154 return nil, err
155 }
156 return cipher.NewGCM(block)
157 }
158
159 type AEADID uint16
160
161 const (
162 AEAD_AES_128_GCM = 0x0001
163 AEAD_AES_256_GCM = 0x0002
164 AEAD_ChaCha20Poly1305 = 0x0003
165 )
166
167 var SupportedAEADs = map[uint16]struct {
168 keySize int
169 nonceSize int
170 aead func([]byte) (cipher.AEAD, error)
171 }{
172
173 AEAD_AES_128_GCM: {keySize: 16, nonceSize: 12, aead: aesGCMNew},
174 AEAD_AES_256_GCM: {keySize: 32, nonceSize: 12, aead: aesGCMNew},
175 AEAD_ChaCha20Poly1305: {keySize: chacha20poly1305.KeySize, nonceSize: chacha20poly1305.NonceSize, aead: chacha20poly1305.New},
176 }
177
178 type KDFID uint16
179
180 const KDF_HKDF_SHA256 = 0x0001
181
182 var SupportedKDFs = map[uint16]func() *hkdfKDF{
183
184 KDF_HKDF_SHA256: func() *hkdfKDF { return &hkdfKDF{crypto.SHA256} },
185 }
186
187 func newContext(sharedSecret []byte, kemID, kdfID, aeadID uint16, info []byte) (*context, error) {
188 sid := suiteID(kemID, kdfID, aeadID)
189
190 kdfInit, ok := SupportedKDFs[kdfID]
191 if !ok {
192 return nil, errors.New("unsupported KDF id")
193 }
194 kdf := kdfInit()
195
196 aeadInfo, ok := SupportedAEADs[aeadID]
197 if !ok {
198 return nil, errors.New("unsupported AEAD id")
199 }
200
201 pskIDHash, err := kdf.LabeledExtract(sid, nil, "psk_id_hash", nil)
202 if err != nil {
203 return nil, err
204 }
205 infoHash, err := kdf.LabeledExtract(sid, nil, "info_hash", info)
206 if err != nil {
207 return nil, err
208 }
209 ksContext := append([]byte{0}, pskIDHash...)
210 ksContext = append(ksContext, infoHash...)
211
212 secret, err := kdf.LabeledExtract(sid, sharedSecret, "secret", nil)
213 if err != nil {
214 return nil, err
215 }
216 key, err := kdf.LabeledExpand(sid, secret, "key", ksContext, uint16(aeadInfo.keySize) )
217 if err != nil {
218 return nil, err
219 }
220 baseNonce, err := kdf.LabeledExpand(sid, secret, "base_nonce", ksContext, uint16(aeadInfo.nonceSize) )
221 if err != nil {
222 return nil, err
223 }
224 exporterSecret, err := kdf.LabeledExpand(sid, secret, "exp", ksContext, uint16(kdf.hash.Size()) )
225 if err != nil {
226 return nil, err
227 }
228
229 aead, err := aeadInfo.aead(key)
230 if err != nil {
231 return nil, err
232 }
233
234 return &context{
235 aead: aead,
236 sharedSecret: sharedSecret,
237 suiteID: sid,
238 key: key,
239 baseNonce: baseNonce,
240 exporterSecret: exporterSecret,
241 }, nil
242 }
243
244 func SetupSender(kemID, kdfID, aeadID uint16, pub *ecdh.PublicKey, info []byte) ([]byte, *Sender, error) {
245 kem, err := newDHKem(kemID)
246 if err != nil {
247 return nil, nil, err
248 }
249 sharedSecret, encapsulatedKey, err := kem.Encap(pub)
250 if err != nil {
251 return nil, nil, err
252 }
253
254 context, err := newContext(sharedSecret, kemID, kdfID, aeadID, info)
255 if err != nil {
256 return nil, nil, err
257 }
258
259 return encapsulatedKey, &Sender{context}, nil
260 }
261
262 func SetupReceipient(kemID, kdfID, aeadID uint16, priv *ecdh.PrivateKey, info, encPubEph []byte) (*Receipient, error) {
263 kem, err := newDHKem(kemID)
264 if err != nil {
265 return nil, err
266 }
267 sharedSecret, err := kem.Decap(encPubEph, priv)
268 if err != nil {
269 return nil, err
270 }
271
272 context, err := newContext(sharedSecret, kemID, kdfID, aeadID, info)
273 if err != nil {
274 return nil, err
275 }
276
277 return &Receipient{context}, nil
278 }
279
280 func (ctx *context) nextNonce() []byte {
281 nonce := ctx.seqNum.bytes()[16-ctx.aead.NonceSize():]
282 for i := range ctx.baseNonce {
283 nonce[i] ^= ctx.baseNonce[i]
284 }
285 return nonce
286 }
287
288 func (ctx *context) incrementNonce() {
289
290
291 if ctx.seqNum.bitLen() >= (ctx.aead.NonceSize()*8)-1 {
292 panic("message limit reached")
293 }
294 ctx.seqNum = ctx.seqNum.addOne()
295 }
296
297 func (s *Sender) Seal(aad, plaintext []byte) ([]byte, error) {
298 ciphertext := s.aead.Seal(nil, s.nextNonce(), plaintext, aad)
299 s.incrementNonce()
300 return ciphertext, nil
301 }
302
303 func (r *Receipient) Open(aad, ciphertext []byte) ([]byte, error) {
304 plaintext, err := r.aead.Open(nil, r.nextNonce(), ciphertext, aad)
305 if err != nil {
306 return nil, err
307 }
308 r.incrementNonce()
309 return plaintext, nil
310 }
311
312 func suiteID(kemID, kdfID, aeadID uint16) []byte {
313 suiteID := make([]byte, 0, 4+2+2+2)
314 suiteID = append(suiteID, []byte("HPKE")...)
315 suiteID = byteorder.BEAppendUint16(suiteID, kemID)
316 suiteID = byteorder.BEAppendUint16(suiteID, kdfID)
317 suiteID = byteorder.BEAppendUint16(suiteID, aeadID)
318 return suiteID
319 }
320
321 func ParseHPKEPublicKey(kemID uint16, bytes []byte) (*ecdh.PublicKey, error) {
322 kemInfo, ok := SupportedKEMs[kemID]
323 if !ok {
324 return nil, errors.New("unsupported KEM id")
325 }
326 return kemInfo.curve.NewPublicKey(bytes)
327 }
328
329 func ParseHPKEPrivateKey(kemID uint16, bytes []byte) (*ecdh.PrivateKey, error) {
330 kemInfo, ok := SupportedKEMs[kemID]
331 if !ok {
332 return nil, errors.New("unsupported KEM id")
333 }
334 return kemInfo.curve.NewPrivateKey(bytes)
335 }
336
337 type uint128 struct {
338 hi, lo uint64
339 }
340
341 func (u uint128) addOne() uint128 {
342 lo, carry := bits.Add64(u.lo, 1, 0)
343 return uint128{u.hi + carry, lo}
344 }
345
346 func (u uint128) bitLen() int {
347 return bits.Len64(u.hi) + bits.Len64(u.lo)
348 }
349
350 func (u uint128) bytes() []byte {
351 b := make([]byte, 16)
352 byteorder.BEPutUint64(b[0:], u.hi)
353 byteorder.BEPutUint64(b[8:], u.lo)
354 return b
355 }
356
View as plain text