- Fix trailing whitespace in kex.rs and s3.rs - Add missing KexProposal import in kex_complete.rs - Auto-fix clippy warnings across all crates - All 153 tests pass
205 lines
6.7 KiB
Rust
205 lines
6.7 KiB
Rust
use super::{DataProvider, ProviderError, User};
|
||
use bcrypt::verify;
|
||
use postgres::{Client, NoTls};
|
||
use std::path::PathBuf;
|
||
|
||
/// PostgreSQL 数据提供者(兼容 SFTPGo 的 users 表)
|
||
pub struct PgProvider {
|
||
conn_str: String,
|
||
}
|
||
|
||
impl PgProvider {
|
||
/// 从连接字符串创建 PgProvider
|
||
///
|
||
/// 连接字符串格式:host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026
|
||
pub fn new(conn_str: &str) -> Result<Self, ProviderError> {
|
||
Ok(Self {
|
||
conn_str: conn_str.to_string(),
|
||
})
|
||
}
|
||
|
||
pub fn from_params(
|
||
host: &str,
|
||
port: u16,
|
||
dbname: &str,
|
||
user: &str,
|
||
password: &str,
|
||
) -> Result<Self, ProviderError> {
|
||
let conn_str = format!(
|
||
"host={} port={} dbname={} user={} password={}",
|
||
host, port, dbname, user, password
|
||
);
|
||
Ok(Self { conn_str })
|
||
}
|
||
|
||
fn open_conn(&self) -> Result<Client, ProviderError> {
|
||
Client::connect(&self.conn_str, NoTls)
|
||
.map_err(|e| ProviderError::Internal(format!("PostgreSQL connect failed: {}", e)))
|
||
}
|
||
}
|
||
|
||
impl DataProvider for PgProvider {
|
||
fn get_user(&self, username: &str) -> Result<Option<User>, ProviderError> {
|
||
let mut conn = self.open_conn()?;
|
||
|
||
let result = conn
|
||
.query_opt(
|
||
"SELECT username, password, home_dir, permissions, uid, gid, status
|
||
FROM users WHERE username = $1 AND status = 1",
|
||
&[&username],
|
||
)
|
||
.map_err(|e| ProviderError::Internal(format!("Query error: {}", e)))?;
|
||
|
||
match result {
|
||
Some(row) => Ok(Some(User {
|
||
username: row.get(0),
|
||
password_hash: row.get::<_, Option<String>>(1).unwrap_or_default(),
|
||
home_dir: PathBuf::from(row.get::<_, String>(2)),
|
||
permissions: row
|
||
.get::<_, Option<String>>(3)
|
||
.unwrap_or_else(|| "*".to_string()),
|
||
uid: row.get::<_, i64>(4) as u32,
|
||
gid: row.get::<_, i64>(5) as u32,
|
||
status: row.get(6),
|
||
})),
|
||
None => Ok(None),
|
||
}
|
||
}
|
||
|
||
fn check_password(&self, username: &str, password: &str) -> Result<bool, ProviderError> {
|
||
let hash = match self.get_user(username)? {
|
||
Some(user) => user.password_hash,
|
||
None => return Ok(false),
|
||
};
|
||
|
||
if hash.is_empty() {
|
||
return Ok(false);
|
||
}
|
||
|
||
verify(password, &hash)
|
||
.map_err(|e| ProviderError::Internal(format!("bcrypt verify error: {}", e)))
|
||
}
|
||
|
||
fn get_home_dir(&self, username: &str) -> Result<Option<String>, ProviderError> {
|
||
Ok(self
|
||
.get_user(username)?
|
||
.map(|u| u.home_dir.to_string_lossy().to_string()))
|
||
}
|
||
|
||
fn get_public_keys(&self, username: &str) -> Result<Vec<String>, ProviderError> {
|
||
let mut conn = self.open_conn()?;
|
||
let result = conn
|
||
.query_opt(
|
||
"SELECT public_keys FROM users WHERE username = $1 AND status = 1",
|
||
&[&username],
|
||
)
|
||
.map_err(|e| ProviderError::Internal(format!("Query error: {}", e)))?;
|
||
|
||
match result {
|
||
Some(row) => {
|
||
let json_str: Option<String> = row.get(0);
|
||
match json_str {
|
||
Some(s) if !s.is_empty() => {
|
||
let keys: Vec<serde_json::Value> =
|
||
serde_json::from_str(&s).map_err(|e| {
|
||
ProviderError::Internal(format!("JSON parse error: {}", e))
|
||
})?;
|
||
Ok(keys
|
||
.iter()
|
||
.filter_map(|v| v.get("public_key")?.as_str().map(|s| s.to_string()))
|
||
.collect())
|
||
}
|
||
_ => Ok(Vec::new()),
|
||
}
|
||
}
|
||
None => Ok(Vec::new()),
|
||
}
|
||
}
|
||
}
|
||
|
||
#[cfg(test)]
|
||
mod tests {
|
||
use super::*;
|
||
|
||
#[test]
|
||
fn test_pg_provider_connection() {
|
||
// 仅当 SFTPGo PostgreSQL 可用时运行
|
||
let provider = PgProvider::new(
|
||
"host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026",
|
||
);
|
||
assert!(provider.is_ok(), "Should connect to SFTPGo PostgreSQL");
|
||
}
|
||
|
||
#[test]
|
||
fn test_pg_get_user_demo() {
|
||
let provider = PgProvider::new(
|
||
"host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026",
|
||
)
|
||
.unwrap();
|
||
let user = provider.get_user("demo").unwrap();
|
||
assert!(user.is_some(), "Demo user should exist");
|
||
assert_eq!(user.unwrap().username, "demo");
|
||
}
|
||
|
||
#[test]
|
||
fn test_pg_get_user_momentry() {
|
||
let provider = PgProvider::new(
|
||
"host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026",
|
||
)
|
||
.unwrap();
|
||
let user = provider.get_user("momentry").unwrap();
|
||
assert!(user.is_some(), "Momentry user should exist");
|
||
}
|
||
|
||
#[test]
|
||
fn test_pg_get_user_warren() {
|
||
let provider = PgProvider::new(
|
||
"host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026",
|
||
)
|
||
.unwrap();
|
||
let user = provider.get_user("warren").unwrap();
|
||
assert!(user.is_some(), "Warren user should exist");
|
||
}
|
||
|
||
#[test]
|
||
fn test_pg_check_password_demo() {
|
||
let provider = PgProvider::new(
|
||
"host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026",
|
||
)
|
||
.unwrap();
|
||
let valid = provider.check_password("demo", "demo123").unwrap();
|
||
assert!(valid, "Password should be valid");
|
||
}
|
||
|
||
#[test]
|
||
fn test_pg_check_password_invalid() {
|
||
let provider = PgProvider::new(
|
||
"host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026",
|
||
)
|
||
.unwrap();
|
||
let valid = provider.check_password("demo", "wrong").unwrap();
|
||
assert!(!valid, "Wrong password should fail");
|
||
}
|
||
|
||
#[test]
|
||
fn test_pg_get_home_dir() {
|
||
let provider = PgProvider::new(
|
||
"host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026",
|
||
)
|
||
.unwrap();
|
||
let dir = provider.get_home_dir("demo").unwrap();
|
||
assert!(dir.is_some());
|
||
assert!(dir.unwrap().contains("momentry"));
|
||
}
|
||
|
||
#[test]
|
||
fn test_pg_nonexistent_user() {
|
||
let provider = PgProvider::new(
|
||
"host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026",
|
||
)
|
||
.unwrap();
|
||
let user = provider.get_user("__nonexistent__").unwrap();
|
||
assert!(user.is_none());
|
||
}
|
||
}
|