use std::io::{self, Read, Write}; use std::net::TcpStream; pub const CTDB_MAGIC: u32 = 0x43544442; pub const CTDB_VERSION: u32 = 1; pub const CTDB_HEADER_SIZE: usize = 24; #[derive(Debug, Clone, Copy, PartialEq, Eq)] #[repr(u32)] pub enum CtdbCommand { Connect = 1, Disconnect = 2, Ping = 3, Pong = 4, GetDb = 10, Fetch = 11, Store = 12, Delete = 13, Keys = 14, SetNodeMask = 20, GetNodeMask = 21, NodeStatus = 22, TakeIp = 30, ReleaseIp = 31, Monitor = 40, Recovery = 50, RecoveryDone = 51, Unknown = 0xFFFF, } impl CtdbCommand { pub fn from_u32(v: u32) -> Self { match v { 1 => CtdbCommand::Connect, 2 => CtdbCommand::Disconnect, 3 => CtdbCommand::Ping, 4 => CtdbCommand::Pong, 10 => CtdbCommand::GetDb, 11 => CtdbCommand::Fetch, 12 => CtdbCommand::Store, 13 => CtdbCommand::Delete, 14 => CtdbCommand::Keys, 20 => CtdbCommand::SetNodeMask, 21 => CtdbCommand::GetNodeMask, 22 => CtdbCommand::NodeStatus, 30 => CtdbCommand::TakeIp, 31 => CtdbCommand::ReleaseIp, 40 => CtdbCommand::Monitor, 50 => CtdbCommand::Recovery, 51 => CtdbCommand::RecoveryDone, _ => CtdbCommand::Unknown, } } } #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum CtdbStatus { Success = 0, Error = 1, NotFound = 2, Exists = 3, Corrupt = 4, Timeout = 5, NotActive = 6, } impl CtdbStatus { pub fn from_u32(v: u32) -> Self { match v { 0 => CtdbStatus::Success, 2 => CtdbStatus::NotFound, 3 => CtdbStatus::Exists, 4 => CtdbStatus::Corrupt, 5 => CtdbStatus::Timeout, 6 => CtdbStatus::NotActive, _ => CtdbStatus::Error, } } } #[derive(Debug, Clone)] pub struct CtdbHeader { pub magic: u32, pub version: u32, pub command: u32, pub status: u32, pub length: u64, } impl CtdbHeader { pub fn new(command: CtdbCommand, status: CtdbStatus, length: u64) -> Self { Self { magic: CTDB_MAGIC, version: CTDB_VERSION, command: command as u32, status: status as u32, length, } } pub fn to_bytes(&self) -> Vec { let mut buf = Vec::with_capacity(CTDB_HEADER_SIZE); buf.extend_from_slice(&self.magic.to_le_bytes()); buf.extend_from_slice(&self.version.to_le_bytes()); buf.extend_from_slice(&self.command.to_le_bytes()); buf.extend_from_slice(&self.status.to_le_bytes()); buf.extend_from_slice(&self.length.to_le_bytes()); buf } pub fn from_bytes(buf: &[u8]) -> Result { if buf.len() < CTDB_HEADER_SIZE { return Err(CtdbProtoError::HeaderTooShort); } let magic = u32::from_le_bytes(buf[0..4].try_into().unwrap()); let version = u32::from_le_bytes(buf[4..8].try_into().unwrap()); let command = u32::from_le_bytes(buf[8..12].try_into().unwrap()); let status = u32::from_le_bytes(buf[12..16].try_into().unwrap()); let length = u64::from_le_bytes(buf[16..24].try_into().unwrap()); Ok(Self { magic, version, command, status, length, }) } pub fn is_valid(&self) -> bool { self.magic == CTDB_MAGIC && self.version == CTDB_VERSION } } #[derive(Debug, Clone)] pub struct CtdbMessage { pub header: CtdbHeader, pub payload: Vec, } impl CtdbMessage { pub fn new(command: CtdbCommand, status: CtdbStatus, payload: Vec) -> Self { let length = payload.len() as u64; Self { header: CtdbHeader::new(command, status, length), payload, } } pub fn to_bytes(&self) -> Vec { let mut buf = self.header.to_bytes(); buf.extend_from_slice(&self.payload); buf } pub fn from_bytes(buf: &[u8]) -> Result { let header = CtdbHeader::from_bytes(buf)?; if !header.is_valid() { return Err(CtdbProtoError::InvalidMagic); } let payload = if buf.len() > CTDB_HEADER_SIZE { buf[CTDB_HEADER_SIZE..].to_vec() } else { Vec::new() }; Ok(Self { header, payload }) } pub fn command(&self) -> CtdbCommand { CtdbCommand::from_u32(self.header.command) } pub fn status(&self) -> CtdbStatus { CtdbStatus::from_u32(self.header.status) } } #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum CtdbProtoError { HeaderTooShort, InvalidMagic, IoError, InvalidPayload, } impl std::fmt::Display for CtdbProtoError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { CtdbProtoError::HeaderTooShort => write!(f, "header too short"), CtdbProtoError::InvalidMagic => write!(f, "invalid magic number"), CtdbProtoError::IoError => write!(f, "I/O error"), CtdbProtoError::InvalidPayload => write!(f, "invalid payload"), } } } impl std::error::Error for CtdbProtoError {} impl From for CtdbProtoError { fn from(_: io::Error) -> Self { CtdbProtoError::IoError } } pub mod payload { use super::CtdbProtoError; pub fn encode_kv(key: &[u8], value: &[u8]) -> Vec { let mut buf = Vec::new(); buf.extend_from_slice(&(key.len() as u32).to_le_bytes()); buf.extend_from_slice(key); buf.extend_from_slice(&(value.len() as u32).to_le_bytes()); buf.extend_from_slice(value); buf } pub fn decode_kv(payload: &[u8]) -> Result<(Vec, Vec), CtdbProtoError> { if payload.len() < 4 { return Err(CtdbProtoError::InvalidPayload); } let key_len = u32::from_le_bytes(payload[0..4].try_into().unwrap()) as usize; if payload.len() < 4 + key_len + 4 { return Err(CtdbProtoError::InvalidPayload); } let key = payload[4..4 + key_len].to_vec(); let val_len_offset = 4 + key_len; let val_len = u32::from_le_bytes( payload[val_len_offset..val_len_offset + 4].try_into().unwrap(), ) as usize; let val_start = val_len_offset + 4; if payload.len() < val_start + val_len { return Err(CtdbProtoError::InvalidPayload); } let value = payload[val_start..val_start + val_len].to_vec(); Ok((key, value)) } pub fn encode_key(key: &[u8]) -> Vec { let mut buf = Vec::new(); buf.extend_from_slice(&(key.len() as u32).to_le_bytes()); buf.extend_from_slice(key); buf } pub fn decode_key(payload: &[u8]) -> Result, CtdbProtoError> { if payload.len() < 4 { return Err(CtdbProtoError::InvalidPayload); } let key_len = u32::from_le_bytes(payload[0..4].try_into().unwrap()) as usize; if payload.len() < 4 + key_len { return Err(CtdbProtoError::InvalidPayload); } Ok(payload[4..4 + key_len].to_vec()) } pub fn encode_node_id(id: u32) -> Vec { id.to_le_bytes().to_vec() } pub fn decode_node_id(payload: &[u8]) -> Result { if payload.len() < 4 { return Err(CtdbProtoError::InvalidPayload); } Ok(u32::from_le_bytes(payload[0..4].try_into().unwrap())) } pub fn encode_nodemask(active: &[u32]) -> Vec { let mut buf = Vec::new(); buf.extend_from_slice(&(active.len() as u32).to_le_bytes()); for &id in active { buf.extend_from_slice(&id.to_le_bytes()); } buf } pub fn decode_nodemask(payload: &[u8]) -> Result, CtdbProtoError> { if payload.len() < 4 { return Err(CtdbProtoError::InvalidPayload); } let count = u32::from_le_bytes(payload[0..4].try_into().unwrap()) as usize; let mut ids = Vec::with_capacity(count); for i in 0..count { let offset = 4 + i * 4; if payload.len() < offset + 4 { return Err(CtdbProtoError::InvalidPayload); } ids.push(u32::from_le_bytes(payload[offset..offset + 4].try_into().unwrap())); } Ok(ids) } pub fn encode_ip(ip: &str, interface: &str) -> Vec { let ip_bytes = ip.as_bytes(); let if_bytes = interface.as_bytes(); let mut buf = Vec::new(); buf.extend_from_slice(&(ip_bytes.len() as u32).to_le_bytes()); buf.extend_from_slice(ip_bytes); buf.extend_from_slice(&(if_bytes.len() as u32).to_le_bytes()); buf.extend_from_slice(if_bytes); buf } pub fn decode_ip(payload: &[u8]) -> Result<(String, String), CtdbProtoError> { if payload.len() < 4 { return Err(CtdbProtoError::InvalidPayload); } let ip_len = u32::from_le_bytes(payload[0..4].try_into().unwrap()) as usize; if payload.len() < 4 + ip_len + 4 { return Err(CtdbProtoError::InvalidPayload); } let ip = String::from_utf8_lossy(&payload[4..4 + ip_len]).to_string(); let if_offset = 4 + ip_len; let if_len = u32::from_le_bytes( payload[if_offset..if_offset + 4].try_into().unwrap(), ) as usize; let if_start = if_offset + 4; if payload.len() < if_start + if_len { return Err(CtdbProtoError::InvalidPayload); } let interface = String::from_utf8_lossy(&payload[if_start..if_start + if_len]).to_string(); Ok((ip, interface)) } } pub struct CtdbConnection { stream: TcpStream, } impl CtdbConnection { pub fn new(stream: TcpStream) -> Self { Self { stream } } pub fn connect(addr: &str) -> Result { let stream = TcpStream::connect(addr)?; Ok(Self { stream }) } pub fn send_message(&mut self, msg: &CtdbMessage) -> Result<(), CtdbProtoError> { let bytes = msg.to_bytes(); self.stream.write_all(&bytes)?; Ok(()) } pub fn recv_message(&mut self) -> Result { let mut header_buf = [0u8; CTDB_HEADER_SIZE]; self.stream.read_exact(&mut header_buf)?; let header = CtdbHeader::from_bytes(&header_buf)?; if !header.is_valid() { return Err(CtdbProtoError::InvalidMagic); } let payload_len = header.length as usize; let mut payload = vec![0u8; payload_len]; if payload_len > 0 { self.stream.read_exact(&mut payload)?; } Ok(CtdbMessage { header, payload }) } pub fn ping(&mut self) -> Result<(), CtdbProtoError> { let msg = CtdbMessage::new(CtdbCommand::Ping, CtdbStatus::Success, vec![]); self.send_message(&msg)?; let resp = self.recv_message()?; if resp.command() == CtdbCommand::Pong && resp.status() == CtdbStatus::Success { Ok(()) } else { Err(CtdbProtoError::InvalidPayload) } } pub fn store(&mut self, key: &[u8], value: &[u8]) -> Result { let payload = payload::encode_kv(key, value); let msg = CtdbMessage::new(CtdbCommand::Store, CtdbStatus::Success, payload); self.send_message(&msg)?; let resp = self.recv_message()?; Ok(resp.status() == CtdbStatus::Success) } pub fn fetch(&mut self, key: &[u8]) -> Result, CtdbProtoError> { let payload = payload::encode_key(key); let msg = CtdbMessage::new(CtdbCommand::Fetch, CtdbStatus::Success, payload); self.send_message(&msg)?; let resp = self.recv_message()?; if resp.status() == CtdbStatus::Success { Ok(resp.payload) } else { Err(CtdbProtoError::InvalidPayload) } } pub fn delete(&mut self, key: &[u8]) -> Result { let payload = payload::encode_key(key); let msg = CtdbMessage::new(CtdbCommand::Delete, CtdbStatus::Success, payload); self.send_message(&msg)?; let resp = self.recv_message()?; Ok(resp.status() == CtdbStatus::Success) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_header_roundtrip() { let header = CtdbHeader::new(CtdbCommand::Ping, CtdbStatus::Success, 42); let bytes = header.to_bytes(); let restored = CtdbHeader::from_bytes(&bytes).unwrap(); assert_eq!(restored.magic, CTDB_MAGIC); assert_eq!(restored.version, CTDB_VERSION); assert_eq!(restored.command, CtdbCommand::Ping as u32); assert_eq!(restored.status, CtdbStatus::Success as u32); assert_eq!(restored.length, 42); assert!(restored.is_valid()); } #[test] fn test_message_roundtrip() { let msg = CtdbMessage::new( CtdbCommand::Store, CtdbStatus::Success, b"test_payload".to_vec(), ); let bytes = msg.to_bytes(); let restored = CtdbMessage::from_bytes(&bytes).unwrap(); assert_eq!(restored.command(), CtdbCommand::Store); assert_eq!(restored.payload, b"test_payload"); } #[test] fn test_command_from_u32() { assert_eq!(CtdbCommand::from_u32(1), CtdbCommand::Connect); assert_eq!(CtdbCommand::from_u32(3), CtdbCommand::Ping); assert_eq!(CtdbCommand::from_u32(0xFFFF), CtdbCommand::Unknown); assert_eq!(CtdbCommand::from_u32(999), CtdbCommand::Unknown); } #[test] fn test_status_from_u32() { assert_eq!(CtdbStatus::from_u32(0), CtdbStatus::Success); assert_eq!(CtdbStatus::from_u32(2), CtdbStatus::NotFound); assert_eq!(CtdbStatus::from_u32(99), CtdbStatus::Error); } #[test] fn test_payload_encode_decode_kv() { let (key, val) = (b"mykey", b"myvalue"); let encoded = payload::encode_kv(key, val); let (k, v) = payload::decode_kv(&encoded).unwrap(); assert_eq!(k, key); assert_eq!(v, val); } #[test] fn test_payload_encode_decode_key() { let key = b"test_key"; let encoded = payload::encode_key(key); let decoded = payload::decode_key(&encoded).unwrap(); assert_eq!(decoded, key); } #[test] fn test_payload_node_id() { let encoded = payload::encode_node_id(42); let decoded = payload::decode_node_id(&encoded).unwrap(); assert_eq!(decoded, 42); } #[test] fn test_payload_nodemask() { let ids = vec![0u32, 1, 2, 3]; let encoded = payload::encode_nodemask(&ids); let decoded = payload::decode_nodemask(&encoded).unwrap(); assert_eq!(decoded, ids); } #[test] fn test_payload_ip() { let encoded = payload::encode_ip("192.168.1.100", "eth0"); let (ip, iface) = payload::decode_ip(&encoded).unwrap(); assert_eq!(ip, "192.168.1.100"); assert_eq!(iface, "eth0"); } #[test] fn test_invalid_magic() { let mut bad_header = CtdbHeader::new(CtdbCommand::Ping, CtdbStatus::Success, 0); bad_header.magic = 0xDEADBEEF; assert!(!bad_header.is_valid()); } #[test] fn test_empty_message() { let msg = CtdbMessage::new(CtdbCommand::Connect, CtdbStatus::Success, vec![]); let bytes = msg.to_bytes(); let restored = CtdbMessage::from_bytes(&bytes).unwrap(); assert!(restored.payload.is_empty()); assert_eq!(restored.header.length, 0); } #[test] fn test_header_too_short() { let result = CtdbHeader::from_bytes(&[0u8; 10]); assert!(result.is_err()); } #[test] fn test_large_payload() { let large = vec![0xABu8; 65000]; let msg = CtdbMessage::new(CtdbCommand::Fetch, CtdbStatus::Success, large.clone()); let bytes = msg.to_bytes(); let restored = CtdbMessage::from_bytes(&bytes).unwrap(); assert_eq!(restored.payload.len(), 65000); assert_eq!(restored.payload, large); } }