NeurIPS 2022 | 四分鐘內(nèi)就能訓(xùn)練目標(biāo)檢測(cè)器,商湯基模型團(tuán)隊(duì)是怎么做到的?
來(lái)自商湯的基模型團(tuán)隊(duì)和香港大學(xué)等機(jī)構(gòu)的研究人員提出了一種大批量訓(xùn)練算法 AGVM,該研究已被NeurIPS 2022接收。
本文提出了一種大批量訓(xùn)練算法 AGVM (Adaptive Gradient Variance Modulator),不僅可以適配于目標(biāo)檢測(cè)任務(wù),同時(shí)也可以適配各類分割任務(wù)。AGVM 可以把目標(biāo)檢測(cè)的訓(xùn)練批量大小擴(kuò)大到 1536,幫助研究人員四分鐘訓(xùn)練 Faster R-CNN,3.5 小時(shí)把 COCO 刷到 62.2 mAP,均打破了目標(biāo)檢測(cè)訓(xùn)練速度的世界紀(jì)錄。
論文地址:https://arxiv.org/pdf/2210.11078.pdf
代碼地址:https://github.com/Sense-X/AGVM
在當(dāng)前的機(jī)器學(xué)習(xí)社區(qū)中,有三個(gè)普遍的趨勢(shì)。首先,神經(jīng)網(wǎng)絡(luò)模型會(huì)越來(lái)越大。在 NLP 領(lǐng)域中最大規(guī)模的模型已經(jīng)達(dá)到了上萬(wàn)億級(jí)別。在視覺(jué)領(lǐng)域,最大規(guī)模的模型也達(dá)到了三百億的量級(jí)。其次,訓(xùn)練的數(shù)據(jù)集也變得越來(lái)越大。比如,ImageNet 21k 和谷歌的 JFT 數(shù)據(jù)集都具有相當(dāng)規(guī)模的數(shù)據(jù)集。另外,由于數(shù)據(jù)集變得越來(lái)越大,訓(xùn)練 SOTA 模型的開(kāi)銷越來(lái)越大。
因此,提升訓(xùn)練效率就變得愈發(fā)重要。而分布式訓(xùn)練因?yàn)槠溥m應(yīng)于數(shù)據(jù)并行、模型并行和流水線并行的加速訓(xùn)練方法的同時(shí),也具備較高的 Deep Learning 通信效率而被廣泛認(rèn)為是一個(gè)有效的解決方案。
隨著大模型時(shí)代的到來(lái),目標(biāo)檢測(cè)器的訓(xùn)練速度越來(lái)越成為學(xué)術(shù)界和工業(yè)界的瓶頸,例如,在 COCO 的標(biāo)準(zhǔn) setting 上把 mAP 訓(xùn)到 62 以上大概需要三天的時(shí)間,算上調(diào)試成本,這在業(yè)界幾乎是不可接受的。那么,我們能不能把這個(gè)訓(xùn)練時(shí)間壓到小時(shí)級(jí)別呢?事實(shí)上,在圖片分類和自然語(yǔ)言處理任務(wù)上,先前的研究人員借助 32K 的批量大小(batch size),只需 14 分鐘就可以完成 ImageNet 的訓(xùn)練,76 分鐘完成 Bert 的訓(xùn)練。但是,在目標(biāo)檢測(cè)領(lǐng)域,還很欠缺這類研究,導(dǎo)致研究人員無(wú)法充分利用當(dāng)前的算力,數(shù)據(jù)集和大模型。
大批量訓(xùn)練算法 AGVM 便是這個(gè)問(wèn)題的最佳解決方案之一。為了支持如此大批量的訓(xùn)練,同時(shí)保持模型的訓(xùn)練精度,本研究提出了一套全新的訓(xùn)練算法,根據(jù)密集預(yù)測(cè)不同模塊的梯度方差(gradient variance),動(dòng)態(tài)調(diào)整每一個(gè)模塊的學(xué)習(xí)率。作者在大量的密集預(yù)測(cè)網(wǎng)絡(luò)和數(shù)據(jù)集上進(jìn)行了實(shí)驗(yàn),并且證實(shí)了該方法的合理性。
方法介紹
大批量訓(xùn)練是加速大型分布式系統(tǒng)中深度神經(jīng)網(wǎng)絡(luò)訓(xùn)練的關(guān)鍵。尤其是在如今的大模型時(shí)代,如果不采用大批量訓(xùn)練,一個(gè)網(wǎng)絡(luò)的訓(xùn)練時(shí)間幾乎是難以接受的。但是,大批量訓(xùn)練很難,因?yàn)樗鼤?huì)產(chǎn)生泛化差距(generalization gap), 直接訓(xùn)練會(huì)導(dǎo)致其準(zhǔn)確率降低。此前的大批量工作往往針對(duì)于圖像分類以及一些自然語(yǔ)言處理的任務(wù),但密集預(yù)測(cè)任務(wù)(包括檢測(cè)分割等),同樣在視覺(jué)中處于舉足輕重的位置,此前的方法并不能在密集預(yù)測(cè)任務(wù)上有很好的表現(xiàn),甚至結(jié)果比基準(zhǔn)線更差,這導(dǎo)致我們難以快速訓(xùn)練一個(gè)目標(biāo)檢測(cè)器。
為了解決這個(gè)問(wèn)題,研究人員進(jìn)行了大量的實(shí)驗(yàn)。最后發(fā)現(xiàn),相較于傳統(tǒng)的分類網(wǎng)絡(luò),利用密集預(yù)測(cè)網(wǎng)絡(luò)一個(gè)很重要的特征:密集預(yù)測(cè)網(wǎng)絡(luò)往往是由多個(gè)組件組成的,以 Faster R-CNN 為例:它由四個(gè)部分組成,骨干網(wǎng)絡(luò) (Backbone),特征金字塔網(wǎng)絡(luò)(FPN),區(qū)域生成網(wǎng)絡(luò)(RPN) 和檢測(cè)頭網(wǎng)絡(luò)(head),我們可以發(fā)現(xiàn)一個(gè)很有效的指標(biāo):密集預(yù)測(cè)網(wǎng)絡(luò)不同組件的梯度方差,在訓(xùn)練批量很小時(shí)(例如 32),幾乎是相同的,但當(dāng)訓(xùn)練批量很大時(shí)(例如 512),它們呈現(xiàn)出很大的區(qū)別,如下圖所示:
那么,能不能直接把這些拉平呢?這直接引出了 AGVM 算法。以隨機(jī)梯度下降算法為例,上角標(biāo) i 代表第 i 個(gè)網(wǎng)絡(luò)模塊(例如 FPN 等),上角標(biāo) 1 代表骨干網(wǎng)絡(luò),代表學(xué)習(xí)率,錨定骨干網(wǎng)絡(luò),可以直接將不同網(wǎng)絡(luò)組件的梯度 g 的方差
:
梯度的方差可以由以下式子估計(jì):
方差的具體求解細(xì)節(jié)可以參考原文,本研究同樣引入了滑動(dòng)平均機(jī)制,防止網(wǎng)絡(luò)訓(xùn)練發(fā)散。同時(shí),研究證明了 AGVM 在非凸情況下的收斂性,討論了動(dòng)量以及衰減的處理方式,具體實(shí)現(xiàn)細(xì)節(jié)可以參考原文。
實(shí)驗(yàn)過(guò)程
本研究首先在目標(biāo)檢測(cè)、實(shí)例分割、全景分割和語(yǔ)義分割的各種密集預(yù)測(cè)網(wǎng)絡(luò)上進(jìn)行了測(cè)試,通過(guò)下表可以看到,當(dāng)用標(biāo)準(zhǔn)批量大小訓(xùn)練時(shí),AGVM 相較傳統(tǒng)方法沒(méi)有明顯優(yōu)勢(shì),但當(dāng)在超大批量下訓(xùn)練時(shí),AGVM 相較傳統(tǒng)方法擁有壓倒性的優(yōu)勢(shì),下圖第二列從左至右分別表示目標(biāo)檢測(cè),實(shí)例分割,全景分割和語(yǔ)義分割的表現(xiàn),AGVM 超越了有史以來(lái)的所有方法:
下表詳細(xì)對(duì)比了 AGVM 和傳統(tǒng)方法,體現(xiàn)出了本研究方法的優(yōu)勢(shì):
同時(shí),為了說(shuō)明 AGVM 的優(yōu)越性,本研究進(jìn)行了以下三個(gè)超大規(guī)模的實(shí)驗(yàn)。研究人員把 Faster R-CNN 的 batch size 放到了 1536,這樣利用 768 張 A100 可以在 4.2 分鐘內(nèi)完成訓(xùn)練。其次,借助 UniNet-G,本研究可以在利用 480 張 A100 的情況下,3.5 個(gè)小時(shí)讓模型在 COCO 上達(dá)到 62.2mAP(不包括骨干網(wǎng)絡(luò)預(yù)訓(xùn)練的時(shí)間),極大的減小了訓(xùn)練時(shí)間:
甚至,在 RetinaNet 上,本研究把批量大小擴(kuò)展到 10K。這在目標(biāo)檢測(cè)領(lǐng)域是從未見(jiàn)的批量大小,在如此大的批量下,每一個(gè) epoch 只有十幾個(gè)迭代次數(shù),AGVM 在如此大的批量下,仍然能展現(xiàn)出很強(qiáng)的穩(wěn)定性,性能如下圖所示:
結(jié)果分析
本研究探究了一個(gè)很重要的問(wèn)題:以 RetinaNet 為例,如下圖第一列所示,探究為什么會(huì)出現(xiàn)梯度方差不匹配這一現(xiàn)象。
本研究認(rèn)為,這一現(xiàn)象來(lái)自于:網(wǎng)絡(luò)不同模塊間的有效批量大小 (effective batch size) 是不同的。例如,RetinaNet 的頭網(wǎng)絡(luò)的輸入是由特征金字塔的五層網(wǎng)絡(luò)輸出的,特征金字塔的 top-down 和 bottom-up pathways,以及像素維度的損失函數(shù)計(jì)算會(huì)導(dǎo)致頭網(wǎng)絡(luò)和骨干網(wǎng)絡(luò)的等效批量大小不同,這一原理導(dǎo)致了梯度方差不匹配的現(xiàn)象。
為了驗(yàn)證這一假設(shè),本研究依次給每一層特征使用單獨(dú)的頭網(wǎng)絡(luò),移去特征金字塔網(wǎng)絡(luò),隨機(jī)忽略掉 75% 的用于計(jì)算損失函數(shù)的像素,最終,本研究發(fā)現(xiàn)骨干網(wǎng)絡(luò)和頭網(wǎng)絡(luò)的梯度方差曲線重合了,本研究也對(duì) Faster R-CNN 做了類似的實(shí)驗(yàn),如下圖第二列所示,更多的討論請(qǐng)參見(jiàn)原文。
*博客內(nèi)容為網(wǎng)友個(gè)人發(fā)布,僅代表博主個(gè)人觀點(diǎn),如有侵權(quán)請(qǐng)聯(lián)系工作人員刪除。