Введение в PyTorch Lightning — платформу, упрощающую и ускоряющую обучение работу моделей глубокого обучения.
Что такое PyTorch Lightning?
PyTorch Lightning — это легкий и высокопроизводительный фреймворк, созданный на основе PyTorch, который позволяет организовать ваш код и автоматизировать процесс оптимизации обучения. PyTorch также предоставляет следующие функции:
- метрики (точность, полнота данных и т. д.)
- регистрация метрик
- контрольные точки обучения модели
- обучение на нескольких GPU, TPU, CPU
- более быстрая реализация (300 мс на эпоху по сравнению с чистым PyTorch)
Официальные документы можно найти здесь, а исходный код на GitHub — здесь.
Установка:
Прежде всего необходимо установить pytorch-lightning:
pip install pytorch-lightning
теперь импортируйте модуль в свой код:
import pytorch_lightning as pl
Существуют разные способы использования PyTorch Lightning. В этом посте мы будем использовать только pl.LightningModule ,изменяя nn.Module при обучении модели .
LightningModule — это torch.nn.Module, но с дополнительными функциями.
Чтобы использовать PyTorch Lightning, вы должны изменить свой код в соответствии с функциями LightningModule. Модуль и все его необходимые функции показаны ниже:
Эти 4 функции — минимум, необходимый для обучения вашей модели с помощью Lightning. Возможно, вам потребуется добавить и другие функции: prepare_data(), validation_step(), test_step() и predict_step().
Итак, изменения, которые вам нужно внести, показаны на картинке ниже:
По шагам:
- Передаем pl.LightningModule вместо nn.Module в модуль
- вы можете удалить .to(device) — Lightning автоматически перемещает данные, поступающие от LightningModule.
Полный код приведенных выше примеров можно найти здесь.
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
import pytorch_lightning as pl
# Download training data from open datasets.
training_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor(),
)
# Download test data from open datasets.
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=ToTensor(),
)
# Create data loaders.
batch_size = 64
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)
# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")
# Define model
class NeuralNetwork(nn.Module):
def __init__(self):
super(NeuralNetwork, self).__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28*28, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 10)
)
def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
model = NeuralNetwork().to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
# train
def train(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
model.train()
for batch, (X, y) in enumerate(dataloader):
X, y = X.to(device), y.to(device)
# Compute prediction error
pred = model(X)
loss = loss_fn(pred, y)
# Backpropagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch % 100 == 0:
loss, current = loss.item(), batch * len(X)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
epochs = 5
for t in range(epochs):
print(f"Epoch {t+1}\n-------------------------------")
train(train_dataloader, model, loss_fn, optimizer)
print("Done!")
Для параллельного сравнения PyTorch и PyTorch Lightning прочитайте следующую статью, написанную одним из создателей PyTorch Lightning.
Обучите свою модель
Если вы структурировали свой код на LightningModule, вы можете обучить свою модель всего тремя строками кода с помощью функции Trainer , построенной на основе вложенных циклов, которые выполняются для всех пакетов в загрузчике данных во все эпохи:
model = NeuralNetworkLit()
trainer = pl.Trainer(max_epochs=5)
trainer.fit(model, train_dataloader)
Для обучения и тестирования модели в PyTorch Lightning предусмотрена специальная функция Trainer(), имеющая широкий набор настроек, к примеру, можно указать вручную количество графических процессоров, если мы имеем дело с распределенными вычислениями. Перед началом обучения Trainer() самостоятельно проверит наличие GPU или TPU, и переключит вычисления на доступный, в PyTorch это надо прописывать вручную.
В Trainer() добавлен ряд привычных из классического ML методов, таких как fit, test, predict, доступных «из коробки».
Logging 📃
В LightningModule вы можете регистрировать метрики ) для этапа обучения, проверки или тестирования моделей. Для этого используйте метод log() для метрики, которую вы хотите отслеживать. Эта функция отправляет вычисленные метрики в логгер, который затем сохраняет их в каталоге по умолчанию в вашем рабочем каталоге.
def training_step(self, batch, batch_idx):
x, y = batch
loss = self.loss_fn(pred, y)
self.log("train_loss", loss)
return loss
По умолчанию log():
логи в после каждого шага обучения , попадают в training_step()
логи в после каждой эпохи попадают в validation_step() и test_step()
Вы можете изменить это с помощью параметров on_step и on_epoch, как показано ниже:
self.log("train_loss", loss, on_step=False, on_epoch=True)
❗Если вы хотите поставить контрольную точку на своей модели и добавить в обучение остановки, необходимо, чтобы log() был добавлен к метрике, которую вы хотите отслеживать.
Проверка модели 💾
Вы можете автоматически сохранять веса вашей модели на этапе обучения или проверки модели на основе метрики, которую вы хотите отслеживать (например, accuracy, loss) через ModelCheckpoint. Чтобы ModelCheckpoint:
from pytorch_lightning.callbacks import ModelCheckpoint
2. Добавьте log() к метрике, которую вы хотите отслеживать:
def validation_step(self, batch, batch_idx):
x, y = batch
loss = self.loss_fn(pred, y)
self.log("val_loss", loss)
return loss
3. Создадим экземпляр класса ModelCheckpoint:
checkpoint_callback = ModelCheckpoint(monitor='val_loss',mode='min')
передадим в monitor параметр метрики, которую мы хотим использовать (строка, которую мы определили в log())
4. Передадмс checkpoint callback в Trainer через параметр callbacks:
trainer = pl.Trainer(max_epochs=5, callbacks=[checkpoint_callback])
Кроме того, для автоматического сохранения параметров вашей модели добавьте self.save_hyperparameters() в LightningModule 's __init__(). Затем параметры модели будут сохранены в атрибуте self.hparams, а также будут сохранены в контрольной точке модели:
def __init__(self, batch, batch_idx):
self.save_hyperparameters()
Ниже вы можете увидеть папку lightning_logs, созданную логгером. Это каталог по умолчанию, в котором хранятся логи, контрольные точки и параметры модели.
Ранние остановки 🛑
Вы можете остановить обучение своей модели на ранней стадии с помощью обратного вызова EarlyStopping, когда нет улучшений в отслеживаемой метрики. Чтобы остановить работу модели досрочно:
from pytorch_lightning.callbacks.early_stopping import EarlyStop
def validation_step(self, batch, batch_idx):
x, y = batch
pred = self.loss_fn(pred, y)
self.log("val_loss", loss)
return loss
2. Создадим экземпляр класса EarlyStopping:
early_stopping = EarlyStopping(monitor="val_loss", mode="min", patience=10)
передадим в monitor параметр название метрики, которую вы хотите отслеживать
передадим в monitor параметр режима ‘min’ or ‘max’ , чтобы остановить тренировку, когда отслеживаемая метрика перестала улучшаться
patience - количество эпох проверки без улучшения
вызовем EarlyStopping в Trainer через параметр callbacks:
trainer = pl.Trainer(max_epochs=5, callbacks=[early_stopping])
Обучение работе с несколькими графическими процессорами ⏭
Чтобы тренироваться на нескольких графических процессорах, просто передайте параметру gpus в Trainer количество GPU вашего устройства, которое вы хотите использовать:
например для использования 2 GPU:
trainer = Trainer(gpus=2)
например для тренировки на всех доступных GPU используйте gpus=-1:
trainer = Trainer(gpus=-1)
С помощью Lightning вы можете делать многие другие вещи, не упомянутые в этом руководстве, например структурировать очистку, обработку и разделение данных с помощью модуля LightningDataModule, поиск LR, и многое другое.
Спасибо за прочтение!
Ссылки: