SMB Server Phase 2: VFS backend build fix + integration test
- Add VfsFile: Send supertrait for Mutex compatibility - Fix SmbServerCommand: struct → Subcommand enum with Start variant - Fix tracing_subscriber::init() → try_init() to avoid panic when logger already initialized - Fix CLI subcommand name: smb-server → smb-start (flatten naming) - Add #[command(name = "smb-start")] for CLI disambiguation - Fix unused variable warnings (smb_fs.rs, smb_server_backend.rs) - Remove unused VfsFile imports (webdav.rs, scp_handler.rs) - Integration test: Docker smbclient verified (list, upload, read)
This commit is contained in:
8
vendor/smb2/.cargo/audit.toml
vendored
Normal file
8
vendor/smb2/.cargo/audit.toml
vendored
Normal file
@@ -0,0 +1,8 @@
|
||||
# Ignored advisories for cargo-audit
|
||||
|
||||
[advisories]
|
||||
ignore = [
|
||||
# Marvin Attack timing sidechannel in `rsa` crate. No fix available.
|
||||
# Only affects benchmarks/smb/ (via sspi -> rsa), not the smb2 crate itself.
|
||||
"RUSTSEC-2023-0071",
|
||||
]
|
||||
12
vendor/smb2/.claude/rules/docs-maintenance.md
vendored
Normal file
12
vendor/smb2/.claude/rules/docs-maintenance.md
vendored
Normal file
@@ -0,0 +1,12 @@
|
||||
When modifying code in a directory that contains a `CLAUDE.md` file, check whether your changes affect the documented
|
||||
architecture, key decisions, or gotchas. If they do, update the `CLAUDE.md` to stay in sync. If you notice a `CLAUDE.md`
|
||||
missing in a directory where there should be one, add it. Skip this for trivial changes (bug fixes, formatting, small
|
||||
refactors that don't change the architecture).
|
||||
|
||||
If something failed due to a wrong assumption, add a `Gotcha/Why` entry to the nearest `CLAUDE.md`.
|
||||
|
||||
Add `Decision/Why` entries to the nearest colocated `CLAUDE.md` for key decisions. If the decision has rich evidence
|
||||
(benchmarks, detailed analysis), put the evidence in `docs/notes/` and link from the CLAUDE.md.
|
||||
|
||||
When writing guides, see [this diff](https://github.com/vdavid/cmdr/commit/13ad8f3#diff-795210f) for the formatting
|
||||
standard. (Before: AI-written. After: matching our standards for conciseness and clarity.)
|
||||
14
vendor/smb2/.claude/rules/git-conventions.md
vendored
Normal file
14
vendor/smb2/.claude/rules/git-conventions.md
vendored
Normal file
@@ -0,0 +1,14 @@
|
||||
## Commit messages
|
||||
|
||||
- Use conventional commit messages.
|
||||
- Title: Capture the IMPACT of the change, not the tech details. From the title, we need to understand WHY we did this,
|
||||
what we ACHIEVED with the commit. Length-wise, aim for about 50 chars max.
|
||||
- Body: Use bullets primarily. No word wrap. Don't hard-wrap body lines at 72 chars or any other width. Let the
|
||||
terminal/viewer wrap naturally. Enclose entities in ``. No co-author!
|
||||
|
||||
## PRs
|
||||
|
||||
- Use the PR title to summarize the changes in a casual/informal tone. Be information dense and concise.
|
||||
- In the desc., write a thorough, organized, but concise, often bulleted list of the changes. Use no headings.
|
||||
- At the bottom of the PR description, use a single "## Test plan" heading, in which, explain how the changes were
|
||||
tested. Assume that the changes were also tested manually if it makes sense for the type of changes.
|
||||
16
vendor/smb2/.codegraph/.gitignore
vendored
Normal file
16
vendor/smb2/.codegraph/.gitignore
vendored
Normal file
@@ -0,0 +1,16 @@
|
||||
# CodeGraph data files
|
||||
# These are local to each machine and should not be committed
|
||||
|
||||
# Database
|
||||
*.db
|
||||
*.db-wal
|
||||
*.db-shm
|
||||
|
||||
# Cache
|
||||
cache/
|
||||
|
||||
# Logs
|
||||
*.log
|
||||
|
||||
# Hook markers
|
||||
.dirty
|
||||
140
vendor/smb2/.codegraph/config.json
vendored
Normal file
140
vendor/smb2/.codegraph/config.json
vendored
Normal file
@@ -0,0 +1,140 @@
|
||||
{
|
||||
"version": 1,
|
||||
"include": [
|
||||
"**/*.ts",
|
||||
"**/*.tsx",
|
||||
"**/*.js",
|
||||
"**/*.jsx",
|
||||
"**/*.py",
|
||||
"**/*.go",
|
||||
"**/*.rs",
|
||||
"**/*.java",
|
||||
"**/*.c",
|
||||
"**/*.h",
|
||||
"**/*.cpp",
|
||||
"**/*.hpp",
|
||||
"**/*.cc",
|
||||
"**/*.cxx",
|
||||
"**/*.cs",
|
||||
"**/*.php",
|
||||
"**/*.rb",
|
||||
"**/*.swift",
|
||||
"**/*.kt",
|
||||
"**/*.kts",
|
||||
"**/*.dart",
|
||||
"**/*.svelte",
|
||||
"**/*.liquid",
|
||||
"**/*.pas",
|
||||
"**/*.dpr",
|
||||
"**/*.dpk",
|
||||
"**/*.lpr",
|
||||
"**/*.dfm",
|
||||
"**/*.fmx"
|
||||
],
|
||||
"exclude": [
|
||||
"**/.git/**",
|
||||
"**/node_modules/**",
|
||||
"**/vendor/**",
|
||||
"**/Pods/**",
|
||||
"**/dist/**",
|
||||
"**/build/**",
|
||||
"**/out/**",
|
||||
"**/bin/**",
|
||||
"**/obj/**",
|
||||
"**/target/**",
|
||||
"**/*.min.js",
|
||||
"**/*.bundle.js",
|
||||
"**/.next/**",
|
||||
"**/.nuxt/**",
|
||||
"**/.svelte-kit/**",
|
||||
"**/.output/**",
|
||||
"**/.turbo/**",
|
||||
"**/.cache/**",
|
||||
"**/.parcel-cache/**",
|
||||
"**/.vite/**",
|
||||
"**/.astro/**",
|
||||
"**/.docusaurus/**",
|
||||
"**/.gatsby/**",
|
||||
"**/.webpack/**",
|
||||
"**/.nx/**",
|
||||
"**/.yarn/cache/**",
|
||||
"**/.pnpm-store/**",
|
||||
"**/storybook-static/**",
|
||||
"**/.expo/**",
|
||||
"**/web-build/**",
|
||||
"**/ios/Pods/**",
|
||||
"**/ios/build/**",
|
||||
"**/android/build/**",
|
||||
"**/android/.gradle/**",
|
||||
"**/__pycache__/**",
|
||||
"**/.venv/**",
|
||||
"**/venv/**",
|
||||
"**/site-packages/**",
|
||||
"**/dist-packages/**",
|
||||
"**/.pytest_cache/**",
|
||||
"**/.mypy_cache/**",
|
||||
"**/.ruff_cache/**",
|
||||
"**/.tox/**",
|
||||
"**/.nox/**",
|
||||
"**/*.egg-info/**",
|
||||
"**/.eggs/**",
|
||||
"**/go/pkg/mod/**",
|
||||
"**/target/debug/**",
|
||||
"**/target/release/**",
|
||||
"**/.gradle/**",
|
||||
"**/.m2/**",
|
||||
"**/generated-sources/**",
|
||||
"**/.kotlin/**",
|
||||
"**/.dart_tool/**",
|
||||
"**/.vs/**",
|
||||
"**/.nuget/**",
|
||||
"**/artifacts/**",
|
||||
"**/publish/**",
|
||||
"**/cmake-build-*/**",
|
||||
"**/CMakeFiles/**",
|
||||
"**/bazel-*/**",
|
||||
"**/vcpkg_installed/**",
|
||||
"**/.conan/**",
|
||||
"**/Debug/**",
|
||||
"**/Release/**",
|
||||
"**/x64/**",
|
||||
"**/.pio/**",
|
||||
"**/release/**",
|
||||
"**/*.app/**",
|
||||
"**/*.asar",
|
||||
"**/DerivedData/**",
|
||||
"**/.build/**",
|
||||
"**/.swiftpm/**",
|
||||
"**/xcuserdata/**",
|
||||
"**/Carthage/Build/**",
|
||||
"**/SourcePackages/**",
|
||||
"**/__history/**",
|
||||
"**/__recovery/**",
|
||||
"**/*.dcu",
|
||||
"**/.composer/**",
|
||||
"**/storage/framework/**",
|
||||
"**/bootstrap/cache/**",
|
||||
"**/.bundle/**",
|
||||
"**/tmp/cache/**",
|
||||
"**/public/assets/**",
|
||||
"**/public/packs/**",
|
||||
"**/.yardoc/**",
|
||||
"**/coverage/**",
|
||||
"**/htmlcov/**",
|
||||
"**/.nyc_output/**",
|
||||
"**/test-results/**",
|
||||
"**/.coverage/**",
|
||||
"**/.idea/**",
|
||||
"**/logs/**",
|
||||
"**/tmp/**",
|
||||
"**/temp/**",
|
||||
"**/_build/**",
|
||||
"**/docs/_build/**",
|
||||
"**/site/**"
|
||||
],
|
||||
"languages": [],
|
||||
"frameworks": [],
|
||||
"maxFileSize": 1048576,
|
||||
"extractDocstrings": true,
|
||||
"trackCallSites": true
|
||||
}
|
||||
4
vendor/smb2/.env.example
vendored
Normal file
4
vendor/smb2/.env.example
vendored
Normal file
@@ -0,0 +1,4 @@
|
||||
# Copy this to .env and fill in your values.
|
||||
# .env is gitignored and never committed.
|
||||
|
||||
SMB2_TEST_NAS_PASSWORD=your_nas_password_here
|
||||
2
vendor/smb2/.gitattributes
vendored
Normal file
2
vendor/smb2/.gitattributes
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
# Force LF line endings for all text files (consistent with rustfmt.toml newline_style = "Unix")
|
||||
* text=auto eol=lf
|
||||
136
vendor/smb2/.github/workflows/ci.yml
vendored
Normal file
136
vendor/smb2/.github/workflows/ci.yml
vendored
Normal file
@@ -0,0 +1,136 @@
|
||||
name: CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main]
|
||||
pull_request:
|
||||
branches: [main]
|
||||
|
||||
env:
|
||||
CARGO_TERM_COLOR: always
|
||||
|
||||
jobs:
|
||||
check:
|
||||
name: Check (${{ matrix.os }}, rust ${{ matrix.rust }})
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [ubuntu-latest, macos-latest, windows-2025]
|
||||
rust: ["1.85", stable]
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6
|
||||
|
||||
- name: Install Rust toolchain
|
||||
uses: dtolnay/rust-toolchain@master
|
||||
with:
|
||||
toolchain: ${{ matrix.rust }}
|
||||
components: rustfmt, clippy
|
||||
|
||||
- name: Cache cargo registry and target
|
||||
uses: Swatinem/rust-cache@v2
|
||||
|
||||
- name: Check formatting
|
||||
run: cargo fmt --check
|
||||
|
||||
- name: Run clippy lints
|
||||
run: cargo clippy --all-targets -- -D warnings
|
||||
|
||||
- name: Run tests
|
||||
run: cargo test
|
||||
|
||||
- name: Build documentation
|
||||
run: cargo doc --no-deps
|
||||
|
||||
docker-tests:
|
||||
name: Docker integration tests
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6
|
||||
|
||||
- name: Install Rust toolchain
|
||||
uses: dtolnay/rust-toolchain@stable
|
||||
|
||||
- name: Cache cargo registry and target
|
||||
uses: Swatinem/rust-cache@v2
|
||||
|
||||
- name: Start SMB containers
|
||||
run: ./tests/docker/start.sh internal
|
||||
|
||||
- name: Run Docker integration tests
|
||||
run: cargo test --test docker_integration -- --ignored
|
||||
env:
|
||||
RUST_LOG: smb2=info
|
||||
|
||||
- name: Stop containers
|
||||
if: always()
|
||||
run: ./tests/docker/stop.sh
|
||||
|
||||
consumer-tests:
|
||||
name: Consumer integration tests
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6
|
||||
|
||||
- name: Install Rust toolchain
|
||||
uses: dtolnay/rust-toolchain@stable
|
||||
|
||||
- name: Cache cargo registry and target
|
||||
uses: Swatinem/rust-cache@v2
|
||||
|
||||
- name: Start consumer containers
|
||||
run: ./tests/docker/start.sh consumer
|
||||
|
||||
- name: Run consumer integration tests
|
||||
run: cargo test --features testing --test consumer_integration -- --ignored
|
||||
env:
|
||||
RUST_LOG: smb2=info
|
||||
|
||||
- name: Stop containers
|
||||
if: always()
|
||||
run: ./tests/docker/stop.sh
|
||||
|
||||
msrv:
|
||||
name: Verify MSRV (1.85)
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6
|
||||
|
||||
- name: Install Rust toolchain (MSRV)
|
||||
uses: dtolnay/rust-toolchain@master
|
||||
with:
|
||||
toolchain: "1.85"
|
||||
|
||||
- name: Cache cargo registry and target
|
||||
uses: Swatinem/rust-cache@v2
|
||||
|
||||
- name: Check compilation on MSRV
|
||||
run: cargo check
|
||||
env:
|
||||
RUSTFLAGS: "-D warnings"
|
||||
|
||||
ci-ok:
|
||||
name: CI OK
|
||||
runs-on: ubuntu-latest
|
||||
needs: [check, docker-tests, consumer-tests, msrv]
|
||||
if: always()
|
||||
steps:
|
||||
- name: Check all jobs passed
|
||||
run: |
|
||||
if [[ "${{ contains(needs.*.result, 'failure') }}" == "true" ]]; then
|
||||
echo "Some jobs failed"
|
||||
exit 1
|
||||
fi
|
||||
if [[ "${{ contains(needs.*.result, 'cancelled') }}" == "true" ]]; then
|
||||
echo "Some jobs were cancelled"
|
||||
exit 1
|
||||
fi
|
||||
echo "All jobs passed"
|
||||
74
vendor/smb2/.github/workflows/fuzz.yml
vendored
Normal file
74
vendor/smb2/.github/workflows/fuzz.yml
vendored
Normal file
@@ -0,0 +1,74 @@
|
||||
name: Fuzz
|
||||
|
||||
# Short-duration fuzz run: weekly schedule + manual dispatch. Each target
|
||||
# runs for 5 minutes with the committed seed corpus. For longer hunts, run
|
||||
# locally: `cargo +nightly fuzz run <target> -- -max_total_time=1800`.
|
||||
#
|
||||
# We deliberately do NOT fuzz on every push -- runs are too long for that.
|
||||
|
||||
on:
|
||||
schedule:
|
||||
# Mondays 04:15 UTC.
|
||||
- cron: "15 4 * * 1"
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
duration_seconds:
|
||||
description: "Per-target fuzz time (seconds)"
|
||||
required: false
|
||||
default: "300"
|
||||
|
||||
env:
|
||||
CARGO_TERM_COLOR: always
|
||||
|
||||
jobs:
|
||||
fuzz:
|
||||
name: Fuzz ${{ matrix.target }}
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
target:
|
||||
- fuzz_header_parse
|
||||
- fuzz_transform_header_parse
|
||||
- fuzz_compression_transform_header_parse
|
||||
- fuzz_compound_split
|
||||
- fuzz_frame_parse
|
||||
- fuzz_sub_frame_parse
|
||||
- fuzz_negotiate_request_parse
|
||||
- fuzz_negotiate_response_parse
|
||||
- fuzz_create_request_parse
|
||||
- fuzz_create_response_parse
|
||||
- fuzz_query_info_response_parse
|
||||
- fuzz_dfs_referral_response_parse
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6
|
||||
|
||||
- name: Install Rust nightly
|
||||
uses: dtolnay/rust-toolchain@nightly
|
||||
|
||||
- name: Cache cargo registry and target
|
||||
uses: Swatinem/rust-cache@v2
|
||||
with:
|
||||
workspaces: |
|
||||
.
|
||||
fuzz
|
||||
|
||||
- name: Install cargo-fuzz
|
||||
run: cargo install cargo-fuzz
|
||||
|
||||
- name: Run fuzz target
|
||||
env:
|
||||
DURATION: ${{ github.event.inputs.duration_seconds || '300' }}
|
||||
run: |
|
||||
cargo +nightly fuzz run "${{ matrix.target }}" \
|
||||
-- -max_total_time="${DURATION}" -print_final_stats=1
|
||||
|
||||
- name: Upload crash artifacts (if any)
|
||||
if: failure()
|
||||
uses: actions/upload-artifact@v7
|
||||
with:
|
||||
name: fuzz-crash-${{ matrix.target }}
|
||||
path: fuzz/artifacts/${{ matrix.target }}/
|
||||
if-no-files-found: ignore
|
||||
7
vendor/smb2/.gitignore
vendored
Normal file
7
vendor/smb2/.gitignore
vendored
Normal file
@@ -0,0 +1,7 @@
|
||||
.idea/
|
||||
.claude/projects/
|
||||
.claude/worktrees/
|
||||
related-repos/
|
||||
target/
|
||||
.DS_Store
|
||||
.env
|
||||
81
vendor/smb2/Cargo.toml
vendored
Normal file
81
vendor/smb2/Cargo.toml
vendored
Normal file
@@ -0,0 +1,81 @@
|
||||
[package]
|
||||
name = "smb2"
|
||||
version = "0.11.3"
|
||||
edition = "2021"
|
||||
rust-version = "1.85"
|
||||
license = "MIT OR Apache-2.0"
|
||||
description = "Pure-Rust SMB2/3 client library with pipelined I/O"
|
||||
repository = "https://github.com/vdavid/smb2"
|
||||
keywords = ["smb", "smb2", "smb3", "cifs", "network"]
|
||||
categories = ["network-programming", "filesystem"]
|
||||
readme = "README.md"
|
||||
documentation = "https://docs.rs/smb2"
|
||||
exclude = [
|
||||
".github/",
|
||||
"AGENTS.md",
|
||||
"docs/",
|
||||
"justfile",
|
||||
"deny.toml",
|
||||
"clippy.toml",
|
||||
"rustfmt.toml",
|
||||
"related-repos/",
|
||||
]
|
||||
|
||||
[package.metadata.docs.rs]
|
||||
all-features = true
|
||||
|
||||
[dependencies]
|
||||
# Logging facade -- application picks the backend (env_logger, tracing, etc.)
|
||||
log = "0.4"
|
||||
|
||||
# Async runtime agnostic
|
||||
async-trait = "0.1"
|
||||
|
||||
# `FuturesUnordered` for pipelined concurrent `execute` calls.
|
||||
futures-util = { version = "0.3", default-features = false, features = ["std", "async-await"] }
|
||||
|
||||
# Error handling
|
||||
thiserror = "2"
|
||||
|
||||
# Enum conversion derives
|
||||
num_enum = "0.7"
|
||||
|
||||
# Async runtime -- transport layer needs net, io-util, time, sync
|
||||
tokio = { version = "1", features = ["net", "io-util", "time", "sync", "rt"] }
|
||||
|
||||
# Crypto -- signing, encryption, key derivation
|
||||
hmac = "0.13"
|
||||
sha2 = "0.11"
|
||||
aes = "0.9"
|
||||
aes-gcm = "=0.11.0-rc.4"
|
||||
ccm = "=0.6.0-rc.3"
|
||||
cmac = "=0.8.0-rc.5"
|
||||
digest = "0.11"
|
||||
|
||||
# NTLM authentication (MS-NLMP)
|
||||
md-5 = "0.11"
|
||||
md4 = "0.11"
|
||||
|
||||
# Kerberos key derivation (AES string-to-key)
|
||||
pbkdf2 = "0.13"
|
||||
sha1 = "0.11"
|
||||
|
||||
# Cryptographically secure random
|
||||
getrandom = "0.4"
|
||||
|
||||
# Compression
|
||||
lz4_flex = "0.13"
|
||||
|
||||
# Optional: `Serialize` derives on diagnostics types. Off by default.
|
||||
serde = { version = "1", optional = true, features = ["derive"] }
|
||||
|
||||
[features]
|
||||
testing = [] # Enables smb2::testing module for Docker-based test servers
|
||||
fuzzing = [] # Exposes parser entry points for `fuzz/` targets; not for applications
|
||||
serde = ["dep:serde"] # `Serialize` impls on `Diagnostics` types and the protocol enums they embed.
|
||||
|
||||
[dev-dependencies]
|
||||
tokio = { version = "1", features = ["rt-multi-thread", "macros", "time", "net", "io-util"] }
|
||||
proptest = "1"
|
||||
env_logger = "0.11"
|
||||
serde_json = "1" # JSON round-trip tests for the `serde` feature
|
||||
120
vendor/smb2/src/auth/CLAUDE.md
vendored
Normal file
120
vendor/smb2/src/auth/CLAUDE.md
vendored
Normal file
@@ -0,0 +1,120 @@
|
||||
# Auth -- NTLM and Kerberos authentication
|
||||
|
||||
NTLMv2 and Kerberos authentication for SMB2 session setup.
|
||||
|
||||
## Key files
|
||||
|
||||
| File | Purpose |
|
||||
|---|---|
|
||||
| `mod.rs` | Module exports |
|
||||
| `der.rs` | Shared ASN.1/DER primitives (TLV encode/decode) |
|
||||
| `ntlm.rs` | `NtlmAuthenticator` -- 3-message NTLM exchange |
|
||||
| `spnego.rs` | SPNEGO NegTokenInit/NegTokenResp wrapping |
|
||||
| `kerberos/mod.rs` | Kerberos module root, re-exports authenticator |
|
||||
| `kerberos/authenticator.rs` | `KerberosAuthenticator` -- full AS + TGS + AP-REQ flow |
|
||||
| `kerberos/crypto.rs` | AES-CTS, RC4-HMAC, string-to-key, key derivation |
|
||||
| `kerberos/messages.rs` | ASN.1/DER encoding/decoding for Kerberos messages |
|
||||
| `kerberos/kdc.rs` | KDC transport client (UDP/TCP with fallback) |
|
||||
|
||||
## NTLM exchange
|
||||
|
||||
1. `negotiate()` -- builds NEGOTIATE_MESSAGE (Type 1) with default flags
|
||||
2. Server sends CHALLENGE_MESSAGE (Type 2) with server challenge and target info
|
||||
3. `authenticate(&challenge_bytes)` -- builds AUTHENTICATE_MESSAGE (Type 3) with NTLMv2 response
|
||||
|
||||
Only NTLMv2 is supported. NTLMv1 is insecure and not implemented.
|
||||
|
||||
## Kerberos exchange
|
||||
|
||||
`KerberosAuthenticator` performs the full Kerberos flow in three steps:
|
||||
|
||||
1. **AS exchange** (client -> KDC): derive user key from password, build PA-ENC-TIMESTAMP + PA-PAC-REQUEST, send AS-REQ, parse AS-REP, decrypt enc-part with user key to get TGT + AS session key.
|
||||
2. **TGS exchange** (client -> KDC): build AP-REQ wrapping TGT + authenticator (encrypted with AS session key), send TGS-REQ for `cifs/hostname`, parse TGS-REP, decrypt enc-part with AS session key to get service ticket + TGS session key.
|
||||
3. **AP-REQ construction**: build Authenticator with subkey, encrypt with TGS session key, build AP-REQ with service ticket, wrap in SPNEGO NegTokenInit.
|
||||
|
||||
The flow differs from NTLM: Kerberos contacts the KDC directly (async, network I/O), then produces a single token for SESSION_SETUP (usually 1 round-trip with the SMB server).
|
||||
|
||||
### Key usage numbers (RFC 4120 section 7.5.1)
|
||||
|
||||
- 1: PA-ENC-TIMESTAMP encryption
|
||||
- 3: AS-REP EncKDCRepPart decryption
|
||||
- 6: TGS-REQ PA-TGS-REQ Authenticator cksum (body checksum)
|
||||
- 7: AP-REQ Authenticator encryption
|
||||
- 8: TGS-REP EncKDCRepPart decryption (tries 8 first, falls back to 9)
|
||||
|
||||
### Encryption types supported
|
||||
|
||||
- AES-256-CTS-HMAC-SHA1-96 (etype 18) -- preferred
|
||||
- AES-128-CTS-HMAC-SHA1-96 (etype 17)
|
||||
- RC4-HMAC (etype 23) -- legacy
|
||||
|
||||
### Key derivation constants (RFC 3961)
|
||||
|
||||
Three subkeys are derived from each base key + usage number:
|
||||
- **Ke** = DK(key, usage || 0xAA) -- encryption subkey, used for AES-CTS
|
||||
- **Ki** = DK(key, usage || 0x55) -- integrity subkey, used for HMAC inside encrypt/decrypt
|
||||
- **Kc** = DK(key, usage || 0x99) -- checksum subkey, used for standalone checksum/MIC
|
||||
|
||||
Ki and Kc are NOT the same key. Ki is for the HMAC that's appended to ciphertext in the encrypt() function. Kc is for standalone operations like the body checksum in the TGS-REQ Authenticator.
|
||||
|
||||
### Kerberos wire encryption format (AES)
|
||||
|
||||
1. Derive Ke (with 0xAA) and Ki (with 0x55) from base key + usage
|
||||
2. Generate 16-byte random confounder
|
||||
3. plaintext' = confounder || plaintext
|
||||
4. ciphertext = AES-CTS(Ke, iv=0, plaintext')
|
||||
5. hmac = HMAC-SHA1-96(Ki, plaintext') -- 12 bytes
|
||||
6. output = ciphertext || hmac
|
||||
|
||||
## NTLM key derivation flow
|
||||
|
||||
1. `NTOWFv2`: `HMAC-MD5(MD4(password_utf16), uppercase(username) + domain)`
|
||||
2. `NTProofStr`: `HMAC-MD5(NTOWFv2, server_challenge + client_blob)`
|
||||
3. `SessionBaseKey`: `HMAC-MD5(NTOWFv2, NTProofStr)`
|
||||
4. If KEY_EXCH flag: generate random session key, RC4-encrypt with SessionBaseKey
|
||||
5. `ExportedSessionKey` feeds into SP800-108 KDF (in `crypto/kdf.rs`)
|
||||
|
||||
## MIC computation
|
||||
|
||||
Modern servers include `MsvAvTimestamp` in the challenge target info, which triggers MIC validation. When present:
|
||||
1. Add `MsvAvFlags` with bit 0x2 (MIC present) to the target info
|
||||
2. Build the AUTHENTICATE_MESSAGE with a zeroed 16-byte MIC field at offset 72
|
||||
3. Compute `HMAC-MD5(ExportedSessionKey, negotiate_msg || challenge_msg || authenticate_msg)`
|
||||
4. Patch the MIC into bytes 72..88
|
||||
|
||||
The authenticator retains raw bytes of NEGOTIATE and CHALLENGE messages for this computation.
|
||||
|
||||
## Key decisions
|
||||
|
||||
- **`getrandom` for random values**: Client challenge, random session key, nonces, and confounders use `getrandom` (OS CSPRNG).
|
||||
- **`test_random_session_key` override**: Tests can inject a deterministic session key for reproducibility. Never used in production.
|
||||
- **Subkey in AP-REQ Authenticator**: The Kerberos authenticator includes a random subkey, which becomes the SMB session key. This provides forward secrecy.
|
||||
- **No full `authenticate()` unit tests**: The full flow requires a real KDC. Unit tests cover individual steps (encrypt/decrypt roundtrip, message encoding, etype parsing). Integration tests with Docker are planned.
|
||||
|
||||
## Gotchas
|
||||
|
||||
- **Retain raw challenge bytes for MIC (NTLM)**: The MIC is computed over the exact wire bytes of all three messages.
|
||||
- **RC4 for key exchange is inline (NTLM)**: ~15 lines of RC4 implementation.
|
||||
- **MsvAvTimestamp presence changes behavior (NTLM)**: Without it, no MIC is computed. With it, MIC is mandatory.
|
||||
- **NTLMv1 not supported**: No fallback.
|
||||
- **Target info modification (NTLM)**: The client modifies the server's target info before including it in the client blob.
|
||||
- **TGS-REP key usage ambiguity (Kerberos)**: RFC 4120 says key usage 8 for TGS-REP encrypted with session key, but some KDCs use 9. The authenticator tries 8 first, falls back to 9.
|
||||
- **KDC_ERR_PREAUTH_REQUIRED handling (Kerberos)**: First AS-REQ without pre-auth gets error 25. The authenticator extracts supported etypes from the e-data (ETYPE-INFO2) and retries with pre-authentication.
|
||||
- **DER primitives in `auth::der`**: Core DER encoding/decoding helpers (`der_length`, `der_tlv`, `parse_der_length`, `parse_der_tlv`) live in `auth/der.rs` and are shared by `spnego.rs` and `kerberos/messages.rs`. Type-specific helpers (INTEGER, GeneralString, etc.) stay in their respective modules.
|
||||
|
||||
## Kerberos key design decisions (from end-to-end testing)
|
||||
|
||||
- **MS Kerberos OID (`1.2.840.48018.1.2.2`)**: Windows AD requires the Microsoft Kerberos OID in SPNEGO NegTokenInit, not the standard RFC 4120 OID. Both are included in mechTypes, with MS OID first.
|
||||
- **Key usage 11 for SPNEGO AP-REQ Authenticator**: Standard RFC 4120 uses key usage 7 for AP-REQ Authenticator encryption. Windows expects key usage 11 when the AP-REQ is wrapped in SPNEGO (per MS-KILE). Using 7 causes `KRB_AP_ERR_MODIFIED`.
|
||||
- **Session key etype detection**: The TGS-REQ requests AES-256, AES-128, and RC4 (preference order). The KDC picks the session key type from this list — it may differ from the ticket encryption type. The authenticator detects the actual etype from the TGS-REP `EncKDCRepPart.key.keytype` and uses the matching cipher for Authenticator encryption.
|
||||
- **Raw ticket pass-through**: The service ticket bytes must be sent to the SMB server exactly as received from the KDC. Re-encoding the ticket from parsed fields produces different DER and causes `KRB_AP_ERR_MODIFIED`. The `Ticket` struct carries `raw_bytes` for this.
|
||||
- **GSS-API wrapping**: The AP-REQ in SPNEGO NegTokenInit must include the GSS-API OID header (`0x60 len OID ap-req`), not just the raw AP-REQ bytes.
|
||||
- **Mutual authentication**: AP-REQ sets the mutual-required flag. The server returns an AP-REP (in SPNEGO NegTokenResp) containing a server sub-session key. The client decrypts the AP-REP (key usage 12) to extract this subkey, which becomes the SMB session key. This provides cryptographic proof that the server possesses the service key. The AP-REP may arrive in a `STATUS_SUCCESS` response (not always `STATUS_MORE_PROCESSING_REQUIRED`).
|
||||
|
||||
- **Credential cache (ccache) support**: `kerberos/ccache.rs` parses MIT Kerberos ccache files (v3 and v4). Supports loading cached TGTs (skip AS exchange, do TGS) and cached service tickets (skip both AS and TGS). Integrates via `Session::setup_kerberos_from_ccache()` and `KerberosAuthenticator::authenticate_from_ccache()`. `load_ccache()` reads from a path or `$KRB5CCNAME`.
|
||||
|
||||
## Known tech debt (Kerberos)
|
||||
|
||||
- ~~DER helpers duplicated between `spnego.rs` and `kerberos/messages.rs`~~ (resolved: shared `auth/der.rs`)
|
||||
- ~~`kerberos/authenticator.rs` mixes crypto wrappers with protocol flow~~ (resolved: `kerberos_encrypt`, `kerberos_decrypt`, `etype_from_i32`, and `generate_random_key` moved to `kerberos/crypto.rs`)
|
||||
- ~~`#![allow(rustdoc::broken_intra_doc_links)]` hack in `kerberos/messages.rs`~~ (resolved: ASN.1 context tags in doc comments wrapped in backticks)
|
||||
196
vendor/smb2/src/auth/der.rs
vendored
Normal file
196
vendor/smb2/src/auth/der.rs
vendored
Normal file
@@ -0,0 +1,196 @@
|
||||
//! Shared ASN.1/DER encoding and decoding primitives.
|
||||
//!
|
||||
//! These low-level helpers are used by both `spnego.rs` and `kerberos/messages.rs`
|
||||
//! to build and parse DER-encoded structures. Only the core TLV operations live
|
||||
//! here; type-specific helpers (INTEGER, GeneralString, etc.) stay in their
|
||||
//! respective modules.
|
||||
|
||||
use crate::Error;
|
||||
|
||||
/// Encode a DER length field.
|
||||
///
|
||||
/// - Lengths < 128 are encoded as a single byte.
|
||||
/// - Lengths < 256 are encoded as `0x81` followed by one byte.
|
||||
/// - Lengths < 65536 are encoded as `0x82` followed by two bytes (big-endian).
|
||||
pub(crate) fn der_length(len: usize) -> Vec<u8> {
|
||||
if len < 128 {
|
||||
vec![len as u8]
|
||||
} else if len < 256 {
|
||||
vec![0x81, len as u8]
|
||||
} else {
|
||||
vec![0x82, (len >> 8) as u8, (len & 0xff) as u8]
|
||||
}
|
||||
}
|
||||
|
||||
/// Wrap data in a DER TLV (tag-length-value).
|
||||
pub(crate) fn der_tlv(tag: u8, data: &[u8]) -> Vec<u8> {
|
||||
let mut out = vec![tag];
|
||||
out.extend_from_slice(&der_length(data.len()));
|
||||
out.extend_from_slice(data);
|
||||
out
|
||||
}
|
||||
|
||||
/// Parse a DER length field, returning `(length, bytes_consumed)`.
|
||||
pub(crate) fn parse_der_length(data: &[u8]) -> Result<(usize, usize), Error> {
|
||||
if data.is_empty() {
|
||||
return Err(Error::invalid_data("DER: truncated length"));
|
||||
}
|
||||
let first = data[0];
|
||||
if first < 128 {
|
||||
Ok((first as usize, 1))
|
||||
} else if first == 0x81 {
|
||||
if data.len() < 2 {
|
||||
return Err(Error::invalid_data("DER: truncated length (0x81)"));
|
||||
}
|
||||
Ok((data[1] as usize, 2))
|
||||
} else if first == 0x82 {
|
||||
if data.len() < 3 {
|
||||
return Err(Error::invalid_data("DER: truncated length (0x82)"));
|
||||
}
|
||||
let len = ((data[1] as usize) << 8) | (data[2] as usize);
|
||||
Ok((len, 3))
|
||||
} else if first == 0x83 {
|
||||
if data.len() < 4 {
|
||||
return Err(Error::invalid_data("DER: truncated length (0x83)"));
|
||||
}
|
||||
let len = ((data[1] as usize) << 16) | ((data[2] as usize) << 8) | (data[3] as usize);
|
||||
Ok((len, 4))
|
||||
} else {
|
||||
Err(Error::invalid_data(format!(
|
||||
"DER: unsupported length encoding: 0x{first:02x}"
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse a DER TLV, returning `(tag, value_slice, total_bytes_consumed)`.
|
||||
pub(crate) fn parse_der_tlv(data: &[u8]) -> Result<(u8, &[u8], usize), Error> {
|
||||
if data.is_empty() {
|
||||
return Err(Error::invalid_data("DER: truncated TLV"));
|
||||
}
|
||||
let tag = data[0];
|
||||
let (len, len_bytes) = parse_der_length(&data[1..])?;
|
||||
let header_len = 1 + len_bytes;
|
||||
let total = header_len + len;
|
||||
if data.len() < total {
|
||||
return Err(Error::invalid_data(format!(
|
||||
"DER: TLV truncated: need {total} bytes, have {}",
|
||||
data.len()
|
||||
)));
|
||||
}
|
||||
Ok((tag, &data[header_len..total], total))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// =======================================================================
|
||||
// DER length encoding
|
||||
// =======================================================================
|
||||
|
||||
#[test]
|
||||
fn length_single_byte() {
|
||||
assert_eq!(der_length(0), vec![0x00]);
|
||||
assert_eq!(der_length(1), vec![0x01]);
|
||||
assert_eq!(der_length(127), vec![0x7f]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn length_two_byte() {
|
||||
assert_eq!(der_length(128), vec![0x81, 0x80]);
|
||||
assert_eq!(der_length(255), vec![0x81, 0xff]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn length_three_byte() {
|
||||
assert_eq!(der_length(256), vec![0x82, 0x01, 0x00]);
|
||||
assert_eq!(der_length(65535), vec![0x82, 0xff, 0xff]);
|
||||
assert_eq!(der_length(1000), vec![0x82, 0x03, 0xe8]);
|
||||
}
|
||||
|
||||
// =======================================================================
|
||||
// DER TLV encoding
|
||||
// =======================================================================
|
||||
|
||||
#[test]
|
||||
fn tlv_simple() {
|
||||
let result = der_tlv(0x04, &[0x01, 0x02]);
|
||||
assert_eq!(result, vec![0x04, 0x02, 0x01, 0x02]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tlv_empty() {
|
||||
let result = der_tlv(0x30, &[]);
|
||||
assert_eq!(result, vec![0x30, 0x00]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tlv_long_content() {
|
||||
let data = vec![0xaa; 200];
|
||||
let result = der_tlv(0x04, &data);
|
||||
assert_eq!(result[0], 0x04);
|
||||
assert_eq!(result[1], 0x81);
|
||||
assert_eq!(result[2], 200);
|
||||
assert_eq!(result.len(), 3 + 200);
|
||||
}
|
||||
|
||||
// =======================================================================
|
||||
// DER length parsing
|
||||
// =======================================================================
|
||||
|
||||
#[test]
|
||||
fn parse_length_single_byte() {
|
||||
let (len, consumed) = parse_der_length(&[0x05]).unwrap();
|
||||
assert_eq!(len, 5);
|
||||
assert_eq!(consumed, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_length_two_byte() {
|
||||
let (len, consumed) = parse_der_length(&[0x81, 0x80]).unwrap();
|
||||
assert_eq!(len, 128);
|
||||
assert_eq!(consumed, 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_length_three_byte() {
|
||||
let (len, consumed) = parse_der_length(&[0x82, 0x01, 0x00]).unwrap();
|
||||
assert_eq!(len, 256);
|
||||
assert_eq!(consumed, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_length_four_byte() {
|
||||
let (len, consumed) = parse_der_length(&[0x83, 0x01, 0x00, 0x00]).unwrap();
|
||||
assert_eq!(len, 65536);
|
||||
assert_eq!(consumed, 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_length_truncated() {
|
||||
assert!(parse_der_length(&[]).is_err());
|
||||
assert!(parse_der_length(&[0x81]).is_err());
|
||||
assert!(parse_der_length(&[0x82, 0x01]).is_err());
|
||||
assert!(parse_der_length(&[0x83, 0x01, 0x00]).is_err());
|
||||
}
|
||||
|
||||
// =======================================================================
|
||||
// DER TLV parsing
|
||||
// =======================================================================
|
||||
|
||||
#[test]
|
||||
fn parse_tlv_roundtrip() {
|
||||
let original = der_tlv(0x04, &[0xde, 0xad, 0xbe, 0xef]);
|
||||
let (tag, value, total) = parse_der_tlv(&original).unwrap();
|
||||
assert_eq!(tag, 0x04);
|
||||
assert_eq!(value, &[0xde, 0xad, 0xbe, 0xef]);
|
||||
assert_eq!(total, original.len());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_tlv_truncated() {
|
||||
assert!(parse_der_tlv(&[]).is_err());
|
||||
// Tag present, length says 10 bytes but only 2 available
|
||||
assert!(parse_der_tlv(&[0x04, 0x0a, 0x01, 0x02]).is_err());
|
||||
}
|
||||
}
|
||||
1637
vendor/smb2/src/auth/kerberos/authenticator.rs
vendored
Normal file
1637
vendor/smb2/src/auth/kerberos/authenticator.rs
vendored
Normal file
File diff suppressed because it is too large
Load Diff
447
vendor/smb2/src/auth/kerberos/ccache.rs
vendored
Normal file
447
vendor/smb2/src/auth/kerberos/ccache.rs
vendored
Normal file
@@ -0,0 +1,447 @@
|
||||
//! MIT Kerberos credential cache (ccache) file parser.
|
||||
//!
|
||||
//! Reads ccache files (v3 and v4) to extract cached TGTs and service tickets,
|
||||
//! enabling Kerberos authentication without a password when the user already
|
||||
//! has a valid ticket (for example, from `kinit`).
|
||||
//!
|
||||
//! References:
|
||||
//! - MIT Kerberos source: `lib/krb5/ccache/cc_file.c`
|
||||
//! - Format: version(2) + [header(v4)] + default_principal + credentials*
|
||||
|
||||
use crate::error::{Error, Result};
|
||||
use log::debug;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Public types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// A parsed Kerberos credential cache.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CCache {
|
||||
/// File format version (3 or 4).
|
||||
pub version: u16,
|
||||
/// Default principal (typically the user who ran `kinit`).
|
||||
pub default_principal: CcachePrincipal,
|
||||
/// Cached credentials (TGTs and service tickets).
|
||||
pub credentials: Vec<CcacheCredential>,
|
||||
}
|
||||
|
||||
/// A principal name in the ccache.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct CcachePrincipal {
|
||||
/// Name type (1 = KRB_NT_PRINCIPAL, 2 = KRB_NT_SRV_INST, etc.).
|
||||
pub name_type: u32,
|
||||
/// Kerberos realm.
|
||||
pub realm: String,
|
||||
/// Name components (for example, `["smbtest"]` or `["cifs", "server.domain.com"]`).
|
||||
pub components: Vec<String>,
|
||||
}
|
||||
|
||||
/// A single cached credential (ticket + metadata).
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CcacheCredential {
|
||||
/// Client principal.
|
||||
pub client: CcachePrincipal,
|
||||
/// Server (service) principal.
|
||||
pub server: CcachePrincipal,
|
||||
/// Session key encryption type.
|
||||
pub key_etype: u16,
|
||||
/// Session key bytes.
|
||||
pub key_data: Vec<u8>,
|
||||
/// Time the ticket was issued (Unix timestamp).
|
||||
pub authtime: u32,
|
||||
/// Time the ticket becomes valid (Unix timestamp).
|
||||
pub starttime: u32,
|
||||
/// Time the ticket expires (Unix timestamp).
|
||||
pub endtime: u32,
|
||||
/// Time the ticket's renewable lifetime expires (Unix timestamp).
|
||||
pub renew_till: u32,
|
||||
/// Raw ticket bytes (DER-encoded Kerberos Ticket).
|
||||
pub ticket: Vec<u8>,
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Parsing
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Read and parse a ccache file from a filesystem path.
|
||||
///
|
||||
/// Reads `$KRB5CCNAME` if `path` is `None`, falling back to
|
||||
/// `/tmp/krb5cc_<uid>` on Unix.
|
||||
pub fn load_ccache(path: Option<&std::path::Path>) -> Result<CCache> {
|
||||
let path = match path {
|
||||
Some(p) => p.to_path_buf(),
|
||||
None => {
|
||||
if let Ok(env_path) = std::env::var("KRB5CCNAME") {
|
||||
// Strip "FILE:" prefix if present.
|
||||
let p = env_path.strip_prefix("FILE:").unwrap_or(&env_path);
|
||||
std::path::PathBuf::from(p)
|
||||
} else {
|
||||
// Default: /tmp/krb5cc_<uid>
|
||||
return Err(Error::invalid_data(
|
||||
"ccache: no path specified and $KRB5CCNAME not set",
|
||||
));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let data = std::fs::read(&path).map_err(|e| {
|
||||
Error::invalid_data(format!("ccache: failed to read {}: {e}", path.display()))
|
||||
})?;
|
||||
|
||||
parse_ccache(&data)
|
||||
}
|
||||
|
||||
/// Parse a ccache file from raw bytes.
|
||||
pub fn parse_ccache(data: &[u8]) -> Result<CCache> {
|
||||
let mut pos = 0;
|
||||
|
||||
// Version: 2 bytes, big-endian. We support 0x0503 (v3) and 0x0504 (v4).
|
||||
if data.len() < 2 {
|
||||
return Err(Error::invalid_data("ccache: file too short for version"));
|
||||
}
|
||||
let version = read_u16(data, &mut pos)?;
|
||||
if version != 0x0503 && version != 0x0504 {
|
||||
return Err(Error::invalid_data(format!(
|
||||
"ccache: unsupported version 0x{version:04x} (expected 0x0503 or 0x0504)"
|
||||
)));
|
||||
}
|
||||
|
||||
// V4 has a header section after the version.
|
||||
if version == 0x0504 {
|
||||
let header_len = read_u16(data, &mut pos)? as usize;
|
||||
if pos + header_len > data.len() {
|
||||
return Err(Error::invalid_data(
|
||||
"ccache: header extends past end of file",
|
||||
));
|
||||
}
|
||||
// Skip header tags (we don't need them).
|
||||
pos += header_len;
|
||||
}
|
||||
|
||||
// Default principal.
|
||||
let default_principal = read_principal(data, &mut pos)?;
|
||||
|
||||
// Credentials: read until EOF.
|
||||
let mut credentials = Vec::new();
|
||||
while pos < data.len() {
|
||||
match read_credential(data, &mut pos) {
|
||||
Ok(cred) => credentials.push(cred),
|
||||
Err(_) => break, // Treat parse errors at the end as EOF.
|
||||
}
|
||||
}
|
||||
|
||||
debug!(
|
||||
"ccache: parsed v{}, principal={}@{}, {} credentials",
|
||||
version & 0xFF,
|
||||
default_principal.components.join("/"),
|
||||
default_principal.realm,
|
||||
credentials.len()
|
||||
);
|
||||
|
||||
Ok(CCache {
|
||||
version,
|
||||
default_principal,
|
||||
credentials,
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Lookup
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
impl CCache {
|
||||
/// Find a cached service ticket for the given SPN and realm.
|
||||
///
|
||||
/// Looks for a credential where the server principal matches
|
||||
/// `service/hostname@realm` (case-insensitive hostname comparison).
|
||||
pub fn find_service_ticket(
|
||||
&self,
|
||||
service: &str,
|
||||
hostname: &str,
|
||||
realm: &str,
|
||||
) -> Option<&CcacheCredential> {
|
||||
self.credentials.iter().find(|c| {
|
||||
c.server.realm.eq_ignore_ascii_case(realm)
|
||||
&& c.server.components.len() == 2
|
||||
&& c.server.components[0].eq_ignore_ascii_case(service)
|
||||
&& c.server.components[1].eq_ignore_ascii_case(hostname)
|
||||
})
|
||||
}
|
||||
|
||||
/// Find a cached TGT for the given realm.
|
||||
///
|
||||
/// Looks for a credential where the server principal is `krbtgt/REALM@REALM`.
|
||||
pub fn find_tgt(&self, realm: &str) -> Option<&CcacheCredential> {
|
||||
self.credentials.iter().find(|c| {
|
||||
c.server.realm.eq_ignore_ascii_case(realm)
|
||||
&& c.server.components.len() == 2
|
||||
&& c.server.components[0] == "krbtgt"
|
||||
&& c.server.components[1].eq_ignore_ascii_case(realm)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Binary reading helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
fn read_u8(data: &[u8], pos: &mut usize) -> Result<u8> {
|
||||
if *pos >= data.len() {
|
||||
return Err(Error::invalid_data("ccache: unexpected end of data"));
|
||||
}
|
||||
let val = data[*pos];
|
||||
*pos += 1;
|
||||
Ok(val)
|
||||
}
|
||||
|
||||
fn read_u16(data: &[u8], pos: &mut usize) -> Result<u16> {
|
||||
if *pos + 2 > data.len() {
|
||||
return Err(Error::invalid_data("ccache: unexpected end of data"));
|
||||
}
|
||||
let val = u16::from_be_bytes([data[*pos], data[*pos + 1]]);
|
||||
*pos += 2;
|
||||
Ok(val)
|
||||
}
|
||||
|
||||
fn read_u32(data: &[u8], pos: &mut usize) -> Result<u32> {
|
||||
if *pos + 4 > data.len() {
|
||||
return Err(Error::invalid_data("ccache: unexpected end of data"));
|
||||
}
|
||||
let val = u32::from_be_bytes([data[*pos], data[*pos + 1], data[*pos + 2], data[*pos + 3]]);
|
||||
*pos += 4;
|
||||
Ok(val)
|
||||
}
|
||||
|
||||
fn read_bytes(data: &[u8], pos: &mut usize, len: usize) -> Result<Vec<u8>> {
|
||||
if *pos + len > data.len() {
|
||||
return Err(Error::invalid_data("ccache: unexpected end of data"));
|
||||
}
|
||||
let val = data[*pos..*pos + len].to_vec();
|
||||
*pos += len;
|
||||
Ok(val)
|
||||
}
|
||||
|
||||
fn read_string(data: &[u8], pos: &mut usize) -> Result<String> {
|
||||
let len = read_u32(data, pos)? as usize;
|
||||
let bytes = read_bytes(data, pos, len)?;
|
||||
String::from_utf8(bytes).map_err(|_| Error::invalid_data("ccache: invalid UTF-8 in string"))
|
||||
}
|
||||
|
||||
fn read_principal(data: &[u8], pos: &mut usize) -> Result<CcachePrincipal> {
|
||||
let name_type = read_u32(data, pos)?;
|
||||
let num_components = read_u32(data, pos)?;
|
||||
let realm = read_string(data, pos)?;
|
||||
let mut components = Vec::with_capacity(num_components as usize);
|
||||
for _ in 0..num_components {
|
||||
components.push(read_string(data, pos)?);
|
||||
}
|
||||
Ok(CcachePrincipal {
|
||||
name_type,
|
||||
realm,
|
||||
components,
|
||||
})
|
||||
}
|
||||
|
||||
fn read_keyblock(data: &[u8], pos: &mut usize) -> Result<(u16, Vec<u8>)> {
|
||||
let enctype = read_u16(data, pos)?;
|
||||
let key_len = read_u32(data, pos)? as usize;
|
||||
let key_data = read_bytes(data, pos, key_len)?;
|
||||
Ok((enctype, key_data))
|
||||
}
|
||||
|
||||
fn read_credential(data: &[u8], pos: &mut usize) -> Result<CcacheCredential> {
|
||||
let client = read_principal(data, pos)?;
|
||||
let server = read_principal(data, pos)?;
|
||||
let (key_etype, key_data) = read_keyblock(data, pos)?;
|
||||
let authtime = read_u32(data, pos)?;
|
||||
let starttime = read_u32(data, pos)?;
|
||||
let endtime = read_u32(data, pos)?;
|
||||
let renew_till = read_u32(data, pos)?;
|
||||
let _is_skey = read_u8(data, pos)?;
|
||||
let _ticket_flags = read_u32(data, pos)?;
|
||||
|
||||
// Addresses (count + entries).
|
||||
let addr_count = read_u32(data, pos)?;
|
||||
for _ in 0..addr_count {
|
||||
let _addr_type = read_u16(data, pos)?;
|
||||
let addr_len = read_u32(data, pos)? as usize;
|
||||
*pos += addr_len; // skip address data
|
||||
}
|
||||
|
||||
// Auth data (count + entries).
|
||||
let authdata_count = read_u32(data, pos)?;
|
||||
for _ in 0..authdata_count {
|
||||
let _ad_type = read_u16(data, pos)?;
|
||||
let ad_len = read_u32(data, pos)? as usize;
|
||||
*pos += ad_len; // skip authdata
|
||||
}
|
||||
|
||||
// Ticket.
|
||||
let ticket_len = read_u32(data, pos)? as usize;
|
||||
let ticket = read_bytes(data, pos, ticket_len)?;
|
||||
|
||||
// Second ticket.
|
||||
let second_ticket_len = read_u32(data, pos)? as usize;
|
||||
let _second_ticket = read_bytes(data, pos, second_ticket_len)?;
|
||||
|
||||
Ok(CcacheCredential {
|
||||
client,
|
||||
server,
|
||||
key_etype,
|
||||
key_data,
|
||||
authtime,
|
||||
starttime,
|
||||
endtime,
|
||||
renew_till,
|
||||
ticket,
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn parse_v4_ccache_from_fixture() {
|
||||
let data = include_bytes!("../../../tests/fixtures/test.ccache");
|
||||
let ccache = parse_ccache(data).expect("failed to parse v4 ccache");
|
||||
|
||||
assert_eq!(ccache.version, 0x0504);
|
||||
assert_eq!(ccache.default_principal.realm, "TEST.LOCAL");
|
||||
assert_eq!(ccache.default_principal.components, vec!["smbtest"]);
|
||||
assert_eq!(ccache.credentials.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_v3_ccache_from_fixture() {
|
||||
let data = include_bytes!("../../../tests/fixtures/test_v3.ccache");
|
||||
let ccache = parse_ccache(data).expect("failed to parse v3 ccache");
|
||||
|
||||
assert_eq!(ccache.version, 0x0503);
|
||||
assert_eq!(ccache.default_principal.realm, "EXAMPLE.COM");
|
||||
assert_eq!(ccache.default_principal.components, vec!["user"]);
|
||||
assert_eq!(ccache.credentials.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tgt_credential_has_correct_fields() {
|
||||
let data = include_bytes!("../../../tests/fixtures/test.ccache");
|
||||
let ccache = parse_ccache(data).unwrap();
|
||||
|
||||
let tgt = &ccache.credentials[0];
|
||||
assert_eq!(tgt.client.realm, "TEST.LOCAL");
|
||||
assert_eq!(tgt.client.components, vec!["smbtest"]);
|
||||
assert_eq!(tgt.server.realm, "TEST.LOCAL");
|
||||
assert_eq!(tgt.server.components, vec!["krbtgt", "TEST.LOCAL"]);
|
||||
assert_eq!(tgt.key_etype, 23); // RC4-HMAC
|
||||
assert_eq!(tgt.key_data.len(), 16);
|
||||
assert_eq!(tgt.authtime, 1744100000);
|
||||
assert_eq!(tgt.endtime, 1744200000);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn service_ticket_has_correct_fields() {
|
||||
let data = include_bytes!("../../../tests/fixtures/test.ccache");
|
||||
let ccache = parse_ccache(data).unwrap();
|
||||
|
||||
let svc = &ccache.credentials[1];
|
||||
assert_eq!(svc.server.components, vec!["cifs", "server.test.local"]);
|
||||
assert_eq!(svc.key_etype, 23);
|
||||
assert_eq!(svc.key_data, (16u8..32).collect::<Vec<_>>());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn find_tgt_by_realm() {
|
||||
let data = include_bytes!("../../../tests/fixtures/test.ccache");
|
||||
let ccache = parse_ccache(data).unwrap();
|
||||
|
||||
let tgt = ccache.find_tgt("TEST.LOCAL");
|
||||
assert!(tgt.is_some());
|
||||
assert_eq!(tgt.unwrap().server.components[0], "krbtgt");
|
||||
|
||||
assert!(ccache.find_tgt("OTHER.REALM").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn find_service_ticket_by_spn() {
|
||||
let data = include_bytes!("../../../tests/fixtures/test.ccache");
|
||||
let ccache = parse_ccache(data).unwrap();
|
||||
|
||||
let svc = ccache.find_service_ticket("cifs", "server.test.local", "TEST.LOCAL");
|
||||
assert!(svc.is_some());
|
||||
assert_eq!(svc.unwrap().key_data, (16u8..32).collect::<Vec<_>>());
|
||||
|
||||
// Case-insensitive hostname.
|
||||
assert!(ccache
|
||||
.find_service_ticket("cifs", "SERVER.TEST.LOCAL", "TEST.LOCAL")
|
||||
.is_some());
|
||||
|
||||
// Wrong hostname.
|
||||
assert!(ccache
|
||||
.find_service_ticket("cifs", "other.test.local", "TEST.LOCAL")
|
||||
.is_none());
|
||||
|
||||
// Wrong service.
|
||||
assert!(ccache
|
||||
.find_service_ticket("ldap", "server.test.local", "TEST.LOCAL")
|
||||
.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn find_tgt_case_insensitive() {
|
||||
let data = include_bytes!("../../../tests/fixtures/test.ccache");
|
||||
let ccache = parse_ccache(data).unwrap();
|
||||
|
||||
assert!(ccache.find_tgt("test.local").is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn v3_ccache_tgt_has_aes256_key() {
|
||||
let data = include_bytes!("../../../tests/fixtures/test_v3.ccache");
|
||||
let ccache = parse_ccache(data).unwrap();
|
||||
|
||||
let tgt = ccache.find_tgt("EXAMPLE.COM").unwrap();
|
||||
assert_eq!(tgt.key_etype, 18); // AES-256
|
||||
assert_eq!(tgt.key_data.len(), 32);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reject_unsupported_version() {
|
||||
let data = [0x05, 0x02]; // v2
|
||||
let result = parse_ccache(&data);
|
||||
assert!(result.is_err());
|
||||
assert!(result
|
||||
.unwrap_err()
|
||||
.to_string()
|
||||
.contains("unsupported version"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reject_truncated_file() {
|
||||
let result = parse_ccache(&[0x05]);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_credentials_list() {
|
||||
// A valid ccache with just a version + principal + no credentials
|
||||
let mut data = vec![0x05, 0x04, 0x00, 0x00]; // v4, no header
|
||||
// Principal: type=1, components=1, realm="R", component="u"
|
||||
data.extend_from_slice(&[0, 0, 0, 1]); // name_type
|
||||
data.extend_from_slice(&[0, 0, 0, 1]); // num_components
|
||||
data.extend_from_slice(&[0, 0, 0, 1]); // realm length
|
||||
data.push(b'R');
|
||||
data.extend_from_slice(&[0, 0, 0, 1]); // component length
|
||||
data.push(b'u');
|
||||
|
||||
let ccache = parse_ccache(&data).unwrap();
|
||||
assert_eq!(ccache.credentials.len(), 0);
|
||||
assert_eq!(ccache.default_principal.realm, "R");
|
||||
assert_eq!(ccache.default_principal.components, vec!["u"]);
|
||||
}
|
||||
}
|
||||
1329
vendor/smb2/src/auth/kerberos/crypto.rs
vendored
Normal file
1329
vendor/smb2/src/auth/kerberos/crypto.rs
vendored
Normal file
File diff suppressed because it is too large
Load Diff
890
vendor/smb2/src/auth/kerberos/kdc.rs
vendored
Normal file
890
vendor/smb2/src/auth/kerberos/kdc.rs
vendored
Normal file
@@ -0,0 +1,890 @@
|
||||
//! KDC (Key Distribution Center) transport client.
|
||||
//!
|
||||
//! Sends AS-REQ and TGS-REQ messages to a Kerberos KDC on port 88.
|
||||
//! Tries UDP first (no framing), falls back to TCP (4-byte big-endian
|
||||
//! length prefix) when the response indicates KRB_ERR_RESPONSE_TOO_BIG
|
||||
//! (error code 52).
|
||||
//!
|
||||
//! Transport details per RFC 4120 section 7.2 and MS-KILE section 2.1:
|
||||
//! - UDP: raw DER bytes, no length prefix, max 65535 bytes
|
||||
//! - TCP: 4-byte big-endian length prefix, then DER bytes
|
||||
//! - Retry: up to 3 attempts with exponential backoff (1s, 2s, 4s)
|
||||
//!
|
||||
//! The functions here are transport-only: they send raw bytes and return
|
||||
//! raw bytes. No ASN.1 parsing beyond detecting error code 52 in the
|
||||
//! UDP-to-TCP fallback path.
|
||||
|
||||
use log::{debug, trace, warn};
|
||||
use std::time::Duration;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::net::{TcpStream, UdpSocket};
|
||||
|
||||
use crate::error::{Error, Result};
|
||||
|
||||
/// Default Kerberos port (RFC 4120).
|
||||
const KERBEROS_PORT: u16 = 88;
|
||||
|
||||
/// Maximum UDP receive buffer size.
|
||||
const UDP_MAX_SIZE: usize = 65535;
|
||||
|
||||
/// KRB_ERR_RESPONSE_TOO_BIG error code (RFC 4120 section 7.2.1).
|
||||
const KRB_ERR_RESPONSE_TOO_BIG: u32 = 52;
|
||||
|
||||
/// Maximum TCP frame size we accept (1 MB, generous for Kerberos).
|
||||
const MAX_KDC_FRAME_SIZE: usize = 1024 * 1024;
|
||||
|
||||
/// Number of retry attempts per transport.
|
||||
const MAX_RETRIES: u32 = 3;
|
||||
|
||||
/// Base retry delay (doubles each attempt).
|
||||
const RETRY_BASE_DELAY: Duration = Duration::from_secs(1);
|
||||
|
||||
/// Configuration for connecting to a KDC.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct KdcConfig {
|
||||
/// KDC address (host:port or just host, defaults to port 88).
|
||||
pub address: String,
|
||||
/// Connection/request timeout.
|
||||
pub timeout: Duration,
|
||||
}
|
||||
|
||||
/// Resolve the KDC address to include a port if not specified.
|
||||
fn resolve_address(address: &str) -> String {
|
||||
if address.contains(':') {
|
||||
address.to_string()
|
||||
} else {
|
||||
format!("{}:{}", address, KERBEROS_PORT)
|
||||
}
|
||||
}
|
||||
|
||||
/// Send a Kerberos message to the KDC and receive the response.
|
||||
///
|
||||
/// Tries UDP first. If the response indicates the message was too
|
||||
/// large for UDP (KRB_ERR_RESPONSE_TOO_BIG), retries with TCP.
|
||||
///
|
||||
/// UDP framing: raw DER bytes, no length prefix.
|
||||
/// TCP framing: 4-byte big-endian length prefix, then DER bytes.
|
||||
pub async fn send_to_kdc(config: &KdcConfig, message: &[u8]) -> Result<Vec<u8>> {
|
||||
let addr = resolve_address(&config.address);
|
||||
debug!("kdc: sending {} bytes to {}", message.len(), addr);
|
||||
|
||||
// Try UDP first.
|
||||
match send_udp(&addr, message, config.timeout).await {
|
||||
Ok(response) => {
|
||||
if is_response_too_big(&response) {
|
||||
debug!("kdc: got KRB_ERR_RESPONSE_TOO_BIG, retrying with TCP");
|
||||
send_tcp(&addr, message, config.timeout).await
|
||||
} else {
|
||||
Ok(response)
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("kdc: UDP failed ({}), falling back to TCP", e);
|
||||
send_tcp(&addr, message, config.timeout).await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Send a Kerberos message via UDP.
|
||||
async fn send_udp(addr: &str, message: &[u8], timeout: Duration) -> Result<Vec<u8>> {
|
||||
let socket = UdpSocket::bind("0.0.0.0:0").await.map_err(Error::Io)?;
|
||||
|
||||
let mut last_err = None;
|
||||
|
||||
for attempt in 0..MAX_RETRIES {
|
||||
if attempt > 0 {
|
||||
let delay = RETRY_BASE_DELAY * 2u32.pow(attempt - 1);
|
||||
debug!("kdc: UDP retry {} after {:?}", attempt, delay);
|
||||
tokio::time::sleep(delay).await;
|
||||
}
|
||||
|
||||
// Send the raw DER bytes (no framing for UDP).
|
||||
match tokio::time::timeout(timeout, socket.send_to(message, addr)).await {
|
||||
Ok(Ok(n)) => {
|
||||
trace!("kdc: UDP sent {} bytes", n);
|
||||
}
|
||||
Ok(Err(e)) => {
|
||||
last_err = Some(Error::Io(e));
|
||||
continue;
|
||||
}
|
||||
Err(_) => {
|
||||
last_err = Some(Error::Timeout);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// Receive the response.
|
||||
let mut buf = vec![0u8; UDP_MAX_SIZE];
|
||||
match tokio::time::timeout(timeout, socket.recv_from(&mut buf)).await {
|
||||
Ok(Ok((n, _src))) => {
|
||||
trace!("kdc: UDP received {} bytes", n);
|
||||
buf.truncate(n);
|
||||
return Ok(buf);
|
||||
}
|
||||
Ok(Err(e)) => {
|
||||
last_err = Some(Error::Io(e));
|
||||
}
|
||||
Err(_) => {
|
||||
last_err = Some(Error::Timeout);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Err(last_err.unwrap_or(Error::Timeout))
|
||||
}
|
||||
|
||||
/// Send a Kerberos message via TCP.
|
||||
async fn send_tcp(addr: &str, message: &[u8], timeout: Duration) -> Result<Vec<u8>> {
|
||||
let mut last_err = None;
|
||||
|
||||
for attempt in 0..MAX_RETRIES {
|
||||
if attempt > 0 {
|
||||
let delay = RETRY_BASE_DELAY * 2u32.pow(attempt - 1);
|
||||
debug!("kdc: TCP retry {} after {:?}", attempt, delay);
|
||||
tokio::time::sleep(delay).await;
|
||||
}
|
||||
|
||||
match send_tcp_once(addr, message, timeout).await {
|
||||
Ok(response) => return Ok(response),
|
||||
Err(e) => {
|
||||
last_err = Some(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Err(last_err.unwrap_or(Error::Timeout))
|
||||
}
|
||||
|
||||
/// Single TCP send/receive attempt.
|
||||
async fn send_tcp_once(addr: &str, message: &[u8], timeout: Duration) -> Result<Vec<u8>> {
|
||||
// Connect with timeout.
|
||||
let mut stream = tokio::time::timeout(timeout, TcpStream::connect(addr))
|
||||
.await
|
||||
.map_err(|_| Error::Timeout)?
|
||||
.map_err(Error::Io)?;
|
||||
|
||||
// Disable Nagle for lower latency.
|
||||
stream.set_nodelay(true).map_err(Error::Io)?;
|
||||
|
||||
// Send: 4-byte big-endian length prefix + DER bytes.
|
||||
let len = message.len() as u32;
|
||||
let len_bytes = len.to_be_bytes();
|
||||
|
||||
tokio::time::timeout(timeout, async {
|
||||
stream.write_all(&len_bytes).await.map_err(Error::Io)?;
|
||||
stream.write_all(message).await.map_err(Error::Io)?;
|
||||
stream.flush().await.map_err(Error::Io)?;
|
||||
trace!("kdc: TCP sent {} bytes", message.len());
|
||||
Ok::<(), Error>(())
|
||||
})
|
||||
.await
|
||||
.map_err(|_| Error::Timeout)??;
|
||||
|
||||
// Receive: 4-byte big-endian length prefix.
|
||||
let mut len_buf = [0u8; 4];
|
||||
tokio::time::timeout(timeout, stream.read_exact(&mut len_buf))
|
||||
.await
|
||||
.map_err(|_| Error::Timeout)?
|
||||
.map_err(|e| {
|
||||
if e.kind() == std::io::ErrorKind::UnexpectedEof {
|
||||
Error::Disconnected
|
||||
} else {
|
||||
Error::Io(e)
|
||||
}
|
||||
})?;
|
||||
|
||||
let resp_len = u32::from_be_bytes(len_buf) as usize;
|
||||
if resp_len > MAX_KDC_FRAME_SIZE {
|
||||
return Err(Error::invalid_data(format!(
|
||||
"KDC TCP response length {} exceeds maximum {}",
|
||||
resp_len, MAX_KDC_FRAME_SIZE
|
||||
)));
|
||||
}
|
||||
|
||||
// Read the response body.
|
||||
let mut buf = vec![0u8; resp_len];
|
||||
tokio::time::timeout(timeout, stream.read_exact(&mut buf))
|
||||
.await
|
||||
.map_err(|_| Error::Timeout)?
|
||||
.map_err(|e| {
|
||||
if e.kind() == std::io::ErrorKind::UnexpectedEof {
|
||||
Error::Disconnected
|
||||
} else {
|
||||
Error::Io(e)
|
||||
}
|
||||
})?;
|
||||
|
||||
trace!("kdc: TCP received {} bytes", resp_len);
|
||||
Ok(buf)
|
||||
}
|
||||
|
||||
/// Detect KRB_ERR_RESPONSE_TOO_BIG (error code 52) in a KRB-ERROR response.
|
||||
///
|
||||
/// KRB-ERROR is APPLICATION [30] (tag 0x7e). We parse just enough DER
|
||||
/// to extract the error-code field (context tag [6]) without a full
|
||||
/// ASN.1 parser.
|
||||
fn is_response_too_big(response: &[u8]) -> bool {
|
||||
// KRB-ERROR starts with APPLICATION [30] = 0x7e.
|
||||
if response.is_empty() || response[0] != 0x7e {
|
||||
return false;
|
||||
}
|
||||
|
||||
match extract_krb_error_code(response) {
|
||||
Some(code) => code == KRB_ERR_RESPONSE_TOO_BIG,
|
||||
None => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract the error-code from a KRB-ERROR message.
|
||||
///
|
||||
/// KRB-ERROR structure (simplified DER):
|
||||
/// ```text
|
||||
/// APPLICATION [30] {
|
||||
/// SEQUENCE {
|
||||
/// [0] pvno INTEGER,
|
||||
/// [1] msg-type INTEGER,
|
||||
/// [2] ctime (optional),
|
||||
/// [3] cusec (optional),
|
||||
/// [4] stime,
|
||||
/// [5] susec,
|
||||
/// [6] error-code INTEGER, <-- we want this
|
||||
/// ...
|
||||
/// }
|
||||
/// }
|
||||
/// ```
|
||||
fn extract_krb_error_code(data: &[u8]) -> Option<u32> {
|
||||
let mut pos = 0;
|
||||
|
||||
// Skip APPLICATION [30] tag.
|
||||
if pos >= data.len() || data[pos] != 0x7e {
|
||||
return None;
|
||||
}
|
||||
pos += 1;
|
||||
pos = skip_der_length(data, pos)?;
|
||||
|
||||
// Skip SEQUENCE tag (0x30).
|
||||
if pos >= data.len() || data[pos] != 0x30 {
|
||||
return None;
|
||||
}
|
||||
pos += 1;
|
||||
pos = skip_der_length(data, pos)?;
|
||||
|
||||
// Now iterate through context-tagged fields until we find [6].
|
||||
loop {
|
||||
if pos >= data.len() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let tag = data[pos];
|
||||
// Context tags are 0xa0..0xbf for constructed.
|
||||
if tag & 0xe0 != 0xa0 {
|
||||
return None;
|
||||
}
|
||||
let tag_num = tag & 0x1f;
|
||||
pos += 1;
|
||||
|
||||
let (field_len, new_pos) = read_der_length(data, pos)?;
|
||||
let field_end = new_pos + field_len;
|
||||
|
||||
if tag_num == 6 {
|
||||
// This field contains an INTEGER with the error code.
|
||||
return parse_der_integer(data, new_pos);
|
||||
}
|
||||
|
||||
pos = field_end;
|
||||
}
|
||||
}
|
||||
|
||||
/// Skip a DER length field and return the position after it.
|
||||
fn skip_der_length(data: &[u8], pos: usize) -> Option<usize> {
|
||||
let (_len, new_pos) = read_der_length(data, pos)?;
|
||||
Some(new_pos)
|
||||
}
|
||||
|
||||
/// Read a DER length field, returning (length, position_after_length).
|
||||
fn read_der_length(data: &[u8], pos: usize) -> Option<(usize, usize)> {
|
||||
if pos >= data.len() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let first = data[pos];
|
||||
match first.cmp(&0x80) {
|
||||
std::cmp::Ordering::Less => {
|
||||
// Short form: length is the byte itself.
|
||||
Some((first as usize, pos + 1))
|
||||
}
|
||||
std::cmp::Ordering::Equal => {
|
||||
// Indefinite length, not used in DER.
|
||||
None
|
||||
}
|
||||
std::cmp::Ordering::Greater => {
|
||||
// Long form: first byte & 0x7f = number of subsequent length bytes.
|
||||
let num_bytes = (first & 0x7f) as usize;
|
||||
if num_bytes > 4 || pos + 1 + num_bytes > data.len() {
|
||||
return None;
|
||||
}
|
||||
let mut length: usize = 0;
|
||||
for i in 0..num_bytes {
|
||||
length = (length << 8) | (data[pos + 1 + i] as usize);
|
||||
}
|
||||
Some((length, pos + 1 + num_bytes))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse a DER INTEGER at the given position, returning its value as u32.
|
||||
fn parse_der_integer(data: &[u8], pos: usize) -> Option<u32> {
|
||||
if pos >= data.len() || data[pos] != 0x02 {
|
||||
return None;
|
||||
}
|
||||
let (len, val_pos) = read_der_length(data, pos + 1)?;
|
||||
if val_pos + len > data.len() || len == 0 || len > 4 {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut value: u32 = 0;
|
||||
for i in 0..len {
|
||||
value = (value << 8) | (data[val_pos + i] as u32);
|
||||
}
|
||||
Some(value)
|
||||
}
|
||||
|
||||
/// Discover KDC addresses for a realm via DNS SRV records.
|
||||
///
|
||||
/// Looks up `_kerberos._udp.{realm}` and `_kerberos._tcp.{realm}`.
|
||||
/// Returns addresses sorted by priority.
|
||||
///
|
||||
/// For now, this is a placeholder -- initial implementation uses
|
||||
/// the hardcoded address from KdcConfig. DNS SRV discovery will
|
||||
/// be added in a future version.
|
||||
pub async fn discover_kdc(_realm: &str) -> Vec<String> {
|
||||
// Placeholder: DNS SRV lookup not yet implemented.
|
||||
// Callers should use KdcConfig.address directly.
|
||||
debug!("kdc: DNS SRV discovery not yet implemented, returning empty list");
|
||||
Vec::new()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tokio::io::AsyncReadExt;
|
||||
use tokio::net::TcpListener;
|
||||
|
||||
// ── DER parsing tests ──────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn read_der_length_short_form() {
|
||||
assert_eq!(read_der_length(&[0x05], 0), Some((5, 1)));
|
||||
assert_eq!(read_der_length(&[0x7f], 0), Some((127, 1)));
|
||||
assert_eq!(read_der_length(&[0x00], 0), Some((0, 1)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_der_length_long_form_one_byte() {
|
||||
// 0x81, 0x80 = 128 bytes
|
||||
assert_eq!(read_der_length(&[0x81, 0x80], 0), Some((128, 2)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_der_length_long_form_two_bytes() {
|
||||
// 0x82, 0x01, 0x00 = 256 bytes
|
||||
assert_eq!(read_der_length(&[0x82, 0x01, 0x00], 0), Some((256, 3)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_der_length_indefinite_returns_none() {
|
||||
assert_eq!(read_der_length(&[0x80], 0), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_der_length_truncated_returns_none() {
|
||||
// Says 2 length bytes follow but only 1 is present.
|
||||
assert_eq!(read_der_length(&[0x82, 0x01], 0), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_der_integer_single_byte() {
|
||||
// INTEGER tag 0x02, length 1, value 52.
|
||||
assert_eq!(parse_der_integer(&[0x02, 0x01, 0x34], 0), Some(52));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_der_integer_two_bytes() {
|
||||
// INTEGER tag 0x02, length 2, value 0x0100 = 256.
|
||||
assert_eq!(parse_der_integer(&[0x02, 0x02, 0x01, 0x00], 0), Some(256));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_der_integer_not_integer_tag() {
|
||||
assert_eq!(parse_der_integer(&[0x03, 0x01, 0x34], 0), None);
|
||||
}
|
||||
|
||||
// ── KRB-ERROR detection tests ──────────────────────────────────
|
||||
|
||||
/// Build a minimal KRB-ERROR with the given error code.
|
||||
///
|
||||
/// This constructs a valid DER-encoded KRB-ERROR with fields:
|
||||
/// [0] pvno = 5, [1] msg-type = 30, [4] stime, [5] susec = 0,
|
||||
/// [6] error-code = the given code.
|
||||
fn build_krb_error(error_code: u32) -> Vec<u8> {
|
||||
// Helper: wrap value in context tag.
|
||||
fn context_tag(tag_num: u8, contents: &[u8]) -> Vec<u8> {
|
||||
let mut out = vec![0xa0 | tag_num];
|
||||
push_der_length(&mut out, contents.len());
|
||||
out.extend_from_slice(contents);
|
||||
out
|
||||
}
|
||||
|
||||
// Helper: encode a DER INTEGER.
|
||||
fn der_integer(value: u32) -> Vec<u8> {
|
||||
// Encode as minimal bytes.
|
||||
let bytes = if value == 0 {
|
||||
vec![0x00]
|
||||
} else if value < 0x80 {
|
||||
vec![value as u8]
|
||||
} else if value < 0x8000 {
|
||||
vec![(value >> 8) as u8, (value & 0xff) as u8]
|
||||
} else if value < 0x800000 {
|
||||
vec![
|
||||
(value >> 16) as u8,
|
||||
(value >> 8) as u8,
|
||||
(value & 0xff) as u8,
|
||||
]
|
||||
} else {
|
||||
vec![
|
||||
(value >> 24) as u8,
|
||||
(value >> 16) as u8,
|
||||
(value >> 8) as u8,
|
||||
(value & 0xff) as u8,
|
||||
]
|
||||
};
|
||||
let mut out = vec![0x02];
|
||||
push_der_length(&mut out, bytes.len());
|
||||
out.extend_from_slice(&bytes);
|
||||
out
|
||||
}
|
||||
|
||||
fn push_der_length(out: &mut Vec<u8>, len: usize) {
|
||||
if len < 0x80 {
|
||||
out.push(len as u8);
|
||||
} else if len < 0x100 {
|
||||
out.push(0x81);
|
||||
out.push(len as u8);
|
||||
} else {
|
||||
out.push(0x82);
|
||||
out.push((len >> 8) as u8);
|
||||
out.push((len & 0xff) as u8);
|
||||
}
|
||||
}
|
||||
|
||||
// Build the SEQUENCE contents.
|
||||
let pvno = context_tag(0, &der_integer(5));
|
||||
let msg_type = context_tag(1, &der_integer(30));
|
||||
// Skip [2] ctime and [3] cusec (optional).
|
||||
// [4] stime: GeneralizedTime "20250101000000Z"
|
||||
let stime_val = b"20250101000000Z";
|
||||
let mut stime_der = vec![0x18]; // GeneralizedTime tag
|
||||
push_der_length(&mut stime_der, stime_val.len());
|
||||
stime_der.extend_from_slice(stime_val);
|
||||
let stime = context_tag(4, &stime_der);
|
||||
let susec = context_tag(5, &der_integer(0));
|
||||
let error_code_field = context_tag(6, &der_integer(error_code));
|
||||
|
||||
let mut seq_contents = Vec::new();
|
||||
seq_contents.extend_from_slice(&pvno);
|
||||
seq_contents.extend_from_slice(&msg_type);
|
||||
seq_contents.extend_from_slice(&stime);
|
||||
seq_contents.extend_from_slice(&susec);
|
||||
seq_contents.extend_from_slice(&error_code_field);
|
||||
|
||||
// Wrap in SEQUENCE.
|
||||
let mut seq = vec![0x30];
|
||||
push_der_length(&mut seq, seq_contents.len());
|
||||
seq.extend_from_slice(&seq_contents);
|
||||
|
||||
// Wrap in APPLICATION [30].
|
||||
let mut msg = vec![0x7e];
|
||||
push_der_length(&mut msg, seq.len());
|
||||
msg.extend_from_slice(&seq);
|
||||
|
||||
msg
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_response_too_big_detects_error_52() {
|
||||
let error = build_krb_error(KRB_ERR_RESPONSE_TOO_BIG);
|
||||
assert!(is_response_too_big(&error));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_response_too_big_ignores_other_errors() {
|
||||
// Error code 6 = KDC_ERR_C_PRINCIPAL_UNKNOWN
|
||||
let error = build_krb_error(6);
|
||||
assert!(!is_response_too_big(&error));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_response_too_big_ignores_non_error_messages() {
|
||||
// AS-REP starts with APPLICATION [11] = 0x6b
|
||||
assert!(!is_response_too_big(&[0x6b, 0x03, 0x30, 0x01, 0x00]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_response_too_big_handles_empty_response() {
|
||||
assert!(!is_response_too_big(&[]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_response_too_big_handles_truncated_response() {
|
||||
// Just the APPLICATION tag and nothing else.
|
||||
assert!(!is_response_too_big(&[0x7e]));
|
||||
assert!(!is_response_too_big(&[0x7e, 0x00]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_error_code_from_valid_krb_error() {
|
||||
let error = build_krb_error(25);
|
||||
assert_eq!(extract_krb_error_code(&error), Some(25));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_error_code_returns_none_for_non_error() {
|
||||
assert_eq!(
|
||||
extract_krb_error_code(&[0x6b, 0x03, 0x30, 0x01, 0x00]),
|
||||
None
|
||||
);
|
||||
}
|
||||
|
||||
// ── Address resolution tests ───────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn resolve_address_adds_default_port() {
|
||||
assert_eq!(resolve_address("kdc.example.com"), "kdc.example.com:88");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_address_preserves_explicit_port() {
|
||||
assert_eq!(
|
||||
resolve_address("kdc.example.com:8888"),
|
||||
"kdc.example.com:8888"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_address_ip_no_port() {
|
||||
assert_eq!(resolve_address("10.0.0.1"), "10.0.0.1:88");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_address_ip_with_port() {
|
||||
assert_eq!(resolve_address("10.0.0.1:88"), "10.0.0.1:88");
|
||||
}
|
||||
|
||||
// ── UDP transport tests ────────────────────────────────────────
|
||||
|
||||
#[tokio::test]
|
||||
async fn udp_send_receive() {
|
||||
// Set up a mock KDC that echoes the request back.
|
||||
let server = UdpSocket::bind("127.0.0.1:0").await.unwrap();
|
||||
let server_addr = server.local_addr().unwrap();
|
||||
|
||||
let server_task = tokio::spawn(async move {
|
||||
let mut buf = vec![0u8; UDP_MAX_SIZE];
|
||||
let (n, src) = server.recv_from(&mut buf).await.unwrap();
|
||||
// Echo back the message.
|
||||
server.send_to(&buf[..n], src).await.unwrap();
|
||||
});
|
||||
|
||||
let message = b"test-kerberos-message";
|
||||
let result = send_udp(&server_addr.to_string(), message, Duration::from_secs(5)).await;
|
||||
|
||||
assert!(
|
||||
result.is_ok(),
|
||||
"UDP send/receive failed: {:?}",
|
||||
result.err()
|
||||
);
|
||||
assert_eq!(result.unwrap(), message);
|
||||
|
||||
server_task.await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn udp_timeout_on_no_response() {
|
||||
// Bind a server socket but never read from it.
|
||||
let server = UdpSocket::bind("127.0.0.1:0").await.unwrap();
|
||||
let server_addr = server.local_addr().unwrap();
|
||||
|
||||
// Use very short timeout and only 1 retry attempt to keep test fast.
|
||||
// We can't change MAX_RETRIES, but we use a very short timeout so
|
||||
// all 3 retries finish quickly.
|
||||
let result = send_udp(
|
||||
&server_addr.to_string(),
|
||||
b"hello",
|
||||
Duration::from_millis(50),
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(result.is_err());
|
||||
assert!(
|
||||
matches!(result.as_ref().unwrap_err(), Error::Timeout),
|
||||
"expected Timeout, got: {:?}",
|
||||
result.unwrap_err()
|
||||
);
|
||||
|
||||
drop(server);
|
||||
}
|
||||
|
||||
// ── TCP transport tests ────────────────────────────────────────
|
||||
|
||||
#[tokio::test]
|
||||
async fn tcp_send_receive() {
|
||||
// Set up a mock KDC that reads a length-prefixed message and
|
||||
// sends back a length-prefixed response.
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let addr = listener.local_addr().unwrap();
|
||||
|
||||
let server_task = tokio::spawn(async move {
|
||||
let (mut stream, _) = listener.accept().await.unwrap();
|
||||
|
||||
// Read 4-byte length prefix.
|
||||
let mut len_buf = [0u8; 4];
|
||||
stream.read_exact(&mut len_buf).await.unwrap();
|
||||
let msg_len = u32::from_be_bytes(len_buf) as usize;
|
||||
|
||||
// Read the message body.
|
||||
let mut msg = vec![0u8; msg_len];
|
||||
stream.read_exact(&mut msg).await.unwrap();
|
||||
|
||||
// Echo back with length prefix.
|
||||
let response = b"kdc-response";
|
||||
let resp_len = (response.len() as u32).to_be_bytes();
|
||||
stream.write_all(&resp_len).await.unwrap();
|
||||
stream.write_all(response).await.unwrap();
|
||||
stream.flush().await.unwrap();
|
||||
});
|
||||
|
||||
let result = send_tcp(&addr.to_string(), b"test-request", Duration::from_secs(5)).await;
|
||||
|
||||
assert!(
|
||||
result.is_ok(),
|
||||
"TCP send/receive failed: {:?}",
|
||||
result.err()
|
||||
);
|
||||
assert_eq!(result.unwrap(), b"kdc-response");
|
||||
|
||||
server_task.await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn tcp_timeout_on_no_response() {
|
||||
// Set up a server that accepts but never responds.
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let addr = listener.local_addr().unwrap();
|
||||
|
||||
let server_task = tokio::spawn(async move {
|
||||
let (stream, _) = listener.accept().await.unwrap();
|
||||
// Hold the connection open but never respond.
|
||||
tokio::time::sleep(Duration::from_secs(10)).await;
|
||||
drop(stream);
|
||||
});
|
||||
|
||||
let result = send_tcp_once(&addr.to_string(), b"hello", Duration::from_millis(100)).await;
|
||||
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err();
|
||||
assert!(
|
||||
matches!(err, Error::Timeout),
|
||||
"expected Timeout, got: {err}"
|
||||
);
|
||||
|
||||
server_task.abort();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn tcp_truncated_response() {
|
||||
// Server sends a length prefix saying 100 bytes, then disconnects.
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let addr = listener.local_addr().unwrap();
|
||||
|
||||
let server_task = tokio::spawn(async move {
|
||||
let (mut stream, _) = listener.accept().await.unwrap();
|
||||
|
||||
// Read the request (don't care about contents).
|
||||
let mut len_buf = [0u8; 4];
|
||||
let _ = stream.read_exact(&mut len_buf).await;
|
||||
let msg_len = u32::from_be_bytes(len_buf) as usize;
|
||||
let mut discard = vec![0u8; msg_len];
|
||||
let _ = stream.read_exact(&mut discard).await;
|
||||
|
||||
// Send response with length 100 but only 5 bytes of data, then close.
|
||||
let resp_len = 100u32.to_be_bytes();
|
||||
stream.write_all(&resp_len).await.unwrap();
|
||||
stream
|
||||
.write_all(&[0x01, 0x02, 0x03, 0x04, 0x05])
|
||||
.await
|
||||
.unwrap();
|
||||
stream.shutdown().await.unwrap();
|
||||
});
|
||||
|
||||
let result = send_tcp_once(&addr.to_string(), b"hello", Duration::from_secs(5)).await;
|
||||
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err();
|
||||
assert!(
|
||||
matches!(err, Error::Disconnected),
|
||||
"expected Disconnected for truncated response, got: {err}"
|
||||
);
|
||||
|
||||
server_task.await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn tcp_oversized_length_rejected() {
|
||||
// Server sends a length prefix larger than MAX_KDC_FRAME_SIZE.
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let addr = listener.local_addr().unwrap();
|
||||
|
||||
let server_task = tokio::spawn(async move {
|
||||
let (mut stream, _) = listener.accept().await.unwrap();
|
||||
|
||||
// Read request.
|
||||
let mut len_buf = [0u8; 4];
|
||||
let _ = stream.read_exact(&mut len_buf).await;
|
||||
let msg_len = u32::from_be_bytes(len_buf) as usize;
|
||||
let mut discard = vec![0u8; msg_len];
|
||||
let _ = stream.read_exact(&mut discard).await;
|
||||
|
||||
// Send absurdly large length.
|
||||
let resp_len = (MAX_KDC_FRAME_SIZE as u32 + 1).to_be_bytes();
|
||||
stream.write_all(&resp_len).await.unwrap();
|
||||
stream.flush().await.unwrap();
|
||||
tokio::time::sleep(Duration::from_secs(1)).await;
|
||||
});
|
||||
|
||||
let result = send_tcp_once(&addr.to_string(), b"hello", Duration::from_secs(5)).await;
|
||||
|
||||
assert!(result.is_err());
|
||||
let err_str = result.unwrap_err().to_string();
|
||||
assert!(
|
||||
err_str.contains("exceeds maximum"),
|
||||
"expected 'exceeds maximum' error, got: {err_str}"
|
||||
);
|
||||
|
||||
server_task.abort();
|
||||
}
|
||||
|
||||
// ── send_to_kdc tests ──────────────────────────────────────────
|
||||
|
||||
#[tokio::test]
|
||||
async fn send_to_kdc_udp_success() {
|
||||
// Set up a UDP mock KDC.
|
||||
let server = UdpSocket::bind("127.0.0.1:0").await.unwrap();
|
||||
let server_addr = server.local_addr().unwrap();
|
||||
|
||||
let server_task = tokio::spawn(async move {
|
||||
let mut buf = vec![0u8; UDP_MAX_SIZE];
|
||||
let (n, src) = server.recv_from(&mut buf).await.unwrap();
|
||||
// Respond with a fake AS-REP (not a KRB-ERROR).
|
||||
let response = b"\x6b\x05\x30\x03\x02\x01\x05"; // Fake AS-REP-like
|
||||
server.send_to(response, src).await.unwrap();
|
||||
drop(buf[..n].to_vec()); // acknowledge we received
|
||||
});
|
||||
|
||||
let config = KdcConfig {
|
||||
address: server_addr.to_string(),
|
||||
timeout: Duration::from_secs(5),
|
||||
};
|
||||
|
||||
let result = send_to_kdc(&config, b"as-req").await;
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(result.unwrap(), b"\x6b\x05\x30\x03\x02\x01\x05");
|
||||
|
||||
server_task.await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn send_to_kdc_udp_too_big_falls_back_to_tcp() {
|
||||
// Set up a UDP server that returns KRB_ERR_RESPONSE_TOO_BIG
|
||||
// and a TCP server that returns a real response. The fallback
|
||||
// path uses one `KdcConfig.address`, so both servers must share
|
||||
// a port.
|
||||
//
|
||||
// Bind TCP first (more restrictive) and then UDP to its port.
|
||||
// On Windows Server, the OS port allocator can hand out an
|
||||
// ephemeral port that's in an excluded range for the other
|
||||
// protocol (WSAEACCES / 10013). Retry a few times if so;
|
||||
// a fresh `:0` lottery picks a different port each attempt.
|
||||
let (udp_server, tcp_listener) = {
|
||||
let mut last_err: Option<std::io::Error> = None;
|
||||
let mut bound = None;
|
||||
for _ in 0..10 {
|
||||
let tcp = match TcpListener::bind("127.0.0.1:0").await {
|
||||
Ok(l) => l,
|
||||
Err(e) => {
|
||||
last_err = Some(e);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
let port = tcp.local_addr().unwrap().port();
|
||||
match UdpSocket::bind(format!("127.0.0.1:{port}")).await {
|
||||
Ok(udp) => {
|
||||
bound = Some((udp, tcp));
|
||||
break;
|
||||
}
|
||||
Err(e) => {
|
||||
last_err = Some(e);
|
||||
// TCP listener drops here; try a new port.
|
||||
}
|
||||
}
|
||||
}
|
||||
bound.unwrap_or_else(|| {
|
||||
panic!("could not co-bind UDP+TCP on a shared loopback port in 10 attempts: {last_err:?}")
|
||||
})
|
||||
};
|
||||
let udp_addr = udp_server.local_addr().unwrap();
|
||||
|
||||
let udp_task = tokio::spawn(async move {
|
||||
let mut buf = vec![0u8; UDP_MAX_SIZE];
|
||||
let (_, src) = udp_server.recv_from(&mut buf).await.unwrap();
|
||||
let error = build_krb_error(KRB_ERR_RESPONSE_TOO_BIG);
|
||||
udp_server.send_to(&error, src).await.unwrap();
|
||||
});
|
||||
|
||||
let tcp_task = tokio::spawn(async move {
|
||||
let (mut stream, _) = tcp_listener.accept().await.unwrap();
|
||||
// Read request.
|
||||
let mut len_buf = [0u8; 4];
|
||||
stream.read_exact(&mut len_buf).await.unwrap();
|
||||
let msg_len = u32::from_be_bytes(len_buf) as usize;
|
||||
let mut msg = vec![0u8; msg_len];
|
||||
stream.read_exact(&mut msg).await.unwrap();
|
||||
|
||||
// Send TCP response.
|
||||
let response = b"tcp-kdc-response";
|
||||
let resp_len = (response.len() as u32).to_be_bytes();
|
||||
stream.write_all(&resp_len).await.unwrap();
|
||||
stream.write_all(response).await.unwrap();
|
||||
stream.flush().await.unwrap();
|
||||
});
|
||||
|
||||
let config = KdcConfig {
|
||||
address: udp_addr.to_string(),
|
||||
timeout: Duration::from_secs(5),
|
||||
};
|
||||
|
||||
let result = send_to_kdc(&config, b"as-req-large").await;
|
||||
assert!(result.is_ok(), "send_to_kdc failed: {:?}", result.err());
|
||||
assert_eq!(result.unwrap(), b"tcp-kdc-response");
|
||||
|
||||
udp_task.await.unwrap();
|
||||
tcp_task.await.unwrap();
|
||||
}
|
||||
|
||||
// ── discover_kdc tests ─────────────────────────────────────────
|
||||
|
||||
#[tokio::test]
|
||||
async fn discover_kdc_returns_empty_placeholder() {
|
||||
let result = discover_kdc("EXAMPLE.COM").await;
|
||||
assert!(result.is_empty());
|
||||
}
|
||||
}
|
||||
1631
vendor/smb2/src/auth/kerberos/messages.rs
vendored
Normal file
1631
vendor/smb2/src/auth/kerberos/messages.rs
vendored
Normal file
File diff suppressed because it is too large
Load Diff
21
vendor/smb2/src/auth/kerberos/mod.rs
vendored
Normal file
21
vendor/smb2/src/auth/kerberos/mod.rs
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
//! Kerberos authentication support.
|
||||
//!
|
||||
//! Implements the cryptographic operations needed for Kerberos authentication
|
||||
//! (etypes 17, 18, 23): string-to-key, key derivation, AES-CTS encryption,
|
||||
//! RC4-HMAC encryption, and checksum computation.
|
||||
//!
|
||||
//! The [`KerberosAuthenticator`] wires all building blocks together into
|
||||
//! a full Kerberos authentication flow: AS exchange, TGS exchange, and
|
||||
//! AP-REQ construction for SMB2 SESSION_SETUP.
|
||||
//!
|
||||
//! The [`ccache`] module supports reading MIT Kerberos credential caches,
|
||||
//! enabling authentication from cached TGTs or service tickets (for example,
|
||||
//! from `kinit`) without requiring a password.
|
||||
|
||||
pub mod ccache;
|
||||
pub mod crypto;
|
||||
pub mod kdc;
|
||||
pub mod messages;
|
||||
|
||||
mod authenticator;
|
||||
pub use authenticator::{KerberosAuthenticator, KerberosCredentials};
|
||||
15
vendor/smb2/src/auth/mod.rs
vendored
Normal file
15
vendor/smb2/src/auth/mod.rs
vendored
Normal file
@@ -0,0 +1,15 @@
|
||||
//! Authentication mechanisms for SMB2.
|
||||
//!
|
||||
//! Supports NTLM authentication (MS-NLMP) and Kerberos authentication
|
||||
//! (RFC 4120, MS-KILE).
|
||||
//!
|
||||
//! Most users don't need this module directly -- [`SmbClient`](crate::SmbClient)
|
||||
//! handles authentication during [`connect`](crate::connect).
|
||||
|
||||
pub(crate) mod der;
|
||||
pub mod kerberos;
|
||||
pub mod ntlm;
|
||||
pub mod spnego;
|
||||
|
||||
pub use kerberos::{KerberosAuthenticator, KerberosCredentials};
|
||||
pub use ntlm::{NtlmAuthenticator, NtlmCredentials};
|
||||
1410
vendor/smb2/src/auth/ntlm.rs
vendored
Normal file
1410
vendor/smb2/src/auth/ntlm.rs
vendored
Normal file
File diff suppressed because it is too large
Load Diff
808
vendor/smb2/src/auth/spnego.rs
vendored
Normal file
808
vendor/smb2/src/auth/spnego.rs
vendored
Normal file
@@ -0,0 +1,808 @@
|
||||
//! SPNEGO (Simple and Protected GSS-API Negotiation Mechanism) token wrapping.
|
||||
//!
|
||||
//! Implements the thin ASN.1/DER wrapper that SMB2 requires around authentication
|
||||
//! tokens (NTLM, Kerberos). The client sends a NegTokenInit with supported
|
||||
//! mechanism OIDs and the first mechanism's token, the server responds with
|
||||
//! NegTokenResp indicating the selected mechanism and its response token, and
|
||||
//! subsequent client messages use NegTokenResp as well.
|
||||
//!
|
||||
//! References:
|
||||
//! - RFC 4178 (SPNEGO)
|
||||
//! - MS-SPNG (Microsoft SPNEGO Extension)
|
||||
|
||||
use super::der::{der_tlv, parse_der_tlv};
|
||||
use crate::Error;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// OID constants (DER-encoded, including tag and length bytes)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// SPNEGO OID: 1.3.6.1.5.5.2
|
||||
pub const OID_SPNEGO: &[u8] = &[0x06, 0x06, 0x2b, 0x06, 0x01, 0x05, 0x05, 0x02];
|
||||
|
||||
/// NTLM (NTLMSSP) OID: 1.3.6.1.4.1.311.2.2.10
|
||||
pub const OID_NTLMSSP: &[u8] = &[
|
||||
0x06, 0x0a, 0x2b, 0x06, 0x01, 0x04, 0x01, 0x82, 0x37, 0x02, 0x02, 0x0a,
|
||||
];
|
||||
|
||||
/// Kerberos OID: 1.2.840.113554.1.2.2 (standard, RFC 4121)
|
||||
pub const OID_KERBEROS: &[u8] = &[
|
||||
0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x12, 0x01, 0x02, 0x02,
|
||||
];
|
||||
|
||||
/// Microsoft Kerberos OID: 1.2.840.48018.1.2.2 (MS-KILE, used by Windows SPNEGO)
|
||||
///
|
||||
/// Windows expects this OID as the primary mechanism in SPNEGO NegTokenInit.
|
||||
/// Using the standard Kerberos OID causes Windows to reject the AP-REQ.
|
||||
pub const OID_MS_KERBEROS: &[u8] = &[
|
||||
0x06, 0x09, 0x2a, 0x86, 0x48, 0x82, 0xf7, 0x12, 0x01, 0x02, 0x02,
|
||||
];
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ASN.1 DER tag constants
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// SEQUENCE tag (constructed).
|
||||
const TAG_SEQUENCE: u8 = 0x30;
|
||||
/// OCTET STRING tag.
|
||||
const TAG_OCTET_STRING: u8 = 0x04;
|
||||
/// ENUMERATED tag.
|
||||
const TAG_ENUMERATED: u8 = 0x0a;
|
||||
/// APPLICATION [0] (constructed) -- wraps the initial NegotiationToken.
|
||||
const TAG_APPLICATION_0: u8 = 0x60;
|
||||
/// Context-specific [0] (constructed).
|
||||
const TAG_CONTEXT_0: u8 = 0xa0;
|
||||
/// Context-specific [1] (constructed).
|
||||
const TAG_CONTEXT_1: u8 = 0xa1;
|
||||
/// Context-specific [2] (constructed).
|
||||
const TAG_CONTEXT_2: u8 = 0xa2;
|
||||
/// Context-specific [3] (constructed).
|
||||
const TAG_CONTEXT_3: u8 = 0xa3;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// NegState enum
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// SPNEGO negotiation state from NegTokenResp.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum NegState {
|
||||
/// Authentication completed successfully.
|
||||
AcceptCompleted,
|
||||
/// Authentication is in progress (more tokens needed).
|
||||
AcceptIncomplete,
|
||||
/// Authentication was rejected.
|
||||
Reject,
|
||||
}
|
||||
|
||||
impl NegState {
|
||||
/// Parse from the DER enumerated value.
|
||||
fn from_value(v: u8) -> Option<NegState> {
|
||||
match v {
|
||||
0 => Some(NegState::AcceptCompleted),
|
||||
1 => Some(NegState::AcceptIncomplete),
|
||||
2 => Some(NegState::Reject),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// NegTokenResp struct
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Parsed SPNEGO NegTokenResp from the server.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct NegTokenResp {
|
||||
/// The negotiation state.
|
||||
pub neg_state: Option<NegState>,
|
||||
/// The selected mechanism OID (raw DER-encoded OID TLV).
|
||||
pub supported_mech: Option<Vec<u8>>,
|
||||
/// The mechanism-specific response token.
|
||||
pub response_token: Option<Vec<u8>>,
|
||||
/// The mechanism list MIC.
|
||||
pub mech_list_mic: Option<Vec<u8>>,
|
||||
}
|
||||
|
||||
// DER encoding/decoding helpers are in `super::der`. Imported at the top.
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Public API: wrapping
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Wrap a mechanism token in a SPNEGO NegTokenInit.
|
||||
///
|
||||
/// The initial token sent by the client. Wraps the raw NTLM or Kerberos
|
||||
/// token with mechanism OID negotiation.
|
||||
///
|
||||
/// Structure (RFC 4178 section 4.2):
|
||||
/// ```text
|
||||
/// APPLICATION [0] {
|
||||
/// OID_SPNEGO,
|
||||
/// [0] { -- NegTokenInit choice tag
|
||||
/// SEQUENCE {
|
||||
/// [0] { SEQUENCE { mechOID1, mechOID2, ... } }, -- mechTypes
|
||||
/// [2] { OCTET STRING { mechToken } } -- mechToken
|
||||
/// }
|
||||
/// }
|
||||
/// }
|
||||
/// ```
|
||||
pub fn wrap_neg_token_init(mech_oids: &[&[u8]], mech_token: &[u8]) -> Vec<u8> {
|
||||
// Build mechTypes: SEQUENCE OF OID
|
||||
let mut mech_list_contents = Vec::new();
|
||||
for oid in mech_oids {
|
||||
mech_list_contents.extend_from_slice(oid);
|
||||
}
|
||||
let mech_list_seq = der_tlv(TAG_SEQUENCE, &mech_list_contents);
|
||||
let mech_types = der_tlv(TAG_CONTEXT_0, &mech_list_seq);
|
||||
|
||||
// Build mechToken: [2] OCTET STRING
|
||||
let mech_token_octet = der_tlv(TAG_OCTET_STRING, mech_token);
|
||||
let mech_token_ctx = der_tlv(TAG_CONTEXT_2, &mech_token_octet);
|
||||
|
||||
// NegTokenInit SEQUENCE
|
||||
let mut init_contents = Vec::new();
|
||||
init_contents.extend_from_slice(&mech_types);
|
||||
init_contents.extend_from_slice(&mech_token_ctx);
|
||||
let init_seq = der_tlv(TAG_SEQUENCE, &init_contents);
|
||||
|
||||
// Wrap in context [0] (NegotiationToken CHOICE for negTokenInit)
|
||||
let choice = der_tlv(TAG_CONTEXT_0, &init_seq);
|
||||
|
||||
// Wrap in APPLICATION [0] with SPNEGO OID
|
||||
let mut app_contents = Vec::new();
|
||||
app_contents.extend_from_slice(OID_SPNEGO);
|
||||
app_contents.extend_from_slice(&choice);
|
||||
der_tlv(TAG_APPLICATION_0, &app_contents)
|
||||
}
|
||||
|
||||
/// Wrap a mechanism token in a SPNEGO NegTokenResp.
|
||||
///
|
||||
/// Used by the client in the second round-trip (for example, the NTLM
|
||||
/// AUTHENTICATE_MESSAGE). Only the responseToken field is set.
|
||||
///
|
||||
/// Structure:
|
||||
/// ```text
|
||||
/// [1] { -- NegotiationToken CHOICE for negTokenResp
|
||||
/// SEQUENCE {
|
||||
/// [2] { OCTET STRING { mechToken } } -- responseToken
|
||||
/// }
|
||||
/// }
|
||||
/// ```
|
||||
pub fn wrap_neg_token_resp(mech_token: &[u8]) -> Vec<u8> {
|
||||
// Build responseToken: [2] OCTET STRING
|
||||
let mech_token_octet = der_tlv(TAG_OCTET_STRING, mech_token);
|
||||
let response_token_ctx = der_tlv(TAG_CONTEXT_2, &mech_token_octet);
|
||||
|
||||
// NegTokenResp SEQUENCE
|
||||
let resp_seq = der_tlv(TAG_SEQUENCE, &response_token_ctx);
|
||||
|
||||
// Wrap in context [1] (NegotiationToken CHOICE for negTokenResp)
|
||||
der_tlv(TAG_CONTEXT_1, &resp_seq)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Public API: parsing
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Parse a SPNEGO NegTokenResp from the server.
|
||||
///
|
||||
/// The input can be either:
|
||||
/// - A bare `[1] { SEQUENCE { ... } }` NegTokenResp
|
||||
/// - An `APPLICATION [0] { OID, [0] { ... } }` wrapping a NegTokenInit2
|
||||
/// (server-initiated SPNEGO, which we parse the inner token from)
|
||||
///
|
||||
/// Extracts the negotiation state, selected mechanism, and response token.
|
||||
pub fn parse_neg_token_resp(data: &[u8]) -> Result<NegTokenResp, Error> {
|
||||
if data.is_empty() {
|
||||
return Err(Error::invalid_data("SPNEGO: empty token"));
|
||||
}
|
||||
|
||||
// Check if this is an APPLICATION [0] wrapper (server-initiated NegTokenInit2)
|
||||
// or a NegTokenResp [1] wrapper.
|
||||
let (tag, value, _) = parse_der_tlv(data)?;
|
||||
|
||||
match tag {
|
||||
TAG_CONTEXT_1 => {
|
||||
// Standard NegTokenResp: [1] { SEQUENCE { ... } }
|
||||
parse_neg_token_resp_inner(value)
|
||||
}
|
||||
TAG_APPLICATION_0 => {
|
||||
// APPLICATION [0] { OID_SPNEGO, [0] { NegTokenInit2 } }
|
||||
// or could contain a [1] { NegTokenResp }
|
||||
// Skip the SPNEGO OID
|
||||
let (oid_tag, _, oid_total) = parse_der_tlv(value)?;
|
||||
if oid_tag != 0x06 {
|
||||
return Err(Error::invalid_data(format!(
|
||||
"SPNEGO: expected OID in APPLICATION [0], got tag 0x{oid_tag:02x}"
|
||||
)));
|
||||
}
|
||||
let remaining = &value[oid_total..];
|
||||
let (inner_tag, inner_value, _) = parse_der_tlv(remaining)?;
|
||||
match inner_tag {
|
||||
TAG_CONTEXT_0 => {
|
||||
// NegTokenInit2 wrapped in [0]: parse as NegTokenInit2
|
||||
// to extract mechTypes (as supportedMech) and mechToken
|
||||
parse_neg_token_init2_as_resp(inner_value)
|
||||
}
|
||||
TAG_CONTEXT_1 => {
|
||||
// NegTokenResp wrapped inside APPLICATION [0]
|
||||
parse_neg_token_resp_inner(inner_value)
|
||||
}
|
||||
_ => Err(Error::invalid_data(format!(
|
||||
"SPNEGO: unexpected tag 0x{inner_tag:02x} inside APPLICATION [0]"
|
||||
))),
|
||||
}
|
||||
}
|
||||
_ => Err(Error::invalid_data(format!(
|
||||
"SPNEGO: expected NegTokenResp [1] or APPLICATION [0], got tag 0x{tag:02x}"
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse the inner SEQUENCE of a NegTokenResp.
|
||||
fn parse_neg_token_resp_inner(data: &[u8]) -> Result<NegTokenResp, Error> {
|
||||
// Expect SEQUENCE
|
||||
let (tag, seq_data, _) = parse_der_tlv(data)?;
|
||||
if tag != TAG_SEQUENCE {
|
||||
return Err(Error::invalid_data(format!(
|
||||
"SPNEGO: expected SEQUENCE in NegTokenResp, got tag 0x{tag:02x}"
|
||||
)));
|
||||
}
|
||||
|
||||
let mut neg_state = None;
|
||||
let mut supported_mech = None;
|
||||
let mut response_token = None;
|
||||
let mut mech_list_mic = None;
|
||||
|
||||
let mut pos = 0;
|
||||
while pos < seq_data.len() {
|
||||
let (ctx_tag, ctx_value, ctx_total) = parse_der_tlv(&seq_data[pos..])?;
|
||||
match ctx_tag {
|
||||
TAG_CONTEXT_0 => {
|
||||
// negState: ENUMERATED
|
||||
let (enum_tag, enum_value, _) = parse_der_tlv(ctx_value)?;
|
||||
if enum_tag != TAG_ENUMERATED {
|
||||
return Err(Error::invalid_data(format!(
|
||||
"SPNEGO: expected ENUMERATED for negState, got tag 0x{enum_tag:02x}"
|
||||
)));
|
||||
}
|
||||
if enum_value.is_empty() {
|
||||
return Err(Error::invalid_data("SPNEGO: empty ENUMERATED for negState"));
|
||||
}
|
||||
neg_state = NegState::from_value(enum_value[0]);
|
||||
if neg_state.is_none() {
|
||||
return Err(Error::invalid_data(format!(
|
||||
"SPNEGO: unknown negState value: {}",
|
||||
enum_value[0]
|
||||
)));
|
||||
}
|
||||
}
|
||||
TAG_CONTEXT_1 => {
|
||||
// supportedMech: OID (the full TLV)
|
||||
supported_mech = Some(ctx_value.to_vec());
|
||||
}
|
||||
TAG_CONTEXT_2 => {
|
||||
// responseToken: OCTET STRING
|
||||
let (oct_tag, oct_value, _) = parse_der_tlv(ctx_value)?;
|
||||
if oct_tag != TAG_OCTET_STRING {
|
||||
return Err(Error::invalid_data(format!(
|
||||
"SPNEGO: expected OCTET STRING for responseToken, got tag 0x{oct_tag:02x}"
|
||||
)));
|
||||
}
|
||||
response_token = Some(oct_value.to_vec());
|
||||
}
|
||||
TAG_CONTEXT_3 => {
|
||||
// mechListMIC: OCTET STRING
|
||||
let (oct_tag, oct_value, _) = parse_der_tlv(ctx_value)?;
|
||||
if oct_tag != TAG_OCTET_STRING {
|
||||
return Err(Error::invalid_data(format!(
|
||||
"SPNEGO: expected OCTET STRING for mechListMIC, got tag 0x{oct_tag:02x}"
|
||||
)));
|
||||
}
|
||||
mech_list_mic = Some(oct_value.to_vec());
|
||||
}
|
||||
_ => {
|
||||
// Unknown context tag, skip it (forward compatibility).
|
||||
}
|
||||
}
|
||||
pos += ctx_total;
|
||||
}
|
||||
|
||||
Ok(NegTokenResp {
|
||||
neg_state,
|
||||
supported_mech,
|
||||
response_token,
|
||||
mech_list_mic,
|
||||
})
|
||||
}
|
||||
|
||||
/// Parse a NegTokenInit2 (server-initiated) and return it as a NegTokenResp.
|
||||
///
|
||||
/// NegTokenInit2 has mechTypes at [0] and mechToken at [2]. We map the
|
||||
/// first mechType to supportedMech and mechToken to responseToken.
|
||||
fn parse_neg_token_init2_as_resp(data: &[u8]) -> Result<NegTokenResp, Error> {
|
||||
let (tag, seq_data, _) = parse_der_tlv(data)?;
|
||||
if tag != TAG_SEQUENCE {
|
||||
return Err(Error::invalid_data(format!(
|
||||
"SPNEGO: expected SEQUENCE in NegTokenInit2, got tag 0x{tag:02x}"
|
||||
)));
|
||||
}
|
||||
|
||||
let mut supported_mech = None;
|
||||
let mut response_token = None;
|
||||
|
||||
let mut pos = 0;
|
||||
while pos < seq_data.len() {
|
||||
let (ctx_tag, ctx_value, ctx_total) = parse_der_tlv(&seq_data[pos..])?;
|
||||
match ctx_tag {
|
||||
TAG_CONTEXT_0 => {
|
||||
// mechTypes: SEQUENCE OF OID -- take the first one
|
||||
let (seq_tag, mech_list_data, _) = parse_der_tlv(ctx_value)?;
|
||||
if seq_tag != TAG_SEQUENCE {
|
||||
return Err(Error::invalid_data(
|
||||
"SPNEGO: expected SEQUENCE for mechTypes",
|
||||
));
|
||||
}
|
||||
if !mech_list_data.is_empty() {
|
||||
// Take the first OID TLV as the supported mech
|
||||
let (oid_tag, _, oid_total) = parse_der_tlv(mech_list_data)?;
|
||||
if oid_tag == 0x06 {
|
||||
supported_mech = Some(mech_list_data[..oid_total].to_vec());
|
||||
}
|
||||
}
|
||||
}
|
||||
TAG_CONTEXT_2 => {
|
||||
// mechToken: OCTET STRING
|
||||
let (oct_tag, oct_value, _) = parse_der_tlv(ctx_value)?;
|
||||
if oct_tag == TAG_OCTET_STRING {
|
||||
response_token = Some(oct_value.to_vec());
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
// Skip reqFlags [1], negHints [3], mechListMIC [4]
|
||||
}
|
||||
}
|
||||
pos += ctx_total;
|
||||
}
|
||||
|
||||
Ok(NegTokenResp {
|
||||
neg_state: None,
|
||||
supported_mech,
|
||||
response_token,
|
||||
mech_list_mic: None,
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// DER primitive tests (der_length, der_tlv, parse_der_length, parse_der_tlv)
|
||||
// live in `auth::der::tests`.
|
||||
|
||||
// =======================================================================
|
||||
// NegTokenInit wrapping tests
|
||||
// =======================================================================
|
||||
|
||||
#[test]
|
||||
fn neg_token_init_starts_with_application_tag() {
|
||||
let token = wrap_neg_token_init(&[OID_NTLMSSP], b"NTLMSSP\0test");
|
||||
assert_eq!(
|
||||
token[0], TAG_APPLICATION_0,
|
||||
"must start with APPLICATION [0]"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn neg_token_init_contains_spnego_oid() {
|
||||
let token = wrap_neg_token_init(&[OID_NTLMSSP], b"NTLMSSP\0test");
|
||||
// The SPNEGO OID value bytes (without the 0x06 tag and 0x06 length)
|
||||
let oid_value = &OID_SPNEGO[2..]; // skip tag+length
|
||||
assert!(
|
||||
token.windows(oid_value.len()).any(|w| w == oid_value),
|
||||
"token must contain SPNEGO OID"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn neg_token_init_contains_mech_oid() {
|
||||
let token = wrap_neg_token_init(&[OID_NTLMSSP], b"test");
|
||||
// The NTLMSSP OID value bytes (without the 0x06 tag)
|
||||
let oid_value = &OID_NTLMSSP[2..]; // skip tag+length
|
||||
assert!(
|
||||
token.windows(oid_value.len()).any(|w| w == oid_value),
|
||||
"token must contain NTLMSSP OID"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn neg_token_init_contains_mech_token() {
|
||||
let mech_token = b"NTLMSSP\0negotiate_payload_here";
|
||||
let token = wrap_neg_token_init(&[OID_NTLMSSP], mech_token);
|
||||
assert!(
|
||||
token.windows(mech_token.len()).any(|w| w == mech_token),
|
||||
"token must contain the raw mech token"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn neg_token_init_multiple_mechs() {
|
||||
let token = wrap_neg_token_init(&[OID_NTLMSSP, OID_KERBEROS], b"tok");
|
||||
// Both OIDs should be present
|
||||
let ntlm_oid_value = &OID_NTLMSSP[2..];
|
||||
let kerb_oid_value = &OID_KERBEROS[2..];
|
||||
assert!(
|
||||
token
|
||||
.windows(ntlm_oid_value.len())
|
||||
.any(|w| w == ntlm_oid_value),
|
||||
"must contain NTLMSSP OID"
|
||||
);
|
||||
assert!(
|
||||
token
|
||||
.windows(kerb_oid_value.len())
|
||||
.any(|w| w == kerb_oid_value),
|
||||
"must contain Kerberos OID"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn neg_token_init_structure_is_valid_der() {
|
||||
let token = wrap_neg_token_init(&[OID_NTLMSSP], b"test_token");
|
||||
// Parse the outer APPLICATION [0]
|
||||
let (tag, value, total) = parse_der_tlv(&token).unwrap();
|
||||
assert_eq!(tag, TAG_APPLICATION_0);
|
||||
assert_eq!(total, token.len(), "entire token should be consumed");
|
||||
|
||||
// Inside: OID_SPNEGO followed by [0] { SEQUENCE { ... } }
|
||||
let (oid_tag, _, oid_total) = parse_der_tlv(value).unwrap();
|
||||
assert_eq!(oid_tag, 0x06, "first element should be OID");
|
||||
|
||||
let (choice_tag, _, _) = parse_der_tlv(&value[oid_total..]).unwrap();
|
||||
assert_eq!(choice_tag, TAG_CONTEXT_0, "second element should be [0]");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn neg_token_init_parseable_structure() {
|
||||
// Wrap a token and verify we can walk the entire structure
|
||||
let mech_token = b"the_raw_ntlm_token";
|
||||
let token = wrap_neg_token_init(&[OID_NTLMSSP], mech_token);
|
||||
|
||||
// APPLICATION [0]
|
||||
let (_, app_value, _) = parse_der_tlv(&token).unwrap();
|
||||
// Skip SPNEGO OID
|
||||
let (_, _, oid_total) = parse_der_tlv(app_value).unwrap();
|
||||
// [0] CHOICE
|
||||
let (_, choice_value, _) = parse_der_tlv(&app_value[oid_total..]).unwrap();
|
||||
// SEQUENCE
|
||||
let (_, seq_value, _) = parse_der_tlv(choice_value).unwrap();
|
||||
// [0] mechTypes
|
||||
let (tag0, ctx0_value, ctx0_total) = parse_der_tlv(seq_value).unwrap();
|
||||
assert_eq!(tag0, TAG_CONTEXT_0);
|
||||
// SEQUENCE OF OID inside mechTypes
|
||||
let (_, mech_list, _) = parse_der_tlv(ctx0_value).unwrap();
|
||||
// First OID should be NTLMSSP
|
||||
assert_eq!(&mech_list[..OID_NTLMSSP.len()], OID_NTLMSSP);
|
||||
|
||||
// [2] mechToken
|
||||
let (tag2, ctx2_value, _) = parse_der_tlv(&seq_value[ctx0_total..]).unwrap();
|
||||
assert_eq!(tag2, TAG_CONTEXT_2);
|
||||
// OCTET STRING
|
||||
let (_, oct_value, _) = parse_der_tlv(ctx2_value).unwrap();
|
||||
assert_eq!(oct_value, mech_token);
|
||||
}
|
||||
|
||||
// =======================================================================
|
||||
// NegTokenResp wrapping tests
|
||||
// =======================================================================
|
||||
|
||||
#[test]
|
||||
fn neg_token_resp_wrap_starts_with_context_1() {
|
||||
let token = wrap_neg_token_resp(b"auth_token");
|
||||
assert_eq!(token[0], TAG_CONTEXT_1, "must start with [1]");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn neg_token_resp_wrap_contains_mech_token() {
|
||||
let mech_token = b"NTLMSSP\0authenticate_payload";
|
||||
let token = wrap_neg_token_resp(mech_token);
|
||||
assert!(
|
||||
token.windows(mech_token.len()).any(|w| w == mech_token),
|
||||
"wrapped token must contain the raw mech token"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn neg_token_resp_wrap_valid_structure() {
|
||||
let mech_token = b"authenticate_me";
|
||||
let token = wrap_neg_token_resp(mech_token);
|
||||
|
||||
// [1]
|
||||
let (tag, ctx1_value, _) = parse_der_tlv(&token).unwrap();
|
||||
assert_eq!(tag, TAG_CONTEXT_1);
|
||||
// SEQUENCE
|
||||
let (tag, seq_value, _) = parse_der_tlv(ctx1_value).unwrap();
|
||||
assert_eq!(tag, TAG_SEQUENCE);
|
||||
// [2] responseToken
|
||||
let (tag, ctx2_value, _) = parse_der_tlv(seq_value).unwrap();
|
||||
assert_eq!(tag, TAG_CONTEXT_2);
|
||||
// OCTET STRING
|
||||
let (tag, oct_value, _) = parse_der_tlv(ctx2_value).unwrap();
|
||||
assert_eq!(tag, TAG_OCTET_STRING);
|
||||
assert_eq!(oct_value, mech_token);
|
||||
}
|
||||
|
||||
// =======================================================================
|
||||
// NegTokenResp parsing tests
|
||||
// =======================================================================
|
||||
|
||||
/// Build a NegTokenResp with known fields for testing.
|
||||
fn build_test_neg_token_resp(
|
||||
neg_state: Option<u8>,
|
||||
supported_mech: Option<&[u8]>,
|
||||
response_token: Option<&[u8]>,
|
||||
mech_list_mic: Option<&[u8]>,
|
||||
) -> Vec<u8> {
|
||||
let mut seq_contents = Vec::new();
|
||||
|
||||
if let Some(state) = neg_state {
|
||||
let enumerated = der_tlv(TAG_ENUMERATED, &[state]);
|
||||
seq_contents.extend_from_slice(&der_tlv(TAG_CONTEXT_0, &enumerated));
|
||||
}
|
||||
|
||||
if let Some(oid) = supported_mech {
|
||||
seq_contents.extend_from_slice(&der_tlv(TAG_CONTEXT_1, oid));
|
||||
}
|
||||
|
||||
if let Some(tok) = response_token {
|
||||
let octet = der_tlv(TAG_OCTET_STRING, tok);
|
||||
seq_contents.extend_from_slice(&der_tlv(TAG_CONTEXT_2, &octet));
|
||||
}
|
||||
|
||||
if let Some(mic) = mech_list_mic {
|
||||
let octet = der_tlv(TAG_OCTET_STRING, mic);
|
||||
seq_contents.extend_from_slice(&der_tlv(TAG_CONTEXT_3, &octet));
|
||||
}
|
||||
|
||||
let seq = der_tlv(TAG_SEQUENCE, &seq_contents);
|
||||
der_tlv(TAG_CONTEXT_1, &seq)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_neg_token_resp_accept_incomplete() {
|
||||
let token = build_test_neg_token_resp(
|
||||
Some(1), // accept-incomplete
|
||||
Some(OID_NTLMSSP),
|
||||
Some(b"challenge_token"),
|
||||
None,
|
||||
);
|
||||
|
||||
let resp = parse_neg_token_resp(&token).unwrap();
|
||||
assert_eq!(resp.neg_state, Some(NegState::AcceptIncomplete));
|
||||
assert_eq!(resp.supported_mech.as_deref(), Some(OID_NTLMSSP));
|
||||
assert_eq!(
|
||||
resp.response_token.as_deref(),
|
||||
Some(&b"challenge_token"[..])
|
||||
);
|
||||
assert!(resp.mech_list_mic.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_neg_token_resp_accept_completed() {
|
||||
let token = build_test_neg_token_resp(Some(0), None, None, None);
|
||||
|
||||
let resp = parse_neg_token_resp(&token).unwrap();
|
||||
assert_eq!(resp.neg_state, Some(NegState::AcceptCompleted));
|
||||
assert!(resp.supported_mech.is_none());
|
||||
assert!(resp.response_token.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_neg_token_resp_reject() {
|
||||
let token = build_test_neg_token_resp(Some(2), None, None, None);
|
||||
|
||||
let resp = parse_neg_token_resp(&token).unwrap();
|
||||
assert_eq!(resp.neg_state, Some(NegState::Reject));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_neg_token_resp_all_fields() {
|
||||
let token = build_test_neg_token_resp(
|
||||
Some(1),
|
||||
Some(OID_NTLMSSP),
|
||||
Some(b"response_data"),
|
||||
Some(b"mic_data"),
|
||||
);
|
||||
|
||||
let resp = parse_neg_token_resp(&token).unwrap();
|
||||
assert_eq!(resp.neg_state, Some(NegState::AcceptIncomplete));
|
||||
assert_eq!(resp.supported_mech.as_deref(), Some(OID_NTLMSSP));
|
||||
assert_eq!(resp.response_token.as_deref(), Some(&b"response_data"[..]));
|
||||
assert_eq!(resp.mech_list_mic.as_deref(), Some(&b"mic_data"[..]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_neg_token_resp_no_fields() {
|
||||
// All fields optional
|
||||
let token = build_test_neg_token_resp(None, None, None, None);
|
||||
|
||||
let resp = parse_neg_token_resp(&token).unwrap();
|
||||
assert!(resp.neg_state.is_none());
|
||||
assert!(resp.supported_mech.is_none());
|
||||
assert!(resp.response_token.is_none());
|
||||
assert!(resp.mech_list_mic.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_neg_token_resp_empty_data_error() {
|
||||
let result = parse_neg_token_resp(&[]);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_neg_token_resp_truncated_error() {
|
||||
// Just a tag byte, no length
|
||||
let result = parse_neg_token_resp(&[TAG_CONTEXT_1]);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_neg_token_resp_wrong_tag_error() {
|
||||
// SEQUENCE tag instead of [1]
|
||||
let data = der_tlv(TAG_SEQUENCE, &[0x00]);
|
||||
let result = parse_neg_token_resp(&data);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_neg_token_resp_unknown_neg_state_error() {
|
||||
let token = build_test_neg_token_resp(Some(99), None, None, None);
|
||||
let result = parse_neg_token_resp(&token);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
// =======================================================================
|
||||
// Cross-validation: construct a realistic server response
|
||||
// =======================================================================
|
||||
|
||||
#[test]
|
||||
fn parse_realistic_server_challenge_response() {
|
||||
// Simulate a typical Samba/Windows SPNEGO response to the first
|
||||
// SESSION_SETUP: accept-incomplete with NTLMSSP OID and an NTLM
|
||||
// challenge token.
|
||||
let ntlm_challenge = b"NTLMSSP\0\x02\x00\x00\x00fake_challenge_data";
|
||||
|
||||
let token = build_test_neg_token_resp(
|
||||
Some(1), // accept-incomplete
|
||||
Some(OID_NTLMSSP),
|
||||
Some(ntlm_challenge),
|
||||
None,
|
||||
);
|
||||
|
||||
let resp = parse_neg_token_resp(&token).unwrap();
|
||||
assert_eq!(resp.neg_state, Some(NegState::AcceptIncomplete));
|
||||
assert_eq!(resp.response_token.as_deref(), Some(&ntlm_challenge[..]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_realistic_server_accept_with_mic() {
|
||||
// Final server response: accept-completed with mechListMIC
|
||||
let mic = [0xaa; 16];
|
||||
let token = build_test_neg_token_resp(Some(0), None, None, Some(&mic));
|
||||
|
||||
let resp = parse_neg_token_resp(&token).unwrap();
|
||||
assert_eq!(resp.neg_state, Some(NegState::AcceptCompleted));
|
||||
assert_eq!(resp.mech_list_mic.as_deref(), Some(&mic[..]));
|
||||
}
|
||||
|
||||
// =======================================================================
|
||||
// Roundtrip: wrap and parse NegTokenResp
|
||||
// =======================================================================
|
||||
|
||||
#[test]
|
||||
fn neg_token_resp_wrap_then_parse() {
|
||||
let mech_token = b"roundtrip_test_token";
|
||||
let wrapped = wrap_neg_token_resp(mech_token);
|
||||
let parsed = parse_neg_token_resp(&wrapped).unwrap();
|
||||
|
||||
// Wrapped with only responseToken, so:
|
||||
assert!(parsed.neg_state.is_none());
|
||||
assert!(parsed.supported_mech.is_none());
|
||||
assert_eq!(parsed.response_token.as_deref(), Some(&mech_token[..]));
|
||||
assert!(parsed.mech_list_mic.is_none());
|
||||
}
|
||||
|
||||
// =======================================================================
|
||||
// Wire capture cross-validation
|
||||
// =======================================================================
|
||||
|
||||
#[test]
|
||||
fn parse_hand_constructed_wire_bytes() {
|
||||
// Hand-constructed NegTokenResp matching what a Windows/Samba server
|
||||
// sends after receiving NegTokenInit with NTLMSSP:
|
||||
//
|
||||
// a1 XX -- [1] NegTokenResp
|
||||
// 30 XX -- SEQUENCE
|
||||
// a0 03 -- [0] negState
|
||||
// 0a 01 01 -- ENUMERATED accept-incomplete (1)
|
||||
// a1 0c -- [1] supportedMech
|
||||
// 06 0a 2b 06 01 04 01 82 37 02 02 0a -- NTLMSSP OID
|
||||
// a2 XX -- [2] responseToken
|
||||
// 04 XX -- OCTET STRING
|
||||
// <ntlm challenge bytes>
|
||||
let ntlm_challenge = b"NTLMSSP\0fake";
|
||||
|
||||
// Build by hand
|
||||
let neg_state_enum = vec![0x0a, 0x01, 0x01]; // ENUMERATED 1
|
||||
let neg_state_ctx = der_tlv(TAG_CONTEXT_0, &neg_state_enum);
|
||||
|
||||
let mech_ctx = der_tlv(TAG_CONTEXT_1, OID_NTLMSSP);
|
||||
|
||||
let resp_octet = der_tlv(TAG_OCTET_STRING, ntlm_challenge);
|
||||
let resp_ctx = der_tlv(TAG_CONTEXT_2, &resp_octet);
|
||||
|
||||
let mut seq_content = Vec::new();
|
||||
seq_content.extend_from_slice(&neg_state_ctx);
|
||||
seq_content.extend_from_slice(&mech_ctx);
|
||||
seq_content.extend_from_slice(&resp_ctx);
|
||||
let seq = der_tlv(TAG_SEQUENCE, &seq_content);
|
||||
let wire_bytes = der_tlv(TAG_CONTEXT_1, &seq);
|
||||
|
||||
let parsed = parse_neg_token_resp(&wire_bytes).unwrap();
|
||||
assert_eq!(parsed.neg_state, Some(NegState::AcceptIncomplete));
|
||||
assert_eq!(parsed.supported_mech.as_deref(), Some(OID_NTLMSSP));
|
||||
assert_eq!(parsed.response_token.as_deref(), Some(&ntlm_challenge[..]));
|
||||
}
|
||||
|
||||
// =======================================================================
|
||||
// OID constant verification
|
||||
// =======================================================================
|
||||
|
||||
#[test]
|
||||
fn oid_constants_are_valid_der() {
|
||||
// Each OID constant should parse as a valid DER TLV with tag 0x06
|
||||
for (name, oid) in [
|
||||
("SPNEGO", OID_SPNEGO),
|
||||
("NTLMSSP", OID_NTLMSSP),
|
||||
("Kerberos", OID_KERBEROS),
|
||||
] {
|
||||
let (tag, _, total) =
|
||||
parse_der_tlv(oid).unwrap_or_else(|e| panic!("{name} OID is not valid DER: {e}"));
|
||||
assert_eq!(tag, 0x06, "{name} OID tag should be 0x06");
|
||||
assert_eq!(total, oid.len(), "{name} OID should be fully consumed");
|
||||
}
|
||||
}
|
||||
|
||||
// =======================================================================
|
||||
// Large token handling
|
||||
// =======================================================================
|
||||
|
||||
#[test]
|
||||
fn neg_token_init_with_large_mech_token() {
|
||||
// Kerberos tokens can be several KB
|
||||
let large_token = vec![0xab; 4096];
|
||||
let wrapped = wrap_neg_token_init(&[OID_KERBEROS], &large_token);
|
||||
|
||||
// Should parse without error
|
||||
let (tag, _, total) = parse_der_tlv(&wrapped).unwrap();
|
||||
assert_eq!(tag, TAG_APPLICATION_0);
|
||||
assert_eq!(total, wrapped.len());
|
||||
|
||||
// The large token should be embedded
|
||||
assert!(
|
||||
wrapped.windows(100).any(|w| w == &large_token[..100]),
|
||||
"large token content must be present"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn neg_token_resp_with_large_response_token() {
|
||||
let large_token = vec![0xcd; 4096];
|
||||
let built = build_test_neg_token_resp(Some(1), None, Some(&large_token), None);
|
||||
let parsed = parse_neg_token_resp(&built).unwrap();
|
||||
assert_eq!(parsed.response_token.as_deref(), Some(&large_token[..]));
|
||||
}
|
||||
}
|
||||
182
vendor/smb2/src/client/CLAUDE.md
vendored
Normal file
182
vendor/smb2/src/client/CLAUDE.md
vendored
Normal file
@@ -0,0 +1,182 @@
|
||||
# Client -- high-level SMB2 API
|
||||
|
||||
Entry point for most users. `SmbClient` wraps `Connection` + `Session` and provides convenience methods for file operations.
|
||||
|
||||
## Key files
|
||||
|
||||
| File | Purpose |
|
||||
|---|---|
|
||||
| `mod.rs` | `SmbClient`, `ClientConfig`, `connect()` shorthand |
|
||||
| `connection.rs` | `Connection` -- credit tracking, message sequencing, signing, encryption, `execute` / `execute_compound` |
|
||||
| `session.rs` | `Session::setup()` -- NTLM auth, key derivation, signing/encryption activation |
|
||||
| `tree.rs` | `Tree` -- share connection, file CRUD, compound and pipelined I/O |
|
||||
| `stream.rs` | `FileDownload` / `FileUpload` / `FileWriter` (owns `Connection` + `Arc<Tree>`, `'static`) / `open_file_writer` -- streaming I/O with progress |
|
||||
| `watcher.rs` | `Watcher` -- directory change notifications via CHANGE_NOTIFY long-poll |
|
||||
| `pipeline.rs` | `Pipeline` / `Op` / `OpResult` -- batched concurrent operations (the core feature) |
|
||||
| `shares.rs` | Share enumeration via IPC$ + srvsvc RPC |
|
||||
| `dfs.rs` | DFS referral IOCTL helper, `DfsResolver` with TTL-based referral cache |
|
||||
|
||||
## Layering
|
||||
|
||||
```
|
||||
SmbClient (owns Connection + Session, stores credentials for reconnect)
|
||||
Connection (TCP transport, credits, message IDs, signing, encryption)
|
||||
Session (NTLM auth, key derivation -- setup mutates Connection)
|
||||
Tree (share-level ops, borrows &mut Connection for each call)
|
||||
extra_connections (HashMap<String, ConnectionEntry> for DFS cross-server)
|
||||
dfs_resolver (DfsResolver with TTL-based referral cache)
|
||||
```
|
||||
|
||||
All `Tree` methods take `&mut Connection` as a parameter. `SmbClient` convenience methods use `connection_for_tree(tree)` to route through the correct connection (primary or DFS extra connection) based on the tree's `server` field.
|
||||
|
||||
## Connection and credits
|
||||
|
||||
- Connection starts with 1 credit (from negotiate). Requests 256 credits in every message.
|
||||
- Multi-credit requests (reads/writes > 64 KB) consume `ceil(payload_size / 65536)` credits and use that many consecutive `MessageId` values. Gaps in `MessageId` sequences cause the server to drop the connection.
|
||||
- Credits flow back from responses via `CreditResponse` header field. The connection tracks available credits and blocks if exhausted.
|
||||
- `STATUS_PENDING` interim responses carry credits but the request isn't done -- keep waiting.
|
||||
|
||||
## Compound requests
|
||||
|
||||
`Connection::execute_compound(&[CompoundOp])` packs multiple operations into a single transport frame. Each sub-request is 8-byte aligned, linked via `NextCommand`. Subsequent related operations use `FileId::SENTINEL` (the server substitutes the real handle from the first CREATE).
|
||||
|
||||
- **Read compound**: CREATE + READ + CLOSE (3 ops, 1 round-trip). Default for `read_file`.
|
||||
- **Write compound**: CREATE + WRITE + FLUSH + CLOSE (4 ops, 1 round-trip). Default for `write_file`.
|
||||
- **Delete compound**: CREATE (DELETE_ON_CLOSE) + CLOSE (2 ops, 1 round-trip). Default for `delete_file` / `delete_directory`.
|
||||
- **Rename compound**: CREATE + SET_INFO + CLOSE (3 ops, 1 round-trip). Default for `rename`.
|
||||
- **Stat compound**: CREATE + QUERY_INFO (basic) + QUERY_INFO (standard) + CLOSE (4 ops, 1 round-trip). Default for `stat`.
|
||||
- **Fs-info compound**: CREATE + QUERY_INFO (FileFsFullSizeInformation) + CLOSE (3 ops, 1 round-trip). Default for `fs_info`.
|
||||
- If CREATE succeeds but a later op fails, the client issues a standalone CLOSE to avoid leaking the handle.
|
||||
|
||||
### Receiving compound responses
|
||||
|
||||
`execute_compound` returns `Result<Vec<Result<Frame>>>`. The outer `Result` is "did the compound hit the wire"; the inner one is per-sub-op (waiter-level: session expired, signature verify, connection dropped mid-await). Sub-op protocol status codes (`STATUS_OBJECT_NAME_NOT_FOUND` etc.) ride in the inner frame's `header.status`, not the inner `Result`. Per MS-SMB2 3.3.4.1.3 the server MAY split the compound response across multiple transport frames (Samba, QNAP, Windows Server in some cases); the receiver task routes each sub-response by `MessageId` so the per-waiter `oneshot::Receiver`s resolve independently and `execute_compound` reassembles the result vector in submission order.
|
||||
|
||||
Most callers use a small `all_or_first_err` helper (see `tree.rs`) that propagates the first inner `Err` as the outer `Err` (matching the pre-Phase-3 shortcircuit behavior) and hands back a `Vec<Frame>` indexable per sub-op. Tolerating partial failure (for example, CREATE ok, READ fails → issue standalone CLOSE with the create's returned `FileId`) keeps the individual inner `Result`s.
|
||||
|
||||
## Batch operations
|
||||
|
||||
`delete_files`, `rename_files`, and `stat_files` issue one `execute_compound` per file. Partial failures are independent — if 3 of 50 files fail, the other 47 still succeed. Each method returns `Vec<Result<T>>` in the same order as the input.
|
||||
|
||||
Decision/Why — sequential execute vs parallel: pre-Phase-3 these methods did "phase 1 send all compounds, phase 2 receive all" for wire-level pipelining. With the new API a caller can re-create that shape by spawning `tokio::spawn` tasks over `conn.clone()`s, each calling `execute_compound`. For cmdr's "delete 50 files" flows the sequential-compound cost is small (one round-trip per file) so we chose simplicity. If a workload needs the extra parallelism later, the refactor is local to each batch method.
|
||||
|
||||
## DFS (Distributed File System) resolution
|
||||
|
||||
Reactive DFS resolution with multi-target failover. When a convenience method gets `STATUS_PATH_NOT_COVERED` (mapped to `ErrorKind::DfsReferral`), it:
|
||||
|
||||
1. Calls `handle_dfs_redirect()` which resolves the referral via `DfsResolver` (cache or IOCTL)
|
||||
2. Tries each target in the referral response (multi-target failover)
|
||||
3. Creates a new connection + session for cross-server targets via `ensure_connection()`
|
||||
4. Tree-connects to the target share via `ensure_tree()`
|
||||
5. Updates the caller's `&mut Tree` in-place to point to the new server/share
|
||||
6. Retries the operation with the resolved remaining path
|
||||
|
||||
**Key design decisions:**
|
||||
- Convenience methods take `&mut Tree` (not `&Tree`) so DFS can update the tree in-place
|
||||
- `disconnect_share` stays as `&Tree` (no redirect on teardown)
|
||||
- Streaming methods (`download`, `upload`) keep `&Tree` because they return handles that borrow the tree for their lifetime
|
||||
- `watch` now returns an *owned* `Watcher` (no lifetime); see the [Watcher pipelining](#watcher-pipelining) section
|
||||
- Batch methods (`delete_files`, `rename_files`, `stat_files`) don't retry per-file; the caller should trigger one single-file operation first to resolve the redirect
|
||||
- `dfs_enabled` flag on `ClientConfig` (default `true`) gates all DFS resolution
|
||||
- Borrow checker requires inlining the connection lookup in `handle_dfs_redirect` to avoid double `&mut self` borrows
|
||||
|
||||
## Watcher pipelining
|
||||
|
||||
`Watcher` keeps **one CHANGE_NOTIFY request pre-issued on the wire at all times** after the first `next_events()` call. The wire never sits idle between responses. This closes the response→re-arm loss window that strict servers (older Samba builds, NAS firmware) drop events through.
|
||||
|
||||
Shape: `Watcher` owns a cloned `Connection` (cheap `Arc::clone`, all clones multiplex over the same SMB session) and a `Tree` clone — no lifetime parameter, no borrow against the caller's `Connection`. `next_events` dispatches the next request via `Connection::dispatch` (a sibling to `execute` that returns once `transport.send().await` completes, handing back the `oneshot::Receiver` for the response) *before* awaiting the previous response. So when control returns to the consumer, the server already has somewhere to put new events.
|
||||
|
||||
Decision/Why — eager-send `dispatch` vs `tokio::spawn(conn.execute(...))`: the spawn-based approach defers the send to when the spawned task is polled, which under tokio's `current_thread` scheduler may not happen until the spawning task yields. That left a gap where the simulator-modeled strict server dropped events. `dispatch` awaits transport.send() inline, so the eager-send guarantee is "after `.await` returns, the request is on the wire" — independent of scheduler.
|
||||
|
||||
Pinned by `client::watcher::loss_window_tests::watcher_does_not_lose_events_between_consecutive_requests`: a strict-server simulator drops events that arrive with no outstanding request. Pre-fix: 5/5 gap events dropped. Post-fix: 0/5 dropped.
|
||||
|
||||
## Pipelined I/O
|
||||
|
||||
For large files, `read_file_pipelined` / `write_file_pipelined` issue multiple `execute_with_credits` calls concurrently on cloned connections via `futures_util::stream::FuturesUnordered`. The sliding window stays at 32 in-flight requests, credits are checked per launch via `conn.credits()`. Chunk size is `min(512 KB, max_read_size)`. This is the core performance feature -- without it, throughput is ~10x worse.
|
||||
|
||||
`FileWriter` owns its `Connection` (cheap `Arc::clone`) and `Arc<Tree>` — no lifetime parameter, no borrow against the `SmbClient` that built it. It keeps an owned `FuturesUnordered<BoxedWriteFut>` field — `launch_wire_chunk` pushes a boxed `execute_with_credits` future, `drain_one` awaits `in_flight.next()`, and the public `write_chunk` / `finish` / `abort` drive that state machine.
|
||||
|
||||
FileWriter provides push-based pipelined writes. The consumer pushes chunks at their own pace via `write_chunk`, with the sliding window handling backpressure. Complement to FileDownload (read streaming). Build one via `open_file_writer(tree, conn, path)` (free function), `Tree::create_file_writer(&Arc<Self>, conn, path)`, or `SmbClient::create_file_writer(&self, tree, path)` — the last clones the client's primary connection internally for convenience.
|
||||
|
||||
## Streaming download entry points
|
||||
|
||||
Two symmetric ways to start a `FileDownload`:
|
||||
|
||||
- `SmbClient::download(&mut self, &Tree, path)` — convenience wrapper that borrows the client's internal `Connection`.
|
||||
- `Tree::download(&self, &mut Connection, path)` — takes the `Connection` directly. Use this when you hold a
|
||||
`conn.clone()` and want to drive concurrent downloads on the same SMB session (each clone pairs with one outstanding
|
||||
download; the receiver task multiplexes responses by `MessageId`). `SmbClient::download` delegates here.
|
||||
|
||||
For full control, `Tree::open_file` (returns `(FileId, u64)`) plus `FileDownload::new` let callers build custom chunk
|
||||
loops with non-default `chunk_size`. Most users shouldn't need this — `read_file_compound` (1 RTT) handles small files
|
||||
and `Tree::download` / `SmbClient::download` handle the streaming case.
|
||||
|
||||
FileWriter has two terminal operations:
|
||||
- `finish()` — send all buffered data, drain in-flight WRITEs, FLUSH (fsync on the server), CLOSE. Use on normal completion.
|
||||
- `abort()` — discard unsent data, drain in-flight WRITEs to keep credits/message-ids in sync, skip FLUSH, best-effort CLOSE. Use on cancellation or error paths where the partial remote file is going to be deleted anyway — `abort()` saves the fsync round-trip. The caller is responsible for deleting the partial remote file.
|
||||
|
||||
Both consume `self` so write-after-close/abort is a compile error. `Drop` logs a debug warning if neither was called (handle leaks).
|
||||
|
||||
## Session setup flow
|
||||
|
||||
1. Send NTLM NEGOTIATE in SESSION_SETUP
|
||||
2. Receive STATUS_MORE_PROCESSING_REQUIRED with challenge, update preauth hash
|
||||
3. Send NTLM AUTHENTICATE in SESSION_SETUP, update preauth hash with request only
|
||||
4. Receive STATUS_SUCCESS (do NOT include in preauth hash)
|
||||
5. Derive signing/encryption keys via SP800-108 KDF
|
||||
6. Activate signing on the connection
|
||||
7. If session or share requires encryption, activate encryption (TRANSFORM_HEADER wrapping with AEAD)
|
||||
|
||||
## Encryption
|
||||
|
||||
Encryption is activated when the session flags include `ENCRYPT_DATA` or a share has `SMB2_SHAREFLAG_ENCRYPT_DATA`. When active:
|
||||
- Outgoing messages are wrapped in TRANSFORM_HEADER (protocol ID 0xFD) with a monotonic nonce
|
||||
- Incoming messages with 0xFD are decrypted before processing
|
||||
- Signing is skipped (AEAD provides authentication)
|
||||
- Compound chains are encrypted as one unit (pitfall #9)
|
||||
|
||||
Tree-level encryption: `connect_share()` checks the share's encrypt flag and activates encryption on the connection if needed, even if the session didn't require it.
|
||||
|
||||
## Reconnection
|
||||
|
||||
`SmbClient::reconnect()` creates a fresh TCP connection, re-negotiates, and re-authenticates using stored credentials. All previous `Tree` handles and `FileId` values are invalidated. The caller must `connect_share` again.
|
||||
|
||||
## Connection internals: receiver task + `oneshot` routing
|
||||
|
||||
`Connection::execute` / `execute_compound` is the primary API. A background receiver task (spawned per `Connection` at `from_transport`) owns the transport's read half and routes each sub-frame to a per-request `oneshot::Sender` by `MessageId`.
|
||||
|
||||
- `Connection` is `Clone` and holds just `Arc<Inner>`. `Inner` owns `waiters: Mutex<HashMap<MessageId, oneshot::Sender<Result<Frame>>>>`, `credits: AtomicU32`, `next_message_id: AtomicU64`, the transport send half (via `Arc<dyn TransportSend>`), the receiver task's `JoinHandle`, and crypto state. All state is behind atomics or short-critical-section `std::sync::Mutex`.
|
||||
- `execute(command, body, tree_id)` allocates a `MessageId` (`AtomicU64::fetch_add(credit_charge)`), registers a `oneshot::Sender` in `waiters` atomically under the waiters lock (re-checks `disconnected` there to rule out a TOCTOU where the receiver task has already shut down and drained the map), packs the frame, signs/encrypts/compresses as needed, and writes through `TransportSend::send`. Then it awaits the local `oneshot::Receiver`. Returns `Result<Frame { header, body, raw }>`.
|
||||
- `execute_compound(&[CompoundOp])` does the same per sub-op, building one compound transport frame with `NextCommand` offsets, then awaits each per-sub-op receiver sequentially. Each receiver resolves independently (the receiver task splits the server's response by `NextCommand` and routes each sub-response by its `MessageId`). The outer `Result` is "did the compound hit the wire"; the inner `Vec<Result<Frame>>` has one entry per sub-op.
|
||||
- **Cancellation-by-drop is safe by construction.** If a caller's future is aborted (`tokio::spawn` + `JoinHandle::abort()` is the common path in consumers), the locally-owned `oneshot::Receiver` drops; the receiver task's `Sender::send` then fails silently when the late frame arrives; the frame is discarded. Credits are still applied in the receiver task so dropped-caller frames don't starve throughput.
|
||||
- **Transport drop** fans `Err(Disconnected)` to every pending `oneshot::Sender` and sets `disconnected=true` under the waiters lock. Subsequent `execute` / `execute_compound` sees `disconnected=true` and returns `Err(Disconnected)` without inserting (no leaked waiters).
|
||||
|
||||
Gotcha/Why — pre-Phase-3 `send_request` / `receive_response` split API was removed in Phase 3 Stage A.3. The test-mode `set_orphan_filter_enabled(false)` escape hatch is gone too; tests that build mocks without going through `setup_connection` call `mock.enable_auto_rewrite_msg_id()` instead, which rewrites each queued response's zero-msg_id to match the next pending sent msg_id in FIFO order.
|
||||
|
||||
Full design in [docs/specs/connection-actor.md](../../docs/specs/connection-actor.md).
|
||||
|
||||
## Key decisions
|
||||
|
||||
- **Owned `FileWriter`: N concurrent streamed writes over one Connection without external locking**: `FileWriter` owns its `Connection` (cheap `Arc::clone`) and `Arc<Tree>` instead of borrowing `&'a mut Connection` from the `SmbClient`. Built via the free `open_file_writer(tree: Arc<Tree>, conn: Connection, path: &str)` or one of the two convenience wrappers (`Tree::create_file_writer`, `SmbClient::create_file_writer`). Multiple writers built from clones of the same `Connection` pipeline their WRITEs over one SMB session — the receiver task multiplexes responses by `MessageId`. The borrowed variant was the root cause of a production-reproducing deadlock in the cmdr SMB volume's `write_from_stream` (Phase C QNAP test, 200 × 7 MB concurrent overwrites): the consumer had to hold its session mutex for the entire upload because the writer borrowed `&'a mut Connection`. Owning the connection removes the lock from the hot path entirely.
|
||||
- **`execute` / `execute_compound` take `&self`**: `Connection: Clone` supports concurrent ops per connection — clone freely across tasks, the receiver task multiplexes responses by `MessageId`. `Tree::*` methods still take `&mut Connection` because session-setup mutators (`activate_signing`, `set_session_id`) keep `&mut self`; Tree code calls both, so `&mut` at that layer is the least-churn choice.
|
||||
- **Sender work stays on the caller thread, only the receiver is a task**: The send path already uses an internal Mutex on the transport write half for ordering; adding a second task just to drive sends would add latency without correctness gain. The receiver bug (orphan/dropped-caller frames corrupting the wire) only existed on the receive side, so only the receive side needed a task.
|
||||
- **Compound reads as default**: One round-trip for small files. Saves 2 RTTs vs sequential CREATE/READ/CLOSE.
|
||||
- **512 KB pipeline chunks**: Balances between too many small requests (overhead) and too few large ones (credit starvation). Gives ~20 chunks per 10 MB file.
|
||||
- **Password stored in `SmbClient`**: Enables reconnect without re-prompting. Not encrypted in memory. Drop when done.
|
||||
|
||||
## Gotchas
|
||||
|
||||
- **Preauth hash excludes the final success response**: Only STATUS_MORE_PROCESSING_REQUIRED responses are hashed. Including the success response produces wrong keys. (MS-SMB2 3.2.5.3.1)
|
||||
- **Oplock break notifications arrive with MessageId 0xFFFFFFFFFFFFFFFF**: The receiver task detects these and skips them without invoking a waiter lookup.
|
||||
- **Register-waiter must be atomic with `disconnected` check**: The waiters lock covers both reading `disconnected` and inserting the `oneshot::Sender`. If the check and insert were racy, a receiver-task failure mid-send could leave an orphan `Sender` in the map that never gets routed — caller would hang on `rx.await` forever. Same goes for `fan_error_to_waiters`: it sets `disconnected=true` UNDER the same waiters lock before draining, so new sends strictly either succeed-and-get-drained or fail at the insert check.
|
||||
- **Unrecoverable frame errors tear down the connection** (Phase 3 P3.4): decrypt failure, decompress failure, or a malformed sub-frame header that survives `split_compound` all cause the receiver task to call `fan_error_to_waiters(Err(Disconnected))` and exit. The alternative — log-and-continue — would leave the matching waiter hanging forever, because the msg_id isn't recoverable from an unparseable frame. The connection is also out of sync after one bad frame, so reconnect is the right move anyway. Counted via `MetricsSnapshot::{decrypt_failures, decompress_failures, malformed_frames}`.
|
||||
- **STATUS_PENDING loop**: CHANGE_NOTIFY and other long-poll operations get STATUS_PENDING first. The receiver task keeps the waiter registered on PENDING and does NOT forward the interim response. Credits from PENDING are still applied so the caller's `conn.credits()` reflects them. Counted via `MetricsSnapshot::status_pending_loops`.
|
||||
- **Signing and encryption are mutually exclusive on the wire**: When encrypting, zero the signature field (AEAD provides integrity). On receive, skip signature verification if decryption succeeded.
|
||||
- **Compound encryption wraps the entire chain**: One TRANSFORM_HEADER for all sub-requests concatenated, not per sub-request.
|
||||
- **Share-level encryption**: If a share has `SMB2_SHAREFLAG_ENCRYPT_DATA`, encryption is activated even if the session didn't require it.
|
||||
- **FileDownload/FileUpload can leak handles on drop**: Rust has no async drop. If not consumed fully, the file handle leaks. The types log a warning.
|
||||
- **FileWriter can leak handles on drop**: Same as FileDownload/FileUpload. Rust has no async drop. If not consumed via `finish()` or `abort()`, the file handle leaks. The type logs a debug warning.
|
||||
- **DFS paths must include server\share prefix**: When `SMB2_FLAGS_DFS_OPERATIONS` is set, the server expects the path to start with `server\share\` (MS-SMB2 3.2.4.3). `Tree::format_path()` handles this automatically for DFS shares. Without the prefix, Samba strips the first two path components, leading to wrong file opens.
|
||||
- **DFS redirect changes the tree in-place**: After a DFS redirect, `tree.server`, `tree.share_name`, and `tree.tree_id` all change. Subsequent operations on the same tree use the target server directly -- they must use target-relative paths, not the original DFS paths.
|
||||
- **tree.server stores addr:port**: The `server` field on `Tree` stores the full `addr:port` string (not just hostname) so `connection_for_tree` can distinguish servers that share the same hostname but use different ports.
|
||||
- **Servers MAY split compound responses**: MS-SMB2 section 3.3.4.1.3 says the server SHOULD compound responses but is not required to. Samba (and QNAP firmware built on it) is known to split compound chains into separate frames in some scenarios; Windows Server does too under certain conditions. Compound-using methods (`read_file_compound`, `write_file_compound`, `fs_info`, `stat`, `rename`, `delete_file`, batch `*_files`) call `Connection::receive_compound_expected(n)` instead of `receive_compound()`, which transparently gathers additional frames if the server splits. Logged at DEBUG, not WARN -- it's a spec edge case, not a problem.
|
||||
3413
vendor/smb2/src/client/connection.rs
vendored
Normal file
3413
vendor/smb2/src/client/connection.rs
vendored
Normal file
File diff suppressed because it is too large
Load Diff
884
vendor/smb2/src/client/dfs.rs
vendored
Normal file
884
vendor/smb2/src/client/dfs.rs
vendored
Normal file
@@ -0,0 +1,884 @@
|
||||
//! DFS referral IOCTL helper and path resolver with referral cache.
|
||||
//!
|
||||
//! Sends `FSCTL_DFS_GET_REFERRALS` via IOCTL to resolve DFS paths. Connects
|
||||
//! to IPC$ for the IOCTL exchange, similar to how `shares.rs` does for RPC.
|
||||
//!
|
||||
//! The [`DfsResolver`] caches referral responses with TTL and resolves UNC
|
||||
//! paths using longest-prefix matching. All string comparisons are
|
||||
//! case-insensitive (DFS paths are case-insensitive per MS-DFSC).
|
||||
|
||||
// DFS resolver is used by SmbClient for reactive DFS path resolution.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use log::debug;
|
||||
|
||||
use crate::client::connection::Connection;
|
||||
use crate::error::Result;
|
||||
use crate::msg::dfs::{ReqGetDfsReferral, RespGetDfsReferral};
|
||||
use crate::msg::ioctl::{
|
||||
IoctlRequest, IoctlResponse, FSCTL_DFS_GET_REFERRALS, SMB2_0_IOCTL_IS_FSCTL,
|
||||
};
|
||||
use crate::msg::tree_connect::{TreeConnectRequest, TreeConnectRequestFlags, TreeConnectResponse};
|
||||
use crate::msg::tree_disconnect::TreeDisconnectRequest;
|
||||
use crate::pack::{Pack, ReadCursor, Unpack, WriteCursor};
|
||||
use crate::types::status::NtStatus;
|
||||
use crate::types::{Command, FileId, TreeId};
|
||||
use crate::Error;
|
||||
|
||||
/// Maximum output buffer size for DFS referral responses (8 KiB).
|
||||
const DFS_MAX_OUTPUT_RESPONSE: u32 = 8192;
|
||||
|
||||
/// Send a DFS referral request and return the parsed response.
|
||||
///
|
||||
/// Connects to IPC$ (or reuses an existing tree), sends
|
||||
/// `FSCTL_DFS_GET_REFERRALS` via IOCTL with `FileId::SENTINEL`, and
|
||||
/// parses the response.
|
||||
///
|
||||
/// The `path` should be a UNC-style path with a single leading backslash
|
||||
/// (for example, `\server\share\dir`).
|
||||
pub(crate) async fn get_dfs_referral(
|
||||
conn: &mut Connection,
|
||||
path: &str,
|
||||
) -> Result<RespGetDfsReferral> {
|
||||
// 1. Tree-connect to IPC$
|
||||
let tree_id = tree_connect_ipc(conn).await?;
|
||||
|
||||
// Send the IOCTL, then clean up regardless of outcome
|
||||
let result = send_dfs_ioctl(conn, tree_id, path).await;
|
||||
|
||||
// Tree-disconnect IPC$ (best-effort -- don't mask the real error)
|
||||
let _ = tree_disconnect(conn, tree_id).await;
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Connect to the IPC$ share, returning the tree ID.
|
||||
async fn tree_connect_ipc(conn: &mut Connection) -> Result<TreeId> {
|
||||
let server = conn.server_name().to_string();
|
||||
let unc_path = format!(r"\\{}\IPC$", server);
|
||||
|
||||
let req = TreeConnectRequest {
|
||||
flags: TreeConnectRequestFlags::default(),
|
||||
path: unc_path,
|
||||
};
|
||||
|
||||
let frame = conn.execute(Command::TreeConnect, &req, None).await?;
|
||||
|
||||
if frame.header.command != Command::TreeConnect {
|
||||
return Err(Error::invalid_data(format!(
|
||||
"expected TreeConnect response, got {:?}",
|
||||
frame.header.command
|
||||
)));
|
||||
}
|
||||
|
||||
if frame.header.status != NtStatus::SUCCESS {
|
||||
return Err(Error::Protocol {
|
||||
status: frame.header.status,
|
||||
command: Command::TreeConnect,
|
||||
});
|
||||
}
|
||||
|
||||
let mut cursor = ReadCursor::new(&frame.body);
|
||||
let _resp = TreeConnectResponse::unpack(&mut cursor)?;
|
||||
|
||||
let tree_id = frame
|
||||
.header
|
||||
.tree_id
|
||||
.ok_or_else(|| Error::invalid_data("TreeConnect response missing tree ID"))?;
|
||||
|
||||
debug!("dfs: connected to IPC$, tree_id={}", tree_id);
|
||||
Ok(tree_id)
|
||||
}
|
||||
|
||||
/// Build and send the FSCTL_DFS_GET_REFERRALS IOCTL, parse the response.
|
||||
async fn send_dfs_ioctl(
|
||||
conn: &mut Connection,
|
||||
tree_id: TreeId,
|
||||
path: &str,
|
||||
) -> Result<RespGetDfsReferral> {
|
||||
// Build the referral request payload
|
||||
let referral_req = ReqGetDfsReferral {
|
||||
max_referral_level: 4,
|
||||
request_file_name: path.to_string(),
|
||||
};
|
||||
let mut req_cursor = WriteCursor::new();
|
||||
referral_req.pack(&mut req_cursor);
|
||||
let input_data = req_cursor.into_inner();
|
||||
|
||||
debug!(
|
||||
"dfs: sending FSCTL_DFS_GET_REFERRALS for {:?} ({} bytes input)",
|
||||
path,
|
||||
input_data.len()
|
||||
);
|
||||
|
||||
// Build the IOCTL request
|
||||
let ioctl_req = IoctlRequest {
|
||||
ctl_code: FSCTL_DFS_GET_REFERRALS,
|
||||
file_id: FileId::SENTINEL,
|
||||
max_input_response: 0,
|
||||
max_output_response: DFS_MAX_OUTPUT_RESPONSE,
|
||||
flags: SMB2_0_IOCTL_IS_FSCTL,
|
||||
input_data,
|
||||
};
|
||||
|
||||
let frame = conn
|
||||
.execute(Command::Ioctl, &ioctl_req, Some(tree_id))
|
||||
.await?;
|
||||
|
||||
if frame.header.status != NtStatus::SUCCESS {
|
||||
return Err(Error::Protocol {
|
||||
status: frame.header.status,
|
||||
command: Command::Ioctl,
|
||||
});
|
||||
}
|
||||
|
||||
// Parse the IOCTL response envelope
|
||||
let mut cursor = ReadCursor::new(&frame.body);
|
||||
let ioctl_resp = IoctlResponse::unpack(&mut cursor)?;
|
||||
|
||||
debug!(
|
||||
"dfs: received IOCTL response ({} bytes output)",
|
||||
ioctl_resp.output_data.len()
|
||||
);
|
||||
|
||||
// Parse the DFS referral from the output buffer
|
||||
let mut ref_cursor = ReadCursor::new(&ioctl_resp.output_data);
|
||||
let referral_resp = RespGetDfsReferral::unpack(&mut ref_cursor)?;
|
||||
|
||||
debug!(
|
||||
"dfs: parsed {} referral entries (path_consumed={})",
|
||||
referral_resp.entries.len(),
|
||||
referral_resp.path_consumed
|
||||
);
|
||||
|
||||
Ok(referral_resp)
|
||||
}
|
||||
|
||||
/// Disconnect from a tree.
|
||||
async fn tree_disconnect(conn: &mut Connection, tree_id: TreeId) -> Result<()> {
|
||||
let body = TreeDisconnectRequest;
|
||||
let frame = conn
|
||||
.execute(Command::TreeDisconnect, &body, Some(tree_id))
|
||||
.await?;
|
||||
|
||||
if frame.header.status != NtStatus::SUCCESS {
|
||||
return Err(Error::Protocol {
|
||||
status: frame.header.status,
|
||||
command: Command::TreeDisconnect,
|
||||
});
|
||||
}
|
||||
|
||||
debug!("dfs: disconnected from IPC$");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ── DFS resolver types ───────────────────────────────────────────────
|
||||
|
||||
/// A resolved DFS path ready for connection.
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct ResolvedPath {
|
||||
/// Server hostname (or IP) to connect to.
|
||||
pub server: String,
|
||||
/// Port to connect on (default 445).
|
||||
pub port: u16,
|
||||
/// Share name to tree-connect.
|
||||
pub share: String,
|
||||
/// Remaining path within the share (may be empty).
|
||||
pub remaining_path: String,
|
||||
}
|
||||
|
||||
/// A single DFS target from a referral response.
|
||||
#[derive(Debug, Clone)]
|
||||
struct DfsTarget {
|
||||
/// Server hostname from the network_address field.
|
||||
server: String,
|
||||
/// Share name from the network_address field.
|
||||
share: String,
|
||||
/// Any remaining path suffix from the network_address.
|
||||
remaining_prefix: String,
|
||||
}
|
||||
|
||||
/// A cached DFS referral entry with TTL.
|
||||
#[derive(Debug, Clone)]
|
||||
struct CachedReferral {
|
||||
/// The DFS path prefix this referral covers (lowercase for matching).
|
||||
dfs_path_prefix: String,
|
||||
/// Available targets (first is preferred).
|
||||
targets: Vec<DfsTarget>,
|
||||
/// When this entry expires.
|
||||
expires_at: Instant,
|
||||
}
|
||||
|
||||
/// DFS referral cache and path resolver.
|
||||
///
|
||||
/// Maintains a cache of DFS referral responses keyed by path prefix.
|
||||
/// Resolves UNC paths by longest-prefix matching against the cache,
|
||||
/// falling back to an IOCTL referral request on cache miss.
|
||||
pub(crate) struct DfsResolver {
|
||||
cache: HashMap<String, CachedReferral>,
|
||||
/// Counters surfaced through [`SmbClient::diagnostics`].
|
||||
cache_hits: AtomicU64,
|
||||
referrals_resolved: AtomicU64,
|
||||
}
|
||||
|
||||
impl DfsResolver {
|
||||
/// Create a new empty resolver.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
cache: HashMap::new(),
|
||||
cache_hits: AtomicU64::new(0),
|
||||
referrals_resolved: AtomicU64::new(0),
|
||||
}
|
||||
}
|
||||
|
||||
/// `(cache_hits, referrals_resolved)` for diagnostics.
|
||||
pub(crate) fn counters(&self) -> (u64, u64) {
|
||||
(
|
||||
self.cache_hits.load(Ordering::Relaxed),
|
||||
self.referrals_resolved.load(Ordering::Relaxed),
|
||||
)
|
||||
}
|
||||
|
||||
/// Iterate the cache entries (including expired ones — eviction is
|
||||
/// lazy). Used by [`SmbClient::diagnostics`].
|
||||
pub(crate) fn cache_entries(&self) -> Vec<crate::client::diagnostics::DfsCacheEntry> {
|
||||
let now = Instant::now();
|
||||
self.cache
|
||||
.values()
|
||||
.map(|e| crate::client::diagnostics::DfsCacheEntry {
|
||||
path_prefix: e.dfs_path_prefix.clone(),
|
||||
target_count: e.targets.len(),
|
||||
expires_in: if e.expires_at > now {
|
||||
Some(e.expires_at - now)
|
||||
} else {
|
||||
None
|
||||
},
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Resolve a UNC path by checking the cache first, then querying the server.
|
||||
///
|
||||
/// `unc_path` should be like `\\server\share\path\to\file`.
|
||||
/// `conn` is the connection to the server that returned `STATUS_PATH_NOT_COVERED`.
|
||||
pub async fn resolve(
|
||||
&mut self,
|
||||
conn: &mut Connection,
|
||||
unc_path: &str,
|
||||
) -> Result<Vec<ResolvedPath>> {
|
||||
// 1. Check cache (longest prefix match)
|
||||
if let Some(resolved) = self.resolve_from_cache(unc_path) {
|
||||
self.cache_hits.fetch_add(1, Ordering::Relaxed);
|
||||
debug!("dfs: cache hit for {:?}", unc_path);
|
||||
return Ok(resolved);
|
||||
}
|
||||
|
||||
// 2. Send referral request.
|
||||
// Convert \\server\share\path to \server\share\path (single leading
|
||||
// backslash for the IOCTL).
|
||||
let referral_path = if unc_path.starts_with("\\\\") {
|
||||
&unc_path[1..] // strip one leading backslash
|
||||
} else {
|
||||
unc_path
|
||||
};
|
||||
|
||||
debug!("dfs: cache miss, sending referral for {:?}", referral_path);
|
||||
let resp = get_dfs_referral(conn, referral_path).await?;
|
||||
self.referrals_resolved.fetch_add(1, Ordering::Relaxed);
|
||||
|
||||
// 3. Cache the result
|
||||
self.cache_referral(&resp);
|
||||
|
||||
// 4. Resolve from the freshly cached entry
|
||||
self.resolve_from_cache(unc_path).ok_or_else(|| {
|
||||
Error::invalid_data("DFS referral response did not match the requested path")
|
||||
})
|
||||
}
|
||||
|
||||
/// Try to resolve a path from the cache. Returns `None` on cache miss or
|
||||
/// expiry. Returns a `Vec` of [`ResolvedPath`]s (multiple targets for
|
||||
/// failover).
|
||||
pub(crate) fn resolve_from_cache(&self, unc_path: &str) -> Option<Vec<ResolvedPath>> {
|
||||
let normalized = unc_path.to_lowercase().replace('/', "\\");
|
||||
|
||||
// Longest prefix match
|
||||
let mut best_match: Option<&CachedReferral> = None;
|
||||
for entry in self.cache.values() {
|
||||
if normalized.starts_with(&entry.dfs_path_prefix)
|
||||
&& entry.expires_at > Instant::now()
|
||||
&& best_match.is_none_or(|b| entry.dfs_path_prefix.len() > b.dfs_path_prefix.len())
|
||||
{
|
||||
best_match = Some(entry);
|
||||
}
|
||||
}
|
||||
|
||||
let entry = best_match?;
|
||||
|
||||
// Strip the consumed prefix and build ResolvedPaths
|
||||
let remaining = &normalized[entry.dfs_path_prefix.len()..];
|
||||
let remaining = remaining.trim_start_matches('\\');
|
||||
|
||||
let resolved: Vec<ResolvedPath> = entry
|
||||
.targets
|
||||
.iter()
|
||||
.map(|target| {
|
||||
let full_remaining = if target.remaining_prefix.is_empty() {
|
||||
remaining.to_string()
|
||||
} else if remaining.is_empty() {
|
||||
target.remaining_prefix.clone()
|
||||
} else {
|
||||
format!("{}\\{}", target.remaining_prefix, remaining)
|
||||
};
|
||||
|
||||
ResolvedPath {
|
||||
server: target.server.clone(),
|
||||
port: 445,
|
||||
share: target.share.clone(),
|
||||
remaining_path: full_remaining,
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
Some(resolved)
|
||||
}
|
||||
|
||||
/// Store a referral response in the cache.
|
||||
fn cache_referral(&mut self, resp: &RespGetDfsReferral) {
|
||||
if resp.entries.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
// Use the dfs_path from the first entry as the cache key.
|
||||
// Normalize to lowercase backslash form with `\\` prefix (UNC canonical).
|
||||
let mut dfs_path_prefix = resp.entries[0].dfs_path.to_lowercase().replace('/', "\\");
|
||||
if !dfs_path_prefix.starts_with("\\\\") {
|
||||
if let Some(stripped) = dfs_path_prefix.strip_prefix('\\') {
|
||||
dfs_path_prefix = format!("\\\\{stripped}");
|
||||
}
|
||||
}
|
||||
|
||||
// Parse targets from entries
|
||||
let targets: Vec<DfsTarget> = resp
|
||||
.entries
|
||||
.iter()
|
||||
.filter_map(|e| parse_unc_target(&e.network_address))
|
||||
.collect();
|
||||
|
||||
if targets.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let ttl = resp.entries[0].ttl.max(1); // At least 1 second
|
||||
|
||||
debug!(
|
||||
"dfs: caching {:?} with {} targets, ttl={}s",
|
||||
dfs_path_prefix,
|
||||
targets.len(),
|
||||
ttl
|
||||
);
|
||||
|
||||
self.cache.insert(
|
||||
dfs_path_prefix.clone(),
|
||||
CachedReferral {
|
||||
dfs_path_prefix,
|
||||
targets,
|
||||
expires_at: Instant::now() + Duration::from_secs(ttl as u64),
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse a UNC network_address into server, share, and remaining path.
|
||||
///
|
||||
/// Input: `\\server\share` or `\\server\share\path`.
|
||||
/// Returns `None` if the format is invalid.
|
||||
fn parse_unc_target(network_address: &str) -> Option<DfsTarget> {
|
||||
let path = network_address.trim_start_matches('\\');
|
||||
let mut parts = path.splitn(3, '\\');
|
||||
let server = parts.next()?.to_string();
|
||||
let share = parts.next()?.to_string();
|
||||
let remaining_prefix = parts.next().unwrap_or("").to_string();
|
||||
|
||||
if server.is_empty() || share.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(DfsTarget {
|
||||
server,
|
||||
share,
|
||||
remaining_prefix,
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::client::connection::pack_message;
|
||||
use crate::client::test_helpers::{build_tree_connect_response, setup_connection};
|
||||
use crate::msg::header::{ErrorResponse, Header};
|
||||
use crate::msg::ioctl::IoctlResponse as IoctlResp;
|
||||
use crate::msg::tree_connect::ShareType;
|
||||
use crate::msg::tree_disconnect::TreeDisconnectResponse;
|
||||
use crate::transport::MockTransport;
|
||||
use crate::types::TreeId;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Build an IOCTL response containing the given output data.
|
||||
fn build_ioctl_response(output_data: Vec<u8>) -> Vec<u8> {
|
||||
let mut h = Header::new_request(Command::Ioctl);
|
||||
h.flags.set_response();
|
||||
h.credits = 32;
|
||||
|
||||
let body = IoctlResp {
|
||||
ctl_code: FSCTL_DFS_GET_REFERRALS,
|
||||
file_id: FileId::SENTINEL,
|
||||
flags: SMB2_0_IOCTL_IS_FSCTL,
|
||||
output_data,
|
||||
};
|
||||
|
||||
pack_message(&h, &body)
|
||||
}
|
||||
|
||||
/// Build an IOCTL error response with the given status.
|
||||
fn build_ioctl_error_response(status: NtStatus) -> Vec<u8> {
|
||||
let mut h = Header::new_request(Command::Ioctl);
|
||||
h.flags.set_response();
|
||||
h.credits = 32;
|
||||
h.status = status;
|
||||
|
||||
let body = ErrorResponse {
|
||||
error_context_count: 0,
|
||||
error_data: vec![],
|
||||
};
|
||||
|
||||
pack_message(&h, &body)
|
||||
}
|
||||
|
||||
/// Build a TREE_DISCONNECT response.
|
||||
fn build_tree_disconnect_response() -> Vec<u8> {
|
||||
let mut h = Header::new_request(Command::TreeDisconnect);
|
||||
h.flags.set_response();
|
||||
h.credits = 32;
|
||||
pack_message(&h, &TreeDisconnectResponse)
|
||||
}
|
||||
|
||||
/// Pack a known DFS referral response into bytes.
|
||||
///
|
||||
/// Builds a V3 referral with the given entries.
|
||||
fn pack_dfs_referral_response(
|
||||
path_consumed: u16,
|
||||
header_flags: u32,
|
||||
entries: &[(&str, &str, &str, u32)], // (dfs_path, alt_path, net_addr, ttl)
|
||||
) -> Vec<u8> {
|
||||
// We build a V3 referral response manually.
|
||||
// Entry fixed size: 4 (version+size) + 2+2+4 (server_type+flags+ttl)
|
||||
// + 2+2+2 (offsets) + 16 (guid) = 34 bytes
|
||||
let entry_fixed_size: u16 = 34;
|
||||
let num_entries = entries.len() as u16;
|
||||
let total_fixed = entry_fixed_size * num_entries;
|
||||
|
||||
// Pre-compute all string bytes
|
||||
let entry_strings: Vec<(Vec<u8>, Vec<u8>, Vec<u8>)> = entries
|
||||
.iter()
|
||||
.map(|(dfs, alt, net, _)| {
|
||||
(
|
||||
encode_null_utf16(dfs),
|
||||
encode_null_utf16(alt),
|
||||
encode_null_utf16(net),
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Compute cumulative string offsets relative to each entry's start.
|
||||
// All strings come after all fixed entries. The offset for entry i
|
||||
// is relative to entry i's start position.
|
||||
let mut buf = Vec::new();
|
||||
|
||||
// Response header (8 bytes)
|
||||
buf.extend_from_slice(&path_consumed.to_le_bytes());
|
||||
buf.extend_from_slice(&num_entries.to_le_bytes());
|
||||
buf.extend_from_slice(&header_flags.to_le_bytes());
|
||||
|
||||
// Calculate where strings start (after all fixed entries, but
|
||||
// offsets are measured from the start of the entry data, not from
|
||||
// the response header -- since RespGetDfsReferral::unpack reads
|
||||
// the header first and then works with the remaining bytes).
|
||||
//
|
||||
// Actually, offsets in V3 entries are relative to the entry start
|
||||
// within the entry data buffer.
|
||||
|
||||
// Accumulate string buffer contents and compute per-entry offsets.
|
||||
let mut string_buf = Vec::new();
|
||||
let mut per_entry_offsets = Vec::new();
|
||||
|
||||
for (i, (dfs_bytes, alt_bytes, net_bytes)) in entry_strings.iter().enumerate() {
|
||||
let entry_start = i as u16 * entry_fixed_size;
|
||||
let strings_base = total_fixed + string_buf.len() as u16;
|
||||
|
||||
let dfs_offset = strings_base - entry_start;
|
||||
let alt_offset = dfs_offset + dfs_bytes.len() as u16;
|
||||
let net_offset = alt_offset + alt_bytes.len() as u16;
|
||||
|
||||
per_entry_offsets.push((dfs_offset, alt_offset, net_offset));
|
||||
|
||||
string_buf.extend_from_slice(dfs_bytes);
|
||||
string_buf.extend_from_slice(alt_bytes);
|
||||
string_buf.extend_from_slice(net_bytes);
|
||||
}
|
||||
|
||||
// Write fixed entries
|
||||
for (i, (_, _, _, ttl)) in entries.iter().enumerate() {
|
||||
let (dfs_off, alt_off, net_off) = per_entry_offsets[i];
|
||||
|
||||
buf.extend_from_slice(&3u16.to_le_bytes()); // version = 3
|
||||
buf.extend_from_slice(&entry_fixed_size.to_le_bytes()); // size
|
||||
buf.extend_from_slice(&0u16.to_le_bytes()); // server_type
|
||||
buf.extend_from_slice(&0u16.to_le_bytes()); // referral_entry_flags
|
||||
buf.extend_from_slice(&ttl.to_le_bytes()); // ttl
|
||||
buf.extend_from_slice(&dfs_off.to_le_bytes());
|
||||
buf.extend_from_slice(&alt_off.to_le_bytes());
|
||||
buf.extend_from_slice(&net_off.to_le_bytes());
|
||||
buf.extend_from_slice(&[0u8; 16]); // service_site_guid
|
||||
}
|
||||
|
||||
// Write string buffer
|
||||
buf.extend_from_slice(&string_buf);
|
||||
|
||||
buf
|
||||
}
|
||||
|
||||
/// Encode a string as null-terminated UTF-16LE bytes.
|
||||
fn encode_null_utf16(s: &str) -> Vec<u8> {
|
||||
let mut out = Vec::new();
|
||||
for cu in s.encode_utf16() {
|
||||
out.extend_from_slice(&cu.to_le_bytes());
|
||||
}
|
||||
out.extend_from_slice(&[0x00, 0x00]);
|
||||
out
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn dfs_referral_ioctl_flow() {
|
||||
let mock = Arc::new(MockTransport::new());
|
||||
let mut conn = setup_connection(&mock);
|
||||
|
||||
let tree_id = TreeId(99);
|
||||
|
||||
// Build the DFS referral payload
|
||||
let referral_bytes = pack_dfs_referral_response(
|
||||
48, // path_consumed
|
||||
0x02, // header_flags (StorageServers)
|
||||
&[
|
||||
(
|
||||
r"\domain\dfs\docs",
|
||||
r"\domain\dfs\docs",
|
||||
r"\server1\share",
|
||||
600,
|
||||
),
|
||||
(
|
||||
r"\domain\dfs\docs",
|
||||
r"\domain\dfs\docs",
|
||||
r"\server2\share",
|
||||
300,
|
||||
),
|
||||
],
|
||||
);
|
||||
|
||||
// Queue responses: TreeConnect, IOCTL, TreeDisconnect
|
||||
mock.queue_response(build_tree_connect_response(tree_id, ShareType::Pipe));
|
||||
mock.queue_response(build_ioctl_response(referral_bytes));
|
||||
mock.queue_response(build_tree_disconnect_response());
|
||||
|
||||
let resp = get_dfs_referral(&mut conn, r"\domain\dfs\docs")
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(resp.path_consumed, 48);
|
||||
assert_eq!(resp.header_flags, 0x02);
|
||||
assert_eq!(resp.entries.len(), 2);
|
||||
|
||||
assert_eq!(resp.entries[0].version, 3);
|
||||
assert_eq!(resp.entries[0].dfs_path, r"\domain\dfs\docs");
|
||||
assert_eq!(resp.entries[0].network_address, r"\server1\share");
|
||||
assert_eq!(resp.entries[0].ttl, 600);
|
||||
|
||||
assert_eq!(resp.entries[1].network_address, r"\server2\share");
|
||||
assert_eq!(resp.entries[1].ttl, 300);
|
||||
|
||||
// Should have sent 3 messages: TreeConnect, IOCTL, TreeDisconnect
|
||||
assert_eq!(mock.sent_count(), 3);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn dfs_referral_ioctl_error() {
|
||||
let mock = Arc::new(MockTransport::new());
|
||||
let mut conn = setup_connection(&mock);
|
||||
|
||||
let tree_id = TreeId(99);
|
||||
|
||||
// Queue responses: TreeConnect, IOCTL error, TreeDisconnect
|
||||
mock.queue_response(build_tree_connect_response(tree_id, ShareType::Pipe));
|
||||
mock.queue_response(build_ioctl_error_response(NtStatus::NOT_FOUND));
|
||||
mock.queue_response(build_tree_disconnect_response());
|
||||
|
||||
let result = get_dfs_referral(&mut conn, r"\nonexistent\path").await;
|
||||
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err();
|
||||
match &err {
|
||||
Error::Protocol { status, command } => {
|
||||
assert_eq!(*status, NtStatus::NOT_FOUND);
|
||||
assert_eq!(*command, Command::Ioctl);
|
||||
}
|
||||
other => panic!("expected Protocol error, got: {other:?}"),
|
||||
}
|
||||
|
||||
// Should still send TreeDisconnect even after IOCTL error
|
||||
assert_eq!(mock.sent_count(), 3);
|
||||
}
|
||||
|
||||
// ── parse_unc_target tests ───────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn parse_unc_target_basic() {
|
||||
let t = parse_unc_target(r"\\server\share").unwrap();
|
||||
assert_eq!(t.server, "server");
|
||||
assert_eq!(t.share, "share");
|
||||
assert_eq!(t.remaining_prefix, "");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_unc_target_with_path() {
|
||||
let t = parse_unc_target(r"\\server\share\path\to").unwrap();
|
||||
assert_eq!(t.server, "server");
|
||||
assert_eq!(t.share, "share");
|
||||
assert_eq!(t.remaining_prefix, r"path\to");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_unc_target_invalid() {
|
||||
assert!(parse_unc_target(r"\\").is_none());
|
||||
assert!(parse_unc_target("").is_none());
|
||||
assert!(parse_unc_target(r"\\server").is_none());
|
||||
// Single backslash + server but no share
|
||||
assert!(parse_unc_target(r"\server").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_unc_target_single_backslash_prefix() {
|
||||
// Network addresses with single backslash prefix should also work.
|
||||
let t = parse_unc_target(r"\server\share").unwrap();
|
||||
assert_eq!(t.server, "server");
|
||||
assert_eq!(t.share, "share");
|
||||
assert_eq!(t.remaining_prefix, "");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_unc_target_triple_backslash() {
|
||||
// Extra leading backslashes are stripped.
|
||||
let t = parse_unc_target(r"\\\server\share\path").unwrap();
|
||||
assert_eq!(t.server, "server");
|
||||
assert_eq!(t.share, "share");
|
||||
assert_eq!(t.remaining_prefix, "path");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_unc_target_ip_address() {
|
||||
// IP addresses as server names.
|
||||
let t = parse_unc_target(r"\\192.168.1.100\data").unwrap();
|
||||
assert_eq!(t.server, "192.168.1.100");
|
||||
assert_eq!(t.share, "data");
|
||||
assert_eq!(t.remaining_prefix, "");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_unc_target_deep_path() {
|
||||
// The remaining prefix captures everything after server\share.
|
||||
let t = parse_unc_target(r"\\server\share\a\b\c\d").unwrap();
|
||||
assert_eq!(t.server, "server");
|
||||
assert_eq!(t.share, "share");
|
||||
assert_eq!(t.remaining_prefix, r"a\b\c\d");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_unc_target_empty_components() {
|
||||
// Empty server or share should return None.
|
||||
assert!(parse_unc_target(r"\\\\share").is_none()); // empty server
|
||||
assert!(parse_unc_target(r"\\\").is_none()); // server is empty after strip
|
||||
}
|
||||
|
||||
// ── DfsResolver tests ────────────────────────────────────────────
|
||||
|
||||
/// Helper: build a RespGetDfsReferral for cache tests.
|
||||
fn make_referral(
|
||||
dfs_path: &str,
|
||||
entries: &[(&str, u32)], // (network_address, ttl)
|
||||
) -> RespGetDfsReferral {
|
||||
use crate::msg::dfs::DfsReferralEntry;
|
||||
|
||||
let referral_entries: Vec<DfsReferralEntry> = entries
|
||||
.iter()
|
||||
.map(|(net_addr, ttl)| DfsReferralEntry {
|
||||
version: 3,
|
||||
server_type: 0,
|
||||
referral_entry_flags: 0,
|
||||
ttl: *ttl,
|
||||
dfs_path: dfs_path.to_string(),
|
||||
dfs_alternate_path: dfs_path.to_string(),
|
||||
network_address: net_addr.to_string(),
|
||||
})
|
||||
.collect();
|
||||
|
||||
RespGetDfsReferral {
|
||||
path_consumed: 0,
|
||||
header_flags: 0,
|
||||
entries: referral_entries,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolver_cache_hit() {
|
||||
let mut resolver = DfsResolver::new();
|
||||
|
||||
let resp = make_referral(r"\domain\dfs\docs", &[(r"\\server1\share", 600)]);
|
||||
resolver.cache_referral(&resp);
|
||||
|
||||
let result = resolver.resolve_from_cache(r"\\domain\dfs\docs\file.txt");
|
||||
assert!(result.is_some());
|
||||
let paths = result.unwrap();
|
||||
assert_eq!(paths.len(), 1);
|
||||
assert_eq!(paths[0].server, "server1");
|
||||
assert_eq!(paths[0].share, "share");
|
||||
assert_eq!(paths[0].port, 445);
|
||||
assert_eq!(paths[0].remaining_path, "file.txt");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolver_cache_miss() {
|
||||
let resolver = DfsResolver::new();
|
||||
|
||||
let result = resolver.resolve_from_cache(r"\\server\share\file.txt");
|
||||
assert!(result.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolver_cache_expired() {
|
||||
let mut resolver = DfsResolver::new();
|
||||
|
||||
// Insert with TTL=0 -- cache_referral clamps to 1s, so we need to
|
||||
// manually insert an already-expired entry.
|
||||
let targets = vec![DfsTarget {
|
||||
server: "srv".to_string(),
|
||||
share: "data".to_string(),
|
||||
remaining_prefix: String::new(),
|
||||
}];
|
||||
resolver.cache.insert(
|
||||
r"\domain\dfs".to_string(),
|
||||
CachedReferral {
|
||||
dfs_path_prefix: r"\domain\dfs".to_string(),
|
||||
targets,
|
||||
expires_at: Instant::now() - Duration::from_secs(1),
|
||||
},
|
||||
);
|
||||
|
||||
let result = resolver.resolve_from_cache(r"\\domain\dfs\file.txt");
|
||||
assert!(result.is_none(), "expired entry should not match");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolver_cache_longest_prefix() {
|
||||
let mut resolver = DfsResolver::new();
|
||||
|
||||
// Insert a short prefix
|
||||
let short = make_referral(r"\domain\dfs", &[(r"\\server1\root", 600)]);
|
||||
resolver.cache_referral(&short);
|
||||
|
||||
// Insert a longer prefix
|
||||
let long = make_referral(r"\domain\dfs\docs", &[(r"\\server2\docs", 600)]);
|
||||
resolver.cache_referral(&long);
|
||||
|
||||
// Should match the longer prefix
|
||||
let result = resolver
|
||||
.resolve_from_cache(r"\\domain\dfs\docs\file.txt")
|
||||
.unwrap();
|
||||
assert_eq!(result[0].server, "server2");
|
||||
assert_eq!(result[0].share, "docs");
|
||||
assert_eq!(result[0].remaining_path, "file.txt");
|
||||
|
||||
// A path that only matches the short prefix
|
||||
let result2 = resolver
|
||||
.resolve_from_cache(r"\\domain\dfs\other\file.txt")
|
||||
.unwrap();
|
||||
assert_eq!(result2[0].server, "server1");
|
||||
assert_eq!(result2[0].share, "root");
|
||||
assert_eq!(result2[0].remaining_path, r"other\file.txt");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolver_multiple_targets() {
|
||||
let mut resolver = DfsResolver::new();
|
||||
|
||||
let resp = make_referral(
|
||||
r"\domain\dfs\docs",
|
||||
&[(r"\\server1\share", 600), (r"\\server2\share", 300)],
|
||||
);
|
||||
resolver.cache_referral(&resp);
|
||||
|
||||
let result = resolver
|
||||
.resolve_from_cache(r"\\domain\dfs\docs\file.txt")
|
||||
.unwrap();
|
||||
assert_eq!(result.len(), 2);
|
||||
assert_eq!(result[0].server, "server1");
|
||||
assert_eq!(result[1].server, "server2");
|
||||
// Both should have the same remaining path
|
||||
assert_eq!(result[0].remaining_path, "file.txt");
|
||||
assert_eq!(result[1].remaining_path, "file.txt");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolver_path_normalization() {
|
||||
let mut resolver = DfsResolver::new();
|
||||
|
||||
// Cache with backslash-separated DFS path
|
||||
let resp = make_referral(r"\domain\dfs\docs", &[(r"\\server\share", 600)]);
|
||||
resolver.cache_referral(&resp);
|
||||
|
||||
// Resolve with double-backslash prefix and mixed case
|
||||
let result = resolver
|
||||
.resolve_from_cache(r"\\DOMAIN\DFS\DOCS\Sub\File.txt")
|
||||
.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].server, "server");
|
||||
assert_eq!(result[0].share, "share");
|
||||
// remaining_path is lowercased because we normalize the full input
|
||||
assert_eq!(result[0].remaining_path, r"sub\file.txt");
|
||||
|
||||
// Forward slashes should also work
|
||||
let result2 = resolver
|
||||
.resolve_from_cache(r"\\domain/dfs/docs/other.txt")
|
||||
.unwrap();
|
||||
assert_eq!(result2[0].remaining_path, "other.txt");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolver_remaining_prefix_from_target() {
|
||||
let mut resolver = DfsResolver::new();
|
||||
|
||||
// Target has a remaining prefix (network_address includes a subpath)
|
||||
let resp = make_referral(r"\domain\dfs\docs", &[(r"\\server\share\subdir", 600)]);
|
||||
resolver.cache_referral(&resp);
|
||||
|
||||
// With additional path after the DFS prefix
|
||||
let result = resolver
|
||||
.resolve_from_cache(r"\\domain\dfs\docs\file.txt")
|
||||
.unwrap();
|
||||
assert_eq!(result[0].remaining_path, r"subdir\file.txt");
|
||||
|
||||
// Without additional path -- just the target's remaining prefix
|
||||
let result2 = resolver.resolve_from_cache(r"\\domain\dfs\docs").unwrap();
|
||||
assert_eq!(result2[0].remaining_path, "subdir");
|
||||
}
|
||||
}
|
||||
1048
vendor/smb2/src/client/diagnostics.rs
vendored
Normal file
1048
vendor/smb2/src/client/diagnostics.rs
vendored
Normal file
File diff suppressed because it is too large
Load Diff
1495
vendor/smb2/src/client/mod.rs
vendored
Normal file
1495
vendor/smb2/src/client/mod.rs
vendored
Normal file
File diff suppressed because it is too large
Load Diff
670
vendor/smb2/src/client/pipeline.rs
vendored
Normal file
670
vendor/smb2/src/client/pipeline.rs
vendored
Normal file
@@ -0,0 +1,670 @@
|
||||
//! Unified operation pipeline for concurrent SMB2 operations.
|
||||
//!
|
||||
//! The [`Pipeline`] sends multiple SMB2 requests without waiting for each
|
||||
//! response, filling the credit window. Results are collected and returned
|
||||
//! once all operations complete.
|
||||
//!
|
||||
//! This is a first-iteration pipeline that executes a batch of operations.
|
||||
//! Future iterations will add a channel-based streaming interface, compound
|
||||
//! request construction, and chunk-level interleaving for large files.
|
||||
|
||||
use log::debug;
|
||||
|
||||
use crate::client::connection::Connection;
|
||||
use crate::client::tree::Tree;
|
||||
|
||||
/// An operation to execute through the pipeline.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Op {
|
||||
/// Read a file, returning its contents.
|
||||
ReadFile(String),
|
||||
/// Write data to a file (create or overwrite).
|
||||
WriteFile(String, Vec<u8>),
|
||||
/// Delete a file.
|
||||
Delete(String),
|
||||
/// List a directory.
|
||||
ListDirectory(String),
|
||||
/// Get file metadata.
|
||||
Stat(String),
|
||||
}
|
||||
|
||||
/// Result of a pipeline operation.
|
||||
#[derive(Debug)]
|
||||
pub enum OpResult {
|
||||
/// File data read successfully.
|
||||
FileData {
|
||||
/// The path that was read.
|
||||
path: String,
|
||||
/// The file contents.
|
||||
data: Vec<u8>,
|
||||
},
|
||||
/// File written successfully.
|
||||
Written {
|
||||
/// The path that was written.
|
||||
path: String,
|
||||
/// Number of bytes written.
|
||||
bytes_written: u64,
|
||||
},
|
||||
/// File deleted successfully.
|
||||
Deleted {
|
||||
/// The path that was deleted.
|
||||
path: String,
|
||||
},
|
||||
/// Directory listing.
|
||||
DirEntries {
|
||||
/// The path that was listed.
|
||||
path: String,
|
||||
/// The directory entries.
|
||||
entries: Vec<crate::client::tree::DirectoryEntry>,
|
||||
},
|
||||
/// File metadata.
|
||||
Stat {
|
||||
/// The path that was queried.
|
||||
path: String,
|
||||
/// The file information.
|
||||
info: crate::client::tree::FileInfo,
|
||||
},
|
||||
/// Operation failed.
|
||||
Error {
|
||||
/// The path that failed.
|
||||
path: String,
|
||||
/// The error that occurred.
|
||||
error: crate::Error,
|
||||
},
|
||||
}
|
||||
|
||||
/// A pipeline for executing multiple SMB operations as a batch.
|
||||
///
|
||||
/// The pipeline executes operations sequentially in this first iteration.
|
||||
/// Each multi-step operation (for example, read = CREATE + READ + CLOSE) runs
|
||||
/// to completion before the next operation starts. Future iterations will
|
||||
/// interleave steps from different operations to fill the credit window.
|
||||
pub struct Pipeline<'a> {
|
||||
conn: &'a mut Connection,
|
||||
tree: &'a Tree,
|
||||
}
|
||||
|
||||
impl<'a> Pipeline<'a> {
|
||||
/// Create a new pipeline bound to a connection and tree.
|
||||
pub fn new(conn: &'a mut Connection, tree: &'a Tree) -> Self {
|
||||
Self { conn, tree }
|
||||
}
|
||||
|
||||
/// Execute a batch of operations and return the results.
|
||||
///
|
||||
/// Results are returned in the same order as the input operations.
|
||||
/// Each operation that fails produces an [`OpResult::Error`] rather
|
||||
/// than aborting the entire batch.
|
||||
pub async fn execute(&mut self, ops: Vec<Op>) -> Vec<OpResult> {
|
||||
let mut results = Vec::with_capacity(ops.len());
|
||||
|
||||
for op in ops {
|
||||
let result = self.execute_one(op).await;
|
||||
results.push(result);
|
||||
}
|
||||
|
||||
results
|
||||
}
|
||||
|
||||
/// Execute a single operation.
|
||||
async fn execute_one(&mut self, op: Op) -> OpResult {
|
||||
match op {
|
||||
Op::ReadFile(path) => {
|
||||
debug!("pipeline: read_file path={}", path);
|
||||
match self.tree.read_file(self.conn, &path).await {
|
||||
Ok(data) => OpResult::FileData { path, data },
|
||||
Err(e) => OpResult::Error { path, error: e },
|
||||
}
|
||||
}
|
||||
Op::WriteFile(path, data) => {
|
||||
debug!("pipeline: write_file path={}", path);
|
||||
match self.tree.write_file(self.conn, &path, &data).await {
|
||||
Ok(bytes_written) => OpResult::Written {
|
||||
path,
|
||||
bytes_written,
|
||||
},
|
||||
Err(e) => OpResult::Error { path, error: e },
|
||||
}
|
||||
}
|
||||
Op::Delete(path) => {
|
||||
debug!("pipeline: delete path={}", path);
|
||||
match self.tree.delete_file(self.conn, &path).await {
|
||||
Ok(()) => OpResult::Deleted { path },
|
||||
Err(e) => OpResult::Error { path, error: e },
|
||||
}
|
||||
}
|
||||
Op::ListDirectory(path) => {
|
||||
debug!("pipeline: list_directory path={}", path);
|
||||
match self.tree.list_directory(self.conn, &path).await {
|
||||
Ok(entries) => OpResult::DirEntries { path, entries },
|
||||
Err(e) => OpResult::Error { path, error: e },
|
||||
}
|
||||
}
|
||||
Op::Stat(path) => {
|
||||
debug!("pipeline: stat path={}", path);
|
||||
match self.tree.stat(self.conn, &path).await {
|
||||
Ok(info) => OpResult::Stat { path, info },
|
||||
Err(e) => OpResult::Error { path, error: e },
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::client::connection::pack_message;
|
||||
use crate::client::test_helpers::{
|
||||
build_close_response, build_create_response, setup_connection,
|
||||
};
|
||||
use crate::client::tree::Tree;
|
||||
use crate::msg::create::{CreateAction, CreateResponse};
|
||||
use crate::msg::header::{ErrorResponse, Header};
|
||||
use crate::msg::query_directory::QueryDirectoryResponse;
|
||||
use crate::msg::query_info::QueryInfoResponse;
|
||||
use crate::msg::read::ReadResponse;
|
||||
use crate::msg::write::WriteResponse;
|
||||
use crate::pack::FileTime;
|
||||
use crate::transport::MockTransport;
|
||||
use crate::types::status::NtStatus;
|
||||
use crate::types::{Command, FileId, OplockLevel, TreeId};
|
||||
use std::sync::Arc;
|
||||
|
||||
fn test_tree() -> Tree {
|
||||
Tree {
|
||||
tree_id: TreeId(10),
|
||||
share_name: "test".to_string(),
|
||||
server: "test-server".to_string(),
|
||||
is_dfs: false,
|
||||
encrypt_data: false,
|
||||
}
|
||||
}
|
||||
|
||||
fn build_create_response_directory(file_id: FileId) -> Vec<u8> {
|
||||
let mut h = Header::new_request(Command::Create);
|
||||
h.flags.set_response();
|
||||
h.credits = 32;
|
||||
|
||||
let body = CreateResponse {
|
||||
oplock_level: OplockLevel::None,
|
||||
flags: 0,
|
||||
create_action: CreateAction::FileOpened,
|
||||
creation_time: FileTime(132_000_000_000_000_000),
|
||||
last_access_time: FileTime(132_000_000_000_000_000),
|
||||
last_write_time: FileTime(133_000_000_000_000_000),
|
||||
change_time: FileTime(133_000_000_000_000_000),
|
||||
allocation_size: 0,
|
||||
end_of_file: 0,
|
||||
file_attributes: 0x10, // DIRECTORY
|
||||
file_id,
|
||||
create_contexts: vec![],
|
||||
};
|
||||
|
||||
pack_message(&h, &body)
|
||||
}
|
||||
|
||||
fn build_flush_response() -> Vec<u8> {
|
||||
let mut h = Header::new_request(Command::Flush);
|
||||
h.flags.set_response();
|
||||
h.credits = 32;
|
||||
|
||||
let body = crate::msg::flush::FlushResponse;
|
||||
pack_message(&h, &body)
|
||||
}
|
||||
|
||||
fn build_read_response(data: Vec<u8>) -> Vec<u8> {
|
||||
let mut h = Header::new_request(Command::Read);
|
||||
h.flags.set_response();
|
||||
h.credits = 32;
|
||||
|
||||
let body = ReadResponse {
|
||||
data_offset: 0x50,
|
||||
data_remaining: 0,
|
||||
flags: 0,
|
||||
data,
|
||||
};
|
||||
|
||||
pack_message(&h, &body)
|
||||
}
|
||||
|
||||
fn build_write_response(count: u32) -> Vec<u8> {
|
||||
let mut h = Header::new_request(Command::Write);
|
||||
h.flags.set_response();
|
||||
h.credits = 32;
|
||||
|
||||
let body = WriteResponse {
|
||||
count,
|
||||
remaining: 0,
|
||||
write_channel_info_offset: 0,
|
||||
write_channel_info_length: 0,
|
||||
};
|
||||
|
||||
pack_message(&h, &body)
|
||||
}
|
||||
|
||||
fn build_query_info_response(output_buffer: Vec<u8>) -> Vec<u8> {
|
||||
let mut h = Header::new_request(Command::QueryInfo);
|
||||
h.flags.set_response();
|
||||
h.credits = 32;
|
||||
|
||||
let body = QueryInfoResponse { output_buffer };
|
||||
|
||||
pack_message(&h, &body)
|
||||
}
|
||||
|
||||
fn build_query_directory_response(status: NtStatus, entries_data: Vec<u8>) -> Vec<u8> {
|
||||
let mut h = Header::new_request(Command::QueryDirectory);
|
||||
h.flags.set_response();
|
||||
h.credits = 32;
|
||||
h.status = status;
|
||||
|
||||
if status == NtStatus::NO_MORE_FILES {
|
||||
let body = ErrorResponse {
|
||||
error_context_count: 0,
|
||||
error_data: vec![],
|
||||
};
|
||||
return pack_message(&h, &body);
|
||||
}
|
||||
|
||||
let body = QueryDirectoryResponse {
|
||||
output_buffer: entries_data,
|
||||
};
|
||||
|
||||
pack_message(&h, &body)
|
||||
}
|
||||
|
||||
/// Build a FileBasicInformation buffer (40 bytes).
|
||||
fn build_file_basic_info(
|
||||
creation_time: u64,
|
||||
last_access_time: u64,
|
||||
last_write_time: u64,
|
||||
change_time: u64,
|
||||
file_attributes: u32,
|
||||
) -> Vec<u8> {
|
||||
let mut buf = Vec::new();
|
||||
buf.extend_from_slice(&creation_time.to_le_bytes());
|
||||
buf.extend_from_slice(&last_access_time.to_le_bytes());
|
||||
buf.extend_from_slice(&last_write_time.to_le_bytes());
|
||||
buf.extend_from_slice(&change_time.to_le_bytes());
|
||||
buf.extend_from_slice(&file_attributes.to_le_bytes());
|
||||
// Padding to 40 bytes (Reserved)
|
||||
buf.extend_from_slice(&0u32.to_le_bytes());
|
||||
buf
|
||||
}
|
||||
|
||||
/// Build a FileStandardInformation buffer (24 bytes).
|
||||
fn build_file_standard_info(
|
||||
allocation_size: u64,
|
||||
end_of_file: u64,
|
||||
number_of_links: u32,
|
||||
delete_pending: bool,
|
||||
directory: bool,
|
||||
) -> Vec<u8> {
|
||||
let mut buf = Vec::new();
|
||||
buf.extend_from_slice(&allocation_size.to_le_bytes());
|
||||
buf.extend_from_slice(&end_of_file.to_le_bytes());
|
||||
buf.extend_from_slice(&number_of_links.to_le_bytes());
|
||||
buf.push(if delete_pending { 1 } else { 0 });
|
||||
buf.push(if directory { 1 } else { 0 });
|
||||
buf.extend_from_slice(&0u16.to_le_bytes()); // Reserved
|
||||
buf
|
||||
}
|
||||
|
||||
/// Build a single FileBothDirectoryInformation entry.
|
||||
fn build_file_both_dir_info(
|
||||
name: &str,
|
||||
size: u64,
|
||||
is_directory: bool,
|
||||
next_offset: u32,
|
||||
) -> Vec<u8> {
|
||||
let name_u16: Vec<u16> = name.encode_utf16().collect();
|
||||
let name_bytes_len = name_u16.len() * 2;
|
||||
|
||||
let mut buf = Vec::new();
|
||||
buf.extend_from_slice(&next_offset.to_le_bytes());
|
||||
buf.extend_from_slice(&0u32.to_le_bytes()); // FileIndex
|
||||
buf.extend_from_slice(&132_000_000_000_000_000u64.to_le_bytes()); // CreationTime
|
||||
buf.extend_from_slice(&132_000_000_000_000_000u64.to_le_bytes()); // LastAccessTime
|
||||
buf.extend_from_slice(&133_000_000_000_000_000u64.to_le_bytes()); // LastWriteTime
|
||||
buf.extend_from_slice(&133_000_000_000_000_000u64.to_le_bytes()); // ChangeTime
|
||||
buf.extend_from_slice(&size.to_le_bytes());
|
||||
buf.extend_from_slice(&((size + 4095) & !4095).to_le_bytes()); // AllocationSize
|
||||
let attrs: u32 = if is_directory { 0x10 } else { 0x20 };
|
||||
buf.extend_from_slice(&attrs.to_le_bytes());
|
||||
buf.extend_from_slice(&(name_bytes_len as u32).to_le_bytes());
|
||||
buf.extend_from_slice(&0u32.to_le_bytes()); // EaSize
|
||||
buf.push(0); // ShortNameLength
|
||||
buf.push(0); // Reserved
|
||||
buf.extend_from_slice(&[0u8; 24]); // ShortName
|
||||
for &u in &name_u16 {
|
||||
buf.extend_from_slice(&u.to_le_bytes());
|
||||
}
|
||||
buf
|
||||
}
|
||||
|
||||
/// Build a compound response frame with proper NextCommand offsets and padding.
|
||||
fn build_compound_response_frame(responses: &[Vec<u8>]) -> Vec<u8> {
|
||||
let mut padded: Vec<Vec<u8>> = Vec::new();
|
||||
for (i, resp) in responses.iter().enumerate() {
|
||||
let mut r = resp.clone();
|
||||
let is_last = i == responses.len() - 1;
|
||||
if !is_last {
|
||||
// Pad to 8-byte alignment.
|
||||
let remainder = r.len() % 8;
|
||||
if remainder != 0 {
|
||||
r.resize(r.len() + (8 - remainder), 0);
|
||||
}
|
||||
// Set NextCommand.
|
||||
let next_cmd = r.len() as u32;
|
||||
r[20..24].copy_from_slice(&next_cmd.to_le_bytes());
|
||||
}
|
||||
padded.push(r);
|
||||
}
|
||||
let mut frame = Vec::new();
|
||||
for r in &padded {
|
||||
frame.extend_from_slice(r);
|
||||
}
|
||||
frame
|
||||
}
|
||||
|
||||
/// Build a compound read response frame (CREATE + READ + CLOSE) for pipeline tests.
|
||||
fn build_compound_read_response(file_id: FileId, data: Vec<u8>) -> Vec<u8> {
|
||||
let create_resp = build_create_response(file_id, data.len() as u64);
|
||||
let read_resp = build_read_response(data);
|
||||
let close_resp = build_close_response();
|
||||
build_compound_response_frame(&[create_resp, read_resp, close_resp])
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn pipeline_batch_of_three_reads() {
|
||||
let mock = Arc::new(MockTransport::new());
|
||||
|
||||
let file_id = FileId {
|
||||
persistent: 1,
|
||||
volatile: 2,
|
||||
};
|
||||
|
||||
// Three read operations, each needs a compound CREATE + READ + CLOSE frame.
|
||||
for i in 0..3 {
|
||||
let data = format!("content_{}", i);
|
||||
mock.queue_response(build_compound_read_response(file_id, data.into_bytes()));
|
||||
}
|
||||
|
||||
let mut conn = setup_connection(&mock);
|
||||
let tree = test_tree();
|
||||
let mut pipeline = Pipeline::new(&mut conn, &tree);
|
||||
|
||||
let results = pipeline
|
||||
.execute(vec![
|
||||
Op::ReadFile("file1.txt".to_string()),
|
||||
Op::ReadFile("file2.txt".to_string()),
|
||||
Op::ReadFile("file3.txt".to_string()),
|
||||
])
|
||||
.await;
|
||||
|
||||
assert_eq!(results.len(), 3);
|
||||
for (i, result) in results.into_iter().enumerate() {
|
||||
match result {
|
||||
OpResult::FileData { path, data } => {
|
||||
assert_eq!(path, format!("file{}.txt", i + 1));
|
||||
assert_eq!(data, format!("content_{}", i).into_bytes());
|
||||
}
|
||||
other => panic!("expected FileData, got {:?}", other),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn pipeline_mixed_ops() {
|
||||
let mock = Arc::new(MockTransport::new());
|
||||
|
||||
let file_id = FileId {
|
||||
persistent: 1,
|
||||
volatile: 2,
|
||||
};
|
||||
|
||||
// Op 1: ReadFile -- compound CREATE + READ + CLOSE
|
||||
mock.queue_response(build_compound_read_response(file_id, b"hello".to_vec()));
|
||||
|
||||
// Op 2: Delete -- compound CREATE(DELETE_ON_CLOSE) + CLOSE
|
||||
let del_create = build_create_response(file_id, 0);
|
||||
let del_close = build_close_response();
|
||||
mock.queue_response(build_compound_response_frame(&[del_create, del_close]));
|
||||
|
||||
// Op 3: ListDirectory -- CREATE + QUERY_DIR + QUERY_DIR(NO_MORE) + CLOSE
|
||||
mock.queue_response(build_create_response_directory(file_id));
|
||||
let entry = build_file_both_dir_info("test.txt", 100, false, 0);
|
||||
mock.queue_response(build_query_directory_response(NtStatus::SUCCESS, entry));
|
||||
mock.queue_response(build_query_directory_response(
|
||||
NtStatus::NO_MORE_FILES,
|
||||
vec![],
|
||||
));
|
||||
mock.queue_response(build_close_response());
|
||||
|
||||
let mut conn = setup_connection(&mock);
|
||||
let tree = test_tree();
|
||||
let mut pipeline = Pipeline::new(&mut conn, &tree);
|
||||
|
||||
let results = pipeline
|
||||
.execute(vec![
|
||||
Op::ReadFile("data.bin".to_string()),
|
||||
Op::Delete("old.txt".to_string()),
|
||||
Op::ListDirectory("docs".to_string()),
|
||||
])
|
||||
.await;
|
||||
|
||||
assert_eq!(results.len(), 3);
|
||||
|
||||
match &results[0] {
|
||||
OpResult::FileData { data, .. } => assert_eq!(data, b"hello"),
|
||||
other => panic!("expected FileData, got {:?}", other),
|
||||
}
|
||||
|
||||
match &results[1] {
|
||||
OpResult::Deleted { path } => assert_eq!(path, "old.txt"),
|
||||
other => panic!("expected Deleted, got {:?}", other),
|
||||
}
|
||||
|
||||
match &results[2] {
|
||||
OpResult::DirEntries { entries, .. } => {
|
||||
assert_eq!(entries.len(), 1);
|
||||
assert_eq!(entries[0].name, "test.txt");
|
||||
}
|
||||
other => panic!("expected DirEntries, got {:?}", other),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn pipeline_delete_file() {
|
||||
let mock = Arc::new(MockTransport::new());
|
||||
let file_id = FileId {
|
||||
persistent: 1,
|
||||
volatile: 2,
|
||||
};
|
||||
|
||||
// DELETE = compound CREATE(DELETE_ON_CLOSE) + CLOSE
|
||||
let create_resp = build_create_response(file_id, 0);
|
||||
let close_resp = build_close_response();
|
||||
let frame = build_compound_response_frame(&[create_resp, close_resp]);
|
||||
mock.queue_response(frame);
|
||||
|
||||
let mut conn = setup_connection(&mock);
|
||||
let tree = test_tree();
|
||||
let mut pipeline = Pipeline::new(&mut conn, &tree);
|
||||
|
||||
let results = pipeline
|
||||
.execute(vec![Op::Delete("remove_me.txt".to_string())])
|
||||
.await;
|
||||
|
||||
assert_eq!(results.len(), 1);
|
||||
match &results[0] {
|
||||
OpResult::Deleted { path } => assert_eq!(path, "remove_me.txt"),
|
||||
other => panic!("expected Deleted, got {:?}", other),
|
||||
}
|
||||
|
||||
// One compound frame sent.
|
||||
let sent = mock.sent_messages();
|
||||
assert_eq!(sent.len(), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn pipeline_write_file() {
|
||||
let mock = Arc::new(MockTransport::new());
|
||||
let file_id = FileId {
|
||||
persistent: 1,
|
||||
volatile: 2,
|
||||
};
|
||||
|
||||
// WRITE uses compound: CREATE+WRITE+FLUSH+CLOSE in one frame.
|
||||
let create_resp = build_create_response(file_id, 0);
|
||||
let write_resp = build_write_response(11);
|
||||
let flush_resp = build_flush_response();
|
||||
let close_resp = build_close_response();
|
||||
let frame =
|
||||
build_compound_response_frame(&[create_resp, write_resp, flush_resp, close_resp]);
|
||||
mock.queue_response(frame);
|
||||
|
||||
let mut conn = setup_connection(&mock);
|
||||
let tree = test_tree();
|
||||
let mut pipeline = Pipeline::new(&mut conn, &tree);
|
||||
|
||||
let results = pipeline
|
||||
.execute(vec![Op::WriteFile(
|
||||
"output.txt".to_string(),
|
||||
b"hello world".to_vec(),
|
||||
)])
|
||||
.await;
|
||||
|
||||
assert_eq!(results.len(), 1);
|
||||
match &results[0] {
|
||||
OpResult::Written {
|
||||
path,
|
||||
bytes_written,
|
||||
} => {
|
||||
assert_eq!(path, "output.txt");
|
||||
assert_eq!(*bytes_written, 11);
|
||||
}
|
||||
other => panic!("expected Written, got {:?}", other),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn pipeline_stat() {
|
||||
let mock = Arc::new(MockTransport::new());
|
||||
let file_id = FileId {
|
||||
persistent: 1,
|
||||
volatile: 2,
|
||||
};
|
||||
|
||||
// STAT = compound CREATE + QUERY_INFO(basic) + QUERY_INFO(standard) + CLOSE
|
||||
let create_resp = build_create_response(file_id, 0);
|
||||
|
||||
let basic_info = build_file_basic_info(
|
||||
132_000_000_000_000_000,
|
||||
132_100_000_000_000_000,
|
||||
133_000_000_000_000_000,
|
||||
133_000_000_000_000_000,
|
||||
0x20, // ARCHIVE (not a directory)
|
||||
);
|
||||
let basic_resp = build_query_info_response(basic_info);
|
||||
|
||||
let std_info = build_file_standard_info(
|
||||
4096, // allocation_size
|
||||
2048, // end_of_file (actual size)
|
||||
1, // number_of_links
|
||||
false, // delete_pending
|
||||
false, // directory
|
||||
);
|
||||
let std_resp = build_query_info_response(std_info);
|
||||
|
||||
let close_resp = build_close_response();
|
||||
|
||||
let frame = build_compound_response_frame(&[create_resp, basic_resp, std_resp, close_resp]);
|
||||
mock.queue_response(frame);
|
||||
|
||||
let mut conn = setup_connection(&mock);
|
||||
let tree = test_tree();
|
||||
let mut pipeline = Pipeline::new(&mut conn, &tree);
|
||||
|
||||
let results = pipeline
|
||||
.execute(vec![Op::Stat("info.txt".to_string())])
|
||||
.await;
|
||||
|
||||
assert_eq!(results.len(), 1);
|
||||
match &results[0] {
|
||||
OpResult::Stat { path, info } => {
|
||||
assert_eq!(path, "info.txt");
|
||||
assert_eq!(info.size, 2048);
|
||||
assert!(!info.is_directory);
|
||||
assert_eq!(info.created, FileTime(132_000_000_000_000_000));
|
||||
assert_eq!(info.modified, FileTime(133_000_000_000_000_000));
|
||||
}
|
||||
other => panic!("expected Stat, got {:?}", other),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn pipeline_error_does_not_abort_batch() {
|
||||
let mock = Arc::new(MockTransport::new());
|
||||
let file_id = FileId {
|
||||
persistent: 1,
|
||||
volatile: 2,
|
||||
};
|
||||
|
||||
// Op 1: ReadFile that fails at CREATE -- compound frame with cascaded errors.
|
||||
let error_body = ErrorResponse {
|
||||
error_context_count: 0,
|
||||
error_data: vec![],
|
||||
};
|
||||
|
||||
let mut h1 = Header::new_request(Command::Create);
|
||||
h1.flags.set_response();
|
||||
h1.credits = 32;
|
||||
h1.status = NtStatus::OBJECT_NAME_NOT_FOUND;
|
||||
let create_err = pack_message(&h1, &error_body);
|
||||
|
||||
let mut h2 = Header::new_request(Command::Read);
|
||||
h2.flags.set_response();
|
||||
h2.credits = 32;
|
||||
h2.status = NtStatus::OBJECT_NAME_NOT_FOUND;
|
||||
let read_err = pack_message(&h2, &error_body);
|
||||
|
||||
let mut h3 = Header::new_request(Command::Close);
|
||||
h3.flags.set_response();
|
||||
h3.credits = 32;
|
||||
h3.status = NtStatus::OBJECT_NAME_NOT_FOUND;
|
||||
let close_err = pack_message(&h3, &error_body);
|
||||
|
||||
mock.queue_response(build_compound_response_frame(&[
|
||||
create_err, read_err, close_err,
|
||||
]));
|
||||
|
||||
// Op 2: ReadFile that succeeds -- compound frame.
|
||||
mock.queue_response(build_compound_read_response(file_id, b"abc".to_vec()));
|
||||
|
||||
let mut conn = setup_connection(&mock);
|
||||
let tree = test_tree();
|
||||
let mut pipeline = Pipeline::new(&mut conn, &tree);
|
||||
|
||||
let results = pipeline
|
||||
.execute(vec![
|
||||
Op::ReadFile("missing.txt".to_string()),
|
||||
Op::ReadFile("exists.txt".to_string()),
|
||||
])
|
||||
.await;
|
||||
|
||||
assert_eq!(results.len(), 2);
|
||||
match &results[0] {
|
||||
OpResult::Error { path, .. } => assert_eq!(path, "missing.txt"),
|
||||
other => panic!("expected Error, got {:?}", other),
|
||||
}
|
||||
match &results[1] {
|
||||
OpResult::FileData { path, data } => {
|
||||
assert_eq!(path, "exists.txt");
|
||||
assert_eq!(data, b"abc");
|
||||
}
|
||||
other => panic!("expected FileData, got {:?}", other),
|
||||
}
|
||||
}
|
||||
}
|
||||
769
vendor/smb2/src/client/session.rs
vendored
Normal file
769
vendor/smb2/src/client/session.rs
vendored
Normal file
@@ -0,0 +1,769 @@
|
||||
//! Authenticated SMB2 session.
|
||||
//!
|
||||
//! The [`Session`] type manages the multi-round-trip SESSION_SETUP exchange
|
||||
//! (NTLM authentication), key derivation, and signing activation.
|
||||
|
||||
use log::{debug, info, trace, warn};
|
||||
|
||||
use crate::auth::ntlm::{NtlmAuthenticator, NtlmCredentials};
|
||||
use crate::client::connection::Connection;
|
||||
use crate::crypto::kdf::derive_session_keys;
|
||||
use crate::crypto::signing::{algorithm_for_dialect, SigningAlgorithm};
|
||||
use crate::error::Result;
|
||||
use crate::msg::session_setup::{SessionSetupRequest, SessionSetupResponse};
|
||||
use crate::pack::{ReadCursor, Unpack};
|
||||
use crate::types::flags::{Capabilities, SecurityMode};
|
||||
use crate::types::status::NtStatus;
|
||||
use crate::types::{Command, Dialect, SessionId};
|
||||
use crate::Error;
|
||||
|
||||
use crate::msg::session_setup::SessionSetupRequestFlags;
|
||||
|
||||
/// An authenticated SMB2 session with derived keys.
|
||||
#[derive(Debug)]
|
||||
pub struct Session {
|
||||
/// The session ID assigned by the server.
|
||||
pub session_id: SessionId,
|
||||
/// Key used to sign outgoing messages.
|
||||
pub signing_key: Vec<u8>,
|
||||
/// Key used to encrypt outgoing messages (SMB 3.x).
|
||||
pub encryption_key: Option<Vec<u8>>,
|
||||
/// Key used to decrypt incoming messages (SMB 3.x).
|
||||
pub decryption_key: Option<Vec<u8>>,
|
||||
/// The signing algorithm to use.
|
||||
pub signing_algorithm: SigningAlgorithm,
|
||||
/// Whether outgoing messages should be signed.
|
||||
pub should_sign: bool,
|
||||
/// Whether outgoing messages should be encrypted.
|
||||
pub should_encrypt: bool,
|
||||
}
|
||||
|
||||
impl Session {
|
||||
/// Perform the multi-round-trip SESSION_SETUP exchange.
|
||||
///
|
||||
/// Steps:
|
||||
/// 1. Send NTLM NEGOTIATE_MESSAGE in SESSION_SETUP.
|
||||
/// 2. Receive STATUS_MORE_PROCESSING_REQUIRED with CHALLENGE_MESSAGE.
|
||||
/// 3. Update preauth hash with request+response.
|
||||
/// 4. Send NTLM AUTHENTICATE_MESSAGE in SESSION_SETUP.
|
||||
/// 5. Receive STATUS_SUCCESS with session flags.
|
||||
/// 6. Update preauth hash with request+response.
|
||||
/// 7. Derive signing/encryption keys.
|
||||
/// 8. Activate signing on the connection.
|
||||
pub async fn setup(
|
||||
conn: &mut Connection,
|
||||
username: &str,
|
||||
password: &str,
|
||||
domain: &str,
|
||||
) -> Result<Session> {
|
||||
let params = conn
|
||||
.params()
|
||||
.ok_or_else(|| Error::invalid_data("negotiate must complete before session setup"))?
|
||||
.clone();
|
||||
|
||||
let mut auth = NtlmAuthenticator::new(NtlmCredentials {
|
||||
username: username.to_string(),
|
||||
password: password.to_string(),
|
||||
domain: domain.to_string(),
|
||||
});
|
||||
|
||||
// Clone the preauth hasher for this session (spec: per-session hash).
|
||||
let mut session_hasher = conn.preauth_hasher().clone();
|
||||
|
||||
// ── Round 1: NEGOTIATE_MESSAGE ──
|
||||
debug!("session: round 1, sending NTLM negotiate");
|
||||
|
||||
let type1_bytes = auth.negotiate();
|
||||
|
||||
let req1 = SessionSetupRequest {
|
||||
flags: SessionSetupRequestFlags(0),
|
||||
security_mode: SecurityMode::new(SecurityMode::SIGNING_ENABLED),
|
||||
capabilities: Capabilities::default(),
|
||||
channel: 0,
|
||||
previous_session_id: 0,
|
||||
security_buffer: type1_bytes,
|
||||
};
|
||||
|
||||
let (frame1, req1_raw) = conn
|
||||
.execute_capturing_request(Command::SessionSetup, &req1, None)
|
||||
.await?;
|
||||
|
||||
// Update session preauth hash with request.
|
||||
session_hasher.update(&req1_raw);
|
||||
|
||||
let resp1_header = frame1.header;
|
||||
let resp1_body = frame1.body;
|
||||
|
||||
// Update session preauth hash with response.
|
||||
session_hasher.update(&frame1.raw);
|
||||
|
||||
if resp1_header.command != Command::SessionSetup {
|
||||
return Err(Error::invalid_data(format!(
|
||||
"expected SessionSetup response, got {:?}",
|
||||
resp1_header.command
|
||||
)));
|
||||
}
|
||||
|
||||
if !resp1_header.status.is_more_processing_required() {
|
||||
if resp1_header.status.is_error() {
|
||||
return Err(Error::Protocol {
|
||||
status: resp1_header.status,
|
||||
command: Command::SessionSetup,
|
||||
});
|
||||
}
|
||||
return Err(Error::invalid_data(
|
||||
"expected STATUS_MORE_PROCESSING_REQUIRED, got success on first round",
|
||||
));
|
||||
}
|
||||
|
||||
// The server assigned a session ID -- use it for subsequent requests.
|
||||
debug!(
|
||||
"session: round 1 complete, status={:?}, session_id={}",
|
||||
resp1_header.status, resp1_header.session_id
|
||||
);
|
||||
conn.set_session_id(resp1_header.session_id);
|
||||
|
||||
// Parse the challenge response.
|
||||
let mut cursor1 = ReadCursor::new(&resp1_body);
|
||||
let setup_resp1 = SessionSetupResponse::unpack(&mut cursor1)?;
|
||||
|
||||
// ── Round 2: AUTHENTICATE_MESSAGE ──
|
||||
debug!("session: round 2, sending NTLM authenticate");
|
||||
|
||||
let type3_bytes = auth.authenticate(&setup_resp1.security_buffer)?;
|
||||
|
||||
let req2 = SessionSetupRequest {
|
||||
flags: SessionSetupRequestFlags(0),
|
||||
security_mode: SecurityMode::new(SecurityMode::SIGNING_ENABLED),
|
||||
capabilities: Capabilities::default(),
|
||||
channel: 0,
|
||||
previous_session_id: 0,
|
||||
security_buffer: type3_bytes,
|
||||
};
|
||||
|
||||
let (frame2, req2_raw) = conn
|
||||
.execute_capturing_request(Command::SessionSetup, &req2, None)
|
||||
.await?;
|
||||
|
||||
// Update session preauth hash with the request ONLY.
|
||||
// The final SESSION_SETUP response (STATUS_SUCCESS) is NOT
|
||||
// included in the preauth hash (spec section 3.2.5.3.1).
|
||||
// Only STATUS_MORE_PROCESSING_REQUIRED responses are hashed.
|
||||
session_hasher.update(&req2_raw);
|
||||
|
||||
let resp2_header = frame2.header;
|
||||
let resp2_body = frame2.body;
|
||||
|
||||
// Do NOT hash the success response -- the preauth hash used for
|
||||
// key derivation contains only messages up to (and including)
|
||||
// the final authenticate request, not the success response.
|
||||
|
||||
if resp2_header.command != Command::SessionSetup {
|
||||
return Err(Error::invalid_data(format!(
|
||||
"expected SessionSetup response, got {:?}",
|
||||
resp2_header.command
|
||||
)));
|
||||
}
|
||||
|
||||
if resp2_header.status != NtStatus::SUCCESS {
|
||||
return Err(Error::Protocol {
|
||||
status: resp2_header.status,
|
||||
command: Command::SessionSetup,
|
||||
});
|
||||
}
|
||||
|
||||
// Parse the final response.
|
||||
let mut cursor2 = ReadCursor::new(&resp2_body);
|
||||
let setup_resp2 = SessionSetupResponse::unpack(&mut cursor2)?;
|
||||
|
||||
let session_id = resp2_header.session_id;
|
||||
conn.set_session_id(session_id);
|
||||
|
||||
// Get the session key from NTLM.
|
||||
let session_key = auth
|
||||
.session_key()
|
||||
.ok_or_else(|| Error::Auth {
|
||||
message: "NTLM did not produce a session key".to_string(),
|
||||
})?
|
||||
.to_vec();
|
||||
|
||||
// Determine signing algorithm.
|
||||
let gmac_negotiated = params.gmac_negotiated;
|
||||
let signing_algorithm = algorithm_for_dialect(params.dialect, gmac_negotiated);
|
||||
debug!(
|
||||
"session: signing_algo={:?}, dialect={}",
|
||||
signing_algorithm, params.dialect
|
||||
);
|
||||
|
||||
// Derive keys for SMB 3.x, or use session key directly for SMB 2.x.
|
||||
trace!(
|
||||
"session: deriving keys, session_key_len={}",
|
||||
session_key.len()
|
||||
);
|
||||
let (signing_key, encryption_key, decryption_key) = match params.dialect {
|
||||
Dialect::Smb3_0 | Dialect::Smb3_0_2 => {
|
||||
let keys = derive_session_keys(&session_key, params.dialect, None, 128);
|
||||
(
|
||||
keys.signing_key,
|
||||
Some(keys.encryption_key),
|
||||
Some(keys.decryption_key),
|
||||
)
|
||||
}
|
||||
Dialect::Smb3_1_1 => {
|
||||
// Key length: 256 bits only for AES-256 ciphers. GMAC signing
|
||||
// uses AES-128-GCM internally, so it needs 128-bit (16-byte) keys.
|
||||
let key_len_bits = match params.cipher {
|
||||
Some(crate::crypto::encryption::Cipher::Aes256Ccm)
|
||||
| Some(crate::crypto::encryption::Cipher::Aes256Gcm) => 256,
|
||||
_ => 128,
|
||||
};
|
||||
let keys = derive_session_keys(
|
||||
&session_key,
|
||||
Dialect::Smb3_1_1,
|
||||
Some(session_hasher.value()),
|
||||
key_len_bits,
|
||||
);
|
||||
(
|
||||
keys.signing_key,
|
||||
Some(keys.encryption_key),
|
||||
Some(keys.decryption_key),
|
||||
)
|
||||
}
|
||||
_ => {
|
||||
// SMB 2.x: use session key directly for signing.
|
||||
(session_key.clone(), None, None)
|
||||
}
|
||||
};
|
||||
|
||||
// Determine if we should sign.
|
||||
let should_sign = params.signing_required
|
||||
|| !setup_resp2.session_flags.is_guest() && !setup_resp2.session_flags.is_null();
|
||||
|
||||
let should_encrypt = setup_resp2.session_flags.encrypt_data();
|
||||
|
||||
// Activate signing on the connection.
|
||||
if should_sign {
|
||||
conn.activate_signing(signing_key.clone(), signing_algorithm);
|
||||
}
|
||||
|
||||
// Activate encryption on the connection if the session requires it.
|
||||
// The cipher comes from negotiate contexts (SMB 3.1.1). If the server
|
||||
// didn't send one (for example, Samba with `smb encrypt = required` sometimes
|
||||
// omits the encryption context), fall back to AES-128-CCM which is
|
||||
// universally supported by all SMB 3.x servers.
|
||||
if should_encrypt {
|
||||
let cipher = params
|
||||
.cipher
|
||||
.unwrap_or(crate::crypto::encryption::Cipher::Aes128Ccm);
|
||||
if let (Some(ref enc_key), Some(ref dec_key)) = (&encryption_key, &decryption_key) {
|
||||
conn.activate_encryption(enc_key.clone(), dec_key.clone(), cipher);
|
||||
} else {
|
||||
warn!(
|
||||
"session: encryption requested but missing keys, \
|
||||
enc_key={}, dec_key={}",
|
||||
encryption_key.is_some(),
|
||||
decryption_key.is_some(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
info!(
|
||||
"session: established, session_id={}, sign={}, encrypt={}",
|
||||
session_id, should_sign, should_encrypt
|
||||
);
|
||||
|
||||
Ok(Session {
|
||||
session_id,
|
||||
signing_key,
|
||||
encryption_key,
|
||||
decryption_key,
|
||||
signing_algorithm,
|
||||
should_sign,
|
||||
should_encrypt,
|
||||
})
|
||||
}
|
||||
|
||||
/// Perform Kerberos-based SESSION_SETUP.
|
||||
///
|
||||
/// Authenticates against the KDC first (AS + TGS), then sends the
|
||||
/// SPNEGO-wrapped AP-REQ in SESSION_SETUP. Handles both single-round
|
||||
/// (STATUS_SUCCESS) and mutual-auth (STATUS_MORE_PROCESSING_REQUIRED)
|
||||
/// flows.
|
||||
///
|
||||
/// The session key comes from the Kerberos TGS exchange, not from the
|
||||
/// SMB server response.
|
||||
/// Perform Kerberos-based SESSION_SETUP using a credential cache.
|
||||
///
|
||||
/// Reads cached tickets from the ccache. If a service ticket for
|
||||
/// `cifs/<server_hostname>` is cached, uses it directly (no KDC needed).
|
||||
/// If only a TGT is cached, does a TGS exchange for the service ticket.
|
||||
pub async fn setup_kerberos_from_ccache(
|
||||
conn: &mut Connection,
|
||||
credentials: &crate::auth::kerberos::KerberosCredentials,
|
||||
server_hostname: &str,
|
||||
ccache: &crate::auth::kerberos::ccache::CCache,
|
||||
) -> Result<Session> {
|
||||
let mut auth = crate::auth::kerberos::KerberosAuthenticator::new(credentials.clone());
|
||||
auth.authenticate_from_ccache(ccache, server_hostname)
|
||||
.await?;
|
||||
Self::setup_kerberos_with_auth(conn, &mut auth).await
|
||||
}
|
||||
|
||||
/// Perform Kerberos-based SESSION_SETUP.
|
||||
///
|
||||
/// Authenticates against the KDC first (AS + TGS), then sends the
|
||||
/// SPNEGO-wrapped AP-REQ in SESSION_SETUP. Handles both single-round
|
||||
/// (STATUS_SUCCESS) and mutual-auth (STATUS_MORE_PROCESSING_REQUIRED)
|
||||
/// flows.
|
||||
///
|
||||
/// The session key comes from the Kerberos TGS exchange, not from the
|
||||
/// SMB server response.
|
||||
pub async fn setup_kerberos(
|
||||
conn: &mut Connection,
|
||||
credentials: &crate::auth::kerberos::KerberosCredentials,
|
||||
server_hostname: &str,
|
||||
) -> Result<Session> {
|
||||
let mut auth = crate::auth::kerberos::KerberosAuthenticator::new(credentials.clone());
|
||||
auth.authenticate(server_hostname).await?;
|
||||
Self::setup_kerberos_with_auth(conn, &mut auth).await
|
||||
}
|
||||
|
||||
/// Shared Kerberos SESSION_SETUP logic used by both password-based
|
||||
/// and ccache-based authentication paths.
|
||||
async fn setup_kerberos_with_auth(
|
||||
conn: &mut Connection,
|
||||
auth: &mut crate::auth::kerberos::KerberosAuthenticator,
|
||||
) -> Result<Session> {
|
||||
let params = conn
|
||||
.params()
|
||||
.ok_or_else(|| Error::invalid_data("negotiate must complete before session setup"))?
|
||||
.clone();
|
||||
|
||||
let token = auth
|
||||
.token()
|
||||
.ok_or_else(|| Error::Auth {
|
||||
message: "Kerberos authentication produced no token".to_string(),
|
||||
})?
|
||||
.to_vec();
|
||||
|
||||
debug!("session: Kerberos auth complete, token_len={}", token.len());
|
||||
|
||||
// Clone the preauth hasher for this session.
|
||||
let mut session_hasher = conn.preauth_hasher().clone();
|
||||
|
||||
// Step 2: Send SPNEGO-wrapped AP-REQ in SESSION_SETUP.
|
||||
let req = SessionSetupRequest {
|
||||
flags: SessionSetupRequestFlags(0),
|
||||
security_mode: SecurityMode::new(SecurityMode::SIGNING_ENABLED),
|
||||
capabilities: Capabilities::default(),
|
||||
channel: 0,
|
||||
previous_session_id: 0,
|
||||
security_buffer: token,
|
||||
};
|
||||
|
||||
let (frame, req_raw) = conn
|
||||
.execute_capturing_request(Command::SessionSetup, &req, None)
|
||||
.await?;
|
||||
|
||||
// Hash the request (same as NTLM round 1).
|
||||
session_hasher.update(&req_raw);
|
||||
|
||||
let resp_header = frame.header;
|
||||
let resp_body = frame.body;
|
||||
let resp_raw = frame.raw;
|
||||
|
||||
if resp_header.command != Command::SessionSetup {
|
||||
return Err(Error::invalid_data(format!(
|
||||
"expected SessionSetup response, got {:?}",
|
||||
resp_header.command
|
||||
)));
|
||||
}
|
||||
|
||||
if resp_header.status != NtStatus::SUCCESS
|
||||
&& !resp_header.status.is_more_processing_required()
|
||||
{
|
||||
return Err(Error::Protocol {
|
||||
status: resp_header.status,
|
||||
command: Command::SessionSetup,
|
||||
});
|
||||
}
|
||||
|
||||
// The server assigned a session ID.
|
||||
let session_id = resp_header.session_id;
|
||||
conn.set_session_id(session_id);
|
||||
|
||||
let mut cursor = ReadCursor::new(&resp_body);
|
||||
let setup_resp = SessionSetupResponse::unpack(&mut cursor)?;
|
||||
|
||||
if resp_header.status.is_more_processing_required() {
|
||||
debug!(
|
||||
"session: Kerberos got MORE_PROCESSING_REQUIRED, session_id={}",
|
||||
session_id
|
||||
);
|
||||
|
||||
// Hash the response per MS-SMB2 3.2.5.3.1.
|
||||
session_hasher.update(&resp_raw);
|
||||
}
|
||||
|
||||
// Process the SPNEGO response token (AP-REP or KRB-ERROR).
|
||||
// This applies to both STATUS_SUCCESS and MORE_PROCESSING_REQUIRED —
|
||||
// the server may include an AP-REP with a sub-session key in either.
|
||||
if !setup_resp.security_buffer.is_empty() {
|
||||
let spnego_resp =
|
||||
crate::auth::spnego::parse_neg_token_resp(&setup_resp.security_buffer)?;
|
||||
debug!(
|
||||
"session: SPNEGO state={:?}, has_token={}, supported_mech={:02x?}",
|
||||
spnego_resp.neg_state,
|
||||
spnego_resp.response_token.is_some(),
|
||||
spnego_resp.supported_mech.as_deref().unwrap_or(&[]),
|
||||
);
|
||||
|
||||
if let Some(ref token_bytes) = spnego_resp.response_token {
|
||||
auth.process_mutual_auth_token(token_bytes)?;
|
||||
}
|
||||
}
|
||||
|
||||
// Get the session key AFTER processing the AP-REP (the server's
|
||||
// subkey may have overridden ours).
|
||||
//
|
||||
// Per MS-SMB2 3.2.5.3: "Session.SessionKey MUST be set to the first
|
||||
// 16 bytes of the cryptographic key queried from the GSS protocol."
|
||||
let full_key = auth.session_key().ok_or_else(|| Error::Auth {
|
||||
message: "Kerberos authentication produced no session key".to_string(),
|
||||
})?;
|
||||
let session_key = if full_key.len() > 16 {
|
||||
full_key[..16].to_vec()
|
||||
} else {
|
||||
full_key.to_vec()
|
||||
};
|
||||
|
||||
debug!(
|
||||
"session: Kerberos session_key_len={} (truncated from {})",
|
||||
session_key.len(),
|
||||
full_key.len()
|
||||
);
|
||||
|
||||
// Determine signing algorithm.
|
||||
let signing_algorithm = algorithm_for_dialect(params.dialect, params.gmac_negotiated);
|
||||
debug!(
|
||||
"session: Kerberos signing_algo={:?}, dialect={}",
|
||||
signing_algorithm, params.dialect
|
||||
);
|
||||
|
||||
// Derive keys for SMB 3.x using the Kerberos session key.
|
||||
let (signing_key, encryption_key, decryption_key) = match params.dialect {
|
||||
Dialect::Smb3_0 | Dialect::Smb3_0_2 => {
|
||||
let keys = derive_session_keys(&session_key, params.dialect, None, 128);
|
||||
(
|
||||
keys.signing_key,
|
||||
Some(keys.encryption_key),
|
||||
Some(keys.decryption_key),
|
||||
)
|
||||
}
|
||||
Dialect::Smb3_1_1 => {
|
||||
let key_len_bits = match params.cipher {
|
||||
Some(crate::crypto::encryption::Cipher::Aes256Ccm)
|
||||
| Some(crate::crypto::encryption::Cipher::Aes256Gcm) => 256,
|
||||
_ => 128,
|
||||
};
|
||||
let keys = derive_session_keys(
|
||||
&session_key,
|
||||
Dialect::Smb3_1_1,
|
||||
Some(session_hasher.value()),
|
||||
key_len_bits,
|
||||
);
|
||||
(
|
||||
keys.signing_key,
|
||||
Some(keys.encryption_key),
|
||||
Some(keys.decryption_key),
|
||||
)
|
||||
}
|
||||
_ => (session_key.clone(), None, None),
|
||||
};
|
||||
|
||||
let should_sign = params.signing_required
|
||||
|| !setup_resp.session_flags.is_guest() && !setup_resp.session_flags.is_null();
|
||||
|
||||
let should_encrypt = setup_resp.session_flags.encrypt_data();
|
||||
|
||||
if should_sign {
|
||||
conn.activate_signing(signing_key.clone(), signing_algorithm);
|
||||
}
|
||||
|
||||
if should_encrypt {
|
||||
let cipher = params
|
||||
.cipher
|
||||
.unwrap_or(crate::crypto::encryption::Cipher::Aes128Ccm);
|
||||
if let (Some(ref enc_key), Some(ref dec_key)) = (&encryption_key, &decryption_key) {
|
||||
conn.activate_encryption(enc_key.clone(), dec_key.clone(), cipher);
|
||||
}
|
||||
}
|
||||
|
||||
info!(
|
||||
"session: Kerberos established, session_id={}, sign={}, encrypt={}",
|
||||
session_id, should_sign, should_encrypt
|
||||
);
|
||||
|
||||
Ok(Session {
|
||||
session_id,
|
||||
signing_key,
|
||||
encryption_key,
|
||||
decryption_key,
|
||||
signing_algorithm,
|
||||
should_sign,
|
||||
should_encrypt,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::client::connection::{pack_message, Connection, NegotiatedParams};
|
||||
use crate::msg::header::Header;
|
||||
use crate::msg::session_setup::{SessionFlags, SessionSetupResponse};
|
||||
use crate::pack::Guid;
|
||||
use crate::transport::MockTransport;
|
||||
use crate::types::flags::Capabilities;
|
||||
use crate::types::status::NtStatus;
|
||||
use crate::types::{Command, Dialect, SessionId};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Build a session setup response with the given status and session ID.
|
||||
fn build_session_setup_response(
|
||||
status: NtStatus,
|
||||
session_id: SessionId,
|
||||
security_buffer: Vec<u8>,
|
||||
session_flags: SessionFlags,
|
||||
) -> Vec<u8> {
|
||||
let mut h = Header::new_request(Command::SessionSetup);
|
||||
h.flags.set_response();
|
||||
h.credits = 32;
|
||||
h.status = status;
|
||||
h.session_id = session_id;
|
||||
|
||||
let body = SessionSetupResponse {
|
||||
session_flags,
|
||||
security_buffer,
|
||||
};
|
||||
|
||||
pack_message(&h, &body)
|
||||
}
|
||||
|
||||
/// Build a minimal NTLM challenge message (Type 2).
|
||||
///
|
||||
/// This is a stripped-down challenge that the NtlmAuthenticator can parse.
|
||||
fn build_ntlm_challenge() -> Vec<u8> {
|
||||
let mut buf = Vec::new();
|
||||
|
||||
// Signature (8 bytes)
|
||||
buf.extend_from_slice(b"NTLMSSP\0");
|
||||
// MessageType = 2 (4 bytes)
|
||||
buf.extend_from_slice(&2u32.to_le_bytes());
|
||||
// TargetNameFields: Len=0, MaxLen=0, Offset=56
|
||||
buf.extend_from_slice(&0u16.to_le_bytes()); // Len
|
||||
buf.extend_from_slice(&0u16.to_le_bytes()); // MaxLen
|
||||
buf.extend_from_slice(&56u32.to_le_bytes()); // Offset
|
||||
// NegotiateFlags
|
||||
let flags: u32 = 0x0000_0001 // UNICODE
|
||||
| 0x0000_0200 // NTLM
|
||||
| 0x0008_0000 // EXTENDED_SESSIONSECURITY
|
||||
| 0x0080_0000 // TARGET_INFO
|
||||
| 0x2000_0000 // 128
|
||||
| 0x4000_0000 // KEY_EXCH
|
||||
| 0x8000_0000 // 56
|
||||
| 0x0000_0010 // SIGN
|
||||
| 0x0000_0020; // SEAL
|
||||
buf.extend_from_slice(&flags.to_le_bytes());
|
||||
// ServerChallenge (8 bytes)
|
||||
buf.extend_from_slice(&[0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF]);
|
||||
// Reserved (8 bytes)
|
||||
buf.extend_from_slice(&[0u8; 8]);
|
||||
|
||||
// TargetInfoFields: Len, MaxLen, Offset (will be at offset 56 + target_name_len)
|
||||
// Build target info: just MsvAvEOL
|
||||
let target_info = build_av_eol();
|
||||
let ti_offset = 56u32; // right after the fixed header
|
||||
buf.extend_from_slice(&(target_info.len() as u16).to_le_bytes()); // Len
|
||||
buf.extend_from_slice(&(target_info.len() as u16).to_le_bytes()); // MaxLen
|
||||
buf.extend_from_slice(&ti_offset.to_le_bytes()); // Offset
|
||||
|
||||
// Ensure we're at offset 56 (pad if needed).
|
||||
while buf.len() < 56 {
|
||||
buf.push(0);
|
||||
}
|
||||
|
||||
// Target info data
|
||||
buf.extend_from_slice(&target_info);
|
||||
|
||||
buf
|
||||
}
|
||||
|
||||
/// Build an AV_PAIR list with just MsvAvEOL.
|
||||
fn build_av_eol() -> Vec<u8> {
|
||||
let mut buf = Vec::new();
|
||||
// MsvAvEOL: AvId=0, AvLen=0
|
||||
buf.extend_from_slice(&0u16.to_le_bytes());
|
||||
buf.extend_from_slice(&0u16.to_le_bytes());
|
||||
buf
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn session_setup_stores_session_id() {
|
||||
let mock = Arc::new(MockTransport::new());
|
||||
mock.enable_auto_rewrite_msg_id();
|
||||
let session_id = SessionId(0xDEAD_BEEF);
|
||||
|
||||
// Queue the two session setup responses.
|
||||
let challenge = build_ntlm_challenge();
|
||||
mock.queue_response(build_session_setup_response(
|
||||
NtStatus::MORE_PROCESSING_REQUIRED,
|
||||
session_id,
|
||||
challenge,
|
||||
SessionFlags(0),
|
||||
));
|
||||
mock.queue_response(build_session_setup_response(
|
||||
NtStatus::SUCCESS,
|
||||
session_id,
|
||||
vec![],
|
||||
SessionFlags(0),
|
||||
));
|
||||
|
||||
let mut conn = Connection::from_transport(
|
||||
Box::new(mock.clone()),
|
||||
Box::new(mock.clone()),
|
||||
"test-server",
|
||||
);
|
||||
|
||||
// Set up negotiate params (pretend we already negotiated).
|
||||
// We need to call negotiate or set params manually.
|
||||
// Let's also queue a negotiate response first.
|
||||
// Actually, let's set params directly.
|
||||
set_test_params(&mut conn, Dialect::Smb2_0_2);
|
||||
|
||||
let session = Session::setup(&mut conn, "user", "pass", "").await.unwrap();
|
||||
assert_eq!(session.session_id, session_id);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn session_setup_derives_signing_key() {
|
||||
let mock = Arc::new(MockTransport::new());
|
||||
mock.enable_auto_rewrite_msg_id();
|
||||
let session_id = SessionId(0x1234);
|
||||
|
||||
let challenge = build_ntlm_challenge();
|
||||
mock.queue_response(build_session_setup_response(
|
||||
NtStatus::MORE_PROCESSING_REQUIRED,
|
||||
session_id,
|
||||
challenge,
|
||||
SessionFlags(0),
|
||||
));
|
||||
mock.queue_response(build_session_setup_response(
|
||||
NtStatus::SUCCESS,
|
||||
session_id,
|
||||
vec![],
|
||||
SessionFlags(0),
|
||||
));
|
||||
|
||||
let mut conn = Connection::from_transport(
|
||||
Box::new(mock.clone()),
|
||||
Box::new(mock.clone()),
|
||||
"test-server",
|
||||
);
|
||||
set_test_params(&mut conn, Dialect::Smb2_0_2);
|
||||
|
||||
let session = Session::setup(&mut conn, "user", "pass", "").await.unwrap();
|
||||
assert!(!session.signing_key.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn session_setup_activates_signing() {
|
||||
let mock = Arc::new(MockTransport::new());
|
||||
mock.enable_auto_rewrite_msg_id();
|
||||
let session_id = SessionId(0x5678);
|
||||
|
||||
let challenge = build_ntlm_challenge();
|
||||
mock.queue_response(build_session_setup_response(
|
||||
NtStatus::MORE_PROCESSING_REQUIRED,
|
||||
session_id,
|
||||
challenge,
|
||||
SessionFlags(0),
|
||||
));
|
||||
mock.queue_response(build_session_setup_response(
|
||||
NtStatus::SUCCESS,
|
||||
session_id,
|
||||
vec![],
|
||||
SessionFlags(0),
|
||||
));
|
||||
|
||||
let mut conn = Connection::from_transport(
|
||||
Box::new(mock.clone()),
|
||||
Box::new(mock.clone()),
|
||||
"test-server",
|
||||
);
|
||||
set_test_params(&mut conn, Dialect::Smb2_0_2);
|
||||
|
||||
let session = Session::setup(&mut conn, "user", "pass", "").await.unwrap();
|
||||
assert!(session.should_sign);
|
||||
assert_eq!(session.signing_algorithm, SigningAlgorithm::HmacSha256);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn session_setup_error_on_auth_failure() {
|
||||
let mock = Arc::new(MockTransport::new());
|
||||
mock.enable_auto_rewrite_msg_id();
|
||||
let session_id = SessionId(0x9999);
|
||||
|
||||
let challenge = build_ntlm_challenge();
|
||||
mock.queue_response(build_session_setup_response(
|
||||
NtStatus::MORE_PROCESSING_REQUIRED,
|
||||
session_id,
|
||||
challenge,
|
||||
SessionFlags(0),
|
||||
));
|
||||
// Auth fails on second round.
|
||||
mock.queue_response(build_session_setup_response(
|
||||
NtStatus::LOGON_FAILURE,
|
||||
session_id,
|
||||
vec![],
|
||||
SessionFlags(0),
|
||||
));
|
||||
|
||||
let mut conn = Connection::from_transport(
|
||||
Box::new(mock.clone()),
|
||||
Box::new(mock.clone()),
|
||||
"test-server",
|
||||
);
|
||||
set_test_params(&mut conn, Dialect::Smb2_0_2);
|
||||
|
||||
let result = Session::setup(&mut conn, "user", "badpass", "").await;
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err();
|
||||
assert!(
|
||||
matches!(
|
||||
err,
|
||||
Error::Protocol {
|
||||
status: NtStatus::LOGON_FAILURE,
|
||||
..
|
||||
}
|
||||
),
|
||||
"expected LOGON_FAILURE, got: {err}"
|
||||
);
|
||||
}
|
||||
|
||||
/// Helper: set fake negotiated params on a connection.
|
||||
fn set_test_params(conn: &mut Connection, dialect: Dialect) {
|
||||
conn.set_test_params(NegotiatedParams {
|
||||
dialect,
|
||||
max_read_size: 65536,
|
||||
max_write_size: 65536,
|
||||
max_transact_size: 65536,
|
||||
server_guid: Guid::ZERO,
|
||||
signing_required: false,
|
||||
capabilities: Capabilities::default(),
|
||||
gmac_negotiated: false,
|
||||
cipher: None,
|
||||
compression_supported: false,
|
||||
});
|
||||
}
|
||||
}
|
||||
764
vendor/smb2/src/client/shares.rs
vendored
Normal file
764
vendor/smb2/src/client/shares.rs
vendored
Normal file
@@ -0,0 +1,764 @@
|
||||
//! Share enumeration via IPC$ + srvsvc RPC.
|
||||
//!
|
||||
//! Lists available shares on an SMB server by connecting to the IPC$ share,
|
||||
//! opening the srvsvc named pipe, and performing the NetShareEnumAll RPC
|
||||
//! exchange.
|
||||
|
||||
use log::{debug, info};
|
||||
|
||||
use crate::client::connection::Connection;
|
||||
use crate::error::Result;
|
||||
use crate::msg::close::CloseRequest;
|
||||
use crate::msg::create::{
|
||||
CreateDisposition, CreateRequest, CreateResponse, ImpersonationLevel, ShareAccess,
|
||||
};
|
||||
use crate::msg::read::{ReadRequest, ReadResponse, SMB2_CHANNEL_NONE};
|
||||
use crate::msg::tree_connect::{TreeConnectRequest, TreeConnectRequestFlags, TreeConnectResponse};
|
||||
use crate::msg::tree_disconnect::TreeDisconnectRequest;
|
||||
use crate::msg::write::{WriteRequest, WriteResponse};
|
||||
use crate::pack::{ReadCursor, Unpack};
|
||||
use crate::rpc;
|
||||
use crate::rpc::srvsvc::{self, ShareInfo};
|
||||
use crate::types::flags::FileAccessMask;
|
||||
use crate::types::status::NtStatus;
|
||||
use crate::types::{Command, FileId, OplockLevel, TreeId};
|
||||
use crate::Error;
|
||||
|
||||
/// Read buffer size for pipe reads (64 KiB is plenty for share listings).
|
||||
const PIPE_READ_BUFFER_SIZE: u32 = 65536;
|
||||
|
||||
/// List available shares on the server.
|
||||
///
|
||||
/// Connects to the IPC$ share, opens the srvsvc named pipe, performs
|
||||
/// the RPC exchange, and returns filtered disk shares.
|
||||
///
|
||||
/// This is a self-contained operation -- it opens and closes its own
|
||||
/// tree connection to IPC$.
|
||||
pub async fn list_shares(conn: &mut Connection) -> Result<Vec<ShareInfo>> {
|
||||
// 1. Tree connect to IPC$
|
||||
let tree_id = tree_connect_ipc(conn).await?;
|
||||
|
||||
// Run the pipe operations, then clean up regardless of outcome
|
||||
let result = pipe_rpc_exchange(conn, tree_id).await;
|
||||
|
||||
// 8. Tree disconnect (best-effort -- don't mask the real error)
|
||||
let _ = tree_disconnect(conn, tree_id).await;
|
||||
|
||||
let all_shares = result?;
|
||||
|
||||
// 9. Filter to disk shares
|
||||
let filtered = srvsvc::filter_disk_shares(all_shares);
|
||||
info!("shares: found {} disk shares", filtered.len());
|
||||
Ok(filtered)
|
||||
}
|
||||
|
||||
/// Connect to the IPC$ share, returning the tree ID.
|
||||
async fn tree_connect_ipc(conn: &mut Connection) -> Result<TreeId> {
|
||||
let server = conn.server_name().to_string();
|
||||
let unc_path = format!(r"\\{}\IPC$", server);
|
||||
|
||||
let req = TreeConnectRequest {
|
||||
flags: TreeConnectRequestFlags::default(),
|
||||
path: unc_path,
|
||||
};
|
||||
|
||||
let frame = conn.execute(Command::TreeConnect, &req, None).await?;
|
||||
|
||||
if frame.header.command != Command::TreeConnect {
|
||||
return Err(Error::invalid_data(format!(
|
||||
"expected TreeConnect response, got {:?}",
|
||||
frame.header.command
|
||||
)));
|
||||
}
|
||||
|
||||
if frame.header.status != NtStatus::SUCCESS {
|
||||
return Err(Error::Protocol {
|
||||
status: frame.header.status,
|
||||
command: Command::TreeConnect,
|
||||
});
|
||||
}
|
||||
|
||||
let mut cursor = ReadCursor::new(&frame.body);
|
||||
let _resp = TreeConnectResponse::unpack(&mut cursor)?;
|
||||
|
||||
let tree_id = frame
|
||||
.header
|
||||
.tree_id
|
||||
.ok_or_else(|| Error::invalid_data("TreeConnect response missing tree ID"))?;
|
||||
|
||||
info!("shares: connected to IPC$, tree_id={}", tree_id);
|
||||
Ok(tree_id)
|
||||
}
|
||||
|
||||
/// Open the srvsvc pipe, perform the RPC bind and request, then close.
|
||||
async fn pipe_rpc_exchange(conn: &mut Connection, tree_id: TreeId) -> Result<Vec<ShareInfo>> {
|
||||
// 2. Create \pipe\srvsvc
|
||||
let file_id = open_srvsvc_pipe(conn, tree_id).await?;
|
||||
|
||||
// Run RPC exchange, then close regardless of outcome
|
||||
let result = rpc_bind_and_request(conn, tree_id, file_id).await;
|
||||
|
||||
// 7. Close the pipe handle (best-effort)
|
||||
let _ = close_handle(conn, tree_id, file_id).await;
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Perform the RPC bind + NetShareEnumAll request over the pipe.
|
||||
async fn rpc_bind_and_request(
|
||||
conn: &mut Connection,
|
||||
tree_id: TreeId,
|
||||
file_id: FileId,
|
||||
) -> Result<Vec<ShareInfo>> {
|
||||
// 3. Write RPC BIND
|
||||
let bind_data = rpc::build_srvsvc_bind(1);
|
||||
write_pipe(conn, tree_id, file_id, &bind_data).await?;
|
||||
debug!("shares: sent RPC BIND ({} bytes)", bind_data.len());
|
||||
|
||||
// 4. Read RPC BIND_ACK
|
||||
let bind_ack_data = read_pipe_message(conn, tree_id, file_id).await?;
|
||||
rpc::parse_bind_ack(&bind_ack_data)?;
|
||||
debug!("shares: received BIND_ACK, context accepted");
|
||||
|
||||
// 5. Write RPC REQUEST (NetShareEnumAll)
|
||||
let server_name = format!(r"\\{}", conn.server_name());
|
||||
let request_data = srvsvc::build_net_share_enum_all(2, &server_name);
|
||||
write_pipe(conn, tree_id, file_id, &request_data).await?;
|
||||
debug!(
|
||||
"shares: sent NetShareEnumAll request ({} bytes)",
|
||||
request_data.len()
|
||||
);
|
||||
|
||||
// 6. Read RPC RESPONSE, reassembling DCE/RPC fragments (MS-RPCE 2.2.2.6).
|
||||
// A large NetShareEnum reply may arrive as several fragment PDUs, each its
|
||||
// own pipe message, with PFC_LAST_FRAG set only on the last.
|
||||
let mut stub = Vec::new();
|
||||
let mut fragments = 0;
|
||||
loop {
|
||||
let pdu = read_pipe_message(conn, tree_id, file_id).await?;
|
||||
let (frag_stub, is_last) = rpc::parse_response_fragment(&pdu)?;
|
||||
stub.extend_from_slice(frag_stub);
|
||||
fragments += 1;
|
||||
if is_last {
|
||||
break;
|
||||
}
|
||||
}
|
||||
let shares = srvsvc::parse_net_share_enum_all_stub(&stub)?;
|
||||
debug!(
|
||||
"shares: received {} shares in response ({} RPC fragment(s))",
|
||||
shares.len(),
|
||||
fragments
|
||||
);
|
||||
|
||||
Ok(shares)
|
||||
}
|
||||
|
||||
/// Open the `\pipe\srvsvc` named pipe via CREATE.
|
||||
async fn open_srvsvc_pipe(conn: &mut Connection, tree_id: TreeId) -> Result<FileId> {
|
||||
let req = CreateRequest {
|
||||
requested_oplock_level: OplockLevel::None,
|
||||
impersonation_level: ImpersonationLevel::Impersonation,
|
||||
desired_access: FileAccessMask::new(
|
||||
FileAccessMask::FILE_READ_DATA | FileAccessMask::FILE_WRITE_DATA,
|
||||
),
|
||||
file_attributes: 0,
|
||||
share_access: ShareAccess(ShareAccess::FILE_SHARE_READ | ShareAccess::FILE_SHARE_WRITE),
|
||||
create_disposition: CreateDisposition::FileOpen,
|
||||
create_options: 0,
|
||||
name: r"srvsvc".to_string(),
|
||||
create_contexts: vec![],
|
||||
};
|
||||
|
||||
let frame = conn.execute(Command::Create, &req, Some(tree_id)).await?;
|
||||
|
||||
if frame.header.status != NtStatus::SUCCESS {
|
||||
return Err(Error::Protocol {
|
||||
status: frame.header.status,
|
||||
command: Command::Create,
|
||||
});
|
||||
}
|
||||
|
||||
let mut cursor = ReadCursor::new(&frame.body);
|
||||
let resp = CreateResponse::unpack(&mut cursor)?;
|
||||
debug!("shares: opened srvsvc pipe, file_id={:?}", resp.file_id);
|
||||
Ok(resp.file_id)
|
||||
}
|
||||
|
||||
/// Write data to the pipe.
|
||||
async fn write_pipe(
|
||||
conn: &mut Connection,
|
||||
tree_id: TreeId,
|
||||
file_id: FileId,
|
||||
data: &[u8],
|
||||
) -> Result<()> {
|
||||
// DataOffset: header (64) + fixed write body (48) = 112 = 0x70
|
||||
let req = WriteRequest {
|
||||
data_offset: 0x70,
|
||||
offset: 0,
|
||||
file_id,
|
||||
channel: 0,
|
||||
remaining_bytes: 0,
|
||||
write_channel_info_offset: 0,
|
||||
write_channel_info_length: 0,
|
||||
flags: 0,
|
||||
data: data.to_vec(),
|
||||
};
|
||||
|
||||
let frame = conn.execute(Command::Write, &req, Some(tree_id)).await?;
|
||||
|
||||
if frame.header.status != NtStatus::SUCCESS {
|
||||
return Err(Error::Protocol {
|
||||
status: frame.header.status,
|
||||
command: Command::Write,
|
||||
});
|
||||
}
|
||||
|
||||
let mut cursor = ReadCursor::new(&frame.body);
|
||||
let resp = WriteResponse::unpack(&mut cursor)?;
|
||||
debug!("shares: wrote {} bytes to pipe", resp.count);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Read one complete pipe message, following `STATUS_BUFFER_OVERFLOW`.
|
||||
///
|
||||
/// A pipe message larger than our read buffer comes back as one or more
|
||||
/// `STATUS_BUFFER_OVERFLOW` reads carrying partial data, terminated by a
|
||||
/// `STATUS_SUCCESS` read with the remainder (MS-SMB2 3.3.5.10). We append each
|
||||
/// chunk until a `SUCCESS` read completes the message.
|
||||
async fn read_pipe_message(
|
||||
conn: &mut Connection,
|
||||
tree_id: TreeId,
|
||||
file_id: FileId,
|
||||
) -> Result<Vec<u8>> {
|
||||
let mut message = Vec::new();
|
||||
|
||||
loop {
|
||||
let req = ReadRequest {
|
||||
padding: 0x50,
|
||||
flags: 0,
|
||||
length: PIPE_READ_BUFFER_SIZE,
|
||||
offset: 0,
|
||||
file_id,
|
||||
minimum_count: 0,
|
||||
channel: SMB2_CHANNEL_NONE,
|
||||
remaining_bytes: 0,
|
||||
read_channel_info: vec![],
|
||||
};
|
||||
|
||||
let frame = conn.execute(Command::Read, &req, Some(tree_id)).await?;
|
||||
|
||||
let status = frame.header.status;
|
||||
// BUFFER_OVERFLOW is a warning meaning "partial data, read again", not a
|
||||
// failure -- accept it alongside SUCCESS.
|
||||
if !status.is_success_or_partial() {
|
||||
return Err(Error::Protocol {
|
||||
status,
|
||||
command: Command::Read,
|
||||
});
|
||||
}
|
||||
|
||||
let mut cursor = ReadCursor::new(&frame.body);
|
||||
let resp = ReadResponse::unpack(&mut cursor)?;
|
||||
let chunk_len = resp.data.len();
|
||||
message.extend_from_slice(&resp.data);
|
||||
|
||||
// SUCCESS completes the message; BUFFER_OVERFLOW means read more.
|
||||
if status != NtStatus::BUFFER_OVERFLOW {
|
||||
break;
|
||||
}
|
||||
// Guard against a server that signals overflow but sends no data, which
|
||||
// would otherwise spin forever.
|
||||
if chunk_len == 0 {
|
||||
return Err(Error::invalid_data(
|
||||
"pipe read returned BUFFER_OVERFLOW with no data",
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
debug!("shares: read {} bytes from pipe", message.len());
|
||||
Ok(message)
|
||||
}
|
||||
|
||||
/// Close a file handle.
|
||||
async fn close_handle(conn: &mut Connection, tree_id: TreeId, file_id: FileId) -> Result<()> {
|
||||
let req = CloseRequest { flags: 0, file_id };
|
||||
|
||||
let frame = conn.execute(Command::Close, &req, Some(tree_id)).await?;
|
||||
|
||||
if frame.header.status != NtStatus::SUCCESS {
|
||||
return Err(Error::Protocol {
|
||||
status: frame.header.status,
|
||||
command: Command::Close,
|
||||
});
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Disconnect from a tree.
|
||||
async fn tree_disconnect(conn: &mut Connection, tree_id: TreeId) -> Result<()> {
|
||||
let body = TreeDisconnectRequest;
|
||||
let frame = conn
|
||||
.execute(Command::TreeDisconnect, &body, Some(tree_id))
|
||||
.await?;
|
||||
|
||||
if frame.header.status != NtStatus::SUCCESS {
|
||||
return Err(Error::Protocol {
|
||||
status: frame.header.status,
|
||||
command: Command::TreeDisconnect,
|
||||
});
|
||||
}
|
||||
|
||||
info!("shares: disconnected from IPC$");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) mod tests {
|
||||
use super::*;
|
||||
use crate::client::connection::{pack_message, NegotiatedParams};
|
||||
use crate::client::test_helpers::{
|
||||
build_close_response, build_create_response, build_tree_connect_response, setup_connection,
|
||||
};
|
||||
use crate::msg::header::Header;
|
||||
use crate::msg::read::ReadResponse as ReadResp;
|
||||
use crate::msg::tree_connect::ShareType;
|
||||
use crate::msg::tree_disconnect::TreeDisconnectResponse;
|
||||
use crate::msg::write::WriteResponse as WriteResp;
|
||||
use crate::pack::Guid;
|
||||
use crate::rpc::srvsvc::{STYPE_DISKTREE, STYPE_IPC, STYPE_SPECIAL};
|
||||
use crate::transport::MockTransport;
|
||||
use crate::types::flags::Capabilities;
|
||||
use crate::types::{Dialect, SessionId, TreeId};
|
||||
use std::sync::Arc;
|
||||
|
||||
fn build_write_response(count: u32) -> Vec<u8> {
|
||||
let mut h = Header::new_request(Command::Write);
|
||||
h.flags.set_response();
|
||||
h.credits = 32;
|
||||
|
||||
let body = WriteResp {
|
||||
count,
|
||||
remaining: 0,
|
||||
write_channel_info_offset: 0,
|
||||
write_channel_info_length: 0,
|
||||
};
|
||||
|
||||
pack_message(&h, &body)
|
||||
}
|
||||
|
||||
fn build_read_response(data: Vec<u8>) -> Vec<u8> {
|
||||
build_read_response_with_status(data, NtStatus::SUCCESS)
|
||||
}
|
||||
|
||||
/// Build a READ response with an explicit NTSTATUS.
|
||||
///
|
||||
/// Pipe reads use `STATUS_BUFFER_OVERFLOW` to mean "this read returned a
|
||||
/// partial message; read again for the rest."
|
||||
fn build_read_response_with_status(data: Vec<u8>, status: NtStatus) -> Vec<u8> {
|
||||
let mut h = Header::new_request(Command::Read);
|
||||
h.flags.set_response();
|
||||
h.credits = 32;
|
||||
h.status = status;
|
||||
|
||||
let body = ReadResp {
|
||||
data_offset: 0x50,
|
||||
data_remaining: 0,
|
||||
flags: 0,
|
||||
data,
|
||||
};
|
||||
|
||||
pack_message(&h, &body)
|
||||
}
|
||||
|
||||
fn build_tree_disconnect_response() -> Vec<u8> {
|
||||
let mut h = Header::new_request(Command::TreeDisconnect);
|
||||
h.flags.set_response();
|
||||
h.credits = 32;
|
||||
pack_message(&h, &TreeDisconnectResponse)
|
||||
}
|
||||
|
||||
/// Build a canned RPC BIND_ACK response.
|
||||
fn build_bind_ack() -> Vec<u8> {
|
||||
use crate::pack::WriteCursor;
|
||||
|
||||
let mut w = WriteCursor::with_capacity(64);
|
||||
// Common header
|
||||
w.write_u8(5); // version
|
||||
w.write_u8(0); // version minor
|
||||
w.write_u8(12); // BIND_ACK type
|
||||
w.write_u8(0x03); // flags (first + last)
|
||||
w.write_bytes(&[0x10, 0x00, 0x00, 0x00]); // data rep
|
||||
let frag_len_pos = w.position();
|
||||
w.write_u16_le(0); // frag length placeholder
|
||||
w.write_u16_le(0); // auth length
|
||||
w.write_u32_le(1); // call id
|
||||
|
||||
// BIND_ACK specific
|
||||
w.write_u16_le(4280); // max xmit frag
|
||||
w.write_u16_le(4280); // max recv frag
|
||||
w.write_u32_le(0x12345); // assoc group
|
||||
|
||||
// Secondary address (empty)
|
||||
w.write_u16_le(0);
|
||||
w.write_bytes(&[0, 0]); // padding
|
||||
|
||||
// Result list
|
||||
w.write_u8(1); // num results
|
||||
w.write_bytes(&[0, 0, 0]); // reserved
|
||||
w.write_u16_le(0); // result = accepted
|
||||
w.write_u16_le(0); // reason
|
||||
|
||||
// Transfer syntax UUID + version (20 bytes)
|
||||
use crate::pack::Pack;
|
||||
let ndr_uuid = Guid {
|
||||
data1: 0x8A885D04,
|
||||
data2: 0x1CEB,
|
||||
data3: 0x11C9,
|
||||
data4: [0x9F, 0xE8, 0x08, 0x00, 0x2B, 0x10, 0x48, 0x60],
|
||||
};
|
||||
ndr_uuid.pack(&mut w);
|
||||
w.write_u32_le(2);
|
||||
|
||||
let total_len = w.position();
|
||||
w.set_u16_le_at(frag_len_pos, total_len as u16);
|
||||
w.into_inner()
|
||||
}
|
||||
|
||||
/// Build the NDR stub for a NetShareEnumAll RESPONSE (no RPC envelope).
|
||||
fn build_share_enum_stub(shares: &[(&str, u32, &str)]) -> Vec<u8> {
|
||||
use crate::pack::WriteCursor;
|
||||
|
||||
// Build NDR stub
|
||||
let mut w = WriteCursor::with_capacity(512);
|
||||
let count = shares.len() as u32;
|
||||
|
||||
// Level = 1
|
||||
w.write_u32_le(1);
|
||||
// Union discriminant = 1
|
||||
w.write_u32_le(1);
|
||||
|
||||
if count == 0 {
|
||||
w.write_u32_le(0); // null container
|
||||
w.write_u32_le(0); // total entries
|
||||
w.write_u32_le(0); // resume handle
|
||||
w.write_u32_le(0); // return value
|
||||
} else {
|
||||
// Container pointer
|
||||
w.write_u32_le(0x0002_0000);
|
||||
// EntriesRead
|
||||
w.write_u32_le(count);
|
||||
// Array pointer
|
||||
w.write_u32_le(0x0002_0004);
|
||||
// MaxCount
|
||||
w.write_u32_le(count);
|
||||
|
||||
// Fixed entries
|
||||
for (i, &(_, share_type, _)) in shares.iter().enumerate() {
|
||||
w.write_u32_le(0x0002_0008 + (i as u32) * 2); // name ref
|
||||
w.write_u32_le(share_type);
|
||||
w.write_u32_le(0x0002_0108 + (i as u32) * 2); // comment ref
|
||||
}
|
||||
|
||||
// Deferred strings
|
||||
for &(name, _, comment) in shares {
|
||||
write_ndr_string(&mut w, name);
|
||||
write_ndr_string(&mut w, comment);
|
||||
}
|
||||
|
||||
w.write_u32_le(count); // total entries
|
||||
w.write_u32_le(0); // resume handle
|
||||
w.write_u32_le(0); // return value
|
||||
}
|
||||
|
||||
w.into_inner()
|
||||
}
|
||||
|
||||
/// Wrap NDR stub bytes in an RPC RESPONSE PDU with the given PFC flags.
|
||||
///
|
||||
/// `pfc_flags` lets a caller emit a fragment (for example, `PFC_FIRST_FRAG`
|
||||
/// alone for a non-final fragment) instead of the usual `FIRST | LAST`.
|
||||
fn wrap_rpc_response_pdu(stub_chunk: &[u8], pfc_flags: u8) -> Vec<u8> {
|
||||
use crate::pack::WriteCursor;
|
||||
|
||||
let mut w = WriteCursor::with_capacity(24 + stub_chunk.len());
|
||||
w.write_u8(5);
|
||||
w.write_u8(0);
|
||||
w.write_u8(2); // RESPONSE
|
||||
w.write_u8(pfc_flags);
|
||||
w.write_bytes(&[0x10, 0x00, 0x00, 0x00]);
|
||||
let frag_len_pos = w.position();
|
||||
w.write_u16_le(0);
|
||||
w.write_u16_le(0);
|
||||
w.write_u32_le(2); // call id
|
||||
|
||||
w.write_u32_le(stub_chunk.len() as u32); // alloc hint
|
||||
w.write_u16_le(0); // context id
|
||||
w.write_u8(0); // cancel count
|
||||
w.write_u8(0); // reserved
|
||||
|
||||
w.write_bytes(stub_chunk);
|
||||
|
||||
let total_len = w.position();
|
||||
w.set_u16_le_at(frag_len_pos, total_len as u16);
|
||||
w.into_inner()
|
||||
}
|
||||
|
||||
/// Build a canned single-fragment RPC RESPONSE with NetShareEnumAll data.
|
||||
fn build_share_enum_response(shares: &[(&str, u32, &str)]) -> Vec<u8> {
|
||||
// 0x03 = PFC_FIRST_FRAG | PFC_LAST_FRAG (a complete, single-fragment PDU).
|
||||
wrap_rpc_response_pdu(&build_share_enum_stub(shares), 0x03)
|
||||
}
|
||||
|
||||
fn write_ndr_string(w: &mut crate::pack::WriteCursor, s: &str) {
|
||||
let utf16: Vec<u16> = s.encode_utf16().chain(std::iter::once(0)).collect();
|
||||
let char_count = utf16.len() as u32;
|
||||
w.write_u32_le(char_count);
|
||||
w.write_u32_le(0);
|
||||
w.write_u32_le(char_count);
|
||||
for &code_unit in &utf16 {
|
||||
w.write_u16_le(code_unit);
|
||||
}
|
||||
w.align_to(4);
|
||||
}
|
||||
|
||||
/// Queue all the responses needed for a full list_shares flow.
|
||||
pub(crate) fn queue_share_listing_responses(
|
||||
mock: &MockTransport,
|
||||
shares: &[(&str, u32, &str)],
|
||||
) {
|
||||
let tree_id = TreeId(42);
|
||||
let file_id = FileId {
|
||||
persistent: 0xAAAA,
|
||||
volatile: 0xBBBB,
|
||||
};
|
||||
|
||||
// 1. TREE_CONNECT response
|
||||
mock.queue_response(build_tree_connect_response(tree_id, ShareType::Pipe));
|
||||
// 2. CREATE response (open srvsvc pipe)
|
||||
mock.queue_response(build_create_response(file_id, 0));
|
||||
// 3. WRITE response (RPC BIND)
|
||||
mock.queue_response(build_write_response(72));
|
||||
// 4. READ response (BIND_ACK)
|
||||
mock.queue_response(build_read_response(build_bind_ack()));
|
||||
// 5. WRITE response (NetShareEnumAll request)
|
||||
mock.queue_response(build_write_response(100));
|
||||
// 6. READ response (NetShareEnumAll response)
|
||||
mock.queue_response(build_read_response(build_share_enum_response(shares)));
|
||||
// 7. CLOSE response
|
||||
mock.queue_response(build_close_response());
|
||||
// 8. TREE_DISCONNECT response
|
||||
mock.queue_response(build_tree_disconnect_response());
|
||||
}
|
||||
|
||||
/// Like `queue_share_listing_responses`, but the server splits a single
|
||||
/// RPC RESPONSE PDU across two pipe reads: the first read returns
|
||||
/// `STATUS_BUFFER_OVERFLOW` with the leading bytes, the second returns
|
||||
/// `SUCCESS` with the rest. The client must stitch them before parsing.
|
||||
fn queue_overflow_share_listing_responses(mock: &MockTransport, shares: &[(&str, u32, &str)]) {
|
||||
let tree_id = TreeId(42);
|
||||
let file_id = FileId {
|
||||
persistent: 0xAAAA,
|
||||
volatile: 0xBBBB,
|
||||
};
|
||||
|
||||
let pdu = build_share_enum_response(shares);
|
||||
let split = pdu.len() / 2;
|
||||
let (first, rest) = pdu.split_at(split);
|
||||
|
||||
mock.queue_response(build_tree_connect_response(tree_id, ShareType::Pipe));
|
||||
mock.queue_response(build_create_response(file_id, 0));
|
||||
mock.queue_response(build_write_response(72));
|
||||
mock.queue_response(build_read_response(build_bind_ack()));
|
||||
mock.queue_response(build_write_response(100));
|
||||
// The response PDU arrives in two chunks: overflow then success.
|
||||
mock.queue_response(build_read_response_with_status(
|
||||
first.to_vec(),
|
||||
NtStatus::BUFFER_OVERFLOW,
|
||||
));
|
||||
mock.queue_response(build_read_response_with_status(
|
||||
rest.to_vec(),
|
||||
NtStatus::SUCCESS,
|
||||
));
|
||||
mock.queue_response(build_close_response());
|
||||
mock.queue_response(build_tree_disconnect_response());
|
||||
}
|
||||
|
||||
/// Like `queue_share_listing_responses`, but the RPC RESPONSE is split into
|
||||
/// two DCE/RPC fragments (each its own pipe message): the first carries
|
||||
/// `PFC_FIRST_FRAG`, the second `PFC_LAST_FRAG`. The client must reassemble
|
||||
/// the stub across fragments before parsing.
|
||||
fn queue_fragmented_share_listing_responses(
|
||||
mock: &MockTransport,
|
||||
shares: &[(&str, u32, &str)],
|
||||
) {
|
||||
let tree_id = TreeId(42);
|
||||
let file_id = FileId {
|
||||
persistent: 0xAAAA,
|
||||
volatile: 0xBBBB,
|
||||
};
|
||||
|
||||
let stub = build_share_enum_stub(shares);
|
||||
let split = stub.len() / 2;
|
||||
let (first, rest) = stub.split_at(split);
|
||||
let frag1 = wrap_rpc_response_pdu(first, 0x01); // PFC_FIRST_FRAG only
|
||||
let frag2 = wrap_rpc_response_pdu(rest, 0x02); // PFC_LAST_FRAG only
|
||||
|
||||
mock.queue_response(build_tree_connect_response(tree_id, ShareType::Pipe));
|
||||
mock.queue_response(build_create_response(file_id, 0));
|
||||
mock.queue_response(build_write_response(72));
|
||||
mock.queue_response(build_read_response(build_bind_ack()));
|
||||
mock.queue_response(build_write_response(100));
|
||||
mock.queue_response(build_read_response(frag1));
|
||||
mock.queue_response(build_read_response(frag2));
|
||||
mock.queue_response(build_close_response());
|
||||
mock.queue_response(build_tree_disconnect_response());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn list_shares_reassembles_buffer_overflow_reads() {
|
||||
let mock = Arc::new(MockTransport::new());
|
||||
let mut conn = setup_connection(&mock);
|
||||
|
||||
queue_overflow_share_listing_responses(
|
||||
&mock,
|
||||
&[
|
||||
("Documents", STYPE_DISKTREE, "Shared docs"),
|
||||
("Photos", STYPE_DISKTREE, "Family photos"),
|
||||
],
|
||||
);
|
||||
|
||||
let shares = list_shares(&mut conn).await.unwrap();
|
||||
|
||||
assert_eq!(shares.len(), 2);
|
||||
assert_eq!(shares[0].name, "Documents");
|
||||
assert_eq!(shares[1].name, "Photos");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn list_shares_reassembles_rpc_fragments() {
|
||||
let mock = Arc::new(MockTransport::new());
|
||||
let mut conn = setup_connection(&mock);
|
||||
|
||||
queue_fragmented_share_listing_responses(
|
||||
&mock,
|
||||
&[
|
||||
("Documents", STYPE_DISKTREE, "Shared docs"),
|
||||
("Photos", STYPE_DISKTREE, "Family photos"),
|
||||
],
|
||||
);
|
||||
|
||||
let shares = list_shares(&mut conn).await.unwrap();
|
||||
|
||||
assert_eq!(shares.len(), 2);
|
||||
assert_eq!(shares[0].name, "Documents");
|
||||
assert_eq!(shares[1].name, "Photos");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn list_shares_returns_disk_shares() {
|
||||
let mock = Arc::new(MockTransport::new());
|
||||
let mut conn = setup_connection(&mock);
|
||||
|
||||
queue_share_listing_responses(
|
||||
&mock,
|
||||
&[
|
||||
("Documents", STYPE_DISKTREE, "Shared docs"),
|
||||
("IPC$", STYPE_IPC | STYPE_SPECIAL, "Remote IPC"),
|
||||
("C$", STYPE_DISKTREE | STYPE_SPECIAL, "Default share"),
|
||||
("Photos", STYPE_DISKTREE, "Family photos"),
|
||||
],
|
||||
);
|
||||
|
||||
let shares = list_shares(&mut conn).await.unwrap();
|
||||
|
||||
// Only disk shares without $ suffix and without STYPE_SPECIAL
|
||||
assert_eq!(shares.len(), 2);
|
||||
assert_eq!(shares[0].name, "Documents");
|
||||
assert_eq!(shares[0].comment, "Shared docs");
|
||||
assert_eq!(shares[1].name, "Photos");
|
||||
assert_eq!(shares[1].comment, "Family photos");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn list_shares_sends_correct_number_of_messages() {
|
||||
let mock = Arc::new(MockTransport::new());
|
||||
let mut conn = setup_connection(&mock);
|
||||
|
||||
queue_share_listing_responses(&mock, &[("TestShare", STYPE_DISKTREE, "A test share")]);
|
||||
|
||||
let _shares = list_shares(&mut conn).await.unwrap();
|
||||
|
||||
// Should have sent 8 messages:
|
||||
// TREE_CONNECT, CREATE, WRITE(bind), READ(bind_ack),
|
||||
// WRITE(request), READ(response), CLOSE, TREE_DISCONNECT
|
||||
assert_eq!(mock.sent_count(), 8);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn list_shares_empty_server() {
|
||||
let mock = Arc::new(MockTransport::new());
|
||||
let mut conn = setup_connection(&mock);
|
||||
|
||||
queue_share_listing_responses(&mock, &[]);
|
||||
|
||||
let shares = list_shares(&mut conn).await.unwrap();
|
||||
assert!(shares.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn list_shares_filters_non_disk_shares() {
|
||||
let mock = Arc::new(MockTransport::new());
|
||||
let mut conn = setup_connection(&mock);
|
||||
|
||||
// All non-disk or special shares
|
||||
queue_share_listing_responses(
|
||||
&mock,
|
||||
&[
|
||||
("IPC$", STYPE_IPC | STYPE_SPECIAL, "Remote IPC"),
|
||||
("ADMIN$", STYPE_DISKTREE | STYPE_SPECIAL, "Remote Admin"),
|
||||
],
|
||||
);
|
||||
|
||||
let shares = list_shares(&mut conn).await.unwrap();
|
||||
assert!(shares.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn list_shares_uses_correct_server_name() {
|
||||
let mock = Arc::new(MockTransport::new());
|
||||
mock.enable_auto_rewrite_msg_id();
|
||||
let mut conn =
|
||||
Connection::from_transport(Box::new(mock.clone()), Box::new(mock.clone()), "my-nas");
|
||||
conn.set_test_params(NegotiatedParams {
|
||||
dialect: Dialect::Smb2_0_2,
|
||||
max_read_size: 65536,
|
||||
max_write_size: 65536,
|
||||
max_transact_size: 65536,
|
||||
server_guid: Guid::ZERO,
|
||||
signing_required: false,
|
||||
capabilities: Capabilities::default(),
|
||||
gmac_negotiated: false,
|
||||
cipher: None,
|
||||
compression_supported: false,
|
||||
});
|
||||
conn.set_session_id(SessionId(0x1234));
|
||||
|
||||
queue_share_listing_responses(&mock, &[("share1", STYPE_DISKTREE, "")]);
|
||||
|
||||
let shares = list_shares(&mut conn).await.unwrap();
|
||||
assert_eq!(shares.len(), 1);
|
||||
|
||||
// Verify the TREE_CONNECT request contains \\my-nas\IPC$
|
||||
let sent = mock.sent_messages();
|
||||
let tree_connect_bytes = &sent[0];
|
||||
// The UNC path is UTF-16LE in the request body
|
||||
let unc_utf8 = String::from_utf8_lossy(tree_connect_bytes);
|
||||
// Verify the server name appears somewhere in the raw bytes
|
||||
assert!(
|
||||
tree_connect_bytes.windows(2).any(|w| w == b"m\0"), // 'm' in UTF-16LE from "my-nas"
|
||||
"TREE_CONNECT should reference the server name"
|
||||
);
|
||||
drop(unc_utf8);
|
||||
}
|
||||
}
|
||||
1499
vendor/smb2/src/client/stream.rs
vendored
Normal file
1499
vendor/smb2/src/client/stream.rs
vendored
Normal file
File diff suppressed because it is too large
Load Diff
182
vendor/smb2/src/client/test_helpers.rs
vendored
Normal file
182
vendor/smb2/src/client/test_helpers.rs
vendored
Normal file
@@ -0,0 +1,182 @@
|
||||
//! Shared test helper functions for `client` module tests.
|
||||
//!
|
||||
//! These build mock SMB2 responses used across pipeline, shares, and tree tests.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::client::connection::{pack_message, Connection, NegotiatedParams};
|
||||
use crate::msg::close::CloseResponse;
|
||||
use crate::msg::create::{CreateAction, CreateResponse};
|
||||
use crate::msg::header::Header;
|
||||
use crate::msg::tree_connect::{ShareType, TreeConnectResponse};
|
||||
use crate::pack::{FileTime, Guid};
|
||||
use crate::transport::MockTransport;
|
||||
use crate::types::flags::{Capabilities, ShareCapabilities, ShareFlags};
|
||||
use crate::types::{Command, Dialect, FileId, OplockLevel, SessionId, TreeId};
|
||||
|
||||
/// Create a mock-backed connection with standard negotiated params.
|
||||
///
|
||||
/// Enables the mock's auto-msg_id-rewrite so canned `build_*_response`
|
||||
/// helpers (which hardcode `MessageId(0)` and don't know the caller's
|
||||
/// allocated msg_ids) still route through the Phase 3 receiver task: on
|
||||
/// each `receive()` the mock patches sub-frame msg_ids to match the next
|
||||
/// pending sent msg_id in FIFO order. Replaces the pre-Phase-3
|
||||
/// `set_orphan_filter_enabled(false)` path.
|
||||
pub(crate) fn setup_connection(mock: &Arc<MockTransport>) -> Connection {
|
||||
mock.enable_auto_rewrite_msg_id();
|
||||
let mut conn = Connection::from_transport(
|
||||
Box::new(mock.clone()),
|
||||
Box::new(mock.clone()),
|
||||
"test-server",
|
||||
);
|
||||
conn.set_test_params(NegotiatedParams {
|
||||
dialect: Dialect::Smb2_0_2,
|
||||
max_read_size: 65536,
|
||||
max_write_size: 65536,
|
||||
max_transact_size: 65536,
|
||||
server_guid: Guid::ZERO,
|
||||
signing_required: false,
|
||||
capabilities: Capabilities::default(),
|
||||
gmac_negotiated: false,
|
||||
cipher: None,
|
||||
compression_supported: false,
|
||||
});
|
||||
conn.set_session_id(SessionId(0x1234));
|
||||
conn
|
||||
}
|
||||
|
||||
/// Build a CREATE response with the given file ID and end-of-file size.
|
||||
pub(crate) fn build_create_response(file_id: FileId, end_of_file: u64) -> Vec<u8> {
|
||||
let mut h = Header::new_request(Command::Create);
|
||||
h.flags.set_response();
|
||||
h.credits = 32;
|
||||
|
||||
let body = CreateResponse {
|
||||
oplock_level: OplockLevel::None,
|
||||
flags: 0,
|
||||
create_action: CreateAction::FileOpened,
|
||||
creation_time: FileTime::ZERO,
|
||||
last_access_time: FileTime::ZERO,
|
||||
last_write_time: FileTime::ZERO,
|
||||
change_time: FileTime::ZERO,
|
||||
allocation_size: 0,
|
||||
end_of_file,
|
||||
file_attributes: 0,
|
||||
file_id,
|
||||
create_contexts: vec![],
|
||||
};
|
||||
|
||||
pack_message(&h, &body)
|
||||
}
|
||||
|
||||
/// Build a CREATE response with a non-success status (for error tests).
|
||||
pub(crate) fn build_create_error_response(status: crate::types::status::NtStatus) -> Vec<u8> {
|
||||
use crate::msg::header::ErrorResponse;
|
||||
let mut h = Header::new_request(Command::Create);
|
||||
h.flags.set_response();
|
||||
h.credits = 32;
|
||||
h.status = status;
|
||||
|
||||
let body = ErrorResponse {
|
||||
error_context_count: 0,
|
||||
error_data: vec![],
|
||||
};
|
||||
|
||||
pack_message(&h, &body)
|
||||
}
|
||||
|
||||
/// Build a CLOSE response with zeroed fields.
|
||||
pub(crate) fn build_close_response() -> Vec<u8> {
|
||||
let mut h = Header::new_request(Command::Close);
|
||||
h.flags.set_response();
|
||||
h.credits = 32;
|
||||
|
||||
let body = CloseResponse {
|
||||
flags: 0,
|
||||
creation_time: FileTime::ZERO,
|
||||
last_access_time: FileTime::ZERO,
|
||||
last_write_time: FileTime::ZERO,
|
||||
change_time: FileTime::ZERO,
|
||||
allocation_size: 0,
|
||||
end_of_file: 0,
|
||||
file_attributes: 0,
|
||||
};
|
||||
|
||||
pack_message(&h, &body)
|
||||
}
|
||||
|
||||
/// Build a WRITE response with the given byte count.
|
||||
pub(crate) fn build_write_response(count: u32) -> Vec<u8> {
|
||||
use crate::msg::write::WriteResponse;
|
||||
let mut h = Header::new_request(Command::Write);
|
||||
h.flags.set_response();
|
||||
h.credits = 32;
|
||||
|
||||
let body = WriteResponse {
|
||||
count,
|
||||
remaining: 0,
|
||||
write_channel_info_offset: 0,
|
||||
write_channel_info_length: 0,
|
||||
};
|
||||
|
||||
pack_message(&h, &body)
|
||||
}
|
||||
|
||||
/// Build a WRITE response with a non-success status (for error tests).
|
||||
pub(crate) fn build_write_error_response(status: crate::types::status::NtStatus) -> Vec<u8> {
|
||||
use crate::msg::header::ErrorResponse;
|
||||
let mut h = Header::new_request(Command::Write);
|
||||
h.flags.set_response();
|
||||
h.credits = 32;
|
||||
h.status = status;
|
||||
|
||||
let body = ErrorResponse {
|
||||
error_context_count: 0,
|
||||
error_data: vec![],
|
||||
};
|
||||
|
||||
pack_message(&h, &body)
|
||||
}
|
||||
|
||||
/// Build a CLOSE response with a non-success status (for error tests).
|
||||
pub(crate) fn build_close_error_response(status: crate::types::status::NtStatus) -> Vec<u8> {
|
||||
use crate::msg::header::ErrorResponse;
|
||||
let mut h = Header::new_request(Command::Close);
|
||||
h.flags.set_response();
|
||||
h.credits = 32;
|
||||
h.status = status;
|
||||
|
||||
let body = ErrorResponse {
|
||||
error_context_count: 0,
|
||||
error_data: vec![],
|
||||
};
|
||||
|
||||
pack_message(&h, &body)
|
||||
}
|
||||
|
||||
/// Build a FLUSH response.
|
||||
pub(crate) fn build_flush_response() -> Vec<u8> {
|
||||
let mut h = Header::new_request(Command::Flush);
|
||||
h.flags.set_response();
|
||||
h.credits = 32;
|
||||
|
||||
let body = crate::msg::flush::FlushResponse;
|
||||
pack_message(&h, &body)
|
||||
}
|
||||
|
||||
/// Build a TREE_CONNECT response with the given tree ID and share type.
|
||||
pub(crate) fn build_tree_connect_response(tree_id: TreeId, share_type: ShareType) -> Vec<u8> {
|
||||
let mut h = Header::new_request(Command::TreeConnect);
|
||||
h.flags.set_response();
|
||||
h.credits = 32;
|
||||
h.tree_id = Some(tree_id);
|
||||
|
||||
let body = TreeConnectResponse {
|
||||
share_type,
|
||||
share_flags: ShareFlags::default(),
|
||||
capabilities: ShareCapabilities::default(),
|
||||
maximal_access: 0x001F_01FF,
|
||||
};
|
||||
|
||||
pack_message(&h, &body)
|
||||
}
|
||||
6691
vendor/smb2/src/client/tree.rs
vendored
Normal file
6691
vendor/smb2/src/client/tree.rs
vendored
Normal file
File diff suppressed because it is too large
Load Diff
780
vendor/smb2/src/client/watcher.rs
vendored
Normal file
780
vendor/smb2/src/client/watcher.rs
vendored
Normal file
@@ -0,0 +1,780 @@
|
||||
//! Directory change notification via SMB2 CHANGE_NOTIFY.
|
||||
//!
|
||||
//! The [`Watcher`] type registers for change notifications on a directory
|
||||
//! and returns [`FileNotifyEvent`] entries describing changes as they happen.
|
||||
//! The server holds the request until a change occurs, making this a long-poll
|
||||
//! operation.
|
||||
|
||||
use log::debug;
|
||||
|
||||
use crate::client::connection::{await_frame, Connection, Frame};
|
||||
use crate::client::tree::Tree;
|
||||
use crate::error::Result;
|
||||
use crate::msg::change_notify::{
|
||||
ChangeNotifyRequest, ChangeNotifyResponse, FILE_NOTIFY_CHANGE_ATTRIBUTES,
|
||||
FILE_NOTIFY_CHANGE_CREATION, FILE_NOTIFY_CHANGE_DIR_NAME, FILE_NOTIFY_CHANGE_FILE_NAME,
|
||||
FILE_NOTIFY_CHANGE_LAST_WRITE, FILE_NOTIFY_CHANGE_SIZE, SMB2_WATCH_TREE,
|
||||
};
|
||||
use crate::pack::{ReadCursor, Unpack};
|
||||
use crate::types::status::NtStatus;
|
||||
use crate::types::{Command, FileId};
|
||||
use crate::Error;
|
||||
use tokio::sync::oneshot;
|
||||
|
||||
/// Default completion filter: watch for most common changes.
|
||||
const DEFAULT_COMPLETION_FILTER: u32 = FILE_NOTIFY_CHANGE_FILE_NAME
|
||||
| FILE_NOTIFY_CHANGE_DIR_NAME
|
||||
| FILE_NOTIFY_CHANGE_ATTRIBUTES
|
||||
| FILE_NOTIFY_CHANGE_SIZE
|
||||
| FILE_NOTIFY_CHANGE_LAST_WRITE
|
||||
| FILE_NOTIFY_CHANGE_CREATION;
|
||||
|
||||
/// Default output buffer length for CHANGE_NOTIFY responses (64 KB).
|
||||
const OUTPUT_BUFFER_LENGTH: u32 = 65536;
|
||||
|
||||
/// The type of change that occurred on a file or directory.
|
||||
///
|
||||
/// These correspond to the `Action` field in `FILE_NOTIFY_INFORMATION`
|
||||
/// (MS-FSCC section 2.4.42).
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum FileNotifyAction {
|
||||
/// A file was added to the directory.
|
||||
Added,
|
||||
/// A file was removed from the directory.
|
||||
Removed,
|
||||
/// A file was modified.
|
||||
Modified,
|
||||
/// A file was renamed (this is the old name).
|
||||
RenamedOldName,
|
||||
/// A file was renamed (this is the new name).
|
||||
RenamedNewName,
|
||||
}
|
||||
|
||||
impl FileNotifyAction {
|
||||
/// Parse an action value from the wire format.
|
||||
fn from_u32(value: u32) -> Result<Self> {
|
||||
match value {
|
||||
0x0000_0001 => Ok(FileNotifyAction::Added),
|
||||
0x0000_0002 => Ok(FileNotifyAction::Removed),
|
||||
0x0000_0003 => Ok(FileNotifyAction::Modified),
|
||||
0x0000_0004 => Ok(FileNotifyAction::RenamedOldName),
|
||||
0x0000_0005 => Ok(FileNotifyAction::RenamedNewName),
|
||||
other => Err(Error::invalid_data(format!(
|
||||
"unknown FILE_NOTIFY_INFORMATION action: {other:#010X}"
|
||||
))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for FileNotifyAction {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
FileNotifyAction::Added => write!(f, "added"),
|
||||
FileNotifyAction::Removed => write!(f, "removed"),
|
||||
FileNotifyAction::Modified => write!(f, "modified"),
|
||||
FileNotifyAction::RenamedOldName => write!(f, "renamed (old name)"),
|
||||
FileNotifyAction::RenamedNewName => write!(f, "renamed (new name)"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A single file change notification.
|
||||
///
|
||||
/// Represents one `FILE_NOTIFY_INFORMATION` entry from the server.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct FileNotifyEvent {
|
||||
/// What kind of change occurred.
|
||||
pub action: FileNotifyAction,
|
||||
/// The relative file name within the watched directory.
|
||||
pub filename: String,
|
||||
}
|
||||
|
||||
/// Watches a directory for changes via SMB2 CHANGE_NOTIFY.
|
||||
///
|
||||
/// The server holds the request until something changes, then responds
|
||||
/// with one or more [`FileNotifyEvent`] entries. Each call to
|
||||
/// [`next_events()`](Watcher::next_events) blocks until the server
|
||||
/// reports a change.
|
||||
///
|
||||
/// ```no_run
|
||||
/// # async fn example(client: &mut smb2::SmbClient, share: &smb2::Tree) -> Result<(), smb2::Error> {
|
||||
/// let mut watcher = client.watch(&share, "_test/", true).await?;
|
||||
/// loop {
|
||||
/// let events = watcher.next_events().await?;
|
||||
/// for event in &events {
|
||||
/// println!("{}: {}", event.filename, event.action);
|
||||
/// }
|
||||
/// }
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
///
|
||||
/// **Pipelining**: `Watcher` keeps one CHANGE_NOTIFY request pre-issued on
|
||||
/// the wire at all times after the first call to
|
||||
/// [`next_events`](Self::next_events). The wire never sits idle between
|
||||
/// consecutive responses, so server-side events that arrive while the
|
||||
/// consumer is processing the previous batch are still delivered to an
|
||||
/// outstanding request — they don't fall in a response→re-arm gap where
|
||||
/// strict servers (older Samba, NAS firmware) drop them silently.
|
||||
///
|
||||
/// The watcher owns a cloned [`Connection`] (cheap `Arc::clone`, all
|
||||
/// clones multiplex over the same SMB session), so the caller doesn't
|
||||
/// need a second `SmbClient` to perform other operations while watching.
|
||||
pub struct Watcher {
|
||||
tree: Tree,
|
||||
conn: Connection,
|
||||
file_id: FileId,
|
||||
recursive: bool,
|
||||
/// In-flight CHANGE_NOTIFY response receiver. Populated lazily on the
|
||||
/// first `next_events()` call and re-populated before awaiting each
|
||||
/// response, so there is always exactly one outstanding request on
|
||||
/// the wire from that point on.
|
||||
pending: Option<oneshot::Receiver<Result<Frame>>>,
|
||||
}
|
||||
|
||||
impl Watcher {
|
||||
/// Create a new watcher (called by `Tree::watch`).
|
||||
pub(crate) fn new(tree: Tree, conn: Connection, file_id: FileId, recursive: bool) -> Self {
|
||||
Watcher {
|
||||
tree,
|
||||
conn,
|
||||
file_id,
|
||||
recursive,
|
||||
pending: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Wait for the next batch of change events.
|
||||
///
|
||||
/// Dispatches a CHANGE_NOTIFY request (if one isn't already pre-issued
|
||||
/// from the previous call), then — before awaiting the response —
|
||||
/// dispatches the *next* CHANGE_NOTIFY. This keeps the wire
|
||||
/// continuously armed: from the moment the first call returns until
|
||||
/// the watcher is dropped, the server always has an outstanding
|
||||
/// request to deliver events into. Closes the response→re-arm loss
|
||||
/// window that strict servers (older Samba, NAS firmware) drop events
|
||||
/// through.
|
||||
///
|
||||
/// The server holds each request until changes occur, so this call
|
||||
/// may block for a long time.
|
||||
///
|
||||
/// Returns `Ok(events)` with one or more events when changes are detected.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `Error::Protocol` with `STATUS_NOTIFY_ENUM_DIR` if too many
|
||||
/// changes occurred and the server could not fit them in the response
|
||||
/// buffer. In this case, the caller should re-scan the directory and
|
||||
/// keep watching — by the time control returns, the pipelined-next
|
||||
/// request is already on the wire so no events arriving during the
|
||||
/// re-scan get lost.
|
||||
pub async fn next_events(&mut self) -> Result<Vec<FileNotifyEvent>> {
|
||||
// Cold start: no request has been issued yet. Dispatch the first.
|
||||
if self.pending.is_none() {
|
||||
let rx = self.dispatch_next().await?;
|
||||
self.pending = Some(rx);
|
||||
}
|
||||
// Take the currently in-flight receiver, then immediately
|
||||
// pre-issue the next request before awaiting this one. The
|
||||
// `dispatch` call below `.await`s only the transport.send(), so
|
||||
// when it returns, the next CHANGE_NOTIFY is on the wire and the
|
||||
// server has somewhere to put new events even while we process
|
||||
// the response for the previous one.
|
||||
let in_flight = self.pending.take().expect("pending populated above");
|
||||
let next_rx = self.dispatch_next().await?;
|
||||
self.pending = Some(next_rx);
|
||||
|
||||
let frame = await_frame(in_flight).await?;
|
||||
|
||||
if frame.header.status == NtStatus::NOTIFY_ENUM_DIR {
|
||||
return Err(Error::Protocol {
|
||||
status: frame.header.status,
|
||||
command: Command::ChangeNotify,
|
||||
});
|
||||
}
|
||||
|
||||
if frame.header.status != NtStatus::SUCCESS {
|
||||
return Err(Error::Protocol {
|
||||
status: frame.header.status,
|
||||
command: Command::ChangeNotify,
|
||||
});
|
||||
}
|
||||
|
||||
let mut cursor = ReadCursor::new(&frame.body);
|
||||
let resp = ChangeNotifyResponse::unpack(&mut cursor)?;
|
||||
|
||||
let events = parse_notify_information(&resp.output_data)?;
|
||||
debug!("watcher: received {} change event(s)", events.len());
|
||||
Ok(events)
|
||||
}
|
||||
|
||||
/// Build a CHANGE_NOTIFY request and dispatch it on the cloned
|
||||
/// connection, returning the response receiver. `Connection::dispatch`
|
||||
/// awaits only up to and including `transport.send()`, so when this
|
||||
/// returns the request is on the wire — the caller can rely on the
|
||||
/// "outstanding on the wire" invariant for whatever comes next.
|
||||
async fn dispatch_next(&self) -> Result<oneshot::Receiver<Result<Frame>>> {
|
||||
let flags = if self.recursive { SMB2_WATCH_TREE } else { 0 };
|
||||
let req = ChangeNotifyRequest {
|
||||
flags,
|
||||
output_buffer_length: OUTPUT_BUFFER_LENGTH,
|
||||
file_id: self.file_id,
|
||||
completion_filter: DEFAULT_COMPLETION_FILTER,
|
||||
};
|
||||
self.conn
|
||||
.dispatch(Command::ChangeNotify, &req, Some(self.tree.tree_id))
|
||||
.await
|
||||
}
|
||||
|
||||
/// Close the directory handle.
|
||||
///
|
||||
/// Drops the pre-issued CHANGE_NOTIFY receiver (the `Connection`
|
||||
/// receiver task discards the late response silently when it
|
||||
/// arrives — same contract `Connection::execute` already documents),
|
||||
/// then issues a CLOSE on the file handle. If `close` is not called
|
||||
/// explicitly, the `Drop` impl drops the pre-issued receiver but the
|
||||
/// server-side handle leaks until the session ends (there is no
|
||||
/// async drop in Rust).
|
||||
pub async fn close(mut self) -> Result<()> {
|
||||
self.pending.take();
|
||||
self.tree.close_handle(&mut self.conn, self.file_id).await
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for Watcher {
|
||||
fn drop(&mut self) {
|
||||
// The pre-issued response receiver (if any) drops with the
|
||||
// Watcher. The `Connection` receiver task discards the late
|
||||
// frame silently when it arrives, matching the contract on
|
||||
// `Connection::execute`. The directory handle itself leaks
|
||||
// server-side until the session ends — the docstring on `close`
|
||||
// already warns about this.
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse a chain of FILE_NOTIFY_INFORMATION entries from the response buffer.
|
||||
///
|
||||
/// Each entry has:
|
||||
/// - `NextEntryOffset` (u32): offset to next entry, 0 for last
|
||||
/// - `Action` (u32): the change type
|
||||
/// - `FileNameLength` (u32): length of filename in bytes (UTF-16LE)
|
||||
/// - `FileName` (variable): UTF-16LE, NOT null-terminated
|
||||
///
|
||||
/// Entries are 4-byte aligned.
|
||||
fn parse_notify_information(data: &[u8]) -> Result<Vec<FileNotifyEvent>> {
|
||||
let mut events = Vec::new();
|
||||
let mut offset = 0usize;
|
||||
|
||||
if data.is_empty() {
|
||||
return Ok(events);
|
||||
}
|
||||
|
||||
loop {
|
||||
// Need at least 12 bytes for the fixed fields.
|
||||
if offset + 12 > data.len() {
|
||||
return Err(Error::invalid_data(
|
||||
"FILE_NOTIFY_INFORMATION truncated: not enough bytes for fixed fields",
|
||||
));
|
||||
}
|
||||
|
||||
let next_entry_offset =
|
||||
u32::from_le_bytes(data[offset..offset + 4].try_into().unwrap()) as usize;
|
||||
let action_raw = u32::from_le_bytes(data[offset + 4..offset + 8].try_into().unwrap());
|
||||
let filename_length =
|
||||
u32::from_le_bytes(data[offset + 8..offset + 12].try_into().unwrap()) as usize;
|
||||
|
||||
// Filename starts right after the 12-byte fixed header.
|
||||
let filename_start = offset + 12;
|
||||
let filename_end = filename_start + filename_length;
|
||||
|
||||
if filename_end > data.len() {
|
||||
return Err(Error::invalid_data(format!(
|
||||
"FILE_NOTIFY_INFORMATION filename extends beyond buffer: \
|
||||
need {} bytes at offset {}, buffer is {} bytes",
|
||||
filename_length,
|
||||
filename_start,
|
||||
data.len()
|
||||
)));
|
||||
}
|
||||
|
||||
let filename_bytes = &data[filename_start..filename_end];
|
||||
|
||||
// Decode UTF-16LE filename.
|
||||
let filename = decode_utf16le(filename_bytes)?;
|
||||
let action = FileNotifyAction::from_u32(action_raw)?;
|
||||
|
||||
events.push(FileNotifyEvent { action, filename });
|
||||
|
||||
if next_entry_offset == 0 {
|
||||
break;
|
||||
}
|
||||
|
||||
offset += next_entry_offset;
|
||||
}
|
||||
|
||||
Ok(events)
|
||||
}
|
||||
|
||||
/// Decode a UTF-16LE byte slice into a Rust String.
|
||||
fn decode_utf16le(bytes: &[u8]) -> Result<String> {
|
||||
if bytes.len() % 2 != 0 {
|
||||
return Err(Error::invalid_data("UTF-16LE filename has odd byte count"));
|
||||
}
|
||||
|
||||
let u16s: Vec<u16> = bytes
|
||||
.chunks_exact(2)
|
||||
.map(|chunk| u16::from_le_bytes([chunk[0], chunk[1]]))
|
||||
.collect();
|
||||
|
||||
String::from_utf16(&u16s)
|
||||
.map_err(|e| Error::invalid_data(format!("invalid UTF-16LE filename: {e}")))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn parse_single_notify_entry() {
|
||||
// Build a single FILE_NOTIFY_INFORMATION entry.
|
||||
let filename = "test.txt";
|
||||
let utf16: Vec<u16> = filename.encode_utf16().collect();
|
||||
let filename_bytes: Vec<u8> = utf16.iter().flat_map(|c| c.to_le_bytes()).collect();
|
||||
let filename_len = filename_bytes.len() as u32;
|
||||
|
||||
let mut data = Vec::new();
|
||||
// NextEntryOffset = 0 (last entry)
|
||||
data.extend_from_slice(&0u32.to_le_bytes());
|
||||
// Action = FILE_ACTION_ADDED (0x00000001)
|
||||
data.extend_from_slice(&1u32.to_le_bytes());
|
||||
// FileNameLength
|
||||
data.extend_from_slice(&filename_len.to_le_bytes());
|
||||
// FileName (UTF-16LE)
|
||||
data.extend_from_slice(&filename_bytes);
|
||||
|
||||
let events = parse_notify_information(&data).unwrap();
|
||||
assert_eq!(events.len(), 1);
|
||||
assert_eq!(events[0].action, FileNotifyAction::Added);
|
||||
assert_eq!(events[0].filename, "test.txt");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_multiple_notify_entries() {
|
||||
// Build two FILE_NOTIFY_INFORMATION entries.
|
||||
let build_entry = |name: &str, action: u32, is_last: bool| -> Vec<u8> {
|
||||
let utf16: Vec<u16> = name.encode_utf16().collect();
|
||||
let filename_bytes: Vec<u8> = utf16.iter().flat_map(|c| c.to_le_bytes()).collect();
|
||||
let filename_len = filename_bytes.len() as u32;
|
||||
|
||||
let mut entry = Vec::new();
|
||||
// Fixed header is 12 bytes + filename. Align to 4 bytes.
|
||||
let entry_size = 12 + filename_bytes.len();
|
||||
let aligned_size = (entry_size + 3) & !3;
|
||||
|
||||
let next_offset = if is_last { 0u32 } else { aligned_size as u32 };
|
||||
entry.extend_from_slice(&next_offset.to_le_bytes());
|
||||
entry.extend_from_slice(&action.to_le_bytes());
|
||||
entry.extend_from_slice(&filename_len.to_le_bytes());
|
||||
entry.extend_from_slice(&filename_bytes);
|
||||
|
||||
// Pad to 4-byte alignment.
|
||||
while entry.len() < aligned_size {
|
||||
entry.push(0);
|
||||
}
|
||||
|
||||
entry
|
||||
};
|
||||
|
||||
let mut data = Vec::new();
|
||||
data.extend_from_slice(&build_entry("added.txt", 1, false));
|
||||
data.extend_from_slice(&build_entry("removed.txt", 2, true));
|
||||
|
||||
let events = parse_notify_information(&data).unwrap();
|
||||
assert_eq!(events.len(), 2);
|
||||
assert_eq!(events[0].action, FileNotifyAction::Added);
|
||||
assert_eq!(events[0].filename, "added.txt");
|
||||
assert_eq!(events[1].action, FileNotifyAction::Removed);
|
||||
assert_eq!(events[1].filename, "removed.txt");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_empty_buffer_returns_no_events() {
|
||||
let events = parse_notify_information(&[]).unwrap();
|
||||
assert!(events.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_truncated_buffer_returns_error() {
|
||||
// Only 8 bytes, need at least 12 for fixed fields.
|
||||
let data = vec![0u8; 8];
|
||||
let result = parse_notify_information(&data);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decode_utf16le_basic() {
|
||||
let input = "hello";
|
||||
let utf16: Vec<u16> = input.encode_utf16().collect();
|
||||
let bytes: Vec<u8> = utf16.iter().flat_map(|c| c.to_le_bytes()).collect();
|
||||
let result = decode_utf16le(&bytes).unwrap();
|
||||
assert_eq!(result, "hello");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decode_utf16le_non_ascii() {
|
||||
let input = "photos/\u{00E9}t\u{00E9}";
|
||||
let utf16: Vec<u16> = input.encode_utf16().collect();
|
||||
let bytes: Vec<u8> = utf16.iter().flat_map(|c| c.to_le_bytes()).collect();
|
||||
let result = decode_utf16le(&bytes).unwrap();
|
||||
assert_eq!(result, input);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decode_utf16le_odd_bytes_is_error() {
|
||||
let result = decode_utf16le(&[0x41, 0x00, 0x42]);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn file_notify_action_display() {
|
||||
assert_eq!(format!("{}", FileNotifyAction::Added), "added");
|
||||
assert_eq!(format!("{}", FileNotifyAction::Removed), "removed");
|
||||
assert_eq!(format!("{}", FileNotifyAction::Modified), "modified");
|
||||
assert_eq!(
|
||||
format!("{}", FileNotifyAction::RenamedOldName),
|
||||
"renamed (old name)"
|
||||
);
|
||||
assert_eq!(
|
||||
format!("{}", FileNotifyAction::RenamedNewName),
|
||||
"renamed (new name)"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn file_notify_action_from_u32_unknown_is_error() {
|
||||
let result = FileNotifyAction::from_u32(0x9999);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
}
|
||||
|
||||
/// Loss-window tests using a strict-server simulator.
|
||||
///
|
||||
/// These probe the architectural property the watcher contract should
|
||||
/// guarantee: every event the server observes is eventually delivered
|
||||
/// to the consumer, even when the server drops events that arrive
|
||||
/// while no `CHANGE_NOTIFY` request is outstanding (the naspi / older
|
||||
/// Samba behavior that triggered cmdr's field reproduction).
|
||||
///
|
||||
/// **TDD-red on `main`**: `LossySim` drops events when no request is
|
||||
/// outstanding; current `next_events()` issues one CHANGE_NOTIFY per
|
||||
/// call, so there's always a gap between response delivery and the
|
||||
/// next request. Events pushed during that gap are dropped, and the
|
||||
/// test fails. The pipelined-watcher fix (always keep one CHANGE_NOTIFY
|
||||
/// pre-issued on the wire) closes the gap, the simulator never drops,
|
||||
/// and the test passes.
|
||||
#[cfg(test)]
|
||||
mod loss_window_tests {
|
||||
use super::*;
|
||||
use crate::client::connection::{pack_message, Connection, NegotiatedParams};
|
||||
use crate::client::tree::Tree;
|
||||
use crate::msg::change_notify::ChangeNotifyResponse;
|
||||
use crate::msg::header::Header;
|
||||
use crate::pack::Guid;
|
||||
use crate::transport::{TransportReceive, TransportSend};
|
||||
use crate::types::flags::Capabilities;
|
||||
use crate::types::{Command, Dialect, MessageId, SessionId, TreeId};
|
||||
use async_trait::async_trait;
|
||||
use std::collections::VecDeque;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::time::Duration;
|
||||
use tokio::sync::Notify;
|
||||
|
||||
/// Simulates a CHANGE_NOTIFY server that DROPS events that arrive
|
||||
/// while no request is outstanding. Models naspi / older Samba
|
||||
/// firmware (the server side of cmdr's 9-files → 4-events field
|
||||
/// reproduction). Forgiving servers like Docker Samba buffer
|
||||
/// generously and won't trigger this; the simulator's job is to
|
||||
/// surface the architectural bug regardless of how forgiving any
|
||||
/// real server happens to be.
|
||||
struct LossySim {
|
||||
/// Outstanding CHANGE_NOTIFY request msg_ids (FIFO).
|
||||
outstanding: Mutex<VecDeque<u64>>,
|
||||
/// Events the server has observed but not yet delivered.
|
||||
pending_events: Mutex<Vec<(String, u32)>>,
|
||||
/// Response queue read by `receive()`.
|
||||
responses: Mutex<VecDeque<Vec<u8>>>,
|
||||
/// Count of events the server saw with no request outstanding.
|
||||
dropped: Mutex<usize>,
|
||||
send_notify: Notify,
|
||||
recv_notify: Notify,
|
||||
closed: AtomicBool,
|
||||
}
|
||||
|
||||
impl LossySim {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
outstanding: Mutex::new(VecDeque::new()),
|
||||
pending_events: Mutex::new(Vec::new()),
|
||||
responses: Mutex::new(VecDeque::new()),
|
||||
dropped: Mutex::new(0),
|
||||
send_notify: Notify::new(),
|
||||
recv_notify: Notify::new(),
|
||||
closed: AtomicBool::new(false),
|
||||
}
|
||||
}
|
||||
|
||||
/// Block until at least one CHANGE_NOTIFY request is outstanding.
|
||||
async fn wait_outstanding(&self) {
|
||||
loop {
|
||||
if !self.outstanding.lock().unwrap().is_empty() {
|
||||
return;
|
||||
}
|
||||
if self.closed.load(Ordering::Acquire) {
|
||||
return;
|
||||
}
|
||||
self.send_notify.notified().await;
|
||||
}
|
||||
}
|
||||
|
||||
/// Push an event. If a CHANGE_NOTIFY request is outstanding, buffer
|
||||
/// the event for the next `deliver_pending()`. Else, drop silently
|
||||
/// and bump the dropped counter.
|
||||
fn push_event(&self, name: &str) {
|
||||
let outstanding = !self.outstanding.lock().unwrap().is_empty();
|
||||
if outstanding {
|
||||
self.pending_events
|
||||
.lock()
|
||||
.unwrap()
|
||||
.push((name.to_string(), 1 /* FILE_ACTION_ADDED */));
|
||||
} else {
|
||||
*self.dropped.lock().unwrap() += 1;
|
||||
}
|
||||
}
|
||||
|
||||
/// Wrap all buffered events into a single CHANGE_NOTIFY response,
|
||||
/// consuming one outstanding msg_id.
|
||||
fn deliver_pending(&self) {
|
||||
let msg_id = self.outstanding.lock().unwrap().pop_front();
|
||||
let events = std::mem::take(&mut *self.pending_events.lock().unwrap());
|
||||
if let Some(id) = msg_id {
|
||||
let resp = build_response(id, &events);
|
||||
self.responses.lock().unwrap().push_back(resp);
|
||||
self.recv_notify.notify_one();
|
||||
}
|
||||
}
|
||||
|
||||
fn dropped_count(&self) -> usize {
|
||||
*self.dropped.lock().unwrap()
|
||||
}
|
||||
|
||||
fn close(&self) {
|
||||
self.closed.store(true, Ordering::Release);
|
||||
self.recv_notify.notify_waiters();
|
||||
self.send_notify.notify_waiters();
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TransportSend for LossySim {
|
||||
async fn send(&self, data: &[u8]) -> crate::error::Result<()> {
|
||||
if let Some(msg_id) = extract_change_notify_msg_id(data) {
|
||||
self.outstanding.lock().unwrap().push_back(msg_id);
|
||||
self.send_notify.notify_waiters();
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TransportReceive for LossySim {
|
||||
async fn receive(&self) -> crate::error::Result<Vec<u8>> {
|
||||
loop {
|
||||
if let Some(data) = self.responses.lock().unwrap().pop_front() {
|
||||
return Ok(data);
|
||||
}
|
||||
if self.closed.load(Ordering::Acquire) {
|
||||
return Err(crate::Error::Disconnected);
|
||||
}
|
||||
self.recv_notify.notified().await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Pull `MessageId` out of a request frame, but only for CHANGE_NOTIFY.
|
||||
/// Non-CHANGE_NOTIFY sends are ignored by the simulator (the test
|
||||
/// pre-configures the connection so no other requests should hit this
|
||||
/// transport — but if any do, we won't track them).
|
||||
fn extract_change_notify_msg_id(data: &[u8]) -> Option<u64> {
|
||||
const HEADER_MIN: usize = 64;
|
||||
if data.len() < HEADER_MIN || &data[0..4] != b"\xFESMB" {
|
||||
return None;
|
||||
}
|
||||
let cmd = u16::from_le_bytes([data[12], data[13]]);
|
||||
if cmd != Command::ChangeNotify as u16 {
|
||||
return None;
|
||||
}
|
||||
Some(u64::from_le_bytes(data[24..32].try_into().unwrap()))
|
||||
}
|
||||
|
||||
/// Pack a CHANGE_NOTIFY response carrying the given (name, action) pairs.
|
||||
fn build_response(msg_id: u64, events: &[(String, u32)]) -> Vec<u8> {
|
||||
let mut output_data = Vec::new();
|
||||
for (i, (name, action)) in events.iter().enumerate() {
|
||||
let is_last = i == events.len() - 1;
|
||||
let utf16: Vec<u16> = name.encode_utf16().collect();
|
||||
let filename_bytes: Vec<u8> = utf16.iter().flat_map(|c| c.to_le_bytes()).collect();
|
||||
let filename_len = filename_bytes.len() as u32;
|
||||
let entry_size = 12 + filename_bytes.len();
|
||||
let aligned_size = (entry_size + 3) & !3;
|
||||
let next_offset = if is_last { 0u32 } else { aligned_size as u32 };
|
||||
let start = output_data.len();
|
||||
output_data.extend_from_slice(&next_offset.to_le_bytes());
|
||||
output_data.extend_from_slice(&action.to_le_bytes());
|
||||
output_data.extend_from_slice(&filename_len.to_le_bytes());
|
||||
output_data.extend_from_slice(&filename_bytes);
|
||||
while output_data.len() - start < aligned_size {
|
||||
output_data.push(0);
|
||||
}
|
||||
}
|
||||
let mut h = Header::new_request(Command::ChangeNotify);
|
||||
h.flags.set_response();
|
||||
h.message_id = MessageId(msg_id);
|
||||
h.credits = 32;
|
||||
let body = ChangeNotifyResponse { output_data };
|
||||
pack_message(&h, &body)
|
||||
}
|
||||
|
||||
fn setup_connection(sim: &Arc<LossySim>) -> Connection {
|
||||
let mut conn =
|
||||
Connection::from_transport(Box::new(sim.clone()), Box::new(sim.clone()), "test-server");
|
||||
conn.set_test_params(NegotiatedParams {
|
||||
dialect: Dialect::Smb2_0_2,
|
||||
max_read_size: 65536,
|
||||
max_write_size: 65536,
|
||||
max_transact_size: 65536,
|
||||
server_guid: Guid::ZERO,
|
||||
signing_required: false,
|
||||
capabilities: Capabilities::default(),
|
||||
gmac_negotiated: false,
|
||||
cipher: None,
|
||||
compression_supported: false,
|
||||
});
|
||||
conn.set_session_id(SessionId(0x1234));
|
||||
conn
|
||||
}
|
||||
|
||||
fn test_tree() -> Tree {
|
||||
Tree {
|
||||
tree_id: TreeId(1),
|
||||
share_name: "test".to_string(),
|
||||
server: "test-server".to_string(),
|
||||
is_dfs: false,
|
||||
encrypt_data: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Cycle, repeated N times:
|
||||
/// 1. wait for outstanding (watcher armed)
|
||||
/// 2. push event A → buffered
|
||||
/// 3. deliver_pending → response queued, msg_id consumed
|
||||
/// 4. push GAP event → on `main`, no outstanding → DROPPED;
|
||||
/// on the pipelined-watcher fix, the next request is already
|
||||
/// issued → buffered.
|
||||
///
|
||||
/// Final flush: one more wait_outstanding + push + deliver to make
|
||||
/// sure any buffered gap events on the fix path get out.
|
||||
///
|
||||
/// On `main`: `dropped_count() > 0`, `delivered.len() < expected`.
|
||||
/// On the fix: `dropped_count() == 0`, all events delivered.
|
||||
#[tokio::test]
|
||||
async fn watcher_does_not_lose_events_between_consecutive_requests() {
|
||||
let _ = env_logger::try_init();
|
||||
|
||||
const N_CYCLES: usize = 5;
|
||||
|
||||
let sim = Arc::new(LossySim::new());
|
||||
let conn = setup_connection(&sim);
|
||||
let tree = test_tree();
|
||||
|
||||
let scenario_sim = sim.clone();
|
||||
let scenario = tokio::spawn(async move {
|
||||
let sim = scenario_sim;
|
||||
for round in 0..N_CYCLES {
|
||||
sim.wait_outstanding().await;
|
||||
sim.push_event(&format!("a_{round:02}"));
|
||||
sim.deliver_pending();
|
||||
// Inline push (no .await) — outstanding queue was just
|
||||
// emptied by deliver_pending. On `main`, no request has
|
||||
// been re-issued yet, so this lands in the "drop" branch.
|
||||
// On the fix, a pre-issued request is still outstanding,
|
||||
// so it lands in the "buffer" branch.
|
||||
sim.push_event(&format!("gap_{round:02}"));
|
||||
// Models "time passes between server-side events". Real
|
||||
// workloads have at least a syscall worth of latency
|
||||
// between events, which is enough for the watcher task
|
||||
// to wake up, process the previous response, and
|
||||
// re-dispatch. The pipelining fix only guarantees one
|
||||
// outstanding through the response-processing window,
|
||||
// not through arbitrary back-to-back synchronous
|
||||
// delivers within a single scheduler quantum.
|
||||
tokio::task::yield_now().await;
|
||||
}
|
||||
// Flush: drive one more cycle to push any buffered gap events
|
||||
// out the door for the fix path.
|
||||
sim.wait_outstanding().await;
|
||||
sim.push_event("flush_marker");
|
||||
sim.deliver_pending();
|
||||
// Brief grace period for the watcher to drain the response,
|
||||
// then close so its next next_events() returns Disconnected
|
||||
// and the consumer loop exits.
|
||||
tokio::time::sleep(Duration::from_millis(50)).await;
|
||||
sim.close();
|
||||
});
|
||||
|
||||
let mut watcher = Watcher::new(
|
||||
tree,
|
||||
conn,
|
||||
crate::types::FileId {
|
||||
persistent: 0x1111,
|
||||
volatile: 0x2222,
|
||||
},
|
||||
true,
|
||||
);
|
||||
let mut delivered: Vec<String> = Vec::new();
|
||||
while let Ok(events) = watcher.next_events().await {
|
||||
for e in &events {
|
||||
delivered.push(e.filename.clone());
|
||||
}
|
||||
}
|
||||
scenario.await.unwrap();
|
||||
|
||||
let dropped = sim.dropped_count();
|
||||
// `a_*` events always land in the outstanding window. `flush_marker`
|
||||
// ditto. `gap_*` events expose the bug: dropped today, delivered
|
||||
// after the fix.
|
||||
let expected_min = N_CYCLES /* a_* */ + 1 /* flush_marker */;
|
||||
let expected_max = expected_min + N_CYCLES /* gap_* */;
|
||||
|
||||
assert!(
|
||||
delivered.len() >= expected_min,
|
||||
"watcher dropped 'a_*' or 'flush_marker' events: got {:?}",
|
||||
delivered
|
||||
);
|
||||
assert_eq!(
|
||||
dropped, 0,
|
||||
"{} server-side event(s) arrived with no outstanding CHANGE_NOTIFY \
|
||||
request and were dropped. The pipelined-watcher fix should keep \
|
||||
one CHANGE_NOTIFY request continuously outstanding so no event \
|
||||
ever lands in the drop branch. Delivered to consumer: {:?}",
|
||||
dropped, delivered
|
||||
);
|
||||
assert_eq!(
|
||||
delivered.len(),
|
||||
expected_max,
|
||||
"expected every 'a_*', 'gap_*', and 'flush_marker' event delivered; \
|
||||
got {:?}",
|
||||
delivered
|
||||
);
|
||||
}
|
||||
}
|
||||
55
vendor/smb2/src/crypto/CLAUDE.md
vendored
Normal file
55
vendor/smb2/src/crypto/CLAUDE.md
vendored
Normal file
@@ -0,0 +1,55 @@
|
||||
# Crypto -- signing, encryption, key derivation, compression
|
||||
|
||||
Handles all cryptographic operations. Most users don't touch this directly -- `Session::setup` and `Connection` use it automatically.
|
||||
|
||||
## Key files
|
||||
|
||||
| File | Purpose |
|
||||
|---|---|
|
||||
| `signing.rs` | Sign/verify messages. Three algorithms: HMAC-SHA256, AES-CMAC, AES-GMAC |
|
||||
| `encryption.rs` | Encrypt/decrypt messages. Four ciphers: AES-128/256-CCM, AES-128/256-GCM |
|
||||
| `kdf.rs` | SP800-108 KDF + `PreauthHasher` (SHA-512 running hash) |
|
||||
| `compression.rs` | LZ4 compression for SMB 3.1.1 |
|
||||
|
||||
## Signing algorithms
|
||||
|
||||
| Algorithm | Dialect | Key size |
|
||||
|---|---|---|
|
||||
| HMAC-SHA256 (truncated to 16 bytes) | SMB 2.0.2, 2.1 | any |
|
||||
| AES-128-CMAC | SMB 3.0, 3.0.2, 3.1.1 (fallback) | 16 bytes |
|
||||
| AES-128-GMAC | SMB 3.1.1 (with `SMB2_SIGNING_CAPABILITIES`) | 16 bytes |
|
||||
|
||||
GMAC is AES-128-GCM with empty plaintext. The auth tag IS the signature. The 12-byte nonce encodes `MessageId` (bytes 0-7), a role bit (byte 8 bit 0: 0=client, 1=server), and a cancel flag (byte 8 bit 1).
|
||||
|
||||
## Encryption
|
||||
|
||||
Four ciphers, negotiated during NEGOTIATE:
|
||||
- AES-128-CCM (11-byte nonce) -- SMB 3.0+
|
||||
- AES-128-GCM (12-byte nonce) -- SMB 3.0+
|
||||
- AES-256-CCM (11-byte nonce) -- SMB 3.1.1
|
||||
- AES-256-GCM (12-byte nonce) -- SMB 3.1.1
|
||||
|
||||
Nonces come from a `NonceGenerator` with a monotonic u64 counter. Nonce reuse breaks GCM catastrophically -- the counter must never reset within a session.
|
||||
|
||||
AAD is the TRANSFORM_HEADER bytes 20..52 (Nonce + OriginalMessageSize + Reserved + Flags + SessionId). The auth tag goes into the Signature field at bytes 4..20.
|
||||
|
||||
## Key derivation (SP800-108)
|
||||
|
||||
`derive_session_keys` produces three keys (signing, encryption, decryption) from the NTLM session key using HMAC-SHA256 in counter mode.
|
||||
|
||||
- **SMB 3.0/3.0.2**: Fixed ASCII label/context pairs (for example, `"SMB2AESCMAC\0"` / `"SmbSign\0"`)
|
||||
- **SMB 3.1.1**: New labels (`"SMBSigningKey\0"`) with preauth hash (64-byte SHA-512) as context
|
||||
|
||||
`PreauthHasher` computes `SHA-512(prev_hash || message_bytes)` incrementally over negotiate and session-setup wire bytes. Cloned per session (spec requires per-session hash).
|
||||
|
||||
## Key decisions
|
||||
|
||||
- **Labels include `\0` terminator**: Matches smb-rs and the spec's Label field definitions. The double-null (label `\0` + separator `0x00`) is correct.
|
||||
- **GMAC uses AES-128, not AES-256**: Despite the signing algorithm name containing "256", the actual GMAC implementation uses AES-128-GCM. The "256" in the spec refers to the GMAC algorithm ID, not the key size. Signing keys are always 16 bytes.
|
||||
|
||||
## Gotchas
|
||||
|
||||
- **GMAC nonce has a role bit**: Client signs with role=0, server with role=1. Verify uses role=1 (server). Same message+key produces different signatures for client vs server.
|
||||
- **Signing and encryption are mutually exclusive on the wire**: When encryption is active, the signature field is zeroed (AEAD provides auth). Never sign AND encrypt.
|
||||
- **Nonce counter must not be reused**: `NonceGenerator` panics on u64 overflow (unreachable in practice). Each session gets its own generator.
|
||||
- **HMAC-SHA256 for signing accepts any key length**: Unlike CMAC/GMAC which require exactly 16 bytes. HMAC pads/hashes the key internally.
|
||||
286
vendor/smb2/src/crypto/compression.rs
vendored
Normal file
286
vendor/smb2/src/crypto/compression.rs
vendored
Normal file
@@ -0,0 +1,286 @@
|
||||
//! SMB2 LZ4 compression for unchained mode (MS-SMB2 section 3.1.4.4).
|
||||
//!
|
||||
//! In unchained mode, the `CompressionTransformHeader` has `Flags = 0x0000`.
|
||||
//! The `Offset` field indicates where compressed data starts relative to the
|
||||
//! original message. Bytes before the offset are sent uncompressed (the
|
||||
//! "uncompressed prefix"), while bytes from the offset onward are
|
||||
//! LZ4-compressed.
|
||||
//!
|
||||
//! This allows the SMB2 header to remain uncompressed for routing while the
|
||||
//! payload is compressed.
|
||||
|
||||
/// Maximum decompressed size we allow (16 MB). Prevents decompression bombs.
|
||||
const MAX_DECOMPRESSED_SIZE: u32 = 16 * 1024 * 1024;
|
||||
|
||||
/// The result of compressing an SMB2 message (unchained mode).
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CompressedMessage {
|
||||
/// The original uncompressed size of the compressed portion.
|
||||
pub original_size: u32,
|
||||
/// Bytes before the compression offset (sent as-is).
|
||||
pub uncompressed_prefix: Vec<u8>,
|
||||
/// The LZ4-compressed data.
|
||||
pub compressed_data: Vec<u8>,
|
||||
/// The offset that was used (same as input offset).
|
||||
pub offset: u32,
|
||||
}
|
||||
|
||||
/// Compress an SMB2 message using LZ4 (unchained mode).
|
||||
///
|
||||
/// `offset` indicates where compression starts in the original message.
|
||||
/// Bytes before `offset` are kept as-is (uncompressed prefix).
|
||||
/// Bytes from `offset` onward are LZ4-compressed.
|
||||
///
|
||||
/// Returns `None` if compression doesn't reduce the size (not worth it),
|
||||
/// or if there is nothing to compress (offset >= message length).
|
||||
pub fn compress_message(message: &[u8], offset: usize) -> Option<CompressedMessage> {
|
||||
// Nothing to compress if offset is at or beyond the end.
|
||||
if offset >= message.len() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let prefix = &message[..offset];
|
||||
let to_compress = &message[offset..];
|
||||
|
||||
let compressed = lz4_flex::block::compress(to_compress);
|
||||
|
||||
// Only use compression if it actually reduces size.
|
||||
if compressed.len() >= to_compress.len() {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(CompressedMessage {
|
||||
original_size: to_compress.len() as u32,
|
||||
uncompressed_prefix: prefix.to_vec(),
|
||||
compressed_data: compressed,
|
||||
offset: offset as u32,
|
||||
})
|
||||
}
|
||||
|
||||
/// Decompress an SMB2 message (unchained mode).
|
||||
///
|
||||
/// `uncompressed_prefix` is the data before the compression offset.
|
||||
/// `compressed_data` is the LZ4-compressed portion.
|
||||
/// `original_size` is the expected decompressed size of the compressed portion.
|
||||
///
|
||||
/// Returns the full reconstructed message (prefix + decompressed data).
|
||||
pub fn decompress_message(
|
||||
uncompressed_prefix: &[u8],
|
||||
compressed_data: &[u8],
|
||||
original_size: u32,
|
||||
) -> Result<Vec<u8>, crate::Error> {
|
||||
// Validate original_size to prevent decompression bombs.
|
||||
if original_size > MAX_DECOMPRESSED_SIZE {
|
||||
return Err(crate::Error::invalid_data(format!(
|
||||
"decompressed size {} exceeds maximum allowed size {}",
|
||||
original_size, MAX_DECOMPRESSED_SIZE
|
||||
)));
|
||||
}
|
||||
|
||||
let decompressed = lz4_flex::block::decompress(compressed_data, original_size as usize)
|
||||
.map_err(|e| crate::Error::invalid_data(format!("LZ4 decompression failed: {e}")))?;
|
||||
|
||||
let mut result = Vec::with_capacity(uncompressed_prefix.len() + decompressed.len());
|
||||
result.extend_from_slice(uncompressed_prefix);
|
||||
result.extend_from_slice(&decompressed);
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn compress_and_decompress_roundtrip() {
|
||||
// Compressible data: repeated pattern.
|
||||
let message: Vec<u8> = b"ABCDEFGH".iter().copied().cycle().take(1024).collect();
|
||||
|
||||
let compressed = compress_message(&message, 0).expect("should compress");
|
||||
assert!(compressed.compressed_data.len() < message.len());
|
||||
assert_eq!(compressed.original_size, message.len() as u32);
|
||||
assert!(compressed.uncompressed_prefix.is_empty());
|
||||
assert_eq!(compressed.offset, 0);
|
||||
|
||||
let decompressed = decompress_message(
|
||||
&compressed.uncompressed_prefix,
|
||||
&compressed.compressed_data,
|
||||
compressed.original_size,
|
||||
)
|
||||
.expect("should decompress");
|
||||
|
||||
assert_eq!(decompressed, message);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compress_with_offset_preserves_prefix() {
|
||||
// Simulate a 64-byte SMB2 header + compressible payload.
|
||||
let mut message = vec![0xFE; 64]; // "header" bytes
|
||||
let payload: Vec<u8> = b"HelloWorld".iter().copied().cycle().take(2048).collect();
|
||||
message.extend_from_slice(&payload);
|
||||
|
||||
let compressed = compress_message(&message, 64).expect("should compress");
|
||||
assert_eq!(compressed.offset, 64);
|
||||
assert_eq!(compressed.uncompressed_prefix, &message[..64]);
|
||||
assert_eq!(compressed.original_size, payload.len() as u32);
|
||||
assert!(compressed.compressed_data.len() < payload.len());
|
||||
|
||||
let decompressed = decompress_message(
|
||||
&compressed.uncompressed_prefix,
|
||||
&compressed.compressed_data,
|
||||
compressed.original_size,
|
||||
)
|
||||
.expect("should decompress");
|
||||
|
||||
assert_eq!(decompressed, message);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compress_with_offset_zero_compresses_entire_message() {
|
||||
let message: Vec<u8> = vec![42u8; 4096];
|
||||
|
||||
let compressed = compress_message(&message, 0).expect("should compress");
|
||||
assert_eq!(compressed.offset, 0);
|
||||
assert!(compressed.uncompressed_prefix.is_empty());
|
||||
assert_eq!(compressed.original_size, 4096);
|
||||
|
||||
let decompressed = decompress_message(
|
||||
&compressed.uncompressed_prefix,
|
||||
&compressed.compressed_data,
|
||||
compressed.original_size,
|
||||
)
|
||||
.expect("should decompress");
|
||||
|
||||
assert_eq!(decompressed, message);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compress_empty_message_returns_none() {
|
||||
let message: &[u8] = &[];
|
||||
assert!(compress_message(message, 0).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compress_offset_at_end_returns_none() {
|
||||
let message = b"short";
|
||||
assert!(compress_message(message, 5).is_none());
|
||||
assert!(compress_message(message, 100).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn incompressible_data_returns_none() {
|
||||
// Random-ish bytes that LZ4 cannot compress (will likely grow).
|
||||
let mut message = Vec::with_capacity(256);
|
||||
for i in 0u16..256 {
|
||||
// Use a simple PRNG-like pattern that doesn't compress well.
|
||||
message.push(((i.wrapping_mul(137).wrapping_add(53)) & 0xFF) as u8);
|
||||
}
|
||||
|
||||
// Small incompressible data should return None.
|
||||
assert!(
|
||||
compress_message(&message, 0).is_none(),
|
||||
"incompressible data should return None"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn large_message_compresses_well() {
|
||||
// 1 MB of repeated pattern -- should compress very well.
|
||||
let message: Vec<u8> = b"SMB2 compression test data! "
|
||||
.iter()
|
||||
.copied()
|
||||
.cycle()
|
||||
.take(1024 * 1024)
|
||||
.collect();
|
||||
|
||||
let compressed = compress_message(&message, 0).expect("should compress large message");
|
||||
|
||||
// LZ4 should achieve at least 4:1 on highly repetitive data.
|
||||
let ratio = message.len() as f64 / compressed.compressed_data.len() as f64;
|
||||
assert!(
|
||||
ratio > 4.0,
|
||||
"compression ratio {ratio:.1} is too low for repetitive data"
|
||||
);
|
||||
|
||||
let decompressed = decompress_message(
|
||||
&compressed.uncompressed_prefix,
|
||||
&compressed.compressed_data,
|
||||
compressed.original_size,
|
||||
)
|
||||
.expect("should decompress");
|
||||
|
||||
assert_eq!(decompressed.len(), message.len());
|
||||
assert_eq!(decompressed, message);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decompress_with_wrong_original_size_fails() {
|
||||
let message: Vec<u8> = vec![0xAA; 1024];
|
||||
let compressed = compress_message(&message, 0).expect("should compress");
|
||||
|
||||
// Use a wrong (smaller) original_size -- decompression should fail
|
||||
// because LZ4 validates the output size.
|
||||
let result = decompress_message(&[], &compressed.compressed_data, 512);
|
||||
assert!(result.is_err(), "wrong original_size should cause an error");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decompress_rejects_oversized_original_size() {
|
||||
// Attempt to decompress with original_size exceeding 16 MB limit.
|
||||
let bogus_compressed = vec![0u8; 10];
|
||||
let result = decompress_message(&[], &bogus_compressed, MAX_DECOMPRESSED_SIZE + 1);
|
||||
assert!(result.is_err());
|
||||
|
||||
let err_msg = result.unwrap_err().to_string();
|
||||
assert!(
|
||||
err_msg.contains("exceeds maximum"),
|
||||
"error should mention size limit, got: {err_msg}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decompress_with_exact_max_size_is_allowed() {
|
||||
// original_size == MAX_DECOMPRESSED_SIZE should not be rejected
|
||||
// by the size check (it will fail on actual decompression since the
|
||||
// data is bogus, but that's a different error).
|
||||
let bogus_compressed = vec![0u8; 10];
|
||||
let result = decompress_message(&[], &bogus_compressed, MAX_DECOMPRESSED_SIZE);
|
||||
|
||||
// Should fail on decompression, not on size validation.
|
||||
assert!(result.is_err());
|
||||
let err_msg = result.unwrap_err().to_string();
|
||||
assert!(
|
||||
err_msg.contains("decompression failed"),
|
||||
"should fail on decompression, not size check, got: {err_msg}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decompress_corrupt_data_fails() {
|
||||
let corrupt = vec![0xFF, 0xFE, 0xFD, 0xFC, 0xFB];
|
||||
let result = decompress_message(&[], &corrupt, 1024);
|
||||
assert!(result.is_err());
|
||||
|
||||
let err_msg = result.unwrap_err().to_string();
|
||||
assert!(
|
||||
err_msg.contains("decompression failed"),
|
||||
"error should mention decompression failure, got: {err_msg}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decompress_preserves_prefix_in_output() {
|
||||
let prefix = b"PREFIX_DATA";
|
||||
let payload: Vec<u8> = vec![0x42; 2048];
|
||||
let compressed_payload = compress_message(&payload, 0).expect("should compress payload");
|
||||
|
||||
let result = decompress_message(
|
||||
prefix,
|
||||
&compressed_payload.compressed_data,
|
||||
compressed_payload.original_size,
|
||||
)
|
||||
.expect("should decompress");
|
||||
|
||||
assert_eq!(&result[..prefix.len()], prefix);
|
||||
assert_eq!(&result[prefix.len()..], &payload);
|
||||
}
|
||||
}
|
||||
591
vendor/smb2/src/crypto/encryption.rs
vendored
Normal file
591
vendor/smb2/src/crypto/encryption.rs
vendored
Normal file
@@ -0,0 +1,591 @@
|
||||
//! SMB2/3 message encryption and decryption.
|
||||
//!
|
||||
//! Implements AES-128-CCM, AES-128-GCM, AES-256-CCM, and AES-256-GCM
|
||||
//! as specified in MS-SMB2 sections 3.1.4.3 (encrypting) and 3.1.5.1
|
||||
//! (decrypting). Nonces are generated from a monotonically increasing
|
||||
//! per-session counter to prevent catastrophic nonce reuse in AES-GCM.
|
||||
|
||||
use aes::{Aes128, Aes256};
|
||||
use aes_gcm::aead::{array::Array, inout::InOutBuf, AeadInOut};
|
||||
use aes_gcm::KeyInit;
|
||||
use ccm::consts::{U11, U16};
|
||||
|
||||
use crate::msg::transform::{TransformHeader, SMB2_TRANSFORM_HEADER_FLAG_ENCRYPTED};
|
||||
use crate::pack::{Pack, WriteCursor};
|
||||
use crate::types::SessionId;
|
||||
use crate::Error;
|
||||
|
||||
/// Offset in the serialized TRANSFORM_HEADER where the AAD begins.
|
||||
///
|
||||
/// The AAD is "the SMB2 TRANSFORM_HEADER, excluding the ProtocolId and
|
||||
/// Signature fields" (MS-SMB2 section 3.1.4.3). ProtocolId is 4 bytes
|
||||
/// and Signature is 16 bytes, so the AAD starts at offset 20 (the Nonce
|
||||
/// field) and extends to the end of the 52-byte header.
|
||||
const AAD_OFFSET: usize = 20;
|
||||
|
||||
/// Total size of the TRANSFORM_HEADER in bytes.
|
||||
const HEADER_SIZE: usize = TransformHeader::SIZE; // 52
|
||||
|
||||
// ── CCM type aliases ─────────────────────────────────────────────────
|
||||
|
||||
/// AES-128-CCM with 16-byte tag and 11-byte nonce (SMB 3.0+).
|
||||
type Aes128Ccm = ccm::Ccm<Aes128, U16, U11>;
|
||||
|
||||
/// AES-256-CCM with 16-byte tag and 11-byte nonce (SMB 3.1.1).
|
||||
type Aes256Ccm = ccm::Ccm<Aes256, U16, U11>;
|
||||
|
||||
// ── Cipher enum ──────────────────────────────────────────────────────
|
||||
|
||||
/// Encryption cipher, determined during negotiation.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
|
||||
pub enum Cipher {
|
||||
/// AES-128-CCM (SMB 3.0+) -- 11-byte nonce.
|
||||
Aes128Ccm,
|
||||
/// AES-128-GCM (SMB 3.0+) -- 12-byte nonce.
|
||||
Aes128Gcm,
|
||||
/// AES-256-CCM (SMB 3.1.1) -- 11-byte nonce.
|
||||
Aes256Ccm,
|
||||
/// AES-256-GCM (SMB 3.1.1) -- 12-byte nonce.
|
||||
Aes256Gcm,
|
||||
}
|
||||
|
||||
impl Cipher {
|
||||
/// Returns the number of nonce bytes actually used by this cipher.
|
||||
pub fn nonce_len(self) -> usize {
|
||||
match self {
|
||||
Cipher::Aes128Ccm | Cipher::Aes256Ccm => 11,
|
||||
Cipher::Aes128Gcm | Cipher::Aes256Gcm => 12,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the expected key length in bytes.
|
||||
fn key_len(self) -> usize {
|
||||
match self {
|
||||
Cipher::Aes128Ccm | Cipher::Aes128Gcm => 16,
|
||||
Cipher::Aes256Ccm | Cipher::Aes256Gcm => 32,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Nonce generator ──────────────────────────────────────────────────
|
||||
|
||||
/// Monotonically increasing nonce generator.
|
||||
///
|
||||
/// Each session gets its own nonce generator. The counter MUST NOT
|
||||
/// be reused -- nonce reuse breaks AES-GCM catastrophically.
|
||||
pub struct NonceGenerator {
|
||||
counter: u64,
|
||||
}
|
||||
|
||||
impl NonceGenerator {
|
||||
/// Create a new nonce generator starting at counter 0.
|
||||
pub fn new() -> Self {
|
||||
Self { counter: 0 }
|
||||
}
|
||||
|
||||
/// Generate the next nonce for the given cipher.
|
||||
///
|
||||
/// Returns the full 16-byte nonce field for the TRANSFORM_HEADER.
|
||||
/// - CCM: 8-byte LE counter in bytes 0..8, zeros in bytes 8..16
|
||||
/// (the cipher uses the first 11 bytes as the nonce).
|
||||
/// - GCM: 8-byte LE counter in bytes 0..8, zeros in bytes 8..16
|
||||
/// (the cipher uses the first 12 bytes as the nonce).
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if the counter overflows `u64::MAX`. In practice this
|
||||
/// can never happen (2^64 messages at line speed would take millennia).
|
||||
pub fn next(&mut self, _cipher: Cipher) -> [u8; 16] {
|
||||
let count = self.counter;
|
||||
self.counter = self.counter.checked_add(1).expect("nonce counter overflow");
|
||||
let mut nonce = [0u8; 16];
|
||||
nonce[..8].copy_from_slice(&count.to_le_bytes());
|
||||
nonce
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for NonceGenerator {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
// ── Encrypt ──────────────────────────────────────────────────────────
|
||||
|
||||
/// Encrypt an SMB2 message.
|
||||
///
|
||||
/// Returns `(transform_header_bytes, encrypted_message)`. The 52-byte
|
||||
/// transform header includes the protocol ID, auth tag (in the Signature
|
||||
/// field), nonce, original message size, flags, and session ID. The
|
||||
/// encrypted message replaces the plaintext.
|
||||
pub fn encrypt_message(
|
||||
plaintext: &[u8],
|
||||
key: &[u8],
|
||||
cipher: Cipher,
|
||||
nonce: &[u8; 16],
|
||||
session_id: u64,
|
||||
) -> Result<(Vec<u8>, Vec<u8>), Error> {
|
||||
if key.len() != cipher.key_len() {
|
||||
return Err(Error::invalid_data(format!(
|
||||
"encryption key length mismatch: expected {}, got {}",
|
||||
cipher.key_len(),
|
||||
key.len()
|
||||
)));
|
||||
}
|
||||
|
||||
// Build the TRANSFORM_HEADER with a zeroed signature (will be filled
|
||||
// with the auth tag after encryption).
|
||||
let header = TransformHeader {
|
||||
signature: [0u8; 16],
|
||||
nonce: *nonce,
|
||||
original_message_size: plaintext.len() as u32,
|
||||
flags: SMB2_TRANSFORM_HEADER_FLAG_ENCRYPTED,
|
||||
session_id: SessionId(session_id),
|
||||
};
|
||||
|
||||
let mut header_bytes = {
|
||||
let mut w = WriteCursor::new();
|
||||
header.pack(&mut w);
|
||||
w.into_inner()
|
||||
};
|
||||
|
||||
// AAD = header bytes 20..52 (Nonce + OriginalMessageSize + Reserved + Flags + SessionId)
|
||||
let aad = &header_bytes[AAD_OFFSET..HEADER_SIZE];
|
||||
|
||||
// Encrypt and get the auth tag.
|
||||
let mut buffer = plaintext.to_vec();
|
||||
let nonce_slice = &nonce[..cipher.nonce_len()];
|
||||
|
||||
let tag = encrypt_raw(cipher, key, nonce_slice, aad, &mut buffer)?;
|
||||
|
||||
// Write the 16-byte auth tag into the Signature field (bytes 4..20).
|
||||
header_bytes[4..20].copy_from_slice(&tag);
|
||||
|
||||
Ok((header_bytes, buffer))
|
||||
}
|
||||
|
||||
// ── Decrypt ──────────────────────────────────────────────────────────
|
||||
|
||||
/// Decrypt an SMB2 message.
|
||||
///
|
||||
/// `transform_header` is the 52-byte TRANSFORM_HEADER (as received on
|
||||
/// the wire). `ciphertext` is the encrypted message data that follows
|
||||
/// the header. Returns the decrypted plaintext.
|
||||
pub fn decrypt_message(
|
||||
transform_header: &[u8],
|
||||
ciphertext: &[u8],
|
||||
key: &[u8],
|
||||
cipher: Cipher,
|
||||
) -> Result<Vec<u8>, Error> {
|
||||
if transform_header.len() != HEADER_SIZE {
|
||||
return Err(Error::invalid_data(format!(
|
||||
"transform header must be {} bytes, got {}",
|
||||
HEADER_SIZE,
|
||||
transform_header.len()
|
||||
)));
|
||||
}
|
||||
if key.len() != cipher.key_len() {
|
||||
return Err(Error::invalid_data(format!(
|
||||
"decryption key length mismatch: expected {}, got {}",
|
||||
cipher.key_len(),
|
||||
key.len()
|
||||
)));
|
||||
}
|
||||
|
||||
// Extract auth tag (Signature) from bytes 4..20.
|
||||
let mut tag = [0u8; 16];
|
||||
tag.copy_from_slice(&transform_header[4..20]);
|
||||
|
||||
// Extract nonce from bytes 20..36.
|
||||
let nonce = &transform_header[20..20 + cipher.nonce_len()];
|
||||
|
||||
// AAD = header bytes 20..52.
|
||||
let aad = &transform_header[AAD_OFFSET..HEADER_SIZE];
|
||||
|
||||
let mut buffer = ciphertext.to_vec();
|
||||
decrypt_raw(cipher, key, nonce, aad, &tag, &mut buffer)?;
|
||||
|
||||
Ok(buffer)
|
||||
}
|
||||
|
||||
// ── Raw encrypt/decrypt helpers ──────────────────────────────────────
|
||||
|
||||
/// Copy an auth tag array into a fixed-size `[u8; 16]` array.
|
||||
fn tag_to_array<N: aes_gcm::aead::array::ArraySize>(tag: Array<u8, N>) -> [u8; 16] {
|
||||
let mut arr = [0u8; 16];
|
||||
arr.copy_from_slice(tag.as_slice());
|
||||
arr
|
||||
}
|
||||
|
||||
/// Encrypt `buffer` in place and return the 16-byte auth tag.
|
||||
fn encrypt_raw(
|
||||
cipher: Cipher,
|
||||
key: &[u8],
|
||||
nonce: &[u8],
|
||||
aad: &[u8],
|
||||
buffer: &mut [u8],
|
||||
) -> Result<[u8; 16], Error> {
|
||||
let map_err = |_| Error::invalid_data("encryption failed");
|
||||
let buf = InOutBuf::from(buffer);
|
||||
|
||||
let tag = match cipher {
|
||||
Cipher::Aes128Ccm => {
|
||||
let c = Aes128Ccm::new(key.try_into().expect("key length validated"));
|
||||
let n = nonce.try_into().expect("nonce length validated");
|
||||
c.encrypt_inout_detached(n, aad, buf)
|
||||
.map(tag_to_array)
|
||||
.map_err(map_err)?
|
||||
}
|
||||
Cipher::Aes128Gcm => {
|
||||
let c = aes_gcm::Aes128Gcm::new(key.try_into().expect("key length validated"));
|
||||
let n = nonce.try_into().expect("nonce length validated");
|
||||
c.encrypt_inout_detached(n, aad, buf)
|
||||
.map(tag_to_array)
|
||||
.map_err(map_err)?
|
||||
}
|
||||
Cipher::Aes256Ccm => {
|
||||
let c = Aes256Ccm::new(key.try_into().expect("key length validated"));
|
||||
let n = nonce.try_into().expect("nonce length validated");
|
||||
c.encrypt_inout_detached(n, aad, buf)
|
||||
.map(tag_to_array)
|
||||
.map_err(map_err)?
|
||||
}
|
||||
Cipher::Aes256Gcm => {
|
||||
let c = aes_gcm::Aes256Gcm::new(key.try_into().expect("key length validated"));
|
||||
let n = nonce.try_into().expect("nonce length validated");
|
||||
c.encrypt_inout_detached(n, aad, buf)
|
||||
.map(tag_to_array)
|
||||
.map_err(map_err)?
|
||||
}
|
||||
};
|
||||
|
||||
Ok(tag)
|
||||
}
|
||||
|
||||
/// Decrypt `buffer` in place, verifying the 16-byte auth tag.
|
||||
fn decrypt_raw(
|
||||
cipher: Cipher,
|
||||
key: &[u8],
|
||||
nonce: &[u8],
|
||||
aad: &[u8],
|
||||
tag: &[u8; 16],
|
||||
buffer: &mut [u8],
|
||||
) -> Result<(), Error> {
|
||||
let map_err = |_| Error::invalid_data("decryption failed: authentication tag mismatch");
|
||||
let buf = InOutBuf::from(buffer);
|
||||
let t: &Array<u8, _> = tag.into();
|
||||
|
||||
match cipher {
|
||||
Cipher::Aes128Ccm => {
|
||||
let c = Aes128Ccm::new(key.try_into().expect("key length validated"));
|
||||
let n = nonce.try_into().expect("nonce length validated");
|
||||
c.decrypt_inout_detached(n, aad, buf, t).map_err(map_err)
|
||||
}
|
||||
Cipher::Aes128Gcm => {
|
||||
let c = aes_gcm::Aes128Gcm::new(key.try_into().expect("key length validated"));
|
||||
let n = nonce.try_into().expect("nonce length validated");
|
||||
c.decrypt_inout_detached(n, aad, buf, t).map_err(map_err)
|
||||
}
|
||||
Cipher::Aes256Ccm => {
|
||||
let c = Aes256Ccm::new(key.try_into().expect("key length validated"));
|
||||
let n = nonce.try_into().expect("nonce length validated");
|
||||
c.decrypt_inout_detached(n, aad, buf, t).map_err(map_err)
|
||||
}
|
||||
Cipher::Aes256Gcm => {
|
||||
let c = aes_gcm::Aes256Gcm::new(key.try_into().expect("key length validated"));
|
||||
let n = nonce.try_into().expect("nonce length validated");
|
||||
c.decrypt_inout_detached(n, aad, buf, t).map_err(map_err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::msg::transform::TRANSFORM_PROTOCOL_ID;
|
||||
|
||||
// ── Helper ────────────────────────────────────────────────────────
|
||||
|
||||
fn test_key(cipher: Cipher) -> Vec<u8> {
|
||||
vec![0x42; cipher.key_len()]
|
||||
}
|
||||
|
||||
// ── Encrypt-then-decrypt roundtrip (one per cipher) ──────────────
|
||||
|
||||
#[test]
|
||||
fn roundtrip_aes128_ccm() {
|
||||
roundtrip_cipher(Cipher::Aes128Ccm);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn roundtrip_aes128_gcm() {
|
||||
roundtrip_cipher(Cipher::Aes128Gcm);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn roundtrip_aes256_ccm() {
|
||||
roundtrip_cipher(Cipher::Aes256Ccm);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn roundtrip_aes256_gcm() {
|
||||
roundtrip_cipher(Cipher::Aes256Gcm);
|
||||
}
|
||||
|
||||
fn roundtrip_cipher(cipher: Cipher) {
|
||||
let key = test_key(cipher);
|
||||
let plaintext = b"Hello, SMB2 encryption roundtrip!";
|
||||
let session_id = 0xDEAD_BEEF_CAFE_FACE;
|
||||
|
||||
let mut nonce_gen = NonceGenerator::new();
|
||||
let nonce = nonce_gen.next(cipher);
|
||||
|
||||
let (header, ciphertext) =
|
||||
encrypt_message(plaintext, &key, cipher, &nonce, session_id).unwrap();
|
||||
|
||||
// Ciphertext must differ from plaintext.
|
||||
assert_ne!(&ciphertext[..], &plaintext[..]);
|
||||
|
||||
let decrypted = decrypt_message(&header, &ciphertext, &key, cipher).unwrap();
|
||||
assert_eq!(decrypted, plaintext);
|
||||
}
|
||||
|
||||
// ── Nonce generator monotonically increases ──────────────────────
|
||||
|
||||
#[test]
|
||||
fn nonce_generator_monotonic() {
|
||||
let mut gen = NonceGenerator::new();
|
||||
let mut prev = [0u8; 16]; // counter 0 hasn't been generated yet
|
||||
|
||||
for i in 0u64..100 {
|
||||
let nonce = gen.next(Cipher::Aes128Gcm);
|
||||
// Extract the 8-byte LE counter from the nonce.
|
||||
let counter = u64::from_le_bytes(nonce[..8].try_into().unwrap());
|
||||
assert_eq!(counter, i, "counter should equal {i}");
|
||||
|
||||
if i > 0 {
|
||||
assert_ne!(nonce, prev, "each nonce must be unique");
|
||||
}
|
||||
prev = nonce;
|
||||
}
|
||||
}
|
||||
|
||||
// ── Nonce format for GCM ─────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn nonce_format_gcm() {
|
||||
let mut gen = NonceGenerator::new();
|
||||
// Advance to counter = 7 to have a non-trivial value.
|
||||
for _ in 0..7 {
|
||||
gen.next(Cipher::Aes128Gcm);
|
||||
}
|
||||
let nonce = gen.next(Cipher::Aes128Gcm); // counter = 7
|
||||
|
||||
// First 8 bytes: LE counter (7).
|
||||
assert_eq!(
|
||||
u64::from_le_bytes(nonce[..8].try_into().unwrap()),
|
||||
7,
|
||||
"counter value"
|
||||
);
|
||||
// Bytes 8..12: zeros (padding to 12-byte GCM nonce).
|
||||
assert_eq!(nonce[8..12], [0, 0, 0, 0], "GCM nonce padding (8..12)");
|
||||
// Bytes 12..16: zeros (unused portion of the 16-byte field).
|
||||
assert_eq!(nonce[12..16], [0, 0, 0, 0], "unused nonce bytes (12..16)");
|
||||
}
|
||||
|
||||
// ── Nonce format for CCM ─────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn nonce_format_ccm() {
|
||||
let mut gen = NonceGenerator::new();
|
||||
// Advance to counter = 5.
|
||||
for _ in 0..5 {
|
||||
gen.next(Cipher::Aes128Ccm);
|
||||
}
|
||||
let nonce = gen.next(Cipher::Aes128Ccm); // counter = 5
|
||||
|
||||
// First 8 bytes: LE counter (5).
|
||||
assert_eq!(
|
||||
u64::from_le_bytes(nonce[..8].try_into().unwrap()),
|
||||
5,
|
||||
"counter value"
|
||||
);
|
||||
// Bytes 8..11: zeros (padding to 11-byte CCM nonce).
|
||||
assert_eq!(nonce[8..11], [0, 0, 0], "CCM nonce padding (8..11)");
|
||||
// Bytes 11..16: zeros (unused portion of the 16-byte field).
|
||||
assert_eq!(
|
||||
nonce[11..16],
|
||||
[0, 0, 0, 0, 0],
|
||||
"unused nonce bytes (11..16)"
|
||||
);
|
||||
}
|
||||
|
||||
// ── Tampered ciphertext fails decryption ─────────────────────────
|
||||
|
||||
#[test]
|
||||
fn tampered_ciphertext_fails() {
|
||||
let cipher = Cipher::Aes128Gcm;
|
||||
let key = test_key(cipher);
|
||||
let plaintext = b"Do not tamper with me!";
|
||||
let session_id = 42;
|
||||
|
||||
let mut gen = NonceGenerator::new();
|
||||
let nonce = gen.next(cipher);
|
||||
|
||||
let (header, mut ciphertext) =
|
||||
encrypt_message(plaintext, &key, cipher, &nonce, session_id).unwrap();
|
||||
|
||||
// Flip a byte in the ciphertext.
|
||||
ciphertext[0] ^= 0xFF;
|
||||
|
||||
let result = decrypt_message(&header, &ciphertext, &key, cipher);
|
||||
assert!(result.is_err(), "tampered ciphertext must fail decryption");
|
||||
let err = result.unwrap_err().to_string();
|
||||
assert!(
|
||||
err.contains("tag mismatch") || err.contains("decryption failed"),
|
||||
"error was: {err}"
|
||||
);
|
||||
}
|
||||
|
||||
// ── Wrong key fails decryption ───────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn wrong_key_fails() {
|
||||
let cipher = Cipher::Aes256Gcm;
|
||||
let key = test_key(cipher);
|
||||
let wrong_key = vec![0x99; cipher.key_len()];
|
||||
let plaintext = b"Secret message";
|
||||
let session_id = 100;
|
||||
|
||||
let mut gen = NonceGenerator::new();
|
||||
let nonce = gen.next(cipher);
|
||||
|
||||
let (header, ciphertext) =
|
||||
encrypt_message(plaintext, &key, cipher, &nonce, session_id).unwrap();
|
||||
|
||||
let result = decrypt_message(&header, &ciphertext, &wrong_key, cipher);
|
||||
assert!(result.is_err(), "wrong key must fail decryption");
|
||||
}
|
||||
|
||||
// ── AAD includes correct TRANSFORM_HEADER bytes (offset 20-51) ──
|
||||
|
||||
#[test]
|
||||
fn aad_is_correct_header_region() {
|
||||
// Verify the AAD constants match the spec.
|
||||
assert_eq!(AAD_OFFSET, 20, "AAD starts at byte 20");
|
||||
assert_eq!(
|
||||
HEADER_SIZE - AAD_OFFSET,
|
||||
32,
|
||||
"AAD is 32 bytes (Nonce + OrigMsgSize + Reserved + Flags + SessionId)"
|
||||
);
|
||||
assert_eq!(HEADER_SIZE, 52, "TRANSFORM_HEADER is 52 bytes");
|
||||
|
||||
// Build a header and verify the AAD region contains the expected fields.
|
||||
let mut nonce = [0u8; 16];
|
||||
nonce[0] = 0xAA;
|
||||
nonce[7] = 0xBB;
|
||||
|
||||
let header = TransformHeader {
|
||||
signature: [0xFF; 16],
|
||||
nonce,
|
||||
original_message_size: 1024,
|
||||
flags: SMB2_TRANSFORM_HEADER_FLAG_ENCRYPTED,
|
||||
session_id: SessionId(0x0123_4567_89AB_CDEF),
|
||||
};
|
||||
|
||||
let mut w = WriteCursor::new();
|
||||
header.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
let aad = &bytes[AAD_OFFSET..HEADER_SIZE];
|
||||
assert_eq!(aad.len(), 32);
|
||||
|
||||
// First 16 bytes of AAD should be the nonce.
|
||||
assert_eq!(aad[0], 0xAA, "nonce byte 0");
|
||||
assert_eq!(aad[7], 0xBB, "nonce byte 7");
|
||||
|
||||
// Bytes 16..20 of AAD should be OriginalMessageSize (1024 LE).
|
||||
assert_eq!(
|
||||
u32::from_le_bytes(aad[16..20].try_into().unwrap()),
|
||||
1024,
|
||||
"OriginalMessageSize"
|
||||
);
|
||||
|
||||
// Bytes 20..22 of AAD should be Reserved (0).
|
||||
assert_eq!(aad[20..22], [0, 0], "Reserved");
|
||||
|
||||
// Bytes 22..24 of AAD should be Flags (0x0001).
|
||||
assert_eq!(
|
||||
u16::from_le_bytes(aad[22..24].try_into().unwrap()),
|
||||
SMB2_TRANSFORM_HEADER_FLAG_ENCRYPTED,
|
||||
"Flags"
|
||||
);
|
||||
|
||||
// Bytes 24..32 of AAD should be SessionId.
|
||||
assert_eq!(
|
||||
u64::from_le_bytes(aad[24..32].try_into().unwrap()),
|
||||
0x0123_4567_89AB_CDEF,
|
||||
"SessionId"
|
||||
);
|
||||
}
|
||||
|
||||
// ── Transform header has correct protocol ID ─────────────────────
|
||||
|
||||
#[test]
|
||||
fn transform_header_protocol_id() {
|
||||
let cipher = Cipher::Aes128Gcm;
|
||||
let key = test_key(cipher);
|
||||
let plaintext = b"test";
|
||||
let session_id = 1;
|
||||
|
||||
let mut gen = NonceGenerator::new();
|
||||
let nonce = gen.next(cipher);
|
||||
|
||||
let (header, _) = encrypt_message(plaintext, &key, cipher, &nonce, session_id).unwrap();
|
||||
|
||||
// First 4 bytes must be 0xFD 'S' 'M' 'B'.
|
||||
assert_eq!(&header[..4], &TRANSFORM_PROTOCOL_ID);
|
||||
assert_eq!(header[0], 0xFD, "protocol ID first byte must be 0xFD");
|
||||
assert_eq!(header[1], b'S');
|
||||
assert_eq!(header[2], b'M');
|
||||
assert_eq!(header[3], b'B');
|
||||
}
|
||||
|
||||
// ── Auth tag (signature) is at bytes 4..20 ──────────────────────
|
||||
|
||||
#[test]
|
||||
fn signature_position_in_header() {
|
||||
let cipher = Cipher::Aes256Ccm;
|
||||
let key = test_key(cipher);
|
||||
let plaintext = b"Check signature position";
|
||||
let session_id = 99;
|
||||
|
||||
let mut gen = NonceGenerator::new();
|
||||
let nonce = gen.next(cipher);
|
||||
|
||||
let (header, _) = encrypt_message(plaintext, &key, cipher, &nonce, session_id).unwrap();
|
||||
|
||||
// The signature (auth tag) lives at bytes 4..20.
|
||||
let signature = &header[4..20];
|
||||
|
||||
// It should NOT be all zeros (that would mean we forgot to write it).
|
||||
assert_ne!(
|
||||
signature, &[0u8; 16],
|
||||
"signature must not be all zeros after encryption"
|
||||
);
|
||||
|
||||
// Verify that using this tag allows successful decryption
|
||||
// (already covered by roundtrip tests, but this confirms the
|
||||
// position explicitly).
|
||||
let decrypted = decrypt_message(&header, &header[..0], &key, cipher);
|
||||
// This will fail because we passed empty ciphertext, but that's
|
||||
// not the point -- the roundtrip tests cover correctness.
|
||||
// Instead, let's verify the tag by a proper roundtrip.
|
||||
drop(decrypted);
|
||||
|
||||
let (header2, ct2) = encrypt_message(plaintext, &key, cipher, &nonce, session_id).unwrap();
|
||||
let result = decrypt_message(&header2, &ct2, &key, cipher).unwrap();
|
||||
assert_eq!(result, plaintext);
|
||||
}
|
||||
}
|
||||
525
vendor/smb2/src/crypto/kdf.rs
vendored
Normal file
525
vendor/smb2/src/crypto/kdf.rs
vendored
Normal file
@@ -0,0 +1,525 @@
|
||||
//! SP800-108 key derivation and preauthentication integrity hashing for SMB2/3.
|
||||
//!
|
||||
//! SMB 3.x uses NIST SP800-108 KDF in counter mode with HMAC-SHA256 as the PRF
|
||||
//! to derive signing, encryption, and decryption keys from the session key.
|
||||
//!
|
||||
//! SMB 3.1.1 additionally requires a preauthentication integrity hash (SHA-512)
|
||||
//! computed over the raw wire bytes of NEGOTIATE and SESSION_SETUP exchanges,
|
||||
//! which feeds into the KDF as the "context" parameter.
|
||||
|
||||
use crate::types::Dialect;
|
||||
use digest::{Digest, KeyInit};
|
||||
use hmac::{Hmac, Mac};
|
||||
use sha2::{Sha256, Sha512};
|
||||
|
||||
type HmacSha256 = Hmac<Sha256>;
|
||||
|
||||
/// Derive a key using SP800-108 KDF in counter mode with HMAC-SHA256.
|
||||
///
|
||||
/// This implements the algorithm from NIST SP800-108 section 5.1 as required
|
||||
/// by MS-SMB2 section 3.1.4.2. The counter width ('r') is 32 bits, and the
|
||||
/// PRF is HMAC-SHA256.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `key` - The key to derive from (the session key from authentication).
|
||||
/// * `label` - Label string (including null terminator).
|
||||
/// * `context` - Context string or preauth hash (including null terminator for
|
||||
/// string contexts).
|
||||
/// * `key_length_bits` - Desired output key length in bits (128 or 256).
|
||||
pub fn sp800_108_kdf(key: &[u8], label: &[u8], context: &[u8], key_length_bits: u32) -> Vec<u8> {
|
||||
let iterations = key_length_bits.div_ceil(256);
|
||||
let mut result = Vec::with_capacity((iterations * 32) as usize);
|
||||
|
||||
for i in 1..=iterations {
|
||||
let mut mac = HmacSha256::new_from_slice(key).expect("HMAC-SHA256 accepts any key length");
|
||||
|
||||
// counter (32-bit big-endian)
|
||||
mac.update(&i.to_be_bytes());
|
||||
// label
|
||||
mac.update(label);
|
||||
// separator byte 0x00
|
||||
mac.update(&[0x00]);
|
||||
// context
|
||||
mac.update(context);
|
||||
// L = key length in bits (32-bit big-endian)
|
||||
mac.update(&key_length_bits.to_be_bytes());
|
||||
|
||||
result.extend_from_slice(&mac.finalize().into_bytes());
|
||||
}
|
||||
|
||||
result.truncate((key_length_bits / 8) as usize);
|
||||
result
|
||||
}
|
||||
|
||||
/// Derived session keys for signing, encryption, and decryption.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DerivedKeys {
|
||||
/// Key used to sign outgoing messages.
|
||||
pub signing_key: Vec<u8>,
|
||||
/// Key used to encrypt outgoing messages.
|
||||
pub encryption_key: Vec<u8>,
|
||||
/// Key used to decrypt incoming messages.
|
||||
pub decryption_key: Vec<u8>,
|
||||
}
|
||||
|
||||
/// Derive session keys for the given dialect.
|
||||
///
|
||||
/// For SMB 3.0 and 3.0.2, the context is a fixed ASCII string.
|
||||
/// For SMB 3.1.1, the context is the preauthentication integrity hash value
|
||||
/// (64 bytes from SHA-512).
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if `dialect` is SMB 3.1.1 and `preauth_hash` is `None`.
|
||||
/// Panics if `dialect` is not in the SMB 3.x family.
|
||||
pub fn derive_session_keys(
|
||||
session_key: &[u8],
|
||||
dialect: Dialect,
|
||||
preauth_hash: Option<&[u8; 64]>,
|
||||
key_length_bits: u32,
|
||||
) -> DerivedKeys {
|
||||
assert!(
|
||||
matches!(
|
||||
dialect,
|
||||
Dialect::Smb3_0 | Dialect::Smb3_0_2 | Dialect::Smb3_1_1
|
||||
),
|
||||
"Key derivation is only applicable for the SMB 3.x dialect family"
|
||||
);
|
||||
|
||||
let (signing_label, signing_context): (&[u8], &[u8]);
|
||||
let (enc_label, enc_context): (&[u8], &[u8]);
|
||||
let (dec_label, dec_context): (&[u8], &[u8]);
|
||||
|
||||
if dialect == Dialect::Smb3_1_1 {
|
||||
let hash = preauth_hash
|
||||
.expect("SMB 3.1.1 requires a preauthentication integrity hash for key derivation");
|
||||
// SMB 3.1.1 labels include null terminator (matches smb-rs and
|
||||
// the MS-SMB2 spec's Label field definitions)
|
||||
signing_label = b"SMBSigningKey\0";
|
||||
signing_context = hash.as_slice();
|
||||
enc_label = b"SMBC2SCipherKey\0";
|
||||
enc_context = hash.as_slice();
|
||||
dec_label = b"SMBS2CCipherKey\0";
|
||||
dec_context = hash.as_slice();
|
||||
} else {
|
||||
// SMB 3.0 and 3.0.2
|
||||
signing_label = b"SMB2AESCMAC\0";
|
||||
signing_context = b"SmbSign\0";
|
||||
enc_label = b"SMB2AESCCM\0";
|
||||
enc_context = b"ServerIn \0";
|
||||
dec_label = b"SMB2AESCCM\0";
|
||||
dec_context = b"ServerOut\0";
|
||||
}
|
||||
|
||||
DerivedKeys {
|
||||
signing_key: sp800_108_kdf(session_key, signing_label, signing_context, key_length_bits),
|
||||
encryption_key: sp800_108_kdf(session_key, enc_label, enc_context, key_length_bits),
|
||||
decryption_key: sp800_108_kdf(session_key, dec_label, dec_context, key_length_bits),
|
||||
}
|
||||
}
|
||||
|
||||
/// Running hash over negotiate and session-setup exchange bytes.
|
||||
///
|
||||
/// Used as the "context" parameter to the KDF for SMB 3.1.1. The hash
|
||||
/// algorithm is SHA-512, producing a 64-byte value.
|
||||
///
|
||||
/// The hash is computed incrementally:
|
||||
/// 1. Initialize with 64 zero bytes
|
||||
/// 2. `update()` with negotiate request raw bytes
|
||||
/// 3. `update()` with negotiate response raw bytes
|
||||
/// 4. (Clone for session hash)
|
||||
/// 5. `update()` with session setup request raw bytes
|
||||
/// 6. `update()` with session setup response raw bytes
|
||||
/// 7. Repeat 5-6 for each SESSION_SETUP round-trip
|
||||
///
|
||||
/// Each `update()` computes: `hash = SHA-512(previous_hash || message_bytes)`
|
||||
pub struct PreauthHasher {
|
||||
hash: [u8; 64],
|
||||
}
|
||||
|
||||
impl PreauthHasher {
|
||||
/// Create a new hasher initialized with 64 zero bytes.
|
||||
pub fn new() -> Self {
|
||||
Self { hash: [0u8; 64] }
|
||||
}
|
||||
|
||||
/// Update the hash with a message's raw wire bytes.
|
||||
///
|
||||
/// Computes `hash = SHA-512(previous_hash || message_bytes)`.
|
||||
pub fn update(&mut self, message_bytes: &[u8]) {
|
||||
let mut hasher = Sha512::new();
|
||||
hasher.update(self.hash);
|
||||
hasher.update(message_bytes);
|
||||
self.hash.copy_from_slice(&hasher.finalize());
|
||||
}
|
||||
|
||||
/// Get the current hash value (64 bytes).
|
||||
pub fn value(&self) -> &[u8; 64] {
|
||||
&self.hash
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for PreauthHasher {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl Clone for PreauthHasher {
|
||||
fn clone(&self) -> Self {
|
||||
Self { hash: self.hash }
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// ========================================================================
|
||||
// SP800-108 KDF tests
|
||||
// ========================================================================
|
||||
|
||||
#[test]
|
||||
fn kdf_128_bit_output_is_16_bytes() {
|
||||
let key = [0xAA; 16];
|
||||
let result = sp800_108_kdf(&key, b"label\0", b"context\0", 128);
|
||||
assert_eq!(result.len(), 16);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn kdf_256_bit_output_is_32_bytes() {
|
||||
let key = [0xBB; 16];
|
||||
let result = sp800_108_kdf(&key, b"label\0", b"context\0", 256);
|
||||
assert_eq!(result.len(), 32);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn kdf_is_deterministic() {
|
||||
let key = [0x42; 16];
|
||||
let label = b"TestLabel\0";
|
||||
let context = b"TestContext\0";
|
||||
let r1 = sp800_108_kdf(&key, label, context, 128);
|
||||
let r2 = sp800_108_kdf(&key, label, context, 128);
|
||||
assert_eq!(r1, r2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn kdf_different_labels_produce_different_keys() {
|
||||
let key = [0x42; 16];
|
||||
let context = b"ctx\0";
|
||||
let k1 = sp800_108_kdf(&key, b"LabelA\0", context, 128);
|
||||
let k2 = sp800_108_kdf(&key, b"LabelB\0", context, 128);
|
||||
assert_ne!(k1, k2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn kdf_different_contexts_produce_different_keys() {
|
||||
let key = [0x42; 16];
|
||||
let label = b"label\0";
|
||||
let k1 = sp800_108_kdf(&key, label, b"ContextA\0", 128);
|
||||
let k2 = sp800_108_kdf(&key, label, b"ContextB\0", 128);
|
||||
assert_ne!(k1, k2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn kdf_different_session_keys_produce_different_derived_keys() {
|
||||
let label = b"SMB2AESCMAC\0";
|
||||
let context = b"SmbSign\0";
|
||||
let k1 = sp800_108_kdf(&[0x11; 16], label, context, 128);
|
||||
let k2 = sp800_108_kdf(&[0x22; 16], label, context, 128);
|
||||
assert_ne!(k1, k2);
|
||||
}
|
||||
|
||||
/// Verify KDF output against a manually computed value.
|
||||
///
|
||||
/// For a single iteration (128-bit output), the KDF computes:
|
||||
/// HMAC-SHA256(key, 0x00000001 || label || 0x00 || context || 0x00000080)
|
||||
/// and takes the first 16 bytes.
|
||||
#[test]
|
||||
fn kdf_known_vector_single_iteration() {
|
||||
let key = [0x00u8; 16];
|
||||
let label = b"SMB2AESCMAC\0";
|
||||
let context = b"SmbSign\0";
|
||||
|
||||
// Manually compute the expected value.
|
||||
let mut mac = HmacSha256::new_from_slice(&key).unwrap();
|
||||
mac.update(&1u32.to_be_bytes()); // counter = 1
|
||||
mac.update(label); // label
|
||||
mac.update(&[0x00]); // separator
|
||||
mac.update(context); // context
|
||||
mac.update(&128u32.to_be_bytes()); // L = 128
|
||||
let full = mac.finalize().into_bytes();
|
||||
let expected = &full[..16];
|
||||
|
||||
let result = sp800_108_kdf(&key, label, context, 128);
|
||||
assert_eq!(result.as_slice(), expected);
|
||||
}
|
||||
|
||||
/// Verify that 256-bit KDF uses two iterations and concatenates correctly.
|
||||
#[test]
|
||||
fn kdf_known_vector_two_iterations() {
|
||||
let key = [0xFFu8; 16];
|
||||
let label = b"TestLabel\0";
|
||||
let context = b"TestCtx\0";
|
||||
|
||||
// Compute iteration 1
|
||||
let mut mac1 = HmacSha256::new_from_slice(&key).unwrap();
|
||||
mac1.update(&1u32.to_be_bytes());
|
||||
mac1.update(label);
|
||||
mac1.update(&[0x00]);
|
||||
mac1.update(context);
|
||||
mac1.update(&256u32.to_be_bytes());
|
||||
let block1 = mac1.finalize().into_bytes();
|
||||
|
||||
// 256 bits = 32 bytes = exactly one HMAC-SHA256 block, so only one
|
||||
// iteration is needed. But let's verify with the formula:
|
||||
// ceil(256 / 256) = 1 iteration. So 256-bit also needs just one.
|
||||
let result = sp800_108_kdf(&key, label, context, 256);
|
||||
assert_eq!(result.len(), 32);
|
||||
assert_eq!(result.as_slice(), block1.as_slice());
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// derive_session_keys tests
|
||||
// ========================================================================
|
||||
|
||||
#[test]
|
||||
fn derive_keys_smb3_0_uses_legacy_labels() {
|
||||
let session_key = [0x42; 16];
|
||||
let keys = derive_session_keys(&session_key, Dialect::Smb3_0, None, 128);
|
||||
|
||||
// Verify each key matches what we'd get calling KDF directly with the
|
||||
// SMB 3.0 label/context pairs.
|
||||
assert_eq!(
|
||||
keys.signing_key,
|
||||
sp800_108_kdf(&session_key, b"SMB2AESCMAC\0", b"SmbSign\0", 128)
|
||||
);
|
||||
assert_eq!(
|
||||
keys.encryption_key,
|
||||
sp800_108_kdf(&session_key, b"SMB2AESCCM\0", b"ServerIn \0", 128)
|
||||
);
|
||||
assert_eq!(
|
||||
keys.decryption_key,
|
||||
sp800_108_kdf(&session_key, b"SMB2AESCCM\0", b"ServerOut\0", 128)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn derive_keys_smb3_0_2_uses_legacy_labels() {
|
||||
let session_key = [0x42; 16];
|
||||
let keys = derive_session_keys(&session_key, Dialect::Smb3_0_2, None, 128);
|
||||
|
||||
assert_eq!(
|
||||
keys.signing_key,
|
||||
sp800_108_kdf(&session_key, b"SMB2AESCMAC\0", b"SmbSign\0", 128)
|
||||
);
|
||||
assert_eq!(
|
||||
keys.encryption_key,
|
||||
sp800_108_kdf(&session_key, b"SMB2AESCCM\0", b"ServerIn \0", 128)
|
||||
);
|
||||
assert_eq!(
|
||||
keys.decryption_key,
|
||||
sp800_108_kdf(&session_key, b"SMB2AESCCM\0", b"ServerOut\0", 128)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn derive_keys_smb3_1_1_uses_new_labels_with_preauth_hash() {
|
||||
let session_key = [0x42; 16];
|
||||
let preauth_hash = [0xAB; 64];
|
||||
let keys = derive_session_keys(&session_key, Dialect::Smb3_1_1, Some(&preauth_hash), 128);
|
||||
|
||||
assert_eq!(
|
||||
keys.signing_key,
|
||||
sp800_108_kdf(&session_key, b"SMBSigningKey\0", &preauth_hash, 128)
|
||||
);
|
||||
assert_eq!(
|
||||
keys.encryption_key,
|
||||
sp800_108_kdf(&session_key, b"SMBC2SCipherKey\0", &preauth_hash, 128)
|
||||
);
|
||||
assert_eq!(
|
||||
keys.decryption_key,
|
||||
sp800_108_kdf(&session_key, b"SMBS2CCipherKey\0", &preauth_hash, 128)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn derive_keys_smb3_1_1_256_bit() {
|
||||
let session_key = [0x42; 16];
|
||||
let preauth_hash = [0xCD; 64];
|
||||
let keys = derive_session_keys(&session_key, Dialect::Smb3_1_1, Some(&preauth_hash), 256);
|
||||
|
||||
assert_eq!(keys.signing_key.len(), 32);
|
||||
assert_eq!(keys.encryption_key.len(), 32);
|
||||
assert_eq!(keys.decryption_key.len(), 32);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn derive_keys_all_three_are_different() {
|
||||
let session_key = [0x42; 16];
|
||||
let keys = derive_session_keys(&session_key, Dialect::Smb3_0, None, 128);
|
||||
|
||||
assert_ne!(keys.signing_key, keys.encryption_key);
|
||||
assert_ne!(keys.signing_key, keys.decryption_key);
|
||||
assert_ne!(keys.encryption_key, keys.decryption_key);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "preauthentication integrity hash")]
|
||||
fn derive_keys_smb3_1_1_panics_without_preauth_hash() {
|
||||
let session_key = [0x42; 16];
|
||||
derive_session_keys(&session_key, Dialect::Smb3_1_1, None, 128);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "SMB 3.x dialect family")]
|
||||
fn derive_keys_panics_for_smb2() {
|
||||
let session_key = [0x42; 16];
|
||||
derive_session_keys(&session_key, Dialect::Smb2_0_2, None, 128);
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// PreauthHasher tests
|
||||
// ========================================================================
|
||||
|
||||
#[test]
|
||||
fn preauth_hasher_starts_with_64_zero_bytes() {
|
||||
let hasher = PreauthHasher::new();
|
||||
assert_eq!(hasher.value(), &[0u8; 64]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn preauth_hasher_default_equals_new() {
|
||||
let h1 = PreauthHasher::new();
|
||||
let h2 = PreauthHasher::default();
|
||||
assert_eq!(h1.value(), h2.value());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn preauth_hasher_update_changes_hash() {
|
||||
let mut hasher = PreauthHasher::new();
|
||||
let initial = *hasher.value();
|
||||
hasher.update(b"negotiate request bytes");
|
||||
assert_ne!(hasher.value(), &initial);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn preauth_hasher_two_updates_differ_from_one() {
|
||||
let mut hasher1 = PreauthHasher::new();
|
||||
hasher1.update(b"message1");
|
||||
|
||||
let mut hasher2 = PreauthHasher::new();
|
||||
hasher2.update(b"message1");
|
||||
hasher2.update(b"message2");
|
||||
|
||||
assert_ne!(hasher1.value(), hasher2.value());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn preauth_hasher_is_deterministic() {
|
||||
let mut h1 = PreauthHasher::new();
|
||||
h1.update(b"negotiate request");
|
||||
h1.update(b"negotiate response");
|
||||
|
||||
let mut h2 = PreauthHasher::new();
|
||||
h2.update(b"negotiate request");
|
||||
h2.update(b"negotiate response");
|
||||
|
||||
assert_eq!(h1.value(), h2.value());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn preauth_hasher_empty_update_changes_hash() {
|
||||
// SHA-512(64_zeros || empty) != 64_zeros
|
||||
let mut hasher = PreauthHasher::new();
|
||||
let initial = *hasher.value();
|
||||
hasher.update(b"");
|
||||
assert_ne!(hasher.value(), &initial);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn preauth_hasher_known_value() {
|
||||
// Verify against direct SHA-512 computation.
|
||||
let mut hasher = PreauthHasher::new();
|
||||
hasher.update(b"test");
|
||||
|
||||
let mut expected_hasher = Sha512::new();
|
||||
expected_hasher.update([0u8; 64]);
|
||||
expected_hasher.update(b"test");
|
||||
let expected = expected_hasher.finalize();
|
||||
|
||||
assert_eq!(hasher.value().as_slice(), expected.as_slice());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn preauth_hasher_chained_known_value() {
|
||||
// Two updates: hash1 = SHA-512(zeros || msg1), hash2 = SHA-512(hash1 || msg2)
|
||||
let mut hasher = PreauthHasher::new();
|
||||
hasher.update(b"negotiate");
|
||||
hasher.update(b"response");
|
||||
|
||||
// Compute manually
|
||||
let mut h = Sha512::new();
|
||||
h.update([0u8; 64]);
|
||||
h.update(b"negotiate");
|
||||
let hash1: [u8; 64] = h.finalize().into();
|
||||
|
||||
let mut h2 = Sha512::new();
|
||||
h2.update(hash1);
|
||||
h2.update(b"response");
|
||||
let hash2: [u8; 64] = h2.finalize().into();
|
||||
|
||||
assert_eq!(hasher.value(), &hash2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn preauth_hasher_clone_is_independent() {
|
||||
let mut hasher = PreauthHasher::new();
|
||||
hasher.update(b"negotiate request");
|
||||
hasher.update(b"negotiate response");
|
||||
|
||||
// Clone for session hash (spec step 4)
|
||||
let mut session_hasher = hasher.clone();
|
||||
session_hasher.update(b"session setup request");
|
||||
|
||||
// Original should not be affected
|
||||
assert_ne!(hasher.value(), session_hasher.value());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn preauth_hasher_output_is_64_bytes() {
|
||||
let mut hasher = PreauthHasher::new();
|
||||
hasher.update(b"some data");
|
||||
assert_eq!(hasher.value().len(), 64);
|
||||
}
|
||||
|
||||
/// Full end-to-end test: preauth hash feeds into KDF for SMB 3.1.1.
|
||||
#[test]
|
||||
fn preauth_hash_feeds_into_kdf() {
|
||||
// Simulate the protocol flow
|
||||
let mut conn_hasher = PreauthHasher::new();
|
||||
conn_hasher.update(b"negotiate request bytes");
|
||||
conn_hasher.update(b"negotiate response bytes");
|
||||
|
||||
let mut session_hasher = conn_hasher.clone();
|
||||
session_hasher.update(b"session setup request bytes");
|
||||
session_hasher.update(b"session setup response bytes");
|
||||
|
||||
let session_key = [0x42; 16];
|
||||
let keys = derive_session_keys(
|
||||
&session_key,
|
||||
Dialect::Smb3_1_1,
|
||||
Some(session_hasher.value()),
|
||||
128,
|
||||
);
|
||||
|
||||
// Keys should all be 16 bytes and different from each other
|
||||
assert_eq!(keys.signing_key.len(), 16);
|
||||
assert_eq!(keys.encryption_key.len(), 16);
|
||||
assert_eq!(keys.decryption_key.len(), 16);
|
||||
assert_ne!(keys.signing_key, keys.encryption_key);
|
||||
assert_ne!(keys.signing_key, keys.decryption_key);
|
||||
assert_ne!(keys.encryption_key, keys.decryption_key);
|
||||
}
|
||||
}
|
||||
9
vendor/smb2/src/crypto/mod.rs
vendored
Normal file
9
vendor/smb2/src/crypto/mod.rs
vendored
Normal file
@@ -0,0 +1,9 @@
|
||||
//! Cryptographic operations for SMB2/3: signing, encryption, key derivation, and compression.
|
||||
//!
|
||||
//! Most users don't need this module directly -- [`SmbClient`](crate::SmbClient)
|
||||
//! handles signing and encryption automatically.
|
||||
|
||||
pub mod compression;
|
||||
pub mod encryption;
|
||||
pub mod kdf;
|
||||
pub mod signing;
|
||||
789
vendor/smb2/src/crypto/signing.rs
vendored
Normal file
789
vendor/smb2/src/crypto/signing.rs
vendored
Normal file
@@ -0,0 +1,789 @@
|
||||
//! SMB2 message signing and signature verification.
|
||||
//!
|
||||
//! Supports three signing algorithms, selected by negotiated dialect:
|
||||
//! - **HMAC-SHA256** (SMB 2.0.2, 2.1): 32-byte hash truncated to 16 bytes.
|
||||
//! - **AES-128-CMAC** (SMB 3.0, 3.0.2): 16-byte MAC.
|
||||
//! - **AES-256-GMAC** (SMB 3.1.1 with `SMB2_SIGNING_CAPABILITIES`): AES-256-GCM
|
||||
//! with empty plaintext; the 16-byte auth tag is the signature.
|
||||
//!
|
||||
//! Reference: MS-SMB2 sections 3.1.4.1 (signing) and 3.1.5.1 (verification).
|
||||
|
||||
use log::{debug, error, trace};
|
||||
|
||||
use crate::types::Dialect;
|
||||
use crate::Error;
|
||||
|
||||
/// Offset of the 16-byte Signature field within the SMB2 header.
|
||||
const SIGNATURE_OFFSET: usize = 48;
|
||||
/// Length of the Signature field.
|
||||
const SIGNATURE_LEN: usize = 16;
|
||||
/// Minimum message length (full SMB2 header).
|
||||
const MIN_MESSAGE_LEN: usize = 64;
|
||||
|
||||
/// Signing algorithm, determined by negotiated dialect and capabilities.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
|
||||
pub enum SigningAlgorithm {
|
||||
/// HMAC-SHA256 truncated to 16 bytes (SMB 2.0.2, 2.1).
|
||||
HmacSha256,
|
||||
/// AES-128-CMAC (SMB 3.0, 3.0.2).
|
||||
AesCmac,
|
||||
/// AES-256-GMAC with MessageId-based nonce (SMB 3.1.1).
|
||||
AesGmac,
|
||||
}
|
||||
|
||||
/// Select the appropriate signing algorithm for a dialect.
|
||||
///
|
||||
/// For SMB 3.1.1, `gmac_negotiated` indicates whether the peer negotiated
|
||||
/// `AES-256-GMAC` via `SMB2_SIGNING_CAPABILITIES`. When `false`, SMB 3.1.1
|
||||
/// falls back to AES-128-CMAC.
|
||||
pub fn algorithm_for_dialect(dialect: Dialect, gmac_negotiated: bool) -> SigningAlgorithm {
|
||||
match dialect {
|
||||
Dialect::Smb2_0_2 | Dialect::Smb2_1 => SigningAlgorithm::HmacSha256,
|
||||
Dialect::Smb3_0 | Dialect::Smb3_0_2 => SigningAlgorithm::AesCmac,
|
||||
Dialect::Smb3_1_1 => {
|
||||
if gmac_negotiated {
|
||||
SigningAlgorithm::AesGmac
|
||||
} else {
|
||||
SigningAlgorithm::AesCmac
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Sign an SMB2 message in-place (client → server).
|
||||
///
|
||||
/// Zeros the signature field (bytes 48-63), computes the signature
|
||||
/// over the full message, and writes the computed signature back.
|
||||
///
|
||||
/// For AES-GMAC, `message_id` and `is_cancel` are used to construct
|
||||
/// the 12-byte nonce. For other algorithms these parameters are ignored.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns [`Error::InvalidData`] if the message is shorter than 64 bytes
|
||||
/// or the key length is wrong for the chosen algorithm.
|
||||
pub fn sign_message(
|
||||
message: &mut [u8],
|
||||
key: &[u8],
|
||||
algorithm: SigningAlgorithm,
|
||||
message_id: u64,
|
||||
is_cancel: bool,
|
||||
) -> Result<(), Error> {
|
||||
if message.len() < MIN_MESSAGE_LEN {
|
||||
return Err(Error::invalid_data(format!(
|
||||
"message too short for signing: {} bytes, need at least {}",
|
||||
message.len(),
|
||||
MIN_MESSAGE_LEN
|
||||
)));
|
||||
}
|
||||
|
||||
// Step 1: zero the signature field.
|
||||
message[SIGNATURE_OFFSET..SIGNATURE_OFFSET + SIGNATURE_LEN].fill(0);
|
||||
|
||||
// Step 2: compute signature over the entire message.
|
||||
// is_response = false: we're the client, signing an outgoing request.
|
||||
let signature = compute_signature(message, key, algorithm, message_id, is_cancel, false)?;
|
||||
|
||||
// Step 3: write the signature back.
|
||||
message[SIGNATURE_OFFSET..SIGNATURE_OFFSET + SIGNATURE_LEN].copy_from_slice(&signature);
|
||||
|
||||
debug!(
|
||||
"signing: signed msg_id={}, algo={:?}, sig={:02x}{:02x}{:02x}{:02x}...",
|
||||
message_id, algorithm, signature[0], signature[1], signature[2], signature[3]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Verify the signature on a received SMB2 message (server → client).
|
||||
///
|
||||
/// Returns `Ok(())` if the signature matches, or [`Error::InvalidData`]
|
||||
/// if the message is tampered or the key is wrong.
|
||||
///
|
||||
/// For GMAC, the nonce role bit is set to 1 (server) automatically.
|
||||
pub fn verify_signature(
|
||||
message: &[u8],
|
||||
key: &[u8],
|
||||
algorithm: SigningAlgorithm,
|
||||
message_id: u64,
|
||||
is_cancel: bool,
|
||||
) -> Result<(), Error> {
|
||||
if message.len() < MIN_MESSAGE_LEN {
|
||||
return Err(Error::invalid_data(format!(
|
||||
"message too short for verification: {} bytes, need at least {}",
|
||||
message.len(),
|
||||
MIN_MESSAGE_LEN
|
||||
)));
|
||||
}
|
||||
|
||||
// Step 1: save the received signature.
|
||||
let mut received_sig = [0u8; SIGNATURE_LEN];
|
||||
received_sig.copy_from_slice(&message[SIGNATURE_OFFSET..SIGNATURE_OFFSET + SIGNATURE_LEN]);
|
||||
|
||||
// Step 2: zero the signature field in a copy.
|
||||
let mut buf = message.to_vec();
|
||||
buf[SIGNATURE_OFFSET..SIGNATURE_OFFSET + SIGNATURE_LEN].fill(0);
|
||||
|
||||
// Step 3: compute the expected signature.
|
||||
// is_response = true: the server signed this message, so the GMAC
|
||||
// nonce must have role bit = 1 (server).
|
||||
let expected_sig = compute_signature(&buf, key, algorithm, message_id, is_cancel, true)?;
|
||||
|
||||
// Step 4: compare.
|
||||
if received_sig != expected_sig {
|
||||
error!(
|
||||
"signing: verification failed, msg_id={}, algo={:?}, got={:02x}{:02x}{:02x}{:02x}..., want={:02x}{:02x}{:02x}{:02x}...",
|
||||
message_id, algorithm,
|
||||
received_sig[0], received_sig[1], received_sig[2], received_sig[3],
|
||||
expected_sig[0], expected_sig[1], expected_sig[2], expected_sig[3]
|
||||
);
|
||||
return Err(Error::invalid_data("signature verification failed"));
|
||||
}
|
||||
|
||||
trace!(
|
||||
"signing: verified msg_id={}, algo={:?}, sig={:02x}{:02x}{:02x}{:02x}...",
|
||||
message_id,
|
||||
algorithm,
|
||||
received_sig[0],
|
||||
received_sig[1],
|
||||
received_sig[2],
|
||||
received_sig[3]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Compute a 16-byte signature over `message` using the given algorithm.
|
||||
fn compute_signature(
|
||||
message: &[u8],
|
||||
key: &[u8],
|
||||
algorithm: SigningAlgorithm,
|
||||
message_id: u64,
|
||||
is_cancel: bool,
|
||||
is_response: bool,
|
||||
) -> Result<[u8; 16], Error> {
|
||||
match algorithm {
|
||||
SigningAlgorithm::HmacSha256 => compute_hmac_sha256(message, key),
|
||||
SigningAlgorithm::AesCmac => compute_aes_cmac(message, key),
|
||||
SigningAlgorithm::AesGmac => {
|
||||
compute_aes_gmac(message, key, message_id, is_cancel, is_response)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// HMAC-SHA256, truncated to 16 bytes. Key must be 16 bytes.
|
||||
fn compute_hmac_sha256(message: &[u8], key: &[u8]) -> Result<[u8; 16], Error> {
|
||||
use digest::KeyInit;
|
||||
use hmac::{Hmac, Mac};
|
||||
use sha2::Sha256;
|
||||
|
||||
type HmacSha256 = Hmac<Sha256>;
|
||||
|
||||
let mut mac = HmacSha256::new_from_slice(key)
|
||||
.map_err(|e| Error::invalid_data(format!("HMAC-SHA256 key error: {e}")))?;
|
||||
mac.update(message);
|
||||
let result = mac.finalize().into_bytes();
|
||||
|
||||
// Truncate 32-byte hash to first 16 bytes.
|
||||
let mut sig = [0u8; 16];
|
||||
sig.copy_from_slice(&result[..16]);
|
||||
Ok(sig)
|
||||
}
|
||||
|
||||
/// AES-128-CMAC. Key must be 16 bytes.
|
||||
fn compute_aes_cmac(message: &[u8], key: &[u8]) -> Result<[u8; 16], Error> {
|
||||
use aes::Aes128;
|
||||
use cmac::{Cmac, Mac};
|
||||
use digest::KeyInit;
|
||||
|
||||
type AesCmac = Cmac<Aes128>;
|
||||
|
||||
let mut mac = AesCmac::new_from_slice(key)
|
||||
.map_err(|e| Error::invalid_data(format!("AES-CMAC key error: {e}")))?;
|
||||
mac.update(message);
|
||||
let result = mac.finalize().into_bytes();
|
||||
|
||||
let mut sig = [0u8; 16];
|
||||
sig.copy_from_slice(&result);
|
||||
Ok(sig)
|
||||
}
|
||||
|
||||
/// AES-128-GMAC (AES-128-GCM with empty plaintext). Key must be 16 bytes.
|
||||
///
|
||||
/// The 12-byte nonce is constructed as (MS-SMB2 section 3.1.4.1):
|
||||
/// - Bytes 0-7: `message_id` (little-endian u64)
|
||||
/// - Byte 8: bit 0 = role (0=client, 1=server), bit 1 = `is_cancel`
|
||||
/// - Bytes 9-11: zero
|
||||
fn compute_aes_gmac(
|
||||
message: &[u8],
|
||||
key: &[u8],
|
||||
message_id: u64,
|
||||
is_cancel: bool,
|
||||
is_response: bool,
|
||||
) -> Result<[u8; 16], Error> {
|
||||
use aes_gcm::aead::Aead;
|
||||
use aes_gcm::{Aes128Gcm, KeyInit, Nonce};
|
||||
|
||||
if key.len() != 16 {
|
||||
return Err(Error::invalid_data(format!(
|
||||
"AES-128-GMAC requires a 16-byte key, got {} bytes",
|
||||
key.len()
|
||||
)));
|
||||
}
|
||||
|
||||
// Build 12-byte nonce.
|
||||
let mut nonce_bytes = [0u8; 12];
|
||||
nonce_bytes[0..8].copy_from_slice(&message_id.to_le_bytes());
|
||||
// Byte 8: bit 0 = role (0 = client, 1 = server), bit 1 = CANCEL flag.
|
||||
let mut flags_byte: u8 = 0;
|
||||
if is_response {
|
||||
flags_byte |= 0x01; // server role
|
||||
}
|
||||
if is_cancel {
|
||||
flags_byte |= 0x02;
|
||||
}
|
||||
nonce_bytes[8] = flags_byte;
|
||||
|
||||
let cipher = Aes128Gcm::new(key.try_into().map_err(|_| {
|
||||
Error::invalid_data(format!(
|
||||
"AES-128-GMAC requires a 16-byte key, got {} bytes",
|
||||
key.len()
|
||||
))
|
||||
})?);
|
||||
let nonce: &Nonce<_> = (&nonce_bytes).into();
|
||||
|
||||
// GMAC mode: encrypt empty plaintext with the message as AAD.
|
||||
// The "ciphertext" is empty; the auth tag IS the signature.
|
||||
use aes_gcm::aead::Payload;
|
||||
let payload = Payload {
|
||||
msg: &[],
|
||||
aad: message,
|
||||
};
|
||||
|
||||
let ciphertext = cipher
|
||||
.encrypt(nonce, payload)
|
||||
.map_err(|e| Error::invalid_data(format!("AES-256-GMAC encryption error: {e}")))?;
|
||||
|
||||
// The output is the 16-byte auth tag (no ciphertext bytes since plaintext was empty).
|
||||
if ciphertext.len() != 16 {
|
||||
return Err(Error::invalid_data(format!(
|
||||
"unexpected GMAC output length: expected 16, got {}",
|
||||
ciphertext.len()
|
||||
)));
|
||||
}
|
||||
|
||||
let mut sig = [0u8; 16];
|
||||
sig.copy_from_slice(&ciphertext);
|
||||
Ok(sig)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
/// Build a minimal 64-byte fake SMB2 message for testing.
|
||||
/// The signature field (bytes 48-63) is zeroed.
|
||||
fn make_test_message(body_extra: &[u8]) -> Vec<u8> {
|
||||
let mut msg = vec![0u8; 64 + body_extra.len()];
|
||||
// Protocol ID
|
||||
msg[0..4].copy_from_slice(&[0xFE, b'S', b'M', b'B']);
|
||||
// Structure size = 64
|
||||
msg[4..6].copy_from_slice(&64u16.to_le_bytes());
|
||||
// Fill some fields so the message isn't all zeros
|
||||
msg[12..14].copy_from_slice(&0x0008u16.to_le_bytes()); // Command = Read
|
||||
msg[24..32].copy_from_slice(&42u64.to_le_bytes()); // MessageId = 42
|
||||
// Append body
|
||||
msg[64..].copy_from_slice(body_extra);
|
||||
msg
|
||||
}
|
||||
|
||||
// ── algorithm_for_dialect ─────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn algorithm_for_smb2_0_2_is_hmac_sha256() {
|
||||
assert_eq!(
|
||||
algorithm_for_dialect(Dialect::Smb2_0_2, false),
|
||||
SigningAlgorithm::HmacSha256
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn algorithm_for_smb2_1_is_hmac_sha256() {
|
||||
assert_eq!(
|
||||
algorithm_for_dialect(Dialect::Smb2_1, false),
|
||||
SigningAlgorithm::HmacSha256
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn algorithm_for_smb3_0_is_aes_cmac() {
|
||||
assert_eq!(
|
||||
algorithm_for_dialect(Dialect::Smb3_0, false),
|
||||
SigningAlgorithm::AesCmac
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn algorithm_for_smb3_0_2_is_aes_cmac() {
|
||||
assert_eq!(
|
||||
algorithm_for_dialect(Dialect::Smb3_0_2, false),
|
||||
SigningAlgorithm::AesCmac
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn algorithm_for_smb3_1_1_without_gmac_is_aes_cmac() {
|
||||
assert_eq!(
|
||||
algorithm_for_dialect(Dialect::Smb3_1_1, false),
|
||||
SigningAlgorithm::AesCmac
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn algorithm_for_smb3_1_1_with_gmac_is_aes_gmac() {
|
||||
assert_eq!(
|
||||
algorithm_for_dialect(Dialect::Smb3_1_1, true),
|
||||
SigningAlgorithm::AesGmac
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn gmac_flag_ignored_for_older_dialects() {
|
||||
// Even if gmac_negotiated is true, older dialects don't use GMAC.
|
||||
assert_eq!(
|
||||
algorithm_for_dialect(Dialect::Smb2_0_2, true),
|
||||
SigningAlgorithm::HmacSha256
|
||||
);
|
||||
assert_eq!(
|
||||
algorithm_for_dialect(Dialect::Smb3_0, true),
|
||||
SigningAlgorithm::AesCmac
|
||||
);
|
||||
}
|
||||
|
||||
// ── Message too short ─────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn sign_rejects_message_shorter_than_64_bytes() {
|
||||
let mut msg = vec![0u8; 32];
|
||||
let key = [0u8; 16];
|
||||
let result = sign_message(&mut msg, &key, SigningAlgorithm::HmacSha256, 0, false);
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().to_string().contains("too short"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn verify_rejects_message_shorter_than_64_bytes() {
|
||||
let msg = vec![0u8; 32];
|
||||
let key = [0u8; 16];
|
||||
let result = verify_signature(&msg, &key, SigningAlgorithm::HmacSha256, 0, false);
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().to_string().contains("too short"));
|
||||
}
|
||||
|
||||
// ── HMAC-SHA256 ──────────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn hmac_sha256_sign_produces_nonzero_signature() {
|
||||
let mut msg = make_test_message(b"hello world");
|
||||
let key = [0xAA; 16];
|
||||
sign_message(&mut msg, &key, SigningAlgorithm::HmacSha256, 0, false).unwrap();
|
||||
|
||||
let sig = &msg[SIGNATURE_OFFSET..SIGNATURE_OFFSET + SIGNATURE_LEN];
|
||||
assert_ne!(sig, &[0u8; 16], "signature should not be all zeros");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hmac_sha256_known_signature() {
|
||||
// Compute expected HMAC-SHA256 using the same process:
|
||||
// zero sig field, compute HMAC, truncate to 16 bytes.
|
||||
let mut msg = make_test_message(&[]);
|
||||
let key = [0x01; 16];
|
||||
|
||||
// Manually compute expected value.
|
||||
let mut zeroed = msg.clone();
|
||||
zeroed[SIGNATURE_OFFSET..SIGNATURE_OFFSET + SIGNATURE_LEN].fill(0);
|
||||
let expected = {
|
||||
use digest::KeyInit;
|
||||
use hmac::{Hmac, Mac};
|
||||
use sha2::Sha256;
|
||||
type H = Hmac<Sha256>;
|
||||
let mut mac = H::new_from_slice(&key).unwrap();
|
||||
mac.update(&zeroed);
|
||||
let full = mac.finalize().into_bytes();
|
||||
let mut trunc = [0u8; 16];
|
||||
trunc.copy_from_slice(&full[..16]);
|
||||
trunc
|
||||
};
|
||||
|
||||
sign_message(&mut msg, &key, SigningAlgorithm::HmacSha256, 0, false).unwrap();
|
||||
assert_eq!(
|
||||
&msg[SIGNATURE_OFFSET..SIGNATURE_OFFSET + SIGNATURE_LEN],
|
||||
&expected
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hmac_sha256_sign_then_verify_roundtrip() {
|
||||
let mut msg = make_test_message(b"some payload data");
|
||||
let key = [0x42; 16];
|
||||
sign_message(&mut msg, &key, SigningAlgorithm::HmacSha256, 0, false).unwrap();
|
||||
verify_signature(&msg, &key, SigningAlgorithm::HmacSha256, 0, false).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hmac_sha256_verify_fails_on_tampered_message() {
|
||||
let mut msg = make_test_message(b"original data");
|
||||
let key = [0x42; 16];
|
||||
sign_message(&mut msg, &key, SigningAlgorithm::HmacSha256, 0, false).unwrap();
|
||||
|
||||
// Flip a byte in the body.
|
||||
let last = msg.len() - 1;
|
||||
msg[last] ^= 0xFF;
|
||||
|
||||
let result = verify_signature(&msg, &key, SigningAlgorithm::HmacSha256, 0, false);
|
||||
assert!(result.is_err());
|
||||
assert!(result
|
||||
.unwrap_err()
|
||||
.to_string()
|
||||
.contains("verification failed"),);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hmac_sha256_verify_fails_with_wrong_key() {
|
||||
let mut msg = make_test_message(b"data");
|
||||
let key = [0x42; 16];
|
||||
sign_message(&mut msg, &key, SigningAlgorithm::HmacSha256, 0, false).unwrap();
|
||||
|
||||
let wrong_key = [0x43; 16];
|
||||
let result = verify_signature(&msg, &wrong_key, SigningAlgorithm::HmacSha256, 0, false);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
// ── AES-128-CMAC ────────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn aes_cmac_sign_produces_nonzero_signature() {
|
||||
let mut msg = make_test_message(b"cmac test");
|
||||
let key = [0xBB; 16];
|
||||
sign_message(&mut msg, &key, SigningAlgorithm::AesCmac, 0, false).unwrap();
|
||||
|
||||
let sig = &msg[SIGNATURE_OFFSET..SIGNATURE_OFFSET + SIGNATURE_LEN];
|
||||
assert_ne!(sig, &[0u8; 16]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn aes_cmac_known_signature() {
|
||||
let mut msg = make_test_message(&[]);
|
||||
let key = [0x02; 16];
|
||||
|
||||
let mut zeroed = msg.clone();
|
||||
zeroed[SIGNATURE_OFFSET..SIGNATURE_OFFSET + SIGNATURE_LEN].fill(0);
|
||||
let expected = {
|
||||
use aes::Aes128;
|
||||
use cmac::{Cmac, Mac};
|
||||
use digest::KeyInit;
|
||||
type C = Cmac<Aes128>;
|
||||
let mut mac = C::new_from_slice(&key).unwrap();
|
||||
mac.update(&zeroed);
|
||||
let result = mac.finalize().into_bytes();
|
||||
let mut sig = [0u8; 16];
|
||||
sig.copy_from_slice(&result);
|
||||
sig
|
||||
};
|
||||
|
||||
sign_message(&mut msg, &key, SigningAlgorithm::AesCmac, 0, false).unwrap();
|
||||
assert_eq!(
|
||||
&msg[SIGNATURE_OFFSET..SIGNATURE_OFFSET + SIGNATURE_LEN],
|
||||
&expected
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn aes_cmac_sign_then_verify_roundtrip() {
|
||||
let mut msg = make_test_message(b"cmac roundtrip payload");
|
||||
let key = [0x55; 16];
|
||||
sign_message(&mut msg, &key, SigningAlgorithm::AesCmac, 0, false).unwrap();
|
||||
verify_signature(&msg, &key, SigningAlgorithm::AesCmac, 0, false).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn aes_cmac_verify_fails_on_tampered_message() {
|
||||
let mut msg = make_test_message(b"cmac original");
|
||||
let key = [0x55; 16];
|
||||
sign_message(&mut msg, &key, SigningAlgorithm::AesCmac, 0, false).unwrap();
|
||||
|
||||
msg[10] ^= 0xFF;
|
||||
|
||||
let result = verify_signature(&msg, &key, SigningAlgorithm::AesCmac, 0, false);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn aes_cmac_verify_fails_with_wrong_key() {
|
||||
let mut msg = make_test_message(b"cmac data");
|
||||
let key = [0x55; 16];
|
||||
sign_message(&mut msg, &key, SigningAlgorithm::AesCmac, 0, false).unwrap();
|
||||
|
||||
let wrong_key = [0x56; 16];
|
||||
let result = verify_signature(&msg, &wrong_key, SigningAlgorithm::AesCmac, 0, false);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
// ── AES-128-GMAC ────────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn aes_gmac_sign_produces_nonzero_signature() {
|
||||
let mut msg = make_test_message(b"gmac test");
|
||||
let key = [0xCC; 16];
|
||||
sign_message(&mut msg, &key, SigningAlgorithm::AesGmac, 1, false).unwrap();
|
||||
|
||||
let sig = &msg[SIGNATURE_OFFSET..SIGNATURE_OFFSET + SIGNATURE_LEN];
|
||||
assert_ne!(sig, &[0u8; 16]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn aes_gmac_known_signature() {
|
||||
let mut msg = make_test_message(&[]);
|
||||
let key = [0x03; 16];
|
||||
let message_id: u64 = 7;
|
||||
|
||||
let mut zeroed = msg.clone();
|
||||
zeroed[SIGNATURE_OFFSET..SIGNATURE_OFFSET + SIGNATURE_LEN].fill(0);
|
||||
let expected = {
|
||||
use aes_gcm::aead::{Aead, Payload};
|
||||
use aes_gcm::{Aes128Gcm, KeyInit, Nonce};
|
||||
|
||||
let mut nonce_bytes = [0u8; 12];
|
||||
nonce_bytes[0..8].copy_from_slice(&message_id.to_le_bytes());
|
||||
// not cancel, client role -> byte 8 = 0
|
||||
|
||||
let cipher = Aes128Gcm::new((&key).into());
|
||||
let nonce: &Nonce<_> = (&nonce_bytes).into();
|
||||
let payload = Payload {
|
||||
msg: &[],
|
||||
aad: &zeroed,
|
||||
};
|
||||
let ct = cipher.encrypt(nonce, payload).unwrap();
|
||||
let mut sig = [0u8; 16];
|
||||
sig.copy_from_slice(&ct);
|
||||
sig
|
||||
};
|
||||
|
||||
sign_message(&mut msg, &key, SigningAlgorithm::AesGmac, message_id, false).unwrap();
|
||||
assert_eq!(
|
||||
&msg[SIGNATURE_OFFSET..SIGNATURE_OFFSET + SIGNATURE_LEN],
|
||||
&expected
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn aes_gmac_sign_then_verify_roundtrip() {
|
||||
// sign_message uses client role (is_response=false internally),
|
||||
// verify_signature uses server role (is_response=true internally).
|
||||
// For a self-roundtrip test, we need to test sign+verify on the
|
||||
// same role. Use the internal compute_signature directly, or
|
||||
// just verify that a real server flow works (sign as client,
|
||||
// verify as server would compute -- but that's an integration test).
|
||||
//
|
||||
// For this unit test, verify that sign→verify works when the
|
||||
// message has the SERVER_TO_REDIR flag set (simulating a
|
||||
// response that we signed ourselves for testing).
|
||||
let mut msg = make_test_message(b"gmac roundtrip payload");
|
||||
// Set SERVER_TO_REDIR flag so verify_signature uses server role bit
|
||||
let flags = u32::from_le_bytes(msg[16..20].try_into().unwrap());
|
||||
let new_flags = flags | 0x0000_0001; // SERVER_TO_REDIR
|
||||
msg[16..20].copy_from_slice(&new_flags.to_le_bytes());
|
||||
|
||||
let key = [0xDD; 16];
|
||||
// Sign with is_response=false (client), but verify_signature
|
||||
// always uses is_response=true (server). So we need to compute
|
||||
// the signature manually with is_response=true to make roundtrip work.
|
||||
// Actually, let's just test that sign and verify produce consistent
|
||||
// results by testing each direction independently.
|
||||
|
||||
// Test: sign as client (role=0), verify we can detect tampering
|
||||
sign_message(&mut msg, &key, SigningAlgorithm::AesGmac, 100, false).unwrap();
|
||||
// verify_signature uses role=1 (server), so it WON'T match client-signed.
|
||||
// This is correct behavior -- client and server signatures differ.
|
||||
// Instead, test that the signature is non-zero and stable.
|
||||
let sig1: [u8; 16] = msg[SIGNATURE_OFFSET..SIGNATURE_OFFSET + SIGNATURE_LEN]
|
||||
.try_into()
|
||||
.unwrap();
|
||||
assert_ne!(sig1, [0u8; 16]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn aes_gmac_verify_fails_on_tampered_message() {
|
||||
let mut msg = make_test_message(b"gmac original");
|
||||
let key = [0xDD; 16];
|
||||
sign_message(&mut msg, &key, SigningAlgorithm::AesGmac, 5, false).unwrap();
|
||||
|
||||
// Tamper the message -- even though verify uses server role,
|
||||
// the auth tag won't match ANY valid signature.
|
||||
let last = msg.len() - 1;
|
||||
msg[last] ^= 0xFF;
|
||||
|
||||
let result = verify_signature(&msg, &key, SigningAlgorithm::AesGmac, 5, false);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn aes_gmac_verify_fails_with_wrong_key() {
|
||||
let mut msg = make_test_message(b"gmac data");
|
||||
let key = [0xDD; 16];
|
||||
sign_message(&mut msg, &key, SigningAlgorithm::AesGmac, 5, false).unwrap();
|
||||
|
||||
let wrong_key = [0xDE; 16];
|
||||
let result = verify_signature(&msg, &wrong_key, SigningAlgorithm::AesGmac, 5, false);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn aes_gmac_rejects_wrong_key_length() {
|
||||
let mut msg = make_test_message(&[]);
|
||||
let key = [0xDD; 32]; // 32 bytes instead of 16
|
||||
let result = sign_message(&mut msg, &key, SigningAlgorithm::AesGmac, 0, false);
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().to_string().contains("16-byte key"));
|
||||
}
|
||||
|
||||
// ── GMAC nonce construction ─────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn aes_gmac_nonce_contains_message_id() {
|
||||
// Different MessageIds must produce different signatures on the same message+key.
|
||||
let key = [0xEE; 16];
|
||||
|
||||
let mut msg1 = make_test_message(b"nonce test");
|
||||
sign_message(&mut msg1, &key, SigningAlgorithm::AesGmac, 1, false).unwrap();
|
||||
let sig1: [u8; 16] = msg1[SIGNATURE_OFFSET..SIGNATURE_OFFSET + SIGNATURE_LEN]
|
||||
.try_into()
|
||||
.unwrap();
|
||||
|
||||
let mut msg2 = make_test_message(b"nonce test");
|
||||
sign_message(&mut msg2, &key, SigningAlgorithm::AesGmac, 2, false).unwrap();
|
||||
let sig2: [u8; 16] = msg2[SIGNATURE_OFFSET..SIGNATURE_OFFSET + SIGNATURE_LEN]
|
||||
.try_into()
|
||||
.unwrap();
|
||||
|
||||
assert_ne!(
|
||||
sig1, sig2,
|
||||
"different MessageIds must produce different signatures"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn aes_gmac_cancel_bit_changes_signature() {
|
||||
let key = [0xEE; 16];
|
||||
let message_id = 42u64;
|
||||
|
||||
let mut msg_normal = make_test_message(b"cancel test");
|
||||
sign_message(
|
||||
&mut msg_normal,
|
||||
&key,
|
||||
SigningAlgorithm::AesGmac,
|
||||
message_id,
|
||||
false,
|
||||
)
|
||||
.unwrap();
|
||||
let sig_normal: [u8; 16] = msg_normal[SIGNATURE_OFFSET..SIGNATURE_OFFSET + SIGNATURE_LEN]
|
||||
.try_into()
|
||||
.unwrap();
|
||||
|
||||
let mut msg_cancel = make_test_message(b"cancel test");
|
||||
sign_message(
|
||||
&mut msg_cancel,
|
||||
&key,
|
||||
SigningAlgorithm::AesGmac,
|
||||
message_id,
|
||||
true,
|
||||
)
|
||||
.unwrap();
|
||||
let sig_cancel: [u8; 16] = msg_cancel[SIGNATURE_OFFSET..SIGNATURE_OFFSET + SIGNATURE_LEN]
|
||||
.try_into()
|
||||
.unwrap();
|
||||
|
||||
assert_ne!(
|
||||
sig_normal, sig_cancel,
|
||||
"CANCEL bit must produce a different signature"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn aes_gmac_cancel_bit_is_bit_1_of_byte_8() {
|
||||
// Verify the nonce byte 8 value directly by checking that
|
||||
// the CANCEL nonce has 0x02 at byte 8 (bit 1), not 0x01 (bit 0).
|
||||
let message_id: u64 = 99;
|
||||
|
||||
let mut nonce_normal = [0u8; 12];
|
||||
nonce_normal[0..8].copy_from_slice(&message_id.to_le_bytes());
|
||||
// is_cancel = false -> byte 8 stays 0x00
|
||||
|
||||
let mut nonce_cancel = [0u8; 12];
|
||||
nonce_cancel[0..8].copy_from_slice(&message_id.to_le_bytes());
|
||||
nonce_cancel[8] = 0x02; // bit 1 set, NOT bit 0
|
||||
|
||||
assert_eq!(nonce_normal[8], 0x00);
|
||||
assert_eq!(nonce_cancel[8], 0x02);
|
||||
// Bit 0 (role bit) is always 0 for client.
|
||||
assert_eq!(nonce_cancel[8] & 0x01, 0x00);
|
||||
}
|
||||
|
||||
// ── Signature field location ────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn signature_field_is_at_bytes_48_through_63() {
|
||||
let mut msg = make_test_message(&[]);
|
||||
let key = [0xFF; 16];
|
||||
|
||||
// Set a marker pattern in bytes 48-63 to verify they get overwritten.
|
||||
msg[SIGNATURE_OFFSET..SIGNATURE_OFFSET + SIGNATURE_LEN].copy_from_slice(&[0xAA; 16]);
|
||||
|
||||
sign_message(&mut msg, &key, SigningAlgorithm::HmacSha256, 0, false).unwrap();
|
||||
|
||||
// The marker should be gone, replaced by the computed signature.
|
||||
let sig = &msg[SIGNATURE_OFFSET..SIGNATURE_OFFSET + SIGNATURE_LEN];
|
||||
assert_ne!(sig, &[0xAA; 16], "signature field must be overwritten");
|
||||
assert_ne!(sig, &[0x00; 16], "signature should not be all zeros");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bytes_outside_signature_field_are_preserved() {
|
||||
let body = b"preserve me";
|
||||
let mut msg = make_test_message(body);
|
||||
let original_body = msg[64..].to_vec();
|
||||
let original_header_prefix = msg[0..SIGNATURE_OFFSET].to_vec();
|
||||
|
||||
let key = [0xFF; 16];
|
||||
sign_message(&mut msg, &key, SigningAlgorithm::HmacSha256, 0, false).unwrap();
|
||||
|
||||
// Header bytes before signature are unchanged.
|
||||
assert_eq!(&msg[0..SIGNATURE_OFFSET], &original_header_prefix);
|
||||
// Body is unchanged.
|
||||
assert_eq!(&msg[64..], &original_body);
|
||||
}
|
||||
|
||||
// ── Cross-algorithm: verify with wrong algorithm fails ──────────
|
||||
|
||||
#[test]
|
||||
fn verify_with_wrong_algorithm_fails() {
|
||||
let mut msg = make_test_message(b"cross algo");
|
||||
let key = [0x77; 16];
|
||||
sign_message(&mut msg, &key, SigningAlgorithm::HmacSha256, 0, false).unwrap();
|
||||
|
||||
let result = verify_signature(&msg, &key, SigningAlgorithm::AesCmac, 0, false);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
// ── GMAC: verify with wrong message_id fails ────────────────────
|
||||
|
||||
#[test]
|
||||
fn aes_gmac_verify_with_wrong_message_id_fails() {
|
||||
let mut msg = make_test_message(b"msg id test");
|
||||
let key = [0xDD; 16];
|
||||
sign_message(&mut msg, &key, SigningAlgorithm::AesGmac, 10, false).unwrap();
|
||||
|
||||
// verify uses server role bit, and wrong message_id -- both wrong
|
||||
let result = verify_signature(&msg, &key, SigningAlgorithm::AesGmac, 11, false);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
}
|
||||
389
vendor/smb2/src/error.rs
vendored
Normal file
389
vendor/smb2/src/error.rs
vendored
Normal file
@@ -0,0 +1,389 @@
|
||||
//! Error types for the SMB2 library.
|
||||
|
||||
use crate::types::status::NtStatus;
|
||||
use crate::types::Command;
|
||||
use thiserror::Error;
|
||||
|
||||
/// Top-level error type for SMB2 operations.
|
||||
#[derive(Debug, Error)]
|
||||
pub enum Error {
|
||||
/// The data is malformed or does not match the expected format.
|
||||
#[error("Invalid data: {message}")]
|
||||
InvalidData {
|
||||
/// Description of what went wrong.
|
||||
message: String,
|
||||
},
|
||||
|
||||
/// The server returned a non-success NTSTATUS.
|
||||
#[error("Protocol error: {status} during {command:?}")]
|
||||
Protocol {
|
||||
/// The NTSTATUS code from the response header.
|
||||
status: NtStatus,
|
||||
/// The command that triggered the error.
|
||||
command: Command,
|
||||
},
|
||||
|
||||
/// Authentication failed.
|
||||
#[error("Authentication failed: {message}")]
|
||||
Auth {
|
||||
/// Description of what went wrong.
|
||||
message: String,
|
||||
},
|
||||
|
||||
/// An I/O or transport error occurred.
|
||||
#[error("I/O error: {0}")]
|
||||
Io(#[from] std::io::Error),
|
||||
|
||||
/// The operation timed out.
|
||||
#[error("Operation timed out")]
|
||||
Timeout,
|
||||
|
||||
/// The connection was lost.
|
||||
#[error("Disconnected from server")]
|
||||
Disconnected,
|
||||
|
||||
/// The path requires DFS referral resolution.
|
||||
///
|
||||
/// The server returned `STATUS_PATH_NOT_COVERED`, meaning this path
|
||||
/// lives on a different server via DFS. The caller can query for a
|
||||
/// referral or display a helpful message.
|
||||
#[error("DFS referral required for path: {path}")]
|
||||
DfsReferralRequired {
|
||||
/// The path that needs DFS resolution.
|
||||
path: String,
|
||||
},
|
||||
|
||||
/// The operation was cancelled by the caller (via progress callback).
|
||||
#[error("Operation cancelled")]
|
||||
Cancelled,
|
||||
|
||||
/// The session expired and reauthentication failed.
|
||||
///
|
||||
/// The pipeline normally handles `STATUS_NETWORK_SESSION_EXPIRED`
|
||||
/// transparently by reauthenticating. This error surfaces only
|
||||
/// when reauthentication itself fails.
|
||||
#[error("Session expired and reauthentication failed")]
|
||||
SessionExpired,
|
||||
}
|
||||
|
||||
impl Error {
|
||||
/// Create an `InvalidData` error with the given message.
|
||||
pub fn invalid_data(msg: impl Into<String>) -> Self {
|
||||
Error::InvalidData {
|
||||
message: msg.into(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns `true` if this error is potentially transient and
|
||||
/// the operation could succeed on retry.
|
||||
pub fn is_retryable(&self) -> bool {
|
||||
matches!(
|
||||
self,
|
||||
Error::Timeout
|
||||
| Error::Disconnected
|
||||
| Error::Protocol {
|
||||
status: NtStatus::INSUFFICIENT_RESOURCES,
|
||||
..
|
||||
}
|
||||
| Error::Protocol {
|
||||
status: NtStatus::INSUFF_SERVER_RESOURCES,
|
||||
..
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
/// Returns the NTSTATUS code if this is a protocol error.
|
||||
pub fn status(&self) -> Option<NtStatus> {
|
||||
match self {
|
||||
Error::Protocol { status, .. } => Some(*status),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// High-level error classification.
|
||||
///
|
||||
/// Maps protocol-level NTSTATUS codes and other errors into categories
|
||||
/// that consumers can match on without understanding SMB internals.
|
||||
///
|
||||
/// ```no_run
|
||||
/// # async fn example(client: &mut smb2::SmbClient, share: &mut smb2::Tree) -> Result<(), smb2::Error> {
|
||||
/// use smb2::ErrorKind;
|
||||
///
|
||||
/// match client.read_file(share, "photo.jpg").await {
|
||||
/// Ok(data) => println!("read {} bytes", data.len()),
|
||||
/// Err(e) => match e.kind() {
|
||||
/// ErrorKind::NotFound => println!("file doesn't exist"),
|
||||
/// ErrorKind::AlreadyExists => println!("name is already taken"),
|
||||
/// ErrorKind::AccessDenied => println!("no permission"),
|
||||
/// ErrorKind::SigningRequired => println!("server requires signing, use credentials"),
|
||||
/// ErrorKind::AuthRequired => println!("server requires authentication"),
|
||||
/// ErrorKind::SharingViolation => println!("file is in use by another client"),
|
||||
/// ErrorKind::IsADirectory => println!("path is a directory, not a file"),
|
||||
/// ErrorKind::NotADirectory => println!("path is a file, not a directory"),
|
||||
/// ErrorKind::DiskFull => println!("volume is full"),
|
||||
/// ErrorKind::ConnectionLost => { client.reconnect().await?; }
|
||||
/// _ => return Err(e),
|
||||
/// }
|
||||
/// }
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
///
|
||||
/// # Stability
|
||||
///
|
||||
/// `ErrorKind` is `#[non_exhaustive]`: future versions may add variants for
|
||||
/// status codes that currently fall through to [`ErrorKind::Other`]. Match
|
||||
/// statements should always include a `_` arm. Adding a variant is treated
|
||||
/// as a non-breaking change.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
#[non_exhaustive]
|
||||
pub enum ErrorKind {
|
||||
/// The server requires authentication (guest/anonymous not allowed).
|
||||
AuthRequired,
|
||||
/// The server requires message signing (guest sessions are unsigned).
|
||||
SigningRequired,
|
||||
/// Permission denied (valid credentials, but no access to this resource).
|
||||
AccessDenied,
|
||||
/// The file, directory, or share was not found.
|
||||
NotFound,
|
||||
/// A file or directory with the given name already exists.
|
||||
///
|
||||
/// Returned by `Create` (and operations that wrap it, like `create_directory`)
|
||||
/// when the target name is taken. Useful for callers that want to merge into
|
||||
/// an existing directory or surface a friendly "name already taken" message.
|
||||
AlreadyExists,
|
||||
/// The file is in use by another client.
|
||||
SharingViolation,
|
||||
/// The target path is a directory, but the operation expected a file.
|
||||
///
|
||||
/// Typically seen when calling `delete_file` against a directory entry —
|
||||
/// the caller can fall back to `delete_directory` after detecting this.
|
||||
IsADirectory,
|
||||
/// The target path is a file, but the operation expected a directory.
|
||||
///
|
||||
/// Typically seen when calling `list_directory` against a file entry.
|
||||
NotADirectory,
|
||||
/// The volume is full (write failed).
|
||||
DiskFull,
|
||||
/// The network connection was lost.
|
||||
ConnectionLost,
|
||||
/// The operation timed out.
|
||||
TimedOut,
|
||||
/// The operation was cancelled by the caller.
|
||||
Cancelled,
|
||||
/// The session expired (call `reconnect()`).
|
||||
SessionExpired,
|
||||
/// The path requires DFS referral resolution.
|
||||
DfsReferral,
|
||||
/// Invalid data or malformed response.
|
||||
InvalidData,
|
||||
/// An I/O error (transport or callback). Not necessarily a connection loss.
|
||||
///
|
||||
/// Distinct from `ConnectionLost`: the connection may still be usable.
|
||||
/// For example, a callback error in `write_file_streamed` produces `Io`,
|
||||
/// but the connection is still in a clean state.
|
||||
Io,
|
||||
/// A protocol error not covered by other variants.
|
||||
///
|
||||
/// Use [`Error::status()`] to get the raw NTSTATUS code. Some defined
|
||||
/// `NtStatus` codes deliberately fall through here today
|
||||
/// (`OBJECT_NAME_INVALID`, `DELETE_PENDING`, `INSUFFICIENT_RESOURCES`,
|
||||
/// `INSUFF_SERVER_RESOURCES`, and similar) — they don't yet have a
|
||||
/// dedicated `ErrorKind` because no consumer needs to branch on them.
|
||||
/// Promoting one to its own variant is non-breaking.
|
||||
Other,
|
||||
}
|
||||
|
||||
impl Error {
|
||||
/// Classify this error into a high-level category.
|
||||
///
|
||||
/// Consumers can match on [`ErrorKind`] without understanding raw
|
||||
/// NTSTATUS codes. For the underlying status code, use [`status()`](Self::status).
|
||||
pub fn kind(&self) -> ErrorKind {
|
||||
match self {
|
||||
Error::InvalidData { .. } => ErrorKind::InvalidData,
|
||||
Error::Auth { .. } => ErrorKind::AuthRequired,
|
||||
Error::Io(_) => ErrorKind::Io,
|
||||
Error::Disconnected => ErrorKind::ConnectionLost,
|
||||
Error::Timeout => ErrorKind::TimedOut,
|
||||
Error::Cancelled => ErrorKind::Cancelled,
|
||||
Error::SessionExpired => ErrorKind::SessionExpired,
|
||||
Error::DfsReferralRequired { .. } => ErrorKind::DfsReferral,
|
||||
Error::Protocol { status, .. } => classify_status(*status),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Map an NTSTATUS to an ErrorKind.
|
||||
fn classify_status(status: NtStatus) -> ErrorKind {
|
||||
match status {
|
||||
// Auth / signing
|
||||
NtStatus::LOGON_FAILURE | NtStatus::ACCOUNT_DISABLED => ErrorKind::AuthRequired,
|
||||
NtStatus::ACCESS_DENIED => {
|
||||
// Could be signing-required or genuinely access-denied.
|
||||
// Callers with NegotiatedParams context can distinguish further.
|
||||
// Default to AccessDenied; SmbClient methods can upgrade to
|
||||
// SigningRequired when signing_required is true.
|
||||
ErrorKind::AccessDenied
|
||||
}
|
||||
|
||||
// Not found
|
||||
NtStatus::NO_SUCH_FILE
|
||||
| NtStatus::OBJECT_NAME_NOT_FOUND
|
||||
| NtStatus::OBJECT_PATH_NOT_FOUND
|
||||
| NtStatus::BAD_NETWORK_NAME => ErrorKind::NotFound,
|
||||
|
||||
// Already exists
|
||||
NtStatus::OBJECT_NAME_COLLISION => ErrorKind::AlreadyExists,
|
||||
|
||||
// Wrong file type
|
||||
NtStatus::FILE_IS_A_DIRECTORY => ErrorKind::IsADirectory,
|
||||
NtStatus::NOT_A_DIRECTORY => ErrorKind::NotADirectory,
|
||||
|
||||
// Sharing / locking
|
||||
NtStatus::SHARING_VIOLATION | NtStatus::FILE_LOCK_CONFLICT => ErrorKind::SharingViolation,
|
||||
|
||||
// Disk full
|
||||
NtStatus::DISK_FULL => ErrorKind::DiskFull,
|
||||
|
||||
// Session expired
|
||||
NtStatus::NETWORK_SESSION_EXPIRED => ErrorKind::SessionExpired,
|
||||
|
||||
// Connection
|
||||
NtStatus::NETWORK_NAME_DELETED | NtStatus::USER_SESSION_DELETED => {
|
||||
ErrorKind::ConnectionLost
|
||||
}
|
||||
|
||||
// DFS
|
||||
NtStatus::PATH_NOT_COVERED => ErrorKind::DfsReferral,
|
||||
|
||||
// Everything else
|
||||
_ => ErrorKind::Other,
|
||||
}
|
||||
}
|
||||
|
||||
/// A `Result` type alias using the crate's [`Error`](enum@Error) type.
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
/// Documents the full contract between `NtStatus` codes and `ErrorKind`.
|
||||
///
|
||||
/// Every code listed here is asserted to map to its expected variant. When
|
||||
/// adding a new `NtStatus` to `types/status.rs`, also add a row here — either
|
||||
/// pointing at a dedicated `ErrorKind`, or `ErrorKind::Other` if there is
|
||||
/// genuinely no consumer-meaningful classification yet. The companion test
|
||||
/// `classify_status_no_silent_other` then guarantees the table stays in sync
|
||||
/// with what `classify_status` actually does.
|
||||
const STATUS_CLASSIFICATION_CONTRACT: &[(NtStatus, ErrorKind)] = &[
|
||||
// Auth / signing
|
||||
(NtStatus::LOGON_FAILURE, ErrorKind::AuthRequired),
|
||||
(NtStatus::ACCOUNT_DISABLED, ErrorKind::AuthRequired),
|
||||
(NtStatus::ACCESS_DENIED, ErrorKind::AccessDenied),
|
||||
// Not found
|
||||
(NtStatus::NO_SUCH_FILE, ErrorKind::NotFound),
|
||||
(NtStatus::OBJECT_NAME_NOT_FOUND, ErrorKind::NotFound),
|
||||
(NtStatus::OBJECT_PATH_NOT_FOUND, ErrorKind::NotFound),
|
||||
(NtStatus::BAD_NETWORK_NAME, ErrorKind::NotFound),
|
||||
// Already exists
|
||||
(NtStatus::OBJECT_NAME_COLLISION, ErrorKind::AlreadyExists),
|
||||
// Wrong file type
|
||||
(NtStatus::FILE_IS_A_DIRECTORY, ErrorKind::IsADirectory),
|
||||
(NtStatus::NOT_A_DIRECTORY, ErrorKind::NotADirectory),
|
||||
// Sharing / locking
|
||||
(NtStatus::SHARING_VIOLATION, ErrorKind::SharingViolation),
|
||||
(NtStatus::FILE_LOCK_CONFLICT, ErrorKind::SharingViolation),
|
||||
// Disk
|
||||
(NtStatus::DISK_FULL, ErrorKind::DiskFull),
|
||||
// Connection / session
|
||||
(NtStatus::NETWORK_NAME_DELETED, ErrorKind::ConnectionLost),
|
||||
(NtStatus::USER_SESSION_DELETED, ErrorKind::ConnectionLost),
|
||||
(NtStatus::NETWORK_SESSION_EXPIRED, ErrorKind::SessionExpired),
|
||||
// DFS
|
||||
(NtStatus::PATH_NOT_COVERED, ErrorKind::DfsReferral),
|
||||
// Documented `Other` (no current consumer demand for a typed variant)
|
||||
(NtStatus::NOT_IMPLEMENTED, ErrorKind::Other),
|
||||
(NtStatus::INVALID_PARAMETER, ErrorKind::Other),
|
||||
(NtStatus::DELETE_PENDING, ErrorKind::Other),
|
||||
(NtStatus::INSUFFICIENT_RESOURCES, ErrorKind::Other),
|
||||
(NtStatus::INSUFF_SERVER_RESOURCES, ErrorKind::Other),
|
||||
];
|
||||
|
||||
#[test]
|
||||
fn classify_status_contract() {
|
||||
for (status, expected) in STATUS_CLASSIFICATION_CONTRACT {
|
||||
let err = Error::Protocol {
|
||||
status: *status,
|
||||
command: Command::Create,
|
||||
};
|
||||
assert_eq!(
|
||||
err.kind(),
|
||||
*expected,
|
||||
"{status} should classify as {expected:?}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn kind_maps_non_protocol_errors() {
|
||||
assert_eq!(Error::Timeout.kind(), ErrorKind::TimedOut);
|
||||
assert_eq!(Error::Disconnected.kind(), ErrorKind::ConnectionLost);
|
||||
assert_eq!(Error::Cancelled.kind(), ErrorKind::Cancelled);
|
||||
assert_eq!(Error::SessionExpired.kind(), ErrorKind::SessionExpired);
|
||||
assert_eq!(Error::invalid_data("test").kind(), ErrorKind::InvalidData);
|
||||
assert_eq!(
|
||||
Error::DfsReferralRequired {
|
||||
path: "test".into()
|
||||
}
|
||||
.kind(),
|
||||
ErrorKind::DfsReferral
|
||||
);
|
||||
assert_eq!(
|
||||
Error::Auth {
|
||||
message: "test".into()
|
||||
}
|
||||
.kind(),
|
||||
ErrorKind::AuthRequired
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn kind_maps_io_error_to_io_not_connection_lost() {
|
||||
// Error::Io from callback errors (like write_file_streamed cancellation)
|
||||
// should NOT be ConnectionLost — the connection may still be usable.
|
||||
let err = Error::Io(std::io::Error::new(
|
||||
std::io::ErrorKind::Interrupted,
|
||||
"cancelled",
|
||||
));
|
||||
assert_eq!(err.kind(), ErrorKind::Io);
|
||||
assert_ne!(err.kind(), ErrorKind::ConnectionLost);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn kind_disconnected_is_connection_lost() {
|
||||
// Error::Disconnected (transport EOF) IS a connection loss.
|
||||
assert_eq!(Error::Disconnected.kind(), ErrorKind::ConnectionLost);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn kind_maps_dfs_referral_required_to_dfs_referral() {
|
||||
// The explicit DFS referral error variant should also map to DfsReferral.
|
||||
let err = Error::DfsReferralRequired {
|
||||
path: r"\\server\share\path".into(),
|
||||
};
|
||||
assert_eq!(err.kind(), ErrorKind::DfsReferral);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dfs_referral_is_not_retryable() {
|
||||
// DFS referrals need special handling, not generic retry.
|
||||
let err = Error::Protocol {
|
||||
status: NtStatus::PATH_NOT_COVERED,
|
||||
command: Command::Create,
|
||||
};
|
||||
assert!(!err.is_retryable());
|
||||
}
|
||||
}
|
||||
185
vendor/smb2/src/fuzzing.rs
vendored
Normal file
185
vendor/smb2/src/fuzzing.rs
vendored
Normal file
@@ -0,0 +1,185 @@
|
||||
//! Fuzzing entry points for `fuzz/` targets.
|
||||
//!
|
||||
//! This module is feature-gated behind `fuzzing` and only exists to give
|
||||
//! `cargo-fuzz` targets stable, public access to otherwise-internal parse
|
||||
//! functions. Applications must not depend on it -- it's unstable by
|
||||
//! design, and enabling the feature pulls in nothing of runtime value.
|
||||
//!
|
||||
//! Every function here takes untrusted bytes and returns either a parsed
|
||||
//! value or a clean typed error. No function here is allowed to panic on
|
||||
//! bad input; that's what the fuzzer tests.
|
||||
//!
|
||||
//! Targets (see `fuzz/fuzz_targets/`):
|
||||
//!
|
||||
//! - [`fuzz_header_parse`] -- SMB2 header (`msg::header::Header`).
|
||||
//! - [`fuzz_transform_header_parse`] -- encryption transform header.
|
||||
//! - [`fuzz_compression_transform_header_parse`] -- compression wrapper.
|
||||
//! - [`fuzz_compound_split`] -- `client::connection::split_compound`.
|
||||
//! - [`fuzz_frame_parse`] -- compound split + per-sub-frame header parse,
|
||||
//! which is the real receiver-loop path up to the body.
|
||||
//! - [`fuzz_sub_frame_parse`] -- header + body (dispatched by `Command`).
|
||||
//! - [`fuzz_negotiate_request_parse`] / [`fuzz_negotiate_response_parse`]
|
||||
//! - [`fuzz_create_request_parse`] / [`fuzz_create_response_parse`]
|
||||
//! -- CreateContext list lives inside these bodies.
|
||||
//! - [`fuzz_query_info_response_parse`] -- opaque output buffer sharp edge.
|
||||
//! - [`fuzz_dfs_referral_response_parse`] -- manual offset arithmetic,
|
||||
//! obvious fuzzing target.
|
||||
|
||||
use crate::msg::header::Header;
|
||||
use crate::msg::transform::{CompressionTransformHeader, TransformHeader};
|
||||
use crate::pack::{ReadCursor, Unpack};
|
||||
use crate::types::Command;
|
||||
|
||||
/// Fuzz the top-level SMB2 header parser.
|
||||
pub fn fuzz_header_parse(data: &[u8]) {
|
||||
let mut cursor = ReadCursor::new(data);
|
||||
let _ = Header::unpack(&mut cursor);
|
||||
}
|
||||
|
||||
/// Fuzz the encryption transform header parser.
|
||||
pub fn fuzz_transform_header_parse(data: &[u8]) {
|
||||
let mut cursor = ReadCursor::new(data);
|
||||
let _ = TransformHeader::unpack(&mut cursor);
|
||||
}
|
||||
|
||||
/// Fuzz the compression transform header parser.
|
||||
pub fn fuzz_compression_transform_header_parse(data: &[u8]) {
|
||||
let mut cursor = ReadCursor::new(data);
|
||||
let _ = CompressionTransformHeader::unpack(&mut cursor);
|
||||
}
|
||||
|
||||
/// Fuzz the compound-frame splitter. Takes a preprocessed (already decrypted
|
||||
/// and decompressed) buffer and returns the sub-frame byte slices.
|
||||
pub fn fuzz_compound_split(data: &[u8]) {
|
||||
let _ = crate::client::connection::split_compound(data);
|
||||
}
|
||||
|
||||
/// Fuzz the full receiver-loop parse path: compound split, plus parsing the
|
||||
/// header of every sub-frame. Mirrors what `prepare_sub_frame` does before
|
||||
/// it dispatches on `Command`.
|
||||
pub fn fuzz_frame_parse(data: &[u8]) {
|
||||
let subs = match crate::client::connection::split_compound(data) {
|
||||
Ok(s) => s,
|
||||
Err(_) => return,
|
||||
};
|
||||
for sub in subs {
|
||||
let mut cursor = ReadCursor::new(&sub);
|
||||
let _ = Header::unpack(&mut cursor);
|
||||
}
|
||||
}
|
||||
|
||||
/// Fuzz header + body (dispatched by `Command`). Much wider surface than
|
||||
/// [`fuzz_frame_parse`] because it actually parses the response body for
|
||||
/// every command type.
|
||||
pub fn fuzz_sub_frame_parse(data: &[u8]) {
|
||||
if data.len() < Header::SIZE {
|
||||
return;
|
||||
}
|
||||
let mut cursor = ReadCursor::new(data);
|
||||
let header = match Header::unpack(&mut cursor) {
|
||||
Ok(h) => h,
|
||||
Err(_) => return,
|
||||
};
|
||||
|
||||
let body = &data[Header::SIZE..];
|
||||
let is_response = header.is_response();
|
||||
dispatch_body(header.command, is_response, body);
|
||||
}
|
||||
|
||||
fn dispatch_body(command: Command, is_response: bool, body: &[u8]) {
|
||||
use crate::msg;
|
||||
|
||||
// Unpack the given type from `body` and discard the result. Parse errors
|
||||
// are fine (boring path); panics / UB are what libfuzzer catches.
|
||||
macro_rules! try_unpack {
|
||||
($ty:ty) => {{
|
||||
let mut cursor = ReadCursor::new(body);
|
||||
let _ = <$ty as Unpack>::unpack(&mut cursor);
|
||||
}};
|
||||
}
|
||||
|
||||
match (command, is_response) {
|
||||
(Command::Negotiate, false) => try_unpack!(msg::negotiate::NegotiateRequest),
|
||||
(Command::Negotiate, true) => try_unpack!(msg::negotiate::NegotiateResponse),
|
||||
(Command::SessionSetup, false) => try_unpack!(msg::session_setup::SessionSetupRequest),
|
||||
(Command::SessionSetup, true) => try_unpack!(msg::session_setup::SessionSetupResponse),
|
||||
(Command::Logoff, false) => try_unpack!(msg::logoff::LogoffRequest),
|
||||
(Command::Logoff, true) => try_unpack!(msg::logoff::LogoffResponse),
|
||||
(Command::TreeConnect, false) => try_unpack!(msg::tree_connect::TreeConnectRequest),
|
||||
(Command::TreeConnect, true) => try_unpack!(msg::tree_connect::TreeConnectResponse),
|
||||
(Command::TreeDisconnect, false) => {
|
||||
try_unpack!(msg::tree_disconnect::TreeDisconnectRequest)
|
||||
}
|
||||
(Command::TreeDisconnect, true) => {
|
||||
try_unpack!(msg::tree_disconnect::TreeDisconnectResponse)
|
||||
}
|
||||
(Command::Create, false) => try_unpack!(msg::create::CreateRequest),
|
||||
(Command::Create, true) => try_unpack!(msg::create::CreateResponse),
|
||||
(Command::Close, false) => try_unpack!(msg::close::CloseRequest),
|
||||
(Command::Close, true) => try_unpack!(msg::close::CloseResponse),
|
||||
(Command::Flush, false) => try_unpack!(msg::flush::FlushRequest),
|
||||
(Command::Flush, true) => try_unpack!(msg::flush::FlushResponse),
|
||||
(Command::Read, false) => try_unpack!(msg::read::ReadRequest),
|
||||
(Command::Read, true) => try_unpack!(msg::read::ReadResponse),
|
||||
(Command::Write, false) => try_unpack!(msg::write::WriteRequest),
|
||||
(Command::Write, true) => try_unpack!(msg::write::WriteResponse),
|
||||
(Command::Lock, false) => try_unpack!(msg::lock::LockRequest),
|
||||
(Command::Lock, true) => try_unpack!(msg::lock::LockResponse),
|
||||
(Command::Ioctl, false) => try_unpack!(msg::ioctl::IoctlRequest),
|
||||
(Command::Ioctl, true) => try_unpack!(msg::ioctl::IoctlResponse),
|
||||
(Command::Cancel, false) => try_unpack!(msg::cancel::CancelRequest),
|
||||
(Command::Echo, false) => try_unpack!(msg::echo::EchoRequest),
|
||||
(Command::Echo, true) => try_unpack!(msg::echo::EchoResponse),
|
||||
(Command::QueryDirectory, false) => {
|
||||
try_unpack!(msg::query_directory::QueryDirectoryRequest)
|
||||
}
|
||||
(Command::QueryDirectory, true) => {
|
||||
try_unpack!(msg::query_directory::QueryDirectoryResponse)
|
||||
}
|
||||
(Command::ChangeNotify, false) => try_unpack!(msg::change_notify::ChangeNotifyRequest),
|
||||
(Command::ChangeNotify, true) => try_unpack!(msg::change_notify::ChangeNotifyResponse),
|
||||
(Command::QueryInfo, false) => try_unpack!(msg::query_info::QueryInfoRequest),
|
||||
(Command::QueryInfo, true) => try_unpack!(msg::query_info::QueryInfoResponse),
|
||||
(Command::SetInfo, false) => try_unpack!(msg::set_info::SetInfoRequest),
|
||||
(Command::SetInfo, true) => try_unpack!(msg::set_info::SetInfoResponse),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
/// Fuzz `NegotiateRequest::unpack` directly.
|
||||
pub fn fuzz_negotiate_request_parse(data: &[u8]) {
|
||||
let mut cursor = ReadCursor::new(data);
|
||||
let _ = crate::msg::negotiate::NegotiateRequest::unpack(&mut cursor);
|
||||
}
|
||||
|
||||
/// Fuzz `NegotiateResponse::unpack` directly. Covers negotiate-context parsing.
|
||||
pub fn fuzz_negotiate_response_parse(data: &[u8]) {
|
||||
let mut cursor = ReadCursor::new(data);
|
||||
let _ = crate::msg::negotiate::NegotiateResponse::unpack(&mut cursor);
|
||||
}
|
||||
|
||||
/// Fuzz `CreateRequest::unpack` directly. Covers create-context list parsing.
|
||||
pub fn fuzz_create_request_parse(data: &[u8]) {
|
||||
let mut cursor = ReadCursor::new(data);
|
||||
let _ = crate::msg::create::CreateRequest::unpack(&mut cursor);
|
||||
}
|
||||
|
||||
/// Fuzz `CreateResponse::unpack` directly.
|
||||
pub fn fuzz_create_response_parse(data: &[u8]) {
|
||||
let mut cursor = ReadCursor::new(data);
|
||||
let _ = crate::msg::create::CreateResponse::unpack(&mut cursor);
|
||||
}
|
||||
|
||||
/// Fuzz `QueryInfoResponse::unpack`, which has the tricky
|
||||
/// output-buffer-offset-from-header arithmetic.
|
||||
pub fn fuzz_query_info_response_parse(data: &[u8]) {
|
||||
let mut cursor = ReadCursor::new(data);
|
||||
let _ = crate::msg::query_info::QueryInfoResponse::unpack(&mut cursor);
|
||||
}
|
||||
|
||||
/// Fuzz the DFS referral response parser. Manual offset arithmetic makes
|
||||
/// this a classic sharp-edge target.
|
||||
pub fn fuzz_dfs_referral_response_parse(data: &[u8]) {
|
||||
let mut cursor = ReadCursor::new(data);
|
||||
let _ = crate::msg::dfs::RespGetDfsReferral::unpack(&mut cursor);
|
||||
}
|
||||
99
vendor/smb2/src/lib.rs
vendored
Normal file
99
vendor/smb2/src/lib.rs
vendored
Normal file
@@ -0,0 +1,99 @@
|
||||
#![forbid(unsafe_code)]
|
||||
#![warn(missing_docs)]
|
||||
|
||||
//! Pure-Rust SMB2/3 client library with pipelined I/O.
|
||||
//!
|
||||
//! No C dependencies, no FFI. Pipelined reads/writes fill the credit window
|
||||
//! so downloads run ~10-25x faster than sequential SMB clients.
|
||||
//!
|
||||
//! # Quick start
|
||||
//!
|
||||
//! ```rust,no_run
|
||||
//! use smb2::{SmbClient, ClientConfig};
|
||||
//!
|
||||
//! # async fn example() -> Result<(), smb2::Error> {
|
||||
//! let mut client = smb2::connect("192.168.1.100:445", "user", "pass").await?;
|
||||
//!
|
||||
//! // List shares
|
||||
//! let shares = client.list_shares().await?;
|
||||
//!
|
||||
//! // Connect to a share
|
||||
//! let mut share = client.connect_share("Documents").await?;
|
||||
//!
|
||||
//! // List files
|
||||
//! let entries = client.list_directory(&mut share, "projects/").await?;
|
||||
//! for entry in &entries {
|
||||
//! println!("{} ({} bytes)", entry.name, entry.size);
|
||||
//! }
|
||||
//!
|
||||
//! // Read a file
|
||||
//! let data = client.read_file(&mut share, "report.pdf").await?;
|
||||
//! # Ok(())
|
||||
//! # }
|
||||
//! ```
|
||||
//!
|
||||
//! # Modules
|
||||
//!
|
||||
//! - [`client`] -- High-level API: [`SmbClient`], [`Tree`], [`Pipeline`].
|
||||
//! This is what most users need.
|
||||
//! - [`error`] -- Error types and NTSTATUS mapping.
|
||||
//! - [`msg`] -- Wire format message structs (advanced/internal use).
|
||||
//! - [`pack`] -- Binary serialization primitives (advanced/internal use).
|
||||
//! - [`transport`] -- Transport trait and TCP implementation (advanced/internal use).
|
||||
//! - [`crypto`] -- Signing and encryption (advanced/internal use).
|
||||
//! - [`auth`] -- NTLM authentication (advanced/internal use).
|
||||
//! - [`rpc`] -- Named pipe RPC for share enumeration (advanced/internal use).
|
||||
//! - [`types`] -- Protocol newtypes and flag types (advanced/internal use).
|
||||
|
||||
pub mod auth;
|
||||
pub mod client;
|
||||
pub mod crypto;
|
||||
pub mod error;
|
||||
pub mod msg;
|
||||
pub mod pack;
|
||||
pub mod rpc;
|
||||
#[cfg(feature = "testing")]
|
||||
pub mod testing;
|
||||
pub mod transport;
|
||||
pub mod types;
|
||||
|
||||
#[cfg(feature = "fuzzing")]
|
||||
pub mod fuzzing;
|
||||
|
||||
// ── Re-exports: the simple-case imports ────────────────────────────────
|
||||
|
||||
// Error types
|
||||
pub use error::{Error, ErrorKind, Result};
|
||||
|
||||
// High-level client
|
||||
pub use client::{connect, ClientConfig, SmbClient};
|
||||
|
||||
// Streaming I/O
|
||||
pub use client::stream::{FileDownload, FileUpload, FileWriter, Progress};
|
||||
|
||||
// Tree and file types
|
||||
pub use client::tree::{DirectoryEntry, FileInfo, FsInfo, Tree};
|
||||
|
||||
// Pipeline
|
||||
pub use client::pipeline::{Op, OpResult, Pipeline};
|
||||
|
||||
// Connection-level types (useful for advanced users)
|
||||
pub use client::connection::{CompoundOp, Frame, NegotiatedParams};
|
||||
pub use client::session::Session;
|
||||
|
||||
// Diagnostics: snapshot tree returned by `SmbClient::diagnostics()` /
|
||||
// `Connection::diagnostics()`.
|
||||
pub use client::diagnostics::{
|
||||
ClientInfo, ClientMetricsSnapshot, CompressionInfo, ConnectionDiagnostics, CreditInfo,
|
||||
DfsCacheEntry, Diagnostics, EncryptionInfo, MetricsSnapshot, NegotiatedSummary,
|
||||
SessionDiagnostics, SigningInfo,
|
||||
};
|
||||
|
||||
// File watching
|
||||
pub use client::watcher::{FileNotifyAction, FileNotifyEvent, Watcher};
|
||||
|
||||
// Share enumeration
|
||||
pub use rpc::srvsvc::ShareInfo;
|
||||
|
||||
// Kerberos authentication
|
||||
pub use auth::kerberos::{KerberosAuthenticator, KerberosCredentials};
|
||||
44
vendor/smb2/src/msg/CLAUDE.md
vendored
Normal file
44
vendor/smb2/src/msg/CLAUDE.md
vendored
Normal file
@@ -0,0 +1,44 @@
|
||||
# Msg -- wire format message structs
|
||||
|
||||
One sub-module per SMB2 command. Each defines request and response structs with `Pack` and `Unpack` implementations.
|
||||
|
||||
## Key files
|
||||
|
||||
| File | Purpose |
|
||||
|---|---|
|
||||
| `mod.rs` | `trivial_message!` macro for 4-byte stub messages, module declarations |
|
||||
| `header.rs` | 64-byte SMB2 header (sync + async variants), `PROTOCOL_ID` (`0xFE 'S' 'M' 'B'`) |
|
||||
| `negotiate.rs` | Negotiate contexts (preauth integrity, encryption, signing, compression) |
|
||||
| `create.rs` | CREATE request/response with create contexts |
|
||||
| `transform.rs` | `TransformHeader` (encryption, protocol ID `0xFD`), `CompressionTransformHeader` (`0xFC`) |
|
||||
|
||||
19 command modules total: negotiate, session_setup, logoff, tree_connect, tree_disconnect, create, close, flush, read, write, lock, ioctl, query_directory, change_notify, query_info, set_info, echo, cancel, oplock_break. Plus `dfs.rs` for DFS referral request/response wire format (used by IOCTL FSCTL_DFS_GET_REFERRALS).
|
||||
|
||||
## Patterns
|
||||
|
||||
- **Pack/Unpack**: All structs implement `pack(&self, &mut WriteCursor)` and `unpack(&mut ReadCursor) -> Result<Self>`. Hand-rolled, no proc macros.
|
||||
- **Offset calculation**: All offsets in SMB2 are relative to the start of the SMB2 header (not the body, not the transport frame). When packing variable-length fields, compute `header_size + fixed_body_size` as the base offset.
|
||||
- **StructureSize validation**: `Unpack` implementations read `StructureSize` first and return an error if it doesn't match the expected value.
|
||||
- **`trivial_message!` macro**: Generates Pack/Unpack for 4-byte stub messages (StructureSize=4 + Reserved=0). Used by echo, cancel, logoff, tree_disconnect.
|
||||
|
||||
## Compound messages
|
||||
|
||||
Built by `Connection::send_compound`. Each sub-request's header has a `NextCommand` field pointing to the next message (8-byte aligned). The last message has `NextCommand = 0`. Related operations use `FileId::SENTINEL` (`0xFFFFFFFF:0xFFFFFFFF`) so the server substitutes the handle from the first CREATE.
|
||||
|
||||
## Transform headers
|
||||
|
||||
- **Encryption** (`0xFD 'S' 'M' 'B'`): 52-byte `TransformHeader` wraps encrypted message(s). Contains nonce, auth tag (signature), original message size, session ID.
|
||||
- **Compression** (`0xFC 'S' 'M' 'B'`): `CompressionTransformHeader` wraps LZ4-compressed messages. Contains original and compressed sizes, algorithm ID.
|
||||
|
||||
## Gotchas
|
||||
|
||||
- **TCP framing is big-endian**: The 4-byte transport header (1 zero byte + 3-byte length) uses big-endian byte order. Everything inside the SMB2 message is little-endian. This is the only big-endian value in the entire protocol.
|
||||
- **StructureSize is "fixed"**: The spec says StructureSize is the size of the fixed-length portion of the struct. It does NOT include variable-length buffers. It's validated on unpack.
|
||||
- **`#![allow(missing_docs)]`**: This module opts out of doc requirements because wire format field names are self-documenting from the spec.
|
||||
- **Manual offset arithmetic requires careful bounds**: In `dfs.rs`, `parse_referral_entry` uses `ensure_remaining(buf, pos, N)` before raw `buf[pos..]` reads. Count the fixed fields carefully -- V2's body is **18** bytes (server_type+flags+proximity+ttl + three u16 offsets), not 16. An off-by-2 here lets a malformed `entry_size` slip past the initial guard and panic on the last offset read. Fuzz-caught in 0.7.2; regression test `resp_parse_v2_short_entry_returns_clean_error`.
|
||||
|
||||
## Fuzzing
|
||||
|
||||
Parse entry points are exposed via the `fuzzing` feature (`smb2::fuzzing`) and exercised by the `fuzz/` crate. See
|
||||
`fuzz/README.md` (if present) or run `just fuzz fuzz_header_parse 300` for a local sweep. Every new parser touching
|
||||
external bytes should get a fuzz target wrapper added in `src/fuzzing.rs` and a matching `fuzz/fuzz_targets/*.rs`.
|
||||
27
vendor/smb2/src/msg/cancel.rs
vendored
Normal file
27
vendor/smb2/src/msg/cancel.rs
vendored
Normal file
@@ -0,0 +1,27 @@
|
||||
//! SMB2 CANCEL request (spec section 2.2.30).
|
||||
//!
|
||||
//! The CANCEL request is fire-and-forget: the client sends it to cancel a
|
||||
//! previously sent message, and there is no corresponding response message.
|
||||
//! The MessageId of the request to cancel is set in the SMB2 header.
|
||||
|
||||
super::trivial_message! {
|
||||
/// SMB2 CANCEL request (spec section 2.2.30).
|
||||
///
|
||||
/// Sent by the client to cancel a previously sent message on the same
|
||||
/// transport connection. There is no response for this command.
|
||||
/// Contains only StructureSize (2 bytes) and Reserved (2 bytes).
|
||||
pub struct CancelRequest;
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
super::super::trivial_message_tests!(
|
||||
CancelRequest,
|
||||
cancel_request_known_bytes,
|
||||
cancel_request_roundtrip,
|
||||
cancel_request_wrong_structure_size,
|
||||
cancel_request_too_short
|
||||
);
|
||||
}
|
||||
355
vendor/smb2/src/msg/change_notify.rs
vendored
Normal file
355
vendor/smb2/src/msg/change_notify.rs
vendored
Normal file
@@ -0,0 +1,355 @@
|
||||
//! SMB2 CHANGE_NOTIFY Request and Response (MS-SMB2 sections 2.2.35, 2.2.36).
|
||||
//!
|
||||
//! The CHANGE_NOTIFY request registers for change notifications on a
|
||||
//! directory. The response returns FILE_NOTIFY_INFORMATION entries
|
||||
//! describing the changes that occurred.
|
||||
|
||||
use crate::error::Result;
|
||||
use crate::pack::{Pack, ReadCursor, Unpack, WriteCursor};
|
||||
use crate::types::FileId;
|
||||
use crate::Error;
|
||||
|
||||
// ── Change Notify flags ────────────────────────────────────────────────
|
||||
|
||||
/// Watch the entire subtree (recursive).
|
||||
pub const SMB2_WATCH_TREE: u16 = 0x0001;
|
||||
|
||||
// ── CompletionFilter values ────────────────────────────────────────────
|
||||
|
||||
/// Notify when a file name changes.
|
||||
pub const FILE_NOTIFY_CHANGE_FILE_NAME: u32 = 0x0000_0001;
|
||||
|
||||
/// Notify when a directory name changes.
|
||||
pub const FILE_NOTIFY_CHANGE_DIR_NAME: u32 = 0x0000_0002;
|
||||
|
||||
/// Notify when file attributes change.
|
||||
pub const FILE_NOTIFY_CHANGE_ATTRIBUTES: u32 = 0x0000_0004;
|
||||
|
||||
/// Notify when the file size changes.
|
||||
pub const FILE_NOTIFY_CHANGE_SIZE: u32 = 0x0000_0008;
|
||||
|
||||
/// Notify when the last write time changes.
|
||||
pub const FILE_NOTIFY_CHANGE_LAST_WRITE: u32 = 0x0000_0010;
|
||||
|
||||
/// Notify when the last access time changes.
|
||||
pub const FILE_NOTIFY_CHANGE_LAST_ACCESS: u32 = 0x0000_0020;
|
||||
|
||||
/// Notify when the creation time changes.
|
||||
pub const FILE_NOTIFY_CHANGE_CREATION: u32 = 0x0000_0040;
|
||||
|
||||
/// Notify when extended attributes change.
|
||||
pub const FILE_NOTIFY_CHANGE_EA: u32 = 0x0000_0080;
|
||||
|
||||
/// Notify when the security descriptor changes.
|
||||
pub const FILE_NOTIFY_CHANGE_SECURITY: u32 = 0x0000_0100;
|
||||
|
||||
/// Notify when a stream name changes.
|
||||
pub const FILE_NOTIFY_CHANGE_STREAM_NAME: u32 = 0x0000_0200;
|
||||
|
||||
/// Notify when a stream size changes.
|
||||
pub const FILE_NOTIFY_CHANGE_STREAM_SIZE: u32 = 0x0000_0400;
|
||||
|
||||
/// Notify when stream data is written.
|
||||
pub const FILE_NOTIFY_CHANGE_STREAM_WRITE: u32 = 0x0000_0800;
|
||||
|
||||
// ── ChangeNotifyRequest ────────────────────────────────────────────────
|
||||
|
||||
/// SMB2 CHANGE_NOTIFY Request (MS-SMB2 section 2.2.35).
|
||||
///
|
||||
/// Registers for directory change notifications. The structure is 32 bytes:
|
||||
/// - StructureSize (2 bytes, must be 32)
|
||||
/// - Flags (2 bytes)
|
||||
/// - OutputBufferLength (4 bytes)
|
||||
/// - FileId (16 bytes)
|
||||
/// - CompletionFilter (4 bytes)
|
||||
/// - Reserved (4 bytes)
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct ChangeNotifyRequest {
|
||||
/// Flags controlling the notification. Use `SMB2_WATCH_TREE` for recursive.
|
||||
pub flags: u16,
|
||||
/// Maximum size of the output buffer for notification data.
|
||||
pub output_buffer_length: u32,
|
||||
/// The directory handle to watch.
|
||||
pub file_id: FileId,
|
||||
/// Bitmask of change types to watch for.
|
||||
pub completion_filter: u32,
|
||||
}
|
||||
|
||||
impl ChangeNotifyRequest {
|
||||
pub const STRUCTURE_SIZE: u16 = 32;
|
||||
}
|
||||
|
||||
impl Pack for ChangeNotifyRequest {
|
||||
fn pack(&self, cursor: &mut WriteCursor) {
|
||||
// StructureSize (2 bytes)
|
||||
cursor.write_u16_le(Self::STRUCTURE_SIZE);
|
||||
// Flags (2 bytes)
|
||||
cursor.write_u16_le(self.flags);
|
||||
// OutputBufferLength (4 bytes)
|
||||
cursor.write_u32_le(self.output_buffer_length);
|
||||
// FileId (16 bytes)
|
||||
cursor.write_u64_le(self.file_id.persistent);
|
||||
cursor.write_u64_le(self.file_id.volatile);
|
||||
// CompletionFilter (4 bytes)
|
||||
cursor.write_u32_le(self.completion_filter);
|
||||
// Reserved (4 bytes)
|
||||
cursor.write_u32_le(0);
|
||||
}
|
||||
}
|
||||
|
||||
impl Unpack for ChangeNotifyRequest {
|
||||
fn unpack(cursor: &mut ReadCursor<'_>) -> Result<Self> {
|
||||
let structure_size = cursor.read_u16_le()?;
|
||||
if structure_size != Self::STRUCTURE_SIZE {
|
||||
return Err(Error::invalid_data(format!(
|
||||
"invalid ChangeNotifyRequest structure size: expected {}, got {}",
|
||||
Self::STRUCTURE_SIZE,
|
||||
structure_size
|
||||
)));
|
||||
}
|
||||
|
||||
let flags = cursor.read_u16_le()?;
|
||||
let output_buffer_length = cursor.read_u32_le()?;
|
||||
let persistent = cursor.read_u64_le()?;
|
||||
let volatile = cursor.read_u64_le()?;
|
||||
let completion_filter = cursor.read_u32_le()?;
|
||||
let _reserved = cursor.read_u32_le()?;
|
||||
|
||||
Ok(ChangeNotifyRequest {
|
||||
flags,
|
||||
output_buffer_length,
|
||||
file_id: FileId {
|
||||
persistent,
|
||||
volatile,
|
||||
},
|
||||
completion_filter,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ── ChangeNotifyResponse ───────────────────────────────────────────────
|
||||
|
||||
/// SMB2 CHANGE_NOTIFY Response (MS-SMB2 section 2.2.36).
|
||||
///
|
||||
/// Returns FILE_NOTIFY_INFORMATION entries describing directory changes.
|
||||
/// The buffer contains raw FILE_NOTIFY_INFORMATION entries; parsing those
|
||||
/// is left to the caller for now.
|
||||
///
|
||||
/// Layout:
|
||||
/// - StructureSize (2 bytes, must be 9)
|
||||
/// - OutputBufferOffset (2 bytes)
|
||||
/// - OutputBufferLength (4 bytes)
|
||||
/// - Buffer (variable, OutputBufferLength bytes)
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct ChangeNotifyResponse {
|
||||
/// Raw FILE_NOTIFY_INFORMATION data. Parsing individual entries is
|
||||
/// deferred to a higher layer.
|
||||
pub output_data: Vec<u8>,
|
||||
}
|
||||
|
||||
impl ChangeNotifyResponse {
|
||||
pub const STRUCTURE_SIZE: u16 = 9;
|
||||
|
||||
/// Fixed header size before the variable buffer (8 bytes).
|
||||
const FIXED_SIZE: u32 = 8;
|
||||
}
|
||||
|
||||
impl Pack for ChangeNotifyResponse {
|
||||
fn pack(&self, cursor: &mut WriteCursor) {
|
||||
let start = cursor.position();
|
||||
// StructureSize (2 bytes)
|
||||
cursor.write_u16_le(Self::STRUCTURE_SIZE);
|
||||
|
||||
let output_len = self.output_data.len() as u32;
|
||||
// Offset is from the beginning of the SMB2 header per spec.
|
||||
let output_offset = if output_len > 0 {
|
||||
(start as u32) + Self::FIXED_SIZE
|
||||
} else {
|
||||
0
|
||||
};
|
||||
|
||||
// OutputBufferOffset (2 bytes)
|
||||
cursor.write_u16_le(output_offset as u16);
|
||||
// OutputBufferLength (4 bytes)
|
||||
cursor.write_u32_le(output_len);
|
||||
// Buffer (variable)
|
||||
cursor.write_bytes(&self.output_data);
|
||||
}
|
||||
}
|
||||
|
||||
impl Unpack for ChangeNotifyResponse {
|
||||
fn unpack(cursor: &mut ReadCursor<'_>) -> Result<Self> {
|
||||
let structure_size = cursor.read_u16_le()?;
|
||||
if structure_size != Self::STRUCTURE_SIZE {
|
||||
return Err(Error::invalid_data(format!(
|
||||
"invalid ChangeNotifyResponse structure size: expected {}, got {}",
|
||||
Self::STRUCTURE_SIZE,
|
||||
structure_size
|
||||
)));
|
||||
}
|
||||
|
||||
let _output_buffer_offset = cursor.read_u16_le()?;
|
||||
let output_buffer_length = cursor.read_u32_le()?;
|
||||
|
||||
let output_data = if output_buffer_length > 0 {
|
||||
cursor
|
||||
.read_bytes_bounded(output_buffer_length as usize)?
|
||||
.to_vec()
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
|
||||
Ok(ChangeNotifyResponse { output_data })
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// ── ChangeNotifyRequest tests ─────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn change_notify_request_roundtrip_recursive() {
|
||||
let original = ChangeNotifyRequest {
|
||||
flags: SMB2_WATCH_TREE,
|
||||
output_buffer_length: 65536,
|
||||
file_id: FileId {
|
||||
persistent: 0x1122_3344_5566_7788,
|
||||
volatile: 0xAABB_CCDD_EEFF_0011,
|
||||
},
|
||||
completion_filter: FILE_NOTIFY_CHANGE_FILE_NAME
|
||||
| FILE_NOTIFY_CHANGE_DIR_NAME
|
||||
| FILE_NOTIFY_CHANGE_LAST_WRITE,
|
||||
};
|
||||
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
// Fixed 32 bytes, no variable data
|
||||
assert_eq!(bytes.len(), 32);
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = ChangeNotifyRequest::unpack(&mut r).unwrap();
|
||||
|
||||
assert_eq!(decoded.flags, SMB2_WATCH_TREE);
|
||||
assert_eq!(decoded.output_buffer_length, 65536);
|
||||
assert_eq!(decoded.file_id, original.file_id);
|
||||
assert_eq!(
|
||||
decoded.completion_filter,
|
||||
FILE_NOTIFY_CHANGE_FILE_NAME
|
||||
| FILE_NOTIFY_CHANGE_DIR_NAME
|
||||
| FILE_NOTIFY_CHANGE_LAST_WRITE
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn change_notify_request_wrong_structure_size() {
|
||||
let mut buf = [0u8; 32];
|
||||
buf[0..2].copy_from_slice(&99u16.to_le_bytes());
|
||||
|
||||
let mut cursor = ReadCursor::new(&buf);
|
||||
let result = ChangeNotifyRequest::unpack(&mut cursor);
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err().to_string();
|
||||
assert!(err.contains("structure size"), "error was: {err}");
|
||||
}
|
||||
|
||||
// ── ChangeNotifyResponse tests ────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn change_notify_response_roundtrip_with_data() {
|
||||
let notify_data = vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08];
|
||||
let original = ChangeNotifyResponse {
|
||||
output_data: notify_data.clone(),
|
||||
};
|
||||
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
// Fixed 8 bytes + 8 bytes data
|
||||
assert_eq!(bytes.len(), 16);
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = ChangeNotifyResponse::unpack(&mut r).unwrap();
|
||||
|
||||
assert_eq!(decoded.output_data, notify_data);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn change_notify_response_roundtrip_empty() {
|
||||
let original = ChangeNotifyResponse {
|
||||
output_data: Vec::new(),
|
||||
};
|
||||
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
assert_eq!(bytes.len(), 8);
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = ChangeNotifyResponse::unpack(&mut r).unwrap();
|
||||
|
||||
assert!(decoded.output_data.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn change_notify_response_wrong_structure_size() {
|
||||
let mut buf = [0u8; 8];
|
||||
buf[0..2].copy_from_slice(&42u16.to_le_bytes());
|
||||
|
||||
let mut cursor = ReadCursor::new(&buf);
|
||||
let result = ChangeNotifyResponse::unpack(&mut cursor);
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err().to_string();
|
||||
assert!(err.contains("structure size"), "error was: {err}");
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod roundtrip_props {
|
||||
use super::*;
|
||||
use crate::msg::roundtrip_strategies::{arb_bytes, arb_file_id};
|
||||
use proptest::prelude::*;
|
||||
|
||||
proptest! {
|
||||
#[test]
|
||||
fn change_notify_request_pack_unpack(
|
||||
flags in any::<u16>(),
|
||||
output_buffer_length in any::<u32>(),
|
||||
file_id in arb_file_id(),
|
||||
completion_filter in any::<u32>(),
|
||||
) {
|
||||
let original = ChangeNotifyRequest {
|
||||
flags,
|
||||
output_buffer_length,
|
||||
file_id,
|
||||
completion_filter,
|
||||
};
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = ChangeNotifyRequest::unpack(&mut r).unwrap();
|
||||
prop_assert_eq!(decoded, original);
|
||||
prop_assert!(r.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn change_notify_response_pack_unpack(output_data in arb_bytes()) {
|
||||
let original = ChangeNotifyResponse { output_data };
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = ChangeNotifyResponse::unpack(&mut r).unwrap();
|
||||
prop_assert_eq!(decoded, original);
|
||||
prop_assert!(r.is_empty());
|
||||
}
|
||||
}
|
||||
}
|
||||
390
vendor/smb2/src/msg/close.rs
vendored
Normal file
390
vendor/smb2/src/msg/close.rs
vendored
Normal file
@@ -0,0 +1,390 @@
|
||||
//! SMB2 CLOSE Request and Response (MS-SMB2 sections 2.2.15, 2.2.16).
|
||||
//!
|
||||
//! The CLOSE request closes a file handle previously opened via CREATE.
|
||||
//! The response optionally returns file attributes if the
|
||||
//! `SMB2_CLOSE_FLAG_POSTQUERY_ATTRIB` flag was set.
|
||||
|
||||
use crate::error::Result;
|
||||
use crate::pack::{FileTime, Pack, ReadCursor, Unpack, WriteCursor};
|
||||
use crate::types::FileId;
|
||||
use crate::Error;
|
||||
|
||||
/// Close flag: request that the server returns file attributes in the response.
|
||||
pub const SMB2_CLOSE_FLAG_POSTQUERY_ATTRIB: u16 = 0x0001;
|
||||
|
||||
/// SMB2 CLOSE Request (MS-SMB2 section 2.2.15).
|
||||
///
|
||||
/// Sent by the client to close a file handle. The structure is 24 bytes:
|
||||
/// - StructureSize (2 bytes, must be 24)
|
||||
/// - Flags (2 bytes)
|
||||
/// - Reserved (4 bytes)
|
||||
/// - FileId (16 bytes)
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct CloseRequest {
|
||||
/// Flags indicating how to process the close.
|
||||
/// Use `SMB2_CLOSE_FLAG_POSTQUERY_ATTRIB` to request attributes.
|
||||
pub flags: u16,
|
||||
/// The file handle to close.
|
||||
pub file_id: FileId,
|
||||
}
|
||||
|
||||
impl CloseRequest {
|
||||
pub const STRUCTURE_SIZE: u16 = 24;
|
||||
}
|
||||
|
||||
impl Pack for CloseRequest {
|
||||
fn pack(&self, cursor: &mut WriteCursor) {
|
||||
// StructureSize (2 bytes)
|
||||
cursor.write_u16_le(Self::STRUCTURE_SIZE);
|
||||
// Flags (2 bytes)
|
||||
cursor.write_u16_le(self.flags);
|
||||
// Reserved (4 bytes)
|
||||
cursor.write_u32_le(0);
|
||||
// FileId (16 bytes): persistent + volatile
|
||||
cursor.write_u64_le(self.file_id.persistent);
|
||||
cursor.write_u64_le(self.file_id.volatile);
|
||||
}
|
||||
}
|
||||
|
||||
impl Unpack for CloseRequest {
|
||||
fn unpack(cursor: &mut ReadCursor<'_>) -> Result<Self> {
|
||||
let structure_size = cursor.read_u16_le()?;
|
||||
if structure_size != Self::STRUCTURE_SIZE {
|
||||
return Err(Error::invalid_data(format!(
|
||||
"invalid CloseRequest structure size: expected {}, got {}",
|
||||
Self::STRUCTURE_SIZE,
|
||||
structure_size
|
||||
)));
|
||||
}
|
||||
|
||||
let flags = cursor.read_u16_le()?;
|
||||
let _reserved = cursor.read_u32_le()?;
|
||||
let persistent = cursor.read_u64_le()?;
|
||||
let volatile = cursor.read_u64_le()?;
|
||||
|
||||
Ok(CloseRequest {
|
||||
flags,
|
||||
file_id: FileId {
|
||||
persistent,
|
||||
volatile,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// SMB2 CLOSE Response (MS-SMB2 section 2.2.16).
|
||||
///
|
||||
/// Sent by the server to confirm a close. The structure is 60 bytes:
|
||||
/// - StructureSize (2 bytes, must be 60)
|
||||
/// - Flags (2 bytes)
|
||||
/// - Reserved (4 bytes)
|
||||
/// - CreationTime (8 bytes)
|
||||
/// - LastAccessTime (8 bytes)
|
||||
/// - LastWriteTime (8 bytes)
|
||||
/// - ChangeTime (8 bytes)
|
||||
/// - AllocationSize (8 bytes)
|
||||
/// - EndOfFile (8 bytes)
|
||||
/// - FileAttributes (4 bytes)
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct CloseResponse {
|
||||
/// Flags echoed from the request. If `SMB2_CLOSE_FLAG_POSTQUERY_ATTRIB`
|
||||
/// is set, the attribute fields below contain valid data.
|
||||
pub flags: u16,
|
||||
/// File creation time.
|
||||
pub creation_time: FileTime,
|
||||
/// Last access time.
|
||||
pub last_access_time: FileTime,
|
||||
/// Last write time.
|
||||
pub last_write_time: FileTime,
|
||||
/// Change time.
|
||||
pub change_time: FileTime,
|
||||
/// Size of allocated data in bytes.
|
||||
pub allocation_size: u64,
|
||||
/// End-of-file position in bytes.
|
||||
pub end_of_file: u64,
|
||||
/// File attributes (see MS-FSCC section 2.6).
|
||||
pub file_attributes: u32,
|
||||
}
|
||||
|
||||
impl CloseResponse {
|
||||
pub const STRUCTURE_SIZE: u16 = 60;
|
||||
}
|
||||
|
||||
impl Pack for CloseResponse {
|
||||
fn pack(&self, cursor: &mut WriteCursor) {
|
||||
cursor.write_u16_le(Self::STRUCTURE_SIZE);
|
||||
cursor.write_u16_le(self.flags);
|
||||
cursor.write_u32_le(0); // Reserved
|
||||
self.creation_time.pack(cursor);
|
||||
self.last_access_time.pack(cursor);
|
||||
self.last_write_time.pack(cursor);
|
||||
self.change_time.pack(cursor);
|
||||
cursor.write_u64_le(self.allocation_size);
|
||||
cursor.write_u64_le(self.end_of_file);
|
||||
cursor.write_u32_le(self.file_attributes);
|
||||
}
|
||||
}
|
||||
|
||||
impl Unpack for CloseResponse {
|
||||
fn unpack(cursor: &mut ReadCursor<'_>) -> Result<Self> {
|
||||
let structure_size = cursor.read_u16_le()?;
|
||||
if structure_size != Self::STRUCTURE_SIZE {
|
||||
return Err(Error::invalid_data(format!(
|
||||
"invalid CloseResponse structure size: expected {}, got {}",
|
||||
Self::STRUCTURE_SIZE,
|
||||
structure_size
|
||||
)));
|
||||
}
|
||||
|
||||
let flags = cursor.read_u16_le()?;
|
||||
let _reserved = cursor.read_u32_le()?;
|
||||
let creation_time = FileTime::unpack(cursor)?;
|
||||
let last_access_time = FileTime::unpack(cursor)?;
|
||||
let last_write_time = FileTime::unpack(cursor)?;
|
||||
let change_time = FileTime::unpack(cursor)?;
|
||||
let allocation_size = cursor.read_u64_le()?;
|
||||
let end_of_file = cursor.read_u64_le()?;
|
||||
let file_attributes = cursor.read_u32_le()?;
|
||||
|
||||
Ok(CloseResponse {
|
||||
flags,
|
||||
creation_time,
|
||||
last_access_time,
|
||||
last_write_time,
|
||||
change_time,
|
||||
allocation_size,
|
||||
end_of_file,
|
||||
file_attributes,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// ── CloseRequest tests ─────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn close_request_roundtrip() {
|
||||
let original = CloseRequest {
|
||||
flags: SMB2_CLOSE_FLAG_POSTQUERY_ATTRIB,
|
||||
file_id: FileId {
|
||||
persistent: 0x1122_3344_5566_7788,
|
||||
volatile: 0xAABB_CCDD_EEFF_0011,
|
||||
},
|
||||
};
|
||||
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
// 2 + 2 + 4 + 16 = 24 bytes
|
||||
assert_eq!(bytes.len(), 24);
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = CloseRequest::unpack(&mut r).unwrap();
|
||||
|
||||
assert_eq!(decoded.flags, original.flags);
|
||||
assert_eq!(decoded.file_id, original.file_id);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn close_request_known_bytes() {
|
||||
let mut buf = [0u8; 24];
|
||||
// StructureSize = 24
|
||||
buf[0..2].copy_from_slice(&24u16.to_le_bytes());
|
||||
// Flags = 0x0001
|
||||
buf[2..4].copy_from_slice(&1u16.to_le_bytes());
|
||||
// Reserved = 0
|
||||
buf[4..8].copy_from_slice(&0u32.to_le_bytes());
|
||||
// FileId persistent = 0x42
|
||||
buf[8..16].copy_from_slice(&0x42u64.to_le_bytes());
|
||||
// FileId volatile = 0x99
|
||||
buf[16..24].copy_from_slice(&0x99u64.to_le_bytes());
|
||||
|
||||
let mut cursor = ReadCursor::new(&buf);
|
||||
let req = CloseRequest::unpack(&mut cursor).unwrap();
|
||||
|
||||
assert_eq!(req.flags, SMB2_CLOSE_FLAG_POSTQUERY_ATTRIB);
|
||||
assert_eq!(req.file_id.persistent, 0x42);
|
||||
assert_eq!(req.file_id.volatile, 0x99);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn close_request_wrong_structure_size() {
|
||||
let mut buf = [0u8; 24];
|
||||
buf[0..2].copy_from_slice(&99u16.to_le_bytes());
|
||||
|
||||
let mut cursor = ReadCursor::new(&buf);
|
||||
let result = CloseRequest::unpack(&mut cursor);
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err().to_string();
|
||||
assert!(err.contains("structure size"), "error was: {err}");
|
||||
}
|
||||
|
||||
// ── CloseResponse tests ────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn close_response_roundtrip() {
|
||||
let original = CloseResponse {
|
||||
flags: SMB2_CLOSE_FLAG_POSTQUERY_ATTRIB,
|
||||
creation_time: FileTime(0x01D8_AAAA_BBBB_CCCC),
|
||||
last_access_time: FileTime(0x01D8_DDDD_EEEE_FFFF),
|
||||
last_write_time: FileTime(0x01D8_1111_2222_3333),
|
||||
change_time: FileTime(0x01D8_4444_5555_6666),
|
||||
allocation_size: 4096,
|
||||
end_of_file: 2048,
|
||||
file_attributes: 0x20, // FILE_ATTRIBUTE_ARCHIVE
|
||||
};
|
||||
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
// 2 + 2 + 4 + 8*6 + 4 = 60 bytes
|
||||
assert_eq!(bytes.len(), 60);
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = CloseResponse::unpack(&mut r).unwrap();
|
||||
|
||||
assert_eq!(decoded.flags, original.flags);
|
||||
assert_eq!(decoded.creation_time, original.creation_time);
|
||||
assert_eq!(decoded.last_access_time, original.last_access_time);
|
||||
assert_eq!(decoded.last_write_time, original.last_write_time);
|
||||
assert_eq!(decoded.change_time, original.change_time);
|
||||
assert_eq!(decoded.allocation_size, original.allocation_size);
|
||||
assert_eq!(decoded.end_of_file, original.end_of_file);
|
||||
assert_eq!(decoded.file_attributes, original.file_attributes);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn close_response_known_bytes() {
|
||||
let mut buf = [0u8; 60];
|
||||
// StructureSize = 60
|
||||
buf[0..2].copy_from_slice(&60u16.to_le_bytes());
|
||||
// Flags = 0x0001
|
||||
buf[2..4].copy_from_slice(&1u16.to_le_bytes());
|
||||
// Reserved = 0
|
||||
buf[4..8].copy_from_slice(&0u32.to_le_bytes());
|
||||
// CreationTime = 100
|
||||
buf[8..16].copy_from_slice(&100u64.to_le_bytes());
|
||||
// LastAccessTime = 200
|
||||
buf[16..24].copy_from_slice(&200u64.to_le_bytes());
|
||||
// LastWriteTime = 300
|
||||
buf[24..32].copy_from_slice(&300u64.to_le_bytes());
|
||||
// ChangeTime = 400
|
||||
buf[32..40].copy_from_slice(&400u64.to_le_bytes());
|
||||
// AllocationSize = 8192
|
||||
buf[40..48].copy_from_slice(&8192u64.to_le_bytes());
|
||||
// EndOfFile = 1024
|
||||
buf[48..56].copy_from_slice(&1024u64.to_le_bytes());
|
||||
// FileAttributes = 0x10 (directory)
|
||||
buf[56..60].copy_from_slice(&0x10u32.to_le_bytes());
|
||||
|
||||
let mut cursor = ReadCursor::new(&buf);
|
||||
let resp = CloseResponse::unpack(&mut cursor).unwrap();
|
||||
|
||||
assert_eq!(resp.flags, SMB2_CLOSE_FLAG_POSTQUERY_ATTRIB);
|
||||
assert_eq!(resp.creation_time, FileTime(100));
|
||||
assert_eq!(resp.last_access_time, FileTime(200));
|
||||
assert_eq!(resp.last_write_time, FileTime(300));
|
||||
assert_eq!(resp.change_time, FileTime(400));
|
||||
assert_eq!(resp.allocation_size, 8192);
|
||||
assert_eq!(resp.end_of_file, 1024);
|
||||
assert_eq!(resp.file_attributes, 0x10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn close_response_wrong_structure_size() {
|
||||
let mut buf = [0u8; 60];
|
||||
buf[0..2].copy_from_slice(&42u16.to_le_bytes());
|
||||
|
||||
let mut cursor = ReadCursor::new(&buf);
|
||||
let result = CloseResponse::unpack(&mut cursor);
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err().to_string();
|
||||
assert!(err.contains("structure size"), "error was: {err}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn close_response_zero_flags_has_zeroed_attributes() {
|
||||
let original = CloseResponse {
|
||||
flags: 0,
|
||||
creation_time: FileTime::ZERO,
|
||||
last_access_time: FileTime::ZERO,
|
||||
last_write_time: FileTime::ZERO,
|
||||
change_time: FileTime::ZERO,
|
||||
allocation_size: 0,
|
||||
end_of_file: 0,
|
||||
file_attributes: 0,
|
||||
};
|
||||
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = CloseResponse::unpack(&mut r).unwrap();
|
||||
|
||||
assert_eq!(decoded.flags, 0);
|
||||
assert_eq!(decoded.creation_time, FileTime::ZERO);
|
||||
assert_eq!(decoded.file_attributes, 0);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod roundtrip_props {
|
||||
use super::*;
|
||||
use crate::msg::roundtrip_strategies::{arb_file_id, arb_file_time};
|
||||
use proptest::prelude::*;
|
||||
|
||||
proptest! {
|
||||
#[test]
|
||||
fn close_request_pack_unpack(
|
||||
flags in any::<u16>(),
|
||||
file_id in arb_file_id(),
|
||||
) {
|
||||
let original = CloseRequest { flags, file_id };
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = CloseRequest::unpack(&mut r).unwrap();
|
||||
prop_assert_eq!(decoded, original);
|
||||
prop_assert!(r.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn close_response_pack_unpack(
|
||||
flags in any::<u16>(),
|
||||
creation_time in arb_file_time(),
|
||||
last_access_time in arb_file_time(),
|
||||
last_write_time in arb_file_time(),
|
||||
change_time in arb_file_time(),
|
||||
allocation_size in any::<u64>(),
|
||||
end_of_file in any::<u64>(),
|
||||
file_attributes in any::<u32>(),
|
||||
) {
|
||||
let original = CloseResponse {
|
||||
flags,
|
||||
creation_time,
|
||||
last_access_time,
|
||||
last_write_time,
|
||||
change_time,
|
||||
allocation_size,
|
||||
end_of_file,
|
||||
file_attributes,
|
||||
};
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = CloseResponse::unpack(&mut r).unwrap();
|
||||
prop_assert_eq!(decoded, original);
|
||||
prop_assert!(r.is_empty());
|
||||
}
|
||||
}
|
||||
}
|
||||
870
vendor/smb2/src/msg/create.rs
vendored
Normal file
870
vendor/smb2/src/msg/create.rs
vendored
Normal file
@@ -0,0 +1,870 @@
|
||||
//! SMB2 CREATE request and response (spec sections 2.2.13, 2.2.14).
|
||||
//!
|
||||
//! The CREATE request opens or creates a file, named pipe, or printer.
|
||||
//! The response carries the file handle ([`FileId`]) plus timestamps,
|
||||
//! attributes, and optional create contexts.
|
||||
|
||||
use crate::error::Result;
|
||||
use crate::msg::header::Header;
|
||||
use crate::pack::{FileTime, Pack, ReadCursor, Unpack, WriteCursor};
|
||||
use crate::types::flags::FileAccessMask;
|
||||
use crate::types::{FileId, OplockLevel};
|
||||
use crate::Error;
|
||||
|
||||
// ── Enums ────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Impersonation level (MS-SMB2 2.2.13).
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
#[repr(u32)]
|
||||
pub enum ImpersonationLevel {
|
||||
/// Anonymous impersonation.
|
||||
Anonymous = 0,
|
||||
/// Identification impersonation.
|
||||
Identification = 1,
|
||||
/// Impersonation level.
|
||||
Impersonation = 2,
|
||||
/// Delegate impersonation.
|
||||
Delegate = 3,
|
||||
}
|
||||
|
||||
impl TryFrom<u32> for ImpersonationLevel {
|
||||
type Error = Error;
|
||||
|
||||
fn try_from(value: u32) -> Result<Self> {
|
||||
match value {
|
||||
0 => Ok(Self::Anonymous),
|
||||
1 => Ok(Self::Identification),
|
||||
2 => Ok(Self::Impersonation),
|
||||
3 => Ok(Self::Delegate),
|
||||
_ => Err(Error::invalid_data(format!(
|
||||
"invalid ImpersonationLevel: {}",
|
||||
value
|
||||
))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Share access flags (MS-SMB2 2.2.13).
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
|
||||
pub struct ShareAccess(pub u32);
|
||||
|
||||
impl ShareAccess {
|
||||
/// Allow other opens to read the file.
|
||||
pub const FILE_SHARE_READ: u32 = 0x0000_0001;
|
||||
/// Allow other opens to write the file.
|
||||
pub const FILE_SHARE_WRITE: u32 = 0x0000_0002;
|
||||
/// Allow other opens to delete the file.
|
||||
pub const FILE_SHARE_DELETE: u32 = 0x0000_0004;
|
||||
}
|
||||
|
||||
/// Create disposition (MS-SMB2 2.2.13).
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
#[repr(u32)]
|
||||
pub enum CreateDisposition {
|
||||
/// If the file exists, supersede it. Otherwise, create.
|
||||
FileSupersede = 0,
|
||||
/// If the file exists, open it. Otherwise, fail.
|
||||
FileOpen = 1,
|
||||
/// If the file exists, fail. Otherwise, create.
|
||||
FileCreate = 2,
|
||||
/// If the file exists, open it. Otherwise, create.
|
||||
FileOpenIf = 3,
|
||||
/// If the file exists, overwrite it. Otherwise, fail.
|
||||
FileOverwrite = 4,
|
||||
/// If the file exists, overwrite it. Otherwise, create.
|
||||
FileOverwriteIf = 5,
|
||||
}
|
||||
|
||||
impl TryFrom<u32> for CreateDisposition {
|
||||
type Error = Error;
|
||||
|
||||
fn try_from(value: u32) -> Result<Self> {
|
||||
match value {
|
||||
0 => Ok(Self::FileSupersede),
|
||||
1 => Ok(Self::FileOpen),
|
||||
2 => Ok(Self::FileCreate),
|
||||
3 => Ok(Self::FileOpenIf),
|
||||
4 => Ok(Self::FileOverwrite),
|
||||
5 => Ok(Self::FileOverwriteIf),
|
||||
_ => Err(Error::invalid_data(format!(
|
||||
"invalid CreateDisposition: {}",
|
||||
value
|
||||
))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Create action returned in the response (MS-SMB2 2.2.14).
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
#[repr(u32)]
|
||||
pub enum CreateAction {
|
||||
/// An existing file was superseded.
|
||||
FileSuperseded = 0,
|
||||
/// An existing file was opened.
|
||||
FileOpened = 1,
|
||||
/// A new file was created.
|
||||
FileCreated = 2,
|
||||
/// An existing file was overwritten.
|
||||
FileOverwritten = 3,
|
||||
}
|
||||
|
||||
impl TryFrom<u32> for CreateAction {
|
||||
type Error = Error;
|
||||
|
||||
fn try_from(value: u32) -> Result<Self> {
|
||||
match value {
|
||||
0 => Ok(Self::FileSuperseded),
|
||||
1 => Ok(Self::FileOpened),
|
||||
2 => Ok(Self::FileCreated),
|
||||
3 => Ok(Self::FileOverwritten),
|
||||
_ => Err(Error::invalid_data(format!(
|
||||
"invalid CreateAction: {}",
|
||||
value
|
||||
))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── CreateRequest ────────────────────────────────────────────────────────
|
||||
|
||||
/// SMB2 CREATE request (spec section 2.2.13).
|
||||
///
|
||||
/// Sent by the client to open or create a file on the server.
|
||||
/// The buffer contains the filename encoded as UTF-16LE, optionally
|
||||
/// followed by create context data.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct CreateRequest {
|
||||
/// Requested oplock level.
|
||||
pub requested_oplock_level: OplockLevel,
|
||||
/// Impersonation level.
|
||||
pub impersonation_level: ImpersonationLevel,
|
||||
/// Desired access rights.
|
||||
pub desired_access: FileAccessMask,
|
||||
/// File attributes for create/open.
|
||||
pub file_attributes: u32,
|
||||
/// Sharing mode.
|
||||
pub share_access: ShareAccess,
|
||||
/// Disposition: what to do if file exists/does not exist.
|
||||
pub create_disposition: CreateDisposition,
|
||||
/// Create options flags.
|
||||
pub create_options: u32,
|
||||
/// The filename to create or open.
|
||||
pub name: String,
|
||||
/// Raw create context bytes (unparsed).
|
||||
pub create_contexts: Vec<u8>,
|
||||
}
|
||||
|
||||
impl CreateRequest {
|
||||
pub const STRUCTURE_SIZE: u16 = 57;
|
||||
}
|
||||
|
||||
impl Pack for CreateRequest {
|
||||
fn pack(&self, cursor: &mut WriteCursor) {
|
||||
let start = cursor.position();
|
||||
|
||||
// StructureSize (2 bytes)
|
||||
cursor.write_u16_le(Self::STRUCTURE_SIZE);
|
||||
// SecurityFlags (1 byte) -- must be 0
|
||||
cursor.write_u8(0);
|
||||
// RequestedOplockLevel (1 byte)
|
||||
cursor.write_u8(self.requested_oplock_level as u8);
|
||||
// ImpersonationLevel (4 bytes)
|
||||
cursor.write_u32_le(self.impersonation_level as u32);
|
||||
// SmbCreateFlags (8 bytes) -- must be 0
|
||||
cursor.write_u64_le(0);
|
||||
// Reserved (8 bytes)
|
||||
cursor.write_u64_le(0);
|
||||
// DesiredAccess (4 bytes)
|
||||
cursor.write_u32_le(self.desired_access.bits());
|
||||
// FileAttributes (4 bytes)
|
||||
cursor.write_u32_le(self.file_attributes);
|
||||
// ShareAccess (4 bytes)
|
||||
cursor.write_u32_le(self.share_access.0);
|
||||
// CreateDisposition (4 bytes)
|
||||
cursor.write_u32_le(self.create_disposition as u32);
|
||||
// CreateOptions (4 bytes)
|
||||
cursor.write_u32_le(self.create_options);
|
||||
|
||||
// NameOffset (2 bytes) -- placeholder, backpatch later
|
||||
let name_offset_pos = cursor.position();
|
||||
cursor.write_u16_le(0);
|
||||
// NameLength (2 bytes) -- placeholder, backpatch later
|
||||
let name_length_pos = cursor.position();
|
||||
cursor.write_u16_le(0);
|
||||
// CreateContextsOffset (4 bytes) -- placeholder
|
||||
let ctx_offset_pos = cursor.position();
|
||||
cursor.write_u32_le(0);
|
||||
// CreateContextsLength (4 bytes) -- placeholder
|
||||
let ctx_length_pos = cursor.position();
|
||||
cursor.write_u32_le(0);
|
||||
|
||||
// Buffer: filename in UTF-16LE
|
||||
// Offsets are from the beginning of the SMB2 header per spec.
|
||||
let name_offset = Header::SIZE + (cursor.position() - start);
|
||||
let name_start = cursor.position();
|
||||
cursor.write_utf16_le(&self.name);
|
||||
let name_byte_len = cursor.position() - name_start;
|
||||
|
||||
// Backpatch name offset and length
|
||||
cursor.set_u16_le_at(name_offset_pos, name_offset as u16);
|
||||
cursor.set_u16_le_at(name_length_pos, name_byte_len as u16);
|
||||
|
||||
// Create contexts (if any)
|
||||
if !self.create_contexts.is_empty() {
|
||||
// Align to 8-byte boundary before create contexts
|
||||
cursor.align_to(8);
|
||||
let ctx_offset = Header::SIZE + (cursor.position() - start);
|
||||
cursor.write_bytes(&self.create_contexts);
|
||||
let ctx_len = self.create_contexts.len();
|
||||
|
||||
cursor.set_u32_le_at(ctx_offset_pos, ctx_offset as u32);
|
||||
cursor.set_u32_le_at(ctx_length_pos, ctx_len as u32);
|
||||
} else if name_byte_len == 0 {
|
||||
// Per spec, buffer must be at least 1 byte even if name is empty
|
||||
cursor.write_u8(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Unpack for CreateRequest {
|
||||
fn unpack(cursor: &mut ReadCursor<'_>) -> Result<Self> {
|
||||
let start = cursor.position();
|
||||
|
||||
// StructureSize (2 bytes)
|
||||
let structure_size = cursor.read_u16_le()?;
|
||||
if structure_size != Self::STRUCTURE_SIZE {
|
||||
return Err(Error::invalid_data(format!(
|
||||
"invalid CreateRequest structure size: expected {}, got {}",
|
||||
Self::STRUCTURE_SIZE,
|
||||
structure_size
|
||||
)));
|
||||
}
|
||||
|
||||
// SecurityFlags (1 byte)
|
||||
let _security_flags = cursor.read_u8()?;
|
||||
// RequestedOplockLevel (1 byte)
|
||||
let oplock_raw = cursor.read_u8()?;
|
||||
let requested_oplock_level = OplockLevel::try_from(oplock_raw)?;
|
||||
// ImpersonationLevel (4 bytes)
|
||||
let imp_raw = cursor.read_u32_le()?;
|
||||
let impersonation_level = ImpersonationLevel::try_from(imp_raw)?;
|
||||
// SmbCreateFlags (8 bytes)
|
||||
let _smb_create_flags = cursor.read_u64_le()?;
|
||||
// Reserved (8 bytes)
|
||||
let _reserved = cursor.read_u64_le()?;
|
||||
// DesiredAccess (4 bytes)
|
||||
let desired_access = FileAccessMask::new(cursor.read_u32_le()?);
|
||||
// FileAttributes (4 bytes)
|
||||
let file_attributes = cursor.read_u32_le()?;
|
||||
// ShareAccess (4 bytes)
|
||||
let share_access = ShareAccess(cursor.read_u32_le()?);
|
||||
// CreateDisposition (4 bytes)
|
||||
let disp_raw = cursor.read_u32_le()?;
|
||||
let create_disposition = CreateDisposition::try_from(disp_raw)?;
|
||||
// CreateOptions (4 bytes)
|
||||
let create_options = cursor.read_u32_le()?;
|
||||
// NameOffset (2 bytes)
|
||||
let name_offset = cursor.read_u16_le()? as usize;
|
||||
// NameLength (2 bytes)
|
||||
let name_length = cursor.read_u16_le()? as usize;
|
||||
// CreateContextsOffset (4 bytes)
|
||||
let ctx_offset = cursor.read_u32_le()? as usize;
|
||||
// CreateContextsLength (4 bytes)
|
||||
let ctx_length = cursor.read_u32_le()? as usize;
|
||||
|
||||
// Read filename
|
||||
// Offsets on the wire are from the beginning of the SMB2 header,
|
||||
// so subtract Header::SIZE to get position within the body.
|
||||
let name = if name_length > 0 {
|
||||
let current = cursor.position();
|
||||
let body_offset = name_offset.saturating_sub(Header::SIZE);
|
||||
let target = start + body_offset;
|
||||
if target > current {
|
||||
cursor.skip(target - current)?;
|
||||
}
|
||||
cursor.read_utf16_le(name_length)?
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
|
||||
// Read create contexts
|
||||
let create_contexts = if ctx_length > 0 {
|
||||
let current = cursor.position();
|
||||
let body_offset = ctx_offset.saturating_sub(Header::SIZE);
|
||||
let target = start + body_offset;
|
||||
if target > current {
|
||||
cursor.skip(target - current)?;
|
||||
}
|
||||
cursor.read_bytes_bounded(ctx_length)?.to_vec()
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
|
||||
Ok(CreateRequest {
|
||||
requested_oplock_level,
|
||||
impersonation_level,
|
||||
desired_access,
|
||||
file_attributes,
|
||||
share_access,
|
||||
create_disposition,
|
||||
create_options,
|
||||
name,
|
||||
create_contexts,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ── CreateResponse ───────────────────────────────────────────────────────
|
||||
|
||||
/// SMB2 CREATE response (spec section 2.2.14).
|
||||
///
|
||||
/// Returned by the server with the file handle and metadata about
|
||||
/// the created or opened file.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct CreateResponse {
|
||||
/// Oplock level granted by the server.
|
||||
pub oplock_level: OplockLevel,
|
||||
/// Flags (SMB 3.x only).
|
||||
pub flags: u8,
|
||||
/// Action taken by the server (opened, created, etc.).
|
||||
pub create_action: CreateAction,
|
||||
/// Time the file was created.
|
||||
pub creation_time: FileTime,
|
||||
/// Time the file was last accessed.
|
||||
pub last_access_time: FileTime,
|
||||
/// Time the file was last written.
|
||||
pub last_write_time: FileTime,
|
||||
/// Time the file metadata was last changed.
|
||||
pub change_time: FileTime,
|
||||
/// Allocation size of the file in bytes.
|
||||
pub allocation_size: u64,
|
||||
/// End-of-file position (actual file size in bytes).
|
||||
pub end_of_file: u64,
|
||||
/// File attributes.
|
||||
pub file_attributes: u32,
|
||||
/// The file handle.
|
||||
pub file_id: FileId,
|
||||
/// Raw create context bytes from the response.
|
||||
pub create_contexts: Vec<u8>,
|
||||
}
|
||||
|
||||
impl CreateResponse {
|
||||
pub const STRUCTURE_SIZE: u16 = 89;
|
||||
}
|
||||
|
||||
impl Pack for CreateResponse {
|
||||
fn pack(&self, cursor: &mut WriteCursor) {
|
||||
let start = cursor.position();
|
||||
|
||||
// StructureSize (2 bytes)
|
||||
cursor.write_u16_le(Self::STRUCTURE_SIZE);
|
||||
// OplockLevel (1 byte)
|
||||
cursor.write_u8(self.oplock_level as u8);
|
||||
// Flags (1 byte)
|
||||
cursor.write_u8(self.flags);
|
||||
// CreateAction (4 bytes)
|
||||
cursor.write_u32_le(self.create_action as u32);
|
||||
// CreationTime (8 bytes)
|
||||
self.creation_time.pack(cursor);
|
||||
// LastAccessTime (8 bytes)
|
||||
self.last_access_time.pack(cursor);
|
||||
// LastWriteTime (8 bytes)
|
||||
self.last_write_time.pack(cursor);
|
||||
// ChangeTime (8 bytes)
|
||||
self.change_time.pack(cursor);
|
||||
// AllocationSize (8 bytes)
|
||||
cursor.write_u64_le(self.allocation_size);
|
||||
// EndOfFile (8 bytes)
|
||||
cursor.write_u64_le(self.end_of_file);
|
||||
// FileAttributes (4 bytes)
|
||||
cursor.write_u32_le(self.file_attributes);
|
||||
// Reserved2 (4 bytes)
|
||||
cursor.write_u32_le(0);
|
||||
// FileId (16 bytes = persistent u64 + volatile u64)
|
||||
cursor.write_u64_le(self.file_id.persistent);
|
||||
cursor.write_u64_le(self.file_id.volatile);
|
||||
// CreateContextsOffset (4 bytes) -- placeholder
|
||||
let ctx_offset_pos = cursor.position();
|
||||
cursor.write_u32_le(0);
|
||||
// CreateContextsLength (4 bytes) -- placeholder
|
||||
let ctx_length_pos = cursor.position();
|
||||
cursor.write_u32_le(0);
|
||||
|
||||
// Create contexts (if any)
|
||||
if !self.create_contexts.is_empty() {
|
||||
cursor.align_to(8);
|
||||
let ctx_offset = Header::SIZE + (cursor.position() - start);
|
||||
cursor.write_bytes(&self.create_contexts);
|
||||
let ctx_len = self.create_contexts.len();
|
||||
|
||||
cursor.set_u32_le_at(ctx_offset_pos, ctx_offset as u32);
|
||||
cursor.set_u32_le_at(ctx_length_pos, ctx_len as u32);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Unpack for CreateResponse {
|
||||
fn unpack(cursor: &mut ReadCursor<'_>) -> Result<Self> {
|
||||
let start = cursor.position();
|
||||
|
||||
// StructureSize (2 bytes)
|
||||
let structure_size = cursor.read_u16_le()?;
|
||||
if structure_size != Self::STRUCTURE_SIZE {
|
||||
return Err(Error::invalid_data(format!(
|
||||
"invalid CreateResponse structure size: expected {}, got {}",
|
||||
Self::STRUCTURE_SIZE,
|
||||
structure_size
|
||||
)));
|
||||
}
|
||||
|
||||
// OplockLevel (1 byte)
|
||||
let oplock_level = OplockLevel::try_from(cursor.read_u8()?)?;
|
||||
// Flags (1 byte)
|
||||
let flags = cursor.read_u8()?;
|
||||
// CreateAction (4 bytes)
|
||||
let create_action = CreateAction::try_from(cursor.read_u32_le()?)?;
|
||||
// CreationTime (8 bytes)
|
||||
let creation_time = FileTime::unpack(cursor)?;
|
||||
// LastAccessTime (8 bytes)
|
||||
let last_access_time = FileTime::unpack(cursor)?;
|
||||
// LastWriteTime (8 bytes)
|
||||
let last_write_time = FileTime::unpack(cursor)?;
|
||||
// ChangeTime (8 bytes)
|
||||
let change_time = FileTime::unpack(cursor)?;
|
||||
// AllocationSize (8 bytes)
|
||||
let allocation_size = cursor.read_u64_le()?;
|
||||
// EndOfFile (8 bytes)
|
||||
let end_of_file = cursor.read_u64_le()?;
|
||||
// FileAttributes (4 bytes)
|
||||
let file_attributes = cursor.read_u32_le()?;
|
||||
// Reserved2 (4 bytes)
|
||||
let _reserved2 = cursor.read_u32_le()?;
|
||||
// FileId (16 bytes)
|
||||
let persistent = cursor.read_u64_le()?;
|
||||
let volatile = cursor.read_u64_le()?;
|
||||
let file_id = FileId {
|
||||
persistent,
|
||||
volatile,
|
||||
};
|
||||
// CreateContextsOffset (4 bytes)
|
||||
let ctx_offset = cursor.read_u32_le()? as usize;
|
||||
// CreateContextsLength (4 bytes)
|
||||
let ctx_length = cursor.read_u32_le()? as usize;
|
||||
|
||||
// Read create contexts
|
||||
// Offset on the wire is from beginning of SMB2 header.
|
||||
let create_contexts = if ctx_length > 0 {
|
||||
let current = cursor.position();
|
||||
let body_offset = ctx_offset.saturating_sub(Header::SIZE);
|
||||
let target = start + body_offset;
|
||||
if target > current {
|
||||
cursor.skip(target - current)?;
|
||||
}
|
||||
cursor.read_bytes_bounded(ctx_length)?.to_vec()
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
|
||||
Ok(CreateResponse {
|
||||
oplock_level,
|
||||
flags,
|
||||
create_action,
|
||||
creation_time,
|
||||
last_access_time,
|
||||
last_write_time,
|
||||
change_time,
|
||||
allocation_size,
|
||||
end_of_file,
|
||||
file_attributes,
|
||||
file_id,
|
||||
create_contexts,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// ── CreateRequest tests ──────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn create_request_roundtrip_no_contexts() {
|
||||
let original = CreateRequest {
|
||||
requested_oplock_level: OplockLevel::Exclusive,
|
||||
impersonation_level: ImpersonationLevel::Impersonation,
|
||||
desired_access: FileAccessMask::new(
|
||||
FileAccessMask::GENERIC_READ | FileAccessMask::FILE_READ_ATTRIBUTES,
|
||||
),
|
||||
file_attributes: 0x80, // FILE_ATTRIBUTE_NORMAL
|
||||
share_access: ShareAccess(ShareAccess::FILE_SHARE_READ | ShareAccess::FILE_SHARE_WRITE),
|
||||
create_disposition: CreateDisposition::FileOpenIf,
|
||||
create_options: 0,
|
||||
name: "test\\file.txt".to_string(),
|
||||
create_contexts: Vec::new(),
|
||||
};
|
||||
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = CreateRequest::unpack(&mut r).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
decoded.requested_oplock_level,
|
||||
original.requested_oplock_level
|
||||
);
|
||||
assert_eq!(decoded.impersonation_level, original.impersonation_level);
|
||||
assert_eq!(decoded.desired_access, original.desired_access);
|
||||
assert_eq!(decoded.file_attributes, original.file_attributes);
|
||||
assert_eq!(decoded.share_access, original.share_access);
|
||||
assert_eq!(decoded.create_disposition, original.create_disposition);
|
||||
assert_eq!(decoded.create_options, original.create_options);
|
||||
assert_eq!(decoded.name, original.name);
|
||||
assert!(decoded.create_contexts.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn create_request_roundtrip_with_create_contexts() {
|
||||
// Simulate a raw create context blob (for example, a
|
||||
// SMB2_CREATE_QUERY_MAXIMAL_ACCESS_REQUEST context).
|
||||
let fake_ctx = vec![
|
||||
0x00, 0x00, 0x00, 0x00, // NextEntryOffset = 0 (last entry)
|
||||
0x10, 0x00, // NameOffset = 16
|
||||
0x04, 0x00, // NameLength = 4
|
||||
0x00, 0x00, // Reserved
|
||||
0x18, 0x00, // DataOffset = 24
|
||||
0x04, 0x00, 0x00, 0x00, // DataLength = 4
|
||||
b'M', b'x', b'A', b'c', // Name = "MxAc"
|
||||
0x00, 0x00, 0x00, 0x00, // padding
|
||||
0x01, 0x02, 0x03, 0x04, // Data (4 bytes)
|
||||
];
|
||||
|
||||
let original = CreateRequest {
|
||||
requested_oplock_level: OplockLevel::Batch,
|
||||
impersonation_level: ImpersonationLevel::Delegate,
|
||||
desired_access: FileAccessMask::new(FileAccessMask::GENERIC_ALL),
|
||||
file_attributes: 0x20, // FILE_ATTRIBUTE_ARCHIVE
|
||||
share_access: ShareAccess(ShareAccess::FILE_SHARE_DELETE),
|
||||
create_disposition: CreateDisposition::FileCreate,
|
||||
create_options: 0x0000_0040, // FILE_NON_DIRECTORY_FILE
|
||||
name: "share\\docs\\report.docx".to_string(),
|
||||
create_contexts: fake_ctx.clone(),
|
||||
};
|
||||
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = CreateRequest::unpack(&mut r).unwrap();
|
||||
|
||||
assert_eq!(decoded.requested_oplock_level, OplockLevel::Batch);
|
||||
assert_eq!(decoded.impersonation_level, ImpersonationLevel::Delegate);
|
||||
assert_eq!(decoded.name, "share\\docs\\report.docx");
|
||||
assert_eq!(decoded.create_contexts, fake_ctx);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn create_request_structure_size_field() {
|
||||
let req = CreateRequest {
|
||||
requested_oplock_level: OplockLevel::None,
|
||||
impersonation_level: ImpersonationLevel::Anonymous,
|
||||
desired_access: FileAccessMask::default(),
|
||||
file_attributes: 0,
|
||||
share_access: ShareAccess::default(),
|
||||
create_disposition: CreateDisposition::FileOpen,
|
||||
create_options: 0,
|
||||
name: "x".to_string(),
|
||||
create_contexts: Vec::new(),
|
||||
};
|
||||
|
||||
let mut w = WriteCursor::new();
|
||||
req.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
// First two bytes are StructureSize = 57
|
||||
assert_eq!(bytes[0], 57);
|
||||
assert_eq!(bytes[1], 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn create_request_wrong_structure_size() {
|
||||
let mut buf = vec![0u8; 64];
|
||||
// Set wrong structure size
|
||||
buf[0..2].copy_from_slice(&99u16.to_le_bytes());
|
||||
let mut cursor = ReadCursor::new(&buf);
|
||||
let result = CreateRequest::unpack(&mut cursor);
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err().to_string();
|
||||
assert!(err.contains("structure size"), "error was: {err}");
|
||||
}
|
||||
|
||||
// ── CreateResponse tests ─────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn create_response_roundtrip() {
|
||||
let original = CreateResponse {
|
||||
oplock_level: OplockLevel::LevelII,
|
||||
flags: 0,
|
||||
create_action: CreateAction::FileOpened,
|
||||
creation_time: FileTime(133_485_408_000_000_000),
|
||||
last_access_time: FileTime(133_485_408_100_000_000),
|
||||
last_write_time: FileTime(133_485_408_200_000_000),
|
||||
change_time: FileTime(133_485_408_300_000_000),
|
||||
allocation_size: 4096,
|
||||
end_of_file: 1234,
|
||||
file_attributes: 0x20, // FILE_ATTRIBUTE_ARCHIVE
|
||||
file_id: FileId {
|
||||
persistent: 0x1111_2222_3333_4444,
|
||||
volatile: 0x5555_6666_7777_8888,
|
||||
},
|
||||
create_contexts: Vec::new(),
|
||||
};
|
||||
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = CreateResponse::unpack(&mut r).unwrap();
|
||||
|
||||
assert_eq!(decoded.oplock_level, original.oplock_level);
|
||||
assert_eq!(decoded.flags, original.flags);
|
||||
assert_eq!(decoded.create_action, original.create_action);
|
||||
assert_eq!(decoded.creation_time, original.creation_time);
|
||||
assert_eq!(decoded.last_access_time, original.last_access_time);
|
||||
assert_eq!(decoded.last_write_time, original.last_write_time);
|
||||
assert_eq!(decoded.change_time, original.change_time);
|
||||
assert_eq!(decoded.allocation_size, original.allocation_size);
|
||||
assert_eq!(decoded.end_of_file, original.end_of_file);
|
||||
assert_eq!(decoded.file_attributes, original.file_attributes);
|
||||
assert_eq!(decoded.file_id, original.file_id);
|
||||
assert!(decoded.create_contexts.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn create_response_with_contexts() {
|
||||
let ctx_data = vec![0xAA, 0xBB, 0xCC, 0xDD];
|
||||
let original = CreateResponse {
|
||||
oplock_level: OplockLevel::None,
|
||||
flags: 0x01,
|
||||
create_action: CreateAction::FileCreated,
|
||||
creation_time: FileTime(100),
|
||||
last_access_time: FileTime(200),
|
||||
last_write_time: FileTime(300),
|
||||
change_time: FileTime(400),
|
||||
allocation_size: 0,
|
||||
end_of_file: 0,
|
||||
file_attributes: 0,
|
||||
file_id: FileId {
|
||||
persistent: 1,
|
||||
volatile: 2,
|
||||
},
|
||||
create_contexts: ctx_data.clone(),
|
||||
};
|
||||
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = CreateResponse::unpack(&mut r).unwrap();
|
||||
|
||||
assert_eq!(decoded.create_action, CreateAction::FileCreated);
|
||||
assert_eq!(decoded.file_id.persistent, 1);
|
||||
assert_eq!(decoded.file_id.volatile, 2);
|
||||
assert_eq!(decoded.create_contexts, ctx_data);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn create_response_structure_size_field() {
|
||||
let resp = CreateResponse {
|
||||
oplock_level: OplockLevel::None,
|
||||
flags: 0,
|
||||
create_action: CreateAction::FileOpened,
|
||||
creation_time: FileTime::ZERO,
|
||||
last_access_time: FileTime::ZERO,
|
||||
last_write_time: FileTime::ZERO,
|
||||
change_time: FileTime::ZERO,
|
||||
allocation_size: 0,
|
||||
end_of_file: 0,
|
||||
file_attributes: 0,
|
||||
file_id: FileId::default(),
|
||||
create_contexts: Vec::new(),
|
||||
};
|
||||
|
||||
let mut w = WriteCursor::new();
|
||||
resp.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
// First two bytes are StructureSize = 89
|
||||
assert_eq!(bytes[0], 89);
|
||||
assert_eq!(bytes[1], 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn create_response_wrong_structure_size() {
|
||||
let mut buf = vec![0u8; 96];
|
||||
buf[0..2].copy_from_slice(&42u16.to_le_bytes());
|
||||
let mut cursor = ReadCursor::new(&buf);
|
||||
let result = CreateResponse::unpack(&mut cursor);
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err().to_string();
|
||||
assert!(err.contains("structure size"), "error was: {err}");
|
||||
}
|
||||
|
||||
// ── Enum conversion tests ────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn oplock_level_roundtrip() {
|
||||
for &level in &[
|
||||
OplockLevel::None,
|
||||
OplockLevel::LevelII,
|
||||
OplockLevel::Exclusive,
|
||||
OplockLevel::Batch,
|
||||
OplockLevel::Lease,
|
||||
] {
|
||||
let raw = level as u8;
|
||||
let decoded = OplockLevel::try_from(raw).unwrap();
|
||||
assert_eq!(decoded, level);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn oplock_level_invalid() {
|
||||
assert!(OplockLevel::try_from(0x42).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn impersonation_level_roundtrip() {
|
||||
for &level in &[
|
||||
ImpersonationLevel::Anonymous,
|
||||
ImpersonationLevel::Identification,
|
||||
ImpersonationLevel::Impersonation,
|
||||
ImpersonationLevel::Delegate,
|
||||
] {
|
||||
let raw = level as u32;
|
||||
let decoded = ImpersonationLevel::try_from(raw).unwrap();
|
||||
assert_eq!(decoded, level);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn create_disposition_roundtrip() {
|
||||
for &disp in &[
|
||||
CreateDisposition::FileSupersede,
|
||||
CreateDisposition::FileOpen,
|
||||
CreateDisposition::FileCreate,
|
||||
CreateDisposition::FileOpenIf,
|
||||
CreateDisposition::FileOverwrite,
|
||||
CreateDisposition::FileOverwriteIf,
|
||||
] {
|
||||
let raw = disp as u32;
|
||||
let decoded = CreateDisposition::try_from(raw).unwrap();
|
||||
assert_eq!(decoded, disp);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn create_action_roundtrip() {
|
||||
for &action in &[
|
||||
CreateAction::FileSuperseded,
|
||||
CreateAction::FileOpened,
|
||||
CreateAction::FileCreated,
|
||||
CreateAction::FileOverwritten,
|
||||
] {
|
||||
let raw = action as u32;
|
||||
let decoded = CreateAction::try_from(raw).unwrap();
|
||||
assert_eq!(decoded, action);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod roundtrip_props {
|
||||
use super::*;
|
||||
use crate::msg::roundtrip_strategies::{
|
||||
arb_create_action, arb_create_disposition, arb_file_access_mask, arb_file_id,
|
||||
arb_file_time, arb_impersonation_level, arb_oplock_level, arb_share_access,
|
||||
arb_small_bytes, arb_utf16_string,
|
||||
};
|
||||
use proptest::prelude::*;
|
||||
|
||||
proptest! {
|
||||
#[test]
|
||||
fn create_request_pack_unpack(
|
||||
requested_oplock_level in arb_oplock_level(),
|
||||
impersonation_level in arb_impersonation_level(),
|
||||
desired_access in arb_file_access_mask(),
|
||||
file_attributes in any::<u32>(),
|
||||
share_access in arb_share_access(),
|
||||
create_disposition in arb_create_disposition(),
|
||||
create_options in any::<u32>(),
|
||||
name in arb_utf16_string(128),
|
||||
create_contexts in arb_small_bytes(),
|
||||
) {
|
||||
let original = CreateRequest {
|
||||
requested_oplock_level,
|
||||
impersonation_level,
|
||||
desired_access,
|
||||
file_attributes,
|
||||
share_access,
|
||||
create_disposition,
|
||||
create_options,
|
||||
name,
|
||||
create_contexts,
|
||||
};
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = CreateRequest::unpack(&mut r).unwrap();
|
||||
prop_assert_eq!(decoded, original);
|
||||
// Note: pack may write a trailing 1-byte pad when name is empty
|
||||
// and there are no create contexts. Unpack only advances through
|
||||
// fields it reads, so the cursor may have 1 trailing byte in
|
||||
// that corner case. That's fine for symmetry on struct contents.
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn create_response_pack_unpack(
|
||||
oplock_level in arb_oplock_level(),
|
||||
flags in any::<u8>(),
|
||||
create_action in arb_create_action(),
|
||||
creation_time in arb_file_time(),
|
||||
last_access_time in arb_file_time(),
|
||||
last_write_time in arb_file_time(),
|
||||
change_time in arb_file_time(),
|
||||
allocation_size in any::<u64>(),
|
||||
end_of_file in any::<u64>(),
|
||||
file_attributes in any::<u32>(),
|
||||
file_id in arb_file_id(),
|
||||
create_contexts in arb_small_bytes(),
|
||||
) {
|
||||
let original = CreateResponse {
|
||||
oplock_level,
|
||||
flags,
|
||||
create_action,
|
||||
creation_time,
|
||||
last_access_time,
|
||||
last_write_time,
|
||||
change_time,
|
||||
allocation_size,
|
||||
end_of_file,
|
||||
file_attributes,
|
||||
file_id,
|
||||
create_contexts,
|
||||
};
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = CreateResponse::unpack(&mut r).unwrap();
|
||||
prop_assert_eq!(decoded, original);
|
||||
}
|
||||
}
|
||||
}
|
||||
697
vendor/smb2/src/msg/dfs.rs
vendored
Normal file
697
vendor/smb2/src/msg/dfs.rs
vendored
Normal file
@@ -0,0 +1,697 @@
|
||||
//! DFS referral request and response wire format (MS-DFSC sections 2.2.2, 2.2.4).
|
||||
//!
|
||||
//! These types are packed into the input/output buffers of an IOCTL request
|
||||
//! with `ctl_code = FSCTL_DFS_GET_REFERRALS`.
|
||||
|
||||
use crate::error::Result;
|
||||
use crate::pack::{Pack, ReadCursor, Unpack, WriteCursor};
|
||||
use crate::Error;
|
||||
|
||||
// ── ReqGetDfsReferral ─────────────────────────────────────────────────
|
||||
|
||||
/// REQ_GET_DFS_REFERRAL (MS-DFSC 2.2.2).
|
||||
///
|
||||
/// Sent as the input buffer of an `FSCTL_DFS_GET_REFERRALS` IOCTL request.
|
||||
/// Contains the maximum referral version the client understands and the
|
||||
/// DFS path to resolve.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct ReqGetDfsReferral {
|
||||
/// Highest DFS referral version understood by the client (typically 4).
|
||||
pub max_referral_level: u16,
|
||||
/// The DFS path to resolve (case-insensitive UNC path).
|
||||
pub request_file_name: String,
|
||||
}
|
||||
|
||||
impl Pack for ReqGetDfsReferral {
|
||||
fn pack(&self, cursor: &mut WriteCursor) {
|
||||
// MaxReferralLevel (2 bytes, LE)
|
||||
cursor.write_u16_le(self.max_referral_level);
|
||||
// RequestFileName (null-terminated UTF-16LE)
|
||||
cursor.write_utf16_le(&self.request_file_name);
|
||||
// Null terminator (2 bytes)
|
||||
cursor.write_u16_le(0);
|
||||
}
|
||||
}
|
||||
|
||||
impl Unpack for ReqGetDfsReferral {
|
||||
fn unpack(cursor: &mut ReadCursor<'_>) -> Result<Self> {
|
||||
let max_referral_level = cursor.read_u16_le()?;
|
||||
// Read the rest as null-terminated UTF-16LE.
|
||||
let request_file_name = read_null_terminated_utf16(cursor)?;
|
||||
Ok(ReqGetDfsReferral {
|
||||
max_referral_level,
|
||||
request_file_name,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ── RespGetDfsReferral ────────────────────────────────────────────────
|
||||
|
||||
/// RESP_GET_DFS_REFERRAL (MS-DFSC 2.2.4).
|
||||
///
|
||||
/// Returned in the output buffer of an IOCTL response for
|
||||
/// `FSCTL_DFS_GET_REFERRALS`. Contains the number of bytes of the path
|
||||
/// consumed by the server, header flags, and a list of referral entries.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct RespGetDfsReferral {
|
||||
/// Number of bytes (not characters) of the path prefix that matched.
|
||||
pub path_consumed: u16,
|
||||
/// Header flags (ReferralServers | StorageServers | TargetFailback).
|
||||
pub header_flags: u32,
|
||||
/// The list of referral entries (V2, V3, or V4).
|
||||
pub entries: Vec<DfsReferralEntry>,
|
||||
}
|
||||
|
||||
/// A single DFS referral entry (V2-V4 flattened).
|
||||
///
|
||||
/// V1 is not supported (extremely rare in practice). Each entry describes
|
||||
/// one target server/share that the client can use to access the DFS path.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct DfsReferralEntry {
|
||||
/// Referral entry version (2, 3, or 4).
|
||||
pub version: u16,
|
||||
/// Server type: 0 = non-root/link target, 1 = root target.
|
||||
pub server_type: u16,
|
||||
/// Referral entry flags (version-specific).
|
||||
pub referral_entry_flags: u16,
|
||||
/// Time-to-live in seconds for caching this referral.
|
||||
pub ttl: u32,
|
||||
/// The DFS path prefix that matched.
|
||||
pub dfs_path: String,
|
||||
/// The DFS alternate path (usually identical to dfs_path).
|
||||
pub dfs_alternate_path: String,
|
||||
/// The target UNC path (for example, `\\server\share`).
|
||||
pub network_address: String,
|
||||
}
|
||||
|
||||
impl Unpack for RespGetDfsReferral {
|
||||
fn unpack(cursor: &mut ReadCursor<'_>) -> Result<Self> {
|
||||
let path_consumed = cursor.read_u16_le()?;
|
||||
let number_of_referrals = cursor.read_u16_le()?;
|
||||
let header_flags = cursor.read_u32_le()?;
|
||||
|
||||
// The remaining data contains all referral entries followed by a
|
||||
// string buffer. We need the full remaining slice to resolve
|
||||
// offsets that are relative to each entry's start.
|
||||
let entry_data = cursor.read_bytes(cursor.remaining())?;
|
||||
|
||||
let mut entries = Vec::with_capacity(number_of_referrals as usize);
|
||||
let mut offset = 0usize;
|
||||
|
||||
for _ in 0..number_of_referrals {
|
||||
if offset + 4 > entry_data.len() {
|
||||
return Err(Error::invalid_data(
|
||||
"DFS referral entry truncated (version/size header)",
|
||||
));
|
||||
}
|
||||
|
||||
let version = u16::from_le_bytes([entry_data[offset], entry_data[offset + 1]]);
|
||||
let entry_size =
|
||||
u16::from_le_bytes([entry_data[offset + 2], entry_data[offset + 3]]) as usize;
|
||||
|
||||
if entry_size < 4 {
|
||||
return Err(Error::invalid_data(format!(
|
||||
"DFS referral entry size too small: {entry_size}"
|
||||
)));
|
||||
}
|
||||
|
||||
let entry_start = offset;
|
||||
// The entry_size includes the version and size fields themselves.
|
||||
let entry_end = entry_start + entry_size;
|
||||
if entry_end > entry_data.len() {
|
||||
return Err(Error::invalid_data(format!(
|
||||
"DFS referral entry extends past buffer: entry_end={entry_end}, buf={}",
|
||||
entry_data.len()
|
||||
)));
|
||||
}
|
||||
|
||||
// All strings referenced by offsets live from entry_start onward
|
||||
// in the full buffer (not truncated to entry_size, because the
|
||||
// strings are in the trailing string buffer).
|
||||
let entry = parse_referral_entry(version, entry_data, entry_start)?;
|
||||
entries.push(entry);
|
||||
|
||||
offset = entry_end;
|
||||
}
|
||||
|
||||
Ok(RespGetDfsReferral {
|
||||
path_consumed,
|
||||
header_flags,
|
||||
entries,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse a single referral entry starting at `entry_start` within `buf`.
|
||||
///
|
||||
/// String offsets in V2/V3/V4 are relative to the start of the entry
|
||||
/// (which includes the 4-byte version+size prefix).
|
||||
fn parse_referral_entry(version: u16, buf: &[u8], entry_start: usize) -> Result<DfsReferralEntry> {
|
||||
// Skip version (2) + size (2) -- already read by caller.
|
||||
let mut pos = entry_start + 4;
|
||||
|
||||
match version {
|
||||
2 => {
|
||||
// V2: server_type(2) + flags(2) + proximity(4) + ttl(4) +
|
||||
// dfs_path_offset(2) + dfs_alternate_path_offset(2) + network_address_offset(2)
|
||||
// = 18 bytes of fixed entry body after the 4-byte version/size prefix.
|
||||
ensure_remaining(buf, pos, 18)?;
|
||||
let server_type = read_u16(buf, pos);
|
||||
pos += 2;
|
||||
let referral_entry_flags = read_u16(buf, pos);
|
||||
pos += 2;
|
||||
let _proximity = read_u32(buf, pos);
|
||||
pos += 4;
|
||||
let ttl = read_u32(buf, pos);
|
||||
pos += 4;
|
||||
let dfs_path_offset = read_u16(buf, pos) as usize;
|
||||
pos += 2;
|
||||
let dfs_alternate_path_offset = read_u16(buf, pos) as usize;
|
||||
pos += 2;
|
||||
let network_address_offset = read_u16(buf, pos) as usize;
|
||||
|
||||
let dfs_path = read_offset_string(buf, entry_start, dfs_path_offset)?;
|
||||
let dfs_alternate_path =
|
||||
read_offset_string(buf, entry_start, dfs_alternate_path_offset)?;
|
||||
let network_address = read_offset_string(buf, entry_start, network_address_offset)?;
|
||||
|
||||
Ok(DfsReferralEntry {
|
||||
version,
|
||||
server_type,
|
||||
referral_entry_flags,
|
||||
ttl,
|
||||
dfs_path,
|
||||
dfs_alternate_path,
|
||||
network_address,
|
||||
})
|
||||
}
|
||||
3 | 4 => {
|
||||
// V3/V4 share the same layout for the common (non-NameListReferral) case.
|
||||
// server_type(2) + flags(2) + ttl(4) +
|
||||
// dfs_path_offset(2) + dfs_alternate_path_offset(2) + network_address_offset(2)
|
||||
// V3/V4: + service_site_guid(16) when NameListReferral=0
|
||||
ensure_remaining(buf, pos, 14)?;
|
||||
let server_type = read_u16(buf, pos);
|
||||
pos += 2;
|
||||
let referral_entry_flags = read_u16(buf, pos);
|
||||
pos += 2;
|
||||
let ttl = read_u32(buf, pos);
|
||||
pos += 4;
|
||||
let dfs_path_offset = read_u16(buf, pos) as usize;
|
||||
pos += 2;
|
||||
let dfs_alternate_path_offset = read_u16(buf, pos) as usize;
|
||||
pos += 2;
|
||||
let network_address_offset = read_u16(buf, pos) as usize;
|
||||
// Skip the rest of the fixed entry (service_site_guid for V3/V4).
|
||||
|
||||
let dfs_path = read_offset_string(buf, entry_start, dfs_path_offset)?;
|
||||
let dfs_alternate_path =
|
||||
read_offset_string(buf, entry_start, dfs_alternate_path_offset)?;
|
||||
let network_address = read_offset_string(buf, entry_start, network_address_offset)?;
|
||||
|
||||
Ok(DfsReferralEntry {
|
||||
version,
|
||||
server_type,
|
||||
referral_entry_flags,
|
||||
ttl,
|
||||
dfs_path,
|
||||
dfs_alternate_path,
|
||||
network_address,
|
||||
})
|
||||
}
|
||||
_ => Err(Error::invalid_data(format!(
|
||||
"unsupported DFS referral version: {version} (only V2-V4 are supported)"
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
// ── Helper functions ──────────────────────────────────────────────────
|
||||
|
||||
/// Read a null-terminated UTF-16LE string from a `ReadCursor`.
|
||||
fn read_null_terminated_utf16(cursor: &mut ReadCursor<'_>) -> Result<String> {
|
||||
let mut code_units: Vec<u16> = Vec::new();
|
||||
loop {
|
||||
let cu = cursor.read_u16_le()?;
|
||||
if cu == 0 {
|
||||
break;
|
||||
}
|
||||
code_units.push(cu);
|
||||
}
|
||||
String::from_utf16(&code_units)
|
||||
.map_err(|_| Error::invalid_data("invalid UTF-16LE in DFS request file name"))
|
||||
}
|
||||
|
||||
/// Read a null-terminated UTF-16LE string from a raw byte buffer at a given absolute offset.
|
||||
fn read_null_terminated_utf16_at(buf: &[u8], offset: usize) -> Result<String> {
|
||||
let mut code_units: Vec<u16> = Vec::new();
|
||||
let mut pos = offset;
|
||||
loop {
|
||||
if pos + 2 > buf.len() {
|
||||
return Err(Error::invalid_data(
|
||||
"DFS referral string extends past buffer",
|
||||
));
|
||||
}
|
||||
let cu = u16::from_le_bytes([buf[pos], buf[pos + 1]]);
|
||||
pos += 2;
|
||||
if cu == 0 {
|
||||
break;
|
||||
}
|
||||
code_units.push(cu);
|
||||
}
|
||||
String::from_utf16(&code_units)
|
||||
.map_err(|_| Error::invalid_data("invalid UTF-16LE in DFS referral string"))
|
||||
}
|
||||
|
||||
/// Read a null-terminated UTF-16LE string at an offset relative to an entry start.
|
||||
fn read_offset_string(buf: &[u8], entry_start: usize, offset: usize) -> Result<String> {
|
||||
let abs = entry_start + offset;
|
||||
read_null_terminated_utf16_at(buf, abs)
|
||||
}
|
||||
|
||||
/// Inline LE u16 read from a byte buffer.
|
||||
fn read_u16(buf: &[u8], pos: usize) -> u16 {
|
||||
u16::from_le_bytes([buf[pos], buf[pos + 1]])
|
||||
}
|
||||
|
||||
/// Inline LE u32 read from a byte buffer.
|
||||
fn read_u32(buf: &[u8], pos: usize) -> u32 {
|
||||
u32::from_le_bytes([buf[pos], buf[pos + 1], buf[pos + 2], buf[pos + 3]])
|
||||
}
|
||||
|
||||
/// Check that at least `need` bytes are available at `pos` in `buf`.
|
||||
fn ensure_remaining(buf: &[u8], pos: usize, need: usize) -> Result<()> {
|
||||
if pos + need > buf.len() {
|
||||
Err(Error::invalid_data(format!(
|
||||
"DFS referral entry truncated: need {need} bytes at offset {pos}, buf len {}",
|
||||
buf.len()
|
||||
)))
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// ── Request tests ─────────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn req_pack_known_bytes() {
|
||||
// Test vector from smb-rs: ReqGetDfsReferral { max_referral_level: 4,
|
||||
// request_file_name: r"\ADC.aviv.local\dfs\Docs" }
|
||||
let expected = hex_to_bytes(
|
||||
"04005c004100440043002e0061007600690076002e006c006f00630061006c005c006400660073005c0044006f00630073000000",
|
||||
);
|
||||
let req = ReqGetDfsReferral {
|
||||
max_referral_level: 4,
|
||||
request_file_name: r"\ADC.aviv.local\dfs\Docs".to_string(),
|
||||
};
|
||||
let mut cursor = WriteCursor::new();
|
||||
req.pack(&mut cursor);
|
||||
assert_eq!(cursor.into_inner(), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn req_pack_roundtrip() {
|
||||
let original = ReqGetDfsReferral {
|
||||
max_referral_level: 4,
|
||||
request_file_name: r"\server\share\path".to_string(),
|
||||
};
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = ReqGetDfsReferral::unpack(&mut r).unwrap();
|
||||
assert_eq!(decoded, original);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn req_pack_empty_path() {
|
||||
let req = ReqGetDfsReferral {
|
||||
max_referral_level: 3,
|
||||
request_file_name: String::new(),
|
||||
};
|
||||
let mut w = WriteCursor::new();
|
||||
req.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
// max_referral_level (2) + null terminator (2) = 4 bytes
|
||||
assert_eq!(bytes.len(), 4);
|
||||
assert_eq!(bytes, [0x03, 0x00, 0x00, 0x00]);
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = ReqGetDfsReferral::unpack(&mut r).unwrap();
|
||||
assert_eq!(decoded, req);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn req_unpack_truncated() {
|
||||
// Only 1 byte -- not enough for max_referral_level.
|
||||
let bytes = [0x04];
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
assert!(ReqGetDfsReferral::unpack(&mut r).is_err());
|
||||
}
|
||||
|
||||
// ── Response tests ────────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn resp_parse_v4_referral() {
|
||||
// Test vector from smb-rs: two V4 entries.
|
||||
let hex = "300002000200000004002200000004000807000044007600\
|
||||
a800000000000000000000000000000000000400220000000000\
|
||||
0807000022005400a8000000000000000000000000000000\
|
||||
00005c004100440043002e0061007600690076002e006c00\
|
||||
6f00630061006c005c006400660073005c0044006f006300\
|
||||
730000005c004100440043002e0061007600690076002e00\
|
||||
6c006f00630061006c005c006400660073005c0044006f00\
|
||||
6300730000005c004100440043005c005300680061007200\
|
||||
650073005c0044006f006300730000005c00460053005200\
|
||||
56005c005300680061007200650073005c004d0079005300\
|
||||
6800610072006500000000";
|
||||
let data = hex_to_bytes(hex);
|
||||
let mut cursor = ReadCursor::new(&data);
|
||||
let resp = RespGetDfsReferral::unpack(&mut cursor).unwrap();
|
||||
|
||||
assert_eq!(resp.path_consumed, 48);
|
||||
// header_flags = 0x00000002 (StorageServers)
|
||||
assert_eq!(resp.header_flags, 0x0000_0002);
|
||||
assert_eq!(resp.entries.len(), 2);
|
||||
|
||||
let e0 = &resp.entries[0];
|
||||
assert_eq!(e0.version, 4);
|
||||
assert_eq!(e0.server_type, 0); // non-root
|
||||
assert_eq!(e0.ttl, 1800);
|
||||
assert_eq!(e0.dfs_path, r"\ADC.aviv.local\dfs\Docs");
|
||||
assert_eq!(e0.dfs_alternate_path, r"\ADC.aviv.local\dfs\Docs");
|
||||
assert_eq!(e0.network_address, r"\ADC\Shares\Docs");
|
||||
|
||||
let e1 = &resp.entries[1];
|
||||
assert_eq!(e1.version, 4);
|
||||
assert_eq!(e1.server_type, 0);
|
||||
assert_eq!(e1.ttl, 1800);
|
||||
assert_eq!(e1.dfs_path, r"\ADC.aviv.local\dfs\Docs");
|
||||
assert_eq!(e1.dfs_alternate_path, r"\ADC.aviv.local\dfs\Docs");
|
||||
assert_eq!(e1.network_address, r"\FSRV\Shares\MyShare");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resp_parse_v3_referral() {
|
||||
// Manually constructed V3 response: one entry.
|
||||
// Header: path_consumed=20, num_referrals=1, flags=0x03
|
||||
// Entry: version=3, size=34 (fixed part), server_type=1, flags=0,
|
||||
// ttl=600, offsets point to strings after the entry.
|
||||
let dfs_path = encode_null_utf16(r"\dom\share");
|
||||
let alt_path = encode_null_utf16(r"\dom\share");
|
||||
let net_addr = encode_null_utf16(r"\srv\share");
|
||||
|
||||
let entry_fixed_size: u16 = 34; // 4 + 2+2+4 + 2+2+2 + 16 = 34
|
||||
let dfs_path_offset = entry_fixed_size;
|
||||
let alt_path_offset = dfs_path_offset + dfs_path.len() as u16;
|
||||
let net_addr_offset = alt_path_offset + alt_path.len() as u16;
|
||||
|
||||
let mut buf = Vec::new();
|
||||
// Response header
|
||||
buf.extend_from_slice(&20u16.to_le_bytes()); // path_consumed
|
||||
buf.extend_from_slice(&1u16.to_le_bytes()); // number_of_referrals
|
||||
buf.extend_from_slice(&3u32.to_le_bytes()); // header_flags
|
||||
|
||||
// Entry header
|
||||
buf.extend_from_slice(&3u16.to_le_bytes()); // version
|
||||
buf.extend_from_slice(&entry_fixed_size.to_le_bytes()); // size (fixed part)
|
||||
buf.extend_from_slice(&1u16.to_le_bytes()); // server_type (root)
|
||||
buf.extend_from_slice(&0u16.to_le_bytes()); // referral_entry_flags
|
||||
buf.extend_from_slice(&600u32.to_le_bytes()); // ttl
|
||||
buf.extend_from_slice(&dfs_path_offset.to_le_bytes());
|
||||
buf.extend_from_slice(&alt_path_offset.to_le_bytes());
|
||||
buf.extend_from_slice(&net_addr_offset.to_le_bytes());
|
||||
buf.extend_from_slice(&[0u8; 16]); // service_site_guid
|
||||
|
||||
// String buffer
|
||||
buf.extend_from_slice(&dfs_path);
|
||||
buf.extend_from_slice(&alt_path);
|
||||
buf.extend_from_slice(&net_addr);
|
||||
|
||||
let mut cursor = ReadCursor::new(&buf);
|
||||
let resp = RespGetDfsReferral::unpack(&mut cursor).unwrap();
|
||||
|
||||
assert_eq!(resp.path_consumed, 20);
|
||||
assert_eq!(resp.header_flags, 3);
|
||||
assert_eq!(resp.entries.len(), 1);
|
||||
|
||||
let e = &resp.entries[0];
|
||||
assert_eq!(e.version, 3);
|
||||
assert_eq!(e.server_type, 1);
|
||||
assert_eq!(e.ttl, 600);
|
||||
assert_eq!(e.dfs_path, r"\dom\share");
|
||||
assert_eq!(e.dfs_alternate_path, r"\dom\share");
|
||||
assert_eq!(e.network_address, r"\srv\share");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resp_parse_v2_referral() {
|
||||
// Manually constructed V2 response: one entry.
|
||||
let dfs_path = encode_null_utf16(r"\domain\dfs");
|
||||
let alt_path = encode_null_utf16(r"\domain\dfs");
|
||||
let net_addr = encode_null_utf16(r"\server\data");
|
||||
|
||||
let entry_fixed_size: u16 = 22; // 4 + 2+2+4+4 + 2+2+2 = 22
|
||||
let dfs_path_offset = entry_fixed_size;
|
||||
let alt_path_offset = dfs_path_offset + dfs_path.len() as u16;
|
||||
let net_addr_offset = alt_path_offset + alt_path.len() as u16;
|
||||
|
||||
let mut buf = Vec::new();
|
||||
// Response header
|
||||
buf.extend_from_slice(&24u16.to_le_bytes()); // path_consumed
|
||||
buf.extend_from_slice(&1u16.to_le_bytes()); // number_of_referrals
|
||||
buf.extend_from_slice(&1u32.to_le_bytes()); // header_flags (ReferralServers)
|
||||
|
||||
// Entry
|
||||
buf.extend_from_slice(&2u16.to_le_bytes()); // version
|
||||
buf.extend_from_slice(&entry_fixed_size.to_le_bytes()); // size
|
||||
buf.extend_from_slice(&0u16.to_le_bytes()); // server_type
|
||||
buf.extend_from_slice(&0u16.to_le_bytes()); // flags
|
||||
buf.extend_from_slice(&0u32.to_le_bytes()); // proximity
|
||||
buf.extend_from_slice(&300u32.to_le_bytes()); // ttl
|
||||
buf.extend_from_slice(&dfs_path_offset.to_le_bytes());
|
||||
buf.extend_from_slice(&alt_path_offset.to_le_bytes());
|
||||
buf.extend_from_slice(&net_addr_offset.to_le_bytes());
|
||||
|
||||
// String buffer
|
||||
buf.extend_from_slice(&dfs_path);
|
||||
buf.extend_from_slice(&alt_path);
|
||||
buf.extend_from_slice(&net_addr);
|
||||
|
||||
let mut cursor = ReadCursor::new(&buf);
|
||||
let resp = RespGetDfsReferral::unpack(&mut cursor).unwrap();
|
||||
|
||||
assert_eq!(resp.path_consumed, 24);
|
||||
assert_eq!(resp.header_flags, 1);
|
||||
assert_eq!(resp.entries.len(), 1);
|
||||
|
||||
let e = &resp.entries[0];
|
||||
assert_eq!(e.version, 2);
|
||||
assert_eq!(e.server_type, 0);
|
||||
assert_eq!(e.ttl, 300);
|
||||
assert_eq!(e.dfs_path, r"\domain\dfs");
|
||||
assert_eq!(e.dfs_alternate_path, r"\domain\dfs");
|
||||
assert_eq!(e.network_address, r"\server\data");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resp_parse_empty() {
|
||||
// Zero referral entries.
|
||||
let mut buf = Vec::new();
|
||||
buf.extend_from_slice(&0u16.to_le_bytes()); // path_consumed
|
||||
buf.extend_from_slice(&0u16.to_le_bytes()); // number_of_referrals
|
||||
buf.extend_from_slice(&0u32.to_le_bytes()); // header_flags
|
||||
|
||||
let mut cursor = ReadCursor::new(&buf);
|
||||
let resp = RespGetDfsReferral::unpack(&mut cursor).unwrap();
|
||||
assert_eq!(resp.path_consumed, 0);
|
||||
assert_eq!(resp.header_flags, 0);
|
||||
assert!(resp.entries.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resp_parse_multiple_entries() {
|
||||
// Two V2 entries with different targets.
|
||||
// Layout: [entry1 fixed][entry2 fixed][strings for entry1][strings for entry2]
|
||||
// Offsets are relative to each entry's start.
|
||||
let dfs_path = encode_null_utf16(r"\ns\link");
|
||||
let alt_path = encode_null_utf16(r"\ns\link");
|
||||
let net_addr_1 = encode_null_utf16(r"\srv1\data");
|
||||
let net_addr_2 = encode_null_utf16(r"\srv2\data");
|
||||
|
||||
let entry_fixed_size: u16 = 22;
|
||||
let total_fixed: u16 = entry_fixed_size * 2; // both entries' fixed parts
|
||||
|
||||
// Entry 1 string offsets (relative to entry 1 start = 0 in entry_data).
|
||||
// Strings start after both entries' fixed parts.
|
||||
let e1_dfs_offset = total_fixed; // 44
|
||||
let e1_alt_offset = e1_dfs_offset + dfs_path.len() as u16;
|
||||
let e1_net_offset = e1_alt_offset + alt_path.len() as u16;
|
||||
let e1_strings_end = e1_net_offset + net_addr_1.len() as u16;
|
||||
|
||||
// Entry 2 string offsets (relative to entry 2 start = 22 in entry_data).
|
||||
let e2_dfs_offset = e1_strings_end - entry_fixed_size; // offset from entry 2 start
|
||||
let e2_alt_offset = e2_dfs_offset + dfs_path.len() as u16;
|
||||
let e2_net_offset = e2_alt_offset + alt_path.len() as u16;
|
||||
|
||||
let mut buf = Vec::new();
|
||||
// Response header
|
||||
buf.extend_from_slice(&16u16.to_le_bytes()); // path_consumed
|
||||
buf.extend_from_slice(&2u16.to_le_bytes()); // number_of_referrals
|
||||
buf.extend_from_slice(&0u32.to_le_bytes()); // header_flags
|
||||
|
||||
// Entry 1 fixed part
|
||||
buf.extend_from_slice(&2u16.to_le_bytes()); // version
|
||||
buf.extend_from_slice(&entry_fixed_size.to_le_bytes()); // size
|
||||
buf.extend_from_slice(&0u16.to_le_bytes()); // server_type
|
||||
buf.extend_from_slice(&0u16.to_le_bytes()); // flags
|
||||
buf.extend_from_slice(&0u32.to_le_bytes()); // proximity
|
||||
buf.extend_from_slice(&120u32.to_le_bytes()); // ttl
|
||||
buf.extend_from_slice(&e1_dfs_offset.to_le_bytes());
|
||||
buf.extend_from_slice(&e1_alt_offset.to_le_bytes());
|
||||
buf.extend_from_slice(&e1_net_offset.to_le_bytes());
|
||||
|
||||
// Entry 2 fixed part
|
||||
buf.extend_from_slice(&2u16.to_le_bytes());
|
||||
buf.extend_from_slice(&entry_fixed_size.to_le_bytes());
|
||||
buf.extend_from_slice(&1u16.to_le_bytes()); // server_type = root
|
||||
buf.extend_from_slice(&0u16.to_le_bytes());
|
||||
buf.extend_from_slice(&0u32.to_le_bytes());
|
||||
buf.extend_from_slice(&240u32.to_le_bytes());
|
||||
buf.extend_from_slice(&e2_dfs_offset.to_le_bytes());
|
||||
buf.extend_from_slice(&e2_alt_offset.to_le_bytes());
|
||||
buf.extend_from_slice(&e2_net_offset.to_le_bytes());
|
||||
|
||||
// String buffer for entry 1
|
||||
buf.extend_from_slice(&dfs_path);
|
||||
buf.extend_from_slice(&alt_path);
|
||||
buf.extend_from_slice(&net_addr_1);
|
||||
|
||||
// String buffer for entry 2
|
||||
buf.extend_from_slice(&dfs_path);
|
||||
buf.extend_from_slice(&alt_path);
|
||||
buf.extend_from_slice(&net_addr_2);
|
||||
|
||||
let mut cursor = ReadCursor::new(&buf);
|
||||
let resp = RespGetDfsReferral::unpack(&mut cursor).unwrap();
|
||||
|
||||
assert_eq!(resp.entries.len(), 2);
|
||||
assert_eq!(resp.entries[0].ttl, 120);
|
||||
assert_eq!(resp.entries[0].network_address, r"\srv1\data");
|
||||
assert_eq!(resp.entries[1].ttl, 240);
|
||||
assert_eq!(resp.entries[1].server_type, 1);
|
||||
assert_eq!(resp.entries[1].network_address, r"\srv2\data");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resp_parse_unsupported_version() {
|
||||
let mut buf = Vec::new();
|
||||
// Response header
|
||||
buf.extend_from_slice(&0u16.to_le_bytes());
|
||||
buf.extend_from_slice(&1u16.to_le_bytes()); // 1 entry
|
||||
buf.extend_from_slice(&0u32.to_le_bytes());
|
||||
// Entry with version 1 (unsupported)
|
||||
buf.extend_from_slice(&1u16.to_le_bytes()); // version
|
||||
buf.extend_from_slice(&8u16.to_le_bytes()); // size
|
||||
buf.extend_from_slice(&[0u8; 4]); // padding to reach size
|
||||
|
||||
let mut cursor = ReadCursor::new(&buf);
|
||||
let result = RespGetDfsReferral::unpack(&mut cursor);
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err().to_string();
|
||||
assert!(
|
||||
err.contains("unsupported DFS referral version"),
|
||||
"error was: {err}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resp_parse_truncated_header() {
|
||||
// Only 4 bytes -- missing header_flags.
|
||||
let buf = [0x00, 0x00, 0x01, 0x00];
|
||||
let mut cursor = ReadCursor::new(&buf);
|
||||
assert!(RespGetDfsReferral::unpack(&mut cursor).is_err());
|
||||
}
|
||||
|
||||
/// Regression: fuzz-found crash. A V2 entry that claims `entry_size = 16`
|
||||
/// used to panic inside the entry-body read. The V2 body needs 18 bytes
|
||||
/// (server_type+flags+proximity+ttl + three u16 offsets), but the guard
|
||||
/// only ensured 16 bytes were available, so the final offset read would
|
||||
/// slip past the buffer. See fuzz target
|
||||
/// `fuzz_dfs_referral_response_parse` crash
|
||||
/// `a6933afd5a1ccec7166d914caed66154416a2fcb`.
|
||||
#[test]
|
||||
fn resp_parse_v2_short_entry_returns_clean_error() {
|
||||
let crash_input: [u8; 28] = [
|
||||
0x10, 0x00, 0x01, 0x00, 0x22, 0x23, 0x00, 0x03, // header
|
||||
0x02, 0x00, 0x10, 0x00, 0x01, 0x00, 0x22, 0x23, // v2 entry start (size=16)
|
||||
0x00, 0x03, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, // body bytes
|
||||
0x00, 0x00, 0x00, 0x00, // tail
|
||||
];
|
||||
let mut cursor = ReadCursor::new(&crash_input);
|
||||
let result = RespGetDfsReferral::unpack(&mut cursor);
|
||||
assert!(result.is_err(), "expected clean error, got {result:?}");
|
||||
}
|
||||
|
||||
// ── Test helpers ──────────────────────────────────────────────────
|
||||
|
||||
/// Decode a hex string (no spaces, no 0x prefix) into bytes.
|
||||
fn hex_to_bytes(hex: &str) -> Vec<u8> {
|
||||
let hex: String = hex.chars().filter(|c| !c.is_whitespace()).collect();
|
||||
(0..hex.len())
|
||||
.step_by(2)
|
||||
.map(|i| u8::from_str_radix(&hex[i..i + 2], 16).unwrap())
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Encode a string as null-terminated UTF-16LE bytes.
|
||||
fn encode_null_utf16(s: &str) -> Vec<u8> {
|
||||
let mut out = Vec::new();
|
||||
for cu in s.encode_utf16() {
|
||||
out.extend_from_slice(&cu.to_le_bytes());
|
||||
}
|
||||
out.extend_from_slice(&[0x00, 0x00]); // null terminator
|
||||
out
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod roundtrip_props {
|
||||
use super::*;
|
||||
use crate::msg::roundtrip_strategies::arb_utf16_string;
|
||||
use proptest::prelude::*;
|
||||
|
||||
/// Generate a UTF-16 string without interior null (U+0000). The encoder
|
||||
/// terminates with a 0x0000 code unit, so an interior null would end
|
||||
/// the string early on decode.
|
||||
fn arb_utf16_no_nul(max: usize) -> impl Strategy<Value = String> {
|
||||
arb_utf16_string(max).prop_filter("string must not contain interior U+0000", |s| {
|
||||
!s.contains('\0')
|
||||
})
|
||||
}
|
||||
|
||||
proptest! {
|
||||
#[test]
|
||||
fn req_get_dfs_referral_pack_unpack(
|
||||
max_referral_level in any::<u16>(),
|
||||
request_file_name in arb_utf16_no_nul(128),
|
||||
) {
|
||||
let original = ReqGetDfsReferral {
|
||||
max_referral_level,
|
||||
request_file_name,
|
||||
};
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = ReqGetDfsReferral::unpack(&mut r).unwrap();
|
||||
prop_assert_eq!(decoded, original);
|
||||
prop_assert!(r.is_empty());
|
||||
}
|
||||
}
|
||||
}
|
||||
42
vendor/smb2/src/msg/echo.rs
vendored
Normal file
42
vendor/smb2/src/msg/echo.rs
vendored
Normal file
@@ -0,0 +1,42 @@
|
||||
//! SMB2 ECHO request and response (spec sections 2.2.28, 2.2.29).
|
||||
//!
|
||||
//! Echo messages are used to check whether a server is processing requests.
|
||||
//! Both request and response contain only a StructureSize field and a
|
||||
//! reserved field, for a total of 4 bytes each.
|
||||
|
||||
super::trivial_message! {
|
||||
/// SMB2 ECHO request (spec section 2.2.28).
|
||||
///
|
||||
/// Sent by the client to determine whether a server is processing requests.
|
||||
/// Contains only StructureSize (2 bytes) and Reserved (2 bytes).
|
||||
pub struct EchoRequest;
|
||||
}
|
||||
|
||||
super::trivial_message! {
|
||||
/// SMB2 ECHO response (spec section 2.2.29).
|
||||
///
|
||||
/// Sent by the server to confirm that an ECHO request was processed.
|
||||
/// Contains only StructureSize (2 bytes) and Reserved (2 bytes).
|
||||
pub struct EchoResponse;
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
super::super::trivial_message_tests!(
|
||||
EchoRequest,
|
||||
echo_request_known_bytes,
|
||||
echo_request_roundtrip,
|
||||
echo_request_wrong_structure_size,
|
||||
echo_request_too_short
|
||||
);
|
||||
|
||||
super::super::trivial_message_tests!(
|
||||
EchoResponse,
|
||||
echo_response_known_bytes,
|
||||
echo_response_roundtrip,
|
||||
echo_response_wrong_structure_size,
|
||||
echo_response_too_short
|
||||
);
|
||||
}
|
||||
254
vendor/smb2/src/msg/flush.rs
vendored
Normal file
254
vendor/smb2/src/msg/flush.rs
vendored
Normal file
@@ -0,0 +1,254 @@
|
||||
//! SMB2 FLUSH request and response (spec sections 2.2.17, 2.2.18).
|
||||
//!
|
||||
//! Flush messages request that the server flush all cached file information
|
||||
//! for a specified open to persistent storage. If the open refers to a
|
||||
//! named pipe, the operation completes once all written data has been
|
||||
//! consumed by a reader.
|
||||
|
||||
use crate::error::Result;
|
||||
use crate::pack::{Pack, ReadCursor, Unpack, WriteCursor};
|
||||
use crate::types::FileId;
|
||||
use crate::Error;
|
||||
|
||||
/// SMB2 FLUSH request (spec section 2.2.17).
|
||||
///
|
||||
/// Sent by the client to request that the server flush cached data for a file.
|
||||
///
|
||||
/// Wire layout (24 bytes):
|
||||
/// - StructureSize (2 bytes): must be 24
|
||||
/// - Reserved1 (2 bytes): must be 0
|
||||
/// - Reserved2 (4 bytes): must be 0
|
||||
/// - FileId (16 bytes): persistent (8 bytes) + volatile (8 bytes)
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct FlushRequest {
|
||||
pub file_id: FileId,
|
||||
}
|
||||
|
||||
impl FlushRequest {
|
||||
pub const STRUCTURE_SIZE: u16 = 24;
|
||||
}
|
||||
|
||||
impl Pack for FlushRequest {
|
||||
fn pack(&self, cursor: &mut WriteCursor) {
|
||||
// StructureSize (2 bytes)
|
||||
cursor.write_u16_le(Self::STRUCTURE_SIZE);
|
||||
// Reserved1 (2 bytes)
|
||||
cursor.write_u16_le(0);
|
||||
// Reserved2 (4 bytes)
|
||||
cursor.write_u32_le(0);
|
||||
// FileId: Persistent (8 bytes) + Volatile (8 bytes)
|
||||
cursor.write_u64_le(self.file_id.persistent);
|
||||
cursor.write_u64_le(self.file_id.volatile);
|
||||
}
|
||||
}
|
||||
|
||||
impl Unpack for FlushRequest {
|
||||
fn unpack(cursor: &mut ReadCursor<'_>) -> Result<Self> {
|
||||
// StructureSize (2 bytes)
|
||||
let structure_size = cursor.read_u16_le()?;
|
||||
if structure_size != Self::STRUCTURE_SIZE {
|
||||
return Err(Error::invalid_data(format!(
|
||||
"invalid FlushRequest structure size: expected {}, got {}",
|
||||
Self::STRUCTURE_SIZE,
|
||||
structure_size
|
||||
)));
|
||||
}
|
||||
|
||||
// Reserved1 (2 bytes)
|
||||
let _reserved1 = cursor.read_u16_le()?;
|
||||
|
||||
// Reserved2 (4 bytes)
|
||||
let _reserved2 = cursor.read_u32_le()?;
|
||||
|
||||
// FileId: Persistent (8 bytes) + Volatile (8 bytes)
|
||||
let persistent = cursor.read_u64_le()?;
|
||||
let volatile = cursor.read_u64_le()?;
|
||||
|
||||
Ok(FlushRequest {
|
||||
file_id: FileId {
|
||||
persistent,
|
||||
volatile,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
super::trivial_message! {
|
||||
/// SMB2 FLUSH response (spec section 2.2.18).
|
||||
///
|
||||
/// Sent by the server to confirm that a FLUSH request was processed.
|
||||
/// Contains only StructureSize (2 bytes) and Reserved (2 bytes).
|
||||
pub struct FlushResponse;
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// ── FlushRequest tests ─────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn flush_request_pack_produces_24_bytes() {
|
||||
let req = FlushRequest {
|
||||
file_id: FileId::default(),
|
||||
};
|
||||
let mut cursor = WriteCursor::new();
|
||||
req.pack(&mut cursor);
|
||||
let bytes = cursor.into_inner();
|
||||
assert_eq!(bytes.len(), 24);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn flush_request_known_bytes() {
|
||||
let req = FlushRequest {
|
||||
file_id: FileId {
|
||||
persistent: 0x0102_0304_0506_0708,
|
||||
volatile: 0x090A_0B0C_0D0E_0F10,
|
||||
},
|
||||
};
|
||||
let mut cursor = WriteCursor::new();
|
||||
req.pack(&mut cursor);
|
||||
let bytes = cursor.into_inner();
|
||||
|
||||
#[rustfmt::skip]
|
||||
let expected: [u8; 24] = [
|
||||
// StructureSize = 24
|
||||
0x18, 0x00,
|
||||
// Reserved1 = 0
|
||||
0x00, 0x00,
|
||||
// Reserved2 = 0
|
||||
0x00, 0x00, 0x00, 0x00,
|
||||
// FileId.Persistent (LE)
|
||||
0x08, 0x07, 0x06, 0x05, 0x04, 0x03, 0x02, 0x01,
|
||||
// FileId.Volatile (LE)
|
||||
0x10, 0x0F, 0x0E, 0x0D, 0x0C, 0x0B, 0x0A, 0x09,
|
||||
];
|
||||
assert_eq!(bytes, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn flush_request_unpack_known_bytes() {
|
||||
#[rustfmt::skip]
|
||||
let bytes: [u8; 24] = [
|
||||
// StructureSize = 24
|
||||
0x18, 0x00,
|
||||
// Reserved1 = 0
|
||||
0x00, 0x00,
|
||||
// Reserved2 = 0
|
||||
0x00, 0x00, 0x00, 0x00,
|
||||
// FileId.Persistent = 0xDEADBEEFCAFEBABE
|
||||
0xBE, 0xBA, 0xFE, 0xCA, 0xEF, 0xBE, 0xAD, 0xDE,
|
||||
// FileId.Volatile = 0x1234567890ABCDEF
|
||||
0xEF, 0xCD, 0xAB, 0x90, 0x78, 0x56, 0x34, 0x12,
|
||||
];
|
||||
let mut cursor = ReadCursor::new(&bytes);
|
||||
let req = FlushRequest::unpack(&mut cursor).unwrap();
|
||||
|
||||
assert_eq!(req.file_id.persistent, 0xDEAD_BEEF_CAFE_BABE);
|
||||
assert_eq!(req.file_id.volatile, 0x1234_5678_90AB_CDEF);
|
||||
assert!(cursor.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn flush_request_roundtrip() {
|
||||
let original = FlushRequest {
|
||||
file_id: FileId {
|
||||
persistent: 0xAAAA_BBBB_CCCC_DDDD,
|
||||
volatile: 0x1111_2222_3333_4444,
|
||||
},
|
||||
};
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = FlushRequest::unpack(&mut r).unwrap();
|
||||
assert_eq!(decoded, original);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn flush_request_roundtrip_sentinel_file_id() {
|
||||
let original = FlushRequest {
|
||||
file_id: FileId::SENTINEL,
|
||||
};
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = FlushRequest::unpack(&mut r).unwrap();
|
||||
assert_eq!(decoded, original);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn flush_request_wrong_structure_size() {
|
||||
let mut bytes = [0u8; 24];
|
||||
// Wrong structure size = 4 instead of 24
|
||||
bytes[0..2].copy_from_slice(&4u16.to_le_bytes());
|
||||
let mut cursor = ReadCursor::new(&bytes);
|
||||
let result = FlushRequest::unpack(&mut cursor);
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err().to_string();
|
||||
assert!(err.contains("structure size"), "error was: {err}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn flush_request_too_short() {
|
||||
let bytes = [0x18, 0x00, 0x00, 0x00];
|
||||
let mut cursor = ReadCursor::new(&bytes);
|
||||
let result = FlushRequest::unpack(&mut cursor);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn flush_request_ignores_reserved_values() {
|
||||
#[rustfmt::skip]
|
||||
let bytes: [u8; 24] = [
|
||||
// StructureSize = 24
|
||||
0x18, 0x00,
|
||||
// Reserved1 = 0xFFFF (non-zero, should be ignored)
|
||||
0xFF, 0xFF,
|
||||
// Reserved2 = 0xFFFFFFFF (non-zero, should be ignored)
|
||||
0xFF, 0xFF, 0xFF, 0xFF,
|
||||
// FileId.Persistent = 0
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
// FileId.Volatile = 0
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
];
|
||||
let mut cursor = ReadCursor::new(&bytes);
|
||||
let req = FlushRequest::unpack(&mut cursor).unwrap();
|
||||
assert_eq!(req.file_id, FileId::default());
|
||||
}
|
||||
|
||||
// ── FlushResponse tests ────────────────────────────────────────
|
||||
|
||||
super::super::trivial_message_tests!(
|
||||
FlushResponse,
|
||||
flush_response_known_bytes,
|
||||
flush_response_roundtrip,
|
||||
flush_response_wrong_structure_size,
|
||||
flush_response_too_short
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod roundtrip_props {
|
||||
use super::*;
|
||||
use crate::msg::roundtrip_strategies::arb_file_id;
|
||||
use proptest::prelude::*;
|
||||
|
||||
proptest! {
|
||||
#[test]
|
||||
fn flush_request_pack_unpack(file_id in arb_file_id()) {
|
||||
let original = FlushRequest { file_id };
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = FlushRequest::unpack(&mut r).unwrap();
|
||||
prop_assert_eq!(decoded, original);
|
||||
prop_assert!(r.is_empty());
|
||||
}
|
||||
}
|
||||
}
|
||||
669
vendor/smb2/src/msg/header.rs
vendored
Normal file
669
vendor/smb2/src/msg/header.rs
vendored
Normal file
@@ -0,0 +1,669 @@
|
||||
//! SMB2 packet header (64 bytes) and error response.
|
||||
//!
|
||||
//! The SMB2 header has two variants that share the same 64-byte layout:
|
||||
//! - **Sync header:** bytes 32-35 = Reserved (u32), bytes 36-39 = TreeId (u32)
|
||||
//! - **Async header:** bytes 32-39 = AsyncId (u64)
|
||||
//!
|
||||
//! The choice is determined by the `SMB2_FLAGS_ASYNC_COMMAND` bit in the Flags field.
|
||||
//!
|
||||
//! Reference: MS-SMB2 sections 2.2.1, 2.2.1.1, 2.2.1.2, 2.2.2.
|
||||
|
||||
use crate::error::Result;
|
||||
use crate::pack::{Pack, ReadCursor, Unpack, WriteCursor};
|
||||
use crate::types::flags::HeaderFlags;
|
||||
use crate::types::status::NtStatus;
|
||||
use crate::types::{Command, CreditCharge, MessageId, SessionId, TreeId};
|
||||
use crate::Error;
|
||||
|
||||
/// The 4-byte protocol identifier at the start of every SMB2 message.
|
||||
pub const PROTOCOL_ID: [u8; 4] = [0xFE, b'S', b'M', b'B'];
|
||||
|
||||
/// SMB2 packet header (64 bytes).
|
||||
///
|
||||
/// Contains both sync and async variants. The `flags` field determines
|
||||
/// which interpretation of bytes 32-39 is correct.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct Header {
|
||||
/// Number of credits charged for this request.
|
||||
pub credit_charge: CreditCharge,
|
||||
/// In responses: NtStatus. In requests before SMB 3.x: Reserved.
|
||||
/// In requests for SMB 3.x: ChannelSequence (u16) + Reserved (u16).
|
||||
pub status: NtStatus,
|
||||
/// The command code for this packet.
|
||||
pub command: Command,
|
||||
/// In requests: credits requested. In responses: credits granted.
|
||||
pub credits: u16,
|
||||
/// Flags indicating how to process the operation.
|
||||
pub flags: HeaderFlags,
|
||||
/// Offset to the next command in a compound chain (0 = last/only).
|
||||
pub next_command: u32,
|
||||
/// Unique message identifier for request/response correlation.
|
||||
pub message_id: MessageId,
|
||||
/// Sync-only: tree identifier. None if async.
|
||||
pub tree_id: Option<TreeId>,
|
||||
/// Async-only: async identifier. None if sync.
|
||||
pub async_id: Option<u64>,
|
||||
/// Session identifier.
|
||||
pub session_id: SessionId,
|
||||
/// 16-byte message signature.
|
||||
pub signature: [u8; 16],
|
||||
}
|
||||
|
||||
impl Header {
|
||||
pub const STRUCTURE_SIZE: u16 = 64;
|
||||
|
||||
/// Total header size in bytes.
|
||||
pub const SIZE: usize = 64;
|
||||
|
||||
/// Create a new request header for a given command.
|
||||
pub fn new_request(command: Command) -> Self {
|
||||
Self {
|
||||
credit_charge: CreditCharge(0),
|
||||
status: NtStatus::SUCCESS,
|
||||
command,
|
||||
credits: 1,
|
||||
flags: HeaderFlags::default(),
|
||||
next_command: 0,
|
||||
message_id: MessageId::default(),
|
||||
tree_id: Some(TreeId::default()),
|
||||
async_id: None,
|
||||
session_id: SessionId::default(),
|
||||
signature: [0u8; 16],
|
||||
}
|
||||
}
|
||||
|
||||
/// Is this a response (vs request)?
|
||||
pub fn is_response(&self) -> bool {
|
||||
self.flags.is_response()
|
||||
}
|
||||
}
|
||||
|
||||
impl Pack for Header {
|
||||
fn pack(&self, cursor: &mut WriteCursor) {
|
||||
// ProtocolId (4 bytes)
|
||||
cursor.write_bytes(&PROTOCOL_ID);
|
||||
// StructureSize (2 bytes)
|
||||
cursor.write_u16_le(Self::STRUCTURE_SIZE);
|
||||
// CreditCharge (2 bytes)
|
||||
cursor.write_u16_le(self.credit_charge.0);
|
||||
// Status (4 bytes)
|
||||
cursor.write_u32_le(self.status.0);
|
||||
// Command (2 bytes)
|
||||
cursor.write_u16_le(self.command.into());
|
||||
// CreditRequest/CreditResponse (2 bytes)
|
||||
cursor.write_u16_le(self.credits);
|
||||
// Flags (4 bytes)
|
||||
cursor.write_u32_le(self.flags.bits());
|
||||
// NextCommand (4 bytes)
|
||||
cursor.write_u32_le(self.next_command);
|
||||
// MessageId (8 bytes)
|
||||
cursor.write_u64_le(self.message_id.0);
|
||||
|
||||
// Bytes 32-39: async or sync variant
|
||||
if self.flags.is_async() {
|
||||
// AsyncId (8 bytes)
|
||||
cursor.write_u64_le(self.async_id.unwrap_or(0));
|
||||
} else {
|
||||
// Reserved (4 bytes)
|
||||
cursor.write_u32_le(0);
|
||||
// TreeId (4 bytes)
|
||||
cursor.write_u32_le(self.tree_id.map_or(0, |t| t.0));
|
||||
}
|
||||
|
||||
// SessionId (8 bytes)
|
||||
cursor.write_u64_le(self.session_id.0);
|
||||
// Signature (16 bytes)
|
||||
cursor.write_bytes(&self.signature);
|
||||
}
|
||||
}
|
||||
|
||||
impl Unpack for Header {
|
||||
fn unpack(cursor: &mut ReadCursor<'_>) -> Result<Self> {
|
||||
// ProtocolId (4 bytes)
|
||||
let proto = cursor.read_bytes(4)?;
|
||||
if proto != PROTOCOL_ID {
|
||||
return Err(Error::invalid_data(format!(
|
||||
"invalid SMB2 protocol ID: expected {:02X?}, got {:02X?}",
|
||||
PROTOCOL_ID, proto
|
||||
)));
|
||||
}
|
||||
|
||||
// StructureSize (2 bytes)
|
||||
let structure_size = cursor.read_u16_le()?;
|
||||
if structure_size != Header::STRUCTURE_SIZE {
|
||||
return Err(Error::invalid_data(format!(
|
||||
"invalid SMB2 header structure size: expected {}, got {}",
|
||||
Header::STRUCTURE_SIZE,
|
||||
structure_size
|
||||
)));
|
||||
}
|
||||
|
||||
// CreditCharge (2 bytes)
|
||||
let credit_charge = CreditCharge(cursor.read_u16_le()?);
|
||||
|
||||
// Status (4 bytes)
|
||||
let status = NtStatus(cursor.read_u32_le()?);
|
||||
|
||||
// Command (2 bytes)
|
||||
let command_raw = cursor.read_u16_le()?;
|
||||
let command = Command::try_from(command_raw).map_err(|_| {
|
||||
Error::invalid_data(format!("invalid SMB2 command code: 0x{:04X}", command_raw))
|
||||
})?;
|
||||
|
||||
// CreditRequest/CreditResponse (2 bytes)
|
||||
let credits = cursor.read_u16_le()?;
|
||||
|
||||
// Flags (4 bytes)
|
||||
let flags = HeaderFlags::new(cursor.read_u32_le()?);
|
||||
|
||||
// NextCommand (4 bytes)
|
||||
let next_command = cursor.read_u32_le()?;
|
||||
|
||||
// MessageId (8 bytes)
|
||||
let message_id = MessageId(cursor.read_u64_le()?);
|
||||
|
||||
// Bytes 32-39: async or sync variant
|
||||
let (tree_id, async_id) = if flags.is_async() {
|
||||
let async_id = cursor.read_u64_le()?;
|
||||
(None, Some(async_id))
|
||||
} else {
|
||||
let _reserved = cursor.read_u32_le()?;
|
||||
let tree_id = TreeId(cursor.read_u32_le()?);
|
||||
(Some(tree_id), None)
|
||||
};
|
||||
|
||||
// SessionId (8 bytes)
|
||||
let session_id = SessionId(cursor.read_u64_le()?);
|
||||
|
||||
// Signature (16 bytes)
|
||||
let sig_bytes = cursor.read_bytes(16)?;
|
||||
let mut signature = [0u8; 16];
|
||||
signature.copy_from_slice(sig_bytes);
|
||||
|
||||
Ok(Header {
|
||||
credit_charge,
|
||||
status,
|
||||
command,
|
||||
credits,
|
||||
flags,
|
||||
next_command,
|
||||
message_id,
|
||||
tree_id,
|
||||
async_id,
|
||||
session_id,
|
||||
signature,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// SMB2 ERROR Response body (spec section 2.2.2).
|
||||
///
|
||||
/// Sent by the server when a request fails. The structure is:
|
||||
/// - StructureSize (2 bytes, must be 9)
|
||||
/// - ErrorContextCount (1 byte)
|
||||
/// - Reserved (1 byte)
|
||||
/// - ByteCount (4 bytes)
|
||||
/// - ErrorData (variable, ByteCount bytes)
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct ErrorResponse {
|
||||
/// Number of error contexts (SMB 3.1.1 only, otherwise 0).
|
||||
pub error_context_count: u8,
|
||||
/// Variable-length error data.
|
||||
pub error_data: Vec<u8>,
|
||||
}
|
||||
|
||||
impl ErrorResponse {
|
||||
pub const STRUCTURE_SIZE: u16 = 9;
|
||||
}
|
||||
|
||||
impl Pack for ErrorResponse {
|
||||
fn pack(&self, cursor: &mut WriteCursor) {
|
||||
// StructureSize (2 bytes)
|
||||
cursor.write_u16_le(Self::STRUCTURE_SIZE);
|
||||
// ErrorContextCount (1 byte)
|
||||
cursor.write_u8(self.error_context_count);
|
||||
// Reserved (1 byte)
|
||||
cursor.write_u8(0);
|
||||
// ByteCount (4 bytes)
|
||||
cursor.write_u32_le(self.error_data.len() as u32);
|
||||
// ErrorData (variable)
|
||||
cursor.write_bytes(&self.error_data);
|
||||
}
|
||||
}
|
||||
|
||||
impl Unpack for ErrorResponse {
|
||||
fn unpack(cursor: &mut ReadCursor<'_>) -> Result<Self> {
|
||||
// StructureSize (2 bytes)
|
||||
let structure_size = cursor.read_u16_le()?;
|
||||
if structure_size != Self::STRUCTURE_SIZE {
|
||||
return Err(Error::invalid_data(format!(
|
||||
"invalid ErrorResponse structure size: expected {}, got {}",
|
||||
Self::STRUCTURE_SIZE,
|
||||
structure_size
|
||||
)));
|
||||
}
|
||||
|
||||
// ErrorContextCount (1 byte)
|
||||
let error_context_count = cursor.read_u8()?;
|
||||
|
||||
// Reserved (1 byte)
|
||||
let _reserved = cursor.read_u8()?;
|
||||
|
||||
// ByteCount (4 bytes)
|
||||
let byte_count = cursor.read_u32_le()? as usize;
|
||||
|
||||
// ErrorData (variable)
|
||||
let error_data = if byte_count > 0 {
|
||||
cursor.read_bytes_bounded(byte_count)?.to_vec()
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
|
||||
Ok(ErrorResponse {
|
||||
error_context_count,
|
||||
error_data,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// ── Header tests ────────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn pack_request_header_produces_64_bytes_with_correct_magic() {
|
||||
let header = Header::new_request(Command::Negotiate);
|
||||
let mut cursor = WriteCursor::new();
|
||||
header.pack(&mut cursor);
|
||||
let bytes = cursor.into_inner();
|
||||
|
||||
assert_eq!(bytes.len(), Header::SIZE);
|
||||
assert_eq!(&bytes[0..4], &PROTOCOL_ID);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unpack_known_64_byte_buffer() {
|
||||
// Build a known buffer manually: sync Negotiate request
|
||||
let mut buf = [0u8; 64];
|
||||
// ProtocolId
|
||||
buf[0..4].copy_from_slice(&PROTOCOL_ID);
|
||||
// StructureSize = 64
|
||||
buf[4..6].copy_from_slice(&64u16.to_le_bytes());
|
||||
// CreditCharge = 1
|
||||
buf[6..8].copy_from_slice(&1u16.to_le_bytes());
|
||||
// Status = SUCCESS (0)
|
||||
buf[8..12].copy_from_slice(&0u32.to_le_bytes());
|
||||
// Command = Negotiate (0)
|
||||
buf[12..14].copy_from_slice(&0u16.to_le_bytes());
|
||||
// Credits = 31
|
||||
buf[14..16].copy_from_slice(&31u16.to_le_bytes());
|
||||
// Flags = 0 (sync, request)
|
||||
buf[16..20].copy_from_slice(&0u32.to_le_bytes());
|
||||
// NextCommand = 0
|
||||
buf[20..24].copy_from_slice(&0u32.to_le_bytes());
|
||||
// MessageId = 42
|
||||
buf[24..32].copy_from_slice(&42u64.to_le_bytes());
|
||||
// Reserved = 0
|
||||
buf[32..36].copy_from_slice(&0u32.to_le_bytes());
|
||||
// TreeId = 7
|
||||
buf[36..40].copy_from_slice(&7u32.to_le_bytes());
|
||||
// SessionId = 0x1234
|
||||
buf[40..48].copy_from_slice(&0x1234u64.to_le_bytes());
|
||||
// Signature = all zeros
|
||||
// (already zero)
|
||||
|
||||
let mut cursor = ReadCursor::new(&buf);
|
||||
let header = Header::unpack(&mut cursor).unwrap();
|
||||
|
||||
assert_eq!(header.credit_charge, CreditCharge(1));
|
||||
assert_eq!(header.status, NtStatus::SUCCESS);
|
||||
assert_eq!(header.command, Command::Negotiate);
|
||||
assert_eq!(header.credits, 31);
|
||||
assert!(!header.flags.is_async());
|
||||
assert!(!header.flags.is_response());
|
||||
assert_eq!(header.next_command, 0);
|
||||
assert_eq!(header.message_id, MessageId(42));
|
||||
assert_eq!(header.tree_id, Some(TreeId(7)));
|
||||
assert_eq!(header.async_id, None);
|
||||
assert_eq!(header.session_id, SessionId(0x1234));
|
||||
assert_eq!(header.signature, [0u8; 16]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn roundtrip_sync_header() {
|
||||
let original = Header {
|
||||
credit_charge: CreditCharge(3),
|
||||
status: NtStatus::ACCESS_DENIED,
|
||||
command: Command::Read,
|
||||
credits: 10,
|
||||
flags: {
|
||||
let mut f = HeaderFlags::default();
|
||||
f.set_response();
|
||||
f
|
||||
},
|
||||
next_command: 0,
|
||||
message_id: MessageId(99),
|
||||
tree_id: Some(TreeId(42)),
|
||||
async_id: None,
|
||||
session_id: SessionId(0xDEAD_BEEF),
|
||||
signature: [
|
||||
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E,
|
||||
0x0F, 0x10,
|
||||
],
|
||||
};
|
||||
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
assert_eq!(bytes.len(), Header::SIZE);
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = Header::unpack(&mut r).unwrap();
|
||||
|
||||
assert_eq!(decoded.credit_charge, original.credit_charge);
|
||||
assert_eq!(decoded.status, original.status);
|
||||
assert_eq!(decoded.command, original.command);
|
||||
assert_eq!(decoded.credits, original.credits);
|
||||
assert_eq!(decoded.flags.bits(), original.flags.bits());
|
||||
assert_eq!(decoded.next_command, original.next_command);
|
||||
assert_eq!(decoded.message_id, original.message_id);
|
||||
assert_eq!(decoded.tree_id, original.tree_id);
|
||||
assert_eq!(decoded.async_id, original.async_id);
|
||||
assert_eq!(decoded.session_id, original.session_id);
|
||||
assert_eq!(decoded.signature, original.signature);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn wrong_magic_bytes_returns_error() {
|
||||
let mut buf = [0u8; 64];
|
||||
// Wrong magic
|
||||
buf[0..4].copy_from_slice(&[0xFF, b'X', b'Y', b'Z']);
|
||||
buf[4..6].copy_from_slice(&64u16.to_le_bytes());
|
||||
|
||||
let mut cursor = ReadCursor::new(&buf);
|
||||
let result = Header::unpack(&mut cursor);
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err().to_string();
|
||||
assert!(err.contains("protocol ID"), "error was: {err}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn wrong_structure_size_returns_error() {
|
||||
let mut buf = [0u8; 64];
|
||||
buf[0..4].copy_from_slice(&PROTOCOL_ID);
|
||||
// Wrong structure size
|
||||
buf[4..6].copy_from_slice(&32u16.to_le_bytes());
|
||||
|
||||
let mut cursor = ReadCursor::new(&buf);
|
||||
let result = Header::unpack(&mut cursor);
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err().to_string();
|
||||
assert!(err.contains("structure size"), "error was: {err}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn async_header_pack_unpack() {
|
||||
let mut flags = HeaderFlags::default();
|
||||
flags.set_async();
|
||||
flags.set_response();
|
||||
|
||||
let original = Header {
|
||||
credit_charge: CreditCharge(0),
|
||||
status: NtStatus::PENDING,
|
||||
command: Command::ChangeNotify,
|
||||
credits: 1,
|
||||
flags,
|
||||
next_command: 0,
|
||||
message_id: MessageId(8),
|
||||
tree_id: None,
|
||||
async_id: Some(0x0000_0000_0000_0008),
|
||||
session_id: SessionId(0x0000_0000_0853_27D7),
|
||||
signature: [0u8; 16],
|
||||
};
|
||||
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
assert_eq!(bytes.len(), Header::SIZE);
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = Header::unpack(&mut r).unwrap();
|
||||
|
||||
assert!(decoded.flags.is_async());
|
||||
assert_eq!(decoded.async_id, Some(8));
|
||||
assert_eq!(decoded.tree_id, None);
|
||||
assert_eq!(decoded.command, Command::ChangeNotify);
|
||||
assert_eq!(decoded.status, NtStatus::PENDING);
|
||||
assert_eq!(decoded.session_id, SessionId(0x0000_0000_0853_27D7));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sync_header_has_tree_id_and_no_async_id() {
|
||||
let header = Header::new_request(Command::Create);
|
||||
|
||||
let mut w = WriteCursor::new();
|
||||
header.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = Header::unpack(&mut r).unwrap();
|
||||
|
||||
assert!(!decoded.flags.is_async());
|
||||
assert!(decoded.tree_id.is_some());
|
||||
assert_eq!(decoded.async_id, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn signature_field_preserved() {
|
||||
let sig = [
|
||||
0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88,
|
||||
0x99, 0x00,
|
||||
];
|
||||
let mut header = Header::new_request(Command::Echo);
|
||||
header.signature = sig;
|
||||
|
||||
let mut w = WriteCursor::new();
|
||||
header.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = Header::unpack(&mut r).unwrap();
|
||||
|
||||
assert_eq!(decoded.signature, sig);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn new_request_produces_correct_defaults() {
|
||||
let header = Header::new_request(Command::Write);
|
||||
|
||||
assert_eq!(header.command, Command::Write);
|
||||
assert_eq!(header.credit_charge, CreditCharge(0));
|
||||
assert_eq!(header.status, NtStatus::SUCCESS);
|
||||
assert_eq!(header.credits, 1);
|
||||
assert!(!header.flags.is_response());
|
||||
assert!(!header.flags.is_async());
|
||||
assert_eq!(header.next_command, 0);
|
||||
assert_eq!(header.message_id, MessageId(0));
|
||||
assert_eq!(header.tree_id, Some(TreeId(0)));
|
||||
assert_eq!(header.async_id, None);
|
||||
assert_eq!(header.session_id, SessionId(0));
|
||||
assert_eq!(header.signature, [0u8; 16]);
|
||||
assert!(!header.is_response());
|
||||
}
|
||||
|
||||
// ── ErrorResponse tests ─────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn error_response_pack_unpack_empty() {
|
||||
let original = ErrorResponse {
|
||||
error_context_count: 0,
|
||||
error_data: Vec::new(),
|
||||
};
|
||||
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
// StructureSize(2) + ErrorContextCount(1) + Reserved(1) + ByteCount(4) = 8
|
||||
assert_eq!(bytes.len(), 8);
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = ErrorResponse::unpack(&mut r).unwrap();
|
||||
|
||||
assert_eq!(decoded.error_context_count, 0);
|
||||
assert!(decoded.error_data.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn error_response_pack_unpack_with_data() {
|
||||
let data = vec![0xDE, 0xAD, 0xBE, 0xEF, 0xCA, 0xFE];
|
||||
let original = ErrorResponse {
|
||||
error_context_count: 1,
|
||||
error_data: data.clone(),
|
||||
};
|
||||
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
// 8 bytes fixed + 6 bytes data
|
||||
assert_eq!(bytes.len(), 14);
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = ErrorResponse::unpack(&mut r).unwrap();
|
||||
|
||||
assert_eq!(decoded.error_context_count, 1);
|
||||
assert_eq!(decoded.error_data, data);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn error_response_roundtrip() {
|
||||
let original = ErrorResponse {
|
||||
error_context_count: 2,
|
||||
error_data: vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
|
||||
};
|
||||
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = ErrorResponse::unpack(&mut r).unwrap();
|
||||
|
||||
assert_eq!(decoded.error_context_count, original.error_context_count);
|
||||
assert_eq!(decoded.error_data, original.error_data);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod roundtrip_props {
|
||||
use super::*;
|
||||
use crate::msg::roundtrip_strategies::{
|
||||
arb_command, arb_credit_charge, arb_header_flags, arb_message_id, arb_nt_status,
|
||||
arb_session_id, arb_small_bytes, arb_tree_id,
|
||||
};
|
||||
use proptest::prelude::*;
|
||||
|
||||
/// Generate a `Header` whose `flags.is_async()` matches which of
|
||||
/// `tree_id`/`async_id` is set. Any other combination wouldn't round-trip
|
||||
/// (pack writes one or the other based on flags, and clears the other on
|
||||
/// unpack), so we never generate it.
|
||||
fn arb_header() -> impl Strategy<Value = Header> {
|
||||
(
|
||||
arb_credit_charge(),
|
||||
arb_nt_status(),
|
||||
arb_command(),
|
||||
any::<u16>(),
|
||||
arb_header_flags(),
|
||||
any::<u32>(),
|
||||
arb_message_id(),
|
||||
any::<bool>(),
|
||||
arb_tree_id(),
|
||||
any::<u64>(),
|
||||
arb_session_id(),
|
||||
any::<[u8; 16]>(),
|
||||
)
|
||||
.prop_map(
|
||||
|(
|
||||
credit_charge,
|
||||
status,
|
||||
command,
|
||||
credits,
|
||||
raw_flags,
|
||||
next_command,
|
||||
message_id,
|
||||
make_async,
|
||||
tree_id,
|
||||
async_id,
|
||||
session_id,
|
||||
signature,
|
||||
)| {
|
||||
// Force `flags.ASYNC_COMMAND` to match `make_async` so
|
||||
// the pack path and the `Option<T>` fields agree.
|
||||
let flags = if make_async {
|
||||
let mut f = raw_flags;
|
||||
f.set(HeaderFlags::ASYNC_COMMAND);
|
||||
f
|
||||
} else {
|
||||
let mut f = raw_flags;
|
||||
f.clear(HeaderFlags::ASYNC_COMMAND);
|
||||
f
|
||||
};
|
||||
let (tree_id, async_id) = if make_async {
|
||||
(None, Some(async_id))
|
||||
} else {
|
||||
(Some(tree_id), None)
|
||||
};
|
||||
Header {
|
||||
credit_charge,
|
||||
status,
|
||||
command,
|
||||
credits,
|
||||
flags,
|
||||
next_command,
|
||||
message_id,
|
||||
tree_id,
|
||||
async_id,
|
||||
session_id,
|
||||
signature,
|
||||
}
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
proptest! {
|
||||
#[test]
|
||||
fn header_pack_unpack(header in arb_header()) {
|
||||
let mut w = WriteCursor::new();
|
||||
header.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
prop_assert_eq!(bytes.len(), Header::SIZE);
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = Header::unpack(&mut r).unwrap();
|
||||
prop_assert_eq!(decoded, header);
|
||||
prop_assert!(r.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn error_response_pack_unpack(
|
||||
error_context_count in any::<u8>(),
|
||||
error_data in arb_small_bytes(),
|
||||
) {
|
||||
let original = ErrorResponse {
|
||||
error_context_count,
|
||||
error_data,
|
||||
};
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = ErrorResponse::unpack(&mut r).unwrap();
|
||||
prop_assert_eq!(decoded, original);
|
||||
prop_assert!(r.is_empty());
|
||||
}
|
||||
}
|
||||
}
|
||||
479
vendor/smb2/src/msg/ioctl.rs
vendored
Normal file
479
vendor/smb2/src/msg/ioctl.rs
vendored
Normal file
@@ -0,0 +1,479 @@
|
||||
//! SMB2 IOCTL Request and Response (MS-SMB2 sections 2.2.31, 2.2.32).
|
||||
//!
|
||||
//! The IOCTL request sends a control code to a server, optionally with input
|
||||
//! data. The response returns output data from the control operation.
|
||||
|
||||
use crate::error::Result;
|
||||
use crate::pack::{Pack, ReadCursor, Unpack, WriteCursor};
|
||||
use crate::types::FileId;
|
||||
use crate::Error;
|
||||
|
||||
// ── IOCTL flags ────────────────────────────────────────────────────────
|
||||
|
||||
/// The request is a file system control (FSCTL) request.
|
||||
pub const SMB2_0_IOCTL_IS_FSCTL: u32 = 0x0000_0001;
|
||||
|
||||
// ── Common CtlCode values ──────────────────────────────────────────────
|
||||
|
||||
/// Named pipe transceive operation.
|
||||
pub const FSCTL_PIPE_TRANSCEIVE: u32 = 0x0011_C017;
|
||||
|
||||
/// Server-side copy chunk (read handle).
|
||||
pub const FSCTL_SRV_COPYCHUNK: u32 = 0x0014_40F2;
|
||||
|
||||
/// Server-side copy chunk (write handle).
|
||||
pub const FSCTL_SRV_COPYCHUNK_WRITE: u32 = 0x0014_80F2;
|
||||
|
||||
/// DFS referral request.
|
||||
pub const FSCTL_DFS_GET_REFERRALS: u32 = 0x0006_0194;
|
||||
|
||||
/// Validate negotiate info (SMB 3.x).
|
||||
pub const FSCTL_VALIDATE_NEGOTIATE_INFO: u32 = 0x0014_0204;
|
||||
|
||||
// ── IoctlRequest ───────────────────────────────────────────────────────
|
||||
|
||||
/// SMB2 IOCTL Request (MS-SMB2 section 2.2.31).
|
||||
///
|
||||
/// Sent by the client to issue a device or file system control command.
|
||||
/// The fixed part is 56 bytes (StructureSize = 57 indicates 1 byte of
|
||||
/// variable data is included in the fixed size, per SMB2 convention).
|
||||
///
|
||||
/// Layout:
|
||||
/// - StructureSize (2 bytes, must be 57)
|
||||
/// - Reserved (2 bytes)
|
||||
/// - CtlCode (4 bytes)
|
||||
/// - FileId (16 bytes)
|
||||
/// - InputOffset (4 bytes)
|
||||
/// - InputCount (4 bytes)
|
||||
/// - MaxInputResponse (4 bytes)
|
||||
/// - OutputOffset (4 bytes)
|
||||
/// - OutputCount (4 bytes)
|
||||
/// - MaxOutputResponse (4 bytes)
|
||||
/// - Flags (4 bytes)
|
||||
/// - Reserved2 (4 bytes)
|
||||
/// - Buffer (variable, InputCount bytes)
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct IoctlRequest {
|
||||
/// The control code for the operation.
|
||||
pub ctl_code: u32,
|
||||
/// The file handle for the operation.
|
||||
pub file_id: FileId,
|
||||
/// Maximum number of input bytes the server can return.
|
||||
pub max_input_response: u32,
|
||||
/// Maximum number of output bytes the server can return.
|
||||
pub max_output_response: u32,
|
||||
/// Flags for the request (for example, `SMB2_0_IOCTL_IS_FSCTL`).
|
||||
pub flags: u32,
|
||||
/// Input data buffer.
|
||||
pub input_data: Vec<u8>,
|
||||
}
|
||||
|
||||
impl IoctlRequest {
|
||||
pub const STRUCTURE_SIZE: u16 = 57;
|
||||
|
||||
/// Fixed header size before the variable buffer (56 bytes).
|
||||
const FIXED_SIZE: u32 = 56;
|
||||
}
|
||||
|
||||
impl Pack for IoctlRequest {
|
||||
fn pack(&self, cursor: &mut WriteCursor) {
|
||||
let start = cursor.position();
|
||||
// StructureSize (2 bytes)
|
||||
cursor.write_u16_le(Self::STRUCTURE_SIZE);
|
||||
// Reserved (2 bytes)
|
||||
cursor.write_u16_le(0);
|
||||
// CtlCode (4 bytes)
|
||||
cursor.write_u32_le(self.ctl_code);
|
||||
// FileId (16 bytes)
|
||||
cursor.write_u64_le(self.file_id.persistent);
|
||||
cursor.write_u64_le(self.file_id.volatile);
|
||||
|
||||
let input_count = self.input_data.len() as u32;
|
||||
// Offset is from the beginning of the SMB2 header per spec.
|
||||
// `start` is the cursor position at the beginning of the body;
|
||||
// in a standalone request this equals Header::SIZE, in a compound
|
||||
// it includes the preceding sub-requests.
|
||||
let input_offset = if input_count > 0 {
|
||||
(start as u32) + Self::FIXED_SIZE
|
||||
} else {
|
||||
0
|
||||
};
|
||||
|
||||
// InputOffset (4 bytes)
|
||||
cursor.write_u32_le(input_offset);
|
||||
// InputCount (4 bytes)
|
||||
cursor.write_u32_le(input_count);
|
||||
// MaxInputResponse (4 bytes)
|
||||
cursor.write_u32_le(self.max_input_response);
|
||||
// OutputOffset (4 bytes) -- no output data in the request
|
||||
cursor.write_u32_le(0);
|
||||
// OutputCount (4 bytes) -- no output data in the request
|
||||
cursor.write_u32_le(0);
|
||||
// MaxOutputResponse (4 bytes)
|
||||
cursor.write_u32_le(self.max_output_response);
|
||||
// Flags (4 bytes)
|
||||
cursor.write_u32_le(self.flags);
|
||||
// Reserved2 (4 bytes)
|
||||
cursor.write_u32_le(0);
|
||||
// Buffer (variable)
|
||||
cursor.write_bytes(&self.input_data);
|
||||
}
|
||||
}
|
||||
|
||||
impl Unpack for IoctlRequest {
|
||||
fn unpack(cursor: &mut ReadCursor<'_>) -> Result<Self> {
|
||||
let structure_size = cursor.read_u16_le()?;
|
||||
if structure_size != Self::STRUCTURE_SIZE {
|
||||
return Err(Error::invalid_data(format!(
|
||||
"invalid IoctlRequest structure size: expected {}, got {}",
|
||||
Self::STRUCTURE_SIZE,
|
||||
structure_size
|
||||
)));
|
||||
}
|
||||
|
||||
let _reserved = cursor.read_u16_le()?;
|
||||
let ctl_code = cursor.read_u32_le()?;
|
||||
let persistent = cursor.read_u64_le()?;
|
||||
let volatile = cursor.read_u64_le()?;
|
||||
let _input_offset = cursor.read_u32_le()?;
|
||||
let input_count = cursor.read_u32_le()?;
|
||||
let max_input_response = cursor.read_u32_le()?;
|
||||
let _output_offset = cursor.read_u32_le()?;
|
||||
let _output_count = cursor.read_u32_le()?;
|
||||
let max_output_response = cursor.read_u32_le()?;
|
||||
let flags = cursor.read_u32_le()?;
|
||||
let _reserved2 = cursor.read_u32_le()?;
|
||||
|
||||
let input_data = if input_count > 0 {
|
||||
cursor.read_bytes_bounded(input_count as usize)?.to_vec()
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
|
||||
Ok(IoctlRequest {
|
||||
ctl_code,
|
||||
file_id: FileId {
|
||||
persistent,
|
||||
volatile,
|
||||
},
|
||||
max_input_response,
|
||||
max_output_response,
|
||||
flags,
|
||||
input_data,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ── IoctlResponse ──────────────────────────────────────────────────────
|
||||
|
||||
/// SMB2 IOCTL Response (MS-SMB2 section 2.2.32).
|
||||
///
|
||||
/// Sent by the server to return the results of an IOCTL operation.
|
||||
///
|
||||
/// Layout:
|
||||
/// - StructureSize (2 bytes, must be 49)
|
||||
/// - Reserved (2 bytes)
|
||||
/// - CtlCode (4 bytes)
|
||||
/// - FileId (16 bytes)
|
||||
/// - InputOffset (4 bytes)
|
||||
/// - InputCount (4 bytes)
|
||||
/// - OutputOffset (4 bytes)
|
||||
/// - OutputCount (4 bytes)
|
||||
/// - Flags (4 bytes)
|
||||
/// - Reserved2 (4 bytes)
|
||||
/// - Buffer (variable -- may contain both input and output data)
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct IoctlResponse {
|
||||
/// The control code echoed from the request.
|
||||
pub ctl_code: u32,
|
||||
/// The file handle echoed from the request.
|
||||
pub file_id: FileId,
|
||||
/// Flags echoed from the request.
|
||||
pub flags: u32,
|
||||
/// Output data buffer returned by the server.
|
||||
pub output_data: Vec<u8>,
|
||||
}
|
||||
|
||||
impl IoctlResponse {
|
||||
pub const STRUCTURE_SIZE: u16 = 49;
|
||||
|
||||
/// Fixed header size before the variable buffer (48 bytes).
|
||||
const FIXED_SIZE: u32 = 48;
|
||||
}
|
||||
|
||||
impl Pack for IoctlResponse {
|
||||
fn pack(&self, cursor: &mut WriteCursor) {
|
||||
let start = cursor.position();
|
||||
// StructureSize (2 bytes)
|
||||
cursor.write_u16_le(Self::STRUCTURE_SIZE);
|
||||
// Reserved (2 bytes)
|
||||
cursor.write_u16_le(0);
|
||||
// CtlCode (4 bytes)
|
||||
cursor.write_u32_le(self.ctl_code);
|
||||
// FileId (16 bytes)
|
||||
cursor.write_u64_le(self.file_id.persistent);
|
||||
cursor.write_u64_le(self.file_id.volatile);
|
||||
|
||||
let output_count = self.output_data.len() as u32;
|
||||
// Offset is from the beginning of the SMB2 header per spec.
|
||||
let output_offset = if output_count > 0 {
|
||||
(start as u32) + Self::FIXED_SIZE
|
||||
} else {
|
||||
0
|
||||
};
|
||||
|
||||
// InputOffset (4 bytes) -- no input data in the response
|
||||
cursor.write_u32_le(0);
|
||||
// InputCount (4 bytes)
|
||||
cursor.write_u32_le(0);
|
||||
// OutputOffset (4 bytes)
|
||||
cursor.write_u32_le(output_offset);
|
||||
// OutputCount (4 bytes)
|
||||
cursor.write_u32_le(output_count);
|
||||
// Flags (4 bytes)
|
||||
cursor.write_u32_le(self.flags);
|
||||
// Reserved2 (4 bytes)
|
||||
cursor.write_u32_le(0);
|
||||
// Buffer (variable)
|
||||
cursor.write_bytes(&self.output_data);
|
||||
}
|
||||
}
|
||||
|
||||
impl Unpack for IoctlResponse {
|
||||
fn unpack(cursor: &mut ReadCursor<'_>) -> Result<Self> {
|
||||
let structure_size = cursor.read_u16_le()?;
|
||||
if structure_size != Self::STRUCTURE_SIZE {
|
||||
return Err(Error::invalid_data(format!(
|
||||
"invalid IoctlResponse structure size: expected {}, got {}",
|
||||
Self::STRUCTURE_SIZE,
|
||||
structure_size
|
||||
)));
|
||||
}
|
||||
|
||||
let _reserved = cursor.read_u16_le()?;
|
||||
let ctl_code = cursor.read_u32_le()?;
|
||||
let persistent = cursor.read_u64_le()?;
|
||||
let volatile = cursor.read_u64_le()?;
|
||||
let _input_offset = cursor.read_u32_le()?;
|
||||
let _input_count = cursor.read_u32_le()?;
|
||||
let _output_offset = cursor.read_u32_le()?;
|
||||
let output_count = cursor.read_u32_le()?;
|
||||
let flags = cursor.read_u32_le()?;
|
||||
let _reserved2 = cursor.read_u32_le()?;
|
||||
|
||||
let output_data = if output_count > 0 {
|
||||
cursor.read_bytes_bounded(output_count as usize)?.to_vec()
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
|
||||
Ok(IoctlResponse {
|
||||
ctl_code,
|
||||
file_id: FileId {
|
||||
persistent,
|
||||
volatile,
|
||||
},
|
||||
flags,
|
||||
output_data,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// ── IoctlRequest tests ────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn ioctl_request_roundtrip_with_input_data() {
|
||||
let original = IoctlRequest {
|
||||
ctl_code: FSCTL_PIPE_TRANSCEIVE,
|
||||
file_id: FileId {
|
||||
persistent: 0x1122_3344_5566_7788,
|
||||
volatile: 0xAABB_CCDD_EEFF_0011,
|
||||
},
|
||||
max_input_response: 0,
|
||||
max_output_response: 4096,
|
||||
flags: SMB2_0_IOCTL_IS_FSCTL,
|
||||
input_data: vec![0x01, 0x02, 0x03, 0x04, 0x05],
|
||||
};
|
||||
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
// Fixed 56 bytes + 5 bytes input data
|
||||
assert_eq!(bytes.len(), 61);
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = IoctlRequest::unpack(&mut r).unwrap();
|
||||
|
||||
assert_eq!(decoded.ctl_code, FSCTL_PIPE_TRANSCEIVE);
|
||||
assert_eq!(decoded.file_id, original.file_id);
|
||||
assert_eq!(decoded.max_input_response, 0);
|
||||
assert_eq!(decoded.max_output_response, 4096);
|
||||
assert_eq!(decoded.flags, SMB2_0_IOCTL_IS_FSCTL);
|
||||
assert_eq!(decoded.input_data, vec![0x01, 0x02, 0x03, 0x04, 0x05]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ioctl_request_roundtrip_no_input_data() {
|
||||
let original = IoctlRequest {
|
||||
ctl_code: FSCTL_VALIDATE_NEGOTIATE_INFO,
|
||||
file_id: FileId::SENTINEL,
|
||||
max_input_response: 0,
|
||||
max_output_response: 256,
|
||||
flags: SMB2_0_IOCTL_IS_FSCTL,
|
||||
input_data: Vec::new(),
|
||||
};
|
||||
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
assert_eq!(bytes.len(), 56);
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = IoctlRequest::unpack(&mut r).unwrap();
|
||||
|
||||
assert_eq!(decoded.ctl_code, FSCTL_VALIDATE_NEGOTIATE_INFO);
|
||||
assert_eq!(decoded.file_id, FileId::SENTINEL);
|
||||
assert!(decoded.input_data.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ioctl_request_wrong_structure_size() {
|
||||
let mut buf = [0u8; 56];
|
||||
buf[0..2].copy_from_slice(&99u16.to_le_bytes());
|
||||
|
||||
let mut cursor = ReadCursor::new(&buf);
|
||||
let result = IoctlRequest::unpack(&mut cursor);
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err().to_string();
|
||||
assert!(err.contains("structure size"), "error was: {err}");
|
||||
}
|
||||
|
||||
// ── IoctlResponse tests ───────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn ioctl_response_roundtrip_with_output_data() {
|
||||
let original = IoctlResponse {
|
||||
ctl_code: FSCTL_PIPE_TRANSCEIVE,
|
||||
file_id: FileId {
|
||||
persistent: 0x42,
|
||||
volatile: 0x99,
|
||||
},
|
||||
flags: SMB2_0_IOCTL_IS_FSCTL,
|
||||
output_data: vec![0xDE, 0xAD, 0xBE, 0xEF],
|
||||
};
|
||||
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
// Fixed 48 bytes + 4 bytes output data
|
||||
assert_eq!(bytes.len(), 52);
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = IoctlResponse::unpack(&mut r).unwrap();
|
||||
|
||||
assert_eq!(decoded.ctl_code, FSCTL_PIPE_TRANSCEIVE);
|
||||
assert_eq!(decoded.file_id, original.file_id);
|
||||
assert_eq!(decoded.flags, SMB2_0_IOCTL_IS_FSCTL);
|
||||
assert_eq!(decoded.output_data, vec![0xDE, 0xAD, 0xBE, 0xEF]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ioctl_response_roundtrip_no_output_data() {
|
||||
let original = IoctlResponse {
|
||||
ctl_code: FSCTL_SRV_COPYCHUNK,
|
||||
file_id: FileId::default(),
|
||||
flags: 0,
|
||||
output_data: Vec::new(),
|
||||
};
|
||||
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
assert_eq!(bytes.len(), 48);
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = IoctlResponse::unpack(&mut r).unwrap();
|
||||
|
||||
assert_eq!(decoded.ctl_code, FSCTL_SRV_COPYCHUNK);
|
||||
assert!(decoded.output_data.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ioctl_response_wrong_structure_size() {
|
||||
let mut buf = [0u8; 48];
|
||||
buf[0..2].copy_from_slice(&42u16.to_le_bytes());
|
||||
|
||||
let mut cursor = ReadCursor::new(&buf);
|
||||
let result = IoctlResponse::unpack(&mut cursor);
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err().to_string();
|
||||
assert!(err.contains("structure size"), "error was: {err}");
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod roundtrip_props {
|
||||
use super::*;
|
||||
use crate::msg::roundtrip_strategies::{arb_bytes, arb_file_id};
|
||||
use proptest::prelude::*;
|
||||
|
||||
proptest! {
|
||||
#[test]
|
||||
fn ioctl_request_pack_unpack(
|
||||
ctl_code in any::<u32>(),
|
||||
file_id in arb_file_id(),
|
||||
max_input_response in any::<u32>(),
|
||||
max_output_response in any::<u32>(),
|
||||
flags in any::<u32>(),
|
||||
input_data in arb_bytes(),
|
||||
) {
|
||||
let original = IoctlRequest {
|
||||
ctl_code,
|
||||
file_id,
|
||||
max_input_response,
|
||||
max_output_response,
|
||||
flags,
|
||||
input_data,
|
||||
};
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = IoctlRequest::unpack(&mut r).unwrap();
|
||||
prop_assert_eq!(decoded, original);
|
||||
prop_assert!(r.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ioctl_response_pack_unpack(
|
||||
ctl_code in any::<u32>(),
|
||||
file_id in arb_file_id(),
|
||||
flags in any::<u32>(),
|
||||
output_data in arb_bytes(),
|
||||
) {
|
||||
let original = IoctlResponse {
|
||||
ctl_code,
|
||||
file_id,
|
||||
flags,
|
||||
output_data,
|
||||
};
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = IoctlResponse::unpack(&mut r).unwrap();
|
||||
prop_assert_eq!(decoded, original);
|
||||
prop_assert!(r.is_empty());
|
||||
}
|
||||
}
|
||||
}
|
||||
445
vendor/smb2/src/msg/lock.rs
vendored
Normal file
445
vendor/smb2/src/msg/lock.rs
vendored
Normal file
@@ -0,0 +1,445 @@
|
||||
//! SMB2 LOCK Request and Response (MS-SMB2 sections 2.2.26, 2.2.27).
|
||||
//!
|
||||
//! The LOCK request locks or unlocks byte ranges within a file.
|
||||
//! Multiple ranges can be locked/unlocked in a single request.
|
||||
|
||||
use crate::error::Result;
|
||||
use crate::pack::{Pack, ReadCursor, Unpack, WriteCursor};
|
||||
use crate::types::FileId;
|
||||
use crate::Error;
|
||||
|
||||
/// Lock flag: shared lock (allows other readers).
|
||||
pub const SMB2_LOCKFLAG_SHARED_LOCK: u32 = 0x0000_0001;
|
||||
|
||||
/// Lock flag: exclusive lock (no other readers or writers).
|
||||
pub const SMB2_LOCKFLAG_EXCLUSIVE_LOCK: u32 = 0x0000_0002;
|
||||
|
||||
/// Lock flag: unlock a previously locked range.
|
||||
pub const SMB2_LOCKFLAG_UNLOCK: u32 = 0x0000_0004;
|
||||
|
||||
/// Lock flag: fail immediately if the lock conflicts.
|
||||
pub const SMB2_LOCKFLAG_FAIL_IMMEDIATELY: u32 = 0x0000_0010;
|
||||
|
||||
/// A single lock element describing a byte range to lock or unlock.
|
||||
///
|
||||
/// Each element is 24 bytes on the wire:
|
||||
/// - Offset (8 bytes)
|
||||
/// - Length (8 bytes)
|
||||
/// - Flags (4 bytes)
|
||||
/// - Reserved (4 bytes)
|
||||
///
|
||||
/// Reference: MS-SMB2 section 2.2.26.1.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct LockElement {
|
||||
/// Starting offset in bytes from where the range begins.
|
||||
pub offset: u64,
|
||||
/// Length of the range in bytes.
|
||||
pub length: u64,
|
||||
/// Flags describing how the range is locked or unlocked.
|
||||
pub flags: u32,
|
||||
}
|
||||
|
||||
impl LockElement {
|
||||
/// Wire size of a single lock element.
|
||||
pub const SIZE: usize = 24;
|
||||
}
|
||||
|
||||
impl Pack for LockElement {
|
||||
fn pack(&self, cursor: &mut WriteCursor) {
|
||||
cursor.write_u64_le(self.offset);
|
||||
cursor.write_u64_le(self.length);
|
||||
cursor.write_u32_le(self.flags);
|
||||
cursor.write_u32_le(0); // Reserved
|
||||
}
|
||||
}
|
||||
|
||||
impl Unpack for LockElement {
|
||||
fn unpack(cursor: &mut ReadCursor<'_>) -> Result<Self> {
|
||||
let offset = cursor.read_u64_le()?;
|
||||
let length = cursor.read_u64_le()?;
|
||||
let flags = cursor.read_u32_le()?;
|
||||
let _reserved = cursor.read_u32_le()?;
|
||||
|
||||
Ok(LockElement {
|
||||
offset,
|
||||
length,
|
||||
flags,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// SMB2 LOCK Request (MS-SMB2 section 2.2.26).
|
||||
///
|
||||
/// Sent by the client to lock or unlock byte ranges. The fixed portion
|
||||
/// is 48 bytes (StructureSize=48, which includes one `LockElement`):
|
||||
/// - StructureSize (2 bytes, must be 48)
|
||||
/// - LockCount (2 bytes)
|
||||
/// - LockSequenceNumber/Index (4 bytes)
|
||||
/// - FileId (16 bytes)
|
||||
/// - Locks (variable, LockCount x 24 bytes each)
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct LockRequest {
|
||||
/// Combined lock sequence number (4 bits) and index (28 bits).
|
||||
/// In SMB 2.0.2 this field is reserved (0).
|
||||
pub lock_sequence: u32,
|
||||
/// File handle to lock ranges on.
|
||||
pub file_id: FileId,
|
||||
/// Array of lock elements. Must contain at least one element.
|
||||
pub locks: Vec<LockElement>,
|
||||
}
|
||||
|
||||
impl LockRequest {
|
||||
pub const STRUCTURE_SIZE: u16 = 48;
|
||||
}
|
||||
|
||||
impl Pack for LockRequest {
|
||||
fn pack(&self, cursor: &mut WriteCursor) {
|
||||
cursor.write_u16_le(Self::STRUCTURE_SIZE);
|
||||
cursor.write_u16_le(self.locks.len() as u16); // LockCount
|
||||
cursor.write_u32_le(self.lock_sequence);
|
||||
cursor.write_u64_le(self.file_id.persistent);
|
||||
cursor.write_u64_le(self.file_id.volatile);
|
||||
for lock in &self.locks {
|
||||
lock.pack(cursor);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Unpack for LockRequest {
|
||||
fn unpack(cursor: &mut ReadCursor<'_>) -> Result<Self> {
|
||||
let structure_size = cursor.read_u16_le()?;
|
||||
if structure_size != Self::STRUCTURE_SIZE {
|
||||
return Err(Error::invalid_data(format!(
|
||||
"invalid LockRequest structure size: expected {}, got {}",
|
||||
Self::STRUCTURE_SIZE,
|
||||
structure_size
|
||||
)));
|
||||
}
|
||||
|
||||
let lock_count = cursor.read_u16_le()?;
|
||||
let lock_sequence = cursor.read_u32_le()?;
|
||||
let persistent = cursor.read_u64_le()?;
|
||||
let volatile = cursor.read_u64_le()?;
|
||||
|
||||
let mut locks = Vec::with_capacity(lock_count as usize);
|
||||
for _ in 0..lock_count {
|
||||
locks.push(LockElement::unpack(cursor)?);
|
||||
}
|
||||
|
||||
Ok(LockRequest {
|
||||
lock_sequence,
|
||||
file_id: FileId {
|
||||
persistent,
|
||||
volatile,
|
||||
},
|
||||
locks,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// SMB2 LOCK Response (MS-SMB2 section 2.2.27).
|
||||
///
|
||||
/// Sent by the server to confirm a lock operation. The structure is 4 bytes:
|
||||
/// - StructureSize (2 bytes, must be 4)
|
||||
/// - Reserved (2 bytes)
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct LockResponse;
|
||||
|
||||
impl LockResponse {
|
||||
pub const STRUCTURE_SIZE: u16 = 4;
|
||||
}
|
||||
|
||||
impl Pack for LockResponse {
|
||||
fn pack(&self, cursor: &mut WriteCursor) {
|
||||
cursor.write_u16_le(Self::STRUCTURE_SIZE);
|
||||
cursor.write_u16_le(0); // Reserved
|
||||
}
|
||||
}
|
||||
|
||||
impl Unpack for LockResponse {
|
||||
fn unpack(cursor: &mut ReadCursor<'_>) -> Result<Self> {
|
||||
let structure_size = cursor.read_u16_le()?;
|
||||
if structure_size != Self::STRUCTURE_SIZE {
|
||||
return Err(Error::invalid_data(format!(
|
||||
"invalid LockResponse structure size: expected {}, got {}",
|
||||
Self::STRUCTURE_SIZE,
|
||||
structure_size
|
||||
)));
|
||||
}
|
||||
|
||||
let _reserved = cursor.read_u16_le()?;
|
||||
|
||||
Ok(LockResponse)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// ── LockElement tests ──────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn lock_element_roundtrip() {
|
||||
let original = LockElement {
|
||||
offset: 0x1000,
|
||||
length: 0x2000,
|
||||
flags: SMB2_LOCKFLAG_EXCLUSIVE_LOCK | SMB2_LOCKFLAG_FAIL_IMMEDIATELY,
|
||||
};
|
||||
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
assert_eq!(bytes.len(), LockElement::SIZE);
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = LockElement::unpack(&mut r).unwrap();
|
||||
|
||||
assert_eq!(decoded, original);
|
||||
}
|
||||
|
||||
// ── LockRequest tests ──────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn lock_request_single_lock_roundtrip() {
|
||||
let original = LockRequest {
|
||||
lock_sequence: 0,
|
||||
file_id: FileId {
|
||||
persistent: 0xDEAD,
|
||||
volatile: 0xBEEF,
|
||||
},
|
||||
locks: vec![LockElement {
|
||||
offset: 0,
|
||||
length: 4096,
|
||||
flags: SMB2_LOCKFLAG_SHARED_LOCK,
|
||||
}],
|
||||
};
|
||||
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
// Fixed: 24 bytes + 1 lock element (24 bytes) = 48 bytes
|
||||
assert_eq!(bytes.len(), 48);
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = LockRequest::unpack(&mut r).unwrap();
|
||||
|
||||
assert_eq!(decoded.lock_sequence, original.lock_sequence);
|
||||
assert_eq!(decoded.file_id, original.file_id);
|
||||
assert_eq!(decoded.locks.len(), 1);
|
||||
assert_eq!(decoded.locks[0], original.locks[0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lock_request_multiple_locks_roundtrip() {
|
||||
let original = LockRequest {
|
||||
lock_sequence: 0x1234_5678,
|
||||
file_id: FileId {
|
||||
persistent: 0x1111,
|
||||
volatile: 0x2222,
|
||||
},
|
||||
locks: vec![
|
||||
LockElement {
|
||||
offset: 0,
|
||||
length: 1024,
|
||||
flags: SMB2_LOCKFLAG_EXCLUSIVE_LOCK | SMB2_LOCKFLAG_FAIL_IMMEDIATELY,
|
||||
},
|
||||
LockElement {
|
||||
offset: 4096,
|
||||
length: 2048,
|
||||
flags: SMB2_LOCKFLAG_SHARED_LOCK,
|
||||
},
|
||||
LockElement {
|
||||
offset: 8192,
|
||||
length: 512,
|
||||
flags: SMB2_LOCKFLAG_UNLOCK,
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
// Fixed: 24 bytes + 3 lock elements (3 * 24) = 96 bytes
|
||||
assert_eq!(bytes.len(), 96);
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = LockRequest::unpack(&mut r).unwrap();
|
||||
|
||||
assert_eq!(decoded.lock_sequence, original.lock_sequence);
|
||||
assert_eq!(decoded.file_id, original.file_id);
|
||||
assert_eq!(decoded.locks.len(), 3);
|
||||
assert_eq!(decoded.locks[0], original.locks[0]);
|
||||
assert_eq!(decoded.locks[1], original.locks[1]);
|
||||
assert_eq!(decoded.locks[2], original.locks[2]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lock_request_known_bytes() {
|
||||
let mut buf = Vec::new();
|
||||
// StructureSize = 48
|
||||
buf.extend_from_slice(&48u16.to_le_bytes());
|
||||
// LockCount = 1
|
||||
buf.extend_from_slice(&1u16.to_le_bytes());
|
||||
// LockSequence = 0
|
||||
buf.extend_from_slice(&0u32.to_le_bytes());
|
||||
// FileId persistent = 0x10
|
||||
buf.extend_from_slice(&0x10u64.to_le_bytes());
|
||||
// FileId volatile = 0x20
|
||||
buf.extend_from_slice(&0x20u64.to_le_bytes());
|
||||
// LockElement: offset = 0, length = 100, flags = SHARED (1), reserved = 0
|
||||
buf.extend_from_slice(&0u64.to_le_bytes());
|
||||
buf.extend_from_slice(&100u64.to_le_bytes());
|
||||
buf.extend_from_slice(&1u32.to_le_bytes());
|
||||
buf.extend_from_slice(&0u32.to_le_bytes());
|
||||
|
||||
let mut cursor = ReadCursor::new(&buf);
|
||||
let req = LockRequest::unpack(&mut cursor).unwrap();
|
||||
|
||||
assert_eq!(req.file_id.persistent, 0x10);
|
||||
assert_eq!(req.file_id.volatile, 0x20);
|
||||
assert_eq!(req.locks.len(), 1);
|
||||
assert_eq!(req.locks[0].offset, 0);
|
||||
assert_eq!(req.locks[0].length, 100);
|
||||
assert_eq!(req.locks[0].flags, SMB2_LOCKFLAG_SHARED_LOCK);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lock_request_wrong_structure_size() {
|
||||
let mut buf = [0u8; 48];
|
||||
buf[0..2].copy_from_slice(&99u16.to_le_bytes());
|
||||
|
||||
let mut cursor = ReadCursor::new(&buf);
|
||||
let result = LockRequest::unpack(&mut cursor);
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err().to_string();
|
||||
assert!(err.contains("structure size"), "error was: {err}");
|
||||
}
|
||||
|
||||
// ── LockResponse tests ─────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn lock_response_roundtrip() {
|
||||
let original = LockResponse;
|
||||
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
// 2 + 2 = 4 bytes
|
||||
assert_eq!(bytes.len(), 4);
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let _decoded = LockResponse::unpack(&mut r).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lock_response_known_bytes() {
|
||||
let mut buf = [0u8; 4];
|
||||
buf[0..2].copy_from_slice(&4u16.to_le_bytes());
|
||||
buf[2..4].copy_from_slice(&0u16.to_le_bytes());
|
||||
|
||||
let mut cursor = ReadCursor::new(&buf);
|
||||
let _resp = LockResponse::unpack(&mut cursor).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lock_response_wrong_structure_size() {
|
||||
let mut buf = [0u8; 4];
|
||||
buf[0..2].copy_from_slice(&8u16.to_le_bytes());
|
||||
|
||||
let mut cursor = ReadCursor::new(&buf);
|
||||
let result = LockResponse::unpack(&mut cursor);
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err().to_string();
|
||||
assert!(err.contains("structure size"), "error was: {err}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lock_flags_combinations() {
|
||||
// Verify flag constants are distinct and correct
|
||||
assert_eq!(SMB2_LOCKFLAG_SHARED_LOCK, 0x01);
|
||||
assert_eq!(SMB2_LOCKFLAG_EXCLUSIVE_LOCK, 0x02);
|
||||
assert_eq!(SMB2_LOCKFLAG_UNLOCK, 0x04);
|
||||
assert_eq!(SMB2_LOCKFLAG_FAIL_IMMEDIATELY, 0x10);
|
||||
|
||||
// Shared + fail immediately
|
||||
let combined = SMB2_LOCKFLAG_SHARED_LOCK | SMB2_LOCKFLAG_FAIL_IMMEDIATELY;
|
||||
assert_eq!(combined, 0x11);
|
||||
|
||||
// Exclusive + fail immediately
|
||||
let combined = SMB2_LOCKFLAG_EXCLUSIVE_LOCK | SMB2_LOCKFLAG_FAIL_IMMEDIATELY;
|
||||
assert_eq!(combined, 0x12);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod roundtrip_props {
|
||||
use super::*;
|
||||
use crate::msg::roundtrip_strategies::arb_file_id;
|
||||
use proptest::prelude::*;
|
||||
|
||||
fn arb_lock_element() -> impl Strategy<Value = LockElement> {
|
||||
(any::<u64>(), any::<u64>(), any::<u32>()).prop_map(|(offset, length, flags)| LockElement {
|
||||
offset,
|
||||
length,
|
||||
flags,
|
||||
})
|
||||
}
|
||||
|
||||
proptest! {
|
||||
#[test]
|
||||
fn lock_element_pack_unpack(elem in arb_lock_element()) {
|
||||
let mut w = WriteCursor::new();
|
||||
elem.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
prop_assert_eq!(bytes.len(), LockElement::SIZE);
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = LockElement::unpack(&mut r).unwrap();
|
||||
prop_assert_eq!(decoded, elem);
|
||||
prop_assert!(r.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lock_request_pack_unpack(
|
||||
lock_sequence in any::<u32>(),
|
||||
file_id in arb_file_id(),
|
||||
// MS-SMB2: LockCount must be >= 1, so generate 1..=8.
|
||||
locks in prop::collection::vec(arb_lock_element(), 1..=8),
|
||||
) {
|
||||
let original = LockRequest {
|
||||
lock_sequence,
|
||||
file_id,
|
||||
locks,
|
||||
};
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = LockRequest::unpack(&mut r).unwrap();
|
||||
prop_assert_eq!(decoded, original);
|
||||
prop_assert!(r.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lock_response_pack_unpack(_ in any::<bool>()) {
|
||||
// LockResponse is a unit struct; there's nothing to vary, but
|
||||
// running it through the proptest harness keeps the coverage
|
||||
// map uniform.
|
||||
let original = LockResponse;
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = LockResponse::unpack(&mut r).unwrap();
|
||||
prop_assert_eq!(decoded, original);
|
||||
prop_assert!(r.is_empty());
|
||||
}
|
||||
}
|
||||
}
|
||||
42
vendor/smb2/src/msg/logoff.rs
vendored
Normal file
42
vendor/smb2/src/msg/logoff.rs
vendored
Normal file
@@ -0,0 +1,42 @@
|
||||
//! SMB2 LOGOFF request and response (spec sections 2.2.7, 2.2.8).
|
||||
//!
|
||||
//! Logoff messages request and confirm termination of a session.
|
||||
//! Both request and response contain only a StructureSize field and a
|
||||
//! reserved field, for a total of 4 bytes each.
|
||||
|
||||
super::trivial_message! {
|
||||
/// SMB2 LOGOFF request (spec section 2.2.7).
|
||||
///
|
||||
/// Sent by the client to request termination of a particular session.
|
||||
/// Contains only StructureSize (2 bytes) and Reserved (2 bytes).
|
||||
pub struct LogoffRequest;
|
||||
}
|
||||
|
||||
super::trivial_message! {
|
||||
/// SMB2 LOGOFF response (spec section 2.2.8).
|
||||
///
|
||||
/// Sent by the server to confirm that a LOGOFF request was processed.
|
||||
/// Contains only StructureSize (2 bytes) and Reserved (2 bytes).
|
||||
pub struct LogoffResponse;
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
super::super::trivial_message_tests!(
|
||||
LogoffRequest,
|
||||
logoff_request_known_bytes,
|
||||
logoff_request_roundtrip,
|
||||
logoff_request_wrong_structure_size,
|
||||
logoff_request_too_short
|
||||
);
|
||||
|
||||
super::super::trivial_message_tests!(
|
||||
LogoffResponse,
|
||||
logoff_response_known_bytes,
|
||||
logoff_response_roundtrip,
|
||||
logoff_response_wrong_structure_size,
|
||||
logoff_response_too_short
|
||||
);
|
||||
}
|
||||
152
vendor/smb2/src/msg/mod.rs
vendored
Normal file
152
vendor/smb2/src/msg/mod.rs
vendored
Normal file
@@ -0,0 +1,152 @@
|
||||
//! Wire format message structs for SMB2/3.
|
||||
//!
|
||||
//! Each sub-module corresponds to one SMB2 command type with its
|
||||
//! request and response structures.
|
||||
//!
|
||||
//! Most users don't need this module directly -- use [`SmbClient`](crate::SmbClient)
|
||||
//! for high-level file operations.
|
||||
|
||||
// Wire format internals, comments would be pretty redundant. Public API docs are enforced at the crate level.
|
||||
#![allow(missing_docs)]
|
||||
|
||||
/// Generates a trivial 4-byte SMB2 stub message (StructureSize + Reserved).
|
||||
///
|
||||
/// Many SMB2 commands (echo, cancel, logoff, tree_disconnect) have request
|
||||
/// and/or response structs that are identical: 2-byte StructureSize (always 4)
|
||||
/// plus 2-byte Reserved. This macro generates the struct definition and its
|
||||
/// `Pack`/`Unpack` impls from a single declaration.
|
||||
///
|
||||
/// # Usage
|
||||
///
|
||||
/// ```ignore
|
||||
/// trivial_message! {
|
||||
/// /// Doc comment for the struct.
|
||||
/// pub struct EchoRequest;
|
||||
/// }
|
||||
/// ```
|
||||
macro_rules! trivial_message {
|
||||
(
|
||||
$(#[$meta:meta])*
|
||||
pub struct $name:ident;
|
||||
) => {
|
||||
$(#[$meta])*
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct $name;
|
||||
|
||||
impl $name {
|
||||
pub const STRUCTURE_SIZE: u16 = 4;
|
||||
}
|
||||
|
||||
impl crate::pack::Pack for $name {
|
||||
fn pack(&self, cursor: &mut crate::pack::WriteCursor) {
|
||||
// StructureSize (2 bytes)
|
||||
cursor.write_u16_le(Self::STRUCTURE_SIZE);
|
||||
// Reserved (2 bytes)
|
||||
cursor.write_u16_le(0);
|
||||
}
|
||||
}
|
||||
|
||||
impl crate::pack::Unpack for $name {
|
||||
fn unpack(cursor: &mut crate::pack::ReadCursor<'_>) -> crate::error::Result<Self> {
|
||||
// StructureSize (2 bytes)
|
||||
let structure_size = cursor.read_u16_le()?;
|
||||
if structure_size != Self::STRUCTURE_SIZE {
|
||||
return Err(crate::Error::invalid_data(format!(
|
||||
"invalid {} structure size: expected {}, got {}",
|
||||
stringify!($name),
|
||||
Self::STRUCTURE_SIZE,
|
||||
structure_size
|
||||
)));
|
||||
}
|
||||
|
||||
// Reserved (2 bytes)
|
||||
let _reserved = cursor.read_u16_le()?;
|
||||
|
||||
Ok($name)
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
pub(crate) use trivial_message;
|
||||
|
||||
/// Generates a minimal test suite for a trivial 4-byte message type.
|
||||
///
|
||||
/// Tests: known bytes, pack-unpack roundtrip, wrong structure size, and
|
||||
/// truncated input. These four tests cover all interesting behavior for
|
||||
/// types produced by [`trivial_message!`].
|
||||
#[cfg(test)]
|
||||
macro_rules! trivial_message_tests {
|
||||
($type:ident, $known:ident, $roundtrip:ident, $wrong_size:ident, $short:ident) => {
|
||||
#[test]
|
||||
fn $known() {
|
||||
let msg = $type;
|
||||
let mut cursor = crate::pack::WriteCursor::new();
|
||||
crate::pack::Pack::pack(&msg, &mut cursor);
|
||||
let bytes = cursor.into_inner();
|
||||
// StructureSize=4 (LE), Reserved=0
|
||||
assert_eq!(bytes, [0x04, 0x00, 0x00, 0x00]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn $roundtrip() {
|
||||
let original = $type;
|
||||
let mut w = crate::pack::WriteCursor::new();
|
||||
crate::pack::Pack::pack(&original, &mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
let mut r = crate::pack::ReadCursor::new(&bytes);
|
||||
let decoded = <$type as crate::pack::Unpack>::unpack(&mut r).unwrap();
|
||||
assert_eq!(decoded, original);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn $wrong_size() {
|
||||
let bytes = [0x08, 0x00, 0x00, 0x00];
|
||||
let mut cursor = crate::pack::ReadCursor::new(&bytes);
|
||||
let result = <$type as crate::pack::Unpack>::unpack(&mut cursor);
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err().to_string();
|
||||
assert!(err.contains("structure size"), "error was: {err}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn $short() {
|
||||
let bytes = [0x04, 0x00];
|
||||
let mut cursor = crate::pack::ReadCursor::new(&bytes);
|
||||
let result = <$type as crate::pack::Unpack>::unpack(&mut cursor);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) use trivial_message_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) mod roundtrip_strategies;
|
||||
|
||||
pub mod cancel;
|
||||
pub mod change_notify;
|
||||
pub mod close;
|
||||
pub mod create;
|
||||
pub mod dfs;
|
||||
pub mod echo;
|
||||
pub mod flush;
|
||||
pub mod header;
|
||||
pub mod ioctl;
|
||||
pub mod lock;
|
||||
pub mod logoff;
|
||||
pub mod negotiate;
|
||||
pub mod oplock_break;
|
||||
pub mod query_directory;
|
||||
pub mod query_info;
|
||||
pub mod read;
|
||||
pub mod session_setup;
|
||||
pub mod set_info;
|
||||
pub mod transform;
|
||||
pub mod tree_connect;
|
||||
pub mod tree_disconnect;
|
||||
pub mod write;
|
||||
|
||||
pub use header::{ErrorResponse, Header, PROTOCOL_ID};
|
||||
1228
vendor/smb2/src/msg/negotiate.rs
vendored
Normal file
1228
vendor/smb2/src/msg/negotiate.rs
vendored
Normal file
File diff suppressed because it is too large
Load Diff
262
vendor/smb2/src/msg/oplock_break.rs
vendored
Normal file
262
vendor/smb2/src/msg/oplock_break.rs
vendored
Normal file
@@ -0,0 +1,262 @@
|
||||
//! SMB2 Oplock Break Notification, Acknowledgment, and Response
|
||||
//! (MS-SMB2 sections 2.2.23, 2.2.24, 2.2.25).
|
||||
//!
|
||||
//! All three oplock break messages share an identical 24-byte wire format:
|
||||
//! - StructureSize (2 bytes, must be 24)
|
||||
//! - OplockLevel (1 byte)
|
||||
//! - Reserved (1 byte)
|
||||
//! - Reserved2 (4 bytes)
|
||||
//! - FileId (16 bytes)
|
||||
//!
|
||||
//! We define one shared struct and provide type aliases for each role.
|
||||
//!
|
||||
//! Note: Lease break notification/acknowledgment/response (sections 2.2.23.2,
|
||||
//! 2.2.24.2, 2.2.25.2) use a different structure with LeaseKey, LeaseState,
|
||||
//! etc. Lease break handling is deferred to a future implementation.
|
||||
|
||||
use crate::error::Result;
|
||||
use crate::pack::{Pack, ReadCursor, Unpack, WriteCursor};
|
||||
use crate::types::{FileId, OplockLevel};
|
||||
use crate::Error;
|
||||
|
||||
// ── OplockBreak (shared struct) ────────────────────────────────────────
|
||||
|
||||
/// Shared wire format for oplock break notification, acknowledgment, and
|
||||
/// response messages (MS-SMB2 sections 2.2.23, 2.2.24, 2.2.25).
|
||||
///
|
||||
/// All three messages have an identical 24-byte layout. The message's role
|
||||
/// (notification vs acknowledgment vs response) is determined by the header's
|
||||
/// command code and flags, not by this structure.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct OplockBreak {
|
||||
/// The oplock level.
|
||||
pub oplock_level: OplockLevel,
|
||||
/// The file handle associated with the oplock.
|
||||
pub file_id: FileId,
|
||||
}
|
||||
|
||||
impl OplockBreak {
|
||||
pub const STRUCTURE_SIZE: u16 = 24;
|
||||
}
|
||||
|
||||
impl Pack for OplockBreak {
|
||||
fn pack(&self, cursor: &mut WriteCursor) {
|
||||
// StructureSize (2 bytes)
|
||||
cursor.write_u16_le(Self::STRUCTURE_SIZE);
|
||||
// OplockLevel (1 byte)
|
||||
cursor.write_u8(self.oplock_level as u8);
|
||||
// Reserved (1 byte)
|
||||
cursor.write_u8(0);
|
||||
// Reserved2 (4 bytes)
|
||||
cursor.write_u32_le(0);
|
||||
// FileId (16 bytes)
|
||||
cursor.write_u64_le(self.file_id.persistent);
|
||||
cursor.write_u64_le(self.file_id.volatile);
|
||||
}
|
||||
}
|
||||
|
||||
impl Unpack for OplockBreak {
|
||||
fn unpack(cursor: &mut ReadCursor<'_>) -> Result<Self> {
|
||||
let structure_size = cursor.read_u16_le()?;
|
||||
if structure_size != Self::STRUCTURE_SIZE {
|
||||
return Err(Error::invalid_data(format!(
|
||||
"invalid OplockBreak structure size: expected {}, got {}",
|
||||
Self::STRUCTURE_SIZE,
|
||||
structure_size
|
||||
)));
|
||||
}
|
||||
|
||||
let oplock_level = OplockLevel::try_from(cursor.read_u8()?)?;
|
||||
let _reserved = cursor.read_u8()?;
|
||||
let _reserved2 = cursor.read_u32_le()?;
|
||||
let persistent = cursor.read_u64_le()?;
|
||||
let volatile = cursor.read_u64_le()?;
|
||||
|
||||
Ok(OplockBreak {
|
||||
oplock_level,
|
||||
file_id: FileId {
|
||||
persistent,
|
||||
volatile,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Oplock break notification (server to client, MS-SMB2 section 2.2.23).
|
||||
///
|
||||
/// Arrives with `MessageId = 0xFFFFFFFFFFFFFFFF` (unsolicited).
|
||||
pub type OplockBreakNotification = OplockBreak;
|
||||
|
||||
/// Oplock break acknowledgment (client to server, MS-SMB2 section 2.2.24).
|
||||
pub type OplockBreakAcknowledgment = OplockBreak;
|
||||
|
||||
/// Oplock break response (server to client after ack, MS-SMB2 section 2.2.25).
|
||||
pub type OplockBreakResponse = OplockBreak;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// ── OplockBreakNotification tests ─────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn oplock_break_notification_roundtrip() {
|
||||
let original = OplockBreakNotification {
|
||||
oplock_level: OplockLevel::LevelII,
|
||||
file_id: FileId {
|
||||
persistent: 0x1122_3344_5566_7788,
|
||||
volatile: 0xAABB_CCDD_EEFF_0011,
|
||||
},
|
||||
};
|
||||
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
// Fixed 24 bytes
|
||||
assert_eq!(bytes.len(), 24);
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = OplockBreakNotification::unpack(&mut r).unwrap();
|
||||
|
||||
assert_eq!(decoded.oplock_level, OplockLevel::LevelII);
|
||||
assert_eq!(decoded.file_id, original.file_id);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn oplock_break_notification_exclusive_level() {
|
||||
let original = OplockBreakNotification {
|
||||
oplock_level: OplockLevel::Exclusive,
|
||||
file_id: FileId {
|
||||
persistent: 0x42,
|
||||
volatile: 0x99,
|
||||
},
|
||||
};
|
||||
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = OplockBreakNotification::unpack(&mut r).unwrap();
|
||||
|
||||
assert_eq!(decoded.oplock_level, OplockLevel::Exclusive);
|
||||
assert_eq!(decoded.file_id.persistent, 0x42);
|
||||
assert_eq!(decoded.file_id.volatile, 0x99);
|
||||
}
|
||||
|
||||
// ── OplockBreakAcknowledgment tests ───────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn oplock_break_acknowledgment_roundtrip() {
|
||||
let original = OplockBreakAcknowledgment {
|
||||
oplock_level: OplockLevel::None,
|
||||
file_id: FileId {
|
||||
persistent: 0xDEAD,
|
||||
volatile: 0xBEEF,
|
||||
},
|
||||
};
|
||||
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
assert_eq!(bytes.len(), 24);
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = OplockBreakAcknowledgment::unpack(&mut r).unwrap();
|
||||
|
||||
assert_eq!(decoded.oplock_level, OplockLevel::None);
|
||||
assert_eq!(decoded.file_id, original.file_id);
|
||||
}
|
||||
|
||||
// ── OplockBreakResponse tests ─────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn oplock_break_response_roundtrip() {
|
||||
let original = OplockBreakResponse {
|
||||
oplock_level: OplockLevel::Batch,
|
||||
file_id: FileId {
|
||||
persistent: 0xCAFE,
|
||||
volatile: 0xFACE,
|
||||
},
|
||||
};
|
||||
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
assert_eq!(bytes.len(), 24);
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = OplockBreakResponse::unpack(&mut r).unwrap();
|
||||
|
||||
assert_eq!(decoded.oplock_level, OplockLevel::Batch);
|
||||
assert_eq!(decoded.file_id, original.file_id);
|
||||
}
|
||||
|
||||
// ── Error tests ───────────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn oplock_break_wrong_structure_size() {
|
||||
let mut buf = [0u8; 24];
|
||||
buf[0..2].copy_from_slice(&99u16.to_le_bytes());
|
||||
|
||||
let mut cursor = ReadCursor::new(&buf);
|
||||
let result = OplockBreak::unpack(&mut cursor);
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err().to_string();
|
||||
assert!(err.contains("structure size"), "error was: {err}");
|
||||
}
|
||||
|
||||
// Roundtrip property tests live in `roundtrip_props` at file end.
|
||||
|
||||
#[test]
|
||||
fn oplock_break_reserved_fields_ignored() {
|
||||
let mut buf = [0u8; 24];
|
||||
// StructureSize = 24
|
||||
buf[0..2].copy_from_slice(&24u16.to_le_bytes());
|
||||
// OplockLevel = LEVEL_II
|
||||
buf[2] = OplockLevel::LevelII as u8;
|
||||
// Reserved = 0xFF (should be ignored)
|
||||
buf[3] = 0xFF;
|
||||
// Reserved2 = 0xDEADBEEF (should be ignored)
|
||||
buf[4..8].copy_from_slice(&0xDEAD_BEEFu32.to_le_bytes());
|
||||
// FileId persistent = 1
|
||||
buf[8..16].copy_from_slice(&1u64.to_le_bytes());
|
||||
// FileId volatile = 2
|
||||
buf[16..24].copy_from_slice(&2u64.to_le_bytes());
|
||||
|
||||
let mut cursor = ReadCursor::new(&buf);
|
||||
let decoded = OplockBreak::unpack(&mut cursor).unwrap();
|
||||
|
||||
assert_eq!(decoded.oplock_level, OplockLevel::LevelII);
|
||||
assert_eq!(decoded.file_id.persistent, 1);
|
||||
assert_eq!(decoded.file_id.volatile, 2);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod roundtrip_props {
|
||||
use super::*;
|
||||
use crate::msg::roundtrip_strategies::{arb_file_id, arb_oplock_level};
|
||||
use proptest::prelude::*;
|
||||
|
||||
proptest! {
|
||||
#[test]
|
||||
fn oplock_break_pack_unpack(
|
||||
oplock_level in arb_oplock_level(),
|
||||
file_id in arb_file_id(),
|
||||
) {
|
||||
let original = OplockBreak { oplock_level, file_id };
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = OplockBreak::unpack(&mut r).unwrap();
|
||||
prop_assert_eq!(decoded, original);
|
||||
prop_assert!(r.is_empty());
|
||||
}
|
||||
}
|
||||
}
|
||||
476
vendor/smb2/src/msg/query_directory.rs
vendored
Normal file
476
vendor/smb2/src/msg/query_directory.rs
vendored
Normal file
@@ -0,0 +1,476 @@
|
||||
//! SMB2 QUERY_DIRECTORY request and response (spec sections 2.2.33, 2.2.34).
|
||||
//!
|
||||
//! Used by the client to enumerate directory contents. The request specifies
|
||||
//! a search pattern (typically `"*"`) and the response contains directory
|
||||
//! entries in the requested information class format.
|
||||
|
||||
use crate::error::Result;
|
||||
use crate::msg::header::Header;
|
||||
use crate::pack::{Pack, ReadCursor, Unpack, WriteCursor};
|
||||
use crate::types::FileId;
|
||||
use crate::Error;
|
||||
|
||||
// ── Enums / flags ────────────────────────────────────────────────────────
|
||||
|
||||
/// File information class for directory queries (MS-SMB2 2.2.33).
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
#[repr(u8)]
|
||||
pub enum FileInformationClass {
|
||||
/// Basic directory information.
|
||||
FileDirectoryInformation = 0x01,
|
||||
/// Full directory information.
|
||||
FileFullDirectoryInformation = 0x02,
|
||||
/// Both short and long name information.
|
||||
FileBothDirectoryInformation = 0x03,
|
||||
/// File names only.
|
||||
FileNamesInformation = 0x0C,
|
||||
/// Both short and long name information with file IDs.
|
||||
FileIdBothDirectoryInformation = 0x25,
|
||||
/// Full directory information with file IDs.
|
||||
FileIdFullDirectoryInformation = 0x26,
|
||||
}
|
||||
|
||||
impl TryFrom<u8> for FileInformationClass {
|
||||
type Error = Error;
|
||||
|
||||
fn try_from(value: u8) -> Result<Self> {
|
||||
match value {
|
||||
0x01 => Ok(Self::FileDirectoryInformation),
|
||||
0x02 => Ok(Self::FileFullDirectoryInformation),
|
||||
0x03 => Ok(Self::FileBothDirectoryInformation),
|
||||
0x0C => Ok(Self::FileNamesInformation),
|
||||
0x25 => Ok(Self::FileIdBothDirectoryInformation),
|
||||
0x26 => Ok(Self::FileIdFullDirectoryInformation),
|
||||
_ => Err(Error::invalid_data(format!(
|
||||
"invalid FileInformationClass: 0x{:02X}",
|
||||
value
|
||||
))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Query directory flags (MS-SMB2 2.2.33).
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
|
||||
pub struct QueryDirectoryFlags(pub u8);
|
||||
|
||||
impl QueryDirectoryFlags {
|
||||
/// Restart the enumeration from the beginning.
|
||||
pub const RESTART_SCANS: u8 = 0x01;
|
||||
/// Return only a single entry.
|
||||
pub const RETURN_SINGLE_ENTRY: u8 = 0x02;
|
||||
/// Resume from the specified file index.
|
||||
pub const INDEX_SPECIFIED: u8 = 0x04;
|
||||
/// Reopen the directory and change the search pattern.
|
||||
pub const REOPEN: u8 = 0x10;
|
||||
}
|
||||
|
||||
// ── QueryDirectoryRequest ────────────────────────────────────────────────
|
||||
|
||||
/// SMB2 QUERY_DIRECTORY request (spec section 2.2.33).
|
||||
///
|
||||
/// Sent by the client to enumerate files in a directory.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct QueryDirectoryRequest {
|
||||
/// The type of information to return for each directory entry.
|
||||
pub file_information_class: FileInformationClass,
|
||||
/// Flags controlling the query behavior.
|
||||
pub flags: QueryDirectoryFlags,
|
||||
/// Byte offset within the directory to resume enumeration from.
|
||||
pub file_index: u32,
|
||||
/// Handle to the directory being queried.
|
||||
pub file_id: FileId,
|
||||
/// Maximum number of bytes the server can return.
|
||||
pub output_buffer_length: u32,
|
||||
/// Search pattern (for example, `"*"` for all files).
|
||||
pub file_name: String,
|
||||
}
|
||||
|
||||
impl QueryDirectoryRequest {
|
||||
pub const STRUCTURE_SIZE: u16 = 33;
|
||||
}
|
||||
|
||||
impl Pack for QueryDirectoryRequest {
|
||||
fn pack(&self, cursor: &mut WriteCursor) {
|
||||
let start = cursor.position();
|
||||
|
||||
// StructureSize (2 bytes)
|
||||
cursor.write_u16_le(Self::STRUCTURE_SIZE);
|
||||
// FileInformationClass (1 byte)
|
||||
cursor.write_u8(self.file_information_class as u8);
|
||||
// Flags (1 byte)
|
||||
cursor.write_u8(self.flags.0);
|
||||
// FileIndex (4 bytes)
|
||||
cursor.write_u32_le(self.file_index);
|
||||
// FileId (16 bytes)
|
||||
cursor.write_u64_le(self.file_id.persistent);
|
||||
cursor.write_u64_le(self.file_id.volatile);
|
||||
// FileNameOffset (2 bytes) -- placeholder
|
||||
let name_offset_pos = cursor.position();
|
||||
cursor.write_u16_le(0);
|
||||
// FileNameLength (2 bytes) -- placeholder
|
||||
let name_length_pos = cursor.position();
|
||||
cursor.write_u16_le(0);
|
||||
// OutputBufferLength (4 bytes)
|
||||
cursor.write_u32_le(self.output_buffer_length);
|
||||
|
||||
if self.file_name.is_empty() {
|
||||
// No search pattern: FileNameOffset and FileNameLength stay 0
|
||||
// per spec section 2.2.33. Write 1 padding byte to satisfy
|
||||
// StructureSize=33 (32 fixed + 1 byte buffer minimum).
|
||||
cursor.write_u8(0);
|
||||
} else {
|
||||
// Buffer: filename pattern in UTF-16LE.
|
||||
// Offset is from the beginning of the SMB2 header per spec.
|
||||
let name_offset = Header::SIZE + (cursor.position() - start);
|
||||
let name_start = cursor.position();
|
||||
cursor.write_utf16_le(&self.file_name);
|
||||
let name_byte_len = cursor.position() - name_start;
|
||||
|
||||
// Backpatch
|
||||
cursor.set_u16_le_at(name_offset_pos, name_offset as u16);
|
||||
cursor.set_u16_le_at(name_length_pos, name_byte_len as u16);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Unpack for QueryDirectoryRequest {
|
||||
fn unpack(cursor: &mut ReadCursor<'_>) -> Result<Self> {
|
||||
let start = cursor.position();
|
||||
|
||||
// StructureSize (2 bytes)
|
||||
let structure_size = cursor.read_u16_le()?;
|
||||
if structure_size != Self::STRUCTURE_SIZE {
|
||||
return Err(Error::invalid_data(format!(
|
||||
"invalid QueryDirectoryRequest structure size: expected {}, got {}",
|
||||
Self::STRUCTURE_SIZE,
|
||||
structure_size
|
||||
)));
|
||||
}
|
||||
|
||||
// FileInformationClass (1 byte)
|
||||
let info_class = FileInformationClass::try_from(cursor.read_u8()?)?;
|
||||
// Flags (1 byte)
|
||||
let flags = QueryDirectoryFlags(cursor.read_u8()?);
|
||||
// FileIndex (4 bytes)
|
||||
let file_index = cursor.read_u32_le()?;
|
||||
// FileId (16 bytes)
|
||||
let persistent = cursor.read_u64_le()?;
|
||||
let volatile = cursor.read_u64_le()?;
|
||||
let file_id = FileId {
|
||||
persistent,
|
||||
volatile,
|
||||
};
|
||||
// FileNameOffset (2 bytes)
|
||||
let name_offset = cursor.read_u16_le()? as usize;
|
||||
// FileNameLength (2 bytes)
|
||||
let name_length = cursor.read_u16_le()? as usize;
|
||||
// OutputBufferLength (4 bytes)
|
||||
let output_buffer_length = cursor.read_u32_le()?;
|
||||
|
||||
// Read filename
|
||||
// Offset on the wire is from beginning of SMB2 header.
|
||||
let file_name = if name_length > 0 {
|
||||
let current = cursor.position();
|
||||
let body_offset = name_offset.saturating_sub(Header::SIZE);
|
||||
let target = start + body_offset;
|
||||
if target > current {
|
||||
cursor.skip(target - current)?;
|
||||
}
|
||||
cursor.read_utf16_le(name_length)?
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
|
||||
Ok(QueryDirectoryRequest {
|
||||
file_information_class: info_class,
|
||||
flags,
|
||||
file_index,
|
||||
file_id,
|
||||
output_buffer_length,
|
||||
file_name,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ── QueryDirectoryResponse ───────────────────────────────────────────────
|
||||
|
||||
/// SMB2 QUERY_DIRECTORY response (spec section 2.2.34).
|
||||
///
|
||||
/// Contains directory enumeration data as raw bytes. The format depends
|
||||
/// on the `FileInformationClass` from the request.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct QueryDirectoryResponse {
|
||||
/// Raw output buffer containing directory entries.
|
||||
pub output_buffer: Vec<u8>,
|
||||
}
|
||||
|
||||
impl QueryDirectoryResponse {
|
||||
pub const STRUCTURE_SIZE: u16 = 9;
|
||||
}
|
||||
|
||||
impl Pack for QueryDirectoryResponse {
|
||||
fn pack(&self, cursor: &mut WriteCursor) {
|
||||
let start = cursor.position();
|
||||
|
||||
// StructureSize (2 bytes)
|
||||
cursor.write_u16_le(Self::STRUCTURE_SIZE);
|
||||
// OutputBufferOffset (2 bytes) -- placeholder
|
||||
let offset_pos = cursor.position();
|
||||
cursor.write_u16_le(0);
|
||||
// OutputBufferLength (4 bytes)
|
||||
cursor.write_u32_le(self.output_buffer.len() as u32);
|
||||
|
||||
// Buffer
|
||||
if !self.output_buffer.is_empty() {
|
||||
// Offset is from the beginning of the SMB2 header per spec.
|
||||
let buf_offset = Header::SIZE + (cursor.position() - start);
|
||||
cursor.write_bytes(&self.output_buffer);
|
||||
cursor.set_u16_le_at(offset_pos, buf_offset as u16);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Unpack for QueryDirectoryResponse {
|
||||
fn unpack(cursor: &mut ReadCursor<'_>) -> Result<Self> {
|
||||
let start = cursor.position();
|
||||
|
||||
// StructureSize (2 bytes)
|
||||
let structure_size = cursor.read_u16_le()?;
|
||||
if structure_size != Self::STRUCTURE_SIZE {
|
||||
return Err(Error::invalid_data(format!(
|
||||
"invalid QueryDirectoryResponse structure size: expected {}, got {}",
|
||||
Self::STRUCTURE_SIZE,
|
||||
structure_size
|
||||
)));
|
||||
}
|
||||
|
||||
// OutputBufferOffset (2 bytes)
|
||||
let buf_offset = cursor.read_u16_le()? as usize;
|
||||
// OutputBufferLength (4 bytes)
|
||||
let buf_length = cursor.read_u32_le()? as usize;
|
||||
|
||||
// Read buffer
|
||||
// Offset on the wire is from beginning of SMB2 header.
|
||||
let output_buffer = if buf_length > 0 {
|
||||
let current = cursor.position();
|
||||
let body_offset = buf_offset.saturating_sub(Header::SIZE);
|
||||
let target = start + body_offset;
|
||||
if target > current {
|
||||
cursor.skip(target - current)?;
|
||||
}
|
||||
cursor.read_bytes_bounded(buf_length)?.to_vec()
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
|
||||
Ok(QueryDirectoryResponse { output_buffer })
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// ── QueryDirectoryRequest tests ──────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn query_directory_request_roundtrip_star_pattern() {
|
||||
let original = QueryDirectoryRequest {
|
||||
file_information_class: FileInformationClass::FileBothDirectoryInformation,
|
||||
flags: QueryDirectoryFlags(QueryDirectoryFlags::RESTART_SCANS),
|
||||
file_index: 0,
|
||||
file_id: FileId {
|
||||
persistent: 0xAAAA_BBBB_CCCC_DDDD,
|
||||
volatile: 0x1111_2222_3333_4444,
|
||||
},
|
||||
output_buffer_length: 65536,
|
||||
file_name: "*".to_string(),
|
||||
};
|
||||
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = QueryDirectoryRequest::unpack(&mut r).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
decoded.file_information_class,
|
||||
FileInformationClass::FileBothDirectoryInformation
|
||||
);
|
||||
assert_eq!(decoded.flags.0, QueryDirectoryFlags::RESTART_SCANS);
|
||||
assert_eq!(decoded.file_index, 0);
|
||||
assert_eq!(decoded.file_id, original.file_id);
|
||||
assert_eq!(decoded.output_buffer_length, 65536);
|
||||
assert_eq!(decoded.file_name, "*");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn query_directory_request_structure_size() {
|
||||
let req = QueryDirectoryRequest {
|
||||
file_information_class: FileInformationClass::FileDirectoryInformation,
|
||||
flags: QueryDirectoryFlags::default(),
|
||||
file_index: 0,
|
||||
file_id: FileId::default(),
|
||||
output_buffer_length: 1024,
|
||||
file_name: "*".to_string(),
|
||||
};
|
||||
|
||||
let mut w = WriteCursor::new();
|
||||
req.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
assert_eq!(bytes[0], 33);
|
||||
assert_eq!(bytes[1], 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn query_directory_request_wrong_structure_size() {
|
||||
let mut buf = vec![0u8; 40];
|
||||
buf[0..2].copy_from_slice(&99u16.to_le_bytes());
|
||||
let mut cursor = ReadCursor::new(&buf);
|
||||
let result = QueryDirectoryRequest::unpack(&mut cursor);
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err().to_string();
|
||||
assert!(err.contains("structure size"), "error was: {err}");
|
||||
}
|
||||
|
||||
// ── QueryDirectoryResponse tests ─────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn query_directory_response_roundtrip_with_buffer() {
|
||||
// Simulate raw directory entry data
|
||||
let raw_entries = vec![
|
||||
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E,
|
||||
0x0F, 0x10,
|
||||
];
|
||||
|
||||
let original = QueryDirectoryResponse {
|
||||
output_buffer: raw_entries.clone(),
|
||||
};
|
||||
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = QueryDirectoryResponse::unpack(&mut r).unwrap();
|
||||
|
||||
assert_eq!(decoded.output_buffer, raw_entries);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn query_directory_response_empty_buffer() {
|
||||
let original = QueryDirectoryResponse {
|
||||
output_buffer: Vec::new(),
|
||||
};
|
||||
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
// StructureSize(2) + Offset(2) + Length(4) = 8 bytes
|
||||
assert_eq!(bytes.len(), 8);
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = QueryDirectoryResponse::unpack(&mut r).unwrap();
|
||||
|
||||
assert!(decoded.output_buffer.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn query_directory_response_structure_size() {
|
||||
let resp = QueryDirectoryResponse {
|
||||
output_buffer: vec![0xFF],
|
||||
};
|
||||
|
||||
let mut w = WriteCursor::new();
|
||||
resp.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
assert_eq!(bytes[0], 9);
|
||||
assert_eq!(bytes[1], 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn query_directory_response_wrong_structure_size() {
|
||||
let mut buf = vec![0u8; 16];
|
||||
buf[0..2].copy_from_slice(&42u16.to_le_bytes());
|
||||
let mut cursor = ReadCursor::new(&buf);
|
||||
let result = QueryDirectoryResponse::unpack(&mut cursor);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
// ── Enum tests ───────────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn file_information_class_roundtrip() {
|
||||
for &class in &[
|
||||
FileInformationClass::FileDirectoryInformation,
|
||||
FileInformationClass::FileFullDirectoryInformation,
|
||||
FileInformationClass::FileBothDirectoryInformation,
|
||||
FileInformationClass::FileNamesInformation,
|
||||
FileInformationClass::FileIdFullDirectoryInformation,
|
||||
FileInformationClass::FileIdBothDirectoryInformation,
|
||||
] {
|
||||
let raw = class as u8;
|
||||
let decoded = FileInformationClass::try_from(raw).unwrap();
|
||||
assert_eq!(decoded, class);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn file_information_class_invalid() {
|
||||
assert!(FileInformationClass::try_from(0xFF).is_err());
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod roundtrip_props {
|
||||
use super::*;
|
||||
use crate::msg::roundtrip_strategies::{
|
||||
arb_bytes, arb_file_id, arb_file_information_class, arb_utf16_string,
|
||||
};
|
||||
use proptest::prelude::*;
|
||||
|
||||
proptest! {
|
||||
#[test]
|
||||
fn query_directory_request_pack_unpack(
|
||||
file_information_class in arb_file_information_class(),
|
||||
flags_raw in any::<u8>(),
|
||||
file_index in any::<u32>(),
|
||||
file_id in arb_file_id(),
|
||||
output_buffer_length in any::<u32>(),
|
||||
// Search pattern is UTF-16LE on the wire. Allow empty + typical.
|
||||
file_name in arb_utf16_string(128),
|
||||
) {
|
||||
let original = QueryDirectoryRequest {
|
||||
file_information_class,
|
||||
flags: QueryDirectoryFlags(flags_raw),
|
||||
file_index,
|
||||
file_id,
|
||||
output_buffer_length,
|
||||
file_name,
|
||||
};
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = QueryDirectoryRequest::unpack(&mut r).unwrap();
|
||||
prop_assert_eq!(decoded, original);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn query_directory_response_pack_unpack(output_buffer in arb_bytes()) {
|
||||
let original = QueryDirectoryResponse { output_buffer };
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = QueryDirectoryResponse::unpack(&mut r).unwrap();
|
||||
prop_assert_eq!(decoded, original);
|
||||
}
|
||||
}
|
||||
}
|
||||
479
vendor/smb2/src/msg/query_info.rs
vendored
Normal file
479
vendor/smb2/src/msg/query_info.rs
vendored
Normal file
@@ -0,0 +1,479 @@
|
||||
//! SMB2 QUERY_INFO request and response (spec sections 2.2.37, 2.2.38).
|
||||
//!
|
||||
//! Used to query file, filesystem, security, or quota information.
|
||||
//! The response buffer is stored as raw bytes -- parsing into specific
|
||||
//! information classes is deferred.
|
||||
|
||||
use crate::error::Result;
|
||||
use crate::msg::header::Header;
|
||||
use crate::pack::{Pack, ReadCursor, Unpack, WriteCursor};
|
||||
use crate::types::FileId;
|
||||
use crate::Error;
|
||||
|
||||
// ── Enums ────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Info type for query/set info operations (MS-SMB2 2.2.37).
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
#[repr(u8)]
|
||||
pub enum InfoType {
|
||||
/// Query file information.
|
||||
File = 0x01,
|
||||
/// Query filesystem information.
|
||||
Filesystem = 0x02,
|
||||
/// Query security information.
|
||||
Security = 0x03,
|
||||
/// Query quota information.
|
||||
Quota = 0x04,
|
||||
}
|
||||
|
||||
impl TryFrom<u8> for InfoType {
|
||||
type Error = Error;
|
||||
|
||||
fn try_from(value: u8) -> Result<Self> {
|
||||
match value {
|
||||
0x01 => Ok(Self::File),
|
||||
0x02 => Ok(Self::Filesystem),
|
||||
0x03 => Ok(Self::Security),
|
||||
0x04 => Ok(Self::Quota),
|
||||
_ => Err(Error::invalid_data(format!(
|
||||
"invalid InfoType: 0x{:02X}",
|
||||
value
|
||||
))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── QueryInfoRequest ─────────────────────────────────────────────────────
|
||||
|
||||
/// SMB2 QUERY_INFO request (spec section 2.2.37).
|
||||
///
|
||||
/// Sent by the client to query information about a file, filesystem,
|
||||
/// security descriptor, or quota.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct QueryInfoRequest {
|
||||
/// The type of information being queried.
|
||||
pub info_type: InfoType,
|
||||
/// The file information class (interpretation depends on `info_type`).
|
||||
pub file_info_class: u8,
|
||||
/// Maximum number of output bytes the server may return.
|
||||
pub output_buffer_length: u32,
|
||||
/// Additional information flags (for example, security information flags).
|
||||
pub additional_information: u32,
|
||||
/// Query flags.
|
||||
pub flags: u32,
|
||||
/// Handle to the file or directory being queried.
|
||||
pub file_id: FileId,
|
||||
/// Optional input buffer (for example, for quota queries).
|
||||
pub input_buffer: Vec<u8>,
|
||||
}
|
||||
|
||||
impl QueryInfoRequest {
|
||||
pub const STRUCTURE_SIZE: u16 = 41;
|
||||
}
|
||||
|
||||
impl Pack for QueryInfoRequest {
|
||||
fn pack(&self, cursor: &mut WriteCursor) {
|
||||
let start = cursor.position();
|
||||
|
||||
// StructureSize (2 bytes)
|
||||
cursor.write_u16_le(Self::STRUCTURE_SIZE);
|
||||
// InfoType (1 byte)
|
||||
cursor.write_u8(self.info_type as u8);
|
||||
// FileInfoClass (1 byte)
|
||||
cursor.write_u8(self.file_info_class);
|
||||
// OutputBufferLength (4 bytes)
|
||||
cursor.write_u32_le(self.output_buffer_length);
|
||||
// InputBufferOffset (2 bytes) -- placeholder
|
||||
let input_offset_pos = cursor.position();
|
||||
cursor.write_u16_le(0);
|
||||
// Reserved (2 bytes)
|
||||
cursor.write_u16_le(0);
|
||||
// InputBufferLength (4 bytes)
|
||||
cursor.write_u32_le(self.input_buffer.len() as u32);
|
||||
// AdditionalInformation (4 bytes)
|
||||
cursor.write_u32_le(self.additional_information);
|
||||
// Flags (4 bytes)
|
||||
cursor.write_u32_le(self.flags);
|
||||
// FileId (16 bytes)
|
||||
cursor.write_u64_le(self.file_id.persistent);
|
||||
cursor.write_u64_le(self.file_id.volatile);
|
||||
|
||||
// Buffer (variable)
|
||||
if !self.input_buffer.is_empty() {
|
||||
// Offset is from the beginning of the SMB2 header per spec.
|
||||
let buf_offset = Header::SIZE + (cursor.position() - start);
|
||||
cursor.write_bytes(&self.input_buffer);
|
||||
cursor.set_u16_le_at(input_offset_pos, buf_offset as u16);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Unpack for QueryInfoRequest {
|
||||
fn unpack(cursor: &mut ReadCursor<'_>) -> Result<Self> {
|
||||
let start = cursor.position();
|
||||
|
||||
// StructureSize (2 bytes)
|
||||
let structure_size = cursor.read_u16_le()?;
|
||||
if structure_size != Self::STRUCTURE_SIZE {
|
||||
return Err(Error::invalid_data(format!(
|
||||
"invalid QueryInfoRequest structure size: expected {}, got {}",
|
||||
Self::STRUCTURE_SIZE,
|
||||
structure_size
|
||||
)));
|
||||
}
|
||||
|
||||
// InfoType (1 byte)
|
||||
let info_type = InfoType::try_from(cursor.read_u8()?)?;
|
||||
// FileInfoClass (1 byte)
|
||||
let file_info_class = cursor.read_u8()?;
|
||||
// OutputBufferLength (4 bytes)
|
||||
let output_buffer_length = cursor.read_u32_le()?;
|
||||
// InputBufferOffset (2 bytes)
|
||||
let input_offset = cursor.read_u16_le()? as usize;
|
||||
// Reserved (2 bytes)
|
||||
let _reserved = cursor.read_u16_le()?;
|
||||
// InputBufferLength (4 bytes)
|
||||
let input_length = cursor.read_u32_le()? as usize;
|
||||
// AdditionalInformation (4 bytes)
|
||||
let additional_information = cursor.read_u32_le()?;
|
||||
// Flags (4 bytes)
|
||||
let flags = cursor.read_u32_le()?;
|
||||
// FileId (16 bytes)
|
||||
let persistent = cursor.read_u64_le()?;
|
||||
let volatile = cursor.read_u64_le()?;
|
||||
let file_id = FileId {
|
||||
persistent,
|
||||
volatile,
|
||||
};
|
||||
|
||||
// Read input buffer
|
||||
// Offset on the wire is from beginning of SMB2 header.
|
||||
let input_buffer = if input_length > 0 {
|
||||
let current = cursor.position();
|
||||
let body_offset = input_offset.saturating_sub(Header::SIZE);
|
||||
let target = start + body_offset;
|
||||
if target > current {
|
||||
cursor.skip(target - current)?;
|
||||
}
|
||||
cursor.read_bytes_bounded(input_length)?.to_vec()
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
|
||||
Ok(QueryInfoRequest {
|
||||
info_type,
|
||||
file_info_class,
|
||||
output_buffer_length,
|
||||
additional_information,
|
||||
flags,
|
||||
file_id,
|
||||
input_buffer,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ── QueryInfoResponse ────────────────────────────────────────────────────
|
||||
|
||||
/// SMB2 QUERY_INFO response (spec section 2.2.38).
|
||||
///
|
||||
/// Contains the queried information as raw bytes. The format depends
|
||||
/// on the `InfoType` and `FileInfoClass` from the request.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct QueryInfoResponse {
|
||||
/// Raw output buffer containing the queried information.
|
||||
pub output_buffer: Vec<u8>,
|
||||
}
|
||||
|
||||
impl QueryInfoResponse {
|
||||
pub const STRUCTURE_SIZE: u16 = 9;
|
||||
}
|
||||
|
||||
impl Pack for QueryInfoResponse {
|
||||
fn pack(&self, cursor: &mut WriteCursor) {
|
||||
let start = cursor.position();
|
||||
|
||||
// StructureSize (2 bytes)
|
||||
cursor.write_u16_le(Self::STRUCTURE_SIZE);
|
||||
// OutputBufferOffset (2 bytes) -- placeholder
|
||||
let offset_pos = cursor.position();
|
||||
cursor.write_u16_le(0);
|
||||
// OutputBufferLength (4 bytes)
|
||||
cursor.write_u32_le(self.output_buffer.len() as u32);
|
||||
|
||||
// Buffer
|
||||
if !self.output_buffer.is_empty() {
|
||||
// Offset is from the beginning of the SMB2 header per spec.
|
||||
let buf_offset = Header::SIZE + (cursor.position() - start);
|
||||
cursor.write_bytes(&self.output_buffer);
|
||||
cursor.set_u16_le_at(offset_pos, buf_offset as u16);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Unpack for QueryInfoResponse {
|
||||
fn unpack(cursor: &mut ReadCursor<'_>) -> Result<Self> {
|
||||
let start = cursor.position();
|
||||
|
||||
// StructureSize (2 bytes)
|
||||
let structure_size = cursor.read_u16_le()?;
|
||||
if structure_size != Self::STRUCTURE_SIZE {
|
||||
return Err(Error::invalid_data(format!(
|
||||
"invalid QueryInfoResponse structure size: expected {}, got {}",
|
||||
Self::STRUCTURE_SIZE,
|
||||
structure_size
|
||||
)));
|
||||
}
|
||||
|
||||
// OutputBufferOffset (2 bytes)
|
||||
let buf_offset = cursor.read_u16_le()? as usize;
|
||||
// OutputBufferLength (4 bytes)
|
||||
let buf_length = cursor.read_u32_le()? as usize;
|
||||
|
||||
// Read buffer
|
||||
// Offset on the wire is from beginning of SMB2 header.
|
||||
let output_buffer = if buf_length > 0 {
|
||||
let current = cursor.position();
|
||||
let body_offset = buf_offset.saturating_sub(Header::SIZE);
|
||||
let target = start + body_offset;
|
||||
if target > current {
|
||||
cursor.skip(target - current)?;
|
||||
}
|
||||
cursor.read_bytes_bounded(buf_length)?.to_vec()
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
|
||||
Ok(QueryInfoResponse { output_buffer })
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// ── QueryInfoRequest tests ───────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn query_info_request_roundtrip_file_info() {
|
||||
let original = QueryInfoRequest {
|
||||
info_type: InfoType::File,
|
||||
file_info_class: 0x12, // FileAllInformation
|
||||
output_buffer_length: 4096,
|
||||
additional_information: 0,
|
||||
flags: 0,
|
||||
file_id: FileId {
|
||||
persistent: 0xDEAD_BEEF_CAFE_BABE,
|
||||
volatile: 0x1234_5678_9ABC_DEF0,
|
||||
},
|
||||
input_buffer: Vec::new(),
|
||||
};
|
||||
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = QueryInfoRequest::unpack(&mut r).unwrap();
|
||||
|
||||
assert_eq!(decoded.info_type, InfoType::File);
|
||||
assert_eq!(decoded.file_info_class, 0x12);
|
||||
assert_eq!(decoded.output_buffer_length, 4096);
|
||||
assert_eq!(decoded.additional_information, 0);
|
||||
assert_eq!(decoded.flags, 0);
|
||||
assert_eq!(decoded.file_id, original.file_id);
|
||||
assert!(decoded.input_buffer.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn query_info_request_with_input_buffer() {
|
||||
let input = vec![0x01, 0x02, 0x03, 0x04];
|
||||
let original = QueryInfoRequest {
|
||||
info_type: InfoType::Quota,
|
||||
file_info_class: 0x20,
|
||||
output_buffer_length: 8192,
|
||||
additional_information: 0x04, // SACL_SECURITY_INFORMATION
|
||||
flags: 0,
|
||||
file_id: FileId {
|
||||
persistent: 1,
|
||||
volatile: 2,
|
||||
},
|
||||
input_buffer: input.clone(),
|
||||
};
|
||||
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = QueryInfoRequest::unpack(&mut r).unwrap();
|
||||
|
||||
assert_eq!(decoded.info_type, InfoType::Quota);
|
||||
assert_eq!(decoded.input_buffer, input);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn query_info_request_structure_size() {
|
||||
let req = QueryInfoRequest {
|
||||
info_type: InfoType::File,
|
||||
file_info_class: 0,
|
||||
output_buffer_length: 0,
|
||||
additional_information: 0,
|
||||
flags: 0,
|
||||
file_id: FileId::default(),
|
||||
input_buffer: Vec::new(),
|
||||
};
|
||||
|
||||
let mut w = WriteCursor::new();
|
||||
req.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
assert_eq!(bytes[0], 41);
|
||||
assert_eq!(bytes[1], 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn query_info_request_wrong_structure_size() {
|
||||
let mut buf = vec![0u8; 48];
|
||||
buf[0..2].copy_from_slice(&99u16.to_le_bytes());
|
||||
let mut cursor = ReadCursor::new(&buf);
|
||||
let result = QueryInfoRequest::unpack(&mut cursor);
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err().to_string();
|
||||
assert!(err.contains("structure size"), "error was: {err}");
|
||||
}
|
||||
|
||||
// ── QueryInfoResponse tests ──────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn query_info_response_roundtrip_with_data() {
|
||||
let info_data = vec![
|
||||
0x10, 0x20, 0x30, 0x40, 0x50, 0x60, 0x70, 0x80, 0x90, 0xA0, 0xB0, 0xC0,
|
||||
];
|
||||
|
||||
let original = QueryInfoResponse {
|
||||
output_buffer: info_data.clone(),
|
||||
};
|
||||
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = QueryInfoResponse::unpack(&mut r).unwrap();
|
||||
|
||||
assert_eq!(decoded.output_buffer, info_data);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn query_info_response_empty() {
|
||||
let original = QueryInfoResponse {
|
||||
output_buffer: Vec::new(),
|
||||
};
|
||||
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
// StructureSize(2) + Offset(2) + Length(4) = 8
|
||||
assert_eq!(bytes.len(), 8);
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = QueryInfoResponse::unpack(&mut r).unwrap();
|
||||
|
||||
assert!(decoded.output_buffer.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn query_info_response_structure_size() {
|
||||
let resp = QueryInfoResponse {
|
||||
output_buffer: vec![0xFF],
|
||||
};
|
||||
|
||||
let mut w = WriteCursor::new();
|
||||
resp.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
assert_eq!(bytes[0], 9);
|
||||
assert_eq!(bytes[1], 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn query_info_response_wrong_structure_size() {
|
||||
let mut buf = vec![0u8; 16];
|
||||
buf[0..2].copy_from_slice(&42u16.to_le_bytes());
|
||||
let mut cursor = ReadCursor::new(&buf);
|
||||
let result = QueryInfoResponse::unpack(&mut cursor);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
// ── Enum tests ───────────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn info_type_roundtrip() {
|
||||
for &it in &[
|
||||
InfoType::File,
|
||||
InfoType::Filesystem,
|
||||
InfoType::Security,
|
||||
InfoType::Quota,
|
||||
] {
|
||||
let raw = it as u8;
|
||||
let decoded = InfoType::try_from(raw).unwrap();
|
||||
assert_eq!(decoded, it);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn info_type_invalid() {
|
||||
assert!(InfoType::try_from(0x00).is_err());
|
||||
assert!(InfoType::try_from(0x05).is_err());
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod roundtrip_props {
|
||||
use super::*;
|
||||
use crate::msg::roundtrip_strategies::{arb_bytes, arb_file_id, arb_info_type};
|
||||
use proptest::prelude::*;
|
||||
|
||||
proptest! {
|
||||
#[test]
|
||||
fn query_info_request_pack_unpack(
|
||||
info_type in arb_info_type(),
|
||||
file_info_class in any::<u8>(),
|
||||
output_buffer_length in any::<u32>(),
|
||||
additional_information in any::<u32>(),
|
||||
flags in any::<u32>(),
|
||||
file_id in arb_file_id(),
|
||||
input_buffer in arb_bytes(),
|
||||
) {
|
||||
let original = QueryInfoRequest {
|
||||
info_type,
|
||||
file_info_class,
|
||||
output_buffer_length,
|
||||
additional_information,
|
||||
flags,
|
||||
file_id,
|
||||
input_buffer,
|
||||
};
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = QueryInfoRequest::unpack(&mut r).unwrap();
|
||||
prop_assert_eq!(decoded, original);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn query_info_response_pack_unpack(output_buffer in arb_bytes()) {
|
||||
let original = QueryInfoResponse { output_buffer };
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = QueryInfoResponse::unpack(&mut r).unwrap();
|
||||
prop_assert_eq!(decoded, original);
|
||||
}
|
||||
}
|
||||
}
|
||||
462
vendor/smb2/src/msg/read.rs
vendored
Normal file
462
vendor/smb2/src/msg/read.rs
vendored
Normal file
@@ -0,0 +1,462 @@
|
||||
//! SMB2 READ Request and Response (MS-SMB2 sections 2.2.19, 2.2.20).
|
||||
//!
|
||||
//! The READ request reads data from a file or named pipe.
|
||||
//! The response carries the read data in a variable-length buffer.
|
||||
|
||||
use crate::error::Result;
|
||||
use crate::pack::{Pack, ReadCursor, Unpack, WriteCursor};
|
||||
use crate::types::FileId;
|
||||
use crate::Error;
|
||||
|
||||
/// Read flag: read data directly from underlying storage (SMB 3.0.2+).
|
||||
pub const SMB2_READFLAG_READ_UNBUFFERED: u8 = 0x01;
|
||||
|
||||
/// Read flag: request compressed response (SMB 3.1.1).
|
||||
pub const SMB2_READFLAG_REQUEST_COMPRESSED: u8 = 0x02;
|
||||
|
||||
/// Channel value: no channel information.
|
||||
pub const SMB2_CHANNEL_NONE: u32 = 0x0000_0000;
|
||||
|
||||
/// SMB2 READ Request (MS-SMB2 section 2.2.19).
|
||||
///
|
||||
/// Sent by the client to read data from a file. The fixed portion is 49 bytes
|
||||
/// (StructureSize says 49 regardless of the variable buffer length):
|
||||
/// - StructureSize (2 bytes, must be 49)
|
||||
/// - Padding (1 byte)
|
||||
/// - Flags (1 byte)
|
||||
/// - Length (4 bytes)
|
||||
/// - Offset (8 bytes)
|
||||
/// - FileId (16 bytes)
|
||||
/// - MinimumCount (4 bytes)
|
||||
/// - Channel (4 bytes)
|
||||
/// - RemainingBytes (4 bytes)
|
||||
/// - ReadChannelInfoOffset (2 bytes)
|
||||
/// - ReadChannelInfoLength (2 bytes)
|
||||
/// - Buffer (variable, typically empty for basic reads)
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct ReadRequest {
|
||||
/// Requested data placement offset in the response.
|
||||
pub padding: u8,
|
||||
/// Flags for the read operation.
|
||||
pub flags: u8,
|
||||
/// Number of bytes to read.
|
||||
pub length: u32,
|
||||
/// File offset to start reading from.
|
||||
pub offset: u64,
|
||||
/// File handle to read from.
|
||||
pub file_id: FileId,
|
||||
/// Minimum number of bytes for a successful read.
|
||||
pub minimum_count: u32,
|
||||
/// Channel for RDMA operations (typically `SMB2_CHANNEL_NONE`).
|
||||
pub channel: u32,
|
||||
/// Remaining bytes in a multi-part read.
|
||||
pub remaining_bytes: u32,
|
||||
/// Variable-length read channel info buffer.
|
||||
pub read_channel_info: Vec<u8>,
|
||||
}
|
||||
|
||||
impl ReadRequest {
|
||||
pub const STRUCTURE_SIZE: u16 = 49;
|
||||
}
|
||||
|
||||
impl Pack for ReadRequest {
|
||||
fn pack(&self, cursor: &mut WriteCursor) {
|
||||
cursor.write_u16_le(Self::STRUCTURE_SIZE);
|
||||
cursor.write_u8(self.padding);
|
||||
cursor.write_u8(self.flags);
|
||||
cursor.write_u32_le(self.length);
|
||||
cursor.write_u64_le(self.offset);
|
||||
cursor.write_u64_le(self.file_id.persistent);
|
||||
cursor.write_u64_le(self.file_id.volatile);
|
||||
cursor.write_u32_le(self.minimum_count);
|
||||
cursor.write_u32_le(self.channel);
|
||||
cursor.write_u32_le(self.remaining_bytes);
|
||||
|
||||
// ReadChannelInfoOffset/Length: relative to start of SMB2 header.
|
||||
// For packing the body alone, we store offset as 0 when empty.
|
||||
if self.read_channel_info.is_empty() {
|
||||
cursor.write_u16_le(0);
|
||||
cursor.write_u16_le(0);
|
||||
} else {
|
||||
// Offset from the SMB2 header = header (64) + fixed body (48) = 112.
|
||||
// The fixed body before Buffer is 48 bytes (StructureSize 49 minus
|
||||
// the 1 byte of Buffer that's counted in StructureSize).
|
||||
cursor.write_u16_le(0); // Caller must backpatch if needed
|
||||
cursor.write_u16_le(self.read_channel_info.len() as u16);
|
||||
}
|
||||
|
||||
// Buffer: at minimum 1 byte per the StructureSize=49 contract,
|
||||
// but we write the actual channel info if present.
|
||||
if self.read_channel_info.is_empty() {
|
||||
// Write a single padding byte so the fixed part is 49 bytes
|
||||
// (StructureSize includes this 1-byte minimum buffer).
|
||||
cursor.write_u8(0);
|
||||
} else {
|
||||
cursor.write_bytes(&self.read_channel_info);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Unpack for ReadRequest {
|
||||
fn unpack(cursor: &mut ReadCursor<'_>) -> Result<Self> {
|
||||
let structure_size = cursor.read_u16_le()?;
|
||||
if structure_size != Self::STRUCTURE_SIZE {
|
||||
return Err(Error::invalid_data(format!(
|
||||
"invalid ReadRequest structure size: expected {}, got {}",
|
||||
Self::STRUCTURE_SIZE,
|
||||
structure_size
|
||||
)));
|
||||
}
|
||||
|
||||
let padding = cursor.read_u8()?;
|
||||
let flags = cursor.read_u8()?;
|
||||
let length = cursor.read_u32_le()?;
|
||||
let offset = cursor.read_u64_le()?;
|
||||
let persistent = cursor.read_u64_le()?;
|
||||
let volatile = cursor.read_u64_le()?;
|
||||
let minimum_count = cursor.read_u32_le()?;
|
||||
let channel = cursor.read_u32_le()?;
|
||||
let remaining_bytes = cursor.read_u32_le()?;
|
||||
let _read_channel_info_offset = cursor.read_u16_le()?;
|
||||
let read_channel_info_length = cursor.read_u16_le()?;
|
||||
|
||||
// The buffer is at least 1 byte (per StructureSize=49).
|
||||
// Read channel info from the buffer based on the length field.
|
||||
let read_channel_info = if read_channel_info_length > 0 {
|
||||
cursor
|
||||
.read_bytes(read_channel_info_length as usize)?
|
||||
.to_vec()
|
||||
} else {
|
||||
// Skip the minimum 1-byte buffer
|
||||
cursor.skip(1)?;
|
||||
Vec::new()
|
||||
};
|
||||
|
||||
Ok(ReadRequest {
|
||||
padding,
|
||||
flags,
|
||||
length,
|
||||
offset,
|
||||
file_id: FileId {
|
||||
persistent,
|
||||
volatile,
|
||||
},
|
||||
minimum_count,
|
||||
channel,
|
||||
remaining_bytes,
|
||||
read_channel_info,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// SMB2 READ Response (MS-SMB2 section 2.2.20).
|
||||
///
|
||||
/// Sent by the server with the requested data. The fixed portion is 17 bytes:
|
||||
/// - StructureSize (2 bytes, must be 17)
|
||||
/// - DataOffset (1 byte)
|
||||
/// - Reserved (1 byte)
|
||||
/// - DataLength (4 bytes)
|
||||
/// - DataRemaining (4 bytes)
|
||||
/// - Reserved2 (4 bytes)
|
||||
/// - Buffer (variable, DataLength bytes)
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct ReadResponse {
|
||||
/// Offset from the start of the SMB2 header to the data.
|
||||
pub data_offset: u8,
|
||||
/// Number of remaining bytes on the channel.
|
||||
pub data_remaining: u32,
|
||||
/// Flags/Reserved2 field (used in SMB 3.1.1, otherwise 0).
|
||||
pub flags: u32,
|
||||
/// The data that was read.
|
||||
pub data: Vec<u8>,
|
||||
}
|
||||
|
||||
impl ReadResponse {
|
||||
pub const STRUCTURE_SIZE: u16 = 17;
|
||||
}
|
||||
|
||||
impl Pack for ReadResponse {
|
||||
fn pack(&self, cursor: &mut WriteCursor) {
|
||||
cursor.write_u16_le(Self::STRUCTURE_SIZE);
|
||||
cursor.write_u8(self.data_offset);
|
||||
cursor.write_u8(0); // Reserved
|
||||
cursor.write_u32_le(self.data.len() as u32);
|
||||
cursor.write_u32_le(self.data_remaining);
|
||||
cursor.write_u32_le(self.flags); // Reserved2/Flags
|
||||
cursor.write_bytes(&self.data);
|
||||
}
|
||||
}
|
||||
|
||||
impl Unpack for ReadResponse {
|
||||
fn unpack(cursor: &mut ReadCursor<'_>) -> Result<Self> {
|
||||
let structure_size = cursor.read_u16_le()?;
|
||||
if structure_size != Self::STRUCTURE_SIZE {
|
||||
return Err(Error::invalid_data(format!(
|
||||
"invalid ReadResponse structure size: expected {}, got {}",
|
||||
Self::STRUCTURE_SIZE,
|
||||
structure_size
|
||||
)));
|
||||
}
|
||||
|
||||
let data_offset = cursor.read_u8()?;
|
||||
let _reserved = cursor.read_u8()?;
|
||||
let data_length = cursor.read_u32_le()?;
|
||||
let data_remaining = cursor.read_u32_le()?;
|
||||
let flags = cursor.read_u32_le()?;
|
||||
|
||||
let data = if data_length > 0 {
|
||||
cursor.read_bytes_bounded(data_length as usize)?.to_vec()
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
|
||||
Ok(ReadResponse {
|
||||
data_offset,
|
||||
data_remaining,
|
||||
flags,
|
||||
data,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// ── ReadRequest tests ──────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn read_request_roundtrip() {
|
||||
let original = ReadRequest {
|
||||
padding: 0x50,
|
||||
flags: SMB2_READFLAG_READ_UNBUFFERED,
|
||||
length: 65536,
|
||||
offset: 0x1000,
|
||||
file_id: FileId {
|
||||
persistent: 0xAAAA_BBBB_CCCC_DDDD,
|
||||
volatile: 0x1111_2222_3333_4444,
|
||||
},
|
||||
minimum_count: 1024,
|
||||
channel: SMB2_CHANNEL_NONE,
|
||||
remaining_bytes: 0,
|
||||
read_channel_info: Vec::new(),
|
||||
};
|
||||
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
// Fixed: 48 bytes + 1-byte minimum buffer = 49 bytes
|
||||
assert_eq!(bytes.len(), 49);
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = ReadRequest::unpack(&mut r).unwrap();
|
||||
|
||||
assert_eq!(decoded.padding, original.padding);
|
||||
assert_eq!(decoded.flags, original.flags);
|
||||
assert_eq!(decoded.length, original.length);
|
||||
assert_eq!(decoded.offset, original.offset);
|
||||
assert_eq!(decoded.file_id, original.file_id);
|
||||
assert_eq!(decoded.minimum_count, original.minimum_count);
|
||||
assert_eq!(decoded.channel, original.channel);
|
||||
assert_eq!(decoded.remaining_bytes, original.remaining_bytes);
|
||||
assert!(decoded.read_channel_info.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_request_with_channel_info_roundtrip() {
|
||||
let channel_data = vec![0xDE, 0xAD, 0xBE, 0xEF];
|
||||
let original = ReadRequest {
|
||||
padding: 0,
|
||||
flags: 0,
|
||||
length: 4096,
|
||||
offset: 0,
|
||||
file_id: FileId {
|
||||
persistent: 1,
|
||||
volatile: 2,
|
||||
},
|
||||
minimum_count: 0,
|
||||
channel: 0x0000_0001, // SMB2_CHANNEL_RDMA_V1
|
||||
remaining_bytes: 4096,
|
||||
read_channel_info: channel_data.clone(),
|
||||
};
|
||||
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
// Fixed: 48 bytes + 4-byte channel info = 52 bytes
|
||||
assert_eq!(bytes.len(), 52);
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = ReadRequest::unpack(&mut r).unwrap();
|
||||
|
||||
assert_eq!(decoded.read_channel_info, channel_data);
|
||||
assert_eq!(decoded.channel, 0x0000_0001);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_request_wrong_structure_size() {
|
||||
let mut buf = [0u8; 49];
|
||||
buf[0..2].copy_from_slice(&50u16.to_le_bytes());
|
||||
|
||||
let mut cursor = ReadCursor::new(&buf);
|
||||
let result = ReadRequest::unpack(&mut cursor);
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err().to_string();
|
||||
assert!(err.contains("structure size"), "error was: {err}");
|
||||
}
|
||||
|
||||
// ── ReadResponse tests ─────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn read_response_roundtrip() {
|
||||
let original = ReadResponse {
|
||||
data_offset: 0x50, // typical: 64 (header) + 16 (body fixed) = 80 = 0x50
|
||||
data_remaining: 0,
|
||||
flags: 0,
|
||||
data: vec![0x01, 0x02, 0x03, 0x04, 0x05],
|
||||
};
|
||||
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
// Fixed: 16 bytes + 5 bytes data = 21 bytes
|
||||
assert_eq!(bytes.len(), 21);
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = ReadResponse::unpack(&mut r).unwrap();
|
||||
|
||||
assert_eq!(decoded.data_offset, original.data_offset);
|
||||
assert_eq!(decoded.data_remaining, original.data_remaining);
|
||||
assert_eq!(decoded.flags, original.flags);
|
||||
assert_eq!(decoded.data, original.data);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_response_empty_data() {
|
||||
let original = ReadResponse {
|
||||
data_offset: 0,
|
||||
data_remaining: 0,
|
||||
flags: 0,
|
||||
data: Vec::new(),
|
||||
};
|
||||
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
// Fixed: 16 bytes, no data
|
||||
assert_eq!(bytes.len(), 16);
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = ReadResponse::unpack(&mut r).unwrap();
|
||||
|
||||
assert!(decoded.data.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_response_known_bytes() {
|
||||
let mut buf = Vec::new();
|
||||
// StructureSize = 17
|
||||
buf.extend_from_slice(&17u16.to_le_bytes());
|
||||
// DataOffset = 0x50
|
||||
buf.push(0x50);
|
||||
// Reserved = 0
|
||||
buf.push(0x00);
|
||||
// DataLength = 3
|
||||
buf.extend_from_slice(&3u32.to_le_bytes());
|
||||
// DataRemaining = 0
|
||||
buf.extend_from_slice(&0u32.to_le_bytes());
|
||||
// Reserved2/Flags = 0
|
||||
buf.extend_from_slice(&0u32.to_le_bytes());
|
||||
// Buffer = [0xAA, 0xBB, 0xCC]
|
||||
buf.extend_from_slice(&[0xAA, 0xBB, 0xCC]);
|
||||
|
||||
let mut cursor = ReadCursor::new(&buf);
|
||||
let resp = ReadResponse::unpack(&mut cursor).unwrap();
|
||||
|
||||
assert_eq!(resp.data_offset, 0x50);
|
||||
assert_eq!(resp.data, vec![0xAA, 0xBB, 0xCC]);
|
||||
assert_eq!(resp.data_remaining, 0);
|
||||
assert_eq!(resp.flags, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_response_wrong_structure_size() {
|
||||
let mut buf = [0u8; 16];
|
||||
buf[0..2].copy_from_slice(&99u16.to_le_bytes());
|
||||
|
||||
let mut cursor = ReadCursor::new(&buf);
|
||||
let result = ReadResponse::unpack(&mut cursor);
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err().to_string();
|
||||
assert!(err.contains("structure size"), "error was: {err}");
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod roundtrip_props {
|
||||
use super::*;
|
||||
use crate::msg::roundtrip_strategies::{arb_bytes, arb_file_id, arb_small_bytes};
|
||||
use proptest::prelude::*;
|
||||
|
||||
proptest! {
|
||||
#[test]
|
||||
fn read_request_pack_unpack(
|
||||
padding in any::<u8>(),
|
||||
flags in any::<u8>(),
|
||||
length in any::<u32>(),
|
||||
offset in any::<u64>(),
|
||||
file_id in arb_file_id(),
|
||||
minimum_count in any::<u32>(),
|
||||
channel in any::<u32>(),
|
||||
remaining_bytes in any::<u32>(),
|
||||
read_channel_info in arb_small_bytes(),
|
||||
) {
|
||||
let original = ReadRequest {
|
||||
padding,
|
||||
flags,
|
||||
length,
|
||||
offset,
|
||||
file_id,
|
||||
minimum_count,
|
||||
channel,
|
||||
remaining_bytes,
|
||||
read_channel_info,
|
||||
};
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = ReadRequest::unpack(&mut r).unwrap();
|
||||
prop_assert_eq!(decoded, original);
|
||||
prop_assert!(r.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_response_pack_unpack(
|
||||
data_offset in any::<u8>(),
|
||||
data_remaining in any::<u32>(),
|
||||
flags in any::<u32>(),
|
||||
data in arb_bytes(),
|
||||
) {
|
||||
let original = ReadResponse {
|
||||
data_offset,
|
||||
data_remaining,
|
||||
flags,
|
||||
data,
|
||||
};
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = ReadResponse::unpack(&mut r).unwrap();
|
||||
prop_assert_eq!(decoded, original);
|
||||
prop_assert!(r.is_empty());
|
||||
}
|
||||
}
|
||||
}
|
||||
250
vendor/smb2/src/msg/roundtrip_strategies.rs
vendored
Normal file
250
vendor/smb2/src/msg/roundtrip_strategies.rs
vendored
Normal file
@@ -0,0 +1,250 @@
|
||||
//! Shared proptest strategies for wire-format roundtrip tests.
|
||||
//!
|
||||
//! Each strategy generates a value that a real encoder could emit. The goal
|
||||
//! is not to stress-test the decoder against malformed input (that's fuzzing)
|
||||
//! but to exercise encode/decode symmetry on well-formed inputs.
|
||||
//!
|
||||
//! Rules followed here:
|
||||
//! - Typed enums always yield valid variants (no invalid discriminants).
|
||||
//! - `Vec<u8>` lengths stay moderate (at most a few KB) to keep tests fast.
|
||||
//! - Internally-dependent sizes (for example, a length field that must match a
|
||||
//! sibling `Vec`) are produced via `prop_map` so generated instances are
|
||||
//! always consistent.
|
||||
|
||||
// Note: `#[cfg(test)]` is applied at the module declaration in `src/msg/mod.rs`
|
||||
// (`#[cfg(test)] pub(crate) mod roundtrip_strategies;`). We don't repeat it
|
||||
// here; clippy's `duplicated_attributes` lint rejects that.
|
||||
#![allow(dead_code)] // Helpers might be unused while tests are being added.
|
||||
|
||||
use proptest::prelude::*;
|
||||
|
||||
use crate::pack::{FileTime, Guid};
|
||||
use crate::types::flags::{
|
||||
Capabilities, FileAccessMask, HeaderFlags, SecurityMode, ShareCapabilities, ShareFlags,
|
||||
};
|
||||
use crate::types::status::NtStatus;
|
||||
use crate::types::{
|
||||
Command, CreditCharge, Dialect, FileId, MessageId, OplockLevel, SessionId, TreeId,
|
||||
};
|
||||
|
||||
/// Max size (in bytes) used for generated `Vec<u8>` buffers across tests.
|
||||
/// Kept small so a 256-case proptest run stays well under a second.
|
||||
pub const MAX_BUFFER_BYTES: usize = 1024;
|
||||
|
||||
/// Moderate buffer for structs that usually carry small bodies.
|
||||
pub const MAX_SMALL_BUFFER_BYTES: usize = 256;
|
||||
|
||||
/// Generate a `Vec<u8>` up to `max` bytes long (including zero).
|
||||
pub fn bytes_up_to(max: usize) -> impl Strategy<Value = Vec<u8>> {
|
||||
prop::collection::vec(any::<u8>(), 0..=max)
|
||||
}
|
||||
|
||||
/// A standard moderate-length byte buffer.
|
||||
pub fn arb_bytes() -> impl Strategy<Value = Vec<u8>> {
|
||||
bytes_up_to(MAX_BUFFER_BYTES)
|
||||
}
|
||||
|
||||
/// A smaller byte buffer, for sub-fields or tightly-nested structures.
|
||||
pub fn arb_small_bytes() -> impl Strategy<Value = Vec<u8>> {
|
||||
bytes_up_to(MAX_SMALL_BUFFER_BYTES)
|
||||
}
|
||||
|
||||
/// Generate a valid UTF-16-encodable String, up to `max_chars` chars.
|
||||
///
|
||||
/// Excludes unpaired surrogates (U+D800..=U+DFFF) because UTF-16 decoding
|
||||
/// would reject any surrogate that isn't part of a valid pair. We use the
|
||||
/// BMP-minus-surrogates range plus occasional supplementary characters, so
|
||||
/// both one-code-unit and two-code-unit forms are covered.
|
||||
pub fn arb_utf16_string(max_chars: usize) -> impl Strategy<Value = String> {
|
||||
prop::collection::vec(
|
||||
prop::char::range('\u{0000}', '\u{D7FF}')
|
||||
.prop_union(prop::char::range('\u{E000}', '\u{FFFF}'))
|
||||
.or(prop::char::range('\u{1_0000}', '\u{10_FFFF}')),
|
||||
0..=max_chars,
|
||||
)
|
||||
.prop_map(|chars| chars.into_iter().collect())
|
||||
}
|
||||
|
||||
// ── Primitive newtype strategies ────────────────────────────────────
|
||||
|
||||
pub fn arb_session_id() -> impl Strategy<Value = SessionId> {
|
||||
any::<u64>().prop_map(SessionId)
|
||||
}
|
||||
|
||||
pub fn arb_message_id() -> impl Strategy<Value = MessageId> {
|
||||
any::<u64>().prop_map(MessageId)
|
||||
}
|
||||
|
||||
pub fn arb_tree_id() -> impl Strategy<Value = TreeId> {
|
||||
any::<u32>().prop_map(TreeId)
|
||||
}
|
||||
|
||||
pub fn arb_credit_charge() -> impl Strategy<Value = CreditCharge> {
|
||||
any::<u16>().prop_map(CreditCharge)
|
||||
}
|
||||
|
||||
pub fn arb_file_id() -> impl Strategy<Value = FileId> {
|
||||
(any::<u64>(), any::<u64>()).prop_map(|(persistent, volatile)| FileId {
|
||||
persistent,
|
||||
volatile,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn arb_file_time() -> impl Strategy<Value = FileTime> {
|
||||
any::<u64>().prop_map(FileTime)
|
||||
}
|
||||
|
||||
pub fn arb_guid() -> impl Strategy<Value = Guid> {
|
||||
(any::<u32>(), any::<u16>(), any::<u16>(), any::<[u8; 8]>()).prop_map(
|
||||
|(data1, data2, data3, data4)| Guid {
|
||||
data1,
|
||||
data2,
|
||||
data3,
|
||||
data4,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
pub fn arb_nt_status() -> impl Strategy<Value = NtStatus> {
|
||||
any::<u32>().prop_map(NtStatus)
|
||||
}
|
||||
|
||||
// ── Flags ────────────────────────────────────────────────────────────
|
||||
|
||||
pub fn arb_header_flags() -> impl Strategy<Value = HeaderFlags> {
|
||||
any::<u32>().prop_map(HeaderFlags::new)
|
||||
}
|
||||
|
||||
pub fn arb_security_mode() -> impl Strategy<Value = SecurityMode> {
|
||||
any::<u16>().prop_map(SecurityMode::new)
|
||||
}
|
||||
|
||||
pub fn arb_capabilities() -> impl Strategy<Value = Capabilities> {
|
||||
any::<u32>().prop_map(Capabilities::new)
|
||||
}
|
||||
|
||||
pub fn arb_share_flags() -> impl Strategy<Value = ShareFlags> {
|
||||
any::<u32>().prop_map(ShareFlags::new)
|
||||
}
|
||||
|
||||
pub fn arb_share_capabilities() -> impl Strategy<Value = ShareCapabilities> {
|
||||
any::<u32>().prop_map(ShareCapabilities::new)
|
||||
}
|
||||
|
||||
pub fn arb_file_access_mask() -> impl Strategy<Value = FileAccessMask> {
|
||||
any::<u32>().prop_map(FileAccessMask::new)
|
||||
}
|
||||
|
||||
// ── Typed enums: only valid variants ────────────────────────────────
|
||||
|
||||
pub fn arb_oplock_level() -> impl Strategy<Value = OplockLevel> {
|
||||
prop_oneof![
|
||||
Just(OplockLevel::None),
|
||||
Just(OplockLevel::LevelII),
|
||||
Just(OplockLevel::Exclusive),
|
||||
Just(OplockLevel::Batch),
|
||||
Just(OplockLevel::Lease),
|
||||
]
|
||||
}
|
||||
|
||||
pub fn arb_dialect() -> impl Strategy<Value = Dialect> {
|
||||
prop_oneof![
|
||||
Just(Dialect::Smb2_0_2),
|
||||
Just(Dialect::Smb2_1),
|
||||
Just(Dialect::Smb3_0),
|
||||
Just(Dialect::Smb3_0_2),
|
||||
Just(Dialect::Smb3_1_1),
|
||||
]
|
||||
}
|
||||
|
||||
pub fn arb_share_type() -> impl Strategy<Value = crate::msg::tree_connect::ShareType> {
|
||||
use crate::msg::tree_connect::ShareType;
|
||||
prop_oneof![
|
||||
Just(ShareType::Disk),
|
||||
Just(ShareType::Pipe),
|
||||
Just(ShareType::Print),
|
||||
]
|
||||
}
|
||||
|
||||
pub fn arb_impersonation_level() -> impl Strategy<Value = crate::msg::create::ImpersonationLevel> {
|
||||
use crate::msg::create::ImpersonationLevel;
|
||||
prop_oneof![
|
||||
Just(ImpersonationLevel::Anonymous),
|
||||
Just(ImpersonationLevel::Identification),
|
||||
Just(ImpersonationLevel::Impersonation),
|
||||
Just(ImpersonationLevel::Delegate),
|
||||
]
|
||||
}
|
||||
|
||||
pub fn arb_create_disposition() -> impl Strategy<Value = crate::msg::create::CreateDisposition> {
|
||||
use crate::msg::create::CreateDisposition;
|
||||
prop_oneof![
|
||||
Just(CreateDisposition::FileSupersede),
|
||||
Just(CreateDisposition::FileOpen),
|
||||
Just(CreateDisposition::FileCreate),
|
||||
Just(CreateDisposition::FileOpenIf),
|
||||
Just(CreateDisposition::FileOverwrite),
|
||||
Just(CreateDisposition::FileOverwriteIf),
|
||||
]
|
||||
}
|
||||
|
||||
pub fn arb_create_action() -> impl Strategy<Value = crate::msg::create::CreateAction> {
|
||||
use crate::msg::create::CreateAction;
|
||||
prop_oneof![
|
||||
Just(CreateAction::FileSuperseded),
|
||||
Just(CreateAction::FileOpened),
|
||||
Just(CreateAction::FileCreated),
|
||||
Just(CreateAction::FileOverwritten),
|
||||
]
|
||||
}
|
||||
|
||||
pub fn arb_share_access() -> impl Strategy<Value = crate::msg::create::ShareAccess> {
|
||||
any::<u32>().prop_map(crate::msg::create::ShareAccess)
|
||||
}
|
||||
|
||||
pub fn arb_info_type() -> impl Strategy<Value = crate::msg::query_info::InfoType> {
|
||||
use crate::msg::query_info::InfoType;
|
||||
prop_oneof![
|
||||
Just(InfoType::File),
|
||||
Just(InfoType::Filesystem),
|
||||
Just(InfoType::Security),
|
||||
Just(InfoType::Quota),
|
||||
]
|
||||
}
|
||||
|
||||
pub fn arb_file_information_class(
|
||||
) -> impl Strategy<Value = crate::msg::query_directory::FileInformationClass> {
|
||||
use crate::msg::query_directory::FileInformationClass;
|
||||
prop_oneof![
|
||||
Just(FileInformationClass::FileDirectoryInformation),
|
||||
Just(FileInformationClass::FileFullDirectoryInformation),
|
||||
Just(FileInformationClass::FileBothDirectoryInformation),
|
||||
Just(FileInformationClass::FileNamesInformation),
|
||||
Just(FileInformationClass::FileIdBothDirectoryInformation),
|
||||
Just(FileInformationClass::FileIdFullDirectoryInformation),
|
||||
]
|
||||
}
|
||||
|
||||
pub fn arb_command() -> impl Strategy<Value = Command> {
|
||||
prop_oneof![
|
||||
Just(Command::Negotiate),
|
||||
Just(Command::SessionSetup),
|
||||
Just(Command::Logoff),
|
||||
Just(Command::TreeConnect),
|
||||
Just(Command::TreeDisconnect),
|
||||
Just(Command::Create),
|
||||
Just(Command::Close),
|
||||
Just(Command::Flush),
|
||||
Just(Command::Read),
|
||||
Just(Command::Write),
|
||||
Just(Command::Lock),
|
||||
Just(Command::Ioctl),
|
||||
Just(Command::Cancel),
|
||||
Just(Command::Echo),
|
||||
Just(Command::QueryDirectory),
|
||||
Just(Command::ChangeNotify),
|
||||
Just(Command::QueryInfo),
|
||||
Just(Command::SetInfo),
|
||||
Just(Command::OplockBreak),
|
||||
]
|
||||
}
|
||||
481
vendor/smb2/src/msg/session_setup.rs
vendored
Normal file
481
vendor/smb2/src/msg/session_setup.rs
vendored
Normal file
@@ -0,0 +1,481 @@
|
||||
//! SMB2 SESSION_SETUP request and response (spec sections 2.2.5, 2.2.6).
|
||||
//!
|
||||
//! Session setup messages are used to establish an authenticated session
|
||||
//! between the client and the server. The request carries a security token
|
||||
//! (for example, SPNEGO/NTLM) and the response carries the server's reply token
|
||||
//! along with session flags.
|
||||
|
||||
use crate::error::Result;
|
||||
use crate::msg::header::Header;
|
||||
use crate::pack::{Pack, ReadCursor, Unpack, WriteCursor};
|
||||
use crate::types::flags::{Capabilities, SecurityMode};
|
||||
use crate::Error;
|
||||
|
||||
// ── Session setup request flags ────────────────────────────────────────
|
||||
|
||||
/// Flags for the SESSION_SETUP request (1 byte, spec section 2.2.5).
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
|
||||
pub struct SessionSetupRequestFlags(pub u8);
|
||||
|
||||
impl SessionSetupRequestFlags {
|
||||
/// Bind an existing session to a new connection (SMB 3.x only).
|
||||
pub const BINDING: u8 = 0x01;
|
||||
|
||||
/// Returns `true` if the binding flag is set.
|
||||
#[inline]
|
||||
pub fn is_binding(&self) -> bool {
|
||||
self.0 & Self::BINDING != 0
|
||||
}
|
||||
}
|
||||
|
||||
// ── Session flags (response) ───────────────────────────────────────────
|
||||
|
||||
/// Session flags returned in the SESSION_SETUP response (spec section 2.2.6).
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
|
||||
pub struct SessionFlags(pub u16);
|
||||
|
||||
impl SessionFlags {
|
||||
/// The client has been authenticated as a guest user.
|
||||
pub const IS_GUEST: u16 = 0x0001;
|
||||
/// The client has been authenticated as an anonymous user.
|
||||
pub const IS_NULL: u16 = 0x0002;
|
||||
/// The server requires encryption of messages on this session (SMB 3.x only).
|
||||
pub const ENCRYPT_DATA: u16 = 0x0004;
|
||||
|
||||
/// Returns `true` if the guest flag is set.
|
||||
#[inline]
|
||||
pub fn is_guest(&self) -> bool {
|
||||
self.0 & Self::IS_GUEST != 0
|
||||
}
|
||||
|
||||
/// Returns `true` if the null session flag is set.
|
||||
#[inline]
|
||||
pub fn is_null(&self) -> bool {
|
||||
self.0 & Self::IS_NULL != 0
|
||||
}
|
||||
|
||||
/// Returns `true` if the encrypt-data flag is set.
|
||||
#[inline]
|
||||
pub fn encrypt_data(&self) -> bool {
|
||||
self.0 & Self::ENCRYPT_DATA != 0
|
||||
}
|
||||
}
|
||||
|
||||
// ── SessionSetupRequest ────────────────────────────────────────────────
|
||||
|
||||
/// SMB2 SESSION_SETUP request (spec section 2.2.5).
|
||||
///
|
||||
/// Sent by the client to establish an authenticated session. The security
|
||||
/// buffer carries a GSS/SPNEGO token (or other auth protocol token).
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct SessionSetupRequest {
|
||||
/// Flags controlling the request (for example, session binding).
|
||||
pub flags: SessionSetupRequestFlags,
|
||||
/// Security mode indicating signing requirements.
|
||||
pub security_mode: SecurityMode,
|
||||
/// Client capabilities.
|
||||
pub capabilities: Capabilities,
|
||||
/// Channel field (reserved, must be 0).
|
||||
pub channel: u32,
|
||||
/// Previously established session identifier for reconnection.
|
||||
pub previous_session_id: u64,
|
||||
/// Security buffer containing the authentication token.
|
||||
pub security_buffer: Vec<u8>,
|
||||
}
|
||||
|
||||
impl SessionSetupRequest {
|
||||
pub const STRUCTURE_SIZE: u16 = 25;
|
||||
}
|
||||
|
||||
impl Pack for SessionSetupRequest {
|
||||
fn pack(&self, cursor: &mut WriteCursor) {
|
||||
// StructureSize (2 bytes)
|
||||
cursor.write_u16_le(Self::STRUCTURE_SIZE);
|
||||
// Flags (1 byte)
|
||||
cursor.write_u8(self.flags.0);
|
||||
// SecurityMode (1 byte)
|
||||
cursor.write_u8(self.security_mode.bits() as u8);
|
||||
// Capabilities (4 bytes)
|
||||
cursor.write_u32_le(self.capabilities.bits());
|
||||
// Channel (4 bytes)
|
||||
cursor.write_u32_le(self.channel);
|
||||
|
||||
// SecurityBufferOffset (2 bytes) -- offset from start of SMB2 header
|
||||
let offset = (Header::SIZE + 24) as u16; // 24 = bytes before the buffer in this struct
|
||||
cursor.write_u16_le(offset);
|
||||
// SecurityBufferLength (2 bytes)
|
||||
cursor.write_u16_le(self.security_buffer.len() as u16);
|
||||
// PreviousSessionId (8 bytes)
|
||||
cursor.write_u64_le(self.previous_session_id);
|
||||
// Buffer (variable)
|
||||
cursor.write_bytes(&self.security_buffer);
|
||||
}
|
||||
}
|
||||
|
||||
impl Unpack for SessionSetupRequest {
|
||||
fn unpack(cursor: &mut ReadCursor<'_>) -> Result<Self> {
|
||||
// StructureSize (2 bytes)
|
||||
let structure_size = cursor.read_u16_le()?;
|
||||
if structure_size != Self::STRUCTURE_SIZE {
|
||||
return Err(Error::invalid_data(format!(
|
||||
"invalid SessionSetupRequest structure size: expected {}, got {}",
|
||||
Self::STRUCTURE_SIZE,
|
||||
structure_size
|
||||
)));
|
||||
}
|
||||
|
||||
// Flags (1 byte)
|
||||
let flags = SessionSetupRequestFlags(cursor.read_u8()?);
|
||||
// SecurityMode (1 byte)
|
||||
let security_mode = SecurityMode::new(cursor.read_u8()? as u16);
|
||||
// Capabilities (4 bytes)
|
||||
let capabilities = Capabilities::new(cursor.read_u32_le()?);
|
||||
// Channel (4 bytes)
|
||||
let channel = cursor.read_u32_le()?;
|
||||
// SecurityBufferOffset (2 bytes) -- we ignore, read sequentially
|
||||
let _offset = cursor.read_u16_le()?;
|
||||
// SecurityBufferLength (2 bytes)
|
||||
let buffer_length = cursor.read_u16_le()? as usize;
|
||||
// PreviousSessionId (8 bytes)
|
||||
let previous_session_id = cursor.read_u64_le()?;
|
||||
// Buffer (variable)
|
||||
let security_buffer = if buffer_length > 0 {
|
||||
cursor.read_bytes_bounded(buffer_length)?.to_vec()
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
|
||||
Ok(SessionSetupRequest {
|
||||
flags,
|
||||
security_mode,
|
||||
capabilities,
|
||||
channel,
|
||||
previous_session_id,
|
||||
security_buffer,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ── SessionSetupResponse ───────────────────────────────────────────────
|
||||
|
||||
/// SMB2 SESSION_SETUP response (spec section 2.2.6).
|
||||
///
|
||||
/// Sent by the server in response to a SESSION_SETUP request. Contains
|
||||
/// session flags and a security buffer with the server's auth token.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct SessionSetupResponse {
|
||||
/// Flags indicating additional information about the session.
|
||||
pub session_flags: SessionFlags,
|
||||
/// Security buffer containing the server's authentication token.
|
||||
pub security_buffer: Vec<u8>,
|
||||
}
|
||||
|
||||
impl SessionSetupResponse {
|
||||
pub const STRUCTURE_SIZE: u16 = 9;
|
||||
}
|
||||
|
||||
impl Pack for SessionSetupResponse {
|
||||
fn pack(&self, cursor: &mut WriteCursor) {
|
||||
// StructureSize (2 bytes)
|
||||
cursor.write_u16_le(Self::STRUCTURE_SIZE);
|
||||
// SessionFlags (2 bytes)
|
||||
cursor.write_u16_le(self.session_flags.0);
|
||||
// SecurityBufferOffset (2 bytes) -- offset from start of SMB2 header
|
||||
let offset = (Header::SIZE + 8) as u16; // 8 = fixed part of response struct
|
||||
cursor.write_u16_le(offset);
|
||||
// SecurityBufferLength (2 bytes)
|
||||
cursor.write_u16_le(self.security_buffer.len() as u16);
|
||||
// Buffer (variable)
|
||||
cursor.write_bytes(&self.security_buffer);
|
||||
}
|
||||
}
|
||||
|
||||
impl Unpack for SessionSetupResponse {
|
||||
fn unpack(cursor: &mut ReadCursor<'_>) -> Result<Self> {
|
||||
// StructureSize (2 bytes)
|
||||
let structure_size = cursor.read_u16_le()?;
|
||||
if structure_size != Self::STRUCTURE_SIZE {
|
||||
return Err(Error::invalid_data(format!(
|
||||
"invalid SessionSetupResponse structure size: expected {}, got {}",
|
||||
Self::STRUCTURE_SIZE,
|
||||
structure_size
|
||||
)));
|
||||
}
|
||||
|
||||
// SessionFlags (2 bytes)
|
||||
let session_flags = SessionFlags(cursor.read_u16_le()?);
|
||||
// SecurityBufferOffset (2 bytes)
|
||||
let _offset = cursor.read_u16_le()?;
|
||||
// SecurityBufferLength (2 bytes)
|
||||
let buffer_length = cursor.read_u16_le()? as usize;
|
||||
// Buffer (variable)
|
||||
let security_buffer = if buffer_length > 0 {
|
||||
cursor.read_bytes_bounded(buffer_length)?.to_vec()
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
|
||||
Ok(SessionSetupResponse {
|
||||
session_flags,
|
||||
security_buffer,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// ── SessionSetupRequest tests ──────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn session_setup_request_roundtrip() {
|
||||
let token = vec![0x60, 0x28, 0x06, 0x06, 0x2b, 0x06, 0x01, 0x05];
|
||||
let original = SessionSetupRequest {
|
||||
flags: SessionSetupRequestFlags(0),
|
||||
security_mode: SecurityMode::new(SecurityMode::SIGNING_ENABLED),
|
||||
capabilities: Capabilities::new(Capabilities::DFS),
|
||||
channel: 0,
|
||||
previous_session_id: 0,
|
||||
security_buffer: token.clone(),
|
||||
};
|
||||
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = SessionSetupRequest::unpack(&mut r).unwrap();
|
||||
|
||||
assert_eq!(decoded.flags, original.flags);
|
||||
assert_eq!(decoded.security_mode.bits(), original.security_mode.bits());
|
||||
assert_eq!(decoded.capabilities.bits(), original.capabilities.bits());
|
||||
assert_eq!(decoded.channel, 0);
|
||||
assert_eq!(decoded.previous_session_id, 0);
|
||||
assert_eq!(decoded.security_buffer, token);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn session_setup_request_with_binding_flag() {
|
||||
let original = SessionSetupRequest {
|
||||
flags: SessionSetupRequestFlags(SessionSetupRequestFlags::BINDING),
|
||||
security_mode: SecurityMode::new(
|
||||
SecurityMode::SIGNING_ENABLED | SecurityMode::SIGNING_REQUIRED,
|
||||
),
|
||||
capabilities: Capabilities::default(),
|
||||
channel: 0,
|
||||
previous_session_id: 0xDEAD_BEEF_CAFE_BABE,
|
||||
security_buffer: vec![0xAA, 0xBB],
|
||||
};
|
||||
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = SessionSetupRequest::unpack(&mut r).unwrap();
|
||||
|
||||
assert!(decoded.flags.is_binding());
|
||||
assert!(decoded.security_mode.signing_enabled());
|
||||
assert!(decoded.security_mode.signing_required());
|
||||
assert_eq!(decoded.previous_session_id, 0xDEAD_BEEF_CAFE_BABE);
|
||||
assert_eq!(decoded.security_buffer, vec![0xAA, 0xBB]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn session_setup_request_empty_buffer() {
|
||||
let original = SessionSetupRequest {
|
||||
flags: SessionSetupRequestFlags(0),
|
||||
security_mode: SecurityMode::default(),
|
||||
capabilities: Capabilities::default(),
|
||||
channel: 0,
|
||||
previous_session_id: 0,
|
||||
security_buffer: Vec::new(),
|
||||
};
|
||||
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = SessionSetupRequest::unpack(&mut r).unwrap();
|
||||
|
||||
assert!(decoded.security_buffer.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn session_setup_request_structure_size_field() {
|
||||
let req = SessionSetupRequest {
|
||||
flags: SessionSetupRequestFlags(0),
|
||||
security_mode: SecurityMode::default(),
|
||||
capabilities: Capabilities::default(),
|
||||
channel: 0,
|
||||
previous_session_id: 0,
|
||||
security_buffer: vec![0x01],
|
||||
};
|
||||
|
||||
let mut w = WriteCursor::new();
|
||||
req.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
// First 2 bytes are structure size = 25
|
||||
assert_eq!(bytes[0], 25);
|
||||
assert_eq!(bytes[1], 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn session_setup_request_wrong_structure_size() {
|
||||
let mut buf = [0u8; 26];
|
||||
buf[0..2].copy_from_slice(&99u16.to_le_bytes());
|
||||
let mut cursor = ReadCursor::new(&buf);
|
||||
let result = SessionSetupRequest::unpack(&mut cursor);
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err().to_string();
|
||||
assert!(err.contains("structure size"), "error was: {err}");
|
||||
}
|
||||
|
||||
// ── SessionSetupResponse tests ─────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn session_setup_response_roundtrip() {
|
||||
let token = vec![0xA1, 0x81, 0xB0, 0x30, 0x81, 0xAD];
|
||||
let original = SessionSetupResponse {
|
||||
session_flags: SessionFlags(0),
|
||||
security_buffer: token.clone(),
|
||||
};
|
||||
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = SessionSetupResponse::unpack(&mut r).unwrap();
|
||||
|
||||
assert_eq!(decoded.session_flags, original.session_flags);
|
||||
assert_eq!(decoded.security_buffer, token);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn session_setup_response_with_flags() {
|
||||
let original = SessionSetupResponse {
|
||||
session_flags: SessionFlags(SessionFlags::IS_GUEST | SessionFlags::ENCRYPT_DATA),
|
||||
security_buffer: vec![0x01, 0x02, 0x03],
|
||||
};
|
||||
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = SessionSetupResponse::unpack(&mut r).unwrap();
|
||||
|
||||
assert!(decoded.session_flags.is_guest());
|
||||
assert!(!decoded.session_flags.is_null());
|
||||
assert!(decoded.session_flags.encrypt_data());
|
||||
assert_eq!(decoded.security_buffer, vec![0x01, 0x02, 0x03]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn session_setup_response_null_session() {
|
||||
let original = SessionSetupResponse {
|
||||
session_flags: SessionFlags(SessionFlags::IS_NULL),
|
||||
security_buffer: Vec::new(),
|
||||
};
|
||||
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = SessionSetupResponse::unpack(&mut r).unwrap();
|
||||
|
||||
assert!(decoded.session_flags.is_null());
|
||||
assert!(!decoded.session_flags.is_guest());
|
||||
assert!(decoded.security_buffer.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn session_setup_response_structure_size_field() {
|
||||
let resp = SessionSetupResponse {
|
||||
session_flags: SessionFlags(0),
|
||||
security_buffer: Vec::new(),
|
||||
};
|
||||
|
||||
let mut w = WriteCursor::new();
|
||||
resp.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
// First 2 bytes are structure size = 9
|
||||
assert_eq!(bytes[0], 9);
|
||||
assert_eq!(bytes[1], 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn session_setup_response_wrong_structure_size() {
|
||||
let mut buf = [0u8; 10];
|
||||
buf[0..2].copy_from_slice(&99u16.to_le_bytes());
|
||||
let mut cursor = ReadCursor::new(&buf);
|
||||
let result = SessionSetupResponse::unpack(&mut cursor);
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err().to_string();
|
||||
assert!(err.contains("structure size"), "error was: {err}");
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod roundtrip_props {
|
||||
use super::*;
|
||||
use crate::msg::roundtrip_strategies::{arb_capabilities, arb_small_bytes};
|
||||
use proptest::prelude::*;
|
||||
|
||||
proptest! {
|
||||
#[test]
|
||||
fn session_setup_request_pack_unpack(
|
||||
flags_raw in any::<u8>(),
|
||||
// SESSION_SETUP packs SecurityMode as a single byte, so only the
|
||||
// low 8 bits survive the roundtrip. Generate u8 values to avoid
|
||||
// producing inputs the encoder would never emit from a real caller.
|
||||
security_mode_raw in any::<u8>(),
|
||||
capabilities in arb_capabilities(),
|
||||
channel in any::<u32>(),
|
||||
previous_session_id in any::<u64>(),
|
||||
security_buffer in arb_small_bytes(),
|
||||
) {
|
||||
let original = SessionSetupRequest {
|
||||
flags: SessionSetupRequestFlags(flags_raw),
|
||||
security_mode: SecurityMode::new(security_mode_raw as u16),
|
||||
capabilities,
|
||||
channel,
|
||||
previous_session_id,
|
||||
security_buffer,
|
||||
};
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = SessionSetupRequest::unpack(&mut r).unwrap();
|
||||
prop_assert_eq!(decoded, original);
|
||||
prop_assert!(r.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn session_setup_response_pack_unpack(
|
||||
session_flags_raw in any::<u16>(),
|
||||
security_buffer in arb_small_bytes(),
|
||||
) {
|
||||
let original = SessionSetupResponse {
|
||||
session_flags: SessionFlags(session_flags_raw),
|
||||
security_buffer,
|
||||
};
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = SessionSetupResponse::unpack(&mut r).unwrap();
|
||||
prop_assert_eq!(decoded, original);
|
||||
prop_assert!(r.is_empty());
|
||||
}
|
||||
}
|
||||
}
|
||||
328
vendor/smb2/src/msg/set_info.rs
vendored
Normal file
328
vendor/smb2/src/msg/set_info.rs
vendored
Normal file
@@ -0,0 +1,328 @@
|
||||
//! SMB2 SET_INFO request and response (spec sections 2.2.39, 2.2.40).
|
||||
//!
|
||||
//! Used to set file, filesystem, security, or quota information.
|
||||
//! The request buffer contains the information to set, stored as raw bytes.
|
||||
//! The response is a minimal 2-byte structure.
|
||||
|
||||
use crate::error::Result;
|
||||
use crate::msg::header::Header;
|
||||
use crate::pack::{Pack, ReadCursor, Unpack, WriteCursor};
|
||||
use crate::types::FileId;
|
||||
use crate::Error;
|
||||
|
||||
// Re-use InfoType from query_info
|
||||
pub use super::query_info::InfoType;
|
||||
|
||||
// ── SetInfoRequest ───────────────────────────────────────────────────────
|
||||
|
||||
/// SMB2 SET_INFO request (spec section 2.2.39).
|
||||
///
|
||||
/// Sent by the client to set information on a file, filesystem,
|
||||
/// security descriptor, or quota.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct SetInfoRequest {
|
||||
/// The type of information being set.
|
||||
pub info_type: InfoType,
|
||||
/// The file information class (interpretation depends on `info_type`).
|
||||
pub file_info_class: u8,
|
||||
/// Additional information flags (for example, security information flags).
|
||||
pub additional_information: u32,
|
||||
/// Handle to the file or directory.
|
||||
pub file_id: FileId,
|
||||
/// Raw buffer containing the information to set.
|
||||
pub buffer: Vec<u8>,
|
||||
}
|
||||
|
||||
impl SetInfoRequest {
|
||||
pub const STRUCTURE_SIZE: u16 = 33;
|
||||
}
|
||||
|
||||
impl Pack for SetInfoRequest {
|
||||
fn pack(&self, cursor: &mut WriteCursor) {
|
||||
let start = cursor.position();
|
||||
|
||||
// StructureSize (2 bytes)
|
||||
cursor.write_u16_le(Self::STRUCTURE_SIZE);
|
||||
// InfoType (1 byte)
|
||||
cursor.write_u8(self.info_type as u8);
|
||||
// FileInfoClass (1 byte)
|
||||
cursor.write_u8(self.file_info_class);
|
||||
// BufferLength (4 bytes)
|
||||
cursor.write_u32_le(self.buffer.len() as u32);
|
||||
// BufferOffset (2 bytes) -- placeholder
|
||||
let offset_pos = cursor.position();
|
||||
cursor.write_u16_le(0);
|
||||
// Reserved (2 bytes)
|
||||
cursor.write_u16_le(0);
|
||||
// AdditionalInformation (4 bytes)
|
||||
cursor.write_u32_le(self.additional_information);
|
||||
// FileId (16 bytes)
|
||||
cursor.write_u64_le(self.file_id.persistent);
|
||||
cursor.write_u64_le(self.file_id.volatile);
|
||||
|
||||
// Buffer (variable)
|
||||
if !self.buffer.is_empty() {
|
||||
// Offset is from the beginning of the SMB2 header per spec.
|
||||
let buf_offset = Header::SIZE + (cursor.position() - start);
|
||||
cursor.write_bytes(&self.buffer);
|
||||
cursor.set_u16_le_at(offset_pos, buf_offset as u16);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Unpack for SetInfoRequest {
|
||||
fn unpack(cursor: &mut ReadCursor<'_>) -> Result<Self> {
|
||||
let start = cursor.position();
|
||||
|
||||
// StructureSize (2 bytes)
|
||||
let structure_size = cursor.read_u16_le()?;
|
||||
if structure_size != Self::STRUCTURE_SIZE {
|
||||
return Err(Error::invalid_data(format!(
|
||||
"invalid SetInfoRequest structure size: expected {}, got {}",
|
||||
Self::STRUCTURE_SIZE,
|
||||
structure_size
|
||||
)));
|
||||
}
|
||||
|
||||
// InfoType (1 byte)
|
||||
let info_type = InfoType::try_from(cursor.read_u8()?)?;
|
||||
// FileInfoClass (1 byte)
|
||||
let file_info_class = cursor.read_u8()?;
|
||||
// BufferLength (4 bytes)
|
||||
let buffer_length = cursor.read_u32_le()? as usize;
|
||||
// BufferOffset (2 bytes)
|
||||
let buf_offset = cursor.read_u16_le()? as usize;
|
||||
// Reserved (2 bytes)
|
||||
let _reserved = cursor.read_u16_le()?;
|
||||
// AdditionalInformation (4 bytes)
|
||||
let additional_information = cursor.read_u32_le()?;
|
||||
// FileId (16 bytes)
|
||||
let persistent = cursor.read_u64_le()?;
|
||||
let volatile = cursor.read_u64_le()?;
|
||||
let file_id = FileId {
|
||||
persistent,
|
||||
volatile,
|
||||
};
|
||||
|
||||
// Read buffer
|
||||
// Offset on the wire is from beginning of SMB2 header.
|
||||
let buffer = if buffer_length > 0 {
|
||||
let current = cursor.position();
|
||||
let body_offset = buf_offset.saturating_sub(Header::SIZE);
|
||||
let target = start + body_offset;
|
||||
if target > current {
|
||||
cursor.skip(target - current)?;
|
||||
}
|
||||
cursor.read_bytes_bounded(buffer_length)?.to_vec()
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
|
||||
Ok(SetInfoRequest {
|
||||
info_type,
|
||||
file_info_class,
|
||||
additional_information,
|
||||
file_id,
|
||||
buffer,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ── SetInfoResponse ──────────────────────────────────────────────────────
|
||||
|
||||
/// SMB2 SET_INFO response (spec section 2.2.40).
|
||||
///
|
||||
/// A minimal response indicating that the set operation succeeded.
|
||||
/// Contains only the 2-byte StructureSize field.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct SetInfoResponse;
|
||||
|
||||
impl SetInfoResponse {
|
||||
pub const STRUCTURE_SIZE: u16 = 2;
|
||||
}
|
||||
|
||||
impl Pack for SetInfoResponse {
|
||||
fn pack(&self, cursor: &mut WriteCursor) {
|
||||
// StructureSize (2 bytes)
|
||||
cursor.write_u16_le(Self::STRUCTURE_SIZE);
|
||||
}
|
||||
}
|
||||
|
||||
impl Unpack for SetInfoResponse {
|
||||
fn unpack(cursor: &mut ReadCursor<'_>) -> Result<Self> {
|
||||
// StructureSize (2 bytes)
|
||||
let structure_size = cursor.read_u16_le()?;
|
||||
if structure_size != Self::STRUCTURE_SIZE {
|
||||
return Err(Error::invalid_data(format!(
|
||||
"invalid SetInfoResponse structure size: expected {}, got {}",
|
||||
Self::STRUCTURE_SIZE,
|
||||
structure_size
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(SetInfoResponse)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// ── SetInfoRequest tests ─────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn set_info_request_roundtrip_with_buffer() {
|
||||
let info_data = vec![0xDE, 0xAD, 0xBE, 0xEF, 0x01, 0x02, 0x03, 0x04];
|
||||
|
||||
let original = SetInfoRequest {
|
||||
info_type: InfoType::File,
|
||||
file_info_class: 0x04, // FileBasicInformation
|
||||
additional_information: 0,
|
||||
file_id: FileId {
|
||||
persistent: 0xAAAA_BBBB_CCCC_DDDD,
|
||||
volatile: 0x1111_2222_3333_4444,
|
||||
},
|
||||
buffer: info_data.clone(),
|
||||
};
|
||||
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = SetInfoRequest::unpack(&mut r).unwrap();
|
||||
|
||||
assert_eq!(decoded.info_type, InfoType::File);
|
||||
assert_eq!(decoded.file_info_class, 0x04);
|
||||
assert_eq!(decoded.additional_information, 0);
|
||||
assert_eq!(decoded.file_id, original.file_id);
|
||||
assert_eq!(decoded.buffer, info_data);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn set_info_request_security_info() {
|
||||
let sd_data = vec![0x01, 0x00, 0x04, 0x80, 0x00, 0x00, 0x00, 0x00];
|
||||
|
||||
let original = SetInfoRequest {
|
||||
info_type: InfoType::Security,
|
||||
file_info_class: 0,
|
||||
additional_information: 0x04, // DACL_SECURITY_INFORMATION
|
||||
file_id: FileId {
|
||||
persistent: 42,
|
||||
volatile: 99,
|
||||
},
|
||||
buffer: sd_data.clone(),
|
||||
};
|
||||
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = SetInfoRequest::unpack(&mut r).unwrap();
|
||||
|
||||
assert_eq!(decoded.info_type, InfoType::Security);
|
||||
assert_eq!(decoded.additional_information, 0x04);
|
||||
assert_eq!(decoded.buffer, sd_data);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn set_info_request_structure_size() {
|
||||
let req = SetInfoRequest {
|
||||
info_type: InfoType::File,
|
||||
file_info_class: 0,
|
||||
additional_information: 0,
|
||||
file_id: FileId::default(),
|
||||
buffer: vec![0x01],
|
||||
};
|
||||
|
||||
let mut w = WriteCursor::new();
|
||||
req.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
assert_eq!(bytes[0], 33);
|
||||
assert_eq!(bytes[1], 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn set_info_request_wrong_structure_size() {
|
||||
let mut buf = vec![0u8; 48];
|
||||
buf[0..2].copy_from_slice(&99u16.to_le_bytes());
|
||||
let mut cursor = ReadCursor::new(&buf);
|
||||
let result = SetInfoRequest::unpack(&mut cursor);
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err().to_string();
|
||||
assert!(err.contains("structure size"), "error was: {err}");
|
||||
}
|
||||
|
||||
// ── SetInfoResponse tests ────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn set_info_response_roundtrip() {
|
||||
let original = SetInfoResponse;
|
||||
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
// Only 2 bytes
|
||||
assert_eq!(bytes.len(), 2);
|
||||
assert_eq!(bytes, [0x02, 0x00]);
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = SetInfoResponse::unpack(&mut r).unwrap();
|
||||
|
||||
assert_eq!(decoded, SetInfoResponse);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn set_info_response_wrong_structure_size() {
|
||||
let bytes = [0x04, 0x00];
|
||||
let mut cursor = ReadCursor::new(&bytes);
|
||||
let result = SetInfoResponse::unpack(&mut cursor);
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err().to_string();
|
||||
assert!(err.contains("structure size"), "error was: {err}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn set_info_response_too_short() {
|
||||
let bytes = [0x02];
|
||||
let mut cursor = ReadCursor::new(&bytes);
|
||||
let result = SetInfoResponse::unpack(&mut cursor);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod roundtrip_props {
|
||||
use super::*;
|
||||
use crate::msg::roundtrip_strategies::{arb_bytes, arb_file_id, arb_info_type};
|
||||
use proptest::prelude::*;
|
||||
|
||||
proptest! {
|
||||
#[test]
|
||||
fn set_info_request_pack_unpack(
|
||||
info_type in arb_info_type(),
|
||||
file_info_class in any::<u8>(),
|
||||
additional_information in any::<u32>(),
|
||||
file_id in arb_file_id(),
|
||||
buffer in arb_bytes(),
|
||||
) {
|
||||
let original = SetInfoRequest {
|
||||
info_type,
|
||||
file_info_class,
|
||||
additional_information,
|
||||
file_id,
|
||||
buffer,
|
||||
};
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = SetInfoRequest::unpack(&mut r).unwrap();
|
||||
prop_assert_eq!(decoded, original);
|
||||
}
|
||||
}
|
||||
}
|
||||
452
vendor/smb2/src/msg/transform.rs
vendored
Normal file
452
vendor/smb2/src/msg/transform.rs
vendored
Normal file
@@ -0,0 +1,452 @@
|
||||
//! SMB2 TRANSFORM_HEADER and COMPRESSION_TRANSFORM_HEADER
|
||||
//! (MS-SMB2 sections 2.2.41, 2.2.42).
|
||||
//!
|
||||
//! These headers wrap (encrypted or compressed) SMB2 messages. They are NOT
|
||||
//! SMB2 messages themselves -- they precede the actual message data.
|
||||
|
||||
use crate::error::Result;
|
||||
use crate::pack::{Pack, ReadCursor, Unpack, WriteCursor};
|
||||
use crate::types::SessionId;
|
||||
use crate::Error;
|
||||
|
||||
// ── Transform header protocol IDs ──────────────────────────────────────
|
||||
|
||||
/// Protocol identifier for the encryption transform header (0xFD 'S' 'M' 'B').
|
||||
/// Note: this is NOT the normal SMB2 protocol ID (0xFE).
|
||||
pub const TRANSFORM_PROTOCOL_ID: [u8; 4] = [0xFD, b'S', b'M', b'B'];
|
||||
|
||||
/// Protocol identifier for the compression transform header (0xFC 'S' 'M' 'B').
|
||||
pub const COMPRESSION_PROTOCOL_ID: [u8; 4] = [0xFC, b'S', b'M', b'B'];
|
||||
|
||||
// ── Transform header flags ────────────────────────────────────────────
|
||||
|
||||
/// The message is encrypted.
|
||||
pub const SMB2_TRANSFORM_HEADER_FLAG_ENCRYPTED: u16 = 0x0001;
|
||||
|
||||
// ── CompressionAlgorithm values ────────────────────────────────────────
|
||||
|
||||
/// No compression.
|
||||
pub const COMPRESSION_ALGORITHM_NONE: u16 = 0x0000;
|
||||
|
||||
/// LZNT1 compression.
|
||||
pub const COMPRESSION_ALGORITHM_LZNT1: u16 = 0x0001;
|
||||
|
||||
/// LZ77 compression.
|
||||
pub const COMPRESSION_ALGORITHM_LZ77: u16 = 0x0002;
|
||||
|
||||
/// LZ77 with Huffman encoding.
|
||||
pub const COMPRESSION_ALGORITHM_LZ77_HUFFMAN: u16 = 0x0003;
|
||||
|
||||
/// Pattern_V1 compression.
|
||||
pub const COMPRESSION_ALGORITHM_PATTERN_V1: u16 = 0x0004;
|
||||
|
||||
/// LZ4 compression.
|
||||
pub const COMPRESSION_ALGORITHM_LZ4: u16 = 0x0005;
|
||||
|
||||
// ── Compression flags ──────────────────────────────────────────────────
|
||||
|
||||
/// No compression flags.
|
||||
pub const SMB2_COMPRESSION_FLAG_NONE: u16 = 0x0000;
|
||||
|
||||
/// Chained compression (multiple segments).
|
||||
pub const SMB2_COMPRESSION_FLAG_CHAINED: u16 = 0x0001;
|
||||
|
||||
// ── TransformHeader ────────────────────────────────────────────────────
|
||||
|
||||
/// SMB2 TRANSFORM_HEADER (MS-SMB2 section 2.2.41).
|
||||
///
|
||||
/// An encryption wrapper that precedes an encrypted SMB2 message.
|
||||
/// The total header is 52 bytes:
|
||||
/// - ProtocolId (4 bytes, must be 0xFD 'S' 'M' 'B')
|
||||
/// - Signature (16 bytes)
|
||||
/// - Nonce (16 bytes -- first 11 bytes used for AES-CCM, first 12 for AES-GCM)
|
||||
/// - OriginalMessageSize (4 bytes)
|
||||
/// - Reserved (2 bytes)
|
||||
/// - Flags (2 bytes)
|
||||
/// - SessionId (8 bytes)
|
||||
///
|
||||
/// The encrypted message data follows immediately after this header.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct TransformHeader {
|
||||
/// 16-byte AES signature over the encrypted message.
|
||||
pub signature: [u8; 16],
|
||||
/// 16-byte nonce. Only the first 11 bytes are used for AES-CCM,
|
||||
/// and the first 12 bytes for AES-GCM. The remaining bytes must be zero.
|
||||
pub nonce: [u8; 16],
|
||||
/// Size of the original (unencrypted) SMB2 message in bytes.
|
||||
pub original_message_size: u32,
|
||||
/// Flags for the transform header. Use
|
||||
/// `SMB2_TRANSFORM_HEADER_FLAG_ENCRYPTED`.
|
||||
pub flags: u16,
|
||||
/// Session identifier for the encrypted message.
|
||||
pub session_id: SessionId,
|
||||
}
|
||||
|
||||
impl TransformHeader {
|
||||
/// Total header size in bytes (52).
|
||||
pub const SIZE: usize = 52;
|
||||
}
|
||||
|
||||
impl Pack for TransformHeader {
|
||||
fn pack(&self, cursor: &mut WriteCursor) {
|
||||
// ProtocolId (4 bytes)
|
||||
cursor.write_bytes(&TRANSFORM_PROTOCOL_ID);
|
||||
// Signature (16 bytes)
|
||||
cursor.write_bytes(&self.signature);
|
||||
// Nonce (16 bytes)
|
||||
cursor.write_bytes(&self.nonce);
|
||||
// OriginalMessageSize (4 bytes)
|
||||
cursor.write_u32_le(self.original_message_size);
|
||||
// Reserved (2 bytes)
|
||||
cursor.write_u16_le(0);
|
||||
// Flags (2 bytes)
|
||||
cursor.write_u16_le(self.flags);
|
||||
// SessionId (8 bytes)
|
||||
cursor.write_u64_le(self.session_id.0);
|
||||
}
|
||||
}
|
||||
|
||||
impl Unpack for TransformHeader {
|
||||
fn unpack(cursor: &mut ReadCursor<'_>) -> Result<Self> {
|
||||
// ProtocolId (4 bytes)
|
||||
let proto = cursor.read_bytes(4)?;
|
||||
if proto != TRANSFORM_PROTOCOL_ID {
|
||||
return Err(Error::invalid_data(format!(
|
||||
"invalid transform header protocol ID: expected {:02X?}, got {:02X?}",
|
||||
TRANSFORM_PROTOCOL_ID, proto
|
||||
)));
|
||||
}
|
||||
|
||||
// Signature (16 bytes)
|
||||
let sig_bytes = cursor.read_bytes(16)?;
|
||||
let mut signature = [0u8; 16];
|
||||
signature.copy_from_slice(sig_bytes);
|
||||
|
||||
// Nonce (16 bytes)
|
||||
let nonce_bytes = cursor.read_bytes(16)?;
|
||||
let mut nonce = [0u8; 16];
|
||||
nonce.copy_from_slice(nonce_bytes);
|
||||
|
||||
// OriginalMessageSize (4 bytes)
|
||||
let original_message_size = cursor.read_u32_le()?;
|
||||
|
||||
// Reserved (2 bytes)
|
||||
let _reserved = cursor.read_u16_le()?;
|
||||
|
||||
// Flags (2 bytes)
|
||||
let flags = cursor.read_u16_le()?;
|
||||
|
||||
// SessionId (8 bytes)
|
||||
let session_id = SessionId(cursor.read_u64_le()?);
|
||||
|
||||
Ok(TransformHeader {
|
||||
signature,
|
||||
nonce,
|
||||
original_message_size,
|
||||
flags,
|
||||
session_id,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ── CompressionTransformHeader ─────────────────────────────────────────
|
||||
|
||||
/// SMB2 COMPRESSION_TRANSFORM_HEADER (MS-SMB2 section 2.2.42).
|
||||
///
|
||||
/// A compression wrapper that precedes a compressed SMB2 message.
|
||||
/// This implements the unchained variant (Flags = 0) only. The total
|
||||
/// header is 16 bytes:
|
||||
/// - ProtocolId (4 bytes, must be 0xFC 'S' 'M' 'B')
|
||||
/// - OriginalCompressedSegmentSize (4 bytes)
|
||||
/// - CompressionAlgorithm (2 bytes)
|
||||
/// - Flags (2 bytes)
|
||||
/// - Offset (4 bytes) -- offset from the end of this header to the
|
||||
/// start of compressed data
|
||||
///
|
||||
/// Note: The chained variant (Flags = SMB2_COMPRESSION_FLAG_CHAINED)
|
||||
/// interprets the last 4 bytes as Length instead of Offset. Chained
|
||||
/// compression is deferred to a future implementation.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct CompressionTransformHeader {
|
||||
/// Size of the original uncompressed data segment.
|
||||
pub original_compressed_segment_size: u32,
|
||||
/// The compression algorithm used.
|
||||
pub compression_algorithm: u16,
|
||||
/// Compression flags. Currently only unchained (0x0000) is supported.
|
||||
pub flags: u16,
|
||||
/// For unchained: offset from end of this header to the start of
|
||||
/// compressed data. For chained: length of the original uncompressed
|
||||
/// segment (chained is not yet implemented).
|
||||
pub offset_or_length: u32,
|
||||
}
|
||||
|
||||
impl CompressionTransformHeader {
|
||||
/// Total header size in bytes (16).
|
||||
pub const SIZE: usize = 16;
|
||||
}
|
||||
|
||||
impl Pack for CompressionTransformHeader {
|
||||
fn pack(&self, cursor: &mut WriteCursor) {
|
||||
// ProtocolId (4 bytes)
|
||||
cursor.write_bytes(&COMPRESSION_PROTOCOL_ID);
|
||||
// OriginalCompressedSegmentSize (4 bytes)
|
||||
cursor.write_u32_le(self.original_compressed_segment_size);
|
||||
// CompressionAlgorithm (2 bytes)
|
||||
cursor.write_u16_le(self.compression_algorithm);
|
||||
// Flags (2 bytes)
|
||||
cursor.write_u16_le(self.flags);
|
||||
// Offset/Length (4 bytes)
|
||||
cursor.write_u32_le(self.offset_or_length);
|
||||
}
|
||||
}
|
||||
|
||||
impl Unpack for CompressionTransformHeader {
|
||||
fn unpack(cursor: &mut ReadCursor<'_>) -> Result<Self> {
|
||||
// ProtocolId (4 bytes)
|
||||
let proto = cursor.read_bytes(4)?;
|
||||
if proto != COMPRESSION_PROTOCOL_ID {
|
||||
return Err(Error::invalid_data(format!(
|
||||
"invalid compression transform header protocol ID: expected {:02X?}, got {:02X?}",
|
||||
COMPRESSION_PROTOCOL_ID, proto
|
||||
)));
|
||||
}
|
||||
|
||||
// OriginalCompressedSegmentSize (4 bytes)
|
||||
let original_compressed_segment_size = cursor.read_u32_le()?;
|
||||
|
||||
// CompressionAlgorithm (2 bytes)
|
||||
let compression_algorithm = cursor.read_u16_le()?;
|
||||
|
||||
// Flags (2 bytes)
|
||||
let flags = cursor.read_u16_le()?;
|
||||
|
||||
// Offset/Length (4 bytes)
|
||||
let offset_or_length = cursor.read_u32_le()?;
|
||||
|
||||
Ok(CompressionTransformHeader {
|
||||
original_compressed_segment_size,
|
||||
compression_algorithm,
|
||||
flags,
|
||||
offset_or_length,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// ── TransformHeader tests ─────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn transform_header_roundtrip() {
|
||||
let mut nonce = [0u8; 16];
|
||||
nonce[0..12].copy_from_slice(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]);
|
||||
|
||||
let original = TransformHeader {
|
||||
signature: [
|
||||
0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88,
|
||||
0x99, 0x00,
|
||||
],
|
||||
nonce,
|
||||
original_message_size: 1024,
|
||||
flags: SMB2_TRANSFORM_HEADER_FLAG_ENCRYPTED,
|
||||
session_id: SessionId(0xDEAD_BEEF_CAFE_FACE),
|
||||
};
|
||||
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
assert_eq!(bytes.len(), TransformHeader::SIZE);
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = TransformHeader::unpack(&mut r).unwrap();
|
||||
|
||||
assert_eq!(decoded.signature, original.signature);
|
||||
assert_eq!(decoded.nonce, original.nonce);
|
||||
assert_eq!(decoded.original_message_size, 1024);
|
||||
assert_eq!(decoded.flags, SMB2_TRANSFORM_HEADER_FLAG_ENCRYPTED);
|
||||
assert_eq!(decoded.session_id, SessionId(0xDEAD_BEEF_CAFE_FACE));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn transform_header_protocol_id_is_0xfd() {
|
||||
let original = TransformHeader {
|
||||
signature: [0u8; 16],
|
||||
nonce: [0u8; 16],
|
||||
original_message_size: 0,
|
||||
flags: 0,
|
||||
session_id: SessionId(0),
|
||||
};
|
||||
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
// First 4 bytes must be 0xFD 'S' 'M' 'B', NOT 0xFE
|
||||
assert_eq!(bytes[0], 0xFD);
|
||||
assert_eq!(bytes[1], b'S');
|
||||
assert_eq!(bytes[2], b'M');
|
||||
assert_eq!(bytes[3], b'B');
|
||||
assert_ne!(bytes[0], 0xFE, "transform header must use 0xFD, not 0xFE");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn transform_header_wrong_protocol_id() {
|
||||
let mut buf = [0u8; TransformHeader::SIZE];
|
||||
// Use the normal SMB2 protocol ID (0xFE) instead of 0xFD
|
||||
buf[0..4].copy_from_slice(&[0xFE, b'S', b'M', b'B']);
|
||||
|
||||
let mut cursor = ReadCursor::new(&buf);
|
||||
let result = TransformHeader::unpack(&mut cursor);
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err().to_string();
|
||||
assert!(err.contains("protocol ID"), "error was: {err}");
|
||||
}
|
||||
|
||||
// ── CompressionTransformHeader tests ──────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn compression_transform_header_roundtrip_unchained() {
|
||||
let original = CompressionTransformHeader {
|
||||
original_compressed_segment_size: 4096,
|
||||
compression_algorithm: COMPRESSION_ALGORITHM_LZ77,
|
||||
flags: SMB2_COMPRESSION_FLAG_NONE,
|
||||
offset_or_length: 64,
|
||||
};
|
||||
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
assert_eq!(bytes.len(), CompressionTransformHeader::SIZE);
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = CompressionTransformHeader::unpack(&mut r).unwrap();
|
||||
|
||||
assert_eq!(decoded.original_compressed_segment_size, 4096);
|
||||
assert_eq!(decoded.compression_algorithm, COMPRESSION_ALGORITHM_LZ77);
|
||||
assert_eq!(decoded.flags, SMB2_COMPRESSION_FLAG_NONE);
|
||||
assert_eq!(decoded.offset_or_length, 64);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compression_transform_header_protocol_id_is_0xfc() {
|
||||
let original = CompressionTransformHeader {
|
||||
original_compressed_segment_size: 0,
|
||||
compression_algorithm: COMPRESSION_ALGORITHM_NONE,
|
||||
flags: SMB2_COMPRESSION_FLAG_NONE,
|
||||
offset_or_length: 0,
|
||||
};
|
||||
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
// First 4 bytes must be 0xFC 'S' 'M' 'B'
|
||||
assert_eq!(bytes[0], 0xFC);
|
||||
assert_eq!(bytes[1], b'S');
|
||||
assert_eq!(bytes[2], b'M');
|
||||
assert_eq!(bytes[3], b'B');
|
||||
assert_ne!(
|
||||
bytes[0], 0xFE,
|
||||
"compression transform header must use 0xFC, not 0xFE"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compression_transform_header_wrong_protocol_id() {
|
||||
let mut buf = [0u8; CompressionTransformHeader::SIZE];
|
||||
// Use wrong protocol ID
|
||||
buf[0..4].copy_from_slice(&[0xFE, b'S', b'M', b'B']);
|
||||
|
||||
let mut cursor = ReadCursor::new(&buf);
|
||||
let result = CompressionTransformHeader::unpack(&mut cursor);
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err().to_string();
|
||||
assert!(err.contains("protocol ID"), "error was: {err}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compression_transform_header_lz77_huffman() {
|
||||
let original = CompressionTransformHeader {
|
||||
original_compressed_segment_size: 8192,
|
||||
compression_algorithm: COMPRESSION_ALGORITHM_LZ77_HUFFMAN,
|
||||
flags: SMB2_COMPRESSION_FLAG_NONE,
|
||||
offset_or_length: 128,
|
||||
};
|
||||
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = CompressionTransformHeader::unpack(&mut r).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
decoded.compression_algorithm,
|
||||
COMPRESSION_ALGORITHM_LZ77_HUFFMAN
|
||||
);
|
||||
assert_eq!(decoded.original_compressed_segment_size, 8192);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod roundtrip_props {
|
||||
use super::*;
|
||||
use crate::msg::roundtrip_strategies::arb_session_id;
|
||||
use proptest::prelude::*;
|
||||
|
||||
proptest! {
|
||||
#[test]
|
||||
fn transform_header_pack_unpack(
|
||||
signature in any::<[u8; 16]>(),
|
||||
nonce in any::<[u8; 16]>(),
|
||||
original_message_size in any::<u32>(),
|
||||
flags in any::<u16>(),
|
||||
session_id in arb_session_id(),
|
||||
) {
|
||||
let original = TransformHeader {
|
||||
signature,
|
||||
nonce,
|
||||
original_message_size,
|
||||
flags,
|
||||
session_id,
|
||||
};
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
prop_assert_eq!(bytes.len(), TransformHeader::SIZE);
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = TransformHeader::unpack(&mut r).unwrap();
|
||||
prop_assert_eq!(decoded, original);
|
||||
prop_assert!(r.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compression_transform_header_pack_unpack(
|
||||
original_compressed_segment_size in any::<u32>(),
|
||||
compression_algorithm in any::<u16>(),
|
||||
flags in any::<u16>(),
|
||||
offset_or_length in any::<u32>(),
|
||||
) {
|
||||
let original = CompressionTransformHeader {
|
||||
original_compressed_segment_size,
|
||||
compression_algorithm,
|
||||
flags,
|
||||
offset_or_length,
|
||||
};
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
prop_assert_eq!(bytes.len(), CompressionTransformHeader::SIZE);
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = CompressionTransformHeader::unpack(&mut r).unwrap();
|
||||
prop_assert_eq!(decoded, original);
|
||||
prop_assert!(r.is_empty());
|
||||
}
|
||||
}
|
||||
}
|
||||
477
vendor/smb2/src/msg/tree_connect.rs
vendored
Normal file
477
vendor/smb2/src/msg/tree_connect.rs
vendored
Normal file
@@ -0,0 +1,477 @@
|
||||
//! SMB2 TREE_CONNECT request and response (spec sections 2.2.9, 2.2.10).
|
||||
//!
|
||||
//! Tree connect messages establish access to a share on the server.
|
||||
//! The request contains a UTF-16LE encoded share path (for example,
|
||||
//! `\\server\share`), and the response contains share metadata such as
|
||||
//! the share type, flags, capabilities, and maximal access rights.
|
||||
|
||||
use crate::error::Result;
|
||||
use crate::msg::header::Header;
|
||||
use crate::pack::{Pack, ReadCursor, Unpack, WriteCursor};
|
||||
use crate::types::flags::{ShareCapabilities, ShareFlags};
|
||||
use crate::Error;
|
||||
|
||||
// ── Share type ─────────────────────────────────────────────────────────
|
||||
|
||||
/// Type of share being accessed (spec section 2.2.10).
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
#[repr(u8)]
|
||||
pub enum ShareType {
|
||||
/// Physical disk share.
|
||||
Disk = 0x01,
|
||||
/// Named pipe share.
|
||||
Pipe = 0x02,
|
||||
/// Printer share.
|
||||
Print = 0x03,
|
||||
}
|
||||
|
||||
impl ShareType {
|
||||
/// Try to convert a raw `u8` to a `ShareType`.
|
||||
pub fn try_from_u8(val: u8) -> Result<Self> {
|
||||
match val {
|
||||
0x01 => Ok(ShareType::Disk),
|
||||
0x02 => Ok(ShareType::Pipe),
|
||||
0x03 => Ok(ShareType::Print),
|
||||
other => Err(Error::invalid_data(format!(
|
||||
"invalid share type: 0x{:02X}",
|
||||
other
|
||||
))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Tree connect request flags ─────────────────────────────────────────
|
||||
|
||||
/// Flags for the TREE_CONNECT request (spec section 2.2.9, SMB 3.1.1 only).
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
|
||||
pub struct TreeConnectRequestFlags(pub u16);
|
||||
|
||||
impl TreeConnectRequestFlags {
|
||||
/// Client has previously connected to the specified cluster share.
|
||||
pub const CLUSTER_RECONNECT: u16 = 0x0001;
|
||||
/// Client can handle synchronous share redirects.
|
||||
pub const REDIRECT_TO_OWNER: u16 = 0x0002;
|
||||
/// Tree connect request extension is present.
|
||||
pub const EXTENSION_PRESENT: u16 = 0x0004;
|
||||
}
|
||||
|
||||
// ── TreeConnectRequest ─────────────────────────────────────────────────
|
||||
|
||||
/// SMB2 TREE_CONNECT request (spec section 2.2.9).
|
||||
///
|
||||
/// Sent by the client to request access to a particular share on the
|
||||
/// server. The path is a Unicode string in the form `\\server\share`.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct TreeConnectRequest {
|
||||
/// Flags controlling the request (SMB 3.1.1 only, otherwise 0).
|
||||
pub flags: TreeConnectRequestFlags,
|
||||
/// Full share path name in UTF-8 (encoded as UTF-16LE on the wire).
|
||||
pub path: String,
|
||||
}
|
||||
|
||||
impl TreeConnectRequest {
|
||||
pub const STRUCTURE_SIZE: u16 = 9;
|
||||
}
|
||||
|
||||
impl Pack for TreeConnectRequest {
|
||||
fn pack(&self, cursor: &mut WriteCursor) {
|
||||
// StructureSize (2 bytes)
|
||||
cursor.write_u16_le(Self::STRUCTURE_SIZE);
|
||||
// Flags/Reserved (2 bytes)
|
||||
cursor.write_u16_le(self.flags.0);
|
||||
|
||||
// Compute path length in UTF-16LE bytes
|
||||
let path_u16: Vec<u16> = self.path.encode_utf16().collect();
|
||||
let path_byte_len = path_u16.len() * 2;
|
||||
|
||||
// PathOffset (2 bytes) -- offset from start of SMB2 header
|
||||
let offset = (Header::SIZE + 8) as u16; // 8 = fixed part of this struct
|
||||
cursor.write_u16_le(offset);
|
||||
// PathLength (2 bytes)
|
||||
cursor.write_u16_le(path_byte_len as u16);
|
||||
// Buffer: path in UTF-16LE
|
||||
cursor.write_utf16_le(&self.path);
|
||||
}
|
||||
}
|
||||
|
||||
impl Unpack for TreeConnectRequest {
|
||||
fn unpack(cursor: &mut ReadCursor<'_>) -> Result<Self> {
|
||||
// StructureSize (2 bytes)
|
||||
let structure_size = cursor.read_u16_le()?;
|
||||
if structure_size != Self::STRUCTURE_SIZE {
|
||||
return Err(Error::invalid_data(format!(
|
||||
"invalid TreeConnectRequest structure size: expected {}, got {}",
|
||||
Self::STRUCTURE_SIZE,
|
||||
structure_size
|
||||
)));
|
||||
}
|
||||
|
||||
// Flags/Reserved (2 bytes)
|
||||
let flags = TreeConnectRequestFlags(cursor.read_u16_le()?);
|
||||
// PathOffset (2 bytes) -- we ignore, read sequentially
|
||||
let _offset = cursor.read_u16_le()?;
|
||||
// PathLength (2 bytes)
|
||||
let path_length = cursor.read_u16_le()? as usize;
|
||||
// Buffer: path in UTF-16LE
|
||||
if path_length > ReadCursor::MAX_UNPACK_BUFFER {
|
||||
return Err(Error::invalid_data(format!(
|
||||
"buffer size {} exceeds maximum {} bytes",
|
||||
path_length,
|
||||
ReadCursor::MAX_UNPACK_BUFFER
|
||||
)));
|
||||
}
|
||||
let path = cursor.read_utf16_le(path_length)?;
|
||||
|
||||
Ok(TreeConnectRequest { flags, path })
|
||||
}
|
||||
}
|
||||
|
||||
// ── TreeConnectResponse ────────────────────────────────────────────────
|
||||
|
||||
/// SMB2 TREE_CONNECT response (spec section 2.2.10).
|
||||
///
|
||||
/// Sent by the server when a TREE_CONNECT request is processed
|
||||
/// successfully. Contains share metadata.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct TreeConnectResponse {
|
||||
/// The type of share being accessed (disk, pipe, or print).
|
||||
pub share_type: ShareType,
|
||||
/// Properties for this share.
|
||||
pub share_flags: ShareFlags,
|
||||
/// Capabilities for this share.
|
||||
pub capabilities: ShareCapabilities,
|
||||
/// Maximum access rights for the connecting user.
|
||||
pub maximal_access: u32,
|
||||
}
|
||||
|
||||
impl TreeConnectResponse {
|
||||
pub const STRUCTURE_SIZE: u16 = 16;
|
||||
}
|
||||
|
||||
impl Pack for TreeConnectResponse {
|
||||
fn pack(&self, cursor: &mut WriteCursor) {
|
||||
// StructureSize (2 bytes)
|
||||
cursor.write_u16_le(Self::STRUCTURE_SIZE);
|
||||
// ShareType (1 byte)
|
||||
cursor.write_u8(self.share_type as u8);
|
||||
// Reserved (1 byte)
|
||||
cursor.write_u8(0);
|
||||
// ShareFlags (4 bytes)
|
||||
cursor.write_u32_le(self.share_flags.bits());
|
||||
// Capabilities (4 bytes)
|
||||
cursor.write_u32_le(self.capabilities.bits());
|
||||
// MaximalAccess (4 bytes)
|
||||
cursor.write_u32_le(self.maximal_access);
|
||||
}
|
||||
}
|
||||
|
||||
impl Unpack for TreeConnectResponse {
|
||||
fn unpack(cursor: &mut ReadCursor<'_>) -> Result<Self> {
|
||||
// StructureSize (2 bytes)
|
||||
let structure_size = cursor.read_u16_le()?;
|
||||
if structure_size != Self::STRUCTURE_SIZE {
|
||||
return Err(Error::invalid_data(format!(
|
||||
"invalid TreeConnectResponse structure size: expected {}, got {}",
|
||||
Self::STRUCTURE_SIZE,
|
||||
structure_size
|
||||
)));
|
||||
}
|
||||
|
||||
// ShareType (1 byte)
|
||||
let share_type = ShareType::try_from_u8(cursor.read_u8()?)?;
|
||||
// Reserved (1 byte)
|
||||
let _reserved = cursor.read_u8()?;
|
||||
// ShareFlags (4 bytes)
|
||||
let share_flags = ShareFlags::new(cursor.read_u32_le()?);
|
||||
// Capabilities (4 bytes)
|
||||
let capabilities = ShareCapabilities::new(cursor.read_u32_le()?);
|
||||
// MaximalAccess (4 bytes)
|
||||
let maximal_access = cursor.read_u32_le()?;
|
||||
|
||||
Ok(TreeConnectResponse {
|
||||
share_type,
|
||||
share_flags,
|
||||
capabilities,
|
||||
maximal_access,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// ── TreeConnectRequest tests ───────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn tree_connect_request_roundtrip() {
|
||||
let original = TreeConnectRequest {
|
||||
flags: TreeConnectRequestFlags::default(),
|
||||
path: r"\\server\share".to_string(),
|
||||
};
|
||||
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = TreeConnectRequest::unpack(&mut r).unwrap();
|
||||
|
||||
assert_eq!(decoded.flags, original.flags);
|
||||
assert_eq!(decoded.path, original.path);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tree_connect_request_with_utf16_path() {
|
||||
let path = r"\\myserver.example.com\IPC$";
|
||||
let original = TreeConnectRequest {
|
||||
flags: TreeConnectRequestFlags::default(),
|
||||
path: path.to_string(),
|
||||
};
|
||||
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = TreeConnectRequest::unpack(&mut r).unwrap();
|
||||
|
||||
assert_eq!(decoded.path, path);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tree_connect_request_structure_size_field() {
|
||||
let req = TreeConnectRequest {
|
||||
flags: TreeConnectRequestFlags::default(),
|
||||
path: r"\\s\d".to_string(),
|
||||
};
|
||||
|
||||
let mut w = WriteCursor::new();
|
||||
req.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
// First 2 bytes are structure size = 9
|
||||
assert_eq!(u16::from_le_bytes([bytes[0], bytes[1]]), 9);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tree_connect_request_wrong_structure_size() {
|
||||
let mut buf = [0u8; 20];
|
||||
buf[0..2].copy_from_slice(&99u16.to_le_bytes());
|
||||
let mut cursor = ReadCursor::new(&buf);
|
||||
let result = TreeConnectRequest::unpack(&mut cursor);
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err().to_string();
|
||||
assert!(err.contains("structure size"), "error was: {err}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tree_connect_request_with_flags() {
|
||||
let original = TreeConnectRequest {
|
||||
flags: TreeConnectRequestFlags(TreeConnectRequestFlags::CLUSTER_RECONNECT),
|
||||
path: r"\\s\d".to_string(),
|
||||
};
|
||||
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = TreeConnectRequest::unpack(&mut r).unwrap();
|
||||
|
||||
assert_eq!(decoded.flags.0, TreeConnectRequestFlags::CLUSTER_RECONNECT);
|
||||
}
|
||||
|
||||
// ── TreeConnectResponse tests ──────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn tree_connect_response_roundtrip_disk() {
|
||||
let original = TreeConnectResponse {
|
||||
share_type: ShareType::Disk,
|
||||
share_flags: ShareFlags::new(ShareFlags::DFS | ShareFlags::ACCESS_BASED_DIRECTORY_ENUM),
|
||||
capabilities: ShareCapabilities::new(ShareCapabilities::DFS),
|
||||
maximal_access: 0x001F_01FF,
|
||||
};
|
||||
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = TreeConnectResponse::unpack(&mut r).unwrap();
|
||||
|
||||
assert_eq!(decoded.share_type, ShareType::Disk);
|
||||
assert_eq!(decoded.share_flags.bits(), original.share_flags.bits());
|
||||
assert_eq!(decoded.capabilities.bits(), original.capabilities.bits());
|
||||
assert_eq!(decoded.maximal_access, 0x001F_01FF);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tree_connect_response_roundtrip_pipe() {
|
||||
let original = TreeConnectResponse {
|
||||
share_type: ShareType::Pipe,
|
||||
share_flags: ShareFlags::default(),
|
||||
capabilities: ShareCapabilities::default(),
|
||||
maximal_access: 0x0012_019F,
|
||||
};
|
||||
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = TreeConnectResponse::unpack(&mut r).unwrap();
|
||||
|
||||
assert_eq!(decoded.share_type, ShareType::Pipe);
|
||||
assert_eq!(decoded.maximal_access, 0x0012_019F);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tree_connect_response_roundtrip_print() {
|
||||
let original = TreeConnectResponse {
|
||||
share_type: ShareType::Print,
|
||||
share_flags: ShareFlags::new(ShareFlags::ENCRYPT_DATA),
|
||||
capabilities: ShareCapabilities::new(
|
||||
ShareCapabilities::CONTINUOUS_AVAILABILITY | ShareCapabilities::CLUSTER,
|
||||
),
|
||||
maximal_access: 0,
|
||||
};
|
||||
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = TreeConnectResponse::unpack(&mut r).unwrap();
|
||||
|
||||
assert_eq!(decoded.share_type, ShareType::Print);
|
||||
assert!(decoded.share_flags.contains(ShareFlags::ENCRYPT_DATA));
|
||||
assert!(decoded
|
||||
.capabilities
|
||||
.contains(ShareCapabilities::CONTINUOUS_AVAILABILITY));
|
||||
assert!(decoded.capabilities.contains(ShareCapabilities::CLUSTER));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tree_connect_response_structure_size_field() {
|
||||
let resp = TreeConnectResponse {
|
||||
share_type: ShareType::Disk,
|
||||
share_flags: ShareFlags::default(),
|
||||
capabilities: ShareCapabilities::default(),
|
||||
maximal_access: 0,
|
||||
};
|
||||
|
||||
let mut w = WriteCursor::new();
|
||||
resp.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
// First 2 bytes are structure size = 16
|
||||
assert_eq!(u16::from_le_bytes([bytes[0], bytes[1]]), 16);
|
||||
// Total packed size: 2 + 1 + 1 + 4 + 4 + 4 = 16
|
||||
assert_eq!(bytes.len(), 16);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tree_connect_response_wrong_structure_size() {
|
||||
let mut buf = [0u8; 16];
|
||||
buf[0..2].copy_from_slice(&99u16.to_le_bytes());
|
||||
buf[2] = 0x01; // valid share type
|
||||
let mut cursor = ReadCursor::new(&buf);
|
||||
let result = TreeConnectResponse::unpack(&mut cursor);
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err().to_string();
|
||||
assert!(err.contains("structure size"), "error was: {err}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tree_connect_response_invalid_share_type() {
|
||||
let mut buf = [0u8; 16];
|
||||
buf[0..2].copy_from_slice(&16u16.to_le_bytes());
|
||||
buf[2] = 0xFF; // invalid share type
|
||||
let mut cursor = ReadCursor::new(&buf);
|
||||
let result = TreeConnectResponse::unpack(&mut cursor);
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err().to_string();
|
||||
assert!(err.contains("share type"), "error was: {err}");
|
||||
}
|
||||
|
||||
// Roundtrip property tests live in `roundtrip_props` at file end.
|
||||
|
||||
#[test]
|
||||
fn tree_connect_response_known_bytes() {
|
||||
// Known bytes from smb-rs test: share_type=Disk, share_flags=0x00000800,
|
||||
// capabilities=0, maximal_access=0x001f01ff
|
||||
let bytes: Vec<u8> = vec![
|
||||
0x10, 0x00, // StructureSize = 16
|
||||
0x01, // ShareType = Disk
|
||||
0x00, // Reserved
|
||||
0x00, 0x08, 0x00, 0x00, // ShareFlags = 0x00000800
|
||||
0x00, 0x00, 0x00, 0x00, // Capabilities = 0
|
||||
0xFF, 0x01, 0x1F, 0x00, // MaximalAccess = 0x001f01ff
|
||||
];
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = TreeConnectResponse::unpack(&mut r).unwrap();
|
||||
|
||||
assert_eq!(decoded.share_type, ShareType::Disk);
|
||||
assert!(decoded
|
||||
.share_flags
|
||||
.contains(ShareFlags::ACCESS_BASED_DIRECTORY_ENUM));
|
||||
assert_eq!(decoded.maximal_access, 0x001F_01FF);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod roundtrip_props {
|
||||
use super::*;
|
||||
use crate::msg::roundtrip_strategies::{
|
||||
arb_share_capabilities, arb_share_flags, arb_share_type, arb_utf16_string,
|
||||
};
|
||||
use proptest::prelude::*;
|
||||
|
||||
proptest! {
|
||||
#[test]
|
||||
fn tree_connect_request_pack_unpack(
|
||||
flags_raw in any::<u16>(),
|
||||
// Path is sent as UTF-16LE. Generate strings that survive that
|
||||
// encoding cleanly (no unpaired surrogates).
|
||||
path in arb_utf16_string(128),
|
||||
) {
|
||||
let original = TreeConnectRequest {
|
||||
flags: TreeConnectRequestFlags(flags_raw),
|
||||
path,
|
||||
};
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = TreeConnectRequest::unpack(&mut r).unwrap();
|
||||
prop_assert_eq!(decoded, original);
|
||||
prop_assert!(r.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tree_connect_response_pack_unpack(
|
||||
share_type in arb_share_type(),
|
||||
share_flags in arb_share_flags(),
|
||||
capabilities in arb_share_capabilities(),
|
||||
maximal_access in any::<u32>(),
|
||||
) {
|
||||
let original = TreeConnectResponse {
|
||||
share_type,
|
||||
share_flags,
|
||||
capabilities,
|
||||
maximal_access,
|
||||
};
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = TreeConnectResponse::unpack(&mut r).unwrap();
|
||||
prop_assert_eq!(decoded, original);
|
||||
prop_assert!(r.is_empty());
|
||||
}
|
||||
}
|
||||
}
|
||||
43
vendor/smb2/src/msg/tree_disconnect.rs
vendored
Normal file
43
vendor/smb2/src/msg/tree_disconnect.rs
vendored
Normal file
@@ -0,0 +1,43 @@
|
||||
//! SMB2 TREE_DISCONNECT request and response (spec sections 2.2.11, 2.2.12).
|
||||
//!
|
||||
//! Tree disconnect messages request and confirm disconnection from a share.
|
||||
//! Both request and response contain only a StructureSize field and a
|
||||
//! reserved field, for a total of 4 bytes each.
|
||||
|
||||
super::trivial_message! {
|
||||
/// SMB2 TREE_DISCONNECT request (spec section 2.2.11).
|
||||
///
|
||||
/// Sent by the client to request that the tree connect specified in the
|
||||
/// TreeId within the SMB2 header be disconnected.
|
||||
/// Contains only StructureSize (2 bytes) and Reserved (2 bytes).
|
||||
pub struct TreeDisconnectRequest;
|
||||
}
|
||||
|
||||
super::trivial_message! {
|
||||
/// SMB2 TREE_DISCONNECT response (spec section 2.2.12).
|
||||
///
|
||||
/// Sent by the server to confirm that a TREE_DISCONNECT request was processed.
|
||||
/// Contains only StructureSize (2 bytes) and Reserved (2 bytes).
|
||||
pub struct TreeDisconnectResponse;
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
super::super::trivial_message_tests!(
|
||||
TreeDisconnectRequest,
|
||||
tree_disconnect_request_known_bytes,
|
||||
tree_disconnect_request_roundtrip,
|
||||
tree_disconnect_request_wrong_structure_size,
|
||||
tree_disconnect_request_too_short
|
||||
);
|
||||
|
||||
super::super::trivial_message_tests!(
|
||||
TreeDisconnectResponse,
|
||||
tree_disconnect_response_known_bytes,
|
||||
tree_disconnect_response_roundtrip,
|
||||
tree_disconnect_response_wrong_structure_size,
|
||||
tree_disconnect_response_too_short
|
||||
);
|
||||
}
|
||||
446
vendor/smb2/src/msg/write.rs
vendored
Normal file
446
vendor/smb2/src/msg/write.rs
vendored
Normal file
@@ -0,0 +1,446 @@
|
||||
//! SMB2 WRITE Request and Response (MS-SMB2 sections 2.2.21, 2.2.22).
|
||||
//!
|
||||
//! The WRITE request writes data to a file or named pipe.
|
||||
//! The response reports how many bytes were written.
|
||||
|
||||
use crate::error::Result;
|
||||
use crate::pack::{Pack, ReadCursor, Unpack, WriteCursor};
|
||||
use crate::types::FileId;
|
||||
use crate::Error;
|
||||
|
||||
/// Write flag: server performs write-through (SMB 2.1+).
|
||||
pub const SMB2_WRITEFLAG_WRITE_THROUGH: u32 = 0x0000_0001;
|
||||
|
||||
/// Write flag: file buffering is not performed (SMB 3.0.2+).
|
||||
pub const SMB2_WRITEFLAG_WRITE_UNBUFFERED: u32 = 0x0000_0002;
|
||||
|
||||
/// SMB2 WRITE Request (MS-SMB2 section 2.2.21).
|
||||
///
|
||||
/// Sent by the client to write data to a file. The fixed portion is 49 bytes
|
||||
/// (StructureSize says 49 regardless of the variable buffer length):
|
||||
/// - StructureSize (2 bytes, must be 49)
|
||||
/// - DataOffset (2 bytes)
|
||||
/// - Length (4 bytes)
|
||||
/// - Offset (8 bytes)
|
||||
/// - FileId (16 bytes)
|
||||
/// - Channel (4 bytes)
|
||||
/// - RemainingBytes (4 bytes)
|
||||
/// - WriteChannelInfoOffset (2 bytes)
|
||||
/// - WriteChannelInfoLength (2 bytes)
|
||||
/// - Flags (4 bytes)
|
||||
/// - Buffer (variable, Length bytes)
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct WriteRequest {
|
||||
/// Offset from the beginning of the SMB2 header to the write data.
|
||||
pub data_offset: u16,
|
||||
/// File offset to start writing at.
|
||||
pub offset: u64,
|
||||
/// File handle to write to.
|
||||
pub file_id: FileId,
|
||||
/// Channel for RDMA operations (typically 0 = SMB2_CHANNEL_NONE).
|
||||
pub channel: u32,
|
||||
/// Remaining bytes in a multi-part write.
|
||||
pub remaining_bytes: u32,
|
||||
/// Write channel info offset (typically 0).
|
||||
pub write_channel_info_offset: u16,
|
||||
/// Write channel info length (typically 0).
|
||||
pub write_channel_info_length: u16,
|
||||
/// Flags for the write operation.
|
||||
pub flags: u32,
|
||||
/// The data to write.
|
||||
pub data: Vec<u8>,
|
||||
}
|
||||
|
||||
impl WriteRequest {
|
||||
pub const STRUCTURE_SIZE: u16 = 49;
|
||||
}
|
||||
|
||||
impl Pack for WriteRequest {
|
||||
fn pack(&self, cursor: &mut WriteCursor) {
|
||||
cursor.write_u16_le(Self::STRUCTURE_SIZE);
|
||||
cursor.write_u16_le(self.data_offset);
|
||||
cursor.write_u32_le(self.data.len() as u32); // Length
|
||||
cursor.write_u64_le(self.offset);
|
||||
cursor.write_u64_le(self.file_id.persistent);
|
||||
cursor.write_u64_le(self.file_id.volatile);
|
||||
cursor.write_u32_le(self.channel);
|
||||
cursor.write_u32_le(self.remaining_bytes);
|
||||
cursor.write_u16_le(self.write_channel_info_offset);
|
||||
cursor.write_u16_le(self.write_channel_info_length);
|
||||
cursor.write_u32_le(self.flags);
|
||||
|
||||
// Buffer: write the data (may be empty for zero-length writes).
|
||||
// Per StructureSize=49 contract, at least 1 byte is implied.
|
||||
if self.data.is_empty() {
|
||||
cursor.write_u8(0);
|
||||
} else {
|
||||
cursor.write_bytes(&self.data);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Unpack for WriteRequest {
|
||||
fn unpack(cursor: &mut ReadCursor<'_>) -> Result<Self> {
|
||||
let structure_size = cursor.read_u16_le()?;
|
||||
if structure_size != Self::STRUCTURE_SIZE {
|
||||
return Err(Error::invalid_data(format!(
|
||||
"invalid WriteRequest structure size: expected {}, got {}",
|
||||
Self::STRUCTURE_SIZE,
|
||||
structure_size
|
||||
)));
|
||||
}
|
||||
|
||||
let data_offset = cursor.read_u16_le()?;
|
||||
let length = cursor.read_u32_le()?;
|
||||
let offset = cursor.read_u64_le()?;
|
||||
let persistent = cursor.read_u64_le()?;
|
||||
let volatile = cursor.read_u64_le()?;
|
||||
let channel = cursor.read_u32_le()?;
|
||||
let remaining_bytes = cursor.read_u32_le()?;
|
||||
let write_channel_info_offset = cursor.read_u16_le()?;
|
||||
let write_channel_info_length = cursor.read_u16_le()?;
|
||||
let flags = cursor.read_u32_le()?;
|
||||
|
||||
let data = if length > 0 {
|
||||
cursor.read_bytes_bounded(length as usize)?.to_vec()
|
||||
} else {
|
||||
// Skip the minimum 1-byte buffer
|
||||
cursor.skip(1)?;
|
||||
Vec::new()
|
||||
};
|
||||
|
||||
Ok(WriteRequest {
|
||||
data_offset,
|
||||
offset,
|
||||
file_id: FileId {
|
||||
persistent,
|
||||
volatile,
|
||||
},
|
||||
channel,
|
||||
remaining_bytes,
|
||||
write_channel_info_offset,
|
||||
write_channel_info_length,
|
||||
flags,
|
||||
data,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// SMB2 WRITE Response (MS-SMB2 section 2.2.22).
|
||||
///
|
||||
/// Sent by the server to confirm a write. The structure is 17 bytes:
|
||||
/// - StructureSize (2 bytes, must be 17)
|
||||
/// - Reserved (2 bytes)
|
||||
/// - Count (4 bytes)
|
||||
/// - Remaining (4 bytes)
|
||||
/// - WriteChannelInfoOffset (2 bytes)
|
||||
/// - WriteChannelInfoLength (2 bytes)
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct WriteResponse {
|
||||
/// Number of bytes written.
|
||||
pub count: u32,
|
||||
/// Reserved remaining field (must be 0).
|
||||
pub remaining: u32,
|
||||
/// Reserved write channel info offset (must be 0).
|
||||
pub write_channel_info_offset: u16,
|
||||
/// Reserved write channel info length (must be 0).
|
||||
pub write_channel_info_length: u16,
|
||||
}
|
||||
|
||||
impl WriteResponse {
|
||||
pub const STRUCTURE_SIZE: u16 = 17;
|
||||
}
|
||||
|
||||
impl Pack for WriteResponse {
|
||||
fn pack(&self, cursor: &mut WriteCursor) {
|
||||
cursor.write_u16_le(Self::STRUCTURE_SIZE);
|
||||
cursor.write_u16_le(0); // Reserved
|
||||
cursor.write_u32_le(self.count);
|
||||
cursor.write_u32_le(self.remaining);
|
||||
cursor.write_u16_le(self.write_channel_info_offset);
|
||||
cursor.write_u16_le(self.write_channel_info_length);
|
||||
}
|
||||
}
|
||||
|
||||
impl Unpack for WriteResponse {
|
||||
fn unpack(cursor: &mut ReadCursor<'_>) -> Result<Self> {
|
||||
let structure_size = cursor.read_u16_le()?;
|
||||
if structure_size != Self::STRUCTURE_SIZE {
|
||||
return Err(Error::invalid_data(format!(
|
||||
"invalid WriteResponse structure size: expected {}, got {}",
|
||||
Self::STRUCTURE_SIZE,
|
||||
structure_size
|
||||
)));
|
||||
}
|
||||
|
||||
let _reserved = cursor.read_u16_le()?;
|
||||
let count = cursor.read_u32_le()?;
|
||||
let remaining = cursor.read_u32_le()?;
|
||||
let write_channel_info_offset = cursor.read_u16_le()?;
|
||||
let write_channel_info_length = cursor.read_u16_le()?;
|
||||
|
||||
Ok(WriteResponse {
|
||||
count,
|
||||
remaining,
|
||||
write_channel_info_offset,
|
||||
write_channel_info_length,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// ── WriteRequest tests ─────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn write_request_roundtrip() {
|
||||
let original = WriteRequest {
|
||||
data_offset: 0x70, // 64 (header) + 48 (fixed body) = 112 = 0x70
|
||||
offset: 0x2000,
|
||||
file_id: FileId {
|
||||
persistent: 0xAAAA_BBBB_CCCC_DDDD,
|
||||
volatile: 0x1111_2222_3333_4444,
|
||||
},
|
||||
channel: 0,
|
||||
remaining_bytes: 0,
|
||||
write_channel_info_offset: 0,
|
||||
write_channel_info_length: 0,
|
||||
flags: SMB2_WRITEFLAG_WRITE_THROUGH,
|
||||
data: vec![0x48, 0x65, 0x6C, 0x6C, 0x6F], // "Hello"
|
||||
};
|
||||
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
// Fixed: 48 bytes + 5 bytes data = 53 bytes
|
||||
assert_eq!(bytes.len(), 53);
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = WriteRequest::unpack(&mut r).unwrap();
|
||||
|
||||
assert_eq!(decoded.data_offset, original.data_offset);
|
||||
assert_eq!(decoded.offset, original.offset);
|
||||
assert_eq!(decoded.file_id, original.file_id);
|
||||
assert_eq!(decoded.channel, original.channel);
|
||||
assert_eq!(decoded.remaining_bytes, original.remaining_bytes);
|
||||
assert_eq!(decoded.flags, original.flags);
|
||||
assert_eq!(decoded.data, original.data);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn write_request_empty_data_roundtrip() {
|
||||
let original = WriteRequest {
|
||||
data_offset: 0x70,
|
||||
offset: 0,
|
||||
file_id: FileId {
|
||||
persistent: 1,
|
||||
volatile: 2,
|
||||
},
|
||||
channel: 0,
|
||||
remaining_bytes: 0,
|
||||
write_channel_info_offset: 0,
|
||||
write_channel_info_length: 0,
|
||||
flags: 0,
|
||||
data: Vec::new(),
|
||||
};
|
||||
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
// Fixed: 48 bytes + 1-byte minimum buffer = 49 bytes
|
||||
assert_eq!(bytes.len(), 49);
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = WriteRequest::unpack(&mut r).unwrap();
|
||||
|
||||
assert!(decoded.data.is_empty());
|
||||
assert_eq!(decoded.file_id, original.file_id);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn write_request_wrong_structure_size() {
|
||||
let mut buf = [0u8; 49];
|
||||
buf[0..2].copy_from_slice(&48u16.to_le_bytes());
|
||||
|
||||
let mut cursor = ReadCursor::new(&buf);
|
||||
let result = WriteRequest::unpack(&mut cursor);
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err().to_string();
|
||||
assert!(err.contains("structure size"), "error was: {err}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn write_request_known_bytes() {
|
||||
let mut buf = Vec::new();
|
||||
// StructureSize = 49
|
||||
buf.extend_from_slice(&49u16.to_le_bytes());
|
||||
// DataOffset = 0x70
|
||||
buf.extend_from_slice(&0x70u16.to_le_bytes());
|
||||
// Length = 2
|
||||
buf.extend_from_slice(&2u32.to_le_bytes());
|
||||
// Offset = 0
|
||||
buf.extend_from_slice(&0u64.to_le_bytes());
|
||||
// FileId persistent = 0x10
|
||||
buf.extend_from_slice(&0x10u64.to_le_bytes());
|
||||
// FileId volatile = 0x20
|
||||
buf.extend_from_slice(&0x20u64.to_le_bytes());
|
||||
// Channel = 0
|
||||
buf.extend_from_slice(&0u32.to_le_bytes());
|
||||
// RemainingBytes = 0
|
||||
buf.extend_from_slice(&0u32.to_le_bytes());
|
||||
// WriteChannelInfoOffset = 0
|
||||
buf.extend_from_slice(&0u16.to_le_bytes());
|
||||
// WriteChannelInfoLength = 0
|
||||
buf.extend_from_slice(&0u16.to_le_bytes());
|
||||
// Flags = WRITE_THROUGH
|
||||
buf.extend_from_slice(&1u32.to_le_bytes());
|
||||
// Buffer = [0xAA, 0xBB]
|
||||
buf.extend_from_slice(&[0xAA, 0xBB]);
|
||||
|
||||
let mut cursor = ReadCursor::new(&buf);
|
||||
let req = WriteRequest::unpack(&mut cursor).unwrap();
|
||||
|
||||
assert_eq!(req.data_offset, 0x70);
|
||||
assert_eq!(req.file_id.persistent, 0x10);
|
||||
assert_eq!(req.file_id.volatile, 0x20);
|
||||
assert_eq!(req.flags, SMB2_WRITEFLAG_WRITE_THROUGH);
|
||||
assert_eq!(req.data, vec![0xAA, 0xBB]);
|
||||
}
|
||||
|
||||
// ── WriteResponse tests ────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn write_response_roundtrip() {
|
||||
let original = WriteResponse {
|
||||
count: 65536,
|
||||
remaining: 0,
|
||||
write_channel_info_offset: 0,
|
||||
write_channel_info_length: 0,
|
||||
};
|
||||
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
// 2 + 2 + 4 + 4 + 2 + 2 = 16 bytes
|
||||
assert_eq!(bytes.len(), 16);
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = WriteResponse::unpack(&mut r).unwrap();
|
||||
|
||||
assert_eq!(decoded.count, original.count);
|
||||
assert_eq!(decoded.remaining, original.remaining);
|
||||
assert_eq!(
|
||||
decoded.write_channel_info_offset,
|
||||
original.write_channel_info_offset
|
||||
);
|
||||
assert_eq!(
|
||||
decoded.write_channel_info_length,
|
||||
original.write_channel_info_length
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn write_response_known_bytes() {
|
||||
let mut buf = Vec::new();
|
||||
// StructureSize = 17
|
||||
buf.extend_from_slice(&17u16.to_le_bytes());
|
||||
// Reserved = 0
|
||||
buf.extend_from_slice(&0u16.to_le_bytes());
|
||||
// Count = 1024
|
||||
buf.extend_from_slice(&1024u32.to_le_bytes());
|
||||
// Remaining = 0
|
||||
buf.extend_from_slice(&0u32.to_le_bytes());
|
||||
// WriteChannelInfoOffset = 0
|
||||
buf.extend_from_slice(&0u16.to_le_bytes());
|
||||
// WriteChannelInfoLength = 0
|
||||
buf.extend_from_slice(&0u16.to_le_bytes());
|
||||
|
||||
let mut cursor = ReadCursor::new(&buf);
|
||||
let resp = WriteResponse::unpack(&mut cursor).unwrap();
|
||||
|
||||
assert_eq!(resp.count, 1024);
|
||||
assert_eq!(resp.remaining, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn write_response_wrong_structure_size() {
|
||||
let mut buf = [0u8; 16];
|
||||
buf[0..2].copy_from_slice(&16u16.to_le_bytes());
|
||||
|
||||
let mut cursor = ReadCursor::new(&buf);
|
||||
let result = WriteResponse::unpack(&mut cursor);
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err().to_string();
|
||||
assert!(err.contains("structure size"), "error was: {err}");
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod roundtrip_props {
|
||||
use super::*;
|
||||
use crate::msg::roundtrip_strategies::{arb_bytes, arb_file_id};
|
||||
use proptest::prelude::*;
|
||||
|
||||
proptest! {
|
||||
#[test]
|
||||
fn write_request_pack_unpack(
|
||||
data_offset in any::<u16>(),
|
||||
offset in any::<u64>(),
|
||||
file_id in arb_file_id(),
|
||||
channel in any::<u32>(),
|
||||
remaining_bytes in any::<u32>(),
|
||||
write_channel_info_offset in any::<u16>(),
|
||||
write_channel_info_length in any::<u16>(),
|
||||
flags in any::<u32>(),
|
||||
data in arb_bytes(),
|
||||
) {
|
||||
let original = WriteRequest {
|
||||
data_offset,
|
||||
offset,
|
||||
file_id,
|
||||
channel,
|
||||
remaining_bytes,
|
||||
write_channel_info_offset,
|
||||
write_channel_info_length,
|
||||
flags,
|
||||
data,
|
||||
};
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = WriteRequest::unpack(&mut r).unwrap();
|
||||
prop_assert_eq!(decoded, original);
|
||||
prop_assert!(r.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn write_response_pack_unpack(
|
||||
count in any::<u32>(),
|
||||
remaining in any::<u32>(),
|
||||
write_channel_info_offset in any::<u16>(),
|
||||
write_channel_info_length in any::<u16>(),
|
||||
) {
|
||||
let original = WriteResponse {
|
||||
count,
|
||||
remaining,
|
||||
write_channel_info_offset,
|
||||
write_channel_info_length,
|
||||
};
|
||||
let mut w = WriteCursor::new();
|
||||
original.pack(&mut w);
|
||||
let bytes = w.into_inner();
|
||||
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = WriteResponse::unpack(&mut r).unwrap();
|
||||
prop_assert_eq!(decoded, original);
|
||||
prop_assert!(r.is_empty());
|
||||
}
|
||||
}
|
||||
}
|
||||
45
vendor/smb2/src/pack/CLAUDE.md
vendored
Normal file
45
vendor/smb2/src/pack/CLAUDE.md
vendored
Normal file
@@ -0,0 +1,45 @@
|
||||
# Pack -- binary serialization primitives
|
||||
|
||||
Cursor-based binary reader/writer for SMB2 wire format. Hand-rolled, no proc macros.
|
||||
|
||||
## Key files
|
||||
|
||||
| File | Purpose |
|
||||
|---|---|
|
||||
| `mod.rs` | `ReadCursor`, `WriteCursor`, `Pack`/`Unpack` traits, primitive read/write methods |
|
||||
| `guid.rs` | GUID pack/unpack with mixed-endian layout |
|
||||
| `filetime.rs` | Windows FILETIME (100ns ticks since 1601-01-01) to/from `SystemTime` |
|
||||
|
||||
## Core types
|
||||
|
||||
- **`ReadCursor<'a>`**: Reads from `&[u8]` with position tracking. Returns `Error` on buffer overrun (no panics). All reads are little-endian.
|
||||
- **`WriteCursor`**: Writes into a growable `Vec<u8>`. Supports backpatching (`set_u16_le_at`, `set_u32_le_at`) for length fields written before their values are known. `align_to(n)` pads with zeros to n-byte boundary.
|
||||
- **`Pack` trait**: `fn pack(&self, cursor: &mut WriteCursor)` -- serialize to binary.
|
||||
- **`Unpack` trait**: `fn unpack(cursor: &mut ReadCursor) -> Result<Self>` -- deserialize from binary.
|
||||
|
||||
## GUID mixed-endian layout
|
||||
|
||||
Windows GUIDs have a mixed-endian wire format:
|
||||
- `data1` (u32): little-endian
|
||||
- `data2` (u16): little-endian
|
||||
- `data3` (u16): little-endian
|
||||
- `data4` ([u8; 8]): raw bytes (no endian conversion)
|
||||
|
||||
This matches the COM/DCOM convention. Not the same as RFC 4122 UUID byte order.
|
||||
|
||||
## FileTime conversion
|
||||
|
||||
Windows FILETIME: 100-nanosecond intervals since 1601-01-01 00:00:00 UTC.
|
||||
Unix epoch: 1970-01-01 00:00:00 UTC.
|
||||
Offset: 11,644,473,600 seconds (116,444,736,000,000,000 ticks).
|
||||
|
||||
## Key decisions
|
||||
|
||||
- **Hand-rolled instead of proc macros**: Full control over wire format details (offsets, alignment, backpatching). Easier to debug. No build-time dependency.
|
||||
- **`MAX_UNPACK_BUFFER` (16 MB)**: `read_bytes_bounded` refuses allocations larger than 16 MB. Prevents OOM from malicious packets claiming huge lengths.
|
||||
|
||||
## Gotchas
|
||||
|
||||
- **Everything is little-endian**: Except TCP framing (see transport module). ReadCursor/WriteCursor only do LE.
|
||||
- **UTF-16LE byte length must be even**: `read_utf16_le` returns an error on odd byte counts.
|
||||
- **Backpatching requires placeholder**: Write a zero first, then `set_u32_le_at` to overwrite once the real value is known. Common pattern for length-prefixed fields.
|
||||
175
vendor/smb2/src/pack/filetime.rs
vendored
Normal file
175
vendor/smb2/src/pack/filetime.rs
vendored
Normal file
@@ -0,0 +1,175 @@
|
||||
//! Windows FILETIME type for SMB2.
|
||||
//!
|
||||
//! A FILETIME is a 64-bit value representing 100-nanosecond intervals
|
||||
//! since 1601-01-01 00:00:00 UTC.
|
||||
|
||||
use std::time::{Duration, SystemTime, UNIX_EPOCH};
|
||||
|
||||
use super::{Pack, ReadCursor, Unpack, WriteCursor};
|
||||
use crate::error::Result;
|
||||
|
||||
/// Difference between the Windows epoch (1601-01-01) and Unix epoch (1970-01-01)
|
||||
/// in 100-nanosecond intervals.
|
||||
const EPOCH_DIFF_100NS: u64 = 116_444_736_000_000_000;
|
||||
|
||||
/// Windows FILETIME: 100-nanosecond intervals since 1601-01-01 00:00:00 UTC.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
|
||||
pub struct FileTime(
|
||||
/// The raw 100-nanosecond tick count.
|
||||
pub u64,
|
||||
);
|
||||
|
||||
impl FileTime {
|
||||
/// A zero filetime, meaning "not set" or "unknown".
|
||||
pub const ZERO: Self = Self(0);
|
||||
|
||||
/// Convert a [`SystemTime`] to a `FileTime`.
|
||||
///
|
||||
/// Uses the Unix epoch offset (116,444,736,000,000,000 intervals of
|
||||
/// 100 ns) to translate between the two epoch origins.
|
||||
pub fn from_system_time(t: SystemTime) -> Self {
|
||||
match t.duration_since(UNIX_EPOCH) {
|
||||
Ok(dur) => {
|
||||
let intervals = dur.as_nanos() / 100;
|
||||
Self(intervals as u64 + EPOCH_DIFF_100NS)
|
||||
}
|
||||
Err(e) => {
|
||||
// Time is before Unix epoch. The duration tells us how far before.
|
||||
let before = e.duration();
|
||||
let intervals = before.as_nanos() / 100;
|
||||
// If the pre-Unix time is still after the Windows epoch, compute it.
|
||||
Self(EPOCH_DIFF_100NS.saturating_sub(intervals as u64))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert this `FileTime` to a [`SystemTime`].
|
||||
///
|
||||
/// Returns `None` if the filetime represents a date before the Unix epoch,
|
||||
/// since [`SystemTime`] cannot represent dates before that.
|
||||
pub fn to_system_time(self) -> Option<SystemTime> {
|
||||
if self.0 < EPOCH_DIFF_100NS {
|
||||
return None;
|
||||
}
|
||||
let intervals_since_unix = self.0 - EPOCH_DIFF_100NS;
|
||||
let nanos = (intervals_since_unix as u128) * 100;
|
||||
let dur = Duration::new(
|
||||
(nanos / 1_000_000_000) as u64,
|
||||
(nanos % 1_000_000_000) as u32,
|
||||
);
|
||||
Some(UNIX_EPOCH + dur)
|
||||
}
|
||||
}
|
||||
|
||||
impl Pack for FileTime {
|
||||
fn pack(&self, cursor: &mut WriteCursor) {
|
||||
cursor.write_u64_le(self.0);
|
||||
}
|
||||
}
|
||||
|
||||
impl Unpack for FileTime {
|
||||
fn unpack(cursor: &mut ReadCursor<'_>) -> Result<Self> {
|
||||
let val = cursor.read_u64_le()?;
|
||||
Ok(Self(val))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn zero_filetime() {
|
||||
assert_eq!(FileTime::ZERO, FileTime(0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pack_zero() {
|
||||
let mut w = WriteCursor::new();
|
||||
FileTime::ZERO.pack(&mut w);
|
||||
assert_eq!(w.as_bytes(), &[0u8; 8]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unpack_zero() {
|
||||
let bytes = [0u8; 8];
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let ft = FileTime::unpack(&mut r).unwrap();
|
||||
assert_eq!(ft, FileTime::ZERO);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn known_value_2024_01_01() {
|
||||
// 2024-01-01 00:00:00 UTC = FileTime(133_485_408_000_000_000)
|
||||
// (Unix timestamp 1_704_067_200 * 10_000_000 + 116_444_736_000_000_000)
|
||||
let expected_raw: u64 = 133_485_408_000_000_000;
|
||||
let ft = FileTime(expected_raw);
|
||||
|
||||
// Pack and verify roundtrip
|
||||
let mut w = WriteCursor::new();
|
||||
ft.pack(&mut w);
|
||||
let mut r = ReadCursor::new(w.as_bytes());
|
||||
let unpacked = FileTime::unpack(&mut r).unwrap();
|
||||
assert_eq!(unpacked, ft);
|
||||
|
||||
// Verify SystemTime conversion
|
||||
// 2024-01-01 00:00:00 UTC = Unix timestamp 1_704_067_200
|
||||
let st = ft.to_system_time().unwrap();
|
||||
let unix_dur = st.duration_since(UNIX_EPOCH).unwrap();
|
||||
assert_eq!(unix_dur.as_secs(), 1_704_067_200);
|
||||
assert_eq!(unix_dur.subsec_nanos(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn from_system_time_roundtrip() {
|
||||
// Use a known Unix timestamp: 2024-01-01 00:00:00 UTC
|
||||
let unix_secs = 1_704_067_200u64;
|
||||
let st = UNIX_EPOCH + Duration::from_secs(unix_secs);
|
||||
let ft = FileTime::from_system_time(st);
|
||||
assert_eq!(ft.0, 133_485_408_000_000_000);
|
||||
|
||||
let st2 = ft.to_system_time().unwrap();
|
||||
let dur = st2.duration_since(UNIX_EPOCH).unwrap();
|
||||
assert_eq!(dur.as_secs(), unix_secs);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pre_unix_epoch_returns_none() {
|
||||
// A FILETIME value that represents a date before 1970-01-01
|
||||
let ft = FileTime(EPOCH_DIFF_100NS - 1);
|
||||
assert!(ft.to_system_time().is_none());
|
||||
|
||||
// Zero is also before Unix epoch
|
||||
assert!(FileTime::ZERO.to_system_time().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unix_epoch_exactly() {
|
||||
let ft = FileTime(EPOCH_DIFF_100NS);
|
||||
let st = ft.to_system_time().unwrap();
|
||||
assert_eq!(st, UNIX_EPOCH);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn from_system_time_unix_epoch() {
|
||||
let ft = FileTime::from_system_time(UNIX_EPOCH);
|
||||
assert_eq!(ft.0, EPOCH_DIFF_100NS);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pack_unpack_roundtrip() {
|
||||
let ft = FileTime(133_476_576_000_000_000);
|
||||
let mut w = WriteCursor::new();
|
||||
ft.pack(&mut w);
|
||||
let mut r = ReadCursor::new(w.as_bytes());
|
||||
let unpacked = FileTime::unpack(&mut r).unwrap();
|
||||
assert_eq!(unpacked, ft);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unpack_insufficient_bytes() {
|
||||
let bytes = [0u8; 4]; // need 8
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
assert!(FileTime::unpack(&mut r).is_err());
|
||||
}
|
||||
}
|
||||
176
vendor/smb2/src/pack/guid.rs
vendored
Normal file
176
vendor/smb2/src/pack/guid.rs
vendored
Normal file
@@ -0,0 +1,176 @@
|
||||
//! GUID (Globally Unique Identifier) type for SMB2.
|
||||
//!
|
||||
//! GUIDs follow the mixed-endian layout defined in MS-DTYP section 2.3.4:
|
||||
//! - Bytes 0-3: `data1` (`u32`, little-endian)
|
||||
//! - Bytes 4-5: `data2` (`u16`, little-endian)
|
||||
//! - Bytes 6-7: `data3` (`u16`, little-endian)
|
||||
//! - Bytes 8-15: `data4` (8 raw bytes, big-endian order)
|
||||
|
||||
use std::fmt;
|
||||
|
||||
use super::{Pack, ReadCursor, Unpack, WriteCursor};
|
||||
use crate::error::Result;
|
||||
|
||||
/// A 128-bit GUID in mixed-endian wire format (MS-DTYP 2.3.4).
|
||||
///
|
||||
/// With the `serde` feature on, the JSON form mirrors the in-memory
|
||||
/// field shape (`{data1, data2, data3, data4}`), **not** the wire byte
|
||||
/// order — the wire layout is mixed-endian and round-tripping it through
|
||||
/// JSON would just be confusing.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
|
||||
pub struct Guid {
|
||||
/// First component (bytes 0-3, little-endian on wire).
|
||||
pub data1: u32,
|
||||
/// Second component (bytes 4-5, little-endian on wire).
|
||||
pub data2: u16,
|
||||
/// Third component (bytes 6-7, little-endian on wire).
|
||||
pub data3: u16,
|
||||
/// Fourth component (bytes 8-15, raw byte order on wire).
|
||||
pub data4: [u8; 8],
|
||||
}
|
||||
|
||||
impl Guid {
|
||||
/// The NULL GUID: `{00000000-0000-0000-0000-000000000000}`.
|
||||
pub const ZERO: Self = Self {
|
||||
data1: 0,
|
||||
data2: 0,
|
||||
data3: 0,
|
||||
data4: [0; 8],
|
||||
};
|
||||
}
|
||||
|
||||
impl Pack for Guid {
|
||||
fn pack(&self, cursor: &mut WriteCursor) {
|
||||
cursor.write_u32_le(self.data1);
|
||||
cursor.write_u16_le(self.data2);
|
||||
cursor.write_u16_le(self.data3);
|
||||
cursor.write_bytes(&self.data4);
|
||||
}
|
||||
}
|
||||
|
||||
impl Unpack for Guid {
|
||||
fn unpack(cursor: &mut ReadCursor<'_>) -> Result<Self> {
|
||||
let data1 = cursor.read_u32_le()?;
|
||||
let data2 = cursor.read_u16_le()?;
|
||||
let data3 = cursor.read_u16_le()?;
|
||||
let raw = cursor.read_bytes(8)?;
|
||||
let mut data4 = [0u8; 8];
|
||||
data4.copy_from_slice(raw);
|
||||
Ok(Self {
|
||||
data1,
|
||||
data2,
|
||||
data3,
|
||||
data4,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for Guid {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"{{{:08x}-{:04x}-{:04x}-{:02x}{:02x}-{:02x}{:02x}{:02x}{:02x}{:02x}{:02x}}}",
|
||||
self.data1,
|
||||
self.data2,
|
||||
self.data3,
|
||||
self.data4[0],
|
||||
self.data4[1],
|
||||
self.data4[2],
|
||||
self.data4[3],
|
||||
self.data4[4],
|
||||
self.data4[5],
|
||||
self.data4[6],
|
||||
self.data4[7],
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn unpack_null_guid() {
|
||||
let bytes = [0u8; 16];
|
||||
let mut cursor = ReadCursor::new(&bytes);
|
||||
let guid = Guid::unpack(&mut cursor).unwrap();
|
||||
assert_eq!(guid, Guid::ZERO);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pack_null_guid() {
|
||||
let mut cursor = WriteCursor::new();
|
||||
Guid::ZERO.pack(&mut cursor);
|
||||
assert_eq!(cursor.as_bytes(), &[0u8; 16]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn roundtrip_known_guid() {
|
||||
let guid = Guid {
|
||||
data1: 0x6BA7B810,
|
||||
data2: 0x9DAD,
|
||||
data3: 0x11D1,
|
||||
data4: [0x80, 0xB4, 0x00, 0xC0, 0x4F, 0xD4, 0x30, 0xC8],
|
||||
};
|
||||
|
||||
let mut w = WriteCursor::new();
|
||||
guid.pack(&mut w);
|
||||
let mut r = ReadCursor::new(w.as_bytes());
|
||||
let unpacked = Guid::unpack(&mut r).unwrap();
|
||||
assert_eq!(unpacked, guid);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display_format() {
|
||||
let guid = Guid {
|
||||
data1: 0x6BA7B810,
|
||||
data2: 0x9DAD,
|
||||
data3: 0x11D1,
|
||||
data4: [0x80, 0xB4, 0x00, 0xC0, 0x4F, 0xD4, 0x30, 0xC8],
|
||||
};
|
||||
assert_eq!(guid.to_string(), "{6ba7b810-9dad-11d1-80b4-00c04fd430c8}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display_null_guid() {
|
||||
assert_eq!(
|
||||
Guid::ZERO.to_string(),
|
||||
"{00000000-0000-0000-0000-000000000000}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mixed_endian_byte_ordering() {
|
||||
// Build a GUID with known values and verify the wire bytes directly.
|
||||
let guid = Guid {
|
||||
data1: 0x04030201,
|
||||
data2: 0x0605,
|
||||
data3: 0x0807,
|
||||
data4: [0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10],
|
||||
};
|
||||
|
||||
let mut w = WriteCursor::new();
|
||||
guid.pack(&mut w);
|
||||
let bytes = w.as_bytes();
|
||||
|
||||
// data1: u32 LE -> 01 02 03 04
|
||||
assert_eq!(&bytes[0..4], &[0x01, 0x02, 0x03, 0x04]);
|
||||
// data2: u16 LE -> 05 06
|
||||
assert_eq!(&bytes[4..6], &[0x05, 0x06]);
|
||||
// data3: u16 LE -> 07 08
|
||||
assert_eq!(&bytes[6..8], &[0x07, 0x08]);
|
||||
// data4: raw bytes -> 09 0A 0B 0C 0D 0E 0F 10
|
||||
assert_eq!(
|
||||
&bytes[8..16],
|
||||
&[0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unpack_insufficient_bytes() {
|
||||
let bytes = [0u8; 10]; // need 16
|
||||
let mut cursor = ReadCursor::new(&bytes);
|
||||
assert!(Guid::unpack(&mut cursor).is_err());
|
||||
}
|
||||
}
|
||||
649
vendor/smb2/src/pack/mod.rs
vendored
Normal file
649
vendor/smb2/src/pack/mod.rs
vendored
Normal file
@@ -0,0 +1,649 @@
|
||||
//! Binary serialization/deserialization primitives for SMB2.
|
||||
//!
|
||||
//! Provides [`ReadCursor`] and [`WriteCursor`] for reading and writing
|
||||
//! little-endian binary data, plus [`Pack`] and [`Unpack`] traits for
|
||||
//! structured types.
|
||||
//!
|
||||
//! Most users don't need this module directly -- use [`SmbClient`](crate::SmbClient)
|
||||
//! for high-level file operations.
|
||||
|
||||
pub mod filetime;
|
||||
pub mod guid;
|
||||
|
||||
pub use filetime::FileTime;
|
||||
pub use guid::Guid;
|
||||
|
||||
use crate::error::Result;
|
||||
use crate::Error;
|
||||
|
||||
/// Trait for types that can serialize themselves into binary format.
|
||||
pub trait Pack: Send + Sync {
|
||||
/// Write this value into the cursor.
|
||||
fn pack(&self, cursor: &mut WriteCursor);
|
||||
}
|
||||
|
||||
/// Trait for types that can deserialize themselves from binary format.
|
||||
pub trait Unpack: Sized {
|
||||
/// Read a value from the cursor, advancing its position.
|
||||
fn unpack(cursor: &mut ReadCursor<'_>) -> Result<Self>;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ReadCursor
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// A cursor for reading little-endian binary data from a byte slice.
|
||||
///
|
||||
/// Tracks the current read position and returns errors on buffer overruns
|
||||
/// rather than panicking.
|
||||
pub struct ReadCursor<'a> {
|
||||
data: &'a [u8],
|
||||
pos: usize,
|
||||
}
|
||||
|
||||
impl<'a> ReadCursor<'a> {
|
||||
/// Create a new read cursor starting at position 0.
|
||||
pub fn new(data: &'a [u8]) -> Self {
|
||||
Self { data, pos: 0 }
|
||||
}
|
||||
|
||||
/// Read a single byte.
|
||||
pub fn read_u8(&mut self) -> Result<u8> {
|
||||
self.ensure(1)?;
|
||||
let val = self.data[self.pos];
|
||||
self.pos += 1;
|
||||
Ok(val)
|
||||
}
|
||||
|
||||
/// Read a little-endian `u16`.
|
||||
pub fn read_u16_le(&mut self) -> Result<u16> {
|
||||
let bytes = self.read_array::<2>()?;
|
||||
Ok(u16::from_le_bytes(bytes))
|
||||
}
|
||||
|
||||
/// Read a little-endian `u32`.
|
||||
pub fn read_u32_le(&mut self) -> Result<u32> {
|
||||
let bytes = self.read_array::<4>()?;
|
||||
Ok(u32::from_le_bytes(bytes))
|
||||
}
|
||||
|
||||
/// Read a little-endian `u64`.
|
||||
pub fn read_u64_le(&mut self) -> Result<u64> {
|
||||
let bytes = self.read_array::<8>()?;
|
||||
Ok(u64::from_le_bytes(bytes))
|
||||
}
|
||||
|
||||
/// Read a little-endian `u128`.
|
||||
pub fn read_u128_le(&mut self) -> Result<u128> {
|
||||
let bytes = self.read_array::<16>()?;
|
||||
Ok(u128::from_le_bytes(bytes))
|
||||
}
|
||||
|
||||
/// Read exactly `n` bytes, returning a sub-slice.
|
||||
pub fn read_bytes(&mut self, n: usize) -> Result<&'a [u8]> {
|
||||
self.ensure(n)?;
|
||||
let slice = &self.data[self.pos..self.pos + n];
|
||||
self.pos += n;
|
||||
Ok(slice)
|
||||
}
|
||||
|
||||
/// Read `byte_len` bytes of UTF-16LE data and decode to a [`String`].
|
||||
///
|
||||
/// `byte_len` must be even (each code unit is 2 bytes).
|
||||
pub fn read_utf16_le(&mut self, byte_len: usize) -> Result<String> {
|
||||
if byte_len % 2 != 0 {
|
||||
return Err(Error::invalid_data(format!(
|
||||
"UTF-16LE byte length must be even, got {}",
|
||||
byte_len
|
||||
)));
|
||||
}
|
||||
let raw = self.read_bytes(byte_len)?;
|
||||
let code_units: Vec<u16> = raw
|
||||
.chunks_exact(2)
|
||||
.map(|chunk| u16::from_le_bytes([chunk[0], chunk[1]]))
|
||||
.collect();
|
||||
String::from_utf16(&code_units)
|
||||
.map_err(|_| Error::invalid_data("invalid UTF-16LE encoding"))
|
||||
}
|
||||
|
||||
/// Skip `n` bytes without reading them.
|
||||
pub fn skip(&mut self, n: usize) -> Result<()> {
|
||||
self.ensure(n)?;
|
||||
self.pos += n;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Return the number of bytes remaining.
|
||||
pub fn remaining(&self) -> usize {
|
||||
self.data.len() - self.pos
|
||||
}
|
||||
|
||||
/// Return the current byte position.
|
||||
pub fn position(&self) -> usize {
|
||||
self.pos
|
||||
}
|
||||
|
||||
/// Return `true` if no bytes remain.
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.remaining() == 0
|
||||
}
|
||||
|
||||
/// Maximum buffer size we'll allocate from untrusted data (16 MB).
|
||||
pub const MAX_UNPACK_BUFFER: usize = 16 * 1024 * 1024;
|
||||
|
||||
/// Read `n` bytes, but refuse if `n` exceeds [`Self::MAX_UNPACK_BUFFER`].
|
||||
pub fn read_bytes_bounded(&mut self, n: usize) -> Result<&'a [u8]> {
|
||||
if n > Self::MAX_UNPACK_BUFFER {
|
||||
return Err(Error::invalid_data(format!(
|
||||
"buffer size {} exceeds maximum {} bytes",
|
||||
n,
|
||||
Self::MAX_UNPACK_BUFFER
|
||||
)));
|
||||
}
|
||||
self.read_bytes(n)
|
||||
}
|
||||
|
||||
// -- private helpers --
|
||||
|
||||
fn ensure(&self, n: usize) -> Result<()> {
|
||||
if self.remaining() < n {
|
||||
Err(Error::invalid_data(format!(
|
||||
"need {} bytes but only {} remain at offset {}",
|
||||
n,
|
||||
self.remaining(),
|
||||
self.pos
|
||||
)))
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn read_array<const N: usize>(&mut self) -> Result<[u8; N]> {
|
||||
self.ensure(N)?;
|
||||
let mut arr = [0u8; N];
|
||||
arr.copy_from_slice(&self.data[self.pos..self.pos + N]);
|
||||
self.pos += N;
|
||||
Ok(arr)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// WriteCursor
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// A cursor for writing little-endian binary data into a growable buffer.
|
||||
pub struct WriteCursor {
|
||||
buf: Vec<u8>,
|
||||
}
|
||||
|
||||
impl WriteCursor {
|
||||
/// Create an empty write cursor.
|
||||
pub fn new() -> Self {
|
||||
Self { buf: Vec::new() }
|
||||
}
|
||||
|
||||
/// Create a write cursor with pre-allocated capacity.
|
||||
pub fn with_capacity(cap: usize) -> Self {
|
||||
Self {
|
||||
buf: Vec::with_capacity(cap),
|
||||
}
|
||||
}
|
||||
|
||||
/// Write a single byte.
|
||||
pub fn write_u8(&mut self, val: u8) {
|
||||
self.buf.push(val);
|
||||
}
|
||||
|
||||
/// Write a little-endian `u16`.
|
||||
pub fn write_u16_le(&mut self, val: u16) {
|
||||
self.buf.extend_from_slice(&val.to_le_bytes());
|
||||
}
|
||||
|
||||
/// Write a little-endian `u32`.
|
||||
pub fn write_u32_le(&mut self, val: u32) {
|
||||
self.buf.extend_from_slice(&val.to_le_bytes());
|
||||
}
|
||||
|
||||
/// Write a little-endian `u64`.
|
||||
pub fn write_u64_le(&mut self, val: u64) {
|
||||
self.buf.extend_from_slice(&val.to_le_bytes());
|
||||
}
|
||||
|
||||
/// Write a little-endian `u128`.
|
||||
pub fn write_u128_le(&mut self, val: u128) {
|
||||
self.buf.extend_from_slice(&val.to_le_bytes());
|
||||
}
|
||||
|
||||
/// Write a raw byte slice.
|
||||
pub fn write_bytes(&mut self, data: &[u8]) {
|
||||
self.buf.extend_from_slice(data);
|
||||
}
|
||||
|
||||
/// Encode a string as UTF-16LE and write the bytes.
|
||||
pub fn write_utf16_le(&mut self, s: &str) {
|
||||
for code_unit in s.encode_utf16() {
|
||||
self.buf.extend_from_slice(&code_unit.to_le_bytes());
|
||||
}
|
||||
}
|
||||
|
||||
/// Write `n` zero bytes.
|
||||
pub fn write_zeros(&mut self, n: usize) {
|
||||
self.buf.resize(self.buf.len() + n, 0);
|
||||
}
|
||||
|
||||
/// Pad with zero bytes until the position is a multiple of `alignment`.
|
||||
///
|
||||
/// Does nothing if `alignment` is 0 or 1, or if already aligned.
|
||||
pub fn align_to(&mut self, alignment: usize) {
|
||||
if alignment <= 1 {
|
||||
return;
|
||||
}
|
||||
let remainder = self.buf.len() % alignment;
|
||||
if remainder != 0 {
|
||||
self.write_zeros(alignment - remainder);
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the current write position (number of bytes written so far).
|
||||
pub fn position(&self) -> usize {
|
||||
self.buf.len()
|
||||
}
|
||||
|
||||
/// Overwrite a `u16` at a previous position (little-endian).
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if `pos + 2 > self.position()`.
|
||||
pub fn set_u16_le_at(&mut self, pos: usize, val: u16) {
|
||||
self.buf[pos..pos + 2].copy_from_slice(&val.to_le_bytes());
|
||||
}
|
||||
|
||||
/// Overwrite a `u32` at a previous position (little-endian).
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if `pos + 4 > self.position()`.
|
||||
pub fn set_u32_le_at(&mut self, pos: usize, val: u32) {
|
||||
self.buf[pos..pos + 4].copy_from_slice(&val.to_le_bytes());
|
||||
}
|
||||
|
||||
/// Consume the cursor and return the underlying buffer.
|
||||
pub fn into_inner(self) -> Vec<u8> {
|
||||
self.buf
|
||||
}
|
||||
|
||||
/// Return a reference to the bytes written so far.
|
||||
pub fn as_bytes(&self) -> &[u8] {
|
||||
&self.buf
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for WriteCursor {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use proptest::prelude::*;
|
||||
|
||||
// -- ReadCursor tests --
|
||||
|
||||
#[test]
|
||||
fn read_u8_from_known_bytes() {
|
||||
let data = [0x42];
|
||||
let mut cursor = ReadCursor::new(&data);
|
||||
assert_eq!(cursor.read_u8().unwrap(), 0x42);
|
||||
assert!(cursor.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_u16_le_from_known_bytes() {
|
||||
let data = [0x34, 0x12];
|
||||
let mut cursor = ReadCursor::new(&data);
|
||||
assert_eq!(cursor.read_u16_le().unwrap(), 0x1234);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_u32_le_from_known_bytes() {
|
||||
let data = [0x78, 0x56, 0x34, 0x12];
|
||||
let mut cursor = ReadCursor::new(&data);
|
||||
assert_eq!(cursor.read_u32_le().unwrap(), 0x12345678);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_u64_le_from_known_bytes() {
|
||||
let data = [0x08, 0x07, 0x06, 0x05, 0x04, 0x03, 0x02, 0x01];
|
||||
let mut cursor = ReadCursor::new(&data);
|
||||
assert_eq!(cursor.read_u64_le().unwrap(), 0x0102030405060708);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_u128_le_from_known_bytes() {
|
||||
let mut data = [0u8; 16];
|
||||
data[0] = 0x01;
|
||||
data[15] = 0x80;
|
||||
let mut cursor = ReadCursor::new(&data);
|
||||
let val = cursor.read_u128_le().unwrap();
|
||||
assert_eq!(val, 0x80000000_00000000_00000000_00000001);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_past_end_returns_error() {
|
||||
let data = [0x00];
|
||||
let mut cursor = ReadCursor::new(&data);
|
||||
assert!(cursor.read_u16_le().is_err());
|
||||
|
||||
let empty: &[u8] = &[];
|
||||
let mut cursor = ReadCursor::new(empty);
|
||||
assert!(cursor.read_u8().is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn remaining_and_position_track_correctly() {
|
||||
let data = [0x01, 0x02, 0x03, 0x04, 0x05];
|
||||
let mut cursor = ReadCursor::new(&data);
|
||||
assert_eq!(cursor.position(), 0);
|
||||
assert_eq!(cursor.remaining(), 5);
|
||||
|
||||
cursor.read_u8().unwrap();
|
||||
assert_eq!(cursor.position(), 1);
|
||||
assert_eq!(cursor.remaining(), 4);
|
||||
|
||||
cursor.read_u16_le().unwrap();
|
||||
assert_eq!(cursor.position(), 3);
|
||||
assert_eq!(cursor.remaining(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn skip_advances_position() {
|
||||
let data = [0x01, 0x02, 0x03, 0x04];
|
||||
let mut cursor = ReadCursor::new(&data);
|
||||
cursor.skip(2).unwrap();
|
||||
assert_eq!(cursor.position(), 2);
|
||||
assert_eq!(cursor.read_u8().unwrap(), 0x03);
|
||||
|
||||
// Skip past end is error
|
||||
assert!(cursor.skip(10).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_bytes_returns_correct_slice() {
|
||||
let data = [0x0A, 0x0B, 0x0C, 0x0D];
|
||||
let mut cursor = ReadCursor::new(&data);
|
||||
cursor.skip(1).unwrap();
|
||||
let slice = cursor.read_bytes(2).unwrap();
|
||||
assert_eq!(slice, &[0x0B, 0x0C]);
|
||||
assert_eq!(cursor.position(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_utf16_le_decodes_hello() {
|
||||
// "hello" in UTF-16LE
|
||||
let data = [0x68, 0x00, 0x65, 0x00, 0x6C, 0x00, 0x6C, 0x00, 0x6F, 0x00];
|
||||
let mut cursor = ReadCursor::new(&data);
|
||||
let s = cursor.read_utf16_le(10).unwrap();
|
||||
assert_eq!(s, "hello");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_utf16_le_odd_byte_len_is_error() {
|
||||
let data = [0x68, 0x00, 0x65];
|
||||
let mut cursor = ReadCursor::new(&data);
|
||||
assert!(cursor.read_utf16_le(3).is_err());
|
||||
}
|
||||
|
||||
// -- WriteCursor tests --
|
||||
|
||||
#[test]
|
||||
fn write_u8_produces_correct_byte() {
|
||||
let mut cursor = WriteCursor::new();
|
||||
cursor.write_u8(0xFF);
|
||||
assert_eq!(cursor.as_bytes(), &[0xFF]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn write_u16_le_produces_correct_bytes() {
|
||||
let mut cursor = WriteCursor::new();
|
||||
cursor.write_u16_le(0x1234);
|
||||
assert_eq!(cursor.as_bytes(), &[0x34, 0x12]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn write_u32_le_produces_correct_bytes() {
|
||||
let mut cursor = WriteCursor::new();
|
||||
cursor.write_u32_le(0x12345678);
|
||||
assert_eq!(cursor.as_bytes(), &[0x78, 0x56, 0x34, 0x12]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn write_u64_le_produces_correct_bytes() {
|
||||
let mut cursor = WriteCursor::new();
|
||||
cursor.write_u64_le(0x0102030405060708);
|
||||
assert_eq!(
|
||||
cursor.as_bytes(),
|
||||
&[0x08, 0x07, 0x06, 0x05, 0x04, 0x03, 0x02, 0x01]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn write_u128_le_produces_correct_bytes() {
|
||||
let mut cursor = WriteCursor::new();
|
||||
cursor.write_u128_le(0x01);
|
||||
let bytes = cursor.as_bytes();
|
||||
assert_eq!(bytes.len(), 16);
|
||||
assert_eq!(bytes[0], 0x01);
|
||||
assert!(bytes[1..].iter().all(|&b| b == 0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn align_to_pads_correctly() {
|
||||
// From position 0 -> already aligned
|
||||
let mut cursor = WriteCursor::new();
|
||||
cursor.align_to(8);
|
||||
assert_eq!(cursor.position(), 0);
|
||||
|
||||
// From position 3 -> pad to 8
|
||||
let mut cursor = WriteCursor::new();
|
||||
cursor.write_bytes(&[0x01, 0x02, 0x03]);
|
||||
cursor.align_to(8);
|
||||
assert_eq!(cursor.position(), 8);
|
||||
// Padding bytes should be zeros
|
||||
assert_eq!(&cursor.as_bytes()[3..8], &[0, 0, 0, 0, 0]);
|
||||
|
||||
// From position 8 -> already aligned
|
||||
cursor.align_to(8);
|
||||
assert_eq!(cursor.position(), 8);
|
||||
|
||||
// From position 1 -> pad to 4
|
||||
let mut cursor = WriteCursor::new();
|
||||
cursor.write_u8(0xAA);
|
||||
cursor.align_to(4);
|
||||
assert_eq!(cursor.position(), 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn set_u32_le_at_backpatches_correctly() {
|
||||
let mut cursor = WriteCursor::new();
|
||||
cursor.write_u32_le(0); // placeholder
|
||||
cursor.write_u32_le(0xDEADBEEF);
|
||||
cursor.set_u32_le_at(0, 0x12345678);
|
||||
assert_eq!(
|
||||
cursor.as_bytes(),
|
||||
&[0x78, 0x56, 0x34, 0x12, 0xEF, 0xBE, 0xAD, 0xDE]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn set_u16_le_at_backpatches_correctly() {
|
||||
let mut cursor = WriteCursor::new();
|
||||
cursor.write_u16_le(0);
|
||||
cursor.write_u16_le(0xBEEF);
|
||||
cursor.set_u16_le_at(0, 0x1234);
|
||||
assert_eq!(cursor.as_bytes(), &[0x34, 0x12, 0xEF, 0xBE]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn write_utf16_le_encodes_correctly() {
|
||||
let mut cursor = WriteCursor::new();
|
||||
cursor.write_utf16_le("hello");
|
||||
assert_eq!(
|
||||
cursor.as_bytes(),
|
||||
&[0x68, 0x00, 0x65, 0x00, 0x6C, 0x00, 0x6C, 0x00, 0x6F, 0x00]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn write_zeros_produces_correct_count() {
|
||||
let mut cursor = WriteCursor::new();
|
||||
cursor.write_zeros(5);
|
||||
assert_eq!(cursor.as_bytes(), &[0, 0, 0, 0, 0]);
|
||||
assert_eq!(cursor.position(), 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn into_inner_returns_buffer() {
|
||||
let mut cursor = WriteCursor::new();
|
||||
cursor.write_u8(0x42);
|
||||
let buf = cursor.into_inner();
|
||||
assert_eq!(buf, vec![0x42]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn with_capacity_works() {
|
||||
let cursor = WriteCursor::with_capacity(1024);
|
||||
assert_eq!(cursor.position(), 0);
|
||||
}
|
||||
|
||||
// -- Roundtrip tests --
|
||||
|
||||
#[test]
|
||||
fn roundtrip_u8() {
|
||||
let mut w = WriteCursor::new();
|
||||
w.write_u8(0xAB);
|
||||
let mut r = ReadCursor::new(w.as_bytes());
|
||||
assert_eq!(r.read_u8().unwrap(), 0xAB);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn roundtrip_u16() {
|
||||
let mut w = WriteCursor::new();
|
||||
w.write_u16_le(0xCAFE);
|
||||
let mut r = ReadCursor::new(w.as_bytes());
|
||||
assert_eq!(r.read_u16_le().unwrap(), 0xCAFE);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn roundtrip_u32() {
|
||||
let mut w = WriteCursor::new();
|
||||
w.write_u32_le(0xDEADBEEF);
|
||||
let mut r = ReadCursor::new(w.as_bytes());
|
||||
assert_eq!(r.read_u32_le().unwrap(), 0xDEADBEEF);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn roundtrip_u64() {
|
||||
let mut w = WriteCursor::new();
|
||||
w.write_u64_le(0x0102030405060708);
|
||||
let mut r = ReadCursor::new(w.as_bytes());
|
||||
assert_eq!(r.read_u64_le().unwrap(), 0x0102030405060708);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn roundtrip_u128() {
|
||||
let val: u128 = 0x0102030405060708090A0B0C0D0E0F10;
|
||||
let mut w = WriteCursor::new();
|
||||
w.write_u128_le(val);
|
||||
let mut r = ReadCursor::new(w.as_bytes());
|
||||
assert_eq!(r.read_u128_le().unwrap(), val);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn roundtrip_utf16_le() {
|
||||
let mut w = WriteCursor::new();
|
||||
w.write_utf16_le("Hello, world!");
|
||||
let bytes = w.into_inner();
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let s = r.read_utf16_le(bytes.len()).unwrap();
|
||||
assert_eq!(s, "Hello, world!");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn roundtrip_utf16_le_emoji() {
|
||||
let mut w = WriteCursor::new();
|
||||
w.write_utf16_le("\u{1F600}");
|
||||
let bytes = w.into_inner();
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let s = r.read_utf16_le(bytes.len()).unwrap();
|
||||
assert_eq!(s, "\u{1F600}");
|
||||
}
|
||||
|
||||
// -- Property-based tests --
|
||||
|
||||
fn valid_utf16_string() -> impl Strategy<Value = String> {
|
||||
prop::collection::vec(
|
||||
prop::char::range('\u{0000}', '\u{D7FF}')
|
||||
.prop_union(prop::char::range('\u{E000}', '\u{FFFF}')),
|
||||
0..100,
|
||||
)
|
||||
.prop_map(|chars| chars.into_iter().collect())
|
||||
}
|
||||
|
||||
proptest! {
|
||||
#[test]
|
||||
fn prop_roundtrip_u8(val: u8) {
|
||||
let mut w = WriteCursor::new();
|
||||
w.write_u8(val);
|
||||
let mut r = ReadCursor::new(w.as_bytes());
|
||||
prop_assert_eq!(r.read_u8().unwrap(), val);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prop_roundtrip_u16(val: u16) {
|
||||
let mut w = WriteCursor::new();
|
||||
w.write_u16_le(val);
|
||||
let mut r = ReadCursor::new(w.as_bytes());
|
||||
prop_assert_eq!(r.read_u16_le().unwrap(), val);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prop_roundtrip_u32(val: u32) {
|
||||
let mut w = WriteCursor::new();
|
||||
w.write_u32_le(val);
|
||||
let mut r = ReadCursor::new(w.as_bytes());
|
||||
prop_assert_eq!(r.read_u32_le().unwrap(), val);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prop_roundtrip_u64(val: u64) {
|
||||
let mut w = WriteCursor::new();
|
||||
w.write_u64_le(val);
|
||||
let mut r = ReadCursor::new(w.as_bytes());
|
||||
prop_assert_eq!(r.read_u64_le().unwrap(), val);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prop_roundtrip_u128(val: u128) {
|
||||
let mut w = WriteCursor::new();
|
||||
w.write_u128_le(val);
|
||||
let mut r = ReadCursor::new(w.as_bytes());
|
||||
prop_assert_eq!(r.read_u128_le().unwrap(), val);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prop_roundtrip_utf16_le(s in valid_utf16_string()) {
|
||||
let mut w = WriteCursor::new();
|
||||
w.write_utf16_le(&s);
|
||||
let bytes = w.into_inner();
|
||||
let mut r = ReadCursor::new(&bytes);
|
||||
let decoded = r.read_utf16_le(bytes.len()).unwrap();
|
||||
prop_assert_eq!(decoded, s);
|
||||
}
|
||||
}
|
||||
}
|
||||
51
vendor/smb2/src/rpc/CLAUDE.md
vendored
Normal file
51
vendor/smb2/src/rpc/CLAUDE.md
vendored
Normal file
@@ -0,0 +1,51 @@
|
||||
# RPC -- named pipe RPC for share enumeration
|
||||
|
||||
DCE/RPC over SMB2 named pipes. Used to list shares on a server via the srvsvc interface.
|
||||
|
||||
## Key files
|
||||
|
||||
| File | Purpose |
|
||||
|---|---|
|
||||
| `mod.rs` | RPC PDU building/parsing: BIND, BIND_ACK, REQUEST, RESPONSE |
|
||||
| `srvsvc.rs` | NDR encoding for `NetShareEnumAll` (opnum 15), `ShareInfo` type |
|
||||
|
||||
## Protocol flow
|
||||
|
||||
1. Tree connect to `IPC$`
|
||||
2. CREATE `srvsvc` pipe (server prepends `\pipe\`)
|
||||
3. WRITE: RPC BIND (call_id=1, srvsvc UUID + NDR transfer syntax)
|
||||
4. READ: RPC BIND_ACK -- verify context accepted
|
||||
5. WRITE: RPC REQUEST (call_id=2, opnum=15, NDR-encoded NetShareEnumAll)
|
||||
6. READ: RPC RESPONSE -- NDR-decode share list
|
||||
7. CLOSE pipe
|
||||
8. Tree disconnect IPC$
|
||||
|
||||
Used by `client/shares.rs` which orchestrates the full flow via `SmbClient::list_shares()`.
|
||||
|
||||
## NDR encoding
|
||||
|
||||
`srvsvc.rs` handles NDR (Network Data Representation) encoding/decoding:
|
||||
- Conformant arrays: max_count prefix, then elements
|
||||
- Conformant varying strings: max_count + offset + actual_count + UTF-16LE data
|
||||
- Referent pointers: non-zero pointer ID, then deferred data
|
||||
- All 4-byte aligned
|
||||
|
||||
## Key decisions
|
||||
|
||||
- **call_id convention**: 1 for BIND, 2 for REQUEST. Arbitrary but consistent with smb-rs.
|
||||
- **Max fragment size 4280**: Default `MAX_XMIT_FRAG` / `MAX_RECV_FRAG`. Matches common implementations.
|
||||
|
||||
## Response reassembly (two independent layers)
|
||||
|
||||
A `NetShareEnum` reply can be split two different ways, and the client handles both. They compose: a fragment loop wrapping a buffer-overflow loop.
|
||||
|
||||
- **DCE/RPC fragmentation (MS-RPCE 2.2.2.6)**: a large response may arrive as several RESPONSE PDUs, each its own pipe message, with `PFC_LAST_FRAG` set only on the last. `parse_response_fragment` returns `(stub, is_last)`; `client/shares.rs` loops reading PDUs and concatenating stubs until `is_last`, then NDR-decodes the joined stub via `srvsvc::parse_net_share_enum_all_stub`. `parse_response` is the single-fragment convenience wrapper (`parse_response_fragment(..).map(|(s, _)| s)`).
|
||||
- **SMB pipe `STATUS_BUFFER_OVERFLOW` (MS-SMB2 3.3.5.10)**: a single pipe message larger than our 64 KiB read buffer comes back as overflow reads (partial data) terminated by a `SUCCESS` read. `client::shares::read_pipe_message` follows this, appending chunks until `SUCCESS`. The two phenomena are usually mutually exclusive in practice (fragments ≤ `MAX_RECV_FRAG` 4280 fit in one read; a server that ignores the frag cap sends one big PDU that overflows), but the code handles either or both.
|
||||
|
||||
## Gotchas
|
||||
|
||||
- **Pipe name is `srvsvc`**: The server prepends `\pipe\` automatically. Don't include it in the CREATE request.
|
||||
- **Admin shares filtered out**: `list_shares` filters shares ending with `$` (IPC$, ADMIN$, C$). Only disk shares returned by default.
|
||||
- **RPC version is 5.0**: Connection-oriented RPC. `PFC_FIRST_FRAG | PFC_LAST_FRAG` together mark a complete single-fragment PDU; a cleared `PFC_LAST_FRAG` means more fragments follow (see reassembly above).
|
||||
- **NDR string alignment**: After each string, pad to 4-byte boundary. Missing alignment causes the server to reject the request silently.
|
||||
- **Don't gate pipe reads on `SUCCESS` only**: `STATUS_BUFFER_OVERFLOW` is a warning (partial data), not a failure. Use `NtStatus::is_success_or_partial` and read again, or you truncate/error on large replies from servers that chunk them. This previously made `list_shares` fail on servers whose listing exceeded one read or one fragment.
|
||||
549
vendor/smb2/src/rpc/mod.rs
vendored
Normal file
549
vendor/smb2/src/rpc/mod.rs
vendored
Normal file
@@ -0,0 +1,549 @@
|
||||
//! Named pipe RPC (MS-RPCE / NDR) for share enumeration.
|
||||
//!
|
||||
//! This module encodes and decodes DCE/RPC PDUs used over SMB2 named pipes.
|
||||
//! The exchange for share enumeration is:
|
||||
//!
|
||||
//! 1. Open `\pipe\srvsvc` via CREATE
|
||||
//! 2. Send RPC BIND request (type 11)
|
||||
//! 3. Receive RPC BIND_ACK response (type 12)
|
||||
//! 4. Send RPC REQUEST with NetShareEnumAll (type 0, opnum 15)
|
||||
//! 5. Receive RPC RESPONSE with results (type 2)
|
||||
//! 6. CLOSE the pipe
|
||||
//!
|
||||
//! Most users don't need this module directly -- use
|
||||
//! [`SmbClient::list_shares`](crate::SmbClient::list_shares) instead.
|
||||
//! The [`ShareInfo`](crate::ShareInfo) type is re-exported at the crate root.
|
||||
|
||||
pub mod srvsvc;
|
||||
|
||||
use crate::error::Result;
|
||||
use crate::pack::guid::Guid;
|
||||
use crate::pack::{Pack, ReadCursor, WriteCursor};
|
||||
use crate::Error;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Constants
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// RPC version 5.0 (connection-oriented).
|
||||
const RPC_VERSION_MAJOR: u8 = 5;
|
||||
/// RPC minor version.
|
||||
const RPC_VERSION_MINOR: u8 = 0;
|
||||
|
||||
/// Data representation: little-endian, ASCII character set, IEEE floating point.
|
||||
const DATA_REP: [u8; 4] = [0x10, 0x00, 0x00, 0x00];
|
||||
|
||||
/// RPC PDU type: REQUEST.
|
||||
const PDU_TYPE_REQUEST: u8 = 0;
|
||||
/// RPC PDU type: RESPONSE.
|
||||
const PDU_TYPE_RESPONSE: u8 = 2;
|
||||
/// RPC PDU type: BIND.
|
||||
const PDU_TYPE_BIND: u8 = 11;
|
||||
/// RPC PDU type: BIND_ACK.
|
||||
const PDU_TYPE_BIND_ACK: u8 = 12;
|
||||
|
||||
/// Default maximum transmit fragment size.
|
||||
const MAX_XMIT_FRAG: u16 = 4280;
|
||||
/// Default maximum receive fragment size.
|
||||
const MAX_RECV_FRAG: u16 = 4280;
|
||||
|
||||
/// PFC flags: first fragment.
|
||||
const PFC_FIRST_FRAG: u8 = 0x01;
|
||||
/// PFC flags: last fragment.
|
||||
const PFC_LAST_FRAG: u8 = 0x02;
|
||||
|
||||
/// srvsvc abstract syntax UUID: `4B324FC8-1670-01D3-1278-5A47BF6EE188`.
|
||||
const SRVSVC_UUID: Guid = Guid {
|
||||
data1: 0x4B324FC8,
|
||||
data2: 0x1670,
|
||||
data3: 0x01D3,
|
||||
data4: [0x12, 0x78, 0x5A, 0x47, 0xBF, 0x6E, 0xE1, 0x88],
|
||||
};
|
||||
/// srvsvc abstract syntax version.
|
||||
const SRVSVC_VERSION: u32 = 3;
|
||||
|
||||
/// NDR transfer syntax UUID: `8A885D04-1CEB-11C9-9FE8-08002B104860`.
|
||||
const NDR_UUID: Guid = Guid {
|
||||
data1: 0x8A885D04,
|
||||
data2: 0x1CEB,
|
||||
data3: 0x11C9,
|
||||
data4: [0x9F, 0xE8, 0x08, 0x00, 0x2B, 0x10, 0x48, 0x60],
|
||||
};
|
||||
/// NDR transfer syntax version.
|
||||
const NDR_VERSION: u32 = 2;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// RPC PDU common header size
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Size of the RPC PDU common header (16 bytes).
|
||||
const RPC_HEADER_SIZE: usize = 16;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Build functions
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Build an RPC BIND request for the srvsvc interface.
|
||||
///
|
||||
/// The BIND PDU negotiates the presentation context, binding the srvsvc
|
||||
/// abstract syntax with the NDR transfer syntax.
|
||||
pub fn build_srvsvc_bind(call_id: u32) -> Vec<u8> {
|
||||
let mut w = WriteCursor::with_capacity(72);
|
||||
|
||||
// Common header (16 bytes) -- FragLength will be backpatched
|
||||
w.write_u8(RPC_VERSION_MAJOR);
|
||||
w.write_u8(RPC_VERSION_MINOR);
|
||||
w.write_u8(PDU_TYPE_BIND);
|
||||
w.write_u8(PFC_FIRST_FRAG | PFC_LAST_FRAG);
|
||||
w.write_bytes(&DATA_REP);
|
||||
let frag_len_pos = w.position();
|
||||
w.write_u16_le(0); // FragLength placeholder
|
||||
w.write_u16_le(0); // AuthLength
|
||||
w.write_u32_le(call_id);
|
||||
|
||||
// BIND-specific fields
|
||||
w.write_u16_le(MAX_XMIT_FRAG);
|
||||
w.write_u16_le(MAX_RECV_FRAG);
|
||||
w.write_u32_le(0); // AssocGroup
|
||||
|
||||
// Presentation context list
|
||||
w.write_u8(1); // NumCtxItems
|
||||
w.write_bytes(&[0, 0, 0]); // Reserved
|
||||
|
||||
// Context item 0
|
||||
w.write_u16_le(0); // ContextId
|
||||
w.write_u8(1); // NumTransferSyntaxes
|
||||
w.write_u8(0); // Reserved
|
||||
|
||||
// Abstract syntax: srvsvc
|
||||
SRVSVC_UUID.pack(&mut w);
|
||||
w.write_u32_le(SRVSVC_VERSION);
|
||||
|
||||
// Transfer syntax: NDR
|
||||
NDR_UUID.pack(&mut w);
|
||||
w.write_u32_le(NDR_VERSION);
|
||||
|
||||
// Backpatch FragLength
|
||||
let total_len = w.position();
|
||||
w.set_u16_le_at(frag_len_pos, total_len as u16);
|
||||
|
||||
w.into_inner()
|
||||
}
|
||||
|
||||
/// Parse an RPC BIND_ACK response.
|
||||
///
|
||||
/// Verifies that the server accepted the presentation context (result == 0).
|
||||
/// Returns `Ok(())` on success, or an error if the bind was rejected or
|
||||
/// the response is malformed.
|
||||
pub fn parse_bind_ack(data: &[u8]) -> Result<()> {
|
||||
let mut r = ReadCursor::new(data);
|
||||
|
||||
// Common header
|
||||
let version = r.read_u8()?;
|
||||
let version_minor = r.read_u8()?;
|
||||
if version != RPC_VERSION_MAJOR || version_minor != RPC_VERSION_MINOR {
|
||||
return Err(Error::invalid_data(format!(
|
||||
"unexpected RPC version {version}.{version_minor}, expected 5.0"
|
||||
)));
|
||||
}
|
||||
|
||||
let ptype = r.read_u8()?;
|
||||
if ptype != PDU_TYPE_BIND_ACK {
|
||||
return Err(Error::invalid_data(format!(
|
||||
"expected BIND_ACK (type 12), got type {ptype}"
|
||||
)));
|
||||
}
|
||||
|
||||
let _flags = r.read_u8()?;
|
||||
let _data_rep = r.read_bytes(4)?;
|
||||
let _frag_length = r.read_u16_le()?;
|
||||
let _auth_length = r.read_u16_le()?;
|
||||
let _call_id = r.read_u32_le()?;
|
||||
|
||||
// BIND_ACK specific fields
|
||||
let _max_xmit_frag = r.read_u16_le()?;
|
||||
let _max_recv_frag = r.read_u16_le()?;
|
||||
let _assoc_group = r.read_u32_le()?;
|
||||
|
||||
// Secondary address (variable length, padded to 4 bytes)
|
||||
let sec_addr_len = r.read_u16_le()?;
|
||||
r.skip(sec_addr_len as usize)?;
|
||||
// Align to 4 bytes after secondary address (the 2-byte length + string)
|
||||
let consumed = 2 + sec_addr_len as usize;
|
||||
let padding = (4 - (consumed % 4)) % 4;
|
||||
r.skip(padding)?;
|
||||
|
||||
// Result list
|
||||
let num_results = r.read_u8()?;
|
||||
r.skip(3)?; // Reserved
|
||||
|
||||
if num_results == 0 {
|
||||
return Err(Error::invalid_data("BIND_ACK has no context results"));
|
||||
}
|
||||
|
||||
// Check first result
|
||||
let result = r.read_u16_le()?;
|
||||
if result != 0 {
|
||||
let reason = r.read_u16_le()?;
|
||||
return Err(Error::invalid_data(format!(
|
||||
"BIND rejected: result={result}, reason={reason}"
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Build an RPC REQUEST PDU wrapping the given stub data.
|
||||
///
|
||||
/// The caller provides the NDR-encoded stub (the operation payload) and the
|
||||
/// operation number.
|
||||
pub fn build_request(call_id: u32, opnum: u16, stub_data: &[u8]) -> Vec<u8> {
|
||||
let mut w = WriteCursor::with_capacity(RPC_HEADER_SIZE + 8 + stub_data.len());
|
||||
|
||||
// Common header
|
||||
w.write_u8(RPC_VERSION_MAJOR);
|
||||
w.write_u8(RPC_VERSION_MINOR);
|
||||
w.write_u8(PDU_TYPE_REQUEST);
|
||||
w.write_u8(PFC_FIRST_FRAG | PFC_LAST_FRAG);
|
||||
w.write_bytes(&DATA_REP);
|
||||
let frag_len_pos = w.position();
|
||||
w.write_u16_le(0); // FragLength placeholder
|
||||
w.write_u16_le(0); // AuthLength
|
||||
w.write_u32_le(call_id);
|
||||
|
||||
// REQUEST specific fields
|
||||
w.write_u32_le(stub_data.len() as u32); // AllocHint
|
||||
w.write_u16_le(0); // ContextId
|
||||
w.write_u16_le(opnum);
|
||||
|
||||
// Stub data
|
||||
w.write_bytes(stub_data);
|
||||
|
||||
// Backpatch FragLength
|
||||
let total_len = w.position();
|
||||
w.set_u16_le_at(frag_len_pos, total_len as u16);
|
||||
|
||||
w.into_inner()
|
||||
}
|
||||
|
||||
/// Parse a single RPC RESPONSE PDU, returning its stub data and whether it is
|
||||
/// the final fragment (`PFC_LAST_FRAG` set).
|
||||
///
|
||||
/// DCE/RPC servers may split a large response across several fragment PDUs,
|
||||
/// clearing `PFC_LAST_FRAG` on every fragment but the last (MS-RPCE 2.2.2.6).
|
||||
/// Callers reassemble by concatenating each fragment's stub until `is_last` is
|
||||
/// `true`. See `client::shares` for the read-and-reassemble loop.
|
||||
pub fn parse_response_fragment(data: &[u8]) -> Result<(&[u8], bool)> {
|
||||
let mut r = ReadCursor::new(data);
|
||||
|
||||
// Common header
|
||||
let version = r.read_u8()?;
|
||||
let version_minor = r.read_u8()?;
|
||||
if version != RPC_VERSION_MAJOR || version_minor != RPC_VERSION_MINOR {
|
||||
return Err(Error::invalid_data(format!(
|
||||
"unexpected RPC version {version}.{version_minor}, expected 5.0"
|
||||
)));
|
||||
}
|
||||
|
||||
let ptype = r.read_u8()?;
|
||||
if ptype != PDU_TYPE_RESPONSE {
|
||||
return Err(Error::invalid_data(format!(
|
||||
"expected RESPONSE (type 2), got type {ptype}"
|
||||
)));
|
||||
}
|
||||
|
||||
let flags = r.read_u8()?;
|
||||
let _data_rep = r.read_bytes(4)?;
|
||||
let frag_length = r.read_u16_le()? as usize;
|
||||
let _auth_length = r.read_u16_le()?;
|
||||
let _call_id = r.read_u32_le()?;
|
||||
|
||||
// RESPONSE specific fields
|
||||
let _alloc_hint = r.read_u32_le()?;
|
||||
let _context_id = r.read_u16_le()?;
|
||||
let _cancel_count = r.read_u8()?;
|
||||
let _reserved = r.read_u8()?;
|
||||
|
||||
// Stub data is the rest (up to frag_length).
|
||||
let header_consumed = r.position();
|
||||
if frag_length < header_consumed {
|
||||
return Err(Error::invalid_data(format!(
|
||||
"RPC frag_length {frag_length} shorter than header {header_consumed}"
|
||||
)));
|
||||
}
|
||||
let stub_data = r.read_bytes(frag_length - header_consumed)?;
|
||||
|
||||
let is_last = flags & PFC_LAST_FRAG != 0;
|
||||
Ok((stub_data, is_last))
|
||||
}
|
||||
|
||||
/// Parse an RPC RESPONSE PDU, returning the stub data.
|
||||
///
|
||||
/// Validates the PDU header and extracts the embedded stub data for
|
||||
/// further NDR decoding. Assumes a single, complete fragment; for fragmented
|
||||
/// responses use [`parse_response_fragment`] and reassemble.
|
||||
pub fn parse_response(data: &[u8]) -> Result<&[u8]> {
|
||||
parse_response_fragment(data).map(|(stub, _is_last)| stub)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::pack::Unpack;
|
||||
|
||||
#[test]
|
||||
fn bind_request_has_correct_header() {
|
||||
let pdu = build_srvsvc_bind(1);
|
||||
|
||||
assert_eq!(pdu[0], RPC_VERSION_MAJOR, "version major");
|
||||
assert_eq!(pdu[1], RPC_VERSION_MINOR, "version minor");
|
||||
assert_eq!(pdu[2], PDU_TYPE_BIND, "packet type");
|
||||
assert_eq!(pdu[3], PFC_FIRST_FRAG | PFC_LAST_FRAG, "flags");
|
||||
|
||||
// Data representation
|
||||
assert_eq!(&pdu[4..8], &DATA_REP);
|
||||
|
||||
// FragLength should match actual PDU length
|
||||
let frag_len = u16::from_le_bytes([pdu[8], pdu[9]]);
|
||||
assert_eq!(frag_len as usize, pdu.len());
|
||||
|
||||
// AuthLength = 0
|
||||
let auth_len = u16::from_le_bytes([pdu[10], pdu[11]]);
|
||||
assert_eq!(auth_len, 0);
|
||||
|
||||
// CallId = 1
|
||||
let call_id = u32::from_le_bytes([pdu[12], pdu[13], pdu[14], pdu[15]]);
|
||||
assert_eq!(call_id, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bind_request_contains_srvsvc_uuid() {
|
||||
let pdu = build_srvsvc_bind(1);
|
||||
|
||||
// After common header (16) + MaxXmitFrag(2) + MaxRecvFrag(2) + AssocGroup(4) +
|
||||
// NumCtxItems(1) + Reserved(3) + ContextId(2) + NumTransferSyntaxes(1) + Reserved(1) = 32
|
||||
let uuid_offset = 32;
|
||||
|
||||
// Extract the abstract syntax UUID bytes
|
||||
let mut cursor = ReadCursor::new(&pdu[uuid_offset..]);
|
||||
let guid = Guid::unpack(&mut cursor).unwrap();
|
||||
assert_eq!(guid, SRVSVC_UUID);
|
||||
|
||||
let version = cursor.read_u32_le().unwrap();
|
||||
assert_eq!(version, SRVSVC_VERSION);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bind_request_contains_ndr_transfer_syntax() {
|
||||
let pdu = build_srvsvc_bind(1);
|
||||
|
||||
// Transfer syntax starts after abstract syntax (UUID=16 + version=4 = 20 bytes after uuid_offset)
|
||||
let transfer_offset = 32 + 20;
|
||||
|
||||
let mut cursor = ReadCursor::new(&pdu[transfer_offset..]);
|
||||
let guid = Guid::unpack(&mut cursor).unwrap();
|
||||
assert_eq!(guid, NDR_UUID);
|
||||
|
||||
let version = cursor.read_u32_le().unwrap();
|
||||
assert_eq!(version, NDR_VERSION);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bind_request_total_length() {
|
||||
let pdu = build_srvsvc_bind(1);
|
||||
// 16 (header) + 4 (max frags) + 4 (assoc) + 4 (ctx list header) +
|
||||
// 4 (ctx item header) + 20 (abstract) + 20 (transfer) = 72
|
||||
assert_eq!(pdu.len(), 72);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_valid_bind_ack() {
|
||||
let ack = build_test_bind_ack(0); // result = 0 = accepted
|
||||
assert!(parse_bind_ack(&ack).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_rejected_bind_ack() {
|
||||
let ack = build_test_bind_ack(2); // result = 2 = provider_rejection
|
||||
let err = parse_bind_ack(&ack).unwrap_err();
|
||||
let msg = err.to_string();
|
||||
assert!(
|
||||
msg.contains("rejected"),
|
||||
"error should mention rejection: {msg}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_bind_ack_wrong_version() {
|
||||
let mut ack = build_test_bind_ack(0);
|
||||
ack[0] = 4; // wrong version
|
||||
assert!(parse_bind_ack(&ack).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_bind_ack_wrong_type() {
|
||||
let mut ack = build_test_bind_ack(0);
|
||||
ack[2] = PDU_TYPE_BIND; // wrong type
|
||||
assert!(parse_bind_ack(&ack).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn request_pdu_has_correct_opnum() {
|
||||
let stub = vec![0xAA, 0xBB, 0xCC];
|
||||
let pdu = build_request(1, 15, &stub);
|
||||
|
||||
// OpNum is at offset 22 (header=16 + AllocHint=4 + ContextId=2)
|
||||
let opnum = u16::from_le_bytes([pdu[22], pdu[23]]);
|
||||
assert_eq!(opnum, 15);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn request_pdu_has_correct_alloc_hint() {
|
||||
let stub = vec![0xAA, 0xBB, 0xCC];
|
||||
let pdu = build_request(1, 15, &stub);
|
||||
|
||||
let alloc_hint = u32::from_le_bytes([pdu[16], pdu[17], pdu[18], pdu[19]]);
|
||||
assert_eq!(alloc_hint, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn request_pdu_contains_stub_data() {
|
||||
let stub = vec![0xAA, 0xBB, 0xCC];
|
||||
let pdu = build_request(1, 15, &stub);
|
||||
|
||||
// Stub starts at offset 24 (header=16 + request fields=8)
|
||||
assert_eq!(&pdu[24..], &[0xAA, 0xBB, 0xCC]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn request_pdu_frag_length_matches() {
|
||||
let stub = vec![0xAA, 0xBB, 0xCC];
|
||||
let pdu = build_request(1, 15, &stub);
|
||||
|
||||
let frag_len = u16::from_le_bytes([pdu[8], pdu[9]]);
|
||||
assert_eq!(frag_len as usize, pdu.len());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_response_extracts_stub() {
|
||||
let stub = b"hello stub data";
|
||||
let response_pdu = build_test_response(1, stub);
|
||||
|
||||
let extracted = parse_response(&response_pdu).unwrap();
|
||||
assert_eq!(extracted, stub);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_response_wrong_version() {
|
||||
let mut pdu = build_test_response(1, b"data");
|
||||
pdu[0] = 4; // wrong version
|
||||
assert!(parse_response(&pdu).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_response_fragment_reports_last_flag() {
|
||||
// build_test_response sets PFC_FIRST_FRAG | PFC_LAST_FRAG.
|
||||
let pdu = build_test_response(1, b"stub");
|
||||
let (stub, is_last) = parse_response_fragment(&pdu).unwrap();
|
||||
assert_eq!(stub, b"stub");
|
||||
assert!(is_last, "FIRST|LAST PDU should be the last fragment");
|
||||
|
||||
// Clear PFC_LAST_FRAG in the flags byte: now it's a non-final fragment.
|
||||
let mut frag = pdu.clone();
|
||||
frag[3] &= !PFC_LAST_FRAG;
|
||||
let (stub, is_last) = parse_response_fragment(&frag).unwrap();
|
||||
assert_eq!(stub, b"stub");
|
||||
assert!(!is_last, "FIRST-only PDU should not be the last fragment");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_response_rejects_frag_length_shorter_than_header() {
|
||||
let mut pdu = build_test_response(1, b"data");
|
||||
// FragLength lives at offset 8 (u16 LE); set it below the 24-byte header.
|
||||
pdu[8] = 4;
|
||||
pdu[9] = 0;
|
||||
assert!(parse_response(&pdu).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_response_wrong_type() {
|
||||
let mut pdu = build_test_response(1, b"data");
|
||||
pdu[2] = PDU_TYPE_REQUEST; // wrong type
|
||||
assert!(parse_response(&pdu).is_err());
|
||||
}
|
||||
|
||||
// -- Test helpers --
|
||||
|
||||
/// Build a minimal BIND_ACK for testing.
|
||||
fn build_test_bind_ack(result: u16) -> Vec<u8> {
|
||||
let mut w = WriteCursor::with_capacity(64);
|
||||
|
||||
// Common header
|
||||
w.write_u8(RPC_VERSION_MAJOR);
|
||||
w.write_u8(RPC_VERSION_MINOR);
|
||||
w.write_u8(PDU_TYPE_BIND_ACK);
|
||||
w.write_u8(PFC_FIRST_FRAG | PFC_LAST_FRAG);
|
||||
w.write_bytes(&DATA_REP);
|
||||
let frag_len_pos = w.position();
|
||||
w.write_u16_le(0); // FragLength placeholder
|
||||
w.write_u16_le(0); // AuthLength
|
||||
w.write_u32_le(1); // CallId
|
||||
|
||||
// BIND_ACK specific
|
||||
w.write_u16_le(MAX_XMIT_FRAG);
|
||||
w.write_u16_le(MAX_RECV_FRAG);
|
||||
w.write_u32_le(0x12345); // AssocGroup
|
||||
|
||||
// Secondary address: "\pipe\srvsvc\0" (empty for simplicity -- use length 0)
|
||||
w.write_u16_le(0); // SecAddrLen = 0
|
||||
w.write_bytes(&[0, 0]); // Padding to 4-byte alignment
|
||||
|
||||
// Result list
|
||||
w.write_u8(1); // NumResults
|
||||
w.write_bytes(&[0, 0, 0]); // Reserved
|
||||
|
||||
// Result entry
|
||||
w.write_u16_le(result); // Result
|
||||
w.write_u16_le(0); // Reason
|
||||
// Transfer syntax (16 bytes UUID + 4 bytes version)
|
||||
NDR_UUID.pack(&mut w);
|
||||
w.write_u32_le(NDR_VERSION);
|
||||
|
||||
let total_len = w.position();
|
||||
w.set_u16_le_at(frag_len_pos, total_len as u16);
|
||||
|
||||
w.into_inner()
|
||||
}
|
||||
|
||||
/// Build a minimal RPC RESPONSE PDU wrapping the given stub data.
|
||||
fn build_test_response(call_id: u32, stub: &[u8]) -> Vec<u8> {
|
||||
let mut w = WriteCursor::with_capacity(RPC_HEADER_SIZE + 8 + stub.len());
|
||||
|
||||
w.write_u8(RPC_VERSION_MAJOR);
|
||||
w.write_u8(RPC_VERSION_MINOR);
|
||||
w.write_u8(PDU_TYPE_RESPONSE);
|
||||
w.write_u8(PFC_FIRST_FRAG | PFC_LAST_FRAG);
|
||||
w.write_bytes(&DATA_REP);
|
||||
let frag_len_pos = w.position();
|
||||
w.write_u16_le(0); // FragLength placeholder
|
||||
w.write_u16_le(0); // AuthLength
|
||||
w.write_u32_le(call_id);
|
||||
|
||||
// RESPONSE specific
|
||||
w.write_u32_le(stub.len() as u32); // AllocHint
|
||||
w.write_u16_le(0); // ContextId
|
||||
w.write_u8(0); // CancelCount
|
||||
w.write_u8(0); // Reserved
|
||||
|
||||
w.write_bytes(stub);
|
||||
|
||||
let total_len = w.position();
|
||||
w.set_u16_le_at(frag_len_pos, total_len as u16);
|
||||
|
||||
w.into_inner()
|
||||
}
|
||||
}
|
||||
554
vendor/smb2/src/rpc/srvsvc.rs
vendored
Normal file
554
vendor/smb2/src/rpc/srvsvc.rs
vendored
Normal file
@@ -0,0 +1,554 @@
|
||||
//! NetShareEnumAll NDR encoding/decoding for the srvsvc interface.
|
||||
//!
|
||||
//! Encodes the NetrShareEnum request (opnum 15) and decodes the response,
|
||||
//! extracting share names, types, and comments.
|
||||
|
||||
use crate::error::Result;
|
||||
use crate::pack::{ReadCursor, WriteCursor};
|
||||
use crate::Error;
|
||||
|
||||
/// Share type: disk share.
|
||||
pub const STYPE_DISKTREE: u32 = 0x0000_0000;
|
||||
/// Share type: printer queue.
|
||||
pub const STYPE_PRINTQ: u32 = 0x0000_0001;
|
||||
/// Share type: device.
|
||||
pub const STYPE_DEVICE: u32 = 0x0000_0002;
|
||||
/// Share type: IPC (inter-process communication).
|
||||
pub const STYPE_IPC: u32 = 0x0000_0003;
|
||||
/// Share type modifier: special/admin share (combined with above via OR).
|
||||
pub const STYPE_SPECIAL: u32 = 0x8000_0000;
|
||||
|
||||
/// Mask for the base share type (low bits).
|
||||
const STYPE_BASE_MASK: u32 = 0x0000_FFFF;
|
||||
|
||||
/// Information about a single network share.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct ShareInfo {
|
||||
/// The share name (for example, "Documents" or "IPC$").
|
||||
pub name: String,
|
||||
/// The share type as a raw u32 (see `STYPE_*` constants).
|
||||
pub share_type: u32,
|
||||
/// An optional comment/description for the share.
|
||||
pub comment: String,
|
||||
}
|
||||
|
||||
/// Build the NDR-encoded stub data for a NetShareEnumAll request.
|
||||
///
|
||||
/// The stub is meant to be wrapped in an RPC REQUEST PDU with opnum 15.
|
||||
pub fn build_net_share_enum_all_stub(server_name: &str) -> Vec<u8> {
|
||||
let mut w = WriteCursor::with_capacity(128);
|
||||
|
||||
// ServerName: NDR unique pointer to conformant+varying string (UTF-16LE, null-terminated)
|
||||
// Referent ID (non-null pointer)
|
||||
w.write_u32_le(0x0002_0000); // referent ID
|
||||
|
||||
// Encode the server name as a conformant+varying NDR string
|
||||
let name_utf16: Vec<u16> = server_name
|
||||
.encode_utf16()
|
||||
.chain(std::iter::once(0))
|
||||
.collect();
|
||||
let char_count = name_utf16.len() as u32;
|
||||
|
||||
// MaxCount
|
||||
w.write_u32_le(char_count);
|
||||
// Offset
|
||||
w.write_u32_le(0);
|
||||
// ActualCount
|
||||
w.write_u32_le(char_count);
|
||||
// String data (UTF-16LE)
|
||||
for &code_unit in &name_utf16 {
|
||||
w.write_u16_le(code_unit);
|
||||
}
|
||||
// Align to 4 bytes after string data
|
||||
w.align_to(4);
|
||||
|
||||
// InfoStruct: SHARE_ENUM_STRUCT
|
||||
// Level = 1 (we want SHARE_INFO_1)
|
||||
w.write_u32_le(1);
|
||||
|
||||
// ShareInfo union discriminant = 1 (matches level)
|
||||
w.write_u32_le(1);
|
||||
|
||||
// Pointer to SHARE_INFO_1_CONTAINER (unique pointer)
|
||||
w.write_u32_le(0x0002_0004); // referent ID
|
||||
|
||||
// SHARE_INFO_1_CONTAINER (deferred pointer data)
|
||||
// EntriesRead = 0 (server fills this)
|
||||
w.write_u32_le(0);
|
||||
// Buffer pointer = NULL (let server allocate)
|
||||
w.write_u32_le(0);
|
||||
|
||||
// PreferedMaximumLength = 0xFFFFFFFF (no limit)
|
||||
w.write_u32_le(0xFFFF_FFFF);
|
||||
|
||||
// ResumeHandle: unique pointer to u32
|
||||
// NULL pointer (no resume)
|
||||
w.write_u32_le(0);
|
||||
|
||||
w.into_inner()
|
||||
}
|
||||
|
||||
/// Build a complete RPC REQUEST PDU for NetShareEnumAll.
|
||||
///
|
||||
/// Combines the RPC REQUEST header (opnum 15) with the NDR stub data.
|
||||
pub fn build_net_share_enum_all(call_id: u32, server_name: &str) -> Vec<u8> {
|
||||
let stub = build_net_share_enum_all_stub(server_name);
|
||||
super::build_request(call_id, 15, &stub)
|
||||
}
|
||||
|
||||
/// Parse the NDR stub data from a NetShareEnumAll RPC RESPONSE.
|
||||
///
|
||||
/// Extracts all share entries from the response. The caller should use
|
||||
/// [`filter_disk_shares`] to get only disk shares.
|
||||
pub fn parse_net_share_enum_all_response(data: &[u8]) -> Result<Vec<ShareInfo>> {
|
||||
// First, parse the RPC RESPONSE envelope to get the stub data
|
||||
let stub = super::parse_response(data)?;
|
||||
parse_net_share_enum_all_stub(stub)
|
||||
}
|
||||
|
||||
/// Parse the NDR stub data directly (without the RPC envelope).
|
||||
///
|
||||
/// Used by the share-enumeration reassembly path, which concatenates the stub
|
||||
/// of each RPC fragment before decoding.
|
||||
pub(crate) fn parse_net_share_enum_all_stub(stub: &[u8]) -> Result<Vec<ShareInfo>> {
|
||||
let mut r = ReadCursor::new(stub);
|
||||
|
||||
// Level (u32) -- should be 1
|
||||
let level = r.read_u32_le()?;
|
||||
if level != 1 {
|
||||
return Err(Error::invalid_data(format!(
|
||||
"expected share info level 1, got {level}"
|
||||
)));
|
||||
}
|
||||
|
||||
// Union discriminant (u32) -- should be 1
|
||||
let discriminant = r.read_u32_le()?;
|
||||
if discriminant != 1 {
|
||||
return Err(Error::invalid_data(format!(
|
||||
"expected union discriminant 1, got {discriminant}"
|
||||
)));
|
||||
}
|
||||
|
||||
// Pointer to SHARE_INFO_1_CONTAINER
|
||||
let container_ptr = r.read_u32_le()?;
|
||||
if container_ptr == 0 {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
// SHARE_INFO_1_CONTAINER
|
||||
let count = r.read_u32_le()?;
|
||||
|
||||
// Pointer to array of SHARE_INFO_1
|
||||
let array_ptr = r.read_u32_le()?;
|
||||
if array_ptr == 0 || count == 0 {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
// Array: MaxCount header
|
||||
let max_count = r.read_u32_le()?;
|
||||
if max_count < count {
|
||||
return Err(Error::invalid_data(format!(
|
||||
"array max_count ({max_count}) < entries ({count})"
|
||||
)));
|
||||
}
|
||||
|
||||
// Read the fixed-size parts of each SHARE_INFO_1 entry:
|
||||
// Each entry has: name_ptr (u32), type (u32), comment_ptr (u32)
|
||||
struct RawEntry {
|
||||
name_ptr: u32,
|
||||
share_type: u32,
|
||||
comment_ptr: u32,
|
||||
}
|
||||
|
||||
let mut entries = Vec::with_capacity(count as usize);
|
||||
for _ in 0..count {
|
||||
let name_ptr = r.read_u32_le()?;
|
||||
let share_type = r.read_u32_le()?;
|
||||
let comment_ptr = r.read_u32_le()?;
|
||||
entries.push(RawEntry {
|
||||
name_ptr,
|
||||
share_type,
|
||||
comment_ptr,
|
||||
});
|
||||
}
|
||||
|
||||
// Now read the deferred pointer data (conformant+varying strings)
|
||||
let mut shares = Vec::with_capacity(count as usize);
|
||||
for entry in &entries {
|
||||
let name = if entry.name_ptr != 0 {
|
||||
read_ndr_string(&mut r)?
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
|
||||
let comment = if entry.comment_ptr != 0 {
|
||||
read_ndr_string(&mut r)?
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
|
||||
shares.push(ShareInfo {
|
||||
name,
|
||||
share_type: entry.share_type,
|
||||
comment,
|
||||
});
|
||||
}
|
||||
|
||||
Ok(shares)
|
||||
}
|
||||
|
||||
/// Read an NDR conformant+varying UTF-16LE string from the cursor.
|
||||
///
|
||||
/// Format: MaxCount(u32) + Offset(u32) + ActualCount(u32) + UTF-16LE data.
|
||||
/// The string is null-terminated on the wire; we strip the null.
|
||||
fn read_ndr_string(r: &mut ReadCursor<'_>) -> Result<String> {
|
||||
let _max_count = r.read_u32_le()?;
|
||||
let _offset = r.read_u32_le()?;
|
||||
let actual_count = r.read_u32_le()?;
|
||||
|
||||
if actual_count == 0 {
|
||||
return Ok(String::new());
|
||||
}
|
||||
|
||||
let byte_len = actual_count as usize * 2;
|
||||
let s = r.read_utf16_le(byte_len)?;
|
||||
|
||||
// Align to 4 bytes after reading string data
|
||||
let pos = r.position();
|
||||
let padding = (4 - (pos % 4)) % 4;
|
||||
if padding > 0 && r.remaining() >= padding {
|
||||
r.skip(padding)?;
|
||||
}
|
||||
|
||||
// Strip trailing null
|
||||
Ok(s.trim_end_matches('\0').to_string())
|
||||
}
|
||||
|
||||
/// Filter shares, keeping only disk shares and excluding admin shares (ending with `$`).
|
||||
pub fn filter_disk_shares(shares: Vec<ShareInfo>) -> Vec<ShareInfo> {
|
||||
shares
|
||||
.into_iter()
|
||||
.filter(|s| {
|
||||
let base_type = s.share_type & STYPE_BASE_MASK;
|
||||
let is_disk = base_type == STYPE_DISKTREE;
|
||||
let is_special = (s.share_type & STYPE_SPECIAL) != 0;
|
||||
let ends_with_dollar = s.name.ends_with('$');
|
||||
is_disk && !is_special && !ends_with_dollar
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn build_request_has_opnum_15() {
|
||||
let pdu = build_net_share_enum_all(1, r"\\server");
|
||||
// OpNum is at offset 22 in the RPC REQUEST PDU
|
||||
let opnum = u16::from_le_bytes([pdu[22], pdu[23]]);
|
||||
assert_eq!(opnum, 15);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_request_stub_contains_server_name() {
|
||||
let stub = build_net_share_enum_all_stub(r"\\server");
|
||||
// The server name should appear as UTF-16LE somewhere in the stub
|
||||
let expected_utf16: Vec<u8> = r"\\server"
|
||||
.encode_utf16()
|
||||
.flat_map(|c| c.to_le_bytes())
|
||||
.collect();
|
||||
|
||||
let found = stub
|
||||
.windows(expected_utf16.len())
|
||||
.any(|window| window == expected_utf16.as_slice());
|
||||
assert!(found, "stub should contain the server name in UTF-16LE");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_response_with_three_shares() {
|
||||
let response_pdu = build_test_enum_response(&[
|
||||
("Documents", STYPE_DISKTREE, "Shared docs"),
|
||||
("IPC$", STYPE_IPC | STYPE_SPECIAL, "Remote IPC"),
|
||||
("C$", STYPE_DISKTREE | STYPE_SPECIAL, "Default share"),
|
||||
]);
|
||||
|
||||
let shares = parse_net_share_enum_all_response(&response_pdu).unwrap();
|
||||
assert_eq!(shares.len(), 3);
|
||||
assert_eq!(shares[0].name, "Documents");
|
||||
assert_eq!(shares[0].share_type, STYPE_DISKTREE);
|
||||
assert_eq!(shares[0].comment, "Shared docs");
|
||||
assert_eq!(shares[1].name, "IPC$");
|
||||
assert_eq!(shares[2].name, "C$");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn filter_keeps_disk_shares() {
|
||||
let shares = vec![
|
||||
ShareInfo {
|
||||
name: "Documents".to_string(),
|
||||
share_type: STYPE_DISKTREE,
|
||||
comment: "Shared docs".to_string(),
|
||||
},
|
||||
ShareInfo {
|
||||
name: "Photos".to_string(),
|
||||
share_type: STYPE_DISKTREE,
|
||||
comment: String::new(),
|
||||
},
|
||||
];
|
||||
|
||||
let filtered = filter_disk_shares(shares);
|
||||
assert_eq!(filtered.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn filter_removes_ipc() {
|
||||
let shares = vec![ShareInfo {
|
||||
name: "IPC$".to_string(),
|
||||
share_type: STYPE_IPC | STYPE_SPECIAL,
|
||||
comment: "Remote IPC".to_string(),
|
||||
}];
|
||||
|
||||
let filtered = filter_disk_shares(shares);
|
||||
assert!(filtered.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn filter_removes_admin_shares() {
|
||||
let shares = vec![
|
||||
ShareInfo {
|
||||
name: "C$".to_string(),
|
||||
share_type: STYPE_DISKTREE | STYPE_SPECIAL,
|
||||
comment: "Default share".to_string(),
|
||||
},
|
||||
ShareInfo {
|
||||
name: "ADMIN$".to_string(),
|
||||
share_type: STYPE_DISKTREE | STYPE_SPECIAL,
|
||||
comment: "Remote Admin".to_string(),
|
||||
},
|
||||
];
|
||||
|
||||
let filtered = filter_disk_shares(shares);
|
||||
assert!(filtered.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn filter_mixed_shares() {
|
||||
let shares = vec![
|
||||
ShareInfo {
|
||||
name: "Documents".to_string(),
|
||||
share_type: STYPE_DISKTREE,
|
||||
comment: "Shared docs".to_string(),
|
||||
},
|
||||
ShareInfo {
|
||||
name: "IPC$".to_string(),
|
||||
share_type: STYPE_IPC | STYPE_SPECIAL,
|
||||
comment: "Remote IPC".to_string(),
|
||||
},
|
||||
ShareInfo {
|
||||
name: "C$".to_string(),
|
||||
share_type: STYPE_DISKTREE | STYPE_SPECIAL,
|
||||
comment: "Default share".to_string(),
|
||||
},
|
||||
ShareInfo {
|
||||
name: "Photos".to_string(),
|
||||
share_type: STYPE_DISKTREE,
|
||||
comment: String::new(),
|
||||
},
|
||||
ShareInfo {
|
||||
name: "Printer".to_string(),
|
||||
share_type: STYPE_PRINTQ,
|
||||
comment: "Office printer".to_string(),
|
||||
},
|
||||
];
|
||||
|
||||
let filtered = filter_disk_shares(shares);
|
||||
assert_eq!(filtered.len(), 2);
|
||||
assert_eq!(filtered[0].name, "Documents");
|
||||
assert_eq!(filtered[1].name, "Photos");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_empty_share_list() {
|
||||
let response_pdu = build_test_enum_response(&[]);
|
||||
let shares = parse_net_share_enum_all_response(&response_pdu).unwrap();
|
||||
assert!(shares.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_share_with_unicode_name() {
|
||||
let response_pdu = build_test_enum_response(&[(
|
||||
"\u{00C4}rchive",
|
||||
STYPE_DISKTREE,
|
||||
"Archiv f\u{00FC}r Dateien",
|
||||
)]);
|
||||
|
||||
let shares = parse_net_share_enum_all_response(&response_pdu).unwrap();
|
||||
assert_eq!(shares.len(), 1);
|
||||
assert_eq!(shares[0].name, "\u{00C4}rchive");
|
||||
assert_eq!(shares[0].comment, "Archiv f\u{00FC}r Dateien");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_share_with_cjk_characters() {
|
||||
let response_pdu = build_test_enum_response(&[(
|
||||
"\u{5171}\u{6709}",
|
||||
STYPE_DISKTREE,
|
||||
"\u{5171}\u{6709}\u{30D5}\u{30A9}\u{30EB}\u{30C0}",
|
||||
)]);
|
||||
|
||||
let shares = parse_net_share_enum_all_response(&response_pdu).unwrap();
|
||||
assert_eq!(shares.len(), 1);
|
||||
assert_eq!(shares[0].name, "\u{5171}\u{6709}");
|
||||
assert_eq!(
|
||||
shares[0].comment,
|
||||
"\u{5171}\u{6709}\u{30D5}\u{30A9}\u{30EB}\u{30C0}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn roundtrip_build_and_parse() {
|
||||
// Build a request, then manually construct a response and parse it
|
||||
let _request = build_net_share_enum_all(1, r"\\testserver");
|
||||
|
||||
let response_pdu = build_test_enum_response(&[
|
||||
("Share1", STYPE_DISKTREE, "First share"),
|
||||
("Share2", STYPE_DISKTREE, "Second share"),
|
||||
]);
|
||||
|
||||
let shares = parse_net_share_enum_all_response(&response_pdu).unwrap();
|
||||
assert_eq!(shares.len(), 2);
|
||||
assert_eq!(shares[0].name, "Share1");
|
||||
assert_eq!(shares[0].comment, "First share");
|
||||
assert_eq!(shares[1].name, "Share2");
|
||||
assert_eq!(shares[1].comment, "Second share");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn filter_preserves_non_dollar_disk_shares_only() {
|
||||
// A share named "My$hare" (dollar in middle) should be kept
|
||||
let shares = vec![ShareInfo {
|
||||
name: "My$hare".to_string(),
|
||||
share_type: STYPE_DISKTREE,
|
||||
comment: String::new(),
|
||||
}];
|
||||
|
||||
let filtered = filter_disk_shares(shares);
|
||||
assert_eq!(filtered.len(), 1);
|
||||
assert_eq!(filtered[0].name, "My$hare");
|
||||
}
|
||||
|
||||
// -- Test helpers --
|
||||
|
||||
/// Write an NDR conformant+varying UTF-16LE string into the cursor.
|
||||
fn write_ndr_string(w: &mut WriteCursor, s: &str) {
|
||||
let utf16: Vec<u16> = s.encode_utf16().chain(std::iter::once(0)).collect();
|
||||
let char_count = utf16.len() as u32;
|
||||
|
||||
w.write_u32_le(char_count); // MaxCount
|
||||
w.write_u32_le(0); // Offset
|
||||
w.write_u32_le(char_count); // ActualCount
|
||||
for &code_unit in &utf16 {
|
||||
w.write_u16_le(code_unit);
|
||||
}
|
||||
w.align_to(4);
|
||||
}
|
||||
|
||||
/// Build a complete RPC RESPONSE PDU containing the given shares.
|
||||
///
|
||||
/// This constructs valid NDR stub data wrapped in an RPC RESPONSE envelope.
|
||||
fn build_test_enum_response(shares: &[(&str, u32, &str)]) -> Vec<u8> {
|
||||
let stub = build_test_enum_stub(shares);
|
||||
build_test_response_pdu(1, &stub)
|
||||
}
|
||||
|
||||
/// Build NDR stub data for a NetShareEnumAll response.
|
||||
fn build_test_enum_stub(shares: &[(&str, u32, &str)]) -> Vec<u8> {
|
||||
let mut w = WriteCursor::with_capacity(512);
|
||||
let count = shares.len() as u32;
|
||||
|
||||
// Level = 1
|
||||
w.write_u32_le(1);
|
||||
// Union discriminant = 1
|
||||
w.write_u32_le(1);
|
||||
|
||||
if count == 0 {
|
||||
// Null container pointer
|
||||
w.write_u32_le(0);
|
||||
// TotalEntries
|
||||
w.write_u32_le(0);
|
||||
// ResumeHandle pointer (null)
|
||||
w.write_u32_le(0);
|
||||
// Return value (Windows error code, 0 = success)
|
||||
w.write_u32_le(0);
|
||||
return w.into_inner();
|
||||
}
|
||||
|
||||
// Container pointer (non-null)
|
||||
w.write_u32_le(0x0002_0000);
|
||||
|
||||
// SHARE_INFO_1_CONTAINER
|
||||
w.write_u32_le(count); // EntriesRead
|
||||
w.write_u32_le(0x0002_0004); // Array pointer (non-null)
|
||||
|
||||
// Array: MaxCount
|
||||
w.write_u32_le(count);
|
||||
|
||||
// Fixed-size entries: name_ptr, type, comment_ptr
|
||||
for (i, &(_, share_type, _)) in shares.iter().enumerate() {
|
||||
w.write_u32_le(0x0002_0008 + (i as u32) * 2); // name referent ID
|
||||
w.write_u32_le(share_type);
|
||||
w.write_u32_le(0x0002_0108 + (i as u32) * 2); // comment referent ID
|
||||
}
|
||||
|
||||
// Deferred string data (name then comment for each entry)
|
||||
for &(name, _, comment) in shares {
|
||||
write_ndr_string(&mut w, name);
|
||||
write_ndr_string(&mut w, comment);
|
||||
}
|
||||
|
||||
// TotalEntries
|
||||
w.write_u32_le(count);
|
||||
// ResumeHandle pointer (null)
|
||||
w.write_u32_le(0);
|
||||
// Return value (0 = success)
|
||||
w.write_u32_le(0);
|
||||
|
||||
w.into_inner()
|
||||
}
|
||||
|
||||
/// Build a minimal RPC RESPONSE PDU wrapping stub data.
|
||||
fn build_test_response_pdu(call_id: u32, stub: &[u8]) -> Vec<u8> {
|
||||
use crate::pack::WriteCursor;
|
||||
|
||||
let mut w = WriteCursor::with_capacity(24 + stub.len());
|
||||
|
||||
// Common header
|
||||
w.write_u8(5); // Version
|
||||
w.write_u8(0); // VersionMinor
|
||||
w.write_u8(2); // PacketType = RESPONSE
|
||||
w.write_u8(0x03); // Flags (first + last)
|
||||
w.write_bytes(&[0x10, 0x00, 0x00, 0x00]); // DataRep
|
||||
let frag_len_pos = w.position();
|
||||
w.write_u16_le(0); // FragLength placeholder
|
||||
w.write_u16_le(0); // AuthLength
|
||||
w.write_u32_le(call_id);
|
||||
|
||||
// RESPONSE specific
|
||||
w.write_u32_le(stub.len() as u32); // AllocHint
|
||||
w.write_u16_le(0); // ContextId
|
||||
w.write_u8(0); // CancelCount
|
||||
w.write_u8(0); // Reserved
|
||||
|
||||
w.write_bytes(stub);
|
||||
|
||||
let total_len = w.position();
|
||||
w.set_u16_le_at(frag_len_pos, total_len as u16);
|
||||
|
||||
w.into_inner()
|
||||
}
|
||||
}
|
||||
48
vendor/smb2/src/testing/CLAUDE.md
vendored
Normal file
48
vendor/smb2/src/testing/CLAUDE.md
vendored
Normal file
@@ -0,0 +1,48 @@
|
||||
# Testing module -- Docker-based SMB test servers
|
||||
|
||||
Feature-gated (`testing` feature flag). Provides Docker-based Samba containers for consumers (apps that depend on smb2) to test their SMB integration.
|
||||
|
||||
## Key files
|
||||
|
||||
| File | Purpose |
|
||||
|---|---|
|
||||
| `mod.rs` | `TestServers`, `Error`, port constants, embedded Docker files, `write_compose_files()` |
|
||||
|
||||
## Architecture
|
||||
|
||||
Three-layer testing model:
|
||||
|
||||
1. **Layer 1 (Rust)**: `TestServers::start()` / `start_all()` / `start_blocking()` return a struct with `*_client()` methods that connect to Docker containers.
|
||||
2. **Layer 2 (E2E)**: `write_compose_files(dir)` extracts embedded Docker infrastructure to disk for non-Rust test frameworks.
|
||||
3. **Layer 3 (Manual QA)**: Same compose files, run manually.
|
||||
|
||||
## Embedded files
|
||||
|
||||
All 35 Docker files (compose, Dockerfiles, smb.conf, scripts) are embedded via `include_str!` at compile time. At runtime, `write_compose_files()` writes them to a temp directory. Docker Compose runs from there.
|
||||
|
||||
## Port scheme
|
||||
|
||||
15 containers on ports 10480-10494. Each port has an env-var override (`SMB_CONSUMER_*_PORT`). The `port()` function checks the env var, falls back to the hardcoded default.
|
||||
|
||||
## Profiles
|
||||
|
||||
- **Minimal**: guest + auth only (2 containers, fast startup).
|
||||
- **All**: all 15 containers.
|
||||
|
||||
Calling a `*_client()` method for a container not in the current profile returns `Error::ContainerNotStarted`.
|
||||
|
||||
## Key decisions
|
||||
|
||||
| Decision | Choice | Why |
|
||||
|---|---|---|
|
||||
| No extra deps | `std::process::Command` for Docker | Keep the crate lean |
|
||||
| Temp dir via `std::env::temp_dir()` | No `tempfile` crate | No extra deps |
|
||||
| Embedded files via `include_str!` | Self-contained published crate | Consumers don't need smb2 source tree |
|
||||
| Separate error type | `testing::Error` vs `smb2::Error` | Docker failures are not protocol errors |
|
||||
| Best-effort cleanup in Drop | `docker compose down` | LazyLock statics never drop, so this is convenience only |
|
||||
|
||||
## Gotchas
|
||||
|
||||
- **LazyLock statics never drop**: `TestServers::drop()` won't run at process exit. CI should use explicit cleanup steps.
|
||||
- **Flaky container has no health check**: The 5s-up/5s-down cycle means health checks would randomly fail. `wait_healthy()` skips it.
|
||||
- **DFS is disabled on test clients**: Consumer containers don't set up DFS. The `connect_guest` / `connect_auth` helpers set `dfs_enabled: false`.
|
||||
1275
vendor/smb2/src/testing/mod.rs
vendored
Normal file
1275
vendor/smb2/src/testing/mod.rs
vendored
Normal file
File diff suppressed because it is too large
Load Diff
53
vendor/smb2/src/transport/CLAUDE.md
vendored
Normal file
53
vendor/smb2/src/transport/CLAUDE.md
vendored
Normal file
@@ -0,0 +1,53 @@
|
||||
# Transport -- send/receive abstraction
|
||||
|
||||
Split transport traits for SMB2 message I/O. Two implementations: TCP and mock.
|
||||
|
||||
## Key files
|
||||
|
||||
| File | Purpose |
|
||||
|---|---|
|
||||
| `mod.rs` | `TransportSend`, `TransportReceive`, `Transport` traits |
|
||||
| `tcp.rs` | `TcpTransport` -- direct TCP to port 445, handles framing |
|
||||
| `mock.rs` | `MockTransport` -- FIFO response queue for testing |
|
||||
|
||||
## Split traits
|
||||
|
||||
`TransportSend` and `TransportReceive` are separate traits. This avoids deadlock in the pipeline's `tokio::select!` loop where one task sends requests while another concurrently reads responses on the same connection. A single `Transport` trait would require `&mut self` for both directions, making concurrent send+receive impossible without `Arc<Mutex>`.
|
||||
|
||||
The blanket impl `Transport` combines both halves. `Connection` stores `Box<dyn TransportSend>` and `Box<dyn TransportReceive>` separately.
|
||||
|
||||
## TCP framing
|
||||
|
||||
```
|
||||
[0x00] [length: 3 bytes, big-endian] [SMB2 message(s)]
|
||||
```
|
||||
|
||||
- First byte must be `0x00`
|
||||
- Next 3 bytes: message length in big-endian (network byte order)
|
||||
- Maximum frame size: 16 MB
|
||||
- This is the ONLY big-endian value in SMB2
|
||||
|
||||
`TcpTransport::send` prepends the 4-byte header. `TcpTransport::receive` reads the header, then `read_exact` for the payload.
|
||||
|
||||
## Who reads the transport
|
||||
|
||||
`TransportReceive::receive()` is called by exactly one owner: the background receiver task spawned by `Connection::from_transport` (Phase 2 actor refactor). No other code path calls `receive()` in production. This is the invariant that makes per-`MessageId` routing sound — there's a single serialized read of the wire, then demux to per-request `oneshot::Sender`s. See `src/client/CLAUDE.md` § "Connection internals: receiver task + `oneshot` routing".
|
||||
|
||||
`TransportSend::send()` is called from the caller thread (the one holding `&mut Connection`). `TcpTransport`'s internal Mutex on the write half serializes sends — relevant for Phase 3 once `Connection` becomes `Clone`.
|
||||
|
||||
## MockTransport
|
||||
|
||||
Used by all unit tests. Stores sent messages for inspection and returns queued responses in FIFO order. Thread-safe via `std::sync::Mutex`.
|
||||
|
||||
Phase 2 changed `receive()` from "return `Err(Disconnected)` immediately when the queue is empty" to "block on `tokio::sync::Notify` until data is queued or `close()` is called". Required because the Connection's receiver task calls `receive()` in a loop — a premature `Disconnected` would kill the task while a test was still setting up responses.
|
||||
|
||||
- `queue_response(data)` / `queue_responses(vec)` push to the queue and call `notify_one()`. `notify_one` stores a permit if no receiver is parked, so the next `.notified().await` returns immediately.
|
||||
- `close()` sets an atomic `closed` flag and calls BOTH `notify_one()` (covers the wake-loss race where `receive()` is between `closed.load()` and `.notified().await`) and `notify_waiters()` (wakes already-parked waiters).
|
||||
- External consumers using `MockTransport` in their own tests must call `close()` to get an explicit end-of-stream; the implicit "empty queue = disconnected" behavior is gone.
|
||||
|
||||
## Gotchas
|
||||
|
||||
- **Partial TCP reads**: Always use `read_exact` to read the full frame. TCP can deliver partial data in any `read()` call.
|
||||
- **16 MB max frame**: Reject frames larger than 16 MB to prevent OOM from malicious servers.
|
||||
- **Frame may contain multiple messages**: Compound responses arrive in a single frame. The Connection's receiver task splits them by `NextCommand` offsets and routes each sub-response by `MessageId` independently.
|
||||
- **`MockTransport::close()` wake-loss**: `notify_waiters()` alone only wakes already-parked waiters; if `close()` fires between `receive()`'s `closed.load()` check and its `notified().await`, the signal is lost. `close()` therefore also calls `notify_one()` to store a permit — next `.notified().await` returns immediately and the loop re-observes `closed=true`. Noticed via code review after Phase 2.
|
||||
507
vendor/smb2/src/transport/mock.rs
vendored
Normal file
507
vendor/smb2/src/transport/mock.rs
vendored
Normal file
@@ -0,0 +1,507 @@
|
||||
//! Mock transport for testing.
|
||||
//!
|
||||
//! Provides a [`MockTransport`] that queues canned responses and records
|
||||
//! sent messages, enabling test-driven development of higher layers
|
||||
//! without needing a real SMB server.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use std::collections::VecDeque;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Mutex;
|
||||
|
||||
use tokio::sync::Notify;
|
||||
|
||||
use crate::error::{Error, Result};
|
||||
use crate::transport::{TransportReceive, TransportSend};
|
||||
|
||||
/// A mock transport that queues responses and records sent messages.
|
||||
///
|
||||
/// Use this in tests to simulate server conversations without a real
|
||||
/// network connection. Responses are returned in FIFO order.
|
||||
///
|
||||
/// `receive()` awaits on an internal `Notify` when the queue is empty,
|
||||
/// so the background receiver task doesn't exit prematurely between
|
||||
/// `queue_response` calls. Explicit disconnect is triggered by calling
|
||||
/// [`Self::close`].
|
||||
pub struct MockTransport {
|
||||
/// Responses to return on `receive()`, in order.
|
||||
responses: Mutex<VecDeque<Vec<u8>>>,
|
||||
/// Messages that were sent, for assertions.
|
||||
sent: Mutex<Vec<Vec<u8>>>,
|
||||
/// How many times `receive()` was called successfully (returning Ok).
|
||||
receive_count: Mutex<usize>,
|
||||
/// Wakes receivers when a response is queued or `close()` is called.
|
||||
notify: Notify,
|
||||
/// Set by `close()` to signal end-of-stream.
|
||||
closed: AtomicBool,
|
||||
/// When `true`, `receive()` rewrites each response sub-frame's
|
||||
/// `MessageId` to match the `MessageId` of the next pending sent request
|
||||
/// (and consumes it). See [`Self::enable_auto_rewrite_msg_id`].
|
||||
auto_rewrite: AtomicBool,
|
||||
/// FIFO of `MessageId`s observed in `send()` that haven't yet been
|
||||
/// consumed by a `receive()` rewrite. Only used when `auto_rewrite`
|
||||
/// is on.
|
||||
pending_sent_msg_ids: Mutex<VecDeque<u64>>,
|
||||
/// Signaled whenever a new send is recorded or a close happens — used
|
||||
/// by `receive()` in auto-rewrite mode to wait for a sent msg_id to
|
||||
/// pair with a queued response.
|
||||
send_notify: Notify,
|
||||
}
|
||||
|
||||
impl MockTransport {
|
||||
/// Create a new mock with no queued responses.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
responses: Mutex::new(VecDeque::new()),
|
||||
sent: Mutex::new(Vec::new()),
|
||||
receive_count: Mutex::new(0),
|
||||
notify: Notify::new(),
|
||||
closed: AtomicBool::new(false),
|
||||
auto_rewrite: AtomicBool::new(false),
|
||||
pending_sent_msg_ids: Mutex::new(VecDeque::new()),
|
||||
send_notify: Notify::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Enable msg_id rewriting: when `true`, `receive()` rewrites each
|
||||
/// response sub-frame's `MessageId` in-place to match the `MessageId`
|
||||
/// of the next request recorded by `send()` (FIFO pairing).
|
||||
///
|
||||
/// Without this, canned response builders hardcode `MessageId(0)` and
|
||||
/// won't match the caller's allocated msg_ids — the receiver task
|
||||
/// drops them as orphans and every caller hangs. This mode is the
|
||||
/// test-fixture replacement for the pre-Phase-3 orphan-filter-off
|
||||
/// path. Compound responses (multiple sub-frames chained via
|
||||
/// `NextCommand`) each consume one sent msg_id in order.
|
||||
///
|
||||
/// The receive side blocks until both a queued response and a sent
|
||||
/// msg_id are available, so tests can queue responses before or
|
||||
/// after the caller sends.
|
||||
pub fn enable_auto_rewrite_msg_id(&self) {
|
||||
self.auto_rewrite.store(true, Ordering::Release);
|
||||
}
|
||||
|
||||
/// Queue a response to be returned by the next `receive()` call.
|
||||
pub fn queue_response(&self, data: Vec<u8>) {
|
||||
self.responses.lock().unwrap().push_back(data);
|
||||
self.notify.notify_one();
|
||||
}
|
||||
|
||||
/// Queue multiple responses to be returned in order.
|
||||
pub fn queue_responses(&self, responses: Vec<Vec<u8>>) {
|
||||
let mut guard = self.responses.lock().unwrap();
|
||||
let count = responses.len();
|
||||
for r in responses {
|
||||
guard.push_back(r);
|
||||
}
|
||||
drop(guard);
|
||||
for _ in 0..count {
|
||||
self.notify.notify_one();
|
||||
}
|
||||
}
|
||||
|
||||
/// Signal end-of-stream: after all queued responses are drained,
|
||||
/// `receive()` returns `Err(Error::Disconnected)`.
|
||||
pub fn close(&self) {
|
||||
self.closed.store(true, Ordering::Release);
|
||||
// Use `notify_one` (stores a permit for the next `notified().await`)
|
||||
// in addition to `notify_waiters` (wakes currently-parked waiters).
|
||||
// `notify_waiters` alone loses the signal if `close()` fires
|
||||
// between `receive()`'s `closed.load()` check and its
|
||||
// `notified().await` — no waiter is parked yet, so nothing gets
|
||||
// woken. The stored permit from `notify_one` covers that gap.
|
||||
self.notify.notify_one();
|
||||
self.notify.notify_waiters();
|
||||
// Same treatment for the send-notification used by auto-rewrite:
|
||||
// close should wake a receive that's blocked waiting for a paired
|
||||
// sent msg_id so it observes `closed` and bails out.
|
||||
self.send_notify.notify_one();
|
||||
self.send_notify.notify_waiters();
|
||||
}
|
||||
|
||||
/// Get all messages that were sent.
|
||||
pub fn sent_messages(&self) -> Vec<Vec<u8>> {
|
||||
self.sent.lock().unwrap().clone()
|
||||
}
|
||||
|
||||
/// Get the nth sent message, or `None` if out of bounds.
|
||||
pub fn sent_message(&self, n: usize) -> Option<Vec<u8>> {
|
||||
self.sent.lock().unwrap().get(n).cloned()
|
||||
}
|
||||
|
||||
/// How many messages have been sent.
|
||||
pub fn sent_count(&self) -> usize {
|
||||
self.sent.lock().unwrap().len()
|
||||
}
|
||||
|
||||
/// Clear all recorded sent messages.
|
||||
pub fn clear_sent(&self) {
|
||||
self.sent.lock().unwrap().clear();
|
||||
}
|
||||
|
||||
/// How many times `receive()` was called successfully (returned Ok).
|
||||
pub fn received_count(&self) -> usize {
|
||||
*self.receive_count.lock().unwrap()
|
||||
}
|
||||
|
||||
/// How many responses are still queued and unread.
|
||||
///
|
||||
/// Useful in tests that want to assert the code-under-test consumed
|
||||
/// every response it was expected to, without leaking any to a
|
||||
/// later test or leaving stale state that could mask a bug.
|
||||
pub fn pending_responses(&self) -> usize {
|
||||
self.responses.lock().unwrap().len()
|
||||
}
|
||||
|
||||
/// Assert that every queued response has been consumed.
|
||||
///
|
||||
/// Panics with a descriptive message if any responses remain in the
|
||||
/// queue. Use at the end of a test to catch the "caller forgot to
|
||||
/// receive" pattern that produces response-pipe pollution in
|
||||
/// real usage.
|
||||
#[track_caller]
|
||||
pub fn assert_fully_consumed(&self) {
|
||||
let remaining = self.pending_responses();
|
||||
assert_eq!(
|
||||
remaining, 0,
|
||||
"MockTransport has {} queued response(s) the code-under-test never read. \
|
||||
This usually means a caller sent a request but never received its response, \
|
||||
which in real usage leaves an orphan on the wire and corrupts the next op.",
|
||||
remaining
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for MockTransport {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TransportSend for MockTransport {
|
||||
async fn send(&self, data: &[u8]) -> Result<()> {
|
||||
// In auto-rewrite mode, capture the MessageId of each sub-frame
|
||||
// so `receive()` can rewrite a queued response to match.
|
||||
if self.auto_rewrite.load(Ordering::Acquire) {
|
||||
for msg_id in extract_msg_ids(data) {
|
||||
self.pending_sent_msg_ids.lock().unwrap().push_back(msg_id);
|
||||
self.send_notify.notify_one();
|
||||
}
|
||||
}
|
||||
self.sent.lock().unwrap().push(data.to_vec());
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TransportReceive for MockTransport {
|
||||
async fn receive(&self) -> Result<Vec<u8>> {
|
||||
loop {
|
||||
let auto = self.auto_rewrite.load(Ordering::Acquire);
|
||||
// Wait for a queued response first (auto mode and plain mode
|
||||
// both need one to exist).
|
||||
let has_response = !self.responses.lock().unwrap().is_empty();
|
||||
if !has_response {
|
||||
if self.closed.load(Ordering::Acquire) {
|
||||
return Err(Error::Disconnected);
|
||||
}
|
||||
self.notify.notified().await;
|
||||
continue;
|
||||
}
|
||||
|
||||
if auto {
|
||||
// We have a response; peek its sub-frame count and wait
|
||||
// for at least that many sent msg_ids to be queued
|
||||
// (one consumed per sub-frame, even ones that already
|
||||
// have non-zero msg_ids, so pairing stays 1:1).
|
||||
let needed = {
|
||||
let guard = self.responses.lock().unwrap();
|
||||
match guard.front() {
|
||||
Some(frame) => count_sub_frames(frame),
|
||||
None => continue,
|
||||
}
|
||||
};
|
||||
if needed > 0 {
|
||||
loop {
|
||||
let have = self.pending_sent_msg_ids.lock().unwrap().len();
|
||||
if have >= needed {
|
||||
break;
|
||||
}
|
||||
if self.closed.load(Ordering::Acquire) {
|
||||
return Err(Error::Disconnected);
|
||||
}
|
||||
self.send_notify.notified().await;
|
||||
}
|
||||
}
|
||||
// Consume one response and `needed` sent msg_ids,
|
||||
// rewriting each sub-frame's zero msg_id to match the
|
||||
// corresponding sent msg_id.
|
||||
let mut data = match self.responses.lock().unwrap().pop_front() {
|
||||
Some(d) => d,
|
||||
None => continue,
|
||||
};
|
||||
let mut ids = self.pending_sent_msg_ids.lock().unwrap();
|
||||
rewrite_msg_ids(&mut data, &mut ids);
|
||||
drop(ids);
|
||||
*self.receive_count.lock().unwrap() += 1;
|
||||
return Ok(data);
|
||||
}
|
||||
|
||||
// Plain mode: just pop and return.
|
||||
let data = match self.responses.lock().unwrap().pop_front() {
|
||||
Some(d) => d,
|
||||
None => continue,
|
||||
};
|
||||
*self.receive_count.lock().unwrap() += 1;
|
||||
return Ok(data);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract `MessageId`s from a packed SMB2 request frame (possibly compound).
|
||||
/// Returns one msg_id per sub-frame, following `NextCommand` offsets.
|
||||
/// Returns an empty Vec if the data isn't a recognizable SMB2 frame —
|
||||
/// e.g. when `send()` is used with arbitrary bytes in transport-level tests.
|
||||
fn extract_msg_ids(data: &[u8]) -> Vec<u64> {
|
||||
const HEADER_MIN: usize = 64;
|
||||
if data.len() < HEADER_MIN {
|
||||
return Vec::new();
|
||||
}
|
||||
// Not an SMB2 header — skip (non-SMB2 tests call send with arbitrary bytes).
|
||||
if &data[0..4] != b"\xFESMB" {
|
||||
return Vec::new();
|
||||
}
|
||||
let mut ids = Vec::new();
|
||||
let mut offset = 0usize;
|
||||
loop {
|
||||
if offset + HEADER_MIN > data.len() {
|
||||
break;
|
||||
}
|
||||
let msg_id =
|
||||
u64::from_le_bytes(data[offset + 24..offset + 32].try_into().unwrap_or([0; 8]));
|
||||
ids.push(msg_id);
|
||||
let next = u32::from_le_bytes(data[offset + 20..offset + 24].try_into().unwrap_or([0; 4]));
|
||||
if next == 0 {
|
||||
break;
|
||||
}
|
||||
offset += next as usize;
|
||||
}
|
||||
ids
|
||||
}
|
||||
|
||||
/// Count sub-frames in a packed SMB2 response frame by walking
|
||||
/// `NextCommand` offsets. Returns 0 for non-SMB2 frames, otherwise the
|
||||
/// total sub-frame count. `rewrite_msg_ids` consumes one sent msg_id
|
||||
/// per sub-frame (even those with already-set msg_ids) to keep
|
||||
/// send→receive pairing strictly 1:1 and avoid queue drift in tests
|
||||
/// that hardcode some but not all msg_ids.
|
||||
fn count_sub_frames(data: &[u8]) -> usize {
|
||||
const HEADER_MIN: usize = 64;
|
||||
if data.len() < HEADER_MIN || &data[0..4] != b"\xFESMB" {
|
||||
return 0;
|
||||
}
|
||||
let mut count = 0usize;
|
||||
let mut offset = 0usize;
|
||||
loop {
|
||||
if offset + HEADER_MIN > data.len() {
|
||||
break;
|
||||
}
|
||||
count += 1;
|
||||
let next = u32::from_le_bytes(data[offset + 20..offset + 24].try_into().unwrap_or([0; 4]));
|
||||
if next == 0 {
|
||||
break;
|
||||
}
|
||||
offset += next as usize;
|
||||
}
|
||||
count
|
||||
}
|
||||
|
||||
/// Rewrite each sub-frame's `MessageId` in-place, consuming one id from
|
||||
/// `ids` per sub-frame in FIFO order. Sub-frames whose msg_id is
|
||||
/// already non-zero keep their hardcoded id (so tests exercising out-of-
|
||||
/// order routing still work) but STILL consume one id from the queue
|
||||
/// to keep send→receive pairing 1:1.
|
||||
fn rewrite_msg_ids(data: &mut [u8], ids: &mut VecDeque<u64>) {
|
||||
const HEADER_MIN: usize = 64;
|
||||
if data.len() < HEADER_MIN || &data[0..4] != b"\xFESMB" {
|
||||
return;
|
||||
}
|
||||
let mut offset = 0usize;
|
||||
loop {
|
||||
if offset + HEADER_MIN > data.len() {
|
||||
break;
|
||||
}
|
||||
let existing =
|
||||
u64::from_le_bytes(data[offset + 24..offset + 32].try_into().unwrap_or([0; 8]));
|
||||
let consumed = ids.pop_front();
|
||||
if existing == 0 {
|
||||
if let Some(id) = consumed {
|
||||
data[offset + 24..offset + 32].copy_from_slice(&id.to_le_bytes());
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
let next = u32::from_le_bytes(data[offset + 20..offset + 24].try_into().unwrap_or([0; 4]));
|
||||
if next == 0 {
|
||||
break;
|
||||
}
|
||||
offset += next as usize;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn queue_response_and_receive_it() {
|
||||
let mock = MockTransport::new();
|
||||
let data = vec![0x01, 0x02, 0x03];
|
||||
mock.queue_response(data.clone());
|
||||
|
||||
let received = mock.receive().await.unwrap();
|
||||
assert_eq!(received, data);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn queue_multiple_responses_received_in_order() {
|
||||
let mock = MockTransport::new();
|
||||
mock.queue_responses(vec![vec![0x01], vec![0x02, 0x03], vec![0x04, 0x05, 0x06]]);
|
||||
|
||||
assert_eq!(mock.receive().await.unwrap(), vec![0x01]);
|
||||
assert_eq!(mock.receive().await.unwrap(), vec![0x02, 0x03]);
|
||||
assert_eq!(mock.receive().await.unwrap(), vec![0x04, 0x05, 0x06]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn close_causes_receive_to_return_disconnected() {
|
||||
let mock = MockTransport::new();
|
||||
mock.close();
|
||||
|
||||
let result = mock.receive().await;
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err();
|
||||
assert!(
|
||||
matches!(err, Error::Disconnected),
|
||||
"expected Disconnected, got: {err}"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn send_records_message() {
|
||||
let mock = MockTransport::new();
|
||||
let msg = vec![0xAA, 0xBB, 0xCC];
|
||||
|
||||
mock.send(&msg).await.unwrap();
|
||||
|
||||
let sent = mock.sent_messages();
|
||||
assert_eq!(sent.len(), 1);
|
||||
assert_eq!(sent[0], msg);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn sent_count_tracks_correctly() {
|
||||
let mock = MockTransport::new();
|
||||
assert_eq!(mock.sent_count(), 0);
|
||||
|
||||
mock.send(&[0x01]).await.unwrap();
|
||||
assert_eq!(mock.sent_count(), 1);
|
||||
|
||||
mock.send(&[0x02]).await.unwrap();
|
||||
assert_eq!(mock.sent_count(), 2);
|
||||
|
||||
mock.send(&[0x03]).await.unwrap();
|
||||
assert_eq!(mock.sent_count(), 3);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn sent_message_returns_nth() {
|
||||
let mock = MockTransport::new();
|
||||
mock.send(&[0x0A]).await.unwrap();
|
||||
mock.send(&[0x0B]).await.unwrap();
|
||||
mock.send(&[0x0C]).await.unwrap();
|
||||
|
||||
assert_eq!(mock.sent_message(0), Some(vec![0x0A]));
|
||||
assert_eq!(mock.sent_message(1), Some(vec![0x0B]));
|
||||
assert_eq!(mock.sent_message(2), Some(vec![0x0C]));
|
||||
assert_eq!(mock.sent_message(3), None);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn clear_sent_removes_all_recorded_messages() {
|
||||
let mock = MockTransport::new();
|
||||
mock.send(&[0x01]).await.unwrap();
|
||||
mock.send(&[0x02]).await.unwrap();
|
||||
assert_eq!(mock.sent_count(), 2);
|
||||
|
||||
mock.clear_sent();
|
||||
assert_eq!(mock.sent_count(), 0);
|
||||
assert!(mock.sent_messages().is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn interleaved_send_and_receive() {
|
||||
let mock = MockTransport::new();
|
||||
mock.queue_responses(vec![vec![0xF1], vec![0xF2], vec![0xF3]]);
|
||||
|
||||
// Send a request, receive a response, repeat.
|
||||
mock.send(&[0x01]).await.unwrap();
|
||||
assert_eq!(mock.receive().await.unwrap(), vec![0xF1]);
|
||||
|
||||
mock.send(&[0x02]).await.unwrap();
|
||||
assert_eq!(mock.receive().await.unwrap(), vec![0xF2]);
|
||||
|
||||
mock.send(&[0x03]).await.unwrap();
|
||||
assert_eq!(mock.receive().await.unwrap(), vec![0xF3]);
|
||||
|
||||
// No more responses. Close to cause Disconnected.
|
||||
mock.close();
|
||||
assert!(mock.receive().await.is_err());
|
||||
|
||||
// All three sends recorded.
|
||||
assert_eq!(mock.sent_count(), 3);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn concurrent_send_and_receive() {
|
||||
use std::sync::Arc;
|
||||
|
||||
let mock = Arc::new(MockTransport::new());
|
||||
mock.queue_responses(vec![vec![0xAA]; 10]);
|
||||
|
||||
let send_mock = Arc::clone(&mock);
|
||||
let send_task = tokio::spawn(async move {
|
||||
for i in 0..10u8 {
|
||||
send_mock.send(&[i]).await.unwrap();
|
||||
}
|
||||
});
|
||||
|
||||
let recv_mock = Arc::clone(&mock);
|
||||
let recv_task = tokio::spawn(async move {
|
||||
let mut received = Vec::new();
|
||||
for _ in 0..10 {
|
||||
received.push(recv_mock.receive().await.unwrap());
|
||||
}
|
||||
received
|
||||
});
|
||||
|
||||
send_task.await.unwrap();
|
||||
let received = recv_task.await.unwrap();
|
||||
|
||||
assert_eq!(received.len(), 10);
|
||||
assert_eq!(mock.sent_count(), 10);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn empty_message_can_be_sent_and_received() {
|
||||
let mock = MockTransport::new();
|
||||
mock.queue_response(vec![]);
|
||||
|
||||
mock.send(&[]).await.unwrap();
|
||||
let received = mock.receive().await.unwrap();
|
||||
|
||||
assert!(received.is_empty());
|
||||
assert_eq!(mock.sent_message(0), Some(vec![]));
|
||||
}
|
||||
}
|
||||
215
vendor/smb2/src/transport/mod.rs
vendored
Normal file
215
vendor/smb2/src/transport/mod.rs
vendored
Normal file
@@ -0,0 +1,215 @@
|
||||
//! Transport abstraction for sending and receiving SMB2 messages.
|
||||
//!
|
||||
//! The transport layer handles framing (TCP's 4-byte length-prefix header)
|
||||
//! and provides split send/receive traits to avoid deadlocks in the
|
||||
//! pipeline's `tokio::select!` loop.
|
||||
//!
|
||||
//! Two implementations are provided:
|
||||
//! - [`TcpTransport`] -- direct TCP connection to an SMB server (port 445)
|
||||
//! - [`MockTransport`] -- canned responses for testing
|
||||
//!
|
||||
//! Most users don't need this module directly -- use [`SmbClient`](crate::SmbClient)
|
||||
//! which handles transport setup internally.
|
||||
|
||||
pub mod mock;
|
||||
pub mod tcp;
|
||||
|
||||
pub use mock::MockTransport;
|
||||
pub use tcp::TcpTransport;
|
||||
|
||||
use crate::error::Result;
|
||||
use async_trait::async_trait;
|
||||
|
||||
/// Send half of a transport connection.
|
||||
#[async_trait]
|
||||
pub trait TransportSend: Send + Sync {
|
||||
/// Send a complete SMB2 message (the implementation adds framing).
|
||||
async fn send(&self, data: &[u8]) -> Result<()>;
|
||||
}
|
||||
|
||||
/// Receive half of a transport connection.
|
||||
#[async_trait]
|
||||
pub trait TransportReceive: Send + Sync {
|
||||
/// Receive one complete SMB2 transport frame.
|
||||
///
|
||||
/// The implementation handles the TCP framing (4-byte header:
|
||||
/// 1 zero byte + 3-byte big-endian length). The returned buffer
|
||||
/// contains the SMB2 message(s) without the framing header.
|
||||
///
|
||||
/// The buffer may contain multiple compounded responses linked
|
||||
/// by NextCommand in the SMB2 headers -- the caller must split them.
|
||||
async fn receive(&self) -> Result<Vec<u8>>;
|
||||
}
|
||||
|
||||
/// A combined transport that can both send and receive.
|
||||
pub trait Transport: TransportSend + TransportReceive {}
|
||||
|
||||
// Blanket implementation: anything that implements both halves is a Transport.
|
||||
impl<T: TransportSend + TransportReceive> Transport for T {}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::msg::header::{Header, PROTOCOL_ID};
|
||||
use crate::msg::negotiate::{
|
||||
NegotiateContext, NegotiateRequest, NegotiateResponse, HASH_ALGORITHM_SHA512,
|
||||
};
|
||||
use crate::pack::{Guid, Pack, ReadCursor, Unpack, WriteCursor};
|
||||
use crate::types::flags::{Capabilities, SecurityMode};
|
||||
use crate::types::{Command, Dialect};
|
||||
|
||||
/// Pack a header + body into raw SMB2 message bytes (no transport framing).
|
||||
fn pack_message(header: &Header, body: &dyn Pack) -> Vec<u8> {
|
||||
let mut cursor = WriteCursor::new();
|
||||
header.pack(&mut cursor);
|
||||
body.pack(&mut cursor);
|
||||
cursor.into_inner()
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn cross_module_negotiate_via_mock_transport() {
|
||||
// Build a NegotiateRequest, send it through MockTransport,
|
||||
// receive a canned NegotiateResponse, and verify unpacking.
|
||||
|
||||
let mock = MockTransport::new();
|
||||
|
||||
// Build a negotiate request.
|
||||
let req_header = Header::new_request(Command::Negotiate);
|
||||
let req_body = NegotiateRequest {
|
||||
security_mode: SecurityMode::new(SecurityMode::SIGNING_ENABLED),
|
||||
capabilities: Capabilities::default(),
|
||||
client_guid: Guid {
|
||||
data1: 0xDEAD_BEEF,
|
||||
data2: 0xCAFE,
|
||||
data3: 0xF00D,
|
||||
data4: [0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08],
|
||||
},
|
||||
dialects: vec![Dialect::Smb2_0_2, Dialect::Smb2_1, Dialect::Smb3_1_1],
|
||||
negotiate_contexts: vec![NegotiateContext::PreauthIntegrity {
|
||||
hash_algorithms: vec![HASH_ALGORITHM_SHA512],
|
||||
salt: vec![0xAA; 32],
|
||||
}],
|
||||
};
|
||||
let req_msg = pack_message(&req_header, &req_body);
|
||||
|
||||
// Build a canned NegotiateResponse.
|
||||
let resp_header = {
|
||||
let mut h = Header::new_request(Command::Negotiate);
|
||||
h.flags.set_response();
|
||||
h.credits = 1;
|
||||
h
|
||||
};
|
||||
let resp_body = NegotiateResponse {
|
||||
security_mode: SecurityMode::new(SecurityMode::SIGNING_ENABLED),
|
||||
dialect_revision: Dialect::Smb3_1_1,
|
||||
server_guid: Guid {
|
||||
data1: 0x1111_2222,
|
||||
data2: 0x3333,
|
||||
data3: 0x4444,
|
||||
data4: [0x55, 0x66, 0x77, 0x88, 0x99, 0xAA, 0xBB, 0xCC],
|
||||
},
|
||||
capabilities: Capabilities::new(Capabilities::DFS | Capabilities::LEASING),
|
||||
max_transact_size: 65536,
|
||||
max_read_size: 65536,
|
||||
max_write_size: 65536,
|
||||
system_time: 132_000_000_000_000_000,
|
||||
server_start_time: 131_000_000_000_000_000,
|
||||
security_buffer: vec![0x60, 0x00], // minimal placeholder
|
||||
negotiate_contexts: vec![NegotiateContext::PreauthIntegrity {
|
||||
hash_algorithms: vec![HASH_ALGORITHM_SHA512],
|
||||
salt: vec![0xBB; 32],
|
||||
}],
|
||||
};
|
||||
let resp_msg = pack_message(&resp_header, &resp_body);
|
||||
|
||||
// Queue the canned response.
|
||||
mock.queue_response(resp_msg);
|
||||
|
||||
// Send the request through the mock.
|
||||
mock.send(&req_msg).await.unwrap();
|
||||
|
||||
// Receive the canned response.
|
||||
let received = mock.receive().await.unwrap();
|
||||
|
||||
// Unpack and verify.
|
||||
let mut cursor = ReadCursor::new(&received);
|
||||
let hdr = Header::unpack(&mut cursor).unwrap();
|
||||
assert!(hdr.is_response());
|
||||
assert_eq!(hdr.command, Command::Negotiate);
|
||||
|
||||
let body = NegotiateResponse::unpack(&mut cursor).unwrap();
|
||||
assert_eq!(body.dialect_revision, Dialect::Smb3_1_1);
|
||||
assert_eq!(body.max_read_size, 65536);
|
||||
assert!(body.security_mode.signing_enabled());
|
||||
|
||||
// Verify the request was recorded.
|
||||
assert_eq!(mock.sent_count(), 1);
|
||||
let sent = mock.sent_message(0).unwrap();
|
||||
|
||||
// Verify we can unpack what was sent.
|
||||
let mut cursor = ReadCursor::new(&sent);
|
||||
let sent_hdr = Header::unpack(&mut cursor).unwrap();
|
||||
assert_eq!(sent_hdr.command, Command::Negotiate);
|
||||
assert!(!sent_hdr.is_response());
|
||||
|
||||
let sent_body = NegotiateRequest::unpack(&mut cursor).unwrap();
|
||||
assert_eq!(sent_body.dialects.len(), 3);
|
||||
assert!(sent_body.dialects.contains(&Dialect::Smb3_1_1));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore] // Requires NAS at 192.168.1.111
|
||||
async fn negotiate_via_tcp_transport() {
|
||||
use std::time::Duration;
|
||||
|
||||
let transport = TcpTransport::connect("192.168.1.111:445", Duration::from_secs(5))
|
||||
.await
|
||||
.expect("failed to connect to NAS");
|
||||
|
||||
// Build a negotiate request.
|
||||
let header = Header::new_request(Command::Negotiate);
|
||||
let request = NegotiateRequest {
|
||||
security_mode: SecurityMode::new(SecurityMode::SIGNING_ENABLED),
|
||||
capabilities: Capabilities::new(
|
||||
Capabilities::DFS | Capabilities::LEASING | Capabilities::LARGE_MTU,
|
||||
),
|
||||
client_guid: Guid {
|
||||
data1: 0xDEAD_BEEF,
|
||||
data2: 0xCAFE,
|
||||
data3: 0xF00D,
|
||||
data4: [0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08],
|
||||
},
|
||||
dialects: vec![
|
||||
Dialect::Smb2_0_2,
|
||||
Dialect::Smb2_1,
|
||||
Dialect::Smb3_0,
|
||||
Dialect::Smb3_0_2,
|
||||
Dialect::Smb3_1_1,
|
||||
],
|
||||
negotiate_contexts: vec![NegotiateContext::PreauthIntegrity {
|
||||
hash_algorithms: vec![HASH_ALGORITHM_SHA512],
|
||||
salt: vec![0xAA; 32],
|
||||
}],
|
||||
};
|
||||
|
||||
let msg = pack_message(&header, &request);
|
||||
|
||||
// Send through transport (framing added automatically).
|
||||
transport.send(&msg).await.unwrap();
|
||||
|
||||
// Receive response (framing stripped automatically).
|
||||
let resp_bytes = transport.receive().await.unwrap();
|
||||
|
||||
// Verify we got a valid response.
|
||||
assert!(resp_bytes[0..4] == PROTOCOL_ID);
|
||||
|
||||
let mut cursor = ReadCursor::new(&resp_bytes);
|
||||
let resp_header = Header::unpack(&mut cursor).unwrap();
|
||||
assert!(resp_header.is_response());
|
||||
assert_eq!(resp_header.command, Command::Negotiate);
|
||||
|
||||
let resp_body = NegotiateResponse::unpack(&mut cursor).unwrap();
|
||||
assert!(Dialect::ALL.contains(&resp_body.dialect_revision));
|
||||
assert!(resp_body.max_read_size >= 65536);
|
||||
}
|
||||
}
|
||||
485
vendor/smb2/src/transport/tcp.rs
vendored
Normal file
485
vendor/smb2/src/transport/tcp.rs
vendored
Normal file
@@ -0,0 +1,485 @@
|
||||
//! Direct TCP transport for SMB2 (port 445).
|
||||
//!
|
||||
//! Implements the SMB2 transport framing defined in MS-SMB2 section 2.1:
|
||||
//! each message is preceded by a 4-byte header consisting of 1 zero byte
|
||||
//! followed by 3 bytes of big-endian length. This is the ONLY big-endian
|
||||
//! encoding in the entire SMB2 protocol.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use log::{debug, error, trace};
|
||||
use std::time::Duration;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
|
||||
use tokio::net::{TcpStream, ToSocketAddrs};
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
use crate::error::{Error, Result};
|
||||
use crate::transport::{TransportReceive, TransportSend};
|
||||
|
||||
/// Maximum frame size we accept (16 MB).
|
||||
///
|
||||
/// Prevents denial-of-service from corrupt or malicious length fields.
|
||||
/// Real SMB2 messages are typically much smaller (the largest negotiated
|
||||
/// MaxReadSize/MaxWriteSize is usually 8 MB).
|
||||
const MAX_FRAME_SIZE: usize = 16 * 1024 * 1024;
|
||||
|
||||
/// Direct TCP transport for SMB2.
|
||||
///
|
||||
/// Wraps a TCP connection and handles the 4-byte framing header.
|
||||
/// The connection is split into independent read and write halves
|
||||
/// so that send and receive can proceed concurrently without contention
|
||||
/// (required by the pipeline's `tokio::select!` loop).
|
||||
#[derive(Debug)]
|
||||
pub struct TcpTransport {
|
||||
/// The read half of the TCP connection, behind a mutex for `&self` access.
|
||||
reader: Mutex<OwnedReadHalf>,
|
||||
/// The write half of the TCP connection, behind a mutex for `&self` access.
|
||||
writer: Mutex<OwnedWriteHalf>,
|
||||
}
|
||||
|
||||
impl TcpTransport {
|
||||
/// Connect to an SMB server over TCP.
|
||||
///
|
||||
/// Applies the given timeout to the connection attempt. Once connected,
|
||||
/// the socket is split into independent read/write halves.
|
||||
pub async fn connect(addr: impl ToSocketAddrs, timeout: Duration) -> Result<Self> {
|
||||
let stream = tokio::time::timeout(timeout, TcpStream::connect(addr))
|
||||
.await
|
||||
.map_err(|_| Error::Timeout)?
|
||||
.map_err(Error::Io)?;
|
||||
|
||||
// Disable Nagle's algorithm for lower latency on small messages.
|
||||
stream.set_nodelay(true).map_err(Error::Io)?;
|
||||
|
||||
debug!("tcp: connected, nodelay=true");
|
||||
let (reader, writer) = stream.into_split();
|
||||
|
||||
Ok(Self {
|
||||
reader: Mutex::new(reader),
|
||||
writer: Mutex::new(writer),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TransportSend for TcpTransport {
|
||||
async fn send(&self, data: &[u8]) -> Result<()> {
|
||||
let len = data.len();
|
||||
if len > MAX_FRAME_SIZE {
|
||||
return Err(Error::invalid_data(format!(
|
||||
"message size {} exceeds maximum frame size {}",
|
||||
len, MAX_FRAME_SIZE
|
||||
)));
|
||||
}
|
||||
|
||||
// Build the 4-byte framing header: 0x00 + 3-byte BE length.
|
||||
let mut frame_header = [0u8; 4];
|
||||
frame_header[0] = 0x00;
|
||||
frame_header[1] = (len >> 16) as u8;
|
||||
frame_header[2] = (len >> 8) as u8;
|
||||
frame_header[3] = len as u8;
|
||||
|
||||
let mut writer = self.writer.lock().await;
|
||||
writer.write_all(&frame_header).await.map_err(Error::Io)?;
|
||||
writer.write_all(data).await.map_err(Error::Io)?;
|
||||
writer.flush().await.map_err(Error::Io)?;
|
||||
|
||||
trace!("tcp: sent frame, len={}", len);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TransportReceive for TcpTransport {
|
||||
async fn receive(&self) -> Result<Vec<u8>> {
|
||||
let mut reader = self.reader.lock().await;
|
||||
|
||||
// Read the 4-byte framing header.
|
||||
let mut frame_header = [0u8; 4];
|
||||
reader.read_exact(&mut frame_header).await.map_err(|e| {
|
||||
if e.kind() == std::io::ErrorKind::UnexpectedEof {
|
||||
Error::Disconnected
|
||||
} else {
|
||||
Error::Io(e)
|
||||
}
|
||||
})?;
|
||||
|
||||
// Validate the first byte is 0x00.
|
||||
if frame_header[0] != 0x00 {
|
||||
error!("tcp: invalid frame, first byte=0x{:02X}", frame_header[0]);
|
||||
return Err(Error::invalid_data(format!(
|
||||
"invalid transport frame: first byte must be 0x00, got 0x{:02X}",
|
||||
frame_header[0]
|
||||
)));
|
||||
}
|
||||
|
||||
// Extract the 3-byte big-endian length.
|
||||
let msg_len = ((frame_header[1] as usize) << 16)
|
||||
| ((frame_header[2] as usize) << 8)
|
||||
| (frame_header[3] as usize);
|
||||
|
||||
// Validate against the maximum frame size.
|
||||
if msg_len > MAX_FRAME_SIZE {
|
||||
return Err(Error::invalid_data(format!(
|
||||
"frame length {} exceeds maximum {}",
|
||||
msg_len, MAX_FRAME_SIZE
|
||||
)));
|
||||
}
|
||||
|
||||
trace!("tcp: receiving frame, len={}", msg_len);
|
||||
|
||||
// Read the message body.
|
||||
let mut buf = vec![0u8; msg_len];
|
||||
reader.read_exact(&mut buf).await.map_err(|e| {
|
||||
if e.kind() == std::io::ErrorKind::UnexpectedEof {
|
||||
Error::Disconnected
|
||||
} else {
|
||||
Error::Io(e)
|
||||
}
|
||||
})?;
|
||||
|
||||
trace!("tcp: received frame, len={}", msg_len);
|
||||
Ok(buf)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
/// Build a framed message (4-byte header + payload).
|
||||
fn frame_message(payload: &[u8]) -> Vec<u8> {
|
||||
let len = payload.len();
|
||||
let mut frame = Vec::with_capacity(4 + len);
|
||||
frame.push(0x00);
|
||||
frame.push((len >> 16) as u8);
|
||||
frame.push((len >> 8) as u8);
|
||||
frame.push(len as u8);
|
||||
frame.extend_from_slice(payload);
|
||||
frame
|
||||
}
|
||||
|
||||
// ── Send framing tests ──────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn frame_header_format_small_message() {
|
||||
let payload = vec![0xFE, 0x53, 0x4D, 0x42]; // "SMB2 magic"
|
||||
let framed = frame_message(&payload);
|
||||
|
||||
// Header: [0x00, 0x00, 0x00, 0x04]
|
||||
assert_eq!(framed[0], 0x00, "first byte must be 0x00");
|
||||
assert_eq!(framed[1], 0x00, "length high byte");
|
||||
assert_eq!(framed[2], 0x00, "length mid byte");
|
||||
assert_eq!(framed[3], 0x04, "length low byte = 4");
|
||||
assert_eq!(&framed[4..], &payload);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn frame_header_format_medium_message() {
|
||||
// 300 bytes -> 0x00, 0x00, 0x01, 0x2C
|
||||
let payload = vec![0xAA; 300];
|
||||
let framed = frame_message(&payload);
|
||||
|
||||
assert_eq!(framed[0], 0x00);
|
||||
assert_eq!(framed[1], 0x00);
|
||||
assert_eq!(framed[2], 0x01);
|
||||
assert_eq!(framed[3], 0x2C);
|
||||
assert_eq!(framed.len(), 304);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn frame_header_format_large_message() {
|
||||
// 0x010203 = 66051 bytes
|
||||
let payload = vec![0xBB; 66051];
|
||||
let framed = frame_message(&payload);
|
||||
|
||||
assert_eq!(framed[0], 0x00);
|
||||
assert_eq!(framed[1], 0x01);
|
||||
assert_eq!(framed[2], 0x02);
|
||||
assert_eq!(framed[3], 0x03);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn frame_header_empty_payload() {
|
||||
let framed = frame_message(&[]);
|
||||
assert_eq!(framed, vec![0x00, 0x00, 0x00, 0x00]);
|
||||
}
|
||||
|
||||
// ── Receive framing tests (using tokio_test-style mock streams) ──
|
||||
|
||||
/// A helper that creates a pair of connected streams via a TCP listener
|
||||
/// on localhost, then writes data to one side and reads from the other.
|
||||
async fn receive_from_bytes(data: &[u8]) -> Result<Vec<u8>> {
|
||||
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let addr = listener.local_addr().unwrap();
|
||||
|
||||
let data = data.to_vec();
|
||||
let writer_task = tokio::spawn(async move {
|
||||
let mut stream = TcpStream::connect(addr).await.unwrap();
|
||||
stream.write_all(&data).await.unwrap();
|
||||
stream.shutdown().await.unwrap();
|
||||
});
|
||||
|
||||
let (stream, _) = listener.accept().await.unwrap();
|
||||
let (reader, writer) = stream.into_split();
|
||||
let transport = TcpTransport {
|
||||
reader: Mutex::new(reader),
|
||||
writer: Mutex::new(writer),
|
||||
};
|
||||
|
||||
let result = transport.receive().await;
|
||||
writer_task.await.unwrap();
|
||||
result
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn receive_valid_frame() {
|
||||
let payload = vec![0xFE, 0x53, 0x4D, 0x42, 0x01, 0x02];
|
||||
let framed = frame_message(&payload);
|
||||
|
||||
let received = receive_from_bytes(&framed).await.unwrap();
|
||||
assert_eq!(received, payload);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn receive_empty_payload() {
|
||||
let framed = frame_message(&[]);
|
||||
let received = receive_from_bytes(&framed).await.unwrap();
|
||||
assert!(received.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn receive_first_byte_not_zero_returns_error() {
|
||||
// First byte is 0x01 instead of 0x00.
|
||||
let data = vec![0x01, 0x00, 0x00, 0x04, 0xAA, 0xBB, 0xCC, 0xDD];
|
||||
|
||||
let result = receive_from_bytes(&data).await;
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err().to_string();
|
||||
assert!(
|
||||
err.contains("first byte must be 0x00"),
|
||||
"unexpected error: {err}"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn receive_length_exceeds_max_returns_error() {
|
||||
// Length = 0xFFFFFF = 16777215 > MAX_FRAME_SIZE (16 * 1024 * 1024 = 16777216)
|
||||
// Wait, 0xFFFFFF = 16777215 < 16777216. Let's use a length just over.
|
||||
// MAX_FRAME_SIZE = 16 * 1024 * 1024 = 16_777_216
|
||||
// We need > 16_777_216, but max 3-byte value is 16_777_215.
|
||||
// So 3 bytes can't exceed 16 MB. But the spec says 16 MB is the max.
|
||||
// Let's set MAX_FRAME_SIZE to slightly less, or test at the boundary.
|
||||
// Actually MAX_FRAME_SIZE = 16 * 1024 * 1024 = 16_777_216.
|
||||
// Max 3-byte value = 0xFFFFFF = 16_777_215 which is < MAX_FRAME_SIZE.
|
||||
// So a 3-byte length can never exceed our MAX_FRAME_SIZE.
|
||||
// This test verifies that the max 3-byte value IS accepted (no error).
|
||||
// But what if someone sends a broken frame? The first byte check
|
||||
// catches that. For the length check specifically, we'd need a
|
||||
// smaller MAX_FRAME_SIZE to exercise the branch. For now, let's test
|
||||
// with an internal test. The important thing is the check exists.
|
||||
|
||||
// Actually, the more realistic concern is a malicious server sending
|
||||
// large values. 0xFFFFFF = ~16 MB is fine by our limit. Let's verify
|
||||
// the boundary: 0xFFFFFF should be accepted because 16_777_215 < 16_777_216.
|
||||
// We can't test > MAX_FRAME_SIZE with only 3 bytes, but the check
|
||||
// is there for defense-in-depth (the first byte could be non-zero
|
||||
// and interpreted as part of length if we didn't validate it).
|
||||
|
||||
// Let's test a frame with length 0xFFFFFF but not enough payload data,
|
||||
// which should return Disconnected (not a crash from huge allocation).
|
||||
let data = vec![0x00, 0xFF, 0xFF, 0xFF]; // Length = 16_777_215 bytes, no payload.
|
||||
|
||||
let result = receive_from_bytes(&data).await;
|
||||
assert!(result.is_err());
|
||||
// Should get Disconnected because the payload read fails.
|
||||
let err = result.unwrap_err();
|
||||
assert!(
|
||||
matches!(err, Error::Disconnected),
|
||||
"expected Disconnected for truncated large frame, got: {err}"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn receive_disconnected_on_eof() {
|
||||
// Empty data = immediate EOF.
|
||||
let result = receive_from_bytes(&[]).await;
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err();
|
||||
assert!(
|
||||
matches!(err, Error::Disconnected),
|
||||
"expected Disconnected, got: {err}"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn receive_partial_header_returns_disconnected() {
|
||||
// Only 2 bytes of the 4-byte header.
|
||||
let result = receive_from_bytes(&[0x00, 0x00]).await;
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err();
|
||||
assert!(
|
||||
matches!(err, Error::Disconnected),
|
||||
"expected Disconnected for partial header, got: {err}"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn receive_partial_payload_returns_disconnected() {
|
||||
// Header says 10 bytes, but only 3 bytes of payload follow.
|
||||
let data = vec![0x00, 0x00, 0x00, 0x0A, 0x01, 0x02, 0x03];
|
||||
|
||||
let result = receive_from_bytes(&data).await;
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err();
|
||||
assert!(
|
||||
matches!(err, Error::Disconnected),
|
||||
"expected Disconnected for truncated payload, got: {err}"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn send_and_receive_roundtrip() {
|
||||
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let addr = listener.local_addr().unwrap();
|
||||
|
||||
let send_task = tokio::spawn(async move {
|
||||
let stream = TcpStream::connect(addr).await.unwrap();
|
||||
let (reader, writer) = stream.into_split();
|
||||
let transport = TcpTransport {
|
||||
reader: Mutex::new(reader),
|
||||
writer: Mutex::new(writer),
|
||||
};
|
||||
|
||||
let payload = vec![0xFE, 0x53, 0x4D, 0x42, 0xDE, 0xAD];
|
||||
transport.send(&payload).await.unwrap();
|
||||
});
|
||||
|
||||
let (stream, _) = listener.accept().await.unwrap();
|
||||
let (reader, writer) = stream.into_split();
|
||||
let recv_transport = TcpTransport {
|
||||
reader: Mutex::new(reader),
|
||||
writer: Mutex::new(writer),
|
||||
};
|
||||
|
||||
let received = recv_transport.receive().await.unwrap();
|
||||
assert_eq!(received, vec![0xFE, 0x53, 0x4D, 0x42, 0xDE, 0xAD]);
|
||||
|
||||
send_task.await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn send_and_receive_multiple_messages() {
|
||||
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let addr = listener.local_addr().unwrap();
|
||||
|
||||
let send_task = tokio::spawn(async move {
|
||||
let stream = TcpStream::connect(addr).await.unwrap();
|
||||
let (reader, writer) = stream.into_split();
|
||||
let transport = TcpTransport {
|
||||
reader: Mutex::new(reader),
|
||||
writer: Mutex::new(writer),
|
||||
};
|
||||
|
||||
transport.send(&[0x01, 0x02]).await.unwrap();
|
||||
transport.send(&[0x03, 0x04, 0x05]).await.unwrap();
|
||||
transport.send(&[0x06]).await.unwrap();
|
||||
});
|
||||
|
||||
let (stream, _) = listener.accept().await.unwrap();
|
||||
let (reader, writer) = stream.into_split();
|
||||
let recv_transport = TcpTransport {
|
||||
reader: Mutex::new(reader),
|
||||
writer: Mutex::new(writer),
|
||||
};
|
||||
|
||||
assert_eq!(recv_transport.receive().await.unwrap(), vec![0x01, 0x02]);
|
||||
assert_eq!(
|
||||
recv_transport.receive().await.unwrap(),
|
||||
vec![0x03, 0x04, 0x05]
|
||||
);
|
||||
assert_eq!(recv_transport.receive().await.unwrap(), vec![0x06]);
|
||||
|
||||
send_task.await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn partial_reads_are_handled_by_read_exact() {
|
||||
// This test exercises the read_exact behavior by sending data
|
||||
// through a real TCP connection. Under the hood, TCP may deliver
|
||||
// data in arbitrary chunk sizes, especially with Nagle disabled.
|
||||
// While we can't force byte-at-a-time delivery reliably, we
|
||||
// verify correctness with a larger payload that's more likely
|
||||
// to arrive in multiple reads.
|
||||
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let addr = listener.local_addr().unwrap();
|
||||
|
||||
let payload: Vec<u8> = (0..=255).cycle().take(8192).collect();
|
||||
let payload_clone = payload.clone();
|
||||
|
||||
let send_task = tokio::spawn(async move {
|
||||
let stream = TcpStream::connect(addr).await.unwrap();
|
||||
let (reader, writer) = stream.into_split();
|
||||
let transport = TcpTransport {
|
||||
reader: Mutex::new(reader),
|
||||
writer: Mutex::new(writer),
|
||||
};
|
||||
|
||||
transport.send(&payload_clone).await.unwrap();
|
||||
});
|
||||
|
||||
let (stream, _) = listener.accept().await.unwrap();
|
||||
let (reader, writer) = stream.into_split();
|
||||
let recv_transport = TcpTransport {
|
||||
reader: Mutex::new(reader),
|
||||
writer: Mutex::new(writer),
|
||||
};
|
||||
|
||||
let received = recv_transport.receive().await.unwrap();
|
||||
assert_eq!(received.len(), payload.len());
|
||||
assert_eq!(received, payload);
|
||||
|
||||
send_task.await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn connect_with_timeout() {
|
||||
// Connect to localhost listener with a generous timeout.
|
||||
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let addr = listener.local_addr().unwrap();
|
||||
|
||||
let transport = TcpTransport::connect(addr, Duration::from_secs(5))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Accept the connection on the server side.
|
||||
let (server_stream, _) = listener.accept().await.unwrap();
|
||||
let (server_reader, mut server_writer) = server_stream.into_split();
|
||||
drop(server_reader);
|
||||
|
||||
// Send a framed message from the "server" side.
|
||||
let payload = vec![0xDE, 0xAD, 0xBE, 0xEF];
|
||||
let mut frame = vec![0x00, 0x00, 0x00, 0x04];
|
||||
frame.extend_from_slice(&payload);
|
||||
server_writer.write_all(&frame).await.unwrap();
|
||||
server_writer.flush().await.unwrap();
|
||||
|
||||
// Receive through the transport.
|
||||
let received = transport.receive().await.unwrap();
|
||||
assert_eq!(received, payload);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn connect_timeout_fires() {
|
||||
// Try to connect to a non-routable address. This should time out.
|
||||
// 192.0.2.1 is a TEST-NET address (RFC 5737) that should be unreachable.
|
||||
let result = TcpTransport::connect("192.0.2.1:445", Duration::from_millis(100)).await;
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err();
|
||||
// Could be Timeout or Io depending on OS behavior.
|
||||
assert!(
|
||||
matches!(err, Error::Timeout | Error::Io(_)),
|
||||
"expected Timeout or Io error, got: {err}"
|
||||
);
|
||||
}
|
||||
}
|
||||
38
vendor/smb2/src/types/CLAUDE.md
vendored
Normal file
38
vendor/smb2/src/types/CLAUDE.md
vendored
Normal file
@@ -0,0 +1,38 @@
|
||||
# Types -- protocol newtypes and enums
|
||||
|
||||
Zero-cost newtype wrappers for protocol IDs, command/dialect enums, and bitflag types.
|
||||
|
||||
## Key files
|
||||
|
||||
| File | Purpose |
|
||||
|---|---|
|
||||
| `mod.rs` | `SessionId`, `TreeId`, `FileId`, `MessageId`, `CreditCharge`, `Command`, `Dialect`, `OplockLevel` |
|
||||
| `flags.rs` | Bitflag types: `HeaderFlags`, `Capabilities`, `SecurityMode`, `FileAccessMask`, etc. |
|
||||
| `status.rs` | `NtStatus` enum (from MS-ERREF) with severity/facility helpers |
|
||||
|
||||
## Newtype IDs
|
||||
|
||||
All protocol IDs are newtypes over their raw integer:
|
||||
- `SessionId(u64)` -- has `NONE` sentinel (0)
|
||||
- `MessageId(u64)` -- has `UNSOLICITED` sentinel (0xFFFFFFFFFFFFFFFF) for oplock breaks
|
||||
- `TreeId(u32)`
|
||||
- `CreditCharge(u16)`
|
||||
- `FileId { persistent: u64, volatile: u64 }` -- has `SENTINEL` (all-F's) for compound related requests
|
||||
|
||||
All implement `Debug`, `Clone`, `Copy`, `PartialEq`, `Eq`, `Hash`, `Display`.
|
||||
|
||||
## Command and Dialect enums
|
||||
|
||||
- `Command`: 19 variants (Negotiate through OplockBreak), `repr(u16)`, uses `num_enum` for `TryFrom<u16>`/`Into<u16>`
|
||||
- `Dialect`: 5 variants (2.0.2 through 3.1.1), `repr(u16)`, ordered (`PartialOrd`/`Ord`). `Dialect::ALL` is a sorted slice.
|
||||
|
||||
## Key decisions
|
||||
|
||||
- **Newtypes over raw u32/u64**: Prevents accidentally passing a TreeId where a SessionId is expected. Zero runtime cost.
|
||||
- **`num_enum` for command/dialect**: Avoids manual match arms for TryFrom. Compile-time checked exhaustive conversions.
|
||||
|
||||
## Gotchas
|
||||
|
||||
- **`MORE_PROCESSING_REQUIRED` has error severity bits but isn't an error**: `NtStatus` severity is encoded in bits 30-31. `MORE_PROCESSING_REQUIRED` (0xC0000016) has severity=3 (error), but it's a normal part of the session setup flow. Use `is_more_processing_required()` instead of checking `is_error()`.
|
||||
- **`STATUS_BUFFER_OVERFLOW` is a warning, not an error**: Returns valid partial data. Don't discard the response body.
|
||||
- **FileId::SENTINEL vs FileId::default()**: SENTINEL is all-F's (used in compound requests). Default is all-zeros (unused). Don't mix them up.
|
||||
465
vendor/smb2/src/types/flags.rs
vendored
Normal file
465
vendor/smb2/src/types/flags.rs
vendored
Normal file
@@ -0,0 +1,465 @@
|
||||
//! Bitflag types for SMB2/3 protocol fields.
|
||||
|
||||
use std::ops::{BitAnd, BitOr, BitOrAssign};
|
||||
|
||||
// ── Macro to reduce boilerplate for flag types ──────────────────────────
|
||||
|
||||
macro_rules! impl_flags {
|
||||
($name:ident, $inner:ty) => {
|
||||
impl $name {
|
||||
/// Create a new flags value from a raw integer.
|
||||
#[inline]
|
||||
pub const fn new(raw: $inner) -> Self {
|
||||
Self(raw)
|
||||
}
|
||||
|
||||
/// Return the raw bits.
|
||||
#[inline]
|
||||
pub const fn bits(&self) -> $inner {
|
||||
self.0
|
||||
}
|
||||
|
||||
/// Check whether a particular flag bit is set.
|
||||
#[inline]
|
||||
pub const fn contains(&self, flag: $inner) -> bool {
|
||||
self.0 & flag == flag
|
||||
}
|
||||
|
||||
/// Set a flag bit.
|
||||
#[inline]
|
||||
pub fn set(&mut self, flag: $inner) {
|
||||
self.0 |= flag;
|
||||
}
|
||||
|
||||
/// Clear a flag bit.
|
||||
#[inline]
|
||||
pub fn clear(&mut self, flag: $inner) {
|
||||
self.0 &= !flag;
|
||||
}
|
||||
}
|
||||
|
||||
impl BitOr for $name {
|
||||
type Output = Self;
|
||||
#[inline]
|
||||
fn bitor(self, rhs: Self) -> Self {
|
||||
Self(self.0 | rhs.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl BitAnd for $name {
|
||||
type Output = Self;
|
||||
#[inline]
|
||||
fn bitand(self, rhs: Self) -> Self {
|
||||
Self(self.0 & rhs.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl BitOrAssign for $name {
|
||||
#[inline]
|
||||
fn bitor_assign(&mut self, rhs: Self) {
|
||||
self.0 |= rhs.0;
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
// ── HeaderFlags ─────────────────────────────────────────────────────────
|
||||
|
||||
/// SMB2 packet header flags (32-bit field from MS-SMB2 2.2.1).
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
|
||||
pub struct HeaderFlags(pub u32);
|
||||
|
||||
impl HeaderFlags {
|
||||
/// The message is a response rather than a request.
|
||||
pub const SERVER_TO_REDIR: u32 = 0x0000_0001;
|
||||
/// The message is an async SMB2 header.
|
||||
pub const ASYNC_COMMAND: u32 = 0x0000_0002;
|
||||
/// The message is part of a compounded chain.
|
||||
pub const RELATED_OPERATIONS: u32 = 0x0000_0004;
|
||||
/// The message is signed.
|
||||
pub const SIGNED: u32 = 0x0000_0008;
|
||||
/// Priority value mask (SMB 3.1.1).
|
||||
pub const PRIORITY_MASK: u32 = 0x0000_0070;
|
||||
/// The command is a DFS operation.
|
||||
pub const DFS_OPERATIONS: u32 = 0x1000_0000;
|
||||
/// The command is a replay operation (SMB 3.x).
|
||||
pub const REPLAY_OPERATION: u32 = 0x2000_0000;
|
||||
|
||||
/// Returns `true` if this is a response (server-to-redirector).
|
||||
#[inline]
|
||||
pub fn is_response(&self) -> bool {
|
||||
self.contains(Self::SERVER_TO_REDIR)
|
||||
}
|
||||
|
||||
/// Returns `true` if the async flag is set.
|
||||
#[inline]
|
||||
pub fn is_async(&self) -> bool {
|
||||
self.contains(Self::ASYNC_COMMAND)
|
||||
}
|
||||
|
||||
/// Returns `true` if the related-operations flag is set.
|
||||
#[inline]
|
||||
pub fn is_related(&self) -> bool {
|
||||
self.contains(Self::RELATED_OPERATIONS)
|
||||
}
|
||||
|
||||
/// Returns `true` if the signed flag is set.
|
||||
#[inline]
|
||||
pub fn is_signed(&self) -> bool {
|
||||
self.contains(Self::SIGNED)
|
||||
}
|
||||
|
||||
/// Set the response flag.
|
||||
#[inline]
|
||||
pub fn set_response(&mut self) {
|
||||
self.set(Self::SERVER_TO_REDIR);
|
||||
}
|
||||
|
||||
/// Set the async flag.
|
||||
#[inline]
|
||||
pub fn set_async(&mut self) {
|
||||
self.set(Self::ASYNC_COMMAND);
|
||||
}
|
||||
|
||||
/// Set the related-operations flag.
|
||||
#[inline]
|
||||
pub fn set_related(&mut self) {
|
||||
self.set(Self::RELATED_OPERATIONS);
|
||||
}
|
||||
|
||||
/// Set the signed flag.
|
||||
#[inline]
|
||||
pub fn set_signed(&mut self) {
|
||||
self.set(Self::SIGNED);
|
||||
}
|
||||
}
|
||||
|
||||
impl_flags!(HeaderFlags, u32);
|
||||
|
||||
// ── SecurityMode ────────────────────────────────────────────────────────
|
||||
|
||||
/// Security mode flags (16-bit field from MS-SMB2 2.2.3/2.2.4).
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
|
||||
pub struct SecurityMode(pub u16);
|
||||
|
||||
impl SecurityMode {
|
||||
/// Signing is supported (enabled).
|
||||
pub const SIGNING_ENABLED: u16 = 0x0001;
|
||||
/// Signing is required.
|
||||
pub const SIGNING_REQUIRED: u16 = 0x0002;
|
||||
|
||||
/// Returns `true` if signing is enabled.
|
||||
#[inline]
|
||||
pub fn signing_enabled(&self) -> bool {
|
||||
self.contains(Self::SIGNING_ENABLED)
|
||||
}
|
||||
|
||||
/// Returns `true` if signing is required.
|
||||
#[inline]
|
||||
pub fn signing_required(&self) -> bool {
|
||||
self.contains(Self::SIGNING_REQUIRED)
|
||||
}
|
||||
}
|
||||
|
||||
impl_flags!(SecurityMode, u16);
|
||||
|
||||
// ── Capabilities ────────────────────────────────────────────────────────
|
||||
|
||||
/// Server/client capability flags (32-bit field from MS-SMB2 2.2.3/2.2.4).
|
||||
///
|
||||
/// With the `serde` feature on, this serializes as the underlying `u32`
|
||||
/// bits, **not** a JSON object of named flags. Decode against the
|
||||
/// associated constants on this type.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
|
||||
pub struct Capabilities(pub u32);
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
impl serde::Serialize for Capabilities {
|
||||
fn serialize<S: serde::Serializer>(&self, s: S) -> std::result::Result<S::Ok, S::Error> {
|
||||
s.serialize_u32(self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl Capabilities {
|
||||
/// Distributed File System (DFS) support.
|
||||
pub const DFS: u32 = 0x0000_0001;
|
||||
/// Leasing support.
|
||||
pub const LEASING: u32 = 0x0000_0002;
|
||||
/// Multi-credit (large MTU) support.
|
||||
pub const LARGE_MTU: u32 = 0x0000_0004;
|
||||
/// Multi-channel support.
|
||||
pub const MULTI_CHANNEL: u32 = 0x0000_0008;
|
||||
/// Persistent handle support.
|
||||
pub const PERSISTENT_HANDLES: u32 = 0x0000_0010;
|
||||
/// Directory leasing support.
|
||||
pub const DIRECTORY_LEASING: u32 = 0x0000_0020;
|
||||
/// Encryption support.
|
||||
pub const ENCRYPTION: u32 = 0x0000_0040;
|
||||
}
|
||||
|
||||
impl_flags!(Capabilities, u32);
|
||||
|
||||
// ── ShareFlags ──────────────────────────────────────────────────────────
|
||||
|
||||
/// Share property flags (32-bit field from MS-SMB2 2.2.10).
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
|
||||
pub struct ShareFlags(pub u32);
|
||||
|
||||
impl ShareFlags {
|
||||
/// The share is in a DFS tree structure.
|
||||
pub const DFS: u32 = 0x0000_0001;
|
||||
/// The share is a DFS root.
|
||||
pub const DFS_ROOT: u32 = 0x0000_0002;
|
||||
|
||||
// Offline caching policies (mutually exclusive, stored in bits 4-5).
|
||||
|
||||
/// The client can cache files explicitly selected by the user.
|
||||
pub const MANUAL_CACHING: u32 = 0x0000_0000;
|
||||
/// The client can automatically cache files used by the user.
|
||||
pub const AUTO_CACHING: u32 = 0x0000_0010;
|
||||
/// Auto-cache with offline mode even when the share is available.
|
||||
pub const VDO_CACHING: u32 = 0x0000_0020;
|
||||
/// Offline caching must not occur.
|
||||
pub const NO_CACHING: u32 = 0x0000_0030;
|
||||
|
||||
/// Disallows exclusive file opens that deny reads.
|
||||
pub const RESTRICT_EXCLUSIVE_OPENS: u32 = 0x0000_0100;
|
||||
/// Disallows exclusive opens that prevent deletion.
|
||||
pub const FORCE_SHARED_DELETE: u32 = 0x0000_0200;
|
||||
/// Allow namespace caching (client must ignore).
|
||||
pub const ALLOW_NAMESPACE_CACHING: u32 = 0x0000_0400;
|
||||
/// Server filters directory entries based on access permissions.
|
||||
pub const ACCESS_BASED_DIRECTORY_ENUM: u32 = 0x0000_0800;
|
||||
/// Server will not issue exclusive caching rights.
|
||||
pub const FORCE_LEVELII_OPLOCK: u32 = 0x0000_1000;
|
||||
/// Hash generation v1 for branch cache (not valid for SMB 2.0.2).
|
||||
pub const ENABLE_HASH_V1: u32 = 0x0000_2000;
|
||||
/// Hash generation v2 for branch cache.
|
||||
pub const ENABLE_HASH_V2: u32 = 0x0000_4000;
|
||||
/// Encryption of remote file access messages required (SMB 3.x).
|
||||
pub const ENCRYPT_DATA: u32 = 0x0000_8000;
|
||||
/// The share supports identity remoting.
|
||||
pub const IDENTITY_REMOTING: u32 = 0x0004_0000;
|
||||
/// The server supports compression on this share (SMB 3.1.1).
|
||||
pub const COMPRESS_DATA: u32 = 0x0010_0000;
|
||||
/// Prefer isolated transport for this share (advisory).
|
||||
pub const ISOLATED_TRANSPORT: u32 = 0x0020_0000;
|
||||
}
|
||||
|
||||
impl_flags!(ShareFlags, u32);
|
||||
|
||||
// ── ShareCapabilities ───────────────────────────────────────────────────
|
||||
|
||||
/// Share capability flags (32-bit field from MS-SMB2 2.2.10).
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
|
||||
pub struct ShareCapabilities(pub u32);
|
||||
|
||||
impl ShareCapabilities {
|
||||
/// The share is part of a DFS tree.
|
||||
pub const DFS: u32 = 0x0000_0008;
|
||||
/// The share has continuously available file handles.
|
||||
pub const CONTINUOUS_AVAILABILITY: u32 = 0x0000_0010;
|
||||
/// The share is a scale-out share.
|
||||
pub const SCALEOUT: u32 = 0x0000_0020;
|
||||
/// The share is a cluster share.
|
||||
pub const CLUSTER: u32 = 0x0000_0040;
|
||||
/// The share is an asymmetric share.
|
||||
pub const ASYMMETRIC: u32 = 0x0000_0080;
|
||||
/// The share supports redirect to owner.
|
||||
pub const REDIRECT_TO_OWNER: u32 = 0x0000_0100;
|
||||
}
|
||||
|
||||
impl_flags!(ShareCapabilities, u32);
|
||||
|
||||
// ── FileAccessMask ──────────────────────────────────────────────────────
|
||||
|
||||
/// File access rights mask (32-bit, from MS-SMB2 2.2.13.1).
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
|
||||
pub struct FileAccessMask(pub u32);
|
||||
|
||||
impl FileAccessMask {
|
||||
/// Read data from the file.
|
||||
pub const FILE_READ_DATA: u32 = 0x0000_0001;
|
||||
/// Write data to the file.
|
||||
pub const FILE_WRITE_DATA: u32 = 0x0000_0002;
|
||||
/// Append data to the file.
|
||||
pub const FILE_APPEND_DATA: u32 = 0x0000_0004;
|
||||
/// Read extended attributes.
|
||||
pub const FILE_READ_EA: u32 = 0x0000_0008;
|
||||
/// Write extended attributes.
|
||||
pub const FILE_WRITE_EA: u32 = 0x0000_0010;
|
||||
/// Execute the file.
|
||||
pub const FILE_EXECUTE: u32 = 0x0000_0020;
|
||||
/// Read file attributes.
|
||||
pub const FILE_READ_ATTRIBUTES: u32 = 0x0000_0080;
|
||||
/// Write file attributes.
|
||||
pub const FILE_WRITE_ATTRIBUTES: u32 = 0x0000_0100;
|
||||
/// Delete the object.
|
||||
pub const DELETE: u32 = 0x0001_0000;
|
||||
/// Read the security descriptor.
|
||||
pub const READ_CONTROL: u32 = 0x0002_0000;
|
||||
/// Modify the DACL.
|
||||
pub const WRITE_DAC: u32 = 0x0004_0000;
|
||||
/// Change the owner.
|
||||
pub const WRITE_OWNER: u32 = 0x0008_0000;
|
||||
/// Synchronize access.
|
||||
pub const SYNCHRONIZE: u32 = 0x0010_0000;
|
||||
/// Request maximum allowed access.
|
||||
pub const MAXIMUM_ALLOWED: u32 = 0x0200_0000;
|
||||
/// All possible access rights.
|
||||
pub const GENERIC_ALL: u32 = 0x1000_0000;
|
||||
/// Execute access.
|
||||
pub const GENERIC_EXECUTE: u32 = 0x2000_0000;
|
||||
/// Write access.
|
||||
pub const GENERIC_WRITE: u32 = 0x4000_0000;
|
||||
/// Read access.
|
||||
pub const GENERIC_READ: u32 = 0x8000_0000;
|
||||
}
|
||||
|
||||
impl_flags!(FileAccessMask, u32);
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// ── HeaderFlags ─────────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn header_flags_default_is_zero() {
|
||||
let f = HeaderFlags::default();
|
||||
assert_eq!(f.bits(), 0);
|
||||
assert!(!f.is_response());
|
||||
assert!(!f.is_async());
|
||||
assert!(!f.is_related());
|
||||
assert!(!f.is_signed());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn header_flags_set_and_check() {
|
||||
let mut f = HeaderFlags::default();
|
||||
f.set_response();
|
||||
assert!(f.is_response());
|
||||
assert!(!f.is_async());
|
||||
|
||||
f.set_signed();
|
||||
assert!(f.is_signed());
|
||||
assert!(f.is_response());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn header_flags_clear() {
|
||||
let mut f = HeaderFlags::new(0xFFFF_FFFF);
|
||||
assert!(f.is_response());
|
||||
f.clear(HeaderFlags::SERVER_TO_REDIR);
|
||||
assert!(!f.is_response());
|
||||
assert!(f.is_async()); // other flags untouched
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn header_flags_contains() {
|
||||
let f = HeaderFlags::new(HeaderFlags::SIGNED | HeaderFlags::ASYNC_COMMAND);
|
||||
assert!(f.contains(HeaderFlags::SIGNED));
|
||||
assert!(f.contains(HeaderFlags::ASYNC_COMMAND));
|
||||
assert!(!f.contains(HeaderFlags::SERVER_TO_REDIR));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn header_flags_bitor() {
|
||||
let a = HeaderFlags::new(HeaderFlags::SERVER_TO_REDIR);
|
||||
let b = HeaderFlags::new(HeaderFlags::SIGNED);
|
||||
let c = a | b;
|
||||
assert!(c.is_response());
|
||||
assert!(c.is_signed());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn header_flags_bitand() {
|
||||
let a = HeaderFlags::new(HeaderFlags::SERVER_TO_REDIR | HeaderFlags::SIGNED);
|
||||
let b = HeaderFlags::new(HeaderFlags::SIGNED);
|
||||
let c = a & b;
|
||||
assert!(!c.is_response());
|
||||
assert!(c.is_signed());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn header_flags_bitor_assign() {
|
||||
let mut a = HeaderFlags::new(HeaderFlags::SERVER_TO_REDIR);
|
||||
a |= HeaderFlags::new(HeaderFlags::ASYNC_COMMAND);
|
||||
assert!(a.is_response());
|
||||
assert!(a.is_async());
|
||||
}
|
||||
|
||||
// ── SecurityMode ────────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn security_mode_signing_enabled() {
|
||||
let m = SecurityMode::new(SecurityMode::SIGNING_ENABLED);
|
||||
assert!(m.signing_enabled());
|
||||
assert!(!m.signing_required());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn security_mode_signing_required() {
|
||||
let m = SecurityMode::new(SecurityMode::SIGNING_ENABLED | SecurityMode::SIGNING_REQUIRED);
|
||||
assert!(m.signing_enabled());
|
||||
assert!(m.signing_required());
|
||||
}
|
||||
|
||||
// ── Capabilities ────────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn capabilities_combine_with_bitor() {
|
||||
let a = Capabilities::new(Capabilities::DFS);
|
||||
let b = Capabilities::new(Capabilities::ENCRYPTION);
|
||||
let c = a | b;
|
||||
assert!(c.contains(Capabilities::DFS));
|
||||
assert!(c.contains(Capabilities::ENCRYPTION));
|
||||
assert!(!c.contains(Capabilities::LEASING));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn capabilities_set_and_clear() {
|
||||
let mut c = Capabilities::default();
|
||||
c.set(Capabilities::LARGE_MTU);
|
||||
assert!(c.contains(Capabilities::LARGE_MTU));
|
||||
c.clear(Capabilities::LARGE_MTU);
|
||||
assert!(!c.contains(Capabilities::LARGE_MTU));
|
||||
}
|
||||
|
||||
// ── ShareFlags ──────────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn share_flags_encrypt_data() {
|
||||
let f = ShareFlags::new(ShareFlags::ENCRYPT_DATA | ShareFlags::DFS);
|
||||
assert!(f.contains(ShareFlags::ENCRYPT_DATA));
|
||||
assert!(f.contains(ShareFlags::DFS));
|
||||
assert!(!f.contains(ShareFlags::COMPRESS_DATA));
|
||||
}
|
||||
|
||||
// ── ShareCapabilities ───────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn share_capabilities_dfs() {
|
||||
let c = ShareCapabilities::new(ShareCapabilities::DFS);
|
||||
assert!(c.contains(ShareCapabilities::DFS));
|
||||
assert!(!c.contains(ShareCapabilities::CLUSTER));
|
||||
}
|
||||
|
||||
// ── FileAccessMask ──────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn file_access_mask_generic_read() {
|
||||
let m = FileAccessMask::new(FileAccessMask::GENERIC_READ);
|
||||
assert!(m.contains(FileAccessMask::GENERIC_READ));
|
||||
assert!(!m.contains(FileAccessMask::GENERIC_WRITE));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn file_access_mask_combine() {
|
||||
let m =
|
||||
FileAccessMask::new(FileAccessMask::FILE_READ_DATA | FileAccessMask::FILE_WRITE_DATA);
|
||||
assert!(m.contains(FileAccessMask::FILE_READ_DATA));
|
||||
assert!(m.contains(FileAccessMask::FILE_WRITE_DATA));
|
||||
assert!(!m.contains(FileAccessMask::DELETE));
|
||||
}
|
||||
}
|
||||
364
vendor/smb2/src/types/mod.rs
vendored
Normal file
364
vendor/smb2/src/types/mod.rs
vendored
Normal file
@@ -0,0 +1,364 @@
|
||||
//! Newtypes, enums, and common data structures for SMB2/3 protocol fields.
|
||||
//!
|
||||
//! Most users don't need to import from this module directly -- the commonly
|
||||
//! used types are re-exported at the crate root.
|
||||
|
||||
pub mod flags;
|
||||
pub mod status;
|
||||
|
||||
use std::fmt;
|
||||
|
||||
use crate::Error;
|
||||
|
||||
/// Requested oplock level (MS-SMB2 2.2.13, 2.2.23).
|
||||
///
|
||||
/// Used across CREATE requests/responses and oplock break messages.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
#[repr(u8)]
|
||||
pub enum OplockLevel {
|
||||
/// No oplock is requested.
|
||||
None = 0x00,
|
||||
/// Level II oplock is requested.
|
||||
LevelII = 0x01,
|
||||
/// Exclusive oplock is requested.
|
||||
Exclusive = 0x08,
|
||||
/// Batch oplock is requested.
|
||||
Batch = 0x09,
|
||||
/// Lease is requested.
|
||||
Lease = 0xFF,
|
||||
}
|
||||
|
||||
impl TryFrom<u8> for OplockLevel {
|
||||
type Error = Error;
|
||||
|
||||
fn try_from(value: u8) -> crate::error::Result<Self> {
|
||||
match value {
|
||||
0x00 => Ok(Self::None),
|
||||
0x01 => Ok(Self::LevelII),
|
||||
0x08 => Ok(Self::Exclusive),
|
||||
0x09 => Ok(Self::Batch),
|
||||
0xFF => Ok(Self::Lease),
|
||||
_ => Err(Error::invalid_data(format!(
|
||||
"invalid OplockLevel: 0x{:02X}",
|
||||
value
|
||||
))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// 64-bit session identifier.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Serialize), serde(transparent))]
|
||||
pub struct SessionId(pub u64);
|
||||
|
||||
impl SessionId {
|
||||
/// Sentinel value indicating no session.
|
||||
pub const NONE: Self = Self(0);
|
||||
}
|
||||
|
||||
impl fmt::Display for SessionId {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "SessionId(0x{:016X})", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
/// 64-bit message identifier for request/response correlation.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Serialize), serde(transparent))]
|
||||
pub struct MessageId(pub u64);
|
||||
|
||||
impl MessageId {
|
||||
/// Unsolicited message ID used for oplock/lease break notifications.
|
||||
pub const UNSOLICITED: Self = Self(0xFFFF_FFFF_FFFF_FFFF);
|
||||
}
|
||||
|
||||
impl fmt::Display for MessageId {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "MessageId(0x{:016X})", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
/// 32-bit tree connect identifier.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Serialize), serde(transparent))]
|
||||
pub struct TreeId(pub u32);
|
||||
|
||||
impl fmt::Display for TreeId {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "TreeId(0x{:08X})", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
/// 16-bit credit charge for multi-credit requests.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
|
||||
pub struct CreditCharge(pub u16);
|
||||
|
||||
impl fmt::Display for CreditCharge {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "CreditCharge({})", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
/// 128-bit file identifier consisting of two 64-bit parts.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
|
||||
pub struct FileId {
|
||||
/// Persistent portion of the file handle.
|
||||
pub persistent: u64,
|
||||
/// Volatile portion of the file handle.
|
||||
pub volatile: u64,
|
||||
}
|
||||
|
||||
impl FileId {
|
||||
/// Sentinel value used in related compound requests.
|
||||
pub const SENTINEL: Self = Self {
|
||||
persistent: 0xFFFF_FFFF_FFFF_FFFF,
|
||||
volatile: 0xFFFF_FFFF_FFFF_FFFF,
|
||||
};
|
||||
}
|
||||
|
||||
impl fmt::Display for FileId {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"FileId(0x{:016X}:0x{:016X})",
|
||||
self.persistent, self.volatile
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// SMB2 command codes from MS-SMB2 section 2.2.1.
|
||||
#[derive(
|
||||
Debug, Clone, Copy, PartialEq, Eq, Hash, num_enum::TryFromPrimitive, num_enum::IntoPrimitive,
|
||||
)]
|
||||
#[repr(u16)]
|
||||
pub enum Command {
|
||||
/// Negotiate protocol version and capabilities.
|
||||
Negotiate = 0x0000,
|
||||
/// Set up an authenticated session.
|
||||
SessionSetup = 0x0001,
|
||||
/// Log off a session.
|
||||
Logoff = 0x0002,
|
||||
/// Connect to a share.
|
||||
TreeConnect = 0x0003,
|
||||
/// Disconnect from a share.
|
||||
TreeDisconnect = 0x0004,
|
||||
/// Open or create a file.
|
||||
Create = 0x0005,
|
||||
/// Close a file handle.
|
||||
Close = 0x0006,
|
||||
/// Flush cached data to stable storage.
|
||||
Flush = 0x0007,
|
||||
/// Read data from a file.
|
||||
Read = 0x0008,
|
||||
/// Write data to a file.
|
||||
Write = 0x0009,
|
||||
/// Lock or unlock byte ranges.
|
||||
Lock = 0x000A,
|
||||
/// Issue a device control or file system control command.
|
||||
Ioctl = 0x000B,
|
||||
/// Cancel a previously sent request.
|
||||
Cancel = 0x000C,
|
||||
/// Check server liveness.
|
||||
Echo = 0x000D,
|
||||
/// Enumerate directory contents.
|
||||
QueryDirectory = 0x000E,
|
||||
/// Request change notifications on a directory.
|
||||
ChangeNotify = 0x000F,
|
||||
/// Query file or filesystem information.
|
||||
QueryInfo = 0x0010,
|
||||
/// Set file or filesystem information.
|
||||
SetInfo = 0x0011,
|
||||
/// Oplock or lease break notification/acknowledgment.
|
||||
OplockBreak = 0x0012,
|
||||
}
|
||||
|
||||
impl fmt::Display for Command {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
fmt::Debug::fmt(self, f)
|
||||
}
|
||||
}
|
||||
|
||||
/// SMB2 dialect revision identifiers from MS-SMB2 section 2.2.3.
|
||||
#[derive(
|
||||
Debug,
|
||||
Clone,
|
||||
Copy,
|
||||
PartialEq,
|
||||
Eq,
|
||||
PartialOrd,
|
||||
Ord,
|
||||
Hash,
|
||||
num_enum::TryFromPrimitive,
|
||||
num_enum::IntoPrimitive,
|
||||
)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
|
||||
#[repr(u16)]
|
||||
pub enum Dialect {
|
||||
/// SMB 2.0.2 dialect.
|
||||
Smb2_0_2 = 0x0202,
|
||||
/// SMB 2.1 dialect.
|
||||
Smb2_1 = 0x0210,
|
||||
/// SMB 3.0 dialect.
|
||||
Smb3_0 = 0x0300,
|
||||
/// SMB 3.0.2 dialect.
|
||||
Smb3_0_2 = 0x0302,
|
||||
/// SMB 3.1.1 dialect.
|
||||
Smb3_1_1 = 0x0311,
|
||||
}
|
||||
|
||||
impl Dialect {
|
||||
/// All supported dialect revisions, in ascending order.
|
||||
pub const ALL: &[Dialect] = &[
|
||||
Dialect::Smb2_0_2,
|
||||
Dialect::Smb2_1,
|
||||
Dialect::Smb3_0,
|
||||
Dialect::Smb3_0_2,
|
||||
Dialect::Smb3_1_1,
|
||||
];
|
||||
}
|
||||
|
||||
impl fmt::Display for Dialect {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
Dialect::Smb2_0_2 => f.write_str("SMB 2.0.2"),
|
||||
Dialect::Smb2_1 => f.write_str("SMB 2.1"),
|
||||
Dialect::Smb3_0 => f.write_str("SMB 3.0"),
|
||||
Dialect::Smb3_0_2 => f.write_str("SMB 3.0.2"),
|
||||
Dialect::Smb3_1_1 => f.write_str("SMB 3.1.1"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// ── Newtype tests ───────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn session_id_none_is_zero() {
|
||||
assert_eq!(SessionId::NONE, SessionId(0));
|
||||
assert_eq!(SessionId::NONE.0, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn message_id_unsolicited() {
|
||||
assert_eq!(MessageId::UNSOLICITED.0, 0xFFFF_FFFF_FFFF_FFFF);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn file_id_sentinel() {
|
||||
assert_eq!(FileId::SENTINEL.persistent, 0xFFFF_FFFF_FFFF_FFFF);
|
||||
assert_eq!(FileId::SENTINEL.volatile, 0xFFFF_FFFF_FFFF_FFFF);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn newtype_display_formatting() {
|
||||
assert_eq!(
|
||||
SessionId(0x1234).to_string(),
|
||||
"SessionId(0x0000000000001234)"
|
||||
);
|
||||
assert_eq!(
|
||||
MessageId(0xABCD).to_string(),
|
||||
"MessageId(0x000000000000ABCD)"
|
||||
);
|
||||
assert_eq!(TreeId(0x42).to_string(), "TreeId(0x00000042)");
|
||||
assert_eq!(CreditCharge(5).to_string(), "CreditCharge(5)");
|
||||
assert_eq!(
|
||||
FileId {
|
||||
persistent: 0x11,
|
||||
volatile: 0x22
|
||||
}
|
||||
.to_string(),
|
||||
"FileId(0x0000000000000011:0x0000000000000022)"
|
||||
);
|
||||
}
|
||||
|
||||
// ── Command tests ───────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn command_roundtrip_via_u16() {
|
||||
assert_eq!(Command::try_from(0x0005u16), Ok(Command::Create));
|
||||
assert_eq!(u16::from(Command::Create), 0x0005);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn command_all_variants_correct_values() {
|
||||
assert_eq!(u16::from(Command::Negotiate), 0x0000);
|
||||
assert_eq!(u16::from(Command::SessionSetup), 0x0001);
|
||||
assert_eq!(u16::from(Command::Logoff), 0x0002);
|
||||
assert_eq!(u16::from(Command::TreeConnect), 0x0003);
|
||||
assert_eq!(u16::from(Command::TreeDisconnect), 0x0004);
|
||||
assert_eq!(u16::from(Command::Create), 0x0005);
|
||||
assert_eq!(u16::from(Command::Close), 0x0006);
|
||||
assert_eq!(u16::from(Command::Flush), 0x0007);
|
||||
assert_eq!(u16::from(Command::Read), 0x0008);
|
||||
assert_eq!(u16::from(Command::Write), 0x0009);
|
||||
assert_eq!(u16::from(Command::Lock), 0x000A);
|
||||
assert_eq!(u16::from(Command::Ioctl), 0x000B);
|
||||
assert_eq!(u16::from(Command::Cancel), 0x000C);
|
||||
assert_eq!(u16::from(Command::Echo), 0x000D);
|
||||
assert_eq!(u16::from(Command::QueryDirectory), 0x000E);
|
||||
assert_eq!(u16::from(Command::ChangeNotify), 0x000F);
|
||||
assert_eq!(u16::from(Command::QueryInfo), 0x0010);
|
||||
assert_eq!(u16::from(Command::SetInfo), 0x0011);
|
||||
assert_eq!(u16::from(Command::OplockBreak), 0x0012);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn command_invalid_u16_is_error() {
|
||||
assert!(Command::try_from(0xFFFFu16).is_err());
|
||||
assert!(Command::try_from(0x0013u16).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn command_display() {
|
||||
assert_eq!(Command::Create.to_string(), "Create");
|
||||
assert_eq!(Command::OplockBreak.to_string(), "OplockBreak");
|
||||
}
|
||||
|
||||
// ── Dialect tests ───────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn dialect_ordering() {
|
||||
assert!(Dialect::Smb2_0_2 < Dialect::Smb2_1);
|
||||
assert!(Dialect::Smb2_1 < Dialect::Smb3_0);
|
||||
assert!(Dialect::Smb3_0 < Dialect::Smb3_0_2);
|
||||
assert!(Dialect::Smb3_0_2 < Dialect::Smb3_1_1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dialect_roundtrip_via_u16() {
|
||||
assert_eq!(Dialect::try_from(0x0311u16), Ok(Dialect::Smb3_1_1));
|
||||
assert_eq!(u16::from(Dialect::Smb3_1_1), 0x0311);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dialect_invalid_u16_is_error() {
|
||||
assert!(Dialect::try_from(0x0000u16).is_err());
|
||||
assert!(Dialect::try_from(0x0201u16).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dialect_display() {
|
||||
assert_eq!(Dialect::Smb2_0_2.to_string(), "SMB 2.0.2");
|
||||
assert_eq!(Dialect::Smb2_1.to_string(), "SMB 2.1");
|
||||
assert_eq!(Dialect::Smb3_0.to_string(), "SMB 3.0");
|
||||
assert_eq!(Dialect::Smb3_0_2.to_string(), "SMB 3.0.2");
|
||||
assert_eq!(Dialect::Smb3_1_1.to_string(), "SMB 3.1.1");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dialect_all_has_five_variants() {
|
||||
assert_eq!(Dialect::ALL.len(), 5);
|
||||
assert_eq!(Dialect::ALL[0], Dialect::Smb2_0_2);
|
||||
assert_eq!(Dialect::ALL[4], Dialect::Smb3_1_1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dialect_all_is_sorted() {
|
||||
for w in Dialect::ALL.windows(2) {
|
||||
assert!(w[0] < w[1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
384
vendor/smb2/src/types/status.rs
vendored
Normal file
384
vendor/smb2/src/types/status.rs
vendored
Normal file
@@ -0,0 +1,384 @@
|
||||
//! NTSTATUS codes used by SMB2/3 (from MS-ERREF).
|
||||
|
||||
use std::fmt;
|
||||
|
||||
/// Defines `NtStatus` associated constants and the `name()` match arms from
|
||||
/// a single table, so adding a new status code only requires one edit.
|
||||
macro_rules! nt_status_codes {
|
||||
(
|
||||
$(
|
||||
$(#[$meta:meta])*
|
||||
$name:ident = $value:expr, $display:expr;
|
||||
)*
|
||||
) => {
|
||||
impl NtStatus {
|
||||
$(
|
||||
$(#[$meta])*
|
||||
pub const $name: Self = Self($value);
|
||||
)*
|
||||
|
||||
/// Returns a human-readable name for known status codes,
|
||||
/// or `None` for unknown codes.
|
||||
fn name(&self) -> Option<&'static str> {
|
||||
match self.0 {
|
||||
$( $value => Some($display), )*
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/// NT status code returned in SMB2 response headers.
|
||||
///
|
||||
/// The top two bits encode severity:
|
||||
/// - `00` = success
|
||||
/// - `01` = informational
|
||||
/// - `10` = warning
|
||||
/// - `11` = error
|
||||
#[derive(Clone, Copy, PartialEq, Eq, Hash, Default)]
|
||||
pub struct NtStatus(pub u32);
|
||||
|
||||
nt_status_codes! {
|
||||
// -- Success (severity 0b00) --
|
||||
|
||||
/// The operation completed successfully.
|
||||
SUCCESS = 0x0000_0000, "STATUS_SUCCESS";
|
||||
|
||||
/// The operation that was requested is pending completion.
|
||||
PENDING = 0x0000_0103, "STATUS_PENDING";
|
||||
|
||||
/// Oplock break notification (informational).
|
||||
NOTIFY_ENUM_DIR = 0x0000_010C, "STATUS_NOTIFY_ENUM_DIR";
|
||||
|
||||
// -- Informational (severity 0b00, facility-specific) --
|
||||
|
||||
/// The authentication exchange is not complete -- send the next
|
||||
/// SESSION_SETUP with the GSS token from this response.
|
||||
///
|
||||
/// **Important:** The severity bits are 0b11 (error), so `is_error()`
|
||||
/// returns `true`. But this is NOT a real error -- it's a "keep going"
|
||||
/// signal during NTLM/SPNEGO auth. Auth code must check
|
||||
/// `is_more_processing_required()` before checking `is_error()`.
|
||||
MORE_PROCESSING_REQUIRED = 0xC000_0016, "STATUS_MORE_PROCESSING_REQUIRED";
|
||||
|
||||
// -- Warnings (severity 0b10) --
|
||||
|
||||
/// The data was too large to fit into the specified buffer.
|
||||
/// This is a warning -- the response body contains valid partial data.
|
||||
BUFFER_OVERFLOW = 0x8000_0005, "STATUS_BUFFER_OVERFLOW";
|
||||
|
||||
/// No more files were found which match the file specification.
|
||||
NO_MORE_FILES = 0x8000_0006, "STATUS_NO_MORE_FILES";
|
||||
|
||||
// -- Errors (severity 0b11) --
|
||||
|
||||
/// The requested operation was unsuccessful.
|
||||
UNSUCCESSFUL = 0xC000_0001, "STATUS_UNSUCCESSFUL";
|
||||
|
||||
/// The requested operation is not implemented.
|
||||
NOT_IMPLEMENTED = 0xC000_0002, "STATUS_NOT_IMPLEMENTED";
|
||||
|
||||
/// An invalid parameter was passed to a service or function.
|
||||
INVALID_PARAMETER = 0xC000_000D, "STATUS_INVALID_PARAMETER";
|
||||
|
||||
/// A device that does not exist was specified.
|
||||
NO_SUCH_DEVICE = 0xC000_000E, "STATUS_NO_SUCH_DEVICE";
|
||||
|
||||
/// The file does not exist.
|
||||
NO_SUCH_FILE = 0xC000_000F, "STATUS_NO_SUCH_FILE";
|
||||
|
||||
/// The specified request is not a valid operation for the target device.
|
||||
INVALID_DEVICE_REQUEST = 0xC000_0010, "STATUS_INVALID_DEVICE_REQUEST";
|
||||
|
||||
/// The end-of-file marker has been reached.
|
||||
END_OF_FILE = 0xC000_0011, "STATUS_END_OF_FILE";
|
||||
|
||||
/// A process has requested access to an object but has not been
|
||||
/// granted those access rights.
|
||||
ACCESS_DENIED = 0xC000_0022, "STATUS_ACCESS_DENIED";
|
||||
|
||||
/// The buffer is too small to contain the entry.
|
||||
BUFFER_TOO_SMALL = 0xC000_0023, "STATUS_BUFFER_TOO_SMALL";
|
||||
|
||||
/// The object name is not found.
|
||||
OBJECT_NAME_NOT_FOUND = 0xC000_0034, "STATUS_OBJECT_NAME_NOT_FOUND";
|
||||
|
||||
/// The object name already exists.
|
||||
OBJECT_NAME_COLLISION = 0xC000_0035, "STATUS_OBJECT_NAME_COLLISION";
|
||||
|
||||
/// The path does not exist.
|
||||
OBJECT_PATH_NOT_FOUND = 0xC000_003A, "STATUS_OBJECT_PATH_NOT_FOUND";
|
||||
|
||||
/// A file cannot be opened because the share access flags
|
||||
/// are incompatible.
|
||||
SHARING_VIOLATION = 0xC000_0043, "STATUS_SHARING_VIOLATION";
|
||||
|
||||
/// A requested read/write cannot be granted due to a conflicting
|
||||
/// file lock.
|
||||
FILE_LOCK_CONFLICT = 0xC000_0054, "STATUS_FILE_LOCK_CONFLICT";
|
||||
|
||||
/// A non-close operation has been requested of a file object that
|
||||
/// has a delete pending.
|
||||
DELETE_PENDING = 0xC000_0056, "STATUS_DELETE_PENDING";
|
||||
|
||||
/// The disk is full.
|
||||
DISK_FULL = 0xC000_007F, "STATUS_DISK_FULL";
|
||||
|
||||
/// The attempted logon is invalid.
|
||||
LOGON_FAILURE = 0xC000_006D, "STATUS_LOGON_FAILURE";
|
||||
|
||||
/// The referenced account is currently disabled.
|
||||
ACCOUNT_DISABLED = 0xC000_0072, "STATUS_ACCOUNT_DISABLED";
|
||||
|
||||
/// Insufficient system resources exist to complete the API.
|
||||
INSUFFICIENT_RESOURCES = 0xC000_009A, "STATUS_INSUFFICIENT_RESOURCES";
|
||||
|
||||
/// The file that was specified as a target is a directory.
|
||||
FILE_IS_A_DIRECTORY = 0xC000_00BA, "STATUS_FILE_IS_A_DIRECTORY";
|
||||
|
||||
/// The network path cannot be located.
|
||||
BAD_NETWORK_PATH = 0xC000_00BE, "STATUS_BAD_NETWORK_PATH";
|
||||
|
||||
/// The network name was deleted.
|
||||
NETWORK_NAME_DELETED = 0xC000_00C9, "STATUS_NETWORK_NAME_DELETED";
|
||||
|
||||
/// The specified share name cannot be found on the remote server.
|
||||
BAD_NETWORK_NAME = 0xC000_00CC, "STATUS_BAD_NETWORK_NAME";
|
||||
|
||||
/// No more connections can be made to this remote computer at this time.
|
||||
REQUEST_NOT_ACCEPTED = 0xC000_00D0, "STATUS_REQUEST_NOT_ACCEPTED";
|
||||
|
||||
/// A requested opened file is not a directory.
|
||||
NOT_A_DIRECTORY = 0xC000_0103, "STATUS_NOT_A_DIRECTORY";
|
||||
|
||||
/// The I/O request was canceled.
|
||||
CANCELLED = 0xC000_0120, "STATUS_CANCELLED";
|
||||
|
||||
/// An I/O request other than close was attempted using a file object
|
||||
/// that had already been closed.
|
||||
FILE_CLOSED = 0xC000_0128, "STATUS_FILE_CLOSED";
|
||||
|
||||
/// The remote user session has been deleted.
|
||||
USER_SESSION_DELETED = 0xC000_0203, "STATUS_USER_SESSION_DELETED";
|
||||
|
||||
/// Insufficient server resources exist to complete the request.
|
||||
INSUFF_SERVER_RESOURCES = 0xC000_0205, "STATUS_INSUFF_SERVER_RESOURCES";
|
||||
|
||||
/// The object was not found.
|
||||
NOT_FOUND = 0xC000_0225, "STATUS_NOT_FOUND";
|
||||
|
||||
/// The contacted server does not support the indicated part
|
||||
/// of the DFS namespace.
|
||||
PATH_NOT_COVERED = 0xC000_0257, "STATUS_PATH_NOT_COVERED";
|
||||
|
||||
/// The client session has expired; the client must re-authenticate.
|
||||
NETWORK_SESSION_EXPIRED = 0xC000_035C, "STATUS_NETWORK_SESSION_EXPIRED";
|
||||
}
|
||||
|
||||
impl NtStatus {
|
||||
// -- Helper methods --
|
||||
|
||||
/// Returns the severity bits (top 2 bits): 0 = success, 1 = info,
|
||||
/// 2 = warning, 3 = error.
|
||||
#[inline]
|
||||
pub fn severity(&self) -> u8 {
|
||||
(self.0 >> 30) as u8
|
||||
}
|
||||
|
||||
/// Returns `true` if the status indicates success (severity 0b00).
|
||||
#[inline]
|
||||
pub fn is_success(&self) -> bool {
|
||||
self.severity() == 0
|
||||
}
|
||||
|
||||
/// Returns `true` if the status is a warning (severity 0b10).
|
||||
#[inline]
|
||||
pub fn is_warning(&self) -> bool {
|
||||
self.severity() == 2
|
||||
}
|
||||
|
||||
/// Returns `true` if the status is an error (severity 0b11).
|
||||
#[inline]
|
||||
pub fn is_error(&self) -> bool {
|
||||
self.severity() == 3
|
||||
}
|
||||
|
||||
/// Returns `true` if this is `STATUS_PENDING`.
|
||||
#[inline]
|
||||
pub fn is_pending(&self) -> bool {
|
||||
*self == Self::PENDING
|
||||
}
|
||||
|
||||
/// Returns `true` if this status indicates the operation produced usable data.
|
||||
///
|
||||
/// This includes `SUCCESS` and warnings like `BUFFER_OVERFLOW` where partial
|
||||
/// data is valid and should be parsed.
|
||||
#[inline]
|
||||
pub fn is_success_or_partial(&self) -> bool {
|
||||
self.is_success() || *self == Self::BUFFER_OVERFLOW
|
||||
}
|
||||
|
||||
/// Returns `true` if the server wants another SESSION_SETUP round-trip.
|
||||
///
|
||||
/// Check this BEFORE `is_error()` during authentication -- it has
|
||||
/// error severity bits but is not a real error.
|
||||
#[inline]
|
||||
pub fn is_more_processing_required(&self) -> bool {
|
||||
*self == Self::MORE_PROCESSING_REQUIRED
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Debug for NtStatus {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self.name() {
|
||||
Some(name) => write!(f, "NtStatus({name})"),
|
||||
None => write!(f, "NtStatus(0x{:08X})", self.0),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for NtStatus {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self.name() {
|
||||
Some(name) => f.write_str(name),
|
||||
None => write!(f, "0x{:08X}", self.0),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn success_is_success() {
|
||||
assert!(NtStatus::SUCCESS.is_success());
|
||||
assert!(!NtStatus::SUCCESS.is_error());
|
||||
assert!(!NtStatus::SUCCESS.is_warning());
|
||||
assert_eq!(NtStatus::SUCCESS.severity(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn access_denied_is_error() {
|
||||
assert!(NtStatus::ACCESS_DENIED.is_error());
|
||||
assert!(!NtStatus::ACCESS_DENIED.is_success());
|
||||
assert!(!NtStatus::ACCESS_DENIED.is_warning());
|
||||
assert_eq!(NtStatus::ACCESS_DENIED.severity(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn buffer_overflow_is_warning() {
|
||||
assert!(NtStatus::BUFFER_OVERFLOW.is_warning());
|
||||
assert!(!NtStatus::BUFFER_OVERFLOW.is_success());
|
||||
assert!(!NtStatus::BUFFER_OVERFLOW.is_error());
|
||||
assert_eq!(NtStatus::BUFFER_OVERFLOW.severity(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pending_is_pending() {
|
||||
assert!(NtStatus::PENDING.is_pending());
|
||||
assert!(NtStatus::PENDING.is_success()); // severity 0b00
|
||||
assert!(!NtStatus::SUCCESS.is_pending());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn more_processing_required_is_error_severity() {
|
||||
// 0xC0000016 has severity 0b11 (error), even though semantically
|
||||
// it means "keep going" during authentication handshakes.
|
||||
assert!(NtStatus::MORE_PROCESSING_REQUIRED.is_error());
|
||||
assert_eq!(NtStatus::MORE_PROCESSING_REQUIRED.severity(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display_known_code() {
|
||||
assert_eq!(NtStatus::ACCESS_DENIED.to_string(), "STATUS_ACCESS_DENIED");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display_unknown_code() {
|
||||
let unknown = NtStatus(0xDEAD_BEEF);
|
||||
assert_eq!(unknown.to_string(), "0xDEADBEEF");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn debug_known_code() {
|
||||
let s = format!("{:?}", NtStatus::SUCCESS);
|
||||
assert_eq!(s, "NtStatus(STATUS_SUCCESS)");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn debug_unknown_code() {
|
||||
let s = format!("{:?}", NtStatus(0x1234_5678));
|
||||
assert_eq!(s, "NtStatus(0x12345678)");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn no_more_files_is_warning() {
|
||||
assert!(NtStatus::NO_MORE_FILES.is_warning());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn default_is_success() {
|
||||
assert_eq!(NtStatus::default(), NtStatus::SUCCESS);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_success_or_partial() {
|
||||
// SUCCESS is usable
|
||||
assert!(NtStatus::SUCCESS.is_success_or_partial());
|
||||
// BUFFER_OVERFLOW is a warning with valid partial data
|
||||
assert!(NtStatus::BUFFER_OVERFLOW.is_success_or_partial());
|
||||
// Errors are not usable
|
||||
assert!(!NtStatus::ACCESS_DENIED.is_success_or_partial());
|
||||
// PENDING has success severity (0b00) so is_success() is true,
|
||||
// but callers handle PENDING separately before reaching status checks.
|
||||
assert!(NtStatus::PENDING.is_success_or_partial());
|
||||
// Other warnings (not BUFFER_OVERFLOW) are not usable
|
||||
assert!(!NtStatus::NO_MORE_FILES.is_success_or_partial());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn all_error_codes_have_error_severity() {
|
||||
let errors = [
|
||||
NtStatus::UNSUCCESSFUL,
|
||||
NtStatus::NOT_IMPLEMENTED,
|
||||
NtStatus::INVALID_PARAMETER,
|
||||
NtStatus::NO_SUCH_DEVICE,
|
||||
NtStatus::NO_SUCH_FILE,
|
||||
NtStatus::END_OF_FILE,
|
||||
NtStatus::ACCESS_DENIED,
|
||||
NtStatus::BUFFER_TOO_SMALL,
|
||||
NtStatus::OBJECT_NAME_NOT_FOUND,
|
||||
NtStatus::OBJECT_NAME_COLLISION,
|
||||
NtStatus::OBJECT_PATH_NOT_FOUND,
|
||||
NtStatus::SHARING_VIOLATION,
|
||||
NtStatus::FILE_LOCK_CONFLICT,
|
||||
NtStatus::DELETE_PENDING,
|
||||
NtStatus::LOGON_FAILURE,
|
||||
NtStatus::ACCOUNT_DISABLED,
|
||||
NtStatus::INSUFFICIENT_RESOURCES,
|
||||
NtStatus::FILE_IS_A_DIRECTORY,
|
||||
NtStatus::BAD_NETWORK_PATH,
|
||||
NtStatus::NETWORK_NAME_DELETED,
|
||||
NtStatus::BAD_NETWORK_NAME,
|
||||
NtStatus::REQUEST_NOT_ACCEPTED,
|
||||
NtStatus::NOT_A_DIRECTORY,
|
||||
NtStatus::CANCELLED,
|
||||
NtStatus::FILE_CLOSED,
|
||||
NtStatus::USER_SESSION_DELETED,
|
||||
NtStatus::INSUFF_SERVER_RESOURCES,
|
||||
NtStatus::NOT_FOUND,
|
||||
NtStatus::PATH_NOT_COVERED,
|
||||
NtStatus::NETWORK_SESSION_EXPIRED,
|
||||
NtStatus::MORE_PROCESSING_REQUIRED,
|
||||
];
|
||||
for status in &errors {
|
||||
assert!(
|
||||
status.is_error(),
|
||||
"{status} should be error but severity is {}",
|
||||
status.severity()
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user