Introduction to Transfer Learning and Neural Networks
In this article, we will look at how to repurpose a pre-trained image recognition neural network with Transfer Learning. Transfer Learning is the process of adapting a pre-trained neural network and only training the specific modules or layers associated with the problem at hand. Transfer Learning significantly reduces neural network development effort and training time. We will describe how to repurpose a pre-trained neural network to classify flowers and then repurpose it to classify patient Chest X-Rays. The results are very encouraging where retraining a model with thousands of images only takes a few minutes on a MacBook. Most importantly, the prediction accuracy is better than 90%.
Image recognition using deep neural networks has come a long way in the last couple of years. Amazon Web Services (AWS) and Google both provide general purpose Application Programming Interfaces (API) that can recognise images across a large range of general purpose categories. Google have the Vision API – try it out for yourself here. These APIs are general in nature. However, consider trying to recognise something very specific or something that is not in the general set of examples or not defined to a required level of detail. You could be faced with the prospect of developing a custom solution. In our case, deep convolutional neural networks to classify flowers and Chest X-Rays.
To give you an idea, image recognition neural networks are very complex with potentially thousands of neurons and between 25 to 150 or more layers. To design, train and validate these networks is complex, expensive and time consuming requiring deep expertise. But what if you could use the vision capability of a pre-trained neural network and reuse it to concentrate on your specific problem! This is called Transfer Learning– the process of training only a part of a pre-trained neural network. Suddenly specific image classification using neural networks becomes practical.
Transfer Learning with TensorFlow
TensorFlow is an open source software library for high performance numerical computation. TensorFlow was originally developed by Google and graciously provided to the open source community. This article is focused on using Image Recognition and Transfer Learning by retraining a TensorFlow pre-trained convolutional neural network. The process used is based on a TensorFlow Tutorial. The pre-trained image recognition model provides image extraction features. The model is based on the TensorFlow Image Modules including the Inception-V3 and MobileNetV1 modules. While all this sounds complicated, it is relatively easy to apply – given you understand neural network fundamentals and can use UNIX shell commands.
The TensorFlow tutorial comes with a Python script that pre-processes the images and executes four thousand training steps. Each step chooses ten images at random from the training set and feeds them into the model to get predictions. These predictions are then compared against the actual labels (image classifications) to update the final layer’s weights through a back-propagation learning process. This means only the final layer is being trained– not all 25 layers of the pre-trained network. As the training process continues, the accuracy improves as the network converges to a solution. After all the training steps are completed, a final test accuracy evaluation is run on a set of images which are kept separate from the training images. This evaluation is an estimate of how the trained model will predict an image. When the model has been trained, it can be used to predict classes of images that have or have not been presented to the model during training.
To demonstrate how TensorFlow can be applied, the model was retrained with two separate datasets to demonstrate two distinct applications:
- Flower Classification for five types of flowers
- Chest X-Ray Classification as Normal or Pneumonia
The flower dataset consists of five classes of flowers: Daisy, Dandelion, Rose, Sunflower and Tulip. There are approximately four thousand images in the dataset. The images are very diverse with different backgrounds and views. It took approximately eight minutes to retrain the model using a MacBook. The following table illustrates a small sample of a diverse range of flower images for each class of flower.
The accuracy of the model was shown to be:
- Train accuracy = 97.0%
- Validation accuracy = 91.0% (100 samples)
- Test accuracy = 92.3% (362 samples)
The results indicate that the model performs very well in classifying flowers for the class of flowers it is trained for. If an unknown class of flower is presented, e.g., a Daffodil, to the trained model then the outcome would be unpredictable.
Chest X-Ray Classification
The model was then retrained for Chest X-Rays by simply pointing the Python script to a different dataset location. The goal of the Chest X-Ray application is to classify a patient’s Chest X-Ray as showing Pneumonia in the lungs or showing a Normal result. The Chest X-Ray images were sourced from the Kaggle Chest X-ray dataset. The Chest X-ray images were selected and classified (labelled) by a study conducted at the Guangzhou Women and Children’s Medical Center, China . There are 5,863 Chest X-Ray images split across 2 categories – Normal and Pneumonia.The following images are indicative of the dataset.
It took approximately seven minutes retrain the model using the training dataset. The accuracy of the model was:
- Train accuracy = 96.0%
- Validation accuracy = 99.0% (N=100)
- Test accuracy = 95.7% (N=531)
The results indicate that the model performs extremely well in classifying Pneumonia. The model is constrained to detecting a Normal or Pneumonia Chest X-Ray result. It cannot differentiate between anything else including other chest conditions.
Image recognition is complex and until recently beyond the capability of many software development and system integration organisations. The ability to repurpose image recognition neural networks using Transfer Learning has many benefits, including:
- Greatly reduced model development time, effort and expertise
- Greatly reduced computations and time required to train a model
- Opens possibilities that were thought too hard, too expensive or not thought of.
Some of the challenges moving forward include data consistency, systems integration and system integrity. Having enough of the right data that is classified (labelled) correctly is key to developing a working model. A model’s accuracy is dependent on the data it is trained on. Once a model is developed and trained it needs to be deployed in a user-friendly manner that is accessible, i.e., the model needs to be integrated into a workable system. This article has presented closed classification examples, i.e., a specific set of classes. Presenting images that are not in the trained classification set have no meaning. In the real-world this is a challenge. For example, a Chest X-Ray indicates a serious condition that is not Pneumonia may be completely missed or worse: classified as Normal – we would prefer the answer to be Unknown. Serious attention must be made to the overall system integrity.
Image recognition using neural networks is a complex subject. Yet it has come a long way and is actively been used in real applications. It will be interesting to see how it progresses and will be used in the future.
 Identifying Medical Diagnoses and Treatable Diseases by Image-Based Deep Learning. Kermany; Daniel, Zhang, Kang; Goldbaum, Michael. https://www.cell.com/cell/fulltext/S0092-8674(18)30154-5