新增Pass

本文从三个方面介绍了Lite中的Pass结构:Pass是什么Pass的实现与接口Pass的一般注册流程。最后以Fc_fuse_pass为例介绍了fusion_pass的作用与注册方法。

前述:Pass是什么?

CxxPredictor加载模型后,在执行预测前会先优化模型。模型优化过程是通过Pass实现的。 具体调用关系如下: https://user-images.githubusercontent.com/45189361/69638690-20d21880-1096-11ea-8169-1d2c7e1a1609.png图片

  • CreatePredictor(CxxConfig)函数调用了Predictor->Build(CxxConfig)
    • CxxPredictor的构建过程(Build)分为两步:
      • Predictor->LoadModel() 加载模型文件到program中
      • Predicotr->optimizer_.Run() 对Program中的原始图形结构进行优化
        • 对图结构的优化是通过调用 Pass->Apply(const std::unique_ptr<SSAGraph>& graph)方法实现的。

每一类Pass定义了一种优化过程,包括:原模型中的kernel选取、OP融合、冗余OP去除、子图创建、内存优化、类型推导、类型转换等。

Pass的实现与接口 :Pass基类、PassManager和Pass注册

1、Pass基类:paddle::lite::mir::Pass

  1. class Pass {
  2. public:
  3. // Pass的类型,Pass按照作用的不同可以分为三种
  4. enum class Kind { //种类的作用不太清楚
  5. // 1. 修改模型中的图拓扑结构的Pass
  6. kProgramWise = 0,
  7. // 2. 不修改图结构,修改状态的Pass
  8. kStmtWise,
  9. // 3. 不修改 IR,用于搜集信息和可视化信息的Pass.
  10. kDebug,
  11. };
  12. // 主要实现函数:Apply 函数定义了 Pass 运行时执行的操作
  13. virtual void Apply(const std::unique_ptr<SSAGraph>& graph) = 0;
  14. bool is_program_pass() const { return kind_ == Kind::kProgramWise; }
  15. bool is_stmt_pass() const { return kind_ == Kind::kStmtWise; }
  16. virtual ~Pass() = default;
  17. private:
  18. const Kind kind_; // pass 的种类
  19. std::string name_; // pass 的名称
  20. std::set<TargetType> bound_targets_; // 指定了Pass运行的硬件平台,模型优化过程会根据当前硬件平台是否匹配筛选Pass。
  21. std::unordered_map<std::string, std::set<lite_api::Place>> bound_kernels_; // 绑定的kernel
  22. };
  23. // Different kinds.
  24. class ProgramPass : public Pass {
  25. public:
  26. ProgramPass() : Pass(Kind::kProgramWise) {}
  27. };
  28. class StmtPass : public Pass {
  29. public:
  30. StmtPass() : Pass(Kind::kStmtWise) {}
  31. };
  32. class DebugPass : public Pass {
  33. public:
  34. DebugPass() : Pass(Kind::kDebug) {}
  35. };

代码位置lite/core/mir/pass.h 主要类成员const Kind kind_ : Pass类型。pass 有三种基本基本类型 :修改图结构的ProgramPass、修改状态量的StmtPass和Debug过程采集信息与控制可视化的DebugPassstd::string name_ :pass 的名称 std::set<TargetType> bound_targets_ : Pass运行的硬件平台,optimizer.Run()优化过程会根据硬件平台选择匹配的Pass。———根据硬件平台自动选择需要的pass std::unordered_map<std::string, std::set<lite_api::Place>> bound_kernels_ : Pass 绑定的kernel (what’s this used for) 主要接口Pass::Apply(const std::unique_ptr& graph) : Pass优化过程的具体操作,是新注册Pass需要实现的接口。输入为SSAGraph型指针,是对模型结构的拓扑表示。

2、Pass管理 paddle::lite::mir::PassManager

  1. class PassManager {
  2. public:
  3. // 内部静态变量PassManager,用来存储使用的Pass和图优化操作
  4. static PassManager& Global() {
  5. static PassManager x;
  6. return x;
  7. }
  8. // 执行所有的 Pass
  9. void Run(const std::unique_ptr<SSAGraph>& graph) {
  10. for (auto& pass : passes_) {
  11. LOG(INFO) << "Running MIR pass " << pass->name();
  12. pass->Apply(graph);
  13. }
  14. private:
  15. std::list<std::unique_ptr> passes_; //存储所有的 Pass
  16. std::map<std::string, mir::Pass*> pass_map_; //使用map变量存储 PassName::Pass
  17. }

代码位置lite/core/mir/pass_manager.h 主要类成员std::list:unique_ptr> passes_; : List类型,存储了所有已注册Pass。 std::map<std::string, mir::Pass*> pass_map_; : Map类型,存储了所有”Pass名称-Pass类”键对,用于根据名称查找Pass。

主要接口static PassManager& Global() 返回PassManager全局静态变量,该变量存储了所有已注册的Pass bool AddNewPass(const std::string& name, Pass* pass) 添加新的Pass到PassManager中

3、 Pass 注册 paddle::lite::mir::PassRegistry

代码位置lite/core/mir/pass_registry.h 主要接口REGISTER_MIR_PASS(name__, class__) :宏定义函数,用于注册Pass。注册Pass过程实现的是 PassManager::Global().AddNewPass(name__, class__),将新注册Pass添加到全局变量PassManager中。

Pass的一般注册流程与使用方法

1. Pass 注册流程

lite/core/mir或其子目录下继承Pass基类,实现Pass::Apply接口,并使用宏REGISTER_MIR_PASS(name__, class__)将Pass注册到PassManager即完成了新Pass注册。

**以新建 **new_demo_pass为例,具体流程如下: (1)在lite/core/mir路径下新建example_pass.ccnew_demo_pass.h 文件 (2)在example_pass.h 文件中继承Pass基类(ProgramPass、StmtPass或DebugPass)定义自己的Pass类。

  1. #include "lite/core/mir/pass.h"
  2. namespace paddle {
  3. namespace lite {
  4. namespace mir {
  5. class ExamplePass : public ProgramPass {
  6. void Apply(const std::unique_ptr<SSAGraph> &graph) override {}
  7. ...
  8. };
  9. } // namespace mir
  10. } // namespace lite
  11. } // namespace paddle

(3)在example_pass.cc 文件中实现ExamplePass::Apply()接口,并注册ExamplePass

  1. #include "lite/core/mir/pass_registry.h"
  2. #include "lite/core/mir/example_pass.h"
  3. namespace paddle {
  4. namespace lite {
  5. namespace mir {
  6. void ExamplePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
  7. ...
  8. }
  9. } // namespace mir
  10. } // namespace lite
  11. } // namespace paddle
  12. REGISTER_MIR_PASS(example_pass, paddle::lite::mir::ExamplePass)
  13. .BindTargets({TARGET(kARM)}); // Pass执行的目标硬件平台
  14. // .BindKernel("conv2d"); //Pass绑定的 kernel

(4)修改lite/core/mir/CMakeLists.txt文件,将example_pass.cc 编译到mir_passes库中

  1. lite_cc_library(mir_passes
  2. SRCS
  3. demo_pass.cc // 新建的Pass文件
  4. ...
  5. memory_optimize_pass.cc
  6. DEPS mir_pass types context ${mir_fusers} ${subgraph_passes})

2. Pass使用流程

将Pass注册到PassManager后不会自动生效。需要在optimizer->run() 函数中添加该Pass才会在模型优化过程中调用。 (1)在paddle_use_passes.h文件中调用该Pass

  1. #include "paddle_lite_factory_helper.h" // NOLINT
  2. ...
  3. USE_MIR_PASS(new_demo_pass); //调用 new_demo_pass

(2)要想在优化模型时调用该Pass,需要在optimizer->run()函数中手动添加调用。

修改lite/core/optimizer.h文件,添加new_demo_passOptimizer::Run()函数;

  1. class Optimizer {
  2. public:
  3. void Run(...) {
  4. ...
  5. if (passes.empty()) {
  6. RunPasses(std::vector<std::string>{
  7. {"new_demo_pass" //将新注册的Pass添加在这里
  8. ...
  9. }
  10. ...
  11. }

(3)只有CxxPredictor才会在模型加载后根据Pass优化模型。

  1. ...
  2. #include "paddle_use_passes.h" // 引用Pass优化模型
  3. void RunModel() {
  4. // 1. 创建 CxxConfig
  5. CxxConfig config;
  6. config.set_model_dir(FLAGS_model_dir);
  7. config.set_valid_places(Place{TARGET(kARM), PRECISION(kFloat)});
  8. // 2. 创建CxxPredictor,该过程包括加载模型和用Pass优化模型
  9. std::shared_ptr> predictor =
  10. Creat<CxxConfig>(config);
  11. }

Fusion Pass的定义与注册

Fusion Pass是一种常见图结构优化Pass,可将多个连续OP融合成单个等效OP,减少数据交换并简化图结构。Pass运行时调用Fuser自动查找并替换指定图结构,所以注册FuserPass时还需要实现对应的Fuser类。

下面以fc_fuse_pass为例,详细说明FusionPass的效果和注册方法。

fc_fuse_pass的作用

将相邻的mul算子和 element_wise add算子 融合成一个 FC 算子

  1. mul(X) = X * W
  2. elementwise_add( mul(x) ) = X * W + Bias
  3. //----------> after fusion
  4. FC(X) = X * W +Bias

Pass 运行效果如下: https://user-images.githubusercontent.com/45189361/69639193-12383100-1097-11ea-9063-21f030414080.png图片 mul和elementwise_add的原有参数映射到FC的参数上: https://user-images.githubusercontent.com/45189361/69638836-74446680-1096-11ea-9cdc-a961fa995dfe.png图片

fc_fuse_pass的注册方法

1、创建FcFuser

(1)在lite/core/mir/fusion路径下新建fc_fuser.ccfc_fuser.h 文件 (2)在fc_fuser.h 文件中继承FuseBase定义自己的Fuser类。

  1. #include "lite/core/mir/pattern_matcher_high_api.h"
  2. namespace paddle {
  3. namespace lite {
  4. namespace mir {
  5. namespace fusion {
  6. class FcFuser : public FuseBase {
  7. public:
  8. void BuildPattern() override;
  9. void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
  10. private:
  11. cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override;
  12. };
  13. } // namespace fusion
  14. } // namespace mir
  15. } // namespace lite
  16. } // namespace paddle

主要接口FuseBase::BuildPattern : 描述需要替换位置的图结构(pattern),Fuser运行时会自动查找并替换该pattern。 FuseBase::GenOpDesc : 创建融合后的等效Fused_op。 FuseBase::InsertNewNode :用Fused_op替换原始图结构(pattern)。

对于 FcFuser:BuildPattern描述的Pattern是mul+elementwise add,GenOpDesc创建的FC_op,InsertNewNode函数的效果是用新建的FC_op替换模型中的mul+elementwise add pattern。

(3) 在fc_fuser.cc文件中实现 BuildPattern()GenOpDesc()InsertNewNode()接口

下面以FcFuser为例介绍三种接口的实现:

  1. // 1. BuildPattern函数,描述需要替换的图结构
  2. // FcFuser::BuildPattern() 描述了 mul + element_wise add 图结构
  3. void FcFuser::BuildPattern() {
  4. // (1) 用OpNode描述和VarNode
  5. // mul OP
  6. auto* mul = OpNode("mul", "mul");
  7. // mul OP 的输入和输出
  8. auto* x = VarNode("x")->assert_is_op_input("mul", "X");
  9. auto* W = VarNode("W")->assert_is_op_input("mul", "Y");
  10. auto* mul_out = VarNode("mul_out");
  11. // elementwise_add OP
  12. auto* add = OpNode("add", "elementwise_add");
  13. //elementwise_add 的输入
  14. auto* b = VarNode("b")->assert_is_persistable_var();
  15. // elementwise_add OP的输出(最终输出)
  16. auto* Out = VarNode("Out");
  17. //(2) 描述拓扑连接 (Fuse之前mul 和elementwise_add的连接)
  18. std::vector<PMNode*> mul_inputs{W, x};
  19. std::vector<PMNode*> add_inputs{mul_out, b};
  20. mul_inputs >> *mul >> *mul_out;
  21. add_inputs >> *add >> *Out;
  22. //(3) 声明新的拓扑结构中将会被移除的节点,包括被fuse的OP和OP之间的中间变量
  23. mul_out->AsIntermediate();
  24. mul->AsIntermediate();
  25. add->AsIntermediate();
  26. }
  27. // 2. GenOpDesc函数新建等效 Fused_op
  28. // FcFuser::GenOpDesc() 新建了Fc_op
  29. cpp::OpDesc FcFuser::GenOpDesc(const key2nodes_t& matched) {
  30. // (1) 得到第一个OP节点的 OpDesc ,并清空输入输出信息
  31. cpp::OpDesc op_desc = *matched.at("mul")->stmt()->op_info();
  32. op_desc.mutable_inputs()->clear();
  33. op_desc.mutable_outputs()->clear();
  34. // (2) 修改OpDesc , 将OpType设置为 "fc" (FC OP 的OP_type),
  35. op_desc.SetType("fc");
  36. // (3) 设置OpDesc中的Input、Output、Attrbute。分别连接到BuildPattern()函数中创建的VarNode
  37. op_desc.SetInput("Input", {matched.at("x")->arg()->name});
  38. op_desc.SetInput("W", {matched.at("W")->arg()->name});
  39. op_desc.SetInput("Bias", {matched.at("b")->arg()->name});
  40. op_desc.SetOutput("Out", {matched.at("Out")->arg()->name});
  41. op_desc.SetAttr(
  42. "in_num_col_dims",
  43. matched.at("mul")->stmt()->op_info()->GetAttr<int>("x_num_col_dims"));
  44. return op_desc;
  45. }
  46. // 3. InsertNewNode函数用Fused OP 替换模型图中的原始 Pattern
  47. // FcFuser::InsertNewNode() 用Fc_OP替换原始模型图中的 " mul + element_wise add "
  48. void FcFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) {
  49. // (1) 创建FC OP的参数(OpDesc)
  50. auto op_desc = GenOpDesc(matched);
  51. // 创建一个 FC OP
  52. auto fc_op = LiteOpRegistry::Global().Create("fc");
  53. // 找到原拓扑结构中的scope (作用域)和 valid_places (可支持设备类型)
  54. auto mul = matched.at("mul")->stmt()->op();
  55. auto* scope = mul->scope();
  56. auto& valid_places = mul->valid_places();
  57. // (2) 将 FC OP的 scope和 valid_places设置与fuse前相同,并在图中创建该节点(node)
  58. fc_op->Attach(op_desc, scope);
  59. auto* new_op_node = graph->GraphCreateInstructNode(fc_op, valid_places);
  60. // (3) 将FC节点连接到输入输出(var_node)
  61. IR_NODE_LINK_TO(matched.at("W"), new_op_node);
  62. IR_NODE_LINK_TO(matched.at("x"), new_op_node);
  63. IR_NODE_LINK_TO(matched.at("b"), new_op_node);
  64. IR_NODE_LINK_TO(new_op_node, matched.at("Out"));
  65. }

2、注册fc_fuse_pass

(1)在lite/core/mir/fusion路径下新建fc_fuse_pass.ccfc_fuse_pass.h 文件 (2)在fc_fuse_pass.h 文件中,继承ProgramPass定义FcFusePass

  1. #include "lite/core/mir/pass.h"
  2. namespace paddle {
  3. namespace lite {
  4. namespace mir {
  5. class FcFusePass : public ProgramPass {
  6. public:
  7. void Apply(const std::unique_ptr<SSAGraph>& graph) override; namespace mir namespace lite namespace paddle

(3)在fc_fuse_pass.cc 文件中实现FcFusePass::Apply()接口,并注册FcFusePass

  1. #include "lite/core/mir/pass_registry.h"
  2. #include "lite/core/mir/example_pass.h"
  3. namespace paddle {
  4. namespace lite {
  5. namespace mir {
  6. void FcFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
  7. fusion::FcFuser fuser;
  8. fuser(graph.get());namespace mir
  9. } // namespace lite
  10. } // namespace paddle
  11. REGISTER_MIR_PASS(lite_fc_fuse_pass, paddle::lite::mir::FcFusePass)
  12. .BindTargets({TARGET(kAny)}) // FcFusePass 可以在任何硬件平台执行
  13. .BindKernel("fc"); // FcFusePass 绑定 fc_kernel

(4)修改lite/core/mir/fusion/CMakeLists.txt文件,将fc_fuser.cc 编译到mir_fusers

  1. lite_cc_library(fuse_fc
  2. SRCS fc_fuser.cc
  3. DEPS pattern_matcher_high_api)
  4. set(mir_fusers
  5. fuse_fc
  6. ...
  7. CACHE INTERNAL "fusers")

(5)修改lite/core/mir/CMakeLists.txt文件,将fc_fuse_pass.cc 编译到mir_pass

  1. lite_cc_library(mir_passes
  2. SRCS
  3. fusion/fc_fuse_pass.cc
  4. ...
  5. DEPS mir_pass types context ${mir_fusers} ${subgraph_passes})

3、使用 fc_fuse_pass

(1) lite/api/paddle_use_passes.h使用USE_LITE_PASS宏来引入新加入的pass

  1. USE_MIR_PASS(lite_fc_fuse_pass);

(2) 在lite/core/optimizer.h文件的Optimizer::Run()函数中添加新注册的pass

  1. class Optimizer {
  2. public:
  3. void Run(Program&& program,
  4. const std::vector<Place>& valid_places,
  5. core::KernelPickFactor kernel_pick_factor,
  6. const std::vector<std::string>& passes = {}) {
  7. ...
  8. if (passes.empty()) {
  9. RunPasses(std::vector<std::string>{
  10. {"lite_fc_fuse_pass", // the newly registered pass
  11. ...
  12. "argument_type_display_pass"}});
  13. } else {
  14. RunPasses(passes);
  15. }
  16. exec_scope_ = program.exec_scope();
  17. }

(3) 以上修改完成后,在CreatePredictor(CxxConfig)创建CxxPredictor时,模型优化过程会调用lite_fc_fuse_pass,扫描mul + element_wise add结构并替换为等效的Fc_OP。