8a66b9086a
- Started from ac75faa (initial E4B-MarkBase integration)
- Kept Sources/ (all engine code) + Package.swift + .gitignore
- Removed all ad-hoc tests, documentation, scripts, Python files
- Added Tests/00_Unit/ (MathTest, TokenizerTest, SamplerTest)
- Added .gitea/workflows/ci.yaml (build + unit tests + lint)
- Added Scripts/check_resources.sh (memory-aware test runner)
- Added Tests/Manifest.json (resource requirements for all tests)
- Focus: 4-bit quantized models only
610 lines
30 KiB
Swift
610 lines
30 KiB
Swift
import Metal
|
|
|
|
// E2B audio tower uses bfloat16 weights (not quantized)
|
|
// Linear weights are full bfloat16, not uint32 packed
|
|
|
|
public struct AudioLayerWeightsE2B {
|
|
public let normPreAttn: MTLBuffer
|
|
public let normPostAttn: MTLBuffer
|
|
public let normOut: MTLBuffer
|
|
|
|
public let selfAttnQProjWeight: MTLBuffer
|
|
public let selfAttnKProjWeight: MTLBuffer
|
|
public let selfAttnVProjWeight: MTLBuffer
|
|
public let selfAttnPostWeight: MTLBuffer
|
|
public let selfAttnRelativeKProj: MTLBuffer
|
|
public let selfAttnPerDimScale: MTLBuffer
|
|
|
|
public let lconv1dPreLayerNorm: MTLBuffer
|
|
public let lconv1dConvNorm: MTLBuffer
|
|
public let lconv1dDepthwiseConv: MTLBuffer
|
|
public let lconv1dLinearStartWeight: MTLBuffer
|
|
public let lconv1dLinearEndWeight: MTLBuffer
|
|
|
|
public let feedForward1Layer1Weight: MTLBuffer
|
|
public let feedForward1Layer2Weight: MTLBuffer
|
|
public let feedForward1PreLayerNorm: MTLBuffer
|
|
public let feedForward1PostLayerNorm: MTLBuffer
|
|
|
|
public let feedForward2Layer1Weight: MTLBuffer
|
|
public let feedForward2Layer2Weight: MTLBuffer
|
|
public let feedForward2PreLayerNorm: MTLBuffer
|
|
public let feedForward2PostLayerNorm: MTLBuffer
|
|
|
|
private static func buffer(_ device: MTLDevice, _ floats: [String: [Float]],
|
|
_ key: String) throws -> MTLBuffer {
|
|
guard let f = floats[key] else {
|
|
throw WeightError.tensorNotFound(key)
|
|
}
|
|
guard let buf = device.makeBuffer(bytes: f, length: f.count * MemoryLayout<Float>.stride) else {
|
|
throw WeightError.tensorNotFound("Failed to create buffer for \(key)")
|
|
}
|
|
return buf
|
|
}
|
|
|
|
public init(device: MTLDevice, layerIdx: Int,
|
|
floats: [String: [Float]]) throws {
|
|
let P = "audio_tower.layers.\(layerIdx)."
|
|
|
|
normPreAttn = try Self.buffer(device, floats, P + "norm_pre_attn.weight")
|
|
normPostAttn = try Self.buffer(device, floats, P + "norm_post_attn.weight")
|
|
normOut = try Self.buffer(device, floats, P + "norm_out.weight")
|
|
|
|
// Attention projections - use linear.weight suffix for E2B
|
|
selfAttnQProjWeight = try Self.buffer(device, floats, P + "self_attn.q_proj.linear.weight")
|
|
selfAttnKProjWeight = try Self.buffer(device, floats, P + "self_attn.k_proj.linear.weight")
|
|
selfAttnVProjWeight = try Self.buffer(device, floats, P + "self_attn.v_proj.linear.weight")
|
|
selfAttnPostWeight = try Self.buffer(device, floats, P + "self_attn.post.linear.weight")
|
|
selfAttnRelativeKProj = try Self.buffer(device, floats, P + "self_attn.relative_k_proj.weight")
|
|
selfAttnPerDimScale = try Self.buffer(device, floats, P + "self_attn.per_dim_scale")
|
|
|
|
// LConv1D
|
|
lconv1dPreLayerNorm = try Self.buffer(device, floats, P + "lconv1d.pre_layer_norm.weight")
|
|
lconv1dConvNorm = try Self.buffer(device, floats, P + "lconv1d.conv_norm.weight")
|
|
lconv1dDepthwiseConv = try Self.buffer(device, floats, P + "lconv1d.depthwise_conv1d.weight")
|
|
lconv1dLinearStartWeight = try Self.buffer(device, floats, P + "lconv1d.linear_start.linear.weight")
|
|
lconv1dLinearEndWeight = try Self.buffer(device, floats, P + "lconv1d.linear_end.linear.weight")
|
|
|
|
// FeedForward 1
|
|
feedForward1Layer1Weight = try Self.buffer(device, floats, P + "feed_forward1.ffw_layer_1.linear.weight")
|
|
feedForward1Layer2Weight = try Self.buffer(device, floats, P + "feed_forward1.ffw_layer_2.linear.weight")
|
|
feedForward1PreLayerNorm = try Self.buffer(device, floats, P + "feed_forward1.pre_layer_norm.weight")
|
|
feedForward1PostLayerNorm = try Self.buffer(device, floats, P + "feed_forward1.post_layer_norm.weight")
|
|
|
|
// FeedForward 2
|
|
feedForward2Layer1Weight = try Self.buffer(device, floats, P + "feed_forward2.ffw_layer_1.linear.weight")
|
|
feedForward2Layer2Weight = try Self.buffer(device, floats, P + "feed_forward2.ffw_layer_2.linear.weight")
|
|
feedForward2PreLayerNorm = try Self.buffer(device, floats, P + "feed_forward2.pre_layer_norm.weight")
|
|
feedForward2PostLayerNorm = try Self.buffer(device, floats, P + "feed_forward2.post_layer_norm.weight")
|
|
}
|
|
}
|
|
|
|
public struct AudioWeightsE2B {
|
|
public let subsampleConvLayer0: SubsampleConvLayer
|
|
public let subsampleConvLayer1: SubsampleConvLayer
|
|
public let inputProjLinearWeight: MTLBuffer
|
|
|
|
public let outputProjWeight: MTLBuffer
|
|
public let outputProjBias: MTLBuffer
|
|
|
|
public let layers: [AudioLayerWeightsE2B]
|
|
|
|
public init(device: MTLDevice, config: AudioConfig,
|
|
floats: [String: [Float]]) throws {
|
|
let P = "audio_tower."
|
|
|
|
subsampleConvLayer0 = SubsampleConvLayer(
|
|
convWeight: try Self.buffer(device, floats, P + "subsample_conv_projection.layer0.conv.weight"),
|
|
normWeight: try Self.buffer(device, floats, P + "subsample_conv_projection.layer0.norm.weight")
|
|
)
|
|
subsampleConvLayer1 = SubsampleConvLayer(
|
|
convWeight: try Self.buffer(device, floats, P + "subsample_conv_projection.layer1.conv.weight"),
|
|
normWeight: try Self.buffer(device, floats, P + "subsample_conv_projection.layer1.norm.weight")
|
|
)
|
|
inputProjLinearWeight = try Self.buffer(device, floats, P + "subsample_conv_projection.input_proj_linear.weight")
|
|
|
|
outputProjWeight = try Self.buffer(device, floats, P + "output_proj.weight")
|
|
outputProjBias = try Self.buffer(device, floats, P + "output_proj.bias")
|
|
|
|
var loadedLayers: [AudioLayerWeightsE2B] = []
|
|
for i in 0..<config.numHiddenLayers {
|
|
loadedLayers.append(try AudioLayerWeightsE2B(device: device, layerIdx: i, floats: floats))
|
|
}
|
|
layers = loadedLayers
|
|
}
|
|
|
|
private static func buffer(_ device: MTLDevice, _ floats: [String: [Float]],
|
|
_ key: String) throws -> MTLBuffer {
|
|
guard let f = floats[key] else {
|
|
throw WeightError.tensorNotFound(key)
|
|
}
|
|
guard let buf = device.makeBuffer(bytes: f, length: f.count * MemoryLayout<Float>.stride) else {
|
|
throw WeightError.tensorNotFound("Failed to create buffer for \(key)")
|
|
}
|
|
return buf
|
|
}
|
|
}
|
|
|
|
// E2B AudioTower - uses float32 weights (bfloat16 converted to float32)
|
|
public final class AudioTowerE2B {
|
|
public let config: AudioConfig
|
|
public let engine: MarkBaseEngine
|
|
public let weights: AudioWeightsE2B
|
|
|
|
private var normBuffer: MTLBuffer
|
|
private var qBuffer: MTLBuffer
|
|
private var kBuffer: MTLBuffer
|
|
private var vBuffer: MTLBuffer
|
|
private var attnOutBuffer: MTLBuffer
|
|
private var ffnBuffer: MTLBuffer
|
|
private var tempBuffer: MTLBuffer
|
|
private var subsampleBuf: MTLBuffer
|
|
|
|
public init(config: AudioConfig, engine: MarkBaseEngine, weights: AudioWeightsE2B) throws {
|
|
self.config = config
|
|
self.engine = engine
|
|
self.weights = weights
|
|
|
|
let device = engine.device
|
|
let maxSeqLen = 4096
|
|
let hiddenSize = config.hiddenSize
|
|
|
|
normBuffer = device.makeBuffer(length: hiddenSize * maxSeqLen * 4)!
|
|
qBuffer = device.makeBuffer(length: hiddenSize * maxSeqLen * 4)!
|
|
kBuffer = device.makeBuffer(length: hiddenSize * maxSeqLen * 4)!
|
|
vBuffer = device.makeBuffer(length: hiddenSize * maxSeqLen * 4)!
|
|
attnOutBuffer = device.makeBuffer(length: hiddenSize * maxSeqLen * 4)!
|
|
ffnBuffer = device.makeBuffer(length: 4096 * maxSeqLen * 4)!
|
|
tempBuffer = device.makeBuffer(length: max(hiddenSize, 4096) * maxSeqLen * 4)!
|
|
subsampleBuf = device.makeBuffer(length: max(hiddenSize, 128 * 64) * maxSeqLen * 4)!
|
|
}
|
|
|
|
public func forward(inputBuffer: MTLBuffer, seqLen: Int, outputBuffer: MTLBuffer) throws {
|
|
var current = inputBuffer
|
|
var currentLen = seqLen
|
|
|
|
let cmdBuf = engine.commandQueue.makeCommandBuffer()!
|
|
|
|
// 1. Subsample conv
|
|
let (projInput, projLen) = try applySubsampleConv(
|
|
melInput: current, nMels: 128, seqLen: currentLen, cmdBuf: cmdBuf
|
|
)
|
|
|
|
// 2. Input projection
|
|
current = try applyFloatLinear(input: projInput, weight: weights.inputProjLinearWeight,
|
|
seqLen: projLen, inDim: 1024, outDim: 1024, cmdBuf: cmdBuf)
|
|
currentLen = projLen
|
|
|
|
// 3. Audio layers
|
|
for layerWeights in weights.layers {
|
|
current = try applyLayer(input: current, weights: layerWeights,
|
|
seqLen: currentLen, cmdBuf: cmdBuf)
|
|
}
|
|
|
|
// 4. Output projection
|
|
try applyOutputProjection(input: current, seqLen: currentLen, output: outputBuffer, cmdBuf: cmdBuf)
|
|
|
|
cmdBuf.commit()
|
|
cmdBuf.waitUntilCompleted()
|
|
}
|
|
|
|
private func applySubsampleConv(melInput: MTLBuffer, nMels: Int, seqLen: Int,
|
|
cmdBuf: MTLCommandBuffer) throws -> (MTLBuffer, Int) {
|
|
let chwInput = try transposeMelToCHW(input: melInput, nMels: nMels, seqLen: seqLen, cmdBuf: cmdBuf)
|
|
|
|
let layer0Out = try applyConv2DLayer(input: chwInput, inCh: 1, height: nMels, width: seqLen,
|
|
convWeight: weights.subsampleConvLayer0.convWeight,
|
|
normWeight: weights.subsampleConvLayer0.normWeight,
|
|
outChannels: 128, outputBuffer: tempBuffer, cmdBuf: cmdBuf)
|
|
let h1 = (nMels + 1) / 2
|
|
let w1 = (seqLen + 1) / 2
|
|
|
|
let layer1Out = try applyConv2DLayer(input: layer0Out, inCh: 128, height: h1, width: w1,
|
|
convWeight: weights.subsampleConvLayer1.convWeight,
|
|
normWeight: weights.subsampleConvLayer1.normWeight,
|
|
outChannels: 32, outputBuffer: subsampleBuf, cmdBuf: cmdBuf)
|
|
let h2 = (h1 + 1) / 2
|
|
let w2 = (w1 + 1) / 2
|
|
|
|
let flatOutput = try flattenCHW(input: layer1Out, C: 32, H: h2, W: w2,
|
|
outputBuffer: tempBuffer, cmdBuf: cmdBuf)
|
|
return (flatOutput, w2)
|
|
}
|
|
|
|
private func transposeMelToCHW(input: MTLBuffer, nMels: Int, seqLen: Int,
|
|
cmdBuf: MTLCommandBuffer) throws -> MTLBuffer {
|
|
let output = subsampleBuf
|
|
let pso = try engine.pipeline(named: "transpose_2d")
|
|
let enc = cmdBuf.makeComputeCommandEncoder()!
|
|
enc.setComputePipelineState(pso)
|
|
enc.setBuffer(input, offset: 0, index: 0)
|
|
enc.setBuffer(output, offset: 0, index: 1)
|
|
var rows = UInt32(nMels)
|
|
enc.setBytes(&rows, length: 4, index: 2)
|
|
var cols = UInt32(seqLen)
|
|
enc.setBytes(&cols, length: 4, index: 3)
|
|
let grid = MTLSize(width: seqLen, height: nMels, depth: 1)
|
|
let tg = engine.threadgroupSize2D(pso, grid: (seqLen, nMels))
|
|
enc.dispatchThreads(grid, threadsPerThreadgroup: tg)
|
|
enc.endEncoding()
|
|
return output
|
|
}
|
|
|
|
private func applyConv2DLayer(input: MTLBuffer, inCh: Int, height: Int, width: Int,
|
|
convWeight: MTLBuffer, normWeight: MTLBuffer,
|
|
outChannels: Int, outputBuffer: MTLBuffer,
|
|
cmdBuf: MTLCommandBuffer) throws -> MTLBuffer {
|
|
let pso = try engine.pipeline(named: "audio_subsample_conv_2d")
|
|
let enc = cmdBuf.makeComputeCommandEncoder()!
|
|
enc.setComputePipelineState(pso)
|
|
enc.setBuffer(input, offset: 0, index: 0)
|
|
enc.setBuffer(convWeight, offset: 0, index: 1)
|
|
enc.setBuffer(normWeight, offset: 0, index: 2)
|
|
enc.setBuffer(outputBuffer, offset: 0, index: 3)
|
|
var inCh_ = UInt32(inCh)
|
|
enc.setBytes(&inCh_, length: 4, index: 4)
|
|
var outCh_ = UInt32(outChannels)
|
|
enc.setBytes(&outCh_, length: 4, index: 5)
|
|
var h_ = UInt32(height)
|
|
enc.setBytes(&h_, length: 4, index: 6)
|
|
var w_ = UInt32(width)
|
|
enc.setBytes(&w_, length: 4, index: 7)
|
|
let outH = (height + 1) / 2
|
|
let outW = (width + 1) / 2
|
|
let grid = MTLSize(width: outChannels, height: outH, depth: outW)
|
|
let tg = MTLSize(width: 8, height: 8, depth: 4)
|
|
enc.dispatchThreads(grid, threadsPerThreadgroup: tg)
|
|
enc.endEncoding()
|
|
return outputBuffer
|
|
}
|
|
|
|
private func flattenCHW(input: MTLBuffer, C: Int, H: Int, W: Int,
|
|
outputBuffer: MTLBuffer, cmdBuf: MTLCommandBuffer) throws -> MTLBuffer {
|
|
let pso = try engine.pipeline(named: "audio_flatten_chw")
|
|
let enc = cmdBuf.makeComputeCommandEncoder()!
|
|
enc.setComputePipelineState(pso)
|
|
enc.setBuffer(input, offset: 0, index: 0)
|
|
enc.setBuffer(outputBuffer, offset: 0, index: 1)
|
|
var C_ = UInt32(C)
|
|
enc.setBytes(&C_, length: 4, index: 2)
|
|
var H_ = UInt32(H)
|
|
enc.setBytes(&H_, length: 4, index: 3)
|
|
var W_ = UInt32(W)
|
|
enc.setBytes(&W_, length: 4, index: 4)
|
|
let grid = MTLSize(width: C * H, height: W, depth: 1)
|
|
let tg = engine.threadgroupSize2D(pso, grid: (C * H, W))
|
|
enc.dispatchThreads(grid, threadsPerThreadgroup: tg)
|
|
enc.endEncoding()
|
|
return outputBuffer
|
|
}
|
|
|
|
private func applyFloatLinear(input: MTLBuffer, weight: MTLBuffer, seqLen: Int,
|
|
inDim: Int, outDim: Int, cmdBuf: MTLCommandBuffer) throws -> MTLBuffer {
|
|
let output = tempBuffer
|
|
let pso = try engine.pipeline(named: "audio_linear_seq")
|
|
let enc = cmdBuf.makeComputeCommandEncoder()!
|
|
enc.setComputePipelineState(pso)
|
|
enc.setBuffer(input, offset: 0, index: 0)
|
|
enc.setBuffer(weight, offset: 0, index: 1)
|
|
enc.setBuffer(nil, offset: 0, index: 2) // No bias
|
|
enc.setBuffer(output, offset: 0, index: 3)
|
|
var inF = UInt32(inDim)
|
|
enc.setBytes(&inF, length: 4, index: 4)
|
|
var outF = UInt32(outDim)
|
|
enc.setBytes(&outF, length: 4, index: 5)
|
|
var hasBias = false
|
|
enc.setBytes(&hasBias, length: 1, index: 6)
|
|
var seqLen_ = UInt32(seqLen)
|
|
enc.setBytes(&seqLen_, length: 4, index: 7)
|
|
let grid = MTLSize(width: outDim, height: seqLen, depth: 1)
|
|
let tg = engine.threadgroupSize2D(pso, grid: (outDim, seqLen))
|
|
enc.dispatchThreads(grid, threadsPerThreadgroup: tg)
|
|
enc.endEncoding()
|
|
return output
|
|
}
|
|
|
|
private func applyLayer(input: MTLBuffer, weights: AudioLayerWeightsE2B,
|
|
seqLen: Int, cmdBuf: MTLCommandBuffer) throws -> MTLBuffer {
|
|
var current = input
|
|
|
|
// 1. Norm pre-attn
|
|
current = try applyRMSNorm(input: current, weight: weights.normPreAttn,
|
|
seqLen: seqLen, hiddenSize: config.hiddenSize, cmdBuf: cmdBuf)
|
|
|
|
// 2. Self-attention
|
|
let q = try applyFloatLinear(input: current, weight: weights.selfAttnQProjWeight,
|
|
seqLen: seqLen, inDim: config.hiddenSize, outDim: config.hiddenSize, cmdBuf: cmdBuf)
|
|
let k = try applyFloatLinear(input: current, weight: weights.selfAttnKProjWeight,
|
|
seqLen: seqLen, inDim: config.hiddenSize, outDim: config.hiddenSize, cmdBuf: cmdBuf)
|
|
let v = try applyFloatLinear(input: current, weight: weights.selfAttnVProjWeight,
|
|
seqLen: seqLen, inDim: config.hiddenSize, outDim: config.hiddenSize, cmdBuf: cmdBuf)
|
|
|
|
let attnOut = try applyAudioAttention(q: q, k: k, v: v,
|
|
relativeKProj: weights.selfAttnRelativeKProj,
|
|
perDimScale: weights.selfAttnPerDimScale,
|
|
seqLen: seqLen, cmdBuf: cmdBuf)
|
|
|
|
let post = try applyFloatLinear(input: attnOut, weight: weights.selfAttnPostWeight,
|
|
seqLen: seqLen, inDim: config.hiddenSize, outDim: config.hiddenSize, cmdBuf: cmdBuf)
|
|
|
|
// 3. Residual + norm
|
|
current = try applyResidualAdd(input: input, add: post, seqLen: seqLen,
|
|
hiddenSize: config.hiddenSize, cmdBuf: cmdBuf)
|
|
current = try applyRMSNorm(input: current, weight: weights.normPostAttn,
|
|
seqLen: seqLen, hiddenSize: config.hiddenSize, cmdBuf: cmdBuf)
|
|
|
|
// 4. LConv1D
|
|
let lconvOut = try applyLConv1D(input: current, weights: weights, seqLen: seqLen, cmdBuf: cmdBuf)
|
|
current = try applyResidualAdd(input: current, add: lconvOut, seqLen: seqLen,
|
|
hiddenSize: config.hiddenSize, cmdBuf: cmdBuf)
|
|
|
|
// 5. FeedForward 1
|
|
let ff1Out = try applyFeedForward(input: current,
|
|
layer1Weight: weights.feedForward1Layer1Weight,
|
|
layer2Weight: weights.feedForward1Layer2Weight,
|
|
preNorm: weights.feedForward1PreLayerNorm,
|
|
postNorm: weights.feedForward1PostLayerNorm,
|
|
seqLen: seqLen, cmdBuf: cmdBuf)
|
|
current = try applyResidualAdd(input: current, add: ff1Out, seqLen: seqLen,
|
|
hiddenSize: config.hiddenSize, cmdBuf: cmdBuf)
|
|
|
|
// 6. FeedForward 2
|
|
let ff2Out = try applyFeedForward(input: current,
|
|
layer1Weight: weights.feedForward2Layer1Weight,
|
|
layer2Weight: weights.feedForward2Layer2Weight,
|
|
preNorm: weights.feedForward2PreLayerNorm,
|
|
postNorm: weights.feedForward2PostLayerNorm,
|
|
seqLen: seqLen, cmdBuf: cmdBuf)
|
|
current = try applyResidualAdd(input: current, add: ff2Out, seqLen: seqLen,
|
|
hiddenSize: config.hiddenSize, cmdBuf: cmdBuf)
|
|
current = try applyRMSNorm(input: current, weight: weights.normOut,
|
|
seqLen: seqLen, hiddenSize: config.hiddenSize, cmdBuf: cmdBuf)
|
|
|
|
return current
|
|
}
|
|
|
|
private func applyRMSNorm(input: MTLBuffer, weight: MTLBuffer, seqLen: Int,
|
|
hiddenSize: Int, cmdBuf: MTLCommandBuffer) throws -> MTLBuffer {
|
|
let output = tempBuffer
|
|
let pso = try engine.pipeline(named: "rms_norm_seq")
|
|
let enc = cmdBuf.makeComputeCommandEncoder()!
|
|
enc.setComputePipelineState(pso)
|
|
enc.setBuffer(input, offset: 0, index: 0)
|
|
enc.setBuffer(weight, offset: 0, index: 1)
|
|
enc.setBuffer(output, offset: 0, index: 2)
|
|
var N = UInt32(hiddenSize)
|
|
enc.setBytes(&N, length: 4, index: 3)
|
|
var eps = config.rmsNormEps
|
|
enc.setBytes(&eps, length: 4, index: 4)
|
|
var seqLen_ = UInt32(seqLen)
|
|
enc.setBytes(&seqLen_, length: 4, index: 5)
|
|
let grid = MTLSize(width: hiddenSize, height: seqLen, depth: 1)
|
|
let tg = engine.threadgroupSize2D(pso, grid: (hiddenSize, seqLen))
|
|
enc.dispatchThreads(grid, threadsPerThreadgroup: tg)
|
|
enc.endEncoding()
|
|
return output
|
|
}
|
|
|
|
private func applyAudioAttention(q: MTLBuffer, k: MTLBuffer, v: MTLBuffer,
|
|
relativeKProj: MTLBuffer, perDimScale: MTLBuffer,
|
|
seqLen: Int, cmdBuf: MTLCommandBuffer) throws -> MTLBuffer {
|
|
let output = attnOutBuffer
|
|
let pso = try engine.pipeline(named: "audio_attention_full")
|
|
let enc = cmdBuf.makeComputeCommandEncoder()!
|
|
enc.setComputePipelineState(pso)
|
|
enc.setBuffer(q, offset: 0, index: 0)
|
|
enc.setBuffer(k, offset: 0, index: 1)
|
|
enc.setBuffer(v, offset: 0, index: 2)
|
|
enc.setBuffer(relativeKProj, offset: 0, index: 3)
|
|
enc.setBuffer(perDimScale, offset: 0, index: 4)
|
|
enc.setBuffer(output, offset: 0, index: 5)
|
|
var seqLen_ = UInt32(seqLen)
|
|
enc.setBytes(&seqLen_, length: 4, index: 6)
|
|
var numHeads = UInt32(config.numAttentionHeads)
|
|
enc.setBytes(&numHeads, length: 4, index: 7)
|
|
var headDim = UInt32(config.headDim)
|
|
enc.setBytes(&headDim, length: 4, index: 8)
|
|
var contextLeft = UInt32(config.attentionContextLeft)
|
|
enc.setBytes(&contextLeft, length: 4, index: 9)
|
|
var logitCap = config.attentionLogitCap
|
|
enc.setBytes(&logitCap, length: 4, index: 10)
|
|
let grid = MTLSize(width: config.numAttentionHeads * config.headDim, height: seqLen, depth: 1)
|
|
let tg = engine.threadgroupSize2D(pso, grid: (config.numAttentionHeads * config.headDim, seqLen))
|
|
enc.dispatchThreads(grid, threadsPerThreadgroup: tg)
|
|
enc.endEncoding()
|
|
return output
|
|
}
|
|
|
|
private func applyLConv1D(input: MTLBuffer, weights: AudioLayerWeightsE2B,
|
|
seqLen: Int, cmdBuf: MTLCommandBuffer) throws -> MTLBuffer {
|
|
var current = try applyRMSNorm(input: input, weight: weights.lconv1dPreLayerNorm,
|
|
seqLen: seqLen, hiddenSize: config.hiddenSize, cmdBuf: cmdBuf)
|
|
|
|
let linearStart = try applyFloatLinear(input: current, weight: weights.lconv1dLinearStartWeight,
|
|
seqLen: seqLen, inDim: config.hiddenSize,
|
|
outDim: config.hiddenSize * 2, cmdBuf: cmdBuf)
|
|
let activated = try applySiLU(input: linearStart, count: seqLen * config.hiddenSize * 2, cmdBuf: cmdBuf)
|
|
let convOut = try applyDepthwiseConv1D(input: activated, weight: weights.lconv1dDepthwiseConv,
|
|
norm: weights.lconv1dConvNorm, seqLen: seqLen,
|
|
channels: config.hiddenSize * 2, kernelSize: config.convKernelSize,
|
|
cmdBuf: cmdBuf)
|
|
let linearEnd = try applyFloatLinear(input: convOut, weight: weights.lconv1dLinearEndWeight,
|
|
seqLen: seqLen, inDim: config.hiddenSize * 2,
|
|
outDim: config.hiddenSize, cmdBuf: cmdBuf)
|
|
return linearEnd
|
|
}
|
|
|
|
private func applyDepthwiseConv1D(input: MTLBuffer, weight: MTLBuffer, norm: MTLBuffer,
|
|
seqLen: Int, channels: Int, kernelSize: Int,
|
|
cmdBuf: MTLCommandBuffer) throws -> MTLBuffer {
|
|
let output = tempBuffer
|
|
let pso = try engine.pipeline(named: "audio_depthwise_conv1d")
|
|
let enc = cmdBuf.makeComputeCommandEncoder()!
|
|
enc.setComputePipelineState(pso)
|
|
enc.setBuffer(input, offset: 0, index: 0)
|
|
enc.setBuffer(weight, offset: 0, index: 1)
|
|
enc.setBuffer(norm, offset: 0, index: 2)
|
|
enc.setBuffer(output, offset: 0, index: 3)
|
|
var channels_ = UInt32(channels)
|
|
enc.setBytes(&channels_, length: 4, index: 4)
|
|
var kernelSize_ = UInt32(kernelSize)
|
|
enc.setBytes(&kernelSize_, length: 4, index: 5)
|
|
var seqLen_ = UInt32(seqLen)
|
|
enc.setBytes(&seqLen_, length: 4, index: 6)
|
|
let grid = MTLSize(width: channels, height: seqLen, depth: 1)
|
|
let tg = engine.threadgroupSize2D(pso, grid: (channels, seqLen))
|
|
enc.dispatchThreads(grid, threadsPerThreadgroup: tg)
|
|
enc.endEncoding()
|
|
return output
|
|
}
|
|
|
|
private func applyFeedForward(input: MTLBuffer, layer1Weight: MTLBuffer, layer2Weight: MTLBuffer,
|
|
preNorm: MTLBuffer, postNorm: MTLBuffer,
|
|
seqLen: Int, cmdBuf: MTLCommandBuffer) throws -> MTLBuffer {
|
|
var current = try applyRMSNorm(input: input, weight: preNorm,
|
|
seqLen: seqLen, hiddenSize: config.hiddenSize, cmdBuf: cmdBuf)
|
|
let layer1 = try applyFloatLinear(input: current, weight: layer1Weight,
|
|
seqLen: seqLen, inDim: config.hiddenSize, outDim: 4096, cmdBuf: cmdBuf)
|
|
let activated = try applySiLU(input: layer1, count: seqLen * 4096, cmdBuf: cmdBuf)
|
|
let layer2 = try applyFloatLinear(input: activated, weight: layer2Weight,
|
|
seqLen: seqLen, inDim: 4096, outDim: config.hiddenSize, cmdBuf: cmdBuf)
|
|
return try applyRMSNorm(input: layer2, weight: postNorm,
|
|
seqLen: seqLen, hiddenSize: config.hiddenSize, cmdBuf: cmdBuf)
|
|
}
|
|
|
|
private func applySiLU(input: MTLBuffer, count: Int, cmdBuf: MTLCommandBuffer) throws -> MTLBuffer {
|
|
let output = tempBuffer
|
|
let pso = try engine.pipeline(named: "silu")
|
|
let enc = cmdBuf.makeComputeCommandEncoder()!
|
|
enc.setComputePipelineState(pso)
|
|
enc.setBuffer(input, offset: 0, index: 0)
|
|
enc.setBuffer(output, offset: 0, index: 1)
|
|
var count_ = UInt32(count)
|
|
enc.setBytes(&count_, length: 4, index: 2)
|
|
let grid = MTLSize(width: count, height: 1, depth: 1)
|
|
let tg = engine.threadgroupSize1D(pso, count: count)
|
|
enc.dispatchThreads(grid, threadsPerThreadgroup: tg)
|
|
enc.endEncoding()
|
|
return output
|
|
}
|
|
|
|
private func applyResidualAdd(input: MTLBuffer, add: MTLBuffer, seqLen: Int,
|
|
hiddenSize: Int, cmdBuf: MTLCommandBuffer) throws -> MTLBuffer {
|
|
let output = tempBuffer
|
|
let pso = try engine.pipeline(named: "residual_add")
|
|
let enc = cmdBuf.makeComputeCommandEncoder()!
|
|
enc.setComputePipelineState(pso)
|
|
enc.setBuffer(input, offset: 0, index: 0)
|
|
enc.setBuffer(add, offset: 0, index: 1)
|
|
enc.setBuffer(output, offset: 0, index: 2)
|
|
var count32 = UInt32(seqLen * hiddenSize)
|
|
enc.setBytes(&count32, length: 4, index: 3)
|
|
var weight = config.residualWeight
|
|
enc.setBytes(&weight, length: 4, index: 4)
|
|
let count = seqLen * hiddenSize
|
|
let grid = MTLSize(width: count, height: 1, depth: 1)
|
|
let tg = engine.threadgroupSize1D(pso, count: count)
|
|
enc.dispatchThreads(grid, threadsPerThreadgroup: tg)
|
|
enc.endEncoding()
|
|
return output
|
|
}
|
|
|
|
private func applyOutputProjection(input: MTLBuffer, seqLen: Int, output: MTLBuffer,
|
|
cmdBuf: MTLCommandBuffer) throws {
|
|
let pso = try engine.pipeline(named: "audio_linear_seq")
|
|
let enc = cmdBuf.makeComputeCommandEncoder()!
|
|
enc.setComputePipelineState(pso)
|
|
enc.setBuffer(input, offset: 0, index: 0)
|
|
enc.setBuffer(weights.outputProjWeight, offset: 0, index: 1)
|
|
enc.setBuffer(weights.outputProjBias, offset: 0, index: 2)
|
|
enc.setBuffer(output, offset: 0, index: 3)
|
|
var inF = UInt32(config.hiddenSize)
|
|
enc.setBytes(&inF, length: 4, index: 4)
|
|
var outF = UInt32(config.outputProjDims)
|
|
enc.setBytes(&outF, length: 4, index: 5)
|
|
var hasBias = true
|
|
enc.setBytes(&hasBias, length: 1, index: 6)
|
|
var seqLen_ = UInt32(seqLen)
|
|
enc.setBytes(&seqLen_, length: 4, index: 7)
|
|
let grid = MTLSize(width: config.outputProjDims, height: seqLen, depth: 1)
|
|
let tg = engine.threadgroupSize2D(pso, grid: (config.outputProjDims, seqLen))
|
|
enc.dispatchThreads(grid, threadsPerThreadgroup: tg)
|
|
enc.endEncoding()
|
|
}
|
|
}
|
|
|
|
// E2B audio tower loader function
|
|
func loadAudioTowerE2B(reader: SafeTensorsReader, config: AudioConfig,
|
|
engine: MarkBaseEngine) throws -> AudioTowerE2B {
|
|
print("Loading E2B Audio Tower with preload optimization...")
|
|
let startTime = Date()
|
|
|
|
// Collect all audio tensor descriptors
|
|
let audioPrefix = "audio_tower."
|
|
let audioDescriptors = reader.allDescriptors().filter {
|
|
$0.name.hasPrefix(audioPrefix)
|
|
}
|
|
|
|
print(" Found \(audioDescriptors.count) audio tensors")
|
|
|
|
// Parallel preload all audio tensors
|
|
let dispatchGroup = DispatchGroup()
|
|
let loadQueue = DispatchQueue(label: "audio-preload-e2b", attributes: .concurrent)
|
|
var loadedData: [Data?] = Array(repeating: nil, count: audioDescriptors.count)
|
|
var loadErrors: [Error?] = Array(repeating: nil, count: audioDescriptors.count)
|
|
|
|
for (idx, desc) in audioDescriptors.enumerated() {
|
|
dispatchGroup.enter()
|
|
loadQueue.async {
|
|
do {
|
|
let data = try reader.read(tensor: desc)
|
|
loadedData[idx] = data
|
|
} catch {
|
|
loadErrors[idx] = error
|
|
}
|
|
dispatchGroup.leave()
|
|
}
|
|
}
|
|
|
|
dispatchGroup.wait()
|
|
|
|
// Check for errors
|
|
for (idx, error) in loadErrors.enumerated() {
|
|
if let err = error {
|
|
throw WeightError.readFailed("Failed to preload audio tensor \(audioDescriptors[idx].name): \(err)")
|
|
}
|
|
}
|
|
|
|
let preloadTime = Date().timeIntervalSince(startTime) * 1000
|
|
print(" ✓ Parallel preloaded \(audioDescriptors.count) audio tensors in \(String(format: "%.1f", preloadTime))ms")
|
|
|
|
// Convert to floats dictionary
|
|
var floats: [String: [Float]] = [:]
|
|
|
|
for (idx, desc) in audioDescriptors.enumerated() {
|
|
guard let data = loadedData[idx] else { continue }
|
|
let name = desc.name
|
|
switch desc.dtype {
|
|
case .bf16:
|
|
floats[name] = SafeTensorsReader.bf16ToFloat32(data)
|
|
case .f32:
|
|
floats[name] = data.withUnsafeBytes {
|
|
Array($0.assumingMemoryBound(to: Float.self))
|
|
}
|
|
default:
|
|
break
|
|
}
|
|
}
|
|
|
|
guard !floats.isEmpty else {
|
|
throw WeightError.tensorNotFound("Audio tower tensors")
|
|
}
|
|
|
|
let weights = try AudioWeightsE2B(device: engine.device, config: config, floats: floats)
|
|
|
|
let totalTime = Date().timeIntervalSince(startTime) * 1000
|
|
print(" ✓ E2B Audio Tower loaded in \(String(format: "%.1f", totalTime))ms")
|
|
|
|
return try AudioTowerE2B(config: config, engine: engine, weights: weights)
|
|
}
|