Stochastic Gradient Descent, Part I, Gradient descent on linear, quadratic and sinusoidal data
After seeing the visualisations of gradient descent in the FastAI course, I thought I'd try to create my own. I start by looking at gradient descent applied to linear, quadratic and sinusoidal data.
Other posts in series
Introduction
At the end of Lecture 3 of the 2020 FastAI course and at the end of Lecture 2 of the 2018 FastAI course, there are visualisations of the gradient descent algorithm. I quite liked them, in particular the animation from the 2018 version, and I wanted to re-create them and on more complex examples.
The animations I created are available below. Note that in all animations, the orange dots represent the training data, and the blue line represents the model’s predictions. I will go through them and give my thoughts. At the end I describe some insights I gained by doing this.
Linear data
I created some linear data y = a*x + b + noise
, and then tried to use gradient descent to determine the coefficients.
This animation is representative of the various examples I tried for linear data. Gradient descent is quick to get very close to the data, but then the learning dramatically slows down and it takes many iterations to improve further. (Note, you have to pay close attention to notice that there is still learning going on throughout the whole video). Clearly, there is some optimisation that can be done with the learning rate; I did try to create a cutoff point where the learning rate gets bigger, but I am sure there are much better ways of doing things.
Quadratic data
Next I created some quadratic data y=a*x*x + b*x + c + noise
.
The pattern here is very similar to the pattern for the linear case: gradient descent quickly reaches a good model, and then the learning dramatically slows down. This is not too surprising, because though the final function is non-linear, this is still a linear-regression problem by treating x*x
and x
as separate features.
Sinusoidal data
Next I created some sinusoidal data y = a*(sin(b*x + c)) + d
. Things were more interesting here.
The first video shows you what happened when I chose a learning rate that was too large (but not so large so as to have everything diverge to infinity):
Crazy, right! The model is oscillating back and forth, with the oscillations slowly getting larger with time. In Lecture 3 of the 2020 course, this behaviour is illustrated with the example of using gradient descent to minimise a quadratic function, but I never thought I would actually encounter this behaviour out in the wild.
This second video shows what happens when I choose a smaller learning rate:
No craziness here, but it does not converge to an appropriate solution. I think the explanation for this is that the algorithm has found a local-minimum, and so the algorithm gets stuck and cannot improve. This is qualitatively different to the linear and quadratic cases: since those were both instances of linear regression, it is known from theory that there is only one minimum so gradient descent will reach it. This sinusoidal case cannot be re-written as a linear regression problem, so there is not automatic guarantee of there being only one minimum point; from this experimentation, it looks like there multiple minimum points!
Conclusion
I learnt various things by doing this experiment.
- The learning rate is very important! I had to play around with the learning rates to get things to work.
- The range of values in the training data seemed to have big impact on the which learning rates to use. I am not 100% sure about this, but I have read in places that it is important to normalise your data, and perhaps its effect on learning rates is a big reason.
- I learnt how to create animations in matplotlib! And also how to include video files in this blog. :D
There are various things I would like to try.
- The next thing I will try is using the same datasets, but seeing if I can fit a neural network to the data.
- Stochastic gradient descent. My hope is that it will avoid the local minimum problem in the sinusoidal case.
- Creating a web-app out of this, so you can easily experiment for yourselves.
The code
The code for this project is in this GitHub repository.