# The Curious Time-Traveling Reverse State Monad

Published July 8, 2018

Here is a very simple programming exercise that even a beginner Haskell programmer should be able to complete in a few minutes: given a list of `Int`s, can you produce a cumulative sum of those integers? For example, if we had the list `[2, 3, 5, 7, 11, 13]`, we want to have `[2, 5, 10, 17, 28, 41]`.

There are actually many different ways to write this function. Depending on your taste on imperative programming, you can choose anywhere between highly imperative `ST`-based destructive updates and an idiomatic functional style. Here’s how I would write it:

``````cumulative :: [Int] -> [Int]
cumulative = tail . scanl (+) 0``````

In general, this kind of operation is known as a “scan” and `Prelude` conveniently exports both `scanl` and `scanr` for the two directions of scanning. For example, what if you want to produce a cumulative sum that is accumulated from the right? So if we had the list `[2, 3, 5, 7, 11, 13]`, we want to have `[41, 39, 36, 31, 24, 13]`. This is also a one-liner using the `scanr` and the analogue of `tail`, which is `init`:

``````cumulativeR :: [Int] -> [Int]
cumulativeR = init . scanr (+) 0``````

As a Haskell programmer, we have the instinct to generalize things. For example, instead of just adding `Int`s, we can actually add all kinds of numbers. In fact we can generalize to arbitrary monoids. As for the list structure, what if we support arbitrary `Traversable` structures? For example we might want to traverse a `Map` from strings to integers, finding the cumulative sum according to the order of the strings. Or we might want to traverse a `Map` of integers to strings, concatenating the strings. This leads to this final type for `cumulative` and `cumulativeR`:

``````import qualified Data.Map as M

-- | Make a traversable data structure cumulative by performing a
-- partial sum (actually a monoidal append).
--
-- >>> cumulative (M.fromList [(1, "one"), (3, "three"), (2, "two")])
-- fromList [(1,"one"),(2,"onetwo"),(3,"onetwothree")]
cumulative :: (Monoid w, Traversable t) => t w -> t w

-- | Like 'makeCumulative' but in reverse.
--
-- >>> cumulativeR (M.fromList [(1, "one"), (3, "three"), (2, "two")])
-- fromList [(1,"onetwothree"),(2,"twothree"),(3,"three")]
cumulativeR :: (Monoid w, Traversable t) => t w -> t w``````

But how might we implement such a function? Let’s consider `cumulative` first. What we really need is to keep track of the running sum as we traverse, and then returning the running sum as the new value. The `State` monad then becomes helpful. This leads to the following implementation:

``````cumulative :: (Monoid w, Traversable t) => t w -> t w
cumulative t =
evalState
(traverse
(\b -> state \$ \s -> let r = s <> b in (r, r))
t)
mempty``````

But `cumulativeR` seems harder to write. After all, we really want to start traversing from the right instead of the left. But `Traversable` allows no such thing. Even the first sentence of its documentation says it’s just for traversing from left to right:

Functors representing data structures that can be traversed from left to right.

So how might we implement this generalized `cumulativeR` function?

The normal state monad allows you take a state, and then produce a value and a new state. This is captured in the `state` class method in `mtl`. Alternatively, think of it as two functionalities: you can ask it for the current state, and you can also set the state, so that the next time you ask for it, you will get back what you previously set:

``````state :: (s -> (a, s)) -> State s a
get :: State s s
put :: s -> State s ()``````

The reverse state monad, on the other hand, has the same API, except that you can set the state, so that the last time you ask for it, you will get back the value you set in the future:

``````state :: (s -> (a, s)) -> ReverseState s a
get :: ReverseState s s
put :: s -> ReverseState s ()``````

Of course, we can’t really time travel, and the previous sentence is deliberately (slightly) hyperbolic because after all, a program is finite and we always know, from the structure of the program, what exactly is the statement that syntactically follows the current statement. So “setting the state you will retrieve in the future” is not really something mythical and mysterious, but rather just a clever use of laziness. Here’s the implementation of the reverse state monad:

``````{-# LANGUAGE DeriveFunctor #-}

newtype ReverseState s a = ReverseState
{ runReverseState :: s -> (a, s)
} deriving Functor

instance Applicative (ReverseState s) where
pure x = ReverseState \$ (,) x
mf <*> mx =
ReverseState \$ \s ->
let (f, past) = runReverseState mf future
(x, future) = runReverseState mx s
in (f x, past)

mx >>= f =
ReverseState \$ \s ->
let (a, past) = runReverseState mx future
(b, future) = runReverseState (f a) s
in (b, past)``````

In the above code, we used the `DeriveFunctor` extension so the `Functor` instance is automatically generated. The `Applicative` and `Monad` instances are pretty standard, except we just swapped the state variables.values

After this detour, we can finally write our generalized `cumulativeR`:

``````cumulativeR :: (Monoid w, Traversable t) => t w -> t w
cumulativeR t =
evalReverseState
(traverse
(\b ->
ReverseState \$ \s -> let r = b <> s in (r, r))
t)
mempty

evalReverseState :: ReverseState s a -> s -> a
evalReverseState m s = fst (runReverseState m s)``````

The function is essentially identical, save for the fact that we replaced `state` with `ReverseState`, `runState` with `runReverseState` and also interchanged the order of the monoidal append. So how are we able to accomplish this even when the `Traversable` class only allows left-to-right traversals?

The answer is that as we traverse from left to right, we keep returning values that will be set in the future. When we are traversing the first element, we essentially return the value `r` that is defined to be `b <> s`, which is the current element’s value appended with the “future” state. The future state is actually all the rest of the elements appended, even though at this step we don’t know what the rest of the state is; only after traversing the entire structure do we know it. The same thing happens afterwards: when we are at the second element, the “future” state is essentially all the elements from the third one onwards appended. At the last element when we are about to finish the traversal, we have reached the point where the future state is that `mempty`.

If this sounds a bit mind-boggling, it might be instructive to load the code in GHCi, play around with it, possibly augmenting it with `Debug.Trace`. As an example, to investigate exactly when and in what order the so-called “future” state `s` in the above function is evaluated, we can define the following function:

``````import Debug.Trace (trace)

cumulativeRExperiment :: (Show w, Monoid w, Traversable t) => t w -> t w
cumulativeRExperiment t =
evalReverseState
(traverse
(\b ->
ReverseState \$ \s ->
let r = b <> trace ("evaluating the cumulative value " ++ show s) s
in (r, r))
t)
mempty``````

After loading a file with this definition in GHCi, we can observe the evaluation of `s` like this:

``````ghci> :seti -XTypeApplications
ghci> :seti -XBangPatterns
ghci> import qualified Data.Map as M
ghci> import Data.Monoid
ghci> import Control.DeepSeq
ghci> let r = cumulativeRExperiment (M.fromList @Int (map (\i -> (i, Sum i)) [1,3 .. 11]))
ghci> let !z = force (M.findMax r) in print z
evaluating the cumulative value Sum {getSum = 0}
(11,Sum {getSum = 11})
ghci> let !z = force (r M.! 5) in print z
evaluating the cumulative value Sum {getSum = 11}
evaluating the cumulative value Sum {getSum = 20}
evaluating the cumulative value Sum {getSum = 27}
Sum {getSum = 32}
ghci>``````

In this session, we were attempting to calculate the right cumulative sum of the map containing (as both keys and values) the odd integers from 1 to 11. When we called `M.findMax` to find the largest value, we find that the only “future” state that was evaluated was the `mempty` provided. When we called `M.!` to find the value at key 5, we see, as we expected, that the states being evaluated were starting from the right, first the rightmost 11, then 20 which is 9+11, then 27 which is 7+9+11.

We used a map in this example, but of course with a simpler data structure like the venerable list, it works in the same way.

In summary, we have used the reverse state monad to essentially perform a right-to-left traversal to find the cumulative sum.

Can we do better? Can we generalize this to more kinds of “cumulative” operations? What if, instead of a simple running sum, what if we want a running average? Or a running standard deviation? Or some entirely new thing such as the running maximum multiplied by the minimum? The only difference between all of those tasks is that the specific state transforming function (the function that was passed to `ReverseState`) is different. The remaining parts—the use of the reverse state monad, the overall structure of the function—are the same. We ought to extract this part out.

It turns out, however, that this work has already been done by Gabriel Gonzalez in his excellent `foldl` library. This library started with a `Control.Foldl` module to provide efficient left folds, but recently (to my surprise!) gained a `Control.Scanl` module. It defines this data type to represent a map-with-accumulator operation:

``````{-# LANGUAGE ExistentialQuantification #-}

data Scan a b = forall x. Scan (a -> State x b) x``````

Then we can use its `scan` function to perform a scan from left to right:

``````scan :: Traversable t => Scan a b -> t a -> t b
scan (Scan step begin) as = fst \$ runState (traverse step as) begin``````

(The above source code is taken from version 1.4.2 of the `foldl` library.)

If we want to provide a right-to-left scan, we can reuse the definition of `Scan`, and provide a function `scanRight` that looks like this:

``````scanRight :: Traversable t => SL.Scan a b -> t a -> t b
scanRight (SL.Scan step begin) as =
evalReverseState (traverse (ReverseState . runState . step) as) begin``````

Here, we simply unwrap a given state monad action, wrap it again in our `ReverseState`, do the traversal, then unwrap it again.

If we were to use this for our cumulative sum, we certainly can:

``````ghci> import qualified Control.Foldl as F
ghci> import qualified Control.Scanl as SL
ghci> scanRight (SL.postscan F.sum) [2 :: Int, 3, 5, 7, 11, 13]
[41,39,36,31,24,13]``````

But we gain so much more. Running averages:

``````ghci> scanRight (SL.postscan F.mean) [2 :: Double, 3, 5, 7, 11, 13]
[6.833333333333333,7.8,9.0,10.333333333333334,12.0,13.0]``````

Running standard deviation:

``````ghci> scanRight (SL.postscan F.std) [2 :: Double, 3, 5, 7, 11, 13]
[4.017323597731316,3.7094473981982814,3.1622776601683795,2.494438257849294,1.0,0.0]``````

Or the crazy minimum-so-far multiplied by maximum-so-far:

``````ghci> import Control.Applicative
ghci> scanRight (SL.postscan ((liftA2 . liftA2) (*) F.maximum F.minimum)) [2 :: Int, 3, 5, 7, 11, 13]
[Just 26,Just 39,Just 65,Just 91,Just 143,Just 169]``````

The `Control.Scanl` module gives us efficient, composable, left-to-right map-with-accumulator. With a little bit of magic from the reverse state monad, we can have efficient, composable, right-to-left map-with-accumulator.

Just to further dispel any notion that this reverse state monad is somehow too magical, a somewhat surprising discovery was that a copy of it actually already exists in `base`. Indeed, the innocuous-sounding `mapAccumR` function in `base` can already provide you with right-to-left map-with-accumulator operation. If you don’t have complicated scans, it is perfectly fine to simply use `mapAccumR` with a handwritten scan and not use the above abstraction. And the kicker is that `mapAccumR` uses the reverse state monad too.