KVCache in Transformers: Accelerating Inference with Efficient Memory Management
In this article, we will discuss the KVCache (Key-Value Cache) which is an inference optimization technique. We will explore the problems of inference and decoder architecture of transformer models. Then we will explore the needs, and limitations of KVCache and implement it using the PyTorch Library.
Introduction
Key-Value Cache also known as KVCache is a highly popular powerful optimization technique that significantly improves the efficiency during inference, especially in autoregressive tasks(i.e. where the future tokens depend upon the previous past tokens).
As large language models (LLMs) continue to grow in size and complexity, deploying them efficiently in real-world applications becomes challenging. One major bottleneck in inference is the repeated computation of attention scores at every step. In every step of inference, we have to recalculate attention scores for all previously generated tokens which are already available to us. The efficiency further worsens as the sequence gets longer, and the amount of computation increases, which results in a quadratic increase in computation and thus slower inference and higher resource usage.
KVCache: The Optimization Hero
So, how do we fix this? Enter **KVCache “**a simple but powerful idea”. The main idea behind KVCache is that we store and reuse the key-value pairs generated during each inference step instead of recomputing them for every token.
By implementing KVCache, instead of recalculating all attention scores at every step, we store the key-value pairs after computing them once. When a new token is generated, we simply append its corresponding key-value pairs to the existing KVCache. This way, the model only processes the new token while leveraging precomputed attention scores from previous tokens thus the matrix multiplication is also much faster as the size of the matrices decreases.
Figure 1: Visualization of KVCache vs. Standard Attention Computation
KVCache in Action: Visualizing the Optimization and its Need
Let’s explain the action of KVCache by understanding how transformer inference typically works. Let's say we want to generate the sentence:
👉 "I love to write code <EOS>"
In autoregressive generation (where a model generates one token at a time), the model predicts each token step by step. If the model is currently generating "code", it first generated “I”, then “love”, “to”, and “write” and now it's using the previous tokens to predict "code".
But here’s the problem: every time the model predicts the next token, it recomputes attention scores for all previous tokens, even though we already know what those tokens are.
This means that when predicting "code", the model doesn’t just compute attention for "write", but also recomputes attention for "I", "love", and "to", even though their key-value representations haven’t changed. This is redundant and wastes a lot of computation, especially as the sequence gets longer.
Now, let’s take a closer look at the GIF (Figure 1) to really understand why KVCache is such a game-changer.
Without KVCache: The Inefficient Way In the “Without KVCache” part of the image, every time we generate a new token, we are recomputing the entire QKᵀ (Query-Key dot product) matrix from scratch. That means all previously computed key-value pairs are recomputed even though they haven’t changed, increasing the size of the matrices and thus making it computationally inefficient.
If you notice the gray boxes in the QKᵀ matrix, those represent values that will be masked. But here’s the kicker, we don’t actually care about them! What we’re really interested in is just the last row, since that’s the only new computation required for the latest token.
The same issue happens when computing the final attention scores. The attention output includes all previous tokens, even though we only care about the last dimension (the attention for our latest token). This unnecessary recomputation results in slower inference and more resource usage, especially for long sequences.
With KVCache: The Smarter Way Now, look at the “With KVCache” part. Here, instead of recomputing everything, we are simply appending the new key-value pairs to our cached KV store. When computing attention, we only focus on the new token(i.e. we feed only the current token as input to the attention mechanism, instead of passing in all previous tokens) while reusing previously stored keys and values. This decreases the size of the matrices and makes matrix multiplication much faster, reducing the redundant calculations, and significantly speeding up inference.
Limitations of KVCache
Although KVCache largely improves the inference efficiency, it has some tradeoffs. However, one of the most striking drawbacks, is the higher memory consumption, because we have to keep track of the key-value pairs of all created tokens. When sequence lengths increase, cache size expands, and this can become an issue, particularly in environments with limited memory. Furthermore, KVCache is also most useful for autoregressive models but doesn’t offer much for models that work in a non-sequential or bidirectional manner (like BERT), where the full context is always needed for prediction. There is also a challenge in accommodating varying sequence lengths, if a model wants to update or drop previous tokens (in for instance editing or reinforcement learning tasks), then a simple caching mechanism will be insufficient. Finally, KVCache does not remove all computational costs as; KVCache may decrease repetitive computations, but the new tokens still require fresh attention computations, and managing cache overhead may sometimes have negligible efficiency cost.
Let’s implement KVCache using PyTorch
We will discuss how to implement KVCache, which I coded when I was working on implementing the Paligemma Model.
Figure 2: KVCache Class
So, let's dive into the KVCache class. As shown in Figure 2, it's a simple way to store key-value pairs so we don’t have to keep recalculating them every time we process a new token. Think of it as a memory bank for the model. In the pre-filling stage (basically, the first time we process a few tokens and store their key-value pairs), the cache is initialized, and once those initial tokens are stored, we're good to go! From here on out, as new tokens arrive, the update()
method simply appends their key-value pairs to the cache. This saves us from doing redundant work, making things much faster.
Figure 3: Forward Method of GemmaAttention
Now, moving on to Figure 3, here’s where KVCache gets integrated into the attention mechanism. In the forward()
method of GemmaAttention, we calculate the query, key, and value states from the input hidden states. If we have a KVCache, we update it with the new key-value pairs. Instead of redoing the attention calculation for every single token in the sequence, we just work with the new token’s query and the cached key-value pairs, which speeds up the whole process.
Figure 4: Inference with KVCache
Finally, in Figure 4, we see KVCache in action during inference. At first, the cache is empty, and during the pre-filling phase, we store the key-value pairs for the first few tokens. Then, as we generate more tokens, the cache gets updated with the new key-value pairs. This means we don’t have to re-calculate attention for the whole sequence, just for the new token. This approach really boosts efficiency, especially when handling long sequences.
By combining pre-filling and updating, KVCache helps the model process sequences faster, making everything run more smoothly and efficiently.
Conclusion
So, to conclude in this article, we explored how KVCache optimizes the inference process by reusing key-value pairs, which makes generating long sequences much faster, especially for autoregressive models. While it comes with some tradeoffs, like higher memory usage, KVCache proves to be a valuable technique for improving efficiency without sacrificing too much computational cost.