Monday, March 19, 2012

Understanding Scala's partially applied functions and partial functions

The two similar terms - partially applied function and partial function - quickly confuse anybody new to functional programming languages. Mastering Scala is difficult by its own which doesn't help understanding the differences between the two similarly named function types. I try to dispel this confusion through some examples.

Before we start, we need to define some terms.

  • A method in Scala, just like in Java, is a part of a class (or object) and it has a complete signature: name, arguments, return type, and if it's not an abstract method, a body that implements the method. For example:
    class Math {
    
      def add(a: Int, b: Int): Int = a + b
    }
    As a syntactic sugar the return type can be omitted, because Scala can infer the type from the last expression of the method.
  • A Scala function on the other hand is in fact a class. There is a series of traits in Scala, each of them represents a function with various number of arguments: Function0 for parameterless functions, Function1 for functions with one parameter, Function2 for functions with two parameters, and so on up until Function22. Given these traits, a function is a class that mixes in one of them (depending on the actual number of parameters), and as such, it has methods. The most important one is the apply method, which contains the implementation of the body of the function. The number of arguments of the apply method is exactly the same as the number in the name of the mixed in trait, that is it's 0 for Function0, 1 for Function1, 2 for Function2, and so on. For example:
    object add extends Function2[Int, Int, Int] {
    
      override def apply(a: Int, b: Int): Int = a + b
    }
    Scala has many syntactic sugars, one of them is the special apply-syntax: if you write a symbol name followed by an argument list in parentheses, Scala translates that into a call to the apply method of the named object. Using the previous example instead of writing add.apply(1, 2), we can simply write add(1, 2) to get the same result.
  • A function literal or anonymous function is an alternate syntax to define functions. Instead of creating a new class that mixes in one of the FunctionN traits and then overriding the apply method, we can just simply list the named function parameters in parentheses, then write a right arrow, and then the body of the function. For example:
    (a: Int, b: Int) => a + b
    From such function literals the Scala compiler generates a function object which mixes in one of the FunctionN traits: the left hand side of the => becomes the parameter list and the right hand side becomes the implementation of the apply method. As you see, it is really nothing more than a shorthand form of a function.
  • A function value is an instance of some class that extends one of the FunctionN traits. A function literal, for example, is first compiled into a class that mixes in a FunctionN trait and then instantiated runtime. The created instance is a function value. Since function values are objects, we can store them in a variable, but they are functions, too, so we can invoke them using the parentheses function-call notation (in fact it will be converted to a call to the apply method of the function class instance which is assigned to the variable). For example:
    val add = (a: Int, b: Int) => a + b
    add(1, 2) // returns 3
    

Partially applied functions

Now that we're familiar with these terms we arrived to the partially applied functions. From now on, the terms method and function are interchangeable. While in general this is not true, I'll use this simplification because it really doesn't matter when we talk about partially applied functions.

Normally, when we invoke a method or a function, we pass in all the required arguments. The function is then applied to these arguments and it computes some value. In Scala, however, it is not strictly necessary to use all of the arguments. We can decide to pass in only some of them or none of them. In either case, Scala cannot compute any result, since arguments are missing, so instead of doing so, it will automatically generate a partially applied function using the supplied arguments and then creates a function value from it. The generated partially applied function mixes in one of the FunctionN traits, where N depends on the number of missing function-arguments. The body of the generated function's apply method is basically the invocation of the original function's apply method with the supplied arguments completed with the arguments passed in to the apply method of the generated partially applied function. While it may sound very complicated, it turns out to be very simple if it's demonstrated through an example:

val add = (a: Int, b: Int) => a + b
val inc = add(_: Int, 1)
inc(10) // returns 11

What happens here is the following: we invoke the add(Int, Int) function with a fixed argument (the number 1) but the other Int argument is missing. Each missing argument is substituted with an underscore. Since only some of the arguments are supplied, the add(Int, Int) function cannot be executed, so inc will not store the result of the function but it will be a reference to the function value instantiated from the generated partially applied function. What is the type of this generated function? Well, there is only one missing argument, hence its type will be Function1. Function1 has two type parameters: one for the function's input parameter and one for its result type. The type of the missing argument of add is Int and its return type is also Int, so both of the type parameters will be Int. And what exactly does its apply(Int) method do? It simply invokes the original add function with the parameter supplied to apply and the fixed argument value 1. Given all this information we could even implement our own partially applied function (fortunately, the Scala compiler takes care of this so we don't have to):

object inc extends Function1[Int, Int] {

  override def apply(v: Int): Int = add(v, 1)
}
inc(10) // returns 11

Another use case is when none of the arguments are given. One can substitue each of the missing arguments with the underscore character but Scala provides an even simpler format: a single underscore character after the name of the function replaces an entire parameter list, for example add _. (Notice the space between the function name and the underscore! This is needed, otherwise the Scala compiler thinks we're referring to the function called add_.) We can use an even more concise form by leaving off the underscore character. This format, however, can only be used at places where a function is required at that point in the code. Such a point in the code is a higher order function which expects another function as an argument, for example the foreach method defined by the GenTraversableOnce trait, which accepts a single-parameter function: Array(1, 2, 3).foreach(println). In this last example the println(Any) method (defined in the scala.Predef object) is passed in to the foreach method of the array. No parameters are given to println and at that point in the code a function is expected, so we simply leave of the underscore character after println. The Scala compiler generates a partially applied function from the println method, which takes exactly one argument. The actual argument passed in to the partially applied println is the current element of the array as the foreach method iterates on it, so in every iteration the current element is printed to the console.


Partial functions

Despite the similar name, partial functions have nothing to do with the partially applied functions. If you studied mathematics, however, you already know what they are: a partial function from X to Y is a function ƒ: X' → Y, where X' is a subset of X. In other words the function is not defined for all elements of the X set. The most obvious example is perhaps the division function, which maps real numbers to real numbers, but it is not defined for the value 0. More formally ƒ(x, y) = x / y, where x, y ∈ ℝ and y ≠ 0.

Scala's partial functions do exactly the same: the function is defined only on a subset of the possible input arguments. Before we go into details about how they are actually implemented, let's take a little detour to the world of patterns. They're beyond the scope of this post so for now it's enough to know that Scala heavily uses patterns, starting from variable definitions, through for expressions to pattern matching. Most of the time you will encounter them inside a match expression as a series of cases, which let you select from a number of alternatives:

args(0) match {
  case "foo" => println("bar")
  case _ => println("?")
}

There are different kinds of patterns, but to simplify things let's just say whatever follows the case keyword is the pattern, and if the expression matches this pattern, the right hand side of the arrow will be executed. The match expression is not the only place though where we can use case sequences. As a matter of fact a case sequence in curly braces can be used anywhere a function literal can be used. Why? Because such a construction is a function literal itself, but it's more general: while a normal function or method has a single entry point with a given parameter list, case sequences have multiple entry points, and each of them has its own parameters (defined by the patterns). Another difference is that a single function has exactly one function body, while case sequences have as many function bodies as many cases form the sequence, that is the right hand side of each case (or entry point) is a function body. This generalization leads us to the partial functions, because such a case sequence is only interpreted on a subset of its arguments (where any of the patterns matches), and it's undefined in every other case.

val div: (Double, Double) => Double = {
  case (x, y) if y != 0 => x / y
}

The above example is a possible implementation of the aforementioned division function. While it does not demonstrate the usage of multiple cases, the single pattern (x, y) if y != 0 is enough to describe our partial function: it is defined for every pair of numbers except when the value of the divisor is 0. What is the output then of the following function call: div(1, 0)? Well, it will fail with a MatchError, because there is no pattern defined in our function which would match the case when the divisor is 0.

Defining a sequence of cases this way has two small problems though. The first one is that there is no way to test (other than invoking the function itself) if the input arguments are in the subset of the valid values. To avoid exceptions we have to put our function call into a try-catch block but that would pollute our code. The second one is that if the Scala compiler has enough information about all of the possible values of the patterns used in the case sequence, it will complain with a warning message if we don't define all of them. The compiler is absolutely right in this case, since forgotten pattern values lead to runtime exceptions. In case of partial functions, however, we don't want to define all of them, only a subset of them.

Fortunately Scala has an elegant way out of both problems. If we tell the compiler that we want to create a partial function, then we can not only check if the function is defined at the input arguments but the Scala compiler will also stop complaining about non-exhaustive matches on the patterns. So how can we let the compiler know that what we want is really a partial function? It's simple: we only have to define the type of the function as PartialFunction. Here's the div function again, this time as a real partial function:

val div: PartialFunction[(Double, Double), Double] = {
  case (x, y) if y != 0 => x /y
}

PartialFunction is basically a specialized Function1. It has all the methods Function1 has, plus some very handy new methods. For example the isDefinedAt() which we can use to check if the function is defined at the input arguments.

div.isDefinedAt(1, 0) // returns false

By using this method we can avoid the try-catch block, and additionally it's more expressive than catching a MatchError.

As you get more and more familiar with Scala, you'll notice that partially applied functions and partial functions are used at many places. For instance, when you use higher order functions, you are actually working with partially applied functions. Or if you look at a try-catch block, you will notice that the catch block is essentially a partial function. I hope this post helped to clarify the "myth" of partially applied and partial functions, and - should you find a use case for it - you can already start writing your own.