diff --git a/main.py b/main.py
index 6c8587e98b25c13552b25af6472b89953bb220f9..27b7e6588b37334e6f2301a1ac811bc38b53a564 100644
--- a/main.py
+++ b/main.py
@@ -12,6 +12,15 @@ Those methods are required if the Python version is lower than 3.8.
 With Python 3.8+ we can use the built-in 'pow' function to calculate the modular inverse, where y = pow(x, -1, p).
 """
 
+NIBBLE_BIT_SIZE = 16
+MODULO_FOR_ADD = 2**4  # = 16
+MODULO_FOR_MUL = MODULO_FOR_ADD + 1  # = 2^4 + 1 = 17
+NIBBLE_STRING_SIZE = 4
+BINARY_BASE = 2
+STRING_BINARY_FORMAT = '04b'
+NB_SUBKEYS_BY_ROUND = 6
+NB_SUBKEYS_FINAL_ROUND = 4
+
 
 # Return a mod b
 def modulo(a, b):
@@ -30,38 +39,38 @@ def modulo(a, b):
 # Return (a+b) mod 16
 def add_mod(a, b):
     if type(a) == str:
-        a = int(a, 2)
+        a = int(a, BINARY_BASE)
 
     if type(b) == str:
-        b = int(b, 2)
+        b = int(b, BINARY_BASE)
 
     tmp = a + b
-    res = modulo(tmp, 2**4)
-    return format(res, '04b')
+    res = modulo(tmp, MODULO_FOR_ADD)
+    return format(res, STRING_BINARY_FORMAT)
 
 
 # Return (a*b) mod 17
 def mul_mod(a, b):
     if type(a) == str:
-        a = int(a, 2)
+        a = int(a, BINARY_BASE)
 
     if type(b) == str:
-        b = int(b, 2)
+        b = int(b, BINARY_BASE)
 
     # If we have a nibble of 0000, then this corresponds to 16 (base 10) that is congruent to -1 modulo 17
     if a == 0:
-        a = 16
+        a = NIBBLE_BIT_SIZE
     if b == 0:
-        b = 16
+        b = NIBBLE_BIT_SIZE
 
     tmp = a * b
-    res = modulo(tmp, 2**4 + 1)
+    res = modulo(tmp, MODULO_FOR_MUL)
     # Because we work with nibbles and  2^4 + 1 = 17, if we obtain 16 as result this will correspond to 10000 in binary.
     # This value can't be stored in a nibble, so we must exclude it
-    if res == 2**4:
+    if res == NIBBLE_BIT_SIZE:
         res = 0
 
-    return format(res, '04b')
+    return format(res, STRING_BINARY_FORMAT)
 
 
 # Return a string of the result of a XOR b
@@ -70,10 +79,10 @@ def xor(a, b):
 
     # Format to binary string
     if type(a) != str:
-        a = format(a, "04b")
+        a = format(a, STRING_BINARY_FORMAT)
 
     if type(b) != str:
-        b = format(b, "04b")
+        b = format(b, STRING_BINARY_FORMAT)
 
     # Adjust the size of a and b
     if len(a) != len(b):
@@ -144,90 +153,92 @@ def modular_inverse(a, n):
         int: The reversed number.
     """
     if type(a) == str:
-        a = int(a, 2)
+        a = int(a, BINARY_BASE)
 
     coefficients = get_bezout_coefficients(a, n)
 
     if a * coefficients[0] % n == 1:
-        return format(coefficients[0] % n, '04b')
+        return format(coefficients[0] % n, STRING_BINARY_FORMAT)
     return None
 
 
 def inverse_add_mod(a):
     if type(a) == str:
-        a = int(a, 2)
+        a = int(a, BINARY_BASE)
 
-    res = 16 - a
+    res = NIBBLE_BIT_SIZE - a
 
-    return format(res, '04b')
+    return format(res, STRING_BINARY_FORMAT)
 
 
-# Shift the 6 first bit of the key to the end
-def shift_key(key):
-    bit_to_shift = key[0:6]
-    new_key = key[6:] + bit_to_shift
+# Shift the x bits starting from the left of the key to the end
+def shift_key(key, x=6):
+    bits_to_shift = key[0:x]
+    new_key = key[x:] + bits_to_shift
     return new_key
 
 
 # Create the table of the 28 subkeys used for the IDEA encryption
 def create_subkeys_table(key):
     subkeys = []
+    subkeys_round = []
     nibble = ''
 
-    while len(subkeys) < 28:
+    current_round = 1
+    total_rounds = 4
+
+    while current_round <= total_rounds:
         # Decompose the key into 8 nibbles
         for bit in key:
             nibble += bit
-            if len(nibble) >= 4 and len(subkeys) < 28:
-                subkeys.append(nibble)
+            if len(nibble) >= NIBBLE_STRING_SIZE:
+                subkeys_round.append(nibble)
                 nibble = ''
 
+                if len(subkeys_round) == NB_SUBKEYS_BY_ROUND:
+                    subkeys.append(subkeys_round)
+                    subkeys_round = []
+
+                # The final round contains less subkeys
+                if current_round == total_rounds and len(subkeys_round) == NB_SUBKEYS_FINAL_ROUND:
+                    subkeys.append(subkeys_round)
+                    # We must break the loop when we have all subkeys
+                    break
+
         # Shift the key for the next round
         key = shift_key(key)
+        current_round += 1
 
     return subkeys
 
 
-# Modify the subkeys table into a 2d array where each array correspond to a round
-def create_subkeys_table_with_rounds(subkeys_table):
-    subkeys_table_with_rounds = []
-    round_keys = []
-
-    for i, key in enumerate(subkeys_table):
-        round_keys.append(key)
-        if modulo(i + 1, 6) == 0:
-            subkeys_table_with_rounds.append(round_keys)
-            round_keys = []
-
-    # The final round have fewer keys than the other rounds, so we have to add it afterward
-    subkeys_table_with_rounds.append(round_keys)
-
-    return subkeys_table_with_rounds
-
-
 def create_decryption_subkeys_table(subkeys_table):
     decryption_subkeys = []
-    # Modify the subkeys table into a 2d array where each array correspond to a round
-    subkeys_table = create_subkeys_table_with_rounds(subkeys_table)
 
     remaining_round = 4
+    # The first decryption subkeys is created by the first encryption key of the last round
+    # So we will loop in reversed
     while remaining_round >= 0:
-        k1 = modular_inverse(subkeys_table[remaining_round][0], 17)
+        subkeys_round = []
+        k1 = modular_inverse(subkeys_table[remaining_round][0], MODULO_FOR_MUL)
         k2 = inverse_add_mod(subkeys_table[remaining_round][1])
         k3 = inverse_add_mod(subkeys_table[remaining_round][2])
-        k4 = modular_inverse(subkeys_table[remaining_round][3], 17)
+        k4 = modular_inverse(subkeys_table[remaining_round][3], MODULO_FOR_MUL)
 
-        decryption_subkeys.append(k1)
-        decryption_subkeys.append(k2)
-        decryption_subkeys.append(k3)
-        decryption_subkeys.append(k4)
+        subkeys_round.append(k1)
+        subkeys_round.append(k2)
+        subkeys_round.append(k3)
+        subkeys_round.append(k4)
 
         if remaining_round > 0:
+            # The last two decryption subkeys of a round correspond to the last
+            # two encryption subkeys of the previous round
             k5 = subkeys_table[remaining_round - 1][4]
             k6 = subkeys_table[remaining_round - 1][5]
-            decryption_subkeys.append(k5)
-            decryption_subkeys.append(k6)
+            subkeys_round.append(k5)
+            subkeys_round.append(k6)
 
+        decryption_subkeys.append(subkeys_round)
         remaining_round -= 1
 
     return decryption_subkeys
@@ -242,13 +253,13 @@ def encrypt(plaintext, subkeys):
 
     while current_round <= total_rounds:
         # 1. Multiply X1 and the first subkey Z1
-        res_1 = mul_mod(input_block[0:4], subkeys[0])
+        res_1 = mul_mod(input_block[0:4], subkeys[current_round][0])
         # 2. Add X2 and the second subkey Z2
-        res_2 = add_mod(input_block[4:8], subkeys[1])
+        res_2 = add_mod(input_block[4:8], subkeys[current_round][1])
         # 3. Add X3 and the third subkey Z3
-        res_3 = add_mod(input_block[8:12], subkeys[2])
+        res_3 = add_mod(input_block[8:12], subkeys[current_round][2])
         # 4. Multiply X4 and the fourth subkey Z4
-        res_4 = mul_mod(input_block[12:16], subkeys[3])
+        res_4 = mul_mod(input_block[12:16], subkeys[current_round][3])
 
         # Don't do the next steps if this is the final round
         if current_round < total_rounds:
@@ -257,11 +268,11 @@ def encrypt(plaintext, subkeys):
             # 6. Bitwise XOR the results of steps 2 and 4
             res_6 = xor(res_2, res_4)
             # 7. Multiply the result of step 5 and the fifth subkey Z5
-            res_7 = mul_mod(res_5, subkeys[4])
+            res_7 = mul_mod(res_5, subkeys[current_round][4])
             # 8. Add the results of step 6 and 7
             res_8 = add_mod(res_6, res_7)
             # 9. Multiply the result of step 8 and the sixth subkey Z6
-            res_9 = mul_mod(res_8, subkeys[5])
+            res_9 = mul_mod(res_8, subkeys[current_round][5])
             # 10. Add the results of steps 7 and 9
             res_10 = add_mod(res_7, res_9)
             # 11. Bitwise XOR the results of steps 1 and 9
@@ -277,7 +288,6 @@ def encrypt(plaintext, subkeys):
             input_block = res_11 + res_13 + res_12 + res_14
 
         current_round += 1
-        subkeys = subkeys[6:]
 
     ciphertext = res_1 + res_2 + res_3 + res_4
     return ciphertext
@@ -285,6 +295,7 @@ def encrypt(plaintext, subkeys):
 
 # Decrypt the ciphertext with the subkeys table by using the simplified IDEA algorithm
 def decrypt(ciphertext, subkeys_table):
+    # To decrypt the ciphertext, we just have to encrypt it by using the decryption subkeys table
     decryption_keys = create_decryption_subkeys_table(subkeys_table)
     return encrypt(ciphertext, decryption_keys)