あれもPython,これもPython

Pythonメモ※本サイトはアフィリエイトを利用しています

lightgbmをoptunaで楽にチューニングしたい

optunaは下準備にコードを書く必要がありますが、lightgbmが対象の場合は、より簡単なコードで処理することができます。

コード

lightbgmでtrainしていたところを、optuna.integration.lightgbmに変えるだけです。

▼lightgbmに関してはこちら

esu-ko.hatenablog.com

下準備

import lightgbm
import optuna.integration.lightgbm as lgbo
from sklearn.model_selection import train_test_split

# 回帰の場合
opt_params = {
    "objective":"regression",
    "metric":"rmse"
}

# データを用意する
 X_train,X_test,y_train,y_test = train_test_split(
                boston_X,
                boston_y,
                test_size=0.2)
    
reg_train = lgb.Dataset(
    X_train,
    y_train
)

reg_eval = lgb.Dataset(
    X_test,
    y_test,
    reference=reg_train
)

パラメータを探す

#パラメータを探す
opt=lgbo.train(
    opt_params,
    reg_train,
    valid_sets = reg_eval,
    verbose_eval=False,
    #ラウンド数
    num_boost_round = 5,
    #打ち切り
    #early_stopping_rounds = 100
)

パラメータを確認

#最適なパラメータを取得
opt.params
    {'objective': 'regression',
     'metric': 'rmse',
     'lambda_l1': 0.06827725234472487,
     'lambda_l2': 0.0026080375045565317,
     
     'num_leaves': 31,
     'feature_fraction': 1.0,
     'bagging_fraction': 1.0,
     'bagging_freq': 0,
     'min_child_samples': 5}

あとはこれをモデルに渡し、再学習します。

ベイズ最適でハイパラチューニングをしたい(optunaを使う)

ハイパーパラメータのチューニングをする場合、グリッドサーチでは非常に時間がかかることがあります。
対して、ベイズを用いたハイパーパラメータチューニングは非常に早く実行できます。

今回はそのためにoptunaを使用します。

基本的な使い方

最大化、最小化する目的関数を作成し、それを最適化します。

目的関数の作成

探索するハイパーパラメータはtrial.xxxで型、範囲、名前を定義します。

import optuna
from sklearn.metrics import accuracy_score

def objective(trial):
    params = {
        "min_samples_split":trial.suggest_int("min_samples_split", 8, 16),
        "max_leaf_nodes":int(trial.suggest_discrete_uniform("max_leaf_nodes", 4, 64, 4)),
        "criterion":trial.suggest_categorical("criterion", ["gini", "entropy"])
    }
    
    rfc = RandomForestClassifier(**params)
    rfc.fit(rfc_data['X_train'], rfc_data["y_train"])
    #return #1.0 - accuracy_score(rfc_data["y_test"], rfc.predict(rfc_data["X_test"]))
    return accuracy_score(rfc_data["y_test"], rfc.predict(rfc_data["X_test"]))

プロセスの定義

directionで最小化か、最大化か、過程をsqliteに書き込むかなどを指定し、optimizeで実行します。

study_name = 'rfc_optimize'
s = optuna.create_study(
    study_name = study_name,
    direction='maximize',
    storage='sqlite:///example.db',
    load_if_exists=True #途中までの結果があればそれを利用する

)
s.optimize(objective,n_trials=10)

#結果の確認
s.best_params
s.best_value

その他便利な物

sqliteに入れた場合は、sqlite3で確認できます。

sqlite3 example.db
.table
select * from trials

また、studyからデータフレームで確認したり、可視化をすることもできます。(plotlyのextensionが必要になります)

s.trials_dataframe(attrs=('number', 'value', 'params', 'state'))
#ハイパーパラメータの重要度
optuna.visualization.plot_param_importances(s)
#過程の可視化
optuna.visualization.plot_optimization_history(s)

予測の解釈がしたい(SHAPを使う)

Limeでレコード単位での予測の把握をすることはできますが、それを全体化したものを見てみたい場合があります。 そうした時はSHAPを使うと便利です。

コード

セットアップ

import shap
shap.initjs()
#ランダムフォレストなどのモデルを渡す
ex = shap.TreeExplainer(mdl)
#トレインデータを渡す
shap_v = ex.shap_values(boston_X)

可視化

個別データの可視化

まずはLimeと同様に個別のデータを見てみます

shap.force_plot(
    ex.expected_value,
    shap_v[15,:],
    boston_train_X[15,:], #対象レコード
    feature_names = boston['feature_names']
)

f:id:esu-ko:20200920104314p:plain

各レコードを同時に可視化してみます。 元の順番や、目的変数の大きさで並べることができます。

shap.force_plot(
    base_value=ex.expected_value,
    shap_values=shap_v, 
    feature_names=boston['feature_names']
)

f:id:esu-ko:20200920104354p:plain f:id:esu-ko:20200920104406p:plain]

特徴そのものの確認

特徴のインパクトの大きさや、特徴量内の大きさがどう影響しているのかを平面で表現できます。

#特徴量のインパクト
shap.summary_plot(
    shap_v,
    boston_test_X, #予測データ
    plot_type="bar",
    feature_names = boston['feature_names']
)

shap.summary_plot(
    shap_v,
    boston_test_X,
    feature_names = boston['feature_names']
)

f:id:esu-ko:20200920104436p:plain f:id:esu-ko:20200920104449p:plain

個別の変数の増減、他の変数との関係、SHAP値との関係も見ることができます。

shap.dependence_plot(
    ind='ZN',
    interaction_index = 'AGE',
    shap_values=shap_v,
    #features=pd.DataFrame(boston_X,columns =load_boston()['feature_names'])
    features = boston_test_X,
    feature_names = boston['feature_names']
)

f:id:esu-ko:20200920104502p:plain ZNが40くらいまでは加工し、そこからは上昇していき、その群はAGEが40以下くらい、といったところでしょうか。