/*
 * COPYRIGHT    2001
 * THE REGENTS OF THE UNIVERSITY OF MICHIGAN
 * ALL RIGHTS RESERVED
 *
 * Permission is granted to use, copy, create derivative works
 * and redistribute this software and such derivative works
 * for any purpose, so long as the name of The University of
 * Michigan is not used in any advertising or publicity
 * pertaining to the use of distribution of this software
 * without specific, written prior authorization.  If the
 * above copyright notice or any other identification of the
 * University of Michigan is included in any copy of any
 * portion of this software, then the disclaimer below must
 * also be included.
 *
 * THIS SOFTWARE IS PROVIDED AS IS, WITHOUT REPRESENTATION
 * FROM THE UNIVERSITY OF MICHIGAN AS TO ITS FITNESS FOR ANY
 * PURPOSE, AND WITHOUT WARRANTY BY THE UNIVERSITY O
 * MICHIGAN OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING
 * WITHOUT LIMITATION THE IMPLIED WARRANTIES OF
 * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE
 * REGENTS OF THE UNIVERSITY OF MICHIGAN SHALL NOT BE LIABLE
 * FOR ANY DAMAGES, INCLUDING SPECIAL, INDIRECT, INCIDENTAL, OR
 * CONSEQUENTIAL DAMAGES, WITH RESPECT TO ANY CLAIM ARISING
 * OUT OF OR IN CONNECTION WITH THE USE OF THE SOFTWARE, EVEN
 * IF IT HAS BEEN OR IS HEREAFTER ADVISED OF THE POSSIBILITY OF
 * SUCH DAMAGES.
 */

#include <sys/types.h>
#include <netinet/in.h>
#include <netdb.h>
#include <stdio.h>
#include <sys/stat.h>
#include <unistd.h>
#include <sys/time.h>
/* So we can build this on a system w/o kernel changes */
#include "../sys/net/bpf.h"	/* #include <net/bpf.h> */
#include <pcap.h>
#include <stdlib.h>
#include <sys/socket.h>
#include <net/if.h>
#include <netinet/if_ether.h>

#include <assert.h>
#include <string.h>
#include <netinet/in_systm.h>
#include <netinet/ip.h>

#include "../bpf/parse_bpf.h"
#include "../crypto/crypto.h"
#include "../dump/btree.h"
#include "../dump/list.h"

/* The fixed-length part of the IP header is all we care about: */
#define IP_HDR_SIZE 20

#define	SUCCESS	0
#define	FAILURE	(-1)

/* globals */
struct list * ipList;
DB * ipDB;
int bFlag = 0;
int dFlag = 0;
char key_volume_key[MAX_KEY_LENGTH];
int totLen = 0;
int totPackets = 0;
int cryptalg = DEFAULT_ALGORITHM;
void print_hex();
char *base();
int shorts, ckeys, rbytes, wbytes, pkts, untrans;
FILE *stats;
struct stat sb;
int reals, realt;

unsigned char  key_trans_key[MAX_KEY_LENGTH];

int num_elements, num_read;

int translate(in_addr_t * addr) {

  if (( bFlag && findDB_real_ip(ipDB,addr)) ||
      (!bFlag && find_real_ip(ipList,addr))) {
    return(SUCCESS);
  } else {
    printf("WARNING, can't find substituted ip 0x%08x in table\n",
	   ntohl(*addr));
    return(FAILURE);
  }
}

int translate_init(char * st_name, struct crypto_session * tr_key_session)
{
  int length, i;
  char * buffer;

  read_cbc_encrypted_buffer(tr_key_session,st_name,&buffer,&length);
  if ((length % (2*sizeof(struct in_addr))) == 0) {
    num_elements = length / (2*sizeof(struct in_addr));
  }
  else {
    printf("Bad length value %d in substitution table", length);
    exit(1);
  }
  for (i=0; i<length; i += 2*sizeof(struct in_addr)) {
    if (bFlag)
      insertDB_substitution(ipDB,*(in_addr_t *)(buffer + i),
			    *(in_addr_t *)(buffer + i
					   + sizeof(in_addr_t)));
    else
      insert_substitution(ipList,(struct in_addr *)(buffer + i),
			  (struct in_addr *)(buffer + i
					     + sizeof(struct in_addr)));
  }
  printf("%d elements in translation table\n", num_elements);
  return SUCCESS;
}

void write_error(char * filename)
{
      printf("Error writing file %s\n",filename);
      exit(1);
}

void usage(char * prog) {
	printf("usage: %s\n",prog);
	printf("-o file: file to write cleartext packets to, in pcap\n");
	printf("         format (packets will be appended if file exists)\n");
	printf("-c file: encrypted file to read\n");
	printf("-d file: write statistics to file, don't write decrypted data\n");
	printf("         (not supported in conversation or endpoint modes)\n");
	printf("-s file: file with encrypted substitution table\n");
	printf("-t file: file with encrypted list of conversations\n");
	printf("-V file: file containing volume key\n");
	printf("-T file: file containing translation table key\n");
	exit(1);
}

void decrypt_z_file(FILE * infile, FILE * outfile,
		    int cryptalg, int format,
		    struct crypto_session * vol_key_session,
		    struct crypto_session * tr_key_session) {
  short int packet_length, padded_packet_length;
  int blocklen;
  int padded_keylen;
  char iv[MAX_BLOCK_LENGTH];
  char packet[MAX_PACKET_SIZE];
  char * conv_key_ptr, * timeval_ptr, * ether_start_ptr, * pkt_end_ptr;
  struct crypto_session conv_key_session;
  struct pcap_pkthdr clear_pkt_hdr;
    
  blocklen = algorithms[cryptalg].block_length;
  if (fread(iv,1,blocklen,infile) != blocklen) {
    printf("infile truncated");
    exit(1);
  }
  while (1) {
    if (fread(&packet_length,sizeof(packet_length),1,infile) == 0) {
      break;
    }
    padded_packet_length = 
      get_encrypted_length(vol_key_session,packet_length);
    if (fread(packet,1,padded_packet_length,infile)
	!= padded_packet_length) {
      printf("truncated infile");
      exit(1);
    }
    pkt_end_ptr = packet + packet_length;
    padded_keylen =
      get_encrypted_length(tr_key_session,algorithms[cryptalg].key_length);
    if (format == 2) {
      /* Skip over the first two encrypted keys; we wouldn't use them
	 unless we were trying to recover a few packets with an
	 endpoint key. */
      conv_key_ptr = packet + 2*padded_keylen;
      memcpy(iv,conv_key_ptr - blocklen,blocklen);
    }
    else
      conv_key_ptr = packet;
    crypt_buffer(tr_key_session,conv_key_ptr,padded_keylen,iv);
    get_crypto_session(&conv_key_session,cryptalg,CBC_MODE,DECRYPT,
		       conv_key_ptr,0);
    timeval_ptr = conv_key_ptr+padded_keylen;
    crypt_buffer(&conv_key_session,timeval_ptr,pkt_end_ptr-timeval_ptr,iv);
    ether_start_ptr = timeval_ptr + sizeof(struct timeval);

    clear_pkt_hdr.ts = *(struct timeval *)(timeval_ptr);
    clear_pkt_hdr.len = pkt_end_ptr-ether_start_ptr;
    clear_pkt_hdr.caplen = pkt_end_ptr-ether_start_ptr;

    if (fwrite(&clear_pkt_hdr,sizeof(clear_pkt_hdr),1,outfile) != 1) {
      printf("Error writing packet header to clear file");
      exit(1);
    }
    if (fwrite(ether_start_ptr,1,pkt_end_ptr-ether_start_ptr,outfile)
	!= pkt_end_ptr-ether_start_ptr) {
      printf("error writing packet contents to clear file");
    }
  }
  fclose(outfile);
  fclose(infile);
}

int main(int argc, char * argv[])
{
  FILE * source, * dest;
  char * volume_key_filename, * tr_key_filename;
  unsigned char key_volume_key[MAX_KEY_LENGTH];
  unsigned char key_tr_key[MAX_KEY_LENGTH];
  struct bpf_hdr bpf;
  char bpf_content[MAX_PACKET_SIZE];
  struct ip * ip_hdr;
  char iv[MAX_BLOCK_LENGTH];
  int need_header=1;
  struct pcap_pkthdr pkt_hdr;
  int optch;
  char * sfilename, * tfilename, * cfilename, * ofilename;
  extern char * optarg;
  extern int optind;
  char * conv_key;
  struct crypto_session conv_key_session, vol_key_session,
    tr_key_session;
  int blocklen;
  int bytes;
  int format = 0;

  umask(077);	/* Make files only visible to the owner */

  while ((optch = getopt(argc,argv,"o:bc:d:s:t:V:T:")) != -1) {
    switch(optch) {
    case 'o':
      ofilename = optarg;
      /* If the output file doesn't already exist, then we'll need to
         add a header: */
      need_header=access(ofilename,F_OK);
      if ((dest=fopen(ofilename,"a")) == NULL) {
	printf("Error opening output file %s\n",ofilename);
	exit(1);
      }
      break;
    case 'b':
      bFlag++;
      break;
    case 'c':
      cfilename = optarg;
      if ((source=fopen(cfilename,"r")) == NULL) {
	printf("Error opening input file %s\n",cfilename);
	exit(1);
      }
      break;
    case 'd':
      if ((stats=fopen(optarg,"a")) == NULL) {
	printf("Error opening stats file %s\n",optarg);
	exit(1);
      }
      break;
    case 's':
      sfilename=optarg;
      break;
    case 't':
      tfilename = optarg; /* actually ignored.... */
      break;
    case 'V':
      volume_key_filename = optarg;
      break;
    case 'T':
      tr_key_filename = optarg;
      break;
    case '?':
    default:
      usage(argv[0]);
    }
  }
  if (fread(&format,sizeof(format),1,source) != 1) {
    perror("Error reading file format from input file");
    exit(1);
  }
  if ((format != 0) && (format != 1) && (format != 2)) {
    printf("File used unknown file format %d\n",format);
    exit(1);
  }
  if (fread(&cryptalg,sizeof(cryptalg),1,source) != 1) {
    perror("Error reading cryptalg from input file");
    exit(1);
  }
  if ((cryptalg<0) || (cryptalg >= NUM_ALGORITHMS)) {
    printf("File used unknown crypto algorithm %d.\n",cryptalg);
    exit(1);
  }
  read_key_file(volume_key_filename,key_volume_key,cryptalg);
  get_crypto_session(&vol_key_session,cryptalg,CBC_MODE,ENCRYPT,
		     key_volume_key,0);
  read_key_file(tr_key_filename,key_tr_key,cryptalg);
  get_crypto_session(&tr_key_session,cryptalg,CBC_MODE,DECRYPT,
		     key_tr_key,0);

  if (!(stats && format == 0) && need_header) {
    bytes = (write_pcap_header(dest));
    if (!bytes)
      write_error(ofilename);
    wbytes += bytes;
  }

  if (format != 0) {
    decrypt_z_file(source,dest,cryptalg,format,
		   &vol_key_session,&tr_key_session);
    exit(0);
  }
  /* From now on we're just dealing with the old (-z0) file format. */
  if (bFlag) {
    ipDB = newDB("btree.s");
  } else {
    ipList = newList(0);
  }

  if (translate_init(sfilename,&tr_key_session) == FAILURE) {
    printf("%s doesn't contain a valid translation table\n", sfilename);
    exit(1);
  }
  if (stats) {
	if ((stat(cfilename, &sb)) < 0) {
		perror(cfilename);
		exit (1);
	}
	fprintf(stats, "f %s %d.0 %d\n",
		cfilename, sb.st_mtime, (int)sb.st_size);
	if ((stat(sfilename, &sb)) < 0) {
		perror(sfilename);
		exit (1);
	}
	fprintf(stats, "s %s %d.0 %d\n",
		sfilename, sb.st_mtime, (int)sb.st_size);
	if ((stat(tfilename, &sb)) < 0) {
		perror(tfilename);
		exit (1);
	}
	fprintf(stats, "t %s %d.0 %d\n",
		tfilename, sb.st_mtime, (int)sb.st_size);
  }
  blocklen = algorithms[cryptalg].block_length;
  if (fread(&iv,1,blocklen,source)!=blocklen) {
    perror("Error reading iv");
    exit(1);
    }
  
  while(1)
  {
    if (read_bpf(source, &bpf, bpf_content, 0) != 0)
	break;
    rbytes += BPF_HDR_SIZE+bpf.bh_caplen;
    ip_hdr = (struct ip *)(bpf_content + ETHER_HDR_LEN);
    if (ntohs(ip_hdr->ip_len) > (bpf.bh_caplen - ETHER_HDR_LEN)) {
	printf("WARNING, ip_len = %d (0x%x), ",
	       ntohs(ip_hdr->ip_len), ntohs(ip_hdr->ip_len));
	printf("bh_caplen-ETHER_HDR_LEN=%d (0x%x)\n",
	       bpf.bh_caplen-ETHER_HDR_LEN, bpf.bh_caplen-ETHER_HDR_LEN);
	++shorts;
	continue;
    }

    /*
     * set up the pcap header for this packet.  pkt_dump has
     * to extend a packet's length to the next higher multiple
     * of the crypto block size, and adjusts bh_caplen accordingly.
     * bh_datalen remains set to the actual packet length, which
     * we use to set pkt_hdr.caplen below.  note the pkt_dump
     * output files thus have bh_caplen >= bh_datalen.
     */
    pkt_hdr.ts.tv_sec=bpf.bh_tstamp.tv_sec;
    pkt_hdr.ts.tv_usec=bpf.bh_tstamp.tv_usec;
    pkt_hdr.len=bpf.bh_datalen;
    pkt_hdr.caplen=bpf.bh_datalen;

    if (!stats) {
      if (fwrite(&pkt_hdr, sizeof(pkt_hdr), 1, dest) == 1)
	wbytes += sizeof pkt_hdr;
      else
	write_error(ofilename);
    }

    if (translate(&ip_hdr->ip_src.s_addr) != SUCCESS ||
        translate(&ip_hdr->ip_dst.s_addr) != SUCCESS) {
	++untrans;
    }
    conv_key = make_conv_key(&vol_key_session,ip_hdr,cryptalg);
    get_crypto_session(&conv_key_session,
		       cryptalg,CBC_MODE,DECRYPT,conv_key,0);
    /* Decrypt ethernet header: */
    crypt_buffer(&conv_key_session,bpf_content+0,16,iv);
    /* Decrypt IP payload: */
    crypt_buffer(&conv_key_session,bpf_content+34,bpf.bh_caplen-34,iv);
    if (!stats) {
      if (fwrite(bpf_content, bpf.bh_datalen, 1, dest) == 1)
	  wbytes += bpf.bh_datalen;
      else
	write_error(ofilename);
    }
    if (stats)
      fprintf(stats, "p %s %ld.%ld %d %d\n", cfilename,
	      bpf.bh_tstamp.tv_sec, bpf.bh_tstamp.tv_usec,
	      BPF_HDR_SIZE+bpf.bh_caplen, bpf.bh_caplen);
    ++pkts;
  }
  printf("+ f %s s %s t %s:  p %d r %d w %d k %d", base(cfilename),
	 base(sfilename), base(tfilename), pkts, rbytes, wbytes, ckeys);
  if (untrans)
	printf (", %d UNTRANSLATED", untrans);
  if (shorts)
	printf (", %d SHORT", shorts);
  putchar('\n');
  exit(untrans + shorts);
}

void
print_hex(s, p, n)
char *s, *p;
{
	int i;

	printf("%s", s);
	for (i = 0; i < n; ++i) {
		if ((i%8) == 0)
			putchar(' ');
		printf("%02X", (*p++) & 0xff);
	}
	printf(" (%d)\n", n);
}

char *
base(s)
char *s;
{
	char *p = rindex(s,'/');

	return p ? (p+1) : s;
}




