#include "kernel.h"

// THE FUNCTIONS DEFINED IN THIS FILE, ARE IN PART MODIFIED VERSIONS OF
// FUNCTIONS DISTRIBUTED WITHIN THE PLASMA LIBRARY (version 2.6)
// THEY WERE ORIGINALLY PUBLISHED UNDER THE FOLLOWING LICENSE:
/*
-- Innovative Computing Laboratory
-- Electrical Engineering and Computer Science Department
-- University of Tennessee
-- (C) Copyright 2008-2010

Redistribution  and  use  in  source and binary forms, with or without
modification,  are  permitted  provided  that the following conditions
are met:

* Redistributions  of  source  code  must  retain  the above copyright
  notice,  this  list  of  conditions  and  the  following  disclaimer.
* Redistributions  in  binary  form must reproduce the above copyright
  notice,  this list of conditions and the following disclaimer in the
  documentation  and/or other materials provided with the distribution.
* Neither  the  name of the University of Tennessee, Knoxville nor the
  names of its contributors may be used to endorse or promote products
  derived from this software without specific prior written permission.

THIS  SOFTWARE  IS  PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS''  AND  ANY  EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED  TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A  PARTICULAR  PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL,  EXEMPLARY,  OR  CONSEQUENTIAL  DAMAGES  (INCLUDING,  BUT NOT
LIMITED  TO,  PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA,  OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY  OF  LIABILITY,  WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING  NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF  THIS  SOFTWARE,  EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

*/

kernel::exchange_field::exchange_field() : fin_flag(false), counter(-1)
{
    char s = 'S';
    sfmin = dlamch(&s);
}

void kernel::exchange_field::initialize(int t_n)
{
    value    = std::vector<double>(t_n, -1.0);
    counter  = 0;
}

void kernel::compute(meta_task_ptr meta, int task_id)
{
    switch (meta->type)
    {
    case PANEL:
        compute_panel(meta, task_id);
        break;

    case B:
        compute_b(meta, task_id);
        break;

    case SCHUR_COMPLEMENT:
        compute_schur(meta, task_id);
        break;

    case BEHIND_PANEL_UPDATE:
        compute_behind(meta, task_id);
        break;
    case RANDOMIZE:
        compute_randomize(meta, task_id);
        break;

    default:
        std::cout << "error: unknown meta task type" << std::endl;
    }
}

inline int min( int i1, int i2 )
{
    return (i1 < i2) ? i1 : i2;
}

void kernel::compute_b(meta_task_ptr meta, int task_id)
{

    matrix_desc& A = meta->matrix;
    matrix_desc sub(A, meta->k, meta->x+task_id, A.mt-meta->k, 1);

    int i2 = sub.tm(0);

    dlaswp_tiled(sub, i2);

    cblas_dtrsm( CblasColMajor, CblasLeft, CblasLower,
                 CblasNoTrans, CblasUnit,
                 i2, sub.n, 1.0,
                 A(meta->k, meta->k), i2,
                 sub(0, 0)          , i2 );
}


void kernel::
compute_schur(meta_task_ptr meta, int task_id)
{
    // column major
    int i = meta->x + task_id / meta->meta_y;
    int j = meta->y + task_id % meta->meta_y;

    matrix_desc& A = meta->matrix;
    int ni(A.tn(i)), mj(A.tm(j)), mk(A.tm(meta->k));

    cblas_dgemm(CblasColMajor, (CBLAS_TRANSPOSE) 111, (CBLAS_TRANSPOSE) 111, // 111 = PlasmaNoTrans
                mj, ni, A.nb,
                (double)-1.0, A( j      , meta->k), mj,
                              A( meta->k, i      ), mk,
                (double)1.0,  A( j      , i      ), mj);
}

void kernel::compute_behind(meta_task_ptr meta, int task_id)
{
    matrix_desc& A  = meta->matrix;
    matrix_desc sub = matrix_desc(A, meta->k, meta->x+task_id, A.mt-meta->k, 1);

    int mintmp = min( A.tm(meta->k), A.tn(meta->k) );

    dlaswp_tiled(sub, mintmp);
}

void kernel::dlaswp_tiled(matrix_desc& mat, int i2)
{
    int lda1 = mat.tm(0);
    double *A1 = mat(0,0);
    int* ipiv = mat.ipiv;

    for (int j = 0; j < i2; ++j, ipiv += 1)
    {
        int ip = (*ipiv) - mat.tom*mat.mb - 1;
        if ( ip != j )
        {
            int it   = ip / mat.mb;
            int i    = ip % mat.mb;
            int lda2 = mat.tm(it);
            cblas_dswap(mat.n, A1       + j, lda1,
                               mat(it, 0) + i, lda2 );
        }
    }
}

void kernel::compute_panel(std::shared_ptr<meta_task> meta, int task_id)
{
    // Submatrix Descriptor;
    matrix_desc A(meta->matrix, meta->k, meta->k, meta->matrix.mt - meta->k, 1);

    int first_tile, last_tile;
    int thread_count = meta->nlocal;
    int minMN = min(A.m, A.n);

    if (thread_count > A.mt) std::cout << "error: more threads than tiles in panel task" << std::endl;

    // split workload vertically
    int q = A.mt / thread_count;
    int r = A.mt % thread_count;
    if (task_id < r)
    {
        q++;
        first_tile = task_id * q;
        last_tile = first_tile + q;
    }
    else
    {
        first_tile = r*(q+1) + (task_id-r)*q;
        last_tile = first_tile + q;
        last_tile = min(last_tile, A.mt);
    }

    // call recursive function
    if (task_id == 0)
    {
        ex.initialize(meta->nlocal);
    }
    else
    {
        while (ex.counter.load() < 0) { }
    }


    double pivot;

    panel_recursion(A, A.ipiv, &pivot,
                    task_id, meta->nlocal, 0, minMN, first_tile, last_tile);

    // if the panel is wider than it is high update (we only compute square matrices, so this is unreachable)
    if ( A.n > minMN ) {
        std::cout << "error: panel update needed but not yet implemented" << std::endl;
    }
    if (task_id == 0) ex.counter = -1;
}

void kernel::panel_recursion(matrix_desc& A, int* ipiv, double* pivot,
                             int t_id, int t_n, int col, int width, int ft, int lt)
{
    int ldft = A.tm(0);
	double *Atop = A(0, 0) + col*ldft;
	int offset = A.tom*A.mb;

	double pivval;
	int piv_sf;
	int max_i, max_it;
	double tmp1, abstmp1;
	double tmp2 = 0.;

    if (width > 1)
    {
        int n1 = width/2;
        int n2 = width - n1;
        double* Atop2 = Atop + n1*ldft;
		
        panel_recursion(A, ipiv, pivot, t_id, t_n, col, n1, ft, lt);

		double*  U;

        if (t_id == 0)
        {
            /* swap to the right */
            int *lipiv = ipiv + col;
            int idxMax = col + n1;
            for (int j = col; j < idxMax; ++j, ++lipiv)
            {
                int ip = (*lipiv) - offset -1;
                if (ip != j)
                {
                    int it = ip / A.mb;
                    int i  = ip % A.mb;
                    int ld = A.tm(it);
                    cblas_dswap(n2, Atop2                    + j, ldft,
                                    A(it,0) + (col+n1)*ld + i, ld   );
                }
            }

            /* trsm on the upper part */
            U = Atop2 + col;
            
            cblas_dtrsm( CblasColMajor, CblasLeft, CblasLower,
                         CblasNoTrans, CblasUnit,
                         n1, n2, (1.0),
                         Atop+col, ldft,
                         U,        ldft );

            /* signal to other threads that they can start update */
            barrier(t_id,t_n);

            pivval = *pivot;
            
            if ( pivval == 0.0 )
            {
                std::cout << "error: pivot value = 0 in column:" << col+n1 << std::endl;
                return;
            }
            else
            {
                if ( std::fabs(pivval) >= ex.sfmin )
                {
                    piv_sf = 1;
                    pivval = 1.0/pivval;
                }
                else
                {
                    piv_sf = 0;
                }
            }

            /* first tile */
            {
                double* L = Atop + col + n1;
                int tmpM = min(ldft, A.m) - col - n1;

                /* scale last column of L */
                if ( piv_sf )
                {
                    cblas_dscal( tmpM, pivval, L+(n1-1)*ldft, 1);
                }
                else
                {
                    Atop2 = L+(n1-1)*ldft;
                    for (int i = 0; i < tmpM; i++, Atop2++) *Atop2 = *Atop2 / pivval;
                }

                /* Apply the GEMM */
                cblas_dgemm( CblasColMajor, CblasNoTrans, CblasNoTrans,
                             tmpM, n2, n1,
                             (-1.0), L,    ldft,
                                     U,    ldft,
                              (1.0), U+n1, ldft);

                /* Search Max in first column of U + n1 */
                tmp2    = U[n1];
                max_it  = ft;
                max_i   = cblas_idamax( tmpM, U+n1, 1 ) +n1;
                tmp1    = U[max_i];
                abstmp1 = std::fabs(tmp1);
                max_i  += col;
            }

        }
        else  /* ******************************************** t_id != 1 ********** */
        {
			pivval = *pivot;
			if (pivval == 0.0) 
			{
                std::cout << "error: pivval == 0 in column:" <<  col + n1 << std::endl;
				return;
			}
			else 
			{
                if (std::fabs(pivval) >= ex.sfmin)
				{
					piv_sf = 1;
					pivval = 1.0 / pivval;
				}
				else 
				{
					piv_sf = 0;
				}
			}

			int     ld = A.tm(ft);
			double*  L = A(ft, 0) + col * ld;
			int     lm = (ft == A.mt - 1) ? A.m - ft * A.mb : A.mb;
			
			U = Atop2 + col;

			/* First tile */
			/* Scale last column of L */
			if (piv_sf) 
			{
				cblas_dscal(lm, (pivval), L + (n1 - 1)*ld, 1);
			}
			else 
			{
				Atop2 = L + (n1 - 1)*ld;
				for (int i = 0; i < lm; i++, Atop2++)
					*Atop2 = *Atop2 / pivval;
			}

			/* Wait for pivoting and triangular solve to be finished
			* before to really start the update */
            barrier(t_id, t_n);

			/* Apply the GEMM */
            cblas_dgemm( CblasColMajor, CblasNoTrans, CblasNoTrans,
				lm, n2, n1,
				(-1.0), L, ld,
				U, ldft,
				( 1.0), L + n1*ld, ld);

			/* Search Max in first column of L+n1*ld */
			max_it = ft;
			max_i = cblas_idamax(lm, L + n1*ld, 1);
			tmp1 = L[n1*ld + max_i];
			abstmp1 = fabs(tmp1);
        }

		/* Update the other blocks */
		for (int it = ft + 1; it < lt; it++)
		{
			int    ld = A.tm(it);
			double* L = A(it, 0) + col * ld;
			int    lm = (it == A.mt - 1) ? A.m - it * A.mb : A.mb;

			/* Scale last column of L */
			if (piv_sf) 
			{
				cblas_dscal(lm, (pivval), L + (n1 - 1)*ld, 1);
			}
			else 
			{
				int i;
				Atop2 = L + (n1 - 1)*ld;
				for (i = 0; i < lm; i++, Atop2++)
					*Atop2 = *Atop2 / pivval;
			}

			/* Apply the GEMM */
            cblas_dgemm(CblasColMajor, CblasNoTrans, CblasNoTrans,
				lm, n2, n1,
				(-1.0), L        , ld,
				        U        , ldft,
				(1.0) , L + n1*ld, ld);

			/* Search the max on the first column of L+n1*ld */
			int jp = cblas_idamax(lm, L + n1*ld, 1);
			if (fabs(L[n1*ld + jp]) > abstmp1) 
			{
				tmp1 = L[n1*ld + jp];
				abstmp1 = fabs(tmp1);
				max_i = jp;
				max_it = it;
			}
		}

		int thwin;
		int jp = offset + max_it*A.mb + max_i;

        exchange_max(tmp1, t_id, t_n, &thwin, &tmp2, pivot, jp+1, ipiv + col + n1, col+n1);

		if (t_id == 0) 
		{
			U[n1] = *pivot; /* all threads have the pivot element: no need for synchronization */
		}
		if (thwin == t_id) /* the thread that owns the best pivot */
		{ 
			if (jp - offset != col + n1) /* if there is a need to exchange the pivot */
			{
				int ld = A.tm( max_it );
				Atop2 = A(max_it, 0) + (col + n1)* ld + max_i;
				*Atop2 = tmp2;
			}
		}

        panel_recursion(A, ipiv, pivot, t_id, t_n, col + n1, n2, ft, lt);
		
		if (t_id == 0)
		{
			/* Swap to the left */
			int *lipiv = ipiv + col + n1;
			int idxMax = col + width;
			for (int j = col + n1; j < idxMax; ++j, ++lipiv) 
			{
				int ip = (*lipiv) - offset - 1;
				if (ip != j)
				{
					int it = ip / A.mb;
					int i  = ip % A.mb;
					int ld = A.tm(it);
					cblas_dswap(n1, Atop + j, ldft,
						A(it, 0) + col*ld + i, ld);
				}
			}
		}

    }
	else if (width == 1)
	{

		if (col == 0)
		{

			if (t_id == 0)
				tmp2 = Atop[col];

			/* First tmp1 */
			int ld = A.tm(ft);
			double* Atop2 = A(ft, 0);
            int lm = (ft == A.mt - 1) ? A.m - ft * A.mb : A.mb;
            max_it = ft;
			max_i = cblas_idamax(lm, Atop2, 1);
			tmp1 = Atop2[max_i];
			abstmp1 = fabs(tmp1);

			/* Update */
			for (int it = ft + 1; it < lt; it++)
			{
				Atop2 = A(it, 0);
                int lm = it == A.mt - 1 ? A.m - it * A.mb : A.mb;
                int jp = cblas_idamax(lm, Atop2, 1);
				if (std::fabs(Atop2[jp]) > abstmp1) {
					tmp1 = Atop2[jp];
					abstmp1 = fabs(tmp1);
					max_i = jp;
					max_it = it;
				}
			}

			int thwin;
            int jp = offset + max_it*A.mb + max_i;

            exchange_max(tmp1, t_id, t_n, &thwin, &tmp2, pivot, jp+1, ipiv + col, col);

			if (t_id == 0) 
            {
				Atop[0] = *pivot;  /* all threads have the pivot element: no need for synchronization */
			}
			if (thwin == t_id)  /* the thread that owns the best pivot */
			{
				if (jp - offset != 0)  /* if there is a need to exchange the pivot */
				{
					Atop2 = A(max_it, 0) + max_i;
					*Atop2 = tmp2;
				}
			}
		}

        barrier(t_id, t_n);

		/* If it is the last column, we just scale */
        if (col == min(A.m, A.n) - 1)
		{
			double pivval = *pivot;

			if (pivval != 0.0)
			{
				if (t_id == 0)
				{
                    if (std::fabs(pivval) >= ex.sfmin)
                    {
						pivval = 1.0 / pivval;

						/*
						* We guess than we never enter the function with m == A.mt-1
						* because it means that there is only one thread
						*/


                        int lm = (ft == A.mt - 1) ? A.m - ft * A.mb : A.mb;
						cblas_dscal(lm - col - 1, (pivval), Atop + col + 1, 1);

						for (int it = ft + 1; it < lt; it++)
						{
							int ld = A.tm(it);
							double* Atop2 = A(it, 0) + col * ld;
                            int lm = (it == A.mt - 1) ? A.m - it * A.mb : A.mb;
							cblas_dscal(lm, (pivval), Atop2, 1);
						}


					}
                    else   /* *********** fabs(pivval) < ex.sfmin ***************** */
					{


						/*
						* We guess than we never enter the function with m == A.mt-1
						* because it means that there is only one thread
						*/
						double* Atop2 = Atop + col + 1;
                        int lm = ft == A.mt - 1 ? A.m - ft * A.mb : A.mb;
						for (int i = 0; i < lm - col - 1; i++, Atop2++)
							*Atop2 = *Atop2 / pivval;

						for (int it = ft + 1; it < lt; it++)
						{
							int ld = A.tm(it);
							Atop2 = A(it, 0) + col * ld;
                            int lm = it == A.mt - 1 ? A.m - it * A.mb : A.mb;

							for (int i = 0; i < lm; i++, Atop2++)
								*Atop2 = *Atop2 / pivval;
						}

                    }
				}
				else     /* ************** t_id != 0 ******** */
				{
                    if (std::fabs(pivval) >= ex.sfmin)
                    {
						pivval = 1.0 / pivval;

						for (int it = ft; it < lt; it++)
						{
							int ld = A.tm(it);
                            double* Atop2 = A(it, 0) + col * ld;
                            int lm = it == A.mt - 1 ? A.m - it * A.mb : A.mb;
							cblas_dscal(lm, (pivval), Atop2, 1);
						}
					}
					else
					{
						/*
						* We guess than we never enter the function with m == A.mt-1
						* because it means that there is only one thread
						*/
						for (int it = ft; it < lt; it++)
						{
							int ld = A.tm(it);
							double* Atop2 = A(it, 0) + col * ld;
                            int lm = it == A.mt - 1 ? A.m - it * A.mb : A.mb;

							for (int i = 0; i < lm; i++, Atop2++)
								*Atop2 = *Atop2 / pivval;
						}
					}
				}
			}
            else     /* ***************** pivval == 0 ********************** */
			{
				std::cout << "error: pivval = 0 in column:" << col + 1 << std::endl;
				return;
			}
		}
	}
}


void kernel::exchange_max(double localamx, int t_id, int t_n, int *t_win, double *diagvalue, double *globalamx, int pividx, int *ipiv, int col)
{


    /* everybody inserts his local value into the value array */
    ex.value.at(t_id) = localamx;
    if (t_id == 0) ex.diag = *diagvalue;    /* set the value of the winning thread */

    int temp = ++(ex.counter);

    if ( temp == t_n )
    {
        // at this point everybody filled in their localamax into the value array
        int j = 0;
        double curval = ex.value.at(0);
        double tmp;
        double curamx = std::fabs(curval);

        for (int i = 1; i < t_n; ++i)
        {
            tmp = ex.value.at(i);
            if (std::fabs(tmp) > curamx)
            {
                curamx = std::fabs(tmp);
                curval = tmp;
                j = i;
            }
        }

        /* make sure everybody knows the amax value */
        ex.fin_value = curval;

        ex.thread = j;

        *t_win = j;
        *globalamx = curval;


        /* signalize that the maximum has been found */
        ex.fin_flag = true;
    }
    else  /* ******************************************* t_id != 0 ************* */
    {
        while (!ex.fin_flag) { }

        *t_win      = ex.thread;
        *globalamx  = ex.fin_value;
    }

    *diagvalue  = ex.diag;
    if (t_id == *t_win) ipiv[0] = pividx;

    temp = --(ex.counter);

    if (temp == 0)
    {
        ex.fin_flag = false;
    }
    else
    {
        while (ex.fin_flag) {  }
    }

}


void kernel::barrier(int t_id, int t_n)
{
    int temp = ++(ex.counter);

    if ( temp == t_n )
    {
        ex.fin_flag = true;
    }
    else { while (!ex.fin_flag) { } }

    temp = --(ex.counter);

    if ( temp == 0 ) { ex.fin_flag = false; }
    else { while (ex.fin_flag) { } }
}

#define Rnd64_A  6364136223846793005ULL
#define Rnd64_C  1ULL
#define RndF_Mul 5.4210108624275222e-20f
#define RndD_Mul 5.4210108624275222e-20

static inline unsigned long long int Rnd64_jump(unsigned long long int n, unsigned long long int seed ) {
  unsigned long long int a_k, c_k, ran;
  int i;

  a_k = Rnd64_A;
  c_k = Rnd64_C;

  ran = seed;
  for (i = 0; n; n >>= 1, ++i) {
    if (n & 1)
      ran = a_k * ran + c_k;
    c_k *= (a_k + 1);
    a_k *= a_k;
  }

  return ran;
}

void kernel::compute_randomize(meta_task_ptr meta, int task_id)
{
    int x = meta->x + task_id / meta->meta_y;
    int y = meta->y + task_id % meta->meta_y;

    matrix_desc& A = meta->matrix;
    A.initialize(y, x);
    double *tmp = A(y, x);

    unsigned long long int ran, jump;

    jump = (unsigned long long int)(y*A.mb) + (unsigned long long int)(x*A.nb) * (unsigned long long int)A.m;

    for (int i = 0; i < A.tn(x); ++i)
    {
        ran = Rnd64_jump( jump, meta->k );
        for (int j = 0; j < A.tm(y); ++j)
        {
            *tmp = 0.5f - ran * RndF_Mul;
            ran  = Rnd64_A * ran + Rnd64_C;
            tmp++;
        }
        jump += A.m;
    }
}
