/*
 * Socket operations
 *
 * getLocalAddr - no error reporting
 * getRemoteAddr - no error reporting
 * openClient - reports error
 * openServer - reports error
 */
#ifdef _POSIX_SOURCE
#undef _POSIX_SOURCE	/* see comment below */
#endif
#include <sys/types.h>	/* pid_t */
/*
 * errno.h must not be _POSIX_SOURCE in BSDI's BSDI/386 or ENOTSOCK not defined
 */
#include <errno.h>	/* EINTR ENOTSOCK */
/*
 * These symbols are defined in <sys/types.h> but not #ifdef _POSIX_SOURCE
 *
 * Some machines can cause them to be defined by:
 * DEC/OSF1  - #define _OSF_SOURCE
 * HP-UX     - #define _HPUX_SOURCE - must still do this for other declarations
 * SunOS 4.1 - not possible, that's why we put them here.
 */
#if 0
typedef unsigned char	u_char;		/* <netinet/in.h> */
typedef unsigned short	u_short;	/* <sys/socket.h> <netinet/tcp.h> */
typedef unsigned long  	u_long;		/* <netinet/tcp.h> */
typedef unsigned int	u_int;		/* <sys/socket.h> and DEC OSF/1 */
#endif
#define _POSIX_SOURCE
#include <sys/socket.h>	/* struct sockaddr */
			/* connect accept listen bind socket */
			/* getpeername getsockname */
#include <netinet/in.h>
#include <netinet/tcp.h>/* TCP_NODELAY */
#include <netdb.h>	/* gethostbyaddr */
#include <arpa/inet.h>	/* inet_ntoa on mips and hpux */

#include <unistd.h>	/* close waitpid */
#include <stdlib.h>	/* atexit */
#include <stdio.h>
#include <string.h>
#include <signal.h>
#include <sys/wait.h>	/* WNOHANG */
#include "log.h"
#include "socket.h"
#include "deslogin.h"
#include "posignal.h"

/*
 * Return a non-zero socket type of the specifed fd if it is a socket
 *    0 otherwise, -1 for failure
 *
 * socket types in <sys/socket.h>: (socket types always > 0)
 *     SOCK_STREAM SOCK_DGRAM SOCK_RAW SOCK_RDM SOCK_SEQPACKET
 */
int socktype(fd) 
   int fd;
{
   int sotype = -1;
   int size = sizeof sotype;
   int res;

   res = getsockopt(fd, SOL_SOCKET, SO_TYPE, &sotype, &size);
   if (res >= 0) {
      res = sotype;
   } else if (errno == ENOTSOCK) {
      res = 0;
   }
   return res;
}

/*
 * If flag non-zero, enable TCP delays
 *
 * Returns:
 *     1 delay was previously enabled
 *     0 delay was previously disabled
 *    -1 if gotprotobyname for "tcp" failed
 *    -2 if getsockopt TCP_NODELAY failed
 *    -3 if setsockopt TCP_NODELAY failed
 */
int tcpDelay(sock, flag) 
   int sock, flag;
{
   struct protoent *pent;
   int newflag = -1, oldflag = -1;
   int newlen = sizeof newflag, oldlen = sizeof oldflag;
   int tcplevel, res;

   flag = !flag;			/* NODELAY is opposate sense */

   pent = getprotobyname("tcp");
   if (pent == (struct protoent *) 0) {	/* opens file /etc/protocols */
      return -1;
   }
   tcplevel = pent->p_proto;		/* protocol number for tcp */

   res = getsockopt(sock, tcplevel, TCP_NODELAY, &oldflag, &oldlen);
   if (res == 0) {
      oldflag = !oldflag;		/* 4 for TCP_NODELAY on, 0 for OFF */
   } else {
      if (errno == ENOPROTOOPT) {	/* never happens on hpux */
	 oldflag = 1;			/* NODELAY was disabled */
      } else {
	 return -2;
      }
   }
   res = setsockopt(sock, tcplevel, TCP_NODELAY, &flag, sizeof flag);
   if (res != 0) {
      return -3;
   }
   /*
    * To detect new state when in debugger
    */
   res = getsockopt(sock, tcplevel, TCP_NODELAY, &newflag, &newlen);
   return oldflag;
}

/*
 * Lookup the specified network address (from accept for example).
 * Returns (char *) 0 for failure.
 */
char *hostAddrToName(addr, name, namelen)
   struct sockaddr_in *addr;	/* address from accept(2) */
   char *name;			/* where to put hostname */
   unsigned namelen;		/* how many bytes are there for it */
{
   struct hostent *hp = gethostbyaddr(
      (char *) &(addr->sin_addr), sizeof addr->sin_addr, addr->sin_family);

   if (hp == (struct hostent *) 0) {
      return (char *) 0;
   }
   if (name != (char *) 0) {
      strncpy(name, hp->h_name, namelen);
   }
   return name;
}

/*
 * Internal: convert address to string and port
 *
 * Returns:
 *    -1 - if hostname not found (string contains internet address)
 *     0 - if hostname found and copied
 */
int mapAddr(iaddr, hostName, size, port)
   struct sockaddr_in *iaddr;
   char *hostName;
   unsigned size;
   unsigned *port;
{
   char *chp;
   int res = 0;

   chp = hostAddrToName(iaddr, hostName, size);
   if (chp == (char *) 0) {
      if (hostName != (char *) 0) {
	 strncpy(hostName, inet_ntoa(iaddr->sin_addr), size);
      }
      res = -1;
   }
   if (port != (unsigned *) 0) {
      *port = (unsigned) ntohs(iaddr->sin_port);
   }
   return res;
}

/*
 * Find the address for the specified host and port (AF_INET only).
 *
 * Input;  addr, size    location and size of whereto put the result
 *         port          which port number to use
 *
 * Output: addr - the address
 * Returns: The sizeof the area used by addr, or -1 if error
 */
int hostNameToAddr(addr, size, name, port)
   struct sockaddr *addr;
   unsigned size;
   char *name;
   int port;
{
   struct hostent   *hp = gethostbyname(name);
   struct sockaddr_in *iaddr = (struct sockaddr_in *) addr;
   int res = -1;

   if (hp != (struct hostent *) 0) {
      memset(addr, '\0', size);
      iaddr->sin_family = hp->h_addrtype;
      if (hp->h_addrtype == AF_INET) {
	 if ((char *) &iaddr->sin_addr + hp->h_length
	  <= (char *) iaddr + size) {
	    memcpy(&iaddr->sin_addr, hp->h_addr, hp->h_length);
	    iaddr->sin_port   = htons(port);
	 }
	 res = sizeof (struct sockaddr_in);
      }
   }
   return res;
}

/*
 * Try to find the local address of the specified socket. 
 *
 * Returns:
 *     0        - success
 *     ENOTSOCK - if not a socket
 *     errno    - if other error
 */
int getLocalAddr(sock, hostName, size, port)
   int sock;
   char *hostName;
   unsigned size;
   unsigned *port;
{
   int res;
   struct sockaddr_in iaddr;
   int addrlen;

   addrlen = sizeof iaddr;
   res = getsockname(sock, (struct sockaddr *) &iaddr, &addrlen);
   if (res < 0) {
      return errno;
   }

   res = mapAddr(&iaddr, hostName, size, port);
   return 0;
}

/*
 * Try to find the remote address of the specified socket. 
 *
 * Returns:
 *     0        - success
 *     ENOTSOCK - if not a socket
 *     errno    - if other error
 */
int getRemoteAddr(sock, hostName, size, port)
   int sock;
   char *hostName;
   unsigned size;
   unsigned *port;
{
   int res;
   struct sockaddr_in iaddr;
   int addrlen;

   addrlen = sizeof iaddr;
   res = getpeername(sock, (struct sockaddr *) &iaddr, &addrlen);
   if (res < 0) {
      return errno;
   }

   res = mapAddr(&iaddr, hostName, size, port);
   return 0;
}

/*
 * Return the port number for the specified service name or -1 if failed
 */
int getServicePort(name)
   char *name;
{
   struct servent *sent; 
   int res = -1;
   
   sent = getservbyname(name, (char *) 0);
   if (sent != (struct servent *) 0) {
      res = sent->s_port;
   }
   return res;
}

void sockHandler(sig) 
   int sig; 
{ 
   if (debug) {
      log("%s(openServer): SIGCHLD\n", progName);
   }
}

/*
 * Listen for a TCP connection on the specified port.  Return socket to it.
 * If hostname is non-null, return upto size bytes of remote host's name.
 * If rport is non-null, set it to the remote port.
 *
 * The caller of this routine should catch SIGCHLD using POSIX sigaction call
 * (see posignal.c) so that we can reap unwanted children when they die.
 *
 * Reports Error if failure and returns -1; othewise the socket
 */
int openServer(port, hostName, size, rport, serverflag)
   int port;
   char *hostName;
   unsigned size;
   unsigned *rport;
   int serverflag;			/* 1 if server, 0 if only once */
{
   int lsock, csock = -1, caddrLen, res; 
   pid_t pid;
   struct sockaddr_in laddr, caddr;
   pid_t child = 0;
   pfv   oldHandler; 
   
   lsock = socket(AF_INET, SOCK_STREAM, 0);
   if (lsock < 0) {
      log("%s: stream socket create failed--%s\n", 
	 progName, ERRMSG);
      return -1;
   }
   memset(&laddr, '\0', sizeof laddr);
   laddr.sin_family = AF_INET;
   laddr.sin_addr.s_addr = htonl(INADDR_ANY);
   laddr.sin_port = htons(port);
   res = bind(lsock,  (void *) &laddr, sizeof laddr);
   if (res < 0) {
      log("%s: bind to port %d failed --%s\n", progName, port, ERRMSG);
      return -1;
   }
   res = listen(lsock, 2);
   if (res < 0) {
      log("%s: listen failed --%s\n", progName, ERRMSG);
      return -1;
   }

   oldHandler = posignal(SIGCHLD, sockHandler);
   if (oldHandler == (pfv) -1)  {
      if (oldHandler == (pfv) -1) {
	 log("%s(openServer): sigaction SIGCHLD failed--%s\n", 
	    progName, ERRMSG);
	 return -1;
      }
   }
   do {
      if (csock >= 0) {
	 close(csock);
      }
      /*
       * Block until we receive a connect request.
       */
      caddrLen = sizeof caddr;
      memset(&caddr, '\0', sizeof caddr);
      csock = accept(lsock, (void *) &caddr, &caddrLen);
      if (csock < 0) {
	 if (errno != EINTR) {
	    log("%s: accept call failed--%s\n", progName, ERRMSG);
	    exit(1);
	 }
	 /* SIGCHLD: Reap all children so we don't leave zombie processes. */
	 do {
	    pid = waitpid(-1, (int *) 0, WNOHANG);
	 } while (pid > 0);
      } else {
	 if (serverflag) {
	    child = fork();
	 }
      }
   } while (serverflag && child > 0);

   posignal(SIGCHLD, oldHandler);
   close(lsock);

   if (debug) {
      log("%s: accept from %s:%u\n", 
	 progName, inet_ntoa(caddr.sin_addr), ntohs(caddr.sin_port));
   }
   if (child < 0) {
      log("%s: openServer: fork of child failed\n", progName);
      exit(1);
   }

   res = tcpDelay(csock, 1);		/* don't delay TCP transmissions */
   if (res < 0) {
      log("%s: tcpDelay returned (%d) --%s\n", progName, res, ERRMSG);
      /* OK to continue */
   }
   res = mapAddr(&caddr, hostName, size, rport);
   return csock;
}

/*
 * Establish a connection to the remote host and port.  
 *
 * Return -1 for failure (error reported)
 */
int openClient(hostName, port)
   char *hostName;
   int port;
{
   int sock, res;
   struct sockaddr addr;
   int addrLen;

   addrLen = hostNameToAddr(&addr, sizeof addr, hostName, port);
   if (addrLen < 0) {
      log("%s: can't find address for host \"%s\"\n", 
	 progName, hostName, ERRMSG);
      return -1;
   }

   sock = socket(addr.sa_family, SOCK_STREAM, 0);
   if (sock < 0) {
      log("%s: stream socket create failed--%s\n", 
	 progName, ERRMSG);
      return -1;
   }
   res = connect(sock, (struct sockaddr *) &addr, addrLen);
   if (res < 0) {
      log("%s: connect to %s:%d failed--%s\n", 
	 progName, hostName, port, ERRMSG);
      return -1;
   }
#if 0		/* BSD removed this and replaced with struct linger. */
   res = setsockopt(sock, SOL_SOCKET, SO_DONTLINGER, (char *) 0, 0);
#endif
   res = tcpDelay(sock, 1);		/* don't delay TCP transmissions */
   if (res < 0) {
      log("%s: tcpDelay returned (%d) --%s\n", progName, res, ERRMSG);
      /* OK to continue */
   }
   return sock;
}
