If you find this post useful, please consider supporting my work by sponsoring me on GitHub:
1. Introduction
In this blog post, we shall consider the problem of steepest descent on Finsler-structured (matrix) manifolds. This problem naturally arises in deep learning optimization because we want model training to be fast and robust. That is, we want our weight updates to maximally change activations (or outputs) while keeping both activations and weights stable.
As discussed in prior blog posts and our latest paper, we can achieve this by properly considering the geometry in which to ‘place’ our weights in. This then begs the questions,
- Which geometry should we ‘place’ our weights in? And,
- How do we perform optimization in this geometry?
For (1), note that we have two degrees of freedom here: the choice of the underlying manifold and the choice of metric or norm to equip to the tangent spaces of the manifold. The latter makes (2) tricky because the manifold we end up with could not only be non-Euclidean but even non-Riemannian–and work on non-Riemannian optimization is scarce to almost non-existent.
While it might seem that we’re just inventing a difficult problem for bored mathematicians to solve, we will show in the next sections that we can motivate such problems with simple arguments and even lead to 1.5x to 2x speedup in large-scale LLM training.
This blog post generalizes work by Jeremy Bernstein and Jianlin Su on ‘Stiefel Muon’ to optimization on Finsler-structured (matrix) manifolds.
2. Case studies
2.1. Case study #1: Muon
Following Bernstein & Newhouse (2024), one can think of the Muon optimizer (Jordan et al., 2024) as doing steepest descent under the spectral norm on $\mathbb{R}^{m \times n}$. But why choose the spectral norm in the first place? Why not the simpler Frobenius norm? As we discussed in previous blog posts, If we want the “natural” norm of our features and feature updates to be stable regardless of the model size,
then the “natural” norm of our weights and weight updates must also be stable regardless of the model size. where the ’natural’ feature norm here is the RMS norm or the scaled Euclidean norm while the ’natural’ weight norm is the RMS-to-RMS norm or the scaled spectral norm. Note that the spectral norm does not follow the Parallelogram Law and so it is not induced by an inner product and therefore non-Riemannian. It does, however, induce a Finsler-structure on the manifold–an example of what we’re trying to generalize here!Show contents of Section 2.1.
2.2. Case study #2: steepest descent on spectral norm Finsler-structured spectral norm ball around the origin
Show contents of Section 2.2.
In our latest paper titled, Training Transformers with Enforced Lipschitz Bounds, we provide methods for keeping the weight norms regulated in addition to using the Muon optimizer. Although we did not explicitly mention it, one can interpret our approach as performing steepest descent on the spectral norm Finsler-structured spectral norm ball around the origin. Inside the norm ball, the space is locally similar to the previous case. But whenever the weights get sent outside of the norm ball, we retract them back via the weight norm controls we introduced in our paper.
2.3. Case study #3: steepest descent on spectral norm Finsler-structured Stiefel manifold
The problem Jeremy, Jianlin, and I have been trying to solve then is this:
Given the current weight $W \in \texttt{St}(m, n)$ and a “raw gradient” we get via e.g. backpropagation $G \in \mathbb{R}^{m \times n}$, we want to find the optimal update $A^*$ such that, $$\begin{equation} A^* = \arg\max_{A \in \mathbb{R}^{m \times n}} \langle G, A \rangle \quad \text{ s.t. } \quad \| A \|_{2 \to 2} \leq 1,\quad A \in T_{W}\texttt{St}(m, n) \end{equation}$$
Inspired by a partial solution by Jianlin (which did not yet work at the time), I proposed heuristic solutions here. Jianlin then solved the problem via a fixed-point iteration method. Finally, Jeremy proposed a more general solution via the dual ascent algorithm. Cédric Simal also independently proposed studying the dual problem to me and Jeremy.
3. General solution via block-wise Primal-Dual Hybrid Gradient
Let $\mathcal{M}$ be a (matrix) manifold and $\| \cdot \|$ be a Finsler norm defined on the tangent spaces of $\mathcal{M}$, both chosen a priori. We want to solve the problem,
Given the current weight $W \in \mathcal{M}$ and a “raw gradient” or differential we get via e.g. backpropagation $G \in T_{W}^*\mathcal{M} \subseteq \mathbb{R}^{m \times n}$, we want to find the optimal update $A^*$ such that, $$\begin{equation} A^* = \arg\max_{A \in \mathbb{R}^{m \times n}} \langle G, A \rangle \quad \text{ s.t. } \quad \| A \| \leq 1,\quad A \in T_{W}\mathcal{M} \end{equation}$$
Replacing the constraints with indicator functions yields, $$\begin{equation} A^* = -\arg\min_{A \in \mathbb{R}^{m \times n}} \left\{ \langle G, A \rangle + \mathcal{i}_{\| \cdot \| \leq 1}(A) + \mathcal{i}_{T_{W}\mathcal{M}}(A) \right\} \end{equation}$$ where, $$ \mathcal{i}_{\| \cdot \| \leq 1}(A) = \begin{cases} 0 &\text{ if } \| A \| \leq 1 \\ \infty &\text{ otherwise} \end{cases} \qquad \text{ and } \qquad \mathcal{i}_{T_{W}\mathcal{M}}(A) = \begin{cases} 0 &\text{ if } A \in T_{W}\mathcal{M} \\ \infty &\text{ otherwise} \end{cases} $$
Equivalently, $$\begin{equation} A^* = -\arg\min_{A \in \mathbb{R}^{m \times n}} \left\{ f(A) + g(A) \right\} \end{equation}$$ where $f(\cdot) := \mathcal{i}_{\| \cdot \|_{\leq 1}}(\cdot)$ and $g(\cdot) := \mathcal{i}_{T_{W}\mathcal{M}}(\cdot) + \langle G, \cdot \rangle$. Note that we can move the $\langle G, \cdot \rangle$ term to $f$ instead, but as we will see later, the proximal operator for $g$ is simpler so we keep it there for improved numerical stability.
We can then split Equation (4) into two subproblems by ‘copying’ $A$, $$\begin{equation} A^* = -\left[\arg\min_{A,B \in \mathbb{R}^{m \times n}} \{f(A) + g(B)\} \quad \text{ s.t. } \quad A - B = 0\right]_{A} \end{equation}$$ This effectively blows up our solution search space, but one can easily prove that the optimal solution to the problem above also solves our original problem!
3.1. Recasting as a primal-dual problem
Define, $$ \begin{align*} X &:= \begin{bmatrix} A \\ B \end{bmatrix}\\ L &:= \begin{bmatrix} I & -I \end{bmatrix} \\ \mathcal{F}(X) &:= f(A) + g(B) \\ \mathcal{G}(Y) &:= \mathcal{i}_{\{0\}}(Y) = \begin{cases} 0 &\text{ if } Y = 0 \\ \infty &\text{ otherwise} \end{cases} \end{align*} $$ where $X \in \mathcal{X} = \mathbb{R}^{2m \times n}$, $Y \in \mathcal{Y} = \mathbb{R}^{m \times n}$, $L: \mathcal{X} \to \mathcal{Y}$ is a linear operator, $\mathcal{F}: \mathcal{X} \to \mathbb{R}$, and $\mathcal{G}: \mathcal{Y} \to \mathbb{R}$.
Then Equation (5) can be rewritten to, $$\begin{align} A^* &= -\left[ \arg\min_{X \in \mathcal{X}} \{\mathcal{F}(X) + \mathcal{G}(LX)\} \right]_{1} \end{align}$$
Fenchel duality then yields the saddle problem, $$\begin{align} \min_{X \in \mathcal{X}} \max_{Y \in \mathcal{Y}} \mathcal{L}(X,Y) &:= \mathcal{F}(X) + \langle LX, Y \rangle - \mathcal{G}^*(Y) \nonumber \\ &\ = \mathcal{F}(X) + \langle LX, Y \rangle \end{align}$$ since $\mathcal{G}^*(Y) = \sup_{Z \in \mathcal{Y}} \{ \langle Y, Z \rangle - \underbrace{\mathcal{G}(Z)}_{=\infty \text{ if } Z \neq 0} \} = \langle Y, 0 \rangle + \mathcal{G}(0) = 0$ for all $Y \in \mathcal{Y}$.
3.2. Block-wise Primal-Dual Hybrid Gradient
Following ODL’s page on PDHG, we choose $\tau_A, \tau_B, \sigma > 0$, $\theta \in [0,1]$, and initialize $X_0 \in \mathcal{X}$, $Y_0 \in \mathcal{Y}$, and $\widetilde{X}_0 = X_0$. We then iterate, $$\begin{align} Y_{k+1} &= \texttt{prox}_{\sigma \mathcal{G}^*} (Y_{k} + \sigma L \widetilde{X}_{k}) \\ X_{k+1} &= \texttt{prox}_{\tau \mathcal{F}} (X_{k} - \tau L^T Y_{k+1}) \\ \widetilde{X}_{k+1} &= X_{k+1} + \theta (X_{k+1} - X_{k}) \end{align}$$ where $\tau = \text{diag}(\tau_A I_m, \tau_B I_m)$ and $\texttt{prox}$ is the proximal operator.
To speed up convergence, we can also re-use the $X^*$ and $Y^*$ from the previous optimization step to initialize $X_0$ and $Y_0$. This is especially useful when e.g. using (nesterov) momentum on $G$, guaranteeing that the ‘input gradients’ do not vary too much.
3.2.1. Converting proximal operators to projections
For the $Y$-variable, $$\begin{align*} Y_{k+1} &= \texttt{prox}_{\sigma \mathcal{G}^*} (Y_{k} + \sigma L \widetilde{X}_{k}) \\ &= \arg\min_{Y \in \mathcal{Y}} \left\{ \sigma \cancel{\mathcal{G}^*(Y)} + \frac{1}{2} \| Y - (Y_{k} + \sigma L \widetilde{X}_{k}) \|_F^2 \right\} \\ &= Y_{k} + \sigma L \widetilde{X}_{k} \end{align*}$$
For the $X$-variable, $$\begin{align*} X_{k+1} &= \texttt{prox}_{\tau \mathcal{F}} (X_{k} - \tau L^T Y_{k+1}) \\ &= \arg\min_{X \in \mathcal{X}} \left\{ \tau \mathcal{F}(X) + \frac{1}{2} \| X - (X_{k} - \tau L^T Y_{k+1}) \|_F^2 \right\} \\ &= \arg\min_{X \in \mathcal{X}} \left\{ \tau_A f(A) + \tau_B g(B) + \frac{1}{2} \left\| \begin{bmatrix} A - (A_k - \tau_A Y_{k+1}) \\ B - (B_k + \tau_B Y_{k+1}) \end{bmatrix} \right\|_F^2 \right\} \\ &= \arg\min_{X \in \mathcal{X}} \{ \tau_A f(A) + \frac{1}{2} \left\| A - (A_k - \tau_A Y_{k+1}) \right\|_F^2 \\ &\qquad\qquad + \tau_B g(B) + \frac{1}{2} \left\| B - (B_k + \tau_B Y_{k+1}) \right\|_F^2 \} \\ \end{align*}$$
Note that we can optimize for $A$ and $B$ separately and thus get, $$\begin{align*} A_{k+1} &= \arg\min_{A \in \mathbb{R}^{m \times n}} \left\{ \tau_A f(A) + \frac{1}{2} \left\| A - (A_k - \tau_A Y_{k+1}) \right\|_F^2 \right\} \\ &= \arg\min_{\| A \| \leq 1} \left\{ \frac{1}{2} \left\| A - (A_k - \tau_A Y_{k+1}) \right\|_F^2 \right\} \\ &= \texttt{proj}_{\| \cdot \| \leq 1} (A_k - \tau_A Y_{k+1}) \\ \end{align*}$$ where $\texttt{proj}_{\| \cdot \| \leq 1}$ is the projection onto the unit norm ball. Likewise, $$\begin{align*} B_{k+1} &= \arg\min_{B \in \mathbb{R}^{m \times n}} \left\{ \tau_B g(B) + \frac{1}{2} \left\| B - (B_k + \tau_B Y_{k+1}) \right\|_F^2 \right\} \\ &= \arg\min_{B \in T_W\mathcal{M}} \left\{ \tau_B \langle G, B \rangle + \frac{1}{2} \left\| B - (B_k + \tau_B Y_{k+1}) \right\|_F^2 \right\} \\ &= \arg\min_{B \in T_W\mathcal{M}} \left\{ \tau_B \langle G, B \rangle + \frac{1}{2} \| B \|_F^2 - \langle B, B_k + \tau_B Y_{k+1} \rangle + \frac{1}{2} \| B_k + \tau_B Y_{k+1} \|_F^2 \right\} \\ &= \arg\min_{B \in T_W\mathcal{M}} \left\{ \frac{1}{2} \| B \|_F^2 - \langle B, B_k + \tau_B Y_{k+1} - \tau_B G \rangle + \text{ constant} \right\} \\ &= \arg\min_{B \in T_W\mathcal{M}} \left\{ \frac{1}{2} \| B - (B_k + \tau_B Y_{k+1} - \tau_B G) \|_F^2 + \text{ constant} \right\} \\ &= \texttt{proj}_{T_W\mathcal{M}} (B_k + \tau_B Y_{k+1} - \tau_B G) \end{align*}$$ Thus, $$ \begin{equation} X_{k+1} = \begin{bmatrix} \texttt{proj}_{\| \cdot \| \leq 1} (A_k - \tau_A Y_{k+1}) \\ \texttt{proj}_{T_W\mathcal{M}} (B_k + \tau_B Y_{k+1} - \tau_B G) \end{bmatrix} \end{equation} $$
3.2.2. Block-wise PDHG algorithm for the steepest descent on Finsler manifolds problem
Taking everything together, our iteration becomes,
$$\begin{align} Y_{k+1} &= Y_{k} + \sigma (\widetilde{A}_{k} - \widetilde{B}_{k}) \\ A_{k+1} &= \texttt{proj}_{\| \cdot \| \leq 1} (A_k - \tau_A Y_{k+1}) \\ B_{k+1} &= \texttt{proj}_{T_W\mathcal{M}} (B_k + \tau_B Y_{k+1} - \tau_B G) \\ \widetilde{A}_{k+1} &= A_{k+1} + \theta (A_{k+1} - A_{k}) \\ \widetilde{B}_{k+1} &= B_{k+1} + \theta (B_{k+1} - B_{k}) \end{align}$$
4. Alternative solution to Stiefel Muon via Primal-Dual Hybrid Gradient
Here we have $\mathcal{M} = \texttt{St}(m, n)$ and $\| \cdot \| = \| \cdot \|_{2 \to 2}$. For the projection to the unit spectral norm ball, $\texttt{proj}_{\| \cdot \|_{2 \to 2} \leq 1}$, we can use the GPU/TPU-friendly spectral hardcap function discussed in my previous blog post and in our latest paper.
def spectral_hardcap(W: jax.Array):
if transpose := W.shape[0] > W.shape[1]:
W = W.T
OW = _orthogonalize_via_newton_schulz(W)
aW = OW - W
result = (1/2) * (OW + W - aW @ _orthogonalize_via_newton_schulz(aW).T @ OW)
if transpose:
result = result.T
return result
And for the projection to the tangent space at $W \in \texttt{St}(m, n)$, we can use the projection map discussed in Theorem 2 in this blog post,
$$\texttt{proj}_{T_W\texttt{St}(m, n)}(V) = V - W \text{sym}(W^T V)$$
4.1. Full implementation with adaptive step sizes
def pdhg_stiefel_spectral(
W, G, *,
tau_A=1.0, tau_B=1.0, sigma=0.49, gamma=1.,
max_iters=200, tol=1e-6,
A0=None, B0=None, y0=None
):
m, n = W.shape
A = jnp.zeros((m, n), W.dtype) if A0 is None else A0
B = jnp.zeros((m, n), W.dtype) if B0 is None else B0
y = jnp.zeros((m, n), W.dtype) if y0 is None else y0
A_bar, B_bar = A, B
def cond(state):
_, _, _, _, _, k, res, *_= state
return jnp.logical_and(k < max_iters, res > tol)
def body(state):
A, B, y, A_bar, B_bar, k, _, tau_A, tau_B, sigma = state
# Dual ascent
y_new = y + sigma * (A_bar - B_bar)
# Primal descent (A & B updates)
A_new = spectral_hardcap(A - tau_A * y_new)
B_new = project_to_stiefel_tangent_space(W, B + tau_B * y_new - tau_B * G)
# update step-sizes
tau = 0.5 * (tau_A + tau_B)
theta = 1 / jnp.sqrt(1 + 2 * gamma * tau)
tau_A = theta * tau_A
tau_B = theta * tau_B
sigma = sigma / theta
# Extrapolation
A_bar_new = A_new + theta * (A_new - A)
B_bar_new = B_new + theta * (B_new - B)
res = jnp.linalg.norm(A_new - B_new)
return (A_new, B_new, y_new, A_bar_new, B_bar_new, k+1, res, tau_A, tau_B, sigma)
init = (A, B, y, A_bar, B_bar, 0, jnp.inf, tau_A, tau_B, sigma)
A, B, y, *_ = jax.lax.while_loop(cond, body, init)
return -A
4.2. Experimental results
Here I’ve plotted the alignment <-> off-tangency frontier for the different methods proposed by myself, Jeremy and Jianlin. The alternating projections method seems to do well despite being provably suboptimal in some cases. But the PDHG method closes the gap as we increase the number of iteration. If we initialize $X_0$ and $Y_0$ from the previous optimization step, we can save compute while potentially improving performance.
5. Generalization to arbitrary number of constraints on the update
Our solution above generalizes to arbitrary number of constraints on $A$ so long as the feasible set for each constraint is convex. We then only need to find the metric projection onto each feasible set.
For example, suppose we add another constraint $A \in S$ in Equation (2) above where $S$ is a convex set and $\texttt{proj}_{S}(\cdot)$ is the (metric) projection onto $S$. Then our Equation (5) becomes, $$\begin{equation} A^* = -\left[\arg\min_{A,B,C \in \mathbb{R}^{m \times n}} \{f(A) + g(B) + h(C)\} \quad \text{ s.t. } \quad A - B = A - C = 0\right]_{A} \end{equation}$$ where, $$ h(C) := \mathcal{i}_{S}(C) = \begin{cases} 0 &\text{ if } C \in S \\ \infty &\text{ otherwise} \end{cases} $$
We then define, $$ \begin{align*} X &:= \begin{bmatrix} A \\ B \\ C \end{bmatrix}\\ L &:= \begin{bmatrix} I & -I & \\ I & & -I \end{bmatrix} \\ \mathcal{F}(X) &:= f(A) + g(B) + h(C) \\ \end{align*} $$ and the rest then follows and Equation (11) becomes, $$ \begin{equation} X_{k+1} = \begin{bmatrix} \texttt{proj}_{\| \cdot \| \leq 1} (A_k - \tau_A [Y_{k+1}]_1 - \tau_A [Y_{k+1}]_2) \\ \texttt{proj}_{T_W\mathcal{M}} (B_k + \tau_B [Y_{k+1}]_1 - \tau_B G) \\ \texttt{proj}_{S} (C_k + \tau_C [Y_{k+1}]_2) \end{bmatrix} \end{equation} $$
Acknowledgements
Big thanks to Jeremy Bernstein and Cédric Simal for productive discussions on the topic!
How to cite
@misc{cesista2025steepestdescentfinsler,
author = {Franz Louis Cesista},
title = {"Steepest Descent on Finsler-Structured (Matrix) Manifolds"},
year = {2025},
url = {http://leloykun.github.io/ponder/steepest-descent-finsler/},
}
If you find this post useful, please consider supporting my work by sponsoring me on GitHub:
References
- Jeremy Bernstein (2025). Stiefel manifold. URL https://docs.modula.systems/algorithms/manifold/stiefel/
- Jianlin Su (2025). Muon + Stiefel. URL https://kexue.fm/archives/11221
- Laker Newhouse, R. Preston Hess, Franz Cesista, Andrii Zahorodnii, Jeremy Bernstein, Phillip Isola (2025). Training Transformers with Enforced Lipschitz Bounds. URL https://arxiv.org/abs/2507.13338
- Jeremy Bernstein & Laker Newhouse (2024). Old optimizer, new norm: an anthology. URL https://arxiv.org/abs/2409.20325
- Keller Jordan and Yuchen Jin and Vlado Boza and Jiacheng You and Franz Cesista and Laker Newhouse and Jeremy Bernstein (2024). Muon: An optimizer for hidden layers in neural networks. URL https://kellerjordan.github.io/posts/muon/
- Greg Yang, James B. Simon, Jeremy Bernstein (2024). A Spectral Condition for Feature Learning. URL https://arxiv.org/abs/2310.17813
- ODL (2020). Primal-Dual Hybrid Gradient Algorithm (PDHG). URL https://odlgroup.github.io/odl/math/solvers/nonsmooth/pdhg.html