/*	$NetBSD$	*/

/*-
 * Copyright (c) 2014 The NetBSD Foundation, Inc.
 * All rights reserved.
 *
 * This code is derived from software contributed to The NetBSD Foundation
 * by Taylor R. Campbell.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 * 1. Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY THE NETBSD FOUNDATION, INC. AND CONTRIBUTORS
 * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
 * TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
 * PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR CONTRIBUTORS
 * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
 * POSSIBILITY OF SUCH DAMAGE.
 */

/*
 * Legacy arc4random(3) API from OpenBSD, reimplemented to use the
 * ChaCha stream cipher, which in 2014 remains unbroken with a high
 * security margin (only attack better than brute force is against 7
 * rounds of 20), unlike RC4, which is completely broken.
 *
 * For a 256-bit key k and 128-bit input i, ChaCha_k(i) yields a
 * 512-bit output.  ChaCha is conjectured to be a PRF -- that is, if K
 * is a random variable in {0,1}^256 with uniform distribution, then
 * the ChaCha_K random variable is computationally indistinguishable
 * from a random variable in functions {0,1}^128 --> {0,1}^512 with
 * uniform distribution.  Computing ChaCha_k(i) for any k, i takes
 * about 300 cycles on an Intel Ivy Bridge CPU with naive C code.
 *
 * The arc4random(3) state is a 256-bit key from sysctl(KERN_URND), a
 * 128-bit nonce, and a 512-bit output buffer aligned to 32-bit words.
 * Whenever the buffer is exhausted we refill it with a single ChaCha
 * output and increment a 128-bit counter.  Long requests are served by
 * generating a key from the buffer and producing a new stream from
 * that key.
 *
 * The state is global to a process and protected by a mutex (if
 * _REENTRANT is defined).
 *
 * Before fork, we zero the CPRNG state.  That way, if the child drops
 * privileges, it won't be able to see the parent's secrets.  The next
 * time either process uses the API it will request a new key from the
 * kernel.
 */

#include <sys/cdefs.h>
__RCSID("$NetBSD$");

#include "namespace.h"
#include "reentrant.h"

#include <sys/bitops.h>
#include <sys/sysctl.h>

#include <assert.h>
#include <sha2.h>
#include <stdbool.h>
#include <stdint.h>
#include <stdlib.h>
#include <string.h>

#ifdef __weak_alias
__weak_alias(arc4random,_arc4random)
__weak_alias(arc4random_addrandom,_arc4random_addrandom)
__weak_alias(arc4random_buf,_arc4random_buf)
__weak_alias(arc4random_stir,_arc4random_stir)
__weak_alias(arc4random_uniform,_arc4random_uniform)
#endif

/* ChaCha core */

#define	crypto_core_OUTPUTWORDS	16
#define	crypto_core_INPUTWORDS	4
#define	crypto_core_KEYWORDS	8
#define	crypto_core_CONSTWORDS	4

#define	crypto_core_ROUNDS	20

static uint32_t
rotate(uint32_t u, unsigned c)
{

	return (u << c) | (u >> (32 - c));
}

#define	QUARTERROUND(a, b, c, d) do {					      \
	(a) += (b); (d) ^= (a); (d) = rotate((d), 16);			      \
	(c) += (d); (b) ^= (c); (b) = rotate((b), 12);			      \
	(a) += (b); (d) ^= (a); (d) = rotate((d),  8);			      \
	(c) += (d); (b) ^= (c); (b) = rotate((b),  7);			      \
} while (0)

static void
crypto_core(uint32_t *out, const uint32_t *in, const uint32_t *k,
    const uint32_t *c)
{
	uint32_t x0,x1,x2,x3,x4,x5,x6,x7,x8,x9,x10,x11,x12,x13,x14,x15;
	uint32_t j0,j1,j2,j3,j4,j5,j6,j7,j8,j9,j10,j11,j12,j13,j14,j15;
	int i;

	j0 = x0 = c[0];
	j1 = x1 = c[1];
	j2 = x2 = c[2];
	j3 = x3 = c[3];
	j4 = x4 = k[0];
	j5 = x5 = k[1];
	j6 = x6 = k[2];
	j7 = x7 = k[3];
	j8 = x8 = k[4];
	j9 = x9 = k[5];
	j10 = x10 = k[6];
	j11 = x11 = k[7];
	j12 = x12 = in[0];
	j13 = x13 = in[1];
	j14 = x14 = in[2];
	j15 = x15 = in[3];

	for (i = crypto_core_ROUNDS; i > 0; i -= 2) {
		QUARTERROUND( x0, x4, x8,x12);
		QUARTERROUND( x1, x5, x9,x13);
		QUARTERROUND( x2, x6,x10,x14);
		QUARTERROUND( x3, x7,x11,x15);
		QUARTERROUND( x0, x5,x10,x15);
		QUARTERROUND( x1, x6,x11,x12);
		QUARTERROUND( x2, x7, x8,x13);
		QUARTERROUND( x3, x4, x9,x14);
	}

	out[0] = x0 + j0;
	out[1] = x1 + j1;
	out[2] = x2 + j2;
	out[3] = x3 + j3;
	out[4] = x4 + j4;
	out[5] = x5 + j5;
	out[6] = x6 + j6;
	out[7] = x7 + j7;
	out[8] = x8 + j8;
	out[9] = x9 + j9;
	out[10] = x10 + j10;
	out[11] = x11 + j11;
	out[12] = x12 + j12;
	out[13] = x13 + j13;
	out[14] = x14 + j14;
	out[15] = x15 + j15;
}

/* CPRNG algorithm */

#define	CPRNG_SEED_BYTES	(crypto_core_KEYWORDS * sizeof(uint32_t))
#define	CPRNG_SHORT_REQ		(crypto_core_OUTPUTWORDS * sizeof(uint32_t))

/* `expand 32-byte k' */
static const uint32_t sigma[4] = {
	0x61707865U, 0x3320646eU, 0x79622d32U, 0x6b206574U,
};

static void
nonce_inc(uint32_t n[crypto_core_INPUTWORDS])
{
	uint64_t t = 1;
	unsigned i;

	for (i = 0; i < crypto_core_INPUTWORDS; i++) {
		t += n[i];
		n[i] = t;
		t >>= 32;
	}

	/*
	 * If the nonce overflows, you counted sequentially to 2^128.
	 * If you count once per femptosecond, it will take you about
	 * ten quadrillion years to manage this.  Don't worry about it.
	 */
}

struct cprng {
	uint32_t	key[crypto_core_KEYWORDS];
	uint32_t	nonce[crypto_core_INPUTWORDS];
	uint32_t	buffer[crypto_core_OUTPUTWORDS];
	unsigned	buffered;
};

static void
cprng_seed(struct cprng *cprng, const void *seed)
{
	unsigned i;

	__CTASSERT(CPRNG_SEED_BYTES == sizeof cprng->key);
	(void)memcpy(cprng->key, seed, sizeof cprng->key);
	for (i = 0; i < crypto_core_INPUTWORDS; i++)
		cprng->nonce[i] = 0;
}

/*
 * Generate a short output from a CPRNG state.
 */
static void
cprng(struct cprng *cprng, void *buf, size_t len)
{
	const size_t nwords = howmany(len, sizeof(uint32_t));

	_DIAGASSERT(len <= CPRNG_SHORT_REQ);
	if (__predict_false(cprng->buffered < nwords)) {
		crypto_core(cprng->buffer, cprng->nonce, cprng->key, sigma);
		nonce_inc(cprng->nonce);
		cprng->buffered = crypto_core_OUTPUTWORDS;
	}

	(void)memcpy(buf, &cprng->buffer[crypto_core_OUTPUTWORDS -
		cprng->buffered], len);
	cprng->buffered -= nwords;
}

/*
 * Generate a long stream from a one-time key.
 */
static void
cprng1(const void *seed, void *buf, size_t len)
{
	uint8_t *p8;
	uint32_t *p32;
	size_t ni, nb, nf;
	uint32_t key[crypto_core_KEYWORDS];
	uint32_t nonce[crypto_core_INPUTWORDS] = {0};
	uint32_t block[crypto_core_OUTPUTWORDS];

	/*
	 * Guarantee we can generate up to len bytes.  We have
	 * 2^(32*INPUTWORDS) possible inputs yielding output of
	 * 4*OUTPUTWORDS*2^(32*INPUTWORDS) bytes.  It suffices to
	 * require that sizeof len > (1/CHAR_BIT) log_2 len be less
	 * than (1/CHAR_BIT) log_2 of the total output stream length.
	 * We have
	 *
	 *	log_2 (4 o 2^(32 i)) = log_2 (4 o) + log_2 2^(32 i)
	 *	  = 2 + log_2 o + 32 i.
	 */
	__CTASSERT(CHAR_BIT*sizeof len <=
	    (2 + ilog2(crypto_core_OUTPUTWORDS) + 32*crypto_core_INPUTWORDS));

	__CTASSERT(CPRNG_SEED_BYTES == sizeof key);
	(void)memcpy(key, seed, sizeof key);

	p8 = buf;
	p32 = (uint32_t *)roundup2((uintptr_t)p8, sizeof(uint32_t));
	ni = (uint8_t *)p32 - p8;
	nb = (len - ni) / sizeof block;
	nf = (len - ni) % sizeof block;

	_DIAGASSERT(((uintptr_t)p32 & 3) == 0);
	_DIAGASSERT(len == (ni + (nb * sizeof block) + nf));
	_DIAGASSERT(ni < sizeof(uint32_t));
	_DIAGASSERT(nf < sizeof(uint32_t)*crypto_core_OUTPUTWORDS);

	if (__predict_false(ni)) {
		crypto_core(block, nonce, key, sigma);
		nonce_inc(nonce);
		(void)memcpy(p8, block, ni);
	}
	while (nb--) {
		crypto_core(p32, nonce, key, sigma);
		nonce_inc(nonce);
		p32 += crypto_core_OUTPUTWORDS;
	}
	if (__predict_false(nf)) {
		crypto_core(block, nonce, key, sigma);
		nonce_inc(nonce);
		(void)memcpy(p32, block, nf);
	}

	if (__predict_false(ni | nf))
		(void)explicit_memset(block, 0, sizeof block);
	(void)explicit_memset(key, 0, sizeof key);
}

/* Library state */

#ifndef _REENTRANT
#define	mutex_lock(m)	do {} while (0)
#define	mutex_unlock(m)	do {} while (0)
#endif

static struct {
#ifdef _REENTRANT
	mutex_t		lock;
#endif
	struct cprng	cprng;
	bool		seeded:1;
	bool		atfork:1;
} arc4random_state = {
#ifdef _REENTRANT
	.lock = MUTEX_INITIALIZER,
#endif
	.seeded = false,
	.atfork = false,
};

static void
arc4random_atfork_prepare(void)
{

	mutex_lock(&arc4random_state.lock);

	/* Don't let the child see our key.  */
	(void)explicit_memset(&arc4random_state.cprng, 0,
	    sizeof arc4random_state.cprng);
	arc4random_state.seeded = false;
}

static void
arc4random_atfork_parent(void)
{

	mutex_unlock(&arc4random_state.lock);
}

static void
arc4random_atfork_child(void)
{

	mutex_unlock(&arc4random_state.lock);
}

static void __noinline
arc4random_atfork_locked(void)
{

	if (__predict_false(pthread_atfork(arc4random_atfork_prepare,
		    arc4random_atfork_parent, arc4random_atfork_child) != 0))
		abort();
	arc4random_state.atfork = true;
}

static void
arc4random_stir_locked(void)
{
	const int mib[] = { CTL_KERN, KERN_URND };
	uint8_t seed[CPRNG_SEED_BYTES];
	size_t len = sizeof seed;

	if (sysctl(mib, __arraycount(mib), seed, &len, NULL, 0) == -1)
		abort();
	cprng_seed(&arc4random_state.cprng, seed);
	(void)explicit_memset(seed, 0, sizeof seed);
	arc4random_state.seeded = true;
}

static void
arc4random_addrandom_locked(u_char *data, int datalen)
{
	const int mib[] = { CTL_KERN, KERN_URND };
	SHA256_CTX ctx;
	uint8_t seed[CPRNG_SEED_BYTES], hash[SHA256_DIGEST_LENGTH];
	size_t len = sizeof seed;

	__CTASSERT(sizeof seed == sizeof hash);

	if (sysctl(mib, __arraycount(mib), seed, &len, NULL, 0) == -1)
		abort();
	SHA256_Init(&ctx);
	SHA256_Update(&ctx, seed, sizeof seed);
	(void)explicit_memset(seed, 0, sizeof seed);
	SHA256_Update(&ctx, data, datalen);
	SHA256_Final(hash, &ctx);
	cprng_seed(&arc4random_state.cprng, hash);
	(void)explicit_memset(hash, 0, sizeof hash);
	arc4random_state.seeded = true;
}

static inline void
arc4random_buf_locked(void *buf, size_t len)
{

	_DIAGASSERT(len <= CPRNG_SHORT_REQ);
	if (__predict_false(!arc4random_state.seeded))
		arc4random_stir_locked();
	_DIAGASSERT(arc4random_state.seeded);
	cprng(&arc4random_state.cprng, buf, len);
}

static inline uint32_t
arc4random_locked(void)
{
	uint32_t r;

	arc4random_buf_locked(&r, sizeof r);

	return r;
}

static inline void
arc4random_enter(void)
{

	mutex_lock(&arc4random_state.lock);
	if (__predict_false(!arc4random_state.atfork))
		arc4random_atfork_locked();
}

static inline void
arc4random_exit(void)
{

	mutex_unlock(&arc4random_state.lock);
}

/* Public API */

uint32_t
arc4random(void)
{
	uint32_t r;

	arc4random_enter();
	r = arc4random_locked();
	arc4random_exit();

	return r;
}

void
arc4random_buf(void *buf, size_t len)
{

	if (len <= CPRNG_SHORT_REQ) {
		arc4random_enter();
		arc4random_buf_locked(buf, len);
		arc4random_exit();
	} else {
		uint8_t seed[CPRNG_SEED_BYTES];

		__CTASSERT(CPRNG_SEED_BYTES <= CPRNG_SHORT_REQ);
		arc4random_enter();
		arc4random_buf_locked(seed, sizeof seed);
		arc4random_exit();

		cprng1(seed, buf, len);
		(void)explicit_memset(seed, 0, sizeof seed);
	}
}

uint32_t
arc4random_uniform(uint32_t bound)
{
	uint32_t minimum, r;

	/*
	 * We want a uniform random choice in [0, n), and arc4random()
	 * makes a uniform random choice in [0, 2^32).  If we reduce
	 * that modulo n, values in [0, 2^32 mod n) will be represented
	 * slightly more than values in [2^32 mod n, n).  Instead we
	 * choose only from [2^32 mod n, 2^32) by rejecting samples in
	 * [0, 2^32 mod n), to avoid counting the extra representative
	 * of [0, 2^32 mod n).  To compute 2^32 mod n, note that
	 *
	 *	2^32 mod n = 2^32 mod n - 0
	 *	  = 2^32 mod n - n mod n
	 *	  = (2^32 - n) mod n,
	 *
	 * the last of which is what we compute in 32-bit arithmetic.
	 */
	minimum = (-bound % bound);

	arc4random_enter();
	do r = arc4random_locked(); while (r < minimum);
	arc4random_exit();

	return (r % bound);
}

void
arc4random_stir(void)
{

	arc4random_enter();
	arc4random_stir_locked();
	arc4random_exit();
}

/*
 * Silly signature here is for hysterical raisins.  Should instead be
 * const void *data and size_t datalen.
 */
void
arc4random_addrandom(u_char *data, int datalen)
{

	arc4random_enter();
	arc4random_addrandom_locked(data, datalen);
	arc4random_exit();
}

#if _ARC4RANDOM_TEST

#include <sys/endian.h>

#include <err.h>
#include <stdint.h>
#include <stdio.h>

static const uint32_t zero32[8];

/*
 * From
 * <http://tools.ietf.org/html/draft-strombergson-chacha-test-vectors-00>.
 */

static const uint8_t out[64] = {
	0x76,0xb8,0xe0,0xad,0xa0,0xf1,0x3d,0x90,
	0x40,0x5d,0x6a,0xe5,0x53,0x86,0xbd,0x28,
	0xbd,0xd2,0x19,0xb8,0xa0,0x8d,0xed,0x1a,
	0xa8,0x36,0xef,0xcc,0x8b,0x77,0x0d,0xc7,
	0xda,0x41,0x59,0x7c,0x51,0x57,0x48,0x8d,
	0x77,0x24,0xe0,0x3f,0xb8,0xd8,0x4a,0x37,
	0x6a,0x43,0xb8,0xf4,0x15,0x18,0xa1,0x1c,
	0xc3,0x87,0xb6,0x69,0xb2,0xee,0x65,0x86,
};

#define	check(E)	do {						      \
	if (!__predict_true(E))						      \
		errx(1, "check failed in %s: %s:%d: %s\n", __func__,	      \
		    __FILE__, __LINE__, #E);				      \
} while (0)

int
main(int argc __unused, char **argv __unused)
{
	const uint8_t sigma_string[] = "expand 32-byte k";
	uint32_t block[16];
	uint8_t small[8], large[256];
	unsigned i;

	__CTASSERT(crypto_core_ROUNDS == 20);
	check(sigma[0] == le32dec(&sigma_string[0]));
	check(sigma[1] == le32dec(&sigma_string[4]));
	check(sigma[2] == le32dec(&sigma_string[8]));
	check(sigma[3] == le32dec(&sigma_string[12]));

	crypto_core(block, zero32, zero32, sigma);
	for (i = 0; i < 8; i++)
		check(block[i] == le32dec(&out[i*4]));

	if (printf("arc4random: %"PRIx32"\n", arc4random()) < 0)
		err(1, "printf");
	arc4random_buf(small, sizeof small);
	if (printf("arc4randombuf small:") < 0)
		err(1, "printf");
	for (i = 0; i < sizeof small; i++)
		if (printf(" %02x", small[i]) < 0)
			err(1, "printf");
	if (printf("\n") < 0)
		err(1, "printf");
	arc4random_buf(large, sizeof large);
	if (printf("arc4randombuf_large:") < 0)
		err(1, "printf");
	for (i = 0; i < sizeof large; i++)
		if (printf(" %02x", large[i]) < 0)
			err(1, "printf");
	if (printf("\n") < 0)
		err(1, "printf");

	return 0;
}
#endif