strint

访问者模式和应用于表达式求解的实现

代码的扩展问题

在面向对象编程中,我们定义对类和处理对类的函数。一般情况下,要么类比较稳定,这样可以比较独立的扩展函数;要么函数比较稳定,可以通过类继承扩展类。但是一种麻烦的情况是,类和处理类的函数都要扩展。

如果以类为中心,把函数都写成类方法,扩展子类很容易,但是类方法就需要改变每个原有的子类。如果以函数为中心,把对类的一类操作都定义成函数,那么扩展函数很容易,但是扩展一个子类一般需要改动每一类操作对应的函数。

这个类和函数都需要扩展的问题在语言解释器的实现中体现比较明显。表达式 Expression 如 Constant 代表常量、BinaryPlus 代表二元加法,Expression 定义了算式;操作 Operation, 如 Evaluator 进行求值、Stringifier 进行字符串化,Operation 定义了对表达式的操作。Expression 和 Operation 通常都需要不断扩展。

以类为中心的扩展问题

通常我们的代码都是以类为中心的。以一个简单的表达式解释器为例,我们看下以类为中心的实现。

class Expr {
public:
  virtual std::string ToString() const = 0;
  virtual double Eval() const = 0;
};

Expr 是表达式的基类,定义了 Eval 和 ToString 两个基类方法。Eval 对表达式求值,ToSting 生成表达式的字符串表达。

基于 Expr,我们可以定义一个常量表达式:

class Constant : public Expr {
 public:
  Constant(double value) : value_(value) {}
  std::string ToString() const {
    std::ostringstream ss;
    ss << value_;
    return ss.str();
  }
  double Eval() const {
    return value_;
  }
 private:
  double value_;
};

定义一个典型的二元加法表达式,且该表达式是存在子表达式的:

class BinaryPlus : public Expr {
 public:
  BinaryPlus(const Expr& lhs, const Expr& rhs) : lhs_(lhs), rhs_(rhs) {}
  std::string ToString() const {
    return lhs_.ToString() + " + " + rhs_.ToString();
  }
  double Eval() const {
    return lhs_.Eval() + rhs_.Eval();
  }
 private:
  const Expr& lhs_;
  const Expr& rhs_;
};

现在我们可以这样去使用:

  Constant c0{1.1};
  std::cout << "Value of a Constant " << c.Eval() << std::endl;
  Constant c1{1.2};
  BinaryPlus b0{c0, c1};
  std::cout << "String of a Binary " << b0.ToString() << std::endl;

很显然,如果我们想再定义一个表达式,只需要新定义一个表达式的子类,然后该子类实现 Eval 和 ToString 就好。

但是如果我们想新增一个基类方法,如生成表达式的汇编代码 ToAssembly,我们需要修改原来每个子类,给其增加 ToAssembly 方法。

这样就违反了面向对象设计中的开闭原则:类、模块、函数等软件实体需要对扩展保持开放但对修改保持封闭。

总之,以类为中心的实现,扩展函数方法是困难的。

以函数为中心的扩展问题

如果我们要处理的类的种类有限,而类的方法需要不断扩展,我们可以考虑实现成以函数为中心。以解释器为例,我们可以把 Operation 从类中拆分出来,然后以 Operation 为中心来实现上文的功能。

如下定义 Operation 的基类,它负责对 Constant 和 BinaryPlus 这样的表达式执行处理。

class ExprProcessor {
 public:
  virtual void ProcessConstant(const Constant& c) = 0;
  virtual void ProcessBinaryPlus(const BinaryPlus& bp) = 0;
};

然后定义 Expression 的基类,它只是有个接受处理的虚函数方法。

class Expr {
 public:
  virtual void Accept(ExprProcessor* process) const = 0;
};

基于此,我们定义 Constant 和 BinaryPlus:

class Constant : public Expr {
 public:
  Constant(double value) : value_(value) {}
  void Accept(ExprProcessor* process) const {
    op->ProcessConstant(*this);
  }
  double GetValue() const {
    return value_;
  }
 private:
  double value_;
};

class BinaryPlus : public Expr {
 public:
  BinaryPlus(const Expr& lhs, const Expr& rhs) : lhs_(lhs), rhs_(rhs) {}
  void Accept(ExprOp* op) const {
    op->ProcessBinaryPlus(*this);
  }
  const Expr& GetLhs() const {
    return lhs_;
  }
  const Expr& GetRhs() const {
    return rhs_;
  }
 private:
  const Expr& lhs_;
  const Expr& rhs_;
};

然后我们分别实现求值、Stringifier操作。

class Evaluator : public ExprProcessor {
 public:
  double GetValueForExpr(const Expr& e) {
    return value_map_[&e];
  }

  void ProcessConstant(const Constant& c) {
    value_map_[&c] = c.GetValue();
  }

  void ProcessBinaryPlus(const BinaryPlus& bp) {
    bp.GetLhs().Accept(this);
    bp.GetRhs().Accept(this);
    value_map_[&bp] = value_map_[&(bp.GetLhs())] + value_map_[&(bp.GetRhs())];
  }

 private:
  std::map<const Expr*, double> value_map_;
};

Class Stringifier : public ExprProcessor {
 public:
  std::string GetStrForExpr(const Expr& e) {
    return str_map_[&e];
  }

  void ProcessConstant(const Constant& c) {
    str_map_[&c] = std::to_string(c.GetValue());
  }

  void ProcessBinaryPlus(const BinaryPlus& bp) {
    bp.GetLhs().Accept(this);
    bp.GetRhs().Accept(this);
    str_map_[&bp] = str_map_[&(bp.GetLhs())] + " + " + str_map_[&(bp.GetRhs())];
  }
 private:
  std::map<const Expr*, std::string> str_map_;
}

现在我们可以这样去使用:

  Evaluator e_op;
  Constant c0{1.1};
  c0.Accept(&e_op);
  std::cout << "Value of a Constant " << e_op.GetValueForExpr(c0) << std::endl;

  Stringifier s_op;
  Constant c1{1.2};
  BinaryPlus b0{c0, c1};
  b0.Accept(s_op)
  std::cout << "String of a Binary " << s_op.GetStrForExpr(b0) << std::endl;

可以发现,对比原来以类为中心的实现,把原来类里面的方法给拆分出来,就变成了以函数为中心的设计:Evaluator 对应原 Eval 方法,Stringifier 对应原 ToString 方法。

原来的 Expr 中只保存 Expr 的数据信息,而不涉及具体的处理方法,只是定义了一个接口 Accept 来接受处理方法的处理。

在这种写法下,当我们想扩展一个新的方法时,只需要继承 ExprProcessor,然后实现针对每种 Expr 的 Process[Expr] 方法就好。现在扩展方法可以不改动原来的代码,只是增加一个 ExprProcessor 就完成了。

但是当我们增加一个 Expr 类型时,发现需要扩展 ExprProcessor 的基类来扩展对新 Expr 的处理,此时增加类型是困难的。

这里如果我们把上面代码中的 Process 都换成 Visit,就符合 Visitor 模式的典型函数接口了。实际这已经很接近 Visitor 模式了,Visitor 模式的 Visit 理解成 Process 会更简单一点,实际在做的主要是把方法(Method)从类(Class)中拆分出来的工作。

扩展问题总结

上面的扩展类和扩展方法不能两全的问题,Eli 把它叫做表达问题,并且做了一个很好的描述叫表达问题矩阵。下面的 2-D 矩阵中,列代表 Type(类)的维度,行代表 Operation (方法)的维度。打对钩的表示对于一个 Type 实现了其 Operation,那么下面的矩阵就可以很好的描述我们之前实现的对 Constant 和 BinaryPlus 的 Evaluate 和 Stringifier:

在面向对象语言中,以类为中心的设计,一个类对应该矩阵中的一行。扩展类(增加新的行)很容易,但是扩展方法(增加新的列)很难,可以表示为:

在以函数为中心的设计中,一个处理操作对应该矩阵中的一列。扩展方法(增加新的列)很容易,但是扩展类(增加新的行)很难,可以表示为:

如何让扩展类和扩展方法都变得容易

在以函数为中心的设计中,已经把类中的函数拆分出来了,叫做 ExprProcessor。但我们不满意的是 ExprProcessor 基类中耦合了处理各种 Expr 的方法如 ProcessConstant 和 ProcessBinaryPlus,这导致扩展新 Expr 需要在 ExprProcessor 增加新的 Process 方法。

对于 ExprProcessor 耦合了 Expr 类型,可以去掉这个耦合,把 ExprProcessor 从:

class ExprProcessor {
 public:
  virtual void ProcessConstant(const Constant& c) = 0;
  virtual void ProcessBinaryPlus(const BinaryPlus& bp) = 0;
};

改为:

class ExprProcessor {
 public:
  virtual void Process(const Expr* exp) = 0;
};

如此再实现新的 Evaluator,负责 Expr 的求值,它也是一个基类:

class Evaluator : virtual public ExprProcessor {
 public:
  double GetValueForExpr(const Expr& e) {
    const auto& iter = value_map_.find(&e);
    if (iter == value_map_.end()) {
      std::cout << "Can't find value." << std::endl;
      return 0.0;
    }
    return iter->second;
  }
  // Eval 只是对 Accept 调用的简单包装
  void Eval(const Expr& e) {
    return e.Accept(this);
  }

 protected:
  std::map<const Expr*, double> value_map_;
};

现在再实现一个针对 Constant 的 Evaluator:

class EvalConstant: virtual public Evaluator {
 public:
  void Process(const Expr* e) {
    const auto* de = dynamic_cast<const Constant*>(e);
    if (de) {
      value_map_[e] = de->GetValue();
    }
  }
};

可以看到,当前的 Evaluator 中,已经没有和 Constant 或者 BinaryPlus 这些具体 Expr 相关的耦合了。而 Evaluator 和具体 Expr 的耦合已经被拆分出来,比如 Evaluator 和 Constant 的组合,已经被拆分成了 EvalConstant。

之前的设计,在 2-D 的表达矩阵中,以类为中心的设计以行为单位做定义,以方法为中心的设计以列为单位做定义。而现在实现了以点为单位做定义,这样就克服了扩展行、扩展列不可兼得的问题。

ExprProcesser 之前扩展方法就比较方便,现在可以让扩展 Expr 也变得方便,比如现在扩展一个 FunctionCall Expr 类型:

class FunctionCall : public Expr {
 public:
  FunctionCall(const std::string& name, const Expr& argument)
      : name_(name), argument_(argument) {}

  void Accept(ExprProcessor* processor) const {
    processor->Process(this);
  }
  const std::string GetName() const {
    return name_;
  }
  const Expr& GetArg() const {
    return argument_;
  }

 private:
  std::string name_;
  const Expr& argument_;
};

扩展它并不用修改原来的代码,现在再实现针对它的 Evaluator:

class EvalFunctionCall: virtual public Evaluator {
 public:
  void Process(const Expr* e) {
    const auto* de = dynamic_cast<const FunctionCall*>(e);
    if (de) {
      std::cout << de->GetName() << " called." << std::endl;
      de->GetArg().Accept(this);
      value_map_[e] = value_map_[&de->GetArg()];
    }
  }
};

达到了扩展新类型也比较方便的目的。这时我们可以这样使用它:

  std::unique_ptr<Expr> c1(new Constant(1.1));
  EvalConstant ec;
  // 这里是 Visitor 模式中典型的 Double dispatch
  // ec.Eval(*c1) will call c1.Accept(ec)
  // c1.Accept(ec) is the first dispatch, finding the right Constatnt::Accept
  // In Constatnt::Accept, processor->Process(this) is ec->Process(*c1)
  // ec->Process(*c1) is the second dispatch, find the right EvalConstant::Process
  ec.Eval(*c1);
  std::cout << "Eval constant: " << ec.GetValueForExpr(*c1) << std::endl;

如果 Expr 都是像 Constant 这样不会嵌套 Expr,那么现在的实现就是很完美的。大部分情况下的问题,现在的实现就足够了。

但是 BinaryPlus 和 FunctionCall 都比较特别,他们会嵌套 Expr。比如我们定义一个 FunctionCall Expression:

  std::unique_ptr<Expr> c1(new Constant(1.1));
  std::unique_ptr<Expr> c2(new Constant(2.2));
  std::unique_ptr<Expr> p2(new BinaryPlus(*c1, *c2));
  std::unique_ptr<Expr> f(new FunctionCall("test_func", *p2));

f 这个 FunctionCall 做 Evaluate 时,其内部的子 Expr 需要的 Evaluator 和 EvalFunctionCall 不同,这里发现不得不定义一个通用的 GeneralEval 来实现对嵌套 Expr 的支持:

// A general Evaluator to support nested expression evaluation
class GeneralEval: public EvalConstant, EvalBinaryPlus, EvalFunctionCall {
 public:
  void Process(const Expr* e) {
      if (dynamic_cast<const Constant*>(e)) {
        EvalConstant::Process(e);
      }
      if (dynamic_cast<const BinaryPlus*>(e)) {
        EvalBinaryPlus::Process(e);
      }
      if (dynamic_cast<const FunctionCall*>(e)) {
        EvalFunctionCall::Process(e);
      }
  }
};

这是嵌套 Expr 带来的新问题,需要在一个地方注册 Expr 和 ExprProcessor 的组合。这里就不那么完美了,我现在还没想到更好的方法。

现在试用下这个 GeneralEval:

  // Double dispatch with nested expressioin
  GeneralEval e;
  // e.Eval(*f) will call f.Accept(e)
  // f.Accept(ec) is the first dispatch, finding the right Constatnt::Accept
  // In FunctionCall::Accept, processor->Process(this) is e->Process(*f)
  // e->Process(*f) is the second dispatch, find the right EvalFunctionCall::Process
  e.Eval(*f);
  std::cout << "Eval function call: " << e.GetValueForExpr(*f) << std::endl;

输入和返回值类型的扩展问题

上文中 Evaluator 返回数值类型,而 Stringifier 返回字符串类型。这种输出(也可能是输入)的扩展性也是常见的。

上文规避了这个问题,在 Accept 或者 Process【Visitor 模式中经常被写成 Visit】 这些基类方法中没有涉及返回值类型,而是把返回值存在了 Processor【或者叫 Visitor 】中。之后每种 Processor 定义自己特殊的结果获取方法。比如 Evaluator 就通过 GetValueForExpr 获取计算的数值,而 Stringifier 则通过 GetStrForExpr 来获取字符串表达结果。

如果想让输出、输入也具有扩展性,可以采用 CRTP 和 Template 来改进,具体的例子可以参考 OneFlow 中实现的一个以函数为中心的 StreamTypeVisitor 的例子。它支持了输出、输入的扩展性,每类 Visitor 可以有不同的输出、输入类型:

template<typename DerivedT>
struct StreamTypeVisitor {
  template<typename... Args>
  static auto Visit(StreamType stream_type, Args&&... args) {
    switch (stream_type) {
      case StreamType::kInvalid: LOG(FATAL) << "invalid stream type";
      case StreamType::kCompute: return DerivedT::VisitCompute(std::forward<Args>(args)...);
      case StreamType::kHost2Device: return DerivedT::VisitHost2Device(std::forward<Args>(args)...);
    }
    LOG(FATAL) << "invalid stream type";
  }
};

其 Visit【或者叫 Process 】方法的输入为可变参数类型,输出使用 auto 自动推导。且使用了 CRTP 这种静态多态来支持输入、输出的类型变换。

输入为空,输出为 char*情况1

struct GetStreamTypeName : public StreamTypeVisitor<GetStreamTypeName> {
  static const char* VisitCompute() { return "compute"; }
  static const char* VisitHost2Device() { return "h2d"; }
};

输入为Symbol<Device>, 输出为Maybe<vm::StreamPolicy>情况2

struct CreateStreamPolicy final : public StreamTypeVisitor<CreateStreamPolicy> {
  static Maybe<vm::StreamPolicy> VisitCompute(Symbol<Device> device) {
    return std::shared_ptr<vm::StreamPolicy>(new vm::EpStreamPolicy(device));
  }
  static Maybe<vm::StreamPolicy> VisitHost2Device(Symbol<Device> device) {
    std::unique_ptr<vm::Allocator> allocator{};
    if (device->enum_type() == DeviceType::kCPU) {
      allocator = vm::EventRecordedEpStreamPolicy::CreateEpBackendDeviceAllocator(device);
    } else {
      allocator =
          std::make_unique<vm::UnimplementedAllocator>("allocator is not supported on h2d stream.");
    }
    return std::shared_ptr<vm::StreamPolicy>(
        new vm::EventRecordedEpStreamPolicy(device, std::move(allocator)));
  }
};

总结

在类和方法都需要扩展时,可以使用访问者模式(Visitor Pattern)是提高代码可扩展性。

最后这个版本的代码见:https://github.com/strint/strint.github.io/blob/master/codes/221113-visitor_pattern.cpp

可以在这里方便的执行它:https://godbolt.org/

参考资料