8a66b9086a
- Started from ac75faa (initial E4B-MarkBase integration)
- Kept Sources/ (all engine code) + Package.swift + .gitignore
- Removed all ad-hoc tests, documentation, scripts, Python files
- Added Tests/00_Unit/ (MathTest, TokenizerTest, SamplerTest)
- Added .gitea/workflows/ci.yaml (build + unit tests + lint)
- Added Scripts/check_resources.sh (memory-aware test runner)
- Added Tests/Manifest.json (resource requirements for all tests)
- Focus: 4-bit quantized models only
740 lines
24 KiB
Swift
740 lines
24 KiB
Swift
import Metal
|
|
|
|
public final class AudioTower {
|
|
public let config: AudioConfig
|
|
public let engine: MarkBaseEngine
|
|
public let weights: AudioWeights
|
|
|
|
private var normBuffer: MTLBuffer
|
|
private var qBuffer: MTLBuffer
|
|
private var kBuffer: MTLBuffer
|
|
private var vBuffer: MTLBuffer
|
|
private var attnOutBuffer: MTLBuffer
|
|
private var ffnBuffer: MTLBuffer
|
|
private var tempBuffer: MTLBuffer
|
|
private var subsampleBuf: MTLBuffer
|
|
private var layerBuffer: MTLBuffer // NEW: dedicated buffer for audio layers
|
|
|
|
public init(config: AudioConfig, engine: MarkBaseEngine, weights: AudioWeights) throws {
|
|
self.config = config
|
|
self.engine = engine
|
|
self.weights = weights
|
|
|
|
let device = engine.device
|
|
let maxSeqLen = 4096
|
|
let hiddenSize = config.hiddenSize
|
|
let headDim = config.headDim
|
|
|
|
normBuffer = device.makeBuffer(length: hiddenSize * maxSeqLen * 4)!
|
|
qBuffer = device.makeBuffer(length: hiddenSize * maxSeqLen * 4)!
|
|
kBuffer = device.makeBuffer(length: hiddenSize * maxSeqLen * 4)!
|
|
vBuffer = device.makeBuffer(length: hiddenSize * maxSeqLen * 4)!
|
|
attnOutBuffer = device.makeBuffer(length: hiddenSize * maxSeqLen * 4)!
|
|
ffnBuffer = device.makeBuffer(length: 4096 * maxSeqLen * 4)!
|
|
tempBuffer = device.makeBuffer(length: max(hiddenSize, 4096) * maxSeqLen * 4)!
|
|
subsampleBuf = device.makeBuffer(length: max(hiddenSize, 128 * 64) * maxSeqLen * 4)!
|
|
layerBuffer = device.makeBuffer(length: max(hiddenSize, 4096) * maxSeqLen * 4)! // NEW
|
|
}
|
|
|
|
public func forward(inputBuffer: MTLBuffer, seqLen: Int, outputBuffer: MTLBuffer) throws {
|
|
var current = inputBuffer
|
|
var currentLen = seqLen
|
|
|
|
let cmdBuf = engine.commandQueue.makeCommandBuffer()!
|
|
|
|
// 1. Subsample conv: mel [seqLen, 128] -> [seqLen/4, 1024]
|
|
let (projInput, projLen) = try applySubsampleConv(
|
|
melInput: current,
|
|
nMels: 128,
|
|
seqLen: currentLen,
|
|
cmdBuf: cmdBuf
|
|
)
|
|
|
|
let cmdBuf2 = engine.commandQueue.makeCommandBuffer()!
|
|
|
|
// 2. Input projection: [seqLen/4, 1024] -> [seqLen/4, 1024]
|
|
current = try applyInputProjection(input: projInput, seqLen: projLen, cmdBuf: cmdBuf2)
|
|
currentLen = projLen
|
|
|
|
let cmdBuf3 = engine.commandQueue.makeCommandBuffer()!
|
|
|
|
// 3. Audio layers (12 layers)
|
|
for layerWeights in weights.layers {
|
|
current = try applyLayer(
|
|
input: current,
|
|
weights: layerWeights,
|
|
seqLen: currentLen,
|
|
cmdBuf: cmdBuf3
|
|
)
|
|
}
|
|
|
|
let cmdBuf4 = engine.commandQueue.makeCommandBuffer()!
|
|
|
|
// 4. Output projection: [seqLen/4, 1024] -> [seqLen/4, 1536]
|
|
try applyOutputProjection(input: current, seqLen: currentLen, output: outputBuffer, cmdBuf: cmdBuf4)
|
|
|
|
cmdBuf4.commit()
|
|
cmdBuf4.waitUntilCompleted()
|
|
}
|
|
|
|
private func applySubsampleConv(
|
|
melInput: MTLBuffer,
|
|
nMels: Int,
|
|
seqLen: Int,
|
|
cmdBuf: MTLCommandBuffer
|
|
) throws -> (MTLBuffer, Int) {
|
|
// Input mel: [seqLen, 128] row-major
|
|
// Step 1: Transpose to CHW [1, 128, seqLen]
|
|
let chwInput = try transposeMelToCHW(input: melInput, nMels: nMels, seqLen: seqLen, cmdBuf: cmdBuf)
|
|
|
|
// Step 2: Layer0 conv2d [1, 128, seqLen] -> [128, 64, seqLen/2]
|
|
let layer0Out = try applyConv2DLayer(
|
|
input: chwInput,
|
|
inCh: 1,
|
|
height: nMels,
|
|
width: seqLen,
|
|
convWeight: weights.subsampleConvLayer0.convWeight,
|
|
normWeight: weights.subsampleConvLayer0.normWeight,
|
|
outChannels: 128,
|
|
outputBuffer: tempBuffer,
|
|
cmdBuf: cmdBuf
|
|
)
|
|
let h1 = (nMels + 1) / 2
|
|
let w1 = (seqLen + 1) / 2
|
|
|
|
// Step 3: Layer1 conv2d [128, 64, seqLen/2] -> [32, 32, seqLen/4]
|
|
let layer1Out = try applyConv2DLayer(
|
|
input: layer0Out,
|
|
inCh: 128,
|
|
height: h1,
|
|
width: w1,
|
|
convWeight: weights.subsampleConvLayer1.convWeight,
|
|
normWeight: weights.subsampleConvLayer1.normWeight,
|
|
outChannels: 32,
|
|
outputBuffer: subsampleBuf,
|
|
cmdBuf: cmdBuf
|
|
)
|
|
let h2 = (h1 + 1) / 2
|
|
let w2 = (w1 + 1) / 2
|
|
|
|
// Step 4: Flatten [32, 32, w2] -> [w2, 1024]
|
|
let flatOutput = try flattenCHW(input: layer1Out, C: 32, H: h2, W: w2, outputBuffer: tempBuffer, cmdBuf: cmdBuf)
|
|
|
|
return (flatOutput, w2)
|
|
}
|
|
|
|
private func transposeMelToCHW(
|
|
input: MTLBuffer,
|
|
nMels: Int,
|
|
seqLen: Int,
|
|
cmdBuf: MTLCommandBuffer
|
|
) throws -> MTLBuffer {
|
|
let output = subsampleBuf
|
|
|
|
let pso = try engine.pipeline(named: "transpose_2d")
|
|
let enc = cmdBuf.makeComputeCommandEncoder()!
|
|
enc.setComputePipelineState(pso)
|
|
|
|
enc.setBuffer(input, offset: 0, index: 0)
|
|
enc.setBuffer(output, offset: 0, index: 1)
|
|
|
|
// FIX: Input is [seqLen, nMels], transpose to [nMels, seqLen]
|
|
var rows = UInt32(seqLen) // FIX: was nMels, should be seqLen
|
|
enc.setBytes(&rows, length: 4, index: 2)
|
|
var cols = UInt32(nMels) // FIX: was seqLen, should be nMels
|
|
enc.setBytes(&cols, length: 4, index: 3)
|
|
|
|
let grid = MTLSize(width: nMels, height: seqLen, depth: 1) // FIX: grid for output [nMels, seqLen]
|
|
let tg = engine.threadgroupSize2D(pso, grid: (nMels, seqLen))
|
|
enc.dispatchThreads(grid, threadsPerThreadgroup: tg)
|
|
enc.endEncoding()
|
|
|
|
return output
|
|
}
|
|
|
|
private func applyConv2DLayer(
|
|
input: MTLBuffer,
|
|
inCh: Int,
|
|
height: Int,
|
|
width: Int,
|
|
convWeight: MTLBuffer,
|
|
normWeight: MTLBuffer,
|
|
outChannels: Int,
|
|
outputBuffer: MTLBuffer,
|
|
cmdBuf: MTLCommandBuffer
|
|
) throws -> MTLBuffer {
|
|
let pso = try engine.pipeline(named: "audio_subsample_conv_2d")
|
|
let enc = cmdBuf.makeComputeCommandEncoder()!
|
|
enc.setComputePipelineState(pso)
|
|
|
|
enc.setBuffer(input, offset: 0, index: 0)
|
|
enc.setBuffer(convWeight, offset: 0, index: 1)
|
|
enc.setBuffer(normWeight, offset: 0, index: 2)
|
|
enc.setBuffer(outputBuffer, offset: 0, index: 3)
|
|
|
|
var inCh_ = UInt32(inCh)
|
|
enc.setBytes(&inCh_, length: 4, index: 4)
|
|
var outCh_ = UInt32(outChannels)
|
|
enc.setBytes(&outCh_, length: 4, index: 5)
|
|
var h_ = UInt32(height)
|
|
enc.setBytes(&h_, length: 4, index: 6)
|
|
var w_ = UInt32(width)
|
|
enc.setBytes(&w_, length: 4, index: 7)
|
|
|
|
let outH = (height + 1) / 2
|
|
let outW = (width + 1) / 2
|
|
let grid = MTLSize(width: outChannels, height: outH, depth: outW)
|
|
let tg = MTLSize(width: 8, height: 8, depth: 4)
|
|
enc.dispatchThreads(grid, threadsPerThreadgroup: tg)
|
|
enc.endEncoding()
|
|
|
|
return outputBuffer
|
|
}
|
|
|
|
private func flattenCHW(
|
|
input: MTLBuffer,
|
|
C: Int,
|
|
H: Int,
|
|
W: Int,
|
|
outputBuffer: MTLBuffer,
|
|
cmdBuf: MTLCommandBuffer
|
|
) throws -> MTLBuffer {
|
|
let pso = try engine.pipeline(named: "audio_flatten_chw")
|
|
let enc = cmdBuf.makeComputeCommandEncoder()!
|
|
enc.setComputePipelineState(pso)
|
|
|
|
enc.setBuffer(input, offset: 0, index: 0)
|
|
enc.setBuffer(outputBuffer, offset: 0, index: 1)
|
|
|
|
var C_ = UInt32(C)
|
|
enc.setBytes(&C_, length: 4, index: 2)
|
|
var H_ = UInt32(H)
|
|
enc.setBytes(&H_, length: 4, index: 3)
|
|
var W_ = UInt32(W)
|
|
enc.setBytes(&W_, length: 4, index: 4)
|
|
|
|
let grid = MTLSize(width: C * H, height: W, depth: 1)
|
|
let tg = engine.threadgroupSize2D(pso, grid: (C * H, W))
|
|
enc.dispatchThreads(grid, threadsPerThreadgroup: tg)
|
|
enc.endEncoding()
|
|
|
|
return outputBuffer
|
|
}
|
|
|
|
private func applyInputProjection(input: MTLBuffer, seqLen: Int, cmdBuf: MTLCommandBuffer) throws -> MTLBuffer {
|
|
// FIX: Use subsampleBuf as output to avoid overwriting input (tempBuffer)
|
|
let output = subsampleBuf
|
|
|
|
// Input: [seqLen, 1024] after flatten (32 channels * 32 height = 1024)
|
|
// Weight: [1024, 1024] float32
|
|
// Output: [seqLen, 1024] (hiddenSize)
|
|
|
|
let pso = try engine.pipeline(named: "audio_linear_seq")
|
|
let enc = cmdBuf.makeComputeCommandEncoder()!
|
|
enc.setComputePipelineState(pso)
|
|
|
|
enc.setBuffer(input, offset: 0, index: 0)
|
|
enc.setBuffer(weights.inputProjLinearWeight, offset: 0, index: 1)
|
|
enc.setBuffer(nil, offset: 0, index: 2) // No bias
|
|
enc.setBuffer(output, offset: 0, index: 3)
|
|
|
|
var inFeatures = UInt32(1024)
|
|
enc.setBytes(&inFeatures, length: 4, index: 4)
|
|
var outFeatures = UInt32(1024)
|
|
enc.setBytes(&outFeatures, length: 4, index: 5)
|
|
var hasBias = false
|
|
enc.setBytes(&hasBias, length: 1, index: 6)
|
|
var seqLen_ = UInt32(seqLen)
|
|
enc.setBytes(&seqLen_, length: 4, index: 7)
|
|
|
|
let grid = MTLSize(width: 1024, height: seqLen, depth: 1)
|
|
let tg = engine.threadgroupSize2D(pso, grid: (1024, seqLen))
|
|
enc.dispatchThreads(grid, threadsPerThreadgroup: tg)
|
|
enc.endEncoding()
|
|
|
|
return output
|
|
}
|
|
|
|
private func applyLayer(
|
|
input: MTLBuffer,
|
|
weights: AudioLayerWeights,
|
|
seqLen: Int,
|
|
cmdBuf: MTLCommandBuffer
|
|
) throws -> MTLBuffer {
|
|
var current = input
|
|
|
|
// 1. Norm pre-attn
|
|
current = try applyRMSNorm(
|
|
input: current,
|
|
weight: weights.normPreAttn,
|
|
seqLen: seqLen,
|
|
hiddenSize: config.hiddenSize,
|
|
cmdBuf: cmdBuf
|
|
)
|
|
|
|
// 2. Self-attention with relative position
|
|
let attnOut = try applySelfAttention(
|
|
input: current,
|
|
weights: weights,
|
|
seqLen: seqLen,
|
|
cmdBuf: cmdBuf
|
|
)
|
|
|
|
// 3. Residual + norm post-attn
|
|
current = try applyResidualAdd(
|
|
input: input,
|
|
add: attnOut,
|
|
seqLen: seqLen,
|
|
hiddenSize: config.hiddenSize,
|
|
residualWeight: config.residualWeight,
|
|
cmdBuf: cmdBuf
|
|
)
|
|
|
|
current = try applyRMSNorm(
|
|
input: current,
|
|
weight: weights.normPostAttn,
|
|
seqLen: seqLen,
|
|
hiddenSize: config.hiddenSize,
|
|
cmdBuf: cmdBuf
|
|
)
|
|
|
|
// 4. Local conv1d
|
|
let lconvOut = try applyLConv1D(
|
|
input: current,
|
|
weights: weights,
|
|
seqLen: seqLen,
|
|
cmdBuf: cmdBuf
|
|
)
|
|
|
|
// 5. Residual
|
|
current = try applyResidualAdd(
|
|
input: current,
|
|
add: lconvOut,
|
|
seqLen: seqLen,
|
|
hiddenSize: config.hiddenSize,
|
|
residualWeight: config.residualWeight,
|
|
cmdBuf: cmdBuf
|
|
)
|
|
|
|
// 6. Feed-forward 1
|
|
let ff1Out = try applyFeedForward(
|
|
input: current,
|
|
weights: weights.feedForward1,
|
|
seqLen: seqLen,
|
|
cmdBuf: cmdBuf
|
|
)
|
|
|
|
// 7. Residual
|
|
current = try applyResidualAdd(
|
|
input: current,
|
|
add: ff1Out,
|
|
seqLen: seqLen,
|
|
hiddenSize: config.hiddenSize,
|
|
residualWeight: config.residualWeight,
|
|
cmdBuf: cmdBuf
|
|
)
|
|
|
|
// 8. Feed-forward 2
|
|
let ff2Out = try applyFeedForward(
|
|
input: current,
|
|
weights: weights.feedForward2,
|
|
seqLen: seqLen,
|
|
cmdBuf: cmdBuf
|
|
)
|
|
|
|
// 9. Residual + norm out
|
|
current = try applyResidualAdd(
|
|
input: current,
|
|
add: ff2Out,
|
|
seqLen: seqLen,
|
|
hiddenSize: config.hiddenSize,
|
|
residualWeight: config.residualWeight,
|
|
cmdBuf: cmdBuf
|
|
)
|
|
|
|
current = try applyRMSNorm(
|
|
input: current,
|
|
weight: weights.normOut,
|
|
seqLen: seqLen,
|
|
hiddenSize: config.hiddenSize,
|
|
cmdBuf: cmdBuf
|
|
)
|
|
|
|
return current
|
|
}
|
|
|
|
private func applySelfAttention(
|
|
input: MTLBuffer,
|
|
weights: AudioLayerWeights,
|
|
seqLen: Int,
|
|
cmdBuf: MTLCommandBuffer
|
|
) throws -> MTLBuffer {
|
|
// Q, K, V projections
|
|
let q = try applyQuantizedLinear(
|
|
input: input,
|
|
weights: weights.selfAttnQProj,
|
|
seqLen: seqLen,
|
|
output: qBuffer,
|
|
cmdBuf: cmdBuf
|
|
)
|
|
|
|
let k = try applyQuantizedLinear(
|
|
input: input,
|
|
weights: weights.selfAttnKProj,
|
|
seqLen: seqLen,
|
|
output: kBuffer,
|
|
cmdBuf: cmdBuf
|
|
)
|
|
|
|
let v = try applyQuantizedLinear(
|
|
input: input,
|
|
weights: weights.selfAttnVProj,
|
|
seqLen: seqLen,
|
|
output: vBuffer,
|
|
cmdBuf: cmdBuf
|
|
)
|
|
|
|
// Attention with relative position and context
|
|
let attnOut = try applyAudioAttention(
|
|
q: q,
|
|
k: k,
|
|
v: v,
|
|
relativeKProj: weights.selfAttnRelativeKProj,
|
|
perDimScale: weights.selfAttnPerDimScale,
|
|
seqLen: seqLen,
|
|
numHeads: config.numAttentionHeads,
|
|
headDim: config.headDim,
|
|
contextLeft: config.attentionContextLeft,
|
|
logitCap: config.attentionLogitCap,
|
|
output: attnOutBuffer,
|
|
cmdBuf: cmdBuf
|
|
)
|
|
|
|
// Post projection
|
|
let output = try applyQuantizedLinear(
|
|
input: attnOut,
|
|
weights: weights.selfAttnPost,
|
|
seqLen: seqLen,
|
|
output: tempBuffer,
|
|
cmdBuf: cmdBuf
|
|
)
|
|
|
|
return output
|
|
}
|
|
|
|
private func applyAudioAttention(
|
|
q: MTLBuffer,
|
|
k: MTLBuffer,
|
|
v: MTLBuffer,
|
|
relativeKProj: MTLBuffer,
|
|
perDimScale: MTLBuffer,
|
|
seqLen: Int,
|
|
numHeads: Int,
|
|
headDim: Int,
|
|
contextLeft: Int,
|
|
logitCap: Float,
|
|
output: MTLBuffer,
|
|
cmdBuf: MTLCommandBuffer
|
|
) throws -> MTLBuffer {
|
|
let pso = try engine.pipeline(named: "audio_attention_full")
|
|
let enc = cmdBuf.makeComputeCommandEncoder()!
|
|
enc.setComputePipelineState(pso)
|
|
|
|
enc.setBuffer(q, offset: 0, index: 0)
|
|
enc.setBuffer(k, offset: 0, index: 1)
|
|
enc.setBuffer(v, offset: 0, index: 2)
|
|
enc.setBuffer(relativeKProj, offset: 0, index: 3)
|
|
enc.setBuffer(perDimScale, offset: 0, index: 4)
|
|
enc.setBuffer(output, offset: 0, index: 5)
|
|
|
|
var seqLen_ = UInt32(seqLen)
|
|
enc.setBytes(&seqLen_, length: 4, index: 6)
|
|
var numHeads_ = UInt32(numHeads)
|
|
enc.setBytes(&numHeads_, length: 4, index: 7)
|
|
var headDim_ = UInt32(headDim)
|
|
enc.setBytes(&headDim_, length: 4, index: 8)
|
|
var contextLeft_ = UInt32(contextLeft)
|
|
enc.setBytes(&contextLeft_, length: 4, index: 9)
|
|
var logitCap_ = logitCap
|
|
enc.setBytes(&logitCap_, length: 4, index: 10)
|
|
|
|
let grid = MTLSize(width: numHeads * headDim, height: seqLen, depth: 1)
|
|
let tg = engine.threadgroupSize2D(pso, grid: (numHeads * headDim, seqLen))
|
|
enc.dispatchThreads(grid, threadsPerThreadgroup: tg)
|
|
enc.endEncoding()
|
|
|
|
return output
|
|
}
|
|
|
|
private func applyLConv1D(
|
|
input: MTLBuffer,
|
|
weights: AudioLayerWeights,
|
|
seqLen: Int,
|
|
cmdBuf: MTLCommandBuffer
|
|
) throws -> MTLBuffer {
|
|
// Pre-layer norm
|
|
var current = try applyRMSNorm(
|
|
input: input,
|
|
weight: weights.lconv1dPreLayerNorm,
|
|
seqLen: seqLen,
|
|
hiddenSize: config.hiddenSize,
|
|
cmdBuf: cmdBuf
|
|
)
|
|
|
|
// Linear start: [seqLen, 1024] -> [seqLen, 2048]
|
|
let linearStart = try applyQuantizedLinear(
|
|
input: current,
|
|
weights: weights.lconv1dLinearStart,
|
|
seqLen: seqLen,
|
|
output: ffnBuffer,
|
|
cmdBuf: cmdBuf
|
|
)
|
|
|
|
// SiLU activation
|
|
let activated = try applySiLU(input: linearStart, count: seqLen * config.hiddenSize * 2, cmdBuf: cmdBuf)
|
|
|
|
// Depthwise conv1d
|
|
let convOut = try applyDepthwiseConv1D(
|
|
input: activated,
|
|
weight: weights.lconv1dDepthwiseConv,
|
|
norm: weights.lconv1dConvNorm,
|
|
seqLen: seqLen,
|
|
channels: config.hiddenSize * 2,
|
|
kernelSize: config.convKernelSize,
|
|
cmdBuf: cmdBuf
|
|
)
|
|
|
|
// Linear end: [seqLen, 2048] -> [seqLen, 1024]
|
|
let output = try applyQuantizedLinear(
|
|
input: convOut,
|
|
weights: weights.lconv1dLinearEnd,
|
|
seqLen: seqLen,
|
|
output: tempBuffer,
|
|
cmdBuf: cmdBuf
|
|
)
|
|
|
|
return output
|
|
}
|
|
|
|
private func applyDepthwiseConv1D(
|
|
input: MTLBuffer,
|
|
weight: MTLBuffer,
|
|
norm: MTLBuffer,
|
|
seqLen: Int,
|
|
channels: Int,
|
|
kernelSize: Int,
|
|
cmdBuf: MTLCommandBuffer
|
|
) throws -> MTLBuffer {
|
|
// FIX: Use layerBuffer for audio layers
|
|
let output = layerBuffer
|
|
|
|
let pso = try engine.pipeline(named: "audio_depthwise_conv1d")
|
|
let enc = cmdBuf.makeComputeCommandEncoder()!
|
|
enc.setComputePipelineState(pso)
|
|
|
|
enc.setBuffer(input, offset: 0, index: 0)
|
|
enc.setBuffer(weight, offset: 0, index: 1)
|
|
enc.setBuffer(norm, offset: 0, index: 2)
|
|
enc.setBuffer(output, offset: 0, index: 3)
|
|
|
|
var channels_ = UInt32(channels)
|
|
enc.setBytes(&channels_, length: 4, index: 4)
|
|
var kernelSize_ = UInt32(kernelSize)
|
|
enc.setBytes(&kernelSize_, length: 4, index: 5)
|
|
var seqLen_ = UInt32(seqLen)
|
|
enc.setBytes(&seqLen_, length: 4, index: 6)
|
|
|
|
let grid = MTLSize(width: channels, height: seqLen, depth: 1)
|
|
let tg = engine.threadgroupSize2D(pso, grid: (channels, seqLen))
|
|
enc.dispatchThreads(grid, threadsPerThreadgroup: tg)
|
|
enc.endEncoding()
|
|
|
|
return output
|
|
}
|
|
|
|
private func applyFeedForward(
|
|
input: MTLBuffer,
|
|
weights: FeedForwardWeights,
|
|
seqLen: Int,
|
|
cmdBuf: MTLCommandBuffer
|
|
) throws -> MTLBuffer {
|
|
// Pre-layer norm
|
|
var current = try applyRMSNorm(
|
|
input: input,
|
|
weight: weights.preLayerNorm,
|
|
seqLen: seqLen,
|
|
hiddenSize: config.hiddenSize,
|
|
cmdBuf: cmdBuf
|
|
)
|
|
|
|
// Layer 1: [seqLen, 1024] -> [seqLen, 4096]
|
|
let layer1 = try applyQuantizedLinear(
|
|
input: current,
|
|
weights: weights.ffwLayer1,
|
|
seqLen: seqLen,
|
|
output: ffnBuffer,
|
|
cmdBuf: cmdBuf
|
|
)
|
|
|
|
// SiLU activation
|
|
let activated = try applySiLU(input: layer1, count: seqLen * 4096, cmdBuf: cmdBuf)
|
|
|
|
// Layer 2: [seqLen, 4096] -> [seqLen, 1024]
|
|
let output = try applyQuantizedLinear(
|
|
input: activated,
|
|
weights: weights.ffwLayer2,
|
|
seqLen: seqLen,
|
|
output: tempBuffer,
|
|
cmdBuf: cmdBuf
|
|
)
|
|
|
|
// Post-layer norm
|
|
return try applyRMSNorm(
|
|
input: output,
|
|
weight: weights.postLayerNorm,
|
|
seqLen: seqLen,
|
|
hiddenSize: config.hiddenSize,
|
|
cmdBuf: cmdBuf
|
|
)
|
|
}
|
|
|
|
private func applyRMSNorm(
|
|
input: MTLBuffer,
|
|
weight: MTLBuffer,
|
|
seqLen: Int,
|
|
hiddenSize: Int,
|
|
cmdBuf: MTLCommandBuffer
|
|
) throws -> MTLBuffer {
|
|
// FIX: Use layerBuffer for audio layers to avoid tempBuffer conflict
|
|
let output = layerBuffer
|
|
|
|
let pso = try engine.pipeline(named: "rms_norm_seq")
|
|
let enc = cmdBuf.makeComputeCommandEncoder()!
|
|
enc.setComputePipelineState(pso)
|
|
|
|
enc.setBuffer(input, offset: 0, index: 0)
|
|
enc.setBuffer(weight, offset: 0, index: 1)
|
|
enc.setBuffer(output, offset: 0, index: 2)
|
|
|
|
var N = UInt32(hiddenSize)
|
|
enc.setBytes(&N, length: 4, index: 3)
|
|
var eps = config.rmsNormEps
|
|
enc.setBytes(&eps, length: 4, index: 4)
|
|
var seqLen_ = UInt32(seqLen)
|
|
enc.setBytes(&seqLen_, length: 4, index: 5)
|
|
|
|
let grid = MTLSize(width: hiddenSize, height: seqLen, depth: 1)
|
|
let tg = engine.threadgroupSize2D(pso, grid: (hiddenSize, seqLen))
|
|
enc.dispatchThreads(grid, threadsPerThreadgroup: tg)
|
|
enc.endEncoding()
|
|
|
|
return output
|
|
}
|
|
|
|
private func applyQuantizedLinear(
|
|
input: MTLBuffer,
|
|
weights: QuantizedWeights,
|
|
seqLen: Int,
|
|
output: MTLBuffer,
|
|
bias: MTLBuffer? = nil,
|
|
cmdBuf: MTLCommandBuffer
|
|
) throws -> MTLBuffer {
|
|
let pso = try engine.pipeline(named: "quantized_matmul_seq")
|
|
let enc = cmdBuf.makeComputeCommandEncoder()!
|
|
enc.setComputePipelineState(pso)
|
|
|
|
enc.setBuffer(input, offset: 0, index: 0)
|
|
enc.setBuffer(weights.weight, offset: 0, index: 1)
|
|
enc.setBuffer(weights.scales, offset: 0, index: 2)
|
|
enc.setBuffer(weights.biases, offset: 0, index: 3)
|
|
enc.setBuffer(bias, offset: 0, index: 4)
|
|
enc.setBuffer(output, offset: 0, index: 5)
|
|
|
|
var inDim = UInt32(weights.inDim)
|
|
enc.setBytes(&inDim, length: 4, index: 6)
|
|
var outDim = UInt32(weights.outDim)
|
|
enc.setBytes(&outDim, length: 4, index: 7)
|
|
var hasBias = bias != nil
|
|
enc.setBytes(&hasBias, length: 1, index: 8)
|
|
var seqLen_ = UInt32(seqLen)
|
|
enc.setBytes(&seqLen_, length: 4, index: 9)
|
|
|
|
let grid = MTLSize(width: weights.outDim, height: seqLen, depth: 1)
|
|
let tg = engine.threadgroupSize2D(pso, grid: (weights.outDim, seqLen))
|
|
enc.dispatchThreads(grid, threadsPerThreadgroup: tg)
|
|
enc.endEncoding()
|
|
|
|
return output
|
|
}
|
|
|
|
private func applySiLU(input: MTLBuffer, count: Int, cmdBuf: MTLCommandBuffer) throws -> MTLBuffer {
|
|
// FIX: Use layerBuffer for audio layers
|
|
let output = layerBuffer
|
|
|
|
let pso = try engine.pipeline(named: "silu")
|
|
let enc = cmdBuf.makeComputeCommandEncoder()!
|
|
enc.setComputePipelineState(pso)
|
|
|
|
enc.setBuffer(input, offset: 0, index: 0)
|
|
enc.setBuffer(output, offset: 0, index: 1)
|
|
|
|
var count_ = UInt32(count)
|
|
enc.setBytes(&count_, length: 4, index: 2)
|
|
|
|
let grid = MTLSize(width: count, height: 1, depth: 1)
|
|
let tg = engine.threadgroupSize1D(pso, count: count)
|
|
enc.dispatchThreads(grid, threadsPerThreadgroup: tg)
|
|
enc.endEncoding()
|
|
|
|
return output
|
|
}
|
|
|
|
private func applyResidualAdd(
|
|
input: MTLBuffer,
|
|
add: MTLBuffer,
|
|
seqLen: Int,
|
|
hiddenSize: Int,
|
|
residualWeight: Float,
|
|
cmdBuf: MTLCommandBuffer
|
|
) throws -> MTLBuffer {
|
|
// FIX: Use layerBuffer for audio layers
|
|
let output = layerBuffer
|
|
|
|
let pso = try engine.pipeline(named: "residual_add")
|
|
let enc = cmdBuf.makeComputeCommandEncoder()!
|
|
enc.setComputePipelineState(pso)
|
|
|
|
enc.setBuffer(input, offset: 0, index: 0)
|
|
enc.setBuffer(add, offset: 0, index: 1)
|
|
enc.setBuffer(output, offset: 0, index: 2)
|
|
|
|
var count32 = UInt32(seqLen * hiddenSize)
|
|
enc.setBytes(&count32, length: 4, index: 3)
|
|
var weight = residualWeight
|
|
enc.setBytes(&weight, length: 4, index: 4)
|
|
|
|
let count = seqLen * hiddenSize
|
|
let grid = MTLSize(width: count, height: 1, depth: 1)
|
|
let tg = engine.threadgroupSize1D(pso, count: count)
|
|
enc.dispatchThreads(grid, threadsPerThreadgroup: tg)
|
|
enc.endEncoding()
|
|
|
|
return output
|
|
}
|
|
|
|
private func applyOutputProjection(
|
|
input: MTLBuffer,
|
|
seqLen: Int,
|
|
output: MTLBuffer,
|
|
cmdBuf: MTLCommandBuffer
|
|
) throws {
|
|
_ = try applyQuantizedLinear(
|
|
input: input,
|
|
weights: weights.outputProj,
|
|
seqLen: seqLen,
|
|
output: output,
|
|
bias: weights.outputProjBias,
|
|
cmdBuf: cmdBuf
|
|
)
|
|
}
|
|
} |