MNIST with PyTorch¶

In [ ]:
import numpy as np
np.set_printoptions(precision=3)
import matplotlib.pyplot as plt
from sklearn.datasets import fetch_openml
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
%matplotlib inline

num_samples = 1000
max_iter = 100

Load and Understand the Data¶

First, let's load the MNIST dataset. Each image is a 28x28 grid of grayscale pixels.

In [ ]:
# Load a subset of MNIST to start
mnist = fetch_openml('mnist_784', version=1, as_frame=False)
x_full = mnist.data[:num_samples] / 255.0  # Scale pixel values to [0,1]
x_full = x_full.astype('float32')
y_full = mnist.target[:num_samples].astype(int)

# Look at the shape of one image
single_image = x_full[0].reshape(28, 28)  # Reshape back to 2D
plt.imshow(single_image, cmap='gray')
plt.title(f"Example digit (label: {y_full[0]})")
print("Shape of one image:", single_image.shape)

# Look at the distribution of labels
plt.figure()
plt.hist(y_full, bins=10, rwidth=0.8)
plt.title("Distribution of Labels")
plt.xlabel("Digit")
plt.ylabel("Count");
Shape of one image: (28, 28)
No description has been provided for this image
No description has been provided for this image

Understanding Flattening¶

To use these images in our model, we need to "flatten" each 28x28 grid into a single vector of 784 numbers:

In [ ]:
# Demonstrate flattening on one image
flat_image = single_image.reshape(-1)  # -1 means "figure out this dimension"
print("Shape after flattening:", flat_image.shape)

# Visualize how flattening works
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
ax1.imshow(single_image, cmap='gray')
ax1.set_title("Original 28x28 image")
ax2.scatter(np.arange(784), flat_image)
ax2.set_title("Same image as 784 numbers")
ax2.set_xlabel("Pixel position")
ax2.set_ylabel("Pixel value")
Shape after flattening: (784,)
Out[ ]:
Text(0, 0.5, 'Pixel value')
No description has been provided for this image

Our dataset x_full was already flattened: each row is one flattened image:

In [ ]:
print("Full dataset shape:", x_full.shape)
print("Number of training examples:", x_full.shape[0])
print("Number of features per example:", x_full.shape[1])
Full dataset shape: (1000, 784)
Number of training examples: 1000
Number of features per example: 784

Setting up data loaders¶

We'll split the data into training and validation sets, and set up PyTorch DataLoader objects to feed the data to our model.

In [ ]:
# Split the data into training and validation sets
x_train, x_val, y_train, y_val = train_test_split(x_full, y_full, test_size=0.2, random_state=42)

# Create DataLoader objects
train_data = TensorDataset(torch.tensor(x_train), torch.tensor(y_train))
val_data = TensorDataset(torch.tensor(x_val), torch.tensor(y_val))
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
val_loader = DataLoader(val_data, batch_size=32, shuffle=False)

Train an MLP to classify MNIST¶

We'll use PyTorch to train a simple multi-layer perceptron (MLP) to classify MNIST digits.

You should be able to piece this together from what we've done before!

In [ ]:
# Define the model
class MLP(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(MLP, self).__init__()
        self.linear_1 = nn.Linear(in_features=..., out_features=..., bias=True)
        self.linear_2 = nn.Linear(in_features=..., out_features=..., bias=True)

    def forward(self, x):
        z1 = ...
        # internally, this is doing
        a1 = ...
        # or you could do
        a1 = z1.max(torch.tensor(0.0))
        # or
        a1 = torch.max(z1, torch.tensor(0.0))
        logits = ...
        return logits

# Instantiate the model
input_size = 784
hidden_size = ...
num_classes = ...

model = MLP(input_size, hidden_size, num_classes)
learning_rate = ...

# Train the model
train_losses = []
val_losses = []

for epoch in range(max_iter):
    train_loss = 0.
    for i, (x_batch, y_batch) in enumerate(train_loader):
        # Forward pass
        logits = ...
        # Compute the loss (remember that F.cross_entropy does the softmax internally)
        loss = F.cross_entropy(logits, y_batch)

        train_loss += loss.item()

        # Backward pass
        # Clear old gradients
        model.zero_grad()
        # Compute the gradients
        loss.backward()
        # Update the weights
        for param in model.parameters():
            param.data -= ... * param.grad
    train_loss /= len(train_loader)
    train_losses.append(train_loss)

    # Compute validation loss
    with torch.no_grad(): # Don't compute gradients, since we're not updating the model
        val_loss = 0.
        for x_val_batch, y_val_batch in val_loader:
            logits = ...
            val_loss += ...
        val_loss /= len(val_loader)
        val_losses.append(val_loss)
    print(f"Epoch {epoch}: val_loss = {val_loss:.3f}")

plt.plot(train_losses, label='train')
plt.plot(val_losses, label='val')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
Epoch 0: val_loss = 1.704
Epoch 1: val_loss = 1.014
Epoch 2: val_loss = 0.741
Epoch 3: val_loss = 0.648
Epoch 4: val_loss = 0.550
Epoch 5: val_loss = 0.494
Epoch 6: val_loss = 0.524
Epoch 7: val_loss = 0.486
Epoch 8: val_loss = 0.493
Epoch 9: val_loss = 0.469
Epoch 10: val_loss = 0.500
Epoch 11: val_loss = 0.488
Epoch 12: val_loss = 0.475
Epoch 13: val_loss = 0.474
Epoch 14: val_loss = 0.465
Epoch 15: val_loss = 0.470
Epoch 16: val_loss = 0.493
Epoch 17: val_loss = 0.495
Epoch 18: val_loss = 0.503
Epoch 19: val_loss = 0.521
Epoch 20: val_loss = 0.493
Epoch 21: val_loss = 0.509
Epoch 22: val_loss = 0.510
Epoch 23: val_loss = 0.530
Epoch 24: val_loss = 0.527
Epoch 25: val_loss = 0.524
Epoch 26: val_loss = 0.527
Epoch 27: val_loss = 0.527
Epoch 28: val_loss = 0.536
Epoch 29: val_loss = 0.533
Epoch 30: val_loss = 0.543
Epoch 31: val_loss = 0.542
Epoch 32: val_loss = 0.547
Epoch 33: val_loss = 0.553
Epoch 34: val_loss = 0.547
Epoch 35: val_loss = 0.560
Epoch 36: val_loss = 0.561
Epoch 37: val_loss = 0.575
Epoch 38: val_loss = 0.571
Epoch 39: val_loss = 0.573
Epoch 40: val_loss = 0.572
Epoch 41: val_loss = 0.573
Epoch 42: val_loss = 0.580
Epoch 43: val_loss = 0.583
Epoch 44: val_loss = 0.583
Epoch 45: val_loss = 0.589
Epoch 46: val_loss = 0.588
Epoch 47: val_loss = 0.593
Epoch 48: val_loss = 0.598
Epoch 49: val_loss = 0.599
Epoch 50: val_loss = 0.605
Epoch 51: val_loss = 0.606
Epoch 52: val_loss = 0.609
Epoch 53: val_loss = 0.610
Epoch 54: val_loss = 0.613
Epoch 55: val_loss = 0.617
Epoch 56: val_loss = 0.615
Epoch 57: val_loss = 0.621
Epoch 58: val_loss = 0.620
Epoch 59: val_loss = 0.628
Epoch 60: val_loss = 0.625
Epoch 61: val_loss = 0.627
Epoch 62: val_loss = 0.628
Epoch 63: val_loss = 0.634
Epoch 64: val_loss = 0.630
Epoch 65: val_loss = 0.635
Epoch 66: val_loss = 0.642
Epoch 67: val_loss = 0.639
Epoch 68: val_loss = 0.642
Epoch 69: val_loss = 0.642
Epoch 70: val_loss = 0.644
Epoch 71: val_loss = 0.647
Epoch 72: val_loss = 0.650
Epoch 73: val_loss = 0.650
Epoch 74: val_loss = 0.651
Epoch 75: val_loss = 0.654
Epoch 76: val_loss = 0.656
Epoch 77: val_loss = 0.657
Epoch 78: val_loss = 0.660
Epoch 79: val_loss = 0.660
Epoch 80: val_loss = 0.660
Epoch 81: val_loss = 0.661
Epoch 82: val_loss = 0.663
Epoch 83: val_loss = 0.666
Epoch 84: val_loss = 0.668
Epoch 85: val_loss = 0.671
Epoch 86: val_loss = 0.673
Epoch 87: val_loss = 0.675
Epoch 88: val_loss = 0.676
Epoch 89: val_loss = 0.673
Epoch 90: val_loss = 0.677
Epoch 91: val_loss = 0.681
Epoch 92: val_loss = 0.680
Epoch 93: val_loss = 0.680
Epoch 94: val_loss = 0.682
Epoch 95: val_loss = 0.681
Epoch 96: val_loss = 0.684
Epoch 97: val_loss = 0.684
Epoch 98: val_loss = 0.687
Epoch 99: val_loss = 0.686
Out[ ]:
<matplotlib.legend.Legend at 0x1420c75f0>
No description has been provided for this image

Data Augmentation¶

In [ ]:
# Shifting the image by one pixel shouldn't make a difference, right?
def shift_image(image, dx, dy):
    image = image.reshape((28, 28))
    shifted_image = torch.roll(image, shifts=(dy, dx), dims=(0, 1))
    return shifted_image.reshape(-1)

def shift_images(image_batch, dx, dy):
    n_batch = image_batch.shape[0]
    return torch.roll(image_batch.reshape(n_batch, 28, 28), shifts=(dy, dx), dims=(1, 2)).reshape(n_batch, 784)

def classify(model, x):
    with torch.no_grad():
        logits = model(x)
        return logits.argmax(dim=-1)

# Visualize the shifted image
single_image = torch.tensor(x_train)[0]
shifted_image = shift_image(single_image, -1, -1)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
ax1.imshow(single_image.reshape(28, 28), cmap='gray')
ax1.set_title(f"Original image (classified as {classify(model, single_image)})")
ax2.imshow(shifted_image.reshape(28, 28), cmap='gray')
ax2.set_title(f"Shifted image (classified as {classify(model, shifted_image)})");
No description has been provided for this image

Solution: shift the images while we're training!

In [ ]:
model = MLP(input_size, hidden_size, num_classes)
learning_rate = 0.1

# Training loop!
train_losses = []
val_losses = []

for epoch in range(max_iter):
    train_loss = 0.
    for i, (x_batch, y_batch) in enumerate(train_loader):
        # Apply random shifts to the images. Pick a shift for this batch.
        shift_x = torch.randint(-1, 2, size=(1,))
        shift_y = torch.randint(-1, 2, size=(1,))
        x_batch_shifted = shift_images(x_batch, shift_x, shift_y)

        # Forward pass
        logits = ...
        # Compute the loss
        loss = F.cross_entropy(logits, y_batch)

        train_loss += loss.item()

        # Backward pass
        # Clear old gradients
        model.zero_grad()
        # Compute the gradients
        loss.backward()
        # Update the weights
        for param in model.parameters():
            param.data -= ... * param.grad
    train_loss /= len(train_loader)
    train_losses.append(train_loss)

    # Compute validation loss
    with torch.no_grad():
        val_loss = 0.
        for x_val_batch, y_val_batch in val_loader:
            logits = ...
            val_loss += ...
        val_loss /= len(val_loader)
        val_losses.append(val_loss)
    print(f"Epoch {epoch}: val_loss = {val_loss:.3f}")

plt.plot(train_losses, label='train')
plt.plot(val_losses, label='val')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
Epoch 0: val_loss = 1.800
Epoch 1: val_loss = 1.218
Epoch 2: val_loss = 0.855
Epoch 3: val_loss = 0.701
Epoch 4: val_loss = 0.618
Epoch 5: val_loss = 0.563
Epoch 6: val_loss = 0.526
Epoch 7: val_loss = 0.512
Epoch 8: val_loss = 0.529
Epoch 9: val_loss = 0.499
Epoch 10: val_loss = 0.522
Epoch 11: val_loss = 0.483
Epoch 12: val_loss = 0.456
Epoch 13: val_loss = 0.466
Epoch 14: val_loss = 0.435
Epoch 15: val_loss = 0.442
Epoch 16: val_loss = 0.431
Epoch 17: val_loss = 0.476
Epoch 18: val_loss = 0.420
Epoch 19: val_loss = 0.417
Epoch 20: val_loss = 0.421
Epoch 21: val_loss = 0.391
Epoch 22: val_loss = 0.393
Epoch 23: val_loss = 0.413
Epoch 24: val_loss = 0.379
Epoch 25: val_loss = 0.384
Epoch 26: val_loss = 0.376
Epoch 27: val_loss = 0.399
Epoch 28: val_loss = 0.425
Epoch 29: val_loss = 0.403
Epoch 30: val_loss = 0.380
Epoch 31: val_loss = 0.410
Epoch 32: val_loss = 0.407
Epoch 33: val_loss = 0.383
Epoch 34: val_loss = 0.387
Epoch 35: val_loss = 0.399
Epoch 36: val_loss = 0.430
Epoch 37: val_loss = 0.440
Epoch 38: val_loss = 0.385
Epoch 39: val_loss = 0.365
Epoch 40: val_loss = 0.376
Epoch 41: val_loss = 0.367
Epoch 42: val_loss = 0.407
Epoch 43: val_loss = 0.358
Epoch 44: val_loss = 0.434
Epoch 45: val_loss = 0.383
Epoch 46: val_loss = 0.355
Epoch 47: val_loss = 0.376
Epoch 48: val_loss = 0.371
Epoch 49: val_loss = 0.347
Epoch 50: val_loss = 0.377
Epoch 51: val_loss = 0.356
Epoch 52: val_loss = 0.332
Epoch 53: val_loss = 0.344
Epoch 54: val_loss = 0.340
Epoch 55: val_loss = 0.337
Epoch 56: val_loss = 0.361
Epoch 57: val_loss = 0.344
Epoch 58: val_loss = 0.370
Epoch 59: val_loss = 0.350
Epoch 60: val_loss = 0.364
Epoch 61: val_loss = 0.369
Epoch 62: val_loss = 0.368
Epoch 63: val_loss = 0.371
Epoch 64: val_loss = 0.369
Epoch 65: val_loss = 0.366
Epoch 66: val_loss = 0.354
Epoch 67: val_loss = 0.329
Epoch 68: val_loss = 0.347
Epoch 69: val_loss = 0.357
Epoch 70: val_loss = 0.350
Epoch 71: val_loss = 0.335
Epoch 72: val_loss = 0.336
Epoch 73: val_loss = 0.344
Epoch 74: val_loss = 0.337
Epoch 75: val_loss = 0.367
Epoch 76: val_loss = 0.357
Epoch 77: val_loss = 0.328
Epoch 78: val_loss = 0.352
Epoch 79: val_loss = 0.368
Epoch 80: val_loss = 0.337
Epoch 81: val_loss = 0.342
Epoch 82: val_loss = 0.356
Epoch 83: val_loss = 0.376
Epoch 84: val_loss = 0.342
Epoch 85: val_loss = 0.371
Epoch 86: val_loss = 0.367
Epoch 87: val_loss = 0.359
Epoch 88: val_loss = 0.344
Epoch 89: val_loss = 0.366
Epoch 90: val_loss = 0.348
Epoch 91: val_loss = 0.327
Epoch 92: val_loss = 0.357
Epoch 93: val_loss = 0.350
Epoch 94: val_loss = 0.354
Epoch 95: val_loss = 0.354
Epoch 96: val_loss = 0.365
Epoch 97: val_loss = 0.354
Epoch 98: val_loss = 0.360
Epoch 99: val_loss = 0.343
Out[ ]:
<matplotlib.legend.Legend at 0x1691a9c10>
No description has been provided for this image
In [ ]: