Train a simple image classifier¶
Task: Train a flower classifier.
Outline:
- Load the data
- Download the dataset.
- Set up the dataloaders (which handles train-validation split, batching, and resizing)
- Train a model
- Get a foundation model (an EfficentNet in our case)
- Fine-tune it.
- Get the model's predictions on an image.
This notebook includes tasks (marked with "Task") and blank code cells (labeled # your code here) to fill in your answers.
Setup¶
Run this code. (You do not need to read or modify the code in this section to successfully complete this assignment.)
# Check versions of Keras and Tensorflow
!pip list | egrep 'keras|tensorflow$'
import os
# Results are better with the TensorFlow backend; this is probably a bug in Keras 3 but I haven't tracked it down.
os.environ["KERAS_BACKEND"] = "tensorflow"
from IPython.display import display, HTML
import io
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
import keras
import keras_cv
import tensorflow as tf
import tensorflow_datasets as tfds
print(f"Keras version: {keras.__version__}, backend: {keras.backend.backend()}")
num_gpus = len(tf.config.list_physical_devices('GPU'))
print(f"GPUs: {num_gpus}")
if num_gpus == 0:
display(HTML("No GPUs available. Training will be slow. <b>Please enable an accelerator.</b>"))
Configure our experiments¶
You'll be invited to change parameters in this code block later; for now just run it as-is.
class config:
seed = 123
learning_rate = 1e-3
epochs = 1
batch_size = 16
image_size = (256, 256)
model_preset = "efficientnetv2_b0_imagenet"
use_zero_init = True
# Reproducibility
# See https://keras.io/examples/keras_recipes/reproducibility_recipes/
#
# Set a seed so that the results are the same every time this is run.
keras.utils.set_random_seed(config.seed)
# If using TensorFlow, this will make GPU ops as deterministic as possible,
# but it will affect the overall performance, so be mindful of that.
tf.config.experimental.enable_op_determinism()
Load the data¶
We'll use a dataset of flower images for this example, but you can later switch this out for another dataset as long as you keep the file-and-folder structure.
The details of the code in this section are not important at this time; just run these cells.
path_to_downloaded_file = keras.utils.get_file(
origin="https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz",
extract=True,
)
Let's see what just got downloaded.
data_path = Path(path_to_downloaded_file).parent / 'flower_photos'
!ls {data_path}
We'll use a Keras helper function to load the data.
Docs: https://keras.io/api/data_loading/image/#imagedatasetfromdirectory-function
# Define which classes we want to use, in what order.
class_names = ['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']
# Create training and validation datasets
train_dataset, val_dataset = keras.utils.image_dataset_from_directory(
data_path,
validation_split=0.2,
labels='inferred',
class_names=class_names,
label_mode='int',
batch_size=config.batch_size,
image_size=config.image_size,
shuffle=True,
seed=128,
subset='both',
crop_to_aspect_ratio=True
)
Let's show some example images.
[[example_images, example_labels]] = train_dataset.take(1)
fig, axs = plt.subplots(3, 3, figsize=(10, 10))
for i, ax in enumerate(axs.flatten()):
ax.imshow(np.array(example_images[i]).astype('uint8'))
label = example_labels[i]
ax.set(title=f"{label} ({class_names[label]})")
ax.axis('off')
# Alternative approach (doesn't show labels)
# keras_cv.visualization.plot_image_gallery(example_images, value_range=(0, 255))
Train a model¶
We'll unpack this code over the next several weeks. For now, pay attention to the progress bar that will (eventually) show on the last line of the output.
# Create a model using a pretrained backbone
# See https://keras.io/api/keras_cv/models/tasks/image_classifier/ for options
model = keras_cv.models.ImageClassifier.from_preset(
config.model_preset,
num_classes=len(class_names))
# Zero the output-layer weights (they were randomly initialized, which adds noise to gradients when fine-tuning)
# I was reminded of this by https://twitter.com/wightmanr/status/1742570388016758822
if config.use_zero_init:
output_layer = model.layers[-1]
output_layer.set_weights([w * 0 for w in output_layer.weights])
# Set up the model for training
model.compile(
loss='sparse_categorical_crossentropy',
optimizer=keras.optimizers.Adam(learning_rate=config.learning_rate),
metrics=['accuracy']
)
model.summary(show_trainable=True)
# Train the model. (Note: this may show some warnings, and it may stop without showing
# progress for up to a minute while it translates the model to run on the GPU.)
history = model.fit(
train_dataset,
validation_data=val_dataset,
epochs=config.epochs
)
Task: Fill in the table below, using the last line from the output above:
- Training set accuracy: ___%
- Validation set accuracy: ___%
- Training loss:
- Validation loss:
Make some predictions¶
# Load a new image
image_file = keras.utils.get_file(origin='https://upload.wikimedia.org/wikipedia/commons/thumb/c/c2/Beautiful_red_tulip.jpg/382px-Beautiful_red_tulip.jpg')
image = keras.utils.load_img(image_file, target_size=(256, 256), keep_aspect_ratio=True)
display(image)
probabilities = model.predict(np.array(image)[np.newaxis, ...])[0]
pd.DataFrame({'class': class_names, 'prob': probabilities}).sort_values('prob', ascending=False)
Task: Is the second column a valid probability distribution (ignoring round-off errors)? Describe why or why not.
Your answer here.
Task: Write code to show the category with the highest predicted probability. To do this, use the np.argmax function and the class_names list.
Hint: look at the value of the probabilities variable by making a code chunk with just probabilities in it.
# your code here
Experimentation¶
Try changing one parameter in the config code block above and rerun the notebook. What effect does this have on the validation accuracy? (Run the same parameters a few times with different values for config.seed to check if the result is robust.)
Tip: an easy way to track your experiemnts is to copy and paste the final progress-bar line.
Your answer here
Optional extension: try out your own image¶
Finish the code below to be able to try out the classifier on your own image.
from ipywidgets import widgets
uploader = widgets.FileUpload()
uploader
if len(uploader.data) > 0:
image_file = io.BytesIO(uploader.data[0])
image = keras.utils.load_img(image_file, target_size=(256, 256), keep_aspect_ratio=True)
display(image)
# TODO: finish showing the predictions on this image and show the most likely class.
# your code here