In the previous post, we derived several linear attention mechanisms from scratch by formulating them as test-time online regression problems. Here, we’ll discuss a more intuitive way to represent the update rules of the internal states of these linear attention mechanisms using a blocked matrix formulation. Then, we’ll discuss how to use it to (1) derive the update rules for linear attention mechanisms that take multiple gradient descent steps per token and (2) derive the update rules for chunk-wise parallelism of already-existing linear attention mechanisms.

Recap: Linear Attention Mechanisms

Linear attention mechanisms typically have an update rule of the form: Si=Si1Ai+BiS_i = S_{i-1}A_i + B_i where Si1S_{i-1} is the (old) state after processing the first i1i-1 tokens, SiS_i is the (new) state after processing the first ii tokens, and AiA_i and BiB_i are update matrices. Think of AiA_i as an operation that modifies some information already stored in the state while BiB_i adds new information to the state. In most cases where AiIA_i \neq I, AiA_i typically removes some (old) information from the state. But if we allow AiA_i to have negative eigenvalues, then we can also think of it as an operation that, in a sense, inverts information instead.

Here are a couple of examples:

Linear Attention MechanismAiA_iBiB_i
Vanilla Linear AttentionIIvikiT\bm{v}_i \bm{k}_i^T
Mamba 2diag(αiI)\text{diag}\left(\alpha_i I\right)vikiT\bm{v}_i \bm{k}_i^T
DeltaNetIβikikiTI - \beta_i \bm{k}_i \bm{k}_i^TβivikiT\beta_i \bm{v}_i \bm{k}_i^T
Gated DeltaNetαi(IβikikiT)\alpha_i(I - \beta_i \bm{k}_i \bm{k}_i^T)βivikiT\beta_i \bm{v}_i \bm{k}_i^T
RWKV-7diag(wi)κ^i(aiκ^iT)\text{diag}(\bm{w}_i) - \bm{\hat{\kappa}}_i(\bm{a}_i \odot\bm{\hat{\kappa}}_i^T)vikiT\bm{v}_i \bm{k}_i^T

where kiRdk\bm{k}_i \in \mathbb{R}^{d_k} and viRdv\bm{v}_i \in \mathbb{R}^{d_v} are the corresponding key-value pair for the ii-th token; αi[0,1]\alpha_i \in [0, 1] can be thought of as a data-dependent weight decay that controls how much of the previous state to keep or forget; and βi[0,1]\beta_i \in [0, 1] can be thought of as a data-dependent learning rate that controls how much of the new information to add to the state.

If we let αi[1,1]\alpha_i \in [-1, 1] for Mamba 2 and βi[0,2]\beta_i \in [0, 2] for (Gated) DeltaNet, then AiA_i can have negative eigenvalues while still having norm Ai1||A_i|| \leq 1. This allows the models to learn more complex patterns while maintaining training stability (Grazzi et al., 2025).

Blocked Matrix Formulation of Linear Attention Mechanisms

Notice that we can rewrite the update rule above as,

Si=Si1Ai+BiSi=[Si1I][AiBi] Si=Si1Ai+BiSi=[Si1I][AiBi]\begin{align*} S_i &= S_{i-1}A_i + B_i\\ S_{i} &= \begin{bmatrix} S_{i-1} & I \end{bmatrix} \begin{bmatrix} A_i \\ B_i \end{bmatrix} \end{align*} or, equivalently, [SiI]=[Si1I][Ai0BiI] \begin{bmatrix} S_{i} & I \end{bmatrix} = \begin{bmatrix} S_{i-1} & I \end{bmatrix} \begin{bmatrix} A_i & 0 \\ B_i & I \end{bmatrix}

At training time, we need all of the intermediary states, not just the final state. Thus, we need an efficient way to compute SNS_N for all token indices NN. To do this, let’s unroll the recurrence above:

[SNI]=[SN1I][AN0BNI][SNI]=[SN2I][AN10BN1I][AN0BNI][SNI]=[S0I][A10B1I][A20B2I][AN0BNI]SN=[S0I][A10B1I][A20B2I][AN0BNI][I0] [SNI]=[SN1I][AN0BNI][SNI]=[SN2I][AN10BN1I][AN0BNI][SNI]=[S0I][A10B1I][A20B2I][AN0BNI]SN=[S0I][A10B1I][A20B2I][AN0BNI][I0]\begin{align*} \begin{bmatrix} S_{N} & I \end{bmatrix} &= \begin{bmatrix} S_{N-1} & I \end{bmatrix} \begin{bmatrix} A_N & 0 \\ B_N & I \end{bmatrix}\\ \begin{bmatrix} S_{N} & I \end{bmatrix} &= \begin{bmatrix} S_{N-2} & I \end{bmatrix} \begin{bmatrix} A_{N-1} & 0 \\ B_{N-1} & I \end{bmatrix} \begin{bmatrix} A_N & 0 \\ B_N & I \end{bmatrix}\\ &\vdots\\ \begin{bmatrix} S_{N} & I \end{bmatrix} &= \begin{bmatrix} S_{0} & I \end{bmatrix} \begin{bmatrix} A_1 & 0 \\ B_1 & I \end{bmatrix} \begin{bmatrix} A_2 & 0 \\ B_2 & I \end{bmatrix} \cdots \begin{bmatrix} A_N & 0 \\ B_N & I \end{bmatrix}\\ S_N &= \begin{bmatrix} S_{0} & I \end{bmatrix} \begin{bmatrix} A_1 & 0 \\ B_1 & I \end{bmatrix} \begin{bmatrix} A_2 & 0 \\ B_2 & I \end{bmatrix} \cdots \begin{bmatrix} A_N & 0 \\ B_N & I \end{bmatrix} \begin{bmatrix} I \\ 0 \end{bmatrix} \end{align*}

In practice, we usually initialize S0S_0 as the zero matrix. Thus,

SN=[0I][A10B1I][A20B2I][AN0BNI][I0]SN=[0I][i=1NAi0i=1N(Bij=i+1NAj)I][I0]SN=i=1N(Bij=i+1NAj) SN=[0I][A10B1I][A20B2I][AN0BNI][I0]SN=[0I][i=1NAi0i=1N(Bij=i+1NAj)I][I0]SN=i=1N(Bij=i+1NAj)\begin{align} S_N &= \begin{bmatrix} 0 & I \end{bmatrix} \begin{bmatrix} A_1 & 0 \\ B_1 & I \end{bmatrix} \begin{bmatrix} A_2 & 0 \\ B_2 & I \end{bmatrix} \cdots \begin{bmatrix} A_N & 0 \\ B_N & I \end{bmatrix} \begin{bmatrix} I \\ 0 \end{bmatrix}\\ S_N &= \begin{bmatrix} 0 & I \end{bmatrix} \begin{bmatrix} \prod_{i=1}^{N} A_i & 0 \\ \sum_{i=1}^{N} \left(B_i \prod_{j=i+1}^{N} A_j\right) & I \end{bmatrix} \begin{bmatrix} I \\ 0 \end{bmatrix}\\ S_N &= \sum_{i=1}^{N} \left(B_i \prod_{j=i+1}^{N} A_j\right) \end{align} where (1)(2)(1) \rightarrow (2) can be proven by induction.

Equation (1)(1) makes it obvious why and how we can parallelize computation of SNS_N, for all NN, at training time: the updates are merely (blocked) matrix multiplications; matrix multiplications are associative; thus, we can use the (fully-parallel) associative scan algorithm to compute all the intermediary states in O(N)O(N) time!

One-Step Online Gradient Descent per Token

Let’s derive SNS_N for each of the linear attention mechanisms in the table above.

Vanilla Linear Attention

Show derivation of SNS_NAi=IBi=vikiTA_i = I \quad\quad B_i = \bm{v}_i \bm{k}_i^T From Equation (3)(3) above, we get: SN=i=1N(vikiTj=i+1NI)SN=i=1NvikiT SN=i=1N(vikiTj=i+1NI)SN=i=1NvikiT\begin{align*} S_N &= \sum_{i=1}^{N} \left(\bm{v}_i \bm{k}_i^T \prod_{j=i+1}^{N} I\right)\\ S_N &= \sum_{i=1}^{N} \bm{v}_i \bm{k}_i^T \end{align*}

Mamba 2

Show derivation of SNS_NAi=diag(αiI)Bi=vikiTA_i = \text{diag}\left(\alpha_i I\right) \quad\quad B_i = \bm{v}_i \bm{k}_i^T Thus, SN=i=1N(vikiTj=i+1Ndiag(αjI))SN=i=1N(j=i+1Nαj)vikiT SN=i=1N(vikiTj=i+1Ndiag(αjI))SN=i=1N(j=i+1Nαj)vikiT\begin{align*} S_N &= \sum_{i=1}^{N} \left(\bm{v}_i \bm{k}_i^T \prod_{j=i+1}^{N} \text{diag}\left(\alpha_j I\right)\right)\\ S_N &= \sum_{i=1}^{N} \left( \prod_{j=i+1}^{N} \alpha_j \right) \bm{v}_i \bm{k}_i^T \end{align*}

DeltaNet

Show derivation of SNS_NAi=IβikikiTBi=βivikiTA_i = I - \beta_i \bm{k}_i \bm{k}_i^T \quad\quad B_i = \beta_i \bm{v}_i \bm{k}_i^T Thus, SN=i=1N(βivikiTj=i+1N(IβjkjkjT))S_N = \sum_{i=1}^{N} \left(\beta_i \bm{v}_i \bm{k}_i^T \prod_{j=i+1}^{N} \left(I - \beta_j \bm{k}_j \bm{k}_j^T\right)\right)

Gated DeltaNet

Show derivation of SNS_NAi=αi(IβikikiT)Bi=βivikiTA_i = \alpha_i(I - \beta_i \bm{k}_i \bm{k}_i^T) \quad\quad B_i = \beta_i \bm{v}_i \bm{k}_i^T Thus, SN=i=1N(βivikiTj=i+1Nαj(IβjkjkjT))SN=i=1N((βij=i+1Nαj)vikiTj=i+1N(IβjkjkjT)) SN=i=1N(βivikiTj=i+1Nαj(IβjkjkjT))SN=i=1N((βij=i+1Nαj)vikiTj=i+1N(IβjkjkjT))\begin{align*} S_N &= \sum_{i=1}^{N} \left(\beta_i \bm{v}_i \bm{k}_i^T \prod_{j=i+1}^{N} \alpha_j \left(I - \beta_j \bm{k}_j \bm{k}_j^T\right)\right)\\ S_N &= \sum_{i=1}^{N} \left(\left(\beta_i \prod_{j=i+1}^{N} \alpha_j \right) \bm{v}_i \bm{k}_i^T \prod_{j=i+1}^{N} \left(I - \beta_j \bm{k}_j \bm{k}_j^T\right)\right) \end{align*}

RWKV-7

Show derivation of SNS_NAi=diag(wi)κ^i(aiκ^iT)Bi=vikiTA_i = \text{diag}(\bm{w}_i) - \bm{\hat{\kappa}}_i(\bm{a}_i \odot\bm{\hat{\kappa}}_i^T) \quad\quad B_i = \bm{v}_i \bm{k}_i^T Thus, SN=i=1N(vikiTj=i+1N(diag(wj)κ^j(ajκ^jT)))S_N = \sum_{i=1}^{N} \left(\bm{v}_i \bm{k}_i^T \prod_{j=i+1}^{N} \left(\text{diag}(\bm{w}_j) - \bm{\hat{\kappa}}_j(\bm{a}_j \odot\bm{\hat{\kappa}}_j^T)\right)\right)

Easy!


Multi-Step Online Gradient Descent per Token

Now, what if we take nhn_h gradient descent steps per token?

To do this, we can follow the procedure outlined in the DeltaProduct (Siems et al., 2025) paper where they:

  1. Recurrently generate nhn_h key-value pairs for each input token,
  2. Update the state using the nhn_h key-value pairs, and
  3. Keep only the final key-value pair and discard the rest.

In our formulation, this is equivalent to replacing each update with a product of nhn_h updates:

[Ai0BiI][Ai,10Bi,1I][Ai,20Bi,2I][Ai,nh0Bi,nhI] \begin{bmatrix} A_{i} & 0 \\ B_{i} & I \end{bmatrix} \longrightarrow \begin{bmatrix} A_{i,1} & 0 \\ B_{i,1} & I \end{bmatrix} \begin{bmatrix} A_{i,2} & 0 \\ B_{i,2} & I \end{bmatrix} \cdots \begin{bmatrix} A_{i,n_h} & 0 \\ B_{i,n_h} & I \end{bmatrix} where Ai,jA_{i,j} and Bi,jB_{i,j} are the update matrices for the jj-th gradient descent step for the ii-th token.

Thus, Equation (1)(1) becomes: SN=[0I][A1,10B1,1I][A1,20B1,2I][A1,nh0B1,nhI][A2,10B2,1I][AN,nh0BN,nhI][I0] SN=[0I][A1,10B1,1I][A1,20B1,2I][A1,nh0B1,nhI][A2,10B2,1I][AN,nh0BN,nhI][I0]\begin{align} S_N = \begin{bmatrix} 0 & I \end{bmatrix} \begin{bmatrix} A_{1,1} & 0 \\ B_{1,1} & I \end{bmatrix} \begin{bmatrix} A_{1,2} & 0 \\ B_{1,2} & I \end{bmatrix} \cdots \begin{bmatrix} A_{1,n_h} & 0 \\ B_{1,n_h} & I \end{bmatrix} \begin{bmatrix} A_{2,1} & 0 \\ B_{2,1} & I \end{bmatrix} \cdots \begin{bmatrix} A_{N, n_h} & 0 \\ B_{N, n_h} & I \end{bmatrix} \begin{bmatrix} I \\ 0 \end{bmatrix} \end{align}

And if we reindex this as []k=[]k/nh, (k1)%nh+1[\cdot]_k = [\cdot]_{\lceil k/n_h \rceil,\space (k-1) \% n_h + 1}, then from equation (3)(3) above, we get: SN=k=1Nnh(Bkk=k+1NnhAk) SN=k=1Nnh(Bkk=k+1NnhAk)\begin{align} S_N = \sum_{k=1}^{Nn_h} \left( B_k \prod_{k’=k+1}^{Nn_h} A_{k’}\right) \end{align}

Alternatively, we can also combine the updates for each token into a single update matrix first before multiplying them together:

[Ai0BiI]=j=1nh[Ai,j0Bi,jI]=[j=1nhAi,j0j=1nh(Bi,jj=j+1nhAi,j)I] [Ai0BiI]=j=1nh[Ai,j0Bi,jI]=[j=1nhAi,j0j=1nh(Bi,jj=j+1nhAi,j)I]\begin{align} \begin{bmatrix} A’_{i} & 0 \\ B’_{i} & I \end{bmatrix} = \prod_{j=1}^{n_h} \begin{bmatrix} A_{i,j} & 0 \\ B_{i,j} & I \end{bmatrix} = \begin{bmatrix} \prod_{j=1}^{n_h} A_{i,j} & 0 \\ \sum_{j=1}^{n_h} \left(B_{i,j} \prod_{j’=j+1}^{n_h} A_{i,j’}\right) & I \end{bmatrix} \end{align}

SN=[0I][A10B1I][A20B2I][AN0BNI][I0]SN=[0I][i=1NAi0i=1N(Bii=i+1NAi)I][I0]SN=i=1N(Bii=i+1NAi)SN=i=1Nj=1nh(Bi,j(j=j+1nhAi,j)(i=i+1Nj=1nhAi,j)) SN=[0I][A10B1I][A20B2I][AN0BNI][I0]SN=[0I][i=1NAi0i=1N(Bii=i+1NAi)I][I0]SN=i=1N(Bii=i+1NAi)SN=i=1Nj=1nh(Bi,j(j=j+1nhAi,j)(i=i+1Nj=1nhAi,j))\begin{align} S_N &= \begin{bmatrix} 0 & I \end{bmatrix} \begin{bmatrix} A’_1 & 0 \\ B’_1 & I \end{bmatrix} \begin{bmatrix} A’_2 & 0 \\ B’_2 & I \end{bmatrix} \cdots \begin{bmatrix} A’_N & 0 \\ B’_N & I \end{bmatrix} \begin{bmatrix} I \\ 0 \end{bmatrix}\\ S_N &= \begin{bmatrix} 0 & I \end{bmatrix} \begin{bmatrix} \prod_{i=1}^N A’_i & 0 \\ \sum_{i=1}^N \left( B’_i \prod_{i’=i+1}^N A’_{i’} \right) & I \end{bmatrix} \begin{bmatrix} I \\ 0 \end{bmatrix}\\ S_N &= \sum_{i=1}^N \left( B’_i \prod_{i’=i+1}^N A’_{i’} \right)\\ S_N &= \sum_{i=1}^N \sum_{j=1}^{n_h} \left( B_{i,j} \underline{\left(\prod_{j’=j+1}^{n_h} A_{i,j’}\right) \left(\prod_{i’=i+1}^N \prod_{j’=1}^{n_h} A_{i’,j’} \right)}\right) \end{align}

which, again, if we reindex this as []k=[]k/nh, (k1)%nh+1[\cdot]_k = [\cdot]_{\lceil k/n_h \rceil,\space (k-1) \% n_h + 1}, we get:

SN=k=1Nnh(Bkk=k+1NnhAk)S_N = \sum_{k=1}^{Nn_h} \left( B_k \prod_{k’=k+1}^{Nn_h} A_{k’}\right) as expected.


Now, let’s derive the SNS_N for the linear attention mechanisms in the table above, but this time, with nhn_h gradient descent steps per token.

MambaSum*

Show derivation of SNS_NAi,j=diag(αi,jI)Bi,j=vi,jki,jTA_{i,j} = \text{diag}\left(\alpha_{i,j} I\right) \quad\quad B_{i,j} = \bm{v}_{i,j} \bm{k}_{i,j}^T Thus, from Equation (10)(10) above, SN=i=1Nj=1nh(vi,jki,jT(j=j+1nhdiag(αi,jI))(i=i+1Nj=1nhdiag(αi,jI)))SN=i=1Nj=1nh((j=j+1nhαi,j)(i=i+1Nj=1nhαi,j))vi,jki,jTSN=k=1Nnh(k=k+1Nnhαk)vkkkT SN=i=1Nj=1nh(vi,jki,jT(j=j+1nhdiag(αi,jI))(i=i+1Nj=1nhdiag(αi,jI)))SN=i=1Nj=1nh((j=j+1nhαi,j)(i=i+1Nj=1nhαi,j))vi,jki,jTSN=k=1Nnh(k=k+1Nnhαk)vkkkT\begin{align*} S_N &= \sum_{i=1}^N \sum_{j=1}^{n_h} \left( \bm{v}_{i,j} \bm{k}_{i,j}^T \left(\prod_{j’=j+1}^{n_h} \text{diag}\left(\alpha_{i,j’} I\right)\right) \left(\prod_{i’=i+1}^N \prod_{j’=1}^{n_h} \text{diag}\left(\alpha_{i’,j’} I\right) \right)\right)\\ S_N &= \sum_{i=1}^N \sum_{j=1}^{n_h} \left(\underline{\left( \prod_{j’=j+1}^{n_h} \alpha_{i,j’}\right) \left(\prod_{i’=i+1}^N \prod_{j’=1}^{n_h} \alpha_{i’,j’} \right)} \right) \bm{v}_{i,j} \bm{k}_{i,j}^T\\ S_N &= \sum_{k=1}^{Nn_h} \left(\prod_{k’=k+1}^{Nn_h} \alpha_{k’}\right) \bm{v}_k \bm{k}_k^T \end{align*}

*I’m not actually sure if MambaSum already exists under a different name. If it does, please let me know!

DeltaProduct

Show derivation of SNS_NAi,j=Iβi,jki,jki,jTBi,j=βi,jvi,jki,jTA_{i,j} = I - \beta_{i,j} \bm{k}_{i,j} \bm{k}_{i,j}^T \quad\quad B_{i,j} = \beta_{i,j} \bm{v}_{i,j} \bm{k}_{i,j}^T Thus, SN=i=1Nj=1nh(βi,jvi,jki,jT(j=j+1nh(Iβi,jki,jki,jT))(i=i+1Nj=1nh(Iβi,jki,jki,jT)))SN=k=1Nnh(βkvkkkTk=k+1Nnh(IβkkkkkT)) SN=i=1Nj=1nh(βi,jvi,jki,jT(j=j+1nh(Iβi,jki,jki,jT))(i=i+1Nj=1nh(Iβi,jki,jki,jT)))SN=k=1Nnh(βkvkkkTk=k+1Nnh(IβkkkkkT))\begin{align*} S_N &= \sum_{i=1}^N \sum_{j=1}^{n_h} \left( \beta_{i,j} \bm{v}_{i,j} \bm{k}_{i,j}^T \underline{\left(\prod_{j’=j+1}^{n_h} \left(I - \beta_{i,j’} \bm{k}_{i,j’} \bm{k}_{i,j’}^T\right)\right) \left(\prod_{i’=i+1}^N \prod_{j’=1}^{n_h} \left(I - \beta_{i’,j’} \bm{k}_{i’,j’} \bm{k}_{i’,j’}^T\right) \right)}\right)\\ S_N &= \sum_{k=1}^{Nn_h} \left(\beta_k \bm{v}_k \bm{k}_k^T \prod_{k’=k+1}^{Nn_h} \left(I - \beta_{k’} \bm{k}_{k’} \bm{k}_{k’}^T\right)\right) \end{align*}

Gated DeltaProduct

Show derivation of SNS_NAi,j=αi,j(Iβi,jki,jki,jT)Bi,j=βi,jvi,jki,jTA_{i,j} = \alpha_{i,j}(I - \beta_{i,j} \bm{k}_{i,j} \bm{k}_{i,j}^T) \quad\quad B_{i,j} = \beta_{i,j} \bm{v}_{i,j} \bm{k}_{i,j}^T Thus, SN=i=1Nj=1nh(βi,jvi,jki,jT(j=j+1nhαi,j(Iβi,jki,jki,jT))(i=i+1Nj=1nhαi,j(Iβi,jki,jki,jT)))SN=k=1Nnh(βkvkkkTk=k+1Nnhαk(IβkkkkkT))SN=k=1Nnh((βkk=k+1Nnhαk)vkkkTk=k+1Nnh(IβkkkkkT)) SN=i=1Nj=1nh(βi,jvi,jki,jT(j=j+1nhαi,j(Iβi,jki,jki,jT))(i=i+1Nj=1nhαi,j(Iβi,jki,jki,jT)))SN=k=1Nnh(βkvkkkTk=k+1Nnhαk(IβkkkkkT))SN=k=1Nnh((βkk=k+1Nnhαk)vkkkTk=k+1Nnh(IβkkkkkT))\begin{align*} S_N &= \sum_{i=1}^N \sum_{j=1}^{n_h} \left( \beta_{i,j} \bm{v}_{i,j} \bm{k}_{i,j}^T \underline{\left(\prod_{j’=j+1}^{n_h} \alpha_{i,j’} \left(I - \beta_{i,j’} \bm{k}_{i,j’} \bm{k}_{i,j’}^T\right)\right) \left(\prod_{i’=i+1}^N \prod_{j’=1}^{n_h} \alpha_{i’,j’} \left(I - \beta_{i’,j’} \bm{k}_{i’,j’} \bm{k}_{i’,j’}^T\right) \right)}\right)\\ S_N &= \sum_{k=1}^{Nn_h} \left(\beta_k \bm{v}_k \bm{k}_k^T \prod_{k’=k+1}^{Nn_h} \alpha_{k’} \left(I - \beta_{k’} \bm{k}_{k’} \bm{k}_{k’}^T\right)\right)\\ S_N &= \sum_{k=1}^{Nn_h} \left(\left( \beta_k \prod_{k’=k+1}^{Nn_h} \alpha_{k’} \right) \bm{v}_k \bm{k}_k^T \prod_{k’=k+1}^{Nn_h} \left(I - \beta_{k’} \bm{k}_{k’} \bm{k}_{k’}^T\right)\right) \end{align*}

RWKV-7P

Show derivation of SNS_NAi,j=diag(wi,j)κ^i,j(ai,jκ^i,jT)Bi,j=vi,jki,jTA_{i,j} = \text{diag}(\bm{w}_{i,j}) - \bm{\hat{\kappa}}_{i,j}(\bm{a}_{i,j} \odot\bm{\hat{\kappa}}_{i,j}^T) \quad\quad B_{i,j} = \bm{v}_{i,j} \bm{k}_{i,j}^T Thus, SN=i=1Nj=1nh(vi,jki,jT(j=j+1nh(diag(wi,j)κ^i,j(ai,jκ^i,jT)))(i=i+1Nj=1nh(diag(wi,j)κ^i,j(ai,jκ^i,jT))))SN=k=1Nnh(vkkkTk=k+1Nnh(diag(wk)κ^k(akκ^kT))) SN=i=1Nj=1nh(vi,jki,jT(j=j+1nh(diag(wi,j)κ^i,j(ai,jκ^i,jT)))(i=i+1Nj=1nh(diag(wi,j)κ^i,j(ai,jκ^i,jT))))SN=k=1Nnh(vkkkTk=k+1Nnh(diag(wk)κ^k(akκ^kT)))\begin{align*} S_N &= \sum_{i=1}^N \sum_{j=1}^{n_h} \left( \bm{v}_{i,j} \bm{k}_{i,j}^T \underline{\left(\prod_{j’=j+1}^{n_h} \left(\text{diag}(\bm{w}_{i,j’}) - \bm{\hat{\kappa}}_{i,j’}(\bm{a}_{i,j’} \odot\bm{\hat{\kappa}}_{i,j’}^T)\right)\right) \left(\prod_{i’=i+1}^N \prod_{j’=1}^{n_h} \left(\text{diag}(\bm{w}_{i’,j’}) - \bm{\hat{\kappa}}_{i’,j’}(\bm{a}_{i’,j’} \odot\bm{\hat{\kappa}}_{i’,j’}^T)\right) \right)}\right)\\ S_N &= \sum_{k=1}^{Nn_h} \left(\bm{v}_k \bm{k}_k^T \prod_{k’=k+1}^{Nn_h} \left(\text{diag}(\bm{w}_k’) - \bm{\hat{\kappa}}_k’(\bm{a}_k’ \odot\bm{\hat{\kappa}}_k’^T)\right)\right) \end{align*}


Chunk-wise Parallelism

Since the update operations of linear attention mechanisms we discussed above are associative–i.e., the order in which we “combine” the updates doesn’t matter–we can perform the computations in multiple ways:

  1. The Fully Recurrent Form where we update the state as we loop through the tokens/update matrices one by one,
  2. The Fully-Parallel Associative Scan Form where we hierarchically combine the updates in a tree-like structure, and
  3. The Chunk-wise Parallel Form (Hua et al., 2022; Sun et al., 2023) which is a compromise between the two where we divide the sequence into chunks first, combine intra-chunk updates in parallel, and then combine the chunk-level updates in a recurrent manner.

At inference time, the recurrent form works best*. But at training time, we have to be more hardware-aware to squeeze out as much performance as possible. We will discuss more about this in a separate post. But for now, there are two important things to keep in mind:

  1. The GPU Memory Hierarchy. NVIDIA GPUs have a “global”, high-bandwidth memory (HBM) that all threads in all processing units can access, and a smaller, shared memory (SRAM) that threads in the same processing unit can access. The shared memory, being more “local”, has a much lower latency than the HBM. Thus, as much as possible, we want to limit communications between the processing units and the HBM and use the SRAM instead.
  2. The Tensor Cores. Modern NVIDIA GPUs have tensor cores that can perform matrix multiplications much faster. Thus, ideally, we want to maximize the use of matrix multiplications and limit other operations.

Now, parallel associative scan might seem the best choice, and indeed it already suffices for some architectures like Mamba 1. However, it requires a lot more (shared) memory and communication between the processing units (and therefore materialization to the HBM). And it also doesn’t fully utilize the tensor cores. But with chunk-wise parallelism, we only need to store the current state in the shared memory, and use matrix multiplications to compute the next chunk-level state. This way, we don’t have to materialize the SNS_Ns to the HBM at all, and we can fully utilize the tensor cores. Hence why most flash linear attention kernels use chunk-wise parallelism.

*At inference time, we need to process the input tokens first before generating outputs. This is called the “pre-filling” stage. And chunk-wise parallelism works better here. After that, we can then use the recurrent form to generate the outputs.


A better way to think of chunk-wise parallelism is as multi-step online gradient descent, but instead of updating the state nhn_h times per token, we update the state ncn_c times per chunk where nc=N/Cn_c = N/C is the number of tokens per chunk and CC is the number of chunks. Thus, we just reuse our results from the previous section!

To make the connection more explicit, let’s reindex Equation (1)(1) as []i=[]i/nc, (i1)%nc+1[\cdot]_i = [\cdot]_{\lceil i/n_c \rceil,\space (i-1) \% n_c + 1}: SN=[0I][A10B1I][A20B2I][Anc0BncI][Anc+10Bnc+1I][AN0BNI][I0]SN=[0I][A1,10B1,1I][A1,20B1,2I][A1,nc0B1,ncI][A2,10B2,1I][AC,nc0BC,ncI][I0] SN=[0I][A10B1I][A20B2I][Anc0BncI][Anc+10Bnc+1I][AN0BNI][I0]SN=[0I][A1,10B1,1I][A1,20B1,2I][A1,nc0B1,ncI][A2,10B2,1I][AC,nc0BC,ncI][I0]\begin{align*} S_N &= \begin{bmatrix} 0 & I \end{bmatrix} \begin{bmatrix} A_{1} & 0 \\ B_{1} & I \end{bmatrix} \begin{bmatrix} A_{2} & 0 \\ B_{2} & I \end{bmatrix} \cdots \begin{bmatrix} A_{n_c} & 0 \\ B_{n_c} & I \end{bmatrix} \begin{bmatrix} A_{n_c + 1} & 0 \\ B_{n_c + 1} & I \end{bmatrix} \cdots \begin{bmatrix} A_{N} & 0 \\ B_{N} & I \end{bmatrix} \begin{bmatrix} I \\ 0 \end{bmatrix}\\ S_N &= \begin{bmatrix} 0 & I \end{bmatrix} \begin{bmatrix} A_{1,1} & 0 \\ B_{1,1} & I \end{bmatrix} \begin{bmatrix} A_{1,2} & 0 \\ B_{1,2} & I \end{bmatrix} \cdots \begin{bmatrix} A_{1,n_c} & 0 \\ B_{1,n_c} & I \end{bmatrix} \begin{bmatrix} A_{2,1} & 0 \\ B_{2,1} & I \end{bmatrix} \cdots \begin{bmatrix} A_{C, n_c} & 0 \\ B_{C, n_c} & I \end{bmatrix} \begin{bmatrix} I \\ 0 \end{bmatrix}\\ \end{align*} where Ac,iA_{c,i} and Bc,iB_{c,i} are now the update matrices for the ii-th token within the cc-th chunk.

And by combining the updates for each chunk as in Equation (6)(6) above, we get: [Ac0BcI]=i=1nc[Ac,i0Bc,iI]=[i=1ncAc,i0i=1nc(Bc,ii=i+1ncAc,i)I] [Ac0BcI]=i=1nc[Ac,i0Bc,iI]=[i=1ncAc,i0i=1nc(Bc,ii=i+1ncAc,i)I]\begin{align} \begin{bmatrix} A^*_{c} & 0 \\ B^*_{c} & I \end{bmatrix} = \prod_{i=1}^{n_c} \begin{bmatrix} A_{c,i} & 0 \\ B_{c,i} & I \end{bmatrix} = \begin{bmatrix} \prod_{i=1}^{n_c} A_{c,i} & 0 \\ \sum_{i=1}^{n_c} \left(B_{c,i} \prod_{i’=i+1}^{n_c} A_{c,i’}\right) & I \end{bmatrix} \end{align} SC=[0I][A10B1I][A20B2I][AC10BC1I][AC0BCI][I0] S_C = \underline{ \begin{bmatrix} 0 & I \end{bmatrix} \begin{bmatrix} A^*_1 & 0 \\ B^*_1 & I \end{bmatrix} \begin{bmatrix} A^*_2 & 0 \\ B^*_2 & I \end{bmatrix} \cdots \begin{bmatrix} A^*_{C-1} & 0 \\ B^*_{C-1} & I \end{bmatrix} } \begin{bmatrix} A^*_C & 0 \\ B^*_C & I \end{bmatrix} \begin{bmatrix} I \\ 0 \end{bmatrix} which has the equivalent cross-chunk recurrent form: [SCI]=[SC1I][AC0BCI]SC=SC1AC+BC [SCI]=[SC1I][AC0BCI]SC=SC1AC+BC\begin{align} \begin{bmatrix} S_{C} & I \end{bmatrix} &= \begin{bmatrix} S_{C-1} & I \end{bmatrix} \begin{bmatrix} A^*_C & 0 \\ B^*_C & I \end{bmatrix}\\ S_C &= S_{C-1}A^*_C + B^*_C \end{align}


Now, let’s derive the SCS_C for the linear attention mechanisms in the table above.

Chunk-wise Mamba 2

Show derivation of SCS_CAc,i=diag(αc,iI)Bc,i=vc,ikc,iTAC=i=1ncdiag(αC,iI)BC=i=1nc(vC,ikC,iTi=i+1ncdiag(αC,iI)) Ac,i=diag(αc,iI)Bc,i=vc,ikc,iTAC=i=1ncdiag(αC,iI)BC=i=1nc(vC,ikC,iTi=i+1ncdiag(αC,iI))\begin{align*} A_{c,i} &= \text{diag}\left(\alpha_{c,i} I\right) & B_{c,i} &= \bm{v}_{c,i} \bm{k}_{c,i}^T\\ A^*_C &= \prod_{i=1}^{n_c} \text{diag}\left(\alpha_{C,i} I\right) \quad & B^*_C &= \sum_{i=1}^{n_c} \left(\bm{v}_{C,i} \bm{k}_{C,i}^T \prod_{i’=i+1}^{n_c} \text{diag}\left(\alpha_{C,i’} I\right)\right) \end{align*} Thus, from Equation (13)(13) above, SC=SC1AC+BCSC=SC1i=1ncdiag(αC,iI)+i=1nc(vC,ikC,iTi=i+1ncdiag(αC,iI))SC=SC1i=1ncαC,i+i=1nc(i=i+1ncαC,i)vC,ikC,iT SC=SC1AC+BCSC=SC1i=1ncdiag(αC,iI)+i=1nc(vC,ikC,iTi=i+1ncdiag(αC,iI))SC=SC1i=1ncαC,i+i=1nc(i=i+1ncαC,i)vC,ikC,iT\begin{align*} S_C &= S_{C-1}A^*_C + B^*_C\\ S_C &= S_{C-1} \prod_{i=1}^{n_c} \text{diag}\left(\alpha_{C,i} I\right) + \sum_{i=1}^{n_c} \left(\bm{v}_{C,i} \bm{k}_{C,i}^T \prod_{i’=i+1}^{n_c} \text{diag}\left(\alpha_{C,i’} I\right)\right)\\ S_C &= S_{C-1} \prod_{i=1}^{n_c} \alpha_{C,i} + \sum_{i=1}^{n_c} \left(\prod_{i’=i+1}^{n_c} \alpha_{C,i’}\right) \bm{v}_{C,i} \bm{k}_{C,i}^T \end{align*}

Chunk-wise DeltaNet

Show derivation of SCS_CAc,i=Iβc,ikc,ikc,iTBc,i=βc,ivc,ikc,iTAC=i=1nc(IβC,ikC,ikC,iT)BC=i=1nc(βC,ivC,ikC,iTi=i+1nc(IβC,ikC,ikC,iT)) Ac,i=Iβc,ikc,ikc,iTBc,i=βc,ivc,ikc,iTAC=i=1nc(IβC,ikC,ikC,iT)BC=i=1nc(βC,ivC,ikC,iTi=i+1nc(IβC,ikC,ikC,iT))\begin{align*} A_{c,i} &= I - \beta_{c,i} \bm{k}_{c,i} \bm{k}_{c,i}^T & B_{c,i} &= \beta_{c,i} \bm{v}_{c,i} \bm{k}_{c,i}^T\\ A^*_C &= \prod_{i=1}^{n_c} \left(I - \beta_{C,i} \bm{k}_{C,i} \bm{k}_{C,i}^T\right) \quad & B^*_C &= \sum_{i=1}^{n_c} \left(\beta_{C,i} \bm{v}_{C,i} \bm{k}_{C,i}^T \prod_{i’=i+1}^{n_c} \left(I - \beta_{C,i’} \bm{k}_{C,i’} \bm{k}_{C,i’}^T\right)\right) \end{align*} Thus, SC=SC1AC+BCSC=SC1i=1nc(IβC,ikC,ikC,iT)+i=1nc(βC,ivC,ikC,iTi=i+1nc(IβC,ikC,ikC,iT)) SC=SC1AC+BCSC=SC1i=1nc(IβC,ikC,ikC,iT)+i=1nc(βC,ivC,ikC,iTi=i+1nc(IβC,ikC,ikC,iT))\begin{align*} S_C &= S_{C-1}A^*_C + B^*_C\\ S_C &= S_{C-1} \prod_{i=1}^{n_c} \left(I - \beta_{C,i} \bm{k}_{C,i} \bm{k}_{C,i}^T\right) + \sum_{i=1}^{n_c} \left(\beta_{C,i} \bm{v}_{C,i} \bm{k}_{C,i}^T \prod_{i’=i+1}^{n_c} \left(I - \beta_{C,i’} \bm{k}_{C,i’} \bm{k}_{C,i’}^T\right)\right) \end{align*}

Chunk-wise Gated DeltaNet

Show derivation of SCS_CAc,i=αc,i(Iβc,ikc,ikc,iT)Bc,i=βc,ivc,ikc,iTAC=i=1ncαC,i(IβC,ikC,ikC,iT)BC=i=1nc(βC,ivC,ikC,iTi=i+1ncαC,i(IβC,ikC,ikC,iT)) Ac,i=αc,i(Iβc,ikc,ikc,iT)Bc,i=βc,ivc,ikc,iTAC=i=1ncαC,i(IβC,ikC,ikC,iT)BC=i=1nc(βC,ivC,ikC,iTi=i+1ncαC,i(IβC,ikC,ikC,iT))\begin{align*} A_{c,i} &= \alpha_{c,i}(I - \beta_{c,i} \bm{k}_{c,i} \bm{k}_{c,i}^T) & B_{c,i} &= \beta_{c,i} \bm{v}_{c,i} \bm{k}_{c,i}^T\\ A^*_C &= \prod_{i=1}^{n_c} \alpha_{C,i} \left(I - \beta_{C,i} \bm{k}_{C,i} \bm{k}_{C,i}^T\right) \quad & B^*_C &= \sum_{i=1}^{n_c} \left(\beta_{C,i} \bm{v}_{C,i} \bm{k}_{C,i}^T \prod_{i’=i+1}^{n_c} \alpha_{C,i’} \left(I - \beta_{C,i’} \bm{k}_{C,i’} \bm{k}_{C,i’}^T\right)\right) \end{align*} Thus, SC=SC1AC+BCSC=SC1i=1ncαC,i(IβC,ikC,ikC,iT)+i=1nc(βC,ivC,ikC,iTi=i+1ncαC,i(IβC,ikC,ikC,iT))SC=SC1(i=1ncαC,i)(i=1nc(IβC,ikC,ikC,iT))+i=1nc((βC,ii=i+1ncαC,i)vC,ikC,iTi=i+1nc(IβC,ikC,ikC,iT)) SC=SC1AC+BCSC=SC1i=1ncαC,i(IβC,ikC,ikC,iT)+i=1nc(βC,ivC,ikC,iTi=i+1ncαC,i(IβC,ikC,ikC,iT))SC=SC1(i=1ncαC,i)(i=1nc(IβC,ikC,ikC,iT))+i=1nc((βC,ii=i+1ncαC,i)vC,ikC,iTi=i+1nc(IβC,ikC,ikC,iT))\begin{align*} S_C &= S_{C-1}A^*_C + B^*_C\\ S_C &= S_{C-1} \prod_{i=1}^{n_c} \alpha_{C,i} \left(I - \beta_{C,i} \bm{k}_{C,i} \bm{k}_{C,i}^T\right) + \sum_{i=1}^{n_c} \left(\beta_{C,i} \bm{v}_{C,i} \bm{k}_{C,i}^T \prod_{i’=i+1}^{n_c} \alpha_{C,i’} \left(I - \beta_{C,i’} \bm{k}_{C,i’} \bm{k}_{C,i’}^T\right)\right)\\ S_C &= S_{C-1} \left(\prod_{i=1}^{n_c} \alpha_{C,i} \right) \left(\prod_{i=1}^{n_c} \left(I - \beta_{C,i} \bm{k}_{C,i} \bm{k}_{C,i}^T\right)\right) + \sum_{i=1}^{n_c} \left(\left(\beta_{C,i} \prod_{i’=i+1}^{n_c} \alpha_{C,i’} \right) \bm{v}_{C,i} \bm{k}_{C,i}^T \prod_{i’=i+1}^{n_c} \left(I - \beta_{C,i’} \bm{k}_{C,i’} \bm{k}_{C,i’}^T\right)\right) \end{align*}

Chunk-wise RWKV-7

Show derivation of SCS_CAc,i=diag(wc,i)κ^c,i(ac,iκ^c,iT)Bc,i=vc,ikc,iTAC=i=1nc(diag(wC,i)κ^C,i(aC,iκ^C,iT))BC=i=1nc(vC,ikC,iTi=i+1nc(diag(wC,i)κ^C,i(aC,iκ^C,iT))) Ac,i=diag(wc,i)κ^c,i(ac,iκ^c,iT)Bc,i=vc,ikc,iTAC=i=1nc(diag(wC,i)κ^C,i(aC,iκ^C,iT))BC=i=1nc(vC,ikC,iTi=i+1nc(diag(wC,i)κ^C,i(aC,iκ^C,iT)))\begin{align*} A_{c,i} &= \text{diag}\left(\bm{w}_{c,i}\right) - \bm{\hat{\kappa}}_{c,i}(\bm{a}_{c,i} \odot\bm{\hat{\kappa}}_{c,i}^T) & B_{c,i} &= \bm{v}_{c,i} \bm{k}_{c,i}^T\\ A^*_C &= \prod_{i=1}^{n_c} \left(\text{diag}\left(\bm{w}_{C,i}\right) - \bm{\hat{\kappa}}_{C,i}(\bm{a}_{C,i} \odot\bm{\hat{\kappa}}_{C,i}^T)\right) \quad & B^*_C &= \sum_{i=1}^{n_c} \left(\bm{v}_{C,i} \bm{k}_{C,i}^T \prod_{i’=i+1}^{n_c} \left(\text{diag}\left(\bm{w}_{C,i’}\right) - \bm{\hat{\kappa}}_{C,i’}(\bm{a}_{C,i’} \odot\bm{\hat{\kappa}}_{C,i’}^T)\right)\right) \end{align*} Thus, SC=SC1AC+BCSC=SC1i=1nc(diag(wC,i)κ^C,i(aC,iκ^C,iT))+i=1nc(vC,ikC,iTi=i+1nc(diag(wC,i)κ^C,i(aC,iκ^C,iT))) SC=SC1AC+BCSC=SC1i=1nc(diag(wC,i)κ^C,i(aC,iκ^C,iT))+i=1nc(vC,ikC,iTi=i+1nc(diag(wC,i)κ^C,i(aC,iκ^C,iT)))\begin{align*} S_C &= S_{C-1}A^*_C + B^*_C\\ S_C &= S_{C-1} \prod_{i=1}^{n_c} \left(\text{diag}\left(\bm{w}_{C,i}\right) - \bm{\hat{\kappa}}_{C,i}(\bm{a}_{C,i} \odot\bm{\hat{\kappa}}_{C,i}^T)\right) + \sum_{i=1}^{n_c} \left(\bm{v}_{C,i} \bm{k}_{C,i}^T \prod_{i’=i+1}^{n_c} \left(\text{diag}\left(\bm{w}_{C,i’}\right) - \bm{\hat{\kappa}}_{C,i’}(\bm{a}_{C,i’} \odot\bm{\hat{\kappa}}_{C,i’}^T)\right)\right) \end{align*}

Multi-Step Online Gradient Descent per Token with Chunk-wise Parallelism

Let’s combine the two techniques we’ve discussed so far: multi-step online gradient descent per token and chunk-wise parallelism.

The strategy

We can do this either way, but suppose we chunk the updates first then expand the each of the updates within the chunks into a product of nhn_h updates. I.e., we have:

[A(c1)nc+i0B(c1)nc+iI]reindex[Ac,i0Bc,iI]expand[Ac,i,10Bc,i,1I][Ac,i,20Bc,i,2I][Ac,i,nh0Bc,i,nhI] \begin{bmatrix} A_{(c-1)*n_c + i} & 0 \\ B_{(c-1)*n_c + i} & I \end{bmatrix} \xrightarrow{\text{reindex}} \begin{bmatrix} A_{c,i} & 0 \\ B_{c,i} & I \end{bmatrix} \xrightarrow{\text{expand}} \begin{bmatrix} A_{c,i,1} & 0 \\ B_{c,i,1} & I \end{bmatrix} \begin{bmatrix} A_{c,i,2} & 0 \\ B_{c,i,2} & I \end{bmatrix} \cdots \begin{bmatrix} A_{c,i,n_h} & 0 \\ B_{c,i,n_h} & I \end{bmatrix} where Ac,i,jA_{c,i,j} and Bc,i,jB_{c,i,j} are the update matrices for the jj-th gradient descent step for the ii-th token within the cc-th chunk.

And from equations (6)(6), (10)(10), and (11)(11), we have: [Ac0BcI]=i=1nc[Ac,i0Bc,iI]=i=1ncj=1nh[Ac,i,j0Bc,i,jI][Ac0BcI]=[i=1ncj=1nhAc,i,j0i=1ncj=1nh(Bc,i,j(j=j+1nhAc,i,j)(i=i+1ncj=1nhAc,i,j))I] [Ac0BcI]=i=1nc[Ac,i0Bc,iI]=i=1ncj=1nh[Ac,i,j0Bc,i,jI][Ac0BcI]=[i=1ncj=1nhAc,i,j0i=1ncj=1nh(Bc,i,j(j=j+1nhAc,i,j)(i=i+1ncj=1nhAc,i,j))I]\begin{align*} \begin{bmatrix} A^*_{c} & 0 \\ B^*_{c} & I \end{bmatrix} &= \prod_{i=1}^{n_c} \begin{bmatrix} A’_{c,i} & 0 \\ B’_{c,i} & I \end{bmatrix} = \prod_{i=1}^{n_c} \prod_{j=1}^{n_h} \begin{bmatrix} A_{c,i,j} & 0 \\ B_{c,i,j} & I \end{bmatrix}\\ \begin{bmatrix} A^*_{c} & 0 \\ B^*_{c} & I \end{bmatrix} &= \begin{bmatrix} \prod_{i=1}^{n_c} \prod_{j=1}^{n_h} A_{c,i,j} & 0 \\ \sum_{i=1}^{n_c}\sum_{j=1}^{n_h} \left( B_{c,i,j} \left(\prod_{j’=j+1}^{n_h} A_{c,i,j’}\right) \left(\prod_{i’=i+1}^{n_c} \prod_{j=1}^{n_h} A_{c,i,j}\right)\right) & I \end{bmatrix} \end{align*} Thus, Ac=i=1ncj=1nhAc,i,jBc=i=1ncj=1nh(Bc,i,j(j=j+1nhAc,i,j)(i=i+1ncj=1nhAc,i,j)) Ac=i=1ncj=1nhAc,i,jBc=i=1ncj=1nh(Bc,i,j(j=j+1nhAc,i,j)(i=i+1ncj=1nhAc,i,j))\begin{align*} A^*_{c} &= \prod_{i=1}^{n_c} \prod_{j=1}^{n_h} A_{c,i,j} \\ B^*_{c} &= \sum_{i=1}^{n_c}\sum_{j=1}^{n_h} \left( B_{c,i,j} \left(\prod_{j’=j+1}^{n_h} A_{c,i,j’}\right) \left(\prod_{i’=i+1}^{n_c} \prod_{j=1}^{n_h} A_{c,i,j}\right)\right) \end{align*} which we can then plug into Equation (13)(13) to get the cross-chunk recurrence:

SC=SC1AC+BCSC=SC1i=1ncj=1nhAC,i,j+i=1ncj=1nh(BC,i,j(j=j+1nhAC,i,j)(i=i+1ncj=1nhAC,i,j)) SC=SC1AC+BCSC=SC1i=1ncj=1nhAC,i,j+i=1ncj=1nh(BC,i,j(j=j+1nhAC,i,j)(i=i+1ncj=1nhAC,i,j))\begin{align*} S_C &= S_{C-1}A^*_C + B^*_C\\ S_C &= S_{C-1} \prod_{i=1}^{n_c} \prod_{j=1}^{n_h} A_{C,i,j} + \sum_{i=1}^{n_c}\sum_{j=1}^{n_h} \left( B_{C,i,j} \left(\prod_{j’=j+1}^{n_h} A_{C,i,j’}\right) \left(\prod_{i’=i+1}^{n_c} \prod_{j=1}^{n_h} A_{C,i,j}\right)\right) \end{align*}

or, if we reindex this as []C,k=[]C, k/nh, (k1)%nh+1[\cdot]_{C,k} = [\cdot]_{C,\space \lceil k/n_h \rceil,\space (k-1) \% n_h + 1}, we get:

SC=SC1k=1ncnhAC,k+k=1ncnh(BC,kk=k+1ncnhAC,k) SC=SC1k=1ncnhAC,k+k=1ncnh(BC,kk=k+1ncnhAC,k)\begin{align*} S_C &= S_{C-1} \prod_{k=1}^{n_c n_h} A_{C,k} + \sum_{k=1}^{n_c n_h} \left( B_{C,k} \prod_{k’=k+1}^{n_c n_h} A_{C,k’}\right) \end{align*}


As an exercise, try deriving the cross-chunk recurrence for MambaSum, DeltaProduct, Gated DeltaProduct, and RWKV-7P.


Conclusion

And that’s it!

Not only is the blocked matrix formulation of linear attention mechanisms intuitive, it also makes the connections between different algorithms and computational forms much more obvious. I’d even go as far as to say that we now have the proper abstraction to do an evolutionary search for new linear attention mechanisms ;)


In the next post, we’ll talk about faster ways to calculate AcA^*_{c} and BcB^*_{c} for diagonal and diagonal-plus-low-rank AcA^*_{c} using the WY Representations and the UT Transform. Stay tuned!

Acknowledgements

Big thanks to Songlin Yang, Julien Siems, and @Smerky, @BeeGass, @safelix, and @jacobbuckman for their feedback and discussions!

How to Cite

@misc{cesista2025blockmatlinearattn,
  author = {Franz Louis Cesista},
  title = {Blocked Matrix Formulation of Linear Attention Mechanisms},
  year = {2025},
  url = {https://leloykun.github.io/ponder/blockmat-linear-attn/},
}

References

  1. Riccardo Grazzi, Julien Siems, Jörg K.H. Franke, Arber Zela, Frank Hutter, Massimiliano Pontil (2025). Unlocking State-Tracking in Linear RNNs Through Negative Eigenvalues. URL https://arxiv.org/abs/2411.12537
  2. Julien Siems, Timur Carstensen, Arber Zela, Frank Hutter, Massimiliano Pontil, Riccardo Grazzi (2025). DeltaProduct: Increasing the Expressivity of DeltaNet Through Products of Householders. URL https://arxiv.org/abs/2502.10297
  3. Angelos Katharopoulos, Apoorv Vyas, Nikolaos Pappas, and François Fleuret. Transformers are rnns: Fast autoregressive transformers with linear attention. In Proceedings of the 37th International Conference on Machine Learning, ICML 2020, 13-18 July 2020, Virtual Event, volume 119 of Proceedings of Machine Learning Research, pp. 5156–5165. PMLR, 2020b. URL http://proceedings.mlr.press/v119/katharopoulos20a.html.
  4. Tri Dao and Albert Gu. Transformers are SSMs: Generalized models and efficient algorithms through structured state space duality. In Proceedings of the 41st International Conference on MachineLearning, volume 235 of Proceedingsof Machine Learning Research, pp. 10041–10071. PMLR, 2024b. URL https://proceedings.mlr.press/v235/dao24a.html.
  5. Songlin Yang, Bailin Wang, Yu Zhang, Yikang Shen, and Yoon Kim (2025). Parallelizing Linear Transformers with the Delta Rule over Sequence Length. URL https://arxiv.org/abs/2406.06484
  6. Songlin Yang, Jan Kautz, Ali Hatamizadeh (2025). Gated Delta Networks: Improving Mamba2 with Delta Rule. URL https://arxiv.org/abs/2412.06464
  7. Weizhe Hua, Zihang Dai, Hanxiao Liu, and Quoc V. Le. Transformer quality in linear time. In Kamalika Chaudhuri, Stefanie Jegelka, Le Song, Csaba Szepesvári, Gang Niu, and Sivan Sabato (eds.), International Conference on Machine Learning, ICML 2022, 17-23 July 2022, Baltimore, Maryland, USA, volume 162 of Proceedings of Machine Learning Research, pp. 9099–9117. PMLR, 2022b. URL https://proceedings.mlr.press/v162/hua22a.html.
  8. Yutao Sun, Li Dong, Shaohan Huang, Shuming Ma, Yuqing Xia, Jilong Xue, Jianyong Wang, and Furu Wei. Retentive network: A successor to transformer for large language models. ArXiv preprint, abs/2307.08621, 2023. URL https://arxiv.org/abs/2307.08621.
  9. Bo Peng, Ruichong Zhang, Daniel Goldstein, Eric Alcaide, Haowen Hou, Janna Lu, William Merrill, Guangyu Song, Kaifeng Tan, Saiteja Utpala, Nathan Wilce, Johan S. Wind, Tianyi Wu, Daniel Wuttke, Christian Zhou-Zheng (2025). RWKV-7 “Goose” with Expressive Dynamic State Evolution. URL https://arxiv.org/abs/2503.14456