對比學習 Contrastive Learning 主流方法一覽

SimCLR, MoCo, SwAV, BYOL, CLIP, DeepCluster, PIRL, Barlow Twins,一覽各種模型的底層邏輯

Rice Yang
18 min readNov 24, 2022
圖片來源:www.redbubble.com

Contrastive Learning 對比學習 最近在學術界與產業界都受到了極大的關注,並且擁有非常多的應用場景。之前在文章《10 分鐘搭建萬物識別 Live Demo》中也稍微介紹過,應用對比學習可以訓練出到非常具有泛化性 Generalization 的預訓練模型或特徵表述,加速終端任務的開發。目前所熟知的各種影音、電商、社交平台,背後無一沒有對比學習的技術。如果今天去面試相關工作,也一定會被問到對比學習相關的技術問題。

本文會試著從對比學習的定義開始,自上而下做一個完整的分層,並試著把近期在對比學習領域中比較有名的方法在一個大的框架下整理出來,比較各種方法的異同。

文章目錄

  • Contrastive Learning 概念
  • 對比學習五大要素
  • Parallel Augmentation: End-to-End, Momentum Encoder
  • Architecture: End-to-End, Memory Bank, MoCo, Clustering
  • Loss Function: NCE, InfoNCE
  • Data-Model: CLIP
  • 主流方法一覽
  • Contrastive Learning v.s. Self-Supervised Learning

Contrastive Learning 概念

Contrastive Learning (CL) 與 VAE、GAN 類似,都是可以作為 Unsupervised Learning 無監督學習方法的學習方式。但用在無標籤數據時,通常我們會將對比學習歸類為自監督學習 Self-Supervised Learning,而不是無監督學習。

如果說 VAE 的概念是「自己學自己」,那麼對比學習的概念就是「自己只像自己」。

從人臉識別開始

比較容易理解對比學習的是 2015 年的人臉辨識模型 FaceNet。在此之前引用傳統監督學習來做人臉辨識時,模型只會要求同一個人的不同照片要保持高相似度;而 FaceNet 創造了 Triplet Loss ,增加要求不同的人的不同照片也要保持低相似度。

在傳統的人臉識別中,模型只會要求兩張歐巴馬的照片要足夠相似,意即在右側的 2D 空間投影要足夠近。而在 FaceNet 中,還會額外要求與其他所有照片都要保持低相似度,例如川普。圖片來源:Building a dog search engine with FaceNet

這時候對比學習還沒大行其道,因此論文中不會出現任何對比學習相關的名詞。但多年以後,我們發現 FaceNet 與對比學習一樣,都關注於「最大化類內相似、最小化類間相似」。而在對比學習與 FaceNet 最大的不同,在於對比學習的類內數據 (in-class data)不是收集來的,而是來自於數據擴增 Data Augmentation

數據擴增 Data Augmentation

數據擴增可以將一筆數據擴增成 N 筆數據,同時不影響這筆數據的意義。例如下圖,這些方法可以把一張圖片變成 N 張,同時不改變圖片的語意:

列舉幾種數據擴增的方法。圖片來源:SimCLR 論文

通常實作時,我們會用預先設置好的機率值來觸發某一種數據擴增方法。例如 SimCLR 的 GitHub 裡面,當隨機數 < 機率 p 的時候就會觸發數據擴增方法 func

數據擴增可以讓我們增加數據的變化性,免費得到幾倍以上的數據量,對最終模型的效果也有很重要的幫助。有一些決定自動化數據擴增策略的方法可以引用,例如 AutoArgument, RandArgument 等。PyTorch 裡面也有已經包好的自動化方法方法 (Autumatic Augmentation Transforms),可以讓你直接使用。

從數據擴增到對比學習

如果我們每個類別都只有一條無標籤數據 (unlabled data),我們是否可以用數據擴增的方法來把無監督數據轉變成監督學習的任務呢?

答案是可以的。假設我們有 N 條數據,每條數據擴增 K 次,那我們就有 N×K 條數據,總共有 N 類 (class),每一類有 K 筆數據。這就是對比學習 CL 的基礎思想。

在對比學習的架構中,不同的數據永遠會被定義成不同的類。這可能衍生出錯誤的標籤,例如兩張歐巴馬的照片會被對比學習定義成不同的人。這是對比學習不可避免的問題,因為如果你手工修正了這兩張圖是同一人,那就不是對比學習,而是監督學習。因此對比學習往往需要大量的數據,來降低這類天然的標籤錯誤問題。

你可能會好奇,一張圖作為一個樣本有什麼意義?如果你用大數據做訓練,訓練出來的模型可以有非常好的特徵提取 (feature extraction) 能力,可以很好的描述並區別不同的東西,達到 Zero-Shot, Few-Shot Learning 的效果。以人臉識別舉例,即使拜登沒有出現在訓練集中,訓練出的模型還是可以提出拜登的特徵,並利用相似性比對來辨識拜登。這個性質稱為泛化性 Generalization

現在來定義一下 Contrastive Learning:

Contrastive Learning 旨在學習一種高泛化性的通用特徵提取方式,所提取的特徵對於相似的樣本有較高的相似度,而相異的樣本有更低的相似度。

對比學習的 5 大要素

經過我整理了一些論文,對比學習方法由 5 種要素組成。只要在要素之間排列組合,就可以變化出不同的對比學習方法:

  1. Data Augmentation: 這是對比學習的基礎成分。提出新的數據擴增方法也很有學術價值,例如 MixUp, CutMix
  2. Parallel Augmentation: 指的是如何運用 noise data (擴增數據) 提取特徵。主要有兩種方式:端到端 End-to-End 與 動量編碼器 Momentum Encoder
  3. Architecture: 指的是如何操作正負樣本來計算損失函數的架構設計。主要可分 4類:End-to-End, Memory Bank, MoCo, Clustering
  4. Loss Function: 損失函數的設計。最經典的可能是 InfoNCE
  5. Data-Modal: 不同模態的數據會衍生不同的對比學習方法。

Parallel Augmentation

在對比學習中,我們的每一類數據都會有原始數據 (anchor) 以及擴增數據 (noise)。Parallel Augmentation 關注於如何從 anchor/noise 數據提取特徵,但最主要的還是配合後面的 Architecture 做相應設計。

End-to-End Encoder 端到端編碼器

指同一個 batch 中,所有 anchor/noise data 的特徵提取都會使用相同參數、相同結構的 Encoder。其實與其說 End-to-End,孿生網路 Siamese Network 可能是更適合的名詞。

在 End-to-End 中,最純粹的方法就是 SimCLR。他會從 1 個 anchor data 生成 2 個 noise data,然後進行排列組合比對。考慮 Batch size = N,最終會產生 2N×N 個樣本對 data pair,包含 2N 個正樣本對 positive pair2N×(N-1) 個副樣本對 negative pair

SimCLR的學習流程圖。圖片來源:google-research/simclr

SimCLR 的架構中,一個 batch 中所有圖片使用的 CNN 都是相同權重的。同樣的,無論是正樣本對還是負樣本對,都會計算 loss 並且向後傳播更新 CNN 權重。

Momentum Encoder 動量編碼器

Momentum 在物理中代表動量,是用來描述慣性一個概念;在機器學習中,動量常常用來借指移動平均 moving-average 的概念。論文 BYOL 較能簡單地闡述 Momentum Encoder 的概念,如下圖。值得特別注意的是,BYOL 計算 loss 的時候只用上了 positive pair,沒有 negative pair

補充說明,圖中的 sg 代表 stop-gradient,代表梯度更新不會傳導到參數 ξ。圖片來源:BYOL Paper

我們可以用函數 f, g, q來表示神經網路,而用下標來表示網路的參數。在上圖的 online 流程中都是使用參數 θ 的神經網路,而下面的 target 流程則是參數 ξ。當 θξ 的時候,這其實就是 end-to-end 網路,而當 ξθ 的移動平均數的時候,f( · ; ξ) 就是一個 Momentum Encoder

所以 Momentum Encoder 指的就是由 End-to-End Encoder 的歷史權重做移動平均而求得獲得的編碼器。這個概念在 Batch Normalization 裡面也有,裡面的超參數 mean, variance 也是透過歷史的 mean, variance 的平均移動得到的。

使用 Momentum Encoder 的好處在於避免 over-fitting,因為這使得正樣本的差異除了來自於 Data Augmentation 外,也可能來自 Momentum Encoder,而這兩者都是不可由梯度傳播直接消除的。在 BYOL 裡面間接證實了 Momentum Encoder 避免 over-fitting 的效果卓越,甚至可以在損失函數中忽略 negative pair 也不會出現問題。

Architecture 對比學習架構

這裡的架構不是只 ResNet, VGG 等網路基礎架構,而是如何在正負樣本的特徵中設計對比學習的方法。

End-to-End 端到端

End-to-End 是最簡單的方法,就是直接比對 Batch 之中的正負樣本。除了 SimCLR 之外,最直觀的方法是 Barlow Twins,該方法直接將總共 N×N 個樣本相似度排列成相似度矩陣 cross-correlation matrix,並且直觀的將該矩陣求解為一個單位矩陣 (Identity Matrix)。

Barlow Twins 訓練流程。圖片來源:Barlow Twins Paper

SimCLR, Barlow Twins 等基於 End-to-End 的架構中其實都有一個問題,就是每次的迭代中 negative class 太少了,與 batch size 是 1:1 的關係。因此 SimCLR 的訓練 batch size 必須要非常大,否則效果就會下降。

最理想的狀況是在 loss 計算的時候考慮所有 negative pair,也就是 batch size = class size = data size,但這是不可能的。如果我們的數據量都有百萬級,在對比學習中就會有百萬個類別,在一次迭代中要計算所有的類別相似度就需要對所有 data 提取特徵 — — 這太耗時了。

Memory Bank

Memory Bank 就是解決這個問題的簡單設計。通常我們計算相似度不需要原始數據,只需要提取出來的特徵就行。因此我們可以把所有樣本的特徵都建立一個 cache table,有需要的時候直接用查表法取出就行,就可以避免重新計算、提取特徵。這個 cache table 就稱作 Memory Bank

PIRL 是近期使用 Memory Bank 比較純粹且成功的方法。在 Memory Bank 中會為所有數據都暫存一份特徵,並且計算 loss 時,所有的 negative sample 都取自於 Memory Bank。值得注意的是,這裡的 Memory Bank 緩存的特徵 (representation) 也是使用 moving-averaged representation,不是直接更新。

圖片來源:PIRL Paper

MoCo: Moving Contrast

MoCoMemory BankMomentum Encoder 融合後的一種變形方法。Memory Bank 的問題是其數量級還是太大了,在記憶體與計算量上都有不少的消耗。MoCo 利用一個大小有限的 FIFO queue 來暫存過往的特徵 (論文中稱為 key), 並且將 queue 裡面的所有數據都當作 negative data

這個方法將 negative pair 的數量級從樣本大小降低到 queue 大小,節省了大量的空間與計算力;但該方法的確比 Memory Bank 降低了一些精度,在一些 benchmark 上,MoCo 的效果都略遜於只用 Memory BankPIRL

MoCo (Momentum Contrast) 的基礎慨念。圖片來源:MoCo Paper

Clustering 分群

早在 20 年代,Clustering 就是用於處理無監督學習的代表性方法,因此也有人引入到對比學習之中,試著利用分群來解決負類別過多的問題。

DeepCluster 流程。圖片來源:DeepCluster Paper

DeepCluster 直觀的解釋了分群如何應用在深度學習中。該方法藉由設計一個額外的分群任務,在特徵空間中求解 K-means 問題,給樣本打上不同的虛擬標籤 psudo-label,再用 pseudo-label 進行傳統的監督學習。

SwAV 與一般對比學習的差異。圖片來源:SwAV Paper

SwAV 是目前 Clustering 架構中最具代表性的方法,如上圖。SwAV 引入了 Prototypes (C) 來儲存每一個聚類中心的特徵向量。論文中還使用了 Codes (Q) 代表抽取出的特徵 z 與每一個 Prototype 的相似度,以作為對比學習計算損失用的向量。需要注意,SwAVBOYL 一樣,都是只考慮 positive pair 的對比學習設計。

理論上將 C 乘以每一個數據特徵 z 之後就可以作為 Q 來使用,但是這種設計很容易讓 C 的每一個 prototypes 都學成相同的數值。因為 SwAV 是只考慮 positive pair 的模型設計,因此機器學習會將所有特徵都學習為同一種 Q,類似於 GAN 網路中的模式崩壞問題 (Mode Collapse)。因此論文使用線性代數方法 Sinkhorn-Knopp Algorithm 來求解 Q,而不是交由機器學習常用的梯度更新求解器,避免模式崩壞。

簡而言之,SwAV 的損失函數負責解決類內相似最大化的問題,而類間相似最小化則交給 Sinkhorn-Knopp Algorithm 保證。

除此之外,SwAV 還引入了數據擴增設計 Multi-crop,在低解析度圖片上做更多的數據擴增,來增加訓練的數據量並提升效果。

Loss Function 損失函數

在對比學習中,傳統的損失函數理論上都可以引入進來用。但其中有 2 個比較有名的函數需要特別注意 —— NCE、InfoNCE。

NCE: Noise-Contrastive Estimation

NCE 是在 2010 年提出的方法,比較有歷史了,但是其重要性在對比學習中依舊。在計算樣本 x 的 NCE loss 時,我們會取一個正樣本 x+ 與一群負樣本 X-,並且期望 xx+ 足夠相似,且與所有 X- 足夠不同。網路上 NCE 的公式都滿不直覺的,我認為最直觀的公式應該寫成如下:

第一行的函數 h 是個機率估計函數,它的作用是引入 softmax 公式來計算機率。第二行則是一個 Binary Cross Entropy (BCE) loss 的標準寫法。所以簡而言之,NCE 就是用 一對正樣本一群負樣本,透過 相似性計算機率 所寫成的 BCE 損失

最後補充一下,通常相似度函數可以寫成特徵向量的 cosine 相似度,通常會再除以一個溫度參數 temperature

InfoNCE

InfoNCE 是在 2018 年的對比學習論文 CPC 中提出來,基於 NCE 的變化版本。InfoNCE 可以當作 NCECross Entropy 版本:我們會取一個正樣本 x+ 與全部樣本 X,並用 Cross Entropy 的方式計算損失:

在論文《Understanding the Behaviour of Contrastive Loss》中,指出 InfoNCE 具有更關注困難樣本的性質,屬於 Hardness-aware loss。在大多數的研究都指出,關注於困難樣本學習可以獲得更好的效果。

Data-Modal 數據模態

本文中主要的方法都是基於純圖片,因此高度仰賴數據擴增進行對比學習。但如果數據包含兩個模態以上,就可以在不同的模態之間做對比學習。例如 CLIP。值得注意的是,CLIP 在跨模態相似度空間上使用 Cross Entropy 作為損失函數,本質上其實就是 IntoNCE loss

圖片來源:CLIP Paper

更多的介紹可以參考之前的文章《10 分鐘搭建萬物識別 Live Demo》。

主流方法一覽

下圖把本篇文章中所有提到的方法整理成一個表格,大家可以從更廣的角度來看所有方法之間的異同。當然,這些方法並不是全部。受限於本人的知識量,還有很多對比學習的方法並沒有列在上面。

Contrastive Learning 各種主流模型的異同。Loss Function 中的加減號代表是否使用正樣本對、負樣本對。圖片來源:自製

除了技術架構外,論文《A Survey on Contrastive Self-Supervised Learning》也整理了主流對比學習技術的 參數量-精度 關係,如下圖。圖中可以看出 SwAV 的效果已經接近於監督學習。

各種對比學習方法在 ImageNet Top-1 Accuracy 的對比。圖片來源:A Survey on Contrastive Self-Supervised Learning

注意,SimCLR 的效果好是因為使用了超大 batch size = 8192 進行訓練,會需要大量的運算與記憶體資源。如果降低 batch size 的話,end-to-end 類型的方法會因為 negative pair 不夠,效果大幅降低。如果你想要重現 SimCLR 的實驗結果,硬體不夠的話是絕對不行的。

對比學習 v.s. 自監督學習

對比學習是自監督學習 Self-Supervised Learning (SSL) 裡面效果較好的方法,可以幫助我們在擁有海量數據並缺少標注的情況下,能夠訓練出很好的特徵提取網路。提取出的特徵可以直接用於做相似度比對,例如人臉辨識 Face Recognition、行人重識別 Person Re-Identification。也可以用學到的 encoder 模型來做遷移學習 transfer learning,做少樣本學習 Few-Shot Learning 或是零樣本學習 Zero-Shot Learning。

反過來說,SSL 除了對比學習之外,還有很多種基於數據本身的學習方式。如果以為「SSL = 對比學習」的話,那就大錯特錯了。

--

--

Rice Yang
Rice Yang

Written by Rice Yang

A Senior Engineer in AI. Experienced in NVIDIA, Alibaba, Pony.ai. Familiar with Deep Learning, Computer Vision, Software Engineering, Autonomous Driving