๐ผ๏ธ [๋ ผ๋ฌธ๋ฆฌ๋ทฐ] FixMatch: ์ ์ label์๋ ์ฑ๋ฅ์ ์ฌ๋ฆฌ๊ธฐ ์ํ ๊ธฐ๋ฒ
ํ์ฌ ํ์ฌ์์ label์ ์ ๋ขฐ์ฑ์ด ๋ฎ์ ๋ฐ์ดํฐ๋ฅผ ๋ค๋ฃจ๊ณ ์์ต๋๋ค. ์ด๋ฅผ ์์์ ์ผ๋ก ๊ฑฐ๋ฅผ ์๋ ์์ด ์ฌ๋์ ๋ฆฌ์์ค๊ฐ ์ ์ผ๋ฉด์๋ ์ฑ๋ฅ์ ์ฌ๋ฆฌ๊ธฐ ์ํ ๋ฐฉ๋ฒ์ ๋ํด ๊ณ ๋ฏผํ๊ณ ๊ทธ ์ค ๋ฐ๊ฒฌํ ๋ ผ๋ฌธ์ ๋๋ค.
NeurIPS 2020์์ Google Research๊ฐ ๋ฐํํ FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence์ ๋ํด ์ ๋ฆฌํ ๊ธ์ ๋๋ค. ๊ณ ๋ ค๋ํ๊ต ์ฐ์ ๊ฒฝ์ํ๋ถ DSBA ์ฐ๊ตฌ์ค Lab Seminar ์ ํ๋ธ ์์์ ์ฐธ๊ณ ํ์์ต๋๋ค.
Introduction
deep network๋ supervised learning์ ์ข์ ์ฑ๋ฅ์ ๋ณด์ ๋๋ค. ๋ฐฉ๋ํ ์์ ๋ฐ์ดํฐ์ ์ ์ฌ์ฉํ์ฌ ์ฑ๋ฅ์ ์ฌ๋ฆฌ๊ธฐ์ labeling์ด ํ์ํ๊ณ ์ด๋ฅผ ์ํด์ ์ฌ๋์ ๋ ธ๋๋ ฅ์ด ํ์ํฉ๋๋ค. ๋ํ, ์ ๋ฌธ๊ฐ(์๋ฅผ ๋ค๋ฉด, ์๋ฃ ๋ถ์ผ์ ์์ฌ)๊ฐ labeling์ ์ํํ๋ ๊ฒฝ์ฐ ๊ทน์ฌํ ๋น์ฉ์ด ๋ค ์ ์์ต๋๋ค. ์ด๋ฅผ ํด๊ฒฐํ๊ธฐ ์ํด semi-supervised learning(SSL) ๊ธฐ๋ฒ์ ๋ง์ด ์ฐ๊ตฌํ๊ณ ์์ต๋๋ค. SSL์ label์ด ์๋ ๋ฐ์ดํฐ๋ฅผ ํ์ตํ๋ ์ ๊ทผ๋ฒ์ ๋๋ค. ์ผ๋ฐ์ ์ผ๋ก label์ด ์๋ ์ด๋ฏธ์ง์ pseudo label์ ์์ฑํ๊ณ label์ด ์๋ ์ด๋ฏธ์ง๋ฅผ pseudo label๋ก ์์ธกํ๋๋ก ๋ชจ๋ธ์ ํ๋ จ์ํค๋ ๊ธฐ๋ฒ์ ์ฌ์ฉํฉ๋๋ค.
Model Structure
label์ด ์ฃผ์ด์ง ๋ฐ์ดํฐ๋ก ํ์ตํ๊ณ label์ด ์๋ ๋ฐ์ดํฐ๋ก ๊ฒ์ฆํ์ฌ ์์ธกํ ๊ฐ์ pseudo-labeling์ด๋ผ ํฉ๋๋ค. ์๋์ ๊ทธ๋ฆผ์ผ๋ก ์ดํดํด๋ณด๊ฒ ์ต๋๋ค. 1. label์ด ์๋ ๋ฐ์ดํฐ๋ฅผ ์ด์ฉํด ๋ชจ๋ธ์ ํ์ตํฉ๋๋ค. 2. ํ์ต๋ ๋ชจ๋ธ๋ก unlabeled data๋ฅผ ์์ธกํฉ๋๋ค. ์ด๋ฅผ ํตํด ์์ธก๋ ๊ฐ์ด pseudo label์ด ๋๋ ๊ฒ์ด์ฃ . 3. ๋ง์ง๋ง์ผ๋ก label์ด ์์ผ๋ pseudo label์ด ์๊ธด ๋ฐ์ดํฐ์ label์ด ์๋ ๋ฐ์ดํฐ๋ฅผ ํฉ์ณ ์ฌํ์ตํ๊ฒ ๋ฉ๋๋ค.
์ถ์ฒ: ๐ Introduction to Pseudo-Labelling : A Semi-Supervised learning technique
fixmatch๋ consistency regularization๊ณผ pseudo-labeling์ ์ด์ฉํ์ฌ pseudo label์ ๋ง๋ค์ด๋ ๋๋ค. ๊ทธ ์ค์์๋ weakly-augmented unlabeled image๋ฅผ ์ด์ฉํ์ฌ strong-augmented unlabeled image์ pseudo label์ ๋ง๋ค์ด๋ ๋๋ค. fixmatch๋ฅผ CIFAR-10์ ์ด์ฉํ์ฌ 250๊ฐ์ label์ด ์๋ ๋ฐ์ดํฐ๋ฅผ ๊ฐ์ง๊ณ ์๋ SOTA๋ฅผ ์ฐ์ด๋์ต๋๋ค.
Pseudo-labeling
labeling๋ ๋ฐ์ดํฐ๋ฅผ ์ด์ฉํด ๋ชจ๋ธ์ ํ์ตํ ๋ค ํ์ตํ ๋ชจ๋ธ๋ก labeling๋์ด์์ง ์์ unlabeled ๋ฐ์ดํฐ๋ฅผ ์์ธกํฉ๋๋ค. ์์ธก๋ ๊ฐ๋ค ์ค confidence score๊ฐ ๋์ ๋ฐ์ดํฐ๋ค๋ง pseudo label์ ๋ถ์ฌํฉ๋๋ค. ์์ธก๋ label์ ์ค์ label์ด ์๋๊ธฐ ๋๋ฌธ์ pseudo label์ด๋ผ ํฉ๋๋ค. labeled ๋ฐ์ดํฐ์ unlabeled ๋ฐ์ดํฐ์ pseudo label์ ๋ถ์ฌ์ ํ์ตํ ๋ค ๋ ๋ค์ labeling ๋์ด์์ง ์์ ๋๋จธ์ง unlabeled ๋ฐ์ดํฐ๋ฅผ ํ์ตํฉ๋๋ค. ๊ณ์ํด์ ๋ฐ๋ณตํ์ฌ unlabeled data๋ฅผ ์ค์ฌ๋๊ฐ๋๋ค.
Consistency regularization
๋ฐ์ดํฐ์ ์์ ๋ณํ(weakly augmentation)์ ๊ฐํด๋ ์์ธกํ ํ๋ฅ ์ ๋ณํ์ง ์์ ๊ฒ์ด๋ผ๋ ๊ฐ์ค ํ์ unlabeled ๋ฐ์ดํฐ์ noise(augmentation)์ ์ฃผ์ ํ ๋ค noise๊ฐ ์๋ ๋ฐ์ดํฐ์ noise๊ฐ ์ฃผ์ ๋ ๋ฐ์ดํฐ๋ฅผ ๋์ผํ class ๋ถํฌ๋ก ์์ธกํ๋๋ก ํ์ตํ๋ ๊ธฐ๋ฒ์ ๋๋ค.
Entropy Minimization
softmax input์ temperature(T)๋ hyperparameter๋ฅผ ์ ์ฉํด ๋ณด๋ค ๋ ํ์คํ ์์ธก ํ๋ฅ ์ ๋ง๋ค์ด๋ ๋๋ค. ์ฝ๋๋ฅผ ๋ณด์๋ฉด temperature๋ฅผ ์ ์ฉํ ๊ฐ๋ค์ ๋ ๊ทน๋จ์ ์ธ ๊ฐ์ ๊ฐ๋ ๊ฒ์ ๋ณด์ค ์ ์์ต๋๋ค.
import torchimport torch.nn as nn
inputs = torch.randn(2, 3)
m = nn.Softmax(dim=1)T = 0.3
print('#inputs\\n', inputs)
print('\\n#softmax(inputs)\\n', m(inputs))
print('\\n#softmax(inputs)+temperature\\n', m(inputs/T))
# #inputs
# tensor([[ 1.4486, -0.0243, -0.1175],
# [ 0.3694, 0.8460, -0.0310]])
# #softmax(inputs)
# tensor([[0.6954, 0.1594, 0.1452],
# [0.3048, 0.4909, 0.2042]])
# #softmax(inputs)+temperature
# tensor([[0.9874, 0.0073, 0.0053],
# [0.1623, 0.7950, 0.0427]])
loss function
- x: label์ด ์๋ ์ด๋ฏธ์ง๋ค ์งํฉ(b๊ฐ์๋งํผ)
- p(yโฅx): x๋ฅผ model์ ์ด์ฉํด class๋ฅผ ์์ธกํ ํ๋ฅ
- A(.): strongly augmentation
- α(.): weakly augmentation
- p: one-hot label๋ค์ ํ๋ฅ ๋ถํฌ
- H(p,โq): p์ q์ ํ๋ฅ ๋ถํฌ์ cross-entropy
supervised loss ls๋ ์ผ๋ฐ์ ์ผ๋ก ์ฌ์ฉ๋๋ supervised loss์ด๋ฉฐ weakly-augmented ๋ฐ์ดํฐ์ ์ฌ์ฉ๋ฉ๋๋ค. ์์ ๋ด์ฉ์ ๋ฐํ์ผ๋ก ์ค๋ช ๋๋ฆฌ๋ฉด label์ด ์๋ ์ด๋ฏธ์ง๋ค์ ์ฝํ augmentation์ ๊ฐํ ์ด๋ฏธ์ง๋ค์ด y๋ฅผ ์์ธกํ ํ๋ฅ ๊ฐ๊ณผ label์ ํ๋ฅ ๊ฐ์ cross-entropy loss์ ๋๋ค.
unsupervised loss lu๋ strongly-augmented ๋ฐ์ดํฐ์ ์ฌ์ฉ๋๋ loss์ ๋๋ค. ์ค๋ช ๋๋ฆฌ์๋ฉด unlabeled ์ด๋ฏธ์ง์ ๊ฐํ augmentaion์ ๊ฐํ y๋ฅผ ์์ธกํ ํ๋ฅ ๊ฐ๊ณผ ์ฝํ augmentation์ ๊ฐํ y๋ฅผ ์์ธกํ ํ๋ฅ ๊ฐ์ cross-entropy loss์ max(qb)๋ฅผ ๊ณฑํด์ฃผ๋๋ฐ max(qb)๊ฐ threshold(τ) ์ด์์ธ ๊ฒฝ์ฐ๋ง ๊ณฑํด์ค๋๋ค.
fixmatch์ ์ต์ข loss๋ ์๋์ ๊ฐ์ต๋๋ค. supervised loss์ unsupervised loss์ ๊ฐ์ค์น ๋๋ค๋ฅผ ๊ณฑํ ํฉ์ ๋๋ค.
Experiment & Results
- weight decay regularization
- standard SGD with momentum
- cosine learning rate decay, k๋ ํ์ฌ training step, K๋ ์ ์ฒด training step
- hyperparameters: λu=1,η=0.03,β=0.9,τ=0.95,μ=7,B=64,K=220
- backbone: wide resnet-28-2 with 1.5M parameters
Conclusion
์ ์ ๋ ์ด๋ธ๋ง ๋ฐ์ดํฐ์๋ ๋์ ์ ํ๋๋ฅผ ๋์ผ ์ ์๋ ๋น๊ต์ ๊ฐ๋จํ ์๊ณ ๋ฆฌ์ฆ์ธ FixMatch ๋ชจ๋ธ์ ์ ์ํ๊ณ , weight decay์ optimizer ๊ฐ์ ์ค๊ณ๊ฐ ์ค์ํ๋ค๋ ๊ฒ์ ๊ฐ์กฐํ์์ต๋๋ค. FixMatch์ ๊ฐ์ ๊ฐ๋จํ๋ฉด์๋ ๋์ ์ฑ๋ฅ์ ๋ด๋ SSL ์๊ณ ๋ฆฌ์ฆ์ ์ฃผ๋ชฉํด์ผํ๋ค๊ณ ์ ์ํ์ต๋๋ค.