#ifndef TASK_GENERATION
#define TASK_GENERATION

#include <vector>
#include <memory>
#include <cmath>

#include "matrix_desc.h"
#include "meta_task.h"
#include "meta_task_scheduler.h"
#include "context.h"

//#define META_SIZE 4

struct node_func_rand
{
    static inline int panel(...)
    {
        return -1;
    }

    static inline int b(...)
    {
        return -1;
    }

    static inline int schur(...)
    {
        return -1;
    }

    static inline int behind(...)
    {
        return -1;
    }

    static inline int random(...)
    {
        return -1;
    }

    static inline int rand_meta(...)
    {
        return 1;
    }
};

struct node_func_true_rand
{
    static inline int panel(...)
    {
        return -1;
    }

    static inline int b(...)
    {
        return -1;
    }

    static inline int schur(...)
    {
        return -1;
    }

    static inline int behind(...)
    {
        return -1;
    }

    static inline int random(int tj, int ti, int nodes)
    {
        return rand() % nodes;
    }

    static inline int rand_meta(...)
    {
        return 1;
    }
};

struct panel_strat_rand
{
    static inline int panel(...)
    {
        return -1;
    }
};

struct panel_strat_top
{
    static inline int panel(int k, int tk, int nodes)
    {
        return (2*tk+1) % nodes;
    }
};

struct schur_strat_rand
{
    static inline int schur(...)
    {
        return -1;
    }
};

struct schur_strat_A
{
    static inline int schur(int tj, int ti, int tk, int nodes)
    {
        return (tj+tk+1) % nodes;
    }
};

struct schur_strat_B
{
    static inline int schur(int tj, int ti, int tk, int nodes)
    {
        return (tk+ti+1) % nodes;
    }
};

struct schur_strat_C
{
    static inline int schur(int tj, int ti, int tk, int nodes)
    {
        return (tj+ti+1) % nodes;
    }
};

template < typename panel_strat = panel_strat_rand,
           typename schur_strat = schur_strat_rand>
struct node_func_det
{
    static inline int panel(int k, int tk, int nodes)
    {
        return panel_strat::panel(k, tk, nodes);
    }

    static inline int b(int ti, int tk, int nodes)
    {
        return (tk+ti+1) % nodes;
    }

    static inline int schur(int tj, int ti, int tk, int nodes)
    {
        return schur_strat::schur(tj,ti,tk, nodes);
    }

    static inline int behind(int ti, int tk, int nodes)
    {
        return (tk+ti+1) % nodes;
    }

    static inline int random(int tj, int ti, int nodes)
    {
        return (tj+ti+1) % nodes;
    }

    static inline int rand_meta(int meta)
    {
        return meta;
    }
};

struct node_func_col
{
    static inline int panel(int k, int tk, int nodes)
    {
        return (tk+1) % nodes;
    }

    static inline int b(int ti, int tk, int nodes)
    {
        return (ti+1) % nodes;
    }

    static inline int schur(int tj, int ti, int tk, int nodes)
    {
        return (ti+1) % nodes;
    }

    static inline int behind(int ti, int tk, int nodes)
    {
        return (ti+1) % nodes;
    }

    static inline int random(int tj, int ti, int nodes)
    {
        return (ti+1) % nodes;
    }

    static inline int rand_meta(int meta)
    {
        return meta;
    }
};

inline int min(int a, int b)
{
    return (a < b) ? a: b;
}

inline int nMeta(matrix_desc& mat, int k, int meta)
{
    return std::ceil((float)(mat.nt-(k+1))/meta);
}

inline int mMeta(matrix_desc& mat, int k, int meta)
{
    return std::ceil((float)(mat.mt-(k+1))/meta);
}

inline int nCom(matrix_desc& mat, int meta)
{
    int sum = 0;
    for (int k = 0; k < min(mat.nt, mat.mt); k++)
    {
        sum += nMeta(mat,k,meta)*(1+mMeta(mat,k,meta));
    }
    sum += min(mat.nt, mat.mt);

    return sum;
}

template< typename node_func = node_func_rand >
void generate_meta_task_dag (matrix_desc& mat,          int meta,
                             bool run_while_generating, bool run_with_work_stealing)
{
    meta_task_scheduler& scheduler = meta_context::getContext().scheduler;

	scheduler.all_tasks_inserted = false;
    scheduler.all_tasks_finished = !run_while_generating;
    // Enable work stealing after task generation to make sure,
    // that the first few panel tasks are not executed on Node 0
    // scheduler.work_stealing = run_with_work_stealing;

    int nodes  = meta_context::getContext().node_scheduler.size();
    int min_nm = min(mat.nt, mat.mt);
    int nm     = nMeta(mat, -1, meta);

    std::vector<int> idx(nm);            // end indecies of meta-columns / meta-rows (symetric)
    int x = (mat.nt % meta) ? mat.nt % meta : meta;
    idx.at(0) = x;
    for (int i = 1; i < nm; i++)
    {
        x += meta;
        idx.at(i) = x;
    }

    std::vector<std::vector<int> > m_column(  nm, std::vector<int>() ); // for each meta-tile-column one vector with all jobs from the previous level (used for incoming dependencies in the following level)
    std::vector<std::vector<int> > level( min_nm, std::vector<int>() ); // all U-tasks ordered by level (used for incoming edges on BEHIND_PANEL jobs)
    std::vector<int>               panel;

    int id=0;
    int behind_panel_id = nCom(mat, meta)-1;                            // first id belonging to a BEHIND_PANEL job who is waiting for the current level
    int m_col_k = 0;                                                    // meta-column of the current PANEL
    std::shared_ptr<meta_task> temp;

    for (int k=0; k < min_nm; k++)
    {
        if (k >= idx.at(m_col_k))
        {
            m_column.at(m_col_k).clear();
            m_col_k ++;
        }

        temp = std::shared_ptr<meta_task>(new meta_task(id, PANEL, mat, k, k, k, 1, mat.mt-k, 3*k, node_func::panel(k, m_col_k, nodes))); // (2*m_col_k) % nodes
        // incoming dependencies (schur complement from previous round)
        temp->in_dep = m_column.at(m_col_k);
        // outgoing dependencies (b jobs)
        for (int i = 0; i < nMeta(mat, k, meta); ++i) temp->out_dep.push_back(id+1+i*mMeta(mat, k, meta)+i);
        // outgoing dependencies (X jobs)
        for (int i = k-1; i >= 0; --i) temp->out_dep.push_back(behind_panel_id-i);

        scheduler.insertMetaTask(temp);

        int panel_id      = id;
        int n_this_level  = nMeta(mat, k, meta) * (1 + mMeta(mat,k,meta)) + 1;
        int next_level_id = id + n_this_level;
        level.at(k).reserve(n_this_level);
        panel.push_back(id);
        id++;

        behind_panel_id += k+1;

        int x = k+1;
        std::vector<int> out;              // used for outgoing dependencies to the B-TASK (and possibly the PANEL) of the next level;
        out.push_back(next_level_id++);    // the SCHUR jobs of the first column have outgoing edges to the next PANEL

        if ( idx.at(m_col_k)-x > 1 || (idx.at(m_col_k) == x && meta > 1))
        {
            //std::cout << "hi" << std::endl;
            out.push_back(next_level_id);
            next_level_id += mMeta(mat, k+1, meta)+1;
        }

        for(int n = m_col_k + ((x >= idx.at(m_col_k)) ? 1 : 0) ; n < nm; n++)
        {
            temp = std::shared_ptr<meta_task>(new meta_task(id, B, mat, k, x, k, idx.at(n)-x, 1, x+2*k, node_func::b(n, m_col_k, nodes)));  // (m_col_k+n) % nodes pri = x0+3k = x-k+3k
            // incoming dependencies (schur complement from previous round + panel job this round)
            temp->in_dep = m_column.at(n);
            temp->in_dep.push_back(panel_id);
            // outgoing dependencies (schur complement jobs)
            for (int i = 0; i < mMeta(mat, k, meta); i++) temp->out_dep.push_back(id+1+i);
            //level.at(k).push_back(id);

            scheduler.insertMetaTask(temp);
            int b_id = id;
            id++;

            m_column.at(n).clear();
            int y = k+1;
            for (int m = m_col_k + ((y >= idx.at(m_col_k)) ? 1 : 0) ; m < nm; m++)
            {
                temp = std::shared_ptr<meta_task>(new meta_task(id, SCHUR_COMPLEMENT, mat, k, x, y, idx.at(n)-x, idx.at(m)-y, x+2*k+1, node_func::schur(m, n, m_col_k,nodes))); // (n+m_col_k) % nodes instead of ... pri x0+3k+1 -> x0 = x-k
                // incoming dependencies (b job this round)
                temp->in_dep.push_back(b_id);
                // outgoing dependencies (panel_job (+b_job?) nextround)
                temp->out_dep = out;
                temp->out_dep.push_back(behind_panel_id);
                level.at(k).push_back(id);

                scheduler.insertMetaTask(temp);

                m_column.at(n).push_back(id);
                id++;
                y = idx.at(m);
            }
            x = idx.at(n);
            out.clear();
            out.push_back(next_level_id);
            next_level_id += mMeta(mat, k+1, meta)+1;
        }
    }

    int in(id);
    m_col_k = 0;
    for (int k=0; k<min_nm; k++)
    {
        if (k >= idx.at(m_col_k)) m_col_k++;
        int m_col_i = 0;
        for (int i=0; i<k; i++)
        {
            if (i >= idx.at(m_col_i)) m_col_i++;
            temp = std::shared_ptr<meta_task>(new meta_task(id, BEHIND_PANEL_UPDATE, mat, k,i,k,1,mat.mt-k, 9999, node_func::behind(m_col_i, m_col_k, nodes))); // (m_col_k+m_col_i) % nodes
            // incoming dependencies (all jobs of level k + behind_panel_updates level k-1)
            if (i < k-1)
            {
                temp->in_dep.push_back(in++);
            }
            else
            {
                temp->in_dep = level.at(k-1);
            }

            temp->in_dep.push_back(panel.at(k));
            // outgoing dependencies (behind_panel_updates for level k+1)
            if (k < min_nm - 1) temp->out_dep.push_back(id+k);

            scheduler.insertMetaTask(temp);

            id++;
        }
    }
    scheduler.all_tasks_finished = false;
	scheduler.all_tasks_inserted = true;
    scheduler.work_stealing = run_with_work_stealing;
}

template < typename node_func >
void generate_random_matrix(matrix_desc& A,            int meta,
                            bool run_while_generating, bool run_with_work_stealing)
{
    srand(3456);
    meta_task_scheduler& scheduler = meta_context::getContext().scheduler;
    scheduler.work_stealing = run_with_work_stealing;
    scheduler.all_tasks_finished = !run_while_generating;
    scheduler.all_tasks_inserted = false;

    int nodes = meta_context::getContext().node_scheduler.size();
    int temp_meta = node_func::rand_meta(meta);

    int nm = nMeta(A,-1,temp_meta);
    int x  = (A.nt % temp_meta) ? A.nt % temp_meta : temp_meta;
    std::pair<int,int>* idx = new std::pair<int,int>[nm];  // will be: (0, off), (off, meta), (off+meta, meta), (off+2meta, meta)...
    idx[0] = std::make_pair(0, x);
    for (int i = 1; i < nm; i++)
    {
        idx[i] = std::make_pair(x,temp_meta);
        x += temp_meta;
    }

    int id = 0;
    std::shared_ptr<meta_task> temp;
    for (int j = 0; j < nm; j++)
    {
        for (int i = 0; i < nm; i++)
        {
            temp = std::shared_ptr<meta_task>(new meta_task(id++,RANDOMIZE, A, 3456, idx[i].first, idx[j].first, idx[i].second, idx[j].second, 0, node_func::random(j,i,nodes))); // (i+j) % nodes
            scheduler.insertMetaTask(temp);
        }
    }
    scheduler.all_tasks_finished = false;
    scheduler.all_tasks_inserted = true;

    delete[] idx;
}


#endif //TASK_GENERATION
