MNIST with PyTorch¶
In [ ]:
import numpy as np
np.set_printoptions(precision=3)
import matplotlib.pyplot as plt
from sklearn.datasets import fetch_openml
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
%matplotlib inline
num_samples = 1000
max_iter = 100
Load and Understand the Data¶
First, let's load the MNIST dataset. Each image is a 28x28 grid of grayscale pixels.
In [ ]:
# Load a subset of MNIST to start
mnist = fetch_openml('mnist_784', version=1, as_frame=False)
x_full = mnist.data[:num_samples] / 255.0 # Scale pixel values to [0,1]
x_full = x_full.astype('float32')
y_full = mnist.target[:num_samples].astype(int)
# Look at the shape of one image
single_image = x_full[0].reshape(28, 28) # Reshape back to 2D
plt.imshow(single_image, cmap='gray')
plt.title(f"Example digit (label: {y_full[0]})")
print("Shape of one image:", single_image.shape)
# Look at the distribution of labels
plt.figure()
plt.hist(y_full, bins=10, rwidth=0.8)
plt.title("Distribution of Labels")
plt.xlabel("Digit")
plt.ylabel("Count");
Shape of one image: (28, 28)
Understanding Flattening¶
To use these images in our model, we need to "flatten" each 28x28 grid into a single vector of 784 numbers:
In [ ]:
# Demonstrate flattening on one image
flat_image = single_image.reshape(-1) # -1 means "figure out this dimension"
print("Shape after flattening:", flat_image.shape)
# Visualize how flattening works
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
ax1.imshow(single_image, cmap='gray')
ax1.set_title("Original 28x28 image")
ax2.scatter(np.arange(784), flat_image)
ax2.set_title("Same image as 784 numbers")
ax2.set_xlabel("Pixel position")
ax2.set_ylabel("Pixel value")
Shape after flattening: (784,)
Out[ ]:
Text(0, 0.5, 'Pixel value')
Our dataset x_full was already flattened: each row is one flattened image:
In [ ]:
print("Full dataset shape:", x_full.shape)
print("Number of training examples:", x_full.shape[0])
print("Number of features per example:", x_full.shape[1])
Full dataset shape: (1000, 784) Number of training examples: 1000 Number of features per example: 784
Setting up data loaders¶
We'll split the data into training and validation sets, and set up PyTorch DataLoader objects to feed the data to our model.
In [ ]:
# Split the data into training and validation sets
x_train, x_val, y_train, y_val = train_test_split(x_full, y_full, test_size=0.2, random_state=42)
# Create DataLoader objects
train_data = TensorDataset(torch.tensor(x_train), torch.tensor(y_train))
val_data = TensorDataset(torch.tensor(x_val), torch.tensor(y_val))
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
val_loader = DataLoader(val_data, batch_size=32, shuffle=False)
Train an MLP to classify MNIST¶
We'll use PyTorch to train a simple multi-layer perceptron (MLP) to classify MNIST digits.
You should be able to piece this together from what we've done before!
In [ ]:
# Define the model
class MLP(nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
super(MLP, self).__init__()
self.linear_1 = nn.Linear(in_features=..., out_features=..., bias=True)
self.linear_2 = nn.Linear(in_features=..., out_features=..., bias=True)
def forward(self, x):
z1 = ...
# internally, this is doing
a1 = ...
# or you could do
a1 = z1.max(torch.tensor(0.0))
# or
a1 = torch.max(z1, torch.tensor(0.0))
logits = ...
return logits
# Instantiate the model
input_size = 784
hidden_size = ...
num_classes = ...
model = MLP(input_size, hidden_size, num_classes)
learning_rate = ...
# Train the model
train_losses = []
val_losses = []
for epoch in range(max_iter):
train_loss = 0.
for i, (x_batch, y_batch) in enumerate(train_loader):
# Forward pass
logits = ...
# Compute the loss (remember that F.cross_entropy does the softmax internally)
loss = F.cross_entropy(logits, y_batch)
train_loss += loss.item()
# Backward pass
# Clear old gradients
model.zero_grad()
# Compute the gradients
loss.backward()
# Update the weights
for param in model.parameters():
param.data -= ... * param.grad
train_loss /= len(train_loader)
train_losses.append(train_loss)
# Compute validation loss
with torch.no_grad(): # Don't compute gradients, since we're not updating the model
val_loss = 0.
for x_val_batch, y_val_batch in val_loader:
logits = ...
val_loss += ...
val_loss /= len(val_loader)
val_losses.append(val_loss)
print(f"Epoch {epoch}: val_loss = {val_loss:.3f}")
plt.plot(train_losses, label='train')
plt.plot(val_losses, label='val')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
Epoch 0: val_loss = 1.654 Epoch 1: val_loss = 0.974 Epoch 2: val_loss = 0.699 Epoch 3: val_loss = 0.609 Epoch 4: val_loss = 0.570 Epoch 5: val_loss = 0.517 Epoch 6: val_loss = 0.498 Epoch 7: val_loss = 0.518 Epoch 8: val_loss = 0.482 Epoch 9: val_loss = 0.494 Epoch 10: val_loss = 0.465 Epoch 11: val_loss = 0.471 Epoch 12: val_loss = 0.488 Epoch 13: val_loss = 0.465 Epoch 14: val_loss = 0.480 Epoch 15: val_loss = 0.478 Epoch 16: val_loss = 0.478 Epoch 17: val_loss = 0.497 Epoch 18: val_loss = 0.494 Epoch 19: val_loss = 0.488 Epoch 20: val_loss = 0.510 Epoch 21: val_loss = 0.485 Epoch 22: val_loss = 0.495 Epoch 23: val_loss = 0.501 Epoch 24: val_loss = 0.517 Epoch 25: val_loss = 0.510 Epoch 26: val_loss = 0.520 Epoch 27: val_loss = 0.528 Epoch 28: val_loss = 0.518 Epoch 29: val_loss = 0.523 Epoch 30: val_loss = 0.532 Epoch 31: val_loss = 0.531 Epoch 32: val_loss = 0.533 Epoch 33: val_loss = 0.538 Epoch 34: val_loss = 0.535 Epoch 35: val_loss = 0.541 Epoch 36: val_loss = 0.554 Epoch 37: val_loss = 0.551 Epoch 38: val_loss = 0.557 Epoch 39: val_loss = 0.560 Epoch 40: val_loss = 0.563 Epoch 41: val_loss = 0.569 Epoch 42: val_loss = 0.565 Epoch 43: val_loss = 0.566 Epoch 44: val_loss = 0.569 Epoch 45: val_loss = 0.568 Epoch 46: val_loss = 0.575 Epoch 47: val_loss = 0.575 Epoch 48: val_loss = 0.584 Epoch 49: val_loss = 0.587 Epoch 50: val_loss = 0.583 Epoch 51: val_loss = 0.590 Epoch 52: val_loss = 0.598 Epoch 53: val_loss = 0.590 Epoch 54: val_loss = 0.590 Epoch 55: val_loss = 0.602 Epoch 56: val_loss = 0.597 Epoch 57: val_loss = 0.598 Epoch 58: val_loss = 0.602 Epoch 59: val_loss = 0.598 Epoch 60: val_loss = 0.604 Epoch 61: val_loss = 0.608 Epoch 62: val_loss = 0.613 Epoch 63: val_loss = 0.611 Epoch 64: val_loss = 0.616 Epoch 65: val_loss = 0.614 Epoch 66: val_loss = 0.621 Epoch 67: val_loss = 0.616 Epoch 68: val_loss = 0.623 Epoch 69: val_loss = 0.622 Epoch 70: val_loss = 0.628 Epoch 71: val_loss = 0.626 Epoch 72: val_loss = 0.628 Epoch 73: val_loss = 0.629 Epoch 74: val_loss = 0.633 Epoch 75: val_loss = 0.632 Epoch 76: val_loss = 0.635 Epoch 77: val_loss = 0.635 Epoch 78: val_loss = 0.637 Epoch 79: val_loss = 0.640 Epoch 80: val_loss = 0.637 Epoch 81: val_loss = 0.638 Epoch 82: val_loss = 0.643 Epoch 83: val_loss = 0.645 Epoch 84: val_loss = 0.649 Epoch 85: val_loss = 0.649 Epoch 86: val_loss = 0.650 Epoch 87: val_loss = 0.650 Epoch 88: val_loss = 0.651 Epoch 89: val_loss = 0.653 Epoch 90: val_loss = 0.653 Epoch 91: val_loss = 0.654 Epoch 92: val_loss = 0.657 Epoch 93: val_loss = 0.659 Epoch 94: val_loss = 0.659 Epoch 95: val_loss = 0.658 Epoch 96: val_loss = 0.660 Epoch 97: val_loss = 0.660 Epoch 98: val_loss = 0.662 Epoch 99: val_loss = 0.665
Out[ ]:
<matplotlib.legend.Legend at 0x1180b0770>
Data Augmentation¶
In [ ]:
# Shifting the image by one pixel shouldn't make a difference, right?
def shift_image(image, dx, dy):
image = image.reshape((28, 28))
shifted_image = torch.roll(image, shifts=(dy, dx), dims=(0, 1))
return shifted_image.reshape(-1)
def shift_images(image_batch, dx, dy):
n_batch = image_batch.shape[0]
return torch.roll(image_batch.reshape(n_batch, 28, 28), shifts=(dy, dx), dims=(1, 2)).reshape(n_batch, 784)
def classify(model, x):
with torch.no_grad():
logits = model(x)
return logits.argmax(dim=-1)
# Visualize the shifted image
single_image = torch.tensor(x_train)[0]
shifted_image = shift_image(single_image, -1, -1)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
ax1.imshow(single_image.reshape(28, 28), cmap='gray')
ax1.set_title(f"Original image (classified as {classify(model, single_image)})")
ax2.imshow(shifted_image.reshape(28, 28), cmap='gray')
ax2.set_title(f"Shifted image (classified as {classify(model, shifted_image)})");
Solution: shift the images while we're training!
In [ ]:
model = MLP(input_size, hidden_size, num_classes)
learning_rate = 0.1
# Training loop!
train_losses = []
val_losses = []
augment_range = 4
for epoch in range(max_iter):
train_loss = 0.
for i, (x_batch, y_batch) in enumerate(train_loader):
# Apply random shifts to the images. Pick a shift for this batch.
shift_x = torch.randint(-augment_range, augment_range + 1, size=(1,))
shift_y = torch.randint(-augment_range, augment_range + 1, size=(1,))
x_batch_shifted = shift_images(x_batch, shift_x, shift_y)
# Forward pass
logits = ...
# Compute the loss
loss = F.cross_entropy(logits, y_batch)
train_loss += loss.item()
# Backward pass
# Clear old gradients
model.zero_grad()
# Compute the gradients
loss.backward()
# Update the weights
for param in model.parameters():
param.data -= ... * param.grad
train_loss /= len(train_loader)
train_losses.append(train_loss)
# Compute validation loss
with torch.no_grad():
val_loss = 0.
for x_val_batch, y_val_batch in val_loader:
logits = ...
val_loss += ...
val_loss /= len(val_loader)
val_losses.append(val_loss)
print(f"Epoch {epoch}: val_loss = {val_loss:.3f}")
plt.plot(train_losses, label='train')
plt.plot(val_losses, label='val')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
Epoch 0: val_loss = 2.192 Epoch 1: val_loss = 1.963 Epoch 2: val_loss = 1.811 Epoch 3: val_loss = 1.590 Epoch 4: val_loss = 1.520 Epoch 5: val_loss = 1.470 Epoch 6: val_loss = 1.445 Epoch 7: val_loss = 1.314 Epoch 8: val_loss = 1.261 Epoch 9: val_loss = 1.306 Epoch 10: val_loss = 1.243 Epoch 11: val_loss = 1.287 Epoch 12: val_loss = 1.075 Epoch 13: val_loss = 1.218 Epoch 14: val_loss = 1.021 Epoch 15: val_loss = 0.924 Epoch 16: val_loss = 0.883 Epoch 17: val_loss = 0.957 Epoch 18: val_loss = 0.985 Epoch 19: val_loss = 0.963 Epoch 20: val_loss = 0.937 Epoch 21: val_loss = 0.791 Epoch 22: val_loss = 0.733 Epoch 23: val_loss = 0.675 Epoch 24: val_loss = 0.639 Epoch 25: val_loss = 0.705 Epoch 26: val_loss = 0.669 Epoch 27: val_loss = 0.693 Epoch 28: val_loss = 0.703 Epoch 29: val_loss = 0.640 Epoch 30: val_loss = 0.567 Epoch 31: val_loss = 0.571 Epoch 32: val_loss = 0.579 Epoch 33: val_loss = 0.639 Epoch 34: val_loss = 0.550 Epoch 35: val_loss = 0.506 Epoch 36: val_loss = 0.509 Epoch 37: val_loss = 0.528 Epoch 38: val_loss = 0.477 Epoch 39: val_loss = 0.481 Epoch 40: val_loss = 0.528 Epoch 41: val_loss = 0.532 Epoch 42: val_loss = 0.490 Epoch 43: val_loss = 0.488 Epoch 44: val_loss = 0.479 Epoch 45: val_loss = 0.444 Epoch 46: val_loss = 0.446 Epoch 47: val_loss = 0.444 Epoch 48: val_loss = 0.424 Epoch 49: val_loss = 0.399 Epoch 50: val_loss = 0.412 Epoch 51: val_loss = 0.461 Epoch 52: val_loss = 0.423 Epoch 53: val_loss = 0.410 Epoch 54: val_loss = 0.448 Epoch 55: val_loss = 0.374 Epoch 56: val_loss = 0.376 Epoch 57: val_loss = 0.373 Epoch 58: val_loss = 0.383 Epoch 59: val_loss = 0.368 Epoch 60: val_loss = 0.408 Epoch 61: val_loss = 0.363 Epoch 62: val_loss = 0.386 Epoch 63: val_loss = 0.371 Epoch 64: val_loss = 0.357 Epoch 65: val_loss = 0.342 Epoch 66: val_loss = 0.349 Epoch 67: val_loss = 0.358 Epoch 68: val_loss = 0.364 Epoch 69: val_loss = 0.354 Epoch 70: val_loss = 0.325 Epoch 71: val_loss = 0.345 Epoch 72: val_loss = 0.333 Epoch 73: val_loss = 0.351 Epoch 74: val_loss = 0.365 Epoch 75: val_loss = 0.326 Epoch 76: val_loss = 0.322 Epoch 77: val_loss = 0.325 Epoch 78: val_loss = 0.340 Epoch 79: val_loss = 0.324 Epoch 80: val_loss = 0.317 Epoch 81: val_loss = 0.322 Epoch 82: val_loss = 0.313 Epoch 83: val_loss = 0.314 Epoch 84: val_loss = 0.327 Epoch 85: val_loss = 0.323 Epoch 86: val_loss = 0.421 Epoch 87: val_loss = 0.352 Epoch 88: val_loss = 0.353 Epoch 89: val_loss = 0.336 Epoch 90: val_loss = 0.308 Epoch 91: val_loss = 0.306 Epoch 92: val_loss = 0.343 Epoch 93: val_loss = 0.316 Epoch 94: val_loss = 0.325 Epoch 95: val_loss = 0.337 Epoch 96: val_loss = 0.350 Epoch 97: val_loss = 0.306 Epoch 98: val_loss = 0.321 Epoch 99: val_loss = 0.306
Out[ ]:
<matplotlib.legend.Legend at 0x119f58770>
In [ ]:
accuracies = []
augment_ranges = list(range(0, 15))
for augment_range in augment_ranges:
with torch.no_grad(): # Don't compute gradients, since we're not updating the model
val_loss = 0.
val_predictions = []
val_actuals = []
for trial in range(50):
for x_val_batch, y_val_batch in val_loader:
shift_amount_x = torch.randint(-augment_range, augment_range+1, size=(1,))
shift_amount_y = torch.randint(-augment_range, augment_range+1, size=(1,))
x_val_batch_shifted = shift_images(x_val_batch, shift_amount_x, shift_amount_y)
logits = model(x_val_batch_shifted)
val_loss += F.cross_entropy(logits, y_val_batch)
val_predictions.extend(logits.argmax(dim=-1).tolist())
val_actuals.extend(y_val_batch.tolist())
accuracies.append(np.mean(np.array(val_predictions) == np.array(val_actuals)))
plt.plot(augment_ranges, accuracies)
plt.xlabel("How many pixels we shift the images by")
plt.ylabel("Validation accuracy")
Out[ ]:
Text(0, 0.5, 'Validation accuracy')
In [ ]:
accuracies = []
augment_ranges = list(range(0, 15))
for augment_range in augment_ranges:
with torch.no_grad(): # Don't compute gradients, since we're not updating the model
val_loss = 0.
val_predictions = []
val_actuals = []
for trial in range(50):
for x_val_batch, y_val_batch in val_loader:
shift_amount_x = torch.randint(-augment_range, augment_range+1, size=(1,))
shift_amount_y = torch.randint(-augment_range, augment_range+1, size=(1,))
x_val_batch_shifted = shift_images(x_val_batch, shift_amount_x, shift_amount_y)
logits = model(x_val_batch_shifted)
val_loss += F.cross_entropy(logits, y_val_batch)
val_predictions.extend(logits.argmax(dim=-1).tolist())
val_actuals.extend(y_val_batch.tolist())
accuracies.append(np.mean(np.array(val_predictions) == np.array(val_actuals)))
plt.plot(augment_ranges, accuracies)
plt.xlabel("How many pixels we shift the images by")
plt.ylabel("Validation accuracy")
Text(0, 0.5, 'Validation accuracy')
In [ ]:
accuracies = []
augment_ranges = list(range(0, 15))
for augment_range in augment_ranges:
with torch.no_grad(): # Don't compute gradients, since we're not updating the model
val_loss = 0.
val_predictions = []
val_actuals = []
for trial in range(50):
for x_val_batch, y_val_batch in val_loader:
shift_amount_x = torch.randint(-augment_range, augment_range+1, size=(1,))
shift_amount_y = torch.randint(-augment_range, augment_range+1, size=(1,))
x_val_batch_shifted = shift_images(x_val_batch, shift_amount_x, shift_amount_y)
logits = model(x_val_batch_shifted)
val_loss += F.cross_entropy(logits, y_val_batch)
val_predictions.extend(logits.argmax(dim=-1).tolist())
val_actuals.extend(y_val_batch.tolist())
accuracies.append(np.mean(np.array(val_predictions) == np.array(val_actuals)))
plt.plot(augment_ranges, accuracies)
plt.xlabel("How many pixels we shift the images by")
plt.ylabel("Validation accuracy")
Out[ ]:
Text(0, 0.5, 'Validation accuracy')
In [ ]: