The trek goes on

0023

The traced state monad

Posted at 16:10 on 1 October 2021, and revised at 16:16 on 11 October 2021

The covid outbreak that started in May is now suppressed, and life is gradually going back to normal (modulo mask-wearing). Coincidentally, Tchaikovsky’s violin concerto and fifth symphony marked the beginning and end of the outbreak: these two works were performed in both the last and first concerts I went to before and after the outbreak (respectively by the Taipei Symphony Orchestra and Taiwan Philharmonic). I was also able to visit my parents’ place during the Mid-Autumn holidays, relieving some stress accumulated during the four months; the holidays, however, caused a long delay in writing this post, and I shouldn’t delay any more.


My undergraduate advisee Zong-You Shih and I are looking into traversal of potentially infinite trees, which reminded me of Maciej and Jeremy’s work on tracing monadic computations back when I started my DPhil. As a small exercise, I decided to play with the traced state monad (a special case of their construction) in Cubical Agda, where we can naturally reason about negatively typed constructions (including functions and codata) extensionally.

The usual state monad State S = λ A → S → A × S is less useful when computation may be non-terminating: if the computation terminates, then we get the result and final state, which is good; but if the computation does not terminate, then we can only wait indefinitely and get nothing. Here a more useful construction is the traced state monad States S, which provides an additional operation mark : States S Unit that, when invoked, records the current state in a trace. If mark is invoked at the right times, even when the computation does not terminate, we still get an infinite trace of intermediate states instead of having to wait for nothing; when the computation does terminate, we get a finite trace of intermediate states followed by the result and final state, as in the case of the usual state monad.

Since the trace is essentially a (non-empty) colist, in Agda we define it as a coinductive record:

mutual

  record Trace (S A : Type) : Type where
    coinductive
    field
      force : TraceStep S A

  data TraceStep (S A : Type) : Type where
    ret : A → S → TraceStep S A
    int : S → Trace S A → TraceStep S A

open Trace

We can force a trace to reveal one TraceStep, which can be either ret a s where a is the result and s the final state, or int s t where s is an intermediate state and t the rest of the trace. Computation in the traced state monad takes an initial state and produces a trace:

States : Type → Type → Type
States S A = S → Trace S A

The mark operation can then be defined by copattern matching:

mark : States S Unit
force (mark s) = int s (return tt s)

Given an initial state s, forcing the trace mark s yields s as an intermediate state, while the rest of the trace is return tt s, where return is similarly defined by copattern matching:

return : A → States S A
force (return a s) = ret a s

We can also define join by

join : States S (States S A) → States S A
join mma s = joinTrace (mma s)

It calls joinTrace to concatenate mma s : Trace S (States S A) — which essentially consists of two traces — into a single trace:

joinTrace : Trace S (States S A) → Trace S A
force (joinTrace tma) with force tma
... | ret ma s   = force (ma s)
... | int s tma' = int s (joinTrace tma')

When joinTrace tma is forced, we should force tma to see whether we have reached the end of the first trace: if force tma = ret ma s, the first trace is depleted, and we continue to force the second trace ma s; otherwise force tma = int s tma', in which case we emit the intermediate state s and call joinTrace corecursively.

The fmap operation for States also follows the same pattern, where the real work is done by fmapTrace, which skips over all the intermediate states in a trace and applies the given function to the result at the end of the trace:

fmapTrace : (A → B) → Trace S A → Trace S B
force (fmapTrace f ta) with force ta
... | ret a s   = ret (f a) s
... | int s ta' = int s (fmapTrace f ta')

fmap : (A → B) → States S A → States S B
fmap f ma s = fmapTrace f (ma s)

Bind can then be defined in the standard way:

_>>=_ : States S A → (A → States S B) → States S B
ma >>= f = join (fmap f ma)

_>>_ : States S A → States S B → States S B
ma >> mb = ma >>= λ _ → mb

It’s also easy to define the state-manipulating operations, which I’ll just dump below:

get : States S S
force (get s) = ret s s

put : S → States S Unit
force (put s _) = ret tt s

modify : (S → S) → States S Unit
modify f = get >>= (put ∘ f)

It’s time to prove some monad laws, which in the case of States are very extensional: monad laws by nature are pointwise equations between functions; moreover, in the case of States, the codomain of the functions is Trace, a codata type, meaning that we have to deal with some kind of bisimilarity as well. Fortunately, pointwise equality and bisimilarity simply coincide with the equality/path types within Cubical Agda. To construct an inhabitant of x ≡ y where x, y : A, that is, a path from x to y, think of the path as a value p i : A parametrised by an interval variable i : I and satisfying the boundary conditions p i0 = x and p i1 = y, where i0, i1 : I are the two endpoints of the (abstract) unit interval I. The construction of a path has to be done without inspecting i (in particular by pattern matching with i0 and i1) so that, very informally speaking, the points on the path do not change abruptly, making the path ‘continuous’ (which is the mysterious intuition that HoTT people want us to accept and I don’t understand yet).

Let us walk through the proof of the right unit law to see how path construction proceeds:

rightUnit : join ∘ fmap return ≡ id {States S A}

This is a path on States S A → States S A, so we should construct a function of that type parametrised by an interval variable i, which we introduce in the same way as an argument of a function:

rightUnit i {-: I -} = {! States S A → States S A !}

The goal type is now States S A → States S A, and the function we construct should be join ∘ fmap return when i = i0 and id when i = i1. We further introduce the two arguments of the function:

rightUnit i ma {-: States S A -} s {-: S -} = {! Trace S A !}

Crucially, the goal boundaries are also expanded pointwise and simplified to joinTrace (fmapTrace return (ma s)) and ma s. This expansion of boundary conditions inside definitions seems to be the key to achieving extensionality of paths on negative types systematically: To construct a path on a negative type is to construct an I-parametrised value of that type. A negatively typed value is defined by results of elimination (application in the special case of functions); as we specify the results, the boundary conditions are also expanded and simplified and become conditions on the results, so that in the end what we prove is that the two boundaries of a path eliminate to the same results, exactly what an extensional equality should be.

Back to rightUnit: if we can construct a path

rightUnitTrace : joinTrace ∘ fmapTrace return ≡ id {Trace S A}

then we can discharge rightUnit by

rightUnit i ma s = rightUnitTrace i (ma s)

For rightUnitTrace we follow the same proof pattern except that this time the elimination also involves force:

force (rightUnitTrace i {-: I -} ta {-: Trace S A -}) = {! TraceStep S A !}

The two goal boundaries are force (joinTrace (fmapTrace return ta)) and force ta; to simplify the first, by the definition of joinTrace we should force fmapTrace return ta, the result of which by the definition of fmapTrace depends on the result of forcing ta, so we should analyse force ta:

force (rightUnitTrace i ta) with force ta
... | ret a s   = {! TraceStep S A !}
... | int s ta' = {! TraceStep S A !}

At the first goal, both goal boundaries reduce to ret a s, so that’s what we fill in. At the second goal, the goal boundaries are int s (joinTrace (fmapTrace return ta')) and int s ta', which can be met by a corecursive call to rightUnitTrace i ta' wrapped within int s:

force (rightUnitTrace i ta) with force ta
... | ret a s   = ret a s
... | int s ta' = int s (rightUnitTrace i ta')

This proof can be verified manually by substituting i0 and i1 for i and checking case by case whether the results are equal to those of joinTrace ∘ fmapTrace return and id. The proof is admittedly elementary (as are the proofs for all the other laws, which are therefore omitted), but correspondingly it uses only basic features of Cubical Agda, which is nice — extensional equalities, even elementary ones, used to be awkward and tedious to deal with, but with modern technologies like Cubical Agda, they can finally be established with practically no formalisation overhead.

Agda file: TracedState.agda
Follow-up: 0024 (Productivity and the traced state monad)