import tensorflow as tf import keras import matplotlib.pyplot as plt import numpy as np import os import random import json import argparse # --- Parse CLI arguments --- parser = argparse.ArgumentParser(description="Test trained Pokémon model.") parser.add_argument("--model", choices=["1", "2"], required=True, help="1 = ResNet50, 2 = Xception") args = parser.parse_args() # --- Paths --- if args.model == "1": h5_path = "../models/ResNet50/pokedex_ResNet50.h5" json_path = "../models/ResNet50/class_names.json" size = (224, 224) elif args.model == "2": h5_path = "../models/Xception/pokedex_Xception.h5" json_path = "../models/Xception/class_names.json" size = (256, 256) base_path = "../Combined_Dataset" # --- Load class names --- with open(json_path, "r") as f: class_names = json.load(f) class_names = [class_names[i] for i in range(len(class_names))] # --- Load model (NO COMPILE to avoid 'reduction=auto' bug) --- model = keras.models.load_model(h5_path, compile=False) # --- 2x2 Random Image Test --- plt.figure(figsize=(10, 10)) for i in range(4): # Pick random Pokémon class & image true_class = random.choice(class_names) class_folder = os.path.join(base_path, true_class) img_file = random.choice([ f for f in os.listdir(class_folder) if f.lower().endswith(('.png', '.jpg', '.jpeg')) ]) img_path = os.path.join(class_folder, img_file) # Load and preprocess image img = keras.utils.load_img(img_path, target_size=size) img_array = keras.utils.img_to_array(img) img_array = tf.expand_dims(img_array, 0) # [1, height, width, 3] # Predict predictions = model.predict(img_array, verbose=0) probabilities = tf.nn.softmax(predictions[0]).numpy() predicted_index = np.argmax(probabilities) predicted_label = class_names[predicted_index] confidence = 100 * probabilities[predicted_index] is_correct = predicted_label == true_class # Show top 5 print(f"\n Image: {img_file} | True: {true_class}") print("-- Top 5 predictions:") for idx in np.argsort(probabilities)[-5:][::-1]: print(f"{class_names[idx]:<20}: {probabilities[idx]*100:.2f}%") # Plot ax = plt.subplot(2, 2, i + 1) plt.imshow(img) plt.axis("off") plt.title( f"Pred: {predicted_label}\n" f"True: {true_class}\n" f"{'YES' if is_correct else 'NO'} | {confidence:.1f}%", fontsize=10 ) plt.tight_layout() plt.show()