Feed forward

Now that we have a conceptual idea of how the neural network works, let's write the forward propagation function. We'll call it Predict because, well, to predict, you merely need to run the function forward:

func (nn *NN) Predict(a tensor.Tensor) (int, error) {
if a.Dims() != 1 {
return -1, errors.New("Expected a vector")
}

var m maybe
hidden := m.do(func() (tensor.Tensor, error) { return nn.hidden.MatVecMul(a) })
act0 := m.do(func() (tensor.Tensor, error) { return hidden.Apply(sigmoid, tensor.UseUnsafe()) })

final := m.do(func() (tensor.Tensor, error) { return tensor.MatVecMul(nn.final, act0) })
pred := m.do(func() (tensor.Tensor, error) { return final.Apply(sigmoid, tensor.UseUnsafe()) })

if m.err != nil {
return -1, m.err
}
return argmax(pred.Data().([]float64)), nil
}

This is fairly straightforward, except for a few control structures. I should first explain that the API of the tensor package is quite expressive in the sense in that it allows the user multiple ways of doing the same thing, albeit with different type signatures. Briefly, the patterns are the following:

  • tensor.BINARYOPERATION(a, b tensor.Tensor, opts ...tensor.FuncOpt) (tensor.Tensor, error)
  • tensor.UNARYOPERATION(a tensor.Tensor, opts ...tensor.FuncOpt)(tensor.Tensor, error)
  • (a *tensor.Dense) BINARYOPERATION (b *tensor.Dense, opts ...tensor.FuncOpt) (*tensor.Dense, error)
  • (a *tensor.Dense) UNARYOPERATION(opts ...tensor.FuncOpt) (*tensor.Dense, error)

Key things to note are package level operations (tensor.Add, tensor.Sub , and so on) take one or more tensor.Tensors and return a tensor.Tensor and an error. There are multiple things that fulfill a tensor.Tensor interface, and the tensor package provides two structural types that fulfill the interface:

  • *tensor.Dense: A representation of of a densely packed tensor
  • *tensor.CS: A memory-efficient representation of a sparsely packed tensor with the data arranged in compressed sparse columns/row format

For the most part, the most commonly used type of tensor.Tensor is the *tensor.Dense type. The *tensor.CS data structure is only used for very specific memory-constrained optimizations for specific algorithms. We shan't talk more about the *tensor.CS type in this chapter.

In addition to the package level operations, each specific type also has methods that they implement. *tensor.Dense's methods (.Add(...), .Sub(...), and so on) take one or more *tensor.Dense and return *tensor.Dense and an error.

..................Content has been hidden....................

You can't read the all page of ebook, please click here login for view all page.
Reset
18.191.176.5