v2: add E4B multimodal test, fix VisionTower missing groupSize
This commit is contained in:
@@ -77,6 +77,8 @@ public final class VisionTower {
|
|||||||
enc.setBytes(&inD, length: MemoryLayout<UInt32>.size, index: 5)
|
enc.setBytes(&inD, length: MemoryLayout<UInt32>.size, index: 5)
|
||||||
var outD = UInt32(weights.outDim)
|
var outD = UInt32(weights.outDim)
|
||||||
enc.setBytes(&outD, length: MemoryLayout<UInt32>.size, index: 6)
|
enc.setBytes(&outD, length: MemoryLayout<UInt32>.size, index: 6)
|
||||||
|
var groupSize = UInt32(weights.groupSize)
|
||||||
|
enc.setBytes(&groupSize, length: MemoryLayout<UInt32>.size, index: 7)
|
||||||
|
|
||||||
let grid = MTLSize(width: weights.outDim * seqLen, height: 1, depth: 1)
|
let grid = MTLSize(width: weights.outDim * seqLen, height: 1, depth: 1)
|
||||||
let tg = engine.threadgroupSize1D(pso, count: max(weights.outDim, seqLen))
|
let tg = engine.threadgroupSize1D(pso, count: max(weights.outDim, seqLen))
|
||||||
|
|||||||
@@ -0,0 +1,118 @@
|
|||||||
|
import XCTest
|
||||||
|
@testable import MarkBase
|
||||||
|
|
||||||
|
final class MultimodalE4BTest: XCTestCase {
|
||||||
|
|
||||||
|
var engine: MarkBaseEngine!
|
||||||
|
var multimodal: MultimodalModel!
|
||||||
|
let modelDir = "/Users/accusys/MarkBaseEngine/models/E4B-MarkBase"
|
||||||
|
let maxCtx = 64
|
||||||
|
|
||||||
|
override func setUp() {
|
||||||
|
super.setUp()
|
||||||
|
guard FileManager.default.fileExists(atPath: modelDir + "/model.safetensors") else {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
engine = try? MarkBaseEngine(autoCompile: true)
|
||||||
|
multimodal = try? MultimodalModel(modelDir: modelDir, engine: engine, maxContextLength: maxCtx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testModelLoads() throws {
|
||||||
|
try XCTSkipIf(multimodal == nil, "E4B-MarkBase not found")
|
||||||
|
XCTAssertEqual(multimodal!.textModel.hiddenSize, 2560)
|
||||||
|
XCTAssertNotNil(multimodal!.visionTowerFull, "Full VisionTower should load")
|
||||||
|
XCTAssertNotNil(multimodal!.audioTowerFull, "Full AudioTower should load")
|
||||||
|
}
|
||||||
|
|
||||||
|
func testVisionTowerForward() throws {
|
||||||
|
try XCTSkipIf(multimodal?.visionTowerFull == nil, "Vision tower not loaded")
|
||||||
|
let tower = multimodal!.visionTowerFull!
|
||||||
|
let numPatches = 4
|
||||||
|
let patchDim = 768
|
||||||
|
let hs = tower.config.hiddenSize // 768
|
||||||
|
|
||||||
|
var patches = [Float](repeating: 0, count: numPatches * patchDim)
|
||||||
|
for i in 0..<patches.count { patches[i] = Float.random(in: -0.5...0.5) }
|
||||||
|
|
||||||
|
let inputBuf = engine.device.makeBuffer(bytes: patches, length: patches.count * 4)!
|
||||||
|
let outBuf = engine.device.makeBuffer(length: numPatches * hs * 4)!
|
||||||
|
|
||||||
|
try tower.forward(patchEmbeddings: inputBuf, numPatches: numPatches, outputBuffer: outBuf)
|
||||||
|
|
||||||
|
let out = engine.readFloats(from: outBuf, count: numPatches * hs)
|
||||||
|
let nanCount = out.filter { $0.isNaN }.count
|
||||||
|
XCTAssertEqual(nanCount, 0, "No NaN in vision output")
|
||||||
|
let maxAbs = out.map { abs($0) }.max() ?? 0
|
||||||
|
XCTAssertGreaterThan(maxAbs, 0, "Vision output should have non-zero values")
|
||||||
|
print(" vision: maxAbs=\(maxAbs)")
|
||||||
|
}
|
||||||
|
|
||||||
|
func testAudioTowerForward() throws {
|
||||||
|
try XCTSkipIf(multimodal?.audioTowerFull == nil, "Audio tower not loaded")
|
||||||
|
let tower = multimodal!.audioTowerFull!
|
||||||
|
let numFrames = 16
|
||||||
|
let audioDim = 128
|
||||||
|
|
||||||
|
var features = [Float](repeating: 0, count: numFrames * audioDim)
|
||||||
|
for i in 0..<features.count { features[i] = Float.random(in: -1.0...1.0) }
|
||||||
|
|
||||||
|
let inputBuf = engine.device.makeBuffer(bytes: features, length: features.count * 4)!
|
||||||
|
let hs = tower.config.outputProjDims
|
||||||
|
let outBuf = engine.device.makeBuffer(length: numFrames / 4 * hs * 4)!
|
||||||
|
|
||||||
|
try tower.forward(inputBuffer: inputBuf, seqLen: numFrames, outputBuffer: outBuf)
|
||||||
|
|
||||||
|
let out = engine.readFloats(from: outBuf, count: numFrames / 4 * hs)
|
||||||
|
let nanCount = out.filter { $0.isNaN }.count
|
||||||
|
XCTAssertEqual(nanCount, 0, "No NaN in audio output")
|
||||||
|
let maxAbs = out.map { abs($0) }.max() ?? 0
|
||||||
|
XCTAssertGreaterThan(maxAbs, 0, "Audio output should have non-zero values")
|
||||||
|
print(" audio: maxAbs=\(maxAbs)")
|
||||||
|
}
|
||||||
|
|
||||||
|
func testTextBackboneForwardAfterVisionInjection() throws {
|
||||||
|
try XCTSkipIf(multimodal?.visionTowerFull == nil, "Vision tower not loaded")
|
||||||
|
let tower = multimodal!.visionTowerFull!
|
||||||
|
let numPatches = 4
|
||||||
|
let patchDim = 768
|
||||||
|
let hs = tower.config.hiddenSize
|
||||||
|
|
||||||
|
var patches = [Float](repeating: 0, count: numPatches * patchDim)
|
||||||
|
for i in 0..<patches.count { patches[i] = Float.random(in: -0.5...0.5) }
|
||||||
|
|
||||||
|
let inputBuf = engine.device.makeBuffer(bytes: patches, length: patches.count * 4)!
|
||||||
|
let visionOut = engine.device.makeBuffer(length: numPatches * multimodal!.textModel.hiddenSize * 4)!
|
||||||
|
try tower.forward(patchEmbeddings: inputBuf, numPatches: numPatches, outputBuffer: visionOut)
|
||||||
|
|
||||||
|
for i in 0..<numPatches {
|
||||||
|
let offset = i * multimodal!.textModel.hiddenSize * 4
|
||||||
|
let logits = try multimodal!.textModel.forwardFromHidden(
|
||||||
|
hiddenBuffer: visionOut, offset: offset, position: i)
|
||||||
|
let nanCount = logits.filter { $0.isNaN }.count
|
||||||
|
XCTAssertEqual(nanCount, 0, "No NaN after vision injection pos=\(i)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testTextBackboneForwardAfterAudioInjection() throws {
|
||||||
|
try XCTSkipIf(multimodal?.audioTowerFull == nil, "Audio tower not loaded")
|
||||||
|
let tower = multimodal!.audioTowerFull!
|
||||||
|
let numFrames = 16
|
||||||
|
let audioDim = 128
|
||||||
|
|
||||||
|
var features = [Float](repeating: 0, count: numFrames * audioDim)
|
||||||
|
for i in 0..<features.count { features[i] = Float.random(in: -1.0...1.0) }
|
||||||
|
|
||||||
|
let inputBuf = engine.device.makeBuffer(bytes: features, length: features.count * 4)!
|
||||||
|
let hs = tower.config.outputProjDims
|
||||||
|
let audioOut = engine.device.makeBuffer(length: numFrames / 4 * hs * 4)!
|
||||||
|
try tower.forward(inputBuffer: inputBuf, seqLen: numFrames, outputBuffer: audioOut)
|
||||||
|
|
||||||
|
for i in 0..<min(4, numFrames / 4) {
|
||||||
|
let offset = i * hs * 4
|
||||||
|
let logits = try multimodal!.textModel.forwardFromHidden(
|
||||||
|
hiddenBuffer: audioOut, offset: offset, position: i)
|
||||||
|
let nanCount = logits.filter { $0.isNaN }.count
|
||||||
|
XCTAssertEqual(nanCount, 0, "No NaN after audio injection pos=\(i)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user