/*
    asm_matprod32 helps compare different implementations of the product
    of two square matrices
    
    Copyright (C) 2019  Jean-Michel RICHER

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <https://www.gnu.org/licenses/>.

	Contact: jean-michel.richer@univ-angers.fr
	
 */

#include <iostream>
#include <sstream>
#include <cstdlib>
#include <xmmintrin.h>
#include <pmmintrin.h>
#include <stdint.h>
#include <cstring>
#include <iomanip>
#include <getopt.h>
#include <cmath>
#include <vector>
#include <set>
using namespace std;
#include "cpu_timer.h"
#include "signal_handler.h"
#include "cpp_config.h"
#include "common.h"
#include "gnu_gpl.h"

#ifdef __PGIC__
#include <mm_malloc.h>
#endif

void report_compiler() {
#if defined( __INTEL_COMPILER )
    cout << "compiler=intel" << endl;
#elif defined( __clang__ )
    cout << "compiler=llvm" << endl;
#elif defined( __PGIC__ )
    cout << "compiler=pgi" << endl;                
#elif defined( __GNUC__ )
    cout << "compiler=gnu" << endl;
#else
	cout << "compiler=unknown" << endl;
#endif
}


// ==================================================================
// GLOBAL VARIABLES
// ==================================================================

// matrices
f32 *A, *B, *C;
// identifier of method to test
u32 method_id = 1;
// default dimension of square matrices
u32 size = 2048;
// perform test ?
bool test_flag = false;
// methods not to test
set<u32> test_avoid;
// dimension of block for loop blocking
u32 BLOCK_DIM = 64;
// print matrices
bool verbose_flag = false;
// select method by name
string select_by_name = "";
// number of iterations
u32 zillions = 1;

/**
 * print matrix
 */
void print_matrix(string name, f32 *m, u32 size) {
	cout << "Matrix " << name << endl;
	for (u32 y=0; y<size; ++y) {
		cout << "line " << y << ": ";
		for (u32 x=0; x<size; ++x) {
			cout << m[y*size+x] << " ";
		}
		cout << endl;
	}
}



// ==================================================================
// Assembly methods
// ==================================================================
extern "C" {
	void mp_asm_fpu(f32 *A, f32 *B, f32 *C, u32 size);
	void mp_asm_fpu_ur4(f32 *A, f32 *B, f32 *C, u32 size);
	void mp_inv_jk_sse(f32 *A, f32 *B, f32 *C, u32 size);
	void mp_inv_jk_avx(f32 *A, f32 *B, f32 *C, u32 size);
}

// ==================================================================
// C implementation with -O2 optimization flags instead of -O3
// ==================================================================
extern void mp_inv_jk_O2(f32 *A, f32 *B, f32 *C, u32 size);


/**
 * Allocation of resources
 */
void allocate_resources() {
	u32 size2 = size * size;
	A = (f32 *) _mm_malloc(size2 * sizeof(f32), CPU_MEMORY_ALIGNMENT);
	B = (f32 *) _mm_malloc(size2 * sizeof(f32), CPU_MEMORY_ALIGNMENT);
	C = (f32 *) _mm_malloc(size2 * sizeof(f32), CPU_MEMORY_ALIGNMENT);
	cout << hex;
	cout << "A=" << (f32 *) A << endl;
	cout << "B=" << (f32 *) B << endl;
	cout << "C=" << (f32 *) C << endl;
	cout << dec;

	// initialize matrices
	for (u32 i = 0; i < size2; ++i) {
		A[i] = 1.0 + (i % 2);
		B[i] = -2.0 + (i % 5);
		//A[i] = i % 4 + 1;
		//B[i] = 1;
	}
	memset(C, 0, size2 * sizeof(f32));

}

/**
 * free resources
 */
void free_resources() {
	_mm_free(A);
	_mm_free(B);
	_mm_free(C);
}

/**
 * This is the reference method that computes the product
 * of matrices C = A * B
 */
void mp_reference(f32 *A, f32 *B, f32 *C, u32 size) {
	for (u32 i = 0; i < size; ++i) {
		for (u32 j = 0; j < size; ++j) {
			float sum = 0;
			for (u32 k = 0; k < size; ++k) {
				sum += a(i,k) * b(k,j);
			}
			c(i,j) = sum;
		}
	} 
}

/**
 * Intrinsics method that implements a 4x4 matrix product
 */
void M4x4_SSE(float *A, float *B, float *C, u32 size) {
    __m128 row1 = _mm_load_ps(&B[0]);
    __m128 row2 = _mm_load_ps(&B[size]);
    __m128 row3 = _mm_load_ps(&B[2*size]);
    __m128 row4 = _mm_load_ps(&B[3*size]);
    for(int i=0; i<4; i++) {
        __m128 brod1 = _mm_set1_ps(A[size*i + 0]);
        __m128 brod2 = _mm_set1_ps(A[size*i + 1]);
        __m128 brod3 = _mm_set1_ps(A[size*i + 2]);
        __m128 brod4 = _mm_set1_ps(A[size*i + 3]);
        __m128 row = _mm_add_ps(_mm_add_ps(_mm_mul_ps(brod1, row1), _mm_mul_ps(brod2, row2)),
                                _mm_add_ps(_mm_mul_ps(brod3, row3), _mm_mul_ps(brod4, row4)));
        __m128 old_row = _mm_load_ps(&C[size*i]); 
        row = _mm_add_ps(row, old_row);                        
        _mm_store_ps(&C[size*i], row);
    }
}

/**
 * function that uses the M4x4_SSE function to perform the product
 * of A * B and store it in C
 */
void mp_tile_4x4(f32 *A, f32 *B, f32 *C, u32 size) {
	for (u32 i = 0; i < size; i+=4) {
		for (u32 j = 0; j < size; j+=4) {
			for (u32 k = 0; k < size; k+=4) {
				M4x4_SSE(&A[i*size+k], &B[k*size+j], &C[i*size+j], size);
			}
		}
	} 
}


/**
 * method with inversion of loop j, k
 */
void mp_inv_jk(f32 *A, f32 *B, f32 *C, u32 size) {
	for (u32 i = 0; i < size; ++i) {
		for (u32 k = 0; k < size; ++k) {
			for (u32 j = 0; j < size; ++j) {
				c(i,j) += a(i,k) * b(k,j);
			}
		}
	} 
}

void mp_tile_bxb_v1(f32 *A, f32 *B, f32 *C, u32 dim) { 

    for (u32 jj=0; jj<dim; jj+=BLOCK_DIM) {
		 for (u32 kk=0; kk<dim; kk+=BLOCK_DIM) {
			for (u32 i=0; i<dim; ++i) {
	  			for (u32 j=jj; j< min(jj+BLOCK_DIM,dim); ++j) {
					for (u32 k=kk; k<min(kk+BLOCK_DIM,dim); ++k) {
						c(i,j) += a(i,k) * b(k,j);
					}
    			}
  			}
  		}
  	}
}

void mp_tile_bxb_v2(f32 *A, f32 *B, f32 *C, u32 size) {

	for (u32 i=0; i<size; i+= BLOCK_DIM) {
    	for (u32 j=0; j<size; j+= BLOCK_DIM) {
      		for (u32 k=0; k<size; k += BLOCK_DIM) { 
        		for (u32 ib=i; ib<min(i+BLOCK_DIM,size); ++ib) {
          			for (u32 jb=j; jb<min(j+BLOCK_DIM,size); ++jb) {
          				f32 sum = 0;
          				for (u32 kb=k; kb<min(k+BLOCK_DIM,size); ++kb) {
            				sum += a(ib,kb) * b(kb,jb);
            			}
            			c(ib,jb) += sum;
          			}	
        		}
      		}
    	}
  	}
}

void mp_tile_bxb_v3(f32 *A, f32 *B, f32 *C, u32 size) {
	
	for (u32 i=0; i<size; i += BLOCK_DIM) {
    	for (u32 j=0; j<size; j += BLOCK_DIM) {
      		for (u32 k=0; k<size; ++k) { 
        		for (u32 ib=i; ib<min(i+BLOCK_DIM,size); ++ib) {
        			f32 *aib = &a(ib,0);
        			f32 *cib = &c(ib,0);
          			for (u32 jb=j; jb<min(j+BLOCK_DIM,size); ++jb) {
            			cib[jb] += aib[k] * b(k,jb);
          			}	
        		}
      		}
    	}
  	}
}

void mp_tile_bxb_v4(f32 *A, f32 *B, f32 *C, u32 size) {
	
	for (u32 i=0; i<size; i += BLOCK_DIM) {
    	for (u32 j=0; j<size; j += BLOCK_DIM) {
      		for (u32 k=0; k<size; k += BLOCK_DIM) { 
        		for (u32 ib=i; ib<min(i+BLOCK_DIM,size); ++ib) {
        			for (u32 kb=k; kb<min(k+BLOCK_DIM,size); ++kb) {
        				f32 *aib = &a(ib,0);
        				f32 *cib = &c(ib,0);
        				for (u32 jb=j; jb<min(j+BLOCK_DIM,size); ++jb) {
        					cib[jb] += aib[kb] * b(kb,jb);
        				}
          			}	
        		}
      		}
    	}
  	}
}


// ==================================================================
// array of methods to test
// ==================================================================
MethodDeclaration methods[] = {
	{ nullptr, "not defined" },
	add_method(mp_reference),
	add_method(mp_asm_fpu),
	add_method(mp_asm_fpu_ur4),
	add_method(mp_inv_jk_O2),
	add_method(mp_inv_jk),
	add_method(mp_inv_jk_sse),
	add_method(mp_inv_jk_avx),
	add_method(mp_tile_4x4),
	add_method(mp_tile_bxb_v1),
	add_method(mp_tile_bxb_v2),
	add_method(mp_tile_bxb_v3),
	add_method(mp_tile_bxb_v4),
	{ nullptr, "not defined" }
};



/**
 * Check if two matrices are equal and if they differ give row
 * and column where they differ
 * @param A pointer to first matrix
 * @param B pointer to second matrix
 * @param dim dimension of matrix
 * @param r reference to row where matrices differ
 * @param c reference to column where matrices differ 
 */
bool are_equal(f32 *A, f32 *B, u32 size, int &r, int &c) {
	for (u32 i = 0; i < size; ++i) {
		for (u32 j = 0; j < size; ++j) {
			if (fabs(a(i,j) - b(i,j)) > 1e-4) {
				r = i;
				c = j;
				return false;
			}
		}
	}
	return true;
}

/**
 * Perform test to check if all methods return the same result.
 * We compute the product of A * B for each method and compare
 * the result to the product given by the reference method.
 * @param A pointer to input matrix A
 */
void validity_test() {
	cout << "test" << endl;
	f32 *D = (f32 *) _mm_malloc(size * size * sizeof(f32), CPU_MEMORY_ALIGNMENT);
	
	CPUTimer timer;
	timer.start();
	methods[1].method(A, B, C, size); 
	timer.stop();
	
	cout << 1 << " " << setw(25) << methods[1].name << " ";
	cout << setw(20) << timer;
	cout << endl;
	
	for (u32 i = 2; methods[i].method != nullptr; ++i) {
		if (test_avoid.find(i) != test_avoid.end()) continue;
		
		memset(D, 0, size*size*sizeof(f32));
		timer.start();
		methods[i].method(A, B, D, size);
		timer.stop();
		cout << i << " " << setw(25) << methods[i].name << " ";
		cout << setw(20) << timer << " ";
		int row, col;
		if (are_equal(C,D,size,row,col)) {
			cout << " OK";
		} else {
			cout << " !!! FAIL !!! at row=" << row << ", col=" << col << " ";
			cout << D[row*size+col] << " != " << C[row*size+col]; 
		}
		cout << endl;
	}
	free_resources();
	_mm_free(D);
	exit(EXIT_SUCCESS);
}

/**
 * return number of methods
 */
u32 methods_count() {
	u32 i = 1;
	while (methods[i].method != nullptr) ++i;
	return i-1;
}

/**
 * list methods
 */
void list_methods() {
	u32 i;
	for (i = 1; methods[i].method != nullptr; ++i) {
		cout << "method.id=" << i << ",method.name=" << methods[i].name << endl;
	}
	cout << "methods.count=" << i-1 << endl;
	report_compiler();
	exit(EXIT_SUCCESS);
}

/**
 * Description of arguments of program
 */ 
void usage(string program_name) {
	size_t pos = program_name.find_last_of("/");
	if (pos != string::npos) program_name = program_name.substr(pos+1);
	
	string b_bold = "\e[1m";
	string e_bold = "\e[0m";
	string b_option = "\t\e[1m";
	string e_option = "\e[0m\n\t\t";
	
	cout << b_bold << "NAME" << e_bold << endl;
	cout << "\t" << program_name << endl << endl;
	cout << b_bold << "SYNOPSIS" << e_bold << endl;
	cout << "\t" << program_name << " [OPTION]..." << endl << endl;
	cout << b_bold << "DESCRIPTION" << e_bold << endl;
	cout << "\tcompute result of matrix product for square matrices" << endl << endl;

	cout << b_option << "-h, --help" << e_option;
	cout << "this message" << endl << endl;

	cout << b_option << "-c, --copying" << e_option;
	cout << "print GNU GPL" << endl << endl;

	cout << b_option << "-l, --list" << e_option;
	cout << "list all methods" << endl << endl;
	
	cout << b_option << "-m, --method=INT" << e_option;
	cout << "select method given its integer identifier" << endl << endl;
	
	cout << b_option << "-n, --name=STRING" << e_option;
	cout << "select method given its name" << endl << endl;
	
	cout << b_option << "-s, --size=INT" << e_option;
	cout << "dimension of matrix" << endl << endl;

	cout << b_option << "-t, --test" << e_option;
	cout << "test of all functions for validity" << endl << endl;
	
	cout << b_option << "-a, --avoid=LIST" << e_option;
	cout << "comma separated list of method not to test" << endl << endl;
	
	cout << b_option << "-b, --block=INT" << e_option;
	cout << "dimension of blocking factor, default is " << BLOCK_DIM << endl;
	cout << "possible values are 16, 32, 48, 64" << endl << endl;

	cout << b_option << "-z, --zillions=INT" << e_option;
	cout << "number of times the method is called" << endl << endl;
	
	cout << b_option << "-v, --verbose" << e_option;
	cout << "print matrices " << endl << endl;

	
	exit(EXIT_SUCCESS);
}

/**
 *  Transform comma separated values string into a set of integers
 */
void tokenize(set<u32> &tokens, string s, char delimiter=',') {
	string token;
	if (s.length() == 0) return ;
	if (s[0] == '=') s.replace(0, 1, "");
	istringstream iss(s);

   while (std::getline(iss, token, delimiter)) {
      tokens.insert(std::stoi(token));
   }
   
}

	
/**
 * main program
 */
int main(int argc, char *argv[]) {
	SignalHandler sh;
	
	string program_name = "asm_matprod32";
	gnu_header(program_name);
	
	// get parameters
	while (1) {
		int option_index = 0;
		static struct option long_options[] = {
			OPTION_ARG("block", 'b'),
			OPTION_NO_ARG("copying", 'c'),
			OPTION_ARG("size", 's'),
			OPTION_NO_ARG("help", 'h'),
			OPTION_NO_ARG("list", 'l'),
			OPTION_ARG("method", 'm'),
			OPTION_ARG("method-name", 'n'),
			OPTION_NO_ARG("test", 't'),
			OPTION_ARG("avoid", 'a'),
			OPTION_NO_ARG("verbose", 'v'),
			OPTION_ARG("zillions", 'z'),
			{0, 0, 0, 0 }
		};

		int c = getopt_long(argc, argv, "hcb:s:lm:n:tva:z:", long_options, &option_index);
		if (c == -1) break;

		switch (c) {
			case 0:
               cerr << "option " << long_options[option_index].name;
				if (optarg) {
					cerr << " with arg " << optarg;
				}		
				cerr << endl;
				break;

			case 'h':
				usage(argv[0]); 
				break;

			case 'c':
				gnu_gpl(program_name);
				break;

			case 'b':
				BLOCK_DIM = static_cast<u32>(atoi(optarg));	
				break;

			case 's':
				size = static_cast<u32>(atoi(optarg));
				break;

			case 'l':
				list_methods();
				break;
				
			case 'm':
				method_id = atoi(optarg);
				break;
							
			case 'n':
				select_by_name = optarg;
				break;
					
			case 't':
				test_flag = true;
				break;
			
			case 'a':
				tokenize(test_avoid, optarg);
				break;
				
			case 'v':
				verbose_flag = true;
				break;
			
			case 'z':
				zillions = atoi(optarg);
				break;
							
			default:
				cerr << "Error ! Check command line arguments !" << endl;
				exit(EXIT_FAILURE);				
		};		
	}	
	
	// check parameters
	if ((method_id < 1) || (method_id > methods_count())) {
		cerr << "method should be inside [1.." << methods_count() << "]" << endl;
		exit(EXIT_FAILURE);
	}
	
	if (size < 16) {
		cerr << "dimension should be greater or equal than 16" << endl;
		exit(EXIT_FAILURE);
	}
	
	report_compiler();
	
	// allocation of matrices 
	allocate_resources();
	
	// initialize random number generator
	srand(19702013);
		
	
	if (select_by_name.length() != 0) {
		u32 id;
		for (id = 1; methods[id].method != nullptr; ++id) {
			if (methods[id].name == select_by_name) {
				break;
			}
		}
		if (methods[id].method == nullptr) {
			cerr << "There is no method of name '" << select_by_name << "'" << endl;
			exit(EXIT_FAILURE);
		} else {
			method_id = id;
		}
	}
	
	cout << "size=" << size << endl;
	
	if (test_flag) {
		validity_test();
	}
	
	// performance test 
	CPUTimer timer;
	timer.start();
	for (u32 zillion = 1; zillion <= zillions; ++zillion) {
		methods[method_id].method(A, B, C, size);
	}
	timer.stop();
	cout << "cycles=" << timer << endl;
	cout << "method.name=" << methods[method_id].name << endl;
	
	// compute result which is the sum values of first three lines
	f32 result = 0.0;
	for (u32 i = 0; i < 3; ++i) {
		for (u32 j = 0; j < size; ++j) result += c(i,j);
	}
	cout << "result=" << result << endl;
		
	if (verbose_flag) {
		print_matrix("A", A, size);
		print_matrix("B", B, size);
		print_matrix("C", C, size);
	}
		
	free_resources();
			
	return EXIT_SUCCESS;
}

