Skip to content
Snippets Groups Projects
Verified Commit bda86fe4 authored by raphael.bach's avatar raphael.bach
Browse files

Duplicate `MPI_COMM_WORLD`

parent b47e8600
No related branches found
No related tags found
No related merge requests found
......@@ -50,6 +50,7 @@ typedef struct fmpi_mpi_ctx {
char cpu_name[MPI_MAX_PROCESSOR_NAME]; //!< TODO
int cpu_name_length; //!< TODO
int root; //!< TODO
MPI_Comm world;
int futhark_err_class; //!< TODO
int futhark_err_code; //!< TODO
struct fmpi_error_handler err_handler; //!< TODO
......
......@@ -61,18 +61,24 @@ struct fmpi_mpi_ctx * fmpi_mpi_init(
FMPI_RAISE_ERROR(err_handler, "MPI", "malloc(fmpi_mpi_ctx) failed!");
return NULL;
}
ctx->root = FMPI_MPI_ROOT;
ctx->err_handler = err_handler;
ctx->root = FMPI_MPI_ROOT;
int err_id = MPI_Init(argc, argv);
if(fmpi_mpi_check_error(ctx, err_id, "MPI_Init") == true) {
return ctx;
}
err_id = MPI_Comm_rank(MPI_COMM_WORLD, &ctx->rank);
err_id = MPI_Comm_dup(MPI_COMM_WORLD, &ctx->world);
if(fmpi_mpi_check_error(ctx, err_id, "MPI_Comm_dup") == true) {
free(ctx);
return NULL;
}
err_id = MPI_Comm_rank(ctx->world, &ctx->rank);
fmpi_mpi_check_error(ctx, err_id, "MPI_Comm_rank");
err_id = MPI_Comm_size(MPI_COMM_WORLD, &ctx->size);
err_id = MPI_Comm_size(ctx->world, &ctx->size);
fmpi_mpi_check_error(ctx, err_id, "MPI_Comm_size");
// `ctx->size` is used with `malloc()` and `malloc(0)` is implementation-defined.
assert(ctx->size > 0);
......@@ -145,7 +151,7 @@ int fmpi_mpi_world_rank(const struct fmpi_mpi_ctx * const ctx)
{
assert(ctx != NULL);
int rank = -1;
int err_id = MPI_Comm_rank(MPI_COMM_WORLD, &rank);
int err_id = MPI_Comm_rank(ctx->world, &rank);
if(fmpi_mpi_check_error(ctx, err_id, "MPI_Comm_rank") == true) {
return -1;
}
......@@ -170,7 +176,7 @@ void fmpi_mpi_abort(const struct fmpi_mpi_ctx * const ctx)
{
assert(ctx != NULL);
if(fmpi_mpi_finalized(ctx) == false) {
int err_id = MPI_Abort(MPI_COMM_WORLD, MPI_ERR_UNKNOWN);
int err_id = MPI_Abort(ctx->world, MPI_ERR_UNKNOWN);
fmpi_mpi_check_error(ctx, err_id, "MPI_Abort");
}
}
......@@ -180,7 +186,7 @@ void fmpi_mpi_abort(const struct fmpi_mpi_ctx * const ctx)
int fmpi_mpi_world_barrier(const struct fmpi_mpi_ctx * const ctx)
{
assert(ctx != NULL);
const int err_id = MPI_Barrier(MPI_COMM_WORLD);
const int err_id = MPI_Barrier(ctx->world);
if(fmpi_mpi_check_error(ctx, err_id, "MPI_Barrier") == true) {
return -1;
}
......
......@@ -66,7 +66,7 @@ T fmpi_reduce_prod_##T(const struct fmpi_ctx * const ctx, const T * const array)
if(fact_global == NULL) { \
fprintf(stderr, "malloc(fact_global) failed!\n"); \
} \
int err_id = MPI_Gather(array, 1, FMPI_TYPE_MPI_##T, fact_global, 1, FMPI_TYPE_MPI_##T, ctx->mpi->root, MPI_COMM_WORLD); \
int err_id = MPI_Gather(array, 1, FMPI_TYPE_MPI_##T, fact_global, 1, FMPI_TYPE_MPI_##T, ctx->mpi->root, ctx->mpi->world); \
if(fmpi_mpi_check_error(ctx->mpi, err_id, "MPI_Gather()") == true) { \
free(fact_global); \
} \
......
......@@ -187,13 +187,13 @@ int fmpi_task_finalize(
MPI_Reduce(
MPI_IN_PLACE, task->args.out.raw, (int)task->args.out.cnt,
fmpi_mpi_type(task->args.out.type.base), MPI_SUM,
ctx->mpi->root, MPI_COMM_WORLD
ctx->mpi->root, ctx->mpi->world
);
} else {
MPI_Reduce(
task->args.out.raw, task->args.out.raw, (int)task->args.out.cnt,
fmpi_mpi_type(task->args.out.type.base), MPI_SUM,
ctx->mpi->root, MPI_COMM_WORLD
ctx->mpi->root, ctx->mpi->world
);
}
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment