#include "matrix_desc.h"

matrix_desc::matrix_desc(int m, int n, int mb, int nb)
    : n(n), m(m), nt(std::ceil((float)n/nb)), mt(std::ceil((float)m/mb)),
      nb(nb), mb(mb), col_size(std::ceil((float)m/mb)),
      submatrix(false), ton(0), tom(0)
{
    data = new double*[nt*mt];
    for (int i=0; i<nt*mt; ++i) data[i] = nullptr;

    ipiv = new int[m];
    for (int i = 0; i < m; i++)
    {
        ipiv[i] = -1;
    }
}


matrix_desc::matrix_desc(matrix_desc& mat, int j, int i, int mt, int nt)
    : n(0), m(0), nt(nt), mt(mt), nb(mat.nb), mb(mat.mb),
      col_size(mat.col_size),
      submatrix(true), ton(i), tom(j),
      data(mat.data+mat.col_size*i+j), ipiv(mat.ipiv + j*mb) // delete: data(mat(j,i))
{
    if (i+nt > mat.nt)
        std::cout << "error: submatrix exceeds the matrix dimensions i:"
                  << i << " snt:" << nt << " nt:" << mat.nt << std::endl;
    if (j+mt > mat.mt)
        std::cout << "error: submatrix exceeds the matrix dimensions j:"
                  << j << " smt:" << mt << " mt:" << mat.mt << std::endl;
    n = (i+nt == mat.nt) ? mat.n - i*nb : nt*nb;
    m = (j+mt == mat.mt) ? mat.m - j*mb : mt*mb;
}

matrix_desc::~matrix_desc()
{
    if (!submatrix)
    {
        #ifdef FORCE_NUMA
        // if the matrix is allocated through libnuma numa_free has to be used
        for (int i=0; i < nt; i+=col_size)
        {
            for (int j=0; j < mt; ++j)
            {
                numa_free(data[j+i], tm(j)*tn(i)*8);
            }
        }
        #else
        for (int i=0; i<nt*mt; ++i)
        {
            if (data[i]) delete[] data[i];
        }
        #endif

        delete[] data;
        delete[] ipiv;
    }
}

void matrix_desc::initialize(int j, int i)
{
    int position = j+i*col_size;
    if (data[position])
    {
        //std::cout << "error: tile already initialized" << std::endl;
        return;
    }

#ifdef FORCE_NUMA
    data[position] = (double*) numa_alloc_local(tn(i)*tm(j)*8);
    // Alternatively force interleaved memory allocation (no NUMA optimizations possible):
    // data[position] = (double*) numa_alloc_interleaved(tn(i)*tm(j)*8);
#else
    data[position] = new double[tn(i)*tm(j)];
#endif

    if (data[position] == nullptr) std::cout << "error: not enough memory" << std::endl;
}

double* matrix_desc::operator() (std::size_t m, std::size_t n)
{
    if (n >= nt || m >= mt) return nullptr;

    if (!data[m+n*col_size])
    {
        std::cout << "error: tile not initialized" << std::endl;
    }

    return data[m+n*col_size];
}

double matrix_desc::get(int j, int i)
{
    int ms = tm(j);
    int ti  = i % nb;
    int tti = i / nb;
    int tj  = j % mb;
    int ttj = j / mb;

    double* block = (*this)(ttj, tti);
    block += ti*ms + tj;
    return *block;
}

int matrix_desc::tn(std::size_t i)
{
    return (i == nt - 1) ? n - i*nb : nb;
}

int matrix_desc::tm(std::size_t j)
{
    return (j == mt-1) ? m - j*mb : mb;
}
