About usWhy usInstructorsReviewsCostFAQContactBlogRegister for Webinar

Kth Smallest Element Of BST Problem

Kth Smallest Element in a BST

Given a binary search tree (BST) and an integer k, find k-th smallest element.

Example

Input:

BST:


    2

  /    \

1      3

 

k=3

Output: 3

The 3rd smallest element is 3.

Notes

Input Format: There are two arguments in the input. First one is the root of the BST and second one is an integer k.

Output: Return an integer, the k-th smallest element of the BST.

Constraints:

● 1

● 1

● -2 * 10^9

● You are not allowed to alter the given BST in any way.


Solution

We want to find kth smallest element of the given BST. If we can get all the elements of BST in sorted order then our answer will be the kth element. We know that in-order traversal visits elements in the sorted order! But the time complexity of such a solution would be O(N) and auxiliary space used is also O(N).

Note that we don’t need to store all the elements, we can just keep the count of visited nodes; when the counter becomes k it is the node we want!

Your code should look like:

void modified_inorder(root, k)

{

handle base case;

modified_inorder(root->l);

if (answer is not found in left subtree)

{

counter++;

// make sure that you are incrementing

// after the left subtree is visited.

consider current node;

modified_inorder(root->r);

}

}

Time complexity:

In terms of the number of tree nodes, it is O(N). Using other variables we can write a tighter bound for this solution. In terms of the height of the tree h and k, it is O(h + k). 

The algorithm first traverses down to the leftmost node which takes O(h) time, then traverses k elements in O(k) time. Therefore overall time complexity is O(h + k).

Note that even if k=1 the algorithm has to go all the way down the tree to find the smallest element, visiting all the nodes on the way, and visiting one node takes constant time. So far we have used O(h) time where h is the height of the tree (worst case is when the leftmost leaf of the tree is the longest one).

Having gone all the way down to the smallest element, the algorithm then visits exactly k nodes from there (still constant time per node); complexity so far is O(h) + O(k).

Having found and saved the k-th element value, the algorithm still needs to pop out from the recursion calls so that it can return the answer in the end. For that it will use constant time per level of recursion, per depth of the tree (worst case, again, is when we have found the k-th element in the leaf of the longest branch of the tree). That takes another O(h) time. Therefore the overall time complexity: O(h) + O(k) + O(h) = O(2h + k) = O(h + k).

Auxiliary Space Used:

O(h) due to the stack frames for the recursive calls.

Space Complexity:

O(N) due to input size.


 // -------- START --------

    // kth smallest element is stored in this variable. 
    static int kth_element;                                             
    /*
    when running more than one testcases then dont use static in counter = 0 use this and 
    initialize counter = 0 at the beginning of each testcase.     
    */
    //int counter = 0;                                                  

    static int counter = 0;

    static void get_k_th_element(TreeNode root, int k)
    {
        // This function uses the idea of inorder_traversal. 
        // either root is null or we have already found the answer.             
        if (root == null || counter >= k)                               
        {
            return;
        }
        /*
        first try to find from left subtree, because elements in left suubtree will be smaller 
        than the root.
        */
        get_k_th_element(root.left_ptr, k);                             
        // if we have not found the answer till now.        
        if (counter < k)                                                
        {
            counter++;
            // if current node is the kth node.
            if (counter == k)                                           
            {
                kth_element = root.val;
                return;
            }
            // we have explored left subtree and the root now explore right subtree. 
            get_k_th_element(root.right_ptr, k);                        
        }
    }

    static int kth_smallest_element(TreeNode root, int k)
    {
        // find kth smallest element
        get_k_th_element(root, k);                                      
        return kth_element;
    }

    // -------- END --------