(pynenv setup with python 3.9.20)
Convolutional Neural Networks (CNNs) are powerful tools in image recognition tasks. A classic problem solved by CNNs is predicting handwritten digits from the MNIST dataset. Here’s how to train a CNN for handwritten digit classification using TensorFlow’s Keras API in Python.
Dataset and Preprocessing
The MNIST dataset (delivered in tensorflow.keras.datasets
) consists of 70,000 grayscale images, each representing a digit (0–9). We first load the TensorFlow library as tf
, preprocess the data by reshaping the images into a 28x28 pixel format with 1 channel (grayscale), and normalize pixel values to the range $[0, 1]$. The labels are one-hot encoded for multi-class classification.
Note that the mnist.load_data()
function from TensorFlow’s Keras API automatically loads the dataset and splits it into touples of training and test sets.
import tensorflow as tf
# load the MNIST dataset
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# Preprocessing
x_train = x_train.reshape(-1, 28, 28, 1).astype("float32") / 255.0
x_test = x_test.reshape(-1, 28, 28, 1).astype("float32") / 255.0
# Convert labels to one-hot encoding
y_train = tf.keras.utils.to_categorical(y_train, 10)
y_test = tf.keras.utils.to_categorical(y_test, 10)
To get an idea of how the observation in MNIST look like, we sample a few examples from the training set and plot these side-by-side using matplotlib.pyplot
.
import matplotlib.pyplot as plt
import numpy as np
# seed for reproducibility
np.random.seed(1337)
# sample a few images from the training data
num_samples = 5
indices = np.random.choice(x_train.shape[0], num_samples, replace=False)
sample_images = x_train[indices]
sample_labels = y_train[indices]
# plot the sampled images
plt.figure(figsize=(10, 2))
for i in range(num_samples):
plt.subplot(1, num_samples, i + 1)
plt.imshow(sample_images[i].reshape(28, 28), cmap='gray')
plt.title(np.argmax(sample_labels[i]))
plt.axis('off')
plt.show()
Setting up the CNN
Convolutional Neural Networks (CNNs) are well-suited for handwritten digit classification because they are designed for extracting spatial features from image data. Use of shared weights and local connectivity reduces the number of parameters in CNNs, making them computationally efficient.
Convolution (surprise!) is the key operation in a convolutional layer, where a small kernel (3x3 in the model below) slides across the layer input, performing element-wise multiplication with overlapping regions and summing the results to create a feature map, also called a filter. This process detects patterns like edges or textures by focusing on local regions of the input (known as receptive fields). After convolution, results are passed through activation functions (ReLU here) to introduce non-linearity.
CNNs often have multiple convolutional blocks, with more filters in deeper layers. Initial layers capture basic features, while deeper layers learn complex patterns. More filters enable the network to learn a richer set of features.
We build the CNN with two convolutional blocks, each followed by max-pooling and dropout layers:
-
Pooling layers retains the most prominent features (e.g., edges or textures) in the feature map, ensuring that only the most important information is passed forward. They provide translation invariance, enabling the model to recognize digits even when they vary in position, scale, or orientation.
-
Dropout layer regularize the training by randomly setting a fraction of outputs of the previous layer to zero during training (25% in the conv. layers), preventing overfitting by reducing reliance on specific neurons.
-
The dense layer at the end maps the learned features to the 10 digit classes.1. The softmax function outputs a probability distribution over the classes by normalizing the raw scores (logits) into probabilities that sum to 1. The class (digit) with the highest estimated probability is the CNNs prediction.
Training
Next, we compile the model with the Adam optimizer using the categorical cross-entropy loss function and accuracy as a metric. We train for 20 epochs with a batch size of 128.
# Compile the model
# def. optimizer
opt = tf.keras.optimizers.Adam(learning_rate=1e-3)
# compile model
model.compile(
optimizer=opt,
loss="categorical_crossentropy",
metrics=["accuracy"]
)
# train the model
model.fit(
x_train,
y_train,
epochs=20,
batch_size=128,
validation_split=.2,
verbose=False
)
<keras.src.callbacks.history.History at 0x1307dcee0>
We can now evaluate the model model
on the test data x_test
/y_test
.
# evaluate on test set
test_loss, test_accuracy = model.evaluate(
x_test,
y_test,
batch_size=128,
verbose=False
)
print(
f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.4%}"
)
Test Loss: 0.0204, Test Accuracy: 99.3900%
The outcomes show that our CNN model generalizes very well to the unseen test cases, classifying over 99% of the observations correctly.
What next?
One might ask: “Your model performs well on the test set, but what practical applications does it have?”. Well, a CNN for handwritten digit classification can be embedded in web-based applications to handle user input on touch screens. This allows for real-time digit recognition tasks in various scenarios, such as form filling or educational tools.
TensorFlow.js is a library for deploying machine learning models in JavaScript environments. While it is possible to train models directly in TensorFlow.js, it is generally more efficient to perform training in an environment with greater computational power, such as Python on a high-performance computer (which I showcase here): Training deep learning models can be computationally intensive and time-consuming, requiring powerful GPUs or TPUs to achieve reasonable training times.2 By training the model in Python, we can take advantage of optimized libraries and hardware accelerators, ensuring faster and more efficient training.
After training in Python, we can export the model to TensorFlow.js format for web deployment, combining Python’s training efficiency with JavaScript’s interactive capabilities.
That’s it for now :-).
-
Here dropout before the output layer ensures that the last layer combines features without over-relying on any single high-level feature, further improving generalization to test data. ↩︎
-
Large Language Models (LLMs) are an extreme example of this principle: Their training requires multi-million dollar computation centers and petabytes of data, yet the trained models can run inference tasks on personal computers. ↩︎