Files
markbaseengine/Sources/MarkBase/Vision/VisionTower12B.swift
T
MarkBase Admin af1d10737e
CI / build (push) Waiting to run
CI / unit-tests (push) Blocked by required conditions
CI / lint (push) Blocked by required conditions
v2: add multimodal 12B test, fix VisionTower12B kernel dispatch
2026-07-05 23:58:42 +08:00

388 lines
15 KiB
Swift

import Metal
// Simplified vision tower for 12B
// 12B vision structure: vision_embedder + embed_vision.embedding_projection
public struct VisionConfig12B {
public let hiddenDim: Int // 3840
public let patchSize: Int // 16
public let numPositions: Int // 1120
public let outputDim: Int // 3840
public init(hiddenDim: Int = 3840, patchSize: Int = 16,
numPositions: Int = 1120, outputDim: Int = 3840) {
self.hiddenDim = hiddenDim
self.patchSize = patchSize
self.numPositions = numPositions
self.outputDim = outputDim
}
}
public struct VisionWeights12B {
// patch_dense (quantized)
public let patchDenseWeight: MTLBuffer
public let patchDenseScales: MTLBuffer
public let patchDenseBiases: MTLBuffer
public let patchDenseBias: MTLBuffer
// patch_ln1
public let patchLn1Weight: MTLBuffer
public let patchLn1Bias: MTLBuffer
// patch_ln2
public let patchLn2Weight: MTLBuffer
public let patchLn2Bias: MTLBuffer
// pos_embedding
public let posEmbedding: MTLBuffer
// pos_norm
public let posNormWeight: MTLBuffer
public let posNormBias: MTLBuffer
// embedding_projection (quantized)
public let embeddingProjectionWeight: MTLBuffer?
public let embeddingProjectionScales: MTLBuffer?
public let embeddingProjectionBiases: MTLBuffer?
public init(device: MTLDevice, tensors: [String: [Float]], packedWeights: [String: [UInt32]]) throws {
patchDenseWeight = device.makeBuffer(bytes: packedWeights["vision_embedder.patch_dense.weight"]!,
length: packedWeights["vision_embedder.patch_dense.weight"]!.count * 4)!
patchDenseScales = device.makeBuffer(bytes: tensors["vision_embedder.patch_dense.scales"]!,
length: tensors["vision_embedder.patch_dense.scales"]!.count * 4)!
patchDenseBiases = device.makeBuffer(bytes: tensors["vision_embedder.patch_dense.biases"]!,
length: tensors["vision_embedder.patch_dense.biases"]!.count * 4)!
patchDenseBias = device.makeBuffer(bytes: tensors["vision_embedder.patch_dense.bias"]!,
length: tensors["vision_embedder.patch_dense.bias"]!.count * 4)!
patchLn1Weight = device.makeBuffer(bytes: tensors["vision_embedder.patch_ln1.weight"]!,
length: tensors["vision_embedder.patch_ln1.weight"]!.count * 4)!
patchLn1Bias = device.makeBuffer(bytes: tensors["vision_embedder.patch_ln1.bias"]!,
length: tensors["vision_embedder.patch_ln1.bias"]!.count * 4)!
patchLn2Weight = device.makeBuffer(bytes: tensors["vision_embedder.patch_ln2.weight"]!,
length: tensors["vision_embedder.patch_ln2.weight"]!.count * 4)!
patchLn2Bias = device.makeBuffer(bytes: tensors["vision_embedder.patch_ln2.bias"]!,
length: tensors["vision_embedder.patch_ln2.bias"]!.count * 4)!
posEmbedding = device.makeBuffer(bytes: tensors["vision_embedder.pos_embedding"]!,
length: tensors["vision_embedder.pos_embedding"]!.count * 4)!
posNormWeight = device.makeBuffer(bytes: tensors["vision_embedder.pos_norm.weight"]!,
length: tensors["vision_embedder.pos_norm.weight"]!.count * 4)!
posNormBias = device.makeBuffer(bytes: tensors["vision_embedder.pos_norm.bias"]!,
length: tensors["vision_embedder.pos_norm.bias"]!.count * 4)!
if let w = packedWeights["embed_vision.embedding_projection.weight"] {
embeddingProjectionWeight = device.makeBuffer(bytes: w, length: w.count * 4)
} else {
embeddingProjectionWeight = nil
}
if let s = tensors["embed_vision.embedding_projection.scales"] {
embeddingProjectionScales = device.makeBuffer(bytes: s, length: s.count * 4)
} else {
embeddingProjectionScales = nil
}
if let b = tensors["embed_vision.embedding_projection.biases"] {
embeddingProjectionBiases = device.makeBuffer(bytes: b, length: b.count * 4)
} else {
embeddingProjectionBiases = nil
}
}
}
public final class VisionTower12B {
public let config: VisionConfig12B
public let weights: VisionWeights12B
public let engine: MarkBaseEngine
// Derived dimensions
public let patchDim: Int
public let hiddenDim: Int
public let posDim: Int
public let outputDim: Int
// Scratch buffers
private let denseOut: MTLBuffer
private let normBuf: MTLBuffer
private let embedBuf: MTLBuffer
public init(config: VisionConfig12B, engine: MarkBaseEngine, weights: VisionWeights12B) {
self.config = config
self.weights = weights
self.engine = engine
// Derive dimensions from weight buffer sizes
let outDim = weights.patchDenseBias.length / MemoryLayout<Float>.stride
let packedLen = weights.patchDenseWeight.length / MemoryLayout<UInt32>.stride
let packedInDim = packedLen / outDim
self.patchDim = packedInDim * 8
self.hiddenDim = outDim
self.posDim = weights.posEmbedding.length / MemoryLayout<Float>.stride / config.numPositions
self.outputDim = config.outputDim
// Allocate scratch buffers (max patches = 1024 by default)
let maxPatches = 1024
self.denseOut = engine.device.makeBuffer(
length: maxPatches * hiddenDim * MemoryLayout<Float>.stride,
options: .storageModeShared
)!
self.normBuf = engine.device.makeBuffer(
length: maxPatches * max(hiddenDim, outputDim) * MemoryLayout<Float>.stride,
options: .storageModeShared
)!
self.embedBuf = engine.device.makeBuffer(
length: maxPatches * outputDim * MemoryLayout<Float>.stride,
options: .storageModeShared
)!
}
// Process vision patches
// Input: patch embeddings [numPatches, patchDim] (Float32)
// Output: projected embeddings [numPatches, outputDim] (Float32)
public func forward(patchEmbeddings: MTLBuffer, numPatches: Int, outputBuffer: MTLBuffer) throws {
let cmdBuf = engine.commandQueue.makeCommandBuffer()!
defer { cmdBuf.commit(); cmdBuf.waitUntilCompleted() }
// 1. patch_dense: quantized matmul [numPatches, patchDim] -> [numPatches, hiddenDim]
try quantizedMatmul(
input: patchEmbeddings,
weight: weights.patchDenseWeight,
scales: weights.patchDenseScales,
biases: weights.patchDenseBiases,
bias: weights.patchDenseBias,
inDim: patchDim, outDim: hiddenDim,
seqLen: numPatches,
output: denseOut,
cmdBuf: cmdBuf
)
// 2. patch_ln1: RMS norm on hiddenDim
try rmsNormSeq(
input: denseOut,
weight: weights.patchLn1Weight,
bias: weights.patchLn1Bias,
normDim: hiddenDim,
seqLen: numPatches,
output: normBuf,
cmdBuf: cmdBuf
)
// 3. pos_embedding: add position embeddings
try addPositionEmbedding(
input: normBuf,
posEmbedding: weights.posEmbedding,
numPatches: numPatches,
hiddenDim: hiddenDim,
output: denseOut,
cmdBuf: cmdBuf
)
// 4. patch_ln2: RMS norm on hiddenDim
try rmsNormSeq(
input: denseOut,
weight: weights.patchLn2Weight,
bias: weights.patchLn2Bias,
normDim: hiddenDim,
seqLen: numPatches,
output: normBuf,
cmdBuf: cmdBuf
)
// 5. pos_norm: position normalization
try rmsNormSeq(
input: normBuf,
weight: weights.posNormWeight,
bias: weights.posNormBias,
normDim: hiddenDim,
seqLen: numPatches,
output: denseOut,
cmdBuf: cmdBuf
)
// 6. embedding_projection (optional): [numPatches, hiddenDim] -> [numPatches, outputDim]
if let projWeight = weights.embeddingProjectionWeight,
let projScales = weights.embeddingProjectionScales,
let projBiases = weights.embeddingProjectionBiases {
try quantizedMatmul(
input: denseOut,
weight: projWeight,
scales: projScales,
biases: projBiases,
bias: nil,
inDim: hiddenDim, outDim: outputDim,
seqLen: numPatches,
output: outputBuffer,
cmdBuf: cmdBuf
)
} else {
// No projection — copy from denseOut to outputBuffer
let blitEnc = cmdBuf.makeBlitCommandEncoder()!
blitEnc.copy(from: denseOut, sourceOffset: 0,
to: outputBuffer, destinationOffset: 0,
size: numPatches * hiddenDim * MemoryLayout<Float>.stride)
blitEnc.endEncoding()
}
}
// ── GPU kernel dispatches ─────────────────────────
private func quantizedMatmul(
input: MTLBuffer,
weight: MTLBuffer,
scales: MTLBuffer,
biases: MTLBuffer,
bias: MTLBuffer?,
inDim: Int, outDim: Int,
seqLen: Int,
output: MTLBuffer,
cmdBuf: MTLCommandBuffer
) throws {
let pso = try engine.pipeline(named: "quantized_matmul_seq")
let enc = cmdBuf.makeComputeCommandEncoder()!
enc.setComputePipelineState(pso)
enc.setBuffer(input, offset: 0, index: 0)
enc.setBuffer(weight, offset: 0, index: 1)
enc.setBuffer(scales, offset: 0, index: 2)
enc.setBuffer(biases, offset: 0, index: 3)
enc.setBuffer(bias ?? biases, offset: 0, index: 4)
enc.setBuffer(output, offset: 0, index: 5)
var inD = UInt32(inDim)
enc.setBytes(&inD, length: 4, index: 6)
var outD = UInt32(outDim)
enc.setBytes(&outD, length: 4, index: 7)
var hasBias = bias != nil
enc.setBytes(&hasBias, length: 1, index: 8)
var sl = UInt32(seqLen)
enc.setBytes(&sl, length: 4, index: 9)
let grid = MTLSize(width: outDim, height: seqLen, depth: 1)
let tg = engine.threadgroupSize2D(pso, grid: (outDim, seqLen))
enc.dispatchThreads(grid, threadsPerThreadgroup: tg)
enc.endEncoding()
}
private func rmsNormSeq(
input: MTLBuffer,
weight: MTLBuffer,
bias: MTLBuffer,
normDim: Int,
seqLen: Int,
output: MTLBuffer,
cmdBuf: MTLCommandBuffer
) throws {
let pso = try engine.pipeline(named: "rms_norm_seq")
let enc = cmdBuf.makeComputeCommandEncoder()!
enc.setComputePipelineState(pso)
enc.setBuffer(input, offset: 0, index: 0)
enc.setBuffer(weight, offset: 0, index: 1)
enc.setBuffer(output, offset: 0, index: 2)
var N = UInt32(normDim)
enc.setBytes(&N, length: MemoryLayout<UInt32>.size, index: 3)
var eps: Float = 1e-6
enc.setBytes(&eps, length: MemoryLayout<Float>.size, index: 4)
var sl = UInt32(seqLen)
enc.setBytes(&sl, length: MemoryLayout<UInt32>.size, index: 5)
let grid = MTLSize(width: normDim, height: seqLen, depth: 1)
let tg = engine.threadgroupSize2D(pso, grid: (normDim, seqLen))
enc.dispatchThreads(grid, threadsPerThreadgroup: tg)
enc.endEncoding()
}
private func addPositionEmbedding(
input: MTLBuffer,
posEmbedding: MTLBuffer,
numPatches: Int,
hiddenDim: Int,
output: MTLBuffer,
cmdBuf: MTLCommandBuffer
) throws {
let pso = try engine.pipeline(named: "vision_add_pos_embed")
let enc = cmdBuf.makeComputeCommandEncoder()!
enc.setComputePipelineState(pso)
enc.setBuffer(input, offset: 0, index: 0)
enc.setBuffer(posEmbedding, offset: 0, index: 1)
enc.setBuffer(output, offset: 0, index: 2)
var hd = UInt32(hiddenDim)
enc.setBytes(&hd, length: MemoryLayout<UInt32>.size, index: 3)
var np = UInt32(numPatches)
enc.setBytes(&np, length: MemoryLayout<UInt32>.size, index: 4)
let grid = MTLSize(width: hiddenDim, height: numPatches, depth: 1)
let tg = engine.threadgroupSize2D(pso, grid: (hiddenDim, numPatches))
enc.dispatchThreads(grid, threadsPerThreadgroup: tg)
enc.endEncoding()
}
private func eltwiseAdd(
input: MTLBuffer,
bias: MTLBuffer,
seqLen: Int,
dim: Int,
cmdBuf: MTLCommandBuffer
) throws {
let pso = try engine.pipeline(named: "eltwise_add")
let enc = cmdBuf.makeComputeCommandEncoder()!
enc.setComputePipelineState(pso)
enc.setBuffer(input, offset: 0, index: 0)
enc.setBuffer(bias, offset: 0, index: 1)
enc.setBuffer(input, offset: 0, index: 2)
var count = UInt32(seqLen * dim)
enc.setBytes(&count, length: MemoryLayout<UInt32>.size, index: 3)
let tg = engine.threadgroupSize1D(pso, count: seqLen * dim)
enc.dispatchThreads(MTLSize(width: seqLen * dim, height: 1, depth: 1),
threadsPerThreadgroup: tg)
enc.endEncoding()
}
// Load vision tower from safetensors
public static func load(modelDir: String, engine: MarkBaseEngine) throws -> VisionTower12B {
let device = engine.device
let shardFile = "model-00002-of-00002.safetensors"
let reader = try SafeTensorsReader(path: "\(modelDir)/\(shardFile)")
var floatTensors: [String: [Float]] = [:]
var packedWeights: [String: [UInt32]] = [:]
let visionKeys = [
"vision_embedder.patch_dense.weight",
"vision_embedder.patch_dense.bias",
"vision_embedder.patch_dense.scales",
"vision_embedder.patch_dense.biases",
"vision_embedder.patch_ln1.weight",
"vision_embedder.patch_ln1.bias",
"vision_embedder.patch_ln2.weight",
"vision_embedder.patch_ln2.bias",
"vision_embedder.pos_embedding",
"vision_embedder.pos_norm.weight",
"vision_embedder.pos_norm.bias",
"embed_vision.embedding_projection.weight",
"embed_vision.embedding_projection.scales",
"embed_vision.embedding_projection.biases"
]
for name in visionKeys {
guard let desc = reader.tensor(named: name) else { continue }
if desc.dtype == TensorDType.u32 {
packedWeights[name] = try reader.readUint32(named: name)
} else {
let raw = try reader.read(named: name)
floatTensors[name] = SafeTensorsReader.bf16ToFloat32(raw)
}
}
let weights = try VisionWeights12B(device: device, tensors: floatTensors, packedWeights: packedWeights)
return VisionTower12B(config: VisionConfig12B(), engine: engine, weights: weights)
}
}