Files
markbase/markbase-core/src/ssh_server/scp_handler.rs
Warren 7eb528d35f
Some checks failed
Test / test (push) Has been cancelled
Test / build (push) Has been cancelled
SMB Server Phase 2: VFS backend build fix + integration test
- Add VfsFile: Send supertrait for Mutex compatibility
- Fix SmbServerCommand: struct → Subcommand enum with Start variant
- Fix tracing_subscriber::init() → try_init() to avoid panic when
  logger already initialized
- Fix CLI subcommand name: smb-server → smb-start (flatten naming)
- Add #[command(name = "smb-start")] for CLI disambiguation
- Fix unused variable warnings (smb_fs.rs, smb_server_backend.rs)
- Remove unused VfsFile imports (webdav.rs, scp_handler.rs)
- Integration test: Docker smbclient verified (list, upload, read)
2026-06-20 19:42:29 +08:00

490 lines
14 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
// SCP协议实现Phase 8
// 参考OpenSSH scp.c源码
use crate::vfs::open_flags::OpenFlags;
use crate::vfs::{VfsBackend, VfsStat};
use anyhow::{anyhow, Result};
use log::{debug, info, warn};
use std::io::{BufRead, Read, Write};
use std::path::{Path, PathBuf};
/// SCP Handler参考OpenSSH scp.c
pub struct ScpHandler {
pub root_dir: PathBuf,
mode: ScpMode,
recursive: bool,
preserve_times: bool,
pub vfs: Box<dyn VfsBackend>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ScpMode {
Source, // scp -f发送文件
Destination, // scp -t接收文件
}
impl ScpHandler {
pub fn new(root_dir: PathBuf, vfs: Box<dyn VfsBackend>) -> Self {
Self {
root_dir,
mode: ScpMode::Destination,
recursive: false,
preserve_times: false,
vfs,
}
}
/// 解析SCP命令参考OpenSSH scp.c: parse_command()
pub fn parse_scp_command(command: &str, vfs: Box<dyn VfsBackend>) -> Result<Self> {
let parts: Vec<&str> = command.split_whitespace().collect();
if parts.len() < 2 || parts[0] != "scp" {
return Err(anyhow!("Invalid SCP command: {}", command));
}
let mut handler = ScpHandler::new(PathBuf::from("/tmp"), vfs);
for part in &parts[1..] {
match part {
&"-f" => handler.mode = ScpMode::Source,
&"-t" => handler.mode = ScpMode::Destination,
&"-r" => handler.recursive = true,
&"-p" => handler.preserve_times = true,
path if !path.starts_with('-') => {
handler.root_dir = PathBuf::from(path);
}
_ => warn!("Unknown SCP flag: {}", part),
}
}
Ok(handler)
}
/// 处理SCP传输参考OpenSSH scp.c: source() / sink()
pub fn handle_scp(&mut self, channel: &mut dyn ReadWrite) -> Result<()> {
match self.mode {
ScpMode::Source => self.handle_source_mode(channel),
ScpMode::Destination => self.handle_destination_mode(channel),
}
}
/// SCP Source Modescp -f发送文件
fn handle_source_mode(&self, channel: &mut dyn ReadWrite) -> Result<()> {
info!(
"SCP source mode: sending files from {}",
self.root_dir.display()
);
let full_path = self.resolve_path(&self.root_dir.to_string_lossy())?;
let stat = self
.vfs
.stat(&full_path)
.map_err(|e| anyhow!("stat error: {}", e))?;
if stat.is_dir {
if !self.recursive {
return Err(anyhow!("Directory detected but -r flag not specified"));
}
self.send_directory(channel, &full_path)?;
} else {
self.send_file(channel, &full_path)?;
}
Ok(())
}
/// SCP Destination Modescp -t接收文件
fn handle_destination_mode(&mut self, channel: &mut dyn ReadWrite) -> Result<()> {
info!(
"SCP destination mode: receiving files to {}",
self.root_dir.display()
);
channel.write_all(&[0])?;
channel.flush()?;
let mut buffer = String::new();
loop {
buffer.clear();
let mut reader = std::io::BufReader::new(&mut *channel);
match reader.read_line(&mut buffer)? {
0 => break,
_ => {
let command = buffer.trim();
debug!("SCP command: {}", command);
match command.chars().next() {
Some('C') => self.handle_file_command(channel, command)?,
Some('D') => self.handle_directory_command(channel, command)?,
Some('E') => self.handle_end_directory(channel)?,
Some('T') => self.handle_time_command(channel, command)?,
Some('\0') => {
continue;
}
_ => {
warn!("Unknown SCP command: {}", command);
self.send_error(channel, &format!("Unknown command: {}", command))?;
}
}
}
}
}
Ok(())
}
/// 发送文件参考OpenSSH scp.c: source()
fn send_file(&self, channel: &mut dyn ReadWrite, path: &Path) -> Result<()> {
let stat = self
.vfs
.stat(path)
.map_err(|e| anyhow!("stat error: {}", e))?;
let size = stat.size;
let filename = path.file_name().unwrap().to_string_lossy();
let command = format!("C0644 {} {}\n", size, filename);
channel.write_all(command.as_bytes())?;
channel.flush()?;
let mut ack = [0u8; 1];
channel.read_exact(&mut ack)?;
if ack[0] != 0 {
return Err(anyhow!("SCP file command rejected"));
}
let flags = OpenFlags::new().read();
let mut file = self
.vfs
.open_file(path, &flags)
.map_err(|e| anyhow!("open error: {}", e))?;
let mut buffer = vec![0u8; 8192];
loop {
let n = file
.read(&mut buffer)
.map_err(|e| anyhow!("read error: {}", e))?;
if n == 0 {
break;
}
channel.write_all(&buffer[..n])?;
}
channel.flush()?;
channel.write_all(&[0])?;
channel.flush()?;
channel.read_exact(&mut ack)?;
if ack[0] != 0 {
return Err(anyhow!("SCP file transfer rejected"));
}
info!("SCP file sent: {} ({} bytes)", filename, size);
Ok(())
}
/// 发送目录参考OpenSSH scp.c: source()
fn send_directory(&self, channel: &mut dyn ReadWrite, path: &Path) -> Result<()> {
let dirname = path.file_name().unwrap().to_string_lossy();
let command = format!("D0755 0 {}\n", dirname);
channel.write_all(command.as_bytes())?;
channel.flush()?;
let mut ack = [0u8; 1];
channel.read_exact(&mut ack)?;
if ack[0] != 0 {
return Err(anyhow!("SCP directory command rejected"));
}
let entries = self
.vfs
.read_dir(path)
.map_err(|e| anyhow!("read_dir error: {}", e))?;
for entry in &entries {
let entry_path = path.join(&entry.name);
if entry.stat.is_dir {
if self.recursive {
self.send_directory(channel, &entry_path)?;
}
} else {
self.send_file(channel, &entry_path)?;
}
}
channel.write_all("E\n".as_bytes())?;
channel.flush()?;
channel.read_exact(&mut ack)?;
if ack[0] != 0 {
return Err(anyhow!("SCP end directory rejected"));
}
info!("SCP directory sent: {}", dirname);
Ok(())
}
/// 处理文件命令C0644 size filename
fn handle_file_command(&self, channel: &mut dyn ReadWrite, command: &str) -> Result<()> {
let parts: Vec<&str> = command.split_whitespace().collect();
if parts.len() != 3 {
return self.send_error(channel, "Invalid file command format");
}
let mode_str = parts[0].trim_start_matches('C');
let size: u64 = parts[1].parse()?;
let filename = parts[2];
debug!(
"SCP receive file: mode={}, size={}, name={}",
mode_str, size, filename
);
if size > 1024 * 1024 * 1024 {
return self.send_error(channel, "File too large (max 1GB)");
}
let full_path = self.resolve_path(filename)?;
let flags = OpenFlags::new().write().create().truncate();
let mut file = self
.vfs
.open_file(&full_path, &flags)
.map_err(|e| anyhow!("open error: {}", e))?;
channel.write_all(&[0])?;
channel.flush()?;
let mut buffer = vec![0u8; 8192];
let mut remaining = size;
while remaining > 0 {
let to_read = std::cmp::min(buffer.len() as u64, remaining) as usize;
let n = channel.read(&mut buffer[..to_read])?;
if n == 0 {
break;
}
file.write_all(&buffer[..n])
.map_err(|e| anyhow!("write error: {}", e))?;
remaining -= n as u64;
}
file.flush().map_err(|e| anyhow!("flush error: {}", e))?;
// 设置文件权限
let mode_int: u32 = mode_str.parse()?;
if mode_int != 0 {
let mut set_stat = VfsStat::new();
set_stat.mode = mode_int;
self.vfs
.set_stat(&full_path, &set_stat)
.map_err(|e| anyhow!("set_stat error: {}", e))?;
}
let mut ack = [0u8; 1];
channel.read_exact(&mut ack)?;
channel.write_all(&[0])?;
channel.flush()?;
info!("SCP file received: {} ({} bytes)", filename, size);
Ok(())
}
/// 处理目录命令D0755 0 dirname
fn handle_directory_command(&self, channel: &mut dyn ReadWrite, command: &str) -> Result<()> {
let parts: Vec<&str> = command.split_whitespace().collect();
if parts.len() != 3 {
return self.send_error(channel, "Invalid directory command format");
}
if !self.recursive {
return self.send_error(channel, "Recursive flag not specified");
}
let mode_str = parts[0].trim_start_matches('D');
let dirname = parts[2];
debug!("SCP receive directory: mode={}, name={}", mode_str, dirname);
let full_path = self.resolve_path(dirname)?;
let mode_int: u32 = mode_str.parse()?;
self.vfs
.create_dir_all(&full_path, mode_int)
.map_err(|e| anyhow!("create_dir_all error: {}", e))?;
channel.write_all(&[0])?;
channel.flush()?;
info!("SCP directory created: {}", dirname);
Ok(())
}
/// 处理结束目录命令E
fn handle_end_directory(&self, channel: &mut dyn ReadWrite) -> Result<()> {
debug!("SCP end directory");
channel.write_all(&[0])?;
channel.flush()?;
Ok(())
}
/// 处理时间命令T mtime atime
fn handle_time_command(&self, channel: &mut dyn ReadWrite, command: &str) -> Result<()> {
if !self.preserve_times {
channel.write_all(&[0])?;
channel.flush()?;
return Ok(());
}
let parts: Vec<&str> = command.split_whitespace().collect();
if parts.len() != 3 {
return self.send_error(channel, "Invalid time command format");
}
let mtime_secs: i64 = parts[1].parse()?;
let atime_secs: i64 = parts[2].parse()?;
debug!("SCP set times: mtime={}, atime={}", mtime_secs, atime_secs);
channel.write_all(&[0])?;
channel.flush()?;
Ok(())
}
/// 发送错误消息
fn send_error(&self, channel: &mut dyn ReadWrite, message: &str) -> Result<()> {
let error_msg = format!("{}\n", message);
channel.write_all(error_msg.as_bytes())?;
channel.flush()?;
Err(anyhow!("SCP error: {}", message))
}
/// 路径解析(安全性检查)
pub fn resolve_path(&self, path: &str) -> Result<PathBuf> {
let full_path = self.root_dir.join(path);
let canonical_path = self
.vfs
.real_path(&full_path)
.map_err(|e| anyhow!("Path resolution error: {}", e))?;
let root_canonical = self
.vfs
.real_path(&self.root_dir)
.map_err(|e| anyhow!("Root path resolution error: {}", e))?;
if !canonical_path.starts_with(&root_canonical) {
return Err(anyhow!("Path traversal attempt detected"));
}
Ok(canonical_path)
}
}
/// Read + Write trait组合用于Channel
pub trait ReadWrite: Read + Write {}
impl<T: Read + Write> ReadWrite for T {}
/// ⭐⭐⭐⭐⭐ Phase 8: Channel wrapper for SCP protocol
/// 实现 Read + Write traits用于 ScpHandler 和 SSH channel 之间传递数据
pub struct ChannelReadWrite {
input_buffer: Vec<u8>,
output_buffer: Vec<u8>,
input_pos: usize,
}
impl ChannelReadWrite {
pub fn new(input_buffer: Vec<u8>) -> Self {
Self {
input_buffer,
output_buffer: Vec::new(),
input_pos: 0,
}
}
pub fn feed_input(&mut self, data: &[u8]) {
self.input_buffer.extend_from_slice(data);
}
pub fn drain_output(&mut self) -> Vec<u8> {
let output = self.output_buffer.clone();
self.output_buffer.clear();
output
}
pub fn has_remaining_input(&self) -> bool {
self.input_pos < self.input_buffer.len()
}
}
impl Read for ChannelReadWrite {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
let remaining = self.input_buffer.len() - self.input_pos;
let to_read = std::cmp::min(buf.len(), remaining);
if to_read == 0 {
return Ok(0);
}
buf[..to_read].copy_from_slice(&self.input_buffer[self.input_pos..self.input_pos + to_read]);
self.input_pos += to_read;
Ok(to_read)
}
}
impl Write for ChannelReadWrite {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
self.output_buffer.extend_from_slice(buf);
Ok(buf.len())
}
fn flush(&mut self) -> std::io::Result<()> {
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::vfs::local_fs::LocalFs;
fn make_handler() -> ScpHandler {
ScpHandler::new(PathBuf::from("/tmp"), Box::new(LocalFs::new()))
}
#[test]
fn test_scp_command_parse() {
let handler =
ScpHandler::parse_scp_command("scp -t /tmp", Box::new(LocalFs::new())).unwrap();
assert_eq!(handler.mode, ScpMode::Destination);
assert_eq!(handler.root_dir, PathBuf::from("/tmp"));
}
#[test]
fn test_scp_recursive_parse() {
let handler =
ScpHandler::parse_scp_command("scp -r -t /tmp", Box::new(LocalFs::new())).unwrap();
assert!(handler.recursive);
}
#[test]
fn test_scp_source_parse() {
let handler =
ScpHandler::parse_scp_command("scp -f /tmp", Box::new(LocalFs::new())).unwrap();
assert_eq!(handler.mode, ScpMode::Source);
}
}