From b319f1252f2a866177be954f34385584a2a15346 Mon Sep 17 00:00:00 2001
From: "raphael.bach" <raphael.bach@etu.hesge.ch>
Date: Thu, 23 Jun 2022 22:30:55 +0200
Subject: [PATCH] Make `fmpi_task_run_sync()` retrieve output data from futhark

---
 include/fmpi_task.h                          |  1 +
 include/internal/generic/fmpi_task_generic.h |  8 ++++---
 src/fmpi_task.c                              | 22 ++++++++++++++++++--
 3 files changed, 26 insertions(+), 5 deletions(-)

diff --git a/include/fmpi_task.h b/include/fmpi_task.h
index c40dac7..187661f 100644
--- a/include/fmpi_task.h
+++ b/include/fmpi_task.h
@@ -59,6 +59,7 @@
 typedef struct fmpi_task_args {
     struct fmpi_data in[FMPI_TASK_ARGS_MAX]; //!< TODO
     struct fmpi_data out;                        //!< TODO
+    void * out_raw;
     size_t cnt;                                  //!< TODO
 } fmpi_task_args;
 /*------------------------------------------------------------------------------
diff --git a/include/internal/generic/fmpi_task_generic.h b/include/internal/generic/fmpi_task_generic.h
index 9b5bbcc..58baa49 100644
--- a/include/internal/generic/fmpi_task_generic.h
+++ b/include/internal/generic/fmpi_task_generic.h
@@ -51,15 +51,15 @@ _Pragma("GCC diagnostic ignored \"-Wcast-qual\"")\
     if(args->out.type.derived == FMPI_TYPE_ARRAY) { \
         if(args->out.dim_cnt == 1) { \
             CPL_MAP_FIXED(FMPI_PRIV_TASK_RET_FUNC, CPL_EMPTY, \
-                (FUNC, N, 1, ctx->fut->ctx, args->out.start, args->in), FMPI_TYPE_REAL) \
+                (FUNC, N, 1, ctx->fut->ctx, args->out_raw, args->in), FMPI_TYPE_REAL) \
         } \
         if(args->out.dim_cnt == 2) { \
             CPL_MAP_FIXED(FMPI_PRIV_TASK_RET_FUNC, CPL_EMPTY, \
-                (FUNC, N, 2, ctx->fut->ctx, args->out.start, args->in), FMPI_TYPE_REAL) \
+                (FUNC, N, 2, ctx->fut->ctx, args->out_raw, args->in), FMPI_TYPE_REAL) \
         } \
         if(args->out.dim_cnt == 3) { \
             CPL_MAP_FIXED(FMPI_PRIV_TASK_RET_FUNC, CPL_EMPTY, \
-                (FUNC, N, 3, ctx->fut->ctx, args->out.start, args->in), FMPI_TYPE_REAL) \
+                (FUNC, N, 3, ctx->fut->ctx, args->out_raw, args->in), FMPI_TYPE_REAL) \
         } \
     } \
     return futhark_entry_##FUNC(ctx->fut->ctx, args->out.start, FMPI_PRIV_TASK_ARGS_##N(args->in)); \
@@ -80,6 +80,7 @@ _Pragma("GCC diagnostic warning \"-Wincompatible-pointer-types\"")\
 #define FMPI_PRIV_TASK_REGISTER_1(FUNC, ctx, type, stencil, arg_out) \
     fmpi_task_register((ctx), FUNC##_0, #FUNC, (type), (stencil), &(struct fmpi_task_args){ \
         .out = (arg_out), \
+        .out_raw = NULL, \
         .cnt = 0 \
     })
 
@@ -87,6 +88,7 @@ _Pragma("GCC diagnostic warning \"-Wincompatible-pointer-types\"")\
     fmpi_task_register((ctx), FUNC##_##N, #FUNC, (type), (stencil), &(struct fmpi_task_args){ \
         .in = {__VA_ARGS__}, \
         .out = (arg_out), \
+        .out_raw = NULL, \
         .cnt = N \
     })
 
diff --git a/src/fmpi_task.c b/src/fmpi_task.c
index 848365e..5895a18 100644
--- a/src/fmpi_task.c
+++ b/src/fmpi_task.c
@@ -35,6 +35,7 @@
 #include "internal/fmpi_error.h"
 #include "internal/fmpi_futhark.h"
 #include "internal/fmpi_mpi.h"
+#include "internal/fmpi_type.h"
 /*==============================================================================
     PUBLIC FUNCTION DEFINITION
 ==============================================================================*/
@@ -110,8 +111,25 @@ int fmpi_task_run_sync(
     assert(task != NULL);
     const int err_id = task->func(ctx, &task->args);
     fmpi_futhark_sync(ctx->fut);
-    if(fmpi_futhark_check_error(ctx->fut, task->name) == true) {
-        return -1;
+    if(task->args.out.type.derived == FMPI_TYPE_ARRAY) {
+        void * out = fmpi_futhark_get_data_sync(
+            ctx->fut, task->args.out_raw, task->args.out.start,
+            task->args.out.type.base, task->args.out.dim_cnt
+        );
+        if(out == NULL) {
+            FMPI_RAISE_ERROR(ctx->err_handler, "FMPI",
+                "fmpi_futhark_get_data_sync() failed!"
+            );
+        }
+        const int err = fmpi_futhark_free_data_sync(
+            ctx->fut, task->args.out_raw, task->args.out.type.base,
+            task->args.out.dim_cnt
+        );
+        if(err != 0) {
+            FMPI_RAISE_ERROR(ctx->err_handler, "FMPI",
+                "fmpi_futhark_free_data_sync() failed!"
+            );
+        }
     }
     return err_id;
 }
-- 
GitLab