SciTech.blog
SciTech.blog

Catamorphisms in C++

3 Oct 2017, 07:07 • c++, catamorphism

In functional programming, fixpoints of data types are used to define recursive types. Let’s see an example of how one can use this technique in C++.

We've already seen how to define catamorphisms in Python so this post contains little prose and only translates the Python code into C++.

Before we begin, we need to define an auxiliary type for representing references to objects:

template<typename T> using Ref = std::shared_ptr<T>;
template<typename T, typename... Args> Ref<T> make_ref(Args&&... args) {
  return std::make_shared<T>(std::forward<Args>(args)...);
}

To keep things simple, we'll use the standard shared pointer (shared_ptr) but the type alias makes this auxiliary type easily replaceable.

Recall how the whole mechanism works:

FμF FA μF A κ in α

In C++, the code is slightly more complicated than in Python, but the good thing is that it's type-safe. We'll use the new std::variant type from C++17:

template<typename F, typename A> class FmapVisitor;

template<typename F, typename A> auto fmap(F f, A a) { return std::visit(FmapVisitor<F,A>(f), a); }

template<template<typename> class F> class Fix : public F<Fix<F>> {
public:
  Fix(F<Fix<F>> f) : F<Fix<F>>(f) {}
};

template<template<typename> class F> Fix<F> fix(F<Fix<F>> f) { return Fix<F>(f); }

template<template<typename> class F> F<Fix<F>> unfix(F<Fix<F>> f) { return f; }

template<template<typename> class F, typename T> std::function<T(Fix<F>)> cata(const std::function<T(F<T>)>& alg) {
  return [alg](auto f) -> T {
    return alg(fmap(cata(alg), unfix(f)));
  };
}

Remember that the data type we want to recursivise has to be functorial. The FmapVisitor class template allows us to define the functor on morphisms as a visitor for (a subclass of) variant. fmap then simply uses std::visit to apply a function to an instance of the given type.

The constructor of Fix is our in. Fix is a subclass of F⟨Fix⟨F⟩⟩ where F is a functor. unfix is the inverse of in (recall that in is invertible). Finally, cata returns a λ-expressions which “unfixes” the value it gets, recursively applies fmap to it and evaluates the result using the provided (non-recursive) evaluator.

For simple arithmetic expressions, the functor and the corresponding evaluator can be defined as follows:

template<typename T> class Add {
public:
  T left, right;
  Add(T&& l, T&& r) : left(l), right(r) {}
  Add(const T& l, const T& r) : left(l), right(r) {}
};

template<typename A> class ExprFEval {
public:
  int operator()(int x) const { return x; }
  int operator()(const Ref<Add<A>>& a) const { return a->left + a->right; }
};

template<typename T, typename U> using ExprVariant = std::variant<T, Ref<Add<U>>>;

template<typename A> class ExprF : public ExprVariant<int,A> {
public:
  ExprF(int x) : ExprVariant<int,A>(x) {}
  ExprF(Add<A>&& a) : ExprVariant<int,A>(make_ref<Add<A>>(a)) {}
};

int evalExprF(const ExprF<int>& e) {
  return std::visit(ExprFEval<int>(), e);
}

template<typename F, typename A> class FmapVisitor<F, ExprF<A>> {
  F f;
  typedef decltype(std::declval<F>()(std::declval<ExprF<A>>())) B;
public:
  FmapVisitor(F&& _f) : f(_f) {}
  FmapVisitor(const F& _f) : f(_f) {}
  ExprF<B> operator()(int x) const { return x; }
  ExprF<B> operator()(const Ref<Add<A>>& a) const { return ExprF<B>(Add(f(a->left), f(a->right))); }
};

using Expr = Fix<ExprF>;

The evaluator (ExprFEval) evaluates expression trees of depth at most 2. ExprF is the functor, i.e., the type we want to recursivise. FmapVisitor is a template specialisation for ExprF applying a function to an expression tree. Finally, we define Expr by Fix⟨ExprF⟩.

The code was tested with clang-6.0 and libc++ (don't forget to use -std=c++17).

Comments

Name: