๋ฐ์ํ
Batch Normalization - ๋ฐฐ์น ์ ๊ทํ
Batch Normalization (๋ฐฐ์น ์ ๊ทํ)์ ๊ฐ๋ ์ 2015๋ ์ ์ ์๋ ๋ฐฉ๋ฒ์ ๋๋ค.
- ์ผ๋จ, Batch Normalization(๋ฐฐ์น ์ ๊ทํ)๊ฐ ์ฃผ๋ชฉ๋ฐ๋ ์ด์ ๋ ๋ค์์ ์ด์ ๋ค๊ณผ ๊ฐ์ต๋๋ค.
- Training(ํ์ต)์ ๋นจ๋ฆฌ ํ ์ ์์ต๋๋ค. ์ฆ, Training(ํ์ต) ์๋๋ฅผ ๊ฐ์ ํ๋ ํจ๊ณผ๊ฐ ์์ต๋๋ค.
- ์ด๊น๊ฐ์ ํฌ๊ฒ ์์กดํ์ง ์๋๋ค๋ ํน์ง์ด ์์ต๋๋ค.
- ๊ทธ๋ฆฌ๊ณ Overiftting์ ์ต์ ํ๋ ํน์ง์ด ์์ต๋๋ค. ์ฆ, Dropout๋ฑ์ ํ์์ฑ์ด ๊ฐ์ํฉ๋๋ค.
- Batch Normalization(๋ฐฐ์น ์ ๊ทํ)์ ๊ธฐ๋ณธ ์์ด๋์ด๋ ์์์ ๋งํ๋ฏ์ด ๊ฐ Layer(์ธต)์์์ Activation Value(ํ์ฑํ ๊ฐ)์ด ์ ๋นํ ๋ถํฌ๊ฐ ๋๋๋ก ์กฐ์ ํ๋ ๊ฒ์ ๋๋ค. ํ๋ฒ ์์๋ฅผ ๋ณด๊ฒ ์ต๋๋ค.
- Batch Normalization(๋ฐฐ์น ์ ๊ทํ)๋ ๊ทธ ์ด๋ฆ๊ณผ ๊ฐ์ด ํ์ต์ Mini-Batch๋ฅผ ๋จ์๋ก ์ ๊ทํ๋ฅผ ํฉ๋๋ค.
๋ฏธ๋ ๋ฐฐ์น(mini-batch)๋ ๋ฐ์ดํฐ์ ์ ์์ ํฌ๊ธฐ์ ์ผ๋ถ๋ก ๋๋์ด ๋คํธ์ํฌ๋ฅผ ํ์ต์ํค๋ ๋ฐฉ๋ฒ์ ๋๋ค.
์ ์ฒด ๋ฐ์ดํฐ์ ์ ํ ๋ฒ์ ๋ชจ๋ ์ฌ์ฉํ๋ ๊ฒ์ด ์๋๋ผ ๋ฐ์ดํฐ๋ฅผ ์์ ๋ฐฐ์น๋ก ๋๋์ด ๊ฐ ๋ฐฐ์น์ ๋ํด ์์ฐจ์ ์ผ๋ก ํ์ต์ ์งํํฉ๋๋ค.
- ๊ตฌ์ฒด์ ์ผ๋ก๋ Mean(ํ๊ท )์ด 0, Variance(๋ถ์ฐ)์ด 1์ด ๋๋๋ก ์ ๊ทํํฉ๋๋ค. ์์์ผ๋ก๋ ์๋์ ์๊ณผ ๊ฐ์ต๋๋ค.
- ์ ์์์ Mini-Batch B = ๊ฐ์ ์ ๋ ฅ ๋ฐ์ดํฐ์ ์งํฉ์ ๋ํค ํ๊ท μB์ ๋ถ์ฐ σBโ²์ ๊ตฌํฉ๋๋ค.
- ๊ทธ๋ฆฌ๊ณ ์ ๋ ฅ ๋ฐ์ดํฐ๋ฅผ Mean(ํ๊ท )์ด 0, Variance(๋ถ์ฐ)์ด 1์ด ๋๊ฒ Normalization(์ ๊ทํ) ํ๊ณ , ε์ 0์ผ๋ก ๋๋๋ ์ฌํ๋ฅผ ์๋ฐฉํ๋ ์ญํ ์ ํฉ๋๋ค.
- ๋ Batch Normalization(๋ฐฐ์น ์ ๊ทํ) Layer ๋ง๋ค ์ด ์ ๊ทํ๋ ๋ฐ์ดํฐ์ ๊ณ ์ ํ ํ๋ scale์ ์ด๋ shift ๋ณํ์ ์ํํฉ๋๋ค.
- ์์์ผ๋ก๋ ์๋์ ๊ฐ์ต๋๋ค.
- γ : ํ๋, β : ์ด๋์ ๋ด๋นํฉ๋๋ค.
- ๋ ๊ฐ์ ์ฒ์์๋ γ=1, β=0 (1๋ฐฐ ํ๋, ์ด๋ ์์=์๋ณธ ๊ทธ๋๋ก)์์ ์์ํด์ ํ์ตํ๋ฉฐ ์ ํฉํ ๊ฐ์ผ๋ก ์กฐ์ ํด๊ฐ๋๋ค.
- ์ด๊ฒ์ด Batch Normalization(๋ฐฐ์น ์ ๊ทํ)์ ์๊ณ ๋ฆฌ์ฆ์ ๋๋ค. ์ค๋ช ํ ๋ด์ฉ์ ์๋์ ๊ทธ๋ํ๋ก ๋ํ๋ผ ์ ์์ต๋๋ค.
Batch Normalization (๋ฐฐ์น ์ ๊ทํ)์ ํจ๊ณผ
ํ๋ฒ Batch Normalization(๋ฐฐ์น ์ ๊ทํ) ๊ณ์ธต์ ์ฌ์ฉํ ์คํ์ ํ๋ฒ Mnist Dataset์ ์ฌ์ฉํ์ฌ Batch Normalization Layer๋ฅผ ์ฌ์ฉํ ๋, ์ฌ์ฉํ์ง ์์๋์ ํ์ต ์ง๋๊ฐ ์ด๋ป๊ฒ ๋ฌ๋ผ์ง๋์ง๋ฅผ ๋ณด๊ฒ ์ต๋๋ค.
# coding: utf-8
import sys
import os
import numpy as np
import matplotlib.pyplot as plt
sys.path.append(os.pardir) # ๋ถ๋ชจ ๋๋ ํฐ๋ฆฌ์ ํ์ผ์ ๊ฐ์ ธ์ฌ ์ ์๋๋ก ์ค์
from dataset.mnist import load_mnist # MNIST ๋ฐ์ดํฐ์
์ ๋ถ๋ฌ์ค๋ ํจ์
from common.multi_layer_net_extend import MultiLayerNetExtend # ๋ค์ธต ์ ๊ฒฝ๋ง ๋ชจ๋ธ ํด๋์ค
from common.optimizer import SGD, Adam # ์ต์ ํ ์๊ณ ๋ฆฌ์ฆ ํด๋์ค
(x_train, t_train), (x_test, t_test) = load_mnist(normalize=True)
# ํ์ต ๋ฐ์ดํฐ๋ฅผ ์ค์
x_train = x_train[:1000]
t_train = t_train[:1000]
max_epochs = 20
train_size = x_train.shape[0]
batch_size = 100
learning_rate = 0.01
# ๊ฐ์ค์น ์ด๊ธฐํ ํ์คํธ์ฐจ ์ค์ ๋ฐ ํ์ต ํจ์ ์ ์
def __train(weight_init_std):
# ๋ฐฐ์น ์ ๊ทํ๋ฅผ ์ ์ฉํ ์ ๊ฒฝ๋ง ๋ชจ๋ธ๊ณผ ์ ์ฉํ์ง ์์ ์ ๊ฒฝ๋ง ๋ชจ๋ธ ์์ฑ
bn_network = MultiLayerNetExtend(input_size=784,
hidden_size_list=[100, 100, 100, 100, 100],
output_size=10,
weight_init_std=weight_init_std,
use_batchnorm=True) # ๋ฐฐ์น ์ ๊ทํ ์ฌ์ฉ
network = MultiLayerNetExtend(input_size=784,
hidden_size_list=[100, 100, 100, 100, 100],
output_size=10,
weight_init_std=weight_init_std) # ๋ฐฐ์น ์ ๊ทํ ๋ฏธ์ฌ์ฉ
optimizer = SGD(lr=learning_rate) # ํ๋ฅ ์ ๊ฒฝ์ฌ ํ๊ฐ๋ฒ
train_acc_list = [] # ํ์ต ์ ํ๋ ๊ธฐ๋ก ๋ฆฌ์คํธ
bn_train_acc_list = [] # ๋ฐฐ์น ์ ๊ทํ ์ ์ฉ ํ์ต ์ ํ๋ ๊ธฐ๋ก ๋ฆฌ์คํธ
iter_per_epoch = max(train_size / batch_size, 1)
epoch_cnt = 0
for i in range(1000000000):
batch_mask = np.random.choice(train_size, batch_size) # ๋ฌด์์ ๋ฐฐ์น ์ํ๋ง
x_batch = x_train[batch_mask]
t_batch = t_train[batch_mask]
for _network in (bn_network, network):
grads = _network.gradient(x_batch, t_batch) # ๊ธฐ์ธ๊ธฐ ๊ณ์ฐ
optimizer.update(_network.params, grads) # ๊ฐ์ค์น ์
๋ฐ์ดํธ
if i % iter_per_epoch == 0:
train_acc = network.accuracy(x_train, t_train) # ์ ํ๋ ๊ณ์ฐ
bn_train_acc = bn_network.accuracy(x_train, t_train) # ๋ฐฐ์น ์ ๊ทํ ์ ์ฉ ์ ํ๋ ๊ณ์ฐ
train_acc_list.append(train_acc)
bn_train_acc_list.append(bn_train_acc)
print("epoch:" + str(epoch_cnt) + " | " + str(train_acc) + " - "
+ str(bn_train_acc))
epoch_cnt += 1
if epoch_cnt >= max_epochs:
break
return train_acc_list, bn_train_acc_list
# ๊ทธ๋ํ ๊ทธ๋ฆฌ๊ธฐ==========
weight_scale_list = np.logspace(0, -4, num=16)
x = np.arange(max_epochs)
for i, w in enumerate(weight_scale_list):
print("============== " + str(i+1) + "/16" + " =============
- ๋ณด๋ฉด Batch Normalization (๋ฐฐ์น ์ ๊ทํ)๊ฐ ํ์ต์ ๋นจ๋ฆฌ ์ง์ ์ํค๊ณ ์์ต๋๋ค.
- ๊ทธ๋ฌ๋ฉด Weight ์ด๊น๊ฐ์ ํ์คํธ์ฐจ๋ฅผ ๋ค์ํ๊ฒ ๋ด๊ฟ๊ฐ๋ฉด์ ํ์ต ๊ฒฝ๊ณผ๋ฅผ ๊ด์ฐฐํ ๊ทธ๋ํ ์ ๋๋ค.
- ๊ฑฐ์ด ๋ชจ๋ ๊ฒฝ์ฐ์์ Batch Normalization(๋ฐฐ์น ์ ๊ทํ)๋ฅผ ์ฌ์ฉํ ๋์ Training(ํ์ต) ์๋๊ฐ ๋น ๋ฅธ ๊ฒ์ผ๋ก ๋ํ๋ฉ๋๋ค.
- ์ค์ ๋ก Batch Normalization(๋ฐฐ์น ์ ๊ทํ)๋ฅผ ์ด์ฉํ์ง ์๋ ๊ฒฝ์ฐ์๋ ์ด๊ฐ๊ฐ์ด ์ ๋ถํฌ๋์ด ์์ง ์์ผ๋ฉด Training(ํ์ต)์ด ์ ํ ์งํ๋์ง ์์ ๋ชจ์ต๋ ํ์ธํ ์ ์์ต๋๋ค.
Summary: Batch Normalization(๋ฐฐ์น ์ ๊ทํ)๋ฅผ ์ฌ์ฉํ๋ฉด Training(ํ์ต)์ด ๋นจ๋ฆฌ์๋ฉด, Weight(๊ฐ์ค์น) ์ด๊น๊ฐ์ ํฌ๊ฒ ์์กด ํ์ง ์์๋ ๋๋ค๋ ํน์ง์ด ์์ต๋๋ค.
๋ฐ์ํ