Home » maths » graph theory » Decision Trees Part 1: Representation

Decision Trees Part 1: Representation

Now that we have data structures describing graphs, we can put them to work and do some machine learning! We’re going to be looking at decision trees. The wikipedia page on Decison Tree learning has an example from our favourite data set, the Titanic survivors, so let’s steal their picture to illustrate.

The idea is you start at the root, each node describes an attribute, here the root attribute is ‘sex’. Then the downward edges represent a partition of the attribute, in a binary tree this just means a yes or no question, which leads you to another node, and so on, until you arrive at a leaf.

Each leaf represents a subset of the data, and in the picture the survival rates for this subset are given. So if you have a bunch of training data (where the response is known) you can use this to let each leaf give a prediction by taking the mode (or mean for numeric response).

A regression tree is where the response is numeric, and a classification tree is where the response is categorical.

The real challenge is coming up with a tree to use as your model, but the first order of business is to be able to represent a decision tree. So let’s get to work.

To simplify things we’ll insist our trees be binary, and we don’t lose any generality by doing so.

The basic idea is that the nodes will have a simple boolean test on an attribute, for numeric it returns true if datapoint.attribute <= pivot and false otherwise, and for a categorical it returns true if datapoint.attribute is in pivot and false otherwise.

Thus each node has a local ‘filter’ on the dataset, which determines its local data. Similarly we can ask a node: to which of your children does this datapoint belong?

I want the nodes themselves to keep track of a lot of this as it feels more natural, so I begin by creating a subclass DecisionNode.

This will be designed to accept a pandas dataframe so we can help ourself to all of its goodies.

import tigraphs as tig
import numpy as np
import pandas as pd
class DecisionNode(tig.BasicNode,object):
    def __init__(self, **kwargs):
        super(DecisionNode, self).__init__(**kwargs)
        self.depth =0
    def local_filter(self, data): #filters data
        pass first=False)
    def get_next_node_or_predict(self, datapoint):

As you can see I have kept it non-specific on how it filters so we can reuse it later. Since it will be in a binary tree, I’ve given it some convenient class attributes to keep track of the left child (corresponding to passing the test) and right child (corresponding to failing the test).

Now we’ll create a binary tree, inheriting from the nary-tree class we made last time and also using our new DecisionNode class.

class DecisionTree(tig.return_nary_tree_class(directed=True), object):
    def __init__(self, data=None,response='',Vertex=DecisionNode,**kwargs):
        super(DecisionTree, self).__init__(N=2, Vertex=Vertex, **kwargs)
        self.data =data
        self.response=response #data attribute we're trying to predict
    def split_vertex(self, vertex):
        super(DecisionTree, self).split_vertex(vertex)
        vertex.left = vertex.children[0]
        vertex.right = vertex.children[1]
    def fuse_vertex(self, vertex):
        super(DecisionTree, self).fuse_vertex(vertex)
        vertex.left, vertex.right = None, None

That was pretty straightforward, now we’ll be more specific and introduce the pivots to chop-up the data with.

class PivotDecisionNode(DecisionNode,object):
    def __init__(self, **kwargs):
        super(PivotDecisionNode, self).__init__(**kwargs)
        self.split_attribute = None
    def local_filter(self, data): #filters the data based on parent's pivot
        if self.parent==None:
            self.size = len(data)
            return data
        attribute = self.parent.split_attribute
        pivot = self.parent.pivot
        if type(pivot)==set:
            ret= data[attribute].isin(pivot)
            ret = data[attribute] &lt;= pivot
        if self == self.parent.left:
        return ret
    def get_next_node_or_predict(self, datapoint): #tells us where to find a prediction, or returns one
        if self.children == None:
            return self.prediction
            if type(self.pivot) ==set:
                if datapoint[self.split_attribute] in self.pivot:
                    return self.left
                    return self.right
                if datapoint[self.split_attribute] &lt;=self.pivot:
                    return self.left
                    return self.right

So the node can now filter the data based on an attribute and pivot, notice that left and right children from the same parent have opposite filters.

Now we’ll use this to create a PivotDecisionTree class.

class PivotDecisionTree(DecisionTree, object):
    def __init__(self, data_type_dict={},metric_kind='Gini',
                 Vertex=PivotDecisionNode, min_node_size=5,
                 max_node_depth=4, **kwargs):
        super(PivotDecisionTree, self).__init__(Vertex=Vertex, **kwargs)
    def split_vertex(self, vertex, split_attribute, pivot):
        super(PivotDecisionTree, self).split_vertex(vertex)
        vertex.pivot, vertex.split_attribute = pivot, split_attribute
    def fuse_vertex(self, vertex):
        super(PivotDecisionTree, self).fuse_vertex(vertex)
        self.pivot, self.split_attribute = None, None
    def create_full_n_level(self, *args, **kwargs):
        raise AttributeError('This method is not appropriate as pivots are not specified')
    def set_node_prediction(self,node):
        if self.tree_kind == 'classification': #returns a probability for each class
            node.size = sum(node.prediction[key] for key in node.prediction.keys())
            node.prediction={ key : node.prediction[key]/node.size
                              for key in node.prediction.keys() }
        elif self.tree_kind == 'regression': #returns mean of the responses
            node.prediction = node.local_data[self.response].mean()
    def set_predictions(self):
        for node in self.vertices:

And now we can represent Decision Trees! To test it, let’s create a very simple one for the Titanic dataset from Kaggle. We’ll just have a single split on ‘sex’, which should be reasonable recalling this plot we produced awhile back.

So we’ll produce this decision tree ‘by hand’, and next time we’ll look at how to automate the process, ‘growing’ trees.

root = t.get_root()
t.split_vertex(vertex=t.get_root(), split_attribute='sex', pivot=set(['female']))
import cleantitanic as ct
data = ct.cleaneddf()[0]
root = t.get_root()
for child in root.children:
for leaf in t.leaves:
    print leaf.prediction
#producing the following output
{0: 0.25796178343949044, 1: 0.7420382165605095}
{0: 0.81109185441941078, 1: 0.18890814558058924}

From looking at the plot above I think you can guess which way around they are. As you can see creating a tree by hand is tedious and is best left to computers.

If you want to learn more I highly recommend having a read of this.



  1. […] so last time we looked at what a Decision Tree was, and how to represent one in Python using our DecisionNode, […]

  2. […] Decision Trees Part 1: Representation (triangleinequality.wordpress.com) […]

  3. […] learn about how they were created you can start here with representation, then here for growing the tree, here for pruning the […]

  4. hey great intro.

    there seem to be some markup errors is there a version available that doesn’t have them?

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Google+ photo

You are commenting using your Google+ account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )


Connecting to %s

%d bloggers like this: