Phase 1: take_payload() optimization - cipher.rs: Added take_payload() to EncryptedPacket - server.rs: Use take_payload() to avoid .to_vec() copy Phase 2a: reuse_buf for CHANNEL_DATA - channel.rs: Added reuse_buf to ExecProcess - handle_channel_data(): Read directly into reuse buffer Phase 2b: read_buf for stdout/stderr - channel.rs: Added read_buf to ExecProcess - poll_exec_stdout_and_client(): Use read_buf for all reads Phase 2c: AES-GCM padding optimization - cipher.rs: Removed padding .to_vec() in AES-GCM decrypt stdin fix: All exec commands use interactive process - channel.rs: Removed conditional rsync/SCP detection - All exec commands now use handle_interactive_exec() - Fixes cat/grep/sed stdin support (small files working) AES-GCM improvements: - cipher.rs: Added CipherMode enum (AES-GCM vs AES-CTR) - cipher.rs: AES-256 key derivation (32 bytes) - cipher.rs: Nonce format follows OpenSSH inc_iv() - kex.rs: Added aes256-gcm@openssh.com to algorithms Performance: ~21% improvement for small files Test: 158 passed, 0 failed Limitation: Large files (>10MB) not working yet (poll loop issue)
364 lines
9.8 KiB
Rust
364 lines
9.8 KiB
Rust
// SSH Buffer 零拷贝实现(参考 OpenSSH sshbuf.c)
|
||
// 提供高效的 buffer 管理,消除临时 buffer
|
||
|
||
use anyhow::{anyhow, Result};
|
||
use std::io::{Read, Write};
|
||
|
||
/// SSH Buffer(参考 OpenSSH struct sshbuf)
|
||
///
|
||
/// OpenSSH 实现:
|
||
/// ```c
|
||
/// struct sshbuf {
|
||
/// u_char *d; // Data (可变数据指针)
|
||
/// size_t off; // First available byte is buf->d + buf->off
|
||
/// size_t size; // Last byte is buf->d + buf->size - 1
|
||
/// size_t alloc; // Total bytes allocated to buf->d
|
||
/// };
|
||
/// ```
|
||
pub struct SshBuf {
|
||
data: Vec<u8>, // Data buffer (对应 OpenSSH buf->d)
|
||
off: usize, // Offset (对应 OpenSSH buf->off)
|
||
size: usize, // Size (对应 OpenSSH buf->size)
|
||
max_size: usize, // Maximum size (对应 OpenSSH buf->max_size)
|
||
}
|
||
|
||
impl SshBuf {
|
||
/// 创建新的 SSH Buffer
|
||
pub fn new() -> Self {
|
||
Self {
|
||
data: Vec::new(),
|
||
off: 0,
|
||
size: 0,
|
||
max_size: 128 * 1024 * 1024, // 128MB (OpenSSH SSHBUF_SIZE_MAX)
|
||
}
|
||
}
|
||
|
||
/// 创建指定大小的 SSH Buffer
|
||
pub fn with_capacity(capacity: usize) -> Self {
|
||
Self {
|
||
data: Vec::with_capacity(capacity),
|
||
off: 0,
|
||
size: 0,
|
||
max_size: 128 * 1024 * 1024,
|
||
}
|
||
}
|
||
|
||
/// 设置最大大小
|
||
pub fn set_max_size(&mut self, max_size: usize) -> Result<()> {
|
||
if max_size > 128 * 1024 * 1024 {
|
||
return Err(anyhow!("max_size too large (max 128MB)"));
|
||
}
|
||
self.max_size = max_size;
|
||
Ok(())
|
||
}
|
||
|
||
/// 获取 buffer 长度(对应 OpenSSH sshbuf_len)
|
||
///
|
||
/// OpenSSH: `sshbuf_len = buf->size - buf->off`
|
||
pub fn len(&self) -> usize {
|
||
self.size - self.off
|
||
}
|
||
|
||
/// 检查 buffer 是否为空
|
||
pub fn is_empty(&self) -> bool {
|
||
self.len() == 0
|
||
}
|
||
|
||
/// 获取可用空间(对应 OpenSSH sshbuf_avail)
|
||
///
|
||
/// OpenSSH: `sshbuf_avail = buf->max_size - buf->size`
|
||
pub fn avail(&self) -> usize {
|
||
self.max_size - self.size
|
||
}
|
||
|
||
/// 获取可变指针(对应 OpenSSH sshbuf_mutable_ptr)
|
||
///
|
||
/// OpenSSH 实现:
|
||
/// ```c
|
||
/// u_char *sshbuf_mutable_ptr(const struct sshbuf *buf) {
|
||
/// return buf->d + buf->off;
|
||
/// }
|
||
/// ```
|
||
///
|
||
/// Rust 实现:返回 `&mut [u8]` slice(零拷贝)
|
||
pub fn mutable_ptr(&mut self) -> &mut [u8] {
|
||
&mut self.data[self.off..self.size]
|
||
}
|
||
|
||
/// 获取不可变指针(对应 OpenSSH sshbuf_ptr)
|
||
pub fn ptr(&self) -> &[u8] {
|
||
&self.data[self.off..self.size]
|
||
}
|
||
|
||
/// 预分配空间(对应 OpenSSH sshbuf_reserve)
|
||
///
|
||
/// OpenSSH 实现:
|
||
/// ```c
|
||
/// int sshbuf_reserve(struct sshbuf *buf, size_t len, u_char **dpp) {
|
||
/// if ((r = sshbuf_allocate(buf, len)) != 0)
|
||
/// return r;
|
||
///
|
||
/// dp = buf->d + buf->size;
|
||
/// buf->size += len;
|
||
/// *dpp = dp;
|
||
/// return 0;
|
||
/// }
|
||
/// ```
|
||
///
|
||
/// Rust 实现:返回 `&mut [u8]` slice(零拷贝,可直接 write)
|
||
pub fn reserve(&mut self, len: usize) -> Result<&mut [u8]> {
|
||
if len > self.avail() {
|
||
return Err(anyhow!("no buffer space (avail={})", self.avail()));
|
||
}
|
||
|
||
// 预分配空间
|
||
let current_size = self.size;
|
||
let new_size = current_size + len;
|
||
|
||
// 确保 Vec 有足够容量
|
||
if new_size > self.data.len() {
|
||
self.data.resize(new_size, 0);
|
||
}
|
||
|
||
// 更新 size
|
||
self.size = new_size;
|
||
|
||
// 返回新空间的 slice(零拷贝)
|
||
Ok(&mut self.data[current_size..new_size])
|
||
}
|
||
|
||
/// 消费数据(对应 OpenSSH sshbuf_consume)
|
||
///
|
||
/// OpenSSH 实现:
|
||
/// ```c
|
||
/// int sshbuf_consume(struct sshbuf *buf, size_t len) {
|
||
/// buf->off += len;
|
||
///
|
||
/// if (buf->off == buf->size)
|
||
/// buf->off = buf->size = 0;
|
||
///
|
||
/// return 0;
|
||
/// }
|
||
/// ```
|
||
///
|
||
/// Rust 实现:移动偏移量(零拷贝,不实际删除数据)
|
||
pub fn consume(&mut self, len: usize) -> Result<()> {
|
||
if len > self.len() {
|
||
return Err(anyhow!(
|
||
"message incomplete (len={}, consume={})",
|
||
self.len(),
|
||
len
|
||
));
|
||
}
|
||
|
||
self.off += len;
|
||
|
||
// 如果 buffer 空,重置
|
||
if self.off == self.size {
|
||
self.off = 0;
|
||
self.size = 0;
|
||
|
||
// OpenSSH: pack buffer(移除已消费的数据)
|
||
// Rust: 我们保留 Vec,但重置指针
|
||
}
|
||
|
||
Ok(())
|
||
}
|
||
|
||
/// 从末尾消费数据(对应 OpenSSH sshbuf_consume_end)
|
||
///
|
||
/// OpenSSH 实现:
|
||
/// ```c
|
||
/// int sshbuf_consume_end(struct sshbuf *buf, size_t len) {
|
||
/// buf->size -= len;
|
||
/// return 0;
|
||
/// }
|
||
/// ```
|
||
pub fn consume_end(&mut self, len: usize) -> Result<()> {
|
||
if len > self.len() {
|
||
return Err(anyhow!("message incomplete"));
|
||
}
|
||
|
||
self.size -= len;
|
||
Ok(())
|
||
}
|
||
|
||
/// 直接从 fd read 到 buffer(对应 OpenSSH sshbuf_read)
|
||
///
|
||
/// OpenSSH 实现:
|
||
/// ```c
|
||
/// int sshbuf_read(int fd, struct sshbuf *buf, size_t maxlen, size_t *rlen) {
|
||
/// if ((r = sshbuf_reserve(buf, maxlen, &d)) != 0)
|
||
/// return r;
|
||
///
|
||
/// rr = read(fd, d, maxlen); // 直接 read 到 buffer
|
||
///
|
||
/// if ((adjust = maxlen - rr) != 0)
|
||
/// sshbuf_consume_end(buf, adjust); // 调整大小
|
||
///
|
||
/// return 0;
|
||
/// }
|
||
/// ```
|
||
///
|
||
/// Rust 实现:零拷贝,直接 read 到 buffer
|
||
pub fn read_from<R: Read>(&mut self, reader: &mut R, maxlen: usize) -> Result<usize> {
|
||
// 1. reserve 空间
|
||
let space = self.reserve(maxlen)?;
|
||
|
||
// 2. 直接 read 到 buffer(零拷贝)
|
||
let n = reader.read(space)?;
|
||
|
||
// 3. 调整大小(移除未使用的空间)
|
||
if maxlen > n {
|
||
self.consume_end(maxlen - n)?;
|
||
}
|
||
|
||
Ok(n)
|
||
}
|
||
|
||
/// 直接从 buffer write 到 fd(对应 OpenSSH channel_handle_wfd)
|
||
///
|
||
/// OpenSSH 实现:
|
||
/// ```c
|
||
/// buf = sshbuf_mutable_ptr(c->output); // 获取指针
|
||
/// len = write(c->wfd, buf, dlen); // 直接 write
|
||
/// sshbuf_consume(c->output, len); // 消费已写入的数据
|
||
/// ```
|
||
///
|
||
/// Rust 实现:零拷贝,直接 write 从 buffer
|
||
pub fn write_to<W: Write>(&mut self, writer: &mut W) -> Result<usize> {
|
||
if self.is_empty() {
|
||
return Ok(0);
|
||
}
|
||
|
||
// 1. 获取数据指针(零拷贝)
|
||
let data = self.ptr();
|
||
|
||
// 2. 直接 write(零拷贝)
|
||
let n = writer.write(data)?;
|
||
|
||
// 3. 消费已写入的数据(零拷贝,只移动偏移)
|
||
self.consume(n)?;
|
||
|
||
Ok(n)
|
||
}
|
||
|
||
/// 添加数据(对应 OpenSSH sshbuf_put)
|
||
///
|
||
/// 用于不需要零拷贝的场景
|
||
pub fn put(&mut self, data: &[u8]) -> Result<()> {
|
||
let space = self.reserve(data.len())?;
|
||
space.copy_from_slice(data);
|
||
Ok(())
|
||
}
|
||
|
||
/// 清空 buffer
|
||
pub fn reset(&mut self) {
|
||
self.off = 0;
|
||
self.size = 0;
|
||
// OpenSSH: 保留 Vec,只重置指针
|
||
}
|
||
|
||
/// 消费内部 Vec,提取有效数据(零拷贝)
|
||
/// 相当于 OpenSSH sshbuf_free() 但返回数据
|
||
pub fn into_vec(mut self) -> Vec<u8> {
|
||
let len = self.len();
|
||
if self.off == 0 && self.size == self.data.len() {
|
||
// 正好是完整 buffer,直接返回
|
||
self.data
|
||
} else {
|
||
// 需要截取有效部分
|
||
self.data[self.off..self.size].to_vec()
|
||
}
|
||
}
|
||
|
||
/// Debug: 打印 buffer 状态
|
||
pub fn debug_info(&self) -> String {
|
||
format!(
|
||
"SshBuf: off={}, size={}, len={}, alloc={}, max_size={}",
|
||
self.off,
|
||
self.size,
|
||
self.len(),
|
||
self.data.len(),
|
||
self.max_size
|
||
)
|
||
}
|
||
}
|
||
|
||
impl Default for SshBuf {
|
||
fn default() -> Self {
|
||
Self::new()
|
||
}
|
||
}
|
||
|
||
#[cfg(test)]
|
||
mod tests {
|
||
use super::*;
|
||
use std::io::Cursor;
|
||
|
||
#[test]
|
||
fn test_sshbuf_basic() {
|
||
let mut buf = SshBuf::new();
|
||
|
||
// Test reserve - write into reserved space
|
||
{
|
||
let space = buf.reserve(10).unwrap();
|
||
assert_eq!(space.len(), 10);
|
||
space[0] = 1;
|
||
space[1] = 2;
|
||
} // space dropped, buf accessible
|
||
|
||
// Verify buffer length after reserve
|
||
assert_eq!(buf.len(), 10);
|
||
let ptr = buf.mutable_ptr();
|
||
assert_eq!(ptr[0], 1);
|
||
assert_eq!(ptr[1], 2);
|
||
|
||
// Test consume
|
||
buf.consume(2).unwrap();
|
||
assert_eq!(buf.len(), 8);
|
||
assert_eq!(buf.off, 2);
|
||
}
|
||
|
||
#[test]
|
||
fn test_sshbuf_zero_copy_read() {
|
||
let mut buf = SshBuf::with_capacity(100);
|
||
let mut reader = Cursor::new("hello world");
|
||
|
||
// 零拷贝 read
|
||
let n = buf.read_from(&mut reader, 20).unwrap();
|
||
assert_eq!(n, 11); // "hello world" length
|
||
assert_eq!(buf.len(), 11);
|
||
|
||
// 检查数据
|
||
let data = buf.ptr();
|
||
assert_eq!(data, "hello world".as_bytes());
|
||
}
|
||
|
||
#[test]
|
||
fn test_sshbuf_zero_copy_write() {
|
||
let mut buf = SshBuf::new();
|
||
buf.put("hello world".as_bytes()).unwrap();
|
||
|
||
let mut writer = Vec::new();
|
||
|
||
// 零拷贝 write
|
||
let n = buf.write_to(&mut writer).unwrap();
|
||
assert_eq!(n, 11);
|
||
assert_eq!(buf.len(), 0); // 已消费
|
||
|
||
// 检查数据
|
||
assert_eq!(writer, "hello world".as_bytes());
|
||
}
|
||
|
||
#[test]
|
||
fn test_sshbuf_max_size() {
|
||
let mut buf = SshBuf::new();
|
||
buf.set_max_size(1000).unwrap();
|
||
|
||
// 尝试 reserve 超过 max_size
|
||
let result = buf.reserve(2000);
|
||
assert!(result.is_err());
|
||
}
|
||
}
|