//
// nono
// Copyright (C) 2020 nono project
// Licensed under nono-license.txt
//

#include "mpu680x0.h"
#include "m680x0acc.h"
#include "stopwatch.h"
#include <csignal>
#include <sys/time.h>

// パフォーマンステストの秒数
#define PERF_SEC	(3)

[[noreturn]] void usage();

static int argc;					// getopt(3) 処理後の argc
static char **argv;					// getopt(3) 処理後の argv
static const char *testname;		// 現在のテスト名
static int total_errcnt;			// 総エラー数
static int errcnt;					// 現在のセクションのエラー数
static char where[256];				// エラー発生箇所
static volatile int perf_signaled;	// パフォーマンス測定用
static Stopwatch sw;				// パフォーマンス測定用
static uint64 perf_count;			// パフォーマンス測定用。処理回数

// value を16進数で指定のビット数分の文字列にする
static std::string
hex(uint32 value, int sz)
{
	char buf[32];

	switch (sz) {
	 case 8:
		snprintf(buf, sizeof(buf), "%02x", value);
		break;
	 case 16:
		snprintf(buf, sizeof(buf), "%04x", value);
		break;
	 case 32:
	 default:
		snprintf(buf, sizeof(buf), "%08x", value);
		break;
	}
	return std::string(buf);
}

// 整数(32bitまで)を検査する
static void __unused
xp_equal(uint32 expected, uint32 actual, const char *name)
{
	if (expected != actual) {
		if (errcnt == 0) {
			printf("\n");
		}
		printf(" %s: %s expects 0x%08x but 0x%08x\n",
			where, name, expected, actual);
		errcnt++;
	}
}

// 値(32bitまで)と CCR を検査する
static void
xp_equal(uint32 actValue, m680x0CCR& actCCR, uint32 expValue,
	bool expX, bool expN, bool expZ, bool expV, bool expC)
{
	if (expValue != actValue ||
		expX != actCCR.IsX() ||
		expN != actCCR.IsN() ||
		expZ != actCCR.IsZ() ||
		expV != actCCR.IsV() ||
		expC != actCCR.IsC())
	{
		if (errcnt == 0) {
			printf("\n");
		}
		printf("%s %s expects %08x %c%c%c%c%c but %08x %c%c%c%c%c\n",
			testname, where,
			expValue,
			(expX ? 'X' : '-'),
			(expN ? 'N' : '-'),
			(expZ ? 'Z' : '-'),
			(expV ? 'V' : '-'),
			(expC ? 'C' : '-'),
			actValue,
			(actCCR.IsX() ? 'X' : '-'),
			(actCCR.IsN() ? 'N' : '-'),
			(actCCR.IsZ() ? 'Z' : '-'),
			(actCCR.IsV() ? 'V' : '-'),
			(actCCR.IsC() ? 'C' : '-')
		);
		errcnt++;
	}
}

// 値(64bit)と CCR を検査する
static void __unused
xp_equal64(uint64 actValue, m680x0CCR& actCCR, uint64 expValue,
	bool expX, bool expN, bool expZ, bool expV, bool expC)
{
	if (expValue != actValue ||
		expX != actCCR.IsX() ||
		expN != actCCR.IsN() ||
		expZ != actCCR.IsZ() ||
		expV != actCCR.IsV() ||
		expC != actCCR.IsC())
	{
		if (errcnt == 0) {
			printf("\n");
		}
		printf("%s %s expects %08x_%08x %c%c%c%c%c but %08x_%08x %c%c%c%c%c\n",
			testname, where,
			(uint32)(expValue >> 32),
			(uint32)(expValue & 0xffffffff),
			(expX ? 'X' : '-'),
			(expN ? 'N' : '-'),
			(expZ ? 'Z' : '-'),
			(expV ? 'V' : '-'),
			(expC ? 'C' : '-'),
			(uint32)(actValue >> 32),
			(uint32)(actValue & 0xffffffff),
			(actCCR.IsX() ? 'X' : '-'),
			(actCCR.IsN() ? 'N' : '-'),
			(actCCR.IsZ() ? 'Z' : '-'),
			(actCCR.IsV() ? 'V' : '-'),
			(actCCR.IsC() ? 'C' : '-')
		);
		errcnt++;
	}
}

// このテストを実行するかどうか
// 引数なしなら全部実行。
// 引数ありなら一致すれば実行。
static bool
check_exec(const char *name)
{
	if (argc == 0) {
		return true;
	}

	// "test_"/"perf_" を除いた後ろが完全一致するか
	name += 5;
	for (int i = 0; i < argc; i++) {
		if (strcmp(name, argv[i]) == 0) {
			return true;
		}
	}

	// サイズ部を除いて一致するか
	// add に対して addx が一致しないようにアンダーバーまでで調べる
	std::string name2(name);
	int u = name2.find('_');
	if (u != std::string::npos) {
		name2.erase(u, name2.size() - u);
	}
	for (int i = 0; i < argc; i++) {
		if (strcmp(name2.c_str(), argv[i]) == 0) {
			return true;
		}
	}

	return false;
}

#define start_test() \
	if (!check_exec(__FUNCTION__))	\
		return;	\
	start_test_func(__FUNCTION__)

static void
start_test_func(const char *name)
{
	testname = name;
	printf("%s ", testname);
	fflush(stdout);
	errcnt = 0;
}

static void
end_test()
{
	if (errcnt == 0) {
		printf("ok\n");
	} else {
		printf("%d error(s)\n", errcnt);
	}
	total_errcnt += errcnt;
}

#define START_PERF \
	if (!check_exec(__FUNCTION__))	\
		return;	\
	start_perf_func(__FUNCTION__)

#define END_PERF	\
	end_perf_func()

static void
signal_alarm(int signo)
{
	perf_signaled = 1;
}

static void
start_perf_func(const char *name)
{
	testname = name;
	printf("%s\t... ", testname);
	fflush(stdout);

	perf_count = 0;
	perf_signaled = 0;
	signal(SIGALRM, signal_alarm);
	struct itimerval it = {};
	it.it_value.tv_sec = PERF_SEC;

	sw.Restart();
	setitimer(ITIMER_REAL, &it, NULL);
}

static void
end_perf_func()
{
	sw.Stop();

	uint64 usec = sw.Elapsed_nsec() / 1000U;
	printf("%" PRIu64 " times/usec\n", perf_count / usec);
}

// 乱数
static uint32
xor32()
{
	static uint32 y = 2463534242;
	y = y ^ (y << 13);
	y = y ^ (y >> 17);
	y = y ^ (y << 5);
	return y;
}

// sz ビット目(最下位を0とする)を立てた値を返す
// sz は 0..63
static uint64
BIT(int sz)
{
	return (1ULL << sz);
}

// 下位 sz ビットがすべて 1 の値を返す
static uint32
MASK(int sz)
{
	return (1ULL << sz) - 1;
}

// value を sz ビット符号付き整数とした時、負かどうかを返す
// sz は 8, 16, 32
static bool
ISNEG(uint64 value, int sz)
{
	return ((value & BIT(sz - 1)) != 0);
}

//
// add,sub
//

static uint32 add_table[] = {
	0x00000000,
	0x00000001,
	0x0000007f,
	0x00000080,
	0x000000ff,
	0x00007fff,
	0x00008000,
	0x0000ffff,
	0x7fffffff,
	0x80000000,
	0xffffffff,
};

static void
test_add(m680x0ACC& acc, int sz)
{
	for (int i = 0; i < countof(add_table); i++) {
		for (int j = 0; j < countof(add_table); j++) {
			uint64 src = add_table[i] & MASK(sz);
			uint64 dst = add_table[j] & MASK(sz);

			uint64 tmp = src + dst;
			uint64 res = tmp & MASK(sz);
			std::string w = hex(src, sz) + " + " + hex(dst, sz);
			strlcpy(where, w.c_str(), sizeof(where));
			bool C = tmp & BIT(sz);
			bool V = ((res ^ src) & (res ^ dst)) & BIT(sz - 1);

			uint32 actual;
			if (sz == 8) {
				actual = acc.add_8(src, dst);
			} else if (sz == 16) {
				actual = acc.add_16(src, dst);
			} else {
				actual = acc.add_32(src, dst);
			}

			xp_equal(actual, acc, res,
				C, ISNEG(res, sz), (res == 0), V, C);
		}
	}
}

static void
test_add_8(m680x0ACC& acc)
{
	start_test();
	test_add(acc, 8);
	end_test();
}

static void
test_add_16(m680x0ACC& acc)
{
	start_test();
	test_add(acc, 16);
	end_test();
}

static void
test_add_32(m680x0ACC& acc)
{
	start_test();
	test_add(acc, 32);
	end_test();
}

static void
test_sub(m680x0ACC& acc, int sz)
{
	for (int i = 0; i < countof(add_table); i++) {
		for (int j = 0; j < countof(add_table); j++) {
			uint64 src = add_table[i] & MASK(sz);
			uint64 dst = add_table[j] & MASK(sz);

			uint64 tmp;
			if (sz == 8) {
				tmp = (uint64)dst - (uint64)src;
			} else if (sz == 16) {
				tmp = (uint64)dst - (uint64)src;
			} else {
				tmp = (uint64)dst - (uint64)src;
			}
			uint64 res = tmp & MASK(sz);
			std::string w = hex(dst, sz) + " - " + hex(src, sz);
			strlcpy(where, w.c_str(), sizeof(where));
			bool C = tmp & BIT(sz);
			bool V = ((src ^ dst) & (res ^ dst)) & BIT(sz - 1);

			uint32 actual;
			if (sz == 8) {
				actual = acc.sub_8(src, dst);
			} else if (sz == 16) {
				actual = acc.sub_16(src, dst);
			} else {
				actual = acc.sub_32(src, dst);
			}

			xp_equal(actual, acc, res,
				C, ISNEG(res, sz), (res == 0), V, C);
		}
	}
}

static void
test_sub_8(m680x0ACC& acc)
{
	start_test();
	test_sub(acc, 8);
	end_test();
}

static void
test_sub_16(m680x0ACC& acc)
{
	start_test();
	test_sub(acc, 16);
	end_test();
}

static void
test_sub_32(m680x0ACC& acc)
{
	start_test();
	test_sub(acc, 32);
	end_test();
}


//
// addx
//
static struct {
	uint32 src, dst;
	bool inX;
	uint32 res;
	bool expN, expZ, expV, expC;
} addx32_table[] = {
	// src        dst         X     res         N  Z  V  C
	{ 0x00000000, 0xffffffff, 0,	0xffffffff, 1, 0, 0, 0 },
	{ 0x00000000, 0xffffffff, 1,	0x00000000, 0, 1, 0, 1 },
	{ 0x7fffffff, 0x00000001, 0,	0x80000000, 1, 0, 1, 0 },
	{ 0x7fffffff, 0x00000000, 1,	0x80000000, 1, 0, 1, 0 },
	{ 0x7fffffff, 0x80000000, 1,	0x00000000, 0, 1, 0, 1 },
};

static void
test_addx_32(m680x0ACC& acc)
{
	start_test();
	for (int i = 0; i < countof(addx32_table); i++) {
		uint32 src = addx32_table[i].src;
		uint32 dst = addx32_table[i].dst;
		bool   inX = addx32_table[i].inX;
		uint32 res = addx32_table[i].res;
		bool  expN = addx32_table[i].expN;
		bool  expZ = addx32_table[i].expZ;
		bool  expV = addx32_table[i].expV;
		bool  expC = addx32_table[i].expC;
		snprintf(where, sizeof(where), "%08x + %08x + %d", src, dst, inX);

		acc.SetX(inX);
		acc.SetZ(true);
		uint32 actual = acc.addx_32(src, dst);
		xp_equal(actual, acc, res, expC, expN, expZ, expV, expC);
	}
	end_test();
}

//
// rotate/shift
//

// ローテート/シフト系のパフォーマンス測定
#define DEFINE_PERF_ROTATE(name)	\
static void	\
__CONCAT(perf_,name)(m680x0ACC& acc)	\
{	\
	START_PERF;	\
	volatile uint32 dst = 0;	\
	for (; perf_signaled == 0;) {	\
		for (int count = 0; count < 32; count++) {	\
			uint32 src = xor32();	\
			dst ^= acc.name(src, count);	\
			perf_count++;	\
		}	\
	}	\
	(void)dst;	\
	END_PERF;	\
}
DEFINE_PERF_ROTATE(asl_32)
DEFINE_PERF_ROTATE(lsl_32)
DEFINE_PERF_ROTATE(roxl_32)
DEFINE_PERF_ROTATE(rol_32)
DEFINE_PERF_ROTATE(asr_32)
DEFINE_PERF_ROTATE(lsr_32)
DEFINE_PERF_ROTATE(roxr_32)
DEFINE_PERF_ROTATE(ror_32)

static void
perf_add_32(m680x0ACC& acc)
{
	START_PERF;
	volatile uint32 dst = 0;
	for (; perf_signaled == 0; ) {
		uint32 src = xor32();
		dst = acc.add_32(src, dst);
		perf_count++;
	}
	END_PERF;
}

int
main(int ac, char *av[])
{
	int c;
	bool do_test;
	bool do_perf;

	do_test = true;
	do_perf = true;

	while ((c = getopt(ac, av, "tp")) != -1) {
		switch (c) {
		 case 't':	// test only
			do_perf = false;
			break;
		 case 'p':	// perf only
			do_test = false;
			break;
		 default:
			usage();
		}
	}
	ac -= optind;
	av += optind;
	argc = ac;
	argv = av;

	std::unique_ptr<m680x0ACC> accptr(new m680x0ACC());
	m680x0ACC& acc = *(accptr.get());

	if (do_test) {
		test_add_8(acc);
		test_add_16(acc);
		test_add_32(acc);
		test_sub_8(acc);
		test_sub_16(acc);
		test_sub_32(acc);
		test_addx_32(acc);
	}

	if (do_perf) {
		perf_asl_32(acc);
		perf_lsl_32(acc);
		perf_roxl_32(acc);
		perf_rol_32(acc);
		perf_asr_32(acc);
		perf_lsr_32(acc);
		perf_roxr_32(acc);
		perf_ror_32(acc);
		perf_add_32(acc);
	}

	return 0;
}

void
usage()
{
	fprintf(stderr, "usage: %s [-t] [-p]\n", getprogname());
	exit(1);
}
