在實務中用生成網路 VAE 做半監督學習的原理與技巧

詳解 NIPS 論文:Semi-supervised Learning with Deep Generative Models

Rice Yang
11 min readOct 5, 2022
圖片來源:Deep learning in chemistry free wallpaper and backgrounds

這篇主要介紹 NIPS 的論文《Semi-supervised Learning with Deep Generative Models》,屬於用生成網路來輔助分類任務,完成半監督學習的開山之作。在我的工作中也有實際的應用到這篇論文,也取得了不錯的效果。

這篇文章雖然經典,但是論文裡面沒有任何圖片,只有數學,所以非常難看懂。我後來看到了原作者所發表 GitHub 文章,比原文好懂很多,於是想用這篇文章借用作者的圖解與概念來說明這篇論文的中心思想,並且在最後加上我在實務上的使用技巧與心得。

對於還不理解什麼是半監督學習與半監督學習的讀者,可以先看看以前的文章:

VAE: Variational Autoencoder

VAE 屬於生成網路的一個主要分支,也是無監督學習 Unsupervised Learning 的方法之一。VAE 的主要精髓思想就是 自己學自己——把訓練數據轉換到較低維度的特徵,再從低維度特徵還原回原始數據。

圖解 VAE。圖片來源:From Autoencoder to Beta-VAE | Lil’Log

VAE 雖然在預測、數據挖掘方面有大量實際應用,例如異常檢測;但在影像識別方面配合半監督學習的卻應用很少。主要有幾個原因:

  1. 複雜網路不適合作為 Encoder:ResNet, MobileNet 等電腦視覺任務表現優秀的架構,作為 VAE Encoder 卻常常無法讓 Decoder 順利解碼。
  2. 影像的維度太大:VAE 對影像用 CNN 做編碼的時候,泛化性 Generalization 不足,容易過擬合 overfitting
  3. 需要前處理:承先前 2 點,如果還是要對影像做 VAE ,需要做相當多的圖片對齊等預處理工作,降低 VAE/CNN 的學習難度。

半監督學習的 2 種思路

半監督學習 Semi-Supervised Learning 主流上有兩種思路。第一種類似之前介紹過半監督學習的經典策略 Self-Training,這個方法假設數據符合 Low-Density Separation(LDS) ,也就是類別之間的間距必須夠大的時候,就能夠透過循環打上虛擬標籤 psudo label 的範式來進行訓練。但是,不符合 LDS 假設的數據就不能使用這種 psudo label 的思路。

另外一個主流思路是 無監督學習的特徵提取+監督學習的任務訓練。前面段落提到的 VAE 特徵提取就屬於這種思路。但這種思路的缺點是「無監督學習的特徵提取」對於圖像這種高維度數據是十分難做好的,直到近年自監督 Self-Supervised Learning 出現後才有比較好的改善。

本篇論文屬於比較傾向於第二種思路,但是不是使用前後處理的 串聯 設計,而是使用聯合訓練的 並聯 設計。

核心思想:有監督與無監督的聯合訓練

在生成網路 Generative Models 的大類中,有些方法允許模型輸入額外的標籤用以控制輸出的數據,例如可以控制性別、年齡的人臉生成器。這種額外的輸入標籤稱為 Condition,引入 Condition 的方法通常被稱為 Conditional Generative Models,例如 Conditional GAN

理解了 Conditional Generative Models 後,一個有趣的想法來了:

如果我今天不知道某個數據的標籤,那我是否可以找到一個 Conditional Generative Models,然後看看哪個標籤還原回來的數據最像呢?

從邏輯上推論,這個方法似乎是可行的。我們可以先拿一些有標籤的數據訓練一個 Conditional Generative Models,再利用這個生成模型來反過來「驗算」我們對無標籤數據的預測標籤是否正確。

這就是這篇論文的最核心思想了。

這篇論文提出了三種模型結構:M1M2M1+M2。其中 M1+M2 我認為只是一種訓練的技巧,不需要太多的介紹。下面主要從 M1 切入,並說明文章重點 M2

M1:VAE + Classifier

M1 是屬於傳統的 VAE 特徵提取+分類網路的架構。我們使用 VAE 訓練一個 encoder + decoder 架構,然後用裡面的 encoder 作為特徵提取網路,再去訓練一個 MLP 或是 SVM 的分類器。

M1 架構。圖片來源:Semi-supervised Learning with Variational Autoencoders

M1 包含 VAE 的損失以及分類任務本身的任務損失。其中 VAE 損失包含兩部分:

  1. Reconstruction loss 重建損失:確保 VAE 能夠從特徵重建回原圖。這部分讓 encoder 網路提取的特徵足夠有表述性。
  2. Latent variable loss:確保提取的特徵符合常態分佈。這部分是用來確保VAE 不要 overfitting。

訓練 VAE 屬於無監督學習,可以在沒有標註的情況下訓練特徵提取能力。後續的分類器損失就是單純的分類任務損失,使用有監督的方式訓練。整體加起來屬於標準的學習模式:「無監督的特徵提取+有監督的訓練任務」。

M2:Conditional VAE + Classifier

M2 是本篇文章的核心思想。我們把分類任務的標籤 label 作為 latent variable 的一部份串接 (concat) 到 VAE 的中間特徵。對於有標注的數據,可以直接把標注的 label 拿來使用;對於無標注的數據,可以利用分類器預測一個 label 來使用。

在論文的實驗結果中,M2 架構可以比 M1 得到更好的結果。

M2 架構。圖片來源:Semi-supervised Learning with Variational Autoencoders

看著上述的架構圖,我們有 2 個解釋來幫助我們理解與實作這個架構。

觀點 1:梯度傳播的半監督學習

從梯度傳播的觀點來看,分類器 classifier 的更新梯度不只能來自於有標籤數據的任務損失,也可以透過左邊外接的 VAE 來產生。這是因為 VAE 自身無監督學習任務產生的損失梯度能夠透過 latent variable 層傳播到 classifier,達到更新權重的效果。

這個解釋雖然直觀,在 TensorFlow、Torch 等框架也是可實作的,但這不是論文原作者的實作邏輯。論文使用基於機率論的數學推導來完成半監督學習的損失函數。

觀點 2:基於機率論推導的半監督學習

作者基於機率論,在最大化生成圖片的機率分佈的前提下,推導出了下面這個 loss function,整個半監督學習任務的目標就是 最小化 這個函數:

損失函數拆解。公式部分來自:Semi-supervised Learning with Variational Autoencoders

其中 q(y|x) 代表 classifier 對標籤 y 符合圖像 x 的估計機率值。圖中綠色的部份表示 VAE 的損失,橘色的部份表示 classifier 的估計結果 q(y|x)上圖最需要注意的是淺藍色的 Maximum Entropy Distribution 這個項次。這個項次的輸入 q(y|x) 雖然與最右邊的 classification loss 是相同的,但是這兩個項次是完全相反的作用,會讓 q(y|x) 往兩個不同的方向拉扯。

  • Maximum Entropy Distribution:如圖下方所示,是一個向下開口的函數。在損失函數中又加上了負號,在最小化任務中會使 q(y|x) 往 0.5 靠攏,也就是傾向於任何標籤都給出低信心的估計機率。
  • Classification Loss:這是一個標準的 cross entropy loss,傾向讓 q(y|x) 給正確的標籤趨近於 1 的估計機率,而給其他標籤趨近 0 的估計值。

VAE, Maximum Entropy Distribution, Classification Loss 這 3個項次的交互作用下,Classifier 會在有標注的訓練數據上給出高信心的預測標籤,在無標注的訓練數據上找出讓 VAE 重建還原度最高的最佳標籤,否則就盡量給出中性的估計信心值。

更詳細的數學推導,建議看作者的 GitHub 原文,比論文詳細得多:

實務經驗

我在工作中實現過幾次該論文中的 M2 架構,也得到了不錯的成果。以下列出幾條經驗分享,如果你也準備嘗試這個方法的話,不仿先看一下能不能幫助到你。

數據類型是否適合

不得不說,使用生成網路類的訓練方式還是對數據有要求的。以下這幾種數據可能比較適合該任務:

  1. 低維度的數據,如工業機器監控數據、手工特徵數據。
  2. 簡單的圖片數據:例如 MNIST、CIFAR 數據集。
  3. 高對齊的圖片數據:例如人臉圖片。

架構優勢

這篇文章雖然問世很久了,但直至今日還是有些優勢讓你可以選用這個架構做為你的半監督學習方案。

  1. 支援任何 classifier 架構。因為 M2 網路使用外接的方式到 VAE 中,所以對於 classifier 是沒有任何要求的。如果你對網路有輕量化、essemble、大網路(或是炫砲網路)等設計需求,理論上都可以嵌到 M2 網路裡面。
  2. 容易新增損失函數。例如對抗損失、多任務學習、focal loss 等等額外的設計想法都可以很簡單的加進來。
  3. 平行計算訓練。因為 VAE 本身不需要 mini-batch 之間的數據溝通,相對於 Contractive Learning 對比學習類的架構,這個方法在平行計算上會更有優勢。

架構劣勢

雖然實務上我從這個架構拿到過不錯的成果,但是其中也碰到了一些麻煩,讓我在架構選型上會想避開這個架構。

  1. Hyper-parameters 超參數增加。包括 latent variable 的維度、各項 loss 之前的調整權重、VAE 的學習率等等,都會增加更多需要調整的超參數,增加工程師調參的難度。
  2. 數據要求過硬。如果數據類型不符合前面提到的假設,使用這個方法很可能怎麼調整都沒有辦法提升效果。又如果你的數據符合文章開頭提到的 LDS 假設的話,應該直接使用 psudo label 類的半監督學習架構。

其他小技巧

  1. 先分開訓練 VAE 與 classifier,再做聯合訓練。
  2. 可以對分類器 classifier 做數據擴增(無論是分開訓或是聯合訓)。
  3. 不要對 VAE 的輸入做數據擴增(無論是分開訓或是聯合訓)。
  4. 預測結果接入 latent variables 的時候,可以用些技巧對分類器的預測結果做銳化或平滑化。例如 Gumbel-Softmax
  5. M2 對於 dirty label 有不錯的容忍能力。如果你的數據標籤品質不是很高,不仿加入無監督數據接入 M2 ,可以抵消 dirty label 的負面影響。

--

--

Rice Yang
Rice Yang

Written by Rice Yang

A Tech Manager in AI. Experienced in NVIDIA, Alibaba, Pony.ai. Familiar with Deep Learning, Computer Vision, Software Engineering, Autonomous Driving

No responses yet