Home » data structures » How to Balance your Binary Search Trees – AVL Trees

# How to Balance your Binary Search Trees – AVL Trees

Last time we introduced the binary search tree (BST) and saw that they could do inserts and deletions in $O(h)$ time where $h$ is the height of the tree. We also saw that it could find the successor or predecessor to a node in the same time, and hence that it could sort a list in $O(nh)$ time, where $n$ is the length of the list.

Ideally we would have $h=O(\log n)$, and this is often the case ‘on average’, but the worst case is $h=n$, which will occur if we start with a sorted list! So today we are going to discuss how to have our tree try to balance itself automagically so as to keep its height $O(\log n)$.

AVL-Trees

Recall that the defining characteristic of a BST is the so-called ‘search property’: for any subtree rooted at a node, the stuff on the left is smaller, and the stuff on the right is greater.

We now introduce another property or invariant that we will insist upon and call it the ‘plus or minus property’ (pm) or ‘balance property’, and it is satisfied by a node if the height of its left subtree differs from the height of its right subtree by no more than one. We call a BST that satisfies pm an ‘AVL tree’ (AVL).

So we will give each node an extra field node.height, and node.balance, where node.balance  is equal to the height on the left minus the height on the right. Hence it will be 0 when they are equal, +1 when the left is larger, and -1 when the right is larger.

Before we go any further, you might be wondering whether or not this property forces the tree to have $O(\log n)$ height, and indeed it does, so let’s prove that to ourselves before coercing our trees to have this property.

Proof That AVL Trees are Balanced

Define $n_h$ to be the smallest number of nodes you can build a AVL of height $h$ with. That is, any AVL of height $h$ must have at least $n_h$ nodes.

Now suppose that we had shown that $h=O(\log n_h)$, or in particular that $h \le K \log n_h$ for all $h$. Now take any AVL tree of height $h$ and having $n$ nodes. Then $n \ge n_h \Rightarrow h \le K \log n_h \le K \log n \Rightarrow h=O(\log n)$ as desired.

So we just need to show that $h=O(\log n_h)$. The way to do this is by setting up a recurrence relation. Indeed we will see that $n_{h+2} = n_{h+1}+n_{h}+1$.

Now using this notice that:

$n_{h+2} = n_{h+1}+n_{h} + 1 \ge 2n_h$.

And so if $h=2k$ then applying this repeatedly yields:

$n_h \ge 2^{k-1} \cdot n_2 = 2^{k} = 2^{h/2}$.

On the other hand if $h=2k+1$ then we apply till we get to $n_1=1$ giving:

$n_h \ge 2^k \cdot n_1=2^k \ge 2^{h/2+1}$.

Combining these gives $n_h \ge 2^{h/2}$. Now taking logs base $2$ gives

$\log n_h \ge \frac{h}{2} \Rightarrow h=O(\log n_h) \Rightarrow h=O(\log n)$.

And we are done.

Keeping Track of the AVL Property

Now that we agree that the AVL or pm property is a thing worth having, we should alter our code for a BST so that the nodes have  height and balance fields, and so that the insertion and deletion methods update this accordingly. Once we have done that we will look at how to enforce the balance property.

Now we could use class inheritance and modify the BST class we built last time, but some of the code is a bit kludgy and so I will implement this from scratch, although I won’t go into detail about the stuff that is the same from last time.

We begin with the AVLNode and gift it with a sense of balance.

class AVLNode(object):
def __init__(self, key):
self.key=key
self.right_child=None
self.left_child=None
self.parent=None
self.height=0
self.balance=0

def update_height(self, upwards=True):
#If upwards we go up the tree correcting heights and balances,
#if not we just correct the given node.
if self.left_child is None:
#Empty left tree.
left_height = 0
else:
left_height = self.left_child.height+1
if self.right_child is None:
#Empty right tree.
right_height = 0
else:
right_height = self.right_child.height+1
#Note that the balance can change even when the height does not,
#so change it before checking to see if height needs updating.
self.balance = left_height-right_height
height = max(left_height, right_height)
if self.height != height:
self.height = height
if self.parent is not None:
#We only need to go up a level if the height changes.
if upwards:
self.parent.update_height()

def is_left(self):
#Handy to find out whether a node is a left or right child or neither.
if self.parent is None:
return self.parent
else:
return self is self.parent.left_child


I think the code, and comments, should explain everything if you read my last post. Let’s also pretend that we had a magic function balance(node), which will balance a node and then move up ancestors and do the same, we won’t worry about what exactly it will do, but add it into our insert routine.

class AVLTree(object):
def __init__(self):
self.root =None

def insert(self, key, node=None):
#The first call is slightly different.
if node is None:
#First call, start node at root.
node = self.root
if node is None:
#Empty tree, create root.
node = AVLNode(key=key)
self.root=node
return node
else:
ret= self.insert(key=key, node=node)
self.balance(ret)
return ret
#Not a first call.
if node.key ==key:
#No need to insert, key already present.
return node
elif node.key >key:
child = node.left_child
if child is None:
#Reached the bottom, insert node and update heights.
child = AVLNode(key=key)
child.parent=node
node.left_child = child
node.update_height()
return child
else:
return self.insert(key=key, node=child)
elif node.key < key:
child = node.right_child
if child is None:
#Reached the bottom, insert node and update heights.
child = AVLNode(key=key)
child.parent=node
node.right_child = child
return child
else:
return self.insert(key=key, node=child)
else:
print "This shouldn't happen."

def find(self, key, node=None):
if node is None:
#First call.
node=self.root
if self.root is None:
return None
else:
return self.find(key, self.root)
#Now we handle nonfirst calls.
elif node.key == key:
#Found the node.
return node
elif key < node.key:
if node.left_child is None:
#If key not in tree, we return a node that would be its parent.
return node
else:
return self.find(key,node.left_child)
else:
if node.right_child is None:
return node
else:
return self.find(key, node.right_child)

def delete(self, key, node=None):
#Delete key from tree.
if node is None:
#Initial call.
node = self.find(key)
if (node is None) or (node.key != key):
#Empty tree or key not in tree.
return

if (node.left_child is None) and (node.right_child is not None):
#Has one right child.
right_child = node.right_child
left = node.is_left()
if left is not None:
parent=node.parent
if not left:
parent.right_child=right_child
else:
parent.left_child=right_child
right_child.parent =parent
self.balance(parent)
else:
right_child.parent=None
self.root = right_child
#No need to update heights or rebalance.

elif (node.left_child is not None) and (node.right_child is None):
#Has one left child.
left_child = node.left_child
left= node.is_left()
if left is not None:
parent=node.parent
if left:
parent.left_child=left_child
else:
parent.right_child=right_child
left_child.parent =parent

self.balance(parent)
else:
left_child.parent=None
self.root=left_child
elif node.left_child is None:
#Has no children.
parent = node.parent
if parent is None:
#Deleting a lone root, set tree to empty.
self.root = None
else:
if parent.left_child is node:
parent.left_child =None
else:
parent.right_child=None
self.balance(parent)
else:
#Node has two childen, swap keys with successor node
#and delete successor node.
right_most_child = self.find_leftmost(node.right_child)
node.key = right_most_child.key
self.delete(key=node.key,node=right_most_child)
#Note that updating the heights will be handled in the next
#call of delete.

def find_rightmost(self, node):
if node.right_child is None:
return node
else:
return self.find_rightmost(node.right_child)

def find_leftmost(self, node):
if node.left_child is None:
return node
else:
return self.find_leftmost(node.left_child)

def find_next(self, key):
node = self.find(key)
if (node is None) or (node.key != key):
#Key not in tree.
return None
else:
right_child = node.right_child
if right_child is not None:
node= self.find_leftmost(right_child)
else:
parent = node.parent
while(parent is not None):
if node is parent.left_child:
break
node = parent
parent = node.parent
node=parent
if node is None:
#Key is largest in tree.
return node
else:
return node.key

def find_prev(self, key):
node = self.find(key)
if (node is None) or (node.key != key):
#Key not in tree.
return None
else:
left_child = node.left_child
if left_child is not None:
node= self.find_rightmost(left_child)
else:
parent = node.parent
while(parent is not None):
if node is parent.right_child:
break
node = parent
parent = node.parent
node=parent
if node is None:
#Key is largest in tree.
return node
else:
return node.key
#I also include a new plotting routine to show the balances or keys of the node.
def plot(self, balance=False):
#Builds a copy of the BST in igraphs for plotting.
#Since exporting the adjacency lists loses information about
#left and right children, we build it using a queue.
import igraph as igraphs
G = igraphs.Graph()
if self.root is not None:
G.add_vertices(1)
queue = [[self.root,0]]
#Queue has a pointer to the node in our BST, and its index
#in the igraphs copy.
index=0
not_break=True
while(not_break):
#At each iteration, we label the head of the queue with its key,
#then add any children into the igraphs graph,
#and into the queue.

node=queue[0][0] #Select front of queue.
node_index = queue[0][1]
if not balance:
G.vs[node_index]['label']=node.key
else:
G.vs[node_index]['label']=node.balance
if index ==0:
#Label root green.
G.vs[node_index]['color']='green'
if node.left_child is not None:
G.add_vertices(1)
G.add_edges([(node_index, index+1)])
queue.append([node.left_child,index+1])
G.vs[index+1]['color']='red' #Left children are red.
index+=1
if node.right_child is not None:
G.add_vertices(1)
G.add_edges([(node_index, index+1)])
G.vs[index+1]['color']='blue'
queue.append([node.right_child, index+1])
index += 1

queue.pop(0)
if len(queue)==0:
not_break=False
layout = G.layout_reingold_tilford(root=0)
igraphs.plot(G, layout=layout)


If you want to make sure it works, I have written a small test function.

def test():
lst= [1,4,2,5,1,3,7,11,4.5]
B = AVLTree()
for item in lst:
print "inserting", item
B.insert(item)
B.plot(True)
print "End of inserts"
print "Deleting 5"
B.plot(True)
B.delete(5)
print "Deleting 1"
B.plot(True)
B.delete(1)
B.plot(False)
print B.root.key ==4
print B.find_next(3) ==4
print B.find_prev(7)==4.5
print B.find_prev(1) is None
print B.find_prev(7)==4.5
print B.find_prev(2) is None
print B.find_prev(11) == 7


Great now we know the balance of each node, but how to balance the tree?
Tree Rotations
Before solving our balancing problem I want to introduce a new tool, an operation called a tree rotation that respects the search property. Here’s what a tree rotation looks like:

So there are two moves that are inverses of one another, a right rotation and a left rotation. Staring at the diagram for a little bit you can also see that these moves preserve the search property. Let’s implement this, and then consider when to use it.

We’ll first create a right rotation, then swap the lefts for rights and vice versa to make a left rotation.

The argument to the function will always be the root, that is the parent of the two nodes, so in our diagram above, the label root is appropriate for the right-rotation.

We then move around the the parent-child pointers. Notice that the relation between the root and $C$, and the relation between pivot and $A$, is unchanged. The changes we have to make are that $B$ changes parents, and that root and pivot swap places from parent/child to child/parent.

So the left-child of root, pivot, is changed to the right-child of pivot, $C$,  the right-child of pivot is changed to root, the parent of root is changed to pivot and finally the previous parent of root, if any, becomes the new parent of pivot.

      def right_rotation(self, root):
left=root.is_left()
pivot = root.left_child
if pivot is None:
return
root.left_child = pivot.right_child
if pivot.right_child is not None:
root.left_child.parent = root
pivot.right_child = root
pivot.parent = root.parent
root.parent=pivot
if left is None:
self.root = pivot
elif left:
pivot.parent.left_child=pivot
else:
pivot.parent.right_child=pivot
root.update_height(False)
pivot.update_height(False)


And analogously:

    def left_rotation(self, root):
left=root.is_left()
pivot = root.right_child
if pivot is None:
return
root.right_child = pivot.left_child
if pivot.left_child is not None:
root.right_child.parent = root
pivot.left_child = root
pivot.parent = root.parent
root.parent=pivot
if left is None:
self.root = pivot
elif left:
pivot.parent.left_child=pivot
else:
pivot.parent.right_child=pivot
root.update_height(False)
pivot.update_height(False)


Now we ought to test it so using the same list as last time:

def test_rotation():
lst= [1,4,2,5,1,3,7,11,4.5]
print "List is ",lst
B = AVLTree()
for item in lst:
print "inserting", item
B.insert(item)
node=B.find(4)
B.plot()
B.left_rotation(node)
B.plot()
B.right_rotation(node.parent)
B.plot()
test_rotation()


Producing the following two plots:

Finally we show how to use tree rotations to keep a AVL tree balanced after insertions and deletions.

Bringing Balance to the Force

After an insertion or deletion, heights can change by at most 1. So assuming that the tree was balanced prior to the operation, if a node becomes unbalanced, it will have a balance of $\pm 2$.

Let’s look at the case where we have a node $N$ with balance $+2$ with some pictures, the other case is the same modulo lefts and rights.

This is called the left-left case. Alternatively we could have the left-right case:

This takes to a similar case to before:

To implement this is simple now that we have the right tools. We want to function to start with the changed node, either the inserted node in insertion, or the parent of the deleted node in deletion, and move up through the ancestors, updating heights and balancing as it goes.

    def balance(self, node):
node.update_height(False)
if node.balance == 2:
if node.left_child.balance != -1:
#Left-left case.
self.right_rotation(node)
if node.parent.parent is not None:
#Move up a level.
self.balance(node.parent.parent)
else:
#Left-right case.
self.left_rotation(node.left_child)
self.balance(node)
elif node.balance ==-2:
if node.right_child.balance != 1:
#Right-right case.
self.left_rotation(node)
if node.parent.parent is not None:
self.balance(node.parent.parent)
else:
#Right-left case.
self.right_rotation(node.right_child)
self.balance(node)
else:
if node.parent is not None:
self.balance(node.parent)


We might as well make a sorting routine to wrap the AVL Tree now.

def sort(lst, ascending=True):
A = AVLTree()
for item in lst:
A.insert(item)
ret=[]
if ascending:
node=A.find_leftmost(A.root)
if node is not None:
key = node.key
else:
key=node
while (key is not None):
ret.append(key)
key=A.find_next(key)
else:
node=A.find_rightmost(A.root)
if node is not None:
key = node.key
else:
key=node
while (key is not None):
ret.append(key)
key=A.find_prev(key)
return ret


And that’s it! Please let me know if you spot any mistakes. You can download the source here.