State monad

In imperative programming, we have the concept of global variables—variables that are available anywhere in the program. This approach is considered to be a bad practice, but is still used quite often. The concept of global state extends global variables by including system resources. As there is only one filesystem or system clock, it totally makes sense to make them globally and universally accessible from anywhere in the program code, right?

In JVM, some of these global resources are available via the java.lang.System class. It contains, for instance, references to "standard" input, output, and error streams, the system timer, environment variables, and properties. The global state should definitely be a good idea, then, if Java exposes it on a language level!

The problem with global state is that it breaks the referential transparency of the code. In essence, referential transparency means that it should always be possible to replace a part of the code, for example, a function call, with the result of the evaluation of this call everywhere in the program, and this change should not cause observable changes in program behaviour.

The concept of referential transparency is closely related to the concept of a pure function—a function is pure if it is referentially transparent for all its referentially transparent arguments.

We will see how this works in a moment, but for starters, please consider the following example:

var globalState = 0

def incGlobal(count: Int): Int = {
globalState += count
globalState
}

val g1 = incGlobal(10) // g1 == 10
val g2 = incGlobal(10) // g1 == 20

In the case of incGlobal, the function is not pure because it is not referentially transparent (because we cannot replace a call of it with the result of evaluation since these results are different each time the function is called). This makes it impossible to reason about the possible outcomes of the program without knowing the global state at every moment it is accessed or modified.

In contrast, the following function is referentially transparent and pure:

def incLocal(count: Int, global: Int): Int = global + count

val l1 = incLocal(10, 0) // l1 = 10
val l2 = incLocal(10, 0) // l2 = 10

In functional programming, we are expected to use only pure functions. This makes global state as a concept unsuitable for functional programming.

But there are still many cases where it is necessary to accumulate and modify state, but how should we deal with that?

This is where the State monad comes into play. The state monad is build around a function that takes a relevant part of the global state as an argument and returns a result and modified state (of course, without changing anything in a global sense). The signature of such a function looks like this: type StatefulFunction[S, A] = S => (A, S).

We can wrap this definition into a case class to simplify the definition of helper methods on it. This State class will denote our effect:

final case class State[S, A](run: S => (A, S))

We can also define a few constructors in the companion object so that we're able to create a state in three different situations (to do this in REPL you need to use :paste command and paste both the case class and a companion object, then press Ctrl + D:

object State {
def apply[S, A](a: => A): State[S, A] = State(s => (a, s))
def get[S]: State[S, S] = State(s => (s, s))
def set[S](s: => S): State[S, Unit] = State(_ => ((), s))
}

The default constructor lifts some value, a: A, into the context of State by returning the given argument as a result and propagating the existing state without changes. The getter creates a State that wraps some function, returning the given argument both as the state and as a result. The setter wraps the State over the function, which takes a state to be wrapped and produces no result. The semantics of these are similar to reading the global state (hence the result is the equal state) and setting it (hence the result is Unit), but applied to s: S.

For now, the State is nothing but a thin wrapper around some computation which involves pushing through (and potentially changing) a bit of state. What we would like to be able to do is compose this computation with the next one. We'd like to do this similarly to how we compose functions, but instead of (A => B) compose (B => C), we now have State[S, A] compose State[S, B]. How can we do this?

By definition, our second computation accepts the result of the first one as its argument, hence we start with (a: A) => . We also stated that, as a result (because of the possible state change and return type of the second state), we'll have a State[S, B]  which gives us a full signature for the computation to compose with the first one: f: A => State[S, B] .

We can implement this composition as a method on State:

final case class State[S, A](run: S => (A, S)) {
def compose[B](f: A => State[S, B]): State[S, B] = {
val composedRuns = (s: S) => {
val (a, nextState) = run(s)
f(a).run(nextState)
}
State(composedRuns)
}
}

We define our composed computation as a combination of two runs. The first is done with the input provided to the first state, which we decompose into the result and a next state. We then call the provided transformation f on the result and run it with the next state. These two successive runs might seem strange at first glance, but they just represent the fact that we're fusing two run functions from different states into one function defined on the composed state.

Now, we have an effect and can create a monad for it. You should have noticed by now that the signature of the compose method we just defined is the same as that of the monadic flatMap.

The compose in this and the following cases does not refer to the function composition we learned about in Chapter 3Deep Dive into Functions, but to the concept of Kleisli composition. It is often called Kleisli arrow, and in essence is just a wrapper over the A => F[B] function, which allows for the composition of functions returning monadic values. It is frequently named >>=, but we'll stick to compose here.

This allows us to delegate monadic behavior to the logic we already have in the State, the same way as we could do for standard effects:

import ch09._
implicit def
stateMonad[S] = new Monad[State[S, ?]] {
override def unit[A](a: => A): State[S, A] = State(a)
override def flatMap[A, B](a: State[S, A])(f: A => State[S, B]): State[S, B] = a.compose(f)
}

Luckily, we can also delegate the lifting done by the unit to the default constructor! This means that we're done with the definition of the monad and can continue with our rigorous testing approach by specifying a property check for it.

Except in this case, we won't.

The rationale behind this is the fact that the State is quite different from the other effects we looked at until now in regard to the value it incorporates. The State is the first effect which is built exclusively around some function. Technically, because functions are first-class values in Scala, other effects such as Option could also contain a function and not a value, but this is an exception.

This brings complications to our testing attempts. Earlier, we modified the value contained in the effect in different ways and checked that the results we equal, as required by the monadic laws, by comparing them. With the requirement to have a function as a value of the effect, we face the challenge of comparing two functions for equality. At the time of writing this book, this is a topic of active academic research. For our practical purposes, there is currently no other way to prove that two functions are equal other than testing them for each possible input parameter(s) and checking whether they return same results—which we obviously cannot afford to do in our properties.

Instead, we will prove that our implementation is correct. We will use a method called the substitution model for this. The essence of the method is in using referential transparency in order to substitute all of the variables and function calls with values they return repeatedly until the resulting code can't be simplified anymore—very much like solving an algebraic equation.

Let's see how this works.

To get us prepared before proving the monadic laws, we'll prove a useful lemma first.

The lemma is stated as follows: having as: M[A], f: A => M[B] and M = State so that as.run = s => (a, s1) (the run method returns a pair of a and s1 for some input s and f(b) = (b: A) => State(s1 => (b, s2)), M.flatMap(as)(f) will always yield State(s => (b, s2)).

This is how we're getting this formula:

  1. By definition, as.run = s => (a, s1), which gives us as = State(s => (a, s1))
  2. The flatMap delegates to the compose method defined on State, and therefore M.flatMap(a)(f) for M = State becomes a.compose(f)
  3. In terms of as and fas.compose(f) can be formulated as State(s => (a, s1)).compose(f)

Now, we're going to substitute the call of the compose method with its definition:

State(s => (a, s1)).compose(f) = State(s => {
f(a).run(s1) // substituting f(a) with the result of the call
}) = State(s => {
State(s1 => (b, s2)).run(s1)
}) = State(s => (b, s2))

Here, we have proved our assumption that Monad[State].flatMap(as)(f) = State(s => (b, s2)) for as = State(s => (a, s1)) and f(a) = (b: A) => State(s1 => (b, s2)).

Now, we can use this lemma while proving the monadic laws for State.

We'll start with the identity laws, and more specifically, with the left identity. This is how we formulated it in our ScalaCheck property:

val leftIdentity = forAll { as: M[A] =>
M.flatMap(as)(M.unit(_)) == as
}

Thus, we want to prove that if we let M = State, then every as: M[A] following it is always true:

M.flatMap(as)(M.unit(_)) == as

Let's simplify the left side of the equation first. By definition, we can replace as with State implementation:

M.flatMap(State(s => (a, s1)))(M.unit(_))

The next step that we must do is substitute the call of the unit method with its implementation. We're just delegating to the default constructor of the State, which is defined as follows:

 def apply[S, A](a: => A): State[S, A] = State(s => (a, s))

Hence, our definition becomes the following:

 M.flatMap(State(s => (a, s1)))(b => State(s1 => (b, s1)))

To substitute the flatMap call, we have to recall that all it does is just delegate to the compose method defined on State:

State(s => (a, s1)).compose(b => State(s1 => (b, s1)))

Now, we can use our lemma for state composition, which gives us the following simplified form:

State(s => (a, s1))

This can't be simplified further, so we will now take a look at the right side of equation, as. Again, by definition, as can be represented as State(s => (a, s1)). This gives us final proof that State(s => (a, s1)) == State(s => (a, s1)), which always holds for any a: A.

The right side identity is proved similarly to the left side, and we leave this as an exercise to the reader.

The second law we need to prove is the associative law. Let's recall how it is described in ScalaCheck terms:

forAll((as: M[A], f: A => M[B], g: B => M[C]) => {
val leftSide = M.flatMap(M.flatMap(as)(f))(g)
val rightSide = M.flatMap(as)(a => M.flatMap(f(a))(g))
leftSide == rightSide
})

Let's see what we can do with that, starting with the  leftSide,  M.flatMap(M.flatMap(as)(f))(g).

By substituting M with State in the internal part, M.flatMap(as)(f) becomes State(s => (a, s1)).compose(f), which by the application of our lemma transforms it into State(s => (b, s2)).

Now, we can substitute the outer flatMap:

M.flatMap(State(s => (b, s2)))(g) is the same as  State(s => (b, s2)).compose(g) (1)

Let's leave it in this form and look at the rightSideM.flatMap(as)(a => M.flatMap(f(a))(g)).

First we substitute the internal flatMap with the compose, before turning a => M.flatMap(f(a))(g) into (a: A) => f(a).compose(g).

Now, by the definition of f we used for the left side, we have f(a) = a => State(s1 => (b, s2)) and thus the internal flatMap becomes a => State(b, s2).compose(g).

Replacing the outer flatMap with compose gives us—in combination with the previous definitionState(s => (a, s1)).compose(a => State(s1 => (b, s2)).compose(g)).

We'll use our lemma again to substitute the first application of compose, which will have State(s => (b, s2)).compose(g)  as the outcome. (2)

(1) and (2) are identical, which means that the leftSide and rightRide of our property are always equal; we just proved the associativity law.

Great, we have an implementation of the State and the corresponding monad, which has been proven to be correct. It's time to look at them in action. As an example, let's imagine that we're going fishing by boat. The boat has a position and direction, and can go forward for some time or change direction:

final case class Boat(direction: Double, position: (Double, Double)) {
def go(speed: Float, time: Float): Boat = ??? // please see the accompanying code
def turn(angle: Double): Boat = ??? // please see the accompanying code
}

We could go around with this boat by calling its methods:

scala> import ch09._
import ch09._
scala> val boat = Boat(0, (0d, 0d))
boat: Boat = Boat(0.0,(0.0,0.0))
scala> boat.go(10, 5).turn(0.5).go(20, 20).turn(-0.1).go(1,1)
res1: Boat = Boat(0.4,(401.95408575015193,192.15963378398988))

There is a problem with this approach, though—it does not include fuel consumption. Unfortunately, this aspect was not envisioned at the time the boat's navigation was developed, and has been added later as a global state. We will now refactor the old style with the state monad. If the quantity of fuel is modelled as a number of litres, the most straightforward way to define the state is as follows:

type FuelState = State[Float, Boat]

Now, we can define our boat moving logic that takes fuel consumption into account. But before doing that, we are going to simplify the syntax of our monadic calls a bit. Currently, the flatMap and map methods of our Monad take two parameters—the container and the function to apply to the container.

We would like to create a wrapper that will incorporate both the effect and a monad so that we have an instance of the effect and only need to pass the transforming function to the mapping methods. This is how we can express this approach:

object lowPriorityImplicits {
implicit class MonadF[A, F[_] : Monad](val value: F[A]) {
private val M = implicitly[Monad[F]]
def unit(a: A) = M.unit(a)
def flatMap[B](fab: A => F[B]): F[B] = M.flatMap(value)(fab)
def map[B](fab: A => B): F[B] = M.map(value)(fab)
}
}

The implicit conversion MonadF will wrap any effect, F[A], as soon as there is an implicit monad definition available for F. Having value, we can use it as a first parameter for the flatMap and map methods defined on monad—thus, in the case of MonadF, they are reduced to higher-order functions taking single parameters. By importing this implicit conversion, we now can call flatMap and map directly on State:

State[Float, Boat](boat).flatMap((boat: Boat) => State[Float, Boat](???))

We also need to create pure functions that will take fuel consumption into account while moving the boat. Assuming that we can't change the original definition of Boat, we have to pass the boat as a parameter to these functions:

lazy val consumption = 1f
def
consume(speed: Float, time: Float) = consumption * time * speed
def turn(angle: Double)(boat: Boat): FuelState =
State(boat.turn(angle))
def
go(speed: Float, time: Float)(boat: Boat): FuelState =
new State(fuel => {
val newFuel = fuel - consume(speed, time)
(boat.go(speed, time), newFuel)
})

The consume function calculates fuel consumption based on speed and time. In the turn function, we're taking a boat, turning it by the specified angle (by delegating to the default implementation), and returning the result as an instance of FuelState.

A similar approach is used in the go method—to compute the boat's position, we are delegating to the boat logic. To sum the new volume of fuel available, we reduce the initial fuel quantity (which is passed as a parameter) and return the result as a part of the state. 

We can finally create the same chain of actions we had defined initially, but this time by tracking fuel consumption:

import Monad.lowPriorityImplicits._
def
move(boat: Boat) = State[Float, Boat](boat).
flatMap(go(10, 5)).
flatMap(turn(0.5)).
flatMap(go(20,20)).
flatMap(turn(-0.1)).
flatMap{b: Boat => go(1,1)(b)}

If you compare this snippet with the original definition, you'll see that the path of the boat is the same. However, much more is happening behind the scenes. Each call of the flatMap passes the state over—this is how it is defined in the code of the monad. In our case, the definition is the compose method defined on the State. The function given as a parameter to the flatMap method describes what should happen with the result and possibly with the passed state. In a sense, using monads gives us a responsibility separationthe monad describes what should happen between computation steps as the result of one step being passed to the next step, and our logic describes what should happen with the result before it is passed over to the next computation.

We defined our logic with partially applied functions, which obscure what is really happening a bitto make this obvious, the last step is defined using explicit syntax. We could also make the process of passing results between steps more explicit by using for-comprehension:

def move(boat: Boat) = for {
a <- State[Float, Boat](boat)
b <- go(10,5)(a)
c <- turn(0.5)(b)
d <- go(20, 20)(c)
e <- turn(-0.1)(d)
f <- go(1,1)(e)
} yield f

The approach is the same as before, but just the syntax has changed—Now, passing the boat between steps is done explicitly, but the state passing had visually disappeared—The for-comprehension makes monadic code look like it's imperative. This is the result of executing both of these approaches:

scala> println(move(boat).value.run(1000f))
(Boat(0.4,(401.95408575015193,192.15963378398988)),549.0)

How can we be sure that the state has been passed correctly? Well, this is what monad law guarantees. For those of you that are curious, we can even manipulate the state using methods we've defined in the state's companion object:

def logFuelState(f: Float) = println(s"Current fuel level is $f")

def
loggingMove(boat: Boat) = for {
a <- State[Float, Boat](boat)
f1 <- State.get[Float]
_ = logFuelState(f1)
_ <- State.set(Math.min(700, f1))
b <- go(10,5)(a)
f2 <- State.get[Float]; _ = logFuelState(f2)
c <- turn(0.5)(b)
f3 <- State.get[Float]; _ = logFuelState(f3)
d <- go(20, 20)(c)
f3 <- State.get[Float]; _ = logFuelState(f3)
e <- turn(-0.1)(d)
f3 <- State.get[Float]; _ = logFuelState(f3)
f <- go(1,1)(e)
} yield f

We augmented our previous for-comprehension with logging statements to output the current state after each step—These are the statements of the form:

  f1 <- State.get[Float]
_ = logFuelState(f1)

Does it feel like we're really reading some global state? Well, in reality, what is happening is that we're getting the current State as a result (this is how we defined State.get earlier), which is passed then over to the next computation—the logging statement. Further computations just use the results of the previous steps explicitly, just like they had before. 

Using this technique, we're also modifying the state:

  _ <- State.set(Math.min(700, f1))

Here, we're simulating that our boat has a fuel tank of a maximal capacity equal to 700. We're doing this by first reading the current state and then setting back whatever is smaller—the state passed by the caller of the run method or our tank capacity. The State.set method returns Unitthis is why we ignore it.

The output of the definition augmented with the logging looks like this:

scala> println(loggingMove(boat).value.run(1000f))
Current fuel level is 1000.0
Current fuel level is 650.0
Current fuel level is 650.0
Current fuel level is 250.0
Current fuel level is 250.0

As we can see, the limit of 700 was applied before the first movements of the boat.

There is still an issue with our implementation of move—it uses hardcoded go and turn functions as if we would only be able to navigate one specific boat. However, this is not the case—we should be able to do this with any boat which has go and turn functionality, even if they are implemented slightly differently. We could model this by passing the go and turn functions as parameters to the move method:

def move(
go: (Float, Float) => Boat => FuelState,
turn: Double => Boat => FuelState
)(boat: Boat): FuelState

This definition will allow us to have different implementations for the go and turn functions in different situations, but still, steer the boat along the given hardcoded path.

If we look carefully, we'll see that after creating the initial wrapper over the provided boat parameter, the definition of the move method has no further notion of the State—we need it to be a monad to be able to use for-comprehension, but this requirement is much more generic than the State we currently have.

We can make the definition of the move function generic by improving on these two aspectsby passing the effect instead of creating it and making the method polymorphic:

def move[A, M[_]: Monad](
go: (Float, Float) => A => M[A],
turn: Double => A => M[A]
)(boat: M[A]): M[A] = for {
a <- boat
b <- go(10,5)(a)
// the rest of the definition is exactly like before
} yield f

Now, we can follow the given path with any type which has a monad and the go and turn functions with specified signatures. Given the fact that this functionality is now generic, we can also move it into the Boat companion object along with the definition of the default boat.

Let's see how this approach works together with the state monad. It turns out that our definition of the go and turn methods does not need to change at all. All we need to do is call the new generic move method:

import Boat.{move, boat}
println(move(go, turn)(State(boat)).run(1000f))

It looks much nicer, but still there is some room for improvement. Specifically, the turn method does nothing but propagate the call to the default implementation. We can make it generic in the same way as we did for the move method:

def turn[M[_]: Monad]: Double => Boat => M[Boat] =
angle => boat => Monad[M].unit(boat.turn(angle))

We can't make it polymorphic in regard to the Boat because we need to propagate a call to the specific type, but we still have the generic monad type. This specific code uses the implicit definition of Monad.apply to summon the monad of a specific type.

Actually, we can also do the same for the go method—provide a default facade implementationand place them both into the companion object of the Boat:

object Boat {
val boat = Boat(0, (0d, 0d))
import Monad.lowPriorityImplicits._
def go[M[_]: Monad]: (Float, Float) => Boat => M[Boat] =
(speed, time) => boat => Monad[M].unit(boat.go(speed, time))
def turn[M[_]: Monad]: Double => Boat => M[Boat] =
angle => boat => Monad[M].unit(boat.turn(angle))
def move[A, M[_]: Monad](go: (Float, Float) => A => M[A], turn: Double => A => M[A])(boat: M[A]): M[A] = // definition as above
}

Again, to put this definition into the REPL you need to use the :paste command, followed by both the definition of boat case class and a companion object, and a combination of Ctrl + D.

Now, we can use the default implementations for the cases where we don't need to override the default behavior. For instance, we can get rid of the default turn implementation for the case of State and call move with the default one:

import ch09._
import
Boat.{move => moveB, turn => turnB, boat}
import StateExample._
type FuelState[B] = State[Float, B]
println(moveBoat(go, turnB[FuelState])(State(boat)).run(1000f))

We have to help the compiler to infer the correct type of monad to use by providing the type parameter, but now our definition of stateful behavior is reduced to the overriden definition of the go method—the rest of the code is generic.

As an illustration, we can reuse everything we have used so far with the Id monad—the result should be the same as executing the chain of calls directly on Boat. This is the complete implementation that's done with the Id monad:

import Monad.Id
import Boat._
println(move(go[Id], turn[Id])(boat))

Again, we're providing the type of monad to use, but this is pretty much it. Since Id[Boat] = Boat, we even can pass the boat directly without wrapping it into the Id.

Isn't that nice? We could use any monad we've defined so far to pass different effects to the main logic formulated in monadic terms. We'll leave the easy part—using existing definitions—as an exercise for the reader, and will now implement two other monads representing the read and write side of the State, that is, the Reader and Writer monads.

..................Content has been hidden....................

You can't read the all page of ebook, please click here login for view all page.
Reset
3.144.30.236