/* transport.c -- implements the server, and transports messages
   Copyright (C) 2004 Maximiliano Pin

   This program is free software; you can redistribute it and/or
   modify it under the terms of the GNU General Public License
   as published by the Free Software Foundation; either version 2
   of the License, or (at your option) any later version.

   This program is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
   GNU General Public License for more details.

   You should have received a copy of the GNU General Public License along
   with this program; if not, write to the Free Software Foundation, Inc.,
   51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
*/

#define _POSIX_SOURCE

#include <unistd.h>		/* read, write, close */
#include <stdio.h>		/* stdlib.h (on some systems) */
#include <stdlib.h>		/* malloc, realloc, free */
#include <string.h>		/* memcpy, memmove, strcpy */
#include <netdb.h>		/* gethostbyname */
#include <netinet/in.h>		/* ntohl, *HTON, *NTOH */
#include <errno.h>		/* errno */
#include "common.h"
#include "uconfig.h"
#include "user_iface.h"
#include "transport.h"
#include "demux.h"
#include "tcp.h"
#include "tls.h"
#include "contact_list.h"
#include "protocol.h"
#include "misc.h"

extern int errno;

#define HELLO_TIMEOUT 20	/* max time waiting for hello */
#define RESOLVER_INTERVAL 3600  /* seconds between resolves (TODO conf) */
#define CHECKER_INTERVAL 60     /* seconds between contact checks (TODO conf) */

#define RECV_ERROR (-1)		/* return values of recv_msg */
#define RECV_INCOMPLETE 0
#define RECV_COMPLETE 1

/* Prototypes */
static void cb_server_data (int fd, void *data);
static void cb_net_input (int fd, void *data);
static void cb_out_connection (int fd, void *data);
static void cb_net_output (int fd, void *data);
static int recv_msg (contact_t *c);
static int net_read (int fd, void *buf, size_t nbytes);
static void send_or_postpone (contact_t *contact);
static contact_t *create_temp_contact (ip_t ip);
static void create_hello_timeout (contact_t *c);
static void cb_hello_timeout (void *data);
static BYTE *inc_out_buffer (contact_t *contact, int bytes);
static void cb_resolv (void *data);
static void cb_check_contacts (void *data);

void
tr_init_server ()
{
	sock_t s;

	s = tcp_listen (cfg.listen_port);
	if (s < 0) {
		ui_output_err ("Cannot bind to port %u. We won't be able "
		               "to receive connections, but we can still "
		               "initiate them.", cfg.listen_port);
		return;
	}

	dmx_add_input_fd (s, cb_server_data, NULL);
}

void
tr_connect (contact_t *contact)
{
	sock_t s;

	CHECK (contact->state.ip && contact->state.socket < 0);

	if (contact->state.cn_state != CS_OFFLINE) {
		ui_output_err ("Program error in tr_connect. Please report. "
		               "Connected without a socket!?");
	}

	s = tcp_connect (contact->state.ip, contact->port);
	CHECK (s >= 0);
	contact->state.socket = s;
	dmx_add_output_fd (s, cb_out_connection, (void *)contact);

	/* if the connection is not accepted nor rejected, the kernel will
	   keep the socket open and trying to connect with exponential
	   back-off, until it gives up; then cb_out_connection is called,
	   the socket is closed, and a new connecting socket will be
	   created when we get here again */

error:
	return;
}

void
tr_disconnect (contact_t *contact)
{
	if (contact->state.hello_timeout) {
		dmx_remove_timer (contact->state.hello_timeout);
		contact->state.hello_timeout = NULL;
	}

	if (contact->state.socket >= 0) {
		dmx_remove_input_fd (contact->state.socket);
		dmx_remove_output_fd (contact->state.socket);
		tls_session_finish (contact);
		close (contact->state.socket);
		contact->state.socket = -1;
		tls_session_deinit (contact);
	}

	if (contact->state.in_buffer) {
		free (contact->state.in_buffer);
		contact->state.in_buffer = NULL;
	}
	contact->state.in_buf_size = 0;
	contact->state.in_buf_rcvd = 0;

	if (contact->state.out_buffer) {
		free (contact->state.out_buffer);
		contact->state.out_buffer = NULL;
	}
	contact->state.in_buf_size = 0;

	contact->state.cn_state = CS_OFFLINE;

	if (contact->nick[0] != '\0') {
		ui_redraw_contacts ();
		ui_output_info ("Contact %s disconnected.", contact->nick);
	}
}

void
tr_send_msg (contact_t *contact, const BYTE *header, int hdr_len,
             const BYTE *payload, int pl_len)
{
	len_t len_msg = hdr_len + pl_len;
	int len_full_msg;
	BYTE *p_full, *p;
	BOOL bcast, was_void;

	bcast = (contact == BROADCAST);
	if (bcast) {
		/* we'll build the full message in the first connected
		   contact, and then we'll copy it to the rest */
		/* TODO we have a problem here; if this is a nick-change-hello,
		   connections in CS_HELLO_PENDING or CS_HANDSHAKING won't
		   know about the nick-change! we cannot send it because
		   we could interfer with TLS handshake...
		   A nice solution would be to implement two HELLO's, one
		   to use before handshake, including version & TLS information
		   (even the public key signature to find the contact_t!),
		   and another to use after handshake, including the nick.
		   So, even the nick and all information we add to HELLO
		   in the future, will be encrypted!
		   (NOTE: this also affects broadcast 'for' below) */
		for (contact = contacts;
		     contact && !CT_IS_CONNECTED (contact);
		     contact = contact->next)
			;
		if (!contact)
			return;
	}
	else if (!CT_IS_CONNECTED (contact) &&
	         contact->state.cn_state != CS_HELLO_PENDING) {
		ui_output_err ("Trying to send data to not connected "
		               "contact. This is a bug, please report it.");
		return;
	}

	len_full_msg = sizeof (len_msg) + len_msg;
	was_void = (contact->state.out_buf_size == 0);
	p = p_full = inc_out_buffer (contact, len_full_msg);
	CHECK (p_full);

	len_msg = LEN_T_HTON (len_msg);
	memcpy (p, &len_msg, sizeof (len_msg));
	p += sizeof (len_msg);
	memcpy (p, header, hdr_len);
	p += hdr_len;
	memcpy (p, payload, pl_len);

	if (was_void) {
		/* if the buffer was void, the socket isn't tied to callback */
		send_or_postpone (contact);
	}

	if (bcast) {
		/* copy full msg to all other connected contacts */
		for (contact = contact->next; contact;
		     contact = contact->next) {
			if (CT_IS_CONNECTED (contact)) {
				was_void = (contact->state.out_buffer == NULL);
				p = inc_out_buffer (contact, len_full_msg);
				CHECK (p);
				memcpy (p, p_full, len_full_msg);
				if (was_void) {
					send_or_postpone (contact);
				}
			}
		}
	}
	return;

error:
	/* TODO printf and ui_output... won't work, what could we do? */
	dmx_stop ();
}

void
tr_resolv (contact_t *contact)
{
	struct hostent *he;

	/* connected/connecting contacts don't change IP address */
	if (contact->state.socket < 0) {
		he = gethostbyname (contact->hname);
		if (he && he->h_addrtype == AF_INET &&
		    he->h_length == sizeof (ip_t)) {
			memcpy (&(contact->state.ip), he->h_addr_list[0],
			        sizeof (ip_t));
			contact->state.ip = ntohl (contact->state.ip);
		}
		else {
			ui_output_err ("Error resolving hostname: %s "
				       "(for contact %s).",
				       contact->hname, contact->nick);
		}
	}
}

void
tr_init_resolver ()
{
	cb_resolv (NULL);
	dmx_add_periodic_alarm (RESOLVER_INTERVAL, -1, cb_resolv, NULL);
}

void
tr_init_checker ()
{
	cb_check_contacts (NULL);
	dmx_add_periodic_alarm (CHECKER_INTERVAL, -1, cb_check_contacts, NULL);
}

/* Called when there is a connection request on 'fd'. */
static void
cb_server_data (int fd, void *data)
{
	sock_t      s;
	ip_t        ip;
	contact_t  *c;

	s = tcp_accept (fd, &ip);
	CHECK (s >= 0);

	c = cl_find_by_ip (ip); /* TODO in the future, find by public key */

	if (!c) {
		c = create_temp_contact (ip);
		CHECK (c);
	}
	else if (c->state.cn_state != CS_OFFLINE) {
		if (CT_IS_CONNECTED (c)) {
			ui_output_info ("%s reconnected, closing old "
			                "connection.", c->nick);
		}
		tr_disconnect (c);
	}
	else if (c->state.socket >= 0) {
		/* this is a connecting socket, just remove it */
		dmx_remove_output_fd (c->state.socket);
		close (c->state.socket);
	}

	c->state.socket = s;
	c->state.cn_state = CS_HELLO_PENDING;
	tls_session_init (c, TRUE);
	dmx_add_input_fd (s, cb_net_input, (void *)c);
	create_hello_timeout (c);
	pr_send_hello (c);

error:
	return;
}

static void
cb_net_input (int fd, void *data)
{
	contact_t *c = (contact_t *)data;
	int ret;

	if (c->state.cn_state == CS_HANDSHAKING) {
		/* ignore data received during TLS handshake */
		return;
	}

	ret = recv_msg (c);
	if (ret == RECV_COMPLETE) {
		ret = pr_msg_received (c, c->state.in_buffer,
		                       c->state.in_buf_size);
		if (ret < 0) {
			tr_disconnect (c);
		}
		free (c->state.in_buffer);
		c->state.in_buffer = NULL;
		c->state.in_buf_size = 0;
		c->state.in_buf_rcvd = 0;
	}
	else if (ret == RECV_ERROR) {
		/* error or disconnection */
		tr_disconnect (c);
	}
}

static void
cb_out_connection (int fd, void *data)
{
	contact_t *c = (contact_t *)data;

	if (fd != c->state.socket || c->state.cn_state != CS_OFFLINE) {
		ui_output_err ("Program error in cb_out_connection. Please "
		               "report. fd=%d socket=%d.", fd, c->state.socket);
	}
	dmx_remove_output_fd (fd);

	if (tcp_connect_result (fd) == OK) {
		tls_session_init (c, FALSE);
		c->state.cn_state = CS_HELLO_PENDING;
		dmx_add_input_fd (fd, cb_net_input, data);
		create_hello_timeout (c);
		pr_send_hello (c);
	}
	else {
		close (fd);
		c->state.socket = -1;
	}
}

static void
cb_net_output (int fd, void *data)
{
	contact_t *c = (contact_t *)data;
	BYTE *buf = c->state.out_buffer;
	len_t cur_size = c->state.out_buf_size;
	int b = 0;

	if (buf) {
		b = tls_send (c, buf, cur_size);
		if (b < 0) {
			ui_output_err ("Error sending data to %s.", c->nick);
			b = 0;
		}
	}

	cur_size -= b;
	c->state.out_buf_size = cur_size;

	if (cur_size == 0) {
		if (buf) {
			free (buf);
			c->state.out_buffer = NULL;
		}
		dmx_remove_output_fd (c->state.socket);
	}
	else if (b > 0) {
		/* move remaining data to the beginning */
		memmove (buf, &buf[b], cur_size);
		buf = c->state.out_buffer = realloc (buf, cur_size);
		if (!buf)
			dmx_stop ();
	}
}

static int
recv_msg (contact_t *c)
{
	BYTE   *buf = c->state.in_buffer;
	int     s = c->state.socket;
	len_t   rcvd;
	int     b;

	if (!buf) {
		/* begin receiving, or continue receiving, the length of
		   the next message */
		rcvd = c->state.in_buf_rcvd;
		// TODO CHECK (c->state.in_buf_rcvd <= sizeof (len_t));
		b = tls_recv (c, ((BYTE *)(&(c->state.in_buf_size))) + rcvd,
		              sizeof (len_t) - rcvd);
		CHECK (b >= 0);

		c->state.in_buf_rcvd += b;
		if (c->state.in_buf_rcvd < sizeof (len_t))
			return RECV_INCOMPLETE;

		/* now the length is complete */
		c->state.in_buf_size = LEN_T_NTOH (c->state.in_buf_size);
		CHECK (c->state.in_buf_size > 0); /* TODO check limit? */

		c->state.in_buf_rcvd = 0;
		buf = c->state.in_buffer = malloc (c->state.in_buf_size + 1);
		CHECK_J (c->state.in_buffer, critical);
	}

	/* receive or continue receiving the message */
	rcvd = c->state.in_buf_rcvd;
	b = tls_recv (c, buf + rcvd, c->state.in_buf_size - rcvd);
	CHECK (b >= 0);
	c->state.in_buf_rcvd += b;

	if (c->state.in_buf_rcvd == c->state.in_buf_size) {
		return RECV_COMPLETE;
	}
	else {
		return RECV_INCOMPLETE;
	}

critical:
	dmx_stop ();
error:
	return RECV_ERROR;
}

/* Wrapper over read(). Returns ERR when there is some fatal error or
   the peer disconnected. Returns 0 when there is no data but not
   because the peer disconnected. Otherwise, returns number of bytes read.
   TODO: distinguish error and disconnect so errors are displayed. */
static int
net_read (int fd, void *buf, size_t nbytes)
{
	ssize_t b;
	int     r;

	b = read (fd, buf, nbytes);
	if (b > 0) {
		r = b;
	}
	else if (b == 0) {
		r = ERR;
	}
	else {
		if (errno == EINTR || errno == EAGAIN)
			r = 0;
		else
			r = ERR;
	}

	return r;
}

static void
send_or_postpone (contact_t *contact)
{
	cb_net_output (contact->state.socket, (void *)contact);
	if (contact->state.out_buf_size > 0) {
		dmx_add_output_fd (contact->state.socket, cb_net_output,
		                   (void *)contact);
	}
}

static contact_t *
create_temp_contact (ip_t ip)
{
	contact_t *c;

	c = ct_new_contact ();
	CHECK_J (c, end);

	strcpy (c->hname, ipv4_to_string (ip));
	c->state.ip = ip;
	cl_add_contact (c);

end:
	return c;
}

static void
create_hello_timeout (contact_t *c)
{
	/* TODO it shouldn't have a timeout yet, check it */
	c->state.hello_timeout = dmx_add_timer (HELLO_TIMEOUT, cb_hello_timeout,
	                                        (void *)c);
}

static void
cb_hello_timeout (void *data)
{
	ui_output_info ("Someone from %s connected and didn't say hello!",
	                ipv4_to_string (((contact_t *)data)->state.ip));
	tr_disconnect ((contact_t *)data);
}

/* Increase output buffer size, return pointer to new portion of data. */
static BYTE *
inc_out_buffer (contact_t *contact, int bytes)
{
	len_t new_size, old_size;
	BYTE *p;

	old_size = contact->state.out_buf_size;
	new_size = old_size + bytes;
	p = contact->state.out_buffer = realloc (contact->state.out_buffer,
	                                         new_size);
	if (!p)
		return p;

	contact->state.out_buf_size = new_size;

	return &(contact->state.out_buffer[old_size]);
}

static void
cb_resolv (void *data)
{
	contact_t *c;

	/* TODO if pthreads are available, use them so the program
	   doesn't block while we resolv (or use libadns) */

/* TODO this should be a configurable setting */
#ifdef IPCHAT_VERBOSE
	ui_output_info ("Resolving hostnames...");
#endif
	for (c = contacts; c; c = c->next) {
		tr_resolv (c);
	}
#ifdef IPCHAT_VERBOSE
	ui_output_info ("Resolving done.");
#endif
}

/* Try to connect to all disconnected contacts. */
static void
cb_check_contacts (void *data)
{
	contact_t *c;

	for (c = contacts; c; c = c->next) {
		/* no need to check connected/connecting contacts */
		if (c->state.socket < 0) {
			tr_connect (c);
		}
	}
}
