Another way we could test how well the neural network is learning is to cross-validate. The neural network could learn very well on the training data, in essence, memorizing which collections of pixels will result in a particular label. However, to check that the machine learning algorithm generalizes well, we need to show the neural network some data it's never seen before.
Here's the code to do so:
log.Printf("Start testing")
testImgs, err := readImageFile(os.Open("t10k-images.idx3-ubyte"))
if err != nil {
log.Fatal(err)
}
testlabels, err := readLabelFile(os.Open("t10k-labels.idx1-ubyte"))
if err != nil {
log.Fatal(err)
}
testData := prepareX(testImgs)
testLbl := prepareY(testlabels)
shape := testData.Shape()
testData2, err := zca(testData)
if err != nil {
log.Fatal(err)
}
visualize(testData, 10, 10, "testData.png")
visualize(testData2, 10, 10, "testData2.png")
var correct, total float64
var oneimg, onelabel tensor.Tensor
var predicted, errcount int
for i := 0; i < shape[0]; i++ {
if oneimg, err = testData.Slice(makeRS(i, i+1)); err != nil {
log.Fatalf("Unable to slice one image %d", i)
}
if onelabel, err = testLbl.Slice(makeRS(i, i+1)); err != nil {
log.Fatalf("Unable to slice one label %d", i)
}
if predicted, err = nn.Predict(oneimg); err != nil {
log.Fatalf("Failed to predict %d", i)
}
label := argmax(onelabel.Data().([]float64))
if predicted == label {
correct++
} else if errcount < 5 {
visualize(oneimg, 1, 1, fmt.Sprintf("%d_%d_%d.png", i, label, predicted))
errcount++
}
total++
}
fmt.Printf("Correct/Totals: %v/%v = %1.3f ", correct, total, correct/total)
Note that the code is largely the same as the code before in the main function. The exception is that instead of calling nn.Train, we call nn.Predict. Then we check to see whether the label is the same as what we predicted.
Here are the tweakable parameters:
After running (it takes 6.5 minutes), and tweaking various parameters, I ran the code and got the following results:
$ go build . -o chapter7
$ ./chapter7
Corerct/Totals: 9719/10000 = 0.972
A simple three-layer neural network leads to a 97% accuracy! This is, of course, not close to state of the art. We'll build one that goes up to 99.xx% in the next chapter, but requires a big shift of mindset.
Furthermore, let's have a look at a few of the things that was wrongly classified. In the preceding code, this snippet writes out five wrong predictions:
if predicted == label {
correct++
} else if errcount < 5 {
visualize(oneimg, 1, 1, fmt.Sprintf("%d_%d_%d.png", i, label, predicted))
errcount++
}
And here they are:
In the first image, the neural network classified it as a 0, while the true value is 6. As you can see, it is an easy mistake to make. The second image shows a 2, and the neural network classified it as a 4. You may be inclined to think that looks a bit like a 4. And, lastly, if you are an American reader, the chances are you have been exposed to the Palmer handwriting method. If so, I'll bet that you might classify the last picture as a 7, instead of a 2, which is exactly what the neural network predicts. Unfortunately, the real label is that it's a 2. Some people just have terrible handwriting.