/*	$NetBSD$	*/

/*-
 * Copyright (c) 2013 The NetBSD Foundation, Inc.
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 * 1. Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY THE NETBSD FOUNDATION, INC. AND CONTRIBUTORS
 * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
 * TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
 * PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR CONTRIBUTORS
 * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
 * POSSIBILITY OF SUCH DAMAGE.
 */
#include <sys/cdefs.h>
__RCSID("$NetBSD$");

#include <errno.h>
#include <limits.h>
#include <stddef.h>
#include <stdint.h>
#include <stdlib.h>
#include <stdio.h>
#include <string.h>
#include <stdarg.h>

#include "fmtmatch.h"

enum length_t {
	LM_NONE,
	LM_hh,
	LM_h,
	LM_l,
	LM_ll,
	LM_j,
	LM_z,
	LM_t,
	LM_L,
};

enum spec_t {
	A_UNDEF,
	A_INT,			/* [di] or ... */
	A_UNSIGNED_INT,		/* ... [ouxX] specifiers */
	A_LONG,			/* ... with l modifier */
	A_UNSIGNED_LONG,
	A_LONG_LONG,		/* ... with ll modifier */
	A_UNSIGNED_LONG_LONG,
	A_SIZE,			/* ... with z modifier */
	A_UNSIGNED_SIZE,
	A_INTMAX,		/* ... with j modifier */
	A_UNSIGNED_INTMAX,
	A_PTRDIFF,		/* ... with t modifier */
	A_UNSIGNED_PTRDIFF,	/* NB: there's no such type */

	A_DOUBLE,		/* [aAeEfFgG] specifiers */
	A_LONG_DOUBLE,		/* ... with L modifier */

	A_STRING,		/* %s */
	A_WIDE_STRING,		/* %ls */

	A_ASTERISK,		/* variable width/precision */
};


struct spec {
	int template;
	uint32_t format;
};
struct fmtmatch {
	size_t maxspecs;
	size_t nspecs;
	char errmsg[1024];
	struct spec *specs;
};

typedef int (*spec_cb_t)(unsigned int, enum spec_t, struct fmtmatch *);

static void __attribute__((__format__(__printf__, 2, 3)))
addmsg(struct fmtmatch *fm, const char *fmt, ...)
{
	va_list ap;
	va_start(ap, fmt);
	if (fm->errmsg[0] == '\0')
		vsnprintf(fm->errmsg, sizeof(fm->errmsg), fmt, ap);
	va_end(ap);
}

/**
 * Check if "fmt" begins with n$ numbered argument specification.
 * Return 0 if not; -1 if the number is zero or doesn't fit into int.
 * Return the number otherwise.
 */
static int
get_numbered_argument(const char * __restrict fmt,
    const char ** __restrict endptr, struct fmtmatch *fm)
{
	char *end;
	long argno;

	*endptr = fmt;

	if (*fmt < '0' || '9' < *fmt)
		return 0;

	errno = 0;
	argno = strtol(fmt, &end, 10);
	if (*end != '$')
		return 0;

	*endptr = end + 1;
	if (INT_MAX < argno || (errno == ERANGE && argno == LONG_MAX)) {
		addmsg(fm, "Numbered argument `%s' is too large", fmt);
		return -1;
	}

	if (argno == 0) {
		addmsg(fm, "Numbered argument `%s' is zero", fmt);
		errno = ERANGE;
		return -1;
	}

	return argno;
}


static int
fmt_parse_template(const char *fmt0, struct fmtmatch *fm, spec_cb_t callback)
{
	const char *fmt;
	int numbered;
	size_t arg;

#define FMTERR(...)					\
	do {						\
		addmsg(fm, __VA_ARGS__);		\
		return EINVAL;				\
	} while (/*CONSTCOND*/0)

#define CALLBACK(_arg, _spec)						\
	do {								\
		if (callback != NULL) {					\
			int status;					\
			status = (*callback)(_arg, _spec, fm);		\
			if (status != 0)				\
				return status;				\
		}							\
	} while (/*CONSTCOND*/0)


	fmt = fmt0;
	numbered = -1;
	arg = 0;
	for (;;) {
		const char *spec;
		size_t ndigits;
		int numarg, precision, lenmod;
		int unsign;
		const char *end;

		fmt = spec = strchr(fmt, '%');
		if (fmt == NULL)
			break;

		/* Skip '%' that starts conversion specification */
		++fmt;

		/*
		 * For '%' conversion specifier "the complete
		 * conversion specification shall be %%."
		 */
		if (*fmt == '%') {
			++fmt;
			continue;
		}

		/*
		 * Numbered argument specification.
		 */
		numarg = get_numbered_argument(fmt, &end, fm);
		if (numarg < -1)
			return errno;

		if (numarg == 0) {
			if (numbered > 0)
				FMTERR("numbered and unnumbered"
				   " arguments cannot be mixed");
			numbered = 0;
		} else {
			/* numbered and unnumbered cannot be mixed */
			if (numbered == 0)
				FMTERR("numbered and unnumbered"
				       " arguments cannot be mixed");
			numbered = 1;

			if (*fmt == '0') /* leading zero is a flag */
				FMTERR("leading zero in numbered argument");

			fmt = end; /* skip past '$' */
		}

		/*
		 * Flags
		 */
		fmt += strspn(fmt, " -+#0'");

#define NUMBERED_ARGUMENT_ASTERISK() do {				\
			int asterisk;					\
									\
			asterisk = get_numbered_argument(fmt, &end, fm);\
			if (asterisk < 0)				\
				return EINVAL;				\
									\
			if (asterisk == 0) {				\
				if (numbered > 0)			\
					FMTERR("numbered and unnumbered"\
					" arguments cannot be mixed");	\
				numbered = 0;				\
				asterisk = ++arg;			\
			}						\
			else {						\
				if (numbered == 0)			\
					FMTERR("numbered and unnumbered"\
					" arguments cannot be mixed");	\
				numbered = 1;				\
				fmt = end;				\
			}						\
			CALLBACK(asterisk, A_ASTERISK);			\
		} while (/* CONSTCOND */ 0)

		/*
		 * Minimum field width
		 */
		if (*fmt == '*') {
			++fmt;
			NUMBERED_ARGUMENT_ASTERISK();
		}
		else {
			ndigits = strspn(fmt, "0123456789");
			fmt += ndigits;
		}

		/*
		 * Precision
		 */
		precision = 0;
		if (*fmt == '.') {
			++fmt;
			precision = 1;

			if (*fmt == '*') {
				++fmt;
				NUMBERED_ARGUMENT_ASTERISK();
			}
			else {
				if (*fmt == '-')
					++fmt;

				ndigits = strspn(fmt, "0123456789");
				if (ndigits == 0 && fmt[-1] == '-')
					FMTERR("orphan '-' in precision");
				fmt += ndigits;
			}
		}

		/*
		 * Length modifier
		 */
		lenmod = LM_NONE;
		switch (*fmt) {
		case 'h':
			if (*++fmt == 'h') {
				++fmt;
				lenmod = LM_hh;
			}
			else
				lenmod = LM_h;
			break;
		case 'l':
			if (*++fmt == 'l') {
				++fmt;
				lenmod = LM_ll;
			}
			else
				lenmod = LM_l;
			break;
		case 'j':
			++fmt;
			lenmod = LM_j;
			break;
		case 'z':
			++fmt;
			lenmod = LM_z;
			break;
			++fmt;
		case 't':
			++fmt;
 			lenmod = LM_t;
			break;
		case 'L':
			++fmt;
 			lenmod = LM_L;
			break;
		}


		/*
		 * Conversion specifier
		 */
		if (numarg)
			arg = numarg;
		else
			++arg;

		unsign = 0;
		switch (*fmt) {
		case 'o': case 'u':
		case 'x': case 'X':
			unsign = 1;
			/* FALLTHROUGH */
		case 'd': case 'i':
			switch (lenmod) {
			case LM_NONE:
			case LM_hh:
			case LM_h:
				CALLBACK(arg, A_INT + unsign);
				break;
			case LM_l:
				CALLBACK(arg, A_LONG + unsign);
				break;
			case LM_ll:
				CALLBACK(arg, A_LONG_LONG + unsign);
				break;
			case LM_j:
				CALLBACK(arg, A_INTMAX + unsign);
				break;
			case LM_z:
				CALLBACK(arg, A_SIZE + unsign);
				break;
			case LM_t:
				CALLBACK(arg, A_PTRDIFF + unsign);
				break;
			case LM_L:
				FMTERR("%%%c doesn't take 'L' modifier", *fmt);
				break;
			}
			break;

		case 'a': case 'A':
		case 'e': case 'E':
		case 'f': case 'F':
		case 'g': case 'G':
 			if (lenmod == LM_NONE)
				CALLBACK(arg, A_DOUBLE);
 			else if (lenmod == LM_L)
				CALLBACK(arg, A_LONG_DOUBLE);
			else
				FMTERR("%%%c takes only 'L' modifier", *fmt);
			break;

		case 'C':
			if (precision)
 				FMTERR("%%C doesn't take precision");
			if (lenmod != LM_NONE)
				FMTERR("%%C doesn't take modifiers");
		wide_char:
			CALLBACK(arg,
#define WINT_IS(_T) (WINT_MAX == (WINT_MIN == 0 ? U## _T ##_MAX : _T ##_MAX))
#if WINT_IS(INT)
				 A_INT
#elif WINT_IS(LONG)
				 A_LONG
#elif WINT_IS(LLONG)
				 A_LONG_LONG
#else
#error Unable to determine wint_t size
#endif
#undef WINT_IS
				 + (WINT_MIN == 0));
			break;

		case 'c':
			if (precision)
 				FMTERR("%%c doesn't take precision");
			if (lenmod == LM_NONE)
				CALLBACK(arg, A_INT);
			else if (lenmod == LM_l)
				goto wide_char;
			else
				FMTERR("%%c takes only 'l' modifier");
			break;

		case 'S':
			if (lenmod != LM_NONE)
				FMTERR("%%S doesn't take modifiers");
			CALLBACK(arg, A_WIDE_STRING);
			break;

		case 's':
			if (lenmod == LM_NONE)
				CALLBACK(arg, A_STRING);
			else if (lenmod == LM_l)
				CALLBACK(arg, A_WIDE_STRING);
			else
				FMTERR("%%s takes only 'L' modifier");
			break;

		case 'p':
			FMTERR("%%p is not allowed");

		case 'n':
			FMTERR("%%n is not allowed");

		default:
			FMTERR("no conversion specifier");
		}
	}

	return 0;
}

static const char *const names[] = {
	[A_INT] = "INT",
	[A_UNSIGNED_INT] = "UNSIGNED INT",
	[A_LONG] = "LONG",
	[A_UNSIGNED_LONG] = "UNSIGNED LONG",
	[A_LONG_LONG] = "LONG LONG",
	[A_UNSIGNED_LONG_LONG] = "UNSIGNED LONG LONG",
	[A_SIZE] = "SSIZE_T",
	[A_UNSIGNED_SIZE] = "SIZE_T",
	[A_INTMAX] = "INTMAX",
	[A_UNSIGNED_INTMAX] = "UNSIGNED INTMAX",
	[A_PTRDIFF] = "PTRDIFF",
	[A_UNSIGNED_PTRDIFF] = "UNSIGNED PTRDIFF",

	[A_DOUBLE] = "DOUBLE",
	[A_LONG_DOUBLE] = "LONG DOUBLE",

	[A_STRING] = "STRING",
	[A_WIDE_STRING] = "WIDE STRING",

	[A_ASTERISK] = "ASTERISK",
};


static int
callback(unsigned int arg, enum spec_t spec, struct fmtmatch *fm __unused)
{
	printf("arg %u is %s\n", arg, names[spec]);
	return 0;
}

static const uint32_t compatible_specs[] = {
	[A_INT] = (3U << A_INT),
	[A_UNSIGNED_INT] = (3U << A_INT),
	[A_LONG] = (3U << A_LONG),
	[A_UNSIGNED_LONG] = (3U << A_LONG),
	[A_LONG_LONG] = (3U << A_LONG_LONG),
	[A_UNSIGNED_LONG_LONG] = (3U << A_LONG_LONG),
	[A_SIZE] = (3U << A_SIZE),
	[A_UNSIGNED_SIZE] = (3U << A_SIZE),
	[A_INTMAX] = (3U << A_INTMAX),
	[A_UNSIGNED_INTMAX] = (3U << A_INTMAX),
	[A_PTRDIFF] = (3U << A_PTRDIFF),
	[A_UNSIGNED_PTRDIFF] = (3U << A_PTRDIFF),

	[A_DOUBLE] = (1U << A_DOUBLE),
	[A_LONG_DOUBLE] = (1U << A_LONG_DOUBLE),

	[A_STRING] = (1U << A_STRING),
	[A_WIDE_STRING] = (1U << A_WIDE_STRING),

	[A_ASTERISK] = (1U << A_ASTERISK),
};


static int
makespecs(unsigned int arg, struct fmtmatch *fm) {
	if (arg < fm->maxspecs)
		return 0;

	size_t newmax = arg + 10;
	void *p = realloc(fm->specs, newmax * sizeof(*fm->specs));

	if (p == NULL) {
		addmsg(fm, "Can't allocate template %zu (%s)", arg,
		    strerror(errno));
		return ENOMEM;
	}

	(void)memset((char *)p + fm->maxspecs * sizeof(*fm->specs), 0,
	    (newmax - fm->maxspecs) * sizeof(*fm->specs));
	fm->maxspecs = newmax;
	fm->specs = p;
	return 0;
}

static int
template_callback(unsigned int arg, enum spec_t spec, struct fmtmatch *fm)
{
	int error;

	if ((error = makespecs(arg, fm)) != 0)
		return error;

	if (fm->specs[arg].template != A_UNDEF
	    && !(fm->specs[arg].template == A_ASTERISK && spec == A_ASTERISK))
	{
		addmsg(fm, "template refers to argument %u multiple times",
		    arg);
		return EINVAL;
	}

	if (fm->nspecs < arg)
		fm->nspecs = arg;

	fm->specs[arg].template = spec;
	return 0;
}

static int
format_callback(unsigned int arg, enum spec_t spec, struct fmtmatch *fm)
{
	int error;

	if ((error = makespecs(arg, fm)) != 0)
		return error;

	fm->specs[arg].format |= 1U << spec;
	return 0;
}

const char * __attribute__((__format_arg__(1)))
fmtmatch(const char *template, const char *format, int flags)
{
	int error;
	unsigned int i;
	struct fmtmatch fm;

	fm.nspecs = 0;
	fm.maxspecs = 0;
	fm.specs = NULL;
	fm.errmsg[0] = '\0';

	error = fmt_parse_template(template, &fm, template_callback);
	if (error)
		goto out;

	for (i = 1; i <= fm.nspecs; ++i) {
		if (fm.specs[i].template == 0) {
			addmsg(&fm, "Unreferenced argument %u in template", i);
			errno = EINVAL;
			goto out;
		}
	}

	error = fmt_parse_template(format, &fm, format_callback);
	if (error)
		goto out;

	for (i = 1; i <= fm.nspecs; ++i) {
		if (fm.specs[i].format == 0) {
			if (flags & FMTMATCH_ALLOW_EXTRA)
				continue;
			addmsg(&fm, "Argument %u not referenced in format", i);
			errno = EINVAL;
			goto out;
		}
		if (fm.specs[i].format &
		    ~compatible_specs[fm.specs[i].template]) {
			addmsg(&fm, "Specification %u in format is "
			    "incompatible", i);
			errno = EINVAL;
			goto out;
		}
	}
	free(fm.specs);
	return format;

out:
	free(fm.specs);
	if (flags & FMTMATCH_ERROR_PRINT)
		warnx("%s", fm.errmsg);
	if (flags & FMTMATCH_ERROR_RETURN_FORMAT)
		return format;
	if (flags & FMTMATCH_ERROR_RETURN_NULL)
		return NULL;
	if (flags & FMTMATCH_ERROR_RETURN_ERROR)
		return strdup(fm.errmsg);
	return template;
}

#ifdef TEST
int
main(int argc, char **argv)
{
	int i;

	if (argc < 2 || argc % 2 == 0)
		errx(EXIT_FAILURE, "Even number of argument(s) required");

	for (i = 1; i < argc; i += 2) {
		fmtmatch(argv[i], argv[i + 1], FMTMATCH_ERROR_PRINT);
	}

	return EXIT_SUCCESS;
}
#endif