Monads in Python

22 Sep 2017, 16:02 • python, monads

A monad is an abstract concept used in functional programming to sequence computation, but it can also be useful in OOP. Let's take a look at how to implement monads in Python.

A monad can be thought of as an interface providing a few methods which make it constructible and composable. It typically wraps a value of an underlying type and has the peculiar property of sequencing computation.

In mathematical parlance, monads are functors, which means that they take objects to objects and functions to functions (in programming languages they provide an implementation of fmap). We'll see an example below.

Every monad \(M\) must provide the following:

  • A function which wraps objects of the underlying type \(t\). In programming this function is usually called unit. In maths it's typically denoted by \(\eta\) and is of type \(t \rightarrow Mt\).
  • A function which unwraps values of type \(MMt\) (that is, it's of type \(MMt \rightarrow Mt\)), typically denoted by \(\mu\).
  • A function called bind which takes a monadic value and a function taking a nonmonadic value to a monadic one and applies the latter to the former. It's typically denoted by \(>\!\!\!>\!\!=\).

\(\eta\) and \(\mu\) must satisfy a few conditions which we won't go into here. It is however important to note that the methods monads provide can be expressed in terms of each other. Specifically we have:

$$\begin{array}{c} \mu(x) = x >\!\!\!>\!\!= \lambda x.x \\ x >\!\!\!>\!\!= f = \mu (M f)(x) \\ (M f)(x) = x >\!\!\!>\!\!= \eta f \end{array}$$

It might be easier to see the relationships in code. Here's what we have in Python.

class Monad:
	def join(self):
		return self.bind(lambda x: x)
	def fmap(self, func):
		return self.bind(lambda x: self.__class__.unit(func(x)))
	def bind(self, func):
		return self.join(self.fmap(func))

Note that to define an instance of a monad, we needn't implement all of the methods. \(\mu\) can be expressed in terms of \(>\!\!\!>\!\!=\), \(>\!\!\!>\!\!=\) can be expressed in terms of \(\mu\) and the functor, etc. On the other hand, we must be careful in deciding what methods to implement, because if we choose wrong, we may end up in an infinite loop.

As an example, let's implement continuations. Continuations are monadic because they are functorial and can be equipped with \(\eta\) and \(\mu\). A continuation takes a value of type \(A\) to \((A \rightarrow R) \rightarrow R\). Let's see how the implementation looks in Python (pay special attention to how bind is defined):

class Continuation(Monad):
	def unit(x):
		return Continuation(lambda f: f(x))
	def __init__(self, val):
		self.val = val
	def __call__(self, func):
		return self.val(func)
	def bind(self, func):
		return Continuation(lambda f: self.val(lambda x: func(x)(f)))

The Continuation class is a callable object wrapping a λ-expression of type \((A \rightarrow R) \rightarrow R\) (Python isn't statically typed so we needn't care about what \(R\) is). We can now wrap a value into a continuation like so:

c = Continuation.unit(1234)

Note that the wrapped λ-expression is \(\lambda f.f(1234)\) and this is what the instance evaluates when called (see __call__). The bind method is slightly more complicated—it produces a λ-expression (taking \(f\)) which unwraps the first argument and applies it to \(\lambda x.g(x)(f)\) where \(g\) is the second argument of bind. The result can be fed into another bind or applied to a function of a compatible type.