diff --git a/Sources/MarkBase/Vision/VisionTower12B.swift b/Sources/MarkBase/Vision/VisionTower12B.swift index b4f0bd5..e8404c2 100644 --- a/Sources/MarkBase/Vision/VisionTower12B.swift +++ b/Sources/MarkBase/Vision/VisionTower12B.swift @@ -236,7 +236,7 @@ public final class VisionTower12B { output: MTLBuffer, cmdBuf: MTLCommandBuffer ) throws { - let pso = try engine.pipeline(named: "quantized_matmul") + let pso = try engine.pipeline(named: "quantized_matmul_seq") let enc = cmdBuf.makeComputeCommandEncoder()! enc.setComputePipelineState(pso) @@ -244,22 +244,22 @@ public final class VisionTower12B { enc.setBuffer(weight, offset: 0, index: 1) enc.setBuffer(scales, offset: 0, index: 2) enc.setBuffer(biases, offset: 0, index: 3) - enc.setBuffer(output, offset: 0, index: 4) + enc.setBuffer(bias ?? biases, offset: 0, index: 4) + enc.setBuffer(output, offset: 0, index: 5) var inD = UInt32(inDim) - enc.setBytes(&inD, length: MemoryLayout.size, index: 5) + enc.setBytes(&inD, length: 4, index: 6) var outD = UInt32(outDim) - enc.setBytes(&outD, length: MemoryLayout.size, index: 6) + enc.setBytes(&outD, length: 4, index: 7) + var hasBias = bias != nil + enc.setBytes(&hasBias, length: 1, index: 8) + var sl = UInt32(seqLen) + enc.setBytes(&sl, length: 4, index: 9) - let grid = MTLSize(width: outDim * seqLen, height: 1, depth: 1) - let tg = engine.threadgroupSize1D(pso, count: max(outDim, seqLen)) + let grid = MTLSize(width: outDim, height: seqLen, depth: 1) + let tg = engine.threadgroupSize2D(pso, grid: (outDim, seqLen)) enc.dispatchThreads(grid, threadsPerThreadgroup: tg) enc.endEncoding() - - // Add unquantized bias if present - if let b = bias { - try eltwiseAdd(input: output, bias: b, seqLen: seqLen, dim: outDim, cmdBuf: cmdBuf) - } } private func rmsNormSeq( diff --git a/Tests/01_Model/Multimodal12BTest.swift b/Tests/01_Model/Multimodal12BTest.swift new file mode 100644 index 0000000..604aa26 --- /dev/null +++ b/Tests/01_Model/Multimodal12BTest.swift @@ -0,0 +1,143 @@ +import XCTest +@testable import MarkBase + +final class Multimodal12BTest: XCTestCase { + + var engine: MarkBaseEngine! + var multimodal: MultimodalModel! + let modelDir = "/Users/accusys/MarkBaseEngine/models/gemma-4-12b-it-4bit" + let maxCtx = 64 + + override func setUp() { + super.setUp() + guard FileManager.default.fileExists(atPath: modelDir + "/model.safetensors.index.json") else { + return + } + engine = try? MarkBaseEngine(autoCompile: true) + multimodal = try? MultimodalModel(modelDir: modelDir, engine: engine, maxContextLength: maxCtx) + } + + func testModelLoads() throws { + try XCTSkipIf(multimodal == nil, "12B model not found") + XCTAssertEqual(multimodal!.textModel.hiddenSize, 3840) + XCTAssertEqual(multimodal!.textModel.numHiddenLayers, 48) + XCTAssertNotNil(multimodal!.visionTower, "VisionTower12B should load") + XCTAssertNotNil(multimodal!.audioTower, "AudioTower12B should load") + } + + func testVisionTowerForward() throws { + try XCTSkipIf(multimodal?.visionTower == nil, "Vision tower not loaded") + let tower = multimodal!.visionTower! + let numPatches = 8 + let patchDim = tower.patchDim + + var patches = [Float](repeating: 0, count: numPatches * patchDim) + for i in 0..