353 lines
10 KiB
Rust
353 lines
10 KiB
Rust
use std::collections::HashMap;
|
|
use std::net::SocketAddr;
|
|
use std::sync::{Arc, RwLock};
|
|
use std::time::{Duration, Instant};
|
|
|
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
|
pub struct NodeId(pub u32);
|
|
|
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
|
pub enum NodeState {
|
|
Up,
|
|
Down,
|
|
Unhealthy,
|
|
Banned,
|
|
Disabled,
|
|
}
|
|
|
|
impl NodeState {
|
|
pub fn is_active(&self) -> bool {
|
|
matches!(self, NodeState::Up)
|
|
}
|
|
|
|
pub fn as_str(&self) -> &'static str {
|
|
match self {
|
|
NodeState::Up => "UP",
|
|
NodeState::Down => "DOWN",
|
|
NodeState::Unhealthy => "UNHEALTHY",
|
|
NodeState::Banned => "BANNED",
|
|
NodeState::Disabled => "DISABLED",
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct NodeInfo {
|
|
pub id: NodeId,
|
|
pub addr: SocketAddr,
|
|
pub state: NodeState,
|
|
pub last_heartbeat: Option<Instant>,
|
|
pub public_ips: Vec<String>,
|
|
pub generation: u64,
|
|
}
|
|
|
|
impl NodeInfo {
|
|
pub fn new(id: NodeId, addr: SocketAddr) -> Self {
|
|
Self {
|
|
id,
|
|
addr,
|
|
state: NodeState::Down,
|
|
last_heartbeat: None,
|
|
public_ips: Vec::new(),
|
|
generation: 0,
|
|
}
|
|
}
|
|
|
|
pub fn is_alive(&self, timeout: Duration) -> bool {
|
|
match self.last_heartbeat {
|
|
Some(t) => t.elapsed() < timeout,
|
|
None => false,
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct NodeMask {
|
|
nodes: Vec<bool>,
|
|
}
|
|
|
|
impl NodeMask {
|
|
pub fn new(size: usize) -> Self {
|
|
Self {
|
|
nodes: vec![false; size],
|
|
}
|
|
}
|
|
|
|
pub fn set(&mut self, id: NodeId, active: bool) {
|
|
if (id.0 as usize) < self.nodes.len() {
|
|
self.nodes[id.0 as usize] = active;
|
|
}
|
|
}
|
|
|
|
pub fn is_active(&self, id: NodeId) -> bool {
|
|
self.nodes.get(id.0 as usize).copied().unwrap_or(false)
|
|
}
|
|
|
|
pub fn active_nodes(&self) -> Vec<NodeId> {
|
|
self.nodes
|
|
.iter()
|
|
.enumerate()
|
|
.filter(|(_, &active)| active)
|
|
.map(|(i, _)| NodeId(i as u32))
|
|
.collect()
|
|
}
|
|
|
|
pub fn len(&self) -> usize {
|
|
self.nodes.len()
|
|
}
|
|
|
|
pub fn active_count(&self) -> usize {
|
|
self.nodes.iter().filter(|&&a| a).count()
|
|
}
|
|
}
|
|
|
|
pub struct NodeManager {
|
|
nodes: RwLock<HashMap<NodeId, NodeInfo>>,
|
|
self_id: NodeId,
|
|
heartbeat_timeout: Duration,
|
|
heartbeat_interval: Duration,
|
|
}
|
|
|
|
impl NodeManager {
|
|
pub fn new(self_id: NodeId, self_addr: SocketAddr) -> Self {
|
|
let mut nodes = HashMap::new();
|
|
nodes.insert(
|
|
self_id,
|
|
NodeInfo {
|
|
id: self_id,
|
|
addr: self_addr,
|
|
state: NodeState::Up,
|
|
last_heartbeat: Some(Instant::now()),
|
|
public_ips: Vec::new(),
|
|
generation: 0,
|
|
},
|
|
);
|
|
Self {
|
|
nodes: RwLock::new(nodes),
|
|
self_id,
|
|
heartbeat_timeout: Duration::from_secs(5),
|
|
heartbeat_interval: Duration::from_secs(1),
|
|
}
|
|
}
|
|
|
|
pub fn add_node(&self, id: NodeId, addr: SocketAddr) {
|
|
let mut nodes = self.nodes.write().unwrap();
|
|
nodes.insert(id, NodeInfo::new(id, addr));
|
|
}
|
|
|
|
pub fn remove_node(&self, id: NodeId) {
|
|
let mut nodes = self.nodes.write().unwrap();
|
|
nodes.remove(&id);
|
|
}
|
|
|
|
pub fn record_heartbeat(&self, id: NodeId) {
|
|
let mut nodes = self.nodes.write().unwrap();
|
|
if let Some(node) = nodes.get_mut(&id) {
|
|
node.last_heartbeat = Some(Instant::now());
|
|
if node.state == NodeState::Down {
|
|
node.state = NodeState::Up;
|
|
node.generation += 1;
|
|
}
|
|
}
|
|
}
|
|
|
|
pub fn set_node_state(&self, id: NodeId, state: NodeState) {
|
|
let mut nodes = self.nodes.write().unwrap();
|
|
if let Some(node) = nodes.get_mut(&id) {
|
|
node.state = state;
|
|
if state == NodeState::Up && node.last_heartbeat.is_none() {
|
|
node.last_heartbeat = Some(Instant::now());
|
|
}
|
|
node.generation += 1;
|
|
}
|
|
}
|
|
|
|
pub fn get_node(&self, id: NodeId) -> Option<NodeInfo> {
|
|
self.nodes.read().unwrap().get(&id).cloned()
|
|
}
|
|
|
|
pub fn all_nodes(&self) -> Vec<NodeInfo> {
|
|
self.nodes.read().unwrap().values().cloned().collect()
|
|
}
|
|
|
|
pub fn active_nodes(&self) -> Vec<NodeInfo> {
|
|
self.nodes
|
|
.read()
|
|
.unwrap()
|
|
.values()
|
|
.filter(|n| n.state == NodeState::Up)
|
|
.cloned()
|
|
.collect()
|
|
}
|
|
|
|
pub fn check_health(&self) -> Vec<(NodeId, NodeState)> {
|
|
let timeout = self.heartbeat_timeout;
|
|
let mut nodes = self.nodes.write().unwrap();
|
|
let mut transitions = Vec::new();
|
|
|
|
for (id, node) in nodes.iter_mut() {
|
|
if *id == self.self_id {
|
|
continue;
|
|
}
|
|
match &node.last_heartbeat {
|
|
Some(t) => {
|
|
if t.elapsed() > timeout && node.state == NodeState::Up {
|
|
node.state = NodeState::Down;
|
|
node.generation += 1;
|
|
transitions.push((*id, NodeState::Down));
|
|
}
|
|
}
|
|
None => {
|
|
if node.state == NodeState::Up {
|
|
node.state = NodeState::Down;
|
|
node.generation += 1;
|
|
transitions.push((*id, NodeState::Down));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
transitions
|
|
}
|
|
|
|
pub fn self_id(&self) -> NodeId {
|
|
self.self_id
|
|
}
|
|
|
|
pub fn node_count(&self) -> usize {
|
|
self.nodes.read().unwrap().len()
|
|
}
|
|
|
|
pub fn build_nodemask(&self) -> NodeMask {
|
|
let nodes = self.nodes.read().unwrap();
|
|
let max_id = nodes.keys().map(|k| k.0).max().unwrap_or(0) as usize;
|
|
let mut mask = NodeMask::new(max_id + 1);
|
|
for (id, node) in nodes.iter() {
|
|
mask.set(*id, node.state.is_active());
|
|
}
|
|
mask
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use std::net::{Ipv4Addr, SocketAddrV4};
|
|
|
|
fn addr(port: u16) -> SocketAddr {
|
|
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, port))
|
|
}
|
|
|
|
#[test]
|
|
fn test_node_creation() {
|
|
let mgr = NodeManager::new(NodeId(0), addr(4000));
|
|
let self_node = mgr.get_node(NodeId(0)).unwrap();
|
|
assert_eq!(self_node.state, NodeState::Up);
|
|
assert!(self_node.last_heartbeat.is_some());
|
|
}
|
|
|
|
#[test]
|
|
fn test_add_remove_node() {
|
|
let mgr = NodeManager::new(NodeId(0), addr(4000));
|
|
mgr.add_node(NodeId(1), addr(4001));
|
|
assert_eq!(mgr.node_count(), 2);
|
|
mgr.remove_node(NodeId(1));
|
|
assert_eq!(mgr.node_count(), 1);
|
|
}
|
|
|
|
#[test]
|
|
fn test_heartbeat_updates() {
|
|
let mgr = NodeManager::new(NodeId(0), addr(4000));
|
|
mgr.add_node(NodeId(1), addr(4001));
|
|
assert_eq!(mgr.get_node(NodeId(1)).unwrap().state, NodeState::Down);
|
|
mgr.record_heartbeat(NodeId(1));
|
|
assert_eq!(mgr.get_node(NodeId(1)).unwrap().state, NodeState::Up);
|
|
}
|
|
|
|
#[test]
|
|
fn test_health_check_timeout() {
|
|
let mgr = NodeManager::new(NodeId(0), addr(4000));
|
|
mgr.add_node(NodeId(1), addr(4001));
|
|
mgr.record_heartbeat(NodeId(1));
|
|
assert_eq!(mgr.get_node(NodeId(1)).unwrap().state, NodeState::Up);
|
|
std::thread::sleep(Duration::from_millis(100));
|
|
let mgr = NodeManager {
|
|
nodes: RwLock::new(mgr.nodes.read().unwrap().clone()),
|
|
self_id: NodeId(0),
|
|
heartbeat_timeout: Duration::from_millis(50),
|
|
heartbeat_interval: Duration::from_millis(10),
|
|
};
|
|
let transitions = mgr.check_health();
|
|
assert!(transitions.contains(&(NodeId(1), NodeState::Down)));
|
|
}
|
|
|
|
#[test]
|
|
fn test_nodemask() {
|
|
let mut mask = NodeMask::new(5);
|
|
mask.set(NodeId(0), true);
|
|
mask.set(NodeId(2), true);
|
|
mask.set(NodeId(4), true);
|
|
assert!(mask.is_active(NodeId(0)));
|
|
assert!(!mask.is_active(NodeId(1)));
|
|
assert_eq!(mask.active_count(), 3);
|
|
assert_eq!(mask.len(), 5);
|
|
}
|
|
|
|
#[test]
|
|
fn test_build_nodemask() {
|
|
let mgr = NodeManager::new(NodeId(0), addr(4000));
|
|
mgr.add_node(NodeId(1), addr(4001));
|
|
mgr.add_node(NodeId(2), addr(4002));
|
|
mgr.record_heartbeat(NodeId(1));
|
|
let mask = mgr.build_nodemask();
|
|
assert!(mask.is_active(NodeId(0)));
|
|
assert!(mask.is_active(NodeId(1)));
|
|
assert!(!mask.is_active(NodeId(2)));
|
|
}
|
|
|
|
#[test]
|
|
fn test_node_state_string() {
|
|
assert_eq!(NodeState::Up.as_str(), "UP");
|
|
assert_eq!(NodeState::Down.as_str(), "DOWN");
|
|
assert_eq!(NodeState::Unhealthy.as_str(), "UNHEALTHY");
|
|
assert_eq!(NodeState::Banned.as_str(), "BANNED");
|
|
assert_eq!(NodeState::Disabled.as_str(), "DISABLED");
|
|
}
|
|
|
|
#[test]
|
|
fn test_set_node_state() {
|
|
let mgr = NodeManager::new(NodeId(0), addr(4000));
|
|
mgr.add_node(NodeId(1), addr(4001));
|
|
mgr.set_node_state(NodeId(1), NodeState::Banned);
|
|
assert_eq!(mgr.get_node(NodeId(1)).unwrap().state, NodeState::Banned);
|
|
}
|
|
|
|
#[test]
|
|
fn test_generation_increment() {
|
|
let mgr = NodeManager::new(NodeId(0), addr(4000));
|
|
mgr.add_node(NodeId(1), addr(4001));
|
|
let gen0 = mgr.get_node(NodeId(1)).unwrap().generation;
|
|
mgr.set_node_state(NodeId(1), NodeState::Down);
|
|
let gen1 = mgr.get_node(NodeId(1)).unwrap().generation;
|
|
assert!(gen1 > gen0);
|
|
}
|
|
|
|
#[test]
|
|
fn test_active_nodes() {
|
|
let mgr = NodeManager::new(NodeId(0), addr(4000));
|
|
mgr.add_node(NodeId(1), addr(4001));
|
|
mgr.add_node(NodeId(2), addr(4002));
|
|
mgr.record_heartbeat(NodeId(1));
|
|
let active = mgr.active_nodes();
|
|
assert_eq!(active.len(), 2);
|
|
}
|
|
|
|
#[test]
|
|
fn test_is_alive() {
|
|
let node = NodeInfo::new(NodeId(0), addr(4000));
|
|
assert!(!node.is_alive(Duration::from_secs(5)));
|
|
let mut node = node;
|
|
node.last_heartbeat = Some(Instant::now());
|
|
assert!(node.is_alive(Duration::from_secs(5)));
|
|
}
|
|
} |