Files
markbaseengine/Sources/MarkBase/Audio/AudioTowerE2B.swift
T
MarkBase Admin 8a66b9086a
CI / build (push) Waiting to run
CI / unit-tests (push) Blocked by required conditions
CI / lint (push) Blocked by required conditions
v2: Initial clean branch with unit tests + CI/CD pipeline
- Started from ac75faa (initial E4B-MarkBase integration)
- Kept Sources/ (all engine code) + Package.swift + .gitignore
- Removed all ad-hoc tests, documentation, scripts, Python files
- Added Tests/00_Unit/ (MathTest, TokenizerTest, SamplerTest)
- Added .gitea/workflows/ci.yaml (build + unit tests + lint)
- Added Scripts/check_resources.sh (memory-aware test runner)
- Added Tests/Manifest.json (resource requirements for all tests)
- Focus: 4-bit quantized models only
2026-07-05 13:29:25 +08:00

610 lines
30 KiB
Swift

import Metal
// E2B audio tower uses bfloat16 weights (not quantized)
// Linear weights are full bfloat16, not uint32 packed
public struct AudioLayerWeightsE2B {
public let normPreAttn: MTLBuffer
public let normPostAttn: MTLBuffer
public let normOut: MTLBuffer
public let selfAttnQProjWeight: MTLBuffer
public let selfAttnKProjWeight: MTLBuffer
public let selfAttnVProjWeight: MTLBuffer
public let selfAttnPostWeight: MTLBuffer
public let selfAttnRelativeKProj: MTLBuffer
public let selfAttnPerDimScale: MTLBuffer
public let lconv1dPreLayerNorm: MTLBuffer
public let lconv1dConvNorm: MTLBuffer
public let lconv1dDepthwiseConv: MTLBuffer
public let lconv1dLinearStartWeight: MTLBuffer
public let lconv1dLinearEndWeight: MTLBuffer
public let feedForward1Layer1Weight: MTLBuffer
public let feedForward1Layer2Weight: MTLBuffer
public let feedForward1PreLayerNorm: MTLBuffer
public let feedForward1PostLayerNorm: MTLBuffer
public let feedForward2Layer1Weight: MTLBuffer
public let feedForward2Layer2Weight: MTLBuffer
public let feedForward2PreLayerNorm: MTLBuffer
public let feedForward2PostLayerNorm: MTLBuffer
private static func buffer(_ device: MTLDevice, _ floats: [String: [Float]],
_ key: String) throws -> MTLBuffer {
guard let f = floats[key] else {
throw WeightError.tensorNotFound(key)
}
guard let buf = device.makeBuffer(bytes: f, length: f.count * MemoryLayout<Float>.stride) else {
throw WeightError.tensorNotFound("Failed to create buffer for \(key)")
}
return buf
}
public init(device: MTLDevice, layerIdx: Int,
floats: [String: [Float]]) throws {
let P = "audio_tower.layers.\(layerIdx)."
normPreAttn = try Self.buffer(device, floats, P + "norm_pre_attn.weight")
normPostAttn = try Self.buffer(device, floats, P + "norm_post_attn.weight")
normOut = try Self.buffer(device, floats, P + "norm_out.weight")
// Attention projections - use linear.weight suffix for E2B
selfAttnQProjWeight = try Self.buffer(device, floats, P + "self_attn.q_proj.linear.weight")
selfAttnKProjWeight = try Self.buffer(device, floats, P + "self_attn.k_proj.linear.weight")
selfAttnVProjWeight = try Self.buffer(device, floats, P + "self_attn.v_proj.linear.weight")
selfAttnPostWeight = try Self.buffer(device, floats, P + "self_attn.post.linear.weight")
selfAttnRelativeKProj = try Self.buffer(device, floats, P + "self_attn.relative_k_proj.weight")
selfAttnPerDimScale = try Self.buffer(device, floats, P + "self_attn.per_dim_scale")
// LConv1D
lconv1dPreLayerNorm = try Self.buffer(device, floats, P + "lconv1d.pre_layer_norm.weight")
lconv1dConvNorm = try Self.buffer(device, floats, P + "lconv1d.conv_norm.weight")
lconv1dDepthwiseConv = try Self.buffer(device, floats, P + "lconv1d.depthwise_conv1d.weight")
lconv1dLinearStartWeight = try Self.buffer(device, floats, P + "lconv1d.linear_start.linear.weight")
lconv1dLinearEndWeight = try Self.buffer(device, floats, P + "lconv1d.linear_end.linear.weight")
// FeedForward 1
feedForward1Layer1Weight = try Self.buffer(device, floats, P + "feed_forward1.ffw_layer_1.linear.weight")
feedForward1Layer2Weight = try Self.buffer(device, floats, P + "feed_forward1.ffw_layer_2.linear.weight")
feedForward1PreLayerNorm = try Self.buffer(device, floats, P + "feed_forward1.pre_layer_norm.weight")
feedForward1PostLayerNorm = try Self.buffer(device, floats, P + "feed_forward1.post_layer_norm.weight")
// FeedForward 2
feedForward2Layer1Weight = try Self.buffer(device, floats, P + "feed_forward2.ffw_layer_1.linear.weight")
feedForward2Layer2Weight = try Self.buffer(device, floats, P + "feed_forward2.ffw_layer_2.linear.weight")
feedForward2PreLayerNorm = try Self.buffer(device, floats, P + "feed_forward2.pre_layer_norm.weight")
feedForward2PostLayerNorm = try Self.buffer(device, floats, P + "feed_forward2.post_layer_norm.weight")
}
}
public struct AudioWeightsE2B {
public let subsampleConvLayer0: SubsampleConvLayer
public let subsampleConvLayer1: SubsampleConvLayer
public let inputProjLinearWeight: MTLBuffer
public let outputProjWeight: MTLBuffer
public let outputProjBias: MTLBuffer
public let layers: [AudioLayerWeightsE2B]
public init(device: MTLDevice, config: AudioConfig,
floats: [String: [Float]]) throws {
let P = "audio_tower."
subsampleConvLayer0 = SubsampleConvLayer(
convWeight: try Self.buffer(device, floats, P + "subsample_conv_projection.layer0.conv.weight"),
normWeight: try Self.buffer(device, floats, P + "subsample_conv_projection.layer0.norm.weight")
)
subsampleConvLayer1 = SubsampleConvLayer(
convWeight: try Self.buffer(device, floats, P + "subsample_conv_projection.layer1.conv.weight"),
normWeight: try Self.buffer(device, floats, P + "subsample_conv_projection.layer1.norm.weight")
)
inputProjLinearWeight = try Self.buffer(device, floats, P + "subsample_conv_projection.input_proj_linear.weight")
outputProjWeight = try Self.buffer(device, floats, P + "output_proj.weight")
outputProjBias = try Self.buffer(device, floats, P + "output_proj.bias")
var loadedLayers: [AudioLayerWeightsE2B] = []
for i in 0..<config.numHiddenLayers {
loadedLayers.append(try AudioLayerWeightsE2B(device: device, layerIdx: i, floats: floats))
}
layers = loadedLayers
}
private static func buffer(_ device: MTLDevice, _ floats: [String: [Float]],
_ key: String) throws -> MTLBuffer {
guard let f = floats[key] else {
throw WeightError.tensorNotFound(key)
}
guard let buf = device.makeBuffer(bytes: f, length: f.count * MemoryLayout<Float>.stride) else {
throw WeightError.tensorNotFound("Failed to create buffer for \(key)")
}
return buf
}
}
// E2B AudioTower - uses float32 weights (bfloat16 converted to float32)
public final class AudioTowerE2B {
public let config: AudioConfig
public let engine: MarkBaseEngine
public let weights: AudioWeightsE2B
private var normBuffer: MTLBuffer
private var qBuffer: MTLBuffer
private var kBuffer: MTLBuffer
private var vBuffer: MTLBuffer
private var attnOutBuffer: MTLBuffer
private var ffnBuffer: MTLBuffer
private var tempBuffer: MTLBuffer
private var subsampleBuf: MTLBuffer
public init(config: AudioConfig, engine: MarkBaseEngine, weights: AudioWeightsE2B) throws {
self.config = config
self.engine = engine
self.weights = weights
let device = engine.device
let maxSeqLen = 4096
let hiddenSize = config.hiddenSize
normBuffer = device.makeBuffer(length: hiddenSize * maxSeqLen * 4)!
qBuffer = device.makeBuffer(length: hiddenSize * maxSeqLen * 4)!
kBuffer = device.makeBuffer(length: hiddenSize * maxSeqLen * 4)!
vBuffer = device.makeBuffer(length: hiddenSize * maxSeqLen * 4)!
attnOutBuffer = device.makeBuffer(length: hiddenSize * maxSeqLen * 4)!
ffnBuffer = device.makeBuffer(length: 4096 * maxSeqLen * 4)!
tempBuffer = device.makeBuffer(length: max(hiddenSize, 4096) * maxSeqLen * 4)!
subsampleBuf = device.makeBuffer(length: max(hiddenSize, 128 * 64) * maxSeqLen * 4)!
}
public func forward(inputBuffer: MTLBuffer, seqLen: Int, outputBuffer: MTLBuffer) throws {
var current = inputBuffer
var currentLen = seqLen
let cmdBuf = engine.commandQueue.makeCommandBuffer()!
// 1. Subsample conv
let (projInput, projLen) = try applySubsampleConv(
melInput: current, nMels: 128, seqLen: currentLen, cmdBuf: cmdBuf
)
// 2. Input projection
current = try applyFloatLinear(input: projInput, weight: weights.inputProjLinearWeight,
seqLen: projLen, inDim: 1024, outDim: 1024, cmdBuf: cmdBuf)
currentLen = projLen
// 3. Audio layers
for layerWeights in weights.layers {
current = try applyLayer(input: current, weights: layerWeights,
seqLen: currentLen, cmdBuf: cmdBuf)
}
// 4. Output projection
try applyOutputProjection(input: current, seqLen: currentLen, output: outputBuffer, cmdBuf: cmdBuf)
cmdBuf.commit()
cmdBuf.waitUntilCompleted()
}
private func applySubsampleConv(melInput: MTLBuffer, nMels: Int, seqLen: Int,
cmdBuf: MTLCommandBuffer) throws -> (MTLBuffer, Int) {
let chwInput = try transposeMelToCHW(input: melInput, nMels: nMels, seqLen: seqLen, cmdBuf: cmdBuf)
let layer0Out = try applyConv2DLayer(input: chwInput, inCh: 1, height: nMels, width: seqLen,
convWeight: weights.subsampleConvLayer0.convWeight,
normWeight: weights.subsampleConvLayer0.normWeight,
outChannels: 128, outputBuffer: tempBuffer, cmdBuf: cmdBuf)
let h1 = (nMels + 1) / 2
let w1 = (seqLen + 1) / 2
let layer1Out = try applyConv2DLayer(input: layer0Out, inCh: 128, height: h1, width: w1,
convWeight: weights.subsampleConvLayer1.convWeight,
normWeight: weights.subsampleConvLayer1.normWeight,
outChannels: 32, outputBuffer: subsampleBuf, cmdBuf: cmdBuf)
let h2 = (h1 + 1) / 2
let w2 = (w1 + 1) / 2
let flatOutput = try flattenCHW(input: layer1Out, C: 32, H: h2, W: w2,
outputBuffer: tempBuffer, cmdBuf: cmdBuf)
return (flatOutput, w2)
}
private func transposeMelToCHW(input: MTLBuffer, nMels: Int, seqLen: Int,
cmdBuf: MTLCommandBuffer) throws -> MTLBuffer {
let output = subsampleBuf
let pso = try engine.pipeline(named: "transpose_2d")
let enc = cmdBuf.makeComputeCommandEncoder()!
enc.setComputePipelineState(pso)
enc.setBuffer(input, offset: 0, index: 0)
enc.setBuffer(output, offset: 0, index: 1)
var rows = UInt32(nMels)
enc.setBytes(&rows, length: 4, index: 2)
var cols = UInt32(seqLen)
enc.setBytes(&cols, length: 4, index: 3)
let grid = MTLSize(width: seqLen, height: nMels, depth: 1)
let tg = engine.threadgroupSize2D(pso, grid: (seqLen, nMels))
enc.dispatchThreads(grid, threadsPerThreadgroup: tg)
enc.endEncoding()
return output
}
private func applyConv2DLayer(input: MTLBuffer, inCh: Int, height: Int, width: Int,
convWeight: MTLBuffer, normWeight: MTLBuffer,
outChannels: Int, outputBuffer: MTLBuffer,
cmdBuf: MTLCommandBuffer) throws -> MTLBuffer {
let pso = try engine.pipeline(named: "audio_subsample_conv_2d")
let enc = cmdBuf.makeComputeCommandEncoder()!
enc.setComputePipelineState(pso)
enc.setBuffer(input, offset: 0, index: 0)
enc.setBuffer(convWeight, offset: 0, index: 1)
enc.setBuffer(normWeight, offset: 0, index: 2)
enc.setBuffer(outputBuffer, offset: 0, index: 3)
var inCh_ = UInt32(inCh)
enc.setBytes(&inCh_, length: 4, index: 4)
var outCh_ = UInt32(outChannels)
enc.setBytes(&outCh_, length: 4, index: 5)
var h_ = UInt32(height)
enc.setBytes(&h_, length: 4, index: 6)
var w_ = UInt32(width)
enc.setBytes(&w_, length: 4, index: 7)
let outH = (height + 1) / 2
let outW = (width + 1) / 2
let grid = MTLSize(width: outChannels, height: outH, depth: outW)
let tg = MTLSize(width: 8, height: 8, depth: 4)
enc.dispatchThreads(grid, threadsPerThreadgroup: tg)
enc.endEncoding()
return outputBuffer
}
private func flattenCHW(input: MTLBuffer, C: Int, H: Int, W: Int,
outputBuffer: MTLBuffer, cmdBuf: MTLCommandBuffer) throws -> MTLBuffer {
let pso = try engine.pipeline(named: "audio_flatten_chw")
let enc = cmdBuf.makeComputeCommandEncoder()!
enc.setComputePipelineState(pso)
enc.setBuffer(input, offset: 0, index: 0)
enc.setBuffer(outputBuffer, offset: 0, index: 1)
var C_ = UInt32(C)
enc.setBytes(&C_, length: 4, index: 2)
var H_ = UInt32(H)
enc.setBytes(&H_, length: 4, index: 3)
var W_ = UInt32(W)
enc.setBytes(&W_, length: 4, index: 4)
let grid = MTLSize(width: C * H, height: W, depth: 1)
let tg = engine.threadgroupSize2D(pso, grid: (C * H, W))
enc.dispatchThreads(grid, threadsPerThreadgroup: tg)
enc.endEncoding()
return outputBuffer
}
private func applyFloatLinear(input: MTLBuffer, weight: MTLBuffer, seqLen: Int,
inDim: Int, outDim: Int, cmdBuf: MTLCommandBuffer) throws -> MTLBuffer {
let output = tempBuffer
let pso = try engine.pipeline(named: "audio_linear_seq")
let enc = cmdBuf.makeComputeCommandEncoder()!
enc.setComputePipelineState(pso)
enc.setBuffer(input, offset: 0, index: 0)
enc.setBuffer(weight, offset: 0, index: 1)
enc.setBuffer(nil, offset: 0, index: 2) // No bias
enc.setBuffer(output, offset: 0, index: 3)
var inF = UInt32(inDim)
enc.setBytes(&inF, length: 4, index: 4)
var outF = UInt32(outDim)
enc.setBytes(&outF, length: 4, index: 5)
var hasBias = false
enc.setBytes(&hasBias, length: 1, index: 6)
var seqLen_ = UInt32(seqLen)
enc.setBytes(&seqLen_, length: 4, index: 7)
let grid = MTLSize(width: outDim, height: seqLen, depth: 1)
let tg = engine.threadgroupSize2D(pso, grid: (outDim, seqLen))
enc.dispatchThreads(grid, threadsPerThreadgroup: tg)
enc.endEncoding()
return output
}
private func applyLayer(input: MTLBuffer, weights: AudioLayerWeightsE2B,
seqLen: Int, cmdBuf: MTLCommandBuffer) throws -> MTLBuffer {
var current = input
// 1. Norm pre-attn
current = try applyRMSNorm(input: current, weight: weights.normPreAttn,
seqLen: seqLen, hiddenSize: config.hiddenSize, cmdBuf: cmdBuf)
// 2. Self-attention
let q = try applyFloatLinear(input: current, weight: weights.selfAttnQProjWeight,
seqLen: seqLen, inDim: config.hiddenSize, outDim: config.hiddenSize, cmdBuf: cmdBuf)
let k = try applyFloatLinear(input: current, weight: weights.selfAttnKProjWeight,
seqLen: seqLen, inDim: config.hiddenSize, outDim: config.hiddenSize, cmdBuf: cmdBuf)
let v = try applyFloatLinear(input: current, weight: weights.selfAttnVProjWeight,
seqLen: seqLen, inDim: config.hiddenSize, outDim: config.hiddenSize, cmdBuf: cmdBuf)
let attnOut = try applyAudioAttention(q: q, k: k, v: v,
relativeKProj: weights.selfAttnRelativeKProj,
perDimScale: weights.selfAttnPerDimScale,
seqLen: seqLen, cmdBuf: cmdBuf)
let post = try applyFloatLinear(input: attnOut, weight: weights.selfAttnPostWeight,
seqLen: seqLen, inDim: config.hiddenSize, outDim: config.hiddenSize, cmdBuf: cmdBuf)
// 3. Residual + norm
current = try applyResidualAdd(input: input, add: post, seqLen: seqLen,
hiddenSize: config.hiddenSize, cmdBuf: cmdBuf)
current = try applyRMSNorm(input: current, weight: weights.normPostAttn,
seqLen: seqLen, hiddenSize: config.hiddenSize, cmdBuf: cmdBuf)
// 4. LConv1D
let lconvOut = try applyLConv1D(input: current, weights: weights, seqLen: seqLen, cmdBuf: cmdBuf)
current = try applyResidualAdd(input: current, add: lconvOut, seqLen: seqLen,
hiddenSize: config.hiddenSize, cmdBuf: cmdBuf)
// 5. FeedForward 1
let ff1Out = try applyFeedForward(input: current,
layer1Weight: weights.feedForward1Layer1Weight,
layer2Weight: weights.feedForward1Layer2Weight,
preNorm: weights.feedForward1PreLayerNorm,
postNorm: weights.feedForward1PostLayerNorm,
seqLen: seqLen, cmdBuf: cmdBuf)
current = try applyResidualAdd(input: current, add: ff1Out, seqLen: seqLen,
hiddenSize: config.hiddenSize, cmdBuf: cmdBuf)
// 6. FeedForward 2
let ff2Out = try applyFeedForward(input: current,
layer1Weight: weights.feedForward2Layer1Weight,
layer2Weight: weights.feedForward2Layer2Weight,
preNorm: weights.feedForward2PreLayerNorm,
postNorm: weights.feedForward2PostLayerNorm,
seqLen: seqLen, cmdBuf: cmdBuf)
current = try applyResidualAdd(input: current, add: ff2Out, seqLen: seqLen,
hiddenSize: config.hiddenSize, cmdBuf: cmdBuf)
current = try applyRMSNorm(input: current, weight: weights.normOut,
seqLen: seqLen, hiddenSize: config.hiddenSize, cmdBuf: cmdBuf)
return current
}
private func applyRMSNorm(input: MTLBuffer, weight: MTLBuffer, seqLen: Int,
hiddenSize: Int, cmdBuf: MTLCommandBuffer) throws -> MTLBuffer {
let output = tempBuffer
let pso = try engine.pipeline(named: "rms_norm_seq")
let enc = cmdBuf.makeComputeCommandEncoder()!
enc.setComputePipelineState(pso)
enc.setBuffer(input, offset: 0, index: 0)
enc.setBuffer(weight, offset: 0, index: 1)
enc.setBuffer(output, offset: 0, index: 2)
var N = UInt32(hiddenSize)
enc.setBytes(&N, length: 4, index: 3)
var eps = config.rmsNormEps
enc.setBytes(&eps, length: 4, index: 4)
var seqLen_ = UInt32(seqLen)
enc.setBytes(&seqLen_, length: 4, index: 5)
let grid = MTLSize(width: hiddenSize, height: seqLen, depth: 1)
let tg = engine.threadgroupSize2D(pso, grid: (hiddenSize, seqLen))
enc.dispatchThreads(grid, threadsPerThreadgroup: tg)
enc.endEncoding()
return output
}
private func applyAudioAttention(q: MTLBuffer, k: MTLBuffer, v: MTLBuffer,
relativeKProj: MTLBuffer, perDimScale: MTLBuffer,
seqLen: Int, cmdBuf: MTLCommandBuffer) throws -> MTLBuffer {
let output = attnOutBuffer
let pso = try engine.pipeline(named: "audio_attention_full")
let enc = cmdBuf.makeComputeCommandEncoder()!
enc.setComputePipelineState(pso)
enc.setBuffer(q, offset: 0, index: 0)
enc.setBuffer(k, offset: 0, index: 1)
enc.setBuffer(v, offset: 0, index: 2)
enc.setBuffer(relativeKProj, offset: 0, index: 3)
enc.setBuffer(perDimScale, offset: 0, index: 4)
enc.setBuffer(output, offset: 0, index: 5)
var seqLen_ = UInt32(seqLen)
enc.setBytes(&seqLen_, length: 4, index: 6)
var numHeads = UInt32(config.numAttentionHeads)
enc.setBytes(&numHeads, length: 4, index: 7)
var headDim = UInt32(config.headDim)
enc.setBytes(&headDim, length: 4, index: 8)
var contextLeft = UInt32(config.attentionContextLeft)
enc.setBytes(&contextLeft, length: 4, index: 9)
var logitCap = config.attentionLogitCap
enc.setBytes(&logitCap, length: 4, index: 10)
let grid = MTLSize(width: config.numAttentionHeads * config.headDim, height: seqLen, depth: 1)
let tg = engine.threadgroupSize2D(pso, grid: (config.numAttentionHeads * config.headDim, seqLen))
enc.dispatchThreads(grid, threadsPerThreadgroup: tg)
enc.endEncoding()
return output
}
private func applyLConv1D(input: MTLBuffer, weights: AudioLayerWeightsE2B,
seqLen: Int, cmdBuf: MTLCommandBuffer) throws -> MTLBuffer {
var current = try applyRMSNorm(input: input, weight: weights.lconv1dPreLayerNorm,
seqLen: seqLen, hiddenSize: config.hiddenSize, cmdBuf: cmdBuf)
let linearStart = try applyFloatLinear(input: current, weight: weights.lconv1dLinearStartWeight,
seqLen: seqLen, inDim: config.hiddenSize,
outDim: config.hiddenSize * 2, cmdBuf: cmdBuf)
let activated = try applySiLU(input: linearStart, count: seqLen * config.hiddenSize * 2, cmdBuf: cmdBuf)
let convOut = try applyDepthwiseConv1D(input: activated, weight: weights.lconv1dDepthwiseConv,
norm: weights.lconv1dConvNorm, seqLen: seqLen,
channels: config.hiddenSize * 2, kernelSize: config.convKernelSize,
cmdBuf: cmdBuf)
let linearEnd = try applyFloatLinear(input: convOut, weight: weights.lconv1dLinearEndWeight,
seqLen: seqLen, inDim: config.hiddenSize * 2,
outDim: config.hiddenSize, cmdBuf: cmdBuf)
return linearEnd
}
private func applyDepthwiseConv1D(input: MTLBuffer, weight: MTLBuffer, norm: MTLBuffer,
seqLen: Int, channels: Int, kernelSize: Int,
cmdBuf: MTLCommandBuffer) throws -> MTLBuffer {
let output = tempBuffer
let pso = try engine.pipeline(named: "audio_depthwise_conv1d")
let enc = cmdBuf.makeComputeCommandEncoder()!
enc.setComputePipelineState(pso)
enc.setBuffer(input, offset: 0, index: 0)
enc.setBuffer(weight, offset: 0, index: 1)
enc.setBuffer(norm, offset: 0, index: 2)
enc.setBuffer(output, offset: 0, index: 3)
var channels_ = UInt32(channels)
enc.setBytes(&channels_, length: 4, index: 4)
var kernelSize_ = UInt32(kernelSize)
enc.setBytes(&kernelSize_, length: 4, index: 5)
var seqLen_ = UInt32(seqLen)
enc.setBytes(&seqLen_, length: 4, index: 6)
let grid = MTLSize(width: channels, height: seqLen, depth: 1)
let tg = engine.threadgroupSize2D(pso, grid: (channels, seqLen))
enc.dispatchThreads(grid, threadsPerThreadgroup: tg)
enc.endEncoding()
return output
}
private func applyFeedForward(input: MTLBuffer, layer1Weight: MTLBuffer, layer2Weight: MTLBuffer,
preNorm: MTLBuffer, postNorm: MTLBuffer,
seqLen: Int, cmdBuf: MTLCommandBuffer) throws -> MTLBuffer {
var current = try applyRMSNorm(input: input, weight: preNorm,
seqLen: seqLen, hiddenSize: config.hiddenSize, cmdBuf: cmdBuf)
let layer1 = try applyFloatLinear(input: current, weight: layer1Weight,
seqLen: seqLen, inDim: config.hiddenSize, outDim: 4096, cmdBuf: cmdBuf)
let activated = try applySiLU(input: layer1, count: seqLen * 4096, cmdBuf: cmdBuf)
let layer2 = try applyFloatLinear(input: activated, weight: layer2Weight,
seqLen: seqLen, inDim: 4096, outDim: config.hiddenSize, cmdBuf: cmdBuf)
return try applyRMSNorm(input: layer2, weight: postNorm,
seqLen: seqLen, hiddenSize: config.hiddenSize, cmdBuf: cmdBuf)
}
private func applySiLU(input: MTLBuffer, count: Int, cmdBuf: MTLCommandBuffer) throws -> MTLBuffer {
let output = tempBuffer
let pso = try engine.pipeline(named: "silu")
let enc = cmdBuf.makeComputeCommandEncoder()!
enc.setComputePipelineState(pso)
enc.setBuffer(input, offset: 0, index: 0)
enc.setBuffer(output, offset: 0, index: 1)
var count_ = UInt32(count)
enc.setBytes(&count_, length: 4, index: 2)
let grid = MTLSize(width: count, height: 1, depth: 1)
let tg = engine.threadgroupSize1D(pso, count: count)
enc.dispatchThreads(grid, threadsPerThreadgroup: tg)
enc.endEncoding()
return output
}
private func applyResidualAdd(input: MTLBuffer, add: MTLBuffer, seqLen: Int,
hiddenSize: Int, cmdBuf: MTLCommandBuffer) throws -> MTLBuffer {
let output = tempBuffer
let pso = try engine.pipeline(named: "residual_add")
let enc = cmdBuf.makeComputeCommandEncoder()!
enc.setComputePipelineState(pso)
enc.setBuffer(input, offset: 0, index: 0)
enc.setBuffer(add, offset: 0, index: 1)
enc.setBuffer(output, offset: 0, index: 2)
var count32 = UInt32(seqLen * hiddenSize)
enc.setBytes(&count32, length: 4, index: 3)
var weight = config.residualWeight
enc.setBytes(&weight, length: 4, index: 4)
let count = seqLen * hiddenSize
let grid = MTLSize(width: count, height: 1, depth: 1)
let tg = engine.threadgroupSize1D(pso, count: count)
enc.dispatchThreads(grid, threadsPerThreadgroup: tg)
enc.endEncoding()
return output
}
private func applyOutputProjection(input: MTLBuffer, seqLen: Int, output: MTLBuffer,
cmdBuf: MTLCommandBuffer) throws {
let pso = try engine.pipeline(named: "audio_linear_seq")
let enc = cmdBuf.makeComputeCommandEncoder()!
enc.setComputePipelineState(pso)
enc.setBuffer(input, offset: 0, index: 0)
enc.setBuffer(weights.outputProjWeight, offset: 0, index: 1)
enc.setBuffer(weights.outputProjBias, offset: 0, index: 2)
enc.setBuffer(output, offset: 0, index: 3)
var inF = UInt32(config.hiddenSize)
enc.setBytes(&inF, length: 4, index: 4)
var outF = UInt32(config.outputProjDims)
enc.setBytes(&outF, length: 4, index: 5)
var hasBias = true
enc.setBytes(&hasBias, length: 1, index: 6)
var seqLen_ = UInt32(seqLen)
enc.setBytes(&seqLen_, length: 4, index: 7)
let grid = MTLSize(width: config.outputProjDims, height: seqLen, depth: 1)
let tg = engine.threadgroupSize2D(pso, grid: (config.outputProjDims, seqLen))
enc.dispatchThreads(grid, threadsPerThreadgroup: tg)
enc.endEncoding()
}
}
// E2B audio tower loader function
func loadAudioTowerE2B(reader: SafeTensorsReader, config: AudioConfig,
engine: MarkBaseEngine) throws -> AudioTowerE2B {
print("Loading E2B Audio Tower with preload optimization...")
let startTime = Date()
// Collect all audio tensor descriptors
let audioPrefix = "audio_tower."
let audioDescriptors = reader.allDescriptors().filter {
$0.name.hasPrefix(audioPrefix)
}
print(" Found \(audioDescriptors.count) audio tensors")
// Parallel preload all audio tensors
let dispatchGroup = DispatchGroup()
let loadQueue = DispatchQueue(label: "audio-preload-e2b", attributes: .concurrent)
var loadedData: [Data?] = Array(repeating: nil, count: audioDescriptors.count)
var loadErrors: [Error?] = Array(repeating: nil, count: audioDescriptors.count)
for (idx, desc) in audioDescriptors.enumerated() {
dispatchGroup.enter()
loadQueue.async {
do {
let data = try reader.read(tensor: desc)
loadedData[idx] = data
} catch {
loadErrors[idx] = error
}
dispatchGroup.leave()
}
}
dispatchGroup.wait()
// Check for errors
for (idx, error) in loadErrors.enumerated() {
if let err = error {
throw WeightError.readFailed("Failed to preload audio tensor \(audioDescriptors[idx].name): \(err)")
}
}
let preloadTime = Date().timeIntervalSince(startTime) * 1000
print(" ✓ Parallel preloaded \(audioDescriptors.count) audio tensors in \(String(format: "%.1f", preloadTime))ms")
// Convert to floats dictionary
var floats: [String: [Float]] = [:]
for (idx, desc) in audioDescriptors.enumerated() {
guard let data = loadedData[idx] else { continue }
let name = desc.name
switch desc.dtype {
case .bf16:
floats[name] = SafeTensorsReader.bf16ToFloat32(data)
case .f32:
floats[name] = data.withUnsafeBytes {
Array($0.assumingMemoryBound(to: Float.self))
}
default:
break
}
}
guard !floats.isEmpty else {
throw WeightError.tensorNotFound("Audio tower tensors")
}
let weights = try AudioWeightsE2B(device: engine.device, config: config, floats: floats)
let totalTime = Date().timeIntervalSince(startTime) * 1000
print(" ✓ E2B Audio Tower loaded in \(String(format: "%.1f", totalTime))ms")
return try AudioTowerE2B(config: config, engine: engine, weights: weights)
}