The Curious Time-Traveling Reverse State Monad

By Zhouyu Qian | Email | Twitter | LinkedIn

Published July 8, 2018

Here is a very simple programming exercise that even a beginner Haskell programmer should be able to complete in a few minutes: given a list of Ints, can you produce a cumulative sum of those integers? For example, if we had the list [2, 3, 5, 7, 11, 13], we want to have [2, 5, 10, 17, 28, 41].

There are actually many different ways to write this function. Depending on your taste on imperative programming, you can choose anywhere between highly imperative ST-based destructive updates and an idiomatic functional style. Here’s how I would write it:

cumulative :: [Int] -> [Int]
cumulative = tail . scanl (+) 0

In general, this kind of operation is known as a “scan” and Prelude conveniently exports both scanl and scanr for the two directions of scanning. For example, what if you want to produce a cumulative sum that is accumulated from the right? So if we had the list [2, 3, 5, 7, 11, 13], we want to have [41, 39, 36, 31, 24, 13]. This is also a one-liner using the scanr and the analogue of tail, which is init:

cumulativeR :: [Int] -> [Int]
cumulativeR = init . scanr (+) 0

As a Haskell programmer, we have the instinct to generalize things. For example, instead of just adding Ints, we can actually add all kinds of numbers. In fact we can generalize to arbitrary monoids. As for the list structure, what if we support arbitrary Traversable structures? For example we might want to traverse a Map from strings to integers, finding the cumulative sum according to the order of the strings. Or we might want to traverse a Map of integers to strings, concatenating the strings. This leads to this final type for cumulative and cumulativeR:

import qualified Data.Map as M

-- | Make a traversable data structure cumulative by performing a
-- partial sum (actually a monoidal append).
-- >>> cumulative (M.fromList [(1, "one"), (3, "three"), (2, "two")])
-- fromList [(1,"one"),(2,"onetwo"),(3,"onetwothree")]
cumulative :: (Monoid w, Traversable t) => t w -> t w

-- | Like 'makeCumulative' but in reverse.
-- >>> cumulativeR (M.fromList [(1, "one"), (3, "three"), (2, "two")])
-- fromList [(1,"onetwothree"),(2,"twothree"),(3,"three")]
cumulativeR :: (Monoid w, Traversable t) => t w -> t w

But how might we implement such a function? Let’s consider cumulative first. What we really need is to keep track of the running sum as we traverse, and then returning the running sum as the new value. The State monad then becomes helpful. This leads to the following implementation:

cumulative :: (Monoid w, Traversable t) => t w -> t w
cumulative t =
       (\b -> state $ \s -> let r = s <> b in (r, r))

But cumulativeR seems harder to write. After all, we really want to start traversing from the right instead of the left. But Traversable allows no such thing. Even the first sentence of its documentation says it’s just for traversing from left to right:

Functors representing data structures that can be traversed from left to right.

So how might we implement this generalized cumulativeR function?

Enter the reverse state monad.

The normal state monad allows you take a state, and then produce a value and a new state. This is captured in the state class method in mtl. Alternatively, think of it as two functionalities: you can ask it for the current state, and you can also set the state, so that the next time you ask for it, you will get back what you previously set:

state :: (s -> (a, s)) -> State s a
get :: State s s
put :: s -> State s ()

The reverse state monad, on the other hand, has the same API, except that you can set the state, so that the last time you ask for it, you will get back the value you set in the future:

state :: (s -> (a, s)) -> ReverseState s a
get :: ReverseState s s
put :: s -> ReverseState s ()

Of course, we can’t really time travel, and the previous sentence is deliberately (slightly) hyperbolic because after all, a program is finite and we always know, from the structure of the program, what exactly is the statement that syntactically follows the current statement. So “setting the state you will retrieve in the future” is not really something mythical and mysterious, but rather just a clever use of laziness. Here’s the implementation of the reverse state monad:

{-# LANGUAGE DeriveFunctor #-}

newtype ReverseState s a = ReverseState
  { runReverseState :: s -> (a, s)
  } deriving Functor

instance Applicative (ReverseState s) where
  pure x = ReverseState $ (,) x
  mf <*> mx =
    ReverseState $ \s ->
      let (f, past) = runReverseState mf future
          (x, future) = runReverseState mx s
      in (f x, past)

instance Monad (ReverseState s) where
  mx >>= f =
    ReverseState $ \s ->
      let (a, past) = runReverseState mx future
          (b, future) = runReverseState (f a) s
      in (b, past)

In the above code, we used the DeriveFunctor extension so the Functor instance is automatically generated. The Applicative and Monad instances are pretty standard, except we just swapped the state variables.values

After this detour, we can finally write our generalized cumulativeR:

cumulativeR :: (Monoid w, Traversable t) => t w -> t w
cumulativeR t =
       (\b ->
          ReverseState $ \s -> let r = b <> s in (r, r))

evalReverseState :: ReverseState s a -> s -> a
evalReverseState m s = fst (runReverseState m s)

The function is essentially identical, save for the fact that we replaced state with ReverseState, runState with runReverseState and also interchanged the order of the monoidal append. So how are we able to accomplish this even when the Traversable class only allows left-to-right traversals?

The answer is that as we traverse from left to right, we keep returning values that will be set in the future. When we are traversing the first element, we essentially return the value r that is defined to be b <> s, which is the current element’s value appended with the “future” state. The future state is actually all the rest of the elements appended, even though at this step we don’t know what the rest of the state is; only after traversing the entire structure do we know it. The same thing happens afterwards: when we are at the second element, the “future” state is essentially all the elements from the third one onwards appended. At the last element when we are about to finish the traversal, we have reached the point where the future state is that mempty.

If this sounds a bit mind-boggling, it might be instructive to load the code in GHCi, play around with it, possibly augmenting it with Debug.Trace. As an example, to investigate exactly when and in what order the so-called “future” state s in the above function is evaluated, we can define the following function:

import Debug.Trace (trace)

cumulativeRExperiment :: (Show w, Monoid w, Traversable t) => t w -> t w
cumulativeRExperiment t =
       (\b ->
          ReverseState $ \s ->
            let r = b <> trace ("evaluating the cumulative value " ++ show s) s
            in (r, r))

After loading a file with this definition in GHCi, we can observe the evaluation of s like this:

ghci> :seti -XTypeApplications
ghci> :seti -XBangPatterns
ghci> import qualified Data.Map as M
ghci> import Data.Monoid
ghci> import Control.DeepSeq
ghci> let r = cumulativeRExperiment (M.fromList @Int (map (\i -> (i, Sum i)) [1,3 .. 11]))
ghci> let !z = force (M.findMax r) in print z
evaluating the cumulative value Sum {getSum = 0}
(11,Sum {getSum = 11})
ghci> let !z = force (r M.! 5) in print z
evaluating the cumulative value Sum {getSum = 11}
evaluating the cumulative value Sum {getSum = 20}
evaluating the cumulative value Sum {getSum = 27}
Sum {getSum = 32}

In this session, we were attempting to calculate the right cumulative sum of the map containing (as both keys and values) the odd integers from 1 to 11. When we called M.findMax to find the largest value, we find that the only “future” state that was evaluated was the mempty provided. When we called M.! to find the value at key 5, we see, as we expected, that the states being evaluated were starting from the right, first the rightmost 11, then 20 which is 9+11, then 27 which is 7+9+11.

We used a map in this example, but of course with a simpler data structure like the venerable list, it works in the same way.

In summary, we have used the reverse state monad to essentially perform a right-to-left traversal to find the cumulative sum.

Can we do better? Can we generalize this to more kinds of “cumulative” operations? What if, instead of a simple running sum, what if we want a running average? Or a running standard deviation? Or some entirely new thing such as the running maximum multiplied by the minimum? The only difference between all of those tasks is that the specific state transforming function (the function that was passed to ReverseState) is different. The remaining parts—the use of the reverse state monad, the overall structure of the function—are the same. We ought to extract this part out.

It turns out, however, that this work has already been done by Gabriel Gonzalez in his excellent foldl library. This library started with a Control.Foldl module to provide efficient left folds, but recently (to my surprise!) gained a Control.Scanl module. It defines this data type to represent a map-with-accumulator operation:

{-# LANGUAGE ExistentialQuantification #-}

import Control.Monad.Trans.State.Strict

data Scan a b = forall x. Scan (a -> State x b) x

Then we can use its scan function to perform a scan from left to right:

scan :: Traversable t => Scan a b -> t a -> t b
scan (Scan step begin) as = fst $ runState (traverse step as) begin

(The above source code is taken from version 1.4.2 of the foldl library.)

If we want to provide a right-to-left scan, we can reuse the definition of Scan, and provide a function scanRight that looks like this:

scanRight :: Traversable t => SL.Scan a b -> t a -> t b
scanRight (SL.Scan step begin) as =
  evalReverseState (traverse (ReverseState . runState . step) as) begin

Here, we simply unwrap a given state monad action, wrap it again in our ReverseState, do the traversal, then unwrap it again.

If we were to use this for our cumulative sum, we certainly can:

ghci> import qualified Control.Foldl as F
ghci> import qualified Control.Scanl as SL
ghci> scanRight (SL.postscan F.sum) [2 :: Int, 3, 5, 7, 11, 13]

But we gain so much more. Running averages:

ghci> scanRight (SL.postscan F.mean) [2 :: Double, 3, 5, 7, 11, 13]

Running standard deviation:

ghci> scanRight (SL.postscan F.std) [2 :: Double, 3, 5, 7, 11, 13]

Or the crazy minimum-so-far multiplied by maximum-so-far:

ghci> import Control.Applicative
ghci> scanRight (SL.postscan ((liftA2 . liftA2) (*) F.maximum F.minimum)) [2 :: Int, 3, 5, 7, 11, 13]
[Just 26,Just 39,Just 65,Just 91,Just 143,Just 169]

The Control.Scanl module gives us efficient, composable, left-to-right map-with-accumulator. With a little bit of magic from the reverse state monad, we can have efficient, composable, right-to-left map-with-accumulator.

Just to further dispel any notion that this reverse state monad is somehow too magical, a somewhat surprising discovery was that a copy of it actually already exists in base. Indeed, the innocuous-sounding mapAccumR function in base can already provide you with right-to-left map-with-accumulator operation. If you don’t have complicated scans, it is perfectly fine to simply use mapAccumR with a handwritten scan and not use the above abstraction. And the kicker is that mapAccumR uses the reverse state monad too.