v2: add full context 2048-token, repeated tokens, edge token tests
This commit is contained in:
@@ -54,6 +54,67 @@ final class LongContext12BTest: XCTestCase {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func testFullContext2048Tokens() throws {
|
||||||
|
try XCTSkipIf(model == nil, "12B model not found")
|
||||||
|
|
||||||
|
let promptLength = maxCtx
|
||||||
|
var tokens = [Int]()
|
||||||
|
for i in 0..<promptLength {
|
||||||
|
tokens.append(100 + (i % 1000))
|
||||||
|
}
|
||||||
|
|
||||||
|
var lastLogits: [Float]?
|
||||||
|
for (pos, tokenId) in tokens.enumerated() {
|
||||||
|
let logits = try model.forward(tokenId: tokenId, position: pos)
|
||||||
|
if pos == 0 || pos == promptLength - 1 || pos % 256 == 0 {
|
||||||
|
let nanCount = logits.filter { $0.isNaN }.count
|
||||||
|
XCTAssertEqual(nanCount, 0, "NaN at pos=\(pos)")
|
||||||
|
print(" pos=\(pos): logits[0..3]=\(logits.prefix(3)) NaN=\(nanCount)")
|
||||||
|
}
|
||||||
|
lastLogits = logits
|
||||||
|
}
|
||||||
|
|
||||||
|
var genTokens = tokens
|
||||||
|
for i in 0..<3 {
|
||||||
|
let logits = try model.forward(tokenId: genTokens.last ?? 0, position: genTokens.count - 1)
|
||||||
|
let nanCount = logits.filter { $0.isNaN }.count
|
||||||
|
XCTAssertEqual(nanCount, 0, "NaN at gen step \(i)")
|
||||||
|
var maxIdx = 0
|
||||||
|
var maxVal = logits[0]
|
||||||
|
for j in 1..<logits.count {
|
||||||
|
if logits[j] > maxVal { maxVal = logits[j]; maxIdx = j }
|
||||||
|
}
|
||||||
|
genTokens.append(maxIdx)
|
||||||
|
print(" gen[\(i)]: token=\(maxIdx) logit=\(maxVal)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testRepeatedTokensFullContext() throws {
|
||||||
|
try XCTSkipIf(model == nil, "12B model not found")
|
||||||
|
|
||||||
|
let promptLength = maxCtx / 2
|
||||||
|
for (pos, _) in (0..<promptLength).enumerated() {
|
||||||
|
let logits = try model.forward(tokenId: 100, position: pos)
|
||||||
|
if pos == 0 || pos == promptLength - 1 || pos % 256 == 0 {
|
||||||
|
let nanCount = logits.filter { $0.isNaN }.count
|
||||||
|
XCTAssertEqual(nanCount, 0, "NaN at pos=\(pos) (repeated tokens)")
|
||||||
|
print(" repeat pos=\(pos): logits[0..3]=\(logits.prefix(3)) NaN=\(nanCount)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testTokenIdBoundaries() throws {
|
||||||
|
try XCTSkipIf(model == nil, "12B model not found")
|
||||||
|
|
||||||
|
let edgeTokens = [0, 1, 2, model.vocabSize - 1]
|
||||||
|
for (pos, tokenId) in edgeTokens.enumerated() {
|
||||||
|
let logits = try model.forward(tokenId: tokenId, position: pos)
|
||||||
|
let nanCount = logits.filter { $0.isNaN }.count
|
||||||
|
XCTAssertEqual(nanCount, 0, "NaN for tokenId=\(tokenId)")
|
||||||
|
print(" edge token=\(tokenId): logits[0..3]=\(logits.prefix(3)) NaN=\(nanCount)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func testLongContext1024Tokens() throws {
|
func testLongContext1024Tokens() throws {
|
||||||
try XCTSkipIf(model == nil, "12B model not found")
|
try XCTSkipIf(model == nil, "12B model not found")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user