Код реализует сервер для детекции объектов с использованием YOLO, оптимизированного NMS из tch, детальной обработки ошибок через Result и поддержки нескольких классов:
Добавьте в Cargo.toml:
[dependencies]
actix-web = "4"
image = "0.24"
tch = "0.14"
serde = { version = "1.0", features = ["derive"] }
Оптимизация с tch::vision::nms
- В методе apply_nms мы преобразуем список bounding boxes и их уверенности в тензоры, а затем используем tch::vision::nms для выполнения NMS. Это быстрее и надежнее, чем писать NMS вручную.
Обработка ошибок
- Все методы, которые могут завершиться неудачей (например, загрузка модели, предобработка изображения, инференс), возвращают Result. Ошибки обрабатываются в API-эндпоинте detect_objects, где клиенту возвращается понятное сообщение и соответствующий HTTP-статус.
Многоклассовая поддержка
- Структура Detection теперь включает class_id.
- В postprocess_output для каждого предсказания извлекаются вероятности классов (начиная с 5-го элемента выходного тензора), определяется класс с максимальной вероятностью, и эта информация добавляется в результат.
use actix_web::{web, App, HttpServer, Responder, HttpResponse};
use image::{DynamicImage, imageops};
use tch::{Device, Tensor, vision::image as tch_image, CModule};
use serde::Serialize;
use std::error::Error;
// Структура для хранения результатов обнаружения
#[derive(Serialize, Clone)]
struct Detection {
x: f32, // Координата x центра
y: f32, // Координата y центра
width: f32, // Ширина рамки
height: f32, // Высота рамки
confidence: f32, // Уверенность предсказания
class_id: i64, // Идентификатор класса
}
// Структура детектора YOLO
struct Detector {
model: CModule,
device: Device,
input_size: i64, // Размер входного изображения (например, 416)
confidence_threshold: f32, // Порог уверенности
nms_threshold: f32, // Порог NMS
num_classes: i64, // Количество классов
}
impl Detector {
// Инициализация детектора
fn new(
model_path: &str,
input_size: i64,
confidence_threshold: f32,
nms_threshold: f32,
num_classes: i64,
) -> Result<Self, Box<dyn Error>> {
let model = CModule::load(model_path)
.map_err(|e| format!("Не удалось загрузить модель: {}", e))?;
let device = Device::cuda_if_available(); // GPU, если доступно
model.to(device, tch::Kind::Float, true);
Ok(Detector {
model,
device,
input_size,
confidence_threshold,
nms_threshold,
num_classes,
})
}
// Обнаружение объектов на изображении
fn detect(&self, image: &DynamicImage) -> Result<Vec<Detection>, Box<dyn Error>> {
let tensor = self.preprocess_image(image)?;
let output = self.model.forward_ts(&[tensor])
.map_err(|e| format!("Ошибка инференса модели: {}", e))?;
self.postprocess_output(output)
}
// Предобработка изображения
fn preprocess_image(&self, image: &DynamicImage) -> Result<Tensor, Box<dyn Error>> {
let resized = imageops::resize(
image,
self.input_size as u32,
self.input_size as u32,
imageops::FilterType::Lanczos3,
);
let tensor = tch_image::from_image(&resized)
.map_err(|e| format!("Ошибка преобразования изображения: {}", e))?
.to_device(self.device)
.to_kind(tch::Kind::Float)
/ 255.0; // Нормализация
Ok(tensor.unsqueeze(0)) // Добавляем размерность батча
}
// Постобработка выходных данных модели
fn postprocess_output(&self, output: Tensor) -> Result<Vec<Detection>, Box<dyn Error>> {
let num_boxes = output.size()[1];
let mut detections = Vec::new();
for i in 0..num_boxes {
let det = output.get(0).get(i);
let confidence = f32::from(det.get(4)); // Уверенность объекта
if confidence > self.confidence_threshold {
let x = f32::from(det.get(0));
let y = f32::from(det.get(1));
let w = f32::from(det.get(2));
let h = f32::from(det.get(3));
let class_probs = det.slice(0, 5, 5 + self.num_classes, 1);
let (class_id, &max_prob) = class_probs
.iter::<f32>()?
.enumerate()
.max_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
.ok_or("Не удалось определить класс")?;
detections.push(Detection {
x,
y,
width: w,
height: h,
confidence: confidence * max_prob,
class_id: class_id as i64,
});
}
}
// Применяем NMS
self.apply_nms(detections)
}
// Non-Maximum Suppression с использованием tch
fn apply_nms(&self, detections: Vec<Detection>) -> Result<Vec<Detection>, Box<dyn Error>> {
if detections.is_empty() {
return Ok(Vec::new());
}
let boxes: Vec<Tensor> = detections
.iter()
.map(|d| Tensor::of_slice(&[d.x - d.width / 2.0, d.y - d.height / 2.0, d.x + d.width / 2.0, d.y + d.height / 2.0]))
.collect();
let scores: Vec<f32> = detections.iter().map(|d| d.confidence).collect();
let boxes_tensor = Tensor::stack(&boxes, 0);
let scores_tensor = Tensor::of_slice(&scores);
let keep = tch::vision::nms(&boxes_tensor, &scores_tensor, self.nms_threshold)
.map_err(|e| format!("Ошибка NMS: {}", e))?;
let filtered: Vec<Detection> = keep
.iter::<i64>()?
.map(|i| detections[i as usize].clone())
.collect();
Ok(filtered)
}
}
// API-эндпоинт для обнаружения
async fn detect_objects(
detector: web::Data<Detector>,
image: web::Bytes,
) -> impl Responder {
match image::load_from_memory(&image) {
Ok(img) => match detector.detect(&img) {
Ok(detections) => HttpResponse::Ok().json(detections),
Err(e) => HttpResponse::InternalServerError().body(format!("Ошибка: {}", e)),
},
Err(_) => HttpResponse::BadRequest().body("Неверный формат изображения"),
}
}
// Запуск сервера
#[actix_web::main]
async fn main() -> std::io::Result<()> {
let detector = web::Data::new(
Detector::new(
"path/to/yolo_model.pt", // Путь к модели
416, // Размер изображения
0.5, // Порог уверенности
0.4, // Порог NMS
2, // Количество классов (пример)
)
.expect("Ошибка инициализации детектора"),
);
HttpServer::new(move || {
App::new()
.app_data(detector.clone())
.route("/detect", web::post().to(detect_objects))
})
.bind("127.0.0.1:8080")?
.run()
.await
}
- Размер входного изображения: Убедитесь, что input_size соответствует требованиям вашей версии YOLO (например, 416 для YOLOv3, 608 для YOLOv4).
- Формат выходных данных: Проверьте структуру выходного тензора вашей модели и адаптируйте функцию postprocess_output. Например, для YOLOv5 выходные данные могут отличаться.
- Классы: Если нужно обнаружить несколько классов, измените логику обработки вероятностей классов в postprocess_output.