Wednesday 2 October 2024

Using Suspend Functions to Prevent Stack Overflows in Kotlin

This post is the Kotlin counterpart to these 2 previous JavaScript and Python posts. Avoiding stack overflows by means of using suspendable functions (or async-await in JavaScript, Python). This technique does not have much of a real practical purpose, the standard mechanism to prevent stack overflows is using trampolines, so consider this post as a brain exercise that has helped me to better understand how coroutines and suspendable functions work (same as the aformentioned posts helped me better understand async/await).

When a kotlin suspend function gets suspended at a suspension point it returns to the calling function, so the corresponding stack frame is released (the function will be later resumed, in the sentence following that suspension point, by "something" calling the resume method of the continuation object corresponding to that function. If we are not stacking things in the stack, there's no risk of a stack overflow. So if we add in our recursive function a call to a trivial suspend function before the next call to the recursive function, we should avoid the stack overflow. The main thing here is clarifying what "trivial" suspend function we can use for this.

Let's start by showing a recursive function that causes a stack overflow:


fun wrapText1(txt: String, times: Int): String = 
    if (times <= 0) txt
    else wrapText1("[$txt]", times - 1)
	
try {
	println(wrapText1("b", 10000))
}
catch (ex: StackOverflowError) {
	println("StackOverflow!!!")
}

// StackOverflow!!!


First of all, that case could be fixed by leveraging Kotlin's compiler support for Tail Call Optimizations. That function is Tail Recursive, so we can just mark it with the tailrec modifier and the compiler will transform that recursion into a loop, so no more overflows.


tailrec fun wrapText1(txt: String, times: Int): String = 
    if (times <= 0) txt
    else wrapText1("[$txt]", times - 1)

That said, let's forget about the tailrec power. That works for tail recursive functions, while the idea of leveraging suspendable functions to prevent stack overflows will work the same for tail and non-tail recursion. Notice that I´ll run my recursive function in a coroutine created with the runBlocking builder using an event loop dispatcher, so everything runs in a single thread, as in JavaScript and Python asyncio.

Let's first try using a suspend function that indeed does nothing. It does not invoke another "real" suspendable function. As you can see in my post from 2021, in JavaScript awaiting on a resolved Promise (or just on a normal object) is enough to avoid the stack overflow, as it causes the function to be suspended, returning and releasing its stack frame. The "then" part to run after the await is not executed inmmediatelly, but put on the microtask queue, meaning that the function returns to the event loop (and this will execute the "then handler" taken from the microtasks queue). Asynchrony in Kotlin is implemented in a quite different way (continuations...) and I was quite convinced that this "doNothing" function would not cause a suspension, but I had to confirm it.


suspend fun doNothingSuspend(): String {
    return "a"
}

// as expected this one DOES NOT prevent Stack Overflows. Calling a function marked as suspend but that is not really suspending is useless for this
// there won't be suspension. So it's like Python asyncio
suspend fun wrapTextAsync1(txt: String, times: Int): String {
    return if (times <= 0) txt
    else {
        doNothingSuspend()
        wrapTextAsync1("[$txt]", times - 1)
    }
}

try {
	println(wrapText("b", 10000))
}
catch (ex: StackOverflowError) {
	println("StackOverflow!!!")
}

//StackOverflow!!!


Yes, we get an ugly stack overflow. Calling that suspend function does not cause a suspension cause the complex code in which a suspend function is transformed by the compiler will not return an COROUTINE_SUSPENDED value, that is what the calling function checks in turn for returning to its calling function (and so on down to the event loop) or continuing with the next sentence (that is what happens in this case). This is a bit related to what happens in Python, where marking a function that "does not really do something asynchronous" as async (a sort of "fake async function") will not really suspend the function. Because of how coroutines are implemented in Python (similar to generators), the chain of async calls will get down to the Task that controls the chain of coroutine calls, and as what has been returned is not an awaitable, the Task will just invoke send() in the first coroutine in the chain, which will go through the different coroutines in the chain recreating their stacks, up to the coroutine from which we had invoked the "fake async function", that will continue, hence running sequentially, not interrupted by a suspension, without having returned the control to the event loop.

So if we want to have a real suspension in our recursive function, returning from the function and avoiding a stack overflow we'll have to invoke a function that really suspends. The immediate option that comes to mind is using a sort of sleep (delay, in the Kotlin coroutines universe) with a minimum value. I say "minimum value" cause passing just 0 to kotlinx.coroutines.delay will return immediatelly without causing any suspension (it won't return a COROUTINE_SUSPENDED value).


suspend fun wrapTextAsync2(txt: String, times: Int): String {
    return if (times <= 0) txt
    else {
        // using a delay(0) is NOT an option, as it returns immediatelly and no suspension happens
        delay(1)
        wrapTextAsync2("[$txt]", times - 1)
    }
}


That works fine. delay(1) returns to the event loop and after that delay the corresponding continuation will resume the recursive function in the next sentence in the function (the next call to that wrapTextAsync2) with a clean stack. But obviously there's an important performance penalty due to the delay.

If you've read this previous post you'll be also familiar with kotlinx.coroutines.yield. Reading the implementation details I was a bit dubious as to whether running my code in an environment consisting of an event loop created by runBlocking with no multiple suspend functions running concurrently (so no items waiting in the event loop queue) would be enough for yield() to cause a suspension, but if you invoke the function below, with whatever huge times value, you'll see that no stack overflow happens. So yes, yield() prevents a stack overflow.


suspend fun wrapTextAsync3(txt: String, times: Int): String {
    return if (times <= 0) txt
    else {
        yield()
        wrapTextAsync3("[$txt]", times - 1)
    }
}


No comments:

Post a Comment