Open Image Dataset์ Object Detection ํ์ต ๋ฐ Inference
Open Image Dataset์์ Football ๊ด๋ จ Object, Fish๊ด๋ จ Object๋ฅผ ์ถ์ถ ํ ํ์ต ๋ฐ์ดํฐ ์ธํธ ์์ฑํ, ์ด๋ฅผ ์ด์ฉํ์ฌ Object Detection์ ์ํํด ๋ณด๊ฒ ์ต๋๋ค.
๋ฐ์ดํฐ ๋ค์ด๋ก๋ ๋งํฌ
GitHub - chulminkw/DLCV
Contribute to chulminkw/DLCV development by creating an account on GitHub.
github.com
์ด๋ฒ์ Keras-yolo ํจํค์ง, OpenImage Dataset์ ์ฌ์ฉํ์ต๋๋ค.
# ํ์ฌ ๋๋ ํ ๋ฆฌ๋ /content์ด๋ฉฐ ์ด ๋๋ ํ ๋ฆฌ๋ฅผ ๊ธฐ์ค์ผ๋ก ์ค์ต์ฝ๋์ ๋ฐ์ดํฐ๋ฅผ ๋ค์ด๋ก๋ ํฉ๋๋ค.
!pwd
!rm -rf DLCV
!git clone https://github.com/chulminkw/DLCV.git
# DLCV ๋๋ ํ ๋ฆฌ๊ฐ Download๋๊ณ DLCV ๋ฐ์ Detection๊ณผ Segmentation ๋๋ ํ ๋ฆฌ๊ฐ ์๋ ๊ฒ์ ํ์ธ
!ls -lia
!ls -lia DLCV
# Colab์์ GPU ์ปค๋ ์ ์ฉ์ tensorflow 1.13์ผ๋ก downgrade๊ฐ ๋์ง ์์ต๋๋ค.
# ๋๋ฌธ์ colab์์๋ Segmentation ํ์ต ์ tensorflow 1.15, keras 2.3 ์ ์ค์นํ๊ฒ ์ต๋๋ค.
# tensorflow 1.15์ ์ค์นํฉ๋๋ค. ์๋์ผ๋ก tensorflow 2.2๊ฐ 1.15๋ก downgrade ๋ฉ๋๋ค.
!pip install tensorflow-gpu==1.15.2
# keras 2.3๋ฅผ ์ค์นํฉ๋๋ค.
!pip install keras==2.3.0
์ฃผ์์ฌํญ
Keras-yolo3 ๋ Custom data ๋ฅผ train์ ์ค๋ฅ๊ฐ ๋ฐ์ํ๋๋ฑ tensorflow 1.15 ์ ์๋ฒฝํ๊ฒ ํธํํ์ง ์์ต๋๋ค.
๋๋ฌธ์ ์๋์ ๊ฐ์ด DLCV github์์ ์์ ๋ init.py ๋ฅผ ๋ค์ด๋ก๋ ๋ฐ์์ keras์ backend์ init.py ๋ฅผ ์์ ํด์ผ ํฉ๋๋ค.
init.py๋ ๋ฐ๋์ import tensorflow, import keras ์ด์ ์ ์ํ๋์ด์ผ ํฉ๋๋ค. ๋ง์ผ tensorflow, keras ์ค์นํ ๋ค ์๋์ import tensorflow, import keras๋ฅผ ๋จผ์ ์ํํ์์ผ๋ฉด ๋ค์ ์ฌ์์์ ํ ํ init.py๋ฅผ ์์ ํฉ๋๋ค.
import os
# keras backend ๋๋ ํ ๋ฆฌ ์ด๋.
os.chdir('/usr/local/lib/python3.6/dist-packages/keras/backend')
!rm -rf __init__.py
!rm -rf __pycache__
# ๊ธฐ์กด __init__.py ์ญ์ ํ๊ณ ์๋ก์ด __init__.py๋ฅผ download
!wget https://raw.githubusercontent.com/chulminkw/DLCV/master/colab_tf115_modify_files/__init__.py
# tensorflow๋ 1.15, keras๋ 2.3 ๋ฒ์ ํ์ธ
# GPU๊ฐ ์ธํ
๋์ด ์์ง ์์ผ๋ฉด ์๋จ ๋ฉ๋ด์์ ๋ฐํ์->๋ฐํ์ ์ ํ ๋ณ๊ฒฝ์์ GPU๋ฅผ ์ ํํ ํ ๋ฐํ์ ๋ค์ ์์์ ์ ํํ๊ณ ์ฒ์ ๋ถํฐ์ธ tensorflow, keras ์ค์น ๋ถํฐ ๋ค์ ์์.
import tensorflow as tf
import keras
print(tf.__version__)
print(keras.__version__)
# gpu๊ฐ ์ธํ
๋์ด ์๋์ง ํ์ธ.
tf.test.gpu_device_name()
Tip: ์ฝ๋ฉ ๋ฒ์ ์ ์๋๋ฅผ ์ด์ฉํ์ฌ keras-yolo3 ํจํค์ง๋ฅผ download ํ์ฌ /content/DLCV/Detection/yolo ๋ฐ์ ์ค์นํฉ๋๋ค.
%cd /content/DLCV/Detection/yolo
!git clone https://github.com/qqwweee/keras-yolo3.git
!ls -lia /content/DLCV/Detection/yolo/keras-yolo3
pretrained ๋ชจ๋ธ ์ฌ์์ฑ ๋ฐ font ๋๋ ํ ๋ฆฌ ๊ต์ฒด (Colab Version)
- ์ฝ๋ฉ์ผ๋ก ๋๋ฆฌ๋ ๊ฒฝ์ฐ์๋ ๋ค์ model_data ๋ฐ์ coco dataset๋ก pretrained ๋ yolov3.weights ํ์ผ์ yolo.h5 ํ์ผ๋ก ๋ณ๊ฒฝํด ์ค์ผ ํฉ๋๋ค.
- keras-yolo3์ font ๋๋ ํ ๋ฆฌ๋ ์ฌ ๊ต์ฒด ํฉ๋๋ค.
# ์ฝ๋ฉ ๋ฒ์ ์ ์๋๋ฅผ ์ด์ฉํ์ฌ yolov3.weights ํ์ผ์ download ๋ฐ๊ณ , convert.py ๋ฅผ ์ํํ์ฌ model_data ๋ฐ์ yolo.h5 ํ์ผ ์์ฑ ์ํ.
%cd /content/DLCV/Detection/yolo/keras-yolo3
# yolo ๊ณต์ ์ฌ์ดํธ์์ download์ download ์๋๊ฐ ์ฝ 25๋ถ ์ ๋ ์์๋จ. github์์ ๋ค์ด๋ก๋ ์๋ง.
# !wget https://pjreddie.com/media/files/yolov3.weights
# github์์ ๋ค์ด๋ก๋
!wget https://github.com/chulminkw/DLCV/releases/download/1.0/yolov3.weights
# yolov3.weights๋ฅผ keras-yolo3์์ ์ฌ์ฉํ ์ ์๋๋ก yolo.h5 ๋ก ๋ณํ
!python convert.py yolov3.cfg yolov3.weights model_data/yolo.h5
# model_data ๋ฐ์ yolo.h5 ํ์ผ์ด ์์ฑ๋์๋์ง ํ์ธ.
!ls /content/DLCV/Detection/yolo/keras-yolo3/model_data
# yolo.detect_image() ๋ฉ์๋๋ PIL package๋ฅผ ์ด์ฉํ์ฌ image ์์
์ํ. keras-yolo3/font ๋๋ ํ ๋ฆฌ๋ฅผ ์์ ๋๋ ํ ๋ฆฌ๋ก ๋ณต์ฌ ํด์ผํจ.
%cd /content/DLCV/Detection/yolo
!cp -rf keras-yolo3/font ./font
OID Toolkit์ ํตํด ๋ง๋ค์ด์ง ๋ฐ์ดํฐ ์ธํธ๋ฅผ ๋ค์ด๋ก๋ ํฉ๋๋ค.
- ๋ง๋ค์ด์ง ballnfish.tar ํ์ผ์ ๋ค์ด๋ก๋ ํ ๋ค ์ด๋ฅผ ์ด์ฉํ์ฌ ํ์ต ํฉ๋๋ค.
# ballnfish.tar ํ์ผ์ ๊ฐ์์ github์์ download ํ ๋ค ์์ถ์ ํ๋ฉด
# /content/DLCV/data/ballnfish ๋๋ ํ ๋ฆฌ๋ก annotations, images ๋๋ ํ ๋ฆฌ์ ๊ด๋ จ ํ์ผ๋ค์ด ์์ฑ ๋ฉ๋๋ค.
%cd /content/DLCV/data
!wget https://github.com/chulminkw/DLCV/releases/download/1.0/ballnfish.tar
!tar -xvf ballnfish.tar > /dev/null 2>&1
import os
from pathlib import Path
# ์ฝ๋ฉ ํ๊ฒฝ์์๋ ํ ๋๋ ํ ๋ฆฌ๋ฅผ '/content'๋ก ์ค์
HOME_DIR = '/content'
# ์ฃผ์๊ณผ ์ด๋ฏธ์ง ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก ์ค์
ANNO_DIR = os.path.join(HOME_DIR, 'DLCV/data/ballnfish/annotations')
IMAGE_DIR = os.path.join(HOME_DIR, 'DLCV/data/ballnfish/images')
print(ANNO_DIR)
# ์ฃผ์ ๋๋ ํ ๋ฆฌ์ ํ์ผ ๋ชฉ๋ก๊ณผ ๊ฐ์ ์ถ๋ ฅ
files = os.listdir(ANNO_DIR)
print('ํ์ผ ๊ฐ์๋:', len(files))
print(files)
- Open Image Dataset์ ์ฃผ์๊ณผ ์ด๋ฏธ์ง ๋๋ ํ ๋ฆฌ๋ฅผ ์ค์ ํ๊ณ , ์ฃผ์ ๋๋ ํ ๋ฆฌ ๋ด ํ์ผ์ ๊ฐ์์ ์ด๋ฆ์ ์ถ๋ ฅํ์ฌ ๋ฐ์ดํฐ ์ค๋น ์ํ๋ฅผ ํ์ธํฉ๋๋ค.
# ์๋๋ ์์์ ์ถ๋ ฅ๋ ์ ๋นํ xml ํ์ผ๋ก ๋ณ๊ฒฝ๋์ด์ผ ํฉ๋๋ค.
!cat /content/DLCV/data/ballnfish/annotations/13825baad5531265.xml
- ์ฃผ์ด์ง ๊ฒฝ๋ก์ XML ํ์ผ๋ค์ ์ฝ์ด ๊ฐ ๊ฐ์ฒด์ ๋ฐ์ด๋ฉ ๋ฐ์ค์ ํด๋์ค ์ ๋ณด๋ฅผ CSV ํ์ผ๋ก ๋ณํํฉ๋๋ค.
- ํด๋์ค ์ด๋ฆ์ ID๋ก ๋งคํํ๊ณ , ์ด๋ฏธ์ง ๊ฒฝ๋ก์ ๊ฐ์ฒด ์ ๋ณด๋ฅผ ํ ์ค์ ๊ธฐ๋กํ์ฌ ๊ฐ์ฒด ๊ฒ์ถ ๋ชจ๋ธ ํ์ต์ ํ์ํ ๋ฐ์ดํฐ๋ฅผ ์ค๋นํฉ๋๋ค.
import glob
import xml.etree.ElementTree as ET
# ํด๋์ค ์ด๋ฆ๊ณผ ID ๋งคํ
classes_map = {'Football':0, 'Football_helmet':1, 'Fish':2, 'Shark':3, 'Shellfish':4 }
def xml_to_csv(path, output_filename):
xml_list = []
# ์ถ๋ ฅ CSV ํ์ผ ์ด๊ธฐ
with open(output_filename, "w") as train_csv_file:
# ์ง์ ๋ ๊ฒฝ๋ก์ ๋ชจ๋ XML ํ์ผ ์ํ
for xml_file in glob.glob(path + '/*.xml'):
# XML ํ์ผ ํ์ฑ
tree = ET.parse(xml_file)
root = tree.getroot()
print('xml file:', xml_file)
# ์ด๋ฏธ์ง ํ์ผ ๊ฒฝ๋ก ๊ฐ์ ธ์ค๊ธฐ
full_image_name = os.path.join(IMAGE_DIR, root.find('filename').text)
value_str_list = ' '
# ๋ชจ๋ object ์์ ์ํ
for obj in root.findall('object'):
xmlbox = obj.find('bndbox')
class_name = obj.find('name').text
x1 = int(xmlbox.find('xmin').text)
y1 = int(xmlbox.find('ymin').text)
x2 = int(xmlbox.find('xmax').text)
y2 = int(xmlbox.find('ymax').text)
# ํด๋์ค ID ๋งคํ
class_id = classes_map[class_name]
# ๊ฐ์ฒด ์ ๋ณด ๋ฌธ์์ด ์์ฑ
value_str = ('{0},{1},{2},{3},{4}').format(x1, y1, x2, y2, class_id)
# ๊ฐ์ฒด ์ ๋ณด ๋์
value_str_list += value_str + ' '
# ์ด๋ฏธ์ง ๊ฒฝ๋ก์ ๊ฐ์ฒด ์ ๋ณด CSV์ ์์ฑ
train_csv_file.write(full_image_name + ' ' + value_str_list + '\n')
# XML ํ์ผ ์ฒ๋ฆฌ ์ข
๋ฃ
xml_to_csv(ANNO_DIR, os.path.join(ANNO_DIR,'ballnfish_anno.csv'))
print(os.path.join(ANNO_DIR,'ballnfish_anno.csv'))
!cat /content/DLCV/data/ballnfish/annotations/ballnfish_anno.csv
์ฌ๊ธฐ์ ๋ณด๋ฉด ์ฌ์ง์ซ์.jpg ๋ค์ ์๋ ์ซ์๋ค์ ์๋ฏธ๋, format(x1, y1, x2, y2, class_id) - ์ข์๋จ, ์ฐํ๋จ ์ขํ
Class id ('Football':0, 'Football_helmet':1, 'Fish':2, 'Shark':3, 'Shellfish':4)๋ฅผ ์๋ฏธํฉ๋๋ค.
๋ฐ์ดํฐ์ ์ฌ์ง ๋ณด๊ธฐ
import cv2
import matplotlib.pyplot as plt
%matplotlib inline
default_dir = '/content'
# 9c27811a78b74a48.jpg ๋ images ๋๋ ํ ๋ฆฌ์ ์๋ ์์ ํ์ผ๋ก ๋ฐ๋์ด์ผ ํฉ๋๋ค.
plt.imshow(cv2.cvtColor(cv2.imread(os.path.join(default_dir, 'DLCV/data/ballnfish/images/9c27811a78b74a48.jpg')), cv2.COLOR_BGR2RGB))
- ์คํ ์์
์คํ ๋ชจ๋ ๋ก๋ฉ
import numpy as np
import keras.backend as K
from keras.layers import Input, Lambda
from keras.models import Model
from keras.optimizers import Adam
from keras.callbacks import TensorBoard, ModelCheckpoint, ReduceLROnPlateau, EarlyStopping
import sys, os
import os
# ์ฝ๋ฉ ์ผ๋ก ๋๋ฆฌ๋ ๊ฒฝ์ฐ์์ ์๋์ ๊ฐ์ด ์ ๋ ๊ฒฝ๋ก๋ฅผ ์ง์ ํ์ฌ Local Package ์ง์ .
default_dir = '/content/DLCV'
default_yolo_dir = os.path.join(default_dir, 'Detection/yolo')
LOCAL_PACKAGE_DIR = os.path.abspath(os.path.join(default_yolo_dir,'keras-yolo3'))
print(LOCAL_PACKAGE_DIR)
sys.path.append(LOCAL_PACKAGE_DIR)
from yolo3.model import preprocess_true_boxes, yolo_body, tiny_yolo_body, yolo_loss
from yolo3.utils import get_random_data
- ๋ํ colab ๋ฒ์ ์ ์๋ ๋ช ๋ น์ด๋ก ballnfish_classes.txt ๋ฅผ ์์
import os
# ์ฝ๋ฉ ํ๊ฒฝ์์ BASE_DIR๊ณผ classes_path ์ค์
BASE_DIR = os.path.join(HOME_DIR, 'DLCV/Detection/yolo/keras-yolo3')
classes_path = os.path.join(BASE_DIR, 'model_data/ballnfish_classes.txt')
# ํด๋์ค ์ด๋ฆ์ ํ์ผ์ ์์ฑ
with open(classes_path, "w") as f:
f.write("Football\n")
f.write("Football_Helmet\n")
f.write("Fish\n")
f.write("Shark\n")
f.write("Shell_Fish\n")
# ์์ฑ๋ ํด๋์ค ํ์ผ ๋ด์ฉ ํ์ธ
!cat /content/DLCV/Detection/yolo/keras-yolo3/model_data/ballnfish_classes.txt
from train import get_classes, get_anchors
from train import create_model, data_generator, data_generator_wrapper
# BASE_DIR ์ค์ : YOLO ๋ชจ๋ธ ๊ด๋ จ ํ์ผ๋ค์ด ์๋ ๋๋ ํ ๋ฆฌ
BASE_DIR = os.path.join(HOME_DIR, 'DLCV/Detection/yolo/keras-yolo3')
# ํ์ต์ ์ํ ๊ฒฝ๋ก ์ค์
annotation_path = os.path.join(ANNO_DIR, 'ballnfish_anno.csv') # ์ฃผ์ ํ์ผ ๊ฒฝ๋ก
log_dir = os.path.join(BASE_DIR, 'snapshots/ballnfish/') # ํ์ต๋ ๋ชจ๋ธ ์ ์ฅ ๊ฒฝ๋ก
classes_path = os.path.join(BASE_DIR, 'model_data/ballnfish_classes.txt') # ํด๋์ค ํ์ผ ๊ฒฝ๋ก
anchors_path = os.path.join(BASE_DIR, 'model_data/yolo_anchors.txt') # ์ต์ปค ํ์ผ ๊ฒฝ๋ก
# ํด๋์ค ์ด๋ฆ๊ณผ ์ ๋ถ๋ฌ์ค๊ธฐ
class_names = get_classes(classes_path)
num_classes = len(class_names)
# ์ต์ปค ๋ฐ์ค ๋ถ๋ฌ์ค๊ธฐ
anchors = get_anchors(anchors_path)
# ํด๋์ค ์ด๋ฆ, ํด๋์ค ์ ๋ฐ ์ต์ปค ๋ฐ์ค ์ถ๋ ฅ
print(class_names, num_classes)
print(anchors)
- ์ถ๋ ฅ๋ ๊ฒฐ๊ณผ๊ฐ (์๋จ์ ํด๋์ค ๋ช & ๊ฐ์), ํ๋จ์ ์ซ์๋ค์ Anchor Box๋ค์ ์ ๋ณด๋ฅผ ์๋ฏธ.
yolo ๋ชจ๋ธ ํ์ต์ ์ํ ์ ๋ฐ์ ์ธ ํ๋ผ๋ฏธํฐ๋ฅผ config ํด๋์ค๋ก ์ค์ ํ๊ณ ํ์์ ์ด๋ฅผ ์์ ํ์ฌ ํ์ต.
์ฆ, ์ด๋ง์ config๋ฅผ ํด๋์คํ ํ๋ค๋ ์๋ฏธ์ ๋๋ค.
# csv annotation ํ์ผ์ ์ฝ์ด์ lines ๋ฆฌ์คํธ๋ก ๋ง๋ฌ.
with open(annotation_path) as f:
lines = f.readlines()
class config:
#tiny yolo๋ก ๋ชจ๋ธ๋ก ์ด๊ธฐ weight ํ์ต ์ํ ์ ์๋๋ฅผ tiny-yolo.h5๋ก ์์ .
initial_weights_path=os.path.join(BASE_DIR, 'model_data/yolo.h5' )
# input_shape๋ ๊ณ ์ .
input_shape=(416, 416)
# epochs๋ freeze, unfreeze 2 step์ ๋ฐ๋ผ ์ค์ .
first_epochs=50
first_initial_epochs=0
second_epochs=100
second_initial_epochs=50
# ํ์ต์ batch size, train,valid๊ฑด์, epoch steps ํ์
batch_size = 4
val_split = 0.1
num_val = int(len(lines)*val_split)
num_train = len(lines) - num_val
train_epoch_steps = num_train//batch_size
val_epoch_steps = num_val//batch_size
anchors = get_anchors(anchors_path)
class_names = get_classes(classes_path)
num_classes = len(class_names)
# epoch์ ์ ์ฅ๋ weight ํ์ผ ๋๋ ํ ๋ฆฌ
log_dir = os.path.join(BASE_DIR, 'snapshots/ballnfish/')
print('Class name:', config.class_names,'\nNum classes:', config.num_classes)
csv ํ์ผ์ ์ ๋ ฅ ๋ฐ์์ train ๋ฐ์ดํฐ์ valid ๋ฐ์ดํฐ ์ฒ๋ฆฌ๋ฅผ ์ํ data_generator_wrapper๊ฐ์ฒด๋ฅผ ๊ฐ๊ฐ ์์ฑ.
- train์ฉ, valid ์ฉ data_generator_wrapper๋ Yolo ๋ชจ๋ธ์ fit_generator()ํ์ต์ ์ธ์๋ก ์ ๋ ฅ๋จ.
def create_generator(lines):
# ํ๋ จ ๋ฐ์ดํฐ ์์ฑ๊ธฐ ์์ฑ (๋ฐ์ดํฐ์ ์ฒ์ config.num_train ๊ฐ ์ฌ์ฉ)
train_data_generator = data_generator_wrapper(
lines[:config.num_train],
config.batch_size,
config.input_shape,
config.anchors,
config.num_classes
)
# ๊ฒ์ฆ ๋ฐ์ดํฐ ์์ฑ๊ธฐ ์์ฑ (๋๋จธ์ง ๋ฐ์ดํฐ ์ฌ์ฉ)
valid_data_generator = data_generator_wrapper(
lines[config.num_train:],
config.batch_size,
config.input_shape,
config.anchors,
config.num_classes
)
# ํ๋ จ ๋ฐ ๊ฒ์ฆ ๋ฐ์ดํฐ ์์ฑ๊ธฐ ๋ฐํ
return train_data_generator, valid_data_generator
- YOLO ๋ชจ๋ธ ๋๋ tiny yolo ๋ชจ๋ธ ๋ฐํ. ์ด๊ธฐ weight๊ฐ์ pretrained๋ yolo weight๊ฐ์ผ๋ก ํ ๋น.
# anchor ๊ฐ์์ ๋ฐ๋ผ Tiny YOLO ๋ชจ๋ธ ๋๋ ์ผ๋ฐ YOLO ๋ชจ๋ธ์ ์์ฑํ๋ ํจ์
def create_yolo_model():
# ์ต์ปค ๊ฐ์๊ฐ 6๊ฐ์ด๋ฉด Tiny YOLO ๋ฒ์ ์ผ๋ก ๊ฐ์ฃผ
is_tiny_version = len(config.anchors) == 6
if is_tiny_version:
# Tiny YOLO ๋ชจ๋ธ ์์ฑ
model = create_tiny_model(
config.input_shape,
config.anchors,
config.num_classes,
freeze_body=2,
weights_path=config.initial_weights_path
)
else:
# ์ผ๋ฐ YOLO ๋ชจ๋ธ ์์ฑ
model = create_model(
config.input_shape,
config.anchors,
config.num_classes,
freeze_body=2,
weights_path=config.initial_weights_path
)
return model
- ์ต์ปค์ ๊ฐ์์ ๋ฐ๋ผ Tiny YOLO ๋ชจ๋ธ ๋๋ ์ผ๋ฐ YOLO ๋ชจ๋ธ์ ์์ฑํฉ๋๋ค.
- ์ต์ปค๊ฐ 6๊ฐ์ธ ๊ฒฝ์ฐ Tiny YOLO๋ฅผ ์ฌ์ฉํ๊ณ , ๊ทธ๋ ์ง ์์ผ๋ฉด ํ์ค YOLO ๋ชจ๋ธ์ ์์ฑํฉ๋๋ค.
# Tensorboard, ModelCheckpoint, ReduceLROnPlateau, EarlyStopping callback ๋ฐํ
def create_callbacks():
logging = TensorBoard(log_dir=config.log_dir)
checkpoint = ModelCheckpoint(config.log_dir + 'ep{epoch:03d}-loss{loss:.3f}-val_loss{val_loss:.3f}.h5',
monitor='val_loss', save_weights_only=True, save_best_only=True, period=3)
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=3, verbose=1)
early_stopping = EarlyStopping(monitor='val_loss', min_delta=0, patience=10, verbose=1)
#๊ฐ๋ณ callback๋ค์ ํ๊บผ๋ฒ์ list๋ก ๋ฌถ์ด์ ๋ฐํ
return [logging, checkpoint, reduce_lr, early_stopping]
ํ์ต ์ํ
# create_generator(), create_model(), create_callbacks() ์ํ.
train_data_generator, valid_data_generator = create_generator(lines)
ballnfish_model = create_yolo_model()
callbacks_list = create_callbacks()
# ์ต์ด ๋ชจ๋ธ์ ์ฃผ์ layer๊ฐ freeze๋์ด ์์. ์์ ์ ์ธ loss๋ฅผ ํ๋ณดํ๊ธฐ ์ํด ์ฃผ์ layer๋ฅผ freezeํ ์ํ๋ก ๋จผ์ ํ์ต.
print('First train ์์')
ballnfish_model.compile(optimizer=Adam(lr=1e-3), loss={'yolo_loss': lambda y_true, y_pred: y_pred})
# ์ฒซ ๋ฒ์งธ ๋จ๊ณ ํ์ต ์ํ
ballnfish_model.fit_generator(
train_data_generator,
steps_per_epoch=config.train_epoch_steps,
validation_data=valid_data_generator,
validation_steps=config.val_epoch_steps,
epochs=config.first_epochs,
initial_epoch=config.first_initial_epochs,
callbacks=callbacks_list
)
# 1๋จ๊ณ ํ์ต ์๋ฃ ๋ชจ๋ธ ์ ์ฅ
ballnfish_model.save_weights(os.path.join(log_dir, 'trained_weights_stage_1.h5'))
# ๋ชจ๋ layer๋ฅผ trainable=True๋ก ์ค์ ํ๊ณ ์ถ๊ฐ ํ์ต ์ํ
for layer in ballnfish_model.layers:
layer.trainable = True
print('Second train ์์')
ballnfish_model.compile(optimizer=Adam(lr=1e-4), loss={'yolo_loss': lambda y_true, y_pred: y_pred})
# ๋ ๋ฒ์งธ ๋จ๊ณ ํ์ต ์ํ
ballnfish_model.fit_generator(
train_data_generator,
steps_per_epoch=config.train_epoch_steps,
validation_data=valid_data_generator,
validation_steps=config.val_epoch_steps,
epochs=config.second_epochs,
initial_epoch=config.second_initial_epochs,
callbacks=callbacks_list
)
# ์ต์ข
ํ์ต ์๋ฃ ๋ชจ๋ธ ์ ์ฅ
ballnfish_model.save_weights(os.path.join(log_dir, 'trained_weights_final.h5'))
- ์ฒซ ๋ฒ์งธ ๋จ๊ณ์์๋ ์ฃผ์ ๋ ์ด์ด๋ฅผ ๊ณ ์ (freeze)ํ ์ํ๋ก ๋ชจ๋ธ์ ํ์ตํ์ฌ ์์ ์ ์ธ ์์ค์ ํ๋ณดํ๊ณ , ํ์ต์ด ์๋ฃ๋๋ฉด ๋ชจ๋ ๋ ์ด์ด๋ฅผ ํ์ต ๊ฐ๋ฅ(trainable) ์ํ์ผ๋ก ์ค์ ํ ํ ์ถ๊ฐ๋ก ํ์ต์ ์งํํฉ๋๋ค.
- ๊ฐ ๋จ๊ณ ํ์ ๋ชจ๋ธ์ ๊ฐ์ค์น๋ฅผ ์ ์ฅํ์ฌ ๋์ค์ ์ฌ์ฉํ ์ ์๋๋ก ํฉ๋๋ค.
์ต์ข ํ์ต๋ ๋ชจ๋ธ์ ๋ก๋ฉํ์ฌ Object Detection ์ํ.
from yolo import YOLO
#keras-yolo์์ image์ฒ๋ฆฌ๋ฅผ ์ฃผ์ PIL๋ก ์ํ.
from PIL import Image
# ์ฝ๋ฉ ๋ฒ์ ์ ์๋์ ๊ฐ์ด ์ ๋ ๊ฒฝ๋ก๋ฅผ ์ง์ ํ์ฌ Local Package ์ง์ .
default_dir = '/content/DLCV'
default_yolo_dir = os.path.join(default_dir, 'Detection/yolo')
LOCAL_PACKAGE_DIR = os.path.abspath(os.path.join(default_yolo_dir,'keras-yolo3'))
sys.path.append(LOCAL_PACKAGE_DIR)
ballnfish_yolo = YOLO(model_path='/content/DLCV/Detection/yolo/keras-yolo3/snapshots/ballnfish/trained_weights_final.h5',
anchors_path='/content/DLCV/Detection/yolo/keras-yolo3/model_data/yolo_anchors.txt',
classes_path='/content/DLCV/Detection/yolo/keras-yolo3/model_data/ballnfish_classes.txt')
์ด๋ฏธ์ง Object Detection.
%cd /content/DLCV/data/ballnfish
!ls annotations
# ์๋ football_list๋ ์ ์ ํ jpg ํ์ผ๋ช
์ผ๋ก ๋ณ๊ฒฝ๋์ด์ผ ํฉ๋๋ค.
football_list = ['f1b492a9bce3ac9a.jpg', '1e6ff631bb0c198b.jpg', '97ac013310bda756.jpg',
'e5b1646c395aecfd.jpg', '53ef241dad498f6c.jpg', '02ccbf5ddaaecedb.jpg' ]
for image_name in football_list:
img = Image.open(os.path.join(IMAGE_DIR, image_name))
detected_img = ballnfish_yolo.detect_image(img)
plt.figure(figsize=(8, 8))
plt.imshow(detected_img)
- Object Detection ๊ฒฐ๊ณผ
# ์๋ helmet_list๋ ์ ์ ํ jpg ํ์ผ๋ช
์ผ๋ก ๋ณ๊ฒฝ๋์ด์ผ ํฉ๋๋ค.
helmet_list = ['1fed5c930211c6e0.jpg', '011a59a160d7a091.jpg', 'd39b46aa4bc0c165.jpg', '7e9eb7eba80e34e7.jpg', '9c27811a78b74a48.jpg']
for image_name in helmet_list:
img = Image.open(os.path.join(IMAGE_DIR, image_name))
detected_img = ballnfish_yolo.detect_image(img)
plt.figure(figsize=(8, 8))
plt.imshow(detected_img)
- Object Detection ๊ฒฐ๊ณผ
# ์๋ fish_list๋ ์ ์ ํ jpg ํ์ผ๋ช
์ผ๋ก ๋ณ๊ฒฝ๋์ด์ผ ํฉ๋๋ค.
fish_list = ['25e42c55bfcbaa88.jpg', 'a571e4cdcfbcb79e.jpg', '872c435491f2b4d3.jpg',
'bebac23c45451d93.jpg', 'eba7caf07a26829b.jpg', 'dc607a2989bdc9dc.jpg' ]
for image_name in fish_list:
img = Image.open(os.path.join(IMAGE_DIR, image_name))
detected_img = ballnfish_yolo.detect_image(img)
plt.figure(figsize=(8, 8))
plt.imshow(detected_img)
- Object Detection ๊ฒฐ๊ณผ (์ฒซ๋ฒ์งธ, 3๋ฒ์งธ ์ฌ์ง์ ํ์ง ์ค๋ฅ์ ๋๋ค.)
#์๋ shark_list๋ ์ ์ ํ jpg ํ์ผ๋ช
์ผ๋ก ๋ณ๊ฒฝ๋์ด์ผ ํฉ๋๋ค.
shark_list = ['d92290f6c04dd83b.jpg', '3a37a09ec201cdeb.jpg', '32717894b5ce0052.jpg', 'a848df5dbed78a0f.jpg', '3283eafe11a847c3.jpg']
for image_name in shark_list:
img = Image.open(os.path.join(IMAGE_DIR, image_name))
detected_img = ballnfish_yolo.detect_image(img)
plt.figure(figsize=(8, 8))
plt.imshow(detected_img)
- Object Detection ๊ฒฐ๊ณผ
#์๋ shell_list๋ ์ ์ ํ jpg ํ์ผ๋ช
์ผ๋ก ๋ณ๊ฒฝ๋์ด์ผ ํฉ๋๋ค.
shell_list=['5cc89bc28084e8e8.jpg', '055e756883766e1f.jpg', '089354fc39f5d82d.jpg', '80eddfdcb3384458.jpg']
for image_name in shell_list:
img = Image.open(os.path.join(IMAGE_DIR, image_name))
detected_img = ballnfish_yolo.detect_image(img)
plt.figure(figsize=(8, 8))
plt.imshow(detected_img)
- Object Detection ๊ฒฐ๊ณผ
์์ Object Detection์ Part.2 ๊ธ๋ก ์ด์ด์ง๋๋ค.
'๐ Computer Vision' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
[CV] OpenImage Dataset์ Object Detection Inference Part.2 (with Keras-yolo) (0) | 2025.01.22 |
---|---|
[CV] Google Open Image Dataset (0) | 2024.10.11 |
[CV] Keras YOLO๋ก Raccoon Dataset์ ์ด์ฉํ Object Detection (0) | 2024.10.10 |
[CV] Object Detection Model Training์ ์ ์์ฌํญ (0) | 2024.10.08 |
[CV] Keras ๊ธฐ๋ฐ YOLO Open Source Package & Object Detection (0) | 2024.10.07 |