-
michael.divia authoredmichael.divia authored
pokedex.py 3.05 KiB
import os
import numpy as np
import keras
from keras import layers
import matplotlib.pyplot as plt
from tensorflow import data as tf_data
import random
#os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
# Load training and validation dataset
train_ds, val_ds = keras.utils.image_dataset_from_directory(
"Combined_Dataset",
labels="inferred",
label_mode="int",
image_size=(256, 256),
batch_size=20,
shuffle=True,
validation_split=0.2,
subset="both",
seed=random.randint(0,8000)
)
# Get class (Pokémon) names
class_names = train_ds.class_names
# Introduce artificial sample diversity
data_augmentation_layers = [
layers.RandomFlip("horizontal"),
layers.RandomRotation(0.1),
]
def data_augmentation(images):
for layer in data_augmentation_layers:
images = layer(images)
return images
# Apply `data_augmentation` to the training images.
train_ds = train_ds.map(
lambda img, label: (data_augmentation(img), label),
num_parallel_calls=tf_data.AUTOTUNE,
)
# Prefetching samples in GPU memory helps maximize GPU utilization.
train_ds = train_ds.prefetch(tf_data.AUTOTUNE)
val_ds = val_ds.prefetch(tf_data.AUTOTUNE)
# MODEL
def simple_xception_netowkr(input_shape, num_classes):
inputs = keras.Input(shape=input_shape)
# Entry block
x = layers.Rescaling(1.0 / 255)(inputs)
x = layers.Conv2D(128, 3, strides=2, padding="same")(x)
x = layers.BatchNormalization()(x)
x = layers.Activation("relu")(x)
previous_block_activation = x # Set aside residual
for size in [256, 512, 728]:
x = layers.Activation("relu")(x)
x = layers.SeparableConv2D(size, 3, padding="same")(x)
x = layers.BatchNormalization()(x)
x = layers.Activation("relu")(x)
x = layers.SeparableConv2D(size, 3, padding="same")(x)
x = layers.BatchNormalization()(x)
x = layers.MaxPooling2D(3, strides=2, padding="same")(x)
# Project residual
residual = layers.Conv2D(size, 1, strides=2, padding="same")(
previous_block_activation
)
x = layers.add([x, residual]) # Add back residual
previous_block_activation = x # Set aside next residual
x = layers.SeparableConv2D(1024, 3, padding="same")(x)
x = layers.BatchNormalization()(x)
x = layers.Activation("relu")(x)
x = layers.GlobalAveragePooling2D()(x)
if num_classes == 2:
units = 1
else:
units = num_classes
x = layers.Dropout(0.25)(x)
# We specify activation=None so as to return logits
outputs = layers.Dense(units, activation=None)(x)
return keras.Model(inputs, outputs)
model = simple_xception_netowkr(input_shape=(256, 256) + (3,), num_classes=152)
# Train
epochs = 25
callbacks = [
keras.callbacks.ModelCheckpoint("save_at_{epoch}.keras"),
]
model.compile(
optimizer=keras.optimizers.Adam(3e-4),
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[keras.metrics.SparseCategoricalAccuracy(name="acc")],
)
model.fit(
train_ds,
epochs=epochs,
callbacks=callbacks,
validation_data=val_ds,
)