The tale of the flatMap

lianyungang

Background

There is a reason why flatMap and its cousin map deserves a special place in the Scala folklore. It all starts from the combinator pattern popularized by the functional programming paradigm, which according to its definition, is “a style of organizing libraries centered around the idea of combining things”. In this design pattern, there are usually a few ways to construct some primitive values of type T and a few combinators (functions) which can combine values of type T to build up more complex values of type T. Proponents of the combinator pattern argue that it promotes modularity and composibility of the program.

Many real world libraries are built using the combinator pattern, including the Scala collection and parser combinator, Dataset in Spark, Queries in Slick and many domain specific languages (DSLs).

If someone wants to leverage the combinator pattern, the challenge is to come up with both the type T and a set of combinators for T to encapsulate whatever computation the author wants to express. If the computation happens to be sequential by nature, that is where Monad comes in handy.

There might be a lot of fuzz about Monad, but it is not much more than a type T with a few pre-defined combinator interfaces. At a high level, it allows us to 1) write a description of a computation and 2) have a way to apply a pure function to the result of the previous computation and transform it into the description of the next computation. With these two weapons at disposal, we can then take an initial description of the computation and sequentially transform it into the final description of the computation by applying a series of pure functions to it. As we can see here, the key of this programming approach is how these pure functions get applied. flatMap defines exactly that.

In fact, sequential compuation are so common that Monad becomes the basis of many types that try to leverage the combinator pattern, to the point where many languages provide syntactic sugar to make the chaining of flatMaps seem less nested and more pleasant for eyes. In Scala, the keyword is for.

flatMap in different context

List

Let’s look at flatMap in the context of a Scala List.

scala> List(1, 2, 3).flatMap { i =>
     |   List(4, 5, 6).map { j =>
     |     (i, j)
     |   }
     | }
res0: List[(Int, Int)] = List((1,4), (1,5), (1,6), (2,4), (2,5), (2,6), (3,4), (3,5), (3,6))

If we change flatMap to map in the above example, the result would have been

res0: List[List[(Int, Int)]] = List(List((1,4), (1,5), (1,6)), List((2,4), (2,5), (2,6)), List((3,4), (3,5), (3,6)))

The intuition here seems to be simple, flatMap maps first and then flattens out the resulting nested list.

But if we think further, what kind of computation does List[Int] really encapsulate? One way of looking at it is that a List[T] represents something that has multiple but equally possible values. For example, List(1, 2, 3) could be considered as one value with 1/3 chance of being 1, 2 and 3 respectively. What flatMap does in the above example is that it combines elements in two list into a list of tuples with respect to the possibilities of each tuple being the potential value of it. e.g. (1, 4) has 1/9 of the chance of being the result of the final list because 1 and 4 both has 1/3 of the chance of being the value of the first and second list respectively.

If we re-write the above example with for syntax, that is where the feelings of a loop in an imperative language gets implemented with Monadic List semantics.

scala> for {
     |   i <- List(1, 2, 3)
     |   j <- List(4, 5, 6)
     | } yield (i, j)
res2: List[(Int, Int)] = List((1,4), (1,5), (1,6), (2,4), (2,5), (2,6), (3,4), (3,5), (3,6))

Fun fact, someone actually implemented a special list where elements do not have the equal possibility of being the potential value.

Option

Option[T] represents a value that may or may not exist. Compared to List[T], the semantics of the flatMap in the context of Option[T] is to skip the rest of the computation and return None directly if the result of the previous computation is None, otherwise continue the computation. The definition of flatMap for Scala Option is pretty straightforward:

@inline final def flatMap[B](f: A => Option[B]): Option[B] =
  if (isEmpty) None else f(this.get)
Future

Future[T] represents a value that will be available in the future, the semantics of the flatMap in the context of Future[T] is to basically wait until the value of the previous future becomes available and then pass it along to the next future. Essentially executing a list of futures in sequence. Any exceptions along the way will pop up in the resulting Future as well.

def flatMap[S](f: T => Future[S])(implicit executor: ExecutionContext): Future[S] = transformWith {
  case Success(s) => f(s)
  case Failure(_) => this.asInstanceOf[Future[S]]
}
Either and Validation

Either[Err, T] is sometimes referred to as error monad, it is one the ways in which functional programming simulate exception handling. The flatMap semantics in the context of Either[Err, T] is similiar to Option[T] in that it will skip the rest of the computation if error occurs. The difference is that Option[T] represents all the error cases as None, whereas Either[Err, T] goes a bit further by representing them using type Err. Here is an example (Scala 2.12+):

scala> :paste
// Entering paste mode (ctrl-D to finish)

import scala.util.Try

val zeroE: Either[Throwable, Int] = Right(0)
val oneE: Either[Throwable, Int] = Right(10)
val twoE: Either[Throwable, Int] = Right(2)

val result: Either[Throwable, Int] = for {
  one          <- oneE
  two          <- twoE
  divideByZero <- Try { one / 0 }.toEither
} yield divideByZero + two

// Exiting paste mode, now interpreting.

import scala.util.Try
zeroE: Either[Throwable,Int] = Right(0)
oneE: Either[Throwable,Int] = Right(10)
twoE: Either[Throwable,Int] = Right(2)
result: Either[Throwable,Int] = Left(java.lang.ArithmeticException: / by zero)

Either[Err, T] is useful our error handling strategy is to halt the execution and return whatever the error is. If we wanna go a bit further than that and collect all the errors that have happened during the entire execution, we can use the Validation monad. Validation monad is not part of the Scala standard library, here is an example in Scalaz.

State

State[S, T] monad is a way of simulating global mutable state in functional programming. The most concise way of demonstrating the semantics of flatMap in this context is the following piece of Haskell code:

(>>=) :: State s a -> (a -> State s b) -> State s b
(act1 >>= fact2) s = runState act2 is
    where (iv,is) = runState act1 s
          act2 = fact2 iv

>>= is flatMap in Haskell term. Essentially, flatMap applies a pure function to the result of the previous computation: iv, and also passes along the updated state is. In this way, it gives the feeling of a mutable global state, but what really happens is that a new updated state is created and passed along for every flatMap invocation.

The following is an example of State[S, T] monad in cats library.

scala> :paste
// Entering paste mode (ctrl-D to finish)

import cats._, cats.data._, cats.implicits._

type Stack = List[Int]
val pop = State[Stack, Int] {
  case x :: xs => (xs, x)
  case Nil     => sys.error("stack is empty")
}

def push(a: Int) = State[Stack, Unit] {
  case xs => (a :: xs, ())
}

def stackManip: State[Stack, Int] = for {
  _ <- push(3)
  a <- pop
  b <- pop
} yield(b)

stackManip.run(List(5, 8, 2, 1)).value

// Exiting paste mode, now interpreting.

import cats._
import cats.data._
import cats.implicits._
defined type alias Stack
pop: cats.data.State[Stack,Int] = cats.data.StateT@43944da8
push: (a: Int)cats.data.State[Stack,Unit]
stackManip: cats.data.State[Stack,Int]
res1: (Stack, Int) = (List(8, 2, 1),5)

stackManip demonstrates the state manipulation in the imperative programming style.

Roll your own type that supports the for syntax

for in Scala is syntactic suger, which means that they are expanded at the parser phase during compilation.

➜  mina scalac -Xshow-phases
    phase name  id  description
    ----------  --  -----------
        parser   1  parse source into ASTs, perform simple desugaring

               .....

Simply speaking, when compiler stumbles upon a for expression, it converts mostly using flatMap, map, withFilter, etc.

Take the following example:

for {
    i <- Some(1)
    j <- Some(2) if j > 0
} yield i + j

If we click “Desugar scala code…” in Intellij, it will be desugared into the following code.

Some.apply(1).flatMap((i: Int) => Some.apply(2).withFilter((j: Int) => j.>(0)).map((j: Int) => i.+(j)))

As we can see, for is transformed into a series of flatMap with a map at the end. Conditionals are translated into withFilter method.

If we want to write a type that supports the for syntax, we just need to implement flatMap and map function for it. Conditionals is only possible when withFilter is implemented as well.

You’d wish it is that perfect

But it isn’t. Scala monad might not be the real monad.

By definition, a monad has to satisfy a few Monad laws. In fact, people have written property based tests in Scalaz to verify if a type satisfy those laws. Since for syntax is purely a syntatic sugar, the fact that a type supports for syntax doesn’t qualify it as a monad. For example, Future[T] can cause side effects and might throw exceptions, which easily breaks the Associativity Monad law, which basically means it shouldn’t matter how nested the flatMap is chained.

Also for syntax supports conditionals, pattern matching and variable assignment, which is more than the scope of the monad.

Because of the implicit conversion, we can flatMap List[T] over Option[T], which might not look that strict by Monad definition.

This discrepency has not stopped flatMap and Monad from being very popular in Scala, but it’s something interesting to be aware of.

Written on January 21, 2018