#include "meta_task_scheduler.h"

// enables some sanity checks
//#define SANITY

meta_task_scheduler::meta_task_scheduler()
    : newest_task(-1), work_stealing(false),
      all_tasks_inserted(true), all_tasks_finished(true)
{
    ready_queues.push_back( priority_queue() );
    pthread_mutex_init(&task_hash_mutex  , 0);
    pthread_mutex_init(&ready_queue_mutex, 0);
}

meta_task_scheduler::~meta_task_scheduler()
{
    pthread_mutex_destroy(&task_hash_mutex);
    pthread_mutex_destroy(&ready_queue_mutex);
}

bool meta_task_scheduler::ready(meta_task &task)
{
    int temp = 0;

    // all tasks < newest_task are either finished or they are in not_finished_hash
    // (tasks are inserted in order and not removed, until they are finished)

    for (auto dep : task.in_dep)
    {
#ifdef SANITY
        if (dep > newest_task) std::cout << "error: sanity check failed dependency " << dep << " is greater than newest_task " << newest_task << std::endl;
#endif
        if ( !task_hash.count(dep) ) ++temp;
    }

    task.fulfilled.store(temp, std::memory_order_relaxed);
    return temp == task.in_dep.size();
}

void meta_task_scheduler::insertMetaTask(meta_task_ptr task)
{
    /* This code can be used to analyse the inserted task DAG
    std::cout << "id:"    << task->id
              << " type:" << task->type
              << " x:"    << task->x
              << " y:"    << task->y
              << " w:"    << task->meta_x
              << " h:"    << task->meta_y
              << " pri:"  << task->pri
              << " node:" << task->preferred_node
              << "  in:";
    //for (int i = 0; i < task->in_dep.size(); i++) std::cout << task->in_dep.at(i) << " ";
    for (auto dep : task->in_dep) std::cout << dep << " ";
    std::cout << "  out:";
    //for (int i = 0; i < task->out_dep.size(); i++) std::cout << task->out_dep.at(i) << " ";
    for (auto dep : task->out_dep) std::cout << dep << " ";
    std::cout << std::endl;
    //*/

	pthread_mutex_lock(&task_hash_mutex);
    bool r = ready(*task);
	insertTaskHash(task);
	newest_task = task->id;
	pthread_mutex_unlock(&task_hash_mutex);

    if (r)
    {
        insertReadyQueue(task);
    }
}


bool meta_task_scheduler::getMetaTask(int id, meta_task_ptr* meta)
{
    int qid = -1;
    int pri = std::numeric_limits<int>::max();

    pthread_mutex_lock(&ready_queue_mutex);

    if (!ready_queues.at(0).empty())
    {
        qid = 0;
        pri = ready_queues.at(0).top()->pri;
    }

    if (!ready_queues.at(id+1).empty() && ready_queues.at(id+1).top()->pri < pri)
    {
        qid = id+1;
        pri = ready_queues.at(id+1).top()->pri;
    }

    if (qid >= 0)
    {
        (*meta) = ready_queues.at(qid).top();
        ready_queues.at(qid).pop();
        pthread_mutex_unlock(&ready_queue_mutex);
        return true;
    }

	if (!work_stealing)
	{
		pthread_mutex_unlock(&ready_queue_mutex);
		return false;
	}
	
	for (int i = 1; i < ready_queues.size(); ++i)
    {
        if (i == id+1) continue;
        if (!ready_queues.at(i).empty() && ready_queues.at(i).top()->pri < pri)
        {
            qid = i;
            pri = ready_queues.at(i).top()->pri;
        }
    }

    if (qid >= 0)
    {
        (*meta) = ready_queues.at(qid).top();
        ready_queues.at(qid).pop();
        pthread_mutex_unlock(&ready_queue_mutex);
        return true;
    }

    pthread_mutex_unlock(&ready_queue_mutex);
    return false;

}


void meta_task_scheduler::finishMetaTask(meta_task_ptr task)
{
    pthread_mutex_lock(&task_hash_mutex);

	int temp_newest = removeTaskHash(task);

    for (auto dep : task->out_dep)
    {
        if (dep > temp_newest) continue;

        if (!task_hash.count(dep))
        {
            pthread_mutex_unlock(&task_hash_mutex);
			std::cout << "warning: task" << dep << " not in hash while clearing dependencies.";
            continue;
        }
        auto t = task_hash.find(dep)->second;

        if (t->status != NOT_READY)
        {
            std::cout << "error: meta_task " << t->id << " has been declared READY prematurely" << std::endl;
            continue;
        }

        // atomic because (possible race condition)
        int temp = ++(t->fulfilled);
        if (temp == t->in_dep.size())
        {
#ifdef SANITY
            if (!ready(*t)) std::cout << "error: sanity check failed " << t->id << " id not ready yet" << std::endl;
#endif
            insertReadyQueue(t);
        }
    }
    pthread_mutex_unlock(&task_hash_mutex);
}

void meta_task_scheduler::insertTaskHash(meta_task_ptr meta)
{
#ifdef SANITY
    if ( task_hash.count(meta->id) )
    {
        std::cout << "error: tried to insert a task which is already hashed!" << std::endl;
        pthread_mutex_unlock(&task_hash_mutex);
        return;
    }
#endif
    task_hash.insert(std::make_pair(meta->id, meta));
}

int meta_task_scheduler::removeTaskHash(meta_task_ptr meta)
{
#ifdef SANITY
    if ( ! task_hash.count(meta->id) )
    {
        std::cout << "error: tried to erase a task which is not hashed!" << std::endl;
        pthread_mutex_unlock(&task_hash_mutex);
        return;
    }
#endif
    task_hash.erase(task_hash.find(meta->id));

    if (task_hash.empty()) all_tasks_finished = all_tasks_inserted;

    meta->status = FINISHED;
	int temp = newest_task;

	return temp;
}

void meta_task_scheduler::insertReadyQueue(meta_task_ptr meta)
{
    pthread_mutex_lock(&ready_queue_mutex);
    ready_queues.at(meta->preferred_node + 1).push(meta);
    meta->status = GLOB_QUEUE;
    pthread_mutex_unlock(&ready_queue_mutex);
}
