Grid Search в Машинном обучении простыми словами

116 прочитали
Изображение: Midjourney
Изображение: Midjourney

Алгоритм поиска по сетке – это метод подбора оптимальных гиперпараметров для Модели (Model) путем перебора всех возможных комбинаций значений Гиперпараметров (Hyperparameter) из заданного набора. Гиперпараметры – это параметры модели, которые не оптимизируются во время процесса обучения, а задаются до его начала. Их оптимальный выбор влияет на качество и обобщающую способность модели.

Допустим, мы создали Дерево решений (Decision Tree) для банковского кредитного датасета. С полным кодом модели вы можете ознакомиться в этом ноутбуке. У первой версии дерева следующие характеристики эффективности:

>>> print('Доля правильных ответов: %.3f' % tr.score(X_test, y_test))
>>> print('Доля правильных ответов во время кросс-валидации: %0.3f' % cv_tr)
>>> print('Точность результата измерений: %.3f' % precision_score(y_test, tr_pred))
>>> print('Полнота: %.3f' % recall_score(y_test, tr_pred))
>>> print('Оценка F1: %.3f' % f1_score(y_test, tr_pred))
... Доля правильных ответов: 0.910
... Доля правильных ответов во время кросс-валидации: 0.908
... Точность результата измерений: 0.946
... Полнота: 0.953
... Оценка F1: 0.950

Но мы хотим автоматически подобрать наилучшие параметры. С этим нам поможет модуль scikit-learn под названием GridSearchCV:

from sklearn.model_selection GridSearchCV

Выберем параметры, значения которых будем перебирать и инициируем экземпляр GridSearchCV:

parameters = {'criterion':['gini','entropy'],
'max_depth':[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20]
}

default_tr = tree.DecisionTreeClassifier(random_state=25)
gs_tree = GridSearchCV(default_tr, parameters, cv=10, n_jobs=-1,verbose=1)
gs_tree.fit(X_train, y_train)

Запустим обучение дерева решений, чтобы подобрать значения гиперпараметров:

gs_tree_pred = gs_tree.predict(X_test)

Выведем оптимизированные значения:

>>> print('Лучшие параметры дерева решений: {}'.format(gs_tree.best_params_))
>>> print('Доля правильных ответов: %0.3f' % (gs_tree.score(X_test,y_test)))
>>> print('Доля правильных ответов кросс-валидации: %0.3f' % gs_tree.best_score_)
>>> print('Точность: %.3f' % precision_score(y_test, gs_tree_pred))
>>> print('Полнота: %.3f' % recall_score(y_test, gs_tree_pred))
>>> print('F1-мера: %.3f' % f1_score(y_test, gs_tree_pred))

... Лучшие параметры дерева решений: {'criterion': 'gini', 'max_depth': 5}
... Доля правильных ответов: 0.915
... Доля правильных ответов кросс-валидации: 0.913
... Точность: 0.936
... Полнота: 0.970
... F1-мера: 0.953