### The State Monad

What is the*State Monad?*If you have been following this blog, you already know the answer. In fact, the parser that we have seen in the previous pos

**t**is one embodiment of it. In general terms, the state monad is just a glorified function that takes a state and computes from that a result value and some new state. Apart from embedding this function, the state monad has a few bells and whistles that help in combining such functions together. In functional programming languages, the state monad is used to simulate external mutable state. The state monad's role is to pass down the state through the calls of the function. The state parameter disappears from the functions actually used.

You have already seen a practical application of this in our parser. In the parser, the state values have been character sequences. We were able to process the input sequence without ever seeing a corresponding parameter in our methods.

Besides parsing, another often-mentioned application of the state monad is recursive function memoization. In this case, the state will consist of pre-computed function values (a

*memo*). We'll see how to memoize a recursive function that computes the Fibonacci numbers without tacking an extra memo argument onto the function (such as a HashMap). It is the state monad that will keep a memo of already computed values between recursive calls.

You can google this stuff. People write about it often. For example, here are two articles that both deal with memoizing Fibonacci numbers, in a fashion similar to but not identical with mine: one in Java by Pierre-Yves Saumont and another in Scala by Tony Morris. But both miss a point I am going to make below, namely how to factor out the memoization logic from the function itself. (Saumont, by the way, is in the process of writing what looks to be an interesting book.)

I'll first show you how memoizing Fibonacci numbers is done, and define a generic memoize method that will enable memoization for any function. Only then will I show you how to derive the underlying state monad implementation through a few little refactorings of our SimpleParser. This way, you won't be inundated with details, having no idea what they lead up to. I'm hoping that the code is clear enough to grasp its

*intent*even without knowing the internals. If it doesn't work for you, try reading this post from the end to the beginning.

And in the last section, I'll give you my evaluation of the whole thing.

### Memoizing Fibonacci numbers

I expect you all know from your introduction to algorithms course that the following naive version will blow up in your face:BigInteger fibNaive(int n) { if (n <= 2) { return BigInteger.ONE; } BigInteger x = fibNaive(n - 1); BigInteger y = fibNaive(n - 2); return x.add(y); }

One way to make the algorithm feasible is to remember previously computed values.

BigInteger fibMemo(int n, Map<Integer, BigInteger> memo) { BigInteger value = memo.get(n); if (value != null) { return value; } BigInteger x = fibMemo(n - 1, memo); BigInteger y = fibMemo(n - 2, memo); BigInteger z = x.add(y); memo.put(n, z); return z; }

The above is written in a non-functional style with an explicit extra argument. (I'll name it the

*explicit version*.) You would call it like this, with a memo already containing the first two Fibonacci numbers:

BigInteger fibMemo(int n) { if (n <= 2) { return BigInteger.ONE; } Map<Integer, BigInteger> memo = new HashMap<>(); memo.put(1, BigInteger.ONE); memo.put(2, BigInteger.ONE); return fibMemo(n, memo); }

Now wouldn't it be nice if we could get rid of that extra argument and the tedium of looking up and storing values, to derive a memoizing version that is closer to the naive version? And, yes, we can. Simply apply a new function, called

*memoize*, to the original function.

StateMonad<BigInteger, Memo<Integer,BigInteger>> fib(int n) { return memoize(this::fib).apply(n-1).bind(x -> // x = memoize (fib) (n-1) memoize(this::fib).apply(n-2).bind(y -> // y = memoize (fib) (n-2) StateMonad.get(_s -> x.add(y)))); // return x.add(y) }

Basically, this is only a notational variant of the naive version, a bit difficult to read at first because Java lacks the syntactic sugar other languages have. I have tried to indicate such sugar in the inline comment. Having had time to get used to it, I find it doesn't read too bad, just ignore the clutter of "apply" and "bind", and remember that the assignment is notated at the end of the line instead of the beginning.

Well, we also had to change the return type a bit, it no longer represents a value, but a computation that yields this value (the function represented by the state monad). As you perhaps have guessed, the method

*bind*above is the equivalent of

*then*in our parser: the signature is the same, except the continuation that we put in, and the combined computation that we get out, are not parsers, but a more general

*StateMonad*class that can have a memo as its state. As with the parser, the function inside the monad takes no arguments except the state, all other arguments are supplied beforehand.

My main point is that here we have a very clear separation of concerns: There is one entity

*memoize*, concerned with caching computed values, another entity, the

*StateMonad*, concerned with passing the cached values between function calls, and finally

*fib*itself, that embodies the definition of the Fibonacci numbers and is concerned with making recursive calls and combining their results.This is in contrast to the explicit version and also to both posts quoted above, where these concerns are more intertwined.

For me, this is of the essence of functional programming, and therein lies its beauty: that it is so neat and modular on a small scale.

In order to get a value from the computation that is returned by

*fib*, we need to evaluate it against some suitable initial state. The function that does this is called

*evalState*. (And this you also already know from the parser, where it was called just

*eval*.)

BigInteger fibMonadic(int n) { return n <= 2 ? BigInteger.ONE : fib(n).evalState(new Memo<Integer,BigInteger>().insert(1, BigInteger.ONE).insert(2, BigInteger.ONE)); }

The class

*Memo*is intended to be a functional equivalent to HashMap.

### Implementing generic memoization

So, what is*memoize*? The signature of

*memoize*is obvious. It takes a function of the type of

*fib*and returns another function of the same type. We can abstract over the input and result value types

*Integer*and

*BigInteger*of

*fib*and replace them with generic type parameters T and R.

Now think of what

*memoize*has to do. It takes a function as its parameter. It must return a new function that

- when given an argument, tries to look up a previously computed value from the memo
- if it finds the value, returns a function that will yield this value
- otherwise, applies the given function to the given argument, stores that result in the memo, and returns a function will yield this value plus an updated state

*Optional*upon lookup, so that we can make the case distinction with

*Optional*'s methods

*map*and

*orElse*.Without further ado, here's the code:

static <T,R> Function<T, StateMonad<R, Memo<T,R>>> memoize(Function<T, StateMonad<R, Memo<T,R>>> f) { return t -> { StateMonad<Optional<R>, Memo<T,R>> s = StateMonad.get(m -> m.lookup(t)); // create a computation that would try to find a cached entry return s.bind(v -> // perform the computation, call the result v and do: v.map(s::unit) // if value is present, return a computation that will yield the value .orElse( // otherwise f.apply(t).bind(r -> // compute r = f(t) s.mod(m -> m.insert(t, r)).map(_v -> r) // apply a function to the memo that stores r in it, set the value of s to r ))); // return a computation that will yield r }; };

The appeal of

*memoize*is that it is completely general and can be applied to any unary function.

What about functions that take multiple arguments? Well, in Java you cannot abstract over functions with arbitrary arity. I suggest to take a suitable class, like the one from the wonderful jOOλ library, that can represent an n-ary tuple, and treat an n-ary function as a unary function of such a tuple. (They should have added tuples to Java, every functional programming person is asking for them.)

### Implementing the state monad

The state monad itself is not so interesting. As I have noted above, it should be self-explanatory to readers familiar with the simple parser. Here's how you might change the parser to derive the monad in a few simple steps, mainly consisting of renamings:- Abstract over the type of the state (replace
*CharSequence*with a type variable S) - Rename a few methods to what's customary in the functional world (e. g. Haskell)
*eval*to*evalState**parse*to*runState**then*to*bind*- Delete
*many,**many1,*and*orElse.*(We don't need*orElse*in our example, but you might nevertheless decide to keep it. I guess that in some contexts, it might be useful for backtracking.) - Simplify
*evalState*by removing the checking for unused input that was completely specific to parsing - You might also wish to rename a few variables. (I renamed
*inp*for "input" to*s*for "state",*p*for "parser" to*m*for "monad" etc.)

*mod*and the static value-extracting utility

*get*, which are used in the definition of

*memoize*. That is very little work, and here is the complete resulting code. (I've left out the overloaded version of

*bind*/

*then*that we don't need for this example. Furthermore, the state monad is usually given a few more convenience methods, which we also don't need and which I'm not going to discuss here.)

@FunctionalInterface public interface StateMonad<T,S> { default T evalState(S s) { Objects.requireNonNull(s); Tuple<T, S> t = runState(s); if (t.isEmpty()) { throw new IllegalArgumentException("Invalid state: " + s); } return t.first(); } /** * The type of the functional interface. A state monad is an abstraction over a function: * StateMonad<T,S> :: S -> Tuple<T, S> * In other words, a state monad represents a stateful computation that derives a value and * some new state from an input state. */ abstract Tuple<T, S> runState(S s); // Monadic operations default <V> StateMonad<V,S> unit(V v) { return inp -> tuple(v, inp); } default <V> StateMonad<V,S> bind(Function<? super T, StateMonad<V,S>> f) { Objects.requireNonNull(f); StateMonad<V,S> m = s -> { Tuple<T, S> t = runState(s); if (t.isEmpty()) { return empty(); } return f.apply(t.first()).runState(t.second()); }; return m; } default <V> StateMonad<V,S> map(Function<? super T, V> f) { Objects.requireNonNull(f); Function<V, StateMonad<V,S>> unit = x -> unit(x); return bind(unit.compose(f)); } // Additional functions /** Modify the current state with the given function. */ default StateMonad<T, S> mod(Function<S, S> f) { return s -> runState(f.apply(s)); } /** Create a computation that extracts a value from the state with the given function. */ static <V, S> StateMonad<V, S> get(Function<S, V> stateProjector) { return s -> tuple(stateProjector.apply(s), s); } }

Of course we could have inlined that static call to the convenience function

*StateMonad.get*in

*memoize*(i. e. used the lambda directly), but that would have made the use of tuples visible in

*memoize*.

### Discussion

The trick I've shown you in this post is certainly appealing in its cleverness, and sometimes I quite admire it. But most of the time I feel that it is too clever by half. You've got to think practically. The functional version is (in Java) no better with regard to conciseness or readability than the explicit version. The effort to rewrite the naive definition to make use of memoization is also about the same in each case, because you have to make all those additional syntactic changes in addition to just throwing in a call to*memoize*. And as for performance, the monadic solution is about 5 times slower on my machine than the explicit one (for 100 ≤ n ≤ 1000).

Then there is the matter of representing state. In OO, state is most naturally represented as a member variable of a class. Here's equivalent coding that fits into this paradigm.(Well, almost equivalent. This solution is not thread-safe. On the other hand one might reuse the

*Memoizer*instance.)

static class Memoizer<T, R> { private final Map<T, R> memo; public Memoizer(Map<T, R> memo) { this.memo = memo; } public Function<T, R> memoize(Function<T, R> f) { return t -> { R r = memo.get(t); if (r == null) { r = f.apply(t); memo.put(t, r); } return r; }; } } Memoizer<Integer, BigInteger> m = new Memoizer<>(new HashMap<>()); BigInteger fib(int n) { if (n <= 2) return BigInteger.ONE; return m.memoize(this::fib).apply(n - 1).add( m.memoize(this::fib).apply(n - 2)); }

And of course there is a simple, idiomatic, non-recursive, constant-space solution with mutable variables:

BigInteger fib(int n) { if (n <= 2) return BigInteger.ONE; BigInteger x = BigInteger.ONE; BigInteger y = BigInteger.ONE; BigInteger z = null; for (int i = 3; i <= n; i++) { z = x.add(y); x = y; y = z; } return z; }

So be careful when you take over an idea from functional programming. You may be able to retain some of the beauty, but in other respects you won't always win. In fact, sometimes you'll lose. Don't overdo it. Functional thinking may offer you a different - and useful - perspective on a problem, as I hope to have shown with the parser and some other posts. But a functional solution is not guaranteed to be

*per se*more readable, more maintainable, or more efficient than a good old vanilla object-oriented or imperative solution. Often, it will be neither. Java is not a functional programming language!

### Addendum

For the sake of completeness, here's the*Memo*class that I have used for demo purposes. It's a mutable map masquerading as a functional data structure. Very bad. Don't copy it. You may put trace statements in these two methods to see in what order elements are computed and retrieved, and that they are indeed computed only once. The method names are the same as in TotallyLazy's

*PersistentMap*.

class Memo<K, V> extends HashMap<K, V> { public Optional<V> lookup(K key) { Optional<V> value = Optional.ofNullable(get(key)); return value; } public Memo<K, V> insert(K key, V value) { put(key, value); return this; } }