В этой статье хочу поделиться наглядным примером процесса обучения нейронной сети. Рассмотрение на пальцах позволяет лучше понять весь процесс, находить узкие места и производить подстройку модели в целях повышения эффективности ее работы.
Рассмотрим сеть из одного нейрона, который будет искать параметры прямой, описывающей тренировочные данные. Последние получены путем умножения массива из 1000 точек (x ) на множитель (7), добавления сдвига (5) и гауссова шума (noise). Таким образом, искомые параметры прямой (у=7x + 5 +noise) - 7 (угол наклона) и 5 (свободный член). Ниже представлен код (tf - псевдоним библиотеки TensorFlow):
Нейрон сформируем как подкласс tf.Module, содержащий две переменные W, b (после обучения они должны содержать искомые параметры) и служебный метод __call__, в котором происходит умножение входных наборов на W и сложение с b (подробнее о формировании структурных элементов нейросетей рассказывал ранее):
Обратите внимание, что первоначально W и b инициализированы произвольными значениями (3 и 0), не равными искомым. Соответственно, построенная на их основе прямая y=W*x + b плохо описывает наши данные:
Мы уже располагаем данными x и y, классом нейрона, а теперь определим функцию потерь. Как я рассказывал ранее, она является математической интерпретацией цели, которую мы хотим достичь. В нашем случае - это определить W и b, которые лучше всего приближают к y результат воздействия нейрона на x (вызов метода __call__). Поэтому возьмем среднее квадрата разностей истинных и предсказанных нейроном значений. Чем оно меньше, тем ближе предсказанные значения к реальным:
Теперь реализуем стадию тренировки:
В данном коде в течение 200 итераций происходит вычисление производных функции потерь по параметрам нейрона и на основании этого - изменение W и b (подробнее читай здесь), чтобы в следующей итерации функция потерь стала меньше. Также в целях демонстрации постепенного улучшения результата, выборочные значения W и b запоминаются, а затем построенные на их основе прямые отображаются на графике:
Как можно заметить, прямые приближаются к нашим данным и каждая последующая все лучше описывает их.
Полный код сценария представлен ниже: