๋ฐ์ํ
Logistic Regression (๋ก์ง์คํฑ ํ๊ท) ๋ก ์์ธ ๋ถ๋ฅํ๊ธฐ
์์ธ์ ๋ถ๋ฅ ํ๊ธฐ ์ํด์ ์ผ๋จ ๋ฐ์ดํฐ์ ์ ๋ถ๋ฌ์ค๊ฒ ์ต๋๋ค.
import pandas as pd
wine = pd.read_csv('https://bit.ly/wine_csv_data')
wine.head()
- ์ด๋ ๊ฒ ๋ฐ์ดํฐ์ ์ Pandas DataFrame์ผ๋ก ์ ๋ถ๋ฌ ์๋์ง head() Method๋ก ํ๋ฒ ๋ถ๋ฌ์์ต๋๋ค.
- ์ฒ์ 3๊ฐ์ ์ด(alcohol, suger, pH)๋ ์์ฝ์ฌ ๋์, ๋น๋, pH(์ฐ๋)๋ฅผ ๋ํ๋ ๋๋ค.
- class๋ ํ๊น๊ฐ์ด 0์ด๋ฉด ๋ ๋์์ธ, 1์ด๋ฉด ํ์ดํธ ์์ธ ์ด๋ผ๊ณ ํฉ๋๋ค.
- ์ด๊ฑด ๋ ๋ & ํ์ดํธ ์์ธ์ ๊ตฌ๋ถํ๋ Binary Classification(์ด์ง ๋ถ๋ฅ)๋ฌธ์ ์ธ๊ฑฐ ๊ฐ์ต๋๋ค. ์ฆ, ์ ์ฒด ์์ธ์ ๋ฐ์ดํฐ์์ ํ์ดํธ ์์ธ์ ๊ณจ๋ผ๋ด๋ ๋ฌธ์ ์ ๋๋ค.
- ๊ณจ๋ผ๋ด๊ธฐ ์ ์, ๋ฐ์ดํฐํ๋ ์์ ๊ฐ ์ด & ๋ฐ์ดํฐ ํ์ ๊ณผ ๋๋ฝ๋ ๋ฐ์ดํฐ๊ฐ ์๋์ง ํ์ธํด ๋ณด๊ฒ ์ต๋๋ค.
# wine-dataframe(์ด๋ค type์ ๋ฐ์ดํฐ, ํฌ๊ธฐ๊ฐ ์ผ๋ง์ธ์ง ์ ๋ณด ์ ๊ณต)
# Non-Null count: ๋๋ฝ๋ ๊ฐ ํ์ธ (๋ง์ฝ ์์ผ๋ฉด 6496๊ฐ๊ฐ ๋์ด)
wine.info()
# ํ๊ท , ํ์คํธ์ฐจ, ์ต์, ์ต๋๊ฐ ๋ฑ ์ ๋ณด ํ์ธ (describe method)
wine.describe()
์ด๋ ๊ฒ EDA ๊ณผ์ ์ ๊ฑฐ์ณ์ ์์ ์๋๊ฑด ์์ฝ์ฌ ๋์, ๋น๋, pH(์ฐ๋)๊ฐ์ scale์ด ๋ค๋ฅด๋ค๋ ๊ฒ์ ๋๋ค. Scikit-learn์ StandardScaler ํด๋์ค๋ฅผ ์ฌ์ฉํด์ ํน์ฑ๋ค์ ํ์คํ ํ๊ฒ ์ต๋๋ค.
๊ทธ์ ์, Pandas Dataframe์ Numpy ๋ฐฐ์ด๋ก ๋ด๊พธ๊ณ Training, Test set๋ก ๋๋๊ฒ ์ต๋๋ค.
# Pandas Dataframe๋ฅผ numpy ๋ฐฐ์ด๋ก ๋ด๊ฟ
data = wine[['alcohol', 'sugar', 'pH']].to_numpy()
target = wine['class'].to_numpy()
- wine์ Dataframe์์ ์ฒ์ 3๊ฐ์ ์ด์ Numpy ๋ฐฐ์ด๋ก ๋ด๊ฟ์ data์ ์ ์ฅํ๊ณ , class ์ด์ ๋ด๊ฟ์ target ๋ฐฐ์ด์ ์ ์ฅํ๊ฒ ์ต๋๋ค.
- ๊ทธ๋ฆฌ๊ณ Training, Test set๋ก ๋๋๊ฒ ์ต๋๋ค.
# Data๋ฅผ train, test set๋ก ๋๋
from sklearn.model_selection import train_test_split
train_input, test_input, train_target, test_target = train_test_split(data, target, test_size=0.2, random_state=42)
์ฌ๊ธฐ์ ๋ง์ฝ์ train_test_split() ํจ์๋ ์ค์ ๊ฐ์ ์ง์ ํ์ง ์์ผ๋ฉด 25%๋ฅผ ํ ์คํธ ์ธํธ๋ก ๋๋๋๋ค.
์ฌ๊ธฐ์๋ ์ํ๊ฐ์๊ฐ ์ถฉ๋ถํ ๋ง์ผ๋ฏ๋ก 20%์ ๋๋ง Test Set๋ก ๋๋๊ฒ ์ต๋๋ค.
์ฌ๊ธฐ์ test_size=0.2๊ฐ ์ด๋ฐ ์๋ฏธ ์ ๋๋ค. ํ๋ฒ ๋๋ Training, Test set์ ํฌ๊ธฐ๋ฅผ ํ์ธํด ๋ณด๊ฒ ์ต๋๋ค.
print(train_input.shape, test_input.shape)
(5197, 3) (1300, 3)
- Training set๋ 5,197๊ฐ, Test set๋ 1,300๊ฐ๋ก ๋๋์ด ์ง๊ฑธ ํ์ธํ์ต๋๋ค.
- ์ด์ StandardScaler ํด๋์ค๋ฅผ ์ฌ์ฉํด์ Training set๋ฅผ ์ ์ฒ๋ฆฌ ํํ, ๊ฐ์ ๊ฐ์ฒด๋ฅผ ์ฌ์ฉํด์ Test set๋ ๋ณํ ํ๊ฒ ์ต๋๋ค.
from sklearn.preprocessing import StandardScaler
ss = StandardScaler()
ss.fit(train_input)
train_scaled = ss.transform(train_input)
test_scaled = ss.transform(test_input)
- ํ์ค ์ ์๋ก ๋ณํ๋ train_scaled, test_scaled๋ฅผ ์ฌ์ฉํด์ Logistic Regression ๋ชจ๋ธ์ ํ๋ จํ๊ฒ ์ต๋๋ค.
from sklearn.linear_model import LogisticRegression
lr = LogisticRegression()
lr.fit(train_scaled, train_target)
print(lr.score(train_scaled, train_target))
print(lr.score(test_scaled, test_target))
0.7808350971714451
0.7776923076923077
์๊ฐ์ธ๋ก, ์ ์๊ฐ ๋์ง ์๋ค์, ๋ญ๊ฐ ์ ์๊ฐ ๋ฎ์ผ๋๊น ๊ณผ๋์ ํฉ ๋๊ฑฐ ๊ฐ์ต๋๋ค.
๊ทธ๋๋ ์ด ๋ชจ๋ธ์ ํ๋จํ๊ธฐ ์ํด์ ํ์ตํ ๊ณ์ & ์ ํธ์ ์ถ๋ ฅํด ๋ณด๊ฒ ์ต๋๋ค.
# ํ์ตํ ๊ณ์, ์ ํธ ์ถ๋ ฅ
print(lr.coef_, lr.intercept_)
[[ 0.51270274. 1.6733911. -0.68767781 ]] [1.81777902]
Decision Tree(๊ฒฐ์ ํธ๋ฆฌ)
๋ชจ๋ธ์ ์ ํ๋ ํ๋จ ๋ช ์ค๋ช ์ ํ๊ธฐ ์ํด์๋ Decision Tree(๊ฒฐ์ ํธ๋ฆฌ) ๋ชจ๋ธ์ ์ฌ์ฉํ๊ธฐ๋ ํฉ๋๋ค.
- ๋ฐ์ดํฐ๋ฅผ ์ ๋๋์ ์๋ ์ง๋ฌธ๋ค์ ์ฐพ์์ ๋ณด๋ฉด ๊ณ์ ์ง๋ฌธ์ ์ถ๊ฐํด์ ๋ถ๋ฅ ์ ํ๋๋ฅผ ๋์ผ ์ ์์ต๋๋ค.
- ์ฌ๊ธฐ์ Scikit-learn์ Decision Tree Algorithm์ ์ ๊ณตํฉ๋๋ค.
- ํ๋ฒ Scikit-learn์ Decision Tree Algorithm์ ์ฌ์ฉํด์ ๊ฒฐ์ ํธ๋ฆฌ ๋ชจ๋ธ์ ํ๋ จ์์ผ ๋ณด๊ฒ ์ต๋๋ค.
- ์ฒ์ ๋ณด๋ ํด๋์ค ์ด์ง๋ง, ํ์ต ๋ฐฉ๋ฒ์ Logistic Regression ๊ณผ ๋์ผํฉ๋๋ค. fit() Method๋ฅผ ํธ์ถํด์ ๋ชจ๋ธ์ ํ๋ จํ ๋ค์, score() Method๋ก ์ ํ๋๋ฅผ ํ๊ฐํด ๋ณด๊ฒ ์ต๋๋ค.
from sklearn.tree import DecisionTreeClassifier
dt = DecisionTreeClassifier(random_state=42)
dt.fit(train_scaled, train_target)
print(dt.score(train_scaled, train_target)) # ํ๋ จ ์ธํธ
print(dt.score(test_scaled, test_target)) # ํ
์คํธ ์ธํธ
0.996921300750433
0.8592307692307692
๋ณด์๋ฉด, ํ๋ จ์ธํธ์ ๋ํ ์ ์๋ ๋์ง๋ง ํ ์คํธ ์ธํธ์ ์ฑ๋ฅ์ ๊ทธ์ ๋นํด์ ์กฐ๊ธ ๋ฎ์๊ฑฐ ๊ฐ์ต๋๋ค. ๊ณผ๋ ์ ํฉ ๋์๋ค๊ณ ๋ณผ ์ ์๊ฒ ๋ค์.
๊ทผ๋ฐ ์ด๋ป๊ฒ ํด์ผ ๊ฒฐ์ ํธ๋ฆฌ๋ฅผ ์๊ฐํ ํ ์ ์์๊น์? Scikit-learn์ plot_tree() ํจ์๋ฅผ ์ฌ์ฉํด์ ๊ฒฐ์ ํธ๋ฆฌ๋ฅผ ์ดํดํ๊ธฐ ์ฌ์ด ๊ทธ๋ฆผ
์ผ๋ก ์ถ๋ ฅํด์ค๋๋ค. ํ๋ฒ ํด๋ณด๊ฒ ์ต๋๋ค.
# ๊ฒฐ์ ํธ๋ฆฌ ์๊ฐํ
import matplotlib.pyplot as plt
from sklearn.tree import plot_tree
plt.figure(figsize=(10,7))
plot_tree(dt)
plt.show()
- ์ด ํธ๋ฆฌ๋ ์ผ๋ฐ์ ์ธ ํธ๋ฆฌ๋ค๊ณผ๋ ์ฝ๊ฐ ๋ค๋ฆ ๋๋ค. ๊ฒฐ์ ํธ๋ฆฌ์ ๋งจ์๋ root node, ๋งจ ์๋๋ leaf node๋ก ๋ด ๋๋ค.
- ์ผ๋ฐ์ ์ธ ํธ๋ฆฌ์๋ ๋ฐ๋๋ก ๋ณผ์ ์์ต๋๋ค.
- ์ฐธ๊ณ ์ฌํญ.
๋๊น์ง ์ค๋ฉด ์์,์์ ์ธ์ง ๊ตฌ๋ถ ๊ฐ๋ฅ(red,white wine ์ธ์ง)
๋ง์ง๋ง leaf node์ ์๋ ์์ธก๊ฐ์ด ์๋ก์ด sample x์ ๋ํ ์์ธก๊ฐ.
๋ถ๋ฅ์ธ ๊ฒฝ์ฐ ์ต๊ทผ์ ์ด์๊ณผ ๋น์ท, ๋ง์ง๋ง leaf node์ ๋ค์ด์๋ sample๋ค์ค ๋ค์์ class๊ฐ ์์ธก ํด๋์ค
ํ๊ท๋ ๋ง์ง๋ง leaf node์ sample๋ค์ target๊ฐ์ ํ๊ท ์ด x sample์ ์์ธก๊ฐ.
๊ทผ๋ฐ, ํธ๋ฆฌ๊ฐ ๋๋ฌด ๋ณต์กํด ๋ณด์ด์ง ์๋์? ํ๋ฒ ํธ๋ฆฌ์ ๊น์ด๋ฅผ ์ ํํด์ ์ถ๋ ฅํด ๋ณด๊ฒ ์ต๋๋ค.
- max_depth์ ๋งค๊ฐ๋ณ์๋ฅผ 1๋ก ์ฃผ๋ฉด, ๋ฃจํธ ๋ ธ๋๋ฅผ ์ ์ธํ๊ณ ํ๋์ ๋ ธ๋๋ฅผ ๋ ํ์ฅํด์ ๊ทธ๋ฆฝ๋๋ค.
- ๊ทธ๋ฆฌ๊ณ filled ๋งค๊ฐ๋ณ์์์ ํด๋์ค์ ๋ง๊ฒ ๋ ธ๋์ ์์ ์น ํ ์ ์์ต๋๋ค.
- feature_names ๋งค๊ฐ๋ณ์์์ ํน์ฑ์ ์ด๋ฆ์ ์ ๋ฌํ ์ ์์ต๋๋ค. ์ด๊ฒ์ ๋ณด๋ฉด ๋ ธ๋๊ฐ ์ด๋ ํ ํน์ฑ๋ค๋ก ๋๋๋์ง ์์ ์์ต๋๋ค. ํ๋ฒ ๊ทธ๋ ค๋ณด๊ฒ ์ต๋๋ค.
plt.figure(figsize=(10,7))
plot_tree(dt, max_depth=1, filled=True, feature_names=['alcohol', 'sugar', 'pH'])
plt.show()
๊ทธ๋ฌ๋ฉด ์ด์ ๊ทธ๋ฆผ์ ์ฝ๋ ๋ฐฉ๋ฒ์ ํ๋ฒ ์์๋ณด๊ฒ ์ต๋๋ค.
- ๋ฃจํธ ๋ ธ๋๋ ๋น๋(sugar)๊ฐ -0.239 ์ดํ์ธ์ง ์ง๋ฌธ์ ํฉ๋๋ค.
- ๋ง์ฝ ์ด๋ค ์ํ์ ๋น๋๊ฐ -0.239์ ๊ฐ๊ฑฐ๋ ์์ผ๋ฉด ์ผ์ชฝ ๊ฐ์ง๋ก ๊ฐ๋๋ค.
- ๊ทธ๋ ์ง ์์ผ๋ฉด ์ค๋ฅธ์ชฝ ๊ฐ์ง๋ก ์ด๋ํฉ๋๋ค.
- ์ฆ ์ผ์ชฝ์ด Yes, ์ค๋ฅธ์ชฝ์ด No์ ๋๋ค.
- ๋ฃจํธ ๋ ธ๋์ ์ด ์ํ ์ (samples)๋ 5,197๊ฐ ์ ๋๋ค.
- ์ด ์ค์์ ์์ฑ ํด๋์ค(๋ ๋ ์์ธ)๋ 1,258๊ฐ ์ด๊ณ ,
- ์์ฑ ํด๋์ค(ํ์ดํธ ์์ธ)๋ 3,939๊ฐ์ ๋๋ค. ์ด ๊ฐ์ด value์ ๋ํ๋ ์์ต๋๋ค.
- ์ผ์ชฝ ๋ ธ๋๋ฅผ ์ดํด๋ณด๊ฒ ์ต๋๋ค. ์ด ๋ ธ๋๋ ๋น๋๊ฐ ๋ ๋ฎ์์ง๋ฅผ ๋ฌผ์ด๋ณด๋ค์.
- ๋น๋๊ฐ -0.802์ ๊ฐ๊ฑฐ๋ ๋ฎ๋ค๋ฉด ๋ค์ ์ผ์ชฝ ๊ฐ์ง๋ก, ๊ทธ๋ ์ง ์์ผ๋ฉด ์ค๋ฅธ์ชฝ ๊ฐ์ง๋ก ์ด๋ํฉ๋๋ค.
- ์ด ๋ ธ๋์์ ์์ฑ ํด๋์ค์ ์์ฑ ํด ๋์ค์ ์ํ ๊ฐ์๋ ๊ฐ๊ฐ 1,177๊ฐ์ 1,745๊ฐ์ ๋๋ค.
- ๋ฃจํธ ๋ ธ๋๋ณด๋ค ์์ฑ ํด๋์ค, ์ฆ ํ์ดํธ ์์ธ์ ๋น์จ์ด ํฌ๊ฒ ์ค์ด๋ค์์ต๋๋ค
- ๊ทธ ์ด์ ๋ ์ค๋ฅธ์ชฝ ๋ ธ๋๋ฅผ ๋ณด๋ฉด ์ ์ ์์ต๋๋ค.
- ์ค๋ฅธ์ชฝ ๋ ธ๋๋ ์์ฑ ํด๋์ค๊ฐ 81๊ฐ, ์์ฑ ํด๋์ค๊ฐ 2,194๊ฐ๋ก ๋๋ถ๋ถ์ ํ์ดํธ ์์ธ ์ํ์ด ์ด ๋ ธ๋๋ก ์ด๋ํ์ต๋๋ค.
- ๋ ธ๋์ ๋ฐํ ์๊น์ ์ ์ฌํ ๋ณด์ธ์. ๋ฃจํธ ๋ ธ๋๋ณด๋ค ์ด ๋ ธ๋๊ฐ ๋ ์งํ๊ณ , ์ผ์ชฝ ๋ ธ๋๋ ๋ ์ฐํด์ง์ง ์์๋์?
- plot_tree() ํจ์์์ filled=True๋ก ์ง์ ํ๋ฉด ํด๋์ค ๋ง๋ค ์๊น์ ๋ถ์ฌํ๊ณ ์ด๋ค ํด๋์ค์ ๋น์จ์ด ๋์์ง๋ฉด ์ ์ ์งํ์ ์ผ๋ก ํ์ํฉ๋๋ค. ์์ฃผ ์ง๊ด์ ์ด๋ค์.
๊ฒฐ์ ํธ๋ฆฌ์์ ์์ธก ํ๋ ๋ฐฉ๋ฒ์ ๊ฐ๋จํฉ๋๋ค. Leaf Node์์ ๊ฐ์ฅ ๋ง์ ํด๋์ค๊ฐ ์์ธก ํด๋์ค๊ฐ ๋ฉ๋๋ค.
์์์ ๋ณด์๋ k-์ต๊ทผ์ ์ด์๊ณผ ๋น์ทํฉ๋๋ค.
๋ง์ฝ์ ๊ฒฐ์ ํธ๋ฆฌ์ ์ฑ์ฅ์ ๋ฉ์ถ๋ค๋ฉด, ์ผ์ชฝ ๋ ธ๋์ ๋๋ฌํ ์ํ๊ณผ ์ค๋ฅธ์ชฝ ๋ ธ๋์ ๋๋ฌํ ์ํ์ ๋ชจ๋ ์์ฑ ํด๋์ค๋ก ์์ธก ๋ฉ๋๋ค.
๋ ๋ ธ๋ ๋ชจ๋ ์์ฑ ํด๋์ค์ ๊ฐ์๊ฐ ๋ง๊ธฐ ๋๋ฌธ์ ๋๋ค.
์์(ํ๋์), ์์(๋นจ๊ฐ์)์ผ๋ก ํํ
ํ ์คํธ ์กฐ๊ฑด(sugar), ๋ถ์๋ (gini), ์ด ์ํ์(samples), ํด๋์ค๋ณ ์ํ ์(value)
leaf node ๋๋๋ ๊ธฐ์ค:sugar
- ๊ทผ๋ฐ...? ๋ ธ๋ ์์ ์์ gini ๋ผ๋ ๊ฒ์ด ์๋ค์..? ์ด๊ฒ ๋ญ์ง ์์ ๋ณด๊ฒ ์ต๋๋ค.
Gini Impurity (์ง๋ ๋ถ์๋)
Gini๋ ์ง๋๋ถ์๋ ์ ๋๋ค. DecisionTreeClassifier ํด๋์ค์ Criterion ๋งค๊ฐ๋ณ์์ ๊ธฐ๋ณธ๊ฐ์ด 'gini'์ ๋๋ค.
Criterion ๋งค๊ฐ๋ณ์์ ์๋๋ ๋ ธ๋์์ ๋ฐ์ดํฐ๋ฅผ ๋ถํ ํ ๊ธฐ์ค์ ์ ํ๋ ๊ฒ์ ๋๋ค. ๊ทธ๋ฌ๋ฉด ์ง๋ ๋ถ์๋๋ ์ด๋ป๊ฒ ๊ตฌํ๋ ๊ฒ์ผ๊น์?
์ง๋ ๋ถ์๋๋ ํด๋์ค์ ๋น์จ์ ์ ๊ณฑํด์ ๋ํ ๋งํผ 1์ ๋นผ๋ฉด ๋ฉ๋๋ค.
- ๋ง์ฝ, ๋ค์ค ํด๋์ค๋ฉด ํด๋์ค๊ฐ ๋ ๋ง๊ฒ ์ง๋ง? ๊ณ์ฐ ํ๋ ๋ฐฉ๋ฒ์ ๋์ผํฉ๋๋ค.
- ๊ทธ๋ฌ๋ฉด ์ด์ ํธ๋ฆฌ ๊ทธ๋ฆผ์ ์๋ ๋ฃจํธ ๋ ธ๋์ ์ง๋ ๋ถ์๋๋ฅผ ๊ณ์ฐํด ๋ณด๊ฒ ์ต๋๋ค.
- ๋ฃจํธ ๋ ธ๋๋ 5,197๊ฐ์ ์ํ์ด ์๊ณ 1,258๊ฐ์ ์์ฑ ํด๋์ค, 3,939๊ฐ๊ฐ ์์ฑ ํด๋์ค ์ ๋๋ค.
- ๋ฐ๋ผ์ ๋ค์๊ณผ ๊ฐ์ด ์ง๋ ๋ถ์๋๋ฅผ ๊ณ์ฐํ ์ ์์ต๋๋ค.
๊ทธ๋ฌ๋ฉด, ํ๋ฒ ์ผ, ์ค๋ฅธ์ชฝ์ ์ง๋ ๋ถ์๋๋ ํ๋ฒ ๊ณ์ฐํด ๋ณด๊ฒ ์ต๋๋ค. ๋ง์ฝ 100๊ฐ์ ์ํ์ด ์๋ ๋ ธ๋์ ํด๋์ค์ ๋น์จ์ด 1/2๋ฉด,
์ง๋๋ถ์๋๋ 0.5๊ฐ ๋์ด ์ต์ ์ด ๋ฉ๋๋ค.
๋ ธ๋์ ํ๋์ ํด๋์ค๊ฐ ์์ผ๋ฉด ์ง๋ ๋ถ์๋๋ 0์ด ๋์ด์ ๊ฐ์ฅ ์์ต๋๋ค. ์ด๋ฐ ๋ ธ๋๋ฅผ ์์ ๋ ธ๋๋ผ๊ณ ๋ถ๋ฆ ๋๋ค.
์๋ฅผ ๋ค์ด ์์ ํธ๋ฆฌ ๊ทธ๋ฆผ์์ ๋ฃจํธ ๋ ธ๋๋ฅผ ๋ถ๋ชจ ๋ ธ๋๋ผ ํ๋ฉด ์ผ์ชฝ ๋ ธ๋์ ์ค๋ฅธ์ชฝ ๋ ธ๋๊ฐ ์์ ๋ ธ๋ ๊ฐ ๋ฉ๋๋ค
์ผ์ชฝ ๋ ธ๋๋ก๋ 2,922๊ฐ์ ์ํ์ด ์ด๋ํ๊ณ ์ค๋ฅธ์ชฝ ๋ ธ๋๋ก๋ 2,275๊ฐ์ ์ํ์ด ์ด๋ํ์ต๋๋ค.
๊ทธ๋ฌ๋ฉด, ๋ถ์๋์ ์ฐจ์ด๋ ๋ค์๊ณผ ๊ฐ์ด ๊ณ์ฐํฉ๋๋ค.
- ์ด๋ฌํ ๋ถ๋ชจ์ ์์ ๋ ธ๋ ์ฌ์ด์ ๋ถ์๋ ์ฐจ์ด๋ฅผ ์ ๋ณด์ด๋(Imformation Gain)์ด๋ผ๊ณ ๋ถ๋ฆ ๋๋ค.
- ๊ทผ๋ฐ, Scikit-learn์ ๋ ๋ค๋ฅธ ๋ถ์๋ ๊ธฐ์ค์ด ์์ต๋๋ค.
- DecisionTreeClassifier ํด๋์ค์์ criterion=‘entropy’๋ฅผ ์ง์ ํ์ฌ ์ํธ๋กํผ ๋ถ์๋๋ฅผ ์์ฉํ ์ ์์ต๋๋ค.
์ํธ๋กํผ ๋ถ์๋๋ ๋ ธ๋์ ํด๋์ค ๋น์จ์ ์ฌ์ฉํ์ง๋ง ์ง๋ ๋ถ์๋์ฒ๋ผ ์ ๊ณฑ์ด ์๋๋ผ ๋ฐ์ด 2์ธ ๋ก๊ทธ๋ฅผ ์์ฉํ์ฌ ๊ณฑํฉ๋๋ค.
์๋ฅผ ๋ค์ด ๋ฃจํธ ๋ ธ๋์ ์ํธ๋กํผ ๋ถ์๋๋ ๋ค์๊ณผ ๊ฐ์ด ๊ณ์ฐํ ์ ์์ต๋๋ค.
- ๋ณดํต ๊ธฐ๋ณธ๊ฐ์ธ ์ง๋ ๋ถ์๋์ ์ํธ๋กํผ ๋ถ์๋๊ฐ ๋ง๋ ๊ฒฐ๊ณผ์ ์ฐจ์ด๋ ํฌ์ง ์์ต๋๋ค. ์ฌ๊ธฐ์๋ ๊ธฐ๋ณธ ๊ฐ์ธ ์ง๋ ๋ถ์๋๋ฅผ ๊ณ์ ์ฌ์ฉํ๊ฒ ์ต๋๋ค.
๊ฐ์ง์น๊ธฐ
๊ฒฐ์ ํธ๋ฆฌ์์ ๊ฐ์ง์น๊ธฐ๋ฅผ ํ๋ ๊ฐ์ฅ ๊ฐ๋จํ ๋ฐฉ๋ฒ์ ์๋ผ๋ ์ ์๋ ํธ๋ฆฌ์ ์ต๋ ๊น์ด๋ฅผ ์ง์ ํ๋ ๊ฒ์ ๋๋ค.
DecisionTreeClassifier ํด๋์ค์ max_depth ๋งค๊ฐ๋ณ์๋ฅผ 3์ผ๋ก ์ง์ ํ์ฌ ๋ชจ๋ธ์ ๋ง๋ค์ด ๋ณด๊ฒ ์ต๋๋ค.
์ด๋ ๊ฒ ํ๋ฉด ๋ฃจํธ ๋ ธ๋ ์๋๋ก ์ต๋ 3๊ฐ์ ๋ ธ๋๊น์ง๋ง ์ฑ์ฅ ํ ์์์ต๋๋ค.
# Logistic ํ๊ท๋ ์ ํํจ์ ๊ฐ์ค์น๋ฅผ ํ์ตํ๊ธฐ ์ํ์ฌ ํน์ฑ์ scale์ ๋ง์ถค.
# ๊ฒฐ์ Tree๋ ์ ํํจ์ ํ๋ จํ๋ ์๊ณ ๋ฆฌ์ฆ์ด ์๋๊ธฐ ๋๋ฌธ์ ํน์ฑ์ scale๋ฅผ ์กฐ์ ํ ํ์ ์์.
dt = DecisionTreeClassifier(max_depth=3, random_state=42)
dt.fit(train_scaled, train_target)
print(dt.score(train_scaled, train_target))
print(dt.score(test_scaled, test_target))
0.8454877814123533
0.8415384615384616
- ํ๋ จ ์ธํธ์ ์ฑ๋ฅ์ ๋ฎ์์ก์ง๋ง, ํ ์คํธ ์ธํธ์ ์ฑ๋ฅ์ ๊ทธ๋๋ก ์ ๋๋ค. ํ๋ฒ plot_tree() ํจ์๋ก ๊ทธ๋ ค๋ณด๊ฒ ์ต๋๋ค.
plt.figure(figsize=(20,15))
plot_tree(dt, filled=True, feature_names=['alcohol', 'sugar', 'pH'])
plt.show()
๋ง์ง๋ง์ผ๋ก ๊ฒฐ์ ํธ๋ฆฌ๋ ์ด๋ค ํน์ฑ์ด ๊ฐ์ฅ ์ ์ฉํ์ง ๋ํ๋ด๋ ํน์ฑ ์ค์๋๋ฅผ ๊ณ์ฐํด ์ค๋๋ค.
์ด ํธ๋ฆฌ ์ ๋ฃจํธ ๋ ธ๋์ ๊น์ด 1์์ ๋น๋๋ฅผ ์ฌ์ฉํ๊ธฐ ๋๋ฌธ์ ์๋ง๋ ๋น๋(sugar)๊ฐ ๊ฐ์ฅ ์ ์ฉํ ํน์ฑ ์ค ํ ๋์ผ ๊ฒ ๊ฐ์ต๋๋ค.
ํน์ฑ ์ค์๋๋ ๊ฒฐ์ ํธ๋ฆฌ ๋ชจ๋ธ์ feature_importances_ ์์ฑ์ ์ ์ฅ๋์ด ์์ต ๋๋ค. ์ด ๊ฐ์ ์ถ๋ ฅํด ํ์ธํด ๋ณด์ฃ .
# ํน์ฑ ์ค์๋ ์ถ๋ ฅ. 0.868๋ก sugar์ ๊ฐ์ด ๊ฐ์ฅ ๋์ผ๋ฏ๋ก ๊ฐ์ฅ ์ค์ํ๊ฒ ์ฌ์ฉ๋๋ ํน์ฑ์ด๋ผ๋๊ฑธ ์์ ์์
print(dt.feature_importances_)
[0.12345626 0.86862934 0.0079144 ]
- ์ด ๊ฐ๋ค์ ํ์ธํด ๋ณด๋ฉด ๋๋ฒ์งธ ํน์ฑ์ธ ๋น๋๊ฐ 0.87์ ๋๋ก ํน์ฑ์ค์๋๊ฐ ๊ฐ์ฅ๋๋ค์.
- ๊ทธ ๋ค์์์ฝ์ฌ, ๋์, pH ์์ ๋๋ค. ์ด ๊ฐ์ ๋ชจ๋ ๋ํ๋ฉด 1์ด ๋ฉ๋๋ค.
- ํน์ฑ ์ค์๋๋ ๊ฐ ๋ ธ๋์ ์ ๋ณด ์ด๋๊ณผ ์ ์ฒด ์ํ์ ๋ํ ๋น์จ์ ๊ณฑํ ํ ํน์ฑ๋ณ๋ก ๋ํ์ฌ ๊ณ์ฐํฉ๋๋ค.
- ํน์ฑ ์ค์๋๋ฅผ ํ์ฉํ๋ฉด ๊ฒฐ์ ํธ๋ฆฌ ๋ชจ๋ธ์ ํน์ฑ ์ ํ์ ํ์ฉํ ์ ์์ต๋๋ค.
- ์ด๊ฒ์ด ๊ฒฐ์ ํธ๋ฆฌ ์๊ณ ๋ฆฌ์ฆ์ ๋ ๋ค๋ฅธ ์ฅ์ ์ค ํ๋ ์ ๋๋ค.
KeyWords
๊ฒฐ์ ํธ๋ฆฌ
- ๊ฒฐ์ ํธ๋ฆฌ๋ ์ / ์๋์ค์ ๋ํ ์ง๋ฌธ์ ์ด์ด๋๊ฐ๋ฉด์ ์ ๋ต์ ์ฐพ์ ํ์ตํ๋ ์๊ณ ๋ฆฌ์ฆ์ ๋๋ค. ๋น๊ต์ ์์ธก๊ณผ์ ์ ์ดํดํ๊ธฐ ์ฝ๊ณ ์ฑ๋ฅ๋ ๋ฐ์ด๋ฉ๋๋ค.
๋ถ์๋
- ๋ถ์๋๋ ๊ฒฐ์ ํธ๋ฆฌ๊ฐ ์ต์ ์ ์ง๋ฌธ์ ์ฐพ๊ธฐ ์ํ ๊ธฐ์ค์ ๋๋ค. ์ฌ์ดํท๋ฐ์ ์ง๋ ๋ถ์๋์ ์ํธ๋กํผ ๋ถ์๋๋ฅผ ์ ๊ณตํฉ๋๋ค.
์ ๋ณด ์ด๋
- ์ ๋ณด ์ด๋์ ๋ถ๋ชจ ๋ ธ๋์ ์์ ๋ ธ๋์ ๋ถ์๋ ์ฐจ์ด์ ๋๋ค. ๊ฒฐ์ ํธ๋ฆฌ ์๊ณ ๋ฆฌ์ฆ์ ์ ๋ณด ์ด๋์ด ์ต๋ํ๋๋๋ก ํ์ตํฉ๋๋ค.
๊ฒฐ์ ํธ๋ฆฌ
- ๊ฒฐ์ ํธ๋ฆฌ๋ ์ ํ ์์ด ์ฑ์ฅํ๋ฉด ํ๋ จ ์ธํธ์ ๊ณผ๋์ ํฉ ๋๊ธฐ ์ฝ์ต๋๋ค.
๊ฐ์ง์น๊ธฐ
- ๊ฐ์ง์น๊ธฐ๋ ๊ฒฐ์ ํธ๋ฆฌ์ ์ฑ์ฅ์ ์ ํํ๋ ๋ฐฉ๋ฒ์ ๋๋ค. ์ฌ์ด๊ฒ๋ฐ์ ๊ฒฐ์ ํธ๋ฆฌ ์๊ณ ๋ฆฌ์ฆ์ ์ฌ๋ฌ ๊ฐ์ง ๊ฐ์น์ง๊ธฐ ๋งค๊ฐ๋ณ์ ๋ฅผ ์ ๊ณตํฉ๋๋ค.
ํน์ฑ ์ค์๋
- ํน์ฑ ์ค์๋๋ ๊ฒฐ์ ํธ๋ฆฌ์ ์์ฉ๋ ํน์ฑ์ด ๋ถ์๋๋ฅผ ๊ฐ์ํ๋๋ฐ ๊ธฐ์ฌํ ์ ๋๋ฅผ ๋ํ๋ด๋ ๊ฐ์ ๋๋ค. ํน์ฑ ์ค์๋๋ฅผ ๊ณ์ฐํ ์ ์๋ ๊ฒ์ด ๊ฒฐ์ ํธ๋ฆฌ์ ๋๋ค๋ฅธ ํฐ ์ฅ์ ์ ๋๋ค.
Pandas
- info()๋ ๋ฐ์ดํฐํ๋ ์์ ์์ฝ๋ ์ ๋ณด๋ฅผ ์ถ๋ ฅํฉ๋๋ค. ์ธํ ์ค์ ์ปฌ๋ผ ํ์ ์ ์ถ๋ ฅํ๊ณ ๋ (null)์ด ์๋ ๊ฐ์ ๊ฐ์, ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋์์ ๊ณตํฉ๋๋ค.
- verbose ๋งค๊ฐ๋ณ์์ ๊ธฐ๋ณธ๊ฐ True๋ฅผ False๋ก ๋น๊พธ๋ฉด ๊ฐ ์ด์ ๋ํ ์ ๋ณด๋ฅผ ์ถ๋ ฅํ์ง ์์ต๋๋ค.
- describe()๋๋ฐ์ดํฐํ๋ ์ ์ด์ ํต๊ณ ๊ฐ์์ ๊ณตํฉ๋๋ค. ์์นํ์ผ ๊ฒฝ์ฐ ์ต์, ์ต๋, ํ๊ท , ํ ์คํธ์ฐจ์์ฌ๋ถ์๊ฐ๋ฑ์ด ์ถ๋ ฅ๋ฉ๋๋ค.
- ๋ฌธ์์ด ๊ฐ์ ๊ฐ์ฒด ํ์ ์ ์ด์๊ฐ์ฅ์ง์ฃผ๋ฑ์ฅํ๋๊ฐ๊ณผํ์๋ฑ์ด์ถ๋ ฅ๋ฉ๋๋ค. percentiles ๋งค๊ฐ๋ณ์์์ ๋ฐฑ๋ถ์์๋ฅผ ์ง์ ํฉ๋๋ค. ๊ธฐ๋ณธ๊ฐ์ [0.25, 0.5, 0.75]์ ๋๋ค.
scikit-learn
- DecisionTreeClassifier๋ ๊ฒฐ์ ํธ๋ฆฌ ๋ถ๋ฅ ํด๋์ค์ ๋๋ค.
- criterion ๋งค๊ฐ๋ณ์๋ ๋ถ์๋๋ฅผ ์ง์ ํ๋ฉฐ ๊ธฐ๋ณธ๊ฐ์ ์ง๋ ๋ถ์๋๋ฅผ ์๋ฏธํ๋ ‘gini’์ด๊ณ ‘entropy’๋ฅผ ์ ํํ์ฌ ์ํธ๋กํผ ๋ถ์๋๋ฅผ ์ฌ์ฉํ ์ ์์ต๋๋ค.
- splitter ๋งค๊ฐ๋ณ์๋ ๋ ธ๋๋ฅผ ๋ถํ ํ๋ ์ ๋ต์ ์ ํํฉ๋๋ค. ๊ธฐ๋ณธ๊ฐ์ ‘best’๋ก ์ ๋ณด ์ด๋์ด ์ต๋๊ฐ ๋๋๋ก ๋ถํ ํฉ๋๋ค. ‘random’์ด๋ฉด ์์๋ก ๋ ธ๋๋ฅผ ๋ถํ ํฉ๋๋ค.
- max_depth๋ ํธ๋ฆฌ๊ฐ ์ฑ์ฅํ ์ต๋ ๊น์ด๋ฅผ ์ง์ ํฉ๋๋ค. ๊ธฐ๋ณธ๊ฐ์ None์ผ๋ก ๋ฆฌํ ๋ ธ๋๊ฐ ์์ํ๊ฑฐ๋ min_samples_split๋ณด๋ค ์ํ ๊ฐ์๊ฐ ์ ์ ๋๊น์ง ์ฑ์ฅํฉ๋๋ค.
- min_samples_split์ ๋ ธ๋๋ฅผ ๋๋๊ธฐ ์ํ ์ต์ ์ํ ๊ฐ์์ ๋๋ค. ๊ธฐ๋ณธ๊ฐ์ 2์ ๋๋ค.
- max_features ๋งค๊ฐ๋ณ์๋ ์ต์ ์ ๋ถํ ์ ์ํด ํ์ํ ํน์ฑ์ ๊ฐ์๋ฅผ ์ง์ ํฉ๋๋ค. ๊ธฐ๋ณธ๊ฐ์ None์ผ๋ก ๋ชจ๋ ํน์ฑ์ ์ฌ์ฉํฉ๋๋ค.
- plot tree()๋ ๊ฒฐ์ ํธ๋ฆฌ ๋ชจ๋์ ์๊ฐํํฉ๋๋ค. ์ฒซ ๋ฒ์งธ ๋งค๊ฐ๋ณ์๋ก ๊ฒฐ์ ํธ๋ฆฌ ๋ชจ๋ธ ๊ฐ์ฒด๋ฅผ ์ ๋ฌํฉ๋๋ค.
- max_depth ๋งค๊ฐ๋ณ์๋ก ๋ํ๋ผ ํธ๋ฆฌ์ ๊น์ด๋ฅผ ์ง์ ํฉ๋๋ค. ๊ธฐ๋ณธ๊ฐ์ None์ผ๋ก ๋ชจ๋ ๋ ธ๋ ๋ฅผ์ถ๋ ฅํฉ๋๋ค.
- feature_names ๋งค๊ฐ๋ณ์๋ก ํน์ฑ์ ์ด๋ฆ์ ์ง์ ํ ์ ์์ต๋๋ค.
- filled ๋งค๊ฐ๋ณ์๋ฅผ True๋ก ์ง์ ํ๋ฉด ํ๊น๊ฐ์ ๋ฐ๋ผ๋ ธ๋ ์์ ์์ ์ฑ์๋๋ค.
๋ฐ์ํ
'๐น๏ธ ํผ๊ณต๋จธ์ ' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
[ํผ๊ณต๋จธ์ ] Tree's Ensemble - Random Forest (๋๋ค ํฌ๋ ์คํธ) (0) | 2024.07.30 |
---|---|
[ํผ๊ณต๋จธ์ ] Cross-Validation & Grid Search (0) | 2024.07.30 |
[ํผ๊ณต๋จธ์ ] Stochastic Gradient Descent (ํ๋ฅ ์ ๊ฒฝ์ฌ ํ๊ฐ๋ฒ) (0) | 2023.11.05 |
[ํผ๊ณต๋จธ์ ] Logistic Regression (๋ก์ง์คํฑ ํ๊ท) (0) | 2023.09.25 |
[ML] ํน์ฑ ๊ณตํ๊ณผ ๊ท์ (0) | 2023.09.24 |