/*-
 * Copyright (c) 1995 Berkeley Software Design, Inc. All rights reserved.
 * The Berkeley Software Design Inc. software License Agreement specifies
 * the terms and conditions for redistribution.
 *
 *	BSDI login_cap.c,v 2.7 1995/10/12 06:05:51 prb Exp
 */
#include <sys/types.h>
#include <sys/stat.h>
#include <sys/time.h>
#include <sys/resource.h>

#include <errno.h>
#include <fcntl.h>
#include <limits.h>
#include <paths.h>
#include <pwd.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <syslog.h>
#include <unistd.h>

#include "login_cap.h"

extern char *__progname;

static	char *classfiles[] = { _PATH_LOGIN_CONF, 0 };
static	char *_authtypes[] = { LOGIN_DEFSTYLE, 0 };
static	void setuserpath __P((login_cap_t *, char *));
static	u_quad_t multiply __P((u_quad_t, u_quad_t));
static	u_quad_t strtolimit __P((char *, char **, int));
static	u_quad_t strtosize __P((char *, char **, int));

login_cap_t *
login_getclass(pwd)
	struct passwd *pwd;
{
	login_cap_t *lc;
	int res;

	if ((lc = malloc(sizeof(login_cap_t))) == NULL) {
		syslog(LOG_ERR, "%s:%d malloc: %m", __FILE__, __LINE__);
		return (0);
	}

	lc->lc_cap = 0;
	lc->lc_class = 0;
	lc->lc_style = 0;

	if (!pwd || (lc->lc_class = pwd->pw_class)[0] == '\0')
		lc->lc_class = LOGIN_DEFCLASS;

    	if ((lc->lc_class = strdup(lc->lc_class)) == NULL) {
		syslog(LOG_ERR, "%s:%d strdup: %m", __FILE__, __LINE__);
		free(lc);
		return (0);
	}

	if ((res = cgetent(&lc->lc_cap, classfiles, lc->lc_class)) != 0 ) {
		lc->lc_cap = 0;
		switch (res) {
		case 1: 
			syslog(LOG_ERR, "%s: %s: couldn't resolve 'tc'",
				__progname, lc->lc_class);
			break;
		case -1:
			if ((res = open(classfiles[0], 0)) >= 0)
				close(res);
			if ((!pwd || pwd->pw_class[0] == '\0') && res < 0)
				return (lc);
			syslog(LOG_ERR, "%s: %s: unknown class",
				__progname, lc->lc_class);
			break;
		case -2:
			syslog(LOG_ERR, "%s: %s: getting class information: %m",
				__progname, lc->lc_class);
			break;
		case -3:
			syslog(LOG_ERR, "%s: %s: 'tc' reference loop",
				__progname, lc->lc_class);
			break;
		default:
			syslog(LOG_ERR, "%s: %s: unexpected cgetent error",
				__progname, lc->lc_class);
			break;
		}
		if (!pwd || pwd->pw_uid) {
			free(lc->lc_class);
			free(lc);
			return (0);
		}
	}
	return (lc);
}

char *
login_getstyle(lc, style, atype)
	login_cap_t *lc;
	char *style;
	char *atype;
{
    	char **authtypes = _authtypes;
	char *auths, *ta;
    	char *f1, **f2;
	int i;

	f1 = 0;
	f2 = 0;

    	if (!atype || !(auths = login_getcapstr(lc, atype, NULL, NULL)))
		auths = login_getcapstr(lc, "auth", "", NULL);

	if (auths) {
		if (*auths) {
			f1 = ta = auths = strdup(auths);
			if (!auths) {
				syslog(LOG_ERR, "%s: strdup: %m", __progname);
				return (0);
			}
			i = 1;
			while (*ta)
				if (*ta++ == ',')
					++i;
			f2 = authtypes = malloc(sizeof(char *) * i);
			if (!authtypes) {
				syslog(LOG_ERR, "%s: malloc: %m", __progname);
				free(f1);
				return (0);
			}
			i = 0;
			while (*auths) {
				authtypes[i] = auths;
				while (*auths && *auths != ',')
					++auths;
				if (*auths)
					*auths++ = 0;
				if (!*authtypes[i])
					authtypes[i] = LOGIN_DEFSTYLE;
				++i;
			}
			authtypes[i] = 0;
			
		}
	}

	if (!style)
		style = authtypes[0];
		
	while (*authtypes && strcmp(style, *authtypes))
		++authtypes;

	if (*authtypes == NULL || (auths = strdup(*authtypes)) == NULL) {
		if (f1)
			free(f1);
		if (f2)
			free(f2);
		if (*authtypes)
			syslog(LOG_ERR, "%s: strdup: %m", __progname);
		lc->lc_style = 0;
		return (0);
	}
	if (f2)
		free(f2);
	return (lc->lc_style = auths);
}

char *
login_getcapstr(lc, cap, def, e)
	login_cap_t *lc;
	char *cap;
	char *def;
	char *e;
{
	char *res;
	int stat;

	errno = 0;

    	if (!lc->lc_cap)
		return (def);

	switch (stat = cgetstr(lc->lc_cap, cap, &res)) {
	case -1:
		return (def);
	case -2:
		syslog(LOG_ERR, "%s: %s: getting capability %s: %m",
		    __progname, lc->lc_class, cap);
		return (e);
	default:
		if (stat >= 0) 
			return (res);
		syslog(LOG_ERR, "%s: %s: unexpected error with capability %s",
		    __progname, lc->lc_class, cap);
		return (e);
	}
}

quad_t
login_getcaptime(lc, cap, def, e)
	login_cap_t *lc;
	char *cap;
	quad_t def;
	quad_t e;
{
	char *ep;
	char *res, *sres;
	int stat;
	quad_t q, r;

	errno = 0;
    	if (!lc->lc_cap)
		return (def);

	switch (stat = cgetstr(lc->lc_cap, cap, &res)) {
	case -1:
		return (def);
	case -2:
		syslog(LOG_ERR, "%s: %s: getting capability %s: %m",
		    __progname, lc->lc_class, cap);
		errno = ERANGE;
		return (e);
	default:
		if (stat >= 0) 
			break;
		syslog(LOG_ERR, "%s: %s: unexpected error with capability %s",
		    __progname, lc->lc_class, cap);
		errno = ERANGE;
		return (e);
	}

	if (strcasecmp(res, "infinity") == 0)
		return (RLIM_INFINITY);

	errno = 0;

	q = 0;
	sres = res;
	while (*res) {
		r = strtoq(res, &ep, 0);
		if (!ep || ep == res ||
		    ((r == QUAD_MIN || r == QUAD_MAX) && errno == ERANGE)) {
invalid:
			syslog(LOG_ERR, "%s: %s:%s=%s: invalid time",
			    __progname, lc->lc_class, cap, sres);
			errno = ERANGE;
			return (e);
		}
		switch (*ep++) {
		case '\0':
			--ep;
			break;
		case 's': case 'S':
			break;
		case 'm': case 'M':
			r *= 60;
			break;
		case 'h': case 'H':
			r *= 60 * 60;
			break;
		case 'd': case 'D':
			r *= 60 * 60 * 24;
			break;
		case 'w': case 'W':
			r *= 60 * 60 * 24 * 7;
			break;
		case 'y': case 'Y':	/* Pretty absurd */
			r *= 60 * 60 * 24 * 365;
			break;
		default:
			goto invalid;
		}
		res = ep;
		q += r;
	}
	return (q);
}

quad_t
login_getcapnum(lc, cap, def, e)
	login_cap_t *lc;
	char *cap;
	quad_t def;
	quad_t e;
{
	char *ep;
	char *res;
	int stat;
	quad_t q;

	errno = 0;
    	if (!lc->lc_cap)
		return (def);

	switch (stat = cgetstr(lc->lc_cap, cap, &res)) {
	case -1:
		return (def);
	case -2:
		syslog(LOG_ERR, "%s: %s: getting capability %s: %m",
		    __progname, lc->lc_class, cap);
		errno = ERANGE;
		return (e);
	default:
		if (stat >= 0) 
			break;
		syslog(LOG_ERR, "%s: %s: unexpected error with capability %s",
		    __progname, lc->lc_class, cap);
		errno = ERANGE;
		return (e);
	}

	if (strcasecmp(res, "infinity") == 0)
		return (RLIM_INFINITY);

	errno = 0;
    	q = strtoq(res, &ep, 0);
	if (!ep || ep == res || ep[0] ||
	    ((q == QUAD_MIN || q == QUAD_MAX) && errno == ERANGE)) {
		syslog(LOG_ERR, "%s: %s:%s=%s: invalid number",
		    __progname, lc->lc_class, cap, res);
		errno = ERANGE;
		return (e);
	}
	return (q);
}

quad_t
login_getcapsize(lc, cap, def, e)
	login_cap_t *lc;
	char *cap;
	quad_t def;
	quad_t e;
{
	char *ep;
	char *res;
	int stat;
	quad_t q;

	errno = 0;

    	if (!lc->lc_cap)
		return (def);

	switch (stat = cgetstr(lc->lc_cap, cap, &res)) {
	case -1:
		return (def);
	case -2:
		syslog(LOG_ERR, "%s: %s: getting capability %s: %m",
		    __progname, lc->lc_class, cap);
		errno = ERANGE;
		return (e);
	default:
		if (stat >= 0) 
			break;
		syslog(LOG_ERR, "%s: %s: unexpected error with capability %s",
		    __progname, lc->lc_class, cap);
		errno = ERANGE;
		return (e);
	}

	errno = 0;
	q = strtolimit(res, &ep, 0);
	if (!ep || ep == res || (ep[0] && ep[1]) ||
	    ((q == QUAD_MIN || q == QUAD_MAX) && errno == ERANGE)) {
		syslog(LOG_ERR, "%s: %s:%s=%s: invalid size",
		    __progname, lc->lc_class, cap, res);
		errno = ERANGE;
		return (e);
	}
	return (q);
}

int
login_getcapbool(lc, cap, def)
	login_cap_t *lc;
	char *cap;
	u_int def;
{
    	if (!lc->lc_cap)
		return (def);

	return (cgetcap(lc->lc_cap, cap, ':') != NULL);
}

void
login_close(lc)
	login_cap_t *lc;
{
	if (lc) {
		if (lc->lc_class)
			free(lc->lc_class);
		if (lc->lc_cap)
			free(lc->lc_cap);
		if (lc->lc_style)
			free(lc->lc_style);
		free(lc);
	}
}

#define	CTIME	1
#define	CSIZE	2
#define	CNUMB	3

static struct {
	int	what;
	int	type;
	char *	name;
} r_list[] = {
	{ RLIMIT_CPU,		CTIME, "cputime", },
	{ RLIMIT_FSIZE,		CSIZE, "filesize", },
	{ RLIMIT_DATA,		CSIZE, "datasize", },
	{ RLIMIT_STACK,		CSIZE, "stacksize", },
	{ RLIMIT_RSS,		CSIZE, "memoryuse", },
	{ RLIMIT_MEMLOCK,	CSIZE, "memorylocked", },
	{ RLIMIT_NPROC,		CNUMB, "maxproc", },
	{ RLIMIT_NOFILE,	CNUMB, "openfiles", },
	{ RLIMIT_CORE,		CSIZE, "coredumpsize", },
	{ -1, 0, 0 }
};

static int
gsetrl(lc, what, name, type)
	login_cap_t *lc;
	int what;
	char *name;
	int type;
{
	struct rlimit rl;
	struct rlimit r;
	char name_cur[32];
	char name_max[32];

	sprintf(name_cur, "%s-cur", name);
	sprintf(name_max, "%s-max", name);

	if (getrlimit(what, &r)) {
		syslog(LOG_ERR, "%s: getting resource limit: %m", __progname);
		return (-1);
	}

#define	RCUR	r.rlim_cur
#define	RMAX	r.rlim_max

	switch (type) {
	case CTIME:
		RCUR = login_getcaptime(lc, name, RCUR, RCUR);
		RMAX = login_getcaptime(lc, name, RMAX, RMAX);
		rl.rlim_cur = login_getcaptime(lc, name_cur, RCUR, RCUR);
		rl.rlim_max = login_getcaptime(lc, name_max, RMAX, RMAX);
		break;
	case CSIZE:
		RCUR = login_getcapsize(lc, name, RCUR, RCUR);
		RMAX = login_getcapsize(lc, name, RMAX, RMAX);
		rl.rlim_cur = login_getcapsize(lc, name_cur, RCUR, RCUR);
		rl.rlim_max = login_getcapsize(lc, name_max, RMAX, RMAX);
		break;
	case CNUMB:
		RCUR = login_getcapnum(lc, name, RCUR, RCUR);
		RMAX = login_getcapnum(lc, name, RMAX, RMAX);
		rl.rlim_cur = login_getcapnum(lc, name_cur, RCUR, RCUR);
		rl.rlim_max = login_getcapnum(lc, name_max, RMAX, RMAX);
		break;
	default:
		return (-1);
	}

	if (rl.rlim_max < rl.rlim_cur)
		syslog(LOG_ERR, "%s: %s: inverted resource limits %s",
		    __progname, lc->lc_class, name);

	if (setrlimit(what, &rl)) {
		syslog(LOG_ERR, "%s: %s: setting resource limit %s: %m",
		    __progname, lc->lc_class, name);
		return (-1);
	}
#undef	RCUR
#undef	RMAX
	return (0);
}

int
setclasscontext(class, flags)
	char *class;
	u_int flags;
{
	struct passwd pwd;
	login_cap_t *lc;

	pwd.pw_class = class;
	pwd.pw_uid = 0;
	flags &= LOGIN_SETRESOURCES | LOGIN_SETPRIORITY | LOGIN_SETUMASK;

	lc = login_getclass(&pwd);
	return (lc ? setusercontext(lc, 0, 0, flags) : -1);
}

int
setusercontext(lc, pwd, uid, flags)
	login_cap_t *lc;
	struct passwd *pwd;
	uid_t uid;
	u_int flags;
{
	quad_t p;
	int i;

	if (!lc && !(lc = login_getclass(pwd)))
		return (-1);

	if (flags & LOGIN_SETRESOURCES)
		for (i = 0; r_list[i].name; ++i) 
			if (gsetrl(lc, r_list[i].what, r_list[i].name,
			    r_list[i].type))
				;

	if (flags & LOGIN_SETPRIORITY) {
		p = login_getcapnum(lc, "priority", 0, 0);

		if (p < PRIO_MIN) {
			syslog(LOG_ERR, "%s: %s: invalid priority %lld",
			    __progname, lc->lc_class, p);
			p = PRIO_MIN;
		}
		if (p > PRIO_MAX) {
			syslog(LOG_ERR, "%s: %s: invalid priority %lld",
			    __progname, lc->lc_class, p);
			p = PRIO_MAX;
		}

		setpriority(PRIO_PROCESS, 0, p);
	}

	if (flags & LOGIN_SETUMASK) {
		p = login_getcapnum(lc, "umask", LOGIN_DEFUMASK,LOGIN_DEFUMASK);

		if (p & ~ACCESSPERMS) {
			syslog(LOG_ERR, "%s: %s: invalid umask %lld",
			    __progname, lc->lc_class, p);
			p = LOGIN_DEFUMASK;
		}
		umask(p);
	}

	if (flags & LOGIN_SETGROUP) {
		if (setgid(pwd->pw_gid) < 0) {
			syslog(LOG_ERR, "%s: setgid(%d): %m", __progname,
			    pwd->pw_gid);
			return (-2);
		}

		initgroups(pwd->pw_name, pwd->pw_gid);
	}

	if (flags & LOGIN_SETLOGIN)
		if (setlogin(pwd->pw_name) < 0) {
			syslog(LOG_ERR, "%s: setlogin() failure: %m",
			    __progname);
			return (-1);
		}

	if (flags & LOGIN_SETUSER)
		if (setuid(uid) < 0) {
			syslog(LOG_ERR, "%s: setuid(%d): %m", uid, __progname);
			return (-2);
		}

	if (flags & LOGIN_SETPATH)
		setuserpath(lc, pwd->pw_dir);

	return (0);
}

static void
setuserpath(lc, home)
	login_cap_t *lc;
	char *home;
{
	int hlen, plen;
	int cnt = 0;
	char *path;
	char *p, *q;

	hlen = strlen(home);

	if (p = path = login_getcapstr(lc, "path", NULL, NULL)) {
		while (*p)
			if (*p++ == '~')
				++cnt;
		plen = (p - path) + cnt * (hlen + 1) + 1;
		p = path;
		if (q = path = malloc(plen)) {
			while (*q = *p++) {
				if (*q == ' ' || *q == '\t') {
					*q = ':';
					while (*p == ' ' || *p == '\t')
						++p;
				} else if (*q == '~') {
					strcpy(q, home);
					q += hlen;
				}
				++q;
			}
		} else
			path = _PATH_DEFPATH;
	} else
		path = _PATH_DEFPATH;
	if (setenv("PATH", path, 0))
		warn("could not set PATH: %s", strerror(errno));
}

/*
 * Convert an expression of the following forms
 * 	1) A number.
 *	2) A number followed by a b (mult by 512).
 *	3) A number followed by a k (mult by 1024).
 *	5) A number followed by a m (mult by 1024 * 1024).
 *	6) A number followed by a g (mult by 1024 * 1024 * 1024).
 *	7) A number followed by a t (mult by 1024 * 1024 * 1024 * 1024).
 *	8) Two or more numbers (with/without k,b,m,g, or t).
 *	   seperated by x (also * for backwards compatibility), specifying
 *	   the product of the indicated values.
 */
static
u_quad_t
strtosize(str, endptr, radix)
	char *str;
	char **endptr;
	int radix;
{
	u_quad_t num, num2, t;
	char *expr, *expr2;

	errno = 0;
	num = strtouq(str, &expr, radix);
	if (errno || expr == str) {
		if (endptr)
			*endptr = expr;
		return (num);
	}

	switch(*expr) {
	case 'b': case 'B':
		num = multiply(num, (u_quad_t)512);
		++expr;
		break;
	case 'k': case 'K':
		num = multiply(num, (u_quad_t)1024);
		++expr;
		break;
	case 'm': case 'M':
		num = multiply(num, (u_quad_t)1024 * 1024);
		++expr;
		break;
	case 'g': case 'G':
		num = multiply(num, (u_quad_t)1024 * 1024 * 1024);
		++expr;
		break;
	case 't': case 'T':
		num = multiply(num, (u_quad_t)1024 * 1024);
		num = multiply(num, (u_quad_t)1024 * 1024);
		++expr;
		break;
	}

	if (errno)
		goto erange;

	switch(*expr) {
	case '*':			/* Backward compatible. */
	case 'x':
		t = num;
		num2 = strtosize(expr+1, &expr2, radix);
		if (errno) {
			expr = expr2;
			goto erange;
		}

		if (expr2 == expr + 1) {
			if (endptr)
				*endptr = expr;
			return (num);
		}
		expr = expr2;
		num = multiply(num, num2);
		if (errno)
			goto erange;
		break;
	}
	if (endptr)
		*endptr = expr;
	return (num);
erange:
	if (endptr)
		*endptr = expr;
	errno = ERANGE;
	return (UQUAD_MAX);
}

static
u_quad_t
strtolimit(str, endptr, radix)
	char *str;
	char **endptr;
	int radix;
{
	if (strcasecmp(str, "infinity") == 0 || strcasecmp(str, "inf") == 0) {
		if (endptr)
			*endptr = str + strlen(str);
		return ((u_quad_t)RLIM_INFINITY);
	}
	return (strtosize(str, endptr, radix));
}

static u_quad_t
multiply(n1, n2)
	u_quad_t n1;
	u_quad_t n2;
{
	static int bpw = 0;
	u_quad_t m;
	u_quad_t r;
	int b1, b2;

	/*
	 * Get rid of the simple cases
	 */
	if (n1 == 0 || n2 == 0)
		return (0);
	if (n1 == 1)
		return (n2);
	if (n2 == 1)
		return (n1);

	/*
	 * sizeof() returns number of bytes needed for storage.
	 * This may be different from the actual number of useful bits.
	 */
	if (!bpw) {
		bpw = sizeof(u_quad_t) * 8;
		while (((u_quad_t)1 << (bpw-1)) == 0)
			--bpw;
	}

	/*
	 * First check the magnitude of each number.  If the sum of the
	 * magnatude is way to high, reject the number.  (If this test
	 * is not done then the first multiply below may overflow.)
	 */
	for (b1 = bpw; (((u_quad_t)1 << (b1-1)) & n1) == 0; --b1)
		; 
	for (b2 = bpw; (((u_quad_t)1 << (b2-1)) & n2) == 0; --b2)
		; 
	if (b1 + b2 - 2 > bpw) {
		errno = ERANGE;
		return (UQUAD_MAX);
	}

	/*
	 * Decompose the multiplication to be:
	 * h1 = n1 & ~1
	 * h2 = n2 & ~1
	 * l1 = n1 & 1
	 * l2 = n2 & 1
	 * (h1 + l1) * (h2 + l2)
	 * (h1 * h2) + (h1 * l2) + (l1 * h2) + (l1 * l2)
	 *
	 * Since h1 && h2 do not have the low bit set, we can then say:
	 *
	 * (h1>>1 * h2>>1 * 4) + ...
	 *
	 * So if (h1>>1 * h2>>1) > (1<<(bpw - 2)) then the result will
	 * overflow.
	 *
	 * Finally, if MAX - ((h1 * l2) + (l1 * h2) + (l1 * l2)) < (h1*h2)
	 * then adding in residual amout will cause an overflow.
	 */

	m = (n1 >> 1) * (n2 >> 1);

	if (m >= ((u_quad_t)1 << (bpw-2))) {
		errno = ERANGE;
		return (UQUAD_MAX);
	}

	m *= 4;

	r = (n1 & n2 & 1)
	  + (n2 & 1) * (n1 & ~(u_quad_t)1)
	  + (n1 & 1) * (n2 & ~(u_quad_t)1);

	if ((u_quad_t)(m + r) < m) {
		errno = ERANGE;
		return (UQUAD_MAX);
	}
	m += r;

	return (m);
}
