v2: add multimodal 12B test, fix VisionTower12B kernel dispatch
This commit is contained in:
@@ -236,7 +236,7 @@ public final class VisionTower12B {
|
||||
output: MTLBuffer,
|
||||
cmdBuf: MTLCommandBuffer
|
||||
) throws {
|
||||
let pso = try engine.pipeline(named: "quantized_matmul")
|
||||
let pso = try engine.pipeline(named: "quantized_matmul_seq")
|
||||
let enc = cmdBuf.makeComputeCommandEncoder()!
|
||||
enc.setComputePipelineState(pso)
|
||||
|
||||
@@ -244,22 +244,22 @@ public final class VisionTower12B {
|
||||
enc.setBuffer(weight, offset: 0, index: 1)
|
||||
enc.setBuffer(scales, offset: 0, index: 2)
|
||||
enc.setBuffer(biases, offset: 0, index: 3)
|
||||
enc.setBuffer(output, offset: 0, index: 4)
|
||||
enc.setBuffer(bias ?? biases, offset: 0, index: 4)
|
||||
enc.setBuffer(output, offset: 0, index: 5)
|
||||
|
||||
var inD = UInt32(inDim)
|
||||
enc.setBytes(&inD, length: MemoryLayout<UInt32>.size, index: 5)
|
||||
enc.setBytes(&inD, length: 4, index: 6)
|
||||
var outD = UInt32(outDim)
|
||||
enc.setBytes(&outD, length: MemoryLayout<UInt32>.size, index: 6)
|
||||
enc.setBytes(&outD, length: 4, index: 7)
|
||||
var hasBias = bias != nil
|
||||
enc.setBytes(&hasBias, length: 1, index: 8)
|
||||
var sl = UInt32(seqLen)
|
||||
enc.setBytes(&sl, length: 4, index: 9)
|
||||
|
||||
let grid = MTLSize(width: outDim * seqLen, height: 1, depth: 1)
|
||||
let tg = engine.threadgroupSize1D(pso, count: max(outDim, seqLen))
|
||||
let grid = MTLSize(width: outDim, height: seqLen, depth: 1)
|
||||
let tg = engine.threadgroupSize2D(pso, grid: (outDim, seqLen))
|
||||
enc.dispatchThreads(grid, threadsPerThreadgroup: tg)
|
||||
enc.endEncoding()
|
||||
|
||||
// Add unquantized bias if present
|
||||
if let b = bias {
|
||||
try eltwiseAdd(input: output, bias: b, seqLen: seqLen, dim: outDim, cmdBuf: cmdBuf)
|
||||
}
|
||||
}
|
||||
|
||||
private func rmsNormSeq(
|
||||
|
||||
@@ -0,0 +1,143 @@
|
||||
import XCTest
|
||||
@testable import MarkBase
|
||||
|
||||
final class Multimodal12BTest: XCTestCase {
|
||||
|
||||
var engine: MarkBaseEngine!
|
||||
var multimodal: MultimodalModel!
|
||||
let modelDir = "/Users/accusys/MarkBaseEngine/models/gemma-4-12b-it-4bit"
|
||||
let maxCtx = 64
|
||||
|
||||
override func setUp() {
|
||||
super.setUp()
|
||||
guard FileManager.default.fileExists(atPath: modelDir + "/model.safetensors.index.json") else {
|
||||
return
|
||||
}
|
||||
engine = try? MarkBaseEngine(autoCompile: true)
|
||||
multimodal = try? MultimodalModel(modelDir: modelDir, engine: engine, maxContextLength: maxCtx)
|
||||
}
|
||||
|
||||
func testModelLoads() throws {
|
||||
try XCTSkipIf(multimodal == nil, "12B model not found")
|
||||
XCTAssertEqual(multimodal!.textModel.hiddenSize, 3840)
|
||||
XCTAssertEqual(multimodal!.textModel.numHiddenLayers, 48)
|
||||
XCTAssertNotNil(multimodal!.visionTower, "VisionTower12B should load")
|
||||
XCTAssertNotNil(multimodal!.audioTower, "AudioTower12B should load")
|
||||
}
|
||||
|
||||
func testVisionTowerForward() throws {
|
||||
try XCTSkipIf(multimodal?.visionTower == nil, "Vision tower not loaded")
|
||||
let tower = multimodal!.visionTower!
|
||||
let numPatches = 8
|
||||
let patchDim = tower.patchDim
|
||||
|
||||
var patches = [Float](repeating: 0, count: numPatches * patchDim)
|
||||
for i in 0..<patches.count { patches[i] = Float.random(in: -0.5...0.5) }
|
||||
|
||||
let inputBuf = engine.device.makeBuffer(bytes: patches, length: patches.count * 4)!
|
||||
let outBuf = engine.device.makeBuffer(length: numPatches * tower.hiddenDim * 4)!
|
||||
|
||||
try tower.forward(patchEmbeddings: inputBuf, numPatches: numPatches, outputBuffer: outBuf)
|
||||
|
||||
let out = engine.readFloats(from: outBuf, count: numPatches * tower.hiddenDim)
|
||||
let nanCount = out.filter { $0.isNaN }.count
|
||||
XCTAssertEqual(nanCount, 0, "No NaN in vision output")
|
||||
|
||||
let maxAbs = out.map { abs($0) }.max() ?? 0
|
||||
XCTAssertLessThan(maxAbs, 1e6, "Vision output magnitude should be reasonable")
|
||||
XCTAssertGreaterThan(maxAbs, 0, "Vision output should have non-zero values")
|
||||
}
|
||||
|
||||
func testAudioTowerForward() throws {
|
||||
try XCTSkipIf(multimodal?.audioTower == nil, "Audio tower not loaded")
|
||||
let tower = multimodal!.audioTower!
|
||||
let numFrames = 16
|
||||
|
||||
var features = [Float](repeating: 0, count: numFrames * 640)
|
||||
for i in 0..<features.count { features[i] = Float.random(in: -1.0...1.0) }
|
||||
|
||||
let inputBuf = engine.device.makeBuffer(bytes: features, length: features.count * 4)!
|
||||
let outBuf = engine.device.makeBuffer(length: numFrames * tower.outDim * 4)!
|
||||
|
||||
try tower.forward(inputBuffer: inputBuf, seqLen: numFrames, outputBuffer: outBuf)
|
||||
|
||||
let out = engine.readFloats(from: outBuf, count: numFrames * tower.outDim)
|
||||
let nanCount = out.filter { $0.isNaN }.count
|
||||
XCTAssertEqual(nanCount, 0, "No NaN in audio output")
|
||||
}
|
||||
|
||||
func testTextBackboneForwardAfterVisionInjection() throws {
|
||||
try XCTSkipIf(multimodal?.visionTower == nil, "Vision tower not loaded")
|
||||
let tower = multimodal!.visionTower!
|
||||
let numPatches = 4
|
||||
let patchDim = tower.patchDim
|
||||
|
||||
var patches = [Float](repeating: 0, count: numPatches * patchDim)
|
||||
for i in 0..<patches.count { patches[i] = Float.random(in: -0.5...0.5) }
|
||||
|
||||
let inputBuf = engine.device.makeBuffer(bytes: patches, length: patches.count * 4)!
|
||||
let visionOut = engine.device.makeBuffer(length: numPatches * 3840 * 4)!
|
||||
try tower.forward(patchEmbeddings: inputBuf, numPatches: numPatches, outputBuffer: visionOut)
|
||||
|
||||
for i in 0..<numPatches {
|
||||
let offset = i * 3840 * 4
|
||||
let logits = try multimodal!.textModel.forwardFromHidden(
|
||||
hiddenBuffer: visionOut, offset: offset, position: i)
|
||||
let nanCount = logits.filter { $0.isNaN }.count
|
||||
XCTAssertEqual(nanCount, 0, "No NaN after vision injection pos=\(i)")
|
||||
}
|
||||
}
|
||||
|
||||
func testTextBackboneForwardAfterAudioInjection() throws {
|
||||
try XCTSkipIf(multimodal?.audioTower == nil, "Audio tower not loaded")
|
||||
let tower = multimodal!.audioTower!
|
||||
let numFrames = 4
|
||||
|
||||
var features = [Float](repeating: 0, count: numFrames * 640)
|
||||
for i in 0..<features.count { features[i] = Float.random(in: -1.0...1.0) }
|
||||
|
||||
let inputBuf = engine.device.makeBuffer(bytes: features, length: features.count * 4)!
|
||||
let audioOut = engine.device.makeBuffer(length: numFrames * 3840 * 4)!
|
||||
try tower.forward(inputBuffer: inputBuf, seqLen: numFrames, outputBuffer: audioOut)
|
||||
|
||||
for i in 0..<numFrames {
|
||||
let offset = i * 3840 * 4
|
||||
let logits = try multimodal!.textModel.forwardFromHidden(
|
||||
hiddenBuffer: audioOut, offset: offset, position: i)
|
||||
let nanCount = logits.filter { $0.isNaN }.count
|
||||
XCTAssertEqual(nanCount, 0, "No NaN after audio injection pos=\(i)")
|
||||
}
|
||||
}
|
||||
|
||||
func testMultimodalInferenceGenerate() throws {
|
||||
try XCTSkipIf(multimodal?.visionTower == nil, "Vision tower not loaded")
|
||||
let inference = try MultimodalInference(model: multimodal!)
|
||||
|
||||
let numPatches = 8
|
||||
let patchDim = multimodal!.visionTower!.patchDim
|
||||
var patches = [Float](repeating: 0, count: numPatches * patchDim)
|
||||
for i in 0..<patches.count { patches[i] = Float.random(in: -0.5...0.5) }
|
||||
|
||||
let audioDim = 640
|
||||
var audioFeatures = [[Float]]()
|
||||
for _ in 0..<32 {
|
||||
var frame = [Float](repeating: 0, count: audioDim)
|
||||
for j in 0..<audioDim { frame[j] = Float.random(in: -1.0...1.0) }
|
||||
audioFeatures.append(frame)
|
||||
}
|
||||
|
||||
let result = try inference.generate(
|
||||
textTokens: [2],
|
||||
audioFeatures: audioFeatures,
|
||||
imagePatches: patches,
|
||||
numImagePatches: numPatches,
|
||||
maxTokens: 5
|
||||
)
|
||||
|
||||
XCTAssertGreaterThan(result.count, 1, "Should generate at least one token")
|
||||
for token in result {
|
||||
XCTAssertGreaterThanOrEqual(token, 0, "Token ID should be non-negative")
|
||||
XCTAssertLessThan(token, multimodal!.textModel.vocabSize, "Token ID should be within vocab range")
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user