...Читать далее
Оглавление
Реализация утилиты командной строки, для пресказаний с использованием ML
- Модель: ResNet-18 обучается на CIFAR-10.
- API: Веб-сервер принимает изображения и возвращает предсказания.
- CLI: Утилита для предсказаний из командной строки.
1. Структуры данных для запросов и ответов
PredictRequest и PredictResponse
- Что это? Это структуры, которые описывают формат данных для запроса к серверу и ответа от сервера.
- Пример:
#[derive(Deserialize)]
struct PredictRequest {
image: Vec<u8>, // Вектор байтов — изображение в "плоском" виде
}PredictRequest: сервер ожидает получить изображение в виде вектора байтов.
PredictResponse: сервер возвращает класс (например, "кошка" или "машина") и уверенность модели в этом предсказании (например, 95%).
2. Определение модели ResNet-18
resnet18
- Что это? Это функция, которая создаёт нейронную сеть ResNet-18 — популярную архитектуру для классификации изображений.
- Как работает? Сначала идёт свёрточный слой (conv1), который извлекает признаки из изображения.
Затем идёт пакетная нормализация (bn1) и слои layer1 — это блоки ResNet, которые помогают сети обучаться на глубоких слоях. - Пример:
let model = resnet18(&vs.root(), 10); // 10 классов для CIFAR-10
3. Загрузка и предобработка данных
load_cifar10
- Что это? Функция загружает датасет CIFAR-10 (60 000 цветных изображений 32x32 в 10 классах) и переносит их на GPU для ускорения обучения.
- Пример:
let (train_images, train_labels, test_images, test_labels) = load_cifar10();
4. Обучение модели
train_model
- Что это? Функция обучает модель на тренировочных данных.
- Как работает?Использует оптимизатор Adam для обновления весов модели.
На каждой эпохе (цикл обучения) вычисляет ошибку (loss) и обновляет веса, чтобы уменьшить эту ошибку. - Пример:
train_model(&model, &train_images, &train_labels);
5. Веб-сервер для предсказаний
predict и HttpServer
- Что это? predict: функция-обработчик, которая принимает изображение, пропускает его через модель и возвращает предсказание.
HttpServer: запускает веб-сервер на порту 8080, который слушает запросы на /predict. - Пример:
HttpServer::new(move || {
App::new()
.app_data(model_data.clone())
.route("/predict", web::post().to(predict))
})
.bind("127.0.0.1:8080")?
.run()
.await
6. Утилита командной строки
predict_from_cli
- Что это? Функция для предсказания класса изображения из командной строки (без веб-сервера).
- Пример:
./target/release/predict --model resnet18.pt --image cat.jpg - Вывод:
Предсказанный класс: 3
Уверенность: 95.00%
7. Скрипты для запуска
Bash-скрипты
- Что это? Скрипты для компиляции, запуска веб-сервера и пакетной обработки изображений.
- Пример:
cargo build --release
./target/release/my_project