diff --git a/include/internal/fmpi_mpi.h b/include/internal/fmpi_mpi.h index 34cd5652553e53321d08981c879eb1ee63da9b7b..3f09b4391648888e344eb75eff4ba9787456f772 100644 --- a/include/internal/fmpi_mpi.h +++ b/include/internal/fmpi_mpi.h @@ -209,6 +209,14 @@ int fmpi_mpi_world_barrier(const struct fmpi_mpi_ctx * ctx); fmpi_mpi_type() ------------------------------------------------------------------------------*/ MPI_Datatype fmpi_mpi_type(enum fmpi_type_base type); +/*------------------------------------------------------------------------------ + fmpi_mpi_gather_in_place() +------------------------------------------------------------------------------*/ +int fmpi_mpi_gather_in_place( + const struct fmpi_mpi_ctx * ctx, const void * send_buf, size_t send_cnt, + MPI_Datatype send_type, void * recv_buf, size_t recv_cnt, + MPI_Datatype recv_type, int root, MPI_Comm comm +); /*============================================================================== GUARD ==============================================================================*/ diff --git a/src/fmpi_mpi.c b/src/fmpi_mpi.c index 7b730dc63ac45c648f666bb61c0fb7474066ffb4..cb852f4af794dd1d91ff0f74ed37efcac60d0d88 100644 --- a/src/fmpi_mpi.c +++ b/src/fmpi_mpi.c @@ -25,6 +25,7 @@ #include "internal/fmpi_mpi.h" // C Standard Library #include <assert.h> +#include <limits.h> // INT_MAX #include <stdbool.h> #include <stdio.h> // printf() #include <stdlib.h> // NULL, free(), malloc() @@ -212,3 +213,39 @@ MPI_Datatype fmpi_mpi_type(const enum fmpi_type_base type) }; return fmpi_mpi_type_list[type]; } +/*------------------------------------------------------------------------------ + fmpi_mpi_gather_in_place() +------------------------------------------------------------------------------*/ +int fmpi_mpi_gather_in_place( + const struct fmpi_mpi_ctx * const ctx, + const void * const send_buf, const size_t send_cnt, MPI_Datatype send_type, + void * const recv_buf, const size_t recv_cnt, MPI_Datatype recv_type, + const int root, MPI_Comm comm +) { + assert(ctx != NULL); + assert(send_buf != NULL); + assert(recv_buf != NULL); + assert(send_cnt <= INT_MAX); + assert(recv_cnt <= INT_MAX); + if(send_type == recv_type) { + assert(send_cnt <= recv_cnt); + } + int err = 0; + if(ctx->rank == root) { + err = MPI_Gather( + MPI_IN_PLACE, (int)send_cnt, send_type, + recv_buf , (int)recv_cnt, recv_type, + root, comm + ); + } else { + err = MPI_Gather( + send_buf, (int)send_cnt, send_type, + recv_buf, (int)recv_cnt, recv_type, + root, comm + ); + } + if(fmpi_mpi_check_error(ctx, err, "MPI_Gather") == true) { + return -1; + } + return 0; +}