diff --git a/Python/pokedex_ResNet50.py b/Python/pokedex_ResNet50.py index dc528a77540d91eaffbc0afe391de81c1ff935b3..80160bc0c7d1903db4b99b96e1df20b3aa394d86 100644 --- a/Python/pokedex_ResNet50.py +++ b/Python/pokedex_ResNet50.py @@ -15,10 +15,10 @@ strategy = tf.distribute.MirroredStrategy() print("Number of GPUs:", strategy.num_replicas_in_sync) # --- Paths --- -#dataset_path = "/home/padi/Git/pokedex/Combined_Dataset" -#model_output_path = "/home/padi/Git/pokedex/ResNet50" -dataset_path = "/home/users/d/divia/scratch/Combined_Dataset" -model_output_path = "/home/users/d/divia/pokedex/ResNet50" +dataset_path = "/home/padi/Git/pokedex/Combined_Dataset" +model_output_path = "/home/padi/Git/pokedex/models/ResNet50" +#dataset_path = "/home/users/d/divia/scratch/Combined_Dataset" +#model_output_path = "/home/users/d/divia/pokedex/models/ResNet50" os.makedirs(model_output_path, exist_ok=True) # --- Image settings --- @@ -102,7 +102,7 @@ callbacks = [ model.fit( train_ds, validation_data=val_ds, - epochs=10, + epochs=1, callbacks=callbacks, class_weight=class_weight_dict ) diff --git a/Python/pokedex_xception.py b/Python/pokedex_xception.py new file mode 100644 index 0000000000000000000000000000000000000000..8584135740b475299f44b6d92427dd2090b321ef --- /dev/null +++ b/Python/pokedex_xception.py @@ -0,0 +1,136 @@ +import tensorflow as tf +from tensorflow.keras import layers, models, Input +from tensorflow.keras.callbacks import EarlyStopping +from tensorflow.keras.preprocessing import image_dataset_from_directory +from sklearn.utils.class_weight import compute_class_weight +import numpy as np +import json +import os + +# --- GPU Strategy --- +strategy = tf.distribute.MirroredStrategy() +print("Number of GPUs:", strategy.num_replicas_in_sync) + +# --- Paths --- +dataset_path = "/home/padi/Git/pokedex/Combined_Dataset" +model_output_path = "/home/padi/Git/pokedex/models/Xception" +#dataset_path = "/home/users/d/divia/scratch/Combined_Dataset" +#model_output_path = "/home/users/d/divia/pokedex/models7Xception" +os.makedirs(model_output_path, exist_ok=True) + +# --- Custom Xception-like model --- +def simple_xception(input_shape, num_classes): + inputs = Input(shape=input_shape) + + x = layers.Rescaling(1.0 / 255)(inputs) + x = layers.Conv2D(128, 3, strides=2, padding="same")(x) + x = layers.BatchNormalization()(x) + x = layers.Activation("relu")(x) + + previous_block_activation = x + + for size in [256, 512, 728]: + x = layers.Activation("relu")(x) + x = layers.SeparableConv2D(size, 3, padding="same")(x) + x = layers.BatchNormalization()(x) + + x = layers.Activation("relu")(x) + x = layers.SeparableConv2D(size, 3, padding="same")(x) + x = layers.BatchNormalization()(x) + + x = layers.MaxPooling2D(3, strides=2, padding="same")(x) + residual = layers.Conv2D(size, 1, strides=2, padding="same")(previous_block_activation) + x = layers.add([x, residual]) + previous_block_activation = x + + x = layers.SeparableConv2D(1024, 3, padding="same")(x) + x = layers.BatchNormalization()(x) + x = layers.Activation("relu")(x) + + x = layers.GlobalAveragePooling2D()(x) + x = layers.Dropout(0.25)(x) + outputs = layers.Dense(num_classes, activation='softmax')(x) + + return models.Model(inputs, outputs) + +# --- Image settings --- +img_size = (256, 256) +batch_size = 32 + +# --- Data Augmentation --- +data_augmentation = tf.keras.Sequential([ + layers.RandomFlip("horizontal"), + layers.RandomRotation(0.1), + layers.RandomZoom(0.1), + layers.RandomContrast(0.1), +]) + +# --- Load datasets --- +raw_train_ds = image_dataset_from_directory( + dataset_path, + image_size=img_size, + batch_size=batch_size, + validation_split=0.2, + subset="training", + seed=123, +) + +raw_val_ds = image_dataset_from_directory( + dataset_path, + image_size=img_size, + batch_size=batch_size, + validation_split=0.2, + subset="validation", + seed=123, +) + +# Save class names +class_names = raw_train_ds.class_names +with open(os.path.join(model_output_path, "class_names.json"), "w") as f: + json.dump(class_names, f) +print(f"Detected {len(class_names)} Pokémon classes.") + +# --- Compute class weights --- +print("Computing class weights...") +all_labels = [label.numpy() for _, label in raw_train_ds.unbatch()] +class_weights = compute_class_weight( + class_weight="balanced", + classes=np.unique(all_labels), + y=np.array(all_labels) +) +class_weight_dict = dict(enumerate(class_weights)) +print("Class weights ready.") + +# --- Performance improvements --- +AUTOTUNE = tf.data.AUTOTUNE +train_ds = raw_train_ds.map(lambda x, y: (data_augmentation(x), y)).prefetch(AUTOTUNE) +val_ds = raw_val_ds.prefetch(buffer_size=AUTOTUNE) + +# --- Build and compile model --- +with strategy.scope(): + model = simple_xception((*img_size, 3), num_classes=len(class_names)) + model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) + +# --- Callbacks --- +callbacks = [ + EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True) +] + +# --- Train the model --- +model.fit( + train_ds, + validation_data=val_ds, + epochs=20, + callbacks=callbacks, + class_weight=class_weight_dict +) + +# --- Save the model --- +model_h5_path = os.path.join(model_output_path, "pokemon_xception.h5") +model.save(model_h5_path) +print(f"Model saved to {model_h5_path}") + +# --- Save as TensorFlow SavedModel --- +saved_model_path = os.path.join(model_output_path, "saved_model") +tf.saved_model.save(model, saved_model_path) +print(f"SavedModel exported to {saved_model_path}") diff --git a/models/ResNet50/pokemon_resnet50.h5 b/models/ResNet50/pokemon_resnet50.h5 index 982cdc6b37379437251ea54fab16d3c5f6609991..a126216f80da6a9506e385e5c7b060c7471179c7 100644 Binary files a/models/ResNet50/pokemon_resnet50.h5 and b/models/ResNet50/pokemon_resnet50.h5 differ diff --git a/models/ResNet50/pokemon_resnet50.har b/models/ResNet50/pokemon_resnet50.har deleted file mode 100644 index 14eeebcaa1e5e747986b3ed467323710f6e58720..0000000000000000000000000000000000000000 Binary files a/models/ResNet50/pokemon_resnet50.har and /dev/null differ diff --git a/models/ResNet50/pokemon_resnet50.hef b/models/ResNet50/pokemon_resnet50.hef deleted file mode 100644 index b3bcca985fe91959707a70628abb16a8df7f969a..0000000000000000000000000000000000000000 Binary files a/models/ResNet50/pokemon_resnet50.hef and /dev/null differ diff --git a/models/ResNet50/pokemon_resnet50.onnx b/models/ResNet50/pokemon_resnet50.onnx deleted file mode 100644 index 3d199a5f5beff4f6d62f2d59dbbd8cd8b85d17c1..0000000000000000000000000000000000000000 Binary files a/models/ResNet50/pokemon_resnet50.onnx and /dev/null differ diff --git a/models/ResNet50/pokemon_resnet50_compiled.har b/models/ResNet50/pokemon_resnet50_compiled.har deleted file mode 100644 index 169c6a48ec988cf2e0dc8cceafcaf8ee0c6f1299..0000000000000000000000000000000000000000 Binary files a/models/ResNet50/pokemon_resnet50_compiled.har and /dev/null differ diff --git a/models/ResNet50/pokemon_resnet50_optimized.har b/models/ResNet50/pokemon_resnet50_optimized.har deleted file mode 100644 index 626b550088837644b1c779b3f9eadd9021008ece..0000000000000000000000000000000000000000 Binary files a/models/ResNet50/pokemon_resnet50_optimized.har and /dev/null differ diff --git a/models/ResNet50/saved_model/fingerprint.pb b/models/ResNet50/saved_model/fingerprint.pb index 90015092b1763c5fbfd24b9e1615847388e03c0b..2378f7e80c93a3c4204936a29dfbf4cd872d4a8c 100644 Binary files a/models/ResNet50/saved_model/fingerprint.pb and b/models/ResNet50/saved_model/fingerprint.pb differ diff --git a/models/ResNet50/saved_model/saved_model.pb b/models/ResNet50/saved_model/saved_model.pb index 9a955fba5e8f98d3de755bee6ea1a0f984aa75a6..5bacac02853fc4439551aaa220bb5392d186b2c5 100644 Binary files a/models/ResNet50/saved_model/saved_model.pb and b/models/ResNet50/saved_model/saved_model.pb differ diff --git a/models/ResNet50/saved_model/variables/variables.data-00000-of-00001 b/models/ResNet50/saved_model/variables/variables.data-00000-of-00001 index 9ab645b4639fbfb2be9039a0a109bc4c44c6fe23..62847711cfdef2a5814449403869126581437a30 100644 Binary files a/models/ResNet50/saved_model/variables/variables.data-00000-of-00001 and b/models/ResNet50/saved_model/variables/variables.data-00000-of-00001 differ diff --git a/models/ResNet50/saved_model/variables/variables.index b/models/ResNet50/saved_model/variables/variables.index index 6d12e83b46c4edb3c172eb4072d96056829d1d6f..bd1ff13d8c2f3c0a8473c14290ecd4d00c82f28a 100644 Binary files a/models/ResNet50/saved_model/variables/variables.index and b/models/ResNet50/saved_model/variables/variables.index differ