package main

import (
	"fmt"
	"go/ast"
	"go/parser"
	"go/token"
	"os"
	"path/filepath"
	"slices"
	"strings"
	"unicode"
)

func main() {
	if len(os.Args) == 1 {
		fmt.Fprintln(os.Stderr, "need at least one type name")
		os.Exit(1)
	}
	typeNames := os.Args[1:]

	goFile := os.Getenv("GOFILE")
	if goFile == "" {
		fmt.Fprintln(os.Stderr, "need GOFILE")
		os.Exit(1)
	}
	goPackage := os.Getenv("GOPACKAGE")
	if goPackage == "" {
		fmt.Fprintln(os.Stderr, "need GOPACKAGE")
		os.Exit(1)
	}

	fset := token.NewFileSet()
	node, err := parser.ParseFile(fset, goFile, nil, parser.ParseComments)
	if err != nil {
		fmt.Fprintf(os.Stderr, "parse file: %v\n", err)
		os.Exit(1)
	}

	typeFields := make(map[string][]string)
	for n := range ast.Preorder(node) {
		typeSpec, ok := n.(*ast.TypeSpec)
		if !ok {
			continue
		}

		typeName := typeSpec.Name.Name
		if !slices.Contains(typeNames, typeName) {
			continue
		}

		structType, ok := typeSpec.Type.(*ast.StructType)
		if !ok {
			continue
		}

		var fields []string
		for _, field := range structType.Fields.List {
			for _, name := range field.Names {
				fields = append(fields, name.Name)
			}
		}

		typeFields[typeName] = fields
	}

	ext := filepath.Ext(goFile)
	destName := strings.TrimSuffix(goFile, ext) + ".gen" + ext

	destf, err := os.Create(destName)
	if err != nil {
		fmt.Fprintf(os.Stderr, "create dest file: %v\n", err)
		os.Exit(1)
	}
	defer destf.Close()

	fmt.Fprintf(destf, "// Code generated by \"%s %s\"; DO NOT EDIT.\n", filepath.Base(os.Args[0]), strings.Join(os.Args[1:], " "))
	fmt.Fprintf(destf, "\n")
	fmt.Fprintf(destf, "package %s\n", goPackage)
	fmt.Fprintf(destf, "\n")
	fmt.Fprintf(destf, "import (\n")
	fmt.Fprintf(destf, "\t\"database/sql\"\n")
	fmt.Fprintf(destf, "\t\"fmt\"\n")
	fmt.Fprintf(destf, ")\n")

	for _, typeName := range typeNames {
		fmt.Fprintf(destf, "\n")

		fields := typeFields[typeName]

		firstChar := strings.ToLower(string([]rune(typeName)[0]))

		fmt.Fprintf(destf, "func _() {\n")
		fmt.Fprintf(destf, "\t// Validate the struct fields haven't changed. If this doesn't compile you probably need to `go generate` again.\n")
		fmt.Fprintf(destf, "\tvar %s %s\n", firstChar, typeName)

		var fieldRefs []string
		for _, field := range fields {
			fieldRefs = append(fieldRefs, fmt.Sprintf("%s.%s", firstChar, field))
		}

		fmt.Fprintf(destf, "\t_ = %s{%s}\n", typeName, strings.Join(fieldRefs, ", "))
		fmt.Fprintf(destf, "}\n")

		fmt.Fprintf(destf, "\n")
		fmt.Fprintf(destf, "func (%s) PrimaryKey() string {\n", typeName)
		fmt.Fprintf(destf, "\treturn \"id\"\n") // TODO: flexible PK(s)
		fmt.Fprintf(destf, "}\n")

		fmt.Fprintf(destf, "\n")
		fmt.Fprintf(destf, "func (%s %s) Values() []sql.NamedArg {\n", firstChar, typeName)

		var namedArgs []string
		for _, f := range fields {
			namedArgs = append(namedArgs, fmt.Sprintf("sql.Named(\"%s\", %s.%s)", toSnake(f), firstChar, f))
		}

		fmt.Fprintf(destf, "\treturn []sql.NamedArg{%s}\n", strings.Join(namedArgs, ", "))
		fmt.Fprintf(destf, "}\n")

		fmt.Fprintf(destf, "\n")
		fmt.Fprintf(destf, "func (%s *%s) ScanFrom(rows *sql.Rows) error {\n", firstChar, typeName)
		fmt.Fprintf(destf, "\tcolumns, err := rows.Columns()\n")
		fmt.Fprintf(destf, "\tif err != nil {\n")
		fmt.Fprintf(destf, "\t\treturn err\n")
		fmt.Fprintf(destf, "\t}\n")
		fmt.Fprintf(destf, "\tdests := make([]any, 0, len(columns))\n")
		fmt.Fprintf(destf, "\tfor _, c := range columns {\n")
		fmt.Fprintf(destf, "\t\tswitch c {\n")
		
		for _, f := range fields {
			fmt.Fprintf(destf, "\t\tcase \"%s\":\n", toSnake(f))
			fmt.Fprintf(destf, "\t\t\tdests = append(dests, &%s.%s)\n", firstChar, f)
		}

		fmt.Fprintf(destf, "\t\tdefault:\n")
		fmt.Fprintf(destf, "\t\t\treturn fmt.Errorf(\"unknown column name %%q\", c)\n")
		fmt.Fprintf(destf, "\t\t}\n")
		fmt.Fprintf(destf, "\t}\n")
		fmt.Fprintf(destf, "\treturn rows.Scan(dests...)\n")
		fmt.Fprintf(destf, "}\n")
	}
}

func toSnake(s string) string {
	var result strings.Builder
	for i, char := range s {
		if i > 0 && unicode.IsUpper(char) {
			if unicode.IsLower(rune(s[i-1])) || unicode.IsDigit(rune(s[i-1])) || (i+1 < len(s) && unicode.IsUpper(char) && i+1 < len(s) && unicode.IsLower(rune(s[i+1]))) {
				result.WriteRune('_')
			}
		}
		result.WriteRune(unicode.ToLower(char))
	}

	return result.String()
}
