Files
markbaseengine/Sources/MarkBase/Engine.swift
T
MarkBase Admin 31427770b1
CI / build (push) Waiting to run
CI / unit-tests (push) Blocked by required conditions
CI / lint (push) Blocked by required conditions
v2: Apply tokenizer UTF-8 fix + Engine writeFloats helper
- Tokenizer fix: collect <0xXX> bytes and decode as UTF-8
  (fixes Chinese/non-ASCII character decoding)
- BPETokenizer + HuggingFaceTokenizer: both updated
- Engine.swift: added writeFloats() utility method
- FloatWeights struct added to Layer.swift (bf16 support)
- attnQBits/KBits/VBits/OBits detection added to Model.swift
- bf16 layer weight support from commit 48c0347 cherry-picked
2026-07-05 13:41:48 +08:00

297 lines
12 KiB
Swift

@preconcurrency import Foundation
@preconcurrency import Metal
public enum E4BError: Error, LocalizedError {
case noMetalDevice
case libraryNotFound
case sourceCompilationFailed(String)
case pipelineCreationFailed(String)
case pipelineNotFound(String)
case bufferCreationFailed
case commandBufferCreationFailed
case commandEncoderCreationFailed
public var errorDescription: String? {
switch self {
case .noMetalDevice: return "No Metal-capable GPU found"
case .libraryNotFound: return "Metal library not found"
case .sourceCompilationFailed(let detail): return "Metal source compilation failed: \(detail)"
case .pipelineCreationFailed(let detail): return "Pipeline creation failed: \(detail)"
case .pipelineNotFound(let name): return "Kernel '\(name)' not found in library"
case .bufferCreationFailed: return "Failed to allocate Metal buffer"
case .commandBufferCreationFailed: return "Failed to create command buffer"
case .commandEncoderCreationFailed: return "Failed to create command encoder"
}
}
}
/// Pure-Swift Metal engine with pipeline cache, proper threadgroup sizing,
/// buffer pool for reuse, and support for both runtime source compilation and pre-compiled metallib.
public final class MarkBaseEngine: @unchecked Sendable {
public let device: MTLDevice
public let commandQueue: MTLCommandQueue
public let bufferPool: BufferPool
public private(set) var library: MTLLibrary?
private var pipelineCache: [String: MTLComputePipelineState] = [:]
public init() throws {
guard let device = MTLCreateSystemDefaultDevice() else {
throw E4BError.noMetalDevice
}
self.device = device
guard let queue = device.makeCommandQueue() else {
throw E4BError.bufferCreationFailed
}
self.commandQueue = queue
self.bufferPool = BufferPool(device: device)
}
/// Initialize with Metal kernels auto-compiled from source.
/// Loads original, optimized, and fusion kernels.
public convenience init(autoCompile: Bool = true) throws {
try self.init()
if autoCompile {
try compileSource(MetalKernels.fullOptimizedSourceWithFusion)
}
}
// ── Library ──────────────────────────────────────
/// Compile Metal source at runtime.
public func compileSource(_ source: String) throws {
pipelineCache.removeAll()
library = try device.makeLibrary(source: source, options: nil)
}
/// Load a pre-compiled .metallib from a file path.
public func loadMetallib(path: String) throws {
let data = try Data(contentsOf: URL(fileURLWithPath: path))
try loadMetallib(data: data)
}
/// Load a pre-compiled .metallib from Data.
public func loadMetallib(data: Data) throws {
pipelineCache.removeAll()
library = try data.withUnsafeBytes { rawPtr in
try device.makeLibrary(data: DispatchData(bytes: rawPtr))
}
}
// ── Pipeline (with cache) ─────────────────────────
/// Return a cached or newly-created compute pipeline state.
public func pipeline(named name: String) throws -> MTLComputePipelineState {
if let cached = pipelineCache[name] { return cached }
guard let lib = library else { throw E4BError.libraryNotFound }
guard let fn = lib.makeFunction(name: name) else {
throw E4BError.pipelineNotFound(name)
}
let pso = try device.makeComputePipelineState(function: fn)
pipelineCache[name] = pso
return pso
}
/// Clear the pipeline cache (e.g. after recompiling the library).
public func clearPipelineCache() { pipelineCache.removeAll() }
// ── Buffers ───────────────────────────────────────
/// Acquire a buffer from the pool (or create new if none available)
public func acquireBuffer(length: Int) -> MTLBuffer {
return bufferPool.acquire(length: length)
}
/// Release a buffer back to the pool for reuse
public func releaseBuffer(_ buffer: MTLBuffer) {
bufferPool.release(buffer)
}
public func makeBuffer<T>(_ values: [T]) throws -> MTLBuffer {
let count = values.count * MemoryLayout<T>.stride
guard let buf = values.withUnsafeBytes({ rawPtr in
device.makeBuffer(bytes: rawPtr.baseAddress!, length: count,
options: .storageModeShared)
}) else { throw E4BError.bufferCreationFailed }
return buf
}
public func makeBuffer(length: Int) throws -> MTLBuffer {
guard let buf = device.makeBuffer(length: length, options: .storageModeShared) else {
throw E4BError.bufferCreationFailed
}
return buf
}
// ── Threadgroup sizing ────────────────────────────
/// Optimal 1D threadgroup size for a given pipeline.
public func threadgroupSize1D(_ pipeline: MTLComputePipelineState,
count: Int) -> MTLSize {
let w = min(pipeline.maxTotalThreadsPerThreadgroup, count)
return MTLSize(width: w, height: 1, depth: 1)
}
/// Optimal 2D threadgroup size for a given pipeline.
public func threadgroupSize2D(_ pipeline: MTLComputePipelineState,
grid: (width: Int, height: Int)) -> MTLSize {
let w = pipeline.threadExecutionWidth
let h = pipeline.maxTotalThreadsPerThreadgroup / w
return MTLSize(
width: min(w, grid.width),
height: min(h, grid.height),
depth: 1
)
}
// ── Synchronous dispatch ──────────────────────────
/// Synchronous 1D dispatch with proper threadgroup sizing.
@discardableResult
public func dispatch1D(
_ pipeline: MTLComputePipelineState,
buffers: [MTLBuffer],
count: Int
) throws -> MTLCommandBuffer {
guard let cmdBuf = commandQueue.makeCommandBuffer() else {
throw E4BError.commandBufferCreationFailed
}
guard let enc = cmdBuf.makeComputeCommandEncoder() else {
throw E4BError.commandEncoderCreationFailed
}
let tg = threadgroupSize1D(pipeline, count: count)
enc.setComputePipelineState(pipeline)
for (i, buf) in buffers.enumerated() { enc.setBuffer(buf, offset: 0, index: i) }
enc.dispatchThreads(MTLSize(width: count, height: 1, depth: 1),
threadsPerThreadgroup: tg)
enc.endEncoding()
cmdBuf.commit()
cmdBuf.waitUntilCompleted()
return cmdBuf
}
/// Synchronous 2D dispatch with proper threadgroup sizing.
@discardableResult
public func dispatch2D(
_ pipeline: MTLComputePipelineState,
buffers: [MTLBuffer],
grid: (width: Int, height: Int)
) throws -> MTLCommandBuffer {
guard let cmdBuf = commandQueue.makeCommandBuffer() else {
throw E4BError.commandBufferCreationFailed
}
guard let enc = cmdBuf.makeComputeCommandEncoder() else {
throw E4BError.commandEncoderCreationFailed
}
let tg = threadgroupSize2D(pipeline, grid: grid)
enc.setComputePipelineState(pipeline)
for (i, buf) in buffers.enumerated() { enc.setBuffer(buf, offset: 0, index: i) }
enc.dispatchThreads(MTLSize(width: grid.width, height: grid.height, depth: 1),
threadsPerThreadgroup: tg)
enc.endEncoding()
cmdBuf.commit()
cmdBuf.waitUntilCompleted()
return cmdBuf
}
// ── Async dispatch ────────────────────────────────
/// Batch dispatch: execute multiple kernel operations in a single command buffer
/// Returns after all GPU work completes
public func batchDispatch(_ operations: [(MTLComputePipelineState, [MTLBuffer], MTLSize)]) throws {
guard let cmdBuf = commandQueue.makeCommandBuffer() else {
throw E4BError.commandBufferCreationFailed
}
guard let enc = cmdBuf.makeComputeCommandEncoder() else {
throw E4BError.commandEncoderCreationFailed
}
for (pso, buffers, gridSize) in operations {
enc.setComputePipelineState(pso)
for (i, buf) in buffers.enumerated() {
enc.setBuffer(buf, offset: 0, index: i)
}
enc.dispatchThreads(gridSize, threadsPerThreadgroup: threadgroupSize1D(pso, count: Int(gridSize.width)))
}
enc.endEncoding()
cmdBuf.commit()
cmdBuf.waitUntilCompleted()
}
/// Create a batch encoder for manual use
/// Caller is responsible for calling endEncoding() and commit/wait on returned command buffer
public func makeBatchEncoder() throws -> (MTLCommandBuffer, MTLComputeCommandEncoder) {
guard let cmdBuf = commandQueue.makeCommandBuffer() else {
throw E4BError.commandBufferCreationFailed
}
guard let enc = cmdBuf.makeComputeCommandEncoder() else {
throw E4BError.commandEncoderCreationFailed
}
return (cmdBuf, enc)
}
/// Asynchronous 1D dispatch with proper threadgroup sizing.
/// Returns when GPU work completes; buffer contents are safe to read after.
public func dispatch1DAsync(
_ pipeline: MTLComputePipelineState,
buffers: [MTLBuffer],
count: Int
) async throws {
guard let cmdBuf = commandQueue.makeCommandBuffer() else {
throw E4BError.commandBufferCreationFailed
}
guard let enc = cmdBuf.makeComputeCommandEncoder() else {
throw E4BError.commandEncoderCreationFailed
}
let tg = threadgroupSize1D(pipeline, count: count)
enc.setComputePipelineState(pipeline)
for (i, buf) in buffers.enumerated() { enc.setBuffer(buf, offset: 0, index: i) }
enc.dispatchThreads(MTLSize(width: count, height: 1, depth: 1),
threadsPerThreadgroup: tg)
enc.endEncoding()
return await withCheckedContinuation { (continuation: CheckedContinuation<Void, Never>) in
cmdBuf.addCompletedHandler { _ in continuation.resume() }
cmdBuf.commit()
}
}
/// Asynchronous 2D dispatch with proper threadgroup sizing.
/// Returns when GPU work completes; buffer contents are safe to read after.
public func dispatch2DAsync(
_ pipeline: MTLComputePipelineState,
buffers: [MTLBuffer],
grid: (width: Int, height: Int)
) async throws {
guard let cmdBuf = commandQueue.makeCommandBuffer() else {
throw E4BError.commandBufferCreationFailed
}
guard let enc = cmdBuf.makeComputeCommandEncoder() else {
throw E4BError.commandEncoderCreationFailed
}
let tg = threadgroupSize2D(pipeline, grid: grid)
enc.setComputePipelineState(pipeline)
for (i, buf) in buffers.enumerated() { enc.setBuffer(buf, offset: 0, index: i) }
enc.dispatchThreads(MTLSize(width: grid.width, height: grid.height, depth: 1),
threadsPerThreadgroup: tg)
enc.endEncoding()
return await withCheckedContinuation { (continuation: CheckedContinuation<Void, Never>) in
cmdBuf.addCompletedHandler { _ in continuation.resume() }
cmdBuf.commit()
}
}
// ── Read-back ─────────────────────────────────────
public func readFloats(from buffer: MTLBuffer, offset: Int = 0, count: Int) -> [Float] {
let ptr = buffer.contents().assumingMemoryBound(to: Float.self)
return Array(UnsafeBufferPointer(start: ptr + offset, count: count))
}
public func writeFloats(to buffer: MTLBuffer, values: [Float], offset: Int = 0) {
let ptr = buffer.contents().assumingMemoryBound(to: Float.self)
for i in 0..<values.count {
ptr[i + offset] = values[i]
}
}
}