A A
[DL] Transfer Learning - 전이 ν•™μŠ΅
Transfer Learning, 즉 전이 ν•™μŠ΅μ€ ML(λ¨Έμ‹  λŸ¬λ‹)κ³Ό DL(λ”₯λŸ¬λ‹)μ—μ„œ 기쑴의 Pre-Training 된 λͺ¨λΈμ„ μƒˆλ‘œμš΄ μž‘μ—…μ— μž¬μ‚¬μš©ν•˜λŠ” κΈ°λ²•μž…λ‹ˆλ‹€. 
  • 이 방법은 특히 λŒ€κ·œλͺ¨ λ°μ΄ν„°μ…‹μ—μ„œ ν•™μŠ΅λœ λͺ¨λΈμ„ μž‘μ€ 데이터셋에 μ μš©ν•  λ•Œ μœ μš©ν•©λ‹ˆλ‹€.
  • 전이 ν•™μŠ΅μ€ λͺ¨λΈμ΄ 사전 ν•™μŠ΅ν•œ 지식을 μƒˆλ‘œμš΄ λ¬Έμ œμ— μ μš©ν•˜μ—¬ ν•™μŠ΅ 속도λ₯Ό 높이고 μ„±λŠ₯을 ν–₯μƒμ‹œν‚¬ 수 μžˆμŠ΅λ‹ˆλ‹€.

Transfer Learning (전이 ν•™μŠ΅)

  • 기쑴의 Neural Network(신경망)μ—μ„œ μ΅œμƒμœ„ 뢀뢄을 μƒˆλ‘œ μ •μ˜ν•œ λ‹€μŒ, 이 뢀뢄을 Training μ‹œν‚€λŠ” 것이 Transfer Learning (전이 ν•™μŠ΅) 이라고 ν•©λ‹ˆλ‹€.
  • μ΄λ•Œ Neural Network(신경망)의 ν•˜μœ„ 뢀뢄은 이미 Training된 Neural Network(신경망)을 μ‚¬μš©ν•˜λ―€λ‘œ 적은 μ–‘μ˜ λ°μ΄ν„°λ‘œ μ‹œμŠ€ν…œμ„ Training μ‹œν‚¬ 수 μžˆμŠ΅λ‹ˆλ‹€.

Transfer Learning (전이 ν•™μŠ΅)의 μ£Όμš” κ°œλ…

  • 사전 ν•™μŠ΅λœ λͺ¨λΈ
    • λŒ€κ·œλͺ¨ 데이터셋, 예λ₯Ό λ“€μ–΄ ImageNetκ³Ό 같은 Training된 λͺ¨λΈμ„ μ‚¬μš©ν•©λ‹ˆλ‹€.
    • μ΄λŸ¬ν•œ λͺ¨λΈμ€ λ‹€μ–‘ν•œ κ°μ²΄λ‚˜ νŒ¨ν„΄μ„ μΈμ‹ν•˜λŠ” 데 맀우 λŠ₯μˆ™ν•©λ‹ˆλ‹€.
  • λ―Έμ„Έ μ‘°μ •(Fine-Tuning)
    • 사전 ν•™μŠ΅λœ λͺ¨λΈμ˜ 일뢀 Layerλ₯Ό κ³ μ •(Freeze)ν•˜κ³ , λ‚˜λ¨Έμ§€ Layerλ₯Ό μƒˆλ‘œμš΄ λ°μ΄ν„°μ…‹μœΌλ‘œ ν•™μŠ΅ν•©λ‹ˆλ‹€.
    • κ³ μ •λœ λ ˆμ΄μ–΄λŠ” 이전 μž‘μ—…μ—μ„œ Trainingν•œ νŠΉμ§•μ„ μœ μ§€ν•©λ‹ˆλ‹€.
    • μƒˆλ‘œμš΄ λ°μ΄ν„°μ…‹μ—μ„œ ν•™μŠ΅λœ LayerλŠ” μƒˆλ‘œμš΄ μž‘μ—…μ— 맞게 μ‘°μ •λ©λ‹ˆλ‹€.
  • κ³ μ •λœ νŠΉμ§• μΆ”μΆœκΈ°(Feature Extractor)
    • Pre-Train된 λͺ¨λΈμ˜ λͺ¨λ“  Layerλ₯Ό κ³ μ •ν•˜κ³ , λ§ˆμ§€λ§‰μ— μƒˆλ‘œμš΄ λΆ„λ₯˜κΈ°λ₯Ό μΆ”κ°€ν•˜μ—¬ Trainingν•©λ‹ˆλ‹€.
    • Pre-Train된 λͺ¨λΈμ€ μƒˆλ‘œμš΄ Dataset의 Feature(νŠΉμ§•)을 μΆ”μΆœν•˜κ³ , μƒˆλ‘œμš΄ λΆ„λ₯˜κΈ°λŠ” μ΄λŸ¬ν•œ νŠΉμ§•μ„ 기반으둜 Predict ν•©λ‹ˆλ‹€.

Transfer Learning (전이 ν•™μŠ΅)의 μž₯점

  • ν•™μŠ΅ μ‹œκ°„ 단좕
    • 초기 Weightλ₯Ό 잘 ν•™μŠ΅λœ λͺ¨λΈλ‘œλΆ€ν„° κ°€μ Έμ˜€κΈ° λ•Œλ¬Έμ— Training μ‹œκ°„μ΄ λ‹¨μΆ•λ©λ‹ˆλ‹€.
  • 더 적은 데이터 ν•„μš”
    • μž‘μ€ Datasetμ—μ„œλ„ 높은 μ„±λŠ₯을 달성할 수 μžˆμŠ΅λ‹ˆλ‹€.
    • λŒ€κ·œλͺ¨ Datasetμ—μ„œ Train된 λͺ¨λΈμ˜ 지식을 ν™œμš©ν•˜κΈ° λ•Œλ¬Έμ— 더 적은 μ–‘μ˜ Dataλ‘œλ„ 쒋은 μ„±λŠ₯을 λ‚Ό 수 μžˆμŠ΅λ‹ˆλ‹€.
  • μ„±λŠ₯ ν–₯상
    • κΈ°μ‘΄ λͺ¨λΈμ˜ Feature(νŠΉμ§•) μΆ”μΆœ λŠ₯λ ₯을 ν™œμš©ν•˜μ—¬ μƒˆλ‘œμš΄ μž‘μ—…μ—μ„œλ„ 높은 μ„±λŠ₯을 κΈ°λŒ€ν•  수 μžˆμŠ΅λ‹ˆλ‹€.

CNN Transfer Learning (CNN μ „μ΄ν•™μŠ΅)

Keras λΌμ΄λΈŒλŸ¬λ¦¬μ—μ„œ μ •μ˜λ˜μ–΄ μžˆλŠ” CNN νŒ¨ν‚€μ§€λ“€μ„ μ‚¬μš©ν•  수 μžˆμŠ΅λ‹ˆλ‹€.


Transfer Learning Example (전이 ν•™μŠ΅ μ˜ˆμ‹œ by TF, Keras)

ν•œλ²ˆ Transfer Learning의 μ˜ˆμ‹œ μ½”λ“œλ₯Ό ν•œλ²ˆ λ΄λ³΄κ² μŠ΅λ‹ˆλ‹€.
  • μ΄μ½”λ“œλŠ” Tensorflow와 Kears 라이브러리λ₯Ό μ‚¬μš©ν•˜μ—¬ Transfer Learning을 μˆ˜ν–‰ν•˜λŠ” μ˜ˆμ‹œμž…λ‹ˆλ‹€.

Pre-Train된 λͺ¨λΈ λ‘œλ“œ 및 μ€€λΉ„

import tensorflow as tf
from tensorflow.keras.applications import VGG16
from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam

# 사전 ν•™μŠ΅λœ VGG16 λͺ¨λΈ λ‘œλ“œ (ImageNet λ°μ΄ν„°μ…‹μ—μ„œ ν•™μŠ΅λ¨)
base_model = VGG16(weights='imagenet', include_top=False, input_shape=(224, 224, 3))

# 사전 ν•™μŠ΅λœ λͺ¨λΈμ˜ λ ˆμ΄μ–΄ κ³ μ • (Feature Extractor둜 μ‚¬μš©)
for layer in base_model.layers:
    layer.trainable = False

# μƒˆλ‘œμš΄ λΆ„λ₯˜κΈ° λ ˆμ΄μ–΄ μΆ”κ°€
x = base_model.output
x = Flatten()(x)
x = Dense(512, activation='relu')(x)
predictions = Dense(10, activation='softmax')(x)

# 전체 λͺ¨λΈ μ •μ˜
model = Model(inputs=base_model.input, outputs=predictions)

# λͺ¨λΈ 컴파일
model.compile(optimizer=Adam(learning_rate=0.0001), loss='categorical_crossentropy', metrics=['accuracy'])

 

Data μ€€λΉ„

데이터셋을 μ€€λΉ„ν•˜κ³ , 이미지 데이터λ₯Ό 뢈러였고, μ „μ²˜λ¦¬ν•˜λŠ” 과정을 ν¬ν•¨ν•©λ‹ˆλ‹€.
from tensorflow.keras.preprocessing.image import ImageDataGenerator

# 데이터셋 경둜
train_data_dir = '/path/to/train/data'
validation_data_dir = '/path/to/validation/data'

# 이미지 데이터 μ œλ„ˆλ ˆμ΄ν„° μ •μ˜ 및 데이터 μ „μ²˜λ¦¬
train_datagen = ImageDataGenerator(rescale=1.0/255.0, shear_range=0.2, zoom_range=0.2, horizontal_flip=True)
validation_datagen = ImageDataGenerator(rescale=1.0/255.0)

# ν•™μŠ΅ 데이터 μ œλ„ˆλ ˆμ΄ν„°
train_generator = train_datagen.flow_from_directory(train_data_dir, target_size=(224, 224), batch_size=32, class_mode='categorical')

# 검증 데이터 μ œλ„ˆλ ˆμ΄ν„°
validation_generator = validation_datagen.flow_from_directory(validation_data_dir, target_size=(224, 224), batch_size=32, class_mode='categorical')

 

λͺ¨λΈ ν•™μŠ΅

# λͺ¨λΈ ν•™μŠ΅
model.fit(train_generator, epochs=10, validation_data=validation_generator)
  • 사전 ν•™μŠ΅λœ VGG16 λͺ¨λΈμ„ μ‚¬μš©ν•˜μ—¬ μƒˆλ‘œμš΄ 데이터셋에 λŒ€ν•΄ 전이 ν•™μŠ΅μ„ μˆ˜ν–‰ν•˜λŠ” 방법을 λ³΄μ—¬μ€λ‹ˆλ‹€.