Files
markbaseengine/Sources/MarkBase/Tokenizer/BPETokenizer.swift
T
MarkBase Admin 8a66b9086a
CI / build (push) Waiting to run
CI / unit-tests (push) Blocked by required conditions
CI / lint (push) Blocked by required conditions
v2: Initial clean branch with unit tests + CI/CD pipeline
- Started from ac75faa (initial E4B-MarkBase integration)
- Kept Sources/ (all engine code) + Package.swift + .gitignore
- Removed all ad-hoc tests, documentation, scripts, Python files
- Added Tests/00_Unit/ (MathTest, TokenizerTest, SamplerTest)
- Added .gitea/workflows/ci.yaml (build + unit tests + lint)
- Added Scripts/check_resources.sh (memory-aware test runner)
- Added Tests/Manifest.json (resource requirements for all tests)
- Focus: 4-bit quantized models only
2026-07-05 13:29:25 +08:00

429 lines
16 KiB
Swift

import Foundation
// ─────────────────────────────────────────────────────────────
// Complete BPE Tokenizer for HuggingFace tokenizer.json format
// Supports Gemma, Llama, Qwen and other modern models
// ─────────────────────────────────────────────────────────────
public final class BPETokenizer: Tokenizer, @unchecked Sendable {
private let vocab: [String: Int]
private let reverseVocab: [Int: String]
private let mergeRanks: [String: Int]
private let addedTokens: [String: Int]
private let addedTokensReverse: [Int: String]
private let preTokenizer: PreTokenizer?
public let vocabSize: Int
public let bosTokenId: Int
public let eosTokenId: Int
public let eosTokenIds: Set<Int>
public let padTokenId: Int
public let unkTokenId: Int
public init(jsonPath: String) throws {
let data = try Data(contentsOf: URL(fileURLWithPath: jsonPath))
let json = try JSONSerialization.jsonObject(with: data) as? [String: Any]
guard let model = json?["model"] as? [String: Any] else {
throw TokenizerError.invalidModelFormat
}
// Load vocabulary
if let vocabDict = model["vocab"] as? [String: Int] {
self.vocab = vocabDict
self.reverseVocab = Dictionary(uniqueKeysWithValues: vocabDict.map { ($1, $0) })
self.vocabSize = vocabDict.count
} else {
throw TokenizerError.invalidModelFormat
}
// Load BPE merges with rank
var mergeRanksDict: [String: Int] = [:]
if let mergePairs = model["merges"] as? [[String]] {
// Gemma-4 format: merges is array of pairs ["token1", "token2"]
for (index, pair) in mergePairs.enumerated() {
if pair.count == 2 {
mergeRanksDict[pair[0] + pair[1]] = index
}
}
} else if let mergeList = model["merges"] as? [String] {
// Legacy format: merges is array of strings "token1 token2"
for (index, merge) in mergeList.enumerated() {
let parts = merge.split(separator: " ", maxSplits: 1)
if parts.count == 2 {
mergeRanksDict[String(parts[0]) + String(parts[1])] = index
}
}
}
self.mergeRanks = mergeRanksDict
// Load added tokens
self.addedTokens = (json?["added_tokens"] as? [[String: Any]])?.reduce(into: [String: Int]()) { dict, tokenInfo in
if let content = tokenInfo["content"] as? String, let id = tokenInfo["id"] as? Int {
dict[content] = id
}
} ?? [:]
self.addedTokensReverse = Dictionary(uniqueKeysWithValues: addedTokens.map { ($1, $0) })
// Parse pre-tokenizer
if let preTokenizers = json?["pre_tokenizer"] as? [String: Any],
let type = preTokenizers["type"] as? String {
let behavior = preTokenizers["behavior"] as? String
self.preTokenizer = PreTokenizer(type: type, behavior: behavior)
} else {
self.preTokenizer = nil
}
// Special tokens
self.bosTokenId = addedTokens["<bos>"] ?? addedTokens["<start_of_turn>"] ?? vocab["<bos>"] ?? 2
self.eosTokenId = addedTokens["<eos>"] ?? addedTokens["<end_of_turn>"] ?? vocab["<eos>"] ?? 1
var eosIds = Set<Int>([eosTokenId])
if let t = addedTokens["<turn|>"] ?? vocab["<turn|>"] { eosIds.insert(t) }
if let t = addedTokens["<|tool_response>"] ?? vocab["<|tool_response>"] { eosIds.insert(t) }
self.eosTokenIds = eosIds
self.padTokenId = addedTokens["<pad>"] ?? vocab["<pad>"] ?? 0
self.unkTokenId = vocab["<unk>"] ?? 3
}
public func encode(text: String) -> [Int] {
var tokens: [Int] = [bosTokenId]
tokens.append(contentsOf: encodeBPE(text))
return tokens
}
public func rawToken(for id: Int) -> String? {
addedTokensReverse[id] ?? reverseVocab[id]
}
public func decode(tokens: [Int]) -> String {
var text = ""
for tokenId in tokens {
if tokenId == bosTokenId || eosTokenIds.contains(tokenId) || tokenId == padTokenId {
continue
}
if let token = addedTokensReverse[tokenId] ?? reverseVocab[tokenId] {
text += token
}
}
return cleanupText(text)
}
// ─────────────────────────────────────────────────────────────
// BPE Encoding
// ─────────────────────────────────────────────────────────────
private func encodeBPE(_ text: String) -> [Int] {
// Pre-tokenize
let words = preTokenizer?.pretokenize(text) ?? [text]
var allTokens: [Int] = []
for word in words {
if word.isEmpty { continue }
// Convert to initial tokens
var symbols = wordToTokens(word)
// Apply BPE merges
symbols = applyMerges(symbols)
// Convert to IDs
for symbol in symbols {
if let id = vocab[symbol] ?? addedTokens[symbol] {
allTokens.append(id)
} else {
allTokens.append(unkTokenId)
}
}
}
return allTokens
}
private func wordToTokens(_ word: String) -> [String] {
var tokens: [String] = []
// Handle each character (not byte-by-byte)
for char in word {
if char == "▁" {
// Sentencepiece underscore: add as single token
tokens.append("▁")
} else {
let charValue = char.asciiValue ?? 0
if charValue >= 0x21 && charValue <= 0x7E && charValue != 0x5C && charValue != 0x60 {
// Regular ASCII: add as single character
tokens.append(String(char))
} else {
// Non-ASCII or special: encode as bytes
for byte in String(char).utf8 {
tokens.append(String(format: "<0x%02X>", byte))
}
}
}
}
return tokens
}
private func applyMerges(_ symbols: [String]) -> [String] {
var current = symbols
while current.count > 1 {
var bestRank = Int.max
var bestPos = -1
var bestMerge = ""
for i in 0..<(current.count - 1) {
let pair = current[i] + current[i + 1]
if let rank = mergeRanks[pair], rank < bestRank {
bestRank = rank
bestPos = i
bestMerge = pair
}
}
guard bestPos >= 0 else { break }
current[bestPos] = bestMerge
current.remove(at: bestPos + 1)
}
return current
}
private func cleanupText(_ text: String) -> String {
var cleaned = text.replacingOccurrences(of: "▁", with: " ")
cleaned = decodeByteTokens(cleaned)
cleaned = cleaned.replacingOccurrences(of: " +", with: " ", options: .regularExpression)
return cleaned.trimmingCharacters(in: .whitespaces)
}
private func decodeByteTokens(_ text: String) -> String {
var result = ""
var i = text.startIndex
while i < text.endIndex {
if text[i] == "<" {
let nextIndex = text.index(after: i)
if nextIndex < text.endIndex && text[nextIndex] == "0" {
let after0 = text.index(after: nextIndex)
if after0 < text.endIndex && text[after0] == "x" {
let hexStart = text.index(after: after0)
let hexEnd = text.index(hexStart, offsetBy: 2, limitedBy: text.endIndex) ?? text.endIndex
let hexStr = String(text[hexStart..<hexEnd])
if let byte = UInt8(hexStr, radix: 16) {
result.append(Character(UnicodeScalar(byte)))
let afterHex = text.index(after: hexEnd)
if afterHex < text.endIndex && text[afterHex] == ">" {
i = text.index(after: afterHex)
} else {
i = afterHex
}
continue
}
}
}
}
result.append(text[i])
i = text.index(after: i)
}
return result
}
}
// ─────────────────────────────────────────────────────────────
// Streaming Decoder
// ─────────────────────────────────────────────────────────────
/// Streaming-friendly decoder that accumulates raw byte tokens
/// and only emits text when complete UTF-8 sequences are formed.
/// This solves the 缺字 issue where Chinese characters (3 byte tokens each)
/// would be decoded as individual Latin-1 garbage characters.
public struct StreamingDecoder {
private let tokenizer: any Tokenizer
private var byteBuffer: [UInt8] = []
public init(tokenizer: any Tokenizer) {
self.tokenizer = tokenizer
}
/// Consume one token, return any newly completed text.
public mutating func consume(tokenId: Int) -> String {
// Skip special tokens
if tokenizer.bosTokenId == tokenId || tokenizer.eosTokenIds.contains(tokenId) || tokenizer.padTokenId == tokenId {
return ""
}
guard let raw = tokenizer.rawToken(for: tokenId) else {
return ""
}
// Extract bytes from raw string (handles both <0xXX> and literal ASCII)
extractBytes(from: raw, into: &byteBuffer)
// Decode as many complete UTF-8 sequences as possible
return drainUTF8()
}
/// Flush any remaining buffered bytes as a lossy UTF-8 string.
public func flush() -> String {
if byteBuffer.isEmpty { return "" }
return String(decoding: byteBuffer, as: UTF8.self)
}
/// Reset the buffer (e.g., on generation complete).
public mutating func reset() {
byteBuffer.removeAll()
}
// ── Private helpers ──
private mutating func drainUTF8() -> String {
guard !byteBuffer.isEmpty else { return "" }
// Find the longest valid UTF-8 prefix
var validCount = 0
var i = 0
while i < byteBuffer.count {
let byte = byteBuffer[i]
if byte < 0x80 {
// Single byte
i += 1
validCount = i
} else if byte < 0xC0 {
// Unexpected continuation byte — stop here
break
} else if byte < 0xE0 {
// 2-byte sequence
guard i + 1 < byteBuffer.count, byteBuffer[i+1] & 0xC0 == 0x80 else { break }
i += 2
validCount = i
} else if byte < 0xF0 {
// 3-byte sequence (Chinese characters)
guard i + 2 < byteBuffer.count,
byteBuffer[i+1] & 0xC0 == 0x80,
byteBuffer[i+2] & 0xC0 == 0x80 else { break }
i += 3
validCount = i
} else {
// 4-byte sequence
guard i + 3 < byteBuffer.count,
byteBuffer[i+1] & 0xC0 == 0x80,
byteBuffer[i+2] & 0xC0 == 0x80,
byteBuffer[i+3] & 0xC0 == 0x80 else { break }
i += 4
validCount = i
}
}
guard validCount > 0 else { return "" }
let data = Data(byteBuffer[0..<validCount])
byteBuffer.removeFirst(validCount)
return String(data: data, encoding: .utf8) ?? ""
}
}
private func extractBytes(from text: String, into buffer: inout [UInt8]) {
var i = text.startIndex
while i < text.endIndex {
if text[i] == "<" {
let nextIdx = text.index(after: i)
if nextIdx < text.endIndex, text[nextIdx] == "0" {
let a0 = text.index(after: nextIdx)
if a0 < text.endIndex, text[a0] == "x" {
let hStart = text.index(after: a0)
let hEnd = text.index(hStart, offsetBy: 2, limitedBy: text.endIndex) ?? text.endIndex
let hex = String(text[hStart..<hEnd])
if let byte = UInt8(hex, radix: 16) {
buffer.append(byte)
let after = text.index(after: hEnd)
if after < text.endIndex, text[after] == ">" {
i = text.index(after: after)
} else {
i = after
}
continue
}
}
}
}
// Literal character — encode as UTF-8
let ch = text[i]
let utf8 = String(ch).utf8
buffer.append(contentsOf: utf8)
i = text.index(after: i)
}
}
// ─────────────────────────────────────────────────────────────
// Pre-Tokenization
// ─────────────────────────────────────────────────────────────
final class PreTokenizer {
let type: String
let behavior: String?
init(type: String, behavior: String? = nil) {
self.type = type
self.behavior = behavior
}
func pretokenize(_ text: String) -> [String] {
switch type {
case "Split":
return splitPreTokenize(text)
case "ByteLevel":
return byteLevelPreTokenize(text)
default:
return [text]
}
}
private func splitPreTokenize(_ text: String) -> [String] {
// Split on whitespace with "MergedWithPrevious" behavior
// This means spaces are attached to the previous word
// For sentencepiece, we convert spaces to ▁ prefix
var words: [String] = []
var currentWord = ""
var i = 0
while i < text.count {
let char = text[text.index(text.startIndex, offsetBy: i)]
if char == " " {
// Space: merge with previous word, convert to ▁ prefix
if !currentWord.isEmpty {
// Add ▁ prefix to current word
currentWord = "▁" + currentWord
words.append(currentWord)
currentWord = ""
}
} else {
currentWord.append(char)
}
i += 1
}
// Add last word (without ▁ prefix if it's the first word)
if !currentWord.isEmpty {
if words.isEmpty {
// First word: no prefix
words.append(currentWord)
} else {
// Not first word: add prefix
words.append("▁" + currentWord)
}
}
return words
}
private func byteLevelPreTokenize(_ text: String) -> [String] {
// Replace space with ▁ and split
let modified = text.replacingOccurrences(of: " ", with: "▁")
return [modified]
}
}