123 lines
4.9 KiB
Swift
123 lines
4.9 KiB
Swift
import XCTest
|
|
@testable import MarkBase
|
|
|
|
final class ModelTest: XCTestCase {
|
|
|
|
var engine: MarkBaseEngine!
|
|
var model: E4BModel!
|
|
let modelDir = "/Users/accusys/MarkBaseEngine/models/E4B-MarkBase"
|
|
let maxCtx = 256
|
|
|
|
override func setUp() {
|
|
super.setUp()
|
|
guard FileManager.default.fileExists(atPath: modelDir + "/model.safetensors") else {
|
|
return
|
|
}
|
|
engine = try? MarkBaseEngine(autoCompile: true)
|
|
model = try? E4BModel(modelDir: modelDir, engine: engine, maxContextLength: maxCtx)
|
|
}
|
|
|
|
// MARK: - Model Loading
|
|
|
|
func testModelLoads() throws {
|
|
try XCTSkipIf(model == nil, "E4B-MarkBase model not found")
|
|
XCTAssertNotNil(model)
|
|
XCTAssertEqual(model.vocabSize, 262144)
|
|
XCTAssertEqual(model.hiddenSize, 2560)
|
|
XCTAssertEqual(model.numHiddenLayers, 42)
|
|
}
|
|
|
|
// MARK: - Forward Pass
|
|
|
|
func testBosTokenLogits() throws {
|
|
try XCTSkipIf(model == nil, "E4B-MarkBase model not found")
|
|
let logits = try model.forward(tokenId: 2, position: 0)
|
|
XCTAssertEqual(logits.count, model.vocabSize)
|
|
let nanCount = logits.filter { $0.isNaN }.count
|
|
XCTAssertEqual(nanCount, 0, "No NaN values in logits")
|
|
XCTAssertGreaterThan(logits.max() ?? -Float.infinity, -50)
|
|
XCTAssertLessThan(logits.max() ?? Float.infinity, 50)
|
|
XCTAssertGreaterThan(logits.min() ?? -Float.infinity, -50)
|
|
XCTAssertLessThan(logits.min() ?? Float.infinity, 50)
|
|
}
|
|
|
|
func testLogitSoftcapping() throws {
|
|
try XCTSkipIf(model == nil, "E4B-MarkBase model not found")
|
|
let logits = try model.forward(tokenId: 2, position: 0)
|
|
let softcap: Float = 30.0
|
|
for logit in logits {
|
|
XCTAssertLessThanOrEqual(abs(logit), softcap + 1e-3,
|
|
"Logit \(logit) exceeds softcap \(softcap)")
|
|
}
|
|
}
|
|
|
|
func testMultipleTokensDeterministic() throws {
|
|
try XCTSkipIf(model == nil, "E4B-MarkBase model not found")
|
|
let tokens = [2, 1024, 2048, 4096]
|
|
var allLogits: [[Float]] = []
|
|
for (pos, tokenId) in tokens.enumerated() {
|
|
let logits = try model.forward(tokenId: tokenId, position: pos)
|
|
allLogits.append(logits)
|
|
}
|
|
XCTAssertEqual(allLogits.count, tokens.count)
|
|
for logits in allLogits {
|
|
let nanCount = logits.filter { $0.isNaN }.count
|
|
XCTAssertEqual(nanCount, 0, "No NaN values in logits")
|
|
}
|
|
}
|
|
|
|
func testDeterministicOutput() throws {
|
|
try XCTSkipIf(model == nil, "E4B-MarkBase model not found")
|
|
let r1 = try model.forward(tokenId: 99, position: 0)
|
|
let r2 = try model.forward(tokenId: 99, position: 0)
|
|
XCTAssertEqual(r1.count, r2.count)
|
|
let differences = zip(r1, r2).map { abs($0 - $1) }
|
|
let maxDiff = differences.max() ?? 0
|
|
let avgDiff = differences.reduce(0, +) / Float(differences.count)
|
|
XCTAssertLessThan(maxDiff, 2.0, "GPU determinism: max diff \(maxDiff) too large")
|
|
XCTAssertLessThan(avgDiff, 0.1, "GPU determinism: avg diff \(avgDiff) too large")
|
|
}
|
|
|
|
func testKVCacheIncrements() throws {
|
|
try XCTSkipIf(model == nil, "E4B-MarkBase model not found")
|
|
let r0 = try model.forward(tokenId: 2, position: 0)
|
|
let r1 = try model.forward(tokenId: 1024, position: 1)
|
|
let r2 = try model.forward(tokenId: 2048, position: 2)
|
|
XCTAssertFalse(r0.elementsEqual(r1))
|
|
XCTAssertFalse(r1.elementsEqual(r2))
|
|
for logits in [r0, r1, r2] {
|
|
let nanCount = logits.filter { $0.isNaN }.count
|
|
XCTAssertEqual(nanCount, 0, "No NaN values in logits")
|
|
}
|
|
}
|
|
|
|
func testDifferentTokensDifferentLogits() throws {
|
|
try XCTSkipIf(model == nil, "E4B-MarkBase model not found")
|
|
let tokenA: [Float] = try model.forward(tokenId: 100, position: 0)
|
|
let tokenB: [Float] = try model.forward(tokenId: 200, position: 0)
|
|
XCTAssertNotEqual(tokenA, tokenB)
|
|
}
|
|
|
|
func testRandomTokenId() throws {
|
|
try XCTSkipIf(model == nil, "E4B-MarkBase model not found")
|
|
for tokenId in [0, 1, 100, 1000, 10000, 100000] {
|
|
let logits = try model.forward(tokenId: tokenId, position: 0)
|
|
let nanCount = logits.filter { $0.isNaN }.count
|
|
XCTAssertEqual(nanCount, 0, "No NaN for tokenId=\(tokenId)")
|
|
XCTAssertEqual(logits.count, model.vocabSize)
|
|
}
|
|
}
|
|
|
|
// MARK: - Batched context test
|
|
|
|
func testFullContextForward() throws {
|
|
try XCTSkipIf(model == nil, "E4B-MarkBase model not found")
|
|
let promptTokens = [2] + Array(repeating: 1024, count: 32)
|
|
for (pos, tokenId) in promptTokens.enumerated() {
|
|
let logits = try model.forward(tokenId: tokenId, position: pos)
|
|
let nanCount = logits.filter { $0.isNaN }.count
|
|
XCTAssertEqual(nanCount, 0, "NaN at position \(pos)")
|
|
}
|
|
}
|
|
}
|