본문 바로가기

부스트캠프 AI Tech 3기/이론 : U-stage

[Day7] Pytorch 5. Pytorch Dataset : dataset, transforms, DataLoader

대용량 데이터 처리가 중요하다

Dataset 클래스

데이터 입력 형태를 정의해서 표준화 시킨다
데이터를 어떻게 불러올 것인지, 길이는 얼마인지를 정의한다.
어떻게 반환해 줄 것인지를 __getitem__()에서 map-style로 결정해준다

import torch
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    # 초기 데이터 생성 방법을 지정
    # 데이터의 directory 지정을 하기도 함
    def __init__(self,text,labels):
        self.label=labels
        self.data=text
    
    # 데이터의 전체 길이
    def __len__(self):
        return len(self.labels)
        
    # 반환되는 데이터의 형태
    # index를 이용해서 dict로 반환하는 코드
    def __getitem__(self, idx):
        label=self.labels[idx]
        text=self.data[idx]
        sample={"text":text, "class":label}
        return sample​

 

위의 예제는 text에 대한 것이지만 이미지, 오디오 등 데이터의 형태에 따라 함수를 다르게 정의해야한다.
huggingface와 같은 표준화된 라이브러리를 사용해도 된다.

 

이렇게 Dataset을 생성할 수 있다

transfroms

ToTensor()로 데이터를 tensor로 바꾸는 전처리를 한다. dataset에서 하지 않고 transforms에서 한다
ex) 이미지를 숫자로 바꾸어서 tensor로 저장

dataLoader

data를 묶어서 model에 feeding할 때 batch를 생성해준다.

GPU로 데이터를 feed하기 직전에 데이터를 변환하는 역할을 한다.

또한 shuffle을 통해 데이터를 잘 섞어준다.

MyDataLoader = DataLoader(MyDataset, batch_size=2, shuffle=True)
next(iter(MyDataLoader)) #데이터 generator로 변환 -> 추출 -> 메모리에 올라감

DataLoader가 1번 돌아가는 것을 epic이라고 한다. 

 

DataLoader의 다른 속성

  • sampler
    index를 결정하는 기법
    어떤 sampler가 있을까
  • batch_sampler
    batch를 어떻게 뽑을지 결정하는 기법
  • collate_fn
    data와 label을 분리해서 사용할 수 있는 옵션
    text 길이가 다를 때 padding처리, sequence data 처리 시 사용

 

추가공부 +)mnist의 source clone coding