/*

Copyright (C) 2003 Christian Kreibich <christian@whoop.org>.

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to
deal in the Software without restriction, including without limitation the
rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
sell copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in
all copies of the Software and its documentation and acknowledgment shall be
given in the documentation and software packages that this Software was
used.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

*/
#ifdef HAVE_CONFIG_H
#include "config.h"
#endif

#include "honeycomb.h"
#include "hc_debug.h"
#include "hc_bitmaps.h"
#include "hc_tcp_conns.h"



typedef TAILQ_HEAD(hc_tcp_conn_list, hc_tcp_conn) HC_TCPConnList;

static HC_TCPConnList   *tcp_conns;
static u_int             tcp_conns_num_slots;
static u_int             tcp_conns_max;
static u_int             tcp_conns_cur;

static int               tcp_conns_key_counter;

static u_int
tcp_conn_id_hash(HC_TCPConnID *id)
{
	return ((id->src_port ^ id->src_addr) ^
		(id->dst_port ^ id->dst_addr));
}


static int
tcp_conn_id_direct_match(HC_TCPConnID *id1, HC_TCPConnID *id2)
{
	if (id1->src_addr == id2->src_addr &&
	    id1->src_port == id2->src_port &&
	    id1->dst_addr == id2->dst_addr &&
	    id1->dst_port == id2->dst_port)
		return 1;

	return 0;
}


static int
tcp_conn_id_reverse_match(HC_TCPConnID *id1, HC_TCPConnID *id2)
{
	if (id1->src_addr == id2->dst_addr &&
	    id1->src_port == id2->dst_port &&
	    id1->dst_addr == id2->src_addr &&
	    id1->dst_port == id2->src_port)
		return 1;

	return 0;
}

static int
tcp_conn_id_equal(HC_TCPConnID *id1, HC_TCPConnID *id2)
{
	if (tcp_conn_id_direct_match(id1, id2) ||
	    tcp_conn_id_reverse_match(id1, id2)) {
		return 1;
	}
	
	return 0;
}


static HC_TCPConn *
tcp_conn_new(struct ip_hdr *iphdr, struct tcp_hdr *tcphdr)
{
	HC_TCPConn *conn;
	
	if (!tcphdr)
		return NULL;

	if (! (conn = calloc(1, sizeof(HC_TCPConn)))) {
		D(("Out of memory\n"));
		return NULL;
	}

	if (! (conn->stream = hc_bitmap_new(HC_TCP_STREAM_PART_SIZE))) {
		free(conn);
		return NULL;
	}

	if (! (conn->stream_headers = hc_bitmap_new(IP_HDR_LEN_MAX + TCP_HDR_LEN_MAX))) {
		hc_bitmap_free(conn->stream);
		free(conn);
		return NULL;
	}

	conn->id.src_addr = iphdr->ip_src;
	conn->id.src_port = tcphdr->th_sport;
	conn->id.dst_addr = iphdr->ip_dst;
	conn->id.dst_port = tcphdr->th_dport;
	
	/* Copy TCP and IP header into connection structure: */
	memcpy(conn->hdr, iphdr, (iphdr->ip_hl + tcphdr->th_off) << 2);
	
	conn->key = tcp_conns_key_counter++;
	
	/* Mark stream direction as "no data seen yet". */
	conn->stream_reversed = -1; 
	
	return conn;
}


static void
tcp_conn_free(HC_TCPConn *conn)
{
	hc_bitmap_free(conn->stream);
	hc_bitmap_free(conn->stream_headers);
	
	if (conn)
		free(conn);
}


void          
hc_tcp_conns_init(u_int max_num_conns)
{	
	u_int i;

	if (tcp_conns || max_num_conns == 0)
		return;
	
	if (! (tcp_conns = malloc(HC_TCP_CONNSHASH_SIZE * sizeof(HC_TCPConnList)))) {		
		D(("Out of memory.\n"));
		return;
	}
	
	tcp_conns_num_slots = HC_TCP_CONNSHASH_SIZE;
	tcp_conns_max = max_num_conns;
	tcp_conns_cur = 0;
	
	for (i = 0; i < HC_TCP_CONNSHASH_SIZE; i++)
		TAILQ_INIT(&tcp_conns[i]);	
}




HC_TCPConn *
hc_tcp_conns_find(ip_addr_t src_addr, uint16_t src_port,
		  ip_addr_t dst_addr, uint16_t dst_port)
{
	HC_TCPConnID id;
	HC_TCPConn *conn;
	HC_TCPConnList *list;	
	u_int hash;

	id.src_addr = src_addr; id.src_port = src_port;
	id.dst_addr = dst_addr; id.dst_port = dst_port;
	
	hash = tcp_conn_id_hash(&id) % tcp_conns_num_slots;
	list = &tcp_conns[hash];
	
	for (conn = list->tqh_first; conn; conn = conn->conns.tqe_next) {
		
		/* Skip old connections, they're not here to be actively
		 * looked up but only for pattern detection!
		 */
		if (conn->terminated)
			continue;
		
		if (tcp_conn_id_equal(&id, &conn->id))
			return conn;		
	}

	D(("TCP connection lookup failed, hash was %u\n", hash));
	return NULL;
}


HC_TCPConn *
hc_tcp_conns_add(struct ip_hdr *iphdr, struct tcp_hdr *tcphdr)
{
	HC_TCPConnID id;
	HC_TCPConn *conn;
	HC_TCPConnList *list;	
	u_int i;

	/* No need to ntohs() ports, we only hash on them. */ 
	id.src_addr = iphdr->ip_src; id.dst_addr = iphdr->ip_dst;
	id.src_port = tcphdr->th_sport; id.dst_port = tcphdr->th_dport;
	
	if ( (conn = hc_tcp_conns_find(id.src_addr, id.src_port,
				       id.dst_addr, id.dst_port))) {
		D(("Not creating new TCP connection state, already there\n"));
		return conn;
	}

	D(("Creating new TCP connection state\n"));
	/* Connection not found, create new one, add it, and return it! */	
	list = &tcp_conns[tcp_conn_id_hash(&id) % tcp_conns_num_slots];

	if (! (conn = tcp_conn_new(iphdr, tcphdr)))
		return NULL;

	TAILQ_INSERT_TAIL(list, conn, conns);
	tcp_conns_cur++;

	if (tcp_conns_cur > tcp_conns_max) {
		
		/* Bugger. We have too many connections, nuke the one with
		 * the oldest key. For now go through the whole hashtable,
		 * if this should prove too slow we can always change the
		 * algorithm ... */
		HC_TCPConnList *oldest_list = NULL;
		HC_TCPConn     *oldest_conn = NULL, *cur_conn;
		int             oldest_key = tcp_conns_key_counter;

		for (i = 0; i < tcp_conns_num_slots; i++) {

			for (cur_conn = tcp_conns[i].tqh_first; cur_conn; cur_conn = cur_conn->conns.tqe_next) {

				if (cur_conn->key < oldest_key) {
					oldest_list = &tcp_conns[i];
					oldest_key = cur_conn->key;
					oldest_conn = cur_conn;
				}
			}
		}
		
		D_ASSERT_PTR(oldest_conn);
		TAILQ_REMOVE(oldest_list, oldest_conn, conns);
		tcp_conn_free(oldest_conn);
		tcp_conns_cur--;
	}
	
	return conn;
}


static void
tcp_conns_drop_in_list(HC_TCPConnList *list, HC_TCPConnID *id)
{
	HC_TCPConn *conn;

	if (!list || !id)
		return;

	for (conn = list->tqh_first; conn; conn = conn->conns.tqe_next) {
		
		if (tcp_conn_id_equal(id, &conn->id)) {
			
			TAILQ_REMOVE(list, conn, conns);
			tcp_conn_free(conn);
			tcp_conns_cur--;
			return;
		}
	}
}


void          
hc_tcp_conns_drop(ip_addr_t src_addr, uint16_t src_port,
		  ip_addr_t dst_addr, uint16_t dst_port)
{
	HC_TCPConnID id;
	HC_TCPConnList *list;	

	id.src_addr = src_addr; id.src_port = src_port;
	id.dst_addr = dst_addr; id.dst_port = dst_port;
	
	list = &tcp_conns[tcp_conn_id_hash(&id) % tcp_conns_num_slots];
	tcp_conns_drop_in_list(list, &id);
}


void         
hc_tcp_conns_foreach(HC_TCPConnCB callback, void *user_data)
{
	HC_TCPConn *conn;
	u_int i;

	if (!callback)
		return;
	
	for (i = 0; i < tcp_conns_num_slots; i++) {
		
		for (conn = tcp_conns[i].tqh_first; conn; conn = conn->conns.tqe_next)
			callback(conn, user_data);
	}
}


static void          
tcp_conn_add_data(HC_TCPConn *conn, struct ip_hdr *iphdr)
{
	HC_TCPConnID id;
	HC_Blob *blob;
	struct tcp_hdr *tcphdr;
	int payload_size, header_size;
	u_char *payload;
	
	/* Don't do anything on invalid input or if this connection
	 * has already exchanged too many messages and we've lost interest.
	 */
	if (!conn || !iphdr ||
	    conn->stream->num_blobs == HC_TCP_STREAM_NUM_PARTS)
		return;
	
	tcphdr = (struct tcp_hdr *) ((u_char *) iphdr + (iphdr->ip_hl << 2));
	
	/* Check if we actually have any payload: */	
	header_size = ((tcphdr->th_off + iphdr->ip_hl) << 2);
	payload_size = ntohs(iphdr->ip_len) - header_size;

	if (payload_size <= 0)
		return;
	
	D(("Adding %i bytes to connection\n", payload_size));

	payload = (u_char *) tcphdr + (tcphdr->th_off << 2);
	id.src_addr = iphdr->ip_src; id.src_port = tcphdr->th_sport;
	id.dst_addr = iphdr->ip_dst; id.dst_port = tcphdr->th_dport;

	/* Okay, we have data. Now figure out what direction it is and
	 * what was the last direction we've seen data flow in. If it
	 * changed, add a new blob to the stream, otherwise add as much
	 * data of this packet to the last blob as possible/necessary.
	 */
	if (conn->stream_reversed < 0) {

		blob = conn->stream->blobs.tqh_first;
		hc_blob_add_data(blob, payload, payload_size);
		D(("First blob now %u %u\n", blob->data_used, blob->data_len));

		blob = conn->stream_headers->blobs.tqh_first;
		hc_blob_add_data(blob, (u_char*) iphdr, header_size);
		
		if (tcp_conn_id_direct_match(&id, &conn->id))
			conn->stream_reversed = 0;
		else
			conn->stream_reversed = 1;

		return;
	}

	blob = hc_bitmap_get_last_blob(conn->stream);
	D_ASSERT_PTR(blob);

	if (tcp_conn_id_reverse_match(&id, &conn->id)) {
		
		if (conn->stream_reversed) {
			hc_blob_add_data(blob, payload, payload_size);
		} else {
			hc_blob_crop(blob);
			
			hc_bitmap_add_blob(conn->stream_headers, (u_char *) iphdr, header_size,
					   IP_HDR_LEN_MAX + TCP_HDR_LEN_MAX);
			hc_bitmap_add_blob(conn->stream, payload, payload_size,
					   HC_TCP_STREAM_PART_SIZE);
		}

		conn->stream_reversed = 1;

	} else {
		
		if (conn->stream_reversed) {
			hc_blob_crop(blob);

			hc_bitmap_add_blob(conn->stream_headers, (u_char *) iphdr, header_size,
					   IP_HDR_LEN_MAX + TCP_HDR_LEN_MAX);
			hc_bitmap_add_blob(conn->stream, payload, payload_size,
					   HC_TCP_STREAM_PART_SIZE);
		} else {
			hc_blob_add_data(blob, payload, payload_size);
		}

		conn->stream_reversed = 0;
	}
}


void          
hc_tcp_conn_update_state(HC_TCPConn *conn, struct ip_hdr *iphdr)
{
	HC_TCPConnID id;
	struct tcp_hdr *tcphdr;
	int reverse_match;

	if (!conn || !iphdr)
		return;
	
	tcphdr = (struct tcp_hdr *) ((u_char *) iphdr + (iphdr->ip_hl << 2));
	
	id.src_addr = iphdr->ip_src; id.src_port = tcphdr->th_sport;
	id.dst_addr = iphdr->ip_dst; id.dst_port = tcphdr->th_dport;

	reverse_match = tcp_conn_id_reverse_match(&conn->id, &id);

	if (reverse_match && conn->answered == 0) {
		D(("TCP connection has been answered\n"));
		conn->answered = 1;
	}

	/* If we see a FIN, track orderly connection teardown. */
	if (tcphdr->th_flags & TH_FIN) {

		if (! reverse_match) {
			conn->fin = ntohl(tcphdr->th_seq) + 1;
		} else {
			conn->fin_back = ntohl(tcphdr->th_seq) + 1;
		}
	}     
	
	if (conn->fin && reverse_match &&
	    ntohl(tcphdr->th_ack) >= conn->fin &&
	    !conn->fin_acked) {

		/* --> FIN was seen with seq x, now we see
		 * <-- ACK with ack x + 1, thus the source side shutdown
		 * is complete.
		 */
		
		D(("FIN --> acked.\n"));
		conn->fin_acked = 1;
	}
	
	if (conn->fin_back && !reverse_match &&
	    ntohl(tcphdr->th_ack) >= conn->fin_back &&
	    !conn->fin_back_acked) {

		/* <-- FIN was seen with seq x, now we see
		 * --> ACK with ack x + 1, thus the dest side shutdown
		 * is complete.
		 */
		D(("FIN <-- acked.\n"));
		conn->fin_back_acked = 1;
	}

	/* Drop this connection once we see a reset or if we've completed
	 * the shutdown procedure:
	 */
	if ((tcphdr->th_flags & TH_RST) ||
	    (conn->fin_acked && conn->fin_back_acked)) {

		D(("Connection terminated.\n"));
		conn->terminated = 1;
	}	

	tcp_conn_add_data(conn, iphdr);
}


u_int           
hc_tcp_conn_get_num_messages(HC_TCPConn *conn)
{
	if (!conn)
		return 0;

	return conn->stream->num_blobs;
}


HC_Blob *
hc_tcp_conn_get_nth_message(HC_TCPConn *conn, u_int num)
{
	HC_Blob *blob;
	u_int i = 0;

	if (num > conn->stream->num_blobs)
		return NULL;

	for (blob = conn->stream->blobs.tqh_first; blob; blob = blob->items.tqe_next) {
		
		if (i == num)
			return blob;
		
		i++;
	}
	
	return NULL;
}


struct ip_hdr *
hc_tcp_conn_get_nth_message_header(HC_TCPConn *conn, u_int num)
{
	HC_Blob *blob;
	u_int i = 0;
	
	if (num > conn->stream_headers->num_blobs)
		return NULL;
	
	for (blob = conn->stream_headers->blobs.tqh_first; blob; blob = blob->items.tqe_next) {
		
		if (i == num)
			return (struct ip_hdr *) blob->data;
		
		i++;
	}
	
	return NULL;
}
