Найти тему
Властелин машин

Предсказываем с моделями машинного обучения

Продемонстрируем мощь алгоритмов машинного обучения для предсказания некой целевой переменной. В демонстрационных целях будем использовать программно сгенерированный набор о доходах и расходах людей (ссылка в конце статьи).

Набор помимо перечисленной информации также включает ФИО человека, индекс города проживания и сумму дотаций из регионального бюджета:

Нашей целью будет являться предсказание расходов человека по доходу и сумме помощи. При этом намеренно установлена следующая зависимость:

data_m['расходы'] = 2*np.sqrt(data_m['зарплата'])+0.5*data_m['сумма_помощи']

Для предсказания будем использовать популярную реализацию алгоритма градиентного бустинга LightGBM.

Рассмотрим минимальные шаги, необходимые для получения модели (без ее настройки, о которой пойдет речь в последующих статьях). В частности, отделим тренировочные данные (для обучения) и тестовые (для проверки качества), инициируем процесс обучения и сравним предсказанные и тестовые данные.

Сначала реализуем этап обучения:

X_tr, X_ts, y_tr, y_ts = train_test_split(data_m[['зарплата', 'сумма_помощи']], data_m['расходы'], test_size=0.2)

train_data = lgb.Dataset(X_tr, label=y_tr)

test_data = lgb.Dataset(X_ts, label=y_ts)

params = {"objective": "regression"}

bst = lgb.train(params, train_set = train_data, valid_sets=[test_data])

С помощью train_test_split мы разделили данные на обучающую и тестовую выборки, затем заполнили внутренний тип библиотеки lightgbm - Dataset данными для обучения и тренировки, задали параметры модели в словаре params (пока достаточно лишь задания типа задачи - классификация/регрессия).

После завершения обучения train вернет экземпляр объекта Booster, который будет использован для будущих предсказаний.

y_pr = bst.predict(X_ts, num_iteration=bst.best_iteration)

-2

Как можно заметить, алгоритм смог даже с настройками по умолчанию уловить закономерности в формировании расходов и достаточно точно предсказать данные.

Код для генерации исследованного набора данных можно скопировать отсюда.

-3