TVM LoadFromFile函数分析
概述
本文主要分析动态库(.so)文件的加载过程,其中包括模块的反序列化和GraphExecuter 的创建流程。
动态库加载
当使用函数relay.build将模型转换为GraphExecutorFactoryModule类型的lib对象后,就可以在python端和C++端调用库文件去执行模型推理。
1 | lib = relay.build(model, target, params) |
直接构建
由于lib的类型为GraphExecutorFactory,且在编译阶段,如果用户不特别指定模型名,默认模型名为default。
1 | def build(ir_mod, target=None, target_host=None, params=None, mod_name="default"): |
可以基于lib对象直接构建运行模块,其中lib["default"](dev)
操作相当于调用GetFunction函数传入参数dev并返回模块名为default的runtime Module。
1 | PackedFunc GraphExecutorFactory::GetFunction( |
然后使用python/tvm/contrib/graph_executor.py
文件中的GraphModule类型对返回的runtime Module进行封装。
1 | graph_mod = tvm.contrib.GraphModule(lib["default"](dev)) |
上述封装方便使用统一接口进行模型输入设定、运行和获取输出结果等。
1 | graph_mod.set_input("x", data) |
基于动态库文件构建
基于动态库文件的构建方式,首先需要调用GraphExecutorFactoryModule类中的export_library函数,将lib对象导出为xxx.so
动态库的形式。然后在需要构建时调用load_module或LoadFromFile函数加载动态库并生成相应的runtime::Module对象。
- python端
在python/tvm/runtime/module.py
文件中,使用tvm.runtime.load_module函数加载动态库。
- C++端
在src/runtime/module.cc
文件中,使用tvm::runtime::Module::LoadFromFile函数加载动态库。
load_module
首先判定待加载文件是否存在,然后根据文件后缀类型是否为.o
或.tar
,自动调用cc.create_shared编译器进行文件编译并生成.so
库文件并追加到path中。支持该操作是为了与RPC加载保持一致。最后调用全局函数ModuleLoadFromFile,通过TVM_REGISTER_GLOBAL宏将Module::LoadFromFile函数注册为全局函数runtime.ModuleLoadFromFile。
1 | def load_module(path, fmt=""): |
LoadFromFile
首先进行文件的类型判断,如果类型为dll、dylib或dso,统一调用runtime.module.loadfile_so函数进行文件处理。然后根据文件后缀类型,调用不同的加载函数std::string load_f_name = "runtime.module.loadfile_" + fmt
。
1 | Module Module::LoadFromFile(const std::string& file_name, const std::string& format) { |
runtime.module.loadfile_so
在src/runtime/dso_library.cc
文件中注册了runtime.module.loadfile_so全局函数。首先创建DSOLibrary对象,并将待加载的文件名参数传入到init函数中,用于初始化DSOLibrary对象n。然后调用CreateModuleFromLibrary函数,将创建的DSOLibrary对象n传入该函数中,该函数位于文件src/runtime/library_module.cc
中。
1 | TVM_REGISTER_GLOBAL("runtime.module.loadfile_so").set_body([](TVMArgs args, TVMRetValue* rv) { |
DSOLibrary
根据系统平台不同,DSOLibrary对象会调用不同的Load函数。
1 | void Init(const std::string& name) { Load(name); } |
- Windows
使用#include <windows.h>
函数库加载动态库。
1 | // Load the library |
- Linux
使用#include <dlfcn.h>
函数库来加载动态库。
1 | // Library handle |
CreateModuleFromLibrary
该函数主要实现从动态库中恢复runtime::Module对象的过程,包括反序列化和重新构建模块间导入的关系。
- InitContextFunctions
调用InitContextFunctions函数从动态库中获取上下文相关函数的函数句柄。
1 | InitContextFunctions([lib](const char* fname) { return lib->GetSymbol(fname); }); |
- Load the imported modules
从动态库中获取__tvm_dev_mblob
符号,将符号地址返回到dev_mblob对象。
1 | const char* dev_mblob = reinterpret_cast<const char*>(lib->GetSymbol(runtime::symbol::tvm_dev_mblob)); |
如果dev_mblob != nullptr
,即动态库中存在__tvm_dev_mblob
符号,说明动态库中存在序列化的imported modules,则调用ProcessModuleBlob函数对其进行处理。否则表明动态库中只有一个简单的DSO Module,直接将动态库转为LibraryModuleNode对象n,并基于n构造runtime::Module对象root_mod。
1 | if (dev_mblob != nullptr) { |
InitContextFunctions
函数InitContextFunctions用于初始化上下文函数,将动态库对应的函数的地址映射到相应函数名,即从动态库中获取相关功能函数的函数句柄。
1 | // Initialize the functions |
- ProcessModuleBlob
该函数主要实现对序列化的module进行反序列化处理,恢复导入模块的关系。
- 获取字段大小
首先根据dev_mblob字段的前8个字节,获取blob字段的大小。
1 | uint64_t nbytes = 0; |
- 构建数据流
根据blob字段的大小和地址,创建dmlc::MemoryFixedSizeStream fs对象,然后将fs对象转化为数据流。
1 | dmlc::MemoryFixedSizeStream fs(const_cast<char*>(mblob + sizeof(nbytes)), |
- 读取blob数量
首先读取blob数量,然后根据数量大小,循环处理每个blob的对象。
1 | uint64_t size; |
- 处理blob对象
在处理blob对象时,根据type_key类型不同,分别进行不同方式的处理。
1 | std::vector<Module> modules; |
type_key == "_lib"
如果类型键为_lib
,直接将动态库转化为LibraryModuleNode对象,然后调用runtime::Module构造函数,进一步转化为dso_module对象,并且DSO的上下文地址也指向dso_module对象。
1 | auto dso_module = Module(make_object<LibraryModuleNode>(lib)); |
type_key == "_import_tree"
如果类型键为_import_tree
,则依次读取import_tree_row_ptr和import_tree_child_indices数组的内容,用于后续构建导入模块的关系。
1 | ICHECK(stream->Read(&import_tree_row_ptr)); |
other type_key
如果类型键非_lib
或_import_tree
类型,则调用LoadModuleFromBinary函数从二进制流中根据type_key加载相应的runtime::Module,并将模型推送进modules对象中。
1 | auto m = LoadModuleFromBinary(tkey, stream); |
在函数LoadModuleFromBinary中,调用注册的全局函数runtime.module.loadbinary_ + type_key
,实现二进制数据流到相应runtime::Module的转换。目前TVM框架中已经存在的type_key
包括如下:
arm_compute_lib | bnns_json | coreml | dnnl_json | ethos-n | tensorrt | cuda | opencl | vulkan |
---|
1 | Module LoadModuleFromBinary(const std::string& type_key, dmlc::Stream* stream) { |
- 重新构建模块关系
如果读取的import_tree_row_ptr为空,说明动态库中无导入树,可能使用的是老版本的动态库格式。直接使用动态库lib构建LibraryModuleNode类型对象n,然后获取实例对象n的imports_
成员的地址并赋给module_import_addr变量。将从二进制数据流中恢复的所有modules都追加到module_import_addr中。
1 | auto n = make_object<LibraryModuleNode>(lib); |
如果存在import_tree_row_ptr对象,则按照CSR格式恢复模块间的关系。
1 | for (size_t i = 0; i < modules.size(); ++i) { |
lookup symbol from boot
如果动态库中存在__tvm_module_ctx
符号,则将该根地址指向DSO模块地址dso_ctx_addr。从而允许从根模块查找符号,实现所有符号都可见。
1 | // allow lookup of symbol from root (so all symbols are visible). |