Thursday 10 October 2024

Partial Function Application in Python

I won't explain here the difference between partial function application and currying, there are millions of articles, posts and discussions about it everywhere. As well explained by wikipedia, partial application refers to the process of fixing a number of arguments of a function, producing another function of smaller arity. JavaScript comes with function.bind, and Python comes with functools.partial. I rarely use it though. Most times I just need to bind a parameter to a function expecting few parameters, and for that I tend to just use a lambda, I mean:


from functools import partial

def wrap_msg(wrapper: str, msg: str) -> str:
    return f"{wrapper}{msg}{wrapper}"

wrap1 = partial(wrap_msg, "||")
print(wrap1("aaa"))

wrap2 = lambda msg: wrap_msg("||", msg)
print(wrap2("aaa"))


In the above case it's a matter of taste, but for other use cases there are clear differences. First, let's see what functools.partial exactly does:

Return a new partial object which when called will behave like func called with the positional arguments args and keyword arguments keywords. If more arguments are supplied to the call, they are appended to args. If additional keyword arguments are supplied, they extend and override keywords. Roughly equivalent to:

So if we want to bind a parameter that is not the first one we'll have to bind it as a named parameter, and then we will have to invoke the resulting function passing all the parameters located after it as named ones. That's a constraint that probably makes more convenient using a lambda. On the other side, given the they extend and override keywords feature, binding a parameter as a named one gives as a very nice feature, we can use it to turn parameters into default parameters. We use partial to bind those parameters for which we've realised we want a default value, and then we still can invoke the resulting function providing those values if we want. I mean:


cities = ["Xixon", "Paris", "Lisbon"]
def join_and_wrap(items: Iterable[Any], jn: str, wr: str) -> str:
    return f"{wr}{jn.join(items)}{wr}"

join_and_wrap_with_defaults = partial(join_and_wrap, jn=".", wr="||")

# make use of both defaults
print(join_and_wrap_with_defaults(cities))

# "overwrite" one of the defaults
print(join_and_wrap_with_defaults(cities, jn=";"))
print(join_and_wrap_with_defaults(cities, wr="--"))

I was checking what others had to say about all this, and this post is an interesting one. They mention how functools.partial is pretty convenient to prevent the Closure in a loop issue. As you can see in the code below it's less verbose than the strategy of creating a wrapper function that works as scope. Notice that indeed with partial we are not creating a closure (it return a callable object that is different from a function, and traps the bound arguments as attributes of the object, not in the cells of a __closure__).


greetings = ["Bonjour", "Bon día", "Hola", "Hi"]

# remember the "closure in a loop issue" as the the variable has not a "iteration scope"

print("'- Closure in a loop' issue")
greet_fns = []
for greeting in greetings:
    def say_hi(name):
        return f"{greeting} {name}"
    greet_fns.append(say_hi)

for fn in greet_fns:
    print(fn("Francois"))

#Hi Francois
#Hi Francois
#Hi Francois
#Hi Francois

print("------------------------------")

print("- Fixed 1")
greet_fns = []
def create_scope(greeting):
    def say_hi(name):
        return f"{greeting} {name}"
    return say_hi
for greeting in greetings:
    greet_fns.append(create_scope(greeting))

for fn in greet_fns:
    print(fn("Francois"))

#Bonjour Francois
#Bon día Francois
#Hola Francois
#Hi Francois

print("------------------------------")

print("- Fixed 2")
greet_fns = []
def say_hi(greeting, name):
    return f"{greeting} {name}"
for greeting in greetings:
    greet_fns.append(partial(say_hi, greeting))

for fn in greet_fns:
    print(fn("Francois"))

#Bonjour Francois
#Bon día Francois
#Hola Francois
#Hi Francois

Reading about this stuff I've come across the better partial module, that implements an interesting approach to partial application, that feels like something in between the "classic" partial and currying. It creates partial functions that create other partial functions. You invoke them with real values and place-holders, until you have provided all the parameters. Sure this explanation is rather unclear, so I better copy-paste here an example from the module github.



from better_partial import partial, _
# Note that "_" can be imported under a different name if it clashes with your conventions

@partial
def f(a, b, c, d, e):
  return a, b, c, d, e

f(1, 2, 3, 4, 5)
f(_, 2, _, 4, _)(1, 3, 5)
f(1, _, 3, ...)(2, 4, 5)
f(..., a=1, e=5)(_, 3, _)(2, 4)
f(..., e=5)(..., d=4)(1, 2, 3)
f(1, ..., e=5)(2, ..., d=4)(3)
f(..., e=5)(..., d=4)(..., c=3)(..., b=2)(..., a=1)
f(_, _, _, _, _)(1, 2, 3, 4, 5)
f(...)(1, 2, 3, 4, 5)


Wednesday 2 October 2024

Using Suspend Functions to Prevent Stack Overflows in Kotlin

This post is the Kotlin counterpart to these 2 previous JavaScript and Python posts. Avoiding stack overflows by means of using suspendable functions (or async-await in JavaScript, Python). This technique does not have much of a real practical purpose, the standard mechanism to prevent stack overflows is using trampolines, so consider this post as a brain exercise that has helped me to better understand how coroutines and suspendable functions work (same as the aformentioned posts helped me better understand async/await).

When a kotlin suspend function gets suspended at a suspension point it returns to the calling function, so the corresponding stack frame is released (the function will be later resumed, in the sentence following that suspension point, by "something" calling the resume method of the continuation object corresponding to that function. If we are not stacking things in the stack, there's no risk of a stack overflow. So if we add in our recursive function a call to a trivial suspend function before the next call to the recursive function, we should avoid the stack overflow. The main thing here is clarifying what "trivial" suspend function we can use for this.

Let's start by showing a recursive function that causes a stack overflow:


fun wrapText1(txt: String, times: Int): String = 
    if (times <= 0) txt
    else wrapText1("[$txt]", times - 1)
	
try {
	println(wrapText1("b", 10000))
}
catch (ex: StackOverflowError) {
	println("StackOverflow!!!")
}

// StackOverflow!!!


First of all, that case could be fixed by leveraging Kotlin's compiler support for Tail Call Optimizations. That function is Tail Recursive, so we can just mark it with the tailrec modifier and the compiler will transform that recursion into a loop, so no more overflows.


tailrec fun wrapText1(txt: String, times: Int): String = 
    if (times <= 0) txt
    else wrapText1("[$txt]", times - 1)

That said, let's forget about the tailrec power. That works for tail recursive functions, while the idea of leveraging suspendable functions to prevent stack overflows will work the same for tail and non-tail recursion. Notice that I´ll run my recursive function in a coroutine created with the runBlocking builder using an event loop dispatcher, so everything runs in a single thread, as in JavaScript and Python asyncio.

Let's first try using a suspend function that indeed does nothing. It does not invoke another "real" suspendable function. As you can see in my post from 2021, in JavaScript awaiting on a resolved Promise (or just on a normal object) is enough to avoid the stack overflow, as it causes the function to be suspended, returning and releasing its stack frame. The "then" part to run after the await is not executed inmmediatelly, but put on the microtask queue, meaning that the function returns to the event loop (and this will execute the "then handler" taken from the microtasks queue). Asynchrony in Kotlin is implemented in a quite different way (continuations...) and I was quite convinced that this "doNothing" function would not cause a suspension, but I had to confirm it.


suspend fun doNothingSuspend(): String {
    return "a"
}

// as expected this one DOES NOT prevent Stack Overflows. Calling a function marked as suspend but that is not really suspending is useless for this
// there won't be suspension. So it's like Python asyncio
suspend fun wrapTextAsync1(txt: String, times: Int): String {
    return if (times <= 0) txt
    else {
        doNothingSuspend()
        wrapTextAsync1("[$txt]", times - 1)
    }
}

try {
	println(wrapText("b", 10000))
}
catch (ex: StackOverflowError) {
	println("StackOverflow!!!")
}

//StackOverflow!!!


Yes, we get an ugly stack overflow. Calling that suspend function does not cause a suspension cause the complex code in which a suspend function is transformed by the compiler will not return an COROUTINE_SUSPENDED value, that is what the calling function checks in turn for returning to its calling function (and so on down to the event loop) or continuing with the next sentence (that is what happens in this case). This is a bit related to what happens in Python, where marking a function that "does not really do something asynchronous" as async (a sort of "fake async function") will not really suspend the function. Because of how coroutines are implemented in Python (similar to generators), the chain of async calls will get down to the Task that controls the chain of coroutine calls, and as what has been returned is not an awaitable, the Task will just invoke send() in the first coroutine in the chain, which will go through the different coroutines in the chain recreating their stacks, up to the coroutine from which we had invoked the "fake async function", that will continue, hence running sequentially, not interrupted by a suspension, without having returned the control to the event loop.

So if we want to have a real suspension in our recursive function, returning from the function and avoiding a stack overflow we'll have to invoke a function that really suspends. The immediate option that comes to mind is using a sort of sleep (delay, in the Kotlin coroutines universe) with a minimum value. I say "minimum value" cause passing just 0 to kotlinx.coroutines.delay will return immediatelly without causing any suspension (it won't return a COROUTINE_SUSPENDED value).


suspend fun wrapTextAsync2(txt: String, times: Int): String {
    return if (times <= 0) txt
    else {
        // using a delay(0) is NOT an option, as it returns immediatelly and no suspension happens
        delay(1)
        wrapTextAsync2("[$txt]", times - 1)
    }
}


That works fine. delay(1) returns to the event loop and after that delay the corresponding continuation will resume the recursive function in the next sentence in the function (the next call to that wrapTextAsync2) with a clean stack. But obviously there's an important performance penalty due to the delay.

If you've read this previous post you'll be also familiar with kotlinx.coroutines.yield. Reading the implementation details I was a bit dubious as to whether running my code in an environment consisting of an event loop created by runBlocking with no multiple suspend functions running concurrently (so no items waiting in the event loop queue) would be enough for yield() to cause a suspension, but if you invoke the function below, with whatever huge times value, you'll see that no stack overflow happens. So yes, yield() prevents a stack overflow.


suspend fun wrapTextAsync3(txt: String, times: Int): String {
    return if (times <= 0) txt
    else {
        yield()
        wrapTextAsync3("[$txt]", times - 1)
    }
}