// SPDX-License-Identifier: 0BSD
/*!
 * @file
 * @license{
 * BSD Zero Clause License
 *
 * Copyright (c) 2022 by Raphael Bach
 *
 * Permission to use, copy, modify, and/or distribute this software for any
 * purpose with or without fee is hereby granted.
 *
 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH
 * REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY
 * AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT,
 * INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM
 * LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR
 * OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR
 * PERFORMANCE OF THIS SOFTWARE.
 * }
 */
/*==============================================================================
    INCLUDE
==============================================================================*/
// Own header
#include "internal/fmpi_mpi.h"
// C Standard Library
#include <assert.h>
#include <stdbool.h>
#include <stdio.h>  // printf()
#include <stdlib.h> // NULL, free(), malloc()
// External
#include <mpi.h>
// Internal
#include "internal/fmpi_error.h"
/*==============================================================================
    DEFINE
==============================================================================*/
#define FMPI_MPI_ROOT 0
/*==============================================================================
    MACRO
==============================================================================*/
#define FMPI_RAISE_MPI_ERROR(ctx, ...) \
do { \
    if((ctx)->err_handler.func != NULL) { \
        FMPI_RAISE_ERROR((ctx)->err_handler, "MPI", __VA_ARGS__); \
    } \
} while(0)
/*==============================================================================
    PUBLIC FUNCTION DEFINITION
==============================================================================*/
/*------------------------------------------------------------------------------
    fmpi_mpi_init()
------------------------------------------------------------------------------*/
struct fmpi_mpi_ctx * fmpi_mpi_init(
    int * const argc, char **  argv[],
    const struct fmpi_error_handler err_handler
){
    struct fmpi_mpi_ctx * ctx = malloc(sizeof(*ctx));
    if(ctx == NULL) {
        FMPI_RAISE_ERROR(err_handler, "MPI", "malloc(fmpi_mpi_ctx) failed!");
        return NULL;
    }
    ctx->root = FMPI_MPI_ROOT;
    ctx->err_handler = err_handler;

    int err_id = MPI_Init(argc, argv);
    if(fmpi_mpi_check_error(ctx, err_id, "MPI_Init") == true) {
        return ctx;
    }

    err_id = MPI_Comm_rank(MPI_COMM_WORLD, &ctx->rank);
    fmpi_mpi_check_error(ctx, err_id, "MPI_Comm_rank");

    err_id = MPI_Comm_size(MPI_COMM_WORLD, &ctx->size);
    fmpi_mpi_check_error(ctx, err_id, "MPI_Comm_size");
    // `ctx->size` is used with `malloc()` and `malloc(0)` is implementation-defined.
    assert(ctx->size > 0);

    err_id = MPI_Get_processor_name(ctx->cpu_name, &ctx->cpu_name_length);
    fmpi_mpi_check_error(ctx, err_id, "MPI_Get_processor_name");

    return ctx;
}
/*------------------------------------------------------------------------------
    fmpi_mpi_exit()
------------------------------------------------------------------------------*/
int fmpi_mpi_exit(struct fmpi_mpi_ctx ** const ctx)
{
    assert(ctx != NULL);
    assert(*ctx != NULL);

    int err_id = MPI_Finalize();
    if(fmpi_mpi_check_error(*ctx, err_id, "MPI_Finalize") == true) {
        err_id = -1;
    }
    free(*ctx); *ctx = NULL;
    return err_id;
}
/*------------------------------------------------------------------------------
    fmpi_mpi_is_root()
------------------------------------------------------------------------------*/
_Bool fmpi_mpi_is_root(const struct fmpi_mpi_ctx * const ctx)
{
    assert(ctx != NULL);
    return ctx->rank == ctx->root;
}
/*------------------------------------------------------------------------------
    fmpi_mpi_check_error()
------------------------------------------------------------------------------*/
_Bool fmpi_mpi_check_error(
    const struct fmpi_mpi_ctx * ctx, int err_id, const char * const func
){
    assert(ctx != NULL);
    assert(func != NULL);
    if(err_id != MPI_SUCCESS) {
        char err_str[MPI_MAX_ERROR_STRING] = {'\0'};
        int err_str_len = 0;
        err_id = MPI_Error_string(err_id, err_str, &err_str_len);
        if(err_id != MPI_SUCCESS) {
            FMPI_RAISE_MPI_ERROR(ctx, "MPI_Error_string() failed!");
        }
        FMPI_RAISE_MPI_ERROR(ctx, "%s() failed! %s", func, err_str);
        return true;
    }
    return false;
}
/*------------------------------------------------------------------------------
    fmpi_mpi_ctx_print()
------------------------------------------------------------------------------*/
void fmpi_mpi_ctx_print(const struct fmpi_mpi_ctx * const ctx)
{
    assert(ctx != NULL);
    printf("MPI size: %d\n", ctx->size);
    printf("MPI rank: %d\n", ctx->rank);
    printf("MPI CPU name: %s\n", ctx->cpu_name);
    printf("MPI root: %s\n", ctx->cpu_name);
    printf("MPI futhark error class: %d\n", ctx->futhark_err_class);
    printf("MPI futhark error code: %d\n", ctx->futhark_err_code);
}
/*------------------------------------------------------------------------------
    fmpi_mpi_world_rank()
------------------------------------------------------------------------------*/
int fmpi_mpi_world_rank(const struct fmpi_mpi_ctx * const ctx)
{
    assert(ctx != NULL);
    int rank = -1;
    int err_id = MPI_Comm_rank(MPI_COMM_WORLD, &rank);
    if(fmpi_mpi_check_error(ctx, err_id, "MPI_Comm_rank") == true) {
        return -1;
    }
    return rank;
}
/*------------------------------------------------------------------------------
    fmpi_mpi_finalized()
------------------------------------------------------------------------------*/
_Bool fmpi_mpi_finalized(const struct fmpi_mpi_ctx * const ctx)
{
    int finalized = 0;
    int err_id = MPI_Finalized(&finalized);
    if(ctx != NULL) {
        fmpi_mpi_check_error(ctx, err_id, "MPI_Finalized");
    }
    return finalized;
}
/*------------------------------------------------------------------------------
    fmpi_mpi_abort()
------------------------------------------------------------------------------*/
void fmpi_mpi_abort(const struct fmpi_mpi_ctx * const ctx)
{
    assert(ctx != NULL);
    if(fmpi_mpi_finalized(ctx) == false) {
        int err_id = MPI_Abort(MPI_COMM_WORLD, MPI_ERR_UNKNOWN);
        fmpi_mpi_check_error(ctx, err_id, "MPI_Abort");
    }
}