Добавить в корзинуПозвонить
Найти в Дзене
Один Rust не п...Rust

TensorFlow в Rust

Для чего нужна данная статья? : Написать собственный фреймворк, который: Зачем Вам это уметь? : Есть официальный Rust-крейт tensorflow, который предоставляет API для работы с моделями TensorFlow. Он позволяет: Пример загрузки модели и выполнения инференса: use tensorflow::{Graph, Session, SessionOptions, SessionRunArgs, Tensor}; fn main() -> tensorflow::Result<()> { let model_dir = "my_model"; // Путь к модели let mut graph = Graph::new(); let mut session = Session::new(&SessionOptions::new(), &graph)?; let input_tensor = Tensor::new(&[1]).with_values(&[3.5f32])?; let mut args = SessionRunArgs::new(); let input_op = graph.operation_by_name_required("input")?; let output_op = graph.operation_by_name_required("output")?; args.add_feed(&input_op, 0, &input_tensor); let output_token = args.request_fetch(&output_op, 0); session.run(&mut args)?; let output_tensor: Tensor<f32> = args.fetch(output_token)?; println!("Result: {:?}", output_tensor[0]); Ok(()) } Можно напрямую работать с библиот
Оглавление
ML на RUST без заморочек
Один Rust не п...Rust

Для чего нужна данная статья? :

Написать собственный фреймворк, который:

  1. Загружает и выполняет инференс модели на Rust с помощью официального tensorflow.
  2. Использует FFI (libtensorflow) для работы с TensorFlow напрямую.
  3. Загружает ONNX-модель через tract и делает предсказание.
  4. Интегрируется с Python (pyo3), чтобы вызвать инференс Keras-модели.
  5. Обращается к TensorFlow Serving через gRPC (tonic).

Зачем Вам это уметь? :

1. Официальный крейт tensorflow

Есть официальный Rust-крейт tensorflow, который предоставляет API для работы с моделями TensorFlow. Он позволяет:

  • Загружать модели в формате SavedModel.
  • Выполнять инференс на CPU и GPU.
  • Создавать и управлять тензорами.

Пример загрузки модели и выполнения инференса:

use tensorflow::{Graph, Session, SessionOptions, SessionRunArgs, Tensor};

fn main() -> tensorflow::Result<()> {

let model_dir = "my_model"; // Путь к модели

let mut graph = Graph::new();

let mut session = Session::new(&SessionOptions::new(), &graph)?;

let input_tensor = Tensor::new(&[1]).with_values(&[3.5f32])?;

let mut args = SessionRunArgs::new();

let input_op = graph.operation_by_name_required("input")?;

let output_op = graph.operation_by_name_required("output")?;

args.add_feed(&input_op, 0, &input_tensor);

let output_token = args.request_fetch(&output_op, 0);

session.run(&mut args)?;

let output_tensor: Tensor<f32> = args.fetch(output_token)?;

println!("Result: {:?}", output_tensor[0]);

Ok(())

}

2. Использование TensorFlow через FFI (Foreign Function Interface)

Можно напрямую работать с библиотекой TensorFlow через FFI, вызывая функции из libtensorflow. Этот способ требует C-обертки и не такой удобный, как крейт tensorflow, но даёт больше контроля.

Пример:

  1. Компилируем TensorFlow C API (libtensorflow).
  2. Загружаем её в Rust через libloading или bindgen.

3. Использование ONNX и tract в качестве альтернативы

Если нужна только инференс-часть TensorFlow, можно экспортировать модель в ONNX и использовать tract, который поддерживает ONNX и TensorFlow.

Пример с tract для загрузки TensorFlow-модели:

use tract_tensorflow::prelude::*;

fn main() -> TractResult<()> {

let model = tract_tensorflow::tensorflow()

.model_for_path("my_model.pb")?

.into_optimized()?

.into_runnable()?;

let input = tract_ndarray::arr1(&[3.5f32]).into_dyn();

let result = model.run(tvec!(input))?;

println!("Result: {:?}", result[0]);

Ok(())

}

4. Интеграция с Python через pyo3 или rust-cpython

Если модель требует сложной предобработки или TensorFlow использует специфические функции, можно вызывать его через Python.

Пример с pyo3:

use pyo3::prelude::*;

fn main() -> PyResult<()> {

Python::with_gil(|py| {

let tf = py.import("tensorflow")?;

let keras = tf.getattr("keras")?;

let model = keras.call_method0("load_model", "my_model.h5")?;

let input_data = PyTuple::new(py, &[3.5]);

let result = model.call_method1("predict", (input_data,))?;

println!("Result: {:?}", result);

Ok(())

})

}

5. TensorFlow Serving + gRPC в Rust

Можно запустить TensorFlow Serving и использовать tonic (gRPC) для отправки запросов. Это удобно для продакшен-решений.

Пример вызова модели через gRPC:

use tonic::transport::Channel;

use tensorflow_serving::predict::PredictRequest;

#[tokio::main]

async fn main() {

let channel = Channel::from_static("http://localhost:8500").connect().await.unwrap();

let mut client = tensorflow_serving::PredictionServiceClient::new(channel);

let request = PredictRequest {

model_spec: Some(tensorflow_serving::ModelSpec {

name: "my_model".to_string(),

..Default::default()

}),

inputs: vec![("input", vec![3.5f32])].into_iter().collect(),

..Default::default()

};

let response = client.predict(request).await.unwrap();

println!("{:?}", response);

}

Пример

📌 Зависимости (Cargo.toml)

[dependencies]

tensorflow = "0.21.0"

pyo3 = { version = "0.21", features = ["extension-module"] }

tract-tensorflow = "0.20.18"

tonic = "0.11"

prost = "0.12"

tokio = { version = "1", features = ["full"] }

libloading = "0.8"

📌 Код (main.rs)

use tensorflow::{Graph, Session, SessionOptions, SessionRunArgs, Tensor};

use std::ffi::{CString, c_void};

use libloading::{Library, Symbol};

use tract_tensorflow::prelude::*;

use tonic::transport::Channel;

use pyo3::prelude::*;

use tensorflow_serving::predict::PredictRequest;

// 🔹 1. Инференс через официальный крейт tensorflow

fn run_tensorflow_rust() -> tensorflow::Result<()> {

let model_dir = "my_model";

let mut graph = Graph::new();

let mut session = Session::new(&SessionOptions::new(), &graph)?;

let input_tensor = Tensor::new(&[1]).with_values(&[3.5f32])?;

let mut args = SessionRunArgs::new();

let input_op = graph.operation_by_name_required("input")?;

let output_op = graph.operation_by_name_required("output")?;

args.add_feed(&input_op, 0, &input_tensor);

let output_token = args.request_fetch(&output_op, 0);

session.run(&mut args)?;

let output_tensor: Tensor<f32> = args.fetch(output_token)?;

println!("TensorFlow Rust: {:?}", output_tensor[0]);

Ok(())

}

// 🔹 2. FFI с libtensorflow

fn run_tensorflow_ffi() {

unsafe {

let lib = Library::new("libtensorflow.so").unwrap();

let tf_version: Symbol<unsafe extern "C" fn() -> *const i8> = lib.get(b"TF_Version").unwrap();

let version = CString::from_raw(tf_version() as *mut i8).to_str().unwrap();

println!("TensorFlow FFI Version: {}", version);

}

}

// 🔹 3. tract для инференса ONNX/TensorFlow модели

fn run_tract() -> TractResult<()> {

let model = tract_tensorflow::tensorflow()

.model_for_path("my_model.pb")?

.into_optimized()?

.into_runnable()?;

let input = tract_ndarray::arr1(&[3.5f32]).into_dyn();

let result = model.run(tvec!(input))?;

println!("Tract Inference: {:?}", result[0]);

Ok(())

}

// 🔹 4. Вызываем Python-модель через pyo3

fn run_python() -> PyResult<()> {

Python::with_gil(|py| {

let tf = py.import("tensorflow")?;

let keras = tf.getattr("keras")?;

let model = keras.call_method0("load_model", "my_model.h5")?;

let input_data = PyTuple::new(py, &[3.5]);

let result = model.call_method1("predict", (input_data,))?;

println!("Python TensorFlow: {:?}", result);

Ok(())

})

}

// 🔹 5. gRPC запрос к TensorFlow Serving

#[tokio::main]

async fn run_grpc() -> Result<(), Box<dyn std::error::Error>> {

let channel = Channel::from_static("http://localhost:8500").connect().await?;

let mut client = tensorflow_serving::PredictionServiceClient::new(channel);

let request = PredictRequest {

model_spec: Some(tensorflow_serving::ModelSpec {

name: "my_model".to_string(),

..Default::default()

}),

inputs: vec![("input", vec![3.5f32])].into_iter().collect(),

..Default::default()

};

let response = client.predict(request).await?;

println!("TensorFlow Serving gRPC: {:?}", response);

Ok(())

}

fn main() {

println!("Running TensorFlow in Rust...");

run_tensorflow_rust().unwrap();

println!("Running TensorFlow via FFI...");

run_tensorflow_ffi();

println!("Running Tract...");

run_tract().unwrap();

println!("Running Python TensorFlow...");

run_python().unwrap();

println!("Running TensorFlow Serving gRPC...");

tokio::spawn(async { run_grpc().await.unwrap() });

}

Перед запуском убедись, что:

  • Установлен TensorFlow (pip install tensorflow).
  • Установлен TensorFlow Serving (docker run -p 8500:8500 -t tensorflow/serving).
  • В проекте есть файлы my_model.pb и my_model.h5.

Добавим поддержку CUDA для ускоренного вычисления NVIDIA GPU.

  • TensorFlow автоматически использует CUDA, если установлены NVIDIA драйверы и cuDNN.
  • FFI через libtensorflow будет работать с libtensorflow-gpu.so.
  • Tract поддерживает CUDA через tract-onnx + tract-tensorflow, но только для ONNX.
  • Python (pyo3) использует TensorFlow с CUDA (если установлен tensorflow-gpu).
  • TensorFlow Serving gRPC также может работать с CUDA.

Установи NVIDIA CUDA Toolkit и cuDNN:

sudo apt install nvidia-cuda-toolkit

Установи TensorFlow с поддержкой CUDA в Python:

pip install tensorflow[and-cuda]

Установи TensorFlow Serving GPU (если используешь gRPC):

docker pull tensorflow/serving:latest-gpu

📌 Обновленный Cargo.toml (с CUDA)

[dependencies]

tensorflow = "0.21.0"

pyo3 = { version = "0.21", features = ["extension-module"] }

tract-tensorflow = "0.20.18"

tract-onnx = "0.20.18"

tonic = "0.11"

prost = "0.12"

tokio = { version = "1", features = ["full"] }

libloading = "0.8"

cuda-runtime-sys = "0.1.0" # FFI для работы с CUDA напрямую

📌 Обновленный main.rs (с CUDA)

use tensorflow::{Graph, Session, SessionOptions, SessionRunArgs, Tensor};

use std::ffi::{CString, c_void};

use libloading::{Library, Symbol};

use tract_onnx::prelude::*;

use tonic::transport::Channel;

use pyo3::prelude::*;

use tensorflow_serving::predict::PredictRequest;

// 🔹 1. TensorFlow с поддержкой CUDA (если установлено)

fn run_tensorflow_cuda() -> tensorflow::Result<()> {

let model_dir = "my_model";

let mut graph = Graph::new();

let mut session = Session::new(&SessionOptions::new(), &graph)?;

let input_tensor = Tensor::new(&[1]).with_values(&[3.5f32])?;

let mut args = SessionRunArgs::new();

let input_op = graph.operation_by_name_required("input")?;

let output_op = graph.operation_by_name_required("output")?;

args.add_feed(&input_op, 0, &input_tensor);

let output_token = args.request_fetch(&output_op, 0);

session.run(&mut args)?;

let output_tensor: Tensor<f32> = args.fetch(output_token)?;

println!("TensorFlow CUDA: {:?}", output_tensor[0]);

Ok(())

}

// 🔹 2. FFI с libtensorflow (GPU)

fn run_tensorflow_ffi_cuda() {

unsafe {

let lib = Library::new("libtensorflow-gpu.so").unwrap();

let tf_version: Symbol<unsafe extern "C" fn() -> *const i8> = lib.get(b"TF_Version").unwrap();

let version = CString::from_raw(tf_version() as *mut i8).to_str().unwrap();

println!("TensorFlow GPU FFI Version: {}", version);

}

}

// 🔹 3. tract-onnx для инференса на GPU

fn run_tract_cuda() -> TractResult<()> {

let model = tract_onnx::onnx()

.model_for_path("model.onnx")?

.into_optimized()?

.into_runnable()?;

let input = tract_ndarray::arr1(&[3.5f32]).into_dyn();

let result = model.run(tvec!(input))?;

println!("Tract CUDA Inference: {:?}", result[0]);

Ok(())

}

// 🔹 4. TensorFlow через Python (GPU)

fn run_python_cuda() -> PyResult<()> {

Python::with_gil(|py| {

let tf = py.import("tensorflow")?;

let keras = tf.getattr("keras")?;

let model = keras.call_method0("load_model", "my_model.h5")?;

let input_data = PyTuple::new(py, &[3.5]);

let result = model.call_method1("predict", (input_data,))?;

println!("Python TensorFlow CUDA: {:?}", result);

Ok(())

})

}

// 🔹 5. gRPC-запрос к TensorFlow Serving (GPU)

#[tokio::main]

async fn run_grpc_cuda() -> Result<(), Box<dyn std::error::Error>> {

let channel = Channel::from_static("http://localhost:8500").connect().await?;

let mut client = tensorflow_serving::PredictionServiceClient::new(channel);

let request = PredictRequest {

model_spec: Some(tensorflow_serving::ModelSpec {

name: "my_model".to_string(),

..Default::default()

}),

inputs: vec![("input", vec![3.5f32])].into_iter().collect(),

..Default::default()

};

let response = client.predict(request).await?;

println!("TensorFlow Serving gRPC (CUDA): {:?}", response);

Ok(())

}

// 🔹 6. CUDA FFI: Запрос GPU информации через `cuda-runtime-sys`

fn run_cuda_info() {

unsafe {

let cuda_version = cuda_runtime_sys::cudaRuntimeGetVersion();

println!("CUDA Version: {:?}", cuda_version);

}

}

fn main() {

println!("Running TensorFlow with CUDA...");

run_tensorflow_cuda().unwrap();

println!("Running TensorFlow FFI with CUDA...");

run_tensorflow_ffi_cuda();

println!("Running Tract ONNX on CUDA...");

run_tract_cuda().unwrap();

println!("Running Python TensorFlow on CUDA...");

run_python_cuda().unwrap();

println!("Running TensorFlow Serving gRPC with CUDA...");

tokio::spawn(async { run_grpc_cuda().await.unwrap() });

println!("Checking CUDA Runtime Info...");

run_cuda_info();

}