// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package tls

import (
	"bytes"
	"crypto/x509"
	"encoding/hex"
	"math"
	"math/rand"
	"reflect"
	"strings"
	"testing"
	"testing/quick"
	"time"
)

var tests = []handshakeMessage{
	&clientHelloMsg{},
	&serverHelloMsg{},
	&finishedMsg{},

	&certificateMsg{},
	&certificateRequestMsg{},
	&certificateVerifyMsg{
		hasSignatureAlgorithm: true,
	},
	&certificateStatusMsg{},
	&clientKeyExchangeMsg{},
	&newSessionTicketMsg{},
	&encryptedExtensionsMsg{},
	&endOfEarlyDataMsg{},
	&keyUpdateMsg{},
	&newSessionTicketMsgTLS13{},
	&certificateRequestMsgTLS13{},
	&certificateMsgTLS13{},
	&SessionState{},
}

func mustMarshal(t *testing.T, msg handshakeMessage) []byte {
	t.Helper()
	b, err := msg.marshal()
	if err != nil {
		t.Fatal(err)
	}
	return b
}

func TestMarshalUnmarshal(t *testing.T) {
	rand := rand.New(rand.NewSource(time.Now().UnixNano()))

	for i, m := range tests {
		ty := reflect.ValueOf(m).Type()
		t.Run(ty.String(), func(t *testing.T) {
			n := 100
			if testing.Short() {
				n = 5
			}
			for j := 0; j < n; j++ {
				v, ok := quick.Value(ty, rand)
				if !ok {
					t.Errorf("#%d: failed to create value", i)
					break
				}

				m1 := v.Interface().(handshakeMessage)
				marshaled := mustMarshal(t, m1)
				if !m.unmarshal(marshaled) {
					t.Errorf("#%d failed to unmarshal %#v %x", i, m1, marshaled)
					break
				}

				if m, ok := m.(*SessionState); ok {
					m.activeCertHandles = nil
				}

				if ch, ok := m.(*clientHelloMsg); ok {
					// extensions is special cased, as it is only populated by the
					// server-side of a handshake and is not expected to roundtrip
					// through marshal + unmarshal.  m ends up with the list of
					// extensions necessary to serialize the other fields of
					// clientHelloMsg, so check that it is non-empty, then clear it.
					if len(ch.extensions) == 0 {
						t.Errorf("expected ch.extensions to be populated on unmarshal")
					}
					ch.extensions = nil
				}

				// clientHelloMsg and serverHelloMsg, when unmarshalled, store
				// their original representation, for later use in the handshake
				// transcript. In order to prevent DeepEqual from failing since
				// we didn't create the original message via unmarshalling, nil
				// the field.
				switch t := m.(type) {
				case *clientHelloMsg:
					t.original = nil
				case *serverHelloMsg:
					t.original = nil
				}

				if !reflect.DeepEqual(m1, m) {
					t.Errorf("#%d got:%#v want:%#v %x", i, m, m1, marshaled)
					break
				}

				if i >= 3 {
					// The first three message types (ClientHello,
					// ServerHello and Finished) are allowed to
					// have parsable prefixes because the extension
					// data is optional and the length of the
					// Finished varies across versions.
					for j := 0; j < len(marshaled); j++ {
						if m.unmarshal(marshaled[0:j]) {
							t.Errorf("#%d unmarshaled a prefix of length %d of %#v", i, j, m1)
							break
						}
					}
				}
			}
		})
	}
}

func TestFuzz(t *testing.T) {
	rand := rand.New(rand.NewSource(0))
	for _, m := range tests {
		for j := 0; j < 1000; j++ {
			len := rand.Intn(1000)
			bytes := randomBytes(len, rand)
			// This just looks for crashes due to bounds errors etc.
			m.unmarshal(bytes)
		}
	}
}

func randomBytes(n int, rand *rand.Rand) []byte {
	r := make([]byte, n)
	if _, err := rand.Read(r); err != nil {
		panic("rand.Read failed: " + err.Error())
	}
	return r
}

func randomString(n int, rand *rand.Rand) string {
	b := randomBytes(n, rand)
	return string(b)
}

func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
	m := &clientHelloMsg{}
	m.vers = uint16(rand.Intn(65536))
	m.random = randomBytes(32, rand)
	m.sessionId = randomBytes(rand.Intn(32), rand)
	m.cipherSuites = make([]uint16, rand.Intn(63)+1)
	for i := 0; i < len(m.cipherSuites); i++ {
		cs := uint16(rand.Int31())
		if cs == scsvRenegotiation {
			cs += 1
		}
		m.cipherSuites[i] = cs
	}
	m.compressionMethods = randomBytes(rand.Intn(63)+1, rand)
	if rand.Intn(10) > 5 {
		m.serverName = randomString(rand.Intn(255), rand)
		for strings.HasSuffix(m.serverName, ".") {
			m.serverName = m.serverName[:len(m.serverName)-1]
		}
	}
	m.ocspStapling = rand.Intn(10) > 5
	m.supportedPoints = randomBytes(rand.Intn(5)+1, rand)
	m.supportedCurves = make([]CurveID, rand.Intn(5)+1)
	for i := range m.supportedCurves {
		m.supportedCurves[i] = CurveID(rand.Intn(30000) + 1)
	}
	if rand.Intn(10) > 5 {
		m.ticketSupported = true
		if rand.Intn(10) > 5 {
			m.sessionTicket = randomBytes(rand.Intn(300), rand)
		} else {
			m.sessionTicket = make([]byte, 0)
		}
	}
	if rand.Intn(10) > 5 {
		m.supportedSignatureAlgorithms = supportedSignatureAlgorithms()
	}
	if rand.Intn(10) > 5 {
		m.supportedSignatureAlgorithmsCert = supportedSignatureAlgorithms()
	}
	for i := 0; i < rand.Intn(5); i++ {
		m.alpnProtocols = append(m.alpnProtocols, randomString(rand.Intn(20)+1, rand))
	}
	if rand.Intn(10) > 5 {
		m.scts = true
	}
	if rand.Intn(10) > 5 {
		m.secureRenegotiationSupported = true
		m.secureRenegotiation = randomBytes(rand.Intn(50)+1, rand)
	}
	if rand.Intn(10) > 5 {
		m.extendedMasterSecret = true
	}
	for i := 0; i < rand.Intn(5); i++ {
		m.supportedVersions = append(m.supportedVersions, uint16(rand.Intn(0xffff)+1))
	}
	if rand.Intn(10) > 5 {
		m.cookie = randomBytes(rand.Intn(500)+1, rand)
	}
	for i := 0; i < rand.Intn(5); i++ {
		var ks keyShare
		ks.group = CurveID(rand.Intn(30000) + 1)
		ks.data = randomBytes(rand.Intn(200)+1, rand)
		m.keyShares = append(m.keyShares, ks)
	}
	switch rand.Intn(3) {
	case 1:
		m.pskModes = []uint8{pskModeDHE}
	case 2:
		m.pskModes = []uint8{pskModeDHE, pskModePlain}
	}
	for i := 0; i < rand.Intn(5); i++ {
		var psk pskIdentity
		psk.obfuscatedTicketAge = uint32(rand.Intn(500000))
		psk.label = randomBytes(rand.Intn(500)+1, rand)
		m.pskIdentities = append(m.pskIdentities, psk)
		m.pskBinders = append(m.pskBinders, randomBytes(rand.Intn(50)+32, rand))
	}
	if rand.Intn(10) > 5 {
		m.quicTransportParameters = randomBytes(rand.Intn(500), rand)
	}
	if rand.Intn(10) > 5 {
		m.earlyData = true
	}
	if rand.Intn(10) > 5 {
		m.encryptedClientHello = randomBytes(rand.Intn(50)+1, rand)
	}

	return reflect.ValueOf(m)
}

func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
	m := &serverHelloMsg{}
	m.vers = uint16(rand.Intn(65536))
	m.random = randomBytes(32, rand)
	m.sessionId = randomBytes(rand.Intn(32), rand)
	m.cipherSuite = uint16(rand.Int31())
	m.compressionMethod = uint8(rand.Intn(256))
	m.supportedPoints = randomBytes(rand.Intn(5)+1, rand)

	if rand.Intn(10) > 5 {
		m.ocspStapling = true
	}
	if rand.Intn(10) > 5 {
		m.ticketSupported = true
	}
	if rand.Intn(10) > 5 {
		m.alpnProtocol = randomString(rand.Intn(32)+1, rand)
	}

	for i := 0; i < rand.Intn(4); i++ {
		m.scts = append(m.scts, randomBytes(rand.Intn(500)+1, rand))
	}

	if rand.Intn(10) > 5 {
		m.secureRenegotiationSupported = true
		m.secureRenegotiation = randomBytes(rand.Intn(50)+1, rand)
	}
	if rand.Intn(10) > 5 {
		m.extendedMasterSecret = true
	}
	if rand.Intn(10) > 5 {
		m.supportedVersion = uint16(rand.Intn(0xffff) + 1)
	}
	if rand.Intn(10) > 5 {
		m.cookie = randomBytes(rand.Intn(500)+1, rand)
	}
	if rand.Intn(10) > 5 {
		for i := 0; i < rand.Intn(5); i++ {
			m.serverShare.group = CurveID(rand.Intn(30000) + 1)
			m.serverShare.data = randomBytes(rand.Intn(200)+1, rand)
		}
	} else if rand.Intn(10) > 5 {
		m.selectedGroup = CurveID(rand.Intn(30000) + 1)
	}
	if rand.Intn(10) > 5 {
		m.selectedIdentityPresent = true
		m.selectedIdentity = uint16(rand.Intn(0xffff))
	}
	if rand.Intn(10) > 5 {
		m.encryptedClientHello = randomBytes(rand.Intn(50)+1, rand)
	}
	if rand.Intn(10) > 5 {
		m.serverNameAck = rand.Intn(2) == 1
	}

	return reflect.ValueOf(m)
}

func (*encryptedExtensionsMsg) Generate(rand *rand.Rand, size int) reflect.Value {
	m := &encryptedExtensionsMsg{}

	if rand.Intn(10) > 5 {
		m.alpnProtocol = randomString(rand.Intn(32)+1, rand)
	}
	if rand.Intn(10) > 5 {
		m.earlyData = true
	}

	return reflect.ValueOf(m)
}

func (*certificateMsg) Generate(rand *rand.Rand, size int) reflect.Value {
	m := &certificateMsg{}
	numCerts := rand.Intn(20)
	m.certificates = make([][]byte, numCerts)
	for i := 0; i < numCerts; i++ {
		m.certificates[i] = randomBytes(rand.Intn(10)+1, rand)
	}
	return reflect.ValueOf(m)
}

func (*certificateRequestMsg) Generate(rand *rand.Rand, size int) reflect.Value {
	m := &certificateRequestMsg{}
	m.certificateTypes = randomBytes(rand.Intn(5)+1, rand)
	for i := 0; i < rand.Intn(100); i++ {
		m.certificateAuthorities = append(m.certificateAuthorities, randomBytes(rand.Intn(15)+1, rand))
	}
	return reflect.ValueOf(m)
}

func (*certificateVerifyMsg) Generate(rand *rand.Rand, size int) reflect.Value {
	m := &certificateVerifyMsg{}
	m.hasSignatureAlgorithm = true
	m.signatureAlgorithm = SignatureScheme(rand.Intn(30000))
	m.signature = randomBytes(rand.Intn(15)+1, rand)
	return reflect.ValueOf(m)
}

func (*certificateStatusMsg) Generate(rand *rand.Rand, size int) reflect.Value {
	m := &certificateStatusMsg{}
	m.response = randomBytes(rand.Intn(10)+1, rand)
	return reflect.ValueOf(m)
}

func (*clientKeyExchangeMsg) Generate(rand *rand.Rand, size int) reflect.Value {
	m := &clientKeyExchangeMsg{}
	m.ciphertext = randomBytes(rand.Intn(1000)+1, rand)
	return reflect.ValueOf(m)
}

func (*finishedMsg) Generate(rand *rand.Rand, size int) reflect.Value {
	m := &finishedMsg{}
	m.verifyData = randomBytes(12, rand)
	return reflect.ValueOf(m)
}

func (*newSessionTicketMsg) Generate(rand *rand.Rand, size int) reflect.Value {
	m := &newSessionTicketMsg{}
	m.ticket = randomBytes(rand.Intn(4), rand)
	return reflect.ValueOf(m)
}

var sessionTestCerts []*x509.Certificate

func init() {
	cert, err := x509.ParseCertificate(testRSACertificate)
	if err != nil {
		panic(err)
	}
	sessionTestCerts = append(sessionTestCerts, cert)
	cert, err = x509.ParseCertificate(testRSACertificateIssuer)
	if err != nil {
		panic(err)
	}
	sessionTestCerts = append(sessionTestCerts, cert)
}

func (*SessionState) Generate(rand *rand.Rand, size int) reflect.Value {
	s := &SessionState{}
	isTLS13 := rand.Intn(10) > 5
	if isTLS13 {
		s.version = VersionTLS13
	} else {
		s.version = uint16(rand.Intn(VersionTLS13))
	}
	s.isClient = rand.Intn(10) > 5
	s.cipherSuite = uint16(rand.Intn(math.MaxUint16))
	s.createdAt = uint64(rand.Int63())
	s.secret = randomBytes(rand.Intn(100)+1, rand)
	for n, i := rand.Intn(3), 0; i < n; i++ {
		s.Extra = append(s.Extra, randomBytes(rand.Intn(100), rand))
	}
	if rand.Intn(10) > 5 {
		s.EarlyData = true
	}
	if rand.Intn(10) > 5 {
		s.extMasterSecret = true
	}
	if s.isClient || rand.Intn(10) > 5 {
		if rand.Intn(10) > 5 {
			s.peerCertificates = sessionTestCerts
		} else {
			s.peerCertificates = sessionTestCerts[:1]
		}
	}
	if rand.Intn(10) > 5 && s.peerCertificates != nil {
		s.ocspResponse = randomBytes(rand.Intn(100)+1, rand)
	}
	if rand.Intn(10) > 5 && s.peerCertificates != nil {
		for i := 0; i < rand.Intn(2)+1; i++ {
			s.scts = append(s.scts, randomBytes(rand.Intn(500)+1, rand))
		}
	}
	if len(s.peerCertificates) > 0 {
		for i := 0; i < rand.Intn(3); i++ {
			if rand.Intn(10) > 5 {
				s.verifiedChains = append(s.verifiedChains, s.peerCertificates)
			} else {
				s.verifiedChains = append(s.verifiedChains, s.peerCertificates[:1])
			}
		}
	}
	if rand.Intn(10) > 5 && s.EarlyData {
		s.alpnProtocol = string(randomBytes(rand.Intn(10), rand))
	}
	if isTLS13 {
		if s.isClient {
			s.useBy = uint64(rand.Int63())
			s.ageAdd = uint32(rand.Int63() & math.MaxUint32)
		}
	} else {
		s.curveID = CurveID(rand.Intn(30000) + 1)
	}
	return reflect.ValueOf(s)
}

func (s *SessionState) marshal() ([]byte, error) { return s.Bytes() }
func (s *SessionState) unmarshal(b []byte) bool {
	ss, err := ParseSessionState(b)
	if err != nil {
		return false
	}
	*s = *ss
	return true
}

func (*endOfEarlyDataMsg) Generate(rand *rand.Rand, size int) reflect.Value {
	m := &endOfEarlyDataMsg{}
	return reflect.ValueOf(m)
}

func (*keyUpdateMsg) Generate(rand *rand.Rand, size int) reflect.Value {
	m := &keyUpdateMsg{}
	m.updateRequested = rand.Intn(10) > 5
	return reflect.ValueOf(m)
}

func (*newSessionTicketMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value {
	m := &newSessionTicketMsgTLS13{}
	m.lifetime = uint32(rand.Intn(500000))
	m.ageAdd = uint32(rand.Intn(500000))
	m.nonce = randomBytes(rand.Intn(100), rand)
	m.label = randomBytes(rand.Intn(1000), rand)
	if rand.Intn(10) > 5 {
		m.maxEarlyData = uint32(rand.Intn(500000))
	}
	return reflect.ValueOf(m)
}

func (*certificateRequestMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value {
	m := &certificateRequestMsgTLS13{}
	if rand.Intn(10) > 5 {
		m.ocspStapling = true
	}
	if rand.Intn(10) > 5 {
		m.scts = true
	}
	if rand.Intn(10) > 5 {
		m.supportedSignatureAlgorithms = supportedSignatureAlgorithms()
	}
	if rand.Intn(10) > 5 {
		m.supportedSignatureAlgorithmsCert = supportedSignatureAlgorithms()
	}
	if rand.Intn(10) > 5 {
		m.certificateAuthorities = make([][]byte, 3)
		for i := 0; i < 3; i++ {
			m.certificateAuthorities[i] = randomBytes(rand.Intn(10)+1, rand)
		}
	}
	return reflect.ValueOf(m)
}

func (*certificateMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value {
	m := &certificateMsgTLS13{}
	for i := 0; i < rand.Intn(2)+1; i++ {
		m.certificate.Certificate = append(
			m.certificate.Certificate, randomBytes(rand.Intn(500)+1, rand))
	}
	if rand.Intn(10) > 5 {
		m.ocspStapling = true
		m.certificate.OCSPStaple = randomBytes(rand.Intn(100)+1, rand)
	}
	if rand.Intn(10) > 5 {
		m.scts = true
		for i := 0; i < rand.Intn(2)+1; i++ {
			m.certificate.SignedCertificateTimestamps = append(
				m.certificate.SignedCertificateTimestamps, randomBytes(rand.Intn(500)+1, rand))
		}
	}
	return reflect.ValueOf(m)
}

func TestRejectEmptySCTList(t *testing.T) {
	// RFC 6962, Section 3.3.1 specifies that empty SCT lists are invalid.

	var random [32]byte
	sct := []byte{0x42, 0x42, 0x42, 0x42}
	serverHello := &serverHelloMsg{
		vers:   VersionTLS12,
		random: random[:],
		scts:   [][]byte{sct},
	}
	serverHelloBytes := mustMarshal(t, serverHello)

	var serverHelloCopy serverHelloMsg
	if !serverHelloCopy.unmarshal(serverHelloBytes) {
		t.Fatal("Failed to unmarshal initial message")
	}

	// Change serverHelloBytes so that the SCT list is empty
	i := bytes.Index(serverHelloBytes, sct)
	if i < 0 {
		t.Fatal("Cannot find SCT in ServerHello")
	}

	var serverHelloEmptySCT []byte
	serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[:i-6]...)
	// Append the extension length and SCT list length for an empty list.
	serverHelloEmptySCT = append(serverHelloEmptySCT, []byte{0, 2, 0, 0}...)
	serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[i+4:]...)

	// Update the handshake message length.
	serverHelloEmptySCT[1] = byte((len(serverHelloEmptySCT) - 4) >> 16)
	serverHelloEmptySCT[2] = byte((len(serverHelloEmptySCT) - 4) >> 8)
	serverHelloEmptySCT[3] = byte(len(serverHelloEmptySCT) - 4)

	// Update the extensions length
	serverHelloEmptySCT[42] = byte((len(serverHelloEmptySCT) - 44) >> 8)
	serverHelloEmptySCT[43] = byte((len(serverHelloEmptySCT) - 44))

	if serverHelloCopy.unmarshal(serverHelloEmptySCT) {
		t.Fatal("Unmarshaled ServerHello with empty SCT list")
	}
}

func TestRejectEmptySCT(t *testing.T) {
	// Not only must the SCT list be non-empty, but the SCT elements must
	// not be zero length.

	var random [32]byte
	serverHello := &serverHelloMsg{
		vers:   VersionTLS12,
		random: random[:],
		scts:   [][]byte{nil},
	}
	serverHelloBytes := mustMarshal(t, serverHello)

	var serverHelloCopy serverHelloMsg
	if serverHelloCopy.unmarshal(serverHelloBytes) {
		t.Fatal("Unmarshaled ServerHello with zero-length SCT")
	}
}

func TestRejectDuplicateExtensions(t *testing.T) {
	clientHelloBytes, err := hex.DecodeString("010000440303000000000000000000000000000000000000000000000000000000000000000000000000001c0000000a000800000568656c6c6f0000000a000800000568656c6c6f")
	if err != nil {
		t.Fatalf("failed to decode test ClientHello: %s", err)
	}
	var clientHelloCopy clientHelloMsg
	if clientHelloCopy.unmarshal(clientHelloBytes) {
		t.Error("Unmarshaled ClientHello with duplicate extensions")
	}

	serverHelloBytes, err := hex.DecodeString("02000030030300000000000000000000000000000000000000000000000000000000000000000000000000080005000000050000")
	if err != nil {
		t.Fatalf("failed to decode test ServerHello: %s", err)
	}
	var serverHelloCopy serverHelloMsg
	if serverHelloCopy.unmarshal(serverHelloBytes) {
		t.Fatal("Unmarshaled ServerHello with duplicate extensions")
	}
}