/*
* Ipq Berkeley db Daemon  WHITE - ibd-white
* written by ale in milano on 17 dec 2008
* edit entries of white.db

Copyright (C) 2008-2021 Alessandro Vesely

This file is part of Ipqbdb.

Ipqbdb is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

Ipqbdb is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with Ipqbdb.  If not, see <http://www.gnu.org/licenses/>.

*/
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <math.h>
#include <unistd.h>
#include <limits.h>
#include <ctype.h>
#include <signal.h>
#include <syslog.h>
#include <errno.h>

#include <sys/types.h>

// Berkeley DB v5.3
#include <db.h>

#include <popt.h>

// format of data packet
#include "config_names.h"
#include "dbstruct.h"

#include <time.h>

// for inet_aton
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>

#include "ip_util.h"
#include "max_range_ip6.h"

static app_private ap;
static char const err_prefix[] = "ibd-white";
#include "setsig_func.h"
#include <assert.h>

static char *db_white_name = IPQBDB_WHITE_DATABASE_NAME;
static double default_decay = (IPQBDB_INITIAL_DECAY/2);
static time_t default_expiration;
static char *ip_address;
static int version_opt, help_opt, syslog_opt, cleanup_opt, force_opt;
static int verbose = -1;
static int list_opt, raw_opt, trunc_opt, read_opt;
static struct poptOption opttab[] =
{
	{"db-white", 'w', POPT_ARG_STRING|POPT_ARGFLAG_SHOW_DEFAULT, &db_white_name, 0,
	"The whitelist database.", "filename"},
	{"ls", 'L', POPT_ARG_NONE, &list_opt, 0,
	"List existing entries (before possibly truncate)", NULL},
	{"ls-raw", '\0', POPT_ARG_NONE, &raw_opt, 0,
	"List without trying to aggregate records", NULL},
	{"truncate", '\0', POPT_ARG_NONE, &trunc_opt, 0,
	"Truncate the database before inserting new entries", NULL},
	{"read", 'R', POPT_ARG_NONE, &read_opt, 0,
	"Read stdin (not needed if redirecting)", NULL},
	{"default-decay", 't', POPT_ARG_DOUBLE|POPT_ARGFLAG_SHOW_DEFAULT, &default_decay, 0,
	"The default decay value (seconds)", "float"},
	{"expiration", 'x', POPT_ARG_INT, &default_expiration, 0,
	"The amount of time that the record will be valid (seconds)", "int[D|H|M|S]"},
	{"ip-addr", 'i', POPT_ARG_STRING, &ip_address, 0,
	"The single address or range to list", "ip[-last|/cidr]"},
	{"verbose", 'v', POPT_ARG_INT|POPT_ARGFLAG_OPTIONAL, &verbose, 0,
	"Be verbose", "level"},
	{"log-syslog", 'l', POPT_ARG_NONE, &syslog_opt, 0,
	"Log to syslog rather than std I/O", NULL},
	{"db-cleanup", '\0', POPT_ARG_NONE, &cleanup_opt, 0,
	"On exit cleanup environment (__db.00? files) if not still busy", NULL},
	{"force-insane", '\0', POPT_ARG_NONE, &force_opt, 0,
	"Allow inserting more than 65536 records per input line", NULL},
	{"version", 'V', POPT_ARG_NONE, &version_opt, 0,
	"Print version number and exit", NULL},
	{"help", 'h', POPT_ARG_NONE, &help_opt, 0,
	"This help.", NULL},
	POPT_TABLEEND
};

static int get_ip_address(ip_range *ip)
{
	if (ip_address == NULL || *ip_address == 0) // no argument
	{
		memset(ip, 0, sizeof *ip);
		memset(ip->u2.ipv6, 0xff, sizeof ip->u2.ipv6); // also ok for IPv4
		return 0;
	}

	int err = parse_ip_address(ip_address, ip, NULL);
	if (err)
		report_error(&ap, LOG_ERR, "invalid %s: %s (err=%d)\n",
			parse_ip_invalid_what(err), ip_address, err);
	return err;
}

static char const *date_format[] =
{
	"%a, %e %b %Y %T %z", // RFC 5322
	"%a, %d %b %Y %T %z", // ditto
	"%d %b %Y %T %z",     // without day
	"%d %b %Y %T",        // without timezone
	"%d %b %Y %H:%M",     // without seconds
	"%d %b %Y"            // without time
};

static void
print_line(ip_range *cur, double decay, time_t expiration, int *have_printed)
{
	if (verbose && !*have_printed)
	{
		printf("%30s %12s %s\n", "IP or RANGE", "DECAY", "EXPIRE");
		*have_printed = 1;
	}

	char rangebuf[INET_RANGESTRLEN];
	char timebuf[40];
	char const *expired = "";
	if (expiration)
	{
		struct tm *tm = localtime(&expiration);
		timebuf[0] = ' ';
		if (strftime(timebuf+1, sizeof timebuf -1 , date_format[0], tm) == 0)
		{
			report_error(&ap, LOG_ERR, "strftime failed\n");
			strcat(timebuf+1, "<error>");
		}
		else if (expiration < time(0))
			expired = " *EXPIRED*";
	}
	else
		timebuf[0] = 0;

	printf("%30s %12g%s%s\n",
		snprint_range(rangebuf, sizeof rangebuf, cur),
		decay, timebuf, expired);
}

static int do_list(DB *db, ip_range *ip)
{
	assert(ip->ip == 4 || ip->ip == 6);

	unsigned total_found = 0;
	time_t const now = time(NULL);

	double decay = 0.0;
	time_t expiration = 0;
	unsigned key_size = ip->ip == 4? 4: 16;
	unsigned char cur_key[16];
	memcpy(cur_key, ip->u.ip_data, key_size);
	unsigned char ip_last[16];
	memcpy(ip_last, ip->u2.ip_data, key_size);

	DBT key, data;
	memset(&key, 0, sizeof key);
	memset(&data, 0, sizeof data);

	key.data = cur_key;
	key.ulen = key.size = sizeof cur_key;
	key.flags = data.flags = DB_DBT_USERMEM;
	int rtc = db->get_pagesize(db, &data.ulen);
	data.ulen *= 32;
	data.flags = DB_DBT_USERMEM;
	if (rtc ||
		(data.data = malloc(data.ulen)) == NULL)
	{
		report_error(&ap, LOG_ERR, "memory fault\n");
		return -1;
	}

	DBC *curs;
	rtc = db->cursor(db, NULL, &curs, 0);
	if (rtc)
	{
		db->err(db, rtc, "cannot create cursor");
		free(data.data);
		return rtc;
	}

	int have_cur = 0, have_printed = 0;
	ip_range cur;
	memset(&cur, 0, sizeof cur);
	cur.ip = ip->ip;

	void *p, *retkey, *retdata;
	size_t retklen, retdlen;
	rtc = curs->get(curs, &key, &data, DB_MULTIPLE_KEY|DB_SET_RANGE);
	while (rtc == 0)
	{
		for (DB_MULTIPLE_INIT(p, &data);;)
		{
			DB_MULTIPLE_KEY_NEXT(p, &data, retkey, retklen, retdata, retdlen);
			if (p == NULL ||
				(rtc = memcmp(ip_last, retkey, key_size) < 0?
					DB_NOTFOUND: 0) != 0)
						break;

			ip_white_t *const ip_white = (ip_white_t*)retdata;
			if (caught_signal ||
				retdlen < sizeof *ip_white || retklen != key_size ||
				ip_white->chk != IPQBDB_CHK_SIGNATURE)
			{
				report_error(&ap, LOG_ERR, "reading data: %s%s%s%sstop.\n",
					caught_signal? "signal, ": "",
					retdlen < sizeof *ip_white? "bad data length, ": "",
					retklen != key_size? "bad key length, ": "",
					ip_white->chk != IPQBDB_CHK_SIGNATURE? "bad record data, ": "");
				
				rtc = 1;
				break;
			}

			memcpy(cur_key, retkey, key_size);

			if (have_cur)
			{
				int is_next_key;

				if (key_size == 4)
				{
					uint32_t next = htonl(ntohl(cur.u2.ipv4l) + 1);
					is_next_key = memcmp(cur_key, &next, 4) == 0;
				}
				else
				{
					unsigned char next[16];
					memcpy(next, cur.u2.ipv6, key_size);
					add_one(next);
					if (ip_white->plen == 0)
						is_next_key = memcmp(cur_key, next, 16) == 0;
					else
					{
						unsigned char first[16];
						memcpy(first, cur_key, 16);
						first_in_range(first, ip_white->plen);
						is_next_key = memcmp(first, next, 16) == 0;
					}
				}

				if (is_next_key &&
					decay == ip_white->decay &&
					expiration == ip_white->expiration)
				{
					memcpy(cur.u2.ip_data, cur_key, key_size);
					cur.args = 2;
				}
				else
				{
					print_line(&cur, decay, expiration, &have_printed);
					have_cur = 0;
				}
			}

			if (!have_cur)
			{
				decay = ip_white->decay;
				expiration = ip_white->expiration;
				memcpy(cur.u.ip_data, cur_key, key_size);
				memcpy(cur.u2.ip_data, cur_key, key_size);
				if (ip_white->plen)
				{
					assert(key_size == 16);
					first_in_range(cur.u.ipv6, ip_white->plen);
					last_in_range(cur.u2.ipv6, ip_white->plen);
					cur.args = 2;
				}
				else
					cur.args = 1;

				if (raw_opt)
					print_line(&cur, decay, expiration, &have_printed);
				else
					have_cur = 1;
			}

			total_found += 1;
		}

		rtc = curs->get(curs, &key, &data, DB_MULTIPLE_KEY|DB_NEXT);
	}

	assert(rtc);
	if (rtc != DB_NOTFOUND && rtc != 1)
		db->err(db, rtc, "cannot read cursor");

	if (have_cur)
		print_line(&cur, decay, expiration, &have_printed);

	int rtc2 = curs->close(curs);
	if (rtc2)
	{
		db->err(db, rtc2, "cannot close cursor");
		if (rtc == 0 || rtc == DB_NOTFOUND)
			rtc = rtc2;
	}

	free(data.data);

	if (verbose)
	{
		char buf[64];
		time_t const end = time(NULL);
		if (end != now)
			sprintf(buf, " in %d sec(s)", (int)(end - now));
		else
			buf[0] = 0;
		printf("%u record(s) found%s\n",
			total_found, buf);
	}

	return rtc != 0 && rtc != DB_NOTFOUND;
}

static int
read_line(int line, ip_range *ip, double *decay, time_t *expiration)
// return -1=hard error or EOF, 1=bad input, 2=interactive not wanted
{
	char buf[1024];
	int const prompt = isatty(fileno(stdin));
	
	if (prompt)
	{
		printf("%s> ", err_prefix);
		fflush(stdin);
	}
	char *s = fgets(buf, sizeof buf, stdin);
	if (s)
	{
		unsigned len = strlen(s);
		int ch, err;

		if (len + 1 >= sizeof buf && s[len-1] != '\n')
		{
			unsigned extra = 1;
			ch = fgetc(stdin);
			while (ch != '\n' && ch != EOF)
			{
				ch = fgetc(stdin);
				++extra;
			}
			report_error(&ap, LOG_ERR,
				"line %d is %u chars too long\n", line, extra);
			return 1;
		}

		while (isspace(ch = *(unsigned char*)s))
			++s;

		if (ch == 0 || ch == '#')
			return prompt? 2: 1;

		/*
		* IP address or range must be given without spaces
		*/
		char *t = s + 1;
		while (!isspace(ch = *(unsigned char*)t))
			++t;
		*t = 0;
		err = parse_ip_address(s, ip, NULL);
		if (err || ip->args <= 0)
		{
			report_error(&ap, LOG_ERR, "invalid %s: %s (err=%d) at line %d\n",
				parse_ip_invalid_what(err), s, err, line);
			return 1;
		}

		*t = ch;
		s = t;
		while (isspace(ch = *(unsigned char*)s))
			++s;

		*decay = 0.0;
		if (ch)
		{
			*decay = strtod(s, &t);

			int type = fpclassify(*decay);
			if ((type != FP_NORMAL && type != FP_ZERO) || *decay < 0.0)
			{
				report_error(&ap, LOG_ERR,
					"invalid float %s at line %d\n", s, line);
				return 1;
			}

			s = t;
			while (isspace(ch = *(unsigned char*)s))
				++s;
		}

		*expiration = 0;
		if (ch)
		{
			struct tm tm;
			for (unsigned i = 0;
				i < sizeof date_format/sizeof *date_format; i++)
			{
				memset(&tm, 0, sizeof tm);
				t = strptime(s, date_format[i], &tm);
				if (t)
				{
					*expiration = mktime(&tm);
					break;
				}
			}

			if (*expiration < 0 || t == NULL)
			{
				report_error(&ap, LOG_ERR,
					"invalid date %s at line %d\n", s, line);
				return 1;
			}
		}

		return 0;
	}

	return -1;
}

static int do_line(DB *db, DB *db6, int line)
/*
* return the number of records inserted (>= 0), -1 on EOF, -2 on error
*/
{
	double decay = default_decay;
	time_t expiration = default_expiration;
	ip_range ip;
	int rtc = read_line(line, &ip, &decay, &expiration);

	if (rtc == 2) return -1;

	if (rtc == 0 && caught_signal == 0)
	{
		DBT key, data;
		int inserted = 0;
		ip_white_t white;

		memset(&key, 0, sizeof key);
		memset(&data, 0, sizeof data);
		memset(&white, 0, sizeof white);

		unsigned key_size = ip.ip == 4? 4: 16;
		key.data = &ip.u.ip_data;
		key.size = key.ulen = key_size;
		data.data = &white;
		data.size = data.ulen = sizeof white;
		key.flags = data.flags = DB_DBT_USERMEM;
		
		white.decay = decay;
		white.expiration = expiration;
		white.chk = IPQBDB_CHK_SIGNATURE;
		if (ip.ip == 4)
		{
			uint32_t first = ntohl(ip.u.ipv4l), last = ntohl(ip.u2.ipv4l);

			if (last - first >= 0xffff && !force_opt)
				report_error(&ap, LOG_WARNING,
					"Use --force-insane to insert a class B or more\n");
			else
			{
				do
				{
					rtc = db->put(db, NULL, &key, &data, 0);
					if (rtc)
					{
						db->err(db, rtc, "db->put");
						return -2;
					}
					++inserted;
					ip.u.ipv4l = htonl(++first);
				} while (first <= last && caught_signal == 0);
			}
		}
		else
		{
#if !defined NDEBUG
			ip_range save;
			memcpy(&save, &ip, sizeof save);
#endif //  !defined NDEBUG
			int cmp;
			do // loop for ranges spanning multiple CIDRs
			{
				if (ip.args > 1)
				{
					int plen = max_range(&ip);
					if (plen < 0)
					{
						char buf[INET6_ADDRSTRLEN];
						char buf2[INET6_ADDRSTRLEN];
						report_error(&ap, LOG_ERR,
							"Cannot get max_range(%s, %s)\n",
							inet_ntop(AF_INET6, ip.u.ipv6, buf, sizeof buf),
							inet_ntop(AF_INET6, ip.u2.ipv6, buf2, sizeof buf2));
						return -2;
					}

					white.plen = plen;
					last_in_range(ip.u.ipv6, plen);
				}
				rtc = db->put(db6, NULL, &key, &data, 0);
				if (rtc)
				{
					db->err(db6, rtc, "db->put");
					return -2;
				}
				++inserted;
				add_one(ip.u.ipv6);
				assert(memcmp(ip.u2.ipv6, save.u2.ipv6, 16) == 0);
				cmp = memcmp(ip.u.ipv6, ip.u2.ipv6, 16);
				ip.args = cmp? 2: 1;
			} while (cmp <= 0 && caught_signal == 0);
		}
		return inserted; // good return
	}

	return rtc < 0? -1: 0; // discard syntax errors and continue
}

int do_trunc(switchable_fname *white_fname, DB* db_white)
{
	uint32_t count = 0;
	int rtc = db_white->truncate(db_white, NULL, &count, 0);
	if (rtc)
		db_white->err(db_white, rtc, "truncate");
	else if (verbose)
		printf("%u record(s) discarded from %s\n",
			count, white_fname->fname);
	return rtc;
}

int main(int argc, char const *argv[])
{
	static const char optaliases[] = IPQBDB_OPTION_FILE;
	int rtc = 0, errs = 0;
	ip_range ip;

	poptContext opt = poptGetContext(err_prefix, argc, argv, opttab, 0);

	if (access(optaliases, F_OK) == 0 &&
		(rtc = poptReadConfigFile(opt, optaliases)) < 0)
	{
		fprintf(stderr, "%s: cannot read %s: %s\n",
			err_prefix, optaliases, poptStrerror(rtc));
		errs = 3;
	}

	rtc = poptGetNextOpt(opt);
	if (rtc != -1)
	{
		fprintf(stderr, "%s: %s\n",
			err_prefix, poptStrerror(rtc));
		errs = 1;
	}
	else
	{
		if (poptPeekArg(opt) != NULL)
		{
			fprintf(stderr, "%s: unexpected argument: %s\n",
				err_prefix, poptGetArg(opt));
			errs = 1;
		}

		if (version_opt)
		{
			fprintf(stdout, "%s: version " PACKAGE_VERSION "\n", err_prefix);
			errs = 2;
		}

		if (help_opt)
		{
			poptPrintHelp(opt, stdout, 0);
			fputs_database_help();
			errs = 2;
		}

		if (get_ip_address(&ip))
			errs = 3;

		// popt sets verbose to 0 if no arg is given
		verbose += 1;

		ap.mode = error_report_stderr; // 0
		ap.err_prefix = err_prefix;
		if (syslog_opt)
		{
			openlog(err_prefix, LOG_PID, LOG_DAEMON);
			ap.mode = LOG_DAEMON;
		}
	}

	if (errs == 1)
		poptPrintUsage(opt, stdout, 0);
	poptFreeContext(opt);
	rtc = 0;

	if (errs)
		rtc = 1;
	else
	{
		switchable_fname *white_fname = database_fname(db_white_name, &ap);

		if (white_fname == NULL)
			rtc = 1;

		if (rtc)
		{
			if (verbose)
				report_error(&ap, LOG_INFO, "Bailing out...\n");
		}
		else
		{
			DB_ENV *db_env = NULL;
			DB *db_white = NULL, *db6_white = NULL;

			setsigs();

			rtc = open_database(white_fname, &ap, &db_env, &db_white, &db6_white);

			if (rtc == 0 && (list_opt || raw_opt) && caught_signal == 0)
			{
				if (verbose)
				{
					if (ip.args > 0)
					{
						char range[INET_RANGESTRLEN];
						printf("list entries %s %s\n",
							ip.args > 1? "in IP range": "for IP",
							snprint_range(range, sizeof range, &ip));
					}
					else
						printf("list all records\n");
				}
				if (ip.args > 0)
					rtc = do_list(ip.ip == 6? db6_white: db_white, &ip);
				else
				{
					ip.ip = 4;
					rtc = do_list(db_white, &ip);
					if (rtc == 0)
					{
						ip.ip = 6;
						rtc = do_list(db6_white, &ip);
					}
				}
			}

			if (rtc == 0 && trunc_opt && caught_signal == 0)
			{
				make_four(white_fname);
				rtc = do_trunc(white_fname, db_white);
				if (rtc == 0)
				{
					make_six(white_fname);
					rtc = do_trunc(white_fname, db6_white);
				}
			}

			if (rtc == 0 && caught_signal == 0 && // used to be unconditional in v.1
				(!isatty(fileno(stdin)) || read_opt))
			{
				unsigned count = 0;
				unsigned line = 0;
				while (!feof(stdin) && !ferror(stdin) && caught_signal == 0)
				{
					rtc = do_line(db_white, db6_white, ++line);
					if (rtc < 0)
						break;
					count += rtc;
				}

				if (ferror(stdin))
					report_error(&ap, LOG_ERR, "Error reading stdin: %s`n",
						strerror(errno));
				else if (rtc == -1)
					rtc = 0;
				if (verbose)
					printf("%u record(s) total\n", count);
			}

			if (rtc)
				rtc = 2;

			close_db(db_white);
			close_dbenv(db_env, cleanup_opt);
			if (caught_signal)
				report_error(&ap, LOG_NOTICE,
					"exiting after signal %s\n", strsignal(caught_signal));
		}

		free(white_fname);
	}

	return rtc;
}
