回顾前两篇文章《基于流式幂迭代的Muon实现:1. 初识》《基于流式幂迭代的Muon实现:2. 加速》,我们引入了Muon的流式幂迭代(Streaming Power Iteration)实现方案,初步验证了它的可行性,并进一步讨论了核心运算——QR分解——的加速,使其接近Newton-Schulz迭代实现的效率。

在这篇文章中,我们不再局限于优化单步的QR分解,而是从更整体的视角看待流式幂迭代,并结合具体的计算背景,对其实现细节做进一步的“精雕细琢”,尽可能减少计算瓶颈,使其效率趋近理论极限。

现有结果 #

流式幂迭代本质上是“边训练边SVD”,它的想法是通过幂迭代来求SVD,并通过缓存上一步的结果,将计算平摊到每一步训练上,使得在优化器中嵌入SVD成为可能。至于Muon,只不过是它的一个基本应用,因为Muon的核心运算$\newcommand{msign}{\mathop{\text{msign}}}\msign$最基本的实现方式就是SVD。具体来说,Muon的更新公式是
\begin{equation}\begin{aligned}
\boldsymbol{M}_t =&\, \beta\boldsymbol{M}_{t-1} + \boldsymbol{G}_t \\[5pt]
\boldsymbol{W}_t =&\, \boldsymbol{W}_{t-1} - \eta_t [\msign(\boldsymbol{M}_t) + \lambda \boldsymbol{W}_{t-1}] \\
\end{aligned}\end{equation}
这里的矩阵都是$n\times m$大小,并约定$n\geq m$。设$\boldsymbol{M}$的SVD是$\boldsymbol{U}\boldsymbol{\Sigma}\boldsymbol{V}^{\top}$(其中$\boldsymbol{U}\in\mathbb{R}^{n\times m}$以及$\boldsymbol{\Sigma},\boldsymbol{V}\in\mathbb{R}^{m\times m}$),那么$\msign(\boldsymbol{M})=\boldsymbol{U}\boldsymbol{V}^{\top}$,因此实现了SVD就实现了$\msign$。当然,直接SVD通常是比较昂贵的,而借助流式幂迭代,能让它变得可行起来。

在上一篇文章中,我们还讨论了流式幂迭代的四种提速思路,其中第一种是启用全精度的FP32乘法,这是通用的,而后三种思路在某种程度上是互斥的,我们只能选其一。笔者建议是选择第二种,它理论上限更高,接下来的雕琢也是基于第二种进行。将第二种思路代入流式幂迭代版并用于Muon,迭代公式是
\begin{equation}\newcommand{QR}{\mathop{\text{QR}}}\newcommand{ColNorm}{\mathop{\text{ColNorm}}}\begin{aligned}
\boldsymbol{M}_t =&\, \beta\boldsymbol{M}_{t-1} + \boldsymbol{G}_t \\[5pt]
\boldsymbol{V}_t =&\, \QR(\boldsymbol{M}_t^{\top}\QR(\boldsymbol{M}_t\boldsymbol{V}_{t-1})) \\[5pt]
\boldsymbol{U}_t =&\, \ColNorm(\boldsymbol{M}_t\boldsymbol{V}_t) \\[5pt]
\boldsymbol{W}_t =&\, \boldsymbol{W}_{t-1} - \eta_t (\boldsymbol{U}_t\boldsymbol{V}_t^{\top} + \lambda \boldsymbol{W}_{t-1}) \\
\end{aligned}\end{equation}
显然,现在最昂贵一步运算是$\boldsymbol{V}_t = \QR(\boldsymbol{M}_t^{\top}\QR(\boldsymbol{M}_t\boldsymbol{V}_{t-1}))$,这正是接下来的优化对象。

加速分解 #

为了保证效率,这里的$\QR$我们并非调用框架自带的QR分解函数,而是用“SCQR(Shifted Cholesky QR)”,它将矩阵$\boldsymbol{A}$的QR分解分为两步:1. 对$\boldsymbol{A}^{\top}\boldsymbol{A} + \lambda \boldsymbol{I}$做Cholesky分解得到上三角阵$\boldsymbol{R}$;2. 解方程$\boldsymbol{Q}\boldsymbol{R}=\boldsymbol{A}$得到正交矩阵$\boldsymbol{Q}$。

这两步理论上都非常高效,问题是它并非总是能成功,所以需要多一步检测,失败后回退到自带的标准QR函数,标准QR几乎一定能成功。不过一旦触发回退,那么端到端的效率将会大打折扣。SCQR失败的原因,主要是Cholesky分解极其依赖于矩阵的条件数,$+\lambda\boldsymbol{I}$正是用来降低$\boldsymbol{A}^{\top}\boldsymbol{A}$条件数的。

然而,这也有个两难全的问题:$\lambda$越大,SCQR越容易成功,但最终结果会越偏离正交(即误差越大),导致效果变差;$\lambda$越小,精度自然是高了,但回退到标准QR的概率也越高,从而导致效率越差。实测发现取$\lambda=\epsilon \Vert\boldsymbol{A}^{\top}\boldsymbol{A}\Vert_F$且$\epsilon=10^{-9}$,能较好地平衡效果和效率。

上一篇文章加速思路,都是围绕降低条件数展开的。第一版流式幂迭代是$\boldsymbol{V}_t = \QR(\boldsymbol{M}_t^{\top}\boldsymbol{M}_t\boldsymbol{V}_{t-1})$,我们要Cholesky分解的矩阵是$\boldsymbol{V}_{t-1}^{\top}(\boldsymbol{M}_t^{\top}\boldsymbol{M}_t)^2\boldsymbol{V}_{t-1}$,即将$\boldsymbol{M}_t$的条件数4次方了,显然暴涨。而改为$\boldsymbol{V}_t = \QR(\boldsymbol{M}_t^{\top}\QR(\boldsymbol{M}_t\boldsymbol{V}_{t-1}))$后,虽然$\QR$变成两次,但每一次Cholesky分解的矩阵条件数只是$\boldsymbol{M}_t$的平方,明显降低条件数,SCQR成功率大大提高,所以速度反而变快。

调整顺序 #

以上的内容都还是前两篇文章的复述(抱歉,铺垫有点长了,但磨刀不误砍柴工),这一节我们才开始讲新的优化思路。仔细留意可以发现,目前我们引入了两次$\QR$,但这两次$\QR$都是独立考虑的。而 @YouJiacheng@Kimi 两位同学发现,如果将它们合起来考虑,能获得一些加速手段。

按照默认的顺序,$\boldsymbol{V}_t = \QR(\boldsymbol{M}_t^{\top}\QR(\boldsymbol{M}_t\boldsymbol{V}_{t-1}))$的计算流程是
\begin{equation}\begin{aligned}
\boldsymbol{A}_{(1), t} =&\, \boldsymbol{M}_t\boldsymbol{V}_{t-1} \\
\boldsymbol{R}_{(1), t}^{\top}\boldsymbol{R}_{(1), t} =&\, \boldsymbol{A}_{(1), t}^{\top}\boldsymbol{A}_{(1), t} + \lambda \boldsymbol{I}\qquad(\text{Cholesky分解}) \\
\boldsymbol{Q}_{(1), t} =&\, \boldsymbol{A}_{(1), t} \boldsymbol{R}_{(1), t}^{-1} \qquad(\text{Triangular Solve}) \\
\boldsymbol{A}_{(2), t} =&\, \boldsymbol{M}_t^{\top}\boldsymbol{Q}_{(1), t} \\
\boldsymbol{R}_{(2), t}^{\top}\boldsymbol{R}_{(2), t} =&\, \boldsymbol{A}_{(2), t}^{\top}\boldsymbol{A}_{(2), t} + \lambda \boldsymbol{I}\quad(\text{Cholesky分解}) \\
\boldsymbol{Q}_{(2), t} =&\, \boldsymbol{A}_{(2), t} \boldsymbol{R}_{(2), t}^{-1} \qquad(\text{Triangular Solve}) \\
\end{aligned}\label{eq:qr2}\end{equation}
其中$\boldsymbol{M}_t\boldsymbol{V}_{t-1}$、$\boldsymbol{A}_{(1), t}^{\top}\boldsymbol{A}_{(1), t}$、$\boldsymbol{A}_{(1), t} \boldsymbol{R}_{(1), t}^{-1}$、$\boldsymbol{M}_t^{\top}\boldsymbol{Q}_{(1), t}$这四步都是$\mathcal{O}(nm^2)$复杂度,剩余的是$\mathcal{O}(m^3)$,在$n \gg m$时,$\mathcal{O}(nm^2)$可能会成为瓶颈。有趣的是,我们可以通过恒等变换,让$\mathcal{O}(nm^2)$只出现一次!
\begin{equation}\begin{aligned}
\boldsymbol{A}_{(1), t} =&\, (\boldsymbol{M}_t^{\top}\boldsymbol{M}_t)\boldsymbol{V}_{t-1} \\
\boldsymbol{R}_{(1), t}^{\top}\boldsymbol{R}_{(1), t} =&\, \boldsymbol{V}_{t-1}^{\top}\boldsymbol{A}_{(1), t} + \lambda \boldsymbol{I}\qquad(\text{Cholesky分解}) \\
\boldsymbol{A}_{(2), t} =&\, \boldsymbol{A}_{(1), t} \boldsymbol{R}_{(1), t}^{-1} \qquad(\text{Triangular Solve}) \\
\boldsymbol{R}_{(2), t}^{\top}\boldsymbol{R}_{(2), t} =&\, \boldsymbol{A}_{(2), t}^{\top}\boldsymbol{A}_{(2), t} + \lambda \boldsymbol{I}\qquad(\text{Cholesky分解}) \\
\boldsymbol{Q}_{(2), t} =&\, \boldsymbol{A}_{(2), t} \boldsymbol{R}_{(2), t}^{-1} \qquad(\text{Triangular Solve}) \\
\end{aligned}\label{eq:qr2-sim}\end{equation}
这个等价版本非常值得细细品味!首先可以证明,它理论上跟原版完全等价,且这种等价性不依赖于$\boldsymbol{V}_{t-1}$和$\boldsymbol{Q}_{(1), t}$的绝对正交性;经过变换之后,只有$\boldsymbol{M}_t^{\top}\boldsymbol{M}_t$这一步是$\mathcal{O}(nm^2)$的,剩下的都是$\mathcal{O}(m^3)$,并且总步数还少了一步(将原本的$\boldsymbol{Q}_{(1), t} = \boldsymbol{A}_{(1), t} \boldsymbol{R}_{(1), t}^{-1}$和$\boldsymbol{A}_{(2), t} = \boldsymbol{M}_t^{\top}\boldsymbol{Q}_{(1), t}$合成了一步)!

注1:根据 @YouJiacheng 表述,这个精妙的变换是他将式$\eqref{eq:qr2}$告诉Kimi后Kimi自动发现的;

注2:式$\eqref{eq:qr2}$和$\eqref{eq:qr2-sim}$实际上会有一点细微的区别——式$\eqref{eq:qr2}$第一步是$(\boldsymbol{M}_t\boldsymbol{V}_{t-1})^{\top}(\boldsymbol{M}_t\boldsymbol{V}_{t-1})$,而式$\eqref{eq:qr2-sim}$则是$\boldsymbol{V}_{t-1}^{\top}(\boldsymbol{M}_t^{\top}\boldsymbol{M}_t)\boldsymbol{V}_{t-1}$,这两种算法数学上完全等价,但对有限精度的浮点运算会有所区别,后者乘出来的矩阵条件数会更大,Cholesky分解可能需要稍微调大一点正则。

简化正则 #

对于矩阵$\boldsymbol{A}^{\top}\boldsymbol{A}$,我们在Cholesky分解时添加的正则项是$\lambda\boldsymbol{I}$,其中$\lambda=\epsilon \Vert\boldsymbol{A}^{\top}\boldsymbol{A}\Vert_F$且$\epsilon=10^{-9}$,但为何取这个形式一直没有做详细解释,这里我们展开介绍一下,并结合问题背景得出一个更简洁的正则项。

由于$\boldsymbol{A}^{\top}\boldsymbol{A}$的正定对称性,它的SVD跟特征值分解相同,不妨设它的SVD为$\boldsymbol{A}^{\top}\boldsymbol{A} = \boldsymbol{V}\boldsymbol{\Sigma}\boldsymbol{V}^{\top}$,那么$\boldsymbol{A}^{\top}\boldsymbol{A} + \lambda\boldsymbol{I} = \boldsymbol{V}(\boldsymbol{\Sigma} + \lambda\boldsymbol{I})\boldsymbol{V}^{\top}$,设$\boldsymbol{A}^{\top}\boldsymbol{A}$的最大、最小奇异值分别为$\sigma_{\max},\sigma_{\min}$,那么$\boldsymbol{A}^{\top}\boldsymbol{A} + \lambda\boldsymbol{I}$的最大、最小奇异值为$\sigma_{\max} + \lambda,\sigma_{\min} + \lambda$,条件数是最大最小奇异值之比,所以它从$\sigma_{\max}/\sigma_{\min}$降低到
\begin{equation}\frac{\sigma_{\max} + \lambda}{\sigma_{\min} + \lambda} < \frac{\sigma_{\max} + \lambda}{\lambda} = \frac{\sigma_{\max}}{\lambda} + 1\end{equation}
如果我们想要条件数控制到不超过$1/\epsilon + 1$,那么$\lambda \geq \epsilon \sigma_{\max}$,这表明理想情况下我们应该要用$\boldsymbol{A}^{\top}\boldsymbol{A}$的最大奇异值——也就是谱范数——来作为基准来调节$\lambda$。但谱范数算起来比较复杂,所以我们改用了比较简单的F范数$\Vert\boldsymbol{A}^{\top}\boldsymbol{A}\Vert_F$,这便是$\lambda=\epsilon \Vert\boldsymbol{A}^{\top}\boldsymbol{A}\Vert_F$的来源,至于$\epsilon=10^{-9}$则纯粹是实验结论了。

然而,“谱范数算起来比较复杂故而改用F范数”只是针对任意矩阵的一般结论,这里我们要做的流式幂迭代本身就是用来求SVD的,随着训练的进行,$\boldsymbol{V}_t$会越来越接近$\boldsymbol{M}_t$的右奇异矩阵,由于$\boldsymbol{M}_t$是缓变的,所以$\boldsymbol{V}_{t-1}$也大差不差,因此理论上$\boldsymbol{V}_{t-1}^{\top}\boldsymbol{M}_t^{\top}\boldsymbol{M}_t\boldsymbol{V}_{t-1}$会越来越接近一个对角阵,它左上角的元素会越来越接近它的谱范数!

同理,$\tilde{\boldsymbol{U}}_t = \QR(\boldsymbol{M}_t\boldsymbol{V}_{t-1})$会越来越接近$\boldsymbol{M}_t$的左奇异矩阵,所以$\tilde{\boldsymbol{U}}_t^{\top}\boldsymbol{M}_t\boldsymbol{M}_t^{\top}\tilde{\boldsymbol{U}}_t$也会越来越接近一个对角阵,其左上角的元素会越来越接近它的谱范数。因此,在我们的场景下,最简单也最准确的基准,就是直接用$(\boldsymbol{A}^{\top}\boldsymbol{A})_{[0,0]}$作为谱范数近似,即$\lambda=\epsilon \cdot (\boldsymbol{A}^{\top}\boldsymbol{A})_{[0,0]}$就行了,实测$\epsilon=10^{-7}$就能较好地平衡效果和效率。

参考实现 #

综合上面两节的改动,从$\boldsymbol{V}_{t-1}$到$\boldsymbol{V}_t$的迭代的参考实现为:

import jax.numpy as jnp
from jax.scipy.linalg import solve_triangular
from jax import lax

def shift_old(A, eps=1e-9):
    return A + eps * jnp.linalg.matrix_norm(A, keepdims=True) * jnp.eye(A.shape[-1])

def scqr(A, eps=1e-9):
    """先按Shifted Cholesky QR算,失败则回退到默认QR
    """
    R = jnp.linalg.cholesky(shift_old(A.mT @ A, eps), upper=True)
    Q = solve_triangular(R.mT, A.mT, lower=True).mT
    return lax.cond(jnp.isfinite(Q).all(), lambda: Q, lambda: jnp.linalg.qr(A)[0])

def v_step_old(M, V, eps=1e-9):
   return scqr(M.mT @ scqr(M @ V, eps), eps)


def shift(A, eps=1e-7):
    return A + eps * A[..., :1, :1] * jnp.eye(A.shape[-1])

def v_step(M, V, eps=1e-7):
    A = (M.mT @ M) @ V
    R = jnp.linalg.cholesky(shift(V.mT @ A, upper=True)
    B = solve_triangular(R.mT, A.mT, lower=True).mT
    R = jnp.linalg.cholesky(shift(B.mT @ B, eps), upper=True)
    Q = solve_triangular(R.mT, B.mT, lower=True).mT
    return lax.cond(jnp.isfinite(Q).all(), lambda: Q, lambda: jnp.linalg.qr(A)[0])

同期工作 #

《基于流式幂迭代的Muon实现:2. 加速》发布之后到本文发布之前的这段时间,外界也出现了一些有趣的优化工作,它们跟本文所提的两处改动有着类似的优化思想,我们可以将它们放在一起学习一下。

首先是在上篇发布之后,@Ji_Ha_Kim 同学也提出了一些改进思路。比如,他表示在跟GPT交流中发现(链接),我们也许可以省去一次$\text{Triangular Solve}$!具体来说,我们有
\begin{equation}\begin{aligned}
\boldsymbol{V}_t =&\, \QR(\boldsymbol{M}_t^{\top}\QR(\boldsymbol{M}_t\boldsymbol{V}_{t-1})) \\
=&\, \QR(\boldsymbol{V}_{t-1}(\boldsymbol{M}_t \boldsymbol{V}_{t-1})^{\top}\QR(\boldsymbol{M}_t\boldsymbol{V}_{t-1})) \\
=&\, \QR(\boldsymbol{V}_{t-1}(\QR(\boldsymbol{M}_t\boldsymbol{V}_{t-1})^{\top}\boldsymbol{M}_t \boldsymbol{V}_{t-1})^{\top}) \\
=&\, \boldsymbol{V}_{t-1}\QR((\QR(\boldsymbol{M}_t\boldsymbol{V}_{t-1})^{\top}\boldsymbol{M}_t \boldsymbol{V}_{t-1})^{\top}) \\
\end{aligned}\end{equation}
易知$\QR(\boldsymbol{M}_t\boldsymbol{V}_{t-1})^{\top}\boldsymbol{M}_t \boldsymbol{V}_{t-1}$就是对$\boldsymbol{M}_t\boldsymbol{V}_{t-1}$做QR分解后的R,这可以直接通过Cholesky分解获得。

换言之,理论上第一次Cholesky分解得到R后就可以进行第二次$\QR$,省去一次$\text{Triangular Solve}$。然而,这只有理论意义,因为这个结果依赖于$\boldsymbol{V}_{t-1}$和$\QR(\boldsymbol{M}_t\boldsymbol{V}_{t-1})$的严格正交性,这仅在精确QR(即$\lambda=0$)的前提下成立,实践中为了保证效率,我们只能用SCQR,这导致结果并非严格正交,所以在恒等变换过程中提前利用它的正交性反而会导致迭代算法误差累积,

期间,Tri-Dao团队还发表了《Gram Newton-Schulz: A Fast, Hardware-Aware Newton-Schulz Algorithm for Muon》,提出了$\msign$算子的一种加速思路。根据定义,$\msign(\boldsymbol{M}) = \boldsymbol{M}(\boldsymbol{M}^{\top}\boldsymbol{M})^{-1/2}$,团队希望通过Newton-Schulz迭代来计算$m\times m$的$(\boldsymbol{M}^{\top}\boldsymbol{M})^{-1/2}$而非$\msign$,这样在$n\gg m$时就能明显降低计算量。事实上,此前很多研究人员都尝试过这个思路,但都失败了,而Tri-Dao他们通过Restart巧妙地解决了这个问题。

显然,$\msign$的这个优化方向跟流式幂迭代从$\eqref{eq:qr2}$到$\eqref{eq:qr2-sim}$的转变是一致的。无独有偶,@Ji_Ha_Kim建议将$\msign$的Newton-Schulz迭代从多项式改为分式,这样可以用更少的迭代次数达到同样好的效果。分式迭代的问题是需要求逆矩阵,不过结合到$\msign$的具体背景后,它只需要对$m\times m$的正定对称矩阵求逆,这可以通过Cholesky分解和两次$\text{Triangular Solve}$来求,尚可接受。

不过,如此一来这种分式迭代的计算流程其实跟流式幂迭代高度重合了,并且每一步还会多出一次$\text{Triangular Solve}$,所以看起来它的速度是无法优于流式幂迭代。

文章小结 #

本文对流式幂迭代的实现细节进行了进一步“精雕细琢”,主要改进包括:1. 调整计算顺序将$\mathcal{O}(nm^2)$复杂度的运算从四次降至一次;2. 利用流式幂迭代的特殊背景简化正则项。这些优化进一步减少了流式幂迭代的计算瓶颈,将计算效率推向极致。

转载到请包括本文地址:https://kexue.fm/archives/11697

更详细的转载事宜请参考:《科学空间FAQ》

如果您还有什么疑惑或建议,欢迎在下方评论区继续讨论。

如果您觉得本文还不错,欢迎分享/打赏本文。打赏并非要从中获得收益,而是希望知道科学空间获得了多少读者的真心关注。当然,如果你无视它,也不会影响你的阅读。再次表示欢迎和感谢!

如果您需要引用本文,请参考:

苏剑林. (Apr. 07, 2026). 《基于流式幂迭代的Muon实现:3. 雕琢 》[Blog post]. Retrieved from https://kexue.fm/archives/11697

@online{kexuefm-11697,
        title={基于流式幂迭代的Muon实现:3. 雕琢},
        author={苏剑林},
        year={2026},
        month={Apr},
        url={\url{https://kexue.fm/archives/11697}},
}