// ==================================================================
// Author: Jean-Michel Richer
// Email: jean-michel.richer@univ-angers.fr
// Date: Aug 2020
// Last modified: November 2020
// Purpose: interface and class to facilitate use of MPI
// ==================================================================
#ifndef EZ_MPI_H
#define EZ_MPI_H

#include <string>
#include <typeinfo>
#include <sstream>
#include <stdexcept>
#include <cstdint>
using namespace std;
#include <time.h>
#include <unistd.h>
#include <mpi.h>

/**
 * EZ MPI is a wrapper for MPI C++. It simplifies the use
 * of MPI send, receive, gather, scatter functions.
 * For the gather, scatter and reduce functions the processor
 * of rank 0 is considered as the "master" that collects or
 * sends data.
 */
 
namespace ez {
 
namespace mpi {

/* 
 * Process Information.
 * This class also acts as a logger so it can be used to print
 * information.
 */
class Process {
protected:
	// maximum number of process working together
	int m_max;
	// identifier of current process, called rank for MPI
	int m_id;
	// identifier of remote process for send, receive, ...
	int m_remote;
	// status of last operation
	MPI::Status m_status;
	// message tag if needed (default is 0)
	int m_message_tag;
	// name of process
	string m_name;
	// Linux process identifier
	int m_pid;
	// verbose mode
	bool m_verbose_flag;
	// verbose mode
	bool m_log_flag;
	// main output stream for current processor
	ostringstream log_stream;
	// temporary output stream
	ostringstream tmp_log;
	
private:
	/**
	 * Find processor name
	 */
	void find_cpu_name();
	
	/**
	 * record output of oss into general output
	 */
	void append();
	
	/**
	 * initialize max_cpus, cpu_rank, processus id
	 */
	void init();
	
public:	
			
	/**
	 * Default constructor
	 */
	Process(int argc, char *argv[], bool verbose=false);
	
	~Process();
	
	/**
	 * Display output of each processor if verbose mode is on
	 */
	void finalize();
	
	/**
	 * set verbose mode
	 */
	void verbose(bool mode) {
		m_verbose_flag = mode;
	}
	
	/**
	 * set log mode
	 */
	void log(bool mode) {
		m_log_flag = mode;
	}
	
	
	
	/**
	 * Get process identifier
	 */
	int pid();
	
	/**
	 * Get processor identifier or rank
	 */
	int id();
	
	/**
	 * Get number of processors used
	 */
	int max();
	
	/**
	 * Get processor name
	 */
	string name();
	
	
	/** 
	 * set remote processor identifier
	 */
	void remote(int rmt);
	
	/**
	 * set message tag
	 * @param tag must be an integer between 0 and 32767
	 */
	void tag(int tag);
	
	/**
	 * Return true if this processor is the processor of rank 0 
	 * considered as the master.
	 */
	bool is_master() {
		return (m_id == 0);
	}
	
	/**
	 * synchronize
	 */
	void synchronize();
	
	/**
	 * Determine type of data T and convert it into MPI::Datatype.
	 * This function needs to be extended with other types.
	 */
	template<class T>
	MPI::Datatype get_type() {
		if (typeid(T) == typeid(char)) {
			return MPI::CHAR;
		} else if (typeid(T) == typeid(int8_t)) {	
			return MPI::CHAR;
		} else if (typeid(T) == typeid(uint8_t)) {	
			return MPI::CHAR;	
		} else if (typeid(T) == typeid(int)) {
			return MPI::INT;
		} else if (typeid(T) == typeid(float)) {
			return MPI::FLOAT;
		} else if (typeid(T) == typeid(double)) {
			return MPI::DOUBLE;
		} 
		//throw std::runtime_error("unknown MPI::Datatype"); 
		cout << "!!!!! unknown " << typeid(T).name() << endl;
		return MPI::INT;
	}
	
	/**
	 * Send one instance of data to remote_cpu
	 * @param v data to send 
	 */
	template<class T>
	void send(T& v) {
		MPI::Datatype data_type = get_type<T>();
		
		tmp_log << "send value=" << v << " to " << m_remote << endl;
		flush();
		
		MPI::COMM_WORLD.Send(&v, 1, data_type, m_remote, m_message_tag);
	}
	
	/**
	 * Send an array to remote_cpu
	 * @param arr address of the array
	 * @param size number of elements to send
	 */
	template<class T>
	void send(T *arr, int size) {
		MPI::Datatype data_type = get_type<T>();
		
		tmp_log << "send array of size=" << size << " to " << m_remote << endl;
		flush();
				
		MPI::COMM_WORLD.Send(&arr[0], size, data_type, m_remote, m_message_tag);
	}
	
	/**
	 * Receive one instance of data from remote cpu
	 * @param v data to receive
	 */
	template<class T>
	void recv(T& v) {
		MPI::Datatype data_type = get_type<T>();
		
		MPI::COMM_WORLD.Recv(&v, 1, data_type, m_remote,
			(m_message_tag == 0) ? MPI::ANY_TAG : m_message_tag,
					m_status);
			
		tmp_log << "receive value=" << v << " from " << m_remote << endl;
		flush();
			
	}
	
	/**
	 * Receive an array of given size
	 * @param arr pointer to address of the array
	 * @param size number of elements
	 */
	template<class T>
	void recv(T *arr, int size) {
		MPI::Datatype data_type = get_type<T>();
		
		MPI::COMM_WORLD.Recv(&arr[0], size, data_type, m_remote,
			(m_message_tag == 0) ? MPI::ANY_TAG : m_message_tag,
					m_status);
			
		tmp_log << "receive array of size=" << size << " from " << m_remote << endl;
		flush();
			
	}
	
	/**
	 * Send array and receive value in return, this is an instance
	 * of the Sendrecv function.
	 * @param arr address of the array to send
	 * @param size size of the array to send
	 * @param value value to receive
	 */
	template<class T, class U>
	void sendrecv(T *array, int size, U& value) {
		MPI::Datatype array_data_type = get_type<T>();
		MPI::Datatype value_data_type = get_type<U>();
		
		tmp_log << "sendrecv/send array of size=" << size << endl;
		flush();
				
		MPI::COMM_WORLD.Sendrecv(&array[0], size, array_data_type, m_remote, 0,
			&value, 1, value_data_type, MPI::ANY_SOURCE, MPI::ANY_TAG, 
			m_status);
			
		tmp_log << "sendrecv/receive value=" << value << endl;
		flush();
			
	}
	
	/**
	 * Perform reduction
	 * @param lcl_value local array used to perform reduction
	 * @param glb_value global data that will contain result
	 * @param op operation to perform (MPI::SUM, MPI::MAX, ...)
	 */
	template<class T>
	void reduce(T &lcl_value, T &glb_value, const MPI::Op& op) {
		MPI::Datatype data_type = get_type<T>();
		
		MPI::COMM_WORLD.Reduce(&lcl_value, 
			&glb_value, 1, data_type, op, 0);
			
		tmp_log << "reduction gives value=" << glb_value << endl;
		flush();

	}
	
	/**
	 * Perform gather operation
	 * @param lcl_array local array that is send to master process
	 * @param glb_array global array that will contain all local arrays
	 */
	template<class T>
	void gather(T *lcl_array, int size, T *glb_array) {
		MPI::Datatype data_type = get_type<T>();
		
		MPI::COMM_WORLD.Gather(lcl_array, size, data_type,
			glb_array, size, data_type, 0);
			
		tmp_log << "gather" << endl;
		flush();

	}
	
	/**
	 * Perform scatter operation
	 * @param glb_array array of data that will be send by to all processors
	 * @param size size of the local array of data
	 * @param lcl_array local array of data
	 */
	template<class T>
	void scatter(T *glb_array, int size, T *lcl_array) {
		MPI::Datatype data_type = get_type<T>();
		
		MPI::COMM_WORLD.Scatter(glb_array, size, data_type,
			lcl_array, size, data_type, 0);
			
		tmp_log << "scatter" << endl;
		flush();

	}
	
	typedef std::ostream& (*ManipFn)(std::ostream&); 
    typedef std::ios_base& (*FlagsFn)(std::ios_base&); 
    
	void print(char v);
	void print(int v);
	void print(string s);
	void print(float f);
	void print(double d);
	
	template<class T> // int, double, strings, etc 
    Process& operator<<(const T& output) { 
		tmp_log << output; 
       	return *this; 
    } 

	// endl, flush, setw, setfill, etc. 
    Process& operator<<(ManipFn manip) {
       	manip(tmp_log); 
		if (manip == static_cast<ManipFn>(std::flush) || manip == static_cast<ManipFn>(std::endl)) {
	      this->flush(); 
		}
     	return *this; 
    } 

	// setiosflags, resetiosflags 
    Process& operator<<(FlagsFn manip) { 
		manip(tmp_log); 
     	return *this; 
    } 

	void flush();

	void logs(ostream& out);
	
	typedef void (*Code)(Process& p);
	
	void run(Code code) {
		code(*this);
	}
};



} // end of namespace mpi

} // end of namespace ez
#endif
