Tail Recursion

In Flix, and in functional programming in general, iteration is expressed through recursion.

For example, if we want to determine if a list contains an element, we can write a recursive function:

def memberOf(x: a, l: List[a]): Bool with Eq[a] = 
    match l {
        case Nil     => false
        case y :: ys => if (x == y) true else memberOf(x, ys)
    }

The memberOf function pattern matches on the list l. If it is empty then it returns false. Otherwise, we have an element y and the rest of the list is ys. If x == y then we have found the element and we return true. Otherwise we recurse on the rest of the list ys.

The recursive call to memberOf is in tail position, i.e. it is the last thing to happen in the memberOf function. This has two important benefits: (a) the Flix compiler is able to rewrite memberOf to use an ordinary loop (which is more efficient than a function call) and more importantly (b) a call to memberOf cannot overflow the stack, because the call stack never increases in height.

Tip: Flix has support for full tail call elimination which means that recursive calls in tail position never increase the stack height and hence cannot cause the stack to overflow!

We remark that Flix has full tail call elimination, not just tail call optimization. This means that the following program compiles and runs successfully:

def isOdd(n: Int32): Bool =
    if (n == 0) false else isEvn(n - 1)

def isEvn(n: Int32): Bool =
    if (n == 0) true else isOdd(n - 1)

def main(): Unit \ IO =
    isOdd(12345) |> println

which is not the case in many other programming languages.

Non-Tail Calls and StackOverflows

While the Flix compiler guarantees that tail calls cannot overflow the stack, the same is not true for function calls in non-tail positions.

For example, the following implementation of the factorial function overflows the call stack:

def factorial(n: Int32): Int32 = match n {
    case 0 => 1
    case _ => n * factorial(n - 1)
}

as this program shows:

def main(): Unit \ IO = 
    println(factorial(1_000_000))

which when compiled and run produces:

java : Exception in thread "main" java.lang.StackOverflowError
	at Cont%Int32.unwind(Cont%Int32)
	at Def%factorial.invoke(Unknown Source)
	at Cont%Int32.unwind(Cont%Int32)
	at Def%factorial.invoke(Unknown Source)
	at Cont%Int32.unwind(Cont%Int32)
    ... many more frames ...

A well-known technique is to rewrite factorial to use an accumulator:

def factorial(n: Int32): Int32 = 
    def visit(x, acc) = match x {
        case 0 => acc
        case _ => visit(x - 1, x * acc)
    };
    visit(n, 1)

Here the visit function is tail recursive, hence it cannot overflow the stack.