본문 바로가기

부스트캠프 AI Tech 3기/프로젝트 : P-stage

Wandb Sweep 하는 방법 : 하이퍼 파라미터 자동 튜닝

Wandb는 자동으로 train을 할 때 마다 그 때의 하이퍼 파라미터와를 기록하고 auc, rmse등을 기록해서 보여주는 굉장히 편리한 도구이다.

 

Weights & Biases – Developer tools for ML

WandB is a central dashboard to keep track of your hyperparameters, system metrics, and predictions so you can compare models live, and share your findings.

wandb.ai

이 Wandb에는 기록뿐 아니라 hyper parameter를 좋은 성능이 나오게 찾아주는 기능도 있다. 

 

1. wandb.log로 loss와 정확도 등 목적 값을 보낸다

wandb.log(
            {
                "epoch": epoch,
                "train_loss": train_loss,
                "train_auc": train_auc,
                "train_acc": train_acc,
                "valid_auc": auc,
                "valid_acc": acc,
            }
        )

for epoch in range(epochs) 안에 넣어서 epoch 마다 보내주면 된다.

 

2. Configuration

 

Sweep Configuration - Documentation

The brackets for this example are: [3, 3*eta, 3*eta*eta, 3*eta*eta*eta], which equals [3, 9, 27, 81].

docs.wandb.ai

뭐든 document가 제일 잘 나와있지만 설명을 보는게 편하긴 하다,,

program : train.py
entity : wandb에서의 team명
method : bayes
metric:
  name: valid_auc
  goal : minimize
parameters:
  max_seq_len : 
    values : [40,50,60,70,80,90,100,110]
  lr: 
    values : [0.1, 0.05, 0.01, 0.005, 0.0001, 0.0005, 0.00001]
  batch_size : 
    values : [32,64,128,256]
  patience : 
    min : 10
    max : 30

이런 내용으로 train.py 와 같은 위치에 sweep_config.yaml 파일을 만들어준다. (다른 위치에 하면 에러난다)

이 파일이 핵심이다

 

Method

method는 bayes, grid등이 있는데 grid는 말 그대로 grid search로 있는 값을 모두 탐색해서 시간이 엄청 걸린다. bayes가 적절하다.

 

Metirc

해당 학습 동안의 metric이 wandb.log()로 넣었던 것 중 어떤 것인지 알려주고, 이 metric값을 줄일지 늘릴지를 써준다.

 

Parameters

실험해보고싶은 hyper parameter를 min,max로 범위를 정해주거나 values로 아예 내가 지정한 값 내에서만 하게 할 수 있다. min max로 하면 시간이 오래걸려서 대부분 values로 정해줬던 것 같다. 

 

3. Config과 wandb 연결

https://docs.wandb.ai/guides/sweeps/quickstart

 

Sweeps Quickstart - Documentation

Set up a YAML file to specify the hyperparameters you wish to sweep over, along with the structure of the sweep like the training script to call, the search strategy and stopping criteria to use, etcetera.

docs.wandb.ai

wandb.init(project="FeatureEngineering",config=vars(args), entity=args.entity,name=args.wandb_name)

wandb init시에 config에 args를 넣어서 나는 따로 config를 만들지 않고 활용했다.

 

4. 실행

wandb sweep sweep_config.yaml

train.py와 sweep_config.yaml이 있는 위치에서 실행해주면

이렇게 뜨는 데 밑의 노랑색글자를 복사해서 다시 CLI 에서 실행해주면 sweep이 된다.

 

5. 결과 확인

실행이 끝나고 결과를 보려면 wandb 홈페이지에서 sweep 탭에 들어가면 된다

 

이렇게 어떤 파라미터일 때 auc가 제일 높은지를 한 눈에 보여준다!