import tensorflow as tf
import keras
import matplotlib.pyplot as plt
import numpy as np
import os
import random
import json
import argparse

# --- Parse CLI arguments ---
parser = argparse.ArgumentParser(description="Test trained Pokémon model.")
parser.add_argument("--model", choices=["1", "2"], required=True, help="1 = ResNet50, 2 = Xception")
args = parser.parse_args()

# --- Paths ---
if args.model == "1":
    h5_path = "../models/ResNet50/pokedex_ResNet50.h5"
    json_path = "../models/ResNet50/class_names.json"
    size = (224, 224)
elif args.model == "2":
    h5_path = "../models/Xception/pokedex_Xception.h5"
    json_path = "../models/Xception/class_names.json"
    size = (256, 256)

base_path = "../Combined_Dataset"

# --- Load class names ---
with open(json_path, "r") as f:
    class_names = json.load(f)
    class_names = [class_names[i] for i in range(len(class_names))]

# --- Load model (NO COMPILE to avoid 'reduction=auto' bug) ---
model = keras.models.load_model(h5_path, compile=False)

# --- 2x2 Random Image Test ---
plt.figure(figsize=(10, 10))

for i in range(4):
    # Pick random Pokémon class & image
    true_class = random.choice(class_names)
    class_folder = os.path.join(base_path, true_class)
    img_file = random.choice([
        f for f in os.listdir(class_folder)
        if f.lower().endswith(('.png', '.jpg', '.jpeg'))
    ])
    img_path = os.path.join(class_folder, img_file)

    # Load and preprocess image
    img = keras.utils.load_img(img_path, target_size=size)
    img_array = keras.utils.img_to_array(img)
    img_array = tf.expand_dims(img_array, 0)  # [1, height, width, 3]

    # Predict
    predictions = model.predict(img_array, verbose=0)
    probabilities = tf.nn.softmax(predictions[0]).numpy()
    predicted_index = np.argmax(probabilities)
    predicted_label = class_names[predicted_index]
    confidence = 100 * probabilities[predicted_index]
    is_correct = predicted_label == true_class

    # Show top 5
    print(f"\n Image: {img_file} | True: {true_class}")
    print("-- Top 5 predictions:")
    for idx in np.argsort(probabilities)[-5:][::-1]:
        print(f"{class_names[idx]:<20}: {probabilities[idx]*100:.2f}%")

    # Plot
    ax = plt.subplot(2, 2, i + 1)
    plt.imshow(img)
    plt.axis("off")
    plt.title(
        f"Pred: {predicted_label}\n"
        f"True: {true_class}\n"
        f"{'YES' if is_correct else 'NO'} | {confidence:.1f}%",
        fontsize=10
    )

plt.tight_layout()
plt.show()