/*
 * COPYRIGHT    2001
 * THE REGENTS OF THE UNIVERSITY OF MICHIGAN
 * ALL RIGHTS RESERVED
 *
 * Permission is granted to use, copy, create derivative works
 * and redistribute this software and such derivative works
 * for any purpose, so long as the name of The University of
 * Michigan is not used in any advertising or publicity
 * pertaining to the use of distribution of this software
 * without specific, written prior authorization.  If the
 * above copyright notice or any other identification of the
 * University of Michigan is included in any copy of any
 * portion of this software, then the disclaimer below must
 * also be included.
 *
 * THIS SOFTWARE IS PROVIDED AS IS, WITHOUT REPRESENTATION
 * FROM THE UNIVERSITY OF MICHIGAN AS TO ITS FITNESS FOR ANY
 * PURPOSE, AND WITHOUT WARRANTY BY THE UNIVERSITY O
 * MICHIGAN OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING
 * WITHOUT LIMITATION THE IMPLIED WARRANTIES OF
 * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE
 * REGENTS OF THE UNIVERSITY OF MICHIGAN SHALL NOT BE LIABLE
 * FOR ANY DAMAGES, INCLUDING SPECIAL, INDIRECT, INCIDENTAL, OR
 * CONSEQUENTIAL DAMAGES, WITH RESPECT TO ANY CLAIM ARISING
 * OUT OF OR IN CONNECTION WITH THE USE OF THE SOFTWARE, EVEN
 * IF IT HAS BEEN OR IS HEREAFTER ADVISED OF THE POSSIBILITY OF
 * SUCH DAMAGES.
 */

#include <sys/types.h>
#include <sys/stat.h>
#include <sys/socket.h>
#include <net/if.h>
#include <netinet/in.h>
#include <netinet/if_ether.h>
#include <netinet/in_systm.h>
#include <netinet/ip.h>
#ifdef LOCAL_BPF_HEADERS
/* So we can build this on a system w/o kernel changes */
#include "../sys/net/bpf.h"
#else
#include "/usr/src/sys/net/bpf.h"
#endif

#include <err.h>
#include <stdio.h>
#include <unistd.h>
#include <stdlib.h>

#ifdef __FreeBSD__
typedef u_int32_t in_addr_t;
#endif

#include "../bpf/parse_bpf.h"
#include "../crypto/crypto.h"
#include "../dump/cache.h"
#include "crypto_file.h"

/* fixed part of IP header: */
#define IP_HDR_SIZE 20

int cryptalg;

struct crypto_file * open_input_crypto_file(char * filename) {
  struct crypto_file * new;
  int blocklen;

  if ((new = malloc(sizeof(*new))) == NULL)
    errx(1,"unable to allocate memory for struct crypto_file");
  if ((new->filename = malloc(strlen(filename)+1)) == NULL)
    errx(1,"unable to allocate memory for filename %s",filename);
  memcpy(new->filename,filename,strlen(filename)+1);

  if ((new->file=fopen(filename,"r")) == NULL)
    errx(1,"Error opening input file %s\n",new->filename);
  if (fread(&new->format,sizeof(new->format),1,new->file) != 1)
    errx(1,"Error reading file format from input file %s",new->filename);
  if ((new->format < 0) || (new->format > 2))
    errx(1,"Unknown crypto file format %d in file %s",new->format,
	 new->filename);

  if (fread(&new->cryptalg,sizeof(new->cryptalg),1,new->file) != 1)
    errx(1,"Error reading cryptalg from %s",new->filename);
  if ((new->cryptalg<0) || (new->cryptalg >= NUM_ALGORITHMS))
    errx(1,"%s uses unknown crypto algorithm %d",new->filename, 
	 new->cryptalg);

  if (new->format != 3) {
    blocklen = algorithms[new->cryptalg].block_length;
    if (fread(new->iv,1,blocklen,new->file) != blocklen)
      errx(1,"%s truncated before iv",new->filename);
  }
  /* fix me: btree, hash, list code expects a global cryptalg variable. */
  cryptalg = new->cryptalg;
  return(new);
}

void try_to_decrypt_packet(struct crypto_packet * packet,
			   struct crypto_key_ring * keys, char * iv) {
  struct crypto_session conv_key_session;
  int blocklen = algorithms[packet->cryptalg].block_length;
  struct crypto_session * conv_key_session_ptr;

  if (!keys)
    goto fail;

  if (packet->format == 0) {
    char * conv_key;
    struct bpf_hdr * bpf = (struct bpf_hdr *)(packet->data);
    char * ether_hdr = packet->data + BPF_HDR_SIZE;
    struct ip * ip_hdr = (struct ip *)(ether_hdr + ETHER_HDR_LEN);

    if (keys->vol_key && keys->tr_key) { /* we're decrypting everything */
      if (!reveal_address(&ip_hdr->ip_src.s_addr))
	printf("WARNING: can't find substituted ip 0x%08x in table\n",
	       ntohl(ip_hdr->ip_src.s_addr));
      if (!reveal_address(&ip_hdr->ip_dst.s_addr))
	printf("WARNING: can't find substituted ip 0x%08x in table\n",
	       ntohl(ip_hdr->ip_dst.s_addr));
      conv_key = make_conv_key(keys->vol_key,ip_hdr,packet->cryptalg);
      get_crypto_session(&conv_key_session,
			 packet->cryptalg,CBC_MODE,DECRYPT,conv_key,0);
      conv_key_session_ptr = &conv_key_session;
      
    } else { /* look for a conversation key */
      if (!keys->conv_key) {
	goto fail;
      }
      if ((keys->conv_key->fake_src == ip_hdr->ip_src.s_addr) &&
	  (keys->conv_key->fake_dst == ip_hdr->ip_dst.s_addr)) {
	ip_hdr->ip_src.s_addr = keys->conv_key->real_src;
	ip_hdr->ip_dst.s_addr = keys->conv_key->real_dst;
	conv_key_session_ptr = &keys->conv_key->session;
      } else {
	goto fail;
      }
    } /* At this point we know conv_key_session_ptr points to a key session 
	 for this packet. */
    /* Decrypt ethernet header (plus first 2 bytes of ip header): */
    crypt_buffer(conv_key_session_ptr,ether_hdr,16,iv);
    /* Decrypt IP payload: */
    crypt_buffer(conv_key_session_ptr,(char *)ip_hdr+IP_HDR_SIZE,
		 bpf->bh_caplen-IP_HDR_SIZE-ETHER_HDR_LEN,iv);
    goto succeed;

  } else { /* format == 1 or 2 */
    char * pkt_end_ptr = packet->data + packet->data_length;
    char * timeval_ptr;
    struct ip * ip_hdr;
    int keylen = algorithms[packet->cryptalg].key_length;

    /* note we assume here keys don't need to be padded (so key length is
       a multiple of block length) */
    if (packet->format == 1)
      timeval_ptr = packet->data + keylen;
    else /* format == 2 */
      timeval_ptr = packet->data + 3*keylen;

    /* Note that for now we're assuming that we'll only get one type
       of key; if both a tr_key and a conv_key are given, for example, the
       code below will be incorrect.*/
    if (keys->conv_key) {
      conv_key_session_ptr = &keys->conv_key->session;
      memcpy(iv,timeval_ptr-blocklen,blocklen);
    } else { /* We have to recover a conversation key */
      char * conv_key_ptr = NULL;
      if (keys->src_key) {
	conv_key_ptr = packet->data;
	crypt_buffer(&keys->src_key->session,conv_key_ptr,keylen,iv);
	memcpy(iv,timeval_ptr-blocklen,blocklen);
      } else if (keys->dst_key) {  
	conv_key_ptr = packet->data + keylen;
	memcpy(iv,conv_key_ptr - blocklen,blocklen);
	crypt_buffer(&keys->dst_key->session,conv_key_ptr,keylen,iv);
	memcpy(iv,timeval_ptr-blocklen,blocklen);
      } else if (keys->tr_key) {
	if (packet->format == 2) {
	  conv_key_ptr = packet->data + 2*keylen;
	  memcpy(iv,conv_key_ptr - blocklen,blocklen);
	} else {
	  conv_key_ptr = packet->data;
	}
	crypt_buffer(keys->tr_key,conv_key_ptr,keylen,iv);
      } /* Now we've found and tried to decrypt a conversation key */
      get_crypto_session(&conv_key_session,packet->cryptalg,
			 CBC_MODE,DECRYPT,conv_key_ptr,0);
      conv_key_session_ptr = &conv_key_session;
    } /* and now we've got a conversation key session */

    crypt_buffer(conv_key_session_ptr,timeval_ptr,pkt_end_ptr-timeval_ptr,iv);

    /* Except in the case where we used a tr_key, we don't know that
       the conversation key we used was actually the right one.  If it
       was, then the IP addresses on the now-decrypted packet should
       match.  If not, then there's still a 1 in 2^32 chance that
       they'll match anyway (for endpoint keys, where we check one IP
       address; 1 in 2^64 for conversation keys).  But we're willing to
       live with those kinds of odds. */ 
    if (keys->tr_key) {
      goto succeed;
    }
    ip_hdr = (struct ip *)(timeval_ptr + sizeof(struct timeval) 
			   + ETHER_HDR_LEN);
    if ((keys->conv_key) && 
	(ip_hdr->ip_src.s_addr == keys->conv_key->real_src) &&
	(ip_hdr->ip_dst.s_addr == keys->conv_key->real_dst)) {
      goto succeed;
    }
    if ((keys->src_key) &&
	(ip_hdr->ip_src.s_addr == keys->src_key->addr)) {
      goto succeed;
    }
    if ((keys->dst_key) &&
	(ip_hdr->ip_dst.s_addr == keys->dst_key->addr)) {
      goto succeed;
    }
    /* note that we do not want to "goto fail" at this point; we've already
       reset the iv by attempting decryption of the packet, and "goto fail"
       will set the iv to the wrong thing! */
    return;
  }
  errx(1,"I thought we couldn't get to this point in try_to_decrypt_packet");
  return;
fail:
  memcpy(iv,packet->data + packet->total_length - blocklen,blocklen);
  return;
succeed:
  packet->encrypted = 0;
  return;
}

void read_crypto_file(struct crypto_file * infile,
		      struct crypto_key_ring * keys,
		      crypto_infile_callback handle_packet,
		      void * user) {
  struct crypto_packet this_packet;
  int blocklen;
  char * buffer;
  int length;

  blocklen = algorithms[infile->cryptalg].block_length;
  this_packet.cryptalg=infile->cryptalg;
  this_packet.format=infile->format;
  if (infile->format == 3) {
    /* read and decrypt the whole thing right now. */
    read_cbc_encrypted_buffer(keys->seg_key,infile->file,
			      infile->filename,&buffer,&length);
    /* code to step through file, return packets */
    errx(1,"-z3 not implemented in crypto_file yet.");
  }
  /* infile->format = 0, 1, or 2 */
  while (1) {
    if (infile->format == 0) {
      struct bpf_hdr * bpf = (struct bpf_hdr *)(this_packet.data);
      if (read_bpf(infile->file, bpf,this_packet.data+BPF_HDR_SIZE, 0) != 0)
	break;
      /* pkt_dump has to extend a packet's length to the next higher
       * multiple of the crypto block size, and adjusts bh_caplen
       * accordingly, but  bh_datalen remains set to the actual packet
       * length.  Note that the pkt_dump output files thus have
       * bh_caplen >= bh_datalen.
       */
      this_packet.data_length = bpf->bh_datalen + BPF_HDR_SIZE;
      this_packet.total_length = bpf->bh_caplen + BPF_HDR_SIZE;
      /* sanity check: */
      if (this_packet.total_length !=
	  get_encrypted_length(infile->cryptalg,this_packet.data_length 
			       - BPF_HDR_SIZE - IP_HDR_SIZE - ETHER_HDR_LEN) 
	  + BPF_HDR_SIZE + IP_HDR_SIZE + ETHER_HDR_LEN)
	errx(1,"data_length = %d, total_length = %d inconsistent in "
	     "read_crypto_file",this_packet.data_length,
	     this_packet.total_length);
    } else { /* format == 1 or 2 */
      if (fread(&this_packet.data_length,sizeof(this_packet.data_length),1,
		infile->file) == 0)
	break;
      this_packet.total_length = 
	get_encrypted_length(infile->cryptalg,this_packet.data_length);
      if (fread(this_packet.data,1,this_packet.total_length,infile->file)
	  != this_packet.total_length)
	errx(1,"%s is truncated",infile->filename);
    }
    this_packet.encrypted = 1;
    try_to_decrypt_packet(&this_packet,keys,infile->iv);
    handle_packet(infile,&this_packet,user);
  }
}

void init_subs_table(struct crypto_key_ring * keys,
		     char * filename) {
  int length, i;
  char * buffer;
  int num_elements;

  init_substitution_table(1); /* just using btree for now. */
  load_cbc_encrypted_buffer(keys->tr_key,filename,&buffer,&length);

  if ((length % (2*sizeof(struct in_addr))) == 0)
    num_elements = length / (2*sizeof(struct in_addr));
  else
    errx(1,"Bad length value %d in substitution table", length);

  for (i=0; i<length; i += 2*sizeof(struct in_addr))
    insert_substitution(*(in_addr_t *)(buffer + i),
			*(in_addr_t *)(buffer + i + sizeof(in_addr_t)));
}

struct timeval * get_pkt_timestamp(struct crypto_packet * packet) {
  int keylen;

  if ((packet->encrypted)
      && ( (packet->format == 1) || (packet->format == 2) ))
    errx(1,"can't get timestamp from packet encrypted in format 1 or 2");
  keylen = algorithms[packet->cryptalg].key_length;
  switch (packet->format) {
  case 0:
    return(&((struct bpf_hdr *)(packet->data))->bh_tstamp);
  case 1:
    return((struct timeval *)(packet->data + keylen));
  case 2:
    return((struct timeval *)(packet->data + 3*keylen));
  default: 
    errx(1,"format %d unsupported in get_pkt_timestamp",packet->format);
  }
}

int get_pkt_size(struct crypto_packet * packet) {
  int keylen;

  keylen = algorithms[packet->cryptalg].key_length;
  switch (packet->format) {
  case 0:
    return(packet->data_length - BPF_HDR_SIZE);
  case 1:
    return(packet->data_length -   keylen - sizeof(struct timeval));
  case 2:
    return(packet->data_length - 3*keylen - sizeof(struct timeval));
  default:
    errx(1,"format %d unsupported in get_pkt_size",
	 packet->format);
  }
}

char * get_pkt_contents(struct crypto_packet * packet) {
  int keylen;

  keylen = algorithms[packet->cryptalg].key_length;
  if (packet->encrypted)
    errx(1,"can't get contents of encrypted packet");
  switch(packet->format) {
  case 0:
    return(packet->data + BPF_HDR_SIZE);
  case 1:
    return(packet->data +   keylen + sizeof(struct timeval));
  case 2:
    return(packet->data + 3*keylen + sizeof(struct timeval));
  default:
    errx(1,"format %d unsupported in get_pkt_contents",packet->format);
  }
}

void destroy_packet(struct crypto_packet * packet) {
  free(packet);
}

void close_input_crypto_file(struct crypto_file * infile) {
  fclose(infile->file);
  free(infile->filename);
  free(infile);
}

struct crypto_key_ring * init_crypto_keys(int cryptalg) {
  struct crypto_key_ring * new;
  if ((new = calloc(1,sizeof(struct crypto_key_ring))) == NULL)
    errx(1,"unable to allocate memory for key ring");
  if ((cryptalg<0) || (cryptalg>=NUM_ALGORITHMS))
    errx(1,"bad cryptalg %d in init_crypto_keys",cryptalg);
  new->cryptalg = cryptalg;
  return(new);
}

void destroy_crypto_keys(struct crypto_key_ring * keys) {

  if (keys == NULL)
    return;
  free(keys->vol_key);
  free(keys->tr_key);
  free(keys->conv_key);
  free(keys->src_key);
  free(keys->dst_key);
  free(keys);
}

void get_vol_key(struct crypto_key_ring * keys, char * filename) {
  char key[MAX_KEY_LENGTH];

  read_key_file(filename,key,keys->cryptalg);
  if ((keys->vol_key = malloc(sizeof(struct crypto_session))) == NULL)
    errx(1,"unable to allocate memory for volume key");
  get_crypto_session(keys->vol_key,keys->cryptalg,CBC_MODE,ENCRYPT,
		     key,0);
}

void get_tr_key(struct crypto_key_ring * keys, char * filename) {
  char key[MAX_KEY_LENGTH];

  read_key_file(filename,key,keys->cryptalg);
  if ((keys->tr_key = malloc(sizeof(struct crypto_session))) == NULL)
    errx(1,"unable to allocate memory for translation table key");
  get_crypto_session(keys->tr_key,keys->cryptalg,CBC_MODE,DECRYPT,
		     key,0);
}

void get_conv_key(struct crypto_key_ring * keys, char * filename,
		  in_addr_t src, in_addr_t dst) {
  char key[MAX_KEY_LENGTH];

  read_key_file(filename,key,keys->cryptalg);
  if ((keys->conv_key = malloc(sizeof(*keys->conv_key))) == NULL)
    errx(1,"unable to allocate memory for conversation key");
  get_crypto_session(&keys->conv_key->session,keys->cryptalg,CBC_MODE,DECRYPT,
		     key,0);
  keys->conv_key->real_src = src;
  keys->conv_key->real_dst = dst;
}

void get_conv_key_open_headers(struct crypto_key_ring * keys, char * filename,
			       in_addr_t real_src, in_addr_t real_dst,
			       in_addr_t fake_src, in_addr_t fake_dst) {
  get_conv_key(keys, filename, real_src, real_dst);
  keys->conv_key->fake_src = fake_src;
  keys->conv_key->fake_dst = fake_dst;
}

void get_src_key(struct crypto_key_ring * keys, char * filename,
		 in_addr_t addr) {
  char key[MAX_KEY_LENGTH];

  read_key_file(filename,key,keys->cryptalg);
  if ((keys->src_key = malloc(sizeof(*keys->src_key))) == NULL)
    errx(1,"unable to allocate memory for source endpoint key");
  get_crypto_session(&keys->src_key->session,keys->cryptalg,CBC_MODE,DECRYPT,
		     key,0);
  keys->src_key->addr = addr;
}

void get_dst_key(struct crypto_key_ring * keys, char * filename,
		  in_addr_t addr) {
  char key[MAX_KEY_LENGTH];

  read_key_file(filename,key,keys->cryptalg);
  if ((keys->dst_key = malloc(sizeof(*keys->dst_key))) == NULL)
    errx(1,"unable to allocate memory for destination endpoint key");
  get_crypto_session(&keys->dst_key->session,keys->cryptalg,CBC_MODE,DECRYPT,
		     key,0);
  keys->dst_key->addr = addr;
}

void get_seg_key(struct crypto_key_ring * keys, char * filename) {
  char key[MAX_KEY_LENGTH];

  read_key_file(filename,key,keys->cryptalg);
  if ((keys->dst_key = malloc(sizeof(*keys->seg_key))) == NULL)
    errx(1,"unable to allocate memory for segment key");
  get_crypto_session(&keys->seg_key,keys->cryptalg,CBC_MODE,DECRYPT,key,0);
}
