Backpropagation

The section on costs is a little sparse for good reason. Furthermore, there is a twist: we're not going to entirely calculate the full cost function, mainly because we don't need to for this specific case. Costs are heavily tied to the notion of backpropagation. Now we're going to do some mathematical trickery.

Recall that our cost was the sum of squared errors. We can write it like so:

Now what I am about to describe can sound very much like cheating, but it's a valid strategy. The derivative with regard to prediction is this:

To make things a bit easier on ourselves, let's redefine the cost as this:

It doesn't make a difference to the process of finding the lowest cost. Think about it; imagine a highest cost and a lowest cost. The difference between them if there is a  multiplier in front of them does not change the fact that the lowest cost is still lower than the highest cost. Take some time to work this out on your own to convince yourself that having a constant multiplier doesn't change the process.

The derivative of a sigmoid function is:

From there, we can work out the derivation of the cost function with regard to the weights matrix. How to work out the full backpropagation will be explained in the next chapter. For now, here is the code:

  // backpropagation
outputErrors := m.do(func() (tensor.Tensor, error) { return tensor.Sub(y, pred) })
cost = sum(outputErrors.Data().([]float64))

hidErrs := m.do(func() (tensor.Tensor, error) {
if err := nn.final.T(); err != nil {
return nil, err
}
defer nn.final.UT()
return tensor.MatMul(nn.final, outputErrors)
})

if m.err != nil {
return 0, m.err
}

dpred := m.do(func() (tensor.Tensor, error) { return pred.Apply(dsigmoid, tensor.UseUnsafe()) })
m.do(func() (tensor.Tensor, error) { return tensor.Mul(pred, outputErrors, tensor.UseUnsafe()) })
// m.do(func() (tensor.Tensor, error) { err := act0.T(); return act0, err })
dpred_dfinal := m.do(func() (tensor.Tensor, error) {
if err := act0.T(); err != nil {
return nil, err
}
defer act0.UT()
return tensor.MatMul(outputErrors, act0)
})

dact0 := m.do(func() (tensor.Tensor, error) { return act0.Apply(dsigmoid) })
m.do(func() (tensor.Tensor, error) { return tensor.Mul(hidErrs, dact0, tensor.UseUnsafe()) })
m.do(func() (tensor.Tensor, error) { err := hidErrs.Reshape(hidErrs.Shape()[0], 1); return hidErrs, err })
// m.do(func() (tensor.Tensor, error) { err := x.T(); return x, err })
dcost_dhidden := m.do(func() (tensor.Tensor, error) {
if err := x.T(); err != nil {
return nil, err
}
defer x.UT()
return tensor.MatMul(hidErrs, x)
})

And there we have it, the derivatives of the cost with regard to the inputs matrices.

The thing to do with the derivatives is to use them as gradients to update the input matrices. To do that, use a simple gradient descent algorithm; we simply add the gradient to the values itself. But we don't want to add the full value of the gradient. If we do that and our starting value is very close to the minima, we'd overshoot it. So we need to multiply the gradients by some small value, known as the learn rate:

  // gradient update
m.do(func() (tensor.Tensor, error) { return tensor.Mul(dcost_dfinal, learnRate, tensor.UseUnsafe()) })
m.do(func() (tensor.Tensor, error) { return tensor.Mul(dcost_dhidden, learnRate, tensor.UseUnsafe()) })
m.do(func() (tensor.Tensor, error) { return tensor.Add(nn.final, dcost_dfinal, tensor.UseUnsafe()) })
m.do(func() (tensor.Tensor, error) { return tensor.Add(nn.hidden, dcost_dhidden, tensor.UseUnsafe()) })

And this is the training function in full:

// X is the image, Y is a one hot vector
func (nn *NN) Train(x, y tensor.Tensor, learnRate float64) (cost float64, err error) {
// predict
var m maybe
m.do(func() (tensor.Tensor, error) { err := x.Reshape(x.Shape()[0], 1); return x, err })
m.do(func() (tensor.Tensor, error) { err := y.Reshape(10, 1); return y, err })

hidden := m.do(func() (tensor.Tensor, error) { return tensor.MatMul(nn.hidden, x) })
act0 := m.do(func() (tensor.Tensor, error) { return hidden.Apply(sigmoid, tensor.UseUnsafe()) })

final := m.do(func() (tensor.Tensor, error) { return tensor.MatMul(nn.final, act0) })
pred := m.do(func() (tensor.Tensor, error) { return final.Apply(sigmoid, tensor.UseUnsafe()) })
// log.Printf("pred %v, correct %v", argmax(pred.Data().([]float64)), argmax(y.Data().([]float64)))

// backpropagation.
outputErrors := m.do(func() (tensor.Tensor, error) { return tensor.Sub(y, pred) })
cost = sum(outputErrors.Data().([]float64))

hidErrs := m.do(func() (tensor.Tensor, error) {
if err := nn.final.T(); err != nil {
return nil, err
}
defer nn.final.UT()
return tensor.MatMul(nn.final, outputErrors)
})

if m.err != nil {
return 0, m.err
}

dpred := m.do(func() (tensor.Tensor, error) { return pred.Apply(dsigmoid, tensor.UseUnsafe()) })
m.do(func() (tensor.Tensor, error) { return tensor.Mul(pred, outputErrors, tensor.UseUnsafe()) })
// m.do(func() (tensor.Tensor, error) { err := act0.T(); return act0, err })
dpred_dfinal := m.do(func() (tensor.Tensor, error) {
if err := act0.T(); err != nil {
return nil, err
}
defer act0.UT()
return tensor.MatMul(outputErrors, act0)
})

dact0 := m.do(func() (tensor.Tensor, error) { return act0.Apply(dsigmoid) })
m.do(func() (tensor.Tensor, error) { return tensor.Mul(hidErrs, dact0, tensor.UseUnsafe()) })
m.do(func() (tensor.Tensor, error) { err := hidErrs.Reshape(hidErrs.Shape()[0], 1); return hidErrs, err })
// m.do(func() (tensor.Tensor, error) { err := x.T(); return x, err })
dcost_dhidden := m.do(func() (tensor.Tensor, error) {
if err := x.T(); err != nil {
return nil, err
}
defer x.UT()
return tensor.MatMul(hidErrs, x)
})

// gradient update
m.do(func() (tensor.Tensor, error) { return tensor.Mul(dcost_dfinal, learnRate, tensor.UseUnsafe()) })
m.do(func() (tensor.Tensor, error) { return tensor.Mul(dcost_dhidden, learnRate, tensor.UseUnsafe()) })
m.do(func() (tensor.Tensor, error) { return tensor.Add(nn.final, dcost_dfinal, tensor.UseUnsafe()) })
m.do(func() (tensor.Tensor, error) { return tensor.Add(nn.hidden, dcost_dhidden, tensor.UseUnsafe()) })
return cost, m.err

There are several observations to be made:

  • You may note that parts of the body of the Predict method are repeated at the top of the Train method
  • The tensor.UseUnsafe() function option is used a lot

This is going to be a pain point when we start scaling up into deeper networks. As such, in the next chapter, we will explore the possible solutions to these problems.

..................Content has been hidden....................

You can't read the all page of ebook, please click here login for view all page.
Reset
18.117.98.250