31427770b1
- Tokenizer fix: collect <0xXX> bytes and decode as UTF-8 (fixes Chinese/non-ASCII character decoding) - BPETokenizer + HuggingFaceTokenizer: both updated - Engine.swift: added writeFloats() utility method - FloatWeights struct added to Layer.swift (bf16 support) - attnQBits/KBits/VBits/OBits detection added to Model.swift - bf16 layer weight support from commit 48c0347 cherry-picked
438 lines
16 KiB
Swift
438 lines
16 KiB
Swift
import Foundation
|
|
|
|
// ─────────────────────────────────────────────────────────────
|
|
// Complete BPE Tokenizer for HuggingFace tokenizer.json format
|
|
// Supports Gemma, Llama, Qwen and other modern models
|
|
// ─────────────────────────────────────────────────────────────
|
|
|
|
public final class BPETokenizer: Tokenizer, @unchecked Sendable {
|
|
private let vocab: [String: Int]
|
|
private let reverseVocab: [Int: String]
|
|
private let mergeRanks: [String: Int]
|
|
private let addedTokens: [String: Int]
|
|
private let addedTokensReverse: [Int: String]
|
|
private let preTokenizer: PreTokenizer?
|
|
|
|
public let vocabSize: Int
|
|
public let bosTokenId: Int
|
|
public let eosTokenId: Int
|
|
public let eosTokenIds: Set<Int>
|
|
public let padTokenId: Int
|
|
public let unkTokenId: Int
|
|
|
|
public init(jsonPath: String) throws {
|
|
let data = try Data(contentsOf: URL(fileURLWithPath: jsonPath))
|
|
let json = try JSONSerialization.jsonObject(with: data) as? [String: Any]
|
|
|
|
guard let model = json?["model"] as? [String: Any] else {
|
|
throw TokenizerError.invalidModelFormat
|
|
}
|
|
|
|
// Load vocabulary
|
|
if let vocabDict = model["vocab"] as? [String: Int] {
|
|
self.vocab = vocabDict
|
|
self.reverseVocab = Dictionary(uniqueKeysWithValues: vocabDict.map { ($1, $0) })
|
|
self.vocabSize = vocabDict.count
|
|
} else {
|
|
throw TokenizerError.invalidModelFormat
|
|
}
|
|
|
|
// Load BPE merges with rank
|
|
var mergeRanksDict: [String: Int] = [:]
|
|
if let mergePairs = model["merges"] as? [[String]] {
|
|
// Gemma-4 format: merges is array of pairs ["token1", "token2"]
|
|
for (index, pair) in mergePairs.enumerated() {
|
|
if pair.count == 2 {
|
|
mergeRanksDict[pair[0] + pair[1]] = index
|
|
}
|
|
}
|
|
} else if let mergeList = model["merges"] as? [String] {
|
|
// Legacy format: merges is array of strings "token1 token2"
|
|
for (index, merge) in mergeList.enumerated() {
|
|
let parts = merge.split(separator: " ", maxSplits: 1)
|
|
if parts.count == 2 {
|
|
mergeRanksDict[String(parts[0]) + String(parts[1])] = index
|
|
}
|
|
}
|
|
}
|
|
self.mergeRanks = mergeRanksDict
|
|
|
|
// Load added tokens
|
|
self.addedTokens = (json?["added_tokens"] as? [[String: Any]])?.reduce(into: [String: Int]()) { dict, tokenInfo in
|
|
if let content = tokenInfo["content"] as? String, let id = tokenInfo["id"] as? Int {
|
|
dict[content] = id
|
|
}
|
|
} ?? [:]
|
|
self.addedTokensReverse = Dictionary(uniqueKeysWithValues: addedTokens.map { ($1, $0) })
|
|
|
|
// Parse pre-tokenizer
|
|
if let preTokenizers = json?["pre_tokenizer"] as? [String: Any],
|
|
let type = preTokenizers["type"] as? String {
|
|
let behavior = preTokenizers["behavior"] as? String
|
|
self.preTokenizer = PreTokenizer(type: type, behavior: behavior)
|
|
} else {
|
|
self.preTokenizer = nil
|
|
}
|
|
|
|
// Special tokens
|
|
self.bosTokenId = addedTokens["<bos>"] ?? addedTokens["<start_of_turn>"] ?? vocab["<bos>"] ?? 2
|
|
self.eosTokenId = addedTokens["<eos>"] ?? addedTokens["<end_of_turn>"] ?? vocab["<eos>"] ?? 1
|
|
var eosIds = Set<Int>([eosTokenId])
|
|
if let t = addedTokens["<turn|>"] ?? vocab["<turn|>"] { eosIds.insert(t) }
|
|
if let t = addedTokens["<|tool_response>"] ?? vocab["<|tool_response>"] { eosIds.insert(t) }
|
|
self.eosTokenIds = eosIds
|
|
self.padTokenId = addedTokens["<pad>"] ?? vocab["<pad>"] ?? 0
|
|
self.unkTokenId = vocab["<unk>"] ?? 3
|
|
}
|
|
|
|
public func encode(text: String) -> [Int] {
|
|
var tokens: [Int] = [bosTokenId]
|
|
tokens.append(contentsOf: encodeBPE(text))
|
|
return tokens
|
|
}
|
|
|
|
public func rawToken(for id: Int) -> String? {
|
|
addedTokensReverse[id] ?? reverseVocab[id]
|
|
}
|
|
|
|
public func decode(tokens: [Int]) -> String {
|
|
var text = ""
|
|
|
|
for tokenId in tokens {
|
|
if tokenId == bosTokenId || eosTokenIds.contains(tokenId) || tokenId == padTokenId {
|
|
continue
|
|
}
|
|
|
|
if let token = addedTokensReverse[tokenId] ?? reverseVocab[tokenId] {
|
|
text += token
|
|
}
|
|
}
|
|
|
|
return cleanupText(text)
|
|
}
|
|
|
|
// ─────────────────────────────────────────────────────────────
|
|
// BPE Encoding
|
|
// ─────────────────────────────────────────────────────────────
|
|
|
|
private func encodeBPE(_ text: String) -> [Int] {
|
|
// Pre-tokenize
|
|
let words = preTokenizer?.pretokenize(text) ?? [text]
|
|
|
|
var allTokens: [Int] = []
|
|
for word in words {
|
|
if word.isEmpty { continue }
|
|
|
|
// Convert to initial tokens
|
|
var symbols = wordToTokens(word)
|
|
|
|
// Apply BPE merges
|
|
symbols = applyMerges(symbols)
|
|
|
|
// Convert to IDs
|
|
for symbol in symbols {
|
|
if let id = vocab[symbol] ?? addedTokens[symbol] {
|
|
allTokens.append(id)
|
|
} else {
|
|
allTokens.append(unkTokenId)
|
|
}
|
|
}
|
|
}
|
|
|
|
return allTokens
|
|
}
|
|
|
|
private func wordToTokens(_ word: String) -> [String] {
|
|
var tokens: [String] = []
|
|
|
|
// Handle each character (not byte-by-byte)
|
|
for char in word {
|
|
if char == "▁" {
|
|
// Sentencepiece underscore: add as single token
|
|
tokens.append("▁")
|
|
} else {
|
|
let charValue = char.asciiValue ?? 0
|
|
if charValue >= 0x21 && charValue <= 0x7E && charValue != 0x5C && charValue != 0x60 {
|
|
// Regular ASCII: add as single character
|
|
tokens.append(String(char))
|
|
} else {
|
|
// Non-ASCII or special: encode as bytes
|
|
for byte in String(char).utf8 {
|
|
tokens.append(String(format: "<0x%02X>", byte))
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return tokens
|
|
}
|
|
|
|
private func applyMerges(_ symbols: [String]) -> [String] {
|
|
var current = symbols
|
|
|
|
while current.count > 1 {
|
|
var bestRank = Int.max
|
|
var bestPos = -1
|
|
var bestMerge = ""
|
|
|
|
for i in 0..<(current.count - 1) {
|
|
let pair = current[i] + current[i + 1]
|
|
if let rank = mergeRanks[pair], rank < bestRank {
|
|
bestRank = rank
|
|
bestPos = i
|
|
bestMerge = pair
|
|
}
|
|
}
|
|
|
|
guard bestPos >= 0 else { break }
|
|
|
|
current[bestPos] = bestMerge
|
|
current.remove(at: bestPos + 1)
|
|
}
|
|
|
|
return current
|
|
}
|
|
|
|
private func cleanupText(_ text: String) -> String {
|
|
var cleaned = text.replacingOccurrences(of: "▁", with: " ")
|
|
cleaned = decodeByteTokens(cleaned)
|
|
cleaned = cleaned.replacingOccurrences(of: " +", with: " ", options: .regularExpression)
|
|
return cleaned.trimmingCharacters(in: .whitespaces)
|
|
}
|
|
|
|
private func decodeByteTokens(_ text: String) -> String {
|
|
var bytes: [UInt8] = []
|
|
var result = ""
|
|
var i = text.startIndex
|
|
|
|
while i < text.endIndex {
|
|
if text[i] == "<" {
|
|
let nextIndex = text.index(after: i)
|
|
if nextIndex < text.endIndex && text[nextIndex] == "0" {
|
|
let after0 = text.index(after: nextIndex)
|
|
if after0 < text.endIndex && text[after0] == "x" {
|
|
let hexStart = text.index(after: after0)
|
|
let hexEnd = text.index(hexStart, offsetBy: 2, limitedBy: text.endIndex) ?? text.endIndex
|
|
let hexStr = String(text[hexStart..<hexEnd])
|
|
|
|
if let byte = UInt8(hexStr, radix: 16) {
|
|
bytes.append(byte)
|
|
let afterHex = text.index(after: hexEnd)
|
|
if afterHex < text.endIndex && text[afterHex] == ">" {
|
|
i = text.index(after: afterHex)
|
|
} else {
|
|
i = afterHex
|
|
}
|
|
continue
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if !bytes.isEmpty {
|
|
result += String(bytes: bytes, encoding: .utf8) ?? ""
|
|
bytes.removeAll()
|
|
}
|
|
result.append(text[i])
|
|
i = text.index(after: i)
|
|
}
|
|
|
|
if !bytes.isEmpty {
|
|
result += String(bytes: bytes, encoding: .utf8) ?? ""
|
|
}
|
|
|
|
return result
|
|
}
|
|
}
|
|
|
|
// ─────────────────────────────────────────────────────────────
|
|
// Streaming Decoder
|
|
// ─────────────────────────────────────────────────────────────
|
|
|
|
/// Streaming-friendly decoder that accumulates raw byte tokens
|
|
/// and only emits text when complete UTF-8 sequences are formed.
|
|
/// This solves the 缺字 issue where Chinese characters (3 byte tokens each)
|
|
/// would be decoded as individual Latin-1 garbage characters.
|
|
public struct StreamingDecoder {
|
|
private let tokenizer: any Tokenizer
|
|
private var byteBuffer: [UInt8] = []
|
|
|
|
public init(tokenizer: any Tokenizer) {
|
|
self.tokenizer = tokenizer
|
|
}
|
|
|
|
/// Consume one token, return any newly completed text.
|
|
public mutating func consume(tokenId: Int) -> String {
|
|
// Skip special tokens
|
|
if tokenizer.bosTokenId == tokenId || tokenizer.eosTokenIds.contains(tokenId) || tokenizer.padTokenId == tokenId {
|
|
return ""
|
|
}
|
|
guard let raw = tokenizer.rawToken(for: tokenId) else {
|
|
return ""
|
|
}
|
|
// Extract bytes from raw string (handles both <0xXX> and literal ASCII)
|
|
extractBytes(from: raw, into: &byteBuffer)
|
|
// Decode as many complete UTF-8 sequences as possible
|
|
return drainUTF8()
|
|
}
|
|
|
|
/// Flush any remaining buffered bytes as a lossy UTF-8 string.
|
|
public func flush() -> String {
|
|
if byteBuffer.isEmpty { return "" }
|
|
return String(decoding: byteBuffer, as: UTF8.self)
|
|
}
|
|
|
|
/// Reset the buffer (e.g., on generation complete).
|
|
public mutating func reset() {
|
|
byteBuffer.removeAll()
|
|
}
|
|
|
|
// ── Private helpers ──
|
|
|
|
private mutating func drainUTF8() -> String {
|
|
guard !byteBuffer.isEmpty else { return "" }
|
|
// Find the longest valid UTF-8 prefix
|
|
var validCount = 0
|
|
var i = 0
|
|
while i < byteBuffer.count {
|
|
let byte = byteBuffer[i]
|
|
if byte < 0x80 {
|
|
// Single byte
|
|
i += 1
|
|
validCount = i
|
|
} else if byte < 0xC0 {
|
|
// Unexpected continuation byte — stop here
|
|
break
|
|
} else if byte < 0xE0 {
|
|
// 2-byte sequence
|
|
guard i + 1 < byteBuffer.count, byteBuffer[i+1] & 0xC0 == 0x80 else { break }
|
|
i += 2
|
|
validCount = i
|
|
} else if byte < 0xF0 {
|
|
// 3-byte sequence (Chinese characters)
|
|
guard i + 2 < byteBuffer.count,
|
|
byteBuffer[i+1] & 0xC0 == 0x80,
|
|
byteBuffer[i+2] & 0xC0 == 0x80 else { break }
|
|
i += 3
|
|
validCount = i
|
|
} else {
|
|
// 4-byte sequence
|
|
guard i + 3 < byteBuffer.count,
|
|
byteBuffer[i+1] & 0xC0 == 0x80,
|
|
byteBuffer[i+2] & 0xC0 == 0x80,
|
|
byteBuffer[i+3] & 0xC0 == 0x80 else { break }
|
|
i += 4
|
|
validCount = i
|
|
}
|
|
}
|
|
guard validCount > 0 else { return "" }
|
|
let data = Data(byteBuffer[0..<validCount])
|
|
byteBuffer.removeFirst(validCount)
|
|
return String(data: data, encoding: .utf8) ?? ""
|
|
}
|
|
}
|
|
|
|
private func extractBytes(from text: String, into buffer: inout [UInt8]) {
|
|
var i = text.startIndex
|
|
while i < text.endIndex {
|
|
if text[i] == "<" {
|
|
let nextIdx = text.index(after: i)
|
|
if nextIdx < text.endIndex, text[nextIdx] == "0" {
|
|
let a0 = text.index(after: nextIdx)
|
|
if a0 < text.endIndex, text[a0] == "x" {
|
|
let hStart = text.index(after: a0)
|
|
let hEnd = text.index(hStart, offsetBy: 2, limitedBy: text.endIndex) ?? text.endIndex
|
|
let hex = String(text[hStart..<hEnd])
|
|
if let byte = UInt8(hex, radix: 16) {
|
|
buffer.append(byte)
|
|
let after = text.index(after: hEnd)
|
|
if after < text.endIndex, text[after] == ">" {
|
|
i = text.index(after: after)
|
|
} else {
|
|
i = after
|
|
}
|
|
continue
|
|
}
|
|
}
|
|
}
|
|
}
|
|
// Literal character — encode as UTF-8
|
|
let ch = text[i]
|
|
let utf8 = String(ch).utf8
|
|
buffer.append(contentsOf: utf8)
|
|
i = text.index(after: i)
|
|
}
|
|
}
|
|
|
|
// ─────────────────────────────────────────────────────────────
|
|
// Pre-Tokenization
|
|
// ─────────────────────────────────────────────────────────────
|
|
|
|
final class PreTokenizer {
|
|
let type: String
|
|
let behavior: String?
|
|
|
|
init(type: String, behavior: String? = nil) {
|
|
self.type = type
|
|
self.behavior = behavior
|
|
}
|
|
|
|
func pretokenize(_ text: String) -> [String] {
|
|
switch type {
|
|
case "Split":
|
|
return splitPreTokenize(text)
|
|
case "ByteLevel":
|
|
return byteLevelPreTokenize(text)
|
|
default:
|
|
return [text]
|
|
}
|
|
}
|
|
|
|
private func splitPreTokenize(_ text: String) -> [String] {
|
|
// Split on whitespace with "MergedWithPrevious" behavior
|
|
// This means spaces are attached to the previous word
|
|
// For sentencepiece, we convert spaces to ▁ prefix
|
|
|
|
var words: [String] = []
|
|
var currentWord = ""
|
|
var i = 0
|
|
|
|
while i < text.count {
|
|
let char = text[text.index(text.startIndex, offsetBy: i)]
|
|
|
|
if char == " " {
|
|
// Space: merge with previous word, convert to ▁ prefix
|
|
if !currentWord.isEmpty {
|
|
// Add ▁ prefix to current word
|
|
currentWord = "▁" + currentWord
|
|
words.append(currentWord)
|
|
currentWord = ""
|
|
}
|
|
} else {
|
|
currentWord.append(char)
|
|
}
|
|
|
|
i += 1
|
|
}
|
|
|
|
// Add last word (without ▁ prefix if it's the first word)
|
|
if !currentWord.isEmpty {
|
|
if words.isEmpty {
|
|
// First word: no prefix
|
|
words.append(currentWord)
|
|
} else {
|
|
// Not first word: add prefix
|
|
words.append("▁" + currentWord)
|
|
}
|
|
}
|
|
|
|
return words
|
|
}
|
|
|
|
private func byteLevelPreTokenize(_ text: String) -> [String] {
|
|
// Replace space with ▁ and split
|
|
let modified = text.replacingOccurrences(of: " ", with: "▁")
|
|
return [modified]
|
|
}
|
|
}
|