/*
g++ -o matrix_prod.exe matrix_prod.cpp -O2 -std=c++11 -fmax-errors=1 -ftree-vectorize -funroll-loops -fopenmp -lgsl -lgslcblas -lblas
*/

#include <string>
using namespace std;
#include "matrix.h"
#include "timer.h"

using namespace base;
using namespace mylib;

// default dimension of square matrix
u32 dimension = 1024;
// maximum number of threads used by openmp, can be set by user
// in the command line
int max_threads=1;

/**
 * test sum of two matrices using sum_impl1 implementation
 */
void test_sum1() {
	cout << "test matrix addition with sum_impl1" << endl;
	
	Matrix<double> a(dimension);
	Matrix<double> b(dimension);
	Matrix<double> c(dimension);

	a.fill(1);
	b.fill(2);

	Timer timer;

	timer.start();
	sum_impl1(c, a, b);
	timer.stop();
	cout << "time=" << timer << endl;
	cout << "accumulate=" << c.get_data().sum() << endl;
}

/**
 * test sum of two matrices using sum_impl2 implementation
 */
void test_sum2() {
	cout << "test matrix addition with sum_impl2" << endl;
	
	Matrix<double> a(dimension);
	Matrix<double> b(dimension);
	Matrix<double> c(dimension);

	a.fill(1);
	b.fill(2);

	Timer timer;

	timer.start();
	sum_impl2(c, a, b);
	timer.stop();
	cout << "time=" << timer << endl;
	cout << "accumulate=" << c.get_data().sum() << endl;
}

/**
 * test multiplication of two matrices using mul_impl1 implementation
 */
void test_mul1() {
	cout << "test matrix multiplication with mul_impl1" << endl;
	
	Matrix<double> a(dimension);
	Matrix<double> b(dimension);
	Matrix<double> c(dimension);

	a.fill(1);
	b.fill(2);

	cout << "a.address=" << hex << &a.get_data()[0] << endl;
	cout << "b.address=" << hex << &b.get_data()[0] << endl;
	cout << "c.address=" << hex << &c.get_data()[0] << endl;
	cout << dec;
	
	Timer timer;

	timer.start();
	mul_impl1(c, a, b);
	timer.stop();
	cout << "time=" << timer << endl;
	cout << "accumulate=" << c.get_data().sum() << endl;
}

/**
 * test multiplication of two matrices using mul_impl2 implementation
 */
void test_mul2() {
	cout << "test matrix multiplication with mul_impl2" << endl;
	
	Matrix<double> a(dimension);
	Matrix<double> b(dimension);
	Matrix<double> c(dimension);

	a.fill(1);
	b.fill(2);

	Timer timer;

	timer.start();
	mul_impl2(c, a, b);
	timer.stop();
	cout << "time=" << timer << endl;
	cout << "accumulate=" << c.get_data().sum() << endl;
}

/**
 * test multiplication of two matrices using mul_impl3 implementation
 */
void test_mul3() {
	cout << "test matrix multiplication with mul_impl3" << endl;
	
	Matrix<double> a(dimension);
	Matrix<double> b(dimension);
	Matrix<double> c(dimension);

	a.fill(1);
	b.fill(2);

	Timer timer;

	timer.start();
	mul_impl3(c, a, b);
	timer.stop();
	cout << "time=" << timer << endl;
	cout << "accumulate=" << c.get_data().sum() << endl;
}

/**
 * test multiplication of two matrices using mul_impl4 implementation
 */
void test_mul4() {
	cout << "test matrix multiplication with mul_impl4" << endl;
	
	Matrix<double> a(dimension);
	Matrix<double> b(dimension);
	Matrix<double> c(dimension);

	a.fill(1);
	b.fill(2);

	Timer timer;

	timer.start();
	mul_impl4(c, a, b);
	timer.stop();
	cout << "time=" << timer << endl;
	cout << "accumulate=" << c.get_data().sum() << endl;
}

/**
 * test multiplication of two matrices using mul_impl5 implementation
 */
void test_mul5() {
	cout << "test matrix multiplication with mul_impl5" << endl;
	
	Matrix<double> a(dimension);
	Matrix<double> b(dimension);
	Matrix<double> c(dimension);

	a.fill(1);
	b.fill(2);

	Timer timer;

	timer.start();
	mul_impl5(c, a, b);
	timer.stop();
	cout << "time=" << timer << endl;
	cout << "accumulate=" << c.get_data().sum() << endl;
}


/**
 * test matrix computation
 */
void test_mul_op() {
	cout << "test matrix multiplication using operators" << endl;
	
	Matrix<double> a(dimension);
	Matrix<double> b(dimension);
	Matrix<double> c(dimension);
	Matrix<double> d(dimension);
	Matrix<double> r1(dimension);
	Matrix<double> r2(dimension);

	a.fill(1);
	b.fill(2);
	c.fill(3);
	d.fill(-1);

	Timer timer;

	timer.start();
	r1 = a * b + c * d;
	timer.stop();
	cout << "r1.time=" << timer << endl;
	cout << "r1.accumulate=" << r1.get_data().sum() << endl;
		
	Matrix<double> t1(dimension);
	Matrix<double> t2(dimension);
		
	timer.start();	
	mul_impl5(t1, a , b);
	mul_impl5(t2, c , d);
	sum_impl2(r2, t1, t2); 
	timer.stop();
	cout << "r2.time=" << timer << endl;
	cout << "r2.accumulate=" << r2.get_data().sum() << endl;
}

void test_all() {
	test_sum1();
	test_sum2();
	test_mul1();
	test_mul2();
	test_mul3();
	test_mul4();
	test_mul5();
	test_mul_op();
}

typedef struct {
	string name;
	void (*ptr_method)();
} method_t;

method_t methods[] = {
	{ "all", test_all},
	{ "sum1", test_sum1},
	{ "sum2", test_sum2},
	{ "mul1", test_mul1},
	{ "mul2", test_mul2},
	{ "mul3", test_mul3},
	{ "mul4", test_mul4},
	{ "mul5", test_mul5},
	{ "mulop", test_mul_op},
	{ "", nullptr} // end of methods marker
};


int get_method(string s) {
	for (int i=0; methods[i].ptr_method != nullptr; ++i) {
		if (s == methods[i].name) return i;
	}
	return -1;
}


void usage(char *argv[]) {
	cout << argv[0] << " [method] [size] [num-threads]" << endl;
	cout << "- first argument  is the method and must be chosen";
	cout << " between: ";
	for (int i=0; methods[i].ptr_method != nullptr; ++i) {
		cout << methods[i].name << " ";
	}
	cout << endl;
	cout << "- second argument is the size of the square matrix" << endl;
	cout << "- thirs argument is the number of threads used by OpenMP" << endl;
	exit(EXIT_FAILURE);
}

/**
 *
 */
int main(int argc, char *argv[]) {
	int method=1;

	if (argc > 1) {
		method = get_method(argv[1]);
		if (method == -1) {
			usage(argv); 
		}
	}

	if (argc > 2) {
		dimension = atoi(argv[2]);
	}
	
	if (argc > 3) {
		max_threads = atoi(argv[3]);
	}
	
	cout << "dimension=" << dimension << endl;
	
	#if defined(CBLAS) || defined(GSL_CBLAS)
	omp_set_num_threads(max_threads);
	omp_set_dynamic(0);
	#endif
	
	#ifdef MKL
	mkl_set_threading_layer(MKL_THREADING_INTEL); 
	cout << "MKL " << mkl_get_version << endl;
	mkl_set_num_threads(max_threads);
	cout << "mkl max threads=" << mkl_get_max_threads() << endl;
	#endif
	
	methods[method].ptr_method();

	return 0;
}
