Skip to content
Snippets Groups Projects
Commit e644289c authored by michael.divia's avatar michael.divia
Browse files

Here we go again

parent 67858bd1
No related branches found
No related tags found
No related merge requests found
import tensorflow as tf
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.models import Model
from tensorflow.keras.preprocessing import image_dataset_from_directory
import json
import os
# --- GPU Strategy ---
strategy = tf.distribute.MirroredStrategy()
print("Number of GPUs:", strategy.num_replicas_in_sync)
# --- Paths ---
dataset_path = "/home/users/d/divia/scratch/Combined_Dataset"
model_output_path = "/home/users/d/divia/ResNet50"
os.makedirs(model_output_path, exist_ok=True)
# --- Image settings ---
img_size = (224, 224)
batch_size = 32
# --- Load datasets ---
raw_train_ds = image_dataset_from_directory(
dataset_path,
image_size=img_size,
batch_size=batch_size,
validation_split=0.2,
subset="training",
seed=123,
)
raw_val_ds = image_dataset_from_directory(
dataset_path,
image_size=img_size,
batch_size=batch_size,
validation_split=0.2,
subset="validation",
seed=123,
)
# Save class names
class_names = raw_train_ds.class_names
with open(os.path.join(model_output_path, "class_names.json"), "w") as f:
json.dump(class_names, f)
print(f"Detected {len(class_names)} Pokémon classes.")
# --- Performance improvements ---
AUTOTUNE = tf.data.AUTOTUNE
train_ds = raw_train_ds.prefetch(buffer_size=AUTOTUNE)
val_ds = raw_val_ds.prefetch(buffer_size=AUTOTUNE)
# --- Build the model ---
with strategy.scope():
base_model = ResNet50(weights="imagenet", include_top=False, input_shape=(224, 224, 3))
x = base_model.output
x = GlobalAveragePooling2D()(x)
outputs = Dense(len(class_names), activation="softmax")(x)
model = Model(inputs=base_model.input, outputs=outputs)
# Freeze base model
for layer in base_model.layers:
layer.trainable = False
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
# --- Train the model ---
model.fit(train_ds, validation_data=val_ds, epochs=10)
# --- Optional Fine-tuning (Unfreeze top layers) ---
# Uncomment to fine-tune top layers of ResNet50
for layer in base_model.layers[-30:]:
layer.trainable = True
model.compile(optimizer=tf.keras.optimizers.Adam(1e-5), loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.fit(train_ds, validation_data=val_ds, epochs=5)
# --- Save the model ---
model_h5_path = os.path.join(model_output_path, "pokemon_resnet50.h5")
model.save(model_h5_path)
print(f"Model saved to {model_h5_path}")
# --- Save as TensorFlow SavedModel (for ONNX export) ---
saved_model_path = os.path.join(model_output_path, "saved_model")
tf.saved_model.save(model, saved_model_path)
print(f"SavedModel exported to {saved_model_path}")
import keras
import tensorflow as tf import tensorflow as tf
from keras import layers from tensorflow.keras.applications import ResNet50
from tensorflow import data as tf_data from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.models import Model
from tensorflow.keras.preprocessing import image_dataset_from_directory
from tensorflow.keras import layers
from tensorflow.keras.callbacks import EarlyStopping
from sklearn.utils.class_weight import compute_class_weight
import numpy as np
import json
import os
# --- GPU Strategy --- # --- GPU Strategy ---
strategy = tf.distribute.MirroredStrategy() strategy = tf.distribute.MirroredStrategy()
print("Number of GPUs:", strategy.num_replicas_in_sync) print("Number of GPUs:", strategy.num_replicas_in_sync)
# --- Parameters --- # --- Paths ---
data_dir = "/home/users/d/divia/scratch/Combined_Dataset" dataset_path = "/home/padi/Git/pokedex/Combined_Dataset"
image_size = (224, 224) model_output_path = "/home/padi/Git/pokedex/ResNet50"
num_classes = 151 os.makedirs(model_output_path, exist_ok=True)
base_batch_size = 32
base_lr = 1e-3
global_batch_size = 32
scaled_lr = min(base_lr * (global_batch_size / base_batch_size), 1e-3)
# --- Load Dataset ---
full_ds = keras.utils.image_dataset_from_directory(
data_dir,
labels="inferred",
label_mode="int",
image_size=image_size,
batch_size=global_batch_size,
shuffle=True,
seed=1234
)
# --- Train/Val Split --- # --- Image settings ---
total_batches = tf.data.experimental.cardinality(full_ds).numpy() img_size = (224, 224)
train_size = int(0.8 * total_batches) batch_size = 32
train_ds = full_ds.take(train_size)
val_ds = full_ds.skip(train_size)
# --- Data Augmentation --- # --- Data Augmentation ---
data_augmentation_layers = keras.Sequential([ data_augmentation = tf.keras.Sequential([
layers.RandomFlip("horizontal"), layers.RandomFlip("horizontal"),
layers.RandomRotation(0.1), layers.RandomRotation(0.1),
layers.RandomZoom(0.1),
layers.RandomContrast(0.1),
]) ])
def preprocess_train(img, label): # --- Load datasets ---
img = data_augmentation_layers(img) raw_train_ds = image_dataset_from_directory(
label = tf.one_hot(label, num_classes) dataset_path,
return img, label image_size=img_size,
batch_size=batch_size,
validation_split=0.2,
subset="training",
seed=123,
)
def preprocess_val(img, label): raw_val_ds = image_dataset_from_directory(
label = tf.one_hot(label, num_classes) dataset_path,
return img, label image_size=img_size,
batch_size=batch_size,
validation_split=0.2,
subset="validation",
seed=123,
)
train_ds = train_ds.map(preprocess_train, num_parallel_calls=tf_data.AUTOTUNE) # Save class names
val_ds = val_ds.map(preprocess_val, num_parallel_calls=tf_data.AUTOTUNE) class_names = raw_train_ds.class_names
with open(os.path.join(model_output_path, "class_names.json"), "w") as f:
json.dump(class_names, f)
print(f"Detected {len(class_names)} Pokémon classes.")
# --- Compute class weights ---
print("Computing class weights...")
all_labels = []
for _, labels in raw_train_ds.unbatch():
all_labels.append(labels.numpy())
all_labels = np.array(all_labels)
class_weights = compute_class_weight(
class_weight="balanced",
classes=np.unique(all_labels),
y=all_labels
)
class_weight_dict = dict(enumerate(class_weights))
print("Class weights ready.")
train_ds = train_ds.prefetch(buffer_size=tf_data.AUTOTUNE) # --- Performance improvements ---
val_ds = val_ds.prefetch(buffer_size=tf_data.AUTOTUNE) AUTOTUNE = tf.data.AUTOTUNE
train_ds = raw_train_ds.map(lambda x, y: (data_augmentation(x), y)).prefetch(AUTOTUNE)
val_ds = raw_val_ds.prefetch(buffer_size=AUTOTUNE)
# --- Build & Compile Model --- # --- Build the model ---
with strategy.scope(): with strategy.scope():
base_model = tf.keras.applications.ResNet50( base_model = ResNet50(weights="imagenet", include_top=False, input_shape=(224, 224, 3))
include_top=False, x = base_model.output
weights='imagenet', x = GlobalAveragePooling2D()(x)
input_shape=(224, 224, 3) outputs = Dense(len(class_names), activation="softmax")(x)
)
model = keras.Sequential([ model = Model(inputs=base_model.input, outputs=outputs)
base_model,
layers.GlobalAveragePooling2D(),
layers.Dense(256, activation='relu'),
layers.Dropout(0.5),
layers.Dense(num_classes, activation='softmax')
])
optimizer = tf.keras.optimizers.Adam(learning_rate=scaled_lr) # Freeze some layers
for layer in base_model.layers[:100]:
layer.trainable = False
model.compile( model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
optimizer=optimizer,
loss='categorical_crossentropy',
metrics=['accuracy']
)
# --- Train --- # --- Callbacks ---
callbacks = [ callbacks = [
keras.callbacks.ModelCheckpoint("/home/users/d/divia/ResNet50/save_at_{epoch}.h5"), EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True)
keras.callbacks.EarlyStopping(monitor="val_loss", patience=2, restore_best_weights=True)
] ]
# --- Train the model with class weights ---
model.fit( model.fit(
train_ds, train_ds,
validation_data=val_ds, validation_data=val_ds,
epochs=15, epochs=10,
callbacks=callbacks callbacks=callbacks,
class_weight=class_weight_dict
) )
# --- Save the model ---
model_h5_path = os.path.join(model_output_path, "pokemon_resnet50.h5")
model.save(model_h5_path)
print(f"Model saved to {model_h5_path}")
# --- Save as TensorFlow SavedModel (for ONNX export) ---
saved_model_path = os.path.join(model_output_path, "saved_model")
tf.saved_model.save(model, saved_model_path)
print(f"SavedModel exported to {saved_model_path}")
\ No newline at end of file
from tensorflow import keras
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import os
import random
import json
# === Load class names from JSON ===
with open("../models/ResNet50/class_names.json", "r") as f:
class_names = json.load(f)
class_names = [class_names[i] for i in range(len(class_names))] # convert to list
# === Load trained model ===
model = keras.models.load_model("../models/ResNet50/pokemon_resnet50.h5")
# === Paths ===
base_path = "../Combined_Dataset"
# === Prepare 2x2 Plot ===
plt.figure(figsize=(10, 10))
for i in range(4):
# Pick random class and image
random_class = random.choice(class_names)
class_folder = os.path.join(base_path, random_class)
random_image = random.choice([
f for f in os.listdir(class_folder)
if f.lower().endswith(('.png', '.jpg', '.jpeg'))
])
img_path = os.path.join(class_folder, random_image)
# === Load & Preprocess Image ===
img = keras.utils.load_img(img_path, target_size=(224, 224)) # resize to match model input
img_array = keras.utils.img_to_array(img)
img_array = img_array / 255.0 # normalize if your model expects it
img_array = tf.expand_dims(img_array, 0)
# === Predict ===
predictions = model.predict(img_array, verbose=0)
probabilities = tf.nn.softmax(predictions[0])
predicted_class_index = np.argmax(probabilities)
predicted_label = class_names[predicted_class_index]
confidence = 100 * probabilities[predicted_class_index]
# Compare with actual
is_correct = predicted_label == random_class
# === Plot ===
ax = plt.subplot(2, 2, i + 1)
plt.imshow(img)
plt.axis("off")
plt.title(
f"Pred: {predicted_label}\n"
f"True: {random_class}\n"
f"{'YES' if is_correct else 'NO'} | {confidence:.1f}%",
fontsize=10
)
plt.tight_layout()
plt.show()
["Abo", "Abra", "Akwakwak", "Alakazam", "Amonistar", "Amonita", "Aquali", "Arbok", "Arcanin", "Artikodin", "Aspicot", "A\u00e9romite", "Boustiflor", "Bulbizarre", "Caninos", "Carabaffe", "Carapuce", "Chenipan", "Chrysacier", "Ch\u00e9tiflor", "Coconfort", "Colossinge", "Crustabri", "Dardargnan", "Dodrio", "Doduo", "Dracaufeu", "Draco", "Dracolosse", "Ectoplasma", "Empiflor", "Excelangue", "Fantominus", "Farfetchd", "Feunard", "Flagadoss", "Florizarre", "F\u00e9rosinge", "Galopa", "Goupix", "Gravalanch", "Grodoudou", "Grolem", "Grotadmorv", "Herbizarre", "Hypnomade", "Hypoc\u00e9an", "Hypotrempe", "Ins\u00e9cateur", "Kabuto", "Kabutops", "Kadabra", "Kangourex", "Kicklee", "Kokiyas", "Krabboss", "Krabby", "Lamantine", "Leveinard", "Lippoutou", "Lokhlass", "L\u00e9viator", "M. Mime", "Machoc", "Machopeur", "Mackogneur", "Magicarpe", "Magmar", "Magn\u00e9ti", "Magn\u00e9ton", "Mew", "Mewtwo", "Miaouss", "Mimitoss", "Minidraco", "Mystherbe", "M\u00e9lodelfe", "M\u00e9lof\u00e9e", "M\u00e9tamorph", "Nidoking", "Nidoqueen", "Nidoran_femelle", "Nidoran_male", "Nidorina", "Nidorino", "Noadkoko", "Noeunoeuf", "Nosferalto", "Nosferapti", "Onix", "Ortide", "Ossatueur", "Osselait", "Otaria", "Papilusion", "Paras", "Parasect", "Persian", "Piafabec", "Pikachu", "Poissir\u00e8ne", "Poissoroy", "Ponyta", "Porygon", "Psykokwak", "Ptitard", "Pt\u00e9ra", "Pyroli", "Racaillou", "Rafflesia", "Raichu", "Ramoloss", "Rapasdepic", "Rattata", "Rattatac", "Reptincel", "Rhinocorne", "Rhinof\u00e9ros", "Rondoudou", "Ronflex", "Roucarnage", "Roucool", "Roucoups", "Sabelette", "Sablaireau", "Salam\u00e8che", "Saquedeneu", "Scarabrute", "Smogo", "Smogogo", "Soporifik", "Spectrum", "Stari", "Staross", "Sulfura", "Tadmorv", "Tartard", "Taupiqueur", "Tauros", "Tentacool", "Tentacruel", "Tortank", "Triopikeur", "Tygnon", "T\u00eatarte", "Voltali", "Voltorbe", "\u00c9lecthor", "\u00c9lectrode", "\u00c9lektek", "\u00c9voli"]
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment