diff --git a/include/fmpi_domain.h b/include/fmpi_domain.h index 51cb754a23b0fb56f8d9cb190fcfd292357c24fc..8512564881627942d4a7345df32e53d0a72b30c9 100644 --- a/include/fmpi_domain.h +++ b/include/fmpi_domain.h @@ -65,6 +65,12 @@ void fmpi_domain_set_inner( const struct fmpi_ctx * ctx, const struct fmpi_domain * domain, const void * data ); +/*------------------------------------------------------------------------------ + fmpi_halo_exchange_1d() +------------------------------------------------------------------------------*/ +int fmpi_halo_exchange_1d( + const struct fmpi_ctx * ctx, const struct fmpi_domain * domain +); /*============================================================================== GUARD ==============================================================================*/ diff --git a/src/fmpi_domain.c b/src/fmpi_domain.c index 4927832425c5753582cd14da88e2b8fd6db2208a..62d1ea8553be7cd7b4a36b4cbe6fe8c6cb96a644 100644 --- a/src/fmpi_domain.c +++ b/src/fmpi_domain.c @@ -89,6 +89,40 @@ void fmpi_domain_set_inner( assert(data != NULL); memcpy(domain->inner.raw, data, domain->inner.size); } +/*------------------------------------------------------------------------------ + fmpi_halo_exchange_1d() +------------------------------------------------------------------------------*/ +int fmpi_halo_exchange_1d( + const struct fmpi_ctx * const ctx, const struct fmpi_domain * const domain +){ + assert(ctx != NULL); + assert(domain != NULL); + + const size_t idx_1 = domain->halo.type.size; + const size_t idx_nm2 = (domain->halo.cnt-2) * domain->halo.type.size; + const size_t idx_nm1 = idx_nm2 + idx_1; + + void * const buf_0 = domain->halo.raw; + const void * const buf_1 = (char *)buf_0 + idx_1; + const void * const buf_nm2 = (char *)buf_0 + idx_nm2; + void * const buf_nm1 = (char *)buf_0 + idx_nm1; + + const int rank = ctx->mpi->rank; + const int left = (rank != 0) ? (rank - 1) : MPI_PROC_NULL; + const int right = (rank != (ctx->mpi->size-1)) ? (rank + 1) : MPI_PROC_NULL; + MPI_Datatype type = fmpi_mpi_type(domain->halo.type.base); + MPI_Sendrecv( + buf_nm2, 1, type, right, 0, + buf_0 , 1, type, left , 0, + ctx->mpi->world, MPI_STATUS_IGNORE + ); + MPI_Sendrecv( + buf_1 , 1, type, left , 1, + buf_nm1, 1, type, right, 1, + ctx->mpi->world, MPI_STATUS_IGNORE + ); + return FMPI_SUCCESS; +} /*============================================================================== PRIVATE FUNCTION DEFINITION ==============================================================================*/ diff --git a/src/fmpi_task.c b/src/fmpi_task.c index 3119a205398ab5ca6fbcbe517aa2b05eadda31be..b6b26b03591269a191ccacf17dd413963aa34ded 100644 --- a/src/fmpi_task.c +++ b/src/fmpi_task.c @@ -128,36 +128,7 @@ int fmpi_task_run_sync( } if(task->stencil.type != FMPI_STENCIL_NONE) { fmpi_domain_set_inner(ctx, &task->domains[0], task->args.out.raw); - - const int rank = ctx->mpi->rank; - size_t type_size = task->domains[0].inner.type.size; - - const size_t idx_1 = type_size; - const size_t idx_nm2 = (task->domains[0].halo.cnt-2) * type_size; - const size_t idx_nm1 = idx_nm2 + idx_1; - - void * const buf_0 = task->domains[0].halo.raw; - const void * const buf_1 = (char *)buf_0 + idx_1; - const void * const buf_nm2 = (char *)buf_0 + idx_nm2; - void * const buf_nm1 = (char *)buf_0 + idx_nm1; - - MPI_Datatype type = fmpi_mpi_type(task->domains[0].halo.type.base); - if((rank % 2) == 0) { - int left = (rank != 0) ? (rank - 1) : MPI_PROC_NULL; - int right = rank + 1; - MPI_Send(buf_nm2, 1, type, right, 0, ctx->mpi->world); - MPI_Recv(buf_0 , 1, type, left, 1, ctx->mpi->world, MPI_STATUS_IGNORE); - MPI_Send(buf_1 , 1, type, left, 2, ctx->mpi->world); - MPI_Recv(buf_nm1, 1, type, right, 3, ctx->mpi->world, MPI_STATUS_IGNORE); - - } else { - int left = rank - 1; - int right = (rank != (ctx->mpi->size-1)) ? (rank + 1) : MPI_PROC_NULL; - MPI_Recv(buf_0 , 1, type, left, 0, ctx->mpi->world, MPI_STATUS_IGNORE); - MPI_Send(buf_nm2, 1, type, right, 1, ctx->mpi->world); - MPI_Recv(buf_nm1, 1, type, right, 2, ctx->mpi->world, MPI_STATUS_IGNORE); - MPI_Send(buf_1 , 1, type, left, 3, ctx->mpi->world); - } + fmpi_halo_exchange_1d(ctx, &task->domains[0]); const int err = fmpi_futhark_free_data_sync( ctx->fut, task->args.in[0].gpu, task->domains[0].halo.type.base, task->domains[0].halo.dim_cnt