Найти в Дзене
Один Rust не п...Rust

Rust для распределенных акторов

Оглавление
GitHub - nicktretyakov/federated_learning

Для чего нужна данная статья? :

Найти компромисс между Rust, Erlang или Akka.

Научиться использовать ML, для создания распределенных систем и акторов.

Зачем Вам это уметь? :

Научиться создавать акторную модель с использованием:

Actix ( Actix Book - Distributed Systems)

Добавьте в ваш Cargo.toml зависимость от Actix:

[dependencies]
actix = "0.14.0"

В этом примере созданы два актора: GreetActor, который обрабатывает сообщение Greet, и DoubleActor, который обрабатывает сообщение Double. Оба актора работают параллельно, и мы можем отправлять им сообщения независимо друг от друга.

use actix::prelude::*;
// Сообщения для акторов
#[derive(Message)]
#[rtype(result = "String")]
struct Greet {
name: String,
}
#[derive(Message)]
#[rtype(result = "i32")]
struct Double;
// Актор
struct GreetActor;
impl Actor for GreetActor {
type Context = Context<Self>;
}
// Обработка сообщений для первого актора
impl Handler<Greet> for GreetActor {
type Result = String;
fn handle(&mut self, msg: Greet, _: &mut Context<Self>) -> Self::Result {
format!("Hello, {}!", msg.name)
}
}
// Актор
struct DoubleActor;
impl Actor for DoubleActor {
type Context = Context<Self>;
}
// Обработка сообщений для второго актора
impl Handler<Double> for DoubleActor {
type Result = i32;
fn handle(&mut self, _: Double, _: &mut Context<Self>) -> Self::Result {
2
}
}
#[actix::main]
async fn main() {
// Создание системы акторов
let system = System::new();
// Создание экземпляров акторов
let greet_actor = GreetActor.start();
let double_actor = DoubleActor.start();
// Отправка сообщений акторам
let greet_result = greet_actor.send(Greet {
name: "Actix".to_string(),
});
let double_result = double_actor.send(Double);
// Обработка результатов
match greet_result.await {
Ok(response) => println!("{}", response),
Err(e) => println!("Error: {:?}", e),
}
match double_result.await {
Ok(response) => println!("Doubled: {}", response),
Err(e) => println!("Error: {:?}", e),
}
// Остановка системы акторов
system.stop();
}

Riker ( Riker - Distributed Systems)

use riker::actors::*;
use riker_default::DefaultModel;
#[derive(Debug, PartialEq, Eq)]
struct Greet {
name: String,
}
#[derive(Debug, PartialEq, Eq)]
struct Double;
#[actor(Greet, Double)]
struct GreetActor;
impl Actor for GreetActor {
type Msg = Msg;
fn receive(&mut self, ctx: &Context<Self::Msg>, msg: Self::Msg, _sender: Option<BasicActorRef>) {
match msg {
Greet { name } => {
println!("Hello, {}!", name);
}
_ => (),
}
}
}
#[actor(Greet, Double)]
struct DoubleActor;
impl Actor for DoubleActor {
type Msg = Msg;
fn receive(&mut self, ctx: &Context<Self::Msg>, msg: Self::Msg, _sender: Option<BasicActorRef>) {
match msg {
Double => {
println!("Doubled!");
}
_ => (),
}
}
}
fn main() {
let system = ActorSystem::new(ActorSystemConfig::load("app.conf")).unwrap();
// Создание экземпляров акторов
let greet_actor = system.actor_of::<GreetActor>("greet-actor").unwrap();
let double_actor = system.actor_of::<DoubleActor>("double-actor").unwrap();
// Отправка сообщений акторам
greet_actor.tell(Greet {
name: "Riker".to_string(),
}, None);
double_actor.tell(Double, None);
// Задержка для обработки сообщений
std::thread::sleep(std::time::Duration::from_secs(1));
// Остановка системы акторов
system.shutdown().unwrap();
}

Tokio с Serde

Добавьте следующие зависимости в ваш Cargo.toml:

[dependencies]
tokio = { version = "1", features = ["full"] }
serde = { version = "1", features = ["derive"] }

Этот пример создает два актора (greet_actor и double_actor), каждый из которых асинхронно ожидает сообщения из своего канала. В main, мы создаем каналы для взаимодействия с акторами, отправляем сообщения и ожидаем завершения.

use serde::{Deserialize, Serialize};
use tokio::sync::mpsc;
use tokio::time::Duration;
// Сообщение для акторов
#[derive(Debug, Serialize, Deserialize)]
enum Message {
Greet { name: String },
Double,
}
// Актор
async fn greet_actor(mut receiver: mpsc::Receiver<Message>) {
while let Some(msg) = receiver.recv().await {
match msg {
Message::Greet { name } => {
println!("Hello, {}!", name);
}
_ => (),
}
}
}
// Актор
async fn double_actor(mut receiver: mpsc::Receiver<Message>) {
while let Some(msg) = receiver.recv().await {
match msg {
Message::Double => {
println!("Doubled!");
}
_ => (),
}
}
}
#[tokio::main]
async fn main() {
// Создание каналов для взаимодействия с акторами
let (greet_sender, greet_receiver) = mpsc::channel(10);
let (double_sender, double_receiver) = mpsc::channel(10);
// Запуск акторов
tokio::spawn(greet_actor(greet_receiver));
tokio::spawn(double_actor(double_receiver));
// Отправка сообщений акторам
greet_sender.send(Message::Greet {
name: "Tokio".to_string(),
}).await.unwrap();
double_sender.send(Message::Double).await.unwrap();
// Задержка для обработки сообщений
tokio::time::sleep(Duration::from_secs(1)).await;
}

Cистема с федеративным обучением, где узлы обучают модели на локальных данных, обмениваются обновлениями с центральным сервером через асинхронное сетевое взаимодействие, и включаем обработку ошибок, обнаружение узлов и балансировку нагрузки.

  • Архитектура:CentralServer: Агрегирует параметры моделей от узлов и рассылает обновления.
    NodeActor: Локально обучает модель, отправляет параметры на сервер и принимает обновления.
  • Федеративное обучение:Каждый узел обучает модель на своих данных (в примере данные фиктивные).
    Центральный сервер использует алгоритм FedAvg для усреднения параметров.
  • Асинхронность:Используем tokio и actix для асинхронного взаимодействия между узлами и сервером.
  • Обработка ошибок:Логируем сбои с помощью log.
    Используем anyhow для удобной обработки ошибок.

Добавьте следующие зависимости в Cargo.toml:

[dependencies]

actix = "0.13"

actix-rt = "2.8"

actix-web = "4.3"

serde = { version = "1.0", features = ["derive"] }

serde_json = "1.0"

tokio = { version = "1.23", features = ["full"] }

tch = "0.10"

rand = "0.8"

log = "0.4"

env_logger = "0.10"

awc = "3.1"

anyhow = "1.0"

futures = "0.3"

etcd-client = "0.8"
tokio = { version = "1", features = ["full"] } # Для асинхронного выполнения

1. Определение структур данных и сообщений

use serde::{Deserialize, Serialize};

use tch::{nn, nn::Module, nn::OptimizerConfig, Device, Kind, Tensor};

use actix::prelude::*;

use anyhow::Result;

use log::{info, error};

// Типы сообщений для обмена между узлами и сервером

#[derive(Serialize, Deserialize, Message)]

#[rtype(result = "Result<(), String>")]

enum NodeMessage {

Train { data: Vec<f32>, labels: Vec<f32> }, // Запрос на обучение

Predict { data: Vec<f32> }, // Запрос на предсказание

UpdateModel { params: Vec<f32> }, // Обновление параметров модели

RegisterNode { addr: String }, // Регистрация узла на сервере

}

// Сообщение для центрального сервера

#[derive(Message)]

#[rtype(result = "Result<(), String>")]

struct ServerMessage {

node_addr: String,

params: Vec<f32>,

}

// Структура модели (простая нейронная сеть)

fn build_model(vs: &nn::Path) -> nn::Sequential {

nn::seq()

.add(nn::linear(vs / "layer1", 10, 64, Default::default()))

.add_fn(|xs| xs.relu())

.add(nn::linear(vs / "layer2", 64, 1, Default::default()))

}

2. Реализация актора узла (NodeActor)

struct NodeActor {

model: nn::Sequential,

opt: nn::Optimizer,

device: Device,

server_addr: String,

node_addr: String,

}

impl Actor for NodeActor {

type Context = Context<Self>;

fn started(&mut self, ctx: &mut Self::Context) {

info!("Node {} started", self.node_addr);

// Регистрация узла на сервере

let msg = NodeMessage::RegisterNode {

addr: self.node_addr.clone(),

};

ctx.run_later(std::time::Duration::from_secs(1), move |act, _| {

act.send_to_server(msg);

});

}

}

impl Handler<NodeMessage> for NodeActor {

type Result = Result<(), String>;

fn handle(&mut self, msg: NodeMessage, ctx: &mut Self::Context) -> Self::Result {

match msg {

NodeMessage::Train { data, labels } => {

info!("Training on node {}", self.node_addr);

let data_tensor = Tensor::of_slice(&data).to_device(self.device);

let labels_tensor = Tensor::of_slice(&labels).to_device(self.device);

// Обучение модели

for _ in 0..10 { // 10 эпох для примера

self.opt.zero_grad();

let pred = self.model.forward(&data_tensor);

let loss = pred.mse_loss(&labels_tensor, tch::Reduction::Mean);

loss.backward();

self.opt.step();

}

// Отправка обновленных параметров на сервер

let params = self.extract_params()?;

let server_msg = ServerMessage {

node_addr: self.node_addr.clone(),

params,

};

Arbiter::current().spawn(async move {

if let Err(e) = CENTRAL_SERVER.try_send(server_msg) {

error!("Failed to send params to server: {}", e);

}

});

Ok(())

}

NodeMessage::Predict { data } => {

let data_tensor = Tensor::of_slice(&data).to_device(self.device);

let pred = self.model.forward(&data_tensor);

info!("Prediction on node {}: {:?}", self.node_addr, pred);

Ok(())

}

NodeMessage::UpdateModel { params } => {

self.update_model(&params)?;

info!("Model updated on node {}", self.node_addr);

Ok(())

}

NodeMessage::RegisterNode { .. } => Ok(()), // Игнорируем, так как это для сервера

}

}

}

impl NodeActor {

fn new(server_addr: String, node_addr: String) -> Self {

let vs = nn::VarStore::new(Device::Cpu);

let model = build_model(&vs.root());

let opt = nn::Adam::default().build(&vs, 1e-3).unwrap();

Self {

model,

opt,

device: Device::Cpu,

server_addr,

node_addr,

}

}

fn extract_params(&self) -> Result<Vec<f32>> {

let mut params = Vec::new();

for param in self.model.parameters() {

let p = param.flatten().to_kind(Kind::Float);

params.extend(p.data::<f32>()?.to_vec());

}

Ok(params)

}

fn update_model(&mut self, params: &[f32]) -> Result<()> {

let mut offset = 0;

for mut param in self.model.parameters() {

let size = param.numel();

let slice = &params[offset..offset + size];

offset += size;

param.copy_(&Tensor::of_slice(slice).to_device(self.device));

}

Ok(())

}

fn send_to_server(&self, msg: NodeMessage) {

let server_addr = self.server_addr.clone();

Arbiter::current().spawn(async move {

let client = awc::Client::default();

if let Err(e) = client.post(&server_addr).send_json(&msg).await {

error!("Failed to send message to server: {}", e);

}

});

}

}

3. Реализация центрального сервера (CentralServer)

struct CentralServer {

nodes: Vec<String>,

aggregated_params: Option<Vec<f32>>,

model: nn::Sequential,

updates_received: usize,

total_nodes: usize,

}

impl Actor for CentralServer {

type Context = Context<Self>;

}

impl Handler<ServerMessage> for CentralServer {

type Result = Result<(), String>;

fn handle(&mut self, msg: ServerMessage, _ctx: &mut Self::Context) -> Self::Result {

if !self.nodes.contains(&msg.node_addr) {

self.nodes.push(msg.node_addr.clone());

}

self.updates_received += 1;

if let Some(ref mut aggregated) = self.aggregated_params {

for (a, b) in aggregated.iter_mut().zip(msg.params.iter()) {

*a += *b;

}

} else {

self.aggregated_params = Some(msg.params);

}

if self.updates_received >= self.total_nodes {

self.aggregate_and_broadcast()?;

self.updates_received = 0;

self.aggregated_params = None;

}

Ok(())

}

}

impl CentralServer {

fn new(total_nodes: usize) -> Self {

let vs = nn::VarStore::new(Device::Cpu);

let model = build_model(&vs.root());

Self {

nodes: Vec::new(),

aggregated_params: None,

model,

updates_received: 0,

total_nodes,

}

}

fn aggregate_and_broadcast(&mut self) -> Result<()> {

let aggregated = self.aggregated_params.as_mut().unwrap();

for param in aggregated.iter_mut() {

*param /= self.total_nodes as f32; // FedAvg

}

self.update_model(aggregated)?;

let msg = NodeMessage::UpdateModel {

params: aggregated.clone(),

};

for node in &self.nodes {

let client = awc::Client::default();

let node_addr = format!("{}/message", node);

Arbiter::current().spawn(async move {

if let Err(e) = client.post(&node_addr).send_json(&msg).await {

error!("Failed to broadcast to {}: {}", node_addr, e);

}

});

}

Ok(())

}

fn update_model(&mut self, params: &[f32]) -> Result<()> {

let mut offset = 0;

for mut param in self.model.parameters() {

let size = param.numel();

let slice = &params[offset..offset + size];

offset += size;

param.copy_(&Tensor::of_slice(slice).to_device(Device::Cpu));

}

Ok(())

}

}

static CENTRAL_SERVER: Lazy<Addr<CentralServer>> = Lazy::new(|| {

CentralServer::new(2).start() // Ожидаем обновления от 2 узлов

});

4. Сетевые маршруты

use actix_web::{web, App, HttpResponse, HttpServer};

use futures::future::join_all;

async fn receive_message(

msg: web::Json<NodeMessage>,

actor: web::Data<Addr<NodeActor>>,

) -> HttpResponse {

match actor.send(msg.0).await {

Ok(Ok(())) => HttpResponse::Ok().json(serde_json::json!({"status": "received"})),

_ => HttpResponse::InternalServerError().json(serde_json::json!({"status": "error"})),

}

}

async fn start_node(server_addr: &str, node_addr: &str) -> std::io::Result<()> {

let actor = NodeActor::new(server_addr.to_string(), node_addr.to_string()).start();

HttpServer::new(move || {

App::new()

.app_data(web::Data::new(actor.clone()))

.route("/message", web::post().to(receive_message))

})

.bind(node_addr)?

.run()

.await

}

5. Точка входа

use once_cell::sync::Lazy;

#[actix_rt::main]

async fn main() -> std::io::Result<()> {

env_logger::init();

info!("Starting distributed ML system");

let server_addr = "http://127.0.0.1:8080/message";

let nodes = vec![

start_node(server_addr, "127.0.0.1:8081"),

start_node(server_addr, "127.0.0.1:8082"),

];

// Симуляция обучения

let client = awc::Client::default();

let train_msg = NodeMessage::Train {

data: vec![1.0; 10], // Пример данных

labels: vec![0.5; 1],

};

for port in 8081..=8082 {

let url = format!("http://127.0.0.1:{}/message", port);

client.post(&url).send_json(&train_msg).await.unwrap();

}

join_all(nodes).await;

Ok(())

}

Обнаружение узлов через etcd, Балансировка нагрузки, Поддержка GPU через tch-rs:

use actix::prelude::*;

use etcd_client::{Client, Error as EtcdError};

use tch::{nn, Device};

use std::time::Duration;

// Центральный сервер

struct CentralServer {

nodes: Vec<String>,

current_node_index: usize,

etcd_addr: String,

}

impl Actor for CentralServer {

type Context = Context<Self>;

fn started(&mut self, ctx: &mut Self::Context) {

// Периодическое обновление списка узлов

ctx.run_interval(Duration::from_secs(10), |act, _| {

actix_rt::spawn(async move {

if let Ok(nodes) = discover_nodes(&act.etcd_addr).await {

act.nodes = nodes;

}

});

});

}

}

// Узел

struct NodeActor {

model: nn::Sequential,

device: Device,

server_addr: String,

node_addr: String,

etcd_addr: String,

}

impl Actor for NodeActor {

type Context = Context<Self>;

fn started(&mut self, ctx: &mut Self::Context) {

println!("Узел {} запущен", self.node_addr);

// Регистрация узла в etcd

let etcd_addr = self.etcd_addr.clone();

let node_addr = self.node_addr.clone();

actix_rt::spawn(async move {

if let Err(e) = register_node(&etcd_addr, &node_addr).await {

eprintln!("Ошибка регистрации узла: {}", e);

}

});

}

}

impl NodeActor {

fn new(server_addr: String, node_addr: String, etcd_addr: String) -> Self {

let device = if tch::Cuda::is_available() {

Device::Cuda(0)

} else {

Device::Cpu

};

let vs = nn::VarStore::new(device);

let model = build_model(&vs.root());

Self {

model,

device,

server_addr,

node_addr,

etcd_addr,

}

}

}

// Точка входа

#[actix_rt::main]

async fn main() -> std::io::Result<()> {

let etcd_addr = "http://127.0.0.1:2379";

let server_addr = "127.0.0.1:8080";

// Запуск узлов

let node1 = NodeActor::new(server_addr.to_string(), "127.0.0.1:8081".to_string(), etcd_addr.to_string()).start();

let node2 = NodeActor::new(server_addr.to_string(), "127.0.0.1:8082".to_string(), etcd_addr.to_string()).start();

// Запуск сервера

let server = CentralServer {

nodes: Vec::new(),

current_node_index: 0,

etcd_addr: etcd_addr.to_string(),

}.start();

actix_rt::System::current().run()

}

полный код

use actix::prelude::*;

use actix_web::{web, App, HttpResponse, HttpServer};

use anyhow::Result;

use etcd_client::{Client as EtcdClient, Error as EtcdError};

use futures::future::join_all;

use log::{error, info};

use once_cell::sync::Lazy;

use serde::{Deserialize, Serialize};

use std::sync::Arc;

use std::time::Duration;

use tch::{nn, nn::Module, nn::OptimizerConfig, Device, Kind, Tensor};

// Messages for communication between nodes and server

#[derive(Serialize, Deserialize, Message)]

#[rtype(result = "Result<(), String>")]

enum NodeMessage {

Train { data: Vec<f32>, labels: Vec<f32> }, // Train request

Predict { data: Vec<f32> }, // Prediction request

UpdateModel { params: Vec<f32> }, // Model parameter update

RegisterNode { addr: String }, // Node registration

}

// Message for the central server

#[derive(Message)]

#[rtype(result = "Result<(), String>")]

struct ServerMessage {

node_addr: String,

params: Vec<f32>,

}

// Neural network model definition

fn build_model(vs: &nn::Path) -> nn::Sequential {

nn::seq()

.add(nn::linear(vs / "layer1", 10, 64, Default::default()))

.add_fn(|xs| xs.relu())

.add(nn::linear(vs / "layer2", 64, 1, Default::default()))

}

// CentralServer actor

struct CentralServer {

nodes: Vec<String>,

aggregated_params: Option<Vec<f32>>,

model: nn::Sequential,

updates_received: usize,

total_nodes: usize,

etcd_addr: String,

current_node_index: usize, // For load balancing

}

impl Actor for CentralServer {

type Context = Context<Self>;

fn started(&mut self, ctx: &mut Self::Context) {

info!("CentralServer started");

let etcd_addr = self.etcd_addr.clone();

ctx.run_interval(Duration::from_secs(10), move |act, _| {

let etcd_addr = etcd_addr.clone();

actix_rt::spawn(async move {

match discover_nodes(&etcd_addr).await {

Ok(nodes) => act.nodes = nodes,

Err(e) => error!("Failed to discover nodes: {}", e),

}

});

});

}

}

impl Handler<ServerMessage> for CentralServer {

type Result = Result<(), String>;

fn handle(&mut self, msg: ServerMessage, _ctx: &mut Self::Context) -> Self::Result {

if !self.nodes.contains(&msg.node_addr) {

self.nodes.push(msg.node_addr.clone());

}

self.updates_received += 1;

if let Some(ref mut aggregated) = self.aggregated_params {

for (a, b) in aggregated.iter_mut().zip(msg.params.iter()) {

*a += *b;

}

} else {

self.aggregated_params = Some(msg.params);

}

if self.updates_received >= self.total_nodes {

self.aggregate_and_broadcast()?;

self.updates_received = 0;

self.aggregated_params = None;

}

Ok(())

}

}

impl CentralServer {

fn new(total_nodes: usize, etcd_addr: String) -> Self {

let vs = nn::VarStore::new(Device::Cpu);

let model = build_model(&vs.root());

Self {

nodes: Vec::new(),

aggregated_params: None,

model,

updates_received: 0,

total_nodes,

etcd_addr,

current_node_index: 0,

}

}

fn aggregate_and_broadcast(&mut self) -> Result<()> {

let aggregated = self.aggregated_params.as_mut().unwrap();

for param in aggregated.iter_mut() {

*param /= self.total_nodes as f32; // FedAvg algorithm

}

self.update_model(aggregated)?;

let msg = NodeMessage::UpdateModel {

params: aggregated.clone(),

};

for node in &self.nodes {

let client = awc::Client::default();

let node_addr = format!("{}/message", node);

actix_rt::spawn(async move {

if let Err(e) = client.post(&node_addr).send_json(&msg).await {

error!("Failed to broadcast to {}: {}", node_addr, e);

}

});

}

Ok(())

}

fn update_model(&mut self, params: &[f32]) -> Result<()> {

let mut offset = 0;

for mut param in self.model.parameters() {

let size = param.numel();

let slice = &params[offset..offset + size];

offset += size;

param.copy_(&Tensor::of_slice(slice).to_device(Device::Cpu));

}

Ok(())

}

}

// NodeActor actor

struct NodeActor {

model: nn::Sequential,

opt: nn::Optimizer,

device: Device,

server_addr: String,

node_addr: String,

etcd_addr: String,

}

impl Actor for NodeActor {

type Context = Context<Self>;

fn started(&mut self, ctx: &mut Self::Context) {

info!("Node {} started", self.node_addr);

let etcd_addr = self.etcd_addr.clone();

let node_addr = self.node_addr.clone();

actix_rt::spawn(async move {

if let Err(e) = register_node(&etcd_addr, &node_addr).await {

error!("Failed to register node in etcd: {}", e);

}

});

let msg = NodeMessage::RegisterNode {

addr: self.node_addr.clone(),

};

self.send_to_server(msg);

}

}

impl Handler<NodeMessage> for NodeActor {

type Result = Result<(), String>;

fn handle(&mut self, msg: NodeMessage, _ctx: &mut Self::Context) -> Self::Result {

match msg {

NodeMessage::Train { data, labels } => {

info!("Training on node {}", self.node_addr);

let data_tensor = Tensor::of_slice(&data).to_device(self.device);

let labels_tensor = Tensor::of_slice(&labels).to_device(self.device);

for _ in 0..10 {

self.opt.zero_grad();

let pred = self.model.forward(&data_tensor);

let loss = pred.mse_loss(&labels_tensor, tch::Reduction::Mean);

loss.backward();

self.opt.step();

}

let params = self.extract_params()?;

let server_msg = ServerMessage {

node_addr: self.node_addr.clone(),

params,

};

actix_rt::spawn(async move {

if let Err(e) = CENTRAL_SERVER.try_send(server_msg) {

error!("Failed to send params to server: {}", e);

}

});

Ok(())

}

NodeMessage::Predict { data } => {

let data_tensor = Tensor::of_slice(&data).to_device(self.device);

let pred = self.model.forward(&data_tensor);

info!("Prediction on node {}: {:?}", self.node_addr, pred);

Ok(())

}

NodeMessage::UpdateModel { params } => {

self.update_model(&params)?;

info!("Model updated on node {}", self.node_addr);

Ok(())

}

NodeMessage::RegisterNode { .. } => Ok(()),

}

}

}

impl NodeActor {

fn new(server_addr: String, node_addr: String, etcd_addr: String) -> Self {

let device = if tch::Cuda::is_available() {

info!("GPU available, using CUDA");

Device::Cuda(0)

} else {

info!("No GPU available, using CPU");

Device::Cpu

};

let vs = nn::VarStore::new(device);

let model = build_model(&vs.root());

let opt = nn::Adam::default().build(&vs, 1e-3).unwrap();

Self {

model,

opt,

device,

server_addr,

node_addr,

etcd_addr,

}

}

fn extract_params(&self) -> Result<Vec<f32>> {

let mut params = Vec::new();

for param in self.model.parameters() {

let p = param.flatten().to_kind(Kind::Float);

params.extend(p.data::<f32>()?.to_vec());

}

Ok(params)

}

fn update_model(&mut self, params: &[f32]) -> Result<()> {

let mut offset = 0;

for mut param in self.model.parameters() {

let size = param.numel();

let slice = &params[offset..offset + size];

offset += size;

param.copy_(&Tensor::of_slice(slice).to_device(self.device));

}

Ok(())

}

fn send_to_server(&self, msg: NodeMessage) {

let server_addr = self.server_addr.clone();

actix_rt::spawn(async move {

let client = awc::Client::default();

if let Err(e) = client.post(&server_addr).send_json(&msg).await {

error!("Failed to send message to server: {}", e);

}

});

}

}

// etcd node discovery and registration

async fn discover_nodes(etcd_addr: &str) -> Result<Vec<String>, EtcdError> {

let client = EtcdClient::connect([etcd_addr], None).await?;

let resp = client.get("/nodes/", None).await?;

let nodes = resp

.kvs()

.iter()

.map(|kv| kv.value_str().unwrap().to_string())

.collect();

Ok(nodes)

}

async fn register_node(etcd_addr: &str, node_addr: &str) -> Result<(), EtcdError> {

let client = EtcdClient::connect([etcd_addr], None).await?;

client

.put(format!("/nodes/{}", node_addr), node_addr, None)

.await?;

Ok(())

}

// HTTP endpoint for receiving messages

async fn receive_message(

msg: web::Json<NodeMessage>,

actor: web::Data<Addr<NodeActor>>,

) -> HttpResponse {

match actor.send(msg.0).await {

Ok(Ok(())) => HttpResponse::Ok().json(serde_json::json!({"status": "received"})),

Ok(Err(e)) => HttpResponse::InternalServerError().json(serde_json::json!({"status": "error", "message": e})),

Err(e) => HttpResponse::InternalServerError().json(serde_json::json!({"status": "error", "message": format!("Actor error: {}", e)})),

}

}

async fn start_node(server_addr: &str, node_addr: &str, etcd_addr: &str) -> std::io::Result<()> {

let actor = NodeActor::new(

server_addr.to_string(),

node_addr.to_string(),

etcd_addr.to_string(),

)

.start();

HttpServer::new(move || {

App::new()

.app_data(web::Data::new(actor.clone()))

.route("/message", web::post().to(receive_message))

})

.bind(node_addr)?

.run()

.await

}

// Global reference to CentralServer

static CENTRAL_SERVER: Lazy<Addr<CentralServer>> = Lazy::new(|| {

CentralServer::new(2, "http://127.0.0.1:2379".to_string()).start()

});

// Main entry point

#[actix_rt::main]

async fn main() -> std::io::Result<()> {

env_logger::init();

info!("Starting federated learning system");

let server_addr = "http://127.0.0.1:8080/message";

let etcd_addr = "http://127.0.0.1:2379";

let nodes = vec![

start_node(server_addr, "127.0.0.1:8081", etcd_addr),

start_node(server_addr, "127.0.0.1:8082", etcd_addr),

];

// Simulate training

let client = awc::Client::default();

let train_msg = NodeMessage::Train {

data: vec![1.0; 10],

labels: vec![0.5; 1],

};

for port in 8081..=8082 {

let url = format!("http://127.0.0.1:{}/message", port);

if let Err(e) = client.post(&url).send_json(&train_msg).await {

error!("Failed to send training message to port {}: {}", port, e);

}

}

join_all(nodes).await;

Ok(())

}