31427770b1
- Tokenizer fix: collect <0xXX> bytes and decode as UTF-8 (fixes Chinese/non-ASCII character decoding) - BPETokenizer + HuggingFaceTokenizer: both updated - Engine.swift: added writeFloats() utility method - FloatWeights struct added to Layer.swift (bf16 support) - attnQBits/KBits/VBits/OBits detection added to Model.swift - bf16 layer weight support from commit 48c0347 cherry-picked
1831 lines
86 KiB
Swift
1831 lines
86 KiB
Swift
import Foundation
|
||
import Metal
|
||
|
||
// ═══════════════════════════════════════════════════
|
||
// E4B Text Model — 42-layer forward pass
|
||
// ═══════════════════════════════════════════════════
|
||
|
||
/// KV sharing map: for each shared layer, which non-shared cache to read from.
|
||
func computeKVSourceMap(numHiddenLayers: Int, numKVShared: Int, layerTypesIsFull: [Bool]) -> [Int: Int] {
|
||
// Use Python's algorithm: count same-type layers, then match in reverse order
|
||
let firstShared = numHiddenLayers - numKVShared
|
||
var counts: [Bool: Int] = [false: 0, true: 0]
|
||
for l in 0..<firstShared {
|
||
counts[layerTypesIsFull[l]]! += 1
|
||
}
|
||
var map: [Int: Int] = [:]
|
||
for l in firstShared..<numHiddenLayers {
|
||
let isFull = layerTypesIsFull[l]
|
||
counts[isFull]! -= 1
|
||
let targetCount = counts[isFull]!
|
||
var found = 0
|
||
for src in 0..<firstShared {
|
||
if layerTypesIsFull[src] == isFull {
|
||
if found == targetCount {
|
||
map[l] = src
|
||
break
|
||
}
|
||
found += 1
|
||
}
|
||
}
|
||
}
|
||
return map
|
||
}
|
||
|
||
public final class E4BModel: @unchecked Sendable {
|
||
public let engine: MarkBaseEngine
|
||
public let config: ModelConfig
|
||
public let layers: [E4BLayer]
|
||
public let kvCaches: [KVCache]
|
||
public let temps: ForwardTemps
|
||
public let embedWeight: QuantizedWeights
|
||
public let embedTokensPerLayerWeight: QuantizedWeights?
|
||
public let perLayerModelProjection: MTLBuffer? // Context-aware projection (BF16 as F32)
|
||
public let perLayerProjectionNorm: MTLBuffer? // Norm for context projection
|
||
public let perLayerInputSize: Int
|
||
public let perLayerModelProjectionOutDim: Int // 10752 = numLayers * perLayerSize
|
||
|
||
// Convenience
|
||
public let numHiddenLayers: Int
|
||
public let hiddenSize: Int
|
||
public let vocabSize: Int
|
||
public let firstKVShared: Int
|
||
public let numKvShared: Int
|
||
public let kvSourceMap: [Int: Int]
|
||
public let rmsNormEps: Float
|
||
public let finalLogitSoftcapping: Float?
|
||
public let embedScale: Float
|
||
public let perLayerInputScaleVal: Float
|
||
public let perLayerModelProjectionScaleVal: Float // 1/sqrt(hiddenSize)
|
||
public let layerTypesIsFull: [Bool] // true = full_attention, false = sliding_attention
|
||
public let maxContext: Int
|
||
|
||
// LM head output buffer
|
||
let logitsBuffer: MTLBuffer
|
||
|
||
// Per-layer embedding buffer (for current token)
|
||
let perLayerEmbedBuffer: MTLBuffer?
|
||
let perLayerContextBuffer: MTLBuffer? // Context-aware projection result
|
||
|
||
// Final RMSNorm (before LM head)
|
||
let finalNorm: MTLBuffer?
|
||
|
||
// ── Init ──────────────────────────────────────────
|
||
|
||
public init(modelDir: String, engine: MarkBaseEngine, maxContextLength: Int = 8192) throws {
|
||
self.engine = engine
|
||
self.maxContext = maxContextLength
|
||
|
||
// Load config
|
||
let cfg = try ModelConfig.load(from: modelDir)
|
||
self.config = cfg
|
||
self.numHiddenLayers = cfg.numHiddenLayers ?? 42
|
||
self.hiddenSize = cfg.hiddenSize ?? 2560
|
||
self.vocabSize = cfg.vocabSize ?? 262144
|
||
self.rmsNormEps = cfg.rmsNormEps ?? 1e-6
|
||
self.finalLogitSoftcapping = cfg.finalLogitSoftcapping
|
||
self.embedScale = cfg.embedScale ?? sqrt(Float(cfg.hiddenSize ?? 2560))
|
||
self.numKvShared = cfg.numKvSharedLayers ?? 0 // 0 = no KV sharing (default for models without this config)
|
||
self.firstKVShared = numHiddenLayers - numKvShared
|
||
let pattern = cfg.slidingWindowPattern ?? 6 // Gemma 4: 5 sliding + 1 full = pattern 6
|
||
self.perLayerInputScaleVal = cfg.perLayerInputScale ?? sqrt(0.5)
|
||
// Determine layer types from config, or derive from pattern
|
||
if let lt = cfg.layerTypes {
|
||
self.layerTypesIsFull = lt.map { $0 == "full_attention" }
|
||
} else {
|
||
var derived: [Bool] = []
|
||
for i in 0..<numHiddenLayers {
|
||
derived.append((i % pattern) == (pattern - 1))
|
||
}
|
||
self.layerTypesIsFull = derived
|
||
}
|
||
self.kvSourceMap = computeKVSourceMap(
|
||
numHiddenLayers: numHiddenLayers,
|
||
numKVShared: numKvShared,
|
||
layerTypesIsFull: layerTypesIsFull
|
||
)
|
||
|
||
// Load all tensors (support both single-file and sharded)
|
||
let singleFile = "\(modelDir)/model.safetensors"
|
||
let index: SafeTensorsIndex?
|
||
let readers: [String: SafeTensorsReader] // shard file -> reader
|
||
|
||
if FileManager.default.fileExists(atPath: singleFile) {
|
||
// Single-file model (E4B)
|
||
index = nil
|
||
let reader = try SafeTensorsReader(path: singleFile)
|
||
readers = ["model.safetensors": reader]
|
||
} else {
|
||
// Sharded model (12B): load all shards in parallel
|
||
let indexPath = "\(modelDir)/model.safetensors.index.json"
|
||
if !FileManager.default.fileExists(atPath: indexPath) {
|
||
throw WeightError.readFailed("No model.safetensors or index.json found")
|
||
}
|
||
|
||
// Load index
|
||
index = try SafeTensorsIndex(modelDir: modelDir)
|
||
|
||
// Parallel shard loading (critical optimization for large models)
|
||
print("Loading \(index!.shardFiles.count) shards in parallel...")
|
||
let loadStart = Date()
|
||
|
||
let shardFiles = index!.shardFiles.sorted()
|
||
var loadedReaders: [SafeTensorsReader?] = Array(repeating: nil, count: shardFiles.count)
|
||
var loadErrors: [Error?] = Array(repeating: nil, count: shardFiles.count)
|
||
|
||
// Use DispatchGroup for parallel loading (thread-safe array access)
|
||
let dispatchGroup = DispatchGroup()
|
||
let queue = DispatchQueue(label: "shard-loading", attributes: .concurrent)
|
||
|
||
for (idx, shardFile) in shardFiles.enumerated() {
|
||
dispatchGroup.enter()
|
||
queue.async {
|
||
do {
|
||
let shardPath = "\(modelDir)/\(shardFile)"
|
||
let reader = try SafeTensorsReader(path: shardPath)
|
||
loadedReaders[idx] = reader // Thread-safe: each thread writes to different index
|
||
} catch {
|
||
loadErrors[idx] = error
|
||
}
|
||
dispatchGroup.leave()
|
||
}
|
||
}
|
||
|
||
dispatchGroup.wait()
|
||
|
||
// Check for errors and build dictionary (sequential, thread-safe)
|
||
var readersDict: [String: SafeTensorsReader] = [:]
|
||
for (idx, error) in loadErrors.enumerated() {
|
||
if let err = error {
|
||
throw WeightError.readFailed("Failed to load shard \(shardFiles[idx]): \(err)")
|
||
}
|
||
if let reader = loadedReaders[idx] {
|
||
readersDict[shardFiles[idx]] = reader
|
||
}
|
||
}
|
||
|
||
let loadTime = Date().timeIntervalSince(loadStart) * 1000
|
||
print("✓ Parallel loaded \(readersDict.count) shards in \(String(format: "%.1f", loadTime))ms")
|
||
print(" Shards: \(shardFiles)")
|
||
readers = readersDict
|
||
}
|
||
|
||
// Helper functions for unified tensor access
|
||
func getReader(forTensor name: String) -> SafeTensorsReader? {
|
||
if let idx = index {
|
||
guard let shardFile = idx.weightMap[name] else { return nil }
|
||
return readers[shardFile]
|
||
} else {
|
||
return readers["model.safetensors"]
|
||
}
|
||
}
|
||
|
||
func getAllTensorDescriptors() -> [TensorDescriptor] {
|
||
if index != nil {
|
||
// Sharded: collect all tensors from all shards
|
||
var all: [TensorDescriptor] = []
|
||
for reader in readers.values {
|
||
all.append(contentsOf: reader.allTensors)
|
||
}
|
||
return all
|
||
} else {
|
||
// Single file
|
||
return readers["model.safetensors"]!.allTensors
|
||
}
|
||
}
|
||
|
||
let allTensors = getAllTensorDescriptors()
|
||
print("✓ Total tensors: \(allTensors.count)")
|
||
|
||
// E4B MLX models use "language_model.model." prefix, but converted models may omit it
|
||
// Detect which format by checking for "layers.0.self_attn.q_proj.weight"
|
||
let P: String
|
||
if allTensors.contains(where: { $0.name == "layers.0.self_attn.q_proj.weight" }) {
|
||
P = ""
|
||
print(" Using short prefix (no language_model.model.)")
|
||
} else {
|
||
P = "language_model.model."
|
||
print(" Using long prefix (language_model.model.)")
|
||
}
|
||
|
||
// Config values may be wrong (e.g. 26B-standard has nHeads=8 but q_proj has 16 heads).
|
||
var effectiveNHeads = cfg.numAttentionHeads ?? 16
|
||
var effectiveNKvHeads = cfg.numKeyValueHeads ?? 8
|
||
var effectiveGlobalKvHeads = cfg.numGlobalKeyValueHeads
|
||
let slidingHd = cfg.slidingHeadDim ?? cfg.headDim ?? 256
|
||
let globalHd = cfg.globalHeadDim ?? cfg.headDim ?? 512
|
||
if let qDesc = allTensors.first(where: { $0.name.contains("language_model") && $0.name.hasSuffix("self_attn.q_proj.weight") }) {
|
||
let qOut = qDesc.shape[0]
|
||
if qOut > 0 {
|
||
let detected = qOut % slidingHd == 0 ? qOut / slidingHd : qOut / globalHd
|
||
if detected != effectiveNHeads {
|
||
print(" ⚠ q_proj out_dim=\(qOut) → nHeads=\(detected) (config says \(effectiveNHeads))")
|
||
effectiveNHeads = detected
|
||
}
|
||
}
|
||
}
|
||
if let kDesc = allTensors.first(where: { $0.name.contains("language_model") && $0.name.hasSuffix("self_attn.k_proj.weight") }) {
|
||
let kOut = kDesc.shape[0]
|
||
if kOut > 0 && kOut % slidingHd == 0 {
|
||
let detected = kOut / slidingHd
|
||
if detected != effectiveNKvHeads {
|
||
print(" ⚠ k_proj out_dim=\(kOut), head_dim=\(slidingHd) → nKvHeads=\(detected) (config says \(effectiveNKvHeads))")
|
||
effectiveNKvHeads = detected
|
||
}
|
||
}
|
||
}
|
||
// Detect global kv heads from first full attention layer
|
||
// Also detect globalHeadDim from k_norm.weight shape
|
||
var detectedGlobalHd: Int? = nil
|
||
if let firstFullIdx = layerTypesIsFull.firstIndex(of: true) {
|
||
// Detect globalHeadDim from k_norm.weight shape
|
||
let kNormName = "\(P)layers.\(firstFullIdx).self_attn.k_norm.weight"
|
||
if let kNormDesc = allTensors.first(where: { $0.name == kNormName }) {
|
||
let kNormShape = kNormDesc.shape[0]
|
||
if kNormShape > 0 {
|
||
detectedGlobalHd = kNormShape
|
||
print(" ⚠ Detected globalHeadDim=\(kNormShape) from k_norm.weight (config: \(cfg.globalHeadDim ?? -1))")
|
||
}
|
||
}
|
||
|
||
// Detect globalKvHeads from k_proj.weight
|
||
let fullKName = "\(P)layers.\(firstFullIdx).self_attn.k_proj.weight"
|
||
if let fullKDesc = allTensors.first(where: { $0.name == fullKName }) {
|
||
let kOut = fullKDesc.shape[0]
|
||
// Use detectedGlobalHd if available, otherwise fallback
|
||
let actualGlobalHd = detectedGlobalHd ?? globalHd
|
||
if kOut > 0 && kOut % actualGlobalHd == 0 {
|
||
let detected = kOut / actualGlobalHd
|
||
if effectiveGlobalKvHeads == nil || detected != effectiveGlobalKvHeads {
|
||
print(" ⚠ (full) k_proj out_dim=\(kOut), global_head_dim=\(actualGlobalHd) → globalKvHeads=\(detected) (config: \(effectiveGlobalKvHeads.map(String.init) ?? "nil"))")
|
||
effectiveGlobalKvHeads = detected
|
||
}
|
||
}
|
||
}
|
||
}
|
||
if effectiveNHeads != (cfg.numAttentionHeads ?? 16) || effectiveNKvHeads != (cfg.numKeyValueHeads ?? 8) {
|
||
print(" → Using effective: nHeads=\(effectiveNHeads), nKvHeads=\(effectiveNKvHeads), globalKvHeads=\(effectiveGlobalKvHeads.map(String.init) ?? "nil")")
|
||
}
|
||
|
||
// ── Load embed tokens ──
|
||
print("Loading embed_tokens...")
|
||
|
||
// Debug: Check what embed_tensors exist
|
||
let embedTensors = allTensors.filter { $0.name.contains("embed_tokens") }
|
||
print(" Found \(embedTensors.count) embed_tokens tensors:")
|
||
for t in embedTensors.prefix(5) {
|
||
print(" - \(t.name): dtype=\(t.dtype), shape=\(t.shape)")
|
||
}
|
||
|
||
// Try without prefix first (converted format), then with prefix (original format)
|
||
var embedGroup = try Self.quantizedGroup(named: "embed_tokens", from: allTensors,
|
||
index: index, readers: readers,
|
||
device: engine.device)
|
||
if embedGroup == nil {
|
||
print(" Trying with prefix...")
|
||
embedGroup = try Self.quantizedGroup(named: "\(P)embed_tokens", from: allTensors,
|
||
index: index, readers: readers,
|
||
device: engine.device)
|
||
}
|
||
|
||
// Handle optional missing scales/biases (non-quantized embedding)
|
||
if let eg = embedGroup {
|
||
print(" ✓ embed_tokens loaded")
|
||
// Check if scales need normalization for custom quantization
|
||
// For groupSize=32 models, scales are ~3000x larger than standard
|
||
// Need to divide by hiddenSize to get correct values
|
||
if eg.groupSize == 32 && eg.inDim == hiddenSize {
|
||
print(" ⚠ Detected groupSize=32 custom quantization, normalizing scales...")
|
||
let scaleCorrection = Float(hiddenSize)
|
||
let pso = try engine.pipeline(named: "eltwise_scale")
|
||
let cmdBuf = engine.commandQueue.makeCommandBuffer()!
|
||
let enc = cmdBuf.makeComputeCommandEncoder()!
|
||
enc.setComputePipelineState(pso)
|
||
enc.setBuffer(eg.scales, offset: 0, index: 0)
|
||
var s = 1.0 / scaleCorrection
|
||
enc.setBytes(&s, length: MemoryLayout<Float>.size, index: 1)
|
||
let count = eg.scales.length / MemoryLayout<Float>.stride
|
||
var N = UInt32(count)
|
||
enc.setBytes(&N, length: MemoryLayout<UInt32>.size, index: 2)
|
||
let tg = engine.threadgroupSize1D(pso, count: count)
|
||
enc.dispatchThreads(MTLSize(width: count, height: 1, depth: 1),
|
||
threadsPerThreadgroup: tg)
|
||
enc.endEncoding()
|
||
cmdBuf.commit()
|
||
cmdBuf.waitUntilCompleted()
|
||
print(" ✓ Scales normalized (divided by \(scaleCorrection))")
|
||
}
|
||
self.embedWeight = eg
|
||
} else {
|
||
// Non-quantized: create dummy quantized wrapper (all 0 scales=1.0, biases=0.0)
|
||
// Actually, if not quantized, we'd need to treat it as f32 weights
|
||
// For E4B 4-bit, it should be quantized
|
||
throw WeightError.unsupportedDtype("Embed tokens not quantized")
|
||
}
|
||
|
||
// ── Load embed_tokens_per_layer ──
|
||
print("Loading embed_tokens_per_layer...")
|
||
let perLayerSize = cfg.hiddenSizePerLayerInput ?? 256
|
||
self.perLayerInputSize = perLayerSize
|
||
print(" cfg.hiddenSizePerLayerInput: \(String(describing: cfg.hiddenSizePerLayerInput))")
|
||
print(" perLayerSize: \(perLayerSize), numHiddenLayers: \(numHiddenLayers)")
|
||
self.perLayerModelProjectionScaleVal = 1.0 / sqrt(Float(hiddenSize))
|
||
|
||
if perLayerSize > 0 {
|
||
// Load the quantized per-layer embedding table
|
||
let plWeight = try Self.quantizedGroup(named: "\(P)embed_tokens_per_layer", from: allTensors,
|
||
index: index, readers: readers,
|
||
device: engine.device)
|
||
if let pw = plWeight {
|
||
print(" ✓ embed_tokens_per_layer loaded: outDim=\(pw.outDim), inDim=\(pw.inDim)")
|
||
self.embedTokensPerLayerWeight = pw
|
||
// Create buffer for per-layer embedding lookup result
|
||
let totalPerLayer = perLayerSize * numHiddenLayers
|
||
guard let plBuf = engine.device.makeBuffer(length: totalPerLayer * MemoryLayout<Float>.stride,
|
||
options: .storageModeShared) else {
|
||
throw E4BError.bufferCreationFailed
|
||
}
|
||
self.perLayerEmbedBuffer = plBuf
|
||
|
||
// Context-aware projection buffer
|
||
guard let ctxBuf = engine.device.makeBuffer(length: totalPerLayer * MemoryLayout<Float>.stride,
|
||
options: .storageModeShared) else {
|
||
throw E4BError.bufferCreationFailed
|
||
}
|
||
self.perLayerContextBuffer = ctxBuf
|
||
print(" ✓ Per-layer buffers created: \(totalPerLayer) Floats each")
|
||
} else {
|
||
print(" ✗ Failed to load embed_tokens_per_layer")
|
||
self.embedTokensPerLayerWeight = nil
|
||
self.perLayerEmbedBuffer = nil
|
||
self.perLayerContextBuffer = nil
|
||
}
|
||
|
||
// Load per_layer_model_projection (context-aware projection, BF16)
|
||
// This is NOT quantized - it's a regular BF16 linear weight
|
||
let projName = "\(P)per_layer_model_projection.weight"
|
||
if let projDesc = allTensors.first(where: { $0.name == projName }) {
|
||
let projReader: SafeTensorsReader
|
||
if let idx = index, let shard = idx.weightMap[projName] {
|
||
projReader = readers[shard]!
|
||
} else {
|
||
projReader = readers["model.safetensors"]!
|
||
}
|
||
|
||
let projData = try projReader.read(tensor: projDesc)
|
||
let projFloats = SafeTensorsReader.bf16ToFloat32(projData)
|
||
let outDim = projDesc.shape[0] // 10752
|
||
let inDim = projDesc.shape[1] // 2560
|
||
|
||
guard let projBuf = engine.device.makeBuffer(
|
||
bytes: projFloats, length: projFloats.count * MemoryLayout<Float>.stride,
|
||
options: .storageModeShared
|
||
) else {
|
||
throw E4BError.bufferCreationFailed
|
||
}
|
||
self.perLayerModelProjection = projBuf
|
||
self.perLayerModelProjectionOutDim = outDim
|
||
print(" ✓ per_layer_model_projection loaded: shape=[\(outDim), \(inDim)], dtype=BF16→F32")
|
||
} else {
|
||
print(" ✗ Failed to load per_layer_model_projection")
|
||
self.perLayerModelProjection = nil
|
||
self.perLayerModelProjectionOutDim = 0
|
||
}
|
||
|
||
// Load per_layer_projection_norm
|
||
let normWeight = try Self.loadNorm(named: "\(P)per_layer_projection_norm.weight", from: allTensors,
|
||
index: index, readers: readers,
|
||
device: engine.device)
|
||
if let nw = normWeight {
|
||
print(" ✓ per_layer_projection_norm loaded")
|
||
self.perLayerProjectionNorm = nw
|
||
} else {
|
||
print(" ✗ Failed to load per_layer_projection_norm")
|
||
self.perLayerProjectionNorm = nil
|
||
}
|
||
} else {
|
||
// 12B doesn't use per-layer input
|
||
print(" Per-layer input disabled (size=0)")
|
||
self.embedTokensPerLayerWeight = nil
|
||
self.perLayerEmbedBuffer = nil
|
||
self.perLayerContextBuffer = nil
|
||
self.perLayerModelProjection = nil
|
||
self.perLayerProjectionNorm = nil
|
||
self.perLayerModelProjectionOutDim = 0
|
||
}
|
||
|
||
// ── Build per-layer data ──
|
||
// ── Optimized: Pre-read all layer weights in parallel before layer construction ──
|
||
// This is the major bottleneck optimization: parallel file reads
|
||
print("\nPre-reading all layer weights in parallel...")
|
||
let preloadStart = Date()
|
||
|
||
//方案C: 直接收集allTensors中实际存在的layer权重
|
||
var allWeightNames: [String] = []
|
||
var debugCounts = (language: 0, vision: 0, audio: 0, other: 0)
|
||
|
||
for layerIdx in 0..<numHiddenLayers {
|
||
// 查找所有包含此layer索引的tensor(精确匹配,避免误匹配layers.1x)
|
||
let layerPrefix = "\(P)layers.\(layerIdx)."
|
||
let layerTensors = allTensors.filter { tensor in
|
||
tensor.name.hasPrefix(layerPrefix) &&
|
||
!tensor.name.contains("vision_tower") &&
|
||
!tensor.name.contains("audio_tower")
|
||
}
|
||
|
||
for tensor in layerTensors {
|
||
allWeightNames.append(tensor.name)
|
||
|
||
// Debug counting
|
||
if tensor.name.contains("language_model.model.layers") {
|
||
debugCounts.language += 1
|
||
} else if tensor.name.contains("vision_tower") {
|
||
debugCounts.vision += 1
|
||
} else if tensor.name.contains("audio_tower") {
|
||
debugCounts.audio += 1
|
||
} else {
|
||
debugCounts.other += 1
|
||
}
|
||
}
|
||
}
|
||
|
||
print(" Collected \(allWeightNames.count) weight names from allTensors")
|
||
|
||
// Parallel weight loading using DispatchGroup (thread-safe array access)
|
||
let dispatchGroup = DispatchGroup()
|
||
let loadQueue = DispatchQueue(label: "weight-preloading", attributes: .concurrent)
|
||
var loadedWeights: [Data?] = Array(repeating: nil, count: allWeightNames.count)
|
||
var loadErrors: [Error?] = Array(repeating: nil, count: allWeightNames.count)
|
||
var notFoundCount = 0
|
||
var successCount = 0
|
||
|
||
for (weightIndex, name) in allWeightNames.enumerated() {
|
||
dispatchGroup.enter()
|
||
loadQueue.async {
|
||
do {
|
||
guard let desc = allTensors.first(where: { $0.name == name }) else {
|
||
notFoundCount += 1
|
||
dispatchGroup.leave()
|
||
return
|
||
}
|
||
|
||
let reader: SafeTensorsReader
|
||
if let idx = index, let shardFile = idx.weightMap[name] {
|
||
reader = readers[shardFile]!
|
||
} else {
|
||
reader = readers["model.safetensors"]!
|
||
}
|
||
|
||
let data = try reader.read(tensor: desc)
|
||
loadedWeights[weightIndex] = data
|
||
successCount += 1
|
||
} catch {
|
||
loadErrors[weightIndex] = error
|
||
}
|
||
dispatchGroup.leave()
|
||
}
|
||
}
|
||
|
||
dispatchGroup.wait()
|
||
|
||
print(" Parallel load completed: success=\(successCount), notFound=\(notFoundCount), errors=\(loadErrors.compactMap { $0 }.count)")
|
||
|
||
// Create weight data cache for fast lookup
|
||
var preloadedDataCache: [String: Data] = [:]
|
||
for (weightIndex, name) in allWeightNames.enumerated() {
|
||
if let data = loadedWeights[weightIndex] {
|
||
preloadedDataCache[name] = data
|
||
}
|
||
}
|
||
|
||
print(" Loaded weights: \(loadedWeights.filter { $0 != nil }.count)/\(loadedWeights.count)")
|
||
|
||
let preloadTime = Date().timeIntervalSince(preloadStart) * 1000
|
||
print("✓ Parallel preloaded \(preloadedDataCache.count) weights in \(String(format: "%.1f", preloadTime))ms")
|
||
|
||
// Check for errors
|
||
for (weightIndex, error) in loadErrors.enumerated() {
|
||
if let err = error {
|
||
print("⚠ Failed to preload weight \(allWeightNames[weightIndex]): \(err)")
|
||
}
|
||
}
|
||
|
||
print("\nBuilding layers...")
|
||
var builtLayers: [E4BLayer] = []
|
||
builtLayers.reserveCapacity(numHiddenLayers)
|
||
|
||
for layerIdx in 0..<numHiddenLayers {
|
||
let prefix = "\(P)layers.\(layerIdx)"
|
||
|
||
// Optimized helper methods that use preloaded cache
|
||
func normFromCache(_ name: String) throws -> MTLBuffer? {
|
||
let fullName = "\(prefix).\(name)"
|
||
if let data = preloadedDataCache[fullName] {
|
||
let desc = allTensors.first(where: { $0.name == fullName })
|
||
let floats: [Float]
|
||
if desc?.dtype == .bf16 {
|
||
floats = SafeTensorsReader.bf16ToFloat32(data)
|
||
} else if desc?.dtype == .f32 {
|
||
floats = data.withUnsafeBytes { Array($0.assumingMemoryBound(to: Float.self)) }
|
||
} else {
|
||
return nil
|
||
}
|
||
guard let buf = engine.device.makeBuffer(
|
||
bytes: floats, length: floats.count * MemoryLayout<Float>.stride,
|
||
options: .storageModeShared
|
||
) else { return nil }
|
||
return buf
|
||
}
|
||
return try Self.loadNorm(named: fullName, from: allTensors,
|
||
index: index, readers: readers,
|
||
device: engine.device)
|
||
}
|
||
|
||
func qwFromCache(_ name: String, bits: Int = 4) throws -> QuantizedWeights? {
|
||
let fullName = "\(prefix).\(name)"
|
||
let wName = "\(fullName).weight"
|
||
let sName = "\(fullName).scales"
|
||
let bName = "\(fullName).biases"
|
||
|
||
if let wData = preloadedDataCache[wName], let sData = preloadedDataCache[sName] {
|
||
let bData = preloadedDataCache[bName]
|
||
let wDesc = allTensors.first(where: { $0.name == wName })
|
||
let sDesc = allTensors.first(where: { $0.name == sName })
|
||
|
||
let wBuf = wData.withUnsafeBytes { ptr in
|
||
engine.device.makeBuffer(bytes: ptr.baseAddress!, length: wData.count, options: .storageModeShared)
|
||
}
|
||
|
||
// Convert scales from BF16 to Float32 (safetensors stores as BF16)
|
||
let sBuf: MTLBuffer?
|
||
if sDesc?.dtype == .bf16 {
|
||
let sFloats = SafeTensorsReader.bf16ToFloat32(sData)
|
||
sBuf = engine.device.makeBuffer(
|
||
bytes: sFloats, length: sFloats.count * MemoryLayout<Float>.stride,
|
||
options: .storageModeShared
|
||
)
|
||
} else {
|
||
sBuf = sData.withUnsafeBytes { ptr in
|
||
engine.device.makeBuffer(bytes: ptr.baseAddress!, length: sData.count, options: .storageModeShared)
|
||
}
|
||
}
|
||
|
||
// Convert biases from BF16 to Float32
|
||
let bBuf: MTLBuffer?
|
||
if let bData = bData {
|
||
if let bDesc = allTensors.first(where: { $0.name == bName }), bDesc.dtype == .bf16 {
|
||
let bFloats = SafeTensorsReader.bf16ToFloat32(bData)
|
||
bBuf = engine.device.makeBuffer(
|
||
bytes: bFloats, length: bFloats.count * MemoryLayout<Float>.stride,
|
||
options: .storageModeShared
|
||
)
|
||
} else {
|
||
bBuf = bData.withUnsafeBytes { ptr in
|
||
engine.device.makeBuffer(bytes: ptr.baseAddress!, length: bData.count, options: .storageModeShared)
|
||
}
|
||
}
|
||
} else {
|
||
// No bias data, create zero biases with same count as scales
|
||
let sCount = sDesc?.shape.reduce(1, *) ?? 0
|
||
let bFloatsZero = [Float](repeating: 0.0, count: sCount)
|
||
bBuf = engine.device.makeBuffer(
|
||
bytes: bFloatsZero,
|
||
length: bFloatsZero.count * MemoryLayout<Float>.stride,
|
||
options: .storageModeShared
|
||
)
|
||
}
|
||
|
||
guard let wBufSafe = wBuf, let sBufSafe = sBuf, let bBufSafe = bBuf else {
|
||
return nil
|
||
}
|
||
|
||
let wShape = wDesc?.shape ?? []
|
||
let sShape = sDesc?.shape ?? []
|
||
|
||
let outDim = wShape[0]
|
||
let packedDim = wShape[1]
|
||
let inDim = packedDim * (bits == 4 ? 8 : 4)
|
||
let groupSize = (sShape.count > 1 && sShape[1] > 0) ? inDim / sShape[1] : 64
|
||
|
||
return QuantizedWeights(
|
||
weight: wBufSafe,
|
||
scales: sBufSafe,
|
||
biases: bBufSafe,
|
||
inDim: inDim,
|
||
outDim: outDim,
|
||
bits: bits,
|
||
groupSize: groupSize
|
||
)
|
||
}
|
||
return try Self.quantizedGroup(named: fullName, from: allTensors,
|
||
index: index, readers: readers,
|
||
device: engine.device, bits: bits)
|
||
}
|
||
|
||
let isFull = layerTypesIsFull[layerIdx]
|
||
// E4B uses intermediateSize=10240 for all layers, not doubled for shared KV
|
||
let intermediate = cfg.intermediateSize ?? 10240
|
||
// Use detected globalHeadDim for full layers if available
|
||
let hd = isFull ? (detectedGlobalHd ?? cfg.globalHeadDim ?? cfg.headDim ?? 512) : (cfg.slidingHeadDim ?? cfg.headDim ?? 256)
|
||
let nHeads = effectiveNHeads
|
||
let nKvHeads = isFull ? (effectiveGlobalKvHeads ?? effectiveNKvHeads) : effectiveNKvHeads
|
||
print(" isFull: \(isFull), headDim: \(hd), intermediate: \(intermediate), nHeads: \(nHeads), nKvHeads: \(nKvHeads)")
|
||
fflush(stdout)
|
||
let lcfg: E4BLayerConfig = isFull
|
||
? .full(hiddenSize: hiddenSize, headDim: hd, intermediateSize: intermediate,
|
||
nHeads: nHeads, nKvHeads: nKvHeads, maxPosition: maxContextLength)
|
||
: .sliding(hiddenSize: hiddenSize, headDim: hd, intermediateSize: intermediate,
|
||
nHeads: nHeads, nKvHeads: nKvHeads, windowSize: cfg.slidingWindow ?? 512)
|
||
|
||
let maxHeadDim = cfg.headDim ?? 512
|
||
|
||
func norm(_ name: String) throws -> MTLBuffer? {
|
||
try Self.loadNorm(named: "\(prefix).\(name)", from: allTensors,
|
||
index: index, readers: readers,
|
||
device: engine.device)
|
||
}
|
||
|
||
func normStrided(_ name: String, nHeads: Int, hd: Int) throws -> MTLBuffer? {
|
||
try Self.loadNormStrided(named: "\(prefix).\(name)", from: allTensors,
|
||
index: index, readers: readers,
|
||
device: engine.device,
|
||
nHeads: nHeads, headDim: hd, maxHeadDim: maxHeadDim)
|
||
}
|
||
|
||
func qw(_ name: String, bits: Int = 4) throws -> QuantizedWeights? {
|
||
try Self.quantizedGroup(named: "\(prefix).\(name)", from: allTensors,
|
||
index: index, readers: readers,
|
||
device: engine.device, bits: bits)
|
||
}
|
||
|
||
func fw(_ name: String) throws -> FloatWeights? {
|
||
let fullName = "\(prefix).\(name)"
|
||
let wName = "\(fullName).weight"
|
||
|
||
// Check if weight is in preloaded cache
|
||
if let wData = preloadedDataCache[wName] {
|
||
let wDesc = allTensors.first(where: { $0.name == wName })
|
||
if let desc = wDesc, desc.dtype == .bf16 {
|
||
let wFloats = SafeTensorsReader.bf16ToFloat32(wData)
|
||
let outDim = desc.shape[0]
|
||
let inDim = desc.shape[1]
|
||
if let wBuf = engine.device.makeBuffer(
|
||
bytes: wFloats, length: wFloats.count * MemoryLayout<Float>.stride,
|
||
options: .storageModeShared
|
||
) {
|
||
return FloatWeights(weight: wBuf, inDim: inDim, outDim: outDim)
|
||
}
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
/// Infer quantization bits from weight tensor shape vs expected input dimension.
|
||
/// Returns 4 or 8, defaulting to `defaultBits` if neither matches.
|
||
func detectBits(for weightName: String, expectedInDim: Int, defaultBits: Int = 4) -> Int {
|
||
guard let wDesc = allTensors.first(where: { $0.name == "\(prefix).\(weightName).weight" }) else {
|
||
return defaultBits
|
||
}
|
||
let packedDim = wDesc.shape[1]
|
||
if packedDim * 8 == expectedInDim { return 4 }
|
||
if packedDim * 4 == expectedInDim { return 8 }
|
||
return defaultBits
|
||
}
|
||
|
||
// Layer scalar (optional; defaults to 1.0)
|
||
let scalar: Float
|
||
if let sDesc = allTensors.first(where: { $0.name == "\(prefix).layer_scalar" }) {
|
||
let sReader: SafeTensorsReader
|
||
if let idx = index {
|
||
guard let shardFile = idx.weightMap[sDesc.name] else {
|
||
scalar = 1.0
|
||
continue
|
||
}
|
||
sReader = readers[shardFile]!
|
||
} else {
|
||
sReader = readers["model.safetensors"]!
|
||
}
|
||
|
||
let sData = try sReader.read(tensor: sDesc)
|
||
let sFloats = SafeTensorsReader.bf16ToFloat32(sData)
|
||
scalar = sFloats.first ?? 1.0
|
||
print(" layer_scalar: \(scalar)")
|
||
fflush(stdout)
|
||
} else {
|
||
scalar = 1.0
|
||
print(" layer_scalar: NOT FOUND (using 1.0)")
|
||
}
|
||
|
||
// Detect quantization bits from weight shape (supports both uniform 4-bit and 8-bit MLP/router)
|
||
let mlpGateBits = detectBits(for: "mlp.gate_proj", expectedInDim: hiddenSize, defaultBits: 4)
|
||
let mlpDownBits = detectBits(for: "mlp.down_proj", expectedInDim: intermediate, defaultBits: 4)
|
||
let attnQBits = detectBits(for: "self_attn.q_proj", expectedInDim: hiddenSize, defaultBits: 4)
|
||
let attnKBits = detectBits(for: "self_attn.k_proj", expectedInDim: hiddenSize, defaultBits: 4)
|
||
let attnVBits = detectBits(for: "self_attn.v_proj", expectedInDim: hiddenSize, defaultBits: 4)
|
||
let attnOBits = detectBits(for: "self_attn.o_proj", expectedInDim: hiddenSize, defaultBits: 4)
|
||
|
||
// Try bf16 weights first (for bf16 models)
|
||
let qpFloat = try fw("self_attn.q_proj")
|
||
let kpFloat = try fw("self_attn.k_proj")
|
||
let vpFloat = try fw("self_attn.v_proj")
|
||
let opFloat = try fw("self_attn.o_proj")
|
||
|
||
// Then try quantized weights (for quantized models)
|
||
let qpQuant = try qwFromCache("self_attn.q_proj", bits: attnQBits)
|
||
let kpQuant = try qwFromCache("self_attn.k_proj", bits: attnKBits)
|
||
let vpQuant = try qwFromCache("self_attn.v_proj", bits: attnVBits)
|
||
let opQuant = try qwFromCache("self_attn.o_proj", bits: attnOBits)
|
||
|
||
guard qpQuant != nil || qpFloat != nil,
|
||
kpQuant != nil || kpFloat != nil,
|
||
opQuant != nil || opFloat != nil
|
||
else {
|
||
throw WeightError.tensorNotFound("Missing weights for layer \(layerIdx)")
|
||
}
|
||
|
||
// ── MoE loading (auto-detect from tensor structure) ──
|
||
// Auto-detect MoE by checking if router.proj.weight exists
|
||
let hasMoETensors = allTensors.contains { $0.name.contains("\(prefix).router.proj") }
|
||
let useMoE = cfg.enableMoEBlock ?? false || hasMoETensors
|
||
|
||
// Infer numExperts from expert tensor shape if not in config
|
||
var numExperts = cfg.numExperts ?? 0
|
||
if numExperts == 0 && hasMoETensors {
|
||
// Try to infer from experts.switch_glu tensor shape
|
||
let expertTensor = allTensors.first { $0.name.contains("\(prefix).experts.switch_glu") }
|
||
if let expertShape = expertTensor?.shape, expertShape.count == 3 {
|
||
numExperts = expertShape[0] // First dimension is numExperts
|
||
}
|
||
}
|
||
|
||
// MLP weights: load real weights if available, create dummy only if missing in MoE layer
|
||
var gp = try qwFromCache("mlp.gate_proj", bits: mlpGateBits)
|
||
var up = try qwFromCache("mlp.up_proj", bits: mlpGateBits)
|
||
var dp = try qwFromCache("mlp.down_proj", bits: mlpDownBits)
|
||
var gpFloat = try fw("mlp.gate_proj")
|
||
var upFloat = try fw("mlp.up_proj")
|
||
var dpFloat = try fw("mlp.down_proj")
|
||
|
||
// If MLP weights missing and this is MoE layer, create dummy weights
|
||
if useMoE && numExperts > 0 {
|
||
if gp == nil || up == nil || dp == nil {
|
||
// Create minimal dummy weights for MoE layer (won't be used in forward if experts available)
|
||
let dummyWeight = engine.device.makeBuffer(length: 4, options: .storageModeShared)!
|
||
let dummyScales = engine.device.makeBuffer(length: 4, options: .storageModeShared)!
|
||
let dummyBiases = engine.device.makeBuffer(length: 4, options: .storageModeShared)!
|
||
|
||
let dummyQuantizedWeights = QuantizedWeights(
|
||
weight: dummyWeight, scales: dummyScales, biases: dummyBiases,
|
||
inDim: 1, outDim: 1, bits: 4, groupSize: 1
|
||
)
|
||
|
||
if gp == nil { gp = dummyQuantizedWeights }
|
||
if up == nil { up = dummyQuantizedWeights }
|
||
if dp == nil { dp = dummyQuantizedWeights }
|
||
}
|
||
} else if (gp == nil || up == nil || dp == nil) && (gpFloat == nil || upFloat == nil || dpFloat == nil) {
|
||
// Dense layer requires either quantized or bf16 MLP weights
|
||
throw WeightError.tensorNotFound("Missing MLP weights for layer \(layerIdx)")
|
||
}
|
||
|
||
// v_proj is optional - full attention layers in 12B don't have it
|
||
let vp = try qwFromCache("self_attn.v_proj")
|
||
|
||
// Per-layer weights are optional (12B doesn't have them)
|
||
let pg = try qwFromCache("per_layer_input_gate")
|
||
let pp = try qwFromCache("per_layer_projection")
|
||
|
||
// Per-layer input: nil for now (will be computed dynamically in forward)
|
||
let plSlice: MTLBuffer? = nil
|
||
|
||
let topK = cfg.topKExperts ?? 8
|
||
let moeIntermediate = cfg.moeIntermediateSize ?? 704
|
||
var routerProj: QuantizedWeights? = nil
|
||
var routerScale: Float = 1.0
|
||
var perExpertScale: [Float]? = nil
|
||
var expertGate: MoEExpertGroup? = nil
|
||
var expertUp: MoEExpertGroup? = nil
|
||
var expertDown: MoEExpertGroup? = nil
|
||
|
||
if useMoE && numExperts > 0 {
|
||
let routerBits = detectBits(for: "router.proj", expectedInDim: hiddenSize, defaultBits: 4)
|
||
routerProj = try Self.quantizedGroup(named: "\(prefix).router.proj", from: allTensors,
|
||
index: index, readers: readers,
|
||
device: engine.device, bits: routerBits)
|
||
|
||
// Load router.scale (scalar)
|
||
if let rsDesc = allTensors.first(where: { $0.name == "\(prefix).router.scale" }) {
|
||
let rsReader: SafeTensorsReader
|
||
if let idx = index, let shard = idx.weightMap[rsDesc.name] {
|
||
rsReader = readers[shard]!
|
||
} else {
|
||
rsReader = readers["model.safetensors"]!
|
||
}
|
||
let rsData = try rsReader.read(tensor: rsDesc)
|
||
let rsFloats = SafeTensorsReader.bf16ToFloat32(rsData)
|
||
let rawRouterScale = rsFloats.first ?? 1.0
|
||
// Normalize router scale by hidden_size (similar to scales normalization for 26B-Standard)
|
||
// This prevents softmax overflow in MoE router computation
|
||
routerScale = rawRouterScale / Float(hiddenSize)
|
||
}
|
||
|
||
// Load per_expert_scale ([numExperts])
|
||
if let pesDesc = allTensors.first(where: { $0.name == "\(prefix).router.per_expert_scale" }) {
|
||
let pesReader: SafeTensorsReader
|
||
if let idx = index, let shard = idx.weightMap[pesDesc.name] {
|
||
pesReader = readers[shard]!
|
||
} else {
|
||
pesReader = readers["model.safetensors"]!
|
||
}
|
||
let pesData = try pesReader.read(tensor: pesDesc)
|
||
let pesFloats = SafeTensorsReader.bf16ToFloat32(pesData)
|
||
perExpertScale = pesFloats
|
||
}
|
||
|
||
// Load expert 3D tensors as MoEExpertGroup
|
||
let ep = "\(prefix).experts.switch_glu"
|
||
expertGate = try Self.loadExpertGroup(named: "\(ep).gate_proj", from: allTensors,
|
||
index: index, readers: readers,
|
||
device: engine.device,
|
||
numExperts: numExperts,
|
||
expertOutDim: moeIntermediate,
|
||
expertInDim: hiddenSize)
|
||
expertUp = try Self.loadExpertGroup(named: "\(ep).up_proj", from: allTensors,
|
||
index: index, readers: readers,
|
||
device: engine.device,
|
||
numExperts: numExperts,
|
||
expertOutDim: moeIntermediate,
|
||
expertInDim: hiddenSize)
|
||
expertDown = try Self.loadExpertGroup(named: "\(ep).down_proj", from: allTensors,
|
||
index: index, readers: readers,
|
||
device: engine.device,
|
||
numExperts: numExperts,
|
||
expertOutDim: hiddenSize,
|
||
expertInDim: moeIntermediate)
|
||
|
||
let loaded = (expertGate != nil && expertUp != nil && expertDown != nil) ? numExperts : 0
|
||
print(" MoE: \(loaded)/\(numExperts) experts loaded")
|
||
}
|
||
|
||
let layer = E4BLayer(
|
||
config: lcfg,
|
||
layerIdx: layerIdx,
|
||
inputLayernorm: try norm("input_layernorm.weight"),
|
||
postAttentionLayernorm: try norm("post_attention_layernorm.weight"),
|
||
preFeedforwardLayernorm: try norm("pre_feedforward_layernorm.weight"),
|
||
postFeedforwardLayernorm: try norm("post_feedforward_layernorm.weight"),
|
||
postPerLayerInputNorm: try norm("post_per_layer_input_norm.weight"),
|
||
qNorm: try normStrided("self_attn.q_norm.weight", nHeads: lcfg.nHeads, hd: hd),
|
||
kNorm: try normStrided("self_attn.k_norm.weight", nHeads: lcfg.nKvHeads, hd: hd),
|
||
vNorm: try normStrided("self_attn.v_norm.weight", nHeads: lcfg.nKvHeads, hd: hd),
|
||
qProj: qpQuant, kProj: kpQuant, vProj: vpQuant, oProj: opQuant,
|
||
gateProj: gp, upProj: up, downProj: dp,
|
||
perLayerGate: pg, perLayerProjection: pp,
|
||
qProjFloat: qpFloat, kProjFloat: kpFloat, vProjFloat: vpFloat, oProjFloat: opFloat,
|
||
gateProjFloat: gpFloat, upProjFloat: upFloat, downProjFloat: dpFloat,
|
||
perLayerGateFloat: try fw("per_layer_input_gate"),
|
||
perLayerProjectionFloat: try fw("per_layer_projection"),
|
||
perLayerInput: plSlice,
|
||
perLayerInputScale: perLayerInputScaleVal,
|
||
perLayerProjectionScale: perLayerModelProjectionScaleVal,
|
||
layerScalar: scalar,
|
||
useMoE: useMoE && expertGate != nil,
|
||
routerProj: routerProj,
|
||
routerScale: routerScale,
|
||
perExpertScale: perExpertScale,
|
||
expertGate: expertGate,
|
||
expertUp: expertUp,
|
||
expertDown: expertDown,
|
||
topK: topK,
|
||
kEqualsV: (vpQuant == nil && vpFloat == nil && isFull) || (cfg.attentionKEqualsV ?? false)
|
||
)
|
||
builtLayers.append(layer)
|
||
}
|
||
self.layers = builtLayers
|
||
|
||
// ── KV caches ──
|
||
var caches: [KVCache] = []
|
||
caches.reserveCapacity(numHiddenLayers)
|
||
for layerIdx in 0..<numHiddenLayers {
|
||
let isFull = layerTypesIsFull[layerIdx]
|
||
let hd = isFull ? (cfg.globalHeadDim ?? cfg.headDim ?? 512) : (cfg.slidingHeadDim ?? cfg.headDim ?? 256)
|
||
let cacheNKvHeads = isFull ? (effectiveGlobalKvHeads ?? effectiveNKvHeads) : effectiveNKvHeads
|
||
let cache = KVCache(
|
||
device: engine.device,
|
||
isSliding: !isFull,
|
||
maxContextLength: maxContextLength,
|
||
nKvHeads: cacheNKvHeads,
|
||
headDim: hd
|
||
)
|
||
caches.append(cache)
|
||
}
|
||
self.kvCaches = caches
|
||
|
||
// ── Temps ──
|
||
print("\nCreating ForwardTemps...")
|
||
fflush(stdout)
|
||
let headDim = cfg.headDim ?? 256
|
||
// Use detectedGlobalHd if available
|
||
let globalHeadDim = detectedGlobalHd ?? cfg.globalHeadDim ?? headDim
|
||
let maxHeadDim = max(headDim, globalHeadDim)
|
||
// Max intermediate: shared layers have 2x intermediate_size; safe upper bound
|
||
let baseIntermediate = cfg.intermediateSize ?? 15360
|
||
let maxIntermediate = baseIntermediate * 2
|
||
let nHeads = effectiveNHeads
|
||
let nKvHeads = max(effectiveNKvHeads, effectiveGlobalKvHeads ?? 0)
|
||
print(" headDim: \(headDim), globalHeadDim: \(globalHeadDim), maxHeadDim: \(maxHeadDim), intermediateSize: \(baseIntermediate), maxIntermediate: \(maxIntermediate), nHeads: \(nHeads), nKvHeads: \(nKvHeads)")
|
||
|
||
self.temps = try ForwardTemps(device: engine.device,
|
||
maxHeadDim: maxHeadDim,
|
||
maxIntermediate: maxIntermediate,
|
||
hiddenSize: hiddenSize,
|
||
nHeads: nHeads,
|
||
nKvHeads: nKvHeads)
|
||
print(" ✓ ForwardTemps created")
|
||
fflush(stdout)
|
||
|
||
// ── Logits buffer ──
|
||
print("Creating logits buffer...")
|
||
print(" vocabSize: \(vocabSize), size: \(vocabSize * 4) bytes")
|
||
fflush(stdout)
|
||
guard let lb = engine.device.makeBuffer(
|
||
length: vocabSize * MemoryLayout<Float>.stride,
|
||
options: .storageModeShared
|
||
) else {
|
||
print(" ✗ Failed to create logits buffer")
|
||
throw E4BError.bufferCreationFailed
|
||
}
|
||
print(" ✓ Logits buffer created")
|
||
self.logitsBuffer = lb
|
||
|
||
// ── Final RMSNorm (before LM head) ──
|
||
print("Loading final norm...")
|
||
let finalNormName = "\(P)norm.weight"
|
||
self.finalNorm = try Self.loadNorm(named: finalNormName, from: allTensors,
|
||
index: index, readers: readers,
|
||
device: engine.device)
|
||
print(" ✓ Final norm loaded")
|
||
|
||
print("\n✓ Model initialization completed successfully\n")
|
||
}
|
||
|
||
// ── Kernel dispatch helpers ───────────────────────
|
||
|
||
func rmsNorm(input: MTLBuffer, weight: MTLBuffer?,
|
||
output: MTLBuffer, count: Int, eps: Float,
|
||
inputOffset: Int = 0, weightOffset: Int = 0, outputOffset: Int = 0) throws {
|
||
let pso = try engine.pipeline(named: "rms_norm")
|
||
let cmdBuf = engine.commandQueue.makeCommandBuffer()!
|
||
let enc = cmdBuf.makeComputeCommandEncoder()!
|
||
enc.setComputePipelineState(pso)
|
||
enc.setBuffer(input, offset: inputOffset, index: 0)
|
||
enc.setBuffer(weight, offset: weightOffset, index: 1)
|
||
enc.setBuffer(output, offset: outputOffset, index: 2)
|
||
var N = UInt32(count)
|
||
enc.setBytes(&N, length: MemoryLayout<UInt32>.size, index: 3)
|
||
var e = eps
|
||
enc.setBytes(&e, length: MemoryLayout<Float>.size, index: 4)
|
||
let tg = engine.threadgroupSize1D(pso, count: count)
|
||
enc.dispatchThreads(MTLSize(width: count, height: 1, depth: 1),
|
||
threadsPerThreadgroup: tg)
|
||
enc.endEncoding()
|
||
cmdBuf.commit()
|
||
cmdBuf.waitUntilCompleted()
|
||
}
|
||
|
||
func eltwiseAddScaledModel(a: MTLBuffer, scaleA: Float,
|
||
b: MTLBuffer, scaleB: Float,
|
||
output: MTLBuffer, count: Int) throws {
|
||
let pso = try engine.pipeline(named: "eltwise_add_scaled")
|
||
let cmdBuf = engine.commandQueue.makeCommandBuffer()!
|
||
let enc = cmdBuf.makeComputeCommandEncoder()!
|
||
enc.setComputePipelineState(pso)
|
||
enc.setBuffer(a, offset: 0, index: 0)
|
||
var sa = scaleA
|
||
enc.setBytes(&sa, length: MemoryLayout<Float>.size, index: 1)
|
||
enc.setBuffer(b, offset: 0, index: 2)
|
||
var sb = scaleB
|
||
enc.setBytes(&sb, length: MemoryLayout<Float>.size, index: 3)
|
||
enc.setBuffer(output, offset: 0, index: 4)
|
||
var N = UInt32(count)
|
||
enc.setBytes(&N, length: MemoryLayout<UInt32>.size, index: 5)
|
||
let tg = engine.threadgroupSize1D(pso, count: count)
|
||
enc.dispatchThreads(MTLSize(width: count, height: 1, depth: 1),
|
||
threadsPerThreadgroup: tg)
|
||
enc.endEncoding()
|
||
cmdBuf.commit()
|
||
cmdBuf.waitUntilCompleted()
|
||
}
|
||
|
||
/// Matmul for regular F32 weights (not quantized)
|
||
func matmulBF16(input: MTLBuffer, weight: MTLBuffer, output: MTLBuffer,
|
||
inDim: Int, outDim: Int) throws {
|
||
let pso = try engine.pipeline(named: "matmul_f32")
|
||
let cmdBuf = engine.commandQueue.makeCommandBuffer()!
|
||
let enc = cmdBuf.makeComputeCommandEncoder()!
|
||
enc.setComputePipelineState(pso)
|
||
enc.setBuffer(input, offset: 0, index: 0)
|
||
enc.setBuffer(weight, offset: 0, index: 1)
|
||
enc.setBuffer(output, offset: 0, index: 2)
|
||
var M = UInt32(1) // batch size
|
||
enc.setBytes(&M, length: MemoryLayout<UInt32>.size, index: 3)
|
||
var K = UInt32(inDim)
|
||
enc.setBytes(&K, length: MemoryLayout<UInt32>.size, index: 4)
|
||
var N = UInt32(outDim)
|
||
enc.setBytes(&N, length: MemoryLayout<UInt32>.size, index: 5)
|
||
let tg = MTLSize(width: 32, height: 1, depth: 1)
|
||
enc.dispatchThreads(MTLSize(width: outDim, height: 1, depth: 1),
|
||
threadsPerThreadgroup: tg)
|
||
enc.endEncoding()
|
||
cmdBuf.commit()
|
||
cmdBuf.waitUntilCompleted()
|
||
}
|
||
|
||
// ── Weight loading helpers ────────────────────────
|
||
|
||
private static func loadNorm(named: String, from tensors: [TensorDescriptor],
|
||
index: SafeTensorsIndex?,
|
||
readers: [String: SafeTensorsReader],
|
||
device: MTLDevice) throws -> MTLBuffer? {
|
||
guard let desc = findTensor(named, in: tensors) else { return nil }
|
||
|
||
// Get correct reader
|
||
let reader: SafeTensorsReader
|
||
if let idx = index {
|
||
guard let shardFile = idx.weightMap[desc.name] else { return nil }
|
||
reader = readers[shardFile]!
|
||
} else {
|
||
reader = readers["model.safetensors"]!
|
||
}
|
||
|
||
let data = try reader.read(tensor: desc)
|
||
let floats: [Float]
|
||
if desc.dtype == .bf16 {
|
||
floats = SafeTensorsReader.bf16ToFloat32(data)
|
||
} else if desc.dtype == .f32 {
|
||
floats = data.withUnsafeBytes { Array($0.assumingMemoryBound(to: Float.self)) }
|
||
} else {
|
||
return nil
|
||
}
|
||
guard let buf = device.makeBuffer(
|
||
bytes: floats, length: floats.count * MemoryLayout<Float>.stride,
|
||
options: .storageModeShared
|
||
) else { return nil }
|
||
return buf
|
||
}
|
||
|
||
/// Load a norm weight, handling per-head repeating.
|
||
/// - If tensor has `headDim` elements (e.g. [256]): repeat `nHeads` times → [nHeads * headDim]
|
||
/// - If tensor has `nHeads * maxHeadDim` elements: extract first `headDim` per head at maxHeadDim stride
|
||
/// - Otherwise use as-is ([nHeads * headDim] already).
|
||
private static func loadNormStrided(named: String, from tensors: [TensorDescriptor],
|
||
index: SafeTensorsIndex?,
|
||
readers: [String: SafeTensorsReader],
|
||
device: MTLDevice,
|
||
nHeads: Int, headDim: Int,
|
||
maxHeadDim: Int) throws -> MTLBuffer? {
|
||
guard let desc = findTensor(named, in: tensors) else { return nil }
|
||
|
||
// Get correct reader
|
||
let reader: SafeTensorsReader
|
||
if let idx = index {
|
||
guard let shardFile = idx.weightMap[desc.name] else { return nil }
|
||
reader = readers[shardFile]!
|
||
} else {
|
||
reader = readers["model.safetensors"]!
|
||
}
|
||
|
||
let data = try reader.read(tensor: desc)
|
||
let floats: [Float]
|
||
if desc.dtype == .bf16 {
|
||
floats = SafeTensorsReader.bf16ToFloat32(data)
|
||
} else if desc.dtype == .f32 {
|
||
floats = data.withUnsafeBytes { Array($0.assumingMemoryBound(to: Float.self)) }
|
||
} else {
|
||
return nil
|
||
}
|
||
|
||
let actualCount = nHeads * headDim
|
||
|
||
// Case 1: shared headDim weight — repeat for each head
|
||
if floats.count == headDim && floats.count < actualCount {
|
||
var repeated = [Float](repeating: 0, count: actualCount)
|
||
for h in 0..<nHeads {
|
||
for d in 0..<headDim {
|
||
repeated[h * headDim + d] = floats[d]
|
||
}
|
||
}
|
||
guard let buf = device.makeBuffer(
|
||
bytes: repeated, length: repeated.count * MemoryLayout<Float>.stride,
|
||
options: .storageModeShared
|
||
) else { return nil }
|
||
return buf
|
||
}
|
||
|
||
// Case 2: stored at maxHeadDim stride (larger than needed)
|
||
if floats.count > actualCount {
|
||
var extracted = [Float](repeating: 0, count: actualCount)
|
||
for h in 0..<nHeads {
|
||
for d in 0..<headDim {
|
||
extracted[h * headDim + d] = floats[h * maxHeadDim + d]
|
||
}
|
||
}
|
||
guard let buf = device.makeBuffer(
|
||
bytes: extracted, length: extracted.count * MemoryLayout<Float>.stride,
|
||
options: .storageModeShared
|
||
) else { return nil }
|
||
return buf
|
||
}
|
||
|
||
// Case 3: already correct size
|
||
guard let buf = device.makeBuffer(
|
||
bytes: floats, length: floats.count * MemoryLayout<Float>.stride,
|
||
options: .storageModeShared
|
||
) else { return nil }
|
||
return buf
|
||
}
|
||
|
||
private static func quantizedGroup(named: String, from tensors: [TensorDescriptor],
|
||
index: SafeTensorsIndex?,
|
||
readers: [String: SafeTensorsReader],
|
||
device: MTLDevice,
|
||
bits: Int = 4) throws -> QuantizedWeights? {
|
||
// Build tensor name -> descriptor map for fast lookups
|
||
let tensorMap = Dictionary(uniqueKeysWithValues: tensors.map { ($0.name, $0) })
|
||
// Also add stripped prefix variants
|
||
let prefix = "language_model.model."
|
||
let tensorMapWithPrefix = tensors.reduce(into: [String: TensorDescriptor]()) { dict, desc in
|
||
dict[desc.name] = desc
|
||
if desc.name.hasPrefix(prefix) {
|
||
dict[String(desc.name.dropFirst(prefix.count))] = desc
|
||
}
|
||
}
|
||
func findTensor(_ name: String) -> TensorDescriptor? {
|
||
if let desc = tensorMapWithPrefix[name] { return desc }
|
||
// Try original map in case name doesn't have prefix
|
||
return tensorMap[name]
|
||
}
|
||
|
||
let wName = "\(named).weight"
|
||
let sName = "\(named).scales"
|
||
let bName = "\(named).biases"
|
||
|
||
guard let wDesc = findTensor(wName),
|
||
let sDesc = findTensor(sName)
|
||
else {
|
||
return nil
|
||
}
|
||
|
||
// Biases are optional (e.g., MLX 26B embed_tensors has no biases)
|
||
let bDesc = findTensor(bName)
|
||
|
||
// Get readers for each tensor (may be in different shards)
|
||
let wReader: SafeTensorsReader
|
||
let sReader: SafeTensorsReader
|
||
let bReader: SafeTensorsReader?
|
||
|
||
if let idx = index {
|
||
// Sharded: resolve correct shard for each tensor
|
||
// Use the actual tensor names (may have prefix stripped)
|
||
let actualWName = wDesc.name
|
||
let actualSName = sDesc.name
|
||
let actualBName = bDesc?.name
|
||
|
||
guard let wShard = idx.weightMap[actualWName],
|
||
let sShard = idx.weightMap[actualSName] else {
|
||
return nil
|
||
}
|
||
wReader = readers[wShard]!
|
||
sReader = readers[sShard]!
|
||
if let actualBName = actualBName, let bShard = idx.weightMap[actualBName] {
|
||
bReader = readers[bShard]
|
||
} else {
|
||
bReader = nil
|
||
}
|
||
} else {
|
||
// Single file
|
||
wReader = readers["model.safetensors"]!
|
||
sReader = wReader
|
||
bReader = wReader
|
||
}
|
||
|
||
// Read data from correct readers
|
||
let wData = try wReader.read(tensor: wDesc)
|
||
let sData = try sReader.read(tensor: sDesc)
|
||
let bData = bReader != nil && bDesc != nil ? try bReader!.read(tensor: bDesc!) : nil
|
||
|
||
let sFloats = SafeTensorsReader.bf16ToFloat32(sData)
|
||
let bFloats = bData != nil ? SafeTensorsReader.bf16ToFloat32(bData!) : nil
|
||
|
||
let outDim = wDesc.shape[0]
|
||
// inDim = packed dim * (32 / bits) (e.g. 4-bit: 8 vals/u32, 8-bit: 4 vals/u32)
|
||
let valsPerU32 = 32 / bits
|
||
let inDim = wDesc.shape[1] * valsPerU32
|
||
|
||
// Compute groupSize from scales shape: scales.shape[1] = inDim / groupSize
|
||
let numGroups = sDesc.shape[1]
|
||
let groupSize = inDim / numGroups
|
||
|
||
guard let wBuf = device.makeBuffer(
|
||
bytes: (wData as NSData).bytes, length: wData.count,
|
||
options: .storageModeShared
|
||
) else { return nil }
|
||
guard let sBuf = device.makeBuffer(
|
||
bytes: sFloats, length: sFloats.count * MemoryLayout<Float>.stride,
|
||
options: .storageModeShared
|
||
) else { return nil }
|
||
|
||
// Create zero biases if not present
|
||
let bBuf: MTLBuffer
|
||
if let bFloats = bFloats {
|
||
guard let buf = device.makeBuffer(
|
||
bytes: bFloats, length: bFloats.count * MemoryLayout<Float>.stride,
|
||
options: .storageModeShared
|
||
) else { return nil }
|
||
bBuf = buf
|
||
} else {
|
||
// Create zero biases with same size as scales
|
||
let bFloatsZero = [Float](repeating: 0.0, count: sFloats.count)
|
||
guard let buf = device.makeBuffer(
|
||
bytes: bFloatsZero, length: bFloatsZero.count * MemoryLayout<Float>.stride,
|
||
options: .storageModeShared
|
||
) else { return nil }
|
||
bBuf = buf
|
||
}
|
||
|
||
return QuantizedWeights(weight: wBuf, scales: sBuf, biases: bBuf,
|
||
inDim: inDim, outDim: outDim, bits: bits, groupSize: groupSize)
|
||
}
|
||
|
||
/// Load non-quantized bf16 embedding weights as FloatWeights
|
||
private static func loadFloatEmbed(named: String, from tensors: [TensorDescriptor],
|
||
index: SafeTensorsIndex?,
|
||
readers: [String: SafeTensorsReader],
|
||
device: MTLDevice,
|
||
hiddenSize: Int) throws -> FloatWeights? {
|
||
let tensorMap = Dictionary(uniqueKeysWithValues: tensors.map { ($0.name, $0) })
|
||
let prefix = "language_model.model."
|
||
let modelPrefix = "model.language_model.model."
|
||
let modelPrefixShort = "model.language_model."
|
||
let tensorMapWithPrefix = tensors.reduce(into: [String: TensorDescriptor]()) { dict, desc in
|
||
dict[desc.name] = desc
|
||
if desc.name.hasPrefix(prefix) {
|
||
dict[String(desc.name.dropFirst(prefix.count))] = desc
|
||
}
|
||
if desc.name.hasPrefix(modelPrefix) {
|
||
dict[String(desc.name.dropFirst(modelPrefix.count))] = desc
|
||
}
|
||
if desc.name.hasPrefix(modelPrefixShort) {
|
||
dict[String(desc.name.dropFirst(modelPrefixShort.count))] = desc
|
||
}
|
||
}
|
||
func findTensor(_ name: String) -> TensorDescriptor? {
|
||
if let desc = tensorMapWithPrefix[name] { return desc }
|
||
return tensorMap[name]
|
||
}
|
||
|
||
let wName = "\(named).weight"
|
||
guard let wDesc = findTensor(wName) else {
|
||
return nil
|
||
}
|
||
|
||
if wDesc.dtype != .bf16 {
|
||
return nil
|
||
}
|
||
|
||
let wReader: SafeTensorsReader
|
||
if let idx = index {
|
||
let actualWName = wDesc.name
|
||
guard let wShard = idx.weightMap[actualWName] else { return nil }
|
||
wReader = readers[wShard]!
|
||
} else {
|
||
wReader = readers["model.safetensors"]!
|
||
}
|
||
|
||
let wData = try wReader.read(tensor: wDesc)
|
||
let wFloats = SafeTensorsReader.bf16ToFloat32(wData)
|
||
|
||
let outDim = wDesc.shape[0]
|
||
let inDim = wDesc.shape[1]
|
||
|
||
guard let wBuf = device.makeBuffer(
|
||
bytes: wFloats, length: wFloats.count * MemoryLayout<Float>.stride,
|
||
options: .storageModeShared
|
||
) else { return nil }
|
||
|
||
return FloatWeights(weight: wBuf, inDim: inDim, outDim: outDim)
|
||
}
|
||
|
||
/// Load non-quantized bf16 layer weights as FloatWeights
|
||
private static func loadFloatWeight(named: String, from tensors: [TensorDescriptor],
|
||
index: SafeTensorsIndex?,
|
||
readers: [String: SafeTensorsReader],
|
||
device: MTLDevice) throws -> FloatWeights? {
|
||
let tensorMap = Dictionary(uniqueKeysWithValues: tensors.map { ($0.name, $0) })
|
||
let prefix = "language_model.model."
|
||
let modelPrefix = "model.language_model."
|
||
let tensorMapWithPrefix = tensors.reduce(into: [String: TensorDescriptor]()) { dict, desc in
|
||
dict[desc.name] = desc
|
||
if desc.name.hasPrefix(prefix) {
|
||
dict[String(desc.name.dropFirst(prefix.count))] = desc
|
||
}
|
||
if desc.name.hasPrefix(modelPrefix) {
|
||
dict[String(desc.name.dropFirst(modelPrefix.count))] = desc
|
||
}
|
||
}
|
||
func findTensor(_ name: String) -> TensorDescriptor? {
|
||
if let desc = tensorMapWithPrefix[name] { return desc }
|
||
return tensorMap[name]
|
||
}
|
||
|
||
let wName = "\(named).weight"
|
||
guard let wDesc = findTensor(wName) else { return nil }
|
||
|
||
if wDesc.dtype != .bf16 {
|
||
return nil
|
||
}
|
||
|
||
let wReader: SafeTensorsReader
|
||
if let idx = index {
|
||
let actualWName = wDesc.name
|
||
guard let wShard = idx.weightMap[actualWName] else { return nil }
|
||
wReader = readers[wShard]!
|
||
} else {
|
||
wReader = readers["model.safetensors"]!
|
||
}
|
||
|
||
let wData = try wReader.read(tensor: wDesc)
|
||
let wFloats = SafeTensorsReader.bf16ToFloat32(wData)
|
||
|
||
let outDim = wDesc.shape[0]
|
||
let inDim = wDesc.shape[1]
|
||
|
||
guard let wBuf = device.makeBuffer(
|
||
bytes: wFloats, length: wFloats.count * MemoryLayout<Float>.stride,
|
||
options: .storageModeShared
|
||
) else { return nil }
|
||
|
||
return FloatWeights(weight: wBuf, inDim: inDim, outDim: outDim)
|
||
}
|
||
/// Load a 3D expert tensor [numExperts, expertOutDim, inDimPacked] as a contiguous MoEExpertGroup.
|
||
/// The data layout is: expert0[outDim, inDimPacked], expert1[outDim, inDimPacked], ...
|
||
/// Per-expert access is done via byte offsets into the shared buffers.
|
||
private static func loadExpertGroup(named: String, from tensors: [TensorDescriptor],
|
||
index: SafeTensorsIndex?,
|
||
readers: [String: SafeTensorsReader],
|
||
device: MTLDevice,
|
||
numExperts: Int,
|
||
expertOutDim: Int,
|
||
expertInDim: Int,
|
||
bits: Int = 4) throws -> MoEExpertGroup? {
|
||
let wName = "\(named).weight"
|
||
let sName = "\(named).scales"
|
||
let bName = "\(named).biases"
|
||
|
||
guard let wDesc = findTensor(wName, in: tensors),
|
||
let sDesc = findTensor(sName, in: tensors)
|
||
else {
|
||
print(" loadExpertGroup: missing weight or scales for \(named)")
|
||
return nil
|
||
}
|
||
|
||
// Weight: [numExperts, expertOutDim, inDimPacked] uint32
|
||
guard wDesc.shape.count == 3 else {
|
||
print(" loadExpertGroup: expected 3D weight, got \(wDesc.shape)")
|
||
return nil
|
||
}
|
||
|
||
// Scales: [numExperts, expertOutDim, numGroups] bf16
|
||
// Biases: same shape as scales
|
||
let groupSize = 64
|
||
let numGroups = expertInDim / groupSize
|
||
|
||
// Get readers
|
||
let wReader: SafeTensorsReader
|
||
let sReader: SafeTensorsReader
|
||
let bReader: SafeTensorsReader?
|
||
|
||
if let idx = index {
|
||
guard let wShard = idx.weightMap[wDesc.name],
|
||
let sShard = idx.weightMap[sDesc.name] else { return nil }
|
||
wReader = readers[wShard]!
|
||
sReader = readers[sShard]!
|
||
if let bDesc = findTensor(bName, in: tensors),
|
||
let bShard = idx.weightMap[bDesc.name] {
|
||
bReader = readers[bShard]
|
||
} else {
|
||
bReader = nil
|
||
}
|
||
} else {
|
||
wReader = readers["model.safetensors"]!
|
||
sReader = wReader
|
||
bReader = wReader
|
||
}
|
||
|
||
let wData = try wReader.read(tensor: wDesc)
|
||
let sData = try sReader.read(tensor: sDesc)
|
||
let bDesc = bReader != nil ? findTensor(bName, in: tensors) : nil
|
||
let bData: Data? = bDesc != nil ? try bReader!.read(tensor: bDesc!) : nil
|
||
|
||
let sFloats = SafeTensorsReader.bf16ToFloat32(sData)
|
||
let bFloats = bData != nil ? SafeTensorsReader.bf16ToFloat32(bData!) : nil
|
||
|
||
let valsPerU32 = 32 / bits
|
||
let inDimPacked = expertInDim / valsPerU32
|
||
|
||
// Create weight buffer (full 3D tensor flattened)
|
||
// Expected: wDesc.shape = [numExperts, expertOutDim, inDimPacked]
|
||
// Verify shape matches
|
||
if wDesc.shape[0] != numExperts ||
|
||
wDesc.shape[1] != expertOutDim ||
|
||
wDesc.shape[2] != inDimPacked {
|
||
print(" loadExpertGroup: shape mismatch: got \(wDesc.shape), expected [\(numExperts), \(expertOutDim), \(inDimPacked)]")
|
||
}
|
||
|
||
guard let wBuf = device.makeBuffer(
|
||
bytes: (wData as NSData).bytes, length: wData.count,
|
||
options: .storageModeShared
|
||
) else { return nil }
|
||
|
||
guard let sBuf = device.makeBuffer(
|
||
bytes: sFloats, length: sFloats.count * MemoryLayout<Float>.stride,
|
||
options: .storageModeShared
|
||
) else { return nil }
|
||
|
||
let bBuf: MTLBuffer
|
||
if let bFloats = bFloats {
|
||
guard let buf = device.makeBuffer(
|
||
bytes: bFloats, length: bFloats.count * MemoryLayout<Float>.stride,
|
||
options: .storageModeShared
|
||
) else { return nil }
|
||
bBuf = buf
|
||
} else {
|
||
let zeros = [Float](repeating: 0.0, count: numExperts * expertOutDim * numGroups)
|
||
guard let buf = device.makeBuffer(
|
||
bytes: zeros, length: zeros.count * MemoryLayout<Float>.stride,
|
||
options: .storageModeShared
|
||
) else { return nil }
|
||
bBuf = buf
|
||
}
|
||
|
||
return MoEExpertGroup(
|
||
weight: wBuf, scales: sBuf, biases: bBuf,
|
||
expertOutDim: expertOutDim,
|
||
expertInDim: expertInDim,
|
||
numGroups: numGroups,
|
||
numExperts: numExperts,
|
||
bits: bits
|
||
)
|
||
}
|
||
|
||
/// Helper to find tensor with prefix fallback
|
||
private static func findTensor(_ name: String, in tensors: [TensorDescriptor]) -> TensorDescriptor? {
|
||
if let desc = tensors.first(where: { $0.name == name }) { return desc }
|
||
let prefix = "language_model.model."
|
||
if name.hasPrefix(prefix) {
|
||
let stripped = String(name.dropFirst(prefix.count))
|
||
if let desc = tensors.first(where: { $0.name == stripped }) { return desc }
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// ── Forward ───────────────────────────────────────
|
||
|
||
/// Run one step of the model: token → logits.
|
||
public func forward(tokenId: Int, position: Int, debug: Bool = false) throws -> [Float] {
|
||
let h = temps.io
|
||
|
||
// ── 1. Embedding lookup ──
|
||
try dequantizeRow(weight: embedWeight, tokenId: tokenId, output: h)
|
||
|
||
// Check embedding for NaN
|
||
if position == 0 {
|
||
let embedVals = engine.readFloats(from: h, count: min(20, hiddenSize))
|
||
let hasNaN = embedVals.contains { $0.isNaN }
|
||
let nanCount = embedVals.filter { $0.isNaN }.count
|
||
print("TEXT Embedding: sample=\(embedVals.prefix(10)), NaN=\(nanCount)/\(min(20, hiddenSize)), hasNaN=\(hasNaN)")
|
||
}
|
||
|
||
if debug && position == 0 {
|
||
let hPtr = h.contents().bindMemory(to: Float.self, capacity: hiddenSize)
|
||
print("Pos \(position) token \(tokenId): BEFORE embed_scale h[0:5]=[\(hPtr[0]), \(hPtr[1]), \(hPtr[2]), \(hPtr[3]), \(hPtr[4])]")
|
||
print(" embedScale = \(embedScale)")
|
||
}
|
||
|
||
// ── 2. Embedding scale ──
|
||
if embedScale != 1.0 {
|
||
try scaleBuffer(h, scale: embedScale, count: hiddenSize)
|
||
}
|
||
|
||
// ── 2b. Per-layer embedding (E4B only) ──
|
||
// Only use per-layer embedding if model has it (embedTokensPerLayerWeight != nil)
|
||
let usePerLayer = embedTokensPerLayerWeight != nil && perLayerEmbedBuffer != nil && perLayerContextBuffer != nil
|
||
if usePerLayer, let plWeight = embedTokensPerLayerWeight,
|
||
let plBuf = perLayerEmbedBuffer,
|
||
let ctxBuf = perLayerContextBuffer {
|
||
|
||
let totalPerLayer = perLayerInputSize * numHiddenLayers
|
||
|
||
// Step 1: Token-identity component
|
||
// get_per_layer_inputs: embed_tokens_per_layer(input_ids) with scale sqrt(perLayerSize)
|
||
try dequantizeRow(weight: plWeight, tokenId: tokenId, output: plBuf, nCols: totalPerLayer)
|
||
let plEmbedScale = sqrt(Float(perLayerInputSize))
|
||
try scaleBuffer(plBuf, scale: plEmbedScale, count: totalPerLayer)
|
||
|
||
// Step 2: Context-aware projection
|
||
// project_per_layer_inputs: per_layer_model_projection(inputs_embeds) * scale
|
||
if let projBuf = perLayerModelProjection {
|
||
// Regular matmul (not quantized): [10752, 2560] @ [2560] -> [10752]
|
||
try matmulBF16(input: h, weight: projBuf, output: ctxBuf,
|
||
inDim: hiddenSize, outDim: perLayerModelProjectionOutDim)
|
||
// Scale by 1/sqrt(hiddenSize)
|
||
try scaleBuffer(ctxBuf, scale: perLayerModelProjectionScaleVal, count: totalPerLayer)
|
||
|
||
// Apply per_layer_projection_norm (RMSNorm on each layer's slice)
|
||
// CRITICAL: RMSNorm is NOT safe for in-place with multiple threadgroups
|
||
// Must use separate input/output buffers. Use plBuf as temp.
|
||
if let norm = perLayerProjectionNorm {
|
||
// Norm each layer's slice: ctxBuf -> plBuf
|
||
for layerIdx in 0..<numHiddenLayers {
|
||
let offset = layerIdx * perLayerInputSize
|
||
try rmsNorm(input: ctxBuf, weight: norm, output: plBuf,
|
||
count: perLayerInputSize, eps: rmsNormEps,
|
||
inputOffset: offset * 4, weightOffset: 0, outputOffset: offset * 4)
|
||
}
|
||
|
||
// Copy normed result back to ctxBuf
|
||
let cmdBufNorm = engine.commandQueue.makeCommandBuffer()!
|
||
let blitNorm = cmdBufNorm.makeBlitCommandEncoder()!
|
||
blitNorm.copy(from: plBuf, sourceOffset: 0,
|
||
to: ctxBuf, destinationOffset: 0,
|
||
size: totalPerLayer * 4)
|
||
blitNorm.endEncoding()
|
||
cmdBufNorm.commit()
|
||
cmdBufNorm.waitUntilCompleted()
|
||
|
||
// Re-compute token identity (was overwritten by norm output)
|
||
try dequantizeRow(weight: plWeight, tokenId: tokenId, output: plBuf, nCols: totalPerLayer)
|
||
try scaleBuffer(plBuf, scale: plEmbedScale, count: totalPerLayer)
|
||
}
|
||
|
||
// Combine: (context_projection + token_identity) * per_layer_input_scale
|
||
// Python: return (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale
|
||
try eltwiseAddScaledModel(a: ctxBuf, scaleA: 1.0,
|
||
b: plBuf, scaleB: 1.0,
|
||
output: ctxBuf, count: totalPerLayer)
|
||
try scaleBuffer(ctxBuf, scale: perLayerInputScaleVal, count: totalPerLayer)
|
||
|
||
// Copy to plBuf for layer use
|
||
let cmdBuf = engine.commandQueue.makeCommandBuffer()!
|
||
let blit = cmdBuf.makeBlitCommandEncoder()!
|
||
blit.copy(from: ctxBuf, sourceOffset: 0,
|
||
to: plBuf, destinationOffset: 0,
|
||
size: totalPerLayer * 4)
|
||
blit.endEncoding()
|
||
cmdBuf.commit()
|
||
cmdBuf.waitUntilCompleted()
|
||
}
|
||
}
|
||
|
||
// ── 3–7. Layers, LM head, readback ──
|
||
return try forwardAfterEmbedding(position: position, debug: debug)
|
||
}
|
||
|
||
/// Forward pass using a pre-computed hidden state (skips embedding lookup).
|
||
/// - Parameters:
|
||
/// - hiddenBuffer: MTLBuffer containing one or more hidden states.
|
||
/// - offset: Byte offset into hiddenBuffer for this position's hidden state.
|
||
/// - position: Sequence position (used for KV cache indexing).
|
||
public func forwardFromHidden(hiddenBuffer: MTLBuffer, offset: Int = 0, position: Int, debug: Bool = false) throws -> [Float] {
|
||
let h = temps.io
|
||
let copySize = hiddenSize * MemoryLayout<Float>.stride
|
||
|
||
let cmdBuf = engine.commandQueue.makeCommandBuffer()!
|
||
let blit = cmdBuf.makeBlitCommandEncoder()!
|
||
blit.copy(from: hiddenBuffer, sourceOffset: offset,
|
||
to: h, destinationOffset: 0,
|
||
size: copySize)
|
||
blit.endEncoding()
|
||
cmdBuf.commit()
|
||
cmdBuf.waitUntilCompleted()
|
||
|
||
return try forwardAfterEmbedding(position: position, debug: debug)
|
||
}
|
||
|
||
/// Shared logic after embedding is in temps.io: layers → LM head → readback.
|
||
private func forwardAfterEmbedding(position: Int, debug: Bool = false) throws -> [Float] {
|
||
let h = temps.io
|
||
|
||
// ── 3. Layer loop ──
|
||
if debug && position == 0 {
|
||
print(" KV config: firstKVShared=\(firstKVShared), numKvShared=\(numKvShared), numLayers=\(numHiddenLayers)")
|
||
}
|
||
|
||
if position == 0 {
|
||
print("TEXT Starting layer loop: numHiddenLayers=\(numHiddenLayers)")
|
||
fflush(stdout)
|
||
}
|
||
|
||
for layerIdx in 0..<numHiddenLayers {
|
||
let isOwner = layerIdx < firstKVShared
|
||
let cacheIdx = isOwner ? layerIdx : (kvSourceMap[layerIdx] ?? (layerIdx - numKvShared))
|
||
let cache = kvCaches[cacheIdx]
|
||
|
||
// Per-layer input offset (E4B only)
|
||
let plOffset = perLayerInputSize > 0 ? layerIdx * perLayerInputSize * MemoryLayout<Float>.stride : 0
|
||
|
||
try layers[layerIdx].forward(
|
||
input: h, position: position,
|
||
kvCache: cache, shouldStoreKV: isOwner,
|
||
temps: temps, engine: engine,
|
||
perLayerInput: perLayerEmbedBuffer,
|
||
perLayerInputOffset: plOffset
|
||
)
|
||
|
||
// Debug: check for NaN after each layer (only on position 0, first 5 layers)
|
||
if position == 0 && layerIdx < 5 {
|
||
let hPtr = h.contents().bindMemory(to: Float.self, capacity: hiddenSize)
|
||
let sample = (0..<min(10, hiddenSize)).map { hPtr[$0] }
|
||
let hasNaN = sample.contains { $0.isNaN }
|
||
let nanCount = sample.filter { $0.isNaN }.count
|
||
print("TEXT After Layer \(layerIdx): sample=\(sample.prefix(5)), NaN=\(nanCount)/\(min(10, hiddenSize)), hasNaN=\(hasNaN)")
|
||
fflush(stdout)
|
||
}
|
||
|
||
if debug && position == 0 && layerIdx < 10 {
|
||
let hPtr = h.contents().bindMemory(to: Float.self, capacity: hiddenSize)
|
||
let ls = layers[layerIdx].layerScalar
|
||
let magnitude = sqrt(hPtr[0]*hPtr[0] + hPtr[1]*hPtr[1] + hPtr[2]*hPtr[2] + hPtr[3]*hPtr[3] + hPtr[4]*hPtr[4])
|
||
print(" ✓ After Layer \(layerIdx): h[0:5]=[\(hPtr[0]), \(hPtr[1]), \(hPtr[2]), \(hPtr[3]), \(hPtr[4])], scalar=\(ls), mag=\(magnitude)")
|
||
}
|
||
}
|
||
|
||
// ── 4. Final RMSNorm ──
|
||
var lmInput = h
|
||
|
||
// Check hidden state after layers
|
||
if position == 0 {
|
||
let hVals = engine.readFloats(from: h, count: min(20, hiddenSize))
|
||
let hasNaN = hVals.contains { $0.isNaN }
|
||
let nanCount = hVals.filter { $0.isNaN }.count
|
||
print("TEXT After layers: sample=\(hVals.prefix(10)), NaN=\(nanCount)/\(min(20, hiddenSize)), hasNaN=\(hasNaN)")
|
||
}
|
||
|
||
if let fn = finalNorm {
|
||
try rmsNorm(input: h, weight: fn, output: temps.ns,
|
||
count: hiddenSize, eps: rmsNormEps)
|
||
lmInput = temps.ns
|
||
|
||
// Debug: check hidden state after norm
|
||
if position == 0 {
|
||
let hiddenVals = engine.readFloats(from: temps.ns, count: min(20, hiddenSize))
|
||
let hasNaN = hiddenVals.contains { $0.isNaN }
|
||
let nanCount = hiddenVals.filter { $0.isNaN }.count
|
||
print("TEXT After finalNorm: sample=\(hiddenVals.prefix(10)), NaN=\(nanCount)/\(min(20, hiddenSize)), hasNaN=\(hasNaN)")
|
||
}
|
||
}
|
||
|
||
// ── 5. LM head (tied embeddings) ──
|
||
try quantizedMatmulModel(input: lmInput, weights: embedWeight, output: logitsBuffer)
|
||
|
||
// Check logits after LM head
|
||
if position == 0 {
|
||
let logitsVals = engine.readFloats(from: logitsBuffer, count: min(20, vocabSize))
|
||
let hasNaN = logitsVals.contains { $0.isNaN }
|
||
let nanCount = logitsVals.filter { $0.isNaN }.count
|
||
print("TEXT After LM head: sample=\(logitsVals.prefix(10)), NaN=\(nanCount)/\(min(20, vocabSize)), hasNaN=\(hasNaN)")
|
||
}
|
||
|
||
// ── 5b. Logits scaling for custom quantization (groupSize=32) ──
|
||
// For groupSize=32 models, logits are ~200x larger than standard
|
||
// Need to scale by ~0.00486 to normalize to E4B-like range
|
||
if embedWeight.groupSize == 32 && embedWeight.inDim == hiddenSize {
|
||
// Total scaling: 1/sqrt(hidden_size) * (30/116) ≈ 0.00486
|
||
// This brings logits to similar range as E4B
|
||
let logitsScale = Float(30.0 / 116.23 / sqrt(Float(hiddenSize)))
|
||
if position == 0 {
|
||
print(" ⚠ Scaling logits by \(logitsScale) for groupSize=32 custom quantization")
|
||
fflush(stdout)
|
||
}
|
||
try scaleBuffer(logitsBuffer, scale: logitsScale, count: vocabSize)
|
||
}
|
||
|
||
// ── 6. Logit softcapping ──
|
||
if let cap = finalLogitSoftcapping {
|
||
if debug && position == 0 {
|
||
print(" Applying logit softcapping with cap=\(cap)")
|
||
}
|
||
try applyLogitSoftcapping(buffer: logitsBuffer, cap: cap, count: vocabSize)
|
||
} else {
|
||
if debug && position == 0 {
|
||
print(" No logit softcapping (cap is nil)")
|
||
}
|
||
}
|
||
|
||
// ── 7. Read back ──
|
||
let logits = engine.readFloats(from: logitsBuffer, count: vocabSize)
|
||
|
||
if debug && position < 3 {
|
||
let maxLogit = logits.max() ?? 0
|
||
let minLogit = logits.min() ?? 0
|
||
let sorted = logits.enumerated().sorted(by: { $0.element > $1.element })
|
||
let top5 = sorted.prefix(5).map { "\($0.offset): \($0.element)" }
|
||
print(" Final logits: max=\(maxLogit), min=\(minLogit), top5: \(top5)")
|
||
}
|
||
|
||
return logits
|
||
}
|
||
|
||
// ── Model-level kernel dispatches ─────────────────
|
||
|
||
func dequantizeRow(weight: QuantizedWeights, tokenId: Int,
|
||
output: MTLBuffer, nCols: Int? = nil) throws {
|
||
let pso = try engine.pipeline(named: "dequantize_row")
|
||
|
||
let cmdBuf = engine.commandQueue.makeCommandBuffer()!
|
||
let enc = cmdBuf.makeComputeCommandEncoder()!
|
||
enc.setComputePipelineState(pso)
|
||
enc.setBuffer(weight.weight, offset: 0, index: 0)
|
||
enc.setBuffer(weight.scales, offset: 0, index: 1)
|
||
enc.setBuffer(weight.biases, offset: 0, index: 2)
|
||
enc.setBuffer(output, offset: 0, index: 3)
|
||
let actualCols = nCols ?? hiddenSize
|
||
var nColsVal = UInt32(actualCols)
|
||
enc.setBytes(&nColsVal, length: MemoryLayout<UInt32>.size, index: 4)
|
||
var row = Int32(tokenId)
|
||
enc.setBytes(&row, length: MemoryLayout<Int32>.size, index: 5)
|
||
var groupSize = UInt32(weight.groupSize)
|
||
enc.setBytes(&groupSize, length: MemoryLayout<UInt32>.size, index: 6)
|
||
let count = actualCols
|
||
let tg = engine.threadgroupSize1D(pso, count: count)
|
||
enc.dispatchThreads(MTLSize(width: count, height: 1, depth: 1),
|
||
threadsPerThreadgroup: tg)
|
||
enc.endEncoding()
|
||
cmdBuf.commit()
|
||
cmdBuf.waitUntilCompleted()
|
||
}
|
||
|
||
func scaleBuffer(_ buf: MTLBuffer, scale: Float, count: Int) throws {
|
||
let pso = try engine.pipeline(named: "eltwise_scale")
|
||
let cmdBuf = engine.commandQueue.makeCommandBuffer()!
|
||
let enc = cmdBuf.makeComputeCommandEncoder()!
|
||
enc.setComputePipelineState(pso)
|
||
enc.setBuffer(buf, offset: 0, index: 0)
|
||
var s = scale
|
||
enc.setBytes(&s, length: MemoryLayout<Float>.size, index: 1)
|
||
var N = UInt32(count)
|
||
enc.setBytes(&N, length: MemoryLayout<UInt32>.size, index: 2)
|
||
let tg = engine.threadgroupSize1D(pso, count: count)
|
||
enc.dispatchThreads(MTLSize(width: count, height: 1, depth: 1),
|
||
threadsPerThreadgroup: tg)
|
||
enc.endEncoding()
|
||
cmdBuf.commit()
|
||
cmdBuf.waitUntilCompleted()
|
||
}
|
||
|
||
func quantizedMatmulModel(input: MTLBuffer, weights: QuantizedWeights,
|
||
output: MTLBuffer) throws {
|
||
let pso = try engine.pipeline(named: "quantized_matmul")
|
||
let cmdBuf = engine.commandQueue.makeCommandBuffer()!
|
||
let enc = cmdBuf.makeComputeCommandEncoder()!
|
||
enc.setComputePipelineState(pso)
|
||
enc.setBuffer(input, offset: 0, index: 0)
|
||
enc.setBuffer(weights.weight, offset: 0, index: 1)
|
||
enc.setBuffer(weights.scales, offset: 0, index: 2)
|
||
enc.setBuffer(weights.biases, offset: 0, index: 3)
|
||
enc.setBuffer(output, offset: 0, index: 4)
|
||
var inDim = UInt32(weights.inDim)
|
||
enc.setBytes(&inDim, length: MemoryLayout<UInt32>.size, index: 5)
|
||
var outDim = UInt32(weights.outDim)
|
||
enc.setBytes(&outDim, length: MemoryLayout<UInt32>.size, index: 6)
|
||
var groupSize = UInt32(weights.groupSize)
|
||
enc.setBytes(&groupSize, length: MemoryLayout<UInt32>.size, index: 7)
|
||
let count = weights.outDim
|
||
let tg = engine.threadgroupSize1D(pso, count: count)
|
||
enc.dispatchThreads(MTLSize(width: count, height: 1, depth: 1),
|
||
threadsPerThreadgroup: tg)
|
||
enc.endEncoding()
|
||
cmdBuf.commit()
|
||
cmdBuf.waitUntilCompleted()
|
||
}
|
||
|
||
func applyLogitSoftcapping(buffer: MTLBuffer, cap: Float, count: Int) throws {
|
||
let pso = try engine.pipeline(named: "tanh_scale")
|
||
let cmdBuf = engine.commandQueue.makeCommandBuffer()!
|
||
let enc = cmdBuf.makeComputeCommandEncoder()!
|
||
enc.setComputePipelineState(pso)
|
||
enc.setBuffer(buffer, offset: 0, index: 0)
|
||
enc.setBuffer(buffer, offset: 0, index: 1) // in-place
|
||
var c = cap
|
||
enc.setBytes(&c, length: MemoryLayout<Float>.size, index: 2)
|
||
var N = UInt32(count)
|
||
enc.setBytes(&N, length: MemoryLayout<UInt32>.size, index: 3)
|
||
let tg = engine.threadgroupSize1D(pso, count: count)
|
||
enc.dispatchThreads(MTLSize(width: count, height: 1, depth: 1),
|
||
threadsPerThreadgroup: tg)
|
||
enc.endEncoding()
|
||
cmdBuf.commit()
|
||
cmdBuf.waitUntilCompleted()
|
||
}
|
||
}
|