#include <mpi.h> 
#include <stdio.h>
#include <stdlib.h>
#include "project.h"

void Create_MPI_Type_t_process_info(MPI_Datatype *datatype) {
	t_process_info mock;
	MPI_Datatype type[7] = {MPI_INT, MPI_INT, MPI_INT, MPI_INT, MPI_INT, MPI_INT, MPI_INT};
	int blocklen[7] = {1, 1, 1, 1, 1, 1, 1};
	MPI_Aint disp[7];
	disp[0] = (void*)&(mock.rank) - (void*)&mock;
	disp[1] = (void*)&(mock.coord0) - (void*)&mock;
	disp[2] = (void*)&(mock.coord1) - (void*)&mock;
	disp[3] = (void*)&(mock.start_m) - (void*)&mock;
	disp[4] = (void*)&(mock.start_n) - (void*)&mock;
	disp[5] = (void*)&(mock.end_m) - (void*)&mock;
	disp[6] = (void*)&(mock.end_n) - (void*)&mock;
	MPI_Type_create_struct(7, blocklen, disp, type, datatype);
	MPI_Type_commit(datatype);
}

void Init_Neighbor_Comm(MPI_Comm cart_comm, MPI_Request *sync_requests, int *matrix_size, int *neighbor_dim0_left, int *neighbor_dim0_right, int *neighbor_dim1_left, int *neighbor_dim1_right, double **dim1_own_edge_values, double **dim1_neighbor_egde_values) {
	if(MPI_Cart_shift(cart_comm, 0, 1, neighbor_dim0_left, neighbor_dim0_right)) {
		fprintf(stderr, "Shift failed\n");		
		exit(1);
	}
	if(MPI_Cart_shift(cart_comm, 1, 1, neighbor_dim1_left, neighbor_dim1_right)) {
		fprintf(stderr, "Shift failed\n");		
		exit(1);
	}
	*dim1_own_edge_values = malloc(sizeof(double) * (2*(matrix_size[0]-2)));
	*dim1_neighbor_egde_values = malloc(sizeof(double) * (2*(matrix_size[0]-2)));

	// set to MPI null for waitany -- only needed once as others are overwritten with every iteration
	for(int i = 0; i < 9; i++) {
		sync_requests[i] = MPI_REQUEST_NULL;
	}
}

int Sync(MPI_Comm cart_comm, int completion, int *completions, int cart_cluster_size, int *matrix_size, MPI_Request *sync_requests, int neighbor_dim0_left, int neighbor_dim0_right, int neighbor_dim1_left, int neighbor_dim1_right, double **partial_field, double *dim1_own_edge_values, double *dim1_neighbor_egde_values) {
	if(MPI_Allgather(&completion, 1, MPI_INT, completions, 1, MPI_INT, cart_comm)) {
		fprintf(stderr, "Alltoall failed\n");		
		exit(1);
	}
	int all_completed = 1;
	for(int i = 0; i < cart_cluster_size; i++) {
		if(!completions[i]) {
			all_completed = 0;
			break;
		}
	}
	if(all_completed) {
		return 1;
	}

	// Sync edges

	// copy own edges in send buffer
	if(neighbor_dim0_left != MPI_PROC_NULL) {
		MPI_Isend(&(partial_field[1][1]), (matrix_size[1] - 2), MPI_DOUBLE, neighbor_dim0_left, 0, cart_comm, &(sync_requests[0]));
		MPI_Irecv(&(partial_field[0][1]), (matrix_size[1] - 2), MPI_DOUBLE, neighbor_dim0_left, 0, cart_comm, &(sync_requests[4]));
		//memcpy(&(dim1_own_edge_values[2*(matrix_size[0]-2)]), &(partial_field[1][1]), sizeof(double) * (matrix_size[1] - 2));
	}
	if(neighbor_dim0_right != MPI_PROC_NULL) {
		MPI_Isend(&(partial_field[matrix_size[0]-2][1]), (matrix_size[1] - 2), MPI_DOUBLE, neighbor_dim0_right, 0, cart_comm, &(sync_requests[1]));
		MPI_Irecv(&(partial_field[matrix_size[0]-1][1]), (matrix_size[1] - 2), MPI_DOUBLE, neighbor_dim0_right, 0, cart_comm, &(sync_requests[5]));
		//memcpy(&(dim1_own_edge_values[2*(matrix_size[0]-2) + matrix_size[1] - 2]), &(partial_field[matrix_size[0]-2][1]), sizeof(double) * (matrix_size[1] - 2));
	}
	if(neighbor_dim1_left != MPI_PROC_NULL) {
		for(int i = 0; i < matrix_size[0] - 2; i++) {
			dim1_own_edge_values[i] = partial_field[i+1][1];
		}
		MPI_Isend(&(dim1_own_edge_values[0]), matrix_size[0] - 2, MPI_DOUBLE, neighbor_dim1_left, 0, cart_comm, &(sync_requests[2]));
		MPI_Irecv(&(dim1_neighbor_egde_values[0]), matrix_size[0] -2, MPI_DOUBLE, neighbor_dim1_left, 0, cart_comm, &(sync_requests[6]));
	}
	if(neighbor_dim1_right != MPI_PROC_NULL) {
		int right_edge_index = matrix_size[1]-2;
		for(int i = 0; i < matrix_size[0] - 2; i++) {
			dim1_own_edge_values[matrix_size[0]-2+i] = partial_field[i+1][right_edge_index];
		}
		MPI_Isend(&(dim1_own_edge_values[matrix_size[0]-2]), matrix_size[0] - 2, MPI_DOUBLE, neighbor_dim1_right, 0, cart_comm, &(sync_requests[3]));
		MPI_Irecv(&(dim1_neighbor_egde_values[matrix_size[0]-2]), matrix_size[0] -2, MPI_DOUBLE, neighbor_dim1_right, 0, cart_comm, &(sync_requests[7]));
	}

	while(1) {
		int current;
		MPI_Waitany(9, sync_requests, &current, MPI_STATUS_IGNORE);
		if(current == MPI_UNDEFINED) {
			break;
		}
		if(current == 6) {
			for(int i = 0; i < matrix_size[0] - 2; i++) {
				partial_field[i+1][0] = dim1_neighbor_egde_values[i];
			}
		}
		if(current == 7) {
			int right_edge_index = matrix_size[1] - 1;
			for(int i = 0; i < matrix_size[0] - 2; i++) {
				partial_field[i+1][right_edge_index] = dim1_neighbor_egde_values[matrix_size[0]-2+i];
			}
		}
	}
	return 0;
}