(pynenv setup with python 3.9.20)
The story so far
In my previous post about Convolutional Neural Networks (CNNs), I demonstrated how to train a model for handwritten digit recognition. Now, we’ll take this model further by deploying it in a web application using TensorFlow.js. This post guides you through the process of exporting the trained model and implementing it in an Observable notebook where users can draw digits on a canvas for real-time classification.
Model export and conversion
(If you’re just interested in the converted model, you can skip the next few steps and scroll down to download the converted files.)
We can export the trained model object model
in Tensorflow’s ‘SavedModel’ format using model.save()
. This will generate a directory called mnist_cnn_savedmodel in your current working directory, which contains all the files needed for a subsequent conversion to TensorFlow.js format.
# Export the model in 'savedModel' format
model.export("mnist_cnn_savedmodel")
The file structure should look like this:
mnist_cnn_savedmodel/
├── fingerprint.pb
├── saved_model.pb
└── assets/
└── variables/
To convert the exported model to tensorflow.js format, we use the tensorflowjs_converter
tool, which comes with the tensorflowjs
Python package. We likely need to install it first:
pip install tensorflowjs
Afterwards, run the following in a shell (paths are relative to the project directory):1
# convert model to tensorflow.js
tensorflowjs_converter \
--input_format=tf_saved_model \
--output_format=tfjs_graph_model \
./mnist_cnn_savedmodel \
./mnist_cnn_tfjs_output
The output folder, mnist_cnn_tfjs_output (located in your working directory), contains two relevant files:
- model.json , representing the converted model in JSON format (model topology), and
- group1-shard1of1.bin , storing the associated weights in binary format.
These files can be used to run predictions using our Python-trained model in a tensorflow.js environment.
Importing and loading the model in TensorFlow.js
To deploy this model in a web environment, we need to create a frontend component that includes:
- A webpage that loads the TensorFlow.js library
- Input elements (like a canvas for drawing)
- JavaScript logic to process the input into the correct format (28x28 grayscale image)
- Code to run predictions using our trained model
I’ve implemented all this in an Observable notebook, which provides an interactive interface where users can draw digits and get real-time predictions from the model. While the detailed implementation of the interactive components and the conversion of input to a suitable format is beyond the scope of this post, I will provide a brief overview of how to load the model and encourage you to work through the other step yourself using the notebook.
The following chunk shows how to load the pre-trained TensorFlow.js model from local files.2 We perform the following steps:
- Load the model architecture from a JSON file.
- Load the model weights from a binary file.
- Gather the model artifacts: the model topology and weights.
- Use an IOHandler to load the model artifacts with
tf.loadGraphModel()
3, wheretf
is my reference to the tensorflow.js library.
mdl = {
// load model.json content
let modelJSON = await FileAttachment("model.json").json();
// load the weight file as ArrayBuffer
let weightData = await FileAttachment("group1-shard1of1.bin").arrayBuffer();
// assemble model artifacts
let modelArtifacts = {
modelTopology: modelJSON.modelTopology || modelJSON.graph,
weightSpecs: modelJSON.weightsManifest[0].weights,
weightData: weightData
};
// IOHandler
const ioHandler = {
load: async function() {
return modelArtifacts;
}
};
// load the model using tf.loadGraphModel
return await tf.loadGraphModel(ioHandler);
}
Now we can use the tensorflow.js interface to run prediction using the imported model object mdl
!
Below are the interactive components of my Observable notebook in action. You may use any common tracking device (touchscreen, trackpad mouse) to draw a digit on the input canvas (left). Results are best if you draw your digits as large as possible.
Click the predict button to get a the classification result as a (predefined) image on the output canvas (middle)! The widget also shows predicted class probabilities (right).
If you are reading on a small-screen mobile device, I suggest you switching to landscape mode.
After testing the application extensively, I observed that the misclassification rate is higher compared to the test dataset results shown in my previous post. In particular, the model has trouble with (my) nines, which seem to ’look’ like threes as indicated by the predicted class probabilities. This discrepancy to the test set performance can be attributed to several factors:
-
Input device variation: The model was trained on standardized handwritten digits, while my app accepts input from various devices (mouse, trackpad, touchscreen), each producing different stroke patterns:
On input devices, we might draw digits in ways that deviate significantly from the training data’s characteristics, making predictions more challenging.
-
Input preprocessing: Although I made some effort to implement decent input standardization (cropping, centering, resizing), there likely is room for improvement in matching the preprocessing pipeline with the original MNIST dataset characteristics.
I might address some of these points in later posts, but that’s it for now 😊.
-
This should work fine for UNIX-like shells. You need to adjust it for Windows CP or powershell! ↩︎
-
Note that this code is written as a cell in Observable’s reactive environment. You’ll need to wrap it in a function in other contexts. ↩︎
-
An IOHandler manages input and output operations, abstracting the details of reading from and writing to data sources. ↩︎