Files
markbase/markbase-core/src/provider/pg.rs
Warren d94cb2df4c Fix code quality: trailing whitespace, unused imports, clippy warnings
- Fix trailing whitespace in kex.rs and s3.rs
- Add missing KexProposal import in kex_complete.rs
- Auto-fix clippy warnings across all crates
- All 153 tests pass
2026-06-19 05:21:38 +08:00

205 lines
6.7 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
use super::{DataProvider, ProviderError, User};
use bcrypt::verify;
use postgres::{Client, NoTls};
use std::path::PathBuf;
/// PostgreSQL 数据提供者(兼容 SFTPGo 的 users 表)
pub struct PgProvider {
conn_str: String,
}
impl PgProvider {
/// 从连接字符串创建 PgProvider
///
/// 连接字符串格式host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026
pub fn new(conn_str: &str) -> Result<Self, ProviderError> {
Ok(Self {
conn_str: conn_str.to_string(),
})
}
pub fn from_params(
host: &str,
port: u16,
dbname: &str,
user: &str,
password: &str,
) -> Result<Self, ProviderError> {
let conn_str = format!(
"host={} port={} dbname={} user={} password={}",
host, port, dbname, user, password
);
Ok(Self { conn_str })
}
fn open_conn(&self) -> Result<Client, ProviderError> {
Client::connect(&self.conn_str, NoTls)
.map_err(|e| ProviderError::Internal(format!("PostgreSQL connect failed: {}", e)))
}
}
impl DataProvider for PgProvider {
fn get_user(&self, username: &str) -> Result<Option<User>, ProviderError> {
let mut conn = self.open_conn()?;
let result = conn
.query_opt(
"SELECT username, password, home_dir, permissions, uid, gid, status
FROM users WHERE username = $1 AND status = 1",
&[&username],
)
.map_err(|e| ProviderError::Internal(format!("Query error: {}", e)))?;
match result {
Some(row) => Ok(Some(User {
username: row.get(0),
password_hash: row.get::<_, Option<String>>(1).unwrap_or_default(),
home_dir: PathBuf::from(row.get::<_, String>(2)),
permissions: row
.get::<_, Option<String>>(3)
.unwrap_or_else(|| "*".to_string()),
uid: row.get::<_, i64>(4) as u32,
gid: row.get::<_, i64>(5) as u32,
status: row.get(6),
})),
None => Ok(None),
}
}
fn check_password(&self, username: &str, password: &str) -> Result<bool, ProviderError> {
let hash = match self.get_user(username)? {
Some(user) => user.password_hash,
None => return Ok(false),
};
if hash.is_empty() {
return Ok(false);
}
verify(password, &hash)
.map_err(|e| ProviderError::Internal(format!("bcrypt verify error: {}", e)))
}
fn get_home_dir(&self, username: &str) -> Result<Option<String>, ProviderError> {
Ok(self
.get_user(username)?
.map(|u| u.home_dir.to_string_lossy().to_string()))
}
fn get_public_keys(&self, username: &str) -> Result<Vec<String>, ProviderError> {
let mut conn = self.open_conn()?;
let result = conn
.query_opt(
"SELECT public_keys FROM users WHERE username = $1 AND status = 1",
&[&username],
)
.map_err(|e| ProviderError::Internal(format!("Query error: {}", e)))?;
match result {
Some(row) => {
let json_str: Option<String> = row.get(0);
match json_str {
Some(s) if !s.is_empty() => {
let keys: Vec<serde_json::Value> =
serde_json::from_str(&s).map_err(|e| {
ProviderError::Internal(format!("JSON parse error: {}", e))
})?;
Ok(keys
.iter()
.filter_map(|v| v.get("public_key")?.as_str().map(|s| s.to_string()))
.collect())
}
_ => Ok(Vec::new()),
}
}
None => Ok(Vec::new()),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pg_provider_connection() {
// 仅当 SFTPGo PostgreSQL 可用时运行
let provider = PgProvider::new(
"host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026",
);
assert!(provider.is_ok(), "Should connect to SFTPGo PostgreSQL");
}
#[test]
fn test_pg_get_user_demo() {
let provider = PgProvider::new(
"host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026",
)
.unwrap();
let user = provider.get_user("demo").unwrap();
assert!(user.is_some(), "Demo user should exist");
assert_eq!(user.unwrap().username, "demo");
}
#[test]
fn test_pg_get_user_momentry() {
let provider = PgProvider::new(
"host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026",
)
.unwrap();
let user = provider.get_user("momentry").unwrap();
assert!(user.is_some(), "Momentry user should exist");
}
#[test]
fn test_pg_get_user_warren() {
let provider = PgProvider::new(
"host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026",
)
.unwrap();
let user = provider.get_user("warren").unwrap();
assert!(user.is_some(), "Warren user should exist");
}
#[test]
fn test_pg_check_password_demo() {
let provider = PgProvider::new(
"host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026",
)
.unwrap();
let valid = provider.check_password("demo", "demo123").unwrap();
assert!(valid, "Password should be valid");
}
#[test]
fn test_pg_check_password_invalid() {
let provider = PgProvider::new(
"host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026",
)
.unwrap();
let valid = provider.check_password("demo", "wrong").unwrap();
assert!(!valid, "Wrong password should fail");
}
#[test]
fn test_pg_get_home_dir() {
let provider = PgProvider::new(
"host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026",
)
.unwrap();
let dir = provider.get_home_dir("demo").unwrap();
assert!(dir.is_some());
assert!(dir.unwrap().contains("momentry"));
}
#[test]
fn test_pg_nonexistent_user() {
let provider = PgProvider::new(
"host=127.0.0.1 port=5432 dbname=sftpgo user=sftpgo password=sftpgo_pass_2026",
)
.unwrap();
let user = provider.get_user("__nonexistent__").unwrap();
assert!(user.is_none());
}
}