Image Classification using LeNet CNN

Fashion MNIST Dataset - Clothing Objects (10 classes)

t-shirt/top, trouser, pullover, dress, coat, sandal, shirt, sneaker, bag, ankle boot

fashionMNIST Sample Data

In [1]:
# import tensorflow module. Check API version.
import tensorflow as tf
import numpy as np

print (tf.__version__)

# required for TF to run within docker using GPU (ignore otherwise)
# gpu = tf.config.experimental.list_physical_devices('GPU')
# tf.config.experimental.set_memory_growth(gpu[0], True)
2.3.0

Load the data

In [2]:
# grab the Fashion MNIST dataset (may take time the first time)
print("[INFO] downloading Fashion MNIST...")
(trainData, trainLabels), (testData, testLabels) = tf.keras.datasets.fashion_mnist.load_data()
[INFO] downloading Fashion MNIST...
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz
32768/29515 [=================================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz
26427392/26421880 [==============================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz
8192/5148 [===============================================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz
4423680/4422102 [==============================] - 0s 0us/step

Prepare the data

In [3]:
print(trainData.shape)
print(testData.shape)
(60000, 28, 28)
(10000, 28, 28)
In [4]:
# parameters for Fashion MNIST data set
num_classes = 10
image_width = 28
image_height = 28
image_channels = 1
# define human readable class names
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat/jacket',
               'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
In [5]:
# shape the input data using "channels last" ordering
# num_samples x rows x columns x depth
trainData = trainData.reshape(
        (trainData.shape[0], image_height, image_width, image_channels))
testData = testData.reshape(
        (testData.shape[0], image_height, image_width, image_channels))
In [6]:
# convert to floating point and scale data to the range of [0.0, 1.0]
trainData = trainData.astype("float32") / 255.0
testData = testData.astype("float32") / 255.0
In [7]:
# pad the data to 32X32 for use in LeNet5 network
trainData = np.pad(trainData, ((0,0),(2,2),(2,2),(0,0)), 'constant')
testData = np.pad(testData, ((0,0),(2,2),(2,2),(0,0)), 'constant')
In [8]:
# display data dimentions
print ("trainData:", trainData.shape)
print ("trainLabels:", trainLabels.shape)
print ("testData:", testData.shape)
print ("testLabels:", testLabels.shape)
trainData: (60000, 32, 32, 1)
trainLabels: (60000,)
testData: (10000, 32, 32, 1)
testLabels: (10000,)
In [9]:
# parameters for training data set
num_classes = 10
image_width = 32
image_height = 32
image_channels = 1

Define Model

LeNet5 Model

In [39]:
# import the necessary packages
from tensorflow.keras import backend
from tensorflow.keras import models
from tensorflow.keras import layers

# define the model as a class
class LeNet:
    # INPUT => CONV => TANH => AVG-POOL => CONV => TANH => AVG-POOL => FC => TANH => FC => TANH => FC => SMAX
    @staticmethod
    def init(numChannels, imgRows, imgCols, numClasses, weightsPath=None):
        # if we are using "channels first", update the input shape
        if backend.image_data_format() == "channels_first":
            inputShape = (numChannels, imgRows, imgCols)
        else:  # "channels last"
            inputShape = (imgRows, imgCols, numChannels)

        # initialize the model
        model = models.Sequential()

        # define the first set of CONV => ACTIVATION => POOL layers
        model.add(layers.Conv2D(filters=6, kernel_size=(5, 5), strides=(1, 1),
                padding="valid", activation=tf.nn.tanh, input_shape=inputShape))
        model.add(layers.AveragePooling2D(pool_size=(2, 2), strides=(2, 2)))

        # define the second set of CONV => ACTIVATION => POOL layers
        model.add(layers.Conv2D(filters=16, kernel_size=(5, 5), strides=(1, 1),
                padding="valid", activation=tf.nn.tanh))
        model.add(layers.AveragePooling2D(pool_size=(2, 2), strides=(2, 2)))

        # flatten the convolution volume to fully connected layers
        model.add(layers.Flatten())

        # define the first FC => ACTIVATION layers
        model.add(layers.Dense(units=120, activation=tf.nn.tanh))

        # define the second FC => ACTIVATION layers
        model.add(layers.Dense(units=84, activation=tf.nn.tanh))

        # lastly, define the soft-max classifier
        model.add(layers.Dense(units=numClasses, activation=tf.nn.softmax))

        # if a weights path is supplied (inicating that the model was
        # pre-trained), then load the weights
        if weightsPath is not None:
            model.load_weights(weightsPath)

        # return the constructed network architecture
        return model

Compile Model

In [40]:
# initialize the model
print("[INFO] compiling model...")
model = LeNet.init(numChannels=image_channels,
                    imgRows=image_height, imgCols=image_width,
                    numClasses=num_classes,
                    weightsPath=None)

# compile the model
model.compile(optimizer=tf.keras.optimizers.SGD(lr=0.05),  # Stochastic Gradient Descent
        loss="sparse_categorical_crossentropy",
        metrics=["accuracy"])

# print model summary
model.summary()
[INFO] compiling model...
Model: "sequential_10"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_20 (Conv2D)           (None, 28, 28, 6)         156       
_________________________________________________________________
average_pooling2d_20 (Averag (None, 14, 14, 6)         0         
_________________________________________________________________
conv2d_21 (Conv2D)           (None, 10, 10, 16)        2416      
_________________________________________________________________
average_pooling2d_21 (Averag (None, 5, 5, 16)          0         
_________________________________________________________________
flatten_10 (Flatten)         (None, 400)               0         
_________________________________________________________________
dense_30 (Dense)             (None, 120)               48120     
_________________________________________________________________
dense_31 (Dense)             (None, 84)                10164     
_________________________________________________________________
dense_32 (Dense)             (None, 10)                850       
=================================================================
Total params: 61,706
Trainable params: 61,706
Non-trainable params: 0
_________________________________________________________________

Train Model

In [41]:
# define callback function for training termination criteria
#accuracy_cutoff = 0.99
class myCallback(tf.keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs=None):
    if(logs.get('accuracy') > 0.95):
      print("\nReached 95% accuracy so cancelling training!")
      self.model.stop_training = True
# initialize training config
batch_size = 128
epochs = 120

# run training
print("[INFO] training...")
history = model.fit(x=trainData, y=trainLabels, validation_data=(testData, testLabels),batch_size=batch_size, epochs=epochs, verbose=1, callbacks=[myCallback()])
[INFO] training...
Epoch 1/120
469/469 [==============================] - 2s 4ms/step - loss: 0.7883 - accuracy: 0.7223 - val_loss: 0.6043 - val_accuracy: 0.7791
Epoch 2/120
469/469 [==============================] - 2s 3ms/step - loss: 0.5214 - accuracy: 0.8107 - val_loss: 0.5044 - val_accuracy: 0.8123
Epoch 3/120
469/469 [==============================] - 1s 3ms/step - loss: 0.4547 - accuracy: 0.8359 - val_loss: 0.4687 - val_accuracy: 0.8252
Epoch 4/120
469/469 [==============================] - 1s 3ms/step - loss: 0.4197 - accuracy: 0.8487 - val_loss: 0.4291 - val_accuracy: 0.8437
Epoch 5/120
469/469 [==============================] - 1s 3ms/step - loss: 0.3928 - accuracy: 0.8586 - val_loss: 0.4228 - val_accuracy: 0.8465
Epoch 6/120
469/469 [==============================] - 2s 3ms/step - loss: 0.3744 - accuracy: 0.8651 - val_loss: 0.3933 - val_accuracy: 0.8572
Epoch 7/120
469/469 [==============================] - 1s 3ms/step - loss: 0.3604 - accuracy: 0.8695 - val_loss: 0.4014 - val_accuracy: 0.8518
Epoch 8/120
469/469 [==============================] - 2s 3ms/step - loss: 0.3478 - accuracy: 0.8738 - val_loss: 0.3799 - val_accuracy: 0.8596
Epoch 9/120
469/469 [==============================] - 2s 3ms/step - loss: 0.3384 - accuracy: 0.8777 - val_loss: 0.3662 - val_accuracy: 0.8640
Epoch 10/120
469/469 [==============================] - 2s 3ms/step - loss: 0.3286 - accuracy: 0.8815 - val_loss: 0.3607 - val_accuracy: 0.8696
Epoch 11/120
469/469 [==============================] - 1s 3ms/step - loss: 0.3201 - accuracy: 0.8835 - val_loss: 0.3678 - val_accuracy: 0.8632
Epoch 12/120
469/469 [==============================] - 2s 3ms/step - loss: 0.3130 - accuracy: 0.8864 - val_loss: 0.3470 - val_accuracy: 0.8729
Epoch 13/120
469/469 [==============================] - 2s 4ms/step - loss: 0.3063 - accuracy: 0.8878 - val_loss: 0.3472 - val_accuracy: 0.8709
Epoch 14/120
469/469 [==============================] - 2s 4ms/step - loss: 0.2996 - accuracy: 0.8898 - val_loss: 0.3478 - val_accuracy: 0.8700
Epoch 15/120
469/469 [==============================] - 2s 3ms/step - loss: 0.2943 - accuracy: 0.8917 - val_loss: 0.3413 - val_accuracy: 0.8736
Epoch 16/120
469/469 [==============================] - 2s 3ms/step - loss: 0.2888 - accuracy: 0.8945 - val_loss: 0.3239 - val_accuracy: 0.8824
Epoch 17/120
469/469 [==============================] - 2s 3ms/step - loss: 0.2837 - accuracy: 0.8954 - val_loss: 0.3286 - val_accuracy: 0.8767
Epoch 18/120
469/469 [==============================] - 2s 3ms/step - loss: 0.2790 - accuracy: 0.8978 - val_loss: 0.3218 - val_accuracy: 0.8803
Epoch 19/120
469/469 [==============================] - 2s 3ms/step - loss: 0.2731 - accuracy: 0.8991 - val_loss: 0.3207 - val_accuracy: 0.8831
Epoch 20/120
469/469 [==============================] - 2s 3ms/step - loss: 0.2696 - accuracy: 0.9012 - val_loss: 0.3350 - val_accuracy: 0.8746
Epoch 21/120
469/469 [==============================] - 1s 3ms/step - loss: 0.2661 - accuracy: 0.9021 - val_loss: 0.3172 - val_accuracy: 0.8859
Epoch 22/120
469/469 [==============================] - 1s 3ms/step - loss: 0.2613 - accuracy: 0.9049 - val_loss: 0.3130 - val_accuracy: 0.8837
Epoch 23/120
469/469 [==============================] - 1s 3ms/step - loss: 0.2558 - accuracy: 0.9059 - val_loss: 0.3200 - val_accuracy: 0.8792
Epoch 24/120
469/469 [==============================] - 2s 3ms/step - loss: 0.2520 - accuracy: 0.9071 - val_loss: 0.3070 - val_accuracy: 0.8847
Epoch 25/120
469/469 [==============================] - 2s 3ms/step - loss: 0.2489 - accuracy: 0.9081 - val_loss: 0.3044 - val_accuracy: 0.8903
Epoch 26/120
469/469 [==============================] - 2s 3ms/step - loss: 0.2458 - accuracy: 0.9086 - val_loss: 0.3111 - val_accuracy: 0.8844
Epoch 27/120
469/469 [==============================] - 1s 3ms/step - loss: 0.2413 - accuracy: 0.9115 - val_loss: 0.2979 - val_accuracy: 0.8915
Epoch 28/120
469/469 [==============================] - 1s 3ms/step - loss: 0.2382 - accuracy: 0.9125 - val_loss: 0.2958 - val_accuracy: 0.8893
Epoch 29/120
469/469 [==============================] - 2s 3ms/step - loss: 0.2356 - accuracy: 0.9133 - val_loss: 0.2940 - val_accuracy: 0.8918
Epoch 30/120
469/469 [==============================] - 2s 3ms/step - loss: 0.2315 - accuracy: 0.9148 - val_loss: 0.2925 - val_accuracy: 0.8930
Epoch 31/120
469/469 [==============================] - 2s 3ms/step - loss: 0.2289 - accuracy: 0.9151 - val_loss: 0.2992 - val_accuracy: 0.8887
Epoch 32/120
469/469 [==============================] - 2s 3ms/step - loss: 0.2272 - accuracy: 0.9167 - val_loss: 0.2991 - val_accuracy: 0.8914
Epoch 33/120
469/469 [==============================] - 2s 3ms/step - loss: 0.2234 - accuracy: 0.9178 - val_loss: 0.2885 - val_accuracy: 0.8952
Epoch 34/120
469/469 [==============================] - 2s 4ms/step - loss: 0.2205 - accuracy: 0.9194 - val_loss: 0.2877 - val_accuracy: 0.8956
Epoch 35/120
469/469 [==============================] - 2s 4ms/step - loss: 0.2179 - accuracy: 0.9202 - val_loss: 0.2858 - val_accuracy: 0.8963
Epoch 36/120
469/469 [==============================] - 2s 4ms/step - loss: 0.2155 - accuracy: 0.9204 - val_loss: 0.2891 - val_accuracy: 0.8944
Epoch 37/120
469/469 [==============================] - 2s 4ms/step - loss: 0.2122 - accuracy: 0.9217 - val_loss: 0.2871 - val_accuracy: 0.8972
Epoch 38/120
469/469 [==============================] - 2s 4ms/step - loss: 0.2081 - accuracy: 0.9239 - val_loss: 0.2886 - val_accuracy: 0.8957
Epoch 39/120
469/469 [==============================] - 2s 3ms/step - loss: 0.2060 - accuracy: 0.9245 - val_loss: 0.2950 - val_accuracy: 0.8913
Epoch 40/120
469/469 [==============================] - 2s 3ms/step - loss: 0.2040 - accuracy: 0.9254 - val_loss: 0.2800 - val_accuracy: 0.9001
Epoch 41/120
469/469 [==============================] - 2s 3ms/step - loss: 0.2005 - accuracy: 0.9260 - val_loss: 0.2835 - val_accuracy: 0.8991
Epoch 42/120
469/469 [==============================] - 1s 3ms/step - loss: 0.1990 - accuracy: 0.9266 - val_loss: 0.2841 - val_accuracy: 0.8966
Epoch 43/120
469/469 [==============================] - 1s 3ms/step - loss: 0.1963 - accuracy: 0.9273 - val_loss: 0.2808 - val_accuracy: 0.8975
Epoch 44/120
469/469 [==============================] - 2s 3ms/step - loss: 0.1941 - accuracy: 0.9291 - val_loss: 0.2929 - val_accuracy: 0.8949
Epoch 45/120
469/469 [==============================] - 2s 3ms/step - loss: 0.1910 - accuracy: 0.9293 - val_loss: 0.2809 - val_accuracy: 0.8999
Epoch 46/120
469/469 [==============================] - 2s 3ms/step - loss: 0.1885 - accuracy: 0.9310 - val_loss: 0.2958 - val_accuracy: 0.8956
Epoch 47/120
469/469 [==============================] - 2s 3ms/step - loss: 0.1866 - accuracy: 0.9316 - val_loss: 0.3179 - val_accuracy: 0.8864
Epoch 48/120
469/469 [==============================] - 2s 3ms/step - loss: 0.1849 - accuracy: 0.9326 - val_loss: 0.2837 - val_accuracy: 0.8977
Epoch 49/120
469/469 [==============================] - 2s 3ms/step - loss: 0.1817 - accuracy: 0.9337 - val_loss: 0.2800 - val_accuracy: 0.8997
Epoch 50/120
469/469 [==============================] - 1s 3ms/step - loss: 0.1802 - accuracy: 0.9340 - val_loss: 0.2814 - val_accuracy: 0.9010
Epoch 51/120
469/469 [==============================] - 2s 3ms/step - loss: 0.1783 - accuracy: 0.9354 - val_loss: 0.2867 - val_accuracy: 0.8987
Epoch 52/120
469/469 [==============================] - 2s 3ms/step - loss: 0.1756 - accuracy: 0.9365 - val_loss: 0.2881 - val_accuracy: 0.8959
Epoch 53/120
469/469 [==============================] - 1s 3ms/step - loss: 0.1741 - accuracy: 0.9360 - val_loss: 0.2861 - val_accuracy: 0.8992
Epoch 54/120
469/469 [==============================] - 2s 3ms/step - loss: 0.1704 - accuracy: 0.9385 - val_loss: 0.2783 - val_accuracy: 0.9008
Epoch 55/120
469/469 [==============================] - 2s 3ms/step - loss: 0.1699 - accuracy: 0.9377 - val_loss: 0.2922 - val_accuracy: 0.8970
Epoch 56/120
469/469 [==============================] - 2s 3ms/step - loss: 0.1664 - accuracy: 0.9398 - val_loss: 0.2784 - val_accuracy: 0.9017
Epoch 57/120
469/469 [==============================] - 2s 4ms/step - loss: 0.1650 - accuracy: 0.9398 - val_loss: 0.2926 - val_accuracy: 0.8971
Epoch 58/120
469/469 [==============================] - 2s 3ms/step - loss: 0.1626 - accuracy: 0.9409 - val_loss: 0.2807 - val_accuracy: 0.9018
Epoch 59/120
469/469 [==============================] - 2s 3ms/step - loss: 0.1603 - accuracy: 0.9416 - val_loss: 0.2899 - val_accuracy: 0.8988
Epoch 60/120
469/469 [==============================] - 2s 3ms/step - loss: 0.1578 - accuracy: 0.9434 - val_loss: 0.2819 - val_accuracy: 0.9032
Epoch 61/120
469/469 [==============================] - 2s 3ms/step - loss: 0.1573 - accuracy: 0.9428 - val_loss: 0.2788 - val_accuracy: 0.9019
Epoch 62/120
469/469 [==============================] - 2s 3ms/step - loss: 0.1540 - accuracy: 0.9437 - val_loss: 0.2973 - val_accuracy: 0.8958
Epoch 63/120
469/469 [==============================] - 1s 3ms/step - loss: 0.1522 - accuracy: 0.9449 - val_loss: 0.2865 - val_accuracy: 0.9020
Epoch 64/120
469/469 [==============================] - 1s 3ms/step - loss: 0.1506 - accuracy: 0.9458 - val_loss: 0.3024 - val_accuracy: 0.8928
Epoch 65/120
469/469 [==============================] - 1s 3ms/step - loss: 0.1488 - accuracy: 0.9458 - val_loss: 0.2910 - val_accuracy: 0.8990
Epoch 66/120
469/469 [==============================] - 1s 3ms/step - loss: 0.1465 - accuracy: 0.9474 - val_loss: 0.2897 - val_accuracy: 0.9019
Epoch 67/120
469/469 [==============================] - 2s 3ms/step - loss: 0.1439 - accuracy: 0.9482 - val_loss: 0.2940 - val_accuracy: 0.8986
Epoch 68/120
469/469 [==============================] - 1s 3ms/step - loss: 0.1422 - accuracy: 0.9489 - val_loss: 0.2961 - val_accuracy: 0.8988
Epoch 69/120
469/469 [==============================] - 1s 3ms/step - loss: 0.1407 - accuracy: 0.9484 - val_loss: 0.2832 - val_accuracy: 0.9019
Epoch 70/120
455/469 [============================>.] - ETA: 0s - loss: 0.1378 - accuracy: 0.9505
Reached 95% accuracy so cancelling training!
469/469 [==============================] - 1s 3ms/step - loss: 0.1387 - accuracy: 0.9503 - val_loss: 0.2833 - val_accuracy: 0.9028

Evaluate Training Performance

Expected Output

accplot lossplot

In [42]:
%matplotlib inline
import matplotlib.pyplot as plt

# retrieve a list of list results on training and test data sets for each training epoch
acc      = history.history['accuracy']
val_acc  = history.history['val_accuracy']
loss     = history.history['loss']
val_loss = history.history['val_loss']

epochs   = range(len(acc)) # get number of epochs

# plot training and validation accuracy per epoch
plt.plot(epochs, acc, label='train accuracy')
plt.plot(epochs, val_acc, label='val accuracy')
plt.xlabel('epochs')
plt.ylabel('accuracy')
plt.legend(loc="lower right")
plt.title('Training and validation accuracy')
plt.figure()

# plot training and validation loss per epoch
plt.plot(epochs, loss, label='train loss')
plt.plot(epochs, val_loss, label='val loss')
plt.xlabel('epochs')
plt.ylabel('loss')
plt.legend(loc="upper right")
plt.title('Training and validation loss')
Out[42]:
Text(0.5, 1.0, 'Training and validation loss')
In [43]:
# show the accuracy on the testing set
print("[INFO] evaluating...")
(loss, accuracy) = model.evaluate(testData, testLabels,
                                  batch_size=batch_size, verbose=1)
print("[INFO] accuracy: {:.2f}%".format(accuracy * 100))
[INFO] evaluating...
79/79 [==============================] - 0s 3ms/step - loss: 0.2833 - accuracy: 0.9028
[INFO] accuracy: 90.28%
In [44]:
model.save_weights("LeNetFashionMNIST.temp.hdf5", overwrite=True)

Evaluate Pre-trained Model

In [45]:
# init model and load the model weights
print("[INFO] compiling model...")
model = LeNet.init(numChannels=image_channels, 
                    imgRows=image_height, imgCols=image_width,
                    numClasses=num_classes,
                    weightsPath="/content/LeNetFashionMNIST.temp.hdf5")

# compile the model
model.compile(optimizer=tf.keras.optimizers.SGD(lr=0.01),  # Stochastic Gradient Descent
        loss="sparse_categorical_crossentropy",
        metrics=["accuracy"])
[INFO] compiling model...
In [46]:
# show the accuracy on the testing set
print("[INFO] evaluating...")
batch_size = 128
(loss, accuracy) = model.evaluate(testData, testLabels,
                                  batch_size=batch_size, verbose=1)
print("[INFO] accuracy: {:.2f}%".format(accuracy * 100))
[INFO] evaluating...
79/79 [==============================] - 0s 2ms/step - loss: 0.2833 - accuracy: 0.9028
[INFO] accuracy: 90.28%

Model Predictions

In [51]:
%matplotlib inline
import numpy as np
import cv2
import matplotlib.pyplot as plt

# set up matplotlib fig, and size it to fit 3x4 pics
nrows = 3
ncols = 10
fig = plt.gcf()
fig.set_size_inches(ncols*4, nrows*4)

# randomly select a few testing digits
num_predictions = 25
test_indices = np.random.choice(np.arange(0, len(testLabels)), size=(num_predictions,))
test_images = np.stack(([testData[i] for i in test_indices]))
test_labels = np.stack(([testLabels[i] for i in test_indices]))

# compute predictions
predictions = model.predict(test_images)

for i in range(num_predictions):
    # select the most probable class
    prediction = np.argmax(predictions[i])

    # rescale the test image
    image = (test_images[i] * 255).astype("uint8")

    # resize the image from a 28 x 28 image to a 96 x 96 image so we can better see it
    image = cv2.resize(image, (96, 96), interpolation=cv2.INTER_CUBIC)

    # convert grayscale image to RGB color
    image = cv2.merge([image] * 3)

    # select prediction text color
    if prediction == test_labels[i]:
        rgb_color = (0, 255, 0) # green for correct predictions
    else:
        rgb_color = (255, 0, 0) # red for wrong predictions

    # show the image and prediction
    cv2.putText(image, str(class_names[prediction]), (0, 10),
                cv2.FONT_HERSHEY_SIMPLEX, 0.5, rgb_color, 1)
    
    # set up subplot; subplot indices start at 1
    sp = plt.subplot(nrows, ncols, i + 1, title="label: %s" % class_names[test_labels[i]])
    sp.axis('Off') # don't show axes (or gridlines)
    plt.imshow(image)

# show figure matrix
plt.show()
In [ ]: