Для чего нужна данная статья? :
Найти компромисс между Rust, Erlang или Akka.
Научиться использовать ML, для создания распределенных систем и акторов.
Зачем Вам это уметь? :
Научиться создавать акторную модель с использованием:
Actix ( Actix Book - Distributed Systems)
Добавьте в ваш Cargo.toml зависимость от Actix:
[dependencies]
actix = "0.14.0"
В этом примере созданы два актора: GreetActor, который обрабатывает сообщение Greet, и DoubleActor, который обрабатывает сообщение Double. Оба актора работают параллельно, и мы можем отправлять им сообщения независимо друг от друга.
use actix::prelude::*;
// Сообщения для акторов
#[derive(Message)]
#[rtype(result = "String")]
struct Greet {
name: String,
}
#[derive(Message)]
#[rtype(result = "i32")]
struct Double;
// Актор
struct GreetActor;
impl Actor for GreetActor {
type Context = Context<Self>;
}
// Обработка сообщений для первого актора
impl Handler<Greet> for GreetActor {
type Result = String;
fn handle(&mut self, msg: Greet, _: &mut Context<Self>) -> Self::Result {
format!("Hello, {}!", msg.name)
}
}
// Актор
struct DoubleActor;
impl Actor for DoubleActor {
type Context = Context<Self>;
}
// Обработка сообщений для второго актора
impl Handler<Double> for DoubleActor {
type Result = i32;
fn handle(&mut self, _: Double, _: &mut Context<Self>) -> Self::Result {
2
}
}
#[actix::main]
async fn main() {
// Создание системы акторов
let system = System::new();
// Создание экземпляров акторов
let greet_actor = GreetActor.start();
let double_actor = DoubleActor.start();
// Отправка сообщений акторам
let greet_result = greet_actor.send(Greet {
name: "Actix".to_string(),
});
let double_result = double_actor.send(Double);
// Обработка результатов
match greet_result.await {
Ok(response) => println!("{}", response),
Err(e) => println!("Error: {:?}", e),
}
match double_result.await {
Ok(response) => println!("Doubled: {}", response),
Err(e) => println!("Error: {:?}", e),
}
// Остановка системы акторов
system.stop();
}
Riker ( Riker - Distributed Systems)
use riker::actors::*;
use riker_default::DefaultModel;
#[derive(Debug, PartialEq, Eq)]
struct Greet {
name: String,
}
#[derive(Debug, PartialEq, Eq)]
struct Double;
#[actor(Greet, Double)]
struct GreetActor;
impl Actor for GreetActor {
type Msg = Msg;
fn receive(&mut self, ctx: &Context<Self::Msg>, msg: Self::Msg, _sender: Option<BasicActorRef>) {
match msg {
Greet { name } => {
println!("Hello, {}!", name);
}
_ => (),
}
}
}
#[actor(Greet, Double)]
struct DoubleActor;
impl Actor for DoubleActor {
type Msg = Msg;
fn receive(&mut self, ctx: &Context<Self::Msg>, msg: Self::Msg, _sender: Option<BasicActorRef>) {
match msg {
Double => {
println!("Doubled!");
}
_ => (),
}
}
}
fn main() {
let system = ActorSystem::new(ActorSystemConfig::load("app.conf")).unwrap();
// Создание экземпляров акторов
let greet_actor = system.actor_of::<GreetActor>("greet-actor").unwrap();
let double_actor = system.actor_of::<DoubleActor>("double-actor").unwrap();
// Отправка сообщений акторам
greet_actor.tell(Greet {
name: "Riker".to_string(),
}, None);
double_actor.tell(Double, None);
// Задержка для обработки сообщений
std::thread::sleep(std::time::Duration::from_secs(1));
// Остановка системы акторов
system.shutdown().unwrap();
}
Tokio с Serde
Добавьте следующие зависимости в ваш Cargo.toml:
[dependencies]
tokio = { version = "1", features = ["full"] }
serde = { version = "1", features = ["derive"] }
Этот пример создает два актора (greet_actor и double_actor), каждый из которых асинхронно ожидает сообщения из своего канала. В main, мы создаем каналы для взаимодействия с акторами, отправляем сообщения и ожидаем завершения.
use serde::{Deserialize, Serialize};
use tokio::sync::mpsc;
use tokio::time::Duration;
// Сообщение для акторов
#[derive(Debug, Serialize, Deserialize)]
enum Message {
Greet { name: String },
Double,
}
// Актор
async fn greet_actor(mut receiver: mpsc::Receiver<Message>) {
while let Some(msg) = receiver.recv().await {
match msg {
Message::Greet { name } => {
println!("Hello, {}!", name);
}
_ => (),
}
}
}
// Актор
async fn double_actor(mut receiver: mpsc::Receiver<Message>) {
while let Some(msg) = receiver.recv().await {
match msg {
Message::Double => {
println!("Doubled!");
}
_ => (),
}
}
}
#[tokio::main]
async fn main() {
// Создание каналов для взаимодействия с акторами
let (greet_sender, greet_receiver) = mpsc::channel(10);
let (double_sender, double_receiver) = mpsc::channel(10);
// Запуск акторов
tokio::spawn(greet_actor(greet_receiver));
tokio::spawn(double_actor(double_receiver));
// Отправка сообщений акторам
greet_sender.send(Message::Greet {
name: "Tokio".to_string(),
}).await.unwrap();
double_sender.send(Message::Double).await.unwrap();
// Задержка для обработки сообщений
tokio::time::sleep(Duration::from_secs(1)).await;
}
Cистема с федеративным обучением, где узлы обучают модели на локальных данных, обмениваются обновлениями с центральным сервером через асинхронное сетевое взаимодействие, и включаем обработку ошибок, обнаружение узлов и балансировку нагрузки.
- Архитектура:CentralServer: Агрегирует параметры моделей от узлов и рассылает обновления.
NodeActor: Локально обучает модель, отправляет параметры на сервер и принимает обновления. - Федеративное обучение:Каждый узел обучает модель на своих данных (в примере данные фиктивные).
Центральный сервер использует алгоритм FedAvg для усреднения параметров. - Асинхронность:Используем tokio и actix для асинхронного взаимодействия между узлами и сервером.
- Обработка ошибок:Логируем сбои с помощью log.
Используем anyhow для удобной обработки ошибок.
Добавьте следующие зависимости в Cargo.toml:
[dependencies]
actix = "0.13"
actix-rt = "2.8"
actix-web = "4.3"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
tokio = { version = "1.23", features = ["full"] }
tch = "0.10"
rand = "0.8"
log = "0.4"
env_logger = "0.10"
awc = "3.1"
anyhow = "1.0"
futures = "0.3"
etcd-client = "0.8"
tokio = { version = "1", features = ["full"] } # Для асинхронного выполнения
1. Определение структур данных и сообщений
use serde::{Deserialize, Serialize};
use tch::{nn, nn::Module, nn::OptimizerConfig, Device, Kind, Tensor};
use actix::prelude::*;
use anyhow::Result;
use log::{info, error};
// Типы сообщений для обмена между узлами и сервером
#[derive(Serialize, Deserialize, Message)]
#[rtype(result = "Result<(), String>")]
enum NodeMessage {
Train { data: Vec<f32>, labels: Vec<f32> }, // Запрос на обучение
Predict { data: Vec<f32> }, // Запрос на предсказание
UpdateModel { params: Vec<f32> }, // Обновление параметров модели
RegisterNode { addr: String }, // Регистрация узла на сервере
}
// Сообщение для центрального сервера
#[derive(Message)]
#[rtype(result = "Result<(), String>")]
struct ServerMessage {
node_addr: String,
params: Vec<f32>,
}
// Структура модели (простая нейронная сеть)
fn build_model(vs: &nn::Path) -> nn::Sequential {
nn::seq()
.add(nn::linear(vs / "layer1", 10, 64, Default::default()))
.add_fn(|xs| xs.relu())
.add(nn::linear(vs / "layer2", 64, 1, Default::default()))
}
2. Реализация актора узла (NodeActor)
struct NodeActor {
model: nn::Sequential,
opt: nn::Optimizer,
device: Device,
server_addr: String,
node_addr: String,
}
impl Actor for NodeActor {
type Context = Context<Self>;
fn started(&mut self, ctx: &mut Self::Context) {
info!("Node {} started", self.node_addr);
// Регистрация узла на сервере
let msg = NodeMessage::RegisterNode {
addr: self.node_addr.clone(),
};
ctx.run_later(std::time::Duration::from_secs(1), move |act, _| {
act.send_to_server(msg);
});
}
}
impl Handler<NodeMessage> for NodeActor {
type Result = Result<(), String>;
fn handle(&mut self, msg: NodeMessage, ctx: &mut Self::Context) -> Self::Result {
match msg {
NodeMessage::Train { data, labels } => {
info!("Training on node {}", self.node_addr);
let data_tensor = Tensor::of_slice(&data).to_device(self.device);
let labels_tensor = Tensor::of_slice(&labels).to_device(self.device);
// Обучение модели
for _ in 0..10 { // 10 эпох для примера
self.opt.zero_grad();
let pred = self.model.forward(&data_tensor);
let loss = pred.mse_loss(&labels_tensor, tch::Reduction::Mean);
loss.backward();
self.opt.step();
}
// Отправка обновленных параметров на сервер
let params = self.extract_params()?;
let server_msg = ServerMessage {
node_addr: self.node_addr.clone(),
params,
};
Arbiter::current().spawn(async move {
if let Err(e) = CENTRAL_SERVER.try_send(server_msg) {
error!("Failed to send params to server: {}", e);
}
});
Ok(())
}
NodeMessage::Predict { data } => {
let data_tensor = Tensor::of_slice(&data).to_device(self.device);
let pred = self.model.forward(&data_tensor);
info!("Prediction on node {}: {:?}", self.node_addr, pred);
Ok(())
}
NodeMessage::UpdateModel { params } => {
self.update_model(¶ms)?;
info!("Model updated on node {}", self.node_addr);
Ok(())
}
NodeMessage::RegisterNode { .. } => Ok(()), // Игнорируем, так как это для сервера
}
}
}
impl NodeActor {
fn new(server_addr: String, node_addr: String) -> Self {
let vs = nn::VarStore::new(Device::Cpu);
let model = build_model(&vs.root());
let opt = nn::Adam::default().build(&vs, 1e-3).unwrap();
Self {
model,
opt,
device: Device::Cpu,
server_addr,
node_addr,
}
}
fn extract_params(&self) -> Result<Vec<f32>> {
let mut params = Vec::new();
for param in self.model.parameters() {
let p = param.flatten().to_kind(Kind::Float);
params.extend(p.data::<f32>()?.to_vec());
}
Ok(params)
}
fn update_model(&mut self, params: &[f32]) -> Result<()> {
let mut offset = 0;
for mut param in self.model.parameters() {
let size = param.numel();
let slice = ¶ms[offset..offset + size];
offset += size;
param.copy_(&Tensor::of_slice(slice).to_device(self.device));
}
Ok(())
}
fn send_to_server(&self, msg: NodeMessage) {
let server_addr = self.server_addr.clone();
Arbiter::current().spawn(async move {
let client = awc::Client::default();
if let Err(e) = client.post(&server_addr).send_json(&msg).await {
error!("Failed to send message to server: {}", e);
}
});
}
}
3. Реализация центрального сервера (CentralServer)
struct CentralServer {
nodes: Vec<String>,
aggregated_params: Option<Vec<f32>>,
model: nn::Sequential,
updates_received: usize,
total_nodes: usize,
}
impl Actor for CentralServer {
type Context = Context<Self>;
}
impl Handler<ServerMessage> for CentralServer {
type Result = Result<(), String>;
fn handle(&mut self, msg: ServerMessage, _ctx: &mut Self::Context) -> Self::Result {
if !self.nodes.contains(&msg.node_addr) {
self.nodes.push(msg.node_addr.clone());
}
self.updates_received += 1;
if let Some(ref mut aggregated) = self.aggregated_params {
for (a, b) in aggregated.iter_mut().zip(msg.params.iter()) {
*a += *b;
}
} else {
self.aggregated_params = Some(msg.params);
}
if self.updates_received >= self.total_nodes {
self.aggregate_and_broadcast()?;
self.updates_received = 0;
self.aggregated_params = None;
}
Ok(())
}
}
impl CentralServer {
fn new(total_nodes: usize) -> Self {
let vs = nn::VarStore::new(Device::Cpu);
let model = build_model(&vs.root());
Self {
nodes: Vec::new(),
aggregated_params: None,
model,
updates_received: 0,
total_nodes,
}
}
fn aggregate_and_broadcast(&mut self) -> Result<()> {
let aggregated = self.aggregated_params.as_mut().unwrap();
for param in aggregated.iter_mut() {
*param /= self.total_nodes as f32; // FedAvg
}
self.update_model(aggregated)?;
let msg = NodeMessage::UpdateModel {
params: aggregated.clone(),
};
for node in &self.nodes {
let client = awc::Client::default();
let node_addr = format!("{}/message", node);
Arbiter::current().spawn(async move {
if let Err(e) = client.post(&node_addr).send_json(&msg).await {
error!("Failed to broadcast to {}: {}", node_addr, e);
}
});
}
Ok(())
}
fn update_model(&mut self, params: &[f32]) -> Result<()> {
let mut offset = 0;
for mut param in self.model.parameters() {
let size = param.numel();
let slice = ¶ms[offset..offset + size];
offset += size;
param.copy_(&Tensor::of_slice(slice).to_device(Device::Cpu));
}
Ok(())
}
}
static CENTRAL_SERVER: Lazy<Addr<CentralServer>> = Lazy::new(|| {
CentralServer::new(2).start() // Ожидаем обновления от 2 узлов
});
4. Сетевые маршруты
use actix_web::{web, App, HttpResponse, HttpServer};
use futures::future::join_all;
async fn receive_message(
msg: web::Json<NodeMessage>,
actor: web::Data<Addr<NodeActor>>,
) -> HttpResponse {
match actor.send(msg.0).await {
Ok(Ok(())) => HttpResponse::Ok().json(serde_json::json!({"status": "received"})),
_ => HttpResponse::InternalServerError().json(serde_json::json!({"status": "error"})),
}
}
async fn start_node(server_addr: &str, node_addr: &str) -> std::io::Result<()> {
let actor = NodeActor::new(server_addr.to_string(), node_addr.to_string()).start();
HttpServer::new(move || {
App::new()
.app_data(web::Data::new(actor.clone()))
.route("/message", web::post().to(receive_message))
})
.bind(node_addr)?
.run()
.await
}
5. Точка входа
use once_cell::sync::Lazy;
#[actix_rt::main]
async fn main() -> std::io::Result<()> {
env_logger::init();
info!("Starting distributed ML system");
let server_addr = "http://127.0.0.1:8080/message";
let nodes = vec![
start_node(server_addr, "127.0.0.1:8081"),
start_node(server_addr, "127.0.0.1:8082"),
];
// Симуляция обучения
let client = awc::Client::default();
let train_msg = NodeMessage::Train {
data: vec![1.0; 10], // Пример данных
labels: vec![0.5; 1],
};
for port in 8081..=8082 {
let url = format!("http://127.0.0.1:{}/message", port);
client.post(&url).send_json(&train_msg).await.unwrap();
}
join_all(nodes).await;
Ok(())
}
Обнаружение узлов через etcd, Балансировка нагрузки, Поддержка GPU через tch-rs:
use actix::prelude::*;
use etcd_client::{Client, Error as EtcdError};
use tch::{nn, Device};
use std::time::Duration;
// Центральный сервер
struct CentralServer {
nodes: Vec<String>,
current_node_index: usize,
etcd_addr: String,
}
impl Actor for CentralServer {
type Context = Context<Self>;
fn started(&mut self, ctx: &mut Self::Context) {
// Периодическое обновление списка узлов
ctx.run_interval(Duration::from_secs(10), |act, _| {
actix_rt::spawn(async move {
if let Ok(nodes) = discover_nodes(&act.etcd_addr).await {
act.nodes = nodes;
}
});
});
}
}
// Узел
struct NodeActor {
model: nn::Sequential,
device: Device,
server_addr: String,
node_addr: String,
etcd_addr: String,
}
impl Actor for NodeActor {
type Context = Context<Self>;
fn started(&mut self, ctx: &mut Self::Context) {
println!("Узел {} запущен", self.node_addr);
// Регистрация узла в etcd
let etcd_addr = self.etcd_addr.clone();
let node_addr = self.node_addr.clone();
actix_rt::spawn(async move {
if let Err(e) = register_node(&etcd_addr, &node_addr).await {
eprintln!("Ошибка регистрации узла: {}", e);
}
});
}
}
impl NodeActor {
fn new(server_addr: String, node_addr: String, etcd_addr: String) -> Self {
let device = if tch::Cuda::is_available() {
Device::Cuda(0)
} else {
Device::Cpu
};
let vs = nn::VarStore::new(device);
let model = build_model(&vs.root());
Self {
model,
device,
server_addr,
node_addr,
etcd_addr,
}
}
}
// Точка входа
#[actix_rt::main]
async fn main() -> std::io::Result<()> {
let etcd_addr = "http://127.0.0.1:2379";
let server_addr = "127.0.0.1:8080";
// Запуск узлов
let node1 = NodeActor::new(server_addr.to_string(), "127.0.0.1:8081".to_string(), etcd_addr.to_string()).start();
let node2 = NodeActor::new(server_addr.to_string(), "127.0.0.1:8082".to_string(), etcd_addr.to_string()).start();
// Запуск сервера
let server = CentralServer {
nodes: Vec::new(),
current_node_index: 0,
etcd_addr: etcd_addr.to_string(),
}.start();
actix_rt::System::current().run()
}
полный код
use actix::prelude::*;
use actix_web::{web, App, HttpResponse, HttpServer};
use anyhow::Result;
use etcd_client::{Client as EtcdClient, Error as EtcdError};
use futures::future::join_all;
use log::{error, info};
use once_cell::sync::Lazy;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use std::time::Duration;
use tch::{nn, nn::Module, nn::OptimizerConfig, Device, Kind, Tensor};
// Messages for communication between nodes and server
#[derive(Serialize, Deserialize, Message)]
#[rtype(result = "Result<(), String>")]
enum NodeMessage {
Train { data: Vec<f32>, labels: Vec<f32> }, // Train request
Predict { data: Vec<f32> }, // Prediction request
UpdateModel { params: Vec<f32> }, // Model parameter update
RegisterNode { addr: String }, // Node registration
}
// Message for the central server
#[derive(Message)]
#[rtype(result = "Result<(), String>")]
struct ServerMessage {
node_addr: String,
params: Vec<f32>,
}
// Neural network model definition
fn build_model(vs: &nn::Path) -> nn::Sequential {
nn::seq()
.add(nn::linear(vs / "layer1", 10, 64, Default::default()))
.add_fn(|xs| xs.relu())
.add(nn::linear(vs / "layer2", 64, 1, Default::default()))
}
// CentralServer actor
struct CentralServer {
nodes: Vec<String>,
aggregated_params: Option<Vec<f32>>,
model: nn::Sequential,
updates_received: usize,
total_nodes: usize,
etcd_addr: String,
current_node_index: usize, // For load balancing
}
impl Actor for CentralServer {
type Context = Context<Self>;
fn started(&mut self, ctx: &mut Self::Context) {
info!("CentralServer started");
let etcd_addr = self.etcd_addr.clone();
ctx.run_interval(Duration::from_secs(10), move |act, _| {
let etcd_addr = etcd_addr.clone();
actix_rt::spawn(async move {
match discover_nodes(&etcd_addr).await {
Ok(nodes) => act.nodes = nodes,
Err(e) => error!("Failed to discover nodes: {}", e),
}
});
});
}
}
impl Handler<ServerMessage> for CentralServer {
type Result = Result<(), String>;
fn handle(&mut self, msg: ServerMessage, _ctx: &mut Self::Context) -> Self::Result {
if !self.nodes.contains(&msg.node_addr) {
self.nodes.push(msg.node_addr.clone());
}
self.updates_received += 1;
if let Some(ref mut aggregated) = self.aggregated_params {
for (a, b) in aggregated.iter_mut().zip(msg.params.iter()) {
*a += *b;
}
} else {
self.aggregated_params = Some(msg.params);
}
if self.updates_received >= self.total_nodes {
self.aggregate_and_broadcast()?;
self.updates_received = 0;
self.aggregated_params = None;
}
Ok(())
}
}
impl CentralServer {
fn new(total_nodes: usize, etcd_addr: String) -> Self {
let vs = nn::VarStore::new(Device::Cpu);
let model = build_model(&vs.root());
Self {
nodes: Vec::new(),
aggregated_params: None,
model,
updates_received: 0,
total_nodes,
etcd_addr,
current_node_index: 0,
}
}
fn aggregate_and_broadcast(&mut self) -> Result<()> {
let aggregated = self.aggregated_params.as_mut().unwrap();
for param in aggregated.iter_mut() {
*param /= self.total_nodes as f32; // FedAvg algorithm
}
self.update_model(aggregated)?;
let msg = NodeMessage::UpdateModel {
params: aggregated.clone(),
};
for node in &self.nodes {
let client = awc::Client::default();
let node_addr = format!("{}/message", node);
actix_rt::spawn(async move {
if let Err(e) = client.post(&node_addr).send_json(&msg).await {
error!("Failed to broadcast to {}: {}", node_addr, e);
}
});
}
Ok(())
}
fn update_model(&mut self, params: &[f32]) -> Result<()> {
let mut offset = 0;
for mut param in self.model.parameters() {
let size = param.numel();
let slice = ¶ms[offset..offset + size];
offset += size;
param.copy_(&Tensor::of_slice(slice).to_device(Device::Cpu));
}
Ok(())
}
}
// NodeActor actor
struct NodeActor {
model: nn::Sequential,
opt: nn::Optimizer,
device: Device,
server_addr: String,
node_addr: String,
etcd_addr: String,
}
impl Actor for NodeActor {
type Context = Context<Self>;
fn started(&mut self, ctx: &mut Self::Context) {
info!("Node {} started", self.node_addr);
let etcd_addr = self.etcd_addr.clone();
let node_addr = self.node_addr.clone();
actix_rt::spawn(async move {
if let Err(e) = register_node(&etcd_addr, &node_addr).await {
error!("Failed to register node in etcd: {}", e);
}
});
let msg = NodeMessage::RegisterNode {
addr: self.node_addr.clone(),
};
self.send_to_server(msg);
}
}
impl Handler<NodeMessage> for NodeActor {
type Result = Result<(), String>;
fn handle(&mut self, msg: NodeMessage, _ctx: &mut Self::Context) -> Self::Result {
match msg {
NodeMessage::Train { data, labels } => {
info!("Training on node {}", self.node_addr);
let data_tensor = Tensor::of_slice(&data).to_device(self.device);
let labels_tensor = Tensor::of_slice(&labels).to_device(self.device);
for _ in 0..10 {
self.opt.zero_grad();
let pred = self.model.forward(&data_tensor);
let loss = pred.mse_loss(&labels_tensor, tch::Reduction::Mean);
loss.backward();
self.opt.step();
}
let params = self.extract_params()?;
let server_msg = ServerMessage {
node_addr: self.node_addr.clone(),
params,
};
actix_rt::spawn(async move {
if let Err(e) = CENTRAL_SERVER.try_send(server_msg) {
error!("Failed to send params to server: {}", e);
}
});
Ok(())
}
NodeMessage::Predict { data } => {
let data_tensor = Tensor::of_slice(&data).to_device(self.device);
let pred = self.model.forward(&data_tensor);
info!("Prediction on node {}: {:?}", self.node_addr, pred);
Ok(())
}
NodeMessage::UpdateModel { params } => {
self.update_model(¶ms)?;
info!("Model updated on node {}", self.node_addr);
Ok(())
}
NodeMessage::RegisterNode { .. } => Ok(()),
}
}
}
impl NodeActor {
fn new(server_addr: String, node_addr: String, etcd_addr: String) -> Self {
let device = if tch::Cuda::is_available() {
info!("GPU available, using CUDA");
Device::Cuda(0)
} else {
info!("No GPU available, using CPU");
Device::Cpu
};
let vs = nn::VarStore::new(device);
let model = build_model(&vs.root());
let opt = nn::Adam::default().build(&vs, 1e-3).unwrap();
Self {
model,
opt,
device,
server_addr,
node_addr,
etcd_addr,
}
}
fn extract_params(&self) -> Result<Vec<f32>> {
let mut params = Vec::new();
for param in self.model.parameters() {
let p = param.flatten().to_kind(Kind::Float);
params.extend(p.data::<f32>()?.to_vec());
}
Ok(params)
}
fn update_model(&mut self, params: &[f32]) -> Result<()> {
let mut offset = 0;
for mut param in self.model.parameters() {
let size = param.numel();
let slice = ¶ms[offset..offset + size];
offset += size;
param.copy_(&Tensor::of_slice(slice).to_device(self.device));
}
Ok(())
}
fn send_to_server(&self, msg: NodeMessage) {
let server_addr = self.server_addr.clone();
actix_rt::spawn(async move {
let client = awc::Client::default();
if let Err(e) = client.post(&server_addr).send_json(&msg).await {
error!("Failed to send message to server: {}", e);
}
});
}
}
// etcd node discovery and registration
async fn discover_nodes(etcd_addr: &str) -> Result<Vec<String>, EtcdError> {
let client = EtcdClient::connect([etcd_addr], None).await?;
let resp = client.get("/nodes/", None).await?;
let nodes = resp
.kvs()
.iter()
.map(|kv| kv.value_str().unwrap().to_string())
.collect();
Ok(nodes)
}
async fn register_node(etcd_addr: &str, node_addr: &str) -> Result<(), EtcdError> {
let client = EtcdClient::connect([etcd_addr], None).await?;
client
.put(format!("/nodes/{}", node_addr), node_addr, None)
.await?;
Ok(())
}
// HTTP endpoint for receiving messages
async fn receive_message(
msg: web::Json<NodeMessage>,
actor: web::Data<Addr<NodeActor>>,
) -> HttpResponse {
match actor.send(msg.0).await {
Ok(Ok(())) => HttpResponse::Ok().json(serde_json::json!({"status": "received"})),
Ok(Err(e)) => HttpResponse::InternalServerError().json(serde_json::json!({"status": "error", "message": e})),
Err(e) => HttpResponse::InternalServerError().json(serde_json::json!({"status": "error", "message": format!("Actor error: {}", e)})),
}
}
async fn start_node(server_addr: &str, node_addr: &str, etcd_addr: &str) -> std::io::Result<()> {
let actor = NodeActor::new(
server_addr.to_string(),
node_addr.to_string(),
etcd_addr.to_string(),
)
.start();
HttpServer::new(move || {
App::new()
.app_data(web::Data::new(actor.clone()))
.route("/message", web::post().to(receive_message))
})
.bind(node_addr)?
.run()
.await
}
// Global reference to CentralServer
static CENTRAL_SERVER: Lazy<Addr<CentralServer>> = Lazy::new(|| {
CentralServer::new(2, "http://127.0.0.1:2379".to_string()).start()
});
// Main entry point
#[actix_rt::main]
async fn main() -> std::io::Result<()> {
env_logger::init();
info!("Starting federated learning system");
let server_addr = "http://127.0.0.1:8080/message";
let etcd_addr = "http://127.0.0.1:2379";
let nodes = vec![
start_node(server_addr, "127.0.0.1:8081", etcd_addr),
start_node(server_addr, "127.0.0.1:8082", etcd_addr),
];
// Simulate training
let client = awc::Client::default();
let train_msg = NodeMessage::Train {
data: vec![1.0; 10],
labels: vec![0.5; 1],
};
for port in 8081..=8082 {
let url = format!("http://127.0.0.1:{}/message", port);
if let Err(e) = client.post(&url).send_json(&train_msg).await {
error!("Failed to send training message to port {}: {}", port, e);
}
}
join_all(nodes).await;
Ok(())
}