Рассмотрим пример построения дерева решений и работы модели на примере классификации цветков Ириса:
from sklearn.datasets import load_iris
iris_df = load_iris(as_frame=True)['frame']
iris_df.head()
Обучим классификатор:
from sklearn.tree import DecisionTreeClassifier
model = DecisionTreeClassifier(random_state=0).fit(iris_df.drop(columns='target'), iris_df.target)
features_l = iris_df.drop(columns='target').columns.tolist()
Визуализация дерева
В модуле sklearn.tree есть функция plot_tree, с которой можно легко нарисовать дерево, для каждого узла включается признак ветвления, граница, загрязненность, количество примеров всего и их распределение по классам:
from sklearn.tree import plot_tree
import matplotlib.pyplot as plt
plt.figure(figsize=(18,7))
_ = plot_tree(model, feature_names = features_l)
Экспорт в текст
Есть и текстовое представление того же дерева, которое можно получить при помощи export_text:
from sklearn.tree import export_text
print(export_text(model, feature_names=features_l))
Атрибуты работы
Более детальные единицы работы алгоритма извлекаются из свойств атрибута tree_: номера левого и правого узлов ветвления (children_left, children_right) и "загрязненность" (impurities) для каждого узла, номера признаков ветвления (features), соответствующие им границы (threshold). Проще поместить эти элементы в один датафрейм и получится целостная картина работы дерева:
import pandas as pd
children_left = model.tree_.children_left
children_right = model.tree_.children_right
features = model.tree_.feature
names = [features_l[i] if i>0 else '' for i in features ]
thres = model.tree_.threshold
impurities = model.tree_.impurity
pd.DataFrame({'children_left':children_left, 'children_right':children_right,
'names':names, 'features':features, 'thresholds':thres,
'impurities':impurities})