Gradient descent

Backpropagation is how we really train our model; it's an algorithm we use to minimize the prediction error by adjusting our model's weights. We usually do this via a method called gradient descent.

Let's begin with a basic example—let's say we want to train a simple neural network to do the following, by multiplying a number by 0.5:

Input

Target

1

0.5

2

1.0

3

1.5

4

2.0

 

We have a basic model to start with, as follows:

y = W * x

So, to start, let's guess that W is actually two. The following table shows these results:

Input

Target

W * x

1

0.5

2

2

1.0

4

3

1.5

6

4

2.0

8

 

Now that we have the output of our guess, we can compare this guess to the answer we are expecting and calculate the relative error. For example, in this table, we are using the sum of the squared errors:

Input

Target

W * x

Absolute error

Squared error

1

0.5

2

-1.5

2.25

2

1.0

4

-3.0

9

3

1.5

6

-4.5

20.25

4

2.0

8

-6.0

36

 

By adding up the values in the last column of the preceding tables, we now have a sum of the squared errors, a total of 67.5.

We can certainly brute force all of the values from -10 to +10 to get an answer, but surely there must be a better way? Ideally, we want a more efficient way that scales to datasets that are not simple tables with four inputs.

A better method is to check the derivative (or gradient). One way we can do this is to do this same calculation again, but with a slightly higher weight; for example, let's try W = 2.01. The following table shows these results:

Input

Target

W * x

Absolute error

Squared error

1

0.5

2.01

-1.51

2.2801

2

1.0

4.02

-3.02

9.1204

3

1.5

6.03

-4.53

20.5209

4

2.0

8.04

-6.04

36.4816

 

This gives us a sum of the squared errors of 68.403; this is higher! This means that, intuitively, if we increase the weight, we're likely to see an increase in the error. The inverse is also true; if we decrease the weight, we are likely to see a decrease in the error. For example, let's try W = 1.99, as shown in the following table:

Input

Target

W * x

Absolute error

Squared error

0

0

0

0

0

4

2

4.04

-1.996

3.984016

8

4

8.08

-3.992

15.93606

16

8

15.84

-7.984

63.74426

 

This gives us a lower error of 83.66434.

If we were to plot the error for a given range of W, you can see that there is a natural bottom point. This is how we can descend on the gradient to minimize the errors.

For this specific example, we can easily plot the error as a function of our weights.

The goal is to follow the slope to the bottom, where the error is zero:

Let's try applying a weight update to our example to illustrate how this works. In general, we follow something called the delta learning rule, which is basically similar to the following:

new_W = old_W - eta * derivative

In this formula, eta is a constant, sometimes also called the learning rate. Recall that when we call solver in Gorgonia, we include a learning rate as one of the options, as shown here:

solver := NewVanillaSolver(WithLearnRate(0.001), WithClip(5))

You will also often see a 0.5 term added to the derivative for the error with respect to the output. This is because, if our error function is a square function, the derivative will be 2, so the 0.5 term is put there to cancel it out; however, eta is a constant anyway (so you can also just consider it absorbed into the eta term).

So, first, we need to work out what the derivative is for the error with respect to the output.

If we were to say that our learning rate was 0.001, this makes our new weight the following:

new_W = 1.00 - 0.001 * 101.338

If we were to compute this, new_W would be 1.89866. This is closer to our eventual target weight of 0.5, and, with enough repetition, we would eventually get there. You'll notice that our learning rate is small. If we set it too large (let's say, 1), we would've ended up adjusting our weight way too far into the negative instead, so we would end up going round and round our gradient, instead of descending it. Our choice of learning rate is important: too small and our model will take too long to converge, and too large and it may even diverge instead.

..................Content has been hidden....................

You can't read the all page of ebook, please click here login for view all page.
Reset
18.223.210.71