// SPDX-FileCopyrightText: 2024 Olivier Charvin <git@olivier.pfad.fr>
//
// SPDX-License-Identifier: CC0-1.0

package check

import (
	"bytes"
	"fmt"
	"slices"
	"strconv"
	"strings"
)

// diff computation based on Myers' diff algorithm.
// Implemented with help from the vizualizations from https://blog.robertelder.org/diff-algorithm/

// The "old" string is horizontal, the "new" string vertical

type hunk struct {
	inOld, inNew bool // if both true: unchanged
	position     int  // from old if both true
	len          int
}

// kLines is an helper to store the various "best X values" in a contiguous manner.
/*
Example with lenOld=5 and lenNew=3. The best x of the k-lines will be stored at the following indexes:
1 2 4  7  10
3 5 8  11 13
6 9 12 14 15
*/
// It is mostly "triangle" computation (add the area of the main triangle and eventually substract from the two smaller ones for clamping)
type kLines struct {
	oldLen, newLen int
}

func (k kLines) Range(d int) (kMin, kMax, indexBase int) {
	kMin = -d
	kMax = d
	indexBase = ((d + 1) * d) / 2 // index in the slice of the best-x for the k-line with the greatest k (for a given d)
	// clamp at the borders
	if d > k.newLen {
		kMin = -(2*k.newLen - d)
		indexBase -= ((d - k.newLen - 1) * (d - k.newLen)) / 2
	}

	if d > k.oldLen {
		kMax = 2*k.oldLen - d
		indexBase -= ((d - k.oldLen + 1) * (d - k.oldLen)) / 2
	}
	// IndexOffset must be added to the indexBase for contiguous storage
	return kMin, kMax, indexBase
}

func (kLines) IndexOffset(d, k int) int {
	return (d - k) / 2
}

func newDiff[T comparable](oldSlice, newSlice []T) []hunk {
	oldLen, newLen := len(oldSlice), len(newSlice)
	if oldLen == 0 && newLen == 0 {
		return nil
	} else if oldLen == 0 {
		return []hunk{{
			inNew:    true,
			position: 0,
			len:      newLen,
		}}
	} else if newLen == 0 {
		return []hunk{{
			inOld:    true,
			position: 0,
			len:      oldLen,
		}}
	}
	// oldLen > 0 and newLen > 0

	// TODO: run from the end as well to find a middle snake
	// (but keep the history to prevent recursion)
	// TODO: add benchmark to compare with
	//  - https://pkg.go.dev/cloudeng.io/algo/lcs
	//  - https://github.com/golang/tools/blob/master/internal/diff/myers/diff.go
	kl := kLines{oldLen, newLen}
	var v []int
	previousBase := -1
	for d := 0; ; d++ {
		kMin, kMax, indexBase := kl.Range(d)
		for k := kMin; k <= kMax; k += 2 {
			var x int
			if previousBase >= 0 {
				indexMinus := previousBase + kl.IndexOffset(d-1, k-1)
				indexPlus := previousBase + kl.IndexOffset(d-1, k+1)
				if k == -d || (k != d && v[indexMinus] < v[indexPlus]) {
					x = v[indexPlus]
				} else {
					x = v[indexMinus] + 1
				}
			}
			y := x - k
			for x < oldLen && y < newLen && oldSlice[x] == newSlice[y] {
				x++
				y++
			}
			if x == oldLen && y == newLen {
				return reconstructEditScript(d, k, x, v, kl)
			}
			index := indexBase + kl.IndexOffset(d, k)
			if index >= len(v) {
				v = append(v, make([]int, index+1-len(v))...)
			}
			v[index] = x
		}
		previousBase = indexBase
	}
}

// reconstructEditScript reconstructs the edit script from the best-distance found (work backward: d-->0)
func reconstructEditScript(d, k, x int, v []int, kl kLines) []hunk {
	hunks := make([]hunk, 0, d)
	var previousKind int
	for d > 0 {
		_, _, previousBase := kl.Range(d - 1)
		indexMinus := previousBase + kl.IndexOffset(d-1, k-1)
		indexPlus := previousBase + kl.IndexOffset(d-1, k+1)
		var previousX int
		var previousK int
		if k == -d || (k != d && v[indexMinus] < v[indexPlus]) {
			// insertion (vertical)
			previousX = v[indexPlus]
			previousK = k + 1
			if previousX < x {
				hunks = append(hunks, hunk{
					inOld:    true,
					inNew:    true,
					position: previousX,
					len:      x - previousX,
				})
				previousKind = 0
			}
			y := previousX - previousK
			if previousKind == 1 {
				h := hunks[len(hunks)-1]
				h.len++
				h.position--
				hunks[len(hunks)-1] = h
			} else {
				hunks = append(hunks, hunk{
					inNew:    true,
					position: y,
					len:      1,
				})
				previousKind = 1
			}
		} else {
			// deletion (horizontal)
			previousX = v[indexMinus]
			previousK = k - 1
			if previousX < x-1 {
				hunks = append(hunks, hunk{
					inOld:    true,
					inNew:    true,
					position: previousX + 1,
					len:      x - previousX - 1,
				})
				previousKind = 0
			}
			if previousKind == -1 {
				h := hunks[len(hunks)-1]
				h.len++
				h.position--
				hunks[len(hunks)-1] = h
			} else {
				hunks = append(hunks, hunk{
					inOld:    true,
					position: previousX,
					len:      1,
				})
				previousKind = -1
			}
		}
		x = previousX
		k = previousK

		d--
	}
	if x > 0 {
		hunks = append(hunks, hunk{
			inOld:    true,
			inNew:    true,
			position: 0,
			len:      x,
		})
	}
	slices.Reverse(hunks)
	return hunks
}

func newUnifiedDiff[T comparable](oldSlice, newSlice []T) string {
	const context = 2

	hunks := newDiff(oldSlice, newSlice)
	var sb strings.Builder
	sb.WriteString(fmt.Sprintf("--- want(len=%d) +++ got(len=%d)", len(oldSlice), len(newSlice)))

	oldStart := 0
	newStart := 0

	oldPos := 0
	newPos := 0

	formatValue := func(v any) string { return fmt.Sprintf("%#v", v) }
	{
		// hack when T can be converted to a string: quote it (to escape any special chars, like newlines)
		// and remove the surrounding quotes (and unescape the internal quotes)
		var zero T
		v := any(zero)
		if _, ok := v.(string); ok {
			formatValue = func(v any) string {
				quoted := strconv.Quote(v.(string))
				quoted = quoted[1 : len(quoted)-1]              // remove surrounding quotes
				return strings.ReplaceAll(quoted, "\\\"", "\"") // unescape internal quotes
			}
		}
	}

	var currentHunk bytes.Buffer

	appendDiff := func(prefix byte, v any) {
		currentHunk.WriteByte('\n')
		currentHunk.WriteByte(prefix)
		currentHunk.WriteString(formatValue(v))
	}
	writeCurrentHunk := func(oldEnd, newEnd int) {
		if currentHunk.Len() == 0 {
			return
		}
		sb.WriteString("\n@@ -")
		sb.WriteString(strconv.Itoa(oldStart + 1))
		if l := oldEnd - oldStart; l != 1 {
			sb.WriteByte(',')
			sb.WriteString(strconv.Itoa(l))
		}
		sb.WriteString(" +")
		sb.WriteString(strconv.Itoa(newStart + 1))
		if l := newEnd - newStart; l != 1 {
			sb.WriteByte(',')
			sb.WriteString(strconv.Itoa(l))
		}
		sb.WriteString(" @@")
		sb.Write(currentHunk.Bytes())

		currentHunk.Reset()
	}
	for i, h := range hunks {
		if h.inNew && h.inOld {
			// unchanged hunk
			oldPos += h.len
			newPos += h.len

			if currentHunk.Len() == 0 {
				// start of a new currentHunk (should never be the very last hunk)

				skip := h.len - context
				if skip > 0 {
					// only print the last 2 (unchanged) lines
					newStart += skip
					oldStart += skip
				}
				for p := 0; p < min(context, h.len); p++ {
					appendDiff(' ', oldSlice[newStart+p])
				}
				continue
			} else if h.len <= 2*context && i < len(hunks)-1 {
				// small unchanged hunk in the middle
				for p := 0; p < h.len; p++ {
					appendDiff(' ', oldSlice[h.position+p])
				}
			} else {
				// large (or at the end) unchanged hunk
				skip := max(0, h.len-context)
				for p := 0; p < h.len-skip; p++ {
					appendDiff(' ', oldSlice[h.position+p])
				}
				writeCurrentHunk(oldPos-skip, newPos-skip)

				if i < len(hunks)-1 {
					// not the last change (and large enough):
					// start a new currentHunk
					for p := h.len - context; p < h.len; p++ {
						appendDiff(' ', oldSlice[h.position+p])
					}

					newStart = newPos - context
					oldStart = oldPos - context
				}
			}
			continue

		} else if h.inOld {
			// deleted
			oldPos += h.len
			for p := 0; p < h.len; p++ {
				appendDiff('-', oldSlice[h.position+p])
			}
			continue
		} else if h.inNew {
			// inserted
			newPos += h.len
			for p := 0; p < h.len; p++ {
				appendDiff('+', newSlice[h.position+p])
			}
			continue
		}
	}
	writeCurrentHunk(oldPos, newPos)
	return sb.String()
}
