diff --git a/src/fmpi_task.c b/src/fmpi_task.c index b6b26b03591269a191ccacf17dd413963aa34ded..421e63d6525b530dc629632818f5997010378b6e 100644 --- a/src/fmpi_task.c +++ b/src/fmpi_task.c @@ -37,6 +37,15 @@ #include "internal/fmpi_futhark.h" #include "internal/fmpi_mpi.h" #include "internal/fmpi_type.h" +/*============================================================================== + PRIVATE FUNCTION DECLARATION +==============================================================================*/ +/*------------------------------------------------------------------------------ + fmpi_task_update_input() +------------------------------------------------------------------------------*/ +static int fmpi_task_update_input( + const struct fmpi_ctx * ctx, struct fmpi_task * const task +); /*============================================================================== PUBLIC FUNCTION DEFINITION ==============================================================================*/ @@ -127,25 +136,10 @@ int fmpi_task_run_sync( } } if(task->stencil.type != FMPI_STENCIL_NONE) { - fmpi_domain_set_inner(ctx, &task->domains[0], task->args.out.raw); - fmpi_halo_exchange_1d(ctx, &task->domains[0]); - const int err = fmpi_futhark_free_data_sync( - ctx->fut, task->args.in[0].gpu, task->domains[0].halo.type.base, - task->domains[0].halo.dim_cnt - ); + const int err = fmpi_task_update_input(ctx, task); if(err != FMPI_SUCCESS) { FMPI_RAISE_ERROR(ctx->err_handler, "FMPI", - "fmpi_futhark_free_data_sync() failed!" - ); - } - task->args.in[0].gpu = fmpi_futhark_new_data_sync( - ctx->fut, task->domains[0].halo.raw, task->domains[0].halo.type.base, - task->domains[0].halo.dim_cnt, task->domains[0].halo.dim_len[0], - task->domains[0].halo.dim_len[1], task->domains[0].halo.dim_len[2] - ); - if(task->args.in[0].gpu == NULL) { - FMPI_RAISE_ERROR(ctx->err_handler, "FMPI", - "fmpi_futhark_new_data_sync() failed!" + "fmpi_task_update_input() failed!" ); } } @@ -200,3 +194,43 @@ int fmpi_task_finalize( } return FMPI_SUCCESS; } +/*============================================================================== + PRIVATE FUNCTION DEFINITION +==============================================================================*/ +/*------------------------------------------------------------------------------ + fmpi_task_update_input() +------------------------------------------------------------------------------*/ +static int fmpi_task_update_input( + const struct fmpi_ctx * const ctx, struct fmpi_task * const task +){ + assert(ctx != NULL); + assert(task != NULL); + + fmpi_domain_set_inner(ctx, &task->domains[0], task->args.out.raw); + const size_t dim_cnt = task->domains[0].inner.dim_cnt; + if(dim_cnt == 1) { + fmpi_halo_exchange_1d(ctx, &task->domains[0]); + } else if (dim_cnt == 2) { + fmpi_halo_exchange_2d(ctx, &task->domains[0]); + } + const int err = fmpi_futhark_free_data_sync( + ctx->fut, task->args.in[0].gpu, task->domains[0].halo.type.base, + task->domains[0].halo.dim_cnt + ); + if(err != FMPI_SUCCESS) { + FMPI_RAISE_ERROR(ctx->err_handler, "FMPI", + "fmpi_futhark_free_data_sync() failed!" + ); + } + task->args.in[0].gpu = fmpi_futhark_new_data_sync( + ctx->fut, task->domains[0].halo.raw, task->domains[0].halo.type.base, + task->domains[0].halo.dim_cnt, task->domains[0].halo.dim_len[0], + task->domains[0].halo.dim_len[1], task->domains[0].halo.dim_len[2] + ); + if(task->args.in[0].gpu == NULL) { + FMPI_RAISE_ERROR(ctx->err_handler, "FMPI", + "fmpi_futhark_new_data_sync() failed!" + ); + } + return err; +}