
#define min(X,Y) (((X) < (Y)) ? (X) : (Y))
#define max(X,Y) (((X) > (Y)) ? (X) : (Y))
#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <mpi.h> 
#include <sys/time.h>
#include "project.h"

int main(int argc, char **argv)
{
	double eps = 0.001;
	double delta_t = 0.000001;
	double alpha = 1;
	MPI_Init(&argc, &argv);
	int m, n;
	double **root_field;		// complete field owned by root
	double **partial_field;		// partial field where process works on
	double start, end;
	t_process_info pi;
	t_process_info *infos;
	int pro_per_dim[2];
	int cell_per_pro[2];
	MPI_Comm cart_comm;
	int matrix_size[2];

	Process_Args(argc, argv, &m, &n, &eps, &delta_t);

	int rank, cart_cluster_size;
	if(MPI_Comm_rank(MPI_COMM_WORLD, &rank)) {
		fprintf(stderr, "Cannot fetch rank\n");
		exit(1);
	}

	if(rank == 0) {
		root_field = New_Matrix(m, n);
		if (root_field == NULL) {
			fprintf(stderr, "rank %s: Can't allocate root_field !\n", rank);
			exit(1);
		}
	}

	// optimize cart cluster
	Optimize_Cart_Cluster(m, n, MPI_COMM_WORLD, rank, pro_per_dim, cell_per_pro);

	cart_comm = Create_MPI_Cart_Cluster(MPI_COMM_WORLD, rank, pro_per_dim);

	pi = Calculate_Process_Info(cart_comm, rank, m, n, cell_per_pro);

	matrix_size[0] = pi.end_m - pi.start_m + 3;
	matrix_size[1] = pi.end_n - pi.start_n + 3;

	if(MPI_Comm_size(cart_comm, &cart_cluster_size)) {
		fprintf(stderr, "Cannot fetch size of cart\n");		
		exit(1);
	}

	Print_Process_Info(pi);

	infos = Gather_Process_Info(&pi, rank, cart_cluster_size, cart_comm);

	if(rank == 0) {
		for(int i = 0; i < cart_cluster_size; i++) {
			Print_Process_Info(infos[i]);
		}
	}

	double delta_a;
	partial_field = New_Matrix(matrix_size[0], matrix_size[1]);
	if (partial_field == NULL) {
		fprintf(stderr, "rank %d: Can't allocate partial_field %d, %d end_M: %d, start_m: %d, end_n: %d, start_n: %d!\n", 
			rank, 
			matrix_size[0], 
			matrix_size[1],
			pi.end_m,
			pi.start_m,
			pi.end_n,
			pi.start_n
		);
		exit(1);
	}

	Init_Matrix(partial_field, matrix_size[0], matrix_size[1], 0);

	double **partial_field_tmp = New_Matrix(matrix_size[0], matrix_size[1]);
	Init_Matrix(partial_field_tmp, matrix_size[0], matrix_size[1], 0);
	double **swap;
	double hx = 1.0/(double)m;
	double hy = 1.0/(double)n;
	double hx_square = hx * hx;
	double hy_square = hy * hy;
	double maxdiff;

	double *dim1_own_edge_values = malloc(sizeof(double) * (2*(matrix_size[0]-2)));
	double *dim1_neighbor_egde_values = malloc(sizeof(double) * (2*(matrix_size[0]-2)));
	MPI_Request sync_requests[9];	// 2 for each edge + 1 completion

	// set to MPI null for waitany -- only needed once as others are overwritten with every iteration
	for(int i = 0; i < 9; i++) {
		sync_requests[i] = MPI_REQUEST_NULL;
	}

	double max_delta_t = 0.25*((min(hx,hy))*(min(hx,hy)))/alpha;  /* minimaler Wert für Konvergenz */
	if (delta_t > max_delta_t) { 
		delta_t = max_delta_t;
		if(rank == 0)
			printf ("Info: delta_t set to %.10lf.\n", delta_t);
	}

	int neighbor_dim0_left, neighbor_dim0_right, neighbor_dim1_left, neighbor_dim1_right;
	if(MPI_Cart_shift(cart_comm, 0, 1, &neighbor_dim0_left, &neighbor_dim0_right)) {
		fprintf(stderr, "Shift failed\n");		
		exit(1);
	}
	if(MPI_Cart_shift(cart_comm, 1, 1, &neighbor_dim1_left, &neighbor_dim1_right)) {
		fprintf(stderr, "Shift failed\n");		
		exit(1);
	}

	//init edges
	if(neighbor_dim1_left == MPI_PROC_NULL) {
		for(int i = pi.start_m; i <= pi.end_m; i++) {
			partial_field[i - pi.start_m + 1][1] = (double)i / (m-1);
			partial_field_tmp[i - pi.start_m + 1][1] = (double)i / (m-1);
		}
	}
	if(neighbor_dim1_right == MPI_PROC_NULL) {
		for(int i = pi.start_m; i <= pi.end_m; i++) {
			partial_field[i - pi.start_m + 1][matrix_size[1]-2] = 1 - (double)i / (m-1);
			partial_field_tmp[i - pi.start_m + 1][matrix_size[1]-2] = 1 - (double)i / (m-1);
		}
	}
	if(neighbor_dim0_left == MPI_PROC_NULL) {
		for(int i = pi.start_n; i <= pi.end_n; i++) {
			partial_field[1][i - pi.start_n + 1] = (double)i / (n-1);
			partial_field_tmp[1][i - pi.start_n + 1] = (double)i / (n-1);
		}
	}
	if(neighbor_dim0_right == MPI_PROC_NULL) {
		for(int i = pi.start_n; i <= pi.end_n; i++) {
			partial_field[matrix_size[0] - 2][i - pi.start_n + 1] = 1 - (double)i / (n-1);
			partial_field_tmp[matrix_size[0] - 2][i - pi.start_n + 1] = 1 - (double)i / (n-1);
		}
	}
	int *completions = malloc(sizeof(int) * cart_cluster_size);
	int k = 0;

	/*
	*
	* START ITERATION
	*
	*/

	while (1) {			// iterate until break;
		k++;
		maxdiff = 0;
		for(
			int i = (neighbor_dim0_left == MPI_PROC_NULL) ? 2 : 1; // catch edges
			i < pi.end_m - pi.start_m + ((neighbor_dim0_right == MPI_PROC_NULL) ? 1 : 2); 
			i++
		) {	
			for(
				int j = (neighbor_dim1_left == MPI_PROC_NULL) ? 2 : 1; // catch edges
				j < pi.end_n - pi.start_n + ((neighbor_dim1_right == MPI_PROC_NULL) ? 1 : 2); 
				j++
			) {
				delta_a = alpha * 
					    ( (partial_field[i+1][j] + partial_field[i-1][j] - 2.0 * partial_field[i][j]) / (hy_square)
						 +(partial_field[i][j-1] + partial_field[i][j+1] - 2.0 * partial_field[i][j]) / (hx_square) );
				delta_a = delta_a * delta_t;
				partial_field_tmp[i][j] = partial_field[i][j] + delta_a;
				
				if(delta_a > maxdiff)
					maxdiff = delta_a;
			}
		}
		swap = partial_field_tmp;
		partial_field_tmp = partial_field;
		partial_field = swap;

		printf("comp rank %d: max: %.10lf eps: %.10lf\n", rank, maxdiff, eps);
		int completion = maxdiff <= eps;

		printf("rank %d completion %d\n", rank, completion);
		if(MPI_Allgather(&completion, 1, MPI_INT, completions, 1, MPI_INT, cart_comm)) {
			fprintf(stderr, "Alltoall failed\n");		
			exit(1);
		}
		for(int i = 0; i < cart_cluster_size; i++) {
			printf("rank %d: %d -> %d \n", rank, i, completions[i]);
		}
		int all_completed = 1;
		for(int i = 0; i < cart_cluster_size; i++) {
			if(!completions[i]) {
				all_completed = 0;
				break;
			}
		}
		if(all_completed) {
			printf("rank: %d: break after %d iterations\n", rank, k);
			break;
		}

		// Sync edges

		// copy own edges in send buffer
		if(neighbor_dim0_left != MPI_PROC_NULL) {
			MPI_Isend(&(partial_field[1][1]), (matrix_size[1] - 2), MPI_DOUBLE, neighbor_dim0_left, 0, cart_comm, &(sync_requests[0]));
			MPI_Irecv(&(partial_field[0][1]), (matrix_size[1] - 2), MPI_DOUBLE, neighbor_dim0_left, 0, cart_comm, &(sync_requests[4]));
			//memcpy(&(dim1_own_edge_values[2*(matrix_size[0]-2)]), &(partial_field[1][1]), sizeof(double) * (matrix_size[1] - 2));
		}
		if(neighbor_dim0_right != MPI_PROC_NULL) {
			MPI_Isend(&(partial_field[matrix_size[0]-2][1]), (matrix_size[1] - 2), MPI_DOUBLE, neighbor_dim0_right, 0, cart_comm, &(sync_requests[1]));
			MPI_Irecv(&(partial_field[matrix_size[0]-1][1]), (matrix_size[1] - 2), MPI_DOUBLE, neighbor_dim0_right, 0, cart_comm, &(sync_requests[5]));
			//memcpy(&(dim1_own_edge_values[2*(matrix_size[0]-2) + matrix_size[1] - 2]), &(partial_field[matrix_size[0]-2][1]), sizeof(double) * (matrix_size[1] - 2));
		}
		if(neighbor_dim1_left != MPI_PROC_NULL) {
			for(int i = 0; i < matrix_size[0] - 2; i++) {
				dim1_own_edge_values[i] = partial_field[i+1][1];
			}
			MPI_Isend(&(dim1_own_edge_values[0]), matrix_size[0] - 2, MPI_DOUBLE, neighbor_dim1_left, 0, cart_comm, &(sync_requests[2]));
			MPI_Irecv(&(dim1_neighbor_egde_values[0]), matrix_size[0] -2, MPI_DOUBLE, neighbor_dim1_left, 0, cart_comm, &(sync_requests[6]));
		}
		if(neighbor_dim1_right != MPI_PROC_NULL) {
			int right_edge_index = matrix_size[1]-2;
			for(int i = 0; i < matrix_size[0] - 2; i++) {
				dim1_own_edge_values[matrix_size[0]-2+i] = partial_field[i+1][right_edge_index];
			}
			MPI_Isend(&(dim1_own_edge_values[matrix_size[0]-2]), matrix_size[0] - 2, MPI_DOUBLE, neighbor_dim1_right, 0, cart_comm, &(sync_requests[3]));
			MPI_Irecv(&(dim1_neighbor_egde_values[matrix_size[0]-2]), matrix_size[0] -2, MPI_DOUBLE, neighbor_dim1_right, 0, cart_comm, &(sync_requests[7]));
		}

		while(1) {
			int current;
			MPI_Waitany(9, sync_requests, &current, MPI_STATUS_IGNORE);
			if(current == MPI_UNDEFINED) {
				break;
			}
			if(current == 6) {
				for(int i = 0; i < matrix_size[0] - 2; i++) {
					partial_field[i+1][0] = dim1_neighbor_egde_values[i];
				}
			}
			if(current == 7) {
				int right_edge_index = matrix_size[1] - 1;
				for(int i = 0; i < matrix_size[0] - 2; i++) {
					partial_field[i+1][right_edge_index] = dim1_neighbor_egde_values[matrix_size[0]-2+i];
				}
			}
			if(current == 8) {
				int all_completed = 1;
				for(int i = 0; i < cart_cluster_size; i++) {
					if(!completions[i]) {
						all_completed = 0;
						break;
					}
				}
				if(all_completed) {
					printf("rank: %d: break after %d iterations\n", rank, k);
					break;
				}
			}
			printf("waitany...\n");
		}
		printf("waitany completed\n");
	}

	/*
	*
	* END ITERATION
	*
	*/

	//Send_To_Root(partial_field, pi.end_m - pi.start_m + 2, pi.end_n - pi.start_n + 2);
	MPI_Send(partial_field[0], matrix_size[0]*matrix_size[1], MPI_DOUBLE, 0, 0 ,cart_comm);
	if(rank == 0) {
		MPI_Request *requests = malloc(sizeof(MPI_Request) * cart_cluster_size);
		double **allocation = malloc(sizeof(double*) * cart_cluster_size);
		for(int i = 0; i < cart_cluster_size; i++) {
			allocation[i] = malloc(sizeof(double) * (infos[i].end_m - infos[i].start_m + 3) * (infos[i].end_n - infos[i].start_n + 3));
			MPI_Irecv(allocation[i], 
				(infos[i].end_m - infos[i].start_m + 3) * (infos[i].end_n - infos[i].start_n + 3), 
				MPI_DOUBLE, 
				infos[i].rank, 
				MPI_ANY_TAG, 
				cart_comm,
				&requests[i]
			);
		}
		for(int i = 0; i < cart_cluster_size; i++) {
			int current;
			MPI_Waitany(cart_cluster_size, requests, &current, MPI_STATUS_IGNORE);
			printf("\n");
			Insert_Array_In_Matrix(
				root_field, 
				m, 
				n, 
				infos[current].start_m, 
				infos[current].start_n, 
				allocation[current], 
				infos[current].end_m - infos[current].start_m + 3, 
				infos[current].end_n - infos[current].start_n + 3, 
				1, 1, 1, 1);
			free(allocation[current]);
		}
		free(requests);
		free(allocation);
		Write_Matrix(root_field, m, n);
	}

	MPI_Finalize();

	return 0;
}