As I mentioned in a previous post, a convolutional neural network (CNN) can be used to classify colour images in much the same way as grey scale classification. The way to achieve this is by utilizing the depth dimension of our input tensors and kernels. In this example I’ll be using the CIFAR-10 dataset, which consists of 32×32 colour images belonging to 10 different classes.

You can see a few examples of each class in the following image from the CIFAR-10 website:

Although previously I’ve talked about the Lasagne and nolearn packages (here and here), extending those to colour images is a rather trivial task. So instead, in this post I’ll be building the network from the ground up, but using a module by Michael Nielsen to handle the plumbing. You can find his original code here¬†(I’ll be using, but renamed to

The key here is to make use of the depth channel to handle colour. In the Lasagne post our input layer received a 4D tensor that looked like this: (None, 1, 28, 28). The last 2 dimensions (28, 28) are the width and height of the input image, and the second dimension (1) is the depth. In the Lasagne MNIST example we only had 1 depth channel as we were dealing with grey scale images. Here, we simply set the depth to 3 to handle RGB colour. We need to make sure we do the same thing with the convolution kernels – they too will need a depth of 3 as the model will be learning colour kernels instead of grey scale.

The goal of this post is to demonstrate how to train a model for colour image classification, rather than try to obtain high classification accuracy (this can be fine-tuned later).

Note: the code below is for Python 2.7 – modifications may be needed to run this on Python 3.x

Loading CIFAR-10 data

To begin, lets import some packages:

I have downloaded the CIFAR-10 data to ../data/cifar-10-batches-py, so my directory structure looks like this:

In my main script ( I use the following functions to load the batches into training and test sets (along with their labels):

Then I create a validation set from the test set:

Finally, I place all of the data into a theano shared variable for use on the GPU during training:

Training the model

We can visualize a single random training example to see what the data looks like:

We’ll be using Michael Nielsen’s module which contains classes for different layer types. His classes handle things like weight initialization, passing parameters to/from other layers in the model, and simplifies the process of constructing a network. For example, the combined convolution and pooling layer (which we’ll be using) looks like this:

First we should set our mini-batch size, which is the number of training examples per batch. I use a small number here just for demonstration purposes, but how large of a batch size you use depends on how much memory you have available (this is mostly a concern if you’re training on a GPU with limited memory):

Next, we define the network architecture by combining layers, setting their parameters, and choosing a minimization method (Stochastic Gradient Descent in this case). This method also allows you to train multiple networks for ensemble training by setting the value of the n parameter of the function:

Finally, call the function to start training. Here I only train 1 network and only 10 epochs – ideally you’d want to train over many more epochs to obtain better model accuracy:

Results and visualizations

During training, you can see the current epoch’s train/validation/test accuracy:

We achieved 66.86% test set accuracy after only 15 minutes of training over 10 epochs. There are many tweaks we could perform to improve accuracy, for example, changing the architecture of our model, or simply increasing the number of training epochs.

Finally, we can plot the training and accuracy curves to see how our model performed over time:

Training a CNN for colour image classification is very similar to training for grey scale classification. Additionally, using a package to handle the layers and passing of parameters (whether that’d be Lasagne, or a custom module like we used here) makes the process a whole lot easier.

As always you can find the full code in my GitHub repo: and

Leave a Reply


This site uses Akismet to reduce spam. Learn how your comment data is processed.

Notify of