Добавить в корзинуПозвонить
Найти в Дзене
Один Rust не п...Rust

Bash & Rust ML

Оглавление
GitHub - nicktretyakov/error_bash_predi
ML на RUST без заморочек
Один Rust не п...Rust

Реализация утилиты командной строки, для пресказаний с использованием ML

  • Модель: ResNet-18 обучается на CIFAR-10.
  • API: Веб-сервер принимает изображения и возвращает предсказания.
  • CLI: Утилита для предсказаний из командной строки.

1. Структуры данных для запросов и ответов

PredictRequest и PredictResponse

  • Что это? Это структуры, которые описывают формат данных для запроса к серверу и ответа от сервера.
  • Пример:
    #[derive(Deserialize)]
    struct PredictRequest {
    image: Vec<u8>,
    // Вектор байтов — изображение в "плоском" виде
    }PredictRequest: сервер ожидает получить изображение в виде вектора байтов.
    PredictResponse: сервер возвращает класс (например, "кошка" или "машина") и уверенность модели в этом предсказании (например, 95%).

2. Определение модели ResNet-18

resnet18

  • Что это? Это функция, которая создаёт нейронную сеть ResNet-18 — популярную архитектуру для классификации изображений.
  • Как работает? Сначала идёт свёрточный слой (conv1), который извлекает признаки из изображения.
    Затем идёт пакетная нормализация (bn1) и слои layer1 — это блоки ResNet, которые помогают сети обучаться на глубоких слоях.
  • Пример:
    let model = resnet18(&vs.root(), 10); // 10 классов для CIFAR-10

3. Загрузка и предобработка данных

load_cifar10

  • Что это? Функция загружает датасет CIFAR-10 (60 000 цветных изображений 32x32 в 10 классах) и переносит их на GPU для ускорения обучения.
  • Пример:
    let (train_images, train_labels, test_images, test_labels) = load_cifar10();

4. Обучение модели

train_model

  • Что это? Функция обучает модель на тренировочных данных.
  • Как работает?Использует оптимизатор Adam для обновления весов модели.
    На каждой эпохе (цикл обучения) вычисляет ошибку (loss) и обновляет веса, чтобы уменьшить эту ошибку.
  • Пример:
    train_model(&model, &train_images, &train_labels);

5. Веб-сервер для предсказаний

predict и HttpServer

  • Что это? predict: функция-обработчик, которая принимает изображение, пропускает его через модель и возвращает предсказание.
    HttpServer: запускает веб-сервер на порту 8080, который слушает запросы на /predict.
  • Пример:
    HttpServer::new(move || {
    App::new()
    .app_data(model_data.clone())
    .route("/predict", web::post().to(predict))
    })
    .bind("127.0.0.1:8080")?
    .run()
    .await

6. Утилита командной строки

predict_from_cli

  • Что это? Функция для предсказания класса изображения из командной строки (без веб-сервера).
  • Пример:
    ./target/release/predict --model resnet18.pt --image cat.jpg
  • Вывод:
    Предсказанный класс: 3
    Уверенность: 95.00%

7. Скрипты для запуска

Bash-скрипты

  • Что это? Скрипты для компиляции, запуска веб-сервера и пакетной обработки изображений.
  • Пример:
    cargo build --release
    ./target/release/my_project