/*
 * 2007+ Copyright (c) Evgeniy Polyakov <johnpol@2ka.mipt.ru>
 * All rights reserved.
 * 
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 2 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 */

#include <linux/kernel.h>
#include <linux/module.h>
#include <linux/list.h>
#include <linux/slab.h>
#include <linux/socket.h>
#include <linux/kthread.h>
#include <linux/net.h>
#include <linux/in.h>
#include <linux/poll.h>

#include <net/sock.h>

#define KST_SA_SIZE		128

struct kst_addr
{
	unsigned short		sa_family;
	unsigned char		sa_data[KST_SA_SIZE];
	int			sa_data_len;
};

struct kst_worker
{
	struct list_head	entry;

	struct list_head	state_list;
	struct mutex		state_mutex;
	
	struct list_head	ready_list;
	spinlock_t		ready_lock;
	
	struct task_struct	*thread;

	wait_queue_head_t 	wait;
	
	int			id;
};

struct kst_state
{
	struct list_head	entry;
	struct list_head	ready_entry;

	wait_queue_t 		wait;
	wait_queue_head_t 	*whead;

	struct kst_worker	*w;
	struct socket		*socket;

	struct kst_addr		addr;
	
	int			backlog;
	int			type, proto;
	
	unsigned int		events;
	int 			(*init)(struct kst_state *, void *);
	int 			(*callback)(struct kst_state *, unsigned int);
	void 			(*exit)(struct kst_state *);
};

struct kst_poll_helper
{
	poll_table 		pt;
	struct kst_state	*st;
};

static LIST_HEAD(kst_worker_list);
static DEFINE_MUTEX(kst_worker_mutex);

static int kst_sock_create(struct kst_state *st)
{
	int err;

	err = sock_create(st->addr.sa_family, st->type, 
			st->proto, &st->socket);
	if (err)
		goto err_out_exit;

	err = st->socket->ops->bind(st->socket, (struct sockaddr *)&st->addr, 
			st->addr.sa_data_len);

	err = st->socket->ops->listen(st->socket, st->backlog);
	if (err)
		goto err_out_release;

	return 0;

err_out_release:
	sock_release(st->socket);
err_out_exit:
	return err;
}

static void kst_sock_release(struct kst_state *st)
{
	sock_release(st->socket);
}

static int kst_state_wake_callback(wait_queue_t *wait, unsigned mode, int sync, void *key)
{
	struct kst_state *st = container_of(wait, struct kst_state, wait);
	struct kst_worker *w = st->w;
	unsigned long flags;

	spin_lock_irqsave(&w->ready_lock, flags);
	if (list_empty(&st->ready_entry))
		list_add_tail(&st->ready_entry, &w->ready_list);
	spin_unlock_irqrestore(&w->ready_lock, flags);

	wake_up(&w->wait);

	return 1;
}

static void kst_queue_func(struct file *file, wait_queue_head_t *whead,
				 poll_table *pt)
{
	struct kst_state *st = container_of(pt, struct kst_poll_helper, pt)->st;
	
	st->whead = whead;
	init_waitqueue_func_entry(&st->wait, kst_state_wake_callback);
	add_wait_queue(whead, &st->wait);
}

static void kst_poll_exit(struct kst_state *st)
{
	remove_wait_queue(st->whead, &st->wait);
}

static int kst_poll_init(struct kst_state *st)
{
	struct kst_poll_helper ph;
	unsigned int revents;

	ph.st = st;

	init_poll_funcptr(&ph.pt, &kst_queue_func);

	revents = st->socket->ops->poll(NULL, st->socket, &ph.pt);
	if (revents & st->events) {
		int ret = st->callback(st, revents);

		if (ret <= 0) {
			kst_poll_exit(st);
			return ret;
		}
	}

	return 0;
}

static struct kst_state *kst_state_init(struct kst_worker *w, 
		int (*init)(struct kst_state *, void *),
		int (*callback)(struct kst_state *, unsigned int),
		void (*exit)(struct kst_state *),
		unsigned int events, void *data)
{
	struct kst_state *st;
	int err;

	st = kzalloc(sizeof(struct kst_state), GFP_KERNEL);
	if (!st)
		return ERR_PTR(-ENOMEM);

	st->events = events;
	st->init = init;
	st->exit = exit;
	st->callback = callback;
	st->w = w;
	INIT_LIST_HEAD(&st->ready_entry);
	INIT_LIST_HEAD(&st->entry);

	err = st->init(st, data);
	if (err)
		goto err_out_free;

	mutex_lock(&w->state_mutex);
	list_add_tail(&st->entry, &w->state_list);
	mutex_unlock(&w->state_mutex);

	return st;

err_out_free:
	kfree(st);
	return ERR_PTR(err);
}

static void kst_state_exit(struct kst_state *st)
{
	struct kst_worker *w = st->w;

	printk("%s: st: %p.\n", __func__, st);

	mutex_lock(&w->state_mutex);
	list_del_init(&st->entry);
	mutex_unlock(&w->state_mutex);

	st->exit(st);
	kfree(st);
}

static int kst_thread_func(void *data)
{
	int err = 0;
	struct kst_worker *w = data;
	struct kst_state *st;
	unsigned int revents;
	unsigned long flags;

	while (!kthread_should_stop()) {
		wait_event_interruptible_timeout(w->wait, 
				!list_empty(&w->ready_list) || kthread_should_stop(), 
				HZ);
		
		st = NULL;
		spin_lock_irqsave(&w->ready_lock, flags);
		if (!list_empty(&w->ready_list)) {
			st = list_entry(w->ready_list.next, struct kst_state, ready_entry);
			list_del_init(&st->ready_entry);
		}
		spin_unlock_irqrestore(&w->ready_lock, flags);
		
		if (!st)
			continue;

		revents = st->socket->ops->poll(st->socket->file, st->socket, NULL);
		printk("%s: st: %p, revents: %x, events: %x.\n", __func__, st, revents, st->events);
		if (revents & st->events) {
			int ret = st->callback(st, revents);
			printk("%s: callback returned, st: %p, ret: %d.\n", __func__, st, ret);

			if (ret <= 0) {
				kst_state_exit(st);
			}
		}
	}

	return err;
}

static struct kst_worker *kst_worker_init(int id)
{
	struct kst_worker *w;
	int err;

	w = kzalloc(sizeof(struct kst_worker), GFP_KERNEL);
	if (!w)
		return ERR_PTR(-ENOMEM);

	w->id = id;
	init_waitqueue_head(&w->wait);
	spin_lock_init(&w->ready_lock);
	mutex_init(&w->state_mutex);

	INIT_LIST_HEAD(&w->ready_list);
	INIT_LIST_HEAD(&w->state_list);

	w->thread = kthread_run(&kst_thread_func, w, "kst%d", w->id);
	if (IS_ERR(w->thread)) {
		err = PTR_ERR(w->thread);
		goto err_out_free;
	}

	mutex_lock(&kst_worker_mutex);
	list_add_tail(&w->entry, &kst_worker_list);
	mutex_unlock(&kst_worker_mutex);

	return w;

err_out_free:
	kfree(w);
	return ERR_PTR(err);
}

static void kst_worker_exit(struct kst_worker *w)
{
	struct kst_state *st, *n;

	mutex_lock(&kst_worker_mutex);
	list_del(&w->entry);
	mutex_unlock(&kst_worker_mutex);

	kthread_stop(w->thread);

	list_for_each_entry_safe(st, n, &w->state_list, entry) {
		kst_state_exit(st);
	}

	kfree(w);
}

static void kst_common_exit(struct kst_state *st)
{
	unsigned long flags;

	printk("%s: st: %p.\n", __func__, st);
	kst_poll_exit(st);

	spin_lock_irqsave(&st->w->ready_lock, flags);
	list_del_init(&st->ready_entry);
	spin_unlock_irqrestore(&st->w->ready_lock, flags);

	kst_sock_release(st);
}

static int kst_data_callback(struct kst_state *st, unsigned int revents)
{
	unsigned char data[128];
	struct msghdr msg;
	struct kvec iov;
	int err;

	do {
		st->socket->sk->sk_allocation = GFP_NOIO;
		iov.iov_base = data;
		iov.iov_len = sizeof(data);
		msg.msg_name = NULL;
		msg.msg_namelen = 0;
		msg.msg_control = NULL;
		msg.msg_controllen = 0;
		msg.msg_flags = MSG_NOSIGNAL | MSG_DONTWAIT;
		err = kernel_recvmsg(st->socket, &msg, &iov, 1, iov.iov_len, msg.msg_flags);
		printk("%s: st: %p, sock: %p, err: %d.\n", __func__, st, st->socket, err);
		if (err <= 0)
			break;
	} while (1);

	if (err == -EAGAIN)
		return 1;
	return err;
}

static int kst_data_init(struct kst_state *st, void *data)
{
	int err;

	st->socket = data;

	err = kst_poll_init(st);
	if (err)
		return err;

	return 0;
}

static int kst_listen_callback(struct kst_state *st, unsigned int revents)
{
	struct socket *newsock;
	struct kst_addr addr;
	struct kst_state *newst;
	struct sockaddr_in *sin;
	int err;

	err = sock_create(st->addr.sa_family, st->type, st->proto, &newsock);
	if (err)
		goto err_out_exit;

	err = st->socket->ops->accept(st->socket, newsock, 0);
	if (err)
		goto err_out_put;
	
	if (newsock->ops->getname(newsock, (struct sockaddr *)&addr,
				  &addr.sa_data_len, 2) < 0) {
		err = -ECONNABORTED;
		goto err_out_put;
	}
	sin = (struct sockaddr_in *)&addr;

	printk("%s: Client: %u.%u.%u.%u:%d, sk: %p.\n", __func__, 
			NIPQUAD(sin->sin_addr.s_addr), ntohs(sin->sin_port), newsock);

	newst = kst_state_init(st->w, &kst_data_init, &kst_data_callback, 
			&kst_common_exit, POLLIN, newsock);
	if (IS_ERR(newst))
		goto err_out_put;
	
	memcpy(&newst->addr, &addr, sizeof(struct kst_addr));

	return 1;

err_out_put:
	sock_release(newsock);
err_out_exit:
	return 1;
}

static int kst_listen_init(struct kst_state *st, void *data)
{
	int err;

	memcpy(&st->addr, data, sizeof(struct kst_addr));

	st->backlog = 100;
	st->proto = IPPROTO_TCP;
	st->type = SOCK_STREAM;

	err = kst_sock_create(st);
	if (err)
		goto err_out_exit;

	err = kst_poll_init(st);
	if (err)
		goto err_out_release;

	return 0;

err_out_release:
	kst_sock_release(st);
err_out_exit:
	return err;
}

static int kst_init(void)
{
	int err;
	struct kst_state *st;
	struct kst_worker *w;
	struct kst_addr addr;
	struct sockaddr_in *sin;

	w = kst_worker_init(128);
	if (IS_ERR(w))
		return PTR_ERR(w);

	memset(&addr, 0, sizeof(struct kst_addr));

	sin = (struct sockaddr_in *)&addr;

	sin->sin_family = AF_INET;
	sin->sin_port = htons(1025);
	sin->sin_addr.s_addr = htonl(INADDR_ANY);
	addr.sa_data_len = sizeof(struct sockaddr_in);

	st = kst_state_init(w, &kst_listen_init, &kst_listen_callback, 
			&kst_common_exit, POLLIN, &addr);
	if (IS_ERR(st)) {
		err = PTR_ERR(st);
		goto err_out_worker_exit;
	}

	return 0;

err_out_worker_exit:
	kst_worker_exit(w);
	return err;
}

static void kst_exit(void)
{
	struct kst_worker *w, *n;

	list_for_each_entry_safe(w, n, &kst_worker_list, entry) {
		kst_worker_exit(w);
	}
}

module_init(kst_init);
module_exit(kst_exit);

MODULE_LICENSE("GPL");
MODULE_AUTHOR("Evgeniy Polyakov <johnpol@2ka.mipt.ru>");
MODULE_DESCRIPTION("Kernel ->poll() based state machine.");
