388 lines
15 KiB
Swift
388 lines
15 KiB
Swift
import Metal
|
|
|
|
// Simplified vision tower for 12B
|
|
// 12B vision structure: vision_embedder + embed_vision.embedding_projection
|
|
|
|
public struct VisionConfig12B {
|
|
public let hiddenDim: Int // 3840
|
|
public let patchSize: Int // 16
|
|
public let numPositions: Int // 1120
|
|
public let outputDim: Int // 3840
|
|
|
|
public init(hiddenDim: Int = 3840, patchSize: Int = 16,
|
|
numPositions: Int = 1120, outputDim: Int = 3840) {
|
|
self.hiddenDim = hiddenDim
|
|
self.patchSize = patchSize
|
|
self.numPositions = numPositions
|
|
self.outputDim = outputDim
|
|
}
|
|
}
|
|
|
|
public struct VisionWeights12B {
|
|
// patch_dense (quantized)
|
|
public let patchDenseWeight: MTLBuffer
|
|
public let patchDenseScales: MTLBuffer
|
|
public let patchDenseBiases: MTLBuffer
|
|
public let patchDenseBias: MTLBuffer
|
|
|
|
// patch_ln1
|
|
public let patchLn1Weight: MTLBuffer
|
|
public let patchLn1Bias: MTLBuffer
|
|
|
|
// patch_ln2
|
|
public let patchLn2Weight: MTLBuffer
|
|
public let patchLn2Bias: MTLBuffer
|
|
|
|
// pos_embedding
|
|
public let posEmbedding: MTLBuffer
|
|
|
|
// pos_norm
|
|
public let posNormWeight: MTLBuffer
|
|
public let posNormBias: MTLBuffer
|
|
|
|
// embedding_projection (quantized)
|
|
public let embeddingProjectionWeight: MTLBuffer?
|
|
public let embeddingProjectionScales: MTLBuffer?
|
|
public let embeddingProjectionBiases: MTLBuffer?
|
|
|
|
public init(device: MTLDevice, tensors: [String: [Float]], packedWeights: [String: [UInt32]]) throws {
|
|
patchDenseWeight = device.makeBuffer(bytes: packedWeights["vision_embedder.patch_dense.weight"]!,
|
|
length: packedWeights["vision_embedder.patch_dense.weight"]!.count * 4)!
|
|
patchDenseScales = device.makeBuffer(bytes: tensors["vision_embedder.patch_dense.scales"]!,
|
|
length: tensors["vision_embedder.patch_dense.scales"]!.count * 4)!
|
|
patchDenseBiases = device.makeBuffer(bytes: tensors["vision_embedder.patch_dense.biases"]!,
|
|
length: tensors["vision_embedder.patch_dense.biases"]!.count * 4)!
|
|
patchDenseBias = device.makeBuffer(bytes: tensors["vision_embedder.patch_dense.bias"]!,
|
|
length: tensors["vision_embedder.patch_dense.bias"]!.count * 4)!
|
|
|
|
patchLn1Weight = device.makeBuffer(bytes: tensors["vision_embedder.patch_ln1.weight"]!,
|
|
length: tensors["vision_embedder.patch_ln1.weight"]!.count * 4)!
|
|
patchLn1Bias = device.makeBuffer(bytes: tensors["vision_embedder.patch_ln1.bias"]!,
|
|
length: tensors["vision_embedder.patch_ln1.bias"]!.count * 4)!
|
|
patchLn2Weight = device.makeBuffer(bytes: tensors["vision_embedder.patch_ln2.weight"]!,
|
|
length: tensors["vision_embedder.patch_ln2.weight"]!.count * 4)!
|
|
patchLn2Bias = device.makeBuffer(bytes: tensors["vision_embedder.patch_ln2.bias"]!,
|
|
length: tensors["vision_embedder.patch_ln2.bias"]!.count * 4)!
|
|
|
|
posEmbedding = device.makeBuffer(bytes: tensors["vision_embedder.pos_embedding"]!,
|
|
length: tensors["vision_embedder.pos_embedding"]!.count * 4)!
|
|
posNormWeight = device.makeBuffer(bytes: tensors["vision_embedder.pos_norm.weight"]!,
|
|
length: tensors["vision_embedder.pos_norm.weight"]!.count * 4)!
|
|
posNormBias = device.makeBuffer(bytes: tensors["vision_embedder.pos_norm.bias"]!,
|
|
length: tensors["vision_embedder.pos_norm.bias"]!.count * 4)!
|
|
|
|
if let w = packedWeights["embed_vision.embedding_projection.weight"] {
|
|
embeddingProjectionWeight = device.makeBuffer(bytes: w, length: w.count * 4)
|
|
} else {
|
|
embeddingProjectionWeight = nil
|
|
}
|
|
if let s = tensors["embed_vision.embedding_projection.scales"] {
|
|
embeddingProjectionScales = device.makeBuffer(bytes: s, length: s.count * 4)
|
|
} else {
|
|
embeddingProjectionScales = nil
|
|
}
|
|
if let b = tensors["embed_vision.embedding_projection.biases"] {
|
|
embeddingProjectionBiases = device.makeBuffer(bytes: b, length: b.count * 4)
|
|
} else {
|
|
embeddingProjectionBiases = nil
|
|
}
|
|
}
|
|
}
|
|
|
|
public final class VisionTower12B {
|
|
public let config: VisionConfig12B
|
|
public let weights: VisionWeights12B
|
|
public let engine: MarkBaseEngine
|
|
|
|
// Derived dimensions
|
|
public let patchDim: Int
|
|
public let hiddenDim: Int
|
|
public let posDim: Int
|
|
public let outputDim: Int
|
|
|
|
// Scratch buffers
|
|
private let denseOut: MTLBuffer
|
|
private let normBuf: MTLBuffer
|
|
private let embedBuf: MTLBuffer
|
|
|
|
public init(config: VisionConfig12B, engine: MarkBaseEngine, weights: VisionWeights12B) {
|
|
self.config = config
|
|
self.weights = weights
|
|
self.engine = engine
|
|
|
|
// Derive dimensions from weight buffer sizes
|
|
let outDim = weights.patchDenseBias.length / MemoryLayout<Float>.stride
|
|
let packedLen = weights.patchDenseWeight.length / MemoryLayout<UInt32>.stride
|
|
let packedInDim = packedLen / outDim
|
|
self.patchDim = packedInDim * 8
|
|
self.hiddenDim = outDim
|
|
self.posDim = weights.posEmbedding.length / MemoryLayout<Float>.stride / config.numPositions
|
|
self.outputDim = config.outputDim
|
|
|
|
// Allocate scratch buffers (max patches = 1024 by default)
|
|
let maxPatches = 1024
|
|
self.denseOut = engine.device.makeBuffer(
|
|
length: maxPatches * hiddenDim * MemoryLayout<Float>.stride,
|
|
options: .storageModeShared
|
|
)!
|
|
self.normBuf = engine.device.makeBuffer(
|
|
length: maxPatches * max(hiddenDim, outputDim) * MemoryLayout<Float>.stride,
|
|
options: .storageModeShared
|
|
)!
|
|
self.embedBuf = engine.device.makeBuffer(
|
|
length: maxPatches * outputDim * MemoryLayout<Float>.stride,
|
|
options: .storageModeShared
|
|
)!
|
|
}
|
|
|
|
// Process vision patches
|
|
// Input: patch embeddings [numPatches, patchDim] (Float32)
|
|
// Output: projected embeddings [numPatches, outputDim] (Float32)
|
|
public func forward(patchEmbeddings: MTLBuffer, numPatches: Int, outputBuffer: MTLBuffer) throws {
|
|
let cmdBuf = engine.commandQueue.makeCommandBuffer()!
|
|
defer { cmdBuf.commit(); cmdBuf.waitUntilCompleted() }
|
|
|
|
// 1. patch_dense: quantized matmul [numPatches, patchDim] -> [numPatches, hiddenDim]
|
|
try quantizedMatmul(
|
|
input: patchEmbeddings,
|
|
weight: weights.patchDenseWeight,
|
|
scales: weights.patchDenseScales,
|
|
biases: weights.patchDenseBiases,
|
|
bias: weights.patchDenseBias,
|
|
inDim: patchDim, outDim: hiddenDim,
|
|
seqLen: numPatches,
|
|
output: denseOut,
|
|
cmdBuf: cmdBuf
|
|
)
|
|
|
|
// 2. patch_ln1: RMS norm on hiddenDim
|
|
try rmsNormSeq(
|
|
input: denseOut,
|
|
weight: weights.patchLn1Weight,
|
|
bias: weights.patchLn1Bias,
|
|
normDim: hiddenDim,
|
|
seqLen: numPatches,
|
|
output: normBuf,
|
|
cmdBuf: cmdBuf
|
|
)
|
|
|
|
// 3. pos_embedding: add position embeddings
|
|
try addPositionEmbedding(
|
|
input: normBuf,
|
|
posEmbedding: weights.posEmbedding,
|
|
numPatches: numPatches,
|
|
hiddenDim: hiddenDim,
|
|
output: denseOut,
|
|
cmdBuf: cmdBuf
|
|
)
|
|
|
|
// 4. patch_ln2: RMS norm on hiddenDim
|
|
try rmsNormSeq(
|
|
input: denseOut,
|
|
weight: weights.patchLn2Weight,
|
|
bias: weights.patchLn2Bias,
|
|
normDim: hiddenDim,
|
|
seqLen: numPatches,
|
|
output: normBuf,
|
|
cmdBuf: cmdBuf
|
|
)
|
|
|
|
// 5. pos_norm: position normalization
|
|
try rmsNormSeq(
|
|
input: normBuf,
|
|
weight: weights.posNormWeight,
|
|
bias: weights.posNormBias,
|
|
normDim: hiddenDim,
|
|
seqLen: numPatches,
|
|
output: denseOut,
|
|
cmdBuf: cmdBuf
|
|
)
|
|
|
|
// 6. embedding_projection (optional): [numPatches, hiddenDim] -> [numPatches, outputDim]
|
|
if let projWeight = weights.embeddingProjectionWeight,
|
|
let projScales = weights.embeddingProjectionScales,
|
|
let projBiases = weights.embeddingProjectionBiases {
|
|
try quantizedMatmul(
|
|
input: denseOut,
|
|
weight: projWeight,
|
|
scales: projScales,
|
|
biases: projBiases,
|
|
bias: nil,
|
|
inDim: hiddenDim, outDim: outputDim,
|
|
seqLen: numPatches,
|
|
output: outputBuffer,
|
|
cmdBuf: cmdBuf
|
|
)
|
|
} else {
|
|
// No projection — copy from denseOut to outputBuffer
|
|
let blitEnc = cmdBuf.makeBlitCommandEncoder()!
|
|
blitEnc.copy(from: denseOut, sourceOffset: 0,
|
|
to: outputBuffer, destinationOffset: 0,
|
|
size: numPatches * hiddenDim * MemoryLayout<Float>.stride)
|
|
blitEnc.endEncoding()
|
|
}
|
|
}
|
|
|
|
// ── GPU kernel dispatches ─────────────────────────
|
|
|
|
private func quantizedMatmul(
|
|
input: MTLBuffer,
|
|
weight: MTLBuffer,
|
|
scales: MTLBuffer,
|
|
biases: MTLBuffer,
|
|
bias: MTLBuffer?,
|
|
inDim: Int, outDim: Int,
|
|
seqLen: Int,
|
|
output: MTLBuffer,
|
|
cmdBuf: MTLCommandBuffer
|
|
) throws {
|
|
let pso = try engine.pipeline(named: "quantized_matmul_seq")
|
|
let enc = cmdBuf.makeComputeCommandEncoder()!
|
|
enc.setComputePipelineState(pso)
|
|
|
|
enc.setBuffer(input, offset: 0, index: 0)
|
|
enc.setBuffer(weight, offset: 0, index: 1)
|
|
enc.setBuffer(scales, offset: 0, index: 2)
|
|
enc.setBuffer(biases, offset: 0, index: 3)
|
|
enc.setBuffer(bias ?? biases, offset: 0, index: 4)
|
|
enc.setBuffer(output, offset: 0, index: 5)
|
|
|
|
var inD = UInt32(inDim)
|
|
enc.setBytes(&inD, length: 4, index: 6)
|
|
var outD = UInt32(outDim)
|
|
enc.setBytes(&outD, length: 4, index: 7)
|
|
var hasBias = bias != nil
|
|
enc.setBytes(&hasBias, length: 1, index: 8)
|
|
var sl = UInt32(seqLen)
|
|
enc.setBytes(&sl, length: 4, index: 9)
|
|
|
|
let grid = MTLSize(width: outDim, height: seqLen, depth: 1)
|
|
let tg = engine.threadgroupSize2D(pso, grid: (outDim, seqLen))
|
|
enc.dispatchThreads(grid, threadsPerThreadgroup: tg)
|
|
enc.endEncoding()
|
|
}
|
|
|
|
private func rmsNormSeq(
|
|
input: MTLBuffer,
|
|
weight: MTLBuffer,
|
|
bias: MTLBuffer,
|
|
normDim: Int,
|
|
seqLen: Int,
|
|
output: MTLBuffer,
|
|
cmdBuf: MTLCommandBuffer
|
|
) throws {
|
|
let pso = try engine.pipeline(named: "rms_norm_seq")
|
|
let enc = cmdBuf.makeComputeCommandEncoder()!
|
|
enc.setComputePipelineState(pso)
|
|
|
|
enc.setBuffer(input, offset: 0, index: 0)
|
|
enc.setBuffer(weight, offset: 0, index: 1)
|
|
enc.setBuffer(output, offset: 0, index: 2)
|
|
|
|
var N = UInt32(normDim)
|
|
enc.setBytes(&N, length: MemoryLayout<UInt32>.size, index: 3)
|
|
var eps: Float = 1e-6
|
|
enc.setBytes(&eps, length: MemoryLayout<Float>.size, index: 4)
|
|
var sl = UInt32(seqLen)
|
|
enc.setBytes(&sl, length: MemoryLayout<UInt32>.size, index: 5)
|
|
|
|
let grid = MTLSize(width: normDim, height: seqLen, depth: 1)
|
|
let tg = engine.threadgroupSize2D(pso, grid: (normDim, seqLen))
|
|
enc.dispatchThreads(grid, threadsPerThreadgroup: tg)
|
|
enc.endEncoding()
|
|
}
|
|
|
|
private func addPositionEmbedding(
|
|
input: MTLBuffer,
|
|
posEmbedding: MTLBuffer,
|
|
numPatches: Int,
|
|
hiddenDim: Int,
|
|
output: MTLBuffer,
|
|
cmdBuf: MTLCommandBuffer
|
|
) throws {
|
|
let pso = try engine.pipeline(named: "vision_add_pos_embed")
|
|
let enc = cmdBuf.makeComputeCommandEncoder()!
|
|
enc.setComputePipelineState(pso)
|
|
|
|
enc.setBuffer(input, offset: 0, index: 0)
|
|
enc.setBuffer(posEmbedding, offset: 0, index: 1)
|
|
enc.setBuffer(output, offset: 0, index: 2)
|
|
|
|
var hd = UInt32(hiddenDim)
|
|
enc.setBytes(&hd, length: MemoryLayout<UInt32>.size, index: 3)
|
|
var np = UInt32(numPatches)
|
|
enc.setBytes(&np, length: MemoryLayout<UInt32>.size, index: 4)
|
|
|
|
let grid = MTLSize(width: hiddenDim, height: numPatches, depth: 1)
|
|
let tg = engine.threadgroupSize2D(pso, grid: (hiddenDim, numPatches))
|
|
enc.dispatchThreads(grid, threadsPerThreadgroup: tg)
|
|
enc.endEncoding()
|
|
}
|
|
|
|
private func eltwiseAdd(
|
|
input: MTLBuffer,
|
|
bias: MTLBuffer,
|
|
seqLen: Int,
|
|
dim: Int,
|
|
cmdBuf: MTLCommandBuffer
|
|
) throws {
|
|
let pso = try engine.pipeline(named: "eltwise_add")
|
|
let enc = cmdBuf.makeComputeCommandEncoder()!
|
|
enc.setComputePipelineState(pso)
|
|
|
|
enc.setBuffer(input, offset: 0, index: 0)
|
|
enc.setBuffer(bias, offset: 0, index: 1)
|
|
enc.setBuffer(input, offset: 0, index: 2)
|
|
|
|
var count = UInt32(seqLen * dim)
|
|
enc.setBytes(&count, length: MemoryLayout<UInt32>.size, index: 3)
|
|
|
|
let tg = engine.threadgroupSize1D(pso, count: seqLen * dim)
|
|
enc.dispatchThreads(MTLSize(width: seqLen * dim, height: 1, depth: 1),
|
|
threadsPerThreadgroup: tg)
|
|
enc.endEncoding()
|
|
}
|
|
|
|
// Load vision tower from safetensors
|
|
public static func load(modelDir: String, engine: MarkBaseEngine) throws -> VisionTower12B {
|
|
let device = engine.device
|
|
let shardFile = "model-00002-of-00002.safetensors"
|
|
let reader = try SafeTensorsReader(path: "\(modelDir)/\(shardFile)")
|
|
|
|
var floatTensors: [String: [Float]] = [:]
|
|
var packedWeights: [String: [UInt32]] = [:]
|
|
|
|
let visionKeys = [
|
|
"vision_embedder.patch_dense.weight",
|
|
"vision_embedder.patch_dense.bias",
|
|
"vision_embedder.patch_dense.scales",
|
|
"vision_embedder.patch_dense.biases",
|
|
"vision_embedder.patch_ln1.weight",
|
|
"vision_embedder.patch_ln1.bias",
|
|
"vision_embedder.patch_ln2.weight",
|
|
"vision_embedder.patch_ln2.bias",
|
|
"vision_embedder.pos_embedding",
|
|
"vision_embedder.pos_norm.weight",
|
|
"vision_embedder.pos_norm.bias",
|
|
"embed_vision.embedding_projection.weight",
|
|
"embed_vision.embedding_projection.scales",
|
|
"embed_vision.embedding_projection.biases"
|
|
]
|
|
|
|
for name in visionKeys {
|
|
guard let desc = reader.tensor(named: name) else { continue }
|
|
|
|
if desc.dtype == TensorDType.u32 {
|
|
packedWeights[name] = try reader.readUint32(named: name)
|
|
} else {
|
|
let raw = try reader.read(named: name)
|
|
floatTensors[name] = SafeTensorsReader.bf16ToFloat32(raw)
|
|
}
|
|
}
|
|
|
|
let weights = try VisionWeights12B(device: device, tensors: floatTensors, packedWeights: packedWeights)
|
|
|
|
return VisionTower12B(config: VisionConfig12B(), engine: engine, weights: weights)
|
|
}
|
|
}
|