Home » machine learning » decision trees » To prune or not to prune: cross-validation is the answer

To prune or not to prune: cross-validation is the answer

Last time we developed a method for pruning decision trees based on a compromise between training error and tree complexity, the latter of which is weighted with a parameter \alpha. Choose a different \alpha and get a different pruned tree.pruning.jpg So how can we choose such a parameter?

If data is plentiful, we can split our training data into training and holdout data, and use the latter to test which \alpha gives the best result. Our data is not so plentiful. The main problem is that the final model did not get to see the whole training set as it was built.

A similar idea is k-fold cross-validation. Partition the data into k equal sized subsets randomly. For each subset, use it as test data, and its complement as training data to build a model, test the model, record the average error over the k subsets and we have an estimate for the model’s generalisation error.

Now repeat (on the same k subsets) for each set of model parameters you wish to test. Choose the set of parameters minmising the generalisation error. Use these parameters to build a model on the whole data set. Done.

Ok so lets build a generator function to return the folded data.

import random
import pandas as pd
def cross_validate(no_folds, data, resample=False):
    rows = list(data.index)
    len_fold = int(N/no_folds)
    for i in range(no_folds):
        if i==no_folds-1:
            stop =N
            stop = start +len_fold
        test = data.ix[rows[start:stop]]
        train = data.ix[rows[:start]+rows[stop:]]
        if resample:
            no_resamples = N-train_len
            train_rows = list(train.index)
            random_extra_rows =[random.choice(train_rows) for row in range(no_resamples)]
            train_rows = train_rows+random_extra_rows
        yield {'test':test, 'train':train}

The idea is pretty simple, we take the list of row indexes and randomly shuffle them, then we just split the data set into k parts by choosing the rows from start*i to start+len(data)/k for the ith part, using the new mixed up row ordering.

Also observe that I have added an option to pump up the training sets to be the same size as the original data set using sampling with replacement.

I tried using this to select \alpha , but I found that the \alpha you get over-prunes the tree trained on the full data set. I believe this is because cross-validation has fewer unique values, and so consequently the leaves make up a bigger slice of the error pie, and so they will be more resistant to pruning and hence need more ‘encouragement’ to fuse through larger \alpha. As a result cross-validation tends to overestimate the optimum \alpha.

The new approach I settled upon was using something like cross-validation: I split the data into training and
testing portions, grow the tree on the training data, and prune it using the test data instead of the training data.
In this case \alpha is irrelevant, we set it to be 0.

The problem of course is that while our model has good information on how much to prune, it has only been trained
on half (say) the data. To alleviate this problem I also grew a tree on the test data, and pruned it using the
test data. This gives 2 models, or k-models if using k-fold cross-validation.

Then from each I take their predictions and combine them by taking the modal prediction. Of course the mode is
not very informative for only 2 predictions, so I repeated this process several times. This helps to reduce
variability from our random partitioning of the data. Here’s how to achieve this.


df=df[['Survived', 'Pclass', 'Sex', 'Age', 'SibSp', 'Parch', 
       'Fare', 'Embarked']]
data_type_dict={'Survived':'nominal', 'Pclass':'ordinal', 'Sex':'nominal', 
                'Age':'ordinal', 'SibSp':'ordinal', 'Parch':'ordinal', 
                'Fare':'ordinal', 'Embarked':'nominal'}

First I’ve just loaded the training and test data, and created the dictionary of data types that the tree expects.

def tree_train(data_type_dict, train_data,test_data, response, no_folds,
                   min_node_size, max_depth, no_iter):
        parameters={'min_node_size':min_node_size, 'max_node_depth':max_depth, 
                    'threshold':0, 'metric_kind':'Gini', 'alpha':0,
        for i in range(no_iter):
            for fold in cross_validate(no_folds, train_data):
                model.train(fold['train'], data_type_dict, parameters, prune=False)
                model.prune_tree(alpha=0, new_data=True)
                predictions.append(test_data.apply(model.predict, axis=1))
        return predictions    

This is pretty straightforward, for each iteration we chop up the data, train on all but one portion and prune with the remaining portion, repeat this for each fold and take predictions, returning a list of predictions at the end.

Next we need a way to combine predictions. We combine the list of predictions into a dataframe, then apply the mode to each row.

def combine_predictions(predictions):
    data_dict ={i:predictions[i] for i in range(len(predictions))}
    def mode(x):
        key,value = max(x.value_counts().iteritems(), key=lambda x:x[1])
        return key
    pred=d.apply(mode, axis=1)
    return predpredictions=tree_train(data_type_dict=data_type_dict, train_data=df,
                           test_data=df2, response='Survived', no_folds=10,
                           max_depth=50, min_node_size=5, no_iter=50)

And that’s it, next time I am going to walkthrough the whole process of cleaning the data through to building the model and making a submission to Kaggle.



  1. […] way I have used decision trees is explained here: basically using training and test data to prune a decision tree, and then swapping the roles of […]

  2. […] Simples! As ever you can download the code here, which also includes some code for cross-validation. […]

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Google+ photo

You are commenting using your Google+ account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )


Connecting to %s

%d bloggers like this: