MNIST with PyTorch¶

In [ ]:
import numpy as np
np.set_printoptions(precision=3)
import matplotlib.pyplot as plt
from sklearn.datasets import fetch_openml
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
%matplotlib inline

num_samples = 1000
max_iter = 100

Load and Understand the Data¶

First, let's load the MNIST dataset. Each image is a 28x28 grid of grayscale pixels.

In [ ]:
# Load a subset of MNIST to start
mnist = fetch_openml('mnist_784', version=1, as_frame=False)
x_full = mnist.data[:num_samples] / 255.0  # Scale pixel values to [0,1]
x_full = x_full.astype('float32')
y_full = mnist.target[:num_samples].astype(int)

# Look at the shape of one image
single_image = x_full[0].reshape(28, 28)  # Reshape back to 2D
plt.imshow(single_image, cmap='gray')
plt.title(f"Example digit (label: {y_full[0]})")
print("Shape of one image:", single_image.shape)

# Look at the distribution of labels
plt.figure()
plt.hist(y_full, bins=10, rwidth=0.8)
plt.title("Distribution of Labels")
plt.xlabel("Digit")
plt.ylabel("Count");
Shape of one image: (28, 28)
No description has been provided for this image
No description has been provided for this image

Understanding Flattening¶

To use these images in our model, we need to "flatten" each 28x28 grid into a single vector of 784 numbers:

In [ ]:
# Demonstrate flattening on one image
flat_image = single_image.reshape(-1)  # -1 means "figure out this dimension"
print("Shape after flattening:", flat_image.shape)

# Visualize how flattening works
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
ax1.imshow(single_image, cmap='gray')
ax1.set_title("Original 28x28 image")
ax2.scatter(np.arange(784), flat_image)
ax2.set_title("Same image as 784 numbers")
ax2.set_xlabel("Pixel position")
ax2.set_ylabel("Pixel value")
Shape after flattening: (784,)
Out[ ]:
Text(0, 0.5, 'Pixel value')
No description has been provided for this image

Our dataset x_full was already flattened: each row is one flattened image:

In [ ]:
print("Full dataset shape:", x_full.shape)
print("Number of training examples:", x_full.shape[0])
print("Number of features per example:", x_full.shape[1])
Full dataset shape: (1000, 784)
Number of training examples: 1000
Number of features per example: 784

Setting up data loaders¶

We'll split the data into training and validation sets, and set up PyTorch DataLoader objects to feed the data to our model.

In [ ]:
# Split the data into training and validation sets
x_train, x_val, y_train, y_val = train_test_split(x_full, y_full, test_size=0.2, random_state=42)

# Create DataLoader objects
train_data = TensorDataset(torch.tensor(x_train), torch.tensor(y_train))
val_data = TensorDataset(torch.tensor(x_val), torch.tensor(y_val))
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
val_loader = DataLoader(val_data, batch_size=32, shuffle=False)

Train an MLP to classify MNIST¶

We'll use PyTorch to train a simple multi-layer perceptron (MLP) to classify MNIST digits.

You should be able to piece this together from what we've done before!

In [ ]:
# Define the model
class MLP(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(MLP, self).__init__()
        self.linear_1 = nn.Linear(in_features=..., out_features=..., bias=True)
        self.linear_2 = nn.Linear(in_features=..., out_features=..., bias=True)

    def forward(self, x):
        z1 = ...
        # internally, this is doing
        a1 = ...
        # or you could do
        a1 = z1.max(torch.tensor(0.0))
        # or
        a1 = torch.max(z1, torch.tensor(0.0))
        logits = ...
        return logits

# Instantiate the model
input_size = 784
hidden_size = ...
num_classes = ...

model = MLP(input_size, hidden_size, num_classes)
learning_rate = ...

# Train the model
train_losses = []
val_losses = []

for epoch in range(max_iter):
    train_loss = 0.
    for i, (x_batch, y_batch) in enumerate(train_loader):
        # Forward pass
        logits = ...
        # Compute the loss (remember that F.cross_entropy does the softmax internally)
        loss = F.cross_entropy(logits, y_batch)

        train_loss += loss.item()

        # Backward pass
        # Clear old gradients
        model.zero_grad()
        # Compute the gradients
        loss.backward()
        # Update the weights
        for param in model.parameters():
            param.data -= ... * param.grad
    train_loss /= len(train_loader)
    train_losses.append(train_loss)

    # Compute validation loss
    with torch.no_grad(): # Don't compute gradients, since we're not updating the model
        val_loss = 0.
        for x_val_batch, y_val_batch in val_loader:
            logits = ...
            val_loss += ...
        val_loss /= len(val_loader)
        val_losses.append(val_loss)
    print(f"Epoch {epoch}: val_loss = {val_loss:.3f}")

plt.plot(train_losses, label='train')
plt.plot(val_losses, label='val')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
Epoch 0: val_loss = 1.654
Epoch 1: val_loss = 0.974
Epoch 2: val_loss = 0.699
Epoch 3: val_loss = 0.609
Epoch 4: val_loss = 0.570
Epoch 5: val_loss = 0.517
Epoch 6: val_loss = 0.498
Epoch 7: val_loss = 0.518
Epoch 8: val_loss = 0.482
Epoch 9: val_loss = 0.494
Epoch 10: val_loss = 0.465
Epoch 11: val_loss = 0.471
Epoch 12: val_loss = 0.488
Epoch 13: val_loss = 0.465
Epoch 14: val_loss = 0.480
Epoch 15: val_loss = 0.478
Epoch 16: val_loss = 0.478
Epoch 17: val_loss = 0.497
Epoch 18: val_loss = 0.494
Epoch 19: val_loss = 0.488
Epoch 20: val_loss = 0.510
Epoch 21: val_loss = 0.485
Epoch 22: val_loss = 0.495
Epoch 23: val_loss = 0.501
Epoch 24: val_loss = 0.517
Epoch 25: val_loss = 0.510
Epoch 26: val_loss = 0.520
Epoch 27: val_loss = 0.528
Epoch 28: val_loss = 0.518
Epoch 29: val_loss = 0.523
Epoch 30: val_loss = 0.532
Epoch 31: val_loss = 0.531
Epoch 32: val_loss = 0.533
Epoch 33: val_loss = 0.538
Epoch 34: val_loss = 0.535
Epoch 35: val_loss = 0.541
Epoch 36: val_loss = 0.554
Epoch 37: val_loss = 0.551
Epoch 38: val_loss = 0.557
Epoch 39: val_loss = 0.560
Epoch 40: val_loss = 0.563
Epoch 41: val_loss = 0.569
Epoch 42: val_loss = 0.565
Epoch 43: val_loss = 0.566
Epoch 44: val_loss = 0.569
Epoch 45: val_loss = 0.568
Epoch 46: val_loss = 0.575
Epoch 47: val_loss = 0.575
Epoch 48: val_loss = 0.584
Epoch 49: val_loss = 0.587
Epoch 50: val_loss = 0.583
Epoch 51: val_loss = 0.590
Epoch 52: val_loss = 0.598
Epoch 53: val_loss = 0.590
Epoch 54: val_loss = 0.590
Epoch 55: val_loss = 0.602
Epoch 56: val_loss = 0.597
Epoch 57: val_loss = 0.598
Epoch 58: val_loss = 0.602
Epoch 59: val_loss = 0.598
Epoch 60: val_loss = 0.604
Epoch 61: val_loss = 0.608
Epoch 62: val_loss = 0.613
Epoch 63: val_loss = 0.611
Epoch 64: val_loss = 0.616
Epoch 65: val_loss = 0.614
Epoch 66: val_loss = 0.621
Epoch 67: val_loss = 0.616
Epoch 68: val_loss = 0.623
Epoch 69: val_loss = 0.622
Epoch 70: val_loss = 0.628
Epoch 71: val_loss = 0.626
Epoch 72: val_loss = 0.628
Epoch 73: val_loss = 0.629
Epoch 74: val_loss = 0.633
Epoch 75: val_loss = 0.632
Epoch 76: val_loss = 0.635
Epoch 77: val_loss = 0.635
Epoch 78: val_loss = 0.637
Epoch 79: val_loss = 0.640
Epoch 80: val_loss = 0.637
Epoch 81: val_loss = 0.638
Epoch 82: val_loss = 0.643
Epoch 83: val_loss = 0.645
Epoch 84: val_loss = 0.649
Epoch 85: val_loss = 0.649
Epoch 86: val_loss = 0.650
Epoch 87: val_loss = 0.650
Epoch 88: val_loss = 0.651
Epoch 89: val_loss = 0.653
Epoch 90: val_loss = 0.653
Epoch 91: val_loss = 0.654
Epoch 92: val_loss = 0.657
Epoch 93: val_loss = 0.659
Epoch 94: val_loss = 0.659
Epoch 95: val_loss = 0.658
Epoch 96: val_loss = 0.660
Epoch 97: val_loss = 0.660
Epoch 98: val_loss = 0.662
Epoch 99: val_loss = 0.665
Out[ ]:
<matplotlib.legend.Legend at 0x1180b0770>
No description has been provided for this image

Data Augmentation¶

In [ ]:
# Shifting the image by one pixel shouldn't make a difference, right?
def shift_image(image, dx, dy):
    image = image.reshape((28, 28))
    shifted_image = torch.roll(image, shifts=(dy, dx), dims=(0, 1))
    return shifted_image.reshape(-1)

def shift_images(image_batch, dx, dy):
    n_batch = image_batch.shape[0]
    return torch.roll(image_batch.reshape(n_batch, 28, 28), shifts=(dy, dx), dims=(1, 2)).reshape(n_batch, 784)

def classify(model, x):
    with torch.no_grad():
        logits = model(x)
        return logits.argmax(dim=-1)

# Visualize the shifted image
single_image = torch.tensor(x_train)[0]
shifted_image = shift_image(single_image, -1, -1)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
ax1.imshow(single_image.reshape(28, 28), cmap='gray')
ax1.set_title(f"Original image (classified as {classify(model, single_image)})")
ax2.imshow(shifted_image.reshape(28, 28), cmap='gray')
ax2.set_title(f"Shifted image (classified as {classify(model, shifted_image)})");
No description has been provided for this image

Solution: shift the images while we're training!

In [ ]:
model = MLP(input_size, hidden_size, num_classes)
learning_rate = 0.1

# Training loop!
train_losses = []
val_losses = []
augment_range = 4

for epoch in range(max_iter):
    train_loss = 0.
    for i, (x_batch, y_batch) in enumerate(train_loader):
        # Apply random shifts to the images. Pick a shift for this batch.
        shift_x = torch.randint(-augment_range, augment_range + 1, size=(1,))
        shift_y = torch.randint(-augment_range, augment_range + 1, size=(1,))
        x_batch_shifted = shift_images(x_batch, shift_x, shift_y)

        # Forward pass
        logits = ...
        # Compute the loss
        loss = F.cross_entropy(logits, y_batch)

        train_loss += loss.item()

        # Backward pass
        # Clear old gradients
        model.zero_grad()
        # Compute the gradients
        loss.backward()
        # Update the weights
        for param in model.parameters():
            param.data -= ... * param.grad
    train_loss /= len(train_loader)
    train_losses.append(train_loss)

    # Compute validation loss
    with torch.no_grad():
        val_loss = 0.
        for x_val_batch, y_val_batch in val_loader:
            logits = ...
            val_loss += ...
        val_loss /= len(val_loader)
        val_losses.append(val_loss)
    print(f"Epoch {epoch}: val_loss = {val_loss:.3f}")

plt.plot(train_losses, label='train')
plt.plot(val_losses, label='val')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
Epoch 0: val_loss = 2.192
Epoch 1: val_loss = 1.963
Epoch 2: val_loss = 1.811
Epoch 3: val_loss = 1.590
Epoch 4: val_loss = 1.520
Epoch 5: val_loss = 1.470
Epoch 6: val_loss = 1.445
Epoch 7: val_loss = 1.314
Epoch 8: val_loss = 1.261
Epoch 9: val_loss = 1.306
Epoch 10: val_loss = 1.243
Epoch 11: val_loss = 1.287
Epoch 12: val_loss = 1.075
Epoch 13: val_loss = 1.218
Epoch 14: val_loss = 1.021
Epoch 15: val_loss = 0.924
Epoch 16: val_loss = 0.883
Epoch 17: val_loss = 0.957
Epoch 18: val_loss = 0.985
Epoch 19: val_loss = 0.963
Epoch 20: val_loss = 0.937
Epoch 21: val_loss = 0.791
Epoch 22: val_loss = 0.733
Epoch 23: val_loss = 0.675
Epoch 24: val_loss = 0.639
Epoch 25: val_loss = 0.705
Epoch 26: val_loss = 0.669
Epoch 27: val_loss = 0.693
Epoch 28: val_loss = 0.703
Epoch 29: val_loss = 0.640
Epoch 30: val_loss = 0.567
Epoch 31: val_loss = 0.571
Epoch 32: val_loss = 0.579
Epoch 33: val_loss = 0.639
Epoch 34: val_loss = 0.550
Epoch 35: val_loss = 0.506
Epoch 36: val_loss = 0.509
Epoch 37: val_loss = 0.528
Epoch 38: val_loss = 0.477
Epoch 39: val_loss = 0.481
Epoch 40: val_loss = 0.528
Epoch 41: val_loss = 0.532
Epoch 42: val_loss = 0.490
Epoch 43: val_loss = 0.488
Epoch 44: val_loss = 0.479
Epoch 45: val_loss = 0.444
Epoch 46: val_loss = 0.446
Epoch 47: val_loss = 0.444
Epoch 48: val_loss = 0.424
Epoch 49: val_loss = 0.399
Epoch 50: val_loss = 0.412
Epoch 51: val_loss = 0.461
Epoch 52: val_loss = 0.423
Epoch 53: val_loss = 0.410
Epoch 54: val_loss = 0.448
Epoch 55: val_loss = 0.374
Epoch 56: val_loss = 0.376
Epoch 57: val_loss = 0.373
Epoch 58: val_loss = 0.383
Epoch 59: val_loss = 0.368
Epoch 60: val_loss = 0.408
Epoch 61: val_loss = 0.363
Epoch 62: val_loss = 0.386
Epoch 63: val_loss = 0.371
Epoch 64: val_loss = 0.357
Epoch 65: val_loss = 0.342
Epoch 66: val_loss = 0.349
Epoch 67: val_loss = 0.358
Epoch 68: val_loss = 0.364
Epoch 69: val_loss = 0.354
Epoch 70: val_loss = 0.325
Epoch 71: val_loss = 0.345
Epoch 72: val_loss = 0.333
Epoch 73: val_loss = 0.351
Epoch 74: val_loss = 0.365
Epoch 75: val_loss = 0.326
Epoch 76: val_loss = 0.322
Epoch 77: val_loss = 0.325
Epoch 78: val_loss = 0.340
Epoch 79: val_loss = 0.324
Epoch 80: val_loss = 0.317
Epoch 81: val_loss = 0.322
Epoch 82: val_loss = 0.313
Epoch 83: val_loss = 0.314
Epoch 84: val_loss = 0.327
Epoch 85: val_loss = 0.323
Epoch 86: val_loss = 0.421
Epoch 87: val_loss = 0.352
Epoch 88: val_loss = 0.353
Epoch 89: val_loss = 0.336
Epoch 90: val_loss = 0.308
Epoch 91: val_loss = 0.306
Epoch 92: val_loss = 0.343
Epoch 93: val_loss = 0.316
Epoch 94: val_loss = 0.325
Epoch 95: val_loss = 0.337
Epoch 96: val_loss = 0.350
Epoch 97: val_loss = 0.306
Epoch 98: val_loss = 0.321
Epoch 99: val_loss = 0.306
Out[ ]:
<matplotlib.legend.Legend at 0x119f58770>
No description has been provided for this image
In [ ]:
accuracies = []
augment_ranges = list(range(0, 15))
for augment_range in augment_ranges:
    with torch.no_grad(): # Don't compute gradients, since we're not updating the model
        val_loss = 0.
        val_predictions = []
        val_actuals = []
        for trial in range(50):
            for x_val_batch, y_val_batch in val_loader:
                shift_amount_x = torch.randint(-augment_range, augment_range+1, size=(1,))
                shift_amount_y = torch.randint(-augment_range, augment_range+1, size=(1,))
                x_val_batch_shifted = shift_images(x_val_batch, shift_amount_x, shift_amount_y)
                logits = model(x_val_batch_shifted)
                val_loss += F.cross_entropy(logits, y_val_batch)
                val_predictions.extend(logits.argmax(dim=-1).tolist())
                val_actuals.extend(y_val_batch.tolist())
        accuracies.append(np.mean(np.array(val_predictions) == np.array(val_actuals)))

plt.plot(augment_ranges, accuracies)
plt.xlabel("How many pixels we shift the images by")
plt.ylabel("Validation accuracy")
Out[ ]:
Text(0, 0.5, 'Validation accuracy')
No description has been provided for this image
In [ ]:
accuracies = []
augment_ranges = list(range(0, 15))
for augment_range in augment_ranges:
    with torch.no_grad(): # Don't compute gradients, since we're not updating the model
        val_loss = 0.
        val_predictions = []
        val_actuals = []
        for trial in range(50):
            for x_val_batch, y_val_batch in val_loader:
                shift_amount_x = torch.randint(-augment_range, augment_range+1, size=(1,))
                shift_amount_y = torch.randint(-augment_range, augment_range+1, size=(1,))
                x_val_batch_shifted = shift_images(x_val_batch, shift_amount_x, shift_amount_y)
                logits = model(x_val_batch_shifted)
                val_loss += F.cross_entropy(logits, y_val_batch)
                val_predictions.extend(logits.argmax(dim=-1).tolist())
                val_actuals.extend(y_val_batch.tolist())
        accuracies.append(np.mean(np.array(val_predictions) == np.array(val_actuals)))

plt.plot(augment_ranges, accuracies)
plt.xlabel("How many pixels we shift the images by")
plt.ylabel("Validation accuracy")
Text(0, 0.5, 'Validation accuracy')
No description has been provided for this image
In [ ]:
accuracies = []
augment_ranges = list(range(0, 15))
for augment_range in augment_ranges:
    with torch.no_grad(): # Don't compute gradients, since we're not updating the model
        val_loss = 0.
        val_predictions = []
        val_actuals = []
        for trial in range(50):
            for x_val_batch, y_val_batch in val_loader:
                shift_amount_x = torch.randint(-augment_range, augment_range+1, size=(1,))
                shift_amount_y = torch.randint(-augment_range, augment_range+1, size=(1,))
                x_val_batch_shifted = shift_images(x_val_batch, shift_amount_x, shift_amount_y)
                logits = model(x_val_batch_shifted)
                val_loss += F.cross_entropy(logits, y_val_batch)
                val_predictions.extend(logits.argmax(dim=-1).tolist())
                val_actuals.extend(y_val_batch.tolist())
        accuracies.append(np.mean(np.array(val_predictions) == np.array(val_actuals)))

plt.plot(augment_ranges, accuracies)
plt.xlabel("How many pixels we shift the images by")
plt.ylabel("Validation accuracy")
Out[ ]:
Text(0, 0.5, 'Validation accuracy')
No description has been provided for this image
In [ ]: