/*
  svc_in_gssd_proc.c
  
  Copyright (c) 2000 The Regents of the University of Michigan.
  All rights reserved.

  Copyright (c) 2002 Bruce Fields <bfields@UMICH.EDU>

  Redistribution and use in source and binary forms, with or without
  modification, are permitted provided that the following conditions
  are met:

  1. Redistributions of source code must retain the above copyright
     notice, this list of conditions and the following disclaimer.
  2. Redistributions in binary form must reproduce the above copyright
     notice, this list of conditions and the following disclaimer in the
     documentation and/or other materials provided with the distribution.
  3. Neither the name of the University nor the names of its
     contributors may be used to endorse or promote products derived
     from this software without specific prior written permission.

  THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED
  WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
  MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
  DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE
  FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
  BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
  LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
  NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
  SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

*/

#include <sys/param.h>
#include <rpc/rpc.h>

#include <pwd.h>
#include <stdio.h>
#include <unistd.h>

#include "rpc_svcgssd.h"
#include "gss_util.h"
#include "err_util.h"
#include "context.h"

static int
do_svc_downcall(int svc_k5_fd, u_int32_t wire_ctx, uid_t uid,
		gss_OID mech, gss_buffer_desc *context_token)
{
	char		buf[1024];
	char		*p = buf, *end = buf + 1024;
	gss_buffer_desc	mechbuf = {
		.length = mech->length,
		.value = mech->elements,
	};
	unsigned int timeout = 0; /* XXX decide on a reasonable value */

	printerr(1, "doing downcall\n");

	if (WRITE_BYTES(&p, end, wire_ctx)) goto out_err;
	/* Not setting any timeout for now: */
	if (WRITE_BYTES(&p, end, timeout)) goto out_err;
	if (WRITE_BYTES(&p, end, uid)) goto out_err;
	if (write_buffer(&p, end, &mechbuf)) goto out_err;
	if (write_buffer(&p, end, context_token)) goto out_err;

	if (write(svc_k5_fd, buf, p - buf) < p - buf) goto out_err;
	return 0;
out_err:
	printerr(0, "downcall failed\n");
	return -1;
}

struct gss_verifier {
	u_int32_t	flav;
	gss_buffer_desc	body;
};

static int
xdr_decode_rpc_gss_verf(u_int32_t **p, u_int32_t *end,
			struct gss_verifier *verf)
{
	if (xdr_get_u32(p, end, &verf->flav)) return -1;
	if (xdr_get_buffer(p, end, &verf->body)) return -1;
	return 0;
}

struct gss_cred {
	u_int32_t		gc_v;
	u_int32_t		gc_proc;
	u_int32_t		gc_seq;
	u_int32_t		gc_svc;
	gss_buffer_desc		gc_ctx;
};

#define RPC_GSS_PROC_INIT		1
#define RPC_GSS_PROC_CONTINUE_INIT	2

static int
xdr_decode_rpc_gss_cred(u_int32_t **p, u_int32_t *end, struct gss_cred *gc)
{
	if (xdr_get_u32(p, end, &gc->gc_v)) return -1;
	if (xdr_get_u32(p, end, &gc->gc_proc)) return -1;
	if (xdr_get_u32(p, end, &gc->gc_seq)) return -1;
	if (xdr_get_u32(p, end, &gc->gc_svc)) return -1;
	if (xdr_get_buffer(p, end, &gc->gc_ctx)) return -1;
	return 0;
}

#define RPC_AUTH_GSS	6

static int
write_verifier(u_int32_t **p, u_int32_t *end, gss_ctx_id_t ctx, u_int32_t win)
{
	u_int32_t		maj_stat, min_stat;
	gss_buffer_desc		bufin, bufout;

	win = htonl(win);
	bufin.length = sizeof(win);
	bufin.value = &win;
	maj_stat = gss_get_mic(&min_stat, ctx, 0, &bufin, &bufout);
	if (maj_stat != GSS_S_COMPLETE)
		return -1;
	if (xdr_write_u32(p, end, RPC_AUTH_GSS)) return -1;
	if (xdr_write_buffer(p, end, &bufout)) return -1;
	return 0;
}

#define RPCSEC_GSS_SEQ_WIN	5

static int
send_response(int nullfd, gss_ctx_id_t ctx,
	      u_int32_t maj_stat, u_int32_t min_stat, u_int32_t wire_ctx,
	      gss_buffer_desc *output_token)
{
	char			buf[1024];
	u_int32_t		*p = (u_int32_t *)buf;
	u_int32_t		*end = (u_int32_t *)(buf + sizeof(buf));
	gss_buffer_desc		tmp;

	printerr(1, "sending null reply\n");
	if (xdr_write_u32(&p, end, 0)) goto out_err; /* reply state */
	if (write_verifier(&p, end, ctx, RPCSEC_GSS_SEQ_WIN)) goto out_err;
	if (xdr_write_u32(&p, end, 0)) goto out_err; /* accept state */
	tmp.length = sizeof(wire_ctx);
	tmp.value = &wire_ctx;
	if (xdr_write_buffer(&p, end, &tmp)) goto out_err;
	if (xdr_write_u32(&p, end, maj_stat)) goto out_err;
	if (xdr_write_u32(&p, end, min_stat)) goto out_err;
	if (xdr_write_u32(&p, end, RPCSEC_GSS_SEQ_WIN)) goto out_err;
	if (xdr_write_buffer(&p, end, output_token)) goto out_err;
	if (write(nullfd, buf, (char *)p-buf) < (char *)p-buf) goto out_err;
	return 0;
out_err:
	printerr(0, "svc_in_gssd: error sending null reply\n");
	return -1;
}

#define rpc_auth_ok			0 
#define rpc_autherr_badcred		1
#define rpc_autherr_rejectedcred	2
#define rpc_autherr_badverf		3
#define rpc_autherr_rejectedverf	4
#define rpc_autherr_tooweak		5
#define rpcsec_gsserr_credproblem	13
#define rpcsec_gsserr_ctxproblem	14 

static int
send_reject_response(int nullfd, u_int32_t stat)
{
	char		buf[12];
	u_int32_t	*p = (u_int32_t *)buf;
	u_int32_t	*end = (u_int32_t *)(buf + sizeof(buf));

	printerr(1, "sending null reply with stat = %d\n", stat);

	if (xdr_write_u32(&p, end, 1)) goto out_err; /* reject */
	if (xdr_write_u32(&p, end, 1)) goto out_err; /* auth_err */
	if (xdr_write_u32(&p, end, stat)) goto out_err; /* auth_stat */
	if (write(nullfd, buf, (char *)p-buf) < (char *)p-buf) goto out_err;
	return 0;
out_err:
	printerr(0, "svc_in_gssd: error sending null reply\n");
	return -1;
}

static int
parse_request(char *buf, int len, u_int32_t *wire_ctx,
	      gss_buffer_desc *input_token, u_int32_t *gss_proc)
{
	struct gss_cred		cred;
	u_int32_t		credlen;
	struct gss_verifier	verf;
	u_int32_t		*p = (u_int32_t *)buf;
	u_int32_t		*end = (u_int32_t *)(buf + len);

	if (len % 4) return -1;
	if (xdr_get_u32(&p, end, &credlen)) return -1;
	if (xdr_decode_rpc_gss_cred(&p, end, &cred)) return -1;
	if ((char *)p != buf + 4 + credlen) return -1;
	if (xdr_decode_rpc_gss_verf(&p, end, &verf)) return -1;
	if (verf.flav || verf.body.length) return -1;
	if (xdr_get_buffer(&p, end, input_token)) return -1;
	if (p != end) return -1;
	switch (cred.gc_proc) {
		case RPC_GSS_PROC_INIT:
			if (cred.gc_ctx.length != 0)
				return -1;
			break;
		case RPC_GSS_PROC_CONTINUE_INIT:
			if (cred.gc_ctx.length != 4)
				return -1;
			*wire_ctx = *(u_int32_t *)cred.gc_ctx.value;
			break;
		default:
			return -1;
	}
	*gss_proc = cred.gc_proc;
	return 0;
}

/* XXX memory leaks everywhere: */
static int
get_uid(gss_name_t client_name, gss_OID *mech, uid_t *uid)
{
	u_int32_t	maj_stat, min_stat;
	gss_buffer_desc	name;
	char		*sname;
	int		res = -1;
	struct passwd	*pw = NULL;
	gss_OID		name_type;

	maj_stat = gss_display_name(&min_stat, client_name, &name, &name_type);
	if (maj_stat != GSS_S_COMPLETE)
		goto out;
	if (!(sname = calloc(name.length + 1, 1)))
		goto out;
	memcpy(sname, name.value, name.length);
	printerr(1, "sname = %s\n", sname);
	/* XXX? mapping unknown users (including machine creds) to nobody: */
	if ( !(pw = getpwnam(sname)) && !(pw = getpwnam("nobody")) )
		goto out;
	*uid = pw->pw_uid;
	res = 0;
out:
	if (res)
		printerr(0, "get_uid failed\n");
	return res;
}

/* request starts right before opaque part of cred (after flavor). */
/* response starts at reject_stat */
void
handle_nullreq(int nullfd, int ctxfd) {
	static char		buf[1024];
	int			nbytes;
	u_int32_t		wire_ctx = 0;
	u_int32_t		req_wire_ctx;
	gss_buffer_desc		input_token,
				output_token = {.length = 0, .value = NULL},
				ctx_token;
	u_int32_t		gss_proc, ret_flags;
	gss_ctx_id_t		ctx = GSS_C_NO_CONTEXT;
	gss_name_t		client_name;
	gss_OID			mech;
	u_int32_t		maj_stat = GSS_S_FAILURE, min_stat = 0;
	uid_t			uid;
	
	printerr(1, "handling null request\n");

	/* XXX doesn't deal with case of long request. */
	if ((nbytes = read(nullfd, buf, sizeof(buf))) < 0) {
		printerr(0, "failed reading request from nullrpc file\n");
		goto out;
	}
	if (parse_request(buf, nbytes, &req_wire_ctx, &input_token,
			  &gss_proc)) {
		printerr(0, "failed parsing request\n");
		send_reject_response(nullfd, rpc_autherr_badcred);
		goto out;
	}
	/* XXX note not supporting mechs requiring continue_init for now: */
	if (gss_proc != RPC_GSS_PROC_INIT) {
		printerr(0, "bad gss_proc\n");
		send_reject_response(nullfd, rpc_autherr_badcred);
		goto out;
	}
	/* get context handle: */
	if (read(ctxfd, &wire_ctx, sizeof(wire_ctx)) != sizeof(wire_ctx)) {
		printerr(0, "failed reading handle from init_context file\n");
		goto out;
	}
	maj_stat = gss_accept_sec_context(&min_stat, &ctx, gssd_creds,
			&input_token, GSS_C_NO_CHANNEL_BINDINGS, &client_name,
			&mech, &output_token, &ret_flags, NULL, NULL);
	if (maj_stat != GSS_S_COMPLETE) {
		printerr(0, "gss_accept_sec_context failed\n");
		send_reject_response(nullfd, rpc_autherr_rejectedcred);
		goto out;
	}
	if (get_uid(client_name, &mech, &uid)) {
		printerr(0, "get_uid failed\n");
		send_reject_response(nullfd, rpc_autherr_rejectedcred);
		goto out;
	}

/* XXX note: Couldn't ctx state change between now and when we
 * use it to calculate the mic in send_response?  Would be better to
 * precalculate the mic here. */
	export_ctx_to_kernel(ctx, &ctx_token);
	do_svc_downcall(ctxfd, wire_ctx, uid, mech, &ctx_token);
out:
	send_response(nullfd, ctx, maj_stat, min_stat, wire_ctx,
			&output_token);
	printerr(1, "finished handling null request\n");
	return;
}
