A A
[PyTorch] Dataset & DataLoader with CIFAR-10
โš ๏ธ ๋ณธ ๋‚ด์šฉ์€ PyTorch Korea์˜ ๊ณต์‹ ๋ฌธ์„œ์— ๊ธฐ๋ฐ˜ํ•˜์—ฌ ๊ณต๋ถ€ํ•œ ๋‚ด์šฉ์„ ์ ์€๊ฒƒ์ด๋‹ˆ ์–‘ํ•ด๋ฐ”๋ž๋‹ˆ๋‹ค!
 

Dataset๊ณผ DataLoader

ํŒŒ์ดํ† ์น˜(PyTorch) ๊ธฐ๋ณธ ์ตํžˆ๊ธฐ|| ๋น ๋ฅธ ์‹œ์ž‘|| ํ…์„œ(Tensor)|| Dataset๊ณผ DataLoader|| ๋ณ€ํ˜•(Transform)|| ์‹ ๊ฒฝ๋ง ๋ชจ๋ธ ๊ตฌ์„ฑํ•˜๊ธฐ|| Autograd|| ์ตœ์ ํ™”(Optimization)|| ๋ชจ๋ธ ์ €์žฅํ•˜๊ณ  ๋ถˆ๋Ÿฌ์˜ค๊ธฐ ๋ฐ์ดํ„ฐ ์ƒ˜ํ”Œ์„ ์ฒ˜๋ฆฌํ•˜๋Š” ์ฝ”

tutorials.pytorch.kr


Dataset & DataLoader

PyTorch์˜ Dataset๊ณผ DataLoader๋Š” ๋ฐ์ดํ„ฐ์…‹์„ ํšจ์œจ์ ์œผ๋กœ ๋กœ๋“œํ•˜๊ณ  ์ „์ฒ˜๋ฆฌํ•˜๊ธฐ ์œ„ํ•ด ์‚ฌ์šฉ๋˜๋Š” ๊ฐ•๋ ฅํ•œ ๋„๊ตฌ์ž…๋‹ˆ๋‹ค.
๋Œ€๊ทœ๋ชจ ๋ฐ์ดํ„ฐ์…‹์„ ๋ฐฐ์น˜ ๋‹จ์œ„๋กœ ๋‚˜๋ˆ„์–ด ๋ชจ๋ธ์— ๊ณต๊ธ‰ํ•  ์ˆ˜ ์žˆ๋„๋ก ๋„์™€์ค๋‹ˆ๋‹ค.

 

  • ๋ฐ์ดํ„ฐ ์ƒ˜ํ”Œ์„ ์ฒ˜๋ฆฌํ•˜๋Š” ์ฝ”๋“œ๋Š” ์ง€์ €๋ถ„(messy)ํ•˜๊ณ  ์œ ์ง€๋ณด์ˆ˜๊ฐ€ ์–ด๋ ค์šธ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
  • ๋” ๋‚˜์€ ๊ฐ€๋…์„ฑ(readability)๊ณผ ๋ชจ๋“ˆ์„ฑ(modularity)์„ ์œ„ํ•ด ๋ฐ์ดํ„ฐ์…‹ ์ฝ”๋“œ๋ฅผ ๋ชจ๋ธ ํ•™์Šต ์ฝ”๋“œ๋กœ๋ถ€ํ„ฐ ๋ถ„๋ฆฌํ•˜๋Š” ๊ฒƒ์ด ์ด์ƒ์ ์ž…๋‹ˆ๋‹ค.
  • PyTorch๋Š” torch.utils.data.DataLoader ์™€ torch.utils.data.Dataset ์˜ ๋‘ ๊ฐ€์ง€ ๋ฐ์ดํ„ฐ ๊ธฐ๋ณธ ์š”์†Œ๋ฅผ ์ œ๊ณตํ•˜์—ฌ ๋ฏธ๋ฆฌ ์ค€๋น„ํ•ด๋‘”(pre-loaded) ๋ฐ์ดํ„ฐ์…‹ ๋ฟ๋งŒ ์•„๋‹ˆ๋ผ ๊ฐ€์ง€๊ณ  ์žˆ๋Š” ๋ฐ์ดํ„ฐ๋ฅผ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋„๋ก ํ•ฉ๋‹ˆ๋‹ค.
  • Dataset์€ ์ƒ˜ํ”Œ๊ณผ ์ •๋‹ต(label)์„ ์ €์žฅํ•˜๊ณ , DataLoader ๋Š” Dataset ์„ ์ƒ˜ํ”Œ์— ์‰ฝ๊ฒŒ ์ ‘๊ทผํ•  ์ˆ˜ ์žˆ๋„๋ก ์ˆœํšŒ ๊ฐ€๋Šฅํ•œ ๊ฐ์ฒด(iterable)๋กœ ๊ฐ์Œ‰๋‹ˆ๋‹ค.

 

  • PyTorch์˜ ๋„๋ฉ”์ธ ํŠนํ™” ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋“ค์€ (FashionMNIST์™€ ๊ฐ™์€) ๋ฏธ๋ฆฌ ์ค€๋น„ํ•ด๋‘”(pre-loaded) ๋‹ค์–‘ํ•œ ๋ฐ์ดํ„ฐ์…‹์„ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค.
  • ๋ฐ์ดํ„ฐ์…‹์€ torch.utils.data.Dataset ์˜ ํ•˜์œ„ ํด๋ž˜์Šค๋กœ ๊ฐœ๋ณ„ ๋ฐ์ดํ„ฐ๋ฅผ ํŠน์ •ํ•˜๋Š” ํ•จ์ˆ˜๊ฐ€ ๊ตฌํ˜„๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค.
  • ์ด๋Ÿฌํ•œ ๋ฐ์ดํ„ฐ์…‹์€ ๋ชจ๋ธ์„ ๋งŒ๋“ค์–ด๋ณด๊ณ (prototype) ์„ฑ๋Šฅ์„ ์ธก์ •(benchmark)ํ•˜๋Š”๋ฐ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
  • ์•„๋ž˜๋Š” ์—ฌ๋Ÿฌ ๋ฐ์ดํ„ฐ์…‹ ์˜ˆ์‹œ์ž…๋‹ˆ๋‹ค. ๋งํฌ๋ฅผ ๋‹ฌ์•„๋†“์•˜์œผ๋‹ˆ ํ™•์ธํ•ด๋ณด์„ธ์š”!

Image Dataset

 

Datasets — Torchvision 0.19 documentation

Shortcuts

pytorch.org

Text Dataset

 

torchtext.datasets — Torchtext 0.18.0 documentation

Shortcuts

pytorch.org

Audio Dataset

 

torchaudio.datasets — Torchaudio 2.4.0 documentation

Docs > torchaudio.datasets > Old version (stable) Shortcuts

pytorch.org


Dataset ๋ถˆ๋Ÿฌ์˜ค๊ธฐ

TorchVision์—์„œ CIFAR-10 ๋ฐ์ดํ„ฐ์…‹์„ ๋ถˆ๋Ÿฌ์˜ค๋Š” ์˜ˆ์ œ๋ฅผ ์‚ดํŽด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.
๊ณต์‹๋ฌธ์„œ์—๋Š” MNIST๋กœ ๋˜์–ด ์žˆ์ง€๋งŒ, ๋”ฐ๋ผํ•˜๋Š”๊ฑฐ ๊ฐ™์•„์„œ ํ•œ๋ฒˆ ๋‹ค๋ฅธ ๋ฐ์ดํ„ฐ์…‹์œผ๋กœ ์ „ ์‹œ๋„ํ•ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค!
 

CIFAR-10 and CIFAR-100 datasets

< Back to Alex Krizhevsky's home page The CIFAR-10 and CIFAR-100 are labeled subsets of the 80 million tiny images dataset. They were collected by Alex Krizhevsky, Vinod Nair, and Geoffrey Hinton. The CIFAR-10 dataset The CIFAR-10 dataset consists of 60000

www.cs.toronto.edu

  • CIFAR-10 ๋ฐ์ดํ„ฐ์…‹์€ 60,000๊ฐœ์˜ 32x32 ์ปฌ๋Ÿฌ ์ด๋ฏธ์ง€๋กœ ๊ตฌ์„ฑ๋˜์–ด ์žˆ์œผ๋ฉฐ, ๊ฐ ์ด๋ฏธ์ง€๋Š” 10๊ฐœ์˜ ํด๋ž˜์Šค ์ค‘ ํ•˜๋‚˜์— ์†ํ•ฉ๋‹ˆ๋‹ค.
  • 50,000๊ฐœ์˜ ํ•™์Šต ์ด๋ฏธ์ง€์™€ 10,000๊ฐœ์˜ ํ…Œ์ŠคํŠธ ์ด๋ฏธ์ง€๋กœ ๋‚˜๋ˆ„์–ด์ ธ ์žˆ์Šต๋‹ˆ๋‹ค.

CIFAR-10 Dataset Parameter (๋งค๊ฐœ๋ณ€์ˆ˜)

  • root ๋Š” ํ•™์Šต/ํ…Œ์ŠคํŠธ ๋ฐ์ดํ„ฐ๊ฐ€ ์ €์žฅ๋˜๋Š” ๊ฒฝ๋กœ์ž…๋‹ˆ๋‹ค.
  • train ์€ ํ•™์Šต์šฉ ๋˜๋Š” ํ…Œ์ŠคํŠธ์šฉ ๋ฐ์ดํ„ฐ์…‹ ์—ฌ๋ถ€๋ฅผ ์ง€์ •ํ•ฉ๋‹ˆ๋‹ค.
  • download=True ๋Š” root ์— ๋ฐ์ดํ„ฐ๊ฐ€ ์—†๋Š” ๊ฒฝ์šฐ ์ธํ„ฐ๋„ท์—์„œ ๋‹ค์šด๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค.
  • transform ๊ณผ target_transform ์€ ํŠน์ง•(feature)๊ณผ ์ •๋‹ต(label) ๋ณ€ํ˜•(transform)์„ ์ง€์ •ํ•ฉ๋‹ˆ๋‹ค.
import torch
from torchvision import datasets, transforms
import matplotlib.pyplot as plt

# ๋ฐ์ดํ„ฐ ๋ณ€ํ™˜ ์ •์˜ (์ •๊ทœํ™” ํฌํ•จ)
transform = transforms.Compose([
    transforms.ToTensor(),  # ์ด๋ฏธ์ง€๋ฅผ ํ…์„œ๋กœ ๋ณ€ํ™˜
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # ์ด๋ฏธ์ง€ ์ •๊ทœํ™”
])
# CIFAR-10 ํ•™์Šต ๋ฐ์ดํ„ฐ์…‹ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
training_data = datasets.CIFAR10(
    root="data",
    train=True,
    download=True,
    transform=transform
)

# CIFAR-10 ํ…Œ์ŠคํŠธ ๋ฐ์ดํ„ฐ์…‹ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
test_data = datasets.CIFAR10(
    root="data",
    train=False,
    download=True,
    transform=transform
)

Dataset์„ ์ˆœํšŒํ•˜๊ณ  ์‹œ๊ฐํ™” ํ•˜๊ธฐ

Dataset์— ๋ฆฌ์ŠคํŠธ(list)์ฒ˜๋Ÿผ ์ง์ ‘ ์ ‘๊ทผ(index)ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
training_data[index]. matplotlib ์„ ์‚ฌ์šฉํ•˜์—ฌ ํ•™์Šต ๋ฐ์ดํ„ฐ์˜ ์ผ๋ถ€๋ฅผ ์‹œ๊ฐํ™”ํ•ด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.
# CIFAR-10 ํด๋ž˜์Šค ๋ ˆ์ด๋ธ” ๋งต
labels_map = {
    0: "Airplane",
    1: "Automobile",
    2: "Bird",
    3: "Cat",
    4: "Deer",
    5: "Dog",
    6: "Frog",
    7: "Horse",
    8: "Ship",
    9: "Truck"
}

# ์‹œ๊ฐํ™”
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
    sample_idx = torch.randint(len(training_data), size=(1,)).item()
    img, label = training_data[sample_idx]
    figure.add_subplot(rows, cols, i)
    plt.title(labels_map[label])
    plt.axis("off")
    img = img / 2 + 0.5  # ์ •๊ทœํ™” ํ•ด์ œ
    plt.imshow(img.permute(1, 2, 0))  # (C, H, W) -> (H, W, C)๋กœ ๋ณ€๊ฒฝ
plt.show()


ํŒŒ์ผ์—์„œ ์‚ฌ์šฉ์ž ์ •์˜ ๋ฐ์ดํ„ฐ์…‹ ๋งŒ๋“ค๊ธฐ

  • ์‚ฌ์šฉ์ž ์ •์˜ Dataset ํด๋ž˜์Šค๋Š” ๋ฐ˜๋“œ์‹œ 3๊ฐœ ํ•จ์ˆ˜๋ฅผ ๊ตฌํ˜„ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.
  • __init__, __len__, and __getitem__. ์•„๋ž˜ ๊ตฌํ˜„์„ ์‚ดํŽด๋ณด๋ฉด CIFAR-10 ์ด๋ฏธ์ง€๋“ค์€ img_dir ๋””๋ ‰ํ† ๋ฆฌ์— ์ €์žฅ๋˜๊ณ , ์ •๋‹ต์€ annotations_file csv ํŒŒ์ผ์— ๋ณ„๋„๋กœ ์ €์žฅ๋ฉ๋‹ˆ๋‹ค.
  • ํ•œ๋ฒˆ ๊ฐ ํ•จ์ˆ˜๋“ค์—์„œ ์ผ์–ด๋‚˜๋Š” ์ผ๋“ค์„ ์ž์„ธํžˆ ์‚ดํŽด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.
import os
import pandas as pd
from torchvision.io import read_image
from torch.utils.data import Dataset

class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = read_image(img_path)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        sample = {"image": image, "label": label}
        return sample

__init__

  • __init__ ํ•จ์ˆ˜๋Š” Dataset ๊ฐ์ฒด๊ฐ€ ์ƒ์„ฑ(instantiate)๋  ๋•Œ ํ•œ ๋ฒˆ๋งŒ ์‹คํ–‰๋ฉ๋‹ˆ๋‹ค.
  • ์—ฌ๊ธฐ์„œ๋Š” ์ด๋ฏธ์ง€์™€ ์ฃผ์„ ํŒŒ์ผ(annotation_file)์ด ํฌํ•จ๋œ ๋””๋ ‰ํ† ๋ฆฌ์™€ (๋‹ค์Œ ์žฅ์—์„œ ์ž์„ธํžˆ ์‚ดํŽด๋ณผ) ๋‘๊ฐ€์ง€ ๋ณ€ํ˜•(transform)์„ ์ดˆ๊ธฐํ™”ํ•ฉ๋‹ˆ๋‹ค. labels.csv ํŒŒ์ผ์€ ์˜ˆ์‹œ๋ฅผ ๋“ค๋ฉด ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค.
tshirt1.jpg, 0
tshirt2.jpg, 0
......
ankleboot999.jpg, 9
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

 

Parameter

  • annotations_file: ์ด๋ฏธ์ง€ ํŒŒ์ผ ์ด๋ฆ„๊ณผ ๋ ˆ์ด๋ธ”์ด ํฌํ•จ๋œ CSV ํŒŒ์ผ์˜ ๊ฒฝ๋กœ.
  • img_dir: ์ด๋ฏธ์ง€ ํŒŒ์ผ์ด ์ €์žฅ๋œ ๋””๋ ‰ํ† ๋ฆฌ ๊ฒฝ๋กœ.
  • transform: ์ด๋ฏธ์ง€์— ์ ์šฉํ•  ๋ณ€ํ˜•(์ „์ฒ˜๋ฆฌ).
  • target_transform: ๋ ˆ์ด๋ธ”์— ์ ์šฉํ•  ๋ณ€ํ˜•(์ „์ฒ˜๋ฆฌ).
  • self.img_labels: CSV ํŒŒ์ผ์„ ์ฝ์–ด์™€ ๋ฐ์ดํ„ฐํ”„๋ ˆ์ž„์œผ๋กœ ์ €์žฅ.
  • self.img_dir: ์ด๋ฏธ์ง€ ๋””๋ ‰ํ† ๋ฆฌ ๊ฒฝ๋กœ ์ €์žฅ.
  • self.transform: ์ด๋ฏธ์ง€ ๋ณ€ํ˜• ์ €์žฅ.
  • self.target_transform: ๋ ˆ์ด๋ธ” ๋ณ€ํ˜• ์ €์žฅ.

 

__len__

  • __len__ ํ•จ์ˆ˜๋Š” ๋ฐ์ดํ„ฐ์…‹์˜ ์ƒ˜ํ”Œ ๊ฐœ์ˆ˜๋ฅผ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
def __len__(self):
    return len(self.img_labels)

 

  • len(self.img_labels): CSV ํŒŒ์ผ์— ์ €์žฅ๋œ ๋ ˆ์ด๋ธ”์˜ ๊ฐœ์ˆ˜๋ฅผ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.

 

__getitem__

  • __getitem__ ํ•จ์ˆ˜๋Š” ์ฃผ์–ด์ง„ ์ธ๋ฑ์Šค idx ์— ํ•ด๋‹นํ•˜๋Š” ์ƒ˜ํ”Œ์„ ๋ฐ์ดํ„ฐ์…‹์—์„œ ๋ถˆ๋Ÿฌ์˜ค๊ณ  ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
  • ์ธ๋ฑ์Šค๋ฅผ ๊ธฐ๋ฐ˜์œผ๋กœ, ๋””์Šคํฌ์—์„œ ์ด๋ฏธ์ง€์˜ ์œ„์น˜๋ฅผ ์‹๋ณ„ํ•˜๊ณ , read_image ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์ด๋ฏธ์ง€๋ฅผ ํ…์„œ๋กœ ๋ณ€ํ™˜ํ•˜๊ณ ,self.img_labels ์˜ csv ๋ฐ์ดํ„ฐ๋กœ๋ถ€ํ„ฐ ํ•ด๋‹นํ•˜๋Š” ์ •๋‹ต(label)์„ ๊ฐ€์ ธ์˜ค๊ณ , (ํ•ด๋‹นํ•˜๋Š” ๊ฒฝ์šฐ) ๋ณ€ํ˜•(transform) ํ•จ์ˆ˜๋“ค์„ ํ˜ธ์ถœํ•œ ๋’ค, ํ…์„œ ์ด๋ฏธ์ง€์™€ ๋ผ๋ฒจ์„ Python ์‚ฌ์ „(dict)ํ˜•์œผ๋กœ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
def __getitem__(self, idx):
    # ์ด๋ฏธ์ง€ ํŒŒ์ผ์˜ ๊ฒฝ๋กœ๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค. self.img_labels์˜ ์ฒซ ๋ฒˆ์งธ ์—ด์—๋Š” ํŒŒ์ผ ์ด๋ฆ„์ด ์žˆ์Šต๋‹ˆ๋‹ค.
    img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
    
    # ์ด๋ฏธ์ง€๋ฅผ ์ฝ์–ด์™€ ํ…์„œ๋กœ ๋ณ€ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
    image = read_image(img_path)
    
    # ๋ ˆ์ด๋ธ”์„ ๊ฐ€์ ธ์˜ต๋‹ˆ๋‹ค. self.img_labels์˜ ๋‘ ๋ฒˆ์งธ ์—ด์—๋Š” ๋ ˆ์ด๋ธ”์ด ์žˆ์Šต๋‹ˆ๋‹ค.
    label = self.img_labels.iloc[idx, 1]
    
    # transform์ด ์ •์˜๋˜์–ด ์žˆ๋‹ค๋ฉด, ์ด๋ฏธ์ง€๋ฅผ ๋ณ€ํ˜•ํ•ฉ๋‹ˆ๋‹ค.
    if self.transform:
        image = self.transform(image)
    
    # target_transform์ด ์ •์˜๋˜์–ด ์žˆ๋‹ค๋ฉด, ๋ ˆ์ด๋ธ”์„ ๋ณ€ํ˜•ํ•ฉ๋‹ˆ๋‹ค.
    if self.target_transform:
        label = self.target_transform(label)
    
    # ์ด๋ฏธ์ง€์™€ ๋ ˆ์ด๋ธ”์„ ํฌํ•จํ•˜๋Š” ์ƒ˜ํ”Œ์„ ์‚ฌ์ „(dict) ํ˜•์‹์œผ๋กœ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
    sample = {"image": image, "label": label}
    return sample

Parameter

 

  • img_path: ์ฃผ์–ด์ง„ ์ธ๋ฑ์Šค์— ํ•ด๋‹นํ•˜๋Š” ์ด๋ฏธ์ง€ ํŒŒ์ผ์˜ ๊ฒฝ๋กœ๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.
  • read_image(img_path): ์ด๋ฏธ์ง€ ํŒŒ์ผ์„ ์ฝ์–ด์™€ ํ…์„œ๋กœ ๋ณ€ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
  • label: ์ฃผ์–ด์ง„ ์ธ๋ฑ์Šค์— ํ•ด๋‹นํ•˜๋Š” ๋ ˆ์ด๋ธ”์„ ๊ฐ€์ ธ์˜ต๋‹ˆ๋‹ค.
  • if self.transform: ์ด๋ฏธ์ง€ ๋ณ€ํ˜•์ด ์ง€์ •๋œ ๊ฒฝ์šฐ, ๋ณ€ํ˜•์„ ์ ์šฉํ•ฉ๋‹ˆ๋‹ค.
  • if self.target_transform: ๋ ˆ์ด๋ธ” ๋ณ€ํ˜•์ด ์ง€์ •๋œ ๊ฒฝ์šฐ, ๋ณ€ํ˜•์„ ์ ์šฉํ•ฉ๋‹ˆ๋‹ค.
  • sample = {"image": image, "label": label}: ์ด๋ฏธ์ง€์™€ ๋ ˆ์ด๋ธ”์„ ํฌํ•จํ•˜๋Š” ์ƒ˜ํ”Œ์„ ์‚ฌ์ „(dict) ํ˜•์‹์œผ๋กœ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.

DataLoader๋กœ ํ•™์Šต์šฉ ๋ฐ์ดํ„ฐ ์ค€๋น„ํ•˜๊ธฐ

  • Dataset ํด๋ž˜์Šค๋Š” ๋ฐ์ดํ„ฐ์…‹์˜ ํŠน์ง•(feature)์„ ๊ฐ€์ ธ์˜ค๊ณ , ๊ฐ ์ƒ˜ํ”Œ์— ์ •๋‹ต(label)์„ ์ง€์ •ํ•˜๋Š” ์ผ์„ ํ•ฉ๋‹ˆ๋‹ค.
  • ๋ชจ๋ธ์„ ํ•™์Šตํ•  ๋•Œ๋Š” ๋ฐ์ดํ„ฐ๋ฅผ ๋ฏธ๋‹ˆ๋ฐฐ์น˜(minibatch)๋กœ ์ „๋‹ฌํ•˜๊ณ , ๋งค ์—ํญ(epoch)๋งˆ๋‹ค ๋ฐ์ดํ„ฐ๋ฅผ ๋‹ค์‹œ ์„ž์–ด์„œ ๊ณผ์ ํ•ฉ(overfit)์„ ๋ฐฉ์ง€ํ•˜๋ฉฐ, ๋ฉ€ํ‹ฐํ”„๋กœ์„ธ์‹ฑ์„ ์‚ฌ์šฉํ•˜์—ฌ ๋ฐ์ดํ„ฐ ๊ฒ€์ƒ‰ ์†๋„๋ฅผ ๋†’์ด๋ ค ํ•ฉ๋‹ˆ๋‹ค.
  • DataLoader๋Š” ์ด๋Ÿฌํ•œ ๋ณต์žกํ•œ ๊ณผ์ •๋“ค์„ ์ถ”์ƒํ™”ํ•œ ๊ฐ„๋‹จํ•œ API๋ฅผ ์ œ๊ณตํ•˜์—ฌ, ๋ฐ์ดํ„ฐ๋ฅผ ํšจ์œจ์ ์œผ๋กœ ๋ถˆ๋Ÿฌ์˜ฌ ์ˆ˜ ์žˆ๋Š” ์ˆœํšŒ ๊ฐ€๋Šฅํ•œ ๊ฐ์ฒด(iterable)๋ฅผ ๋งŒ๋“ญ๋‹ˆ๋‹ค.

DataLoader ์ƒ์„ฑ

๋จผ์ € torch.utils.data ๋ชจ๋“ˆ์—์„œ DataLoader ํด๋ž˜์Šค๋ฅผ ๊ฐ€์ ธ์™€ ํ•™์Šต์šฉ ๋ฐ์ดํ„ฐ์™€ ํ…Œ์ŠคํŠธ ๋ฐ์ดํ„ฐ๋ฅผ ์œ„ํ•œ DataLoader๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.
from torch.utils.data import DataLoader

# ํ•™์Šต ๋ฐ์ดํ„ฐ์šฉ DataLoader
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)

# ํ…Œ์ŠคํŠธ ๋ฐ์ดํ„ฐ์šฉ DataLoader
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

 

Parameter

  • batch_size=64: ํ•œ ๋ฒˆ์— 64๊ฐœ์˜ ์ƒ˜ํ”Œ์„ ๊ฐ€์ ธ์˜ต๋‹ˆ๋‹ค.
  • shuffle=True: ์—ํญ๋งˆ๋‹ค ๋ฐ์ดํ„ฐ๋ฅผ ์„ž์–ด์ค๋‹ˆ๋‹ค.

 

DataLoader๋ฅผ ํ†ตํ•ด ๋ฐ์ดํ„ฐ ์ˆœํšŒํ•˜๊ธฐ

  • DataLoader ์— ๋ฐ์ดํ„ฐ์…‹์„ ๋ถˆ๋Ÿฌ์˜จ ๋’ค์—๋Š” ํ•„์š”์— ๋”ฐ๋ผ ๋ฐ์ดํ„ฐ์…‹์„ ์ˆœํšŒ(iterate)ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
  • ์•„๋ž˜์˜ ๊ฐ ์ˆœํšŒ(iteration)๋Š” (๊ฐ๊ฐ batch_size=64 ์˜ ํŠน์ง•(feature)๊ณผ ์ •๋‹ต(label)์„ ํฌํ•จํ•˜๋Š”train_features ์™€ train_labels ์˜ ๋ฌถ์Œ(batch)์„ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
  • ๋˜ํ•œ shuffle=True ๋กœ ์ง€์ •ํ–ˆ์œผ๋ฏ€๋กœ, ๋ชจ๋“  ๋ฐฐ์น˜๋ฅผ ์ˆœํšŒํ•œ ๋’ค ๋ฐ์ดํ„ฐ๊ฐ€ ์„ž์ž…๋‹ˆ๋‹ค.
# DataLoader์—์„œ ์ฒซ ๋ฒˆ์งธ ๋ฐฐ์น˜ ๊ฐ€์ ธ์˜ค๊ธฐ
train_features, train_labels = next(iter(train_dataloader))

# ๋ฐฐ์น˜์˜ ํฌ๊ธฐ ์ถœ๋ ฅ
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")

# ์ฒซ ๋ฒˆ์งธ ์ด๋ฏธ์ง€์™€ ๋ผ๋ฒจ์„ ๊ฐ€์ ธ์™€์„œ ์‹œ๊ฐํ™”
img = train_features[0].permute(1, 2, 0) / 2 + 0.5  # ์ •๊ทœํ™” ํ•ด์ œ ๋ฐ (C, H, W) -> (H, W, C) ๋ณ€๊ฒฝ
label = train_labels[0]
plt.imshow(img.numpy())
plt.title(f"Label: {label}")
plt.axis("off")
plt.show()

print(f"Label: {label}")

 

 

Feature batch shape: torch.Size([64, 3, 32, 32])
Labels batch shape: torch.Size([64])

Label: 4

Parameter

 

  • next(iter(train_dataloader)): DataLoader์—์„œ ์ฒซ ๋ฒˆ์งธ ๋ฐฐ์น˜๋ฅผ ๊ฐ€์ ธ์˜ต๋‹ˆ๋‹ค.
  • train_features.size(): ํŠน์ง•(feature) ๋ฐฐ์น˜์˜ ํฌ๊ธฐ๋ฅผ ์ถœ๋ ฅํ•ฉ๋‹ˆ๋‹ค.
  • train_labels.size(): ์ •๋‹ต(label) ๋ฐฐ์น˜์˜ ํฌ๊ธฐ๋ฅผ ์ถœ๋ ฅํ•ฉ๋‹ˆ๋‹ค.
  • train_features[0]: ์ฒซ ๋ฒˆ์งธ ์ด๋ฏธ์ง€๋ฅผ ๊ฐ€์ ธ์˜ต๋‹ˆ๋‹ค.
  • permute(1, 2, 0): ์ด๋ฏธ์ง€์˜ ์ฐจ์›์„ (H, W, C) ํ˜•์‹์œผ๋กœ ๋ณ€๊ฒฝํ•˜์—ฌ ์‹œ๊ฐํ™”ํ•ฉ๋‹ˆ๋‹ค.
  • img.numpy(): ์ด๋ฏธ์ง€๋ฅผ NumPy ๋ฐฐ์—ด๋กœ ๋ณ€ํ™˜ํ•˜์—ฌ plt.imshow๋กœ ์‹œ๊ฐํ™”ํ•ฉ๋‹ˆ๋‹ค.
  • plt.title(f"Label: {label}"): ์ด๋ฏธ์ง€์˜ ๋ผ๋ฒจ์„ ์ œ๋ชฉ์œผ๋กœ ํ‘œ์‹œํ•ฉ๋‹ˆ๋‹ค.
  • plt.axis("off"): ์ถ•์„ ๋•๋‹ˆ๋‹ค.

 

๋” ์ž์„ธํ•œ ๋‚ด์šฉ์„ ๋ณด๊ณ  ์‹ถ์œผ์‹œ๋ฉด ์•„๋ž˜ ๋งํฌ์— ๋“ค์–ด๊ฐ€์„œ ๊ณต์‹ ๋ฌธ์„œ๋ฅผ ์ฐธ๊ณ ํ•˜์„ธ์š”!

  • torch.utils.data ๊ด€๋ จ PyTorch ๊ณต์‹ ๋ฌธ์„œ
 

torch.utils.data — PyTorch 2.4 documentation

torch.utils.data At the heart of PyTorch data loading utility is the torch.utils.data.DataLoader class. It represents a Python iterable over a dataset, with support for These options are configured by the constructor arguments of a DataLoader, which has si

pytorch.org