In the previous post, we derived several linear attention mechanisms from scratch by formulating them as test-time online regression problems. Here, we’ll discuss a more intuitive way to represent the update rules of the internal states of these linear attention mechanisms using a blocked matrix formulation. Then, we’ll discuss how to use it to (1) derive the update rules for linear attention mechanisms that take multiple gradient descent steps per token and (2) derive the update rules for chunk-wise parallelism of already-existing linear attention mechanisms.
Linear attention mechanisms typically have an update rule of the form:
Si=Si−1Ai+Bi
where Si−1 is the (old) state after processing the first i−1 tokens, Si is the (new) state after processing the first i tokens, and Ai and Bi are update matrices. Think of Ai as an operation that modifies some information already stored in the state while Biadds new information to the state. In most cases where Ai=I, Ai typically removes some (old) information from the state. But if we allow Ai to have negative eigenvalues, then we can also think of it as an operation that, in a sense, inverts information instead.
Here are a couple of examples:
Linear Attention Mechanism
Ai
Bi
Vanilla Linear Attention
I
vikiT
Mamba 2
diag(αiI)
vikiT
DeltaNet
I−βikikiT
βivikiT
Gated DeltaNet
αi(I−βikikiT)
βivikiT
RWKV-7
diag(wi)−κ^i(ai⊙κ^iT)
vikiT
where ki∈Rdk and vi∈Rdv are the corresponding key-value pair for the i-th token; αi∈[0,1] can be thought of as a data-dependent weight decay that controls how much of the previous state to keep or forget; and βi∈[0,1] can be thought of as a data-dependent learning rate that controls how much of the new information to add to the state.
If we let αi∈[−1,1] for Mamba 2 and βi∈[0,2] for (Gated) DeltaNet, then Ai can have negative eigenvalues while still having norm ∣∣Ai∣∣≤1. This allows the models to learn more complex patterns while maintaining training stability (Grazzi et al., 2025).
Blocked Matrix Formulation of Linear Attention Mechanisms#
Notice that we can rewrite the update rule above as,
SiSi=Si−1Ai+Bi=[Si−1I][AiBi]
or, equivalently,
[SiI]=[Si−1I][Ai0BiI]
\begin{bmatrix}
S_{i} & I
\end{bmatrix} =
\begin{bmatrix}
S_{i-1} & I
\end{bmatrix}
\begin{bmatrix}
A_i & 0 \\
B_i & I
\end{bmatrix}
[SiI]=[Si−1I][AiBi0I]
At training time, we need all of the intermediary states, not just the final state. Thus, we need an efficient way to compute SNS_NSN for all token indices NNN. To do this, let’s unroll the recurrence above:
[SNI]=[SN−1I][AN0BNI][SNI]=[SN−2I][AN−10BN−1I][AN0BNI]⋮[SNI]=[S0I][A10B1I][A20B2I]⋯[AN0BNI]SN=[S0I][A10B1I][A20B2I]⋯[AN0BNI][I0][SNI]=[SN−1I][AN0BNI][SNI]=[SN−2I][AN−10BN−1I][AN0BNI]⋮[SNI]=[S0I][A10B1I][A20B2I]⋯[AN0BNI]SN=[S0I][A10B1I][A20B2I]⋯[AN0BNI][I0]\begin{align*}
\begin{bmatrix}
S_{N} & I
\end{bmatrix} &=
\begin{bmatrix}
S_{N-1} & I
\end{bmatrix}
\begin{bmatrix}
A_N & 0 \\
B_N & I
\end{bmatrix}\\
\begin{bmatrix}
S_{N} & I
\end{bmatrix} &=
\begin{bmatrix}
S_{N-2} & I
\end{bmatrix}
\begin{bmatrix}
A_{N-1} & 0 \\
B_{N-1} & I
\end{bmatrix}
\begin{bmatrix}
A_N & 0 \\
B_N & I
\end{bmatrix}\\
&\vdots\\
\begin{bmatrix}
S_{N} & I
\end{bmatrix} &=
\begin{bmatrix}
S_{0} & I
\end{bmatrix}
\begin{bmatrix}
A_1 & 0 \\
B_1 & I
\end{bmatrix}
\begin{bmatrix}
A_2 & 0 \\
B_2 & I
\end{bmatrix}
\cdots
\begin{bmatrix}
A_N & 0 \\
B_N & I
\end{bmatrix}\\
S_N &=
\begin{bmatrix}
S_{0} & I
\end{bmatrix}
\begin{bmatrix}
A_1 & 0 \\
B_1 & I
\end{bmatrix}
\begin{bmatrix}
A_2 & 0 \\
B_2 & I
\end{bmatrix}
\cdots
\begin{bmatrix}
A_N & 0 \\
B_N & I
\end{bmatrix}
\begin{bmatrix}
I \\
0
\end{bmatrix}
\end{align*}[SNI][SNI][SNI]SN=[SN−1I][ANBN0I]=[SN−2I][AN−1BN−10I][ANBN0I]⋮=[S0I][A1B10I][A2B20I]⋯[ANBN0I]=[S0I][A1B10I][A2B20I]⋯[ANBN0I][I0][SNI][SNI][SNI]SN=[SN−1I][ANBN0I]=[SN−2I][AN−1BN−10I][ANBN0I]⋮=[S0I][A1B10I][A2B20I]⋯[ANBN0I]=[S0I][A1B10I][A2B20I]⋯[ANBN0I][I0]
In practice, we usually initialize S0S_0S0 as the zero matrix. Thus,
SN=[0I][A10B1I][A20B2I]⋯[AN0BNI][I0]SN=[0I][∏i=1NAi0∑i=1N(Bi∏j=i+1NAj)I][I0]SN=∑i=1N(Bi∏j=i+1NAj)SN=[0I][A10B1I][A20B2I]⋯[AN0BNI][I0]SN=[0I][∏i=1NAi0∑i=1N(Bi∏j=i+1NAj)I][I0]SN=∑i=1N(Bi∏j=i+1NAj)\begin{align}
S_N &=
\begin{bmatrix}
0 & I
\end{bmatrix}
\begin{bmatrix}
A_1 & 0 \\
B_1 & I
\end{bmatrix}
\begin{bmatrix}
A_2 & 0 \\
B_2 & I
\end{bmatrix}
\cdots
\begin{bmatrix}
A_N & 0 \\
B_N & I
\end{bmatrix}
\begin{bmatrix}
I \\
0
\end{bmatrix}\\
S_N &=
\begin{bmatrix}
0 & I
\end{bmatrix}
\begin{bmatrix}
\prod_{i=1}^{N} A_i & 0 \\
\sum_{i=1}^{N} \left(B_i \prod_{j=i+1}^{N} A_j\right) & I
\end{bmatrix}
\begin{bmatrix}
I \\
0
\end{bmatrix}\\
S_N &= \sum_{i=1}^{N} \left(B_i \prod_{j=i+1}^{N} A_j\right)
\end{align}SNSNSN=[0I][A1B10I][A2B20I]⋯[ANBN0I][I0]=[0I][∏i=1NAi∑i=1N(Bi∏j=i+1NAj)0I][I0]=i=1∑N(Bij=i+1∏NAj)SNSNSN=[0I][A1B10I][A2B20I]⋯[ANBN0I][I0]=[0I][∏i=1NAi∑i=1N(Bi∏j=i+1NAj)0I][I0]=i=1∑N(Bij=i+1∏NAj)
where (1)→(2)(1) \rightarrow (2)(1)→(2) can be proven by induction.
Equation (1)(1)(1) makes it obvious why and how we can parallelize computation of SNS_NSN, for all NNN, at training time: the updates are merely (blocked) matrix multiplications; matrix multiplications are associative; thus, we can use the (fully-parallel) associative scan algorithm to compute all the intermediary states in O(N)O(N)O(N) time!
Now, what if we take nhn_hnh gradient descent steps per token?
To do this, we can follow the procedure outlined in the DeltaProduct (Siems et al., 2025) paper where they:
Recurrently generate nhn_hnh key-value pairs for each input token,
Update the state using the nhn_hnh key-value pairs, and
Keep only the final key-value pair and discard the rest.
In our formulation, this is equivalent to replacing each update with a product of nhn_hnh updates:
[Ai0BiI]⟶[Ai,10Bi,1I][Ai,20Bi,2I]⋯[Ai,nh0Bi,nhI]
\begin{bmatrix}
A_{i} & 0 \\
B_{i} & I
\end{bmatrix}
\longrightarrow
\begin{bmatrix}
A_{i,1} & 0 \\
B_{i,1} & I
\end{bmatrix}
\begin{bmatrix}
A_{i,2} & 0 \\
B_{i,2} & I
\end{bmatrix}
\cdots
\begin{bmatrix}
A_{i,n_h} & 0 \\
B_{i,n_h} & I
\end{bmatrix}
[AiBi0I]⟶[Ai,1Bi,10I][Ai,2Bi,20I]⋯[Ai,nhBi,nh0I]
where Ai,jA_{i,j}Ai,j and Bi,jB_{i,j}Bi,j are the update matrices for the jjj-th gradient descent step for the iii-th token.
Thus, Equation (1)(1)(1) becomes:
SN=[0I][A1,10B1,1I][A1,20B1,2I]⋯[A1,nh0B1,nhI][A2,10B2,1I]⋯[AN,nh0BN,nhI][I0]SN=[0I][A1,10B1,1I][A1,20B1,2I]⋯[A1,nh0B1,nhI][A2,10B2,1I]⋯[AN,nh0BN,nhI][I0]\begin{align}
S_N =
\begin{bmatrix}
0 & I
\end{bmatrix}
\begin{bmatrix}
A_{1,1} & 0 \\
B_{1,1} & I
\end{bmatrix}
\begin{bmatrix}
A_{1,2} & 0 \\
B_{1,2} & I
\end{bmatrix}
\cdots
\begin{bmatrix}
A_{1,n_h} & 0 \\
B_{1,n_h} & I
\end{bmatrix}
\begin{bmatrix}
A_{2,1} & 0 \\
B_{2,1} & I
\end{bmatrix}
\cdots
\begin{bmatrix}
A_{N, n_h} & 0 \\
B_{N, n_h} & I
\end{bmatrix}
\begin{bmatrix}
I \\
0
\end{bmatrix}
\end{align}SN=[0I][A1,1B1,10I][A1,2B1,20I]⋯[A1,nhB1,nh0I][A2,1B2,10I]⋯[AN,nhBN,nh0I][I0]SN=[0I][A1,1B1,10I][A1,2B1,20I]⋯[A1,nhB1,nh0I][A2,1B2,10I]⋯[AN,nhBN,nh0I][I0]
And if we reindex this as [⋅]k=[⋅]⌈k/nh⌉,(k−1)%nh+1[\cdot]_k = [\cdot]_{\lceil k/n_h \rceil,\space (k-1) \% n_h + 1}[⋅]k=[⋅]⌈k/nh⌉,(k−1)%nh+1, then from equation (3)(3)(3) above, we get:
SN=∑k=1Nnh(Bk∏k’=k+1NnhAk’)SN=∑k=1Nnh(Bk∏k’=k+1NnhAk’)\begin{align}
S_N = \sum_{k=1}^{Nn_h} \left( B_k \prod_{k’=k+1}^{Nn_h} A_{k’}\right)
\end{align}SN=k=1∑Nnh(Bkk’=k+1∏NnhAk’)SN=k=1∑Nnh(Bkk’=k+1∏NnhAk’)
Alternatively, we can also combine the updates for each token into a single update matrix first before multiplying them together:
which, again, if we reindex this as [⋅]k=[⋅]⌈k/nh⌉,(k−1)%nh+1[\cdot]_k = [\cdot]_{\lceil k/n_h \rceil,\space (k-1) \% n_h + 1}[⋅]k=[⋅]⌈k/nh⌉,(k−1)%nh+1, we get:
SN=∑k=1Nnh(Bk∏k’=k+1NnhAk’)S_N = \sum_{k=1}^{Nn_h} \left( B_k \prod_{k’=k+1}^{Nn_h} A_{k’}\right)SN=k=1∑Nnh(Bkk’=k+1∏NnhAk’)
as expected.
Now, let’s derive the SNS_NSN for the linear attention mechanisms in the table above, but this time, with nhn_hnh gradient descent steps per token.
Since the update operations of linear attention mechanisms we discussed above are associative–i.e., the order in which we “combine” the updates doesn’t matter–we can perform the computations in multiple ways:
The Fully Recurrent Form where we update the state as we loop through the tokens/update matrices one by one,
The Fully-Parallel Associative Scan Form where we hierarchically combine the updates in a tree-like structure, and
The Chunk-wise Parallel Form (Hua et al., 2022; Sun et al., 2023) which is a compromise between the two where we divide the sequence into chunks first, combine intra-chunk updates in parallel, and then combine the chunk-level updates in a recurrent manner.
At inference time, the recurrent form works best*. But at training time, we have to be more hardware-aware to squeeze out as much performance as possible. We will discuss more about this in a separate post. But for now, there are two important things to keep in mind:
The GPU Memory Hierarchy. NVIDIA GPUs have a “global”, high-bandwidth memory (HBM) that all threads in all processing units can access, and a smaller, shared memory (SRAM) that threads in the same processing unit can access. The shared memory, being more “local”, has a much lower latency than the HBM. Thus, as much as possible, we want to limit communications between the processing units and the HBM and use the SRAM instead.
The Tensor Cores. Modern NVIDIA GPUs have tensor cores that can perform matrix multiplications much faster. Thus, ideally, we want to maximize the use of matrix multiplications and limit other operations.
Now, parallel associative scan might seem the best choice, and indeed it already suffices for some architectures like Mamba 1. However, it requires a lot more (shared) memory and communication between the processing units (and therefore materialization to the HBM). And it also doesn’t fully utilize the tensor cores. But with chunk-wise parallelism, we only need to store the current state in the shared memory, and use matrix multiplications to compute the next chunk-level state. This way, we don’t have to materialize the SNS_NSNs to the HBM at all, and we can fully utilize the tensor cores. Hence why most flash linear attention kernels use chunk-wise parallelism.
*At inference time, we need to process the input tokens first before generating outputs. This is called the “pre-filling” stage. And chunk-wise parallelism works better here. After that, we can then use the recurrent form to generate the outputs.
A better way to think of chunk-wise parallelism is as multi-step online gradient descent, but instead of updating the state nhn_hnh times per token, we update the state ncn_cnc times per chunk where nc=N/Cn_c = N/Cnc=N/C is the number of tokens per chunk and CCC is the number of chunks. Thus, we just reuse our results from the previous section!
To make the connection more explicit, let’s reindex Equation (1)(1)(1) as [⋅]i=[⋅]⌈i/nc⌉,(i−1)%nc+1[\cdot]_i = [\cdot]_{\lceil i/n_c \rceil,\space (i-1) \% n_c + 1}[⋅]i=[⋅]⌈i/nc⌉,(i−1)%nc+1:
SN=[0I][A10B1I][A20B2I]⋯[Anc0BncI][Anc+10Bnc+1I]⋯[AN0BNI][I0]SN=[0I][A1,10B1,1I][A1,20B1,2I]⋯[A1,nc0B1,ncI][A2,10B2,1I]⋯[AC,nc0BC,ncI][I0]SN=[0I][A10B1I][A20B2I]⋯[Anc0BncI][Anc+10Bnc+1I]⋯[AN0BNI][I0]SN=[0I][A1,10B1,1I][A1,20B1,2I]⋯[A1,nc0B1,ncI][A2,10B2,1I]⋯[AC,nc0BC,ncI][I0]\begin{align*}
S_N &=
\begin{bmatrix}
0 & I
\end{bmatrix}
\begin{bmatrix}
A_{1} & 0 \\
B_{1} & I
\end{bmatrix}
\begin{bmatrix}
A_{2} & 0 \\
B_{2} & I
\end{bmatrix}
\cdots
\begin{bmatrix}
A_{n_c} & 0 \\
B_{n_c} & I
\end{bmatrix}
\begin{bmatrix}
A_{n_c + 1} & 0 \\
B_{n_c + 1} & I
\end{bmatrix}
\cdots
\begin{bmatrix}
A_{N} & 0 \\
B_{N} & I
\end{bmatrix}
\begin{bmatrix}
I \\
0
\end{bmatrix}\\
S_N &=
\begin{bmatrix}
0 & I
\end{bmatrix}
\begin{bmatrix}
A_{1,1} & 0 \\
B_{1,1} & I
\end{bmatrix}
\begin{bmatrix}
A_{1,2} & 0 \\
B_{1,2} & I
\end{bmatrix}
\cdots
\begin{bmatrix}
A_{1,n_c} & 0 \\
B_{1,n_c} & I
\end{bmatrix}
\begin{bmatrix}
A_{2,1} & 0 \\
B_{2,1} & I
\end{bmatrix}
\cdots
\begin{bmatrix}
A_{C, n_c} & 0 \\
B_{C, n_c} & I
\end{bmatrix}
\begin{bmatrix}
I \\
0
\end{bmatrix}\\
\end{align*}SNSN=[0I][A1B10I][A2B20I]⋯[AncBnc0I][Anc+1Bnc+10I]⋯[ANBN0I][I0]=[0I][A1,1B1,10I][A1,2B1,20I]⋯[A1,ncB1,nc0I][A2,1B2,10I]⋯[AC,ncBC,nc0I][I0]SNSN=[0I][A1B10I][A2B20I]⋯[AncBnc0I][Anc+1Bnc+10I]⋯[ANBN0I][I0]=[0I][A1,1B1,10I][A1,2B1,20I]⋯[A1,ncB1,nc0I][A2,1B2,10I]⋯[AC,ncBC,nc0I][I0]
where Ac,iA_{c,i}Ac,i and Bc,iB_{c,i}Bc,i are now the update matrices for the iii-th token within the ccc-th chunk.
And by combining the updates for each chunk as in Equation (6)(6)(6) above, we get:
[Ac∗0Bc∗I]=∏i=1nc[Ac,i0Bc,iI]=[∏i=1ncAc,i0∑i=1nc(Bc,i∏i’=i+1ncAc,i’)I][Ac∗0Bc∗I]=∏i=1nc[Ac,i0Bc,iI]=[∏i=1ncAc,i0∑i=1nc(Bc,i∏i’=i+1ncAc,i’)I]\begin{align}
\begin{bmatrix}
A^*_{c} & 0 \\
B^*_{c} & I
\end{bmatrix}
= \prod_{i=1}^{n_c} \begin{bmatrix}
A_{c,i} & 0 \\
B_{c,i} & I
\end{bmatrix}
= \begin{bmatrix}
\prod_{i=1}^{n_c} A_{c,i} & 0 \\
\sum_{i=1}^{n_c} \left(B_{c,i} \prod_{i’=i+1}^{n_c} A_{c,i’}\right) & I
\end{bmatrix}
\end{align}[Ac∗Bc∗0I]=i=1∏nc[Ac,iBc,i0I]=[∏i=1ncAc,i∑i=1nc(Bc,i∏i’=i+1ncAc,i’)0I][Ac∗Bc∗0I]=i=1∏nc[Ac,iBc,i0I]=[∏i=1ncAc,i∑i=1nc(Bc,i∏i’=i+1ncAc,i’)0I]SC=[0I][A1∗0B1∗I][A2∗0B2∗I]⋯[AC−1∗0BC−1∗I]‾[AC∗0BC∗I][I0]
S_C =
\underline{
\begin{bmatrix}
0 & I
\end{bmatrix}
\begin{bmatrix}
A^*_1 & 0 \\
B^*_1 & I
\end{bmatrix}
\begin{bmatrix}
A^*_2 & 0 \\
B^*_2 & I
\end{bmatrix}
\cdots
\begin{bmatrix}
A^*_{C-1} & 0 \\
B^*_{C-1} & I
\end{bmatrix}
}
\begin{bmatrix}
A^*_C & 0 \\
B^*_C & I
\end{bmatrix}
\begin{bmatrix}
I \\
0
\end{bmatrix}
SC=[0I][A1∗B1∗0I][A2∗B2∗0I]⋯[AC−1∗BC−1∗0I][AC∗BC∗0I][I0]
which has the equivalent cross-chunk recurrent form:
[SCI]=[SC−1I][AC∗0BC∗I]SC=SC−1AC∗+BC∗[SCI]=[SC−1I][AC∗0BC∗I]SC=SC−1AC∗+BC∗\begin{align}
\begin{bmatrix}
S_{C} & I
\end{bmatrix} &=
\begin{bmatrix}
S_{C-1} & I
\end{bmatrix}
\begin{bmatrix}
A^*_C & 0 \\
B^*_C & I
\end{bmatrix}\\
S_C &= S_{C-1}A^*_C + B^*_C
\end{align}[SCI]SC=[SC−1I][AC∗BC∗0I]=SC−1AC∗+BC∗[SCI]SC=[SC−1I][AC∗BC∗0I]=SC−1AC∗+BC∗
Now, let’s derive the SCS_CSC for the linear attention mechanisms in the table above.
We can do this either way, but suppose we chunk the updates first then expand the each of the updates within the chunks into a product of nhn_hnh updates. I.e., we have:
[A(c−1)∗nc+i0B(c−1)∗nc+iI]→reindex[Ac,i0Bc,iI]→expand[Ac,i,10Bc,i,1I][Ac,i,20Bc,i,2I]⋯[Ac,i,nh0Bc,i,nhI]
\begin{bmatrix}
A_{(c-1)*n_c + i} & 0 \\
B_{(c-1)*n_c + i} & I
\end{bmatrix}
\xrightarrow{\text{reindex}}
\begin{bmatrix}
A_{c,i} & 0 \\
B_{c,i} & I
\end{bmatrix}
\xrightarrow{\text{expand}}
\begin{bmatrix}
A_{c,i,1} & 0 \\
B_{c,i,1} & I
\end{bmatrix}
\begin{bmatrix}
A_{c,i,2} & 0 \\
B_{c,i,2} & I
\end{bmatrix}
\cdots
\begin{bmatrix}
A_{c,i,n_h} & 0 \\
B_{c,i,n_h} & I
\end{bmatrix}
[A(c−1)∗nc+iB(c−1)∗nc+i0I]reindex[Ac,iBc,i0I]expand[Ac,i,1Bc,i,10I][Ac,i,2Bc,i,20I]⋯[Ac,i,nhBc,i,nh0I]
where Ac,i,jA_{c,i,j}Ac,i,j and Bc,i,jB_{c,i,j}Bc,i,j are the update matrices for the jjj-th gradient descent step for the iii-th token within the ccc-th chunk.
And from equations (6)(6)(6), (10)(10)(10), and (11)(11)(11), we have:
[Ac∗0Bc∗I]=∏i=1nc[A’c,i0B’c,iI]=∏i=1nc∏j=1nh[Ac,i,j0Bc,i,jI][Ac∗0Bc∗I]=[∏i=1nc∏j=1nhAc,i,j0∑i=1nc∑j=1nh(Bc,i,j(∏j’=j+1nhAc,i,j’)(∏i’=i+1nc∏j=1nhAc,i,j))I][Ac∗0Bc∗I]=∏i=1nc[A’c,i0B’c,iI]=∏i=1nc∏j=1nh[Ac,i,j0Bc,i,jI][Ac∗0Bc∗I]=[∏i=1nc∏j=1nhAc,i,j0∑i=1nc∑j=1nh(Bc,i,j(∏j’=j+1nhAc,i,j’)(∏i’=i+1nc∏j=1nhAc,i,j))I]\begin{align*}
\begin{bmatrix}
A^*_{c} & 0 \\
B^*_{c} & I
\end{bmatrix}
&= \prod_{i=1}^{n_c} \begin{bmatrix}
A’_{c,i} & 0 \\
B’_{c,i} & I
\end{bmatrix}
= \prod_{i=1}^{n_c} \prod_{j=1}^{n_h} \begin{bmatrix}
A_{c,i,j} & 0 \\
B_{c,i,j} & I
\end{bmatrix}\\
\begin{bmatrix}
A^*_{c} & 0 \\
B^*_{c} & I
\end{bmatrix}
&= \begin{bmatrix}
\prod_{i=1}^{n_c} \prod_{j=1}^{n_h} A_{c,i,j} & 0 \\
\sum_{i=1}^{n_c}\sum_{j=1}^{n_h} \left( B_{c,i,j} \left(\prod_{j’=j+1}^{n_h} A_{c,i,j’}\right) \left(\prod_{i’=i+1}^{n_c} \prod_{j=1}^{n_h} A_{c,i,j}\right)\right) & I
\end{bmatrix}
\end{align*}[Ac∗Bc∗0I][Ac∗Bc∗0I]=i=1∏nc[A’c,iB’c,i0I]=i=1∏ncj=1∏nh[Ac,i,jBc,i,j0I]=[∏i=1nc∏j=1nhAc,i,j∑i=1nc∑j=1nh(Bc,i,j(∏j’=j+1nhAc,i,j’)(∏i’=i+1nc∏j=1nhAc,i,j))0I][Ac∗Bc∗0I][Ac∗Bc∗0I]=i=1∏nc[A’c,iB’c,i0I]=i=1∏ncj=1∏nh[Ac,i,jBc,i,j0I]=[∏i=1nc∏j=1nhAc,i,j∑i=1nc∑j=1nh(Bc,i,j(∏j’=j+1nhAc,i,j’)(∏i’=i+1nc∏j=1nhAc,i,j))0I]
Thus,
Ac∗=∏i=1nc∏j=1nhAc,i,jBc∗=∑i=1nc∑j=1nh(Bc,i,j(∏j’=j+1nhAc,i,j’)(∏i’=i+1nc∏j=1nhAc,i,j))Ac∗=∏i=1nc∏j=1nhAc,i,jBc∗=∑i=1nc∑j=1nh(Bc,i,j(∏j’=j+1nhAc,i,j’)(∏i’=i+1nc∏j=1nhAc,i,j))\begin{align*}
A^*_{c} &= \prod_{i=1}^{n_c} \prod_{j=1}^{n_h} A_{c,i,j} \\
B^*_{c} &= \sum_{i=1}^{n_c}\sum_{j=1}^{n_h} \left( B_{c,i,j} \left(\prod_{j’=j+1}^{n_h} A_{c,i,j’}\right) \left(\prod_{i’=i+1}^{n_c} \prod_{j=1}^{n_h} A_{c,i,j}\right)\right)
\end{align*}Ac∗Bc∗=i=1∏ncj=1∏nhAc,i,j=i=1∑ncj=1∑nhBc,i,jj’=j+1∏nhAc,i,j’(i’=i+1∏ncj=1∏nhAc,i,j)Ac∗Bc∗=i=1∏ncj=1∏nhAc,i,j=i=1∑ncj=1∑nhBc,i,jj’=j+1∏nhAc,i,j’(i’=i+1∏ncj=1∏nhAc,i,j)
which we can then plug into Equation (13)(13)(13) to get the cross-chunk recurrence:
or, if we reindex this as [⋅]C,k=[⋅]C,⌈k/nh⌉,(k−1)%nh+1[\cdot]_{C,k} = [\cdot]_{C,\space \lceil k/n_h \rceil,\space (k-1) \% n_h + 1}[⋅]C,k=[⋅]C,⌈k/nh⌉,(k−1)%nh+1, we get:
Not only is the blocked matrix formulation of linear attention mechanisms intuitive, it also makes the connections between different algorithms and computational forms much more obvious. I’d even go as far as to say that we now have the proper abstraction to do an evolutionary search for new linear attention mechanisms ;)
In the next post, we’ll talk about faster ways to calculate Ac∗A^*_{c}Ac∗ and Bc∗B^*_{c}Bc∗ for diagonal and diagonal-plus-low-rank Ac∗A^*_{c}Ac∗ using the WY Representations and the UT Transform. Stay tuned!
@misc{cesista2025blockmatlinearattn,
author = {Franz Louis Cesista},
title = {Blocked Matrix Formulation of Linear Attention Mechanisms},
year = {2025},
url = {https://leloykun.github.io/ponder/blockmat-linear-attn/},
}
Riccardo Grazzi, Julien Siems, Jörg K.H. Franke, Arber Zela, Frank Hutter, Massimiliano Pontil (2025). Unlocking State-Tracking in Linear RNNs Through Negative Eigenvalues. URL https://arxiv.org/abs/2411.12537
Julien Siems, Timur Carstensen, Arber Zela, Frank Hutter, Massimiliano Pontil, Riccardo Grazzi (2025). DeltaProduct: Increasing the Expressivity of DeltaNet Through Products of Householders. URL https://arxiv.org/abs/2502.10297
Angelos Katharopoulos, Apoorv Vyas, Nikolaos Pappas, and François Fleuret. Transformers are rnns: Fast autoregressive transformers with linear attention. In Proceedings of the 37th International Conference on Machine Learning, ICML 2020, 13-18 July 2020, Virtual Event, volume 119 of Proceedings of Machine Learning Research, pp. 5156–5165. PMLR, 2020b. URL http://proceedings.mlr.press/v119/katharopoulos20a.html.
Tri Dao and Albert Gu. Transformers are SSMs: Generalized models and efficient algorithms through structured state space duality. In Proceedings of the 41st International Conference on MachineLearning, volume 235 of Proceedingsof Machine Learning Research, pp. 10041–10071. PMLR, 2024b. URL https://proceedings.mlr.press/v235/dao24a.html.
Songlin Yang, Bailin Wang, Yu Zhang, Yikang Shen, and Yoon Kim (2025). Parallelizing Linear Transformers with the Delta Rule over Sequence Length. URL https://arxiv.org/abs/2406.06484
Songlin Yang, Jan Kautz, Ali Hatamizadeh (2025). Gated Delta Networks: Improving Mamba2 with Delta Rule. URL https://arxiv.org/abs/2412.06464
Weizhe Hua, Zihang Dai, Hanxiao Liu, and Quoc V. Le. Transformer quality in linear time. In Kamalika Chaudhuri, Stefanie Jegelka, Le Song, Csaba Szepesvári, Gang Niu, and Sivan Sabato (eds.), International Conference on Machine Learning, ICML 2022, 17-23 July 2022, Baltimore, Maryland, USA, volume 162 of Proceedings of Machine Learning Research, pp. 9099–9117. PMLR, 2022b. URL https://proceedings.mlr.press/v162/hua22a.html.
Yutao Sun, Li Dong, Shaohan Huang, Shuming Ma, Yuqing Xia, Jilong Xue, Jianyong Wang, and Furu Wei. Retentive network: A successor to transformer for large language models. ArXiv preprint, abs/2307.08621, 2023. URL https://arxiv.org/abs/2307.08621.
Bo Peng, Ruichong Zhang, Daniel Goldstein, Eric Alcaide, Haowen Hou, Janna Lu, William Merrill, Guangyu Song, Kaifeng Tan, Saiteja Utpala, Nathan Wilce, Johan S. Wind, Tianyi Wu, Daniel Wuttke, Christian Zhou-Zheng (2025). RWKV-7 “Goose” with Expressive Dynamic State Evolution. URL https://arxiv.org/abs/2503.14456