#include "common.h"
#include "packet.h"
#include <errno.h>
#include <sys/socket.h>
#include <netinet/in.h>

void
init_packet(p)
PACKET *p;
{
  memset(p, 0, sizeof(*p));
}

stuff_short(p, s)
PACKET *p;
SHORT s;
{
  SHORT *newptr;
  int newsz = (p->cshort+1)*sizeof(SHORT);
  newptr = (SHORT *)realloc(p->shorts, newsz);
  if (newptr == NULL) {
    /* out of memory */
    return PACKET_ERROR_MEMORY;
  }
  newptr[p->cshort] = s;
  p->shorts = newptr;
  p->cshort++;
  p->size += sizeof(SHORT);
  return 0;
}

stuff_long(p, l)
PACKET *p;
LONG l;
{
  LONG *newptr;
  int newsz = (p->clong+1)*sizeof(LONG);
  newptr = (LONG *)realloc(p->longs, newsz);
  if (newptr == NULL) {
    /* out of memory */
    return PACKET_ERROR_MEMORY;
  }
  newptr[p->clong] = l;
  p->longs = newptr;
  p->clong++;
  p->size += sizeof(LONG);
  return 0;
}

stuff_str(p, str)
PACKET *p;
char *str;
{
  char **newptr;
  int newsz = (p->cstr+1)*sizeof(char *);
  int len;
  newptr = (char **)realloc(p->strs, newsz);
  if (newptr == NULL) {
    /* out of memory */
    return PACKET_ERROR_MEMORY;
  }
  if (str == NULL) str = "";
  len = strlen(str);
  newptr[p->cstr] = (char *)malloc(len+1);
  if (newptr[p->cstr] == NULL) {
    /* out of memory */
    free(newptr);
    return PACKET_ERROR_MEMORY;
  }
  strcpy(newptr[p->cstr], str);
  p->strs = newptr;
  p->cstr++;
  p->size += (sizeof(SHORT)+len);
  return 0;
}

stuff_bin(p, data, size)
PACKET *p;
void *data;
SHORT size;
{
  BINDATA *newptr;
  int newsz = (p->cbin+1)*sizeof(BINDATA);
  newptr = (BINDATA *)realloc(p->bins, newsz);
  if (newptr == NULL) {
    /* out of memory */    
    return PACKET_ERROR_MEMORY;
  }
  newptr[p->cbin].size = size;
  newptr[p->cbin].data = (char *)malloc(size);
  if (newptr[p->cbin].data == NULL) {
    /* out of memory */
    free(newptr);
    return PACKET_ERROR_MEMORY;
  }
  memcpy(newptr[p->cbin].data, data, size);
  p->bins = newptr;
  p->cbin++;
  p->size += (sizeof(SHORT)+size);
  return 0;
}

void
free_packet(p)
PACKET *p;
{
  SHORT i;
  if (p->shorts != NULL) {
    free(p->shorts);
    p->shorts = NULL;
  }
  if (p->longs != NULL) {
    free(p->longs);
    p->longs = NULL;
  }
  if (p->strs != NULL) {
    for (i=0; i<p->cstr; i++)
      if (p->strs[i] != NULL) free(p->strs[i]);
    free(p->strs);
    p->strs = NULL;
  }
  if (p->bins != NULL) {
    for (i=0; i<p->cbin; i++) {
      if (p->bins[i].data != NULL) free(p->bins[i].data);
    }
    free(p->bins);
    p->bins = NULL;
  }
  p->cshort = p->clong = p->cstr = p->cbin = 0;
  p->size = 0;
  p->flags = 0;
  /* p->magic and p->opcode are left alone */
}

raw_to_packet(p, pdata)
PACKET *p;
char *pdata;
{
  /* pdata must point to a buffer of size at least p->size */
  SHORT i, sz, ts;
  LONG left = p->size, tl;

  if (p->cshort > 0) {
    if ((p->shorts = (SHORT *)calloc(p->cshort, sizeof(SHORT))) == NULL) {
      /* out of memory */
      return PACKET_ERROR_MEMORY;
    }
    if (left < p->cshort*sizeof(SHORT)) {
      /* inconsistent packet */
      return PACKET_ERROR_SANITY;
    }
    for (i=0; i<p->cshort; i++) {
      memcpy(&ts, pdata, sizeof(SHORT));
      p->shorts[i] = ntohs(ts);
      pdata += sizeof(SHORT);
    }
    left -= (p->cshort*sizeof(SHORT));
  }

  if (p->clong > 0) {
    if ((p->longs = (LONG *)calloc(p->clong, sizeof(LONG))) == NULL) {
      /* out of memory */
      return PACKET_ERROR_MEMORY;
    }
    if (left < p->clong*sizeof(LONG)) {
      /* inconsistent packet */
      return PACKET_ERROR_SANITY;
    }
    for (i=0; i<p->clong; i++) {
      memcpy(&tl, pdata, sizeof(LONG));
      p->longs[i] = ntohl(tl);
      pdata += sizeof(LONG);
    }
    left -= (p->clong*sizeof(LONG));
  }

  if (p->cstr > 0) {
    if ((p->strs = (char **)calloc(p->cstr, sizeof(char *))) == NULL) {
      /* out of memory */
      return PACKET_ERROR_MEMORY;
    }
    for (i=0; i<p->cstr; i++) {
      if (left < sizeof(SHORT)) {
        /* inconsistent packet */
        return PACKET_ERROR_SANITY;
      }
      memcpy(&ts, pdata, sizeof(SHORT));
      sz = ntohs(ts);
      pdata += sizeof(SHORT);
      left -= sizeof(SHORT);
      if ((p->strs[i] = (char *)calloc(sz+1, 1)) == NULL) {
        /* out of memory */
        return PACKET_ERROR_MEMORY;
      }
      if (sz > 0) {
        if (left < sz) {
          /* inconsistent packet */
          return PACKET_ERROR_SANITY;
        }
        memcpy(p->strs[i], pdata, sz);
        pdata += sz;
        left -= sz;
      }
    }
  }

  if (p->cbin > 0) {        
    if ((p->bins = (BINDATA *)calloc(p->cbin, sizeof(BINDATA))) == NULL) {
      /* out of memory */
      return PACKET_ERROR_MEMORY;
    }
    for (i=0; i<p->cbin; i++) {
      if (left < sizeof(SHORT)) {
        /* inconsistent packet */
        return PACKET_ERROR_SANITY;
      }
      memcpy(&ts, pdata, sizeof(SHORT));
      sz = ntohs(ts);
      pdata += sizeof(SHORT);
      left -= sizeof(SHORT);
      p->bins[i].size = sz;
      if (sz == 0) p->bins[i].data = NULL;
      else {
        if ((p->bins[i].data = (void *)calloc(sz, 1)) == NULL) {
          /* out of memory */
          return PACKET_ERROR_MEMORY;
        }
        if (left < sz) {
          /* inconsistent packet */
          return PACKET_ERROR_SANITY;
        }
        memcpy(p->bins[i].data, pdata, sz);
        pdata += sz;
        left -= sz;
      }
    }
  }

  if (left != 0) {
    /* inconsistent packet */
    return PACKET_ERROR_SANITY;
  }

  return 0;
}

packet_to_raw(p, pdata)
PACKET *p;
char *pdata;
{
  /* pdata must point to a buffer of size at least p->size */
  SHORT i, sz, ts;
  LONG left = p->size, tl;

  if (p->cshort > 0) {
    if (left < p->cshort*sizeof(SHORT)) {
      /* inconsistent packet */
      return PACKET_ERROR_SANITY;
    }
    for (i=0; i<p->cshort; i++) {
      ts = htons(p->shorts[i]);
      memcpy(pdata, &ts, sizeof(SHORT));
      pdata += sizeof(SHORT);
    }
    left -= (p->cshort*sizeof(SHORT));
  }

  if (p->clong > 0) {
    if (left < p->clong*sizeof(LONG)) {
      /* inconsistent packet */
      return PACKET_ERROR_SANITY;
    }
    for (i=0; i<p->clong; i++) {
      tl = htonl(p->longs[i]);
      memcpy(pdata, &tl, sizeof(LONG));
      pdata += sizeof(LONG);
    }
    left -= (p->clong*sizeof(LONG));
  }

  if (p->cstr > 0) {
    for (i=0; i<p->cstr; i++) {
      sz = (p->strs[i] == NULL ? 0 : strlen(p->strs[i]));
      if (left < (sizeof(SHORT)+sz)) {
        /* inconsistent packet */
        return PACKET_ERROR_SANITY;
      }
      ts = htons(sz);
      memcpy(pdata, &ts, sizeof(SHORT));
      pdata += sizeof(SHORT);
      if (sz > 0) {      
        memcpy(pdata, p->strs[i], sz);
        pdata += sz;
      }
      left -= (sizeof(SHORT)+sz);
    }
  }

  if (p->cbin > 0) {
    for (i=0; i<p->cbin; i++) {
      sz = p->bins[i].size;
      if (left < (sizeof(SHORT)+sz)) {
        /* inconsistent packet */
        return PACKET_ERROR_SANITY;
      }
      ts = htons(sz);
      memcpy(pdata, &ts, sizeof(SHORT));
      pdata += sizeof(SHORT);
      if (sz > 0) {
        memcpy(pdata, p->bins[i].data, sz);
        pdata += sz;
      }
      left -= (sizeof(SHORT)+sz);
    }
  }

  if (left != 0) {
    /* inconsistent packet */
    return PACKET_ERROR_SANITY;
  }

  return 0;
}

read_packet(p, sd)
PACKET *p;
int sd;
{
  int toread, numread, aok;
  char *cp, *packdata;
  PACKET tp;

  toread = PACKET_HDR_SIZE;
  cp = (char *)&tp;
  while (toread > 0) {
    numread = recv(sd, cp, toread, 0);
    if (numread == -1) return errno;
    else if (numread == 0) return PACKET_ERROR_CLOSED;
    cp += numread;
    toread -= numread;
  }

  p->magic = ntohl(tp.magic);
  p->opcode = ntohs(tp.opcode);
  p->flags = ntohs(tp.flags);
  p->cshort = ntohs(tp.cshort);
  p->clong = ntohs(tp.clong);
  p->cstr = ntohs(tp.cstr);
  p->cbin = ntohs(tp.cbin);
  p->size = ntohs(tp.size);
  
  if (p->magic != PACKET_MAGIC) {
    /* not a valid packet */
    return PACKET_ERROR_SANITY;
  }

  if (p->size == 0) {
    /* empty packet -- we're done */
    return 0;
  }

  if ((packdata = (char *)malloc(p->size)) == NULL) {
    /* out of memory */    
    return PACKET_ERROR_MEMORY;
  }

  toread = p->size;
  cp = packdata;
  while (toread > 0) {
    numread = recv(sd, cp, toread, 0);
    if (numread == -1) {
      free(packdata);
      return errno;
    }
    else if (numread == 0) {
      free(packdata);
      return PACKET_ERROR_CLOSED;
    }
    cp += numread;
    toread -= numread;
  }
  
  aok = raw_to_packet(p, packdata);

  free(packdata);
  return aok;
}

write_packet(p, sd)
PACKET *p;
int sd;
{
  int towrite, numwrote, aok;
  char *cp, *packdata;
  PACKET tp;

  tp.magic = htonl(PACKET_MAGIC);
  tp.opcode = htons(p->opcode);
  tp.flags = htons(p->flags);
  tp.cshort = htons(p->cshort);
  tp.clong = htons(p->clong);
  tp.cstr = htons(p->cstr);
  tp.cbin = htons(p->cbin);
  tp.size = htons(p->size);

  towrite = PACKET_HDR_SIZE;
  cp = (char *)&tp;
  while (towrite > 0) {
    numwrote = send(sd, cp, towrite, 0);
    if (numwrote == -1) {
      free(packdata);
      return errno;
    }
    else if (numwrote == 0) {
      free(packdata);
      return PACKET_ERROR_CLOSED;
    }
    cp += numwrote;
    towrite -= numwrote;
  }

  if (p->size == 0) {
    /* empty packet -- we're done */
    return 0;
  }

  if ((packdata = (char *)malloc(p->size)) == NULL) {
    /* out of memory */
    return PACKET_ERROR_MEMORY;
  }

  aok = packet_to_raw(p, packdata);
  if (aok != 0) {
    free(packdata);
    return aok;
  }

  towrite = p->size;
  cp = packdata;
  while (towrite > 0) {
    numwrote = send(sd, cp, towrite, 0);
    if (numwrote == -1) {
      free(packdata);
      return errno;
    }
    else if (numwrote == 0) {
      free(packdata);
      return PACKET_ERROR_CLOSED;
    }
    cp += numwrote;
    towrite -= numwrote;
  }

  free(packdata);
  return 0;
}
