Files
markbaseengine/Sources/MarkBase/Layers/Layer.swift
T
MarkBase Admin 239474bef0
CI / build (push) Waiting to run
CI / unit-tests (push) Blocked by required conditions
CI / lint (push) Blocked by required conditions
v2: fix 26B activation explosion — normalize groupSize=32 scales, fix hardcoded loops
2026-07-05 19:52:47 +08:00

1383 lines
66 KiB
Swift

import Metal
// ── Quantized Weights ────────────────────────────
public struct QuantizedWeights {
public let weight: MTLBuffer // U32 packed [outDim, inDim/(32/bits)]
public let scales: MTLBuffer // Float32 [outDim, inDim/groupSize]
public let biases: MTLBuffer // Float32 [outDim, inDim/groupSize]
public let inDim: Int
public let outDim: Int
public let bits: Int // 4 or 8
public let groupSize: Int // Quantization group size (32, 64, etc.)
}
// ── Float Weights (non-quantized bf16/f32) ────────────────────────────
public struct FloatWeights {
public let weight: MTLBuffer // Float32 [outDim, inDim]
public let inDim: Int
public let outDim: Int
}
// ── Layer Configuration ──────────────────────────
public struct E4BLayerConfig {
public let isSliding: Bool
public let headDim: Int
public let intermediateSize: Int
public let rotatedDim: Int
public let ropeTheta: Float
public let ropeScale: Float
public let windowSize: Int
public let nHeads: Int
public let nKvHeads: Int
public let hiddenSize: Int
public static func sliding(hiddenSize: Int, headDim: Int,
intermediateSize: Int,
nHeads: Int, nKvHeads: Int,
windowSize: Int = 512) -> E4BLayerConfig {
E4BLayerConfig(
isSliding: true, headDim: headDim,
intermediateSize: intermediateSize,
rotatedDim: headDim / 2, // default RoPE: half of dimensions rotated
ropeTheta: 10000.0, ropeScale: 1.0,
windowSize: windowSize,
nHeads: nHeads, nKvHeads: nKvHeads,
hiddenSize: hiddenSize
)
}
public static func full(hiddenSize: Int, headDim: Int,
intermediateSize: Int,
nHeads: Int, nKvHeads: Int,
maxPosition: Int = 8192) -> E4BLayerConfig {
E4BLayerConfig(
isSliding: false, headDim: headDim,
intermediateSize: intermediateSize,
rotatedDim: Int(Float(headDim) * 0.25),
ropeTheta: 1000000.0, ropeScale: 1.0,
windowSize: maxPosition,
nHeads: nHeads, nKvHeads: nKvHeads,
hiddenSize: hiddenSize
)
}
}
// ── Temp buffers shared across forward pass ──────
public struct ForwardTemps {
public let q: MTLBuffer // [maxHeads * maxHeadDim]
public let k: MTLBuffer // [maxKvHeads * maxHeadDim]
public let v: MTLBuffer // [maxKvHeads * maxHeadDim]
public let h: MTLBuffer // [hiddenSize] scratch (FFN专用)
public let attnH: MTLBuffer // [hiddenSize] attention专用 (避免覆盖h)
public let gate: MTLBuffer // [maxIntermediateSize]
public let up: MTLBuffer // [maxIntermediateSize]
public let attn: MTLBuffer // [maxHeads * maxHeadDim]
public let gating: MTLBuffer // [256] per-layer gating scratch
public let ns: MTLBuffer // [max(hiddenSize, nHeads*headDim)] norm scratch
public let io: MTLBuffer // [hiddenSize] dedicated layer I/O (separate from scratch)
public init(device: MTLDevice,
maxHeadDim: Int = 512,
maxIntermediate: Int = 20480,
hiddenSize: Int = 2560,
nHeads: Int = 8,
nKvHeads: Int = 2) throws {
func buf(_ n: Int) throws -> MTLBuffer {
guard let b = device.makeBuffer(length: n * MemoryLayout<Float>.stride,
options: .storageModeShared)
else { throw E4BError.bufferCreationFailed }
return b
}
q = try buf(nHeads * maxHeadDim)
k = try buf(nKvHeads * maxHeadDim)
v = try buf(nKvHeads * maxHeadDim)
h = try buf(hiddenSize)
attnH = try buf(hiddenSize) // NEW: attention专用buffer
gate = try buf(maxIntermediate)
up = try buf(maxIntermediate)
attn = try buf(nHeads * maxHeadDim)
gating = try buf(256)
ns = try buf(max(hiddenSize, nHeads * maxHeadDim))
io = try buf(hiddenSize)
}
}
// ── MoE structures ──────────────────────────────
public struct MoEExpert {
public let gateProj: QuantizedWeights
public let upProj: QuantizedWeights
public let downProj: QuantizedWeights
public init(gateProj: QuantizedWeights, upProj: QuantizedWeights, downProj: QuantizedWeights) {
self.gateProj = gateProj
self.upProj = upProj
self.downProj = downProj
}
}
/// Expert weights stored as contiguous 3D tensors [numExperts, outDim, inDimPacked]
/// and [numExperts, outDim, numGroups] for scales/biases.
/// Per-expert access uses byte offsets into the shared buffers.
public struct MoEExpertGroup {
/// Full 3D weight buffer [numExperts * expertOutDim, expertInDimPacked] uint32
public let weight: MTLBuffer
/// Full 3D scales buffer [numExperts * expertOutDim, numGroups] float32
public let scales: MTLBuffer
/// Full 3D biases buffer (same shape as scales)
public let biases: MTLBuffer
/// Per-expert output dimension
public let expertOutDim: Int
/// Input dimension (hidden size)
public let expertInDim: Int
/// Number of groups per output row = inDim / groupSize
public let numGroups: Int
/// Total number of experts
public let numExperts: Int
/// Quantization bits (4 or 8)
public let bits: Int
/// Byte stride per expert for weight buffer
public var weightStride: Int { expertOutDim * (expertInDim * bits / 32) * 4 }
/// Byte stride per expert for scales/biases buffer
public var scalesStride: Int { expertOutDim * numGroups * 4 }
public init(weight: MTLBuffer, scales: MTLBuffer, biases: MTLBuffer,
expertOutDim: Int, expertInDim: Int, numGroups: Int, numExperts: Int,
bits: Int = 4) {
self.weight = weight
self.scales = scales
self.biases = biases
self.expertOutDim = expertOutDim
self.expertInDim = expertInDim
self.numGroups = numGroups
self.numExperts = numExperts
self.bits = bits
}
}
// ── Layer forward pass ───────────────────────────
public final class E4BLayer {
let config: E4BLayerConfig
let rmsNormEps: Float = 1e-6
let layerIdx: Int // For debug logging
// Norm weights (Float32)
let inputLayernorm: MTLBuffer?
let postAttentionLayernorm: MTLBuffer?
let preFeedforwardLayernorm: MTLBuffer?
let postFeedforwardLayernorm: MTLBuffer?
let postPerLayerInputNorm: MTLBuffer? // after per-layer gating
let qNorm: MTLBuffer?
let kNorm: MTLBuffer?
let vNorm: MTLBuffer? // nil — no-scale variant
// Quantized projections
let qProj: QuantizedWeights?
let kProj: QuantizedWeights?
let vProj: QuantizedWeights?
let oProj: QuantizedWeights?
let gateProj: QuantizedWeights?
let upProj: QuantizedWeights?
let downProj: QuantizedWeights?
let perLayerGate: QuantizedWeights?
let perLayerProjection: QuantizedWeights?
// Float projections (bf16 models)
let qProjFloat: FloatWeights?
let kProjFloat: FloatWeights?
let vProjFloat: FloatWeights?
let oProjFloat: FloatWeights?
let gateProjFloat: FloatWeights?
let upProjFloat: FloatWeights?
let downProjFloat: FloatWeights?
let perLayerGateFloat: FloatWeights?
let perLayerProjectionFloat: FloatWeights?
// MoE
let useMoE: Bool
let routerProj: QuantizedWeights?
let routerScale: Float
let perExpertScale: [Float]?
let expertGate: MoEExpertGroup?
let expertUp: MoEExpertGroup?
let expertDown: MoEExpertGroup?
let topK: Int
// K=V sharing for full attention layers (Gemma 4)
let kEqualsV: Bool
// Per-layer constants
let perLayerInput: MTLBuffer?
let perLayerInputScale: Float
let perLayerProjectionScale: Float
let layerScalar: Float
public init(config: E4BLayerConfig,
layerIdx: Int = 0,
inputLayernorm: MTLBuffer?,
postAttentionLayernorm: MTLBuffer?,
preFeedforwardLayernorm: MTLBuffer?,
postFeedforwardLayernorm: MTLBuffer?,
postPerLayerInputNorm: MTLBuffer?,
qNorm: MTLBuffer?,
kNorm: MTLBuffer?,
vNorm: MTLBuffer?,
qProj: QuantizedWeights? = nil,
kProj: QuantizedWeights? = nil,
vProj: QuantizedWeights? = nil,
oProj: QuantizedWeights? = nil,
gateProj: QuantizedWeights? = nil,
upProj: QuantizedWeights? = nil,
downProj: QuantizedWeights? = nil,
perLayerGate: QuantizedWeights? = nil,
perLayerProjection: QuantizedWeights? = nil,
qProjFloat: FloatWeights? = nil,
kProjFloat: FloatWeights? = nil,
vProjFloat: FloatWeights? = nil,
oProjFloat: FloatWeights? = nil,
gateProjFloat: FloatWeights? = nil,
upProjFloat: FloatWeights? = nil,
downProjFloat: FloatWeights? = nil,
perLayerGateFloat: FloatWeights? = nil,
perLayerProjectionFloat: FloatWeights? = nil,
perLayerInput: MTLBuffer?,
perLayerInputScale: Float,
perLayerProjectionScale: Float,
layerScalar: Float,
useMoE: Bool = false,
routerProj: QuantizedWeights? = nil,
routerScale: Float = 1.0,
perExpertScale: [Float]? = nil,
expertGate: MoEExpertGroup? = nil,
expertUp: MoEExpertGroup? = nil,
expertDown: MoEExpertGroup? = nil,
topK: Int = 8,
kEqualsV: Bool = false) {
self.layerIdx = layerIdx
self.config = config
self.inputLayernorm = inputLayernorm
self.postAttentionLayernorm = postAttentionLayernorm
self.preFeedforwardLayernorm = preFeedforwardLayernorm
self.postFeedforwardLayernorm = postFeedforwardLayernorm
self.postPerLayerInputNorm = postPerLayerInputNorm
self.qNorm = qNorm
self.kNorm = kNorm
self.vNorm = vNorm
self.qProj = qProj
self.kProj = kProj
self.vProj = vProj
self.oProj = oProj
self.gateProj = gateProj
self.upProj = upProj
self.downProj = downProj
self.perLayerGate = perLayerGate
self.perLayerProjection = perLayerProjection
self.qProjFloat = qProjFloat
self.kProjFloat = kProjFloat
self.vProjFloat = vProjFloat
self.oProjFloat = oProjFloat
self.gateProjFloat = gateProjFloat
self.upProjFloat = upProjFloat
self.downProjFloat = downProjFloat
self.perLayerGateFloat = perLayerGateFloat
self.perLayerProjectionFloat = perLayerProjectionFloat
self.kEqualsV = kEqualsV
self.perLayerInput = perLayerInput
self.perLayerInputScale = perLayerInputScale
self.perLayerProjectionScale = perLayerProjectionScale
self.layerScalar = layerScalar
self.useMoE = useMoE
self.routerProj = routerProj
self.routerScale = routerScale
self.perExpertScale = perExpertScale
self.expertGate = expertGate
self.expertUp = expertUp
self.expertDown = expertDown
self.topK = topK
}
// ── Kernel dispatch helpers (optimized versions preferred) ──────────────────
func rmsNorm(engine: MarkBaseEngine, cmdBuf: MTLCommandBuffer,
input: MTLBuffer, weight: MTLBuffer?,
output: MTLBuffer, count: Int, eps: Float) throws {
let pso = try engine.pipeline(named: "rms_norm")
let enc = cmdBuf.makeComputeCommandEncoder()!
enc.setComputePipelineState(pso)
enc.setBuffer(input, offset: 0, index: 0)
enc.setBuffer(weight, offset: 0, index: 1)
enc.setBuffer(output, offset: 0, index: 2)
var N = UInt32(count)
enc.setBytes(&N, length: MemoryLayout<UInt32>.size, index: 3)
var e = eps
enc.setBytes(&e, length: MemoryLayout<Float>.size, index: 4)
let gridSize = MTLSize(width: count, height: 1, depth: 1)
let tgSize = min(256, count)
let tg = MTLSize(width: tgSize, height: 1, depth: 1)
enc.dispatchThreads(gridSize, threadsPerThreadgroup: tg)
enc.endEncoding()
}
func groupedRmsNorm(engine: MarkBaseEngine, cmdBuf: MTLCommandBuffer,
input: MTLBuffer, weight: MTLBuffer?,
output: MTLBuffer, count: Int,
groupSize: Int, eps: Float) throws {
let pso = try engine.pipeline(named: "rms_norm_grouped")
let enc = cmdBuf.makeComputeCommandEncoder()!
enc.setComputePipelineState(pso)
enc.setBuffer(input, offset: 0, index: 0)
enc.setBuffer(weight, offset: 0, index: 1)
enc.setBuffer(output, offset: 0, index: 2)
var N = UInt32(count)
enc.setBytes(&N, length: MemoryLayout<UInt32>.size, index: 3)
var gs = UInt32(groupSize)
enc.setBytes(&gs, length: MemoryLayout<UInt32>.size, index: 4)
var e = eps
enc.setBytes(&e, length: MemoryLayout<Float>.size, index: 5)
let tg = engine.threadgroupSize1D(pso, count: count)
enc.dispatchThreads(MTLSize(width: count, height: 1, depth: 1),
threadsPerThreadgroup: tg)
enc.endEncoding()
}
func gelu(engine: MarkBaseEngine, cmdBuf: MTLCommandBuffer,
input: MTLBuffer, output: MTLBuffer, count: Int) throws {
let pso = try engine.pipeline(named: "gelu_approx")
let enc = cmdBuf.makeComputeCommandEncoder()!
enc.setComputePipelineState(pso)
enc.setBuffer(input, offset: 0, index: 0)
enc.setBuffer(output, offset: 0, index: 1)
var N = UInt32(count)
enc.setBytes(&N, length: MemoryLayout<UInt32>.size, index: 2)
let tg = engine.threadgroupSize1D(pso, count: count)
enc.dispatchThreads(MTLSize(width: count, height: 1, depth: 1),
threadsPerThreadgroup: tg)
enc.endEncoding()
}
func quantizedMatmul(engine: MarkBaseEngine, cmdBuf: MTLCommandBuffer,
input: MTLBuffer,
weights: QuantizedWeights,
output: MTLBuffer) throws {
// Select kernel based on quantization bits
let kernelName = weights.bits == 8 ? "quantized_matmul_simd_8bit" : "quantized_matmul"
if let pso = try? engine.pipeline(named: kernelName) {
let enc = cmdBuf.makeComputeCommandEncoder()!
enc.setComputePipelineState(pso)
enc.setBuffer(input, offset: 0, index: 0)
enc.setBuffer(weights.weight, offset: 0, index: 1)
enc.setBuffer(weights.scales, offset: 0, index: 2)
enc.setBuffer(weights.biases, offset: 0, index: 3)
enc.setBuffer(output, offset: 0, index: 4)
var inDim = UInt32(weights.inDim)
enc.setBytes(&inDim, length: MemoryLayout<UInt32>.size, index: 5)
var outDim = UInt32(weights.outDim)
enc.setBytes(&outDim, length: MemoryLayout<UInt32>.size, index: 6)
var groupSize = UInt32(weights.groupSize) // quantization group size from weights
enc.setBytes(&groupSize, length: MemoryLayout<UInt32>.size, index: 7)
// Threadgroup memory for input vector cache
let tgMemSize = weights.inDim * 4 // Float32
enc.setThreadgroupMemoryLength(tgMemSize, index: 0)
let count = weights.outDim
let tg = MTLSize(width: 256, height: 1, depth: 1)
enc.dispatchThreads(MTLSize(width: count, height: 1, depth: 1),
threadsPerThreadgroup: tg)
enc.endEncoding()
return
}
// Fallback to original
let pso = try engine.pipeline(named: "quantized_matmul")
let enc = cmdBuf.makeComputeCommandEncoder()!
enc.setComputePipelineState(pso)
enc.setBuffer(input, offset: 0, index: 0)
enc.setBuffer(weights.weight, offset: 0, index: 1)
enc.setBuffer(weights.scales, offset: 0, index: 2)
enc.setBuffer(weights.biases, offset: 0, index: 3)
enc.setBuffer(output, offset: 0, index: 4)
var inDim = UInt32(weights.inDim)
enc.setBytes(&inDim, length: MemoryLayout<UInt32>.size, index: 5)
var outDim = UInt32(weights.outDim)
enc.setBytes(&outDim, length: MemoryLayout<UInt32>.size, index: 6)
var groupSize = UInt32(weights.groupSize) // FIX: Add groupSize!
enc.setBytes(&groupSize, length: MemoryLayout<UInt32>.size, index: 7)
let count = weights.outDim
let tg = engine.threadgroupSize1D(pso, count: count)
enc.dispatchThreads(MTLSize(width: count, height: 1, depth: 1),
threadsPerThreadgroup: tg)
enc.endEncoding()
}
func matmulFloat(engine: MarkBaseEngine, cmdBuf: MTLCommandBuffer,
input: MTLBuffer,
weights: FloatWeights,
output: MTLBuffer) throws {
let pso = try engine.pipeline(named: "matmul_f32")
let enc = cmdBuf.makeComputeCommandEncoder()!
enc.setComputePipelineState(pso)
enc.setBuffer(input, offset: 0, index: 0)
enc.setBuffer(weights.weight, offset: 0, index: 1)
enc.setBuffer(output, offset: 0, index: 2)
var M: UInt32 = 1 // Single token
enc.setBytes(&M, length: MemoryLayout<UInt32>.size, index: 3)
var K = UInt32(weights.inDim)
enc.setBytes(&K, length: MemoryLayout<UInt32>.size, index: 4)
var N = UInt32(weights.outDim)
enc.setBytes(&N, length: MemoryLayout<UInt32>.size, index: 5)
let count = weights.outDim
let tg = engine.threadgroupSize1D(pso, count: count)
enc.dispatchThreads(MTLSize(width: count, height: 1, depth: 1),
threadsPerThreadgroup: tg)
enc.endEncoding()
}
func matmulAny(engine: MarkBaseEngine, cmdBuf: MTLCommandBuffer,
input: MTLBuffer,
weightsQ: QuantizedWeights?,
weightsF: FloatWeights?,
output: MTLBuffer) throws {
if let qw = weightsQ {
try quantizedMatmul(engine: engine, cmdBuf: cmdBuf, input: input, weights: qw, output: output)
} else if let fw = weightsF {
try matmulFloat(engine: engine, cmdBuf: cmdBuf, input: input, weights: fw, output: output)
}
}
func applyRoPEQ(engine: MarkBaseEngine, cmdBuf: MTLCommandBuffer,
q: MTLBuffer, position: Int) throws {
let pso = try engine.pipeline(named: "apply_rope_q")
let enc = cmdBuf.makeComputeCommandEncoder()!
enc.setComputePipelineState(pso)
enc.setBuffer(q, offset: 0, index: 0)
var nHeads = UInt32(config.nHeads)
enc.setBytes(&nHeads, length: MemoryLayout<UInt32>.size, index: 1)
var headDim = UInt32(config.headDim)
enc.setBytes(&headDim, length: MemoryLayout<UInt32>.size, index: 2)
var rotatedDim = UInt32(config.rotatedDim)
enc.setBytes(&rotatedDim, length: MemoryLayout<UInt32>.size, index: 3)
var theta = config.ropeTheta
enc.setBytes(&theta, length: MemoryLayout<Float>.size, index: 4)
var scale = config.ropeScale
enc.setBytes(&scale, length: MemoryLayout<Float>.size, index: 5)
var pos = Int32(position)
enc.setBytes(&pos, length: MemoryLayout<Int32>.size, index: 6)
let count = config.nHeads * (config.rotatedDim / 2)
let tg = engine.threadgroupSize1D(pso, count: count)
enc.dispatchThreads(MTLSize(width: count, height: 1, depth: 1),
threadsPerThreadgroup: tg)
enc.endEncoding()
}
func applyRoPEK(engine: MarkBaseEngine, cmdBuf: MTLCommandBuffer,
k: MTLBuffer, position: Int) throws {
let pso = try engine.pipeline(named: "apply_rope_k")
let enc = cmdBuf.makeComputeCommandEncoder()!
enc.setComputePipelineState(pso)
enc.setBuffer(k, offset: 0, index: 0)
var nKvHeads = UInt32(config.nKvHeads)
enc.setBytes(&nKvHeads, length: MemoryLayout<UInt32>.size, index: 1)
var headDim = UInt32(config.headDim)
enc.setBytes(&headDim, length: MemoryLayout<UInt32>.size, index: 2)
var rotatedDim = UInt32(config.rotatedDim)
enc.setBytes(&rotatedDim, length: MemoryLayout<UInt32>.size, index: 3)
var theta = config.ropeTheta
enc.setBytes(&theta, length: MemoryLayout<Float>.size, index: 4)
var scale = config.ropeScale
enc.setBytes(&scale, length: MemoryLayout<Float>.size, index: 5)
var pos = Int32(position)
enc.setBytes(&pos, length: MemoryLayout<Int32>.size, index: 6)
let count = config.nKvHeads * (config.rotatedDim / 2)
let tg = engine.threadgroupSize1D(pso, count: count)
enc.dispatchThreads(MTLSize(width: count, height: 1, depth: 1),
threadsPerThreadgroup: tg)
enc.endEncoding()
}
func slidingAttention(engine: MarkBaseEngine, cmdBuf: MTLCommandBuffer,
q: MTLBuffer, cache: KVCache,
position: Int) throws {
// Try optimized SIMD version first (softcapping removed)
if let pso = try? engine.pipeline(named: "sliding_attention_simd") {
let enc = cmdBuf.makeComputeCommandEncoder()!
enc.setComputePipelineState(pso)
enc.setBuffer(q, offset: 0, index: 0)
enc.setBuffer(cache.buffer, offset: cache.keyBaseOffset, index: 1)
enc.setBuffer(cache.buffer, offset: cache.valueBaseOffset, index: 2)
enc.setBuffer(attnBuf, offset: 0, index: 3)
var nHeads = UInt32(config.nHeads)
enc.setBytes(&nHeads, length: MemoryLayout<UInt32>.size, index: 4)
var nKvHeads = UInt32(config.nKvHeads)
enc.setBytes(&nKvHeads, length: MemoryLayout<UInt32>.size, index: 5)
var headDim = UInt32(config.headDim)
enc.setBytes(&headDim, length: MemoryLayout<UInt32>.size, index: 6)
var windowSize = UInt32(config.windowSize)
enc.setBytes(&windowSize, length: MemoryLayout<UInt32>.size, index: 7)
var off = Int32(position)
enc.setBytes(&off, length: MemoryLayout<Int32>.size, index: 8)
// Threadgroup memory for K/V cache
let kvCacheSize = config.windowSize * config.nKvHeads * (config.headDim/4) * 16 // float4 = 16 bytes
enc.setThreadgroupMemoryLength(kvCacheSize, index: 0)
enc.setThreadgroupMemoryLength(kvCacheSize, index: 1)
let grid = MTLSize(width: config.nHeads, height: config.headDim/4, depth: 1)
let tg = MTLSize(width: 8, height: 16, depth: 1) // Tune for cache efficiency
enc.dispatchThreads(grid, threadsPerThreadgroup: tg)
enc.endEncoding()
return
}
// Fallback to original
let pso = try engine.pipeline(named: "sliding_attention")
let enc = cmdBuf.makeComputeCommandEncoder()!
enc.setComputePipelineState(pso)
enc.setBuffer(q, offset: 0, index: 0)
enc.setBuffer(cache.buffer, offset: cache.keyBaseOffset, index: 1)
enc.setBuffer(cache.buffer, offset: cache.valueBaseOffset, index: 2)
enc.setBuffer(attnBuf, offset: 0, index: 3)
var nHeads = UInt32(config.nHeads)
enc.setBytes(&nHeads, length: MemoryLayout<UInt32>.size, index: 4)
var nKvHeads = UInt32(config.nKvHeads)
enc.setBytes(&nKvHeads, length: MemoryLayout<UInt32>.size, index: 5)
var headDim = UInt32(config.headDim)
enc.setBytes(&headDim, length: MemoryLayout<UInt32>.size, index: 6)
var windowSize = UInt32(config.windowSize)
enc.setBytes(&windowSize, length: MemoryLayout<UInt32>.size, index: 7)
var off = Int32(position)
enc.setBytes(&off, length: MemoryLayout<Int32>.size, index: 8)
let grid = MTLSize(width: config.nHeads, height: config.headDim, depth: 1)
let tg = engine.threadgroupSize2D(pso, grid: (config.nHeads, config.headDim))
enc.dispatchThreads(grid, threadsPerThreadgroup: tg)
enc.endEncoding()
}
func slidingAttentionWithCurrent(engine: MarkBaseEngine, cmdBuf: MTLCommandBuffer,
q: MTLBuffer, cache: KVCache,
curK: MTLBuffer, curV: MTLBuffer,
position: Int) throws {
let pso = try engine.pipeline(named: "sliding_attention_with_current")
let enc = cmdBuf.makeComputeCommandEncoder()!
enc.setComputePipelineState(pso)
enc.setBuffer(q, offset: 0, index: 0)
enc.setBuffer(cache.buffer, offset: cache.keyBaseOffset, index: 1)
enc.setBuffer(cache.buffer, offset: cache.valueBaseOffset, index: 2)
enc.setBuffer(curK, offset: 0, index: 3)
enc.setBuffer(curV, offset: 0, index: 4)
enc.setBuffer(attnBuf, offset: 0, index: 5)
var nHeads = UInt32(config.nHeads)
enc.setBytes(&nHeads, length: MemoryLayout<UInt32>.size, index: 6)
var nKvHeads = UInt32(config.nKvHeads)
enc.setBytes(&nKvHeads, length: MemoryLayout<UInt32>.size, index: 7)
var headDim = UInt32(config.headDim)
enc.setBytes(&headDim, length: MemoryLayout<UInt32>.size, index: 8)
var windowSize = UInt32(cache.maxLength)
enc.setBytes(&windowSize, length: MemoryLayout<UInt32>.size, index: 9)
// For shared layers: read ALL entries from owner's cache (positions 0..N)
// The owner has already stored at all positions up to current position
let cacheLen = cache.currentLength
var cacheLenVal = UInt32(cacheLen)
enc.setBytes(&cacheLenVal, length: MemoryLayout<UInt32>.size, index: 10)
var pos = Int32(position)
enc.setBytes(&pos, length: MemoryLayout<Int32>.size, index: 11)
let grid = MTLSize(width: config.nHeads, height: config.headDim, depth: 1)
let tg = engine.threadgroupSize2D(pso, grid: (config.nHeads, config.headDim))
enc.dispatchThreads(grid, threadsPerThreadgroup: tg)
enc.endEncoding()
}
// Temp buffer for attention output — set externally from ForwardTemps
var attnBuf: MTLBuffer!
func fullAttention(engine: MarkBaseEngine, cmdBuf: MTLCommandBuffer,
q: MTLBuffer, cache: KVCache,
position: Int) throws {
// Try optimized SIMD version first (no softcapping for text models)
if let pso = try? engine.pipeline(named: "full_attention_simd") {
let enc = cmdBuf.makeComputeCommandEncoder()!
enc.setComputePipelineState(pso)
enc.setBuffer(q, offset: 0, index: 0)
enc.setBuffer(cache.buffer, offset: cache.keyBaseOffset, index: 1)
enc.setBuffer(cache.buffer, offset: cache.valueBaseOffset, index: 2)
enc.setBuffer(attnBuf, offset: 0, index: 3)
var nHeads = UInt32(config.nHeads)
enc.setBytes(&nHeads, length: MemoryLayout<UInt32>.size, index: 4)
var nKvHeads = UInt32(config.nKvHeads)
enc.setBytes(&nKvHeads, length: MemoryLayout<UInt32>.size, index: 5)
var headDim = UInt32(config.headDim)
enc.setBytes(&headDim, length: MemoryLayout<UInt32>.size, index: 6)
var seqLen = UInt32(position + 1) // FIX: use seqLen, not maxPos
enc.setBytes(&seqLen, length: MemoryLayout<UInt32>.size, index: 7)
// Threadgroup memory for K/V cache (use seqLen, not maxPos)
let kvCacheSize = Int(seqLen) * config.nKvHeads * (config.headDim/4) * 16 // float4 = 16 bytes
enc.setThreadgroupMemoryLength(kvCacheSize, index: 0)
enc.setThreadgroupMemoryLength(kvCacheSize, index: 1)
let grid = MTLSize(width: config.nHeads, height: config.headDim/4, depth: 1)
let tg = MTLSize(width: 8, height: 16, depth: 1) // Tune for cache efficiency
enc.dispatchThreads(grid, threadsPerThreadgroup: tg)
enc.endEncoding()
return
}
// Fallback to original
let pso = try engine.pipeline(named: "full_attention")
let enc = cmdBuf.makeComputeCommandEncoder()!
enc.setComputePipelineState(pso)
enc.setBuffer(q, offset: 0, index: 0)
enc.setBuffer(cache.buffer, offset: cache.keyBaseOffset, index: 1)
enc.setBuffer(cache.buffer, offset: cache.valueBaseOffset, index: 2)
enc.setBuffer(attnBuf, offset: 0, index: 3)
var nHeads = UInt32(config.nHeads)
enc.setBytes(&nHeads, length: MemoryLayout<UInt32>.size, index: 4)
var nKvHeads = UInt32(config.nKvHeads)
enc.setBytes(&nKvHeads, length: MemoryLayout<UInt32>.size, index: 5)
var headDim = UInt32(config.headDim)
enc.setBytes(&headDim, length: MemoryLayout<UInt32>.size, index: 6)
var maxPos = UInt32(cache.maxLength)
enc.setBytes(&maxPos, length: MemoryLayout<UInt32>.size, index: 7)
var off = Int32(position)
enc.setBytes(&off, length: MemoryLayout<Int32>.size, index: 8)
let grid = MTLSize(width: config.nHeads, height: config.headDim, depth: 1)
let tg = engine.threadgroupSize2D(pso, grid: (config.nHeads, config.headDim))
enc.dispatchThreads(grid, threadsPerThreadgroup: tg)
enc.endEncoding()
}
func fullAttentionWithCurrent(engine: MarkBaseEngine, cmdBuf: MTLCommandBuffer,
q: MTLBuffer, cache: KVCache,
curK: MTLBuffer, curV: MTLBuffer,
position: Int) throws {
let pso = try engine.pipeline(named: "full_attention_with_current")
let enc = cmdBuf.makeComputeCommandEncoder()!
enc.setComputePipelineState(pso)
enc.setBuffer(q, offset: 0, index: 0)
enc.setBuffer(cache.buffer, offset: cache.keyBaseOffset, index: 1)
enc.setBuffer(cache.buffer, offset: cache.valueBaseOffset, index: 2)
enc.setBuffer(curK, offset: 0, index: 3)
enc.setBuffer(curV, offset: 0, index: 4)
enc.setBuffer(attnBuf, offset: 0, index: 5)
var nHeads = UInt32(config.nHeads)
enc.setBytes(&nHeads, length: MemoryLayout<UInt32>.size, index: 6)
var nKvHeads = UInt32(config.nKvHeads)
enc.setBytes(&nKvHeads, length: MemoryLayout<UInt32>.size, index: 7)
var headDim = UInt32(config.headDim)
enc.setBytes(&headDim, length: MemoryLayout<UInt32>.size, index: 8)
// For shared layers: read ALL entries from owner's cache (positions 0..N)
// The owner has already stored at all positions up to current position
let cacheLen = cache.currentLength
var cacheLenVal = UInt32(cacheLen)
enc.setBytes(&cacheLenVal, length: MemoryLayout<UInt32>.size, index: 9)
var pos = Int32(position)
enc.setBytes(&pos, length: MemoryLayout<Int32>.size, index: 10)
let grid = MTLSize(width: config.nHeads, height: config.headDim, depth: 1)
let tg = engine.threadgroupSize2D(pso, grid: (config.nHeads, config.headDim))
enc.dispatchThreads(grid, threadsPerThreadgroup: tg)
enc.endEncoding()
}
func eltwiseAdd(engine: MarkBaseEngine, cmdBuf: MTLCommandBuffer,
a: MTLBuffer, b: MTLBuffer,
output: MTLBuffer, count: Int) throws {
// Try optimized SIMD version first
if let pso = try? engine.pipeline(named: "eltwise_add_simd") {
let enc = cmdBuf.makeComputeCommandEncoder()!
enc.setComputePipelineState(pso)
enc.setBuffer(a, offset: 0, index: 0)
enc.setBuffer(b, offset: 0, index: 1)
enc.setBuffer(output, offset: 0, index: 2)
var N = UInt32(count)
enc.setBytes(&N, length: MemoryLayout<UInt32>.size, index: 3)
let grid = MTLSize(width: (count + 3) / 4, height: 1, depth: 1) // float4 processing
let tg = engine.threadgroupSize1D(pso, count: (count + 3) / 4)
enc.dispatchThreads(grid, threadsPerThreadgroup: tg)
enc.endEncoding()
return
}
// Fallback to original
let pso = try engine.pipeline(named: "eltwise_add")
let enc = cmdBuf.makeComputeCommandEncoder()!
enc.setComputePipelineState(pso)
enc.setBuffer(a, offset: 0, index: 0)
enc.setBuffer(b, offset: 0, index: 1)
enc.setBuffer(output, offset: 0, index: 2)
var N = UInt32(count)
enc.setBytes(&N, length: MemoryLayout<UInt32>.size, index: 3)
let tg = engine.threadgroupSize1D(pso, count: count)
enc.dispatchThreads(MTLSize(width: count, height: 1, depth: 1),
threadsPerThreadgroup: tg)
enc.endEncoding()
}
func eltwiseMul(engine: MarkBaseEngine, cmdBuf: MTLCommandBuffer,
a: MTLBuffer, aOffset: Int = 0,
b: MTLBuffer, bOffset: Int = 0,
output: MTLBuffer, outputOffset: Int = 0,
count: Int) throws {
// Try optimized SIMD version first
if let pso = try? engine.pipeline(named: "eltwise_mul_simd") {
let enc = cmdBuf.makeComputeCommandEncoder()!
enc.setComputePipelineState(pso)
enc.setBuffer(a, offset: aOffset, index: 0)
enc.setBuffer(b, offset: bOffset, index: 1)
enc.setBuffer(output, offset: outputOffset, index: 2)
var N = UInt32(count)
enc.setBytes(&N, length: MemoryLayout<UInt32>.size, index: 3)
let grid = MTLSize(width: (count + 3) / 4, height: 1, depth: 1) // float4 processing
let tg = engine.threadgroupSize1D(pso, count: (count + 3) / 4)
enc.dispatchThreads(grid, threadsPerThreadgroup: tg)
enc.endEncoding()
return
}
// Fallback to original
let pso = try engine.pipeline(named: "eltwise_mul")
let enc = cmdBuf.makeComputeCommandEncoder()!
enc.setComputePipelineState(pso)
enc.setBuffer(a, offset: 0, index: 0)
enc.setBuffer(b, offset: 0, index: 1)
enc.setBuffer(output, offset: 0, index: 2)
var N = UInt32(count)
enc.setBytes(&N, length: MemoryLayout<UInt32>.size, index: 3)
let tg = engine.threadgroupSize1D(pso, count: count)
enc.dispatchThreads(MTLSize(width: count, height: 1, depth: 1),
threadsPerThreadgroup: tg)
enc.endEncoding()
}
func eltwiseAddScaled(engine: MarkBaseEngine, cmdBuf: MTLCommandBuffer,
a: MTLBuffer, scaleA: Float,
b: MTLBuffer, scaleB: Float,
output: MTLBuffer, count: Int) throws {
let pso = try engine.pipeline(named: "eltwise_add_scaled")
let enc = cmdBuf.makeComputeCommandEncoder()!
enc.setComputePipelineState(pso)
enc.setBuffer(a, offset: 0, index: 0)
var sa = scaleA
enc.setBytes(&sa, length: MemoryLayout<Float>.size, index: 1)
enc.setBuffer(b, offset: 0, index: 2)
var sb = scaleB
enc.setBytes(&sb, length: MemoryLayout<Float>.size, index: 3)
enc.setBuffer(output, offset: 0, index: 4)
var N = UInt32(count)
enc.setBytes(&N, length: MemoryLayout<UInt32>.size, index: 5)
let tg = engine.threadgroupSize1D(pso, count: count)
enc.dispatchThreads(MTLSize(width: count, height: 1, depth: 1),
threadsPerThreadgroup: tg)
enc.endEncoding()
}
func fusedGateUp(engine: MarkBaseEngine, cmdBuf: MTLCommandBuffer,
input: MTLBuffer,
output: MTLBuffer) throws {
// Float path: separate matmuls for gate and up
if let gf = gateProjFloat, let uf = upProjFloat {
try matmulFloat(engine: engine, cmdBuf: cmdBuf, input: input, weights: gf, output: output)
// Note: This only does gate projection, up projection is separate for bf16
return
}
// Quantized path: fused kernel
guard let gp = gateProj, let up = upProj else { return }
let kernelName = gp.bits == 8 ? "quantized_matmul_gate_up_opt_8bit" : "quantized_matmul_gate_up_opt"
if let pso = try? engine.pipeline(named: kernelName) {
// Optimized path: threadgroup-cached input + uint4 loads
let enc = cmdBuf.makeComputeCommandEncoder()!
enc.setComputePipelineState(pso)
enc.setBuffer(input, offset: 0, index: 0)
enc.setBuffer(gp.weight, offset: 0, index: 1)
enc.setBuffer(gp.scales, offset: 0, index: 2)
enc.setBuffer(gp.biases, offset: 0, index: 3)
enc.setBuffer(up.weight, offset: 0, index: 4)
enc.setBuffer(up.scales, offset: 0, index: 5)
enc.setBuffer(up.biases, offset: 0, index: 6)
enc.setBuffer(output, offset: 0, index: 7)
var inDim = UInt32(gp.inDim)
enc.setBytes(&inDim, length: MemoryLayout<UInt32>.size, index: 8)
var outDim = UInt32(gp.outDim)
enc.setBytes(&outDim, length: MemoryLayout<UInt32>.size, index: 9)
var groupSize = UInt32(gp.groupSize)
enc.setBytes(&groupSize, length: MemoryLayout<UInt32>.size, index: 10)
let tgMemSize = gp.inDim * 4
enc.setThreadgroupMemoryLength(tgMemSize, index: 0)
let count = gp.outDim
let tg = MTLSize(width: 256, height: 1, depth: 1)
enc.dispatchThreads(MTLSize(width: count, height: 1, depth: 1),
threadsPerThreadgroup: tg)
enc.endEncoding()
} else {
// Fallback to old kernel
let fallbackName = gp.bits == 8 ? "quantized_matmul_gate_up_8bit" : "quantized_matmul_gate_up"
let fallbackPSO = try engine.pipeline(named: fallbackName)
let enc = cmdBuf.makeComputeCommandEncoder()!
enc.setComputePipelineState(fallbackPSO)
enc.setBuffer(input, offset: 0, index: 0)
enc.setBuffer(gp.weight, offset: 0, index: 1)
enc.setBuffer(gp.scales, offset: 0, index: 2)
enc.setBuffer(gp.biases, offset: 0, index: 3)
enc.setBuffer(up.weight, offset: 0, index: 4)
enc.setBuffer(up.scales, offset: 0, index: 5)
enc.setBuffer(up.biases, offset: 0, index: 6)
enc.setBuffer(output, offset: 0, index: 7)
var inDim = UInt32(gp.inDim)
enc.setBytes(&inDim, length: MemoryLayout<UInt32>.size, index: 8)
var outDim = UInt32(gp.outDim)
enc.setBytes(&outDim, length: MemoryLayout<UInt32>.size, index: 9)
var groupSize = UInt32(gp.groupSize)
enc.setBytes(&groupSize, length: MemoryLayout<UInt32>.size, index: 10)
let count = gp.outDim
let tg = engine.threadgroupSize1D(fallbackPSO, count: count)
enc.dispatchThreads(MTLSize(width: count, height: 1, depth: 1),
threadsPerThreadgroup: tg)
enc.endEncoding()
}
}
// ── MoE forward helpers ──────────────────────
/// Quantized matmul for a specific expert slice within a 3D tensor.
func quantizedMatmulExpert(engine: MarkBaseEngine, cmdBuf: MTLCommandBuffer,
input: MTLBuffer,
expert: MoEExpertGroup, expertIdx: Int,
output: MTLBuffer) throws {
let kernelName = expert.bits == 8 ? "quantized_matmul_simd_8bit" : "quantized_matmul_simd"
guard let pso = try? engine.pipeline(named: kernelName) else {
print(" [ERROR] quantizedMatmulExpert: Shader \(kernelName) not found! Falling back to quantized_matmul_seq")
let fallbackPSO = try engine.pipeline(named: "quantized_matmul_seq")
// Use fallback shader...
let enc = cmdBuf.makeComputeCommandEncoder()!
enc.setComputePipelineState(fallbackPSO)
enc.setBuffer(input, offset: 0, index: 0)
enc.setBuffer(expert.weight, offset: expert.weightStride * expertIdx, index: 1)
enc.setBuffer(expert.scales, offset: expert.scalesStride * expertIdx, index: 2)
enc.setBuffer(expert.biases, offset: expert.scalesStride * expertIdx, index: 3)
enc.setBuffer(output, offset: 0, index: 4)
var inDim = UInt32(expert.expertInDim)
enc.setBytes(&inDim, length: MemoryLayout<UInt32>.size, index: 5)
var outDim = UInt32(expert.expertOutDim)
enc.setBytes(&outDim, length: MemoryLayout<UInt32>.size, index: 6)
var groupSize = UInt32(expert.expertInDim / expert.numGroups)
enc.setBytes(&groupSize, length: MemoryLayout<UInt32>.size, index: 7)
let tg = engine.threadgroupSize1D(fallbackPSO, count: expert.expertOutDim)
enc.dispatchThreads(MTLSize(width: expert.expertOutDim, height: 1, depth: 1),
threadsPerThreadgroup: tg)
enc.endEncoding()
return
}
let enc = cmdBuf.makeComputeCommandEncoder()!
enc.setComputePipelineState(pso)
enc.setBuffer(input, offset: 0, index: 0)
enc.setBuffer(expert.weight, offset: expert.weightStride * expertIdx, index: 1)
enc.setBuffer(expert.scales, offset: expert.scalesStride * expertIdx, index: 2)
enc.setBuffer(expert.biases, offset: expert.scalesStride * expertIdx, index: 3)
enc.setBuffer(output, offset: 0, index: 4)
var inDim = UInt32(expert.expertInDim)
enc.setBytes(&inDim, length: MemoryLayout<UInt32>.size, index: 5)
var outDim = UInt32(expert.expertOutDim)
enc.setBytes(&outDim, length: MemoryLayout<UInt32>.size, index: 6)
var groupSize = UInt32(expert.expertInDim / expert.numGroups) // dynamic group size
enc.setBytes(&groupSize, length: MemoryLayout<UInt32>.size, index: 7)
let tgMemSize = expert.expertInDim * 4
enc.setThreadgroupMemoryLength(tgMemSize, index: 0)
let count = expert.expertOutDim
let tg = MTLSize(width: 256, height: 1, depth: 1)
enc.dispatchThreads(MTLSize(width: count, height: 1, depth: 1),
threadsPerThreadgroup: tg)
enc.endEncoding()
}
/// Fused gate+up matmul for a specific expert slice.
func expertFusedGateUp(engine: MarkBaseEngine, cmdBuf: MTLCommandBuffer,
input: MTLBuffer,
gate: MoEExpertGroup, up: MoEExpertGroup,
expertIdx: Int,
output: MTLBuffer) throws {
let kernelName = gate.bits == 8 ? "quantized_matmul_gate_up_8bit" : "quantized_matmul_gate_up"
let pso = try engine.pipeline(named: kernelName)
let enc = cmdBuf.makeComputeCommandEncoder()!
enc.setComputePipelineState(pso)
enc.setBuffer(input, offset: 0, index: 0)
enc.setBuffer(gate.weight, offset: gate.weightStride * expertIdx, index: 1)
enc.setBuffer(gate.scales, offset: gate.scalesStride * expertIdx, index: 2)
enc.setBuffer(gate.biases, offset: gate.scalesStride * expertIdx, index: 3)
enc.setBuffer(up.weight, offset: up.weightStride * expertIdx, index: 4)
enc.setBuffer(up.scales, offset: up.scalesStride * expertIdx, index: 5)
enc.setBuffer(up.biases, offset: up.scalesStride * expertIdx, index: 6)
enc.setBuffer(output, offset: 0, index: 7)
var inDim = UInt32(gate.expertInDim)
enc.setBytes(&inDim, length: MemoryLayout<UInt32>.size, index: 8)
var outDim = UInt32(gate.expertOutDim)
enc.setBytes(&outDim, length: MemoryLayout<UInt32>.size, index: 9)
var groupSize = UInt32(gate.expertInDim / gate.numGroups)
enc.setBytes(&groupSize, length: MemoryLayout<UInt32>.size, index: 10)
let count = gate.expertOutDim
let tg = engine.threadgroupSize1D(pso, count: count)
enc.dispatchThreads(MTLSize(width: count, height: 1, depth: 1),
threadsPerThreadgroup: tg)
enc.endEncoding()
}
/// Fused gate+up+down for a specific expert slice.
/// Replaces: fusedGateUp + blit + downMatmul + scaledAdd with a single kernel.
func expertFusedGateUpDown(engine: MarkBaseEngine, cmdBuf: MTLCommandBuffer,
input: MTLBuffer,
gate: MoEExpertGroup, up: MoEExpertGroup, down: MoEExpertGroup,
expertIdx: Int,
accum: MTLBuffer, weight: Float) throws {
let kernelName = gate.bits == 8 ? "quantized_matmul_gate_up_down_8bit" : "quantized_matmul_gate_up_down"
let pso = try engine.pipeline(named: kernelName)
let enc = cmdBuf.makeComputeCommandEncoder()!
enc.setComputePipelineState(pso)
enc.setBuffer(input, offset: 0, index: 0)
enc.setBuffer(gate.weight, offset: gate.weightStride * expertIdx, index: 1)
enc.setBuffer(gate.scales, offset: gate.scalesStride * expertIdx, index: 2)
enc.setBuffer(gate.biases, offset: gate.scalesStride * expertIdx, index: 3)
enc.setBuffer(up.weight, offset: up.weightStride * expertIdx, index: 4)
enc.setBuffer(up.scales, offset: up.scalesStride * expertIdx, index: 5)
enc.setBuffer(up.biases, offset: up.scalesStride * expertIdx, index: 6)
enc.setBuffer(down.weight, offset: down.weightStride * expertIdx, index: 7)
enc.setBuffer(down.scales, offset: down.scalesStride * expertIdx, index: 8)
enc.setBuffer(down.biases, offset: down.scalesStride * expertIdx, index: 9)
enc.setBuffer(accum, offset: 0, index: 10)
var hiddenSize = UInt32(gate.expertInDim)
enc.setBytes(&hiddenSize, length: MemoryLayout<UInt32>.size, index: 11)
var moeIntermediate = UInt32(gate.expertOutDim)
enc.setBytes(&moeIntermediate, length: MemoryLayout<UInt32>.size, index: 12)
var groupSize = UInt32(gate.expertInDim / gate.numGroups)
enc.setBytes(&groupSize, length: MemoryLayout<UInt32>.size, index: 13)
var w = weight
enc.setBytes(&w, length: MemoryLayout<Float>.size, index: 14)
let count = Int(max(hiddenSize, moeIntermediate))
let tgMemSize = count * 4
enc.setThreadgroupMemoryLength(tgMemSize, index: 0)
let tg = MTLSize(width: 256, height: 1, depth: 1)
enc.dispatchThreads(MTLSize(width: count, height: 1, depth: 1),
threadsPerThreadgroup: tg)
enc.endEncoding()
}
func moeMegaKernel(engine: MarkBaseEngine, cmdBuf: MTLCommandBuffer,
input: MTLBuffer,
router: QuantizedWeights,
gate: MoEExpertGroup, up: MoEExpertGroup, down: MoEExpertGroup,
accum: MTLBuffer) throws -> Bool {
guard let pso = try? engine.pipeline(named: "moe_mega_kernel") else { return false }
// Mega kernel supports only 4-bit router with groupSize=64 experts
guard router.bits == 4 else { return false }
let expertGroupSize = gate.expertInDim / gate.numGroups
guard expertGroupSize == 64 else { return false }
let enc = cmdBuf.makeComputeCommandEncoder()!
enc.setComputePipelineState(pso)
enc.setBuffer(input, offset: 0, index: 0)
// Router: full 3D weight buffer [numExperts * outDim, inDimPacked]
enc.setBuffer(router.weight, offset: 0, index: 1)
enc.setBuffer(router.scales, offset: 0, index: 2)
enc.setBuffer(router.biases, offset: 0, index: 3)
// Gate: full 3D weight buffer [numExperts * expertOutDim, expertInDimPacked]
enc.setBuffer(gate.weight, offset: 0, index: 4)
enc.setBuffer(gate.scales, offset: 0, index: 5)
enc.setBuffer(gate.biases, offset: 0, index: 6)
// Up: full 3D weight buffer
enc.setBuffer(up.weight, offset: 0, index: 7)
enc.setBuffer(up.scales, offset: 0, index: 8)
enc.setBuffer(up.biases, offset: 0, index: 9)
// Down: full 3D weight buffer [numExperts * expertOutDim, expertInDimPacked]
enc.setBuffer(down.weight, offset: 0, index: 10)
enc.setBuffer(down.scales, offset: 0, index: 11)
enc.setBuffer(down.biases, offset: 0, index: 12)
enc.setBuffer(accum, offset: 0, index: 13)
var hiddenSize = UInt32(gate.expertInDim)
enc.setBytes(&hiddenSize, length: MemoryLayout<UInt32>.size, index: 14)
var moeIntermediate = UInt32(gate.expertOutDim)
enc.setBytes(&moeIntermediate, length: MemoryLayout<UInt32>.size, index: 15)
var numExperts = UInt32(gate.numExperts)
enc.setBytes(&numExperts, length: MemoryLayout<UInt32>.size, index: 16)
var rScale = routerScale
enc.setBytes(&rScale, length: MemoryLayout<Float>.size, index: 17)
var topK = UInt32(topK)
enc.setBytes(&topK, length: MemoryLayout<UInt32>.size, index: 18)
let count = Int(max(hiddenSize, moeIntermediate))
let logitStorage = Int(numExperts) + Int(topK) + Int(topK)
let tgMemSize = (count + logitStorage) * 4
enc.setThreadgroupMemoryLength(tgMemSize, index: 0)
let tg = MTLSize(width: 256, height: 1, depth: 1)
enc.dispatchThreads(MTLSize(width: count, height: 1, depth: 1),
threadsPerThreadgroup: tg)
enc.endEncoding()
return true
}
func moeForward(input: MTLBuffer, ns: MTLBuffer,
temps: ForwardTemps,
cmdBuf: MTLCommandBuffer,
engine: MarkBaseEngine) throws {
guard let router = routerProj,
let eGate = expertGate, let eUp = expertUp, let eDown = expertDown
else { return }
let hs = config.hiddenSize
// ── Step 1: Copy MoE input ns → io (preserve from overwrite by expert output) ──
let moeInput = temps.io
let blit = cmdBuf.makeBlitCommandEncoder()!
blit.copy(from: ns, sourceOffset: 0,
to: moeInput, destinationOffset: 0,
size: hs * 4)
blit.endEncoding()
// Zero accumulation buffer temps.h
let zeroBlit = cmdBuf.makeBlitCommandEncoder()!
zeroBlit.fill(buffer: temps.h, range: 0..<hs * 4, value: 0)
zeroBlit.endEncoding()
// ── Step 2: Try GPU mega-kernel (fused router + softmax + topK + experts) ──
if try moeMegaKernel(engine: engine, cmdBuf: cmdBuf,
input: moeInput,
router: router, gate: eGate, up: eUp, down: eDown,
accum: temps.h) {
// Mega kernel does ALL work on GPU, no CPU dependency
// No wait needed - caller will commit when ready
} else {
// ── Fallback: CPU route (requires wait for CPU read) ──
// Create separate command buffer for CPU fallback path
let cpuCmdBuf = engine.commandQueue.makeCommandBuffer()!
let numExperts = eGate.numExperts
// Router matmul
try quantizedMatmul(engine: engine, cmdBuf: cpuCmdBuf,
input: moeInput, weights: router,
output: temps.gate)
cpuCmdBuf.commit()
cpuCmdBuf.waitUntilCompleted() // CPU read required
// CPU softmax + top-k
let routerData = engine.readFloats(from: temps.gate, count: numExperts)
var scaled = routerData.map { $0 * routerScale }
let maxVal = scaled.max() ?? 0
var sum: Float = 0
for i in 0..<numExperts {
scaled[i] = exp(scaled[i] - maxVal)
sum += scaled[i]
}
if sum > 0 {
for i in 0..<numExperts { scaled[i] /= sum }
}
let k = min(topK, numExperts)
var indexed = Array(scaled.enumerated())
indexed.sort { $0.element > $1.element }
let topK = indexed.prefix(k)
let topKSum = topK.reduce(0) { $0 + $1.element }
guard topKSum > 0 else { return }
// Compute experts on passed cmdBuf (can be batched with other layers)
for (expertIdx, rawWeight) in topK {
let weight = rawWeight / topKSum
try expertFusedGateUpDown(engine: engine, cmdBuf: cmdBuf,
input: moeInput,
gate: eGate, up: eUp, down: eDown,
expertIdx: expertIdx,
accum: temps.h, weight: weight)
}
}
// ── Step 5: Residual: input += moe_output (temps.h) scaled by layerScalar ──
if layerScalar != 1.0 {
try eltwiseAddScaled(engine: engine, cmdBuf: cmdBuf,
a: input, scaleA: 1.0,
b: temps.h, scaleB: layerScalar,
output: input, count: hs)
} else {
try eltwiseAdd(engine: engine, cmdBuf: cmdBuf,
a: input, b: temps.h,
output: input, count: hs)
}
}
// ── Main forward ─────────────────────────────
/// Run one E4B layer forward pass.
///
/// - Parameters:
/// - input: [hiddenSize] — **modified in-place** to become the output
/// - position: absolute token position (0-based)
/// - kvCache: KV cache for reading attention. Must be non-nil.
/// - shouldStoreKV: compute K,V and write into kvCache when true
/// - temps: pre-allocated scratch buffers
/// - engine: Metal engine
public func forward(input: MTLBuffer, position: Int,
kvCache: KVCache,
shouldStoreKV: Bool,
temps: ForwardTemps,
engine: MarkBaseEngine,
perLayerInput: MTLBuffer? = nil,
perLayerInputOffset: Int = 0) throws {
self.attnBuf = temps.attn
if useMoE {
// ── MoE path: GPU mega kernel eliminates CPU dependency ──
// All operations use shared command buffer (NO waits)
let cmdBuf = engine.commandQueue.makeCommandBuffer()!
try attentionForward(input: input, position: position,
kvCache: kvCache, shouldStoreKV: shouldStoreKV,
temps: temps, engine: engine, cmdBuf: cmdBuf)
try moeForward(input: input, ns: temps.ns, temps: temps,
cmdBuf: cmdBuf, engine: engine)
try postFfnForward(input: input, temps: temps, engine: engine,
cmdBuf: cmdBuf,
perLayerInput: perLayerInput,
perLayerInputOffset: perLayerInputOffset)
cmdBuf.commit()
cmdBuf.waitUntilCompleted()
} else {
// Dense path: unified flow for all positions
let cmdBuf = engine.commandQueue.makeCommandBuffer()!
try attentionForward(input: input, position: position,
kvCache: kvCache, shouldStoreKV: shouldStoreKV,
temps: temps, engine: engine, cmdBuf: cmdBuf)
// FFN: gate+up fused → down → residual (scaled by layerScalar)
try fusedGateUp(engine: engine, cmdBuf: cmdBuf,
input: temps.ns, output: temps.gate)
try matmulAny(engine: engine, cmdBuf: cmdBuf,
input: temps.gate, weightsQ: downProj, weightsF: downProjFloat, output: temps.h)
if layerScalar != 1.0 {
try eltwiseAddScaled(engine: engine, cmdBuf: cmdBuf,
a: input, scaleA: 1.0,
b: temps.h, scaleB: layerScalar,
output: input, count: config.hiddenSize)
} else {
try eltwiseAdd(engine: engine, cmdBuf: cmdBuf,
a: input, b: temps.h,
output: input, count: config.hiddenSize)
}
// Per-layer gating for dense path
if let pg = perLayerGate, let pp = perLayerProjection, let pl = perLayerInput {
try rmsNorm(engine: engine, cmdBuf: cmdBuf,
input: input, weight: postFeedforwardLayernorm,
output: temps.h, count: config.hiddenSize, eps: rmsNormEps)
try matmulAny(engine: engine, cmdBuf: cmdBuf,
input: temps.h, weightsQ: pg, weightsF: perLayerGateFloat,
output: temps.gating)
try gelu(engine: engine, cmdBuf: cmdBuf,
input: temps.gating, output: temps.gating, count: 256)
try eltwiseMul(engine: engine, cmdBuf: cmdBuf,
a: temps.gating, aOffset: 0,
b: pl, bOffset: perLayerInputOffset,
output: temps.gating, outputOffset: 0,
count: 256)
try matmulAny(engine: engine, cmdBuf: cmdBuf,
input: temps.gating, weightsQ: pp, weightsF: perLayerProjectionFloat,
output: temps.h)
if let ppn = postPerLayerInputNorm {
try rmsNorm(engine: engine, cmdBuf: cmdBuf,
input: temps.h, weight: ppn,
output: temps.h, count: config.hiddenSize, eps: rmsNormEps)
}
try eltwiseAdd(engine: engine, cmdBuf: cmdBuf,
a: input, b: temps.h,
output: input, count: config.hiddenSize)
}
cmdBuf.commit()
cmdBuf.waitUntilCompleted()
}
}
// ── Attention forward (steps 1-13) ──
private func attentionForward(input: MTLBuffer, position: Int,
kvCache: KVCache,
shouldStoreKV: Bool,
temps: ForwardTemps,
engine: MarkBaseEngine,
cmdBuf: MTLCommandBuffer) throws {
// ── 1. input_layernorm(x) → temps.h ──
try rmsNorm(engine: engine, cmdBuf: cmdBuf,
input: input, weight: inputLayernorm,
output: temps.h, count: config.hiddenSize, eps: rmsNormEps)
// ── 2. Q = q_proj(temps.h) → temps.q ──
try matmulAny(engine: engine, cmdBuf: cmdBuf,
input: temps.h, weightsQ: qProj, weightsF: qProjFloat, output: temps.q)
// ── 3. Q = q_norm(Q) → ns (per-head RMSNorm) ──
try groupedRmsNorm(engine: engine, cmdBuf: cmdBuf,
input: temps.q, weight: qNorm,
output: temps.ns,
count: config.nHeads * config.headDim,
groupSize: config.headDim, eps: rmsNormEps)
// ── 4. RoPE(Q) on ns ──
try applyRoPEQ(engine: engine, cmdBuf: cmdBuf,
q: temps.ns, position: position)
// ── 5. K,V projections ──
try matmulAny(engine: engine, cmdBuf: cmdBuf,
input: temps.h, weightsQ: kProj, weightsF: kProjFloat, output: temps.k)
if let vp = vProj, let vpF = vProjFloat {
if vp != nil || vpF != nil {
try matmulAny(engine: engine, cmdBuf: cmdBuf,
input: temps.h, weightsQ: vp, weightsF: vpF, output: temps.v)
}
} else if kEqualsV {
let blit = cmdBuf.makeBlitCommandEncoder()!
let copyBytes = config.nKvHeads * config.headDim * MemoryLayout<Float>.stride
blit.copy(from: temps.k, sourceOffset: 0,
to: temps.v, destinationOffset: 0,
size: copyBytes)
blit.endEncoding()
} else {
try eltwiseAddScaled(engine: engine, cmdBuf: cmdBuf,
a: temps.v, scaleA: 0.0,
b: temps.v, scaleB: 0.0,
output: temps.v, count: config.nKvHeads * config.headDim)
}
// ── 6. K,V norms ──
try groupedRmsNorm(engine: engine, cmdBuf: cmdBuf,
input: temps.k, weight: kNorm,
output: temps.up,
count: config.nKvHeads * config.headDim,
groupSize: config.headDim, eps: rmsNormEps)
if let vn = vNorm {
try groupedRmsNorm(engine: engine, cmdBuf: cmdBuf,
input: temps.v, weight: vn,
output: temps.gate,
count: config.nKvHeads * config.headDim,
groupSize: config.headDim, eps: rmsNormEps)
}
// ── 7. RoPE(K) ──
try applyRoPEK(engine: engine, cmdBuf: cmdBuf,
k: temps.up, position: position)
// ── 8. Store K,V ──
if shouldStoreKV {
let valueBuf = vNorm != nil ? temps.gate : temps.v
kvCache.store(key: temps.up, keySrcOffset: 0,
value: valueBuf, valueSrcOffset: 0,
position: position, commandBuffer: cmdBuf)
}
// ── 9. Attention ──
let curK = temps.up
let curV = vNorm != nil ? temps.gate : temps.v
if config.isSliding {
if shouldStoreKV {
try slidingAttention(engine: engine, cmdBuf: cmdBuf,
q: temps.ns, cache: kvCache, position: position)
} else {
try slidingAttentionWithCurrent(engine: engine, cmdBuf: cmdBuf,
q: temps.ns, cache: kvCache,
curK: curK, curV: curV,
position: position)
}
} else {
if shouldStoreKV {
try fullAttention(engine: engine, cmdBuf: cmdBuf,
q: temps.ns, cache: kvCache, position: position)
} else {
try fullAttentionWithCurrent(engine: engine, cmdBuf: cmdBuf,
q: temps.ns, cache: kvCache,
curK: curK, curV: curV,
position: position)
}
}
// ── 10. O projection ──
try matmulAny(engine: engine, cmdBuf: cmdBuf,
input: temps.attn, weightsQ: oProj, weightsF: oProjFloat, output: temps.h)
// ── 11. Residual 1 (scaled by layerScalar) ──
if layerScalar != 1.0 {
try eltwiseAddScaled(engine: engine, cmdBuf: cmdBuf,
a: input, scaleA: 1.0,
b: temps.h, scaleB: layerScalar,
output: input, count: config.hiddenSize)
} else {
try eltwiseAdd(engine: engine, cmdBuf: cmdBuf,
a: input, b: temps.h,
output: input, count: config.hiddenSize)
}
// ── 12. post_attention_layernorm → temps.h ──
try rmsNorm(engine: engine, cmdBuf: cmdBuf,
input: input, weight: postAttentionLayernorm,
output: temps.h, count: config.hiddenSize, eps: rmsNormEps)
// ── 13. pre_feedforward_layernorm → ns ──
try rmsNorm(engine: engine, cmdBuf: cmdBuf,
input: temps.h, weight: preFeedforwardLayernorm,
output: temps.ns, count: config.hiddenSize, eps: rmsNormEps)
}
// ── Post-FFN forward (steps 17-19) ──
private func postFfnForward(input: MTLBuffer, temps: ForwardTemps,
engine: MarkBaseEngine,
cmdBuf: MTLCommandBuffer,
perLayerInput: MTLBuffer?,
perLayerInputOffset: Int) throws {
// ── 17. post_feedforward_layernorm → temps.h ──
try rmsNorm(engine: engine, cmdBuf: cmdBuf,
input: input, weight: postFeedforwardLayernorm,
output: temps.h, count: config.hiddenSize, eps: rmsNormEps)
// ── 18. Per-layer gating (optional) ──
if let pg = perLayerGate, let pp = perLayerProjection, let pl = perLayerInput {
try matmulAny(engine: engine, cmdBuf: cmdBuf,
input: temps.h, weightsQ: pg, weightsF: perLayerGateFloat,
output: temps.gating)
try gelu(engine: engine, cmdBuf: cmdBuf,
input: temps.gating, output: temps.gating, count: 256)
try eltwiseMul(engine: engine, cmdBuf: cmdBuf,
a: temps.gating, aOffset: 0,
b: pl, bOffset: perLayerInputOffset,
output: temps.gating, outputOffset: 0,
count: 256)
try matmulAny(engine: engine, cmdBuf: cmdBuf,
input: temps.gating, weightsQ: pp, weightsF: perLayerProjectionFloat,
output: temps.h)
if let ppn = postPerLayerInputNorm {
try rmsNorm(engine: engine, cmdBuf: cmdBuf,
input: temps.h, weight: ppn,
output: temps.h, count: config.hiddenSize, eps: rmsNormEps)
}
try eltwiseAdd(engine: engine, cmdBuf: cmdBuf,
a: input, b: temps.h,
output: input, count: config.hiddenSize)
} else {
try eltwiseAdd(engine: engine, cmdBuf: cmdBuf,
a: input, b: temps.h,
output: input, count: config.hiddenSize)
}
}
}