diff --git a/__main__.py b/__main__.py index 303a7d9a4a695e5b756463ed7ac9a9b19346ca32..c9680ebfac7ffa1aea42109213bbc791901bde27 100644 --- a/__main__.py +++ b/__main__.py @@ -11,8 +11,7 @@ class Node: self.children = [] -# lower_bound -def array_binary_search(array, value): +def lower_bound(array, value): low = 0 high = len(array) - 1 @@ -28,9 +27,25 @@ def array_binary_search(array, value): return low +def is_value_in_array(array, value): + low = 0 + high = len(array) - 1 + + while low <= high: + m = (low + high) // 2 + + if array[m] < value: + low = m + 1 + elif array[m] > value: + high = m - 1 + else: + return True + + return False + def array_insert_sorted(array, value): - index = array_binary_search(array, value) + index = lower_bound(array, value) array.insert(index, value) return index @@ -41,7 +56,7 @@ def find_leaf(root, key): while not current.is_leaf: parents.append(current) - children_index = array_binary_search(current.keys, key) + children_index = lower_bound(current.keys, key) current = current.children[children_index] return parents, current @@ -66,7 +81,7 @@ def redistribute_keys(left_node, right_node, left_index, right_index): def split_leaf(node, key): - virtual_insertion_index = array_binary_search(node.keys, key) + virtual_insertion_index = lower_bound(node.keys, key) median_index = len(node.keys) // 2 right_node = Node(node.order) @@ -99,7 +114,7 @@ def redistribute_children(left_node, right_node, left_index, right_index): def split_internal(node, key, right_child_node): - virtual_insertion_index = array_binary_search(node.keys, key) + virtual_insertion_index = lower_bound(node.keys, key) median_index = len(node.keys) // 2 right_node = Node(node.order) right_node.is_leaf = False @@ -162,6 +177,18 @@ def insert_non_full(node, key, previous_split_right_node): node.children.insert(inserted_at_index + 1, previous_split_right_node) +def tree_search(root, key): + if root.is_leaf: + return is_value_in_array(root.keys, key) + + children_index = lower_bound(root.keys, key) + + if children_index < len(root.keys) and root.keys[children_index] == key: + children_index += 1 + + return tree_search(root.children[children_index], key) + + def tree_print(root, depth=0): print(" " * depth, end="") print(root.keys) @@ -218,6 +245,18 @@ def main(): extracted_keys = extract_all_keys(root) assert extracted_keys == sorted(keys) + for key in keys: + assert tree_search(root, key) + + for _ in range(5): + while True: + random_key = random.randint(1, 99) + + if random_key not in keys: + break + + assert not tree_search(root, random_key) + if __name__ == "__main__": main()