- Fix trailing whitespace in kex.rs and s3.rs - Add missing KexProposal import in kex_complete.rs - Auto-fix clippy warnings across all crates - All 153 tests pass
140 lines
4.2 KiB
Rust
140 lines
4.2 KiB
Rust
// SSH版本交换实现
|
||
// 参考OpenSSH sshd.c: ssh_exchange_identification()
|
||
|
||
use anyhow::Result;
|
||
use log::{debug, info};
|
||
use std::io::{Read, Write};
|
||
|
||
/// SSH版本字符串
|
||
pub const SSH_VERSION: &str = "SSH-2.0-MarkBaseSSH_1.0";
|
||
|
||
/// 版本交换处理器
|
||
pub struct VersionExchange;
|
||
|
||
impl VersionExchange {
|
||
/// 执行版本交换(服务器端)
|
||
pub fn exchange<T: Read + Write>(stream: &mut T) -> Result<String> {
|
||
info!("Starting SSH version exchange");
|
||
|
||
// 1. 发送服务器版本
|
||
Self::send_version(stream)?;
|
||
|
||
// 2. 接收客户端版本
|
||
let client_version = Self::receive_version(stream)?;
|
||
|
||
info!(
|
||
"Version exchange completed: server={}, client={}",
|
||
SSH_VERSION, client_version
|
||
);
|
||
Ok(client_version)
|
||
}
|
||
|
||
/// 发送服务器版本(参考OpenSSH ssh_exchange_identification)
|
||
fn send_version<T: Write>(stream: &mut T) -> Result<()> {
|
||
let version_line = format!("{}\r\n", SSH_VERSION);
|
||
stream.write_all(version_line.as_bytes())?;
|
||
stream.flush()?;
|
||
|
||
debug!("Sent version: {}", SSH_VERSION);
|
||
Ok(())
|
||
}
|
||
|
||
/// 接收客户端版本(参考OpenSSH ssh_exchange_identification)
|
||
fn receive_version<T: Read>(stream: &mut T) -> Result<String> {
|
||
let mut buffer = Vec::new();
|
||
let mut byte = [0u8; 1];
|
||
|
||
// 读取直到遇到'\n'(参考OpenSSH实现)
|
||
loop {
|
||
stream.read_exact(&mut byte)?;
|
||
|
||
// OpenSSH兼容性处理:跳过前导空行和调试信息
|
||
if buffer.is_empty() && byte[0] == b'\n' {
|
||
continue; // 跳过空行
|
||
}
|
||
|
||
// 调试信息行(以'#'开头),跳过
|
||
if buffer.is_empty() && byte[0] == b'#' {
|
||
// 读取整行调试信息
|
||
while byte[0] != b'\n' {
|
||
stream.read_exact(&mut byte)?;
|
||
}
|
||
buffer.clear();
|
||
continue;
|
||
}
|
||
|
||
buffer.push(byte[0]);
|
||
|
||
// 遇到'\n'结束
|
||
if byte[0] == b'\n' {
|
||
break;
|
||
}
|
||
|
||
// 缓冲区溢出保护(OpenSSH限制:255字节)
|
||
if buffer.len() > 255 {
|
||
return Err(anyhow::anyhow!("Version string too long"));
|
||
}
|
||
}
|
||
|
||
// 解析版本字符串
|
||
let version_line = String::from_utf8(buffer)?;
|
||
let version = version_line.trim().trim_matches('\r');
|
||
|
||
// 验证版本格式(SSH-2.0-*)
|
||
if !version.starts_with("SSH-2.0-") {
|
||
return Err(anyhow::anyhow!("Invalid SSH version: {}", version));
|
||
}
|
||
|
||
debug!("Received version: {}", version);
|
||
Ok(version.to_string())
|
||
}
|
||
|
||
/// 解析客户端版本信息(兼容性检查)
|
||
pub fn parse_client_version(version: &str) -> Result<ClientVersionInfo> {
|
||
// 格式:SSH-protoversion-softwareversion SP comments
|
||
let parts: Vec<&str> = version.split_whitespace().collect();
|
||
|
||
let main_part = parts.first().map_or(version, |v| v);
|
||
let dash_parts: Vec<&str> = main_part.split('-').collect();
|
||
|
||
if dash_parts.len() < 3 {
|
||
return Err(anyhow::anyhow!("Invalid version format: {}", version));
|
||
}
|
||
|
||
let proto_version = dash_parts.get(1).map_or("2.0", |v| v);
|
||
let software_version = dash_parts.get(2).map_or("unknown", |v| v);
|
||
let comments = parts.get(1).map(|s| s.to_string());
|
||
|
||
Ok(ClientVersionInfo {
|
||
proto_version: proto_version.to_string(),
|
||
software_version: software_version.to_string(),
|
||
comments,
|
||
})
|
||
}
|
||
}
|
||
|
||
/// 客户端版本信息
|
||
pub struct ClientVersionInfo {
|
||
pub proto_version: String,
|
||
pub software_version: String,
|
||
pub comments: Option<String>,
|
||
}
|
||
|
||
#[cfg(test)]
|
||
mod tests {
|
||
use super::*;
|
||
|
||
#[test]
|
||
fn test_version_format() {
|
||
assert!(SSH_VERSION.starts_with("SSH-2.0-"));
|
||
}
|
||
|
||
#[test]
|
||
fn test_parse_client_version() {
|
||
let version = "SSH-2.0-OpenSSH_10.2";
|
||
let info = VersionExchange::parse_client_version(version).unwrap();
|
||
assert_eq!(info.proto_version, "2.0");
|
||
assert_eq!(info.software_version, "OpenSSH_10.2");
|
||
}
|
||
}
|