Files
markbaseengine/Sources/MarkBase/Audio/AudioTower.swift
T
MarkBase Admin 8a66b9086a
CI / build (push) Waiting to run
CI / unit-tests (push) Blocked by required conditions
CI / lint (push) Blocked by required conditions
v2: Initial clean branch with unit tests + CI/CD pipeline
- Started from ac75faa (initial E4B-MarkBase integration)
- Kept Sources/ (all engine code) + Package.swift + .gitignore
- Removed all ad-hoc tests, documentation, scripts, Python files
- Added Tests/00_Unit/ (MathTest, TokenizerTest, SamplerTest)
- Added .gitea/workflows/ci.yaml (build + unit tests + lint)
- Added Scripts/check_resources.sh (memory-aware test runner)
- Added Tests/Manifest.json (resource requirements for all tests)
- Focus: 4-bit quantized models only
2026-07-05 13:29:25 +08:00

740 lines
24 KiB
Swift

import Metal
public final class AudioTower {
public let config: AudioConfig
public let engine: MarkBaseEngine
public let weights: AudioWeights
private var normBuffer: MTLBuffer
private var qBuffer: MTLBuffer
private var kBuffer: MTLBuffer
private var vBuffer: MTLBuffer
private var attnOutBuffer: MTLBuffer
private var ffnBuffer: MTLBuffer
private var tempBuffer: MTLBuffer
private var subsampleBuf: MTLBuffer
private var layerBuffer: MTLBuffer // NEW: dedicated buffer for audio layers
public init(config: AudioConfig, engine: MarkBaseEngine, weights: AudioWeights) throws {
self.config = config
self.engine = engine
self.weights = weights
let device = engine.device
let maxSeqLen = 4096
let hiddenSize = config.hiddenSize
let headDim = config.headDim
normBuffer = device.makeBuffer(length: hiddenSize * maxSeqLen * 4)!
qBuffer = device.makeBuffer(length: hiddenSize * maxSeqLen * 4)!
kBuffer = device.makeBuffer(length: hiddenSize * maxSeqLen * 4)!
vBuffer = device.makeBuffer(length: hiddenSize * maxSeqLen * 4)!
attnOutBuffer = device.makeBuffer(length: hiddenSize * maxSeqLen * 4)!
ffnBuffer = device.makeBuffer(length: 4096 * maxSeqLen * 4)!
tempBuffer = device.makeBuffer(length: max(hiddenSize, 4096) * maxSeqLen * 4)!
subsampleBuf = device.makeBuffer(length: max(hiddenSize, 128 * 64) * maxSeqLen * 4)!
layerBuffer = device.makeBuffer(length: max(hiddenSize, 4096) * maxSeqLen * 4)! // NEW
}
public func forward(inputBuffer: MTLBuffer, seqLen: Int, outputBuffer: MTLBuffer) throws {
var current = inputBuffer
var currentLen = seqLen
let cmdBuf = engine.commandQueue.makeCommandBuffer()!
// 1. Subsample conv: mel [seqLen, 128] -> [seqLen/4, 1024]
let (projInput, projLen) = try applySubsampleConv(
melInput: current,
nMels: 128,
seqLen: currentLen,
cmdBuf: cmdBuf
)
let cmdBuf2 = engine.commandQueue.makeCommandBuffer()!
// 2. Input projection: [seqLen/4, 1024] -> [seqLen/4, 1024]
current = try applyInputProjection(input: projInput, seqLen: projLen, cmdBuf: cmdBuf2)
currentLen = projLen
let cmdBuf3 = engine.commandQueue.makeCommandBuffer()!
// 3. Audio layers (12 layers)
for layerWeights in weights.layers {
current = try applyLayer(
input: current,
weights: layerWeights,
seqLen: currentLen,
cmdBuf: cmdBuf3
)
}
let cmdBuf4 = engine.commandQueue.makeCommandBuffer()!
// 4. Output projection: [seqLen/4, 1024] -> [seqLen/4, 1536]
try applyOutputProjection(input: current, seqLen: currentLen, output: outputBuffer, cmdBuf: cmdBuf4)
cmdBuf4.commit()
cmdBuf4.waitUntilCompleted()
}
private func applySubsampleConv(
melInput: MTLBuffer,
nMels: Int,
seqLen: Int,
cmdBuf: MTLCommandBuffer
) throws -> (MTLBuffer, Int) {
// Input mel: [seqLen, 128] row-major
// Step 1: Transpose to CHW [1, 128, seqLen]
let chwInput = try transposeMelToCHW(input: melInput, nMels: nMels, seqLen: seqLen, cmdBuf: cmdBuf)
// Step 2: Layer0 conv2d [1, 128, seqLen] -> [128, 64, seqLen/2]
let layer0Out = try applyConv2DLayer(
input: chwInput,
inCh: 1,
height: nMels,
width: seqLen,
convWeight: weights.subsampleConvLayer0.convWeight,
normWeight: weights.subsampleConvLayer0.normWeight,
outChannels: 128,
outputBuffer: tempBuffer,
cmdBuf: cmdBuf
)
let h1 = (nMels + 1) / 2
let w1 = (seqLen + 1) / 2
// Step 3: Layer1 conv2d [128, 64, seqLen/2] -> [32, 32, seqLen/4]
let layer1Out = try applyConv2DLayer(
input: layer0Out,
inCh: 128,
height: h1,
width: w1,
convWeight: weights.subsampleConvLayer1.convWeight,
normWeight: weights.subsampleConvLayer1.normWeight,
outChannels: 32,
outputBuffer: subsampleBuf,
cmdBuf: cmdBuf
)
let h2 = (h1 + 1) / 2
let w2 = (w1 + 1) / 2
// Step 4: Flatten [32, 32, w2] -> [w2, 1024]
let flatOutput = try flattenCHW(input: layer1Out, C: 32, H: h2, W: w2, outputBuffer: tempBuffer, cmdBuf: cmdBuf)
return (flatOutput, w2)
}
private func transposeMelToCHW(
input: MTLBuffer,
nMels: Int,
seqLen: Int,
cmdBuf: MTLCommandBuffer
) throws -> MTLBuffer {
let output = subsampleBuf
let pso = try engine.pipeline(named: "transpose_2d")
let enc = cmdBuf.makeComputeCommandEncoder()!
enc.setComputePipelineState(pso)
enc.setBuffer(input, offset: 0, index: 0)
enc.setBuffer(output, offset: 0, index: 1)
// FIX: Input is [seqLen, nMels], transpose to [nMels, seqLen]
var rows = UInt32(seqLen) // FIX: was nMels, should be seqLen
enc.setBytes(&rows, length: 4, index: 2)
var cols = UInt32(nMels) // FIX: was seqLen, should be nMels
enc.setBytes(&cols, length: 4, index: 3)
let grid = MTLSize(width: nMels, height: seqLen, depth: 1) // FIX: grid for output [nMels, seqLen]
let tg = engine.threadgroupSize2D(pso, grid: (nMels, seqLen))
enc.dispatchThreads(grid, threadsPerThreadgroup: tg)
enc.endEncoding()
return output
}
private func applyConv2DLayer(
input: MTLBuffer,
inCh: Int,
height: Int,
width: Int,
convWeight: MTLBuffer,
normWeight: MTLBuffer,
outChannels: Int,
outputBuffer: MTLBuffer,
cmdBuf: MTLCommandBuffer
) throws -> MTLBuffer {
let pso = try engine.pipeline(named: "audio_subsample_conv_2d")
let enc = cmdBuf.makeComputeCommandEncoder()!
enc.setComputePipelineState(pso)
enc.setBuffer(input, offset: 0, index: 0)
enc.setBuffer(convWeight, offset: 0, index: 1)
enc.setBuffer(normWeight, offset: 0, index: 2)
enc.setBuffer(outputBuffer, offset: 0, index: 3)
var inCh_ = UInt32(inCh)
enc.setBytes(&inCh_, length: 4, index: 4)
var outCh_ = UInt32(outChannels)
enc.setBytes(&outCh_, length: 4, index: 5)
var h_ = UInt32(height)
enc.setBytes(&h_, length: 4, index: 6)
var w_ = UInt32(width)
enc.setBytes(&w_, length: 4, index: 7)
let outH = (height + 1) / 2
let outW = (width + 1) / 2
let grid = MTLSize(width: outChannels, height: outH, depth: outW)
let tg = MTLSize(width: 8, height: 8, depth: 4)
enc.dispatchThreads(grid, threadsPerThreadgroup: tg)
enc.endEncoding()
return outputBuffer
}
private func flattenCHW(
input: MTLBuffer,
C: Int,
H: Int,
W: Int,
outputBuffer: MTLBuffer,
cmdBuf: MTLCommandBuffer
) throws -> MTLBuffer {
let pso = try engine.pipeline(named: "audio_flatten_chw")
let enc = cmdBuf.makeComputeCommandEncoder()!
enc.setComputePipelineState(pso)
enc.setBuffer(input, offset: 0, index: 0)
enc.setBuffer(outputBuffer, offset: 0, index: 1)
var C_ = UInt32(C)
enc.setBytes(&C_, length: 4, index: 2)
var H_ = UInt32(H)
enc.setBytes(&H_, length: 4, index: 3)
var W_ = UInt32(W)
enc.setBytes(&W_, length: 4, index: 4)
let grid = MTLSize(width: C * H, height: W, depth: 1)
let tg = engine.threadgroupSize2D(pso, grid: (C * H, W))
enc.dispatchThreads(grid, threadsPerThreadgroup: tg)
enc.endEncoding()
return outputBuffer
}
private func applyInputProjection(input: MTLBuffer, seqLen: Int, cmdBuf: MTLCommandBuffer) throws -> MTLBuffer {
// FIX: Use subsampleBuf as output to avoid overwriting input (tempBuffer)
let output = subsampleBuf
// Input: [seqLen, 1024] after flatten (32 channels * 32 height = 1024)
// Weight: [1024, 1024] float32
// Output: [seqLen, 1024] (hiddenSize)
let pso = try engine.pipeline(named: "audio_linear_seq")
let enc = cmdBuf.makeComputeCommandEncoder()!
enc.setComputePipelineState(pso)
enc.setBuffer(input, offset: 0, index: 0)
enc.setBuffer(weights.inputProjLinearWeight, offset: 0, index: 1)
enc.setBuffer(nil, offset: 0, index: 2) // No bias
enc.setBuffer(output, offset: 0, index: 3)
var inFeatures = UInt32(1024)
enc.setBytes(&inFeatures, length: 4, index: 4)
var outFeatures = UInt32(1024)
enc.setBytes(&outFeatures, length: 4, index: 5)
var hasBias = false
enc.setBytes(&hasBias, length: 1, index: 6)
var seqLen_ = UInt32(seqLen)
enc.setBytes(&seqLen_, length: 4, index: 7)
let grid = MTLSize(width: 1024, height: seqLen, depth: 1)
let tg = engine.threadgroupSize2D(pso, grid: (1024, seqLen))
enc.dispatchThreads(grid, threadsPerThreadgroup: tg)
enc.endEncoding()
return output
}
private func applyLayer(
input: MTLBuffer,
weights: AudioLayerWeights,
seqLen: Int,
cmdBuf: MTLCommandBuffer
) throws -> MTLBuffer {
var current = input
// 1. Norm pre-attn
current = try applyRMSNorm(
input: current,
weight: weights.normPreAttn,
seqLen: seqLen,
hiddenSize: config.hiddenSize,
cmdBuf: cmdBuf
)
// 2. Self-attention with relative position
let attnOut = try applySelfAttention(
input: current,
weights: weights,
seqLen: seqLen,
cmdBuf: cmdBuf
)
// 3. Residual + norm post-attn
current = try applyResidualAdd(
input: input,
add: attnOut,
seqLen: seqLen,
hiddenSize: config.hiddenSize,
residualWeight: config.residualWeight,
cmdBuf: cmdBuf
)
current = try applyRMSNorm(
input: current,
weight: weights.normPostAttn,
seqLen: seqLen,
hiddenSize: config.hiddenSize,
cmdBuf: cmdBuf
)
// 4. Local conv1d
let lconvOut = try applyLConv1D(
input: current,
weights: weights,
seqLen: seqLen,
cmdBuf: cmdBuf
)
// 5. Residual
current = try applyResidualAdd(
input: current,
add: lconvOut,
seqLen: seqLen,
hiddenSize: config.hiddenSize,
residualWeight: config.residualWeight,
cmdBuf: cmdBuf
)
// 6. Feed-forward 1
let ff1Out = try applyFeedForward(
input: current,
weights: weights.feedForward1,
seqLen: seqLen,
cmdBuf: cmdBuf
)
// 7. Residual
current = try applyResidualAdd(
input: current,
add: ff1Out,
seqLen: seqLen,
hiddenSize: config.hiddenSize,
residualWeight: config.residualWeight,
cmdBuf: cmdBuf
)
// 8. Feed-forward 2
let ff2Out = try applyFeedForward(
input: current,
weights: weights.feedForward2,
seqLen: seqLen,
cmdBuf: cmdBuf
)
// 9. Residual + norm out
current = try applyResidualAdd(
input: current,
add: ff2Out,
seqLen: seqLen,
hiddenSize: config.hiddenSize,
residualWeight: config.residualWeight,
cmdBuf: cmdBuf
)
current = try applyRMSNorm(
input: current,
weight: weights.normOut,
seqLen: seqLen,
hiddenSize: config.hiddenSize,
cmdBuf: cmdBuf
)
return current
}
private func applySelfAttention(
input: MTLBuffer,
weights: AudioLayerWeights,
seqLen: Int,
cmdBuf: MTLCommandBuffer
) throws -> MTLBuffer {
// Q, K, V projections
let q = try applyQuantizedLinear(
input: input,
weights: weights.selfAttnQProj,
seqLen: seqLen,
output: qBuffer,
cmdBuf: cmdBuf
)
let k = try applyQuantizedLinear(
input: input,
weights: weights.selfAttnKProj,
seqLen: seqLen,
output: kBuffer,
cmdBuf: cmdBuf
)
let v = try applyQuantizedLinear(
input: input,
weights: weights.selfAttnVProj,
seqLen: seqLen,
output: vBuffer,
cmdBuf: cmdBuf
)
// Attention with relative position and context
let attnOut = try applyAudioAttention(
q: q,
k: k,
v: v,
relativeKProj: weights.selfAttnRelativeKProj,
perDimScale: weights.selfAttnPerDimScale,
seqLen: seqLen,
numHeads: config.numAttentionHeads,
headDim: config.headDim,
contextLeft: config.attentionContextLeft,
logitCap: config.attentionLogitCap,
output: attnOutBuffer,
cmdBuf: cmdBuf
)
// Post projection
let output = try applyQuantizedLinear(
input: attnOut,
weights: weights.selfAttnPost,
seqLen: seqLen,
output: tempBuffer,
cmdBuf: cmdBuf
)
return output
}
private func applyAudioAttention(
q: MTLBuffer,
k: MTLBuffer,
v: MTLBuffer,
relativeKProj: MTLBuffer,
perDimScale: MTLBuffer,
seqLen: Int,
numHeads: Int,
headDim: Int,
contextLeft: Int,
logitCap: Float,
output: MTLBuffer,
cmdBuf: MTLCommandBuffer
) throws -> MTLBuffer {
let pso = try engine.pipeline(named: "audio_attention_full")
let enc = cmdBuf.makeComputeCommandEncoder()!
enc.setComputePipelineState(pso)
enc.setBuffer(q, offset: 0, index: 0)
enc.setBuffer(k, offset: 0, index: 1)
enc.setBuffer(v, offset: 0, index: 2)
enc.setBuffer(relativeKProj, offset: 0, index: 3)
enc.setBuffer(perDimScale, offset: 0, index: 4)
enc.setBuffer(output, offset: 0, index: 5)
var seqLen_ = UInt32(seqLen)
enc.setBytes(&seqLen_, length: 4, index: 6)
var numHeads_ = UInt32(numHeads)
enc.setBytes(&numHeads_, length: 4, index: 7)
var headDim_ = UInt32(headDim)
enc.setBytes(&headDim_, length: 4, index: 8)
var contextLeft_ = UInt32(contextLeft)
enc.setBytes(&contextLeft_, length: 4, index: 9)
var logitCap_ = logitCap
enc.setBytes(&logitCap_, length: 4, index: 10)
let grid = MTLSize(width: numHeads * headDim, height: seqLen, depth: 1)
let tg = engine.threadgroupSize2D(pso, grid: (numHeads * headDim, seqLen))
enc.dispatchThreads(grid, threadsPerThreadgroup: tg)
enc.endEncoding()
return output
}
private func applyLConv1D(
input: MTLBuffer,
weights: AudioLayerWeights,
seqLen: Int,
cmdBuf: MTLCommandBuffer
) throws -> MTLBuffer {
// Pre-layer norm
var current = try applyRMSNorm(
input: input,
weight: weights.lconv1dPreLayerNorm,
seqLen: seqLen,
hiddenSize: config.hiddenSize,
cmdBuf: cmdBuf
)
// Linear start: [seqLen, 1024] -> [seqLen, 2048]
let linearStart = try applyQuantizedLinear(
input: current,
weights: weights.lconv1dLinearStart,
seqLen: seqLen,
output: ffnBuffer,
cmdBuf: cmdBuf
)
// SiLU activation
let activated = try applySiLU(input: linearStart, count: seqLen * config.hiddenSize * 2, cmdBuf: cmdBuf)
// Depthwise conv1d
let convOut = try applyDepthwiseConv1D(
input: activated,
weight: weights.lconv1dDepthwiseConv,
norm: weights.lconv1dConvNorm,
seqLen: seqLen,
channels: config.hiddenSize * 2,
kernelSize: config.convKernelSize,
cmdBuf: cmdBuf
)
// Linear end: [seqLen, 2048] -> [seqLen, 1024]
let output = try applyQuantizedLinear(
input: convOut,
weights: weights.lconv1dLinearEnd,
seqLen: seqLen,
output: tempBuffer,
cmdBuf: cmdBuf
)
return output
}
private func applyDepthwiseConv1D(
input: MTLBuffer,
weight: MTLBuffer,
norm: MTLBuffer,
seqLen: Int,
channels: Int,
kernelSize: Int,
cmdBuf: MTLCommandBuffer
) throws -> MTLBuffer {
// FIX: Use layerBuffer for audio layers
let output = layerBuffer
let pso = try engine.pipeline(named: "audio_depthwise_conv1d")
let enc = cmdBuf.makeComputeCommandEncoder()!
enc.setComputePipelineState(pso)
enc.setBuffer(input, offset: 0, index: 0)
enc.setBuffer(weight, offset: 0, index: 1)
enc.setBuffer(norm, offset: 0, index: 2)
enc.setBuffer(output, offset: 0, index: 3)
var channels_ = UInt32(channels)
enc.setBytes(&channels_, length: 4, index: 4)
var kernelSize_ = UInt32(kernelSize)
enc.setBytes(&kernelSize_, length: 4, index: 5)
var seqLen_ = UInt32(seqLen)
enc.setBytes(&seqLen_, length: 4, index: 6)
let grid = MTLSize(width: channels, height: seqLen, depth: 1)
let tg = engine.threadgroupSize2D(pso, grid: (channels, seqLen))
enc.dispatchThreads(grid, threadsPerThreadgroup: tg)
enc.endEncoding()
return output
}
private func applyFeedForward(
input: MTLBuffer,
weights: FeedForwardWeights,
seqLen: Int,
cmdBuf: MTLCommandBuffer
) throws -> MTLBuffer {
// Pre-layer norm
var current = try applyRMSNorm(
input: input,
weight: weights.preLayerNorm,
seqLen: seqLen,
hiddenSize: config.hiddenSize,
cmdBuf: cmdBuf
)
// Layer 1: [seqLen, 1024] -> [seqLen, 4096]
let layer1 = try applyQuantizedLinear(
input: current,
weights: weights.ffwLayer1,
seqLen: seqLen,
output: ffnBuffer,
cmdBuf: cmdBuf
)
// SiLU activation
let activated = try applySiLU(input: layer1, count: seqLen * 4096, cmdBuf: cmdBuf)
// Layer 2: [seqLen, 4096] -> [seqLen, 1024]
let output = try applyQuantizedLinear(
input: activated,
weights: weights.ffwLayer2,
seqLen: seqLen,
output: tempBuffer,
cmdBuf: cmdBuf
)
// Post-layer norm
return try applyRMSNorm(
input: output,
weight: weights.postLayerNorm,
seqLen: seqLen,
hiddenSize: config.hiddenSize,
cmdBuf: cmdBuf
)
}
private func applyRMSNorm(
input: MTLBuffer,
weight: MTLBuffer,
seqLen: Int,
hiddenSize: Int,
cmdBuf: MTLCommandBuffer
) throws -> MTLBuffer {
// FIX: Use layerBuffer for audio layers to avoid tempBuffer conflict
let output = layerBuffer
let pso = try engine.pipeline(named: "rms_norm_seq")
let enc = cmdBuf.makeComputeCommandEncoder()!
enc.setComputePipelineState(pso)
enc.setBuffer(input, offset: 0, index: 0)
enc.setBuffer(weight, offset: 0, index: 1)
enc.setBuffer(output, offset: 0, index: 2)
var N = UInt32(hiddenSize)
enc.setBytes(&N, length: 4, index: 3)
var eps = config.rmsNormEps
enc.setBytes(&eps, length: 4, index: 4)
var seqLen_ = UInt32(seqLen)
enc.setBytes(&seqLen_, length: 4, index: 5)
let grid = MTLSize(width: hiddenSize, height: seqLen, depth: 1)
let tg = engine.threadgroupSize2D(pso, grid: (hiddenSize, seqLen))
enc.dispatchThreads(grid, threadsPerThreadgroup: tg)
enc.endEncoding()
return output
}
private func applyQuantizedLinear(
input: MTLBuffer,
weights: QuantizedWeights,
seqLen: Int,
output: MTLBuffer,
bias: MTLBuffer? = nil,
cmdBuf: MTLCommandBuffer
) throws -> MTLBuffer {
let pso = try engine.pipeline(named: "quantized_matmul_seq")
let enc = cmdBuf.makeComputeCommandEncoder()!
enc.setComputePipelineState(pso)
enc.setBuffer(input, offset: 0, index: 0)
enc.setBuffer(weights.weight, offset: 0, index: 1)
enc.setBuffer(weights.scales, offset: 0, index: 2)
enc.setBuffer(weights.biases, offset: 0, index: 3)
enc.setBuffer(bias, offset: 0, index: 4)
enc.setBuffer(output, offset: 0, index: 5)
var inDim = UInt32(weights.inDim)
enc.setBytes(&inDim, length: 4, index: 6)
var outDim = UInt32(weights.outDim)
enc.setBytes(&outDim, length: 4, index: 7)
var hasBias = bias != nil
enc.setBytes(&hasBias, length: 1, index: 8)
var seqLen_ = UInt32(seqLen)
enc.setBytes(&seqLen_, length: 4, index: 9)
let grid = MTLSize(width: weights.outDim, height: seqLen, depth: 1)
let tg = engine.threadgroupSize2D(pso, grid: (weights.outDim, seqLen))
enc.dispatchThreads(grid, threadsPerThreadgroup: tg)
enc.endEncoding()
return output
}
private func applySiLU(input: MTLBuffer, count: Int, cmdBuf: MTLCommandBuffer) throws -> MTLBuffer {
// FIX: Use layerBuffer for audio layers
let output = layerBuffer
let pso = try engine.pipeline(named: "silu")
let enc = cmdBuf.makeComputeCommandEncoder()!
enc.setComputePipelineState(pso)
enc.setBuffer(input, offset: 0, index: 0)
enc.setBuffer(output, offset: 0, index: 1)
var count_ = UInt32(count)
enc.setBytes(&count_, length: 4, index: 2)
let grid = MTLSize(width: count, height: 1, depth: 1)
let tg = engine.threadgroupSize1D(pso, count: count)
enc.dispatchThreads(grid, threadsPerThreadgroup: tg)
enc.endEncoding()
return output
}
private func applyResidualAdd(
input: MTLBuffer,
add: MTLBuffer,
seqLen: Int,
hiddenSize: Int,
residualWeight: Float,
cmdBuf: MTLCommandBuffer
) throws -> MTLBuffer {
// FIX: Use layerBuffer for audio layers
let output = layerBuffer
let pso = try engine.pipeline(named: "residual_add")
let enc = cmdBuf.makeComputeCommandEncoder()!
enc.setComputePipelineState(pso)
enc.setBuffer(input, offset: 0, index: 0)
enc.setBuffer(add, offset: 0, index: 1)
enc.setBuffer(output, offset: 0, index: 2)
var count32 = UInt32(seqLen * hiddenSize)
enc.setBytes(&count32, length: 4, index: 3)
var weight = residualWeight
enc.setBytes(&weight, length: 4, index: 4)
let count = seqLen * hiddenSize
let grid = MTLSize(width: count, height: 1, depth: 1)
let tg = engine.threadgroupSize1D(pso, count: count)
enc.dispatchThreads(grid, threadsPerThreadgroup: tg)
enc.endEncoding()
return output
}
private func applyOutputProjection(
input: MTLBuffer,
seqLen: Int,
output: MTLBuffer,
cmdBuf: MTLCommandBuffer
) throws {
_ = try applyQuantizedLinear(
input: input,
weights: weights.outputProj,
seqLen: seqLen,
output: output,
bias: weights.outputProjBias,
cmdBuf: cmdBuf
)
}
}