Cross-validation

Another way we could test how well the neural network is learning is to cross-validate. The neural network could learn very well on the training data, in essence, memorizing which collections of pixels will result in a particular label. However, to check that the machine learning algorithm generalizes well, we need to show the neural network some data it's never seen before.

Here's the code to do so:

  log.Printf("Start testing")
testImgs, err := readImageFile(os.Open("t10k-images.idx3-ubyte"))
if err != nil {
log.Fatal(err)
}

testlabels, err := readLabelFile(os.Open("t10k-labels.idx1-ubyte"))
if err != nil {
log.Fatal(err)
}

testData := prepareX(testImgs)
testLbl := prepareY(testlabels)
shape := testData.Shape()
testData2, err := zca(testData)
if err != nil {
log.Fatal(err)
}

visualize(testData, 10, 10, "testData.png")
visualize(testData2, 10, 10, "testData2.png")

var correct, total float64
var oneimg, onelabel tensor.Tensor
var predicted, errcount int
for i := 0; i < shape[0]; i++ {
if oneimg, err = testData.Slice(makeRS(i, i+1)); err != nil {
log.Fatalf("Unable to slice one image %d", i)
}
if onelabel, err = testLbl.Slice(makeRS(i, i+1)); err != nil {
log.Fatalf("Unable to slice one label %d", i)
}
if predicted, err = nn.Predict(oneimg); err != nil {
log.Fatalf("Failed to predict %d", i)
}

label := argmax(onelabel.Data().([]float64))
if predicted == label {
correct++
} else if errcount < 5 {
visualize(oneimg, 1, 1, fmt.Sprintf("%d_%d_%d.png", i, label, predicted))
errcount++
}
total++
}
fmt.Printf("Correct/Totals: %v/%v = %1.3f ", correct, total, correct/total)

Note that the code is largely the same as the code before in the main function. The exception is that instead of calling nn.Train, we call nn.Predict. Then we check to see whether the label is the same as what we predicted.

Here are the tweakable parameters:

After running (it takes 6.5 minutes), and tweaking various parameters, I ran the code and got the following results:

$ go build . -o chapter7
$ ./chapter7
Corerct/Totals: 9719/10000 = 0.972

A simple three-layer neural network leads to a 97% accuracy! This is, of course, not close to state of the art. We'll build one that goes up to 99.xx% in the next chapter, but requires a big shift of mindset.

Training a neural network takes time. It's often wise to want to save the result of the neural network. The *tensor.Dense type implements gob.GobEncoder and gob.GobDecoder and to save the neural network to disk, simply save the weights (nn.hidden and nn.final). For an additional challenge, write a gob encoder for those weight matrices and save/load the functionality. 

Furthermore, let's have a look at a few of the things that was wrongly classified. In the preceding code, this snippet writes out five wrong predictions:

    if predicted == label {
correct++
} else if errcount < 5 {
visualize(oneimg, 1, 1, fmt.Sprintf("%d_%d_%d.png", i, label, predicted))
errcount++
}

And here they are:

  

In the first image, the neural network classified it as a 0, while the true value is 6. As you can see, it is an easy mistake to make. The second image shows a 2, and the neural network classified it as a 4. You may be inclined to think that looks a bit like a 4. And, lastly, if you are an American reader, the chances are you have been exposed to the Palmer handwriting method. If so, I'll bet that you might classify the last picture as a 7, instead of a 2, which is exactly what the neural network predicts. Unfortunately, the real label is that it's a 2. Some people just have terrible handwriting.

..................Content has been hidden....................

You can't read the all page of ebook, please click here login for view all page.
Reset
13.58.120.57