diff --git a/Python/new_pokedex_resnet50.py b/Python/new_pokedex_resnet50.py deleted file mode 100644 index bd105bb1585a3a721bcfb1a677107bf61899ceb1..0000000000000000000000000000000000000000 --- a/Python/new_pokedex_resnet50.py +++ /dev/null @@ -1,85 +0,0 @@ -import tensorflow as tf -from tensorflow.keras.applications import ResNet50 -from tensorflow.keras.layers import Dense, GlobalAveragePooling2D -from tensorflow.keras.models import Model -from tensorflow.keras.preprocessing import image_dataset_from_directory -import json -import os - -# --- GPU Strategy --- -strategy = tf.distribute.MirroredStrategy() -print("Number of GPUs:", strategy.num_replicas_in_sync) - -# --- Paths --- -dataset_path = "/home/users/d/divia/scratch/Combined_Dataset" -model_output_path = "/home/users/d/divia/ResNet50" -os.makedirs(model_output_path, exist_ok=True) - -# --- Image settings --- -img_size = (224, 224) -batch_size = 32 - -# --- Load datasets --- -raw_train_ds = image_dataset_from_directory( - dataset_path, - image_size=img_size, - batch_size=batch_size, - validation_split=0.2, - subset="training", - seed=123, -) - -raw_val_ds = image_dataset_from_directory( - dataset_path, - image_size=img_size, - batch_size=batch_size, - validation_split=0.2, - subset="validation", - seed=123, -) - -# Save class names -class_names = raw_train_ds.class_names -with open(os.path.join(model_output_path, "class_names.json"), "w") as f: - json.dump(class_names, f) -print(f"Detected {len(class_names)} Pokémon classes.") - -# --- Performance improvements --- -AUTOTUNE = tf.data.AUTOTUNE -train_ds = raw_train_ds.prefetch(buffer_size=AUTOTUNE) -val_ds = raw_val_ds.prefetch(buffer_size=AUTOTUNE) - -# --- Build the model --- -with strategy.scope(): - base_model = ResNet50(weights="imagenet", include_top=False, input_shape=(224, 224, 3)) - x = base_model.output - x = GlobalAveragePooling2D()(x) - outputs = Dense(len(class_names), activation="softmax")(x) - - model = Model(inputs=base_model.input, outputs=outputs) - - # Freeze base model - for layer in base_model.layers: - layer.trainable = False - - model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) - -# --- Train the model --- -model.fit(train_ds, validation_data=val_ds, epochs=10) - -# --- Optional Fine-tuning (Unfreeze top layers) --- -# Uncomment to fine-tune top layers of ResNet50 -for layer in base_model.layers[-30:]: - layer.trainable = True -model.compile(optimizer=tf.keras.optimizers.Adam(1e-5), loss='sparse_categorical_crossentropy', metrics=['accuracy']) -model.fit(train_ds, validation_data=val_ds, epochs=5) - -# --- Save the model --- -model_h5_path = os.path.join(model_output_path, "pokemon_resnet50.h5") -model.save(model_h5_path) -print(f"Model saved to {model_h5_path}") - -# --- Save as TensorFlow SavedModel (for ONNX export) --- -saved_model_path = os.path.join(model_output_path, "saved_model") -tf.saved_model.save(model, saved_model_path) -print(f"SavedModel exported to {saved_model_path}") diff --git a/Python/pokedex_ResNet50.py b/Python/pokedex_ResNet50.py index 29b498d374c154b1cf7ef6d06f5a58fe1150378f..59886992752f6d21bb07db0e88a3b6bfcb647599 100644 --- a/Python/pokedex_ResNet50.py +++ b/Python/pokedex_ResNet50.py @@ -1,93 +1,116 @@ -import keras import tensorflow as tf -from keras import layers -from tensorflow import data as tf_data +from tensorflow.keras.applications import ResNet50 +from tensorflow.keras.layers import Dense, GlobalAveragePooling2D +from tensorflow.keras.models import Model +from tensorflow.keras.preprocessing import image_dataset_from_directory +from tensorflow.keras import layers +from tensorflow.keras.callbacks import EarlyStopping +from sklearn.utils.class_weight import compute_class_weight +import numpy as np +import json +import os # --- GPU Strategy --- strategy = tf.distribute.MirroredStrategy() print("Number of GPUs:", strategy.num_replicas_in_sync) -# --- Parameters --- -data_dir = "/home/users/d/divia/scratch/Combined_Dataset" -image_size = (224, 224) -num_classes = 151 -base_batch_size = 32 -base_lr = 1e-3 - -global_batch_size = 32 -scaled_lr = min(base_lr * (global_batch_size / base_batch_size), 1e-3) - -# --- Load Dataset --- -full_ds = keras.utils.image_dataset_from_directory( - data_dir, - labels="inferred", - label_mode="int", - image_size=image_size, - batch_size=global_batch_size, - shuffle=True, - seed=1234 -) +# --- Paths --- +dataset_path = "/home/padi/Git/pokedex/Combined_Dataset" +model_output_path = "/home/padi/Git/pokedex/ResNet50" +os.makedirs(model_output_path, exist_ok=True) -# --- Train/Val Split --- -total_batches = tf.data.experimental.cardinality(full_ds).numpy() -train_size = int(0.8 * total_batches) -train_ds = full_ds.take(train_size) -val_ds = full_ds.skip(train_size) +# --- Image settings --- +img_size = (224, 224) +batch_size = 32 # --- Data Augmentation --- -data_augmentation_layers = keras.Sequential([ +data_augmentation = tf.keras.Sequential([ layers.RandomFlip("horizontal"), layers.RandomRotation(0.1), + layers.RandomZoom(0.1), + layers.RandomContrast(0.1), ]) -def preprocess_train(img, label): - img = data_augmentation_layers(img) - label = tf.one_hot(label, num_classes) - return img, label +# --- Load datasets --- +raw_train_ds = image_dataset_from_directory( + dataset_path, + image_size=img_size, + batch_size=batch_size, + validation_split=0.2, + subset="training", + seed=123, +) -def preprocess_val(img, label): - label = tf.one_hot(label, num_classes) - return img, label +raw_val_ds = image_dataset_from_directory( + dataset_path, + image_size=img_size, + batch_size=batch_size, + validation_split=0.2, + subset="validation", + seed=123, +) -train_ds = train_ds.map(preprocess_train, num_parallel_calls=tf_data.AUTOTUNE) -val_ds = val_ds.map(preprocess_val, num_parallel_calls=tf_data.AUTOTUNE) +# Save class names +class_names = raw_train_ds.class_names +with open(os.path.join(model_output_path, "class_names.json"), "w") as f: + json.dump(class_names, f) +print(f"Detected {len(class_names)} Pokémon classes.") + +# --- Compute class weights --- +print("Computing class weights...") +all_labels = [] +for _, labels in raw_train_ds.unbatch(): + all_labels.append(labels.numpy()) + +all_labels = np.array(all_labels) +class_weights = compute_class_weight( + class_weight="balanced", + classes=np.unique(all_labels), + y=all_labels +) +class_weight_dict = dict(enumerate(class_weights)) +print("Class weights ready.") -train_ds = train_ds.prefetch(buffer_size=tf_data.AUTOTUNE) -val_ds = val_ds.prefetch(buffer_size=tf_data.AUTOTUNE) +# --- Performance improvements --- +AUTOTUNE = tf.data.AUTOTUNE +train_ds = raw_train_ds.map(lambda x, y: (data_augmentation(x), y)).prefetch(AUTOTUNE) +val_ds = raw_val_ds.prefetch(buffer_size=AUTOTUNE) -# --- Build & Compile Model --- +# --- Build the model --- with strategy.scope(): - base_model = tf.keras.applications.ResNet50( - include_top=False, - weights='imagenet', - input_shape=(224, 224, 3) - ) - - model = keras.Sequential([ - base_model, - layers.GlobalAveragePooling2D(), - layers.Dense(256, activation='relu'), - layers.Dropout(0.5), - layers.Dense(num_classes, activation='softmax') - ]) - - optimizer = tf.keras.optimizers.Adam(learning_rate=scaled_lr) - - model.compile( - optimizer=optimizer, - loss='categorical_crossentropy', - metrics=['accuracy'] - ) - -# --- Train --- + base_model = ResNet50(weights="imagenet", include_top=False, input_shape=(224, 224, 3)) + x = base_model.output + x = GlobalAveragePooling2D()(x) + outputs = Dense(len(class_names), activation="softmax")(x) + + model = Model(inputs=base_model.input, outputs=outputs) + + # Freeze some layers + for layer in base_model.layers[:100]: + layer.trainable = False + + model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) + +# --- Callbacks --- callbacks = [ - keras.callbacks.ModelCheckpoint("/home/users/d/divia/ResNet50/save_at_{epoch}.h5"), - keras.callbacks.EarlyStopping(monitor="val_loss", patience=2, restore_best_weights=True) + EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True) ] +# --- Train the model with class weights --- model.fit( train_ds, validation_data=val_ds, - epochs=15, - callbacks=callbacks -) \ No newline at end of file + epochs=10, + callbacks=callbacks, + class_weight=class_weight_dict +) + +# --- Save the model --- +model_h5_path = os.path.join(model_output_path, "pokemon_resnet50.h5") +model.save(model_h5_path) +print(f"Model saved to {model_h5_path}") + +# --- Save as TensorFlow SavedModel (for ONNX export) --- +saved_model_path = os.path.join(model_output_path, "saved_model") +tf.saved_model.save(model, saved_model_path) +print(f"SavedModel exported to {saved_model_path}") \ No newline at end of file diff --git a/Python/test_ResNet50.py b/Python/test_ResNet50.py new file mode 100644 index 0000000000000000000000000000000000000000..04dd17289f4678570ac94118685fd594f51ef6ab --- /dev/null +++ b/Python/test_ResNet50.py @@ -0,0 +1,61 @@ +from tensorflow import keras +import tensorflow as tf +import matplotlib.pyplot as plt +import numpy as np +import os +import random +import json + +# === Load class names from JSON === +with open("../models/ResNet50/class_names.json", "r") as f: + class_names = json.load(f) + class_names = [class_names[i] for i in range(len(class_names))] # convert to list + +# === Load trained model === +model = keras.models.load_model("../models/ResNet50/pokemon_resnet50.h5") + +# === Paths === +base_path = "../Combined_Dataset" + +# === Prepare 2x2 Plot === +plt.figure(figsize=(10, 10)) + +for i in range(4): + # Pick random class and image + random_class = random.choice(class_names) + class_folder = os.path.join(base_path, random_class) + random_image = random.choice([ + f for f in os.listdir(class_folder) + if f.lower().endswith(('.png', '.jpg', '.jpeg')) + ]) + img_path = os.path.join(class_folder, random_image) + + # === Load & Preprocess Image === + img = keras.utils.load_img(img_path, target_size=(224, 224)) # resize to match model input + img_array = keras.utils.img_to_array(img) + img_array = img_array / 255.0 # normalize if your model expects it + img_array = tf.expand_dims(img_array, 0) + + # === Predict === + predictions = model.predict(img_array, verbose=0) + probabilities = tf.nn.softmax(predictions[0]) + predicted_class_index = np.argmax(probabilities) + predicted_label = class_names[predicted_class_index] + confidence = 100 * probabilities[predicted_class_index] + + # Compare with actual + is_correct = predicted_label == random_class + + # === Plot === + ax = plt.subplot(2, 2, i + 1) + plt.imshow(img) + plt.axis("off") + plt.title( + f"Pred: {predicted_label}\n" + f"True: {random_class}\n" + f"{'YES' if is_correct else 'NO'} | {confidence:.1f}%", + fontsize=10 + ) + +plt.tight_layout() +plt.show() diff --git a/ResNet50/class_names.json b/ResNet50/class_names.json new file mode 100644 index 0000000000000000000000000000000000000000..57874547cdef98b4f08dd1eebb09a2dcedc3ed4e --- /dev/null +++ b/ResNet50/class_names.json @@ -0,0 +1 @@ +["Abo", "Abra", "Akwakwak", "Alakazam", "Amonistar", "Amonita", "Aquali", "Arbok", "Arcanin", "Artikodin", "Aspicot", "A\u00e9romite", "Boustiflor", "Bulbizarre", "Caninos", "Carabaffe", "Carapuce", "Chenipan", "Chrysacier", "Ch\u00e9tiflor", "Coconfort", "Colossinge", "Crustabri", "Dardargnan", "Dodrio", "Doduo", "Dracaufeu", "Draco", "Dracolosse", "Ectoplasma", "Empiflor", "Excelangue", "Fantominus", "Farfetchd", "Feunard", "Flagadoss", "Florizarre", "F\u00e9rosinge", "Galopa", "Goupix", "Gravalanch", "Grodoudou", "Grolem", "Grotadmorv", "Herbizarre", "Hypnomade", "Hypoc\u00e9an", "Hypotrempe", "Ins\u00e9cateur", "Kabuto", "Kabutops", "Kadabra", "Kangourex", "Kicklee", "Kokiyas", "Krabboss", "Krabby", "Lamantine", "Leveinard", "Lippoutou", "Lokhlass", "L\u00e9viator", "M. Mime", "Machoc", "Machopeur", "Mackogneur", "Magicarpe", "Magmar", "Magn\u00e9ti", "Magn\u00e9ton", "Mew", "Mewtwo", "Miaouss", "Mimitoss", "Minidraco", "Mystherbe", "M\u00e9lodelfe", "M\u00e9lof\u00e9e", "M\u00e9tamorph", "Nidoking", "Nidoqueen", "Nidoran_femelle", "Nidoran_male", "Nidorina", "Nidorino", "Noadkoko", "Noeunoeuf", "Nosferalto", "Nosferapti", "Onix", "Ortide", "Ossatueur", "Osselait", "Otaria", "Papilusion", "Paras", "Parasect", "Persian", "Piafabec", "Pikachu", "Poissir\u00e8ne", "Poissoroy", "Ponyta", "Porygon", "Psykokwak", "Ptitard", "Pt\u00e9ra", "Pyroli", "Racaillou", "Rafflesia", "Raichu", "Ramoloss", "Rapasdepic", "Rattata", "Rattatac", "Reptincel", "Rhinocorne", "Rhinof\u00e9ros", "Rondoudou", "Ronflex", "Roucarnage", "Roucool", "Roucoups", "Sabelette", "Sablaireau", "Salam\u00e8che", "Saquedeneu", "Scarabrute", "Smogo", "Smogogo", "Soporifik", "Spectrum", "Stari", "Staross", "Sulfura", "Tadmorv", "Tartard", "Taupiqueur", "Tauros", "Tentacool", "Tentacruel", "Tortank", "Triopikeur", "Tygnon", "T\u00eatarte", "Voltali", "Voltorbe", "\u00c9lecthor", "\u00c9lectrode", "\u00c9lektek", "\u00c9voli"] \ No newline at end of file