Home » machine learning » decision trees » Decision Trees Part 3: Pruning your Tree

Decision Trees Part 3: Pruning your Tree

Ok last time we learned how to automatically grow a tree, using a greedy algorithm to choose splits that maximise a given ‘metric’.

The problem is that the trees become huge and undoubtedly overfit to our data, meaning that it will generalize to unseen data poorly. This is the age old battle between signal and noise, where we have to build in a level of fuzziness to the model – pixelating the trees to see the forest.

So first we come up with a way of measuring what we want to achieve, then we find an algorithm to maximise that quantity.

Let’s introduce some terminology and notation. For some model built on training data, the resubstitution error is the error (for some notion of error) on the training data. For a tree T we may write
R(T) for this quantity.

In our case we will use the misclassification rate for our error, since in the Kaggle competition on the Titanic data  this is how the predictions are scored.

Now we quantify tree complexity, write |T| for the number of leaves. There is some balance to be struck between this quantities, which we will call the ‘\alpha-cost’ of the tree, given by
R_{\alpha}(T) = R(T) = \alpha |T|.

So \alpha is the price we pay for adding an extra leaf, which must be paid for in reduced training error.

So given our tree T and $\alpha$, we seek a subtree $T^*$ that minimises R_{\alpha}, or at least one that gives us a small value (because there are a lot of subtrees).

One thing we can try is look at each node that is not a leaf, fuse it, and calculate the cost, choose the node(if any) which gives the best improvement in tree cost, if none of them give an improvement, we’re done, else repeat the process.

Ok let’s get coding. First we want to be able to measure the error rate at a node, I call this the conditional error rate, because it assumes that you are already in the node. Since we will be calculating this quite a lot, I have added it to the node so it won’t recalculate it when called again.

    def conditional_error_rate(self, node, new_data=False):
        if node.error==None:
            if new_data:
                node.error= node.local_data[self.response].value_counts()
                if node.predicted_class in node.error.keys():
                    node.error = node.error[node.predicted_class]/float(node.size)
                    node.error= 1-node.error
                node.error= 1-node.predicted_prob
        return node.error

In the training or resubstitution case, where new_data=False, the missclassification rate is just one minus the predicted class. In the case we want to look at later, where after we have built our tree we look at new data, we need to know how often the class appears in the new data. If the predicted class is not in the new data at all, the missclassification rate is total, 1.

That was the conditional error rate, but what contribution does a node make to the overall error rate? To answer that we need to know the probability of being at that node, then multiply the conditional rate by that probability.

    def node_prob(self, node):
        return node.size/float(self.data_size)

    def local_error_rate(self, node, new_data=False):
        return self.conditional_error_rate(node, new_data)*self.node_prob(node)

So local_error_rate tells us the error a node contributes if we have fused at that node. To compare with not fusing, we look at the subtree rooted at that node, and calculate the sum of its leaves local errors.

    def node_error(self,node):
        return sum(self.local_error_rate(leaf) for leaf in self.get_node_leaves(node))

    def get_node_leaves(self, node): #get leaves of subtree at a node
        for descendant in self.node_iter_down(node):
            if descendant in self.leaves:
        return leaves

    def node_iter_down(self,base, first=True): #iterates top to bottom over the subtree at a node
        if first:
            yield base
            if base.children==None:
        if base.children==None:
            yield base
            for child in base.children:
                yield child
                for node in self.node_iter_down(child, first=False):
                    yield node

Great! Now if we want to know R(T) we just sum all of the local errors at the leaves.

    def error(self, new_data=False):
        return sum(self.local_error_rate(leaf,new_data) for leaf in self.leaves)

This is looking promising, now what we need to do is find a way of iterating through each possible prune, and calculate the change in tree cost. The good thing about the change in cost is that we only need to know about what’s going on in the subtree rooted at a node: we lose \alpha times the number of leaves of the subtree minus one to the cost by fusing, but gain the difference between the node error and its local error rate.

The change in cost by fusing a node is given by

    def node_cost(self, node, alpha):
        fused_error = self.local_error_rate(node)
        unfused_error = self.node_error(node)
        number_leaves_lost = len(self.get_node_leaves(node))-1
        return error_diff - alpha*number_leaves_lost

And now we iterate over possible prunes:

    def get_best_prune(self, alpha):
        best_node_cost, best_node = min(self.iter_prune_cost(alpha),
                                          key=lambda x: x[0])
        if best_node_cost             return best_node
        else: #If no good prunes then we return zip.
            return None

    def iter_prune_cost(self, alpha):
        for node in self.vertices:
            if not node in self.leaves:
                yield [self.node_cost(node, alpha),node]
        if check:  #This is a base case, if the only node left is the root, then no more can be done.

Having laid the groundwork, the final method is simplicity itself.

    def prune_tree(self, alpha):
        if best_prune==None:

And that is that! Let’s give it a whirl. I’ve added a method called train to the class which you can see in the code I have put up, this allows us to pass a dictionary of parameters to build the tree. You can also see that I have seperated out the regression and classificiation trees into their own subclasses. First we try \alpha=0, meaning do no pruning.

import cleantitanic as ct
df=df[['survived', 'pclass', 'sex', 'age', 'sibsp', 'parch',
       'fare', 'embarked']]
data_type_dict={'survived':'nominal', 'pclass':'ordinal', 'sex':'nominal',
                'age':'ordinal', 'sibsp':'ordinal', 'parch':'ordinal',
                'fare':'ordinal', 'embarked':'nominal'}
parameters={'min_node_size':5, 'max_node_depth':20,
g.train(data=df,data_type_dict=data_type_dict, parameters=parameters)

Which produces this monstrosity.


Yikes! Let’s see it with \alpha=0.001.

alpha001treeAnd \alpha=0.00135.

alpha00135treeAnd \alpha=0.00135.

alpha0015treeAnd finally  \alpha=0.002.

alpha002treeSo you can see that the tree you get depends entirely on \alpha, so which one should you use? We will attempt to answer that question next time where we will discuss cross-validation and OOB(out of bag resampling).

Here is the code I have been promising, it also includes the means to test unseen data, which we will discuss next time.



  1. […] Last time we developed a method for pruning decision trees based on a compromise between training error and tree complexity, the latter of which is weighted with a parameter . Choose a different and get a different pruned tree. So how can we choose such a parameter? […]

  2. […] about how they were created you can start here with representation, then here for growing the tree, here for pruning the […]

  3. Wonderful article! We will be linking to this great content on our website.
    Keep up the good writing.

  4. Jamal says:

    Good post! We are linking to this great content on our
    site. Keep up the great writing.

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out / Change )

Twitter picture

You are commenting using your Twitter account. Log Out / Change )

Facebook photo

You are commenting using your Facebook account. Log Out / Change )

Google+ photo

You are commenting using your Google+ account. Log Out / Change )

Connecting to %s

%d bloggers like this: