/*-
 * Copyright (c) 2010 Joerg Sonnenberger <joerg@NetBSD.org>
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 *
 * 1. Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in
 *    the documentation and/or other materials provided with the
 *    distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
 * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
 * FOR A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE
 * COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
 * INCIDENTAL, SPECIAL, EXEMPLARY OR CONSEQUENTIAL DAMAGES (INCLUDING,
 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED
 * AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
 * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT
 * OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
 * SUCH DAMAGE.
 */

/*
 * Derived from the public domain code of Daniel J. Bernstein.
 */

#include <sys/cdefs.h>
__RCSID("$NetBSD$");

//#include "namespace.h"
//#include "reentrant.h"

#include <inttypes.h>
#include <pthread.h>
#include <stdbool.h>
#include <stdlib.h>
#include <string.h>
#include <sys/sysctl.h>

static size_t reseed_volume = 256 * 1024; 
#ifdef _REENTRANT
static mutex_t mutex = MUTEX_INITIALIZER;
#define	LOCK()	mutex_lock(&mutex)
#define	UNLOCK()	mutex_unlock(&mutex)
#else
#define	LOCK()	do {} while (/* CONSTCOND */0)
#define	UNLOCK()	do {} while (/* CONSTCOND */0)
#endif

static struct state {
	bool init_done;
	size_t cur_output;
	size_t cur_volume;

	uint32_t input[16];
	union {
		uint32_t out32[16];
		uint8_t out8[64];
	} output;
} global_state = { .init_done = false, .cur_volume = SIZE_MAX, .cur_output = 64 };

#ifdef _REENTRANT
static void
arc4random_fork_prepare(void)
{
	LOCK();
}

static void
arc4random_fork_parent(void)
{
	UNLOCK();
}
#endif

static void
arc4random_fork_child(void)
{
	global_state.cur_volume = SIZE_MAX;
	global_state.cur_output = 64;
	UNLOCK();
}

static void
init(struct state *state)
{

	state->init_done = true;

#ifdef _REENTRANT
	pthread_atfork(arc4random_fork_prepare, arc4random_fork_parent,
	    arc4random_fork_child);
#else
	pthread_atfork(NULL, NULL, arc4random_fork_child);
#endif
}

static void
rekey(struct state *state, const uint32_t key[10])
{
	static uint32_t magic[] = { 0x61707865, 0x3320646e,
	    0x79622d32, 0x6b206574 };

	state->input[0] = magic[0];
	state->input[1] = key[0];
	state->input[2] = key[1];
	state->input[3] = key[2];
	state->input[4] = key[3];
	state->input[5] = magic[1];
	state->input[6] = key[4];
	state->input[7] = key[5];
	state->input[8] = 0;
	state->input[9] = 0;
	state->input[10] = magic[2];
	state->input[11] = key[6];
	state->input[12] = key[7];
	state->input[13] = key[8];
	state->input[14] = key[9];
	state->input[15] = magic[3];
}

static void
stir(struct state *state)
{
	static const int mib[2] = { CTL_KERN, KERN_ARND };
	uint32_t buf[10];
	size_t len;

	len = sizeof(buf);
	sysctl(mib, 2, buf, &len, NULL, 0);
	/* Ignore error, we can't really do anything if sysctl fails. */
	rekey(state, buf);
	state->cur_volume = 0;
}

__inline static uint32_t
rotate(uint32_t v, int c)
{
	return (v << c) | (v >> (32 -c));
}

static void
extract_block(struct state *state)
{
	uint32_t x[16];
	int i;

	if (__predict_false(state->cur_volume++ >= reseed_volume)) {
		if (__predict_false(!state->init_done))
			init(state);
		stir(state);
	}

	for (i = 0;i < 16;++i)
		x[i] = state->input[i];
	for (i = 8;i > 0;i -= 2) {
		x[ 4] ^= rotate(x[ 0] + x[12], 7);
		x[ 8] ^= rotate(x[ 4] + x[ 0], 9);
		x[12] ^= rotate(x[ 8] + x[ 4],13);
		x[ 0] ^= rotate(x[12] + x[ 8],18);
		x[ 9] ^= rotate(x[ 5] + x[ 1], 7);
		x[13] ^= rotate(x[ 9] + x[ 5], 9);
		x[ 1] ^= rotate(x[13] + x[ 9],13);
		x[ 5] ^= rotate(x[ 1] + x[13],18);
		x[14] ^= rotate(x[10] + x[ 6], 7);
		x[ 2] ^= rotate(x[14] + x[10], 9);
		x[ 6] ^= rotate(x[ 2] + x[14],13);
		x[10] ^= rotate(x[ 6] + x[ 2],18);
		x[ 3] ^= rotate(x[15] + x[11], 7);
		x[ 7] ^= rotate(x[ 3] + x[15], 9);
		x[11] ^= rotate(x[ 7] + x[ 3],13);
		x[15] ^= rotate(x[11] + x[ 7],18);
		x[ 1] ^= rotate(x[ 0] + x[ 3], 7);
		x[ 2] ^= rotate(x[ 1] + x[ 0], 9);
		x[ 3] ^= rotate(x[ 2] + x[ 1],13);
		x[ 0] ^= rotate(x[ 3] + x[ 2],18);
		x[ 6] ^= rotate(x[ 5] + x[ 4], 7);
		x[ 7] ^= rotate(x[ 6] + x[ 5], 9);
		x[ 4] ^= rotate(x[ 7] + x[ 6],13);
		x[ 5] ^= rotate(x[ 4] + x[ 7],18);
		x[11] ^= rotate(x[10] + x[ 9], 7);
		x[ 8] ^= rotate(x[11] + x[10], 9);
		x[ 9] ^= rotate(x[ 8] + x[11],13);
		x[10] ^= rotate(x[ 9] + x[ 8],18);
		x[12] ^= rotate(x[15] + x[14], 7);
		x[13] ^= rotate(x[12] + x[15], 9);
		x[14] ^= rotate(x[13] + x[12],13);
		x[15] ^= rotate(x[14] + x[13],18);
	}
	for (i = 0; i < 16; ++i)
		state->output.out32[i] = x[i] + state->input[i];

	++state->input[8];
	if (!state->input[8])
		++state->input[9];

	state->cur_output = 0;
}

void
arc4random_stir(void)
{
	struct state *state = &global_state;

	LOCK();
	if (__predict_false(!state->init_done))
		init(state);

	stir(state);
	UNLOCK();
}

void
arc4random_addrandom(uint8_t *data, int len)
{
}

#ifdef __weak_alias
__weak_alias(arc4random,_arc4random)
#endif

uint32_t
arc4random(void)
{
	struct state *state = &global_state;
	uint32_t val;

	LOCK();

	if (__predict_false(state->cur_output > 60))
		extract_block(state);
	val = state->output.out32[state->cur_output >> 2];
	state->cur_output += 7;
	state->cur_output &= ~3;

	UNLOCK();

	return val;
}

void
arc4random_buf(void *buf_, size_t len)
{
	struct state *state = &global_state;
	uint32_t val;
	uint8_t *buf2, *buf = buf_;
	size_t len2;

	if (len == 0)
		return;

	LOCK();

	if (len <= 8) {
		buf2 = state->output.out8 + state->cur_output;
		if (state->cur_output + len <= 64) {
			state->cur_output += len;
			do {
				*buf++ = *buf2++;
			} while (--len);
		} else if (__predict_false(state->cur_output >= 64)) {
			extract_block(state);
			state->cur_output = len;
			do {
				*buf++ = *buf2++;
			} while (--len);
		} else {
			len2 = 64 - state->cur_output;
			len -= len2;
			do {
				*buf++ = *buf2++;
			} while (--len2);
			extract_block(state);
			state->cur_output = len;
			buf2 = state->output.out8;
			do {
				*buf++ = *buf2++;
			} while (--len);			
		}
	} else {
		if ((len2 = (uintptr_t)buf & 3)) {
			len2 = 4 - len2;
			if (__predict_false(state->cur_output + len2 >= 64)) {
				extract_block(state);
				buf2 = state->output.out8;
			} else {
				buf2 = state->output.out8 + state->cur_output;
			}
			state->cur_output += len2;
			do {
				*buf++ = *buf2++;
			} while (--len2);
		}
		state->cur_output = (state->cur_output + 3) & ~3;
		buf2 = state->output.out8 + state->cur_output;
		while (len >= 4) {
			if (__predict_false(state->cur_output > 60)) {
				extract_block(state);
				buf2 = state->output.out8;
				if (len >= 64) {
					for (;;) {
						memcpy(buf, buf2, 64);
						buf += 64;
						buf2 += 64;
						len -= 64;
						if (len < 64)
							break;
						extract_block(state);
					}
					if (len < 4)
						break;
				}
			}
			*(uint32_t *)buf = *(uint32_t *)buf2;
			buf += 4;
			buf2 += 4;
			len -= 4;
			state->cur_output += 4;
		}
		if (len) {
			if (__predict_false(state->cur_output +len > 64)) {
				extract_block(state);
				buf2 = state->output.out8;
			}
			state->cur_output += len;
			do {
				*buf++ = *buf2++;
			} while (--len);			
		}
	}

	UNLOCK();
}

uint32_t
arc4random_uniform(uint32_t limit)
{
	struct state *state = &global_state;
	uint32_t val, rem;

	if (limit < 2)
		return 0;

	LOCK();

	do {
		if (__predict_false(state->cur_output >= 60))
			extract_block(state);
		val = state->output.out32[state->cur_output >> 2];
		state->cur_output += 7;
		state->cur_output &= ~3;
		rem = val % limit;
	} while (val - rem > UINT32_MAX - limit);

	UNLOCK();

	return rem;
}