1
2
3
4
5
6
7 package tls
8
9 import (
10 "bytes"
11 "context"
12 "crypto/cipher"
13 "crypto/subtle"
14 "crypto/x509"
15 "errors"
16 "fmt"
17 "hash"
18 "internal/godebug"
19 "io"
20 "net"
21 "sync"
22 "sync/atomic"
23 "time"
24 )
25
26
27
28 type Conn struct {
29
30 conn net.Conn
31 isClient bool
32 handshakeFn func(context.Context) error
33 quic *quicState
34
35
36
37
38 isHandshakeComplete atomic.Bool
39
40 handshakeMutex sync.Mutex
41 handshakeErr error
42 vers uint16
43 haveVers bool
44 config *Config
45
46
47
48 handshakes int
49 extMasterSecret bool
50 didResume bool
51 didHRR bool
52 cipherSuite uint16
53 curveID CurveID
54 ocspResponse []byte
55 scts [][]byte
56 peerCertificates []*x509.Certificate
57
58
59 activeCertHandles []*activeCert
60
61
62 verifiedChains [][]*x509.Certificate
63
64 serverName string
65
66
67
68 secureRenegotiation bool
69
70 ekm func(label string, context []byte, length int) ([]byte, error)
71
72
73 resumptionSecret []byte
74 echAccepted bool
75
76
77
78
79 ticketKeys []ticketKey
80
81
82
83
84
85 clientFinishedIsFirst bool
86
87
88 closeNotifyErr error
89
90
91 closeNotifySent bool
92
93
94
95
96
97 clientFinished [12]byte
98 serverFinished [12]byte
99
100
101 clientProtocol string
102
103
104 in, out halfConn
105 rawInput bytes.Buffer
106 input bytes.Reader
107 hand bytes.Buffer
108 buffering bool
109 sendBuf []byte
110
111
112
113 bytesSent int64
114 packetsSent int64
115
116
117
118
119 retryCount int
120
121
122
123 activeCall atomic.Int32
124
125 tmp [16]byte
126 }
127
128
129
130
131
132
133 func (c *Conn) LocalAddr() net.Addr {
134 return c.conn.LocalAddr()
135 }
136
137
138 func (c *Conn) RemoteAddr() net.Addr {
139 return c.conn.RemoteAddr()
140 }
141
142
143
144
145 func (c *Conn) SetDeadline(t time.Time) error {
146 return c.conn.SetDeadline(t)
147 }
148
149
150
151 func (c *Conn) SetReadDeadline(t time.Time) error {
152 return c.conn.SetReadDeadline(t)
153 }
154
155
156
157
158 func (c *Conn) SetWriteDeadline(t time.Time) error {
159 return c.conn.SetWriteDeadline(t)
160 }
161
162
163
164
165 func (c *Conn) NetConn() net.Conn {
166 return c.conn
167 }
168
169
170
171 type halfConn struct {
172 sync.Mutex
173
174 err error
175 version uint16
176 cipher any
177 mac hash.Hash
178 seq [8]byte
179
180 scratchBuf [13]byte
181
182 nextCipher any
183 nextMac hash.Hash
184
185 level QUICEncryptionLevel
186 trafficSecret []byte
187 }
188
189 type permanentError struct {
190 err net.Error
191 }
192
193 func (e *permanentError) Error() string { return e.err.Error() }
194 func (e *permanentError) Unwrap() error { return e.err }
195 func (e *permanentError) Timeout() bool { return e.err.Timeout() }
196 func (e *permanentError) Temporary() bool { return false }
197
198 func (hc *halfConn) setErrorLocked(err error) error {
199 if e, ok := err.(net.Error); ok {
200 hc.err = &permanentError{err: e}
201 } else {
202 hc.err = err
203 }
204 return hc.err
205 }
206
207
208
209 func (hc *halfConn) prepareCipherSpec(version uint16, cipher any, mac hash.Hash) {
210 hc.version = version
211 hc.nextCipher = cipher
212 hc.nextMac = mac
213 }
214
215
216
217 func (hc *halfConn) changeCipherSpec() error {
218 if hc.nextCipher == nil || hc.version == VersionTLS13 {
219 return alertInternalError
220 }
221 hc.cipher = hc.nextCipher
222 hc.mac = hc.nextMac
223 hc.nextCipher = nil
224 hc.nextMac = nil
225 for i := range hc.seq {
226 hc.seq[i] = 0
227 }
228 return nil
229 }
230
231 func (hc *halfConn) setTrafficSecret(suite *cipherSuiteTLS13, level QUICEncryptionLevel, secret []byte) {
232 hc.trafficSecret = secret
233 hc.level = level
234 key, iv := suite.trafficKey(secret)
235 hc.cipher = suite.aead(key, iv)
236 for i := range hc.seq {
237 hc.seq[i] = 0
238 }
239 }
240
241
242 func (hc *halfConn) incSeq() {
243 for i := 7; i >= 0; i-- {
244 hc.seq[i]++
245 if hc.seq[i] != 0 {
246 return
247 }
248 }
249
250
251
252
253 panic("TLS: sequence number wraparound")
254 }
255
256
257
258
259 func (hc *halfConn) explicitNonceLen() int {
260 if hc.cipher == nil {
261 return 0
262 }
263
264 switch c := hc.cipher.(type) {
265 case cipher.Stream:
266 return 0
267 case aead:
268 return c.explicitNonceLen()
269 case cbcMode:
270
271 if hc.version >= VersionTLS11 {
272 return c.BlockSize()
273 }
274 return 0
275 default:
276 panic("unknown cipher type")
277 }
278 }
279
280
281
282
283 func extractPadding(payload []byte) (toRemove int, good byte) {
284 if len(payload) < 1 {
285 return 0, 0
286 }
287
288 paddingLen := payload[len(payload)-1]
289 t := uint(len(payload)-1) - uint(paddingLen)
290
291 good = byte(int32(^t) >> 31)
292
293
294 toCheck := 256
295
296 if toCheck > len(payload) {
297 toCheck = len(payload)
298 }
299
300 for i := 0; i < toCheck; i++ {
301 t := uint(paddingLen) - uint(i)
302
303 mask := byte(int32(^t) >> 31)
304 b := payload[len(payload)-1-i]
305 good &^= mask&paddingLen ^ mask&b
306 }
307
308
309
310 good &= good << 4
311 good &= good << 2
312 good &= good << 1
313 good = uint8(int8(good) >> 7)
314
315
316
317
318
319
320
321
322
323
324 paddingLen &= good
325
326 toRemove = int(paddingLen) + 1
327 return
328 }
329
330 func roundUp(a, b int) int {
331 return a + (b-a%b)%b
332 }
333
334
335 type cbcMode interface {
336 cipher.BlockMode
337 SetIV([]byte)
338 }
339
340
341
342 func (hc *halfConn) decrypt(record []byte) ([]byte, recordType, error) {
343 var plaintext []byte
344 typ := recordType(record[0])
345 payload := record[recordHeaderLen:]
346
347
348
349 if hc.version == VersionTLS13 && typ == recordTypeChangeCipherSpec {
350 return payload, typ, nil
351 }
352
353 paddingGood := byte(255)
354 paddingLen := 0
355
356 explicitNonceLen := hc.explicitNonceLen()
357
358 if hc.cipher != nil {
359 switch c := hc.cipher.(type) {
360 case cipher.Stream:
361 c.XORKeyStream(payload, payload)
362 case aead:
363 if len(payload) < explicitNonceLen {
364 return nil, 0, alertBadRecordMAC
365 }
366 nonce := payload[:explicitNonceLen]
367 if len(nonce) == 0 {
368 nonce = hc.seq[:]
369 }
370 payload = payload[explicitNonceLen:]
371
372 var additionalData []byte
373 if hc.version == VersionTLS13 {
374 additionalData = record[:recordHeaderLen]
375 } else {
376 additionalData = append(hc.scratchBuf[:0], hc.seq[:]...)
377 additionalData = append(additionalData, record[:3]...)
378 n := len(payload) - c.Overhead()
379 additionalData = append(additionalData, byte(n>>8), byte(n))
380 }
381
382 var err error
383 plaintext, err = c.Open(payload[:0], nonce, payload, additionalData)
384 if err != nil {
385 return nil, 0, alertBadRecordMAC
386 }
387 case cbcMode:
388 blockSize := c.BlockSize()
389 minPayload := explicitNonceLen + roundUp(hc.mac.Size()+1, blockSize)
390 if len(payload)%blockSize != 0 || len(payload) < minPayload {
391 return nil, 0, alertBadRecordMAC
392 }
393
394 if explicitNonceLen > 0 {
395 c.SetIV(payload[:explicitNonceLen])
396 payload = payload[explicitNonceLen:]
397 }
398 c.CryptBlocks(payload, payload)
399
400
401
402
403
404
405
406 paddingLen, paddingGood = extractPadding(payload)
407 default:
408 panic("unknown cipher type")
409 }
410
411 if hc.version == VersionTLS13 {
412 if typ != recordTypeApplicationData {
413 return nil, 0, alertUnexpectedMessage
414 }
415 if len(plaintext) > maxPlaintext+1 {
416 return nil, 0, alertRecordOverflow
417 }
418
419 for i := len(plaintext) - 1; i >= 0; i-- {
420 if plaintext[i] != 0 {
421 typ = recordType(plaintext[i])
422 plaintext = plaintext[:i]
423 break
424 }
425 if i == 0 {
426 return nil, 0, alertUnexpectedMessage
427 }
428 }
429 }
430 } else {
431 plaintext = payload
432 }
433
434 if hc.mac != nil {
435 macSize := hc.mac.Size()
436 if len(payload) < macSize {
437 return nil, 0, alertBadRecordMAC
438 }
439
440 n := len(payload) - macSize - paddingLen
441 n = subtle.ConstantTimeSelect(int(uint32(n)>>31), 0, n)
442 record[3] = byte(n >> 8)
443 record[4] = byte(n)
444 remoteMAC := payload[n : n+macSize]
445 localMAC := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload[:n], payload[n+macSize:])
446
447
448
449
450
451
452
453
454 macAndPaddingGood := subtle.ConstantTimeCompare(localMAC, remoteMAC) & int(paddingGood)
455 if macAndPaddingGood != 1 {
456 return nil, 0, alertBadRecordMAC
457 }
458
459 plaintext = payload[:n]
460 }
461
462 hc.incSeq()
463 return plaintext, typ, nil
464 }
465
466
467
468
469 func sliceForAppend(in []byte, n int) (head, tail []byte) {
470 if total := len(in) + n; cap(in) >= total {
471 head = in[:total]
472 } else {
473 head = make([]byte, total)
474 copy(head, in)
475 }
476 tail = head[len(in):]
477 return
478 }
479
480
481
482 func (hc *halfConn) encrypt(record, payload []byte, rand io.Reader) ([]byte, error) {
483 if hc.cipher == nil {
484 return append(record, payload...), nil
485 }
486
487 var explicitNonce []byte
488 if explicitNonceLen := hc.explicitNonceLen(); explicitNonceLen > 0 {
489 record, explicitNonce = sliceForAppend(record, explicitNonceLen)
490 if _, isCBC := hc.cipher.(cbcMode); !isCBC && explicitNonceLen < 16 {
491
492
493
494
495
496
497
498
499
500 copy(explicitNonce, hc.seq[:])
501 } else {
502 if _, err := io.ReadFull(rand, explicitNonce); err != nil {
503 return nil, err
504 }
505 }
506 }
507
508 var dst []byte
509 switch c := hc.cipher.(type) {
510 case cipher.Stream:
511 mac := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload, nil)
512 record, dst = sliceForAppend(record, len(payload)+len(mac))
513 c.XORKeyStream(dst[:len(payload)], payload)
514 c.XORKeyStream(dst[len(payload):], mac)
515 case aead:
516 nonce := explicitNonce
517 if len(nonce) == 0 {
518 nonce = hc.seq[:]
519 }
520
521 if hc.version == VersionTLS13 {
522 record = append(record, payload...)
523
524
525 record = append(record, record[0])
526 record[0] = byte(recordTypeApplicationData)
527
528 n := len(payload) + 1 + c.Overhead()
529 record[3] = byte(n >> 8)
530 record[4] = byte(n)
531
532 record = c.Seal(record[:recordHeaderLen],
533 nonce, record[recordHeaderLen:], record[:recordHeaderLen])
534 } else {
535 additionalData := append(hc.scratchBuf[:0], hc.seq[:]...)
536 additionalData = append(additionalData, record[:recordHeaderLen]...)
537 record = c.Seal(record, nonce, payload, additionalData)
538 }
539 case cbcMode:
540 mac := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload, nil)
541 blockSize := c.BlockSize()
542 plaintextLen := len(payload) + len(mac)
543 paddingLen := blockSize - plaintextLen%blockSize
544 record, dst = sliceForAppend(record, plaintextLen+paddingLen)
545 copy(dst, payload)
546 copy(dst[len(payload):], mac)
547 for i := plaintextLen; i < len(dst); i++ {
548 dst[i] = byte(paddingLen - 1)
549 }
550 if len(explicitNonce) > 0 {
551 c.SetIV(explicitNonce)
552 }
553 c.CryptBlocks(dst, dst)
554 default:
555 panic("unknown cipher type")
556 }
557
558
559 n := len(record) - recordHeaderLen
560 record[3] = byte(n >> 8)
561 record[4] = byte(n)
562 hc.incSeq()
563
564 return record, nil
565 }
566
567
568 type RecordHeaderError struct {
569
570 Msg string
571
572
573 RecordHeader [5]byte
574
575
576
577
578 Conn net.Conn
579 }
580
581 func (e RecordHeaderError) Error() string { return "tls: " + e.Msg }
582
583 func (c *Conn) newRecordHeaderError(conn net.Conn, msg string) (err RecordHeaderError) {
584 err.Msg = msg
585 err.Conn = conn
586 copy(err.RecordHeader[:], c.rawInput.Bytes())
587 return err
588 }
589
590 func (c *Conn) readRecord() error {
591 return c.readRecordOrCCS(false)
592 }
593
594 func (c *Conn) readChangeCipherSpec() error {
595 return c.readRecordOrCCS(true)
596 }
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612 func (c *Conn) readRecordOrCCS(expectChangeCipherSpec bool) error {
613 if c.in.err != nil {
614 return c.in.err
615 }
616 handshakeComplete := c.isHandshakeComplete.Load()
617
618
619 if c.input.Len() != 0 {
620 return c.in.setErrorLocked(errors.New("tls: internal error: attempted to read record with pending application data"))
621 }
622 c.input.Reset(nil)
623
624 if c.quic != nil {
625 return c.in.setErrorLocked(errors.New("tls: internal error: attempted to read record with QUIC transport"))
626 }
627
628
629 if err := c.readFromUntil(c.conn, recordHeaderLen); err != nil {
630
631
632
633 if err == io.ErrUnexpectedEOF && c.rawInput.Len() == 0 {
634 err = io.EOF
635 }
636 if e, ok := err.(net.Error); !ok || !e.Temporary() {
637 c.in.setErrorLocked(err)
638 }
639 return err
640 }
641 hdr := c.rawInput.Bytes()[:recordHeaderLen]
642 typ := recordType(hdr[0])
643
644
645
646
647
648 if !handshakeComplete && typ == 0x80 {
649 c.sendAlert(alertProtocolVersion)
650 return c.in.setErrorLocked(c.newRecordHeaderError(nil, "unsupported SSLv2 handshake received"))
651 }
652
653 vers := uint16(hdr[1])<<8 | uint16(hdr[2])
654 expectedVers := c.vers
655 if expectedVers == VersionTLS13 {
656
657
658 expectedVers = VersionTLS12
659 }
660 n := int(hdr[3])<<8 | int(hdr[4])
661 if c.haveVers && vers != expectedVers {
662 c.sendAlert(alertProtocolVersion)
663 msg := fmt.Sprintf("received record with version %x when expecting version %x", vers, expectedVers)
664 return c.in.setErrorLocked(c.newRecordHeaderError(nil, msg))
665 }
666 if !c.haveVers {
667
668
669
670
671 if (typ != recordTypeAlert && typ != recordTypeHandshake) || vers >= 0x1000 {
672 return c.in.setErrorLocked(c.newRecordHeaderError(c.conn, "first record does not look like a TLS handshake"))
673 }
674 }
675 if c.vers == VersionTLS13 && n > maxCiphertextTLS13 || n > maxCiphertext {
676 c.sendAlert(alertRecordOverflow)
677 msg := fmt.Sprintf("oversized record received with length %d", n)
678 return c.in.setErrorLocked(c.newRecordHeaderError(nil, msg))
679 }
680 if err := c.readFromUntil(c.conn, recordHeaderLen+n); err != nil {
681 if e, ok := err.(net.Error); !ok || !e.Temporary() {
682 c.in.setErrorLocked(err)
683 }
684 return err
685 }
686
687
688 record := c.rawInput.Next(recordHeaderLen + n)
689 data, typ, err := c.in.decrypt(record)
690 if err != nil {
691 return c.in.setErrorLocked(c.sendAlert(err.(alert)))
692 }
693 if len(data) > maxPlaintext {
694 return c.in.setErrorLocked(c.sendAlert(alertRecordOverflow))
695 }
696
697
698 if c.in.cipher == nil && typ == recordTypeApplicationData {
699 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
700 }
701
702 if typ != recordTypeAlert && typ != recordTypeChangeCipherSpec && len(data) > 0 {
703
704 c.retryCount = 0
705 }
706
707
708 if c.vers == VersionTLS13 && typ != recordTypeHandshake && c.hand.Len() > 0 {
709 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
710 }
711
712 switch typ {
713 default:
714 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
715
716 case recordTypeAlert:
717 if c.quic != nil {
718 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
719 }
720 if len(data) != 2 {
721 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
722 }
723 if alert(data[1]) == alertCloseNotify {
724 return c.in.setErrorLocked(io.EOF)
725 }
726 if c.vers == VersionTLS13 {
727
728
729
730
731
732 if alert(data[1]) == alertUserCanceled {
733
734 return c.retryReadRecord(expectChangeCipherSpec)
735 }
736 return c.in.setErrorLocked(&net.OpError{Op: "remote error", Err: alert(data[1])})
737 }
738 switch data[0] {
739 case alertLevelWarning:
740
741 return c.retryReadRecord(expectChangeCipherSpec)
742 case alertLevelError:
743 return c.in.setErrorLocked(&net.OpError{Op: "remote error", Err: alert(data[1])})
744 default:
745 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
746 }
747
748 case recordTypeChangeCipherSpec:
749 if len(data) != 1 || data[0] != 1 {
750 return c.in.setErrorLocked(c.sendAlert(alertDecodeError))
751 }
752
753 if c.hand.Len() > 0 {
754 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
755 }
756
757
758
759
760
761 if c.vers == VersionTLS13 {
762 return c.retryReadRecord(expectChangeCipherSpec)
763 }
764 if !expectChangeCipherSpec {
765 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
766 }
767 if err := c.in.changeCipherSpec(); err != nil {
768 return c.in.setErrorLocked(c.sendAlert(err.(alert)))
769 }
770
771 case recordTypeApplicationData:
772 if !handshakeComplete || expectChangeCipherSpec {
773 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
774 }
775
776
777 if len(data) == 0 {
778 return c.retryReadRecord(expectChangeCipherSpec)
779 }
780
781
782
783 c.input.Reset(data)
784
785 case recordTypeHandshake:
786 if len(data) == 0 || expectChangeCipherSpec {
787 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
788 }
789 c.hand.Write(data)
790 }
791
792 return nil
793 }
794
795
796
797 func (c *Conn) retryReadRecord(expectChangeCipherSpec bool) error {
798 c.retryCount++
799 if c.retryCount > maxUselessRecords {
800 c.sendAlert(alertUnexpectedMessage)
801 return c.in.setErrorLocked(errors.New("tls: too many ignored records"))
802 }
803 return c.readRecordOrCCS(expectChangeCipherSpec)
804 }
805
806
807
808
809 type atLeastReader struct {
810 R io.Reader
811 N int64
812 }
813
814 func (r *atLeastReader) Read(p []byte) (int, error) {
815 if r.N <= 0 {
816 return 0, io.EOF
817 }
818 n, err := r.R.Read(p)
819 r.N -= int64(n)
820 if r.N > 0 && err == io.EOF {
821 return n, io.ErrUnexpectedEOF
822 }
823 if r.N <= 0 && err == nil {
824 return n, io.EOF
825 }
826 return n, err
827 }
828
829
830
831 func (c *Conn) readFromUntil(r io.Reader, n int) error {
832 if c.rawInput.Len() >= n {
833 return nil
834 }
835 needs := n - c.rawInput.Len()
836
837
838
839 c.rawInput.Grow(needs + bytes.MinRead)
840 _, err := c.rawInput.ReadFrom(&atLeastReader{r, int64(needs)})
841 return err
842 }
843
844
845 func (c *Conn) sendAlertLocked(err alert) error {
846 if c.quic != nil {
847 return c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err})
848 }
849
850 switch err {
851 case alertNoRenegotiation, alertCloseNotify:
852 c.tmp[0] = alertLevelWarning
853 default:
854 c.tmp[0] = alertLevelError
855 }
856 c.tmp[1] = byte(err)
857
858 _, writeErr := c.writeRecordLocked(recordTypeAlert, c.tmp[0:2])
859 if err == alertCloseNotify {
860
861 return writeErr
862 }
863
864 return c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err})
865 }
866
867
868 func (c *Conn) sendAlert(err alert) error {
869 c.out.Lock()
870 defer c.out.Unlock()
871 return c.sendAlertLocked(err)
872 }
873
874 const (
875
876
877
878
879
880 tcpMSSEstimate = 1208
881
882
883
884
885 recordSizeBoostThreshold = 128 * 1024
886 )
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904 func (c *Conn) maxPayloadSizeForWrite(typ recordType) int {
905 if c.config.DynamicRecordSizingDisabled || typ != recordTypeApplicationData {
906 return maxPlaintext
907 }
908
909 if c.bytesSent >= recordSizeBoostThreshold {
910 return maxPlaintext
911 }
912
913
914 payloadBytes := tcpMSSEstimate - recordHeaderLen - c.out.explicitNonceLen()
915 if c.out.cipher != nil {
916 switch ciph := c.out.cipher.(type) {
917 case cipher.Stream:
918 payloadBytes -= c.out.mac.Size()
919 case cipher.AEAD:
920 payloadBytes -= ciph.Overhead()
921 case cbcMode:
922 blockSize := ciph.BlockSize()
923
924
925 payloadBytes = (payloadBytes & ^(blockSize - 1)) - 1
926
927
928 payloadBytes -= c.out.mac.Size()
929 default:
930 panic("unknown cipher type")
931 }
932 }
933 if c.vers == VersionTLS13 {
934 payloadBytes--
935 }
936
937
938 pkt := c.packetsSent
939 c.packetsSent++
940 if pkt > 1000 {
941 return maxPlaintext
942 }
943
944 n := payloadBytes * int(pkt+1)
945 if n > maxPlaintext {
946 n = maxPlaintext
947 }
948 return n
949 }
950
951 func (c *Conn) write(data []byte) (int, error) {
952 if c.buffering {
953 c.sendBuf = append(c.sendBuf, data...)
954 return len(data), nil
955 }
956
957 n, err := c.conn.Write(data)
958 c.bytesSent += int64(n)
959 return n, err
960 }
961
962 func (c *Conn) flush() (int, error) {
963 if len(c.sendBuf) == 0 {
964 return 0, nil
965 }
966
967 n, err := c.conn.Write(c.sendBuf)
968 c.bytesSent += int64(n)
969 c.sendBuf = nil
970 c.buffering = false
971 return n, err
972 }
973
974
975 var outBufPool = sync.Pool{
976 New: func() any {
977 return new([]byte)
978 },
979 }
980
981
982
983 func (c *Conn) writeRecordLocked(typ recordType, data []byte) (int, error) {
984 if c.quic != nil {
985 if typ != recordTypeHandshake {
986 return 0, errors.New("tls: internal error: sending non-handshake message to QUIC transport")
987 }
988 c.quicWriteCryptoData(c.out.level, data)
989 if !c.buffering {
990 if _, err := c.flush(); err != nil {
991 return 0, err
992 }
993 }
994 return len(data), nil
995 }
996
997 outBufPtr := outBufPool.Get().(*[]byte)
998 outBuf := *outBufPtr
999 defer func() {
1000
1001
1002
1003
1004
1005 *outBufPtr = outBuf
1006 outBufPool.Put(outBufPtr)
1007 }()
1008
1009 var n int
1010 for len(data) > 0 {
1011 m := len(data)
1012 if maxPayload := c.maxPayloadSizeForWrite(typ); m > maxPayload {
1013 m = maxPayload
1014 }
1015
1016 _, outBuf = sliceForAppend(outBuf[:0], recordHeaderLen)
1017 outBuf[0] = byte(typ)
1018 vers := c.vers
1019 if vers == 0 {
1020
1021
1022 vers = VersionTLS10
1023 } else if vers == VersionTLS13 {
1024
1025
1026 vers = VersionTLS12
1027 }
1028 outBuf[1] = byte(vers >> 8)
1029 outBuf[2] = byte(vers)
1030 outBuf[3] = byte(m >> 8)
1031 outBuf[4] = byte(m)
1032
1033 var err error
1034 outBuf, err = c.out.encrypt(outBuf, data[:m], c.config.rand())
1035 if err != nil {
1036 return n, err
1037 }
1038 if _, err := c.write(outBuf); err != nil {
1039 return n, err
1040 }
1041 n += m
1042 data = data[m:]
1043 }
1044
1045 if typ == recordTypeChangeCipherSpec && c.vers != VersionTLS13 {
1046 if err := c.out.changeCipherSpec(); err != nil {
1047 return n, c.sendAlertLocked(err.(alert))
1048 }
1049 }
1050
1051 return n, nil
1052 }
1053
1054
1055
1056
1057 func (c *Conn) writeHandshakeRecord(msg handshakeMessage, transcript transcriptHash) (int, error) {
1058 c.out.Lock()
1059 defer c.out.Unlock()
1060
1061 data, err := msg.marshal()
1062 if err != nil {
1063 return 0, err
1064 }
1065 if transcript != nil {
1066 transcript.Write(data)
1067 }
1068
1069 return c.writeRecordLocked(recordTypeHandshake, data)
1070 }
1071
1072
1073
1074 func (c *Conn) writeChangeCipherRecord() error {
1075 c.out.Lock()
1076 defer c.out.Unlock()
1077 _, err := c.writeRecordLocked(recordTypeChangeCipherSpec, []byte{1})
1078 return err
1079 }
1080
1081
1082 func (c *Conn) readHandshakeBytes(n int) error {
1083 if c.quic != nil {
1084 return c.quicReadHandshakeBytes(n)
1085 }
1086 for c.hand.Len() < n {
1087 if err := c.readRecord(); err != nil {
1088 return err
1089 }
1090 }
1091 return nil
1092 }
1093
1094
1095
1096
1097 func (c *Conn) readHandshake(transcript transcriptHash) (any, error) {
1098 if err := c.readHandshakeBytes(4); err != nil {
1099 return nil, err
1100 }
1101 data := c.hand.Bytes()
1102
1103 maxHandshakeSize := maxHandshake
1104
1105
1106
1107 if c.haveVers && data[0] == typeCertificate {
1108
1109
1110
1111 maxHandshakeSize = maxHandshakeCertificateMsg
1112 }
1113
1114 n := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
1115 if n > maxHandshakeSize {
1116 c.sendAlertLocked(alertInternalError)
1117 return nil, c.in.setErrorLocked(fmt.Errorf("tls: handshake message of length %d bytes exceeds maximum of %d bytes", n, maxHandshakeSize))
1118 }
1119 if err := c.readHandshakeBytes(4 + n); err != nil {
1120 return nil, err
1121 }
1122 data = c.hand.Next(4 + n)
1123 return c.unmarshalHandshakeMessage(data, transcript)
1124 }
1125
1126 func (c *Conn) unmarshalHandshakeMessage(data []byte, transcript transcriptHash) (handshakeMessage, error) {
1127 var m handshakeMessage
1128 switch data[0] {
1129 case typeHelloRequest:
1130 m = new(helloRequestMsg)
1131 case typeClientHello:
1132 m = new(clientHelloMsg)
1133 case typeServerHello:
1134 m = new(serverHelloMsg)
1135 case typeNewSessionTicket:
1136 if c.vers == VersionTLS13 {
1137 m = new(newSessionTicketMsgTLS13)
1138 } else {
1139 m = new(newSessionTicketMsg)
1140 }
1141 case typeCertificate:
1142 if c.vers == VersionTLS13 {
1143 m = new(certificateMsgTLS13)
1144 } else {
1145 m = new(certificateMsg)
1146 }
1147 case typeCertificateRequest:
1148 if c.vers == VersionTLS13 {
1149 m = new(certificateRequestMsgTLS13)
1150 } else {
1151 m = &certificateRequestMsg{
1152 hasSignatureAlgorithm: c.vers >= VersionTLS12,
1153 }
1154 }
1155 case typeCertificateStatus:
1156 m = new(certificateStatusMsg)
1157 case typeServerKeyExchange:
1158 m = new(serverKeyExchangeMsg)
1159 case typeServerHelloDone:
1160 m = new(serverHelloDoneMsg)
1161 case typeClientKeyExchange:
1162 m = new(clientKeyExchangeMsg)
1163 case typeCertificateVerify:
1164 m = &certificateVerifyMsg{
1165 hasSignatureAlgorithm: c.vers >= VersionTLS12,
1166 }
1167 case typeFinished:
1168 m = new(finishedMsg)
1169 case typeEncryptedExtensions:
1170 m = new(encryptedExtensionsMsg)
1171 case typeEndOfEarlyData:
1172 m = new(endOfEarlyDataMsg)
1173 case typeKeyUpdate:
1174 m = new(keyUpdateMsg)
1175 default:
1176 return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
1177 }
1178
1179
1180
1181
1182 data = append([]byte(nil), data...)
1183
1184 if !m.unmarshal(data) {
1185 return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
1186 }
1187
1188 if transcript != nil {
1189 transcript.Write(data)
1190 }
1191
1192 return m, nil
1193 }
1194
1195 var (
1196 errShutdown = errors.New("tls: protocol is shutdown")
1197 )
1198
1199
1200
1201
1202
1203
1204
1205 func (c *Conn) Write(b []byte) (int, error) {
1206
1207 for {
1208 x := c.activeCall.Load()
1209 if x&1 != 0 {
1210 return 0, net.ErrClosed
1211 }
1212 if c.activeCall.CompareAndSwap(x, x+2) {
1213 break
1214 }
1215 }
1216 defer c.activeCall.Add(-2)
1217
1218 if err := c.Handshake(); err != nil {
1219 return 0, err
1220 }
1221
1222 c.out.Lock()
1223 defer c.out.Unlock()
1224
1225 if err := c.out.err; err != nil {
1226 return 0, err
1227 }
1228
1229 if !c.isHandshakeComplete.Load() {
1230 return 0, alertInternalError
1231 }
1232
1233 if c.closeNotifySent {
1234 return 0, errShutdown
1235 }
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246 var m int
1247 if len(b) > 1 && c.vers == VersionTLS10 {
1248 if _, ok := c.out.cipher.(cipher.BlockMode); ok {
1249 n, err := c.writeRecordLocked(recordTypeApplicationData, b[:1])
1250 if err != nil {
1251 return n, c.out.setErrorLocked(err)
1252 }
1253 m, b = 1, b[1:]
1254 }
1255 }
1256
1257 n, err := c.writeRecordLocked(recordTypeApplicationData, b)
1258 return n + m, c.out.setErrorLocked(err)
1259 }
1260
1261
1262 func (c *Conn) handleRenegotiation() error {
1263 if c.vers == VersionTLS13 {
1264 return errors.New("tls: internal error: unexpected renegotiation")
1265 }
1266
1267 msg, err := c.readHandshake(nil)
1268 if err != nil {
1269 return err
1270 }
1271
1272 helloReq, ok := msg.(*helloRequestMsg)
1273 if !ok {
1274 c.sendAlert(alertUnexpectedMessage)
1275 return unexpectedMessageError(helloReq, msg)
1276 }
1277
1278 if !c.isClient {
1279 return c.sendAlert(alertNoRenegotiation)
1280 }
1281
1282 switch c.config.Renegotiation {
1283 case RenegotiateNever:
1284 return c.sendAlert(alertNoRenegotiation)
1285 case RenegotiateOnceAsClient:
1286 if c.handshakes > 1 {
1287 return c.sendAlert(alertNoRenegotiation)
1288 }
1289 case RenegotiateFreelyAsClient:
1290
1291 default:
1292 c.sendAlert(alertInternalError)
1293 return errors.New("tls: unknown Renegotiation value")
1294 }
1295
1296 c.handshakeMutex.Lock()
1297 defer c.handshakeMutex.Unlock()
1298
1299 c.isHandshakeComplete.Store(false)
1300 if c.handshakeErr = c.clientHandshake(context.Background()); c.handshakeErr == nil {
1301 c.handshakes++
1302 }
1303 return c.handshakeErr
1304 }
1305
1306
1307
1308 func (c *Conn) handlePostHandshakeMessage() error {
1309 if c.vers != VersionTLS13 {
1310 return c.handleRenegotiation()
1311 }
1312
1313 msg, err := c.readHandshake(nil)
1314 if err != nil {
1315 return err
1316 }
1317 c.retryCount++
1318 if c.retryCount > maxUselessRecords {
1319 c.sendAlert(alertUnexpectedMessage)
1320 return c.in.setErrorLocked(errors.New("tls: too many non-advancing records"))
1321 }
1322
1323 switch msg := msg.(type) {
1324 case *newSessionTicketMsgTLS13:
1325 return c.handleNewSessionTicket(msg)
1326 case *keyUpdateMsg:
1327 return c.handleKeyUpdate(msg)
1328 }
1329
1330
1331
1332
1333 c.sendAlert(alertUnexpectedMessage)
1334 return fmt.Errorf("tls: received unexpected handshake message of type %T", msg)
1335 }
1336
1337 func (c *Conn) handleKeyUpdate(keyUpdate *keyUpdateMsg) error {
1338 if c.quic != nil {
1339 c.sendAlert(alertUnexpectedMessage)
1340 return c.in.setErrorLocked(errors.New("tls: received unexpected key update message"))
1341 }
1342
1343 cipherSuite := cipherSuiteTLS13ByID(c.cipherSuite)
1344 if cipherSuite == nil {
1345 return c.in.setErrorLocked(c.sendAlert(alertInternalError))
1346 }
1347
1348 newSecret := cipherSuite.nextTrafficSecret(c.in.trafficSecret)
1349 c.in.setTrafficSecret(cipherSuite, QUICEncryptionLevelInitial, newSecret)
1350
1351 if keyUpdate.updateRequested {
1352 c.out.Lock()
1353 defer c.out.Unlock()
1354
1355 msg := &keyUpdateMsg{}
1356 msgBytes, err := msg.marshal()
1357 if err != nil {
1358 return err
1359 }
1360 _, err = c.writeRecordLocked(recordTypeHandshake, msgBytes)
1361 if err != nil {
1362
1363 c.out.setErrorLocked(err)
1364 return nil
1365 }
1366
1367 newSecret := cipherSuite.nextTrafficSecret(c.out.trafficSecret)
1368 c.out.setTrafficSecret(cipherSuite, QUICEncryptionLevelInitial, newSecret)
1369 }
1370
1371 return nil
1372 }
1373
1374
1375
1376
1377
1378
1379
1380 func (c *Conn) Read(b []byte) (int, error) {
1381 if err := c.Handshake(); err != nil {
1382 return 0, err
1383 }
1384 if len(b) == 0 {
1385
1386
1387 return 0, nil
1388 }
1389
1390 c.in.Lock()
1391 defer c.in.Unlock()
1392
1393 for c.input.Len() == 0 {
1394 if err := c.readRecord(); err != nil {
1395 return 0, err
1396 }
1397 for c.hand.Len() > 0 {
1398 if err := c.handlePostHandshakeMessage(); err != nil {
1399 return 0, err
1400 }
1401 }
1402 }
1403
1404 n, _ := c.input.Read(b)
1405
1406
1407
1408
1409
1410
1411
1412
1413 if n != 0 && c.input.Len() == 0 && c.rawInput.Len() > 0 &&
1414 recordType(c.rawInput.Bytes()[0]) == recordTypeAlert {
1415 if err := c.readRecord(); err != nil {
1416 return n, err
1417 }
1418 }
1419
1420 return n, nil
1421 }
1422
1423
1424 func (c *Conn) Close() error {
1425
1426 var x int32
1427 for {
1428 x = c.activeCall.Load()
1429 if x&1 != 0 {
1430 return net.ErrClosed
1431 }
1432 if c.activeCall.CompareAndSwap(x, x|1) {
1433 break
1434 }
1435 }
1436 if x != 0 {
1437
1438
1439
1440
1441
1442
1443 return c.conn.Close()
1444 }
1445
1446 var alertErr error
1447 if c.isHandshakeComplete.Load() {
1448 if err := c.closeNotify(); err != nil {
1449 alertErr = fmt.Errorf("tls: failed to send closeNotify alert (but connection was closed anyway): %w", err)
1450 }
1451 }
1452
1453 if err := c.conn.Close(); err != nil {
1454 return err
1455 }
1456 return alertErr
1457 }
1458
1459 var errEarlyCloseWrite = errors.New("tls: CloseWrite called before handshake complete")
1460
1461
1462
1463
1464 func (c *Conn) CloseWrite() error {
1465 if !c.isHandshakeComplete.Load() {
1466 return errEarlyCloseWrite
1467 }
1468
1469 return c.closeNotify()
1470 }
1471
1472 func (c *Conn) closeNotify() error {
1473 c.out.Lock()
1474 defer c.out.Unlock()
1475
1476 if !c.closeNotifySent {
1477
1478 c.SetWriteDeadline(time.Now().Add(time.Second * 5))
1479 c.closeNotifyErr = c.sendAlertLocked(alertCloseNotify)
1480 c.closeNotifySent = true
1481
1482 c.SetWriteDeadline(time.Now())
1483 }
1484 return c.closeNotifyErr
1485 }
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500 func (c *Conn) Handshake() error {
1501 return c.HandshakeContext(context.Background())
1502 }
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514 func (c *Conn) HandshakeContext(ctx context.Context) error {
1515
1516
1517 return c.handshakeContext(ctx)
1518 }
1519
1520 func (c *Conn) handshakeContext(ctx context.Context) (ret error) {
1521
1522
1523
1524 if c.isHandshakeComplete.Load() {
1525 return nil
1526 }
1527
1528 handshakeCtx, cancel := context.WithCancel(ctx)
1529
1530
1531
1532 defer cancel()
1533
1534 if c.quic != nil {
1535 c.quic.cancelc = handshakeCtx.Done()
1536 c.quic.cancel = cancel
1537 } else if ctx.Done() != nil {
1538
1539
1540
1541
1542
1543 done := make(chan struct{})
1544 interruptRes := make(chan error, 1)
1545 defer func() {
1546 close(done)
1547 if ctxErr := <-interruptRes; ctxErr != nil {
1548
1549 ret = ctxErr
1550 }
1551 }()
1552 go func() {
1553 select {
1554 case <-handshakeCtx.Done():
1555
1556 _ = c.conn.Close()
1557 interruptRes <- handshakeCtx.Err()
1558 case <-done:
1559 interruptRes <- nil
1560 }
1561 }()
1562 }
1563
1564 c.handshakeMutex.Lock()
1565 defer c.handshakeMutex.Unlock()
1566
1567 if err := c.handshakeErr; err != nil {
1568 return err
1569 }
1570 if c.isHandshakeComplete.Load() {
1571 return nil
1572 }
1573
1574 c.in.Lock()
1575 defer c.in.Unlock()
1576
1577 c.handshakeErr = c.handshakeFn(handshakeCtx)
1578 if c.handshakeErr == nil {
1579 c.handshakes++
1580 } else {
1581
1582
1583 c.flush()
1584 }
1585
1586 if c.handshakeErr == nil && !c.isHandshakeComplete.Load() {
1587 c.handshakeErr = errors.New("tls: internal error: handshake should have had a result")
1588 }
1589 if c.handshakeErr != nil && c.isHandshakeComplete.Load() {
1590 panic("tls: internal error: handshake returned an error but is marked successful")
1591 }
1592
1593 if c.quic != nil {
1594 if c.handshakeErr == nil {
1595 c.quicHandshakeComplete()
1596
1597
1598
1599 c.quicSetReadSecret(QUICEncryptionLevelApplication, c.cipherSuite, c.in.trafficSecret)
1600 } else {
1601 var a alert
1602 c.out.Lock()
1603 if !errors.As(c.out.err, &a) {
1604 a = alertInternalError
1605 }
1606 c.out.Unlock()
1607
1608
1609
1610
1611 c.handshakeErr = fmt.Errorf("%w%.0w", c.handshakeErr, AlertError(a))
1612 }
1613 close(c.quic.blockedc)
1614 close(c.quic.signalc)
1615 }
1616
1617 return c.handshakeErr
1618 }
1619
1620
1621 func (c *Conn) ConnectionState() ConnectionState {
1622 c.handshakeMutex.Lock()
1623 defer c.handshakeMutex.Unlock()
1624 return c.connectionStateLocked()
1625 }
1626
1627 var tlsunsafeekm = godebug.New("tlsunsafeekm")
1628
1629 func (c *Conn) connectionStateLocked() ConnectionState {
1630 var state ConnectionState
1631 state.HandshakeComplete = c.isHandshakeComplete.Load()
1632 state.Version = c.vers
1633 state.NegotiatedProtocol = c.clientProtocol
1634 state.DidResume = c.didResume
1635 state.testingOnlyDidHRR = c.didHRR
1636 state.CurveID = c.curveID
1637 state.NegotiatedProtocolIsMutual = true
1638 state.ServerName = c.serverName
1639 state.CipherSuite = c.cipherSuite
1640 state.PeerCertificates = c.peerCertificates
1641 state.VerifiedChains = c.verifiedChains
1642 state.SignedCertificateTimestamps = c.scts
1643 state.OCSPResponse = c.ocspResponse
1644 if (!c.didResume || c.extMasterSecret) && c.vers != VersionTLS13 {
1645 if c.clientFinishedIsFirst {
1646 state.TLSUnique = c.clientFinished[:]
1647 } else {
1648 state.TLSUnique = c.serverFinished[:]
1649 }
1650 }
1651 if c.config.Renegotiation != RenegotiateNever {
1652 state.ekm = noEKMBecauseRenegotiation
1653 } else if c.vers != VersionTLS13 && !c.extMasterSecret {
1654 state.ekm = func(label string, context []byte, length int) ([]byte, error) {
1655 if tlsunsafeekm.Value() == "1" {
1656 tlsunsafeekm.IncNonDefault()
1657 return c.ekm(label, context, length)
1658 }
1659 return noEKMBecauseNoEMS(label, context, length)
1660 }
1661 } else {
1662 state.ekm = c.ekm
1663 }
1664 state.ECHAccepted = c.echAccepted
1665 return state
1666 }
1667
1668
1669
1670 func (c *Conn) OCSPResponse() []byte {
1671 c.handshakeMutex.Lock()
1672 defer c.handshakeMutex.Unlock()
1673
1674 return c.ocspResponse
1675 }
1676
1677
1678
1679
1680 func (c *Conn) VerifyHostname(host string) error {
1681 c.handshakeMutex.Lock()
1682 defer c.handshakeMutex.Unlock()
1683 if !c.isClient {
1684 return errors.New("tls: VerifyHostname called on TLS server connection")
1685 }
1686 if !c.isHandshakeComplete.Load() {
1687 return errors.New("tls: handshake has not yet been performed")
1688 }
1689 if len(c.verifiedChains) == 0 {
1690 return errors.New("tls: handshake did not verify certificate chain")
1691 }
1692 return c.peerCertificates[0].VerifyHostname(host)
1693 }
1694
View as plain text