Rでは決定木の可視化は非常に楽だが、Pythonでは他のツールを入れながらでないと、、、と昔は大変だったのですが、現在ではsklearnのplot_treeだけで簡単に表示できるようになっています。
さらにplot_tree
はmatplotlibと同様に操作できるため、pandasなどに慣れている人はカスタムも楽になっています。
やってみる
from sklearn.tree import DecisionTreeClassifier,plot_tree from sklearn.datasets import load_iris import matplotlib.pyplot as plt #データと決定木の準備 clf = DecisionTreeClassifier() clf.fit(X=iris['data'],y=iris['target']) #可視化を行う # matplotlibで画像サイズの調整 plt.figure(figsize=(20,20)) #描画 plot_tree( clf, max_depth=2, proportion=True, feature_names = iris['feature_names'], filled=True, rotate=False, precision=2, fontsize=12 #ax #class_names = あまり使わない ) #matplotlibで保存 plt.savefig("tree.pdf")
箱の大きさはfontsizeなどで勝手に変わるようです。
日本語を使う場合などは、japanize-matplotlib
などと組み合わせる必要があるかもしれません。
引数など
特に以下の5つは探索的に試す際に便利です。
- max_depth : rootからいくつまで可視化するか
- proportion : データの数を比率で出すか
- filled:うまく分類できた場合色をつけるか
- fontsize:文字のサイズ
- precision:小数点以下をいくつか
また、axはグループごとに木をつくるときに表示をいじる際に便利です。