/*
 * Amanda, The Advanced Maryland Automatic Network Disk Archiver
 * Copyright (c) 1993 University of Maryland
 * All Rights Reserved.
 *
 * Permission to use, copy, modify, distribute, and sell this software and its
 * documentation for any purpose is hereby granted without fee, provided that
 * the above copyright notice appear in all copies and that both that
 * copyright notice and this permission notice appear in supporting
 * documentation, and that the name of U.M. not be used in advertising or
 * publicity pertaining to distribution of the software without specific,
 * written prior permission.  U.M. makes no representations about the
 * suitability of this software for any purpose.  It is provided "as is"
 * without express or implied warranty.
 *
 * U.M. DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, INCLUDING ALL
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS, IN NO EVENT SHALL U.M.
 * BE LIABLE FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
 * OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
 * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
 *
 * Author: James da Silva, Systems Design and Analysis Group
 *			   Computer Science Department
 *			   University of Maryland at College Park
 */
/*
 * krb4-security.c - helper functions for kerberos v4 security.
 */
#include "amanda.h"
#include "krb4-security.h"
#include "protocol.h"
#include "diskfile.h"

#define HOSTNAME_INSTANCE inst

void kerberos_service_init()
{
    int rc;
    char hostname[256], inst[256], realm[256];

    gethostname(hostname, sizeof(hostname));
    hostname[255] = '\0';

    host2krbname(hostname, inst, realm);

    rc = krb_get_svc_in_tkt(SERVER_HOST_PRINCIPLE, SERVER_HOST_INSTANCE,
			    realm, "krbtgt", realm, TICKET_LIFETIME,
			    SERVER_HOST_KEY_FILE);
    if(rc) error("could not get ticket-granting-ticket: %s", krb_err_txt[rc]);

    krb_set_lifetime(TICKET_LIFETIME);
}


unsigned long kerberos_cksum(str)
char *str;
{
    des_cblock seed;

    memset(seed, 0, sizeof(seed));
    return quad_cksum(str, NULL, strlen(str), 1, seed);
}

struct hostent *host2krbname(alias, inst, realm)
char *alias, *inst, *realm;
{
    struct hostent *hp;
    char *s, *d;

    if((hp = gethostbyname(alias)) == 0) return 0;

    /* get inst name: like krb_get_phost, but avoid multiple gethostbyname */

    for(s = hp->h_name, d = inst; *s && *s != '.'; s++, d++)
	*d = isupper(*s)? tolower(*s) : *s;
    *d = '\0';

    /* get realm name: krb_realmofhost always returns *something* */
    strcpy(realm, krb_realmofhost(hp->h_name));

    return hp;
}

void encrypt_data(data, length, key)
char *data;
int length;
des_cblock key;
{
    des_key_schedule sched;

    des_key_sched(key, sched);
    des_pcbc_encrypt(data, data, length, sched, key, DES_ENCRYPT);
}


void decrypt_data(data, length, key)
char *data;
int length;
des_cblock key;
{
    des_key_schedule sched;

    des_key_sched(key, sched);
    des_pcbc_encrypt(data, data, length, sched, key, DES_DECRYPT);
}


int kerberos_handshake(fd, key)
int fd;
des_cblock key;
{
    int rc;
    struct timeval local, localenc, remote, rcvlocal;
    struct timezone tz;
    char *strerror();

    gettimeofday(&local, &tz);
    memcpy(&localenc, &local, sizeof local);

    localenc.tv_sec  = htonl(localenc.tv_sec );
    localenc.tv_usec = htonl(localenc.tv_usec);
    encrypt_data(&localenc, sizeof local, key);

    if((rc = write(fd, &localenc, sizeof local)) < sizeof local)
	error("kerberos_handshake write error: [%s]", 
	      rc == -1? strerror(errno) : "short write");
    
    if((rc = read(fd, &remote, sizeof remote)) < sizeof remote)
	error("kerberos_handshake read error: [%s]", 
	      rc == -1? strerror(errno) : "short read");
    
    decrypt_data(&remote, sizeof remote, key);
    remote.tv_sec  = ntohl(remote.tv_sec);
    remote.tv_usec = ntohl(remote.tv_usec);

    /* XXX do timestamp checking here */

    remote.tv_sec += 1;
    remote.tv_usec += 1;

    remote.tv_sec  = htonl(remote.tv_sec);
    remote.tv_usec = htonl(remote.tv_usec);
    encrypt_data(&remote, sizeof remote, key);
    
    if((rc = write(fd, &remote, sizeof remote)) < sizeof remote)
	error("kerberos_handshake write2 error: [%s]", 
	      rc == -1? strerror(errno) : "short write");
    
    if((rc = read(fd, &rcvlocal, sizeof rcvlocal)) < sizeof rcvlocal)
	error("kerberos_handshake read2 error: [%s]", 
	      rc == -1? strerror(errno) : "short read");

    decrypt_data(&rcvlocal, sizeof rcvlocal, key);
    rcvlocal.tv_sec  = ntohl(rcvlocal.tv_sec);
    rcvlocal.tv_usec = ntohl(rcvlocal.tv_usec);

    return (rcvlocal.tv_sec  == local.tv_sec + 1) &&
	   (rcvlocal.tv_usec == local.tv_usec + 1);
}

des_cblock *host2key(hostp)
host_t *hostp;
{
    static des_cblock key;
    char inst[256], realm[256];
    CREDENTIALS cred;

    if(host2krbname(hostp->hostname, inst, realm))
	krb_get_cred(CLIENT_HOST_PRINCIPLE, CLIENT_HOST_INSTANCE, realm,&cred);

    memcpy(key, cred.session, sizeof key);
    return key;
}

int check_mutual_authenticator(key, pkt, p)
des_cblock *key;
pkt_t *pkt;
proto_t *p;
{
    char astr[256];
    union {
	char pad[8];
	unsigned long i;
    } mutual;
    int len;

    if(pkt->security == NULL)
	return 0;

    if(sscanf(pkt->security, "MUTUAL-AUTH %[^\n]", astr) != 1)
	return 0;

    /* XXX - goddamn it this is a worm-hole */
    astr2bin(astr, &mutual, &len);

    decrypt_data(&mutual, len, key);
    return mutual.i == p->auth_cksum + 1;
}


/* ---------------- */

/* XXX - I'm getting astrs with the high bit set in the debug output!?? */

#define hex_digit(d)	("0123456789ABCDEF"[(d)])
#define unhex_digit(h)	(((h) - '0') > 9? ((h) - 'A' + 10) : ((h) - '0'))

char *bin2astr(buf, len)
char *buf;
int len;
{
    char *str, *p, *q;
    int slen, i, needs_quote;

    /* first pass, calculate string len */

    slen = needs_quote = 0; p = buf;
    if(*p == '"') needs_quote = 1;	/* special case */
    for(i=0;i<len;i++) {
	if(!isgraph(*p)) needs_quote = 1;
	if(isprint(*p) && *p != '$' && *p != '"')
	    slen += 1;
	else 
	    slen += 3;
	p++;
    }
    if(needs_quote) slen += 2;

    /* 2nd pass, allocate string and fill it in */

    str = (char *)alloc(slen+1);
    p = buf;
    q = str;
    if(needs_quote) *q++ = '"';
    for(i=0;i<len;i++) {
	if(isprint(*p) && *p != '$' && *p != '"') 
	    *q++ = *p++;
	else {
	    *q++ = '$';
	    *q++ = hex_digit((*p >> 4) & 0xF);
	    *q++ = hex_digit(*p & 0xF);
	    p++;
	}
    }
    if(needs_quote) *q++ = '"';
    *q = '\0';
    if(q-str != slen)
	printf("bin2str: hmmm.... calculated %d got %d\n",
	       slen, q-str);
    return str;
}

void astr2bin(astr, buf, lenp)
char *astr, *buf;
int  *lenp;
{
    char *p, *q, c;

    p = astr; q = buf;

    if(*p != '"') {
	/* strcpy, but without the null */
	while(*p) *q++ = *p++;
	*lenp = q-buf;
	return;
    }

    p++;
    while(*p != '"') {
	if(*p != '$') 
	    *q++ = *p++;
	else {
	    *q++ = (unhex_digit(p[1]) << 4) + unhex_digit(p[2]);
	     p  += 3;
	}
    }
    if(p-astr+1 != strlen(astr))
	printf("astr2bin: hmmm... short inp exp %d got %d\n",
	       strlen(astr), p-astr+1);
    *lenp = q-buf;
}

/* -------------------------- */
/* debug routines */

print_hex(str,buf,len)
char *str;
unsigned char *buf;
int len;
{
    int i;

    printf("%s:", str);
    for(i=0;i<len;i++) {
	if(i%25 == 0) putchar('\n');
	printf(" %02X", buf[i]);
    }
    putchar('\n');
}

print_ticket(str, tktp)
char *str;
KTEXT tktp;
{
    int i;
    printf("%s: length %d chk %X\n", str, tktp->length, tktp->mbz);
    print_hex("ticket data", tktp->dat, tktp->length);
    fflush(stdout);
}

print_auth(authp)
AUTH_DAT *authp;
{
    printf("\nAuth Data:\n");
    printf("  Principal \"%s\" Instance \"%s\" Realm \"%s\"\n",
	   authp->pname, authp->pinst, authp->prealm);
    printf("  cksum %d life %d keylen %d\n", authp->checksum, 
	   authp->life, sizeof(authp->session));
    print_hex("session key", authp->session, sizeof(authp->session));
    fflush(stdout);
}

print_credentials(credp)
CREDENTIALS *credp;
{
    printf("\nCredentials:\n");
    printf("  service \"%s\" instance \"%s\" realm \"%s\" life %d kvno %d\n",
	   credp->service, credp->instance, credp->realm, credp->lifetime,
	   credp->kvno);
    print_hex("session key", credp->session, sizeof(credp->session));
    print_hex("ticket", credp->ticket_st.dat, credp->ticket_st.length);
    fflush(stdout);
}
