Data Augmentation

My previous post demonstrated how to use transfer learning to build a model that with just 300 training images can classify photos of three different types of Arctic wildlife with 95% accuracy. One of the benefits of transfer learning is that it can do a lot with relatively few images. This feature, however, can also be a bug. With just 100 or samples of each class, there isn’t a lot of diversity among images. A model might be able to recognize a polar bear if the bear’s head is perfectly aligned in center of the photo. But if the training images don’t include photos with the bear’s head aligned differently or tilted at different angles, the model might have difficulty classifying the photo.

One solution is data augmentation. Rather than scare up more training images, you can rotate, translate, and scale the images you have. It doesn’t always increase accuracy, but it frequently does. Keras makes it easy to randomly transform training images provided to a network. Images are transformed differently in each epoch, so if you train for 10 epochs, the network sees 10 different variations of each training image. This can increase a model’s ability to generalize with little to no impact on training time. The figure below shows the effect of applying random transforms to a hot-dog image. You can see why presenting the same image to a model in different ways might make the model more adept at recognizing hot dogs, regardless of how the hot dog is framed.

Data augmentation

Keras has built-in support for data augmentation with images. Let’s look at a couple of ways to put image augmentation to work, and then apply it to the Arctic-wildlife model presented in the previous post.

Image Augmentation with ImageDataGenerator

One way to leverage image augmentation when training a model is to use Keras’s ImageDataGenerator class. ImageDataGenerator generates batches of training images on the fly, either from images you’ve loaded (for example, with Keras’s load_img function) or from a specified location in the file system. The latter is especially useful when training CNNs with millions of images because it loads images into memory in batches rather than all at once. Regardless of where the images come from, however, ImageDataGenerator is happy to apply transforms as it serves them up.

Here’s a simple example that you can try yourself. Use the following code to load an image from your file system, wrap an ImageDataGenerator around it, and generate 24 versions of the image. Be sure to replace polar_bear.png on line 8 with the path to the image:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
<code data-highlighted="yes" class="hljs language-python"><span class="hljs-keyword">import</span> numpy <span class="hljs-keyword">as</span> np
<span class="hljs-keyword">from</span> keras.preprocessing <span class="hljs-keyword">import</span> image
<span class="hljs-keyword">from</span> tensorflow.keras.preprocessing.image <span class="hljs-keyword">import</span> ImageDataGenerator
<span class="hljs-keyword">import</span> matplotlib.pyplot <span class="hljs-keyword">as</span> plt
%matplotlib inline
 
<span class="hljs-comment"># Load an image</span>
x = image.load_img(<span class="hljs-string">'polar_bear.png'</span>)
x = image.img_to_array(x)
x = np.expand_dims(x, axis=<span class="hljs-number">0</span>)
 
<span class="hljs-comment"># Wrap an ImageDataGenerator around it</span>
idg = ImageDataGenerator(rescale=<span class="hljs-number">1.</span>/<span class="hljs-number">255</span>,
                         horizontal_flip=<span class="hljs-literal">True</span>,
                         rotation_range=<span class="hljs-number">30</span>,
                         width_shift_range=<span class="hljs-number">0.2</span>,
                         height_shift_range=<span class="hljs-number">0.2</span>,
                         zoom_range=<span class="hljs-number">0.2</span>)
idg.fit(x)
 
<span class="hljs-comment"># Generate 24 versions of the image</span>
generator = idg.flow(x, [<span class="hljs-number">0</span>], batch_size=<span class="hljs-number">1</span>, seed=<span class="hljs-number">0</span>)
fig, axes = plt.subplots(<span class="hljs-number">3</span>, <span class="hljs-number">8</span>, figsize=(<span class="hljs-number">16</span>, <span class="hljs-number">6</span>), subplot_kw={<span class="hljs-string">'xticks'</span>: [], <span class="hljs-string">'yticks'</span>: []})
 
<span class="hljs-keyword">for</span> i, ax <span class="hljs-keyword">in</span> <span class="hljs-built_in">enumerate</span>(axes.flat):
    img, label = generator.<span class="hljs-built_in">next</span>()
    ax.imshow(img[<span class="hljs-number">0</span>])
</code>

Here’s the result:


Polar bears


The parameters passed to ImageDataGenerator tell it how to transform the image each time it’s fetched:

  • rescale=1./255 divides each pixel value by 255
  • horizontal_flip=True randomly flips the image horizontally (around a vertical axis)
  • rotation_range=30 randomly rotates the image by -30 to 30 degrees
  • width_shift_range=0.2 and height_shift_range=0.2 randomly translate the image by -20% to 20%
  • zoom_range=0.2 randomly scales the image by -20% to 20%

There are other parameters that you can use such as vertical_flip, shear_range, and brightness_range, but you get the picture. The flow method generates images from the images you pass to fit. The related flow_from_directory method loads images from the file system and optionally labels them based on the subdirectories they’re in.

The generator returned by flow can be passed directly to a model’s fit method to provide randomly transformed images to the model as it is trained. Assume that x_train and y_train hold a collection of training images and labels. The following code wraps an ImageDataGenerator around them and uses them to train a model:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
<code data-highlighted="yes" class="hljs language-makefile">idg = ImageDataGenerator(rescale=1./255,
                         horizontal_flip=True,
                         rotation_range=30,
                         width_shift_range=0.2,
                         height_shift_range=0.2,
                         zoom_range=0.2)
 
idg.fit(x_train)
image_batch_size = 10
generator = idg.flow(x_train, y_train, batch_size=image_batch_size, seed=0)
 
model.fit(generator,
          steps_per_epoch=len(x_train) // image_batch_size,
          validation_data=(x_test, y_test),
          batch_size=20,
          epochs=10)
</code>

The steps_per_epoch parameter is key because an ImageDataGenerator can provide an infinite number of versions of each image. In this example, the batch_size parameter passed to flow tells the generator to create 10 images in each batch (each call to next). Dividing the number of images by the image batch size to calculate steps_per_epoch ensures that in each training epoch, the model is provided with one transformed version of each image in the dataset.

Earlier versions of Keras didn’t allow a generator to be passed to a model’s fit method. Instead, they provided a separate method named fit_generator. That method is deprecated and should no longer be used. It will be removed in a future release.

Observe that the call to fit includes a validation_data parameter identifying a separate set of images and labels for validating the network during training. You generally don’t want to augment validation images, so you should avoid using validation_split when passing a generator to fit.

Image Augmentation with Augmentation Layers

You can use ImageDataGenerator to provide transformed images to a model, but recent versions of Keras provide an alternative in the form of image-preprocessing layers and image-augmentation layers. Rather than transform training images separately, you can integrate the transforms into the model. Here’s an example:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
<code data-highlighted="yes" class="hljs language-csharp"><span class="hljs-keyword">from</span> keras.models import Sequential
<span class="hljs-keyword">from</span> keras.layers import Conv2D, MaxPooling2D
<span class="hljs-keyword">from</span> keras.layers import Rescaling, RandomFlip, RandomRotation, RandomTranslation, RandomZoom
<span class="hljs-keyword">from</span> keras.layers import Flatten, Dense
 
model = Sequential()
model.<span class="hljs-keyword">add</span>(Rescaling(<span class="hljs-number">1.</span>/<span class="hljs-number">255</span>))
model.<span class="hljs-keyword">add</span>(RandomFlip(mode=<span class="hljs-string">'horizontal'</span>))
model.<span class="hljs-keyword">add</span>(RandomTranslation(<span class="hljs-number">0.2</span>, <span class="hljs-number">0.2</span>))
model.<span class="hljs-keyword">add</span>(RandomRotation(<span class="hljs-number">0.2</span>))
model.<span class="hljs-keyword">add</span>(RandomZoom(<span class="hljs-number">0.2</span>))
model.<span class="hljs-keyword">add</span>(Conv2D(<span class="hljs-number">32</span>, (<span class="hljs-number">3</span>, <span class="hljs-number">3</span>), activation=<span class="hljs-string">'relu'</span>, input_shape=(<span class="hljs-number">224</span>, <span class="hljs-number">224</span>, <span class="hljs-number">3</span>)))
model.<span class="hljs-keyword">add</span>(MaxPooling2D(<span class="hljs-number">2</span>, <span class="hljs-number">2</span>))
model.<span class="hljs-keyword">add</span>(Conv2D(<span class="hljs-number">128</span>, (<span class="hljs-number">3</span>, <span class="hljs-number">3</span>), activation=<span class="hljs-string">'relu'</span>))
model.<span class="hljs-keyword">add</span>(MaxPooling2D(<span class="hljs-number">2</span>, <span class="hljs-number">2</span>))
model.<span class="hljs-keyword">add</span>(Flatten())
model.<span class="hljs-keyword">add</span>(Dense(<span class="hljs-number">128</span>, activation=<span class="hljs-string">'relu'</span>))
model.<span class="hljs-keyword">add</span>(Dense(<span class="hljs-number">3</span>, activation=<span class="hljs-string">'softmax'</span>)
</code>

Each image used to train the CNN has its pixel values divided by 255 and is then randomly flipped, translated, rotated, and scaled. Significantly, the RandomFlip, RandomTranslation, RandomRotation, and RandomZoom layers only operate on training images. They are inactive when the network is validated or asked to make predictions. The Rescaling layer is active at all times, meaning you no longer have to remember to divide by 255 before passing an image to the network for classification.

Apply Image Augmentation to Arctic Wildlife

Would image augmentation make the model featured in my previous post even better? There’s one way to find out.

If you haven’t already, download the zip file containing wildlife images. Unpack the zip file and place its contents in the directory where your Jupyter notebooks are hosted. The zip file contains folders named “train,” “test,” and “samples.” Each folder contains subfolders named “arctic_fox,” “polar_bear,” and “walrus.” The training folders contain 100 images each, while the test folders contain 40 images each.

Create a Jupyter notebook and paste the following code into the first cell to define helper functions for loading and labeling images and declare Python lists for accumulating images and labels:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
<code data-highlighted="yes" class="hljs language-python"><span class="hljs-keyword">import</span> os
<span class="hljs-keyword">import</span> numpy <span class="hljs-keyword">as</span> np
<span class="hljs-keyword">from</span> keras.preprocessing <span class="hljs-keyword">import</span> image
<span class="hljs-keyword">import</span> matplotlib.pyplot <span class="hljs-keyword">as</span> plt
%matplotlib inline
 
<span class="hljs-keyword">def</span> <span class="hljs-title function_">load_images_from_path</span>(<span class="hljs-params">path, label</span>):
    images = []
    labels = []
 
    <span class="hljs-keyword">for</span> file <span class="hljs-keyword">in</span> os.listdir(path):
        img = image.load_img(os.path.join(path, file), target_size=(<span class="hljs-number">224</span>, <span class="hljs-number">224</span>, <span class="hljs-number">3</span>))
        images.append(image.img_to_array(img))
        labels.append((label))
         
    <span class="hljs-keyword">return</span> images, labels
 
<span class="hljs-keyword">def</span> <span class="hljs-title function_">show_images</span>(<span class="hljs-params">images</span>):
    fig, axes = plt.subplots(<span class="hljs-number">1</span>, <span class="hljs-number">8</span>, figsize=(<span class="hljs-number">20</span>, <span class="hljs-number">20</span>), subplot_kw={<span class="hljs-string">'xticks'</span>: [], <span class="hljs-string">'yticks'</span>: []})
 
    <span class="hljs-keyword">for</span> i, ax <span class="hljs-keyword">in</span> <span class="hljs-built_in">enumerate</span>(axes.flat):
        ax.imshow(images[i] / <span class="hljs-number">255</span>)
 
x_train = []
y_train = []
x_test = []
y_test = []
</code>

Use the following statements to load the Arctic-fox training images and plot a few of them:

1
2
3
4
5
6
<code data-highlighted="yes" class="hljs language-makefile">images, labels = load_images_from_path('train/arctic_fox', 0)
show_images(images)
       
x_train += images
y_train += labels
</code>

Load and label the polar-bear training images:

1
2
3
4
5
6
<code data-highlighted="yes" class="hljs language-makefile">images, labels = load_images_from_path('train/polar_bear', 1)
show_images(images)
       
x_train += images
y_train += labels
</code>

And then the walrus training images:

1
2
3
4
5
6
<code data-highlighted="yes" class="hljs language-makefile">images, labels = load_images_from_path('train/walrus', 2)
show_images(images)
   
x_train += images
y_train += labels
</code>

The dataset also contains test images. Load the Arctic-fox test images:

1
2
3
4
5
6
<code data-highlighted="yes" class="hljs language-makefile">images, labels = load_images_from_path('test/arctic_fox', 0)
show_images(images)
       
x_test += images
y_test += labels
</code>

Then the polar-bear test images:

1
2
3
4
5
6
<code data-highlighted="yes" class="hljs language-makefile">images, labels = load_images_from_path('test/polar_bear', 1)
show_images(images)
       
x_test += images
y_test += labels
</code>

And finally, the walrus test images:

1
2
3
4
5
6
<code data-highlighted="yes" class="hljs language-makefile">images, labels = load_images_from_path('test/walrus', 2)
show_images(images)
       
x_test += images
y_test += labels
</code>

The next step is to one-hot-encode the labels and preprocess the images the way ResNet50V2 expects. Note that there is no need to divide pixel values by 255 because we’ll include a Rescaling layer in our network to do that:

1
2
3
4
5
6
7
8
9
<code data-highlighted="yes" class="hljs language-javascript"><span class="hljs-keyword">from</span> tensorflow.<span class="hljs-property">keras</span>.<span class="hljs-property">utils</span> <span class="hljs-keyword">import</span> to_categorical
<span class="hljs-keyword">from</span> tensorflow.<span class="hljs-property">keras</span>.<span class="hljs-property">applications</span>.<span class="hljs-property">resnet50</span> <span class="hljs-keyword">import</span> preprocess_input
 
x_train = <span class="hljs-title function_">preprocess_input</span>(np.<span class="hljs-title function_">array</span>(x_train))
x_test = <span class="hljs-title function_">preprocess_input</span>(np.<span class="hljs-title function_">array</span>(x_test))
     
y_train_encoded = <span class="hljs-title function_">to_categorical</span>(y_train)
y_test_encoded = <span class="hljs-title function_">to_categorical</span>(y_test)
</code>

Now load ResNet50V2 without the classification layers and initialize it with the weights arrived at when it was trained on the ImageNet dataset. A key element here is preventing the bottleneck layers from training when the network is trained by setting their trainable attributes to False, effectively freezing those layers:

1
2
3
4
5
6
7
<code data-highlighted="yes" class="hljs language-python"><span class="hljs-keyword">from</span> tensorflow.keras.applications <span class="hljs-keyword">import</span> ResNet50V2
 
base_model = ResNet50V2(weights=<span class="hljs-string">'imagenet'</span>, include_top=<span class="hljs-literal">False</span>)
 
<span class="hljs-keyword">for</span> layer <span class="hljs-keyword">in</span> base_model.layers:
    layer.trainable = <span class="hljs-literal">False</span>
</code>

Define a network that incorporates rescaling and augmentation layers, ResNet50V2‘s bottleneck layers, dense layers for classification, and a dropout layer to help the network generalize. Then train the network using an increased number of epochs so it sees more randomly transformed training samples:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
<code data-highlighted="yes" class="hljs language-csharp"><span class="hljs-keyword">from</span> keras.models import Sequential
<span class="hljs-keyword">from</span> keras.layers import Flatten, Dense, Dropout
<span class="hljs-keyword">from</span> keras.layers import Rescaling, RandomFlip, RandomRotation, RandomTranslation, RandomZoom
 
model = Sequential()
model.<span class="hljs-keyword">add</span>(Rescaling(<span class="hljs-number">1.</span>/<span class="hljs-number">255</span>))
model.<span class="hljs-keyword">add</span>(RandomFlip(mode=<span class="hljs-string">'horizontal'</span>))
model.<span class="hljs-keyword">add</span>(RandomTranslation(<span class="hljs-number">0.2</span>, <span class="hljs-number">0.2</span>))
model.<span class="hljs-keyword">add</span>(RandomRotation(<span class="hljs-number">0.2</span>))
model.<span class="hljs-keyword">add</span>(RandomZoom(<span class="hljs-number">0.2</span>))
model.<span class="hljs-keyword">add</span>(base_model)
model.<span class="hljs-keyword">add</span>(Flatten())
model.<span class="hljs-keyword">add</span>(Dense(<span class="hljs-number">1024</span>, activation=<span class="hljs-string">'relu'</span>))
model.<span class="hljs-keyword">add</span>(Dropout(<span class="hljs-number">0.2</span>))
model.<span class="hljs-keyword">add</span>(Dense(<span class="hljs-number">3</span>, activation=<span class="hljs-string">'softmax'</span>))
model.compile(optimizer=<span class="hljs-string">'adam'</span>, loss=<span class="hljs-string">'categorical_crossentropy'</span>, metrics=[<span class="hljs-string">'accuracy'</span>])
 
hist = model.fit(x_train, y_train_encoded, validation_data=(x_test, y_test_encoded), batch_size=<span class="hljs-number">10</span>, epochs=<span class="hljs-number">25</span>)
</code>
Dropout is a commonly used technique for increasing a neural network’s ability to generalize by preventing it from fitting too tightly to the training data. In Keras, dropout is introduced by including Dropout layers in the network. Dropout(0.2) tells Keras to drop a randomly selected 20% of the connections between neurons in each training pass — that is, each time a batch of training samples is run through the network. Dropout layers are active during training but are ignored when the network is asked to make predictions.

How well did the network train? Let’s plot the training accuracy and validation accuracy for each epoch:

1
2
3
4
5
6
7
8
9
10
11
12
<code data-highlighted="yes" class="hljs language-go">acc = hist.history[<span class="hljs-string">'accuracy'</span>]
val_acc = hist.history[<span class="hljs-string">'val_accuracy'</span>]
epochs = <span class="hljs-keyword">range</span>(<span class="hljs-number">1</span>, <span class="hljs-built_in">len</span>(acc) + <span class="hljs-number">1</span>)
 
plt.plot(epochs, acc, <span class="hljs-string">'-'</span>, label=<span class="hljs-string">'Training Accuracy'</span>)
plt.plot(epochs, val_acc, <span class="hljs-string">':'</span>, label=<span class="hljs-string">'Validation Accuracy'</span>)
plt.title(<span class="hljs-string">'Training and Validation Accuracy'</span>)
plt.xlabel(<span class="hljs-string">'Epoch'</span>)
plt.ylabel(<span class="hljs-string">'Accuracy'</span>)
plt.legend(loc=<span class="hljs-string">'lower right'</span>)
plt.plot()
</code>

With a little luck, the network achieved 97% to 98% accuracy, which is a couple percentage points more than it achieved without data augmentation. Use a confusion matrix to visualize how well the network performed during testing:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
<code data-highlighted="yes" class="hljs language-python"><span class="hljs-keyword">from</span> sklearn.metrics <span class="hljs-keyword">import</span> confusion_matrix
<span class="hljs-keyword">import</span> seaborn <span class="hljs-keyword">as</span> sns
sns.<span class="hljs-built_in">set</span>()
 
y_predicted = model.predict(x_test)
mat = confusion_matrix(y_test_encoded.argmax(axis=<span class="hljs-number">1</span>), y_predicted.argmax(axis=<span class="hljs-number">1</span>))
class_labels = [<span class="hljs-string">'arctic fox'</span>, <span class="hljs-string">'polar bear'</span>, <span class="hljs-string">'walrus'</span>]
 
sns.heatmap(mat, square=<span class="hljs-literal">True</span>, annot=<span class="hljs-literal">True</span>, fmt=<span class="hljs-string">'d'</span>, cbar=<span class="hljs-literal">False</span>, cmap=<span class="hljs-string">'Blues'</span>,
            xticklabels=class_labels,
            yticklabels=class_labels)
 
plt.xlabel(<span class="hljs-string">'Predicted label'</span>)
plt.ylabel(<span class="hljs-string">'Actual label'</span>)
</code>

Use the following statements to load an Arctic-fox image that the network was neither trained nor tested with:

1
2
3
4
5
<code data-highlighted="yes" class="hljs language-bash">x = image.load_img(<span class="hljs-string">'samples/arctic_fox/arctic_fox_140.jpeg'</span>, target_size=(224, 224))
plt.xticks([])
plt.yticks([])
plt.imshow(x)
</code>

Preprocess the image and see how the network classifies it:

1
2
3
4
5
6
7
8
<code data-highlighted="yes" class="hljs language-python">x = image.img_to_array(x)
x = np.expand_dims(x, axis=<span class="hljs-number">0</span>)
x = preprocess_input(x)
predictions = model.predict(x)
 
<span class="hljs-keyword">for</span> i, label <span class="hljs-keyword">in</span> <span class="hljs-built_in">enumerate</span>(class_labels):
    <span class="hljs-built_in">print</span>(<span class="hljs-string">f'<span class="hljs-subst">{label}</span>: <span class="hljs-subst">{predictions[<span class="hljs-number">0</span>][i]}</span>'</span>)
</code>

Now load a walrus image:

1
2
3
4
5
<code data-highlighted="yes" class="hljs language-bash">x = image.load_img(<span class="hljs-string">'samples/walrus/walrus_143.png'</span>, target_size=(224, 224))
plt.xticks([])
plt.yticks([])
plt.imshow(x)
</code>

And submit it to the network for classification:

1
2
3
4
5
6
7
8
<code data-highlighted="yes" class="hljs language-python">x = image.img_to_array(x)
x = np.expand_dims(x, axis=<span class="hljs-number">0</span>)
x = preprocess_input(x)
predictions = model.predict(x)
 
<span class="hljs-keyword">for</span> i, label <span class="hljs-keyword">in</span> <span class="hljs-built_in">enumerate</span>(class_labels):
    <span class="hljs-built_in">print</span>(<span class="hljs-string">f'<span class="hljs-subst">{label}</span>: <span class="hljs-subst">{predictions[<span class="hljs-number">0</span>][i]}</span>'</span>)
</code>

Data scientists often employ data augmentation even when they’re training a CNN from scratch rather than employing transfer learning. It’s a useful tool to know about, and one that could make a difference when you’re trying to squeeze every last ounce of accuracy out of a deep-learning model.

Get the Code

You can download a Jupyter notebook demonstrating transfer learning with data augmentation from the deep-learning repo that I maintain on GitHub. Feel free to check out the other notebooks in the repo while you’re at it. Also be sure to check back from time to time because I am constantly uploading new samples and updating existing ones.

Stay Informed

Sign up for the latest blogs, events, and insights.

We deliver solutions that accelerate the value of Azure.
Ready to experience the full power of Microsoft Azure?

Atmosera is thrilled to announce that we have been named GitHub AI Partner of the Year.

X