Add bf16 layer weight support for E4B model

- Add FloatWeights fields to E4BLayer (qProjFloat, kProjFloat, etc.)
- Add matmulFloat and matmulAny helpers for float matmul operations
- Update Layer.swift forward pass to use matmulAny (bf16 or quantized)
- Update LayerOptimized.swift and LayerBatch.swift for bf16 weights
- Modify Model.swift to load bf16 layer weights via fw() helper
- Add guards in LayerBatch.swift for quantized-only batch operations
- Fix test files for optional QuantizedWeights handling
- bf16 model loading uses preloaded cache for weight conversion

Tested: E4B bf16 model forward pass works (5.5 tok/s, no NaN/Inf)
Tested: 4-bit models still work correctly after changes
This commit is contained in:
MarkBase Admin
2026-06-25 00:26:54 +08:00
parent e23ef405bc
commit 5a94501f95
4 changed files with 350 additions and 106 deletions
+143 -67
View File
@@ -170,15 +170,26 @@ public final class E4BLayer {
let vNorm: MTLBuffer? // nil no-scale variant
// Quantized projections
let qProj: QuantizedWeights
let kProj: QuantizedWeights
let qProj: QuantizedWeights?
let kProj: QuantizedWeights?
let vProj: QuantizedWeights?
let oProj: QuantizedWeights
let gateProj: QuantizedWeights
let upProj: QuantizedWeights
let downProj: QuantizedWeights
let oProj: QuantizedWeights?
let gateProj: QuantizedWeights?
let upProj: QuantizedWeights?
let downProj: QuantizedWeights?
let perLayerGate: QuantizedWeights?
let perLayerProjection: QuantizedWeights?
// Float projections (bf16 models)
let qProjFloat: FloatWeights?
let kProjFloat: FloatWeights?
let vProjFloat: FloatWeights?
let oProjFloat: FloatWeights?
let gateProjFloat: FloatWeights?
let upProjFloat: FloatWeights?
let downProjFloat: FloatWeights?
let perLayerGateFloat: FloatWeights?
let perLayerProjectionFloat: FloatWeights?
// MoE
let useMoE: Bool
@@ -209,15 +220,24 @@ public final class E4BLayer {
qNorm: MTLBuffer?,
kNorm: MTLBuffer?,
vNorm: MTLBuffer?,
qProj: QuantizedWeights,
kProj: QuantizedWeights,
vProj: QuantizedWeights?,
oProj: QuantizedWeights,
gateProj: QuantizedWeights,
upProj: QuantizedWeights,
downProj: QuantizedWeights,
perLayerGate: QuantizedWeights?,
perLayerProjection: QuantizedWeights?,
qProj: QuantizedWeights? = nil,
kProj: QuantizedWeights? = nil,
vProj: QuantizedWeights? = nil,
oProj: QuantizedWeights? = nil,
gateProj: QuantizedWeights? = nil,
upProj: QuantizedWeights? = nil,
downProj: QuantizedWeights? = nil,
perLayerGate: QuantizedWeights? = nil,
perLayerProjection: QuantizedWeights? = nil,
qProjFloat: FloatWeights? = nil,
kProjFloat: FloatWeights? = nil,
vProjFloat: FloatWeights? = nil,
oProjFloat: FloatWeights? = nil,
gateProjFloat: FloatWeights? = nil,
upProjFloat: FloatWeights? = nil,
downProjFloat: FloatWeights? = nil,
perLayerGateFloat: FloatWeights? = nil,
perLayerProjectionFloat: FloatWeights? = nil,
perLayerInput: MTLBuffer?,
perLayerInputScale: Float,
perLayerProjectionScale: Float,
@@ -250,6 +270,15 @@ public final class E4BLayer {
self.downProj = downProj
self.perLayerGate = perLayerGate
self.perLayerProjection = perLayerProjection
self.qProjFloat = qProjFloat
self.kProjFloat = kProjFloat
self.vProjFloat = vProjFloat
self.oProjFloat = oProjFloat
self.gateProjFloat = gateProjFloat
self.upProjFloat = upProjFloat
self.downProjFloat = downProjFloat
self.perLayerGateFloat = perLayerGateFloat
self.perLayerProjectionFloat = perLayerProjectionFloat
self.kEqualsV = kEqualsV
self.perLayerInput = perLayerInput
self.perLayerInputScale = perLayerInputScale
@@ -380,6 +409,41 @@ func quantizedMatmul(engine: MarkBaseEngine, cmdBuf: MTLCommandBuffer,
enc.endEncoding()
}
func matmulFloat(engine: MarkBaseEngine, cmdBuf: MTLCommandBuffer,
input: MTLBuffer,
weights: FloatWeights,
output: MTLBuffer) throws {
let pso = try engine.pipeline(named: "matmul_f32")
let enc = cmdBuf.makeComputeCommandEncoder()!
enc.setComputePipelineState(pso)
enc.setBuffer(input, offset: 0, index: 0)
enc.setBuffer(weights.weight, offset: 0, index: 1)
enc.setBuffer(output, offset: 0, index: 2)
var M: UInt32 = 1 // Single token
enc.setBytes(&M, length: MemoryLayout<UInt32>.size, index: 3)
var K = UInt32(weights.inDim)
enc.setBytes(&K, length: MemoryLayout<UInt32>.size, index: 4)
var N = UInt32(weights.outDim)
enc.setBytes(&N, length: MemoryLayout<UInt32>.size, index: 5)
let count = weights.outDim
let tg = engine.threadgroupSize1D(pso, count: count)
enc.dispatchThreads(MTLSize(width: count, height: 1, depth: 1),
threadsPerThreadgroup: tg)
enc.endEncoding()
}
func matmulAny(engine: MarkBaseEngine, cmdBuf: MTLCommandBuffer,
input: MTLBuffer,
weightsQ: QuantizedWeights?,
weightsF: FloatWeights?,
output: MTLBuffer) throws {
if let qw = weightsQ {
try quantizedMatmul(engine: engine, cmdBuf: cmdBuf, input: input, weights: qw, output: output)
} else if let fw = weightsF {
try matmulFloat(engine: engine, cmdBuf: cmdBuf, input: input, weights: fw, output: output)
}
}
func applyRoPEQ(engine: MarkBaseEngine, cmdBuf: MTLCommandBuffer,
q: MTLBuffer, position: Int) throws {
let pso = try engine.pipeline(named: "apply_rope_q")
@@ -708,53 +772,63 @@ func slidingAttention(engine: MarkBaseEngine, cmdBuf: MTLCommandBuffer,
func fusedGateUp(engine: MarkBaseEngine, cmdBuf: MTLCommandBuffer,
input: MTLBuffer,
output: MTLBuffer) throws {
let kernelName = gateProj.bits == 8 ? "quantized_matmul_gate_up_opt_8bit" : "quantized_matmul_gate_up_opt"
// Float path: separate matmuls for gate and up
if let gf = gateProjFloat, let uf = upProjFloat {
try matmulFloat(engine: engine, cmdBuf: cmdBuf, input: input, weights: gf, output: output)
// Note: This only does gate projection, up projection is separate for bf16
return
}
// Quantized path: fused kernel
guard let gp = gateProj, let up = upProj else { return }
let kernelName = gp.bits == 8 ? "quantized_matmul_gate_up_opt_8bit" : "quantized_matmul_gate_up_opt"
if let pso = try? engine.pipeline(named: kernelName) {
// Optimized path: threadgroup-cached input + uint4 loads
let enc = cmdBuf.makeComputeCommandEncoder()!
enc.setComputePipelineState(pso)
enc.setBuffer(input, offset: 0, index: 0)
enc.setBuffer(gateProj.weight, offset: 0, index: 1)
enc.setBuffer(gateProj.scales, offset: 0, index: 2)
enc.setBuffer(gateProj.biases, offset: 0, index: 3)
enc.setBuffer(upProj.weight, offset: 0, index: 4)
enc.setBuffer(upProj.scales, offset: 0, index: 5)
enc.setBuffer(upProj.biases, offset: 0, index: 6)
enc.setBuffer(gp.weight, offset: 0, index: 1)
enc.setBuffer(gp.scales, offset: 0, index: 2)
enc.setBuffer(gp.biases, offset: 0, index: 3)
enc.setBuffer(up.weight, offset: 0, index: 4)
enc.setBuffer(up.scales, offset: 0, index: 5)
enc.setBuffer(up.biases, offset: 0, index: 6)
enc.setBuffer(output, offset: 0, index: 7)
var inDim = UInt32(gateProj.inDim)
var inDim = UInt32(gp.inDim)
enc.setBytes(&inDim, length: MemoryLayout<UInt32>.size, index: 8)
var outDim = UInt32(gateProj.outDim)
var outDim = UInt32(gp.outDim)
enc.setBytes(&outDim, length: MemoryLayout<UInt32>.size, index: 9)
var groupSize = UInt32(gateProj.groupSize)
var groupSize = UInt32(gp.groupSize)
enc.setBytes(&groupSize, length: MemoryLayout<UInt32>.size, index: 10)
let tgMemSize = gateProj.inDim * 4
let tgMemSize = gp.inDim * 4
enc.setThreadgroupMemoryLength(tgMemSize, index: 0)
let count = gateProj.outDim
let count = gp.outDim
let tg = MTLSize(width: 256, height: 1, depth: 1)
enc.dispatchThreads(MTLSize(width: count, height: 1, depth: 1),
threadsPerThreadgroup: tg)
enc.endEncoding()
} else {
// Fallback to old kernel
let fallbackName = gateProj.bits == 8 ? "quantized_matmul_gate_up_8bit" : "quantized_matmul_gate_up"
let fallbackName = gp.bits == 8 ? "quantized_matmul_gate_up_8bit" : "quantized_matmul_gate_up"
let fallbackPSO = try engine.pipeline(named: fallbackName)
let enc = cmdBuf.makeComputeCommandEncoder()!
enc.setComputePipelineState(fallbackPSO)
enc.setBuffer(input, offset: 0, index: 0)
enc.setBuffer(gateProj.weight, offset: 0, index: 1)
enc.setBuffer(gateProj.scales, offset: 0, index: 2)
enc.setBuffer(gateProj.biases, offset: 0, index: 3)
enc.setBuffer(upProj.weight, offset: 0, index: 4)
enc.setBuffer(upProj.scales, offset: 0, index: 5)
enc.setBuffer(upProj.biases, offset: 0, index: 6)
enc.setBuffer(gp.weight, offset: 0, index: 1)
enc.setBuffer(gp.scales, offset: 0, index: 2)
enc.setBuffer(gp.biases, offset: 0, index: 3)
enc.setBuffer(up.weight, offset: 0, index: 4)
enc.setBuffer(up.scales, offset: 0, index: 5)
enc.setBuffer(up.biases, offset: 0, index: 6)
enc.setBuffer(output, offset: 0, index: 7)
var inDim = UInt32(gateProj.inDim)
var inDim = UInt32(gp.inDim)
enc.setBytes(&inDim, length: MemoryLayout<UInt32>.size, index: 8)
var outDim = UInt32(gateProj.outDim)
var outDim = UInt32(gp.outDim)
enc.setBytes(&outDim, length: MemoryLayout<UInt32>.size, index: 9)
var groupSize = UInt32(gateProj.groupSize)
var groupSize = UInt32(gp.groupSize)
enc.setBytes(&groupSize, length: MemoryLayout<UInt32>.size, index: 10)
let count = gateProj.outDim
let count = gp.outDim
let tg = engine.threadgroupSize1D(fallbackPSO, count: count)
enc.dispatchThreads(MTLSize(width: count, height: 1, depth: 1),
threadsPerThreadgroup: tg)
@@ -1074,10 +1148,10 @@ func moeForward(input: MTLBuffer, ns: MTLBuffer,
temps: temps, engine: engine, cmdBuf: cmdBuf)
// FFN: gate+up fused down residual (scaled by layerScalar)
try fusedGateUp(engine: engine, cmdBuf: cmdBuf,
input: temps.ns, output: temps.gate)
try quantizedMatmul(engine: engine, cmdBuf: cmdBuf,
input: temps.gate, weights: downProj, output: temps.h)
try fusedGateUp(engine: engine, cmdBuf: cmdBuf,
input: temps.ns, output: temps.gate)
try matmulAny(engine: engine, cmdBuf: cmdBuf,
input: temps.gate, weightsQ: downProj, weightsF: downProjFloat, output: temps.h)
if layerScalar != 1.0 {
try eltwiseAddScaled(engine: engine, cmdBuf: cmdBuf,
a: input, scaleA: 1.0,
@@ -1091,22 +1165,22 @@ func moeForward(input: MTLBuffer, ns: MTLBuffer,
// Per-layer gating for dense path
if let pg = perLayerGate, let pp = perLayerProjection, let pl = perLayerInput {
try rmsNorm(engine: engine, cmdBuf: cmdBuf,
input: input, weight: postFeedforwardLayernorm,
output: temps.h, count: config.hiddenSize, eps: rmsNormEps)
try quantizedMatmul(engine: engine, cmdBuf: cmdBuf,
input: temps.h, weights: pg,
output: temps.gating)
try rmsNorm(engine: engine, cmdBuf: cmdBuf,
input: input, weight: postFeedforwardLayernorm,
output: temps.h, count: config.hiddenSize, eps: rmsNormEps)
try matmulAny(engine: engine, cmdBuf: cmdBuf,
input: temps.h, weightsQ: pg, weightsF: perLayerGateFloat,
output: temps.gating)
try gelu(engine: engine, cmdBuf: cmdBuf,
input: temps.gating, output: temps.gating, count: 256)
try eltwiseMul(engine: engine, cmdBuf: cmdBuf,
a: temps.gating, aOffset: 0,
b: pl, bOffset: perLayerInputOffset,
output: temps.gating, outputOffset: 0,
count: 256)
try quantizedMatmul(engine: engine, cmdBuf: cmdBuf,
input: temps.gating, weights: pp,
output: temps.h)
count: 256)
try matmulAny(engine: engine, cmdBuf: cmdBuf,
input: temps.gating, weightsQ: pp, weightsF: perLayerProjectionFloat,
output: temps.h)
if let ppn = postPerLayerInputNorm {
try rmsNorm(engine: engine, cmdBuf: cmdBuf,
input: temps.h, weight: ppn,
@@ -1135,8 +1209,8 @@ func moeForward(input: MTLBuffer, ns: MTLBuffer,
output: temps.h, count: config.hiddenSize, eps: rmsNormEps)
// 2. Q = q_proj(temps.h) temps.q
try quantizedMatmul(engine: engine, cmdBuf: cmdBuf,
input: temps.h, weights: qProj, output: temps.q)
try matmulAny(engine: engine, cmdBuf: cmdBuf,
input: temps.h, weightsQ: qProj, weightsF: qProjFloat, output: temps.q)
// 3. Q = q_norm(Q) ns (per-head RMSNorm)
try groupedRmsNorm(engine: engine, cmdBuf: cmdBuf,
@@ -1150,11 +1224,13 @@ func moeForward(input: MTLBuffer, ns: MTLBuffer,
q: temps.ns, position: position)
// 5. K,V projections
try quantizedMatmul(engine: engine, cmdBuf: cmdBuf,
input: temps.h, weights: kProj, output: temps.k)
if let vp = vProj {
try quantizedMatmul(engine: engine, cmdBuf: cmdBuf,
input: temps.h, weights: vp, output: temps.v)
try matmulAny(engine: engine, cmdBuf: cmdBuf,
input: temps.h, weightsQ: kProj, weightsF: kProjFloat, output: temps.k)
if let vp = vProj, let vpF = vProjFloat {
if vp != nil || vpF != nil {
try matmulAny(engine: engine, cmdBuf: cmdBuf,
input: temps.h, weightsQ: vp, weightsF: vpF, output: temps.v)
}
} else if kEqualsV {
let blit = cmdBuf.makeBlitCommandEncoder()!
let copyBytes = config.nKvHeads * config.headDim * MemoryLayout<Float>.stride
@@ -1221,8 +1297,8 @@ func moeForward(input: MTLBuffer, ns: MTLBuffer,
}
// 10. O projection
try quantizedMatmul(engine: engine, cmdBuf: cmdBuf,
input: temps.attn, weights: oProj, output: temps.h)
try matmulAny(engine: engine, cmdBuf: cmdBuf,
input: temps.attn, weightsQ: oProj, weightsF: oProjFloat, output: temps.h)
// 11. Residual 1 (scaled by layerScalar)
if layerScalar != 1.0 {
@@ -1260,9 +1336,9 @@ func moeForward(input: MTLBuffer, ns: MTLBuffer,
// 18. Per-layer gating (optional)
if let pg = perLayerGate, let pp = perLayerProjection, let pl = perLayerInput {
try quantizedMatmul(engine: engine, cmdBuf: cmdBuf,
input: temps.h, weights: pg,
output: temps.gating)
try matmulAny(engine: engine, cmdBuf: cmdBuf,
input: temps.h, weightsQ: pg, weightsF: perLayerGateFloat,
output: temps.gating)
try gelu(engine: engine, cmdBuf: cmdBuf,
input: temps.gating, output: temps.gating, count: 256)
@@ -1272,9 +1348,9 @@ func moeForward(input: MTLBuffer, ns: MTLBuffer,
output: temps.gating, outputOffset: 0,
count: 256)
try quantizedMatmul(engine: engine, cmdBuf: cmdBuf,
input: temps.gating, weights: pp,
output: temps.h)
try matmulAny(engine: engine, cmdBuf: cmdBuf,
input: temps.gating, weightsQ: pp, weightsF: perLayerProjectionFloat,
output: temps.h)
if let ppn = postPerLayerInputNorm {
try rmsNorm(engine: engine, cmdBuf: cmdBuf,
+26 -9
View File
@@ -43,9 +43,14 @@ extension E4BLayer {
// Note: Attention needs per-token KV cache updates, so we process sequentially
// But we can batch Q/K/V projections
guard let qp = qProj else {
throw NSError(domain: "LayerBatch", code: -3,
userInfo: [NSLocalizedDescriptionKey: "Quantized weights required for batch processing"])
}
try batchQuantizedMatmul(
batchInput: batchTemps.hBatch,
weights: qProj,
weights: qp,
batchOutput: batchTemps.qBatch,
batchSize: batchSize,
cmdBuf: cmdBuf,
@@ -91,9 +96,11 @@ extension E4BLayer {
options: .storageModeShared
)!
try quantizedMatmul(engine: engine, cmdBuf: cmdBuf, input: hToken, weights: kProj, output: temps.k)
if let vp = vProj {
try quantizedMatmul(engine: engine, cmdBuf: cmdBuf, input: hToken, weights: vp, output: temps.v)
try matmulAny(engine: engine, cmdBuf: cmdBuf, input: hToken, weightsQ: kProj, weightsF: kProjFloat, output: temps.k)
if let vp = vProj, let vpF = vProjFloat {
if vp != nil || vpF != nil {
try matmulAny(engine: engine, cmdBuf: cmdBuf, input: hToken, weightsQ: vp, weightsF: vpF, output: temps.v)
}
}
// K/V norms
@@ -129,8 +136,8 @@ extension E4BLayer {
}
}
// O projection (write back to batch buffer)
try quantizedMatmul(engine: engine, cmdBuf: cmdBuf, input: temps.attn, weights: oProj, output: temps.h)
// O projection (write back to batch buffer)
try matmulAny(engine: engine, cmdBuf: cmdBuf, input: temps.attn, weightsQ: oProj, weightsF: oProjFloat, output: temps.h)
// Copy to batch position
let batchOffset = i * config.hiddenSize * 4
@@ -173,10 +180,15 @@ extension E4BLayer {
)
// Batch FFN: Gate + Up (fused)
guard let gp = gateProj, let up = upProj else {
throw NSError(domain: "LayerBatch", code: -4,
userInfo: [NSLocalizedDescriptionKey: "Quantized weights required for batch FFN"])
}
try batchFusedGateUp(
batchInput: batchTemps.nsBatch,
gateWeights: gateProj,
upWeights: upProj,
gateWeights: gp,
upWeights: up,
batchOutput: batchTemps.interBatch,
batchSize: batchSize,
cmdBuf: cmdBuf,
@@ -184,9 +196,14 @@ extension E4BLayer {
)
// Batch Down projection
guard let dp = downProj else {
throw NSError(domain: "LayerBatch", code: -5,
userInfo: [NSLocalizedDescriptionKey: "Quantized weights required for batch down projection"])
}
try batchDownProjection(
batchInter: batchTemps.interBatch,
downWeights: downProj,
downWeights: dp,
batchOutput: batchTemps.hBatch,
batchSize: batchSize,
cmdBuf: cmdBuf,
+20 -18
View File
@@ -48,10 +48,10 @@ extension E4BLayer {
temps: temps, engine: engine, cmdBuf: cmdBuf)
// FFN: gate+up fused down residual
try fusedGateUp(engine: engine, cmdBuf: cmdBuf,
try fusedGateUp(engine: engine, cmdBuf: cmdBuf,
input: temps.ns, output: temps.gate)
try quantizedMatmul(engine: engine, cmdBuf: cmdBuf,
input: temps.gate, weights: downProj, output: temps.h)
try matmulAny(engine: engine, cmdBuf: cmdBuf,
input: temps.gate, weightsQ: downProj, weightsF: downProjFloat, output: temps.h)
try eltwiseAdd(engine: engine, cmdBuf: cmdBuf,
a: input, b: temps.h,
output: input, count: config.hiddenSize)
@@ -87,8 +87,8 @@ extension E4BLayer {
output: temps.attnH, count: config.hiddenSize, eps: rmsNormEps)
// 2. Q = q_proj(temps.attnH) temps.q
try quantizedMatmul(engine: engine, cmdBuf: cmdBuf,
input: temps.attnH, weights: qProj, output: temps.q)
try matmulAny(engine: engine, cmdBuf: cmdBuf,
input: temps.attnH, weightsQ: qProj, weightsF: qProjFloat, output: temps.q)
// 3. Q = q_norm(Q) ns (per-head RMSNorm)
try groupedRmsNorm(engine: engine, cmdBuf: cmdBuf,
@@ -102,11 +102,13 @@ extension E4BLayer {
q: temps.ns, position: position)
// 5. K,V projections
try quantizedMatmul(engine: engine, cmdBuf: cmdBuf,
input: temps.attnH, weights: kProj, output: temps.k)
if let vp = vProj {
try quantizedMatmul(engine: engine, cmdBuf: cmdBuf,
input: temps.attnH, weights: vp, output: temps.v)
try matmulAny(engine: engine, cmdBuf: cmdBuf,
input: temps.attnH, weightsQ: kProj, weightsF: kProjFloat, output: temps.k)
if let vp = vProj, let vpF = vProjFloat {
if vp != nil || vpF != nil {
try matmulAny(engine: engine, cmdBuf: cmdBuf,
input: temps.attnH, weightsQ: vp, weightsF: vpF, output: temps.v)
}
} else if kEqualsV {
let blit = cmdBuf.makeBlitCommandEncoder()!
let copyBytes = config.nKvHeads * config.headDim * MemoryLayout<Float>.stride
@@ -168,8 +170,8 @@ extension E4BLayer {
}
// 10. O projection
try quantizedMatmul(engine: engine, cmdBuf: cmdBuf,
input: temps.attn, weights: oProj, output: temps.attnH)
try matmulAny(engine: engine, cmdBuf: cmdBuf,
input: temps.attn, weightsQ: oProj, weightsF: oProjFloat, output: temps.attnH)
// 11. Residual 1
try eltwiseAdd(engine: engine, cmdBuf: cmdBuf,
@@ -210,9 +212,9 @@ extension E4BLayer {
// 18. Per-layer gating (optional)
if let pg = perLayerGate, let pp = perLayerProjection, let pl = perLayerInput {
try quantizedMatmul(engine: engine, cmdBuf: cmdBuf,
input: temps.h, weights: pg,
output: temps.gating)
try matmulAny(engine: engine, cmdBuf: cmdBuf,
input: temps.h, weightsQ: pg, weightsF: perLayerGateFloat,
output: temps.gating)
try gelu(engine: engine, cmdBuf: cmdBuf,
input: temps.gating, output: temps.gating, count: 256)
@@ -222,9 +224,9 @@ extension E4BLayer {
output: temps.gating, outputOffset: 0,
count: 256)
try quantizedMatmul(engine: engine, cmdBuf: cmdBuf,
input: temps.gating, weights: pp,
output: temps.h)
try matmulAny(engine: engine, cmdBuf: cmdBuf,
input: temps.gating, weightsQ: pp, weightsF: perLayerProjectionFloat,
output: temps.h)
if let ppn = postPerLayerInputNorm {
try rmsNorm(engine: engine, cmdBuf: cmdBuf,
+161 -12
View File
@@ -657,6 +657,28 @@ readers = readersDict
index: index, readers: readers,
device: engine.device, bits: bits)
}
func fw(_ name: String) throws -> FloatWeights? {
let fullName = "\(prefix).\(name)"
let wName = "\(fullName).weight"
// Check if weight is in preloaded cache
if let wData = preloadedDataCache[wName] {
let wDesc = allTensors.first(where: { $0.name == wName })
if let desc = wDesc, desc.dtype == .bf16 {
let wFloats = SafeTensorsReader.bf16ToFloat32(wData)
let outDim = desc.shape[0]
let inDim = desc.shape[1]
if let wBuf = engine.device.makeBuffer(
bytes: wFloats, length: wFloats.count * MemoryLayout<Float>.stride,
options: .storageModeShared
) {
return FloatWeights(weight: wBuf, inDim: inDim, outDim: outDim)
}
}
}
return nil
}
/// Infer quantization bits from weight tensor shape vs expected input dimension.
/// Returns 4 or 8, defaulting to `defaultBits` if neither matches.
@@ -698,12 +720,23 @@ readers = readersDict
let mlpGateBits = detectBits(for: "mlp.gate_proj", expectedInDim: hiddenSize, defaultBits: 4)
let mlpDownBits = detectBits(for: "mlp.down_proj", expectedInDim: intermediate, defaultBits: 4)
// Check attention projections (required for all layers)
guard let qp = try qwFromCache("self_attn.q_proj"),
let kp = try qwFromCache("self_attn.k_proj"),
let op = try qwFromCache("self_attn.o_proj")
// Try bf16 weights first (for bf16 models)
let qpFloat = try fw("self_attn.q_proj")
let kpFloat = try fw("self_attn.k_proj")
let vpFloat = try fw("self_attn.v_proj")
let opFloat = try fw("self_attn.o_proj")
// Then try quantized weights (for quantized models)
let qpQuant = try qwFromCache("self_attn.q_proj", bits: attnQBits)
let kpQuant = try qwFromCache("self_attn.k_proj", bits: attnKBits)
let vpQuant = try qwFromCache("self_attn.v_proj", bits: attnVBits)
let opQuant = try qwFromCache("self_attn.o_proj", bits: attnOBits)
guard qpQuant != nil || qpFloat != nil,
kpQuant != nil || kpFloat != nil,
opQuant != nil || opFloat != nil
else {
throw WeightError.tensorNotFound("Missing quantized weight for layer \(layerIdx)")
throw WeightError.tensorNotFound("Missing weights for layer \(layerIdx)")
}
// MoE loading (auto-detect from tensor structure)
@@ -725,6 +758,9 @@ readers = readersDict
var gp = try qwFromCache("mlp.gate_proj", bits: mlpGateBits)
var up = try qwFromCache("mlp.up_proj", bits: mlpGateBits)
var dp = try qwFromCache("mlp.down_proj", bits: mlpDownBits)
var gpFloat = try fw("mlp.gate_proj")
var upFloat = try fw("mlp.up_proj")
var dpFloat = try fw("mlp.down_proj")
// If MLP weights missing and this is MoE layer, create dummy weights
if useMoE && numExperts > 0 {
@@ -743,9 +779,9 @@ readers = readersDict
if up == nil { up = dummyQuantizedWeights }
if dp == nil { dp = dummyQuantizedWeights }
}
} else if gp == nil || up == nil || dp == nil {
// Dense layer requires MLP weights
throw WeightError.tensorNotFound("Missing quantized weight for layer \(layerIdx)")
} else if (gp == nil || up == nil || dp == nil) && (gpFloat == nil || upFloat == nil || dpFloat == nil) {
// Dense layer requires either quantized or bf16 MLP weights
throw WeightError.tensorNotFound("Missing MLP weights for layer \(layerIdx)")
}
// v_proj is optional - full attention layers in 12B don't have it
@@ -838,9 +874,13 @@ readers = readersDict
qNorm: try normStrided("self_attn.q_norm.weight", nHeads: lcfg.nHeads, hd: hd),
kNorm: try normStrided("self_attn.k_norm.weight", nHeads: lcfg.nKvHeads, hd: hd),
vNorm: try normStrided("self_attn.v_norm.weight", nHeads: lcfg.nKvHeads, hd: hd),
qProj: qp, kProj: kp, vProj: vp, oProj: op,
gateProj: gp!, upProj: up!, downProj: dp!, // Force unwrap (guaranteed to have value after dummy creation)
qProj: qpQuant, kProj: kpQuant, vProj: vpQuant, oProj: opQuant,
gateProj: gp, upProj: up, downProj: dp,
perLayerGate: pg, perLayerProjection: pp,
qProjFloat: qpFloat, kProjFloat: kpFloat, vProjFloat: vpFloat, oProjFloat: opFloat,
gateProjFloat: gpFloat, upProjFloat: upFloat, downProjFloat: dpFloat,
perLayerGateFloat: try fw("per_layer_input_gate"),
perLayerProjectionFloat: try fw("per_layer_projection"),
perLayerInput: plSlice,
perLayerInputScale: perLayerInputScaleVal,
perLayerProjectionScale: perLayerModelProjectionScaleVal,
@@ -853,8 +893,7 @@ readers = readersDict
expertUp: expertUp,
expertDown: expertDown,
topK: topK,
// For models without v_proj on full attention layers, use k_eq_v=true
kEqualsV: (vp == nil && isFull) || (cfg.attentionKEqualsV ?? false)
kEqualsV: (vpQuant == nil && vpFloat == nil && isFull) || (cfg.attentionKEqualsV ?? false)
)
builtLayers.append(layer)
}
@@ -1214,6 +1253,116 @@ readers = readersDict
inDim: inDim, outDim: outDim, bits: bits, groupSize: groupSize)
}
/// Load non-quantized bf16 embedding weights as FloatWeights
private static func loadFloatEmbed(named: String, from tensors: [TensorDescriptor],
index: SafeTensorsIndex?,
readers: [String: SafeTensorsReader],
device: MTLDevice,
hiddenSize: Int) throws -> FloatWeights? {
let tensorMap = Dictionary(uniqueKeysWithValues: tensors.map { ($0.name, $0) })
let prefix = "language_model.model."
let modelPrefix = "model.language_model.model."
let modelPrefixShort = "model.language_model."
let tensorMapWithPrefix = tensors.reduce(into: [String: TensorDescriptor]()) { dict, desc in
dict[desc.name] = desc
if desc.name.hasPrefix(prefix) {
dict[String(desc.name.dropFirst(prefix.count))] = desc
}
if desc.name.hasPrefix(modelPrefix) {
dict[String(desc.name.dropFirst(modelPrefix.count))] = desc
}
if desc.name.hasPrefix(modelPrefixShort) {
dict[String(desc.name.dropFirst(modelPrefixShort.count))] = desc
}
}
func findTensor(_ name: String) -> TensorDescriptor? {
if let desc = tensorMapWithPrefix[name] { return desc }
return tensorMap[name]
}
let wName = "\(named).weight"
guard let wDesc = findTensor(wName) else {
return nil
}
if wDesc.dtype != .bf16 {
return nil
}
let wReader: SafeTensorsReader
if let idx = index {
let actualWName = wDesc.name
guard let wShard = idx.weightMap[actualWName] else { return nil }
wReader = readers[wShard]!
} else {
wReader = readers["model.safetensors"]!
}
let wData = try wReader.read(tensor: wDesc)
let wFloats = SafeTensorsReader.bf16ToFloat32(wData)
let outDim = wDesc.shape[0]
let inDim = wDesc.shape[1]
guard let wBuf = device.makeBuffer(
bytes: wFloats, length: wFloats.count * MemoryLayout<Float>.stride,
options: .storageModeShared
) else { return nil }
return FloatWeights(weight: wBuf, inDim: inDim, outDim: outDim)
}
/// Load non-quantized bf16 layer weights as FloatWeights
private static func loadFloatWeight(named: String, from tensors: [TensorDescriptor],
index: SafeTensorsIndex?,
readers: [String: SafeTensorsReader],
device: MTLDevice) throws -> FloatWeights? {
let tensorMap = Dictionary(uniqueKeysWithValues: tensors.map { ($0.name, $0) })
let prefix = "language_model.model."
let modelPrefix = "model.language_model."
let tensorMapWithPrefix = tensors.reduce(into: [String: TensorDescriptor]()) { dict, desc in
dict[desc.name] = desc
if desc.name.hasPrefix(prefix) {
dict[String(desc.name.dropFirst(prefix.count))] = desc
}
if desc.name.hasPrefix(modelPrefix) {
dict[String(desc.name.dropFirst(modelPrefix.count))] = desc
}
}
func findTensor(_ name: String) -> TensorDescriptor? {
if let desc = tensorMapWithPrefix[name] { return desc }
return tensorMap[name]
}
let wName = "\(named).weight"
guard let wDesc = findTensor(wName) else { return nil }
if wDesc.dtype != .bf16 {
return nil
}
let wReader: SafeTensorsReader
if let idx = index {
let actualWName = wDesc.name
guard let wShard = idx.weightMap[actualWName] else { return nil }
wReader = readers[wShard]!
} else {
wReader = readers["model.safetensors"]!
}
let wData = try wReader.read(tensor: wDesc)
let wFloats = SafeTensorsReader.bf16ToFloat32(wData)
let outDim = wDesc.shape[0]
let inDim = wDesc.shape[1]
guard let wBuf = device.makeBuffer(
bytes: wFloats, length: wFloats.count * MemoryLayout<Float>.stride,
options: .storageModeShared
) else { return nil }
return FloatWeights(weight: wBuf, inDim: inDim, outDim: outDim)
}
/// Load a 3D expert tensor [numExperts, expertOutDim, inDimPacked] as a contiguous MoEExpertGroup.
/// The data layout is: expert0[outDim, inDimPacked], expert1[outDim, inDimPacked], ...
/// Per-expert access is done via byte offsets into the shared buffers.