ํฐ์คํ ๋ฆฌ ๋ทฐ
๐งโ๏ธ best model์ ์ด๋ค ์งํ๋ฅผ ๊ธฐ์ค์ผ๋ก ์ ์ฅํ ๊น?
์ด์ YIYU 2024. 5. 2. 14:45๋ชจ๋ธ์ ํ์ต์ํค๊ณ ์ ์ฅ ์์ ๊ณ ๋ฏผ์ด ์๊ฒผ๋ค. ์ด๋ค ๊ฒ์ ๊ธฐ์ค์ผ๋ก ์ข์ ๋ชจ๋ธ์ด๋ผ๊ณ ๋งํ ์ ์์๊น? ๊ธฐ์กด์ loss๋ฅผ ์ค์ฌ์ผ๋ก ์๊ฐํ์ง๋ง ๋ด๊ฐ ๋ณด๊ณ ์๋ ๋ฐ์ดํฐ๋ ๋ถ๊ท ํ์ด ์ฌํ๊ธฐ ๋๋ฌธ์ ๋จ์ํ loss๊ฐ ๊ธฐ์ค์ด ๋๋ ๊ฒ ์๋๊ฑฐ๋ผ๋ ์๊ฐ์ ํ๋ค. ๊ทธ๋์ ๋ค๋ฅธ ์ฌ๋๋ค์ ์ด๋ป๊ฒ ์ฌ์ฉํ๊ณ ์๋์ง ์ฐพ์๋ดค๋ค. ์ ์ผ ์ข์ ๋ ํผ๋ฐ์ค๋ ์บ๊ธ์ด๋ค. ์ฝ๋๋ ๋ง๊ณ ์ฑ๋ฅ ์ค์ฌ์ด๊ธฐ ๋๋ฌธ์ด๋ค. ๊ทธ ์ค์ ๋ดค๋ ์ฝ๋๋ Pytorch multi labels by cumulated level ๐ ์๋ค. ์ฌ๊ธฐ์ best model์ ์ ์ฅํ๋ ์ฝ๋๋ง ๊ฐ์ ธ์๋ค.
if auroc > best_metric:
best_metric = auroc
torch.save(model.state_dict(), f'dict_model_{j}_fold_{fold}_ckpt_pytorch')
else :
early_stoping += 1
if early_stoping > EARLY_STOPPING :
print(f'{Fore.RED}{Style.BRIGHT}====> early stopping{Style.RESET_ALL}\\n')
break
if epoch+1 < 10 :
a =' '
else :
a =''
print(f'Epoch: {epoch+1}{a}/{EPOCHS} | Train Loss: {train_loss:.6f} | Val loss: {val_loss:.6f} | Val auc {auroc:.6f} | Best auc {best_metric:.6f} | lr: {lr} ')
auroc = metric.reset()
์ฌ๊ธฐ์ auprc๋ฅผ ๊ธฐ์ค์ผ๋ก best model์ ์ ์ฅํ๊ณ ์์๊ณ early stopping์ด ์ง์ ํ ๊ฐ๋ณด๋ค ํฌ๋ค๋ฉด ํ์ต์ ๋ฉ์ถ๋ ๋ก์ง์ด๋ค. ๋ฐ์ดํฐ์ ํน์ฑ์ ๋ง์ถฐ์ best model metric ๊ธฐ์ค์ ์ก์์ผ ๋๋๋ณด๋ค. ์ด๋๋ก๋ ์ฐ์ฐํ๋ ์ข ๋ ๊ฒ์ฆ๋ ์ฝ๋๋ฅผ ์ฐพ์๋ดค๋ค. ์ ๊ตฌํ๋์ด์๋ ํ๊น ํ์ด์ค์ transformer ์ฝ๋๋ฅผ ๋ฏ์ด๋ดค๋ค. ํ๊น ํ์ด์ค์๋ trainer์์ early stopping ๋ฐฉ์์ผ๋ก ํ์ต์ํฌ ์ ์๋ค. ๋จผ์ trainer callbacks์์ EarlyStoppingCallback ํ๋ผ๋ฏธํฐ๋ฅผ ์ฌ์ฉํด์ผํ๋ค.
trainer = Trainer(
model,
args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
callbacks = [EarlyStoppingCallback(early_stopping_patience=2)]
)
metric_for_best_model์์ ์ํ๋ metric์ ์ค์ ํด์ฃผ๋ฉด ๋๊ณ ๊ธฐ๋ณธ์ loss๋ก ๋์ด์๋ค. ์ด ํ๋ผ๋ฏธํฐ๋ฅผ ์์ ํ๋ ๊ฒฝ์ฐ์ greater_is_better๋ ํจ๊ป ์ง์ ํด์ผํ๋ค. auc์ฒ๋ผ ๋์ ๊ฐ์ด ์ข๋ค๋ฉด True, loss์ฒ๋ผ ๋ฎ์ ๊ฐ์ด ์ข๋ค๋ฉด False๋ฅผ ์ง์ ํด์ผํ๋ค.
metric_for_best_model (str, optional) — Use in conjunction with load_best_model_at_end to specify the metric to use to compare two different models. Must be the name of a metric returned by the evaluation with or without the prefix "eval_". Will default to "loss" if unspecified and load_best_model_at_end=True (to use the evaluation loss).
If you set this value, greater_is_better will default to True. Don’t forget to set it to False if your metric is better when lower.
์ด์ธ์๋ ์ง์ ํ ํ๋ผ๋ฏธํฐ๋ค์ ์๋์ ๊ฐ๋ค.
- load_best_model_at_end = True (EarlyStoppingCallback() requires this to be True).
- evaluation_strategy=’steps’ # or epoch
- eval_steps = 50 (evaluate the metrics after N steps).
๊ฒฐ๊ตญ, loss ์ธ์๋ f1์ด๋ auc ๋ฑ ๋ฐ์ดํฐ ํน์ฑ์ ๋ง์ถฐ early stopping์ ํ๋ ๊ฒ์ด ๋ง๋ค. ๋จ, early stopping์ ํ๋ จ ๋จ์๊ฐ step์ผ ๊ฒฝ์ฐ test ์ฑ๋ฅ์ด ๋ ์ ์ข์ ์ ์๋ค๋ ๊ฒ์ ์ฃผ์ํ์.
์ฐธ๊ณ ํ ์๋ฃ
'๐ง Machine Learning' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
๐ฅซ [ELMo ๋ ผ๋ฌธ ๋ฆฌ๋ทฐ] Deep contextualized word representations (0) | 2024.05.06 |
---|---|
๐ชข ๋ฅ๋ฌ๋์ ํ์ํ ํ๋ฅ ์ด์ง ์ฐ์ด๋จน๊ธฐ, ์ต๋์ฐ๋๋ฒ (0) | 2024.05.02 |
๐ง๏ธ ์ดํดํ๋ฉด ์ฌ์ด ๋ฒ ์ด์ฆ ์ ๋ฆฌ์ VAE (0) | 2024.04.17 |
๐งโ๐ซ KL divergence ์ดํด๋ณด๊ธฐ (1) | 2024.04.08 |
๐ Multivariable Fractional Polynomials(MFP) (0) | 2024.04.03 |
- ๋ฒ ์ด์ฆ ์ ๋ฆฌ
- python
- ์ฑ ๋ฆฌ๋ทฐ
- ๊ธ๋
- tmux
- GIT
- ํ๊ณ
- vscode
- ๊ฐ๋ฐ์
- ๋จธ์ ๋ฌ๋ ์ด๋ก
- Multiprocessing
- ๋ ํ๊ฐ
- Computer Vision
- linux
- Generative Model
์ผ | ์ | ํ | ์ | ๋ชฉ | ๊ธ | ํ |
---|---|---|---|---|---|---|
1 | 2 | 3 | 4 | 5 | 6 | 7 |
8 | 9 | 10 | 11 | 12 | 13 | 14 |
15 | 16 | 17 | 18 | 19 | 20 | 21 |
22 | 23 | 24 | 25 | 26 | 27 | 28 |
29 | 30 |
- Total
- Today
- Yesterday