/* -*- pftp-c -*- */
#if HAVE_CONFIG_H
# include "config.h"
#endif
#if HAVE_STDLIB_H
# include <stdlib.h>
#endif
#if HAVE_STDIO_H
# include <stdio.h>
#endif
#if HAVE_INTTYPES_H
# include <inttypes.h>
#endif
#if HAVE_STRING_H
# include <string.h>
#endif
#if HAVE_TIME_H
# include <time.h>
#endif
#if HAVE_NETINET_IN_H
# include <netinet/in.h>
#endif

#ifndef NO_SSL

#include <openssl/ssl.h>

#ifdef WIN32
typedef SOCKET socket_t;
#else
typedef int socket_t; 
#endif

#ifdef DEBUG
# include <assert.h>
#else
# define assert(x) // x
#endif

#include "pftp_default.h"
#include "pftp_settings.h"
#include "pftp.h"
#include "pftp_speed.h"
#include "pftp_sftp.h"
#include "pftp_internal.h"
#include "pftp_ssh_packet.h"
#include "pftp_ssh.h"
#include "pftp_utf8.h"
#include "pftp_ssh_userauth.h"

#ifdef WITH_DMALLOC
# include <dmalloc.h>
#endif

typedef struct {
    uint8_t type;

    /* FAILED */
    char *name_list;
    int partial_success;

    /* BANNER */
    char *message;

    /* PASSWD_CHANGEREQ */
    char *prompt;

    /* BANNER & PASSWD_CHANGEREQ */
    char *lang;
} ssh_userauth_response_t;

#define SSH_MSG_USERAUTH_REQUEST 50
#define SSH_MSG_USERAUTH_FAILURE 51
#define SSH_MSG_USERAUTH_SUCCESS 52
#define SSH_MSG_USERAUTH_BANNER  53

#define SSH_MSG_USERAUTH_PASSWD_CHANGEREQ   60

static int _send_userauth_request_pass(pftp_ssh_t ssh, const char *username, 
				       const char *password);
static int _send_userauth_request_none(pftp_ssh_t ssh, const char *username);
#if 0
static int _send_userauth_request_chpass(pftp_ssh_t ssh, const char *username, 
					 const char *oldpass, 
					 const char *newpass);
#endif
static int _get_userauth_response(pftp_ssh_t ssh,
				  ssh_userauth_response_t **response);
static void _free_userauth_response(ssh_userauth_response_t **response);

int pftp_ssh_userauth(pftp_server_t ftp, pftp_ssh_t ssh)
{
    ssh_userauth_response_t *response = NULL;
    int none_method;
    none_method = 1;
    
    if (pftp_ssh_request_service(ssh, "ssh-userauth")) {
	return -1;
    }
    
    if (!ftp->settings->username || !ftp->settings->username[0]) {
	if (!pftp_userauth(ftp, &ftp->settings->username, NULL)) {
	    return -1;
	}
    }
    
    for (;;) {
	if (none_method) {
	    if (_send_userauth_request_none(ssh, ftp->settings->username)) {
		return -1;
	    }
	} else {
	    if (!ftp->settings->password || !ftp->settings->password[0]) {
		if (!pftp_userauth(ftp, NULL, &ftp->settings->password)) {
		    return -1;
		}
	    }
	    
	    if (_send_userauth_request_pass(ssh, ftp->settings->username, 
					    ftp->settings->password)) {
		return -1;
	    }
	}
	
	for (;;) {
	    if (_get_userauth_response(ssh, &response)) {
		_free_userauth_response(&response);
		return -1;
	    }
	    
	    if (response->type == SSH_MSG_USERAUTH_SUCCESS) {
		/* OK. Goodie then. */
		_free_userauth_response(&response);
		return 0;
	    } else if (response->type == SSH_MSG_USERAUTH_FAILURE) {
		if (none_method) {
		    none_method = 0;
		    _free_userauth_response(&response);
		    break;
		} else {
		    _free_userauth_response(&response);
		    if (!pftp_userauth(ftp, &ftp->settings->username,
				       &ftp->settings->password)) {
			return -1;
		    }
		    break;
		}
	    } else if (response->type == SSH_MSG_USERAUTH_BANNER) {
		pftp_status_message(ftp, "SSH: Userauth: Banner message `%s'",
				    response->message);
		_free_userauth_response(&response);
		continue;
	    } else {
		pftp_status_message(ftp, 
				    "SSH: Userauth: Unknown response type `%u'",
				    response->type);
		_free_userauth_response(&response);
		return -1;
	    }
	}
    }
    
    return -1;
}

int _send_userauth_request_none(pftp_ssh_t ssh, const char *username)
{
    int ret;
    pftp_ssh_pkt_t pkg;
    char *user;
    pkg = pftp_ssh_create_pkt(SSH_MSG_USERAUTH_REQUEST);
    user = strdup(username);    
    pftp_make_unicode_utf8(&user);
    pftp_ssh_pkt_put_string(pkg, user);
    free(user);
    pftp_ssh_pkt_put_string(pkg, "ssh-connection");
    pftp_ssh_pkt_put_string(pkg, "none");
    
    ret = pftp_ssh_send(ssh, pkg, 0);
    pftp_ssh_pkt_free(pkg);
    return ret;
}

int _send_userauth_request_pass(pftp_ssh_t ssh,
				const char *username, 
				const char *password)
{
    pftp_ssh_pkt_t pkg;
    char *pass, *user;
    int ret;
    pkg = pftp_ssh_create_pkt(SSH_MSG_USERAUTH_REQUEST);
    user = strdup(username);
    pass = strdup(password);
    pftp_make_unicode_utf8(&user);
    pftp_make_unicode_utf8(&pass);
    
    pftp_ssh_pkt_put_string(pkg, user);
    free(user);
    pftp_ssh_pkt_put_string(pkg, "ssh-connection");
    pftp_ssh_pkt_put_string(pkg, "password");
    pftp_ssh_pkt_put_char(pkg, '\0');
    pftp_ssh_pkt_put_string(pkg, pass);
    free(pass);
    
    ret = pftp_ssh_send(ssh, pkg, 0);
    pftp_ssh_pkt_free(pkg);
    return ret;
}

#if 0
int _send_userauth_request_chpass(pftp_ssh_t ssh,
				  const char *username, 
				  const char *oldpass,
				  const char *newpass)
{
    pftp_ssh_pkt_t pkg;
    char *pass1, *pass2, *user;
    int ret;
    pkg = pftp_ssh_create_pkt(SSH_MSG_USERAUTH_REQUEST);
    user = strdup(username);
    pass1 = strdup(oldpass);
    pass2 = strdup(newpass);
    pftp_make_unicode_utf8(&user);
    pftp_make_unicode_utf8(&pass1);
    pftp_make_unicode_utf8(&pass2);
    
    pftp_ssh_pkt_put_string(pkg, user);
    free(user);
    pftp_ssh_pkt_put_string(pkg, "ssh-connection");
    pftp_ssh_pkt_put_string(pkg, "password");
    pftp_ssh_pkt_put_char(pkg, '\1');
    pftp_ssh_pkt_put_string(pkg, pass1);
    free(pass1);
    pftp_ssh_pkt_put_string(pkg, pass2);
    free(pass2);
    
    ret = pftp_ssh_send(ssh, pkg, 0);
    pftp_ssh_pkt_free(pkg);
    return ret;
}
#endif

int _get_userauth_response(pftp_ssh_t ssh, 
			   ssh_userauth_response_t **response)
{
    pftp_ssh_pkt_t pkg;
    
    assert(*response == NULL);
    
    pkg = pftp_ssh_get(ssh, 0);
    
    if (!pkg)
	return -1;
    
    *response = malloc(sizeof(ssh_userauth_response_t));
    memset(*response, 0, sizeof(ssh_userauth_response_t));

    (*response)->type = pftp_ssh_pkt_msg(pkg);
    
    switch ((*response)->type) {
    case SSH_MSG_USERAUTH_SUCCESS:
	/* No arguments that I know of */
	break;
    case SSH_MSG_USERAUTH_FAILURE: {
	int valid;
	(*response)->name_list = pftp_ssh_pkt_get_string(pkg);
	if (!(*response)->name_list) {
	    free((*response));
	    pftp_ssh_pkt_free(pkg);
	    *response = NULL;
	    return -1;
	}
	(*response)->partial_success = (pftp_ssh_pkt_get_char(pkg, &valid) 
					== '\1');
	if (!valid) {
	    free((*response)->name_list);
	    free((*response));
	    pftp_ssh_pkt_free(pkg);
	    *response = NULL;
	    return -1;
	}
    }; break;
    case SSH_MSG_USERAUTH_BANNER:
	(*response)->message = pftp_ssh_pkt_get_string(pkg);
	if (!((*response)->message)) {
	    free((*response));
	    pftp_ssh_pkt_free(pkg);
	    *response = NULL;
	    return -1;
	}
	pftp_parse_unicode_utf8((*response)->message);
	(*response)->lang = pftp_ssh_pkt_get_string(pkg);
	if (!((*response)->lang)) {
	    free((*response)->message);
	    free((*response));
	    pftp_ssh_pkt_free(pkg);
	    *response = NULL;
	    return -1;
	}
	break;
    case SSH_MSG_USERAUTH_PASSWD_CHANGEREQ:
	(*response)->prompt = pftp_ssh_pkt_get_string(pkg);
	if (!((*response)->prompt)) {
	    free((*response));
	    pftp_ssh_pkt_free(pkg);
	    *response = NULL;
	    return -1;
	}
	pftp_parse_unicode_utf8((*response)->prompt);
	(*response)->lang = pftp_ssh_pkt_get_string(pkg);
	if (!((*response)->lang)) {
	    free((*response)->prompt);
	    free((*response));
	    pftp_ssh_pkt_free(pkg);
	    *response = NULL;
	    return -1;
	}
	break;
    default:
	/* Take anything */
#ifdef DEBUG
	fprintf(stderr, "SFTP: Unknown packet `%u'.\n",
		pftp_ssh_pkt_msg(pkg));
#endif
	break;
    }
    
    pftp_ssh_pkt_free(pkg);
    
    return 0;
}

void _free_userauth_response(ssh_userauth_response_t **response)
{
    if (*response) {
	switch ((*response)->type) {
	case SSH_MSG_USERAUTH_SUCCESS:
	    break;
	case SSH_MSG_USERAUTH_FAILURE:
	    free((*response)->name_list);
	    break;
	case SSH_MSG_USERAUTH_BANNER:
	    free((*response)->message);
	    free((*response)->lang);
	    break;
	case SSH_MSG_USERAUTH_PASSWD_CHANGEREQ:
	    free((*response)->prompt);
	    free((*response)->lang);
	    break;
	default:
	    /* Do nothing */
	    break;
	}
	
	free((*response));
	*response = NULL;
    }
}

#endif /* NO_SSL */
