Skip to content
Snippets Groups Projects
Commit c801ffeb authored by michael.divia's avatar michael.divia
Browse files

Added Freezing and Fine Tuning for EfficientNetV2M

parent e85b4b89
Branches
No related tags found
No related merge requests found
import os import os
import gc
import keras import keras
import tensorflow as tf import tensorflow as tf
from keras import layers from keras import layers
from tensorflow import data as tf_data from tensorflow import data as tf_data
# --- Silence TensorFlow logs ---
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
# --- GPU Strategy --- # --- GPU Strategy ---
strategy = tf.distribute.MirroredStrategy() strategy = tf.distribute.MirroredStrategy()
print("Number of GPUs:", strategy.num_replicas_in_sync) print("Number of GPUs:", strategy.num_replicas_in_sync)
...@@ -14,9 +16,8 @@ data_dir = "/home/users/d/divia/scratch/Combined_Dataset" ...@@ -14,9 +16,8 @@ data_dir = "/home/users/d/divia/scratch/Combined_Dataset"
image_size = (240, 240) image_size = (240, 240)
num_classes = 151 num_classes = 151
base_batch_size = 32 base_batch_size = 32
base_lr = 1e-3
global_batch_size = 32 global_batch_size = 32
base_lr = 1e-3
scaled_lr = min(base_lr * (global_batch_size / base_batch_size), 1e-3) scaled_lr = min(base_lr * (global_batch_size / base_batch_size), 1e-3)
# --- Load Dataset --- # --- Load Dataset ---
...@@ -54,10 +55,13 @@ def preprocess_val(img, label): ...@@ -54,10 +55,13 @@ def preprocess_val(img, label):
train_ds = train_ds.map(preprocess_train, num_parallel_calls=tf_data.AUTOTUNE) train_ds = train_ds.map(preprocess_train, num_parallel_calls=tf_data.AUTOTUNE)
val_ds = val_ds.map(preprocess_val, num_parallel_calls=tf_data.AUTOTUNE) val_ds = val_ds.map(preprocess_val, num_parallel_calls=tf_data.AUTOTUNE)
train_ds = train_ds.prefetch(buffer_size=tf_data.AUTOTUNE) # Add auto-shard options to suppress Grappler warning
val_ds = val_ds.prefetch(buffer_size=tf_data.AUTOTUNE) options = tf.data.Options()
options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA
train_ds = train_ds.with_options(options).prefetch(buffer_size=tf_data.AUTOTUNE)
val_ds = val_ds.with_options(options).prefetch(buffer_size=tf_data.AUTOTUNE)
# --- Build & Compile Model --- # --- Build model ---
with strategy.scope(): with strategy.scope():
base_model = tf.keras.applications.EfficientNetV2M( base_model = tf.keras.applications.EfficientNetV2M(
include_top=False, include_top=False,
...@@ -65,31 +69,41 @@ with strategy.scope(): ...@@ -65,31 +69,41 @@ with strategy.scope():
input_shape=(240, 240, 3) input_shape=(240, 240, 3)
) )
model = keras.Sequential([ x = layers.GlobalAveragePooling2D()(base_model.output)
base_model, x = layers.Dense(256, activation='relu')(x)
layers.GlobalAveragePooling2D(), x = layers.Dropout(0.5)(x)
layers.Dense(256, activation='relu'), predictions = layers.Dense(num_classes, activation='softmax')(x)
layers.Dropout(0.5),
layers.Dense(num_classes, activation='softmax') model = keras.Model(inputs=base_model.input, outputs=predictions)
])
optimizer = tf.keras.optimizers.Adam(learning_rate=scaled_lr) # PHASE 1: Freeze the base model
base_model.trainable = False
model.compile( model.compile(
optimizer=optimizer, optimizer=tf.keras.optimizers.Adam(learning_rate=scaled_lr),
loss='categorical_crossentropy', loss="categorical_crossentropy",
metrics=['accuracy'] metrics=["accuracy"]
) )
# --- Train --- # --- Callbacks ---
callbacks = [ callbacks = [
keras.callbacks.ModelCheckpoint("/home/users/d/divia/EfficientNetV2M/save_at_{epoch}.keras"), keras.callbacks.ModelCheckpoint("/home/users/d/divia/EfficientNetV2M/save_at_{epoch}.keras"),
keras.callbacks.EarlyStopping(monitor="val_loss", patience=3, restore_best_weights=True) keras.callbacks.EarlyStopping(monitor="val_loss", patience=3, restore_best_weights=True)
] ]
model.fit( # --- Train head only ---
train_ds, model.fit(train_ds, validation_data=val_ds, epochs=5, callbacks=callbacks)
validation_data=val_ds,
epochs=10, # PHASE 2: Fine-tune top of the base model
callbacks=callbacks # Unfreeze the whole base:
) base_model.trainable = True
\ No newline at end of file
# Recompile with lower LR for fine-tuning
model.compile(
optimizer=tf.keras.optimizers.SGD(learning_rate=1e-5, momentum=0.9),
loss="categorical_crossentropy",
metrics=["accuracy"]
)
# Fine-tune the full model
model.fit(train_ds, validation_data=val_ds, epochs=5, callbacks=callbacks)
\ No newline at end of file
import os
import gc
import keras import keras
import tensorflow as tf import tensorflow as tf
from keras import layers from keras import layers
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment