Skip to content
Snippets Groups Projects
Commit 60d03928 authored by michael.divia's avatar michael.divia
Browse files

Input image test

parent cfc85998
Branches
No related tags found
No related merge requests found
......@@ -178,4 +178,7 @@ pyrightconfig.json
# End of https://www.toptal.com/developers/gitignore/api/python
Combined_Dataset
\ No newline at end of file
Combined_Dataset
SSD_Dataset
YOLOv8_Data
pokedex-env
\ No newline at end of file
......@@ -10,6 +10,7 @@ import argparse
# --- Parse CLI arguments ---
parser = argparse.ArgumentParser(description="Test trained Pokémon model.")
parser.add_argument("--model", choices=["1", "2"], required=True, help="1 = ResNet50, 2 = Xception")
parser.add_argument("--image", type=str, help="Optional image path for single inference")
args = parser.parse_args()
# --- Paths ---
......@@ -29,51 +30,80 @@ with open(json_path, "r") as f:
class_names = json.load(f)
class_names = [class_names[i] for i in range(len(class_names))]
# --- Load model (NO COMPILE to avoid 'reduction=auto' bug) ---
# --- Load model ---
model = keras.models.load_model(h5_path, compile=False)
# --- 2x2 Random Image Test ---
plt.figure(figsize=(10, 10))
for i in range(4):
# Pick random Pokémon class & image
true_class = random.choice(class_names)
class_folder = os.path.join(base_path, true_class)
img_file = random.choice([
f for f in os.listdir(class_folder)
if f.lower().endswith(('.png', '.jpg', '.jpeg'))
])
img_path = os.path.join(class_folder, img_file)
# Load and preprocess image
img = keras.utils.load_img(img_path, target_size=size)
def preprocess_image(image_path, target_size):
img = keras.utils.load_img(image_path, target_size=target_size)
img_array = keras.utils.img_to_array(img)
img_array = tf.expand_dims(img_array, 0) # [1, height, width, 3]
return img_array, img
# Predict
def predict_and_display(img_array, img, true_label=None):
predictions = model.predict(img_array, verbose=0)
probabilities = tf.nn.softmax(predictions[0]).numpy()
predicted_index = np.argmax(probabilities)
predicted_label = class_names[predicted_index]
confidence = 100 * probabilities[predicted_index]
is_correct = predicted_label == true_class
# Show top 5
print(f"\n Image: {img_file} | True: {true_class}")
print("-- Top 5 predictions:")
print("\n-- Top 5 predictions:")
for idx in np.argsort(probabilities)[-5:][::-1]:
print(f"{class_names[idx]:<20}: {probabilities[idx]*100:.2f}%")
# Plot
ax = plt.subplot(2, 2, i + 1)
# Display image
plt.imshow(img)
plt.axis("off")
plt.title(
f"Pred: {predicted_label}\n"
f"True: {true_class}\n"
f"{'YES' if is_correct else 'NO'} | {confidence:.1f}%",
fontsize=10
)
plt.tight_layout()
plt.show()
title = f"Pred: {predicted_label} | {confidence:.1f}%"
if true_label:
correctness = "YES" if predicted_label == true_label else "NO"
title += f"\nTrue: {true_label} | {correctness}"
plt.title(title, fontsize=12)
plt.show()
# --- Single image prediction ---
if args.image:
if not os.path.isfile(args.image):
print(f"Error: Image '{args.image}' not found.")
exit(1)
img_array, img = preprocess_image(args.image, size)
predict_and_display(img_array, img)
# --- Else: Random 2x2 Test Grid ---
else:
plt.figure(figsize=(10, 10))
for i in range(4):
true_class = random.choice(class_names)
class_folder = os.path.join(base_path, true_class)
img_file = random.choice([
f for f in os.listdir(class_folder)
if f.lower().endswith(('.png', '.jpg', '.jpeg'))
])
img_path = os.path.join(class_folder, img_file)
img_array, img = preprocess_image(img_path, size)
predictions = model.predict(img_array, verbose=0)
probabilities = tf.nn.softmax(predictions[0]).numpy()
predicted_index = np.argmax(probabilities)
predicted_label = class_names[predicted_index]
confidence = 100 * probabilities[predicted_index]
is_correct = predicted_label == true_class
print(f"\nImage: {img_file} | True: {true_class}")
print("-- Top 5 predictions:")
for idx in np.argsort(probabilities)[-5:][::-1]:
print(f"{class_names[idx]:<20}: {probabilities[idx]*100:.2f}%")
ax = plt.subplot(2, 2, i + 1)
plt.imshow(img)
plt.axis("off")
plt.title(
f"Pred: {predicted_label}\n"
f"True: {true_class}\n"
f"{'YES' if is_correct else 'NO'} | {confidence:.1f}%",
fontsize=10
)
plt.tight_layout()
plt.show()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment