Skip to content
Snippets Groups Projects
Commit cc235fa3 authored by michael.divia's avatar michael.divia
Browse files

Ow Shit

parent b165ab46
No related branches found
No related tags found
No related merge requests found
...@@ -13,9 +13,8 @@ strategy = tf.distribute.MirroredStrategy() ...@@ -13,9 +13,8 @@ strategy = tf.distribute.MirroredStrategy()
print("Number of GPUs:", strategy.num_replicas_in_sync) print("Number of GPUs:", strategy.num_replicas_in_sync)
# --- Paths --- # --- Paths ---
parser = argparse.ArgumentParser(description="WHERE ?!") parser = argparse.ArgumentParser(description="Train Xception Pokémon Classifier")
parser.add_argument("--hpc", choices=["yes", "no"], default="no", parser.add_argument("--hpc", choices=["yes", "no"], default="no", help="Use HPC paths if 'yes', otherwise local paths.")
help="Use HPC paths if 'yes', otherwise local paths.")
args = parser.parse_args() args = parser.parse_args()
if args.hpc == "yes": if args.hpc == "yes":
...@@ -30,14 +29,12 @@ os.makedirs(model_output_path, exist_ok=True) ...@@ -30,14 +29,12 @@ os.makedirs(model_output_path, exist_ok=True)
# --- Custom Xception-like model --- # --- Custom Xception-like model ---
def simple_xception(input_shape, num_classes): def simple_xception(input_shape, num_classes):
inputs = Input(shape=input_shape) inputs = Input(shape=input_shape)
x = layers.Rescaling(1.0 / 255)(inputs) x = layers.Rescaling(1.0 / 255)(inputs)
x = layers.Conv2D(128, 3, strides=2, padding="same")(x) x = layers.Conv2D(128, 3, strides=2, padding="same")(x)
x = layers.BatchNormalization()(x) x = layers.BatchNormalization()(x)
x = layers.Activation("relu")(x) x = layers.Activation("relu")(x)
previous_block_activation = x previous_block_activation = x
for size in [256, 512, 728]: for size in [256, 512, 728]:
x = layers.Activation("relu")(x) x = layers.Activation("relu")(x)
x = layers.SeparableConv2D(size, 3, padding="same")(x) x = layers.SeparableConv2D(size, 3, padding="same")(x)
...@@ -55,11 +52,11 @@ def simple_xception(input_shape, num_classes): ...@@ -55,11 +52,11 @@ def simple_xception(input_shape, num_classes):
x = layers.SeparableConv2D(1024, 3, padding="same")(x) x = layers.SeparableConv2D(1024, 3, padding="same")(x)
x = layers.BatchNormalization()(x) x = layers.BatchNormalization()(x)
x = layers.Activation("relu")(x) x = layers.Activation("relu")(x)
x = layers.GlobalAveragePooling2D()(x) x = layers.GlobalAveragePooling2D()(x)
x = layers.Dropout(0.25)(x) x = layers.Dropout(0.25)(x)
outputs = layers.Dense(num_classes, activation='softmax')(x)
# Output logits
outputs = layers.Dense(num_classes, activation=None)(x)
return models.Model(inputs, outputs) return models.Model(inputs, outputs)
# --- Image settings --- # --- Image settings ---
...@@ -83,7 +80,6 @@ raw_train_ds = image_dataset_from_directory( ...@@ -83,7 +80,6 @@ raw_train_ds = image_dataset_from_directory(
subset="training", subset="training",
seed=123, seed=123,
) )
raw_val_ds = image_dataset_from_directory( raw_val_ds = image_dataset_from_directory(
dataset_path, dataset_path,
image_size=img_size, image_size=img_size,
...@@ -110,6 +106,12 @@ class_weights = compute_class_weight( ...@@ -110,6 +106,12 @@ class_weights = compute_class_weight(
class_weight_dict = dict(enumerate(class_weights)) class_weight_dict = dict(enumerate(class_weights))
print("Class weights ready.") print("Class weights ready.")
# --- Debug print for class balance ---
print("Unique labels in training set:", np.unique(all_labels))
print("Class Names (index -> name):")
for i, name in enumerate(class_names):
print(f"{i}: {name}")
# --- Performance improvements --- # --- Performance improvements ---
AUTOTUNE = tf.data.AUTOTUNE AUTOTUNE = tf.data.AUTOTUNE
train_ds = raw_train_ds.map(lambda x, y: (data_augmentation(x), y)).prefetch(AUTOTUNE) train_ds = raw_train_ds.map(lambda x, y: (data_augmentation(x), y)).prefetch(AUTOTUNE)
...@@ -118,7 +120,11 @@ val_ds = raw_val_ds.prefetch(buffer_size=AUTOTUNE) ...@@ -118,7 +120,11 @@ val_ds = raw_val_ds.prefetch(buffer_size=AUTOTUNE)
# --- Build and compile model --- # --- Build and compile model ---
with strategy.scope(): with strategy.scope():
model = simple_xception((*img_size, 3), num_classes=len(class_names)) model = simple_xception((*img_size, 3), num_classes=len(class_names))
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) model.compile(
optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy']
)
# --- Callbacks --- # --- Callbacks ---
callbacks = [ callbacks = [
...@@ -142,4 +148,4 @@ print(f"Model saved to {model_h5_path}") ...@@ -142,4 +148,4 @@ print(f"Model saved to {model_h5_path}")
# --- Save as TensorFlow SavedModel --- # --- Save as TensorFlow SavedModel ---
saved_model_path = os.path.join(model_output_path, "saved_model") saved_model_path = os.path.join(model_output_path, "saved_model")
tf.saved_model.save(model, saved_model_path) tf.saved_model.save(model, saved_model_path)
print(f"SavedModel exported to {saved_model_path}") print(f"SavedModel exported to {saved_model_path}")
\ No newline at end of file
...@@ -54,10 +54,16 @@ for i in range(4): ...@@ -54,10 +54,16 @@ for i in range(4):
# --- Predict --- # --- Predict ---
predictions = model.predict(img_array, verbose=0) predictions = model.predict(img_array, verbose=0)
probabilities = tf.nn.softmax(predictions[0]) probabilities = predictions[0]
predicted_class_index = np.argmax(probabilities) predicted_class_index = np.argmax(probabilities)
predicted_label = class_names[predicted_class_index] predicted_label = class_names[predicted_class_index]
confidence = 100 * probabilities[predicted_class_index] confidence = 100 * probabilities[predicted_class_index]
top_5_indices = np.argsort(probabilities)[-5:][::-1]
print("\nTop 5 predictions:")
for idx in top_5_indices:
print(f"{class_names[idx]:<20}: {probabilities[idx]:.4f}")
# Compare with actual # Compare with actual
is_correct = predicted_label == random_class is_correct = predicted_label == random_class
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment