๋ฐ์ํ
โ ๏ธ ๋ณธ ๋ด์ฉ์ PyTorch Korea์ ๊ณต์ ๋ฌธ์์ ๊ธฐ๋ฐํ์ฌ ๊ณต๋ถํ ๋ด์ฉ์ ์ ์๊ฒ์ด๋ ์ํด๋ฐ๋๋๋ค!
Model ์ ์ฅํ๊ณ ๋ถ๋ฌ์ค๊ธฐ
์ด๋ฒ์๋ ์ ์ฅ or ๋ถ๋ฌ์ค๊ธฐ๋ฅผ ํตํด ๋ชจ๋ธ์ ์ํ ์ ์ง(persist)๋ฐ ๋ชจ๋ธ์ ์์ธก์ ์ํํ๋ ๋ฐฉ๋ฒ์ ๋ณด๊ฒ ์ต๋๋ค.
import torch
import torchvision.models as models
Model Weight(๊ฐ์ค์น) ์ ์ฅํ๊ณ ๋ถ๋ฌ์ค๊ธฐ
- PyTorch ๋ชจ๋ธ์ ํ์ตํ ๋งค๊ฐ๋ณ์๋ฅผ state_dict๋ผ๊ณ ๋ถ๋ฆฌ๋ ๋ด๋ถ ์ํ ์ฌ์ (internal state dictionary)์ ์ ์ฅํฉ๋๋ค.
- ์ด ์ํ ๊ฐ๋ค์ torch.save ๋ฉ์๋๋ฅผ ์ฌ์ฉํ์ฌ ์ ์ฅ(persist)ํ ์ ์์ต๋๋ค.
model = models.vgg16(weights='IMAGENET1K_V1')
torch.save(model.state_dict(), 'model_weights.pth')
Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth
100%|โโโโโโโโโโ| 528M/528M [00:07<00:00, 69.5MB/s]
- ๋ชจ๋ธ ๊ฐ์ค์น๋ฅผ ๋ถ๋ฌ์ค๊ธฐ ์ํด์๋, ๋จผ์ ๋์ผํ ๋ชจ๋ธ์ ์ธ์คํด์ค(instance)๋ฅผ ์์ฑํ ๋ค์์ load_state_dict() ๋ฉ์๋๋ฅผ ์ฌ์ฉํ์ฌ ๋งค๊ฐ๋ณ์๋ค์ ๋ถ๋ฌ์ต๋๋ค.
model = models.vgg16() # ์ฌ๊ธฐ์๋ ``weights`` ๋ฅผ ์ง์ ํ์ง ์์์ผ๋ฏ๋ก, ํ์ต๋์ง ์์ ๋ชจ๋ธ์ ์์ฑํฉ๋๋ค.
model.load_state_dict(torch.load('model_weights.pth'))
model.eval()
VGG(
(features): Sequential(
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace=True)
(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU(inplace=True)
(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(6): ReLU(inplace=True)
(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(8): ReLU(inplace=True)
(9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(11): ReLU(inplace=True)
(12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(13): ReLU(inplace=True)
(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(15): ReLU(inplace=True)
(16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(18): ReLU(inplace=True)
(19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(20): ReLU(inplace=True)
(21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(22): ReLU(inplace=True)
(23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(25): ReLU(inplace=True)
(26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(27): ReLU(inplace=True)
(28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(29): ReLU(inplace=True)
(30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
(classifier): Sequential(
(0): Linear(in_features=25088, out_features=4096, bias=True)
(1): ReLU(inplace=True)
(2): Dropout(p=0.5, inplace=False)
(3): Linear(in_features=4096, out_features=4096, bias=True)
(4): ReLU(inplace=True)
(5): Dropout(p=0.5, inplace=False)
(6): Linear(in_features=4096, out_features=1000, bias=True)
)
)
- ๋ง์ฝ์, ์ฌ์ ์ ํ์ต๋ Weight(๊ฐ์ค์น)๋ฅผ ์ฌ์ฉํ์ฌ ๋ชจ๋ธ์ ๋ก๋ ํ๋ค๊ณ ํ๋ฉด? ์๋์ ์ฝ๋์ฒ๋ผ ์์ฑํ๋ฉด ๋ฉ๋๋ค.
# ์ฌ์ ํ์ต๋ ๊ฐ์ค์น๋ฅผ ์ฌ์ฉํ์ฌ ๋ชจ๋ธ ๋ก๋
model = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1)
# ๋ชจ๋ธ์ ํ๊ฐ ๋ชจ๋๋ก ์ค์
model.eval()
VGG(
(features): Sequential(
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace=True)
(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU(inplace=True)
(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(6): ReLU(inplace=True)
(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(8): ReLU(inplace=True)
(9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(11): ReLU(inplace=True)
(12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(13): ReLU(inplace=True)
(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(15): ReLU(inplace=True)
(16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(18): ReLU(inplace=True)
(19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(20): ReLU(inplace=True)
(21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(22): ReLU(inplace=True)
(23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(25): ReLU(inplace=True)
(26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(27): ReLU(inplace=True)
(28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(29): ReLU(inplace=True)
(30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
(classifier): Sequential(
(0): Linear(in_features=25088, out_features=4096, bias=True)
(1): ReLU(inplace=True)
(2): Dropout(p=0.5, inplace=False)
(3): Linear(in_features=4096, out_features=4096, bias=True)
(4): ReLU(inplace=True)
(5): Dropout(p=0.5, inplace=False)
(6): Linear(in_features=4096, out_features=1000, bias=True)
)
)
Note
์ถ๋ก (inference)์ ํ๊ธฐ ์ ์ model.eval() ๋ฉ์๋๋ฅผ ํธ์ถํ์ฌ ๋๋กญ์์(dropout)๊ณผ ๋ฐฐ์น ์ ๊ทํ(batch normalization)๋ฅผ ํ๊ฐ ๋ชจ๋(evaluation mode)๋ก ์ค์ ํด์ผ ํฉ๋๋ค. ๊ทธ๋ ์ง ์์ผ๋ฉด ์ผ๊ด์ฑ ์๋ ์ถ๋ก ๊ฒฐ๊ณผ๊ฐ ์์ฑ๋ฉ๋๋ค.
๋ชจ๋ธ์ ํํ๋ฅผ ํฌํจํ์ฌ ์ ์ฅ & ๋ถ๋ฌ์ค๊ธฐ
๋ชจ๋ธ์ ๊ฐ์ค์น๋ฅผ ๋ถ๋ฌ์ฌ ๋, ์ ๊ฒฝ๋ง์ ๊ตฌ์กฐ๋ฅผ ์ ์ํ๊ธฐ ์ํด ๋ชจ๋ธ ํด๋์ค๋ฅผ ๋จผ์ ์์ฑ(instantiate)ํด์ผ ํ์ต๋๋ค.
- ์ด ํด๋์ค์ ๊ตฌ์กฐ๋ฅผ ๋ชจ๋ธ๊ณผ ํจ๊ป ์ ์ฅํ๊ณ ์ถ์ผ๋ฉด, (model.state_dict()๊ฐ ์๋) model ์ ์ ์ฅ ํจ์์ ์ ๋ฌํฉ๋๋ค.
torch.save(model, 'model.pth')
- ๋ค์๊ณผ ๊ฐ์ด ๋ชจ๋ธ์ ๋ถ๋ฌ์ฌ ์ ์์ต๋๋ค.
model = torch.load('model.pth')
model
VGG(
(features): Sequential(
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace=True)
(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU(inplace=True)
(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(6): ReLU(inplace=True)
(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(8): ReLU(inplace=True)
(9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(11): ReLU(inplace=True)
(12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(13): ReLU(inplace=True)
(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(15): ReLU(inplace=True)
(16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(18): ReLU(inplace=True)
(19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(20): ReLU(inplace=True)
(21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(22): ReLU(inplace=True)
(23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(25): ReLU(inplace=True)
(26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(27): ReLU(inplace=True)
(28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(29): ReLU(inplace=True)
(30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
(classifier): Sequential(
(0): Linear(in_features=25088, out_features=4096, bias=True)
(1): ReLU(inplace=True)
(2): Dropout(p=0.5, inplace=False)
(3): Linear(in_features=4096, out_features=4096, bias=True)
(4): ReLU(inplace=True)
(5): Dropout(p=0.5, inplace=False)
(6): Linear(in_features=4096, out_features=1000, bias=True)
)
)
Note
์ด ์ ๊ทผ ๋ฐฉ์์ Python pickle ๋ชจ๋์ ์ฌ์ฉํ์ฌ ๋ชจ๋ธ์ ์ง๋ ฌํ(serialize)ํ๋ฏ๋ก, ๋ชจ๋ธ์ ๋ถ๋ฌ์ฌ ๋ ์ค์ ํด๋์ค ์ ์(definition)๋ฅผ ์ ์ฉ(rely on)ํฉ๋๋ค.
๋ฐ์ํ
'๐ฅ PyTorch' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
[PyTorch] Checkpoint Model ์ ์ฅ & ๋ถ๋ฌ์ค๊ธฐ (0) | 2024.08.02 |
---|---|
[PyTorch] ๋ชจ๋ธ ๋งค๊ฐ๋ณ์ ์ต์ ํ(Optimization) ํ๊ธฐ (0) | 2024.07.30 |
[PyTorch] Torch.Autograd๋ฅผ ์ด์ฉํ ์๋ ๋ฏธ๋ถ (0) | 2024.07.30 |
[PyTorch] Neural Network Model (์ ๊ฒฝ๋ง ๋ชจ๋ธ) ๊ตฌ์ฑํ๊ธฐ (0) | 2024.07.26 |
[PyTorch] Transform (๋ณํ) (0) | 2024.07.26 |