SciTech.blog
SciTech.blog

Recursive sum types in C++

26 Sep 2017, 09:20 • c++, algebraic types

The recently released C++17 provides a new type, std::variant, for implementing sum types. This post focusses on its basic use.

Sum types (also called tagged unions or, in mathematical parlance, coproducts) are useful in implementing objects that can take on values of different types. We'll illustrate its use by implementing a simple expression evaluator.

Let's define an expression type that can be either a constant or a sum. The declaration will be

template<typename T, template<typename> class E> using ExprVariant = std::variant<T, Ref<Add<E<T>>>>;
template<typename T> class Expr : public ExprVariant<T,Expr>;

We want the class to be generic so we can represent expressions of any (meaningful) type. The template arguments of std::variant specify that Expr is either T or Ref⟨Add⟨Expr⟨T⟩⟩⟩, which is defined as follows:

template<typename T> class Add {
public:
  T left, right;
  Add(T&& l, T&& r) : left(l), right(r) {}
  Add(const T& l, const T& r) : left(l), right(r) {}
};

This class represents the sum of two expressions. Note that we can't use Add directly because std::variant only admits value types, so we need a pointer since expressions can be recursive. We use Ref which is just an alias for std::shared_ptr.

template<typename T> using Ref = std::shared_ptr<T>;
template<typename T, typename... Args> Ref<T> make_ref(Args&&... args) {
  return std::make_shared<T>(std::forward<Args>(args)...);
}

We need to provide constructors for all the shapes an expression can take:

template<typename T> class Expr : public ExprVariant<T,Expr> {
public:
  Expr(const T& x) : ExprVariant<T,Expr>(x) {}
  Expr(Add<Expr<T>>&& a) : ExprVariant<T,Expr>(make_ref<Add<Expr<T>>>(a)) {}
  T eval() const { return std::visit(ExprEval<T,Expr>(), *this); }
  Expr<T> operator+(const Expr<T>& e) const {
    return Expr<T>(Add<Expr<T>>(*this, e));
  }
};

We now have a definition of what an expression is, but we don't know how to evaluate it. The following class defines how to evaluate an expression:

template<typename T, template<typename> class E> class ExprEval {
public:
  T operator()(const T& x) const { return x; }
  T operator()(const Ref<Add<E<T>>>& a) const {
    return a->left.eval() + a->right.eval();
  }
};

We can now use eval to evaluate expressions:

auto e = Expr(2) + Expr(3);
std::cout << e.eval() << std::endl;

The code is tested with clang-6.0 and libc++ (don't forget to use -std=c++17).

Comments

Name: