Wednesday 3 January 2024

Tail Call Optimization

Some days ago one colleage told me that his main complain about Python was the lack of Tail Call Optimization. Indeed Guido was very clear about that, in order to avoid messing the stack traces there will never be (well, don't know if now that he's no longer the BDFL maybe that could change) TCO in Python (in the cPython runtime I guess). Though over the time I have written several posts about how to adapt your code to prevent stack overflows (by using async-await in JavaScript and Python, or by using trampolines), I'd never paid much attention to this Tail Call Optimization thing, so I suddenly got intrigued.

First, I immediatelly associate Tail Call Optimization to Tail Recursive code, but this is not necessarilly so. A Tail Call happens when the last instruction in a function is a call to another function. Callbacks and Continuations come easily to mind. Normally Callbacks and Continuations are not used massively, so optimizing them won't be particularly important, neither in performance terms (improved memory locality) nor in terms of risking a stack overflow, save for functional style programming fully adhering to the CPS (Continuation Passing Style) style. That's why Tail Call Optimization is mainly important for recursive code (though it can be applied to any general Tail Call).

Another important point is distinguishing between direct recursion and mutual or indirect recursion. In direct recursion a function calls itself, in mutual/indirect recursion f1 calls f2 that calls f1 again. Indirect recursion and not recursive tail calls are more complex to optimize than direct recursive tail calls. From the wikipedia article:

The special case of tail-recursive calls, when a function calls itself, may be more amenable to call elimination than general tail calls. When the language semantics do not explicitly support general tail calls, a compiler can often still optimize sibling calls, or tail calls to functions which take and return the same types as the caller.

However, for language implementations which store function arguments and local variables on a call stack (which is the default implementation for many languages, at least on systems with a hardware stack, such as the x86), implementing generalized tail-call optimization (including mutual tail recursion) presents an issue: if the size of the callee's activation record is different from that of the caller, then additional cleanup or resizing of the stack frame may be required. For these cases, optimizing tail recursion remains trivial, but general tail-call optimization may be harder to implement efficiently.

For example, in the Java virtual machine (JVM), tail-recursive calls can be eliminated (as this reuses the existing call stack), but general tail calls cannot be (as this changes the call stack).[13][14] As a result, functional languages such as Scala that target the JVM can efficiently implement direct tail recursion, but not mutual tail recursion.

Curious about the last statement, I've found this that explains it pretty well.

Have you noticed that the method address starts with 0? That all methods offsets start with 0? JVM doesn't allow one to jump outside a method.

As I've aforementioned, Python does not provide any sort of TCO. We can modify a bit our code writing trampolines ourselves, that will work fine with tail and not tail recursion, direct and mutual recursion. We could also use the async/await trick, but I assume it'll be pretty awful in performance terms. And then, we have several smart people that have come up with some general solutions. There is this decorator that throws (an catches) an exception every 2 recursive calls, so the stack never grows (but I guess it'll hit performance pretty hard). Then this module that seems more ellaborate, and then this one that modifies one function bytecode to use trampolines.

Kotlin has the tailrec modifier that you can use in a function definition and will only work for direct tail recursion. In this article you can see how the Java bytecodes generated by the Kotlin compiler for a tailrec function implement the tail recursion by updating values in the current stack frame and performing a jump to the start of the function (so transforming the recursion into a loop), rather than doing a call.

I've learned from this article that .Net comes with support for tail calls at the bytecode level by means of the tail. instruction. I guess at the CLR bytecode level you are not allowed to write a jump to outside of the current function (same as in the JVM), so this tail function instructs the JIT compiler to compile the ensuing call bytecode into a native jump instruction, rather than creating a new stack frame and doing a native call. . It seems the F# compiler uses that instruction for non recursive tail calls, while transforming recursion into loops for recursive tail calls. The C# compiler does none of them.

The above brings up 2 general points to reflect about for me.
First, when providing Tail Call Optimizations in a VM environment, should the "high level language" to bytecodes compiler perform the optimization, or should that be delegated to the bytecodes to Native code JIT compiler. Well, for VM's on which we use an interpreter and 1 or n JIT's for "hot code" it's clear that that should be done by the "language to bytecodes" compiler. If there's no interpreter involved, just JIT, I'm not sure which one of the two compilers should perfomr optimization.
Second point, should we instruct the compiler about performing the optimization (as we do in Kotlin with tailrec) or should the compiler decide on its own and perform the optimization when he can and sees fit (as the F# compiler seems to do)? I would say I prefer the former approach.

There's this interesting discussion about the above topics appliced to C# and .Net.

No comments:

Post a Comment