#include "utility.h"

void print(matrix_desc &m)
{
    for (int j = 0; j < m.m; j++)
    {
        for (int i = 0; i < m.n; i++)
        {
            double temp = m.get(j,i);
            std::cout << temp << " ";
        }
        std::cout << std::endl;
    }
}

void copy_lower(matrix_desc& source, matrix_desc& target)
{
    if (source.n != target.n || source.m != target.m) return;

    int n(source.n);
    int m(source.m);

    for (int j = 0; j < m; ++j)
    {
        (*target.data)[j*m+j] = 1;
        for (int i = 0; i < j && i < n; i++)
        {
            (*target.data)[i*m+j] = source.get(j,i);
        }
    }
}

void copy_upper(matrix_desc& source, matrix_desc& target)
{
    if (source.n != target.n || source.m != target.m) return;

    int n(source.n);
    int m(source.m);

    for (int j = 0; j < m; ++j)
    {
        for (int i = j; i < n; i++)
        {
            (*target.data)[i*m+j] = source.get(j,i);
        }
    }
}

void multiply(matrix_desc& m1, matrix_desc& m2, matrix_desc& target)
{
    if (m1.n != m2.m || m1.m != target.m || m2.n != target.n)
    {
        std::cout << "error: unmatching matrix sizes" << std::endl;
        return;
    }
    if (m1.mt > 1 || m2.mt > 1 || target.mt > 1)
    {
        std::cout << "error: matrix not in column major form" << std::endl;
        return;
    }
    cblas_dgemm(CblasColMajor, (CBLAS_TRANSPOSE) 111, (CBLAS_TRANSPOSE) 111, // 111 = NoTrans
                m1.m, m2.n, m1.n,
                (double)1. , *m1.data,     m1.m,
                             *m2.data,     m2.m,
                (double)0. , *target.data, target.m);
}

void permute_back(matrix_desc& m, int* ipiv)
{
    if (m.mt > 1)
    {
        std::cout << "error: matrix not in column major form" << std::endl;
        return;
    }
    int* t_ipiv = ipiv + m.m - 1;
    for (int k = m.m; k > 0; --k, --t_ipiv)
    {
        if (k != *t_ipiv)
        {
            int tk = k-1; int tl = *t_ipiv-1;
            cblas_dswap(m.n, (*m.data)+tk, m.m,
                             (*m.data)+tl, m.m);
        }
    }

}

void compare(matrix_desc& m1, matrix_desc& m2, matrix_desc& target)
{
    if (m1.n != m2.n || m1.m != m2.m || m1.n != target.n || m1.m != target.m)
    {
        std::cout << "error: matrices are not matching" << std::endl;
        return;
    }
    if (target.mt > 1)
    {
        std::cout << "error: matrix not in column major form" << std::endl;
        return;
    }
    for (int i = 0; i < m1.n; ++i)
    {
        for (int j = 0; j < m1.m; ++j)
        {
            (*target.data)[i*target.m+j] = m1.get(j,i) - m2.get(j,i);
        }
    }
}

void compare(matrix_desc& m1, matrix_desc& m2)
{
    if (m1.n != m2.n || m1.m != m2.m)
    {
        std::cout << "error: matrices are not matching" << std::endl;
        return;
    }
    double max = 0.0;
    for (int i = 0; i < m1.n; ++i)
    {
        for (int j = 0; j < m1.m; ++j)
        {
            double temp = m1.get(j,i) - m2.get(j,i);
            if (std::fabs(temp) > std::fabs(max))
            {
                max = temp;
            }
        }
    }

    std::cout << "maxdiff:" << max << std::endl;
}

void copy(matrix_desc& s, matrix_desc& t)
{
    if (s.n != t.n || s.m != t.m)
    {
        std::cout << "error: matrices are not matching" << std::endl;
        return;
    }
    if (t.mt > 1)
    {
        std::cout << "error: matrix not in column major form" << std::endl;
        return;
    }
    for (int i = 0; i < s.n; ++i)
    {
        for (int j = 0; j < s.m; ++j)
        {
            (*t.data)[i*t.m+j] = s.get(j,i);
        }
    }
}
//*/
template<typename T>
void p(std::ofstream & f, T t, int space)
{
    f.width(space);
    f << t << " ";
}

void create_log(int n, int nb, std::chrono::high_resolution_clock::time_point start, long long time, std::string name)
{
    std::ofstream file(name);
    std::vector<meta_thread>& threads = meta_context::getContext().threads;

    int sum(0);
    for (auto t: threads)
    {
        sum += t.log.size();
    }

    file << "# Nb threads:  " << threads.size()   << std::endl;
    file << "# matrix size: " << n                << std::endl;
    file << "# tile size:   " << nb               << std::endl;
    file << "# Nb jobs:     " << sum              << std::endl;
    file << "# time:        " << time/1000000000. << std::endl << std::endl;

    for (auto t: threads)
    {
        for (auto l: t.log)
        {
            auto s = std::chrono::duration_cast<std::chrono::nanoseconds>(l.start - start);
            auto e = std::chrono::duration_cast<std::chrono::nanoseconds>(l.end   - start);
            p(file, t.id, 3);
            p(file, l.task_id, 3);
            p(file, l.k, 3);
            p(file, l.x, 3);
            p(file, l.y, 3);
            p(file, l.size_y, 3);
            p(file, s.count(), 13);
            p(file, e.count(), 13);
            p(file, l.cache_m, 6);
            p(file, l.cache_a, 9);
            p(file, (double)l.cache_m / l.cache_a, 13);
            file << std::endl;
        }
    }
    file.close();
}
