Classifier Diagnostics

Task: plot a confusion matrix, find images that were misclassified

Setup

You do not need to read or modify the code in this section to successfully complete this assignment.

In [ ]:
# Import fastai code.
from fastai.vision.all import *

# Set a seed for reproducibility.
set_seed(0, reproducible=True)

Monkey-patch plot_top_losses because of a bug.

In [ ]:
def _plot_top_losses(self, k, largest=True, **kwargs):
    losses,idx = self.top_losses(k, largest)
    if not isinstance(self.inputs, tuple): self.inputs = (self.inputs,)
    if isinstance(self.inputs[0], Tensor): inps = tuple(o[idx] for o in self.inputs)
    else: inps = self.dl.create_batch(self.dl.before_batch([tuple(o[i] for o in self.inputs) for i in idx]))
    b = inps + tuple(o[idx] for o in (self.targs if is_listy(self.targs) else (self.targs,)))
    x,y,its = self.dl._pre_show_batch(b, max_n=k)
    b_out = inps + tuple(o[idx] for o in (self.decoded if is_listy(self.decoded) else (self.decoded,)))
    x1,y1,outs = self.dl._pre_show_batch(b_out, max_n=k)
    if its is not None:
        plot_top_losses(x, y, its, outs.itemgot(slice(len(inps), None)), self.preds[idx], losses, **kwargs)
ClassificationInterpretation.plot_top_losses = _plot_top_losses

Set up the dataset

In [ ]:
path = untar_data(URLs.PETS)/'images'
In [ ]:
image_files = get_image_files(path).sorted()
In [ ]:
# Cat images have filenames that start with a capital letter.
def is_cat(filename):
    return filename[0].isupper()

Optionally corrupt some of the image labels

In [ ]:
FLIP_PROB = 0.0 # <--AFTER FINISHING, try setting this to 0.25 to play with detecting mislabeled images
correct_labels = [is_cat(path.name) for path in image_files]
corrupted_labels = [
    not correct_label if random.random() < FLIP_PROB else correct_label
    for correct_label in correct_labels]

Check how many labels are still correct.

In [ ]:
sum(
    correct_label == corrupted_label
    for correct_label, corrupted_label in zip(correct_labels, corrupted_labels)
) / len(correct_labels)

Train the classifier

In [ ]:
dataloaders = ImageDataLoaders.from_lists(
    path=path, fnames=image_files, labels=corrupted_labels,
    valid_pct=0.2,
    seed=42,
    item_tfms=Resize(224)
)
In [ ]:
learn = cnn_learner(
    dls=dataloaders,
    arch=resnet18,
    metrics=accuracy
)
learn.fine_tune(epochs=4)
learn.recorder.plot_loss()

Task

We've given you a classifier (the learn object). It makes a few mistakes; can you find them?

The code above provides a way to corrupt some of the labels before training. For the purposes of this assignment, the corruption machinery is turned off. But you might find it enlightening to re-enable it and see how a classifier handles mislabeled data. But wait until after you finish this assignment.

Follow these steps:

  1. Show one batch from each of the training and validation sets. (Find the DataLoader objects at dataloaders.train and dataloaders.valid; each of them has a .show_batch() method.)
In [ ]:
# your code here
In [ ]:
# your code here
  1. Compute the accuracy and error rate of this classifier on the validation set (accuracy(interp.preds, interp.targs)). Check that this number matches the last accuracy figure reported while training above. Multiply this by the number of images in the validation set to give the actual number of misclassified images.

Hints:

  • You may need WHATEVER.item() to get a plain number instead of a Tensor.
  • DataLoaders have a .n attribute that gives the number of images in them.
In [ ]:
# your code here
  1. Plot the confusion matrix on the validation set (see chapter 2).
In [ ]:
# your code here
  1. Compute the accuracy on the training set. (Since "dataset 0" is the training set and "dataset 1" is the validation set, we can use interp_train = ClassificationInterpretation.from_learner(learn, ds_idx=0))
In [ ]:
interp_train = ClassificationInterpretation.from_learner(learn, ds_idx=0)
# your code here
  1. Plot the top 12 losses in the validation set.
In [ ]:
interp.plot_top_losses(12)

Analysis

  1. How many dogs in the validation set were misclassified as cats? Vice versa?

X out of XX images were incorrectly labeled "cat".

Y out of YY images were incorrectly labeled "dog".

  1. If we had only looked at the accuracy on the training set, would we have overestimated or underestimated how well the classifier would have performed on the validation set? By how much?

your answer here

  1. Examine the top losses plot.
    1. Explain what the four things above each image mean.
    2. Explain why some correctly classified images appear in the "top losses".
    3. What is the relationship between "loss" and "probability"?

your answer here