AVL Tree in Java

An AVL tree is a self-balancing binary tree. That’s all I remember from when I wrote this java implementation of an AVL Tree a few years ago. This code may be useful to you in understanding how LL, RR, LR and RL rotations work in re-balancing the binary tree.

There are 3 classes included:

  • AVLTree
  • BinaryTree
  • BinaryNode

Hopefully someone out there will find this useful!

AVLTree.java

public class AVLTree extends BinaryTree {

	/**
	 * Insert and make sure the property of balance holds.
	 */
	public BinaryNode insert(int value){
		
		// Insert normally.
		BinaryNode n = super.insert(value);
		
		// Keep track if we have already applied a balance.
		boolean appliedBalance = false;
		
		// Follow up the path and update the height and balance of each node.
		// If we have a node that violates the balance property,
		// (abs(balance) >= 2), then we perform the appropriate rotation on THAT node.
		BinaryNode p = n.getParent();
		while(p != null){
			
			// Update the balance.
			p.reBalance();
			
			// If the node is not balanced, perform the rotation on that node.
			if(!p.isBalanced() && !appliedBalance){

				// Check which sub tree we are in.
				if(getRoot().getValue() < n.getValue()){
					// We are in the right sub tree.
					// Check which second side we are on.
					if(p.getValue() < n.getValue()){
						// We on the right (RR).
						singleWithRightChild(p);
					} else {
						// We are on the left (RL).
						doubleWithLeftChild(p);
					}
				} else {
					// We are in the left sub tree.
					// Check which second side we are on.
					if(p.getValue() < n.getValue()){
						// We on the right (LR).
						doubleWithRightChild(p);
					} else {						
						// We are on the left (LL).
						singleWithLeftChild(p);						
					}				
				}
				
				// We applied the balance.
				appliedBalance = true;
				
			}
			
			// Go to next parent.
			p = p.getParent();
			
		}
		
		// Return the inserted node.
		return n;
		
	}
	
	/**
	 * LL Rotation.
	 * @param k2
	 * @return
	 */
	public BinaryNode singleWithLeftChild(BinaryNode k2){
		
		// 1) k1 is the left child of k2.
		BinaryNode k1 = k2.getLeft();
		
		// 2) left child of k2 = right child of k1.
		k2.setLeft(k1.getRight());
		
		// 3) left child of parent of k2 or right child of parent of k2 = k1.
		if(k2.getParent() != null){
			if(k2.getParent().getLeft() == k2){
				// Insert on left.
				k2.getParent().setLeft(k1);
			} else {
				// Insert on right.
				k2.getParent().setRight(k1);
			}
		}
		
		// 4) right child of k1 = k2.
		k1.setRight(k2);
		
		// 5) update height information for k2 and k1.
		k2.reBalance();
		k1.reBalance();
		
		// If we have a new root, make sure to update the variable.
		// Our new root will occur if k1's parent is null.
		if(k1.getParent() == null){
			_root = k1;
		}
		
		// Return k1 (for use in double rotations).
		return k1;
	}
	
	/**
	 * RR Rotation.
	 * @param k2
	 * @return
	 */
	public BinaryNode singleWithRightChild(BinaryNode k2){
		
		// 1) k1 is the right child of k2.
		BinaryNode k1 = k2.getRight();
		
		// 2) right child of k2 = left child of k2.
		k2.setRight(k1.getLeft());
		
		// 3) left child of parent of k2 or right child of parent of k2 = k1.
		if(k2.getParent() != null){
			if(k2.getParent().getLeft() == k2){
				// Insert on left.
				k2.getParent().setLeft(k1);
			} else {
				// Insert on right.
				k2.getParent().setRight(k1);
			}
		}
		
		// 4) left child of k1 = k2.
		k1.setLeft(k2);
		
		// 5) update height information for k2 and k1.
		k2.reBalance();
		k1.reBalance();
		
		// If we have a new root, make sure to update the variable.
		// Our new root will occur if k1's parent is null.
		if(k1.getParent() == null){
			_root = k1;
		}
		
		// Return k1 (for use in double rotations).
		return k1;
		
	}
	
	/**
	 * RL Rotation.
	 * @param k3
	 * @return
	 */
	public BinaryNode doubleWithLeftChild(BinaryNode k3){
		
		k3.setLeft(singleWithRightChild(k3.getLeft()));
		return singleWithLeftChild(k3);
		
	}
	
	/**
	 * LR Rotation.
	 * @param k3
	 * @return
	 */
	public BinaryNode doubleWithRightChild(BinaryNode k3){
		
		k3.setRight(singleWithLeftChild(k3.getRight()));
		return singleWithRightChild(k3);
		
	}
	
}

BinaryTree.java

public class BinaryTree {

	// The root of the tree.
	protected BinaryNode _root;
	
	/**
	 * Insert a value into the tree.
	 * @param value
	 */
	public BinaryNode insert(int value){
		
		// If we don't have a root node, make one.
		if(getRoot() == null){
			
			// Make the root node.
			_root = new BinaryNode(this, value);
			
			// Return the node.
			return _root;
		
		} else {
			
			// Set the current node.
			BinaryNode current = _root;
			
			// Find the proper spot.
			while(current != null){
				
				// Compare values.
				if(value < current.getValue()){
					// Insert on the left.
					if(current.getLeft() == null){
						// Create the node.
						BinaryNode n = new BinaryNode(this, value);
						current.setLeft(n);
						return n;
					} else {
						current = current.getLeft(); // Go to next left.
					}
				} else {
					// Insert on the right.
					if(current.getRight() == null){
						// Create the node.
						BinaryNode n = new BinaryNode(this, value);
						current.setRight(n);
						return n;
					} else {
						current = current.getRight(); // Go to next right.
					}
				}
				
			}
			
			// If we weren't able to insert.
			return null;
			
		}
		
	}
	
	/**
	 * Print the entire tree in proper format.
	 */
	public void printFormattedTree(){
		
		// Start from root.
		if(getRoot() != null){
			getRoot().printFormatted();
		}
	}
	
	/**
	 * Find the node that has the value.
	 * @param value
	 * @return
	 */
	public BinaryNode findNodeWithValue(int value){
		return findNodeWithValue(getRoot(), value);
	}
	public BinaryNode findNodeWithValue(BinaryNode n, int value){
		
		// Search for the value.
		if(n == null){
			
			// The value couldn't be found.
			return null;
			
		} else if(n.getValue() < value){
			
			// Go to the right.
			return findNodeWithValue(n.getRight(), value);
			
		} else if(n.getValue() > value) {
			
			// Go to the left.
			return findNodeWithValue(n.getLeft(), value);
			
		} else {
			
			return n; // We found the sunbitch!
			
		}
		
	}
	
	/**
	 * Get the minimum value in the tree.
	 * @return
	 */
	public int getMinimumValue(){
		
		// Make sure we have at least one node.
		if(getRoot() != null){

			// Set the current node.
			BinaryNode current = _root;
			
			// Iterate until we hit the last left node.
			while(current.getLeft() != null){
				current = current.getLeft();
			}
			
			// Return the value of the last left node.
			return current.getValue();
		
		} else {
			return 0;
		}
		
	}
	
	/**
	 * Get the maximum value in the tree.
	 * @return
	 */
	public int getMaximumValue(){
		
		// Make sure we have at least one node.
		if(getRoot() != null){
			
			// Set the current node.
			BinaryNode current = _root;
			
			// Iterate until we hit the last right node.
			while(current.getRight() != null){
				current = current.getRight();
			}
			
			// Return the value of the last right node.
			return current.getValue();
		
		} else {
			return 0;
		}
		
	}
	
	/**
	 * Get the height of a specified node.
	 * @param n
	 * @return
	 */
	public int getHeightOfTree(){
		return getHeightOfNode(getRoot());
	}
	public int getHeightOfNode(BinaryNode n){
		
		// Base case.
		if(n == null){
			return -1;
		} else {
			// Get the depth of the left side.
			int leftSideDepth = getHeightOfNode(n.getLeft());
			// Get the depth of the right size.
			int rightSideDepth = getHeightOfNode(n.getRight());
			// Return the maximum depth of the two.
			return 1 + Math.max(leftSideDepth, rightSideDepth);
		}
		
	}
	
	/**
	 * Get the depth of the node.
	 * @param n
	 * @return
	 */
	public int getDepthOfTree(){
		return getHeightOfTree();
	}
	public int getDepthOfNode(BinaryNode s){
		return getDepthOfNode(getRoot(), s);
	}
	public int getDepthOfNode(BinaryNode c, BinaryNode s){
	
		if(c == null){
			return 0;
		}
		
		// Search for the s node.
		if(c.getValue() < s.getValue()){
			return 1 + getDepthOfNode(c.getRight(), s);
		} else if(c.getValue() > s.getValue()){
			return 1 + getDepthOfNode(c.getLeft(), s);
		} else { // We found the search node.
			return 0;
		}
		
	}

	/**
	 * Calculates the balance of the node.
	 * @param n
	 * @return
	 */
	public int getBalanceOfTree(){
		return getBalanceOfNode(getRoot());
	}
	public int getBalanceOfNode(BinaryNode n){
		return getHeightOfNode(n.getLeft()) - getHeightOfNode(n.getRight());
	}
	
	/**
	 * Return the number of nodes in the tree.
	 * @return
	 */
	public int getSizeOfTree(){
		return getSizeOfNode(_root);
	}
	public int getSizeOfNode(BinaryNode n){
		
		// If the node is null, return 0.
		if(n == null){
			return 0;
		} else {
			// We have at least one element. Get the size of it's children.
			return 1 + getSizeOfNode(n.getLeft()) + getSizeOfNode(n.getRight());
		}
		
	}
	
	// Getter for the root.
	public BinaryNode getRoot(){
		return _root;
	}
	
}

BinaryNode.java

public class BinaryNode {

	// Reference to the tree the binary node belongs in.
	private BinaryTree _tree;
	
	// The key value in the node.
	private int _value;
	
	// Left and right children and parent.
	private BinaryNode _left;
	private BinaryNode _right;
	private BinaryNode _parent;
	
	// Store the balance property of the node.
	private int _balance;
	
	/**
	 * Constructor.
	 * @param tree		The tree the node belongs to.
	 * @param value
	 */
	public BinaryNode(BinaryTree tree, int value){
		_value = value; // Initialize.
		_tree = tree;
	}
		
	/**
	 * Implementation of toString().
	 */
	public String toString(){
		return "( " + _value + " ) [B=" + getBalance() + "]";
	}
	
	/**
	 * Check if the node is balanced.
	 * @return
	 */
	public boolean isBalanced(){
		// If it is -1, 0, or 1, it is balanced.
		return Math.abs(getBalance()) < 2;
	}
	
	/**
	 * Print the node in a formatted manner.
	 */
	public void printFormatted(){
		printFormatted(0);
	}
	public void printFormatted(int depth){
		
		String indent = "";
		for(int i=0; i < depth; i++){
			indent += "\t";
		}
		
		// Recurse on left.
		if(getLeft() != null){
			getLeft().printFormatted(depth + 1);
		}
		
		System.out.println(indent + this);
		
		// Recurse on right.
		if(getRight() != null){
			getRight().printFormatted(depth + 1);
		}
		
	}
	
	/**
	 * Re balance the node by checking it's current balance.
	 */
	public void reBalance(){
		_balance = getTree().getBalanceOfNode(this);
	}
	
	// Setters for the node references.
	public void setLeft(BinaryNode n){
		
		// We should NOT wrap this with a n != null check, in case we pass a null
		// value in the argument, infering that we would like to clear the child.
		_left = n;
		
		// This MUST be wrapped with a null check, because we can't set a parent
		// on a null object.
		if(n != null){
			n.setParent(this); // Set the parent of the other node.
		}
		
		// Take into consideration if we are setting the current parent of this
		// to be the child of this. Then we must break the parent connection and
		// our new child is no longer our parent.
		if(n == getParent()){
			setParent(null);
		}
		
	}
	
	public void setRight(BinaryNode n){
		
		// We should NOT wrap this with a n != null check, in case we pass a null
		// value in the argument, infering that we would like to clear the child.
		_right = n;
		
		// This MUST be wrapped with a null check, because we can't set a parent
		// on a null object.
		if(n != null){
			n.setParent(this); // Set the parent of the other node.
		}
		
		// Take into consideration if we are setting the current parent of this
		// to be the child of this. Then we must break the parent connection and
		// our new child is no longer our parent.
		if(n == getParent()){
			setParent(null);
		}
		
	}
	
	public void setParent(BinaryNode n){
		_parent = n;
	}
	
	// Get the balance.
	public int getBalance(){
		return _balance;
	}
	
	/**
	 * Return the height of the node.
	 * @return
	 */
	public int getHeight(){
		return getTree().getHeightOfNode(this);
	}
	
	// Get the value.
	public int getValue(){
		return _value;
	}
	
	// Get the tree.
	public BinaryTree getTree(){
		return _tree;
	}
	
	// Getters for the node references.
	public BinaryNode getLeft(){
		return _left;
	}
	
	public BinaryNode getRight(){
		return _right;
	}
	
	public BinaryNode getParent(){
		return _parent;
	}
	
}

Leave a Reply

Your email address will not be published. Required fields are marked *