#!/usr/bin/env python3
import re
import sys
from collections.abc import Callable, Iterable
from dataclasses import dataclass
from fnmatch import fnmatch
from pathlib import Path

from shared.common import DATA_DIR, REPO_DIR, SRC_DIR, TOOLS_DIR

EXTENSIONS = [
    "Dockerfile",
    ".h",
    ".c",
    ".build",
    ".ninja",
    ".py",
    ".cs",
    ".xaml",
    ".csproj",
    ".txt",
    ".rc",
    ".glsl",
    ".md",
    ".json",
    ".json5",
    "Dockerfile",
]
EXCEPTIONS = [
    re.compile(r"tools/([^/]*/)*(bin|obj)/.*"),
]


@dataclass
class LintWarning:
    path: Path
    message: str
    line: int | None = None

    def __str__(self) -> str:
        prefix = str(self.path.relative_to(REPO_DIR))
        if self.line is not None:
            prefix += f":{self.line}"
        return f"{prefix}: {self.message}"


def get_files() -> Iterable[Path]:
    for root_dir in (DATA_DIR, SRC_DIR, TOOLS_DIR):
        for path in root_dir.rglob("**/*"):
            if (
                path.is_file()
                and ((path.suffix in EXTENSIONS) or (path.name in EXTENSIONS))
                and not any(
                    pattern.match(str(path.relative_to(REPO_DIR)))
                    for pattern in EXCEPTIONS
                )
            ):
                yield path


def lint_newlines(path: Path) -> Iterable[LintWarning]:
    text = path.read_text()
    if text and not text.endswith("\n"):
        yield LintWarning(path, "missing newline character at end of file")


def lint_trailing_whitespace(path: Path) -> Iterable[LintWarning]:
    for i, line in enumerate(path.open("r"), 1):
        if line.rstrip("\n").endswith(" "):
            yield LintWarning(path, "trailing whitespace", line=i)


LINTERS: list[Callable[[], Iterable[LintWarning]]] = [
    lint_newlines,
    lint_trailing_whitespace,
]


def main() -> None:
    files = get_files()

    exit_code = 0
    for file in files:
        for linter in LINTERS:
            for lint_warning in linter(file):
                print(
                    str(lint_warning),
                    file=sys.stderr,
                )
                exit_code = 1

    exit(exit_code)


if __name__ == "__main__":
    main()
