Skip to content
Snippets Groups Projects
Commit e15bb05c authored by michael.divia's avatar michael.divia
Browse files

Migrated from SoftMax to Logits

parent cc235fa3
No related branches found
No related tags found
No related merge requests found
...@@ -69,15 +69,11 @@ print(f"Detected {len(class_names)} Pokémon classes.") ...@@ -69,15 +69,11 @@ print(f"Detected {len(class_names)} Pokémon classes.")
# --- Compute class weights --- # --- Compute class weights ---
print("Computing class weights...") print("Computing class weights...")
all_labels = [] all_labels = [label.numpy() for _, label in raw_train_ds.unbatch()]
for _, labels in raw_train_ds.unbatch():
all_labels.append(labels.numpy())
all_labels = np.array(all_labels)
class_weights = compute_class_weight( class_weights = compute_class_weight(
class_weight="balanced", class_weight="balanced",
classes=np.unique(all_labels), classes=np.unique(all_labels),
y=all_labels y=np.array(all_labels)
) )
class_weight_dict = dict(enumerate(class_weights)) class_weight_dict = dict(enumerate(class_weights))
print("Class weights ready.") print("Class weights ready.")
...@@ -92,26 +88,29 @@ with strategy.scope(): ...@@ -92,26 +88,29 @@ with strategy.scope():
base_model = ResNet50(weights="imagenet", include_top=False, input_shape=(224, 224, 3)) base_model = ResNet50(weights="imagenet", include_top=False, input_shape=(224, 224, 3))
x = base_model.output x = base_model.output
x = GlobalAveragePooling2D()(x) x = GlobalAveragePooling2D()(x)
outputs = Dense(len(class_names), activation="softmax")(x) outputs = Dense(len(class_names), activation=None)(x)
model = Model(inputs=base_model.input, outputs=outputs) model = Model(inputs=base_model.input, outputs=outputs)
# Freeze some layers
for layer in base_model.layers[:100]: for layer in base_model.layers[:100]:
layer.trainable = False layer.trainable = False
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) model.compile(
optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy']
)
# --- Callbacks --- # --- Callbacks ---
callbacks = [ callbacks = [
EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True) EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True)
] ]
# --- Train the model with class weights --- # --- Train the model ---
model.fit( model.fit(
train_ds, train_ds,
validation_data=val_ds, validation_data=val_ds,
epochs=1, epochs=20,
callbacks=callbacks, callbacks=callbacks,
class_weight=class_weight_dict class_weight=class_weight_dict
) )
...@@ -121,7 +120,7 @@ model_h5_path = os.path.join(model_output_path, "pokedex_ResNet50.h5") ...@@ -121,7 +120,7 @@ model_h5_path = os.path.join(model_output_path, "pokedex_ResNet50.h5")
model.save(model_h5_path) model.save(model_h5_path)
print(f"Model saved to {model_h5_path}") print(f"Model saved to {model_h5_path}")
# --- Save as TensorFlow SavedModel (for ONNX export) --- # --- Save as TensorFlow SavedModel ---
saved_model_path = os.path.join(model_output_path, "saved_model") saved_model_path = os.path.join(model_output_path, "saved_model")
tf.saved_model.save(model, saved_model_path) tf.saved_model.save(model, saved_model_path)
print(f"SavedModel exported to {saved_model_path}") print(f"SavedModel exported to {saved_model_path}")
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment