// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package hpke

import (
	"bytes"
	"encoding/hex"
	"encoding/json"
	"os"
	"strconv"
	"strings"
	"testing"

	"crypto/ecdh"
	_ "crypto/sha256"
	_ "crypto/sha512"
)

func mustDecodeHex(t *testing.T, in string) []byte {
	t.Helper()
	b, err := hex.DecodeString(in)
	if err != nil {
		t.Fatal(err)
	}
	return b
}

func parseVectorSetup(vector string) map[string]string {
	vals := map[string]string{}
	for _, l := range strings.Split(vector, "\n") {
		fields := strings.Split(l, ": ")
		vals[fields[0]] = fields[1]
	}
	return vals
}

func parseVectorEncryptions(vector string) []map[string]string {
	vals := []map[string]string{}
	for _, section := range strings.Split(vector, "\n\n") {
		e := map[string]string{}
		for _, l := range strings.Split(section, "\n") {
			fields := strings.Split(l, ": ")
			e[fields[0]] = fields[1]
		}
		vals = append(vals, e)
	}
	return vals
}

func TestRFC9180Vectors(t *testing.T) {
	vectorsJSON, err := os.ReadFile("testdata/rfc9180-vectors.json")
	if err != nil {
		t.Fatal(err)
	}

	var vectors []struct {
		Name        string
		Setup       string
		Encryptions string
	}
	if err := json.Unmarshal(vectorsJSON, &vectors); err != nil {
		t.Fatal(err)
	}

	for _, vector := range vectors {
		t.Run(vector.Name, func(t *testing.T) {
			setup := parseVectorSetup(vector.Setup)

			kemID, err := strconv.Atoi(setup["kem_id"])
			if err != nil {
				t.Fatal(err)
			}
			if _, ok := SupportedKEMs[uint16(kemID)]; !ok {
				t.Skip("unsupported KEM")
			}
			kdfID, err := strconv.Atoi(setup["kdf_id"])
			if err != nil {
				t.Fatal(err)
			}
			if _, ok := SupportedKDFs[uint16(kdfID)]; !ok {
				t.Skip("unsupported KDF")
			}
			aeadID, err := strconv.Atoi(setup["aead_id"])
			if err != nil {
				t.Fatal(err)
			}
			if _, ok := SupportedAEADs[uint16(aeadID)]; !ok {
				t.Skip("unsupported AEAD")
			}

			info := mustDecodeHex(t, setup["info"])
			pubKeyBytes := mustDecodeHex(t, setup["pkRm"])
			pub, err := ParseHPKEPublicKey(uint16(kemID), pubKeyBytes)
			if err != nil {
				t.Fatal(err)
			}

			ephemeralPrivKey := mustDecodeHex(t, setup["skEm"])

			testingOnlyGenerateKey = func() (*ecdh.PrivateKey, error) {
				return SupportedKEMs[uint16(kemID)].curve.NewPrivateKey(ephemeralPrivKey)
			}
			t.Cleanup(func() { testingOnlyGenerateKey = nil })

			encap, sender, err := SetupSender(
				uint16(kemID),
				uint16(kdfID),
				uint16(aeadID),
				pub,
				info,
			)
			if err != nil {
				t.Fatal(err)
			}

			expectedEncap := mustDecodeHex(t, setup["enc"])
			if !bytes.Equal(encap, expectedEncap) {
				t.Errorf("unexpected encapsulated key, got: %x, want %x", encap, expectedEncap)
			}

			privKeyBytes := mustDecodeHex(t, setup["skRm"])
			priv, err := ParseHPKEPrivateKey(uint16(kemID), privKeyBytes)
			if err != nil {
				t.Fatal(err)
			}

			receipient, err := SetupReceipient(
				uint16(kemID),
				uint16(kdfID),
				uint16(aeadID),
				priv,
				info,
				encap,
			)
			if err != nil {
				t.Fatal(err)
			}

			for _, ctx := range []*context{sender.context, receipient.context} {
				expectedSharedSecret := mustDecodeHex(t, setup["shared_secret"])
				if !bytes.Equal(ctx.sharedSecret, expectedSharedSecret) {
					t.Errorf("unexpected shared secret, got: %x, want %x", ctx.sharedSecret, expectedSharedSecret)
				}
				expectedKey := mustDecodeHex(t, setup["key"])
				if !bytes.Equal(ctx.key, expectedKey) {
					t.Errorf("unexpected key, got: %x, want %x", ctx.key, expectedKey)
				}
				expectedBaseNonce := mustDecodeHex(t, setup["base_nonce"])
				if !bytes.Equal(ctx.baseNonce, expectedBaseNonce) {
					t.Errorf("unexpected base nonce, got: %x, want %x", ctx.baseNonce, expectedBaseNonce)
				}
				expectedExporterSecret := mustDecodeHex(t, setup["exporter_secret"])
				if !bytes.Equal(ctx.exporterSecret, expectedExporterSecret) {
					t.Errorf("unexpected exporter secret, got: %x, want %x", ctx.exporterSecret, expectedExporterSecret)
				}
			}

			for _, enc := range parseVectorEncryptions(vector.Encryptions) {
				t.Run("seq num "+enc["sequence number"], func(t *testing.T) {
					seqNum, err := strconv.Atoi(enc["sequence number"])
					if err != nil {
						t.Fatal(err)
					}
					sender.seqNum = uint128{lo: uint64(seqNum)}
					receipient.seqNum = uint128{lo: uint64(seqNum)}
					expectedNonce := mustDecodeHex(t, enc["nonce"])
					computedNonce := sender.nextNonce()
					if !bytes.Equal(computedNonce, expectedNonce) {
						t.Errorf("unexpected nonce: got %x, want %x", computedNonce, expectedNonce)
					}

					expectedCiphertext := mustDecodeHex(t, enc["ct"])
					ciphertext, err := sender.Seal(mustDecodeHex(t, enc["aad"]), mustDecodeHex(t, enc["pt"]))
					if err != nil {
						t.Fatal(err)
					}
					if !bytes.Equal(ciphertext, expectedCiphertext) {
						t.Errorf("unexpected ciphertext: got %x want %x", ciphertext, expectedCiphertext)
					}

					expectedPlaintext := mustDecodeHex(t, enc["pt"])
					plaintext, err := receipient.Open(mustDecodeHex(t, enc["aad"]), mustDecodeHex(t, enc["ct"]))
					if err != nil {
						t.Fatal(err)
					}
					if !bytes.Equal(plaintext, expectedPlaintext) {
						t.Errorf("unexpected plaintext: got %x want %x", plaintext, expectedPlaintext)
					}
				})
			}
		})
	}
}