// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

//go:build ignore

// Note: this program must be run in this directory.
//   go run mknode.go

package main

import (
	"bytes"
	"fmt"
	"go/ast"
	"go/format"
	"go/parser"
	"go/token"
	"io/fs"
	"log"
	"os"
	"slices"
	"strings"
)

var fset = token.NewFileSet()

var buf bytes.Buffer

// concreteNodes contains all concrete types in the package that implement Node
// (except for the mini* types).
var concreteNodes []*ast.TypeSpec

// interfaceNodes contains all interface types in the package that implement Node.
var interfaceNodes []*ast.TypeSpec

// mini contains the embeddable mini types (miniNode, miniExpr, and miniStmt).
var mini = map[string]*ast.TypeSpec{}

// implementsNode reports whether the type t is one which represents a Node
// in the AST.
func implementsNode(t ast.Expr) bool {
	id, ok := t.(*ast.Ident)
	if !ok {
		return false // only named types
	}
	for _, ts := range interfaceNodes {
		if ts.Name.Name == id.Name {
			return true
		}
	}
	for _, ts := range concreteNodes {
		if ts.Name.Name == id.Name {
			return true
		}
	}
	return false
}

func isMini(t ast.Expr) bool {
	id, ok := t.(*ast.Ident)
	return ok && mini[id.Name] != nil
}

func isNamedType(t ast.Expr, name string) bool {
	if id, ok := t.(*ast.Ident); ok {
		if id.Name == name {
			return true
		}
	}
	return false
}

func main() {
	fmt.Fprintln(&buf, "// Code generated by mknode.go. DO NOT EDIT.")
	fmt.Fprintln(&buf)
	fmt.Fprintln(&buf, "package ir")
	fmt.Fprintln(&buf)
	fmt.Fprintln(&buf, `import "fmt"`)

	filter := func(file fs.FileInfo) bool {
		return !strings.HasPrefix(file.Name(), "mknode")
	}
	pkgs, err := parser.ParseDir(fset, ".", filter, 0)
	if err != nil {
		panic(err)
	}
	pkg := pkgs["ir"]

	// Find all the mini types. These let us determine which
	// concrete types implement Node, so we need to find them first.
	for _, f := range pkg.Files {
		for _, d := range f.Decls {
			g, ok := d.(*ast.GenDecl)
			if !ok {
				continue
			}
			for _, s := range g.Specs {
				t, ok := s.(*ast.TypeSpec)
				if !ok {
					continue
				}
				if strings.HasPrefix(t.Name.Name, "mini") {
					mini[t.Name.Name] = t
					// Double-check that it is or embeds miniNode.
					if t.Name.Name != "miniNode" {
						s := t.Type.(*ast.StructType)
						if !isNamedType(s.Fields.List[0].Type, "miniNode") {
							panic(fmt.Sprintf("can't find miniNode in %s", t.Name.Name))
						}
					}
				}
			}
		}
	}

	// Find all the declarations of concrete types that implement Node.
	for _, f := range pkg.Files {
		for _, d := range f.Decls {
			g, ok := d.(*ast.GenDecl)
			if !ok {
				continue
			}
			for _, s := range g.Specs {
				t, ok := s.(*ast.TypeSpec)
				if !ok {
					continue
				}
				if strings.HasPrefix(t.Name.Name, "mini") {
					// We don't treat the mini types as
					// concrete implementations of Node
					// (even though they are) because
					// we only use them by embedding them.
					continue
				}
				if isConcreteNode(t) {
					concreteNodes = append(concreteNodes, t)
				}
				if isInterfaceNode(t) {
					interfaceNodes = append(interfaceNodes, t)
				}
			}
		}
	}
	// Sort for deterministic output.
	slices.SortFunc(concreteNodes, func(a, b *ast.TypeSpec) int {
		return strings.Compare(a.Name.Name, b.Name.Name)
	})
	// Generate code for each concrete type.
	for _, t := range concreteNodes {
		processType(t)
	}
	// Add some helpers.
	generateHelpers()

	// Format and write output.
	out, err := format.Source(buf.Bytes())
	if err != nil {
		// write out mangled source so we can see the bug.
		out = buf.Bytes()
	}
	err = os.WriteFile("node_gen.go", out, 0666)
	if err != nil {
		log.Fatal(err)
	}
}

// isConcreteNode reports whether the type t is a concrete type
// implementing Node.
func isConcreteNode(t *ast.TypeSpec) bool {
	s, ok := t.Type.(*ast.StructType)
	if !ok {
		return false
	}
	for _, f := range s.Fields.List {
		if isMini(f.Type) {
			return true
		}
	}
	return false
}

// isInterfaceNode reports whether the type t is an interface type
// implementing Node (including Node itself).
func isInterfaceNode(t *ast.TypeSpec) bool {
	s, ok := t.Type.(*ast.InterfaceType)
	if !ok {
		return false
	}
	if t.Name.Name == "Node" {
		return true
	}
	if t.Name.Name == "OrigNode" || t.Name.Name == "InitNode" {
		// These we exempt from consideration (fields of
		// this type don't need to be walked or copied).
		return false
	}

	// Look for embedded Node type.
	// Note that this doesn't handle multi-level embedding, but
	// we have none of that at the moment.
	for _, f := range s.Methods.List {
		if len(f.Names) != 0 {
			continue
		}
		if isNamedType(f.Type, "Node") {
			return true
		}
	}
	return false
}

func processType(t *ast.TypeSpec) {
	name := t.Name.Name
	fmt.Fprintf(&buf, "\n")
	fmt.Fprintf(&buf, "func (n *%s) Format(s fmt.State, verb rune) { fmtNode(n, s, verb) }\n", name)

	switch name {
	case "Name", "Func":
		// Too specialized to automate.
		return
	}

	s := t.Type.(*ast.StructType)
	fields := s.Fields.List

	// Expand any embedded fields.
	for i := 0; i < len(fields); i++ {
		f := fields[i]
		if len(f.Names) != 0 {
			continue // not embedded
		}
		if isMini(f.Type) {
			// Insert the fields of the embedded type into the main type.
			// (It would be easier just to append, but inserting in place
			// matches the old mknode behavior.)
			ss := mini[f.Type.(*ast.Ident).Name].Type.(*ast.StructType)
			var f2 []*ast.Field
			f2 = append(f2, fields[:i]...)
			f2 = append(f2, ss.Fields.List...)
			f2 = append(f2, fields[i+1:]...)
			fields = f2
			i--
			continue
		} else if isNamedType(f.Type, "origNode") {
			// Ignore this field
			copy(fields[i:], fields[i+1:])
			fields = fields[:len(fields)-1]
			i--
			continue
		} else {
			panic("unknown embedded field " + fmt.Sprintf("%v", f.Type))
		}
	}
	// Process fields.
	var copyBody strings.Builder
	var doChildrenBody strings.Builder
	var doChildrenWithHiddenBody strings.Builder
	var editChildrenBody strings.Builder
	var editChildrenWithHiddenBody strings.Builder
	for _, f := range fields {
		names := f.Names
		ft := f.Type
		hidden := false
		if f.Tag != nil {
			tag := f.Tag.Value[1 : len(f.Tag.Value)-1]
			if strings.HasPrefix(tag, "mknode:") {
				if tag[7:] == "\"-\"" {
					if !isNamedType(ft, "Node") {
						continue
					}
					hidden = true
				} else {
					panic(fmt.Sprintf("unexpected tag value: %s", tag))
				}
			}
		}
		if isNamedType(ft, "Nodes") {
			// Nodes == []Node
			ft = &ast.ArrayType{Elt: &ast.Ident{Name: "Node"}}
		}
		isSlice := false
		if a, ok := ft.(*ast.ArrayType); ok && a.Len == nil {
			isSlice = true
			ft = a.Elt
		}
		isPtr := false
		if p, ok := ft.(*ast.StarExpr); ok {
			isPtr = true
			ft = p.X
		}
		if !implementsNode(ft) {
			continue
		}
		for _, name := range names {
			ptr := ""
			if isPtr {
				ptr = "*"
			}
			if isSlice {
				fmt.Fprintf(&doChildrenWithHiddenBody,
					"if do%ss(n.%s, do) {\nreturn true\n}\n", ft, name)
				fmt.Fprintf(&editChildrenWithHiddenBody,
					"edit%ss(n.%s, edit)\n", ft, name)
			} else {
				fmt.Fprintf(&doChildrenWithHiddenBody,
					"if n.%s != nil && do(n.%s) {\nreturn true\n}\n", name, name)
				fmt.Fprintf(&editChildrenWithHiddenBody,
					"if n.%s != nil {\nn.%s = edit(n.%s).(%s%s)\n}\n", name, name, name, ptr, ft)
			}
			if hidden {
				continue
			}
			if isSlice {
				fmt.Fprintf(&copyBody, "c.%s = copy%ss(c.%s)\n", name, ft, name)
				fmt.Fprintf(&doChildrenBody,
					"if do%ss(n.%s, do) {\nreturn true\n}\n", ft, name)
				fmt.Fprintf(&editChildrenBody,
					"edit%ss(n.%s, edit)\n", ft, name)
			} else {
				fmt.Fprintf(&doChildrenBody,
					"if n.%s != nil && do(n.%s) {\nreturn true\n}\n", name, name)
				fmt.Fprintf(&editChildrenBody,
					"if n.%s != nil {\nn.%s = edit(n.%s).(%s%s)\n}\n", name, name, name, ptr, ft)
			}
		}
	}
	fmt.Fprintf(&buf, "func (n *%s) copy() Node {\nc := *n\n", name)
	buf.WriteString(copyBody.String())
	fmt.Fprintf(&buf, "return &c\n}\n")
	fmt.Fprintf(&buf, "func (n *%s) doChildren(do func(Node) bool) bool {\n", name)
	buf.WriteString(doChildrenBody.String())
	fmt.Fprintf(&buf, "return false\n}\n")
	fmt.Fprintf(&buf, "func (n *%s) doChildrenWithHidden(do func(Node) bool) bool {\n", name)
	buf.WriteString(doChildrenWithHiddenBody.String())
	fmt.Fprintf(&buf, "return false\n}\n")
	fmt.Fprintf(&buf, "func (n *%s) editChildren(edit func(Node) Node) {\n", name)
	buf.WriteString(editChildrenBody.String())
	fmt.Fprintf(&buf, "}\n")
	fmt.Fprintf(&buf, "func (n *%s) editChildrenWithHidden(edit func(Node) Node) {\n", name)
	buf.WriteString(editChildrenWithHiddenBody.String())
	fmt.Fprintf(&buf, "}\n")
}

func generateHelpers() {
	for _, typ := range []string{"CaseClause", "CommClause", "Name", "Node"} {
		ptr := "*"
		if typ == "Node" {
			ptr = "" // interfaces don't need *
		}
		fmt.Fprintf(&buf, "\n")
		fmt.Fprintf(&buf, "func copy%ss(list []%s%s) []%s%s {\n", typ, ptr, typ, ptr, typ)
		fmt.Fprintf(&buf, "if list == nil { return nil }\n")
		fmt.Fprintf(&buf, "c := make([]%s%s, len(list))\n", ptr, typ)
		fmt.Fprintf(&buf, "copy(c, list)\n")
		fmt.Fprintf(&buf, "return c\n")
		fmt.Fprintf(&buf, "}\n")
		fmt.Fprintf(&buf, "func do%ss(list []%s%s, do func(Node) bool) bool {\n", typ, ptr, typ)
		fmt.Fprintf(&buf, "for _, x := range list {\n")
		fmt.Fprintf(&buf, "if x != nil && do(x) {\n")
		fmt.Fprintf(&buf, "return true\n")
		fmt.Fprintf(&buf, "}\n")
		fmt.Fprintf(&buf, "}\n")
		fmt.Fprintf(&buf, "return false\n")
		fmt.Fprintf(&buf, "}\n")
		fmt.Fprintf(&buf, "func edit%ss(list []%s%s, edit func(Node) Node) {\n", typ, ptr, typ)
		fmt.Fprintf(&buf, "for i, x := range list {\n")
		fmt.Fprintf(&buf, "if x != nil {\n")
		fmt.Fprintf(&buf, "list[i] = edit(x).(%s%s)\n", ptr, typ)
		fmt.Fprintf(&buf, "}\n")
		fmt.Fprintf(&buf, "}\n")
		fmt.Fprintf(&buf, "}\n")
	}
}