A A
[ML] Model์˜ ํ•™์Šต๊ณผ ํ‰๊ฐ€
๋จธ์‹ ๋Ÿฌ๋‹ ๋ชจ๋ธ์˜ ํ•™์Šต๊ณผ ํ‰๊ฐ€ ๊ณผ์ •์—์„œ ์ค‘์š”ํ•œ ์š”์†Œ๋“ค์— ๋Œ€ํ•ด ๋‹ค๋ฃจ๊ฒ ์Šต๋‹ˆ๋‹ค.

ํ›ˆ๋ จ ๋ฐ์ดํ„ฐ์™€ ํ…Œ์ŠคํŠธ ๋ฐ์ดํ„ฐ์˜ ๋ถ„ํ• 

๋ชจ๋ธ์˜ ์„ฑ๋Šฅ์„ ์ •ํ™•ํžˆ ํ‰๊ฐ€ํ•˜๊ณ  ์ผ๋ฐ˜ํ™” ๋Šฅ๋ ฅ์„ ํ™•์ธํ•˜๊ธฐ ์œ„ํ•ด ๋ฐ์ดํ„ฐ์…‹์„ ํ›ˆ๋ จ ๋ฐ์ดํ„ฐ์™€ ํ…Œ์ŠคํŠธ ๋ฐ์ดํ„ฐ๋กœ ๋ถ„ํ• ํ•ฉ๋‹ˆ๋‹ค.
  • ํ›ˆ๋ จ ๋ฐ์ดํ„ฐ๋Š” ๋ชจ๋ธ์„ ํ•™์Šต์‹œํ‚ค๋Š” ๋ฐ ์‚ฌ์šฉ๋˜๋ฉฐ, ํ…Œ์ŠคํŠธ ๋ฐ์ดํ„ฐ๋Š” ํ•™์Šต๋˜์ง€ ์•Š์€ ๋ฐ์ดํ„ฐ์—์„œ ๋ชจ๋ธ์˜ ์„ฑ๋Šฅ์„ ํ‰๊ฐ€ํ•˜๋Š” ๋ฐ ์‚ฌ์šฉ๋ฉ๋‹ˆ๋‹ค.
  • ์ผ๋ฐ˜์ ์ธ ๋น„์œจ:
    • Train(ํ›ˆ๋ จ ๋ฐ์ดํ„ฐ) : Test(ํ…Œ์ŠคํŠธ ๋ฐ์ดํ„ฐ) = 70:30
    • Train(ํ›ˆ๋ จ ๋ฐ์ดํ„ฐ) : Test(ํ…Œ์ŠคํŠธ ๋ฐ์ดํ„ฐ) = 80:20

๋ฐ์ดํ„ฐ ๋ถ„ํ•  ๋ฐฉ๋ฒ•

Train(ํ›ˆ๋ จ ๋ฐ์ดํ„ฐ) & Test(ํ…Œ์ŠคํŠธ ๋ฐ์ดํ„ฐ)๋ฅผ ์–ด๋– ํ•œ ๋น„์œจ๋กœ ๋‚˜๋ˆ„๋Š”์ง€ ์•Œ์•˜์Šต๋‹ˆ๋‹ค. ๊ทธ๋Ÿฌ๋ฉด ์–ด๋– ํ•œ ๋ฐฉ๋ฒ•์œผ๋กœ ๋ถ„๋ฆฌํ• ๊นŒ์š”?
  • ์ž„์˜ ๋ถ„ํ• (Random Split):
    • ๋ฐ์ดํ„ฐ๋ฅผ ๋ฌด์ž‘์œ„๋กœ ์„ž์€ ํ›„, ์ง€์ •๋œ ๋น„์œจ์— ๋”ฐ๋ผ ํ›ˆ๋ จ ๋ฐ์ดํ„ฐ์™€ ํ…Œ์ŠคํŠธ ๋ฐ์ดํ„ฐ๋กœ ๋ถ„ํ• ํ•ฉ๋‹ˆ๋‹ค.
    • ์ด๋Š” ๋ฐ์ดํ„ฐ์˜ ์ˆœ์„œ๊ฐ€ ๋ชจ๋ธ ์„ฑ๋Šฅ์— ์˜ํ–ฅ์„ ๋ฏธ์น˜์ง€ ์•Š๋Š” ๊ฒฝ์šฐ์— ์œ ์šฉํ•ฉ๋‹ˆ๋‹ค.
  • ์ธตํ™” ๋ถ„ํ• (Stratified Split):
    • ๋ฐ์ดํ„ฐ์˜ ํด๋ž˜์Šค ๋ถ„ํฌ๋ฅผ ์œ ์ง€ํ•˜๋ฉด์„œ ํ›ˆ๋ จ ๋ฐ์ดํ„ฐ์™€ ํ…Œ์ŠคํŠธ ๋ฐ์ดํ„ฐ๋กœ ๋ถ„ํ• ํ•ฉ๋‹ˆ๋‹ค.
    • ์ด๋Š” ํด๋ž˜์Šค ๋ถˆ๊ท ํ˜• ๋ฌธ์ œ๊ฐ€ ์žˆ๋Š” ๊ฒฝ์šฐ์— ์œ ์šฉํ•˜์—ฌ, ํ›ˆ๋ จ ๋ฐ ํ…Œ์ŠคํŠธ ๋ฐ์ดํ„ฐ ๋ชจ๋‘์—์„œ ํด๋ž˜์Šค ๋น„์œจ์ด ๋™์ผํ•˜๊ฒŒ ์œ ์ง€๋˜๋„๋ก ํ•ฉ๋‹ˆ๋‹ค.

๊ต์ฐจ๊ฒ€์ฆ (Cross-Validation)

๊ต์ฐจ๊ฒ€์ฆ (Cross-Validation)์€ ๋ฐ์ดํ„ฐ๋ฅผ ๋ถ„ํ• ํ•จ์œผ๋กœ์จ,
Bias(ํŽธํ–ฅ)์„ ์ค„์ด๊ณ  ๋ชจ๋ธ์˜ ์„ฑ๋Šฅ์„ ๋†’์ด๋ฉด์„œ ์ •ํ™•ํ•˜๊ฒŒ ํ‰๊ฐ€ํ•˜๊ธฐ ์œ„ํ•œ ๋ชฉ์ ์œผ๋กœ ์‚ฌ์šฉ๋ฉ๋‹ˆ๋‹ค.
  • ๋ฐ์ดํ„ฐ์˜ ๋ถ„ํ• ๋กœ ์ธํ•œ ํŽธํ–ฅ์„ ์ค„์ด๊ณ  ๋ชจ๋ธ์˜ ์ผ๋ฐ˜ํ™” ์„ฑ๋Šฅ์„ ๋ณด๋‹ค ์ •ํ™•ํ•˜๊ฒŒ ํ‰๊ฐ€ํ•˜๊ธฐ ์œ„ํ•ด ์‚ฌ์šฉ๋ฉ๋‹ˆ๋‹ค.
  • k-Fold Cross-Validation
    • ๋ฐ์ดํ„ฐ์…‹์„ k๊ฐœ์˜ ํด๋“œ(fold)๋กœ ๋‚˜๋ˆ„๊ณ , ๊ฐ ํด๋“œ๋Š” ํ•œ ๋ฒˆ์”ฉ ํ…Œ์ŠคํŠธ ๋ฐ์ดํ„ฐ๋กœ ์‚ฌ์šฉ๋˜๋ฉฐ, ๋‚˜๋จธ์ง€ k-1๊ฐœ์˜ ํด๋“œ๋Š” ํ›ˆ๋ จ ๋ฐ์ดํ„ฐ๋กœ ์‚ฌ์šฉ๋ฉ๋‹ˆ๋‹ค. ์ด ๊ณผ์ •์ด k๋ฒˆ ๋ฐ˜๋ณต๋˜๋ฉฐ, ๊ฐ ๋ฐ˜๋ณต์˜ ์„ฑ๋Šฅ์„ ํ‰๊ท ํ•˜์—ฌ ์ตœ์ข… ์„ฑ๋Šฅ ํ‰๊ฐ€์— ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.
    • ์˜ˆ์‹œ๋ฅผ ํ•œ๋ฒˆ, 5-Fold Cross-Validation์œผ๋กœ ๋“ค์–ด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.
    • 100๊ฐœ์˜ ๋ฐ์ดํ„ฐ๊ฐ€ ์žˆ๋‹ค๋ฉด, 100๊ฐœ์˜ ๋ฐ์ดํ„ฐ ์ „์ฒด๋ฅผ 5๊ฐœ์˜ ํด๋“œ๋กœ ๋‚˜๋ˆ„๊ณ , ๊ฐ ํด๋“œ(20๊ฐœ ๋ฐ์ดํ„ฐ)๊ฐ€ ํ•œ๋ฒˆ์”ฉ ํ…Œ์ŠคํŠธ ๋ฐ์ดํ„ฐ๋กœ ์‚ฌ์šฉ๋ฉ๋‹ˆ๋‹ค. ๋‚˜๋จธ์ง€ 4๊ฐœ์˜ ํด๋“œ(80๊ฐœ ๋ฐ์ดํ„ฐ)๋Š” ํ›ˆ๋ จ ๋ฐ์ดํ„ฐ๋กœ ์‚ฌ์šฉ๋ฉ๋‹ˆ๋‹ค. ์ด ๊ณผ์ •์„ 5๋ฒˆ ๋ฐ˜๋ณตํ•˜์—ฌ ์„ฑ๋Šฅ ์ง€ํ‘œ๋ฅผ ํ‰๊ท ํ•ฉ๋‹ˆ๋‹ค.

์ถœ์ฒ˜: https://docs.ultralytics.com/guides/kfold-cross-validation/


๊ต์ฐจ๊ฒ€์ฆ (Cross-Validation) Example Code

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.linear_model import LinearRegression
from sklearn.naive_bayes import GaussianNB
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.metrics import confusion_matrix, roc_curve, roc_auc_score, accuracy_score, mean_squared_error, r2_score
# ๋ฐ์ดํ„ฐ ์ƒ์„ฑ
from sklearn.datasets import load_iris
from sklearn.datasets import load_digits
from sklearn.datasets import load_breast_cancer
# ๋ฐ์ดํ„ฐ ์ƒ์„ฑ
digits = load_digits()  # ์†๊ธ€์”จ ์ˆซ์ž ๋ฐ์ดํ„ฐ์…‹์„ ๋กœ๋“œ
X = digits.data  # ์ž…๋ ฅ ๋ฐ์ดํ„ฐ (ํŠน์ง•)
y = digits.target  # ์ถœ๋ ฅ ๋ฐ์ดํ„ฐ (๋ ˆ์ด๋ธ”)

# ์ž…๋ ฅ ๋ฐ์ดํ„ฐ์™€ ๋ ˆ์ด๋ธ” ์ถœ๋ ฅ
print(f"X: {X}")  # ํŠน์ง• ๋ฐ์ดํ„ฐ ์ถœ๋ ฅ
print(f"y: {y}")  # ๋ ˆ์ด๋ธ” ์ถœ๋ ฅ
X: [[ 0.  0.  5. ...  0.  0.  0.]
 [ 0.  0.  0. ... 10.  0.  0.]
 [ 0.  0.  0. ... 16.  9.  0.]
 ...
 [ 0.  0.  1. ...  6.  0.  0.]
 [ 0.  0.  2. ... 12.  0.  0.]
 [ 0.  0. 10. ... 12.  1.  0.]]
y: [0 1 2 ... 8 9 8]
# ๋ชจ๋ธ ์ƒ์„ฑ
nb = GaussianNB()  # ๋‚˜์ด๋ธŒ ๋ฒ ์ด์ฆˆ ๋ถ„๋ฅ˜๊ธฐ ์ธ์Šคํ„ด์Šค ์ƒ์„ฑ

# ๊ต์ฐจ ๊ฒ€์ฆ
scores = cross_val_score(nb, X, y, cv=5)  # 5๊ฒน ๊ต์ฐจ ๊ฒ€์ฆ ์ˆ˜ํ–‰ํ•˜์—ฌ ์ ์ˆ˜ ๊ณ„์‚ฐ

# ๊ต์ฐจ ๊ฒ€์ฆ ์ ์ˆ˜ ์ถœ๋ ฅ
print(f'Cross-validation scores: {scores}')  # ๊ฐ fold์˜ ๊ฒ€์ฆ ์ ์ˆ˜ ์ถœ๋ ฅ
print(f'Mean CV Score: {np.mean(scores)}')  # ํ‰๊ท  ๊ต์ฐจ ๊ฒ€์ฆ ์ ์ˆ˜ ์ถœ๋ ฅ
Cross-validation scores: [0.78055556 0.78333333 0.79387187 0.8718663  0.80501393]
Mean CV Score: 0.8069281956050759
# ์‹œ๊ฐํ™”
plt.plot(range(1, len(scores) + 1), scores, marker='o', linestyle='--', color='b')
plt.xlabel('Fold')
plt.ylabel('Accuracy')
plt.title('Cross-Validation Scores')
plt.show()


ํ˜ผ๋™ํ–‰๋ ฌ(Confusion Matrix)

ํ˜ผ๋™ํ–‰๋ ฌ(Confusion Matrix)์€ ๋ชจ๋ธ์˜ ์˜ˆ์ธก ๊ฒฐ๊ณผ์™€ ์‹ค์ œ ๊ฒฐ๊ณผ๋ฅผ ๋น„๊ตํ•˜์—ฌ ์„ฑ๋Šฅ์„ ํ‰๊ฐ€ํ•˜๋Š” ๋ฐ ์‚ฌ์šฉ๋ฉ๋‹ˆ๋‹ค.
์ด์ง„ ๋ถ„๋ฅ˜์™€ ๋‹ค์ค‘ ํด๋ž˜์Šค ๋ถ„๋ฅ˜ ๋ฌธ์ œ ๋ชจ๋‘์— ์œ ์šฉํ•ฉ๋‹ˆ๋‹ค.

  • ํ˜ผ๋™ํ–‰๋ ฌ(Confusion Matrix)์—์„œ ์ฃผ์š” ์ง€ํ‘œ์— ๋ฐํ•˜์—ฌ ํ•œ๋ฒˆ ์•Œ์•„๋ณด๋ฉด 4๊ฐœ์˜ ์ง€ํ‘œ๊ฐ€ ์žˆ์Šต๋‹ˆ๋‹ค.
  • True Positive (TP)
    • ์‹ค์ œ ์–‘์„ฑ์ธ ๋ฐ์ดํ„ฐ๋ฅผ ์–‘์„ฑ์œผ๋กœ ์˜ฌ๋ฐ”๋ฅด๊ฒŒ ์˜ˆ์ธกํ•œ ๊ฒฝ์šฐ -  ๋ชจ๋ธ์ด Positive๋กœ ์˜ˆ์ธกํ•œ ๊ฒฐ๊ณผ๊ฐ€ ์‹ค์ œ๋กœ๋„ Positive์ธ ๊ฒฝ์šฐ์ž…๋‹ˆ๋‹ค.
    •  ์˜ˆ: ์‹ค์ œ๋กœ ์ŠคํŒธ ์ด๋ฉ”์ผ์ด ์ŠคํŒธ์œผ๋กœ ๋ถ„๋ฅ˜๋œ ๊ฒฝ์šฐ
  • False Negative (FN)
    • ์‹ค์ œ ์–‘์„ฑ์ธ ๋ฐ์ดํ„ฐ๋ฅผ ์Œ์„ฑ์œผ๋กœ ์ž˜๋ชป ์˜ˆ์ธกํ•œ ๊ฒฝ์šฐ - ๋ชจ๋ธ์ด Negative๋กœ ์˜ˆ์ธกํ•œ ๊ฒฐ๊ณผ๊ฐ€ ์‹ค์ œ๋กœ๋Š” Positive์ธ ๊ฒฝ์šฐ ์ž…๋‹ˆ๋‹ค.
    • ์˜ˆ: ์‹ค์ œ๋กœ ์ŠคํŒธ ์ด๋ฉ”์ผ์ด ์ŠคํŒธ์ด ์•„๋‹Œ ๊ฒƒ์œผ๋กœ ๋ถ„๋ฅ˜๋œ ๊ฒฝ์šฐ (๋ˆ„๋ฝ๋œ ์ŠคํŒธ)
  • False Positive (FP)
    • ์‹ค์ œ ์Œ์„ฑ์ธ ๋ฐ์ดํ„ฐ๋ฅผ ์–‘์„ฑ์œผ๋กœ ์ž˜๋ชป ์˜ˆ์ธกํ•œ ๊ฒฝ์šฐ -  ๋ชจ๋ธ์ด Positive๋กœ ์˜ˆ์ธกํ•œ ๊ฒฐ๊ณผ๊ฐ€ ์‹ค์ œ๋กœ๋Š” Negative์ธ ๊ฒฝ์šฐ ์ž…๋‹ˆ๋‹ค.
    • ์˜ˆ: ์‹ค์ œ๋กœ ์ŠคํŒธ์ด ์•„๋‹Œ ์ด๋ฉ”์ผ์ด ์ŠคํŒธ์œผ๋กœ ๋ถ„๋ฅ˜๋œ ๊ฒฝ์šฐ (์ž˜๋ชป๋œ ์ŠคํŒธ)
  • True Negative (TN)
    • ์‹ค์ œ ์Œ์„ฑ์ธ ๋ฐ์ดํ„ฐ๋ฅผ ์Œ์„ฑ์œผ๋กœ ์˜ฌ๋ฐ”๋ฅด๊ฒŒ ์˜ˆ์ธกํ•œ ๊ฒฝ์šฐ - ๋ชจ๋ธ์ด Negative๋กœ ์˜ˆ์ธกํ•œ ๊ฒฐ๊ณผ๊ฐ€ ์‹ค์ œ๋กœ๋„ Negative์ธ ๊ฒฝ์šฐ ์ž…๋‹ˆ๋‹ค.
    • ์˜ˆ: ์‹ค์ œ๋กœ ์ŠคํŒธ์ด ์•„๋‹Œ ์ด๋ฉ”์ผ์ด ์ŠคํŒธ์ด ์•„๋‹Œ ๊ฒƒ์œผ๋กœ ๋ถ„๋ฅ˜๋œ ๊ฒฝ์šฐ

 

ํ˜ผ๋™ํ–‰๋ ฌ(Confusion Matrix)์˜ ์„ฑ๋Šฅ ์ง€ํ‘œ

ํ˜ผ๋™ํ–‰๋ ฌ(Confusion Matrix)์˜ ์„ฑ๋Šฅ ์ง€ํ‘œ๋Š” 4๊ฐ€์ง€๊ฐ€ ์žˆ์Šต๋‹ˆ๋‹ค.
  • ์ •ํ™•๋„ (Accuracy): ์ „์ฒด ์˜ˆ์ธก ์ค‘ ๋งž์ถ˜ ๋น„์œจ์ž…๋‹ˆ๋‹ค.
  • ์ฆ‰, ๋ชจ๋ธ์ด ์–ผ๋งˆ๋‚˜ ์ž˜ ์˜ˆ์ธกํ–ˆ๋Š”์ง€๋ฅผ ํ‰๊ฐ€ํ•˜๋Š” ๊ฐ€์žฅ ๊ธฐ๋ณธ์ ์ธ ์ง€ํ‘œ์ž…๋‹ˆ๋‹ค.
  • 100๊ฐœ์˜ ์ƒ˜ํ”Œ ์ค‘ 90๊ฐœ๋ฅผ ๋งž์ถ”๊ณ  10๊ฐœ๋ฅผ ํ‹€๋ ธ๋‹ค๋ฉด, ์ •ํ™•๋„๋Š” 90%์ž…๋‹ˆ๋‹ค.

  • ์ •๋ฐ€๋„ (Precision): ์–‘์„ฑ์œผ๋กœ ์˜ˆ์ธก๋œ ๊ฒƒ ์ค‘ ์‹ค์ œ ์–‘์„ฑ์˜ ๋น„์œจ์ž…๋‹ˆ๋‹ค.
  • ์ฆ‰, ๋ชจ๋ธ์ด ์–‘์„ฑ์ด๋ผ๊ณ  ์˜ˆ์ธกํ•œ ๊ฒƒ๋“ค ์ค‘์—์„œ ์–ผ๋งˆ๋‚˜ ๋งŽ์€ ๊ฒƒ์ด ์‹ค์ œ๋กœ ์–‘์„ฑ์ธ์ง€๋ฅผ ์ธก์ •ํ•ฉ๋‹ˆ๋‹ค.
  • ๋†’์€ ์ •๋ฐ€๋„๋Š” ๋ชจ๋ธ์ด ์–‘์„ฑ์ด๋ผ๊ณ  ์˜ˆ์ธกํ•  ๋•Œ, ๊ทธ ์˜ˆ์ธก์ด ์‹ค์ œ๋กœ ๋งž์„ ๊ฐ€๋Šฅ์„ฑ์ด ๋†’๋‹ค๋Š” ๊ฒƒ์„ ์˜๋ฏธํ•ฉ๋‹ˆ๋‹ค.
  • ์ •๋ฐ€๋„๋Š” ์–‘์„ฑ ์˜ˆ์ธก์ด ์ค‘์š”ํ•œ ๊ฒฝ์šฐ(์˜ˆ: ์ŠคํŒธ ๋ฉ”์ผ ํ•„ํ„ฐ๋ง์—์„œ ์ŠคํŒธ์œผ๋กœ ์ž˜๋ชป ๋ถ„๋ฅ˜๋œ ์ •์ƒ ๋ฉ”์ผ์ด ์ ์–ด์•ผ ํ•˜๋Š” ๊ฒฝ์šฐ) ์œ ์šฉํ•ฉ๋‹ˆ๋‹ค.

  • ์žฌํ˜„์œจ (Recall): ์‹ค์ œ ์–‘์„ฑ ์ค‘ ๋งž์ถ˜ ๋น„์œจ์ž…๋‹ˆ๋‹ค. ์ฆ‰, ๋ชจ๋ธ์ด ์‹ค์ œ ์–‘์„ฑ ๋ฐ์ดํ„ฐ๋ฅผ ์–ผ๋งˆ๋‚˜ ์ž˜ ์ฐพ์•„๋‚ด๋Š”์ง€๋ฅผ ์ธก์ •ํ•ฉ๋‹ˆ๋‹ค.
  • ๋†’์€ ์žฌํ˜„์œจ์€ ๋ชจ๋ธ์ด ์‹ค์ œ ์–‘์„ฑ ๋ฐ์ดํ„ฐ๋ฅผ ์ž˜ ๋†“์น˜์ง€ ์•Š๋Š”๋‹ค๋Š” ๊ฒƒ์„ ์˜๋ฏธํ•ฉ๋‹ˆ๋‹ค.
  • ์žฌํ˜„์œจ์€ ์–‘์„ฑ ๋ฐ์ดํ„ฐ์˜ ํƒ์ง€๊ฐ€ ์ค‘์š”ํ•œ ๊ฒฝ์šฐ(์˜ˆ: ์•” ์ง„๋‹จ์—์„œ ์•” ํ™˜์ž๋ฅผ ๋†“์น˜์ง€ ์•Š๋Š” ๊ฒƒ์ด ์ค‘์š”ํ•œ ๊ฒฝ์šฐ) ์œ ์šฉํ•ฉ๋‹ˆ๋‹ค.

  • F1-์ ์ˆ˜ (F1-Score): ์ •๋ฐ€๋„์™€ ์žฌํ˜„์œจ์˜ ์กฐํ™” ํ‰๊ท ์ž…๋‹ˆ๋‹ค. ์ •๋ฐ€๋„์™€ ์žฌํ˜„์œจ ์‚ฌ์ด์˜ ๊ท ํ˜•์„ ์ธก์ •ํ•ฉ๋‹ˆ๋‹ค.
  • F1-์ ์ˆ˜๋Š” ์ •๋ฐ€๋„์™€ ์žฌํ˜„์œจ์˜ ๊ท ํ˜•์ด ์ค‘์š”ํ•œ ๊ฒฝ์šฐ์— ์‚ฌ์šฉ๋ฉ๋‹ˆ๋‹ค.
  • ํ•œ์ชฝ์ด ๋งค์šฐ ๋†’๊ณ  ๋‹ค๋ฅธ ์ชฝ์ด ๋‚ฎ์„ ๋•Œ, F1-์ ์ˆ˜๋Š” ์ด๋ฅผ ๋ฐ˜์˜ํ•˜์—ฌ ์ ์ ˆํ•œ ๊ท ํ˜•์„ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค.


ํ˜ผ๋™ํ–‰๋ ฌ(Confusion Matrix) Example Code

# ๋ฐ์ดํ„ฐ ์ƒ์„ฑ
data = load_breast_cancer()  # ์œ ๋ฐฉ์•” ๋ฐ์ดํ„ฐ์…‹์„ ๋กœ๋“œ
X = data.data  # ์ž…๋ ฅ ๋ฐ์ดํ„ฐ (ํŠน์ง•)
y = data.target  # ์ถœ๋ ฅ ๋ฐ์ดํ„ฐ (๋ ˆ์ด๋ธ”)

# ๋ฐ์ดํ„ฐ ๋ถ„ํ• 
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)  
# ์ „์ฒด ๋ฐ์ดํ„ฐ๋ฅผ 80%๋Š” ํ•™์Šต ๋ฐ์ดํ„ฐ๋กœ, 20%๋Š” ํ…Œ์ŠคํŠธ ๋ฐ์ดํ„ฐ๋กœ ๋ถ„ํ• 

# ๋ชจ๋ธ ํ•™์Šต
nb = GaussianNB()  # ๋‚˜์ด๋ธŒ ๋ฒ ์ด์ฆˆ ๋ถ„๋ฅ˜๊ธฐ ์ธ์Šคํ„ด์Šค ์ƒ์„ฑ
nb.fit(X_train, y_train)  # ํ•™์Šต ๋ฐ์ดํ„ฐ๋กœ ๋ชจ๋ธ ํ•™์Šต

# ์˜ˆ์ธก
y_pred = nb.predict(X_test)  # ํ…Œ์ŠคํŠธ ๋ฐ์ดํ„ฐ๋กœ ์˜ˆ์ธก ์ˆ˜ํ–‰
y_pred_prob = nb.predict_proba(X_test)[:, 1]  # ๊ฐ ํด๋ž˜์Šค์— ๋Œ€ํ•œ ์˜ˆ์ธก ํ™•๋ฅ  ์ค‘ ์–‘์„ฑ ํด๋ž˜์Šค ํ™•๋ฅ  ์ถ”์ถœ

# ํ˜ผ๋™ ํ–‰๋ ฌ
conf_matrix = confusion_matrix(y_test, y_pred)  # ์‹ค์ œ ๋ ˆ์ด๋ธ”๊ณผ ์˜ˆ์ธก ๋ ˆ์ด๋ธ”์„ ๋น„๊ตํ•˜์—ฌ ํ˜ผ๋™ ํ–‰๋ ฌ ์ƒ์„ฑ
print(f'Confusion Matrix:\n{conf_matrix}')  # ํ˜ผ๋™ ํ–‰๋ ฌ ์ถœ๋ ฅ
Confusion Matrix:
[[40  3]
 [ 0 71]]
# ์‹œ๊ฐํ™” - ํ˜ผ๋™ ํ–‰๋ ฌ
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues')
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.title('Confusion Matrix')
plt.show()


ROC ๊ณก์„ ๊ณผ AUC

ROC (Receiver Operating Characteristic) ๊ณก์„ 

๋ชจ๋ธ์˜ ๋ถ„๋ฅ˜ ์ž„๊ณ„๊ฐ’์„ ๋ณ€ํ™”์‹œํ‚ค๋ฉฐ True Positive Rate(์žฌํ˜„์œจ)์™€ False Positive Rate๋ฅผ ๋น„๊ตํ•˜๋Š” ๊ณก์„ ์ž…๋‹ˆ๋‹ค.
  • True Positive Rate (TPR): ์žฌํ˜„์œจ๊ณผ ๋™์ผํ•ฉ๋‹ˆ๋‹ค.
  • False Positive Rate (FPR): ์‹ค์ œ ์Œ์„ฑ ์ค‘ ์ž˜๋ชป ์–‘์„ฑ์œผ๋กœ ์˜ˆ์ธก๋œ ๋น„์œจ์ž…๋‹ˆ๋‹ค.

  • ROC ๊ณก์„ ์€ FPR์„ x์ถ•, TPR์„ y์ถ•์— ๋†“๊ณ  ๊ทธ๋ฆฝ๋‹ˆ๋‹ค.
  • ์ž„๊ณ„๊ฐ’์„ ๋ณ€ํ™”์‹œํ‚ค๋ฉฐ FPR๊ณผ TPR์˜ ๋ณ€ํ™”๋ฅผ ๊ด€์ฐฐํ•ฉ๋‹ˆ๋‹ค.
  • ์ตœ์ ์˜ ๋ชจ๋ธ์€ ROC ๊ณก์„ ์ด ์™ผ์ชฝ ์ƒ๋‹จ ๋ชจ์„œ๋ฆฌ์— ๊ฐ€๊นŒ์šด ํ˜•ํƒœ๋ฅผ ๊ฐ€์ง‘๋‹ˆ๋‹ค.

AUC (Area Under the Curve)

  • ROC ๊ณก์„  ์•„๋ž˜์˜ ๋ฉด์ ์„ ๋‚˜ํƒ€๋‚ด๋ฉฐ, ๋ชจ๋ธ์˜ ์ „๋ฐ˜์ ์ธ ์„ฑ๋Šฅ์„ ํ‰๊ฐ€ํ•˜๋Š” ์ง€ํ‘œ์ž…๋‹ˆ๋‹ค
  • AUC ๊ฐ’์ด 1์— ๊ฐ€๊นŒ์šธ์ˆ˜๋ก ์ข‹์€ ๋ชจ๋ธ์„ ์˜๋ฏธํ•ฉ๋‹ˆ๋‹ค.
  • AUC = 1: ์™„๋ฒฝํ•œ ๋ชจ๋ธ
  • AUC = 0.5: ๋žœ๋ค ์ถ”์ธก
  • AUC < 0.5: ๋ชจ๋ธ ์„ฑ๋Šฅ์ด ๋žœ๋ค ์ถ”์ธก๋ณด๋‹ค ๋‚˜์จ
  • ๋งŒ์•ฝ, AUC๊ฐ€ 0.9์ธ ๊ฒฝ์šฐ: ๋ชจ๋ธ์ด 90%์˜ ํ™•๋ฅ ๋กœ ์–‘์„ฑ ์˜ˆ์ธก๊ณผ ์Œ์„ฑ ์˜ˆ์ธก์„ ์˜ฌ๋ฐ”๋ฅด๊ฒŒ ๊ตฌ๋ณ„ํ•  ์ˆ˜ ์žˆ๋‹ค๋Š” ๊ฒƒ์„ ์˜๋ฏธํ•ฉ๋‹ˆ๋‹ค.

 


ROC ๊ณก์„ ๊ณผ AUC Example Code

# ROC ๋ฐ AUC ๊ณ„์‚ฐ
fpr, tpr, _ = roc_curve(y_test, y_pred_prob)  # ๊ฑฐ์ง“ ๊ธ์ • ๋น„์œจ(fpr)๊ณผ ์ง„์งœ ๊ธ์ • ๋น„์œจ(tpr) ๊ณ„์‚ฐ
roc_auc = roc_auc_score(y_test, y_pred_prob)  # AUC (๊ณก์„  ์•„๋ž˜ ๋ฉด์ ) ๊ณ„์‚ฐ

# AUC ์ ์ˆ˜ ์ถœ๋ ฅ
print(f'ROC AUC Score: {roc_auc}')  # ROC AUC ์ ์ˆ˜ ์ถœ๋ ฅ

# ROC AUC Score: 0.9983622666229938
# ์‹œ๊ฐํ™” - ROC ๊ณก์„ 
plt.plot(fpr, tpr, color='blue', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')  
# ROC ๊ณก์„ ์„ ํŒŒ๋ž€์ƒ‰ ์„ ์œผ๋กœ ๊ทธ๋ฆฌ๋ฉฐ AUC ๊ฐ’์„ ๋ ˆ์ด๋ธ”๋กœ ์ถ”๊ฐ€

plt.plot([0, 1], [0, 1], color='grey', lw=2, linestyle='--')  # ๋Œ€๊ฐ์„  ๊ธฐ์ค€์„  ์ถ”๊ฐ€

plt.xlim([0.0, 1.0])  # x์ถ• ๋ฒ”์œ„ ์„ค์ •
plt.ylim([0.0, 1.05])  # y์ถ• ๋ฒ”์œ„ ์„ค์ •
plt.xlabel('False Positive Rate')  # x์ถ• ๋ ˆ์ด๋ธ” ์„ค์ •
plt.ylabel('True Positive Rate')  # y์ถ• ๋ ˆ์ด๋ธ” ์„ค์ •
plt.title('Receiver Operating Characteristic (ROC) Curve')  # ๊ทธ๋ž˜ํ”„ ์ œ๋ชฉ ์„ค์ •
plt.legend(loc="lower right")  # ๋ฒ”๋ก€ ์œ„์น˜ ์„ค์ •
plt.show()  # ๊ทธ๋ž˜ํ”„ ์ถœ๋ ฅ