#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/in.h>

#include <assert.h>
#include <netdb.h>
#include <time.h>
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include <err.h>
#include <errno.h>

#define PORT_V4MAPPED "6666"
#define HOST_V4MAPPED "::FFFF:239.1.1.1"
#define PORT_V4 "6666"
#define HOST_V4 "239.1.1.1"
#define PORT_V6 "6666"
#define HOST_V6 "FF05:0:0:0:0:0:0:1"

static int
addmc(int s, struct addrinfo *ai)
{
	struct ip_mreq 	m4;
	struct ipv6_mreq m6;
	struct sockaddr_in *s4;
	struct sockaddr_in6 *s6;
	
	switch (ai->ai_family) {
	case AF_INET:
		s4 = (void *)ai->ai_addr;
		assert(sizeof(*s4) == ai->ai_addrlen);
		m4.imr_multiaddr = s4->sin_addr;
		m4.imr_interface.s_addr = htonl(INADDR_ANY);
		return setsockopt(s, IPPROTO_IP, IP_ADD_MEMBERSHIP,
		    &m4, sizeof(m4));
	case AF_INET6:
		s6 = (void *)ai->ai_addr;
		// XXX: Both linux and we do this thing wrong...
		if (IN6_IS_ADDR_V4MAPPED(&s6->sin6_addr)) {
			memcpy(&m4.imr_multiaddr, &s6->sin6_addr.s6_addr[12],
			    sizeof(m4.imr_multiaddr));
			m4.imr_interface.s_addr = htonl(INADDR_ANY);
			return setsockopt(s, IPPROTO_IP, IP_ADD_MEMBERSHIP,
			    &m4, sizeof(m4));
		}
		assert(sizeof(*s6) == ai->ai_addrlen);
		memset(&m6, 0, sizeof(m6));
		m6.ipv6mr_interface = 0;
	        m6.ipv6mr_multiaddr = s6->sin6_addr;
		return setsockopt(s, IPPROTO_IPV6, IPV6_JOIN_GROUP,
		    &m6, sizeof(m6));
	default:
		errno = EOPNOTSUPP;
		return -1;
	}
}

static int
allowv4mapped(int s, struct addrinfo *ai)
{
	struct sockaddr_in6 *s6;
	int zero = 0;

	if (ai->ai_family != AF_INET6)
		return 0;

	s6 = (void *)ai->ai_addr;

	if (!IN6_IS_ADDR_V4MAPPED(&s6->sin6_addr))
		return 0;
	return setsockopt(s, IPPROTO_IPV6, IPV6_V6ONLY, &zero, sizeof(zero));
}

static int
getsocket(const char *host, const char *port,
    int (*f)(int, const struct sockaddr *, socklen_t))
{
	int e, s;
	struct addrinfo hints, *ai0, *ai;
	const char *cause;

	memset(&hints, 0, sizeof(hints));
	hints.ai_family = AF_UNSPEC;
	hints.ai_socktype = SOCK_DGRAM;
	e = getaddrinfo(host, port, &hints, &ai0);
	if (e)
		errx(EXIT_FAILURE, "Can't resolve %s:%s (%s)", host, port,
		    gai_strerror(e));

	s = -1;
	for (ai = ai0; ai; ai = ai->ai_next) {
		s = socket(ai->ai_family, ai->ai_socktype, ai->ai_protocol);
		if (s == -1) {
			cause = "socket";
			continue;
		}
		if (allowv4mapped(s, ai) == -1) {
			cause = "allow v4 mapped";
			goto out;
		}
		if ((*f)(s, ai->ai_addr, ai->ai_addrlen) == -1) {
			cause = f == bind ? "bind" : "connect";
			goto out;
		}
		if (f == bind && addmc(s, ai) == -1) {
			cause = "join group";
			goto out;
		}
		break;
out:
		close(s);
		s = -1;
		continue;
	}
	freeaddrinfo(ai0);
	if (s == -1)
		err(1, "%s", cause);
	return s;
}

int
main(int argc, char *argv[])
{
	int s;
	ssize_t l;
	size_t seq;
	char buf[64];
	const char *host, *port;

	host = HOST_V4;
	port = PORT_V4;
	host = HOST_V4MAPPED;
	port = PORT_V4MAPPED;
	host = HOST_V6;
	port = PORT_V6;

	if (argc > 1) {
		s = getsocket(host, port, connect);
		for (seq = 0;; seq++) {
			time_t t = time(&t);
			snprintf(buf, sizeof(buf), "%zu: %-24.24s",
			    seq, ctime(&t));
			printf("sending: %s\n", buf);
			l = send(s, buf, sizeof(buf), 0);
			if (l == -1)
				err(EXIT_FAILURE, "send");
			sleep(1);
		}
	} else {
		s = getsocket(host, port, bind);
		for (;;) {
			l = recv(s, buf, sizeof(buf), 0);
			if (l == -1)
				err(EXIT_FAILURE, "recv");
			printf("got: %s\n", buf);
		}
	}
	return 0;
}