Classifying MNIST dataset using random forest

In this section, we will show an example of a classification using the random forest. We will break down the code step-by-step so that you can understand the solution easily.

Step 1. Load and parse the MNIST dataset in LIVSVM format

// Load training data in LIBSVM format.
val data = MLUtils.loadLibSVMFile(spark.sparkContext, "data/mnist.bz2")

Step 2. Prepare the training and test sets

Split data into training (75%) and test (25%) and also set the seed for the reproducibility, as follows:

val splits = data.randomSplit(Array(0.75, 0.25), seed = 12345L)
val training = splits(0).cache()
val test = splits(1)

Step 3. Run the training algorithm to build the model

Train a random forest model with an empty categoricalFeaturesInfo. This required since all the features are continuous in the dataset:

val numClasses = 10 //number of classes in the MNIST dataset
val categoricalFeaturesInfo = Map[Int, Int]()
val numTrees = 50 // Use more in practice.More is better
val featureSubsetStrategy = "auto" // Let the algorithm choose.
val impurity = "gini" // see above notes on RandomForest for explanation
val maxDepth = 30 // More is better in practice
val maxBins = 32 // More is better in practice
val model = RandomForest.trainClassifier(training, numClasses, categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins)

Note that training a random forest model is very resource extensive. Consequently, it will take more memory, so beware of OOM. I would say increase the Java heap space prior to running this code.

Step 4. Compute raw scores on the test set

Compute raw scores on the test set so that we can evaluate the model using the aforementioned performance metrics, as follows:

val scoreAndLabels = test.map { point =>
val score = model.predict(point.features)
(score, point.label)
}

Step 5. Instantiate a multiclass metrics for the evaluation

val metrics = new MulticlassMetrics(scoreAndLabels)

Step 6. Constructing the confusion matrix

println("Confusion matrix:")
println(metrics.confusionMatrix)

The preceding code prints the following confusion matrix for our classification:

Figure 31: Confusion matrix generated by the random forest classifier

Step 7. Overall statistics

Now let's compute the overall statistics to judge the performance of the model:

val accuracy = metrics.accuracy
println("Summary Statistics")
println(s"Accuracy = $accuracy")
// Precision by label
val labels = metrics.labels
labels.foreach { l =>
println(s"Precision($l) = " + metrics.precision(l))
}
// Recall by label
labels.foreach { l =>
println(s"Recall($l) = " + metrics.recall(l))
}
// False positive rate by label
labels.foreach { l =>
println(s"FPR($l) = " + metrics.falsePositiveRate(l))
}
// F-measure by label
labels.foreach { l =>
println(s"F1-Score($l) = " + metrics.fMeasure(l))
}

The preceding code segment produces the following output, containing some performance metrics, such as accuracy, precision, recall, true positive rate , false positive rate, and F1 score:

Summary Statistics:
------------------------------
Precision(0.0) = 0.9861932938856016
Precision(1.0) = 0.9891799544419134
.
.
Precision(8.0) = 0.9546079779917469
Precision(9.0) = 0.9474747474747475

Recall(0.0) = 0.9778357235984355
Recall(1.0) = 0.9897435897435898
.
.
Recall(8.0) = 0.9442176870748299
Recall(9.0) = 0.9449294828744124

FPR(0.0) = 0.0015387997362057595
FPR(1.0) = 0.0014151646059883808
.
.
FPR(8.0) = 0.0048136532710962
FPR(9.0) = 0.0056967572304995615

F1-Score(0.0) = 0.9819967266775778
F1-Score(1.0) = 0.9894616918256907
.
.
F1-Score(8.0) = 0.9493844049247605
F1-Score(9.0) = 0.9462004034969739

Now let's compute the overall statistics, as follows:

println(s"Weighted precision: ${metrics.weightedPrecision}")
println(s"Weighted recall: ${metrics.weightedRecall}")
println(s"Weighted F1 score: ${metrics.weightedFMeasure}")
println(s"Weighted false positive rate: ${metrics.weightedFalsePositiveRate}")
val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / test.count()
println("Accuracy = " + (1-testErr) * 100 + " %")

The preceding code segment prints the following output, containing weighted precision, recall, F1 score, and false positive rate:

Overall statistics
----------------------------
Weighted precision: 0.966513107682512
Weighted recall: 0.9664712469534286
Weighted F1 score: 0.9664794711607312
Weighted false positive rate: 0.003675328222679072
Accuracy = 96.64712469534287 %

The overall statistics say that the accuracy of the model is more than 96%, which is better than that of logistic regression. However, we can still improve it using better model tuning.

..................Content has been hidden....................

You can't read the all page of ebook, please click here login for view all page.
Reset
3.22.60.167