强制间隔投影(Margin-Enforcing Projection)

苏剑林 • 2026-06-19

这篇文章我们介绍一个数学运算,名为“强制间隔投影(Margin-Enforcing Projection,MEP)”,它将向量按指定方式划分为两部分,然后要求这两部分的间隔至少为$m$(Margin)。

背景简介 #

我们知道,分类任务是要选出正确的类别,所以训练目标通常是“正类分数大于负类分数”就行。但在某些场景下,我们不仅希望正类分数要超过负类分数,还希望至少超过指定间隔$m > 0$。

这些场景主要分两种。第一种是希望分类结果更为稳健,不轻易受到随机噪声的干扰,尤其是低精度推理场景,所以我们希望预测结果的间隔能更显著一些;第二种是根本不用分类模型,而是想着通过分类来学习特征,最终场景是用特征去检索,此时如果不设置Margin,那么边界处的检索结果就很容易出错。

实际上,这两个场景在早些年就已经非常成熟,尤其是场景二,典型代表就是人脸识别特征模型的训练,所以本文其实算是“挖坟”了。该问题的标准思路是设计各种Margin Loss,比如Hinge Loss、Margin Softmax、AM-Softmax等,我们之前在《基于GRU和AM-Softmax的句子相似度模型》《从三角不等式到Margin Softmax》中也略有介绍。

本文的设想是,如果我们有某种运算,能够将预测分数$\boldsymbol{x}$投影成符合间隔要求的分数$\boldsymbol{y}$,那么直接以它为目标给模型学就是了,比如最小化$\Vert\boldsymbol{y} - \boldsymbol{x}\Vert_2^2$。而这个投影运算,便是接下来要研究的问题。

数学定义 #

记$\boldsymbol{x}=(x_1,x_2,\cdots,x_n), \boldsymbol{z}=(z_1,z_2,\cdots,z_n)$,$\boldsymbol{x}$代表模型的预测分数,简单起见,我们假设前$k$个分数代表得分,后$n-k$个代表负类得分,我们定义
\begin{equation}\newcommand{argmin}{\mathop{\text{argmin}}}\mathcal{P}_m(\boldsymbol{x})\triangleq \argmin_{\boldsymbol{z}\in\mathbb{R}^n} d(\boldsymbol{z}, \boldsymbol{x})\quad\text{s.t.}\quad \min(\boldsymbol{z}_{\leq k}) - \max(\boldsymbol{z}_{> k}) \geq m\end{equation}
其中$\boldsymbol{z}_{\leq k} = (z_1,\cdots,z_k), \boldsymbol{z}_{> k} = (z_{k+1},\cdots,z_n)$,$k$是大于0、小于$n$的整数,$m > 0$是给定的间隔,$d(\boldsymbol{z}, \boldsymbol{x})$是要最小化的距离函数,我们有L1距离和L2距离两种选择,接下来会分别讨论。

这个定义很直观,就是在满足间隔要求的前提下,找跟预测分数最接近的向量,这符合投影运算的一般思想,所以我们称之为“强制间隔投影(Margin-Enforcing Projection,MEP)”。以这样的投影结果作为学习目标,理论上能让模型往最轻松、最快捷的路径走,同时也能在达到目标后及时“刹车”,避免过度训练。

共同结果 #

若$\boldsymbol{x}$本就满足$\min(\boldsymbol{x}_{\leq k}) - \max(\boldsymbol{x}_{> k}) \geq m$,那么显然$\boldsymbol{z}^* = \boldsymbol{x}$,这是平凡的。不失一般性,下面都假设$\min(\boldsymbol{x}_{\leq k}) - \max(\boldsymbol{x}_{> k}) < m$。不难看出,不管$d$选择L1还是L2距离,最优解$\boldsymbol{z}^*$都必然在\begin{equation}\min(\boldsymbol{z}^*_{\leq k}) - \max(\boldsymbol{z}^*_{> k}) = m\end{equation}
取到,否则总可以让某些$z_i$更靠近$x_i$来降低目标值。基于这个观察,设$\min(\boldsymbol{z}^*_{\leq k})=\ell$,那么$\max(\boldsymbol{z}^*_{> k}) = \ell - m$,那么可以得到
\begin{equation}\boldsymbol{z}^*_{\leq k} = \max(\boldsymbol{x}_{\leq k}, \ell),\qquad \boldsymbol{z}^*_{> k} = \min(\boldsymbol{x}_{> k}, \ell - m)\end{equation}
此时
\begin{equation}\begin{aligned}
\boldsymbol{z}^* - \boldsymbol{x} =&\, [\max(\boldsymbol{x}_{\leq k}, \ell) - \boldsymbol{x}_{\leq k}, \min(\boldsymbol{x}_{> k}, \ell - m) - \boldsymbol{x}_{> k}] \\[4pt]
=&\, [\max(\ell - \boldsymbol{x}_{\leq k}, 0), \min(\ell - m - \boldsymbol{x}_{> k}, 0)] \\[4pt]
=&\, [\max(\ell - \boldsymbol{x}_{\leq k}, 0), -\max(\boldsymbol{x}_{> k} + m - \ell, 0)]
\end{aligned}\end{equation}
接下来就是要根据不同的$d$来求解$\ell$。

距离之二 #

先考虑L2距离,此时我们有
\begin{equation}\Vert\boldsymbol{z}^* - \boldsymbol{x}\Vert_2^2 = \sum_{i=1}^k\max(\ell - x_i, 0)^2 + \sum_{j=k+1}^n \max(x_j + m - \ell, 0)^2 \triangleq f(\ell)\end{equation}
求导得
\begin{gather}f'(\ell) = 2\sum_{i=1}^k\max(\ell - x_i, 0) - 2\sum_{j=k+1}^n \max(x_j + m - \ell, 0) \\[5pt]
f''(\ell) = 2\#\{\ell > x_i\} + 2\#\{\ell < x_j + m\} \end{gather}
其中$\#$是计数函数,约定$1\leq i\leq k < j\leq n$。显然$f''(\ell) \geq 0$,但还可以加强到$f''(\ell) > 0$。

这是因为$f''(\ell) = 0$意味着$\#\{\ell > x_i\}=0$且$\#\{\ell < x_j + m\}=0$,即同时成立$\min(\boldsymbol{x}_{\leq k}) \geq \ell$和$\max(\boldsymbol{x}_{> k})\leq \ell - m$,这跟假设$\min(\boldsymbol{x}_{\leq k}) - \max(\boldsymbol{x}_{> k}) < m$矛盾。所以$f''(\ell) > 0$,即$f(\ell)$是严格凸的,加上$f(\ell)$和$f'(\ell)$的连续性,以及$f'(-\infty)=-\infty$和$f'(\infty)=\infty$,可以得出$f(\ell)$的最小值点只有一个,并且必然在$f'(\ell)=0$处取到。

由于$f'(\ell)$是分段线性函数$\max(x, 0)$的复合,所以$f'(\ell)$也是$\ell$的分段线性函数,边界点是全体$x_i$和$x_j + m$。为了求解$f'(\ell)=0$,我们先将边界点$\{x_1,\cdots,x_k,x_{k+1}+m,\cdots,x_n+m\}$从小到大排列,得到$n-1$个区间,在单个区间内,$f'(\ell)$是一条直线,遍历所有区间$[a, b]$,找到$f'(a) \leq 0$且$f'(b) \geq 0$的区间,在该区间内求直线的零点即可。

距离之一 #

接着考虑L1距离,此时我们有
\begin{equation}\Vert\boldsymbol{z}^* - \boldsymbol{x}\Vert_1 = \sum_{i=1}^k\max(\ell - x_i, 0) + \sum_{j=k+1}^n \max(x_j + m - \ell, 0) \triangleq g(\ell)\end{equation}
显然,$g(\ell)$本身就是分段线性函数,这种函数的最小值只能在边界点取到,所以最朴素的解法就是遍历所有边界点$\{x_1,\cdots,x_k,x_{k+1}+m,\cdots,x_n+m\}$,取让$g(\ell)$最小者,复杂度为$\mathcal{O}(n^2)$。但我们还可以更进一步简化,首先求导得
\begin{equation}\begin{aligned}
g'(\ell) =&\, \#\{\ell > x_i\} - \#\{\ell < x_j + m\} \\[4pt]
=&\, \#\{\ell > x_i\} + \#\{\ell \geq x_j + m\} - (n - k)
\end{aligned}\end{equation}
第二个等号用了恒等式$\#\{\ell < x_j + m\} + \#\{\ell \geq x_j + m\} = n - k$,现在容易看出,$g'(\ell)$是从负到正单调递增的。当然,$g'(\ell)$不是连续的,我们没法保证找到$g'(\ell)=0$的点。

不过,如果我们将$\#\{\ell > x_i\}$改为$\#\{\ell \geq x_i\}$(不连续的边界点的导数可以视为任意的,所以这种调整是允许的),那么$g'(\ell)=0$的含义正好是“小于等于$\ell$的边界点刚好有$n-k$个”,如果进一步假设所有边界点两两不同,那么$l$正好是全体边界点从小到大排序后的第$n-k$个!所以,L1场景下,其实一次排序就可以得到$\ell$,相当漂亮。

参考实现 #

两个版本的MEP参考实现如下:

import jax
import jax.numpy as jnp

@jax.jit
def l2mep(inputs, mask, margin):
    x, m = inputs, margin
    u = jnp.where(mask, x, x + m).sort()[:, None]
    v = jnp.where(mask, jnp.fmax(u - x, 0), -jnp.fmax(x + m - u, 0)).sum(axis=1)
    i = ((v[:-1] < 0) & (v[1:] >= 0)).argmax()
    l = (u[i + 1] * v[i] - u[i] * v[i + 1]) / (v[i] - v[i + 1])
    return jnp.where(mask, jnp.fmax(x, l), jnp.fmin(x, l - m))

@jax.jit
def l1mep(inputs, mask, margin):
    x, m = inputs, margin
    u = jnp.where(mask, x, x + m).sort(axis=-1)
    l = jnp.take_along_axis(u, (~mask).sum(axis=-1, keepdims=True) - 1)
    return jnp.where(mask, jnp.fmax(x, l), jnp.fmin(x, l - m))

这里稍加说明一下,实际场景下,正类通常不会简单地排列在前$k$个,数目、位置都可能不确定,所以在上述实现中,我们用一个mask向量来标记正负类。

l2mep当前实现只支持1d的输入,如果需要支持batch维度,用jax.vmap包装一层即可。这个算法通过遍历所有区间来找零点区间,每步复杂度是$\mathcal{O}(n)$,总复杂度是$\mathcal{O}(n^2)$。如果想要提高效率,可以改用二分法来找零点区间,这样能降低到$\mathcal{O}(n\log n)$

l1mep由于算法简单,所以现在的写法就已经支持任意的batch维度,它的核心运算只有排序,复杂度是$\mathcal{O}(n \log n)$,不管从简洁还是速度看都已经相当理想,如果没有别的考虑,实践中推荐用L1版。

文章小结 #

本文介绍了“强制间隔投影(Margin-Enforcing Projection,MEP)”这一运算:给定一个分数向量,把它投影到满足“正类最小值至少比负类最大值大$m$”的最近向量上。这可以为传统的Margin Learning提供一些新的思路。

转载到请包括本文地址:https://kexue.fm/archives/11784

更详细的转载事宜请参考:《科学空间FAQ》