A A
[ํ˜ผ๊ณต๋จธ์‹ ] Cross-Validation & Grid Search

Validation Set (๊ฒ€์ฆ ์„ธํŠธ)

Test Dataset์„ ์‚ฌ์šฉํ•˜์ง€ ์•Š์œผ๋ฉด ๋ชจ๋ธ์ด ๊ณผ๋Œ€์ ํ•ฉ์ธ์ง€ ๊ณผ์†Œ์ ํ•ฉ์ธ์ง€ ํŒ๋‹จํ•˜๊ธฐ ์–ด๋ ต์Šต๋‹ˆ๋‹ค.
  • ํ…Œ์ŠคํŠธ ์„ธํŠธ๋ฅผ ์‚ฌ์šฉํ•˜์ง€ ์•Š๊ณ  ์ด๋ฅผ ์ธก์ •ํ•˜๋Š” ๊ฐ„๋‹จํ•œ ๋ฐฉ๋ฒ•์€ Training Dataset์„ ๋‚˜๋ˆ„๋Š” ๋ฐฉ๋ฒ•์ž…๋‹ˆ๋‹ค.
  • ์ด ๋‚˜๋ˆˆ Dataset๋ฅผ Validation Set (๊ฒ€์ฆ ์„ธํŠธ)๋ผ๊ณ  ๋ถ€๋ฆ…๋‹ˆ๋‹ค.
  • ์ „์ฒด ๋ฐ์ดํ„ฐ์…‹์ด 100%๋ผ๊ณ  ํ•˜๋ฉด ์ „์ฒด ๋ฐ์ดํ„ฐ์…‹ ์ค‘์—์„œ 20%๋งŒ Test Dataset์œผ๋กœ ๋งŒ๋“ค๊ณ , 80%๋ฅผ Train Dataset์œผ๋กœ ๊ตฌ์„ฑํ–ˆ์œผ๋ฉด, ์ด Training Dataset์ค‘ 20%๋ฅผ ๋–ผ์–ด ๋‚ด์–ด์„œ Validation Dataset์œผ๋กœ ๋‚˜๋ˆ•๋‹ˆ๋‹ค.

  • Training Dataset์—์„œ Model์„ Trainingํ•˜๊ณ  Validation Set๋กœ ๋ชจ๋ธ์„ ํ‰๊ฐ€ํ•ฉ๋‹ˆ๋‹ค.
  • ์ด๋Ÿฐ์‹์œผ๋กœ Test ํ•˜๊ณ  ์‹ถ์€ ๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ ๊ณ ์น˜๋ฉด์„œ ๊ฐ€์žฅ ์ข‹์€ ๋ชจ๋ธ์„ ๊ณ ๋ฅธํ›„, ์ „์ฒด ๋ฐ์ดํ„ฐ์…‹์œผ๋กœ ๋ชจ๋ธ์„ ๋‹ค์‹œ ํ›ˆ๋ จํ›„, TestSet์œผ๋กœ ๋ชจ๋ธ์˜ ์ตœ์ข… ์ ์ˆ˜๋ฅผ ํ‰๊ฐ€ํ•ฉ๋‹ˆ๋‹ค.
  • ํ•œ๋ฒˆ ์˜ˆ์‹œ๋ฅผ ๋ณด๋„๋ก ํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค.
import pandas as pd
wine = pd.read_csv('https://bit.ly/wine-date')

data = wine[['alcohol', 'sugar', 'ph']].to_numpy()
target = wine['class']. to_numpy()
  • ๊ทธ๋Ÿฌ๋ฉด ์ด์ œ Training Set, Test Set๋ฅผ ํ•œ๋ฒˆ ๋‚˜๋ˆ ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.
from sklearn.model_selection import train_test_split
train_input, test_input, train_target, test_target = train_test_split(data, target, test_size=0.2, random_state=42)
  • train_input๊ณผ train_target์„ train_test_split() ํ•จ์ˆ˜์— ๋„ฃ์–ด ํ›ˆ๋ จ ์„ธํŠธ์ธ sub_input๊ณผ sub_target, ๊ทธ๋ฆฌ๊ณ  ๊ฒ€์ฆ ์„ธํŠธ์ธ val_input๊ณผ val_target์„ ๋งŒ๋“ญ๋‹ˆ๋‹ค.
  • ์—ฌ๊ธฐ์„œ test_size ๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ 0.2๋กœ ์ง€์ •ํ•˜์—ฌ train_input์˜ ์•ฝ 20%๋ฅผ val_input์œผ๋กœ ์„ค์ •ํ•ฉ๋‹ˆ๋‹ค.
sub_input, val_input, sub_target, val_target = train_test_split(
train_input, train_target, test_size=0.2, random_state=42)
  • ๊ทธํ›„, Training & Test Set์˜ ๋ฐ์ดํ„ฐ ํฌ๊ธฐ๋ฅผ ํ™•์ธํ•ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.
print(sub_input.shape, val_input.shape)

# (4157, 3) (1040, 3)
  • ์›๋ž˜ 5,197๊ฐœ์˜€๋˜ ํ›ˆ๋ จ ์„ธํŠธ๊ฐ€ 4,157๊ฐœ๋กœ ์ค„์–ด๋“ค๊ณ , ๊ฒ€์ฆ ์„ธํŠธ๋Š” 1,040๊ฐœ๊ฐ€ ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.
  • ๊ทธ๋Ÿฌ๋ฉด ์ด์ œ sub_input, sub_target๊ณผ val_input, val_target์„ ์‚ฌ์šฉํ•˜์—ฌ ๋ชจ๋ธ์„ ๋งŒ๋“ค๊ณ  ํ‰๊ฐ€ํ•ด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.
from sklearn.tree import DecisionTreeClassifier
dt = DecisionTreeClassifier(random_state=42)
dt.fit(sub_input, sub_target)

print(dt.score(sub_input, sub_target))
print(dt.score(val_input, val_target))

# 0.9971133028626413
# 0.864423076923077
  • ์ด๋ ‡๊ฒŒ val_input๊ณผ val_target์„ ์‚ฌ์šฉํ•˜์—ฌ ๋ชจ๋ธ์„ ํ‰๊ฐ€ํ•˜๋ฉด ๋ฉ๋‹ˆ๋‹ค.
  • ๋ณด๋ฉด ์ด ๋ชจ๋ธ์€ ํ™•์‹คํžˆ ํ›ˆ๋ จ ์„ธํŠธ์— ๊ณผ๋Œ€์ ํ•ฉ๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค.
  • ๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ ์กฐ์ •ํ•˜์—ฌ ๋” ์ข‹์€ ๋ชจ๋ธ์„ ์ฐพ์•„์•ผ ํ•ฉ๋‹ˆ๋‹ค.

Cross-Validation (๊ต์ฐจ ๊ฒ€์ฆ)

๊ฒ€์ฆ ์„ธํŠธ๋ฅผ ๋งŒ๋“ค๋ฉด์„œ ํ›ˆ๋ จ ์„ธํŠธ๊ฐ€ ์ค„์–ด๋“ค์—ˆ์Šต๋‹ˆ๋‹ค. ์ผ๋ฐ˜์ ์œผ๋กœ ๋งŽ์€ ๋ฐ์ดํ„ฐ๋ฅผ ํ›ˆ๋ จ์— ์‚ฌ์šฉํ• ์ˆ˜๋ก ์ข‹์€ ๋ชจ๋ธ์ด ๋งŒ๋“ค์–ด์ง‘๋‹ˆ๋‹ค.
  • ๊ทผ๋ฐ ๊ฒ€์ฆ ์„ธํŠธ๋ฅผ ๋„ˆ๋ฌด ์ ๊ฒŒ ๋–ผ์–ด๋†“์œผ๋ฉด ๊ฒ€์ฆ ์ ์ˆ˜๊ฐ€ ๋“ค์ญ‰๋‚ ์ญ‰ํ•˜๊ณ  ๋ถˆ์•ˆ์ •ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
  • ์ด๋•Œ ๊ต์ฐจ ๊ฒ€์ฆ(cross-validation)์„ ์‚ฌ์šฉํ•˜๋ฉด ์•ˆ์ •์ ์ธ ๊ฒ€์ฆ ์ ์ˆ˜๋ฅผ ์–ป๊ณ  ํ›ˆ๋ จ์— ๋” ๋งŽ์€ ๋ฐ์ดํ„ฐ๋ฅผ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

3-fold Cross-Validation
ํ›ˆ๋ จ ์„ธํŠธ๋ฅผ ์„ธ ๋ถ€๋ถ„์œผ๋กœ ๋‚˜๋ˆ„์–ด ๊ต์ฐจ ๊ฒ€์ฆ์„ ์ˆ˜ํ–‰ํ•˜๋Š” ๊ฒƒ์„ 3-ํด๋“œ ๊ต์ฐจ ๊ฒ€์ฆ์ด๋ผ๊ณ  ํ•ฉ๋‹ˆ๋‹ค.
์ด๋ฅผ ํ†ต์นญํ•˜์—ฌ K-ํด๋“œ ๊ต์ฐจ ๊ฒ€์ฆ(k-fold cross-validation)์ด๋ผ๊ณ  ํ•˜๋ฉฐ, ํ›ˆ๋ จ ์„ธํŠธ๋ฅผ ๋ช‡ ๋ถ€๋ถ„์œผ๋กœ ๋‚˜๋ˆ„๋Š๋ƒ์— ๋”ฐ๋ผ ๋‹ค๋ฅด๊ฒŒ ๋ถ€๋ฆ…๋‹ˆ๋‹ค. K-๊ฒน ๊ต์ฐจ ๊ฒ€์ฆ์ด๋ผ๊ณ ๋„ ๋ถˆ๋ฆฝ๋‹ˆ๋‹ค.

 

  • 3-ํด๋“œ ๊ต์ฐจ ๊ฒ€์ฆ์„ ์˜ˆ์‹œ๋กœ ๋“ค์—ˆ์ง€๋งŒ, ๋ณดํ†ต 5-ํด๋“œ ๊ต์ฐจ ๊ฒ€์ฆ์ด๋‚˜ 10-ํด๋“œ ๊ต์ฐจ ๊ฒ€์ฆ์„ ๋งŽ์ด ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.
  • ์ด๋ ‡๊ฒŒ ํ•˜๋ฉด ๋ฐ์ดํ„ฐ์˜ 80~90%๊นŒ์ง€ ํ›ˆ๋ จ์— ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๊ฒ€์ฆ ์„ธํŠธ๊ฐ€ ์ค„์–ด๋“ค์ง€๋งŒ, ๊ฐ ํด๋“œ์—์„œ ๊ณ„์‚ฐํ•œ ๊ฒ€์ฆ ์ ์ˆ˜๋ฅผ ํ‰๊ท ํ•˜๊ธฐ ๋•Œ๋ฌธ์— ์•ˆ์ •๋œ ์ ์ˆ˜๋กœ ๊ฐ„์ฃผํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
  • Scikit-learn์—๋Š” cross_validate()๋ผ๋Š” ๊ต์ฐจ ๊ฒ€์ฆ ํ•จ์ˆ˜๊ฐ€ ์žˆ์Šต๋‹ˆ๋‹ค. ์‚ฌ์šฉ๋ฒ•์€ ๊ฐ„๋‹จํ•œ๋ฐ, ๋จผ์ € ํ‰๊ฐ€ํ•  ๋ชจ๋ธ ๊ฐ์ฒด๋ฅผ ์ฒซ ๋ฒˆ์งธ ๋งค๊ฐœ๋ณ€์ˆ˜๋กœ ์ „๋‹ฌํ•ฉ๋‹ˆ๋‹ค.
  • ๊ทธ๋‹ค์Œ ์•ž์—์„œ์ฒ˜๋Ÿผ ์ง์ ‘ ๊ฒ€์ฆ ์„ธํŠธ๋ฅผ ๋–ผ์–ด๋‚ด์ง€ ์•Š๊ณ  ํ›ˆ๋ จ ์„ธํŠธ ์ „์ฒด๋ฅผ cross_validate() ํ•จ์ˆ˜์— ์ „๋‹ฌํ•ฉ๋‹ˆ๋‹ค.
๊ทธ๋ฆฌ๊ณ  ์‚ฌ์ดํ‚ท๋Ÿฐ์—๋Š” cross_validate() ํ•จ์ˆ˜์˜ ์ „์‹ ์ธ cross_val_score()๋„ ์žˆ์Šต๋‹ˆ๋‹ค.
์ด ํ•จ์ˆ˜๋Š” cross_validate() ํ•จ์ˆ˜์˜ ๊ฒฐ๊ณผ ์ค‘์—์„œ test_score ๊ฐ’๋งŒ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
from sklearn.model_selection import cross_validate
scores = cross_validate(dt, train_input, train_target)
print(scores)
{'fit_time': array([0.02602839, 0.02728128, 0.04096222, 0.01634145, 0.01161575]),
'score_time': array([0.00565434, 0.01029921, 0.00742149, 0.00183392, 0.00161386]),
'test_score': array([0.86923077, 0.84615385, 0.87680462, 0.84889317, 0.83541867])}
  • 3-ํด๋“œ ๊ต์ฐจ ๊ฒ€์ฆ์„ ์˜ˆ์‹œ๋กœ ๋“ค์—ˆ์ง€๋งŒ, ๋ณดํ†ต 5-ํด๋“œ ๊ต์ฐจ ๊ฒ€์ฆ์ด๋‚˜ 10-ํด๋“œ ๊ต์ฐจ ๊ฒ€์ฆ์„ ๋งŽ์ด ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.
  • ์ด๋ ‡๊ฒŒ ํ•˜๋ฉด ๋ฐ์ดํ„ฐ์˜ 80~90%๊นŒ์ง€ ํ›ˆ๋ จ์— ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
  • ์ด๋•Œ ๊ฒ€์ฆ ์„ธํŠธ๋Š” ์ค„์–ด๋“ค์ง€๋งŒ, ๊ฐ ํด๋“œ์—์„œ ๊ณ„์‚ฐํ•œ ๊ฒ€์ฆ ์ ์ˆ˜๋ฅผ ํ‰๊ท ํ•˜๊ธฐ ๋•Œ๋ฌธ์— ์•ˆ์ •๋œ ์ ์ˆ˜๋กœ ๊ฐ„์ฃผํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
  • ๋˜ํ•œ Scikit-learn์—๋Š” cross_validate()๋ผ๋Š” ๊ต์ฐจ ๊ฒ€์ฆ ํ•จ์ˆ˜๊ฐ€ ์žˆ์Šต๋‹ˆ๋‹ค.
  • ์‚ฌ์šฉ๋ฒ•์€ ๊ฐ„๋‹จํ•œ๋ฐ, ๋จผ์ € ํ‰๊ฐ€ํ•  ๋ชจ๋ธ ๊ฐ์ฒด๋ฅผ ์ฒซ ๋ฒˆ์งธ ๋งค๊ฐœ๋ณ€์ˆ˜๋กœ ์ „๋‹ฌํ•ฉ๋‹ˆ๋‹ค.
  • ๊ทธ๋‹ค์Œ ์•ž์—์„œ์ฒ˜๋Ÿผ ์ง์ ‘ ๊ฒ€์ฆ ์„ธํŠธ๋ฅผ ๋–ผ์–ด๋‚ด์ง€ ์•Š๊ณ  ํ›ˆ๋ จ ์„ธํŠธ ์ „์ฒด๋ฅผ cross_validate() ํ•จ์ˆ˜์— ์ „๋‹ฌํ•ฉ๋‹ˆ๋‹ค.
  • ์ฐธ๊ณ ๋กœ, ์‚ฌ์ดํ‚ท๋Ÿฐ์—๋Š” cross_validate() ํ•จ์ˆ˜์˜ ์ „์‹ ์ธ cross_val_score()๋„ ์žˆ์Šต๋‹ˆ๋‹ค.
  • ์ด ํ•จ์ˆ˜๋Š” cross_validate() ํ•จ์ˆ˜์˜ ๊ฒฐ๊ณผ ์ค‘์—์„œ test_score ๊ฐ’๋งŒ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.

 

  • ๊ต์ฐจ๊ฒ€์ฆ์˜ ์ตœ์ข…์ ์ˆ˜๋Š” test_score ํ‚ค์— ๋‹ด๊ธด 5๊ฐœ์˜ ์ ์ˆ˜๋ฅผ ํ‰๊ท ํ•˜์—ฌ ์–ป์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
import numpy as np
print(np.mean(scores['test_score']))

# 0.855300214703487
  • ๊ต์ฐจ ๊ฒ€์ฆ์„ ์ˆ˜ํ–‰ํ•˜๋ฉด ์ž…๋ ฅํ•œ ๋ชจ๋ธ์—์„œ ์–ป์„ ์ˆ˜ ์žˆ๋Š” ์ตœ์ƒ์˜ ๊ฒ€์ฆ ์ ์ˆ˜๋ฅผ ๊ฐ€๋Š ํ•ด ๋ณผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
  • ํ•œ ๊ฐ€์ง€ ์ฃผ์˜ํ•  ์ ์€ cross_validate()๋Š” ํ›ˆ๋ จ ์„ธํŠธ๋ฅผ ์‰ฝ๊ฒŒ ํด๋“œ๋ฅผ ๋‚˜๋ˆ„์ง€ ์•Š์Šต๋‹ˆ๋‹ค.
  • ์•ž์„œ train_test_split() ํ•จ์ˆ˜๋กœ ์ „์ฒด ๋ฐ์ดํ„ฐ๋ฅผ ์‰ฝ์€ ํ›„ ํ›ˆ๋ จ ์„ธํŠธ๋ฅผ ์ค€๋น„ํ–ˆ๊ธฐ ๋•Œ๋ฌธ์— ๋”ฐ๋กœ ์„ž์„ ํ•„์š”๋Š” ์—†์ง€๋งŒ, ๋งŒ์•ฝ ๊ต์ฐจ ๊ฒ€์ฆ์„ ํ•  ๋•Œ ํ›ˆ๋ จ ์„ธํŠธ๋ฅผ ์„ž๋Š” ๊ฒฝ์šฐ์—๋Š” ๋ถ„ํ• ๊ธฐ(splitter)๋ฅผ ์ง€์ •ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.
  • Scikit-learn์˜ ๋ถ„ํ• ๊ธฐ๋Š” ๊ต์ฐจ ๊ฒ€์ฆ์—์„œ ํด๋“œ๋ฅผ ์–ด๋–ป๊ฒŒ ๋‚˜๋ˆ„์ง€ ๊ฒฐ์ •ํ•ด ์ค๋‹ˆ๋‹ค.
  • cross_validate() ํ•จ์ˆ˜๋Š” ๊ธฐ๋ณธ์ ์œผ๋กœ ํšŒ๊ท€ ๋ชจ๋ธ์ผ ๊ฒฝ์šฐ KFold ๋ถ„ํ• ๊ธฐ๋ฅผ ์‚ฌ์šฉํ•˜๊ณ  ๋ถ„๋ฅ˜ ๋ชจ๋ธ์ผ ๊ฒฝ์šฐ ํƒ€๊นƒ ํด๋ž˜์Šค๋ฅผ ๊ณจ๊ณ ๋ฃจ ๋‚˜๋ˆ„๊ธฐ ์œ„ํ•ด StratifiedKFold๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.
from sklearn.model_selection import StratifiedKFold
scores = cross_validate(dt, train_input, train_target, cv=StratifiedKFold())
print(np.mean(scores['test_score']))

# 0.855300214703487
  • ๋งŒ์•ฝ, 10-Fold cross validate (10-ํด๋“œ ๊ต์ฐจ์ ์ฆ)์„ ์ˆ˜ํ–‰ํ•˜๋ ค๋ฉด ์•„๋ž˜์˜ ์ฝ”๋“œ์™€ ๊ฐ™์ด ์ž‘์„ฑํ•ฉ๋‹ˆ๋‹ค.
  • KFold Class๋„ ๋™์ผํ•œ ๋ฐฉ์‹์œผ๋กœ ์‚ฌ์šฉ ๊ฐ€๋Šฅ ํ•ฉ๋‹ˆ๋‹ค.


HyperParameter Tuning (ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ ํŠœ๋‹)

๋ชจ๋ธ์ด ํ•™์Šตํ• ์ˆ˜ ์—†์–ด์„œ ์‚ฌ์šฉ์ž๊ฐ€ ๊ผญ ๋ชจ๋ธ์—๊ฒŒ ์ง€์ •์„ ํ•ด์ค˜์•ผ ํ•˜๋Š” ๋ถ€๋ถ„์ด ์žˆ์Šต๋‹ˆ๋‹ค.
์‚ฌ์šฉ์ž๊ฐ€ ์ง€์ •์„ ํ•ด์ค˜์•ผ ํ•˜๋Š” ๋ถ€๋ถ„์„ HyperParameter ๋ผ๊ณ  ํ•ฉ๋‹ˆ๋‹ค.
  • ๊ทธ๋Ÿฌ๋ฉด ์ด๋Ÿฐ Hyperparameter๋ฅผ ํŠœ๋‹ํ•˜๋Š” ์ž‘์—…์€ ์–ด๋–ป๊ฒŒ ํ• ๊นŒ์š”?
  • ๋ชจ๋ธ์€ ์ผ๋ฐ˜์ ์œผ๋กœ, ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ์—์„œ ์ œ๊ณตํ•˜๋Š” ๊ธฐ๋ณธ๊ฐ’์— ๊ธฐ๋ฐ˜ํ•˜์—ฌ ๋ชจ๋ธ์„ ํ›ˆ๋ จํ•ฉ๋‹ˆ๋‹ค.
  • ๊ทธ ๋‹ค์Œ, Validation(๊ฒ€์ฆ ์„ธํŠธ) & Cross-Validation(๊ต์ฐจ ๊ฒ€์ฆ)์„ ํ†ตํ•ด์„œ ๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ ์กฐ๊ธˆ์”ฉ ๋ด๊ฟ”๋ณด๋Š” ํ˜•ํƒœ๋กœ ์ง„ํ–‰ํ•ฉ๋‹ˆ๋‹ค.
  • ๊ทผ๋ฐ ์•„์ฃผ ์ค‘์š”ํ•œ ์ ์ด ์žˆ์Šต๋‹ˆ๋‹ค. ์˜ˆ๋ฅผ ๋“ค์–ด, ๊ฒฐ์ • ํŠธ๋ฆฌ ๋ชจ๋ธ์—์„œ ์ตœ์ ์˜ max_depth ๊ฐ’์„ ์ฐพ์•˜๋‹ค๊ณ  ๊ฐ€์ •ํ•ด๋ด…์‹œ๋‹ค.
  • ๊ทธ ๋‹ค์Œ max_depth๋ฅผ ์ตœ์ ์˜ ๊ฐ’์œผ๋กœ ๊ณ ์ •ํ•˜๊ณ  min_samples_split์„ ๋ฐ”๊ฟ”๊ฐ€๋ฉฐ ์ตœ์ ์˜ ๊ฐ’์„ ์ฐพ์Šต๋‹ˆ๋‹ค.
  • ์ด๋ ‡๊ฒŒ ํ•œ ๋งค๊ฐœ๋ณ€์ˆ˜์˜ ์ตœ์  ๊ฐ’์„ ์ฐพ๊ณ  ๋‹ค๋ฅธ ๋งค๊ฐœ๋ณ€์ˆ˜์˜ ์ตœ์  ๊ฐ’์„ ์ฐพ์•„๋„ ๋ ๊นŒ์š”?
  • ์•„๋‹™๋‹ˆ๋‹ค. max_depth์˜ ์ตœ์  ๊ฐ’์€ min_samples_split ๋งค๊ฐœ๋ณ€์ˆ˜์˜ ๊ฐ’์ด ๋ฐ”๋€Œ๋ฉด ํ•จ๊ป˜ ๋‹ฌ๋ผ์ง‘๋‹ˆ๋‹ค.
  • ์ฆ‰, ์ด ๋‘ ๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ ๋™์‹œ์— ๋ฐ”๊ฟ”๊ฐ€๋ฉฐ ์ตœ์ ์˜ ๊ฐ’์„ ์ฐพ์•„์•ผ ํ•ฉ๋‹ˆ๋‹ค.
  • ๊ฒŒ๋‹ค๊ฐ€ ๋งค๊ฐœ๋ณ€์ˆ˜๊ฐ€ ๋งŽ์•„์ง€๋ฉด ๋ฌธ์ œ๋Š” ๋” ๋ณต์žกํ•ด์ง‘๋‹ˆ๋‹ค.
  • ํŒŒ์ด์ฌ์˜ for ๋ฐ˜๋ณต๋ฌธ์œผ๋กœ ์ด๋Ÿฐ ๊ณผ์ •์„ ์ง์ ‘ ๊ตฌํ˜„ํ•  ์ˆ˜ ์žˆ์ง€๋งŒ, Scikit-learn์—์„œ ์ œ๊ณตํ•˜๋Š” ๊ทธ๋ฆฌ๋“œ ์„œ์น˜(GridSearch)๋ฅผ ์‚ฌ์šฉํ•ด์„œ ๊ฐ„๋‹จํ•˜๊ฒŒ ๊ตฌํ˜„ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
from sklearn.model_selection import GridSearchCV
params = {'min_impurity_decrease': [0.0001, 0.0002, 0.0003, 0.0004, 0.0005]}
  • ํ•œ๋ฒˆ GridSearchCV ํด๋ž˜์Šค๋ฅผ ๊ฐ€์ ธ์™€์„œ ํƒ์ƒ‰ํ•  ๋งค๊ฐœ๋ณ€์ˆ˜๋ฐ ํƒ์ƒ‰ํ•  ๊ฐ’์˜ ๋ฆฌ์ŠคํŠธ๋ฅผ ๋”•์…”๋„ˆ๋ฆฌํ™”ํ›„, 5๊ฐœ์˜ ๊ฐ’์œผ๋กœ ํ…Œ์ŠคํŠธ๋ฅผ ํ•ด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.
  • GridSearchCV ํด๋ž˜์Šค์— ํƒ์ƒ‰ ๋Œ€์ƒ ๋ชจ๋ธ & Params ๋ณ€์ˆ˜๋ฅผ ์ „๋‹ฌํ•˜์—ฌ GridSearch ๊ฐ์ฒด๋ฅผ ๋งŒ๋“ญ๋‹ˆ๋‹ค.
gs = GridSearchCV(DecisionTreeClassifier(random_state=42), params,n_jobs=-1)
  • ๊ทธ ๋‹ค์Œ ์ผ๋ฐ˜ ๋ชจ๋ธ์„ ํ›ˆ๋ จํ•˜๋Š” ๊ฒƒ์ฒ˜๋Ÿผ gs ๊ฐ์ฒด์— fit() ๋ฉ”์„œ๋“œ๋ฅผ ํ˜ธ์ถœํ•ฉ๋‹ˆ๋‹ค.
  • ์ด ๋ฉ”์„œ๋“œ๋ฅผ ํ˜ธ์ถœํ•˜๋ฉด ๊ทธ๋ฆฌ๋“œ ์„œ์น˜ ๊ฐ์ฒด๋Š” ๊ฒฐ์ • ํŠธ๋ฆฌ ๋ชจ๋ธ์˜ min_impurity_decrease ๊ฐ’์„ ๋ฐ”๊ฟ”๊ฐ€๋ฉฐ ์ด 5๋ฒˆ ์‹คํ–‰ํ•ฉ๋‹ˆ๋‹ค.
  • GridSearchCV์˜ cv ๋งค๊ฐœ๋ณ€์ˆ˜ ๊ธฐ๋ณธ๊ฐ’์€ 5์ž…๋‹ˆ๋‹ค. ๋”ฐ๋ผ์„œ min_impurity_decrease ๊ฐ’๋งˆ๋‹ค 5-ํด๋“œ ๊ต์ฐจ ๊ฒ€์ฆ์„ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค.
  • ๊ทธ๋Ÿฌ๋ฉด ์ด (5 x 5 = 25)๊ฐœ์˜ ๋ชจ๋ธ์„ ํ›ˆ๋ จํ•ฉ๋‹ˆ๋‹ค.
gs.fit(train_input, train_target)
  • ์‚ฌ์ดํ‚ท๋Ÿฐ์˜ ๊ทธ๋ฆฌ๋“œ ์„œ์น˜๋Š” ํ›ˆ๋ จ์ด ๋๋‚˜๋ฉด 25๊ฐœ์˜ ๋ชจ๋ธ ์ค‘์—์„œ ๊ฒ€์ฆ ์ ์ˆ˜๊ฐ€ ๊ฐ€์žฅ ๋†’์€ ๋ชจ๋ธ์˜ ๋งค๊ฐœ๋ณ€์ˆ˜ ์กฐํ•ฉ์œผ๋กœ ์ „์ฒด ํ›ˆ๋ จ ์„ธํŠธ์—์„œ ์ž๋™์œผ๋กœ ๋‹ค์‹œ ๋ชจ๋ธ์„ ํ›ˆ๋ จํ•ฉ๋‹ˆ๋‹ค.
  • ์ด ๋ชจ๋ธ์€ gs ๊ฐ์ฒด์˜ best_estimator_ ์†์„ฑ์— ์ €์žฅ๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค.
  • ์ด ๋ชจ๋ธ์„ ์ผ๋ฐ˜ ๊ฒฐ์ • ํŠธ๋ฆฌ์ฒ˜๋Ÿผ ๋˜‘๊ฐ™์ด ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
dt = gs.best_estimator_
print(dt.score(train_input, train_target))
print(gs.best_params_)

# 0.9615162593804117
# {'min_impurity_decrease': 0.0001}
  • GridSearch๋กœ ์ฐพ์€ ์ตœ์ ์˜ ๋งค๊ฐœ๋ณ€์ˆ˜๋Š” best_params_ ์†์„ฑ์— ์ €์žฅ๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค.
  • ๊ทธ๋Ÿฌ๋ฉด ์ด๋ฒˆ์—” 5๋ฒˆ์˜ ๊ต์ฐจ์ ์ฆ์œผ๋กœ ์–ป์€ ์ ์ˆ˜๋ฅผ ์ถœ๋ ฅํ•ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.
print(gs.cv_results_['mean_test_score'])

# [0.86819297 0.86453617 0.86492226 0.86780891 0.86761661]
  • ์ฒซ ๋ฒˆ์งธ ๊ฐ’์ด ๊ฐ€์žฅ ํฐ ๊ฒƒ ๊ฐ™๊ตฐ์š”. ์ˆ˜๋™์œผ๋กœ ๊ณ ๋ฅด๋Š” ๊ฒƒ๋ณด๋‹ค ๋„˜ํŒŒ์ด argmax() ํ•จ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•˜๋ฉด ๊ฐ€์žฅ ํฐ ๊ฐ’์˜ ์ธ๋ฑ์Šค๋ฅผ ์ถ”์ถœํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
  • ๊ทธ๋‹ค์Œ ์ด ์ธ๋ฑ์Šค๋ฅผ ์‚ฌ์šฉํ•ด params ํ‚ค์— ์ €์žฅ๋œ ๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ ์ถœ๋ ฅํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
  • ์ด ๊ฐ’์ด ์ตœ์ƒ์˜ ๊ฒ€์ฆ ์ ์ˆ˜๋ฅผ ๋งŒ๋“  ๋งค๊ฐœ๋ณ€์ˆ˜ ์กฐํ•ฉ์ž…๋‹ˆ๋‹ค. ์•ž์—์„œ ์ถœ๋ ฅํ•œ gs.best_params_์™€ ๋™์ผํ•œ์ง€ ํ™•์ธํ•ด ๋ณด์„ธ์š”.
best_index = np.argmax(gs.cv_results_['mean_test_score'])
print(gs.cv_results_['params'][best_index])

# {'min_impurity_decrease': 0.0001}

 

๊ทธ๋Ÿฌ๋ฉด ์œ„์˜ ๊ณผ์ •์„ ํ•œ๋ฒˆ ์ •๋ฆฌํ•ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.

  1. ๋จผ์ € ํƒ์ƒ‰ํ•  ๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ ์ง€์ •ํ•ฉ๋‹ˆ๋‹ค.
  2. ๊ทธ๋‹ค์Œ ํ›ˆ๋ จ ์„ธํŠธ์—์„œ ๊ทธ๋ฆฌ๋“œ ์„œ์น˜๋ฅผ ์ˆ˜ํ–‰ํ•˜์—ฌ ์ตœ์ƒ์˜ ํ‰๊ท  ๊ฒ€์ฆ ์ ์ˆ˜๊ฐ€ ๋‚˜์˜ค๋Š” ๋งค๊ฐœ๋ณ€์ˆ˜ ์กฐํ•ฉ์„ ์ฐพ์Šต๋‹ˆ๋‹ค.
    • ์ด ์กฐํ•ฉ์€ ๊ทธ๋ฆฌ๋“œ ์„œ์น˜ ๊ฐ์ฒด์— ์ €์žฅ๋ฉ๋‹ˆ๋‹ค.
  3. ๊ทธ๋ฆฌ๋“œ ์„œ์น˜๋Š” ์ตœ์ƒ์˜ ๋งค๊ฐœ๋ณ€์ˆ˜์—์„œ (๊ต์ฐจ ๊ฒ€์ฆ์— ์‚ฌ์šฉํ•œ ํ›ˆ๋ จ ์„ธํŠธ๊ฐ€ ์•„๋‹ˆ๋ผ) ์ „์ฒด ํ›ˆ๋ จ ์„ธํŠธ๋ฅผ ์‚ฌ์šฉํ•ด ์ตœ์ข… ๋ชจ๋ธ์„ ํ›ˆ๋ จํ•ฉ๋‹ˆ๋‹ค.
    • ์ด ๋ชจ๋ธ๋„ ๊ทธ๋ฆฌ๋“œ ์„œ์น˜ ๊ฐ์ฒด์— ์ €์žฅ๋ฉ๋‹ˆ๋‹ค.

 

  • ๊ทธ๋Ÿฌ๋ฉด, ์ด๋ฒˆ์—๋Š” ๋” ๋ณต์žกํ•œ ๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ ํƒ์ƒ‰ํ•ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.
  • ๊ฒฐ์ • ํŠธ๋ฆฌ์—์„œ min_impurity_decrease๋Š” ๋…ธ๋“œ๋ฅผ ๋ถ„ํ• ํ•˜๊ธฐ ์œ„ํ•œ ๋ถˆ์ˆœ๋„ ๊ฐ์†Œ ์ตœ์†Ÿ๊ฐ’์„ ์ง€์ •ํ•ฉ๋‹ˆ๋‹ค.
  • ์—ฌ๊ธฐ์—๋‹ค๊ฐ€ max_depth๋กœ ํŠธ๋ฆฌ์˜ ๊นŠ์ด๋ฅผ ์ œํ•œํ•˜๊ณ  min_samples_split์œผ๋กœ ๋…ธ๋“œ๋ฅผ ๋‚˜๋ˆ„๊ธฐ ์œ„ํ•œ ์ตœ์†Œ ์ƒ˜ํ”Œ ์ˆ˜๋„ ๊ณจ๋ผ ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.

 

  • ๋„˜ํŒŒ์ด arange() ํ•จ์ˆ˜(โ‘ )๋Š” ์ฒซ ๋ฒˆ์งธ ๋งค๊ฐœ๋ณ€์ˆ˜ ๊ฐ’์—์„œ ์‹œ์ž‘ํ•˜์—ฌ ๋‘ ๋ฒˆ์งธ ๋งค๊ฐœ๋ณ€์ˆ˜ ๊ฐ’์— ๋„๋‹ฌํ•  ๋•Œ๊นŒ์ง€ ์„ธ ๋ฒˆ์งธ ๋งค๊ฐœ๋ณ€์ˆ˜์”ฉ ๊ณ„์† ๋”ํ•œ ๋ฐฐ์—ด์„ ๋งŒ๋“ญ๋‹ˆ๋‹ค.
  • ์ฝ”๋“œ์—์„œ๋Š” 0.0001์—์„œ ์‹œ์ž‘ํ•˜์—ฌ 0.001์ด ๋  ๋•Œ๊นŒ์ง€ 0.0001์„ ๊ณ„์† ๋”ํ•œ ๋ฐฐ์—ด์ž…๋‹ˆ๋‹ค.
  • ๋‘ ๋ฒˆ์งธ ๋งค๊ฐœ๋ณ€์ˆ˜๋Š” ํฌํ•จ๋˜์ง€ ์•Š์œผ๋ฏ€๋กœ ๋ฐฐ์—ด์˜ ์›์†Œ๋Š” ์ด 9๊ฐœ์ž…๋‹ˆ๋‹ค.
  • ํŒŒ์ด์ฌ range() ํ•จ์ˆ˜(โ‘ก)๋„ ๋น„์Šทํ•ฉ๋‹ˆ๋‹ค. ํ•˜์ง€๋งŒ ์ด ํ•จ์ˆ˜๋Š” ์ •์ˆ˜๋งŒ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
  • ์ด ๊ฒฝ์šฐ max_depth๋ฅผ 5์—์„œ 20๊นŒ์ง€ 1์”ฉ ์ฆ๊ฐ€ํ•˜๋ฉด์„œ 15๊ฐœ์˜ ๊ฐ’์„ ๋งŒ๋“ญ๋‹ˆ๋‹ค.
  • min_samples_split์€ 2์—์„œ 100๊นŒ์ง€ 10์”ฉ ์ฆ๊ฐ€ํ•˜๋ฉด์„œ 10๊ฐœ์˜ ๊ฐ’์„ ๋งŒ๋“ญ๋‹ˆ๋‹ค.
  • ๋”ฐ๋ผ์„œ ์ด ๋งค๊ฐœ๋ณ€์ˆ˜๋กœ ์ˆ˜ํ–‰ํ•  ๊ต์ฐจ ๊ฒ€์ฆ ํšŸ์ˆ˜๋Š” 9 × 15 × 10 = 1,350๊ฐœ์ž…๋‹ˆ๋‹ค.
  • ๊ธฐ๋ณธ 5-ํด๋“œ ๊ต์ฐจ ๊ฒ€์ฆ์„ ์ˆ˜ํ–‰ํ•˜๋ฏ€๋กœ ๋งŒ๋“ค์–ด์ง€๋Š” ๋ชจ๋ธ์˜ ์ˆ˜๋Š” 6,750๊ฐœ ์ž…๋‹ˆ๋‹ค.

 

  • ํ•œ๋ฒˆ, n_jobs ๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ -1๋กœ ์„ค์ •ํ•˜๊ณ  ๊ทธ๋ฆฌ๋“œ ์„œ์น˜๋ฅผ ์‹คํ–‰ํ•ด ๋ณธํ›„, ์ตœ์ƒ์˜ ๋งค๊ฐœ๋ณ€์ˆ˜ ์กฐํ•ฉ์ด๋ž‘ ๊ต์ฐจ์ ์ฆ ์ ์ˆ˜๋„ ํ™•์ธํ•ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.
params = {'min_impurity_decrease': np.arange(0.0001, 0.001, 0.0001),
          'max_depth': range(5, 20, 1),
          'min_samples_split': range(2, 100, 10)
          }

gs = GridSearchCV(DecisionTreeClassifier(random_state=42), params, n_jobs=-1)
gs.fit(train_input, train_target)
print(gs.best_params_)
print(np.max(gs.cv_results_['mean_test_score']))

# {'max_depth': 14, 'min_impurity_decrease': 0.0004, 'min_samples_split': 12}
# 0.8683865773302731

 

GridSearchCV ํด๋ž˜์Šค๋ฅผ ์‚ฌ์šฉํ•˜๋‹ˆ ๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ ์ผ์ผ์ด ๋ฐ”๊พธ๋ฉฐ ๊ต์ฐจ ๊ฒ€์ฆ์„ ์ˆ˜ํ–‰ํ•˜์ง€ ์•Š๊ณ  ์›ํ•˜๋Š” ๋งค๊ฐœ๋ณ€์ˆ˜ ๊ฐ’์„ ๋‚˜์—ดํ•œ ๋‹ค์Œ, ๊ต์ฐจ ๊ฒ€์ฆ์„ ์ˆ˜ํ–‰ํ•ด์„œ ์ตœ์ƒ์˜ ๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ ์ฐพ์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.


Random Search (๋žœ๋ค ์„œ์น˜)

  • ๋งค๊ฐœ๋ณ€์ˆ˜์˜ ๊ฐ’์ด ์ˆ˜์น˜์ผ ๋•Œ ๊ฐ’์˜ ๋ฒ”์œ„๋‚˜ ๊ฐ„๊ฒฉ์„ ๋ฏธ๋ฆฌ ์ •ํ•˜๊ธฐ ์–ด๋ ค์šธ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
  • ๋˜ ๋„ˆ๋ฌด ๋งŽ์€ ๋งค๊ฐœ๋ณ€์ˆ˜ ์กฐ๊ฑด์ด ์žˆ์–ด ๊ทธ๋ฆฌ๋“œ ์„œ์น˜ ์ˆ˜ํ–‰ ์‹œ๊ฐ„์ด ์˜ค๋ž˜ ๊ฑธ๋ฆด ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ด๋Ÿด ๋•Œ ๋žœ๋ค ์„œ์น˜(Random Search)๋ฅผ ์‚ฌ์šฉํ•˜๋ฉด ์ข‹์Šต๋‹ˆ๋‹ค.
  • ๋žœ๋ค ์„œ์น˜์—๋Š” ๋งค๊ฐœ๋ณ€์ˆ˜ ๊ฐ’์˜ ๋ชฉ๋ก์„ ์ „๋‹ฌํ•˜๋Š” ๊ฒƒ์ด ์•„๋‹ˆ๋ผ ๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ ์ƒ˜ํ”Œ๋งํ•  ์ˆ˜ ์žˆ๋Š” ํ™•๋ฅ  ๋ถ„ํฌ ๊ฐ์ฒด๋ฅผ ์ „๋‹ฌํ•ฉ๋‹ˆ๋‹ค.
  • ํ™•๋ฅ  ๋ถ„ํฌ(probability distribution)๋ž€ ๋ฌด์ž‘์œ„ ๊ฐ’์„ ๋ฝ‘์„ ์ˆ˜ ์žˆ๋Š” ๋ฒ”์œ„๋ฅผ ๋งํ•ฉ๋‹ˆ๋‹ค.
from scipy.stats import uniform, randint
  • Scipy์˜ stats ์„œ๋ธŒ ํŒจํ‚ค์ง€์— ์žˆ๋Š” uniform๊ณผ randint ํด๋ž˜์Šค๋Š” ๋ชจ๋‘ ์ฃผ์–ด์ง„ ๋ฒ”์œ„์—์„œ ๊ณ ๋ฅด๊ฒŒ ๊ฐ’์„ ๋ฝ‘์Šต๋‹ˆ๋‹ค.
  • randint๋Š” ์ •์ˆ˜๊ฐ’์„ ๋ฝ‘๊ณ , uniform์€ ์‹ค์ˆ˜๊ฐ’์„ ๋ฝ‘์Šต๋‹ˆ๋‹ค. ๊ท ๋“ฑ ๋ถ„ํฌ์—์„œ ์ƒ˜ํ”Œ๋งํ•œ๋‹ค๊ณ  ๋งํ•ฉ๋‹ˆ๋‹ค.
  • 0์—์„œ 10 ์‚ฌ์ด์˜ ๋ฒ”์œ„๋ฅผ ๊ฐ–๋Š” randint ๊ฐ์ฒด๋ฅผ ๋งŒ๋“ค๊ณ  10๊ฐœ์˜ ์ˆซ์ž๋ฅผ ์ƒ˜ํ”Œ๋งํ•ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.
rgen = randint(0, 10)
rgen.rvs(10)

array([0, 6, 5, 1, 4, 0, 9, 4, 0, 3])
10๊ฐœ๋ฐ–์— ๋˜์ง€ ์•Š๊ธฐ ๋•Œ๋ฌธ์— ๊ณ ๋ฅด๊ฒŒ ์ƒ˜ํ”Œ๋ง๋˜๋Š” ๊ฒƒ ๊ฐ™์ง€ ์•Š์ง€๋งŒ ์ƒ˜ํ”Œ๋ง ์ˆซ์ž๋ฅผ ๋Š˜๋ฆฌ๋ฉด ์‰ฝ๊ฒŒ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
1,000๊ฐœ๋ฅผ ์ƒ˜ํ”Œ๋งํ•ด์„œ ๊ฐ ์ˆซ์ž์˜ ๊ฐœ์ˆ˜๋ฅผ ์„ธ์–ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.
np.unique(rgen.rvs(1000), return_counts=True)

# (array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
# array([ 95,  90,  92,  98, 130,  86, 104, 116,  92,  97]))

 

  • ๊ฐœ์ˆ˜๊ฐ€ ๋Š˜์–ด๋‚˜๋‹ˆ 0์—์„œ 9๊นŒ์ง€์˜ ์ˆซ์ž๊ฐ€ ์–ด๋Š ์ •๋„ ๊ณ ๋ฅด๊ฒŒ ์ถ”์ถœ๋œ ๊ฒƒ์„ ๋ณผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
  • uniform ํด๋ž˜์Šค์˜ ์‚ฌ์šฉ๋ฒ•๋„ ๋™์ผํ•ฉ๋‹ˆ๋‹ค. 0~1 ์‚ฌ์ด์—์„œ 10๊ฐœ์˜ ์‹ค์ˆ˜๋ฅผ ์ถ”์ถœํ•ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.
ugen = uniform(0, 1)
ugen.rvs(10)

# array([0.94577774, 0.89877912, 0.63212905, 0.32615542, 0.37558058,
#       0.47341714, 0.16045226, 0.83559588, 0.26931821, 0.43901825])
  • ๋‚œ์ˆ˜ ๋ฐœ์ƒ๊ธฐ๋ฅผ ์œ ์‚ฌํ•˜๊ฒŒ ์ƒ๊ฐํ•˜๋ฉด ๋ฉ๋‹ˆ๋‹ค.
  • ๋žœ๋ค ์„œ์น˜์— randint์™€ uniform ํด๋ž˜์Šค ๊ฐ์ฒด๋ฅผ ๋„˜๊ฒจ์ฃผ๊ณ  ์ด ๋ช‡ ๋ฒˆ ์ƒ˜ํ”Œ๋งํ•ด์„œ ์ตœ์ ์˜ ๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ ์ฐพ์œผ๋ผ๊ณ  ๋ช…๋ นํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
  • ์ƒ˜ํ”Œ๋ง ํšŸ์ˆ˜๋Š” ์‹œ์Šคํ…œ ์ž์›์ด ํ—ˆ๋ฝํ•˜๋Š” ๋ฒ”์œ„ ๋‚ด์—์„œ ์ตœ๋Œ€ํ•œ ํฌ๊ฒŒ ํ•˜๋Š” ๊ฒƒ์ด ์ข‹์Šต๋‹ˆ๋‹ค.

 

  • ํƒ์ƒ‰ํ•  ๋งค๊ฐœ๋ณ€์ˆ˜์˜ ๋”•์…”๋„ˆ๋ฆฌ๋ฅผ ๋งŒ๋“ค์–ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.
  • ์—ฌ๊ธฐ์—์„œ๋Š” min_samples_leaf ๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ ํƒ์ƒ‰ ๋Œ€์ƒ์— ์ถ”๊ฐ€ํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค.
  • ์ด ๋งค๊ฐœ๋ณ€์ˆ˜๋Š” ๋ฆฌํ”„ ๋…ธ๋“œ๊ฐ€ ๋˜๊ธฐ ์œ„ํ•œ ์ตœ์†Œ ์ƒ˜ํ”Œ์˜ ๊ฐœ์ˆ˜์ž…๋‹ˆ๋‹ค.
  • ์–ด๋–ค ๋…ธ๋“œ๊ฐ€ ๋ถ„ํ• ํ•˜์—ฌ ๋งŒ๋“ค์–ด์งˆ ์ž์‹ ๋…ธ๋“œ์˜ ์ƒ˜ํ”Œ ์ˆ˜๊ฐ€ ์ด ๊ฐ’๋ณด๋‹ค ์ž‘์„ ๊ฒฝ์šฐ ๋ถ„ํ• ํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค.
params = {'min_impurity_decrease': uniform(0.0001, 0.001),
          'max_depth': randint(20, 50),
          'min_samples_split': randint(2, 25),
          'min_samples_leaf': randint(1, 25),
          }
  • min_imputiry_decrease๋Š” 0.0001์—์„œ 0.001 ์‚ฌ์ด์˜ ์‹ค์ˆซ๊ฐ’์„ ์ƒ˜ํ”Œ๋งํ•ฉ๋‹ˆ๋‹ค.
  • max_depth๋Š” 20์—์„œ 50 ์‚ฌ์ด์˜ ์ •์ˆ˜, min_samples_split์€ 2์—์„œ 25 ์‚ฌ์ด์˜ ์ •์ˆ˜, min_samples_leaf๋Š” 1์—์„œ 25 ์‚ฌ์ด์˜ ์ •์ˆ˜๋ฅผ ์ƒ˜ํ”Œ๋งํ•ฉ๋‹ˆ๋‹ค.
  • ์ƒ˜ํ”Œ๋ง ํšŸ์ˆ˜๋Š” ์‚ฌ์ดํ‚ท๋Ÿฐ์˜ ๋žœ๋ค ์„œ์น˜ ํด๋ž˜์Šค์ธ RandomizedSearchCV์˜ n_iter ๋งค๊ฐœ๋ณ€์ˆ˜์— ์ง€์ •ํ•ฉ๋‹ˆ๋‹ค.
from sklearn.model_selection import RandomizedSearchCV

gs = RandomizedSearchCV(DecisionTreeClassifier(random_state=42), params, 
                        n_iter=100, n_jobs=-1, random_state=42)
gs.fit(train_input, train_target)
  • params์— ์ •์˜๋œ ๋งค๊ฐœ๋ณ€์ˆ˜ ๋ฒ”์œ„์—์„œ ์ด 100๋ฒˆ(n_iter ๋งค๊ฐœ๋ณ€์ˆ˜)์„ ์ƒ˜ํ”Œ๋งํ•˜์—ฌ ๊ต์ฐจ ๊ฒ€์ฆ์„ ์ˆ˜ํ–‰ํ•˜๊ณ  ์ตœ์ ์˜ ๋งค๊ฐœ๋ณ€์ˆ˜ ์กฐํ•ฉ์„ ์ฐพ์Šต๋‹ˆ๋‹ค. ์•ž์„œ ๊ทธ๋ฆฌ๋“œ ์„œ์น˜๋ณด๋‹ค ์›”๋“ฑํžˆ ๊ต์ฐจ ๊ฒ€์ฆ ์ˆ˜๋ฅผ ์ค„์ด๋ฉด์„œ ๋„“์€ ์˜์—ญ์„ ํšจ๊ณผ์ ์œผ๋กœ ํƒ์ƒ‰ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
print(gs.best_params_)

# {'max_depth': 39, 'min_impurity_decrease': 0.00034102546602601173,
# 'min_samples_leaf': 7, 'min_samples_split': 13}
  • ์ตœ๊ณ ์˜ ๊ต์ฐจ ๊ฒ€์ฆ ์ ์ˆ˜ & ํ…Œ์ŠคํŠธ ์„ธํŠธ์˜ ์„ฑ๋Šฅ๋„ ํ™•์ธํ•ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.
print(np.max(gs.cv_results_['mean_test_score']))
# 0.8695428296438884
dt = gs. best_estimator_
print(dt.score(test_input, test_target))
# 0.86
  • ํ…Œ์ŠคํŠธ์„ธํŠธ ์ ์ˆ˜๋Š” ๊ฒ€์ฆ์„ธํŠธ์— ๋Œ€ํ•œ ์ ์ˆ˜๋ณด๋‹ค ์กฐ๊ธˆ ์ž‘์€๊ฒƒ์ด ์ผ๋ฐ˜์ ์ž…๋‹ˆ๋‹ค.

Summary

Keywords

  • ๊ฒ€์ฆ ์„ธํŠธ๋Š” ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ ํŠœ๋‹์„ ์œ„ํ•ด ๋ชจ๋ธ์„ ํ‰๊ฐ€ํ•  ๋•Œ, ํ…Œ์ŠคํŠธ ์„ธํŠธ๋ฅผ ์‚ฌ์šฉํ•˜์ง€ ์•Š๊ธฐ ์œ„ํ•ด ํ›ˆ๋ จ ์„ธํŠธ์—์„œ ๋‹ค์‹œ ๋–ผ์–ด๋‚ธ ๋ฐ์ดํ„ฐ ์„ธํŠธ์ž…๋‹ˆ๋‹ค.
  • ๊ต์ฐจ ๊ฒ€์ฆ์€ ํ›ˆ๋ จ ์„ธํŠธ๋ฅผ ์—ฌ๋Ÿฌ ํด๋“œ๋กœ ๋‚˜๋ˆˆ ๋‹ค์Œ, ํ•œ ํด๋“œ๊ฐ€ ๊ฒ€์ฆ ์„ธํŠธ์˜ ์—ญํ• ์„ ํ•˜๊ณ  ๋‚˜๋จธ์ง€ ํด๋“œ์—์„œ๋Š” ๋ชจ๋ธ์„ ํ›ˆ๋ จํ•˜๋Š” ๋ฐฉ๋ฒ•์ž…๋‹ˆ๋‹ค. ๊ต์ฐจ ๊ฒ€์ฆ์„ ํ†ตํ•ด ๋ชจ๋“  ํด๋“œ์— ๋Œ€ํ•ด ๊ฒ€์ฆ ์ ์ˆ˜๋ฅผ ์–ป์–ด ํ‰๊ฐ€ํ•ฉ๋‹ˆ๋‹ค.
  • ๊ทธ๋ฆฌ๋“œ ์„œ์น˜๋Š” ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ ํƒ์ƒ‰์„ ์ž๋™ํ™”ํ•ด์ฃผ๋Š” ๋„๊ตฌ์ž…๋‹ˆ๋‹ค. ํƒ์ƒ‰ํ•  ๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ ๋‚˜์—ดํ•˜๋ฉด ๊ต์ฐจ ๊ฒ€์ฆ์„ ์ˆ˜ํ–‰ํ•˜์—ฌ ๊ฐ€์žฅ ์ข‹์€ ๊ฒ€์ฆ ์ ์ˆ˜์˜ ๋งค๊ฐœ๋ณ€์ˆ˜ ์กฐํ•ฉ์„ ์„ ํƒํ•˜๊ณ , ๋งˆ์ง€๋ง‰์œผ๋กœ ์ด ๋งค๊ฐœ๋ณ€์ˆ˜ ์กฐํ•ฉ์œผ๋กœ ์ตœ์ข… ๋ชจ๋ธ์„ ํ›ˆ๋ จํ•ฉ๋‹ˆ๋‹ค.
  • ๋žœ๋ค ์„œ์น˜๋Š” ์—ฐ์†๋œ ๋งค๊ฐœ๋ณ€์ˆ˜ ๊ฐ’์„ ํƒ์ƒ‰ํ•  ๋•Œ ์œ ์šฉํ•ฉ๋‹ˆ๋‹ค. ํƒ์ƒ‰ํ•  ๊ฐ’์„ ์ง์ ‘ ๋‚˜์—ดํ•˜๋Š” ๊ฒƒ์ด ์•„๋‹ˆ๋ผ, ํƒ์ƒ‰๊ฐ’์„ ์ƒ˜ํ”Œ๋งํ•  ์ˆ˜ ์žˆ๋Š” ํ™•๋ฅ  ๋ถ„ํฌ ๊ฐ์ฒด๋ฅผ ์ „๋‹ฌํ•ฉ๋‹ˆ๋‹ค.
    • ์ง€์ •๋œ ํšŸ์ˆ˜๋งŒํผ ์ƒ˜ํ”Œ๋งํ•˜์—ฌ ๊ต์ฐจ ๊ฒ€์ฆ์„ ์ˆ˜ํ–‰ํ•˜๊ธฐ ๋•Œ๋ฌธ์— ์‹œ์Šคํ…œ ์ž์›์ด ํ—ˆ๋ฝํ•˜๋Š” ๋งŒํผ ํƒ์ƒ‰๋Ÿ‰์„ ์กฐ์ ˆํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

 

ํ•ต์‹ฌ ํŒจํ‚ค์ง€์™€ ํ•จ์ˆ˜

  • cross_validate(): ๊ต์ฐจ ๊ฒ€์ฆ์„ ์ˆ˜ํ–‰ํ•˜๋Š” ํ•จ์ˆ˜์ž…๋‹ˆ๋‹ค.
    • ์ฒซ ๋ฒˆ์งธ ๋งค๊ฐœ๋ณ€์ˆ˜์— ๊ต์ฐจ ๊ฒ€์ฆ์„ ์ˆ˜ํ–‰ํ•  ๋ชจ๋ธ ๊ฐ์ฒด๋ฅผ ์ „๋‹ฌํ•˜๊ณ , ๋‘ ๋ฒˆ์งธ์™€ ์„ธ ๋ฒˆ์งธ ๋งค๊ฐœ๋ณ€์ˆ˜์— ํŠน์„ฑ๊ณผ ํƒ€๊นƒ ๋ฐ์ดํ„ฐ๋ฅผ ์ „๋‹ฌํ•ฉ๋‹ˆ๋‹ค.
    • scoring ๋งค๊ฐœ๋ณ€์ˆ˜์— ๊ฒ€์ฆ์— ์‚ฌ์šฉํ•  ํ‰๊ฐ€์ง€ํ‘œ๋ฅผ ์ง€์ •ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๊ธฐ๋ณธ์ ์œผ๋กœ ๋ถ„๋ฅ˜ ๋ชจ๋ธ์€ ์ •ํ™•๋„๋ฅผ ์˜๋ฏธํ•˜๋Š” 'accuracy', ํšŒ๊ท€ ๋ชจ๋ธ์€ ๊ฒฐ์ • ๊ณ„์ˆ˜๋ฅผ ์˜๋ฏธํ•˜๋Š” 'r2'๊ฐ€ ๋ฉ๋‹ˆ๋‹ค.
    • cv ๋งค๊ฐœ๋ณ€์ˆ˜์—๋Š” ๊ต์ฐจ ๊ฒ€์ฆ ํด๋“œ ์ˆ˜๋‚˜ ์Šคํ”Œ๋ฆฌํ„ฐ ๊ฐ์ฒด๋ฅผ ์ง€์ •ํ•  ์ˆ˜ ์žˆ์œผ๋ฉฐ, ๊ธฐ๋ณธ๊ฐ’์€ 5์ž…๋‹ˆ๋‹ค. ํšŒ๊ท€์ผ ๋•Œ๋Š” KFold ํด๋ž˜์Šค๋ฅผ ์‚ฌ์šฉํ•˜๊ณ , ๋ถ„๋ฅ˜์ผ ๋•Œ๋Š” StratifiedKFold ํด๋ž˜์Šค๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ 5-ํด๋“œ ๊ต์ฐจ ๊ฒ€์ฆ์„ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค.
    • n_jobs ๋งค๊ฐœ๋ณ€์ˆ˜๋Š” ๊ต์ฐจ ๊ฒ€์ฆ์„ ์ˆ˜ํ–‰ํ•  ๋•Œ ์‚ฌ์šฉํ•  CPU ์ฝ”์–ด ์ˆ˜๋ฅผ ์ง€์ •ํ•˜๋ฉฐ, ๊ธฐ๋ณธ๊ฐ’์€ 1์ž…๋‹ˆ๋‹ค.
    • 1๋กœ ์ง€์ •ํ•˜๋ฉด ์‹œ์Šคํ…œ์— ์žˆ๋Š” ๋ชจ๋“  ์ฝ”์–ด๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค. return_train_score ๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ True๋กœ ์ง€์ •ํ•˜๋ฉด ํ›ˆ๋ จ ์„ธํŠธ์˜ ์ ์ˆ˜๋„ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค. ๊ธฐ๋ณธ๊ฐ’์€ False์ž…๋‹ˆ๋‹ค.
  • GridSearchCV: ๊ต์ฐจ ๊ฒ€์ฆ์œผ๋กœ ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ ํƒ์ƒ‰์„ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค.
    • ์ตœ์ƒ์˜ ๋ชจ๋ธ์„ ์ฐพ์€ ํ›„ ํ›ˆ๋ จ ์„ธํŠธ ์ „์ฒด๋ฅผ ์‚ฌ์šฉํ•ด ์ตœ์ข… ๋ชจ๋ธ์„ ํ›ˆ๋ จํ•ฉ๋‹ˆ๋‹ค.
    • ์ฒซ ๋ฒˆ์งธ ๋งค๊ฐœ๋ณ€์ˆ˜์— ๊ทธ๋ฆฌ๋“œ ์„œ์น˜๋ฅผ ์ˆ˜ํ–‰ํ•  ๋ชจ๋ธ ๊ฐ์ฒด๋ฅผ ์ „๋‹ฌํ•˜๊ณ , ๋‘ ๋ฒˆ์งธ ๋งค๊ฐœ๋ณ€์ˆ˜์—๋Š” ํƒ์ƒ‰ํ•  ๋ชจ๋ธ์˜ ๋งค๊ฐœ๋ณ€์ˆ˜์™€ ๊ฐ’์„ ์ „๋‹ฌํ•ฉ๋‹ˆ๋‹ค.
    • scoring, cv, n_jobs, return_train_score ๋งค๊ฐœ๋ณ€์ˆ˜๋Š” cross_validate() ํ•จ์ˆ˜์™€ ๋™์ผํ•ฉ๋‹ˆ๋‹ค.
  • RandomizedSearchCV: ๊ต์ฐจ ๊ฒ€์ฆ์œผ๋กœ ๋žœ๋คํ•œ ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ ํƒ์ƒ‰์„ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค.
    • ์ตœ์ƒ์˜ ๋ชจ๋ธ์„ ์ฐพ์€ ํ›„ ํ›ˆ๋ จ ์„ธํŠธ ์ „์ฒด๋ฅผ ์‚ฌ์šฉํ•ด ์ตœ์ข… ๋ชจ๋ธ์„ ํ›ˆ๋ จํ•ฉ๋‹ˆ๋‹ค.
    • ์ฒซ ๋ฒˆ์งธ ๋งค๊ฐœ๋ณ€์ˆ˜์— ๊ทธ๋ฆฌ๋“œ ์„œ์น˜๋ฅผ ์ˆ˜ํ–‰ํ•  ๋ชจ๋ธ ๊ฐ์ฒด๋ฅผ ์ „๋‹ฌํ•˜๊ณ , ๋‘ ๋ฒˆ์งธ ๋งค๊ฐœ๋ณ€์ˆ˜์—๋Š” ํƒ์ƒ‰ํ•  ๋ชจ๋ธ์˜ ๋งค๊ฐœ๋ณ€์ˆ˜์™€ ํ™•๋ฅ  ๋ถ„ํฌ ๊ฐ์ฒด๋ฅผ ์ „๋‹ฌํ•ฉ๋‹ˆ๋‹ค.
    • scoring, cv, n_jobs, return_train_score ๋งค๊ฐœ๋ณ€์ˆ˜๋Š” cross_validate() ํ•จ์ˆ˜์™€ ๋™์ผํ•ฉ๋‹ˆ๋‹ค.