Task: plot a confusion matrix, find images that were misclassified
You do not need to read or modify the code in this section to successfully complete this assignment.
# Import fastai code.
from fastai.vision.all import *
# Set a seed for reproducibility.
set_seed(0, reproducible=True)
Monkey-patch plot_top_losses because of a bug.
def _plot_top_losses(self, k, largest=True, **kwargs):
losses,idx = self.top_losses(k, largest)
if not isinstance(self.inputs, tuple): self.inputs = (self.inputs,)
if isinstance(self.inputs[0], Tensor): inps = tuple(o[idx] for o in self.inputs)
else: inps = self.dl.create_batch(self.dl.before_batch([tuple(o[i] for o in self.inputs) for i in idx]))
b = inps + tuple(o[idx] for o in (self.targs if is_listy(self.targs) else (self.targs,)))
x,y,its = self.dl._pre_show_batch(b, max_n=k)
b_out = inps + tuple(o[idx] for o in (self.decoded if is_listy(self.decoded) else (self.decoded,)))
x1,y1,outs = self.dl._pre_show_batch(b_out, max_n=k)
if its is not None:
plot_top_losses(x, y, its, outs.itemgot(slice(len(inps), None)), self.preds[idx], losses, **kwargs)
ClassificationInterpretation.plot_top_losses = _plot_top_losses
path = untar_data(URLs.PETS)/'images'
image_files = get_image_files(path).sorted()
# Cat images have filenames that start with a capital letter.
def is_cat(filename):
return filename[0].isupper()
FLIP_PROB = 0.0 # <--AFTER FINISHING, try setting this to 0.25 to play with detecting mislabeled images
correct_labels = [is_cat(path.name) for path in image_files]
corrupted_labels = [
not correct_label if random.random() < FLIP_PROB else correct_label
for correct_label in correct_labels]
Check how many labels are still correct.
sum(
correct_label == corrupted_label
for correct_label, corrupted_label in zip(correct_labels, corrupted_labels)
) / len(correct_labels)
dataloaders = ImageDataLoaders.from_lists(
path=path, fnames=image_files, labels=corrupted_labels,
valid_pct=0.2,
seed=42,
item_tfms=Resize(224)
)
learn = cnn_learner(
dls=dataloaders,
arch=resnet18,
metrics=accuracy
)
learn.fine_tune(epochs=4)
learn.recorder.plot_loss()
We've given you a classifier (the learn object). It makes a few mistakes; can you find them?
The code above provides a way to corrupt some of the labels before training. For the purposes of this assignment, the corruption machinery is turned off. But you might find it enlightening to re-enable it and see how a classifier handles mislabeled data. But wait until after you finish this assignment.
Follow these steps:
DataLoader objects at dataloaders.train and dataloaders.valid; each of them has a .show_batch() method.)# your code here
# your code here
accuracy(interp.preds, interp.targs)). Check that this number matches the last accuracy figure reported while training above. Multiply this by the number of images in the validation set to give the actual number of misclassified images.Hints:
WHATEVER.item() to get a plain number instead of a Tensor.DataLoaders have a .n attribute that gives the number of images in them.# your code here
# your code here
interp_train = ClassificationInterpretation.from_learner(learn, ds_idx=0))interp_train = ClassificationInterpretation.from_learner(learn, ds_idx=0)
# your code here
interp.plot_top_losses(12)
X out of XX images were incorrectly labeled "cat".
Y out of YY images were incorrectly labeled "dog".
your answer here
your answer here