Добавить в корзинуПозвонить
Найти в Дзене
Один Rust не п...Rust

Python + Rust с tract

t.me/oneRustnoqRust Обучить модель ML для предсказаний параметров. Найти альтернативы нормализации в Rust и в Python (Разные mean/scale) . Что происходит на самом деле: → Загрузка Iris → Разделение → Нормализация → Обучение MLP → Конвертация в ONNX → Сохранение трёх файлов. mlp.onnx, test_data.csv, scaler_params.pkl → в папку проекта Rust → Чтение параметров нормализации
→ Загрузка и оптимизация ONNX-модели
→ Чтение тестовых примеров
→ Параллельно для каждого примера:
• Извлечение признаков
• Нормализация (точно такая же, как в Python)
• Создание тензора
• Выполнение модели
• argmax по выходным вероятностям
• Сравнение с истинной меткой
→ Подсчёт точности и вывод результатов Нормализация в Rust повторяет Python. Модель обучена на нормализованных данных (Разные mean/scale) → полная бессмыслица предсказаний. // 1. Чтение параметров нормализации из pickle-файла let (mean, scale): (Vec<f64>, Vec<f64>) = bincode::deserialize(&buffer).unwrap(); let mean: Vec<f32> = mean.into_iter().map(|x| x
Оглавление
nicktretyakov1/Rust_Python_tract | Gitverse
ML на RUST без заморочек

t.me/oneRustnoqRust

Для чего нужна данная статья? :

Обучить модель ML для предсказаний параметров.

Зачем Вам это уметь? :

Найти альтернативы нормализации в Rust и в Python (Разные mean/scale) .

Python-часть: Обучение, подготовка и экспорт модели (train_model.py)

Что происходит на самом деле:

  • Модель учится на нормализованных данных (среднее ≈ 0, дисперсия ≈ 1)
  • Параметры нормализации (mean_, scale_) сохраняются отдельно
  • ONNX-файл содержит граф вычислений всей модели (веса, активации, структура слоёв)
  • Тестовые данные сохраняются в исходном масштабе (не нормализованные), чтобы Rust мог повторить весь пайплайн.

→ Загрузка Iris → Разделение → Нормализация → Обучение MLP → Конвертация в ONNX → Сохранение трёх файлов.

mlp.onnx, test_data.csv, scaler_params.pkl → в папку проекта Rust

Rust-часть: Загрузка, нормализация и инференс (src/main.rs)

→ Чтение параметров нормализации
→ Загрузка и оптимизация ONNX-модели
→ Чтение тестовых примеров
→ Параллельно для каждого примера:
• Извлечение признаков
• Нормализация (точно такая же, как в Python)
• Создание тензора
• Выполнение модели
• argmax по выходным вероятностям
• Сравнение с истинной меткой
→ Подсчёт точности и вывод результатов

Нормализация в Rust повторяет Python. Модель обучена на нормализованных данных (Разные mean/scale) → полная бессмыслица предсказаний.

// 1. Чтение параметров нормализации из pickle-файла

let (mean, scale): (Vec<f64>, Vec<f64>) = bincode::deserialize(&buffer).unwrap();

let mean: Vec<f32> = mean.into_iter().map(|x| x as f32).collect();

let scale: Vec<f32> = scale.into_iter().map(|x| x as f32).collect();

// 2. Загрузка ONNX-модели

let model = tract_onnx::onnx()

.model_for_path("mlp.onnx")?

.with_input_fact(0, InferenceFact::dt_shape(f32::datum_type(), tvec!(1, 4)))? // ← важный момент

.into_optimized()?

.into_runnable()?;

let model = Arc::new(model); // для безопасного многопоточного использования

// 3. Чтение тестовых данных (не нормализованных!)

let mut rdr = Reader::from_path("test_data.csv")?;

let records: Vec<_> = rdr.records().collect::<Result<_, _>>()?;

// 4. Параллельный инференс (самая интересная часть)

let predictions: Vec<_> = records.par_iter().map(|record| {

// 4.1 Извлечение 4-х признаков (как строки → f32)

let features: Vec<f32> = record.iter().take(4).map(|s| s.parse::<f32>().unwrap()).collect();

// 4.2 Ручная нормализация (точно такая же, как в Python)

let normalized: Vec<f32> = features.iter()

.zip(&mean)

.zip(&scale)

.map(|((x, m), s)| (x - m) / s)

.collect();

// 4.3 Создание входного тензора (batch=1, 4 признака)

let input = Array::from_shape_vec((1, 4), normalized).unwrap().into_tensor();

// 4.4 Выполнение модели (forward pass)

let output = model.run(tvec!(input)).unwrap();

// 4.5 Получение предсказаний (3 вероятности — для 3 классов Iris)

let probs = output[0].to_array_view::<f32>().unwrap();

let prediction: Vec<i32> = probs.iter().map(|&x| x.round() as i32).collect();

// 4.6 Выбор класса с максимальной вероятностью (argmax)

let predicted_class = prediction.iter()

.enumerate()

.max_by_key(|&(_, &val)| val)

.unwrap().0 as i32;

// 4.7 Истинная метка (5-й столбец в CSV)

let target: i32 = record[4].parse::<i32>().unwrap();

(predicted_class, target)

}).collect();

// 5. Подсчёт точности и вывод результатов

let correct = predictions.iter().filter(|&&(pred, target)| pred == target).count();

let accuracy = correct as f32 / predictions.len() as f32 * 100.0;

info!("Точность модели: {:.2}%", accuracy);