/*-
 * Copyright (c) 2015 Taylor R. Campbell
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 * 1. Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
 * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
 * SUCH DAMAGE.
 */

#if defined(__NetBSD__) && defined(_KERNEL)
#include <sys/types.h>
#include <sys/errno.h>
#include <sys/systm.h>
#else
#include <errno.h>
#include <limits.h>
#include <stdbool.h>
#include <stddef.h>
#include <stdint.h>
#include <string.h>
#endif

#include "pb.h"
#include "pb_decode.h"

struct decode {
	pb_decoder_callback_t		*callback;
	void				*arg;
};

static int	pb_decode_buf_partial(struct decode *, void *, size_t *);
static int	pb_decode_buf(struct decode *, void *, size_t);
static int	pb_skip(struct decode *, size_t);
static int	pb_decode_1_eof(struct decode *, uint8_t *, bool *);
static int	pb_decode_1(struct decode *, uint8_t *);

static int	pb_decode_by_hdr(struct decode *, struct pb_msg_hdr *);
static const struct pb_field *
		pb_find_field(const struct pb_msgdesc *, uint32_t);
static int	pb_decode_check_required(const struct pb_msgdesc *,
		    uint32_t[PB_MAX_REQUIRED_FIELDS], unsigned int);

static int	pb_skip_field(struct decode *, enum pb_wiretype);
static int	pb_skip_varint(struct decode *);
static int	pb_skip_length_delimited(struct decode *);

static int	pb_decode_field(struct decode *, unsigned char *,
		    const struct pb_field *, enum pb_wiretype,
		    uint32_t[PB_MAX_REQUIRED_FIELDS], unsigned int *);
static int	pb_decode_field_value(struct decode *,
		    const struct pb_field *, enum pb_wiretype,
		    unsigned char *);
static const struct pb_enumerand *
		pb_enumerand_by_number(const struct pb_enumeration *, int32_t);

static int	pb_decompose_tag(uint64_t, uint32_t *, enum pb_wiretype *);
static int	pb_decode_tag(struct decode *, uint32_t *, enum pb_wiretype *);
static int	pb_decode_varint_eof(struct decode *, uint64_t *, bool *);
static int	pb_decode_varint(struct decode *, uint64_t *);

static int	pb_decode_varint_u(struct decode *, uint64_t *);
static int	pb_decode_fixed32(struct decode *, uint32_t *);
static int	pb_decode_fixed64(struct decode *, uint64_t *);

static int	pb_decode_varint_s(struct decode *, int64_t *);
static int	pb_decode_zigzag(struct decode *, int64_t *);
static int	pb_decode_sfixed32(struct decode *, int32_t *);
static int	pb_decode_sfixed64(struct decode *, int64_t *);

static int	pb_decode_ieee32(struct decode *, float *);
static int	pb_decode_ieee64(struct decode *, double *);

static int	pb_decode_length(struct decode *, size_t *);

static int	pb_decode_submsg(struct decode *, const struct pb_msgdesc *,
		    struct pb_msg_hdr *, size_t);

static void	sort32(uint32_t *, size_t);

struct decode_memory {
	const uint8_t	*ptr;
	size_t		nleft;
};

static pb_decoder_callback_t decode_memory_callback;

static int
decode_memory_callback(void *cookie, void *buf, size_t *size)
{
	struct decode_memory *const M = cookie;
	size_t n = *size < M->nleft? *size : M->nleft;

	(void)memcpy(buf, M->ptr, n);
	M->ptr += n;
	M->nleft -= n;
	*size = n;

	return 0;
}

int
pb_decode_from_memory(struct pb_msg msg, const void *buf, size_t len)
{
	struct decode_memory M = { .ptr = buf, .nleft = len };

	return pb_decode(msg, &decode_memory_callback, &M);
}

int
pb_decode(struct pb_msg msg, pb_decoder_callback_t *callback, void *arg)
{
	struct pb_msg_hdr *const msg_hdr = (struct pb_msg_hdr *)msg.pbm_ptr;
	struct decode D = {
		.callback = callback,
		.arg = arg,
	};

	if (msg_hdr->pbmh_msgdesc != msg.pbm_msgdesc)
		/* XXX pb_bug */
		return EINVAL;
	return pb_decode_by_hdr(&D, msg_hdr);
}

/*
 * XXX It would be nice if we could have a pb_decode_ptr(D, &p, size)
 * which would store in p a pointer to a buffer of the requested size,
 * rather than copying it to another buffer, if possible.
 */

static int
pb_decode_buf_partial(struct decode *D, void *buf, size_t *size)
{

	return (*D->callback)(D->arg, buf, size);
}

static int
pb_decode_buf(struct decode *D, void *buf, size_t size)
{
	size_t rsize = size;
	int error;

	error = pb_decode_buf_partial(D, buf, &rsize);
	if (error)
		return error;
	if (rsize != size)
		return EIO;	/* XXX What error code?  */
	return 0;
}

static int
pb_skip(struct decode *D, size_t size)
{

	return pb_decode_buf(D, NULL, size);
}

static int
pb_decode_1_eof(struct decode *D, uint8_t *p, bool *eofp)
{
	size_t n = 1;
	int error;

	error = pb_decode_buf_partial(D, p, &n);
	if (error)
		return error;
	pb_assert(n <= 1);
	*eofp = (n == 0);
	return 0;
}

static int
pb_decode_1(struct decode *D, uint8_t *p)
{

	return pb_decode_buf(D, p, 1);
}

static int
pb_decode_by_hdr(struct decode *D, struct pb_msg_hdr *msg_hdr)
{
	unsigned char *const addr = (void *)msg_hdr;
	const struct pb_msgdesc *const msgdesc = msg_hdr->pbmh_msgdesc;
	uint32_t tag;
	enum pb_wiretype wiretype;
	const struct pb_field *field;
	uint32_t req_fields[PB_MAX_REQUIRED_FIELDS];
	unsigned int nreq_fields = 0;
	int error;

	while ((error = pb_decode_tag(D, &tag, &wiretype)) == 0) {
		/* Allow for zero-delimited or externally framed messages.  */
		if (tag == 0)
			break;
		field = pb_find_field(msgdesc, tag);
		if (field == NULL)
			error = pb_skip_field(D, wiretype);
		else
			error = pb_decode_field(D, addr, field, wiretype,
			    req_fields, &nreq_fields);
		if (error)
			break;
	}
	if (error)
		return error;

	error = pb_decode_check_required(msgdesc, req_fields, nreq_fields);
	if (error)
		return error;

	return 0;
}

static const struct pb_field *
pb_find_field(const struct pb_msgdesc *msgdesc, uint32_t tag)
{
	size_t start = 0, end = msgdesc->pbmd_nfields;

	while (start < end) {
		const size_t i = (start + ((end - start) / 2));

		if (tag < msgdesc->pbmd_fields[i].pbf_tag)
			end = i;
		else if (tag > msgdesc->pbmd_fields[i].pbf_tag)
			start = (i + 1);
		else
			return &msgdesc->pbmd_fields[i];
	}

	return NULL;
}

static int
pb_decode_check_required(const struct pb_msgdesc *msgdesc,
    uint32_t req_fields[PB_MAX_REQUIRED_FIELDS], unsigned int nreq_fields)
{
	unsigned int i, j;

	sort32(req_fields, nreq_fields);
	for (i = 0, j = 0; i < nreq_fields; i++, j++) {
		while (pb_assert(j < msgdesc->pbmd_nfields),
		    msgdesc->pbmd_fields[j].pbf_quant != PBQ_REQUIRED)
			j++;
		if (req_fields[i] != msgdesc->pbmd_fields[j].pbf_tag)
			return EIO; /* XXX What error code?  */
	}

	return 0;
}

static int
pb_skip_field(struct decode *D, enum pb_wiretype wiretype)
{

	switch (wiretype) {
	case PB_WIRETYPE_VARINT:
		return pb_skip_varint(D);
	case PB_WIRETYPE_32BIT:
		return pb_skip(D, 4);
	case PB_WIRETYPE_64BIT:
		return pb_skip(D, 8);
	case PB_WIRETYPE_LENGTH_DELIMITED:
		return pb_skip_length_delimited(D);

	default:
		return EIO;	/* XXX What error code?  */
	}
}

static int
pb_skip_varint(struct decode *D)
{
	uint8_t o;
	int error;

	do {
		error = pb_decode_1(D, &o);
		if (error)
			return error;
	} while ((o & 0x80) != 0);

	return 0;
}

static int
pb_skip_length_delimited(struct decode *D)
{
	size_t n;
	int error;

	error = pb_decode_length(D, &n);
	if (error)
		return error;
	return pb_skip(D, n);
}

static int
pb_decode_field(struct decode *D, unsigned char *addr,
    const struct pb_field *field, enum pb_wiretype wiretype,
    uint32_t req_fields[PB_MAX_REQUIRED_FIELDS], unsigned int *nreq_fields)
{
	int error;

	switch (field->pbf_quant) {
	case PBQ_REQUIRED:
		error = pb_decode_field_value(D, field, wiretype,
		    (addr + field->pbf_qu.required.offset));
		if (error)
			return error;
		pb_assert(*nreq_fields < (PB_MAX_REQUIRED_FIELDS + 1));
		req_fields[*nreq_fields++] = field->pbf_tag;
		return 0;
	case PBQ_OPTIONAL:
		*(bool *)(addr + field->pbf_qu.optional.present_offset) = true;
		return pb_decode_field_value(D, field, wiretype,
		    (addr + field->pbf_qu.optional.value_offset));
	case PBQ_REPEATED: {
		struct pb_repeated *const repeated =
		    (struct pb_repeated *)(addr +
			field->pbf_qu.repeated.hdr_offset);
		unsigned char *ptr;
		const size_t elemsize = pb_type_size(&field->pbf_type);
		size_t i;

		if ((0 < field->pbf_qu.repeated.maximum) &&
		    (field->pbf_qu.repeated.maximum <=
			pb_repeated_count(repeated)))
			return pb_skip_field(D, wiretype);
		error = pb_repeated_add(repeated, &i);
		if (error)
			return error;
		ptr = *(void *const *)(addr +
		    field->pbf_qu.repeated.ptr_offset);
		return pb_decode_field_value(D, field, wiretype,
		    (ptr + (i * elemsize)));
	}
	default:
		return EIO; /* XXX */
	}
}

static int
pb_decode_field_value(struct decode *D, const struct pb_field *field,
    enum pb_wiretype wiretype, unsigned char *value)
{
	int error;

	switch (field->pbf_type.pbt_type) {
		uint32_t u32;
		uint64_t u64;
		int32_t s32;
		int64_t s64;
		float f;
		double d;

#define	DECODE_CHECK(TYPE, WIRETYPE, DECODER, VAR, CHECK) do		      \
	{								      \
		if (wiretype != (WIRETYPE))				      \
			return EIO;	/* XXX What error code?  */	      \
		error = DECODER(D, &(VAR));				      \
		if (error)						      \
			return error;					      \
		CHECK;							      \
		*(TYPE *)value = (VAR);					      \
		return 0;						      \
	} while (0)

#define	DECODE(TYPE, WIRETYPE, DECODER, VAR)				      \
	DECODE_CHECK(TYPE, WIRETYPE, DECODER, VAR, do {} while (0))

#define	DECODE_MAX(TYPE, WIRETYPE, DECODER, VAR, MAXIMUM)		      \
	DECODE_CHECK(TYPE, WIRETYPE, DECODER, VAR, do {			      \
		if ((MAXIMUM) < (VAR))					      \
			return ERANGE; /* XXX What error code?  */	      \
	} while (0))

#define	DECODE_MINMAX(TYPE, WIRETYPE, DECODER, VAR, MINIMUM, MAXIMUM)	      \
	DECODE_CHECK(TYPE, WIRETYPE, DECODER, VAR, do {			      \
		if ((MAXIMUM) < (VAR))					      \
			return ERANGE; /* XXX What error code?  */	      \
	} while (0))

	case PB_TYPE_BOOL:
		DECODE_MAX(bool, PB_WIRETYPE_VARINT, pb_decode_varint_u, u64,
		    1);
	case PB_TYPE_UINT32:
		DECODE_MAX(uint32_t, PB_WIRETYPE_VARINT, pb_decode_varint_u,
		    u64, UINT32_MAX);
	case PB_TYPE_UINT64:
		DECODE(uint64_t, PB_WIRETYPE_VARINT, pb_decode_varint_u, u64);
	case PB_TYPE_FIXED32:
		DECODE(uint32_t, PB_WIRETYPE_32BIT, pb_decode_fixed32, u32);
	case PB_TYPE_FIXED64:
		DECODE(uint64_t, PB_WIRETYPE_64BIT, pb_decode_fixed64, u64);
	case PB_TYPE_INT32:
		DECODE_MINMAX(int32_t, PB_WIRETYPE_VARINT, pb_decode_varint_s,
		    s64, INT32_MIN, INT32_MAX);
	case PB_TYPE_INT64:
		DECODE(int64_t, PB_WIRETYPE_VARINT, pb_decode_varint_s, s64);
	case PB_TYPE_SINT32:
		DECODE_MINMAX(int32_t, PB_WIRETYPE_VARINT, pb_decode_zigzag,
		    s64, INT32_MIN, INT32_MAX);
	case PB_TYPE_SINT64:
		DECODE(int64_t, PB_WIRETYPE_VARINT, pb_decode_zigzag, s64);
	case PB_TYPE_SFIXED32:
		DECODE(int32_t, PB_WIRETYPE_32BIT, pb_decode_sfixed32, s32);
	case PB_TYPE_SFIXED64:
		DECODE(int64_t, PB_WIRETYPE_64BIT, pb_decode_sfixed64, s64);
	case PB_TYPE_ENUM:
		DECODE_CHECK(int32_t, PB_WIRETYPE_VARINT, pb_decode_varint_s,
		    s64, do {
			    const struct pb_type *const type =
				&field->pbf_type;
			    const struct pb_enumeration *const enumeration =
				type->pbt_u.enumerated.enumeration;

			    if ((s64 < INT32_MIN) || (INT32_MAX < s64))
				    return ERANGE; /* XXX What error code?  */
			    if (pb_enumerand_by_number(enumeration, s64) ==
				NULL)
				    return EIO; /* XXX What error code?  */
		    } while (0));
	case PB_TYPE_FLOAT:
		DECODE(float, PB_WIRETYPE_32BIT, pb_decode_ieee32, f);
	case PB_TYPE_DOUBLE:
		DECODE(double, PB_WIRETYPE_64BIT, pb_decode_ieee64, d);
	case PB_TYPE_BYTES: {
		struct pb_bytes *const bytes = (struct pb_bytes *)value;
		size_t size, tsize pb_attr_diagused;
		uint8_t *ptr;

		if (wiretype != PB_WIRETYPE_LENGTH_DELIMITED)
			return EIO; /* XXX What error code?  */
		error = pb_decode_length(D, &size);
		if (error)
			return error;
		error = pb_bytes_alloc(bytes, size);
		if (error)
			return error;
		ptr = pb_bytes_ptr_mutable(bytes, &tsize);
		pb_assert(tsize == size);
		return pb_decode_buf(D, ptr, size);
	}
	case PB_TYPE_STRING: {
		struct pb_string *const string = (struct pb_string *)value;
		size_t len;
		char *ptr;

		if (wiretype != PB_WIRETYPE_LENGTH_DELIMITED)
			return EIO; /* XXX What error code?  */
		error = pb_decode_length(D, &len);
		if (error)
			return error;
		error = pb_string_alloc(string, len);
		if (error)
			return error;
		pb_assert(pb_string_len(string) == len);
		ptr = pb_string_ptr_mutable(string);
		pb_assert(ptr[len] == '\0');
		error = pb_decode_buf(D, ptr, len);
		if (error)
			return error;
		error = pb_utf8_validate(ptr, len);
		if (error) {
			(void)memset(ptr, 0, len); /* paranoia */
			pb_string_set_ptr(string, "", 0);
			return error;
		}
		return 0;
	}
	case PB_TYPE_MSG: {
		struct pb_msg_hdr *msg_hdr = (struct pb_msg_hdr *)value;
		size_t size;

		if (wiretype != PB_WIRETYPE_LENGTH_DELIMITED)
			return EIO; /* XXX What error code?  */
		error = pb_decode_length(D, &size);
		if (error)
			return error;
		return pb_decode_submsg(D, field->pbf_type.pbt_u.msg.msgdesc,
		    msg_hdr, size);
	}
	default:
		return EIO;	/* XXX What error code?  */
        }
}

static const struct pb_enumerand *
pb_enumerand_by_number(const struct pb_enumeration *enumeration,
    int32_t number)
{
	size_t start = 0, end = enumeration->pben_nenumerands;

	while (start < end) {
		const size_t i = (start + ((end - start) / 2));

		if (number < enumeration->pben_enumerands[i].pbed_number)
			end = i;
		else if (number > enumeration->pben_enumerands[i].pbed_number)
			start = (i + 1);
		else
			return &enumeration->pben_enumerands[i];
	}

	return NULL;
}

/*
 * Tags and varints
 */

static int
pb_decompose_tag(uint64_t wtag, uint32_t *tag, enum pb_wiretype *wiretype)
{

	if ((wtag >> 3) & ~(uint64_t)0xffffffff)
		return EIO;	/* XXX What error code?  */
	*tag = ((wtag >> 3) & 0xffffffff);
	*wiretype = (wtag & 7);
	return 0;
}

static int
pb_decode_tag(struct decode *D, uint32_t *tag, enum pb_wiretype *wiretype)
{
	uint64_t wtag;
	bool eof;
	int error;

	error = pb_decode_varint_eof(D, &wtag, &eof);
	if (error)
		return error;
	if (eof) {
		*tag = 0;
		return 0;
	} else {
		return pb_decompose_tag(wtag, tag, wiretype);
	}
}

static int
pb_decode_varint_eof(struct decode *D, uint64_t *value, bool *eofp)
{
	uint8_t o;
	uint64_t v;
	unsigned int s = 0;
	int error;

	error = pb_decode_1_eof(D, &o, eofp);
	if (error)
		return error;
	if (*eofp)
		return 0;
	if ((o & 0x80) == 0) {
		*value = o;
		return 0;
	}

	v = (o & 0x7f);
	do {
		s += 7;
		if (s >= 32)
			return ERANGE; /* XXX What error code?  */
		error = pb_decode_1(D, &o);
		if (error)
			return error;
		v |= (uint64_t)(o & 0x7f) << s;
	} while ((o & 0x80) != 0);

	*value = v;
	return 0;
}

static int
pb_decode_varint(struct decode *D, uint64_t *value)
{
	bool eof;
	int error;

	error = pb_decode_varint_eof(D, value, &eof);
	if (error)
		return error;
	if (eof)
		return EIO;	/* XXX What error code?  */
	return 0;
}

/*
 * Unsigned integer formats
 */

static int
pb_decode_varint_u(struct decode *D, uint64_t *p)
{

	return pb_decode_varint(D, p);
}

static int
pb_decode_fixed32(struct decode *D, uint32_t *p)
{
	uint8_t buf[4];
	int error;

	error = pb_decode_buf(D, buf, sizeof buf);
	if (error)
		return error;

	*p = buf[0] |
	    ((uint32_t)buf[1] << 8) |
	    ((uint32_t)buf[2] << 16) |
	    ((uint32_t)buf[3] << 24);
	return 0;
}

static int
pb_decode_fixed64(struct decode *D, uint64_t *p)
{
	uint8_t buf[8];
	int error;

	error = pb_decode_buf(D, buf, sizeof buf);
	if (error)
		return error;

	*p = buf[0] |
	    ((uint64_t)buf[1] << 8) |
	    ((uint64_t)buf[2] << 16) |
	    ((uint64_t)buf[3] << 24) |
	    ((uint64_t)buf[4] << 32) |
	    ((uint64_t)buf[5] << 40) |
	    ((uint64_t)buf[6] << 48) |
	    ((uint64_t)buf[7] << 56);
	return 0;
}

/*
 * Signed integer formats
 *
 * XXX These assume two's-complement arithmetic.
 */

static int
pb_decode_varint_s(struct decode *D, int64_t *p)
{
	uint64_t u;
	int error;

	error = pb_decode_varint_u(D, &u);
	if (error)
		return error;

	*p = (int64_t)u;
	return 0;
}

static int
pb_decode_zigzag(struct decode *D, int64_t *p)
{
	uint64_t u;
	int error;

	error = pb_decode_varint_u(D, &u);
	if (error)
		return error;

	*p = (int64_t)(((u & 1) << 63) | (u >> 1));
	return 0;
}

static int
pb_decode_sfixed32(struct decode *D, int32_t *p)
{
	uint32_t u;
	int error;

	error = pb_decode_fixed32(D, &u);
	if (error)
		return error;

	*p = (int32_t)u;
	return 0;
}

static int
pb_decode_sfixed64(struct decode *D, int64_t *p)
{
	uint64_t u;
	int error;

	error = pb_decode_fixed64(D, &u);
	if (error)
		return error;

	*p = (int64_t)u;
	return 0;
}

static int
pb_decode_ieee32(struct decode *D, float *p)
{
	union { float f; uint32_t i; } u;
	int error;

	error = pb_decode_fixed32(D, &u.i);
	if (error)
		return error;

	*p = u.f;
	return 0;
}

static int
pb_decode_ieee64(struct decode *D, double *p)
{
	union { double f; uint64_t i; } u;
	int error;

	error = pb_decode_fixed64(D, &u.i);
	if (error)
		return error;

	*p = u.f;
	return 0;
}

/*
 * Length-delimited fields
 */

static int
pb_decode_length(struct decode *D, size_t *p)
{
	uint64_t u;
	int error;

	error = pb_decode_varint(D, &u);
	if (error)
		return error;
	if (SIZE_MAX < u)
		return ERANGE;	/* XXX What error code?  */

	*p = (size_t)u;
	return 0;
}

/*
 * Submessages
 */

struct submsg {
	size_t		sm_size;
	struct decode	*sm_D;
};

static pb_decoder_callback_t decode_submsg_callback;

static int
pb_decode_submsg(struct decode *D, const struct pb_msgdesc *msgdesc,
    struct pb_msg_hdr *msg_hdr, size_t size)
{
	struct submsg submsg = {
		.sm_size = size,
		.sm_D = D,
	};
	struct decode Dsub = {
		.callback = &decode_submsg_callback,
		.arg = &submsg,
	};

	if (msg_hdr->pbmh_msgdesc != msgdesc)
		return EINVAL;
	return pb_decode_by_hdr(&Dsub, msg_hdr);
}

#define	MIN(A,B)	((A) < (B)? (A) : (B))

static int
decode_submsg_callback(void *arg, void *buf, size_t *n)
{
	struct submsg *const submsg = arg;
	size_t nreq, n0;
	int error;

	nreq = *n;
	n0 = MIN(nreq, submsg->sm_size);
	error = pb_decode_buf_partial(submsg->sm_D, buf, &n0);
	if (error)
		return error;
	assert(n0 <= MIN(nreq, submsg->sm_size));
	*n = n0;
	submsg->sm_size -= n0;

	return 0;
}

/*
 * Trivial heap sort
 */

static size_t
parent(size_t i)
{
	return ((i - 1)/2);
}

static size_t
left(size_t i)
{
	return ((2*i) + 1);
}

static size_t
right(size_t i)
{
	return ((2*i) + 2);
}

static void
swap32(uint32_t *a, uint32_t *b)
{
	uint32_t t;

	t = *a;
	*a = *b;
	*b = t;
}

static void
heapify32(uint32_t *a, size_t node, size_t end)
{

	/*
	 * XXX Arithmetic overflow is not an issue here because the
	 * array size has a small bound, but it would be an issue if
	 * you made copypasta of this code elsewhere.
	 */
	while (left(node) <= end) {
		size_t largest = node;

		if ((left(node) <= end) && (a[largest] < a[left(node)]))
			largest = left(node);
		if ((right(node) <= end) && (a[largest] < a[right(node)]))
			largest = right(node);
		if (largest == node)
			break;
		swap32(&a[node], &a[largest]);
		node = largest;
	}
}

static void
sort32(uint32_t *a, size_t n)
{
	size_t start, end;

	if (n < 2)
		return;
	end = (n - 1);

	start = parent(end);
	do heapify32(a, start, end); while (0 < start--);

	while (0 < end) {
		swap32(&a[0], &a[end--]);
		heapify32(a, 0, end);
	}
}
