31427770b1
- Tokenizer fix: collect <0xXX> bytes and decode as UTF-8 (fixes Chinese/non-ASCII character decoding) - BPETokenizer + HuggingFaceTokenizer: both updated - Engine.swift: added writeFloats() utility method - FloatWeights struct added to Layer.swift (bf16 support) - attnQBits/KBits/VBits/OBits detection added to Model.swift - bf16 layer weight support from commit 48c0347 cherry-picked
379 lines
13 KiB
Swift
379 lines
13 KiB
Swift
import Foundation
|
|
|
|
// ─────────────────────────────────────────────────────────────
|
|
// HuggingFace Tokenizer (tokenizer.json format)
|
|
// Used by Gemma-4 and most modern models
|
|
// ─────────────────────────────────────────────────────────────
|
|
|
|
public final class HuggingFaceTokenizer: Tokenizer {
|
|
private let vocab: [String: Int]
|
|
private let reverseVocab: [Int: String]
|
|
private let mergeRanks: [String: Int] // BPE merge ranks (pair -> rank)
|
|
private let addedTokens: [String: Int]
|
|
private let addedTokensReverse: [Int: String]
|
|
|
|
public let vocabSize: Int
|
|
public let bosTokenId: Int
|
|
public let eosTokenId: Int
|
|
public let eosTokenIds: Set<Int>
|
|
public let padTokenId: Int
|
|
public let unkTokenId: Int
|
|
|
|
private let specialTokens: SpecialTokens
|
|
|
|
public init(jsonPath: String) throws {
|
|
let data = try Data(contentsOf: URL(fileURLWithPath: jsonPath))
|
|
let json = try JSONSerialization.jsonObject(with: data) as? [String: Any]
|
|
|
|
guard let model = json?["model"] as? [String: Any] else {
|
|
throw TokenizerError.invalidModelFormat
|
|
}
|
|
|
|
// Load vocabulary
|
|
if let vocabDict = model["vocab"] as? [String: Int] {
|
|
self.vocab = vocabDict
|
|
self.reverseVocab = Dictionary(uniqueKeysWithValues: vocabDict.map { ($1, $0) })
|
|
self.vocabSize = vocabDict.count
|
|
} else {
|
|
throw TokenizerError.invalidModelFormat
|
|
}
|
|
|
|
// Load BPE merges with rank (for proper merge ordering)
|
|
var mergeRanksDict: [String: Int] = [:]
|
|
if let mergeList = model["merges"] as? [String] {
|
|
for (index, merge) in mergeList.enumerated() {
|
|
let parts = merge.split(separator: " ", maxSplits: 1)
|
|
if parts.count == 2 {
|
|
let p0 = String(parts[0])
|
|
let p1 = String(parts[1])
|
|
mergeRanksDict[p0 + p1] = index
|
|
}
|
|
}
|
|
}
|
|
self.mergeRanks = mergeRanksDict
|
|
|
|
// Load added tokens (special tokens)
|
|
self.addedTokens = (json?["added_tokens"] as? [[String: Any]])?.reduce(into: [String: Int]()) { dict, tokenInfo in
|
|
if let content = tokenInfo["content"] as? String, let id = tokenInfo["id"] as? Int {
|
|
dict[content] = id
|
|
}
|
|
} ?? [:]
|
|
self.addedTokensReverse = Dictionary(uniqueKeysWithValues: addedTokens.map { ($1, $0) })
|
|
|
|
// Debug: check mappings
|
|
print("\n=== Tokenizer Init Debug ===")
|
|
fflush(stdout)
|
|
print("vocabSize: \(vocabSize)")
|
|
fflush(stdout)
|
|
print("addedTokensReverse count: \(addedTokensReverse.count)")
|
|
fflush(stdout)
|
|
print("reverseVocab count: \(reverseVocab.count)")
|
|
fflush(stdout)
|
|
print("addedTokensReverse[6226]: \(addedTokensReverse[6226] ?? "nil")")
|
|
fflush(stdout)
|
|
print("reverseVocab[6226]: \(reverseVocab[6226] ?? "nil")")
|
|
fflush(stdout)
|
|
print("addedTokensReverse[262143]: \(addedTokensReverse[262143] ?? "nil")")
|
|
fflush(stdout)
|
|
print("reverseVocab[262143]: \(reverseVocab[262143] ?? "nil")")
|
|
fflush(stdout)
|
|
|
|
// Check if reverseVocab is correct
|
|
if let vocab6226 = reverseVocab[6226] {
|
|
print("reverseVocab[6226] = '\(vocab6226)'")
|
|
fflush(stdout)
|
|
}
|
|
if let vocab262143 = reverseVocab[262143] {
|
|
print("reverseVocab[262143] = '\(vocab262143)'")
|
|
fflush(stdout)
|
|
}
|
|
// Check vocab dictionary
|
|
if let vocabToken = vocab["<unused6226>"] {
|
|
print("vocab['<unused6226>'] = \(vocabToken)")
|
|
fflush(stdout)
|
|
}
|
|
print("=== End Tokenizer Init Debug ===")
|
|
fflush(stdout)
|
|
|
|
// Special tokens IDs
|
|
self.specialTokens = SpecialTokens.gemma4
|
|
self.bosTokenId = addedTokens[specialTokens.bosToken] ?? vocab[specialTokens.bosToken] ?? 2
|
|
self.eosTokenId = addedTokens[specialTokens.eosToken] ?? vocab[specialTokens.eosToken] ?? 1
|
|
var eosIds = Set<Int>([eosTokenId])
|
|
if let t = addedTokens["<turn|>"] ?? vocab["<turn|>"] { eosIds.insert(t) }
|
|
if let t = addedTokens["<|tool_response>"] ?? vocab["<|tool_response>"] { eosIds.insert(t) }
|
|
self.eosTokenIds = eosIds
|
|
self.padTokenId = addedTokens[specialTokens.padToken] ?? vocab[specialTokens.padToken] ?? 0
|
|
self.unkTokenId = vocab["<unk>"] ?? 3
|
|
}
|
|
|
|
public func encode(text: String) -> [Int] {
|
|
var tokens: [Int] = []
|
|
|
|
// Add BOS token
|
|
tokens.append(bosTokenId)
|
|
|
|
// Apply BPE encoding
|
|
let bpeTokens = encodeBPE(text)
|
|
tokens.append(contentsOf: bpeTokens)
|
|
|
|
// Add EOS token
|
|
tokens.append(eosTokenId)
|
|
|
|
return tokens
|
|
}
|
|
|
|
/// Full BPE encoding algorithm
|
|
private func encodeBPE(_ text: String) -> [Int] {
|
|
// Step 1: Pre-tokenize (split into words, handle special chars)
|
|
let words = pretokenize(text)
|
|
|
|
var allTokens: [Int] = []
|
|
for word in words {
|
|
// Step 2: Convert word to byte-level tokens
|
|
var symbols = wordToSymbols(word)
|
|
|
|
// Step 3: Apply BPE merges
|
|
symbols = applyMerges(symbols)
|
|
|
|
// Step 4: Convert to token IDs
|
|
for symbol in symbols {
|
|
if let tokenId = vocab[symbol] {
|
|
allTokens.append(tokenId)
|
|
} else if let tokenId = addedTokens[symbol] {
|
|
allTokens.append(tokenId)
|
|
} else {
|
|
allTokens.append(unkTokenId)
|
|
}
|
|
}
|
|
}
|
|
|
|
return allTokens
|
|
}
|
|
|
|
/// Pre-tokenization: split text into words and special tokens
|
|
private func pretokenize(_ text: String) -> [String] {
|
|
// Split on whitespace but keep the spaces
|
|
var words: [String] = []
|
|
var currentWord = ""
|
|
var inSpace = false
|
|
|
|
for char in text {
|
|
if char.isWhitespace {
|
|
if !currentWord.isEmpty && !inSpace {
|
|
words.append(currentWord)
|
|
currentWord = ""
|
|
}
|
|
currentWord.append(char)
|
|
inSpace = true
|
|
} else {
|
|
if !currentWord.isEmpty && inSpace {
|
|
words.append(currentWord)
|
|
currentWord = ""
|
|
}
|
|
currentWord.append(char)
|
|
inSpace = false
|
|
}
|
|
}
|
|
|
|
if !currentWord.isEmpty {
|
|
words.append(currentWord)
|
|
}
|
|
|
|
return words
|
|
}
|
|
|
|
/// Convert a word to initial byte-level tokens
|
|
private func wordToSymbols(_ word: String) -> [String] {
|
|
var symbols: [String] = []
|
|
let utf8 = word.utf8
|
|
|
|
for byte in utf8 {
|
|
if byte >= 33 && byte <= 126 && byte != 92 && byte != 96 {
|
|
// Printable ASCII
|
|
symbols.append(String(UnicodeScalar(byte)))
|
|
} else {
|
|
// Encode as special token
|
|
symbols.append(String(format: "<0x%02X>", byte))
|
|
}
|
|
}
|
|
|
|
return symbols
|
|
}
|
|
|
|
/// Apply BPE merges to a list of tokens
|
|
private func applyMerges(_ symbols: [String]) -> [String] {
|
|
var current = symbols
|
|
|
|
while current.count > 1 {
|
|
// Find best merge (lowest rank = highest priority)
|
|
var bestRank = Int.max
|
|
var bestPos = -1
|
|
var bestMerge = ""
|
|
|
|
for i in 0..<(current.count - 1) {
|
|
let pair = current[i] + current[i + 1]
|
|
if let rank = mergeRanks[pair], rank < bestRank {
|
|
bestRank = rank
|
|
bestPos = i
|
|
bestMerge = pair
|
|
}
|
|
}
|
|
|
|
guard bestPos >= 0 else { break }
|
|
|
|
// Apply merge
|
|
current[bestPos] = bestMerge
|
|
current.remove(at: bestPos + 1)
|
|
}
|
|
|
|
return current
|
|
}
|
|
|
|
public func rawToken(for id: Int) -> String? {
|
|
addedTokensReverse[id] ?? reverseVocab[id]
|
|
}
|
|
|
|
public func decode(tokens: [Int]) -> String {
|
|
var text = ""
|
|
|
|
for tokenId in tokens {
|
|
// Skip special tokens
|
|
if tokenId == bosTokenId || eosTokenIds.contains(tokenId) || tokenId == padTokenId {
|
|
continue
|
|
}
|
|
|
|
// Look up in reverse vocab or added tokens
|
|
if let token = addedTokensReverse[tokenId] ?? reverseVocab[tokenId] {
|
|
text += token
|
|
}
|
|
}
|
|
|
|
// Clean up tokenization artifacts
|
|
return cleanupText(text)
|
|
}
|
|
|
|
private func cleanupText(_ text: String) -> String {
|
|
// Convert SentencePiece space marker back to space
|
|
var cleaned = text.replacingOccurrences(of: "▁", with: " ")
|
|
|
|
// Decode byte-level tokens
|
|
cleaned = decodeByteTokens(cleaned)
|
|
|
|
// Remove multiple spaces
|
|
cleaned = cleaned.replacingOccurrences(of: " +", with: " ", options: .regularExpression)
|
|
|
|
return cleaned.trimmingCharacters(in: .whitespaces)
|
|
}
|
|
|
|
/// Decode <0xXX> byte tokens back to characters
|
|
private func decodeByteTokens(_ text: String) -> String {
|
|
var bytes: [UInt8] = []
|
|
var result = ""
|
|
var i = text.startIndex
|
|
|
|
while i < text.endIndex {
|
|
if text[i] == "<" {
|
|
let nextIndex = text.index(after: i)
|
|
if nextIndex < text.endIndex && text[nextIndex] == "0" {
|
|
let after0 = text.index(after: nextIndex)
|
|
if after0 < text.endIndex && text[after0] == "x" {
|
|
let hexStart = text.index(after: after0)
|
|
let hexEnd = text.index(hexStart, offsetBy: 2, limitedBy: text.endIndex) ?? text.endIndex
|
|
let hexStr = String(text[hexStart..<hexEnd])
|
|
|
|
if let byte = UInt8(hexStr, radix: 16) {
|
|
bytes.append(byte)
|
|
let afterHex = text.index(after: hexEnd)
|
|
if afterHex < text.endIndex && text[afterHex] == ">" {
|
|
i = text.index(after: afterHex)
|
|
} else {
|
|
i = afterHex
|
|
}
|
|
continue
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if !bytes.isEmpty {
|
|
result += String(bytes: bytes, encoding: .utf8) ?? ""
|
|
bytes.removeAll()
|
|
}
|
|
result.append(text[i])
|
|
i = text.index(after: i)
|
|
}
|
|
|
|
if !bytes.isEmpty {
|
|
result += String(bytes: bytes, encoding: .utf8) ?? ""
|
|
}
|
|
|
|
return result
|
|
}
|
|
}
|
|
|
|
// ─────────────────────────────────────────────────────────────
|
|
// Simple Tokenizer (Fallback for testing)
|
|
// ─────────────────────────────────────────────────────────────
|
|
|
|
public final class SimpleTokenizer: Tokenizer {
|
|
private let vocab: [String: Int]
|
|
private let reverseVocab: [Int: String]
|
|
|
|
public let vocabSize: Int
|
|
public let bosTokenId: Int = 0
|
|
public let eosTokenId: Int = 1
|
|
public let eosTokenIds: Set<Int> = [1]
|
|
public let padTokenId: Int = 2
|
|
|
|
public init() {
|
|
// Create minimal vocab for testing
|
|
var vocabDict: [String: Int] = ["<bos>": 0, "<eos>": 1, "<pad>": 2, "<unk>": 3]
|
|
var reverseDict: [Int: String] = [0: "<bos>", 1: "<eos>", 2: "<pad>", 3: "<unk>"]
|
|
|
|
// Add basic characters
|
|
var idx = 4
|
|
for char in "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789 " {
|
|
vocabDict[String(char)] = idx
|
|
reverseDict[idx] = String(char)
|
|
idx += 1
|
|
}
|
|
|
|
self.vocab = vocabDict
|
|
self.reverseVocab = reverseDict
|
|
self.vocabSize = vocabDict.count
|
|
}
|
|
|
|
public func encode(text: String) -> [Int] {
|
|
var tokens: [Int] = [bosTokenId]
|
|
|
|
for char in text {
|
|
if let tokenId = vocab[String(char)] {
|
|
tokens.append(tokenId)
|
|
}
|
|
}
|
|
|
|
tokens.append(eosTokenId)
|
|
return tokens
|
|
}
|
|
|
|
public func rawToken(for id: Int) -> String? {
|
|
reverseVocab[id]
|
|
}
|
|
|
|
public func decode(tokens: [Int]) -> String {
|
|
var text = ""
|
|
|
|
for tokenId in tokens {
|
|
if tokenId == bosTokenId || eosTokenIds.contains(tokenId) || tokenId == padTokenId {
|
|
continue
|
|
}
|
|
|
|
if let char = reverseVocab[tokenId] {
|
|
text += char
|
|
}
|
|
}
|
|
|
|
return text
|
|
}
|
|
} |