A A
[PyTorch] Transform (๋ณ€ํ˜•)
โš ๏ธ ๋ณธ ๋‚ด์šฉ์€ PyTorch Korea์˜ ๊ณต์‹ ๋ฌธ์„œ์— ๊ธฐ๋ฐ˜ํ•˜์—ฌ ๊ณต๋ถ€ํ•œ ๋‚ด์šฉ์„ ์ ์€๊ฒƒ์ด๋‹ˆ ์–‘ํ•ด๋ฐ”๋ž๋‹ˆ๋‹ค!
 

๋ณ€ํ˜•(Transform)

ํŒŒ์ดํ† ์น˜(PyTorch) ๊ธฐ๋ณธ ์ตํžˆ๊ธฐ|| ๋น ๋ฅธ ์‹œ์ž‘|| ํ…์„œ(Tensor)|| Dataset๊ณผ Dataloader|| ๋ณ€ํ˜•(Transform)|| ์‹ ๊ฒฝ๋ง ๋ชจ๋ธ ๊ตฌ์„ฑํ•˜๊ธฐ|| Autograd|| ์ตœ์ ํ™”(Optimization)|| ๋ชจ๋ธ ์ €์žฅํ•˜๊ณ  ๋ถˆ๋Ÿฌ์˜ค๊ธฐ ๋ฐ์ดํ„ฐ๊ฐ€ ํ•ญ์ƒ ๋จธ์‹ ๋Ÿฌ๋‹ ์•Œ

tutorials.pytorch.kr


Transform (๋ณ€ํ˜•)

๋ฐ์ดํ„ฐ ๋ณ€ํ˜•(Transform)์€ ๋ฐ์ดํ„ฐ ์ „์ฒ˜๋ฆฌ ๋ฐ ๋ฐ์ดํ„ฐ ์ฆ๊ฐ•(data augmentation)์„ ์œ„ํ•ด ์ž์ฃผ ์‚ฌ์šฉ๋ฉ๋‹ˆ๋‹ค.

  • ๋ฐ์ดํ„ฐ๊ฐ€ ํ•ญ์ƒ ๋จธ์‹ ๋Ÿฌ๋‹ ์•Œ๊ณ ๋ฆฌ์ฆ˜ ํ•™์Šต์— ํ•„์š”ํ•œ ์ตœ์ข… ์ฒ˜๋ฆฌ๊ฐ€ ๋œ ํ˜•ํƒœ๋กœ ์ œ๊ณต๋˜์ง€๋Š” ์•Š์Šต๋‹ˆ๋‹ค.
  • ๋ณ€ํ˜•(transform) ์„ ํ•ด์„œ ๋ฐ์ดํ„ฐ๋ฅผ ์กฐ์ž‘ํ•˜๊ณ  ํ•™์Šต์— ์ ํ•ฉํ•˜๊ฒŒ ๋งŒ๋“ญ๋‹ˆ๋‹ค.
  • ๋ชจ๋“  TorchVision ๋ฐ์ดํ„ฐ์…‹๋“ค์€ ๋ณ€ํ˜• ๋กœ์ง์„ ๊ฐ–๋Š”, ํ˜ธ์ถœ ๊ฐ€๋Šฅํ•œ ๊ฐ์ฒด(callable)๋ฅผ ๋ฐ›๋Š” ๋งค๊ฐœ๋ณ€์ˆ˜ ๋‘๊ฐœ (ํŠน์ง•(feature)์„ ๋ณ€๊ฒฝํ•˜๊ธฐ ์œ„ํ•œ transform ๊ณผ ์ •๋‹ต(label)์„ ๋ณ€๊ฒฝํ•˜๊ธฐ ์œ„ํ•œ target_transform)๋ฅผ ๊ฐ–์Šต๋‹ˆ๋‹ค.
  • torchvision.transforms ๋ชจ๋“ˆ์€ ์ด๋ฏธ์ง€ ๋ฐ์ดํ„ฐ์— ๋Œ€ํ•œ ๋‹ค์–‘ํ•œ ๋ณ€ํ˜• ๊ธฐ๋Šฅ์„ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค.
 

Transforming and augmenting images — Torchvision 0.19 documentation

Shortcuts

pytorch.org

 

  • ๋ณ€ํ˜•์€ ์ด๋ฏธ์ง€ ๋ฐ์ดํ„ฐ๋ฅผ ํ…์„œ๋กœ ๋ณ€ํ™˜ํ•˜๊ฑฐ๋‚˜, ์ •๊ทœํ™”ํ•˜๊ฑฐ๋‚˜, ํšŒ์ „, ์ž๋ฅด๊ธฐ ๋“ฑ์˜ ์ž‘์—…์„ ์ˆ˜ํ–‰ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
  • ์ด๋Ÿฌํ•œ ๋ณ€ํ˜•์€ ๋ฐ์ดํ„ฐ์…‹์„ ๋กœ๋“œํ•  ๋•Œ ์ ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

Example (Fashion-MNIST)

  • FashionMNIST ํŠน์ง•(feature)์€ PIL Image ํ˜•์‹์ด๋ฉฐ, ์ •๋‹ต(label)์€ ์ •์ˆ˜(integer)์ž…๋‹ˆ๋‹ค.
  • ํ•™์Šต์„ ํ•˜๋ ค๋ฉด ์ •๊ทœํ™”(normalize)๋œ ํ…์„œ ํ˜•ํƒœ์˜ ํŠน์ง•(feature)๊ณผ ์›-ํ•ซ(one-hot)์œผ๋กœ ๋ถ€ํ˜ธํ™”(encode)๋œ ํ…์„œ ํ˜•ํƒœ์˜ ์ •๋‹ต(label)์ด ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค.
  • ์ด๋Ÿฌํ•œ ๋ณ€ํ˜•(transformation)์„ ํ•˜๊ธฐ ์œ„ํ•ด ToTensor์™€ Lambda๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.
  • ์ž ๊น Fashing-MNIST ๋ฐ์ดํ„ฐ์…‹์— ๋ฐํ•˜์—ฌ ์„ค๋ช…์„ ํ•ด๋ณด๋ฉด, 60,000๊ฐœ์˜ ํ•™์Šต ์ด๋ฏธ์ง€์™€ 10,000๊ฐœ์˜ ํ…Œ์ŠคํŠธ ์ด๋ฏธ์ง€๋กœ ๊ตฌ์„ฑ๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค.
  • ๊ฐ ์ด๋ฏธ์ง€๋Š” 28x28 ํฌ๊ธฐ์˜ ํ‘๋ฐฑ(grayscale) ์ด๋ฏธ์ง€์ด๋ฉฐ, 10๊ฐœ์˜ ํด๋ž˜์Šค๋กœ ๋ถ„๋ฅ˜๋ฉ๋‹ˆ๋‹ค.
import torch
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda

# FashionMNIST ๋ฐ์ดํ„ฐ์…‹ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
ds = datasets.FashionMNIST(
    root="data",  # ๋ฐ์ดํ„ฐ์…‹์ด ์ €์žฅ๋  ๊ฒฝ๋กœ
    train=True,  # ํ•™์Šต์šฉ ๋ฐ์ดํ„ฐ์…‹์„ ๋ถˆ๋Ÿฌ์˜ต๋‹ˆ๋‹ค (train=False๋กœ ์„ค์ •ํ•˜๋ฉด ํ…Œ์ŠคํŠธ ๋ฐ์ดํ„ฐ์…‹์„ ๋ถˆ๋Ÿฌ์˜ต๋‹ˆ๋‹ค)
    download=True,  # ๊ฒฝ๋กœ์— ๋ฐ์ดํ„ฐ์…‹์ด ์—†์„ ๊ฒฝ์šฐ ์ธํ„ฐ๋„ท์—์„œ ๋‹ค์šด๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค
    transform=ToTensor(),  # ์ด๋ฏธ์ง€๋ฅผ ํ…์„œ๋กœ ๋ณ€ํ™˜ํ•ฉ๋‹ˆ๋‹ค
    target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))  # ๋ ˆ์ด๋ธ”์„ ์›-ํ•ซ ์ธ์ฝ”๋”ฉ์œผ๋กœ ๋ณ€ํ™˜ํ•ฉ๋‹ˆ๋‹ค
)
  • FashionMNIST ๋ฐ์ดํ„ฐ์…‹์„ ๋ถˆ๋Ÿฌ์™€์„œ ์ด๋ฏธ์ง€๋ฅผ ํ…์„œ๋กœ ๋ณ€ํ™˜ํ•˜๊ณ , ๋ ˆ์ด๋ธ”์„ ์›-ํ•ซ ์ธ์ฝ”๋”ฉ์œผ๋กœ ๋ณ€ํ™˜ํ•˜์—ฌ ๋ฐ์ดํ„ฐ์…‹์„ ์ค€๋น„ํ•ฉ๋‹ˆ๋‹ค.
  • ์ด ๋ฐ์ดํ„ฐ์…‹์€ ์ดํ›„ ๋ชจ๋ธ ํ•™์Šต์— ์‚ฌ์šฉ๋  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
  • DataLoader๋ฅผ ํ†ตํ•ด ๋ฐฐ์น˜ ๋‹จ์œ„๋กœ ๋ฐ์ดํ„ฐ๋ฅผ ๋ถˆ๋Ÿฌ์˜ฌ ์ˆ˜ ์žˆ์œผ๋ฉฐ, ๋ณ€ํ˜•๋œ ์ด๋ฏธ์ง€์™€ ๋ ˆ์ด๋ธ”์„ ๋ชจ๋ธ์— ์ „๋‹ฌํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

ToTensor()

  • ToTensor๋Š” PIL Image๋‚˜ NumPy ndarray ๋ฅผ FloatTensor๋กœ ๋ณ€ํ™˜ํ•˜๊ณ , ์ด๋ฏธ์ง€์˜ ํ”ฝ์…€์˜ ํฌ๊ธฐ(intensity) ๊ฐ’์„ [0., 1.] ๋ฒ”์œ„๋กœ ๋น„๋ก€ํ•˜์—ฌ ์กฐ์ •(scale)ํ•ฉ๋‹ˆ๋‹ค.
  • ๊ทธํ›„, (H, W) ํ˜•์‹์˜ ์ด๋ฏธ์ง€๋ฅผ (C, H, W) ํ˜•์‹์œผ๋กœ ๋ณ€ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
์—ฌ๊ธฐ์„œ C, H, W๋Š” ๊ฐ๊ฐ ์ฑ„๋„(Channel), ๋†’์ด(Height), ๋„ˆ๋น„(Width)๋ฅผ ๋‚˜ํƒ€๋ƒ…๋‹ˆ๋‹ค.

Lambda ๋ณ€ํ˜• (Transform)

  • Lambda ๋ณ€ํ˜•์€ ์‚ฌ์šฉ์ž ์ •์˜ ๋žŒ๋‹ค(lambda) ํ•จ์ˆ˜๋ฅผ ์ ์šฉํ•ฉ๋‹ˆ๋‹ค. ์—ฌ๊ธฐ์—์„œ๋Š” ์ •์ˆ˜๋ฅผ ์›-ํ•ซ์œผ๋กœ ๋ถ€ํ˜ธํ™”๋œ ํ…์„œ๋กœ ๋ฐ”๊พธ๋Š” ํ•จ์ˆ˜๋ฅผ ์ •์˜ํ•ฉ๋‹ˆ๋‹ค.
  • ์ด ํ•จ์ˆ˜๋Š” ๋จผ์ € (๋ฐ์ดํ„ฐ์…‹ ์ •๋‹ต์˜ ๊ฐœ์ˆ˜์ธ) ํฌ๊ธฐ 10์งœ๋ฆฌ ์˜ ํ…์„œ(zero tensor)๋ฅผ ๋งŒ๋“ค๊ณ , scatter ๋ฅผ ํ˜ธ์ถœํ•˜์—ฌ ์ฃผ์–ด์ง„ ์ •๋‹ต y ์— ํ•ด๋‹นํ•˜๋Š” ์ธ๋ฑ์Šค์— value=1 ์„ ํ• ๋‹นํ•ฉ๋‹ˆ๋‹ค.
 

torch.Tensor.scatter_ — PyTorch 2.4 documentation

Shortcuts

pytorch.org

target_transform = Lambda(lambda y: torch.zeros(
    10, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))
  • ์ด ๋ณ€ํ˜•์€ ๋ ˆ์ด๋ธ”์„ ์›-ํ•ซ ์ธ์ฝ”๋”ฉ ๋ฒกํ„ฐ๋กœ ๋ณ€ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
  • ์›-ํ•ซ ์ธ์ฝ”๋”ฉ์€ ํด๋ž˜์Šค ๋ ˆ์ด๋ธ”์„ ์ด์ง„ ๋ฒกํ„ฐ๋กœ ๋ณ€ํ™˜ํ•˜์—ฌ, ํ•ด๋‹น ํด๋ž˜์Šค์—๋งŒ 1์„ ํ• ๋‹นํ•˜๊ณ  ๋‚˜๋จธ์ง€ ํด๋ž˜์Šค์—๋Š” 0์„ ํ• ๋‹นํ•˜๋Š” ๋ฐฉ๋ฒ•์ž…๋‹ˆ๋‹ค.
์˜ˆ๋ฅผ ๋“ค์–ด, ์›๋ณธ ๋ ˆ์ด๋ธ”์ด 3์ผ ๊ฒฝ์šฐ, ๋ณ€ํ™˜๋œ ์›-ํ•ซ ์ธ์ฝ”๋”ฉ ๋ฒกํ„ฐ๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค:
์›๋ณธ ๋ ˆ์ด๋ธ”: 3์›-ํ•ซ ์ธ์ฝ”๋”ฉ ๋ฒกํ„ฐ: [0, 0, 0, 1, 0, 0, 0, 0, 0, 0]

One-Hot Encoding ๋ณ€ํ™˜ ๊ณผ์ •

Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))

 

  • torch.zeros(10, dtype=torch.float): ๊ธธ์ด๊ฐ€ 10์ธ 0์œผ๋กœ ์ฑ„์›Œ์ง„ ํ…์„œ๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.
  • ์ด๋Š” 10๊ฐœ์˜ ํด๋ž˜์Šค ๊ฐ๊ฐ์— ๋Œ€ํ•ด ํ•˜๋‚˜์˜ ์›-ํ•ซ ์ธ์ฝ”๋”ฉ ๋ฒกํ„ฐ๋ฅผ ๋‚˜ํƒ€๋ƒ…๋‹ˆ๋‹ค.
  • scatter_(0, torch.tensor(y), value=1): ์ธ๋ฑ์Šค y์— ํ•ด๋‹นํ•˜๋Š” ์œ„์น˜์— 1์„ ์„ค์ •ํ•ฉ๋‹ˆ๋‹ค.
  • y๋Š” ์›๋ณธ ๋ ˆ์ด๋ธ”์ด๋ฉฐ, ์ด๋Š” 0๋ถ€ํ„ฐ 9๊นŒ์ง€์˜ ๊ฐ’์„ ๊ฐ€์ง‘๋‹ˆ๋‹ค.

๋” ์ž์„ธํ•œ ๋‚ด์šฉ์„ ๋ณด๊ณ  ์‹ถ์œผ์‹œ๋ฉด ์•„๋ž˜ ๋งํฌ์— ๋“ค์–ด๊ฐ€์„œ ๊ณต์‹ ๋ฌธ์„œ๋ฅผ ์ฐธ๊ณ ํ•˜์„ธ์š”!

  • torchvision.transforms ๊ด€๋ จ PyTorch ๊ณต์‹ ๋ฌธ์„œ
 

Transforming and augmenting images — Torchvision 0.19 documentation

Shortcuts

pytorch.org