66 lines
2.6 KiB
Swift
66 lines
2.6 KiB
Swift
import XCTest
|
|
@testable import MarkBase
|
|
|
|
final class Model26BTest: XCTestCase {
|
|
|
|
var engine: MarkBaseEngine!
|
|
var model: E4BModel!
|
|
let modelDir = "/Users/accusys/MarkBaseEngine/models/gemma-4-26b-standard"
|
|
let maxCtx = 128
|
|
|
|
override func setUp() {
|
|
super.setUp()
|
|
guard FileManager.default.fileExists(atPath: modelDir + "/model.safetensors") else {
|
|
return
|
|
}
|
|
engine = try? MarkBaseEngine(autoCompile: true)
|
|
model = try? E4BModel(modelDir: modelDir, engine: engine, maxContextLength: maxCtx)
|
|
}
|
|
|
|
func testModelLoads() throws {
|
|
try XCTSkipIf(model == nil, "gemma-4-26b-standard model not found")
|
|
XCTAssertNotNil(model)
|
|
XCTAssertEqual(model.hiddenSize, 2816)
|
|
XCTAssertEqual(model.numHiddenLayers, 30)
|
|
XCTAssertEqual(model.vocabSize, 262144)
|
|
}
|
|
|
|
func testBosTokenLogitsNoNaN() throws {
|
|
try XCTSkipIf(model == nil, "gemma-4-26b-standard model not found")
|
|
let logits = try model.forward(tokenId: 2, position: 0)
|
|
XCTAssertEqual(logits.count, model.vocabSize)
|
|
let nanCount = logits.filter { $0.isNaN }.count
|
|
XCTAssertEqual(nanCount, 0, "No NaN values in logits")
|
|
}
|
|
|
|
func testLogitsNotAllSaturated() throws {
|
|
try XCTSkipIf(model == nil, "gemma-4-26b-standard model not found")
|
|
let logits = try model.forward(tokenId: 2, position: 0)
|
|
// 26B has no softcapping, so logits should have variation
|
|
let uniqueCount = Set(logits.map { round($0 * 10) / 10 }).count
|
|
XCTAssertGreaterThan(uniqueCount, 100, "Logits should have meaningful variation")
|
|
}
|
|
|
|
func testLogitsReasonableRange() throws {
|
|
try XCTSkipIf(model == nil, "gemma-4-26b-standard model not found")
|
|
let logits = try model.forward(tokenId: 2, position: 0)
|
|
let maxVal = logits.max() ?? 0
|
|
let minVal = logits.min() ?? 0
|
|
XCTAssertGreaterThan(maxVal, -100)
|
|
XCTAssertLessThan(maxVal, 100000)
|
|
XCTAssertGreaterThan(minVal, -100000)
|
|
XCTAssertLessThan(minVal, 25000)
|
|
XCTAssertGreaterThan(maxVal, minVal, "Logits should have dynamic range")
|
|
}
|
|
|
|
func testMultipleTokensProduceDifferentLogits() throws {
|
|
try XCTSkipIf(model == nil, "gemma-4-26b-standard model not found")
|
|
let tokens = [2, 100, 1000, 10000]
|
|
for (pos, tokenId) in tokens.enumerated() {
|
|
let logits = try model.forward(tokenId: tokenId, position: pos)
|
|
let nanCount = logits.filter { $0.isNaN }.count
|
|
XCTAssertEqual(nanCount, 0, "NaN for token=\(tokenId) pos=\(pos)")
|
|
}
|
|
}
|
|
}
|