
#include "numa.h"


/*
 * While the CECK flag is defined, the computed results are checked.
 * Because the checking process is single threaded and unoptimized, we do not encourage you to try this on big matrices.
 */
//#define CHECK

#include <iostream>
#include <iomanip>
#include <fstream>

#include <string>
#include <cstdio>
#include <chrono>

#include "matrix_desc.h"
#include "context.h"
#include "task_generation.h"
#include "utility.h"


enum log_stat {
    LOG      = 0,
    LOG_ONCE = 1,
    NO_LOG   = 2
};

/*************** Actual main ************************************/
int main(int argc, char** argv) {

    if (numa_available())
    {
        std::cout << "error: libnuma is not available!" << std::endl;
        return 1;
    }
    // INITIALIZE PARAMETERS
    int  nb              =  256;
    int  nStart          = 8192;
    int  nEnd            = 8192;
    int  nStep           = 1024;
    int  meta            =    3;
    int  cores           = numa_num_task_cpus()/2;  // 32;
    int  nodes           = numa_num_task_nodes(); //  4;
    int  iter               =    5;
    bool work_stealing      = true;
    bool work_while_gen     = true;
    log_stat log            = LOG_ONCE;
    std::string numa_string = "rand";
    std::string log_path    = "";
    std::vector<int> threads;

    void (*rand_dag) (matrix_desc&, int, bool, bool) = &generate_random_matrix<node_func_rand>;
    void (*gen_dag)  (matrix_desc&, int, bool, bool) = &generate_meta_task_dag<node_func_rand>;

    // READ PARAMETERS
    for (int i = 1; i < argc && argv[i]; ++i)
    {
        std::string str(argv[i]);
        if (!str.compare(0,10,"--n_range="))
        {
            size_t idx;
            str = str.substr(10);
            nStart = std::stoi(str,&idx);
            str    = str.substr(++idx);
            nEnd   = std::stoi(str,&idx);
            str    = str.substr(++idx);
            nStep  = std::stoi(str,&idx);

        }
        else if (!str.compare(0,5,"--nb="))
        {
            str = str.substr(5);
            nb = std::stoi(str);

        }
        else if (!str.compare(0,8,"--niter="))
        {
            str = str.substr(8);
            iter = std::stoi(str);
        }
        else if (!str.compare(0,7,"--meta="))
        {
            str = str.substr(7);
            meta = std::stoi(str);
        }
        else if (!str.compare(0,10,"--threads="))
        {
            str = str.substr(10);
            cores = std::stoi(str);
        }
        else if (!str.compare(0,8,"--nodes="))
        {
            str = str.substr(8);
            nodes = std::stoi(str);
        }
        else if (!str.compare(0,7,"--numa="))
        {
            if (str.size() < 8) std::cout << "warning: unknown option " << str << std::endl;
            else if (!str.compare(7,4,"rand"))
            {
                gen_dag  = &generate_meta_task_dag<node_func_rand>;
                rand_dag = &generate_random_matrix<node_func_rand>;
                numa_string = "rand";
            }
            else if (!str.compare(7,5,"trand"))
            {
                gen_dag  = &generate_meta_task_dag<node_func_true_rand>;
                rand_dag = &generate_random_matrix<node_func_true_rand>;
                numa_string = "trand";
            }
            else if (!str.compare(7,3,"col"))
            {
                gen_dag  = &generate_meta_task_dag<node_func_col>;
                rand_dag = &generate_random_matrix<node_func_col>;
                numa_string = "col";
            }
            else if (!str.compare(7,3,"det"))
            {
                rand_dag = &generate_random_matrix<node_func_det<panel_strat_rand, schur_strat_rand> >;
                if (str.size() < 14) std::cout << "warning: unknown option" << str << std::endl;
                else if (!str.compare(11,1,"R"))
                {
                    if      (!str.compare(13,1,"R"))
                    {
                        gen_dag = &generate_meta_task_dag<node_func_det<panel_strat_rand, schur_strat_rand> >;
                        numa_string = "RR";
                    }
                    else if (!str.compare(13,1,"A"))
                    {
                        gen_dag = &generate_meta_task_dag<node_func_det<panel_strat_rand, schur_strat_A> >;
                        numa_string = "RA";
                    }
                    else if (!str.compare(13,1,"B"))
                    {
                        gen_dag = &generate_meta_task_dag<node_func_det<panel_strat_rand, schur_strat_B> >;
                        numa_string = "RB";
                    }
                    else if (!str.compare(13,1,"C"))
                    {
                        gen_dag = &generate_meta_task_dag<node_func_det<panel_strat_rand, schur_strat_C> >;
                        numa_string = "RC";
                    }
                    else std::cout << "warning: unknown option " << str << std::endl;
                }
                else if (!str.compare(11,1,"T"))
                {
                    if      (!str.compare(13,1,"R"))
                    {
                        gen_dag = &generate_meta_task_dag<node_func_det<panel_strat_top, schur_strat_rand> >;
                        numa_string = "TR";
                    }
                    else if (!str.compare(13,1,"A"))
                    {
                        gen_dag = &generate_meta_task_dag<node_func_det<panel_strat_top, schur_strat_A> >;
                        numa_string = "TA";
                    }
                    else if (!str.compare(13,1,"B"))
                    {
                        gen_dag = &generate_meta_task_dag<node_func_det<panel_strat_top, schur_strat_B> >;
                        numa_string = "TB";
                    }
                    else if (!str.compare(13,1,"C"))
                    {
                        gen_dag = &generate_meta_task_dag<node_func_det<panel_strat_top, schur_strat_C> >;
                        numa_string = "TC";
                    }
                    else std::cout << "warning: unknown option " << str << std::endl;
                }
                else std::cout << "warning: unknown option " << str << std::endl;
            }
            else std::cout << "warning: unknown option " << str <<std::endl;
        }
        else if (!str.compare("--log"))
        {
            log = LOG;
        }
        else if (!str.compare("--nolog"))
        {
            log = NO_LOG;
        }
        else if (!str.compare("--log_once"))
        {
            log = LOG_ONCE;
        }
        else if (!str.compare("--sync"))
        {
            work_while_gen = false;
        }
        else if (!str.compare("--async"))
        {
            work_while_gen = true;
        }
        else if (!str.compare("--nosteal"))
        {
            work_stealing = false;
        }
        else if (!str.compare("--steal"))
        {
            work_stealing = true;
        }
        else if (!str.compare(0,7,"--path="))
        {
            log_path = str.substr(7);
            log_path.append("/");
        }
        else if (!str.compare(0,8,"--cores="))
        {
            str = str.substr(8);
            //std::cout << str << std::endl;
            while(!str.empty())
            {
                size_t idx;
                int temp = std::stoi(str,&idx);
                //std::cout << temp << " " << std::endl;
                threads.push_back(temp);
                if (++idx > str.size())  break;
                else                     str = str.substr(idx);
            }
        }
        else
        {
            std::cout << "warning: unknown parameter " << str << std::endl;
        }
    }

    meta_context& context = meta_context::getContext();

    std::vector<long long> times;

    if (threads.empty())
    {
        context.initialize(cores, nodes);
    } else
    {
        cores = threads.size();
        context.initialize(threads, nodes);
    }

    printf( "#\n"
            "# META %s\n"
            "# Meta tiles: %d\n"
            "# Nb threads: %d\n"
            "# Nb nodes:   %d\n"
            "# NB:         %d\n"
            "# NUMA opt:   %s\n"
            "#\n"
            "# Nb iter:    %d\n"
            "# %s \n"
            "# %s \n"
            "#\n#\n",
            argv[0],
            meta,
            cores,
            nodes,
            nb,
            numa_string.c_str(),
            iter,
            (work_while_gen) ? "asynchronized start" : "synchronized start",
            (work_stealing)  ? "work stealing"       : "no work stealing"
            );
    printf( "#     M       N  K/NRHS   seconds   Gflop/s Deviation  logfile\n" );

    // WARM UP
    {
        matrix_desc matrix(16384,16384,256,256);
        rand_dag(matrix, meta, work_while_gen, false);
        context.work_0();
        context.cleanup();

        gen_dag(matrix, meta, work_while_gen, work_stealing);
        context.work_0();
        context.cleanup();
    }

    {
        matrix_desc matrix(16384,16384,256,256);
        rand_dag(matrix, meta, work_while_gen, false);
        context.work_0();
        context.cleanup();

        gen_dag(matrix, meta, work_while_gen, work_stealing);
        context.work_0();
        context.cleanup();
    }

    {
        matrix_desc matrix(16384,16384,256,256);
        rand_dag(matrix, meta, work_while_gen, false);
        context.work_0();
        context.cleanup();

        gen_dag(matrix, meta, work_while_gen, work_stealing);
        context.work_0();
        context.cleanup();
    }

    char filename[128];
    for (int n = nStart; n <= nEnd; n += nStep)
    {
        for (int i = 0; i < iter; ++i)
        {


            // ALLOCATE MATRIX could be moved to outer for loop
            matrix_desc matrix(n, n, nb, nb);

            // RANDOMIZE MATRIX
            rand_dag(matrix, meta, work_while_gen, false);

            context.work_0();
            context.cleanup();

            #ifdef CHECK
                matrix_desc m_copy(n,n,n,n);
                m_copy.initialize(0,0);
                copy(matrix, m_copy);
                //print(matrix);
            #endif

            // START TIME MEASURE
            auto start = std::chrono::high_resolution_clock::now();

            // CREATE JOB DAG
            gen_dag(matrix, meta, work_while_gen, work_stealing);

            // RUN DECOMPOSITION
            context.work_0();
            auto duration = std::chrono::duration_cast<std::chrono::nanoseconds>
                    (std::chrono::high_resolution_clock::now() - start);

            // END TIME MEASURE
            times.push_back(duration.count());

            // CREATE LOG TO VIEW
            #ifdef CREATE_LOG
                if (log == LOG || (log == LOG_ONCE && i == iter-1))
                {
                    std::sprintf(filename, "%slog_N%d_Nb%d_M%d_numa%s_%s_%s_I%d.txt",
                                 log_path.c_str(),
                                 matrix.nt,
                                 nb,
                                 meta,
                                 numa_string.c_str(),
                                 (work_while_gen) ? "async" : "sync",
                                 (work_stealing)  ? "ws"    : "nows",
                                 i);
                    create_log(n, nb, start, duration.count(), filename);
                }
            #endif

            // RESET CONTEXT
            context.cleanup();

            #ifdef CHECK
                matrix_desc lapack_L(n,n,n,n);
                lapack_L.initialize(0,0);
                matrix_desc lapack_U(n,n,n,n);
                lapack_U.initialize(0,0);
                matrix_desc lapack_LU(n,n,n,n);
                lapack_LU.initialize(0,0);
                //matrix_desc lapack_D(n,n,n,n);

                copy_lower(matrix, lapack_L);
                copy_upper(matrix, lapack_U);
                multiply(lapack_L, lapack_U, lapack_LU);
                permute_back(lapack_LU, matrix.ipiv);
                //compare(m_copy, lapack_LU, lapack_D);
                compare(m_copy, lapack_LU);
                //std::cout << std::endl;
                //print(lapack_LU);
                //std::cout << std::endl << std::endl;
                //print(lapack_D);
            #endif

        }

        double time_sum(0.);
        double gflop  = (2./3.*n*n*n - 1./2.*n*n + 5./6.*n)/1000000000.;
        double gflops(0.);
        double gflops_sq(0.);
        for (auto t : times)
        {
            double dt(t/1000000000.);
            //std::cout << dt << " ";
            time_sum  += dt;
            gflops    += gflop/dt;
            gflops_sq += gflop/dt * gflop/dt;
        }
        //std::cout << std::endl;
        double deviation(0);
        if (iter>1) deviation = (gflops_sq - (gflops*gflops/iter))/(iter-1);
        printf( "%7d %7d %7d %9.3f %9.2f %9.2f  # %s \n",
                n, n, 1, time_sum/iter, gflops/iter, deviation, filename);

        times.clear();
    }
    context.shut_down();

    return 0;
}
