1
2
3
4
5 package tls
6
7 import (
8 "crypto/aes"
9 "crypto/cipher"
10 "crypto/hmac"
11 "crypto/sha256"
12 "crypto/subtle"
13 "crypto/x509"
14 "errors"
15 "io"
16
17 "golang.org/x/crypto/cryptobyte"
18 )
19
20
21 type SessionState struct {
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74 Extra [][]byte
75
76
77
78
79 EarlyData bool
80
81 version uint16
82 isClient bool
83 cipherSuite uint16
84
85
86
87 createdAt uint64
88 secret []byte
89 extMasterSecret bool
90 peerCertificates []*x509.Certificate
91 activeCertHandles []*activeCert
92 ocspResponse []byte
93 scts [][]byte
94 verifiedChains [][]*x509.Certificate
95 alpnProtocol string
96
97
98 useBy uint64
99 ageAdd uint32
100 ticket []byte
101
102
103 curveID CurveID
104 }
105
106
107
108
109
110
111
112 func (s *SessionState) Bytes() ([]byte, error) {
113 var b cryptobyte.Builder
114 b.AddUint16(s.version)
115 if s.isClient {
116 b.AddUint8(2)
117 } else {
118 b.AddUint8(1)
119 }
120 b.AddUint16(s.cipherSuite)
121 addUint64(&b, s.createdAt)
122 b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
123 b.AddBytes(s.secret)
124 })
125 b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
126 for _, extra := range s.Extra {
127 b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
128 b.AddBytes(extra)
129 })
130 }
131 })
132 if s.extMasterSecret {
133 b.AddUint8(1)
134 } else {
135 b.AddUint8(0)
136 }
137 if s.EarlyData {
138 b.AddUint8(1)
139 } else {
140 b.AddUint8(0)
141 }
142 marshalCertificate(&b, Certificate{
143 Certificate: certificatesToBytesSlice(s.peerCertificates),
144 OCSPStaple: s.ocspResponse,
145 SignedCertificateTimestamps: s.scts,
146 })
147 b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
148 for _, chain := range s.verifiedChains {
149 b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
150
151 if len(chain) == 0 {
152 b.SetError(errors.New("tls: internal error: empty verified chain"))
153 return
154 }
155 for _, cert := range chain[1:] {
156 b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
157 b.AddBytes(cert.Raw)
158 })
159 }
160 })
161 }
162 })
163 if s.EarlyData {
164 b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
165 b.AddBytes([]byte(s.alpnProtocol))
166 })
167 }
168 if s.version >= VersionTLS13 {
169 if s.isClient {
170 addUint64(&b, s.useBy)
171 b.AddUint32(s.ageAdd)
172 }
173 } else {
174 b.AddUint16(uint16(s.curveID))
175 }
176 return b.Bytes()
177 }
178
179 func certificatesToBytesSlice(certs []*x509.Certificate) [][]byte {
180 s := make([][]byte, 0, len(certs))
181 for _, c := range certs {
182 s = append(s, c.Raw)
183 }
184 return s
185 }
186
187
188 func ParseSessionState(data []byte) (*SessionState, error) {
189 ss := &SessionState{}
190 s := cryptobyte.String(data)
191 var typ, extMasterSecret, earlyData uint8
192 var cert Certificate
193 var extra cryptobyte.String
194 if !s.ReadUint16(&ss.version) ||
195 !s.ReadUint8(&typ) ||
196 !s.ReadUint16(&ss.cipherSuite) ||
197 !readUint64(&s, &ss.createdAt) ||
198 !readUint8LengthPrefixed(&s, &ss.secret) ||
199 !s.ReadUint24LengthPrefixed(&extra) ||
200 !s.ReadUint8(&extMasterSecret) ||
201 !s.ReadUint8(&earlyData) ||
202 len(ss.secret) == 0 ||
203 !unmarshalCertificate(&s, &cert) {
204 return nil, errors.New("tls: invalid session encoding")
205 }
206 for !extra.Empty() {
207 var e []byte
208 if !readUint24LengthPrefixed(&extra, &e) {
209 return nil, errors.New("tls: invalid session encoding")
210 }
211 ss.Extra = append(ss.Extra, e)
212 }
213 switch typ {
214 case 1:
215 ss.isClient = false
216 case 2:
217 ss.isClient = true
218 default:
219 return nil, errors.New("tls: unknown session encoding")
220 }
221 switch extMasterSecret {
222 case 0:
223 ss.extMasterSecret = false
224 case 1:
225 ss.extMasterSecret = true
226 default:
227 return nil, errors.New("tls: invalid session encoding")
228 }
229 switch earlyData {
230 case 0:
231 ss.EarlyData = false
232 case 1:
233 ss.EarlyData = true
234 default:
235 return nil, errors.New("tls: invalid session encoding")
236 }
237 for _, cert := range cert.Certificate {
238 c, err := globalCertCache.newCert(cert)
239 if err != nil {
240 return nil, err
241 }
242 ss.activeCertHandles = append(ss.activeCertHandles, c)
243 ss.peerCertificates = append(ss.peerCertificates, c.cert)
244 }
245 if ss.isClient && len(ss.peerCertificates) == 0 {
246 return nil, errors.New("tls: no server certificates in client session")
247 }
248 ss.ocspResponse = cert.OCSPStaple
249 ss.scts = cert.SignedCertificateTimestamps
250 var chainList cryptobyte.String
251 if !s.ReadUint24LengthPrefixed(&chainList) {
252 return nil, errors.New("tls: invalid session encoding")
253 }
254 for !chainList.Empty() {
255 var certList cryptobyte.String
256 if !chainList.ReadUint24LengthPrefixed(&certList) {
257 return nil, errors.New("tls: invalid session encoding")
258 }
259 var chain []*x509.Certificate
260 if len(ss.peerCertificates) == 0 {
261 return nil, errors.New("tls: invalid session encoding")
262 }
263 chain = append(chain, ss.peerCertificates[0])
264 for !certList.Empty() {
265 var cert []byte
266 if !readUint24LengthPrefixed(&certList, &cert) {
267 return nil, errors.New("tls: invalid session encoding")
268 }
269 c, err := globalCertCache.newCert(cert)
270 if err != nil {
271 return nil, err
272 }
273 ss.activeCertHandles = append(ss.activeCertHandles, c)
274 chain = append(chain, c.cert)
275 }
276 ss.verifiedChains = append(ss.verifiedChains, chain)
277 }
278 if ss.EarlyData {
279 var alpn []byte
280 if !readUint8LengthPrefixed(&s, &alpn) {
281 return nil, errors.New("tls: invalid session encoding")
282 }
283 ss.alpnProtocol = string(alpn)
284 }
285 if ss.version >= VersionTLS13 {
286 if ss.isClient {
287 if !s.ReadUint64(&ss.useBy) || !s.ReadUint32(&ss.ageAdd) {
288 return nil, errors.New("tls: invalid session encoding")
289 }
290 }
291 } else {
292 if !s.ReadUint16((*uint16)(&ss.curveID)) {
293 return nil, errors.New("tls: invalid session encoding")
294 }
295 }
296 return ss, nil
297 }
298
299
300
301 func (c *Conn) sessionState() *SessionState {
302 return &SessionState{
303 version: c.vers,
304 cipherSuite: c.cipherSuite,
305 createdAt: uint64(c.config.time().Unix()),
306 alpnProtocol: c.clientProtocol,
307 peerCertificates: c.peerCertificates,
308 activeCertHandles: c.activeCertHandles,
309 ocspResponse: c.ocspResponse,
310 scts: c.scts,
311 isClient: c.isClient,
312 extMasterSecret: c.extMasterSecret,
313 verifiedChains: c.verifiedChains,
314 curveID: c.curveID,
315 }
316 }
317
318
319
320 func (c *Config) EncryptTicket(cs ConnectionState, ss *SessionState) ([]byte, error) {
321 ticketKeys := c.ticketKeys(nil)
322 stateBytes, err := ss.Bytes()
323 if err != nil {
324 return nil, err
325 }
326 return c.encryptTicket(stateBytes, ticketKeys)
327 }
328
329 func (c *Config) encryptTicket(state []byte, ticketKeys []ticketKey) ([]byte, error) {
330 if len(ticketKeys) == 0 {
331 return nil, errors.New("tls: internal error: session ticket keys unavailable")
332 }
333
334 encrypted := make([]byte, aes.BlockSize+len(state)+sha256.Size)
335 iv := encrypted[:aes.BlockSize]
336 ciphertext := encrypted[aes.BlockSize : len(encrypted)-sha256.Size]
337 authenticated := encrypted[:len(encrypted)-sha256.Size]
338 macBytes := encrypted[len(encrypted)-sha256.Size:]
339
340 if _, err := io.ReadFull(c.rand(), iv); err != nil {
341 return nil, err
342 }
343 key := ticketKeys[0]
344 block, err := aes.NewCipher(key.aesKey[:])
345 if err != nil {
346 return nil, errors.New("tls: failed to create cipher while encrypting ticket: " + err.Error())
347 }
348 cipher.NewCTR(block, iv).XORKeyStream(ciphertext, state)
349
350 mac := hmac.New(sha256.New, key.hmacKey[:])
351 mac.Write(authenticated)
352 mac.Sum(macBytes[:0])
353
354 return encrypted, nil
355 }
356
357
358
359
360
361 func (c *Config) DecryptTicket(identity []byte, cs ConnectionState) (*SessionState, error) {
362 ticketKeys := c.ticketKeys(nil)
363 stateBytes := c.decryptTicket(identity, ticketKeys)
364 if stateBytes == nil {
365 return nil, nil
366 }
367 s, err := ParseSessionState(stateBytes)
368 if err != nil {
369 return nil, nil
370 }
371 return s, nil
372 }
373
374 func (c *Config) decryptTicket(encrypted []byte, ticketKeys []ticketKey) []byte {
375 if len(encrypted) < aes.BlockSize+sha256.Size {
376 return nil
377 }
378
379 iv := encrypted[:aes.BlockSize]
380 ciphertext := encrypted[aes.BlockSize : len(encrypted)-sha256.Size]
381 authenticated := encrypted[:len(encrypted)-sha256.Size]
382 macBytes := encrypted[len(encrypted)-sha256.Size:]
383
384 for _, key := range ticketKeys {
385 mac := hmac.New(sha256.New, key.hmacKey[:])
386 mac.Write(authenticated)
387 expected := mac.Sum(nil)
388
389 if subtle.ConstantTimeCompare(macBytes, expected) != 1 {
390 continue
391 }
392
393 block, err := aes.NewCipher(key.aesKey[:])
394 if err != nil {
395 return nil
396 }
397 plaintext := make([]byte, len(ciphertext))
398 cipher.NewCTR(block, iv).XORKeyStream(plaintext, ciphertext)
399
400 return plaintext
401 }
402
403 return nil
404 }
405
406
407
408 type ClientSessionState struct {
409 session *SessionState
410 }
411
412
413
414
415
416
417 func (cs *ClientSessionState) ResumptionState() (ticket []byte, state *SessionState, err error) {
418 if cs == nil || cs.session == nil {
419 return nil, nil, nil
420 }
421 return cs.session.ticket, cs.session, nil
422 }
423
424
425
426
427
428
429 func NewResumptionState(ticket []byte, state *SessionState) (*ClientSessionState, error) {
430 state.ticket = ticket
431 return &ClientSessionState{
432 session: state,
433 }, nil
434 }
435
View as plain text