๐ŸŒ  Computer Vision

๐Ÿ–ผ๏ธ [๋…ผ๋ฌธ๋ฆฌ๋ทฐ] FixMatch: ์ ์€ label์—๋„ ์„ฑ๋Šฅ์„ ์˜ฌ๋ฆฌ๊ธฐ ์œ„ํ•œ ๊ธฐ๋ฒ•

์ด์œ  YIYU 2024. 5. 2. 15:06

ํ˜„์žฌ ํšŒ์‚ฌ์—์„  label์˜ ์‹ ๋ขฐ์„ฑ์ด ๋‚ฎ์€ ๋ฐ์ดํ„ฐ๋ฅผ ๋‹ค๋ฃจ๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค. ์ด๋ฅผ ์ˆ˜์ž‘์—…์œผ๋กœ ๊ฑฐ๋ฅผ ์ˆ˜๋Š” ์—†์–ด ์‚ฌ๋žŒ์˜ ๋ฆฌ์†Œ์Šค๊ฐ€ ์ ์œผ๋ฉด์„œ๋„ ์„ฑ๋Šฅ์„ ์˜ฌ๋ฆฌ๊ธฐ ์œ„ํ•œ ๋ฐฉ๋ฒ•์— ๋Œ€ํ•ด ๊ณ ๋ฏผํ–ˆ๊ณ  ๊ทธ ์ค‘ ๋ฐœ๊ฒฌํ•œ ๋…ผ๋ฌธ์ž…๋‹ˆ๋‹ค.

NeurIPS 2020์—์„œ Google Research๊ฐ€ ๋ฐœํ‘œํ•œ FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence์— ๋Œ€ํ•ด ์ •๋ฆฌํ•œ ๊ธ€์ž…๋‹ˆ๋‹ค. ๊ณ ๋ ค๋Œ€ํ•™๊ต ์‚ฐ์—…๊ฒฝ์˜ํ•™๋ถ€ DSBA ์—ฐ๊ตฌ์‹ค Lab Seminar ์œ ํŠœ๋ธŒ ์˜์ƒ์„ ์ฐธ๊ณ ํ•˜์˜€์Šต๋‹ˆ๋‹ค.

Introduction

deep network๋Š” supervised learning์— ์ข‹์€ ์„ฑ๋Šฅ์„ ๋ณด์ž…๋‹ˆ๋‹ค. ๋ฐฉ๋Œ€ํ•œ ์–‘์˜ ๋ฐ์ดํ„ฐ์…‹์„ ์‚ฌ์šฉํ•˜์—ฌ ์„ฑ๋Šฅ์„ ์˜ฌ๋ฆฌ๊ธฐ์—” labeling์ด ํ•„์š”ํ•˜๊ณ  ์ด๋ฅผ ์œ„ํ•ด์„  ์‚ฌ๋žŒ์˜ ๋…ธ๋™๋ ฅ์ด ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค. ๋˜ํ•œ, ์ „๋ฌธ๊ฐ€(์˜ˆ๋ฅผ ๋“ค๋ฉด, ์˜๋ฃŒ ๋ถ„์•ผ์˜ ์˜์‚ฌ)๊ฐ€ labeling์„ ์ˆ˜ํ–‰ํ•˜๋Š” ๊ฒฝ์šฐ ๊ทน์‹ฌํ•œ ๋น„์šฉ์ด ๋“ค ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ด๋ฅผ ํ•ด๊ฒฐํ•˜๊ธฐ ์œ„ํ•ด semi-supervised learning(SSL) ๊ธฐ๋ฒ•์„ ๋งŽ์ด ์—ฐ๊ตฌํ•˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค. SSL์€ label์ด ์—†๋Š” ๋ฐ์ดํ„ฐ๋ฅผ ํ•™์Šตํ•˜๋Š” ์ ‘๊ทผ๋ฒ•์ž…๋‹ˆ๋‹ค. ์ผ๋ฐ˜์ ์œผ๋กœ label์ด ์—†๋Š” ์ด๋ฏธ์ง€์— pseudo label์„ ์ƒ์„ฑํ•˜๊ณ  label์ด ์—†๋Š” ์ด๋ฏธ์ง€๋ฅผ pseudo label๋กœ ์˜ˆ์ธกํ•˜๋„๋ก ๋ชจ๋ธ์„ ํ›ˆ๋ จ์‹œํ‚ค๋Š” ๊ธฐ๋ฒ•์„ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.

Model Structure

label์ด ์ฃผ์–ด์ง„ ๋ฐ์ดํ„ฐ๋กœ ํ•™์Šตํ•˜๊ณ  label์ด ์—†๋Š” ๋ฐ์ดํ„ฐ๋กœ ๊ฒ€์ฆํ•˜์—ฌ ์˜ˆ์ธกํ•œ ๊ฐ’์„ pseudo-labeling์ด๋ผ ํ•ฉ๋‹ˆ๋‹ค. ์•„๋ž˜์˜ ๊ทธ๋ฆผ์œผ๋กœ ์ดํ•ดํ•ด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค. 1. label์ด ์žˆ๋Š” ๋ฐ์ดํ„ฐ๋ฅผ ์ด์šฉํ•ด ๋ชจ๋ธ์„ ํ•™์Šตํ•ฉ๋‹ˆ๋‹ค. 2. ํ•™์Šต๋œ ๋ชจ๋ธ๋กœ unlabeled data๋ฅผ ์˜ˆ์ธกํ•ฉ๋‹ˆ๋‹ค. ์ด๋ฅผ ํ†ตํ•ด ์˜ˆ์ธก๋œ ๊ฐ’์ด pseudo label์ด ๋˜๋Š” ๊ฒƒ์ด์ฃ . 3. ๋งˆ์ง€๋ง‰์œผ๋กœ label์ด ์—†์œผ๋‚˜ pseudo label์ด ์ƒ๊ธด ๋ฐ์ดํ„ฐ์™€ label์ด ์žˆ๋Š” ๋ฐ์ดํ„ฐ๋ฅผ ํ•ฉ์ณ ์žฌํ•™์Šตํ•˜๊ฒŒ ๋ฉ๋‹ˆ๋‹ค.

์ถœ์ฒ˜: ๐Ÿ”— Introduction to Pseudo-Labelling : A Semi-Supervised learning technique

fixmatch๋Š” consistency regularization๊ณผ pseudo-labeling์„ ์ด์šฉํ•˜์—ฌ pseudo label์„ ๋งŒ๋“ค์–ด๋ƒ…๋‹ˆ๋‹ค. ๊ทธ ์ค‘์—์„œ๋„ weakly-augmented unlabeled image๋ฅผ ์ด์šฉํ•˜์—ฌ strong-augmented unlabeled image์˜ pseudo label์„ ๋งŒ๋“ค์–ด๋ƒ…๋‹ˆ๋‹ค. fixmatch๋ฅผ CIFAR-10์„ ์ด์šฉํ•˜์—ฌ 250๊ฐœ์˜ label์ด ์žˆ๋Š” ๋ฐ์ดํ„ฐ๋ฅผ ๊ฐ€์ง€๊ณ ์„œ๋„ SOTA๋ฅผ ์ฐ์–ด๋ƒˆ์Šต๋‹ˆ๋‹ค.

Pseudo-labeling

labeling๋œ ๋ฐ์ดํ„ฐ๋ฅผ ์ด์šฉํ•ด ๋ชจ๋ธ์„ ํ•™์Šตํ•œ ๋’ค ํ•™์Šตํ•œ ๋ชจ๋ธ๋กœ labeling๋˜์–ด์žˆ์ง€ ์•Š์€ unlabeled ๋ฐ์ดํ„ฐ๋ฅผ ์˜ˆ์ธกํ•ฉ๋‹ˆ๋‹ค. ์˜ˆ์ธก๋œ ๊ฐ’๋“ค ์ค‘ confidence score๊ฐ€ ๋†’์€ ๋ฐ์ดํ„ฐ๋“ค๋งŒ pseudo label์„ ๋ถ€์—ฌํ•ฉ๋‹ˆ๋‹ค. ์˜ˆ์ธก๋œ label์€ ์‹ค์ œ label์ด ์•„๋‹ˆ๊ธฐ ๋•Œ๋ฌธ์— pseudo label์ด๋ผ ํ•ฉ๋‹ˆ๋‹ค. labeled ๋ฐ์ดํ„ฐ์™€ unlabeled ๋ฐ์ดํ„ฐ์— pseudo label์„ ๋ถ™์—ฌ์„œ ํ•™์Šตํ•œ ๋’ค ๋˜ ๋‹ค์‹œ labeling ๋˜์–ด์žˆ์ง€ ์•Š์€ ๋‚˜๋จธ์ง€ unlabeled ๋ฐ์ดํ„ฐ๋ฅผ ํ•™์Šตํ•ฉ๋‹ˆ๋‹ค. ๊ณ„์†ํ•ด์„œ ๋ฐ˜๋ณตํ•˜์—ฌ unlabeled data๋ฅผ ์ค„์—ฌ๋‚˜๊ฐ‘๋‹ˆ๋‹ค.

Consistency regularization

๋ฐ์ดํ„ฐ์— ์ž‘์€ ๋ณ€ํ˜•(weakly augmentation)์„ ๊ฐ€ํ•ด๋„ ์˜ˆ์ธกํ•œ ํ™•๋ฅ ์€ ๋ณ€ํ•˜์ง€ ์•Š์„ ๊ฒƒ์ด๋ผ๋Š” ๊ฐ€์„ค ํ•˜์— unlabeled ๋ฐ์ดํ„ฐ์— noise(augmentation)์„ ์ฃผ์ž…ํ•œ ๋’ค noise๊ฐ€ ์—†๋Š” ๋ฐ์ดํ„ฐ์™€ noise๊ฐ€ ์ฃผ์ž…๋œ ๋ฐ์ดํ„ฐ๋ฅผ ๋™์ผํ•œ class ๋ถ„ํฌ๋กœ ์˜ˆ์ธกํ•˜๋„๋ก ํ•™์Šตํ•˜๋Š” ๊ธฐ๋ฒ•์ž…๋‹ˆ๋‹ค.

Entropy Minimization

softmax input์— temperature(T)๋ž€ hyperparameter๋ฅผ ์ ์šฉํ•ด ๋ณด๋‹ค ๋” ํ™•์‹คํ•œ ์˜ˆ์ธก ํ™•๋ฅ ์„ ๋งŒ๋“ค์–ด๋ƒ…๋‹ˆ๋‹ค. ์ฝ”๋“œ๋ฅผ ๋ณด์‹œ๋ฉด temperature๋ฅผ ์ ์šฉํ•œ ๊ฐ’๋“ค์˜ ๋” ๊ทน๋‹จ์ ์ธ ๊ฐ’์„ ๊ฐ–๋Š” ๊ฒƒ์„ ๋ณด์‹ค ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

import torchimport torch.nn as nn
inputs = torch.randn(2, 3)
m = nn.Softmax(dim=1)T = 0.3
print('#inputs\\n', inputs)
print('\\n#softmax(inputs)\\n', m(inputs))
print('\\n#softmax(inputs)+temperature\\n', m(inputs/T))

# #inputs
#  tensor([[ 1.4486, -0.0243, -0.1175],
#         [ 0.3694,  0.8460, -0.0310]])
# #softmax(inputs)
#  tensor([[0.6954, 0.1594, 0.1452],
#         [0.3048, 0.4909, 0.2042]])
# #softmax(inputs)+temperature
#  tensor([[0.9874, 0.0073, 0.0053],
#         [0.1623, 0.7950, 0.0427]])

loss function

  • x: label์ด ์žˆ๋Š” ์ด๋ฏธ์ง€๋“ค ์ง‘ํ•ฉ(b๊ฐœ์ˆ˜๋งŒํผ)
  • p(yโˆฅx): x๋ฅผ model์„ ์ด์šฉํ•ด class๋ฅผ ์˜ˆ์ธกํ•œ ํ™•๋ฅ 
  • A(.): strongly augmentation
  • α(.): weakly augmentation
  • p: one-hot label๋“ค์˜ ํ™•๋ฅ  ๋ถ„ํฌ
  • H(p,โ€†q): p์™€ q์˜ ํ™•๋ฅ  ๋ถ„ํฌ์˜ cross-entropy

supervised loss ls๋Š” ์ผ๋ฐ˜์ ์œผ๋กœ ์‚ฌ์šฉ๋˜๋Š” supervised loss์ด๋ฉฐ weakly-augmented ๋ฐ์ดํ„ฐ์— ์‚ฌ์šฉ๋ฉ๋‹ˆ๋‹ค. ์œ„์˜ ๋‚ด์šฉ์„ ๋ฐ”ํƒ•์œผ๋กœ ์„ค๋ช…๋“œ๋ฆฌ๋ฉด label์ด ์žˆ๋Š” ์ด๋ฏธ์ง€๋“ค์— ์•ฝํ•œ augmentation์„ ๊ฐ€ํ•œ ์ด๋ฏธ์ง€๋“ค์ด y๋ฅผ ์˜ˆ์ธกํ•œ ํ™•๋ฅ  ๊ฐ’๊ณผ label์˜ ํ™•๋ฅ  ๊ฐ’์˜ cross-entropy loss์ž…๋‹ˆ๋‹ค.

unsupervised loss lu๋Š” strongly-augmented ๋ฐ์ดํ„ฐ์— ์‚ฌ์šฉ๋˜๋Š” loss์ž…๋‹ˆ๋‹ค. ์„ค๋ช… ๋“œ๋ฆฌ์ž๋ฉด unlabeled ์ด๋ฏธ์ง€์— ๊ฐ•ํ•œ augmentaion์„ ๊ฐ€ํ•œ y๋ฅผ ์˜ˆ์ธกํ•œ ํ™•๋ฅ  ๊ฐ’๊ณผ ์•ฝํ•œ augmentation์„ ๊ฐ€ํ•œ y๋ฅผ ์˜ˆ์ธกํ•œ ํ™•๋ฅ  ๊ฐ’์˜ cross-entropy loss์— max(qb)๋ฅผ ๊ณฑํ•ด์ฃผ๋Š”๋ฐ max(qb)๊ฐ€ threshold(τ) ์ด์ƒ์ธ ๊ฒฝ์šฐ๋งŒ ๊ณฑํ•ด์ค๋‹ˆ๋‹ค.

fixmatch์˜ ์ตœ์ข… loss๋Š” ์•„๋ž˜์™€ ๊ฐ™์Šต๋‹ˆ๋‹ค. supervised loss์™€ unsupervised loss์— ๊ฐ€์ค‘์น˜ ๋žŒ๋‹ค๋ฅผ ๊ณฑํ•œ ํ•ฉ์ž…๋‹ˆ๋‹ค.

Experiment & Results

  • weight decay regularization
  • standard SGD with momentum
  • cosine learning rate decay, k๋Š” ํ˜„์žฌ training step, K๋Š” ์ „์ฒด training step
  • hyperparameters: λu=1,η=0.03,β=0.9,τ=0.95,μ=7,B=64,K=220
  • backbone: wide resnet-28-2 with 1.5M parameters

Conclusion

์ ์€ ๋ ˆ์ด๋ธ”๋ง ๋ฐ์ดํ„ฐ์—๋„ ๋†’์€ ์ •ํ™•๋„๋ฅผ ๋†’์ผ ์ˆ˜ ์žˆ๋Š” ๋น„๊ต์  ๊ฐ„๋‹จํ•œ ์•Œ๊ณ ๋ฆฌ์ฆ˜์ธ FixMatch ๋ชจ๋ธ์„ ์ œ์•ˆํ–ˆ๊ณ , weight decay์™€ optimizer ๊ฐ™์€ ์„ค๊ณ„๊ฐ€ ์ค‘์š”ํ•˜๋‹ค๋Š” ๊ฒƒ์„ ๊ฐ•์กฐํ•˜์˜€์Šต๋‹ˆ๋‹ค. FixMatch์™€ ๊ฐ™์€ ๊ฐ„๋‹จํ•˜๋ฉด์„œ๋„ ๋†’์€ ์„ฑ๋Šฅ์„ ๋‚ด๋Š” SSL ์•Œ๊ณ ๋ฆฌ์ฆ˜์— ์ฃผ๋ชฉํ•ด์•ผํ•œ๋‹ค๊ณ  ์ œ์•ˆํ–ˆ์Šต๋‹ˆ๋‹ค.

Reference