From scratch model part 2: training an image recognition model using fastai

image recognition
CNN
fastai
Author

Mike Gallimore

Published

February 8, 2023

In this presentation Mike walked us through how to train an image classifier using the fastai library. This was a group coding session from no-code to a working model.

Running a model on GPU in Jupyter Notebook.

Install required libraries

!pip install -Uqq fastai
!pip install duckduckgo_search
Requirement already satisfied: duckduckgo_search in /usr/local/lib/python3.9/dist-packages (2.8.0)
Requirement already satisfied: requests>=2.28.1 in /usr/local/lib/python3.9/dist-packages (from duckduckgo_search) (2.28.1)
Requirement already satisfied: click>=8.1.3 in /usr/local/lib/python3.9/dist-packages (from duckduckgo_search) (8.1.3)
Requirement already satisfied: idna<4,>=2.5 in /usr/lib/python3/dist-packages (from requests>=2.28.1->duckduckgo_search) (2.8)
Requirement already satisfied: charset-normalizer<3,>=2 in /usr/local/lib/python3.9/dist-packages (from requests>=2.28.1->duckduckgo_search) (2.1.0)
Requirement already satisfied: certifi>=2017.4.17 in /usr/lib/python3/dist-packages (from requests>=2.28.1->duckduckgo_search) (2019.11.28)
Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.9/dist-packages (from requests>=2.28.1->duckduckgo_search) (1.26.10)
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv

Imports

from fastai.vision.all import *
from fastai.vision.widgets import *
from fastdownload import download_url
import torch
from duckduckgo_search import ddg_images
import warnings
warnings.filterwarnings('ignore')
def search_images(term, max_images = 200):
    return L(ddg_images(term, max_results=max_images)).itemgot('image')
search_images('canoe')[0]
'https://northernwilds.com/wp-content/uploads/2018/06/IMG_1530_opt.jpeg'
ims = search_images('canoe')
dest = ('images/canoe.jpg')
download_url(ims[0], dest)
100.32% [2170880/2163989 00:00<00:00]
Path('images/canoe.jpg')
im = Image.open(dest)
im.to_thumb(128, 128)

Change your code to reflect the categories you’re using.

boat_types = 'canoe', 'kayak', 'sailboat'
path = Path('boats')

Download images from chosen categories.

This cell may take a few minutes to run

if not path.exists():
    path.mkdir()
    for boat in boat_types:
        dest = (path/boat)
        dest.mkdir(exist_ok=True)
        results = search_images(boat)
        download_images(dest, urls=results)
!ls {path}
path
canoe  kayak  sailboat
Path('boats')
filenames = get_image_files(path)
filenames
(#554) [Path('boats/sailboat/3b873c96-0d1f-4404-9aae-af25d3805177.jpg'),Path('boats/sailboat/6336c57d-8ad9-4318-836c-fba58b5357be.jpg'),Path('boats/sailboat/d2a981d0-0d3e-4b0e-a077-a5cbda1a41b0.JPG'),Path('boats/sailboat/da7dc60e-7c83-4807-9fc1-a0f5061a5080.jpg'),Path('boats/sailboat/168643af-803b-4783-8c6b-a785157c481e.jpeg'),Path('boats/sailboat/f4055206-199c-42a8-bf74-8ed6c4dc7e08.jpg'),Path('boats/sailboat/d5883260-5325-41cb-a645-016051c37819.jpg'),Path('boats/sailboat/4b44e5d8-75e1-413b-a9b9-bd3f19b63df4.jpg'),Path('boats/sailboat/86d8f1b9-b75b-467f-8476-2d91805a150b.jpg'),Path('boats/sailboat/e07b4c1d-a717-4c10-9059-005016798081.jpeg')...]
failed = verify_images(filenames)
failed
(#0) []
failed.map(Path.unlink)
(#0) []
failed
(#0) []

Create a way to load datasets and dataloaders

boats = DataBlock(
    blocks=(ImageBlock, CategoryBlock),
    get_items=get_image_files,
    splitter=RandomSplitter(valid_pct=0.2, seed=42),
    get_y=parent_label,
    item_tfms=Resize(128)
)
dataloaders = boats.dataloaders(path)
dataloaders.device
device(type='cuda', index=0)

If the device type is ‘cuda’ that means we’ve got the model on the GPU. If device type is ‘cpu’ then now’s a good time to see if there is a free GPU instance available. The data you downloaded earlier won’t be lost since Paperspace has persistent storage.

# Docs for some modules. Uncomment as required. 
# DataLoader?
# DataLoaders?
# Datasets?
# torch.utils.data.Dataset?
# torch.utils.data.DataLoader?

Training and validation sets

Our data is split into training and validation batches. 20% of the data is in the validation set, and 80% is in the training set.

Dataset: an iterable over tuples containing images with their corresponding category.
DataLoader: a PyTorch iterable returning a batch of datasets.
DataLoaders: a fastai iterable which splits dataloaders into training and validation datasets.

dataloaders.train.show_batch()

Item transformations

boats = boats.new(RandomResizedCrop(224, min_scale=0.5), 
                  batch_tfms=aug_transforms())
dataloaders = boats.dataloaders(path)
dataloaders.train.show_batch(max_n=8, nrows=2, unique=True)

Train a model

learn = vision_learner(dataloaders, resnet18, metrics=error_rate)
learn.fine_tune(4)
learn.recorder.plot_loss()
epoch train_loss valid_loss error_rate time
0 1.512921 0.358238 0.145455 00:13
epoch train_loss valid_loss error_rate time
0 0.609499 0.233965 0.100000 00:13
1 0.459157 0.267047 0.118182 00:13
2 0.373013 0.286165 0.118182 00:13
3 0.338020 0.260390 0.109091 00:13

Confusion matrix

interpreter = ClassificationInterpretation.from_learner(learn)
interpreter.plot_confusion_matrix()

interpreter.plot_top_losses(10)

cleaner = ImageClassifierCleaner(learn)
cleaner
# Delete bad images
for idx in cleaner.delete(): 
    cleaner.fns[idx].unlink()
    
# Recategorize mislabelled images
for idx, cat in cleaner.change():
    shutil.move(str(cleaner.fns[idx]), path/cat)

Re train model with better data

learn.lr_find()
SuggestedLRs(valley=0.00010964782268274575)

learn = vision_learner(dataloaders, resnet18, metrics=error_rate, lr=1e-4)
learn.fine_tune(4)
learn.recorder.plot_loss()
epoch train_loss valid_loss error_rate time
0 1.487951 0.548378 0.172727 00:12
epoch train_loss valid_loss error_rate time
0 0.503072 0.315105 0.127273 00:13
1 0.459028 0.248447 0.118182 00:13
2 0.389208 0.258396 0.072727 00:12
3 0.332792 0.233255 0.063636 00:13

Use the model to predict on unseen data.

Upload some images to the images folder of your GPU instance. Make sure the images weren’t in the training set.
See how well your model does at differentiating between different images. I’ve uploaded some of my own photos to test.

!ls images
canoe.jpg      testboat2.JPG  testboat4.JPG  testboat6.JPG  testboat8.JPG
testboat1.JPG  testboat3.JPG  testboat5.JPG  testboat7.JPG  testboat9.JPG
prediction, index, probs = learn.predict('images/testboat1.JPG')
Image.open('images/testboat1.JPG')

print(f"The model predicted {prediction} with a confidence of {probs[index]}")
The model predicted sailboat with a confidence of 0.9995493292808533

That’s pretty good! Let’s have a discussion about the results.