Home » data structures » How to Balance your Binary Search Trees – AVL Trees

How to Balance your Binary Search Trees – AVL Trees

Last time we introduced the binary search tree (BST) and saw that they could do inserts and deletions in O(h) time where h is the height of the tree. We also saw that it could find the successor or predecessor to a node in the same time, and hence that it could sort a list in O(nh) time, where n is the length of the list.

Ideally we would have h=O(\log n), and this is often the case ‘on average’, but the worst case is h=n, which will occur if we start with a sorted list! So today we are going to discuss how to have our tree try to balance itself automagically so as to keep its height O(\log n).

AVL-Trees

Recall that the defining characteristic of a BST is the so-called ‘search property’: for any subtree rooted at a node, the stuff on the left is smaller, and the stuff on the right is greater.

We now introduce another property or invariant that we will insist upon and call it the ‘plus or minus property’ (pm) or ‘balance property’, and it is satisfied by a node if the height of its left subtree differs from the height of its right subtree by no more than one. We call a BST that satisfies pm an ‘AVL tree’ (AVL).

So we will give each node an extra field node.height, and node.balance, where node.balance  is equal to the height on the left minus the height on the right. Hence it will be 0 when they are equal, +1 when the left is larger, and -1 when the right is larger.

Before we go any further, you might be wondering whether or not this property forces the tree to have O(\log n) height, and indeed it does, so let’s prove that to ourselves before coercing our trees to have this property.

Proof That AVL Trees are Balanced

Define n_h to be the smallest number of nodes you can build a AVL of height h with. That is, any AVL of height h must have at least n_h nodes.

Now suppose that we had shown that h=O(\log n_h), or in particular that h \le K \log n_h for all h. Now take any AVL tree of height h and having n nodes. Then n \ge n_h \Rightarrow h \le K \log n_h \le K \log n \Rightarrow h=O(\log n) as desired.

So we just need to show that h=O(\log n_h). The way to do this is by setting up a recurrence relation. Indeed we will see that n_{h+2} = n_{h+1}+n_{h}+1.

Note_070714_124407_0

Note_070714_124407_1

Now using this notice that:

n_{h+2} = n_{h+1}+n_{h} + 1 \ge 2n_h.

And so if h=2k then applying this repeatedly yields:

n_h \ge 2^{k-1} \cdot n_2 = 2^{k} = 2^{h/2}.

On the other hand if h=2k+1 then we apply till we get to n_1=1 giving:

n_h \ge 2^k \cdot n_1=2^k \ge 2^{h/2+1}.

Combining these gives n_h \ge 2^{h/2}. Now taking logs base 2 gives

\log n_h \ge \frac{h}{2} \Rightarrow h=O(\log n_h) \Rightarrow h=O(\log n).

And we are done.

 Keeping Track of the AVL Property

Now that we agree that the AVL or pm property is a thing worth having, we should alter our code for a BST so that the nodes have  height and balance fields, and so that the insertion and deletion methods update this accordingly. Once we have done that we will look at how to enforce the balance property.

Now we could use class inheritance and modify the BST class we built last time, but some of the code is a bit kludgy and so I will implement this from scratch, although I won’t go into detail about the stuff that is the same from last time.

We begin with the AVLNode and gift it with a sense of balance.

class AVLNode(object):
    def __init__(self, key):
        self.key=key
        self.right_child=None
        self.left_child=None
        self.parent=None
        self.height=0
        self.balance=0

    def update_height(self, upwards=True):
        #If upwards we go up the tree correcting heights and balances,
        #if not we just correct the given node.
        if self.left_child is None:
            #Empty left tree.
            left_height = 0
        else:
            left_height = self.left_child.height+1
        if self.right_child is None:
            #Empty right tree.
            right_height = 0
        else:
            right_height = self.right_child.height+1
        #Note that the balance can change even when the height does not,
        #so change it before checking to see if height needs updating.
        self.balance = left_height-right_height
        height = max(left_height, right_height)
        if self.height != height:
            self.height = height
            if self.parent is not None:
                #We only need to go up a level if the height changes.
                if upwards:
                    self.parent.update_height()

    def is_left(self):
        #Handy to find out whether a node is a left or right child or neither.
        if self.parent is None:
            return self.parent
        else:
            return self is self.parent.left_child

I think the code, and comments, should explain everything if you read my last post. Let’s also pretend that we had a magic function balance(node), which will balance a node and then move up ancestors and do the same, we won’t worry about what exactly it will do, but add it into our insert routine.

class AVLTree(object):
    def __init__(self):
        self.root =None

    def insert(self, key, node=None):
        #The first call is slightly different.
        if node is None:
            #First call, start node at root.
            node = self.root
            if node is None:
                #Empty tree, create root.
                node = AVLNode(key=key)
                self.root=node
                return node
            else:
                ret= self.insert(key=key, node=node)
                self.balance(ret)
                return ret
        #Not a first call.
        if node.key ==key:
            #No need to insert, key already present.
            return node
        elif node.key >key:
            child = node.left_child
            if child is None:
                #Reached the bottom, insert node and update heights.
                child = AVLNode(key=key)
                child.parent=node
                node.left_child = child
                node.update_height()
                return child
            else:
                return self.insert(key=key, node=child)
        elif node.key < key:
            child = node.right_child
            if child is None:
                #Reached the bottom, insert node and update heights.
                child = AVLNode(key=key)
                child.parent=node
                node.right_child = child
                return child
            else:
                return self.insert(key=key, node=child)
        else:
            print "This shouldn't happen."

    def find(self, key, node=None):
        if node is None:
            #First call.
            node=self.root
            if self.root is None:
                return None
            else:
                return self.find(key, self.root)
        #Now we handle nonfirst calls.
        elif node.key == key:
            #Found the node.
            return node
        elif key < node.key:
            if node.left_child is None:
                #If key not in tree, we return a node that would be its parent.
                return node
            else:
                return self.find(key,node.left_child)
        else:
            if node.right_child is None:
                return node
            else:
                return self.find(key, node.right_child)

    def delete(self, key, node=None):
        #Delete key from tree.
        if node is None:
            #Initial call.
            node = self.find(key)
            if (node is None) or (node.key != key):
                #Empty tree or key not in tree.
                return

        if (node.left_child is None) and (node.right_child is not None):
            #Has one right child.
            right_child = node.right_child
            left = node.is_left()
            if left is not None:
                parent=node.parent
                if not left:
                    parent.right_child=right_child
                else:
                    parent.left_child=right_child
                right_child.parent =parent
                self.balance(parent)
            else:
                right_child.parent=None
                self.root = right_child
                #No need to update heights or rebalance.

        elif (node.left_child is not None) and (node.right_child is None):
            #Has one left child.
            left_child = node.left_child
            left= node.is_left()
            if left is not None:
                parent=node.parent
                if left:
                    parent.left_child=left_child
                else:
                    parent.right_child=right_child
                left_child.parent =parent

                self.balance(parent)
            else:
                left_child.parent=None
                self.root=left_child
        elif node.left_child is None:
            #Has no children.
            parent = node.parent
            if parent is None:
                #Deleting a lone root, set tree to empty.
                self.root = None
            else:
                if parent.left_child is node:
                    parent.left_child =None
                else:
                    parent.right_child=None
                self.balance(parent)
        else:
            #Node has two childen, swap keys with successor node
            #and delete successor node.
            right_most_child = self.find_leftmost(node.right_child)
            node.key = right_most_child.key
            self.delete(key=node.key,node=right_most_child)
            #Note that updating the heights will be handled in the next
            #call of delete.

    def find_rightmost(self, node):
        if node.right_child is None:
            return node
        else:
            return self.find_rightmost(node.right_child)

    def find_leftmost(self, node):
        if node.left_child is None:
            return node
        else:
            return self.find_leftmost(node.left_child)

    def find_next(self, key):
        node = self.find(key)
        if (node is None) or (node.key != key):
            #Key not in tree.
            return None
        else:
            right_child = node.right_child
            if right_child is not None:
                node= self.find_leftmost(right_child)
            else:
                parent = node.parent
                while(parent is not None):
                    if node is parent.left_child:
                        break
                    node = parent
                    parent = node.parent
                node=parent
            if node is None:
                #Key is largest in tree.
                return node
            else:
                return node.key

    def find_prev(self, key):
        node = self.find(key)
        if (node is None) or (node.key != key):
            #Key not in tree.
            return None
        else:
            left_child = node.left_child
            if left_child is not None:
                node= self.find_rightmost(left_child)
            else:
                parent = node.parent
                while(parent is not None):
                    if node is parent.right_child:
                        break
                    node = parent
                    parent = node.parent
                node=parent
            if node is None:
                #Key is largest in tree.
                return node
            else:
                return node.key
#I also include a new plotting routine to show the balances or keys of the node.
    def plot(self, balance=False):
        #Builds a copy of the BST in igraphs for plotting.
        #Since exporting the adjacency lists loses information about
        #left and right children, we build it using a queue.
        import igraph as igraphs
        G = igraphs.Graph()
        if self.root is not None:
            G.add_vertices(1)
        queue = [[self.root,0]]
        #Queue has a pointer to the node in our BST, and its index
        #in the igraphs copy.
        index=0
        not_break=True
        while(not_break):
            #At each iteration, we label the head of the queue with its key,
            #then add any children into the igraphs graph,
            #and into the queue.

            node=queue[0][0] #Select front of queue.
            node_index = queue[0][1]
            if not balance:
                G.vs[node_index]['label']=node.key
            else:
                 G.vs[node_index]['label']=node.balance
            if index ==0:
                #Label root green.
                G.vs[node_index]['color']='green'
            if node.left_child is not None:
                G.add_vertices(1)
                G.add_edges([(node_index, index+1)])
                queue.append([node.left_child,index+1])
                G.vs[index+1]['color']='red' #Left children are red.
                index+=1
            if node.right_child is not None:
                G.add_vertices(1)
                G.add_edges([(node_index, index+1)])
                G.vs[index+1]['color']='blue'
                queue.append([node.right_child, index+1])
                index += 1 

            queue.pop(0)
            if len(queue)==0:
                not_break=False
        layout = G.layout_reingold_tilford(root=0)
        igraphs.plot(G, layout=layout)

If you want to make sure it works, I have written a small test function.

def test():
    lst= [1,4,2,5,1,3,7,11,4.5]
    B = AVLTree()
    for item in lst:
        print "inserting", item
        B.insert(item)
        B.plot(True)
    print "End of inserts"
    print "Deleting 5"
    B.plot(True)
    B.delete(5)
    print "Deleting 1"
    B.plot(True)
    B.delete(1)
    B.plot(False)
    print B.root.key ==4
    print B.find_next(3) ==4
    print B.find_prev(7)==4.5
    print B.find_prev(1) is None
    print B.find_prev(7)==4.5
    print B.find_prev(2) is None
    print B.find_prev(11) == 7

Great now we know the balance of each node, but how to balance the tree?
Tree Rotations
Before solving our balancing problem I want to introduce a new tool, an operation called a tree rotation that respects the search property. Here’s what a tree rotation looks like:

Note_110614_095331_4So there are two moves that are inverses of one another, a right rotation and a left rotation. Staring at the diagram for a little bit you can also see that these moves preserve the search property. Let’s implement this, and then consider when to use it.

We’ll first create a right rotation, then swap the lefts for rights and vice versa to make a left rotation.

The argument to the function will always be the root, that is the parent of the two nodes, so in our diagram above, the label root is appropriate for the right-rotation.

We then move around the the parent-child pointers. Notice that the relation between the root and C, and the relation between pivot and A, is unchanged. The changes we have to make are that B changes parents, and that root and pivot swap places from parent/child to child/parent.

So the left-child of root, pivot, is changed to the right-child of pivot, C,  the right-child of pivot is changed to root, the parent of root is changed to pivot and finally the previous parent of root, if any, becomes the new parent of pivot.

 

      def right_rotation(self, root):
        left=root.is_left()
        pivot = root.left_child
        if pivot is None:
            return
        root.left_child = pivot.right_child
        if pivot.right_child is not None:
            root.left_child.parent = root
        pivot.right_child = root
        pivot.parent = root.parent
        root.parent=pivot
        if left is None:
            self.root = pivot
        elif left:
            pivot.parent.left_child=pivot
        else:
            pivot.parent.right_child=pivot
        root.update_height(False)
        pivot.update_height(False)

And analogously:

    def left_rotation(self, root):
        left=root.is_left()
        pivot = root.right_child
        if pivot is None:
            return
        root.right_child = pivot.left_child
        if pivot.left_child is not None:
            root.right_child.parent = root
        pivot.left_child = root
        pivot.parent = root.parent
        root.parent=pivot
        if left is None:
            self.root = pivot
        elif left:
            pivot.parent.left_child=pivot
        else:
            pivot.parent.right_child=pivot
        root.update_height(False)
        pivot.update_height(False)

Now we ought to test it so using the same list as last time:

def test_rotation():
    lst= [1,4,2,5,1,3,7,11,4.5]
    print "List is ",lst
    B = AVLTree()
    for item in lst:
       print "inserting", item
       B.insert(item)
    node=B.find(4)
    B.plot()
    B.left_rotation(node)
    B.plot()
    B.right_rotation(node.parent)
    B.plot()
test_rotation()

Producing the following two plots:

tree_after_left_rotationtree_before_rotation

 

Finally we show how to use tree rotations to keep a AVL tree balanced after insertions and deletions.

Bringing Balance to the Force

After an insertion or deletion, heights can change by at most 1. So assuming that the tree was balanced prior to the operation, if a node becomes unbalanced, it will have a balance of \pm 2.

Let’s look at the case where we have a node N with balance +2 with some pictures, the other case is the same modulo lefts and rights.

Blog_13

This is called the left-left case. Alternatively we could have the left-right case:

Blog_14This takes to a similar case to before:

Blog_15

To implement this is simple now that we have the right tools. We want to function to start with the changed node, either the inserted node in insertion, or the parent of the deleted node in deletion, and move up through the ancestors, updating heights and balancing as it goes.

    def balance(self, node):
        node.update_height(False)
        if node.balance == 2:
            if node.left_child.balance != -1:
                #Left-left case.
                self.right_rotation(node)
                if node.parent.parent is not None:
                    #Move up a level.
                    self.balance(node.parent.parent)
            else:
                #Left-right case.
                self.left_rotation(node.left_child)
                self.balance(node)
        elif node.balance ==-2:
            if node.right_child.balance != 1:
                #Right-right case.
                self.left_rotation(node)
                if node.parent.parent is not None:
                    self.balance(node.parent.parent)
            else:
                #Right-left case.
                self.right_rotation(node.right_child)
                self.balance(node)
        else:
            if node.parent is not None:
                self.balance(node.parent)

We might as well make a sorting routine to wrap the AVL Tree now.

def sort(lst, ascending=True):
    A = AVLTree()
    for item in lst:
        A.insert(item)
    ret=[]
    if ascending:
        node=A.find_leftmost(A.root)
        if node is not None:
            key = node.key
        else:
            key=node
        while (key is not None):
            ret.append(key)
            key=A.find_next(key)
    else:
        node=A.find_rightmost(A.root)
        if node is not None:
            key = node.key
        else:
            key=node
        while (key is not None):
            ret.append(key)
            key=A.find_prev(key)
    return ret

And that’s it! Please let me know if you spot any mistakes. You can download the source here.

Advertisements

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out / Change )

Twitter picture

You are commenting using your Twitter account. Log Out / Change )

Facebook photo

You are commenting using your Facebook account. Log Out / Change )

Google+ photo

You are commenting using your Google+ account. Log Out / Change )

Connecting to %s

%d bloggers like this: