From 4513c80d81d5797aba8572ccc899a370596d80ce Mon Sep 17 00:00:00 2001 From: "raphael.bach" <raphael.bach@etu.hesge.ch> Date: Tue, 19 Jul 2022 02:23:44 +0200 Subject: [PATCH] Add more reduce operations in `fmpi_reduce.fut` --- src/futhark/fmpi_reduce.fut | 79 ++++++++++++++++++++++++++++++++++--- 1 file changed, 73 insertions(+), 6 deletions(-) diff --git a/src/futhark/fmpi_reduce.fut b/src/futhark/fmpi_reduce.fut index 0bb889e..7e37430 100644 --- a/src/futhark/fmpi_reduce.fut +++ b/src/futhark/fmpi_reduce.fut @@ -1,17 +1,84 @@ -- Default operations for MPI_Reduce() -- See https://www.open-mpi.org/doc/current/man3/MPI_Reduce.3.php#sect10 -local module type fmpi_reduce_t = { +local module type fmpi_reduce_bool_t = { type t - val sum [n]: [n]t -> t - val prod [n]: [n]t -> t + val sum [n]: [n]t -> i64 + val and [n]: [n]t -> t + val or [n]: [n]t -> t + val xor [n]: [n]t -> t +} + +local module type fmpi_reduce_integer_t = { + type t + val sum [n]: [n]t -> i64 + val prod [n]: [n]t -> i64 + val min [n]: [n]t -> t + val max [n]: [n]t -> t + val and [n]: [n]t -> t + val or [n]: [n]t -> t + val xor [n]: [n]t -> t + val & [n]: [n]t -> t + val | [n]: [n]t -> t + val ^ [n]: [n]t -> t +} + +local module type fmpi_reduce_float_t = { + type t + val sum [n]: [n]t -> f64 + val prod [n]: [n]t -> f64 val min [n]: [n]t -> t val max [n]: [n]t -> t + val and [n]: [n]t -> t + val or [n]: [n]t -> t + val xor [n]: [n]t -> t } -module fmpi_reduce (T: float): (fmpi_reduce_t with t = T.t) = { +module fmpi_reduce_bool: (fmpi_reduce_bool_t with t = bool) = { + type t = bool + def sum [n] (xs: [n]t): i64 = reduce (+) 0 (map (\x -> i64.bool x) xs) + def and [n] (xs: [n]t): t = reduce (&&) true xs + def or [n] (xs: [n]t): t = reduce (||) false xs + def xor [n] (xs: [n]t): t = reduce (!=) false xs +} + +module fmpi_reduce_integer (T: integral): (fmpi_reduce_integer_t with t = T.t) = { + type t = T.t + local def and_impl (a: t) (b: t): t = T.bool (a T.!= (T.i32 0) && b T.!= (T.i32 0)) + local def or_impl (a: t) (b: t): t = T.bool (a T.!= (T.i32 0) || b T.!= (T.i32 0)) + local def xor_impl (a: t) (b: t): t = T.bool (a T.!= b) + def sum [n] (xs: [n]t): i64 = T.to_i64 (T.sum xs) + def prod [n] (xs: [n]t): i64 = T.to_i64 (T.product xs) + def min [n] (xs: [n]t): t = T.minimum xs + def max [n] (xs: [n]t): t = T.maximum xs + def and [n] (xs: [n]t): t = reduce (and_impl) (T.i32 1) xs + def or [n] (xs: [n]t): t = reduce (or_impl) (T.i32 0) xs + def xor [n] (xs: [n]t): t = reduce (xor_impl) (T.i32 0) xs + def (&) [n] (xs: [n]t): t = reduce (T.&) (T.i32 1) xs + def (|) [n] (xs: [n]t): t = reduce (T.|) (T.i32 0) xs + def (^) [n] (xs: [n]t): t = reduce (T.^) (T.i32 0) xs +} + +module fmpi_reduce_float (T: float): (fmpi_reduce_float_t with t = T.t) = { type t = T.t - def sum [n] (xs: [n]t): t = T.sum xs - def prod [n] (xs: [n]t): t = T.product xs + local def and_impl (a: t) (b: t): t = T.bool (a T.!= (T.i32 0) && b T.!= (T.i32 0)) + local def or_impl (a: t) (b: t): t = T.bool (a T.!= (T.i32 0) || b T.!= (T.i32 0)) + local def xor_impl (a: t) (b: t): t = T.bool (a T.!= b) + def sum [n] (xs: [n]t): f64 = T.to_f64 (T.sum xs) + def prod [n] (xs: [n]t): f64 = T.to_f64 (T.product xs) def min [n] (xs: [n]t): t = T.minimum xs def max [n] (xs: [n]t): t = T.maximum xs + def and [n] (xs: [n]t): t = reduce (and_impl) (T.i32 1) xs + def or [n] (xs: [n]t): t = reduce (or_impl) (T.i32 0) xs + def xor [n] (xs: [n]t): t = reduce (xor_impl) (T.i32 0) xs } + +module fmpi_reduce_i8 = fmpi_reduce_integer i8 +module fmpi_reduce_i16 = fmpi_reduce_integer i16 +module fmpi_reduce_i32 = fmpi_reduce_integer i32 +module fmpi_reduce_i64 = fmpi_reduce_integer i64 +module fmpi_reduce_u8 = fmpi_reduce_integer u8 +module fmpi_reduce_u16 = fmpi_reduce_integer u16 +module fmpi_reduce_u32 = fmpi_reduce_integer u32 +module fmpi_reduce_u64 = fmpi_reduce_integer u64 +module fmpi_reduce_f32 = fmpi_reduce_float f32 +module fmpi_reduce_f64 = fmpi_reduce_float f64 -- GitLab