Transfer Learning

My post introducing convolutional neural networks (CNNs) used a dataset with photos of Arctic foxes, polar bears, and walruses to train a CNN to recognize Artic wildlife. Trained with 300 images – 100 for each of the three classes – the CNN achieved an accuracy around 60%. That’s not sufficient for most purposes. Imagine you’re a climate scientist tracking polar bears in the wild, and you’re using AI to analyze photos snapped by motion-activated cameras to determine which ones contain polar bears. 60% accuracy won’t get you very far.

One solution is to train the CNN with tens of thousands of photos. A better solution – one that can deliver world-class accuracy with the 300 photos you have and doesn’t require expensive GPUs – is transfer learning. In the hands of software developers and engineers, transfer learning makes CNNs a practical solution for a variety of computer-vision problems. And it requires orders of magnitude less time and compute power than CNNs built from scratch. Let’s take a moment to understand what transfer learning is and how it works – and then put it to work finding polar bears.

Understanding Transfer Learning

Pretrained CNNs trained on the ImageNet dataset can identify Arctic foxes and polar bears, but as demonstrated in my previous post, they can’t detect walruses because they weren’t trained with walrus photos. Transfer learning lets you repurpose pretrained CNNs to identify objects they weren’t originally trained to identify. It leverages the intelligence baked into pretrained CNNs, but it redirects that intelligence to solve domain-specific problems.

Recall that a CNN contains two groups of layers: bottleneck layers (also known as feature-extraction layers) containing the convolution and pooling layers that extract features from images at various resolutions, and classification layers, which classify features output from the bottleneck layers as belonging to an Arctic fox, a polar bear, or something else. The convolution layers use matrices called convolution kernels to extract features, and the values in the convolutional kernels are learned during training. This learning accounts for the bulk of the training time. When sophisticated CNNs are trained with millions of images, the convolution kernels become very efficient at extracting features.

The premise behind transfer learning is shown below. You load the bottleneck layers of a pretrained CNN, but you don’t load the classification layers. Instead, you provide your own, which train orders of magnitude more quickly than an entire CNN. Then you pass the training images through the bottleneck layers for feature extraction, and train the classification layers on those features. The pretrained CNN might have been trained to extract features from pictures of apples and oranges, but those same layers are probably pretty good at extracting features from photos of dogs and cats, too. By using the pretrained bottleneck layers to do feature extraction and then using those features to train your own classification layers, you can teach the model that a certain feature extracted from an image might be indicative of a dog rather than an apple.


Transfer learning


Transfer learning is relatively simple to implement with Keras and TensorFlow. Recall that the following statement loads Microsoft’s ResNet50V2 CNN and initializes it with the weights (including kernel values) that were arrived at when the network was trained on a subset of the ImageNet dataset:

1
2
<code data-highlighted="yes" class="hljs language-ini"><span class="hljs-attr">base_model</span> = ResNet50V2(weights=<span class="hljs-string">'imagenet'</span>)
</code>

To load ResNet50V2 (or any other pretrained CNN that Keras supports) without the classification layers, you simply add an include_top=False attribute:

1
2
<code data-highlighted="yes" class="hljs language-ini"><span class="hljs-attr">base_model</span> = ResNet50V2(weights=<span class="hljs-string">'imagenet'</span>, include_top=<span class="hljs-literal">False</span>)
</code>

From that point, there are two different ways to implement transfer learning. The first involves appending classification layers to the base model’s bottleneck layers, and setting each base layer’s trainable attribute to False so the convolution kernels won’t be updated when the network is trained:

1
2
3
4
5
6
7
8
9
10
11
12
<code data-highlighted="yes" class="hljs language-css">for layer in base_model<span class="hljs-selector-class">.layers</span>:
    layer.trainable = False
 
model = <span class="hljs-built_in">Sequential</span>()
model.<span class="hljs-built_in">add</span>(base_model)
model.<span class="hljs-built_in">add</span>(<span class="hljs-built_in">Flatten</span>())
model.<span class="hljs-built_in">add</span>(<span class="hljs-built_in">Dense</span>(<span class="hljs-number">128</span>, activation=<span class="hljs-string">'relu'</span>))
model.<span class="hljs-built_in">add</span>(<span class="hljs-built_in">Dense</span>(<span class="hljs-number">3</span>, activation=<span class="hljs-string">'softmax'</span>))
model.<span class="hljs-built_in">compile</span>(optimizer=<span class="hljs-string">'adam'</span>, loss=<span class="hljs-string">'categorical_crossentropy'</span>, metrics=[<span class="hljs-string">'accuracy'</span>])
 
model.<span class="hljs-built_in">fit</span>(x, y, validation_split=<span class="hljs-number">0.2</span>, epochs=<span class="hljs-number">10</span>, batch_size=<span class="hljs-number">10</span>)
</code>

The second technique is to run all the training images through the base model for feature extraction, and then run the features through a separate network containing your classification layers:

1
2
3
4
5
6
7
8
9
10
<code data-highlighted="yes" class="hljs language-csharp">features = base_model.predict(x)
 
model = Sequential()
model.<span class="hljs-keyword">add</span>(Flatten(input_shape=features.shape[<span class="hljs-number">1</span>:]))
model.<span class="hljs-keyword">add</span>(Dense(<span class="hljs-number">128</span>, activation=<span class="hljs-string">'relu'</span>))
model.<span class="hljs-keyword">add</span>(Dense(<span class="hljs-number">3</span>, activation=<span class="hljs-string">'softmax'</span>))
model.compile(optimizer=<span class="hljs-string">'adam'</span>, loss=<span class="hljs-string">'categorical_crossentropy'</span>, metrics=[<span class="hljs-string">'accuracy'</span>])
 
model.fit(features, y, validation_split=<span class="hljs-number">0.2</span>, epochs=<span class="hljs-number">10</span>, batch_size=<span class="hljs-number">10</span>)
</code>

Which technique is better? The second is faster because the training images go through the bottleneck layers for feature extraction just one time rather than once per epoch. It’s the technique you should use in the absence of a compelling reason to do otherwise. The first technique is slightly slower, but it lends itself to fine tuning, in which you unfreeze one or more of the bottleneck layers after training is complete and train for a few more epochs with a very low learning rate. It also makes it easy to perform data augmentation, which I’ll introduce in my next post.

Because no training is done in the bottleneck layers when the network is trained, transfer learning is much faster than training a sophisticated CNN. And because the bottleneck layers were trained when the pretrained CNN was trained, they’re already adept at extracting features from images.

If you use the first technique above to implement transfer learning, you make predictions by preprocessing the images and passing them to the model’s predict method. If you use the second (faster) technique, making predictions is a 2-step process. After preprocessing the images, you pass them to the base model’s predict method, and then you pass the output from that method to your model’s predict method:

1
2
3
4
5
6
7
<code data-highlighted="yes" class="hljs language-ini"><span class="hljs-attr">x</span> = image.img_to_array(x)
<span class="hljs-attr">x</span> = np.expand_dims(x, axis=<span class="hljs-number">0</span>)
<span class="hljs-attr">x</span> = preprocess_input(x) / <span class="hljs-number">255</span>
 
<span class="hljs-attr">features</span> = base_model.predict(x)
<span class="hljs-attr">predictions</span> = model.predict(features)
</code>

And with that, transfer learning is complete. All that remains is to put it in practice.

Use Transfer Learning to Identify Arctic Wildlife

Let’s use transfer learning to solve the same problem that we attempted to solve with a scratch-built CNN in my post introducing CNNs: building a model that determines whether a photo contains an Arctic fox, a polar bear, or a walrus.

Start by downloading the zip file containing Arctic wildlife images if you haven’t downloaded it already. Unpack the zip file and place its contents in the directory where your Jupyter notebooks are hosted. The zip file contains folders named “train,” “test,” and “samples.” Each folder contains subfolders named “arctic_fox,” “polar_bear,” and “walrus.” The training folders contain 100 images each, while the test folders contain 40 images each. Here again are some of the polar-bear training images.



As you did in the earlier tutorial, create a Jupyter notebook and paste the following code into the first cell to define helper functions for loading and labeling images and declare Python lists for accumulating images and labels:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
<code data-highlighted="yes" class="hljs language-python"><span class="hljs-keyword">import</span> os
<span class="hljs-keyword">import</span> numpy <span class="hljs-keyword">as</span> np
<span class="hljs-keyword">from</span> keras.preprocessing <span class="hljs-keyword">import</span> image
<span class="hljs-keyword">import</span> matplotlib.pyplot <span class="hljs-keyword">as</span> plt
%matplotlib inline
 
<span class="hljs-keyword">def</span> <span class="hljs-title function_">load_images_from_path</span>(<span class="hljs-params">path, label</span>):
    images = []
    labels = []
 
    <span class="hljs-keyword">for</span> file <span class="hljs-keyword">in</span> os.listdir(path):
        img = image.load_img(os.path.join(path, file), target_size=(<span class="hljs-number">224</span>, <span class="hljs-number">224</span>, <span class="hljs-number">3</span>))
        images.append(image.img_to_array(img))
        labels.append((label))
         
    <span class="hljs-keyword">return</span> images, labels
 
<span class="hljs-keyword">def</span> <span class="hljs-title function_">show_images</span>(<span class="hljs-params">images</span>):
    fig, axes = plt.subplots(<span class="hljs-number">1</span>, <span class="hljs-number">8</span>, figsize=(<span class="hljs-number">20</span>, <span class="hljs-number">20</span>), subplot_kw={<span class="hljs-string">'xticks'</span>: [], <span class="hljs-string">'yticks'</span>: []})
 
    <span class="hljs-keyword">for</span> i, ax <span class="hljs-keyword">in</span> <span class="hljs-built_in">enumerate</span>(axes.flat):
        ax.imshow(images[i] / <span class="hljs-number">255</span>)
 
x_train = []
y_train = []
x_test = []
y_test = []
</code>

Use the following statements to load 100 Arctic-fox training images and plot a subset of them:

1
2
3
4
5
6
<code data-highlighted="yes" class="hljs language-makefile">images, labels = load_images_from_path('train/arctic_fox', 0)
show_images(images)
      
x_train += images
y_train += labels
</code>

Load and label the polar-bear training images:

1
2
3
4
5
6
<code data-highlighted="yes" class="hljs language-makefile">images, labels = load_images_from_path('train/polar_bear', 1)
show_images(images)
      
x_train += images
y_train += labels
</code>

And then the walrus training images:

1
2
3
4
5
6
<code data-highlighted="yes" class="hljs language-makefile">images, labels = load_images_from_path('train/walrus', 2)
show_images(images)
  
x_train += images
y_train += labels
</code>

We also need to load the images used to validate the CNN. Start with 40 Arctic-fox test images:

1
2
3
4
5
6
<code data-highlighted="yes" class="hljs language-makefile">images, labels = load_images_from_path('test/arctic_fox', 0)
show_images(images)
      
x_test += images
y_test += labels
</code>

Then the polar-bear test images:

1
2
3
4
5
6
<code data-highlighted="yes" class="hljs language-makefile">images, labels = load_images_from_path('test/polar_bear', 1)
show_images(images)
      
x_test += images
y_test += labels
</code>

And finally, the walrus test images:

1
2
3
4
5
6
<code data-highlighted="yes" class="hljs language-makefile">images, labels = load_images_from_path('test/walrus', 2)
show_images(images)
      
x_test += images
y_test += labels
</code>

Now that the training and test images are loaded and labeled, the next step is to one-hot-encode the labels and preprocess the images. We’ll be using ResNet50V2 as our pretrained CNN, so we’ll use the ResNet version of preprocess_input to preprocess the pixels, and then divide each pixel value by 255:

1
2
3
4
5
6
7
8
9
<code data-highlighted="yes" class="hljs language-javascript"><span class="hljs-keyword">from</span> tensorflow.<span class="hljs-property">keras</span>.<span class="hljs-property">utils</span> <span class="hljs-keyword">import</span> to_categorical
<span class="hljs-keyword">from</span> tensorflow.<span class="hljs-property">keras</span>.<span class="hljs-property">applications</span>.<span class="hljs-property">resnet50</span> <span class="hljs-keyword">import</span> preprocess_input
 
x_train = <span class="hljs-title function_">preprocess_input</span>(np.<span class="hljs-title function_">array</span>(x_train)) / <span class="hljs-number">255</span>
x_test = <span class="hljs-title function_">preprocess_input</span>(np.<span class="hljs-title function_">array</span>(x_test)) / <span class="hljs-number">255</span>
     
y_train_encoded = <span class="hljs-title function_">to_categorical</span>(y_train)
y_test_encoded = <span class="hljs-title function_">to_categorical</span>(y_test)
</code>

The next step is to load a pretrained CNN, being careful to load the bottleneck layers but not the classification layers, and use it to extract features from the training and test images:

1
2
3
4
5
6
7
<code data-highlighted="yes" class="hljs language-python"><span class="hljs-keyword">from</span> tensorflow.keras.applications <span class="hljs-keyword">import</span> ResNet50V2
 
base_model = ResNet50V2(weights=<span class="hljs-string">'imagenet'</span>, include_top=<span class="hljs-literal">False</span>)
 
x_train = base_model.predict(x_train)
x_test = base_model.predict(x_test)
</code>

Now we’ll train our own neural network to classify features extracted from the training images:

1
2
3
4
5
6
7
8
9
10
11
<code data-highlighted="yes" class="hljs language-csharp"><span class="hljs-keyword">from</span> keras.models import Sequential
<span class="hljs-keyword">from</span> keras.layers import Flatten, Dense
 
model = Sequential()
model.<span class="hljs-keyword">add</span>(Flatten(input_shape=x_train.shape[<span class="hljs-number">1</span>:]))
model.<span class="hljs-keyword">add</span>(Dense(<span class="hljs-number">1024</span>, activation=<span class="hljs-string">'relu'</span>))
model.<span class="hljs-keyword">add</span>(Dense(<span class="hljs-number">3</span>, activation=<span class="hljs-string">'softmax'</span>))
model.compile(optimizer=<span class="hljs-string">'adam'</span>, loss=<span class="hljs-string">'categorical_crossentropy'</span>, metrics=[<span class="hljs-string">'accuracy'</span>])
 
hist = model.fit(x_train, y_train_encoded, validation_data=(x_test, y_test_encoded), batch_size=<span class="hljs-number">10</span>, epochs=<span class="hljs-number">10</span>)
</code>

How well did the network train? Plot the training accuracy and validation accuracy for each epoch:

1
2
3
4
5
6
7
8
9
10
11
12
<code data-highlighted="yes" class="hljs language-go">acc = hist.history[<span class="hljs-string">'accuracy'</span>]
val_acc = hist.history[<span class="hljs-string">'val_accuracy'</span>]
epochs = <span class="hljs-keyword">range</span>(<span class="hljs-number">1</span>, <span class="hljs-built_in">len</span>(acc) + <span class="hljs-number">1</span>)
 
plt.plot(epochs, acc, <span class="hljs-string">'-'</span>, label=<span class="hljs-string">'Training Accuracy'</span>)
plt.plot(epochs, val_acc, <span class="hljs-string">':'</span>, label=<span class="hljs-string">'Validation Accuracy'</span>)
plt.title(<span class="hljs-string">'Training and Validation Accuracy'</span>)
plt.xlabel(<span class="hljs-string">'Epoch'</span>)
plt.ylabel(<span class="hljs-string">'Accuracy'</span>)
plt.legend(loc=<span class="hljs-string">'lower right'</span>)
plt.plot()
</code>

Your results will differ from mine, but I got about 95% accuracy. If you didn’t quite get there, try training the network again:

Training accuracy

Finally, use a confusion matrix to visualize just how well the network is able to distinguish the various classes:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
<code data-highlighted="yes" class="hljs language-python"><span class="hljs-keyword">from</span> sklearn.metrics <span class="hljs-keyword">import</span> confusion_matrix
<span class="hljs-keyword">import</span> seaborn <span class="hljs-keyword">as</span> sns
sns.<span class="hljs-built_in">set</span>()
 
y_predicted = model.predict(x_test)
mat = confusion_matrix(y_test_encoded.argmax(axis=<span class="hljs-number">1</span>), y_predicted.argmax(axis=<span class="hljs-number">1</span>))
class_labels = [<span class="hljs-string">'arctic fox'</span>, <span class="hljs-string">'polar bear'</span>, <span class="hljs-string">'walrus'</span>]
 
sns.heatmap(mat, square=<span class="hljs-literal">True</span>, annot=<span class="hljs-literal">True</span>, fmt=<span class="hljs-string">'d'</span>, cbar=<span class="hljs-literal">False</span>, cmap=<span class="hljs-string">'Blues'</span>,
            xticklabels=class_labels,
            yticklabels=class_labels)
 
plt.xlabel(<span class="hljs-string">'Predicted label'</span>)
plt.ylabel(<span class="hljs-string">'Actual label'</span>)
</code>

To see transfer learning at work, load one of the Arctic-fox images from the “samples” folder. That folder contains wildlife images that the model was neither trained nor tested with:

1
2
3
4
5
<code data-highlighted="yes" class="hljs language-bash">x = image.load_img(<span class="hljs-string">'samples/arctic_fox/arctic_fox_140.jpeg'</span>, target_size=(224, 224))
plt.xticks([])
plt.yticks([])
plt.imshow(x)
</code>

Now preprocess the image, run it through ResNet50V2‘s feature-extraction layers, and run the output through the newly trained classification layers:

1
2
3
4
5
6
7
8
9
10
<code data-highlighted="yes" class="hljs language-makefile">x = image.img_to_array(x)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x) / 255
 
y = base_model.predict(x)
predictions = model.predict(y)
 
for i, label in enumerate(class_labels):
    print(f'{label}: {predictions[0][i]}')
</code>

With a little luck, the network predicted with almost 100% confidence that the image contains an Arctic fox. Perhaps that’s not surprising since ResNetV2 was trained with Arctic-fox images as well as polar-bear images. But now let’s load a walrus image, which, you’ll recall, ResNet50V2 was unable to classify:

1
2
3
4
5
<code data-highlighted="yes" class="hljs language-bash">x = image.load_img(<span class="hljs-string">'samples/walrus/walrus_143.png'</span>, target_size=(224, 224))
plt.xticks([])
plt.yticks([])
plt.imshow(x)
</code>

Preprocess the image and make a prediction:

1
2
3
4
5
6
7
8
9
10
<code data-highlighted="yes" class="hljs language-makefile">x = image.img_to_array(x)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x) / 255
 
y = base_model.predict(x)
predictions = model.predict(y)
 
for i, label in enumerate(class_labels):
    print(f'{label}: {predictions[0][i]}')
</code>

ResNet50V2 wasn’t trained to recognize walruses, but our network was. That’s transfer learning in a nutshell. It’s the deep-learning equivalent of having your cake and eating it, too. And it’s the secret sauce that makes CNNs a viable tool for anyone with a laptop and a few hundred training images.

That’s not to say that transfer learning will always get you 95% accuracy with 100 images per class. It won’t. If a dataset lacks the information to achieve that level of separation, neither scratch-built CNNs nor transfer learning will magically make it happen. That’s always true in machine learning and AI. You can’t get water from a rock. And you can’t build an accurate model from data that doesn’t support it.

Get the Code

You can download a Jupyter notebook demonstrating transfer learning from the deep-learning repo that I maintain on GitHub. Feel free to check out the other notebooks in the repo while you’re at it. Also be sure to check back from time to time because I am constantly uploading new samples and updating existing ones.

Stay Informed

Sign up for the latest blogs, events, and insights.

We deliver solutions that accelerate the value of Azure.
Ready to experience the full power of Microsoft Azure?

Atmosera is thrilled to announce that we have been named GitHub AI Partner of the Year.

X