SyntaxHighlighter

21 Oct 2016

Lazy tree walking made easy with Kotlin coroutines

Let's imagine a simple unbalanced binary tree structure, in which an abstract BinaryTree<E> is either a concrete Node labelled with a value attribute of type E and having a left and right subtree, or a concrete Empty unlabelled tree. With trees, one very common requirement is to traverse the nodes in some appropriate order (preorder, inorder, or postorder).

In this post, we are going to consider how to do that lazily. That is, we want to be able to short-circuit the traversal, visiting only as many nodes in the tree as we require to see. Moreover, we'd prefer a single-threaded solution.
It's actually not trivial to do that in Java. With regard to iteration, you have to keep track of a lot state. (Cf. these explanations or look at the source code of java.util.TreeMap) You might scrap single-threadedness and actually have two threads communicating over a blocking queue, but that also entails some complexity, especially with task cancellation on short-circuiting.

And even when turning to stream-based processing instead of iteration, it doesn't quite work out as expected. Here's a proposed solution in a hypothetical BinaryTreeStreamer class:

/**
 * Supplies a postorder stream of the nodes in the given tree.
 */
public static <E> Stream<Node<E>> postorderNodes(BinaryTree<E> t) {
    return t.match(
                empty -> Stream.<Node<E>> empty(),
                node -> concat(Stream.of(node.left, node.right).flatMap(BinaryTreeStreamer::postorderNodes),
                               Stream.of(node)));
}

The corresponding inorder- or preorder-traversals would be similar. The technique for structural pattern-matching with method BinaryTree#match() goes back to Alonzo Church and is explained in more detail on RĂșnar Bjarnason's blog. Basically, each subclass of BinaryTree applies the appropriate function to itself, i. e. Empty invokes the first argument of match, and Node the second.

The code above looks quite reasonable, but unfortunately it is broken by the same JDK feature/bug that I mentioned over a year ago in this post. Embedded flatMap just isn't lazy enough, and breaks short-circuiting. Suppose we construct ourselves a tree representing the expression (3 - 1) * (4 / 2 + 5 * 6). I'll use this as an example throughout this article. Then we start streaming, with the aim of finding out whether the expression contains a node for division:

boolean divides = BinaryTreeStreamer.postorderNodes(tree).filter(node -> node.value.equals("/")).findAny().isPresent();

which leads the code to traverse the entire tree down to nodes 5 and 6. And anyway, we are no closer to an iterating solution.

Now in Python, things look quite different. The thing is, Python has coroutines, called generators in Python. Here's how Wikipedia defines coroutines:
Coroutines are computer program components that generalize subroutines for nonpreemptive multitasking, by allowing multiple entry points for suspending and resuming execution at certain locations.
In Python you can say "yield" anywhere in a coroutine and the calling coroutine starts up again with the value that was yielded. Coroutines are like functions that return multiple times and keep their state (which would include the values of local variables plus the command pointer) so they can resume from where they yielded. Which means they have multiple entry points as well. So here's a Python solution to our problem, with the defaultdict as the tree implementation, using value, left and right as the dictionary keys. (For a presentation that goes a bit beyond our simple example, e. g. see Matt Bone's page.)

tree = lambda: defaultdict(tree)

def postorder(tree):
    if not tree:
        return
    
    for x in postorder(tree['left']):
        yield x
        
    for x in postorder(tree['right']):
        yield x
        
    yield tree


One thing to note is that we must yield each value from the sub-generators. Without that, although the recursive calls would dutifully yield all required nodes, they would yield them in embedded generators. We must append them one level up. That corresponds to the successive flat-mapping in our Java code. Here's how we can enumerate the first few nodes of our example tree in postorder. I also show a bit of the Python tree encoding.

expr = tree()
expr['value'] = '*'
expr['left']['value'] = '-'
expr['left']['left']['value'] = '3'
expr['left']['right']['value'] = '1'
…
node = postorder(expr)  
print(next(node)['value'])
print(next(node)['value'])
print(next(node)['value'])

Many other languages besides Python have coroutines, or something similar, if not in the language, then at least as a library. Java does not have them, so I started looking for other JVM languages that do. There aren't many. But I found  a library for Scala. However, Scala is not a language that Java developers readily embrace. The happier I was to learn that coroutines will be a feature of Kotlin 1.1, which is now in the early access phase.

I had already known of Kotlin. It is fun. It's very like Java, only better in many respects, it is 100% interoperable with Java, and – being developed by JetBrains – has great tool support from IntelliJ IDEA out-of-the-box. It really has a host of nice features. You might want to check out the following articles, which piqued my interest in the language.
  1. 10 Features I Wish Java Would Steal From the Kotlin Language
  2. Kotlin for Java Developers: 10 Features You Will Love About Kotlin
  3. Why Kotlin is my next programming language 
Kotlin seems to have gained popularity especially among Android developers, among the reasons being its small footprint and the fact that up to the release of Android Nougat, people had been stuck with Java 6 on Android.

The current milestone is Kotlin 1.1-M04. In Kotlin, unlike Python, yield is not a keyword, but a library function. Kotlin as a language has a more basic notions of suspendable computation. You can read all about it in this informal overview. All that talk about suspending lambdas and coroutine builders and what not may seem somewhat intimidating, but fortunately there are already libraries that build upon standard Kotlin to provide functions that are easy to understand and use.

One such library is kotlinx-coroutines.  It contains a function generate that takes a coroutine parameter. Inside that coroutine we can use yield to suspend and return a value, just as in Python. The values are returned as a Kotlin Sequence object. Let me show you my attempt to port the above Python code to Kotlin. I tried to do a faithful translation, almost line by line. That turned out to be pretty straightforward, which I can only explain by guessing that the designers of Kotlin's generate must have been influenced by Python.

fun <E> postorderNodes(t : BinaryTree<E>): Iterable<Node<E>> = generate<Node<E>> {
 when(t) {
  is Empty -> {}
  is Node -> {
   postorderNodes(t.left).forEach { yield(it) }
   postorderNodes(t.right).forEach { yield(it) }
   yield(t)
  }
 }
}.asIterable()

We can seamlessly use Kotlin classes in Java code and vice versa. However, instead of the Kotlin sequence, java.util.Iterable is much nicer to work with on the Java side of things. Fortunately, as shown above, we can simply call asIterable() on the sequence to effect the conversion. So, let BinaryTreeWalker be a Kotlin class that contains the Iterable-returning generator method, and look at some Java code exercising that method:

Iterable<Node<String>> postfix = new BinaryTreeWalker().postorderNodes(expr);
Iterator<Node<String>> postfixIterator = postfix.iterator();

System.out.println(postfixIterator.next().value); 
System.out.println(postfixIterator.next().value); 
System.out.println(postfixIterator.next().value); 

For our example tree, this will correctly print the sequence "31-" and visit no further nodes.

Stream-processing is for free, as you can easily obtain a Stream from the Iterable with StreamSupport.stream(postfix.spliterator(), false) That gives you a Java stream based on an iterator over a sequence backed by a Kotlin generator. On that stream, that little snippet looking for a division-node would work as well, only now it would really be lazy.

Seamless integration also means that you are at liberty to migrate any class in a project to Kotlin, while all the rest stays in Java. For example, having written some JUnit tests for the expected behavior of the tree iterator in Java, I could simply keep these tests to verify the Kotlin class, after I had thrown out the Java implementation as insufficient.

You can try this out yourself. The easiest way is to download and install IntelliJ IDEA Community Edition. Then follow the instructions under "How to try it" in the Kotlin blog post. I was able to create a Maven project with dependencies on kotlin-stdlib 1.1-M04 and kotlinx-coroutines 0.2-beta without problems.

Edit 2017-03-01: Today Kotlin 1.1 has been released. The generate method has been moved to the Kotlin standard library under the name of buildSequence. Thus to use it you don't have to depend on kotlinx.coroutines, just import that function from the package kotlin.coroutines.experimental. Here is the release announcement

In closing, I should mention the Quasar library. I must admit I am not sure of the relation between Quasar and Kotlin coroutines. On the one hand, on the page cited, Quasar claims to "provide high-performance lightweight threads, Go-like channels, Erlang-like actors, and other asynchronous programming tools for Java and Kotlin", on the other hand, this very informative presentation from JVMLS 2016 says that Kotlin coroutines are not based on Quasar, and are in effect a much simpler construct. The distinction here is between stackless and stackful coroutines. However, as the Kotlin blog now says (my emphasis)
suspending functions are only allowed to make tail-calls to other suspending functions. This restriction will be lifted in the future.
this distinction may not be so relevant after all. There seems to be discussion at JetBrains whether to integrate more tightly with Quasar (see this issue). It will be interesting to see how this develops.

Addendum:  Just in case you're wondering, no, Kotlin sequences are no lazier than Java streams, the following Kotlin version of the initial Java attempt also traverses the entire tree when  trying to find the first division node:

 fun <E> postorderNodes(t: BinaryTree<E>): Sequence<Node<E>> =
            when(t) {
                is Node -> {
                    (sequenceOf(t.left, t.right).flatMap { postorderNodes(it) }
                    + sequenceOf(t))
                }
                else -> emptySequence()
            }

30 Sept 2016

Map Inversion

Recently, I had occasion to invert a map, i. e. exchange its keys and values. Having read several discussions on Stackoverflow (see here, here, and here), I decided to create my own utility, internally using Java streams. Besides simple maps, I wanted to be able to convert multimaps, both of the Google Guava and "standard Java" kind, the difference being that the entry set of a Guava multimap contains only elements with single values, whereas the entry set of a standard multimap contains an element with a Collection value.

As an example, here's the code for the latter (and most complicated) case:

/**
 * Inverts a map. The map may also be a non-Guava multimap (i. e. the entrySet elements may have collections as values).
 * 
 * @param map the map to inverted
 * @param valueStreamer a function that converts an original map value to a stream. (In case of a multimap, may stream the original
 *            value collection.) Can also perform any other transformation on the original values to make them suitable as keys.
 * @param mapFactory a factory which returns a new empty Map into which the results will be inserted
 * @param collectionFactory a factory for the value type of the inverted map
 * @return the inverted map, where all keys that map to the same value are now mapped from that value
 */
public static <K, V, V1, C extends Collection<K>> Map<V, C> invertMultiMap(Map<K, V1> map, Function<V1, Stream<V>> valueStreamer,
  Supplier<Map<V, C>> mapFactory, Supplier<C> collectionFactory) {
    Map<V, C> inverted = map.entrySet().stream().flatMap(e ->
       valueStreamer.apply(e.getValue()).map(v -> new SimpleEntry<>(v, newCollection(e.getKey(), collectionFactory))))
      .collect(toMap(Entry::getKey, Entry::getValue, (a, b) -> {a.addAll(b);  return a;}, mapFactory));
    return inverted;
}

private static <T, E extends T, C extends Collection<T>> C newCollection(E elem, Supplier<C> s) {
    C collection = s.get();
    collection.add(elem);
    return collection;
}

You can find the complete code, covering some more cases and with some tests, in this Gist. The tests depend on Guava.


3 May 2016

Concurrent Recursive Function Memoization

Recently on concurrency-interest, there has been a discussion triggered by an observation that Heinz Kabutz made about ConcurrentHashMap. The observation being that the following plausible-looking coding is broken:

public static class FibonacciCached {
  private final Map<Integer, BigInteger> cache = new ConcurrentHashMap<>();
  public BigInteger fib(int n) {
    if (n <= 2) return BigInteger.ONE;
    return cache.computeIfAbsent(n, key -> fib(n - 1).add(fib(n - 2)));
  }
}

ConcurrentHashMap livelocks for n ≥ 16. The reason, explained by Doug Lea, is that if the computation attempts to update any mappings, the results of this operation are undefined, but may include IllegalStateExceptions, or in concurrent contexts, deadlock or livelock.  This is partially covered in the Map Javadocs: "The mapping function should not modify this map during computation."

Mainly in conversation between Viktor Klang and me, and based on an original idea of Viktor's, another approach was developed that appears workable compared to computeIfAbsent. The approach also harks back to two previous posts in this blog, namely the ones about memoization and trampolining. The idea improves on the memoization scheme by providing reusable, thread-safe components for memoizing recursive functions, and combines that with a trampoline to process the function calls in such a way as to eliminate stack overflows.

I'll present the idea with a few explanatory comments. If you want a step-by-step derivation, you can get that by reading the mail list archives. There are three parts to the code.
 
First, a general purpose memoizer
.
public static class ConcurrentTrampoliningMemoizer<T, R> {
  private static final Executor TRAMPOLINE = newSingleThreadExecutor(new ThreadFactoryBuilder().setDaemon(true).build());
  private final ConcurrentMap<T, CompletableFuture<R>> memo;

  public ConcurrentTrampoliningMemoizer(ConcurrentMap<T, CompletableFuture<R>> cache) {
    this.memo = cache;
  }

  public Function<T, CompletableFuture<R>> memoize(Function<T, CompletableFuture<R>> f) {
    return t -> {
      CompletableFuture<R> r = memo.get(t);
      if (r == null) {
        final CompletableFuture<R> compute = new CompletableFuture<>();
        r = memo.putIfAbsent(t, compute);
        if (r == null) {
          r = CompletableFuture.supplyAsync(() -> f.apply(t), TRAMPOLINE).thenCompose(Function.identity())
                .thenCompose(x -> {
                   compute.complete(x);
                   return compute;
                });
        }
      }
      return r;
    };
  }
}

Second, a class that uses the memoizer to compute Fibonacci numbers.

public static class Fibonacci {
  private static final CompletableFuture<BigInteger> ONE = completedFuture(BigInteger.ONE);
  private final Function<Integer, CompletableFuture<BigInteger>> fibMem;

  public Fibonacci(ConcurrentMap<Integer, CompletableFuture<BigInteger>> cache) {
    ConcurrentTrampoliningMemoizer<Integer, BigInteger> memoizer = new ConcurrentTrampoliningMemoizer<>(cache);
    fibMem = memoizer.memoize(this::fib);
  }

  public CompletableFuture<BigInteger> fib(int n) {
    if (n <= 2) return ONE;
    return fibMem.apply(n - 1).thenCompose(x ->
           fibMem.apply(n - 2).thenApply(y -> 
             x.add(y)));
  }
}

Third, any number of clients of a Fibonacci instance.

Fibonacci fibCached = new Fibonacci(new ConcurrentHashMap<>());
BigInteger result = fibCached.fib(550_000).join();

As the final usage example shows, the 550.000th Fibonacci number is about the largest I can get on my box, before all the cached BigInteger values cause an OutOfMemoryError. The Fibonacci instance may be shared among several threads. The whole thing seems to scale OK, as my tests with 2, 4, or 8 threads sharing the same instance indicate.

The formulation of the fib method should be familiar to you from the previous blog entry on memoization, except we're using a different monad here (CompletableFuture instead of StateMonad).

The most interesting thing of course is the memoizer. Here are a few observations:
  • On a cache miss a container (in the form of a CompletableFuture) is inserted into the cache which will later come to contain the value.
  • Different threads will always wind up using the same value instances, so the cache is suitable also for values that are supposed to be singletons. Null values are not allowed.
  • Only the thread that first has a cache miss calls the underlying function.
  • Concurrent readers/writers won't block each other (computeIfAbsent would do that for instance).
  • The trampolining happens because we can bounce off the queue in the executor to which supplyAsync submits tasks. The technique has been explained in a previous post. We are just inlining our two utility methods terminate and tailcall, and fibMem is the thunk. I find it interesting to see how we benefit from this even if there is no tail recursion.
I haven't shown a few obvious static imports. The ThreadFactoryBuilder is from Guava.
You might want to pass in an executor to the memoizer from the outside. This would facilitate separate testing of the implementation vs. performance/scalability. However, performance seems to degrade when using anything but a dedicated single thread, so it may not be worth much in practice.

I am sure there are very many ways to improve on the idea, but I'll leave it at that.

8 Feb 2016

Recursive Function Memoization with the State Monad

The State Monad

What is the State Monad? If you have been following this blog, you already know the answer. In fact, the parser that we have seen in the previous post is one embodiment of it. In general terms, the state monad is just a glorified function that takes a state and computes from that a result value and some new state. Apart from embedding this function, the state monad has a few bells and whistles that help in combining such functions together. In functional programming languages, the state monad is used to simulate external mutable state. The state monad's role is to pass down the state through the calls of the function. The state parameter disappears from the functions actually used.

You have already seen a practical application of this in our parser. In the parser, the state values have been character sequences. We were able to process the input sequence without ever seeing a corresponding parameter in our methods.

Besides parsing, another often-mentioned application of the state monad is recursive function memoization. In this case, the state will consist of pre-computed function values (a memo).  We'll see how to memoize a recursive function that computes the Fibonacci numbers without tacking an extra memo argument onto the function (such as a HashMap). It is the state monad that will keep a memo of already computed values between recursive calls.

You can google this stuff. People write about it often. For example, here are two articles that both deal with memoizing Fibonacci numbers, in a fashion similar to but not identical with mine: one in Java by Pierre-Yves Saumont and another in Scala by Tony Morris. But both miss a point I am going to make below, namely how to factor out the memoization logic from the function itself. (Saumont, by the way, is in the process of writing what looks to be an interesting book.)

I'll first show you how memoizing Fibonacci numbers is done, and define a generic memoize method that will enable memoization for any function. Only then will I show you how to derive the underlying state monad implementation through a few little refactorings of our SimpleParser. This way, you won't be inundated with details, having no idea what they lead up to. I'm hoping that the code is clear enough to grasp its intent even without knowing the internals. If it doesn't work for you, try reading this post from the end to the beginning.

And in the last section, I'll give you my evaluation of the whole thing.

Memoizing Fibonacci numbers

I expect you all know from your introduction to algorithms course that the following naive version will blow up in your face:

BigInteger fibNaive(int n) {
    if (n <= 2) {
        return BigInteger.ONE;
    }
    BigInteger x = fibNaive(n - 1);
    BigInteger y = fibNaive(n - 2);
    return x.add(y);
} 

One way to make the algorithm feasible is to remember previously computed values.

BigInteger fibMemo(int n, Map<Integer, BigInteger> memo) {
    BigInteger value = memo.get(n);
    if (value != null) {
       return value;
    }
    BigInteger x = fibMemo(n - 1, memo);
    BigInteger y = fibMemo(n - 2, memo);
    BigInteger z = x.add(y);
    memo.put(n, z);
    return z;
}

The above is written in a non-functional style with an explicit extra argument. (I'll name it the explicit version.) You would call it like this, with a memo already containing the first two Fibonacci numbers:

BigInteger fibMemo(int n) {
    if (n <= 2) {
        return BigInteger.ONE;
    }
    Map<Integer, BigInteger> memo = new HashMap<>();
    memo.put(1, BigInteger.ONE);
    memo.put(2, BigInteger.ONE);
    return fibMemo(n, memo);
}

Now wouldn't it be nice if we could get rid of that extra argument and the tedium of looking up and storing values, to derive a memoizing version that is closer to the naive version? And, yes, we can. Simply apply a new function, called memoize, to the original function.

StateMonad<BigInteger, Memo<Integer,BigInteger>> fib(int n) {
    return memoize(this::fib).apply(n-1).bind(x ->              // x = memoize (fib) (n-1)
           memoize(this::fib).apply(n-2).bind(y ->              // y = memoize (fib) (n-2)
           StateMonad.get(_s -> x.add(y))));                    // return x.add(y)
}

Basically, this is only a notational variant of the naive version, a bit difficult to read at first because Java lacks the syntactic sugar other languages have. I have tried to indicate such sugar in the inline comment. Having had time to get used to it, I find it doesn't read too bad, just ignore the clutter of "apply" and "bind", and remember that the assignment is notated at the end of the line instead of the beginning.

Well, we also had to change the return type a bit, it no longer represents a value, but a computation that yields this value (the function represented by the state monad). As you perhaps have guessed, the method bind above is the equivalent of then in our parser: the signature is the same, except the continuation that we put in, and the combined computation that we get out, are not parsers, but a more general StateMonad class that can have a memo as its state. As with the parser, the function inside the monad takes no arguments except the state, all other arguments are supplied beforehand.


My main point is that here we have a very clear separation of concerns: There is one entity memoize, concerned with caching computed values, another entity, the StateMonad, concerned with passing the cached values between function calls, and finally fib itself, that embodies the definition of the Fibonacci numbers and is concerned with making recursive calls and combining their results.This is in contrast to the explicit version and also to both posts quoted above, where these concerns are more intertwined.

For me, this is of the essence of functional programming, and therein lies its beauty: that it is so neat and modular on a small scale.

In order to get a value from the computation that is returned by fib, we need to evaluate it against some suitable initial state. The function that does this is called evalState. (And this you also already know from the parser, where it was called just eval.)

BigInteger fibMonadic(int n) {
    return n <= 2 ? BigInteger.ONE
                  : fib(n).evalState(new Memo<Integer,BigInteger>().insert(1, BigInteger.ONE).insert(2, BigInteger.ONE));
}

The class Memo is intended to be a functional equivalent to HashMap.

Implementing generic memoization

So, what is memoize? The signature of memoize is obvious. It takes a function of the type of fib and returns another function of the same type. We can abstract over the input and result value types Integer and BigInteger of fib and replace them with generic type parameters T and R.

Now think of what memoize has to do. It takes a function as its parameter. It must return a new function that
  • when given an argument, tries to look up a previously computed value from the memo
  • if it finds the value, returns a function that will yield this value
  • otherwise, applies the given function to the given argument, stores that result in the memo, and returns a function will yield this value plus an updated state 
As usual with functional maps, the memo shall return an Optional upon lookup, so that we can make the case distinction with Optional's methods map and orElse.Without further ado, here's the code:

static <T,R> Function<T, StateMonad<R, Memo<T,R>>> memoize(Function<T, StateMonad<R, Memo<T,R>>> f) {
    return t -> {   
                  StateMonad<Optional<R>, Memo<T,R>> s = StateMonad.get(m -> m.lookup(t)); // create a computation that would try to find a cached entry
                  return s.bind(v ->                                                       // perform the computation, call the result v and do: 
                              v.map(s::unit)                                               // if value is present, return a computation that will yield the value
                               .orElse(                                                    // otherwise  
                                     f.apply(t).bind(r ->                                  //   compute r = f(t)  
                                     s.mod(m -> m.insert(t, r)).map(_v -> r)               //   apply a function to the memo that stores r in it, set the value of s to r
                              )));                                                         //   return a computation that will yield r
                };
}; 


The appeal of memoize is that it is completely general and can be applied to any unary function.

What about functions that take multiple arguments? Well, in Java you cannot abstract over functions with arbitrary arity. I suggest to take a suitable class, like the one from the wonderful jOOλ library, that can represent an n-ary tuple, and treat an n-ary function as a unary function of such a tuple. (They should have added tuples to Java, every functional programming person is asking for them.)

Implementing the state monad

The state monad itself is not so interesting. As I have noted above, it should be self-explanatory to readers familiar with the simple parser. Here's how you might change the parser to derive the monad in a few simple steps, mainly consisting of renamings:
  • Abstract over the type of the state (replace CharSequence with a type variable S)
  • Rename a few methods to what's customary in the functional world (e. g. Haskell)
    • eval to evalState
    • parse to runState
    • then to bind
  • Delete many, many1, and orElse. (We don't need orElse in our example, but you might nevertheless decide to keep it. I guess that in some contexts, it might be useful for backtracking.)
  • Simplify evalState by removing the checking for unused input that was completely specific to parsing
  • You might also wish to rename a few variables. (I renamed inp for "input" to s for "state", p for "parser" to m for "monad" etc.)
All that remains is to define the new state-manipulating method mod and the static value-extracting utility get, which are used in the definition of memoize. That is very little work, and here is the complete resulting code. (I've left out the overloaded version of bind/then that we don't need for this example. Furthermore, the state monad is usually given a few more convenience methods, which we also don't need and which I'm not going to discuss here.)

@FunctionalInterface
public interface StateMonad<T,S> {
    
    default T evalState(S s) {
        Objects.requireNonNull(s);
        Tuple<T, S> t = runState(s);
        if (t.isEmpty()) {
            throw new IllegalArgumentException("Invalid state: " + s);
        }
        return t.first();
    }

    /** 
     * The type of the functional interface. A state monad is an abstraction over a function:
     * StateMonad<T,S> :: S -> Tuple<T, S> 
     * In other words, a state monad represents a stateful computation that derives a value and 
     * some new state from an input state.
     */
    abstract Tuple<T, S> runState(S s);
    
    // Monadic operations

    default <V> StateMonad<V,S> unit(V v) {
        return inp -> tuple(v, inp);
    }

    default <V> StateMonad<V,S> bind(Function<? super T, StateMonad<V,S>> f) {
        Objects.requireNonNull(f);
        StateMonad<V,S> m = s -> {
            Tuple<T, S> t = runState(s);
            if (t.isEmpty()) {
                return empty();
            }
            return f.apply(t.first()).runState(t.second());
        };
        return m;
    }
    
    default <V> StateMonad<V,S> map(Function<? super T, V> f) {
        Objects.requireNonNull(f);
        Function<V, StateMonad<V,S>> unit = x -> unit(x);
        return bind(unit.compose(f));
    }
       
    // Additional functions
       
    /** Modify the current state with the given function. */
    default StateMonad<T, S> mod(Function<S, S> f) {
        return s -> runState(f.apply(s)); 
    }

    /** Create a computation that extracts a value from the state with the given function. */
    static <V, S> StateMonad<V, S> get(Function<S, V> stateProjector) {
        return s -> tuple(stateProjector.apply(s), s);
    }
}

Of course we could have inlined that static call to the convenience function StateMonad.get in memoize (i. e. used the lambda directly), but that would have made the use of tuples visible in memoize.

Discussion

The trick I've shown you in this post is certainly appealing in its cleverness, and sometimes I quite admire it. But most of the time I feel that it is too clever by half. You've got to think practically. The functional version is (in Java) no better with regard to conciseness or readability than the explicit version. The effort to rewrite the naive definition to make use of memoization is also about the same in each case, because you have to make all those additional syntactic changes in addition to just throwing in a call to memoize. And as for performance, the monadic solution is about 5 times slower on my machine than the explicit one (for 100 ≤ n ≤ 1000).

Then there is the matter of representing state. In OO, state is most naturally represented as a member variable of a class. Here's equivalent coding that fits into this paradigm.(Well, almost equivalent. This solution is not thread-safe. On the other hand one might reuse the Memoizer instance.)

static class Memoizer<T, R> {
  private final Map<T, R> memo;

  public Memoizer(Map<T, R> memo) {
    this.memo = memo;
  }

  public Function<T, R> memoize(Function<T, R> f) {
    return t -> {
      R r = memo.get(t);
      if (r == null) {
        r = f.apply(t);
        memo.put(t, r);
      }
      return r;
    };
  }
}


Memoizer<Integer, BigInteger> m = new Memoizer<>(new HashMap<>());

BigInteger fib(int n) {
  if (n <= 2) return BigInteger.ONE;
  return m.memoize(this::fib).apply(n - 1).add(
         m.memoize(this::fib).apply(n - 2));
} 

And of course there is a simple, idiomatic, non-recursive, constant-space solution with mutable variables:

BigInteger fib(int n) {
 if (n <= 2) return BigInteger.ONE;
 
 BigInteger x = BigInteger.ONE;
 BigInteger y = BigInteger.ONE;
 BigInteger z = null;
 for (int i = 3; i <= n; i++) {
  z = x.add(y);
  x = y;
  y = z;
 }
 return z;
} 

So be careful when you take over an idea from functional programming. You may be able to retain some of the beauty, but in other respects you won't always win. In fact, sometimes you'll lose. Don't overdo it. Functional thinking may offer you a different - and useful - perspective on a problem, as I hope to have shown with the parser and some other posts. But a functional solution is not guaranteed to be per se more readable, more maintainable, or more efficient than a good old vanilla object-oriented or imperative solution. Often, it will be neither. Java is not a functional programming language!

Addendum 

For the sake of completeness, here's the Memo class that I have used for demo purposes. It's a mutable map masquerading as a functional data structure. Very bad. Don't copy it. You may put trace statements in these two methods to see in what order elements are computed and retrieved, and that they are indeed computed only once. The method names are the same as in TotallyLazy's PersistentMap.

class Memo<K, V> extends HashMap<K, V> {

    public Optional<V> lookup(K key) {
        Optional<V> value = Optional.ofNullable(get(key));
        return value;
    }

    public Memo<K, V> insert(K key, V value) {
        put(key, value);
        return this;
    }
}

10 Jan 2016

DSLs with Graham Hutton's Functional Parser

One of the best introductions to functional programming is Graham Hutton's excellent book Programming in Haskell. Although it is primarily a course on the Haskell programming language, it also provides an introduction to the underpinnings of functional programming in general. It is very readable, aimed at the undergraduate level. I think it should be on the shelf of every student of functional programming.

I had thought of writing a book review, but there's really no need for it. You can find in-depth technical reviews on the book's site, and more: You can download slides for each chapter, watch video lectures, and get the complete Haskell source code of all the examples. That is for free, without even buying the book - a tremendous service.

Instead of a review, I'll show you what I did for an exercise. I ported the functional parser that Hutton presents in chapter 8 of his book to Java. My aim is to demonstrate the following:
  • Mastering functional idioms will help you to create powerful interfaces supporting a fluent programming style in Java
  • Of particular use in abstracting from tedious glue code are monadic interfaces, i. e. basically those that have a flatMap method
  • There is a special way of encapsulating internal state that is sometimes useful
The parser itself is a recursive descent parser of the kind you wouldn't necessarily use for large and complex grammars. (You'd perhaps use a compiler-compiler or parser generator tool like ANTLR or Xtext/Xtend.) However, you might use it to quickly spin-off a little DSL or demo coding. The real point is not the parser, but the approach to application design.

In the following, some of the text is taken verbatim from Hutton's book, some is my own. I shall not bother with quotes and attributions, because I am not claiming any original ideas here. Remember: this is just an exercise in porting Haskell to Java. However, I'll try to be helpful by interspersing a few extra examples and comments.

Nevertheless, this post is going to be a bit tedious, in the way textbooks are. I apologize in advance. If you're feeling impatient, perhaps you might want to scroll down and look at the final example, a little parser for arithmetic expressions. I summarize my experiences and give some advice in the concluding section.

Basic parsers


So what is a parser? It is a function that takes an input character sequence and produces a pair comprising a result value and an output character sequence, the unconsumed rest of the input. Let's abstract over the type of the result value and define a corresponding functional interface:

@FunctionalInterface
public interface SimpleParser<T> {   
    abstract Tuple<T, CharSequence> parse(CharSequence cs);
}

Hutton's parser in fact returns a list, to include the possibility of returning more than one result if the input sequence can be parsed in more than one way. For simplicity, however, he only considers parsers that return at most one result. For that reason, I have opted for the simpler signature. It follows that our parser cannot handle ambiguity. You might want to do the generalization yourself.

I shall assume the existence of a special empty tuple, which has no components, as an analogue to the empty list, which signifies that the input cannot be parsed at all. (Otherwise, the type Tuple is just what you might expect.)

We will now define some basic parsers, and then see how we can arrange them in a hierarchy of ever higher-level derived parsers. First, here's a method that produces a basic parser that always succeeds with the given result value without consuming any input:

default <V> SimpleParser<V> unit(V v) {
    return inp -> tuple(v, inp);
}

In Haskell, this method would be called "return", but of course that's not possible in Java. We'll also provide output, a static analogue to unit, and failure, a way to produce a parser that always fails regardless of the input. In fact, we'll define quite a few static factory methods for creating parsers. These I have collected in a separate utility class SimpleParserPrimitives. The parser interface itself will contain only the (non-static) methods to combine and manipulate parsers. It will turn out that these methods are quite general. You might re-use them, and each time you might well provide a different set of parsing "primitives" for a different parsing purpose. The primitives I will be discussing are geared towards parsing programming languages.

public static <T> SimpleParser<T> output(T v) {
    return inp -> tuple(v, inp); 
}

public static <T> SimpleParser<T> failure() {
    return _inp -> empty();
} 

The method empty returns the empty tuple, representing failure. Our final basic parser is item, which fails if the input string is empty, and succeeds with the first character as the result value otherwise.

public static SimpleParser<Character> item() {
    return inp -> inp.length() == 0 ? empty() : tuple(inp.charAt(0), inp.subSequence(1, inp.length()));
}

Sequencing


Suppose we want to parse an arithmetic expression like "5 + 4 * 3". One way to decompose the problem is to create a parser for the numeral "5", and a parser for the rest of the expression "+ 4 * 3" (which we would again decompose into a parsers for the symbol "+" and the expression "4 * 3" and so on.) We then look for a way to combine these parsers into a parser for the whole expression. The combined parser shall execute the parsing steps one after the other. In the process of doing that we need to make decisions based on the result of the previous step. One such decision is if and how to continue parsing the rest of the input or not. If we want to continue, we must pass on the intermediate result to the following step.

Here's the signature of such a method that takes a description of the future parsing steps, and returns a new parser that will execute the current step combined with those future ones:

default <V> SimpleParser<V> then(Function<? super T, SimpleParser<V>> f)

The returned parser fails if the application of the current parser to the input string fails, and otherwise applies the function f to the result value to give a second parser q, which is then applied to the output string to give the final result. Thus, the result value produced by this parser is made directly available for processing in the remaining steps by binding it to a lambda variable in a function describing those remaining steps. For example, the expression

p.then(x -> output(x));

may be read as: Parse the input using the parser p, and then, calling the result "x", apply output() to x.

The method then corresponds to flatMap on Java Stream and Optional,  resp. to CompletableFuture.thenCompose(), which are the three "monadic" interfaces built into Java 8. These methods all have the property of chaining computations together (in the case of streams like nested for-loops, as I have discussed in several earlier posts.) Here's the implementation. You can see how each line corresponds to a part of the description we have given above.

default <V> SimpleParser<V> then(Function<? super T, SimpleParser<V>> f) {
    Objects.requireNonNull(f);
    SimpleParser<V> pf = inp -> {
        Tuple<T, CharSequence> t = parse(inp);
        if (t.isEmpty()) {
            return empty();
        }
        return f.apply(t.first()).parse(t.second());
    };
    return pf;
}

Let's look what we can do with the building blocks we have assembled so far. Here's a contrived example. The parser p consumes three characters, discards the second, and returns the distance between the first and the third:

@Test 
public void composition() {
    SimpleParser<Integer> p = 
            item().then( x -> 
            item().then( y ->
            item().then( z ->
            output(z - x)
            )));
    assertEquals(tuple(2,"def"), p.parse("abcdef"));
    assertEquals(empty(), p.parse("ab"));
}
Notice how the formatting goes some way to hide the nesting of the constructs. I find the code better readable that way, because conceptually, we are just executing parsing steps in sequence. Note, too, how we are ignoring the variable y in the last two steps. Let's introduce a variant on then that allows us to dispense with this superfluous variable.

default <V> SimpleParser<V> then(Supplier<SimpleParser<V>> p) {
    Objects.requireNonNull(p);
    return then(_v -> p.get());
}

The above parser can now be rewritten without the unused variable:

    SimpleParser<Integer> p = 
            item().then( x -> 
            item().then( () ->
            item().then( z ->
            output(z - x)
            )));

Note the use of Supplier to enforce laziness: One could not simply overload then to accept a parser instance, because that would imply that the argument parser p were immediately constructed by the caller, while otherwise it is constructed only after the current parser has completed successfully. But at least we do no longer have to litter our code with dummy variables. Haskell has nice syntactic sugar ("do-notation") that allows just leaving out the unused variables, and also hides the nesting of the calls (what I have tried to achieve by code  formatting).

Also note how there's no reference to the input that is being parsed in the definition of our little parser. There is no explicit variable of type CharSequence being threaded through the calls, or anything like that. All the mechanics of state handling are neatly encapsulated inside then

Derived parsers


We can now start combining our basic parsers in meaningful ways to derive more parsing primitives. First, we build upon item to define a parser that accepts a certain character only if it satisfies some predicate. And with the help of the utility methods in java.lang.Character we can go on to define parsers for all kinds of character classes, like digits, lower and upper case letters, alphanumeric characters or whitespace. For example:

/** A parser that returns the next character in the input if it satisfies a predicate. */
public static SimpleParser<Character> sat(Predicate<Character> pred) {
    Objects.requireNonNull(pred);
    return item().then(x -> pred.test(x) ? output(x) : failure());
}

/** A parser that returns the next character in the input if it is the given character. */
public static SimpleParser<Character> character(Character c) {
    return sat(x -> c.equals(x));
}

/** A parser that returns the next character in the input if it is a digit. */
public static SimpleParser<Character> digit() {
    return sat(Character::isDigit);
}

/** A parser that returns the next character in the input if it is a whitespace character. */
public static SimpleParser<Character> whitespace() {
    return sat(Character::isWhitespace);
}

With character we can recursively recognize strings.

public static SimpleParser<String> string(String s) {
    Objects.requireNonNull(s);
    return s.isEmpty() ? output("") : character(s.charAt(0)).then(() -> string(s.substring(1)).then(() -> output(s)));
}

The base case states that the empty string can always be parsed. The recursive case states that a non-empty string can be parsed by parsing the first character, parsing the remaining characters, and returning the entire string as the result value.

Choice and repetition


There are more ways to combine parsers than simple sequencing. We often need to express alternative ways of parsing one string ("choice"), and we need repetition, where a string is parsed by the repeated application of the same parser.

/** Applies the given parser if this parser fails. */
default SimpleParser<T> orElse(SimpleParser<T> p) {
    Objects.requireNonNull(p);
    return inp -> {
        Tuple<T, CharSequence> out = parse(inp);
        return out.isEmpty() ? p.parse(inp) : out;
    };

/** Applies this parser at least once or not at all. */
default SimpleParser<List<T>> many() {
    return many1().orElse(unit(emptyList()));
}

/** Applies this parser once and then zero or more times. */
default SimpleParser<List<T>> many1() {
    return then(c -> many().then(cs -> unit(cons(c, cs))));
}

(cons is a utility that adds an element at the front of a persistent list. Think of copying the cs to a new list and adding c at index 0.)

Note that many and many1 are mutually recursive. In particular, the definition for many states that a parser can either be applied at least once or not at all, while the definition for many1 states that it can be applied once and then zero or more times. The result of applying a parser repeatedly  is a list of the collected results. At some point, we may want to convert this list back to a single value, e. g. we'd perhaps like to convert a sequence of digits to an integer value.

public static SimpleParser<Integer> nat() {
    return digit().many1()
                  .then(cs -> output(ParserUtils.charsToString(cs)))
                  .then(s -> output(Integer.parseInt(s)));
}

This looks a bit clumsy, after all we are only transforming the result value inside a parser without actually consuming any input. It would be more natural if we had a method map to do that. (BTW, Hutton does not include such a method.) Fortunately, map is easy to define in terms of unit and then. In fact, it is axiomatic that this definition should be valid, otherwise we have made a mistake in the definition of our flatMap-analogue then (probably not with unit, since it is so simple).

default <V> SimpleParser<V> map(Function<? super T, V> f) {
    Objects.requireNonNull(f);
    Function<V, SimpleParser<V>> unit = this::unit;
    return then(unit.compose(f));
}

A new parser is "wrapped" around the output of the function f, as opposed to then, which already takes a parser-bearing function and does not wrap it with another parser. The analogy with map and flatMap on Stream is obvious. Our code for parsing a natural number now becomes more readable. And we can now also use method references.

public static SimpleParser<Integer> nat() {
    return digit().many1()
                  .map(ParserUtils::charsToString)
                  .map(Integer::parseInt);
}

Tokenization


In order to handle spacing, we introduce a new parser token that ignores any space before and after applying some given parser to an input string.


/** A parser that collapses as many whitespace characters as possible. */
public static SimpleParser<String> space() {
    return whitespace().many().map(s -> "");
}

/** A parser that ignores any space around a given parser for a token. */ 
public static <T> SimpleParser<T> token(SimpleParser<T> p) {
    Objects.requireNonNull(p);
    return space().then(() -> p.then(v -> space().then(() -> output(v)))); 
}

We may be interested in a diversity of tokens, according to our target language, such as tokens for variable names and other identifiers, reserved words, operators, numbers etc. For example, we can define

/** A parser for a natural number token. */ 
public static SimpleParser<Integer> natural() {
    return token(nat());
} 

/** A parser for a symbol token. */
public static SimpleParser<String> symbol(String sym) {
    Objects.requireNonNull(sym);
    return token(string(sym));
}

This completes our overview of the parsing primitives. Let's see a few applications.

Lists of numbers 


The following test  defines a parser for a non-empty list of natural numbers that ignores spacing around tokens. This definition states that such a list begins with an opening square bracket and a natural number, followed by zero or more commas and natural numbers, and concludes with a closing square bracket. Note that the parser should only succeed if a complete list in precisely this format is consumed.

@Test
public void tokenization() {
    SimpleParser<List<Integer>> numList = 
                            symbol("[").then(() ->
                            natural().then(n ->
                                symbol(",").then(() -> natural()).many().then(ns ->
                            symbol("]").then(() ->
                            output(cons(n,ns))))));

    assertEquals(tuple(asList(1,2,3),""), numList.parse("[1,2,3]"));
    assertEquals(tuple(asList(11,22,33),""), numList.parse(" [11,  22, 33 ] "));
    assertEquals(tuple(asList(11,22,33),"abc"), numList.parse(" [11,  22, 33 ] abc"));
    assertEquals(empty(), numList.parse("[1,  2,]"));
}

You get the idea. There is one shortcoming: Detecting failure on account of ill-formed or unused input involves examining the output tuples in caller code. This may not always be what we want.

Representing failure


What would one expect to happen when some input is not accepted? In Java one would certainly expect an exception to be thrown. However, that is not so easy here. Of course, we cannot make then or failure throw an exception instead of returning empty(), because that would abort all  possible search paths, not just the one we're currently on. Putting exception handling code for regular control flow into orElse I also consider bad. The best thing to do, I suppose, is to add a top-level method to the parser that would examine the parsing result.

default T eval(CharSequence cs) {
    Objects.requireNonNull(cs);
    Tuple<T, CharSequence> t = parse(cs);
    if (t.isEmpty()) {
        throw new IllegalArgumentException("Invalid input: " + cs);
    }
    CharSequence unusedInput = t.second();
    if (unusedInput.length() > 0) {
        throw new IllegalArgumentException("Unused input: " + unusedInput);
    }
    return t.first();
}

This way, we have also successfully hidden the internal use of tuples from callers of the parser, which is a good thing.

Arithmetic expressions


Hutton presents an extended example for parsing arithmetic expressions. The example serves to show how easy it can be to create a little DSL.

Consider a simple form of arithmetic expressions built up from natural numbers using addition, multiplication, and parentheses. We assume that addition and multiplication associate to the right, and that multiplication has higher priority than addition. Here's an unambiguous BNF grammar for such expressions.

expr ::= term (+ expr | nil)
term ::= factor (∗ term | nil)
factor ::= natural | (expr)
natural ::= 0 | 1 | 2 | · · ·


It is now straightforward to translate this grammar into a parser for expressions, by simply rewriting the rules using our parsing primitives. We choose to have the parser itself evaluate the expression to its integer value, rather than returning some form of tree.

SimpleParser<Integer> expr() {
    return term().then(t ->
                    symbol("+").then(() -> expr().then(e -> output(t + e)))
                    .orElse(output(t)));
}
    
SimpleParser<Integer> term() {
    return factor().then(f ->
                      symbol("*").then(() -> term().then(t -> output(f * t)))
                      .orElse(output(f)));
}
    
SimpleParser<Integer> factor() {
    return natural()
           .orElse(
               symbol("(").then(() ->
               expr().then(e ->
               symbol(")").then(() ->
               output(e)))));
}

I like the look of that. Once you get used to the formalism, writing the parser is as easy as writing the grammar in the first place. Let's test the whole thing:

@Test
public void multiplicationTakesPrecedence() {
    assertThat( expr().eval("2+3*5"), is(17) );
}

Laziness


Notice how we have created functions at each step that only get evaluated when we actually use them. Only then is a new parser instantiated. The construction of the parser is interleaved with its execution. As mentioned above, it is important in this context that then always takes a function, if necessary even a no-argument function, that yields a parser, never directly a parser instance. If that were different, we might actually have an infinite loop, e. g. expr() → term() → factor() → expr(), where the last call to expr() would be that in the "orElse"-clause, because Java as a strict language evaluates method arguments immediately.

Recursive lambdas


You may wonder why I have written the expression parser in terms of methods, instead of declaring instance variables for the functions expr, term, and factor.The reason is that lambdas capture local variables by value when they are created, and we would therefore get a compile-time error: "Cannot reference a field before it is defined". The only way to write recursive lambdas is to make everything static and use fully qualified names (see the Lambda FAQ).


Lessons learned


I feel I have learned something from this exercise, especially about the feel of functional programming in Java. A few particular lessons have been:
  • "Monadic" interfaces are easy to design once you get your head around flatMap and its relatives. This method in its various guises, in Stream, Optional, CompletableFuture or our SimpleParser, is a powerful tool you need to understand.
  • When you need to add functionality to your lambdas, default methods in interfaces are the way to go. You can instantiate your functions with lambda expressions, and then combine them fluently in a variety of interesting ways.This is also what the JDK does, e. g. the Function interface has a default method compose for functional composition.
  • Our parser interface has rather many (in fact, eight) default methods. All these default methods are strongly coupled. It is generally accepted that this is bad when designing for inheritance (cf. what Josh Bloch has to say in Effective Java, 2nd ed., items 17 and 18). Perhaps that is a reason to feel a bit queasy about the approach.Which  leads me to the next point.
  • Functional interfaces need not be designed for inheritance, so that in contrast to "ordinary" interfaces, strong coupling  is acceptable: I have adopted a practice according to which I do not extend functional interfaces and implement them only with lambdas. (It's nevertheless OK to extend another functional interface in order to specialize a generic type variable or add a few little static utilities. That is what UnaryOperator and BinaryOperator do, and they are the only two of the 57 functional interfaces in the JDK to extend another interface.)
  • Functions are great for re-usability. Being implemented via interfaces, they are usually public. As a consequence, data structures tend to leak all over the place. (We had to take special care to hide the use of tuples.) I guess that this is part of the functional mind-set: While in the OO-world, we encapsulate and hide as much we can (with reason), this is less important in the functional world.
  • Don't clutter your interfaces with too many static methods. It is sometimes difficult to decide where a static method should go (e. g. output and failure are so essential, should they have been put into the parser interface?).
  • Java is really not a functional programming language.Stay with idiomatic Java as far as possible. Some of the things that I found to be missing in Java to make it more "functional" are
    • persistent collections, tuples (you can pull-in third party libraries)
    • some important syntactic sugar like do-notation, inconspicuous unused variables (perhaps underscore only) etc. Nested constructs can quickly become difficult to read.
    • the ability to write non-static recursive lambdas.
    • a nice syntax for functional application.
  • Beware of Java's eagerness, prefer lazy idioms. 
  • Don't kill your backtracking with exceptions.
  • Do not throw checked exceptions from your code. Sooner or later (probably sooner) it's going to be a pain when you want it to work with lambdas, because all the most important standard functional methods in the JDK - like Function#apply() - do not have a throws-clause.
In my IDE (Eclipse Mars) I have been annoyed by two things in particular. I don't know if other IDEs offer a better experience.
  • Debugging lambdas is a pain. (As an exercise, you might replace equals with == in the character parser, pretend you didn't know that you made this silly mistake, and try to fix it.) Unit testing becomes all the more important. Here are a few negative experiences I have made:
    • Being restricted to line breakpoints is bad when you're interested only in a part of those nice one-liners. 
    • In order to set breakpoints inside a lambda, you need blocks, again bad for elegance. 
    • When you're inside the lambda, you may see synthetic members named arg$1 etc. without being really able to tell what they are. It helps a bit when you at least show the declared type in the Variables view. 
    • Stepping into a lambda is not really an option, because of all the synthetic stuff in between. 
  • Formatting lambdas is not always easy (I often do it manually, case by case, because I don't always like what Eclipse does automatically). 
 And do have a look at Hutton's book!