21 Oct 2016

Lazy tree walking made easy with Kotlin coroutines

Let's imagine a simple unbalanced binary tree structure, in which an abstract BinaryTree<E> is either a concrete Node labelled with a value attribute of type E and having a left and right subtree, or a concrete Empty unlabelled tree. With trees, one very common requirement is to traverse the nodes in some appropriate order (preorder, inorder, or postorder).

In this post, we are going to consider how to do that lazily. That is, we want to be able to short-circuit the traversal, visiting only as many nodes in the tree as we require to see. Moreover, we'd prefer a single-threaded solution.
It's actually not trivial to do that in Java. With regard to iteration, you have to keep track of a lot state. (Cf. these explanations or look at the source code of java.util.TreeMap) You might scrap single-threadedness and actually have two threads communicating over a blocking queue, but that also entails some complexity, especially with task cancellation on short-circuiting.

And even when turning to stream-based processing instead of iteration, it doesn't quite work out as expected. Here's a proposed solution in a hypothetical BinaryTreeStreamer class:

 * Supplies a postorder stream of the nodes in the given tree.
public static <E> Stream<Node<E>> postorderNodes(BinaryTree<E> t) {
    return t.match(
                empty -> Stream.<Node<E>> empty(),
                node -> concat(Stream.of(node.left, node.right).flatMap(BinaryTreeStreamer::postorderNodes),

The corresponding inorder- or preorder-traversals would be similar. The technique for structural pattern-matching with method BinaryTree#match() goes back to Alonzo Church and is explained in more detail on RĂșnar Bjarnason's blog. Basically, each subclass of BinaryTree applies the appropriate function to itself, i. e. Empty invokes the first argument of match, and Node the second.

The code above looks quite reasonable, but unfortunately it is broken by the same JDK feature/bug that I mentioned over a year ago in this post. Embedded flatMap just isn't lazy enough, and breaks short-circuiting. Suppose we construct ourselves a tree representing the expression (3 - 1) * (4 / 2 + 5 * 6). I'll use this as an example throughout this article. Then we start streaming, with the aim of finding out whether the expression contains a node for division:

boolean divides = BinaryTreeStreamer.postorderNodes(tree).filter(node -> node.value.equals("/")).findAny().isPresent();

which leads the code to traverse the entire tree down to nodes 5 and 6. And anyway, we are no closer to an iterating solution.

Now in Python, things look quite different. The thing is, Python has coroutines, called generators in Python. Here's how Wikipedia defines coroutines:
Coroutines are computer program components that generalize subroutines for nonpreemptive multitasking, by allowing multiple entry points for suspending and resuming execution at certain locations.
In Python you can say "yield" anywhere in a coroutine and the calling coroutine starts up again with the value that was yielded. Coroutines are like functions that return multiple times and keep their state (which would include the values of local variables plus the command pointer) so they can resume from where they yielded. Which means they have multiple entry points as well. So here's a Python solution to our problem, with the defaultdict as the tree implementation, using value, left and right as the dictionary keys. (For a presentation that goes a bit beyond our simple example, e. g. see Matt Bone's page.)

tree = lambda: defaultdict(tree)

def postorder(tree):
    if not tree:
    for x in postorder(tree['left']):
        yield x
    for x in postorder(tree['right']):
        yield x
    yield tree

One thing to note is that we must yield each value from the sub-generators. Without that, although the recursive calls would dutifully yield all required nodes, they would yield them in embedded generators. We must append them one level up. That corresponds to the successive flat-mapping in our Java code. Here's how we can enumerate the first few nodes of our example tree in postorder. I also show a bit of the Python tree encoding.

expr = tree()
expr['value'] = '*'
expr['left']['value'] = '-'
expr['left']['left']['value'] = '3'
expr['left']['right']['value'] = '1'
node = postorder(expr)  

Many other languages besides Python have coroutines, or something similar, if not in the language, then at least as a library. Java does not have them, so I started looking for other JVM languages that do. There aren't many. But I found  a library for Scala. However, Scala is not a language that Java developers readily embrace. The happier I was to learn that coroutines will be a feature of Kotlin 1.1, which is now in the early access phase.

I had already known of Kotlin. It is fun. It's very like Java, only better in many respects, it is 100% interoperable with Java, and – being developed by JetBrains – has great tool support from IntelliJ IDEA out-of-the-box. It really has a host of nice features. You might want to check out the following articles, which piqued my interest in the language.
  1. 10 Features I Wish Java Would Steal From the Kotlin Language
  2. Kotlin for Java Developers: 10 Features You Will Love About Kotlin
  3. Why Kotlin is my next programming language 
Kotlin seems to have gained popularity especially among Android developers, among the reasons being its small footprint and the fact that up to the release of Android Nougat, people had been stuck with Java 6 on Android.

The current milestone is Kotlin 1.1-M04. In Kotlin, unlike Python, yield is not a keyword, but a library function. Kotlin as a language has a more basic notions of suspendable computation. You can read all about it in this informal overview. All that talk about suspending lambdas and coroutine builders and what not may seem somewhat intimidating, but fortunately there are already libraries that build upon standard Kotlin to provide functions that are easy to understand and use.

One such library is kotlinx-coroutines.  It contains a function generate that takes a coroutine parameter. Inside that coroutine we can use yield to suspend and return a value, just as in Python. The values are returned as a Kotlin Sequence object. Let me show you my attempt to port the above Python code to Kotlin. I tried to do a faithful translation, almost line by line. That turned out to be pretty straightforward, which I can only explain by guessing that the designers of Kotlin's generate must have been influenced by Python.

fun <E> postorderNodes(t : BinaryTree<E>): Iterable<Node<E>> = generate<Node<E>> {
 when(t) {
  is Empty -> {}
  is Node -> {
   postorderNodes(t.left).forEach { yield(it) }
   postorderNodes(t.right).forEach { yield(it) }

We can seamlessly use Kotlin classes in Java code and vice versa. However, instead of the Kotlin sequence, java.util.Iterable is much nicer to work with on the Java side of things. Fortunately, as shown above, we can simply call asIterable() on the sequence to effect the conversion. So, let BinaryTreeWalker be a Kotlin class that contains the Iterable-returning generator method, and look at some Java code exercising that method:

Iterable<Node<String>> postfix = new BinaryTreeWalker().postorderNodes(expr);
Iterator<Node<String>> postfixIterator = postfix.iterator();


For our example tree, this will correctly print the sequence "31-" and visit no further nodes.

Stream-processing is for free, as you can easily obtain a Stream from the Iterable with, false) That gives you a Java stream based on an iterator over a sequence backed by a Kotlin generator. On that stream, that little snippet looking for a division-node would work as well, only now it would really be lazy.

Seamless integration also means that you are at liberty to migrate any class in a project to Kotlin, while all the rest stays in Java. For example, having written some JUnit tests for the expected behavior of the tree iterator in Java, I could simply keep these tests to verify the Kotlin class, after I had thrown out the Java implementation as insufficient.

You can try this out yourself. The easiest way is to download and install IntelliJ IDEA Community Edition. Then follow the instructions under "How to try it" in the Kotlin blog post. I was able to create a Maven project with dependencies on kotlin-stdlib 1.1-M04 and kotlinx-coroutines 0.2-beta without problems.

Edit 2017-03-01: Today Kotlin 1.1 has been released. The generate method has been moved to the Kotlin standard library under the name of buildSequence. Thus to use it you don't have to depend on kotlinx.coroutines, just import that function from the package kotlin.coroutines.experimental. Here is the release announcement

In closing, I should mention the Quasar library. I must admit I am not sure of the relation between Quasar and Kotlin coroutines. On the one hand, on the page cited, Quasar claims to "provide high-performance lightweight threads, Go-like channels, Erlang-like actors, and other asynchronous programming tools for Java and Kotlin", on the other hand, this very informative presentation from JVMLS 2016 says that Kotlin coroutines are not based on Quasar, and are in effect a much simpler construct. The distinction here is between stackless and stackful coroutines. However, as the Kotlin blog now says (my emphasis)
suspending functions are only allowed to make tail-calls to other suspending functions. This restriction will be lifted in the future.
this distinction may not be so relevant after all. There seems to be discussion at JetBrains whether to integrate more tightly with Quasar (see this issue). It will be interesting to see how this develops.

Addendum:  Just in case you're wondering, no, Kotlin sequences are no lazier than Java streams, the following Kotlin version of the initial Java attempt also traverses the entire tree when  trying to find the first division node:

 fun <E> postorderNodes(t: BinaryTree<E>): Sequence<Node<E>> =
            when(t) {
                is Node -> {
                    (sequenceOf(t.left, t.right).flatMap { postorderNodes(it) }
                    + sequenceOf(t))
                else -> emptySequence()