Files
markbase/markbase-core/src/ctdb/node.rs

353 lines
10 KiB
Rust

use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::{Arc, RwLock};
use std::time::{Duration, Instant};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct NodeId(pub u32);
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum NodeState {
Up,
Down,
Unhealthy,
Banned,
Disabled,
}
impl NodeState {
pub fn is_active(&self) -> bool {
matches!(self, NodeState::Up)
}
pub fn as_str(&self) -> &'static str {
match self {
NodeState::Up => "UP",
NodeState::Down => "DOWN",
NodeState::Unhealthy => "UNHEALTHY",
NodeState::Banned => "BANNED",
NodeState::Disabled => "DISABLED",
}
}
}
#[derive(Debug, Clone)]
pub struct NodeInfo {
pub id: NodeId,
pub addr: SocketAddr,
pub state: NodeState,
pub last_heartbeat: Option<Instant>,
pub public_ips: Vec<String>,
pub generation: u64,
}
impl NodeInfo {
pub fn new(id: NodeId, addr: SocketAddr) -> Self {
Self {
id,
addr,
state: NodeState::Down,
last_heartbeat: None,
public_ips: Vec::new(),
generation: 0,
}
}
pub fn is_alive(&self, timeout: Duration) -> bool {
match self.last_heartbeat {
Some(t) => t.elapsed() < timeout,
None => false,
}
}
}
#[derive(Debug, Clone)]
pub struct NodeMask {
nodes: Vec<bool>,
}
impl NodeMask {
pub fn new(size: usize) -> Self {
Self {
nodes: vec![false; size],
}
}
pub fn set(&mut self, id: NodeId, active: bool) {
if (id.0 as usize) < self.nodes.len() {
self.nodes[id.0 as usize] = active;
}
}
pub fn is_active(&self, id: NodeId) -> bool {
self.nodes.get(id.0 as usize).copied().unwrap_or(false)
}
pub fn active_nodes(&self) -> Vec<NodeId> {
self.nodes
.iter()
.enumerate()
.filter(|(_, &active)| active)
.map(|(i, _)| NodeId(i as u32))
.collect()
}
pub fn len(&self) -> usize {
self.nodes.len()
}
pub fn active_count(&self) -> usize {
self.nodes.iter().filter(|&&a| a).count()
}
}
pub struct NodeManager {
nodes: RwLock<HashMap<NodeId, NodeInfo>>,
self_id: NodeId,
heartbeat_timeout: Duration,
heartbeat_interval: Duration,
}
impl NodeManager {
pub fn new(self_id: NodeId, self_addr: SocketAddr) -> Self {
let mut nodes = HashMap::new();
nodes.insert(
self_id,
NodeInfo {
id: self_id,
addr: self_addr,
state: NodeState::Up,
last_heartbeat: Some(Instant::now()),
public_ips: Vec::new(),
generation: 0,
},
);
Self {
nodes: RwLock::new(nodes),
self_id,
heartbeat_timeout: Duration::from_secs(5),
heartbeat_interval: Duration::from_secs(1),
}
}
pub fn add_node(&self, id: NodeId, addr: SocketAddr) {
let mut nodes = self.nodes.write().unwrap();
nodes.insert(id, NodeInfo::new(id, addr));
}
pub fn remove_node(&self, id: NodeId) {
let mut nodes = self.nodes.write().unwrap();
nodes.remove(&id);
}
pub fn record_heartbeat(&self, id: NodeId) {
let mut nodes = self.nodes.write().unwrap();
if let Some(node) = nodes.get_mut(&id) {
node.last_heartbeat = Some(Instant::now());
if node.state == NodeState::Down {
node.state = NodeState::Up;
node.generation += 1;
}
}
}
pub fn set_node_state(&self, id: NodeId, state: NodeState) {
let mut nodes = self.nodes.write().unwrap();
if let Some(node) = nodes.get_mut(&id) {
node.state = state;
if state == NodeState::Up && node.last_heartbeat.is_none() {
node.last_heartbeat = Some(Instant::now());
}
node.generation += 1;
}
}
pub fn get_node(&self, id: NodeId) -> Option<NodeInfo> {
self.nodes.read().unwrap().get(&id).cloned()
}
pub fn all_nodes(&self) -> Vec<NodeInfo> {
self.nodes.read().unwrap().values().cloned().collect()
}
pub fn active_nodes(&self) -> Vec<NodeInfo> {
self.nodes
.read()
.unwrap()
.values()
.filter(|n| n.state == NodeState::Up)
.cloned()
.collect()
}
pub fn check_health(&self) -> Vec<(NodeId, NodeState)> {
let timeout = self.heartbeat_timeout;
let mut nodes = self.nodes.write().unwrap();
let mut transitions = Vec::new();
for (id, node) in nodes.iter_mut() {
if *id == self.self_id {
continue;
}
match &node.last_heartbeat {
Some(t) => {
if t.elapsed() > timeout && node.state == NodeState::Up {
node.state = NodeState::Down;
node.generation += 1;
transitions.push((*id, NodeState::Down));
}
}
None => {
if node.state == NodeState::Up {
node.state = NodeState::Down;
node.generation += 1;
transitions.push((*id, NodeState::Down));
}
}
}
}
transitions
}
pub fn self_id(&self) -> NodeId {
self.self_id
}
pub fn node_count(&self) -> usize {
self.nodes.read().unwrap().len()
}
pub fn build_nodemask(&self) -> NodeMask {
let nodes = self.nodes.read().unwrap();
let max_id = nodes.keys().map(|k| k.0).max().unwrap_or(0) as usize;
let mut mask = NodeMask::new(max_id + 1);
for (id, node) in nodes.iter() {
mask.set(*id, node.state.is_active());
}
mask
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::{Ipv4Addr, SocketAddrV4};
fn addr(port: u16) -> SocketAddr {
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, port))
}
#[test]
fn test_node_creation() {
let mgr = NodeManager::new(NodeId(0), addr(4000));
let self_node = mgr.get_node(NodeId(0)).unwrap();
assert_eq!(self_node.state, NodeState::Up);
assert!(self_node.last_heartbeat.is_some());
}
#[test]
fn test_add_remove_node() {
let mgr = NodeManager::new(NodeId(0), addr(4000));
mgr.add_node(NodeId(1), addr(4001));
assert_eq!(mgr.node_count(), 2);
mgr.remove_node(NodeId(1));
assert_eq!(mgr.node_count(), 1);
}
#[test]
fn test_heartbeat_updates() {
let mgr = NodeManager::new(NodeId(0), addr(4000));
mgr.add_node(NodeId(1), addr(4001));
assert_eq!(mgr.get_node(NodeId(1)).unwrap().state, NodeState::Down);
mgr.record_heartbeat(NodeId(1));
assert_eq!(mgr.get_node(NodeId(1)).unwrap().state, NodeState::Up);
}
#[test]
fn test_health_check_timeout() {
let mgr = NodeManager::new(NodeId(0), addr(4000));
mgr.add_node(NodeId(1), addr(4001));
mgr.record_heartbeat(NodeId(1));
assert_eq!(mgr.get_node(NodeId(1)).unwrap().state, NodeState::Up);
std::thread::sleep(Duration::from_millis(100));
let mgr = NodeManager {
nodes: RwLock::new(mgr.nodes.read().unwrap().clone()),
self_id: NodeId(0),
heartbeat_timeout: Duration::from_millis(50),
heartbeat_interval: Duration::from_millis(10),
};
let transitions = mgr.check_health();
assert!(transitions.contains(&(NodeId(1), NodeState::Down)));
}
#[test]
fn test_nodemask() {
let mut mask = NodeMask::new(5);
mask.set(NodeId(0), true);
mask.set(NodeId(2), true);
mask.set(NodeId(4), true);
assert!(mask.is_active(NodeId(0)));
assert!(!mask.is_active(NodeId(1)));
assert_eq!(mask.active_count(), 3);
assert_eq!(mask.len(), 5);
}
#[test]
fn test_build_nodemask() {
let mgr = NodeManager::new(NodeId(0), addr(4000));
mgr.add_node(NodeId(1), addr(4001));
mgr.add_node(NodeId(2), addr(4002));
mgr.record_heartbeat(NodeId(1));
let mask = mgr.build_nodemask();
assert!(mask.is_active(NodeId(0)));
assert!(mask.is_active(NodeId(1)));
assert!(!mask.is_active(NodeId(2)));
}
#[test]
fn test_node_state_string() {
assert_eq!(NodeState::Up.as_str(), "UP");
assert_eq!(NodeState::Down.as_str(), "DOWN");
assert_eq!(NodeState::Unhealthy.as_str(), "UNHEALTHY");
assert_eq!(NodeState::Banned.as_str(), "BANNED");
assert_eq!(NodeState::Disabled.as_str(), "DISABLED");
}
#[test]
fn test_set_node_state() {
let mgr = NodeManager::new(NodeId(0), addr(4000));
mgr.add_node(NodeId(1), addr(4001));
mgr.set_node_state(NodeId(1), NodeState::Banned);
assert_eq!(mgr.get_node(NodeId(1)).unwrap().state, NodeState::Banned);
}
#[test]
fn test_generation_increment() {
let mgr = NodeManager::new(NodeId(0), addr(4000));
mgr.add_node(NodeId(1), addr(4001));
let gen0 = mgr.get_node(NodeId(1)).unwrap().generation;
mgr.set_node_state(NodeId(1), NodeState::Down);
let gen1 = mgr.get_node(NodeId(1)).unwrap().generation;
assert!(gen1 > gen0);
}
#[test]
fn test_active_nodes() {
let mgr = NodeManager::new(NodeId(0), addr(4000));
mgr.add_node(NodeId(1), addr(4001));
mgr.add_node(NodeId(2), addr(4002));
mgr.record_heartbeat(NodeId(1));
let active = mgr.active_nodes();
assert_eq!(active.len(), 2);
}
#[test]
fn test_is_alive() {
let node = NodeInfo::new(NodeId(0), addr(4000));
assert!(!node.is_alive(Duration::from_secs(5)));
let mut node = node;
node.last_heartbeat = Some(Instant::now());
assert!(node.is_alive(Duration::from_secs(5)));
}
}