diff --git a/include/internal/fmpi_mpi.h b/include/internal/fmpi_mpi.h index 2e9f832740ded562f29dcb378bb958d3ede045ed..34cd5652553e53321d08981c879eb1ee63da9b7b 100644 --- a/include/internal/fmpi_mpi.h +++ b/include/internal/fmpi_mpi.h @@ -50,6 +50,7 @@ typedef struct fmpi_mpi_ctx { char cpu_name[MPI_MAX_PROCESSOR_NAME]; //!< TODO int cpu_name_length; //!< TODO int root; //!< TODO + MPI_Comm world; int futhark_err_class; //!< TODO int futhark_err_code; //!< TODO struct fmpi_error_handler err_handler; //!< TODO diff --git a/src/fmpi_mpi.c b/src/fmpi_mpi.c index 03da923bb3119230c7b1313a4cc43017a8a0e04d..7b730dc63ac45c648f666bb61c0fb7474066ffb4 100644 --- a/src/fmpi_mpi.c +++ b/src/fmpi_mpi.c @@ -61,18 +61,24 @@ struct fmpi_mpi_ctx * fmpi_mpi_init( FMPI_RAISE_ERROR(err_handler, "MPI", "malloc(fmpi_mpi_ctx) failed!"); return NULL; } - ctx->root = FMPI_MPI_ROOT; ctx->err_handler = err_handler; + ctx->root = FMPI_MPI_ROOT; int err_id = MPI_Init(argc, argv); if(fmpi_mpi_check_error(ctx, err_id, "MPI_Init") == true) { return ctx; } - err_id = MPI_Comm_rank(MPI_COMM_WORLD, &ctx->rank); + err_id = MPI_Comm_dup(MPI_COMM_WORLD, &ctx->world); + if(fmpi_mpi_check_error(ctx, err_id, "MPI_Comm_dup") == true) { + free(ctx); + return NULL; + } + + err_id = MPI_Comm_rank(ctx->world, &ctx->rank); fmpi_mpi_check_error(ctx, err_id, "MPI_Comm_rank"); - err_id = MPI_Comm_size(MPI_COMM_WORLD, &ctx->size); + err_id = MPI_Comm_size(ctx->world, &ctx->size); fmpi_mpi_check_error(ctx, err_id, "MPI_Comm_size"); // `ctx->size` is used with `malloc()` and `malloc(0)` is implementation-defined. assert(ctx->size > 0); @@ -145,7 +151,7 @@ int fmpi_mpi_world_rank(const struct fmpi_mpi_ctx * const ctx) { assert(ctx != NULL); int rank = -1; - int err_id = MPI_Comm_rank(MPI_COMM_WORLD, &rank); + int err_id = MPI_Comm_rank(ctx->world, &rank); if(fmpi_mpi_check_error(ctx, err_id, "MPI_Comm_rank") == true) { return -1; } @@ -170,7 +176,7 @@ void fmpi_mpi_abort(const struct fmpi_mpi_ctx * const ctx) { assert(ctx != NULL); if(fmpi_mpi_finalized(ctx) == false) { - int err_id = MPI_Abort(MPI_COMM_WORLD, MPI_ERR_UNKNOWN); + int err_id = MPI_Abort(ctx->world, MPI_ERR_UNKNOWN); fmpi_mpi_check_error(ctx, err_id, "MPI_Abort"); } } @@ -180,7 +186,7 @@ void fmpi_mpi_abort(const struct fmpi_mpi_ctx * const ctx) int fmpi_mpi_world_barrier(const struct fmpi_mpi_ctx * const ctx) { assert(ctx != NULL); - const int err_id = MPI_Barrier(MPI_COMM_WORLD); + const int err_id = MPI_Barrier(ctx->world); if(fmpi_mpi_check_error(ctx, err_id, "MPI_Barrier") == true) { return -1; } diff --git a/src/fmpi_reduce.c b/src/fmpi_reduce.c index f07155294fd951917613876f330f885279f1362e..b8961097a3f600348222472a28e96ac980834bf7 100644 --- a/src/fmpi_reduce.c +++ b/src/fmpi_reduce.c @@ -66,7 +66,7 @@ T fmpi_reduce_prod_##T(const struct fmpi_ctx * const ctx, const T * const array) if(fact_global == NULL) { \ fprintf(stderr, "malloc(fact_global) failed!\n"); \ } \ - int err_id = MPI_Gather(array, 1, FMPI_TYPE_MPI_##T, fact_global, 1, FMPI_TYPE_MPI_##T, ctx->mpi->root, MPI_COMM_WORLD); \ + int err_id = MPI_Gather(array, 1, FMPI_TYPE_MPI_##T, fact_global, 1, FMPI_TYPE_MPI_##T, ctx->mpi->root, ctx->mpi->world); \ if(fmpi_mpi_check_error(ctx->mpi, err_id, "MPI_Gather()") == true) { \ free(fact_global); \ } \ diff --git a/src/fmpi_task.c b/src/fmpi_task.c index cd6f59b340e771dc064eb1ccb8419ae14f609ff1..cfc1ef6f7010a3b0c46880963b32b41b873aaf8b 100644 --- a/src/fmpi_task.c +++ b/src/fmpi_task.c @@ -187,13 +187,13 @@ int fmpi_task_finalize( MPI_Reduce( MPI_IN_PLACE, task->args.out.raw, (int)task->args.out.cnt, fmpi_mpi_type(task->args.out.type.base), MPI_SUM, - ctx->mpi->root, MPI_COMM_WORLD + ctx->mpi->root, ctx->mpi->world ); } else { MPI_Reduce( task->args.out.raw, task->args.out.raw, (int)task->args.out.cnt, fmpi_mpi_type(task->args.out.type.base), MPI_SUM, - ctx->mpi->root, MPI_COMM_WORLD + ctx->mpi->root, ctx->mpi->world ); } }