From bda86fe4809337483be3749c2cb96f72eb47ae91 Mon Sep 17 00:00:00 2001 From: "raphael.bach" <raphael.bach@etu.hesge.ch> Date: Sun, 26 Jun 2022 02:21:49 +0200 Subject: [PATCH] Duplicate `MPI_COMM_WORLD` --- include/internal/fmpi_mpi.h | 1 + src/fmpi_mpi.c | 18 ++++++++++++------ src/fmpi_reduce.c | 2 +- src/fmpi_task.c | 4 ++-- 4 files changed, 16 insertions(+), 9 deletions(-) diff --git a/include/internal/fmpi_mpi.h b/include/internal/fmpi_mpi.h index 2e9f832..34cd565 100644 --- a/include/internal/fmpi_mpi.h +++ b/include/internal/fmpi_mpi.h @@ -50,6 +50,7 @@ typedef struct fmpi_mpi_ctx { char cpu_name[MPI_MAX_PROCESSOR_NAME]; //!< TODO int cpu_name_length; //!< TODO int root; //!< TODO + MPI_Comm world; int futhark_err_class; //!< TODO int futhark_err_code; //!< TODO struct fmpi_error_handler err_handler; //!< TODO diff --git a/src/fmpi_mpi.c b/src/fmpi_mpi.c index 03da923..7b730dc 100644 --- a/src/fmpi_mpi.c +++ b/src/fmpi_mpi.c @@ -61,18 +61,24 @@ struct fmpi_mpi_ctx * fmpi_mpi_init( FMPI_RAISE_ERROR(err_handler, "MPI", "malloc(fmpi_mpi_ctx) failed!"); return NULL; } - ctx->root = FMPI_MPI_ROOT; ctx->err_handler = err_handler; + ctx->root = FMPI_MPI_ROOT; int err_id = MPI_Init(argc, argv); if(fmpi_mpi_check_error(ctx, err_id, "MPI_Init") == true) { return ctx; } - err_id = MPI_Comm_rank(MPI_COMM_WORLD, &ctx->rank); + err_id = MPI_Comm_dup(MPI_COMM_WORLD, &ctx->world); + if(fmpi_mpi_check_error(ctx, err_id, "MPI_Comm_dup") == true) { + free(ctx); + return NULL; + } + + err_id = MPI_Comm_rank(ctx->world, &ctx->rank); fmpi_mpi_check_error(ctx, err_id, "MPI_Comm_rank"); - err_id = MPI_Comm_size(MPI_COMM_WORLD, &ctx->size); + err_id = MPI_Comm_size(ctx->world, &ctx->size); fmpi_mpi_check_error(ctx, err_id, "MPI_Comm_size"); // `ctx->size` is used with `malloc()` and `malloc(0)` is implementation-defined. assert(ctx->size > 0); @@ -145,7 +151,7 @@ int fmpi_mpi_world_rank(const struct fmpi_mpi_ctx * const ctx) { assert(ctx != NULL); int rank = -1; - int err_id = MPI_Comm_rank(MPI_COMM_WORLD, &rank); + int err_id = MPI_Comm_rank(ctx->world, &rank); if(fmpi_mpi_check_error(ctx, err_id, "MPI_Comm_rank") == true) { return -1; } @@ -170,7 +176,7 @@ void fmpi_mpi_abort(const struct fmpi_mpi_ctx * const ctx) { assert(ctx != NULL); if(fmpi_mpi_finalized(ctx) == false) { - int err_id = MPI_Abort(MPI_COMM_WORLD, MPI_ERR_UNKNOWN); + int err_id = MPI_Abort(ctx->world, MPI_ERR_UNKNOWN); fmpi_mpi_check_error(ctx, err_id, "MPI_Abort"); } } @@ -180,7 +186,7 @@ void fmpi_mpi_abort(const struct fmpi_mpi_ctx * const ctx) int fmpi_mpi_world_barrier(const struct fmpi_mpi_ctx * const ctx) { assert(ctx != NULL); - const int err_id = MPI_Barrier(MPI_COMM_WORLD); + const int err_id = MPI_Barrier(ctx->world); if(fmpi_mpi_check_error(ctx, err_id, "MPI_Barrier") == true) { return -1; } diff --git a/src/fmpi_reduce.c b/src/fmpi_reduce.c index f071552..b896109 100644 --- a/src/fmpi_reduce.c +++ b/src/fmpi_reduce.c @@ -66,7 +66,7 @@ T fmpi_reduce_prod_##T(const struct fmpi_ctx * const ctx, const T * const array) if(fact_global == NULL) { \ fprintf(stderr, "malloc(fact_global) failed!\n"); \ } \ - int err_id = MPI_Gather(array, 1, FMPI_TYPE_MPI_##T, fact_global, 1, FMPI_TYPE_MPI_##T, ctx->mpi->root, MPI_COMM_WORLD); \ + int err_id = MPI_Gather(array, 1, FMPI_TYPE_MPI_##T, fact_global, 1, FMPI_TYPE_MPI_##T, ctx->mpi->root, ctx->mpi->world); \ if(fmpi_mpi_check_error(ctx->mpi, err_id, "MPI_Gather()") == true) { \ free(fact_global); \ } \ diff --git a/src/fmpi_task.c b/src/fmpi_task.c index cd6f59b..cfc1ef6 100644 --- a/src/fmpi_task.c +++ b/src/fmpi_task.c @@ -187,13 +187,13 @@ int fmpi_task_finalize( MPI_Reduce( MPI_IN_PLACE, task->args.out.raw, (int)task->args.out.cnt, fmpi_mpi_type(task->args.out.type.base), MPI_SUM, - ctx->mpi->root, MPI_COMM_WORLD + ctx->mpi->root, ctx->mpi->world ); } else { MPI_Reduce( task->args.out.raw, task->args.out.raw, (int)task->args.out.cnt, fmpi_mpi_type(task->args.out.type.base), MPI_SUM, - ctx->mpi->root, MPI_COMM_WORLD + ctx->mpi->root, ctx->mpi->world ); } } -- GitLab