あれもPython,これもPython

Pythonメモ※本サイトはアフィリエイトを利用しています

レコード単位の予測の解釈がしたい(Limeを使う)

木構造モデルを作った後、変数のモデル内での影響度は簡単に出すことができます。 ただし、lm,glmと異なり、実際に各レコードの予測値に各変数が影響しているのかを確認するのはデフォルト機能では難しいです。

そこでLimeを使うことで、実際にレコード単位での予測値への影響を可視化することができます。

コード

今回はテーブル構造データに対して使ってみます。 作ったモデルの予測メソッドと予測用の特徴量、対象とするレコードを与えることで、可視化をすることができます

下準備

下準備

import pandas as pd
from sklearn.datasets import load_iris,load_boston

iris_y = load_iris()['target']
iris_X = load_iris()['data']

boston_y = load_boston()['target']
boston_X = load_boston()['data']

import lime
import lime.lime_tabular

from sklearn.ensemble import RandomForestClassifier,RandomForestRegressor

def data_split(X,y):
    X_train,X_test,y_train,y_test = train_test_split(
        X,
        y,
        test_size = 0.3  
    )

    return {
        "X_train":X_train,
        "X_test":X_test,
        "y_train":y_train,
        "y_test":y_test   
    }

分類モデル

現在Limeは分類の場合は、各分類のprobaを与える必要があります。

予測モデルの作成

rfc_data = data_split(iris_X,iris_y)
rfc = RandomForestClassifier()
rfc.fit(rfc_data['X_train'],rfc_data['y_train'])

rfc_res = pd.DataFrame(
    {
        'pred':rfc.predict(rfc_data['X_test']),
        'obs':rfc_data['y_test']
    }
)

#不正解のデータを探す
rfc_res[rfc_res['pred']!=rfc_res['obs']]

Limeの使用

lime_ex = lime.lime_tabular.LimeTabularExplainer(
    training_data= rfc_data['X_test'],
    feature_names = load_iris()['feature_names'],
    mode='classification' #or_regression
)

lime_res = lime_ex.explain_instance(
    rfc_data['X_test'][43],
    rfc.predict_proba #proba
)

lime_res.show_in_notebook(show_table=True)

f:id:esu-ko:20200919221402p:plain

petalの二つに引っ張られているようです。

回帰

ほぼ同じ処理です。

rfr_data = data_split(boston_X,boston_y)
rfr = RandomForestRegressor()
rfr.fit(rfr_data['X_train'],rfr_data['y_train'])
rfr_res = pd.DataFrame(
    {
        'pred':rfr.predict(rfr_data['X_test']),
        'obs':rfr_data['y_test']
    }
)
rfr_res['diff']= (rfr_res['pred']-rfr_res['obs']).abs()
#よくなさそうなデータを確認
rfr_res.sort_values('diff',ascending=False).head(5)

lime_ex = lime.lime_tabular.LimeTabularExplainer(
    training_data= rfr_data['X_test'],
    feature_names = load_boston()['feature_names'],
    mode='regression' #or_regression
)
lime_res = lime_ex.explain_instance(
    rfr_data['X_test'][36],
    rfr.predict #proba
)
lime_res.show_in_notebook(show_table=True)

f:id:esu-ko:20200919221548p:plain

LSTATやRMに強く引っ張られているか、ネガティブ系でどれか弱いのがありそうです。