From 0da6c504ce8ebb84906f952c910902814b3127e8 Mon Sep 17 00:00:00 2001
From: Alexis Durgnat <alexis.durgnat@hesge.ch>
Date: Wed, 3 Nov 2021 11:06:32 +0100
Subject: [PATCH] Cleanup examples, optimize a little

---
 src/ar_sandbox/examples/examples.py | 57 ++++++++++++++++-------------
 1 file changed, 32 insertions(+), 25 deletions(-)

diff --git a/src/ar_sandbox/examples/examples.py b/src/ar_sandbox/examples/examples.py
index 2a762e0..83cb4f9 100644
--- a/src/ar_sandbox/examples/examples.py
+++ b/src/ar_sandbox/examples/examples.py
@@ -1,51 +1,60 @@
+from typing import Any
 import numpy as np
 
 COLORS = [
-    (0, [255, 0, 0]),
-    (0.25, [255, 255, 0]),
-    (0.5, [0, 255, 0]),
-    (0.75, [0, 255, 255]),
-    (1, [0, 0, 255])
+    (0, (255, 0, 0)),
+    (0.25, (255, 255, 0)),
+    (0.5, (0, 255, 0)),
+    (0.75, (0, 255, 255)),
+    (1, (0, 0, 255))
 ]
 
 class LevelDisplay():
 
-    @staticmethod
-    def get_color(depth_matrix, frame, draw_lines=True, colormap=COLORS):
+    def __init__(self, colormap:list = COLORS) -> None:
         """
-        Given a depth matrix between 0-1, return a color from a given colormap.
+        Inialize the level display with a given colormap.
         The colormap is a list of tuples containing the depth value, and the
         color for this level :
-        (depth_value, [RedVal, GreenVal, BlueVal])
+        (depth_value, (RedVal, GreenVal, BlueVal))
 
         The colormap should at least contain a color for the depths 0 and 1.
         See COLORS for an example colormap.
-        
+
+        Arguments:
+            colormap : A list of tuple representing a mapping between a float value and
+            a triple of Red, Green and Blue values.
+        """
+        self.colormap = colormap
+        self.points = np.array([c[0] for c in colormap])
+        self.r = [c[1][0] for c in colormap]
+        self.g = [c[1][1] for c in colormap]
+        self.b = [c[1][2] for c in colormap]
+
+    def __call__(self, *args: Any, **kwds: Any) -> Any:
+        return self.get_color(*args, **kwds)
+
+    def get_color(self, depth_matrix: np.ndarray, frame: np.ndarray, draw_lines: bool=True) -> np.ndarray:
+        """
+        Given a depth matrix between 0-1, return a color from the colormap
+
         Arguments:
             depth_matrix : Normalized 1 channel numpy matrix
             frame: Unused. The frame captured by the camera
             draw_line: Should line be drawn between levels ?
-            colormap: Override the default colormap.
         Return:
             A 3 channel matrix in BGR frame of the same size as depth_matrix.
         """
-        points = np.array([c[0] for c in colormap])
-        r = [c[1][2] for c in colormap]
-        g = [c[1][1] for c in colormap]
-        b = [c[1][0] for c in colormap]
-        
-        # print(rval)
-        if draw_lines: line_mask = LevelDisplay.draw_lines(depth_matrix, points)
-        # print(line_mask.shape)
-        rval = np.interp(depth_matrix, points, r)
-        gval = np.interp(depth_matrix, points, g)
-        bval = np.interp(depth_matrix, points, b)
+        if draw_lines: line_mask = LevelDisplay.draw_lines(depth_matrix, self.points)
+        rval = np.interp(depth_matrix, self.points, self.r)
+        gval = np.interp(depth_matrix, self.points, self.g)
+        bval = np.interp(depth_matrix, self.points, self.b)
         res = np.dstack((bval, gval, rval)).astype(np.uint8)
         if draw_lines: res[line_mask,:] = 0 
         return res
 
     @staticmethod
-    def draw_lines(depth, points, width=0.025, between_levels=False):
+    def draw_lines(depth:np.ndarray, points:list, width: float=0.025, between_levels: bool=False) -> np.ndarray:
         """
         Given the depth matrix and a list of points, return a mask for every 
         value near the points. If between_levels is set to True, mask in 
@@ -73,6 +82,4 @@ class LevelDisplay():
         halfwidth = width/2
         for line in lines:
             mask = np.where(np.logical_or(mask, np.logical_and(depth >= line - halfwidth, depth <= line+halfwidth)), True, False)
-            # rpos_mline = np.interp([rpos_mline])
-        # print(mask)
         return mask
\ No newline at end of file
-- 
GitLab