Breaking @tailrec

To follow through this example, you will need to know some basic (really basic) Scala. You will also need to know what call-by-name and call-by-value means in the context of functional programming. Those are concepts very loosely related to “pass-by-value” and “pass-by-reference” in imperative programming; you should not confuse them. tl;dr at the end.

I have recently been learning Scala and functional programming in general. Recursive exercises are a great warm-up to make the imperative programmer start thinking in a more recursive manner, so one of the first few examples I wrote was factorial. I realized I needed to understand quickly how the language treats tail recursion, and factorial is a great exercise for that.

The following intuitive one-liner is, unfortunately, not tail recursive, but at least the stack grows linearly with n:

def factorialNaive(n:Int) = if(n == 0) 1 else n * factorialNaive(n-1)

So a tail-recursive implementation is quite appropriate here, and the following pair of functions will get the job done in exactly such a manner:

def factorialTailRec(n: Int) : Int = {
    @tailrec
    def factorialTailRec(n: Int, f:Int): Int = {
      if (n == 0) f
      else factorialTailRec(n - 1, n * f)
    }
    factorialTailRec(n, 1)
  }

Now at this point it’s important to remind ourselves that the @tailrec annotation, which can be used after we import scala.annotation.
tailrec, is not required, but is a very good idea to include since it will warn us at compile-time if the recursive call is not in tail position. If we were to use that annotation for factorialNaive, the code would not compile:

Now here is where an evil idea creeped into my mind. I noticed that I was passing the second parameter of factorialTailRec by value. I thought to myself: “This means that every “stack frame” (more like iteration scope at this point) is burdened with one multiplication… so all the way down the call chain we have n-1 multiplications. The alternative, of passing the parameter by name, would delay any multiplication up to the last stack frame, where we bottom out with a term of 1. So I would expect similar performance.”

Turns out, the above is only partly true, and the most interesting fact is not included in it! You see, while it is the case that the product computation is delayed until the very end, its terms are embedded within every one of the stack frames by the anonymous function that the called-by-name parameter builds! The evaluation of every multiplication i * (i+1) with 0 <= i < n, even when i+1 can be evaluated within the current stack frame, requires popping another stack frame, a frame that was previously pushed by the otherwise tail-recursive call that made a call-by-name for its 2nd argument!

To test this, we can use the following runner program. Note that I don’t care at all about the actual factorial values, so I let the result overflow and be as non-sensical as it likes. I’m not even assigning it anywhere. I’m interested exclusively in how the code affects the use of the JVM’s stack. For the hyper-exponential factorial function, even the Long data type is not sufficient and one had best use an efficient BigInteger library or Stirling’s approximation if they care about computing the values of large factorials.

package factorial

import scala.annotation.tailrec

object Factorial extends App {

  val ITERS = 100000      // Let's push things

  // Naive Factorial
  def factorialNaive(n:Int) : Int =  if(n == 0) 1 else n * factorialNaive(n-1)

  try {
    for (i <- 1 to ITERS) factorialNaive(i)
    println("Naive factorial worked up to " + ITERS + "!.")
  } catch{
    case  se:StackOverflowError => println("Naive factorial threw an instance of StackOverflowError.")
    case e: Exception => println("Naive factorial threw an instance of " + e.getClass + " with message: " + e.getMessage + ".")
  }

  // Tail-recursive factorial
  def factorialTailRec(n: Int) : Int = {
    @tailrec
    def factorialTailRec(n: Int, f:Int): Int = {
      if (n == 0) f
      else factorialTailRec(n - 1, n * f)
    }
    factorialTailRec(n, 1)
  }

  try {
    for(i <-1 to ITERS) factorialTailRec(i)
    println("Tail recursive factorial worked up to " + ITERS + "!.")
  } catch{
      case  se:StackOverflowError => println("Tail recursive factorial threw an instance of StackOverflowError.")
      case e: Exception => println("Tail recursive factorial threw an instance of " + e.getClass + " with message: " + e.getMessage + ".")
  }

  println("Exiting...")

}

Notice that in factorialTailRec, the accumulator argument is passed by value. Running this program for the given parameter of ITERS=100,000 yields the output:

Naive factorial threw an instance of StackOverflowError.
Tail recursive factorial worked up to 100000!
Exiting...

Nothing surprising so far. But what if I were to pass the second argument by name instead, by changing the method declaration to:

def factorialTailRec(n: Int, f: => Int): Int = ...

Then, our output is:

Naive factorial threw an instance of StackOverflowError.
Tail recursive factorial threw an instance of StackOverflowError.
Exiting...

Everybody has my express permission to print the above output and paste it on their office doors, perhaps appropriately captioned. Not to mention that @tailrec did not complain at all! The file compiled just fine. I’m not sure whether it’s possible to have @tailrec disallow call-by-name parameters. It’s one of those things that sound easy to do, but are probably very difficult in practice.

Now let’s play a little game. Suppose that for whatever reason you cannot change the signature of factorialTailRec and you are stuck with an idiotic call-by-name that has doomed your method to be tail-recursively constructing a linear-size stack for a function that it itself builds… 😦 Is there anything we can do to achieve call-by-value?

Yup. This:

def factorialTailRec(n: Int) : Int = {
    @tailrec
    def factorialTailRec(n: Int, f: => Int): Int = {
      if (n == 0) {
        val fEvaluated = f
        fEvaluated
      }
      else {
        val fEvaluated = f
        factorialTailRec(n - 1, n * fEvaluated)
      }
    }
    factorialTailRec(n, 1)
  }

Which is a cheap general template to emulate call-by-value from a called-by-name parameter if you need it. Right now with my limited Scala experience, I can’t envision a scenario where this would be preferable to a simple call-by-value declared at the method signature level, but hey, it works:

Naive factorial threw an instance of StackOverflowError.
Tail recursive factorial worked up to 100000!
Exiting...

Finally, to once again underscore the difference between val and def, changing the two lines val fEvaluated = f into def fEvaluated = f does nothing to evaluate the value of the parameter, since the defs themselves are calls by name (aliases)! In fact, I’d be willing to bet that it doubles the constant factor in front of the O(n) space occupied by the stack.

Naive factorial threw an instance of StackOverflowError.
Tail recursive factorial threw an instance of StackOverflowError.
Exiting...

tl;dr I built a tail-recursive method in Scala that blows up the JVM’s stack.

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Google photo

You are commenting using your Google account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s