確率補正に必要なsklearn APIの取り付け

更新日:2021年5月12日

本連載では分類モデルの予測値を信頼性曲線にプロットしたり、 クラス確率に近づける確率補正について取り上げています。 前回、sklearn APIが元々備わっていたLightGBMモデルを対象にしましたが、 今回は、AutoGluon Tabular という AutoML を対象に、 確率補正に必要なAPIメソッドを取り付ける 具体例をご紹介します。

ライブラリの用意

sklearn を使って信頼性曲線を書いたり確率補正します。 図形は Matplotlib と Plotly で作ります。 補正対象のモデルとしては AutoGluon Tabular を使います。 AutoGluon内部では lightgbm, catboost, xgboost などが使われます。

Condaやpipでそれぞれインストールできます。

サンプルデータ

Adult Census Income (国勢調査の成人収入) を使います。5万ドル以上の年収があるかどうかを分類するデータセットです。 このデータは元々正例の割合が少なく、偏っていますが、 更にノイズを投入しました。

import numpy as np
import shap
import sklearn
from sklearn.model_selection import train_test_split

## Census income
X, y = shap.datasets.adult()
X = X.values

print("Original dataframe shape", X.shape)
n_samples, n_features = X.shape

# Add noise
random_state = np.random.RandomState(0)
X = X + 4 * random_state.randn(n_samples, n_features)
X = np.c_[X, random_state.randn(n_samples, 100 * n_features)]
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.25, random_state=random_state
)
print("Noisy dataframe shape", X.shape)
print("Classes", np.unique(y))

X_test, X_calib, y_test, y_calib = train_test_split(
    X_test,
    y_test,
    test_size=4000 / len(X_test),
    random_state=random_state,
)
print("学習・補正・テスト用データ比率", np.array([len(X_train), len(X_calib), len(X_test)]) / len(X))

from autogluon.tabular import TabularDataset
tr_data = TabularDataset(X_train)
tr_data["y"] = y_train
te_data = TabularDataset(X_test)
te_data["y"] = y_test
Original dataframe shape (32561, 12)
Noisy dataframe shape (32561, 1212)
Classes [False  True]
学習・補正・テスト用データ比率 [0.74997697 0.12284635 0.12717668]

モデル学習

AutoGluonモデルを学習データにフィットさせます。


from autogluon.tabular import TabularPredictor

save_path = "trained-model"
predictor = TabularPredictor(label="y", path=save_path).fit(
    tr_data, hyperparameters="toy", time_limit=30
)

Sklearnラッパー

sklearnを使って確率補正を行うために、必要なインタフェースを用意する。 試験データでメトリック AUC, Brier を図り、信頼性曲線を書いてみます。


from fastcore.basics import store_attr


class AutoGluonWrapper:
    """
    sklearnを使って信頼性曲線を描いたり、確率補正を行うために、
    必要なインタフェースを用意する。
    """

    def __init__(
        self,
        trained_model_path,  # AutoGluon学習済みモデルの保存パス
        classes_,  # sklearn APIに求められる属性
    ):
        store_attr()

    def load_model(self):
        """ AutoGluon学習済みモデルをロード """
        self.ag_model = TabularPredictor.load(self.trained_model_path)

    def fit(self):
        """ sklearn API に求められるメソッド """
        return True

    def predict_proba(self, X):
        """ sklearn API に求められるメソッド """
        X = TabularDataset(X)
        proba = self.ag_model.predict_proba(X)
        return proba.values
ag_ = AutoGluonWrapper(save_path, classes_=np.unique(y))
ag_.load_model() 

信頼性曲線

試験データに対してメトリック AUC, Brier を図り、信頼性曲線を書いてみます。 前回定義したメソッド plot_calibration_curve を使います。


from kowaza.proba_calib import plot_calibration_curve
plot_calibration_curve(dict(AutoGluon=ag_), X_test, y_test)


確率補正の実施

補正するためにsklearn.calibration. CalibratedClassifierCVを使います。 cv = "prefit" と指定することによって、ベースモデルが学習済みであり、補正モデルに渡すデータは全量補正用であることを伝えます。

今回は確率補正方法として sigmoid と isotonic をそれぞれ使ってみましょう。



from sklearn.calibration import CalibratedClassifierCV, calibration_curve
sigmoid = CalibratedClassifierCV(ag_, cv="prefit", method="sigmoid")
sigmoid.fit(X_calib, y_calib)
isotonic = CalibratedClassifierCV(ag_, cv="prefit", method="isotonic")
isotonic.fit(X_calib, y_calib)

補正前と補正後のモデルの信頼性曲線を同じ図に書いて比較してみましょう。


plot_calibration_curve(
    dict(
        AutoGluon=ag_,
        Sigmoid=sigmoid,
        Isotonic=isotonic,
    ),
    X_test,
    y_test,
)



補正の結果、信頼性曲線が改善され、Brierスコアも改善されました。 最も、対象データセットにノイズを投入せずにAutoGluonに学習させると、確率補正を行わなくてもきれいな信頼性曲線が得られます。 しかし、実運用では、データに必ずノイズが含まれるし、データの分布も時間とともに少しずつ変化していくものなので、信頼性曲線をプロットして必要に応じて確率補正を行う必要があります。

まとめ

第1回目に、分類モデルの出力値が必ずしもクラス確率とは限らないので、信頼性曲線を確認したり、確率補正を行ってみました。 今回は、sklearnが規定しているインタフェースを持たない学習済みモデルの補正実験を行いました。 次回は、確率補正に関する背景的な理論について書ければと思います。 実行可能な Jupyter Notebook はこちらです。



閲覧数:54回0件のコメント

最新記事

すべて表示

スパムメールやクレジット詐欺を見分けるタスクなどを学習した 分類モデルが出力する予測値は通常 (0, 1) の範囲内に収まり、 予測確率とも呼ばれるので、 うっかり正例であるクラス確率だと 思い込みかねません。 実運用では、閾値を設けて、予測値がその閾値を超えるかどうかで...