From bda86fe4809337483be3749c2cb96f72eb47ae91 Mon Sep 17 00:00:00 2001
From: "raphael.bach" <raphael.bach@etu.hesge.ch>
Date: Sun, 26 Jun 2022 02:21:49 +0200
Subject: [PATCH] Duplicate `MPI_COMM_WORLD`

---
 include/internal/fmpi_mpi.h |  1 +
 src/fmpi_mpi.c              | 18 ++++++++++++------
 src/fmpi_reduce.c           |  2 +-
 src/fmpi_task.c             |  4 ++--
 4 files changed, 16 insertions(+), 9 deletions(-)

diff --git a/include/internal/fmpi_mpi.h b/include/internal/fmpi_mpi.h
index 2e9f832..34cd565 100644
--- a/include/internal/fmpi_mpi.h
+++ b/include/internal/fmpi_mpi.h
@@ -50,6 +50,7 @@ typedef struct fmpi_mpi_ctx {
     char cpu_name[MPI_MAX_PROCESSOR_NAME]; //!< TODO
     int cpu_name_length; //!< TODO
     int root; //!< TODO
+    MPI_Comm world;
     int futhark_err_class; //!< TODO
     int futhark_err_code; //!< TODO
     struct fmpi_error_handler err_handler; //!< TODO
diff --git a/src/fmpi_mpi.c b/src/fmpi_mpi.c
index 03da923..7b730dc 100644
--- a/src/fmpi_mpi.c
+++ b/src/fmpi_mpi.c
@@ -61,18 +61,24 @@ struct fmpi_mpi_ctx * fmpi_mpi_init(
         FMPI_RAISE_ERROR(err_handler, "MPI", "malloc(fmpi_mpi_ctx) failed!");
         return NULL;
     }
-    ctx->root = FMPI_MPI_ROOT;
     ctx->err_handler = err_handler;
+    ctx->root = FMPI_MPI_ROOT;
 
     int err_id = MPI_Init(argc, argv);
     if(fmpi_mpi_check_error(ctx, err_id, "MPI_Init") == true) {
         return ctx;
     }
 
-    err_id = MPI_Comm_rank(MPI_COMM_WORLD, &ctx->rank);
+    err_id = MPI_Comm_dup(MPI_COMM_WORLD, &ctx->world);
+    if(fmpi_mpi_check_error(ctx, err_id, "MPI_Comm_dup") == true) {
+        free(ctx);
+        return NULL;
+    }
+
+    err_id = MPI_Comm_rank(ctx->world, &ctx->rank);
     fmpi_mpi_check_error(ctx, err_id, "MPI_Comm_rank");
 
-    err_id = MPI_Comm_size(MPI_COMM_WORLD, &ctx->size);
+    err_id = MPI_Comm_size(ctx->world, &ctx->size);
     fmpi_mpi_check_error(ctx, err_id, "MPI_Comm_size");
     // `ctx->size` is used with `malloc()` and `malloc(0)` is implementation-defined.
     assert(ctx->size > 0);
@@ -145,7 +151,7 @@ int fmpi_mpi_world_rank(const struct fmpi_mpi_ctx * const ctx)
 {
     assert(ctx != NULL);
     int rank = -1;
-    int err_id = MPI_Comm_rank(MPI_COMM_WORLD, &rank);
+    int err_id = MPI_Comm_rank(ctx->world, &rank);
     if(fmpi_mpi_check_error(ctx, err_id, "MPI_Comm_rank") == true) {
         return -1;
     }
@@ -170,7 +176,7 @@ void fmpi_mpi_abort(const struct fmpi_mpi_ctx * const ctx)
 {
     assert(ctx != NULL);
     if(fmpi_mpi_finalized(ctx) == false) {
-        int err_id = MPI_Abort(MPI_COMM_WORLD, MPI_ERR_UNKNOWN);
+        int err_id = MPI_Abort(ctx->world, MPI_ERR_UNKNOWN);
         fmpi_mpi_check_error(ctx, err_id, "MPI_Abort");
     }
 }
@@ -180,7 +186,7 @@ void fmpi_mpi_abort(const struct fmpi_mpi_ctx * const ctx)
 int fmpi_mpi_world_barrier(const struct fmpi_mpi_ctx * const ctx)
 {
     assert(ctx != NULL);
-    const int err_id = MPI_Barrier(MPI_COMM_WORLD);
+    const int err_id = MPI_Barrier(ctx->world);
     if(fmpi_mpi_check_error(ctx, err_id, "MPI_Barrier") == true) {
         return -1;
     }
diff --git a/src/fmpi_reduce.c b/src/fmpi_reduce.c
index f071552..b896109 100644
--- a/src/fmpi_reduce.c
+++ b/src/fmpi_reduce.c
@@ -66,7 +66,7 @@ T fmpi_reduce_prod_##T(const struct fmpi_ctx * const ctx, const T * const array)
     if(fact_global == NULL) { \
         fprintf(stderr, "malloc(fact_global) failed!\n"); \
     } \
-    int err_id = MPI_Gather(array, 1, FMPI_TYPE_MPI_##T, fact_global, 1, FMPI_TYPE_MPI_##T, ctx->mpi->root, MPI_COMM_WORLD); \
+    int err_id = MPI_Gather(array, 1, FMPI_TYPE_MPI_##T, fact_global, 1, FMPI_TYPE_MPI_##T, ctx->mpi->root, ctx->mpi->world); \
     if(fmpi_mpi_check_error(ctx->mpi, err_id, "MPI_Gather()") == true) { \
         free(fact_global); \
     } \
diff --git a/src/fmpi_task.c b/src/fmpi_task.c
index cd6f59b..cfc1ef6 100644
--- a/src/fmpi_task.c
+++ b/src/fmpi_task.c
@@ -187,13 +187,13 @@ int fmpi_task_finalize(
             MPI_Reduce(
                 MPI_IN_PLACE, task->args.out.raw, (int)task->args.out.cnt,
                 fmpi_mpi_type(task->args.out.type.base), MPI_SUM,
-                ctx->mpi->root, MPI_COMM_WORLD
+                ctx->mpi->root, ctx->mpi->world
             );
         } else {
             MPI_Reduce(
                 task->args.out.raw, task->args.out.raw, (int)task->args.out.cnt,
                 fmpi_mpi_type(task->args.out.type.base), MPI_SUM,
-                ctx->mpi->root, MPI_COMM_WORLD
+                ctx->mpi->root, ctx->mpi->world
             );
         }
     }
-- 
GitLab