From 4513c80d81d5797aba8572ccc899a370596d80ce Mon Sep 17 00:00:00 2001
From: "raphael.bach" <raphael.bach@etu.hesge.ch>
Date: Tue, 19 Jul 2022 02:23:44 +0200
Subject: [PATCH] Add more reduce operations in `fmpi_reduce.fut`

---
 src/futhark/fmpi_reduce.fut | 79 ++++++++++++++++++++++++++++++++++---
 1 file changed, 73 insertions(+), 6 deletions(-)

diff --git a/src/futhark/fmpi_reduce.fut b/src/futhark/fmpi_reduce.fut
index 0bb889e..7e37430 100644
--- a/src/futhark/fmpi_reduce.fut
+++ b/src/futhark/fmpi_reduce.fut
@@ -1,17 +1,84 @@
 -- Default operations for MPI_Reduce()
 -- See https://www.open-mpi.org/doc/current/man3/MPI_Reduce.3.php#sect10
-local module type fmpi_reduce_t = {
+local module type fmpi_reduce_bool_t = {
     type t
-    val sum  [n]: [n]t -> t
-    val prod [n]: [n]t -> t
+    val sum [n]: [n]t -> i64
+    val and [n]: [n]t -> t
+    val or  [n]: [n]t -> t
+    val xor [n]: [n]t -> t
+}
+
+local module type fmpi_reduce_integer_t = {
+    type t
+    val sum  [n]: [n]t -> i64
+    val prod [n]: [n]t -> i64
+    val min  [n]: [n]t -> t
+    val max  [n]: [n]t -> t
+    val and  [n]: [n]t -> t
+    val or   [n]: [n]t -> t
+    val xor  [n]: [n]t -> t
+    val &    [n]: [n]t -> t
+    val |    [n]: [n]t -> t
+    val ^    [n]: [n]t -> t
+}
+
+local module type fmpi_reduce_float_t = {
+    type t
+    val sum  [n]: [n]t -> f64
+    val prod [n]: [n]t -> f64
     val min  [n]: [n]t -> t
     val max  [n]: [n]t -> t
+    val and  [n]: [n]t -> t
+    val or   [n]: [n]t -> t
+    val xor  [n]: [n]t -> t
 }
 
-module fmpi_reduce (T: float): (fmpi_reduce_t with t = T.t) = {
+module fmpi_reduce_bool: (fmpi_reduce_bool_t with t = bool) = {
+    type t = bool
+    def sum [n] (xs: [n]t): i64 = reduce (+) 0 (map (\x -> i64.bool x) xs)
+    def and [n] (xs: [n]t): t = reduce (&&) true xs
+    def or  [n] (xs: [n]t): t = reduce (||) false xs
+    def xor [n] (xs: [n]t): t = reduce (!=) false xs
+}
+
+module fmpi_reduce_integer (T: integral): (fmpi_reduce_integer_t with t = T.t) = {
+    type t = T.t
+    local def and_impl (a: t) (b: t): t = T.bool (a T.!= (T.i32 0) && b T.!= (T.i32 0))
+    local def or_impl  (a: t) (b: t): t = T.bool (a T.!= (T.i32 0) || b T.!= (T.i32 0))
+    local def xor_impl (a: t) (b: t): t = T.bool (a T.!= b)
+    def sum  [n] (xs: [n]t): i64 = T.to_i64 (T.sum xs)
+    def prod [n] (xs: [n]t): i64 = T.to_i64 (T.product xs)
+    def min  [n] (xs: [n]t): t = T.minimum xs
+    def max  [n] (xs: [n]t): t = T.maximum xs
+    def and  [n] (xs: [n]t): t = reduce (and_impl) (T.i32 1) xs
+    def or   [n] (xs: [n]t): t = reduce (or_impl)  (T.i32 0) xs
+    def xor  [n] (xs: [n]t): t = reduce (xor_impl) (T.i32 0) xs
+    def (&)  [n] (xs: [n]t): t = reduce (T.&) (T.i32 1) xs
+    def (|)  [n] (xs: [n]t): t = reduce (T.|) (T.i32 0) xs
+    def (^)  [n] (xs: [n]t): t = reduce (T.^) (T.i32 0) xs
+}
+
+module fmpi_reduce_float (T: float): (fmpi_reduce_float_t with t = T.t) = {
     type t = T.t
-    def sum  [n] (xs: [n]t): t = T.sum xs
-    def prod [n] (xs: [n]t): t = T.product xs
+    local def and_impl (a: t) (b: t): t = T.bool (a T.!= (T.i32 0) && b T.!= (T.i32 0))
+    local def or_impl  (a: t) (b: t): t = T.bool (a T.!= (T.i32 0) || b T.!= (T.i32 0))
+    local def xor_impl (a: t) (b: t): t = T.bool (a T.!= b)
+    def sum  [n] (xs: [n]t): f64 = T.to_f64 (T.sum xs)
+    def prod [n] (xs: [n]t): f64 = T.to_f64 (T.product xs)
     def min  [n] (xs: [n]t): t = T.minimum xs
     def max  [n] (xs: [n]t): t = T.maximum xs
+    def and  [n] (xs: [n]t): t = reduce (and_impl) (T.i32 1) xs
+    def or   [n] (xs: [n]t): t = reduce (or_impl)  (T.i32 0) xs
+    def xor  [n] (xs: [n]t): t = reduce (xor_impl) (T.i32 0) xs
 }
+
+module fmpi_reduce_i8  = fmpi_reduce_integer i8
+module fmpi_reduce_i16 = fmpi_reduce_integer i16
+module fmpi_reduce_i32 = fmpi_reduce_integer i32
+module fmpi_reduce_i64 = fmpi_reduce_integer i64
+module fmpi_reduce_u8  = fmpi_reduce_integer u8
+module fmpi_reduce_u16 = fmpi_reduce_integer u16
+module fmpi_reduce_u32 = fmpi_reduce_integer u32
+module fmpi_reduce_u64 = fmpi_reduce_integer u64
+module fmpi_reduce_f32 = fmpi_reduce_float   f32
+module fmpi_reduce_f64 = fmpi_reduce_float   f64
-- 
GitLab