8a66b9086a
- Started from ac75faa (initial E4B-MarkBase integration)
- Kept Sources/ (all engine code) + Package.swift + .gitignore
- Removed all ad-hoc tests, documentation, scripts, Python files
- Added Tests/00_Unit/ (MathTest, TokenizerTest, SamplerTest)
- Added .gitea/workflows/ci.yaml (build + unit tests + lint)
- Added Scripts/check_resources.sh (memory-aware test runner)
- Added Tests/Manifest.json (resource requirements for all tests)
- Focus: 4-bit quantized models only
52 lines
1.5 KiB
Swift
52 lines
1.5 KiB
Swift
import Foundation
|
|
import MarkBase
|
|
|
|
// Simple CLI test for forward pass
|
|
let modelDir = "./models/gemma-4-12b-it-4bit"
|
|
|
|
guard FileManager.default.fileExists(atPath: modelDir + "/config.json") else {
|
|
print("Model not found at \(modelDir)")
|
|
exit(1)
|
|
}
|
|
|
|
print("Loading engine...")
|
|
let engine = try MarkBaseEngine(autoCompile: true)
|
|
print("✓ Engine created")
|
|
|
|
print("\nLoading 12B model...")
|
|
let start = Date()
|
|
let model = try E4BModel(modelDir: modelDir, engine: engine, maxContextLength: 128)
|
|
let loadTime = Date().timeIntervalSince(start)
|
|
|
|
print("✓ Model loaded in \(String(format: "%.1f", loadTime))s")
|
|
print(" Layers: \(model.numHiddenLayers)")
|
|
print(" Hidden: \(model.hiddenSize)")
|
|
print(" Vocab: \(model.vocabSize)")
|
|
|
|
// Test forward pass
|
|
print("\n=== Testing forward pass ===")
|
|
print("Testing token 2 (BOS) at position 0...")
|
|
let logits = try model.forward(tokenId: 2, position: 0, debug: true)
|
|
|
|
print("\n✓ Forward pass complete: \(logits.count) logits")
|
|
|
|
let maxLogit = logits.max() ?? -999
|
|
let minLogit = logits.min() ?? -999
|
|
let hasNaN = logits.contains { $0.isNaN }
|
|
let nanCount = logits.filter { $0.isNaN }.count
|
|
|
|
print(" Max logit: \(maxLogit)")
|
|
print(" Min logit: \(minLogit)")
|
|
print(" NaN count: \(nanCount)/\(logits.count)")
|
|
print(" Has NaN: \(hasNaN)")
|
|
|
|
if !hasNaN {
|
|
let sorted = logits.enumerated().sorted { $0.element > $1.element }
|
|
let top10 = sorted.prefix(10)
|
|
print("\n Top 10 tokens:")
|
|
for (idx, logit) in top10 {
|
|
print(" Token \(idx): \(String(format: "%.4f", logit))")
|
|
}
|
|
}
|
|
|
|
print("\n✅ CLI test complete!") |