ROC AUC Example with Scikit Plot
1. Example
1.1 Imports
from itertools import cycle
import matplotlib.pyplot as plt
import numpy as np
from sklearn import datasets, svm
from sklearn.model_selection import train_test_split
from sklearn.multiclass import OneVsRestClassifier
from sklearn.preprocessing import label_binarize
1.2 Data - Digits
np.random.seed(0)
data = datasets.load_digits()
x_data = data.data
y_data = data.target
n_classes = len(np.unique(y_data))
# 노이즈 추가
random_state = np.random.RandomState(0)
n_samples, n_features = x_data.shape
x_data = np.c_[x_data, random_state.randn(n_samples, 200 * n_features)]
# shuffle and split training and test sets
x_train, x_test, y_train, y_test = train_test_split(x_data, y_data, test_size=0.5, random_state=0)
print("x_data :", x_data.shape)
print("y_data :", y_data.shape)
print('x_train:', x_train.shape)
print('y_train:', y_train.shape)
print('x_test :', x_test.shape)
print('y_test :', y_test.shape)
print('unique labels:', len(np.unique(y_test)))
1.3 Model
from lightgbm import LGBMClassifier
model = LGBMClassifier(random_state=0, n_estimators=2, num_leaves=2, max_depth=1, objective='multiclass')
model.fit(x_train, y_train)
1.4 Evaluation
from sklearn.metrics import accuracy_score, recall_score, precision_score, classification_report
y_pred = model.predict(x_test)
y_prob = model.predict_proba(x_test)
print('ACC :', accuracy_score(y_test, y_pred))
print('Recall :', recall_score(y_test, y_pred, average='macro'))
print('Precision :', recall_score(y_test, y_pred, average='macro'))
print(classification_report(y_test, y_pred))
1.5 ROC - Matplotlib
from sklearn.metrics import auc, roc_auc_score, roc_curve
fpr = [None] * n_classes
tpr = [None] * n_classes
roc_auc = [None] * n_classes
for i in range(n_classes):
fpr[i], tpr[i], _ = roc_curve(y_test == i, y_prob[:, i])
roc_auc[i] = auc(y_test == i, y_pred == i)
def plot_roc(lw=0.7):
idx = 0
fig, ax = plt.subplots(1, figsize=(7, 5))
ax.plot([0, 1], [0, 1], "k--")
for i in range(n_classes):
ax.plot(fpr[i], tpr[i], lw=lw, label=f"[{i}] AUC: {roc_auc[i]:.2f})")
ax.set_ylabel('True Positive')
ax.set_xlabel('False Positive')
fig.legend(loc='lower right')
return fig
fig = plot_roc()
1.6 ROC - Scikit Plot
import scikitplot as skplt
y_prob = model.predict_proba(x_test)
fig, ax = plt.subplots(1, figsize=(7, 5))
skplt.metrics.plot_roc(y_test, y_prob, ax=ax)
# Line
for line in ax.get_lines():
line.set_linewidth(0.5)