Files
markbaseengine/Sources/MarkBase/Model.swift
T
MarkBase Admin 239474bef0
CI / build (push) Waiting to run
CI / unit-tests (push) Blocked by required conditions
CI / lint (push) Blocked by required conditions
v2: fix 26B activation explosion — normalize groupSize=32 scales, fix hardcoded loops
2026-07-05 19:52:47 +08:00

1818 lines
85 KiB
Swift
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import Foundation
import Metal
// ═══════════════════════════════════════════════════
// E4B Text Model — 42-layer forward pass
// ═══════════════════════════════════════════════════
/// KV sharing map: for each shared layer, which non-shared cache to read from.
func computeKVSourceMap(numHiddenLayers: Int, numKVShared: Int, layerTypesIsFull: [Bool]) -> [Int: Int] {
// Use Python's algorithm: count same-type layers, then match in reverse order
let firstShared = numHiddenLayers - numKVShared
var counts: [Bool: Int] = [false: 0, true: 0]
for l in 0..<firstShared {
counts[layerTypesIsFull[l]]! += 1
}
var map: [Int: Int] = [:]
for l in firstShared..<numHiddenLayers {
let isFull = layerTypesIsFull[l]
counts[isFull]! -= 1
let targetCount = counts[isFull]!
var found = 0
for src in 0..<firstShared {
if layerTypesIsFull[src] == isFull {
if found == targetCount {
map[l] = src
break
}
found += 1
}
}
}
return map
}
public final class E4BModel: @unchecked Sendable {
public let engine: MarkBaseEngine
public let config: ModelConfig
public let layers: [E4BLayer]
public let kvCaches: [KVCache]
public let temps: ForwardTemps
public let embedWeight: QuantizedWeights
public let embedTokensPerLayerWeight: QuantizedWeights?
public let perLayerModelProjection: MTLBuffer? // Context-aware projection (BF16 as F32)
public let perLayerProjectionNorm: MTLBuffer? // Norm for context projection
public let perLayerInputSize: Int
public let perLayerModelProjectionOutDim: Int // 10752 = numLayers * perLayerSize
// Convenience
public let numHiddenLayers: Int
public let hiddenSize: Int
public let vocabSize: Int
public let firstKVShared: Int
public let numKvShared: Int
public let kvSourceMap: [Int: Int]
public let rmsNormEps: Float
public let finalLogitSoftcapping: Float?
public let embedScale: Float
public let perLayerInputScaleVal: Float
public let perLayerModelProjectionScaleVal: Float // 1/sqrt(hiddenSize)
public let layerTypesIsFull: [Bool] // true = full_attention, false = sliding_attention
public let maxContext: Int
// LM head output buffer
let logitsBuffer: MTLBuffer
// Per-layer embedding buffer (for current token)
let perLayerEmbedBuffer: MTLBuffer?
let perLayerContextBuffer: MTLBuffer? // Context-aware projection result
// Final RMSNorm (before LM head)
let finalNorm: MTLBuffer?
// ── Init ──────────────────────────────────────────
public init(modelDir: String, engine: MarkBaseEngine, maxContextLength: Int = 8192) throws {
self.engine = engine
self.maxContext = maxContextLength
// Load config
let cfg = try ModelConfig.load(from: modelDir)
self.config = cfg
self.numHiddenLayers = cfg.numHiddenLayers ?? 42
self.hiddenSize = cfg.hiddenSize ?? 2560
self.vocabSize = cfg.vocabSize ?? 262144
self.rmsNormEps = cfg.rmsNormEps ?? 1e-6
self.finalLogitSoftcapping = cfg.finalLogitSoftcapping
self.embedScale = cfg.embedScale ?? sqrt(Float(cfg.hiddenSize ?? 2560))
self.numKvShared = cfg.numKvSharedLayers ?? 0 // 0 = no KV sharing (default for models without this config)
self.firstKVShared = numHiddenLayers - numKvShared
let pattern = cfg.slidingWindowPattern ?? 6 // Gemma 4: 5 sliding + 1 full = pattern 6
self.perLayerInputScaleVal = cfg.perLayerInputScale ?? sqrt(0.5)
// Determine layer types from config, or derive from pattern
if let lt = cfg.layerTypes {
self.layerTypesIsFull = lt.map { $0 == "full_attention" }
} else {
var derived: [Bool] = []
for i in 0..<numHiddenLayers {
derived.append((i % pattern) == (pattern - 1))
}
self.layerTypesIsFull = derived
}
self.kvSourceMap = computeKVSourceMap(
numHiddenLayers: numHiddenLayers,
numKVShared: numKvShared,
layerTypesIsFull: layerTypesIsFull
)
// Load all tensors (support both single-file and sharded)
let singleFile = "\(modelDir)/model.safetensors"
let index: SafeTensorsIndex?
let readers: [String: SafeTensorsReader] // shard file -> reader
if FileManager.default.fileExists(atPath: singleFile) {
// Single-file model (E4B)
index = nil
let reader = try SafeTensorsReader(path: singleFile)
readers = ["model.safetensors": reader]
} else {
// Sharded model (12B): load all shards in parallel
let indexPath = "\(modelDir)/model.safetensors.index.json"
if !FileManager.default.fileExists(atPath: indexPath) {
throw WeightError.readFailed("No model.safetensors or index.json found")
}
// Load index
index = try SafeTensorsIndex(modelDir: modelDir)
// Parallel shard loading (critical optimization for large models)
print("Loading \(index!.shardFiles.count) shards in parallel...")
let loadStart = Date()
let shardFiles = index!.shardFiles.sorted()
var loadedReaders: [SafeTensorsReader?] = Array(repeating: nil, count: shardFiles.count)
var loadErrors: [Error?] = Array(repeating: nil, count: shardFiles.count)
// Use DispatchGroup for parallel loading (thread-safe array access)
let dispatchGroup = DispatchGroup()
let queue = DispatchQueue(label: "shard-loading", attributes: .concurrent)
for (idx, shardFile) in shardFiles.enumerated() {
dispatchGroup.enter()
queue.async {
do {
let shardPath = "\(modelDir)/\(shardFile)"
let reader = try SafeTensorsReader(path: shardPath)
loadedReaders[idx] = reader // Thread-safe: each thread writes to different index
} catch {
loadErrors[idx] = error
}
dispatchGroup.leave()
}
}
dispatchGroup.wait()
// Check for errors and build dictionary (sequential, thread-safe)
var readersDict: [String: SafeTensorsReader] = [:]
for (idx, error) in loadErrors.enumerated() {
if let err = error {
throw WeightError.readFailed("Failed to load shard \(shardFiles[idx]): \(err)")
}
if let reader = loadedReaders[idx] {
readersDict[shardFiles[idx]] = reader
}
}
let loadTime = Date().timeIntervalSince(loadStart) * 1000
print("✓ Parallel loaded \(readersDict.count) shards in \(String(format: "%.1f", loadTime))ms")
print(" Shards: \(shardFiles)")
readers = readersDict
}
// Helper functions for unified tensor access
func getReader(forTensor name: String) -> SafeTensorsReader? {
if let idx = index {
guard let shardFile = idx.weightMap[name] else { return nil }
return readers[shardFile]
} else {
return readers["model.safetensors"]
}
}
func getAllTensorDescriptors() -> [TensorDescriptor] {
if index != nil {
// Sharded: collect all tensors from all shards
var all: [TensorDescriptor] = []
for reader in readers.values {
all.append(contentsOf: reader.allTensors)
}
return all
} else {
// Single file
return readers["model.safetensors"]!.allTensors
}
}
let allTensors = getAllTensorDescriptors()
print("✓ Total tensors: \(allTensors.count)")
// E4B MLX models use "language_model.model." prefix, but converted models may omit it
// Detect which format by checking for "layers.0.self_attn.q_proj.weight"
let P: String
if allTensors.contains(where: { $0.name == "layers.0.self_attn.q_proj.weight" }) {
P = ""
print(" Using short prefix (no language_model.model.)")
} else {
P = "language_model.model."
print(" Using long prefix (language_model.model.)")
}
// Config values may be wrong (e.g. 26B-standard has nHeads=8 but q_proj has 16 heads).
var effectiveNHeads = cfg.numAttentionHeads ?? 16
var effectiveNKvHeads = cfg.numKeyValueHeads ?? 8
var effectiveGlobalKvHeads = cfg.numGlobalKeyValueHeads
let slidingHd = cfg.slidingHeadDim ?? cfg.headDim ?? 256
let globalHd = cfg.globalHeadDim ?? cfg.headDim ?? 512
if let qDesc = allTensors.first(where: { $0.name.contains("language_model") && $0.name.hasSuffix("self_attn.q_proj.weight") }) {
let qOut = qDesc.shape[0]
if qOut > 0 {
let detected = qOut % slidingHd == 0 ? qOut / slidingHd : qOut / globalHd
if detected != effectiveNHeads {
print(" ⚠ q_proj out_dim=\(qOut) → nHeads=\(detected) (config says \(effectiveNHeads))")
effectiveNHeads = detected
}
}
}
if let kDesc = allTensors.first(where: { $0.name.contains("language_model") && $0.name.hasSuffix("self_attn.k_proj.weight") }) {
let kOut = kDesc.shape[0]
if kOut > 0 && kOut % slidingHd == 0 {
let detected = kOut / slidingHd
if detected != effectiveNKvHeads {
print(" ⚠ k_proj out_dim=\(kOut), head_dim=\(slidingHd) → nKvHeads=\(detected) (config says \(effectiveNKvHeads))")
effectiveNKvHeads = detected
}
}
}
// Detect global kv heads from first full attention layer
// Also detect globalHeadDim from k_norm.weight shape
var detectedGlobalHd: Int? = nil
if let firstFullIdx = layerTypesIsFull.firstIndex(of: true) {
// Detect globalHeadDim from k_norm.weight shape
let kNormName = "\(P)layers.\(firstFullIdx).self_attn.k_norm.weight"
if let kNormDesc = allTensors.first(where: { $0.name == kNormName }) {
let kNormShape = kNormDesc.shape[0]
if kNormShape > 0 {
detectedGlobalHd = kNormShape
print(" ⚠ Detected globalHeadDim=\(kNormShape) from k_norm.weight (config: \(cfg.globalHeadDim ?? -1))")
}
}
// Detect globalKvHeads from k_proj.weight
let fullKName = "\(P)layers.\(firstFullIdx).self_attn.k_proj.weight"
if let fullKDesc = allTensors.first(where: { $0.name == fullKName }) {
let kOut = fullKDesc.shape[0]
// Use detectedGlobalHd if available, otherwise fallback
let actualGlobalHd = detectedGlobalHd ?? globalHd
if kOut > 0 && kOut % actualGlobalHd == 0 {
let detected = kOut / actualGlobalHd
if effectiveGlobalKvHeads == nil || detected != effectiveGlobalKvHeads {
print(" ⚠ (full) k_proj out_dim=\(kOut), global_head_dim=\(actualGlobalHd) → globalKvHeads=\(detected) (config: \(effectiveGlobalKvHeads.map(String.init) ?? "nil"))")
effectiveGlobalKvHeads = detected
}
}
}
}
if effectiveNHeads != (cfg.numAttentionHeads ?? 16) || effectiveNKvHeads != (cfg.numKeyValueHeads ?? 8) {
print(" → Using effective: nHeads=\(effectiveNHeads), nKvHeads=\(effectiveNKvHeads), globalKvHeads=\(effectiveGlobalKvHeads.map(String.init) ?? "nil")")
}
// ── Load embed tokens ──
print("Loading embed_tokens...")
// Debug: Check what embed_tensors exist
let embedTensors = allTensors.filter { $0.name.contains("embed_tokens") }
print(" Found \(embedTensors.count) embed_tokens tensors:")
for t in embedTensors.prefix(5) {
print(" - \(t.name): dtype=\(t.dtype), shape=\(t.shape)")
}
// Try without prefix first (converted format), then with prefix (original format)
var embedGroup = try Self.quantizedGroup(named: "embed_tokens", from: allTensors,
index: index, readers: readers,
device: engine.device)
if embedGroup == nil {
print(" Trying with prefix...")
embedGroup = try Self.quantizedGroup(named: "\(P)embed_tokens", from: allTensors,
index: index, readers: readers,
device: engine.device)
}
// Handle optional missing scales/biases (non-quantized embedding)
if let eg = embedGroup {
print(" ✓ embed_tokens loaded")
// Note: groupSize=32 scale normalization now done in quantizedGroup
self.embedWeight = eg
} else {
// Non-quantized: create dummy quantized wrapper (all 0 scales=1.0, biases=0.0)
// Actually, if not quantized, we'd need to treat it as f32 weights
// For E4B 4-bit, it should be quantized
throw WeightError.unsupportedDtype("Embed tokens not quantized")
}
// ── Load embed_tokens_per_layer ──
print("Loading embed_tokens_per_layer...")
let perLayerSize = cfg.hiddenSizePerLayerInput ?? 256
self.perLayerInputSize = perLayerSize
print(" cfg.hiddenSizePerLayerInput: \(String(describing: cfg.hiddenSizePerLayerInput))")
print(" perLayerSize: \(perLayerSize), numHiddenLayers: \(numHiddenLayers)")
self.perLayerModelProjectionScaleVal = 1.0 / sqrt(Float(hiddenSize))
if perLayerSize > 0 {
// Load the quantized per-layer embedding table
let plWeight = try Self.quantizedGroup(named: "\(P)embed_tokens_per_layer", from: allTensors,
index: index, readers: readers,
device: engine.device)
if let pw = plWeight {
print(" ✓ embed_tokens_per_layer loaded: outDim=\(pw.outDim), inDim=\(pw.inDim)")
self.embedTokensPerLayerWeight = pw
// Create buffer for per-layer embedding lookup result
let totalPerLayer = perLayerSize * numHiddenLayers
guard let plBuf = engine.device.makeBuffer(length: totalPerLayer * MemoryLayout<Float>.stride,
options: .storageModeShared) else {
throw E4BError.bufferCreationFailed
}
self.perLayerEmbedBuffer = plBuf
// Context-aware projection buffer
guard let ctxBuf = engine.device.makeBuffer(length: totalPerLayer * MemoryLayout<Float>.stride,
options: .storageModeShared) else {
throw E4BError.bufferCreationFailed
}
self.perLayerContextBuffer = ctxBuf
print(" ✓ Per-layer buffers created: \(totalPerLayer) Floats each")
} else {
print(" ✗ Failed to load embed_tokens_per_layer")
self.embedTokensPerLayerWeight = nil
self.perLayerEmbedBuffer = nil
self.perLayerContextBuffer = nil
}
// Load per_layer_model_projection (context-aware projection, BF16)
// This is NOT quantized - it's a regular BF16 linear weight
let projName = "\(P)per_layer_model_projection.weight"
if let projDesc = allTensors.first(where: { $0.name == projName }) {
let projReader: SafeTensorsReader
if let idx = index, let shard = idx.weightMap[projName] {
projReader = readers[shard]!
} else {
projReader = readers["model.safetensors"]!
}
let projData = try projReader.read(tensor: projDesc)
let projFloats = SafeTensorsReader.bf16ToFloat32(projData)
let outDim = projDesc.shape[0] // 10752
let inDim = projDesc.shape[1] // 2560
guard let projBuf = engine.device.makeBuffer(
bytes: projFloats, length: projFloats.count * MemoryLayout<Float>.stride,
options: .storageModeShared
) else {
throw E4BError.bufferCreationFailed
}
self.perLayerModelProjection = projBuf
self.perLayerModelProjectionOutDim = outDim
print(" ✓ per_layer_model_projection loaded: shape=[\(outDim), \(inDim)], dtype=BF16→F32")
} else {
print(" ✗ Failed to load per_layer_model_projection")
self.perLayerModelProjection = nil
self.perLayerModelProjectionOutDim = 0
}
// Load per_layer_projection_norm
let normWeight = try Self.loadNorm(named: "\(P)per_layer_projection_norm.weight", from: allTensors,
index: index, readers: readers,
device: engine.device)
if let nw = normWeight {
print(" ✓ per_layer_projection_norm loaded")
self.perLayerProjectionNorm = nw
} else {
print(" ✗ Failed to load per_layer_projection_norm")
self.perLayerProjectionNorm = nil
}
} else {
// 12B doesn't use per-layer input
print(" Per-layer input disabled (size=0)")
self.embedTokensPerLayerWeight = nil
self.perLayerEmbedBuffer = nil
self.perLayerContextBuffer = nil
self.perLayerModelProjection = nil
self.perLayerProjectionNorm = nil
self.perLayerModelProjectionOutDim = 0
}
// ── Build per-layer data ──
// ── Optimized: Pre-read all layer weights in parallel before layer construction ──
// This is the major bottleneck optimization: parallel file reads
print("\nPre-reading all layer weights in parallel...")
let preloadStart = Date()
//方案C: 直接收集allTensors中实际存在的layer权重
var allWeightNames: [String] = []
var debugCounts = (language: 0, vision: 0, audio: 0, other: 0)
for layerIdx in 0..<numHiddenLayers {
// 查找所有包含此layer索引的tensor(精确匹配,避免误匹配layers.1x)
let layerPrefix = "\(P)layers.\(layerIdx)."
let layerTensors = allTensors.filter { tensor in
tensor.name.hasPrefix(layerPrefix) &&
!tensor.name.contains("vision_tower") &&
!tensor.name.contains("audio_tower")
}
for tensor in layerTensors {
allWeightNames.append(tensor.name)
// Debug counting
if tensor.name.contains("language_model.model.layers") {
debugCounts.language += 1
} else if tensor.name.contains("vision_tower") {
debugCounts.vision += 1
} else if tensor.name.contains("audio_tower") {
debugCounts.audio += 1
} else {
debugCounts.other += 1
}
}
}
print(" Collected \(allWeightNames.count) weight names from allTensors")
// Parallel weight loading using DispatchGroup (thread-safe array access)
let dispatchGroup = DispatchGroup()
let loadQueue = DispatchQueue(label: "weight-preloading", attributes: .concurrent)
var loadedWeights: [Data?] = Array(repeating: nil, count: allWeightNames.count)
var loadErrors: [Error?] = Array(repeating: nil, count: allWeightNames.count)
var notFoundCount = 0
var successCount = 0
for (weightIndex, name) in allWeightNames.enumerated() {
dispatchGroup.enter()
loadQueue.async {
do {
guard let desc = allTensors.first(where: { $0.name == name }) else {
notFoundCount += 1
dispatchGroup.leave()
return
}
let reader: SafeTensorsReader
if let idx = index, let shardFile = idx.weightMap[name] {
reader = readers[shardFile]!
} else {
reader = readers["model.safetensors"]!
}
let data = try reader.read(tensor: desc)
loadedWeights[weightIndex] = data
successCount += 1
} catch {
loadErrors[weightIndex] = error
}
dispatchGroup.leave()
}
}
dispatchGroup.wait()
print(" Parallel load completed: success=\(successCount), notFound=\(notFoundCount), errors=\(loadErrors.compactMap { $0 }.count)")
// Create weight data cache for fast lookup
var preloadedDataCache: [String: Data] = [:]
for (weightIndex, name) in allWeightNames.enumerated() {
if let data = loadedWeights[weightIndex] {
preloadedDataCache[name] = data
}
}
print(" Loaded weights: \(loadedWeights.filter { $0 != nil }.count)/\(loadedWeights.count)")
let preloadTime = Date().timeIntervalSince(preloadStart) * 1000
print("✓ Parallel preloaded \(preloadedDataCache.count) weights in \(String(format: "%.1f", preloadTime))ms")
// Check for errors
for (weightIndex, error) in loadErrors.enumerated() {
if let err = error {
print("⚠ Failed to preload weight \(allWeightNames[weightIndex]): \(err)")
}
}
print("\nBuilding layers...")
var builtLayers: [E4BLayer] = []
builtLayers.reserveCapacity(numHiddenLayers)
for layerIdx in 0..<numHiddenLayers {
let prefix = "\(P)layers.\(layerIdx)"
// Optimized helper methods that use preloaded cache
func normFromCache(_ name: String) throws -> MTLBuffer? {
let fullName = "\(prefix).\(name)"
if let data = preloadedDataCache[fullName] {
let desc = allTensors.first(where: { $0.name == fullName })
let floats: [Float]
if desc?.dtype == .bf16 {
floats = SafeTensorsReader.bf16ToFloat32(data)
} else if desc?.dtype == .f32 {
floats = data.withUnsafeBytes { Array($0.assumingMemoryBound(to: Float.self)) }
} else {
return nil
}
guard let buf = engine.device.makeBuffer(
bytes: floats, length: floats.count * MemoryLayout<Float>.stride,
options: .storageModeShared
) else { return nil }
return buf
}
return try Self.loadNorm(named: fullName, from: allTensors,
index: index, readers: readers,
device: engine.device)
}
func qwFromCache(_ name: String, bits: Int = 4) throws -> QuantizedWeights? {
let fullName = "\(prefix).\(name)"
let wName = "\(fullName).weight"
let sName = "\(fullName).scales"
let bName = "\(fullName).biases"
if let wData = preloadedDataCache[wName], let sData = preloadedDataCache[sName], fullName.contains("embed") == false {
let wDesc = allTensors.first(where: { $0.name == wName })
let sDesc = allTensors.first(where: { $0.name == sName })
let wShape = wDesc?.shape ?? []
let sShape = sDesc?.shape ?? []
let outDim = wShape.count > 0 ? wShape[0] : 0
let packedDim = wShape.count > 1 ? wShape[1] : 0
let inDim = packedDim * (bits == 4 ? 8 : 4)
let groupSize = (sShape.count > 1 && sShape[1] > 0) ? inDim / sShape[1] : 64
let bData = preloadedDataCache[bName]
let wBuf = wData.withUnsafeBytes { ptr in
engine.device.makeBuffer(bytes: ptr.baseAddress!, length: wData.count, options: .storageModeShared)
}
let sBuf: MTLBuffer?
if sDesc?.dtype == .bf16 {
var sFloats = SafeTensorsReader.bf16ToFloat32(sData)
if groupSize == 32 {
for i in 0..<sFloats.count {
sFloats[i] = sFloats[i] / Float(inDim)
}
}
sBuf = engine.device.makeBuffer(
bytes: sFloats, length: sFloats.count * MemoryLayout<Float>.stride,
options: .storageModeShared
)
} else {
sBuf = sData.withUnsafeBytes { ptr in
engine.device.makeBuffer(bytes: ptr.baseAddress!, length: sData.count, options: .storageModeShared)
}
}
let bBuf: MTLBuffer?
if let bData = bData {
if let bDesc = allTensors.first(where: { $0.name == bName }), bDesc.dtype == .bf16 {
let bFloats = SafeTensorsReader.bf16ToFloat32(bData)
bBuf = engine.device.makeBuffer(
bytes: bFloats, length: bFloats.count * MemoryLayout<Float>.stride,
options: .storageModeShared
)
} else {
bBuf = bData.withUnsafeBytes { ptr in
engine.device.makeBuffer(bytes: ptr.baseAddress!, length: bData.count, options: .storageModeShared)
}
}
} else {
let sCount = sDesc?.shape.reduce(1, *) ?? 0
let bFloatsZero = [Float](repeating: 0.0, count: sCount)
bBuf = engine.device.makeBuffer(
bytes: bFloatsZero,
length: bFloatsZero.count * MemoryLayout<Float>.stride,
options: .storageModeShared
)
}
guard let wBufSafe = wBuf, let sBufSafe = sBuf, let bBufSafe = bBuf else {
return nil
}
return QuantizedWeights(
weight: wBufSafe,
scales: sBufSafe,
biases: bBufSafe,
inDim: inDim,
outDim: outDim,
bits: bits,
groupSize: groupSize
)
}
return try Self.quantizedGroup(named: fullName, from: allTensors,
index: index, readers: readers,
device: engine.device, bits: bits)
}
let isFull = layerTypesIsFull[layerIdx]
// E4B uses intermediateSize=10240 for all layers, not doubled for shared KV
let intermediate = cfg.intermediateSize ?? 10240
// Use detected globalHeadDim for full layers if available
let hd = isFull ? (detectedGlobalHd ?? cfg.globalHeadDim ?? cfg.headDim ?? 512) : (cfg.slidingHeadDim ?? cfg.headDim ?? 256)
let nHeads = effectiveNHeads
let nKvHeads = isFull ? (effectiveGlobalKvHeads ?? effectiveNKvHeads) : effectiveNKvHeads
print(" isFull: \(isFull), headDim: \(hd), intermediate: \(intermediate), nHeads: \(nHeads), nKvHeads: \(nKvHeads)")
fflush(stdout)
let lcfg: E4BLayerConfig = isFull
? .full(hiddenSize: hiddenSize, headDim: hd, intermediateSize: intermediate,
nHeads: nHeads, nKvHeads: nKvHeads, maxPosition: maxContextLength)
: .sliding(hiddenSize: hiddenSize, headDim: hd, intermediateSize: intermediate,
nHeads: nHeads, nKvHeads: nKvHeads, windowSize: cfg.slidingWindow ?? 512)
let maxHeadDim = cfg.headDim ?? 512
func norm(_ name: String) throws -> MTLBuffer? {
try Self.loadNorm(named: "\(prefix).\(name)", from: allTensors,
index: index, readers: readers,
device: engine.device)
}
func normStrided(_ name: String, nHeads: Int, hd: Int) throws -> MTLBuffer? {
try Self.loadNormStrided(named: "\(prefix).\(name)", from: allTensors,
index: index, readers: readers,
device: engine.device,
nHeads: nHeads, headDim: hd, maxHeadDim: maxHeadDim)
}
func qw(_ name: String, bits: Int = 4) throws -> QuantizedWeights? {
try Self.quantizedGroup(named: "\(prefix).\(name)", from: allTensors,
index: index, readers: readers,
device: engine.device, bits: bits)
}
func fw(_ name: String) throws -> FloatWeights? {
let fullName = "\(prefix).\(name)"
let wName = "\(fullName).weight"
// Check if weight is in preloaded cache
if let wData = preloadedDataCache[wName] {
let wDesc = allTensors.first(where: { $0.name == wName })
if let desc = wDesc, desc.dtype == .bf16 {
let wFloats = SafeTensorsReader.bf16ToFloat32(wData)
let outDim = desc.shape[0]
let inDim = desc.shape[1]
if let wBuf = engine.device.makeBuffer(
bytes: wFloats, length: wFloats.count * MemoryLayout<Float>.stride,
options: .storageModeShared
) {
return FloatWeights(weight: wBuf, inDim: inDim, outDim: outDim)
}
}
}
return nil
}
/// Infer quantization bits from weight tensor shape vs expected input dimension.
/// Returns 4 or 8, defaulting to `defaultBits` if neither matches.
func detectBits(for weightName: String, expectedInDim: Int, defaultBits: Int = 4) -> Int {
guard let wDesc = allTensors.first(where: { $0.name == "\(prefix).\(weightName).weight" }) else {
return defaultBits
}
let packedDim = wDesc.shape[1]
if packedDim * 8 == expectedInDim { return 4 }
if packedDim * 4 == expectedInDim { return 8 }
return defaultBits
}
// Layer scalar (optional; defaults to 1.0)
let scalar: Float
if let sDesc = allTensors.first(where: { $0.name == "\(prefix).layer_scalar" }) {
let sReader: SafeTensorsReader
if let idx = index {
guard let shardFile = idx.weightMap[sDesc.name] else {
scalar = 1.0
continue
}
sReader = readers[shardFile]!
} else {
sReader = readers["model.safetensors"]!
}
let sData = try sReader.read(tensor: sDesc)
let sFloats = SafeTensorsReader.bf16ToFloat32(sData)
scalar = sFloats.first ?? 1.0
print(" layer_scalar: \(scalar)")
fflush(stdout)
} else {
scalar = 1.0
print(" layer_scalar: NOT FOUND (using 1.0)")
}
// Detect quantization bits from weight shape (supports both uniform 4-bit and 8-bit MLP/router)
let mlpGateBits = detectBits(for: "mlp.gate_proj", expectedInDim: hiddenSize, defaultBits: 4)
let mlpDownBits = detectBits(for: "mlp.down_proj", expectedInDim: intermediate, defaultBits: 4)
let attnQBits = detectBits(for: "self_attn.q_proj", expectedInDim: hiddenSize, defaultBits: 4)
let attnKBits = detectBits(for: "self_attn.k_proj", expectedInDim: hiddenSize, defaultBits: 4)
let attnVBits = detectBits(for: "self_attn.v_proj", expectedInDim: hiddenSize, defaultBits: 4)
let attnOBits = detectBits(for: "self_attn.o_proj", expectedInDim: hiddenSize, defaultBits: 4)
// Try bf16 weights first (for bf16 models)
let qpFloat = try fw("self_attn.q_proj")
let kpFloat = try fw("self_attn.k_proj")
let vpFloat = try fw("self_attn.v_proj")
let opFloat = try fw("self_attn.o_proj")
// Then try quantized weights (for quantized models)
let qpQuant = try qwFromCache("self_attn.q_proj", bits: attnQBits)
let kpQuant = try qwFromCache("self_attn.k_proj", bits: attnKBits)
let vpQuant = try qwFromCache("self_attn.v_proj", bits: attnVBits)
let opQuant = try qwFromCache("self_attn.o_proj", bits: attnOBits)
guard qpQuant != nil || qpFloat != nil,
kpQuant != nil || kpFloat != nil,
opQuant != nil || opFloat != nil
else {
throw WeightError.tensorNotFound("Missing weights for layer \(layerIdx)")
}
// ── MoE loading (auto-detect from tensor structure) ──
// Auto-detect MoE by checking if router.proj.weight exists
let hasMoETensors = allTensors.contains { $0.name.contains("\(prefix).router.proj") }
let useMoE = cfg.enableMoEBlock ?? false || hasMoETensors
// Infer numExperts from expert tensor shape if not in config
var numExperts = cfg.numExperts ?? 0
if numExperts == 0 && hasMoETensors {
// Try to infer from experts.switch_glu tensor shape
let expertTensor = allTensors.first { $0.name.contains("\(prefix).experts.switch_glu") }
if let expertShape = expertTensor?.shape, expertShape.count == 3 {
numExperts = expertShape[0] // First dimension is numExperts
}
}
// MLP weights: load real weights if available, create dummy only if missing in MoE layer
var gp = try qwFromCache("mlp.gate_proj", bits: mlpGateBits)
var up = try qwFromCache("mlp.up_proj", bits: mlpGateBits)
var dp = try qwFromCache("mlp.down_proj", bits: mlpDownBits)
var gpFloat = try fw("mlp.gate_proj")
var upFloat = try fw("mlp.up_proj")
var dpFloat = try fw("mlp.down_proj")
// If MLP weights missing and this is MoE layer, create dummy weights
if useMoE && numExperts > 0 {
if gp == nil || up == nil || dp == nil {
// Create minimal dummy weights for MoE layer (won't be used in forward if experts available)
let dummyWeight = engine.device.makeBuffer(length: 4, options: .storageModeShared)!
let dummyScales = engine.device.makeBuffer(length: 4, options: .storageModeShared)!
let dummyBiases = engine.device.makeBuffer(length: 4, options: .storageModeShared)!
let dummyQuantizedWeights = QuantizedWeights(
weight: dummyWeight, scales: dummyScales, biases: dummyBiases,
inDim: 1, outDim: 1, bits: 4, groupSize: 1
)
if gp == nil { gp = dummyQuantizedWeights }
if up == nil { up = dummyQuantizedWeights }
if dp == nil { dp = dummyQuantizedWeights }
}
} else if (gp == nil || up == nil || dp == nil) && (gpFloat == nil || upFloat == nil || dpFloat == nil) {
// Dense layer requires either quantized or bf16 MLP weights
throw WeightError.tensorNotFound("Missing MLP weights for layer \(layerIdx)")
}
// v_proj is optional - full attention layers in 12B don't have it
let vp = try qwFromCache("self_attn.v_proj")
// Per-layer weights are optional (12B doesn't have them)
let pg = try qwFromCache("per_layer_input_gate")
let pp = try qwFromCache("per_layer_projection")
// Per-layer input: nil for now (will be computed dynamically in forward)
let plSlice: MTLBuffer? = nil
let topK = cfg.topKExperts ?? 8
let moeIntermediate = cfg.moeIntermediateSize ?? 704
var routerProj: QuantizedWeights? = nil
var routerScale: Float = 1.0
var perExpertScale: [Float]? = nil
var expertGate: MoEExpertGroup? = nil
var expertUp: MoEExpertGroup? = nil
var expertDown: MoEExpertGroup? = nil
if useMoE && numExperts > 0 {
let routerBits = detectBits(for: "router.proj", expectedInDim: hiddenSize, defaultBits: 4)
routerProj = try Self.quantizedGroup(named: "\(prefix).router.proj", from: allTensors,
index: index, readers: readers,
device: engine.device, bits: routerBits)
// Load router.scale (scalar)
if let rsDesc = allTensors.first(where: { $0.name == "\(prefix).router.scale" }) {
let rsReader: SafeTensorsReader
if let idx = index, let shard = idx.weightMap[rsDesc.name] {
rsReader = readers[shard]!
} else {
rsReader = readers["model.safetensors"]!
}
let rsData = try rsReader.read(tensor: rsDesc)
let rsFloats = SafeTensorsReader.bf16ToFloat32(rsData)
let rawRouterScale = rsFloats.first ?? 1.0
// Normalize router scale by hidden_size (similar to scales normalization for 26B-Standard)
// This prevents softmax overflow in MoE router computation
routerScale = rawRouterScale / Float(hiddenSize)
}
// Load per_expert_scale ([numExperts])
if let pesDesc = allTensors.first(where: { $0.name == "\(prefix).router.per_expert_scale" }) {
let pesReader: SafeTensorsReader
if let idx = index, let shard = idx.weightMap[pesDesc.name] {
pesReader = readers[shard]!
} else {
pesReader = readers["model.safetensors"]!
}
let pesData = try pesReader.read(tensor: pesDesc)
let pesFloats = SafeTensorsReader.bf16ToFloat32(pesData)
perExpertScale = pesFloats
}
// Load expert 3D tensors as MoEExpertGroup
let ep = "\(prefix).experts.switch_glu"
expertGate = try Self.loadExpertGroup(named: "\(ep).gate_proj", from: allTensors,
index: index, readers: readers,
device: engine.device,
numExperts: numExperts,
expertOutDim: moeIntermediate,
expertInDim: hiddenSize)
expertUp = try Self.loadExpertGroup(named: "\(ep).up_proj", from: allTensors,
index: index, readers: readers,
device: engine.device,
numExperts: numExperts,
expertOutDim: moeIntermediate,
expertInDim: hiddenSize)
expertDown = try Self.loadExpertGroup(named: "\(ep).down_proj", from: allTensors,
index: index, readers: readers,
device: engine.device,
numExperts: numExperts,
expertOutDim: hiddenSize,
expertInDim: moeIntermediate)
let loaded = (expertGate != nil && expertUp != nil && expertDown != nil) ? numExperts : 0
print(" MoE: \(loaded)/\(numExperts) experts loaded")
}
let layer = E4BLayer(
config: lcfg,
layerIdx: layerIdx,
inputLayernorm: try norm("input_layernorm.weight"),
postAttentionLayernorm: try norm("post_attention_layernorm.weight"),
preFeedforwardLayernorm: try norm("pre_feedforward_layernorm.weight"),
postFeedforwardLayernorm: try norm("post_feedforward_layernorm.weight"),
postPerLayerInputNorm: try norm("post_per_layer_input_norm.weight"),
qNorm: try normStrided("self_attn.q_norm.weight", nHeads: lcfg.nHeads, hd: hd),
kNorm: try normStrided("self_attn.k_norm.weight", nHeads: lcfg.nKvHeads, hd: hd),
vNorm: try normStrided("self_attn.v_norm.weight", nHeads: lcfg.nKvHeads, hd: hd),
qProj: qpQuant, kProj: kpQuant, vProj: vpQuant, oProj: opQuant,
gateProj: gp, upProj: up, downProj: dp,
perLayerGate: pg, perLayerProjection: pp,
qProjFloat: qpFloat, kProjFloat: kpFloat, vProjFloat: vpFloat, oProjFloat: opFloat,
gateProjFloat: gpFloat, upProjFloat: upFloat, downProjFloat: dpFloat,
perLayerGateFloat: try fw("per_layer_input_gate"),
perLayerProjectionFloat: try fw("per_layer_projection"),
perLayerInput: plSlice,
perLayerInputScale: perLayerInputScaleVal,
perLayerProjectionScale: perLayerModelProjectionScaleVal,
layerScalar: scalar,
useMoE: useMoE && expertGate != nil,
routerProj: routerProj,
routerScale: routerScale,
perExpertScale: perExpertScale,
expertGate: expertGate,
expertUp: expertUp,
expertDown: expertDown,
topK: topK,
kEqualsV: (vpQuant == nil && vpFloat == nil && isFull) || (cfg.attentionKEqualsV ?? false)
)
builtLayers.append(layer)
}
self.layers = builtLayers
// ── KV caches ──
var caches: [KVCache] = []
caches.reserveCapacity(numHiddenLayers)
for layerIdx in 0..<numHiddenLayers {
let isFull = layerTypesIsFull[layerIdx]
let hd = isFull ? (cfg.globalHeadDim ?? cfg.headDim ?? 512) : (cfg.slidingHeadDim ?? cfg.headDim ?? 256)
let cacheNKvHeads = isFull ? (effectiveGlobalKvHeads ?? effectiveNKvHeads) : effectiveNKvHeads
let cache = KVCache(
device: engine.device,
isSliding: !isFull,
maxContextLength: maxContextLength,
nKvHeads: cacheNKvHeads,
headDim: hd
)
caches.append(cache)
}
self.kvCaches = caches
// ── Temps ──
print("\nCreating ForwardTemps...")
fflush(stdout)
let headDim = cfg.headDim ?? 256
// Use detectedGlobalHd if available
let globalHeadDim = detectedGlobalHd ?? cfg.globalHeadDim ?? headDim
let maxHeadDim = max(headDim, globalHeadDim)
// Max intermediate: shared layers have 2x intermediate_size; safe upper bound
let baseIntermediate = cfg.intermediateSize ?? 15360
let maxIntermediate = baseIntermediate * 2
let nHeads = effectiveNHeads
let nKvHeads = max(effectiveNKvHeads, effectiveGlobalKvHeads ?? 0)
print(" headDim: \(headDim), globalHeadDim: \(globalHeadDim), maxHeadDim: \(maxHeadDim), intermediateSize: \(baseIntermediate), maxIntermediate: \(maxIntermediate), nHeads: \(nHeads), nKvHeads: \(nKvHeads)")
self.temps = try ForwardTemps(device: engine.device,
maxHeadDim: maxHeadDim,
maxIntermediate: maxIntermediate,
hiddenSize: hiddenSize,
nHeads: nHeads,
nKvHeads: nKvHeads)
print(" ✓ ForwardTemps created")
fflush(stdout)
// ── Logits buffer ──
print("Creating logits buffer...")
print(" vocabSize: \(vocabSize), size: \(vocabSize * 4) bytes")
fflush(stdout)
guard let lb = engine.device.makeBuffer(
length: vocabSize * MemoryLayout<Float>.stride,
options: .storageModeShared
) else {
print(" ✗ Failed to create logits buffer")
throw E4BError.bufferCreationFailed
}
print(" ✓ Logits buffer created")
self.logitsBuffer = lb
// ── Final RMSNorm (before LM head) ──
print("Loading final norm...")
let finalNormName = "\(P)norm.weight"
self.finalNorm = try Self.loadNorm(named: finalNormName, from: allTensors,
index: index, readers: readers,
device: engine.device)
print(" ✓ Final norm loaded")
print("\n✓ Model initialization completed successfully\n")
}
// ── Kernel dispatch helpers ───────────────────────
func rmsNorm(input: MTLBuffer, weight: MTLBuffer?,
output: MTLBuffer, count: Int, eps: Float,
inputOffset: Int = 0, weightOffset: Int = 0, outputOffset: Int = 0) throws {
let pso = try engine.pipeline(named: "rms_norm")
let cmdBuf = engine.commandQueue.makeCommandBuffer()!
let enc = cmdBuf.makeComputeCommandEncoder()!
enc.setComputePipelineState(pso)
enc.setBuffer(input, offset: inputOffset, index: 0)
enc.setBuffer(weight, offset: weightOffset, index: 1)
enc.setBuffer(output, offset: outputOffset, index: 2)
var N = UInt32(count)
enc.setBytes(&N, length: MemoryLayout<UInt32>.size, index: 3)
var e = eps
enc.setBytes(&e, length: MemoryLayout<Float>.size, index: 4)
let tg = engine.threadgroupSize1D(pso, count: count)
enc.dispatchThreads(MTLSize(width: count, height: 1, depth: 1),
threadsPerThreadgroup: tg)
enc.endEncoding()
cmdBuf.commit()
cmdBuf.waitUntilCompleted()
}
func eltwiseAddScaledModel(a: MTLBuffer, scaleA: Float,
b: MTLBuffer, scaleB: Float,
output: MTLBuffer, count: Int) throws {
let pso = try engine.pipeline(named: "eltwise_add_scaled")
let cmdBuf = engine.commandQueue.makeCommandBuffer()!
let enc = cmdBuf.makeComputeCommandEncoder()!
enc.setComputePipelineState(pso)
enc.setBuffer(a, offset: 0, index: 0)
var sa = scaleA
enc.setBytes(&sa, length: MemoryLayout<Float>.size, index: 1)
enc.setBuffer(b, offset: 0, index: 2)
var sb = scaleB
enc.setBytes(&sb, length: MemoryLayout<Float>.size, index: 3)
enc.setBuffer(output, offset: 0, index: 4)
var N = UInt32(count)
enc.setBytes(&N, length: MemoryLayout<UInt32>.size, index: 5)
let tg = engine.threadgroupSize1D(pso, count: count)
enc.dispatchThreads(MTLSize(width: count, height: 1, depth: 1),
threadsPerThreadgroup: tg)
enc.endEncoding()
cmdBuf.commit()
cmdBuf.waitUntilCompleted()
}
/// Matmul for regular F32 weights (not quantized)
func matmulBF16(input: MTLBuffer, weight: MTLBuffer, output: MTLBuffer,
inDim: Int, outDim: Int) throws {
let pso = try engine.pipeline(named: "matmul_f32")
let cmdBuf = engine.commandQueue.makeCommandBuffer()!
let enc = cmdBuf.makeComputeCommandEncoder()!
enc.setComputePipelineState(pso)
enc.setBuffer(input, offset: 0, index: 0)
enc.setBuffer(weight, offset: 0, index: 1)
enc.setBuffer(output, offset: 0, index: 2)
var M = UInt32(1) // batch size
enc.setBytes(&M, length: MemoryLayout<UInt32>.size, index: 3)
var K = UInt32(inDim)
enc.setBytes(&K, length: MemoryLayout<UInt32>.size, index: 4)
var N = UInt32(outDim)
enc.setBytes(&N, length: MemoryLayout<UInt32>.size, index: 5)
let tg = MTLSize(width: 32, height: 1, depth: 1)
enc.dispatchThreads(MTLSize(width: outDim, height: 1, depth: 1),
threadsPerThreadgroup: tg)
enc.endEncoding()
cmdBuf.commit()
cmdBuf.waitUntilCompleted()
}
// ── Weight loading helpers ────────────────────────
private static func loadNorm(named: String, from tensors: [TensorDescriptor],
index: SafeTensorsIndex?,
readers: [String: SafeTensorsReader],
device: MTLDevice) throws -> MTLBuffer? {
guard let desc = findTensor(named, in: tensors) else { return nil }
// Get correct reader
let reader: SafeTensorsReader
if let idx = index {
guard let shardFile = idx.weightMap[desc.name] else { return nil }
reader = readers[shardFile]!
} else {
reader = readers["model.safetensors"]!
}
let data = try reader.read(tensor: desc)
let floats: [Float]
if desc.dtype == .bf16 {
floats = SafeTensorsReader.bf16ToFloat32(data)
} else if desc.dtype == .f32 {
floats = data.withUnsafeBytes { Array($0.assumingMemoryBound(to: Float.self)) }
} else {
return nil
}
guard let buf = device.makeBuffer(
bytes: floats, length: floats.count * MemoryLayout<Float>.stride,
options: .storageModeShared
) else { return nil }
return buf
}
/// Load a norm weight, handling per-head repeating.
/// - If tensor has `headDim` elements (e.g. [256]): repeat `nHeads` times → [nHeads * headDim]
/// - If tensor has `nHeads * maxHeadDim` elements: extract first `headDim` per head at maxHeadDim stride
/// - Otherwise use as-is ([nHeads * headDim] already).
private static func loadNormStrided(named: String, from tensors: [TensorDescriptor],
index: SafeTensorsIndex?,
readers: [String: SafeTensorsReader],
device: MTLDevice,
nHeads: Int, headDim: Int,
maxHeadDim: Int) throws -> MTLBuffer? {
guard let desc = findTensor(named, in: tensors) else { return nil }
// Get correct reader
let reader: SafeTensorsReader
if let idx = index {
guard let shardFile = idx.weightMap[desc.name] else { return nil }
reader = readers[shardFile]!
} else {
reader = readers["model.safetensors"]!
}
let data = try reader.read(tensor: desc)
let floats: [Float]
if desc.dtype == .bf16 {
floats = SafeTensorsReader.bf16ToFloat32(data)
} else if desc.dtype == .f32 {
floats = data.withUnsafeBytes { Array($0.assumingMemoryBound(to: Float.self)) }
} else {
return nil
}
let actualCount = nHeads * headDim
// Case 1: shared headDim weight — repeat for each head
if floats.count == headDim && floats.count < actualCount {
var repeated = [Float](repeating: 0, count: actualCount)
for h in 0..<nHeads {
for d in 0..<headDim {
repeated[h * headDim + d] = floats[d]
}
}
guard let buf = device.makeBuffer(
bytes: repeated, length: repeated.count * MemoryLayout<Float>.stride,
options: .storageModeShared
) else { return nil }
return buf
}
// Case 2: stored at maxHeadDim stride (larger than needed)
if floats.count > actualCount {
var extracted = [Float](repeating: 0, count: actualCount)
for h in 0..<nHeads {
for d in 0..<headDim {
extracted[h * headDim + d] = floats[h * maxHeadDim + d]
}
}
guard let buf = device.makeBuffer(
bytes: extracted, length: extracted.count * MemoryLayout<Float>.stride,
options: .storageModeShared
) else { return nil }
return buf
}
// Case 3: already correct size
guard let buf = device.makeBuffer(
bytes: floats, length: floats.count * MemoryLayout<Float>.stride,
options: .storageModeShared
) else { return nil }
return buf
}
private static func quantizedGroup(named: String, from tensors: [TensorDescriptor],
index: SafeTensorsIndex?,
readers: [String: SafeTensorsReader],
device: MTLDevice,
bits: Int = 4) throws -> QuantizedWeights? {
// Build tensor name -> descriptor map for fast lookups
let tensorMap = Dictionary(uniqueKeysWithValues: tensors.map { ($0.name, $0) })
// Also add stripped prefix variants
let prefix = "language_model.model."
let tensorMapWithPrefix = tensors.reduce(into: [String: TensorDescriptor]()) { dict, desc in
dict[desc.name] = desc
if desc.name.hasPrefix(prefix) {
dict[String(desc.name.dropFirst(prefix.count))] = desc
}
}
func findTensor(_ name: String) -> TensorDescriptor? {
if let desc = tensorMapWithPrefix[name] { return desc }
// Try original map in case name doesn't have prefix
return tensorMap[name]
}
let wName = "\(named).weight"
let sName = "\(named).scales"
let bName = "\(named).biases"
guard let wDesc = findTensor(wName),
let sDesc = findTensor(sName)
else {
return nil
}
// Biases are optional (e.g., MLX 26B embed_tensors has no biases)
let bDesc = findTensor(bName)
// Get readers for each tensor (may be in different shards)
let wReader: SafeTensorsReader
let sReader: SafeTensorsReader
let bReader: SafeTensorsReader?
if let idx = index {
// Sharded: resolve correct shard for each tensor
// Use the actual tensor names (may have prefix stripped)
let actualWName = wDesc.name
let actualSName = sDesc.name
let actualBName = bDesc?.name
guard let wShard = idx.weightMap[actualWName],
let sShard = idx.weightMap[actualSName] else {
return nil
}
wReader = readers[wShard]!
sReader = readers[sShard]!
if let actualBName = actualBName, let bShard = idx.weightMap[actualBName] {
bReader = readers[bShard]
} else {
bReader = nil
}
} else {
// Single file
wReader = readers["model.safetensors"]!
sReader = wReader
bReader = wReader
}
// Read data from correct readers
let wData = try wReader.read(tensor: wDesc)
let sData = try sReader.read(tensor: sDesc)
let bData = bReader != nil && bDesc != nil ? try bReader!.read(tensor: bDesc!) : nil
var sFloats = SafeTensorsReader.bf16ToFloat32(sData)
let bFloats = bData != nil ? SafeTensorsReader.bf16ToFloat32(bData!) : nil
let outDim = wDesc.shape[0]
// inDim = packed dim * (32 / bits) (e.g. 4-bit: 8 vals/u32, 8-bit: 4 vals/u32)
let valsPerU32 = 32 / bits
let inDim = wDesc.shape[1] * valsPerU32
// Compute groupSize from scales shape: scales.shape[1] = inDim / groupSize
let numGroups = sDesc.shape[1]
let groupSize = inDim / numGroups
// Normalize scales for groupSize=32 custom quantization
// These models store scales inflated by hiddenSize factor
if groupSize == 32 {
for i in 0..<sFloats.count {
sFloats[i] = sFloats[i] / Float(inDim)
}
}
guard let wBuf = device.makeBuffer(
bytes: (wData as NSData).bytes, length: wData.count,
options: .storageModeShared
) else { return nil }
guard let sBuf = device.makeBuffer(
bytes: sFloats, length: sFloats.count * MemoryLayout<Float>.stride,
options: .storageModeShared
) else { return nil }
// Create zero biases if not present
let bBuf: MTLBuffer
if let bFloats = bFloats {
guard let buf = device.makeBuffer(
bytes: bFloats, length: bFloats.count * MemoryLayout<Float>.stride,
options: .storageModeShared
) else { return nil }
bBuf = buf
} else {
// Create zero biases with same size as scales
let bFloatsZero = [Float](repeating: 0.0, count: sFloats.count)
guard let buf = device.makeBuffer(
bytes: bFloatsZero, length: bFloatsZero.count * MemoryLayout<Float>.stride,
options: .storageModeShared
) else { return nil }
bBuf = buf
}
return QuantizedWeights(weight: wBuf, scales: sBuf, biases: bBuf,
inDim: inDim, outDim: outDim, bits: bits, groupSize: groupSize)
}
/// Load non-quantized bf16 embedding weights as FloatWeights
private static func loadFloatEmbed(named: String, from tensors: [TensorDescriptor],
index: SafeTensorsIndex?,
readers: [String: SafeTensorsReader],
device: MTLDevice,
hiddenSize: Int) throws -> FloatWeights? {
let tensorMap = Dictionary(uniqueKeysWithValues: tensors.map { ($0.name, $0) })
let prefix = "language_model.model."
let modelPrefix = "model.language_model.model."
let modelPrefixShort = "model.language_model."
let tensorMapWithPrefix = tensors.reduce(into: [String: TensorDescriptor]()) { dict, desc in
dict[desc.name] = desc
if desc.name.hasPrefix(prefix) {
dict[String(desc.name.dropFirst(prefix.count))] = desc
}
if desc.name.hasPrefix(modelPrefix) {
dict[String(desc.name.dropFirst(modelPrefix.count))] = desc
}
if desc.name.hasPrefix(modelPrefixShort) {
dict[String(desc.name.dropFirst(modelPrefixShort.count))] = desc
}
}
func findTensor(_ name: String) -> TensorDescriptor? {
if let desc = tensorMapWithPrefix[name] { return desc }
return tensorMap[name]
}
let wName = "\(named).weight"
guard let wDesc = findTensor(wName) else {
return nil
}
if wDesc.dtype != .bf16 {
return nil
}
let wReader: SafeTensorsReader
if let idx = index {
let actualWName = wDesc.name
guard let wShard = idx.weightMap[actualWName] else { return nil }
wReader = readers[wShard]!
} else {
wReader = readers["model.safetensors"]!
}
let wData = try wReader.read(tensor: wDesc)
let wFloats = SafeTensorsReader.bf16ToFloat32(wData)
let outDim = wDesc.shape[0]
let inDim = wDesc.shape[1]
guard let wBuf = device.makeBuffer(
bytes: wFloats, length: wFloats.count * MemoryLayout<Float>.stride,
options: .storageModeShared
) else { return nil }
return FloatWeights(weight: wBuf, inDim: inDim, outDim: outDim)
}
/// Load non-quantized bf16 layer weights as FloatWeights
private static func loadFloatWeight(named: String, from tensors: [TensorDescriptor],
index: SafeTensorsIndex?,
readers: [String: SafeTensorsReader],
device: MTLDevice) throws -> FloatWeights? {
let tensorMap = Dictionary(uniqueKeysWithValues: tensors.map { ($0.name, $0) })
let prefix = "language_model.model."
let modelPrefix = "model.language_model."
let tensorMapWithPrefix = tensors.reduce(into: [String: TensorDescriptor]()) { dict, desc in
dict[desc.name] = desc
if desc.name.hasPrefix(prefix) {
dict[String(desc.name.dropFirst(prefix.count))] = desc
}
if desc.name.hasPrefix(modelPrefix) {
dict[String(desc.name.dropFirst(modelPrefix.count))] = desc
}
}
func findTensor(_ name: String) -> TensorDescriptor? {
if let desc = tensorMapWithPrefix[name] { return desc }
return tensorMap[name]
}
let wName = "\(named).weight"
guard let wDesc = findTensor(wName) else { return nil }
if wDesc.dtype != .bf16 {
return nil
}
let wReader: SafeTensorsReader
if let idx = index {
let actualWName = wDesc.name
guard let wShard = idx.weightMap[actualWName] else { return nil }
wReader = readers[wShard]!
} else {
wReader = readers["model.safetensors"]!
}
let wData = try wReader.read(tensor: wDesc)
let wFloats = SafeTensorsReader.bf16ToFloat32(wData)
let outDim = wDesc.shape[0]
let inDim = wDesc.shape[1]
guard let wBuf = device.makeBuffer(
bytes: wFloats, length: wFloats.count * MemoryLayout<Float>.stride,
options: .storageModeShared
) else { return nil }
return FloatWeights(weight: wBuf, inDim: inDim, outDim: outDim)
}
/// Load a 3D expert tensor [numExperts, expertOutDim, inDimPacked] as a contiguous MoEExpertGroup.
/// The data layout is: expert0[outDim, inDimPacked], expert1[outDim, inDimPacked], ...
/// Per-expert access is done via byte offsets into the shared buffers.
private static func loadExpertGroup(named: String, from tensors: [TensorDescriptor],
index: SafeTensorsIndex?,
readers: [String: SafeTensorsReader],
device: MTLDevice,
numExperts: Int,
expertOutDim: Int,
expertInDim: Int,
bits: Int = 4) throws -> MoEExpertGroup? {
let wName = "\(named).weight"
let sName = "\(named).scales"
let bName = "\(named).biases"
guard let wDesc = findTensor(wName, in: tensors),
let sDesc = findTensor(sName, in: tensors)
else {
print(" loadExpertGroup: missing weight or scales for \(named)")
return nil
}
// Weight: [numExperts, expertOutDim, inDimPacked] uint32
guard wDesc.shape.count == 3 else {
print(" loadExpertGroup: expected 3D weight, got \(wDesc.shape)")
return nil
}
// Scales: [numExperts, expertOutDim, numGroups] bf16
// Biases: same shape as scales
let numGroups = sDesc.shape.count > 2 ? sDesc.shape[2] : expertInDim / 64
let expertGroupSize = expertInDim / numGroups
// Get readers
let wReader: SafeTensorsReader
let sReader: SafeTensorsReader
let bReader: SafeTensorsReader?
if let idx = index {
guard let wShard = idx.weightMap[wDesc.name],
let sShard = idx.weightMap[sDesc.name] else { return nil }
wReader = readers[wShard]!
sReader = readers[sShard]!
if let bDesc = findTensor(bName, in: tensors),
let bShard = idx.weightMap[bDesc.name] {
bReader = readers[bShard]
} else {
bReader = nil
}
} else {
wReader = readers["model.safetensors"]!
sReader = wReader
bReader = wReader
}
let wData = try wReader.read(tensor: wDesc)
let sData = try sReader.read(tensor: sDesc)
let bDesc = bReader != nil ? findTensor(bName, in: tensors) : nil
let bData: Data? = bDesc != nil ? try bReader!.read(tensor: bDesc!) : nil
var sFloats = SafeTensorsReader.bf16ToFloat32(sData)
let bFloats = bData != nil ? SafeTensorsReader.bf16ToFloat32(bData!) : nil
// Normalize scales for groupSize=32 custom quantization
if expertGroupSize == 32 {
for i in 0..<sFloats.count {
sFloats[i] = sFloats[i] / Float(expertInDim)
}
}
let valsPerU32 = 32 / bits
let inDimPacked = expertInDim / valsPerU32
// Create weight buffer (full 3D tensor flattened)
// Expected: wDesc.shape = [numExperts, expertOutDim, inDimPacked]
// Verify shape matches
if wDesc.shape[0] != numExperts ||
wDesc.shape[1] != expertOutDim ||
wDesc.shape[2] != inDimPacked {
print(" loadExpertGroup: shape mismatch: got \(wDesc.shape), expected [\(numExperts), \(expertOutDim), \(inDimPacked)]")
}
guard let wBuf = device.makeBuffer(
bytes: (wData as NSData).bytes, length: wData.count,
options: .storageModeShared
) else { return nil }
guard let sBuf = device.makeBuffer(
bytes: sFloats, length: sFloats.count * MemoryLayout<Float>.stride,
options: .storageModeShared
) else { return nil }
let bBuf: MTLBuffer
if let bFloats = bFloats {
guard let buf = device.makeBuffer(
bytes: bFloats, length: bFloats.count * MemoryLayout<Float>.stride,
options: .storageModeShared
) else { return nil }
bBuf = buf
} else {
let zeros = [Float](repeating: 0.0, count: numExperts * expertOutDim * numGroups)
guard let buf = device.makeBuffer(
bytes: zeros, length: zeros.count * MemoryLayout<Float>.stride,
options: .storageModeShared
) else { return nil }
bBuf = buf
}
return MoEExpertGroup(
weight: wBuf, scales: sBuf, biases: bBuf,
expertOutDim: expertOutDim,
expertInDim: expertInDim,
numGroups: numGroups,
numExperts: numExperts,
bits: bits
)
}
/// Helper to find tensor with prefix fallback
private static func findTensor(_ name: String, in tensors: [TensorDescriptor]) -> TensorDescriptor? {
if let desc = tensors.first(where: { $0.name == name }) { return desc }
let prefix = "language_model.model."
if name.hasPrefix(prefix) {
let stripped = String(name.dropFirst(prefix.count))
if let desc = tensors.first(where: { $0.name == stripped }) { return desc }
}
return nil
}
// ── Forward ───────────────────────────────────────
/// Run one step of the model: token → logits.
public func forward(tokenId: Int, position: Int, debug: Bool = false) throws -> [Float] {
let h = temps.io
// ── 1. Embedding lookup ──
try dequantizeRow(weight: embedWeight, tokenId: tokenId, output: h)
// Check embedding for NaN
if position == 0 {
let embedVals = engine.readFloats(from: h, count: min(20, hiddenSize))
let hasNaN = embedVals.contains { $0.isNaN }
let nanCount = embedVals.filter { $0.isNaN }.count
print("TEXT Embedding: sample=\(embedVals.prefix(10)), NaN=\(nanCount)/\(min(20, hiddenSize)), hasNaN=\(hasNaN)")
}
if debug && position == 0 {
let hPtr = h.contents().bindMemory(to: Float.self, capacity: hiddenSize)
print("Pos \(position) token \(tokenId): BEFORE embed_scale h[0:5]=[\(hPtr[0]), \(hPtr[1]), \(hPtr[2]), \(hPtr[3]), \(hPtr[4])]")
print(" embedScale = \(embedScale)")
}
// ── 2. Embedding scale ──
if embedScale != 1.0 {
try scaleBuffer(h, scale: embedScale, count: hiddenSize)
}
// ── 2b. Per-layer embedding (E4B only) ──
// Only use per-layer embedding if model has it (embedTokensPerLayerWeight != nil)
let usePerLayer = embedTokensPerLayerWeight != nil && perLayerEmbedBuffer != nil && perLayerContextBuffer != nil
if usePerLayer, let plWeight = embedTokensPerLayerWeight,
let plBuf = perLayerEmbedBuffer,
let ctxBuf = perLayerContextBuffer {
let totalPerLayer = perLayerInputSize * numHiddenLayers
// Step 1: Token-identity component
// get_per_layer_inputs: embed_tokens_per_layer(input_ids) with scale sqrt(perLayerSize)
try dequantizeRow(weight: plWeight, tokenId: tokenId, output: plBuf, nCols: totalPerLayer)
let plEmbedScale = sqrt(Float(perLayerInputSize))
try scaleBuffer(plBuf, scale: plEmbedScale, count: totalPerLayer)
// Step 2: Context-aware projection
// project_per_layer_inputs: per_layer_model_projection(inputs_embeds) * scale
if let projBuf = perLayerModelProjection {
// Regular matmul (not quantized): [10752, 2560] @ [2560] -> [10752]
try matmulBF16(input: h, weight: projBuf, output: ctxBuf,
inDim: hiddenSize, outDim: perLayerModelProjectionOutDim)
// Scale by 1/sqrt(hiddenSize)
try scaleBuffer(ctxBuf, scale: perLayerModelProjectionScaleVal, count: totalPerLayer)
// Apply per_layer_projection_norm (RMSNorm on each layer's slice)
// CRITICAL: RMSNorm is NOT safe for in-place with multiple threadgroups
// Must use separate input/output buffers. Use plBuf as temp.
if let norm = perLayerProjectionNorm {
// Norm each layer's slice: ctxBuf -> plBuf
for layerIdx in 0..<numHiddenLayers {
let offset = layerIdx * perLayerInputSize
try rmsNorm(input: ctxBuf, weight: norm, output: plBuf,
count: perLayerInputSize, eps: rmsNormEps,
inputOffset: offset * 4, weightOffset: 0, outputOffset: offset * 4)
}
// Copy normed result back to ctxBuf
let cmdBufNorm = engine.commandQueue.makeCommandBuffer()!
let blitNorm = cmdBufNorm.makeBlitCommandEncoder()!
blitNorm.copy(from: plBuf, sourceOffset: 0,
to: ctxBuf, destinationOffset: 0,
size: totalPerLayer * 4)
blitNorm.endEncoding()
cmdBufNorm.commit()
cmdBufNorm.waitUntilCompleted()
// Re-compute token identity (was overwritten by norm output)
try dequantizeRow(weight: plWeight, tokenId: tokenId, output: plBuf, nCols: totalPerLayer)
try scaleBuffer(plBuf, scale: plEmbedScale, count: totalPerLayer)
}
// Combine: (context_projection + token_identity) * per_layer_input_scale
// Python: return (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale
try eltwiseAddScaledModel(a: ctxBuf, scaleA: 1.0,
b: plBuf, scaleB: 1.0,
output: ctxBuf, count: totalPerLayer)
try scaleBuffer(ctxBuf, scale: perLayerInputScaleVal, count: totalPerLayer)
// Copy to plBuf for layer use
let cmdBuf = engine.commandQueue.makeCommandBuffer()!
let blit = cmdBuf.makeBlitCommandEncoder()!
blit.copy(from: ctxBuf, sourceOffset: 0,
to: plBuf, destinationOffset: 0,
size: totalPerLayer * 4)
blit.endEncoding()
cmdBuf.commit()
cmdBuf.waitUntilCompleted()
}
}
// ── 37. Layers, LM head, readback ──
return try forwardAfterEmbedding(position: position, debug: debug)
}
/// Forward pass using a pre-computed hidden state (skips embedding lookup).
/// - Parameters:
/// - hiddenBuffer: MTLBuffer containing one or more hidden states.
/// - offset: Byte offset into hiddenBuffer for this position's hidden state.
/// - position: Sequence position (used for KV cache indexing).
public func forwardFromHidden(hiddenBuffer: MTLBuffer, offset: Int = 0, position: Int, debug: Bool = false) throws -> [Float] {
let h = temps.io
let copySize = hiddenSize * MemoryLayout<Float>.stride
let cmdBuf = engine.commandQueue.makeCommandBuffer()!
let blit = cmdBuf.makeBlitCommandEncoder()!
blit.copy(from: hiddenBuffer, sourceOffset: offset,
to: h, destinationOffset: 0,
size: copySize)
blit.endEncoding()
cmdBuf.commit()
cmdBuf.waitUntilCompleted()
return try forwardAfterEmbedding(position: position, debug: debug)
}
/// Shared logic after embedding is in temps.io: layers → LM head → readback.
private func forwardAfterEmbedding(position: Int, debug: Bool = false) throws -> [Float] {
let h = temps.io
// ── 3. Layer loop ──
if debug && position == 0 {
print(" KV config: firstKVShared=\(firstKVShared), numKvShared=\(numKvShared), numLayers=\(numHiddenLayers)")
}
if position == 0 {
print("TEXT Starting layer loop: numHiddenLayers=\(numHiddenLayers)")
fflush(stdout)
}
for layerIdx in 0..<numHiddenLayers {
let isOwner = layerIdx < firstKVShared
let cacheIdx = isOwner ? layerIdx : (kvSourceMap[layerIdx] ?? (layerIdx - numKvShared))
let cache = kvCaches[cacheIdx]
// Per-layer input offset (E4B only)
let plOffset = perLayerInputSize > 0 ? layerIdx * perLayerInputSize * MemoryLayout<Float>.stride : 0
try layers[layerIdx].forward(
input: h, position: position,
kvCache: cache, shouldStoreKV: isOwner,
temps: temps, engine: engine,
perLayerInput: perLayerEmbedBuffer,
perLayerInputOffset: plOffset
)
// Debug: check for NaN after each layer (only on position 0, first 5 layers)
if position == 0 && layerIdx < 5 {
let hPtr = h.contents().bindMemory(to: Float.self, capacity: hiddenSize)
let sample = (0..<min(10, hiddenSize)).map { hPtr[$0] }
let hasNaN = sample.contains { $0.isNaN }
let nanCount = sample.filter { $0.isNaN }.count
print("TEXT After Layer \(layerIdx): sample=\(sample.prefix(5)), NaN=\(nanCount)/\(min(10, hiddenSize)), hasNaN=\(hasNaN)")
fflush(stdout)
}
if debug && position == 0 && layerIdx < 10 {
let hPtr = h.contents().bindMemory(to: Float.self, capacity: hiddenSize)
let ls = layers[layerIdx].layerScalar
let magnitude = sqrt(hPtr[0]*hPtr[0] + hPtr[1]*hPtr[1] + hPtr[2]*hPtr[2] + hPtr[3]*hPtr[3] + hPtr[4]*hPtr[4])
print(" ✓ After Layer \(layerIdx): h[0:5]=[\(hPtr[0]), \(hPtr[1]), \(hPtr[2]), \(hPtr[3]), \(hPtr[4])], scalar=\(ls), mag=\(magnitude)")
}
}
// ── 4. Final RMSNorm ──
var lmInput = h
// Check hidden state after layers
if position == 0 {
let hVals = engine.readFloats(from: h, count: min(20, hiddenSize))
let hasNaN = hVals.contains { $0.isNaN }
let nanCount = hVals.filter { $0.isNaN }.count
print("TEXT After layers: sample=\(hVals.prefix(10)), NaN=\(nanCount)/\(min(20, hiddenSize)), hasNaN=\(hasNaN)")
}
if let fn = finalNorm {
try rmsNorm(input: h, weight: fn, output: temps.ns,
count: hiddenSize, eps: rmsNormEps)
lmInput = temps.ns
// Debug: check hidden state after norm
if position == 0 {
let hiddenVals = engine.readFloats(from: temps.ns, count: min(20, hiddenSize))
let hasNaN = hiddenVals.contains { $0.isNaN }
let nanCount = hiddenVals.filter { $0.isNaN }.count
print("TEXT After finalNorm: sample=\(hiddenVals.prefix(10)), NaN=\(nanCount)/\(min(20, hiddenSize)), hasNaN=\(hasNaN)")
}
}
// ── 5. LM head (tied embeddings) ──
try quantizedMatmulModel(input: lmInput, weights: embedWeight, output: logitsBuffer)
// Check logits after LM head
if position == 0 {
let logitsVals = engine.readFloats(from: logitsBuffer, count: min(20, vocabSize))
let hasNaN = logitsVals.contains { $0.isNaN }
let nanCount = logitsVals.filter { $0.isNaN }.count
print("TEXT After LM head: sample=\(logitsVals.prefix(10)), NaN=\(nanCount)/\(min(20, vocabSize)), hasNaN=\(hasNaN)")
}
// ── 5b. Logits scaling for custom quantization (groupSize=32) ──
// For groupSize=32 models, logits are ~200x larger than standard
// NOTE: groupSize=32 scale normalization now done in quantizedGroup/loadExpertGroup
// No additional logit scaling needed here
// ── 6. Logit softcapping ──
if let cap = finalLogitSoftcapping {
if debug && position == 0 {
print(" Applying logit softcapping with cap=\(cap)")
}
try applyLogitSoftcapping(buffer: logitsBuffer, cap: cap, count: vocabSize)
} else {
if debug && position == 0 {
print(" No logit softcapping (cap is nil)")
}
}
// ── 7. Read back ──
let logits = engine.readFloats(from: logitsBuffer, count: vocabSize)
if debug && position < 3 {
let maxLogit = logits.max() ?? 0
let minLogit = logits.min() ?? 0
let sorted = logits.enumerated().sorted(by: { $0.element > $1.element })
let top5 = sorted.prefix(5).map { "\($0.offset): \($0.element)" }
print(" Final logits: max=\(maxLogit), min=\(minLogit), top5: \(top5)")
}
return logits
}
// ── Model-level kernel dispatches ─────────────────
func dequantizeRow(weight: QuantizedWeights, tokenId: Int,
output: MTLBuffer, nCols: Int? = nil) throws {
let pso = try engine.pipeline(named: "dequantize_row")
let cmdBuf = engine.commandQueue.makeCommandBuffer()!
let enc = cmdBuf.makeComputeCommandEncoder()!
enc.setComputePipelineState(pso)
enc.setBuffer(weight.weight, offset: 0, index: 0)
enc.setBuffer(weight.scales, offset: 0, index: 1)
enc.setBuffer(weight.biases, offset: 0, index: 2)
enc.setBuffer(output, offset: 0, index: 3)
let actualCols = nCols ?? hiddenSize
var nColsVal = UInt32(actualCols)
enc.setBytes(&nColsVal, length: MemoryLayout<UInt32>.size, index: 4)
var row = Int32(tokenId)
enc.setBytes(&row, length: MemoryLayout<Int32>.size, index: 5)
var groupSize = UInt32(weight.groupSize)
enc.setBytes(&groupSize, length: MemoryLayout<UInt32>.size, index: 6)
let count = actualCols
let tg = engine.threadgroupSize1D(pso, count: count)
enc.dispatchThreads(MTLSize(width: count, height: 1, depth: 1),
threadsPerThreadgroup: tg)
enc.endEncoding()
cmdBuf.commit()
cmdBuf.waitUntilCompleted()
}
func scaleBuffer(_ buf: MTLBuffer, scale: Float, count: Int) throws {
let pso = try engine.pipeline(named: "eltwise_scale")
let cmdBuf = engine.commandQueue.makeCommandBuffer()!
let enc = cmdBuf.makeComputeCommandEncoder()!
enc.setComputePipelineState(pso)
enc.setBuffer(buf, offset: 0, index: 0)
var s = scale
enc.setBytes(&s, length: MemoryLayout<Float>.size, index: 1)
var N = UInt32(count)
enc.setBytes(&N, length: MemoryLayout<UInt32>.size, index: 2)
let tg = engine.threadgroupSize1D(pso, count: count)
enc.dispatchThreads(MTLSize(width: count, height: 1, depth: 1),
threadsPerThreadgroup: tg)
enc.endEncoding()
cmdBuf.commit()
cmdBuf.waitUntilCompleted()
}
func quantizedMatmulModel(input: MTLBuffer, weights: QuantizedWeights,
output: MTLBuffer) throws {
let pso = try engine.pipeline(named: "quantized_matmul")
let cmdBuf = engine.commandQueue.makeCommandBuffer()!
let enc = cmdBuf.makeComputeCommandEncoder()!
enc.setComputePipelineState(pso)
enc.setBuffer(input, offset: 0, index: 0)
enc.setBuffer(weights.weight, offset: 0, index: 1)
enc.setBuffer(weights.scales, offset: 0, index: 2)
enc.setBuffer(weights.biases, offset: 0, index: 3)
enc.setBuffer(output, offset: 0, index: 4)
var inDim = UInt32(weights.inDim)
enc.setBytes(&inDim, length: MemoryLayout<UInt32>.size, index: 5)
var outDim = UInt32(weights.outDim)
enc.setBytes(&outDim, length: MemoryLayout<UInt32>.size, index: 6)
var groupSize = UInt32(weights.groupSize)
enc.setBytes(&groupSize, length: MemoryLayout<UInt32>.size, index: 7)
let count = weights.outDim
let tg = engine.threadgroupSize1D(pso, count: count)
enc.dispatchThreads(MTLSize(width: count, height: 1, depth: 1),
threadsPerThreadgroup: tg)
enc.endEncoding()
cmdBuf.commit()
cmdBuf.waitUntilCompleted()
}
func applyLogitSoftcapping(buffer: MTLBuffer, cap: Float, count: Int) throws {
let pso = try engine.pipeline(named: "tanh_scale")
let cmdBuf = engine.commandQueue.makeCommandBuffer()!
let enc = cmdBuf.makeComputeCommandEncoder()!
enc.setComputePipelineState(pso)
enc.setBuffer(buffer, offset: 0, index: 0)
enc.setBuffer(buffer, offset: 0, index: 1) // in-place
var c = cap
enc.setBytes(&c, length: MemoryLayout<Float>.size, index: 2)
var N = UInt32(count)
enc.setBytes(&N, length: MemoryLayout<UInt32>.size, index: 3)
let tg = engine.threadgroupSize1D(pso, count: count)
enc.dispatchThreads(MTLSize(width: count, height: 1, depth: 1),
threadsPerThreadgroup: tg)
enc.endEncoding()
cmdBuf.commit()
cmdBuf.waitUntilCompleted()
}
}