package httpsec

import (
	"fmt"
	"math/rand"
	"net/http"
	"strings"
	"testing"

	"github.com/DataDog/appsec-internal-go/netip"
	"github.com/stretchr/testify/require"
)

type ipTestCase struct {
	name            string
	remoteAddr      string
	headers         map[string]string
	expectedIP      netip.Addr
	clientIPHeaders []string
}

func genIPTestCases() []ipTestCase {
	ipv4Global := randGlobalIPv4().String()
	ipv6Global := randGlobalIPv6().String()
	ipv4Private := randPrivateIPv4().String()
	ipv6Private := randPrivateIPv6().String()

	tcs := []ipTestCase{
		{
			name:       "ipv4-global-remoteaddr",
			remoteAddr: ipv4Global,
			expectedIP: netip.MustParseAddr(ipv4Global),
		},
		{
			name:       "ipv4-private-remoteaddr",
			remoteAddr: ipv4Private,
			expectedIP: netip.MustParseAddr(ipv4Private),
		},
		{
			name:       "ipv6-global-remoteaddr",
			remoteAddr: ipv6Global,
			expectedIP: netip.MustParseAddr(ipv6Global),
		},
		{
			name:       "ipv6-private-remoteaddr",
			remoteAddr: ipv6Private,
			expectedIP: netip.MustParseAddr(ipv6Private),
		},
	}

	testHeaders := []string{
		"x-forwarded-for",
		"x-real-ip",
		"true-client-ip",
		"x-client-ip",
		"x-forwarded",
		"forwarded-for",
		"x-cluster-client-ip",
		"fastly-client-ip",
		"cf-connecting-ip",
		"cf-connecting-ip6",
	}

	// Simple ipv4 test cases over all headers
	for _, header := range testHeaders {
		tcs = append(tcs,
			ipTestCase{
				name:            "ipv4-global." + header,
				remoteAddr:      ipv4Private,
				headers:         map[string]string{header: ipv4Global},
				expectedIP:      netip.MustParseAddr(ipv4Global),
				clientIPHeaders: testHeaders,
			},
			ipTestCase{
				name:            "ipv4-private." + header,
				headers:         map[string]string{header: ipv4Private},
				remoteAddr:      ipv6Private,
				expectedIP:      netip.MustParseAddr(ipv4Private),
				clientIPHeaders: testHeaders,
			},
			ipTestCase{
				name:            "ipv4-global-remoteaddr-local-ip-header." + header,
				remoteAddr:      ipv4Global,
				headers:         map[string]string{header: ipv4Private},
				expectedIP:      netip.MustParseAddr(ipv4Global),
				clientIPHeaders: testHeaders,
			},
			ipTestCase{
				name:            "ipv4-global-remoteaddr-global-ip-header." + header,
				remoteAddr:      ipv6Global,
				headers:         map[string]string{header: ipv4Global},
				expectedIP:      netip.MustParseAddr(ipv4Global),
				clientIPHeaders: testHeaders,
			})
	}

	// Simple ipv6 test cases over all headers
	for _, header := range testHeaders {
		tcs = append(tcs, ipTestCase{
			name:            "ipv6-global." + header,
			remoteAddr:      ipv4Private,
			headers:         map[string]string{header: ipv6Global},
			expectedIP:      netip.MustParseAddr(ipv6Global),
			clientIPHeaders: testHeaders,
		},
			ipTestCase{
				name:            "ipv6-private." + header,
				headers:         map[string]string{header: ipv6Private},
				remoteAddr:      ipv4Private,
				expectedIP:      netip.MustParseAddr(ipv6Private),
				clientIPHeaders: testHeaders,
			},
			ipTestCase{
				name:            "ipv6-global-remoteaddr-local-ip-header." + header,
				remoteAddr:      ipv6Global,
				headers:         map[string]string{header: ipv6Private},
				expectedIP:      netip.MustParseAddr(ipv6Global),
				clientIPHeaders: testHeaders,
			},
			ipTestCase{
				name:            "ipv6-global-remoteaddr-global-ip-header." + header,
				remoteAddr:      ipv4Global,
				headers:         map[string]string{header: ipv6Global},
				expectedIP:      netip.MustParseAddr(ipv6Global),
				clientIPHeaders: testHeaders,
			})
	}

	// private and global in same header
	tcs = append([]ipTestCase{
		{
			name:       "ipv4-private+global",
			headers:    map[string]string{"x-forwarded-for": ipv4Private + "," + ipv4Global},
			expectedIP: netip.MustParseAddr(ipv4Global),
		},
		{
			name:       "ipv4-global+private",
			headers:    map[string]string{"x-forwarded-for": ipv4Global + "," + ipv4Private},
			expectedIP: netip.MustParseAddr(ipv4Global),
		},
		{
			name:       "ipv6-private+global",
			headers:    map[string]string{"x-forwarded-for": ipv6Private + "," + ipv6Global},
			expectedIP: netip.MustParseAddr(ipv6Global),
		},
		{
			name:       "ipv6-global+private",
			headers:    map[string]string{"x-forwarded-for": ipv6Global + "," + ipv6Private},
			expectedIP: netip.MustParseAddr(ipv6Global),
		},
		{
			name:       "mixed-global+global",
			headers:    map[string]string{"x-forwarded-for": ipv4Private + "," + ipv6Global + "," + ipv4Global},
			expectedIP: netip.MustParseAddr(ipv6Global),
		},
		{
			name:       "mixed-global+global",
			headers:    map[string]string{"x-forwarded-for": ipv4Private + "," + ipv4Global + "," + ipv6Global},
			expectedIP: netip.MustParseAddr(ipv4Global),
		},
	}, tcs...)

	// Invalid IPs (or a mix of valid/invalid over a single or multiple headers)
	tcs = append([]ipTestCase{
		{
			name:       "no headers",
			headers:    nil,
			expectedIP: netip.Addr{},
		},
		{
			name:       "invalid-ipv4",
			headers:    map[string]string{"x-forwarded-for": "127..0.0.1"},
			expectedIP: netip.Addr{},
		},
		{
			name:       "invalid-ipv4-header-valid-remoteaddr",
			headers:    map[string]string{"x-forwarded-for": "127..0.0.1"},
			remoteAddr: ipv4Private,
			expectedIP: netip.MustParseAddr(ipv4Private),
		},
		{
			name:       "invalid-ipv4-recover",
			headers:    map[string]string{"x-forwarded-for": "127..0.0.1, " + ipv6Private + "," + ipv4Global},
			expectedIP: netip.MustParseAddr(ipv4Global),
		},
		{
			name:            "ip-multi-header-order-0",
			headers:         map[string]string{"x-forwarded-for": ipv4Global, "forwarded-for": ipv6Global},
			clientIPHeaders: []string{"x-forwarded-for", "forwarded-for"},
			expectedIP:      netip.MustParseAddr(ipv4Global),
		},
		{
			name:            "ip-multi-header-order-1",
			headers:         map[string]string{"x-forwarded-for": ipv4Global, "forwarded-for": ipv6Global},
			clientIPHeaders: []string{"forwarded-for", "x-forwarded-for"},
			expectedIP:      netip.MustParseAddr(ipv6Global),
		},
		{
			name:            "ipv4-multi-header-0",
			headers:         map[string]string{"x-forwarded-for": ipv4Private, "forwarded-for": ipv4Global},
			clientIPHeaders: []string{"x-forwarded-for", "forwarded-for"},
			expectedIP:      netip.MustParseAddr(ipv4Global),
		},
		{
			name:            "ipv4-multi-header-1",
			headers:         map[string]string{"x-forwarded-for": ipv4Global, "forwarded-for": ipv4Private},
			clientIPHeaders: []string{"x-forwarded-for", "forwarded-for"},
			expectedIP:      netip.MustParseAddr(ipv4Global),
		},
		{
			name:            "ipv4-multi-header-2",
			headers:         map[string]string{"x-forwarded-for": "127.0.0.1, " + ipv4Private, "forwarded-for": fmt.Sprintf("%s, %s", ipv4Private, ipv4Global)},
			clientIPHeaders: []string{"x-forwarded-for", "forwarded-for"},
			expectedIP:      netip.MustParseAddr(ipv4Global),
		},
		{
			name:            "ipv4-multi-header-3",
			headers:         map[string]string{"x-forwarded-for": "127.0.0.1, " + ipv4Global, "forwarded-for": ipv4Private},
			clientIPHeaders: []string{"x-forwarded-for", "forwarded-for"},
			expectedIP:      netip.MustParseAddr(ipv4Global),
		},
		{
			name:       "invalid-ipv6",
			headers:    map[string]string{"x-forwarded-for": "2001:0db8:2001:zzzz::"},
			expectedIP: netip.Addr{},
		},
		{
			name:       "invalid-ipv6-recover",
			headers:    map[string]string{"x-forwarded-for": "2001:0db8:2001:zzzz::, " + ipv6Global},
			expectedIP: netip.MustParseAddr(ipv6Global),
		},
		{
			name:            "ipv6-multi-header-0",
			headers:         map[string]string{"x-forwarded-for": ipv6Private, "forwarded-for": ipv6Global},
			clientIPHeaders: []string{"x-forwarded-for", "forwarded-for"},
			expectedIP:      netip.MustParseAddr(ipv6Global),
		},
		{
			name:            "ipv6-multi-header-1",
			headers:         map[string]string{"x-forwarded-for": ipv6Global, "forwarded-for": ipv6Private},
			clientIPHeaders: []string{"x-forwarded-for", "forwarded-for"},
			expectedIP:      netip.MustParseAddr(ipv6Global),
		},
		{
			name:            "ipv6-multi-header-2",
			headers:         map[string]string{"x-forwarded-for": "127.0.0.1, " + ipv6Private, "forwarded-for": fmt.Sprintf("%s, %s", ipv6Private, ipv6Global)},
			clientIPHeaders: []string{"x-forwarded-for", "forwarded-for"},
			expectedIP:      netip.MustParseAddr(ipv6Global),
		},
		{
			name:            "ipv6-multi-header-3",
			headers:         map[string]string{"x-forwarded-for": "127.0.0.1, " + ipv6Global, "forwarded-for": ipv6Private},
			clientIPHeaders: []string{"x-forwarded-for", "forwarded-for"},
			expectedIP:      netip.MustParseAddr(ipv6Global),
		},
		{
			name:       "no-headers",
			expectedIP: netip.Addr{},
		},
		{
			name:       "header-case",
			headers:    map[string]string{"X-fOrWaRdEd-FoR": ipv4Global},
			expectedIP: netip.MustParseAddr(ipv4Global),
		},
		{
			name:            "user-header",
			headers:         map[string]string{"x-forwarded-for": ipv6Global, "custom-header": ipv4Global},
			clientIPHeaders: []string{"custom-header"},
			expectedIP:      netip.MustParseAddr(ipv4Global),
		},
		{
			name:            "user-header-not-found",
			headers:         map[string]string{"x-forwarded-for": ipv4Global},
			clientIPHeaders: []string{"custom-header"},
			expectedIP:      netip.Addr{},
		},
	}, tcs...)

	return tcs
}

func TestClientIP(t *testing.T) {
	for _, hasCanonicalMIMEHeaderKeys := range []bool{true, false} {
		t.Run(fmt.Sprintf("canonical-headers-%t", hasCanonicalMIMEHeaderKeys), func(t *testing.T) {
			for _, tc := range genIPTestCases() {
				t.Run(tc.name, func(t *testing.T) {
					headers := http.Header{}
					for k, v := range tc.headers {
						if hasCanonicalMIMEHeaderKeys {
							headers.Add(k, v)
						} else {
							k = strings.ToLower(k)
							headers[k] = append(headers[k], v)
						}
					}

					// Default list to use - the tests rely on x-forwarded-for only when using this default list
					monitoredHeaders := []string{"x-client-ip", "x-forwarded-for", "true-client-ip"}
					if tc.clientIPHeaders != nil {
						monitoredHeaders = tc.clientIPHeaders
					}
					remoteIP, clientIP := ClientIP(headers, hasCanonicalMIMEHeaderKeys, tc.remoteAddr, monitoredHeaders)
					tags := ClientIPTags(remoteIP, clientIP)
					if tc.expectedIP.IsValid() {
						expectedIP := tc.expectedIP.String()
						require.Equal(t, expectedIP, clientIP.String())
						if tc.remoteAddr != "" {
							require.Equal(t, tc.remoteAddr, remoteIP.String())
							require.Equal(t, tc.remoteAddr, tags[RemoteIPTag])
						} else {
							require.NotContains(t, tags, RemoteIPTag)
						}
						require.Equal(t, expectedIP, tags[ClientIPTag])
					} else {
						require.NotContains(t, tags, ClientIPTag)
						require.False(t, clientIP.IsValid())
					}
				})
			}
		})
	}
}

func randIPv4() netip.Addr {
	return netip.IPv4(uint8(rand.Uint32()), uint8(rand.Uint32()), uint8(rand.Uint32()), uint8(rand.Uint32()))
}

func randIPv6() netip.Addr {
	return netip.AddrFrom16([16]byte{
		uint8(rand.Uint32()), uint8(rand.Uint32()), uint8(rand.Uint32()), uint8(rand.Uint32()),
		uint8(rand.Uint32()), uint8(rand.Uint32()), uint8(rand.Uint32()), uint8(rand.Uint32()),
		uint8(rand.Uint32()), uint8(rand.Uint32()), uint8(rand.Uint32()), uint8(rand.Uint32()),
		uint8(rand.Uint32()), uint8(rand.Uint32()), uint8(rand.Uint32()), uint8(rand.Uint32()),
	})
}

func randGlobalIPv4() netip.Addr {
	for {
		ip := randIPv4()
		if isGlobal(ip) {
			return ip
		}
	}
}

func randGlobalIPv6() netip.Addr {
	for {
		ip := randIPv6()
		if isGlobal(ip) {
			return ip
		}
	}
}

func randPrivateIPv4() netip.Addr {
	for {
		ip := randIPv4()
		if !isGlobal(ip) && ip.IsPrivate() {
			return ip
		}
	}
}

func randPrivateIPv6() netip.Addr {
	for {
		ip := randIPv6()
		if !isGlobal(ip) && ip.IsPrivate() {
			return ip
		}
	}
}
