#!/usr/bin/env python3
import argparse
import re
from pathlib import Path
from shutil import which
from subprocess import run

from shared.common import SRC_DIR


def fix_imports(path: Path) -> None:
    iwyu_result = run(
        ["include-what-you-use", "-I", "src", path],
        capture_output=True,
        text=True,
    ).stderr
    run(
        [which("fix_include") and "fix_include" or "iwyu-fix-includes"],
        input=iwyu_result,
        text=True,
    )


def custom_sort(source: list[str], forced_order: list[str]) -> list[str]:
    def key_func(item: str) -> tuple[int, int, str]:
        if item in forced_order:
            return (forced_order[0], forced_order.index(item))
        return (item, 0)

    return sorted(source, key=key_func)


def sort_imports(path: Path) -> None:
    source = path.read_text()
    rel_path = path.relative_to(SRC_DIR)
    own_include = str(rel_path.with_suffix(".h"))
    own_include = {
        # files headers of which are not a 1:1 match with their filename
        "game/game/game.c": "game/game.h",
        "game/game/game_cutscene.c": "game/game.h",
        "game/game/game_demo.c": "game/game.h",
        "game/game/game_draw.c": "game/game.h",
        "game/game/game_pause.c": "game/game.h",
        "game/gun/gun.c": "game/gun.h",
        "game/inventry.c": "game/inv.h",
        "game/invfunc.c": "game/inv.h",
        "game/invvars.c": "game/inv.h",
        "game/lara/lara.c": "game/lara.h",
        "game/option/option.c": "game/option.h",
        "game/savegame/savegame.c": "game/savegame.h",
        "specific/s_audio_sample.c": "specific/s_audio.h",
        "specific/s_audio_stream.c": "specific/s_audio.h",
        "specific/s_log_unknown.c": "specific/s_log.h",
        "specific/s_log_linux.c": "specific/s_log.h",
        "specific/s_log_windows.c": "specific/s_log.h",
    }.get(str(rel_path), own_include)

    forced_order = [
        "<windows.h>",
        "<dbghelp.h>",
        "<tlhelp32.h>",
    ]

    def cb(match):
        includes = re.findall(r'#include (["<][^"<>]+[">])', match.group(0))
        groups = {
            "own": set(),
            "proj": set(),
            "lib": set(),
        }
        for include in includes:
            if include.strip('"') == own_include:
                groups["own"].add(include)
            elif include.startswith("<"):
                groups["lib"].add(include)
            else:
                groups["proj"].add(include)

        groups = {key: value for key, value in groups.items() if value}

        ret = "\n\n".join(
            "\n".join(
                f"#include {include}"
                for include in custom_sort(group, forced_order)
            )
            for group in groups.values()
        ).strip()
        return ret

    source = re.sub(
        "^#include [^\n]+(\n*#include [^\n]+)*",
        cb,
        source,
        flags=re.M,
    )
    if source != path.read_text():
        path.write_text(source)


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser()
    parser.add_argument(metavar="path", type=Path, nargs="*", dest="paths")
    return parser.parse_args()


def main() -> None:
    args = parse_args()
    paths = [path.absolute() for path in args.paths]

    if not paths:
        paths = sorted(
            path
            for path in SRC_DIR.glob("**/*.[ch]")
            if path != SRC_DIR / "init.c"
        )

    for path in paths:
        fix_imports(path)
        sort_imports(path)


if __name__ == "__main__":
    main()
