/*	$NetBSD$	*/

/*
 * XXX WARNING WARNING WARNING XXX
 *
 * This code does not run!  I have not even compile-tested it.
 *
 * rrwlock: Pserialized recursive reader/writer lock.  Reader locks
 * can be taken recursively.  The writer lock is very expensive, but
 * frequent readers do not require interprocessor synchronization.
 * (Infrequent readers will likely incur the interprocessor
 * synchronization, defeating the purpose.)
 */

/*-
 * Copyright (c) 2014 The NetBSD Foundation, Inc.
 * All rights reserved.
 *
 * This code is derived from software contributed to The NetBSD Foundation
 * by Taylor R. Campbell.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 * 1. Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY THE NETBSD FOUNDATION, INC. AND CONTRIBUTORS
 * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
 * TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
 * PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR CONTRIBUTORS
 * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
 * POSSIBILITY OF SUCH DAMAGE.
 */

#include <sys/cdefs.h>
__KERNEL_RCSID(0, "$NetBSD$");

#include <sys/types.h>
#include <sys/condvar.h>
#include <sys/lwp.h>
#include <sys/mutex.h>
#include <sys/pool.h>
#include <sys/pserialize.h>
#include <sys/queue.h>

struct rrwlock {
	kmutex_t		rrwl_lock;
	kcondvar_t		rrwl_cv;
	LIST_HEAD(, rrw_reader)	rrwl_readers;
	struct lwp		*rrwl_writer;
	pserialize_t		rrwl_psz;
	pool_cache_t		rrwl_reader_pc;
#if DIAGNOSTIC
	const char		*rrwl_name;
#endif
};

struct rrw_reader {
	struct rrwlock		*rr_lock;
	LIST_ENTRY(rrw_reader)	rr_list;
	struct lwp		*rr_reader;
	SLIST_ENTRY(rrw_reader)	rr_stack;
	unsigned int		rr_depth;
};

struct rrw_perlwp {
	SLIST_HEAD(, rrw_reader)	rpl_stack;
};

struct rrw_writer;		/* fictitious */

static struct pool		rrwlock_pool	__cacheline_aligned;
static pool_cache_t		rrw_perlwp_pc	__read_mostly;
static specificdata_key_t	rrw_perlwp_key	__read_mostly;

int
rrwlocks_init(void)
{
	int error;

	pool_init(&rrwlock_pool, sizeof(struct rrwlock), 0, 0, 0, "rrwlock",
	    NULL, IPL_NONE);
	rrw_perlwp_pc = pool_cache_init(sizeof(struct rrw_perlwp), 0, 0, 0,
	    "rrw_perlwp", NULL, IPL_NONE, NULL, NULL, NULL);
	error = lwp_specific_key_create(&rrw_perlwp_key, &rrw_perlwp_dtor);
	if (error)
		panic("unable to create lwp-specific rrwlock key: %d", error);
}

static void
rrw_perlwp_dtor(void *arg)
{
	struct rrw_perlwp *const rpl = arg;
	struct rrw_reader *reader, next;

	KASSERT(rpl != NULL);
	if (!SLIST_EMPTY(&rpl->rpl_stack))
		panic("lwp %p still holds rrwlock readers", curlwp);
	pool_cache_put(rrw_perlwp_pc, rpl);
}

struct rrwlock *
rrwlock_create(const char *name)
{
	struct rrwlock *const rrwlock = pool_get(&rrwlock_pool, PR_WAITOK);

	mutex_init(&rrwlock->rrwl_lock, MUTEX_DEFAULT, IPL_NONE);
	cv_init(&rrwlock->rrwl_cv, name);
	LIST_INIT(&rrwlock->rrwl_readers);
	rrwlock->rrwl_writer = NULL;
	rrwlock->rrwl_psz = pserialize_create();
	rrwlock->rrwl_reader_pc = pool_cache_init(sizeof(struct rrw_reader),
	    0, 0, 0, "rrw_reader", NULL, IPL_NONE,
	    &rrw_reader_ctor, &rrw_reader_dtor, rrw);
#if DIAGNOSTIC
	rrwlock->rrwl_name = name;
#endif

	return rrwlock;
}

void
rrwlock_destroy(struct rrwlock *rrwlock)
{

	pool_cache_destroy(rrwlock->rrwl_reader_pc);
	pserialize_destroy(rrwlock->rrwl_psz);
	KASSERTMSG((rrwlock->rrwl_writer == NULL),
	    "rrwlock %s @ %p writer-locked by lwp %p",
	    rrwlock->rrwl_name, rrwlock, rrwlock->rrwl_writer);
	KASSERTMSG(LIST_EMPTY(&rrwlock->rrwl_readers),
	    "rrwlock reader-locked");
	cv_destroy(&rrwlock->rrwl_cv);
	mutex_destroy(&rrwlock->rrwl_lock);

	pool_put(&rrwlock_pool, rrwlock);
}

static int
rrw_reader_ctor(void *vrrwlock, void *vreader, int flags __unused)
{
	struct rrwlock *const rrwlock = vrrwlock;
	struct rrw_reader *const reader = vreader;

	reader->rr_reader = NULL;

	mutex_enter(&rrwlock->rrwl_lock);
	LIST_INSERT_TAIL(&rrwlock->rrwl_readers, reader);
	mutex_exit(&rrwlock->rrwl_lock);

	return 0;
}

static void
rrw_reader_dtor(void *vrrwlock, void *vreader)
{
	struct rrwlock *const rrwlock = vrrwlock;
	struct rrw_reader *const reader = vreader;

	KASSERTMSG((reader->rr_reader == NULL),
	    "rrwlock %s @ %p reader %p still in use by lwp %p",
	    rrwlock->rrwl_name, rrwlock, reader, reader->rr_reader);

	mutex_enter(&rrwlock->rrwl_lock);
	LIST_REMOVE(reader, rr_entry);
	mutex_exit(&rrwlock->rrwl_lock);
}

static inline struct rrw_reader *
rrw_reader_find(struct rrwlock *rrwlock)
{
	struct rrw_perlwp *rpl;
	struct rrw_reader *reader;

	rpl = lwp_getspecific(rrwlock_perlwp_key);
	if (__predict_false(rpl == NULL))
		return NULL;

	SLIST_FOREACH(reader, &rpl->rpl_stack, rr_stack) {
		if (reader->rr_lock == rrwlock)
			return reader;
	}

	return NULL;
}

static inline void
rrw_reader_push(struct rrw_reader *reader)
{
	struct rrw_perlwp *rpl;

	rpl = lwp_getspecific(rrwlock_perlwp_key);
	if (__predict_false(rpl == NULL)) {
		rpl = pool_cache_get(rrw_perlwp_pc, PR_WAITOK);
		lwp_setspecific(rrwlock_perlwp_key, rpl);
	}

	SLIST_INSERT_HEAD(&rpl->rpl_stack, reader, rr_stack);
}

static inline void
rrw_reader_pop(struct rrw_reader *reader)
{
	struct rrw_perlwp *rpl;

	rpl = lwp_getspecific(rrwlock_perlwp_key);
	KASSERT(rpl != NULL);

	KASSERT(!SLIST_EMPTY(&rpl->rpl_stack));
	if (__predict_true(reader == SLIST_FIRST(&rpl->rpl_stack)))
		SLIST_REMOVE_HEAD(&rpl->rpl_stack, rr_stack);
	else
		SLIST_REMOVE(&rpl->rpl_stack, reader, rr_stack);
}

void
rrwlock_reader_enter(struct rrwlock *rrwlock, struct rrw_reader **readerp)
{
	struct rrw_reader *reader;
	int s;

	KASSERTMSG((rrwlock->rrwl_writer != curlwp),
	    "lwp %p read-locking rrwlock %s @ %p with write-lock held",
	    curlwp, rrwlock->rrwl_name, rrwlock);

	reader = rrw_reader_find(rrwlock);
	if (reader != NULL) {
		if (__predict_false(reader->rr_depth == UINT_MAX))
			panic("rrwlock reader overflow: %p", rrwlock);
		reader->rr_depth++;
		goto out0;
	}

	reader = pool_cache_get(rrwlock->rrwl_reader_pc, PR_WAITOK);
	KASSERT(reader != NULL);

	s = pserialize_read_enter();
	if (__predict_false(rrwlock->rrwl_writer != NULL)) {
		pserialize_read_exit(s);
		mutex_enter(&rrwlock->rrwl_lock);
		while (rrwlock->rrwl_writer != NULL)
			cv_wait(&rrwlock->rrwl_cv, &rrwlock->rrwl_lock);
		reader->rr_reader = curlwp;
		mutex_exit(&rrwlock->rrwl_lock);
		goto out1;
	}
	reader->rr_reader = curlwp;
	pserialize_read_exit(s);

out1:	rrw_reader_push(reader);
out0:	*readerp = reader;
}

void
rrwlock_reader_exit(struct rrwlock *rrwlock, struct rrw_reader *reader)
{
	unsigned int depth;
	int s;

	KASSERT(reader != NULL);
	KASSERTMSG((0 < reader->rr_depth),
	    "rrwlock %s @ %p reader %p underflow",
	    rrwlock->rrwl_name, rrwlock, reader);
	KASSERTMSG((reader->rr_reader != NULL),
	    "rrwlock %s @ %p reader %p lost lwp %p",
	    rrwlock->rrwl_name, rrwlock, reader, curlwp);
	KASSERTMSG((reader->rr_reader == curlwp),
	    "rrwlock %s @ %p reader %p switched lwp from %p to %p",
	    rrwlock->rrwl_name, rrwlock, reader, reader->rr_reader, curlwp);

	if (1 < reader->rr_depth) {
		reader->rr_depth--;
		return;
	}

	s = pserialize_read_enter();
	if (__predict_false(rrwlock->rrwl_writer != NULL)) {
		pserialize_read_exit(s);
		mutex_enter(&rrwlock->rrwl_lock);
		reader->rr_depth = 0;
		reader->rr_reader = NULL;
		cv_broadcast(&rrwlock->rrwl_cv);
		mutex_exit(&rrwlock->rrwl_lock);
		goto out;
	}
	reader->rr_depth = 0;
	reader->rr_reader = NULL;
	pserialize_read_exit(s);

out:	rrw_reader_pop(reader);
	pool_cache_put(rrwlock->rrwl_reader_pc, reader);
}

static bool
rrwlock_readers_p(struct rrwlock *rrwlock)
{
	struct rrw_reader *reader;

	KASSERT(mutex_owned(&rrwlock->rrwl_lock));

	LIST_FOREACH(reader, &rrwlock->rrwl_readers, rr_entry) {
		if (reader->rr_reader != NULL)
			return true;
	}

	return false;
}

void
rrwlock_writer_enter(struct rrwlock *rrwlock, struct rrw_writer **writerp)
{

	mutex_enter(&rrwlock->rrwl_lock);
	KASSERTMSG((rrwlock->rrwl_writer != curlwp),
	    "write-locking against myself: rrwlock %s @ %p, lwp %p",
	    rrwlock->rrwl_name, rrwlock, curlwp);
	while (rrwlock->rrwl_writer != NULL)
		cv_wait(&rrwlock->rrwl_cv, &rrwlock->rrwl_lock);
	rrwlock->rrwl_writer = curlwp;
	pserialize_perform(rrwlock->rrwl_psz);
	while (rrwlock_readers_p(rrwlock))
		cv_wait(&rrwlock->rrwl_cv, &rrwlock->rrwl_lock);
	mutex_exit(&rrwlock->rrwl_lock);

	*writerp = (struct rrw_writer *)rrwlock;
}

void
rrwlock_writer_exit(struct rrwlock *rrwlock, struct rrw_writer *writer)
{

	KASSERT(writer == (struct rrw_writer *)rrwlock);

	mutex_enter(&rrwlock->rrwl_lock);
	KASSERTMSG((rrwlock->rrwl_writer == curlwp),
	    "lwp %p writer-unlocking rrwlock %s @ %p held by lwp %p",
	    curlwp, rrwlock->rrwl_name, rrwlock, rrwlock->rrwl_writer);
	rrwlock->rrwl_writer = NULL;
	cv_broadcast(&rrwlock->rrwl_cv);
	mutex_exit(&rrwlock->rrwl_lock);
}