谷歌用算力爆了一篇論文,解答有關(guān)無限寬度網(wǎng)絡(luò)的一切
無限寬度神經(jīng)網(wǎng)絡(luò)是近來一個重要的研究課題,但要通過實證實驗來探索它們的性質(zhì),必需大規(guī)模的計算能力才行。近日,谷歌大腦公布的一篇論文介紹了他們在有限和無限神經(jīng)網(wǎng)絡(luò)方面的系統(tǒng)性探索成果。該研究通過大規(guī)模對比實驗得到了 12 條重要的實驗結(jié)論并在此過程中找到了一些新的改進(jìn)方法。該文作者之一 Jascha Sohl-Dickstein 表示:「這篇論文包含你想知道的但沒有足夠的計算能力探求的有關(guān)無限寬度網(wǎng)絡(luò)的一切!」

近日,谷歌大腦的研究者通過大規(guī)模實證研究探討了寬神經(jīng)網(wǎng)絡(luò)與核(kernel)方法之間的對應(yīng)關(guān)系。在此過程中,研究者解決了一系列與無限寬度神經(jīng)網(wǎng)絡(luò)研究相關(guān)的問題,并總結(jié)得到了 12 項實驗結(jié)果。
此外,實驗還額外為權(quán)重衰減找到了一種改進(jìn)版逐層擴(kuò)展方法,可以提升有限寬度網(wǎng)絡(luò)的泛化能力。
最后,他們還為使用 NNGP(神經(jīng)網(wǎng)絡(luò)高斯過程)和 NT(神經(jīng)正切)核的預(yù)測任務(wù)找到了一種改進(jìn)版的最佳實踐,其中包括一種全新的集成(ensembling)技術(shù)。這些最佳實踐技術(shù)讓實驗中每種架構(gòu)對應(yīng)的核在 CIFAR-10 分類任務(wù)上均取得了當(dāng)前最佳的成績。

論文鏈接:https://arxiv.org/pdf/2007.15801v1.pdf
當(dāng)使用貝葉斯方法和梯度下降方法訓(xùn)練的神經(jīng)網(wǎng)絡(luò)的中間層是無限寬時,這些網(wǎng)絡(luò)可以收斂至高斯過程或緊密相關(guān)的核方法。這些無限寬度網(wǎng)絡(luò)的預(yù)測過程可通過貝葉斯網(wǎng)絡(luò)的神經(jīng)網(wǎng)絡(luò)高斯過程(NNGP)核函數(shù)來描述,也可通過梯度下降方法所訓(xùn)練網(wǎng)絡(luò)的神經(jīng)正切核(NTK)和權(quán)重空間線性化來描述。
這種對應(yīng)關(guān)系是近來在理解神經(jīng)網(wǎng)絡(luò)方面獲得突破的關(guān)鍵,同時還使核方法、貝葉斯深度學(xué)習(xí)、主動學(xué)習(xí)和半監(jiān)督學(xué)習(xí)取得了切實的進(jìn)步。在為大規(guī)模神經(jīng)網(wǎng)絡(luò)提供確切理論描述時,NNGP、NTK 和相關(guān)的寬度限制都是獨(dú)特的。因此可以相信它們?nèi)詫⒗^續(xù)為深度學(xué)習(xí)理論帶來變革。
無限網(wǎng)絡(luò)是近來一個活躍的研究領(lǐng)域,但其基礎(chǔ)性的實證問題仍待解答。谷歌大腦的這項研究對有限和無限寬度神經(jīng)網(wǎng)絡(luò)進(jìn)行了廣泛深入的實證研究。在此過程中,研究者通過實證數(shù)據(jù)定量地解答了影響有限網(wǎng)絡(luò)和核方法性能的變化因素,揭示了出人意料的新行為,并開發(fā)了可提升有限與無限寬度網(wǎng)絡(luò)性能的最佳實踐。
實驗設(shè)計
為了系統(tǒng)性地對無限和有限神經(jīng)網(wǎng)絡(luò)進(jìn)行實證研究,研究者首先確立了每種架構(gòu)的 base,方便直接對比無限寬度核方法、線性化權(quán)重空間網(wǎng)絡(luò)和基于非線性梯度下降的訓(xùn)練方法。對于有限寬度的情況,base 架構(gòu)使用了恒定小學(xué)習(xí)率且損失為 MSE(均方誤差)的 mini-batch 梯度下降。在核學(xué)習(xí)設(shè)置中,研究者為整個數(shù)據(jù)集計算了 NNGP 和 NTK。
完成這種一對一的比較之后,研究者在 base 模型之上進(jìn)行了大量不同種類的修改。某些修改會大致保留其對應(yīng)關(guān)系(比如數(shù)據(jù)增強(qiáng)),而另一些則會打破這種對應(yīng)關(guān)系,并且假設(shè)對應(yīng)關(guān)系的打破會影響到性能結(jié)果(比如使用較大的學(xué)習(xí)率)。
此外,研究者還圍繞 base 模型的初始化對其進(jìn)行線性化嘗試,在這種情況下,其訓(xùn)練動態(tài)可使用常量核來精準(zhǔn)地描述。由于有限寬度效應(yīng),這不同于前文描述的核設(shè)置。
該研究使用 MSE 損失的原因是能更容易地與核方法進(jìn)行比較,交叉熵?fù)p失在性能方面比 MSE 損失略好,但這還留待未來研究。
該研究涉及的架構(gòu)要么是基于全連接層(FCN)構(gòu)建的,要么就是用卷積層(CNN)構(gòu)建的。所有案例都使用了 ReLU 非線性函數(shù)。除非另有說明,該研究使用的模型都是 3 層的 FCN 和 8 層的 CNN。對于卷積網(wǎng)絡(luò),在最后的讀出層(readout layer)之前必須壓縮圖像形狀數(shù)據(jù)的空間維度。為此,要么是將圖像展平為一維向量(VEC),要么是對空間維度應(yīng)用全局平均池化(GAP)。
最后,研究者比較了兩種參數(shù)化網(wǎng)絡(luò)權(quán)重和偏置的方法:標(biāo)準(zhǔn)參數(shù)化(STD)和 NTK 參數(shù)化(NTK)。其中 STD 用于有限寬度網(wǎng)絡(luò)的研究,NTK 則在目前大多數(shù)無限寬度網(wǎng)絡(luò)研究中得到應(yīng)用。
除非另有說明,該研究中所有核方法的實驗都是基于對角核正則化(diagonal kernel regularization)獨(dú)立優(yōu)化完成的。有限寬度網(wǎng)絡(luò)則全都使用了與 base 模型相對應(yīng)的小學(xué)習(xí)率。
這篇論文中的實驗基本都是計算密集型的。舉個例子,要為 CNN-GAP 架構(gòu)在 CIFAR-10 上計算 NTK 或 NNGP,就必須用 6×10^7 乘 6×10^7 的核矩陣對各項進(jìn)行評估。通常來說,這需要雙精度 GPU 時間約 1200 小時,因此研究者使用了基于 beam 的大規(guī)模分布式計算基礎(chǔ)設(shè)施。
所有實驗都使用了基于 JAX 的 Neural Tangents 庫:https://github.com/google/neural-tangents。
為了盡可能地做到系統(tǒng)性,同時又考慮到如此巨大的計算需求,于是研究者僅使用了一個數(shù)據(jù)集 CIFAR-10,即在該數(shù)據(jù)集上評估對每種架構(gòu)的每種修改措施。同時,為了保證結(jié)果也適用于不同的數(shù)據(jù)集,研究者還在 CIFAR-100 和 Fashion-MNIST 上評估了部分關(guān)鍵結(jié)果。
從實驗中得到的 12 條結(jié)論
以下為基于實驗結(jié)果總結(jié)的 12 個結(jié)論(詳細(xì)分析請參閱原論文):
1. NNGP/NTK 的表現(xiàn)可勝過有限網(wǎng)絡(luò)
在無限網(wǎng)絡(luò)研究中,一個常見假設(shè)是它們在大數(shù)據(jù)環(huán)境中的表現(xiàn)趕不上對應(yīng)的有限網(wǎng)絡(luò)。通過比較核方法與有限寬度架構(gòu)(使用小學(xué)習(xí)率,無正則化)的 base 模型,并逐一驗證可打破(大學(xué)習(xí)率、L2 正則化)或改進(jìn)(集成)無限寬度與核方法對應(yīng)性的訓(xùn)練實踐的效果,研究者驗證了這一假設(shè)。結(jié)果見下圖 1:

圖 1:有限和無限網(wǎng)絡(luò)及其變體在 CIFAR-10 上的測試準(zhǔn)確率。從給定架構(gòu)類別的有限寬度 base 網(wǎng)絡(luò)開始,標(biāo)準(zhǔn)和 NTK 參數(shù)化的模型表現(xiàn)隨著修改而發(fā)生變化:+C 指居中(Centering)、+LR 指大學(xué)習(xí)率、+U 指通過早停實現(xiàn)欠擬合、+ZCA 指使用 ZCA 正則化進(jìn)行輸入預(yù)處理、+Ens 指多個初始化集成,另外還有一些組合方案。Lin 指線性化 base 網(wǎng)絡(luò)的性能。
從中可以觀察到,對于 base 有限網(wǎng)絡(luò),無限 FCN 和 CNN-VEC 的表現(xiàn)要優(yōu)于它們各自對應(yīng)的有限網(wǎng)絡(luò)。另一方面,無限 CNN-GAP 網(wǎng)絡(luò)的表現(xiàn)又比其對應(yīng)的有限版本差。研究者指出這其實與架構(gòu)有關(guān)。舉例來說,即使有限寬度 FCN 網(wǎng)絡(luò)組合了高學(xué)習(xí)率、L2 和欠擬合等多種不同技巧,無限 FCN 網(wǎng)絡(luò)的性能還是更優(yōu)。只有再加上集成之后,有限網(wǎng)絡(luò)的性能才能達(dá)到相近程度。
另一個有趣的觀察是,ZCA 正則化預(yù)處理能顯著提升 CNN-GAP 核的表現(xiàn)。
2. NNGP 通常優(yōu)于 NTK
從下圖 2 中可以看出,在 CIFAR-10、CIFAR-100 和 Fashion-MNIST 數(shù)據(jù)集上 NNGP 的性能持續(xù)優(yōu)于 NTK。NNGP 核不僅能得到更強(qiáng)的模型,而且所需的內(nèi)存和計算量也僅有對應(yīng)的 NTK 的一半左右,而且某些性能最高的核根本就沒有對應(yīng)的 NTK 版本。

圖 2:當(dāng)對角正則化經(jīng)過精心調(diào)整時,NNGP 在圖像分類任務(wù)上通常優(yōu)于 NTK。
3. 居中和集成有限網(wǎng)絡(luò)都會得到類 kernel 的表現(xiàn)

圖 3:居中可以加速訓(xùn)練和提升性能。

圖 4:集成 base 網(wǎng)絡(luò)可讓它們達(dá)到與核方法相媲美的表現(xiàn),并且在非線性 CNN 上還優(yōu)于核方法。
4. 大學(xué)習(xí)率和 L2 正則化會讓有限網(wǎng)絡(luò)和核之間出現(xiàn)差異
從上圖 1 中可以觀察到,大學(xué)習(xí)率(LR)的效果容易受到架構(gòu)和參數(shù)化的影響。
L2 正則化則能穩(wěn)定地提升所有架構(gòu)和參數(shù)化的性能(+1-2%)。即使使用經(jīng)過精心調(diào)節(jié)的 L2 正則化,有限寬度 CNN-VEC 和 FCN 依然比不上 NNGP/NTK。L2 結(jié)合早停能為有限寬度 CNN-VEC 帶來 10-15% 的顯著性能提升,使其超過 NNGP/NTK。
5. 使用標(biāo)準(zhǔn)參數(shù)化能為網(wǎng)絡(luò)提升 L2 正則化

圖 5:受 NTK 啟發(fā)的逐層擴(kuò)展能讓 L2 正則化在標(biāo)準(zhǔn)參數(shù)化網(wǎng)絡(luò)中更有幫助。
研究者發(fā)現(xiàn),相比于使用標(biāo)準(zhǔn)參數(shù)化,使用 NTK 參數(shù)化時 L2 正則化能為有限寬度網(wǎng)絡(luò)帶來顯著的性能提升。使用兩種參數(shù)化的網(wǎng)絡(luò)的權(quán)重之間存在雙射映射。受 NTK 參數(shù)化中 L2 正則化項性能提升的啟發(fā),研究者使用這一映射構(gòu)建了一個可用于標(biāo)準(zhǔn)參數(shù)化網(wǎng)絡(luò)的正則化項,其得到的懲罰項與原版 L2 正則化在對應(yīng)的 NTK 參數(shù)化網(wǎng)絡(luò)上得到的一樣。
6. 在超過兩次下降的寬度中,性能表現(xiàn)可能是非單調(diào)的

圖 6:有限寬度網(wǎng)絡(luò)在寬度增大時通常會有更好的表現(xiàn),但 CNN-VEC 表現(xiàn)出了出人意料的非單調(diào)行為。L2:在訓(xùn)練階段允許非零權(quán)重衰減,LR:允許大學(xué)習(xí)率,虛線表示允許欠擬合(U)。
7. 核對角正則化的行為類似于早停

圖 7:對角核正則化的行為類似于早停。實線對應(yīng)具備不同對角正則化 ε 的 NTK 推斷;虛線對應(yīng)梯度下降到時間 τ = ηt 后的預(yù)測結(jié)果,線條顏色表示不同的訓(xùn)練集大小 m。在時間 t 執(zhí)行早停緊密對應(yīng)于使用系數(shù) ε = Km/ηt 的正則化,其中 K=10 表示輸出類別的數(shù)量。
8. 浮點數(shù)精度決定了核方法失敗的關(guān)鍵數(shù)據(jù)集大小

圖 8:無限網(wǎng)絡(luò)核的尾部特征值表現(xiàn)出了冪律衰減趨勢。
9. 由于條件不好,線性化 CNN-GAP 模型表現(xiàn)很差
研究者觀察到線性化 CNN-GAP 在訓(xùn)練集上的收斂速度非常慢,導(dǎo)致其驗證表現(xiàn)也很差(見上圖 3)。
這一結(jié)果的原因是池化網(wǎng)絡(luò)的條件很差。Xiao 等人的研究 [33] 表明 CNN-GAP 網(wǎng)絡(luò)初始化的條件比 FCN 或 CNN-VEC 網(wǎng)絡(luò)差了像素數(shù)倍(對 CIFAR-10 來說是 1024)。

表 1:對應(yīng)架構(gòu)類型的核的 CIFAR-10 測試準(zhǔn)確率。
10. 正則化 ZCA 白化(whitening)可提升準(zhǔn)確率

圖 9:正則化 ZCA 白化可提升有限和無限寬度網(wǎng)絡(luò)的圖像分類性能。所有的圖都將性能表現(xiàn)為 ZCA 正則化強(qiáng)度的函數(shù)。a)在 CIFAR-10、Fashion-MNIST、CIFAR-100 上核方法輸入的 ZCA 白化;b)有限寬度網(wǎng)絡(luò)輸入的 ZCA 白化。
11. 同變性(equivariance)僅對遠(yuǎn)離核區(qū)域的窄網(wǎng)絡(luò)有益

圖 10:同變性僅在核區(qū)域之外的 CNN 模型中得到利用。
如果 CNN 模型能有效地利用同變性,則預(yù)計它能比 FCN 更穩(wěn)健地處理裁剪和平移。出人意料的是,寬 CNN-VEC 的性能會隨輸入擾動的幅度而下降,而且下降速度與 FCN 一樣快,這說明同變性并未得到利用。相反,使用權(quán)重衰減的窄模型(CNN-VEC+L2+narrow)的性能下降速度要慢得多。正如預(yù)期,平移不變型 CNN-GAP 依然是最穩(wěn)健的。
12. 集成核預(yù)測器可使用 NNGP/NTK 進(jìn)行實用的數(shù)據(jù)增強(qiáng)

圖 11:集成核預(yù)測器(ensembling kernel predictors)可使基于大規(guī)模增強(qiáng)數(shù)據(jù)集的預(yù)測在計算上可行。
可以觀察到,DA 集成可提升準(zhǔn)確率,且相比于 NTK,它對 NNGP 的效果要好得多。
這里研究者提出了一種直接讓集成核預(yù)測器實現(xiàn)更廣泛的數(shù)據(jù)增強(qiáng)的方法。該策略涉及到構(gòu)建一組經(jīng)過增強(qiáng)的數(shù)據(jù)批,為其中每一批執(zhí)行核推斷,然后執(zhí)行所得結(jié)果的集成。這相當(dāng)于用模塊對角近似替代核,其中每個模塊都對應(yīng)一個數(shù)據(jù)批,所有增強(qiáng)的數(shù)據(jù)批的并集即為完整的增強(qiáng)數(shù)據(jù)集。該方法在該研究所有無線寬度架構(gòu)的對應(yīng)核方法上都取得了當(dāng)前最佳結(jié)果。


























