Multi-Head, Multi-Query, and Grouped-Query Attention: Which One Should You Use?

Multi-Head, Multi-Query, and Grouped-Query Attention: Which One Should You Use?

In this blog, we will discuss different types of attention mechanisms. First, we will discuss the intro about Multi Head and Multi Query attentions and their limitations. Then we will discuss Group Query Attention(GQA) and why it is needed and we will implement it using the PyTorch Library.

Introduction

In transformer-based architectures, attention heads are crucial in learning long-range dependencies. The traditional Multi Head Attention(MHA) introduced in Attention is All You Need paper has been the standard go-to attention approach, but it has a high computational cost which makes it inefficient for the large language models(LLMS). This inefficiency arises due to the high memory bandwidth required, which becomes a bottleneck during both training and inference.

So, another alternative approach Multi Query Attention(MQA) was introduced in the paper Fast Transformer Decoding: One Write-Head is All You Need. MQA improved upon MHA by sharing the key-value pairs across all queries for different attention heads. This key modification significantly reduced the memory bandwidth costs and sped up the decoding process.

Figure 1: Fig showing the structure of the attention mechanism (taken from GQA paper)

As shown in the figure above, in multi-head attention, each query has it’s own unique key-value pairs while in the multi-query multiple queries share the same key-value pair. We will explore about Grouped Query mechanism down below.

Limitations of MHA and MQA

The limitations of Multi-Head Attention (MHA) become particularly evident in autoregressive tasks, such as text generation, when Transformer models utilize key-value (KV) caches. While the actual computation for attention via matrix multiplication is fast, the real bottleneck arises during the data transfer process. At each decoding step, the system has to copy all the key-value pairs from high-bandwidth memory to lower-bandwidth memory, which is a much slower operation. The issue worsens when the number of attention heads increases, leading to slower inference and significant memory overhead.

While Multi-Query Attention (MQA) addresses the memory bandwidth issue by using the same key-value pair across all query heads, this simplification comes at the cost of the performance of the model. By sharing a single key-value pair for all queries, MQA reduces memory overhead and speeds up the process as keys and values do not have to be loaded in the memory for each head in different layers of attention. However, this significantly limits the model's ability to capture nuanced relationships between queries, leading to a reduction in its capacity to generalize. As a result, MQA struggles with handling complex patterns or dependencies in the data, which can lead to poorer quality predictions and a less robust model, particularly for tasks that require detailed and diverse information processing. This limitation can also lead to training instability, as the model may fail to learn optimal representations for each query.

Grouped Query Attention: The savior

Grouped Query Attention (GQA) is another attention mechanism that was introduced in the paper GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints as a solution to address the limitations of both Multi-Head Attention (MHA) and Multi-Query Attention (MQA). GQA is the middle ground, which balances between the two by retaining the complexity of MHA (enabling it to capture intricate relationships) while addressing the memory bandwidth bottleneck associated with autoregressive decoding in large models.

As shown in the figure above, the key idea behind GQA is to group multiple queries together and share a set of key-value pairs for each group. This reduces the memory overhead by using fewer key-value pairs, yet it still allows the model to capture more complex patterns than MQA, which uses a single key-value pair for all queries. Thus, it helps in getting the quality closer to MHA and also keeps the speed closer to MQA. So, it is a balance between both efficiency and speed.

GQA is particularly beneficial in the decoder layers of a transformer model, where memory bandwidth becomes a bottleneck due to the sequential nature of autoregressive decoding. However, in the encoder layers, the representations are computed in parallel, so the memory bandwidth overhead is not as problematic, which is why GQA is not applied there.

How to implement Grouped Query Attention(GQA)?

We will discuss how to implement GQA, which I coded when I was working on implementing the PaliGemma model.

Figure 2: Constructor of the Gemma_Attention Class Implementing MQA and GQA

The figure above shows the constructor of the Gemma_Attention class, where we initialize the key parameters for the attention mechanism, including key-value heads, the number of query heads, and other configurations.

  • When num_key_value_heads = 1, it is a special case of GQA which is MQA, where all query heads share a single key-value pair.

  • We can switch to GQA by increasing the num_key_value_heads (e.g., 2, 4), where multiple query heads share a set of key-value pairs. This improves scalability and efficiency.

  • The number of query groups (num_key_value_groups) is simply num_heads / num_key_value_heads, which represents the number of query heads per key-value group in GQA.

  • The q_proj, k_proj, and v_proj layers transform the input hidden states into query, key, and value vectors. The input is the embedding_dimension of each token(hidden_size) , and the output dimensions are based on the number of query and key-value heads.

Figure 4,5: Forward method of the class

In the forward function, we implement the attention mechanism, project query, key, and value states, apply rotary embeddings, implement KV Cache, and compute attention weights.
Key insights from the codes from figure 3,4,5:

  • We first project the input hidden states using linear layers:

    • Query (q_proj)(batch_size, seq_len, num_heads, head_dim)

    • Key (k_proj), Value (v_proj)(batch_size, seq_len, num_key_value_heads, head_dim)

After projection, we reshape (view) the tensors to separate attention heads and then transpose them to reorder the dimensions as (batch_size, num_heads, seq_len, head_dim). This is done so that the number of heads comes first, making it easier to perform attention computations across different heads.

  • We update the kvcache if it is enabled by caching key_states and value_states for efficient reuse in autoregressive decoding.

  • Repeat KV Heads function (Figure 3):

    • The _repeat_kv function ensures that key and value heads match the number of query heads (num_heads). This is crucial for handling Grouped Query Attention (GQA) and Multi-Query Attention (MQA).

    • If n_rep == 1 (MQA case): No repetition needed.

    • If n_rep > 1 (GQA case): Expands key-value heads to align with query heads.

I haven’t dived deep into KV caching and Rotary Position Embeddings in this article, but you can read about them in another article.

Conclusion: Which One Should You Use?

The choice between Multi-Head Attention (MHA), Multi-Query Attention (MQA), and Grouped-Query Attention (GQA) depends on the specific trade-offs you are willing to make.

  • If you prioritize model expressiveness and capturing fine-grained interactions, MHA remains the best option—but at a high computational cost.

  • If your goal is faster inference with minimal memory overhead, MQA is a great alternative, but it sacrifices some model quality.

  • GQA, however, strikes the best balance by improving efficiency while maintaining quality close to MHA, making it an ideal choice for large-scale models that need both speed and accuracy.

For training-intensive tasks where memory is not a bottleneck, MHA is preferable.
For fast, real-time inference, especially in autoregressive decoding, GQA is the way to go.

We also explored how to implement GQA in PyTorch, understanding the key components like query-key-value projections, attention computations, and KV caching. With this knowledge, you can now experiment with GQA in your own transformer models and optimize performance based on your specific needs.

References

  1. https://github.com/sisirdhakal/VIT

  2. https://www.youtube.com/watch?v=vAmKB7iPkWw

  3. https://youtu.be/oM4VmoabDAI?t=4608