ECCV 2022 | 多域長尾分布學習,不平衡域泛化問題研究(開源)
來源丨h(huán)ttps://zhuanlan.zhihu.com/p/539749541編輯丨極市平臺 導讀
本文由被ECCV2022接受論文的作者親自解讀,講述如何推廣傳統不平衡分類問題的范式,將數據不平衡問題從單領域推廣到多領域。
前言項目主頁:http://mdlt.csail.mit.edu/論文鏈接:https://arxiv.org/abs/2203.09513代碼,數據和模型開源鏈接:https://github.com/YyzHarry/multi-domain-imbalance來給大家介紹一下我們的新工作,目前已被ECCV 2022接收:On Multi-Domain Long-Tailed Recognition, Imbalanced Domain Generalization and Beyond。顧名思義,這項工作研究的問題是當有多個領域(domain)數據,且這些 domain 都存在(可能互不相同的)數據不平衡情況下,該如何學習到魯棒的模型?,F有的處理不平衡數據/長尾分布的方法僅針對單域,即數據來源于同一個 domain;但是,自然數據可以源自不同 domain,而其中一個 domain 中的 minority class 可能在其他 domain 會是 majority class;而有效的利用不同域的數據很可能會提升長尾學習的表現。本文推廣了傳統不平衡分類問題的范式,將數據不平衡問題從單領域推廣到多領域。其中,多域長尾學習(Multi-Domain Long-Tailed Recognition,MDLT)的首個目標,是模型能夠在每一個 domain 的每一個 class 上,都有較好的 performance。而更一進步,因為從多個不同的 domain 學習,我們希望模型也能夠泛化到 unseen domain,即 (Imbalanced) Domain Generalization,不平衡域泛化。我們首先提出了 domain-class transferability graph,用來刻畫不同 <domain, class>對 之間的可轉移性(相似程度)。我們發(fā)現,基于這種定義的 transferability,直接決定了模型在 MDLT 任務上的表現。在此基礎上,基于理論分析,我們提出了 BoDA,一個理論上能夠 upper-bound 住 transferability 統計量的損失函數,來提升模型在 MDLT 問題上的性能。我們基于流行的多域數據集,構建了五個新的 benchmark MDLT 數據集,并實現和對比了約 20種 涵蓋 DA,DG,imbalance 等不同的算法,發(fā)現 BoDA 能夠穩(wěn)定提升 MDLT 的性能。此外,更有意思的是,我們發(fā)現目前流行的 域泛化(domain generalization,DG)問題的數據集本質上也是不平衡的,這種不平衡貫穿于 (1) 同一個 domain 內部的標簽不平衡;(2)不同 domain 之間的不平衡標簽分布的不一致。這證實了數據不平衡是 DG 中的一個內在問題, 但被過去的工作所忽視。神奇的是,我們發(fā)現當和 DG 算法結合到一起,BoDA 能穩(wěn)定提升 DG 的表現,這也揭示了標簽不平衡會影響 out-of-distribution generalization,而實用魯棒的 DG 算法設計也需要整合標簽不平衡的重要性。
1. 研究背景與動機現實世界的數據經常表現出標簽不平衡 — 現實數據通常不會是每個類別都具有理想的均勻分布,而是本質上會呈現長尾分布,其中某些類別的觀測數據量明顯較少。為了應對這種現象,許多解決數據不平衡的方法被陸續(xù)提出;完整的現有不平衡學習方法調研歡迎查看:分類機器學習中,某一標簽占比太大(標簽稀疏),如何學習?(https://www.zhihu.com/question/372186043/answer/1501948720)但是,現有的從不平衡數據中學習的解決方案,主要考慮的是 single domain 的情況,也就是說所有樣本來自于同樣的 data distribution。然而,真實情況下,針對同一項任務的數據可以來自不同的域(domain)。例如下圖所示,Terra Incognita[1] 是一個實際采集的野生動物識別+分類的數據集。左邊子圖顯示的是在不同位置建立的 camera trap,以及拍到的野生動物樣例;而右圖則是(一部分)不同 camera location 拿到的具體數據分布以及其拍攝效果。我們可以明顯的看出,即使是同一個 wildlife 分類任務,不同 camera 的參數、拍攝背景、光照強度等也完全不同,即不同 camera trap 之間存在 domain gap。而由于某些動物只會出現在特定位置,這導致了一個 camera(domain)的數據是不平衡的,甚至沒有某一些類別的數據(例如 location 100 幾乎沒有類別0和1的數據)。但由于不同 camera 拍到的 label distribution 往往截然不同,這也暗示了其他 domain 很可能在這些類別有許多樣本 — 如 location 46 就有較多的類別1 的數據。這說明了我們可以利用多域數據來解決在每個域內固有的數據不平衡問題。Terra Incognita數據集樣例。同一個wildlife分類任務中,不同相機的參數、拍攝背景、光照強度等也完全不同;并且,同一個相機拿到的數據也是類別極度不平衡的;不僅如此,不同相機拍到的標簽分布也截然不同,往往是非常mismatch的。但這也說明我們可以利用多域數據來解決每個域內固有的數據不平衡。同樣,在其他實際應用中也會發(fā)生類似的情況。例如,在視覺識別問題中,來自“照片”圖像的少數類可以用來自“草圖”圖像的潛在豐富樣本來補充。同樣,在自動駕駛中,“現實”生活中的少數事故類別可以通過“模擬”中產生的事故來豐富。此外,在醫(yī)學診斷中,來自不同人群的數據可以相互增強,例如其中一個機構的少數樣本可以與其他機構的可能存在的多數實例相結合。在以上這些例子中,不同的數據類型可以充當不同的域,而這樣的多域數據也可以被有效的利用來解決數據不平衡問題。因此,在這項工作中,我們定義并研究多域長尾分布學習,Multi-Domain Long-Tailed Recognition(MDLT),即從來自多個不同域的不平衡數據中學習。具體來說,給定具有多個域的目標數據集,MDLT 旨在從來自多域的不平衡數據中學習,解決每個域內的標簽不平衡、不同域之前的不同標簽分布,并且最終模型能夠泛化到所有域的所有類別上。對于 MDLT 我們考慮用一個在每個域的每個類別上分布是平衡的測試集來測試模型的泛化能力,這樣能夠為 MDLT 提供非常全面并且沒有偏差的評估。這種 setting 也是對單域長尾識別問題的自然的推廣,與其 setting 保持一致。
多域長尾分布學習,Multi-Domain Long-Tailed Recognition(MDLT),即從來自多個不同域的不平衡數據中學習,解決每個域內的標簽不平衡、不同域之前的不同標簽分布,并泛化到所有域的所有類別上。
2. 多域長尾學習的難點與挑戰(zhàn)需要注意到的是,相比于單域的長尾識別問題,MDLT 帶來了以下全新的挑戰(zhàn)。(一)首先,每個域的標簽分布都可能與其他域不同(label distribution shift across domains)。例如,在上一個gif圖中,“照片” 和 “卡通” 域都表現出不平衡的標簽分布;然而,“卡通” 中的 “馬” 類比 “照片” 中的樣本多得多。因此,除了域內數據不平衡之外,這還帶來了跨域標簽分布不同的挑戰(zhàn)。(二)此外,多域數據必然會涉及到域之間存在偏差(domain shift)。簡單地將來自不同域的數據視為一個整體并應用傳統的數據不平衡方法不太可能產生最佳結果,因為域之間的 gap 可以任意大。例如在第一張圖中顯示的 wildlife camera traps,不同camera的參數、拍攝背景等往往差距很大,而模型設計上也需要考慮到這一點。(三)最后,與單域不同,在多域長尾學習中,某些域的某些類別可能就根本沒有數據。因此,MDLT 自然地包含了 域內 和 跨域 的零樣本泛化(zero-shot generalization within and across domains) — 即(1)泛化到域內缺失類(gif圖中 “草圖” 域的右側部分);以及(2)完全沒有訓練數據的新域,也通常稱為域泛化(Domain Generalization,DG)??偨Y上述的問題,我們可以看到MDLT相比與傳統的單域不平衡分類具有全新的難點與挑戰(zhàn)。那么,我們應該如何進行多域長尾學習呢?在接下來的兩節(jié),我們將從整體建模、motivating examples、觀察到的現象、理論推導,到最終損失函數的設計,來一步一步分析這個問題,并最終提升模型在MDLT任務上的表現。
3. Domain-Class Transferability Graph(域-類對可轉移性圖)這里我們首先提出了一系列定義,來對 MDLT 這個問題建模。在單域長尾識別問題中,我們通??紤]的 “最小單位” 是 一個類別(class),也即按照樣本數量不同分成 majority classes 和 minority classes。然而當拓展到多域情況,我們該如何定義這個 “最小單位”,從而能同時考慮到 domain shift 和 class imbalance 呢?我們提出,在 MDLT 下,這個基本單元自然而然地變成了一個 “域-類對”(domain-class pair)。那么當我們從 “域類對” 下手,我們則可以在 embedding space 上,通過定義不同域類對之間的距離,來定義其之間的可轉移性(相似程度):直觀地說,兩個域類對之間的可遷移性是它們特征之間的平均距離,表征它們在特征空間中的接近程度。距離函數 d 默認設置為 Euclidean distance(一階統計量),但也可以選用其他距離來度量高階統計量(例如用 Mahalanobis distance 也用到了 covariance)。那么自然而然地,基于 transferability 我們可以定義 transferability graph(可轉移性圖):在 Transferability graph 里,每一個 node 是一個 域類對,而每一條邊則是兩個域類對之間的 transferability。通過這種定義,我們可以直觀地將 transferability graph 可視化到一個二維平面。可轉移性圖的總體框架。(a) 為所有域類對計算分布統計量,由此我們生成一個完整的可轉移性矩陣。(b) 我們利用 MDS 將可轉移性圖投影到二維空間中進行可視化。(c) 我們定義 (α, β, γ) 可轉移性統計量以進一步描述整個可轉移性圖。具體而言,由上圖 (a)(b) 所示,對于每一個域類對,我們可以計算出屬于這個域類對的所有數據的特征統計量(mean,covariance等)。那么對于不同域類對,我們進一步計算兩兩之間的 transferability,由此我們生成一個完整的可轉移性圖,由矩陣形式表示(圖a)。之后我們可以使用多維縮放(MDS)[2] 在2D平面上可視化這種相似性以及其可轉移性圖(圖b)。在圖b中,我們可以看到不同domain用不同顏色來標記,每一個點代表一個域類對,其大小代表所含數據量多少,數字則代表具體類別;而他們之間的距離,則可以看作 transferability。顯而易見,我們希望相同的數字(即相同類別)的域類對更接近,而不同類別的域類對互相遠離;而這種關系,能夠更加被抽象化成三種可轉移性統計量:不同domain相同class( α ),相同domain不同class( β ),以及不同domain不同class( γ ):那么到此為止,我們?yōu)?MDLT 進行了建模和數學形式上的定義。接下來我們將進一步探索 transferability 和 最終MDLT performance的關系。
4. 什么是多域長尾學習上好的特征?4.1. 發(fā)現1:跨域不匹配的標簽分布會阻礙模型學到可轉移的特征我們首先發(fā)現:由于不平衡的存在,不同域上不同的標簽分布阻礙了模型學到可轉移的特征。Motivating Example:我們首先構建了一個小型 MDLT 數據集,Digits-MLT,是將兩個數字分類數據集合并到一起:(1) MNIST-M[3],一個彩色背景的MNIST手寫數字數據集,以及 (2) SVHN[4],一個街頭拍攝的數字數據集。這兩個數據集的任務是一致的,也即0~9的十個數字分類問題。我們手動改變每個域類對的樣本數量以模擬不同的標簽分布,并針對每種情況使用經驗風險最小化 (ERM) 訓練一個普通的 ResNet-18。我們保持所有測試集是平衡且相同的。改變 Digits-MLT 的標簽比例時可轉移性圖的演變模式。(a) 兩個域的標簽分布是平衡且相同的。(b) 兩個域的標簽分布不平衡但相同。(c) 兩個域的標簽分布不平衡且發(fā)散。上圖的結果揭示了有趣的觀察結果。當每個域的標簽分布平衡且跨域相同時,盡管存在域差距,但并不妨礙模型學習高精度(90.5%)的判別特征,如圖a 所示。如果標簽分布不平衡但保持相同(圖b),ERM 仍然能夠對齊兩個域中的相似類,其中多數類(例如類9)在可轉移性方面要好于少數類(例如類0)。相反,當標簽在域之間既不平衡又不匹配時,如圖c 所示,學習到的特征不再是可遷移的,這也導致了學到的特征在域之間存在明顯的gap,以及最差的準確率。這是因為跨域的不同標簽分布會產生shortcut;模型可以簡單地通過分離兩個域來最小化分類損失。這種現象表明,可轉移的特征是我們所需要的。上面的結果表明,模型需要學到跨域類對的可轉移的特征,尤其是在數據不平衡時。特別是,同一類跨域之間的transferability 應大于域內或跨域的不同類之間的transferability — 而這則可以通過 ( αβγ ) 可轉移性統計量來量化。
4.2. 發(fā)現2:轉移統計量刻畫了模型的泛化能力承接上文,我們說到模型需要可轉移的特征,而可轉移性統計量則可以幫助量化判斷模型的好壞。那么可轉移性統計量和 模型performance 的具體關系是什么呢?Motivating Example:同樣,我們使用具有不同標簽分布的 Digits-MLT。我們考慮三種不平衡類型來組成不同的標簽配置:(1)統一(即平衡標簽),(2)前向長尾,其中標簽在類別ID上表現出長尾分布,以及(3)后向長尾,其中標簽相對于類別ID 是反向長尾的。對于每種配置,我們訓練了 20 個具有不同超參數的 ERM 模型。然后我們計算每個模型的 ( αβγ ) 統計量,并繪制其分類準確度與 βγα 的關系。(β + γ) - α 統計量與 Digits-MLT 不同標簽配置的測試準確度之間的對應關系。每個子圖代表兩個域的特定標簽分布(例如,子圖a對域1 使用“Uniform”,對域2 使用“Uniform”)。圖中每個點對應于使用不同超參數使用 ERM 訓練的模型。上圖揭示了以下發(fā)現:
- ( αβγ ) 統計量表征了模型在 MDLT 中的性能。特別是, βγα 統計量在整個范圍上和每個標簽配置的測試性能均顯示出非常強的相關性。
- 數據不平衡會增加學到不可遷移特征的風險。當跨域的標簽分布一致且平衡時(圖a),模型對變化的參數具有魯棒性,在右上區(qū)域聚集。然而,隨著標簽變得不平衡(圖b、c)和進一步發(fā)散(圖d、e),模型學習不可遷移特征(即較低的 βγα)的機會增加,導致性能大幅下降。
我們利用上述發(fā)現設計了一個特別適合 MDLT 的新損失函數。我們將首先介紹損失函數,然后理論證明它最小化了 ( αβγ ) 統計量的上限。我們從一個受度量學習目標啟發(fā)的簡單損失開始,并稱這種損失為 ,因為它旨在實現域類分布的對齊,即跨域對齊同一類的特征:直觀來看, 解決了標簽分布跨域不匹配的問題,因為共享同一類的域類對將被拉得更近,反之亦然。它還與 ( αβγ ) 統計有關,因為分子表示正跨域對 ( α ),分母表示負跨域對 ( βγ)。但是,它并沒有解決標簽不平衡問題。我們注意到( αβγ )是以平衡的方式定義的,與每個域類對中的樣本數無關。然而,給定一個不平衡的數據集,大多數樣本將來自多數域類對,這將主導 并導致少數域類對被忽略。BoDA loss:為了應對上述問題,我們進一步修改公式1,得到 Balanced Domain-Class Distribution Alignment (BoDA) loss —可以發(fā)現,BoDA 將原始的距離函數 d 縮放了 的因子,其中 是域類對 的樣本數量。即,BoDA 通過引入平衡的距離度量 來抵消不平衡域類對的影響。而對于 ,我們證明了以下定理:具體的證明細節(jié)請詳見我們文章。定理1有如下的有趣的含義:
- 是 ( αβγ ) 統計量的一種理想的形式的upper-bound。通過最小化 ,我們確保了低 α (吸引相同的類)和高 β、γ (分離不同的類),這是 MDLT 中泛化的必要條件,自然轉化為更好的性能。
- 統計量中的常數因子對應于每個部分對可遷移性圖的貢獻程度。我們注意到在 里,目標與 αβγ 成正比。根據定義3,我們注意到 總結了同一類的數據相似性,而 $ \frac1}{D|β + \fracD|?1D|\gamma使用 β和γ$ 的加權平均值總結了不同類的數據相似性,其中它們的權重與相關域的數量成正比(即, β 為 1, γ 為 )。
BoDA 的工作原理是鼓勵跨域的相似類的特征遷移,即如果 (d, c) 和 (d' , c) 是不同域中的同一類,那么我們希望將它們的特征是相互遷移的。但是,由于數據不平衡,少數域類對的統計量估計值自然會更差,而這種情況下迫使其他對轉移到它們會損害模型的學習過程。因此,當在特征空間中使兩個域類對更接近時,我們希望少數域類對轉移到多數,而反過來則不是。這里細節(jié)較多,就直接跳過了,我們的 paper 中給出了詳細的 motivating example 和 interpretation。結論是,可以通過在 BoDA的基礎上加上一個 Calibration 項,由兩個域類對的相對樣本數量來實現轉移程度的控制:
5. 基準MDLT數據集及實驗分析基準MDLT數據集:終于來到了激動人心的實驗部分 ;) 為了方便對不平衡算法進行標準的測試,以及方便未來的research工作,我們在現有的multi-domain數據集基礎上,建立了五個MDLT的基準數據集。具體來說,我們使用的是域泛化的基準數據集[5],并將它們用于 MDLT 評估。為此,我們?yōu)槊總€數據集創(chuàng)建兩個平衡的數據集,一個用于驗證,另一個用于測試,其余的用于訓練。驗證和測試數據集的大小分別約為原始數據的 5% 和 10%。這些數據集的訓練數據分布如下圖所示:此外,我們選取了近20種算法,涵蓋了 multi-domain learning,distributionally robust optimization,invariant feature learning,meta-learning,imbalanced learning 等各種類別作為基線方法比較,并對每種算法優(yōu)化了超參數。這樣的過程確保了比較是最佳與最佳的(best-vs-best),并且超參數針對所有算法進行了優(yōu)化。在評估過程中,除了跨域的平均準確率外,我們還報告了所有域的最差準確率,并將所有域類對進一步劃分為幾個不相交的子集:稱為many-shot(訓練樣本超過 100 個的),medium-shot(20~100 個訓練樣本的),few-shot(訓練樣本少于 20 個的),還有zero-shot(沒有訓練數據的),并報告這些子集的結果。具體詳見我們的文章。實驗:由于實驗較多,這里僅展示在所有數據集上的合并結果,所有的結果請詳見論文。如下圖所示,BoDA(及其變種)在所有數據集上始終保持最佳平均準確度。在大多數情況下,它還可以達到最佳的最壞情況精度。此外,在某些數據集(如OfficeHome-MLT)上,MDL 方法表現更好(如CORAL),而在其他數據集(如TerraInc-MLT)上,不平衡方法獲得更高的收益(如CRT);盡管如此,無論數據集如何,BoDA 都優(yōu)于所有方法,突出了其對 MDLT 任務的有效性。最后,與 ERM 相比,BoDA 略微提高了平均和many-shot的性能,同時大幅提升了medium-shot、few-shot和zero-shot的性能。實驗分析之 BoDA 學到了怎樣的可轉移性圖:我們進一步來對提出的方法做一些進一步的分析。我們繪制了通過BoDA學到的可轉移性圖,并在不同跨域標簽分布下與 ERM 進性對比。從下圖可以發(fā)現,BoDA 學習到了更加平衡的特征空間,將不同的類別分開。當標簽分布是平衡且一致時,ERM 和 BoDA 都能學到好的特征;而當標簽開始不平衡(b,c),甚至跨域不匹配(d,e)時,ERM 的可轉移性圖出現了明顯的 domain gap;與之對應,BoDA 則能一直學到平衡且對齊的特征空間。更好的學習特征便轉化為更好的準確度(9.5% 的絕對準確度增益)。
6. MDLT 更進一步:不平衡域泛化問題域泛化(DG)是指從多個域中學習并泛化到未見過的域。由于學習域的標簽分布很可能不同,甚至可能在每個域內都存在類不平衡,因此我們研究解決跨域數據不平衡是否可以進一步增強 DG 的性能?;叵胛覀?yōu)?MDLT 建立的所有數據集都是 DG 的基準數據集,這證實了數據不平衡是 DG 的一個內在問題,但過去的工作卻忽略了這一點。我們研究 BoDA 是否可以提高 DG 的性能。為了測試 BoDA,我們遵循標準的 DG 評估協議 [5]。通過上表,我們發(fā)現僅 BoDA 就可以在五個數據集中的四個上提升當前的結果,并實現顯著的平均性能提升。此外,結合當前的SOTA,BoDA 進一步將所有數據集的結果顯著提升,這表明標簽不平衡與現有的 DG 特定算法是正交的。最后,與 MDLT 類似,增益取決于數據集內不平衡的嚴重程度——例如,TerraInc 表現出跨域最嚴重的標簽不平衡,而 BoDA 在其上獲得最高增益。這些有趣的結果揭示了標簽不平衡如何影響域泛化,并強調了整合標簽不平衡對于實際 DG 算法設計的重要性。
7. 結語最后總結一下本文,我們提出了一個新的任務,稱為多域長尾分布學習(MDLT),同時我們系統性地研究了MDLT,并提出了有理論保障的新損失函數 BoDA,以解決多域的學習不平衡數據的問題,最后我們建立了五個新的benchmark來方便未來在多域不平衡數據上的研究。本文有很直觀的問題分析與解釋,理論證明,以及用非常簡潔并且通用的框架去提升多域下的不平衡學習任務。此外,我們發(fā)現標簽不平衡會影響 out-of-distribution generalization,而實用魯棒的DG算法設計也需要整合標簽不平衡的重要性。
參考- Recognition in Terra Incognita. ECCV, 2018.
- Multidimensional scaling. Measurement, judgment and decision making, pages 179–250, 1998.
- Domain-adversarial training of neural networks. Journal of machine learning research, 17(1):2096–2030, 2016.
- Reading digits in natural images with unsupervised feature learning. NIPS Workshop on Deep Learning and Unsupervised Feature Learning, 2011.
- In search of lost domain generalization. In ICLR, 2021.
- Delving into Deep Imbalanced Regression. ICML, 2021.
本文僅做學術分享,如有侵權,請聯系刪文。
*博客內容為網友個人發(fā)布,僅代表博主個人觀點,如有侵權請聯系工作人員刪除。