/*
 * COPYRIGHT    2001
 * THE REGENTS OF THE UNIVERSITY OF MICHIGAN
 * ALL RIGHTS RESERVED
 *
 * Permission is granted to use, copy, create derivative works
 * and redistribute this software and such derivative works
 * for any purpose, so long as the name of The University of
 * Michigan is not used in any advertising or publicity
 * pertaining to the use of distribution of this software
 * without specific, written prior authorization.  If the
 * above copyright notice or any other identification of the
 * University of Michigan is included in any copy of any
 * portion of this software, then the disclaimer below must
 * also be included.
 *
 * THIS SOFTWARE IS PROVIDED AS IS, WITHOUT REPRESENTATION
 * FROM THE UNIVERSITY OF MICHIGAN AS TO ITS FITNESS FOR ANY
 * PURPOSE, AND WITHOUT WARRANTY BY THE UNIVERSITY O
 * MICHIGAN OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING
 * WITHOUT LIMITATION THE IMPLIED WARRANTIES OF
 * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE
 * REGENTS OF THE UNIVERSITY OF MICHIGAN SHALL NOT BE LIABLE
 * FOR ANY DAMAGES, INCLUDING SPECIAL, INDIRECT, INCIDENTAL, OR
 * CONSEQUENTIAL DAMAGES, WITH RESPECT TO ANY CLAIM ARISING
 * OUT OF OR IN CONNECTION WITH THE USE OF THE SOFTWARE, EVEN
 * IF IT HAS BEEN OR IS HEREAFTER ADVISED OF THE POSSIBILITY OF
 * SUCH DAMAGES.
 */

#include <sys/types.h>
#include <unistd.h>
#include <stdlib.h>
#include <stdio.h>
#include <fcntl.h>
#include <netinet/in_systm.h>
#include <netinet/in.h>
#include <netinet/ip.h>
#include <err.h>
#include <string.h>
#include "crypto.h"

struct algorithm_struct algorithms[NUM_ALGORITHMS] = {
  /* name, block_length, key_length, key_schedule_length */
  {"NONE", 16,            0,            0},
  {"DESX", 8,            24,          144}, /* 144 = sizeof(long)*2*16 + 16 */
  {"AES" , 16,           16, 4*(AES_ROUNDS+1)*4},
  {"GAES", 16,           16, sizeof(aes_ctx)}
};

void init_desx_ks(desx_ks ks, char * key, int operation) {
  switch (operation) {
  case ENCRYPT:
    memcpy(ks->prewhitening, key+0 ,8);
    deskey(ks->des,          key+8 ,operation);
    memcpy(ks->postwhitening,key+16,8);
    break;
  case DECRYPT:
    memcpy(ks->prewhitening, key+16,8);
    deskey(ks->des,          key+8 ,operation);
    memcpy(ks->postwhitening,key+0 ,8);
    break;
  default:
    exit(1);
  }
}

#define MAX_CONV_DATA_LENGTH 9

void get_crypto_session(struct crypto_session * session,
			int algorithm, int chaining, 
			int operation, char * key, int save_key) {

  session->algorithm = algorithm;
  if (algorithm == NONE) {
#ifndef INSECURE
    printf("You are not allowed to choose to use no encryption\n");
    exit(1);
#endif
    return;
  }
  session->chaining = chaining;
  session->operation = operation;
  if (save_key) {
    memcpy(session->key,key,algorithms[algorithm].key_length);
  }
  switch (algorithm) {
  case DESX:
    init_desx_ks(&session->key_schedule.desx,key,operation);
    break;
  case AES:
    switch (session->operation) {
    case ENCRYPT:
      rijndaelKeySetupEnc(session->key_schedule.aes,key,
			  algorithms[AES].key_length*8);
      break;
    case DECRYPT:
      rijndaelKeySetupDec(session->key_schedule.aes,key,
			  algorithms[AES].key_length*8);
      break;
    default:
      printf("unknown operation in get_crypto_session\n");
      exit(1);
    }
    break;
  case GAES:
    if (aes_set_key(key,16,(session->operation == ENCRYPT ? aes_enc : aes_dec),
		    &session->key_schedule.gaes) == aes_bad)
      errx(1,"error setting key schedule");
    break;
  default: 
    printf("unknown crypto algorithm\n");
    exit(1);
  }
}

void xor_bytes(char * xored, char * xorer, int length) {
  int i;

  for (i=0; i<length; i+= sizeof(long long int)) {
    *(long long int *)(xored + i) ^= *(long long int *)(xorer + i);
  }
}

void desx(desx_ks ks, char * buffer) {
  
  xor_bytes(buffer,ks->prewhitening,8);
  des(ks->des,buffer);
  xor_bytes(buffer,ks->postwhitening,8);
}

void decrypt_block(struct crypto_session * session, char * block) {

    switch (session->algorithm) {
    case DESX:
	desx(&session->key_schedule.desx,block);
	break;
    case AES:
	rijndaelDecrypt(session->key_schedule.aes,AES_ROUNDS,
			(u8 *)(block),(u8 *)block);
	break;
    case GAES:
	aes_decrypt(block,block,&session->key_schedule.gaes);
	break;
    default:
	printf("unknown algorithm in crypt_block\n");
	exit(1);
    }
}

void encrypt_block(struct crypto_session * session, char * block) {
    switch (session->algorithm) {
	  case DESX:
    desx(&session->key_schedule.desx,block);
    break;
    case AES:
	rijndaelEncrypt(session->key_schedule.aes,AES_ROUNDS,
			(u8 *)(block),(u8 *)block);
	break;
    case GAES:
	aes_encrypt(block,block,&session->key_schedule.gaes);
	break;
    default:
	printf("unknown algorithm in crypt_block\n");
	exit(1);
    }
}

/*
void crypt_block(struct crypto_session * session, char * block) {

  switch (session->algorithm) {
  case DESX:
    desx(&session->key_schedule.desx,block);
    break;
  case AES:
    switch (session->operation) {
    case ENCRYPT:
      rijndaelEncrypt(session->key_schedule.aes,AES_ROUNDS,
		      (u8 *)(block),(u8 *)block);
      break;
    case DECRYPT:
      rijndaelDecrypt(session->key_schedule.aes,AES_ROUNDS,
		      (u8 *)(block),(u8 *)block);
      break;
    default:
      printf("unknown operation in crypt_block");
      exit(1);
      break;
    }
    break;
  case GAES:
    switch (session->operation) {
    case ENCRYPT:
      aes_encrypt(block,block,&session->key_schedule.gaes);
      break;
    case DECRYPT:
      aes_decrypt(block,block,&session->key_schedule.gaes);
      break;
    default:
      errx(1,"unknown operation in crypt_block");
    }
    break;
  default:
    printf("unknown algorithm in crypt_block\n");
    exit(1);
  }
}
*/

/* shortcut assuming we always cbc encrypt, cryptalg = gaes, length>0:
int crypt_buffer(struct crypto_session * session, char * buffer, int length,
		 char * iv) {
    int i;
    
    *(long long int *)(buffer) ^=
	*(long long int *)iv;
    *(long long int *)(buffer + 8) ^=
	*(long long int *)(iv + 8);
    aes_encrypt(buffer,buffer,&session->key_schedule.gaes);
    for (i=16; i<length; i+=16) {
	*(long long int *)(buffer + i) ^= 
	    *(long long int *)(buffer+i-16);
	*(long long int *)(buffer + i + 8 ) ^= 
	    *(long long int *)((buffer+i-16) + 8);
	aes_encrypt(buffer+i,buffer+i,&session->key_schedule.gaes);
    }
    memcpy(iv,buffer+i-16,16);
    return(i);
}
*/

int crypt_buffer(struct crypto_session * session, char * buffer, int length,
		 char * iv) {
  int i, padded_length;
  char saved_iv[MAX_BLOCK_LENGTH];
  int blocklen;

  if (session->algorithm == NONE) {
    return(get_encrypted_length(NONE,length));
  }
  blocklen = algorithms[session->algorithm].block_length;
  switch (session->chaining) {
      /*
  case ECB_MODE:
      for (i=0; i<length; i+=blocklen) {
	  if (session->algorithm == ENCRYPT)
	      encrypt_block(session,buffer+i);
	  else
	      decrypt_block(session,buffer+i);
      }
      return(i);
      break;
      */
  case CBC_MODE:
      if (length <= 0)
	return(0);
      switch (session->operation) {
      case ENCRYPT:
	xor_bytes(buffer,iv,blocklen);
	encrypt_block(session,buffer);
	for (i=blocklen; i<length; i+=blocklen) {
	  xor_bytes(buffer+i,(buffer+i-blocklen),blocklen);
	  encrypt_block(session,buffer+i); 
	}
	memcpy(iv,buffer+i-blocklen,blocklen);
	return(i);
	break;
      case DECRYPT:
	i = padded_length = get_encrypted_length(session->algorithm,length);
	memcpy(saved_iv,iv,blocklen);
	memcpy(iv,buffer+i-blocklen,blocklen);
	for (i = i-blocklen; i>=0; i-=blocklen) {
	  decrypt_block(session,buffer+i);
	  xor_bytes(buffer+i,(i ? buffer+i-blocklen : saved_iv),blocklen);
	}
	return(padded_length);
	break;
      default:
	printf("unknown crypto operation in crypt_buffer\n");
	exit(1);
	break;
      }
  default:
    printf("unknown chaining mode in crypt_buffer\n");
    exit(1);
  }
  printf("unexpectedly fell off the end of crypt_buffer\n");
  exit(1);
}

int get_encrypted_length(int cryptalg, int length) {
  int orphan_len, divmask, blocklen;

  blocklen = algorithms[cryptalg].block_length;
  switch (blocklen) {
  case 1: 
    divmask = 0x00000000;
    break;
  case 8:
    divmask = 0x00000007;
    break;
  case 16:
    divmask = 0x0000000F;
    break;
  default:
    printf("unknown block size in get_encrypted_length\n");
    exit(1);
  }
  orphan_len = length & divmask;
  return(length + (orphan_len ? blocklen-orphan_len : 0));
}

void write_random_key(char * keyfile, int algorithm) {
  int devrand;
  FILE * out;
  char bytes[MAX_BLOCK_LENGTH];

  if ((algorithm <0) || (algorithm >= NUM_ALGORITHMS)) {
    perror("unknown algorithm in make_random_key");
    exit(1);
  }
  
  if (!(devrand = open("/dev/srandom",O_RDONLY,0))) {
    perror("couldn't open /dev/srandom");
    exit(1);
  }
  if (read(devrand,bytes,algorithms[algorithm].key_length)
      != algorithms[algorithm].key_length) {
    perror("failed to read from /dev/srandom");
    exit(1);
  }
  close(devrand);

  if (!(out = fopen(keyfile,"w"))) {
    perror("couldn't open key file");
    exit(1);
  }
  if (fwrite(bytes,1,algorithms[algorithm].key_length,out)
      != algorithms[algorithm].key_length) {
    perror("failed to write to key file");
    exit(1);
  }
  fclose(out);
}

void print_algorithms() {
  int i;

  for (i=0; i<NUM_ALGORITHMS; i++) {
    printf("%d: %s\n",i,algorithms[i].name);
  }
}
     

/* The key-generation algorithm below works as follows:
   * Assume the following are given:
     a block cipher and a key-generating key (specified in "session")
     conversation data (e.g., a pair of IP addresses)
     the length of the desired key (determined by "cryptalg")
   * Use the key-generating-key (in "session") to encrypt, in CBC mode,
     the conversation data, zero-padded to a sufficient length that when
     we take the final bits of the ciphertext as a key, their overlap will
     be contained in at most one block.  Then return a pointer to the final
     bits, to be used as a key.
   In most cases this level of generality isn't needed (e.g., if we use AES
   with 128-bit blocks and keys, then our conversation data and key length
   will each always fit in one AES block).  If we settle on a single
   algorithm, we may decide to make this less general, for efficiency's sake.
*/

/* Note that *buffer will be destructively modified, and must be of size
   at least DERIVED_KEY_BUFFER_SIZE. */

#define DERIVED_KEY_BUFFER_SIZE MAX_CONV_DATA_LENGTH+MAX_KEY_LENGTH

char * make_derived_key(struct crypto_session * session,
			char * buffer, int length,
			int cryptalg) {
  char iv[MAX_BLOCK_LENGTH];
  static int keygen_blocklen, target_keylen, padded_conv_length, 
    padded_target_keylen, buffer_length;

  keygen_blocklen = algorithms[session->algorithm].block_length;
  target_keylen = algorithms[cryptalg].key_length;
  padded_conv_length = get_encrypted_length(session->algorithm,length);
  padded_target_keylen =
    get_encrypted_length(session->algorithm,target_keylen);
  if ( (padded_conv_length == 0) || (padded_target_keylen == 0) )
    buffer_length = padded_conv_length + padded_target_keylen;
  else 
    buffer_length = padded_conv_length + padded_target_keylen
      - keygen_blocklen;

  bzero(buffer+length,buffer_length-length);
  bzero(iv,keygen_blocklen);
  crypt_buffer(session,buffer,buffer_length,iv);

  return(buffer+buffer_length-target_keylen);
}

char * make_conv_key(struct crypto_session * session,
		     struct ip * ip, int cryptalg) {
  static char buffer[DERIVED_KEY_BUFFER_SIZE];

  memcpy(buffer, &ip->ip_src, 8);
  return(make_derived_key(session,buffer,8,cryptalg));
}

char * make_src_key(struct crypto_session * session,
		    struct ip * ip, int cryptalg) {
  static char buffer[DERIVED_KEY_BUFFER_SIZE];

  memcpy(buffer,&ip->ip_src,4);
  /* The following ensures that the generated key will not (at least
     not obviously) be the same as any conversation key: */
  memcpy(buffer+4,"\0\0\0\0S",5);
  return(make_derived_key(session,buffer,9,cryptalg));
}

char * make_dst_key(struct crypto_session * session,
		    struct ip * ip, int cryptalg) {
  static char buffer[DERIVED_KEY_BUFFER_SIZE];

  memcpy(buffer,&ip->ip_dst,4);
  /* The following ensures that the generated key will not (at least
     not obviously) be the same as any conversation key: */
  memcpy(buffer+4,"\0\0\0\0D",5);
  return(make_derived_key(session,buffer,9,cryptalg));
}

#define MAX_FILENAME_SIZE 20

char * make_segment_key(struct crypto_session * session,
			char * filename, int cryptalg) {
  static char buffer[MAX_KEY_LENGTH + MAX_FILENAME_SIZE];

  strncpy(buffer,filename,MAX_KEY_LENGTH + MAX_FILENAME_SIZE);
  /* Bother with uniqueness? If filenames always at least certain
     length, maybe we don't care. */
  return(make_derived_key(session,buffer,strlen(filename),cryptalg));
}

void read_key_file(char * filename, char * key, int cryptalg) {
  FILE * keyfile;

  if (algorithms[cryptalg].key_length == 0)
    return;
  if ((keyfile=fopen(filename,"r")) == NULL) {
    printf("Error opening key file %s\n",filename);
    exit(1);
  }
  if (fread(key,algorithms[cryptalg].key_length,1,keyfile)!=1) {
    printf("Key file %s truncated",filename);
    exit(1);
  }
  fclose(keyfile);
}

void write_cbc_encrypted_buffer(struct crypto_session * session,
				FILE * outfile,
				char * filename,
				char * buffer, int length) {
  int i, blocklen, padded_length;
  char iv[MAX_BLOCK_LENGTH];

  blocklen = algorithms[session->algorithm].block_length;
  /* We assume the random number generator has been seeded elsewhere.
     (Currently that happens in dump/pkt_dump.c, main().)*/
  for (i = 0; i < blocklen; i += sizeof(long))
    *(long *)(iv + i) = random();

  if (fwrite(iv,1,blocklen,outfile) != blocklen) {
    errx(1,"Error writing iv to %s",filename);
  }
  padded_length = crypt_buffer(session,buffer,length,iv);
  if ((fwrite(&length,sizeof(length),1,outfile) != 1)
      || (fwrite(buffer,1,padded_length,outfile) != padded_length)) {
    errx(1,"Error writing %s",filename);
  }
}

void save_cbc_encrypted_buffer(struct crypto_session * session,
			       char * filename,
			       char * buffer, int length) {
  FILE * outfile;
  
  if ((outfile = fopen(filename,"w+")) == NULL) {
    errx(1,"Failed to open %s",filename);
  }
  write_cbc_encrypted_buffer(session, outfile, filename, buffer, length);
  fclose(outfile);
}

void read_cbc_encrypted_buffer(struct crypto_session * session,
			       FILE * infile,
			       char * filename,
			       char ** buffer, int * length) {
  int blocklen, padded_length;
  char iv[MAX_BLOCK_LENGTH];

  blocklen = algorithms[session->algorithm].block_length;
  if (fread(iv,1,blocklen,infile) != blocklen) {
    errx(1,"Error reading iv from %s",filename);
  }
  if (fread(length,sizeof(*length), 1,infile) != 1) {
    errx(1,"Error reading length from %s",filename);
  }
  padded_length = get_encrypted_length(session->algorithm,*length);
  if (!(*buffer = (char *)malloc(padded_length))) {
    errx(1,"Unable to malloc memory for buffer in load_cbc_encrypted_buffer");
  }
  if (fread(*buffer,1,padded_length,infile) != padded_length) {
    errx(1,"File %s truncated",filename);
  }
  crypt_buffer(session,*buffer,*length,iv);

}

void load_cbc_encrypted_buffer(struct crypto_session * session,
			       char * filename,
			       char ** buffer, int * length) {
  FILE * infile;

  if ((infile = fopen(filename,"r")) == NULL) {
    errx(1,"Error opening file %s",filename);
  }
  read_cbc_encrypted_buffer(session, infile, filename, buffer, length);
  fclose(infile);
}

void print_key(char * key, int cryptalg) {
  int i;

  for (i=0; i<algorithms[cryptalg].key_length; i++) {
    if (i%8 == 0)
      printf(" ");
   printf("%02x",key[i] & 0xff);
  }
}
