v2: add 12B model test (Model12BTest)
This commit is contained in:
@@ -7,5 +7,6 @@ Package.resolved
|
||||
*.xcodeproj/
|
||||
*.xcworkspace/
|
||||
.DS_Store
|
||||
blobs/
|
||||
test_summary.md.runner
|
||||
.runner
|
||||
|
||||
@@ -0,0 +1,55 @@
|
||||
import XCTest
|
||||
@testable import MarkBase
|
||||
|
||||
final class Model12BTest: XCTestCase {
|
||||
|
||||
var engine: MarkBaseEngine!
|
||||
var model: E4BModel!
|
||||
let modelDir = "/Users/accusys/MarkBaseEngine/models/gemma-4-12b-it-4bit"
|
||||
let maxCtx = 64
|
||||
|
||||
override func setUp() {
|
||||
super.setUp()
|
||||
guard FileManager.default.fileExists(atPath: modelDir + "/model.safetensors.index.json") else {
|
||||
return
|
||||
}
|
||||
engine = try? MarkBaseEngine(autoCompile: true)
|
||||
model = try? E4BModel(modelDir: modelDir, engine: engine, maxContextLength: maxCtx)
|
||||
}
|
||||
|
||||
func testModelLoads() throws {
|
||||
try XCTSkipIf(model == nil, "gemma-4-12b-it-4bit model not found")
|
||||
XCTAssertNotNil(model)
|
||||
XCTAssertEqual(model.hiddenSize, 3840)
|
||||
XCTAssertEqual(model.numHiddenLayers, 48)
|
||||
XCTAssertEqual(model.vocabSize, 262144)
|
||||
}
|
||||
|
||||
func testBosTokenLogitsNoNaN() throws {
|
||||
try XCTSkipIf(model == nil, "gemma-4-12b-it-4bit model not found")
|
||||
let logits = try model.forward(tokenId: 2, position: 0)
|
||||
XCTAssertEqual(logits.count, model.vocabSize)
|
||||
let nanCount = logits.filter { $0.isNaN }.count
|
||||
XCTAssertEqual(nanCount, 0, "No NaN values in logits")
|
||||
}
|
||||
|
||||
func testLogitSoftcapping() throws {
|
||||
try XCTSkipIf(model == nil, "gemma-4-12b-it-4bit model not found")
|
||||
let logits = try model.forward(tokenId: 2, position: 0)
|
||||
let softcap: Float = 30.0
|
||||
for logit in logits {
|
||||
XCTAssertLessThanOrEqual(abs(logit), softcap + 0.1,
|
||||
"Logit \(logit) exceeds softcap \(softcap)")
|
||||
}
|
||||
}
|
||||
|
||||
func testMultipleTokensProduceDifferentLogits() throws {
|
||||
try XCTSkipIf(model == nil, "gemma-4-12b-it-4bit model not found")
|
||||
let tokens = [2, 100, 1000]
|
||||
for (pos, tokenId) in tokens.enumerated() {
|
||||
let logits = try model.forward(tokenId: tokenId, position: pos)
|
||||
let nanCount = logits.filter { $0.isNaN }.count
|
||||
XCTAssertEqual(nanCount, 0, "NaN for token=\(tokenId) pos=\(pos)")
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user