From 4202ad35d0a15882077c624e2a22c0c540a664b3 Mon Sep 17 00:00:00 2001
From: "raphael.bach" <raphael.bach@etu.hesge.ch>
Date: Mon, 18 Jul 2022 22:24:21 +0200
Subject: [PATCH] Add `fmpi_task_update_input()`

---
 src/fmpi_task.c | 68 ++++++++++++++++++++++++++++++++++++-------------
 1 file changed, 51 insertions(+), 17 deletions(-)

diff --git a/src/fmpi_task.c b/src/fmpi_task.c
index b6b26b0..421e63d 100644
--- a/src/fmpi_task.c
+++ b/src/fmpi_task.c
@@ -37,6 +37,15 @@
 #include "internal/fmpi_futhark.h"
 #include "internal/fmpi_mpi.h"
 #include "internal/fmpi_type.h"
+/*==============================================================================
+    PRIVATE FUNCTION DECLARATION
+==============================================================================*/
+/*------------------------------------------------------------------------------
+    fmpi_task_update_input()
+------------------------------------------------------------------------------*/
+static int fmpi_task_update_input(
+    const struct fmpi_ctx * ctx, struct fmpi_task * const task
+);
 /*==============================================================================
     PUBLIC FUNCTION DEFINITION
 ==============================================================================*/
@@ -127,25 +136,10 @@ int fmpi_task_run_sync(
         }
     }
     if(task->stencil.type != FMPI_STENCIL_NONE) {
-        fmpi_domain_set_inner(ctx, &task->domains[0], task->args.out.raw);
-        fmpi_halo_exchange_1d(ctx, &task->domains[0]);
-        const int err = fmpi_futhark_free_data_sync(
-            ctx->fut, task->args.in[0].gpu, task->domains[0].halo.type.base,
-            task->domains[0].halo.dim_cnt
-        );
+        const int err = fmpi_task_update_input(ctx, task);
         if(err != FMPI_SUCCESS) {
             FMPI_RAISE_ERROR(ctx->err_handler, "FMPI",
-                "fmpi_futhark_free_data_sync() failed!"
-            );
-        }
-        task->args.in[0].gpu = fmpi_futhark_new_data_sync(
-            ctx->fut, task->domains[0].halo.raw, task->domains[0].halo.type.base,
-            task->domains[0].halo.dim_cnt, task->domains[0].halo.dim_len[0],
-            task->domains[0].halo.dim_len[1], task->domains[0].halo.dim_len[2]
-        );
-        if(task->args.in[0].gpu == NULL) {
-            FMPI_RAISE_ERROR(ctx->err_handler, "FMPI",
-                "fmpi_futhark_new_data_sync() failed!"
+                "fmpi_task_update_input() failed!"
             );
         }
     }
@@ -200,3 +194,43 @@ int fmpi_task_finalize(
     }
     return FMPI_SUCCESS;
 }
+/*==============================================================================
+    PRIVATE FUNCTION DEFINITION
+==============================================================================*/
+/*------------------------------------------------------------------------------
+    fmpi_task_update_input()
+------------------------------------------------------------------------------*/
+static int fmpi_task_update_input(
+    const struct fmpi_ctx * const ctx, struct fmpi_task * const task
+){
+    assert(ctx != NULL);
+    assert(task != NULL);
+
+    fmpi_domain_set_inner(ctx, &task->domains[0], task->args.out.raw);
+    const size_t dim_cnt = task->domains[0].inner.dim_cnt;
+    if(dim_cnt == 1) {
+        fmpi_halo_exchange_1d(ctx, &task->domains[0]);
+    } else if (dim_cnt == 2) {
+        fmpi_halo_exchange_2d(ctx, &task->domains[0]);
+    }
+    const int err = fmpi_futhark_free_data_sync(
+        ctx->fut, task->args.in[0].gpu, task->domains[0].halo.type.base,
+        task->domains[0].halo.dim_cnt
+    );
+    if(err != FMPI_SUCCESS) {
+        FMPI_RAISE_ERROR(ctx->err_handler, "FMPI",
+            "fmpi_futhark_free_data_sync() failed!"
+        );
+    }
+    task->args.in[0].gpu = fmpi_futhark_new_data_sync(
+        ctx->fut, task->domains[0].halo.raw, task->domains[0].halo.type.base,
+        task->domains[0].halo.dim_cnt, task->domains[0].halo.dim_len[0],
+        task->domains[0].halo.dim_len[1], task->domains[0].halo.dim_len[2]
+    );
+    if(task->args.in[0].gpu == NULL) {
+        FMPI_RAISE_ERROR(ctx->err_handler, "FMPI",
+            "fmpi_futhark_new_data_sync() failed!"
+        );
+    }
+    return err;
+}
-- 
GitLab