So in preparation for starting a PhD in Machine Learning/Data Science at the University of Edinburgh .
I have been watching some of the lectures on algorithms and data structures at MIT here.
One of the data structures discussed is the binary search tree (BST), and so in this post I will explain what they are and give a python implementation.
As you might expect, a BST is a way of organizing data in a binary tree. It consists of a collection of nodes, each of which has a key, corresponding to a piece of data, say an integer. Importantly the keys must be comparable, so that for any two nodes either or .
Each node also has (at most) three pointers: parent, left child and right child.
Most importantly we have the invariant of the BST. An invariant is, of course, something that stays the same, a property that is not altered under any of the data structures allowable operations. The invariant or search property is that for any given node all of the nodes in the left subtree rooted at have keys less than , and all of the nodes in the right subtree have keys greater than .
Any useful data structure has operations. Some obvious things to want to do are to insert and delete. Of course in order to do these things we have to take care not to mess up the search property. To insert we begin at the root and at each stage we ask whether the key to be inserted is lesser, greater or equal to the parent, in which case we move right, left or nowhere respectively. It’s easiest just to see an example:
Now I think you get the idea, so let’s start coding this up. First we will make the BST node.
class TreeNode(object): #A tree node has a key, which can be compared with other keys, #and possibly 'pointers' to a parent, left child or right child. def __init__(self, key, parent=None, left_child=None, right_child=None): self.parent=parent self.left_child=left_child self.right_child=right_child self.key=key
Next we will start making the BST. The BST remembers the root, and the pointers do the rest of the work. We begin by implementing a helpful method called ‘find’, which returns a node with a given key, or that could (but doesn’t) have a child with that key.
class BinarySearchTree(object): def __init__(self, root=None): self.root=root def find(self, key): #Returns a node corresponding to a key. #If key is the key of some node, returns that node. #If tree is empty, returns None. #Otherwise returns a leaf that could accept #a node with that key as a child. #This function wraps a call to a recursive function self._find. if self.root is None: return None else: return self._find(key, self.root) def _find(self, key, node): #This is a recursive function that does all the work for #self.find. if key == node.key: return node elif key < node.key: if node.left_child is None: return node else: return self._find(key, node.left_child) else: if node.right_child is None: return node else: return self._find(key, node.right_child)
Using ‘find’ we can then implement ‘insert’.
def insert(self, key): #Inserts a node with given key. #If key already in tree, then returns the node with that key. #Otherwise creates a new node and returns it. #This takes time of order the height of the tree. node = self.find(key) if node is None: self.root=TreeNode(key) return self.root elif key == node.key: return node elif key < node.key: left_child = TreeNode(key, parent=node) node.left_child=left_child return left_child else: right_child = TreeNode(key, parent=node) node.right_child=right_child return right_child
So far so good!
The next thing we want to do is be able to delete keys.
If a key belongs to a leaf node then this is simple we just delete the node.
If the node has one child, then if it has a parent then when we delete the node we must connect the parent to the child of the deleted node.
If the node has two children things are a bit more tricky. If we delete the node then we break the tree. What we need to do instead is replace the node’s key with something else in the tree. What we replace it with is the rightmost node of the subtree rooted at the node’s left child, as we shall see later this is the successor of the node in the BST. The following diagram should explain:
So we swap the node’s key for this new rightmode node’s, then delete the rightmost node. Since it is rightmost it cannot have a right child, and so it is one of our base cases for the deletion operation. If this is still not clear to you, do some examples on paper and you will soon see the idea. Let’s implement this.
def delete(self, key): #Delete key from tree. #If key is not in BST does nothing. #Otherwise it calls a semi-recursive function _delete. node = self.find(key) if (node is None) or (node.key != key): return else: self._delete(node) def _delete(self, node): #If the node has less than two children we can delete the node #directly by removing it and then gluing the tree back together. #If node has two children, it swaps a key lower down in the tree to #replace the deleted node, and deletes the lower down node. if (node.left_child is None) and (node.right_child is not None): #Has one right child. right_child = node.right_child parent = node.parent if parent is not None: parent.right_child=right_child right_child.parent =parent else: right_child.parent=None if node is self.root: self.root=right_child elif (node.left_child is not None) and (node.right_child is None): #Has one left child. left_child = node.left_child parent = node.parent if parent is not None: parent.left_child=left_child left_child.parent =parent else: left_child.parent=None if node is self.root: self.root=left_child elif node.left_child is None: #Has no children. parent = node.parent if parent is None: self.root = None else: if parent.left_child is not None: left = parent.left_child is node else: left = False if left: parent.left_child =None else: parent.right_child=None else: right_most_child = self.find_leftmost(node.right_child) node.key = right_most_child.key self._delete(right_most_child) def find_rightmost(self, node): if node.right_child is None: return node else: return self.find_rightmost(node.right_child) def find_leftmost(self, node): if node.left_child is None: return node else: return self.find_leftmost(node.left_child)
Ok great, we have now covered the basics of how to modify a BST. Now we want to know how to extract information from it. In particular given a key, we want to be able to find the next largest or smallest key in the BST, if any. Calling the successor (predecessor) operation repeatedly, starting from the left(right)-most node, will give us the sorted keys in ascending(descending) order.
I will explain how to find the successor, the predecessor case is of course very similar. Let’s suppose we want to find the successor of node with right-child . In this case we choose the left-most child of which we will call .
Note that in a subtree rooted at some node, the left-most node is the smallest element in the subtree, and the right-most node is the largest.
So in the subtree rooted at , the right-subtree rooted at contains all the elements in the subtree greater than . The left-most element of the right-subtree is then the smallest element in the subtree greater than and is a good candidate for a successor. Let’s draw a picture and try to see why this must in fact be the correct node.
OK, but what if has no right child? In this case we move up through the ancestors of , searching for the first ancestor that is a left-child of its parent. If there is no such node, ie we get to the root, then is the right-most, and hence largest, element in the tree, and has no successor. To see that this is correct employ a similar argument to the previous case.
Let’s code that up.
def find_next(self, key): node = self.find(key) if (node is None) or (node.key != key): #Key not in tree. return else: ret= self._find_next(node) if ret is None: #Key is largest in tree. return ret else: return ret.key def _find_next(self, node): right_child = node.right_child if right_child is not None: return self.find_leftmost(right_child) else: parent = node.parent while(parent is not None): if node is parent.left_child: break node = parent parent = node.parent return parent def find_prev(self, key): node = self.find(key) if (node is None) or (node.key != key): #Key not in tree. return else: ret= self._find_prev(node) if ret is None: #Key is smallest in tree. return ret else: return ret.key def _find_prev(self, node): left_child = node.left_child if left_child is not None: return self.find_rightmost(left_child) else: parent = node.parent while(parent is not None): if node is parent.right_child: break node = parent parent = node.parent return parent
By now I think we deserve to see the fruits of our hard work, so let’s whip up a plotting routine using igraphs. Don’t worry about the details: since I can’t find a binary tree plotting style in igraphs, I have colored the root green, left children red and right children blue.
In order to preserve information about left and right children, the plotting routine builds the graph in igraphs using a queue.
def plot(self): #Builds a copy of the BST in igraphs for plotting. #Since exporting the adjacency lists loses information about #left and right children, we build it using a queue. import igraph as igraphs G = igraphs.Graph() if self.root is not None: G.add_vertices(1) queue = [[self.root,0]] #Queue has a pointer to the node in our BST, and its index #in the igraphs copy. index=0 not_break=True while(not_break): #At each iteration, we label the head of the queue with its key, #then add any children into the igraphs graph, #and into the queue. node=queue #Select front of queue. node_index = queue G.vs[node_index]['label']=node.key if index ==0: #Label root green. G.vs[node_index]['color']='green' if node.left_child is not None: G.add_vertices(1) G.add_edges([(node_index, index+1)]) queue.append([node.left_child,index+1]) G.vs[index+1]['color']='red' #Left children are red. index+=1 if node.right_child is not None: G.add_vertices(1) G.add_edges([(node_index, index+1)]) G.vs[index+1]['color']='blue' queue.append([node.right_child, index+1]) index += 1 queue.pop(0) if len(queue)==0: not_break=False layout = G.layout_reingold_tilford(root=0) igraphs.plot(G, layout=layout)
Now let’s test it!
def test(): lst= [1,4,2,5,1,3,7,11,4.5] B = BinarySearchTree() for item in lst: B.insert(item) B.plot() B.delete(5) B.delete(1) B.plot() print B.root.key print B.find_next(3) print B.find_prev(7) print B.find_prev(1) test()
You can probably see that the operations take time where is the height of the tree. This means that we can sort a length list in time. On average , but if given a monotonic list we will get a linear tree with . So next time we will see how to ‘balance’ the tree using tree rotations to achieve fast operations.
Please feel free to leave any comments, suggestions or corrections! Here is the source code.