TVM LoadFromFile函数分析

概述

本文主要分析动态库(.so)文件的加载过程,其中包括模块的反序列化GraphExecuter 的创建流程。

东台条子泥

动态库加载

当使用函数relay.build将模型转换为GraphExecutorFactoryModule类型的lib对象后,就可以在python端和C++端调用库文件去执行模型推理。

1
lib = relay.build(model, target, params)

直接构建

由于lib的类型为GraphExecutorFactory,且在编译阶段,如果用户不特别指定模型名,默认模型名为default

1
def build(ir_mod, target=None, target_host=None, params=None, mod_name="default"):

可以基于lib对象直接构建运行模块,其中lib["default"](dev)操作相当于调用GetFunction函数传入参数dev并返回模块名为defaultruntime Module

1
2
3
4
5
6
7
8
9
10
11
12
13
PackedFunc GraphExecutorFactory::GetFunction(
const std::string& name, const tvm::runtime::ObjectPtr<tvm::runtime::Object>& sptr_to_self) {
if (name == module_name_) {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
std::vector<Device> devices;
for (int i = 0; i < args.num_args; ++i) {
devices.emplace_back(args[i].operator Device());
}
*rv = this->ExecutorCreate(devices);
});
}
...
}

然后使用python/tvm/contrib/graph_executor.py文件中的GraphModule类型对返回的runtime Module进行封装。

1
graph_mod = tvm.contrib.GraphModule(lib["default"](dev))

上述封装方便使用统一接口进行模型输入设定、运行和获取输出结果等。

1
2
3
graph_mod.set_input("x", data)
graph_mod.run()
graph_mod.get_output(0, out_data)

基于动态库文件构建

基于动态库文件的构建方式,首先需要调用GraphExecutorFactoryModule类中的export_library函数,将lib对象导出为xxx.so动态库的形式。然后在需要构建时调用load_moduleLoadFromFile函数加载动态库并生成相应的runtime::Module对象。

  • python

python/tvm/runtime/module.py文件中,使用tvm.runtime.load_module函数加载动态库。

  • C++

src/runtime/module.cc文件中,使用tvm::runtime::Module::LoadFromFile函数加载动态库。

load_module

首先判定待加载文件是否存在,然后根据文件后缀类型是否为.o.tar,自动调用cc.create_shared编译器进行文件编译并生成.so库文件并追加到path中。支持该操作是为了与RPC加载保持一致。最后调用全局函数ModuleLoadFromFile,通过TVM_REGISTER_GLOBAL宏将Module::LoadFromFile函数注册为全局函数runtime.ModuleLoadFromFile

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
def load_module(path, fmt=""):
if os.path.isfile(path):
path = os.path.realpath(path)
else:
raise ValueError("cannot find file %s" % path)

# c++ compiler/linker
cc = os.environ.get("CXX", "g++")

# High level handling for .o and .tar file.
# We support this to be consistent with RPC module load.
if path.endswith(".o"):
# Extra dependencies during runtime.
from tvm.contrib import cc as _cc

_cc.create_shared(path + ".so", path, cc=cc)
path += ".so"
elif path.endswith(".tar"):
# Extra dependencies during runtime.
from tvm.contrib import cc as _cc, utils as _utils, tar as _tar

tar_temp = _utils.tempdir(custom_path=path.replace(".tar", ""))
_tar.untar(path, tar_temp.temp_dir)
files = [tar_temp.relpath(x) for x in tar_temp.listdir()]
_cc.create_shared(path + ".so", files, cc=cc)
path += ".so"
# Redirect to the load API
return _ffi_api.ModuleLoadFromFile(path, fmt)

LoadFromFile

首先进行文件的类型判断,如果类型为dlldylibdso,统一调用runtime.module.loadfile_so函数进行文件处理。然后根据文件后缀类型,调用不同的加载函数std::string load_f_name = "runtime.module.loadfile_" + fmt

1
2
3
4
5
6
7
8
9
10
11
Module Module::LoadFromFile(const std::string& file_name, const std::string& format) {
std::string fmt = GetFileFormat(file_name, format);
ICHECK(fmt.length() != 0) << "Cannot deduce format of file " << file_name;
if (fmt == "dll" || fmt == "dylib" || fmt == "dso") {
fmt = "so";
}
std::string load_f_name = "runtime.module.loadfile_" + fmt;
const PackedFunc* f = Registry::Get(load_f_name);
Module m = (*f)(file_name, format);
return m;
}
runtime.module.loadfile_so

src/runtime/dso_library.cc文件中注册了runtime.module.loadfile_so全局函数。首先创建DSOLibrary对象,并将待加载的文件名参数传入到init函数中,用于初始化DSOLibrary对象n。然后调用CreateModuleFromLibrary函数,将创建的DSOLibrary对象n传入该函数中,该函数位于文件src/runtime/library_module.cc中。

1
2
3
4
5
TVM_REGISTER_GLOBAL("runtime.module.loadfile_so").set_body([](TVMArgs args, TVMRetValue* rv) {
auto n = make_object<DSOLibrary>();
n->Init(args[0]);
*rv = CreateModuleFromLibrary(n);
});
DSOLibrary

根据系统平台不同,DSOLibrary对象会调用不同的Load函数。

1
void Init(const std::string& name) { Load(name); }
  • Windows

使用#include <windows.h>函数库加载动态库。

1
2
3
4
5
6
7
// Load the library
void Load(const std::string& name) {
// use wstring version that is needed by LLVM.
std::wstring wname(name.begin(), name.end());
lib_handle_ = LoadLibraryW(wname.c_str());
ICHECK(lib_handle_ != nullptr) << "Failed to load dynamic shared library " << name;
}
  • Linux

使用#include <dlfcn.h>函数库来加载动态库。

1
2
3
4
5
6
7
8
9
10
11
12
13
// Library handle

void* lib_handle_{nullptr};

// load the library

void Load(const std::string& name) {

lib_handle_ = dlopen(name.c_str(), RTLD_LAZY | RTLD_LOCAL);

ICHECK(lib_handle_ != nullptr) << "Failed to load dynamic shared library " << name << " " << dlerror();

}
CreateModuleFromLibrary

该函数主要实现从动态库中恢复runtime::Module对象的过程,包括反序列化和重新构建模块间导入的关系。

  • InitContextFunctions

调用InitContextFunctions函数从动态库中获取上下文相关函数的函数句柄。

1
InitContextFunctions([lib](const char* fname) { return lib->GetSymbol(fname); });
  • Load the imported modules

从动态库中获取__tvm_dev_mblob符号,将符号地址返回到dev_mblob对象。

1
const char* dev_mblob = reinterpret_cast<const char*>(lib->GetSymbol(runtime::symbol::tvm_dev_mblob));

如果dev_mblob != nullptr,即动态库中存在__tvm_dev_mblob符号,说明动态库中存在序列化的imported modules,则调用ProcessModuleBlob函数对其进行处理。否则表明动态库中只有一个简单的DSO Module,直接将动态库转为LibraryModuleNode对象n,并基于n构造runtime::Module对象root_mod

1
2
3
4
5
6
7
if (dev_mblob != nullptr) {
ProcessModuleBlob(dev_mblob, lib, &root_mod, &dso_ctx_addr);
} else {
// Only have one single DSO Module
root_mod = Module(n);
dso_ctx_addr = root_mod.operator->();
}
InitContextFunctions

函数InitContextFunctions用于初始化上下文函数,将动态库对应的函数的地址映射到相应函数名,即从动态库中获取相关功能函数的函数句柄。

1
2
3
4
5
6
7
8
// Initialize the functions
TVM_INIT_CONTEXT_FUNC(TVMFuncCall);
TVM_INIT_CONTEXT_FUNC(TVMAPISetLastError);
TVM_INIT_CONTEXT_FUNC(TVMBackendGetFuncFromEnv);
TVM_INIT_CONTEXT_FUNC(TVMBackendAllocWorkspace);
TVM_INIT_CONTEXT_FUNC(TVMBackendFreeWorkspace);
TVM_INIT_CONTEXT_FUNC(TVMBackendParallelLaunch);
TVM_INIT_CONTEXT_FUNC(TVMBackendParallelBarrier);
  • ProcessModuleBlob

该函数主要实现对序列化的module进行反序列化处理,恢复导入模块的关系。

  1. 获取字段大小

首先根据dev_mblob字段的前8个字节,获取blob字段的大小。

1
2
3
4
5
uint64_t nbytes = 0;
for (size_t i = 0; i < sizeof(nbytes); ++i) {
uint64_t c = mblob[i];
nbytes |= (c & 0xffUL) << (i * 8);
}
  1. 构建数据流

根据blob字段的大小和地址,创建dmlc::MemoryFixedSizeStream fs对象,然后将fs对象转化为数据流。

1
2
3
dmlc::MemoryFixedSizeStream fs(const_cast<char*>(mblob + sizeof(nbytes)),
static_cast<size_t>(nbytes));
dmlc::Stream* stream = &fs;
  1. 读取blob数量

首先读取blob数量,然后根据数量大小,循环处理每个blob的对象。

1
2
uint64_t size;
ICHECK(stream->Read(&size));
  1. 处理blob对象

在处理blob对象时,根据type_key类型不同,分别进行不同方式的处理。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
std::vector<Module> modules;
std::vector<uint64_t> import_tree_row_ptr;
std::vector<uint64_t> import_tree_child_indices;
int num_dso_module = 0;

for (uint64_t i = 0; i < size; ++i) {
std::string tkey;
ICHECK(stream->Read(&tkey));
if (tkey == "_lib") {
// construct dso module using lib
} else if (tkey == "_import_tree") {
// READ(_import_tree_row_ptr)
// READ(_import_tree_child_indices)
} else {
// call module.loadbinary_blob_type_key, such as module.loadbinary_cuda to restore.
}
}
  • type_key == "_lib"

如果类型键为_lib,直接将动态库转化为LibraryModuleNode对象,然后调用runtime::Module构造函数,进一步转化为dso_module对象,并且DSO的上下文地址也指向dso_module对象。

1
2
3
4
auto dso_module = Module(make_object<LibraryModuleNode>(lib));
*dso_ctx_addr = dso_module.operator->();
++num_dso_module;
modules.emplace_back(dso_module);
  • type_key == "_import_tree"

如果类型键为_import_tree,则依次读取import_tree_row_ptrimport_tree_child_indices数组的内容,用于后续构建导入模块的关系。

1
2
ICHECK(stream->Read(&import_tree_row_ptr));
ICHECK(stream->Read(&import_tree_child_indices));
  • other type_key

如果类型键非_lib_import_tree类型,则调用LoadModuleFromBinary函数从二进制流中根据type_key加载相应的runtime::Module,并将模型推送进modules对象中。

1
2
auto m = LoadModuleFromBinary(tkey, stream);
modules.emplace_back(m);

在函数LoadModuleFromBinary中,调用注册的全局函数runtime.module.loadbinary_ + type_key,实现二进制数据流到相应runtime::Module的转换。目前TVM框架中已经存在的type_key包括如下:

arm_compute_lib bnns_json coreml dnnl_json ethos-n tensorrt cuda opencl vulkan
1
2
3
4
5
6
Module LoadModuleFromBinary(const std::string& type_key, dmlc::Stream* stream) {
std::string loadkey = "runtime.module.loadbinary_";
std::string fkey = loadkey + type_key;
const PackedFunc* f = Registry::Get(fkey);
return (*f)(static_cast<void*>(stream));
}
  1. 重新构建模块关系

如果读取的import_tree_row_ptr为空,说明动态库中无导入树,可能使用的是老版本的动态库格式。直接使用动态库lib构建LibraryModuleNode类型对象n,然后获取实例对象nimports_成员的地址并赋给module_import_addr变量。将从二进制数据流中恢复的所有modules都追加到module_import_addr中。

1
2
3
4
5
6
7
auto n = make_object<LibraryModuleNode>(lib);
auto module_import_addr = ModuleInternal::GetImportsAddr(n.operator->());
for (const auto& m : modules) {
module_import_addr->emplace_back(m);
}
*dso_ctx_addr = n.get();
*root_module = Module(n);

如果存在import_tree_row_ptr对象,则按照CSR格式恢复模块间的关系。

1
2
3
4
5
6
7
8
9
10
11
12
13
for (size_t i = 0; i < modules.size(); ++i) {
for (size_t j = import_tree_row_ptr[i]; j < import_tree_row_ptr[i + 1]; ++j) {
auto module_import_addr = ModuleInternal::GetImportsAddr(modules[i].operator->());
auto child_index = import_tree_child_indices[j];
ICHECK(child_index < modules.size());
module_import_addr->emplace_back(modules[child_index]);
}
}

ICHECK(!modules.empty()) << "modules cannot be empty when import tree is present";
// invariance: root module is always at location 0.
// The module order is collected via DFS
*root_module = modules[0];
lookup symbol from boot

如果动态库中存在__tvm_module_ctx符号,则将该根地址指向DSO模块地址dso_ctx_addr。从而允许从根模块查找符号,实现所有符号都可见。

1
2
3
4
// allow lookup of symbol from root (so all symbols are visible).
if (auto* ctx_addr = reinterpret_cast<void**>(lib->GetSymbol(runtime::symbol::tvm_module_ctx))) {
*ctx_addr = dso_ctx_addr;
}