SyntaxHighlighter

3 May 2016

Concurrent Recursive Function Memoization

Recently on concurrency-interest, there has been a discussion triggered by an observation that Heinz Kabutz made about ConcurrentHashMap. The observation being that the following plausible-looking coding is broken:

public static class FibonacciCached {
  private final Map<Integer, BigInteger> cache = new ConcurrentHashMap<>();
  public BigInteger fib(int n) {
    if (n <= 2) return BigInteger.ONE;
    return cache.computeIfAbsent(n, key -> fib(n - 1).add(fib(n - 2)));
  }
}

ConcurrentHashMap livelocks for n ≥ 16. The reason, explained by Doug Lea, is that if the computation attempts to update any mappings, the results of this operation are undefined, but may include IllegalStateExceptions, or in concurrent contexts, deadlock or livelock.  This is partially covered in the Map Javadocs: "The mapping function should not modify this map during computation."

Mainly in conversation between Viktor Klang and me, and based on an original idea of Viktor's, another approach was developed that appears workable compared to computeIfAbsent. The approach also harks back to two previous posts in this blog, namely the ones about memoization and trampolining. The idea improves on the memoization scheme by providing reusable, thread-safe components for memoizing recursive functions, and combines that with a trampoline to process the function calls in such a way as to eliminate stack overflows.

I'll present the idea with a few explanatory comments. If you want a step-by-step derivation, you can get that by reading the mail list archives. There are three parts to the code.
 
First, a general purpose memoizer
.
public static class ConcurrentTrampoliningMemoizer<T, R> {
  private static final Executor TRAMPOLINE = newSingleThreadExecutor(new ThreadFactoryBuilder().setDaemon(true).build());
  private final ConcurrentMap<T, CompletableFuture<R>> memo;

  public ConcurrentTrampoliningMemoizer(ConcurrentMap<T, CompletableFuture<R>> cache) {
    this.memo = cache;
  }

  public Function<T, CompletableFuture<R>> memoize(Function<T, CompletableFuture<R>> f) {
    return t -> {
      CompletableFuture<R> r = memo.get(t);
      if (r == null) {
        final CompletableFuture<R> compute = new CompletableFuture<>();
        r = memo.putIfAbsent(t, compute);
        if (r == null) {
          r = CompletableFuture.supplyAsync(() -> f.apply(t), TRAMPOLINE).thenCompose(Function.identity())
                .thenCompose(x -> {
                   compute.complete(x);
                   return compute;
                });
        }
      }
      return r;
    };
  }
}

Second, a class that uses the memoizer to compute Fibonacci numbers.

public static class Fibonacci {
  private static final CompletableFuture<BigInteger> ONE = completedFuture(BigInteger.ONE);
  private final Function<Integer, CompletableFuture<BigInteger>> fibMem;

  public Fibonacci(ConcurrentMap<Integer, CompletableFuture<BigInteger>> cache) {
    ConcurrentTrampoliningMemoizer<Integer, BigInteger> memoizer = new ConcurrentTrampoliningMemoizer<>(cache);
    fibMem = memoizer.memoize(this::fib);
  }

  public CompletableFuture<BigInteger> fib(int n) {
    if (n <= 2) return ONE;
    return fibMem.apply(n - 1).thenCompose(x ->
           fibMem.apply(n - 2).thenApply(y -> 
             x.add(y)));
  }
}

Third, any number of clients of a Fibonacci instance.

Fibonacci fibCached = new Fibonacci(new ConcurrentHashMap<>());
BigInteger result = fibCached.fib(550_000).join();

As the final usage example shows, the 550.000th Fibonacci number is about the largest I can get on my box, before all the cached BigInteger values cause an OutOfMemoryError. The Fibonacci instance may be shared among several threads. The whole thing seems to scale OK, as my tests with 2, 4, or 8 threads sharing the same instance indicate.

The formulation of the fib method should be familiar to you from the previous blog entry on memoization, except we're using a different monad here (CompletableFuture instead of StateMonad).

The most interesting thing of course is the memoizer. Here are a few observations:
  • On a cache miss a container (in the form of a CompletableFuture) is inserted into the cache which will later come to contain the value.
  • Different threads will always wind up using the same value instances, so the cache is suitable also for values that are supposed to be singletons. Null values are not allowed.
  • Only the thread that first has a cache miss calls the underlying function.
  • Concurrent readers/writers won't block each other (computeIfAbsent would do that for instance).
  • The trampolining happens because we can bounce off the queue in the executor to which supplyAsync submits tasks. The technique has been explained in a previous post. We are just inlining our two utility methods terminate and tailcall, and fibMem is the thunk. I find it interesting to see how we benefit from this even if there is no tail recursion.
I haven't shown a few obvious static imports. The ThreadFactoryBuilder is from Guava.
You might want to pass in an executor to the memoizer from the outside. This would facilitate separate testing of the implementation vs. performance/scalability. However, performance seems to degrade when using anything but a dedicated single thread, so it may not be worth much in practice.

I am sure there are very many ways to improve on the idea, but I'll leave it at that.

1 comment:

  1. Example using Spring project Reactor


    public static class ConcurrentTrampoliningMemoizer1
    {

    private final ConcurrentMap> memo;

    public ConcurrentTrampoliningMemoizer1( ConcurrentMap> cache )
    {
    this.memo = cache;
    }

    public Function> memoize( Function> f )
    {
    return t ->
    {
    Mono r = memo.get( t );
    if( r == null )
    {
    final CompletableFuture compute = new CompletableFuture<>();
    final Mono mono = Mono.fromFuture( compute );

    r = memo.putIfAbsent( t , mono );
    if( r == null )
    {
    r = f.apply( t )
    .flatMap( x ->
    {
    compute.complete( x );
    return mono;
    } );
    }
    }
    return r;
    };
    }
    }

    public static class Fibonacci1
    {
    private static final Mono ONE = Mono.just( BigInteger.ONE );

    private final Function> fibMem;

    public Fibonacci1( ConcurrentMap> cache )
    {
    ConcurrentTrampoliningMemoizer1 memoizer = new ConcurrentTrampoliningMemoizer1<>( cache );
    fibMem = memoizer.memoize( this::fib );
    }

    public Mono fib( int n )
    {
    if( n <= 2 )
    return ONE;

    return fibMem.apply( n - 1 )
    .flatMap( x -> fibMem.apply( n - 2 )
    .map( y -> x.add( y ) ) );
    }
    }



    But I dont know how to avoid CompletableFuture completlely??

    ReplyDelete