【模型評估指標】機率校正曲線圖Probability Calibration Curves

你的模型能夠預測的多好,除了得到準確度(Accuracy)以外,有什麼樣的方式可以描述分類模型的預測表現,甚至是畫出容易讓非相關背景的人,也能一眼看出預測能力最佳模型的方法呢?

當然有的!!!

我們可以使用scikit-learn當中的Probability calibration(機率校正)來獲得模型對於預測表現的信心,一般來說,我們會應用此Probability calibration在二元分類的模型上,並且利用特定模型當中的predict_proba()方法來得到預測該類別的機率。

校正曲線(Calibration Curve)也稱做標準曲線,原本是在化學分析中非常重要的工具,用於將量測的值與原始已知的濃度(也就是標準品的濃度)進行分析;而在機器學習領域當中,我們也會使用此概念來對應我們想要預測目標。

機率校正曲線圖的組成

機率校正曲線圖是在機器學習中用來評估分類模型預測機率準確度的工具,它能幫助我們理解模型輸出的機率分佈與實際結果之間的關係。

圖片來源: scikit-learn Probability Calibration curves

本文將會分別解釋X軸、Y軸以及完全準確的校正曲線圖,我們就可以知道分類模型在預測上的信心水準囉!

X軸:預測機率(Predicted Probability)

  • 表示分類模型預測輸出的機率,範圍在0到1之間。
  • 一般我們會將想要預測的類設定為class: 1(Positive class)。

Y軸:實際機率(Observed/True Probability)

  • 表示該模型預測的機率值在實際當中有多準確。

完全準確的校正曲線(Perfect Calibration Line)

  • 是一條對角線從(0,0)延伸到(1,1),也就是上圖灰色虛線從最左下角到最右上角。
  • 若是建立的模型完全貼合此虛線,則表示分類模型預測的機率非常準確。

模型建立後繪製出的校準曲線

圖片來源: scikit-learn Probability Calibration curves
  • 通過將預測結果分成多個區間(bins),圖片中是將每0.1做為一個區間,計算每個區間內的平均預測機率與實際機率來繪製機率校正曲線圖。
  • 而預測模型的曲線偏離虛線的程度也就說明模型預測的機率偏離實際機率的程度。

機率校正曲線圖的分析流程

校正曲線圖的目的

主要評估模型的機率預測是否準確,當模型在說明某事件發生的機率時,我們希望這個機率與實際發生的頻率越相近越好。

校正曲線圖的計算過程

  1. 分箱:將模型的預測機率分成多個區間(bins)。
  2. 計算:對於每個區間,計算該區間的平均預測機率和實際觀察到的事件發生的比例。
  3. 繪圖:在圖表中繪製這些計算的點,並連成一條預測曲線。

校正曲線圖的分析

  • 預測曲線的低估或高估:如果曲線在對角線下方,模型的預測機率偏低(低估了發生的機率);如果在對角線上方,則表示偏高(高估了發生的機率)。
  • 完美校正:如果模型的預測曲線與對角線(虛線)完全重合,表示模型的預測機率與實際情況完全吻合。

校正曲線圖的python範例

接下來會說明如何使用生成的資料集進行隨機森林(Random Forest)和支持向量分類器(SVC)在二元分類數據集上圖表的產出,可以線上執行的範例我會放在我的colab上,歡迎使用

# 安裝需要使用到的套件
import numpy as np
from sklearn.datasets import make_classification
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
from sklearn.model_selection import train_test_split
from matplotlib.gridspec import GridSpec
from sklearn.calibration import CalibrationDisplay
import matplotlib.pyplot as plt

# 生成二元分類資料集
X, y = make_classification(n_samples=500, n_features=30, n_informative=15, n_redundant=5, n_classes=2, random_state=28)

# 將資料集分為訓練集(90%)和測試集(10%)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=28)

# 建立並擬合隨機森林(Random Forest)模型
rf_model = RandomForestClassifier(n_estimators=100, random_state=28)
rf_model.fit(X_train, y_train)

# 建立並擬合支持向量分類器(Support Vector Classification)模型
svc_model = SVC(probability=True, random_state=28)  # 開啟機率估計
svc_model.fit(X_train, y_train)

# 使用模型的predict_proba()來預測正分類的機率
rf_prob = rf_model.predict_proba(X_test)[:, 1]  # 隨機森林的正類機率
svc_prob = svc_model.predict_proba(X_test)[:, 1]  # 支持向量分類器的正類機率

以上的code我們已經建立了兩個具有機率預測能力的模型,接下來就是畫出機率校正曲線圖,也就是使用scikit-learn當中的CalibrationDisplay。

# 初始化圖表和網格
fig = plt.figure(figsize=(10, 10))
gs = GridSpec(4, 2, figure=fig)

# 建立校正曲線
ax_calibration_curve = fig.add_subplot(gs[:2, :2])

# 把隨機森林和支持向量分類器儲存成一個模型列表
clf_list = [(rf_model, "Random Forest"), (svc_model, "SVC")]

# 設定圖片顏色
colors = plt.colormaps['tab10']

# 繪製校正曲線圖
calibration_displays = {}
for i, (clf, name) in enumerate(clf_list):
    display = CalibrationDisplay.from_estimator(
        clf,
        X_test,
        y_test,
        n_bins=10,
        name=name,
        ax=ax_calibration_curve,
        color=colors(i),
    )
    calibration_displays[name] = display

# 設定完全校正的虛線
ax_calibration_curve.plot([0, 1], [0, 1], 'k--')  
ax_calibration_curve.grid()
ax_calibration_curve.set_title("Calibration plots")

# 繪製直方圖
grid_positions = [(2, 0), (2, 1), (3, 0), (3, 1)]
for i, (_, name) in enumerate(clf_list):
    row, col = grid_positions[i]
    ax = fig.add_subplot(gs[row, col])

    ax.hist(
        calibration_displays[name].y_prob,
        range=(0, 1),
        bins=10,
        label=name,
        color=colors(i),
        alpha=0.7
    )
    ax.set(title=name, xlabel="Mean predicted probability", ylabel="Count")

# 自動調整圖片佈局(讓圖片更緊湊)
plt.tight_layout()

# 顯示圖片
plt.show()

輸出圖表可以發現隨機森林模型(藍色線)在整體的分類預測上面,更加貼近完全校正的虛線,下方的直方圖可以幫助我們理解在不同的區域,兩種模型有較為相反的正分類預測結果。

最後的結果可以發現:

  • 隨機森林模型的預測機率分佈較為均勻,比較接近常態分佈,表示模型在對測試集進行預測時,大多數樣本的預測機率分佈在中間區域(約在0.2到0.8之間),而極端高(靠近1)或極端低(靠近0)的機率較少。
  • 支持向量分類器的預測結果則是在<0.1以及>0.9的部分有相當高的預測機率,這樣的模型表現則是較為不理想,會變成此模型在預測的時候,會有產生過度預測正分類(class1)或過度預測負分類(class0)的情形,這樣是在建立模型當中比較不希望看的結果。

參考資料

發佈留言

發佈留言必須填寫的電子郵件地址不會公開。 必填欄位標示為 *