sklearnだけで決定木の可視化をしたい(plot_treeを使う)

Rでは決定木の可視化は非常に楽だが、Pythonでは他のツールを入れながらでないと、、、と昔は大変だったのですが、現在ではsklearnのplot_treeだけで簡単に表示できるようになっています。

さらにplot_treeはmatplotlibと同様に操作できるため、pandasなどに慣れている人はカスタムも楽になっています。

やってみる

from sklearn.tree import DecisionTreeClassifier,plot_tree
from sklearn.datasets import load_iris
import matplotlib.pyplot as plt


#データと決定木の準備
clf = DecisionTreeClassifier()
clf.fit(X=iris['data'],y=iris['target'])


#可視化を行う
# matplotlibで画像サイズの調整
plt.figure(figsize=(20,20))
#描画
plot_tree(
    clf,
    max_depth=2,
    proportion=True,
    feature_names = iris['feature_names'],
    filled=True,
    rotate=False,
    precision=2,
    fontsize=12
    #ax
    #class_names = あまり使わない
    )
#matplotlibで保存
plt.savefig("tree.pdf")

f:id:esu-ko:20200819233146p:plain

箱の大きさはfontsizeなどで勝手に変わるようです。 日本語を使う場合などは、japanize-matplotlibなどと組み合わせる必要があるかもしれません。

引数など

特に以下の5つは探索的に試す際に便利です。

  • max_depth : rootからいくつまで可視化するか
  • proportion : データの数を比率で出すか
  • filled:うまく分類できた場合色をつけるか
  • fontsize:文字のサイズ
  • precision:小数点以下をいくつか

また、axはグループごとに木をつくるときに表示をいじる際に便利です。