Recursion is a well-known technique in computer programming that allows a function to call itself until a base condition is met.
Scala provides an optimization technique known as tail recursion optimization. This technique allows recursive functions to be executed more efficiently by avoiding the creation of multiple stack frames for recursive calls. When a function calls itself, instead of adding a new frame to the stack, the previous frame is replaced with the new one. This means that the function will use the same stack frame throughout its execution, reducing the amount of memory needed and improving performance.
However, in order to apply this optimization, the function must meet certain criteria, such as having the last operation be the recursive call and not performing any additional work after the recursive call. Let's see how to build such functions!
The factorial example
This is a classic example of computing the factorial of a number. A non-tail recursive implementation of the factorial function would look like this:
def factorial(n: Long): Long =
if n == 0 then 1
else n * factorial(n - 1)
But a tail-recursive implementation of the same function would look like this:
import scala.annotation.tailrec
@tailrec
def factorial(n: Long, acc: Long = 1): Long =
if n == 0 then acc
else factorial(n - 1, n * acc)
The tail-recursive solution determines the factorial of a given number n using an accumulator acc as an auxiliary variable. The accumulator is used to avoid the creation of temporary variables in each recursive call. The base case in this function is when n equals zero, and the accumulator is returned. Otherwise, the function continues with the recursive call, passing n - 1 and acc * n as arguments. As the recursive call is the last operation in this function, not a multiply operation like in the first implementation, it is considered tail-recursive and can be optimized by the Scala compiler.
Note that Scala’s @tailrec annotation is used to ensure that a function is tail-recursive. When a function is marked with @tailrec, the compiler will check if the function is truly tail-recursive, and will generate a compile-time error if not. Additionally, Scala’s Optimized Tail Calls (OTC) feature enables the optimization of tail-recursive functions into iterative loops, which conserves memory and improves performance. OTC is implemented as a part of the Scala Language and requires a function to be tail-recursive.
Replacing recursion with a loop
Every recursive method can be written iteratively using a loop:
def factorial(n: Long): Long =
var result: Long = 1
for (i <- 1 to n) {
result *= i
}
result
But the use of mutable variables can lead to problems in asynchronous or large systems where it can be difficult to keep in mind the whole context of what is going on in the code. Recursion is a fundamental concept in functional programming and is a natural choice for solving problems. When using recursion in Scala, we can avoid the potential issues that arise from mutating data structures in loops.
The partition example
Imagine that there is some data of people:
case class Person(name: String, age: Int)
val data: List[People] =
List(Person("Dave", 15), Person("Mike", 23), Person("Kate", 6), Person("Mia", 32))
You need to divide the data into collections of people who are under and above 18 years old. Essentially, we need to go through the list of people once, check each age, and put the person on the list into the "under 18" or "over 18" collection. You can solve this problem by recursion:
import scala.annotation.tailrec
def splitPeopleByAge(people: List[Person]): (List[Person], List[Person]) = {
@tailrec
def withAccumulators(
people: List[Person],
below: List[Person] = List.empty,
above: List[Person] = List.empty,
): (List[Person], List[Person]) =
people match
case Nil => (below, above) // people list is over, we can return what we accumulated
case person :: otherPeople =>
if person.age < 18 then withAccumulators(otherPeople, person :: below, above)
else withAccumulators(otherPeople, below, person :: above)
withAccumulators(people) // accumulators are empty by default
}
We have defined an additional function that has two (initially empty) lists. They help us accumulate people of the desired category. Each recursive call we check the age and just put the person in the right collection. The basic recursion case is the end of the list of people: we just return whatever is accumulated.
Partitioning is an essential technique for processing large amounts of data efficiently. By breaking down the data into smaller, more manageable chunks, partitioning allows for parallel processing, reducing the overall processing time.
Accordingly, Scala already implements a similar method, with a more general structure. We can simply pass in a function that will define the division. You will leanr how this works in a later topic:
data.partition(_.age < 18) // the same result as for 'splitPeopleByAge(data)'Conclusion
To sum up, tail recursion is an important concept in functional programming, particularly in Scala. It allows for more efficient program execution by avoiding stack overflows and reducing memory usage. Scala supports tail recursion optimization through the @tailrec annotation, which verifies that a function is tail-recursive and generates optimized code accordingly.