Найти в Дзене
Властелин машин

Градиентный бустинг с библиотекой H2O

Рассмотрим, какие параметры имеет класс H2OGradientBoostingEstimator для решения задач машинного обучения, как обучать и тестировать модель. В качестве примера загрузим данные о заболеваемости раком через метод load_breast_cancer из sklearn (подробнее о получении датасетов с использованием библиотеки):

Разделим выборки, инициализируем h2o кластер (подробнее здесь), создадим фрейм и укажем колонки:

-2

Ниже привожу основные параметры модели, которые понадобятся для обучения:

  • ntrees - количество деревьев (по умолчанию 50);
  • max_depth - максимальная глубина (5);
  • min_rows - минимальное количество наблюдений в листе (10)
  • min_split_improvement - минимальное улучшение для разделения узла (1e-05);
  • learn_rate - скорость обучения (0.1);
  • learn_rate_annealing - множитель скорости обучения с каждой итерацией (1)
  • sample_rate - доля наблюдений для обучения каждого дерева (1);
  • col_sample_rate - доля признаков, участвующих при сплите дерева(1);
  • seed - инициализатор счетчика случайных чисел;
  • balance_classes - надо ли сбалансировать классы перед обучением (False);
  • model_id - имя модели (идентификатор, по которому модель можно найти в H2O Flow);
  • build_tree_one_node - флаг запуска вычислений только на одном узле (False). При малых датасетах лучше выставить, чтобы не задействовать ресурсы кластера и тратить время на взаимодействие между нодами;
  • distribution - задает тип задачи и распределение target-а. Предпочтительными значениями для бинарной классификации являются bernoulli, многоклассовой - multinomial (тип target-а факторный), для регрессии - gaussian, poisson (тип target-а числовой);

Для обучения создаете класс с желаемыми параметрами и вызываете метод train, которому передайте список имен признаков, целевой колонки, тренировочный фрейм и другие параметры (я передам еще валидационный фрейм):

-3

Значения набора валидационных метрик можно получить методом model_performance с параметром valid=True или, передав ему валидационный фрейм в параметре test_data:

-4

-5