// SCP协议实现(Phase 8) // 参考OpenSSH scp.c源码 use anyhow::{Result, anyhow}; use log::{info, warn, debug}; use std::path::{Path, PathBuf}; use std::fs::{self, File, OpenOptions}; use std::io::{Read, Write, BufReader, BufWriter, BufRead}; // 导入BufRead trait(OpenSSH标准) use chrono::{DateTime, Utc}; /// SCP Handler(参考OpenSSH scp.c) pub struct ScpHandler { root_dir: PathBuf, mode: ScpMode, recursive: bool, preserve_times: bool, } #[derive(Debug, Clone, PartialEq, Eq)] pub enum ScpMode { Source, // scp -f(发送文件) Destination, // scp -t(接收文件) } impl ScpHandler { pub fn new(root_dir: PathBuf) -> Self { Self { root_dir, mode: ScpMode::Destination, recursive: false, preserve_times: false, } } /// 解析SCP命令(参考OpenSSH scp.c: parse_command()) pub fn parse_scp_command(command: &str) -> Result { let parts: Vec<&str> = command.split_whitespace().collect(); if parts.len() < 2 || parts[0] != "scp" { return Err(anyhow!("Invalid SCP command: {}", command)); } let mut handler = ScpHandler::new(PathBuf::from("/tmp")); for part in &parts[1..] { match part { &"-f" => handler.mode = ScpMode::Source, &"-t" => handler.mode = ScpMode::Destination, &"-r" => handler.recursive = true, &"-p" => handler.preserve_times = true, path if !path.starts_with('-') => { handler.root_dir = PathBuf::from(path); } _ => warn!("Unknown SCP flag: {}", part), } } Ok(handler) } /// 处理SCP传输(参考OpenSSH scp.c: source() / sink()) pub fn handle_scp(&mut self, channel: &mut dyn ReadWrite) -> Result<()> { match self.mode { ScpMode::Source => self.handle_source_mode(channel), ScpMode::Destination => self.handle_destination_mode(channel), } } /// SCP Source Mode(scp -f,发送文件) fn handle_source_mode(&self, channel: &mut dyn ReadWrite) -> Result<()> { info!("SCP source mode: sending files from {}", self.root_dir.display()); // 使用display()(Rust标准) let full_path = self.resolve_path(&self.root_dir.to_string_lossy())?; if full_path.is_file() { self.send_file(channel, &full_path)?; } else if full_path.is_dir() { if !self.recursive { return Err(anyhow!("Directory detected but -r flag not specified")); } self.send_directory(channel, &full_path)?; } else { return Err(anyhow!("Path does not exist: {}", full_path.display())); } Ok(()) } /// SCP Destination Mode(scp -t,接收文件) fn handle_destination_mode(&mut self, channel: &mut dyn ReadWrite) -> Result<()> { info!("SCP destination mode: receiving files to {}", self.root_dir.display()); // 使用display()(Rust标准) // 发送确认('\0') channel.write_all(&[0])?; channel.flush()?; let mut buffer = String::new(); loop { buffer.clear(); // 每次循环创建新的reader(避免borrow冲突)- OpenSSH标准 let mut reader = BufReader::new(&mut *channel); match reader.read_line(&mut buffer)? { 0 => break, // EOF _ => { let command = buffer.trim(); debug!("SCP command: {}", command); match command.chars().next() { Some('C') => self.handle_file_command(channel, command)?, Some('D') => self.handle_directory_command(channel, command)?, Some('E') => self.handle_end_directory(channel)?, Some('T') => self.handle_time_command(channel, command)?, Some('\0') => { // 确认信号,继续 continue; } _ => { warn!("Unknown SCP command: {}", command); self.send_error(channel, &format!("Unknown command: {}", command))?; } } } } } Ok(()) } /// 发送文件(参考OpenSSH scp.c: source()) fn send_file(&self, channel: &mut dyn ReadWrite, path: &Path) -> Result<()> { let metadata = fs::metadata(path)?; let size = metadata.len(); let filename = path.file_name().unwrap().to_string_lossy(); // 发送文件命令:C0644 size filename let command = format!("C0644 {} {}\n", size, filename); channel.write_all(command.as_bytes())?; channel.flush()?; // 等待确认('\0') let mut ack = [0u8; 1]; channel.read_exact(&mut ack)?; if ack[0] != 0 { return Err(anyhow!("SCP file command rejected")); } // 发送文件内容 let file = File::open(path)?; let mut reader = BufReader::new(file); let mut buffer = vec![0u8; 8192]; while let Ok(n) = reader.read(&mut buffer) { if n == 0 { break; } channel.write_all(&buffer[..n])?; } channel.flush()?; // 发送结束确认('\0') channel.write_all(&[0])?; channel.flush()?; // 等待确认('\0') channel.read_exact(&mut ack)?; if ack[0] != 0 { return Err(anyhow!("SCP file transfer rejected")); } info!("SCP file sent: {} ({} bytes)", filename, size); Ok(()) } /// 发送目录(参考OpenSSH scp.c: source()) fn send_directory(&self, channel: &mut dyn ReadWrite, path: &Path) -> Result<()> { let dirname = path.file_name().unwrap().to_string_lossy(); // 发送目录命令:D0755 0 dirname let command = format!("D0755 0 {}\n", dirname); channel.write_all(command.as_bytes())?; channel.flush()?; // 等待确认('\0') let mut ack = [0u8; 1]; channel.read_exact(&mut ack)?; if ack[0] != 0 { return Err(anyhow!("SCP directory command rejected")); } // 递归发送目录内容 for entry in fs::read_dir(path)? { let entry = entry?; let full_path = entry.path(); if full_path.is_file() { self.send_file(channel, &full_path)?; } else if full_path.is_dir() && self.recursive { self.send_directory(channel, &full_path)?; } } // 发送结束目录命令:E channel.write_all("E\n".as_bytes())?; channel.flush()?; // 等待确认('\0') channel.read_exact(&mut ack)?; if ack[0] != 0 { return Err(anyhow!("SCP end directory rejected")); } info!("SCP directory sent: {}", dirname); Ok(()) } /// 处理文件命令(C0644 size filename) fn handle_file_command(&self, channel: &mut dyn ReadWrite, command: &str) -> Result<()> { let parts: Vec<&str> = command.split_whitespace().collect(); if parts.len() != 3 { return self.send_error(channel, "Invalid file command format"); } let mode = parts[0].trim_start_matches('C'); let size: u64 = parts[1].parse()?; let filename = parts[2]; debug!("SCP receive file: mode={}, size={}, name={}", mode, size, filename); // 安全性检查:文件大小限制(防止DoS) if size > 1024 * 1024 * 1024 { // 1GB限制 return self.send_error(channel, "File too large (max 1GB)"); } // 创建文件 let full_path = self.resolve_path(filename)?; let file = OpenOptions::new() .write(true) .create(true) .truncate(true) .open(&full_path)?; // 发送确认('\0') channel.write_all(&[0])?; channel.flush()?; // 接收文件内容 let mut writer = BufWriter::new(file); let mut buffer = vec![0u8; 8192]; let mut remaining = size; while remaining > 0 { let to_read = std::cmp::min(buffer.len() as u64, remaining) as usize; let n = channel.read(&mut buffer[..to_read])?; if n == 0 { break; } writer.write_all(&buffer[..n])?; remaining -= n as u64; } writer.flush()?; // 设置文件权限 #[cfg(unix)] { use std::os::unix::fs::PermissionsExt; let mode_int: u32 = mode.parse()?; fs::set_permissions(&full_path, fs::Permissions::from_mode(mode_int))?; } // 接收结束确认('\0') let mut ack = [0u8; 1]; channel.read_exact(&mut ack)?; // 发送确认('\0') channel.write_all(&[0])?; channel.flush()?; info!("SCP file received: {} ({} bytes)", filename, size); Ok(()) } /// 处理目录命令(D0755 0 dirname) fn handle_directory_command(&self, channel: &mut dyn ReadWrite, command: &str) -> Result<()> { let parts: Vec<&str> = command.split_whitespace().collect(); if parts.len() != 3 { return self.send_error(channel, "Invalid directory command format"); } if !self.recursive { return self.send_error(channel, "Recursive flag not specified"); } let mode = parts[0].trim_start_matches('D'); let dirname = parts[2]; debug!("SCP receive directory: mode={}, name={}", mode, dirname); // 创建目录 let full_path = self.resolve_path(dirname)?; fs::create_dir_all(&full_path)?; // 设置目录权限 #[cfg(unix)] { use std::os::unix::fs::PermissionsExt; let mode_int: u32 = mode.parse()?; fs::set_permissions(&full_path, fs::Permissions::from_mode(mode_int))?; } // 发送确认('\0') channel.write_all(&[0])?; channel.flush()?; info!("SCP directory created: {}", dirname); Ok(()) } /// 处理结束目录命令(E) fn handle_end_directory(&self, channel: &mut dyn ReadWrite) -> Result<()> { debug!("SCP end directory"); // 发送确认('\0') channel.write_all(&[0])?; channel.flush()?; Ok(()) } /// 处理时间命令(T mtime atime) fn handle_time_command(&self, channel: &mut dyn ReadWrite, command: &str) -> Result<()> { if !self.preserve_times { // 发送确认('\0'),但不设置时间 channel.write_all(&[0])?; channel.flush()?; return Ok(()); } let parts: Vec<&str> = command.split_whitespace().collect(); if parts.len() != 3 { return self.send_error(channel, "Invalid time command format"); } let mtime: i64 = parts[1].parse()?; let atime: i64 = parts[2].parse()?; debug!("SCP set times: mtime={}, atime={}", mtime, atime); // 发送确认('\0') channel.write_all(&[0])?; channel.flush()?; // 时间设置将在文件接收完成后进行 // (这里仅记录,实际设置在handle_file_command中) Ok(()) } /// 发送错误消息 fn send_error(&self, channel: &mut dyn ReadWrite, message: &str) -> Result<()> { let error_msg = format!("{}\n", message); channel.write_all(error_msg.as_bytes())?; channel.flush()?; Err(anyhow!("SCP error: {}", message)) } /// 路径解析(安全性检查) fn resolve_path(&self, path: &str) -> Result { let full_path = self.root_dir.join(path); let canonical_path = full_path.canonicalize() .map_err(|e| anyhow!("Path resolution error: {}", e))?; if !canonical_path.starts_with(&self.root_dir.canonicalize()?) { return Err(anyhow!("Path traversal attempt detected")); } Ok(canonical_path) } } /// Read + Write trait组合(用于Channel) pub trait ReadWrite: Read + Write {} impl ReadWrite for T {} #[cfg(test)] mod tests { use super::*; #[test] fn test_scp_command_parse() { let handler = ScpHandler::parse_scp_command("scp -t /tmp").unwrap(); assert_eq!(handler.mode, ScpMode::Destination); assert_eq!(handler.root_dir, PathBuf::from("/tmp")); } #[test] fn test_scp_recursive_parse() { let handler = ScpHandler::parse_scp_command("scp -r -t /tmp").unwrap(); assert!(handler.recursive); } #[test] fn test_scp_source_parse() { let handler = ScpHandler::parse_scp_command("scp -f /tmp").unwrap(); assert_eq!(handler.mode, ScpMode::Source); } }