Skip to content
Snippets Groups Projects
Verified Commit 4513c80d authored by raphael.bach's avatar raphael.bach
Browse files

Add more reduce operations in `fmpi_reduce.fut`

parent 8ef4248c
No related branches found
No related tags found
No related merge requests found
-- Default operations for MPI_Reduce() -- Default operations for MPI_Reduce()
-- See https://www.open-mpi.org/doc/current/man3/MPI_Reduce.3.php#sect10 -- See https://www.open-mpi.org/doc/current/man3/MPI_Reduce.3.php#sect10
local module type fmpi_reduce_t = { local module type fmpi_reduce_bool_t = {
type t type t
val sum [n]: [n]t -> t val sum [n]: [n]t -> i64
val prod [n]: [n]t -> t val and [n]: [n]t -> t
val or [n]: [n]t -> t
val xor [n]: [n]t -> t
}
local module type fmpi_reduce_integer_t = {
type t
val sum [n]: [n]t -> i64
val prod [n]: [n]t -> i64
val min [n]: [n]t -> t
val max [n]: [n]t -> t
val and [n]: [n]t -> t
val or [n]: [n]t -> t
val xor [n]: [n]t -> t
val & [n]: [n]t -> t
val | [n]: [n]t -> t
val ^ [n]: [n]t -> t
}
local module type fmpi_reduce_float_t = {
type t
val sum [n]: [n]t -> f64
val prod [n]: [n]t -> f64
val min [n]: [n]t -> t val min [n]: [n]t -> t
val max [n]: [n]t -> t val max [n]: [n]t -> t
val and [n]: [n]t -> t
val or [n]: [n]t -> t
val xor [n]: [n]t -> t
} }
module fmpi_reduce (T: float): (fmpi_reduce_t with t = T.t) = { module fmpi_reduce_bool: (fmpi_reduce_bool_t with t = bool) = {
type t = bool
def sum [n] (xs: [n]t): i64 = reduce (+) 0 (map (\x -> i64.bool x) xs)
def and [n] (xs: [n]t): t = reduce (&&) true xs
def or [n] (xs: [n]t): t = reduce (||) false xs
def xor [n] (xs: [n]t): t = reduce (!=) false xs
}
module fmpi_reduce_integer (T: integral): (fmpi_reduce_integer_t with t = T.t) = {
type t = T.t
local def and_impl (a: t) (b: t): t = T.bool (a T.!= (T.i32 0) && b T.!= (T.i32 0))
local def or_impl (a: t) (b: t): t = T.bool (a T.!= (T.i32 0) || b T.!= (T.i32 0))
local def xor_impl (a: t) (b: t): t = T.bool (a T.!= b)
def sum [n] (xs: [n]t): i64 = T.to_i64 (T.sum xs)
def prod [n] (xs: [n]t): i64 = T.to_i64 (T.product xs)
def min [n] (xs: [n]t): t = T.minimum xs
def max [n] (xs: [n]t): t = T.maximum xs
def and [n] (xs: [n]t): t = reduce (and_impl) (T.i32 1) xs
def or [n] (xs: [n]t): t = reduce (or_impl) (T.i32 0) xs
def xor [n] (xs: [n]t): t = reduce (xor_impl) (T.i32 0) xs
def (&) [n] (xs: [n]t): t = reduce (T.&) (T.i32 1) xs
def (|) [n] (xs: [n]t): t = reduce (T.|) (T.i32 0) xs
def (^) [n] (xs: [n]t): t = reduce (T.^) (T.i32 0) xs
}
module fmpi_reduce_float (T: float): (fmpi_reduce_float_t with t = T.t) = {
type t = T.t type t = T.t
def sum [n] (xs: [n]t): t = T.sum xs local def and_impl (a: t) (b: t): t = T.bool (a T.!= (T.i32 0) && b T.!= (T.i32 0))
def prod [n] (xs: [n]t): t = T.product xs local def or_impl (a: t) (b: t): t = T.bool (a T.!= (T.i32 0) || b T.!= (T.i32 0))
local def xor_impl (a: t) (b: t): t = T.bool (a T.!= b)
def sum [n] (xs: [n]t): f64 = T.to_f64 (T.sum xs)
def prod [n] (xs: [n]t): f64 = T.to_f64 (T.product xs)
def min [n] (xs: [n]t): t = T.minimum xs def min [n] (xs: [n]t): t = T.minimum xs
def max [n] (xs: [n]t): t = T.maximum xs def max [n] (xs: [n]t): t = T.maximum xs
def and [n] (xs: [n]t): t = reduce (and_impl) (T.i32 1) xs
def or [n] (xs: [n]t): t = reduce (or_impl) (T.i32 0) xs
def xor [n] (xs: [n]t): t = reduce (xor_impl) (T.i32 0) xs
} }
module fmpi_reduce_i8 = fmpi_reduce_integer i8
module fmpi_reduce_i16 = fmpi_reduce_integer i16
module fmpi_reduce_i32 = fmpi_reduce_integer i32
module fmpi_reduce_i64 = fmpi_reduce_integer i64
module fmpi_reduce_u8 = fmpi_reduce_integer u8
module fmpi_reduce_u16 = fmpi_reduce_integer u16
module fmpi_reduce_u32 = fmpi_reduce_integer u32
module fmpi_reduce_u64 = fmpi_reduce_integer u64
module fmpi_reduce_f32 = fmpi_reduce_float f32
module fmpi_reduce_f64 = fmpi_reduce_float f64
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment