Files
markbaseengine/Sources/MarkBase/Weights/SafeTensors.swift
T
MarkBase Admin 8a66b9086a
CI / build (push) Waiting to run
CI / unit-tests (push) Blocked by required conditions
CI / lint (push) Blocked by required conditions
v2: Initial clean branch with unit tests + CI/CD pipeline
- Started from ac75faa (initial E4B-MarkBase integration)
- Kept Sources/ (all engine code) + Package.swift + .gitignore
- Removed all ad-hoc tests, documentation, scripts, Python files
- Added Tests/00_Unit/ (MathTest, TokenizerTest, SamplerTest)
- Added .gitea/workflows/ci.yaml (build + unit tests + lint)
- Added Scripts/check_resources.sh (memory-aware test runner)
- Added Tests/Manifest.json (resource requirements for all tests)
- Focus: 4-bit quantized models only
2026-07-05 13:29:25 +08:00

127 lines
4.9 KiB
Swift

import Foundation
/// SafeTensors file reader. Handles single-file and sharded (index) formats,
/// BF16→Float32 conversion, and quantized tensor grouping.
public final class SafeTensorsReader {
public let fileURL: URL
private let headerSize: Int
private let rawHeader: [String: Any]
private let fileHandle: FileHandle // kept open for fast repeated reads
private let lock = NSLock() // thread-safe access to fileHandle
// ── Init ──────────────────────────────────────────
/// Open a single .safetensors file and parse its header.
public init(path: String) throws {
self.fileURL = URL(fileURLWithPath: path)
let handle = try FileHandle(forReadingFrom: fileURL)
let lenData = handle.readData(ofLength: 8)
headerSize = Int(UInt64(littleEndian: lenData.withUnsafeBytes { $0.load(as: UInt64.self) }))
let jsonData = handle.readData(ofLength: headerSize)
guard let json = try JSONSerialization.jsonObject(with: jsonData) as? [String: Any] else {
try? handle.close()
throw WeightError.invalidHeader("Top-level JSON is not a dictionary")
}
self.rawHeader = json
self.fileHandle = handle
}
deinit {
try? fileHandle.close()
}
// ── Tensor listing ────────────────────────────────
/// All tensor descriptors in this file.
public var allTensors: [TensorDescriptor] {
rawHeader.compactMap { name, value in
guard let info = value as? [String: Any],
let dtypeStr = info["dtype"] as? String,
let dtype = TensorDType.from(dtype: dtypeStr),
let shape = info["shape"] as? [Int],
let offsets = info["data_offsets"] as? [Int],
offsets.count == 2
else { return nil }
return TensorDescriptor(
name: name, dtype: dtype, shape: shape,
dataOffset: headerSize + 8 + offsets[0],
dataSize: offsets[1] - offsets[0]
)
}
}
/// All tensor descriptors (convenience).
public func allDescriptors() -> [TensorDescriptor] { allTensors }
/// Look up a specific tensor by name.
public func tensor(named name: String) -> TensorDescriptor? {
allTensors.first { $0.name == name }
}
// ── Reading raw data ──────────────────────────────
/// Read raw bytes for a tensor.
public func read(tensor: TensorDescriptor) throws -> Data {
lock.lock()
defer { lock.unlock() }
try fileHandle.seek(toOffset: UInt64(tensor.dataOffset))
return fileHandle.readData(ofLength: tensor.dataSize)
}
/// Read a specific tensor by name.
public func read(named name: String) throws -> Data {
guard let desc = tensor(named: name) else {
throw WeightError.tensorNotFound(name)
}
return try read(tensor: desc)
}
/// Read raw bytes for a tensor as uint32 array
public func readUint32(named name: String) throws -> [UInt32] {
guard let desc = tensor(named: name) else {
throw WeightError.tensorNotFound(name)
}
let data = try read(tensor: desc)
return data.withUnsafeBytes { ptr in
let uint32Ptr = ptr.bindMemory(to: UInt32.self)
return Array(uint32Ptr)
}
}
// ── BF16 → Float32 conversion ─────────────────────
/// Convert BF16 binary data to Float32 array.
public static func bf16ToFloat32(_ data: Data) -> [Float] {
data.withUnsafeBytes { ptr in
let bf16 = ptr.assumingMemoryBound(to: UInt16.self)
return (0..<data.count / 2).map { i in
Float(bitPattern: UInt32(bf16[i]) << 16)
}
}
}
}
// ── Errors ────────────────────────────────────────────
public enum WeightError: Error, LocalizedError {
case invalidHeader(String)
case tensorNotFound(String)
case unsupportedDtype(String)
case fileNotFound(String)
case readFailed(String)
case bufferCreationFailed(String)
public var errorDescription: String? {
switch self {
case .invalidHeader(let detail): return "Invalid SafeTensors header: \(detail)"
case .tensorNotFound(let name): return "Tensor '\(name)' not found"
case .unsupportedDtype(let dtype): return "Unsupported dtype: \(dtype)"
case .fileNotFound(let path): return "File not found: \(path)"
case .readFailed(let detail): return "Read failed: \(detail)"
case .bufferCreationFailed(let name): return "Failed to create Metal buffer: \(name)"
}
}
}