Добавить в корзинуПозвонить
Найти в Дзене
Один Rust не п...Rust

Как распознать текст на фото с использованием Rust ML

Добавьте в Cargo.toml: [dependencies] actix-web = "4" image = "0.24" tch = "0.14" serde = { version = "1.0", features = ["derive"] } use actix_web::{web, App, HttpServer, Responder, HttpResponse}; use image::{DynamicImage, imageops}; use tch::{Device, Tensor, vision::image as tch_image, CModule}; use serde::Serialize; use std::error::Error; // Структура для хранения результатов обнаружения #[derive(Serialize, Clone)] struct Detection { x: f32, // Координата x центра y: f32, // Координата y центра width: f32, // Ширина рамки height: f32, // Высота рамки confidence: f32, // Уверенность предсказания class_id: i64, // Идентификатор класса } // Структура детектора YOLO struct Detector { model: CModule, device: Device, input_size: i64, // Размер входного изображения (например, 416) confidence_threshold: f32, // Порог уверенности nms_threshold: f32, // Порог NMS num_classes: i64, // Количество классов } impl Detector { // Инициализация детектора f
Оглавление
ML на RUST без заморочек
Один Rust не п...Rust

Код реализует сервер для детекции объектов с использованием YOLO, оптимизированного NMS из tch, детальной обработки ошибок через Result и поддержки нескольких классов:

Добавьте в Cargo.toml:

[dependencies]

actix-web = "4"

image = "0.24"

tch = "0.14"

serde = { version = "1.0", features = ["derive"] }

Оптимизация с tch::vision::nms

  • В методе apply_nms мы преобразуем список bounding boxes и их уверенности в тензоры, а затем используем tch::vision::nms для выполнения NMS. Это быстрее и надежнее, чем писать NMS вручную.

Обработка ошибок

  • Все методы, которые могут завершиться неудачей (например, загрузка модели, предобработка изображения, инференс), возвращают Result. Ошибки обрабатываются в API-эндпоинте detect_objects, где клиенту возвращается понятное сообщение и соответствующий HTTP-статус.

Многоклассовая поддержка

  • Структура Detection теперь включает class_id.
  • В postprocess_output для каждого предсказания извлекаются вероятности классов (начиная с 5-го элемента выходного тензора), определяется класс с максимальной вероятностью, и эта информация добавляется в результат.

use actix_web::{web, App, HttpServer, Responder, HttpResponse};

use image::{DynamicImage, imageops};

use tch::{Device, Tensor, vision::image as tch_image, CModule};

use serde::Serialize;

use std::error::Error;

// Структура для хранения результатов обнаружения

#[derive(Serialize, Clone)]

struct Detection {

x: f32, // Координата x центра

y: f32, // Координата y центра

width: f32, // Ширина рамки

height: f32, // Высота рамки

confidence: f32, // Уверенность предсказания

class_id: i64, // Идентификатор класса

}

// Структура детектора YOLO

struct Detector {

model: CModule,

device: Device,

input_size: i64, // Размер входного изображения (например, 416)

confidence_threshold: f32, // Порог уверенности

nms_threshold: f32, // Порог NMS

num_classes: i64, // Количество классов

}

impl Detector {

// Инициализация детектора

fn new(

model_path: &str,

input_size: i64,

confidence_threshold: f32,

nms_threshold: f32,

num_classes: i64,

) -> Result<Self, Box<dyn Error>> {

let model = CModule::load(model_path)

.map_err(|e| format!("Не удалось загрузить модель: {}", e))?;

let device = Device::cuda_if_available(); // GPU, если доступно

model.to(device, tch::Kind::Float, true);

Ok(Detector {

model,

device,

input_size,

confidence_threshold,

nms_threshold,

num_classes,

})

}

// Обнаружение объектов на изображении

fn detect(&self, image: &DynamicImage) -> Result<Vec<Detection>, Box<dyn Error>> {

let tensor = self.preprocess_image(image)?;

let output = self.model.forward_ts(&[tensor])

.map_err(|e| format!("Ошибка инференса модели: {}", e))?;

self.postprocess_output(output)

}

// Предобработка изображения

fn preprocess_image(&self, image: &DynamicImage) -> Result<Tensor, Box<dyn Error>> {

let resized = imageops::resize(

image,

self.input_size as u32,

self.input_size as u32,

imageops::FilterType::Lanczos3,

);

let tensor = tch_image::from_image(&resized)

.map_err(|e| format!("Ошибка преобразования изображения: {}", e))?

.to_device(self.device)

.to_kind(tch::Kind::Float)

/ 255.0; // Нормализация

Ok(tensor.unsqueeze(0)) // Добавляем размерность батча

}

// Постобработка выходных данных модели

fn postprocess_output(&self, output: Tensor) -> Result<Vec<Detection>, Box<dyn Error>> {

let num_boxes = output.size()[1];

let mut detections = Vec::new();

for i in 0..num_boxes {

let det = output.get(0).get(i);

let confidence = f32::from(det.get(4)); // Уверенность объекта

if confidence > self.confidence_threshold {

let x = f32::from(det.get(0));

let y = f32::from(det.get(1));

let w = f32::from(det.get(2));

let h = f32::from(det.get(3));

let class_probs = det.slice(0, 5, 5 + self.num_classes, 1);

let (class_id, &max_prob) = class_probs

.iter::<f32>()?

.enumerate()

.max_by(|a, b| a.1.partial_cmp(&b.1).unwrap())

.ok_or("Не удалось определить класс")?;

detections.push(Detection {

x,

y,

width: w,

height: h,

confidence: confidence * max_prob,

class_id: class_id as i64,

});

}

}

// Применяем NMS

self.apply_nms(detections)

}

// Non-Maximum Suppression с использованием tch

fn apply_nms(&self, detections: Vec<Detection>) -> Result<Vec<Detection>, Box<dyn Error>> {

if detections.is_empty() {

return Ok(Vec::new());

}

let boxes: Vec<Tensor> = detections

.iter()

.map(|d| Tensor::of_slice(&[d.x - d.width / 2.0, d.y - d.height / 2.0, d.x + d.width / 2.0, d.y + d.height / 2.0]))

.collect();

let scores: Vec<f32> = detections.iter().map(|d| d.confidence).collect();

let boxes_tensor = Tensor::stack(&boxes, 0);

let scores_tensor = Tensor::of_slice(&scores);

let keep = tch::vision::nms(&boxes_tensor, &scores_tensor, self.nms_threshold)

.map_err(|e| format!("Ошибка NMS: {}", e))?;

let filtered: Vec<Detection> = keep

.iter::<i64>()?

.map(|i| detections[i as usize].clone())

.collect();

Ok(filtered)

}

}

// API-эндпоинт для обнаружения

async fn detect_objects(

detector: web::Data<Detector>,

image: web::Bytes,

) -> impl Responder {

match image::load_from_memory(&image) {

Ok(img) => match detector.detect(&img) {

Ok(detections) => HttpResponse::Ok().json(detections),

Err(e) => HttpResponse::InternalServerError().body(format!("Ошибка: {}", e)),

},

Err(_) => HttpResponse::BadRequest().body("Неверный формат изображения"),

}

}

// Запуск сервера

#[actix_web::main]

async fn main() -> std::io::Result<()> {

let detector = web::Data::new(

Detector::new(

"path/to/yolo_model.pt", // Путь к модели

416, // Размер изображения

0.5, // Порог уверенности

0.4, // Порог NMS

2, // Количество классов (пример)

)

.expect("Ошибка инициализации детектора"),

);

HttpServer::new(move || {

App::new()

.app_data(detector.clone())

.route("/detect", web::post().to(detect_objects))

})

.bind("127.0.0.1:8080")?

.run()

.await

}

  • Размер входного изображения: Убедитесь, что input_size соответствует требованиям вашей версии YOLO (например, 416 для YOLOv3, 608 для YOLOv4).
  • Формат выходных данных: Проверьте структуру выходного тензора вашей модели и адаптируйте функцию postprocess_output. Например, для YOLOv5 выходные данные могут отличаться.
  • Классы: Если нужно обнаружить несколько классов, измените логику обработки вероятностей классов в postprocess_output.