๐งโ๏ธ best model์ ์ด๋ค ์งํ๋ฅผ ๊ธฐ์ค์ผ๋ก ์ ์ฅํ ๊น?
๋ชจ๋ธ์ ํ์ต์ํค๊ณ ์ ์ฅ ์์ ๊ณ ๋ฏผ์ด ์๊ฒผ๋ค. ์ด๋ค ๊ฒ์ ๊ธฐ์ค์ผ๋ก ์ข์ ๋ชจ๋ธ์ด๋ผ๊ณ ๋งํ ์ ์์๊น? ๊ธฐ์กด์ loss๋ฅผ ์ค์ฌ์ผ๋ก ์๊ฐํ์ง๋ง ๋ด๊ฐ ๋ณด๊ณ ์๋ ๋ฐ์ดํฐ๋ ๋ถ๊ท ํ์ด ์ฌํ๊ธฐ ๋๋ฌธ์ ๋จ์ํ loss๊ฐ ๊ธฐ์ค์ด ๋๋ ๊ฒ ์๋๊ฑฐ๋ผ๋ ์๊ฐ์ ํ๋ค. ๊ทธ๋์ ๋ค๋ฅธ ์ฌ๋๋ค์ ์ด๋ป๊ฒ ์ฌ์ฉํ๊ณ ์๋์ง ์ฐพ์๋ดค๋ค. ์ ์ผ ์ข์ ๋ ํผ๋ฐ์ค๋ ์บ๊ธ์ด๋ค. ์ฝ๋๋ ๋ง๊ณ ์ฑ๋ฅ ์ค์ฌ์ด๊ธฐ ๋๋ฌธ์ด๋ค. ๊ทธ ์ค์ ๋ดค๋ ์ฝ๋๋ Pytorch multi labels by cumulated level ๐ ์๋ค. ์ฌ๊ธฐ์ best model์ ์ ์ฅํ๋ ์ฝ๋๋ง ๊ฐ์ ธ์๋ค.
if auroc > best_metric:
best_metric = auroc
torch.save(model.state_dict(), f'dict_model_{j}_fold_{fold}_ckpt_pytorch')
else :
early_stoping += 1
if early_stoping > EARLY_STOPPING :
print(f'{Fore.RED}{Style.BRIGHT}====> early stopping{Style.RESET_ALL}\\n')
break
if epoch+1 < 10 :
a =' '
else :
a =''
print(f'Epoch: {epoch+1}{a}/{EPOCHS} | Train Loss: {train_loss:.6f} | Val loss: {val_loss:.6f} | Val auc {auroc:.6f} | Best auc {best_metric:.6f} | lr: {lr} ')
auroc = metric.reset()
์ฌ๊ธฐ์ auprc๋ฅผ ๊ธฐ์ค์ผ๋ก best model์ ์ ์ฅํ๊ณ ์์๊ณ early stopping์ด ์ง์ ํ ๊ฐ๋ณด๋ค ํฌ๋ค๋ฉด ํ์ต์ ๋ฉ์ถ๋ ๋ก์ง์ด๋ค. ๋ฐ์ดํฐ์ ํน์ฑ์ ๋ง์ถฐ์ best model metric ๊ธฐ์ค์ ์ก์์ผ ๋๋๋ณด๋ค. ์ด๋๋ก๋ ์ฐ์ฐํ๋ ์ข ๋ ๊ฒ์ฆ๋ ์ฝ๋๋ฅผ ์ฐพ์๋ดค๋ค. ์ ๊ตฌํ๋์ด์๋ ํ๊น ํ์ด์ค์ transformer ์ฝ๋๋ฅผ ๋ฏ์ด๋ดค๋ค. ํ๊น ํ์ด์ค์๋ trainer์์ early stopping ๋ฐฉ์์ผ๋ก ํ์ต์ํฌ ์ ์๋ค. ๋จผ์ trainer callbacks์์ EarlyStoppingCallback ํ๋ผ๋ฏธํฐ๋ฅผ ์ฌ์ฉํด์ผํ๋ค.
trainer = Trainer(
model,
args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
callbacks = [EarlyStoppingCallback(early_stopping_patience=2)]
)
metric_for_best_model์์ ์ํ๋ metric์ ์ค์ ํด์ฃผ๋ฉด ๋๊ณ ๊ธฐ๋ณธ์ loss๋ก ๋์ด์๋ค. ์ด ํ๋ผ๋ฏธํฐ๋ฅผ ์์ ํ๋ ๊ฒฝ์ฐ์ greater_is_better๋ ํจ๊ป ์ง์ ํด์ผํ๋ค. auc์ฒ๋ผ ๋์ ๊ฐ์ด ์ข๋ค๋ฉด True, loss์ฒ๋ผ ๋ฎ์ ๊ฐ์ด ์ข๋ค๋ฉด False๋ฅผ ์ง์ ํด์ผํ๋ค.
metric_for_best_model (str, optional) — Use in conjunction with load_best_model_at_end to specify the metric to use to compare two different models. Must be the name of a metric returned by the evaluation with or without the prefix "eval_". Will default to "loss" if unspecified and load_best_model_at_end=True (to use the evaluation loss).
If you set this value, greater_is_better will default to True. Don’t forget to set it to False if your metric is better when lower.
์ด์ธ์๋ ์ง์ ํ ํ๋ผ๋ฏธํฐ๋ค์ ์๋์ ๊ฐ๋ค.
- load_best_model_at_end = True (EarlyStoppingCallback() requires this to be True).
- evaluation_strategy=’steps’ # or epoch
- eval_steps = 50 (evaluate the metrics after N steps).
๊ฒฐ๊ตญ, loss ์ธ์๋ f1์ด๋ auc ๋ฑ ๋ฐ์ดํฐ ํน์ฑ์ ๋ง์ถฐ early stopping์ ํ๋ ๊ฒ์ด ๋ง๋ค. ๋จ, early stopping์ ํ๋ จ ๋จ์๊ฐ step์ผ ๊ฒฝ์ฐ test ์ฑ๋ฅ์ด ๋ ์ ์ข์ ์ ์๋ค๋ ๊ฒ์ ์ฃผ์ํ์.