use std::collections::HashMap; use std::net::SocketAddr; use std::sync::{Arc, RwLock}; use std::time::{Duration, Instant}; #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct NodeId(pub u32); #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum NodeState { Up, Down, Unhealthy, Banned, Disabled, } impl NodeState { pub fn is_active(&self) -> bool { matches!(self, NodeState::Up) } pub fn as_str(&self) -> &'static str { match self { NodeState::Up => "UP", NodeState::Down => "DOWN", NodeState::Unhealthy => "UNHEALTHY", NodeState::Banned => "BANNED", NodeState::Disabled => "DISABLED", } } } #[derive(Debug, Clone)] pub struct NodeInfo { pub id: NodeId, pub addr: SocketAddr, pub state: NodeState, pub last_heartbeat: Option, pub public_ips: Vec, pub generation: u64, } impl NodeInfo { pub fn new(id: NodeId, addr: SocketAddr) -> Self { Self { id, addr, state: NodeState::Down, last_heartbeat: None, public_ips: Vec::new(), generation: 0, } } pub fn is_alive(&self, timeout: Duration) -> bool { match self.last_heartbeat { Some(t) => t.elapsed() < timeout, None => false, } } } #[derive(Debug, Clone)] pub struct NodeMask { nodes: Vec, } impl NodeMask { pub fn new(size: usize) -> Self { Self { nodes: vec![false; size], } } pub fn set(&mut self, id: NodeId, active: bool) { if (id.0 as usize) < self.nodes.len() { self.nodes[id.0 as usize] = active; } } pub fn is_active(&self, id: NodeId) -> bool { self.nodes.get(id.0 as usize).copied().unwrap_or(false) } pub fn active_nodes(&self) -> Vec { self.nodes .iter() .enumerate() .filter(|(_, &active)| active) .map(|(i, _)| NodeId(i as u32)) .collect() } pub fn len(&self) -> usize { self.nodes.len() } pub fn active_count(&self) -> usize { self.nodes.iter().filter(|&&a| a).count() } } pub struct NodeManager { nodes: RwLock>, self_id: NodeId, heartbeat_timeout: Duration, heartbeat_interval: Duration, } impl NodeManager { pub fn new(self_id: NodeId, self_addr: SocketAddr) -> Self { let mut nodes = HashMap::new(); nodes.insert( self_id, NodeInfo { id: self_id, addr: self_addr, state: NodeState::Up, last_heartbeat: Some(Instant::now()), public_ips: Vec::new(), generation: 0, }, ); Self { nodes: RwLock::new(nodes), self_id, heartbeat_timeout: Duration::from_secs(5), heartbeat_interval: Duration::from_secs(1), } } pub fn add_node(&self, id: NodeId, addr: SocketAddr) { let mut nodes = self.nodes.write().unwrap(); nodes.insert(id, NodeInfo::new(id, addr)); } pub fn remove_node(&self, id: NodeId) { let mut nodes = self.nodes.write().unwrap(); nodes.remove(&id); } pub fn record_heartbeat(&self, id: NodeId) { let mut nodes = self.nodes.write().unwrap(); if let Some(node) = nodes.get_mut(&id) { node.last_heartbeat = Some(Instant::now()); if node.state == NodeState::Down { node.state = NodeState::Up; node.generation += 1; } } } pub fn set_node_state(&self, id: NodeId, state: NodeState) { let mut nodes = self.nodes.write().unwrap(); if let Some(node) = nodes.get_mut(&id) { node.state = state; if state == NodeState::Up && node.last_heartbeat.is_none() { node.last_heartbeat = Some(Instant::now()); } node.generation += 1; } } pub fn get_node(&self, id: NodeId) -> Option { self.nodes.read().unwrap().get(&id).cloned() } pub fn all_nodes(&self) -> Vec { self.nodes.read().unwrap().values().cloned().collect() } pub fn active_nodes(&self) -> Vec { self.nodes .read() .unwrap() .values() .filter(|n| n.state == NodeState::Up) .cloned() .collect() } pub fn check_health(&self) -> Vec<(NodeId, NodeState)> { let timeout = self.heartbeat_timeout; let mut nodes = self.nodes.write().unwrap(); let mut transitions = Vec::new(); for (id, node) in nodes.iter_mut() { if *id == self.self_id { continue; } match &node.last_heartbeat { Some(t) => { if t.elapsed() > timeout && node.state == NodeState::Up { node.state = NodeState::Down; node.generation += 1; transitions.push((*id, NodeState::Down)); } } None => { if node.state == NodeState::Up { node.state = NodeState::Down; node.generation += 1; transitions.push((*id, NodeState::Down)); } } } } transitions } pub fn self_id(&self) -> NodeId { self.self_id } pub fn node_count(&self) -> usize { self.nodes.read().unwrap().len() } pub fn build_nodemask(&self) -> NodeMask { let nodes = self.nodes.read().unwrap(); let max_id = nodes.keys().map(|k| k.0).max().unwrap_or(0) as usize; let mut mask = NodeMask::new(max_id + 1); for (id, node) in nodes.iter() { mask.set(*id, node.state.is_active()); } mask } } #[cfg(test)] mod tests { use super::*; use std::net::{Ipv4Addr, SocketAddrV4}; fn addr(port: u16) -> SocketAddr { SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, port)) } #[test] fn test_node_creation() { let mgr = NodeManager::new(NodeId(0), addr(4000)); let self_node = mgr.get_node(NodeId(0)).unwrap(); assert_eq!(self_node.state, NodeState::Up); assert!(self_node.last_heartbeat.is_some()); } #[test] fn test_add_remove_node() { let mgr = NodeManager::new(NodeId(0), addr(4000)); mgr.add_node(NodeId(1), addr(4001)); assert_eq!(mgr.node_count(), 2); mgr.remove_node(NodeId(1)); assert_eq!(mgr.node_count(), 1); } #[test] fn test_heartbeat_updates() { let mgr = NodeManager::new(NodeId(0), addr(4000)); mgr.add_node(NodeId(1), addr(4001)); assert_eq!(mgr.get_node(NodeId(1)).unwrap().state, NodeState::Down); mgr.record_heartbeat(NodeId(1)); assert_eq!(mgr.get_node(NodeId(1)).unwrap().state, NodeState::Up); } #[test] fn test_health_check_timeout() { let mgr = NodeManager::new(NodeId(0), addr(4000)); mgr.add_node(NodeId(1), addr(4001)); mgr.record_heartbeat(NodeId(1)); assert_eq!(mgr.get_node(NodeId(1)).unwrap().state, NodeState::Up); std::thread::sleep(Duration::from_millis(100)); let mgr = NodeManager { nodes: RwLock::new(mgr.nodes.read().unwrap().clone()), self_id: NodeId(0), heartbeat_timeout: Duration::from_millis(50), heartbeat_interval: Duration::from_millis(10), }; let transitions = mgr.check_health(); assert!(transitions.contains(&(NodeId(1), NodeState::Down))); } #[test] fn test_nodemask() { let mut mask = NodeMask::new(5); mask.set(NodeId(0), true); mask.set(NodeId(2), true); mask.set(NodeId(4), true); assert!(mask.is_active(NodeId(0))); assert!(!mask.is_active(NodeId(1))); assert_eq!(mask.active_count(), 3); assert_eq!(mask.len(), 5); } #[test] fn test_build_nodemask() { let mgr = NodeManager::new(NodeId(0), addr(4000)); mgr.add_node(NodeId(1), addr(4001)); mgr.add_node(NodeId(2), addr(4002)); mgr.record_heartbeat(NodeId(1)); let mask = mgr.build_nodemask(); assert!(mask.is_active(NodeId(0))); assert!(mask.is_active(NodeId(1))); assert!(!mask.is_active(NodeId(2))); } #[test] fn test_node_state_string() { assert_eq!(NodeState::Up.as_str(), "UP"); assert_eq!(NodeState::Down.as_str(), "DOWN"); assert_eq!(NodeState::Unhealthy.as_str(), "UNHEALTHY"); assert_eq!(NodeState::Banned.as_str(), "BANNED"); assert_eq!(NodeState::Disabled.as_str(), "DISABLED"); } #[test] fn test_set_node_state() { let mgr = NodeManager::new(NodeId(0), addr(4000)); mgr.add_node(NodeId(1), addr(4001)); mgr.set_node_state(NodeId(1), NodeState::Banned); assert_eq!(mgr.get_node(NodeId(1)).unwrap().state, NodeState::Banned); } #[test] fn test_generation_increment() { let mgr = NodeManager::new(NodeId(0), addr(4000)); mgr.add_node(NodeId(1), addr(4001)); let gen0 = mgr.get_node(NodeId(1)).unwrap().generation; mgr.set_node_state(NodeId(1), NodeState::Down); let gen1 = mgr.get_node(NodeId(1)).unwrap().generation; assert!(gen1 > gen0); } #[test] fn test_active_nodes() { let mgr = NodeManager::new(NodeId(0), addr(4000)); mgr.add_node(NodeId(1), addr(4001)); mgr.add_node(NodeId(2), addr(4002)); mgr.record_heartbeat(NodeId(1)); let active = mgr.active_nodes(); assert_eq!(active.len(), 2); } #[test] fn test_is_alive() { let node = NodeInfo::new(NodeId(0), addr(4000)); assert!(!node.is_alive(Duration::from_secs(5))); let mut node = node; node.last_heartbeat = Some(Instant::now()); assert!(node.is_alive(Duration::from_secs(5))); } }