How to Handle Overfitting In Deep Learning Models
How to Handle Overfitting In Deep Learning Models
Deep learning is one of the most revolutionary technologies at present. It gives machines the ability to think and learn on their own. The key motivation for deep learning is to build algorithms that mimic the human brain.
To achieve this we need to feed as much as relevant data for the models to learn. Unlike machine learning algorithms the deep learning algorithms learning won’t be saturated with feeding more data. But feeding more data to deep learning models will lead to overfitting issue.
That’s why developing a more generalized deep learning model is always a challenging problem to solve. Usually, we need more data to train the deep learning model. In order to get an efficient score we have to feed more data to the model. But unfortunately, in some cases, we face issues with a lack of data.
One of the most common problems with building neural networks is overfitting. The key reason is, the build model is not generalized well and it’s well-optimized only for the training dataset. In layman terms, the model memorized how to predict the target class only for the training dataset.
The other cases overfitting usually happens when we don’t have enough data, or because of complex architectures without regularizations.
If we don't have the sufficient data to feed, the model will fail to capture the trend in data. It tries to understand each and every data point in training data and performs poorly on test/unseen data.
Learn how to handle overfitting in deep learning models.
In some cases, the model is overfitted if we use very complex neural network architecture without applying proper data preprocessing techniques to handling the overfitting.
So we need to learn how to apply smart techniques to preprocess the data before we start building the deep learning models. These techniques we are going to see in the next section in the article.
In this article, you are going to learn how smartly we can handle overfitting in deep learning, this helps to build the best and highly accurate models.
Before we drive further let’s see what you learning in this article.
Deep learning Introduction
High-end research is happening in the deep learning field, every day some new features or new model architecture or well-optimized models were going up to give continuous updates in this field. This makes the deep learning field young all the time, its growth rate is exponentially increasing.
The growth of this field is reasonable and expected one too. If we observe, In the past two decades back, we had problems like storing data, data scarcity, lack of high computing processors, cost of processors, etc.
At present, the scenario was completely different. Big data came into picture which allows you to store huge amounts of data so easily. We are having very powerful computing processors with very low/cheap cost. And also we can solve almost any problem with the help of neural networks.
Deep learning algorithms have a lot of different architectures like
- ANN (Artificial Neural Networks),
- CNN (Convolutional Neural Networks),
- RNN (Recurrent Neural Networks), etc
To solve complex problems in an efficient manner. It is able to perform different kinds of approaches in a better way. The architectures are giving the ability to classify the images, detect the objects, segment the objects/images, forecasting the future, and so on.
Deep Learning Applications
We have plenty of real-world applications in deep learning, Which makes this field super hot.
You can see a few examples below
- Auto Image Captioning
- Automatic image captioning is the task were given an image the model is able to generate a caption that describes the contents of the given image.
- Self-driving cars
- This is one of the greatest inventions which the car can go, drive without a driver. It is able to distinguish different types of objects, road signals, peoples, etc, and drives without human intervention. Many companies are building these types of cars using deep learning.
- Healthcare Sector
- Deep learning is also widely used in medical fields that are able to assist the patients. Able to classify the diseases, segment the images, etc. It is able to predict human health conditions in the future.
- Voice assistant
- Your favorite voice assistant uses deep learning every time it’s used. Siri for example uses deep learning to both recognize your voice and “learn” based on your queries.
If you haven’t heard about overfitting and don't know how to handle overfitting don’t worry. In the next couple of sections of this article, we are going to explain it in detail.
Different issues with deep learning models
In general, once we complete model building in machine learning or deep learning. The build models face some common issues, it’s worth investing the issues before we deploy the model in the production environment. The two common issues are
- Overfitting
- Underfitting
In this article, we are focusing only on how to handle the overfitting issue while building deep learning models.
Before we learn the difference between these modeling issues and how to handle them, we need to know about bias and variance.
Bias
It is simply how far our predicted value is with respect to the actual value. We have two different types in bias, they are:
- Low Bias: Suggests less far from the actual target value
- High-Bias: Suggests more far from the actual target value.
Variance
Variance means when a model performs well on train data during training and does not generalize on the new data. It is simply the error rate of the test data. How much it is varying the performance/accuracy on training and testing.
We have two different types of invariance, they are:
- Low variance: shows less difference in test accuracy with respect to train accuracy.
- High-variance: shows a high difference in test accuracy with respect to train accuracy.
Bias variance tradeoff
Finding the right balance between bias and variance of the model is called the Bias-variance tradeoff. If our model is too simple and has very few parameters then it may have high bias and low variance.
On the other hand, if our model has a large number of parameters then it’s going to have high variance and low bias. So we need to find a good balance without overfitting and underfitting the data.
You can clearly see the picture to know more
From the diagram we have to know a few things;
- Low bias & Low variance -------> Good model
- Low bias & High Variance -------> Overfitted model
- High bias & Low variance ------> Under fitted model
By now we know all the pieces to learn about underfitting and overfitting, Let’s jump to learn that.
What is Underfitting
If the model shows high bias on both train and test data is said to be under the fitted model. In simple terms, the model fails to capture the underlying trend of the data. It gives a poor performance on both training and testing data.
As we said earlier In this article, we are focusing only on dealing with overfitting issues.
What is Overfitting
If the model shows low bias with training data and high variance with test data seems to be Overfitted. In simple terms, a model is overfitted if it tries to learn data and noise too much in training that it negatively shows the performance of the model on unseen data.
The problem with overfitting the model gives high accuracy on training data that performs very poorly on new data (shows high variance).
Overfitting example
We can clearly see how complex the model was, it tries to learn each and every data point in training and fails to generalize on unseen/test data.
The above example showcaes the overfitting in regression kind of models.
How about classification problem? In classification models we check the train and test accuracy to say a model is overfitted or not.
Have a look at the below classification model results on train and test set in below table
Models | Train Accuracy | Test Accuracy |
---|---|---|
Model 01 | 97% | 57% |
We can clearly see the model performing well on training data and unable to perform well on test data.
You can also see loss difference in graphical representation
Model with overfitting issue
Now we are going to build a deep learning model which suffers from overfitting issue. Later we will apply different techniques to handle the overfitting issue.
We are going to learn how to apply these techniques, then we will build the same model to show how we improve the deep learning model performance.
Before that let’s quickly see the synopsis of the model flow.
Synopsis of the model we are going to build
Before we are going to handle overfitting, we need to create a Base model
- First, we are going to create a base model in order to showcase the overfitting
- In order to create a model and showcase the example, first, we need to create data. we are going to create data by using make_moons() function.
- Then we fit a very basic model (without applying any techniques) on newly created data points
- Then we will walk you through the different techniques to handle overfitting issues with example codes and graphs.
Data preparation
The make_moons() function is for binary classification and will generate a swirl pattern, or two moons
parameters:
- n_samples - int: the total number of points generated optional (default=100)
- shuffle- bool: whether to shuffle the samples.optional (default=True)
- noise- double or None: the standard deviation of Gaussian noise added to the data (default=None)
- random_state- int: RandomState instance, default=None
Returns:
- Xarray of shape [n_samples, 2]
- Y array of shape [n_samples], the integer labels (0 or 1) for class membership of each sample
Model Creation
Here, we are creating a sequential model with two layers, with binary_crossentropy loss.
Model Evaluation
Let’s see both training and validation loss in graphical representation.
We can clearly see that it is showing high variance according to test data.
By now you know the above build deep learning model having the overfitting issue. Now let’s learn how to handle such overfitting issues with different techniques.
Techniques to Handle Overfitting In Deep Learning
For handling overfitting problems, we can use any of the below techniques, but we should be aware of how and when we should use these techniques.
Let’s learn about these techniques one by one.
- Regularization
- Dropout
- Data Augmentation
- Early stopping
Regularization
Regularization is one of the best techniques to avoid overfitting. It can be done by simply adding a penalty to the loss function with respect to the size of the weights in the model. By adding regularization to neural networks it may not be the best model on training but it is able to outperform well on unseen data.
You can see the example below:
Regularized model
In the above code, we are
- Creating an instance of Sequential class
- Adding the input layer with 2 input dimensions,500 neurons, relu activation function, and L2 kernel regularizer
- Adding the output layer with 1 neuron, sigmoid activation function, and L2 kernel regularizer
- Compile the model with ‘binary_crossentrophy’ loss, adam optimizer and accuracy metric
- Finally fit the model on both training and validation data with 4000 epochs.
Model Evaluation
We can see that the model is not showing high variance with respect to test data. By adding regularization we are able to make our model more generalized.
Dropout
Dropout is simply dropping the neurons in neural networks. During training a deep learning model, it drops some of its neurons and trains on rest. It updates the weights of only selected or activated neurons and others remain constant.
For every next/new epoch again it selects some nodes randomly based on the dropout ratio and keeps the rest of the neurons deactivated. It helps to create a more robust model that is able to perform well on unseen data.
You can see the example below
In the above code, we are
- Creating an instance of Sequential class
- Adding an input layer with 2 input dimensions ,500 neurons,relu activation function and 0.5 dropout ratio.
- Adding a hidden layer with 128 hidden neurons,relu activation function, and 0.25 dropout ratio.
- Adding the output layer with 1 neuron and sigmoid activation function
- Compile the model with ‘binary_crossentrophy’ loss, adam optimizer and accuracy metric
- Finally fit the model on both training and validation data with 500 epochs.
Model Evaluation
Data Augmentation
We can prevent the model from being overfitted by training the model on more numbers of examples. We can increase the size of the data by applying some minor changes in the data.
Examples:
- Translations,
- Rotations,
- Changes in scale,
- Shearing,
- Horizontal (and in some cases, vertical) flips.
This technique mostly used for only CNN’s
Data Augmentation code snippet
In order to generate the data, we have a method called ImageDataGenerator which is available in Keras library.
You can see the demo of Data Augmentation below
Early Stopping
It is one of the most universally used techniques in which we can smartly overcome the overfitting in deep learning. Too many epochs can lead to overfitting of the training dataset. In a way this a smar way to handle overfitting.
Early stopping is a technique that monitors the model performance on validation or test set based on a given metric and stops training when performance decreases.
You can find the example below
In the above code, we are
- Creating an instance of Sequential class.
- Adding an input layer with 2 input dimensions,128 neurons, and relu activation function.
- Adding the output layer with 1 neuron and sigmoid activation function
- Compile the model with ‘binary_crossentrophy’ loss, adam optimizer and accuracy metric
- Creating a callback which can keep on monitor the ‘val_loss’, helps to stop the epochs when val_loss increases.
- Finally fit the model on both training and validation data with 2000 epochs and defined callbacks.
Model Evaluation
Complete Code
Below is the complete code used in this aricle. You can also fork this code in our GitHub repository.
Conclusion
Each technique approaches the problem differently and tries to create a model more generalized and robust to perform well on new data. We have different types of techniques to avoid overfitting, you can also use all of these techniques in one model.
Don't limit youself to consider only these techniques for handle overfitting, you can try other new and advanced techniques to handle overfitting while building deep learning models.
We can't say which technique is better, try to use all of the techniques and select the best according to your data.
Suggestions
- Classical approach: use early stopping and L2 regularization
- The modern approach: use early stopping and dropout, in addition to regularization.
Recommended Deep Learning courses
Deep Learning
Specializations
Rating: 4.5/5
Deep Learning A to Z Python Course
Rating: 4/5
Deep Learning With Tensorflow
Rating: 4.6/5
I found this article is very useful for the understanding of overfitting in DL models…
Thankyou! Lavanya, I’m happy to hear that.