/*
 *  sock.c
 *
 *  Copyright (C) 1995 by Paal-Kr. Engstad and Volker Lendecke
 *	Modified by Christian Starkjohann <cs@hal.kph.tuwien.ac.at>
 *
 */

#include "syshdr.h"
#include <smb/smb_fs.h>

#include <smb/smb.h>
#include <smb/smbno.h>
#include "my_defines.h"

extern int	errno;

int close_fp(struct file *filp)
{
	return close((int)filp);
}

static int
_recvfrom(int sock_fd, unsigned char *ubuf, int size,
	  int noblock, unsigned flags, struct sockaddr_in *sa, int *addr_len)
{
int		sflags, rval;
	
	if((sflags = fcntl(sock_fd, F_GETFL, 0)) != -1){
		if(noblock)
			sflags |= O_NDELAY;
		else
			sflags &= ~O_NDELAY;
		fcntl(sock_fd, F_SETFL, sflags);
		rval = recvfrom(sock_fd, ubuf, size, flags, (struct sockaddr *)sa,
																	addr_len);
		if(rval < 0){
			return -errno;
		}
		return rval;
	}else{
		return -errno;
	}
}

static int
_send(int sock_fd, const void *buff, int len,
      int noblock, unsigned flags)
{
int		sflags, rval;
	
	if((sflags = fcntl(sock_fd, F_GETFL, 0)) != -1){
		if(noblock)
			sflags |= O_NDELAY;
		else
			sflags &= ~O_NDELAY;
		fcntl(sock_fd, F_SETFL, sflags);
		rval = send(sock_fd, (void *)buff, len, flags);
		if(rval < 0){
			return -errno;
		}
		return rval;
	}else{
		return -errno;
	}
}

int
smb_catch_keepalive(struct smb_server *server)
{
	return 0;
}
                
int
smb_dont_catch_keepalive(struct smb_server *server)
{
	return 0;
}

/*
 * smb_receive_raw
 * fs points to the correct segment, sock != NULL, target != NULL
 * The smb header is only stored if want_header != 0.
 */
static int
smb_receive_raw(int sock_fd, unsigned char *target,
				int max_raw_length, int want_header)
{
int len, result;
int already_read;
unsigned char peek_buf[4];

re_recv:

	result = _recvfrom(sock_fd, (void *)peek_buf, 4, 0,
									0, NULL, NULL);
	if (result < 0) {
		DPRINTK("smb_receive_raw: recv error = %d\n", -result);
		return result;
	}
	
	if (result < 4) {
		DPRINTK("smb_receive_raw: got less than 4 bytes\n");
		return -EIO;
	}
	
	switch (peek_buf[0]) {
	case 0x00:
	case 0x82:
		break;
	case 0x85:
		DPRINTK("smb_receive_raw: Got SESSION KEEP ALIVE\n");
		goto re_recv;
	default:
		printk("smb_receive_raw: Invalid packet 0x%02x\n", peek_buf[0]);
		return -EIO;
	}
	/* The length in the RFC NB header is the raw data length */
	len = smb_len(peek_buf); 
	if (len > max_raw_length) { 
		printk("smb_receive_raw: Received length (%d) > max_xmit (%d)!\n", 
		len, max_raw_length);
		return -EIO;
	}
	
	if (want_header != 0) {
		memcpy(target, peek_buf, 4);
		target += 4;
	}
	already_read = 0;
	while (already_read < len) {
		result = _recvfrom(sock_fd, (void *)(target + already_read),
								len - already_read, 0, 0, NULL, NULL);
		if (result < 0) {
			printk("smb_receive_raw: recvfrom error = %d\n", -result);
			return result;
		}
		already_read += result;
	}
	return already_read;
}

/*
 * smb_receive
 * fs points to the correct segment, server != NULL, sock!=NULL
 */
static int
smb_receive(struct smb_server *server, int sock_fd)
{
int result;

	result = smb_receive_raw(sock_fd, server->packet,
								server->max_recv - 4, /* max_xmit in server
														includes NB header */
								1); /* We want the header */
	if (result < 0) {
		printk("smb_receive: receive error: %d\n", result);
		return result;
	}
	server->rcls = *((unsigned char *)(server->packet+9));
	server->err = WVAL(server->packet, 11);
	if (server->rcls != 0) {
		DPRINTK("smb_receive: rcls=%d, err=%d\n", server->rcls, server->err);
	}
	return result;
}


/*
 * smb_receive's preconditions also apply here.
 */
static int
smb_receive_trans2(struct smb_server *server, int sock_fd,
                   int *data_len, int *param_len, char **data, char **param)
{
int				total_data = 0;
int				total_param = 0;
int				result;
unsigned char	*inbuf = server->packet;

	DDPRINTK("smb_receive_trans2: enter\n");
	*data_len = *param_len = 0;
	if ((result = smb_receive(server, sock_fd)) < 0) {
		return result;
	}
	if (server->rcls != 0) {
		return result;
	}
	/* parse out the lengths */
	total_data = WVAL(inbuf, smb_tdrcnt);
	total_param = WVAL(inbuf, smb_tprcnt);

	if ((total_data  > TRANS2_MAX_TRANSFER)
			|| (total_param > TRANS2_MAX_TRANSFER)) {
		printk("smb_receive_trans2: data/param too long\n");
		return -EIO;
	}
	/* allocate it */
	if ((*data  = malloc(total_data)) == NULL) {
		printk("smb_receive_trans2: could not alloc data area\n");
		return -ENOMEM;
	}
	if ((*param = malloc(total_param)) == NULL) {
		printk("smb_receive_trans2: could not alloc param area\n");
		free(*data);
		return -ENOMEM;
	}
	DDPRINTK("smb_rec_trans2: total_data/param: %d/%d\n",
				total_data, total_param);
	while (1)
	{
		if (WVAL(inbuf,smb_prdisp)+WVAL(inbuf, smb_prcnt) > total_param) {
			printk("smb_receive_trans2: invalid parameters\n");
			result = -EIO;
			goto fail;
		}
		memcpy(*param + WVAL(inbuf,smb_prdisp),
				smb_base(inbuf) + WVAL(inbuf,smb_proff),WVAL(inbuf,smb_prcnt));
		*param_len += WVAL(inbuf,smb_prcnt);

		if (WVAL(inbuf,smb_drdisp)+WVAL(inbuf, smb_drcnt)>total_data) {
			printk("smb_receive_trans2: invalid data block\n");
			result = -EIO;
			goto fail;
		}
		memcpy(*data + WVAL(inbuf,smb_drdisp),
				smb_base(inbuf) + WVAL(inbuf,smb_droff),WVAL(inbuf,smb_drcnt));
		*data_len += WVAL(inbuf,smb_drcnt);
		DDPRINTK("smb_rec_trans2: drcnt/prcnt: %d/%d\n",
					WVAL(inbuf, smb_drcnt), WVAL(inbuf, smb_prcnt));

		/* parse out the total lengths again - they can shrink! */
		if ((WVAL(inbuf,smb_tdrcnt) > total_data)
				|| (WVAL(inbuf,smb_tprcnt) > total_param)) {
			printk("smb_receive_trans2: data/params grew!\n");
			result = -EIO;
			goto fail;
		}
		total_data = WVAL(inbuf,smb_tdrcnt);
		total_param = WVAL(inbuf,smb_tprcnt);
		if (total_data <= *data_len && total_param <= *param_len)
				break;

		if ((result = smb_receive(server, sock_fd)) < 0) {
			goto fail;
		}
		if (server->rcls != 0) {
			result = -EIO;
			goto fail;
		}
	}
	DDPRINTK("smb_receive_trans2: normal exit\n");
	return 0;
fail:
	DPRINTK("smb_receive_trans2: failed exit\n");
	free(*param); *param = NULL;
	free(*data);  *data = NULL;
	return result;
}

static inline int
server_sock(struct smb_server *server)
{
int	sock_fd = server->m.fd;

	server->sock_file = (struct file *)sock_fd;
	return sock_fd;
}

int
smb_release(struct smb_server *server)
{
int		sock_fd = server_sock(server);

	if (sock_fd >= 0)
		close(sock_fd);
	sock_fd = socket(AF_INET, SOCK_STREAM, 0);
	server->m.fd = sock_fd;
	if(sock_fd < 0)
		return -errno;
	return 0;
}

int
smb_connect(struct smb_server *server)
{
int		sock_fd = server_sock(server);

	if (sock_fd < 0)
			return -EINVAL;
	return connect(sock_fd, (struct sockaddr *)&(server->m.addr),
											sizeof(struct sockaddr_in));
}
        
/*****************************************************************************/
/*                                                                           */
/*  This routine was once taken from nfs, which is for udp. Here TCP does    */
/*  most of the ugly stuff for us (thanks, Alan!)                            */
/*                                                                           */
/*****************************************************************************/
int
smb_request(struct smb_server *server)
{
int				len, result, result2;
int				sock_fd = server_sock(server);
unsigned char	*buffer = (server == NULL) ? NULL : server->packet;

	if ((sock_fd < 0) || (buffer == NULL)) {
		printk("smb_request: Bad server!\n");
		return -EBADF;
	}
	if (server->state != CONN_VALID)
		return -EIO;
	if ((result = smb_dont_catch_keepalive(server)) != 0) {
		server->state = CONN_INVALID;
		smb_invalidate_all_inodes(server);
		return result;
	}
	len = smb_len(buffer) + 4;
	DDPRINTK("smb_request: len = %d cmd = 0x%X\n", len, buffer[8]);
#ifdef NETBIOS_TRACE
{
	char	buf[2048];
	int		i, pos;
	sprintf(buf, "<- Tx:");
	pos = strlen(buf);
	for(i=0;i<len;i++){
		sprintf(buf + pos, " %02x", buffer[i]);
		pos = strlen(buf);
	}
	fprintf(stderr, "%s\n", buf);
}
#endif
	result = _send(sock_fd, (void *)buffer, len, 0, 0);
	if (result < 0) {
		printk("smb_request: send error = %d\n", result);
	}
	else {
		result = smb_receive(server, sock_fd);
	}
	if ((result2 = smb_catch_keepalive(server)) < 0) {
		result = result2;
	}
	if (result < 0) {
		server->state = CONN_INVALID;
		smb_invalidate_all_inodes(server);
	}
	DDPRINTK("smb_request: result = %d\n", result);
	return result;
}

/*
 * This is not really a trans2 request, we assume that you only have
 * one packet to send.
 */ 
int
smb_trans2_request(struct smb_server *server,
                   int *data_len, int *param_len,
                   char **data, char **param)
{
int				len, result, result2;
int				sock_fd = server_sock(server);
unsigned char	*buffer = (server == NULL) ? NULL : server->packet;

	if ((sock_fd < 0) || (buffer == NULL)) {
		printk("smb_trans2_request: Bad server!\n");
		return -EBADF;
	}
	if (server->state != CONN_VALID)
		return -EIO;
	if ((result = smb_dont_catch_keepalive(server)) != 0) {
		server->state = CONN_INVALID;
		smb_invalidate_all_inodes(server);
		return result;
	}
	len = smb_len(buffer) + 4;

	DDPRINTK("smb_request: len = %d cmd = 0x%X\n", len, buffer[8]);
	result = _send(sock_fd, (void *)buffer, len, 0, 0);
	if (result < 0) {
		printk("smb_trans2_request: send error = %d\n", result);
	}
	else {
		result = smb_receive_trans2(server, sock_fd, data_len, param_len,
													data, param);
	}
	if ((result2 = smb_catch_keepalive(server)) < 0) {
		result = result2;
	}
	if (result < 0) {
		server->state = CONN_INVALID;
		smb_invalidate_all_inodes(server);
	}
	DDPRINTK("smb_trans2_request: result = %d\n", result);
	return result;
}

/* target must be in user space */
int
smb_request_read_raw(struct smb_server *server,
                     unsigned char *target, int max_len)
{
int				len, result, result2;
int				sock_fd = server_sock(server);
unsigned char	*buffer = (server == NULL) ? NULL : server->packet;

	if ((sock_fd < 0) || (buffer == NULL)) {
		printk("smb_request_read_raw: Bad server!\n");
		return -EBADF;
	}
	if (server->state != CONN_VALID)
		return -EIO;
	if ((result = smb_dont_catch_keepalive(server)) != 0) {
		server->state = CONN_INVALID;
		smb_invalidate_all_inodes(server);
		return result;
	}
	len = smb_len(buffer) + 4;

	DPRINTK("smb_request_read_raw: len = %d cmd = 0x%X\n",
			len, buffer[8]);
	DPRINTK("smb_request_read_raw: target=%X, max_len=%d\n",
			(unsigned int)target, max_len);
	DPRINTK("smb_request_read_raw: buffer=%X, sock=%X\n",
			(unsigned int)buffer, (unsigned int)sock_fd);

	result = _send(sock_fd, (void *)buffer, len, 0, 0);

	DPRINTK("smb_request_read_raw: send returned %d\n", result);
	if (result < 0) {
		printk("smb_request_read_raw: send error = %d\n", result);
	}
	else {
		result = smb_receive_raw(sock_fd, target, max_len, 0);
	}
	if ((result2 = smb_catch_keepalive(server)) < 0) {
		result = result2;
	}

	if (result < 0) {
		server->state = CONN_INVALID;
		smb_invalidate_all_inodes(server);
	}
	DPRINTK("smb_request_read_raw: result = %d\n", result);
	return result;
}

/* Source must be in user space. smb_request_write_raw assumes that
 * the request SMBwriteBraw has been completed successfully, so that
 * we can send the raw data now.  */
int
smb_request_write_raw(struct smb_server *server,
                      unsigned const char *source, int length)
{
int				result, result2;
byte			nb_header[4];
int				sock_fd = server_sock(server);
unsigned char	*buffer = (server == NULL) ? NULL : server->packet;

	if ((sock_fd < 0) || (buffer == NULL)) {
		printk("smb_request_write_raw: Bad server!\n");
		return -EBADF;
	}

	if (server->state != CONN_VALID)
		return -EIO;
		
	if ((result = smb_dont_catch_keepalive(server)) != 0) {
		server->state = CONN_INVALID;
		smb_invalidate_all_inodes(server);
		return result;
	}
	smb_encode_smb_length(nb_header, length);
	result = _send(sock_fd, (void *)nb_header, 4, 0, 0);
	if (result == 4) {
		result = _send(sock_fd, (void *)source, length, 0, 0);
	} else {
		result = -EIO;
	}
	DPRINTK("smb_request_write_raw: send returned %d\n", result);
	if (result == length) {
		result = smb_receive(server, sock_fd);
	} else {
		result = -EIO;
	}
	if ((result2 = smb_catch_keepalive(server)) < 0) {
		result = result2;
	}
	if (result < 0) {
		server->state = CONN_INVALID;
		smb_invalidate_all_inodes(server);
	}
	if (result > 0) {
		result = length;
	}
	DPRINTK("smb_request_write_raw: result = %d\n", result);
	return result;
}
