KV Cache管理-llama.cpp

概述

llama.cpp中的KV Cache管理主要涉及Buffer的分配和KV Cache的管理。其中,Buffer采用静态分配的方式,初始时分配上下文大小Buffer。当序列长度超出设定的上下文时,一种方式通过删除位置相对较远的KV Cache来释放一部分Cache空间,只保留相对较新的KV Cache;另一种是在模型上下文固定的情况下,通过压缩位置嵌入位置信息,实现更长上下文的支持。

黄海森林公园

KV Cache Init

llama_kv_cache_init函数实现了KV CacheBuffer的分配和初始化,并根据模型架构不同,使用不同的数据类型保存缓存数据,其中Mamba结构的模型使用 FP32 类型保存,其它架构默认使用 FP16 类型保存。

llama_kv_cache对象初始化

如下代码所示,cells 容器的大小初始化为 上下文的大小k_lv_l 容器大小初始化为模型的层数,用于存储每层KV向量的数据。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
cache.has_shift = false;

cache.recurrent = llama_model_is_recurrent(&model);
cache.v_trans = !cache.recurrent && !cparams.flash_attn;

cache.head = 0;
cache.size = kv_size;
cache.used = 0;

cache.type_k = type_k;
cache.type_v = type_v;

cache.cells.clear();
cache.cells.resize(kv_size);

cache.k_l.reserve(n_layer);
cache.v_l.reserve(n_layer);

相应的数据结构如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
// ring-buffer of cached KV data
struct llama_kv_cache {
bool has_shift = false;
bool do_defrag = false;
bool recurrent = false; // with recurrent state models, a cell can hold the state for more than one past token
bool v_trans = true; // the value tensor is transposed

// Note: The value of head isn't only used to optimize searching
// for a free KV slot. llama_decode_internal also uses it, so it
// cannot be freely changed after a slot has been allocated.
uint32_t head = 0;
uint32_t size = 0;
uint32_t used = 0; // used cells (i.e. at least one seq_id)

// computed before each graph build
uint32_t n = 0;

ggml_type type_k = GGML_TYPE_F16;
ggml_type type_v = GGML_TYPE_F16;

std::vector<llama_kv_cell> cells;

std::vector<struct ggml_tensor *> k_l; // per layer
std::vector<struct ggml_tensor *> v_l;

std::vector<ggml_context_ptr> ctxs;
std::vector<ggml_backend_buffer_ptr> bufs;

size_t total_size() {
size_t size = 0;
for (auto & buf : bufs) {
size += ggml_backend_buffer_get_size(buf.get());
}
return size;
}
};
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
struct llama_kv_cell {
llama_pos pos = -1;
llama_pos delta = 0;
int32_t src = -1; // used by recurrent state models to copy states
int32_t tail = -1;

std::set<llama_seq_id> seq_id;

bool has_seq_id(const llama_seq_id & id) const {
return seq_id.find(id) != seq_id.end();
}

bool is_empty() const {
return seq_id.empty();
}

bool is_same_seq(const llama_kv_cell & other) const {
return seq_id == other.seq_id;
}
};

上述数据结构图形化表示如下:

llama_kv_cache结构示意图

k_lv_l 容器分配

下述代码的核心功能是为每一层 Transformer 模型创建 KeyValue 张量,并将它们存储到 llama_kv_cache 结构体中。它根据模型超参数(hparams.n_embd_k_gqa(i)hparams.n_embd_k_s())和是否启用 GPU OffLoad 来确定张量的维度和存储位置。这段代码确保了 KV 缓存为每一层模型都分配了足够的内存空间来存储 KeyValue 数据,为后续的推理计算做好了准备。ggml_format_name 函数用于设置张量的名称,方便调试和可视化。 其中,n_embd_k_gqa 变量的大小为 hparams.n_embd_k_gqa(i)hparams.n_embd_k_s() 两个参数之和,其中 hparams.n_embd_k_s() 只有 MambaRWKV 框架才有效。所以 Transformer 框架只有 hparams.n_embd_k_gqa(i) 参数生效,其表达式为 \(n\_head\_embd \times n\_head\_kv\)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
for (int i = 0; i < (int) n_layer; i++) {
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();

ggml_backend_buffer_type_t buft;
if (offload) {
auto * dev = model.dev_layer.at(i).dev;
buft = ggml_backend_dev_buffer_type(dev);
} else {
buft = ggml_backend_cpu_buffer_type();
}
ggml_context * ctx = ctx_for_buft(buft);

if (!ctx) {
LLAMA_LOG_ERROR("%s: failed to create ggml context for kv cache\n", __func__);
return false;
}

ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
ggml_format_name(k, "cache_k_l%d", i);
ggml_format_name(v, "cache_v_l%d", i);
cache.k_l.push_back(k);
cache.v_l.push_back(v);
}

Allocate Buffer

下述代码的主要功能是为每种类型的后端缓冲区分配内存,并进行初始化。它遍历之前创建的上下文映射 ctx_map ,为每个上下文分配一个后端缓冲区,将缓冲区清零以避免 NaNs,然后将缓冲区句柄存储到 cache.bufs 向量中。这段代码确保了 KV 缓存拥有足够的内存空间,并且内存被正确地初始化,为后续的推理计算做好了准备。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
// allocate tensors and initialize the buffers to avoid NaNs in the padding
for (auto it : ctx_map) {
auto * buft = it.first;
auto * ctx = it.second;

ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
if (!buf) {
LLAMA_LOG_ERROR("%s: failed to allocate buffer for kv cache\n", __func__);
return false;
}
ggml_backend_buffer_clear(buf, 0);
LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
cache.bufs.emplace_back(buf);
}

KV Cache Update

如下代码是Decode阶段的 KV cache的更新流程如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
llama_kv_cache_update(&lctx);
// if we have enough unused cells before the current head ->
// better to start searching from the beginning of the cache, hoping to fill it
if (kv_self.head > kv_self.used + 2*n_tokens) {
kv_self.head = 0;
}
const auto slot = llama_kv_cache_find_slot(kv_self, ubatch);
if (!slot) {
return 1;
}
kv_slot_restorer.save(slot);
if (!kv_self.recurrent) {
// a heuristic, to avoid attending the full cache if it is not yet utilized
// after enough generations, the benefit from this heuristic disappears
// if we start defragmenting the cache, the benefit from this will be more important
const uint32_t pad = llama_kv_cache_get_padding(cparams);
kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(llama_kv_cache_cell_max(kv_self), pad)));
//kv_self.n = llama_kv_cache_cell_max(kv_self);
}

llama_kv_cache_update

llama_kv_cache_update 负责维护 KV 缓存的有效性和效率,它通过应用K-shift碎片整理预留最坏情况下的计算图内存来实现这一目标。

K-shift

为什么只有K-shift 没有 V-shift?

GG say:

在注意力机制中,token 的位置通过 RoPE(即隐藏状态的旋转)进行编码。由于 RoPE 编码是可加的,所以我们可以通过使用新旧位置的增量(delta)来应用 RoPE “移动” 缓存的键(Key)。我们不将其应用于值 (Value),因为它们没有显式进行 RoPE。此操作在数学上不等同于从头开始重新计算新上下文,但它的速度要快得多,并且似乎出于某种原因产生了合理的结果。

RoPE编码的可加性

如下代码验证了RoPE编码的可加性,即张量 r1 和 r2 相等。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
const int n_past_0 = 100;
const int n_past_2 = 33;

for (int i = 0; i < ne[2]; ++i) {
((int32_t *) p0->data)[i] = n_past_0 + i;
((int32_t *) p1->data)[i] = n_past_2 - n_past_0;
((int32_t *) p2->data)[i] = n_past_2 + i;
}

x = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);

// 100, 101, 102, ..., 172
struct ggml_tensor * r0 = ggml_rope(ctx0, x, p0, n_rot, mode);
// -67, -67, -67, ..., -67
struct ggml_tensor * r1 = ggml_rope(ctx0, r0, p1, n_rot, mode); // "context swap", i.e. forget n_past_0 - n_past_2 tokens

// 33, 34, 35, ..., 105
struct ggml_tensor * r2 = ggml_rope(ctx0, x, p2, n_rot, mode);
Value 不执行 RoPE

如下图所示,常见LLM中只有 QK 执行 RopE:

llama3网络图
K-Shift 算法

如下代码是 K-Shift 算法的实现,通过构建一个graph实现对 Cache 中的 Key 向量的 RoPE 修正。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
// apply K-shift if needed
if (lctx.model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
ggml_backend_sched_reset(lctx.sched.get());

ggml_cgraph * gf = llama_build_graph_k_shift(lctx);

ggml_backend_sched_alloc_graph(lctx.sched.get(), gf);

llama_set_k_shift(lctx);

llama_graph_compute(lctx, gf, lctx.cparams.n_threads, lctx.threadpool);

need_reserve = true;
}

如下图所示,是 K-Shift 算法流程的示意图:

K-Shift示意图

通过对 K 向量添加位置偏移量,实现其位置的修正。

Defragment

关于KV Cache的碎片整理代码如下,其主要由 llama_kv_cache_defrag_internal 函数实现,并使能 need_reserve

1
2
3
4
5
6
7
8
// defragment the KV cache if needed
if (lctx.kv_self.do_defrag) {
llama_kv_cache_defrag_internal(lctx);

need_reserve = true;

lctx.kv_self.do_defrag = false;
}

llama_kv_cache_defrag_internal 函数主要分为两个功能模块:

更新 ids 映射表

通过更新 ids 映射表, 确定哪些 KV Cache 需要移动到哪里;

  1. 找空洞: 从 KV 缓存的起始位置开始,寻找连续的空闲单元,使用 \(\text{hole_size}\) 变量表示找到的空洞数量。

  2. 找数据: 从 KV 缓存的末尾开始,寻找连续的有效单元,其数量与找到的空洞数量一致。

  3. 移动数据: 将末尾的有效单元移动到空洞的位置,并清空原来的单元。

  4. 循环执行: 重复以上步骤,直到扫描完整个缓存或达到最大移动次数限制。

构建dedragment图

通过构建 dedragment 图实现KV Cache中的张量数据搬移,这里会用到 cpy 算子, 如下图所示,描述了内存碎片整理的过程,其中,蓝色表示找空洞的过程,红色表示找数据的过程:

dedragment.gif

llama_kv_cache_find_slot

llama_kv_cache_find_slot 函数是 KV 缓存管理的核心部分,它负责在缓存中找到合适的空闲 slot,用于存储新的 KV 值。函数根据模型类型的不同采用不同的策略,循环模型注重序列状态的连续性,而非循环模型则更关注找到连续的空闲单元格

kv_slot_restorer.save(slot)

kv_slot_restorer.save(slot) 函数用于保存当前周期下内存插槽的边界位置信息,当出现异常推理时,便于恢复对缓存的占用。

kv_self.n

1
2
3
4
5
6
7
8
if (!kv_self.recurrent) {
// a heuristic, to avoid attending the full cache if it is not yet utilized
// after enough generations, the benefit from this heuristic disappears
// if we start defragmenting the cache, the benefit from this will be more important
const uint32_t pad = llama_kv_cache_get_padding(cparams);
kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(llama_kv_cache_cell_max(kv_self), pad)));
//kv_self.n = llama_kv_cache_cell_max(kv_self);
}

动态调整 kv_self.n 的值来限制参与注意力计算的缓存大小,最小尺寸为pad大小

1
2
3
4
static uint32_t llama_kv_cache_get_padding(const struct llama_cparams & cparams) {
// the FA kernels require padding to avoid extra runtime boundary checks
return cparams.flash_attn ? 256u : 32u;
}

KV Cache Update(Computing)

Prefill and Decode

如下图所示,左图是Prefill阶段的更新的过程,一次性写入所有序列的KV值,右图是Decode阶段的更新过程,每次读取所有的历史KV值,并写入当前序列的KV值:

Prefill
Decode

KV Store

llama.cpp中通过调用 llm_build_kv_store 函数,实现对 KV cache 的保存。

1
llm_build_kv_store(ctx, hparams, cparams, kv, graph, k_cur, v_cur, n_tokens, kv_head, cb, il);
  • K向量
1
2
3
4
struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, kv.k_l[il], n_tokens*n_embd_k_gqa, ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa)*kv_head);

// note: storing RoPE-ed version of K in the KV cache
ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur, k_cache_view));
  • V向量
1
2
3
4
5
6
7
8
9
10
11
if (cparams.flash_attn) {
v_cache_view = ggml_view_1d(ctx, kv.v_l[il], n_tokens*n_embd_v_gqa, ggml_row_size(kv.v_l[il]->type, n_embd_v_gqa)*kv_head);
} else {
// note: the V cache is transposed when not using flash attention
v_cache_view = ggml_view_2d(ctx, kv.v_l[il], n_tokens, n_embd_v_gqa,
( n_ctx)*ggml_element_size(kv.v_l[il]),
(kv_head)*ggml_element_size(kv.v_l[il]));

v_cur = ggml_transpose(ctx, v_cur);
}
ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur, v_cache_view));

The size of cache

对于每个token,需要为每个注意力头和每一层存储两个向量,且向量中的每个元素假设按照 fp16 格式存储,则每个token在内存中按照字节存储的缓存大小为:

\[2\times2\times head\_dim \times n\_kv\_heads \times n\_layers\]

其中, \(head\_dim\) 表示每个头的 keyvalue 的向量大小, \(n\_kv\_heads\) 表示KV注意力头的数量, \(n\_layers\)表示模型的层数。

Model Cache size/token
LLama3.2-1B 32KB
LLama3.2-3B 84KB
LLama3.2-11B-Vision 5MB

为了容纳单一推理任务的完整上下文大小,我们必须相应地分配足够的缓存空间。更多的时候,可能会以小批量进行,那么cache大小还需要进一步扩大为:

\[2\times2\times head\_dim \times n\_kv\_heads \times n\_layers \times max\_context\_lenght \times batch\_size\]

如果想利用LLama3.2-11B-Vision128K个token的完整上下文,且以4个批次推理,则需要缓存的大小将近 2.5TB的大小,这远大于存储模型参数所需要的 22GB大小。

因此,KV缓存的大小限制了两件事:

  1. 系统能够支持的最大上下文大小;
  2. 系统每个推理时批次的最大大小;

How to reduce cache size?

Grouped Query Attention

分组查询注意力 (Grouped-Query-Attention GQA)是multi-head attention的一种变体,减少了KV Cache的大小,大体思想如下图所示:

分组查询方法概述

如下左图是MHA的计算过程,右图是GQA的计算过程,可见GQA可以有效减少KV缓存的大小:

MHA
GQA

Sliding Window Attention

滑动窗口注意力 (SWA)Mistral-7B 用于支持更长上下文长度而无需增加 KV 缓存大小的技术。 SWA是对原始自注意力机制的改进,利用Transformer的堆叠层来关注窗口大小 \(W\) 以外的信息。第 \(k\) 层位置 \(i\) 的隐藏状态 \(h_i\) ,关注于前一层位置在 \(i-W\)\(W\) 之间的全部隐藏状态。递归地,\(h_i\) 可以访问从输入层到距离最远为 \(W\times k\) 个token中的所有token,如下图所示:

SWA

按照下图中的参数,在模型最后一层,使用窗口大小 \(W = 4096\) ,我们具有大约 131K 个 token 的理论注意力跨度。

SWA-params
滚动缓冲区缓存

固定的注意力跨度意味着我们可以使用滚动缓冲区缓存来限制我们的缓存大小。缓存大小固定为 W,时间步 i 的键值对存储在缓存的 \(i \mod W\) 位置。因此,当位置 \(i > W\) 时,缓存中的旧值将被覆盖,缓存的大小停止增长。

滚动缓冲区缓存。缓存大小固定为W = 4
缓存大小固定为W = 6
Prefill and Block

生成序列时,我们需要逐个预测标记,因为每个标记都以之前的标记为条件。然而,提示是预先已知的,我们可以用提示预填充 (k, v) 缓存。如果提示词非常大,我们可以将其分割成更小的片段,并用每个片段预填充缓存。为此,我们可以选择窗口大小作为我们的片段大小。因此,对于每个片段,我们需要计算缓存和片段上的注意力。如下图所示,展示了注意力掩码如何在缓存和块上运作。

Prefill and Block

PagedAttention

PagedAttention 是一种受操作系统中 虚拟内存分页 经典思想启发的注意力算法,与传统的注意力算法不同,PagedAttention允许将连续的键值对存储在非连续的内存空间中。具体来说,PagedAttention 将每个序列的KV缓存划分为块,每个块包含固定数量标记的键值对。在注意力计算过程中,PagedAttention内核高效地识别并提取这些块。

PagedAttention:KV Cache 被划分为块,块在内存空间中不需要连续

由于块不需要在内存中连续,我们可以像操作系统虚拟内存那样更灵活地管理键值对:可以将块视为页面,标记视为字节,序列视为进程。序列中连续的逻辑块通过块表映射到非连续的物理块,新的token生成时,物理块按需分配。

使用PagedAttention的请求示例生成过程

PagedAttention中,内存浪费仅发生在序列的最后一个块。实际上,这导致了接近最佳的内存使用率,浪费仅低于4%。这种内存效率的提升证明非常有益:它允许系统将更多序列一起批处理,提高GPU利用率,从而显著提高吞吐量。

内存共享

PagedAttention的另一个关键优势在于:高效的内存共享。例如,在并行采样中,多个输出序列是从相同的提示生成的。在这种情况下,提示的计算和内存可以在输出序列之间共享。

并行采样示例

PagedAttention 通过其块表自然地实现了内存共享。类似于进程共享物理页面,不同的序列在PagedAttention中可以通过将其逻辑块映射到相同的物理块来共享这些块。为了确保安全共享,PagedAttention 跟踪物理块的引用计数并实现了写时复制机制。

多输出采样请求的示例生成过程

PageAttention的内存共享极大地降低了复杂采样算法(例如并行采样和束搜索)的内存开销,将其内存使用量减少了高达55%。这可以转化为高达2.2倍的吞吐量提升。这使得此类采样方法在LLM服务中变得实用。

循环缓冲区更新

1
2
3
4
5
6
7
8
9
// update the kv ring buffer
{
kv_self.head += n_tokens;

// Ensure kv cache head points to a valid index.
if (kv_self.head >= kv_self.size) {
kv_self.head = 0;
}
}

这段代码实现了 KV 缓存的环形缓冲区更新逻辑。 它将头部指针 kv_self.head 向前移动 n_tokens 个位置,并在超出缓存边界时将其重置为 0。 这使得 KV 缓存可以循环利用存储空间,从而支持处理更长的序列。

缓存碎片管理

llama_kv_cache_defrag_internal 函数通过在 KV 缓存中移动数据来整理碎片,提高缓存利用率。它使用 ggml_graph 来高效地执行数据移动操作。

1
2
3
4
5
6
7
8
9
10
if (cparams.causal_attn && cparams.defrag_thold >= 0.0f) {
const float fragmentation = kv_self.n >= 128 ? 1.0f - float(kv_self.used)/float(kv_self.n) : 0.0f;

// queue defragmentation for next llama_kv_cache_update
if (fragmentation > cparams.defrag_thold) {
//LLAMA_LOG_INFO("fragmentation: %.2f\n", fragmentation);

llama_kv_cache_defrag(kv_self);
}
}

KV Cache管理-无限文本生成

众所周知,大型语言模型(LLM)难以很好地泛化到长度超过训练序列长度的长文本上下文,这在推理过程中使用LLM处理长输入序列时带来了挑战。在这项工作中,我们认为LLM本身具有无需微调即可处理长文本上下文的能力。为实现这一目标,我们提出SelfExtend方法,通过构建双层注意力信息:组注意力邻域注意力来扩展LLM的上下文窗口。

SelfExtend

该方法针对基于旋转位置编码(RoPE)的LLM,其内积仅以相对形式的编码位置信息:

\[\langle f_q(x_m,m),f_k(x_n,n) \rangle =g(x_m,x_n,m-n).\]

大型语言模型为何在推理过程中无法处理超过其预训练上下文窗口长度的序列,根据前人的经验,在未见过的相对位置上,其注意力分布与预训练上下文窗口内的注意力分布不同。故这种失败源于分布外 (OOD) 的相对距离,即神经网络对分布外输入不鲁棒。

Grouped Attention

Grouped Attention(GA)与原始的自注意力机制相同,只是在进行内积之前,对每个token的原始位置应用了FLOOR(向下取整)操作。 我们可以使用FLOOR操作将未见位置映射到预训练上下文窗口内的位置。 \[P_g=P // G_s\]

假设LLM预训练的上下文窗口长度为5,而推理序列的长度为8。下图展示了当输入长度超出预训练上下文窗口大小时出现的位置超出分布外的问题。该矩阵的纵轴代表查询词元的位置,横轴代表键词元的位置。在这种情况下,在相对位置矩阵中,只有橙色的部分在预训练期间可见。灰色区域中的相对位置位于预训练上下文窗口之外。

分组注意力图示

在上述右图中,展示了FLOOR(向下取整)操作的应用方式以及分组自注意力机制的相对位置矩阵。当 \(G_s=2\) 时,查询token和键token的位置由FLOOR (//)从0-7映射到0-3。新的相对位置(蓝色显示)均在预训练上下文窗口的范围内。

分组注意力对模型影响

大型语言模型在没有精确位置信息时似乎能有效工作,但并非十全十美。 如下图所示,虚线表示未进行 FLOOR 运算的原始模型的PPL,实线表示利用 FLOOR 操作的模型PPL。可见,采用分组注意力的大型语言模型在长度超出预训练上下文窗口的序列上能保持相对较低且稳定的困惑度。同时,采用分组注意力机制的模型 PPL 比原始 LLM 略高。

分组注意力使用不同的组大小
  • 如何恢复由分组注意力引起的退化的语言建模能力?

核心思路是在邻近区域重新引入普通注意力,根据一系列的研究表明,与目标token邻近的token,在生成下一个token时,起着至关重要的作用。组注意力可能不显著影响句子整体质量,但需精确定位注意力。保留标准注意力机制,可以确保语言模型捕捉局部上下文细微之处的精确性和有效性。

Self-Extend

SelfExtend是一种无需微调即可增强大型语言模型处理长上下文自然能力的方法,融合了两种不同类型的注意力机制:

  1. Grouped Attention

这是专门为距离较远的标记而设计的,通过对位置进行下取整运算来处理标记之间的长距离关系。

  1. Standard Attention

它使用传统注意力机制来处理指定范围内的相邻词元。

注意:SelfExtend仅在推理过程中修改注意力机制,无需额外微调。

如下图所示,上下文窗口从预训练长度7扩展到(7−4)∗2+4=10

SelfExtend图示

如下图所示,左图是代码实现介绍,右图是对该代码的图形化解释。

SelfExtend代码
SelfExtend注意力生成示意图

KV Cache管理方案分析

目前llama.cpp中关于无限文本生成的KV Cache管理始终保持Cache大小不变,根据是否采用分组注意力机制,采用不同的Cache更新方式。

  • KV Cache更新

ga_n==1 时,会始终保留系统提示词的 KV Cache,从剩余的缓存中移除一半相对当前token较远的token,只保留一半相对较近的token。具体操作的代码实现如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
if (n_past + (int) embd.size() >= n_ctx) {
if (!params.ctx_shift){
LOG_DBG("\n\n%s: context full and context shift is disabled => stopping\n", __func__);
break;
}

if (params.n_predict == -2) {
LOG_DBG("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict);
break;
}

const int n_left = n_past - params.n_keep;
const int n_discard = n_left/2;

LOG_DBG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
n_past, n_left, n_ctx, params.n_keep, n_discard);

llama_kv_cache_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard);
llama_kv_cache_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard);

n_past -= n_discard;

LOG_DBG("after swap: n_past = %d\n", n_past);

LOG_DBG("embd: %s\n", string_from(ctx, embd).c_str());

LOG_DBG("clear session path\n");
path_session.clear();
}

下图描述了上述代码的中的KV Cache更新过程:

KV Cache Update
  • Grouped Attention

ga_n!=1 时,说明采用 grouped self-attention,采用上文提到的 SelfExtend 方法,通过压缩token的相对位置,达到扩展模型上下文的能力,代码实现逻辑如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
// context extension via Self-Extend
while (n_past >= ga_i + ga_w) {
const int ib = (ga_n*ga_i)/ga_w;
const int bd = (ga_w/ga_n)*(ga_n - 1);
const int dd = (ga_w/ga_n) - ib*bd - ga_w;

LOG_DBG("\n");
LOG_DBG("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", ga_i, n_past, ib*bd, ga_i + ib*bd, n_past + ib*bd);
LOG_DBG("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n, (ga_i + ib*bd)/ga_n, (ga_i + ib*bd + ga_w)/ga_n);
LOG_DBG("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", ga_i + ib*bd + ga_w, n_past + ib*bd, dd, ga_i + ib*bd + ga_w + dd, n_past + ib*bd + dd);

llama_kv_cache_seq_add(ctx, 0, ga_i, n_past, ib*bd);
llama_kv_cache_seq_div(ctx, 0, ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n);
llama_kv_cache_seq_add(ctx, 0, ga_i + ib*bd + ga_w, n_past + ib*bd, dd);

n_past -= bd;

ga_i += ga_w/ga_n;

LOG_DBG("\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", n_past + bd, n_past, ga_i);
}

下图描述了上述代码的实现流程:

Grouped Attention

Session Load and Save

llama.cpp支持加载历史会话文件,该文件保存了上次会话时位于KV Cache中的数据,如果匹配到提示词在历史会话中,则在decode阶段可以直接加载其KV Cache, 不需要重新计算token的嵌入,可以提升总体性能。

Session file load

首先调用 llama_state_load_file 函数,然后调用 llama_state_load_file_internal 函数,其代码如下:

  • 文件校验
1
2
3
4
5
6
7
8
9
10
// sanity checks
{
const uint32_t magic = file.read_u32();
const uint32_t version = file.read_u32();

if (magic != LLAMA_SESSION_MAGIC || version != LLAMA_SESSION_VERSION) {
LLAMA_LOG_ERROR("%s: unknown (magic, version) for session file: %08x, %08x\n", __func__, magic, version);
return false;
}
}
  • 加载保存的token计数

从文件中读取token的数量,代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
// load the prompt
{
const uint32_t n_token_count = file.read_u32();

if (n_token_count > n_token_capacity) {
LLAMA_LOG_ERROR("%s: token count in session file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity);
return false;
}

file.read_raw(tokens_out, sizeof(llama_token) * n_token_count);
*n_token_count_out = n_token_count;
}
  • 更新上下文的状态
1
2
3
4
5
6
7
8
9
10
11
12
// restore the context state
{
const size_t n_state_size_cur = file.size - file.tell();

llama_data_read_file data_ctx(&file);
const size_t n_read = llama_state_set_data_internal(ctx, data_ctx);

if (n_read != n_state_size_cur) {
LLAMA_LOG_ERROR("%s: did not read all of the session file data! size %zu, got %zu\n", __func__, n_state_size_cur, n_read);
return false;
}
}

其调用 llama_state_set_data_internal 函数更新上下文的状态:

1
2
3
4
5
6
7
8
9
10
11
12
llama_synchronize(ctx);

data_ctx.read_model_info(ctx);

// set outputs
data_ctx.read_output_ids(ctx);
data_ctx.read_logits(ctx);
data_ctx.read_embeddings(ctx);

data_ctx.read_kv_cache(ctx);

return data_ctx.get_size_read();

其中,KV Cache 使用函数 read_kv_cache 读取更新:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
void read_kv_cache(struct llama_context * ctx, llama_seq_id seq_id = -1) {
uint32_t cell_count;
read_to(&cell_count, sizeof(cell_count));

bool res = read_kv_cache_meta(ctx, cell_count, seq_id) && read_kv_cache_data(ctx, cell_count);

if (!res) {
if (seq_id == -1) {
llama_kv_cache_clear(ctx);
} else {
llama_kv_cache_seq_rm(ctx, seq_id, -1, -1);
}
throw std::runtime_error("failed to restore kv cache");
}
}

其中 read_kv_cache_meta 函数读取 cells 信息, read_kv_cache_data 函数读取每层的 KV cache 数据。