// SPDX-License-Identifier: 0BSD
/*!
 * @file
 * @license{
 * BSD Zero Clause License
 *
 * Copyright (c) 2022 by Raphael Bach
 *
 * Permission to use, copy, modify, and/or distribute this software for any
 * purpose with or without fee is hereby granted.
 *
 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH
 * REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY
 * AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT,
 * INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM
 * LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR
 * OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR
 * PERFORMANCE OF THIS SOFTWARE.
 * }
 */
/*==============================================================================
    INCLUDE
==============================================================================*/
// Own header
#include "fmpi_task.h"
// C Standard Library
#include <assert.h>
#include <stdbool.h>
#include <stddef.h> // NULL, size_t
#include <stdio.h>
// Internal
#include "fmpi_domain.h"
#include "fmpi_stencil.h"
#include "internal/fmpi_ctx.h"
#include "internal/fmpi_error.h"
#include "internal/fmpi_futhark.h"
#include "internal/fmpi_mpi.h"
#include "internal/fmpi_type.h"
/*==============================================================================
    PUBLIC FUNCTION DEFINITION
==============================================================================*/
/*------------------------------------------------------------------------------
    fmpi_task_register()
------------------------------------------------------------------------------*/
struct fmpi_task fmpi_task_register(
    struct fmpi_ctx * const ctx, const fmpi_task_func func,
    const char * const name, const enum fmpi_task_type type,
    const struct fmpi_stencil stencil, const struct fmpi_task_args * const args
){
    assert(ctx != NULL);
    assert(ctx->task_cnt < FMPI_TASK_MAX);
    assert(name != NULL);
    assert(args != NULL);

    struct fmpi_task task = {
        .func = func,
        .name = name,
        .type = type,
        .args = *args,
        .stencil = stencil
    };
    for(size_t i = 0; i < task.args.cnt; i++) {
        task.domains[i] = fmpi_new_domain(ctx, &args->in[i], stencil);
        const struct fmpi_data * const data = (stencil.type != FMPI_STENCIL_NONE)
            ? &task.domains[i].halo
            : &task.domains[i].inner;
        void * gpu_data = NULL;
        const char * func_name = NULL;
        if(task.type == FMPI_TASK_TYPE_SYNC) {
        //! @todo Could fmpi_futhark_new_data_async() be called here instead?
            gpu_data = fmpi_futhark_new_data_sync(
                ctx->fut, data->raw, data->type.base, data->dim_cnt,
                data->dim_len[0], data->dim_len[1], data->dim_len[2]
            );
            func_name = "fmpi_futhark_new_data_sync";
        } else {
            gpu_data = fmpi_futhark_new_data_async(
                ctx->fut, data->raw, data->type.base, data->dim_cnt,
                data->dim_len[0], data->dim_len[1], data->dim_len[2]
            );
            func_name = "fmpi_futhark_new_data_async";
        }
        if(gpu_data == NULL) {
            FMPI_RAISE_ERROR(ctx->err_handler, "FMPI",
                "%s() failed!", func_name
            );
        }
        task.args.in[i].gpu = gpu_data;
    }
    ctx->tasks[ctx->task_cnt++] = task;
    if(task.type == FMPI_TASK_TYPE_SYNC) {
        fmpi_futhark_sync(ctx->fut);
    }
    return task;
}
/*------------------------------------------------------------------------------
    fmpi_task_run_sync()
------------------------------------------------------------------------------*/
int fmpi_task_run_sync(
    const struct fmpi_ctx * const ctx, const struct fmpi_task * const task
){
    assert(ctx != NULL);
    assert(task != NULL);

    const int err_id = task->func(ctx, &task->args);
    fmpi_futhark_sync(ctx->fut);
    fmpi_futhark_check_error(ctx->fut, "task->func");
    if(task->args.out.type.derived == FMPI_TYPE_ARRAY) {
        void * out = fmpi_futhark_get_data_sync(
            ctx->fut, task->args.out.gpu, task->args.out.raw,
            task->args.out.type.base, task->args.out.dim_cnt
        );
        if(out == NULL) {
            FMPI_RAISE_ERROR(ctx->err_handler, "FMPI",
                "fmpi_futhark_get_data_sync() failed!"
            );
        }
        const int err = fmpi_futhark_free_data_sync(
            ctx->fut, task->args.out.gpu, task->args.out.type.base,
            task->args.out.dim_cnt
        );
        if(err != 0) {
            FMPI_RAISE_ERROR(ctx->err_handler, "FMPI",
                "fmpi_futhark_free_data_sync() failed!"
            );
        }
    }
    return err_id;
}
/*------------------------------------------------------------------------------
    fmpi_task_run_async()
------------------------------------------------------------------------------*/
int fmpi_task_run_async(
    const struct fmpi_ctx * const ctx, const struct fmpi_task * const task
){
    assert(ctx != NULL);
    assert(task != NULL);
    return task->func(ctx, &task->args);
}
/*------------------------------------------------------------------------------
    fmpi_task_finalize()
------------------------------------------------------------------------------*/
int fmpi_task_finalize(
    const struct fmpi_ctx * const ctx, const struct fmpi_task * const task,
    const enum fmpi_task_op op
){
    assert(ctx != NULL);
    assert(task != NULL);
    if(op == FMPI_TASK_OP_NONE) {
        return FMPI_SUCCESS;
    }
    if(op == FMPI_TASK_OP_SUM) {
        const int err = fmpi_mpi_world_reduce_in_place(
            ctx->mpi, task->args.out.raw, task->args.out.cnt,
            fmpi_mpi_type(task->args.out.type.base), MPI_SUM
        );
        if(err != FMPI_SUCCESS) {
            FMPI_RAISE_ERROR(ctx->err_handler, "FMPI",
                "fmpi_mpi_world_reduce_in_place() failed!"
            );
        }
        return err;
    }
    if(op == FMPI_TASK_OP_GATHER) {
        const size_t cnt = task->domains[0].inner.cnt;
        MPI_Datatype type = fmpi_mpi_type(task->args.out.type.base);
        const int err = fmpi_mpi_world_gather_in_place(
            ctx->mpi, task->args.out.raw, type, cnt, cnt
        );
        if(err != FMPI_SUCCESS) {
            FMPI_RAISE_ERROR(ctx->err_handler, "FMPI",
                "fmpi_mpi_world_gather_in_place() failed!"
            );
        }
        return err;
    }
    return FMPI_SUCCESS;
}