use anyhow::{anyhow, Result}; use log::{info, warn}; use std::collections::HashMap; use std::fs; use std::io::{BufRead, BufReader}; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; use std::path::{Path, PathBuf}; #[derive(Debug, Clone, PartialEq)] pub enum KnownHostKey { Ed25519(Vec), Rsa(Vec), Ecdsa(Vec), Dsa(Vec), } #[derive(Debug, Clone)] pub struct KnownHostEntry { pub hosts: Vec, pub key_type: String, pub key: KnownHostKey, pub comment: Option, pub is_hashed: bool, pub is_cert_authority: bool, } impl KnownHostEntry { pub fn matches_host(&self, hostname: &str, ip: Option) -> bool { if self.is_hashed { return self.matches_hashed_host(hostname, ip); } for host in &self.hosts { if host == hostname { return true; } if let Some(ip_addr) = ip { if host == &ip_addr.to_string() { return true; } } if host.contains(',') { let parts: Vec<&str> = host.split(',').collect(); for part in parts { if part == hostname { return true; } if let Some(ip_addr) = ip { if part == &ip_addr.to_string() { return true; } } } } if host.starts_with('|') { if self.matches_pattern_host(host, hostname) { return true; } } } false } fn matches_hashed_host(&self, hostname: &str, _ip: Option) -> bool { for host in &self.hosts { if host.starts_with('|') { if let Ok(decoded) = decode_hashed_host(host) { if decoded == hostname { return true; } } } } false } fn matches_pattern_host(&self, pattern: &str, hostname: &str) -> bool { if pattern.contains('*') || pattern.contains('?') { let regex_pattern = pattern.replace('*', ".*").replace('?', "."); if let Ok(re) = regex::Regex::new(®ex_pattern) { return re.is_match(hostname); } } false } pub fn verify_key(&self, server_key: &[u8], key_type: &str) -> Result { if self.key_type != key_type { return Ok(false); } match &self.key { KnownHostKey::Ed25519(key_bytes) => { if key_type == "ssh-ed25519" { Ok(key_bytes == server_key) } else { Ok(false) } } KnownHostKey::Rsa(key_bytes) => { if key_type == "ssh-rsa" || key_type == "rsa-sha2-256" || key_type == "rsa-sha2-512" { Ok(key_bytes == server_key) } else { Ok(false) } } KnownHostKey::Ecdsa(key_bytes) => { if key_type.starts_with("ecdsa-sha2-") { Ok(key_bytes == server_key) } else { Ok(false) } } KnownHostKey::Dsa(key_bytes) => { if key_type == "ssh-dss" { Ok(key_bytes == server_key) } else { Ok(false) } } } } } fn decode_hashed_host(hashed: &str) -> Result { use base64::{engine::general_purpose::STANDARD, Engine as _}; let parts: Vec<&str> = hashed.split('|').collect(); if parts.len() < 4 || parts[0] != "1" { return Err(anyhow!("Invalid hashed host format")); } let salt = STANDARD.decode(parts[1])?; let hash = STANDARD.decode(parts[2])?; use sha2::{Digest, Sha256}; let mut hasher = Sha256::new(); hasher.update(&salt); hasher.update(parts[3].as_bytes()); let computed_hash = hasher.finalize(); if hash == computed_hash.as_slice() { Ok(parts[3].to_string()) } else { Err(anyhow!("Hash mismatch")) } } pub struct KnownHostsParser { entries: Vec, } impl KnownHostsParser { pub fn new() -> Self { Self { entries: Vec::new(), } } pub fn load_default() -> Result { let known_hosts_path = Self::default_known_hosts_path()?; Self::load_from_file(&known_hosts_path) } pub fn load_from_file(path: &Path) -> Result { if !path.exists() { info!("Known hosts file not found: {}", path.display()); return Ok(Self::new()); } let file = fs::File::open(path)?; let reader = BufReader::new(file); let mut parser = Self::new(); for line in reader.lines() { let line = line?; if line.is_empty() || line.starts_with('#') { continue; } if let Some(entry) = parser.parse_line(&line) { parser.entries.push(entry); } } info!("Loaded {} known hosts entries from {}", parser.entries.len(), path.display()); Ok(parser) } fn default_known_hosts_path() -> Result { let home = std::env::var("HOME") .or_else(|_| std::env::var("USERPROFILE")) .map_err(|_| anyhow!("Cannot determine home directory"))?; Ok(PathBuf::from(home).join(".ssh").join("known_hosts")) } fn parse_line(&self, line: &str) -> Option { let parts: Vec<&str> = line.split_whitespace().collect(); if parts.len() < 3 { return None; } let is_cert_authority = parts[0].starts_with("@cert-authority"); let (hosts_part, key_type, key_base64, rest_parts) = if is_cert_authority { (parts[1], parts[2], parts[3], &parts[4..]) } else { (parts[0], parts[1], parts[2], &parts[3..]) }; let comment = if rest_parts.len() > 0 { Some(rest_parts.join(" ")) } else { None }; let hosts: Vec = hosts_part.split(',').map(|s| s.to_string()).collect(); let is_hashed = hosts.iter().any(|h| h.starts_with('|')); let key = self.decode_key(key_type, key_base64)?; Some(KnownHostEntry { hosts, key_type: key_type.to_string(), key, comment, is_hashed, is_cert_authority, }) } fn decode_key(&self, key_type: &str, key_base64: &str) -> Option { use base64::{engine::general_purpose::STANDARD, Engine as _}; let key_bytes = STANDARD.decode(key_base64).ok()?; match key_type { "ssh-ed25519" => Some(KnownHostKey::Ed25519(key_bytes)), "ssh-rsa" | "rsa-sha2-256" | "rsa-sha2-512" => Some(KnownHostKey::Rsa(key_bytes)), "ecdsa-sha2-nistp256" | "ecdsa-sha2-nistp384" | "ecdsa-sha2-nistp521" => { Some(KnownHostKey::Ecdsa(key_bytes)) } "ssh-dss" => Some(KnownHostKey::Dsa(key_bytes)), _ => { warn!("Unknown key type: {}", key_type); None } } } pub fn verify_host_key( &self, hostname: &str, ip: Option, server_key: &[u8], key_type: &str, ) -> Result { let matching_entries: Vec<&KnownHostEntry> = self .entries .iter() .filter(|e| e.matches_host(hostname, ip)) .collect(); if matching_entries.is_empty() { return Ok(VerifyResult::UnknownHost); } for entry in matching_entries { if entry.verify_key(server_key, key_type)? { return Ok(VerifyResult::Verified); } } Ok(VerifyResult::KeyMismatch) } pub fn add_host_key( &self, hostname: &str, key_type: &str, key: &[u8], comment: Option<&str>, ) -> Result { use base64::{engine::general_purpose::STANDARD, Engine as _}; let key_base64 = STANDARD.encode(key); let line = if let Some(c) = comment { format!("{} {} {} {}", hostname, key_type, key_base64, c) } else { format!("{} {} {}", hostname, key_type, key_base64) }; Ok(line) } pub fn get_entries(&self) -> &[KnownHostEntry] { &self.entries } pub fn get_entries_for_host(&self, hostname: &str) -> Vec<&KnownHostEntry> { self.entries .iter() .filter(|e| e.matches_host(hostname, None)) .collect() } pub fn remove_host(&mut self, hostname: &str) -> usize { let original_len = self.entries.len(); self.entries.retain(|e| !e.matches_host(hostname, None)); original_len - self.entries.len() } pub fn hash_host(&self, hostname: &str) -> Result { use base64::{engine::general_purpose::STANDARD, Engine as _}; use rand::Rng; use sha2::{Digest, Sha256}; let salt: [u8; 20] = rand::rngs::OsRng.gen(); let mut hasher = Sha256::new(); hasher.update(&salt); hasher.update(hostname.as_bytes()); let hash = hasher.finalize(); Ok(format!( "|1|{}|{}|{}", STANDARD.encode(&salt), STANDARD.encode(&hash), hostname )) } } #[derive(Debug, Clone, PartialEq)] pub enum VerifyResult { Verified, KeyMismatch, UnknownHost, } impl std::fmt::Display for VerifyResult { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { VerifyResult::Verified => write!(f, "Host key verified"), VerifyResult::KeyMismatch => write!(f, "Host key mismatch - possible MITM attack"), VerifyResult::UnknownHost => write!(f, "Unknown host - key not found in known_hosts"), } } } #[cfg(test)] mod tests { use super::*; use base64::{engine::general_purpose::STANDARD, Engine as _}; use tempfile::TempDir; #[test] fn test_parse_simple_entry() { let parser = KnownHostsParser::new(); let valid_key = "c3NoLWVkMjU1MTkAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=="; let line = format!("example.com ssh-ed25519 {}", valid_key); let entry = parser.parse_line(&line).unwrap(); assert_eq!(entry.hosts, vec!["example.com"]); assert_eq!(entry.key_type, "ssh-ed25519"); assert!(!entry.is_hashed); assert!(!entry.is_cert_authority); } #[test] fn test_parse_multiple_hosts() { let parser = KnownHostsParser::new(); let valid_key = "c3NoLWVkMjU1MTkAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=="; let line = format!("host1,host2,192.168.1.1 ssh-ed25519 {}", valid_key); let entry = parser.parse_line(&line).unwrap(); assert_eq!(entry.hosts.len(), 3); assert!(entry.hosts.contains(&"host1".to_string())); assert!(entry.hosts.contains(&"host2".to_string())); } #[test] fn test_parse_cert_authority() { let parser = KnownHostsParser::new(); let valid_key = "c3NoLWVkMjU1MTkAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=="; let line = format!("@cert-authority *.example.com ssh-ed25519 {}", valid_key); let entry = parser.parse_line(&line).unwrap(); assert!(entry.is_cert_authority); } #[test] fn test_matches_host() { let parser = KnownHostsParser::new(); let valid_key = "c3NoLWVkMjU1MTkAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=="; let line = format!("example.com ssh-ed25519 {}", valid_key); let entry = parser.parse_line(&line).unwrap(); assert!(entry.matches_host("example.com", None)); assert!(!entry.matches_host("other.com", None)); } #[test] fn test_matches_ip() { let parser = KnownHostsParser::new(); let valid_key = "c3NoLWVkMjU1MTkAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=="; let line = format!("example.com,192.168.1.1 ssh-ed25519 {}", valid_key); let entry = parser.parse_line(&line).unwrap(); let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)); assert!(entry.matches_host("example.com", Some(ip))); } #[test] fn test_verify_host_key() { let valid_key = "c3NoLWVkMjU1MTkAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=="; let parser = KnownHostsParser::new(); let line = format!("example.com ssh-ed25519 {}", valid_key); let entry = parser.parse_line(&line).unwrap(); let mut parser = KnownHostsParser::new(); parser.entries.push(entry); let key_bytes = STANDARD.decode(valid_key).unwrap(); let result = parser.verify_host_key("example.com", None, &key_bytes, "ssh-ed25519"); assert_eq!(result.unwrap(), VerifyResult::Verified); let result = parser.verify_host_key("example.com", None, &[0u8; 32], "ssh-ed25519"); assert_eq!(result.unwrap(), VerifyResult::KeyMismatch); let result = parser.verify_host_key("unknown.com", None, &key_bytes, "ssh-ed25519"); assert_eq!(result.unwrap(), VerifyResult::UnknownHost); } #[test] fn test_load_from_file() { let valid_key = "c3NoLWVkMjU1MTkAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=="; let temp_dir = TempDir::new().unwrap(); let known_hosts_path = temp_dir.path().join("known_hosts"); fs::write( &known_hosts_path, format!("example.com ssh-ed25519 {}\n", valid_key), ) .unwrap(); let parser = KnownHostsParser::load_from_file(&known_hosts_path).unwrap(); assert_eq!(parser.entries.len(), 1); } #[test] fn test_add_host_key() { let parser = KnownHostsParser::new(); let key_bytes = vec![1, 2, 3, 4]; let line = parser.add_host_key("example.com", "ssh-ed25519", &key_bytes, None).unwrap(); assert!(line.contains("example.com")); assert!(line.contains("ssh-ed25519")); } #[test] fn test_remove_host() { let valid_key = "c3NoLWVkMjU1MTkAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=="; let parser = KnownHostsParser::new(); let line = format!("example.com ssh-ed25519 {}", valid_key); let entry = parser.parse_line(&line).unwrap(); let mut parser = KnownHostsParser::new(); parser.entries.push(entry); let removed = parser.remove_host("example.com"); assert_eq!(removed, 1); assert_eq!(parser.entries.len(), 0); } #[test] fn test_hash_host() { let parser = KnownHostsParser::new(); let hashed = parser.hash_host("example.com").unwrap(); assert!(hashed.starts_with("|1|")); } #[test] fn test_comment_parsing() { let valid_key = "c3NoLWVkMjU1MTkAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=="; let parser = KnownHostsParser::new(); let line = format!("example.com ssh-ed25519 {} this is a comment", valid_key); let entry = parser.parse_line(&line).unwrap(); assert_eq!(entry.comment, Some("this is a comment".to_string())); } #[test] fn test_empty_file() { let temp_dir = TempDir::new().unwrap(); let known_hosts_path = temp_dir.path().join("known_hosts"); fs::write(&known_hosts_path, "").unwrap(); let parser = KnownHostsParser::load_from_file(&known_hosts_path).unwrap(); assert_eq!(parser.entries.len(), 0); } #[test] fn test_skip_comments() { let valid_key = "c3NoLWVkMjU1MTkAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=="; let temp_dir = TempDir::new().unwrap(); let known_hosts_path = temp_dir.path().join("known_hosts"); fs::write( &known_hosts_path, format!("# This is a comment\nexample.com ssh-ed25519 {}\n", valid_key), ) .unwrap(); let parser = KnownHostsParser::load_from_file(&known_hosts_path).unwrap(); assert_eq!(parser.entries.len(), 1); } }