331 lines
15 KiB
Swift
331 lines
15 KiB
Swift
import Metal
|
|
|
|
public final class VisionTower {
|
|
public let config: VisionConfig
|
|
public let engine: MarkBaseEngine
|
|
public let weights: VisionWeights
|
|
|
|
private var qBuffer: MTLBuffer
|
|
private var kBuffer: MTLBuffer
|
|
private var vBuffer: MTLBuffer
|
|
private var attnOutBuffer: MTLBuffer
|
|
private var mlpBuffer: MTLBuffer
|
|
private var tempBuffer: MTLBuffer
|
|
private var normBuffer: MTLBuffer
|
|
private var residualBuffer: MTLBuffer
|
|
|
|
public init(config: VisionConfig, engine: MarkBaseEngine, weights: VisionWeights) throws {
|
|
self.config = config
|
|
self.engine = engine
|
|
self.weights = weights
|
|
|
|
let device = engine.device
|
|
let maxPatches = 4096
|
|
let hiddenSize = config.hiddenSize
|
|
let intermediateSize = config.intermediateSize
|
|
|
|
qBuffer = device.makeBuffer(length: hiddenSize * maxPatches * 4)!
|
|
kBuffer = device.makeBuffer(length: hiddenSize * maxPatches * 4)!
|
|
vBuffer = device.makeBuffer(length: hiddenSize * maxPatches * 4)!
|
|
attnOutBuffer = device.makeBuffer(length: hiddenSize * maxPatches * 4)!
|
|
mlpBuffer = device.makeBuffer(length: intermediateSize * maxPatches * 4)!
|
|
tempBuffer = device.makeBuffer(length: max(hiddenSize, intermediateSize) * maxPatches * 4)!
|
|
normBuffer = device.makeBuffer(length: hiddenSize * maxPatches * 4)!
|
|
residualBuffer = device.makeBuffer(length: hiddenSize * maxPatches * 4)!
|
|
}
|
|
|
|
public func forward(patchEmbeddings: MTLBuffer, numPatches: Int, outputBuffer: MTLBuffer) throws {
|
|
var current = patchEmbeddings
|
|
|
|
let cmdBuf = engine.commandQueue.makeCommandBuffer()!
|
|
|
|
// Input projection: [numPatches, 768] -> [numPatches, 768]
|
|
current = try applyQuantizedMatmul(input: current, weights: weights.inputProj,
|
|
seqLen: numPatches, output: tempBuffer, cmdBuf: cmdBuf)
|
|
|
|
// Add position embedding
|
|
current = try addPositionEmbedding(input: current, numPatches: numPatches, cmdBuf: cmdBuf)
|
|
|
|
// Vision layers (16 layers)
|
|
for layerWeights in weights.layers {
|
|
current = try applyLayer(input: current, weights: layerWeights, numPatches: numPatches, cmdBuf: cmdBuf)
|
|
}
|
|
|
|
// Embedding projection: [numPatches, 768] -> [numPatches, 2560]
|
|
try applyEmbeddingProjection(input: current, numPatches: numPatches, output: outputBuffer, cmdBuf: cmdBuf)
|
|
|
|
cmdBuf.commit()
|
|
cmdBuf.waitUntilCompleted()
|
|
}
|
|
|
|
// ── Quantized matmul (sequence-aware) ─────────────
|
|
|
|
private func applyQuantizedMatmul(input: MTLBuffer, weights: QuantizedWeights,
|
|
seqLen: Int, output: MTLBuffer,
|
|
cmdBuf: MTLCommandBuffer) throws -> MTLBuffer {
|
|
let pso = try engine.pipeline(named: "quantized_matmul")
|
|
let enc = cmdBuf.makeComputeCommandEncoder()!
|
|
enc.setComputePipelineState(pso)
|
|
|
|
enc.setBuffer(input, offset: 0, index: 0)
|
|
enc.setBuffer(weights.weight, offset: 0, index: 1)
|
|
enc.setBuffer(weights.scales, offset: 0, index: 2)
|
|
enc.setBuffer(weights.biases, offset: 0, index: 3)
|
|
enc.setBuffer(output, offset: 0, index: 4)
|
|
|
|
var inD = UInt32(weights.inDim)
|
|
enc.setBytes(&inD, length: MemoryLayout<UInt32>.size, index: 5)
|
|
var outD = UInt32(weights.outDim)
|
|
enc.setBytes(&outD, length: MemoryLayout<UInt32>.size, index: 6)
|
|
var groupSize = UInt32(weights.groupSize)
|
|
enc.setBytes(&groupSize, length: MemoryLayout<UInt32>.size, index: 7)
|
|
|
|
let grid = MTLSize(width: weights.outDim * seqLen, height: 1, depth: 1)
|
|
let tg = engine.threadgroupSize1D(pso, count: max(weights.outDim, seqLen))
|
|
enc.dispatchThreads(grid, threadsPerThreadgroup: tg)
|
|
enc.endEncoding()
|
|
|
|
return output
|
|
}
|
|
|
|
// ── Position embedding ────────────────────────────
|
|
|
|
private func addPositionEmbedding(input: MTLBuffer, numPatches: Int, cmdBuf: MTLCommandBuffer) throws -> MTLBuffer {
|
|
let output = normBuffer
|
|
|
|
let pso = try engine.pipeline(named: "vision_add_pos_embed")
|
|
let enc = cmdBuf.makeComputeCommandEncoder()!
|
|
enc.setComputePipelineState(pso)
|
|
|
|
enc.setBuffer(input, offset: 0, index: 0)
|
|
enc.setBuffer(weights.positionEmbedding, offset: 0, index: 1)
|
|
enc.setBuffer(output, offset: 0, index: 2)
|
|
|
|
var hiddenSize = UInt32(config.hiddenSize)
|
|
enc.setBytes(&hiddenSize, length: 4, index: 3)
|
|
var numPatches_ = UInt32(numPatches)
|
|
enc.setBytes(&numPatches_, length: 4, index: 4)
|
|
|
|
let grid = MTLSize(width: config.hiddenSize, height: numPatches, depth: 1)
|
|
let tg = engine.threadgroupSize2D(pso, grid: (config.hiddenSize, numPatches))
|
|
enc.dispatchThreads(grid, threadsPerThreadgroup: tg)
|
|
enc.endEncoding()
|
|
|
|
return output
|
|
}
|
|
|
|
// ── Layer ─────────────────────────────────────────
|
|
|
|
private func applyLayer(input: MTLBuffer, weights: VisionLayerWeights, numPatches: Int, cmdBuf: MTLCommandBuffer) throws -> MTLBuffer {
|
|
var current = input
|
|
|
|
// 1. Input layernorm
|
|
current = try applyRMSNorm(input: current, weight: weights.inputLayernorm, seqLen: numPatches, hiddenSize: config.hiddenSize, cmdBuf: cmdBuf)
|
|
|
|
// 2. Self-attention with Q/K norm
|
|
let attnOut = try applyVisionAttention(input: current, weights: weights, numPatches: numPatches, cmdBuf: cmdBuf)
|
|
|
|
// 3. Residual + post_attention_layernorm
|
|
current = try applyResidualAdd(input: input, add: attnOut, seqLen: numPatches, hiddenSize: config.hiddenSize, cmdBuf: cmdBuf)
|
|
current = try applyRMSNorm(input: current, weight: weights.postAttentionLayernorm, seqLen: numPatches, hiddenSize: config.hiddenSize, cmdBuf: cmdBuf)
|
|
|
|
// 4. Pre-feedforward layernorm
|
|
current = try applyRMSNorm(input: current, weight: weights.preFeedforwardLayernorm, seqLen: numPatches, hiddenSize: config.hiddenSize, cmdBuf: cmdBuf)
|
|
|
|
// 5. MLP (SwiGLU)
|
|
let mlpOut = try applyVisionMLP(input: current, weights: weights, numPatches: numPatches, cmdBuf: cmdBuf)
|
|
|
|
// 6. Residual + post_feedforward_layernorm
|
|
current = try applyResidualAdd(input: current, add: mlpOut, seqLen: numPatches, hiddenSize: config.hiddenSize, cmdBuf: cmdBuf)
|
|
current = try applyRMSNorm(input: current, weight: weights.postFeedforwardLayernorm, seqLen: numPatches, hiddenSize: config.hiddenSize, cmdBuf: cmdBuf)
|
|
|
|
return current
|
|
}
|
|
|
|
private func applyVisionAttention(input: MTLBuffer, weights: VisionLayerWeights, numPatches: Int, cmdBuf: MTLCommandBuffer) throws -> MTLBuffer {
|
|
// Q, K, V projections
|
|
let q = try applyQuantizedMatmul(input: input, weights: weights.selfAttnQProj, seqLen: numPatches, output: qBuffer, cmdBuf: cmdBuf)
|
|
let k = try applyQuantizedMatmul(input: input, weights: weights.selfAttnKProj, seqLen: numPatches, output: kBuffer, cmdBuf: cmdBuf)
|
|
let v = try applyQuantizedMatmul(input: input, weights: weights.selfAttnVProj, seqLen: numPatches, output: vBuffer, cmdBuf: cmdBuf)
|
|
|
|
// Q/K norm
|
|
let qNormed = try applyHeadNorm(input: q, weight: weights.qNorm, seqLen: numPatches, numHeads: config.numAttentionHeads, headDim: config.headDim, cmdBuf: cmdBuf)
|
|
let kNormed = try applyHeadNorm(input: k, weight: weights.kNorm, seqLen: numPatches, numHeads: config.numAttentionHeads, headDim: config.headDim, cmdBuf: cmdBuf)
|
|
|
|
// Attention
|
|
let attnOut = try applyAttention(q: qNormed, k: kNormed, v: v, numPatches: numPatches, numHeads: config.numAttentionHeads, headDim: config.headDim, output: attnOutBuffer, cmdBuf: cmdBuf)
|
|
|
|
// O projection
|
|
return try applyQuantizedMatmul(input: attnOut, weights: weights.selfAttnOProj, seqLen: numPatches, output: tempBuffer, cmdBuf: cmdBuf)
|
|
}
|
|
|
|
private func applyHeadNorm(input: MTLBuffer, weight: MTLBuffer, seqLen: Int, numHeads: Int, headDim: Int, cmdBuf: MTLCommandBuffer) throws -> MTLBuffer {
|
|
let output = input
|
|
|
|
let pso = try engine.pipeline(named: "vision_head_norm")
|
|
let enc = cmdBuf.makeComputeCommandEncoder()!
|
|
enc.setComputePipelineState(pso)
|
|
|
|
enc.setBuffer(input, offset: 0, index: 0)
|
|
enc.setBuffer(weight, offset: 0, index: 1)
|
|
enc.setBuffer(output, offset: 0, index: 2)
|
|
|
|
var numHeads_ = UInt32(numHeads)
|
|
enc.setBytes(&numHeads_, length: 4, index: 3)
|
|
var headDim_ = UInt32(headDim)
|
|
enc.setBytes(&headDim_, length: 4, index: 4)
|
|
var seqLen_ = UInt32(seqLen)
|
|
enc.setBytes(&seqLen_, length: 4, index: 5)
|
|
var eps = config.rmsNormEps
|
|
enc.setBytes(&eps, length: 4, index: 6)
|
|
|
|
let grid = MTLSize(width: numHeads * headDim, height: seqLen, depth: 1)
|
|
let tg = engine.threadgroupSize2D(pso, grid: (numHeads * headDim, seqLen))
|
|
enc.dispatchThreads(grid, threadsPerThreadgroup: tg)
|
|
enc.endEncoding()
|
|
|
|
return output
|
|
}
|
|
|
|
private func applyAttention(q: MTLBuffer, k: MTLBuffer, v: MTLBuffer, numPatches: Int, numHeads: Int, headDim: Int, output: MTLBuffer, cmdBuf: MTLCommandBuffer) throws -> MTLBuffer {
|
|
let pso = try engine.pipeline(named: "vision_attention")
|
|
let enc = cmdBuf.makeComputeCommandEncoder()!
|
|
enc.setComputePipelineState(pso)
|
|
|
|
enc.setBuffer(q, offset: 0, index: 0)
|
|
enc.setBuffer(k, offset: 0, index: 1)
|
|
enc.setBuffer(v, offset: 0, index: 2)
|
|
enc.setBuffer(output, offset: 0, index: 3)
|
|
|
|
var numPatches_ = UInt32(numPatches)
|
|
enc.setBytes(&numPatches_, length: 4, index: 4)
|
|
var numHeads_ = UInt32(numHeads)
|
|
enc.setBytes(&numHeads_, length: 4, index: 5)
|
|
var headDim_ = UInt32(headDim)
|
|
enc.setBytes(&headDim_, length: 4, index: 6)
|
|
|
|
let grid = MTLSize(width: numHeads * headDim, height: numPatches, depth: 1)
|
|
let tg = engine.threadgroupSize2D(pso, grid: (numHeads * headDim, numPatches))
|
|
enc.dispatchThreads(grid, threadsPerThreadgroup: tg)
|
|
enc.endEncoding()
|
|
|
|
return output
|
|
}
|
|
|
|
private func applyVisionMLP(input: MTLBuffer, weights: VisionLayerWeights, numPatches: Int, cmdBuf: MTLCommandBuffer) throws -> MTLBuffer {
|
|
// Gate projection: [numPatches, 768] -> [numPatches, 3072]
|
|
let gate = try applyQuantizedMatmul(input: input, weights: weights.mlpGateProj, seqLen: numPatches, output: mlpBuffer, cmdBuf: cmdBuf)
|
|
|
|
// Up projection: [numPatches, 768] -> [numPatches, 3072]
|
|
let up = try applyQuantizedMatmul(input: input, weights: weights.mlpUpProj, seqLen: numPatches, output: tempBuffer, cmdBuf: cmdBuf)
|
|
|
|
// SiLU(gate) * up
|
|
let gated = try applyGateMultiply(gate: gate, up: up, count: numPatches * config.intermediateSize, cmdBuf: cmdBuf)
|
|
|
|
// Down projection: [numPatches, 3072] -> [numPatches, 768]
|
|
return try applyQuantizedMatmul(input: gated, weights: weights.mlpDownProj, seqLen: numPatches, output: tempBuffer, cmdBuf: cmdBuf)
|
|
}
|
|
|
|
private func applyGateMultiply(gate: MTLBuffer, up: MTLBuffer, count: Int, cmdBuf: MTLCommandBuffer) throws -> MTLBuffer {
|
|
let output = mlpBuffer
|
|
|
|
let pso = try engine.pipeline(named: "vision_gate_multiply")
|
|
let enc = cmdBuf.makeComputeCommandEncoder()!
|
|
enc.setComputePipelineState(pso)
|
|
|
|
enc.setBuffer(gate, offset: 0, index: 0)
|
|
enc.setBuffer(up, offset: 0, index: 1)
|
|
enc.setBuffer(output, offset: 0, index: 2)
|
|
|
|
var count_ = UInt32(count)
|
|
enc.setBytes(&count_, length: 4, index: 3)
|
|
|
|
let grid = MTLSize(width: count, height: 1, depth: 1)
|
|
let tg = engine.threadgroupSize1D(pso, count: count)
|
|
enc.dispatchThreads(grid, threadsPerThreadgroup: tg)
|
|
enc.endEncoding()
|
|
|
|
return output
|
|
}
|
|
|
|
// ── Utility kernels ───────────────────────────────
|
|
|
|
private func applyRMSNorm(input: MTLBuffer, weight: MTLBuffer, seqLen: Int, hiddenSize: Int, cmdBuf: MTLCommandBuffer) throws -> MTLBuffer {
|
|
let output = tempBuffer
|
|
|
|
let pso = try engine.pipeline(named: "rms_norm_seq")
|
|
let enc = cmdBuf.makeComputeCommandEncoder()!
|
|
enc.setComputePipelineState(pso)
|
|
|
|
enc.setBuffer(input, offset: 0, index: 0)
|
|
enc.setBuffer(weight, offset: 0, index: 1)
|
|
enc.setBuffer(output, offset: 0, index: 2)
|
|
|
|
var N = UInt32(hiddenSize)
|
|
enc.setBytes(&N, length: 4, index: 3)
|
|
var eps = config.rmsNormEps
|
|
enc.setBytes(&eps, length: 4, index: 4)
|
|
var sl = UInt32(seqLen)
|
|
enc.setBytes(&sl, length: 4, index: 5)
|
|
|
|
let grid = MTLSize(width: hiddenSize, height: seqLen, depth: 1)
|
|
let tg = engine.threadgroupSize2D(pso, grid: (hiddenSize, seqLen))
|
|
enc.dispatchThreads(grid, threadsPerThreadgroup: tg)
|
|
enc.endEncoding()
|
|
|
|
return output
|
|
}
|
|
|
|
private func applyResidualAdd(input: MTLBuffer, add: MTLBuffer, seqLen: Int, hiddenSize: Int, cmdBuf: MTLCommandBuffer) throws -> MTLBuffer {
|
|
let output = residualBuffer
|
|
|
|
let pso = try engine.pipeline(named: "vision_residual_add")
|
|
let enc = cmdBuf.makeComputeCommandEncoder()!
|
|
enc.setComputePipelineState(pso)
|
|
|
|
enc.setBuffer(input, offset: 0, index: 0)
|
|
enc.setBuffer(add, offset: 0, index: 1)
|
|
enc.setBuffer(output, offset: 0, index: 2)
|
|
|
|
var count = UInt32(seqLen * hiddenSize)
|
|
enc.setBytes(&count, length: 4, index: 3)
|
|
|
|
let grid = MTLSize(width: seqLen * hiddenSize, height: 1, depth: 1)
|
|
let tg = engine.threadgroupSize1D(pso, count: seqLen * hiddenSize)
|
|
enc.dispatchThreads(grid, threadsPerThreadgroup: tg)
|
|
enc.endEncoding()
|
|
|
|
return output
|
|
}
|
|
|
|
private func applyEmbeddingProjection(input: MTLBuffer, numPatches: Int, output: MTLBuffer, cmdBuf: MTLCommandBuffer) throws {
|
|
let pso = try engine.pipeline(named: "vision_embedding_projection_quantized")
|
|
let enc = cmdBuf.makeComputeCommandEncoder()!
|
|
enc.setComputePipelineState(pso)
|
|
|
|
enc.setBuffer(input, offset: 0, index: 0)
|
|
enc.setBuffer(weights.embeddingProjectionWeight, offset: 0, index: 1)
|
|
enc.setBuffer(weights.embeddingProjectionScales, offset: 0, index: 2)
|
|
enc.setBuffer(weights.embeddingProjectionBiases, offset: 0, index: 3)
|
|
enc.setBuffer(output, offset: 0, index: 4)
|
|
|
|
var inFeatures = UInt32(768) // Vision hidden size
|
|
enc.setBytes(&inFeatures, length: 4, index: 5)
|
|
var outFeatures = UInt32(2560) // Text hidden size
|
|
enc.setBytes(&outFeatures, length: 4, index: 6)
|
|
var np = UInt32(numPatches)
|
|
enc.setBytes(&np, length: 4, index: 7)
|
|
var packedSize = UInt32(96) // 768 / 8
|
|
enc.setBytes(&packedSize, length: 4, index: 8)
|
|
var groupSize = UInt32(64)
|
|
enc.setBytes(&groupSize, length: 4, index: 9)
|
|
var numGroups = UInt32(12) // 768 / 64
|
|
enc.setBytes(&numGroups, length: 4, index: 10)
|
|
|
|
let grid = MTLSize(width: 2560, height: numPatches, depth: 1)
|
|
let tg = engine.threadgroupSize2D(pso, grid: (2560, numPatches))
|
|
enc.dispatchThreads(grid, threadsPerThreadgroup: tg)
|
|
enc.endEncoding()
|
|
}
|
|
}
|