With SVMs

The Shogun library also has an implementation of the multi-class SVM algorithm in the  CMulticlassLibSVM class. The instances of this class can be configured with a parameter named C, which is a measure of the allowance of misclassification with a kernel object. In the following example, we use an instance of the CGaussianKernel class for the kernel object. This object also has parameters for configuration, but we used only one named the combined_kernel_weight parameter because it gave the most reasonable configuration for our model after a series of experiments. Let's look at the code in the following block:

Some <CDenseFeatures<DataType>> features;
Some<CMulticlassLabels> labels;
Some<CDenseFeatures<DataType>> test_features;
Some<CMulticlassLabels> test_labels;

These are our train and test dataset objects' definition:

auto kernel = some<CGaussianKernel>(features, features, 5);
// one vs one classification
auto svm = some<CMulticlassLibSVM>();
svm->set_kernel(kernel);

Using these datasets, we initialized the CMulticlassLibSVM class object and configured its kernel, as follows:

 // search for hyper-parameters
auto root = some<CModelSelectionParameters>();
// C - how much you want to avoid misclassifying
CModelSelectionParameters* c = new CModelSelectionParameters("C");
root->append_child(c);
c->build_values(1.0, 1000.0, R_LINEAR, 100.);

auto params_kernel = some<CModelSelectionParameters>("kernel", kernel);
root->append_child(params_kernel);

auto params_kernel_width =
some<CModelSelectionParameters>("combined_kernel_weight");
params_kernel_width->build_values(0.1, 10.0, R_LINEAR, 0.5);

params_kernel->append_child(params_kernel_width);

Then, we configured cross-validation parameters objects to look for the best hyperparameters combination, as follows:

     index_t k = 3;
CStratifiedCrossValidationSplitting* splitting =
new CStratifiedCrossValidationSplitting(labels, k);

auto eval_criterium = some<CMulticlassAccuracy>();

auto cross =
some<CCrossValidation>(svm, features, labels, splitting,
eval_criterium);
cross->set_num_runs(1);

auto model_selection = some<CGridSearchModelSelection>(cross, root);
CParameterCombination* best_params =
wrap(model_selection->select_model(false));
best_params->apply_to_machine(svm);
best_params->print_tree();

Having configured the cross-validation parameters, we initialized the CCrossValidation class object and ran the grid-search process for model selection, as follows:

 // train SVM
svm->set_labels(labels);
svm->train(features);

// evaluate model on test data
auto new_labels = wrap(svm->apply_multiclass(test_features));

// estimate accuracy
auto accuracy = eval_criterium->evaluate(new_labels, test_labels);
std::cout << "svm " << name << " accuracy = " << accuracy << std::endl;

// process results
auto feature_matrix = test_features->get_feature_matrix();
for (index_t i = 0; i < new_labels->get_num_labels(); ++i) {
auto label_idx_pred = new_labels->get_label(i);
auto vector = feature_matrix.get_column(i);
...
}

When the best hyperparameters were found and applied to the model, we repeated training and did the evaluation.

Notice that, with the exception of the different parameters for model configuration, the code is the same as for the previous example. We created the parameters' tree, the cross-validation object, the same evaluation metrics object, and used the grid-search approach for finding the best combination of the model parameters. Then, we trained the model and used the apply_multiclass() method for evaluation. These facts show you that the library has a unified API for different algorithms, which allows us to try different models with minimal modifications in the code.

The following screenshot shows the results of applying the Shogun implementation of the SVM algorithm to our datasets:

Notice that SVM made another classification error, and we have incorrect labels in Dataset 2, Dataset 3, and Dataset 4. Other datasets were classified almost correctly.

..................Content has been hidden....................

You can't read the all page of ebook, please click here login for view all page.
Reset
3.12.166.131