//go:build linux

package iptables

import (
	"errors"
	"fmt"
	"os/exec"
	"slices"
	"strings"

	log "github.com/sirupsen/logrus"

	"github.com/crowdsecurity/crowdsec/pkg/models"

	"github.com/crowdsecurity/cs-firewall-bouncer/pkg/cfg"
	"github.com/crowdsecurity/cs-firewall-bouncer/pkg/ipsetcmd"
	"github.com/crowdsecurity/cs-firewall-bouncer/pkg/types"
)

const (
	IPTablesDroppedPacketIdx = 0
	IPTablesDroppedByteIdx   = 1
)

type iptables struct {
	v4 *ipTablesContext
	v6 *ipTablesContext
}

func NewIPTables(config *cfg.BouncerConfig) (types.Backend, error) {
	var err error

	ret := &iptables{}

	defaultSet, err := ipsetcmd.NewIPSet("")
	if err != nil {
		return nil, err
	}

	allowedActions := []string{"DROP", "REJECT", "TARPIT", "LOG"}

	target := strings.ToUpper(config.DenyAction)
	if target == "" {
		target = "DROP"
	}

	log.Infof("using '%s' as deny_action", target)

	if !slices.Contains(allowedActions, target) {
		return nil, fmt.Errorf("invalid deny_action '%s', must be one of %s", config.DenyAction, strings.Join(allowedActions, ", "))
	}

	v4Sets := make(map[string]*ipsetcmd.IPSet)
	v6Sets := make(map[string]*ipsetcmd.IPSet)

	ipv4Ctx := &ipTablesContext{
		version:              "v4",
		SetName:              config.BlacklistsIpv4,
		SetType:              config.SetType,
		SetSize:              config.SetSize,
		ipsetDisableTimeouts: config.SetDisableTimeouts,
		Chains:               []string{},
		defaultSet:           defaultSet,
		target:               target,
		loggingEnabled:       config.DenyLog,
		loggingPrefix:        config.DenyLogPrefix,
		addRuleComments:      config.IptablesAddRuleComments,
	}
	ipv6Ctx := &ipTablesContext{
		version:              "v6",
		SetName:              config.BlacklistsIpv6,
		SetType:              config.SetType,
		SetSize:              config.SetSize,
		ipsetDisableTimeouts: config.SetDisableTimeouts,
		Chains:               []string{},
		defaultSet:           defaultSet,
		target:               target,
		loggingEnabled:       config.DenyLog,
		loggingPrefix:        config.DenyLogPrefix,
		addRuleComments:      config.IptablesAddRuleComments,
	}

	if !config.DisableIPV4 {
		ipv4Ctx.iptablesSaveBin, err = exec.LookPath("iptables-save")
		if err != nil {
			return nil, errors.New("unable to find iptables-save")
		}

		if config.Mode == cfg.IpsetMode {
			ipv4Ctx.ipsetContentOnly = true

			set, err := ipsetcmd.NewIPSet(config.BlacklistsIpv4)
			if err != nil {
				return nil, err
			}

			v4Sets["ipset"] = set
		} else {
			ipv4Ctx.iptablesBin, err = exec.LookPath("iptables")
			if err != nil {
				return nil, errors.New("unable to find iptables")
			}

			// Try to "adopt" any leftover sets from a previous run if we crashed
			// They will get flushed/deleted just after
			v4Sets, _ = ipsetcmd.GetSetsStartingWith(config.BlacklistsIpv4)

			config.IptablesV4Chains = append(config.IptablesV4Chains, config.IptablesChains...)
			ipv4Ctx.Chains = config.IptablesV4Chains
		}

		ipv4Ctx.ipsets = v4Sets
		ret.v4 = ipv4Ctx
	}

	if !config.DisableIPV6 {
		ipv6Ctx.iptablesSaveBin, err = exec.LookPath("ip6tables-save")
		if err != nil {
			return nil, errors.New("unable to find ip6tables-save")
		}

		if config.Mode == cfg.IpsetMode {
			ipv6Ctx.ipsetContentOnly = true

			set, err := ipsetcmd.NewIPSet(config.BlacklistsIpv6)
			if err != nil {
				return nil, err
			}

			v6Sets["ipset"] = set
		} else {
			ipv6Ctx.iptablesBin, err = exec.LookPath("ip6tables")
			if err != nil {
				return nil, errors.New("unable to find ip6tables")
			}

			v6Sets, _ = ipsetcmd.GetSetsStartingWith(config.BlacklistsIpv6)
			config.IptablesV6Chains = append(config.IptablesV6Chains, config.IptablesChains...)
			ipv6Ctx.Chains = config.IptablesV6Chains
		}

		ipv6Ctx.ipsets = v6Sets
		ret.v6 = ipv6Ctx
	}
	return ret, nil
}

func (ipt *iptables) Init() error {
	if ipt.v4 != nil {
		log.Info("iptables for ipv4 initiated")

		// flush before init
		if err := ipt.v4.shutDown(); err != nil {
			return fmt.Errorf("iptables shutdown failed: %w", err)
		}

		if !ipt.v4.ipsetContentOnly {
			ipt.v4.setupChain()
		}
	}

	if ipt.v6 != nil {
		log.Info("iptables for ipv6 initiated")

		if err := ipt.v6.shutDown(); err != nil {
			return fmt.Errorf("iptables shutdown failed: %w", err)
		}

		if !ipt.v6.ipsetContentOnly {
			ipt.v6.setupChain()
		}
	}

	return nil
}

func (ipt *iptables) Commit() error {
	if ipt.v4 != nil {
		err := ipt.v4.commit()
		if err != nil {
			return fmt.Errorf("ipset for ipv4 commit failed: %w", err)
		}
	}

	if ipt.v6 != nil {
		err := ipt.v6.commit()
		if err != nil {
			return fmt.Errorf("ipset for ipv6 commit failed: %w", err)
		}
	}

	return nil
}

func (ipt *iptables) Add(decision *models.Decision) error {
	if strings.HasPrefix(*decision.Type, "simulation:") {
		log.Debugf("measure against '%s' is in simulation mode, skipping it", *decision.Value)
		return nil
	}

	if strings.Contains(*decision.Value, ":") {
		if ipt.v6 == nil {
			log.Debugf("not adding '%s' because ipv6 is disabled", *decision.Value)
			return nil
		}

		ipt.v6.add(decision)
	} else {
		if ipt.v4 == nil {
			log.Debugf("not adding '%s' because ipv4 is disabled", *decision.Value)
			return nil
		}
		ipt.v4.add(decision)
	}

	return nil
}

func (ipt *iptables) ShutDown() error {
	if ipt.v4 != nil {
		if err := ipt.v4.shutDown(); err != nil {
			return fmt.Errorf("iptables for ipv4 shutdown failed: %w", err)
		}
	}

	if ipt.v6 != nil {
		if err := ipt.v6.shutDown(); err != nil {
			return fmt.Errorf("iptables for ipv6 shutdown failed: %w", err)
		}
	}

	return nil
}

func (ipt *iptables) Delete(decision *models.Decision) error {
	done := false

	if strings.Contains(*decision.Value, ":") {
		if ipt.v6 == nil {
			log.Debugf("not deleting '%s' because ipv6 is disabled", *decision.Value)
			return nil
		}

		if err := ipt.v6.delete(decision); err != nil {
			return errors.New("failed deleting ban")
		}

		done = true
	}

	if strings.Contains(*decision.Value, ".") {
		if ipt.v4 == nil {
			log.Debugf("not deleting '%s' because ipv4 is disabled", *decision.Value)
			return nil
		}

		if err := ipt.v4.delete(decision); err != nil {
			return errors.New("failed deleting ban")
		}

		done = true
	}

	if !done {
		return fmt.Errorf("failed deleting ban: ip %s was not recognized", *decision.Value)
	}

	return nil
}
