31427770b1
- Tokenizer fix: collect <0xXX> bytes and decode as UTF-8 (fixes Chinese/non-ASCII character decoding) - BPETokenizer + HuggingFaceTokenizer: both updated - Engine.swift: added writeFloats() utility method - FloatWeights struct added to Layer.swift (bf16 support) - attnQBits/KBits/VBits/OBits detection added to Model.swift - bf16 layer weight support from commit 48c0347 cherry-picked
1379 lines
66 KiB
Swift
1379 lines
66 KiB
Swift
import Metal
|
|
|
|
// ── Quantized Weights ────────────────────────────
|
|
|
|
public struct QuantizedWeights {
|
|
|
|
public let weight: MTLBuffer // U32 packed [outDim, inDim/(32/bits)]
|
|
public let scales: MTLBuffer // Float32 [outDim, inDim/groupSize]
|
|
public let biases: MTLBuffer // Float32 [outDim, inDim/groupSize]
|
|
public let inDim: Int
|
|
public let outDim: Int
|
|
public let bits: Int // 4 or 8
|
|
public let groupSize: Int // Quantization group size (32, 64, etc.)
|
|
}
|
|
|
|
// ── Float Weights (non-quantized bf16/f32) ────────────────────────────
|
|
|
|
public struct FloatWeights {
|
|
public let weight: MTLBuffer // Float32 [outDim, inDim]
|
|
public let inDim: Int
|
|
public let outDim: Int
|
|
}
|
|
|
|
// ── Layer Configuration ──────────────────────────
|
|
|
|
public struct E4BLayerConfig {
|
|
public let isSliding: Bool
|
|
public let headDim: Int
|
|
public let intermediateSize: Int
|
|
public let rotatedDim: Int
|
|
public let ropeTheta: Float
|
|
public let ropeScale: Float
|
|
public let windowSize: Int
|
|
public let nHeads: Int
|
|
public let nKvHeads: Int
|
|
public let hiddenSize: Int
|
|
|
|
public static func sliding(hiddenSize: Int, headDim: Int,
|
|
intermediateSize: Int,
|
|
nHeads: Int, nKvHeads: Int,
|
|
windowSize: Int = 512) -> E4BLayerConfig {
|
|
E4BLayerConfig(
|
|
isSliding: true, headDim: headDim,
|
|
intermediateSize: intermediateSize,
|
|
rotatedDim: headDim / 2, // default RoPE: half of dimensions rotated
|
|
ropeTheta: 10000.0, ropeScale: 1.0,
|
|
windowSize: windowSize,
|
|
nHeads: nHeads, nKvHeads: nKvHeads,
|
|
hiddenSize: hiddenSize
|
|
)
|
|
}
|
|
|
|
public static func full(hiddenSize: Int, headDim: Int,
|
|
intermediateSize: Int,
|
|
nHeads: Int, nKvHeads: Int,
|
|
maxPosition: Int = 8192) -> E4BLayerConfig {
|
|
E4BLayerConfig(
|
|
isSliding: false, headDim: headDim,
|
|
intermediateSize: intermediateSize,
|
|
rotatedDim: Int(Float(headDim) * 0.25),
|
|
ropeTheta: 1000000.0, ropeScale: 1.0,
|
|
windowSize: maxPosition,
|
|
nHeads: nHeads, nKvHeads: nKvHeads,
|
|
hiddenSize: hiddenSize
|
|
)
|
|
}
|
|
}
|
|
|
|
// ── Temp buffers shared across forward pass ──────
|
|
|
|
public struct ForwardTemps {
|
|
public let q: MTLBuffer // [maxHeads * maxHeadDim]
|
|
public let k: MTLBuffer // [maxKvHeads * maxHeadDim]
|
|
public let v: MTLBuffer // [maxKvHeads * maxHeadDim]
|
|
public let h: MTLBuffer // [hiddenSize] scratch (FFN专用)
|
|
public let attnH: MTLBuffer // [hiddenSize] attention专用 (避免覆盖h)
|
|
public let gate: MTLBuffer // [maxIntermediateSize]
|
|
public let up: MTLBuffer // [maxIntermediateSize]
|
|
public let attn: MTLBuffer // [maxHeads * maxHeadDim]
|
|
public let gating: MTLBuffer // [256] per-layer gating scratch
|
|
public let ns: MTLBuffer // [max(hiddenSize, nHeads*headDim)] norm scratch
|
|
public let io: MTLBuffer // [hiddenSize] dedicated layer I/O (separate from scratch)
|
|
|
|
public init(device: MTLDevice,
|
|
maxHeadDim: Int = 512,
|
|
maxIntermediate: Int = 20480,
|
|
hiddenSize: Int = 2560,
|
|
nHeads: Int = 8,
|
|
nKvHeads: Int = 2) throws {
|
|
func buf(_ n: Int) throws -> MTLBuffer {
|
|
guard let b = device.makeBuffer(length: n * MemoryLayout<Float>.stride,
|
|
options: .storageModeShared)
|
|
else { throw E4BError.bufferCreationFailed }
|
|
return b
|
|
}
|
|
q = try buf(nHeads * maxHeadDim)
|
|
k = try buf(nKvHeads * maxHeadDim)
|
|
v = try buf(nKvHeads * maxHeadDim)
|
|
h = try buf(hiddenSize)
|
|
attnH = try buf(hiddenSize) // NEW: attention专用buffer
|
|
gate = try buf(maxIntermediate)
|
|
up = try buf(maxIntermediate)
|
|
attn = try buf(nHeads * maxHeadDim)
|
|
gating = try buf(256)
|
|
ns = try buf(max(hiddenSize, nHeads * maxHeadDim))
|
|
io = try buf(hiddenSize)
|
|
}
|
|
}
|
|
|
|
// ── MoE structures ──────────────────────────────
|
|
|
|
public struct MoEExpert {
|
|
public let gateProj: QuantizedWeights
|
|
public let upProj: QuantizedWeights
|
|
public let downProj: QuantizedWeights
|
|
public init(gateProj: QuantizedWeights, upProj: QuantizedWeights, downProj: QuantizedWeights) {
|
|
self.gateProj = gateProj
|
|
self.upProj = upProj
|
|
self.downProj = downProj
|
|
}
|
|
}
|
|
|
|
/// Expert weights stored as contiguous 3D tensors [numExperts, outDim, inDimPacked]
|
|
/// and [numExperts, outDim, numGroups] for scales/biases.
|
|
/// Per-expert access uses byte offsets into the shared buffers.
|
|
public struct MoEExpertGroup {
|
|
/// Full 3D weight buffer [numExperts * expertOutDim, expertInDimPacked] uint32
|
|
public let weight: MTLBuffer
|
|
/// Full 3D scales buffer [numExperts * expertOutDim, numGroups] float32
|
|
public let scales: MTLBuffer
|
|
/// Full 3D biases buffer (same shape as scales)
|
|
public let biases: MTLBuffer
|
|
/// Per-expert output dimension
|
|
public let expertOutDim: Int
|
|
/// Input dimension (hidden size)
|
|
public let expertInDim: Int
|
|
/// Number of groups per output row = inDim / groupSize
|
|
public let numGroups: Int
|
|
/// Total number of experts
|
|
public let numExperts: Int
|
|
/// Quantization bits (4 or 8)
|
|
public let bits: Int
|
|
|
|
/// Byte stride per expert for weight buffer
|
|
public var weightStride: Int { expertOutDim * (expertInDim * bits / 32) * 4 }
|
|
/// Byte stride per expert for scales/biases buffer
|
|
public var scalesStride: Int { expertOutDim * numGroups * 4 }
|
|
|
|
public init(weight: MTLBuffer, scales: MTLBuffer, biases: MTLBuffer,
|
|
expertOutDim: Int, expertInDim: Int, numGroups: Int, numExperts: Int,
|
|
bits: Int = 4) {
|
|
self.weight = weight
|
|
self.scales = scales
|
|
self.biases = biases
|
|
self.expertOutDim = expertOutDim
|
|
self.expertInDim = expertInDim
|
|
self.numGroups = numGroups
|
|
self.numExperts = numExperts
|
|
self.bits = bits
|
|
}
|
|
}
|
|
|
|
// ── Layer forward pass ───────────────────────────
|
|
|
|
public final class E4BLayer {
|
|
let config: E4BLayerConfig
|
|
let rmsNormEps: Float = 1e-6
|
|
let layerIdx: Int // For debug logging
|
|
|
|
// Norm weights (Float32)
|
|
let inputLayernorm: MTLBuffer?
|
|
let postAttentionLayernorm: MTLBuffer?
|
|
let preFeedforwardLayernorm: MTLBuffer?
|
|
let postFeedforwardLayernorm: MTLBuffer?
|
|
let postPerLayerInputNorm: MTLBuffer? // after per-layer gating
|
|
let qNorm: MTLBuffer?
|
|
let kNorm: MTLBuffer?
|
|
let vNorm: MTLBuffer? // nil — no-scale variant
|
|
|
|
// Quantized projections
|
|
let qProj: QuantizedWeights?
|
|
let kProj: QuantizedWeights?
|
|
let vProj: QuantizedWeights?
|
|
let oProj: QuantizedWeights?
|
|
let gateProj: QuantizedWeights?
|
|
let upProj: QuantizedWeights?
|
|
let downProj: QuantizedWeights?
|
|
let perLayerGate: QuantizedWeights?
|
|
let perLayerProjection: QuantizedWeights?
|
|
|
|
// Float projections (bf16 models)
|
|
let qProjFloat: FloatWeights?
|
|
let kProjFloat: FloatWeights?
|
|
let vProjFloat: FloatWeights?
|
|
let oProjFloat: FloatWeights?
|
|
let gateProjFloat: FloatWeights?
|
|
let upProjFloat: FloatWeights?
|
|
let downProjFloat: FloatWeights?
|
|
let perLayerGateFloat: FloatWeights?
|
|
let perLayerProjectionFloat: FloatWeights?
|
|
|
|
// MoE
|
|
let useMoE: Bool
|
|
let routerProj: QuantizedWeights?
|
|
let routerScale: Float
|
|
let perExpertScale: [Float]?
|
|
let expertGate: MoEExpertGroup?
|
|
let expertUp: MoEExpertGroup?
|
|
let expertDown: MoEExpertGroup?
|
|
let topK: Int
|
|
|
|
// K=V sharing for full attention layers (Gemma 4)
|
|
let kEqualsV: Bool
|
|
|
|
// Per-layer constants
|
|
let perLayerInput: MTLBuffer?
|
|
let perLayerInputScale: Float
|
|
let perLayerProjectionScale: Float
|
|
let layerScalar: Float
|
|
|
|
public init(config: E4BLayerConfig,
|
|
layerIdx: Int = 0,
|
|
inputLayernorm: MTLBuffer?,
|
|
postAttentionLayernorm: MTLBuffer?,
|
|
preFeedforwardLayernorm: MTLBuffer?,
|
|
postFeedforwardLayernorm: MTLBuffer?,
|
|
postPerLayerInputNorm: MTLBuffer?,
|
|
qNorm: MTLBuffer?,
|
|
kNorm: MTLBuffer?,
|
|
vNorm: MTLBuffer?,
|
|
qProj: QuantizedWeights? = nil,
|
|
kProj: QuantizedWeights? = nil,
|
|
vProj: QuantizedWeights? = nil,
|
|
oProj: QuantizedWeights? = nil,
|
|
gateProj: QuantizedWeights? = nil,
|
|
upProj: QuantizedWeights? = nil,
|
|
downProj: QuantizedWeights? = nil,
|
|
perLayerGate: QuantizedWeights? = nil,
|
|
perLayerProjection: QuantizedWeights? = nil,
|
|
qProjFloat: FloatWeights? = nil,
|
|
kProjFloat: FloatWeights? = nil,
|
|
vProjFloat: FloatWeights? = nil,
|
|
oProjFloat: FloatWeights? = nil,
|
|
gateProjFloat: FloatWeights? = nil,
|
|
upProjFloat: FloatWeights? = nil,
|
|
downProjFloat: FloatWeights? = nil,
|
|
perLayerGateFloat: FloatWeights? = nil,
|
|
perLayerProjectionFloat: FloatWeights? = nil,
|
|
perLayerInput: MTLBuffer?,
|
|
perLayerInputScale: Float,
|
|
perLayerProjectionScale: Float,
|
|
layerScalar: Float,
|
|
useMoE: Bool = false,
|
|
routerProj: QuantizedWeights? = nil,
|
|
routerScale: Float = 1.0,
|
|
perExpertScale: [Float]? = nil,
|
|
expertGate: MoEExpertGroup? = nil,
|
|
expertUp: MoEExpertGroup? = nil,
|
|
expertDown: MoEExpertGroup? = nil,
|
|
topK: Int = 8,
|
|
kEqualsV: Bool = false) {
|
|
self.layerIdx = layerIdx
|
|
self.config = config
|
|
self.inputLayernorm = inputLayernorm
|
|
self.postAttentionLayernorm = postAttentionLayernorm
|
|
self.preFeedforwardLayernorm = preFeedforwardLayernorm
|
|
self.postFeedforwardLayernorm = postFeedforwardLayernorm
|
|
self.postPerLayerInputNorm = postPerLayerInputNorm
|
|
self.qNorm = qNorm
|
|
self.kNorm = kNorm
|
|
self.vNorm = vNorm
|
|
self.qProj = qProj
|
|
self.kProj = kProj
|
|
self.vProj = vProj
|
|
self.oProj = oProj
|
|
self.gateProj = gateProj
|
|
self.upProj = upProj
|
|
self.downProj = downProj
|
|
self.perLayerGate = perLayerGate
|
|
self.perLayerProjection = perLayerProjection
|
|
self.qProjFloat = qProjFloat
|
|
self.kProjFloat = kProjFloat
|
|
self.vProjFloat = vProjFloat
|
|
self.oProjFloat = oProjFloat
|
|
self.gateProjFloat = gateProjFloat
|
|
self.upProjFloat = upProjFloat
|
|
self.downProjFloat = downProjFloat
|
|
self.perLayerGateFloat = perLayerGateFloat
|
|
self.perLayerProjectionFloat = perLayerProjectionFloat
|
|
self.kEqualsV = kEqualsV
|
|
self.perLayerInput = perLayerInput
|
|
self.perLayerInputScale = perLayerInputScale
|
|
self.perLayerProjectionScale = perLayerProjectionScale
|
|
self.layerScalar = layerScalar
|
|
self.useMoE = useMoE
|
|
self.routerProj = routerProj
|
|
self.routerScale = routerScale
|
|
self.perExpertScale = perExpertScale
|
|
self.expertGate = expertGate
|
|
self.expertUp = expertUp
|
|
self.expertDown = expertDown
|
|
self.topK = topK
|
|
}
|
|
|
|
// ── Kernel dispatch helpers (optimized versions preferred) ──────────────────
|
|
|
|
func rmsNorm(engine: MarkBaseEngine, cmdBuf: MTLCommandBuffer,
|
|
input: MTLBuffer, weight: MTLBuffer?,
|
|
output: MTLBuffer, count: Int, eps: Float) throws {
|
|
let pso = try engine.pipeline(named: "rms_norm")
|
|
let enc = cmdBuf.makeComputeCommandEncoder()!
|
|
enc.setComputePipelineState(pso)
|
|
enc.setBuffer(input, offset: 0, index: 0)
|
|
enc.setBuffer(weight, offset: 0, index: 1)
|
|
enc.setBuffer(output, offset: 0, index: 2)
|
|
var N = UInt32(count)
|
|
enc.setBytes(&N, length: MemoryLayout<UInt32>.size, index: 3)
|
|
var e = eps
|
|
enc.setBytes(&e, length: MemoryLayout<Float>.size, index: 4)
|
|
let gridSize = MTLSize(width: count, height: 1, depth: 1)
|
|
let tgSize = min(256, count)
|
|
let tg = MTLSize(width: tgSize, height: 1, depth: 1)
|
|
enc.dispatchThreads(gridSize, threadsPerThreadgroup: tg)
|
|
enc.endEncoding()
|
|
}
|
|
|
|
func groupedRmsNorm(engine: MarkBaseEngine, cmdBuf: MTLCommandBuffer,
|
|
input: MTLBuffer, weight: MTLBuffer?,
|
|
output: MTLBuffer, count: Int,
|
|
groupSize: Int, eps: Float) throws {
|
|
let pso = try engine.pipeline(named: "rms_norm_grouped")
|
|
let enc = cmdBuf.makeComputeCommandEncoder()!
|
|
enc.setComputePipelineState(pso)
|
|
enc.setBuffer(input, offset: 0, index: 0)
|
|
enc.setBuffer(weight, offset: 0, index: 1)
|
|
enc.setBuffer(output, offset: 0, index: 2)
|
|
var N = UInt32(count)
|
|
enc.setBytes(&N, length: MemoryLayout<UInt32>.size, index: 3)
|
|
var gs = UInt32(groupSize)
|
|
enc.setBytes(&gs, length: MemoryLayout<UInt32>.size, index: 4)
|
|
var e = eps
|
|
enc.setBytes(&e, length: MemoryLayout<Float>.size, index: 5)
|
|
let tg = engine.threadgroupSize1D(pso, count: count)
|
|
enc.dispatchThreads(MTLSize(width: count, height: 1, depth: 1),
|
|
threadsPerThreadgroup: tg)
|
|
enc.endEncoding()
|
|
}
|
|
|
|
func gelu(engine: MarkBaseEngine, cmdBuf: MTLCommandBuffer,
|
|
input: MTLBuffer, output: MTLBuffer, count: Int) throws {
|
|
let pso = try engine.pipeline(named: "gelu_approx")
|
|
let enc = cmdBuf.makeComputeCommandEncoder()!
|
|
enc.setComputePipelineState(pso)
|
|
enc.setBuffer(input, offset: 0, index: 0)
|
|
enc.setBuffer(output, offset: 0, index: 1)
|
|
var N = UInt32(count)
|
|
enc.setBytes(&N, length: MemoryLayout<UInt32>.size, index: 2)
|
|
let tg = engine.threadgroupSize1D(pso, count: count)
|
|
enc.dispatchThreads(MTLSize(width: count, height: 1, depth: 1),
|
|
threadsPerThreadgroup: tg)
|
|
enc.endEncoding()
|
|
}
|
|
|
|
func quantizedMatmul(engine: MarkBaseEngine, cmdBuf: MTLCommandBuffer,
|
|
input: MTLBuffer,
|
|
weights: QuantizedWeights,
|
|
output: MTLBuffer) throws {
|
|
// Select kernel based on quantization bits
|
|
let kernelName = weights.bits == 8 ? "quantized_matmul_8bit" : "quantized_matmul"
|
|
// TEMPORARILY USE FALLBACK KERNEL FOR TESTING
|
|
if false, let pso = try? engine.pipeline(named: kernelName) {
|
|
let enc = cmdBuf.makeComputeCommandEncoder()!
|
|
enc.setComputePipelineState(pso)
|
|
enc.setBuffer(input, offset: 0, index: 0)
|
|
enc.setBuffer(weights.weight, offset: 0, index: 1)
|
|
enc.setBuffer(weights.scales, offset: 0, index: 2)
|
|
enc.setBuffer(weights.biases, offset: 0, index: 3)
|
|
enc.setBuffer(output, offset: 0, index: 4)
|
|
var inDim = UInt32(weights.inDim)
|
|
enc.setBytes(&inDim, length: MemoryLayout<UInt32>.size, index: 5)
|
|
var outDim = UInt32(weights.outDim)
|
|
enc.setBytes(&outDim, length: MemoryLayout<UInt32>.size, index: 6)
|
|
var groupSize = UInt32(weights.groupSize) // quantization group size from weights
|
|
enc.setBytes(&groupSize, length: MemoryLayout<UInt32>.size, index: 7)
|
|
|
|
// Threadgroup memory for input vector cache
|
|
let tgMemSize = weights.inDim * 4 // Float32
|
|
enc.setThreadgroupMemoryLength(tgMemSize, index: 0)
|
|
|
|
let count = weights.outDim
|
|
let tg = MTLSize(width: 256, height: 1, depth: 1)
|
|
enc.dispatchThreads(MTLSize(width: count, height: 1, depth: 1),
|
|
threadsPerThreadgroup: tg)
|
|
enc.endEncoding()
|
|
return
|
|
}
|
|
|
|
// Fallback to original
|
|
let pso = try engine.pipeline(named: "quantized_matmul")
|
|
let enc = cmdBuf.makeComputeCommandEncoder()!
|
|
enc.setComputePipelineState(pso)
|
|
enc.setBuffer(input, offset: 0, index: 0)
|
|
enc.setBuffer(weights.weight, offset: 0, index: 1)
|
|
enc.setBuffer(weights.scales, offset: 0, index: 2)
|
|
enc.setBuffer(weights.biases, offset: 0, index: 3)
|
|
enc.setBuffer(output, offset: 0, index: 4)
|
|
var inDim = UInt32(weights.inDim)
|
|
enc.setBytes(&inDim, length: MemoryLayout<UInt32>.size, index: 5)
|
|
var outDim = UInt32(weights.outDim)
|
|
enc.setBytes(&outDim, length: MemoryLayout<UInt32>.size, index: 6)
|
|
var groupSize = UInt32(weights.groupSize) // FIX: Add groupSize!
|
|
enc.setBytes(&groupSize, length: MemoryLayout<UInt32>.size, index: 7)
|
|
let count = weights.outDim
|
|
let tg = engine.threadgroupSize1D(pso, count: count)
|
|
enc.dispatchThreads(MTLSize(width: count, height: 1, depth: 1),
|
|
threadsPerThreadgroup: tg)
|
|
enc.endEncoding()
|
|
}
|
|
|
|
func matmulFloat(engine: MarkBaseEngine, cmdBuf: MTLCommandBuffer,
|
|
input: MTLBuffer,
|
|
weights: FloatWeights,
|
|
output: MTLBuffer) throws {
|
|
let pso = try engine.pipeline(named: "matmul_f32")
|
|
let enc = cmdBuf.makeComputeCommandEncoder()!
|
|
enc.setComputePipelineState(pso)
|
|
enc.setBuffer(input, offset: 0, index: 0)
|
|
enc.setBuffer(weights.weight, offset: 0, index: 1)
|
|
enc.setBuffer(output, offset: 0, index: 2)
|
|
var M: UInt32 = 1 // Single token
|
|
enc.setBytes(&M, length: MemoryLayout<UInt32>.size, index: 3)
|
|
var K = UInt32(weights.inDim)
|
|
enc.setBytes(&K, length: MemoryLayout<UInt32>.size, index: 4)
|
|
var N = UInt32(weights.outDim)
|
|
enc.setBytes(&N, length: MemoryLayout<UInt32>.size, index: 5)
|
|
let count = weights.outDim
|
|
let tg = engine.threadgroupSize1D(pso, count: count)
|
|
enc.dispatchThreads(MTLSize(width: count, height: 1, depth: 1),
|
|
threadsPerThreadgroup: tg)
|
|
enc.endEncoding()
|
|
}
|
|
|
|
func matmulAny(engine: MarkBaseEngine, cmdBuf: MTLCommandBuffer,
|
|
input: MTLBuffer,
|
|
weightsQ: QuantizedWeights?,
|
|
weightsF: FloatWeights?,
|
|
output: MTLBuffer) throws {
|
|
if let qw = weightsQ {
|
|
try quantizedMatmul(engine: engine, cmdBuf: cmdBuf, input: input, weights: qw, output: output)
|
|
} else if let fw = weightsF {
|
|
try matmulFloat(engine: engine, cmdBuf: cmdBuf, input: input, weights: fw, output: output)
|
|
}
|
|
}
|
|
|
|
func applyRoPEQ(engine: MarkBaseEngine, cmdBuf: MTLCommandBuffer,
|
|
q: MTLBuffer, position: Int) throws {
|
|
let pso = try engine.pipeline(named: "apply_rope_q")
|
|
let enc = cmdBuf.makeComputeCommandEncoder()!
|
|
enc.setComputePipelineState(pso)
|
|
enc.setBuffer(q, offset: 0, index: 0)
|
|
var nHeads = UInt32(config.nHeads)
|
|
enc.setBytes(&nHeads, length: MemoryLayout<UInt32>.size, index: 1)
|
|
var headDim = UInt32(config.headDim)
|
|
enc.setBytes(&headDim, length: MemoryLayout<UInt32>.size, index: 2)
|
|
var rotatedDim = UInt32(config.rotatedDim)
|
|
enc.setBytes(&rotatedDim, length: MemoryLayout<UInt32>.size, index: 3)
|
|
var theta = config.ropeTheta
|
|
enc.setBytes(&theta, length: MemoryLayout<Float>.size, index: 4)
|
|
var scale = config.ropeScale
|
|
enc.setBytes(&scale, length: MemoryLayout<Float>.size, index: 5)
|
|
var pos = Int32(position)
|
|
enc.setBytes(&pos, length: MemoryLayout<Int32>.size, index: 6)
|
|
let count = config.nHeads * (config.rotatedDim / 2)
|
|
let tg = engine.threadgroupSize1D(pso, count: count)
|
|
enc.dispatchThreads(MTLSize(width: count, height: 1, depth: 1),
|
|
threadsPerThreadgroup: tg)
|
|
enc.endEncoding()
|
|
}
|
|
|
|
func applyRoPEK(engine: MarkBaseEngine, cmdBuf: MTLCommandBuffer,
|
|
k: MTLBuffer, position: Int) throws {
|
|
let pso = try engine.pipeline(named: "apply_rope_k")
|
|
let enc = cmdBuf.makeComputeCommandEncoder()!
|
|
enc.setComputePipelineState(pso)
|
|
enc.setBuffer(k, offset: 0, index: 0)
|
|
var nKvHeads = UInt32(config.nKvHeads)
|
|
enc.setBytes(&nKvHeads, length: MemoryLayout<UInt32>.size, index: 1)
|
|
var headDim = UInt32(config.headDim)
|
|
enc.setBytes(&headDim, length: MemoryLayout<UInt32>.size, index: 2)
|
|
var rotatedDim = UInt32(config.rotatedDim)
|
|
enc.setBytes(&rotatedDim, length: MemoryLayout<UInt32>.size, index: 3)
|
|
var theta = config.ropeTheta
|
|
enc.setBytes(&theta, length: MemoryLayout<Float>.size, index: 4)
|
|
var scale = config.ropeScale
|
|
enc.setBytes(&scale, length: MemoryLayout<Float>.size, index: 5)
|
|
var pos = Int32(position)
|
|
enc.setBytes(&pos, length: MemoryLayout<Int32>.size, index: 6)
|
|
let count = config.nKvHeads * (config.rotatedDim / 2)
|
|
let tg = engine.threadgroupSize1D(pso, count: count)
|
|
enc.dispatchThreads(MTLSize(width: count, height: 1, depth: 1),
|
|
threadsPerThreadgroup: tg)
|
|
enc.endEncoding()
|
|
}
|
|
|
|
func slidingAttention(engine: MarkBaseEngine, cmdBuf: MTLCommandBuffer,
|
|
q: MTLBuffer, cache: KVCache,
|
|
position: Int) throws {
|
|
// Try optimized SIMD version first (softcapping removed)
|
|
if let pso = try? engine.pipeline(named: "sliding_attention_simd") {
|
|
let enc = cmdBuf.makeComputeCommandEncoder()!
|
|
enc.setComputePipelineState(pso)
|
|
enc.setBuffer(q, offset: 0, index: 0)
|
|
enc.setBuffer(cache.buffer, offset: cache.keyBaseOffset, index: 1)
|
|
enc.setBuffer(cache.buffer, offset: cache.valueBaseOffset, index: 2)
|
|
enc.setBuffer(attnBuf, offset: 0, index: 3)
|
|
var nHeads = UInt32(config.nHeads)
|
|
enc.setBytes(&nHeads, length: MemoryLayout<UInt32>.size, index: 4)
|
|
var nKvHeads = UInt32(config.nKvHeads)
|
|
enc.setBytes(&nKvHeads, length: MemoryLayout<UInt32>.size, index: 5)
|
|
var headDim = UInt32(config.headDim)
|
|
enc.setBytes(&headDim, length: MemoryLayout<UInt32>.size, index: 6)
|
|
var windowSize = UInt32(config.windowSize)
|
|
enc.setBytes(&windowSize, length: MemoryLayout<UInt32>.size, index: 7)
|
|
var off = Int32(position)
|
|
enc.setBytes(&off, length: MemoryLayout<Int32>.size, index: 8)
|
|
|
|
// Threadgroup memory for K/V cache
|
|
let kvCacheSize = config.windowSize * config.nKvHeads * (config.headDim/4) * 16 // float4 = 16 bytes
|
|
enc.setThreadgroupMemoryLength(kvCacheSize, index: 0)
|
|
enc.setThreadgroupMemoryLength(kvCacheSize, index: 1)
|
|
|
|
let grid = MTLSize(width: config.nHeads, height: config.headDim/4, depth: 1)
|
|
let tg = MTLSize(width: 8, height: 16, depth: 1) // Tune for cache efficiency
|
|
enc.dispatchThreads(grid, threadsPerThreadgroup: tg)
|
|
enc.endEncoding()
|
|
return
|
|
}
|
|
|
|
// Fallback to original
|
|
let pso = try engine.pipeline(named: "sliding_attention")
|
|
let enc = cmdBuf.makeComputeCommandEncoder()!
|
|
enc.setComputePipelineState(pso)
|
|
enc.setBuffer(q, offset: 0, index: 0)
|
|
enc.setBuffer(cache.buffer, offset: cache.keyBaseOffset, index: 1)
|
|
enc.setBuffer(cache.buffer, offset: cache.valueBaseOffset, index: 2)
|
|
enc.setBuffer(attnBuf, offset: 0, index: 3)
|
|
var nHeads = UInt32(config.nHeads)
|
|
enc.setBytes(&nHeads, length: MemoryLayout<UInt32>.size, index: 4)
|
|
var nKvHeads = UInt32(config.nKvHeads)
|
|
enc.setBytes(&nKvHeads, length: MemoryLayout<UInt32>.size, index: 5)
|
|
var headDim = UInt32(config.headDim)
|
|
enc.setBytes(&headDim, length: MemoryLayout<UInt32>.size, index: 6)
|
|
var windowSize = UInt32(config.windowSize)
|
|
enc.setBytes(&windowSize, length: MemoryLayout<UInt32>.size, index: 7)
|
|
var off = Int32(position)
|
|
enc.setBytes(&off, length: MemoryLayout<Int32>.size, index: 8)
|
|
let grid = MTLSize(width: config.nHeads, height: config.headDim, depth: 1)
|
|
let tg = engine.threadgroupSize2D(pso, grid: (config.nHeads, config.headDim))
|
|
enc.dispatchThreads(grid, threadsPerThreadgroup: tg)
|
|
enc.endEncoding()
|
|
}
|
|
|
|
func slidingAttentionWithCurrent(engine: MarkBaseEngine, cmdBuf: MTLCommandBuffer,
|
|
q: MTLBuffer, cache: KVCache,
|
|
curK: MTLBuffer, curV: MTLBuffer,
|
|
position: Int) throws {
|
|
let pso = try engine.pipeline(named: "sliding_attention_with_current")
|
|
let enc = cmdBuf.makeComputeCommandEncoder()!
|
|
enc.setComputePipelineState(pso)
|
|
enc.setBuffer(q, offset: 0, index: 0)
|
|
enc.setBuffer(cache.buffer, offset: cache.keyBaseOffset, index: 1)
|
|
enc.setBuffer(cache.buffer, offset: cache.valueBaseOffset, index: 2)
|
|
enc.setBuffer(curK, offset: 0, index: 3)
|
|
enc.setBuffer(curV, offset: 0, index: 4)
|
|
enc.setBuffer(attnBuf, offset: 0, index: 5)
|
|
var nHeads = UInt32(config.nHeads)
|
|
enc.setBytes(&nHeads, length: MemoryLayout<UInt32>.size, index: 6)
|
|
var nKvHeads = UInt32(config.nKvHeads)
|
|
enc.setBytes(&nKvHeads, length: MemoryLayout<UInt32>.size, index: 7)
|
|
var headDim = UInt32(config.headDim)
|
|
enc.setBytes(&headDim, length: MemoryLayout<UInt32>.size, index: 8)
|
|
var windowSize = UInt32(cache.maxLength)
|
|
enc.setBytes(&windowSize, length: MemoryLayout<UInt32>.size, index: 9)
|
|
// For shared layers: read ALL entries from owner's cache (positions 0..N)
|
|
// The owner has already stored at all positions up to current position
|
|
let cacheLen = cache.currentLength
|
|
var cacheLenVal = UInt32(cacheLen)
|
|
enc.setBytes(&cacheLenVal, length: MemoryLayout<UInt32>.size, index: 10)
|
|
var pos = Int32(position)
|
|
enc.setBytes(&pos, length: MemoryLayout<Int32>.size, index: 11)
|
|
let grid = MTLSize(width: config.nHeads, height: config.headDim, depth: 1)
|
|
let tg = engine.threadgroupSize2D(pso, grid: (config.nHeads, config.headDim))
|
|
enc.dispatchThreads(grid, threadsPerThreadgroup: tg)
|
|
enc.endEncoding()
|
|
}
|
|
|
|
// Temp buffer for attention output — set externally from ForwardTemps
|
|
var attnBuf: MTLBuffer!
|
|
|
|
func fullAttention(engine: MarkBaseEngine, cmdBuf: MTLCommandBuffer,
|
|
q: MTLBuffer, cache: KVCache,
|
|
position: Int) throws {
|
|
// Try optimized SIMD version first (no softcapping for text models)
|
|
if let pso = try? engine.pipeline(named: "full_attention_simd") {
|
|
let enc = cmdBuf.makeComputeCommandEncoder()!
|
|
enc.setComputePipelineState(pso)
|
|
enc.setBuffer(q, offset: 0, index: 0)
|
|
enc.setBuffer(cache.buffer, offset: cache.keyBaseOffset, index: 1)
|
|
enc.setBuffer(cache.buffer, offset: cache.valueBaseOffset, index: 2)
|
|
enc.setBuffer(attnBuf, offset: 0, index: 3)
|
|
var nHeads = UInt32(config.nHeads)
|
|
enc.setBytes(&nHeads, length: MemoryLayout<UInt32>.size, index: 4)
|
|
var nKvHeads = UInt32(config.nKvHeads)
|
|
enc.setBytes(&nKvHeads, length: MemoryLayout<UInt32>.size, index: 5)
|
|
var headDim = UInt32(config.headDim)
|
|
enc.setBytes(&headDim, length: MemoryLayout<UInt32>.size, index: 6)
|
|
var seqLen = UInt32(position + 1) // FIX: use seqLen, not maxPos
|
|
enc.setBytes(&seqLen, length: MemoryLayout<UInt32>.size, index: 7)
|
|
|
|
// Threadgroup memory for K/V cache (use seqLen, not maxPos)
|
|
let kvCacheSize = Int(seqLen) * config.nKvHeads * (config.headDim/4) * 16 // float4 = 16 bytes
|
|
enc.setThreadgroupMemoryLength(kvCacheSize, index: 0)
|
|
enc.setThreadgroupMemoryLength(kvCacheSize, index: 1)
|
|
|
|
let grid = MTLSize(width: config.nHeads, height: config.headDim/4, depth: 1)
|
|
let tg = MTLSize(width: 8, height: 16, depth: 1) // Tune for cache efficiency
|
|
enc.dispatchThreads(grid, threadsPerThreadgroup: tg)
|
|
enc.endEncoding()
|
|
return
|
|
}
|
|
|
|
// Fallback to original
|
|
let pso = try engine.pipeline(named: "full_attention")
|
|
let enc = cmdBuf.makeComputeCommandEncoder()!
|
|
enc.setComputePipelineState(pso)
|
|
enc.setBuffer(q, offset: 0, index: 0)
|
|
enc.setBuffer(cache.buffer, offset: cache.keyBaseOffset, index: 1)
|
|
enc.setBuffer(cache.buffer, offset: cache.valueBaseOffset, index: 2)
|
|
enc.setBuffer(attnBuf, offset: 0, index: 3)
|
|
var nHeads = UInt32(config.nHeads)
|
|
enc.setBytes(&nHeads, length: MemoryLayout<UInt32>.size, index: 4)
|
|
var nKvHeads = UInt32(config.nKvHeads)
|
|
enc.setBytes(&nKvHeads, length: MemoryLayout<UInt32>.size, index: 5)
|
|
var headDim = UInt32(config.headDim)
|
|
enc.setBytes(&headDim, length: MemoryLayout<UInt32>.size, index: 6)
|
|
var maxPos = UInt32(cache.maxLength)
|
|
enc.setBytes(&maxPos, length: MemoryLayout<UInt32>.size, index: 7)
|
|
var off = Int32(position)
|
|
enc.setBytes(&off, length: MemoryLayout<Int32>.size, index: 8)
|
|
let grid = MTLSize(width: config.nHeads, height: config.headDim, depth: 1)
|
|
let tg = engine.threadgroupSize2D(pso, grid: (config.nHeads, config.headDim))
|
|
enc.dispatchThreads(grid, threadsPerThreadgroup: tg)
|
|
enc.endEncoding()
|
|
}
|
|
|
|
func fullAttentionWithCurrent(engine: MarkBaseEngine, cmdBuf: MTLCommandBuffer,
|
|
q: MTLBuffer, cache: KVCache,
|
|
curK: MTLBuffer, curV: MTLBuffer,
|
|
position: Int) throws {
|
|
let pso = try engine.pipeline(named: "full_attention_with_current")
|
|
let enc = cmdBuf.makeComputeCommandEncoder()!
|
|
enc.setComputePipelineState(pso)
|
|
enc.setBuffer(q, offset: 0, index: 0)
|
|
enc.setBuffer(cache.buffer, offset: cache.keyBaseOffset, index: 1)
|
|
enc.setBuffer(cache.buffer, offset: cache.valueBaseOffset, index: 2)
|
|
enc.setBuffer(curK, offset: 0, index: 3)
|
|
enc.setBuffer(curV, offset: 0, index: 4)
|
|
enc.setBuffer(attnBuf, offset: 0, index: 5)
|
|
var nHeads = UInt32(config.nHeads)
|
|
enc.setBytes(&nHeads, length: MemoryLayout<UInt32>.size, index: 6)
|
|
var nKvHeads = UInt32(config.nKvHeads)
|
|
enc.setBytes(&nKvHeads, length: MemoryLayout<UInt32>.size, index: 7)
|
|
var headDim = UInt32(config.headDim)
|
|
enc.setBytes(&headDim, length: MemoryLayout<UInt32>.size, index: 8)
|
|
// For shared layers: read ALL entries from owner's cache (positions 0..N)
|
|
// The owner has already stored at all positions up to current position
|
|
let cacheLen = cache.currentLength
|
|
var cacheLenVal = UInt32(cacheLen)
|
|
enc.setBytes(&cacheLenVal, length: MemoryLayout<UInt32>.size, index: 9)
|
|
var pos = Int32(position)
|
|
enc.setBytes(&pos, length: MemoryLayout<Int32>.size, index: 10)
|
|
let grid = MTLSize(width: config.nHeads, height: config.headDim, depth: 1)
|
|
let tg = engine.threadgroupSize2D(pso, grid: (config.nHeads, config.headDim))
|
|
enc.dispatchThreads(grid, threadsPerThreadgroup: tg)
|
|
enc.endEncoding()
|
|
}
|
|
|
|
func eltwiseAdd(engine: MarkBaseEngine, cmdBuf: MTLCommandBuffer,
|
|
a: MTLBuffer, b: MTLBuffer,
|
|
output: MTLBuffer, count: Int) throws {
|
|
// Try optimized SIMD version first
|
|
if let pso = try? engine.pipeline(named: "eltwise_add_simd") {
|
|
let enc = cmdBuf.makeComputeCommandEncoder()!
|
|
enc.setComputePipelineState(pso)
|
|
enc.setBuffer(a, offset: 0, index: 0)
|
|
enc.setBuffer(b, offset: 0, index: 1)
|
|
enc.setBuffer(output, offset: 0, index: 2)
|
|
var N = UInt32(count)
|
|
enc.setBytes(&N, length: MemoryLayout<UInt32>.size, index: 3)
|
|
let grid = MTLSize(width: (count + 3) / 4, height: 1, depth: 1) // float4 processing
|
|
let tg = engine.threadgroupSize1D(pso, count: (count + 3) / 4)
|
|
enc.dispatchThreads(grid, threadsPerThreadgroup: tg)
|
|
enc.endEncoding()
|
|
return
|
|
}
|
|
|
|
// Fallback to original
|
|
let pso = try engine.pipeline(named: "eltwise_add")
|
|
let enc = cmdBuf.makeComputeCommandEncoder()!
|
|
enc.setComputePipelineState(pso)
|
|
enc.setBuffer(a, offset: 0, index: 0)
|
|
enc.setBuffer(b, offset: 0, index: 1)
|
|
enc.setBuffer(output, offset: 0, index: 2)
|
|
var N = UInt32(count)
|
|
enc.setBytes(&N, length: MemoryLayout<UInt32>.size, index: 3)
|
|
let tg = engine.threadgroupSize1D(pso, count: count)
|
|
enc.dispatchThreads(MTLSize(width: count, height: 1, depth: 1),
|
|
threadsPerThreadgroup: tg)
|
|
enc.endEncoding()
|
|
}
|
|
|
|
func eltwiseMul(engine: MarkBaseEngine, cmdBuf: MTLCommandBuffer,
|
|
a: MTLBuffer, aOffset: Int = 0,
|
|
b: MTLBuffer, bOffset: Int = 0,
|
|
output: MTLBuffer, outputOffset: Int = 0,
|
|
count: Int) throws {
|
|
// Try optimized SIMD version first
|
|
if let pso = try? engine.pipeline(named: "eltwise_mul_simd") {
|
|
let enc = cmdBuf.makeComputeCommandEncoder()!
|
|
enc.setComputePipelineState(pso)
|
|
enc.setBuffer(a, offset: aOffset, index: 0)
|
|
enc.setBuffer(b, offset: bOffset, index: 1)
|
|
enc.setBuffer(output, offset: outputOffset, index: 2)
|
|
var N = UInt32(count)
|
|
enc.setBytes(&N, length: MemoryLayout<UInt32>.size, index: 3)
|
|
let grid = MTLSize(width: (count + 3) / 4, height: 1, depth: 1) // float4 processing
|
|
let tg = engine.threadgroupSize1D(pso, count: (count + 3) / 4)
|
|
enc.dispatchThreads(grid, threadsPerThreadgroup: tg)
|
|
enc.endEncoding()
|
|
return
|
|
}
|
|
|
|
// Fallback to original
|
|
let pso = try engine.pipeline(named: "eltwise_mul")
|
|
let enc = cmdBuf.makeComputeCommandEncoder()!
|
|
enc.setComputePipelineState(pso)
|
|
enc.setBuffer(a, offset: 0, index: 0)
|
|
enc.setBuffer(b, offset: 0, index: 1)
|
|
enc.setBuffer(output, offset: 0, index: 2)
|
|
var N = UInt32(count)
|
|
enc.setBytes(&N, length: MemoryLayout<UInt32>.size, index: 3)
|
|
let tg = engine.threadgroupSize1D(pso, count: count)
|
|
enc.dispatchThreads(MTLSize(width: count, height: 1, depth: 1),
|
|
threadsPerThreadgroup: tg)
|
|
enc.endEncoding()
|
|
}
|
|
|
|
func eltwiseAddScaled(engine: MarkBaseEngine, cmdBuf: MTLCommandBuffer,
|
|
a: MTLBuffer, scaleA: Float,
|
|
b: MTLBuffer, scaleB: Float,
|
|
output: MTLBuffer, count: Int) throws {
|
|
let pso = try engine.pipeline(named: "eltwise_add_scaled")
|
|
let enc = cmdBuf.makeComputeCommandEncoder()!
|
|
enc.setComputePipelineState(pso)
|
|
enc.setBuffer(a, offset: 0, index: 0)
|
|
var sa = scaleA
|
|
enc.setBytes(&sa, length: MemoryLayout<Float>.size, index: 1)
|
|
enc.setBuffer(b, offset: 0, index: 2)
|
|
var sb = scaleB
|
|
enc.setBytes(&sb, length: MemoryLayout<Float>.size, index: 3)
|
|
enc.setBuffer(output, offset: 0, index: 4)
|
|
var N = UInt32(count)
|
|
enc.setBytes(&N, length: MemoryLayout<UInt32>.size, index: 5)
|
|
let tg = engine.threadgroupSize1D(pso, count: count)
|
|
enc.dispatchThreads(MTLSize(width: count, height: 1, depth: 1),
|
|
threadsPerThreadgroup: tg)
|
|
enc.endEncoding()
|
|
}
|
|
|
|
func fusedGateUp(engine: MarkBaseEngine, cmdBuf: MTLCommandBuffer,
|
|
input: MTLBuffer,
|
|
output: MTLBuffer) throws {
|
|
// Float path: separate matmuls for gate and up
|
|
if let gf = gateProjFloat, let uf = upProjFloat {
|
|
try matmulFloat(engine: engine, cmdBuf: cmdBuf, input: input, weights: gf, output: output)
|
|
// Note: This only does gate projection, up projection is separate for bf16
|
|
return
|
|
}
|
|
|
|
// Quantized path: fused kernel
|
|
guard let gp = gateProj, let up = upProj else { return }
|
|
|
|
let kernelName = gp.bits == 8 ? "quantized_matmul_gate_up_opt_8bit" : "quantized_matmul_gate_up_opt"
|
|
if let pso = try? engine.pipeline(named: kernelName) {
|
|
// Optimized path: threadgroup-cached input + uint4 loads
|
|
let enc = cmdBuf.makeComputeCommandEncoder()!
|
|
enc.setComputePipelineState(pso)
|
|
enc.setBuffer(input, offset: 0, index: 0)
|
|
enc.setBuffer(gp.weight, offset: 0, index: 1)
|
|
enc.setBuffer(gp.scales, offset: 0, index: 2)
|
|
enc.setBuffer(gp.biases, offset: 0, index: 3)
|
|
enc.setBuffer(up.weight, offset: 0, index: 4)
|
|
enc.setBuffer(up.scales, offset: 0, index: 5)
|
|
enc.setBuffer(up.biases, offset: 0, index: 6)
|
|
enc.setBuffer(output, offset: 0, index: 7)
|
|
var inDim = UInt32(gp.inDim)
|
|
enc.setBytes(&inDim, length: MemoryLayout<UInt32>.size, index: 8)
|
|
var outDim = UInt32(gp.outDim)
|
|
enc.setBytes(&outDim, length: MemoryLayout<UInt32>.size, index: 9)
|
|
var groupSize = UInt32(gp.groupSize)
|
|
enc.setBytes(&groupSize, length: MemoryLayout<UInt32>.size, index: 10)
|
|
let tgMemSize = gp.inDim * 4
|
|
enc.setThreadgroupMemoryLength(tgMemSize, index: 0)
|
|
let count = gp.outDim
|
|
let tg = MTLSize(width: 256, height: 1, depth: 1)
|
|
enc.dispatchThreads(MTLSize(width: count, height: 1, depth: 1),
|
|
threadsPerThreadgroup: tg)
|
|
enc.endEncoding()
|
|
} else {
|
|
// Fallback to old kernel
|
|
let fallbackName = gp.bits == 8 ? "quantized_matmul_gate_up_8bit" : "quantized_matmul_gate_up"
|
|
let fallbackPSO = try engine.pipeline(named: fallbackName)
|
|
let enc = cmdBuf.makeComputeCommandEncoder()!
|
|
enc.setComputePipelineState(fallbackPSO)
|
|
enc.setBuffer(input, offset: 0, index: 0)
|
|
enc.setBuffer(gp.weight, offset: 0, index: 1)
|
|
enc.setBuffer(gp.scales, offset: 0, index: 2)
|
|
enc.setBuffer(gp.biases, offset: 0, index: 3)
|
|
enc.setBuffer(up.weight, offset: 0, index: 4)
|
|
enc.setBuffer(up.scales, offset: 0, index: 5)
|
|
enc.setBuffer(up.biases, offset: 0, index: 6)
|
|
enc.setBuffer(output, offset: 0, index: 7)
|
|
var inDim = UInt32(gp.inDim)
|
|
enc.setBytes(&inDim, length: MemoryLayout<UInt32>.size, index: 8)
|
|
var outDim = UInt32(gp.outDim)
|
|
enc.setBytes(&outDim, length: MemoryLayout<UInt32>.size, index: 9)
|
|
var groupSize = UInt32(gp.groupSize)
|
|
enc.setBytes(&groupSize, length: MemoryLayout<UInt32>.size, index: 10)
|
|
let count = gp.outDim
|
|
let tg = engine.threadgroupSize1D(fallbackPSO, count: count)
|
|
enc.dispatchThreads(MTLSize(width: count, height: 1, depth: 1),
|
|
threadsPerThreadgroup: tg)
|
|
enc.endEncoding()
|
|
}
|
|
}
|
|
|
|
// ── MoE forward helpers ──────────────────────
|
|
|
|
/// Quantized matmul for a specific expert slice within a 3D tensor.
|
|
func quantizedMatmulExpert(engine: MarkBaseEngine, cmdBuf: MTLCommandBuffer,
|
|
input: MTLBuffer,
|
|
expert: MoEExpertGroup, expertIdx: Int,
|
|
output: MTLBuffer) throws {
|
|
let kernelName = expert.bits == 8 ? "quantized_matmul_simd_8bit" : "quantized_matmul_simd"
|
|
|
|
guard let pso = try? engine.pipeline(named: kernelName) else {
|
|
print(" [ERROR] quantizedMatmulExpert: Shader \(kernelName) not found! Falling back to quantized_matmul_seq")
|
|
let fallbackPSO = try engine.pipeline(named: "quantized_matmul_seq")
|
|
// Use fallback shader...
|
|
let enc = cmdBuf.makeComputeCommandEncoder()!
|
|
enc.setComputePipelineState(fallbackPSO)
|
|
enc.setBuffer(input, offset: 0, index: 0)
|
|
enc.setBuffer(expert.weight, offset: expert.weightStride * expertIdx, index: 1)
|
|
enc.setBuffer(expert.scales, offset: expert.scalesStride * expertIdx, index: 2)
|
|
enc.setBuffer(expert.biases, offset: expert.scalesStride * expertIdx, index: 3)
|
|
enc.setBuffer(output, offset: 0, index: 4)
|
|
var inDim = UInt32(expert.expertInDim)
|
|
enc.setBytes(&inDim, length: MemoryLayout<UInt32>.size, index: 5)
|
|
var outDim = UInt32(expert.expertOutDim)
|
|
enc.setBytes(&outDim, length: MemoryLayout<UInt32>.size, index: 6)
|
|
var groupSize = UInt32(expert.expertInDim / 64)
|
|
enc.setBytes(&groupSize, length: MemoryLayout<UInt32>.size, index: 7)
|
|
let tg = engine.threadgroupSize1D(fallbackPSO, count: expert.expertOutDim)
|
|
enc.dispatchThreads(MTLSize(width: expert.expertOutDim, height: 1, depth: 1),
|
|
threadsPerThreadgroup: tg)
|
|
enc.endEncoding()
|
|
return
|
|
}
|
|
|
|
|
|
let enc = cmdBuf.makeComputeCommandEncoder()!
|
|
enc.setComputePipelineState(pso)
|
|
enc.setBuffer(input, offset: 0, index: 0)
|
|
enc.setBuffer(expert.weight, offset: expert.weightStride * expertIdx, index: 1)
|
|
enc.setBuffer(expert.scales, offset: expert.scalesStride * expertIdx, index: 2)
|
|
enc.setBuffer(expert.biases, offset: expert.scalesStride * expertIdx, index: 3)
|
|
enc.setBuffer(output, offset: 0, index: 4)
|
|
var inDim = UInt32(expert.expertInDim)
|
|
enc.setBytes(&inDim, length: MemoryLayout<UInt32>.size, index: 5)
|
|
var outDim = UInt32(expert.expertOutDim)
|
|
enc.setBytes(&outDim, length: MemoryLayout<UInt32>.size, index: 6)
|
|
var groupSize = UInt32(expert.expertInDim / expert.numGroups) // dynamic group size
|
|
enc.setBytes(&groupSize, length: MemoryLayout<UInt32>.size, index: 7)
|
|
let tgMemSize = expert.expertInDim * 4
|
|
enc.setThreadgroupMemoryLength(tgMemSize, index: 0)
|
|
let count = expert.expertOutDim
|
|
let tg = MTLSize(width: 256, height: 1, depth: 1)
|
|
enc.dispatchThreads(MTLSize(width: count, height: 1, depth: 1),
|
|
threadsPerThreadgroup: tg)
|
|
enc.endEncoding()
|
|
}
|
|
|
|
/// Fused gate+up matmul for a specific expert slice.
|
|
func expertFusedGateUp(engine: MarkBaseEngine, cmdBuf: MTLCommandBuffer,
|
|
input: MTLBuffer,
|
|
gate: MoEExpertGroup, up: MoEExpertGroup,
|
|
expertIdx: Int,
|
|
output: MTLBuffer) throws {
|
|
let kernelName = gate.bits == 8 ? "quantized_matmul_gate_up_8bit" : "quantized_matmul_gate_up"
|
|
let pso = try engine.pipeline(named: kernelName)
|
|
let enc = cmdBuf.makeComputeCommandEncoder()!
|
|
enc.setComputePipelineState(pso)
|
|
enc.setBuffer(input, offset: 0, index: 0)
|
|
enc.setBuffer(gate.weight, offset: gate.weightStride * expertIdx, index: 1)
|
|
enc.setBuffer(gate.scales, offset: gate.scalesStride * expertIdx, index: 2)
|
|
enc.setBuffer(gate.biases, offset: gate.scalesStride * expertIdx, index: 3)
|
|
enc.setBuffer(up.weight, offset: up.weightStride * expertIdx, index: 4)
|
|
enc.setBuffer(up.scales, offset: up.scalesStride * expertIdx, index: 5)
|
|
enc.setBuffer(up.biases, offset: up.scalesStride * expertIdx, index: 6)
|
|
enc.setBuffer(output, offset: 0, index: 7)
|
|
var inDim = UInt32(gate.expertInDim)
|
|
enc.setBytes(&inDim, length: MemoryLayout<UInt32>.size, index: 8)
|
|
var outDim = UInt32(gate.expertOutDim)
|
|
enc.setBytes(&outDim, length: MemoryLayout<UInt32>.size, index: 9)
|
|
var groupSize = UInt32(gate.expertInDim / 64) // group_size is 64 for quantized weights
|
|
enc.setBytes(&groupSize, length: MemoryLayout<UInt32>.size, index: 10)
|
|
let count = gate.expertOutDim
|
|
let tg = engine.threadgroupSize1D(pso, count: count)
|
|
enc.dispatchThreads(MTLSize(width: count, height: 1, depth: 1),
|
|
threadsPerThreadgroup: tg)
|
|
enc.endEncoding()
|
|
}
|
|
|
|
/// Fused gate+up+down for a specific expert slice.
|
|
/// Replaces: fusedGateUp + blit + downMatmul + scaledAdd with a single kernel.
|
|
func expertFusedGateUpDown(engine: MarkBaseEngine, cmdBuf: MTLCommandBuffer,
|
|
input: MTLBuffer,
|
|
gate: MoEExpertGroup, up: MoEExpertGroup, down: MoEExpertGroup,
|
|
expertIdx: Int,
|
|
accum: MTLBuffer, weight: Float) throws {
|
|
let kernelName = gate.bits == 8 ? "quantized_matmul_gate_up_down_8bit" : "quantized_matmul_gate_up_down"
|
|
let pso = try engine.pipeline(named: kernelName)
|
|
let enc = cmdBuf.makeComputeCommandEncoder()!
|
|
enc.setComputePipelineState(pso)
|
|
enc.setBuffer(input, offset: 0, index: 0)
|
|
enc.setBuffer(gate.weight, offset: gate.weightStride * expertIdx, index: 1)
|
|
enc.setBuffer(gate.scales, offset: gate.scalesStride * expertIdx, index: 2)
|
|
enc.setBuffer(gate.biases, offset: gate.scalesStride * expertIdx, index: 3)
|
|
enc.setBuffer(up.weight, offset: up.weightStride * expertIdx, index: 4)
|
|
enc.setBuffer(up.scales, offset: up.scalesStride * expertIdx, index: 5)
|
|
enc.setBuffer(up.biases, offset: up.scalesStride * expertIdx, index: 6)
|
|
enc.setBuffer(down.weight, offset: down.weightStride * expertIdx, index: 7)
|
|
enc.setBuffer(down.scales, offset: down.scalesStride * expertIdx, index: 8)
|
|
enc.setBuffer(down.biases, offset: down.scalesStride * expertIdx, index: 9)
|
|
enc.setBuffer(accum, offset: 0, index: 10)
|
|
var hiddenSize = UInt32(gate.expertInDim)
|
|
enc.setBytes(&hiddenSize, length: MemoryLayout<UInt32>.size, index: 11)
|
|
var moeIntermediate = UInt32(gate.expertOutDim)
|
|
enc.setBytes(&moeIntermediate, length: MemoryLayout<UInt32>.size, index: 12)
|
|
var groupSize = UInt32(gate.expertInDim / gate.numGroups)
|
|
enc.setBytes(&groupSize, length: MemoryLayout<UInt32>.size, index: 13)
|
|
var w = weight
|
|
enc.setBytes(&w, length: MemoryLayout<Float>.size, index: 14)
|
|
|
|
let count = Int(max(hiddenSize, moeIntermediate))
|
|
let tgMemSize = count * 4
|
|
enc.setThreadgroupMemoryLength(tgMemSize, index: 0)
|
|
let tg = MTLSize(width: 256, height: 1, depth: 1)
|
|
enc.dispatchThreads(MTLSize(width: count, height: 1, depth: 1),
|
|
threadsPerThreadgroup: tg)
|
|
enc.endEncoding()
|
|
}
|
|
|
|
func moeMegaKernel(engine: MarkBaseEngine, cmdBuf: MTLCommandBuffer,
|
|
input: MTLBuffer,
|
|
router: QuantizedWeights,
|
|
gate: MoEExpertGroup, up: MoEExpertGroup, down: MoEExpertGroup,
|
|
accum: MTLBuffer) throws -> Bool {
|
|
guard let pso = try? engine.pipeline(named: "moe_mega_kernel") else { return false }
|
|
let enc = cmdBuf.makeComputeCommandEncoder()!
|
|
enc.setComputePipelineState(pso)
|
|
enc.setBuffer(input, offset: 0, index: 0)
|
|
// Router: full 3D weight buffer [numExperts * outDim, inDimPacked]
|
|
enc.setBuffer(router.weight, offset: 0, index: 1)
|
|
enc.setBuffer(router.scales, offset: 0, index: 2)
|
|
enc.setBuffer(router.biases, offset: 0, index: 3)
|
|
// Gate: full 3D weight buffer [numExperts * expertOutDim, expertInDimPacked]
|
|
enc.setBuffer(gate.weight, offset: 0, index: 4)
|
|
enc.setBuffer(gate.scales, offset: 0, index: 5)
|
|
enc.setBuffer(gate.biases, offset: 0, index: 6)
|
|
// Up: full 3D weight buffer
|
|
enc.setBuffer(up.weight, offset: 0, index: 7)
|
|
enc.setBuffer(up.scales, offset: 0, index: 8)
|
|
enc.setBuffer(up.biases, offset: 0, index: 9)
|
|
// Down: full 3D weight buffer [numExperts * expertOutDim, expertInDimPacked]
|
|
enc.setBuffer(down.weight, offset: 0, index: 10)
|
|
enc.setBuffer(down.scales, offset: 0, index: 11)
|
|
enc.setBuffer(down.biases, offset: 0, index: 12)
|
|
enc.setBuffer(accum, offset: 0, index: 13)
|
|
|
|
var hiddenSize = UInt32(gate.expertInDim)
|
|
enc.setBytes(&hiddenSize, length: MemoryLayout<UInt32>.size, index: 14)
|
|
var moeIntermediate = UInt32(gate.expertOutDim)
|
|
enc.setBytes(&moeIntermediate, length: MemoryLayout<UInt32>.size, index: 15)
|
|
var numExperts = UInt32(gate.numExperts)
|
|
enc.setBytes(&numExperts, length: MemoryLayout<UInt32>.size, index: 16)
|
|
var rScale = routerScale
|
|
enc.setBytes(&rScale, length: MemoryLayout<Float>.size, index: 17)
|
|
var topK = UInt32(topK)
|
|
enc.setBytes(&topK, length: MemoryLayout<UInt32>.size, index: 18)
|
|
|
|
let count = Int(max(hiddenSize, moeIntermediate))
|
|
let logitStorage = Int(numExperts) + Int(topK) + Int(topK)
|
|
let tgMemSize = (count + logitStorage) * 4
|
|
enc.setThreadgroupMemoryLength(tgMemSize, index: 0)
|
|
|
|
let tg = MTLSize(width: 256, height: 1, depth: 1)
|
|
enc.dispatchThreads(MTLSize(width: count, height: 1, depth: 1),
|
|
threadsPerThreadgroup: tg)
|
|
enc.endEncoding()
|
|
return true
|
|
}
|
|
|
|
func moeForward(input: MTLBuffer, ns: MTLBuffer,
|
|
temps: ForwardTemps,
|
|
cmdBuf: MTLCommandBuffer,
|
|
engine: MarkBaseEngine) throws {
|
|
guard let router = routerProj,
|
|
let eGate = expertGate, let eUp = expertUp, let eDown = expertDown
|
|
else { return }
|
|
|
|
let hs = config.hiddenSize
|
|
|
|
// ── Step 1: Copy MoE input ns → io (preserve from overwrite by expert output) ──
|
|
let moeInput = temps.io
|
|
let blit = cmdBuf.makeBlitCommandEncoder()!
|
|
blit.copy(from: ns, sourceOffset: 0,
|
|
to: moeInput, destinationOffset: 0,
|
|
size: hs * 4)
|
|
blit.endEncoding()
|
|
|
|
// Zero accumulation buffer temps.h
|
|
let zeroBlit = cmdBuf.makeBlitCommandEncoder()!
|
|
zeroBlit.fill(buffer: temps.h, range: 0..<hs * 4, value: 0)
|
|
zeroBlit.endEncoding()
|
|
|
|
// ── Step 2: Try GPU mega-kernel (fused router + softmax + topK + experts) ──
|
|
if try moeMegaKernel(engine: engine, cmdBuf: cmdBuf,
|
|
input: moeInput,
|
|
router: router, gate: eGate, up: eUp, down: eDown,
|
|
accum: temps.h) {
|
|
// Mega kernel does ALL work on GPU, no CPU dependency
|
|
// No wait needed - caller will commit when ready
|
|
} else {
|
|
// ── Fallback: CPU route (requires wait for CPU read) ──
|
|
// Create separate command buffer for CPU fallback path
|
|
let cpuCmdBuf = engine.commandQueue.makeCommandBuffer()!
|
|
|
|
let numExperts = eGate.numExperts
|
|
|
|
// Router matmul
|
|
try quantizedMatmul(engine: engine, cmdBuf: cpuCmdBuf,
|
|
input: moeInput, weights: router,
|
|
output: temps.gate)
|
|
|
|
cpuCmdBuf.commit()
|
|
cpuCmdBuf.waitUntilCompleted() // CPU read required
|
|
|
|
// CPU softmax + top-k
|
|
let routerData = engine.readFloats(from: temps.gate, count: numExperts)
|
|
var scaled = routerData.map { $0 * routerScale }
|
|
let maxVal = scaled.max() ?? 0
|
|
var sum: Float = 0
|
|
for i in 0..<numExperts {
|
|
scaled[i] = exp(scaled[i] - maxVal)
|
|
sum += scaled[i]
|
|
}
|
|
if sum > 0 {
|
|
for i in 0..<numExperts { scaled[i] /= sum }
|
|
}
|
|
|
|
let k = min(topK, numExperts)
|
|
var indexed = Array(scaled.enumerated())
|
|
indexed.sort { $0.element > $1.element }
|
|
let topK = indexed.prefix(k)
|
|
let topKSum = topK.reduce(0) { $0 + $1.element }
|
|
guard topKSum > 0 else { return }
|
|
|
|
// Compute experts on passed cmdBuf (can be batched with other layers)
|
|
for (expertIdx, rawWeight) in topK {
|
|
let weight = rawWeight / topKSum
|
|
try expertFusedGateUpDown(engine: engine, cmdBuf: cmdBuf,
|
|
input: moeInput,
|
|
gate: eGate, up: eUp, down: eDown,
|
|
expertIdx: expertIdx,
|
|
accum: temps.h, weight: weight)
|
|
}
|
|
}
|
|
|
|
// ── Step 5: Residual: input += moe_output (temps.h) scaled by layerScalar ──
|
|
if layerScalar != 1.0 {
|
|
try eltwiseAddScaled(engine: engine, cmdBuf: cmdBuf,
|
|
a: input, scaleA: 1.0,
|
|
b: temps.h, scaleB: layerScalar,
|
|
output: input, count: hs)
|
|
} else {
|
|
try eltwiseAdd(engine: engine, cmdBuf: cmdBuf,
|
|
a: input, b: temps.h,
|
|
output: input, count: hs)
|
|
}
|
|
}
|
|
|
|
// ── Main forward ─────────────────────────────
|
|
|
|
/// Run one E4B layer forward pass.
|
|
///
|
|
/// - Parameters:
|
|
/// - input: [hiddenSize] — **modified in-place** to become the output
|
|
/// - position: absolute token position (0-based)
|
|
/// - kvCache: KV cache for reading attention. Must be non-nil.
|
|
/// - shouldStoreKV: compute K,V and write into kvCache when true
|
|
/// - temps: pre-allocated scratch buffers
|
|
/// - engine: Metal engine
|
|
public func forward(input: MTLBuffer, position: Int,
|
|
kvCache: KVCache,
|
|
shouldStoreKV: Bool,
|
|
temps: ForwardTemps,
|
|
engine: MarkBaseEngine,
|
|
perLayerInput: MTLBuffer? = nil,
|
|
perLayerInputOffset: Int = 0) throws {
|
|
self.attnBuf = temps.attn
|
|
|
|
if useMoE {
|
|
// ── MoE path: GPU mega kernel eliminates CPU dependency ──
|
|
// All operations use shared command buffer (NO waits)
|
|
let cmdBuf = engine.commandQueue.makeCommandBuffer()!
|
|
try attentionForward(input: input, position: position,
|
|
kvCache: kvCache, shouldStoreKV: shouldStoreKV,
|
|
temps: temps, engine: engine, cmdBuf: cmdBuf)
|
|
|
|
try moeForward(input: input, ns: temps.ns, temps: temps,
|
|
cmdBuf: cmdBuf, engine: engine)
|
|
|
|
try postFfnForward(input: input, temps: temps, engine: engine,
|
|
cmdBuf: cmdBuf,
|
|
perLayerInput: perLayerInput,
|
|
perLayerInputOffset: perLayerInputOffset)
|
|
|
|
cmdBuf.commit()
|
|
cmdBuf.waitUntilCompleted()
|
|
} else {
|
|
// Dense path: unified flow for all positions
|
|
let cmdBuf = engine.commandQueue.makeCommandBuffer()!
|
|
try attentionForward(input: input, position: position,
|
|
kvCache: kvCache, shouldStoreKV: shouldStoreKV,
|
|
temps: temps, engine: engine, cmdBuf: cmdBuf)
|
|
|
|
// FFN: gate+up fused → down → residual (scaled by layerScalar)
|
|
try fusedGateUp(engine: engine, cmdBuf: cmdBuf,
|
|
input: temps.ns, output: temps.gate)
|
|
try matmulAny(engine: engine, cmdBuf: cmdBuf,
|
|
input: temps.gate, weightsQ: downProj, weightsF: downProjFloat, output: temps.h)
|
|
if layerScalar != 1.0 {
|
|
try eltwiseAddScaled(engine: engine, cmdBuf: cmdBuf,
|
|
a: input, scaleA: 1.0,
|
|
b: temps.h, scaleB: layerScalar,
|
|
output: input, count: config.hiddenSize)
|
|
} else {
|
|
try eltwiseAdd(engine: engine, cmdBuf: cmdBuf,
|
|
a: input, b: temps.h,
|
|
output: input, count: config.hiddenSize)
|
|
}
|
|
|
|
// Per-layer gating for dense path
|
|
if let pg = perLayerGate, let pp = perLayerProjection, let pl = perLayerInput {
|
|
try rmsNorm(engine: engine, cmdBuf: cmdBuf,
|
|
input: input, weight: postFeedforwardLayernorm,
|
|
output: temps.h, count: config.hiddenSize, eps: rmsNormEps)
|
|
try matmulAny(engine: engine, cmdBuf: cmdBuf,
|
|
input: temps.h, weightsQ: pg, weightsF: perLayerGateFloat,
|
|
output: temps.gating)
|
|
try gelu(engine: engine, cmdBuf: cmdBuf,
|
|
input: temps.gating, output: temps.gating, count: 256)
|
|
try eltwiseMul(engine: engine, cmdBuf: cmdBuf,
|
|
a: temps.gating, aOffset: 0,
|
|
b: pl, bOffset: perLayerInputOffset,
|
|
output: temps.gating, outputOffset: 0,
|
|
count: 256)
|
|
try matmulAny(engine: engine, cmdBuf: cmdBuf,
|
|
input: temps.gating, weightsQ: pp, weightsF: perLayerProjectionFloat,
|
|
output: temps.h)
|
|
if let ppn = postPerLayerInputNorm {
|
|
try rmsNorm(engine: engine, cmdBuf: cmdBuf,
|
|
input: temps.h, weight: ppn,
|
|
output: temps.h, count: config.hiddenSize, eps: rmsNormEps)
|
|
}
|
|
try eltwiseAdd(engine: engine, cmdBuf: cmdBuf,
|
|
a: input, b: temps.h,
|
|
output: input, count: config.hiddenSize)
|
|
}
|
|
|
|
cmdBuf.commit()
|
|
cmdBuf.waitUntilCompleted()
|
|
}
|
|
}
|
|
|
|
// ── Attention forward (steps 1-13) ──
|
|
private func attentionForward(input: MTLBuffer, position: Int,
|
|
kvCache: KVCache,
|
|
shouldStoreKV: Bool,
|
|
temps: ForwardTemps,
|
|
engine: MarkBaseEngine,
|
|
cmdBuf: MTLCommandBuffer) throws {
|
|
// ── 1. input_layernorm(x) → temps.h ──
|
|
try rmsNorm(engine: engine, cmdBuf: cmdBuf,
|
|
input: input, weight: inputLayernorm,
|
|
output: temps.h, count: config.hiddenSize, eps: rmsNormEps)
|
|
|
|
// ── 2. Q = q_proj(temps.h) → temps.q ──
|
|
try matmulAny(engine: engine, cmdBuf: cmdBuf,
|
|
input: temps.h, weightsQ: qProj, weightsF: qProjFloat, output: temps.q)
|
|
|
|
// ── 3. Q = q_norm(Q) → ns (per-head RMSNorm) ──
|
|
try groupedRmsNorm(engine: engine, cmdBuf: cmdBuf,
|
|
input: temps.q, weight: qNorm,
|
|
output: temps.ns,
|
|
count: config.nHeads * config.headDim,
|
|
groupSize: config.headDim, eps: rmsNormEps)
|
|
|
|
// ── 4. RoPE(Q) on ns ──
|
|
try applyRoPEQ(engine: engine, cmdBuf: cmdBuf,
|
|
q: temps.ns, position: position)
|
|
|
|
// ── 5. K,V projections ──
|
|
try matmulAny(engine: engine, cmdBuf: cmdBuf,
|
|
input: temps.h, weightsQ: kProj, weightsF: kProjFloat, output: temps.k)
|
|
if let vp = vProj, let vpF = vProjFloat {
|
|
if vp != nil || vpF != nil {
|
|
try matmulAny(engine: engine, cmdBuf: cmdBuf,
|
|
input: temps.h, weightsQ: vp, weightsF: vpF, output: temps.v)
|
|
}
|
|
} else if kEqualsV {
|
|
let blit = cmdBuf.makeBlitCommandEncoder()!
|
|
let copyBytes = config.nKvHeads * config.headDim * MemoryLayout<Float>.stride
|
|
blit.copy(from: temps.k, sourceOffset: 0,
|
|
to: temps.v, destinationOffset: 0,
|
|
size: copyBytes)
|
|
blit.endEncoding()
|
|
} else {
|
|
try eltwiseAddScaled(engine: engine, cmdBuf: cmdBuf,
|
|
a: temps.v, scaleA: 0.0,
|
|
b: temps.v, scaleB: 0.0,
|
|
output: temps.v, count: config.nKvHeads * config.headDim)
|
|
}
|
|
|
|
// ── 6. K,V norms ──
|
|
try groupedRmsNorm(engine: engine, cmdBuf: cmdBuf,
|
|
input: temps.k, weight: kNorm,
|
|
output: temps.up,
|
|
count: config.nKvHeads * config.headDim,
|
|
groupSize: config.headDim, eps: rmsNormEps)
|
|
if let vn = vNorm {
|
|
try groupedRmsNorm(engine: engine, cmdBuf: cmdBuf,
|
|
input: temps.v, weight: vn,
|
|
output: temps.gate,
|
|
count: config.nKvHeads * config.headDim,
|
|
groupSize: config.headDim, eps: rmsNormEps)
|
|
}
|
|
|
|
// ── 7. RoPE(K) ──
|
|
try applyRoPEK(engine: engine, cmdBuf: cmdBuf,
|
|
k: temps.up, position: position)
|
|
|
|
// ── 8. Store K,V ──
|
|
if shouldStoreKV {
|
|
let valueBuf = vNorm != nil ? temps.gate : temps.v
|
|
kvCache.store(key: temps.up, keySrcOffset: 0,
|
|
value: valueBuf, valueSrcOffset: 0,
|
|
position: position, commandBuffer: cmdBuf)
|
|
}
|
|
|
|
// ── 9. Attention ──
|
|
let curK = temps.up
|
|
let curV = vNorm != nil ? temps.gate : temps.v
|
|
if config.isSliding {
|
|
if shouldStoreKV {
|
|
try slidingAttention(engine: engine, cmdBuf: cmdBuf,
|
|
q: temps.ns, cache: kvCache, position: position)
|
|
} else {
|
|
try slidingAttentionWithCurrent(engine: engine, cmdBuf: cmdBuf,
|
|
q: temps.ns, cache: kvCache,
|
|
curK: curK, curV: curV,
|
|
position: position)
|
|
}
|
|
} else {
|
|
if shouldStoreKV {
|
|
try fullAttention(engine: engine, cmdBuf: cmdBuf,
|
|
q: temps.ns, cache: kvCache, position: position)
|
|
} else {
|
|
try fullAttentionWithCurrent(engine: engine, cmdBuf: cmdBuf,
|
|
q: temps.ns, cache: kvCache,
|
|
curK: curK, curV: curV,
|
|
position: position)
|
|
}
|
|
}
|
|
|
|
// ── 10. O projection ──
|
|
try matmulAny(engine: engine, cmdBuf: cmdBuf,
|
|
input: temps.attn, weightsQ: oProj, weightsF: oProjFloat, output: temps.h)
|
|
|
|
// ── 11. Residual 1 (scaled by layerScalar) ──
|
|
if layerScalar != 1.0 {
|
|
try eltwiseAddScaled(engine: engine, cmdBuf: cmdBuf,
|
|
a: input, scaleA: 1.0,
|
|
b: temps.h, scaleB: layerScalar,
|
|
output: input, count: config.hiddenSize)
|
|
} else {
|
|
try eltwiseAdd(engine: engine, cmdBuf: cmdBuf,
|
|
a: input, b: temps.h,
|
|
output: input, count: config.hiddenSize)
|
|
}
|
|
|
|
// ── 12. post_attention_layernorm → temps.h ──
|
|
try rmsNorm(engine: engine, cmdBuf: cmdBuf,
|
|
input: input, weight: postAttentionLayernorm,
|
|
output: temps.h, count: config.hiddenSize, eps: rmsNormEps)
|
|
|
|
// ── 13. pre_feedforward_layernorm → ns ──
|
|
try rmsNorm(engine: engine, cmdBuf: cmdBuf,
|
|
input: temps.h, weight: preFeedforwardLayernorm,
|
|
output: temps.ns, count: config.hiddenSize, eps: rmsNormEps)
|
|
}
|
|
|
|
// ── Post-FFN forward (steps 17-19) ──
|
|
private func postFfnForward(input: MTLBuffer, temps: ForwardTemps,
|
|
engine: MarkBaseEngine,
|
|
cmdBuf: MTLCommandBuffer,
|
|
perLayerInput: MTLBuffer?,
|
|
perLayerInputOffset: Int) throws {
|
|
// ── 17. post_feedforward_layernorm → temps.h ──
|
|
try rmsNorm(engine: engine, cmdBuf: cmdBuf,
|
|
input: input, weight: postFeedforwardLayernorm,
|
|
output: temps.h, count: config.hiddenSize, eps: rmsNormEps)
|
|
|
|
// ── 18. Per-layer gating (optional) ──
|
|
if let pg = perLayerGate, let pp = perLayerProjection, let pl = perLayerInput {
|
|
try matmulAny(engine: engine, cmdBuf: cmdBuf,
|
|
input: temps.h, weightsQ: pg, weightsF: perLayerGateFloat,
|
|
output: temps.gating)
|
|
try gelu(engine: engine, cmdBuf: cmdBuf,
|
|
input: temps.gating, output: temps.gating, count: 256)
|
|
|
|
try eltwiseMul(engine: engine, cmdBuf: cmdBuf,
|
|
a: temps.gating, aOffset: 0,
|
|
b: pl, bOffset: perLayerInputOffset,
|
|
output: temps.gating, outputOffset: 0,
|
|
count: 256)
|
|
|
|
try matmulAny(engine: engine, cmdBuf: cmdBuf,
|
|
input: temps.gating, weightsQ: pp, weightsF: perLayerProjectionFloat,
|
|
output: temps.h)
|
|
|
|
if let ppn = postPerLayerInputNorm {
|
|
try rmsNorm(engine: engine, cmdBuf: cmdBuf,
|
|
input: temps.h, weight: ppn,
|
|
output: temps.h, count: config.hiddenSize, eps: rmsNormEps)
|
|
}
|
|
|
|
try eltwiseAdd(engine: engine, cmdBuf: cmdBuf,
|
|
a: input, b: temps.h,
|
|
output: input, count: config.hiddenSize)
|
|
} else {
|
|
try eltwiseAdd(engine: engine, cmdBuf: cmdBuf,
|
|
a: input, b: temps.h,
|
|
output: input, count: config.hiddenSize)
|
|
}
|
|
}
|
|
}
|