/* $NetBSD$ */

/*-
 * Copyright (C) 2006,2008 Jared D. McNeill <jmcneill@invisible.ca>.
 * All rights reserved.
 * 
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 * 1. Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 * 3. All advertising materials mentioning features or use of this software
 *    must display the following acknowledgement:
 *	This product includes software developed by Jared D. McNeill.
 * 4. Neither the name of the author nor the names of its contributors
 *    may be used to endorse or promote products derived from this software
 *    without specific prior written permission.
 * 
 * THIS SOFTWARE IS PROVIDED BY THE REGENTS AND CONTRIBUTORS ``AS IS'' AND
 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
 * ARE DISCLAIMED.  IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE
 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
 * SUCH DAMAGE.
 */

#include <sys/cdefs.h>

#include "porting.h"

#ifdef __COPYRIGHT
__COPYRIGHT("@(#) Copyright (C) 2006,2008\n\
	Jared D. McNeill <jmcneill@invisible.ca>. All rights reserved.\n");
#endif
__RCSID("$NetBSD$");

#include <sys/types.h>
#include <sys/param.h>
#include <sys/stat.h>

#include <arpa/inet.h>

#include <pthread.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <syslog.h>
#include <stdio.h>
#include <errno.h>

#include "pathnames.h"
#include "conf.h"
#include "transmit.h"

#ifdef __linux__
static char *__progname = NULL;

static char *
getprogname(void)
{
	return __progname;
}

static void
setprogname(char *p)
{
	__progname = p;
}
#endif

#define	TFTP_OPCODE(b)		((b)[0] << 8 | (b)[1])
#define	TFTP_OPCODE_RRQ		0x0001
#define	TFTP_OPCODE_ERR		0x0005
#define	TFTP_OPCODE_OACK	0x0006
#define		TFTP_ERR_FILENOTFOUND		0x0001

static void	usage(void);

static int	mtftpd_listen(mtftpd_conf_t *);
static void	mtftpd_process(mtftpd_conf_t *, uint8_t *, ssize_t,
			       struct sockaddr_in *);
static void	mtftpd_process_rrq(mtftpd_conf_t *, uint8_t *, ssize_t,
				   struct sockaddr_in *);
static void	mtftpd_send_error(uint16_t, const char *, struct sockaddr_in *);
static void	mtftpd_send_oack(int, in_addr_t, uint16_t, uint16_t, uint32_t,
				 struct sockaddr_in *);

static int	sock;
static int	dofork = 1;

int
main(int argc, char *argv[])
{
	const char *config_file = MTFTPD_PLIST;
	mtftpd_conf_t *mc;
	pid_t pid;
	int rv, ch;

	setprogname(argv[0]);

	while ((ch = getopt(argc, argv, "df:h")) != -1) {
		switch (ch) {
		case 'd':
			dofork = !dofork;
			break;
		case 'f':
			config_file = optarg;
			break;
		case 'h':
		default:
			usage();
			/* NOTREACHED */
		}
	}
	argc -= optind;
	argv += optind;

	conf_init();
	mc = conf_reload(config_file);
	if (mc == NULL) {
		fprintf(stderr, "%s: Unable to load configuration from %s\n",
		    getprogname(), config_file);
		return 1;
	}

	if (dofork) {
		pid = fork();
		if (pid == -1) {
			perror("fork");
			return 2;
		}
		if (pid > 0)
			return 0;
	}

	rv = mtftpd_listen(mc);

	return rv;
}

static int
mtftpd_listen(mtftpd_conf_t *mc)
{
	uint8_t buf[1500];
	struct sockaddr_in sin;
	socklen_t len;
	ssize_t sz;
	int rv;

	sock = socket(PF_INET, SOCK_DGRAM, IPPROTO_UDP);
	if (sock < 0) {
		syslog(LOG_ERR, "socket(...) failed: %s", strerror(errno));
		return errno;
	}

	memset(&sin, 0, sizeof(struct sockaddr_in));
	sin.sin_family = AF_INET;
	sin.sin_port = htons(mc->mc_port);
	sin.sin_addr.s_addr = INADDR_ANY;
	rv = bind(sock, (struct sockaddr *)&sin, sizeof(struct sockaddr_in));
	if (rv < 0) {
		syslog(LOG_ERR, "bind(...) failed: %s", strerror(errno));
		return errno;
	}

	do {
		len = sizeof(struct sockaddr_in);
		memset(&sin, 0, sizeof(struct sockaddr_in));
		sz = recvfrom(sock, buf, sizeof(buf), 0,
		    (struct sockaddr *)&sin, &len);
		if (sz == -1)
			continue;

		mtftpd_process(mc, buf, sz, &sin);
	} while (sz != -1);

	return (sz == -1 ? -1 : 0);
}

static void
mtftpd_process(mtftpd_conf_t *mc, uint8_t *buf, ssize_t buflen,
    struct sockaddr_in *sin)
{
	uint16_t opcode;

	/* check request type */
	if (buflen < 4)
		return;
	opcode = TFTP_OPCODE(buf);

	switch (opcode) {
	case TFTP_OPCODE_RRQ:
		mtftpd_process_rrq(mc, buf, buflen, sin);
		break;
	default:
		return;	/* only read requests are supported for now */
	}

	return;
}

static void
mtftpd_process_rrq(mtftpd_conf_t *mc, uint8_t *buf, ssize_t buflen,
    struct sockaddr_in *sin)
{
	struct stat st;
	int off, found;
	const char *fn;
	const char *mode;
	mtftpd_file_t *mf;
	in_addr_t dst;
	uint16_t port;

	/* make sure the last byte is a NULL */
	if (buf[buflen - 1] != '\0')
		return;

	/* get filename */
	off = 2; /* skip opcode */
	fn = (const char *)&buf[off];
	off += strlen(fn) + 1;
	if (off >= buflen)
		return;	/* invalid request */

	/* get mode */
	mode = (const char *)&buf[off];

	/* ignore all other options */

	syslog(LOG_INFO, "read request from %s for %s (%s)\n",
	    inet_ntoa(sin->sin_addr), fn, mode);

	/* see if we have the requested file */
	found = 0;
	TAILQ_FOREACH(mf, &mc->mc_mf, mf_files)
		if (mf && strcmp(mf->mf_filename, fn) == 0) {
			found = 1;
			break;
		}

	if (found == 0 || stat(mf->mf_filepath, &st) == -1) {
		mtftpd_send_error(TFTP_ERR_FILENOTFOUND, "File not found", sin);
		return;
	}

	dst = mf->mf_inetaddr.s_addr;
	port = mf->mf_port;

	mtftpd_send_oack(mf->mf_sock, dst, port, mc->mc_blksize,
	    st.st_size, sin);
	transmit_start(mc, mf);

	return;
}

static void
mtftpd_send_error(uint16_t err, const char *errstr, struct sockaddr_in *sin)
{
	char *buf;
	size_t len;
	ssize_t sent;

	len = 4 + strlen(errstr) + 1;
	buf = malloc(len);
	if (buf == NULL)
		return;

	memset(buf, 0, len);
	/* opcode = 5 */
	buf[1] = TFTP_OPCODE_ERR;
	buf[3] = err;
	memcpy(buf + 4, errstr, strlen(errstr));
	sent = sendto(sock, buf, len, 0, (struct sockaddr *)sin,
	    sizeof(struct sockaddr_in));
	if (sent == -1)
		syslog(LOG_WARNING, "sendto(...) failed: %s", strerror(errno));

	free(buf);

	return;
}

static void
mtftpd_send_oack(int s, in_addr_t dst, uint16_t port, uint16_t blksize,
		 uint32_t tsize, struct sockaddr_in *sin)
{
	struct in_addr ia;
	char buf[100] = { 0 };
	char tmp[10];
	char multicast[24];
	int off, sent;

	ia.s_addr = dst;

	buf[1] = TFTP_OPCODE_OACK;
	off = 2;

	/* block size option */
	strcpy(buf + off, "blksize");
	off += strlen("blksize");
	buf[off++] = '\0';
	sprintf(tmp, "%d", blksize);
	strcpy(buf + off, tmp);
	off += strlen(tmp);
	buf[off++] = '\0';

	/* file size option */
	strcpy(buf + off, "tsize");
	off += strlen("tsize");
	buf[off++] = '\0';
	sprintf(tmp, "%u", tsize);
	strcpy(buf + off, tmp);
	off += strlen(tmp);
	buf[off++] = '\0';

	/* multicast option */
	strcpy(buf + off, "multicast");
	off += strlen("multicast");
	buf[off++] = '\0';
	sprintf(multicast, "%s,%d,%d",
	    inet_ntoa(ia),
	    1758 /* tftp-multicast */,
	    0 /* open-loop always */);
	strcpy(buf + off, multicast);
	off += strlen(multicast);
	buf[off++] = '\0';

	/* send oack */
	sent = sendto(s, buf, off, 0, (struct sockaddr *)sin,
	    sizeof(struct sockaddr_in));

	if (sent == -1)
		syslog(LOG_WARNING, "sendto(...) failed: %s", strerror(errno));

	return;
}

static void
usage(void)
{
	fprintf(stderr, "usage: %s [-d] [-f config]\n", getprogname());
	fprintf(stderr, "       %s [-h]\n", getprogname());

	exit(1);
}
