TVM LoadFromFile 函数分析
概述
本文主要分析动态库 (.so) 文件的加载过程,其中包括模块的反序列化和 GraphExecuter 的创建流程。

动态库加载
当使用函数 relay.build 将模型转换为 GraphExecutorFactoryModule 类型的 lib 对象后,就可以在 python 端和 C++ 端调用库文件去执行模型推理。
1 | lib = relay.build(model, target, params) |
直接构建
由于 lib 的类型为 GraphExecutorFactory,且在编译阶段,如果用户不特别指定模型名,默认模型名为 default。
1 | def build(ir_mod, target=None, target_host=None, params=None, mod_name="default"): |
可以基于 lib 对象直接构建运行模块,其中 lib["default"](dev)
操作相当于调用 GetFunction 函数传入参数 dev 并返回模块名为 default 的 runtime Module。
1 | PackedFunc GraphExecutorFactory::GetFunction( |
然后使用 python/tvm/contrib/graph_executor.py
文件中的 GraphModule 类型对返回的 runtime Module 进行封装。
1 | graph_mod = tvm.contrib.GraphModule(lib["default"](dev)) |
上述封装方便使用统一接口进行模型输入设定、运行和获取输出结果等。
1 | graph_mod.set_input("x", data) |
基于动态库文件构建
基于动态库文件的构建方式,首先需要调用 GraphExecutorFactoryModule 类中的 export_library 函数,将 lib 对象导出为 xxx.so
动态库的形式。然后在需要构建时调用 load_module 或 LoadFromFile 函数加载动态库并生成相应的 runtime::Module 对象。
- python 端
在 python/tvm/runtime/module.py
文件中,使用 tvm.runtime.load_module 函数加载动态库。
- C++ 端
在 src/runtime/module.cc
文件中,使用 tvm::runtime::Module::LoadFromFile 函数加载动态库。
load_module
首先判定待加载文件是否存在,然后根据文件后缀类型是否为.o
或.tar
,自动调用 cc.create_shared 编译器进行文件编译并生成.so
库文件并追加到 path 中。支持该操作是为了与 RPC 加载保持一致。最后调用全局函数 ModuleLoadFromFile,通过 TVM_REGISTER_GLOBAL 宏将 Module::LoadFromFile 函数注册为全局函数 runtime.ModuleLoadFromFile。
1 | def load_module(path, fmt=""): |
LoadFromFile
首先进行文件的类型判断,如果类型为 dll、dylib 或 dso,统一调用 runtime.module.loadfile_so 函数进行文件处理。然后根据文件后缀类型,调用不同的加载函数 std::string load_f_name = "runtime.module.loadfile_" + fmt
。
1 | Module Module::LoadFromFile(const std::string& file_name, const std::string& format) { |
runtime.module.loadfile_so
在 src/runtime/dso_library.cc
文件中注册了 runtime.module.loadfile_so 全局函数。首先创建 DSOLibrary 对象,并将待加载的文件名参数传入到 init 函数中,用于初始化 DSOLibrary 对象 n。然后调用 CreateModuleFromLibrary 函数,将创建的 DSOLibrary 对象 n 传入该函数中,该函数位于文件 src/runtime/library_module.cc
中。
1 | TVM_REGISTER_GLOBAL("runtime.module.loadfile_so").set_body([](TVMArgs args, TVMRetValue* rv) { |
DSOLibrary
根据系统平台不同,DSOLibrary 对象会调用不同的 Load 函数。
1 | void Init(const std::string& name) { Load(name); } |
- Windows
使用#include <windows.h>
函数库加载动态库。
1 | // Load the library |
- Linux
使用#include <dlfcn.h>
函数库来加载动态库。
1 | // Library handle |
CreateModuleFromLibrary
该函数主要实现从动态库中恢复 runtime::Module 对象的过程,包括反序列化和重新构建模块间导入的关系。
- InitContextFunctions
调用 InitContextFunctions 函数从动态库中获取上下文相关函数的函数句柄。
1 | InitContextFunctions([lib](const char* fname) { return lib->GetSymbol(fname); }); |
- Load the imported modules
从动态库中获取__tvm_dev_mblob
符号,将符号地址返回到 dev_mblob 对象。
1 | const char* dev_mblob = reinterpret_cast<const char*>(lib->GetSymbol(runtime::symbol::tvm_dev_mblob)); |
如果 dev_mblob != nullptr
,即动态库中存在__tvm_dev_mblob
符号,说明动态库中存在序列化的 imported modules,则调用 ProcessModuleBlob 函数对其进行处理。否则表明动态库中只有一个简单的 DSO Module,直接将动态库转为 LibraryModuleNode 对象 n,并基于 n 构造 runtime::Module 对象 root_mod。
1 | if (dev_mblob != nullptr) { |
InitContextFunctions
函数 InitContextFunctions 用于初始化上下文函数,将动态库对应的函数的地址映射到相应函数名,即从动态库中获取相关功能函数的函数句柄。
1 | // Initialize the functions |
- ProcessModuleBlob
该函数主要实现对序列化的 module 进行反序列化处理,恢复导入模块的关系。
- 获取字段大小
首先根据 dev_mblob 字段的前 8 个字节,获取 blob 字段的大小。
1 | uint64_t nbytes = 0; |
- 构建数据流
根据 blob 字段的大小和地址,创建 dmlc::MemoryFixedSizeStream fs 对象,然后将 fs 对象转化为数据流。
1 | dmlc::MemoryFixedSizeStream fs(const_cast<char*>(mblob + sizeof(nbytes)), |
- 读取 blob 数量
首先读取 blob 数量,然后根据数量大小,循环处理每个 blob 的对象。
1 | uint64_t size; |
- 处理 blob 对象
在处理 blob 对象时,根据 type_key 类型不同,分别进行不同方式的处理。
1 | std::vector<Module> modules; |
type_key == "_lib"
如果类型键为_lib
,直接将动态库转化为 LibraryModuleNode 对象,然后调用 runtime::Module 构造函数,进一步转化为 dso_module 对象,并且 DSO 的上下文地址也指向 dso_module 对象。
1 | auto dso_module = Module(make_object<LibraryModuleNode>(lib)); |
type_key == "_import_tree"
如果类型键为_import_tree
,则依次读取 import_tree_row_ptr 和 import_tree_child_indices 数组的内容,用于后续构建导入模块的关系。
1 | ICHECK(stream->Read(&import_tree_row_ptr)); |
other type_key
如果类型键非_lib
或_import_tree
类型,则调用 LoadModuleFromBinary 函数从二进制流中根据 type_key 加载相应的 runtime::Module,并将模型推送进 modules 对象中。
1 | auto m = LoadModuleFromBinary(tkey, stream); |
在函数 LoadModuleFromBinary 中,调用注册的全局函数 runtime.module.loadbinary_ + type_key
,实现二进制数据流到相应 runtime::Module 的转换。目前 TVM 框架中已经存在的 type_key
包括如下:
arm_compute_lib | bnns_json | coreml | dnnl_json | ethos-n | tensorrt | cuda | opencl | vulkan |
---|
1 | Module LoadModuleFromBinary(const std::string& type_key, dmlc::Stream* stream) { |
- 重新构建模块关系
如果读取的 import_tree_row_ptr 为空,说明动态库中无导入树,可能使用的是老版本的动态库格式。直接使用动态库 lib 构建 LibraryModuleNode 类型对象 n,然后获取实例对象 n 的 imports_
成员的地址并赋给 module_import_addr 变量。将从二进制数据流中恢复的所有 modules 都追加到 module_import_addr 中。
1 | auto n = make_object<LibraryModuleNode>(lib); |
如果存在 import_tree_row_ptr 对象,则按照 CSR 格式恢复模块间的关系。
1 | for (size_t i = 0; i < modules.size(); ++i) { |
lookup symbol from boot
如果动态库中存在__tvm_module_ctx
符号,则将该根地址指向 DSO 模块地址 dso_ctx_addr。从而允许从根模块查找符号,实现所有符号都可见。
1 | // allow lookup of symbol from root (so all symbols are visible). |