あれもPython,これもPython

Pythonメモ※本サイトはアフィリエイトを利用しています

レコード単位の予測の解釈がしたい(Limeを使う)

木構造モデルを作った後、変数のモデル内での影響度は簡単に出すことができます。 ただし、lm,glmと異なり、実際に各レコードの予測値に各変数が影響しているのかを確認するのはデフォルト機能では難しいです。

そこでLimeを使うことで、実際にレコード単位での予測値への影響を可視化することができます。

コード

今回はテーブル構造データに対して使ってみます。 作ったモデルの予測メソッドと予測用の特徴量、対象とするレコードを与えることで、可視化をすることができます

下準備

下準備

import pandas as pd
from sklearn.datasets import load_iris,load_boston

iris_y = load_iris()['target']
iris_X = load_iris()['data']

boston_y = load_boston()['target']
boston_X = load_boston()['data']

import lime
import lime.lime_tabular

from sklearn.ensemble import RandomForestClassifier,RandomForestRegressor

def data_split(X,y):
    X_train,X_test,y_train,y_test = train_test_split(
        X,
        y,
        test_size = 0.3  
    )

    return {
        "X_train":X_train,
        "X_test":X_test,
        "y_train":y_train,
        "y_test":y_test   
    }

分類モデル

現在Limeは分類の場合は、各分類のprobaを与える必要があります。

予測モデルの作成

rfc_data = data_split(iris_X,iris_y)
rfc = RandomForestClassifier()
rfc.fit(rfc_data['X_train'],rfc_data['y_train'])

rfc_res = pd.DataFrame(
    {
        'pred':rfc.predict(rfc_data['X_test']),
        'obs':rfc_data['y_test']
    }
)

#不正解のデータを探す
rfc_res[rfc_res['pred']!=rfc_res['obs']]

Limeの使用

lime_ex = lime.lime_tabular.LimeTabularExplainer(
    training_data= rfc_data['X_test'],
    feature_names = load_iris()['feature_names'],
    mode='classification' #or_regression
)

lime_res = lime_ex.explain_instance(
    rfc_data['X_test'][43],
    rfc.predict_proba #proba
)

lime_res.show_in_notebook(show_table=True)

f:id:esu-ko:20200919221402p:plain

petalの二つに引っ張られているようです。

回帰

ほぼ同じ処理です。

rfr_data = data_split(boston_X,boston_y)
rfr = RandomForestRegressor()
rfr.fit(rfr_data['X_train'],rfr_data['y_train'])
rfr_res = pd.DataFrame(
    {
        'pred':rfr.predict(rfr_data['X_test']),
        'obs':rfr_data['y_test']
    }
)
rfr_res['diff']= (rfr_res['pred']-rfr_res['obs']).abs()
#よくなさそうなデータを確認
rfr_res.sort_values('diff',ascending=False).head(5)

lime_ex = lime.lime_tabular.LimeTabularExplainer(
    training_data= rfr_data['X_test'],
    feature_names = load_boston()['feature_names'],
    mode='regression' #or_regression
)
lime_res = lime_ex.explain_instance(
    rfr_data['X_test'][36],
    rfr.predict #proba
)
lime_res.show_in_notebook(show_table=True)

f:id:esu-ko:20200919221548p:plain

LSTATやRMに強く引っ張られているか、ネガティブ系でどれか弱いのがありそうです。

PythonでCatboostを使いたい

Catboostはboosting系の中で、やや重いですが。カテゴリをそのまま扱えるといった特徴があります。

xgboost,lightgbm同様に専用のパッケージをインストールします。

コード

catboostは、Poolという形でデータをできるかぎり扱うようにすると楽です。
今回も基本的な動きの確認のため、評価データを作りませんでした。

from catboost import CatBoost,Pool

回帰

#回帰
reg_train = Pool(
    boston_X,
    boston_y
)

reg_mdl = CatBoost()
reg_mdl.fit(
    reg_train,
    verbose=False,
    #jupyter上で学習過程を可視化してくれる
    #plot=True  
)

分類

# 多値分類問題

mc_train = Pool(
    iris_X,
    iris_y
)

mc_mdl = CatBoost(
    {
        'loss_function': 'MultiClass'
    }
)
mc_mdl.fit(
    mc_train,
    verbose=False,
    #plot=True  
)

予測

予測で帰ってくる値は`prediction_typeで調整します。

#Classだとクラスを、Probabilityだと各確率を返してくれる
mc_mdl.predict(iris_X,prediction_type='Probability')

特徴重要度の可視化

作成したモデルの特徴重要度はfeature_importances_で見れます。
木構造を可視化してくれるメソッドも存在しますが、何番目の木か、を指定する必要があります。

mc_mdl.feature_importances_

mc_mdl.plot_tree(tree_idx=1)

f:id:esu-ko:20200919222655p:plain

Pythonでxgboostを使う

xgboostはlightgbmに比べると、やや重い印象があります。
ただし、ケースによってはlightgbm以上の精度が出ることもあります。

lightgbm同様、sklearnには導入されていません。

コード

xgboostもsklearnライクな使い方もありますが、デフォルトの方で使ってみます。

基本的にはlightgbmと同じですが、オブジェクトに渡す文字列や、専用のデータセットの作り方など若干の近いがあります。(lightgbmの方が後発ですが、、、)

今回も基本的な動きの確認のため、評価データを作りませんでした。

import xgboost as xgb

回帰

#回帰
dtrain = xgb.DMatrix(boston_X, label=boston_y)
reg_mdl = xgb.train(
    {"objective":"reg:squarederror"},
    dtrain
)

#予測
#reg_mdl.predict(dtrain)

分類

dtrain = xgb.DMatrix(iris_X, iris_y)
mc_mdl = xgb.train(
    {
        "objective":"multi:softmax",
        "num_class":3
    },
    dtrain
)
#予測
#mc_mdl.predict(dtrain)

特徴重要度の可視化

作成したモデルはfeature_importances_になります。 また、可視化を直接してくれるメソッド、木構造を可視化してくれるメソッドも存在します。

xgb_mdl.feature_importances_
xgb.plot_tree(reg_mdl)
xgb.plot_importance(mc_mdl)

f:id:esu-ko:20200919222107p:plain

f:id:esu-ko:20200919222120p:plain