One Attention Head To Rule Them All
Don't spend GPU memory all at one place :)
Inference time KV cache storage in LLMs has become a common practice. Although storing the KV cache speeds up generating new tokens from a prompt, it becomes problematic with a long context window or larger batch size. In this blog, I’ll talk about Multi Query Attention as a means to lower the memory footprint of this cache.
Let’s first understand how token generation happens inside a decoder-only model (eg chatGPT, Mistral, Llama, etc.) and why we need KV caching. Below is an overall view of how tokens are generated iteratively in the decoder-only transformer architecture.

The figure below shows the inside of a decoder block. It consists of a multi-head attention block where all the magic happens.

Each attention head in the multi-head attention block computes its own K and V matrices. One thing worth mentioning here is that the attention scores are calculated with keeping causality in mind. In other words, the attention scores for a token are calculated only for the tokens that came before it. For eg. for the token “you”, only the tokens “who” and “are” are taken into consideration and the ones occurring after are masked.
“Who are you? I am”
“Who are you? I am a”
“Who are you? I am a large”
“Who are you? I am a large language”
In all four cases above, the final vector for “I” is identical as it only depends upon “Who are you?”. So, for predicting the next word “model” the only new information we need is the K, V and Q vectors of “language” as the rest could be used from the previous step. This is called KV caching - incrementally storing the K and V values and using them in the next steps. All we have to do is -
Calculate the query, key and value vectors for “language”
Retrieve the key and value vectors for all the other tokens
Calculate attention scores for “language” and then the output vector
Now think of multiple attention heads - we’ll store the KV cache for all the heads separately. Now think of long context prompts - K and V values of a lot of tokens across multiple heads. Yes, you are thinking right, the memory footprint is too damn high. The cache grows significantly and takes up a good chunk of the GPU memory. We obviously don’t want that. There have been quite a few innovations around this problem categorized as “inference time optimization”. One of these optimizations is Multi Query Attention.
Multi Query Attention
Every attention head in the self-attention block in a decoder has its own set of W_k, W_q and W_v matrices for calculating key, query and value vectors respectively. As discussed above, this contributes to the growing cache problem.
Noam Shazeer from Google proposed a solution to have common W_k and W_v matrices across all attention heads. This results in having common key and value vectors for the tokens for all heads which ultimately reduces the cache considerably. I know what you are thinking - wouldn’t this type of “quantization” (probably not the best term for this) hurt the quality of the model’s responses? Not too much! The author did experiments to prove that it’s only slightly worse than the baseline -
At the same time, they saw huge gains in the inference speed, especially in the decoder. For the baseline, it took 46 microseconds per token in the decoder compared to just 3.8 microseconds (12x faster) for the multi-query attention.
The main reason for this speed gain is the reduced KV cache in the memory. This frees up the GPU for processing more tokens at a time (larger batch) and increases the generation rate. The amount of data that needs to be read from the memory is also reduced so that also helps with the speed. Folks at character.ai have claimed to reduce their KV cache 8x by using MQA.