- Add VfsFile: Send supertrait for Mutex compatibility - Fix SmbServerCommand: struct → Subcommand enum with Start variant - Fix tracing_subscriber::init() → try_init() to avoid panic when logger already initialized - Fix CLI subcommand name: smb-server → smb-start (flatten naming) - Add #[command(name = "smb-start")] for CLI disambiguation - Fix unused variable warnings (smb_fs.rs, smb_server_backend.rs) - Remove unused VfsFile imports (webdav.rs, scp_handler.rs) - Integration test: Docker smbclient verified (list, upload, read)
490 lines
14 KiB
Rust
490 lines
14 KiB
Rust
// SCP协议实现(Phase 8)
|
||
// 参考OpenSSH scp.c源码
|
||
|
||
use crate::vfs::open_flags::OpenFlags;
|
||
use crate::vfs::{VfsBackend, VfsStat};
|
||
use anyhow::{anyhow, Result};
|
||
use log::{debug, info, warn};
|
||
use std::io::{BufRead, Read, Write};
|
||
use std::path::{Path, PathBuf};
|
||
|
||
/// SCP Handler(参考OpenSSH scp.c)
|
||
pub struct ScpHandler {
|
||
pub root_dir: PathBuf,
|
||
mode: ScpMode,
|
||
recursive: bool,
|
||
preserve_times: bool,
|
||
pub vfs: Box<dyn VfsBackend>,
|
||
}
|
||
|
||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||
pub enum ScpMode {
|
||
Source, // scp -f(发送文件)
|
||
Destination, // scp -t(接收文件)
|
||
}
|
||
|
||
impl ScpHandler {
|
||
pub fn new(root_dir: PathBuf, vfs: Box<dyn VfsBackend>) -> Self {
|
||
Self {
|
||
root_dir,
|
||
mode: ScpMode::Destination,
|
||
recursive: false,
|
||
preserve_times: false,
|
||
vfs,
|
||
}
|
||
}
|
||
|
||
/// 解析SCP命令(参考OpenSSH scp.c: parse_command())
|
||
pub fn parse_scp_command(command: &str, vfs: Box<dyn VfsBackend>) -> Result<Self> {
|
||
let parts: Vec<&str> = command.split_whitespace().collect();
|
||
|
||
if parts.len() < 2 || parts[0] != "scp" {
|
||
return Err(anyhow!("Invalid SCP command: {}", command));
|
||
}
|
||
|
||
let mut handler = ScpHandler::new(PathBuf::from("/tmp"), vfs);
|
||
|
||
for part in &parts[1..] {
|
||
match part {
|
||
&"-f" => handler.mode = ScpMode::Source,
|
||
&"-t" => handler.mode = ScpMode::Destination,
|
||
&"-r" => handler.recursive = true,
|
||
&"-p" => handler.preserve_times = true,
|
||
path if !path.starts_with('-') => {
|
||
handler.root_dir = PathBuf::from(path);
|
||
}
|
||
_ => warn!("Unknown SCP flag: {}", part),
|
||
}
|
||
}
|
||
|
||
Ok(handler)
|
||
}
|
||
|
||
/// 处理SCP传输(参考OpenSSH scp.c: source() / sink())
|
||
pub fn handle_scp(&mut self, channel: &mut dyn ReadWrite) -> Result<()> {
|
||
match self.mode {
|
||
ScpMode::Source => self.handle_source_mode(channel),
|
||
ScpMode::Destination => self.handle_destination_mode(channel),
|
||
}
|
||
}
|
||
|
||
/// SCP Source Mode(scp -f,发送文件)
|
||
fn handle_source_mode(&self, channel: &mut dyn ReadWrite) -> Result<()> {
|
||
info!(
|
||
"SCP source mode: sending files from {}",
|
||
self.root_dir.display()
|
||
);
|
||
|
||
let full_path = self.resolve_path(&self.root_dir.to_string_lossy())?;
|
||
let stat = self
|
||
.vfs
|
||
.stat(&full_path)
|
||
.map_err(|e| anyhow!("stat error: {}", e))?;
|
||
|
||
if stat.is_dir {
|
||
if !self.recursive {
|
||
return Err(anyhow!("Directory detected but -r flag not specified"));
|
||
}
|
||
self.send_directory(channel, &full_path)?;
|
||
} else {
|
||
self.send_file(channel, &full_path)?;
|
||
}
|
||
|
||
Ok(())
|
||
}
|
||
|
||
/// SCP Destination Mode(scp -t,接收文件)
|
||
fn handle_destination_mode(&mut self, channel: &mut dyn ReadWrite) -> Result<()> {
|
||
info!(
|
||
"SCP destination mode: receiving files to {}",
|
||
self.root_dir.display()
|
||
);
|
||
|
||
channel.write_all(&[0])?;
|
||
channel.flush()?;
|
||
|
||
let mut buffer = String::new();
|
||
|
||
loop {
|
||
buffer.clear();
|
||
|
||
let mut reader = std::io::BufReader::new(&mut *channel);
|
||
match reader.read_line(&mut buffer)? {
|
||
0 => break,
|
||
_ => {
|
||
let command = buffer.trim();
|
||
debug!("SCP command: {}", command);
|
||
|
||
match command.chars().next() {
|
||
Some('C') => self.handle_file_command(channel, command)?,
|
||
Some('D') => self.handle_directory_command(channel, command)?,
|
||
Some('E') => self.handle_end_directory(channel)?,
|
||
Some('T') => self.handle_time_command(channel, command)?,
|
||
Some('\0') => {
|
||
continue;
|
||
}
|
||
_ => {
|
||
warn!("Unknown SCP command: {}", command);
|
||
self.send_error(channel, &format!("Unknown command: {}", command))?;
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
Ok(())
|
||
}
|
||
|
||
/// 发送文件(参考OpenSSH scp.c: source())
|
||
fn send_file(&self, channel: &mut dyn ReadWrite, path: &Path) -> Result<()> {
|
||
let stat = self
|
||
.vfs
|
||
.stat(path)
|
||
.map_err(|e| anyhow!("stat error: {}", e))?;
|
||
let size = stat.size;
|
||
let filename = path.file_name().unwrap().to_string_lossy();
|
||
|
||
let command = format!("C0644 {} {}\n", size, filename);
|
||
channel.write_all(command.as_bytes())?;
|
||
channel.flush()?;
|
||
|
||
let mut ack = [0u8; 1];
|
||
channel.read_exact(&mut ack)?;
|
||
if ack[0] != 0 {
|
||
return Err(anyhow!("SCP file command rejected"));
|
||
}
|
||
|
||
let flags = OpenFlags::new().read();
|
||
let mut file = self
|
||
.vfs
|
||
.open_file(path, &flags)
|
||
.map_err(|e| anyhow!("open error: {}", e))?;
|
||
|
||
let mut buffer = vec![0u8; 8192];
|
||
|
||
loop {
|
||
let n = file
|
||
.read(&mut buffer)
|
||
.map_err(|e| anyhow!("read error: {}", e))?;
|
||
if n == 0 {
|
||
break;
|
||
}
|
||
channel.write_all(&buffer[..n])?;
|
||
}
|
||
|
||
channel.flush()?;
|
||
|
||
channel.write_all(&[0])?;
|
||
channel.flush()?;
|
||
|
||
channel.read_exact(&mut ack)?;
|
||
if ack[0] != 0 {
|
||
return Err(anyhow!("SCP file transfer rejected"));
|
||
}
|
||
|
||
info!("SCP file sent: {} ({} bytes)", filename, size);
|
||
Ok(())
|
||
}
|
||
|
||
/// 发送目录(参考OpenSSH scp.c: source())
|
||
fn send_directory(&self, channel: &mut dyn ReadWrite, path: &Path) -> Result<()> {
|
||
let dirname = path.file_name().unwrap().to_string_lossy();
|
||
|
||
let command = format!("D0755 0 {}\n", dirname);
|
||
channel.write_all(command.as_bytes())?;
|
||
channel.flush()?;
|
||
|
||
let mut ack = [0u8; 1];
|
||
channel.read_exact(&mut ack)?;
|
||
if ack[0] != 0 {
|
||
return Err(anyhow!("SCP directory command rejected"));
|
||
}
|
||
|
||
let entries = self
|
||
.vfs
|
||
.read_dir(path)
|
||
.map_err(|e| anyhow!("read_dir error: {}", e))?;
|
||
|
||
for entry in &entries {
|
||
let entry_path = path.join(&entry.name);
|
||
|
||
if entry.stat.is_dir {
|
||
if self.recursive {
|
||
self.send_directory(channel, &entry_path)?;
|
||
}
|
||
} else {
|
||
self.send_file(channel, &entry_path)?;
|
||
}
|
||
}
|
||
|
||
channel.write_all("E\n".as_bytes())?;
|
||
channel.flush()?;
|
||
|
||
channel.read_exact(&mut ack)?;
|
||
if ack[0] != 0 {
|
||
return Err(anyhow!("SCP end directory rejected"));
|
||
}
|
||
|
||
info!("SCP directory sent: {}", dirname);
|
||
Ok(())
|
||
}
|
||
|
||
/// 处理文件命令(C0644 size filename)
|
||
fn handle_file_command(&self, channel: &mut dyn ReadWrite, command: &str) -> Result<()> {
|
||
let parts: Vec<&str> = command.split_whitespace().collect();
|
||
|
||
if parts.len() != 3 {
|
||
return self.send_error(channel, "Invalid file command format");
|
||
}
|
||
|
||
let mode_str = parts[0].trim_start_matches('C');
|
||
let size: u64 = parts[1].parse()?;
|
||
let filename = parts[2];
|
||
|
||
debug!(
|
||
"SCP receive file: mode={}, size={}, name={}",
|
||
mode_str, size, filename
|
||
);
|
||
|
||
if size > 1024 * 1024 * 1024 {
|
||
return self.send_error(channel, "File too large (max 1GB)");
|
||
}
|
||
|
||
let full_path = self.resolve_path(filename)?;
|
||
|
||
let flags = OpenFlags::new().write().create().truncate();
|
||
let mut file = self
|
||
.vfs
|
||
.open_file(&full_path, &flags)
|
||
.map_err(|e| anyhow!("open error: {}", e))?;
|
||
|
||
channel.write_all(&[0])?;
|
||
channel.flush()?;
|
||
|
||
let mut buffer = vec![0u8; 8192];
|
||
let mut remaining = size;
|
||
|
||
while remaining > 0 {
|
||
let to_read = std::cmp::min(buffer.len() as u64, remaining) as usize;
|
||
let n = channel.read(&mut buffer[..to_read])?;
|
||
if n == 0 {
|
||
break;
|
||
}
|
||
file.write_all(&buffer[..n])
|
||
.map_err(|e| anyhow!("write error: {}", e))?;
|
||
remaining -= n as u64;
|
||
}
|
||
|
||
file.flush().map_err(|e| anyhow!("flush error: {}", e))?;
|
||
|
||
// 设置文件权限
|
||
let mode_int: u32 = mode_str.parse()?;
|
||
if mode_int != 0 {
|
||
let mut set_stat = VfsStat::new();
|
||
set_stat.mode = mode_int;
|
||
self.vfs
|
||
.set_stat(&full_path, &set_stat)
|
||
.map_err(|e| anyhow!("set_stat error: {}", e))?;
|
||
}
|
||
|
||
let mut ack = [0u8; 1];
|
||
channel.read_exact(&mut ack)?;
|
||
|
||
channel.write_all(&[0])?;
|
||
channel.flush()?;
|
||
|
||
info!("SCP file received: {} ({} bytes)", filename, size);
|
||
Ok(())
|
||
}
|
||
|
||
/// 处理目录命令(D0755 0 dirname)
|
||
fn handle_directory_command(&self, channel: &mut dyn ReadWrite, command: &str) -> Result<()> {
|
||
let parts: Vec<&str> = command.split_whitespace().collect();
|
||
|
||
if parts.len() != 3 {
|
||
return self.send_error(channel, "Invalid directory command format");
|
||
}
|
||
|
||
if !self.recursive {
|
||
return self.send_error(channel, "Recursive flag not specified");
|
||
}
|
||
|
||
let mode_str = parts[0].trim_start_matches('D');
|
||
let dirname = parts[2];
|
||
|
||
debug!("SCP receive directory: mode={}, name={}", mode_str, dirname);
|
||
|
||
let full_path = self.resolve_path(dirname)?;
|
||
|
||
let mode_int: u32 = mode_str.parse()?;
|
||
self.vfs
|
||
.create_dir_all(&full_path, mode_int)
|
||
.map_err(|e| anyhow!("create_dir_all error: {}", e))?;
|
||
|
||
channel.write_all(&[0])?;
|
||
channel.flush()?;
|
||
|
||
info!("SCP directory created: {}", dirname);
|
||
Ok(())
|
||
}
|
||
|
||
/// 处理结束目录命令(E)
|
||
fn handle_end_directory(&self, channel: &mut dyn ReadWrite) -> Result<()> {
|
||
debug!("SCP end directory");
|
||
|
||
channel.write_all(&[0])?;
|
||
channel.flush()?;
|
||
|
||
Ok(())
|
||
}
|
||
|
||
/// 处理时间命令(T mtime atime)
|
||
fn handle_time_command(&self, channel: &mut dyn ReadWrite, command: &str) -> Result<()> {
|
||
if !self.preserve_times {
|
||
channel.write_all(&[0])?;
|
||
channel.flush()?;
|
||
return Ok(());
|
||
}
|
||
|
||
let parts: Vec<&str> = command.split_whitespace().collect();
|
||
|
||
if parts.len() != 3 {
|
||
return self.send_error(channel, "Invalid time command format");
|
||
}
|
||
|
||
let mtime_secs: i64 = parts[1].parse()?;
|
||
let atime_secs: i64 = parts[2].parse()?;
|
||
|
||
debug!("SCP set times: mtime={}, atime={}", mtime_secs, atime_secs);
|
||
|
||
channel.write_all(&[0])?;
|
||
channel.flush()?;
|
||
|
||
Ok(())
|
||
}
|
||
|
||
/// 发送错误消息
|
||
fn send_error(&self, channel: &mut dyn ReadWrite, message: &str) -> Result<()> {
|
||
let error_msg = format!("{}\n", message);
|
||
channel.write_all(error_msg.as_bytes())?;
|
||
channel.flush()?;
|
||
Err(anyhow!("SCP error: {}", message))
|
||
}
|
||
|
||
/// 路径解析(安全性检查)
|
||
pub fn resolve_path(&self, path: &str) -> Result<PathBuf> {
|
||
let full_path = self.root_dir.join(path);
|
||
|
||
let canonical_path = self
|
||
.vfs
|
||
.real_path(&full_path)
|
||
.map_err(|e| anyhow!("Path resolution error: {}", e))?;
|
||
|
||
let root_canonical = self
|
||
.vfs
|
||
.real_path(&self.root_dir)
|
||
.map_err(|e| anyhow!("Root path resolution error: {}", e))?;
|
||
|
||
if !canonical_path.starts_with(&root_canonical) {
|
||
return Err(anyhow!("Path traversal attempt detected"));
|
||
}
|
||
|
||
Ok(canonical_path)
|
||
}
|
||
}
|
||
|
||
/// Read + Write trait组合(用于Channel)
|
||
pub trait ReadWrite: Read + Write {}
|
||
impl<T: Read + Write> ReadWrite for T {}
|
||
|
||
/// ⭐⭐⭐⭐⭐ Phase 8: Channel wrapper for SCP protocol
|
||
/// 实现 Read + Write traits,用于 ScpHandler 和 SSH channel 之间传递数据
|
||
pub struct ChannelReadWrite {
|
||
input_buffer: Vec<u8>,
|
||
output_buffer: Vec<u8>,
|
||
input_pos: usize,
|
||
}
|
||
|
||
impl ChannelReadWrite {
|
||
pub fn new(input_buffer: Vec<u8>) -> Self {
|
||
Self {
|
||
input_buffer,
|
||
output_buffer: Vec::new(),
|
||
input_pos: 0,
|
||
}
|
||
}
|
||
|
||
pub fn feed_input(&mut self, data: &[u8]) {
|
||
self.input_buffer.extend_from_slice(data);
|
||
}
|
||
|
||
pub fn drain_output(&mut self) -> Vec<u8> {
|
||
let output = self.output_buffer.clone();
|
||
self.output_buffer.clear();
|
||
output
|
||
}
|
||
|
||
pub fn has_remaining_input(&self) -> bool {
|
||
self.input_pos < self.input_buffer.len()
|
||
}
|
||
}
|
||
|
||
impl Read for ChannelReadWrite {
|
||
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
|
||
let remaining = self.input_buffer.len() - self.input_pos;
|
||
let to_read = std::cmp::min(buf.len(), remaining);
|
||
|
||
if to_read == 0 {
|
||
return Ok(0);
|
||
}
|
||
|
||
buf[..to_read].copy_from_slice(&self.input_buffer[self.input_pos..self.input_pos + to_read]);
|
||
self.input_pos += to_read;
|
||
|
||
Ok(to_read)
|
||
}
|
||
}
|
||
|
||
impl Write for ChannelReadWrite {
|
||
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
|
||
self.output_buffer.extend_from_slice(buf);
|
||
Ok(buf.len())
|
||
}
|
||
|
||
fn flush(&mut self) -> std::io::Result<()> {
|
||
Ok(())
|
||
}
|
||
}
|
||
|
||
#[cfg(test)]
|
||
mod tests {
|
||
use super::*;
|
||
use crate::vfs::local_fs::LocalFs;
|
||
|
||
fn make_handler() -> ScpHandler {
|
||
ScpHandler::new(PathBuf::from("/tmp"), Box::new(LocalFs::new()))
|
||
}
|
||
|
||
#[test]
|
||
fn test_scp_command_parse() {
|
||
let handler =
|
||
ScpHandler::parse_scp_command("scp -t /tmp", Box::new(LocalFs::new())).unwrap();
|
||
assert_eq!(handler.mode, ScpMode::Destination);
|
||
assert_eq!(handler.root_dir, PathBuf::from("/tmp"));
|
||
}
|
||
|
||
#[test]
|
||
fn test_scp_recursive_parse() {
|
||
let handler =
|
||
ScpHandler::parse_scp_command("scp -r -t /tmp", Box::new(LocalFs::new())).unwrap();
|
||
assert!(handler.recursive);
|
||
}
|
||
|
||
#[test]
|
||
fn test_scp_source_parse() {
|
||
let handler =
|
||
ScpHandler::parse_scp_command("scp -f /tmp", Box::new(LocalFs::new())).unwrap();
|
||
assert_eq!(handler.mode, ScpMode::Source);
|
||
}
|
||
}
|