#! /usr/bin/env python

from e3.env import Env
from e3.diff import diff
from e3.fs import find
from e3.testsuite import logger, Testsuite
from e3.testsuite.control import AdaCoreLegacyTestControlCreator, TestControl
from e3.testsuite.driver import TestDriver
from e3.testsuite.driver.adacore import AdaCoreLegacyTestDriver
from e3.testsuite.driver.diff import DiffTestDriver, OutputRefiner, RefiningChain
from e3.testsuite.result import binary_repr, Log, truncated
from e3.testsuite.testcase_finder import ParsedTest, ProbingError, TestFinder
import e3.yaml
import glob
import multiprocessing
import os
import os.path
import re
import subprocess
import sys
import tempfile
import time
from typing import AnyStr, List, Optional

# name for the entry in the test_env dict for the subprogram to call
call_entry = "call_name"
# name for the entry in the test_env dict for arguments to
# prove_all/do_flow/no_crash calls
args_entry = "call_args"


class SPARKControlCreator(AdaCoreLegacyTestControlCreator):
    """Class that decides the status (XFAIL, SKIP, etc) of a test. This is the
    same as the standard "test.opt" mechanism, except that we also look into
    the yaml file for the "large" attribute."""

    def create(self, driver: TestDriver) -> TestControl:
        if (
            "large" in self.condition_env
            and self.condition_env["large"]
            and "large" not in self.system_tags
        ):
            return TestControl("large test", skip=True, xfail=False)
        return super().create(driver)

    def __init__(
        self, system_tags: List[str], env=None, opt_filename: str = "test.opt"
    ) -> None:
        super().__init__(system_tags, opt_filename)
        self.condition_env = {} if env is None else env


class ProofTestDriver(AdaCoreLegacyTestDriver):
    """Driver for all SPARK tests. This test generates a test.py if not already
    there. The generated test.py contains a call to the method in
    self.test_env[call_entry], with the arguments in
    self.test_env[args_entry]."""

    copy_test_directory = True

    @property
    def test_control_creator(self):
        assert isinstance(self.env.discs, list)
        return SPARKControlCreator(self.env.discs, self.test_env)

    def get_script_command_line(self):
        script_file = self.working_dir("src", "test.py")
        if not os.path.isfile(script_file):
            assert call_entry in self.test_env
            assert args_entry in self.test_env
            argdict = self.test_env[args_entry]
            with open(script_file, "w") as fn:
                fn.write("from test_support import *" + "\n")
                fn.write("args = " + str(argdict) + "\n")
                fn.write(self.test_env[call_entry] + "(**args)" + "\n")
        self.test_control.opt_results["CMD"] = script_file
        return super().get_script_command_line()

    def set_up(self) -> None:
        super().set_up()
        if "timeout" in self.test_env:
            self.timeout = int(self.test_env["timeout"])


class NoEmptyOutputRefiner(OutputRefiner):
    def refine(self, output):
        if len(output) == 0:
            return "(no output)"
        return output


class SortLinesRefiner(OutputRefiner):
    def refine(self, output):
        lines = output.splitlines()
        lines.sort()
        return "\n".join(lines)


class RACTestDriver(DiffTestDriver):

    _baseline = None

    def set_up(self):
        super().set_up()
        with open(self.gprfile, "w") as f:
            f.write(
                """
project Test is
  for Main use ("main.adb");
end Test;
"""
            )
        compile_cmd = [
            "gnatmake",
            "-P",
            self.gprfile,
            "-gnata",
            "-gnatw.A",
            "-gnat2022",
        ]
        self.shell(compile_cmd, analyze_output=False)
        main_cmd = [os.path.join(self.cwd, "main")]
        self.shell(main_cmd, analyze_output=True, catch_error=False)
        self._baseline = str(self.output)
        self.output = Log("")

    @property
    def test_control_creator(self):
        assert isinstance(self.env.discs, list)
        return SPARKControlCreator(self.env.discs, self.test_env)

    @property
    def baseline(self):
        return None, self._baseline, False

    @property
    def cwd(self):
        return self.test_env["working_dir"]

    @property
    def gprfile(self):
        return os.path.join(self.cwd, "test.gpr")

    def run(self):
        cmd = [
            "gnatprove",
            "-P",
            self.gprfile,
            "--quiet",
            "--debug-exec-rac",
            "--no-inlining",
            "--warnings=off",
            "-cargs",
            "-gnat2022",
        ]
        self.shell(cmd, analyze_output=True, catch_error=False)


class UGTestDriver(DiffTestDriver):
    """Simple test driver (in principle) that just runs a fixed gnatprove
    command on the UG tests. The output is sorted before comparison as we
    have observed diffs in the order of file processing (for the tests that
    contain multiple units). When updating baselines, the new actual output
    (without sorting) needs to be stored. Because of this requirement, we
    had to copy the code of "compute_diff" here and modify it to store
    non-refined baselines.
    """

    copy_test_directory = True

    output_refiners = [NoEmptyOutputRefiner(), SortLinesRefiner()]

    @property
    def test_control_creator(self):
        assert isinstance(self.env.discs, list)
        return SPARKControlCreator(self.env.discs, self.test_env)

    @property
    def refine_baseline(self) -> bool:
        """Whether to apply output refiners to the output baseline."""
        return True

    # copying this code from DiffTestDriver
    def compute_diff(
        self,
        baseline_file: Optional[str],
        baseline: AnyStr,
        actual: AnyStr,
        failure_message: str = "unexpected output",
        ignore_white_chars: Optional[bool] = None,
        truncate_logs_threshold: Optional[int] = None,
    ) -> List[str]:
        """Compute the diff between expected and actual outputs.

        Return an empty list if there is no diff, and return a list that
        contains an error message based on ``failure_message`` otherwise.

        :param baseline_file: Absolute filename for the text file that contains
            the expected content (for baseline rewriting, if enabled), or None.
        :param actual: Actual content to compare.
        :param failure_message: Failure message to return if there is a
            difference.
        :param ignore_white_chars: Whether to ignore whitespaces during the
            diff computation. If left to None, use
            ``self.diff_ignore_white_chars``.
        :param truncate_logs_threshold: Threshold to truncate the diff message
            in ``self.result.log``. See ``e3.testsuite.result.truncated``'s
            ``line_count`` argument. If left to None, use the testsuite's
            ``--truncate-logs`` option.
        """
        if ignore_white_chars is None:
            ignore_white_chars = self.diff_ignore_white_chars

        if truncate_logs_threshold is None:
            truncate_logs_threshold = self.testsuite_options.truncate_logs

        # Run output refiners on the actual output, not on the baseline
        refiners = (
            RefiningChain[str](self.output_refiners)
            if isinstance(actual, str)
            else RefiningChain[bytes](self.output_refiners)
        )
        refined_actual = refiners.refine(actual)
        refined_baseline = (
            refiners.refine(baseline) if self.refine_baseline else baseline
        )

        # When running in binary mode, make sure the diff runs on text strings
        if self.default_encoding == "binary":
            assert isinstance(refined_actual, bytes)
            assert isinstance(refined_baseline, bytes)
            decoded_ref_actual = binary_repr(refined_actual)
            decoded_baseline = binary_repr(refined_baseline)
            decoded_actual = binary_repr(actual)
        else:
            assert isinstance(refined_actual, str)
            assert isinstance(refined_baseline, str)
            decoded_ref_actual = refined_actual
            decoded_baseline = refined_baseline
            decoded_actual = actual

        # Get the two texts to compare as list of lines, with trailing
        # characters preserved (splitlines(keepends=True)).
        expected_lines = decoded_baseline.splitlines(True)
        actual_lines = decoded_ref_actual.splitlines(True)

        # Compute the diff. If it is empty, return no failure. Otherwise,
        # include the diff in the test log and return the given failure
        # message.
        d = diff(expected_lines, actual_lines, ignore_white_chars=ignore_white_chars)
        if not d:
            return []

        self.failing_diff_count += 1
        message = failure_message

        diff_lines = []
        for line in d.splitlines():
            # Add colors diff lines
            if line.startswith("-"):
                color = self.Fore.RED
            elif line.startswith("+"):
                color = self.Fore.GREEN
            elif line.startswith("@"):
                color = self.Fore.CYAN
            else:
                color = ""
            diff_lines.append(color + line + self.Style.RESET_ALL)

        # If requested and the failure is not expected, rewrite the test
        # baseline with the new one.
        if (
            baseline_file is not None
            and not self.test_control.xfail
            and getattr(self.env, "rewrite_baselines", False)
        ):
            if isinstance(actual, str):
                with open(baseline_file, "w", encoding=self.default_encoding) as f:
                    f.write(actual)
            else:
                assert isinstance(actual, bytes)
                with open(baseline_file, "wb") as f:
                    f.write(actual)
            message = "{} (baseline updated)".format(message)

        # Send the appropriate logging. Make sure ".log" has all the
        # information. If there are multiple diff failures for this testcase,
        # do not emit the "expected/out" logs, as they support only one diff.
        diff_log = (
            self.Style.RESET_ALL
            + self.Style.BRIGHT
            + "Diff failure: {}\n".format(message)
            + "\n".join(diff_lines)
            + "\n"
        )
        self.result.log += "\n" + truncated(diff_log, truncate_logs_threshold)
        if self.failing_diff_count == 1:
            self.result.expected = Log(decoded_baseline)
            self.result.out = Log(decoded_actual)
            self.result.diff = Log(diff_log)
        else:
            self.result.expected = None
            self.result.out = None
            assert isinstance(self.result.diff, Log) and isinstance(
                self.result.diff.log, str
            )
            self.result.diff += "\n" + diff_log

        return [message]

    def run(self):
        env = os.environ.copy()
        pattern = os.path.join(self.test_env["working_dir"], "*.gpr")
        gprfiles = glob.glob(pattern)
        assert len(gprfiles) == 1
        cmd = ["gnatprove", "-P", gprfiles[0], "-f", "-q", "--level=2", "--timeout=30"]

        # Generate sparklib.gpr in case the project depends on SPARKlib
        sparklib_proj = os.path.join(self.test_env["working_dir"], "sparklib.gpr")
        with open(sparklib_proj, "w") as f_prj:
            f_prj.write('project SPARKlib extends "sparklib_external" is\n')
            f_prj.write('   for Object_Dir use "sparklib_obj";\n')
            f_prj.write("   for Source_Dirs use SPARKlib_External'Source_Dirs;\n")
            f_prj.write(
                "   for Excluded_Source_Files use "
                + "SPARKlib_External'Excluded_Source_Files;\n"
            )
            f_prj.write("end SPARKlib;\n")

        # Tests withing SPARK libraries need this switch
        cmd += ["--no-subprojects"]

        # tests prefixed by ug__long__ require more steps to prove
        if self.test_env["test_name"].startswith("ug__long__"):
            cmd += ["--steps=19000"]
        else:
            cmd += ["--steps=200"]

        start_time = time.time()
        self.shell(cmd, env=env, catch_error=False)
        self.result.time = time.time() - start_time


class SPARKTestFinder(TestFinder):
    """Class to find the tests. The [probe] method determines if a folder is a
    valid test. To do that it takes into account various command line arguments
    that allow users to select tests. It also reads the test.yaml file of the
    test if present, and loads it into the test environment. Finally, it sets
    some settings, such as the "call_entry" and "args_entry" fields of the test
    environment."""

    def __init__(self, root_dir, testlist=None, pattern="", only_large=False):
        """
        Initialize a SPARKTestFinder instance.
        """
        self.testlist = [] if testlist is None else testlist
        self.root_dir = root_dir
        self.only_large = only_large
        if pattern:
            self.pattern = re.compile(pattern.encode("utf-8"))
        else:
            self.pattern = None

    def file_contains_regex(self, pattern, fn):
        """
        Return True iff the file [fn] contains [pattern], which is a compiled
        regex.
        """
        with open(fn, "rb") as f:
            for line in f:
                if pattern.search(line):
                    return True
        return False

    def test_contains_pattern(self, test, pattern):
        """
        Return True iff the test contains an .adb/s file that contains
        [pattern], which is a compiled regex.
        """
        return any(
            self.file_contains_regex(pattern, fn) for fn in find(test, "*.ad[bs]")
        )

    def probe(self, testsuite, dirpath, dirnames, filenames):
        """
        See documentation of e3.testsuite for the arguments of this method. We
        check if we consider [dirpath] a valid test of the testsuite, and
        return a ParsedTest object in that case, and [None] otherwise.
        """
        parent_name = os.path.dirname(dirpath)
        testname = os.path.basename(dirpath)
        allowed_testdirs = ["tests", "internal", "sparklib"]
        # If the dir is not a direct subdir of <root_dir/tests>, skip it
        if (
            os.path.basename(parent_name) not in allowed_testdirs
            or os.path.dirname(parent_name) != self.root_dir
        ):
            return None
        # If bogus test dir such as git folder, skip it
        if testname == ".git":
            return None

        # We have a valid test directory, so we don't need to visit its
        # subdirectories, e.g. with session files. This should save us
        # unnecessary I/O operations. Modify the list of directories in-place,
        # as explained in the Python documentation for os.walk routine, which
        # is used by e3-testsuite driver.
        del dirnames[:]

        # If testlist was passed and the dir is not in testlist, skip it
        if self.testlist and testname not in self.testlist:
            return None
        # If pattern was passed and dir doesn't contain files with the pattern,
        # skip it.
        if self.pattern and not self.test_contains_pattern(dirpath, self.pattern):
            return None

        # We read the test.yaml file if present and load it into the
        # environment
        yaml_file = os.path.join(dirpath, "test.yaml")
        if "test.yaml" in filenames:
            try:
                test_env = e3.yaml.load_with_config(yaml_file, Env().to_dict())
            except e3.yaml.YamlError as exc:
                raise ProbingError(
                    "invalid syntax for test.yaml in '{}'".format(testname)
                ) from exc
        else:
            test_env = {}

        # If this is not a large test, but only large was requested, skip
        # test. The "always_fail" test is exempted from this check to include
        # it in both large and non-large testsuites.
        if (
            self.only_large
            and "large" not in test_env
            and "always_fail" not in testname
        ):
            return None

        if testname.startswith("ug__"):
            return ParsedTest(testname, UGTestDriver, {}, dirpath)

        if "__rac" in testname:
            return ParsedTest(testname, RACTestDriver, {}, dirpath)

        # If the test contains a test.py, we use the AdaCoreLegacyTestDriver,
        # otherwise we use the proof/flow driver
        if "test.py" in filenames or "test.cmd" in filenames:
            if "prove_all" in test_env or "do_flow" in test_env:
                logger.warning(
                    '{}: "test.py" file is present, so "prove_all" '
                    'and "do_flow" ignored in "test.yaml"'.format(testname)
                )
        else:
            # if it's a flow test, we set the call to do_flow, and if no
            # do_flow entry is present, we define empty args.
            # We emit a warning if "prove_all" is present
            if "__flow" in testname:
                test_env[call_entry] = "do_flow"
                if "prove_all" in test_env:
                    logger.warning(
                        '{}: "prove_all" in "test.yaml" ignored for flow test'.format(
                            testname
                        )
                    )
                test_env[args_entry] = (
                    test_env["do_flow"] if "do_flow" in test_env else {}
                )
            # same for proof, either with "no_crash" or "prove_all" entry point
            else:
                test_env[call_entry] = (
                    "no_crash" if "no_crash" in test_env else "prove_all"
                )
                if "do_flow" in test_env:
                    logger.warning(
                        '{}: "do_flow" in "test.yaml" ignored for proof test'.format(
                            testname
                        )
                    )
                test_env[args_entry] = (
                    test_env["no_crash"]
                    if "no_crash" in test_env
                    else test_env["prove_all"]
                    if "prove_all" in test_env
                    else {}
                )
        # copy replay info into test_env[args_entry] if not already present:
        if "replay" in test_env and "replay" not in test_env[args_entry]:
            test_env[args_entry]["replay"] = test_env["replay"]
        # This is a valid test, return the ParsedTest object
        return ParsedTest(testname, ProofTestDriver, test_env, dirpath)


class SPARKTestsuite(Testsuite):
    """Testsuite for SPARK."""

    @property
    def test_finders(self):
        return [
            SPARKTestFinder(
                self.root_dir,
                testlist=self.testlist,
                pattern=self.main.args.pattern,
                only_large=self.env.only_large,
            )
        ]

    def add_options(self, parser):
        parser.add_argument(
            "--cache", action="store_true", help="Use memcached to run testsuite faster"
        )
        parser.add_argument(
            "--diffs", action="store_true", help="Synonym of --show-error-output/-E"
        )
        parser.add_argument("--disc", type=str, help="discriminants to be used")
        parser.add_argument(
            "--benchmark",
            type=str,
            help="run testsuite in benchmark mode for given prover",
        )
        parser.add_argument(
            "--no-gaia-output",
            action="store_true",
            help="The opposite of --gaia-output (which is on by default).",
        )
        parser.add_argument(
            "--pattern",
            type=str,
            help="Argument is a python regex. Run only tests whose \
                  .adb/ads files match the regex.",
        )
        parser.add_argument(
            "--rewrite",
            action="store_true",
            help="Rewrite test baselines according to current outputs",
        )
        parser.add_argument(
            "--testlist",
            type=str,
            help="Argument is a file which contains one test \
                  name per line. Run only those tests.",
        )
        parser.add_argument(
            "--only-large",
            action="store_true",
            help="Run only large tests; implies --disc=large",
        )
        parser.add_argument(
            "--share-why3server",
            action="store_true",
            help="Use a shared why3server for all tests",
        )

    def run_why3server(self):
        cur_dir = os.getcwd()
        tmpdir = tempfile.gettempdir()
        try:
            os.chdir(tmpdir)
            socketname = os.path.join(tmpdir, "runtests.sock")
            cmd = [
                "why3server",
                "-j",
                str(multiprocessing.cpu_count()),
                "--socket",
                socketname,
            ]
            self.why3_process = subprocess.Popen(cmd)
        finally:
            os.chdir(cur_dir)
        return socketname

    def compute_environ(self):
        python_lib = os.path.join(self.root_dir, "lib", "python")
        pypath = "PYTHONPATH"
        if pypath in os.environ:
            os.environ["PYTHONPATH"] += os.path.pathsep + python_lib
        else:
            os.environ["PYTHONPATH"] = python_lib
        if self.main.args.cache:
            os.environ["cache"] = "true"
        if self.main.args.benchmark:
            os.environ["benchmark"] = self.main.args.benchmark
        if self.main.args.share_why3server:
            os.environ["why3server"] = self.run_why3server()
        return dict(os.environ)

    def set_up(self):
        super().set_up()
        if self.main.args.diffs:
            self.main.args.show_error_output = True
        if not self.main.args.no_gaia_output:
            self.main.args.gaia_output = True
        self.env.discs = self.env.discriminants
        if self.main.args.disc:
            self.env.discs += self.main.args.disc.split(",")
        if self.main.args.only_large:
            self.env.only_large = True
            self.env.discs += ["large"]
        else:
            self.env.only_large = False
        self.env.rewrite_baselines = self.main.args.rewrite
        self.why3_process = None
        self.env.test_environ = self.compute_environ()
        if self.main.args.testlist:
            with open(self.main.args.testlist, "r") as f:
                self.testlist = [s.strip() for s in f]
        else:
            self.testlist = []

    def tear_down(self):
        if self.why3_process:
            self.why3_process.kill()
        super().tear_down()


if __name__ == "__main__":
    sys.exit(SPARKTestsuite().testsuite_main())
