/*
 * 	ucon_crypto.c
 *
 * Copyright (c) 2004-2005 Evgeniy Polyakov <johnpol@2ka.mipt.ru>
 * 
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 2 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
 */

#include <asm/types.h>

#include <sys/types.h>
#include <sys/socket.h>
#include <sys/poll.h>
#include <sys/mman.h>
#include <sys/signal.h>
#include <sys/time.h>
#include <sys/resource.h>
#include <sys/wait.h>

#include <linux/netlink.h>
#include <linux/types.h>
#include <linux/rtnetlink.h>

#include <arpa/inet.h>

#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include <string.h>
#include <errno.h>
#include <time.h>

#include "perf.h"

static int need_exit;
static __u32 seq;

static void SIGCHLD_h(int signo __attribute__((unused)))
{
	while (wait3(NULL, WNOHANG, NULL) == 0 && errno == EINTR);
}
static void SIGTERM_h(int signo)
{
	need_exit = signo;
}

static int nps_netlink_send(FILE *out, int s, struct nps_msg *msg, char *type)
{
	struct nlmsghdr *nlh;
	unsigned int size;
	char buf[128];
	int err;

	size = NLMSG_SPACE(sizeof(struct nps_msg));

	nlh = (struct nlmsghdr *)buf;
	nlh->nlmsg_seq = seq++;
	nlh->nlmsg_pid = getpid();
	nlh->nlmsg_type = NLMSG_DONE;
	nlh->nlmsg_len = NLMSG_LENGTH(size - sizeof(*nlh));
	nlh->nlmsg_flags = 0;

	memcpy(NLMSG_DATA(nlh), msg, sizeof(struct nps_msg));

	fprintf(out, "type=%s, num=%u, size=%u, pid=%u, users=%u.\n", 
			type, msg->num, msg->size, msg->pid, msg->users);

	err = send(s, nlh, size, 0);
	if (err == -1) {
		fprintf(out, "Failed to send: %s [%d].\n", strerror(errno), errno);
		return err;
	}

	return 0;
}

static int nps_send_cmd(FILE *out, int s, struct nps_msg *m)
{
	char *ctype = (m->type == NPS_UNICAST)?"Unicast":"Broadcast";
	
	return nps_netlink_send(out, s, m, ctype);
}

static void usage(char *procname)
{
	fprintf(stderr, "Usage: %s -p proc -l logfile -s size -t type -n num -h\n", procname);
	fprintf(stderr, "	-p proc		- number of clients receiving data. Default 1.\n");
	fprintf(stderr, "	-l logfile	- log file. Default stdout.\n");
	fprintf(stderr, "	-s size		- size of each message being sent from kernelspace. Default 4kb.\n");
	fprintf(stderr, "	-t type		- test type. 0 - unicast delivery, 1 - broadcast delivery. Default 0.\n");
	fprintf(stderr, "	-n num		- number of messages kenelspace will try to send. Default 1000.\n");
	fprintf(stderr, "	-h		- this help.\n");
}

static int nps_create_user(FILE *out, struct nps_msg *m)
{
	struct pollfd pfd;
	struct sockaddr_nl l_local;
	char *buf;
	int s, len;
	struct nlmsghdr *reply;
	unsigned int received = 0;
	struct timeval tm1, tm2;
	long diff, written;
	double speed;

	buf = malloc(m->size * 2);	/* Should be enough to store netlink overhead. */
	if (!buf)
		return -ENOMEM;

	s = socket(PF_NETLINK, SOCK_DGRAM, NETLINK_W1);
	if (s == -1) {
		perror("socket");
		return -1;
	}

	l_local.nl_family = AF_NETLINK;
	l_local.nl_groups = 1;
	l_local.nl_pid = m->pid;

	if (bind(s, (struct sockaddr *)&l_local, sizeof(struct sockaddr_nl)) == -1) {
		fprintf(out, "Failed to bind to pid %u: %s [%d].\n", m->pid, strerror(errno), errno);
		close(s);
		return -1;
	}

	gettimeofday(&tm1, NULL);
	fprintf(out, "[%2u] started.\n", m->pid);

	pfd.fd = s;

	while (!need_exit) {
		pfd.events = POLLIN;
		pfd.revents = 0;
		switch (poll(&pfd, 1, -1)) {
			case 0:
				need_exit = 1;
				break;
			case -1:
				if (errno != EINTR) {
					need_exit = 1;
					break;
				}
				continue;
		}
		if (need_exit)
			break;
		
		memset(buf, 0, 2 * m->size);
		len = recv(s, buf, 2 * m->size, 0);
		if (len == -1) {
			perror("recv buf");
			close(s);
			return -1;
		}
		reply = (struct nlmsghdr *)buf;

		switch (reply->nlmsg_type) {
		case NLMSG_ERROR:
			fprintf(out, "Error message received.\n");
			break;
		case NLMSG_DONE:
			received++;
			break;
		default:
			break;
		}

		if (received == m->num)
			break;
	}

	close(s);
	
	gettimeofday(&tm2, NULL);

	written = received * m->size / (1024*1024);
	diff = (tm2.tv_sec - tm1.tv_sec)*1000000 + tm2.tv_usec - tm1.tv_usec;
	speed = ((double)(unsigned long)written)*1000000.0/((double)diff);

	fprintf(out, "[%2u] received: %u [%ld Mb], speed %f Mb/sec, size %u, users %u.\n", m->pid, received, written, speed, m->size, m->users);
	
	return 0;
}

static int nps_control(FILE *out, struct nps_msg *m)
{
	int s, err;
	struct sockaddr_nl l_local;

	s = socket(PF_NETLINK, SOCK_DGRAM, NETLINK_W1);
	if (s == -1) {
		perror("socket");
		return -1;
	}

	l_local.nl_family = AF_NETLINK;
	l_local.nl_groups = 2;
	l_local.nl_pid = 0;

	if (bind(s, (struct sockaddr *)&l_local, sizeof(struct sockaddr_nl)) == -1) {
		fprintf(out, "Failed to bind to control pid %u: %s [%d].\n", m->pid, strerror(errno), errno);
		close(s);
		return -1;
	}

	m->pid = getpid();
	err = nps_send_cmd(out, s, m);
	
	close(s);
	return err;
}

static int nps_start(FILE *out, struct nps_msg *m)
{
	int err;
	unsigned int i;
	pid_t pid;
	
	signal(SIGCHLD, SIGCHLD_h);
	signal(SIGTERM, SIGTERM_h);
	signal(SIGINT, SIGTERM_h);

	for (i=0; i<m->users; ++i) {
		m->pid = getpid() + i;

		pid = fork();
		if (pid == -1) {
			m->users = i;
			break;
		} else if (pid == 0) {
			nps_create_user(out, m);
			exit(0);
		}
	}

	err = nps_control(out, m);
	if (err) {
		need_exit = 1;
		return err;
	}

	while (!need_exit)
		sleep(1);

	return 0;
}

int main(int argc, char *argv[])
{
	int ch;
	FILE *out;
	char *logfile = NULL;
	unsigned int size, num, type, proc;
	struct nps_msg m;

	proc = 1;
	size = 4096;
	type = 0;
	num = 1000;
	
	while ((ch = getopt(argc, argv, "p:l:s:t:n:h")) != -1) {
		switch (ch) {
			case 'p':
				proc = atoi(optarg);
				break;
			case 'l':
				logfile = optarg;
				break;
			case 's':
				size = atoi(optarg);
				break;
			case 'n':
				num = atoi(optarg);
				break;
			case 't':
				type = atoi(optarg);
				break;
			default:
			case 'h':
				usage(argv[0]);
				return -1;
		}
	}

	if (type != NPS_UNICAST && type != NPS_BROADCAST)
		return -EINVAL;

	if (logfile == NULL) {
		out = stdout;
		logfile = "(stdout)";
	} else {
		out = fopen(argv[1], "a+");
		if (!out) {
			fprintf(stderr, "Unable to open %s for writing: %s\n",
				argv[1], strerror(errno));
			out = stdout;
			logfile = "(stdout)";
		}
	}

	printf("logfile: %s, size: %u, type: %u, num: %u, proc: %u.\n", logfile, size, type, num, proc);

	m.size = size;
	m.type = type;
	m.num = num;
	m.users = proc;

	nps_start(out, &m);
	
	return 0;
}

