v2: add long context 12B test (256 tokens)
This commit is contained in:
@@ -0,0 +1,56 @@
|
||||
import XCTest
|
||||
@testable import MarkBase
|
||||
|
||||
final class LongContext12BTest: XCTestCase {
|
||||
|
||||
var engine: MarkBaseEngine!
|
||||
var model: E4BModel!
|
||||
let modelDir = "/Users/accusys/MarkBaseEngine/models/gemma-4-12b-it-4bit"
|
||||
let maxCtx = 2048
|
||||
|
||||
override func setUp() {
|
||||
super.setUp()
|
||||
guard FileManager.default.fileExists(atPath: modelDir + "/model.safetensors.index.json") else {
|
||||
return
|
||||
}
|
||||
engine = try? MarkBaseEngine(autoCompile: true)
|
||||
model = try? E4BModel(modelDir: modelDir, engine: engine, maxContextLength: maxCtx)
|
||||
}
|
||||
|
||||
func testLongContext256Tokens() throws {
|
||||
try XCTSkipIf(model == nil, "12B model not found")
|
||||
|
||||
let promptLength = 256
|
||||
var tokens = [Int]()
|
||||
for i in 0..<promptLength {
|
||||
tokens.append(100 + (i % 1000))
|
||||
}
|
||||
|
||||
for (pos, tokenId) in tokens.enumerated() {
|
||||
let logits = try model.forward(tokenId: tokenId, position: pos)
|
||||
if pos == 0 || pos == promptLength - 1 {
|
||||
let nanCount = logits.filter { $0.isNaN }.count
|
||||
XCTAssertEqual(nanCount, 0, "NaN at pos=\(pos)")
|
||||
}
|
||||
if pos % 64 == 0 {
|
||||
let sample = logits.prefix(5)
|
||||
let nanCount = logits.filter { $0.isNaN }.count
|
||||
print(" pos=\(pos): logits[0..5]=\(sample) NaN=\(nanCount)")
|
||||
}
|
||||
}
|
||||
|
||||
var genTokens = tokens
|
||||
for i in 0..<5 {
|
||||
let logits = try model.forward(tokenId: genTokens.last ?? 0, position: genTokens.count - 1)
|
||||
let nanCount = logits.filter { $0.isNaN }.count
|
||||
XCTAssertEqual(nanCount, 0, "NaN at gen step \(i)")
|
||||
var maxIdx = 0
|
||||
var maxVal = logits[0]
|
||||
for j in 1..<logits.count {
|
||||
if logits[j] > maxVal { maxVal = logits[j]; maxIdx = j }
|
||||
}
|
||||
genTokens.append(maxIdx)
|
||||
print(" gen[\(i)]: token=\(maxIdx) logit=\(maxVal)")
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user