diff --git a/include/internal/fmpi_mpi.h b/include/internal/fmpi_mpi.h
index 2e9f832740ded562f29dcb378bb958d3ede045ed..34cd5652553e53321d08981c879eb1ee63da9b7b 100644
--- a/include/internal/fmpi_mpi.h
+++ b/include/internal/fmpi_mpi.h
@@ -50,6 +50,7 @@ typedef struct fmpi_mpi_ctx {
     char cpu_name[MPI_MAX_PROCESSOR_NAME]; //!< TODO
     int cpu_name_length; //!< TODO
     int root; //!< TODO
+    MPI_Comm world;
     int futhark_err_class; //!< TODO
     int futhark_err_code; //!< TODO
     struct fmpi_error_handler err_handler; //!< TODO
diff --git a/src/fmpi_mpi.c b/src/fmpi_mpi.c
index 03da923bb3119230c7b1313a4cc43017a8a0e04d..7b730dc63ac45c648f666bb61c0fb7474066ffb4 100644
--- a/src/fmpi_mpi.c
+++ b/src/fmpi_mpi.c
@@ -61,18 +61,24 @@ struct fmpi_mpi_ctx * fmpi_mpi_init(
         FMPI_RAISE_ERROR(err_handler, "MPI", "malloc(fmpi_mpi_ctx) failed!");
         return NULL;
     }
-    ctx->root = FMPI_MPI_ROOT;
     ctx->err_handler = err_handler;
+    ctx->root = FMPI_MPI_ROOT;
 
     int err_id = MPI_Init(argc, argv);
     if(fmpi_mpi_check_error(ctx, err_id, "MPI_Init") == true) {
         return ctx;
     }
 
-    err_id = MPI_Comm_rank(MPI_COMM_WORLD, &ctx->rank);
+    err_id = MPI_Comm_dup(MPI_COMM_WORLD, &ctx->world);
+    if(fmpi_mpi_check_error(ctx, err_id, "MPI_Comm_dup") == true) {
+        free(ctx);
+        return NULL;
+    }
+
+    err_id = MPI_Comm_rank(ctx->world, &ctx->rank);
     fmpi_mpi_check_error(ctx, err_id, "MPI_Comm_rank");
 
-    err_id = MPI_Comm_size(MPI_COMM_WORLD, &ctx->size);
+    err_id = MPI_Comm_size(ctx->world, &ctx->size);
     fmpi_mpi_check_error(ctx, err_id, "MPI_Comm_size");
     // `ctx->size` is used with `malloc()` and `malloc(0)` is implementation-defined.
     assert(ctx->size > 0);
@@ -145,7 +151,7 @@ int fmpi_mpi_world_rank(const struct fmpi_mpi_ctx * const ctx)
 {
     assert(ctx != NULL);
     int rank = -1;
-    int err_id = MPI_Comm_rank(MPI_COMM_WORLD, &rank);
+    int err_id = MPI_Comm_rank(ctx->world, &rank);
     if(fmpi_mpi_check_error(ctx, err_id, "MPI_Comm_rank") == true) {
         return -1;
     }
@@ -170,7 +176,7 @@ void fmpi_mpi_abort(const struct fmpi_mpi_ctx * const ctx)
 {
     assert(ctx != NULL);
     if(fmpi_mpi_finalized(ctx) == false) {
-        int err_id = MPI_Abort(MPI_COMM_WORLD, MPI_ERR_UNKNOWN);
+        int err_id = MPI_Abort(ctx->world, MPI_ERR_UNKNOWN);
         fmpi_mpi_check_error(ctx, err_id, "MPI_Abort");
     }
 }
@@ -180,7 +186,7 @@ void fmpi_mpi_abort(const struct fmpi_mpi_ctx * const ctx)
 int fmpi_mpi_world_barrier(const struct fmpi_mpi_ctx * const ctx)
 {
     assert(ctx != NULL);
-    const int err_id = MPI_Barrier(MPI_COMM_WORLD);
+    const int err_id = MPI_Barrier(ctx->world);
     if(fmpi_mpi_check_error(ctx, err_id, "MPI_Barrier") == true) {
         return -1;
     }
diff --git a/src/fmpi_reduce.c b/src/fmpi_reduce.c
index f07155294fd951917613876f330f885279f1362e..b8961097a3f600348222472a28e96ac980834bf7 100644
--- a/src/fmpi_reduce.c
+++ b/src/fmpi_reduce.c
@@ -66,7 +66,7 @@ T fmpi_reduce_prod_##T(const struct fmpi_ctx * const ctx, const T * const array)
     if(fact_global == NULL) { \
         fprintf(stderr, "malloc(fact_global) failed!\n"); \
     } \
-    int err_id = MPI_Gather(array, 1, FMPI_TYPE_MPI_##T, fact_global, 1, FMPI_TYPE_MPI_##T, ctx->mpi->root, MPI_COMM_WORLD); \
+    int err_id = MPI_Gather(array, 1, FMPI_TYPE_MPI_##T, fact_global, 1, FMPI_TYPE_MPI_##T, ctx->mpi->root, ctx->mpi->world); \
     if(fmpi_mpi_check_error(ctx->mpi, err_id, "MPI_Gather()") == true) { \
         free(fact_global); \
     } \
diff --git a/src/fmpi_task.c b/src/fmpi_task.c
index cd6f59b340e771dc064eb1ccb8419ae14f609ff1..cfc1ef6f7010a3b0c46880963b32b41b873aaf8b 100644
--- a/src/fmpi_task.c
+++ b/src/fmpi_task.c
@@ -187,13 +187,13 @@ int fmpi_task_finalize(
             MPI_Reduce(
                 MPI_IN_PLACE, task->args.out.raw, (int)task->args.out.cnt,
                 fmpi_mpi_type(task->args.out.type.base), MPI_SUM,
-                ctx->mpi->root, MPI_COMM_WORLD
+                ctx->mpi->root, ctx->mpi->world
             );
         } else {
             MPI_Reduce(
                 task->args.out.raw, task->args.out.raw, (int)task->args.out.cnt,
                 fmpi_mpi_type(task->args.out.type.base), MPI_SUM,
-                ctx->mpi->root, MPI_COMM_WORLD
+                ctx->mpi->root, ctx->mpi->world
             );
         }
     }