{"id":1287,"date":"2018-11-03T19:19:37","date_gmt":"2018-11-03T19:19:37","guid":{"rendered":"http:\/\/www.philippeadjiman.com\/blog\/?p=1287"},"modified":"2025-07-18T13:15:13","modified_gmt":"2025-07-18T13:15:13","slug":"visualising-sgd-with-momentum-adam-and-learning-rate-annealing","status":"publish","type":"post","link":"https:\/\/philippeadjiman.com\/blog\/2018\/11\/03\/visualising-sgd-with-momentum-adam-and-learning-rate-annealing\/","title":{"rendered":"Visualising SGD with Momentum, Adam and Learning Rate Annealing"},"content":{"rendered":"<span style=\"font-size: revert;\"><\/span>\n<div class=\"entry-content\">\n\n\n<figure class=\"wp-block-image size-full\"><img data-recalc-dims=\"1\" decoding=\"async\" width=\"600\" height=\"298\" loading=\"lazy\" src=\"https:\/\/i0.wp.com\/philippeadjiman.com\/blog\/wp-content\/uploads\/2018\/11\/full_anim-1.gif?resize=600%2C298&#038;ssl=1\" alt=\"\" class=\"wp-image-1834\" \/><\/figure>\n\n\n<strong>[Full code on my github<span>\u00a0<\/span><a href=\"https:\/\/github.com\/padjiman\/blog-posts-code\/blob\/master\/sgd_viz\/sgd_and_optims_viz.ipynb\" target=\"_blank\" rel=\"noopener\">here<\/a>\u00a0. To see it from mobile, once you land on github, click on \u201cDesktop Version\u201d ]<\/strong>\n\nAt the very heart of the model training procedure of almost every modern machine learning or deep learning algorithm applied to big enough data, you\u2019ll find Stochastic Gradient Descent (SGD).\n\nThe best part of SGD is its simplicity. As<span>\u00a0<\/span><a href=\"https:\/\/twitter.com\/fchollet\/status\/951906139632840704\">Francois Chollet<\/a><span>\u00a0<\/span>would say, it is made of a small set of high school-level ideas put together. But it does not make it less powerful and beautiful.\n\nIn this post we\u2019ll implement from scratch SGD and some optimizations around it like Momentum, Adam and learning rate annealing, and we\u2019ll apply it on some very simple generated toy data in order to visually compare them together with some animated graph in python. In the post, we\u2019ll only show\u00a0 some snippets of a subset of the code, check<span>\u00a0<\/span><a href=\"https:\/\/github.com\/padjiman\/blog-posts-code\/blob\/master\/sgd_viz\/sgd_and_optims_viz.ipynb\" target=\"_blank\" rel=\"noopener\">here<\/a><span>\u00a0<\/span>for the full code.\n<h2>Vanilla SGD<\/h2>\nFirst we generate some data. We\u2019ll take on purpose the simplest and almost smallest data set ever, by simply generating 20 random points from a linear function ax+b.\n<div>\n<div id=\"highlighter_102411\" class=\"syntaxhighlighter nogutter  python\">\n<table border=\"0\" cellpadding=\"0\" cellspacing=\"0\">\n<tbody>\n<tr>\n<td class=\"code\">\n<div class=\"container\">\n<div class=\"line number1 index0 alt2\"><code class=\"python plain\">np.random.seed(<\/code><code class=\"python value\">7<\/code><code class=\"python plain\">)<\/code><\/div>\n<div class=\"line number2 index1 alt1\"><code class=\"python plain\">a_real <\/code><code class=\"python keyword\">=<\/code> <code class=\"python value\">1.5<\/code><\/div>\n<div class=\"line number3 index2 alt2\"><code class=\"python plain\">b_real <\/code><code class=\"python keyword\">=<\/code> <code class=\"python keyword\">-<\/code><code class=\"python value\">28<\/code><\/div>\n<div class=\"line number4 index3 alt1\"><code class=\"python plain\">xlim <\/code><code class=\"python keyword\">=<\/code> <code class=\"python plain\">[<\/code><code class=\"python keyword\">-<\/code><code class=\"python value\">10<\/code><code class=\"python plain\">,<\/code><code class=\"python value\">10<\/code><code class=\"python plain\">]<\/code><\/div>\n<div class=\"line number5 index4 alt2\"><code class=\"python plain\">x_gen <\/code><code class=\"python keyword\">=<\/code> <code class=\"python plain\">np.random.randint(low<\/code><code class=\"python keyword\">=<\/code><code class=\"python plain\">xlim[<\/code><code class=\"python value\">0<\/code><code class=\"python plain\">], high<\/code><code class=\"python keyword\">=<\/code><code class=\"python plain\">xlim[<\/code><code class=\"python value\">1<\/code><code class=\"python plain\">], size<\/code><code class=\"python keyword\">=<\/code><code class=\"python value\">20<\/code><code class=\"python plain\">)<\/code><\/div>\n<div class=\"line number6 index5 alt1\"><code class=\"python plain\">y_real <\/code><code class=\"python keyword\">=<\/code> <code class=\"python plain\">a_real<\/code><code class=\"python keyword\">*<\/code><code class=\"python plain\">x_gen <\/code><code class=\"python keyword\">+<\/code> <code class=\"python plain\">b_real<\/code><\/div>\n<div class=\"line number7 index6 alt2\"><code class=\"python plain\">plt.plot(x_gen, y_real, <\/code><code class=\"python string\">'bo'<\/code><code class=\"python plain\">)<\/code><\/div>\n<\/div><\/td>\n<\/tr>\n<\/tbody>\n<\/table>\n<\/div>\n<\/div>\n<a href=\"https:\/\/i0.wp.com\/www.philippeadjiman.com\/blog\/wp-content\/uploads\/2018\/10\/simple_data.png\"><img loading=\"lazy\" decoding=\"async\" class=\"aligncenter wp-image-1293\" src=\"https:\/\/i0.wp.com\/www.philippeadjiman.com\/blog\/wp-content\/uploads\/2018\/10\/simple_data.png?resize=377%2C251\" alt=\"\" width=\"377\" height=\"251\" srcset=\"https:\/\/i0.wp.com\/oldblog.philippeadjiman.com\/wp-content\/uploads\/2018\/10\/simple_data.png?w=432 432w, https:\/\/i0.wp.com\/oldblog.philippeadjiman.com\/wp-content\/uploads\/2018\/10\/simple_data.png?resize=300%2C200 300w\" sizes=\"auto, (max-width: 377px) 100vw, 377px\" data-recalc-dims=\"1\" \/><\/a>\n\nSo we\u2019ll start with some initial a and b (say a=0 and b=0) and the goal of SGD is to find alone the real a and b (a=1.5 and b=-28 in our example) which were used to generate those few data points.\u00a0 To do so, we simply need to minimize a cost function on the data, in our case\u00a0\u00a0<span class=\"MathJax_Preview\"><\/span><span class=\"MathJax\" id=\"MathJax-Element-1-Frame\" tabindex=\"0\"><nobr><span class=\"math\" id=\"MathJax-Span-1\"><span><span class=\"mrow\" id=\"MathJax-Span-2\"><span class=\"munderover\" id=\"MathJax-Span-3\"><span class=\"mo\" id=\"MathJax-Span-4\">\u2211<\/span><span class=\"texatom\" id=\"MathJax-Span-5\"><span class=\"mrow\" id=\"MathJax-Span-6\"><span class=\"mo\" id=\"MathJax-Span-7\">(<\/span><span class=\"mi\" id=\"MathJax-Span-8\">\ud835\udc65<\/span><span class=\"mo\" id=\"MathJax-Span-9\">,<\/span><span class=\"mi\" id=\"MathJax-Span-10\">\ud835\udc66<\/span><span class=\"mo\" id=\"MathJax-Span-11\">)<\/span><span class=\"mo\" id=\"MathJax-Span-12\">\u2208<\/span><span class=\"mi\" id=\"MathJax-Span-13\">\ud835\udc51<\/span><span class=\"mi\" id=\"MathJax-Span-14\">\ud835\udc4e<\/span><span class=\"mi\" id=\"MathJax-Span-15\">\ud835\udc61<\/span><span class=\"mi\" id=\"MathJax-Span-16\">\ud835\udc4e<\/span><span class=\"mi\" id=\"MathJax-Span-17\">\ud835\udc60<\/span><span class=\"mi\" id=\"MathJax-Span-18\">\ud835\udc52<\/span><span class=\"mi\" id=\"MathJax-Span-19\">\ud835\udc61<\/span><\/span><\/span><\/span><span class=\"mo\" id=\"MathJax-Span-20\">(<\/span><span class=\"mi\" id=\"MathJax-Span-21\">\ud835\udc4e<\/span><span class=\"mi\" id=\"MathJax-Span-22\">\ud835\udc65<\/span><span class=\"mo\" id=\"MathJax-Span-23\">+<\/span><span class=\"mi\" id=\"MathJax-Span-24\">\ud835\udc4f<\/span><span class=\"mo\" id=\"MathJax-Span-25\">\u2212<\/span><span class=\"mi\" id=\"MathJax-Span-26\">\ud835\udc66<\/span><span class=\"msubsup\" id=\"MathJax-Span-27\"><span class=\"mo\" id=\"MathJax-Span-28\">)<\/span><span class=\"mn\" id=\"MathJax-Span-29\">2<\/span><\/span><\/span><\/span><span><\/span><\/span><\/nobr><\/span><span>\u00a0<\/span>. SGD achieves that by simply following the negative of the gradient (negative because the gradient is the direction of the steepest increase of the function and we\u2019re looking for the minimum of the cost function).\n\nSo basically, the vanilla SGD parameter update is simply:\n<div>\n<div id=\"highlighter_321389\" class=\"syntaxhighlighter nogutter  python\">\n<table border=\"0\" cellpadding=\"0\" cellspacing=\"0\">\n<tbody>\n<tr>\n<td class=\"code\">\n<div class=\"container\">\n<div class=\"line number1 index0 alt2\"><code class=\"python plain\">param <\/code><code class=\"python keyword\">+<\/code><code class=\"python keyword\">=<\/code> <code class=\"python keyword\">-<\/code><code class=\"python plain\">lr<\/code><code class=\"python keyword\">*<\/code><code class=\"python plain\">dx<\/code><\/div>\n<\/div><\/td>\n<\/tr>\n<\/tbody>\n<\/table>\n<\/div>\n<\/div>\nwith lr being the learning rate, and dx being the gradient of the cost function relative to the corresponding param you want to update (in our case, only a or b).\n\nHow to compute dx ? if our hypothesis function was a deep neural network you could simply apply the chain rule multiple times (a.k.a backpropagation) via e.g.<span>\u00a0<\/span><a href=\"https:\/\/pytorch.org\/docs\/stable\/autograd.html\">pytorch\u2019s autograd<\/a>\u00a0, but in our case we can simply compute the analytical gradient of the cost function\u00a0 w.r.t. a and b, or just use wolfram alpha (like<span>\u00a0<\/span><a href=\"http:\/\/www.wolframalpha.com\/input\/?i=D%5B(b+%2B+a+x+-+y)%5E2,+a%5D\">this<\/a><span>\u00a0<\/span>) if you\u2019re lazy.\n\nThis inevitably leads you to a full implementation of SGD (for our example) in less than 10 lines of code:\n<div>\n<div id=\"highlighter_957465\" class=\"syntaxhighlighter  python\">\n<table border=\"0\" cellpadding=\"0\" cellspacing=\"0\">\n<tbody>\n<tr>\n<td class=\"gutter\">\n<div class=\"line number1 index0 alt2\">1<\/div>\n<div class=\"line number2 index1 alt1\">2<\/div>\n<div class=\"line number3 index2 alt2\">3<\/div>\n<div class=\"line number4 index3 alt1\">4<\/div>\n<div class=\"line number5 index4 alt2\">5<\/div>\n<div class=\"line number6 index5 alt1\">6<\/div>\n<div class=\"line number7 index6 alt2\">7<\/div>\n<div class=\"line number8 index7 alt1\">8<\/div><\/td>\n<td class=\"code\">\n<div class=\"container\">\n<div class=\"line number1 index0 alt2\"><code class=\"python keyword\">def<\/code> <code class=\"python plain\">sgd(X,Y, a, b, lr, epochs<\/code><code class=\"python keyword\">=<\/code><code class=\"python value\">1<\/code><code class=\"python plain\">):<\/code><\/div>\n<div class=\"line number2 index1 alt1\"><code class=\"python spaces\">\u00a0\u00a0\u00a0\u00a0<\/code><code class=\"python keyword\">for<\/code> <code class=\"python plain\">e <\/code><code class=\"python keyword\">in<\/code> <code class=\"python functions\">range<\/code><code class=\"python plain\">(epochs):<\/code><\/div>\n<div class=\"line number3 index2 alt2\"><code class=\"python spaces\">\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0<\/code><code class=\"python keyword\">for<\/code> <code class=\"python plain\">x_ , y_ <\/code><code class=\"python keyword\">in<\/code> <code class=\"python functions\">zip<\/code><code class=\"python plain\">(X,Y):<\/code><\/div>\n<div class=\"line number4 index3 alt1\"><code class=\"python spaces\">\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0<\/code><code class=\"python plain\">a <\/code><code class=\"python keyword\">=<\/code> <code class=\"python plain\">a <\/code><code class=\"python keyword\">-<\/code> <code class=\"python plain\">lr<\/code><code class=\"python keyword\">*<\/code><code class=\"python value\">2<\/code><code class=\"python keyword\">*<\/code><code class=\"python plain\">x_<\/code><code class=\"python keyword\">*<\/code><code class=\"python plain\">(a<\/code><code class=\"python keyword\">*<\/code><code class=\"python plain\">x_<\/code><code class=\"python keyword\">+<\/code><code class=\"python plain\">b<\/code><code class=\"python keyword\">-<\/code><code class=\"python plain\">y_)<\/code><\/div>\n<div class=\"line number5 index4 alt2\"><code class=\"python spaces\">\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0<\/code><code class=\"python plain\">b <\/code><code class=\"python keyword\">=<\/code> <code class=\"python plain\">b <\/code><code class=\"python keyword\">-<\/code> <code class=\"python plain\">lr<\/code><code class=\"python keyword\">*<\/code><code class=\"python value\">2<\/code><code class=\"python keyword\">*<\/code><code class=\"python plain\">(a<\/code><code class=\"python keyword\">*<\/code><code class=\"python plain\">x_<\/code><code class=\"python keyword\">+<\/code><code class=\"python plain\">b<\/code><code class=\"python keyword\">-<\/code><code class=\"python plain\">y_)\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0 <\/code><\/div>\n<div class=\"line number6 index5 alt1\"><code class=\"python spaces\">\u00a0\u00a0\u00a0\u00a0<\/code><code class=\"python keyword\">return<\/code> <code class=\"python plain\">a,b<\/code><\/div>\n<div class=\"line number7 index6 alt2\"><code class=\"python plain\">a,b <\/code><code class=\"python keyword\">=<\/code> <code class=\"python plain\">gradient_descent(x_gen,y_real,<\/code><code class=\"python value\">0<\/code><code class=\"python plain\">,<\/code><code class=\"python value\">0<\/code><code class=\"python plain\">,<\/code><code class=\"python value\">0.001<\/code><code class=\"python plain\">,epochs<\/code><code class=\"python keyword\">=<\/code><code class=\"python value\">150<\/code><code class=\"python plain\">) <\/code><\/div>\n<div class=\"line number8 index7 alt1\"><code class=\"python functions\">print<\/code><code class=\"python plain\">(<\/code><code class=\"python string\">\"a={:.3f} , b={:.3f}\"<\/code><code class=\"python plain\">.<\/code><code class=\"python functions\">format<\/code><code class=\"python plain\">(a,b))<\/code><\/div>\n<\/div><\/td>\n<\/tr>\n<\/tbody>\n<\/table>\n<\/div>\n<\/div>\nwhich prints:\n<div>\n<div id=\"highlighter_834748\" class=\"syntaxhighlighter nogutter  plain\">\n<table border=\"0\" cellpadding=\"0\" cellspacing=\"0\">\n<tbody>\n<tr>\n<td class=\"code\">\n<div class=\"container\">\n<div class=\"line number1 index0 alt2\"><code class=\"plain plain\">a=1.501 , b=-27.913<\/code><\/div>\n<\/div><\/td>\n<\/tr>\n<\/tbody>\n<\/table>\n<\/div>\n<\/div>\nNot that far from our real a=1.5 and b= -28 .\n\nNote that we could make this code much more efficient by vectorizing it but we keep it dumb on purpose to observe easily how simple it is (and also add metrics to be able to monitor and visualize the gradients steps, c.f. later section). In the code\u00a0we treat one data point at a time (batch size = 1) , and going over all the data points multiple times (as many time as specified in the epoch parameter).\n\nWe can keep track of the a and b updates after each epoch and animate the evolution (see the section named \u201cSimple standalone Animation code\u201d in the<span>\u00a0<\/span><a href=\"https:\/\/github.com\/padjiman\/blog-posts-code\/blob\/master\/sgd_viz\/sgd_and_optims_viz.ipynb\" target=\"_blank\" rel=\"noopener\">notebook<\/a><span>\u00a0<\/span>to see how to generate such an animation):\n\n<a href=\"https:\/\/i0.wp.com\/www.philippeadjiman.com\/blog\/wp-content\/uploads\/2018\/10\/vanilla_sgd_quicktime.gif\"><img loading=\"lazy\" decoding=\"async\" class=\"aligncenter size-full wp-image-1296\" src=\"https:\/\/i0.wp.com\/www.philippeadjiman.com\/blog\/wp-content\/uploads\/2018\/10\/vanilla_sgd_quicktime.gif?resize=600%2C320\" alt=\"\" width=\"600\" height=\"320\" data-recalc-dims=\"1\" \/><\/a>\n\nNote how a and b are converging slowly but surely to the real values (of a = 1.5 and b = -28).\n<h2>Momentum<\/h2>\nSo, the vanilla gradient descent is converging surely towards the optimal, but also rather slowly as we see above, taking the same small step (in the right direction) at every iteration. More than that, here we have a nice and easy convex cost function, but in case of ravines, SGD becomes even more slower by taking hesitant steps toward the optimum.\n\nTo improve that, the momentum update is taking advantage of the history of previous gradient steps directions in order to make more aggressive steps when the gradient direction seems stable, and slows it down when it is starting to go in multiple directions, inspired from the velocity principle in physics.\n\nSo, instead of of doing the vanilla update rule of SGD (param += -lr*dx) , the momentum update is actually replacing the dx part by a decaying average of previous gradients. The average is controlled by a parameter beta , and the gradient is replaced by a linear interpolation of previous gradient\u2019s update and current gradient. It gives the simple following code:\n<div>\n<div id=\"highlighter_402603\" class=\"syntaxhighlighter  python\">\n<table border=\"0\" cellpadding=\"0\" cellspacing=\"0\">\n<tbody>\n<tr>\n<td class=\"gutter\">\n<div class=\"line number1 index0 alt2\">1<\/div>\n<div class=\"line number2 index1 alt1\">2<\/div>\n<div class=\"line number3 index2 alt2\">3<\/div>\n<div class=\"line number4 index3 alt1\">4<\/div>\n<div class=\"line number5 index4 alt2\">5<\/div>\n<div class=\"line number6 index5 alt1\">6<\/div>\n<div class=\"line number7 index6 alt2\">7<\/div>\n<div class=\"line number8 index7 alt1\">8<\/div>\n<div class=\"line number9 index8 alt2\">9<\/div>\n<div class=\"line number10 index9 alt1\">10<\/div>\n<div class=\"line number11 index10 alt2\">11<\/div>\n<div class=\"line number12 index11 alt1\">12<\/div><\/td>\n<td class=\"code\">\n<div class=\"container\">\n<div class=\"line number1 index0 alt2\"><code class=\"python keyword\">def<\/code> <code class=\"python plain\">sgd_momentum(X,Y, a, b, lr,\u00a0 beta, epochs<\/code><code class=\"python keyword\">=<\/code><code class=\"python value\">1<\/code><code class=\"python plain\">):<\/code><\/div>\n<div class=\"line number2 index1 alt1\"><code class=\"python spaces\">\u00a0\u00a0\u00a0\u00a0<\/code><code class=\"python plain\">avg_ga <\/code><code class=\"python keyword\">=<\/code> <code class=\"python value\">0<\/code><\/div>\n<div class=\"line number3 index2 alt2\"><code class=\"python spaces\">\u00a0\u00a0\u00a0\u00a0<\/code><code class=\"python plain\">avg_gb <\/code><code class=\"python keyword\">=<\/code> <code class=\"python value\">0<\/code><\/div>\n<div class=\"line number4 index3 alt1\"><code class=\"python spaces\">\u00a0\u00a0\u00a0\u00a0<\/code><code class=\"python keyword\">for<\/code> <code class=\"python plain\">e <\/code><code class=\"python keyword\">in<\/code> <code class=\"python functions\">range<\/code><code class=\"python plain\">(epochs):<\/code><\/div>\n<div class=\"line number5 index4 alt2\"><code class=\"python spaces\">\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0<\/code><code class=\"python keyword\">for<\/code> <code class=\"python plain\">x_ , y_ <\/code><code class=\"python keyword\">in<\/code> <code class=\"python functions\">zip<\/code><code class=\"python plain\">(X,Y):<\/code><\/div>\n<div class=\"line number6 index5 alt1\"><code class=\"python spaces\">\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0<\/code><code class=\"python plain\">de_da <\/code><code class=\"python keyword\">=<\/code> <code class=\"python value\">2<\/code><code class=\"python keyword\">*<\/code><code class=\"python plain\">x_<\/code><code class=\"python keyword\">*<\/code><code class=\"python plain\">(a<\/code><code class=\"python keyword\">*<\/code><code class=\"python plain\">x_<\/code><code class=\"python keyword\">+<\/code><code class=\"python plain\">b<\/code><code class=\"python keyword\">-<\/code><code class=\"python plain\">y_)<\/code><\/div>\n<div class=\"line number7 index6 alt2\"><code class=\"python spaces\">\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0<\/code><code class=\"python plain\">de_db <\/code><code class=\"python keyword\">=<\/code> <code class=\"python value\">2<\/code><code class=\"python keyword\">*<\/code><code class=\"python plain\">(a<\/code><code class=\"python keyword\">*<\/code><code class=\"python plain\">x_<\/code><code class=\"python keyword\">+<\/code><code class=\"python plain\">b<\/code><code class=\"python keyword\">-<\/code><code class=\"python plain\">y_) <\/code><\/div>\n<div class=\"line number8 index7 alt1\"><code class=\"python spaces\">\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0<\/code><code class=\"python plain\">avg_ga <\/code><code class=\"python keyword\">=<\/code> <code class=\"python plain\">avg_ga<\/code><code class=\"python keyword\">*<\/code><code class=\"python plain\">beta <\/code><code class=\"python keyword\">+<\/code> <code class=\"python plain\">(<\/code><code class=\"python value\">1.0<\/code><code class=\"python keyword\">-<\/code><code class=\"python plain\">beta)<\/code><code class=\"python keyword\">*<\/code><code class=\"python plain\">de_da<\/code><\/div>\n<div class=\"line number9 index8 alt2\"><code class=\"python spaces\">\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0<\/code><code class=\"python plain\">avg_gb <\/code><code class=\"python keyword\">=<\/code> <code class=\"python plain\">avg_gb<\/code><code class=\"python keyword\">*<\/code><code class=\"python plain\">beta <\/code><code class=\"python keyword\">+<\/code> <code class=\"python plain\">(<\/code><code class=\"python value\">1.0<\/code><code class=\"python keyword\">-<\/code><code class=\"python plain\">beta)<\/code><code class=\"python keyword\">*<\/code><code class=\"python plain\">de_db<\/code><\/div>\n<div class=\"line number10 index9 alt1\"><code class=\"python spaces\">\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0<\/code><code class=\"python plain\">a <\/code><code class=\"python keyword\">=<\/code> <code class=\"python plain\">a <\/code><code class=\"python keyword\">-<\/code> <code class=\"python plain\">lr<\/code><code class=\"python keyword\">*<\/code><code class=\"python plain\">avg_ga<\/code><\/div>\n<div class=\"line number11 index10 alt2\"><code class=\"python spaces\">\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0<\/code><code class=\"python plain\">b <\/code><code class=\"python keyword\">=<\/code> <code class=\"python plain\">b <\/code><code class=\"python keyword\">-<\/code> <code class=\"python plain\">lr<\/code><code class=\"python keyword\">*<\/code><code class=\"python plain\">avg_gb\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0 <\/code><\/div>\n<div class=\"line number12 index11 alt1\"><code class=\"python spaces\">\u00a0\u00a0\u00a0\u00a0<\/code><code class=\"python keyword\">return<\/code> <code class=\"python plain\">a,b<\/code><\/div>\n<\/div><\/td>\n<\/tr>\n<\/tbody>\n<\/table>\n<\/div>\n<\/div>\nWhat is interesting is to compare the evolution of the gradient per method, to see if we do see the expected smoother evolution of the gradient update each iterations. This is how compares the vanilla SGD v.s momentum gradient updates on the first learned parameter (the a):\n\n<a href=\"https:\/\/i0.wp.com\/www.philippeadjiman.com\/blog\/wp-content\/uploads\/2018\/10\/momemtum_a.png\"><img loading=\"lazy\" decoding=\"async\" class=\"aligncenter size-full wp-image-1297\" src=\"https:\/\/i0.wp.com\/www.philippeadjiman.com\/blog\/wp-content\/uploads\/2018\/10\/momemtum_a.png?resize=604%2C318\" alt=\"\" width=\"604\" height=\"318\" srcset=\"https:\/\/i0.wp.com\/oldblog.philippeadjiman.com\/wp-content\/uploads\/2018\/10\/momemtum_a.png?w=735 735w, https:\/\/i0.wp.com\/oldblog.philippeadjiman.com\/wp-content\/uploads\/2018\/10\/momemtum_a.png?resize=300%2C158 300w\" sizes=\"auto, (max-width: 604px) 100vw, 604px\" data-recalc-dims=\"1\" \/><\/a>\n\nAnd on the second one (the b):\n\n<a href=\"https:\/\/i0.wp.com\/www.philippeadjiman.com\/blog\/wp-content\/uploads\/2018\/10\/momemtum_b.png\"><img loading=\"lazy\" decoding=\"async\" class=\"aligncenter size-full wp-image-1298\" src=\"https:\/\/i0.wp.com\/www.philippeadjiman.com\/blog\/wp-content\/uploads\/2018\/10\/momemtum_b.png?resize=604%2C324\" alt=\"\" width=\"604\" height=\"324\" srcset=\"https:\/\/i0.wp.com\/oldblog.philippeadjiman.com\/wp-content\/uploads\/2018\/10\/momemtum_b.png?w=721 721w, https:\/\/i0.wp.com\/oldblog.philippeadjiman.com\/wp-content\/uploads\/2018\/10\/momemtum_b.png?resize=300%2C161 300w\" sizes=\"auto, (max-width: 604px) 100vw, 604px\" data-recalc-dims=\"1\" \/><\/a>\n\nYeah, i know it looks a bit like a\u00a0<img loading=\"lazy\" decoding=\"async\" class=\"\" src=\"https:\/\/i0.wp.com\/www.eram.fr\/media\/catalog\/product\/cache\/1\/image\/9df78eab33525d08d6e5fb8d27136e95\/W\/W\/WWWERAM_10383960387_0.jpg?resize=33%2C29&amp;ssl=1\" alt=\"Image result for talon aiguille\" width=\"33\" height=\"29\" data-recalc-dims=\"1\" \/>, but the point is that the momentum update is doing its job: it is not going crazy in all direction like the raw gradient, but it is smoothing it, based on previous iterations. For such a simple convex error function like ours, it does not really matters (and it won\u2019t make a dramatic difference in terms of how fast we converge as we\u2019ll see below), but we can easily understand how in a very bumpy loss function surface, this could be a great advantage to surf around those rather than entering eyes closed inside each small ravine.\n<h2>Adam, Learning Rate Annealing and other SGD optimisations<\/h2>\nIn the same spirit as the momentum update, many different methods are exposing multiple variations of how to modify the SGD update rule, in order to converge faster and better to the optimal parameters that the model is trying to learn.\n\nWe won\u2019t get into details of all those great optimizations because there are already excellent posts\/video around that topic, e.g. from Sebastian Rudder<span>\u00a0<\/span><a href=\"http:\/\/ruder.io\/optimizing-gradient-descent\/\">here<\/a>\u00a0or Andrej Karpathy\u00a0<span>\u00a0<\/span><a href=\"http:\/\/cs231n.github.io\/neural-networks-3\/#update\">here<\/a>\u00a0or the fantastic video (and corresponding<span>\u00a0<\/span><a href=\"https:\/\/github.com\/fastai\/fastai\/blob\/master\/courses\/dl1\/excel\/graddesc.xlsm\" target=\"_blank\" rel=\"noopener\">excel file)<\/a><span>\u00a0<\/span>by Jeremy Howard . But we\u2019ll do mention a couple of important concepts that those methods are using:\n<h4>Per coordinate adaptive learning rate<\/h4>\nThe idea is that the size of the steps taken in gradient descent should be adapted for each learned parameter separately. Intuitively, the idea would be to make the learning rate smaller as a function of how much data was observed for the specific corresponding parameter. First popular proposed method is\u00a0<a href=\"http:\/\/jmlr.org\/papers\/v12\/duchi11a.html\">AdaGrad<\/a>\u00a0and then Adadelta and RMSProp are some evolution of it, then\u00a0 Adam (Adaptive moment estimation) is combining that idea with momentum, then other methods are improving on top of Adam etc.. Again, c.f. the links above and the excel file, they are the best reference for all those. You can also find an illustration of how to apply and implement (in few lines of code) this concept in the context of logistic regression in my blog post<span>\u00a0<\/span><a href=\"http:\/\/www.philippeadjiman.com\/blog\/2018\/02\/26\/deep-dive-into-logistic-regression-part-2\/\">here<\/a>.\n<h4>Learning rate annealing (with restarts)<\/h4>\nThis simple concept is also one the most effective tricks you can find in the deep learning world.\u00a0<span>The idea is that when you start your search of the optimal parameter, you can afford doing some big jump, but the more you progress towards your minimum, the more you want to make smaller steps to nail it down (and not miss it by too big steps), and thus you progressively reduce your learning rate.\u00a0<\/span><span>You can also combine that idea by reinitialising your learning rate to its highest value from time to time (this is called SGD with restarts) to find more general optimums. More on that in the fantastic\u00a0<a href=\"https:\/\/course.fast.ai\/\" target=\"_blank\" rel=\"noopener\">fast.ai course<\/a>\u00a0and also in\u00a0<a href=\"http:\/\/ruder.io\/deep-learning-optimization-2017\/index.html#sgdwithrestarts\">that great blog post<\/a>.\u00a0<\/span>\n<h2>Implementation and Visual comparisons on our simple example<\/h2>\nIn<span>\u00a0<\/span><a href=\"https:\/\/github.com\/padjiman\/blog-posts-code\/blob\/master\/sgd_viz\/sgd_and_optims_viz.ipynb\" target=\"_blank\" rel=\"noopener\">that notebook<\/a>, I\u2019ve implemented a few functions allowing to simulate, visualize, debug, investigate, experiment with few variations of SGD: vanilla, momentum, adam and adam with learning rate annealing .\n\nThe implementation can easily be extended with any function (not only a linear function as in our example), only the derivative needs to be provided as a function (which itself could be automated using e.g.<span>\u00a0<\/span><a href=\"https:\/\/pytorch.org\/docs\/stable\/autograd.html\">pytorch\u2019s autograd<\/a>), although the visualisations are adapted only for functions in the domain<span>\u00a0<\/span><span class=\"MathJax_Preview\"><\/span><span class=\"MathJax\" id=\"MathJax-Element-2-Frame\" tabindex=\"0\"><nobr><span class=\"math\" id=\"MathJax-Span-30\"><span><span class=\"mrow\" id=\"MathJax-Span-31\"><span class=\"msubsup\" id=\"MathJax-Span-32\"><span class=\"texatom\" id=\"MathJax-Span-33\"><span class=\"mrow\" id=\"MathJax-Span-34\"><span class=\"mi\" id=\"MathJax-Span-35\">\u211d<\/span><\/span><\/span><span class=\"mi\" id=\"MathJax-Span-36\">\ud835\udc5b<\/span><\/span><span class=\"mo\" id=\"MathJax-Span-37\">\u2192<\/span><span class=\"texatom\" id=\"MathJax-Span-38\"><span class=\"mrow\" id=\"MathJax-Span-39\"><span class=\"mi\" id=\"MathJax-Span-40\">\u211d<\/span><\/span><\/span><\/span><\/span><span><\/span><\/span><\/nobr><\/span><span>\u00a0<\/span>.\n\nBelow, we\u2019re showing some of the output given when calling this code with:\n<div>\n<div id=\"highlighter_574266\" class=\"syntaxhighlighter  python\">\n<table border=\"0\" cellpadding=\"0\" cellspacing=\"0\">\n<tbody>\n<tr>\n<td class=\"gutter\">\n<div class=\"line number1 index0 alt2\">1<\/div>\n<div class=\"line number2 index1 alt1\">2<\/div>\n<div class=\"line number3 index2 alt2\">3<\/div>\n<div class=\"line number4 index3 alt1\">4<\/div>\n<div class=\"line number5 index4 alt2\">5<\/div>\n<div class=\"line number6 index5 alt1\">6<\/div>\n<div class=\"line number7 index6 alt2\">7<\/div>\n<div class=\"line number8 index7 alt1\">8<\/div>\n<div class=\"line number9 index8 alt2\">9<\/div>\n<div class=\"line number10 index9 alt1\">10<\/div>\n<div class=\"line number11 index10 alt2\">11<\/div>\n<div class=\"line number12 index11 alt1\">12<\/div>\n<div class=\"line number13 index12 alt2\">13<\/div><\/td>\n<td class=\"code\">\n<div class=\"container\">\n<div class=\"line number1 index0 alt2\"><code class=\"python plain\">params,methods,x,y, \\<\/code><\/div>\n<div class=\"line number2 index1 alt1\"><code class=\"python plain\">loss_evolution_list,lr_evolution_list <\/code><code class=\"python keyword\">=<\/code> <code class=\"python plain\">\\<\/code><\/div>\n<div class=\"line number3 index2 alt2\"><code class=\"python plain\">compare_methods_and_plot([<\/code><code class=\"python value\">1.5<\/code><code class=\"python plain\">,<\/code><code class=\"python keyword\">-<\/code><code class=\"python value\">28<\/code><code class=\"python plain\">],[<\/code><code class=\"python value\">1<\/code><code class=\"python plain\">,<\/code><code class=\"python value\">1<\/code><code class=\"python plain\">],lin,<\/code><\/div>\n<div class=\"line number4 index3 alt1\"><code class=\"python spaces\">\u00a0\u00a0\u00a0\u00a0<\/code><code class=\"python plain\">linear_gradients,[<\/code><code class=\"python keyword\">-<\/code><code class=\"python value\">10<\/code><code class=\"python plain\">,<\/code><code class=\"python value\">10<\/code><code class=\"python plain\">],<\/code><\/div>\n<div class=\"line number5 index4 alt2\"><code class=\"python spaces\">\u00a0\u00a0\u00a0\u00a0<\/code><code class=\"python plain\">size_gen_data<\/code><code class=\"python keyword\">=<\/code><code class=\"python value\">30<\/code><code class=\"python plain\">,<\/code><\/div>\n<div class=\"line number6 index5 alt1\"><code class=\"python spaces\">\u00a0\u00a0\u00a0\u00a0<\/code><code class=\"python plain\">epochs <\/code><code class=\"python keyword\">=<\/code> <code class=\"python value\">50<\/code><code class=\"python plain\">,<\/code><\/div>\n<div class=\"line number7 index6 alt2\"><code class=\"python spaces\">\u00a0\u00a0\u00a0\u00a0<\/code><code class=\"python plain\">methods <\/code><code class=\"python keyword\">=<\/code> <code class=\"python plain\">[<\/code><code class=\"python string\">\"SGD\"<\/code><code class=\"python plain\">,<\/code><code class=\"python string\">\"Momentum\"<\/code><code class=\"python plain\">,<\/code><code class=\"python string\">\"Adam\"<\/code><code class=\"python plain\">,<\/code><code class=\"python string\">\"AdamAnn\"<\/code><code class=\"python plain\">],<\/code><\/div>\n<div class=\"line number8 index7 alt1\"><code class=\"python spaces\">\u00a0\u00a0\u00a0\u00a0<\/code><code class=\"python plain\">methods_optim_params <\/code><code class=\"python keyword\">=<\/code> <code class=\"python plain\">[[<\/code><code class=\"python value\">0.001<\/code><code class=\"python plain\">] , [<\/code><code class=\"python value\">0.001<\/code><code class=\"python plain\">,<\/code><code class=\"python value\">0.95<\/code><code class=\"python plain\">] , [<\/code><code class=\"python value\">1<\/code><code class=\"python plain\">,<\/code><code class=\"python value\">0.7<\/code><code class=\"python plain\">,<\/code><code class=\"python value\">0.9<\/code><code class=\"python plain\">] , [<\/code><code class=\"python value\">1<\/code><code class=\"python plain\">,<\/code><code class=\"python value\">0.7<\/code><code class=\"python plain\">,<\/code><code class=\"python value\">0.9<\/code><code class=\"python plain\">] ],<\/code><\/div>\n<div class=\"line number9 index8 alt2\"><code class=\"python spaces\">\u00a0\u00a0\u00a0\u00a0<\/code><code class=\"python plain\">anim_interval_ms<\/code><code class=\"python keyword\">=<\/code><code class=\"python value\">100<\/code><code class=\"python plain\">,ylim_anim<\/code><code class=\"python keyword\">=<\/code><code class=\"python plain\">[<\/code><code class=\"python keyword\">-<\/code><code class=\"python value\">50<\/code><code class=\"python plain\">,<\/code><code class=\"python value\">15<\/code><code class=\"python plain\">])<\/code><\/div>\n<div class=\"line number10 index9 alt1\"><\/div>\n<div class=\"line number11 index10 alt2\"><code class=\"python plain\">fig, ax <\/code><code class=\"python keyword\">=<\/code> <code class=\"python plain\">plt.subplots(figsize<\/code><code class=\"python keyword\">=<\/code><code class=\"python plain\">(<\/code><code class=\"python value\">12<\/code><code class=\"python plain\">, <\/code><code class=\"python value\">6<\/code><code class=\"python plain\">))\u00a0\u00a0\u00a0 <\/code><\/div>\n<div class=\"line number12 index11 alt1\"><code class=\"python plain\">anim <\/code><code class=\"python keyword\">=<\/code> <code class=\"python plain\">draw_animation(params,methods, lin,x,y,<\/code><code class=\"python value\">75<\/code><code class=\"python plain\">,ylim<\/code><code class=\"python keyword\">=<\/code><code class=\"python plain\">[<\/code><code class=\"python keyword\">-<\/code><code class=\"python value\">50<\/code><code class=\"python plain\">,<\/code><code class=\"python value\">15<\/code><code class=\"python plain\">])<\/code><\/div>\n<div class=\"line number13 index12 alt2\"><code class=\"python plain\">HTML(anim.to_jshtml())<\/code><\/div>\n<\/div><\/td>\n<\/tr>\n<\/tbody>\n<\/table>\n<\/div>\n<\/div>\nFirst let\u2019s observe the animation:\n\n<a href=\"https:\/\/i0.wp.com\/www.philippeadjiman.com\/blog\/wp-content\/uploads\/2018\/11\/full_anim.gif\"><img loading=\"lazy\" decoding=\"async\" class=\"aligncenter size-full wp-image-1307\" src=\"https:\/\/i0.wp.com\/www.philippeadjiman.com\/blog\/wp-content\/uploads\/2018\/11\/full_anim.gif?resize=600%2C298\" alt=\"\" width=\"600\" height=\"298\" data-recalc-dims=\"1\" \/><\/a>\n\nWe can note few things:\n<ul>\n \t<li>See how fast Adam and Adam with annealing are converging compared to vanilla SGD\u00a0 or SGD with momentum.<\/li>\n \t<li>However, Adam without annealing is not stable and suffers from some \u201cparkinson\u201d side effects. Probably because in that case the initial learning rate remains too high at the end for it to stay around, while with annealing, once it converged, it stays there because after enough iterations, the learning rate becomes really tiny.<\/li>\n \t<li>The Momentum update is not really helping to converge faster in that specific example. This is because our example is using a dead simple convex cost function that is easy to optimize anyway. But in more complex cost functions ( like the ones represented by neural nets) the momentum update can provide much more added value.<\/li>\n \t<li>By tuning the initial learning rates for each method,\u00a0 we could potentially make them converge faster, but here we took standard initial value for each method for the sake of the comparison.<\/li>\n<\/ul>\nLet\u2019s observe the loss evolution over iterations between methods:\n\n<a href=\"https:\/\/i0.wp.com\/www.philippeadjiman.com\/blog\/wp-content\/uploads\/2018\/11\/cost_evolution.png\"><img loading=\"lazy\" decoding=\"async\" class=\"aligncenter size-full wp-image-1308\" src=\"https:\/\/i0.wp.com\/www.philippeadjiman.com\/blog\/wp-content\/uploads\/2018\/11\/cost_evolution.png?resize=604%2C322\" alt=\"\" width=\"604\" height=\"322\" srcset=\"https:\/\/i0.wp.com\/oldblog.philippeadjiman.com\/wp-content\/uploads\/2018\/11\/cost_evolution.png?w=727 727w, https:\/\/i0.wp.com\/oldblog.philippeadjiman.com\/wp-content\/uploads\/2018\/11\/cost_evolution.png?resize=300%2C160 300w\" sizes=\"auto, (max-width: 604px) 100vw, 604px\" data-recalc-dims=\"1\" \/><\/a>\n\nWe can see how fast Adam is converging to a minimal cost compared to vanilla SGD of SGD with momentum. Momentum also seems to converge as fast as (even slightly more slowly than) vanilla SGD, but again, this is due to the dead simple function we used here. In a neural net, it would already proved much more useful in most cases.\n\nThe code is also generating the evolution of each learned param over iterations. We can e.g. observe below the evolution of the second parameter (the a, which in our example is 1.5). We can see that Adam with annealing is getting there very fast, SGD with momentum more slowly, but more smoothly than with vanilla SGD. And we can observe how Adam without annealing is suffering for high oscillations (which was also observed in the animation).\n\n<a href=\"https:\/\/i0.wp.com\/www.philippeadjiman.com\/blog\/wp-content\/uploads\/2018\/11\/param_a_evolution_all.png\"><img loading=\"lazy\" decoding=\"async\" class=\"aligncenter size-full wp-image-1310\" src=\"https:\/\/i0.wp.com\/www.philippeadjiman.com\/blog\/wp-content\/uploads\/2018\/11\/param_a_evolution_all.png?resize=604%2C320\" alt=\"\" width=\"604\" height=\"320\" srcset=\"https:\/\/i0.wp.com\/oldblog.philippeadjiman.com\/wp-content\/uploads\/2018\/11\/param_a_evolution_all.png?w=730 730w, https:\/\/i0.wp.com\/oldblog.philippeadjiman.com\/wp-content\/uploads\/2018\/11\/param_a_evolution_all.png?resize=300%2C159 300w\" sizes=\"auto, (max-width: 604px) 100vw, 604px\" data-recalc-dims=\"1\" \/><\/a>\n\nAgain, one should obviously not generalise around the added value of each method based on that simple example. Here we just wanted to illustrate the concepts, and even on such a toy set, we can understand and observe the core ideas behind those simple yet powerful methods.\n\nThat\u2019s it for now. Hope you enjoyed that post. Feel free to comment\/ask questions and\/or use the<span>\u00a0<\/span><a href=\"https:\/\/github.com\/padjiman\/blog-posts-code\/blob\/master\/sgd_viz\/sgd_and_optims_viz.ipynb\" target=\"_blank\" rel=\"noopener\">code<\/a><span>\u00a0<\/span>for your own experiments.\n\n<\/div>","protected":false},"excerpt":{"rendered":"<p>Watch optimizers battle it out in a visual showdown\u2014Momentum vs Adam vs LR schedules, explained with intuition and flair.<\/p>\n","protected":false},"author":1,"featured_media":1303,"comment_status":"open","ping_status":"open","sticky":false,"template":"","format":"standard","meta":{"nf_dc_page":"","_jetpack_memberships_contains_paid_content":false,"footnotes":""},"categories":[7,8,12],"tags":[],"class_list":["post-1287","post","type-post","status-publish","format-standard","has-post-thumbnail","hentry","category-data-science","category-experiments","category-machine-learning"],"jetpack_featured_media_url":"https:\/\/i0.wp.com\/philippeadjiman.com\/blog\/wp-content\/uploads\/2018\/11\/gradients_a.png?fit=721%2C387&ssl=1","jetpack_sharing_enabled":true,"_links":{"self":[{"href":"https:\/\/philippeadjiman.com\/blog\/wp-json\/wp\/v2\/posts\/1287","targetHints":{"allow":["GET"]}}],"collection":[{"href":"https:\/\/philippeadjiman.com\/blog\/wp-json\/wp\/v2\/posts"}],"about":[{"href":"https:\/\/philippeadjiman.com\/blog\/wp-json\/wp\/v2\/types\/post"}],"author":[{"embeddable":true,"href":"https:\/\/philippeadjiman.com\/blog\/wp-json\/wp\/v2\/users\/1"}],"replies":[{"embeddable":true,"href":"https:\/\/philippeadjiman.com\/blog\/wp-json\/wp\/v2\/comments?post=1287"}],"version-history":[{"count":2,"href":"https:\/\/philippeadjiman.com\/blog\/wp-json\/wp\/v2\/posts\/1287\/revisions"}],"predecessor-version":[{"id":1957,"href":"https:\/\/philippeadjiman.com\/blog\/wp-json\/wp\/v2\/posts\/1287\/revisions\/1957"}],"wp:featuredmedia":[{"embeddable":true,"href":"https:\/\/philippeadjiman.com\/blog\/wp-json\/wp\/v2\/media\/1303"}],"wp:attachment":[{"href":"https:\/\/philippeadjiman.com\/blog\/wp-json\/wp\/v2\/media?parent=1287"}],"wp:term":[{"taxonomy":"category","embeddable":true,"href":"https:\/\/philippeadjiman.com\/blog\/wp-json\/wp\/v2\/categories?post=1287"},{"taxonomy":"post_tag","embeddable":true,"href":"https:\/\/philippeadjiman.com\/blog\/wp-json\/wp\/v2\/tags?post=1287"}],"curies":[{"name":"wp","href":"https:\/\/api.w.org\/{rel}","templated":true}]}}