๋ฐ์ํ
โ ๏ธ ๋ณธ ๋ด์ฉ์ PyTorch Korea์ ๊ณต์ ๋ฌธ์์ ๊ธฐ๋ฐํ์ฌ ๊ณต๋ถํ ๋ด์ฉ์ ์ ์๊ฒ์ด๋ ์ํด๋ฐ๋๋๋ค!
PyTorch์์ Inference & Training์ ๋ค์ ํ๊ธฐ ์ํด์ Checkpoint Model์ ์ ์ฅ & ๋ถ๋ฌ์ค๋๊ฒ์ ํ๋ฒ ํด๋ณด๊ฒ ์ต๋๋ค.
Intro
- PyTorch์์ ์ฌ๋ฌ Checkpoint๋ค์ ์ ์ฅํ๊ธฐ ์ํด์ ์ฌ์ (Dictionary)์ Checkpoint๋ค์ ๊ตฌ์ฑํํ,
- torch.save()๋ฅผ ์ฌ์ฉํ์ฌ Dictionary๋ฅผ ์ง๋ ฌํ(Seralize)ํด์ผ ํฉ๋๋ค.
- PyTorch์์๋ ์ฌ๋ฌ ์ฒดํฌํฌ์ธํธ๋ค์ ์ ์ฅํ ๋ .tar ํ์ฅ์๋ฅผ ์ฌ์ฉํฉ๋๋ค.
- ํญ๋ชฉ๋ค์ ๋ถ๋ฌ์ฌ ๋์๋, ๋จผ์ ๋ชจ๋ธ๊ณผ ์ตํฐ๋ง์ด์ ๋ฅผ ์ด๊ธฐํํ๊ณ , torch.load()๋ฅผ ์ฌ์ฉํ์ฌ ์ฌ์ ์ ๋ถ๋ฌ์ต๋๋ค.
- ์ดํ ์ํ๋๋๋ก ์ ์ฅํ ํญ๋ชฉ๋ค์ ์ฌ์ ์ ์กฐํํ์ฌ ์ ๊ทผํ ์ ์์ต๋๋ค.
Setting (์ค์ )
์ผ๋จ, ๋จผ์ PyTorch ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ์ค์นํด์ค์ผ ํฉ๋๋ค.
!pip install torch
Model ์ ์ฅ & ๋ถ๋ฌ์ค๋ ๋จ๊ณ
ํ๋ฒ, ๋ชจ๋ธ์ ์ ์ฅ & ๋ถ๋ฌ์ค๋ ๋จ๊ณ๋ฅผ ํ๋ฒ ๊ตฌํํด ๋ณด๊ฒ ์ต๋๋ค.
- ๋ฐ์ดํฐ ๋ถ๋ฌ์ฌ ๋ ํ์ํ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ค ๋ถ๋ฌ์ค๊ธฐ
- ์ ๊ฒฝ๋ง์ ๊ตฌ์ฑํ๊ณ ์ด๊ธฐํํ๊ธฐ
- ์ตํฐ๋ง์ด์ ์ด๊ธฐํํ๊ธฐ
- ์ผ๋ฐ์ ์ธ ์ฒดํฌํฌ์ธํธ ์ ์ฅํ๊ธฐ
- ์ผ๋ฐ์ ์ธ ์ฒดํฌํฌ์ธํธ ๋ถ๋ฌ์ค๊ธฐ
- ๋ฐ์ดํฐ๋ฅผ ๋ถ๋ฌ์ฌ๋ ํ์ํ ๋ผ์ด๋ธ๋ฌ๋ฆฌ ๋ถ๋ฌ์ค๊ธฐ
import torch
import torch.nn as nn
import torch.optim as optim
Neural Network (์ ๊ฒฝ๋ง) ๊ตฌ์ฑ & ์ด๊ธฐํ
์์๋ฅผ ๋ค์ด ๋ชจ๋ธ์ ์ ์ฅ & ๋ถ๋ฌ์ค๊ธฐ ์ํ ์ ๊ฒฝ๋ง์ ํ๋ฒ ๊ตฌ์ฑํด ๋ณด๊ฒ ์ต๋๋ค.
class NeuarlNetwork(nn.Module):
def __init__(self):
super(NeuarlNetwork, self).__init__()
self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
self.fc1 = nn.Linear(64 * 8 * 8, 512)
self.fc2 = nn.Linear(512, 64)
self.fc3 = nn.Linear(64, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 64 * 8 * 8)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
net = NeuarlNetwork()
print(net)
NeuarlNetwork(
(conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(fc1): Linear(in_features=4096, out_features=512, bias=True)
(fc2): Linear(in_features=512, out_features=64, bias=True)
(fc3): Linear(in_features=64, out_features=10, bias=True)
)
Optimizer ์ด๊ธฐํ ํ๊ธฐ
Momentum(๋ชจ๋ฉํ )์ด๋ SGD๋ฅผ ์ฌ์ฉํ๊ฒ ์ต๋๋ค.
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
Model์ Checkpoint ์ ์ฅํ๊ธฐ
Model์ ๊ด๋ จ๋ ์ ๋ณด๋ฅผ ๋ถ๋ฌ์์ Dictionary๋ฅผ ๊ตฌ์ฑํด ๋ณด๊ฒ ์ต๋๋ค.
# ์ถ๊ฐ ์ ๋ณด
EPOCH = 5 # ์ ์ฅํ ์ํญ ์
PATH = "model.pt" # ๋ชจ๋ธ ์ ์ฅ ๊ฒฝ๋ก
LOSS = 0.4 # ๋ง์ง๋ง ์์ค ๊ฐ
# ๋ชจ๋ธ๊ณผ ์ตํฐ๋ง์ด์ ์ํ ์ ์ฅ
torch.save({
'epoch': EPOCH, # ํ์ฌ ์ํญ ์
'model_state_dict': net.state_dict(), # ๋ชจ๋ธ์ ์ํ(dictionary)
'optimizer_state_dict': optimizer.state_dict(), # ์ตํฐ๋ง์ด์ ์ ์ํ(dictionary)
'loss': LOSS, # ๋ง์ง๋ง ์์ค ๊ฐ
}, PATH) # ์ง์ ๋ ๊ฒฝ๋ก์ ์ ์ฅ
Model์ Checkpoint ๋ถ๋ฌ์ค๊ธฐ
Model๊ณผ Optimizer๋ฅผ ์ด๊ธฐํ ํํ, Checkpoint๋ฅผ ์ ์ฅํ Dictionary๋ฅผ ๋ถ๋ฌ์์ผ ํฉ๋๋ค.
# ๋ชจ๋ธ๊ณผ ์ตํฐ๋ง์ด์ ์ด๊ธฐํ
model = Net() # ์ฌ์ฉ์ ์ ์ ๋ชจ๋ธ ์ธ์คํด์ค ์์ฑ
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9) # SGD ์ตํฐ๋ง์ด์ ์ด๊ธฐํ
# ์ฒดํฌํฌ์ธํธ ๋ถ๋ฌ์ค๊ธฐ
checkpoint = torch.load(PATH) # ์ ์ฅ๋ ์ฒดํฌํฌ์ธํธ ํ์ผ ๋ก๋
model.load_state_dict(checkpoint['model_state_dict']) # ๋ชจ๋ธ ์ํ ๋ณต์
optimizer.load_state_dict(checkpoint['optimizer_state_dict']) # ์ตํฐ๋ง์ด์ ์ํ ๋ณต์
epoch = checkpoint['epoch'] # ์ํญ ์ ๋ณต์
loss = checkpoint['loss'] # ์์ค ๊ฐ ๋ณต์
# ๋ชจ๋ธ ๋ชจ๋ ์ค์
model.eval() # ํ๊ฐ ๋ชจ๋๋ก ์ค์ (๋๋กญ์์, ๋ฐฐ์น ์ ๊ทํ ๋นํ์ฑํ)
# - ๋๋ -
model.train() # ํ์ต ๋ชจ๋๋ก ์ค์ (๋๋กญ์์, ๋ฐฐ์น ์ ๊ทํ ํ์ฑํ)
- ์ถ๋ก (inference)์ ์คํํ๊ธฐ ์ ์ model.eval() ์ ํธ์ถํ์ฌ ๋๋กญ์์(dropout)๊ณผ ๋ฐฐ์น ์ ๊ทํ ์ธต(batch normalization layer)์ ํ๊ฐ(evaluation) ๋ชจ๋๋ก ๋ฐ๊ฟ์ผํ๋ค๋ ๊ฒ์ ๊ธฐ์ตํด์ผํฉ๋๋ค.
- ๋ง์ฝ ์ด๊ฒ์ ๋นผ๋จน์ผ๋ฉด ์ผ๊ด์ฑ ์๋ ์ถ๋ก ๊ฒฐ๊ณผ๋ฅผ ์ป๊ฒ ๋ฉ๋๋ค.
- ๋ํ ํ์ต์ ๊ณ์ํ๊ธธ ์ํ๋ค๋ฉด model.train() ์ ํธ์ถํ์ฌ ์ด ์ธต(layer)๋ค์ด ํ์ต ๋ชจ๋์ธ์ง ํ์ธ(ensure)ํด์ผ ํฉ๋๋ค.
๋ฐ์ํ
'๐ฅ PyTorch' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
[PyTorch] Model ์ ์ฅ & ๋ถ๋ฌ์ค๊ธฐ (0) | 2024.07.31 |
---|---|
[PyTorch] ๋ชจ๋ธ ๋งค๊ฐ๋ณ์ ์ต์ ํ(Optimization) ํ๊ธฐ (0) | 2024.07.30 |
[PyTorch] Torch.Autograd๋ฅผ ์ด์ฉํ ์๋ ๋ฏธ๋ถ (0) | 2024.07.30 |
[PyTorch] Neural Network Model (์ ๊ฒฝ๋ง ๋ชจ๋ธ) ๊ตฌ์ฑํ๊ธฐ (0) | 2024.07.26 |
[PyTorch] Transform (๋ณํ) (0) | 2024.07.26 |