In this chapter, we’ll look at some of the functional programming features of Scala, specifically the ubiquitous map and flatMap functions. We’re interested in these because they’re closely related to the idea of monads, a key feature of functional programming.

Mapping Functions

You’ll see the map function on countless classes in Scala. It’s often described in the context of collections. Classes like List, Set, and Map all have it. For these, it applies a given function to each element in the collection, giving back a new collection based on the result of that function. You “map” some function over each element of your collection.

For example, you could create a function that works out how old a person is today, given the year of their birth.

  import java.util.Calendar
  def age(birthYear: Int) = {
    val currentYear = Calendar.getInstance.get(Calendar.YEAR)
    currentYear - birthYear

We could call the map function on a list of birth years, passing in the function to create a new list of ages.

  val birthdays = List(1990, 1977, 1984, 1961, 1973)

The result would be a list of ages. Assuming it’s run in 2017, we can transform the year 1990 into age 27, for example.

  res0: List[Int] = List(25, 38, 31, 54, 42)

Being a higher-order function, you could have written the function inline as a lambda like this:

  birthdays.map(year => Calendar.getInstance.get(Calendar.YEAR) - year)

Using the underscore as a shorthand for the lambda’s parameter, it would look like this:

  birthdays.map(Calendar.getInstance.get(Calendar.YEAR) - _)

It's Like foreach

So, map is a transforming function. For collections, it iterates over the collection applying some function, just like foreach does. The difference is that unlike foreach, map will collect the return values from the function into a new collection and then return that collection.

It’s trivial to implement a mapping function by hand. For example, we could create a class Mappable that takes a number of elements of type A and creates a map function.

  class Mappable[A](val elements: List[A]) {
    def map[B](f: Function1[A, B]): List[B] = {

The parameter to map is a function that transforms from type A to type B; it takes an A and returns a B. I’ve written it longhand as a type of Function1 which is equivalent to Java’s java.util.function.Function class. We can also write it using Scala’s shorthand syntax and the compiler will do the conversion for us.

  def map[B](f: A => B): List[B] = ...

Then it’s just a question of creating a new collection, calling the function (using apply) with each element as the argument. We’d store the result to the new collection and finally return it.

  class Mappable[A](val elements: List[A]) {
    def map[B](f: A => B): List[B] = {
      val result = collection.mutable.MutableList[B]()
      elements.foreach {
        result += f.apply(_)

We can test it by creating a list of numbers, making them “mappable” by creating a new instance of Mappable and calling map with an anonymous function that simply doubles the input.

  object Example extends App {
    val numbers: List[Int] = List(1, 2, 54, 4, 12, 43, 54, 23, 34)
    val mappable: Mappable[Int] = new Mappable(numbers)
    val result = mappable.map(_ * 2)

The output would look like this:

  List(2, 4, 108, 8, 24, 86, 108, 46, 68)


You’ll often see the flatMap function where you see the map function. For collections, it’s very similar in that it maps a function over the collection, storing the result in a new collection, but with a couple of differences.

  • It still transforms but this time the function applies a one-to-many transformation. It takes a single argument as before but returns multiple values.

  • The result would therefore end up being a collection of collections, so flatMap also flattens the result to give a single collection.


  • For a given collection of A, the map function applies a function to each element transforming an A to B. The result is a collection of B (that is, List[B]).

  • For a given collection of A, the flatMap function applies a function to each element transforming an A to a collection of B. This results in a collection of collection of B (that is, List[List[B]]), which is the flattened to a single collection of B (that is, List[B]).

Let’s say we want a mapping function to return a person’s age plus or minus a year. So, if we think a person is 38, we’d return a list of 37, 38, 39.

  import java.util.Calendar
  def ageEitherSide(birthYear: Int): List[Int] = {
    val today = Calendar.getInstance.get(Calendar.YEAR)
    List(today - 1 - birthYear, today - birthYear, today + 1 - birthYear)

The signature has changed from the previous example to return a List[Int] rather than just an Int. If we pass the list of birthday years into the map function, we get a list of lists back (res0 below).

  val birthdays = List(1990, 1977, 1984)


  scala> birthdays.map(ageEitherSide)
  res0: List[List[Int]] =
    List(List(26, 27, 28), List(39, 40, 41), List(32, 33, 34))

If, however, we pass it into the flatMap function, we get a flattened list back. It maps, then flattens.

  scala> birthdays.flatMap(ageEitherSide)
  res1: List[Int] = List(26, 27, 28, 39, 40, 41, 32, 33, 34)

If you wanted to write your own version of flatMap, it might look something like this (notice the return type of the function):

  class FlatMappable[A](elements: A*) {

    def flatMap[B](f: A => List[B]): List[B] = {
      val result = collection.mutable.MutableList[B]()
      elements.foreach {
        f.apply(_).foreach {
          result += _

The first loop will enumerate the elements of the collection and apply the function to each. Because this function itself returns a list, another loop is needed to enumerate each of these, adding them into the result collection. This is the bit that flattens the function’s result.

To test it, let’s start by creating a function that goes from an Int to a collection of Int. It gives back all the odd numbers between zero and the argument.

  def oddNumbersTo(end: Int): List[Int] = {
    val odds = collection.mutable.MutableList[Int]()
    for (i <- 0 to end) {
      if (i % 2 != 0) odds += i

We then just create an instance of our class with a few numbers in. Call flatMap and you’ll see that all odd numbers from 0 to 1, 0 to 2, and 0 to 10 are collected into a list.

  object Example {
    def main(args: Array[String]) {
      val mappable = new FlatMappable(1, 2, 10)
      val result = mappable.flatMap(oddNumbersTo)

The output would be the following:

  List(1, 1, 1, 3, 5, 7, 9)

Not Just for Collections

We’ve seen how map and flatMap work for collections, but they also exist on many other classes. More generally, map and flatMap operate on what’s called monads. In fact, having map and flatMap behavior is one of the defining features of monads.

So just what are monads? We’ll look at that next.

