diff --git a/__main__.py b/__main__.py
index 303a7d9a4a695e5b756463ed7ac9a9b19346ca32..c9680ebfac7ffa1aea42109213bbc791901bde27 100644
--- a/__main__.py
+++ b/__main__.py
@@ -11,8 +11,7 @@ class Node:
         self.children = []
 
 
-# lower_bound
-def array_binary_search(array, value):
+def lower_bound(array, value):
     low = 0
     high = len(array) - 1
 
@@ -28,9 +27,25 @@ def array_binary_search(array, value):
 
     return low
 
+def is_value_in_array(array, value):
+    low = 0
+    high = len(array) - 1
+
+    while low <= high:
+        m = (low + high) // 2
+
+        if array[m] < value:
+            low = m + 1
+        elif array[m] > value:
+            high = m - 1
+        else:
+            return True
+
+    return False
+
 
 def array_insert_sorted(array, value):
-    index = array_binary_search(array, value)
+    index = lower_bound(array, value)
     array.insert(index, value)
     return index
 
@@ -41,7 +56,7 @@ def find_leaf(root, key):
 
     while not current.is_leaf:
         parents.append(current)
-        children_index = array_binary_search(current.keys, key)
+        children_index = lower_bound(current.keys, key)
         current = current.children[children_index]
 
     return parents, current
@@ -66,7 +81,7 @@ def redistribute_keys(left_node, right_node, left_index, right_index):
 
 
 def split_leaf(node, key):
-    virtual_insertion_index = array_binary_search(node.keys, key)
+    virtual_insertion_index = lower_bound(node.keys, key)
     median_index = len(node.keys) // 2
     right_node = Node(node.order)
 
@@ -99,7 +114,7 @@ def redistribute_children(left_node, right_node, left_index, right_index):
 
 
 def split_internal(node, key, right_child_node):
-    virtual_insertion_index = array_binary_search(node.keys, key)
+    virtual_insertion_index = lower_bound(node.keys, key)
     median_index = len(node.keys) // 2
     right_node = Node(node.order)
     right_node.is_leaf = False
@@ -162,6 +177,18 @@ def insert_non_full(node, key, previous_split_right_node):
         node.children.insert(inserted_at_index + 1, previous_split_right_node)
 
 
+def tree_search(root, key):
+    if root.is_leaf:
+        return is_value_in_array(root.keys, key)
+
+    children_index = lower_bound(root.keys, key)
+
+    if children_index < len(root.keys) and root.keys[children_index] == key:
+        children_index += 1
+
+    return tree_search(root.children[children_index], key)
+
+
 def tree_print(root, depth=0):
     print("  " * depth, end="")
     print(root.keys)
@@ -218,6 +245,18 @@ def main():
     extracted_keys = extract_all_keys(root)
     assert extracted_keys == sorted(keys)
 
+    for key in keys:
+        assert tree_search(root, key)
+
+    for _ in range(5):
+        while True:
+            random_key = random.randint(1, 99)
+
+            if random_key not in keys:
+                break
+
+        assert not tree_search(root, random_key)
+
 
 if __name__ == "__main__":
     main()