SyntaxHighlighter

10 Sept 2015

Cartesian Products with Kleisli Composition


Over at the JOOQ blog, they have written about How to use Java 8 Functional Programming to Generate an Alphabetic Sequence. By "alphabetic sequence" they mean the (n-order) Cartesian product of the letters in the alphabet. For example, the third-order Cartesian product of the alphabet would be
AAA, AAB, AAC, ..., ZZZ
The author of that post (it is unsigned, but I believe it was written by Lukas Eder) claims that
the Java 8 Stream API does not offer enough functionality for this task.
I will show that this claim is wrong, and that there is in fact a pure Java 8 solution. This solution brings out the structure of the problem rather nicely. It is general in the sense that it uses a general function combinator well-known in functional programming. It relies only on the Java 8 Streams API without needing any constructs from jOOλ.

The first building block of the solution is provided by Misha in his answer to this Stackoverflow question. Here's Misha's function that will create a stream of combinations of its first argument with the given stream, where the mode of combination can be externally specified as well. (In the following I will not only call this method crossJoin, but also the function that it returns.)
<T, U, R> Function<T, Stream<R>> crossJoin(Supplier<Stream<? extends U>> streamSupplier, BiFunction<? super T, ? super U, ? extends R> combiner) {
    return t -> streamSupplier.get().map(u -> combiner.apply(t, u));
}

Now you can do the nth-order Cartesian product by applying flatMap with crossJoin a fixed number of times. A question that goes unanswered in that Stackoverflow discussion is at the core of the JOOQ post: How deal with the situation when the required number of applications is not fixed, but supplied at runtime as a parameter? That is, how do we implement

List<String> cartesian(int n, Collection<String> coll)
Of course, you might do it recursively. But that would rather count as supporting the claim that Java streams were not sufficient! Fortunately, the implementation becomes easy once we remember that the repeated application of functions is equivalent to a single application of the composition of these functions. In our case, we actually want sequential composition, which is just the flip of regular composition. Looking at the signature

crossJoin :: a -> Stream b
it may be hard to see at first how we might compose functions of this signature sequentially. Actually, however, there is guaranteed to be such a combinator. It is called "Kleisli composition", written ">=>", and has the following signature:

>=> :: (a -> Stream b) -> (b -> Stream c) -> (a -> Stream c)

In fact, >=> has a completely general definition for all so-called "monadic types", of which Stream in Java is one. What this means is that we could even abstract over Stream in the above signature, and implement sequential composition for all those types in the same way. This is easier in some languages (Haskell, Scala) than in others (Java). This article by Yoav Abrahami gives an instructive example of what you can do with Kleisli composition in Scala. It demonstrates the same trick we are about to perform, namely replacing recursion with a form of functional composition.

For the sake of example, instead of generalizing we will simplify somewhat. In our case we are dealing only with streams of strings. Here's the Java implementation of >=> for this special case:
BinaryOperator<Function<String,Stream<String>>> kleisli = (f,g) -> s -> f.apply(s).flatMap(g);
The Stream type, being "monadic", obeys certain axioms, one of which states that >=> is associative. So we may use it in a reduction. The required identity function is s -> Stream.of(s). You can easily convince yourself of that by considering the first-order product, which is of course just the original list, which we'll get by lifting every element to Stream and flat-mapping back down, without invoking crossJoin at all. (Of course, you can also prove it by plugging the identity into the definition of >=>.)

Let's put it all together: The idea is to create a stream of as many crossJoin functions as we need, reduce this stream to a single function in memory by composing them together, and finally apply the entire function chain in one fell swoop.

List<String> cartesian(int n, Collection<String> coll) {
    return coll.stream()
           .flatMap( IntStream.range(1, n).boxed()
                     .map(_any -> crossJoin(coll::stream, String::concat)) // create (n-1) appropriate crossJoin instances
                     .reduce(s -> Stream.of(s), kleisli)                   // compose them sequentially with >=>
                    )                                                      // flatMap the stream with the entire function chain
           .collect(toList());
}
The following call will give you a list of all three-letter alphabetic sequences:
cartesian(3, alphabet)
A variation on the above is when you do not want to multiply a collection n times with itself, but with n other collections (all of the same type). Instead of passing in n, you might pass in the sequence of those other collections, and instead of streaming an integer range you might stream that sequence, creating a crossJoin function for each element,  like this: Stream.of(colls).map(c -> crossJoin(c::stream, String::concat))

For good measure, here's how you may construct the list of letters alphabet, if you do not wish to write them down severally with Arrays.asList:
List<String> alphabet = IntStream.rangeClosed('A','Z')
                        .mapToObj(c -> String.valueOf((char) c))
                        .collect(toList());    

[Addendum]:
  1. It's important that we pass a Supplier<Stream> to crossJoin, not the stream itself. 
  2. You might also be interested in Tagir Valeev's answer to a related question on Stackoverflow.

3 Jun 2015

Recursive Parallel Search with Shallow Backtracking

The implementation in the previous post had a serious disadvantage: We generated substitutions for all the letters, and then checked if those substitutions constituted a solution. However, in a more efficient approach some partial substitutions could be rejected out-of-hand as not leading to a solution. This is called "shallow backtracking". In this post I show how to examine the letters as they occur from right to left in the operands, and interleave the checking of arithmetic constraints with the generation of substitutions.

This solution combines the advantages of flatMap-based search in Java 8, parallelization, persistent collections, and recursion. This code that I will show is completely general for cryptarithmetic puzzles with two summands, because the constraints are no longer problem-specific: in fact they encode only the rules of long addition plus a general side condition.

The implementation is also a bit more idiomatic (for Java), with less clutter in the method interfaces, because we keep the read-only operands in instance variables instead of passing them around explicitly.

The really nice thing is that this approach is about 25 times faster than the original parallel flatMap solution with exhaustive search.

Here's the code:
import static com.googlecode.totallylazy.collections.PersistentList.constructors.list;
import static com.googlecode.totallylazy.collections.PersistentMap.constructors.map;
import static java.util.stream.Collectors.toList;

import java.util.Collection;
import java.util.Map;
import java.util.stream.Stream;

import com.googlecode.totallylazy.collections.PersistentList;
import com.googlecode.totallylazy.collections.PersistentMap;

public class SendMoreMoneyShallow {

    static final int PRUNED = -1;
    static final char PADDING = ' ';
    static final PersistentList<Integer> DIGITS = list(0, 1, 2, 3, 4, 5, 6, 7, 8, 9);
    
    // padded puzzle arguments: op1 + op2 = op3
    final String op1;
    final String op2;
    final String op3;
    
    public static void main(String[] args) {
        SendMoreMoneyShallow puzzle = new SendMoreMoneyShallow(" send", " more", "money");
        Collection<String> solutions = puzzle.solve();
        System.out.println("There are " + solutions.size() + " solutions: " + solutions);
    }

    public SendMoreMoneyShallow(String op1, String op2, String op3) {
        // the arguments come padded with blanks so they all have the same length
        // there is no need to reverse the strings, because we have random access and can process them backwards
        assert op1.length() == op3.length();
        assert op2.length() == op3.length();
        this.op1 = op1;
        this.op2 = op2;
        this.op3 = op3;
    }
    
    public Collection<String> solve() {
        PersistentMap<Character, Integer> subst = map();
        Collection<String> solutions = go(op1.length() - 1, subst, 0).collect(toList());
        return solutions;
    }

    Stream<String> go(int i, PersistentMap<Character, Integer> subst, int carry) {
        // Each level of recursion accomplishes up to three substitutions of a character with a number. The recursion
        // should end when we run out of characters to substitute. At this point, all constraints have already been
        // checked and therefore the substitutions must represent a solution.
        if (i < 0) {
            return solution(subst);
        }
        
        // the state consists of partial substitutions and the carry. Every time we have made a substitution for a column
        // of letters (from right to left), we immediately check constraints.
        Character sx = op1.charAt(i);
        Character sy = op2.charAt(i);
        Character sz = op3.charAt(i);
        return candidates(sx, subst).stream().parallel().flatMap(x -> {
                PersistentMap<Character, Integer> substX = subst.insert(sx,x);
                return candidates(sy, substX).stream().flatMap(y -> {
                    PersistentMap<Character, Integer> substXY = substX.insert(sy,y);
                    return candidates(sz, substXY).stream().flatMap(z ->   {
                        PersistentMap<Character, Integer> substXYZ = substXY.insert(sz, z);
                        // recurse if not pruned, using the tails of the strings, the substitutions we have just made, and
                        // the value for carry that results from checking the arithmetic constraints
                        int nextCarry = prune(substXYZ, carry, x, y, z);
                        return nextCarry == PRUNED ? Stream.empty() : go(i - 1, substXYZ, nextCarry);
                    });});});
    }

    int prune(PersistentMap<Character, Integer> subst, int carry, Integer x, Integer y, Integer z) {
        // neither of the most significant digits may be 0, and we cannot be sure the substitutions have already been made
        if (subst.getOrDefault(mostSignificantLetter(op1), 1) == 0 || subst.getOrDefault(mostSignificantLetter(op2), 1) == 0) {
            return PRUNED;
        }

        // the column sum must be correct
        int zPrime = x + y + carry;
        if (zPrime % 10 != z) {
            return PRUNED;
        }

        // return next carry
        return zPrime / 10;
    }

    PersistentList<Integer> candidates(Character letter, PersistentMap<Character, Integer> subst) {
        if (letter == PADDING) {
            return list(0);
        }
        // if we have a substitution, use that, otherwise consider only those digits that have not yet been assigned
        return subst.containsKey(letter) ? list(subst.get(letter)) : DIGITS.deleteAll(subst.values());
    }

    Stream<String> solution(PersistentMap<Character, Integer> subst) {
        // transform the set of substitutions to a solution (in this case a String because Java has no tuples)
        int a = toNumber(subst, op1.trim());
        int b = toNumber(subst, op2.trim());
        int c = toNumber(subst, op3.trim());
        return Stream.of("(" + a + "," + b + "," + c + ")");
    }

    static final int toNumber(Map<Character, Integer> subst, String word) {
        // return the integer corresponding to the given word according to the substitutions
        assert word.length() > 0;
        return word.chars().map(x -> subst.get((char)x)).reduce((x, y) -> 10 * x + y).getAsInt();
    }

    static char mostSignificantLetter(String op) {
        return op.trim().charAt(0);
    }
}

And here's a representative performance measurement:

# JMH 1.9.1 (released 40 days ago)
# VM invoker: C:\Program Files\Java\jdk1.8.0_25\jre\bin\java.exe
# VM options: -Dfile.encoding=UTF-8
# Warmup: 5 iterations, 1 s each
# Measurement: 25 iterations, 1 s each
# Timeout: 10 min per iteration
# Threads: 1 thread, will synchronize iterations
# Benchmark mode: Average time, time/op
# Benchmark: java8.streams.SendMoreMoneyShallowBenchmark.measureShallowBacktrackingPerformance


Benchmark                              Mode  Cnt  Score   Error  Units
measureShallowBacktrackingPerformance  avgt   25  6.319 ± 0.082  ms/op


There is room for still more improvement, because the substitution for "z" in each round is in fact determined by the previous substitutions and need not be guessed.

27 May 2015

Recursive Parallel Search

Bartosz Milewski has been discussing monadic programming in C++. In this post he presents a recursive version of the flatMap-based solution of the "Send more money" cryptarithmetic puzzle.

(BTW: Bartosz' writings always make for excellent reading if you're at all interested in functional programming.)

Here's a Java version of the recursive approach. I have already discussed the non-recursive solution in a previous blog entry. The main advantage over the original solution is that the recursive approach is less redundant and much easier to generalize. The main drawback is that on my machine it is about 6 times slower.

Please note how easy it is to parallelize this recursive solution when using persistent collections, in this case the ones from the TotallyLazy framework. Please spend a minute thinking about how you would go about coding this solution using only classes from the JDK. It's not trivial. (I would guess the same is true of the C++ implementation).

So here goes:

import static com.googlecode.totallylazy.collections.PersistentList.constructors.list;
import static com.googlecode.totallylazy.collections.PersistentMap.constructors.map;
import static java.util.stream.Collectors.toList;

import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.stream.Stream;

import com.googlecode.totallylazy.collections.PersistentList;
import com.googlecode.totallylazy.collections.PersistentMap;

public class SendMoreMoneyRecursive {

    static final PersistentList<Character> LETTERS = list(uniqueChars("sendmoremoney"));
    static final PersistentList<Integer> DIGITS = list(0, 1, 2, 3, 4, 5, 6, 7, 8, 9);

    public static void main(String[] args) {
        SendMoreMoneyRecursive puzzle = new SendMoreMoneyRecursive();
        Collection<String> solutions = puzzle.solve(LETTERS, DIGITS);
        System.out.println("There are " + solutions.size() + " solutions: " + solutions);
    }

    Collection<String> solve(PersistentList<Character> str, PersistentList<Integer> digits) {
        PersistentMap<Character, Integer> subst = map();
        Collection<String> solutions = go(str, digits, subst).collect(toList());
        return solutions;
    }

    Stream<String> go(PersistentList<Character> str, PersistentList<Integer> digits,
            PersistentMap<Character, Integer> subst) {
        // Each level of nesting accomplishes one substitution of a character by a number. 
        // The recursion ends when we run out of characters to substitute.
        if (str.isEmpty()) {
            return prune(subst);
        }
        return digits.stream().parallel().flatMap(n -> go(str.tail(), digits.delete(n), subst.insert(str.head(), n)));
    }

    Stream<String> prune(PersistentMap<Character, Integer> subst) {
        // we know that we never look up a value that’s not in the map
        if (subst.get('s') == 0 || subst.get('m') == 0) {
            return Stream.empty();
        }
        int send = toNumber(subst, "send");
        int more = toNumber(subst, "more");
        int money = toNumber(subst, "money");
        return send + more == money ? Stream.of(solution(send, more, money)) : Stream.empty();
    }

    String solution(int send, int more, int money) {
        return "(" + send + "," + more + "," + money + ")";
    }

    static final int toNumber(Map<Character, Integer> subst, String word) {
        assert word.length() > 0;
        return word.chars().map(x -> subst.get((char)x)).reduce((x, y) -> 10 * x + y).getAsInt();
    }

    static List<Character> uniqueChars(String s) {
        return s.chars().distinct().mapToObj(d -> Character.valueOf((char) d)).collect(toList());
    }
}

1 May 2015

Stream#flatMap() may cause short-circuiting of downstream operations to break

There is a bug report showing how flatMapping to an infinite stream may lead to non-termination of stream processing even in the presence of a short-circuiting terminal operation.

On StackOverflow, there is a discussion (in fact it was this discussion that led to the entry in the bug database) in which participants agree that the behavior is confusing and perhaps unexpected, but do not agree on whether it is actually a bug. There seem to be different ways to read the spec.

Here's an example:
Stream.iterate(0, i->i+1).findFirst()
works as expected, while
Stream.of("").flatMap(x->Stream.iterate(0, i->i+1)).findFirst()
will end up in an infinite loop.

The behavior of flatMap becomes still more surprising when one throws in an intermediate (as opposed to terminal) short-circuiting operation. While the following works as expected, printing out the infinite sequence of integers
Stream.of("").flatMap(x -> Stream.iterate(1, i -> i + 1)).forEach(System.out::println);
the following code prints out only the "1", but still does not terminate:
Stream.of("").flatMap(x -> Stream.iterate(1, i -> i + 1)).limit(1).forEach(System.out::println);

I cannot imagine a reading of the spec in which this were not a bug.

27 Apr 2015

Yield Return in Java (comment on Benji Webber)

Benji Webber has said that a feature often missed in Java by C# developers is yield return, and considers very complex ways of bringing this into Java for the implementation of iterators and generators. In particular he discusses a threading approach in some detail, which makes it seem really hard to generate and print out the integers from one to five.

In fact, with Java 8, there is no need for any of that, because the required behavior is already built into the lazy execution model of streams. Iterators and generators are part and parcel of the Stream API.

The following code, for example, will print the infinite series of positive numbers:

Stream.iterate(1, x->x+1).forEach(System.out::println);
Throwing in a limit(5) will give you the one-to-five example.

The reason that this works is that each stream element is only generated (lazily) when a downstream method explicitly asks for it. Contrary to appearances, no list of five elements is ever constructed in the snippet

Stream.iterate(1, x->x+1).limit(5).forEach(System.out::println);

To my eyes, the Java 8 code is even nicer to read than the corresponding code in C#.

Easy exhaustive search with Java 8 Streams

I have just been reading this post by Mark Dominus on Haskell. It discusses how the Haskell list monad can be used to hide some of the glue code involved in doing exhaustive searches. Java 8 Streams, which are somewhat similar to Haskell lists in also being monadic, lend themselves to the same style of coding.

The example used in the post I have quoted is the well-known crypt-arithmetics puzzle in which you are asked to find all possible ways of mapping the letters S, E, N, D, M, O, R, Y to distinct digits 0 through 9 (where we may assume that S is not 0) so that the following comes out true:

    S E N D
  + M O R E
  ---------
  M O N E Y

Here's my Java 8 port of Mark's Haskell example.

public class SendMoreMoney {

    static final List<Integer> DIGITS = unmodifiableList(asList(0,1,2,3,4,5,6,7,8,9));
    
    public static void main(String[] args) {
        List<String> solutions = 
            remove(DIGITS, 0).stream().flatMap( s ->
            remove(DIGITS, s).stream().flatMap( e ->
            remove(DIGITS, s, e).stream().flatMap( n ->
            remove(DIGITS, s, e, n).stream().flatMap( d ->
            remove(DIGITS, s, e, n, d).stream().flatMap( m ->
            remove(DIGITS, s, e, n, d, m).stream().flatMap( o ->
            remove(DIGITS, s, e, n, d, m, o).stream().flatMap( r ->
            remove(DIGITS, s, e, n, d, m, o, r).stream().flatMap( y ->
                { int send = toNumber(s, e, n, d);
                  int more = toNumber(m, o, r, e);
                  int money = toNumber(m, o, n, e, y);
                  return  send + more == money ? Stream.of(solution(send, more, money)) : Stream.empty();
                }
            ))))))))
            .collect(toList());
           
         System.out.println(solutions);
    }

    static String solution(int send, int more, int money) {
        return "(" + send + "," + more + "," + money + ")";
    }
    
    static final int toNumber(Integer... digits) {
        assert digits.length > 0;
        return Stream.of(digits).reduce((x,y) -> 10*x + y).get();
    }
    
    static final List<Integer> remove(List<Integer> xs, Integer... ys) {
        // this naive implementation is O(n^2).
        List<Integer> zs = new ArrayList<>(xs);
        zs.removeAll(asList(ys));
        return zs;
    }
}

The minor optimization of not unncecessarily recomputing "send" and "more" is left out for the sake of readability. The methods remove() - which implements list difference - toNumber(), and solution() have simple implementations. Of these, toNumber() is again a lot like the corresponding Haskell code. Method solution() here returns a String because Java does not have tuples.

Too bad that in Java one must have the nested method calls, but the formatting goes some way to hide this. All in all, I think this is quite nice.

But how fast is it? I did a simple micro-benchmark with JMH 1.9.1 (available from Maven Central) on my laptop computer, which is a quad-core machine with an Intel i7 processor.

Here are the measurement parameters:

# JMH 1.9.1 (released 5 days ago)
# VM invoker: C:\Program Files\Java\jdk1.8.0_25\jre\bin\java.exe
# VM options: -Dfile.encoding=UTF-8
# Warmup: 5 iterations, 1 s each
# Measurement: 25 iterations, 1 s each
# Timeout: 10 min per iteration
# Threads: 1 thread, will synchronize iterations
# Benchmark mode: Average time, time/op

I measured the flatMap solution against the equivalent formulation with eight nested forEach-loops and an external accumulator variable. The flatMap solution is about half as fast. Here's a representative measurement:

Benchmark                        Mode  Cnt    Score   Error  Units
measureFlatMapSearchPerformance  avgt   25  662.377 ± 3.747  ms/op
measureForLoopSearchPerformance  avgt   25  316.105 ± 3.823  ms/op


The nice thing abbout Streams is they are so easily parallelizable. Just throw in a .parallel() in the first line like this:

   remove(DIGITS, 0).stream().parallel().flatMap( s ->

leaving everything else unchanged, and the (parallel) flatMap version becomes twice as fast as the (serial) for-loop version:

Benchmark                        Mode  Cnt    Score   Error  Units
measureFlatMapSearchPerformance  avgt   25  168.278 ± 1.700  ms/op
measureForLoopSearchPerformance  avgt   25  315.806 ± 2.878  ms/op

16 Sept 2013

The Y Combinator

The Y combinator is one of the most beautiful ideas in programming theory. It allows one to define recursive functions from non-recursive, anonymous functions. The "anonymous" is what makes this particularly amazing: We are going to have recursion without having a way for a function to refer to itself by name. The usual example is the factorial function.

The theoretical background is beautifully and thoroughly explained by Mike Vanier in his article The Y Combinator (Slight Return), which I am not going to repeat here. (For a very short, readable intro to the topic, you might also refer to Deron Meranda's blog post on fixpoints.)

Many people seem to enjoy implementing the Y combinator, even in Java. Most implementations seem in some way to go back to this one by Ken Shirriff, which is in Java 7 and almost unreadable due to the use of inner classes. And here is a Java 8 implementation by Arul Dhesiaseelan. Note that the code on that last-mentioned post uses a method name Y, and that the comments on it propose a simplification which uses the this reference. Both are strictly not allowed, one might say that this is cheating. So here is my own, very similar but non-cheating implementation:

 
import java.util.function.Function;

public class YCombinator {

    // First a few abbreviations

    /** Function on T returning T. */
    interface Func<T> extends Function<T,T> {
    }

    /**
     * Higher-order function returning a T function.
     * <br>F: F -> (T -> T)
     * <br>This is the type of the Y combinator subexpressions.
     */
    interface FuncToTFunc<T> extends Function<FuncToTFunc<T>, Func<T>> {
    }

    /**
     * Function from a function on T functions to a T function.
     * <br>((T -> T) -> (T -> T)) -> (T -> T)
     * <br>This is the type of the fixed point operator.
     */
    interface Fix<T> extends Function<Func<Func<T>>, Func<T>>  {
    }

    // And now the real thing

    /**
     * Computes the factorial with the applicative-order Y combinator.
     * <br>Y = λr.(λf.(f f)) λf.(r λx.((f f) x))
     */
    public static int factorial(int n) {
        return  // Y combinator
                ((Fix<Integer>) r ->
                    ((FuncToTFunc<Integer>) (f -> f.apply(f)))
                            .apply(f -> r.apply(x -> f.apply(f).apply(x)))
                )
                .apply(
                        // Recursive function generator
                        f -> x -> (x == 0) ? 1 : x * f.apply(x - 1)
                )
                .apply(
                        // Argument
                        n);
    }
}

This code really computes the factorial function. And it does this using only anonymous functions: no function names or variables anywhere except lamba parameters. (You can plug-in any other recursive function you might want to compute, e. g. Fibonacci, or Ackermann, etc.).

The regrettable thing is having to have those casts. It would be nice to be able to get rid of them.

-- Sebastian