MHA: That Scalar Weight Thing and Why V Might Be Underrated
Quick MHA Recap
Assuming you already have some idea about basic attention and MHA. Quick recap: you take your input sequence (with embedding dimension d_model), transform it using learned linear layers into Query (Q), Key (K), and Value (V) tensors. Mash Q and K together via a scaled dot product: softmax(Q @ K^T / sqrt(d_k)) to get attention scores. Use those scores to create a weighted average of V, then run through an output linear layer.
The "Multi-Head" part: instead of one big calculation, you have h heads doing smaller calculations in parallel. After getting your main Q, K, V (all d_model deep), you reshape/split their embedding dimension into h chunks, each with dimension d_k = d_model / h. Each head runs attention separately on its chunk, then you concatenate and project back through W_O.
The Scalar Weight Problem
Standard attention calculates one single scalar attention weight a_ij for each query-key pair. That single number then uniformly scales every element of the corresponding Value vector. The mechanism looks at Q_i and K_j, decides "V_j is 70% important overall," and multiplies the whole thing by 0.7. Every dimension gets the same blunt scaling factor.
What If Attention Was Pickier?
Why not give a different weight to each dimension within the Value chunk, instead of applying a single scalar? It feels like that should give the model more freedom. Maybe the relationship between Q_i and K_j tells us that certain dimensions of V_j are highly relevant while others aren't needed for this specific query.
Extreme Case: num_heads = d_model
What happens if you crank heads all the way up? Since head_dim = d_model / num_heads, setting num_heads = d_model gives head_dim = 1.
- Each of the
d_modelheads operates on a single scalar from Q, K, and V. - Attention inside each head is just multiplying two scalars.
- Each head outputs one scalar (a weighted sum of scalars).
- Concatenating all heads reconstructs a
d_modelvector.
You get d_model parallel scalar-level attention calculations. It seems to give dimension-level control, but in an atomized way. Definitely not standard (where d_k is usually ~64), and the compute cost may not be worth it.
Q, K, and V Don't Need the Same Dimension
Q and K must match in dimension (you can't compute Q @ K^T otherwise). But V can have a different dimension d_v. The process:
- Calculate scalar attention scores using Q (dim
d_k) and K (dimd_k). - Weight V vectors (dim
d_v) by those scores. Each head outputs dimd_v. - Concatenate all heads: dimension
h * d_v. - Final projection
W_Omapsh * d_vback tod_model.
The comparison dimension (d_k) is separate from the information dimension (d_v). This flexibility is often glossed over.
V Is Where the Action Is
While Q and K figure out what to focus on, V carries the actual content that gets aggregated and passed forward. Research has found performance improvements from manipulating the Value stream directly: adding value residuals from previous layers, applying activation functions to V before weighting, etc. The Value pathway seems like a high-leverage target for architectural improvements. But that's a whole other post.