#!/usr/bin/env python3

"""This script can be used to analyze TR1X for potential memory leaks.

For the script to work, the repository needs to have applied
tools/analyze_memleaks.patch that logs information about every memory
allocation and free.
"""

import argparse
import re
from collections import defaultdict
from collections.abc import Callable, Iterable
from dataclasses import dataclass
from enum import Enum
from itertools import groupby
from pathlib import Path
from typing import Any, TypeVar

T = TypeVar("T")
U = TypeVar("U")


class OperationType(Enum):
    Allocation = 1
    Reallocation = 2
    Deallocation = 3


@dataclass
class Operation:
    file: str
    line: int
    func: str
    type: OperationType
    address: int
    size: int | None
    new_address: int | None = None


@dataclass
class Problem:
    message: str
    operation: Operation | None


def group(
    source: Iterable[T], key=Callable[[T], U]
) -> Iterable[tuple[U, list[T]]]:
    for key, group in groupby(sorted(source, key=key), key=key):
        yield key, list(group)


def get_operations_from_log_lines(lines: Iterable[str]) -> Iterable[Operation]:
    for line in lines:
        if match := re.search(
            r"(\S+) (\d+) (\S+) Allocating memory \((\d+) bytes\): ([A-F0-9]*)",
            line,
        ):
            yield Operation(
                file=match.group(1),
                line=int(match.group(2)),
                func=match.group(3),
                type=OperationType.Allocation,
                size=int(match.group(4)),
                address=int(match.group(5), 16),
            )
        elif match := re.search(
            r"(\S+) (\d+) (\S+) Reallocating memory at ([A-F0-9]*) \((\d+) bytes\): ([A-F0-9]*)",
            line,
        ):
            yield Operation(
                file=match.group(1),
                line=int(match.group(2)),
                func=match.group(3),
                type=OperationType.Reallocation,
                address=int(match.group(4), 16),
                size=int(match.group(5)),
                new_address=int(match.group(6), 16),
            )
        elif match := re.search(
            r"(\S+) (\d+) (\S+) Freeing memory at ([A-F0-9]*)", line
        ):
            yield Operation(
                file=match.group(1),
                line=int(match.group(2)),
                func=match.group(3),
                type=OperationType.Deallocation,
                size=None,
                address=int(match.group(4), 16),
            )


def find_problems(operations: Iterable[Operation]) -> Iterable[Problem]:
    reachable_memory: dict[int, Operation] = {}
    for operation in operations:
        if operation.type == OperationType.Allocation:
            assert operation.address not in reachable_memory
            reachable_memory[operation.address] = operation
        elif operation.type == OperationType.Reallocation:
            if operation.address != 0:
                del reachable_memory[operation.address]
            assert operation.new_address not in reachable_memory
            reachable_memory[operation.new_address] = operation
        elif operation.type == OperationType.Deallocation:
            if operation.address != 0:
                try:
                    del reachable_memory[operation.address]
                except KeyError:
                    yield Problem("Invalid free", operation=operation)

    for address, operation in reachable_memory.items():
        yield Problem(
            f"Unfreed memory: {address:08x} ({operation.size} bytes)",
            operation=operation,
        )

    reachable_bytes = sum(
        operation.size for operation in reachable_memory.values()
    )
    yield Problem(
        f"Total reachable memory: {reachable_bytes / 1024.0 / 1024.0:.02f} MB",
        operation=None,
    )


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser()
    parser.add_argument("path", help="Path to the TR1X.log file.")
    return parser.parse_args()


def main():
    args = parse_args()
    with Path(args.path).open("r", encoding="utf-8") as handle:
        operations = list(get_operations_from_log_lines(handle))

        key: tuple[str, int]
        for key, values in group(
            source=find_problems(operations),
            key=(
                lambda problem: (
                    problem.operation.file,
                    problem.operation.line,
                )
                if problem.operation
                else ("ZZZ", 0)
            ),
        ):
            operation = values[0].operation
            if operation:
                print(operation.file, operation.line)
            else:
                print("General problems:")
            for problem in values:
                print(" " * 4, problem.message)


if __name__ == "__main__":
    main()
