A A
[PyTorch] Checkpoint Model ์ €์žฅ & ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
โš ๏ธ ๋ณธ ๋‚ด์šฉ์€ PyTorch Korea์˜ ๊ณต์‹ ๋ฌธ์„œ์— ๊ธฐ๋ฐ˜ํ•˜์—ฌ ๊ณต๋ถ€ํ•œ ๋‚ด์šฉ์„ ์ ์€๊ฒƒ์ด๋‹ˆ ์–‘ํ•ด๋ฐ”๋ž๋‹ˆ๋‹ค!
 

PyTorch์—์„œ ์ผ๋ฐ˜์ ์ธ ์ฒดํฌํฌ์ธํŠธ(checkpoint) ์ €์žฅํ•˜๊ธฐ & ๋ถˆ๋Ÿฌ์˜ค๊ธฐ

์ถ”๋ก (inference) ๋˜๋Š” ํ•™์Šต(training)์˜ ์žฌ๊ฐœ๋ฅผ ์œ„ํ•ด ์ฒดํฌํฌ์ธํŠธ(checkpoint) ๋ชจ๋ธ์„ ์ €์žฅํ•˜๊ณ  ๋ถˆ๋Ÿฌ์˜ค๋Š” ๊ฒƒ์€ ๋งˆ์ง€๋ง‰์œผ๋กœ ์ค‘๋‹จํ–ˆ๋˜ ๋ถ€๋ถ„์„ ์„ ํƒํ•˜๋Š”๋ฐ ๋„์›€์„ ์ค„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ฒดํฌํฌ์ธํŠธ๋ฅผ ์ €์žฅํ•  ๋•Œ๋Š”

tutorials.pytorch.kr


PyTorch์—์„œ Inference & Training์„ ๋‹ค์‹œ ํ•˜๊ธฐ ์œ„ํ•ด์„œ Checkpoint Model์„ ์ €์žฅ & ๋ถˆ๋Ÿฌ์˜ค๋Š”๊ฒƒ์„ ํ•œ๋ฒˆ ํ•ด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.

Intro

  • PyTorch์—์„œ ์—ฌ๋Ÿฌ Checkpoint๋“ค์„ ์ €์žฅํ•˜๊ธฐ ์œ„ํ•ด์„  ์‚ฌ์ „(Dictionary)์— Checkpoint๋“ค์„ ๊ตฌ์„ฑํ•œํ›„,
  • torch.save()๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ Dictionary๋ฅผ ์ง๋ ฌํ™”(Seralize)ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.
  • PyTorch์—์„œ๋Š” ์—ฌ๋Ÿฌ ์ฒดํฌํฌ์ธํŠธ๋“ค์„ ์ €์žฅํ•  ๋•Œ .tar ํ™•์žฅ์ž๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.
  • ํ•ญ๋ชฉ๋“ค์„ ๋ถˆ๋Ÿฌ์˜ฌ ๋•Œ์—๋Š”, ๋จผ์ € ๋ชจ๋ธ๊ณผ ์˜ตํ‹ฐ๋งˆ์ด์ €๋ฅผ ์ดˆ๊ธฐํ™”ํ•˜๊ณ , torch.load()๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์‚ฌ์ „์„ ๋ถˆ๋Ÿฌ์˜ต๋‹ˆ๋‹ค.
  • ์ดํ›„ ์›ํ•˜๋Š”๋Œ€๋กœ ์ €์žฅํ•œ ํ•ญ๋ชฉ๋“ค์„ ์‚ฌ์ „์— ์กฐํšŒํ•˜์—ฌ ์ ‘๊ทผํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

Setting (์„ค์ •)

์ผ๋‹จ, ๋จผ์ € PyTorch ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋ฅผ ์„ค์น˜ํ•ด์ค˜์•ผ ํ•ฉ๋‹ˆ๋‹ค.
!pip install torch

Model ์ €์žฅ & ๋ถˆ๋Ÿฌ์˜ค๋Š” ๋‹จ๊ณ„

ํ•œ๋ฒˆ, ๋ชจ๋ธ์„ ์ €์žฅ & ๋ถˆ๋Ÿฌ์˜ค๋Š” ๋‹จ๊ณ„๋ฅผ ํ•œ๋ฒˆ ๊ตฌํ˜„ํ•ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.
  1. ๋ฐ์ดํ„ฐ ๋ถˆ๋Ÿฌ์˜ฌ ๋•Œ ํ•„์š”ํ•œ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋“ค ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
  2. ์‹ ๊ฒฝ๋ง์„ ๊ตฌ์„ฑํ•˜๊ณ  ์ดˆ๊ธฐํ™”ํ•˜๊ธฐ
  3. ์˜ตํ‹ฐ๋งˆ์ด์ € ์ดˆ๊ธฐํ™”ํ•˜๊ธฐ
  4. ์ผ๋ฐ˜์ ์ธ ์ฒดํฌํฌ์ธํŠธ ์ €์žฅํ•˜๊ธฐ
  5. ์ผ๋ฐ˜์ ์ธ ์ฒดํฌํฌ์ธํŠธ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
  6. ๋ฐ์ดํ„ฐ๋ฅผ ๋ถˆ๋Ÿฌ์˜ฌ๋–„ ํ•„์š”ํ•œ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
import torch
import torch.nn as nn
import torch.optim as optim

Neural Network (์‹ ๊ฒฝ๋ง) ๊ตฌ์„ฑ & ์ดˆ๊ธฐํ™”

์˜ˆ์‹œ๋ฅผ ๋“ค์–ด ๋ชจ๋ธ์„ ์ €์žฅ & ๋ถˆ๋Ÿฌ์˜ค๊ธฐ ์œ„ํ•œ ์‹ ๊ฒฝ๋ง์„ ํ•œ๋ฒˆ ๊ตฌ์„ฑํ•ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.
class NeuarlNetwork(nn.Module):
    def __init__(self):
        super(NeuarlNetwork, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.fc1 = nn.Linear(64 * 8 * 8, 512)
        self.fc2 = nn.Linear(512, 64)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 64 * 8 * 8)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

net = NeuarlNetwork()
print(net)
NeuarlNetwork(
  (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=4096, out_features=512, bias=True)
  (fc2): Linear(in_features=512, out_features=64, bias=True)
  (fc3): Linear(in_features=64, out_features=10, bias=True)
)

Optimizer ์ดˆ๊ธฐํ™” ํ•˜๊ธฐ

Momentum(๋ชจ๋ฉ˜ํ…€)์ด๋ž‘ SGD๋ฅผ ์‚ฌ์šฉํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค.
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

Model์˜ Checkpoint ์ €์žฅํ•˜๊ธฐ

Model์— ๊ด€๋ จ๋œ ์ •๋ณด๋ฅผ ๋ถˆ๋Ÿฌ์™€์„œ Dictionary๋ฅผ ๊ตฌ์„ฑํ•ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.
# ์ถ”๊ฐ€ ์ •๋ณด
EPOCH = 5  # ์ €์žฅํ•  ์—ํญ ์ˆ˜
PATH = "model.pt"  # ๋ชจ๋ธ ์ €์žฅ ๊ฒฝ๋กœ
LOSS = 0.4  # ๋งˆ์ง€๋ง‰ ์†์‹ค ๊ฐ’

# ๋ชจ๋ธ๊ณผ ์˜ตํ‹ฐ๋งˆ์ด์ € ์ƒํƒœ ์ €์žฅ
torch.save({
            'epoch': EPOCH,  # ํ˜„์žฌ ์—ํญ ์ˆ˜
            'model_state_dict': net.state_dict(),  # ๋ชจ๋ธ์˜ ์ƒํƒœ(dictionary)
            'optimizer_state_dict': optimizer.state_dict(),  # ์˜ตํ‹ฐ๋งˆ์ด์ €์˜ ์ƒํƒœ(dictionary)
            'loss': LOSS,  # ๋งˆ์ง€๋ง‰ ์†์‹ค ๊ฐ’
            }, PATH)  # ์ง€์ •๋œ ๊ฒฝ๋กœ์— ์ €์žฅ

Model์˜ Checkpoint ๋ถˆ๋Ÿฌ์˜ค๊ธฐ

Model๊ณผ Optimizer๋ฅผ ์ดˆ๊ธฐํ™” ํ•œํ›„, Checkpoint๋ฅผ ์ €์žฅํ•œ Dictionary๋ฅผ ๋ถˆ๋Ÿฌ์™€์•ผ ํ•ฉ๋‹ˆ๋‹ค.
# ๋ชจ๋ธ๊ณผ ์˜ตํ‹ฐ๋งˆ์ด์ € ์ดˆ๊ธฐํ™”
model = Net()  # ์‚ฌ์šฉ์ž ์ •์˜ ๋ชจ๋ธ ์ธ์Šคํ„ด์Šค ์ƒ์„ฑ
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)  # SGD ์˜ตํ‹ฐ๋งˆ์ด์ € ์ดˆ๊ธฐํ™”

# ์ฒดํฌํฌ์ธํŠธ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
checkpoint = torch.load(PATH)  # ์ €์žฅ๋œ ์ฒดํฌํฌ์ธํŠธ ํŒŒ์ผ ๋กœ๋“œ
model.load_state_dict(checkpoint['model_state_dict'])  # ๋ชจ๋ธ ์ƒํƒœ ๋ณต์›
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])  # ์˜ตํ‹ฐ๋งˆ์ด์ € ์ƒํƒœ ๋ณต์›
epoch = checkpoint['epoch']  # ์—ํญ ์ˆ˜ ๋ณต์›
loss = checkpoint['loss']  # ์†์‹ค ๊ฐ’ ๋ณต์›

# ๋ชจ๋ธ ๋ชจ๋“œ ์„ค์ •
model.eval()  # ํ‰๊ฐ€ ๋ชจ๋“œ๋กœ ์„ค์ • (๋“œ๋กญ์•„์›ƒ, ๋ฐฐ์น˜ ์ •๊ทœํ™” ๋น„ํ™œ์„ฑํ™”)
# - ๋˜๋Š” -
model.train()  # ํ•™์Šต ๋ชจ๋“œ๋กœ ์„ค์ • (๋“œ๋กญ์•„์›ƒ, ๋ฐฐ์น˜ ์ •๊ทœํ™” ํ™œ์„ฑํ™”)
  • ์ถ”๋ก (inference)์„ ์‹คํ–‰ํ•˜๊ธฐ ์ „์— model.eval() ์„ ํ˜ธ์ถœํ•˜์—ฌ ๋“œ๋กญ์•„์›ƒ(dropout)๊ณผ ๋ฐฐ์น˜ ์ •๊ทœํ™” ์ธต(batch normalization layer)์„ ํ‰๊ฐ€(evaluation) ๋ชจ๋“œ๋กœ ๋ฐ”๊ฟ”์•ผํ•œ๋‹ค๋Š” ๊ฒƒ์„ ๊ธฐ์–ตํ•ด์•ผํ•ฉ๋‹ˆ๋‹ค.
  • ๋งŒ์•ฝ ์ด๊ฒƒ์„ ๋นผ๋จน์œผ๋ฉด ์ผ๊ด€์„ฑ ์—†๋Š” ์ถ”๋ก  ๊ฒฐ๊ณผ๋ฅผ ์–ป๊ฒŒ ๋ฉ๋‹ˆ๋‹ค.
  • ๋˜ํ•œ ํ•™์Šต์„ ๊ณ„์†ํ•˜๊ธธ ์›ํ•œ๋‹ค๋ฉด model.train() ์„ ํ˜ธ์ถœํ•˜์—ฌ ์ด ์ธต(layer)๋“ค์ด ํ•™์Šต ๋ชจ๋“œ์ธ์ง€ ํ™•์ธ(ensure)ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.