This package provides training and evaluation code to perform Digit Recognition with Convolutional Neural Network on the MNIST dataset. The accuracy on the test set is 99.47%.
The package depends on:
- TF-Slim: The lightweight syntactic sugar library of TensorFlow
- Menpo Project: The Python framework for data handling and deformable modeling.
In general, as explained in Menpo's installation instructions, it is highly recommended to use conda as your Python distribution.
Once downloading and installing conda, this project can be installed by:
Step 1: Create a new conda environment and activate it:
$ conda create -n mnist python=3.5
$ source activate mnist
Step 2: Install TensorFlow following the official installation instructions. For example, for 64-bit Linux, the installation of GPU enabled, Python 3.5 TensorFlow involves:
(mnist)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.10.0rc0-cp35-cp35m-linux_x86_64.whl
(mnist)$ pip install --upgrade $TF_BINARY_URL
Step 3: Install menpo from the menpo channel as:
(mnist)$ conda install -c menpo menpo
Step 4: Clone and install the digitrecognition
project as:
(mnist)$ cd ~/Documents
(mnist)$ git clone git@github.com:nontas/digitrecognition.git
(mnist)$ pip install -e digitrecognition/
The solution implemented in this package is the following:
-
Data pre-processing
During training, each image is pre-processed in the following way:- The image is rotated around its centre with a random angle.
- The image is skewed (distorted) with random angles.
The pre-proceessing is happenning on the fly, i.e. every time a new example is loaded the pre-processing is applied. Given that we also do not limit the number of batches, the system will keep training with an infinite number of randomly perturbed examples. This pre-processing is implemented using Menpo. You can find the implementation in data_provider.py. To get an idea of the results of the employed pre-processing, you can run the following code in a Jupyter notebook:
%matplotlib inline from numpy.random import randint import matplotlib.pyplot as plt from menpo.image import Image from digitrecognition import import_mnist_data from digitrecognition.data_provider import preprocess # Load train images train_images, _, _, _, _, _ = import_mnist_data(verbose=True) # Generate random image index i = randint(0, len(train_images)) # Pre-process image im = train_images[i].pixels_with_channels_at_back()[..., None] im = preprocess(im) # Plot before and after plt.subplot(121) train_images[i].view() plt.subplot(122) Image.init_from_channels_at_back(im).view()
-
Network architecture:
After trying a few network architectures, the best performing one is:- Convolutional layer (64 filters, 5x5 kernel, batch normalization)
- Max-Pooling layer (2x2 kernel)
- Convolutional layer (32 filters, 5x5 kernel, batch normalization)
- Max-Pooling layer (2x2 kernel)
- Fully Connected layer (1024 outputs, batch normalization)
- Fully Connected layer (10 outputs)
The definitions of the various architectures can be found in model.py.
-
Learning rate decay:
The experiments showed that to decay the learning rate can help a lot. The initial value of the learning rate is0.001
and then it decreases with a rate of0.9
every10000
steps. Refer to train.py for more details on how this is implemented. -
Optimizer:
The employed optimizer istf.train.AdamOptimizer
which proved better thantf.train.RMSPropOptimizer
.
To run the training and evaluation, do the following:
Data Collection: In the terminal, run
(mnist)$ python digitrecognition/data_converter.py
which will download the MNIST data, if are not
already downloaded, load them and convert them to tfrecords
files.
The files are stored in the data/
folder. Note that the data will be split in the
training, validation and testing sets.
Training: To train the model, run:
(mnist)$ python digitrecognition/train.py
which will initiate the training. Various arguments can be passed in through that function:
--architecture ARCHITECTURE
The network architecture to use: baseline, ultimate,
ultimate_v2.
--batch_size BATCH_SIZE
Batch size.
--num_train_batches NUM_TRAIN_BATCHES
Number of batches to train (epochs).
--log_train_dir LOG_TRAIN_DIR
Directory with the training log data.
--initial_learning_rate INITIAL_LEARNING_RATE
Initial value of learning rate decay.
--decay_steps DECAY_STEPS
Learning rate decay steps.
--decay_rate DECAY_RATE
Learning rate decay rate.
--eval_set EVAL_SET The dataset to evaluate on: train, validation or test
--num_samples NUM_SAMPLES
Number of samples to evaluate.
--log_eval_dir LOG_EVAL_DIR
Directory with the evaluation log data.
--momentum MOMENTUM Optimizer .
--optimization OPTIMIZATION
The optimization method to use. Either 'rms' or
'adam'.
--verbose [VERBOSE] Print log in terminal.
These arguments and their default values are defined in params.py.
Note that by default, the training log files are stored in ./log/train/
.
Evaluation: To evaluate the model on the validation set, run:
(mnist)$ python digitrecognition/eval.py
Note that by default, the validation log files are stored in ./log/eval/
.
Testing: Testing can be performed as:
(mnist)$ python digitrecognition/eval.py --eval_set=test --log_eval_dir=./log/eval_test
TensorBoard: You can simultaneously run the training and validation. The results can be observed through TensorBoard. Simply run:
(mnist)$ tensorboard --logdir=log
This makes it easy to explore the graph, data, loss evolution and accuracy on the validation set.
The ultimate
architecture achieves the best accuracy on the test set,
followed by ultimate_v2
and baseline
. Specifically, the accuracy on the test set is:
Architecture | Accuracy | Steps |
---|---|---|
ultimate |
99.47% | ~210k |
ultimate_v2 |
99.37% | ~257k |
_Figure 1: The streaming accuracy of `ultimate` network. The blue and orange lines report the accuracy on the validation and test set._
_Figure 2: The streaming loss of `ultimate` (blue) and `ultimate_v2` (green) networks._
_Figure 3: The graph of the `ultimate` network._