此论文通过隐式方法将优化问题表达为神经网络中的一层,它的内部不是传统意义上的神经元堆叠,不是包含全连接层、激活函数等结构,而是一个特殊的“可学习的可微优化层”
通过理论推导得到它的前向传播和反向传播的公式表达,并不需要定义它的内部结构
2. 相关研究
使用神经网络解决受限类别的优化问题的整体思路一般可以分为四种
基于能量的学习方法(Energy-based learning methods)
总体思路为:在训练过程中将能量函数在观测数据流形附近调低,而在其他地方调高。缺点:一些问题中可能没有参测数据;可能会出现不稳定问题
解析法(Analytically)
如果能找到优化问题的解析解,那么梯度一般也可以解析地计算出来。缺点:大多数问题没有解析解
展开(Unrolling)
将优化问题的迭代求解过程近似展开为神经网络的结构。缺点:展开操作会增加网络的复杂性和深度;对于有约束的问题可能难以展开
最小化操作求导(Argmin differentiation)
类似本文的操作,对argmin问题进行微分
3. OptNet:在神经网络中解决优化问题
本文重点研究二次规划问题
zmin s.t. 21z⊤Qz+q⊤zAz=b,Gz≤h
其中z∈Rn是优化变量,Q∈Rn,n⪰0,q∈Rn,A∈Rm,n,b∈Rm,G∈Rp,n,h∈Rp
3.1 反向传播
拉格朗日函数可以写为
L(z,ν,λ)=21z⊤Qz+q⊤z+ν⊤(Az−b)+λ⊤(Gz−h)
其中ν和λ≥0是对偶变量。可以写出它的KKT条件(分别为平稳性条件、原始可行性和互补松弛性)
Qz∗+q+A⊤ν∗+G⊤λ∗=0Az∗−b=0D(λ∗)(Gz∗−h)=0
其中D(⋅)表示由向量展开为对角矩阵,z∗,ν∗,λ∗代表最优原始变量和最优对偶变量
对这些条件求微分得到
ΔQz∗+QΔz∗+Δq+ΔA⊤ν∗+A⊤Δν∗+ΔG⊤λ∗+G⊤Δλ∗=0ΔAz∗+AΔz∗−Δb=0D(Gz∗−h)Δλ∗+D(λ∗)(ΔGz∗+GΔz∗−Δh)=0
写成矩阵形式为
QD(λ∗)GAG⊤D(Gz∗−h)0A⊤00Δz∗Δλ∗Δν∗=−Δq−ΔA⊤ν∗−ΔG⊤λ∗−ΔQz∗−D(λ∗)ΔGz∗+D(λ∗)Δh−ΔAz∗+Δb
KKT条件是优化问题的必要条件,即如果x∗是最优解,则它必须满足KKT条件
对于凸优化问题,KKT条件是充要条件,即满足KKT条件的解一定是最优解
这一层的输入为zi,部分参数是由zi决定的,优化变量为z,输出为z∗,参数为q,b,h,Q,A,G。想要进行反向传播,需要计算得到损失对于这些参数的梯度
式(7)的推导
将KKT条件的矩阵形式写为紧凑形式
KΔz∗Δλ∗Δν∗=r
令上一层传回来的梯度为
g=∂z∗∂ℓ
那么微分可以写为
Δℓ=g⊤Δz∗
定义
e=g00
那么
Δℓ=e⊤Δz∗Δλ∗Δν∗
结合紧凑版KKT条件,上式可以进一步写为
Δℓ=e⊤K−1r
为了避免显式求解e⊤K−1,引入变量
d=dzdλdν
它满足
K⊤d=−e
这里K⊤为
K⊤=QGAG⊤D(λ∗)D(Gz∗−h)0A⊤00
展开d=(K⊤)−1(−e)就可以得到论文中式(7)的形式
式(8)的推导
因为用d取代了(K⊤)−1(−e),因此Δℓ可以表示为
Δℓ=−d⊤r
展开为
Δℓ=dzTΔQz∗+dzTΔq+dzTΔGTλ∗+dzTΔATν∗+dλTD(λ∗)ΔGz∗−dλTD(λ∗)Δh+dνTΔAz∗−dνTΔb
下面逐项拆解梯度
对q的梯度
含Δq的项为
dzTΔq
因此
∂q∂ℓ=dz
对b的梯度
含Δb的项为
−dνTΔb
因此
∂b∂ℓ=−dν
对h的梯度
含Δh的项为
−dλTD(λ∗)Δh
因为 D(λ⋆) 是对角矩阵,所以
−dλTD(λ∗)Δh=−(D(λ∗)dλ)⊤Δh
因此
∂h∂ℓ=−D(λ∗)dλ
对Q的梯度
含ΔQ的项为
dzTΔQz∗
利用迹运算
dzTΔQz∗=tr(dzTΔQz∗)=tr(z∗dzTΔQ)
因此正常应该为
∂Q∂ℓ=dzz∗T
但为了结果保持对称性,梯度写为
∂Q∂ℓ=21(dzz∗T+z∗dzT)
对A的梯度
含ΔA的项为
dzTΔATν∗+dνTΔAz∗
因此
∂A∂ℓ=dνz∗T+ν∗dzT
对G的梯度
含ΔG的项为
dzTΔGTλ∗+dλTD(λ∗)ΔGz∗
因此
∂G∂ℓ=D(λ∗)dλz∗T+λ∗dzT
这样便得到了式(8)的形式
这样不需要显示求解显式求解e⊤K−1,而是通过求解
K⊤d=−e
得到d的值,继而求得反向传播时的参数梯度
3.2 正向传播
由于传统求解器Gurobi和CPLEX只能运行在cpu上,在进行批量运算时耗时较长,作者根据原始-对偶内点法开发了能跑在gpu上的求解器qpth,具体公式推导不再展开
值得一提的是,在正向传播时会顺便把反向传播时需要的d求出,所以计算耗时主要集中在正向传播,反向传播开销极低