31427770b1
- Tokenizer fix: collect <0xXX> bytes and decode as UTF-8 (fixes Chinese/non-ASCII character decoding) - BPETokenizer + HuggingFaceTokenizer: both updated - Engine.swift: added writeFloats() utility method - FloatWeights struct added to Layer.swift (bf16 support) - attnQBits/KBits/VBits/OBits detection added to Model.swift - bf16 layer weight support from commit 48c0347 cherry-picked
297 lines
12 KiB
Swift
297 lines
12 KiB
Swift
@preconcurrency import Foundation
|
|
@preconcurrency import Metal
|
|
|
|
public enum E4BError: Error, LocalizedError {
|
|
case noMetalDevice
|
|
case libraryNotFound
|
|
case sourceCompilationFailed(String)
|
|
case pipelineCreationFailed(String)
|
|
case pipelineNotFound(String)
|
|
case bufferCreationFailed
|
|
case commandBufferCreationFailed
|
|
case commandEncoderCreationFailed
|
|
|
|
public var errorDescription: String? {
|
|
switch self {
|
|
case .noMetalDevice: return "No Metal-capable GPU found"
|
|
case .libraryNotFound: return "Metal library not found"
|
|
case .sourceCompilationFailed(let detail): return "Metal source compilation failed: \(detail)"
|
|
case .pipelineCreationFailed(let detail): return "Pipeline creation failed: \(detail)"
|
|
case .pipelineNotFound(let name): return "Kernel '\(name)' not found in library"
|
|
case .bufferCreationFailed: return "Failed to allocate Metal buffer"
|
|
case .commandBufferCreationFailed: return "Failed to create command buffer"
|
|
case .commandEncoderCreationFailed: return "Failed to create command encoder"
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Pure-Swift Metal engine with pipeline cache, proper threadgroup sizing,
|
|
/// buffer pool for reuse, and support for both runtime source compilation and pre-compiled metallib.
|
|
public final class MarkBaseEngine: @unchecked Sendable {
|
|
public let device: MTLDevice
|
|
public let commandQueue: MTLCommandQueue
|
|
public let bufferPool: BufferPool
|
|
public private(set) var library: MTLLibrary?
|
|
private var pipelineCache: [String: MTLComputePipelineState] = [:]
|
|
|
|
public init() throws {
|
|
guard let device = MTLCreateSystemDefaultDevice() else {
|
|
throw E4BError.noMetalDevice
|
|
}
|
|
self.device = device
|
|
guard let queue = device.makeCommandQueue() else {
|
|
throw E4BError.bufferCreationFailed
|
|
}
|
|
self.commandQueue = queue
|
|
self.bufferPool = BufferPool(device: device)
|
|
}
|
|
|
|
/// Initialize with Metal kernels auto-compiled from source.
|
|
/// Loads original, optimized, and fusion kernels.
|
|
public convenience init(autoCompile: Bool = true) throws {
|
|
try self.init()
|
|
if autoCompile {
|
|
try compileSource(MetalKernels.fullOptimizedSourceWithFusion)
|
|
}
|
|
}
|
|
|
|
// ── Library ──────────────────────────────────────
|
|
|
|
/// Compile Metal source at runtime.
|
|
public func compileSource(_ source: String) throws {
|
|
pipelineCache.removeAll()
|
|
library = try device.makeLibrary(source: source, options: nil)
|
|
}
|
|
|
|
/// Load a pre-compiled .metallib from a file path.
|
|
public func loadMetallib(path: String) throws {
|
|
let data = try Data(contentsOf: URL(fileURLWithPath: path))
|
|
try loadMetallib(data: data)
|
|
}
|
|
|
|
/// Load a pre-compiled .metallib from Data.
|
|
public func loadMetallib(data: Data) throws {
|
|
pipelineCache.removeAll()
|
|
library = try data.withUnsafeBytes { rawPtr in
|
|
try device.makeLibrary(data: DispatchData(bytes: rawPtr))
|
|
}
|
|
}
|
|
|
|
// ── Pipeline (with cache) ─────────────────────────
|
|
|
|
/// Return a cached or newly-created compute pipeline state.
|
|
public func pipeline(named name: String) throws -> MTLComputePipelineState {
|
|
if let cached = pipelineCache[name] { return cached }
|
|
guard let lib = library else { throw E4BError.libraryNotFound }
|
|
guard let fn = lib.makeFunction(name: name) else {
|
|
throw E4BError.pipelineNotFound(name)
|
|
}
|
|
let pso = try device.makeComputePipelineState(function: fn)
|
|
pipelineCache[name] = pso
|
|
return pso
|
|
}
|
|
|
|
/// Clear the pipeline cache (e.g. after recompiling the library).
|
|
public func clearPipelineCache() { pipelineCache.removeAll() }
|
|
|
|
// ── Buffers ───────────────────────────────────────
|
|
|
|
/// Acquire a buffer from the pool (or create new if none available)
|
|
public func acquireBuffer(length: Int) -> MTLBuffer {
|
|
return bufferPool.acquire(length: length)
|
|
}
|
|
|
|
/// Release a buffer back to the pool for reuse
|
|
public func releaseBuffer(_ buffer: MTLBuffer) {
|
|
bufferPool.release(buffer)
|
|
}
|
|
|
|
public func makeBuffer<T>(_ values: [T]) throws -> MTLBuffer {
|
|
let count = values.count * MemoryLayout<T>.stride
|
|
guard let buf = values.withUnsafeBytes({ rawPtr in
|
|
device.makeBuffer(bytes: rawPtr.baseAddress!, length: count,
|
|
options: .storageModeShared)
|
|
}) else { throw E4BError.bufferCreationFailed }
|
|
return buf
|
|
}
|
|
|
|
public func makeBuffer(length: Int) throws -> MTLBuffer {
|
|
guard let buf = device.makeBuffer(length: length, options: .storageModeShared) else {
|
|
throw E4BError.bufferCreationFailed
|
|
}
|
|
return buf
|
|
}
|
|
|
|
// ── Threadgroup sizing ────────────────────────────
|
|
|
|
/// Optimal 1D threadgroup size for a given pipeline.
|
|
public func threadgroupSize1D(_ pipeline: MTLComputePipelineState,
|
|
count: Int) -> MTLSize {
|
|
let w = min(pipeline.maxTotalThreadsPerThreadgroup, count)
|
|
return MTLSize(width: w, height: 1, depth: 1)
|
|
}
|
|
|
|
/// Optimal 2D threadgroup size for a given pipeline.
|
|
public func threadgroupSize2D(_ pipeline: MTLComputePipelineState,
|
|
grid: (width: Int, height: Int)) -> MTLSize {
|
|
let w = pipeline.threadExecutionWidth
|
|
let h = pipeline.maxTotalThreadsPerThreadgroup / w
|
|
return MTLSize(
|
|
width: min(w, grid.width),
|
|
height: min(h, grid.height),
|
|
depth: 1
|
|
)
|
|
}
|
|
|
|
// ── Synchronous dispatch ──────────────────────────
|
|
|
|
/// Synchronous 1D dispatch with proper threadgroup sizing.
|
|
@discardableResult
|
|
public func dispatch1D(
|
|
_ pipeline: MTLComputePipelineState,
|
|
buffers: [MTLBuffer],
|
|
count: Int
|
|
) throws -> MTLCommandBuffer {
|
|
guard let cmdBuf = commandQueue.makeCommandBuffer() else {
|
|
throw E4BError.commandBufferCreationFailed
|
|
}
|
|
guard let enc = cmdBuf.makeComputeCommandEncoder() else {
|
|
throw E4BError.commandEncoderCreationFailed
|
|
}
|
|
let tg = threadgroupSize1D(pipeline, count: count)
|
|
enc.setComputePipelineState(pipeline)
|
|
for (i, buf) in buffers.enumerated() { enc.setBuffer(buf, offset: 0, index: i) }
|
|
enc.dispatchThreads(MTLSize(width: count, height: 1, depth: 1),
|
|
threadsPerThreadgroup: tg)
|
|
enc.endEncoding()
|
|
cmdBuf.commit()
|
|
cmdBuf.waitUntilCompleted()
|
|
return cmdBuf
|
|
}
|
|
|
|
/// Synchronous 2D dispatch with proper threadgroup sizing.
|
|
@discardableResult
|
|
public func dispatch2D(
|
|
_ pipeline: MTLComputePipelineState,
|
|
buffers: [MTLBuffer],
|
|
grid: (width: Int, height: Int)
|
|
) throws -> MTLCommandBuffer {
|
|
guard let cmdBuf = commandQueue.makeCommandBuffer() else {
|
|
throw E4BError.commandBufferCreationFailed
|
|
}
|
|
guard let enc = cmdBuf.makeComputeCommandEncoder() else {
|
|
throw E4BError.commandEncoderCreationFailed
|
|
}
|
|
let tg = threadgroupSize2D(pipeline, grid: grid)
|
|
enc.setComputePipelineState(pipeline)
|
|
for (i, buf) in buffers.enumerated() { enc.setBuffer(buf, offset: 0, index: i) }
|
|
enc.dispatchThreads(MTLSize(width: grid.width, height: grid.height, depth: 1),
|
|
threadsPerThreadgroup: tg)
|
|
enc.endEncoding()
|
|
cmdBuf.commit()
|
|
cmdBuf.waitUntilCompleted()
|
|
return cmdBuf
|
|
}
|
|
|
|
// ── Async dispatch ────────────────────────────────
|
|
|
|
/// Batch dispatch: execute multiple kernel operations in a single command buffer
|
|
/// Returns after all GPU work completes
|
|
public func batchDispatch(_ operations: [(MTLComputePipelineState, [MTLBuffer], MTLSize)]) throws {
|
|
guard let cmdBuf = commandQueue.makeCommandBuffer() else {
|
|
throw E4BError.commandBufferCreationFailed
|
|
}
|
|
guard let enc = cmdBuf.makeComputeCommandEncoder() else {
|
|
throw E4BError.commandEncoderCreationFailed
|
|
}
|
|
|
|
for (pso, buffers, gridSize) in operations {
|
|
enc.setComputePipelineState(pso)
|
|
for (i, buf) in buffers.enumerated() {
|
|
enc.setBuffer(buf, offset: 0, index: i)
|
|
}
|
|
enc.dispatchThreads(gridSize, threadsPerThreadgroup: threadgroupSize1D(pso, count: Int(gridSize.width)))
|
|
}
|
|
|
|
enc.endEncoding()
|
|
cmdBuf.commit()
|
|
cmdBuf.waitUntilCompleted()
|
|
}
|
|
|
|
/// Create a batch encoder for manual use
|
|
/// Caller is responsible for calling endEncoding() and commit/wait on returned command buffer
|
|
public func makeBatchEncoder() throws -> (MTLCommandBuffer, MTLComputeCommandEncoder) {
|
|
guard let cmdBuf = commandQueue.makeCommandBuffer() else {
|
|
throw E4BError.commandBufferCreationFailed
|
|
}
|
|
guard let enc = cmdBuf.makeComputeCommandEncoder() else {
|
|
throw E4BError.commandEncoderCreationFailed
|
|
}
|
|
return (cmdBuf, enc)
|
|
}
|
|
|
|
/// Asynchronous 1D dispatch with proper threadgroup sizing.
|
|
/// Returns when GPU work completes; buffer contents are safe to read after.
|
|
public func dispatch1DAsync(
|
|
_ pipeline: MTLComputePipelineState,
|
|
buffers: [MTLBuffer],
|
|
count: Int
|
|
) async throws {
|
|
guard let cmdBuf = commandQueue.makeCommandBuffer() else {
|
|
throw E4BError.commandBufferCreationFailed
|
|
}
|
|
guard let enc = cmdBuf.makeComputeCommandEncoder() else {
|
|
throw E4BError.commandEncoderCreationFailed
|
|
}
|
|
let tg = threadgroupSize1D(pipeline, count: count)
|
|
enc.setComputePipelineState(pipeline)
|
|
for (i, buf) in buffers.enumerated() { enc.setBuffer(buf, offset: 0, index: i) }
|
|
enc.dispatchThreads(MTLSize(width: count, height: 1, depth: 1),
|
|
threadsPerThreadgroup: tg)
|
|
enc.endEncoding()
|
|
return await withCheckedContinuation { (continuation: CheckedContinuation<Void, Never>) in
|
|
cmdBuf.addCompletedHandler { _ in continuation.resume() }
|
|
cmdBuf.commit()
|
|
}
|
|
}
|
|
|
|
/// Asynchronous 2D dispatch with proper threadgroup sizing.
|
|
/// Returns when GPU work completes; buffer contents are safe to read after.
|
|
public func dispatch2DAsync(
|
|
_ pipeline: MTLComputePipelineState,
|
|
buffers: [MTLBuffer],
|
|
grid: (width: Int, height: Int)
|
|
) async throws {
|
|
guard let cmdBuf = commandQueue.makeCommandBuffer() else {
|
|
throw E4BError.commandBufferCreationFailed
|
|
}
|
|
guard let enc = cmdBuf.makeComputeCommandEncoder() else {
|
|
throw E4BError.commandEncoderCreationFailed
|
|
}
|
|
let tg = threadgroupSize2D(pipeline, grid: grid)
|
|
enc.setComputePipelineState(pipeline)
|
|
for (i, buf) in buffers.enumerated() { enc.setBuffer(buf, offset: 0, index: i) }
|
|
enc.dispatchThreads(MTLSize(width: grid.width, height: grid.height, depth: 1),
|
|
threadsPerThreadgroup: tg)
|
|
enc.endEncoding()
|
|
return await withCheckedContinuation { (continuation: CheckedContinuation<Void, Never>) in
|
|
cmdBuf.addCompletedHandler { _ in continuation.resume() }
|
|
cmdBuf.commit()
|
|
}
|
|
}
|
|
|
|
// ── Read-back ─────────────────────────────────────
|
|
|
|
public func readFloats(from buffer: MTLBuffer, offset: Int = 0, count: Int) -> [Float] {
|
|
let ptr = buffer.contents().assumingMemoryBound(to: Float.self)
|
|
return Array(UnsafeBufferPointer(start: ptr + offset, count: count))
|
|
}
|
|
|
|
public func writeFloats(to buffer: MTLBuffer, values: [Float], offset: Int = 0) {
|
|
let ptr = buffer.contents().assumingMemoryBound(to: Float.self)
|
|
for i in 0..<values.count {
|
|
ptr[i + offset] = values[i]
|
|
}
|
|
}
|
|
}
|