// SCP协议实现(Phase 8) // 参考OpenSSH scp.c源码 use crate::vfs::open_flags::OpenFlags; use crate::vfs::{VfsBackend, VfsStat}; use anyhow::{anyhow, Result}; use log::{debug, info, warn}; use std::io::{BufRead, Read, Write}; use std::path::{Path, PathBuf}; /// SCP Handler(参考OpenSSH scp.c) pub struct ScpHandler { pub root_dir: PathBuf, mode: ScpMode, recursive: bool, preserve_times: bool, pub vfs: Box, } #[derive(Debug, Clone, PartialEq, Eq)] pub enum ScpMode { Source, // scp -f(发送文件) Destination, // scp -t(接收文件) } impl ScpHandler { pub fn new(root_dir: PathBuf, vfs: Box) -> Self { Self { root_dir, mode: ScpMode::Destination, recursive: false, preserve_times: false, vfs, } } /// 解析SCP命令(参考OpenSSH scp.c: parse_command()) pub fn parse_scp_command(command: &str, vfs: Box) -> Result { let parts: Vec<&str> = command.split_whitespace().collect(); if parts.len() < 2 || parts[0] != "scp" { return Err(anyhow!("Invalid SCP command: {}", command)); } let mut handler = ScpHandler::new(PathBuf::from("/tmp"), vfs); for part in &parts[1..] { match part { &"-f" => handler.mode = ScpMode::Source, &"-t" => handler.mode = ScpMode::Destination, &"-r" => handler.recursive = true, &"-p" => handler.preserve_times = true, path if !path.starts_with('-') => { handler.root_dir = PathBuf::from(path); } _ => warn!("Unknown SCP flag: {}", part), } } Ok(handler) } /// 处理SCP传输(参考OpenSSH scp.c: source() / sink()) pub fn handle_scp(&mut self, channel: &mut dyn ReadWrite) -> Result<()> { match self.mode { ScpMode::Source => self.handle_source_mode(channel), ScpMode::Destination => self.handle_destination_mode(channel), } } /// SCP Source Mode(scp -f,发送文件) fn handle_source_mode(&self, channel: &mut dyn ReadWrite) -> Result<()> { info!( "SCP source mode: sending files from {}", self.root_dir.display() ); let full_path = self.resolve_path(&self.root_dir.to_string_lossy())?; let stat = self .vfs .stat(&full_path) .map_err(|e| anyhow!("stat error: {}", e))?; if stat.is_dir { if !self.recursive { return Err(anyhow!("Directory detected but -r flag not specified")); } self.send_directory(channel, &full_path)?; } else { self.send_file(channel, &full_path)?; } Ok(()) } /// SCP Destination Mode(scp -t,接收文件) fn handle_destination_mode(&mut self, channel: &mut dyn ReadWrite) -> Result<()> { info!( "SCP destination mode: receiving files to {}", self.root_dir.display() ); channel.write_all(&[0])?; channel.flush()?; let mut buffer = String::new(); loop { buffer.clear(); let mut reader = std::io::BufReader::new(&mut *channel); match reader.read_line(&mut buffer)? { 0 => break, _ => { let command = buffer.trim(); debug!("SCP command: {}", command); match command.chars().next() { Some('C') => self.handle_file_command(channel, command)?, Some('D') => self.handle_directory_command(channel, command)?, Some('E') => self.handle_end_directory(channel)?, Some('T') => self.handle_time_command(channel, command)?, Some('\0') => { continue; } _ => { warn!("Unknown SCP command: {}", command); self.send_error(channel, &format!("Unknown command: {}", command))?; } } } } } Ok(()) } /// 发送文件(参考OpenSSH scp.c: source()) fn send_file(&self, channel: &mut dyn ReadWrite, path: &Path) -> Result<()> { let stat = self .vfs .stat(path) .map_err(|e| anyhow!("stat error: {}", e))?; let size = stat.size; let filename = path.file_name().unwrap().to_string_lossy(); let command = format!("C0644 {} {}\n", size, filename); channel.write_all(command.as_bytes())?; channel.flush()?; let mut ack = [0u8; 1]; channel.read_exact(&mut ack)?; if ack[0] != 0 { return Err(anyhow!("SCP file command rejected")); } let flags = OpenFlags::new().read(); let mut file = self .vfs .open_file(path, &flags) .map_err(|e| anyhow!("open error: {}", e))?; let mut buffer = vec![0u8; 8192]; loop { let n = file .read(&mut buffer) .map_err(|e| anyhow!("read error: {}", e))?; if n == 0 { break; } channel.write_all(&buffer[..n])?; } channel.flush()?; channel.write_all(&[0])?; channel.flush()?; channel.read_exact(&mut ack)?; if ack[0] != 0 { return Err(anyhow!("SCP file transfer rejected")); } info!("SCP file sent: {} ({} bytes)", filename, size); Ok(()) } /// 发送目录(参考OpenSSH scp.c: source()) fn send_directory(&self, channel: &mut dyn ReadWrite, path: &Path) -> Result<()> { let dirname = path.file_name().unwrap().to_string_lossy(); let command = format!("D0755 0 {}\n", dirname); channel.write_all(command.as_bytes())?; channel.flush()?; let mut ack = [0u8; 1]; channel.read_exact(&mut ack)?; if ack[0] != 0 { return Err(anyhow!("SCP directory command rejected")); } let entries = self .vfs .read_dir(path) .map_err(|e| anyhow!("read_dir error: {}", e))?; for entry in &entries { let entry_path = path.join(&entry.name); if entry.stat.is_dir { if self.recursive { self.send_directory(channel, &entry_path)?; } } else { self.send_file(channel, &entry_path)?; } } channel.write_all("E\n".as_bytes())?; channel.flush()?; channel.read_exact(&mut ack)?; if ack[0] != 0 { return Err(anyhow!("SCP end directory rejected")); } info!("SCP directory sent: {}", dirname); Ok(()) } /// 处理文件命令(C0644 size filename) fn handle_file_command(&self, channel: &mut dyn ReadWrite, command: &str) -> Result<()> { let parts: Vec<&str> = command.split_whitespace().collect(); if parts.len() != 3 { return self.send_error(channel, "Invalid file command format"); } let mode_str = parts[0].trim_start_matches('C'); let size: u64 = parts[1].parse()?; let filename = parts[2]; debug!( "SCP receive file: mode={}, size={}, name={}", mode_str, size, filename ); if size > 1024 * 1024 * 1024 { return self.send_error(channel, "File too large (max 1GB)"); } let full_path = self.resolve_path(filename)?; let flags = OpenFlags::new().write().create().truncate(); let mut file = self .vfs .open_file(&full_path, &flags) .map_err(|e| anyhow!("open error: {}", e))?; channel.write_all(&[0])?; channel.flush()?; let mut buffer = vec![0u8; 8192]; let mut remaining = size; while remaining > 0 { let to_read = std::cmp::min(buffer.len() as u64, remaining) as usize; let n = channel.read(&mut buffer[..to_read])?; if n == 0 { break; } file.write_all(&buffer[..n]) .map_err(|e| anyhow!("write error: {}", e))?; remaining -= n as u64; } file.flush().map_err(|e| anyhow!("flush error: {}", e))?; // 设置文件权限 let mode_int: u32 = mode_str.parse()?; if mode_int != 0 { let mut set_stat = VfsStat::new(); set_stat.mode = mode_int; self.vfs .set_stat(&full_path, &set_stat) .map_err(|e| anyhow!("set_stat error: {}", e))?; } let mut ack = [0u8; 1]; channel.read_exact(&mut ack)?; channel.write_all(&[0])?; channel.flush()?; info!("SCP file received: {} ({} bytes)", filename, size); Ok(()) } /// 处理目录命令(D0755 0 dirname) fn handle_directory_command(&self, channel: &mut dyn ReadWrite, command: &str) -> Result<()> { let parts: Vec<&str> = command.split_whitespace().collect(); if parts.len() != 3 { return self.send_error(channel, "Invalid directory command format"); } if !self.recursive { return self.send_error(channel, "Recursive flag not specified"); } let mode_str = parts[0].trim_start_matches('D'); let dirname = parts[2]; debug!("SCP receive directory: mode={}, name={}", mode_str, dirname); let full_path = self.resolve_path(dirname)?; let mode_int: u32 = mode_str.parse()?; self.vfs .create_dir_all(&full_path, mode_int) .map_err(|e| anyhow!("create_dir_all error: {}", e))?; channel.write_all(&[0])?; channel.flush()?; info!("SCP directory created: {}", dirname); Ok(()) } /// 处理结束目录命令(E) fn handle_end_directory(&self, channel: &mut dyn ReadWrite) -> Result<()> { debug!("SCP end directory"); channel.write_all(&[0])?; channel.flush()?; Ok(()) } /// 处理时间命令(T mtime atime) fn handle_time_command(&self, channel: &mut dyn ReadWrite, command: &str) -> Result<()> { if !self.preserve_times { channel.write_all(&[0])?; channel.flush()?; return Ok(()); } let parts: Vec<&str> = command.split_whitespace().collect(); if parts.len() != 3 { return self.send_error(channel, "Invalid time command format"); } let mtime_secs: i64 = parts[1].parse()?; let atime_secs: i64 = parts[2].parse()?; debug!("SCP set times: mtime={}, atime={}", mtime_secs, atime_secs); channel.write_all(&[0])?; channel.flush()?; Ok(()) } /// 发送错误消息 fn send_error(&self, channel: &mut dyn ReadWrite, message: &str) -> Result<()> { let error_msg = format!("{}\n", message); channel.write_all(error_msg.as_bytes())?; channel.flush()?; Err(anyhow!("SCP error: {}", message)) } /// 路径解析(安全性检查) pub fn resolve_path(&self, path: &str) -> Result { let full_path = self.root_dir.join(path); let canonical_path = self .vfs .real_path(&full_path) .map_err(|e| anyhow!("Path resolution error: {}", e))?; let root_canonical = self .vfs .real_path(&self.root_dir) .map_err(|e| anyhow!("Root path resolution error: {}", e))?; if !canonical_path.starts_with(&root_canonical) { return Err(anyhow!("Path traversal attempt detected")); } Ok(canonical_path) } } /// Read + Write trait组合(用于Channel) pub trait ReadWrite: Read + Write {} impl ReadWrite for T {} /// ⭐⭐⭐⭐⭐ Phase 8: Channel wrapper for SCP protocol /// 实现 Read + Write traits,用于 ScpHandler 和 SSH channel 之间传递数据 pub struct ChannelReadWrite { input_buffer: Vec, output_buffer: Vec, input_pos: usize, } impl ChannelReadWrite { pub fn new(input_buffer: Vec) -> Self { Self { input_buffer, output_buffer: Vec::new(), input_pos: 0, } } pub fn feed_input(&mut self, data: &[u8]) { self.input_buffer.extend_from_slice(data); } pub fn drain_output(&mut self) -> Vec { let output = self.output_buffer.clone(); self.output_buffer.clear(); output } pub fn has_remaining_input(&self) -> bool { self.input_pos < self.input_buffer.len() } } impl Read for ChannelReadWrite { fn read(&mut self, buf: &mut [u8]) -> std::io::Result { let remaining = self.input_buffer.len() - self.input_pos; let to_read = std::cmp::min(buf.len(), remaining); if to_read == 0 { return Ok(0); } buf[..to_read].copy_from_slice(&self.input_buffer[self.input_pos..self.input_pos + to_read]); self.input_pos += to_read; Ok(to_read) } } impl Write for ChannelReadWrite { fn write(&mut self, buf: &[u8]) -> std::io::Result { self.output_buffer.extend_from_slice(buf); Ok(buf.len()) } fn flush(&mut self) -> std::io::Result<()> { Ok(()) } } #[cfg(test)] mod tests { use super::*; use crate::vfs::local_fs::LocalFs; fn make_handler() -> ScpHandler { ScpHandler::new(PathBuf::from("/tmp"), Box::new(LocalFs::new())) } #[test] fn test_scp_command_parse() { let handler = ScpHandler::parse_scp_command("scp -t /tmp", Box::new(LocalFs::new())).unwrap(); assert_eq!(handler.mode, ScpMode::Destination); assert_eq!(handler.root_dir, PathBuf::from("/tmp")); } #[test] fn test_scp_recursive_parse() { let handler = ScpHandler::parse_scp_command("scp -r -t /tmp", Box::new(LocalFs::new())).unwrap(); assert!(handler.recursive); } #[test] fn test_scp_source_parse() { let handler = ScpHandler::parse_scp_command("scp -f /tmp", Box::new(LocalFs::new())).unwrap(); assert_eq!(handler.mode, ScpMode::Source); } }