A A
[PyTorch] ๋ชจ๋ธ ๋งค๊ฐœ๋ณ€์ˆ˜ ์ตœ์ ํ™”(Optimization) ํ•˜๊ธฐ
โš ๏ธ ๋ณธ ๋‚ด์šฉ์€ PyTorch Korea์˜ ๊ณต์‹ ๋ฌธ์„œ์— ๊ธฐ๋ฐ˜ํ•˜์—ฌ ๊ณต๋ถ€ํ•œ ๋‚ด์šฉ์„ ์ ์€๊ฒƒ์ด๋‹ˆ ์–‘ํ•ด๋ฐ”๋ž๋‹ˆ๋‹ค!
 

๋ชจ๋ธ ๋งค๊ฐœ๋ณ€์ˆ˜ ์ตœ์ ํ™”ํ•˜๊ธฐ

ํŒŒ์ดํ† ์น˜(PyTorch) ๊ธฐ๋ณธ ์ตํžˆ๊ธฐ|| ๋น ๋ฅธ ์‹œ์ž‘|| ํ…์„œ(Tensor)|| Dataset๊ณผ Dataloader|| ๋ณ€ํ˜•(Transform)|| ์‹ ๊ฒฝ๋ง ๋ชจ๋ธ ๊ตฌ์„ฑํ•˜๊ธฐ|| Autograd|| ์ตœ์ ํ™”(Optimization)|| ๋ชจ๋ธ ์ €์žฅํ•˜๊ณ  ๋ถˆ๋Ÿฌ์˜ค๊ธฐ ์ด์ œ ๋ชจ๋ธ๊ณผ ๋ฐ์ดํ„ฐ๊ฐ€ ์ค€๋น„

tutorials.pytorch.kr


Model ๋งค๊ฐœ๋ณ€์ˆ˜ ์ตœ์ ํ™” ํ•˜๊ธฐ

์ด๋ฒˆ์—๋Š” ์ค€๋น„๋œ ๋ชจ๋ธ๊ณผ ๋ฐ์ดํ„ฐ๋กœ, ๋ฐ์ดํ„ฐ์— ๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ ์ตœ์ ํ™” ํ•˜์—ฌ ๋ชจ๋ธ์„ ํ•™์Šต, ๊ฒ€์ฆ, ํ…Œ์ŠคํŠธ๋ฅผ ํ•ด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.
  • ๋ชจ๋ธ์„ ํ•™์Šตํ•˜๋Š” ๊ณผ์ •์€ ๋ฐ˜๋ณต์ ์ธ ๊ณผ์ •์„ ๊ฑฐ์นฉ๋‹ˆ๋‹ค.
  • ๊ฐ ๋ฐ˜๋ณต ๋‹จ๊ณ„์—์„œ ๋ชจ๋ธ์€ ์ถœ๋ ฅ์„ ์ถ”์ธกํ•˜๊ณ , ์ถ”์ธก๊ณผ ์ •๋‹ต ์‚ฌ์ด์˜ ์˜ค๋ฅ˜(์†์‹ค(loss))๋ฅผ ๊ณ„์‚ฐํ•˜๊ณ , ๋งค๊ฐœ๋ณ€์ˆ˜์— ๋Œ€ํ•œ ์˜ค๋ฅ˜์˜ ๋„ํ•จ์ˆ˜(derivative)๋ฅผ ์ˆ˜์ง‘ํ•œ ๋’ค, ๊ฒฝ์‚ฌํ•˜๊ฐ•๋ฒ•์„ ์‚ฌ์šฉํ•˜์—ฌ ์ด ํŒŒ๋ผ๋ฏธํ„ฐ๋“ค์„ ์ตœ์ ํ™”(optimize)ํ•ฉ๋‹ˆ๋‹ค.
  • ์ด ๊ณผ์ •์— ๋Œ€ํ•œ ์ž์„ธํ•œ ์„ค๋ช…์€ ์•„๋ž˜ ๋งํฌ์— 3Blue1Brown์˜ Backpropagation(์—ญ์ „ํŒŒ)์˜์ƒ์„ ์ฐธ๊ณ ํ•˜์„ธ์š”.

  • 3Blue1Brown์˜ Backpropagation(์—ญ์ „ํŒŒ)์˜์ƒ ๋งํฌ ์ž…๋‹ˆ๋‹ค.

Pre-requisite Code (๊ธฐ๋ณธ ์ฝ”๋“œ)

์ „์— ๊ณต๋ถ€ํ–ˆ๋˜ ๋‚ด์šฉ์—์„œ Dataset, Dataloader ๋ถ€๋ถ„๊ณผ ์‹ ๊ฒฝ๋ง ๋ชจ๋ธ ๊ตฌ์„ฑํ•˜๊ธฐ ๋ถ€๋ถ„์—์„œ ์ฝ”๋“œ๋ฅผ ๊ฐ€์ ธ์™”์Šต๋‹ˆ๋‹ค.
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

train_dataloader = DataLoader(training_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)
class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

model = NeuralNetwork()

Hyperparameter (ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ)

ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ(Hyperparameter)๋Š” ๋ชจ๋ธ ์ตœ์ ํ™” ๊ณผ์ •์„ ์ œ์–ดํ•  ์ˆ˜ ์žˆ๋Š” ์กฐ์ ˆ ๊ฐ€๋Šฅํ•œ ๋งค๊ฐœ๋ณ€์ˆ˜์ž…๋‹ˆ๋‹ค.
  • ์„œ๋กœ ๋‹ค๋ฅธ ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ ๊ฐ’์€ ๋ชจ๋ธ ํ•™์Šต๊ณผ ์ˆ˜๋ ด์œจ(convergence rate)์— ์˜ํ–ฅ์„ ๋ฏธ์น  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
  • Hyperparameter์— ๋ฐํ•œ ๊ณต์‹๋ฌธ์„œ ์ž…๋‹ˆ๋‹ค. ์ด ๋ถ€๋ถ„์— ๋ฐํ•˜์—ฌ ๊ณต๋ถ€ํ•œ ๋‚ด์šฉ์€ ๋‚˜์ค‘์— ์˜ฌ๋ฆฌ๋„๋ก ํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค.
 

Ray Tune์„ ์‚ฌ์šฉํ•œ ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ ํŠœ๋‹

๋ฒˆ์—ญ: ์‹ฌํ˜•์ค€ ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ ํŠœ๋‹์€ ๋ณดํ†ต์˜ ๋ชจ๋ธ๊ณผ ๋งค์šฐ ์ •ํ™•ํ•œ ๋ชจ๋ธ๊ฐ„์˜ ์ฐจ์ด๋ฅผ ๋งŒ๋“ค์–ด ๋‚ผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ข…์ข… ๋‹ค๋ฅธ ํ•™์Šต๋ฅ (Learnig rate)์„ ์„ ํƒํ•˜๊ฑฐ๋‚˜ layer size๋ฅผ ๋ณ€๊ฒฝํ•˜๋Š” ๊ฒƒ๊ณผ ๊ฐ™์€ ๊ฐ„๋‹จํ•œ ์ž‘์—…

tutorials.pytorch.kr

  • ๋ชจ๋ธ์„ ํ•™์Šตํ• ๋•Œ, ์ผ๋ฐ˜์ ์œผ๋กœ ๋‹ค์Œ๊ณผ ๊ฐ™์€ Hyperparameter๋ฅผ ์ •์˜ํ•ฉ๋‹ˆ๋‹ค.
  • ์—ํญ(epoch) ์ˆ˜ - ๋ฐ์ดํ„ฐ์…‹์„ ๋ฐ˜๋ณตํ•˜๋Š” ํšŸ์ˆ˜
  • ๋ฐฐ์น˜ ํฌ๊ธฐ(batch size) - ๋งค๊ฐœ๋ณ€์ˆ˜๊ฐ€ ๊ฐฑ์‹ ๋˜๊ธฐ ์ „ ์‹ ๊ฒฝ๋ง์„ ํ†ตํ•ด ์ „ํŒŒ๋œ ๋ฐ์ดํ„ฐ ์ƒ˜ํ”Œ์˜ ์ˆ˜
  • ํ•™์Šต๋ฅ (learning rate) - ๊ฐ ๋ฐฐ์น˜/์—ํญ์—์„œ ๋ชจ๋ธ์˜ ๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ ์กฐ์ ˆํ•˜๋Š” ๋น„์œจ. ๊ฐ’์ด ์ž‘์„์ˆ˜๋ก ํ•™์Šต ์†๋„๊ฐ€ ๋Š๋ ค์ง€๊ณ , ๊ฐ’์ด ํฌ๋ฉด ํ•™์Šต ์ค‘ ์˜ˆ์ธกํ•  ์ˆ˜ ์—†๋Š” ๋™์ž‘์ด ๋ฐœ์ƒํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
learning_rate = 1e-3
batch_size = 64
epochs = 5

์ตœ์ ํ™” ๋‹จ๊ณ„(Optimization Loop)

ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ์„ค์ •ํ•œ ๋’ค์—๋Š” ์ตœ์ ํ™” ๋‹จ๊ณ„๋ฅผ ํ†ตํ•ด ๋ชจ๋ธ์„ ํ•™์Šตํ•˜๊ณ  ์ตœ์ ํ™”ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
์ตœ์ ํ™” ๋‹จ๊ณ„์˜ ๊ฐ ๋ฐ˜๋ณต(iteration)์„ Epoch(์—ํญ) ์ด๋ผ๊ณ  ๋ถ€๋ฆ…๋‹ˆ๋‹ค
  • ํ•˜๋‚˜์˜ ์—ํญ์€ ๋‘๊ฐ€์ง€๋กœ ๊ตฌ์„ฑ๋ฉ๋‹ˆ๋‹ค
    • ํ•™์Šต ๋‹จ๊ณ„(train loop) - ํ•™์Šต์šฉ ๋ฐ์ดํ„ฐ์…‹์„ ๋ฐ˜๋ณต(iterate)ํ•˜๊ณ  ์ตœ์ ์˜ ๋งค๊ฐœ๋ณ€์ˆ˜๋กœ ์ˆ˜๋ ดํ•ฉ๋‹ˆ๋‹ค.
    • ๊ฒ€์ฆ / ํ…Œ์ŠคํŠธ ๋‹จ๊ณ„(validation / test loop) - ๋ชจ๋ธ ์„ฑ๋Šฅ์ด ๊ฐœ์„ ๋˜๊ณ  ์žˆ๋Š”์ง€๋ฅผ ํ™•์ธํ•˜๊ธฐ ์œ„ํ•ด ํ…Œ์ŠคํŠธ ๋ฐ์ดํ„ฐ์…‹์„ ๋ฐ˜๋ณต(iterate)ํ•ฉ๋‹ˆ๋‹ค.
  • ๊ทธ๋Ÿฌ๋ฉด ํ•™์Šต ๋‹จ๊ณ„(training loop)์—์„œ ์ผ์–ด๋‚˜๋Š” ๋ช‡ ๊ฐ€์ง€ ๊ฐœ๋…๋“ค์„ ๊ฐ„๋žตํžˆ ์‚ดํŽด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.

์†์‹ค ํ•จ์ˆ˜ (Loss Function)

ํ•™์Šต์šฉ ๋ฐ์ดํ„ฐ๋ฅผ ์ œ๊ณตํ•˜๋ฉด, ํ•™์Šต๋˜์ง€ ์•Š์€ ์‹ ๊ฒฝ๋ง์€ ์ •๋‹ต์„ ์ œ๊ณตํ•˜์ง€ ์•Š์„ ํ™•๋ฅ ์ด ๋†’์Šต๋‹ˆ๋‹ค.
  • ์†์‹ค ํ•จ์ˆ˜(loss function)๋Š” ํš๋“ํ•œ ๊ฒฐ๊ณผ์™€ ์‹ค์ œ ๊ฐ’ ์‚ฌ์ด์˜ ํ‹€๋ฆฐ ์ •๋„(degree of dissimilarity)๋ฅผ ์ธก์ •ํ•˜๋ฉฐ, ํ•™์Šต ์ค‘์— ์ด ๊ฐ’์„ ์ตœ์†Œํ™”ํ•˜๋ ค๊ณ  ํ•ฉ๋‹ˆ๋‹ค.
  • ์ฃผ์–ด์ง„ ๋ฐ์ดํ„ฐ ์ƒ˜ํ”Œ์„ ์ž…๋ ฅ์œผ๋กœ ๊ณ„์‚ฐํ•œ ์˜ˆ์ธก๊ณผ ์ •๋‹ต(label)์„ ๋น„๊ตํ•˜์—ฌ ์†์‹ค(loss)์„ ๊ณ„์‚ฐํ•ฉ๋‹ˆ๋‹ค.
  • ์ผ๋ฐ˜์ ์ธ ์†์‹คํ•จ์ˆ˜์—๋Š” ํšŒ๊ท€ ๋ฌธ์ œ(regression task)์— ์‚ฌ์šฉํ•˜๋Š” nn.MSELoss (ํ‰๊ท  ์ œ๊ณฑ ์˜ค์ฐจ(MSE; Mean Square Error))๋‚˜ ๋ถ„๋ฅ˜(classification)์— ์‚ฌ์šฉํ•˜๋Š” nn.NLLLoss (์Œ์˜ ๋กœ๊ทธ ์šฐ๋„(Negative Log Likelihood)), ๊ทธ๋ฆฌ๊ณ  nn.LogSoftmax์™€ nn.NLLLoss๋ฅผ ํ•ฉ์นœ nn.CrossEntropyLoss ๋“ฑ์ด ์žˆ์Šต๋‹ˆ๋‹ค.
  • ๋ชจ๋ธ์˜ ์ถœ๋ ฅ ๋กœ์ง“(logit)์„ nn.CrossEntropyLoss์— ์ „๋‹ฌํ•˜์—ฌ ๋กœ์ง“(logit)์„ ์ •๊ทœํ™”ํ•˜๊ณ  ์˜ˆ์ธก ์˜ค๋ฅ˜๋ฅผ ๊ณ„์‚ฐํ•ฉ๋‹ˆ๋‹ค.
# ์†์‹ค ํ•จ์ˆ˜๋ฅผ ์ดˆ๊ธฐํ™”ํ•ฉ๋‹ˆ๋‹ค.
loss_fn = nn.CrossEntropyLoss()
  • ํšŒ๊ท€ ๋ฌธ์ œ(regression task)์— ์‚ฌ์šฉํ•˜๋Š” nn.MSELoss ๊ณต์‹๋ฌธ์„œ ์ž…๋‹ˆ๋‹ค.
 

MSELoss — PyTorch 2.4 documentation

Shortcuts

pytorch.org

  • nn.NLLLoss (์Œ์˜ ๋กœ๊ทธ ์šฐ๋„(Negative Log Likelihood)) ๊ณต์‹๋ฌธ์„œ ์ž…๋‹ˆ๋‹ค.
 

NLLLoss — PyTorch 2.4 documentation

Shortcuts

pytorch.org

  • nn.CrossEntropyLoss ๊ณต์‹๋ฌธ์„œ ์ž…๋‹ˆ๋‹ค.
 

CrossEntropyLoss — PyTorch 2.4 documentation

Shortcuts

pytorch.org


์˜ตํ‹ฐ๋งˆ์ด์ € (Optimizer)

์ตœ์ ํ™”๋Š” ๊ฐ ํ•™์Šต ๋‹จ๊ณ„์—์„œ ๋ชจ๋ธ์˜ ์˜ค๋ฅ˜๋ฅผ ์ค„์ด๊ธฐ ์œ„ํ•ด ๋ชจ๋ธ ๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ ์กฐ์ •ํ•˜๋Š” ๊ณผ์ •์ž…๋‹ˆ๋‹ค.
์ตœ์ ํ™” ์•Œ๊ณ ๋ฆฌ์ฆ˜์€ ์ด ๊ณผ์ •์ด ์ˆ˜ํ–‰๋˜๋Š” ๋ฐฉ์‹(์—ฌ๊ธฐ์—์„œ๋Š” ํ™•๋ฅ ์  ๊ฒฝ์‚ฌํ•˜๊ฐ•๋ฒ•(SGD; Stochastic Gradient Descent))์„ ์ •์˜ํ•ฉ๋‹ˆ๋‹ค.
  • ์ตœ์ ํ™”๋Š” ๊ฐ ํ•™์Šต ๋‹จ๊ณ„์—์„œ ๋ชจ๋ธ์˜ ์˜ค๋ฅ˜๋ฅผ ์ค„์ด๊ธฐ ์œ„ํ•ด ๋ชจ๋ธ ๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ ์กฐ์ •ํ•˜๋Š” ๊ณผ์ •์ž…๋‹ˆ๋‹ค.
  • ์ตœ์ ํ™” ์•Œ๊ณ ๋ฆฌ์ฆ˜์€ ์ด ๊ณผ์ •์ด ์ˆ˜ํ–‰๋˜๋Š” ๋ฐฉ์‹(์—ฌ๊ธฐ์—์„œ๋Š” ํ™•๋ฅ ์  ๊ฒฝ์‚ฌํ•˜๊ฐ•๋ฒ•(SGD; Stochastic Gradient Descent))์„ ์ •์˜ํ•ฉ๋‹ˆ๋‹ค.
  • ๋ชจ๋“  ์ตœ์ ํ™” ์ ˆ์ฐจ(logic)๋Š” optimizer ๊ฐ์ฒด์— ์บก์Šํ™”(encapsulate)๋ฉ๋‹ˆ๋‹ค.
  • ์—ฌ๊ธฐ์„œ๋Š” SGD ์˜ตํ‹ฐ๋งˆ์ด์ €๋ฅผ ์‚ฌ์šฉํ•˜๊ณ  ์žˆ์œผ๋ฉฐ, PyTorch์—๋Š” ADAM์ด๋‚˜ RMSProp๊ณผ ๊ฐ™์€ ๋‹ค๋ฅธ ์ข…๋ฅ˜์˜ ๋ชจ๋ธ๊ณผ ๋ฐ์ดํ„ฐ์—์„œ ๋” ์ž˜ ๋™์ž‘ํ•˜๋Š” ๋‹น์–‘ํ•œ ์˜ตํ‹ฐ๋งˆ์ด์ €๊ฐ€ ์žˆ์Šต๋‹ˆ๋‹ค.
  • ํ•™์Šตํ•˜๋ ค๋Š” ๋ชจ๋ธ์˜ ๋งค๊ฐœ๋ณ€์ˆ˜์™€ ํ•™์Šต๋ฅ (learning rate) ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ๋“ฑ๋กํ•˜์—ฌ ์˜ตํ‹ฐ๋งˆ์ด์ €๋ฅผ ์ดˆ๊ธฐํ™”ํ•ฉ๋‹ˆ๋‹ค.
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

 

ํ•™์Šต ๋‹จ๊ณ„(loop)์—์„œ ์ตœ์ ํ™”๋Š” ์„ธ๋‹จ๊ณ„๋กœ ์ด๋ค„์ง‘๋‹ˆ๋‹ค.

  • optimizer.zero_grad()๋ฅผ ํ˜ธ์ถœํ•˜์—ฌ ๋ชจ๋ธ ๋งค๊ฐœ๋ณ€์ˆ˜์˜ ๋ณ€ํ™”๋„๋ฅผ ์žฌ์„ค์ •ํ•ฉ๋‹ˆ๋‹ค.
  • ๊ธฐ๋ณธ์ ์œผ๋กœ ๋ณ€ํ™”๋„๋Š” ๋”ํ•ด์ง€๊ธฐ(add up) ๋•Œ๋ฌธ์— ์ค‘๋ณต ๊ณ„์‚ฐ์„ ๋ง‰๊ธฐ ์œ„ํ•ด ๋ฐ˜๋ณตํ•  ๋•Œ๋งˆ๋‹ค ๋ช…์‹œ์ ์œผ๋กœ 0์œผ๋กœ ์„ค์ •ํ•ฉ๋‹ˆ๋‹ค.
  • loss.backwards()๋ฅผ ํ˜ธ์ถœํ•˜์—ฌ ์˜ˆ์ธก ์†์‹ค(prediction loss)์„ ์—ญ์ „ํŒŒํ•ฉ๋‹ˆ๋‹ค.
  • PyTorch๋Š” ๊ฐ ๋งค๊ฐœ๋ณ€์ˆ˜์— ๋Œ€ํ•œ ์†์‹ค์˜ ๋ณ€ํ™”๋„๋ฅผ ์ €์žฅํ•ฉ๋‹ˆ๋‹ค.
  • ๋ณ€ํ™”๋„๋ฅผ ๊ณ„์‚ฐํ•œ ๋’ค์—๋Š” optimizer.step()์„ ํ˜ธ์ถœํ•˜์—ฌ ์—ญ์ „ํŒŒ ๋‹จ๊ณ„์—์„œ ์ˆ˜์ง‘๋œ ๋ณ€ํ™”๋„๋กœ ๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ ์กฐ์ •ํ•ฉ๋‹ˆ๋‹ค.
  • ์•„๋ž˜๋Š” Optimizer ๊ด€๋ จ ๊ณต์‹๋ฌธ์„œ ์ž…๋‹ˆ๋‹ค.
 

torch.optim — PyTorch 2.4 documentation

torch.optim torch.optim is a package implementing various optimization algorithms. Most commonly used methods are already supported, and the interface is general enough, so that more sophisticated ones can also be easily integrated in the future. How to us

pytorch.org


์ „์ฒด ์ฝ”๋“œ ๊ตฌํ˜„

์ตœ์ ํ™” ์ฝ”๋“œ๋ฅผ ๋ฐ˜๋ณตํ•˜์—ฌ ์ˆ˜ํ–‰ํ•˜๋Š” train_loop์™€ ํ…Œ์ŠคํŠธ ๋ฐ์ดํ„ฐ๋กœ ๋ชจ๋ธ์˜ ์„ฑ๋Šฅ์„ ์ธก์ •ํ•˜๋Š” test_loop๋ฅผ ์ •์˜ํ•˜์˜€์Šต๋‹ˆ๋‹ค.
def train_loop(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    # ๋ชจ๋ธ์„ ํ•™์Šต(train) ๋ชจ๋“œ๋กœ ์„ค์ •ํ•ฉ๋‹ˆ๋‹ค - ๋ฐฐ์น˜ ์ •๊ทœํ™”(Batch Normalization) ๋ฐ ๋“œ๋กญ์•„์›ƒ(Dropout) ๋ ˆ์ด์–ด๋“ค์— ์ค‘์š”ํ•ฉ๋‹ˆ๋‹ค.
    # ์ด ์˜ˆ์‹œ์—์„œ๋Š” ์—†์–ด๋„ ๋˜์ง€๋งŒ, ์ถ”๊ฐ€ํ•ด๋‘์—ˆ์Šต๋‹ˆ๋‹ค.
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        # ์˜ˆ์ธก(prediction)๊ณผ ์†์‹ค(loss) ๊ณ„์‚ฐ
        pred = model(X)
        loss = loss_fn(pred, y)

        # ์—ญ์ „ํŒŒ
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * batch_size + len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
def test_loop(dataloader, model, loss_fn):
    # ๋ชจ๋ธ์„ ํ‰๊ฐ€(eval) ๋ชจ๋“œ๋กœ ์„ค์ •ํ•ฉ๋‹ˆ๋‹ค - ๋ฐฐ์น˜ ์ •๊ทœํ™”(Batch Normalization) ๋ฐ ๋“œ๋กญ์•„์›ƒ(Dropout) ๋ ˆ์ด์–ด๋“ค์— ์ค‘์š”ํ•ฉ๋‹ˆ๋‹ค.
    # ์ด ์˜ˆ์‹œ์—์„œ๋Š” ์—†์–ด๋„ ๋˜์ง€๋งŒ, ๋ชจ๋ฒ” ์‚ฌ๋ก€๋ฅผ ์œ„ํ•ด ์ถ”๊ฐ€ํ•ด๋‘์—ˆ์Šต๋‹ˆ๋‹ค.
    model.eval()
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0

    # torch.no_grad()๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ํ…Œ์ŠคํŠธ ์‹œ ๋ณ€ํ™”๋„(gradient)๋ฅผ ๊ณ„์‚ฐํ•˜์ง€ ์•Š๋„๋ก ํ•ฉ๋‹ˆ๋‹ค.
    # ์ด๋Š” requires_grad=True๋กœ ์„ค์ •๋œ ํ…์„œ๋“ค์˜ ๋ถˆํ•„์š”ํ•œ ๋ณ€ํ™”๋„ ์—ฐ์‚ฐ ๋ฐ ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰ ๋˜ํ•œ ์ค„์—ฌ์ค๋‹ˆ๋‹ค.
    with torch.no_grad():
        for X, y in dataloader:
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
  • ์†์‹ค ํ•จ์ˆ˜์™€ ์˜ตํ‹ฐ๋งˆ์ด์ €๋ฅผ ์ดˆ๊ธฐํ™”ํ•˜๊ณ  train_loop์™€ test_loop์— ์ „๋‹ฌํ•ฉ๋‹ˆ๋‹ค.
  • ๋ชจ๋ธ์˜ ์„ฑ๋Šฅ ํ–ฅ์ƒ์„ ์•Œ์•„๋ณด๊ธฐ ์œ„ํ•ด ์ž์œ ๋กญ๊ฒŒ ์—ํญ(epoch) ์ˆ˜๋ฅผ ์ฆ๊ฐ€์‹œ์ผœ ๋ณผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

epochs = 10
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loop(train_dataloader, model, loss_fn, optimizer)
    test_loop(test_dataloader, model, loss_fn)
print("Done!")
Epoch 1
-------------------------------
loss: 2.305544  [   64/60000]
loss: 2.296132  [ 6464/60000]
loss: 2.269240  [12864/60000]
loss: 2.265443  [19264/60000]
loss: 2.257847  [25664/60000]
loss: 2.210369  [32064/60000]
loss: 2.225546  [38464/60000]
loss: 2.182035  [44864/60000]
loss: 2.192158  [51264/60000]
loss: 2.160539  [57664/60000]
Test Error: 
 Accuracy: 38.8%, Avg loss: 2.149502 

Epoch 2
-------------------------------
loss: 2.165807  [   64/60000]
loss: 2.152438  [ 6464/60000]
loss: 2.086566  [12864/60000]
loss: 2.106932  [19264/60000]
loss: 2.064639  [25664/60000]
loss: 1.981671  [32064/60000]
loss: 2.020240  [38464/60000]
loss: 1.928579  [44864/60000]
loss: 1.948187  [51264/60000]
loss: 1.877932  [57664/60000]
Test Error: 
 Accuracy: 52.7%, Avg loss: 1.867320 

Epoch 3
-------------------------------
loss: 1.906202  [   64/60000]
loss: 1.871369  [ 6464/60000]
loss: 1.746385  [12864/60000]
loss: 1.800164  [19264/60000]
loss: 1.699125  [25664/60000]
loss: 1.632668  [32064/60000]
loss: 1.670603  [38464/60000]
loss: 1.561565  [44864/60000]
loss: 1.602784  [51264/60000]
loss: 1.505543  [57664/60000]
Test Error: 
 Accuracy: 57.2%, Avg loss: 1.510549 

Epoch 4
-------------------------------
loss: 1.583206  [   64/60000]
loss: 1.545566  [ 6464/60000]
loss: 1.391972  [12864/60000]
loss: 1.475445  [19264/60000]
loss: 1.366858  [25664/60000]
loss: 1.349078  [32064/60000]
loss: 1.376894  [38464/60000]
loss: 1.291643  [44864/60000]
loss: 1.333184  [51264/60000]
loss: 1.248018  [57664/60000]
Test Error: 
 Accuracy: 62.3%, Avg loss: 1.258796 

Epoch 5
-------------------------------
loss: 1.338726  [   64/60000]
loss: 1.319980  [ 6464/60000]
loss: 1.150760  [12864/60000]
loss: 1.264624  [19264/60000]
loss: 1.147503  [25664/60000]
loss: 1.162498  [32064/60000]
loss: 1.192891  [38464/60000]
loss: 1.120305  [44864/60000]
loss: 1.161247  [51264/60000]
loss: 1.091846  [57664/60000]
Test Error: 
 Accuracy: 64.2%, Avg loss: 1.099003 

Epoch 6
-------------------------------
loss: 1.172135  [   64/60000]
loss: 1.174778  [ 6464/60000]
loss: 0.989002  [12864/60000]
loss: 1.128839  [19264/60000]
loss: 1.007943  [25664/60000]
loss: 1.032277  [32064/60000]
loss: 1.074430  [38464/60000]
loss: 1.006675  [44864/60000]
loss: 1.046682  [51264/60000]
loss: 0.989977  [57664/60000]
Test Error: 
 Accuracy: 65.4%, Avg loss: 0.992776 

Epoch 7
-------------------------------
loss: 1.053729  [   64/60000]
loss: 1.078095  [ 6464/60000]
loss: 0.875616  [12864/60000]
loss: 1.035629  [19264/60000]
loss: 0.916670  [25664/60000]
loss: 0.937665  [32064/60000]
loss: 0.994427  [38464/60000]
loss: 0.930131  [44864/60000]
loss: 0.966351  [51264/60000]
loss: 0.920454  [57664/60000]
Test Error: 
 Accuracy: 66.6%, Avg loss: 0.919069 

Epoch 8
-------------------------------
loss: 0.965864  [   64/60000]
loss: 1.009843  [ 6464/60000]
loss: 0.793096  [12864/60000]
loss: 0.968506  [19264/60000]
loss: 0.854228  [25664/60000]
loss: 0.867148  [32064/60000]
loss: 0.937596  [38464/60000]
loss: 0.877755  [44864/60000]
loss: 0.908159  [51264/60000]
loss: 0.870483  [57664/60000]
Test Error: 
 Accuracy: 68.0%, Avg loss: 0.865721 

Epoch 9
-------------------------------
loss: 0.898012  [   64/60000]
loss: 0.958372  [ 6464/60000]
loss: 0.730860  [12864/60000]
loss: 0.918205  [19264/60000]
loss: 0.809332  [25664/60000]
loss: 0.813669  [32064/60000]
loss: 0.894510  [38464/60000]
loss: 0.840718  [44864/60000]
loss: 0.864679  [51264/60000]
loss: 0.832418  [57664/60000]
Test Error: 
 Accuracy: 69.3%, Avg loss: 0.825300 

Epoch 10
-------------------------------
loss: 0.843837  [   64/60000]
loss: 0.917194  [ 6464/60000]
loss: 0.682222  [12864/60000]
loss: 0.879231  [19264/60000]
loss: 0.775033  [25664/60000]
loss: 0.772480  [32064/60000]
loss: 0.859712  [38464/60000]
loss: 0.813271  [44864/60000]
loss: 0.830852  [51264/60000]
loss: 0.801870  [57664/60000]
Test Error: 
 Accuracy: 70.6%, Avg loss: 0.793228 

Done!

๋” ์ž์„ธํ•œ ๋‚ด์šฉ์„ ๋ณด๊ณ  ์‹ถ์œผ์‹œ๋ฉด ์•„๋ž˜ ๋งํฌ์— ๋“ค์–ด๊ฐ€์„œ ๊ณต์‹ ๋ฌธ์„œ๋ฅผ ์ฐธ๊ณ ํ•˜์„ธ์š”!

  • Loss Function ๊ด€๋ จ ๊ณต์‹๋ฌธ์„œ
 

torch.nn — PyTorch 2.4 documentation

Shortcuts

pytorch.org

  • Torch.Optimizer ๊ด€๋ จ ๊ณต์‹๋ฌธ์„œ
 

torch.optim — PyTorch 2.4 documentation

torch.optim torch.optim is a package implementing various optimization algorithms. Most commonly used methods are already supported, and the interface is general enough, so that more sophisticated ones can also be easily integrated in the future. How to us

pytorch.org

  • Warmstart Training Model ๊ด€๋ จ ๊ณต์‹๋ฌธ์„œ
 

PyTorch์—์„œ ๋‹ค๋ฅธ ๋ชจ๋ธ์˜ ๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋น ๋ฅด๊ฒŒ ๋ชจ๋ธ ์‹œ์ž‘ํ•˜๊ธฐ(warmstart)

๋ชจ๋ธ์„ ๋ถ€๋ถ„์ ์œผ๋กœ ๋ถˆ๋Ÿฌ์˜ค๊ฑฐ๋‚˜, ํ˜น์€ ๋ถ€๋ถ„์ ์ธ ๋ชจ๋ธ์„ ๋ถˆ๋Ÿฌ์˜ค๋Š” ๊ฒƒ์€ ํ•™์Šต ์ „์ด(Transfer learning)๋‚˜ ๋ณต์žกํ•œ ๋ชจ๋ธ์„ ์ƒˆ๋กœ ํ•™์Šตํ•  ๋•Œ ์ž์ฃผ ์ ‘ํ•˜๋Š” ์‹œ๋‚˜๋ฆฌ์˜ค์ž…๋‹ˆ๋‹ค. ํ•™์Šต๋œ ๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ ํ™œ์šฉํ•˜๋ฉด ํ•™์Šต

tutorials.pytorch.kr