Learning to use decision trees

We already learned the power and flexibility of decision trees for adding a decision-making component to our game. Furthermore, we can also build them dynamically through supervised learning. That's why we're revisiting them in this chapter.

There are several algorithms for building decision trees that are suited for different uses such as prediction and classification. In our case, we'll explore decision-tree learning by implementing the ID3 algorithm.

Getting ready…

Despite having built decision trees in a previous chapter, and the fact that they're based on the same principles as the ones that we will implement now, we will use different data types for our implementation needs in spite of the learning algorithm.

We will need two data types: one for the decision nodes and one for storing the examples to be learned.

The code for the DecisionNode data type is as follows:

using System.Collections.Generic;

public class DecisionNode
{
    public string testValue;
    public Dictionary<float, DecisionNode> children;

    public DecisionNode(string testValue = "")
    {
        this.testValue = testValue;
        children = new Dictionary<float, DecisionNode>();
    }
}

The code for the Example data type is as follows:

using UnityEngine;
using System.Collections.Generic;

public enum ID3Action
{
    STOP, WALK, RUN
}

public class ID3Example : MonoBehaviour
{
    public ID3Action action;
    public Dictionary<string, float> values;
    
    public float GetValue(string attribute)
    {
        return values[attribute];
    }
}

How to do it…

We will create the ID3 class with several functions for computing the resulting decision tree.

  1. Create the ID3 class:
    using UnityEngine;
    using System.Collections.Generic;
    public class ID3 : MonoBehaviour
    {
        // next steps
    }
  2. Start the implementation of the function responsible for splitting the attributes into sets:
    public Dictionary<float, List<ID3Example>> SplitByAttribute(
            ID3Example[] examples,
            string attribute)
    {
        Dictionary<float, List<ID3Example>> sets;
        sets = new Dictionary<float, List<ID3Example>>();
        // next step
    }
  3. Iterate though all the examples received, and extract their value in order to assign them to a set:
    foreach (ID3Example e in examples)
    {
        float key = e.GetValue(attribute);
        if (!sets.ContainsKey(key))
            sets.Add(key, new List<ID3Example>());
        sets[key].Add(e);
    }
    return sets;
  4. Create the function for computing the entropy for a set of examples:
    public float GetEntropy(ID3Example[] examples)
    {
        if (examples.Length == 0) return 0f;
        int numExamples = examples.Length;
        Dictionary<ID3Action, int> actionTallies;
        actionTallies = new Dictionary<ID3Action, int>();
        // next steps
    }
  5. Iterate through all of the examples to compute their action quota:
    foreach (ID3Example e in examples)
    {
        if (!actionTallies.ContainsKey(e.action))
            actionTallies.Add(e.action, 0);
        actionTallies[e.action]++;
    }
  6. Compute the entropy :
    int actionCount = actionTallies.Keys.Count;
    if (actionCount == 0) return 0f;
    float entropy = 0f;
    float proportion = 0f;
    foreach (int tally in actionTallies.Values)
    {
        proportion = tally / (float)numExamples;
        entropy -= proportion * Mathf.Log(proportion, 2);
    }
    return entropy;
  7. Implement the function for computing the entropy for all the sets of examples. This is very similar to the preceding one; in fact, it uses it:
    public float GetEntropy(
            Dictionary<float, List<ID3Example>> sets,
            int numExamples)
    {
        float entropy = 0f;
        foreach (List<ID3Example> s in sets.Values)
        {
            float proportion;
            proportion = s.Count / (float)numExamples;
            entropy -= proportion * GetEntropy(s.ToArray());
        }
        return entropy;
    }
  8. Define the function for building a decision tree:
    public void MakeTree(
            ID3Example[] examples,
            List<string> attributes,
            DecisionNode node)
    {
        float initEntropy = GetEntropy(examples);
        if (initEntropy <= 0) return;
        // next steps
    }
  9. Declare and initialize all the required members for the task:
    int numExamples = examples.Length;
    float bestInfoGain = 0f;
    string bestSplitAttribute = "";
    float infoGain = 0f;
    float overallEntropy = 0f;
    Dictionary<float, List<ID3Example>> bestSets;
    bestSets = new Dictionary<float, List<ID3Example>>();
    Dictionary<float, List<ID3Example>> sets;
  10. Iterate through all the attributes in order to get the best set based on the information gain:
    foreach (string a in attributes)
    {
        sets = SplitByAttribute(examples, a);
        overallEntropy = GetEntropy(sets, numExamples);
        infoGain = initEntropy - overallEntropy;
        if (infoGain > bestInfoGain)
        {
            bestInfoGain = infoGain;
            bestSplitAttribute = a;
            bestSets = sets;
        }
    }
  11. Select the root node based on the best split attribute, and rearrange the remaining attributes for building the rest of the tree:
    node.testValue = bestSplitAttribute;
    List<string> newAttributes = new List<string>(attributes);
    newAttributes.Remove(bestSplitAttribute);
  12. Iterate through all the remaining attributes. calling the function recursively:
    foreach (List<ID3Example> set in bestSets.Values)
    {
        float val = set[0].GetValue(bestSplitAttribute);
        DecisionNode child = new DecisionNode();
        node.children.Add(val, child);
        MakeTree(set.ToArray(), newAttributes, child);
    }

How it works…

The class is modular in terms of functionality. It doesn't store any information but is able to compute and retrieve everything needed for the function that builds the decision tree. SplitByAttribute takes the examples and divides them into sets that are needed for computing their entropy. ComputeEntropy is an overloaded function that computes a list of examples and all the sets of examples using the formulae defined in the ID3 algorithm. Finally, MakeTree works recursively in order to build the decision tree, getting hold of the most significant attribute.

See also

  • Chapter 3, Decision Making, the Choosing through a decision tree recipe
..................Content has been hidden....................

You can't read the all page of ebook, please click here login for view all page.
Reset
3.144.97.187