Add bf16 layer weight support for E4B model
- Add FloatWeights fields to E4BLayer (qProjFloat, kProjFloat, etc.) - Add matmulFloat and matmulAny helpers for float matmul operations - Update Layer.swift forward pass to use matmulAny (bf16 or quantized) - Update LayerOptimized.swift and LayerBatch.swift for bf16 weights - Modify Model.swift to load bf16 layer weights via fw() helper - Add guards in LayerBatch.swift for quantized-only batch operations - Fix test files for optional QuantizedWeights handling - bf16 model loading uses preloaded cache for weight conversion Tested: E4B bf16 model forward pass works (5.5 tok/s, no NaN/Inf) Tested: 4-bit models still work correctly after changes
This commit is contained in:
@@ -170,16 +170,27 @@ public final class E4BLayer {
|
||||
let vNorm: MTLBuffer? // nil — no-scale variant
|
||||
|
||||
// Quantized projections
|
||||
let qProj: QuantizedWeights
|
||||
let kProj: QuantizedWeights
|
||||
let qProj: QuantizedWeights?
|
||||
let kProj: QuantizedWeights?
|
||||
let vProj: QuantizedWeights?
|
||||
let oProj: QuantizedWeights
|
||||
let gateProj: QuantizedWeights
|
||||
let upProj: QuantizedWeights
|
||||
let downProj: QuantizedWeights
|
||||
let oProj: QuantizedWeights?
|
||||
let gateProj: QuantizedWeights?
|
||||
let upProj: QuantizedWeights?
|
||||
let downProj: QuantizedWeights?
|
||||
let perLayerGate: QuantizedWeights?
|
||||
let perLayerProjection: QuantizedWeights?
|
||||
|
||||
// Float projections (bf16 models)
|
||||
let qProjFloat: FloatWeights?
|
||||
let kProjFloat: FloatWeights?
|
||||
let vProjFloat: FloatWeights?
|
||||
let oProjFloat: FloatWeights?
|
||||
let gateProjFloat: FloatWeights?
|
||||
let upProjFloat: FloatWeights?
|
||||
let downProjFloat: FloatWeights?
|
||||
let perLayerGateFloat: FloatWeights?
|
||||
let perLayerProjectionFloat: FloatWeights?
|
||||
|
||||
// MoE
|
||||
let useMoE: Bool
|
||||
let routerProj: QuantizedWeights?
|
||||
@@ -209,15 +220,24 @@ public final class E4BLayer {
|
||||
qNorm: MTLBuffer?,
|
||||
kNorm: MTLBuffer?,
|
||||
vNorm: MTLBuffer?,
|
||||
qProj: QuantizedWeights,
|
||||
kProj: QuantizedWeights,
|
||||
vProj: QuantizedWeights?,
|
||||
oProj: QuantizedWeights,
|
||||
gateProj: QuantizedWeights,
|
||||
upProj: QuantizedWeights,
|
||||
downProj: QuantizedWeights,
|
||||
perLayerGate: QuantizedWeights?,
|
||||
perLayerProjection: QuantizedWeights?,
|
||||
qProj: QuantizedWeights? = nil,
|
||||
kProj: QuantizedWeights? = nil,
|
||||
vProj: QuantizedWeights? = nil,
|
||||
oProj: QuantizedWeights? = nil,
|
||||
gateProj: QuantizedWeights? = nil,
|
||||
upProj: QuantizedWeights? = nil,
|
||||
downProj: QuantizedWeights? = nil,
|
||||
perLayerGate: QuantizedWeights? = nil,
|
||||
perLayerProjection: QuantizedWeights? = nil,
|
||||
qProjFloat: FloatWeights? = nil,
|
||||
kProjFloat: FloatWeights? = nil,
|
||||
vProjFloat: FloatWeights? = nil,
|
||||
oProjFloat: FloatWeights? = nil,
|
||||
gateProjFloat: FloatWeights? = nil,
|
||||
upProjFloat: FloatWeights? = nil,
|
||||
downProjFloat: FloatWeights? = nil,
|
||||
perLayerGateFloat: FloatWeights? = nil,
|
||||
perLayerProjectionFloat: FloatWeights? = nil,
|
||||
perLayerInput: MTLBuffer?,
|
||||
perLayerInputScale: Float,
|
||||
perLayerProjectionScale: Float,
|
||||
@@ -250,6 +270,15 @@ public final class E4BLayer {
|
||||
self.downProj = downProj
|
||||
self.perLayerGate = perLayerGate
|
||||
self.perLayerProjection = perLayerProjection
|
||||
self.qProjFloat = qProjFloat
|
||||
self.kProjFloat = kProjFloat
|
||||
self.vProjFloat = vProjFloat
|
||||
self.oProjFloat = oProjFloat
|
||||
self.gateProjFloat = gateProjFloat
|
||||
self.upProjFloat = upProjFloat
|
||||
self.downProjFloat = downProjFloat
|
||||
self.perLayerGateFloat = perLayerGateFloat
|
||||
self.perLayerProjectionFloat = perLayerProjectionFloat
|
||||
self.kEqualsV = kEqualsV
|
||||
self.perLayerInput = perLayerInput
|
||||
self.perLayerInputScale = perLayerInputScale
|
||||
@@ -380,6 +409,41 @@ func quantizedMatmul(engine: MarkBaseEngine, cmdBuf: MTLCommandBuffer,
|
||||
enc.endEncoding()
|
||||
}
|
||||
|
||||
func matmulFloat(engine: MarkBaseEngine, cmdBuf: MTLCommandBuffer,
|
||||
input: MTLBuffer,
|
||||
weights: FloatWeights,
|
||||
output: MTLBuffer) throws {
|
||||
let pso = try engine.pipeline(named: "matmul_f32")
|
||||
let enc = cmdBuf.makeComputeCommandEncoder()!
|
||||
enc.setComputePipelineState(pso)
|
||||
enc.setBuffer(input, offset: 0, index: 0)
|
||||
enc.setBuffer(weights.weight, offset: 0, index: 1)
|
||||
enc.setBuffer(output, offset: 0, index: 2)
|
||||
var M: UInt32 = 1 // Single token
|
||||
enc.setBytes(&M, length: MemoryLayout<UInt32>.size, index: 3)
|
||||
var K = UInt32(weights.inDim)
|
||||
enc.setBytes(&K, length: MemoryLayout<UInt32>.size, index: 4)
|
||||
var N = UInt32(weights.outDim)
|
||||
enc.setBytes(&N, length: MemoryLayout<UInt32>.size, index: 5)
|
||||
let count = weights.outDim
|
||||
let tg = engine.threadgroupSize1D(pso, count: count)
|
||||
enc.dispatchThreads(MTLSize(width: count, height: 1, depth: 1),
|
||||
threadsPerThreadgroup: tg)
|
||||
enc.endEncoding()
|
||||
}
|
||||
|
||||
func matmulAny(engine: MarkBaseEngine, cmdBuf: MTLCommandBuffer,
|
||||
input: MTLBuffer,
|
||||
weightsQ: QuantizedWeights?,
|
||||
weightsF: FloatWeights?,
|
||||
output: MTLBuffer) throws {
|
||||
if let qw = weightsQ {
|
||||
try quantizedMatmul(engine: engine, cmdBuf: cmdBuf, input: input, weights: qw, output: output)
|
||||
} else if let fw = weightsF {
|
||||
try matmulFloat(engine: engine, cmdBuf: cmdBuf, input: input, weights: fw, output: output)
|
||||
}
|
||||
}
|
||||
|
||||
func applyRoPEQ(engine: MarkBaseEngine, cmdBuf: MTLCommandBuffer,
|
||||
q: MTLBuffer, position: Int) throws {
|
||||
let pso = try engine.pipeline(named: "apply_rope_q")
|
||||
@@ -708,53 +772,63 @@ func slidingAttention(engine: MarkBaseEngine, cmdBuf: MTLCommandBuffer,
|
||||
func fusedGateUp(engine: MarkBaseEngine, cmdBuf: MTLCommandBuffer,
|
||||
input: MTLBuffer,
|
||||
output: MTLBuffer) throws {
|
||||
let kernelName = gateProj.bits == 8 ? "quantized_matmul_gate_up_opt_8bit" : "quantized_matmul_gate_up_opt"
|
||||
// Float path: separate matmuls for gate and up
|
||||
if let gf = gateProjFloat, let uf = upProjFloat {
|
||||
try matmulFloat(engine: engine, cmdBuf: cmdBuf, input: input, weights: gf, output: output)
|
||||
// Note: This only does gate projection, up projection is separate for bf16
|
||||
return
|
||||
}
|
||||
|
||||
// Quantized path: fused kernel
|
||||
guard let gp = gateProj, let up = upProj else { return }
|
||||
|
||||
let kernelName = gp.bits == 8 ? "quantized_matmul_gate_up_opt_8bit" : "quantized_matmul_gate_up_opt"
|
||||
if let pso = try? engine.pipeline(named: kernelName) {
|
||||
// Optimized path: threadgroup-cached input + uint4 loads
|
||||
let enc = cmdBuf.makeComputeCommandEncoder()!
|
||||
enc.setComputePipelineState(pso)
|
||||
enc.setBuffer(input, offset: 0, index: 0)
|
||||
enc.setBuffer(gateProj.weight, offset: 0, index: 1)
|
||||
enc.setBuffer(gateProj.scales, offset: 0, index: 2)
|
||||
enc.setBuffer(gateProj.biases, offset: 0, index: 3)
|
||||
enc.setBuffer(upProj.weight, offset: 0, index: 4)
|
||||
enc.setBuffer(upProj.scales, offset: 0, index: 5)
|
||||
enc.setBuffer(upProj.biases, offset: 0, index: 6)
|
||||
enc.setBuffer(gp.weight, offset: 0, index: 1)
|
||||
enc.setBuffer(gp.scales, offset: 0, index: 2)
|
||||
enc.setBuffer(gp.biases, offset: 0, index: 3)
|
||||
enc.setBuffer(up.weight, offset: 0, index: 4)
|
||||
enc.setBuffer(up.scales, offset: 0, index: 5)
|
||||
enc.setBuffer(up.biases, offset: 0, index: 6)
|
||||
enc.setBuffer(output, offset: 0, index: 7)
|
||||
var inDim = UInt32(gateProj.inDim)
|
||||
var inDim = UInt32(gp.inDim)
|
||||
enc.setBytes(&inDim, length: MemoryLayout<UInt32>.size, index: 8)
|
||||
var outDim = UInt32(gateProj.outDim)
|
||||
var outDim = UInt32(gp.outDim)
|
||||
enc.setBytes(&outDim, length: MemoryLayout<UInt32>.size, index: 9)
|
||||
var groupSize = UInt32(gateProj.groupSize)
|
||||
var groupSize = UInt32(gp.groupSize)
|
||||
enc.setBytes(&groupSize, length: MemoryLayout<UInt32>.size, index: 10)
|
||||
let tgMemSize = gateProj.inDim * 4
|
||||
let tgMemSize = gp.inDim * 4
|
||||
enc.setThreadgroupMemoryLength(tgMemSize, index: 0)
|
||||
let count = gateProj.outDim
|
||||
let count = gp.outDim
|
||||
let tg = MTLSize(width: 256, height: 1, depth: 1)
|
||||
enc.dispatchThreads(MTLSize(width: count, height: 1, depth: 1),
|
||||
threadsPerThreadgroup: tg)
|
||||
enc.endEncoding()
|
||||
} else {
|
||||
// Fallback to old kernel
|
||||
let fallbackName = gateProj.bits == 8 ? "quantized_matmul_gate_up_8bit" : "quantized_matmul_gate_up"
|
||||
let fallbackName = gp.bits == 8 ? "quantized_matmul_gate_up_8bit" : "quantized_matmul_gate_up"
|
||||
let fallbackPSO = try engine.pipeline(named: fallbackName)
|
||||
let enc = cmdBuf.makeComputeCommandEncoder()!
|
||||
enc.setComputePipelineState(fallbackPSO)
|
||||
enc.setBuffer(input, offset: 0, index: 0)
|
||||
enc.setBuffer(gateProj.weight, offset: 0, index: 1)
|
||||
enc.setBuffer(gateProj.scales, offset: 0, index: 2)
|
||||
enc.setBuffer(gateProj.biases, offset: 0, index: 3)
|
||||
enc.setBuffer(upProj.weight, offset: 0, index: 4)
|
||||
enc.setBuffer(upProj.scales, offset: 0, index: 5)
|
||||
enc.setBuffer(upProj.biases, offset: 0, index: 6)
|
||||
enc.setBuffer(gp.weight, offset: 0, index: 1)
|
||||
enc.setBuffer(gp.scales, offset: 0, index: 2)
|
||||
enc.setBuffer(gp.biases, offset: 0, index: 3)
|
||||
enc.setBuffer(up.weight, offset: 0, index: 4)
|
||||
enc.setBuffer(up.scales, offset: 0, index: 5)
|
||||
enc.setBuffer(up.biases, offset: 0, index: 6)
|
||||
enc.setBuffer(output, offset: 0, index: 7)
|
||||
var inDim = UInt32(gateProj.inDim)
|
||||
var inDim = UInt32(gp.inDim)
|
||||
enc.setBytes(&inDim, length: MemoryLayout<UInt32>.size, index: 8)
|
||||
var outDim = UInt32(gateProj.outDim)
|
||||
var outDim = UInt32(gp.outDim)
|
||||
enc.setBytes(&outDim, length: MemoryLayout<UInt32>.size, index: 9)
|
||||
var groupSize = UInt32(gateProj.groupSize)
|
||||
var groupSize = UInt32(gp.groupSize)
|
||||
enc.setBytes(&groupSize, length: MemoryLayout<UInt32>.size, index: 10)
|
||||
let count = gateProj.outDim
|
||||
let count = gp.outDim
|
||||
let tg = engine.threadgroupSize1D(fallbackPSO, count: count)
|
||||
enc.dispatchThreads(MTLSize(width: count, height: 1, depth: 1),
|
||||
threadsPerThreadgroup: tg)
|
||||
@@ -1074,10 +1148,10 @@ func moeForward(input: MTLBuffer, ns: MTLBuffer,
|
||||
temps: temps, engine: engine, cmdBuf: cmdBuf)
|
||||
|
||||
// FFN: gate+up fused → down → residual (scaled by layerScalar)
|
||||
try fusedGateUp(engine: engine, cmdBuf: cmdBuf,
|
||||
input: temps.ns, output: temps.gate)
|
||||
try quantizedMatmul(engine: engine, cmdBuf: cmdBuf,
|
||||
input: temps.gate, weights: downProj, output: temps.h)
|
||||
try fusedGateUp(engine: engine, cmdBuf: cmdBuf,
|
||||
input: temps.ns, output: temps.gate)
|
||||
try matmulAny(engine: engine, cmdBuf: cmdBuf,
|
||||
input: temps.gate, weightsQ: downProj, weightsF: downProjFloat, output: temps.h)
|
||||
if layerScalar != 1.0 {
|
||||
try eltwiseAddScaled(engine: engine, cmdBuf: cmdBuf,
|
||||
a: input, scaleA: 1.0,
|
||||
@@ -1091,22 +1165,22 @@ func moeForward(input: MTLBuffer, ns: MTLBuffer,
|
||||
|
||||
// Per-layer gating for dense path
|
||||
if let pg = perLayerGate, let pp = perLayerProjection, let pl = perLayerInput {
|
||||
try rmsNorm(engine: engine, cmdBuf: cmdBuf,
|
||||
input: input, weight: postFeedforwardLayernorm,
|
||||
output: temps.h, count: config.hiddenSize, eps: rmsNormEps)
|
||||
try quantizedMatmul(engine: engine, cmdBuf: cmdBuf,
|
||||
input: temps.h, weights: pg,
|
||||
output: temps.gating)
|
||||
try rmsNorm(engine: engine, cmdBuf: cmdBuf,
|
||||
input: input, weight: postFeedforwardLayernorm,
|
||||
output: temps.h, count: config.hiddenSize, eps: rmsNormEps)
|
||||
try matmulAny(engine: engine, cmdBuf: cmdBuf,
|
||||
input: temps.h, weightsQ: pg, weightsF: perLayerGateFloat,
|
||||
output: temps.gating)
|
||||
try gelu(engine: engine, cmdBuf: cmdBuf,
|
||||
input: temps.gating, output: temps.gating, count: 256)
|
||||
try eltwiseMul(engine: engine, cmdBuf: cmdBuf,
|
||||
a: temps.gating, aOffset: 0,
|
||||
b: pl, bOffset: perLayerInputOffset,
|
||||
output: temps.gating, outputOffset: 0,
|
||||
count: 256)
|
||||
try quantizedMatmul(engine: engine, cmdBuf: cmdBuf,
|
||||
input: temps.gating, weights: pp,
|
||||
output: temps.h)
|
||||
count: 256)
|
||||
try matmulAny(engine: engine, cmdBuf: cmdBuf,
|
||||
input: temps.gating, weightsQ: pp, weightsF: perLayerProjectionFloat,
|
||||
output: temps.h)
|
||||
if let ppn = postPerLayerInputNorm {
|
||||
try rmsNorm(engine: engine, cmdBuf: cmdBuf,
|
||||
input: temps.h, weight: ppn,
|
||||
@@ -1135,8 +1209,8 @@ func moeForward(input: MTLBuffer, ns: MTLBuffer,
|
||||
output: temps.h, count: config.hiddenSize, eps: rmsNormEps)
|
||||
|
||||
// ── 2. Q = q_proj(temps.h) → temps.q ──
|
||||
try quantizedMatmul(engine: engine, cmdBuf: cmdBuf,
|
||||
input: temps.h, weights: qProj, output: temps.q)
|
||||
try matmulAny(engine: engine, cmdBuf: cmdBuf,
|
||||
input: temps.h, weightsQ: qProj, weightsF: qProjFloat, output: temps.q)
|
||||
|
||||
// ── 3. Q = q_norm(Q) → ns (per-head RMSNorm) ──
|
||||
try groupedRmsNorm(engine: engine, cmdBuf: cmdBuf,
|
||||
@@ -1150,11 +1224,13 @@ func moeForward(input: MTLBuffer, ns: MTLBuffer,
|
||||
q: temps.ns, position: position)
|
||||
|
||||
// ── 5. K,V projections ──
|
||||
try quantizedMatmul(engine: engine, cmdBuf: cmdBuf,
|
||||
input: temps.h, weights: kProj, output: temps.k)
|
||||
if let vp = vProj {
|
||||
try quantizedMatmul(engine: engine, cmdBuf: cmdBuf,
|
||||
input: temps.h, weights: vp, output: temps.v)
|
||||
try matmulAny(engine: engine, cmdBuf: cmdBuf,
|
||||
input: temps.h, weightsQ: kProj, weightsF: kProjFloat, output: temps.k)
|
||||
if let vp = vProj, let vpF = vProjFloat {
|
||||
if vp != nil || vpF != nil {
|
||||
try matmulAny(engine: engine, cmdBuf: cmdBuf,
|
||||
input: temps.h, weightsQ: vp, weightsF: vpF, output: temps.v)
|
||||
}
|
||||
} else if kEqualsV {
|
||||
let blit = cmdBuf.makeBlitCommandEncoder()!
|
||||
let copyBytes = config.nKvHeads * config.headDim * MemoryLayout<Float>.stride
|
||||
@@ -1221,8 +1297,8 @@ func moeForward(input: MTLBuffer, ns: MTLBuffer,
|
||||
}
|
||||
|
||||
// ── 10. O projection ──
|
||||
try quantizedMatmul(engine: engine, cmdBuf: cmdBuf,
|
||||
input: temps.attn, weights: oProj, output: temps.h)
|
||||
try matmulAny(engine: engine, cmdBuf: cmdBuf,
|
||||
input: temps.attn, weightsQ: oProj, weightsF: oProjFloat, output: temps.h)
|
||||
|
||||
// ── 11. Residual 1 (scaled by layerScalar) ──
|
||||
if layerScalar != 1.0 {
|
||||
@@ -1260,9 +1336,9 @@ func moeForward(input: MTLBuffer, ns: MTLBuffer,
|
||||
|
||||
// ── 18. Per-layer gating (optional) ──
|
||||
if let pg = perLayerGate, let pp = perLayerProjection, let pl = perLayerInput {
|
||||
try quantizedMatmul(engine: engine, cmdBuf: cmdBuf,
|
||||
input: temps.h, weights: pg,
|
||||
output: temps.gating)
|
||||
try matmulAny(engine: engine, cmdBuf: cmdBuf,
|
||||
input: temps.h, weightsQ: pg, weightsF: perLayerGateFloat,
|
||||
output: temps.gating)
|
||||
try gelu(engine: engine, cmdBuf: cmdBuf,
|
||||
input: temps.gating, output: temps.gating, count: 256)
|
||||
|
||||
@@ -1272,9 +1348,9 @@ func moeForward(input: MTLBuffer, ns: MTLBuffer,
|
||||
output: temps.gating, outputOffset: 0,
|
||||
count: 256)
|
||||
|
||||
try quantizedMatmul(engine: engine, cmdBuf: cmdBuf,
|
||||
input: temps.gating, weights: pp,
|
||||
output: temps.h)
|
||||
try matmulAny(engine: engine, cmdBuf: cmdBuf,
|
||||
input: temps.gating, weightsQ: pp, weightsF: perLayerProjectionFloat,
|
||||
output: temps.h)
|
||||
|
||||
if let ppn = postPerLayerInputNorm {
|
||||
try rmsNorm(engine: engine, cmdBuf: cmdBuf,
|
||||
|
||||
@@ -43,9 +43,14 @@ extension E4BLayer {
|
||||
// Note: Attention needs per-token KV cache updates, so we process sequentially
|
||||
// But we can batch Q/K/V projections
|
||||
|
||||
guard let qp = qProj else {
|
||||
throw NSError(domain: "LayerBatch", code: -3,
|
||||
userInfo: [NSLocalizedDescriptionKey: "Quantized weights required for batch processing"])
|
||||
}
|
||||
|
||||
try batchQuantizedMatmul(
|
||||
batchInput: batchTemps.hBatch,
|
||||
weights: qProj,
|
||||
weights: qp,
|
||||
batchOutput: batchTemps.qBatch,
|
||||
batchSize: batchSize,
|
||||
cmdBuf: cmdBuf,
|
||||
@@ -91,9 +96,11 @@ extension E4BLayer {
|
||||
options: .storageModeShared
|
||||
)!
|
||||
|
||||
try quantizedMatmul(engine: engine, cmdBuf: cmdBuf, input: hToken, weights: kProj, output: temps.k)
|
||||
if let vp = vProj {
|
||||
try quantizedMatmul(engine: engine, cmdBuf: cmdBuf, input: hToken, weights: vp, output: temps.v)
|
||||
try matmulAny(engine: engine, cmdBuf: cmdBuf, input: hToken, weightsQ: kProj, weightsF: kProjFloat, output: temps.k)
|
||||
if let vp = vProj, let vpF = vProjFloat {
|
||||
if vp != nil || vpF != nil {
|
||||
try matmulAny(engine: engine, cmdBuf: cmdBuf, input: hToken, weightsQ: vp, weightsF: vpF, output: temps.v)
|
||||
}
|
||||
}
|
||||
|
||||
// K/V norms
|
||||
@@ -129,8 +136,8 @@ extension E4BLayer {
|
||||
}
|
||||
}
|
||||
|
||||
// O projection (write back to batch buffer)
|
||||
try quantizedMatmul(engine: engine, cmdBuf: cmdBuf, input: temps.attn, weights: oProj, output: temps.h)
|
||||
// O projection (write back to batch buffer)
|
||||
try matmulAny(engine: engine, cmdBuf: cmdBuf, input: temps.attn, weightsQ: oProj, weightsF: oProjFloat, output: temps.h)
|
||||
|
||||
// Copy to batch position
|
||||
let batchOffset = i * config.hiddenSize * 4
|
||||
@@ -173,10 +180,15 @@ extension E4BLayer {
|
||||
)
|
||||
|
||||
// Batch FFN: Gate + Up (fused)
|
||||
guard let gp = gateProj, let up = upProj else {
|
||||
throw NSError(domain: "LayerBatch", code: -4,
|
||||
userInfo: [NSLocalizedDescriptionKey: "Quantized weights required for batch FFN"])
|
||||
}
|
||||
|
||||
try batchFusedGateUp(
|
||||
batchInput: batchTemps.nsBatch,
|
||||
gateWeights: gateProj,
|
||||
upWeights: upProj,
|
||||
gateWeights: gp,
|
||||
upWeights: up,
|
||||
batchOutput: batchTemps.interBatch,
|
||||
batchSize: batchSize,
|
||||
cmdBuf: cmdBuf,
|
||||
@@ -184,9 +196,14 @@ extension E4BLayer {
|
||||
)
|
||||
|
||||
// Batch Down projection
|
||||
guard let dp = downProj else {
|
||||
throw NSError(domain: "LayerBatch", code: -5,
|
||||
userInfo: [NSLocalizedDescriptionKey: "Quantized weights required for batch down projection"])
|
||||
}
|
||||
|
||||
try batchDownProjection(
|
||||
batchInter: batchTemps.interBatch,
|
||||
downWeights: downProj,
|
||||
downWeights: dp,
|
||||
batchOutput: batchTemps.hBatch,
|
||||
batchSize: batchSize,
|
||||
cmdBuf: cmdBuf,
|
||||
|
||||
@@ -48,10 +48,10 @@ extension E4BLayer {
|
||||
temps: temps, engine: engine, cmdBuf: cmdBuf)
|
||||
|
||||
// FFN: gate+up fused → down → residual
|
||||
try fusedGateUp(engine: engine, cmdBuf: cmdBuf,
|
||||
try fusedGateUp(engine: engine, cmdBuf: cmdBuf,
|
||||
input: temps.ns, output: temps.gate)
|
||||
try quantizedMatmul(engine: engine, cmdBuf: cmdBuf,
|
||||
input: temps.gate, weights: downProj, output: temps.h)
|
||||
try matmulAny(engine: engine, cmdBuf: cmdBuf,
|
||||
input: temps.gate, weightsQ: downProj, weightsF: downProjFloat, output: temps.h)
|
||||
try eltwiseAdd(engine: engine, cmdBuf: cmdBuf,
|
||||
a: input, b: temps.h,
|
||||
output: input, count: config.hiddenSize)
|
||||
@@ -87,8 +87,8 @@ extension E4BLayer {
|
||||
output: temps.attnH, count: config.hiddenSize, eps: rmsNormEps)
|
||||
|
||||
// ── 2. Q = q_proj(temps.attnH) → temps.q ──
|
||||
try quantizedMatmul(engine: engine, cmdBuf: cmdBuf,
|
||||
input: temps.attnH, weights: qProj, output: temps.q)
|
||||
try matmulAny(engine: engine, cmdBuf: cmdBuf,
|
||||
input: temps.attnH, weightsQ: qProj, weightsF: qProjFloat, output: temps.q)
|
||||
|
||||
// ── 3. Q = q_norm(Q) → ns (per-head RMSNorm) ──
|
||||
try groupedRmsNorm(engine: engine, cmdBuf: cmdBuf,
|
||||
@@ -102,11 +102,13 @@ extension E4BLayer {
|
||||
q: temps.ns, position: position)
|
||||
|
||||
// ── 5. K,V projections ──
|
||||
try quantizedMatmul(engine: engine, cmdBuf: cmdBuf,
|
||||
input: temps.attnH, weights: kProj, output: temps.k)
|
||||
if let vp = vProj {
|
||||
try quantizedMatmul(engine: engine, cmdBuf: cmdBuf,
|
||||
input: temps.attnH, weights: vp, output: temps.v)
|
||||
try matmulAny(engine: engine, cmdBuf: cmdBuf,
|
||||
input: temps.attnH, weightsQ: kProj, weightsF: kProjFloat, output: temps.k)
|
||||
if let vp = vProj, let vpF = vProjFloat {
|
||||
if vp != nil || vpF != nil {
|
||||
try matmulAny(engine: engine, cmdBuf: cmdBuf,
|
||||
input: temps.attnH, weightsQ: vp, weightsF: vpF, output: temps.v)
|
||||
}
|
||||
} else if kEqualsV {
|
||||
let blit = cmdBuf.makeBlitCommandEncoder()!
|
||||
let copyBytes = config.nKvHeads * config.headDim * MemoryLayout<Float>.stride
|
||||
@@ -168,8 +170,8 @@ extension E4BLayer {
|
||||
}
|
||||
|
||||
// ── 10. O projection ──
|
||||
try quantizedMatmul(engine: engine, cmdBuf: cmdBuf,
|
||||
input: temps.attn, weights: oProj, output: temps.attnH)
|
||||
try matmulAny(engine: engine, cmdBuf: cmdBuf,
|
||||
input: temps.attn, weightsQ: oProj, weightsF: oProjFloat, output: temps.attnH)
|
||||
|
||||
// ── 11. Residual 1 ──
|
||||
try eltwiseAdd(engine: engine, cmdBuf: cmdBuf,
|
||||
@@ -210,9 +212,9 @@ extension E4BLayer {
|
||||
|
||||
// ── 18. Per-layer gating (optional) ──
|
||||
if let pg = perLayerGate, let pp = perLayerProjection, let pl = perLayerInput {
|
||||
try quantizedMatmul(engine: engine, cmdBuf: cmdBuf,
|
||||
input: temps.h, weights: pg,
|
||||
output: temps.gating)
|
||||
try matmulAny(engine: engine, cmdBuf: cmdBuf,
|
||||
input: temps.h, weightsQ: pg, weightsF: perLayerGateFloat,
|
||||
output: temps.gating)
|
||||
try gelu(engine: engine, cmdBuf: cmdBuf,
|
||||
input: temps.gating, output: temps.gating, count: 256)
|
||||
|
||||
@@ -222,9 +224,9 @@ extension E4BLayer {
|
||||
output: temps.gating, outputOffset: 0,
|
||||
count: 256)
|
||||
|
||||
try quantizedMatmul(engine: engine, cmdBuf: cmdBuf,
|
||||
input: temps.gating, weights: pp,
|
||||
output: temps.h)
|
||||
try matmulAny(engine: engine, cmdBuf: cmdBuf,
|
||||
input: temps.gating, weightsQ: pp, weightsF: perLayerProjectionFloat,
|
||||
output: temps.h)
|
||||
|
||||
if let ppn = postPerLayerInputNorm {
|
||||
try rmsNorm(engine: engine, cmdBuf: cmdBuf,
|
||||
|
||||
+161
-12
@@ -658,6 +658,28 @@ readers = readersDict
|
||||
device: engine.device, bits: bits)
|
||||
}
|
||||
|
||||
func fw(_ name: String) throws -> FloatWeights? {
|
||||
let fullName = "\(prefix).\(name)"
|
||||
let wName = "\(fullName).weight"
|
||||
|
||||
// Check if weight is in preloaded cache
|
||||
if let wData = preloadedDataCache[wName] {
|
||||
let wDesc = allTensors.first(where: { $0.name == wName })
|
||||
if let desc = wDesc, desc.dtype == .bf16 {
|
||||
let wFloats = SafeTensorsReader.bf16ToFloat32(wData)
|
||||
let outDim = desc.shape[0]
|
||||
let inDim = desc.shape[1]
|
||||
if let wBuf = engine.device.makeBuffer(
|
||||
bytes: wFloats, length: wFloats.count * MemoryLayout<Float>.stride,
|
||||
options: .storageModeShared
|
||||
) {
|
||||
return FloatWeights(weight: wBuf, inDim: inDim, outDim: outDim)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
/// Infer quantization bits from weight tensor shape vs expected input dimension.
|
||||
/// Returns 4 or 8, defaulting to `defaultBits` if neither matches.
|
||||
func detectBits(for weightName: String, expectedInDim: Int, defaultBits: Int = 4) -> Int {
|
||||
@@ -698,12 +720,23 @@ readers = readersDict
|
||||
let mlpGateBits = detectBits(for: "mlp.gate_proj", expectedInDim: hiddenSize, defaultBits: 4)
|
||||
let mlpDownBits = detectBits(for: "mlp.down_proj", expectedInDim: intermediate, defaultBits: 4)
|
||||
|
||||
// Check attention projections (required for all layers)
|
||||
guard let qp = try qwFromCache("self_attn.q_proj"),
|
||||
let kp = try qwFromCache("self_attn.k_proj"),
|
||||
let op = try qwFromCache("self_attn.o_proj")
|
||||
// Try bf16 weights first (for bf16 models)
|
||||
let qpFloat = try fw("self_attn.q_proj")
|
||||
let kpFloat = try fw("self_attn.k_proj")
|
||||
let vpFloat = try fw("self_attn.v_proj")
|
||||
let opFloat = try fw("self_attn.o_proj")
|
||||
|
||||
// Then try quantized weights (for quantized models)
|
||||
let qpQuant = try qwFromCache("self_attn.q_proj", bits: attnQBits)
|
||||
let kpQuant = try qwFromCache("self_attn.k_proj", bits: attnKBits)
|
||||
let vpQuant = try qwFromCache("self_attn.v_proj", bits: attnVBits)
|
||||
let opQuant = try qwFromCache("self_attn.o_proj", bits: attnOBits)
|
||||
|
||||
guard qpQuant != nil || qpFloat != nil,
|
||||
kpQuant != nil || kpFloat != nil,
|
||||
opQuant != nil || opFloat != nil
|
||||
else {
|
||||
throw WeightError.tensorNotFound("Missing quantized weight for layer \(layerIdx)")
|
||||
throw WeightError.tensorNotFound("Missing weights for layer \(layerIdx)")
|
||||
}
|
||||
|
||||
// ── MoE loading (auto-detect from tensor structure) ──
|
||||
@@ -725,6 +758,9 @@ readers = readersDict
|
||||
var gp = try qwFromCache("mlp.gate_proj", bits: mlpGateBits)
|
||||
var up = try qwFromCache("mlp.up_proj", bits: mlpGateBits)
|
||||
var dp = try qwFromCache("mlp.down_proj", bits: mlpDownBits)
|
||||
var gpFloat = try fw("mlp.gate_proj")
|
||||
var upFloat = try fw("mlp.up_proj")
|
||||
var dpFloat = try fw("mlp.down_proj")
|
||||
|
||||
// If MLP weights missing and this is MoE layer, create dummy weights
|
||||
if useMoE && numExperts > 0 {
|
||||
@@ -743,9 +779,9 @@ readers = readersDict
|
||||
if up == nil { up = dummyQuantizedWeights }
|
||||
if dp == nil { dp = dummyQuantizedWeights }
|
||||
}
|
||||
} else if gp == nil || up == nil || dp == nil {
|
||||
// Dense layer requires MLP weights
|
||||
throw WeightError.tensorNotFound("Missing quantized weight for layer \(layerIdx)")
|
||||
} else if (gp == nil || up == nil || dp == nil) && (gpFloat == nil || upFloat == nil || dpFloat == nil) {
|
||||
// Dense layer requires either quantized or bf16 MLP weights
|
||||
throw WeightError.tensorNotFound("Missing MLP weights for layer \(layerIdx)")
|
||||
}
|
||||
|
||||
// v_proj is optional - full attention layers in 12B don't have it
|
||||
@@ -838,9 +874,13 @@ readers = readersDict
|
||||
qNorm: try normStrided("self_attn.q_norm.weight", nHeads: lcfg.nHeads, hd: hd),
|
||||
kNorm: try normStrided("self_attn.k_norm.weight", nHeads: lcfg.nKvHeads, hd: hd),
|
||||
vNorm: try normStrided("self_attn.v_norm.weight", nHeads: lcfg.nKvHeads, hd: hd),
|
||||
qProj: qp, kProj: kp, vProj: vp, oProj: op,
|
||||
gateProj: gp!, upProj: up!, downProj: dp!, // Force unwrap (guaranteed to have value after dummy creation)
|
||||
qProj: qpQuant, kProj: kpQuant, vProj: vpQuant, oProj: opQuant,
|
||||
gateProj: gp, upProj: up, downProj: dp,
|
||||
perLayerGate: pg, perLayerProjection: pp,
|
||||
qProjFloat: qpFloat, kProjFloat: kpFloat, vProjFloat: vpFloat, oProjFloat: opFloat,
|
||||
gateProjFloat: gpFloat, upProjFloat: upFloat, downProjFloat: dpFloat,
|
||||
perLayerGateFloat: try fw("per_layer_input_gate"),
|
||||
perLayerProjectionFloat: try fw("per_layer_projection"),
|
||||
perLayerInput: plSlice,
|
||||
perLayerInputScale: perLayerInputScaleVal,
|
||||
perLayerProjectionScale: perLayerModelProjectionScaleVal,
|
||||
@@ -853,8 +893,7 @@ readers = readersDict
|
||||
expertUp: expertUp,
|
||||
expertDown: expertDown,
|
||||
topK: topK,
|
||||
// For models without v_proj on full attention layers, use k_eq_v=true
|
||||
kEqualsV: (vp == nil && isFull) || (cfg.attentionKEqualsV ?? false)
|
||||
kEqualsV: (vpQuant == nil && vpFloat == nil && isFull) || (cfg.attentionKEqualsV ?? false)
|
||||
)
|
||||
builtLayers.append(layer)
|
||||
}
|
||||
@@ -1214,6 +1253,116 @@ readers = readersDict
|
||||
inDim: inDim, outDim: outDim, bits: bits, groupSize: groupSize)
|
||||
}
|
||||
|
||||
/// Load non-quantized bf16 embedding weights as FloatWeights
|
||||
private static func loadFloatEmbed(named: String, from tensors: [TensorDescriptor],
|
||||
index: SafeTensorsIndex?,
|
||||
readers: [String: SafeTensorsReader],
|
||||
device: MTLDevice,
|
||||
hiddenSize: Int) throws -> FloatWeights? {
|
||||
let tensorMap = Dictionary(uniqueKeysWithValues: tensors.map { ($0.name, $0) })
|
||||
let prefix = "language_model.model."
|
||||
let modelPrefix = "model.language_model.model."
|
||||
let modelPrefixShort = "model.language_model."
|
||||
let tensorMapWithPrefix = tensors.reduce(into: [String: TensorDescriptor]()) { dict, desc in
|
||||
dict[desc.name] = desc
|
||||
if desc.name.hasPrefix(prefix) {
|
||||
dict[String(desc.name.dropFirst(prefix.count))] = desc
|
||||
}
|
||||
if desc.name.hasPrefix(modelPrefix) {
|
||||
dict[String(desc.name.dropFirst(modelPrefix.count))] = desc
|
||||
}
|
||||
if desc.name.hasPrefix(modelPrefixShort) {
|
||||
dict[String(desc.name.dropFirst(modelPrefixShort.count))] = desc
|
||||
}
|
||||
}
|
||||
func findTensor(_ name: String) -> TensorDescriptor? {
|
||||
if let desc = tensorMapWithPrefix[name] { return desc }
|
||||
return tensorMap[name]
|
||||
}
|
||||
|
||||
let wName = "\(named).weight"
|
||||
guard let wDesc = findTensor(wName) else {
|
||||
return nil
|
||||
}
|
||||
|
||||
if wDesc.dtype != .bf16 {
|
||||
return nil
|
||||
}
|
||||
|
||||
let wReader: SafeTensorsReader
|
||||
if let idx = index {
|
||||
let actualWName = wDesc.name
|
||||
guard let wShard = idx.weightMap[actualWName] else { return nil }
|
||||
wReader = readers[wShard]!
|
||||
} else {
|
||||
wReader = readers["model.safetensors"]!
|
||||
}
|
||||
|
||||
let wData = try wReader.read(tensor: wDesc)
|
||||
let wFloats = SafeTensorsReader.bf16ToFloat32(wData)
|
||||
|
||||
let outDim = wDesc.shape[0]
|
||||
let inDim = wDesc.shape[1]
|
||||
|
||||
guard let wBuf = device.makeBuffer(
|
||||
bytes: wFloats, length: wFloats.count * MemoryLayout<Float>.stride,
|
||||
options: .storageModeShared
|
||||
) else { return nil }
|
||||
|
||||
return FloatWeights(weight: wBuf, inDim: inDim, outDim: outDim)
|
||||
}
|
||||
|
||||
/// Load non-quantized bf16 layer weights as FloatWeights
|
||||
private static func loadFloatWeight(named: String, from tensors: [TensorDescriptor],
|
||||
index: SafeTensorsIndex?,
|
||||
readers: [String: SafeTensorsReader],
|
||||
device: MTLDevice) throws -> FloatWeights? {
|
||||
let tensorMap = Dictionary(uniqueKeysWithValues: tensors.map { ($0.name, $0) })
|
||||
let prefix = "language_model.model."
|
||||
let modelPrefix = "model.language_model."
|
||||
let tensorMapWithPrefix = tensors.reduce(into: [String: TensorDescriptor]()) { dict, desc in
|
||||
dict[desc.name] = desc
|
||||
if desc.name.hasPrefix(prefix) {
|
||||
dict[String(desc.name.dropFirst(prefix.count))] = desc
|
||||
}
|
||||
if desc.name.hasPrefix(modelPrefix) {
|
||||
dict[String(desc.name.dropFirst(modelPrefix.count))] = desc
|
||||
}
|
||||
}
|
||||
func findTensor(_ name: String) -> TensorDescriptor? {
|
||||
if let desc = tensorMapWithPrefix[name] { return desc }
|
||||
return tensorMap[name]
|
||||
}
|
||||
|
||||
let wName = "\(named).weight"
|
||||
guard let wDesc = findTensor(wName) else { return nil }
|
||||
|
||||
if wDesc.dtype != .bf16 {
|
||||
return nil
|
||||
}
|
||||
|
||||
let wReader: SafeTensorsReader
|
||||
if let idx = index {
|
||||
let actualWName = wDesc.name
|
||||
guard let wShard = idx.weightMap[actualWName] else { return nil }
|
||||
wReader = readers[wShard]!
|
||||
} else {
|
||||
wReader = readers["model.safetensors"]!
|
||||
}
|
||||
|
||||
let wData = try wReader.read(tensor: wDesc)
|
||||
let wFloats = SafeTensorsReader.bf16ToFloat32(wData)
|
||||
|
||||
let outDim = wDesc.shape[0]
|
||||
let inDim = wDesc.shape[1]
|
||||
|
||||
guard let wBuf = device.makeBuffer(
|
||||
bytes: wFloats, length: wFloats.count * MemoryLayout<Float>.stride,
|
||||
options: .storageModeShared
|
||||
) else { return nil }
|
||||
|
||||
return FloatWeights(weight: wBuf, inDim: inDim, outDim: outDim)
|
||||
}
|
||||
/// Load a 3D expert tensor [numExperts, expertOutDim, inDimPacked] as a contiguous MoEExpertGroup.
|
||||
/// The data layout is: expert0[outDim, inDimPacked], expert1[outDim, inDimPacked], ...
|
||||
/// Per-expert access is done via byte offsets into the shared buffers.
|
||||
|
||||
Reference in New Issue
Block a user