/*
 * sthreads_sol.c -- interface between Toba's thread model and Solaris threads
 *
 * defines
 *    mythread
 *    sthread_init
 *    sthread_got_exception
 *    sthread_create
 *    sthread_stop
 *    sthread_yield
 *    sthread_sleep
 *    sthread_suspend
 *    sthread_resume
 *    sthread_setpriority
 *    sthread_current
 *    sthread_exit
 *    sthread_dontkill_start
 *    sthread_dontkill_end
 *    sthread_mutex_init
 *    sthread_mutex_lock
 *    sthread_mutex_unlock
 *    sthread_mutex_destroy
 *    sthread_cond_init
 *    sthread_cond_wait
 *    sthread_cond_signal
 *    sthread_cond_broadcast
 *    sthread_cond_destroy
 * 
 * TO DO:
 *   - There is some code in here that is not dependent
 *     on the underlying threads package - it should probably
 *     be broken out
 *
 * NOTE: We use the (otherwise unused) threadQ field of
 *       a java.lang.Thread to point to a thread's TCB
 */

#include <stdio.h>
#include <stddef.h>
#include <time.h>
#include <assert.h>
#include <stdlib.h>
#include <errno.h>
#include <signal.h>

#include "toba.h"
#include "../gc/gc.h"
#include "sthreads.h"
#include "runtime.h"

#include "java_lang_Object.h"
#include "java_lang_Thread.h"
#include "java_lang_ThreadGroup.h"
#include "java_lang_IllegalThreadStateException.h"

#define SUSPEND_SIG		SIGUSR1
#define ASYNC_EXCEPTION_SIG	SIGUSR2

/* global key for the per-thread TCB pointer */
static thread_key_t perthread;

/* counter of non-daemon threads and semaphore for signalling last thread */
static sema_t counter_lock;
static sema_t last_thread;
static int thread_counter;

static int sthreads_initialized = 0;


/* single threaded flag and the single threads tcb */
int singlethreaded = 1;
static tcb *onlythread;

static void
signal_handler(int sig)
/*ARGSUSED*/
{
    tcb *thr = sthread_current();

    assert(thr);
    if(sig == SUSPEND_SIG) {
        if(thr->defer) {
            thr->defer_suspend = 1;
        } else {
            thr->defer_suspend = 0;
            sema_wait(&thr->lock);
            thr->state = TS_SUSPENDED;
            sema_post(&thr->lock);
            /* let the suspending thread continue */
            sema_post(&thr->synch);
            sema_wait(&thr->suspend);
        }
    }

    if(sig == ASYNC_EXCEPTION_SIG) {
        if(thr->got_async_exception) {
            /* ignore all but the first async exception */
            thr->defer_async_exception = 0;
            return;
        }

        /* record sig - sigmask will have to be cleared */
        thr->sig = sig;
        if(thr->defer) {
            thr->defer_async_exception = 1;
        } else {
            thr->got_async_exception = 1;
            thr->defer_async_exception = 0;
            athrow(thr->mythread.exception);
        }
    }
}

/* notification that an exception was caught */
void
sthread_got_exception(void)
{
    tcb *thr = sthread_current();
    sigset_t set;

    if(thr->sig) {
        /* if we longjmp'ed out of a signal handler fix up the sigmask */
        sigemptyset(&set);
        sigaddset(&set, thr->sig);
        thr->sig = 0;
        sigprocmask(SIG_UNBLOCK, &set, 0);
    }
}

/*
 * Initialization of thread system
 */
void
sthread_init(void) 
{
    tcb *newtcb;
    struct in_java_lang_Thread *obj;
    struct in_java_lang_ThreadGroup *sysgroup;
    struct sigaction act;

    act.sa_handler = signal_handler;
    act.sa_flags = SA_RESTART;
    sigemptyset(&act.sa_mask);
    sigaction(ASYNC_EXCEPTION_SIG, &act, 0);
    sigaction(SUSPEND_SIG, &act, 0);

    /* Start up singlethreaded */
    singlethreaded = 1;

    /* set up a key for looking up TCBs */
    assert(thr_keycreate(&perthread, 0) == 0);

    /* create a java.lang.Thread object and a thread group
     * for this thread */
    newtcb = (tcb *)allocate(sizeof(*newtcb));

    /* associate the TCB with this thread */
    assert(thr_setspecific(perthread, newtcb) == 0);
    onlythread = newtcb;

    newtcb->thread = thr_self();
    /* defer asynchronous actions until we're ready for them */
    newtcb->defer = 1;
    sema_init(&newtcb->synch, 0, USYNC_THREAD, 0);
    sema_init(&newtcb->suspend, 0, USYNC_THREAD, 0);
    sema_init(&newtcb->lock, 1, USYNC_THREAD, 0);
    newtcb->state = TS_RUNNING;

    /* new ThreadGroup() */
    sysgroup = new(&cl_java_lang_ThreadGroup.C);
    init__UzqJE(sysgroup);

    /* new Thread(group, runnable, string) */
    obj = new(&cl_java_lang_Thread.C);
    newtcb->obj = obj;
    init_TRS_jtRvs(obj, sysgroup, obj, javastring("main"));

    /* set up non-daemon thread counter and locks */
    thread_counter = 0;
    sema_init(&counter_lock, 1, USYNC_THREAD, 0);
    sema_init(&last_thread, 0, USYNC_THREAD, 0);

    /* the thread is still not killable yet - 
     * the lock is held until default exception
     * handling is set up in start_thread() */

    sthreads_initialized = 1;
}

/* Note that this routine runs prior to calling sthread_init */
void
sthread_launch(void (*routine)(void *arg), void *arg)
{

    /* Dispatch to the named routine */
    (*routine)(arg);

    sthread_exit();
    /*NOTREACHED*/
}

static void
add_nondaemon(void)
{
    /* increment count of non-daemon threads */
    sema_wait(&counter_lock);
    thread_counter ++;
    sema_post(&counter_lock);
}

static void
rem_nondaemon(void)
{
    int done;

    sema_wait(&counter_lock);
    thread_counter --;
    done = (thread_counter == 0);
    sema_post(&counter_lock);
    if(done)
        sema_post(&last_thread);
}

static void *
thread_startup(void *obj)
{
    struct in_java_lang_Thread *o = (Object)obj;
    tcb *curthread;
    struct sigaction act;
    int monitor_held;

    /* setup handlers for asynchronous events */
    act.sa_handler = signal_handler;
    act.sa_flags = SA_RESTART;
    sigemptyset(&act.sa_mask);
    sigaction(ASYNC_EXCEPTION_SIG, &act, 0);
    sigaction(SUSPEND_SIG, &act, 0);

    /* let parent continue - we're done our critical stuff */
    curthread = (tcb *)o->threadQ;
    sema_post(&curthread->synch);

    /* set our priority and associate our TCB with the thread */
    thr_setprio(curthread->thread, o->priority);
    if(thr_setspecific(perthread, curthread) == 0) {
        /* start running in the java.lang.Thread start() method */
        start_thread(o->class->M.run__QJ0S5.f, o);
    } else {
        /* this should never happen - just die silently */
        fprintf(stderr, "thr_setspecific failed!\n");
    }

    /* break association of object and tcb - mark thread as not being active */
    o->threadQ = 0;

    /* send notification up to anyone doing a join on us */
    sthread_dontkill_start(&curthread->mythread);
    monitorenter(o, &curthread->mythread, 1, &monitor_held);
    monitornotifyall(o);
    monitorexit(o, &curthread->mythread, 0, &monitor_held);
    sthread_dontkill_end(&curthread->mythread);

    /* last non-daemon thread signals the main thread to exit */
    if(o->daemon == JAVA_FALSE)
        rem_nondaemon();

    return 0;
}

/*
 * create a new thread associated with this java.lang.Thread object
 * and start it running in its start method
 */
void
sthread_create(struct in_java_lang_Thread *o)
{
    tcb *newtcb;

    /* dont start a new thread if there's already a running thread */
    if(o->threadQ) 
        throwMesg(&cl_java_lang_IllegalThreadStateException.C,
                  "thread already running");

    /* we've just gone multithreaded boys */
    if (singlethreaded) {
	singlethreaded = 0;
	fixup_monitors();
    }
    
    newtcb = (tcb *)allocate(sizeof(*newtcb));
    newtcb->obj = o;
    sema_init(&newtcb->synch, 0, USYNC_THREAD, 0);
    sema_init(&newtcb->suspend, 0, USYNC_THREAD, 0);
    sema_init(&newtcb->lock, 1, USYNC_THREAD, 0);
    newtcb->state = TS_RUNNING;

    /* defer asynchronous actions until we're ready for them */
    newtcb->defer = 1;
    o->threadQ = newtcb;    /* threadQ points to TCB */

    if(o->daemon == JAVA_FALSE)
        add_nondaemon();

    if(thr_create(0, 0, thread_startup, o, 0, &newtcb->thread) != 0) {
        /* thread creation failed */
        if(o->daemon == JAVA_FALSE)
            rem_nondaemon();
    }

    /* wait here until child does critical stuff */
    sema_wait(&newtcb->synch);
}

/*
 * Retrieve per-thread state
 */
tcb *
sthread_current()
{
    tcb *ret;

    if (singlethreaded) {
	return onlythread;
    } else {
        thr_getspecific(perthread, (void **)&ret);
        /* should only be null during thread initialization */
        /* assert(ret || !sthreads_initialized); */
        return ret;
    }
}

struct mythread *
mythread()
{
    return &sthread_current()->mythread;
}

void
sthread_yield()
{
    thr_yield();
}

void
sthread_sleep(Long millis)
{
    struct timeval tv;

    tv.tv_sec = millis / 1000;
    tv.tv_usec = (millis % 1000) * 1000;
    select(0, 0, 0, 0, &tv);
}

void
sthread_suspend(struct in_java_lang_Thread *o)
{
    tcb *thread = (tcb *)o->threadQ;

    if(thread) {
        thr_kill(thread->thread, SUSPEND_SIG);
        /* wait till he actually handles it to continue */
        sema_wait(&thread->synch);
    }
}

void
sthread_resume(struct in_java_lang_Thread *o)
{
    tcb *thread = (tcb *)o->threadQ;

    if(thread) {
        sema_wait(&thread->lock);
        if(thread->state == TS_SUSPENDED) {
            thread->state = TS_RUNNING;
            sema_post(&thread->suspend);
        }
        sema_post(&thread->lock);
    }
}

void
sthread_setpriority(struct in_java_lang_Thread *o, int priority)
{
    tcb *thread = (tcb *)o->threadQ;

    o->priority = priority;
    if(thread) {
        thr_setprio(thread->thread, priority);
    }
}

void
sthread_stop(struct in_java_lang_Thread *o, Object e)
{
    tcb *thread = (tcb *)o->threadQ;

    if(thread) {
        /* note - race for multiple killers but this is 
         * nondeterministic anyway - we don't care who wins */
        thread->mythread.exception = e;
        thr_kill(thread->thread, ASYNC_EXCEPTION_SIG);
    }
}

/*
 * wait for all non-daemon threads to exit
 */
void
sthread_exit()
{
    while(thread_counter != 0) {
        sema_wait(&last_thread);
    }
}

/* start a sequence in which a thread shouldn't be killed */
void
sthread_dontkill_start(struct mythread *thr)
{
    /* convert from mythread pointer to tcb */
    tcb *curthread = (tcb *) 
        ((char *)thr - offsetof(struct tcb, mythread));

    /* if we have no tcb there's no chance we could be killed */
    if(curthread) {
        curthread->defer ++;
    }
}

/* end a sequence in which a thread shouldn't be killed */
void
sthread_dontkill_end(struct mythread *thr)
{
    /* convert from mythread pointer to tcb */
    tcb *curthread = (tcb *) 
        ((char *)thr - offsetof(struct tcb, mythread));

    if(curthread) {
        curthread->defer --;

        if(curthread->defer == 0) {
            /* re-send deferred signals for deferred actions */
            if(curthread->defer_suspend)
                thr_kill(curthread->thread, SUSPEND_SIG);
            if(curthread->defer_async_exception)
                thr_kill(curthread->thread, ASYNC_EXCEPTION_SIG);
        }
    }
}

int
sthread_dontkill_p(struct mythread *thr)
{
    /* convert from mythread pointer to tcb */
    tcb *curthread = (tcb *) 
        ((char *)thr - offsetof(struct tcb, mythread));

    return curthread->defer;
}

void
sthread_mutex_init(struct sthread_mutex *m)
{
    mutex_init(&m->mutex, USYNC_THREAD, 0);
}

void
sthread_mutex_lock(struct sthread_mutex *m)
{
    mutex_lock(&m->mutex);
}

void
sthread_mutex_unlock(struct sthread_mutex *m)
{
    mutex_unlock(&m->mutex);
}

void
sthread_mutex_destroy(struct sthread_mutex *m)
{
    mutex_destroy(&m->mutex);
}

void
sthread_cond_init(struct sthread_cond *c)
{
    cond_init(&c->cond, USYNC_THREAD, 0);
    c->nsignals = 0;
    c->nwaiting = 0;
}

/* We can't just use the bare conditions that Solaris or Pthreads
   provide, because they don't guarantee that only one thread leaves a
   signaled wait */
void
sthread_cond_wait(struct sthread_cond *c, struct sthread_mutex *m, 
		  Long timeout)
{
    struct timeval tv;
    timestruc_t waittime;
    int timeoutsec;
    /* We hold m, so we can access fields of c atomically */

    int retval = 0;

    if (timeout > 0) {
	gettimeofday(&tv, NULL);
        timeoutsec = timeout / 1000;
	waittime.tv_sec = tv.tv_sec + timeoutsec;
	waittime.tv_nsec = tv.tv_usec * 1000 + 
            (timeout - 1000 * timeoutsec) * 1000000;
    }
    
    c->nwaiting++;
    do {
        if (timeout > 0) {
	    retval = cond_timedwait(&c->cond, &m->mutex, &waittime);
        } else {
	    retval = cond_wait(&c->cond, &m->mutex);
	}
    } while ((c->nsignals == 0) && (retval == EINTR));

    c->nwaiting--;

    if (retval == 0) {
        c->nsignals--;
    } 
}

void
sthread_cond_signal(struct sthread_cond *c) 
{
    /* We hold the associated lock, so we can access fields of c atomically */
    if (c->nsignals < c->nwaiting) {
	c->nsignals++;
        cond_signal(&c->cond);
    }
}

void
sthread_cond_broadcast(struct sthread_cond *c)
{
    if (c->nsignals < c->nwaiting) {
        c->nsignals = c->nwaiting;
        cond_broadcast(&c->cond);
    }
}

void
sthread_cond_destroy(struct sthread_cond *c)
{
    cond_destroy(&c->cond);
}
