package main

import (
	"fmt"
	"slices"
	"strconv"
	"unicode"

	"git.mstar.dev/mstar/aoc24/util"
	"git.mstar.dev/mstar/goutils/other"
	"git.mstar.dev/mstar/goutils/sliceutils"
)

const EmptyBlockId = -1

func findFirstEmptyIndex(list []int) (int, bool) {
	for i, v := range list {
		if v == EmptyBlockId {
			return i, true
		}
	}
	return -1, false
}

func findLastUsedIndex(list []int) (int, bool) {
	for i := len(list) - 1; i >= 0; i-- {
		if list[i] != EmptyBlockId {
			return i, true
		}
	}
	return -1, false
}

func indicesToBlocks(indices []int) []int {
	blocks := []int{}
	// indices are len of block, len of empty, len of block, len of empty
	isFile := true
	// block id increments each time a block with a file completes
	i := 0
	for _, size := range indices {
		// Write the full block
		for range size {
			if isFile {
				blocks = append(blocks, i)
			} else {
				blocks = append(blocks, EmptyBlockId)
			}
		}
		// If current block is a file, increment id counter
		if isFile {
			i++
		}
		// Swap state
		isFile = !isFile
	}
	return blocks
}

func compactBlocks1(blocks []int) (compacted []int) {
	compacted = slices.Clone(blocks)

	firstEmpty, _ := findFirstEmptyIndex(compacted)
	lastUsed, _ := findLastUsedIndex(compacted)
	for firstEmpty < lastUsed {
		util.SliceSwap(compacted, firstEmpty, lastUsed)
		firstEmpty, _ = findFirstEmptyIndex(compacted)
		lastUsed, _ = findLastUsedIndex(compacted)
	}

	return
}

func checksumBlocks(blocks []int) uint64 {
	var total uint64 = 0
	for i, v := range blocks {
		if v == EmptyBlockId {
			continue
		}
		total += uint64(i * v)
	}
	return total
}

func findFirstEmptyBlockOfLength(list []int, length int) (int, bool) {
	lenCounter := 0
	for i, v := range list {
		if v == EmptyBlockId {
			lenCounter++
		} else {
			lenCounter = 0
		}
		if lenCounter == length {
			// Start of empty block with (at least) target length is at i - (lenCounter-1), since lenCounter starts at 0
			return i - lenCounter + 1, true
		}
	}
	return -1, false
}

func findBlockById(list []int, blockId int) (start, length int, found bool) {
	length = 0
	for i := len(list) - 1; i >= 0; i-- {
		if list[i] == blockId {
			length++
		} else {
			if length != 0 {
				return i + 1, length, true
			}
		}
	}
	return -1, -1, false
}

func compactBlocks2(list []int) (compacted []int) {
	compacted = slices.Clone(list)
	lastBlockIndex, _ := findLastUsedIndex(compacted)
	lastBlockId := compacted[lastBlockIndex]
	// fmt.Printf("i1 Checking block id %d\n", lastBlockId)
	lastUsedStart, length, _ := findBlockById(compacted, lastBlockId)
	firstEmpty, found := findFirstEmptyBlockOfLength(compacted, length)
	for !found {
		lastBlockId--
		if lastBlockId < 0 {
			return
		}
		// fmt.Printf("i2 Checking block id %d\n", lastBlockId)
		lastUsedStart, length, _ = findBlockById(compacted, lastBlockId)
		firstEmpty, found = findFirstEmptyBlockOfLength(compacted, length)
	}
	for {
		// fmt.Printf(
		// 	"Moving block %d starting at %d with len %d to %d\n",
		// 	lastBlockId,
		// 	lastUsedStart,
		// 	length,
		// 	firstEmpty,
		// )
		for i := range length {
			util.SliceSwap(compacted, lastUsedStart+i, firstEmpty+i)
		}
		lastBlockId--
		if lastBlockId < 0 {
			return
		}
		// fmt.Printf("l1 Checking block id %d\n", lastBlockId)
		lastUsedStart, length, _ = findBlockById(compacted, lastBlockId)
		firstEmpty, found = findFirstEmptyBlockOfLength(compacted, length)
		// startLastEmpty := findStartLastEmptyBlock(compacted)
		// fmt.Printf("First empty found: %d, start last empty: %d\n", firstEmpty, startLastEmpty)
		for !found || firstEmpty > lastUsedStart {
			// fmt.Printf("No space for id %d found\n", lastBlockId)
			lastBlockId--
			if lastBlockId < 0 {
				return
			}
			// fmt.Printf("l2 Checking block id %d\n", lastBlockId)
			lastUsedStart, length, _ = findBlockById(compacted, lastBlockId)
			firstEmpty, found = findFirstEmptyBlockOfLength(compacted, length)
		}
	}
}

func main() {
	rawInput := []rune(string(util.LoadFileFromArgs()))
	allNums := sliceutils.Map(
		sliceutils.Filter(rawInput, func(t rune) bool { return unicode.IsNumber(t) }),
		func(t rune) int { return other.Must(strconv.Atoi(string(t))) },
	)
	blocks := indicesToBlocks(allNums)
	// fmt.Printf("Parsed blocks: %v\n", blocks)
	compacted := compactBlocks1(blocks)
	// fmt.Printf("Compacted blocks: %v\n", compacted)
	onlySetBlocks := sliceutils.Filter(compacted, func(t int) bool { return t != EmptyBlockId })
	checksum1 := checksumBlocks(onlySetBlocks)
	fmt.Printf("Task 1: %d\n", checksum1)
	compacted2 := compactBlocks2(blocks)
	fmt.Printf(
		"Task 2: %d\n",
		checksumBlocks(
			compacted2,
		),
	)
}