#include <unistd.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <sys/socket.h>
#include <sys/select.h>
#include <netdb.h>
#include <errno.h>
#include <stdlib.h>
#include <string.h>
#include <stdio.h>
#include <sys/time.h>
#include <signal.h>

#include "xtux.h"

typedef struct {
    char *name;
    int size;
} message_t;

static message_t netmsg_table[NUM_NETMESSAGES] = {
    { DFNTOSTR( NETMSG_NONE ),             sizeof(netmsg_none) },
    { DFNTOSTR( NETMSG_NOOP ),             sizeof(netmsg_noop) },
    { DFNTOSTR( NETMSG_QUERY_VERSION ),    sizeof(netmsg_query_version) },
    { DFNTOSTR( NETMSG_VERSION ),          sizeof(netmsg_version) },
    { DFNTOSTR( NETMSG_TEXTMESSAGE ),      sizeof(netmsg_textmessage) },
    { DFNTOSTR( NETMSG_QUIT ),             sizeof(netmsg_quit) },
    { DFNTOSTR( NETMSG_REJECTION ),        sizeof(netmsg_rejection) },
    { DFNTOSTR( NETMSG_SV_INFO ),          sizeof(netmsg_sv_info) },
    { DFNTOSTR( NETMSG_CHANGELEVEL ),      sizeof(netmsg_changelevel) },
    { DFNTOSTR( NETMSG_START_FRAME ),      sizeof(netmsg_start_frame) },
    { DFNTOSTR( NETMSG_END_FRAME ),        sizeof(netmsg_end_frame) },
    { DFNTOSTR( NETMSG_ENTITY ),           sizeof(netmsg_entity) },
    { DFNTOSTR( NETMSG_MYENTITY ),         sizeof(netmsg_entity) },
    { DFNTOSTR( NETMSG_PARTICLES ),        sizeof(netmsg_particles) },
    { DFNTOSTR( NETMSG_UPDATE_STATUSBAR ), sizeof(netmsg_update_statusbar) },
    { DFNTOSTR( NETMSG_JOIN ),             sizeof(netmsg_join) },
    { DFNTOSTR( NETMSG_READY ),            sizeof(netmsg_ready) },
    { DFNTOSTR( NETMSG_QUERY_SV_INFO ),    sizeof(netmsg_query_sv_info) },
    { DFNTOSTR( NETMSG_CL_UPDATE ),        sizeof(netmsg_cl_update) },
    { DFNTOSTR( NETMSG_GAMEMESSAGE ),      sizeof(netmsg_gamemessage) }
};

static unsigned long bytes_sent, bytes_recieved;
static msec_t last_time;
static void net_reset(void);
static void sig_catcher(int signo);


/* Basically just a wrapper for socket() */
int net_init(void)
{
    int sock;

    net_reset();

    signal(SIGPIPE, sig_catcher);

    if( (sock = socket(AF_INET, SOCK_STREAM, 0)) < 0 )
	perror("socket");
#if DEBUG
    printf("Socket initialised\n");
#endif

    return sock;

}


vector_t net_stats(netstats_t type)
{
    static unsigned long last_in = 0, last_out = 0;
    float recent_in, recent_out;
    msec_t now, msecs;
    vector_t NS;

    /* Byte difference since last call */
    recent_in = bytes_recieved - last_in;
    recent_out = bytes_sent - last_out;
    last_in = bytes_recieved;
    last_out = bytes_sent;

    /* Time diff sine last call */
    now = gettime();
    msecs = now - last_time;
    last_time = now; 

    switch( type ) {
    case NS_TOTAL:
	NS.x = bytes_recieved;
	NS.y = bytes_sent;
	break;
    case NS_RECENT:
	NS.x = recent_in * M_SEC/msecs;
	NS.y = recent_out * M_SEC/msecs;
	break;
    default:
	break;
    }

    return NS;

}

/*
  Any net-message that has an integer data type larger than a byte must be
  converted to and from network byte order!
*/

/* Macro's to save typing */
#define HTONS(a) (a = htons(a))
#define NTOHS(a) (a = ntohs(a))

static netmsg tonet(netmsg msg)
{

    switch( msg.type ) {
    case NETMSG_READY:
	HTONS(msg.ready.view_w);
	HTONS(msg.ready.view_h);
	break;
    case NETMSG_START_FRAME:
	HTONS(msg.start_frame.screenpos.x);
	HTONS(msg.start_frame.screenpos.y);
	break;
    case NETMSG_ENTITY:
    case NETMSG_MYENTITY:
	HTONS(msg.entity.x);
	HTONS(msg.entity.y);
	break;
    case NETMSG_PARTICLES:
	HTONS(msg.particles.x);
	HTONS(msg.particles.y);
	break;
    case NETMSG_UPDATE_STATUSBAR:
	HTONS(msg.update_statusbar.frags);
	HTONS(msg.update_statusbar.ammo);
	break;
    default:
	break;
    }

    return msg;

}


static netmsg fromnet(netmsg msg)
{

    switch( msg.type ) {
    case NETMSG_READY:
	NTOHS(msg.ready.view_w);
	NTOHS(msg.ready.view_h);
	break;
    case NETMSG_START_FRAME:
	NTOHS(msg.start_frame.screenpos.x);
	NTOHS(msg.start_frame.screenpos.y);
	break;
    case NETMSG_ENTITY:
    case NETMSG_MYENTITY:
	NTOHS(msg.entity.x);
	NTOHS(msg.entity.y);
	break;
    case NETMSG_PARTICLES:
	NTOHS(msg.particles.x);
	NTOHS(msg.particles.y);
	break;
    case NETMSG_UPDATE_STATUSBAR:
	NTOHS(msg.update_statusbar.frags);
	NTOHS(msg.update_statusbar.ammo);
	break;
    default:
	break;
    }

    return msg;

}


/* Extra space is for one message in case we only get part of a message in the
   first read() */
static char netbuf[NETBUFSIZ + sizeof(netmsg)];
static int recv_fd;
static int recv_pos;
static int recv_size;
static int netbuf_size; /* Records the ACTUAL size of the amount read */
static int frame_start;
static int frame_end;

static void net_reset(void)
{

    bytes_sent = 0;
    bytes_recieved = 0;
    last_time = gettime();

    recv_fd = -1;
    recv_pos = 0;
    recv_size = 0;
    netbuf_size = 0;
    frame_start = 0;
    frame_end = NETBUFSIZ;

}


int net_buffered_read(int fd)
{
    int last_start_found = -1;
    int offset;
    int msg_size;
    int frames_found = 0;
    byte type;

    /* Initial defaults, so that if no frame starts or ends are found
       then net_get_message() will work ok */
    frame_start = 0;
    frame_end = NETBUFSIZ;

    /* offset is the amount of data left over in
       netbuf that wasn't delt with */
    offset = netbuf_size - recv_size;

    if( offset > 0 ) {
	if( recv_size > NETBUFSIZ ) {
	    printf("Recv_size = %d (netbufsiz = %d)!\n", recv_size, NETBUFSIZ);
	    offset = 0;
	} else
	    memmove( netbuf, netbuf + recv_size, offset );
    }

    if( recv_fd != fd ) {
	net_reset();
	recv_fd = fd;
    }

    if( (recv_size = read(recv_fd, netbuf + offset, NETBUFSIZ)) <= 0) {
	if( errno == EAGAIN )
	    return 0; /* Error reading, but we should try again */
	else if( errno == EINTR ) {
	    printf("%s, BUFFERED READ ERROR!!\n", __FILE__);
	    return 0; /* Most probably got hit with a sigalarm timeout */
	} else {
	    net_reset();
	    return -1; /* Got EOF or non-recoverable error */
	}
    }

#ifdef DEBUG
    printf("%d bytes from read()\n", recv_size);
#endif

    bytes_recieved += recv_size;
    netbuf_size = offset + recv_size;
    recv_pos = 0;

    /* Go through netbuf and find the last full frame start & finish */
    while( netbuf_size - recv_pos > 0 ) {
	type = netbuf[recv_pos];
	msg_size = net_message_size( type );
	if( msg_size <=0 || msg_size > sizeof(netmsg) ) {
	    printf("net_read_frame: netbuf is corrupted! RESETTING.\n");
	    net_reset();
	    return 0;
	}

	if( type == NETMSG_START_FRAME )
	    last_start_found = recv_pos;
	else if( type == NETMSG_END_FRAME && last_start_found >= 0 ) {
	    frame_end = recv_pos;
	    frame_start = last_start_found;
	    last_start_found = -1;
	    frames_found++;

	    /*
	      printf("Last start found: %d\n", last_start_found);
	      printf("Start: %d\n", frame_start);
	      printf("  End: %d\n", frame_end);
	    */

	    recv_size = frame_end + msg_size - frame_start;
	}

	recv_pos += msg_size;

    }

    recv_pos = 0;

    /* PARTIAL FRAME!! Save for next read */
    if( last_start_found >= frame_start )
	recv_size = last_start_found;
    else
	recv_size = netbuf_size;

    /* printf("%d frames found\n", frames_found); */

    /*
      printf("done, recv_size = %d\n", recv_size);
      printf("%d frames read\n", frames_found);
      printf("frame: ST= %d .... END=%d\n", frame_start, frame_end);
    */

    return recv_fd;

}


netmsg net_get_message(void)
{
    netmsg msg;
    int msg_size, r;
    int bts; /* BTS = bytes left to process */

    memset(&msg, 0, sizeof(netmsg));

    if( recv_fd < 0 ) {
	printf("net_get_message: Current_fd is not set!\n");
	msg.type = NETMSG_QUIT;
	return msg;
    }

    while( (bts = recv_size - recv_pos) > 0 ) {
	/* printf("%d bytes left...\n", bts); */
	msg_size = net_message_size( netbuf[recv_pos] );
	if( msg_size <=0 || msg_size > sizeof(netmsg) ) {
	    msg.type = NETMSG_NONE; /* corrupted messages */
	    return msg;
	}

	/* Message is not completly in netbuf. Keep read()'ing while the
	   amount of data unprocessed in netbuf is greater than the size of
	   the message */
	while( bts < msg_size ) {
	    printf("net_get_message: completing %d bytes of message\n",
		   msg_size - bts);
	    if((r = read(recv_fd, &netbuf[recv_size], msg_size-bts)) <= 0) {
		perror("read");
		fprintf(stderr, "Error reading data! Returning NETMSG_QUIT\n");
		msg.type = NETMSG_QUIT;
		return msg;
	    }
	    bts += r;
	    bytes_recieved += r;
	}

	/* These kind of messages are dropped if they occur outside
	   of the current frame_start and frame_end positions */
	if( netbuf[recv_pos] == NETMSG_ENTITY ||
	    netbuf[recv_pos] == NETMSG_MYENTITY ||
	    netbuf[recv_pos] == NETMSG_START_FRAME ||
	    netbuf[recv_pos] == NETMSG_END_FRAME ) {
	    if( recv_pos < frame_start || recv_pos > frame_end ) {
#ifdef DEBUG
		printf("get_message: Dropping %s!\n",
		       net_message_name( netbuf[recv_pos] ));
#endif
		recv_pos += msg_size;
		continue;
	    }
	}
	
	memcpy( &msg, &netbuf[recv_pos], msg_size );
	recv_pos += msg_size;
#ifdef DEBUG
	printf("RECIEVING %s\n", net_message_name(msg.type));
#endif
	return fromnet(msg);

    }

    /* No messages (or bytes) left in netbuf */
    msg.type = NETMSG_NONE;
    return msg;

}


/* FIXME: unbuffered! */
int net_send_message(int fd, netmsg msg)
{
    int n;

#ifdef DEBUG
    printf("SENDING %s\n", net_message_name(msg.type));
#endif

    if( (n = net_message_size( msg.type )) <= 0 ) {
	printf("Error sending %s (%db)\n", net_message_name(msg.type), n);
	return n;
    }


    msg = tonet(msg);

    if( (n = write( fd, &msg, n )) <= 0 ) {
	fprintf(stderr, "net_send_messsage: error sending message %s to %s\n",
		net_message_name(msg.type), net_get_address(fd));
	perror("write");
    } else
	bytes_sent += n;
    
    return n; /* write's status */

}


int net_message_size(byte type)
{

    if( type < NUM_NETMESSAGES )
	return netmsg_table[(int)type].size;
    else {
	printf("Requested size for %s (%d)\n", net_message_name(type),
	       (int)type);
	return -1;
    }

}


char *net_message_name(byte type)
{

    if( type < NUM_NETMESSAGES )
	return netmsg_table[(int)type].name;
    else
	return "UNKNOWN MESSAGE!";

}


/* Gets the IP address of the connection then returns the hostname if it can
   resolve it, otherwise it returns the ip address in ascii dot notation */
char *net_get_address(int fd)
{
    struct sockaddr_in sin;
    struct hostent *hp;
    int len;
    
    len = sizeof(sin);
    if( getpeername(fd, (struct sockaddr *) &sin, (socklen_t *) &len) == 0 ) {
	hp = gethostbyaddr((char*) &sin.sin_addr.s_addr,
			   sizeof(sin.sin_addr.s_addr), AF_INET);
	/* Return hostname, or IP address if we can't resolve hostname */
	if (hp == NULL)
	    return (char *)inet_ntoa(sin.sin_addr); /* IP address */
	else
	    return hp->h_name; /* Hostname */

    } else {
	perror("getpeername");
	return NULL;
    }

}


static void sig_catcher(int signo)
{

    /* Not sure how portable  psignal (or strsignal) is */
    /* psignal(signo, "Caught signal"); */
    if( signo == SIGPIPE )
	printf("Caught SIGPIPE\n");

}
