diff --git a/__main__.py b/__main__.py
index 1f9ccd1b9a58c484d58630eb40787d56b33e663e..9d5cd69ab2d0b296859ff4942575f24ee2def387 100644
--- a/__main__.py
+++ b/__main__.py
@@ -37,18 +37,23 @@ def array_insert_sorted(array, value):
     return index
 
 
-def insert(root, key):
-    leaf_node = root
+def find_leaf(root, key):
     parents = []
-    while not leaf_node.is_leaf:
-        parents.append(leaf_node)
-        index = array_binary_search(leaf_node.keys, key)
-        leaf_node = leaf_node.children[index]
+    current = root
+    while not current.is_leaf:
+        parents.append(current)
+        children_index = array_binary_search(current.keys, key)
+        current = current.children[children_index]
+    return parents, current
+
+
+def insert(root, key):
+    parents, leaf = find_leaf(root, key)
 
-    if node_is_full(leaf_node):
-        insert_full(root, parents, leaf_node, key, None)
+    if node_is_full(leaf):
+        insert_full(root, parents, leaf, key, None)
     else:
-        insert_non_full(leaf_node, key, None)
+        insert_non_full(leaf, key, None)
 
 
 def insert_non_full(node, key, right_child_node):
@@ -156,8 +161,6 @@ def increase_height(root, key, right_child_node):
 
 
 def tree_print(root, depth=0):
-    if len(root.keys) == 0:
-        return
     print("  " * depth, end="")
     print(root.keys)