diff --git a/src/fmpi_task.c b/src/fmpi_task.c
index b6b26b03591269a191ccacf17dd413963aa34ded..421e63d6525b530dc629632818f5997010378b6e 100644
--- a/src/fmpi_task.c
+++ b/src/fmpi_task.c
@@ -37,6 +37,15 @@
 #include "internal/fmpi_futhark.h"
 #include "internal/fmpi_mpi.h"
 #include "internal/fmpi_type.h"
+/*==============================================================================
+    PRIVATE FUNCTION DECLARATION
+==============================================================================*/
+/*------------------------------------------------------------------------------
+    fmpi_task_update_input()
+------------------------------------------------------------------------------*/
+static int fmpi_task_update_input(
+    const struct fmpi_ctx * ctx, struct fmpi_task * const task
+);
 /*==============================================================================
     PUBLIC FUNCTION DEFINITION
 ==============================================================================*/
@@ -127,25 +136,10 @@ int fmpi_task_run_sync(
         }
     }
     if(task->stencil.type != FMPI_STENCIL_NONE) {
-        fmpi_domain_set_inner(ctx, &task->domains[0], task->args.out.raw);
-        fmpi_halo_exchange_1d(ctx, &task->domains[0]);
-        const int err = fmpi_futhark_free_data_sync(
-            ctx->fut, task->args.in[0].gpu, task->domains[0].halo.type.base,
-            task->domains[0].halo.dim_cnt
-        );
+        const int err = fmpi_task_update_input(ctx, task);
         if(err != FMPI_SUCCESS) {
             FMPI_RAISE_ERROR(ctx->err_handler, "FMPI",
-                "fmpi_futhark_free_data_sync() failed!"
-            );
-        }
-        task->args.in[0].gpu = fmpi_futhark_new_data_sync(
-            ctx->fut, task->domains[0].halo.raw, task->domains[0].halo.type.base,
-            task->domains[0].halo.dim_cnt, task->domains[0].halo.dim_len[0],
-            task->domains[0].halo.dim_len[1], task->domains[0].halo.dim_len[2]
-        );
-        if(task->args.in[0].gpu == NULL) {
-            FMPI_RAISE_ERROR(ctx->err_handler, "FMPI",
-                "fmpi_futhark_new_data_sync() failed!"
+                "fmpi_task_update_input() failed!"
             );
         }
     }
@@ -200,3 +194,43 @@ int fmpi_task_finalize(
     }
     return FMPI_SUCCESS;
 }
+/*==============================================================================
+    PRIVATE FUNCTION DEFINITION
+==============================================================================*/
+/*------------------------------------------------------------------------------
+    fmpi_task_update_input()
+------------------------------------------------------------------------------*/
+static int fmpi_task_update_input(
+    const struct fmpi_ctx * const ctx, struct fmpi_task * const task
+){
+    assert(ctx != NULL);
+    assert(task != NULL);
+
+    fmpi_domain_set_inner(ctx, &task->domains[0], task->args.out.raw);
+    const size_t dim_cnt = task->domains[0].inner.dim_cnt;
+    if(dim_cnt == 1) {
+        fmpi_halo_exchange_1d(ctx, &task->domains[0]);
+    } else if (dim_cnt == 2) {
+        fmpi_halo_exchange_2d(ctx, &task->domains[0]);
+    }
+    const int err = fmpi_futhark_free_data_sync(
+        ctx->fut, task->args.in[0].gpu, task->domains[0].halo.type.base,
+        task->domains[0].halo.dim_cnt
+    );
+    if(err != FMPI_SUCCESS) {
+        FMPI_RAISE_ERROR(ctx->err_handler, "FMPI",
+            "fmpi_futhark_free_data_sync() failed!"
+        );
+    }
+    task->args.in[0].gpu = fmpi_futhark_new_data_sync(
+        ctx->fut, task->domains[0].halo.raw, task->domains[0].halo.type.base,
+        task->domains[0].halo.dim_cnt, task->domains[0].halo.dim_len[0],
+        task->domains[0].halo.dim_len[1], task->domains[0].halo.dim_len[2]
+    );
+    if(task->args.in[0].gpu == NULL) {
+        FMPI_RAISE_ERROR(ctx->err_handler, "FMPI",
+            "fmpi_futhark_new_data_sync() failed!"
+        );
+    }
+    return err;
+}