Files
markbaseengine/Tests/01_Model/LongContext12BTest.swift
T
MarkBase Admin 16c16b9bee
CI / build (push) Waiting to run
CI / unit-tests (push) Blocked by required conditions
CI / lint (push) Blocked by required conditions
v2: add 1024-token long context test
2026-07-06 01:11:50 +08:00

90 lines
3.2 KiB
Swift

import XCTest
@testable import MarkBase
final class LongContext12BTest: XCTestCase {
var engine: MarkBaseEngine!
var model: E4BModel!
let modelDir = "/Users/accusys/MarkBaseEngine/models/gemma-4-12b-it-4bit"
let maxCtx = 2048
override func setUp() {
super.setUp()
guard FileManager.default.fileExists(atPath: modelDir + "/model.safetensors.index.json") else {
return
}
engine = try? MarkBaseEngine(autoCompile: true)
model = try? E4BModel(modelDir: modelDir, engine: engine, maxContextLength: maxCtx)
}
func testLongContext256Tokens() throws {
try XCTSkipIf(model == nil, "12B model not found")
let promptLength = 256
var tokens = [Int]()
for i in 0..<promptLength {
tokens.append(100 + (i % 1000))
}
for (pos, tokenId) in tokens.enumerated() {
let logits = try model.forward(tokenId: tokenId, position: pos)
if pos == 0 || pos == promptLength - 1 {
let nanCount = logits.filter { $0.isNaN }.count
XCTAssertEqual(nanCount, 0, "NaN at pos=\(pos)")
}
if pos % 64 == 0 {
let sample = logits.prefix(5)
let nanCount = logits.filter { $0.isNaN }.count
print(" pos=\(pos): logits[0..5]=\(sample) NaN=\(nanCount)")
}
}
var genTokens = tokens
for i in 0..<5 {
let logits = try model.forward(tokenId: genTokens.last ?? 0, position: genTokens.count - 1)
let nanCount = logits.filter { $0.isNaN }.count
XCTAssertEqual(nanCount, 0, "NaN at gen step \(i)")
var maxIdx = 0
var maxVal = logits[0]
for j in 1..<logits.count {
if logits[j] > maxVal { maxVal = logits[j]; maxIdx = j }
}
genTokens.append(maxIdx)
print(" gen[\(i)]: token=\(maxIdx) logit=\(maxVal)")
}
}
func testLongContext1024Tokens() throws {
try XCTSkipIf(model == nil, "12B model not found")
let promptLength = 1024
var tokens = [Int]()
for i in 0..<promptLength {
tokens.append(100 + (i % 1000))
}
for (pos, tokenId) in tokens.enumerated() {
let logits = try model.forward(tokenId: tokenId, position: pos)
if pos == 0 || pos == promptLength - 1 || pos % 128 == 0 {
let nanCount = logits.filter { $0.isNaN }.count
XCTAssertEqual(nanCount, 0, "NaN at pos=\(pos)")
print(" pos=\(pos): logits[0..3]=\(logits.prefix(3)) NaN=\(nanCount)")
}
}
var genTokens = tokens
for i in 0..<5 {
let logits = try model.forward(tokenId: genTokens.last ?? 0, position: genTokens.count - 1)
let nanCount = logits.filter { $0.isNaN }.count
XCTAssertEqual(nanCount, 0, "NaN at gen step \(i)")
var maxIdx = 0
var maxVal = logits[0]
for j in 1..<logits.count {
if logits[j] > maxVal { maxVal = logits[j]; maxIdx = j }
}
genTokens.append(maxIdx)
print(" gen[\(i)]: token=\(maxIdx) logit=\(maxVal)")
}
}
}