90 lines
3.2 KiB
Swift
90 lines
3.2 KiB
Swift
import XCTest
|
|
@testable import MarkBase
|
|
|
|
final class LongContext12BTest: XCTestCase {
|
|
|
|
var engine: MarkBaseEngine!
|
|
var model: E4BModel!
|
|
let modelDir = "/Users/accusys/MarkBaseEngine/models/gemma-4-12b-it-4bit"
|
|
let maxCtx = 2048
|
|
|
|
override func setUp() {
|
|
super.setUp()
|
|
guard FileManager.default.fileExists(atPath: modelDir + "/model.safetensors.index.json") else {
|
|
return
|
|
}
|
|
engine = try? MarkBaseEngine(autoCompile: true)
|
|
model = try? E4BModel(modelDir: modelDir, engine: engine, maxContextLength: maxCtx)
|
|
}
|
|
|
|
func testLongContext256Tokens() throws {
|
|
try XCTSkipIf(model == nil, "12B model not found")
|
|
|
|
let promptLength = 256
|
|
var tokens = [Int]()
|
|
for i in 0..<promptLength {
|
|
tokens.append(100 + (i % 1000))
|
|
}
|
|
|
|
for (pos, tokenId) in tokens.enumerated() {
|
|
let logits = try model.forward(tokenId: tokenId, position: pos)
|
|
if pos == 0 || pos == promptLength - 1 {
|
|
let nanCount = logits.filter { $0.isNaN }.count
|
|
XCTAssertEqual(nanCount, 0, "NaN at pos=\(pos)")
|
|
}
|
|
if pos % 64 == 0 {
|
|
let sample = logits.prefix(5)
|
|
let nanCount = logits.filter { $0.isNaN }.count
|
|
print(" pos=\(pos): logits[0..5]=\(sample) NaN=\(nanCount)")
|
|
}
|
|
}
|
|
|
|
var genTokens = tokens
|
|
for i in 0..<5 {
|
|
let logits = try model.forward(tokenId: genTokens.last ?? 0, position: genTokens.count - 1)
|
|
let nanCount = logits.filter { $0.isNaN }.count
|
|
XCTAssertEqual(nanCount, 0, "NaN at gen step \(i)")
|
|
var maxIdx = 0
|
|
var maxVal = logits[0]
|
|
for j in 1..<logits.count {
|
|
if logits[j] > maxVal { maxVal = logits[j]; maxIdx = j }
|
|
}
|
|
genTokens.append(maxIdx)
|
|
print(" gen[\(i)]: token=\(maxIdx) logit=\(maxVal)")
|
|
}
|
|
}
|
|
|
|
func testLongContext1024Tokens() throws {
|
|
try XCTSkipIf(model == nil, "12B model not found")
|
|
|
|
let promptLength = 1024
|
|
var tokens = [Int]()
|
|
for i in 0..<promptLength {
|
|
tokens.append(100 + (i % 1000))
|
|
}
|
|
|
|
for (pos, tokenId) in tokens.enumerated() {
|
|
let logits = try model.forward(tokenId: tokenId, position: pos)
|
|
if pos == 0 || pos == promptLength - 1 || pos % 128 == 0 {
|
|
let nanCount = logits.filter { $0.isNaN }.count
|
|
XCTAssertEqual(nanCount, 0, "NaN at pos=\(pos)")
|
|
print(" pos=\(pos): logits[0..3]=\(logits.prefix(3)) NaN=\(nanCount)")
|
|
}
|
|
}
|
|
|
|
var genTokens = tokens
|
|
for i in 0..<5 {
|
|
let logits = try model.forward(tokenId: genTokens.last ?? 0, position: genTokens.count - 1)
|
|
let nanCount = logits.filter { $0.isNaN }.count
|
|
XCTAssertEqual(nanCount, 0, "NaN at gen step \(i)")
|
|
var maxIdx = 0
|
|
var maxVal = logits[0]
|
|
for j in 1..<logits.count {
|
|
if logits[j] > maxVal { maxVal = logits[j]; maxIdx = j }
|
|
}
|
|
genTokens.append(maxIdx)
|
|
print(" gen[\(i)]: token=\(maxIdx) logit=\(maxVal)")
|
|
}
|
|
}
|
|
}
|