Keras API предоставляет встроенные классы для регулярного сохранения моделей во время подгонки моделей. Чтобы сохранить модель и восстановить ее позже, мы можем создать обратный вызов ModelCheckPoint, передаваемый в model.fit, и модель будет регулярно сохраняться во время обучения.
Сохранить всю модель v.s. Сохранить олько веса
Есть два варианта сохранения модели - только веса или включение конфигурации обучения, а также архитектуры модели. Если флаг save_weights_only имеет значение True при создании ModelCheckpoint, модель будет сохранена как model.save_weights (путь к файлу), чтобы сохранить только веса модели. Если установлено значение False, полная модель сохраняется как model.save (путь к файлу) в формате SavedModel.
По умолчанию модель будет сохраняться каждую эпоху. Но его можно переопределить с помощью save_freq в ModelCheckpoint.
Также по умолчанию сохраняются 5 последних контрольных точек.
model.save_weights
Если модель сохраняется только с весами, нам нужно сначала создать экземпляр новой модели, прежде чем восстанавливать веса. Скорее всего, мы вызываем исходный код Python (в нашем примере create_model) для создания экземпляра модели. Затем мы загружаем веса модели с помощью model.load_weights.
Без обратного вызова мы также можем использовать model.save_weights для сохранения веса модели вручную.
model.save
Чтобы сохранить полную модель, мы используем model.save (путь к файлу), чтобы сохранить ее как SavedModel.Как объяснено ниже, он содержит состояние оптимизатора и итератора набора данных для возобновления всего обучения с последней сохраненной точки. Поскольку архитектура и конфигурация модели сохраняются, модель можно восстановить напрямую, без создания экземпляра модели. Нам не нужен модельный код Python.Фактически, это может уменьшить количество ошибок при развертывании в производственной среде.
Когда модель сохраняется, все ее tf.Variable сохраняются, а все аннотированные методы @ tf.function будут сохранены в виде графика. Это основная причина того, что нам не нужен код Python, поскольку модель восстанавливается из графа вычислений. Но для этого необходимо, чтобы все методы, необходимые для пользовательских моделей, были охвачены аннотацией @ tf.function.
CheckpointManager
Мы также можем использовать CheckpointManager для сохранения моделей, если мы хотим использовать более низкий уровень Keras API. Приведенный ниже код представляет собой шаблонный код для создания набора данных игрушки и модели. Он также содержит коды для шага обучения.
Чтобы сохранить контрольную точку, мы создаем CheckpointManager с контрольной точкой. Эта контрольная точка содержит модель, оптимизатор, состояние обучения (шаг) и итератор набора данных. Перед обучением мы можем восстановить контрольную точку с последней сохраненной контрольной точкой. Это загружает веса модели и восстанавливает состояние оптимизатора, итератора набора данных и этапов обучения. Короче говоря, мы возобновляем состояние обучения при последнем сохранении модели, а не только веса модели.
В качестве альтернативы мы можем восстановить обучение с нуля, используя новые экземпляры оптимизатора, модели, набора данных и итератора. Вот код, как это сделать.
Восстановить тренировку
Наконец, мы немного подробнее рассмотрим, что сохраняется в SavedModel и как восстанавливается сеанс обучения. Контрольная точка в предыдущем разделе не сохраняет только параметры модели. Он также содержит состояние оптимизатора (скорость обучения, затухание) и любые параметры, связанные с параметрами обучения, например импульс (m). Он также содержит состояние обучения, включая этап обучения и счетчик сохранения, добавленные к имени файла контрольной точки. Следовательно, когда контрольная точка восстанавливается, она также восстанавливает состояние оптимизатора и состояния контрольной точки. Он также проверяет ход выполнения итератора набора данных. Следовательно, итератор можно возобновить с того места, где он остановился, а не начинать с начала.
checkpoint.restore восстанавливает значения переменных для любого подходящего пути из объекта контрольной точки, т.е. мы можем просто загрузить часть контрольной точки. Например, мы можем воссоздать только часть модели, а в примере ниже мы просто загружаем веса смещения из контрольной точки плотного слоя self.l1.
Копировать веса
В приведенном ниже коде веса копируются с одного слоя на другой.
В приведенном ниже коде, даже несмотря на то, что function_model_with_dropout содержит дополнительный слой исключения по сравнению с function_model, слой исключения не содержит никакого веса. Таким образом, мы все еще можем копировать веса из function_model в function_model_with_dropout.
Источники и ссылки
Код в этой статье в основном взят из руководства