8a66b9086a
- Started from ac75faa (initial E4B-MarkBase integration)
- Kept Sources/ (all engine code) + Package.swift + .gitignore
- Removed all ad-hoc tests, documentation, scripts, Python files
- Added Tests/00_Unit/ (MathTest, TokenizerTest, SamplerTest)
- Added .gitea/workflows/ci.yaml (build + unit tests + lint)
- Added Scripts/check_resources.sh (memory-aware test runner)
- Added Tests/Manifest.json (resource requirements for all tests)
- Focus: 4-bit quantized models only
129 lines
4.7 KiB
Swift
129 lines
4.7 KiB
Swift
import Metal
|
|
|
|
public struct AudioConfig12B {
|
|
public let outputDim: Int
|
|
public let audioDim: Int
|
|
public let groupSize: Int
|
|
|
|
public init(outputDim: Int = 3840, audioDim: Int = 640, groupSize: Int = 64) {
|
|
self.outputDim = outputDim
|
|
self.audioDim = audioDim
|
|
self.groupSize = groupSize
|
|
}
|
|
}
|
|
|
|
public struct AudioWeights12B {
|
|
public let projectionWeight: MTLBuffer
|
|
public let projectionScales: MTLBuffer
|
|
public let projectionBiases: MTLBuffer
|
|
public let numGroups: Int
|
|
public let hasOutputBias: Bool
|
|
public let outputBias: MTLBuffer?
|
|
|
|
public init(device: MTLDevice,
|
|
weightData: [UInt32],
|
|
scalesData: [Float],
|
|
biasesData: [Float],
|
|
numGroups: Int,
|
|
outputBias: [Float]? = nil) throws {
|
|
projectionWeight = device.makeBuffer(bytes: weightData, length: weightData.count * 4)!
|
|
projectionScales = device.makeBuffer(bytes: scalesData, length: scalesData.count * 4)!
|
|
projectionBiases = device.makeBuffer(bytes: biasesData, length: biasesData.count * 4)!
|
|
self.numGroups = numGroups
|
|
|
|
if let bias = outputBias {
|
|
self.outputBias = device.makeBuffer(bytes: bias, length: bias.count * 4)
|
|
self.hasOutputBias = true
|
|
} else {
|
|
self.outputBias = nil
|
|
self.hasOutputBias = false
|
|
}
|
|
}
|
|
}
|
|
|
|
public final class AudioTower12B {
|
|
public let config: AudioConfig12B
|
|
public let weights: AudioWeights12B
|
|
public let engine: MarkBaseEngine
|
|
|
|
public let inDim: Int
|
|
public let outDim: Int
|
|
|
|
public init(config: AudioConfig12B, engine: MarkBaseEngine, weights: AudioWeights12B) {
|
|
self.config = config
|
|
self.weights = weights
|
|
self.engine = engine
|
|
|
|
self.inDim = config.audioDim
|
|
self.outDim = config.outputDim
|
|
}
|
|
|
|
public func forward(inputBuffer: MTLBuffer, seqLen: Int, outputBuffer: MTLBuffer) throws {
|
|
let cmdBuf = engine.commandQueue.makeCommandBuffer()!
|
|
defer { cmdBuf.commit(); cmdBuf.waitUntilCompleted() }
|
|
|
|
let pso = try engine.pipeline(named: "quantized_matmul_seq")
|
|
let enc = cmdBuf.makeComputeCommandEncoder()!
|
|
enc.setComputePipelineState(pso)
|
|
|
|
enc.setBuffer(inputBuffer, offset: 0, index: 0)
|
|
enc.setBuffer(weights.projectionWeight, offset: 0, index: 1)
|
|
enc.setBuffer(weights.projectionScales, offset: 0, index: 2)
|
|
enc.setBuffer(weights.projectionBiases, offset: 0, index: 3)
|
|
|
|
if let bias = weights.outputBias {
|
|
enc.setBuffer(bias, offset: 0, index: 4)
|
|
} else {
|
|
enc.setBuffer(weights.projectionBiases, offset: 0, index: 4)
|
|
}
|
|
|
|
enc.setBuffer(outputBuffer, offset: 0, index: 5)
|
|
|
|
var inD = UInt32(inDim)
|
|
enc.setBytes(&inD, length: 4, index: 6)
|
|
var outD = UInt32(outDim)
|
|
enc.setBytes(&outD, length: 4, index: 7)
|
|
var hasBias = weights.hasOutputBias
|
|
enc.setBytes(&hasBias, length: 1, index: 8)
|
|
var sl = UInt32(seqLen)
|
|
enc.setBytes(&sl, length: 4, index: 9)
|
|
|
|
let grid = MTLSize(width: outDim, height: seqLen, depth: 1)
|
|
let tg = engine.threadgroupSize2D(pso, grid: (outDim, seqLen))
|
|
enc.dispatchThreads(grid, threadsPerThreadgroup: tg)
|
|
enc.endEncoding()
|
|
}
|
|
|
|
public static func load(modelDir: String, engine: MarkBaseEngine) throws -> AudioTower12B {
|
|
let device = engine.device
|
|
let shardFile = "model-00002-of-00002.safetensors"
|
|
let reader = try SafeTensorsReader(path: "\(modelDir)/\(shardFile)")
|
|
|
|
let weightData = try reader.readUint32(named: "embed_audio.embedding_projection.weight")
|
|
let scalesRaw = try reader.read(named: "embed_audio.embedding_projection.scales")
|
|
let scalesData = SafeTensorsReader.bf16ToFloat32(scalesRaw)
|
|
let biasesRaw = try reader.read(named: "embed_audio.embedding_projection.biases")
|
|
let biasesData = SafeTensorsReader.bf16ToFloat32(biasesRaw)
|
|
|
|
let numWeights = weightData.count
|
|
let numScales = scalesData.count
|
|
|
|
let audioDim = 640
|
|
let packedInDim = audioDim / 8
|
|
let outDim = numWeights / packedInDim
|
|
let numGroups = packedInDim / 8
|
|
|
|
let weights = try AudioWeights12B(
|
|
device: device,
|
|
weightData: weightData,
|
|
scalesData: scalesData,
|
|
biasesData: biasesData,
|
|
numGroups: numGroups,
|
|
outputBias: nil
|
|
)
|
|
|
|
let config = AudioConfig12B(outputDim: outDim, audioDim: audioDim, groupSize: 64)
|
|
print(" AudioTower12B: inDim=\(audioDim), outDim=\(outDim), numGroups=\(numGroups)")
|
|
return AudioTower12B(config: config, engine: engine, weights: weights)
|
|
}
|
|
} |