Найти тему

PyTorch: что это и с чем едят?

Фото: Mateus Maia / Unsplash
Фото: Mateus Maia / Unsplash

PyTorch – это библиотека Машинного обучения (ML) с открытым исходным кодом, позволяющая решать огромное количество задач, в числе которых:

  • Распознавание изображений (Image Recognition)
  • Обнаружение мошеннических операций (Fraud Detection)
  • Распознавание рукописного текста (Hadrwriting Recognition) и проч.
-2

Основные преимущества:

  • Готовность к развертыванию: сопутствующие сервисы TorchScript – программа для создания и оптимизации Модели (Model) и TorchServe – фреймворк для развертывания обученной модели
  • Оптимизация производительности
  • Надежная экосистема расширяет PyTorch и поддерживает разработку в области Компьютерного зрения (CV), Обработки естественного языка (NLP) и других областях.
  • Облачная поддержка: PyTorch хорошо поддерживается на основных облачных платформах, обеспечивая беспроблемную разработку и легкое масштабирование.

Простейшая модель на PyTorch

Посмотрим, как работает фреймворк и насколько кратким может быть код. Для начала импортируем необходимые библиотеки:

-3

Мы инициируем линейную регрессию. Мы используем функцию, которая комбинирует входные значения x и веса w линейно. Создадим обучающие выборки x и y: тип данных – ‘numpy.float32’. Y – это просто удвоенные значения x. Инициализируем наши веса, равные нулю для начала:

-4

А теперь нам нужно создать прогноз, рассчитать потери и градиент. Каждый из этих шагов мы выполним вручную. Инициализируем функцию "Прямой проход" (Forward Pass), и она получит x в качестве аргумента. "На выходе" у нашей функции W, умноженный на X:

-5

Здесь мы определяем потерю функции, которая зависит от прогнозируемых y. В случае линейной регрессии, это среднеквадратическая ошибка. Мы можем рассчитать это, возведя разницы между предсказанными и реальными значениями y в квадрат и взяв среднее от полученного: нам предстоит теперь вручную рассчитать градиент потерь согласно нашим параметрам.

Давайте посмотрим на среднеквадратичную ошибку. Формула выглядит так: у нас есть производная. Частное производных J и весов равно отношению единицы к N, умноженное на 2x и (w * x - y):Мы реализуем это с помощью функции gradient следующим образом. Мы можем сделать это в одной строке, поэтому возвращаем numpy., умноженное на два X, разница между предсказанным и реальным y. И, конечно, нам также понадобится среднее:

-6

Выведем наш прогноз перед обучением:

-7

До наших тренировок прогноз равен нулю.

-8

А теперь приступим к обучению и определим некоторые параметры. Нам нужна скорость обучения, которая равна нулю. Определим число итераций, равное 20:

-9

Теперь давайте выполним наш цикл обучения. Выполним прямой проход, функцией forward(). Затем рассчитаем потери, потому Y теперь нам нужен, чтобы получить градиенты. И теперь мы должны обновить наши веса. Формула обновления в алгоритме градиентного спуска заключается в том, что мы просто меняем на противоположное скорость обучения градиента, а затем умножаем ее на наш градиент.

Итак, это формула обновления. Допустим, мы также хотим вывести здесь некоторую информацию, поэтому мы выводим на экран каждый шаг, если он не равен 0. Мы хотим отобразить данные о каждой эпохе. Мы инкриминируем число эпох на один, потому что нумерация смещена, а затем отображаем веса с тремя знаками после запятой. Потери отобразим по восьмой знак после запятой:

-10

Смотрите, как снижаются потери с каждой эпохой:

-11

Напомню, наша формула y равен 2 на X, поэтому наша W равна двум в начале. Мы видим, что с каждым шагом тренировки увеличиваются веса и уменьшаются потери.

Давайте используем интерполяцию – смешение текстовых данных и вычисляемого значения. Предположим, что мы хотим предсказать значение y при х, равное 5; это 10. Отобразим результат как число с трем знаками после запятой:

-12

После обучения наша модель предсказывает с точностью 1,9999:

-13

Ноутбук, не требующий дополнительной настройки на момент написания статьи, можно скачать здесь.

Понравилась статья? Поддержите нас, поделившись статьей в социальных сетях и подписавшись на канал. И попробуйте курсы на Udemy.