Making Recursive Functions Tail Recursive

First of this blog post assumes that you already have a basic understanding of recursion. I’ll use some Python for the code examples simply because it’s a popular language for learners, the concepts I’ll be talking about are language agnostic. In fact your preferred language might benefit more from tail recursion than Python.

A look at a recursive factorial function

Let’s first look at an issue with recursion, so that we can later fix it using tail recursion. For this we’ll make use of the classic introduction to recursion, the factorial function:

def factorial(n):
    if (n == 0):
        return 1

    return n * factorial(n - 1)

Let’s look how this function works step by step:

factorial(5) => 5 * factorial(4)
             => 5 * (4 * factorial(3))
             => 5 * (4 * (3 * factorial(2)))
             => 5 * (4 * (3 * (2 * factorial(1))))
             => 5 * (4 * (3 * (2 * (1 * factorial(0))))) # base case
             => 5 * (4 * (3 * (2 * (1 * (1)))))
             => 5 * (4 * (3 * (2 * (1))))
             => 5 * (4 * (3 * (2)))
             => 5 * (4 * (6))
             => 5 * (24)
             => 120

This function works correctly, but has an issue common to many recursive functions. Each time you make a call to a function your language has to store the arguments passed in the call as well as the local variables used in the function somewhere, at least until the function finished. With recursion, we have a function that just does a small step and then calls itself over and over again, each time requiring new memory for it. So this memory overhead of function calls becomes more prominent compared to iterative functions and can become a problem.

For example the line 5 * (4 * (3 * (2 * (1 * (1))))) has 5 sets of parenthesis meaning that it has to store the data for 5 different function calls.

Additionally, the memory used for this function overhead is relatively limited and once you exceed it you’ll get a crash.

Now arguably this is just a toy function as an example and likely won’t be called often enough to create an issue due to function call overhead, but other kind of recursive functions or extensive use of recursion across a project could very well lead to such an issue.

Rewriting it as tail recursive function

Now that we’ve seen that the above recursive functions could potentially create a memory issue let’s look at how a tail recursive version would look like and why it would fix this issue. For this I’ll jump straight to showing the tail recursive function. If the code snippet seems confusing to you don’t worry, we’ll talk through it afterwards and see how this fixes the memory overhead issue.

def factorial(n):
    def impl(n, accumulator):
        if (n == 0):
            return accumulator

        return impl(n - 1, n * accumulator)

    return impl(n, 1)

First a little style explanation, to make the factorial function tail recursive we need to take an extra function argument, but I didn’t want to make the interface more complicated. So we left the factorial function the same on the outside taking a single argument and made a local impl function that has the extra argument we need and does the actual work. The factorial is nothing more than a simple wrapper that calls impl with the initial arguments.

Using one or multiple extra arguments as “accumulators” is a common strategy to convert functions into tail recursive form.

Now before we talk about what it means to be tail recursive in detail and how this above implementation fixes the memory issue let’s step through the function again to see if it still works:

factorial(5) => impl(5, 1)
impl(5, 1)   => impl(5 - 1, 5 * 1)
impl(4, 5)   => impl(4 - 1, 4 * 5)
impl(3, 20)  => impl(3 - 1, 3 * 20)
impl(2, 60)  => impl(2 - 1, 2 * 60)
impl(1, 120) => impl(1 - 1, 1 * 120)
impl(0, 120) => 120

I’ve decided to write the step by step update a bit more detailed, showing both the calculations in the function call and what the function call actually looks like.

You can probably already see that we don’t have multiple parenthesis per line indicating the memory of multiple function calls that we have to store. But why? After all we’re still calling the function recursively over and over again. The important difference is in the recursive calls, compare:

return n * factorial(n - 1)

and:

return impl(n - 1, n * accumulator)

The names “tail recursion” and “tail call” come from the fact that the function call is the last thing happening in the function, it’s the tail of the function. We can see in the 2nd version we just make the recursive call and return whatever it results to, so it is indeed a tail call. In the first version we make a recursive call, BUT we still have to multiply its result with n before returning.

Since the function call is the last thing that happens in the second version, the language can make use of a trick called “tail call optimization”. Instead of creating new memory for the arguments of the recursive function call we’ll simply reuse and overwrite the memory of the current function. We can only overwrite the memory because the recursive function call is the last thing happening and the old function doesn’t need its memory anymore.

In the first version we can’t reuse the memory of the old function because we still need to keep track of the n of the old function as well, after all we still need to be able to multiply it with the result of the recursive function call.

This means our tail recursive function only needs memory for the first call to impl, all consecutive recursive calls simply reuse this memory.

Tail call optimization across languages

Now it’s important to discuss how tail call optimization works across programming languages. As you’ve maybe noticed by the name already it’s an optimization and as such not all languages might make use of it. Broadly speaking there are 3 types of different approaches in term of tail call optimization:

  • Disallow it! This is the approach that Python uses for example. I’ve already told you at the start that I’m using Python for the snippets because it’s a popular language to learn. It doesn’t really benefit much from tail calls. But why would you disallow an optimization? The argument of the Python developers is that tail calls reusing the memory of the calling function makes the function call stack harder to understand and debugging harder, and they prefer the debugging clarity over the optimization benefits.

  • Guarantee it! This is the approach of multiple functional languages for example Scheme or languages in the ML family. Recursion is a big part of the functional programming paradigm and many functions don’t even provide iterative constructs like while or for loops. For these languages guaranteeing this optimization is essential as it prevents crashes from too deeply nested recursive calls as well as general speed benefits.

  • Purely another optimization! Now this is the approach of many languages with extensive compilation steps, for example C, C++ and Rust. When compiling in debug mode it probably won’t be optimized out, but when turning on optimizations the compiler will try its best to optimize your code and tail call optimization is a relatively easy optimization for it to perform. And whilst it isn’t guaranteed and you may in fact run into a function that wasn’t tail call optimized creating problems, this is very unlikely since optimizers nowadays are insanely good. Especially with such a relatively straight forward optimization, in fact some compilers will convert the 1st version of our function into our 2nd tail recursive function automatically if the function is simple enough (like a toy factorial function).

Conclusion

I hope I’ve been able to explain tail recursion and tail call optimization in a way that’s understandable. Tail call optimization is just one but arguably a big benefit of tail recursion but I’ll leave it at that since this post is already long enough.

But for another example if you want to try on your own, the classic Fibonacci function often used as intro to recursion with 2 recursive calls has exponential complexity. With 2 extra arguments you can make a naturally tail recursive version with linear complexity.