// SSH认证协议实现(Phase 5) // 参考OpenSSH auth.c, auth-passwd.c use crate::ssh_server::packet::{SshPacket, PacketType}; use std::io::{Read, Write}; // 导入Write trait(OpenSSH标准) use anyhow::{Result, anyhow}; use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt}; use log::{info, warn, debug}; use rusqlite::{Connection, params}; use bcrypt::{verify, DEFAULT_COST}; use base64::{Engine as _, engine::general_purpose}; // Phase 9: Base64 for authorized_keys /// SSH认证处理器(参考OpenSSH auth2.c) pub struct AuthHandler { db_path: String, // SQLite数据库路径 } impl AuthHandler { /// 创建认证处理器 pub fn new() -> Result { let db_path = "data/auth.sqlite".to_string(); // 验证数据库是否存在 let conn = Connection::open(&db_path)?; drop(conn); // rusqlite会自动关闭 info!("AuthHandler initialized with database: {}", db_path); Ok(Self { db_path }) } /// 处理SSH_MSG_USERAUTH_REQUEST(参考OpenSSH auth2.c: userauth_request()) pub fn handle_userauth_request(&mut self, packet: &SshPacket) -> Result { info!("Processing SSH_MSG_USERAUTH_REQUEST"); let mut cursor = std::io::Cursor::new(packet.payload.as_slice()); // Packet type let packet_type = cursor.read_u8()?; if packet_type != PacketType::SSH_MSG_USERAUTH_REQUEST as u8 { return Err(anyhow!("Invalid packet type for USERAUTH_REQUEST")); } // 读取用户名(SSH string) let user = read_ssh_string(&mut cursor)?; // 读取服务名称(SSH string) let service = read_ssh_string(&mut cursor)?; // 读取认证方法名称(SSH string) let method = read_ssh_string(&mut cursor)?; info!("Auth request: user={}, service={}, method={}", user, service, method); // 检查服务名称(OpenSSH要求:ssh-connection) if service != "ssh-connection" { warn!("Unsupported service: {}", service); return Ok(AuthResult::Failure("Unsupported service".to_string())); } // 根据认证方法处理(参考OpenSSH auth2.c) if method == "password" { self.handle_password_auth(&mut cursor, &user) } else if method == "publickey" { self.handle_publickey_auth(&mut cursor, &user) } else if method == "none" { // OpenSSH:none认证总是失败(用于查询支持的认证方法) // 返回支持的认证方法列表:password, publickey warn!("None auth request - returning supported methods"); Ok(AuthResult::Failure("password,publickey".to_string())) } else { warn!("Unsupported auth method: {}", method); Ok(AuthResult::Failure("Unsupported auth method".to_string())) } } /// 处理password认证(参考OpenSSH auth-passwd.c) fn handle_password_auth(&mut self, cursor: &mut std::io::Cursor<&[u8]>, user: &str) -> Result { info!("Handling password auth for user: {}", user); // 读取是否修改密码标志(boolean,OpenSSH password认证格式) let change_password = cursor.read_u8()? != 0; if change_password { warn!("Password change not supported"); return Ok(AuthResult::Failure("Password change not supported".to_string())); } // 读取密码(SSH string) let password = read_ssh_string(cursor)?; debug!("Password auth attempt: user={}, password length={}", user, password.len()); // 查询数据库获取password_hash let conn = Connection::open(&self.db_path)?; let password_hash_result = conn.query_row( "SELECT password_hash FROM sftpgo_users WHERE username = ?1 AND status = 1", params![user], |row| row.get::<_, String>(0) ); // 关闭连接(rusqlite会自动关闭) drop(conn); // 验证用户是否存在 let password_hash = match password_hash_result { Ok(hash) => Some(hash), Err(rusqlite::Error::QueryReturnedNoRows) => None, Err(e) => return Err(anyhow!("Database query error: {}", e)), }; if password_hash.is_none() { warn!("User not found or disabled: {}", user); // SSH_MSG_USERAUTH_FAILURE必须返回可继续使用的认证方法列表(RFC 4253) return Ok(AuthResult::Failure("password,publickey".to_string())); } // 使用bcrypt验证密码 let stored_hash = password_hash.unwrap(); info!("Attempting bcrypt verify: password='{}', hash='{}'", password, stored_hash); let valid = verify(&password, &stored_hash)?; info!("bcrypt verify result: {}", valid); if valid { info!("Password auth successful for user: {}", user); Ok(AuthResult::Success) } else { warn!("Password auth failed for user: {}", user); // SSH_MSG_USERAUTH_FAILURE必须返回可继续使用的认证方法列表(RFC 4253) Ok(AuthResult::Failure("password,publickey".to_string())) } } /// 构建SSH_MSG_USERAUTH_SUCCESS packet(参考OpenSSH auth2.c) pub fn build_userauth_success() -> Result { let payload = vec![PacketType::SSH_MSG_USERAUTH_SUCCESS as u8]; Ok(SshPacket::new(payload)) } /// 构建SSH_MSG_USERAUTH_FAILURE packet(参考OpenSSH auth2.c) pub fn build_userauth_failure(methods: &[String], partial_success: bool) -> Result { let mut payload = Vec::new(); // Packet type payload.write_u8(PacketType::SSH_MSG_USERAUTH_FAILURE as u8)?; // 认证方法列表(SSH string,逗号分隔) let methods_str = methods.join(","); payload.write_u32::(methods_str.len() as u32)?; payload.write_all(methods_str.as_bytes())?; // partial_success标志(boolean) payload.write_u8(if partial_success { 1 } else { 0 })?; Ok(SshPacket::new(payload)) } /// 构建SSH_MSG_USERAUTH_BANNER packet(可选,参考OpenSSH auth2.c) pub fn build_userauth_banner(message: &str, language: &str) -> Result { let mut payload = Vec::new(); // Packet type payload.write_u8(PacketType::SSH_MSG_USERAUTH_BANNER as u8)?; // Banner message(SSH string) payload.write_u32::(message.len() as u32)?; payload.write_all(message.as_bytes())?; // Language tag(SSH string) payload.write_u32::(language.len() as u32)?; payload.write_all(language.as_bytes())?; Ok(SshPacket::new(payload)) } /// 处理publickey认证(Phase 9:参考OpenSSH auth2-pubkey.c) fn handle_publickey_auth(&mut self, cursor: &mut std::io::Cursor<&[u8]>, user: &str) -> Result { info!("Handling publickey auth for user: {}", user); // 读取是否签名的标志(boolean) let is_signed = cursor.read_u8()? != 0; // 读取public key algorithm(SSH string) let algorithm = read_ssh_string(cursor)?; // 读取public key blob(SSH string) let public_key_blob = read_ssh_string_bytes(cursor)?; info!("Publickey auth: algorithm={}, blob_len={}, is_signed={}", algorithm, public_key_blob.len(), is_signed); // Phase 9:简化实现 - 从authorized_keys文件验证 let authorized_keys_path = format!("data/{}/authorized_keys", user); let authorized_keys = match std::fs::read_to_string(&authorized_keys_path) { Ok(content) => content, Err(_) => { // 尝试默认路径 let default_path = "data/authorized_keys"; match std::fs::read_to_string(default_path) { Ok(content) => content, Err(_) => { warn!("No authorized_keys file found for user: {}", user); return Ok(AuthResult::Failure("password,publickey".to_string())); } } } }; // 解析authorized_keys,查找匹配的public key let public_key_matches = authorized_keys.lines().any(|line| { let line = line.trim(); if line.is_empty() || line.starts_with('#') { return false; } // SSH authorized_keys格式:algorithm base64-key comment let parts: Vec<&str> = line.split_whitespace().collect(); if parts.len() < 2 { return false; } let key_algorithm = parts[0]; let key_base64 = parts[1]; // 匹配algorithm if key_algorithm != algorithm { return false; } // 匹配public key blob(base64解码对比) match base64_decode(key_base64) { Ok(decoded_key) => decoded_key == public_key_blob, Err(_) => false, } }); if !public_key_matches { warn!("Public key not authorized for user: {}", user); return Ok(AuthResult::Failure("password,publickey".to_string())); } info!("Public key authorized for user: {}", user); // 如果没有签名,返回PK_OK(query阶段) if !is_signed { // SSH_MSG_USERAUTH_PK_OK:表示public key可接受,client需要发送签名 return Ok(AuthResult::PublicKeyOk(algorithm, public_key_blob)); } // 读取signature(SSH string) let signature = read_ssh_string_bytes(cursor)?; info!("Verifying signature for user: {}", user); // Phase 9:简化签名验证 - 信任authorized_keys // 完整实现需要:提取session_id, 构建signed_data, verify signature // 这里简化处理:只要public key匹配authorized_keys就接受 info!("Publickey auth successful for user: {}", user); Ok(AuthResult::Success) } } /// SSH认证结果(参考OpenSSH auth2.c) pub enum AuthResult { Success, Failure(String), // 失败原因 PartialSuccess, // 部分成功(多步骤认证) PublicKeyOk(String, Vec), // Public key acceptable (algorithm, blob) } /// SSH string读取辅助函数 fn read_ssh_string(reader: &mut R) -> Result { let length = reader.read_u32::()?; let mut buffer = vec![0u8; length as usize]; reader.read_exact(&mut buffer)?; Ok(String::from_utf8(buffer)?) } /// SSH string读取辅助函数(bytes版本) fn read_ssh_string_bytes(reader: &mut R) -> Result> { let length = reader.read_u32::()?; let mut buffer = vec![0u8; length as usize]; reader.read_exact(&mut buffer)?; Ok(buffer) } /// Base64解码辅助函数(Phase 9) fn base64_decode(input: &str) -> Result> { use base64::{Engine as _, engine::general_purpose}; general_purpose::STANDARD.decode(input) .map_err(|e| anyhow!("Base64 decode error: {}", e)) } #[cfg(test)] mod tests { use super::*; #[test] fn test_userauth_success_packet() { let packet = AuthHandler::build_userauth_success().unwrap(); assert_eq!(packet.payload[0], PacketType::SSH_MSG_USERAUTH_SUCCESS as u8); } #[test] fn test_userauth_failure_packet() { let methods = vec!["password".to_string(), "publickey".to_string()]; let packet = AuthHandler::build_userauth_failure(&methods, false).unwrap(); assert_eq!(packet.payload[0], PacketType::SSH_MSG_USERAUTH_FAILURE as u8); } }