#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/ptrace.h>
#include <signal.h>
#include <sys/types.h>
#include <sys/resource.h>
#include <unistd.h>
#include <errno.h>
#include <sys/wait.h>
#include <sys/reg.h>
#include <sys/time.h>
#include <sys/stat.h>
#include <fcntl.h>
#include <assert.h>

#include "supervisor.h"
#include "prototypes.h"

#define VERSION_STRING "1.2"

static volatile int tracked_attached;
static long current_syscall;
static int terminate;
static long syscalls_counter;
static int em_retval;
static long changed_syscall;
static int result_fd = -1;

pid_t tracked;
struct rusage res_usage;
short sc_state = SS_OUT;
long used_memory = -1LL;
long max_memory;
unsigned long long exectime;
unsigned long long exectime_offset;
int enable_extensions, force_extensions;
int pass_retval;

int count_systime;
long time_limit;
long hard_time_limit;
long mem_limit;
long out_limit;
unsigned long long et_limit;
int allow_all;
int quiet;

FILE* resultf;

static int the_argc;
static char** the_argv;

const char* result_codes[256] = {
	[RETVAL_OK]   = "Exited normally",
	[RETVAL_TLE]  = "Time limit exceeded",
	[RETVAL_MLE]  = "Memory limit exceeded",
	[RETVAL_RE]   = "Runtime error",
	[RETVAL_RV]   = "Rule violation",
	[RETVAL_OLE]  = "Output limit exceeded"
};
static char buf[255];

static void fail_msg(const char* str)
{
	fprintf(stderr, "supervisor: %s (errno='%s')\n", str, strerror(errno));
}

static void fail(const char* str)
{
	fail_msg(str);
	exit(122);
}

void fail_tracker(const char* str)
{
	fail_msg(str);
	ptrace(PTRACE_KILL, tracked, 0, 0);
	ptrace(PTRACE_CONT, tracked, 0, 0);
	kill(tracked, SIGKILL);
	kill(tracked, SIGCONT);
	exit(122);
}

static long tv2long(struct timeval* tv)
{
	return tv->tv_sec*1000 + tv->tv_usec/1000;
}

static long extract_time(struct rusage* ru)
{
	return tv2long(&ru->ru_utime) +
		(count_systime ? tv2long(&ru->ru_stime) : 0);
}

void report_result(int retval, const char* comment)
{
    fprintf(stderr, "\n");
    fprintf(resultf, "__RESULT__ %d %ld %llu %ld %ld\n%s\n",
            retval,
            extract_time(&res_usage),
            exectime,
            max_memory,
            syscalls_counter,
            comment);

    if (!quiet) {
        fprintf(stderr, "SUPERVISOR REPORT\n");
        fprintf(stderr, "-----------------\n");
        fprintf(stderr, "  Result code: %s\n", result_codes[retval]);
        fprintf(stderr, "  Time used:   %ldms\n", extract_time(&res_usage));
        fprintf(stderr, "  Memory used: %ldkB\n", max_memory);
        fprintf(stderr, "  Comment:     %s\n", comment);
        fprintf(stderr, "  Syscalls:    %ld\n", syscalls_counter);
        fprintf(stderr, "  ExecTime:    %llu\n", exectime);
    }
}

void report_result_chklim(int retval, const char* comment)
{
	if (max_memory >= mem_limit)
		report_result(RETVAL_MLE, "memory limit exceeded");
	else if (et_limit && exectime > et_limit)
		report_result(RETVAL_TLE, "exectime limit exceeded");
	else if (!et_limit && extract_time(&res_usage) >= time_limit)
		report_result(RETVAL_TLE, "time limit exceeded");
	else
		report_result(retval, comment);
}

long uread(long offset)
{
	return ptrace(PTRACE_PEEKUSER, tracked, (void*)(offset*sizeof(long)), NULL);
}

void uwrite(long offset, long value)
{
	ptrace(PTRACE_POKEUSER, tracked, (void*)(offset*sizeof(long)), (void*)value);
	assert(uread(offset) == value);
}

void track_memory(void)
{
	long um;
	char buf[1024];
	FILE* f;

	snprintf(buf, sizeof(buf), "/proc/%d/statm", tracked);
	if (!(f = fopen(buf, "r")))
		fail_tracker("cannot open /proc/.../statm");
	fscanf(f, "%ld", &um);
	fclose(f);
	um *= 4; /* 4kB per page */

	if (um != used_memory)
	{
		used_memory = um;
		if (um > max_memory)
		{
			max_memory = um;
			/*fprintf(stderr, "Max memory used: %ldkB\n", um);*/
		}
	}
}

void track_exectime(void)
{
	char buf[1024];
	FILE* f;

	snprintf(buf, sizeof(buf), "/proc/%d/exectime", tracked);
	if (!(f = fopen(buf, "r"))) {
        exectime = 0xffffffffffffffffULL;
        return;
    }
	fscanf(f, "%llu", &exectime);
    exectime -= exectime_offset;
	fclose(f);
}

void zero_exectime(void)
{
    exectime_offset = exectime;
    exectime = 0;
}

void kill_and_exit(void)
{
	ptrace(PTRACE_KILL, tracked, NULL, NULL);
	ptrace(PTRACE_CONT, tracked, NULL, NULL);
	exit(pass_retval ? 1 : 0);
}

void forbidden(const char* msg)
{
	//fprintf(stderr, "FORBIDDEN ACTION: %s\n", msg);
	report_result(RETVAL_RV, msg);
	kill_and_exit();
}

void emulate_syscall(int retval)
{
	uwrite(ORIG_EAX, -1);
	sc_state = SS_EMULATING;
    em_retval = retval;
}

static void emulate_syscall_exit(void)
{
	uwrite(EAX, em_retval);
}

void skip_syscall(void)
{
	uwrite(ORIG_EAX, -1);
	sc_state = SS_SKIPPING;
}

static void skip_syscall_exit(void)
{
	uwrite(EAX, -ENOSYS);
}

void change_syscall(long sysnr)
{
    changed_syscall = uread(ORIG_EAX);
    uwrite(ORIG_EAX, sysnr);
    sc_state = SS_CHANGED;
}

static void handle_syscall_entry(int sysnr, short flags, int check_perms)
{
	current_syscall = sysnr;
	syscalls_counter++;
	if (flags & SC_SKIP)
		skip_syscall();
	else if (!(((flags & SC_ALLOWED) &&
                    (!force_extensions || !(flags & SC_NOEXTENSION)))
                || allow_all
                || !check_perms
                || ((flags & SC_EXTENSION) && enable_extensions)))
	{
		char buf[256];
		skip_syscall();
		snprintf(buf, sizeof(buf), "intercepted forbidden syscall %d (%s)",
				sysnr, syscall_name(sysnr));
		forbidden(buf);
	}
}

static void handle_syscall_exit(int sysnr, short flags)
{
	long res = uread(EAX);

	if (sysnr != current_syscall)
		fail_tracker("exiting different syscall than entered!");
	current_syscall = -1;
	
	if (res < 0)
		dbgprintf(stderr, "exit %ld (%s)\n", res, strerror(-res));
	else
		dbgprintf(stderr, "exit %ld (ok)\n", res);

	if (flags & SC_MEMORY)
		track_memory();
	return;
}

static inline const char* ss_prefix(short ss)
{
	switch (ss)
	{
		case SS_OUT:       return "-->";
		case SS_IN:        return "<--";
		case SS_SKIPPING:  return "<**";
		case SS_EMULATING: return "<EM";
		default:           return "?!?";
	}
}

static void handle_syscall(void)
{
	long sysnr = uread(ORIG_EAX);
	short flags;
    int qlevel;

    if (sc_state == SS_CHANGED) {
        sysnr = changed_syscall;
        sc_state = SS_IN;
    }

	if (sysnr == -1 && (sc_state == SS_SKIPPING || sc_state == SS_EMULATING))
	{
		dbgprintf(stderr, "<** SYSCALL\n");
        if (sc_state == SS_SKIPPING)
            skip_syscall_exit();
        else
            emulate_syscall_exit();
		sc_state = SS_OUT;
		return;
	}
		
	if (sysnr < 0 || sysnr >= MAX_SYSNR)
	{
		char buf[256];
		skip_syscall();
		snprintf(buf, sizeof(buf), "intercepted unknown syscall %ld", sysnr);
		forbidden(buf);
	}

	flags = syscall_flags[sysnr];

    if (flags & SC_EXTENSION && !enable_extensions)
        forbidden("extensions are disabled");

	dbgprintf(stderr, "%s SYSCALL %s\n", ss_prefix(sc_state),
			syscall_name(sysnr));
	
	qlevel = quirk_syscall(sysnr);
	if (qlevel == 0)
		return;

	switch (sc_state)
	{
		case SS_OUT:
			sc_state = SS_IN;
			handle_syscall_entry(sysnr, flags, qlevel == 1);
			break;
		case SS_IN:
			sc_state = SS_OUT;
			handle_syscall_exit(sysnr, flags);
			break;
        case SS_CHANGED:
			handle_syscall_entry(sysnr, flags, 0);
			break;
		default:
			fail_tracker("state machine in unknown/invalid state");
	}
}

static void periodic_fn(int sig)
{
	kill(tracked, SIGVTALRM);
}

static void termhandler(int sig)
{
	terminate = 1;
}

static void segvhandler(int sig)
{
    fail_tracker("segmentation fault");
}

static char* get_env_string(const char* name, const char* name2, const char* dflt)
{
	char* ret;
	char* t = getenv(name);
	if (!t)
		if (!name2 || !(t = getenv(name2)))
			return strdup(dflt);
	ret = strdup(t);
	unsetenv(name);
	return ret;
}

static long long get_env_number(const char* name, const char* name2, long dflt)
{
	long long ret;
	char* t = getenv(name);
	if (!t)
		if (!name2 || !(t = getenv(name2)))
			return dflt;
	ret = strtoll(t, NULL, 0);
	unsetenv(name);
	return ret;
}

static int get_env_bool(const char* name, const char* name2, int dflt)
{
	int ret;
	char* t = getenv(name);
	if (!t)
		if (!name2 || !(t = getenv(name2)))
			return dflt;
	ret = t[0] != '\0';
	unsetenv(name);
	return ret;
}

static void tracked_process(void)
{
	struct rlimit rl;
	int fd;
	char* params[2] = { the_argv[0], NULL };

	rl.rlim_cur = rl.rlim_max = 0;
	if (setrlimit(RLIMIT_CORE, &rl))
		fail("setrlimit(RLIMIT_CORE) failed");

	rl.rlim_cur = rl.rlim_max = RLIM_INFINITY;
	if (setrlimit(RLIMIT_STACK, &rl))
		fail("setrlimit(RLIMIT_STACK) failed");

	rl.rlim_cur = rl.rlim_max = mem_limit*1024;
	if (setrlimit(RLIMIT_AS, &rl))
		fail("setrlimit(RLIMIT_AS) failed");
	mem_limit = 0;

	if (hard_time_limit)
	{
		rl.rlim_cur = rl.rlim_max = hard_time_limit/1000;
		if (setrlimit(RLIMIT_CPU, &rl))
			fail("setrlimit(RLIMIT_CPU) failed");
	}

	signal(SIGALRM, SIG_IGN);
	
    if (result_fd)
        close(result_fd);

	fd = open("/dev/null", O_WRONLY);
	if (fd < 0)
		fail("cannot open /dev/null");
	if (fd != 2)
	{
		if (dup2(fd, 2) < 0)
			fail("dup2() failed");
		close(fd);
	}

    if (force_extensions) {
        close(0);
        close(1);
    }

	ptrace(PTRACE_TRACEME, getpid(), 0, 0);
	raise(SIGSTOP);

	execvp(the_argv[0], params);
	exit(123);
}

static void tracker_process(void)
{
	struct itimerval itv;
	sigset_t ss;

	sigfillset(&ss);
	sigdelset(&ss, SIGKILL);
	sigdelset(&ss, SIGSTOP);
	sigdelset(&ss, SIGALRM);
	sigdelset(&ss, SIGTERM);
	sigprocmask(SIG_SETMASK, &ss, NULL);
	signal(SIGALRM, periodic_fn);
	signal(SIGTERM, termhandler);
	signal(SIGINT, termhandler);
	signal(SIGSEGV, segvhandler);
	
	for (;;)
	{
		int status;
		pid_t p = wait4(tracked, &status, 0, (struct rusage*)&res_usage);

		if (p < 0)
		{
			if (errno == EINTR)
				continue;
			else
				fail_tracker("waitpid unsuccessful");
		}

		if (terminate)
		{
			report_result(RETVAL_RE, "terminated on request");
			kill_and_exit();
		}

		if (WIFEXITED(status))
		{
			if (!tracked_attached)
				fail_tracker("process exited before attachment");

			/*fprintf(stderr, "process exited normally with status %d\n",
					WEXITSTATUS(status));*/
			
			if (WEXITSTATUS(status) == 0)
				report_result_chklim(RETVAL_OK, "ok");
			else {
                char buf[256];
                sprintf(buf, "runtime error %d", WEXITSTATUS(status));
                report_result_chklim(RETVAL_RE_BASE + WEXITSTATUS(status), buf);
            }
			
			exit(pass_retval ? status : 0);
		}
		else if (WIFSIGNALED(status))
		{
			if (execve_called <= 2 && WTERMSIG(status) == SIGKILL)
			{
				sprintf(buf, "process killed before execve(), OOM?");
				report_result_chklim(RETVAL_MLE, buf);
				exit(pass_retval ? 1 : 0);
			}
			sprintf(buf, "process exited due to signal %d",
					WTERMSIG(status));
			report_result_chklim(RETVAL_SIG_BASE + WTERMSIG(status), buf);
			exit(pass_retval ? 127+WTERMSIG(status) : 0);
		}

		if (!WIFSTOPPED(status))
			fail_tracker("no events after wait4() returned, strange ...");

		if (WSTOPSIG(status) == SIGSTOP)
		{
			if (tracked_attached)
				fail_tracker("tracked process unexpectedly stopped, killing");

			itv.it_value.tv_sec = 0;
			itv.it_value.tv_usec = 10000;
			itv.it_interval.tv_sec = 0;
			itv.it_interval.tv_usec = 10000;
			if (setitimer(ITIMER_REAL, &itv, NULL) < 0)
				fail_tracker("cannot set interval timer");

			tracked_attached = 1;
			track_memory();
		}
		else if (WSTOPSIG(status) == SIGTRAP)
			handle_syscall();
		else if (WSTOPSIG(status) == SIGVTALRM)
		{
			track_memory();
			track_exectime();
			if (et_limit && exectime > et_limit)
			{
				report_result(RETVAL_TLE, "exectime limit exceeded");
				kill_and_exit();
			}
		}
		else
		{
			if (!quiet)
				fprintf(stderr, "tracked process received signal %d\n",
						WSTOPSIG(status));
			track_memory();
			ptrace(PTRACE_SYSCALL, tracked, NULL, WSTOPSIG(status));
			continue;
		}
		
		ptrace(PTRACE_SYSCALL, tracked, NULL, NULL);
	}
}

static void usage(void)
{
    fprintf(stderr, "Usage: supervisor [options] program [parameters]\n\n");
    fprintf(stderr, "Options:\n");
    fprintf(stderr, "  -e      enable ioshm extensions\n");
    fprintf(stderr, "  -x      force extensions, disable regular i/o\n");
    fprintf(stderr, "  -f fd   write result to the specified file descriptor\n");
    fprintf(stderr, "  -w pid  wait for the specified process and attach to it\n");
    fprintf(stderr, "  -r      pass return value from the program\n");
    fprintf(stderr, "  -q      be quiet\n");
    fprintf(stderr, "  -h      print usage information\n");
    exit(0);
}

int main(int argc, char** argv)
{
    const char* progname;
    int no_fork = 0;

    resultf = stderr;

	argv++;
	argc--;
	while (argc > 0)
	{
		if (!strcmp(*argv, "-q"))
			quiet = 1;
		else if (!strcmp(*argv, "-x"))
			enable_extensions = force_extensions = 1;
		else if (!strcmp(*argv, "-e"))
			enable_extensions = 1;
        else if (!strcmp(*argv, "-f")) {
            result_fd = atoi(*++argv);
            argc--;
            resultf = fdopen(result_fd, "w");
            if (!resultf)
                fail("cannot open result fd");
        }
        else if (!strcmp(*argv, "-w")) {
            tracked = (pid_t)atoi(*++argv);
            argc--;
            no_fork = 1;
            execve_called = 3;
            progname = "<auto>";
        }
        else if (!strcmp(*argv, "-r"))
            pass_retval = 1;
        else if (!strcmp(*argv, "-h"))
            usage();
		else
			break;
		argv++;
		argc--;
	}
	
	the_argc = argc;
	the_argv = argv;

	if (!quiet)
	{
		fprintf(stderr, "SJudge Supervisor version " VERSION_STRING "\n");
		fprintf(stderr, "Copyright 2004-2006 (C) Szymon Acedanski\n\n");
	}

	if (argc < 1 && !no_fork)
        usage();
    else
        progname = argv[0];

	mem_limit = get_env_number("MEM_LIMIT", "MEM", 16000) & ~3;
	time_limit = get_env_number("TIME_LIMIT", "TIME", 10000);
    hard_time_limit = get_env_number("HARD_LIMIT", "HARD", time_limit + 2000);
	out_limit = get_env_number("OUT_LIMIT", NULL, 10000000);
	et_limit = get_env_number("ET_LIMIT", NULL, 0);
	count_systime = get_env_bool("COUNT_SYSTIME", NULL, 0);
	allow_all = get_env_bool("ALLOW", NULL, 0);

	if (et_limit)
		 hard_time_limit = 0;

	if (!quiet)
	{
		fprintf(stderr, "Program:        %s\n", progname);
		fprintf(stderr, "Memory limit:   %ld kB\n", mem_limit);
		if (et_limit)
			fprintf(stderr, "ExecTime limit: %llu Mops\n", et_limit/1000000ULL);
		else
			fprintf(stderr, "Time limit:     %ld msecs\n", time_limit);
		fprintf(stderr, "Output limit:   %ld kB\n", out_limit/1024);
		fprintf(stderr, "System time:    %scounted\n\n",
				count_systime?"":"not ");
		fprintf(stderr, "Syscall filter: %sabled\n\n", allow_all?"dis":"en");
	}

	nice(10);

    if (enable_extensions)
        init_ioshm();

    if (!no_fork) {
        tracked = fork();
        if (tracked < 0)
            fail("cannot fork");
        if (tracked == 0)
        {
            tracked_process();
            return 0;
        }
    }

	tracker_process();
    return 0;
}
