Image Embeddings¶
Today we'll explore what happens inside a neural network classifier, focusing on the embeddings (feature vectors) that the model learns.
Outline:
- Train a flower classifier (same as u01n1, just to get a model)
- Split the model into body (feature extractor) and head (classifier)
- Extract embeddings and explore what they capture
- Compare embedding similarity to raw pixel similarity
- Visualize the embedding space
- See how the classifier uses embeddings as prototypes
Course Objectives Addressed¶
- TM-Embeddings: "I can explain how neural networks represent data as vectors (embeddings) where geometric relationships encode meaning."
- TM-RepresentationLearning: "I can explain how a neural network learns useful internal representations through training."
- TM-DotProduct: "I can compute and reason about dot products of vectors."
- OG-Pretrained: "I can explain the benefits and risks of using pretrained models."
Setup¶
Same setup as u01n1. Run these cells without modification.
import random
import time
import os
import torch
import torchvision
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
from IPython.display import display, HTML
from PIL import Image
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
import torchvision.models as models
import torch.nn as nn
import torch.optim as optim
from tqdm.notebook import tqdm
print(f"PyTorch version: {torch.__version__}")
print(f"TorchVision version: {torchvision.__version__}")
num_gpus = torch.accelerator.device_count()
print(f"Accelerators available: {num_gpus}")
if num_gpus == 0:
display(HTML("No Accelerators available. Training will be slow. <b>Please enable an accelerator.</b>"))
PyTorch version: 2.10.0 TorchVision version: 0.25.0 Accelerators available: 1
device = torch.accelerator.current_accelerator() or torch.device("cpu")
print(f"Using device: {device}")
Using device: mps
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def show_image_grid(images, titles=None, rows=None, cols=3, title_fontsize=8, figsize=(10, 10)):
"""Display a grid of PIL images."""
if rows is None:
rows = (len(images) + (cols - 1)) // cols
fig, axs = plt.subplots(rows, cols, figsize=figsize)
for ax in axs.flatten(): ax.axis('off')
for i, ax in enumerate(axs.flatten()):
if i >= len(images): break
ax.imshow(np.array(images[i]).astype('uint8'))
if titles is not None:
ax.set_title(titles[i], fontsize=title_fontsize)
VALIDATION_FRAC = 0.2
class config:
seed = 123
learning_rate = 1e-3
epochs = 1
batch_size = 16
image_size = 256
pretrained_weights = models.EfficientNet_B0_Weights.IMAGENET1K_V1
freeze_backbone = False
set_seed(config.seed)
Load the data¶
import urllib.request
import tarfile
url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
download_dir = Path("./data")
download_dir.mkdir(parents=True, exist_ok=True)
archive_path = download_dir / "flower_photos.tgz"
extract_path = download_dir / "flower_photos"
if not archive_path.exists():
print(f"Downloading {url} to {archive_path}...")
urllib.request.urlretrieve(url, archive_path)
print("Download complete.")
if not extract_path.exists():
print(f"Extracting {archive_path} to {extract_path}...")
with tarfile.open(archive_path, "r:gz") as tar:
tar.extractall(path=download_dir)
print("Extraction complete.")
data_path = extract_path
print(f"Data path set to: {data_path}")
Data path set to: data/flower_photos
data_transforms = config.pretrained_weights.transforms(crop_size=config.image_size)
full_dataset = datasets.ImageFolder(root=data_path, transform=data_transforms)
class_names = full_dataset.classes
val_size = int(VALIDATION_FRAC * len(full_dataset))
train_size = len(full_dataset) - val_size
train_dataset, val_dataset = torch.utils.data.random_split(
full_dataset, [train_size, val_size], generator=torch.Generator().manual_seed(config.seed)
)
num_dataloader_workers = os.cpu_count() // 2 if os.cpu_count() else 0
train_dataloader = DataLoader(
train_dataset, batch_size=config.batch_size, shuffle=True,
num_workers=num_dataloader_workers,
multiprocessing_context='fork' if num_dataloader_workers > 0 else None
)
val_dataloader = DataLoader(
val_dataset, batch_size=config.batch_size, shuffle=False,
num_workers=num_dataloader_workers,
multiprocessing_context='fork' if num_dataloader_workers > 0 else None
)
print(f"Classes: {class_names}")
print(f"Training samples: {len(train_dataset)}, Validation samples: {len(val_dataset)}")
Classes: ['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips'] Training samples: 2936, Validation samples: 734
We also need a way to load the original (un-transformed) images for display. The transformed images are normalized and hard to visualize, so we'll load images directly from file paths when we need to display them.
def get_val_image(idx):
"""Load the original (un-transformed) image for a validation set index."""
dataset_idx = val_dataset.indices[idx]
path, _ = full_dataset.samples[dataset_idx]
return Image.open(path).convert('RGB').resize((config.image_size, config.image_size))
def get_val_images(indices):
"""Load original images for a list of validation set indices."""
return [get_val_image(idx) for idx in indices]
Train a model¶
Same as u01n1: load a pretrained EfficientNet, replace the classification head, train for 1 epoch.
model = models.efficientnet_b0(weights=config.pretrained_weights)
if config.freeze_backbone:
# Freeze all layers except the classifier
for param in model.parameters():
param.requires_grad = False
num_features = model.classifier[-1].in_features
model.classifier[-1] = nn.Linear(num_features, len(class_names))
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)
model.train()
for inputs, labels in tqdm(train_dataloader, desc="Training"):
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
optimizer.zero_grad()
# Quick validation check
model.eval()
correct = 0
total = 0
with torch.no_grad():
for inputs, labels in val_dataloader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
correct += (preds == labels).sum().item()
total += labels.size(0)
print(f"Validation accuracy: {correct / total:.4f}")
Training: 0%| | 0/184 [00:00<?, ?it/s]
Validation accuracy: 0.8624
The Body and the Head¶
A neural network trained for classification can be split into two parts:
- Body (feature extractor / backbone): takes an image and produces a vector of numbers (the embedding)
- Head (classifier): takes that vector and produces class predictions
Let's look at the structure of our EfficientNet model:
for name, child in model.named_children():
print(f"{name}: {type(child).__name__}")
print()
print("Classifier (head):")
print(model.classifier)
features: Sequential avgpool: AdaptiveAvgPool2d classifier: Sequential Classifier (head): Sequential( (0): Dropout(p=0.2, inplace=True) (1): Linear(in_features=1280, out_features=5, bias=True) )
The body is features + avgpool: it processes the image through many layers and produces a single vector.
The head is classifier: a Dropout layer followed by a single Linear layer that maps the feature vector to class scores.
We can build a feature extractor by copying the model and removing the final Linear layer from the classifier. You've already seen this trick:
import copy
feature_extractor = copy.deepcopy(model)
del feature_extractor.classifier[-1]
feature_extractor = feature_extractor.eval()
print("Feature extractor classifier (no more Linear layer):")
print(feature_extractor.classifier)
Feature extractor classifier (no more Linear layer): Sequential( (0): Dropout(p=0.2, inplace=True) )
Let's try it on a single batch to see what comes out:
sample_images, sample_labels = next(iter(val_dataloader))
with torch.no_grad():
sample_features = feature_extractor(sample_images.to(device))
print("Input shape:", sample_images.shape)
print("Output shape:", sample_features.shape)
Input shape: torch.Size([16, 3, 256, 256]) Output shape: torch.Size([16, 1280])
Task: What do the numbers in the output shape mean? What happened to the 256×256×3 image?
Your answer here.
Extract Embeddings¶
Now let's extract the embedding (feature vector) for every image in the validation set. Each image becomes a single vector of 1,280 numbers — its embedding.
all_features = []
all_labels = []
with torch.no_grad():
for images, labels in tqdm(val_dataloader, desc="Extracting embeddings"):
features = feature_extractor(images.to(device))
all_features.append(features.cpu())
all_labels.append(labels)
val_features = torch.cat(all_features).numpy()
val_labels = torch.cat(all_labels).numpy()
print("Embeddings shape:", val_features.shape)
print("Labels shape:", val_labels.shape)
Extracting embeddings: 0%| | 0/46 [00:00<?, ?it/s]
Embeddings shape: (734, 1280) Labels shape: (734,)
Task: The embeddings have shape (734, 1280). What does the 734 represent? What does the 1280 represent?
Your answer here.
Similarity in Embedding Space¶
If the model has learned useful representations, then similar images should have similar embeddings. Let's test this.
First, let's pick three example images: two sunflowers and one tulip.
sunflower_indices = [idx for idx, label in enumerate(val_labels) if label == class_names.index('sunflowers')]
tulip_indices = [idx for idx, label in enumerate(val_labels) if label == class_names.index('tulips')]
example_indices = [sunflower_indices[0], sunflower_indices[2], tulip_indices[0]]
show_image_grid(
get_val_images(example_indices),
titles=['sunflower A', 'sunflower B', 'tulip C'])
Now let's compute dot products between their embeddings. Recall that the dot product measures how much two vectors point in the same direction — it's a rough measure of similarity.
vec_a, vec_b, vec_c = val_features[example_indices]
print(f"sunflower A · sunflower B = {vec_a @ vec_b:.2f}")
print(f"sunflower A · tulip C = {vec_a @ vec_c:.2f}")
print(f"sunflower B · tulip C = {vec_b @ vec_c:.2f}")
sunflower A · sunflower B = 17.95 sunflower A · tulip C = 6.81 sunflower B · tulip C = -15.59
Task: Which pair is most similar according to the dot product? Does this match your intuition?
Your answer here.
Cosine similarity¶
Raw dot products depend on the magnitude of the vectors — a vector with large values will have a high dot product with everything. To control for this, we normalize each vector to unit length before computing the dot product. This is called cosine similarity, and it ranges from -1 (opposite) to +1 (identical direction).
def normalize(vectors):
"""Normalize each row to unit length."""
norms = np.linalg.norm(vectors, axis=1, keepdims=True)
return vectors / norms
example_normed = normalize(val_features[example_indices])
cosine_sim = example_normed @ example_normed.T
pd.DataFrame(
cosine_sim,
index=['sunflower A', 'sunflower B', 'tulip C'],
columns=['sunflower A', 'sunflower B', 'tulip C']
).round(3)
| sunflower A | sunflower B | tulip C | |
|---|---|---|---|
| sunflower A | 1.000 | 0.228 | 0.063 |
| sunflower B | 0.228 | 1.000 | -0.095 |
| tulip C | 0.063 | -0.095 | 1.000 |
Finding similar images¶
Let's use cosine similarity to find the images most and least similar to our first sunflower.
query_idx = example_indices[0]
query_vec = val_features[query_idx]
# Compute cosine similarity of the query against all validation embeddings
val_normed = normalize(val_features)
query_normed = query_vec / np.linalg.norm(query_vec)
similarities = val_normed @ query_normed
most_similar = np.argsort(similarities)[::-1]
print("Most similar (by embedding):")
show_image_grid(get_val_images(most_similar[:9]))
Most similar (by embedding):
Exercise: Show the 9 least similar images. (Hint: just change which end of most_similar you slice from.)
# your code here
Least similar (by embedding):
Embeddings vs Raw Pixels¶
Are embeddings actually better than just comparing raw pixels? Let's find out.
We'll flatten each image into a single long vector of pixel values and compute similarity the same way.
# Collect raw (transformed) images as flat pixel vectors
raw_pixels = []
for images, _ in val_dataloader:
raw_pixels.append(images.reshape(images.size(0), -1))
raw_pixels = torch.cat(raw_pixels).numpy()
print("Raw pixel vectors shape:", raw_pixels.shape)
print(f"Each image is now a vector of {raw_pixels.shape[1]} numbers")
Raw pixel vectors shape: (734, 196608) Each image is now a vector of 196608 numbers
# Similarity in pixel space
query_pixels = raw_pixels[query_idx]
pixel_similarities = raw_pixels @ query_pixels
most_similar_pixels = np.argsort(pixel_similarities)[::-1]
print("Most similar (by raw pixels):")
show_image_grid(get_val_images(most_similar_pixels[:9]))
Most similar (by raw pixels):
Task: Compare the "most similar" results from embeddings vs raw pixels. Which approach better captures what is in the image (as opposed to surface-level appearance like color or brightness)? Why do you think that is?
Your answer here.
Task: We saw two ways to represent images:
- Raw pixels (256 × 256 × 3 = 196,608 numbers)
- Embeddings (1,280 numbers)
For each representation: What information does it preserve? What information does it discard?
Your answer here.
Visualizing the Embedding Space¶
Our embeddings live in 1,280 dimensions — way too many to visualize directly. But we can project them down to 2D and see if structure emerges.
Important caveat: Any single 2D projection can only show a tiny slice of the full 1,280-dimensional space. Different projections will reveal different structure. To get a less misleading picture, we'll look at several projections side by side.
First, we'll use PCA (Principal Component Analysis) to find the directions of greatest variation. Then we'll also try LDA (Linear Discriminant Analysis), which finds directions that best separate the classes. (The details of PCA and LDA are beyond this course — just think of them as different strategies for flattening 1,280 dimensions into 2.)
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.decomposition import PCA
# Get consistent colors for each class
tab10 = plt.cm.tab10
class_colors = [tab10(i) for i in range(len(class_names))]
# Fit multiple projections on the raw (un-normalized) embeddings,
# since the classifier uses dot products, not cosine similarity.
pca = PCA(n_components=6).fit(val_features)
lda = LinearDiscriminantAnalysis(n_components=2).fit(val_features, val_labels)
projections = {
'PCA dims 1–2': pca.transform(val_features)[:, :2],
'PCA dims 3–4': pca.transform(val_features)[:, 2:4],
'PCA dims 5–6': pca.transform(val_features)[:, 4:6],
'LDA (best class separation)': lda.transform(val_features),
}
fig, axes = plt.subplots(2, 2, figsize=(15, 8))
for ax, (title, proj) in zip(axes.flat, projections.items()):
for i, name in enumerate(class_names):
mask = val_labels == i
ax.scatter(proj[mask, 0], proj[mask, 1], label=name,
color=class_colors[i], alpha=0.5, s=10)
ax.set_title(title)
ax.set_xticks([])
ax.set_yticks([])
axes.flat[-1].legend(bbox_to_anchor=(1.05, 1), loc='upper left', markerscale=3)
plt.tight_layout()
plt.show()
# Keep references for later use
projected = projections['LDA (best class separation)']
Task: What do you notice about how the classes are arranged? Why do clusters form?
Your answer here.
Prototypes: How the Classifier Uses Embeddings¶
Now let's look at the head — the final Linear layer that maps embeddings to class predictions.
This layer has a weight matrix where each row is a learned vector for one class. We can think of each row as a prototype — the model's internal template for what that class "looks like" in embedding space. The classifier predicts a class by computing how similar (dot product) each image's embedding is to each prototype.
classifier_weights = model.classifier[-1].weight.detach().cpu().numpy()
classifier_bias = model.classifier[-1].bias.detach().cpu().numpy()
print("Classifier weights shape:", classifier_weights.shape)
print("Classifier bias shape:", classifier_bias.shape)
print(f"\nThere are {classifier_weights.shape[0]} prototypes, one per class, each with {classifier_weights.shape[1]} dimensions.")
Classifier weights shape: (5, 1280) Classifier bias shape: (5,) There are 5 prototypes, one per class, each with 1280 dimensions.
Let's extract the prototype for one class and see which images align with it most.
Exercise: Extract the prototype for the "roses" class. (Hint: class_names tells you the index.)
rose_idx = ...
rose_prototype = classifier_weights[...]
print(f"Rose prototype shape: {rose_prototype.shape}")
Rose prototype shape: (1280,)
Now compute the dot product of every validation embedding with the rose prototype. This tells us how "rosy" the model thinks each image is.
rose_scores = val_features @ rose_prototype
print("Rose scores shape:", rose_scores.shape)
images_by_rosiness = np.argsort(rose_scores)
print("\nMost 'rosy' images:")
show_image_grid(get_val_images(images_by_rosiness[::-1][:9]))
Rose scores shape: (734,) Most 'rosy' images:
Exercise: Show the least rosy images.
# your code here
Least 'rosy' images:
We can also project the prototypes into the same LDA space. Since prototypes are directions (weight vectors), not points, we'll show them as arrows from the origin. Each arrow points in the direction that the classifier associates with that class.
# Project prototypes using only the LDA rotation, without centering.
# lda.transform() subtracts the data mean first, which is correct for data points
# but wrong for weight vectors (they're directions, not points in data space).
projected_prototypes = classifier_weights @ lda.scalings_[:, :2]
# scale up to be on a similar scale to the projected data points
projected_prototypes *= 5
plt.figure(figsize=(8, 6))
for i, name in enumerate(class_names):
mask = val_labels == i
plt.scatter(projected[mask, 0], projected[mask, 1], label=name,
color=class_colors[i], alpha=0.4, s=10)
# Show prototype as an arrow from the origin
plt.annotate('', xy=projected_prototypes[i], xytext=(0, 0),
arrowprops=dict(arrowstyle='->', color=class_colors[i], lw=2))
plt.text(projected_prototypes[i, 0], projected_prototypes[i, 1], f' {name}',
color=class_colors[i], fontsize=9, fontweight='bold')
plt.scatter(0, 0, color='black', s=30, zorder=10, marker='o') # mark the origin
plt.legend(markerscale=3)
plt.title("Embeddings + class prototype directions (arrows)")
plt.show()
Task: How do the prototypes (×) relate to the data points in their class? In other classes?
Your answer here.
Wrap-up¶
Task: In your own words, what is an embedding? How is it different from the raw image?
Your answer here.
Task: Suppose you have a small dataset of bird photos (only 50 images) and you want to build a bird species classifier. How could you use the pretrained EfficientNet body (which was trained on ImageNet, not birds specifically) to help? Why would this work better than training from scratch?
Your answer here.