Files
markbaseengine/Sources/MarkBase/Vision/VisionTower.swift
T
MarkBase Admin 96fe213bc4
CI / build (push) Waiting to run
CI / unit-tests (push) Blocked by required conditions
CI / lint (push) Blocked by required conditions
v2: add E4B multimodal test, fix VisionTower missing groupSize
2026-07-06 02:53:49 +08:00

331 lines
15 KiB
Swift

import Metal
public final class VisionTower {
public let config: VisionConfig
public let engine: MarkBaseEngine
public let weights: VisionWeights
private var qBuffer: MTLBuffer
private var kBuffer: MTLBuffer
private var vBuffer: MTLBuffer
private var attnOutBuffer: MTLBuffer
private var mlpBuffer: MTLBuffer
private var tempBuffer: MTLBuffer
private var normBuffer: MTLBuffer
private var residualBuffer: MTLBuffer
public init(config: VisionConfig, engine: MarkBaseEngine, weights: VisionWeights) throws {
self.config = config
self.engine = engine
self.weights = weights
let device = engine.device
let maxPatches = 4096
let hiddenSize = config.hiddenSize
let intermediateSize = config.intermediateSize
qBuffer = device.makeBuffer(length: hiddenSize * maxPatches * 4)!
kBuffer = device.makeBuffer(length: hiddenSize * maxPatches * 4)!
vBuffer = device.makeBuffer(length: hiddenSize * maxPatches * 4)!
attnOutBuffer = device.makeBuffer(length: hiddenSize * maxPatches * 4)!
mlpBuffer = device.makeBuffer(length: intermediateSize * maxPatches * 4)!
tempBuffer = device.makeBuffer(length: max(hiddenSize, intermediateSize) * maxPatches * 4)!
normBuffer = device.makeBuffer(length: hiddenSize * maxPatches * 4)!
residualBuffer = device.makeBuffer(length: hiddenSize * maxPatches * 4)!
}
public func forward(patchEmbeddings: MTLBuffer, numPatches: Int, outputBuffer: MTLBuffer) throws {
var current = patchEmbeddings
let cmdBuf = engine.commandQueue.makeCommandBuffer()!
// Input projection: [numPatches, 768] -> [numPatches, 768]
current = try applyQuantizedMatmul(input: current, weights: weights.inputProj,
seqLen: numPatches, output: tempBuffer, cmdBuf: cmdBuf)
// Add position embedding
current = try addPositionEmbedding(input: current, numPatches: numPatches, cmdBuf: cmdBuf)
// Vision layers (16 layers)
for layerWeights in weights.layers {
current = try applyLayer(input: current, weights: layerWeights, numPatches: numPatches, cmdBuf: cmdBuf)
}
// Embedding projection: [numPatches, 768] -> [numPatches, 2560]
try applyEmbeddingProjection(input: current, numPatches: numPatches, output: outputBuffer, cmdBuf: cmdBuf)
cmdBuf.commit()
cmdBuf.waitUntilCompleted()
}
// ── Quantized matmul (sequence-aware) ─────────────
private func applyQuantizedMatmul(input: MTLBuffer, weights: QuantizedWeights,
seqLen: Int, output: MTLBuffer,
cmdBuf: MTLCommandBuffer) throws -> MTLBuffer {
let pso = try engine.pipeline(named: "quantized_matmul")
let enc = cmdBuf.makeComputeCommandEncoder()!
enc.setComputePipelineState(pso)
enc.setBuffer(input, offset: 0, index: 0)
enc.setBuffer(weights.weight, offset: 0, index: 1)
enc.setBuffer(weights.scales, offset: 0, index: 2)
enc.setBuffer(weights.biases, offset: 0, index: 3)
enc.setBuffer(output, offset: 0, index: 4)
var inD = UInt32(weights.inDim)
enc.setBytes(&inD, length: MemoryLayout<UInt32>.size, index: 5)
var outD = UInt32(weights.outDim)
enc.setBytes(&outD, length: MemoryLayout<UInt32>.size, index: 6)
var groupSize = UInt32(weights.groupSize)
enc.setBytes(&groupSize, length: MemoryLayout<UInt32>.size, index: 7)
let grid = MTLSize(width: weights.outDim * seqLen, height: 1, depth: 1)
let tg = engine.threadgroupSize1D(pso, count: max(weights.outDim, seqLen))
enc.dispatchThreads(grid, threadsPerThreadgroup: tg)
enc.endEncoding()
return output
}
// ── Position embedding ────────────────────────────
private func addPositionEmbedding(input: MTLBuffer, numPatches: Int, cmdBuf: MTLCommandBuffer) throws -> MTLBuffer {
let output = normBuffer
let pso = try engine.pipeline(named: "vision_add_pos_embed")
let enc = cmdBuf.makeComputeCommandEncoder()!
enc.setComputePipelineState(pso)
enc.setBuffer(input, offset: 0, index: 0)
enc.setBuffer(weights.positionEmbedding, offset: 0, index: 1)
enc.setBuffer(output, offset: 0, index: 2)
var hiddenSize = UInt32(config.hiddenSize)
enc.setBytes(&hiddenSize, length: 4, index: 3)
var numPatches_ = UInt32(numPatches)
enc.setBytes(&numPatches_, length: 4, index: 4)
let grid = MTLSize(width: config.hiddenSize, height: numPatches, depth: 1)
let tg = engine.threadgroupSize2D(pso, grid: (config.hiddenSize, numPatches))
enc.dispatchThreads(grid, threadsPerThreadgroup: tg)
enc.endEncoding()
return output
}
// ── Layer ─────────────────────────────────────────
private func applyLayer(input: MTLBuffer, weights: VisionLayerWeights, numPatches: Int, cmdBuf: MTLCommandBuffer) throws -> MTLBuffer {
var current = input
// 1. Input layernorm
current = try applyRMSNorm(input: current, weight: weights.inputLayernorm, seqLen: numPatches, hiddenSize: config.hiddenSize, cmdBuf: cmdBuf)
// 2. Self-attention with Q/K norm
let attnOut = try applyVisionAttention(input: current, weights: weights, numPatches: numPatches, cmdBuf: cmdBuf)
// 3. Residual + post_attention_layernorm
current = try applyResidualAdd(input: input, add: attnOut, seqLen: numPatches, hiddenSize: config.hiddenSize, cmdBuf: cmdBuf)
current = try applyRMSNorm(input: current, weight: weights.postAttentionLayernorm, seqLen: numPatches, hiddenSize: config.hiddenSize, cmdBuf: cmdBuf)
// 4. Pre-feedforward layernorm
current = try applyRMSNorm(input: current, weight: weights.preFeedforwardLayernorm, seqLen: numPatches, hiddenSize: config.hiddenSize, cmdBuf: cmdBuf)
// 5. MLP (SwiGLU)
let mlpOut = try applyVisionMLP(input: current, weights: weights, numPatches: numPatches, cmdBuf: cmdBuf)
// 6. Residual + post_feedforward_layernorm
current = try applyResidualAdd(input: current, add: mlpOut, seqLen: numPatches, hiddenSize: config.hiddenSize, cmdBuf: cmdBuf)
current = try applyRMSNorm(input: current, weight: weights.postFeedforwardLayernorm, seqLen: numPatches, hiddenSize: config.hiddenSize, cmdBuf: cmdBuf)
return current
}
private func applyVisionAttention(input: MTLBuffer, weights: VisionLayerWeights, numPatches: Int, cmdBuf: MTLCommandBuffer) throws -> MTLBuffer {
// Q, K, V projections
let q = try applyQuantizedMatmul(input: input, weights: weights.selfAttnQProj, seqLen: numPatches, output: qBuffer, cmdBuf: cmdBuf)
let k = try applyQuantizedMatmul(input: input, weights: weights.selfAttnKProj, seqLen: numPatches, output: kBuffer, cmdBuf: cmdBuf)
let v = try applyQuantizedMatmul(input: input, weights: weights.selfAttnVProj, seqLen: numPatches, output: vBuffer, cmdBuf: cmdBuf)
// Q/K norm
let qNormed = try applyHeadNorm(input: q, weight: weights.qNorm, seqLen: numPatches, numHeads: config.numAttentionHeads, headDim: config.headDim, cmdBuf: cmdBuf)
let kNormed = try applyHeadNorm(input: k, weight: weights.kNorm, seqLen: numPatches, numHeads: config.numAttentionHeads, headDim: config.headDim, cmdBuf: cmdBuf)
// Attention
let attnOut = try applyAttention(q: qNormed, k: kNormed, v: v, numPatches: numPatches, numHeads: config.numAttentionHeads, headDim: config.headDim, output: attnOutBuffer, cmdBuf: cmdBuf)
// O projection
return try applyQuantizedMatmul(input: attnOut, weights: weights.selfAttnOProj, seqLen: numPatches, output: tempBuffer, cmdBuf: cmdBuf)
}
private func applyHeadNorm(input: MTLBuffer, weight: MTLBuffer, seqLen: Int, numHeads: Int, headDim: Int, cmdBuf: MTLCommandBuffer) throws -> MTLBuffer {
let output = input
let pso = try engine.pipeline(named: "vision_head_norm")
let enc = cmdBuf.makeComputeCommandEncoder()!
enc.setComputePipelineState(pso)
enc.setBuffer(input, offset: 0, index: 0)
enc.setBuffer(weight, offset: 0, index: 1)
enc.setBuffer(output, offset: 0, index: 2)
var numHeads_ = UInt32(numHeads)
enc.setBytes(&numHeads_, length: 4, index: 3)
var headDim_ = UInt32(headDim)
enc.setBytes(&headDim_, length: 4, index: 4)
var seqLen_ = UInt32(seqLen)
enc.setBytes(&seqLen_, length: 4, index: 5)
var eps = config.rmsNormEps
enc.setBytes(&eps, length: 4, index: 6)
let grid = MTLSize(width: numHeads * headDim, height: seqLen, depth: 1)
let tg = engine.threadgroupSize2D(pso, grid: (numHeads * headDim, seqLen))
enc.dispatchThreads(grid, threadsPerThreadgroup: tg)
enc.endEncoding()
return output
}
private func applyAttention(q: MTLBuffer, k: MTLBuffer, v: MTLBuffer, numPatches: Int, numHeads: Int, headDim: Int, output: MTLBuffer, cmdBuf: MTLCommandBuffer) throws -> MTLBuffer {
let pso = try engine.pipeline(named: "vision_attention")
let enc = cmdBuf.makeComputeCommandEncoder()!
enc.setComputePipelineState(pso)
enc.setBuffer(q, offset: 0, index: 0)
enc.setBuffer(k, offset: 0, index: 1)
enc.setBuffer(v, offset: 0, index: 2)
enc.setBuffer(output, offset: 0, index: 3)
var numPatches_ = UInt32(numPatches)
enc.setBytes(&numPatches_, length: 4, index: 4)
var numHeads_ = UInt32(numHeads)
enc.setBytes(&numHeads_, length: 4, index: 5)
var headDim_ = UInt32(headDim)
enc.setBytes(&headDim_, length: 4, index: 6)
let grid = MTLSize(width: numHeads * headDim, height: numPatches, depth: 1)
let tg = engine.threadgroupSize2D(pso, grid: (numHeads * headDim, numPatches))
enc.dispatchThreads(grid, threadsPerThreadgroup: tg)
enc.endEncoding()
return output
}
private func applyVisionMLP(input: MTLBuffer, weights: VisionLayerWeights, numPatches: Int, cmdBuf: MTLCommandBuffer) throws -> MTLBuffer {
// Gate projection: [numPatches, 768] -> [numPatches, 3072]
let gate = try applyQuantizedMatmul(input: input, weights: weights.mlpGateProj, seqLen: numPatches, output: mlpBuffer, cmdBuf: cmdBuf)
// Up projection: [numPatches, 768] -> [numPatches, 3072]
let up = try applyQuantizedMatmul(input: input, weights: weights.mlpUpProj, seqLen: numPatches, output: tempBuffer, cmdBuf: cmdBuf)
// SiLU(gate) * up
let gated = try applyGateMultiply(gate: gate, up: up, count: numPatches * config.intermediateSize, cmdBuf: cmdBuf)
// Down projection: [numPatches, 3072] -> [numPatches, 768]
return try applyQuantizedMatmul(input: gated, weights: weights.mlpDownProj, seqLen: numPatches, output: tempBuffer, cmdBuf: cmdBuf)
}
private func applyGateMultiply(gate: MTLBuffer, up: MTLBuffer, count: Int, cmdBuf: MTLCommandBuffer) throws -> MTLBuffer {
let output = mlpBuffer
let pso = try engine.pipeline(named: "vision_gate_multiply")
let enc = cmdBuf.makeComputeCommandEncoder()!
enc.setComputePipelineState(pso)
enc.setBuffer(gate, offset: 0, index: 0)
enc.setBuffer(up, offset: 0, index: 1)
enc.setBuffer(output, offset: 0, index: 2)
var count_ = UInt32(count)
enc.setBytes(&count_, length: 4, index: 3)
let grid = MTLSize(width: count, height: 1, depth: 1)
let tg = engine.threadgroupSize1D(pso, count: count)
enc.dispatchThreads(grid, threadsPerThreadgroup: tg)
enc.endEncoding()
return output
}
// ── Utility kernels ───────────────────────────────
private func applyRMSNorm(input: MTLBuffer, weight: MTLBuffer, seqLen: Int, hiddenSize: Int, cmdBuf: MTLCommandBuffer) throws -> MTLBuffer {
let output = tempBuffer
let pso = try engine.pipeline(named: "rms_norm_seq")
let enc = cmdBuf.makeComputeCommandEncoder()!
enc.setComputePipelineState(pso)
enc.setBuffer(input, offset: 0, index: 0)
enc.setBuffer(weight, offset: 0, index: 1)
enc.setBuffer(output, offset: 0, index: 2)
var N = UInt32(hiddenSize)
enc.setBytes(&N, length: 4, index: 3)
var eps = config.rmsNormEps
enc.setBytes(&eps, length: 4, index: 4)
var sl = UInt32(seqLen)
enc.setBytes(&sl, length: 4, index: 5)
let grid = MTLSize(width: hiddenSize, height: seqLen, depth: 1)
let tg = engine.threadgroupSize2D(pso, grid: (hiddenSize, seqLen))
enc.dispatchThreads(grid, threadsPerThreadgroup: tg)
enc.endEncoding()
return output
}
private func applyResidualAdd(input: MTLBuffer, add: MTLBuffer, seqLen: Int, hiddenSize: Int, cmdBuf: MTLCommandBuffer) throws -> MTLBuffer {
let output = residualBuffer
let pso = try engine.pipeline(named: "vision_residual_add")
let enc = cmdBuf.makeComputeCommandEncoder()!
enc.setComputePipelineState(pso)
enc.setBuffer(input, offset: 0, index: 0)
enc.setBuffer(add, offset: 0, index: 1)
enc.setBuffer(output, offset: 0, index: 2)
var count = UInt32(seqLen * hiddenSize)
enc.setBytes(&count, length: 4, index: 3)
let grid = MTLSize(width: seqLen * hiddenSize, height: 1, depth: 1)
let tg = engine.threadgroupSize1D(pso, count: seqLen * hiddenSize)
enc.dispatchThreads(grid, threadsPerThreadgroup: tg)
enc.endEncoding()
return output
}
private func applyEmbeddingProjection(input: MTLBuffer, numPatches: Int, output: MTLBuffer, cmdBuf: MTLCommandBuffer) throws {
let pso = try engine.pipeline(named: "vision_embedding_projection_quantized")
let enc = cmdBuf.makeComputeCommandEncoder()!
enc.setComputePipelineState(pso)
enc.setBuffer(input, offset: 0, index: 0)
enc.setBuffer(weights.embeddingProjectionWeight, offset: 0, index: 1)
enc.setBuffer(weights.embeddingProjectionScales, offset: 0, index: 2)
enc.setBuffer(weights.embeddingProjectionBiases, offset: 0, index: 3)
enc.setBuffer(output, offset: 0, index: 4)
var inFeatures = UInt32(768) // Vision hidden size
enc.setBytes(&inFeatures, length: 4, index: 5)
var outFeatures = UInt32(2560) // Text hidden size
enc.setBytes(&outFeatures, length: 4, index: 6)
var np = UInt32(numPatches)
enc.setBytes(&np, length: 4, index: 7)
var packedSize = UInt32(96) // 768 / 8
enc.setBytes(&packedSize, length: 4, index: 8)
var groupSize = UInt32(64)
enc.setBytes(&groupSize, length: 4, index: 9)
var numGroups = UInt32(12) // 768 / 64
enc.setBytes(&numGroups, length: 4, index: 10)
let grid = MTLSize(width: 2560, height: numPatches, depth: 1)
let tg = engine.threadgroupSize2D(pso, grid: (2560, numPatches))
enc.dispatchThreads(grid, threadsPerThreadgroup: tg)
enc.endEncoding()
}
}