Files
markbaseengine/Sources/MarkBase/Metal/MetalKernels.metal
T
MarkBase Admin 8a66b9086a
CI / build (push) Waiting to run
CI / unit-tests (push) Blocked by required conditions
CI / lint (push) Blocked by required conditions
v2: Initial clean branch with unit tests + CI/CD pipeline
- Started from ac75faa (initial E4B-MarkBase integration)
- Kept Sources/ (all engine code) + Package.swift + .gitignore
- Removed all ad-hoc tests, documentation, scripts, Python files
- Added Tests/00_Unit/ (MathTest, TokenizerTest, SamplerTest)
- Added .gitea/workflows/ci.yaml (build + unit tests + lint)
- Added Scripts/check_resources.sh (memory-aware test runner)
- Added Tests/Manifest.json (resource requirements for all tests)
- Focus: 4-bit quantized models only
2026-07-05 13:29:25 +08:00

2053 lines
77 KiB
Metal
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#include <metal_stdlib>
using namespace metal;
// ═══════════════════════════════════════════════
// E4B Inference Kernels
// ═══════════════════════════════════════════════
// ── Constants ──────────────────────────────────
constant uint GROUP_SIZE = 64; // MLX quantization group size (E4B 4-bit)
// ── 1. RMSNorm ─────────────────────────────────
// y[i] = x[i] * rsqrt(mean(x^2) + eps) * w[i]
// NOTE: NOT safe for in-place when dispatched with multiple threadgroups.
// The Swift layer always passes separate input/output buffers.
kernel void rms_norm(
device const float *x [[buffer(0)]], // [N]
device const float *w [[buffer(1)]], // [N] weight (can be null)
device float *y [[buffer(2)]], // [N]
constant uint &N [[buffer(3)]],
constant float &eps [[buffer(4)]],
uint id [[thread_position_in_grid]]
) {
if (id >= N) return;
float ss = 0.0;
for (uint i = 0; i < N; i++) ss += x[i] * x[i];
float rms = rsqrt(ss / float(N) + eps);
y[id] = (w ? x[id] * rms * w[id] : x[id] * rms);
}
// ── Sampling Kernels ────────────────────────────
// Softmax: probs[i] = exp(logits[i] - max) / sum(exp)
kernel void softmax(
device const float *logits [[buffer(0)]], // [N]
device float *probs [[buffer(1)]], // [N]
constant uint &N [[buffer(2)]],
uint id [[thread_position_in_grid]]
) {
if (id >= N) return;
// Pass 1: find max (all threads compute same)
float maxVal = -INFINITY;
for (uint i = 0; i < N; i++) maxVal = max(maxVal, logits[i]);
// Pass 2: exp and sum
float sumExp = 0.0;
for (uint i = 0; i < N; i++) sumExp += exp(logits[i] - maxVal);
// Output
probs[id] = exp(logits[id] - maxVal) / sumExp;
}
// Temperature scaling: logits[i] /= temperature
kernel void temperature_scale(
device float *logits [[buffer(0)]], // [N] in-place
constant uint &N [[buffer(1)]],
constant float &temperature [[buffer(2)]],
uint id [[thread_position_in_grid]]
) {
if (id >= N) return;
logits[id] = logits[id] / temperature;
}
// Argmax: find index of maximum value
// Uses atomic to safely update best index across threads
kernel void argmax(
device const float *logits [[buffer(0)]], // [N]
device atomic_uint *bestIdx [[buffer(1)]], // single atomic uint
device atomic_float *bestVal [[buffer(2)]], // single atomic float
constant uint &N [[buffer(3)]],
uint id [[thread_position_in_grid]]
) {
if (id >= N) return;
float val = logits[id];
// Atomic compare and swap
float oldBest = atomic_load_explicit(bestVal, memory_order_relaxed);
while (val > oldBest) {
// Try to update
if (atomic_compare_exchange_weak_explicit(
bestVal, &oldBest, val,
memory_order_relaxed, memory_order_relaxed)) {
atomic_store_explicit(bestIdx, id, memory_order_relaxed);
break;
}
// oldBest was updated by another thread, retry
}
}
// Top-k mask: set logits outside top-k to -inf
// Uses parallel sort-like approach with threshold
kernel void top_k_mask(
device float *logits [[buffer(0)]], // [N] in-place
constant uint &N [[buffer(1)]],
constant uint &k [[buffer(2)]],
uint id [[thread_position_in_grid]]
) {
if (id >= N) return;
// Find k-th largest value using parallel bubble sort idea
// This is O(N*k) but simple and works for small vocab sizes
// For large vocab, use a proper GPU sort
// Simple approach: each thread maintains a local threshold
// We'll use a different approach - find threshold via bucket
// For now, use a simpler single-thread approach in Swift
// Actually, for efficiency, we'll do top-k in Swift on CPU
// This kernel just marks the structure
}
// ── 1b. Grouped RMSNorm (per-head norm) ──────────
// Groups are contiguous blocks of `groupSize` elements.
// Each group computes its own RMS independently.
// weight buffer layout: [groupSize], replicated across groups (same weight for each head).
kernel void rms_norm_grouped(
device const float *x [[buffer(0)]], // [N]
device const float *w [[buffer(1)]], // [groupSize] weight replicated across groups
device float *y [[buffer(2)]], // [N]
constant uint &N [[buffer(3)]],
constant uint &groupSize [[buffer(4)]],
constant float &eps [[buffer(5)]],
uint id [[thread_position_in_grid]]
) {
if (id >= N) return;
uint g = id / groupSize;
uint start = g * groupSize;
uint end = min(start + groupSize, N);
uint wIdx = id % groupSize; // Weight index within group
float ss = 0.0;
for (uint i = start; i < end; i++) ss += x[i] * x[i];
float rms = rsqrt(ss / float(groupSize) + eps);
y[id] = (w ? x[id] * rms * w[wIdx] : x[id] * rms);
}
// ── 2. GELU Approximation ──────────────────────
// gelu(x) ≈ x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
kernel void gelu_approx(
device const float *x [[buffer(0)]],
device float *y [[buffer(1)]],
constant uint &N [[buffer(2)]],
uint id [[thread_position_in_grid]]
) {
if (id >= N) return;
float v = x[id];
// Numerically stable GELU: clamp |v| to avoid v^3 overflow in Float32.
// For |v| > 10, GELU ≈ max(v, 0).
float c = M_SQRT2_F * M_2_SQRTPI_F * 0.5; // sqrt(2/pi)
float absv = v > 0 ? v : -v;
if (absv > 10.0) {
y[id] = v > 0 ? v : 0.0;
} else {
float v3 = v * v * v;
y[id] = 0.5 * v * (1.0 + tanh(c * (v + 0.044715 * v3)));
}
}
// ── 3. Quantized MatMul (MLX U32-packed format) ─
// out[outDim] = dequant(weight[outDim, inDim/8]) @ x[inDim]
// scales/biases: [outDim, inDim/GROUP_SIZE]
kernel void quantized_matmul(
device const float *x [[buffer(0)]],
device const uint *w [[buffer(1)]],
device const float *s [[buffer(2)]],
device const float *b [[buffer(3)]],
device float *out [[buffer(4)]],
constant uint &inDim [[buffer(5)]],
constant uint &outDim [[buffer(6)]],
constant uint &groupSize [[buffer(7)]],
uint id [[thread_position_in_grid]]
) {
if (id >= outDim) return;
uint numGroups = inDim / groupSize;
uint packedPerOut = inDim / 8; // 8 × 4-bit per U32
float sum = 0.0;
for (uint g = 0; g < numGroups; g++) {
float scale = s[id * numGroups + g];
float bias = b[id * numGroups + g];
for (uint j = 0; j < groupSize; j++) {
uint packedIdx = g * (groupSize / 8) + j / 8;
uint shift = (j % 8) * 4;
uint qval = (w[id * packedPerOut + packedIdx] >> shift) & 0xF;
float dq = float(qval) * scale + bias;
sum += dq * x[g * groupSize + j];
}
}
out[id] = sum;
}
// ── 3b. Quantized MatMul with fused GELU ───────
// out = gelu(quantized_matmul(x, w, s, b))
kernel void quantized_matmul_gelu(
device const float *x [[buffer(0)]],
device const uint *w [[buffer(1)]],
device const float *s [[buffer(2)]],
device const float *b [[buffer(3)]],
device float *out [[buffer(4)]],
constant uint &inDim [[buffer(5)]],
constant uint &outDim [[buffer(6)]],
constant uint &groupSize [[buffer(7)]],
uint id [[thread_position_in_grid]]
) {
if (id >= outDim) return;
uint numGroups = inDim / groupSize;
uint packedPerOut = inDim / 8;
float sum = 0.0;
for (uint g = 0; g < numGroups; g++) {
float scale = s[id * numGroups + g];
float bias = b[id * numGroups + g];
for (uint j = 0; j < groupSize; j++) {
uint packedIdx = g * (groupSize / 8) + j / 8;
uint shift = (j % 8) * 4;
uint qval = (w[id * packedPerOut + packedIdx] >> shift) & 0xF;
float dq = float(qval) * scale + bias;
sum += dq * x[g * groupSize + j];
}
}
float v = sum;
float c = M_SQRT2_F * M_2_SQRTPI_F * 0.5;
float absv = v > 0 ? v : -v;
if (absv > 10.0) {
out[id] = v > 0 ? v : 0.0;
} else {
float v3 = v * v * v;
out[id] = 0.5 * v * (1.0 + tanh(c * (v + 0.044715 * v3)));
}
}
// ── 4. Quantized MatMul + Mul (for SwiGLU-style) ─
// Output layout: out[0:outDim] = gelu(gate_out), out[outDim:2*outDim] = up_out
// Gate gets GELU activation, Up has no activation
// Swift code will do element-wise multiply: gelu(gate) * up
kernel void quantized_matmul_gate_up(
device const float *x [[buffer(0)]],
// gate projection
device const uint *w_gate [[buffer(1)]],
device const float *s_gate [[buffer(2)]],
device const float *b_gate [[buffer(3)]],
// up projection
device const uint *w_up [[buffer(4)]],
device const float *s_up [[buffer(5)]],
device const float *b_up [[buffer(6)]],
device float *out [[buffer(7)]],
constant uint &inDim [[buffer(8)]],
constant uint &outDim [[buffer(9)]],
constant uint &groupSize [[buffer(10)]],
uint id [[thread_position_in_grid]]
) {
if (id >= outDim) return;
uint numGroups = inDim / groupSize;
uint packedPerOut = inDim / 8;
// Gate projection + GELU
float gateSum = 0.0;
for (uint g = 0; g < numGroups; g++) {
float scale = s_gate[id * numGroups + g];
float bias = b_gate[id * numGroups + g];
for (uint j = 0; j < groupSize; j++) {
uint packedIdx = g * (groupSize / 8) + j / 8;
uint shift = (j % 8) * 4;
uint qval = (w_gate[id * packedPerOut + packedIdx] >> shift) & 0xF;
gateSum += (float(qval) * scale + bias) * x[g * groupSize + j];
}
}
// Clamp gateSum to prevent overflow
if (gateSum > 100.0) gateSum = 100.0;
if (gateSum < -100.0) gateSum = -100.0;
float v = gateSum;
float c = M_SQRT2_F * M_2_SQRTPI_F * 0.5;
float absv = v > 0 ? v : -v;
float gate;
if (absv > 10.0) {
gate = v > 0 ? v : 0.0;
} else {
float v3 = v * v * v;
gate = 0.5 * v * (1.0 + tanh(c * (v + 0.044715 * v3)));
}
// Up projection (no activation)
float upSum = 0.0;
for (uint g = 0; g < numGroups; g++) {
float scale = s_up[id * numGroups + g];
float bias = b_up[id * numGroups + g];
for (uint j = 0; j < groupSize; j++) {
uint packedIdx = g * (groupSize / 8) + j / 8;
uint shift = (j % 8) * 4;
uint qval = (w_up[id * packedPerOut + packedIdx] >> shift) & 0xF;
upSum += (float(qval) * scale + bias) * x[g * groupSize + j];
}
}
// Clamp upSum to prevent overflow
if (upSum > 100.0) upSum = 100.0;
if (upSum < -100.0) upSum = -100.0;
// Clamp gate*up product to prevent overflow
float product = gate * upSum;
if (product > 10.0) product = 10.0;
if (product < -10.0) product = -10.0;
if (isnan(product) || isinf(product)) product = 0.0;
// Original: output element-wise product (for testing)
out[id] = product;
}
// ── 8-bit Fused Gate+Up Matmul ───────────────
// Same as quantized_matmul_gate_up but for 8-bit weights (4 values per uint32, mask 0xFF)
kernel void quantized_matmul_gate_up_8bit(
device const float *x [[buffer(0)]],
device const uint *w_gate [[buffer(1)]],
device const float *s_gate [[buffer(2)]],
device const float *b_gate [[buffer(3)]],
device const uint *w_up [[buffer(4)]],
device const float *s_up [[buffer(5)]],
device const float *b_up [[buffer(6)]],
device float *out [[buffer(7)]],
constant uint &inDim [[buffer(8)]],
constant uint &outDim [[buffer(9)]],
constant uint &groupSize [[buffer(10)]],
uint id [[thread_position_in_grid]]
) {
if (id >= outDim) return;
uint numGroups = inDim / groupSize;
uint packedPerOut = inDim / 4;
// Gate projection + GELU
float gateSum = 0.0;
for (uint g = 0; g < numGroups; g++) {
float scale = s_gate[id * numGroups + g];
float bias = b_gate[id * numGroups + g];
for (uint j = 0; j < groupSize; j++) {
uint packedIdx = g * (groupSize / 4) + j / 4;
uint shift = (j % 4) * 8;
uint qval = (w_gate[id * packedPerOut + packedIdx] >> shift) & 0xFF;
gateSum += (float(qval) * scale + bias) * x[g * groupSize + j];
}
}
// Clamp gateSum to prevent overflow
if (gateSum > 100.0) gateSum = 100.0;
if (gateSum < -100.0) gateSum = -100.0;
float v = gateSum;
float c = M_SQRT2_F * M_2_SQRTPI_F * 0.5;
float absv = v > 0 ? v : -v;
float gate;
if (absv > 10.0) {
gate = v > 0 ? v : 0.0;
} else {
float v3 = v * v * v;
gate = 0.5 * v * (1.0 + tanh(c * (v + 0.044715 * v3)));
}
// Up projection (no activation)
float upSum = 0.0;
for (uint g = 0; g < numGroups; g++) {
float scale = s_up[id * numGroups + g];
float bias = b_up[id * numGroups + g];
for (uint j = 0; j < groupSize; j++) {
uint packedIdx = g * (groupSize / 4) + j / 4;
uint shift = (j % 4) * 8;
uint qval = (w_up[id * packedPerOut + packedIdx] >> shift) & 0xFF;
upSum += (float(qval) * scale + bias) * x[g * groupSize + j];
}
}
// Clamp upSum and product
if (upSum > 100.0) upSum = 100.0;
if (upSum < -100.0) upSum = -100.0;
float product = gate * upSum;
if (product > 10.0) product = 10.0;
if (product < -10.0) product = -10.0;
if (isnan(product) || isinf(product)) product = 0.0;
out[id] = product;
}
// ── 5a. Half-split RoPE (Q only) ─────────────
// Gemma uses rotate_half style: pairs are (d, d + headDim/2), not (d1, d2)
kernel void apply_rope_q(
device float *q [[buffer(0)]], // [nHeads, headDim] — in-place
constant uint &nHeads [[buffer(1)]],
constant uint &headDim [[buffer(2)]],
constant uint &rotatedDim [[buffer(3)]],
constant float &theta [[buffer(4)]],
constant float &scale [[buffer(5)]],
constant int &position [[buffer(6)]],
uint id [[thread_position_in_grid]]
) {
uint halfDim = headDim / 2;
uint nPairs = rotatedDim / 2;
uint head = id / nPairs;
uint pair = id % nPairs;
if (head >= nHeads || pair >= nPairs || nPairs == 0) return;
// Half rotation: pair i corresponds to (i, i + halfDim)
uint d1 = pair;
uint d2 = pair + halfDim;
float freqBase = pow(theta, -2.0 * float(pair) / float(headDim));
float freq = freqBase * pow(scale, float(position));
float c = cos(float(position) * freq);
float s = sin(float(position) * freq);
device float *h = q + head * headDim;
float v1 = h[d1], v2 = h[d2];
// rotate_half style: [ -v2 * sin, v1 * sin ] + [ v1 * cos, v2 * cos ]
// But we compute in-place: d1 = v1*c - v2*s, d2 = v1*s + v2*c (same formula!)
h[d1] = v1 * c - v2 * s;
h[d2] = v1 * s + v2 * c;
}
// ── 5b. Half-split RoPE (K only) ─────────────
// Gemma uses rotate_half style: pairs are (d, d + headDim/2), not (d1, d2)
kernel void apply_rope_k(
device float *k [[buffer(0)]], // [nKvHeads, headDim] — in-place
constant uint &nKvHeads [[buffer(1)]],
constant uint &headDim [[buffer(2)]],
constant uint &rotatedDim [[buffer(3)]],
constant float &theta [[buffer(4)]],
constant float &scale [[buffer(5)]],
constant int &position [[buffer(6)]],
uint id [[thread_position_in_grid]]
) {
uint halfDim = headDim / 2;
uint nPairs = rotatedDim / 2;
uint kvHead = id / nPairs;
uint pair = id % nPairs;
if (kvHead >= nKvHeads || pair >= nPairs || nPairs == 0) return;
// Half rotation: pair i corresponds to (i, i + halfDim)
uint d1 = pair;
uint d2 = pair + halfDim;
float freqBase = pow(theta, -2.0 * float(pair) / float(headDim));
float freq = freqBase * pow(scale, float(position));
float c = cos(float(position) * freq);
float s = sin(float(position) * freq);
device float *h = k + kvHead * headDim;
float v1 = h[d1], v2 = h[d2];
h[d1] = v1 * c - v2 * s;
h[d2] = v1 * s + v2 * c;
}
// ── 6. Scaled Dot-Product Attention (Sliding, rotating buffer) ──
// O = softmax(Q K^T) V with sliding window (rotating) and GQA.
// cacheIdx = (start + t) % windowSize for rotating wrap.
kernel void sliding_attention(
device const float *q [[buffer(0)]], // [nHeads, headDim]
device const float *k [[buffer(1)]], // [windowSize, nKvHeads, headDim]
device const float *v [[buffer(2)]], // [windowSize, nKvHeads, headDim]
device float *out [[buffer(3)]], // [nHeads, headDim]
constant uint &nHeads [[buffer(4)]],
constant uint &nKvHeads [[buffer(5)]],
constant uint &headDim [[buffer(6)]],
constant uint &windowSize [[buffer(7)]],
constant int &offset [[buffer(8)]], // current logical position
uint2 gid [[thread_position_in_grid]]
) {
uint head = gid.x;
uint dim = gid.y;
if (head >= nHeads || dim >= headDim) return;
uint kvHead = head % nKvHeads;
uint seqLen = uint(offset + 1);
uint actualWindow = min(seqLen, windowSize);
int base = int(offset) - int(actualWindow) + 1; // may be negative
// Cache layout: [maxLength, nKvHeads, headDim] flat
// k[p * nKvHeads * headDim + h * headDim + d]
// v[p * nKvHeads * headDim + h * headDim + d]
float scale = 1.0 / sqrt(float(headDim));
// Pass 1: find max score
float maxScore = -INFINITY;
for (uint t = 0; t < actualWindow; t++) {
int logicalPos = base + int(t);
uint cacheIdx = logicalPos >= 0 ? uint(logicalPos) % windowSize : 0;
float score = 0.0;
for (uint d = 0; d < headDim; d++) {
score += q[head * headDim + d] * k[(cacheIdx * nKvHeads + kvHead) * headDim + d];
}
score *= scale;
// Text model has NO attention softcapping
maxScore = max(maxScore, score);
}
// Pass 2: softmax + weighted sum
float sumExp = 0.0;
float result = 0.0;
for (uint t = 0; t < actualWindow; t++) {
int logicalPos = base + int(t);
uint cacheIdx = logicalPos >= 0 ? uint(logicalPos) % windowSize : 0;
float score = 0.0;
for (uint d = 0; d < headDim; d++) {
score += q[head * headDim + d] * k[(cacheIdx * nKvHeads + kvHead) * headDim + d];
}
score *= scale;
// Text model has NO attention softcapping
float expVal = exp(score - maxScore);
sumExp += expVal;
result += expVal * v[(cacheIdx * nKvHeads + kvHead) * headDim + dim];
}
out[head * headDim + dim] = result / sumExp;
}
// ── 6b. Sliding attention with current K,V appended ──
// For non-owner sliding layers: use cache K,V + current layer's K,V
// Cache entries are positions 0..cacheLen-1, current K/V is at position = cacheLen
kernel void sliding_attention_with_current(
device const float *q [[buffer(0)]], // [nHeads, headDim]
device const float *cacheK[[buffer(1)]], // [windowSize, nKvHeads, headDim]
device const float *cacheV[[buffer(2)]], // [windowSize, nKvHeads, headDim]
device const float *curK [[buffer(3)]], // [nKvHeads, headDim]
device const float *curV [[buffer(4)]], // [nKvHeads, headDim]
device float *out [[buffer(5)]], // [nHeads, headDim]
constant uint &nHeads [[buffer(6)]],
constant uint &nKvHeads [[buffer(7)]],
constant uint &headDim [[buffer(8)]],
constant uint &windowSize [[buffer(9)]],
constant uint &cacheLen [[buffer(10)]], // number of entries in cache
constant int &position [[buffer(11)]], // current position for causal mask
uint2 gid [[thread_position_in_grid]]
) {
uint head = gid.x;
uint dim = gid.y;
if (head >= nHeads || dim >= headDim) return;
uint kvHead = head % nKvHeads;
uint seqLen = cacheLen + 1; // cache entries + current K,V
uint actualWindow = min(seqLen, windowSize);
int base = int(seqLen) - int(actualWindow);
float scale = 1.0 / sqrt(float(headDim));
// Pass 1: find max score (with causal mask)
float maxScore = -INFINITY;
for (uint t = 0; t < actualWindow; t++) {
int logicalPos = base + int(t);
// Causal mask: only attend to positions <= current position
if (logicalPos > position) continue; // skip future positions
float score = 0.0;
if (logicalPos >= 0 && uint(logicalPos) < cacheLen) {
// From cache
uint cacheIdx = uint(logicalPos) % windowSize;
for (uint d = 0; d < headDim; d++) {
score += q[head * headDim + d] * cacheK[(cacheIdx * nKvHeads + kvHead) * headDim + d];
}
} else {
// Current layer's K,V (position = cacheLen)
for (uint d = 0; d < headDim; d++) {
score += q[head * headDim + d] * curK[kvHead * headDim + d];
}
}
score *= scale;
// Text model has NO attention softcapping
maxScore = max(maxScore, score);
}
// Pass 2: softmax + weighted sum (with causal mask)
float sumExp = 0.0;
float result = 0.0;
for (uint t = 0; t < actualWindow; t++) {
int logicalPos = base + int(t);
// Causal mask: only attend to positions <= current position
if (logicalPos > position) continue; // skip future positions
float score = 0.0;
if (logicalPos >= 0 && uint(logicalPos) < cacheLen) {
// From cache
uint cacheIdx = uint(logicalPos) % windowSize;
for (uint d = 0; d < headDim; d++) {
score += q[head * headDim + d] * cacheK[(cacheIdx * nKvHeads + kvHead) * headDim + d];
}
score *= scale;
// Text model has NO attention softcapping
float expVal = exp(score - maxScore);
sumExp += expVal;
result += expVal * cacheV[(cacheIdx * nKvHeads + kvHead) * headDim + dim];
} else {
// Current layer's K,V
for (uint d = 0; d < headDim; d++) {
score += q[head * headDim + d] * curK[kvHead * headDim + d];
}
score *= scale;
// Text model has NO attention softcapping
float expVal = exp(score - maxScore);
sumExp += expVal;
result += expVal * curV[kvHead * headDim + dim];
}
}
out[head * headDim + dim] = result / sumExp;
}
// ── 7. Scaled Dot-Product Attention (Full, causal) ──
// Two-pass: find max, then softmax + weighted sum (no score array splat).
// Supports arbitrary sequence length.
kernel void full_attention(
device const float *q [[buffer(0)]], // [nHeads, headDim]
device const float *k [[buffer(1)]], // [maxPos, nKvHeads, headDim]
device const float *v [[buffer(2)]], // [maxPos, nKvHeads, headDim]
device float *out [[buffer(3)]], // [nHeads, headDim]
constant uint &nHeads [[buffer(4)]],
constant uint &nKvHeads [[buffer(5)]],
constant uint &headDim [[buffer(6)]],
constant uint &maxPos [[buffer(7)]],
constant int &offset [[buffer(8)]],
uint2 gid [[thread_position_in_grid]]
) {
uint head = gid.x;
uint dim = gid.y;
if (head >= nHeads || dim >= headDim) return;
uint kvHead = head % nKvHeads;
uint seqLen = uint(offset + 1);
// Cache layout: [maxPos, nKvHeads, headDim] flat
// k[t * nKvHeads * headDim + h * headDim + d]
// v[t * nKvHeads * headDim + h * headDim + d]
float scale = 1.0 / sqrt(float(headDim));
// Pass 1: max
float maxScore = -INFINITY;
for (uint t = 0; t < seqLen; t++) {
float score = 0.0;
for (uint d = 0; d < headDim; d++) {
score += q[head * headDim + d] * k[(t * nKvHeads + kvHead) * headDim + d];
}
score *= scale;
// Text model has NO attention softcapping
maxScore = max(maxScore, score);
}
// Pass 2: softmax + weighted sum
float sumExp = 0.0;
float result = 0.0;
for (uint t = 0; t < seqLen; t++) {
float score = 0.0;
for (uint d = 0; d < headDim; d++) {
score += q[head * headDim + d] * k[(t * nKvHeads + kvHead) * headDim + d];
}
score *= scale;
// Text model has NO attention softcapping
float expVal = exp(score - maxScore);
sumExp += expVal;
result += expVal * v[(t * nKvHeads + kvHead) * headDim + dim];
}
out[head * headDim + dim] = result / sumExp;
}
// ── 7b. Full attention with current K,V appended ──
// For non-owner layers: use cache K,V + current layer's K,V
// Cache entries are positions 0..cacheLen-1, current K/V is at position = cacheLen
// Causal mask: only attend to positions <= current position
kernel void full_attention_with_current(
device const float *q [[buffer(0)]], // [nHeads, headDim]
device const float *cacheK[[buffer(1)]], // [maxPos, nKvHeads, headDim]
device const float *cacheV[[buffer(2)]], // [maxPos, nKvHeads, headDim]
device const float *curK [[buffer(3)]], // [nKvHeads, headDim]
device const float *curV [[buffer(4)]], // [nKvHeads, headDim]
device float *out [[buffer(5)]], // [nHeads, headDim]
constant uint &nHeads [[buffer(6)]],
constant uint &nKvHeads [[buffer(7)]],
constant uint &headDim [[buffer(8)]],
constant uint &cacheLen [[buffer(9)]], // number of entries in cache (position + 1)
constant int &position [[buffer(10)]], // current position for causal mask
uint2 gid [[thread_position_in_grid]]
) {
uint head = gid.x;
uint dim = gid.y;
if (head >= nHeads || dim >= headDim) return;
uint kvHead = head % nKvHeads;
float scale = 1.0 / sqrt(float(headDim));
// Pass 1: max score (with causal mask)
float maxScore = -INFINITY;
uint seqLen = cacheLen + 1; // cache entries + current K,V (always include current)
// Process all entries (cache + current), apply causal mask per-entry
for (uint t = 0; t < seqLen; t++) {
int logicalPos = int(t);
if (logicalPos > position) continue; // causal mask: skip future positions
float score = 0.0;
if (t < cacheLen) {
// From cache
for (uint d = 0; d < headDim; d++) {
score += q[head * headDim + d] * cacheK[(t * nKvHeads + kvHead) * headDim + d];
}
} else {
// Current layer's K,V
for (uint d = 0; d < headDim; d++) {
score += q[head * headDim + d] * curK[kvHead * headDim + d];
}
}
score *= scale;
// Text model has NO attention softcapping
maxScore = max(maxScore, score);
}
// Pass 2: softmax + weighted sum (with causal mask)
float sumExp = 0.0;
float result = 0.0;
for (uint t = 0; t < seqLen; t++) {
int logicalPos = int(t);
if (logicalPos > position) continue; // causal mask: skip future positions
float score = 0.0;
if (t < cacheLen) {
// From cache
for (uint d = 0; d < headDim; d++) {
score += q[head * headDim + d] * cacheK[(t * nKvHeads + kvHead) * headDim + d];
}
score *= scale;
float expVal = exp(score - maxScore);
sumExp += expVal;
result += expVal * cacheV[(t * nKvHeads + kvHead) * headDim + dim];
} else {
// Current layer's K,V
for (uint d = 0; d < headDim; d++) {
score += q[head * headDim + d] * curK[kvHead * headDim + d];
}
score *= scale;
float expVal = exp(score - maxScore);
sumExp += expVal;
result += expVal * curV[kvHead * headDim + dim];
}
}
out[head * headDim + dim] = result / sumExp;
}
// ── 8. Dequantize a single row ─────────────────
kernel void dequantize_row(
device const uint *w [[buffer(0)]], // [nRows, nCols/8]
device const float *s [[buffer(1)]], // [nRows, numGroups]
device const float *b [[buffer(2)]], // [nRows, numGroups]
device float *out [[buffer(3)]], // [nCols]
constant uint &nCols [[buffer(4)]],
constant int &rowIdx [[buffer(5)]],
constant uint &groupSize [[buffer(6)]],
uint id [[thread_position_in_grid]]
) {
if (id >= nCols) return;
uint g = id / groupSize;
uint inG = id % groupSize;
uint packedIdx = g * (groupSize / 8) + inG / 8;
uint shift = (inG % 8) * 4;
uint qval = (w[rowIdx * (nCols / 8) + packedIdx] >> shift) & 0xF;
uint numGroups = nCols / groupSize;
float scale = s[rowIdx * numGroups + g];
float bias = b[rowIdx * numGroups + g];
out[id] = float(qval) * scale + bias;
}
// ── 9. Element-wise helpers ────────────────────
kernel void eltwise_add(
device const float *a [[buffer(0)]],
device const float *b [[buffer(1)]],
device float *out [[buffer(2)]],
constant uint &n [[buffer(3)]],
uint id [[thread_position_in_grid]]
) { if (id < n) out[id] = a[id] + b[id]; }
kernel void eltwise_mul(
device const float *a [[buffer(0)]],
device const float *b [[buffer(1)]],
device float *out [[buffer(2)]],
constant uint &n [[buffer(3)]],
uint id [[thread_position_in_grid]]
) { if (id < n) out[id] = a[id] * b[id]; }
kernel void eltwise_scale(
device float *buf [[buffer(0)]],
constant float &scale [[buffer(1)]],
constant uint &n [[buffer(2)]],
uint id [[thread_position_in_grid]]
) { if (id < n) buf[id] *= scale; }
// out = a * scaleA + b * scaleB
kernel void eltwise_add_scaled(
device const float *a [[buffer(0)]],
constant float &scaleA [[buffer(1)]],
device const float *b [[buffer(2)]],
constant float &scaleB [[buffer(3)]],
device float *out [[buffer(4)]],
constant uint &n [[buffer(5)]],
uint id [[thread_position_in_grid]]
) { if (id < n) out[id] = a[id] * scaleA + b[id] * scaleB; }
// ── 10. Tanh scaling (logit softcapping) ──────
// out[i] = tanh(in[i] / cap) * cap
kernel void tanh_scale(
device const float *inp [[buffer(0)]],
device float *out [[buffer(1)]],
constant float &cap [[buffer(2)]],
constant uint &n [[buffer(3)]],
uint id [[thread_position_in_grid]]
) {
if (id >= n) return;
float v = inp[id];
out[id] = tanh(v / cap) * cap;
}
// ══════════════════════════════════════════════════════
// Audio Processing Kernels
// ══════════════════════════════════════════════════════
// ── Audio Subsample Convolution ──
// 2D convolution for audio feature extraction
kernel void audio_conv2d(
device const float *input [[buffer(0)]], // [inChannels, height, width]
device const float *weight [[buffer(1)]], // [outChannels, inChannels, kernelH, kernelW]
device float *output [[buffer(2)]], // [outChannels, outHeight, outWidth]
constant uint &inChannels [[buffer(3)]],
constant uint &outChannels [[buffer(4)]],
constant uint &inHeight [[buffer(5)]],
constant uint &inWidth [[buffer(6)]],
constant uint &kernelH [[buffer(7)]],
constant uint &kernelW [[buffer(8)]],
constant uint &strideH [[buffer(9)]],
constant uint &strideW [[buffer(10)]],
uint3 gid [[thread_position_in_grid]]
) {
uint oc = gid.x; // output channel
uint oh = gid.y; // output height
uint ow = gid.z; // output width
if (oc >= outChannels || oh >= (inHeight - kernelH + strideH) / strideH ||
ow >= (inWidth - kernelW + strideW) / strideW) return;
float sum = 0.0;
for (uint ic = 0; ic < inChannels; ic++) {
for (uint kh = 0; kh < kernelH; kh++) {
for (uint kw = 0; kw < kernelW; kw++) {
uint ih = oh * strideH + kh;
uint iw = ow * strideW + kw;
uint inIdx = ic * inHeight * inWidth + ih * inWidth + iw;
uint wIdx = oc * inChannels * kernelH * kernelW +
ic * kernelH * kernelW + kh * kernelW + kw;
sum += input[inIdx] * weight[wIdx];
}
}
}
uint outIdx = oc * ((inHeight - kernelH + strideH) / strideH) *
((inWidth - kernelW + strideW) / strideW) + oh *
((inWidth - kernelW + strideW) / strideW) + ow;
output[outIdx] = sum;
}
// ── Audio RMS Norm ──
// Per-channel RMS normalization for audio features
kernel void audio_rms_norm(
device const float *input [[buffer(0)]],
device const float *weight [[buffer(1)]],
device float *output [[buffer(2)]],
constant uint &channels [[buffer(3)]],
constant uint &featureSize [[buffer(4)]],
constant float &eps [[buffer(5)]],
uint2 gid [[thread_position_in_grid]]
) {
uint ch = gid.x;
uint feat = gid.y;
if (ch >= channels || feat >= featureSize) return;
// Compute RMS for this channel
float ss = 0.0;
for (uint f = 0; f < featureSize; f++) {
float v = input[ch * featureSize + f];
ss += v * v;
}
float rms = sqrt(ss / float(featureSize) + eps);
// Normalize and apply weight
uint idx = ch * featureSize + feat;
output[idx] = input[idx] / rms * weight[ch];
}
// ── Audio Linear Projection ──
// Linear projection for audio features
kernel void audio_linear(
device const float *input [[buffer(0)]], // [inFeatures]
device const float *weight [[buffer(1)]], // [outFeatures, inFeatures]
device const float *bias [[buffer(2)]], // [outFeatures] (optional)
device float *output [[buffer(3)]], // [outFeatures]
constant uint &inFeatures [[buffer(4)]],
constant uint &outFeatures [[buffer(5)]],
constant bool &hasBias [[buffer(6)]],
uint gid [[thread_position_in_grid]]
) {
uint of = gid;
if (of >= outFeatures) return;
float sum = hasBias ? bias[of] : 0.0;
for (uint inf = 0; inf < inFeatures; inf++) {
sum += input[inf] * weight[of * inFeatures + inf];
}
output[of] = sum;
}
// ── Audio Attention ──
// Sliding window attention for audio encoder
kernel void audio_attention(
device const float *q [[buffer(0)]], // [nHeads, headDim]
device const float *k [[buffer(1)]], // [seqLen, nKvHeads, headDim]
device const float *v [[buffer(2)]], // [seqLen, nKvHeads, headDim]
device float *out [[buffer(3)]], // [nHeads, headDim]
constant uint &nHeads [[buffer(4)]],
constant uint &nKvHeads [[buffer(5)]],
constant uint &headDim [[buffer(6)]],
constant uint &seqLen [[buffer(7)]],
constant uint &chunkSize [[buffer(8)]],
constant uint &contextLeft [[buffer(9)]],
constant float &logitCap [[buffer(10)]],
uint2 gid [[thread_position_in_grid]]
) {
uint head = gid.x;
uint dim = gid.y;
if (head >= nHeads || dim >= headDim) return;
uint kvHead = head % nKvHeads;
float scale = 1.0 / sqrt(float(headDim));
// Compute attention scores for chunk
float maxScore = -INFINITY;
uint startPos = max(0u, seqLen - chunkSize);
for (uint t = startPos; t < seqLen; t++) {
float score = 0.0;
for (uint d = 0; d < headDim; d++) {
score += q[head * headDim + d] * k[(t * nKvHeads + kvHead) * headDim + d];
}
score *= scale;
score = tanh(score / logitCap) * logitCap;
maxScore = max(maxScore, score);
}
// Softmax + weighted sum
float sumExp = 0.0;
float result = 0.0;
for (uint t = startPos; t < seqLen; t++) {
float score = 0.0;
for (uint d = 0; d < headDim; d++) {
score += q[head * headDim + d] * k[(t * nKvHeads + kvHead) * headDim + d];
}
score *= scale;
score = tanh(score / logitCap) * logitCap;
float expVal = exp(score - maxScore);
sumExp += expVal;
result += expVal * v[(t * nKvHeads + kvHead) * headDim + dim];
}
out[head * headDim + dim] = result / sumExp;
}
// ── Audio FFN (Feed-Forward Network) ──
kernel void audio_ffn(
device const float *input [[buffer(0)]],
device const float *weight1 [[buffer(1)]],
device const float *weight2 [[buffer(2)]],
device float *output [[buffer(3)]],
constant uint &inFeatures [[buffer(4)]],
constant uint &hiddenFeatures [[buffer(5)]],
uint gid [[thread_position_in_grid]]
) {
uint of = gid;
if (of >= inFeatures) return;
// FFN: hidden = silu(input @ weight1), output = hidden @ weight2
float sum = 0.0;
for (uint hf = 0; hf < hiddenFeatures; hf++) {
float h = 0.0;
for (uint inf = 0; inf < inFeatures; inf++) {
h += input[inf] * weight1[hf * inFeatures + inf];
}
// SiLU activation: x * sigmoid(x)
float sigmoidH = 1.0 / (1.0 + exp(-h));
float activated = h * sigmoidH;
sum += activated * weight2[of * hiddenFeatures + hf];
}
output[of] = sum;
}
// ── Audio 1D Convolution (Local Context) ──
kernel void audio_conv1d(
device const float *input [[buffer(0)]],
device const float *weight [[buffer(1)]],
device float *output [[buffer(2)]],
constant uint &inChannels [[buffer(3)]],
constant uint &outChannels [[buffer(4)]],
constant uint &kernelSize [[buffer(5)]],
constant uint &seqLen [[buffer(6)]],
uint2 gid [[thread_position_in_grid]]
) {
uint oc = gid.x;
uint pos = gid.y;
if (oc >= outChannels || pos >= seqLen) return;
float sum = 0.0;
for (uint ic = 0; ic < inChannels; ic++) {
for (uint k = 0; k < kernelSize; k++) {
int inPos = int(pos) - int(kernelSize / 2) + int(k);
if (inPos >= 0 && inPos < int(seqLen)) {
uint inIdx = ic * seqLen + uint(inPos);
uint wIdx = oc * inChannels * kernelSize + ic * kernelSize + k;
sum += input[inIdx] * weight[wIdx];
}
}
}
output[oc * seqLen + pos] = sum;
}
// ═══════════════════════════════════════════════
// Audio Tower Kernels
// ═══════════════════════════════════════════════
// Audio subsample conv 2D: stride-2 conv2d with group norm
// Treats mel spectrogram as 2D surface: [inCh, H, W] where H=nMels, W=seqLen
// Conv weight (safetensors format): [outCh, kernelH, kernelW, inCh] = [outCh, 3, 3, inCh]
// Norm weight: [outCh] per-channel group norm
// Output: [outCh, outH, outW] CHW flat where outH=(H+1)/2, outW=(W+1)/2
kernel void audio_subsample_conv_2d(
device const float *input [[buffer(0)]], // [inCh, H, W] CHW flat
device const float *convWeight [[buffer(1)]], // [outCh, 3, 3, inCh]
device const float *normWeight [[buffer(2)]], // [outCh]
device float *output [[buffer(3)]], // [outCh, outH, outW] CHW flat
constant uint &inChannels [[buffer(4)]],
constant uint &outChannels [[buffer(5)]],
constant uint &height [[buffer(6)]], // H = nMels
constant uint &width [[buffer(7)]], // W = seqLen
uint3 gid [[thread_position_in_grid]]
) {
uint oc = gid.x;
uint oh = gid.y;
uint ow = gid.z;
uint outH = (height + 1) / 2;
uint outW = (width + 1) / 2;
if (oc >= outChannels || oh >= outH || ow >= outW) return;
int ihStart = int(oh * 2) - 1;
int iwStart = int(ow * 2) - 1;
float sum = 0.0;
for (uint ic = 0; ic < inChannels; ic++) {
for (uint kh = 0; kh < 3; kh++) {
for (uint kw = 0; kw < 3; kw++) {
int ih = ihStart + int(kh);
int iw = iwStart + int(kw);
if (ih >= 0 && ih < int(height) && iw >= 0 && iw < int(width)) {
uint inIdx = ic * height * width + uint(ih) * width + uint(iw);
uint wIdx = oc * 9 * inChannels + kh * 3 * inChannels + kw * inChannels + ic;
sum += input[inIdx] * convWeight[wIdx];
}
}
}
}
sum = sum * normWeight[oc];
output[oc * outH * outW + oh * outW + ow] = sum;
}
// Transpose 2D matrix: [rows, cols] -> [cols, rows]
// Used for converting mel spectrogram from [seqLen, nMels] to CHW [1, nMels, seqLen]
kernel void transpose_2d(
device const float *input [[buffer(0)]], // [rows, cols] row-major
device float *output [[buffer(1)]], // [cols, rows] row-major
constant uint &rows [[buffer(2)]],
constant uint &cols [[buffer(3)]],
uint2 gid [[thread_position_in_grid]]
) {
uint r = gid.y;
uint c = gid.x;
if (r >= rows || c >= cols) return;
output[c * rows + r] = input[r * cols + c];
}
// Flatten CHW [C, H, W] -> row-major [W, C*H]
// Useful after subsample conv to prepare for linear projection
kernel void audio_flatten_chw(
device const float *input [[buffer(0)]], // [C, H, W] CHW flat
device float *output [[buffer(1)]], // [W, C*H] row-major
constant uint &C [[buffer(2)]],
constant uint &H [[buffer(3)]],
constant uint &W [[buffer(4)]],
uint2 gid [[thread_position_in_grid]]
) {
uint ch = gid.x;
uint w = gid.y;
if (ch >= C * H || w >= W) return;
uint c = ch / H;
uint h = ch % H;
uint inIdx = c * H * W + h * W + w;
output[w * (C * H) + ch] = input[inIdx];
}
// Audio linear seq: [seqLen, inFeatures] -> [seqLen, outFeatures]
kernel void audio_linear_seq(
device const float *input [[buffer(0)]],
device const float *weight [[buffer(1)]],
device const float *bias [[buffer(2)]],
device float *output [[buffer(3)]],
constant uint &inFeatures [[buffer(4)]],
constant uint &outFeatures [[buffer(5)]],
constant bool &hasBias [[buffer(6)]],
constant uint &seqLen [[buffer(7)]],
uint2 gid [[thread_position_in_grid]]
) {
uint of = gid.x;
uint s = gid.y;
if (of >= outFeatures || s >= seqLen) return;
float sum = hasBias ? bias[of] : 0.0;
for (uint i = 0; i < inFeatures; i++) {
sum += input[s * inFeatures + i] * weight[of * inFeatures + i];
}
output[s * outFeatures + of] = sum;
}
// Audio quantized matmul with sequence dimension (batched)
// output[s * outDim + of] = bias[of] + sum_i input[s * inDim + i] * deq(weight[of][i])
kernel void quantized_matmul_seq(
device const float *input [[buffer(0)]], // [seqLen, inDim]
device const uint *weight [[buffer(1)]], // [outDim, inDim/8]
device const float *scales [[buffer(2)]], // [outDim, inDim/64]
device const float *biases_q [[buffer(3)]], // [outDim, inDim/64]
device const float *bias [[buffer(4)]], // [outDim] optional output bias
device float *output [[buffer(5)]], // [seqLen, outDim]
constant uint &inDim [[buffer(6)]],
constant uint &outDim [[buffer(7)]],
constant bool &hasBias [[buffer(8)]],
constant uint &seqLen [[buffer(9)]],
uint2 gid [[thread_position_in_grid]]
) {
uint of = gid.x;
uint s = gid.y;
if (of >= outDim || s >= seqLen) return;
uint numGroups = inDim / GROUP_SIZE;
uint packedPerOut = inDim / 8;
float sum = hasBias ? bias[of] : 0.0;
for (uint g = 0; g < numGroups; g++) {
float scale = scales[of * numGroups + g];
float bias_q = biases_q[of * numGroups + g];
for (uint j = 0; j < GROUP_SIZE; j++) {
uint packedIdx = g * (GROUP_SIZE / 8) + j / 8;
uint shift = (j % 8) * 4;
uint qval = (weight[of * packedPerOut + packedIdx] >> shift) & 0xF;
float dq = float(qval) * scale + bias_q;
sum += dq * input[s * inDim + g * GROUP_SIZE + j];
}
}
output[s * outDim + of] = sum;
}
// Audio quantized linear with 8-bit quantization scales
kernel void audio_quantized_linear(
device const float *input [[buffer(0)]],
device const float *weight [[buffer(1)]],
device const float *inputMin [[buffer(2)]],
device const float *inputMax [[buffer(3)]],
device const float *outputMin [[buffer(4)]],
device const float *outputMax [[buffer(5)]],
device float *output [[buffer(6)]],
constant uint &inFeatures [[buffer(7)]],
constant uint &outFeatures [[buffer(8)]],
constant uint &seqLen [[buffer(9)]],
uint2 gid [[thread_position_in_grid]]
) {
uint of = gid.x;
uint s = gid.y;
if (of >= outFeatures || s >= seqLen) return;
float sum = 0.0;
for (uint i = 0; i < inFeatures; i++) {
// Apply input quantization scale if available
float inVal = input[s * inFeatures + i];
if (inputMin && inputMax) {
float scale = (inputMax[0] - inputMin[0]) / 255.0;
inVal = inVal * scale + inputMin[0];
}
sum += inVal * weight[of * inFeatures + i];
}
// Apply output quantization scale if available
if (outputMin && outputMax) {
float scale = (outputMax[0] - outputMin[0]) / 255.0;
sum = sum * scale + outputMin[0];
}
output[s * outFeatures + of] = sum;
}
// Audio attention with relative position and context window
kernel void audio_attention_full(
device const float *q [[buffer(0)]],
device const float *k [[buffer(1)]],
device const float *v [[buffer(2)]],
device const float *relativeK [[buffer(3)]],
device const float *perDimScale [[buffer(4)]],
device float *output [[buffer(5)]],
constant uint &seqLen [[buffer(6)]],
constant uint &numHeads [[buffer(7)]],
constant uint &headDim [[buffer(8)]],
constant uint &contextLeft [[buffer(9)]],
constant float &logitCap [[buffer(10)]],
uint2 gid [[thread_position_in_grid]]
) {
uint idx = gid.x;
uint pos = gid.y;
uint head = idx / headDim;
uint d = idx % headDim;
if (head >= numHeads || pos >= seqLen) return;
float qVal = q[pos * numHeads * headDim + head * headDim + d];
// Compute attention scores
float sum = 0.0;
float maxScore = -INFINITY;
// Context window: attend to positions [pos - contextLeft, pos]
int startPos = max(0, int(pos) - int(contextLeft));
for (int p = startPos; p <= int(pos); p++) {
float kVal = k[uint(p) * numHeads * headDim + head * headDim + d];
float score = qVal * kVal * perDimScale[head * headDim + d];
score = min(score, logitCap);
score = max(score, -logitCap);
maxScore = max(maxScore, score);
}
// Softmax
float expSum = 0.0;
for (int p = startPos; p <= int(pos); p++) {
float kVal = k[uint(p) * numHeads * headDim + head * headDim + d];
float score = qVal * kVal * perDimScale[head * headDim + d];
score = min(score, logitCap);
score = max(score, -logitCap);
expSum += exp(score - maxScore);
}
// Weighted sum of values
float outVal = 0.0;
for (int p = startPos; p <= int(pos); p++) {
float kVal = k[uint(p) * numHeads * headDim + head * headDim + d];
float vVal = v[uint(p) * numHeads * headDim + head * headDim + d];
float score = qVal * kVal * perDimScale[head * headDim + d];
score = min(score, logitCap);
score = max(score, -logitCap);
float attn = exp(score - maxScore) / expSum;
outVal += attn * vVal;
}
output[pos * numHeads * headDim + head * headDim + d] = outVal;
}
// Audio depthwise conv1d
kernel void audio_depthwise_conv1d(
device const float *input [[buffer(0)]],
device const float *weight [[buffer(1)]],
device const float *norm [[buffer(2)]],
device float *output [[buffer(3)]],
constant uint &channels [[buffer(4)]],
constant uint &kernelSize [[buffer(5)]],
constant uint &seqLen [[buffer(6)]],
uint2 gid [[thread_position_in_grid]]
) {
uint c = gid.x;
uint pos = gid.y;
if (c >= channels || pos >= seqLen) return;
int halfKernel = int(kernelSize) / 2;
float sum = 0.0;
for (uint k = 0; k < kernelSize; k++) {
int inPos = int(pos) - halfKernel + int(k);
if (inPos >= 0 && inPos < int(seqLen)) {
uint inIdx = uint(inPos) * channels + c;
uint wIdx = c * kernelSize + k;
sum += input[inIdx] * weight[wIdx];
}
}
// Apply norm
sum = sum * norm[c];
output[pos * channels + c] = sum;
}
// ═══════════════════════════════════════════════
// GPU Mel Spectrogram Extraction
// ═══════════════════════════════════════════════
// DFT magnitude spectrum for all frames in parallel
// Grid: [numFrames, spectrumSize] where spectrumSize = nFft/2 + 1
kernel void audio_dft_magnitude(
device const float *audioData [[buffer(0)]], // [audioLen]
device float *spectrum [[buffer(1)]], // [numFrames * spectrumSize]
constant uint &nFft [[buffer(2)]],
constant uint &hopLength [[buffer(3)]],
constant uint &numFrames [[buffer(4)]],
constant uint &spectrumSize [[buffer(5)]],
constant uint &audioLen [[buffer(6)]], // total audio length for bounds check
uint2 gid [[thread_position_in_grid]]
) {
uint frame = gid.x;
uint bin = gid.y;
if (frame >= numFrames || bin >= spectrumSize) return;
uint start = frame * hopLength;
float real = 0.0;
float imag = 0.0;
for (uint i = 0; i < nFft; i++) {
float angle = -2.0 * M_PI_F * float(bin) * float(i) / float(nFft);
float sample = (start + i < audioLen) ? audioData[start + i] : 0.0;
float window = 0.5 * (1.0 - cos(2.0 * M_PI_F * float(i) / float(nFft - 1)));
float val = sample * window;
real += val * cos(angle);
imag += val * sin(angle);
}
spectrum[frame * spectrumSize + bin] = sqrt(real * real + imag * imag);
}
// Apply mel filterbank to DFT magnitude spectrum
// Grid: [numFrames, nMels]
kernel void audio_mel_filterbank(
device const float *spectrum [[buffer(0)]], // [numFrames * spectrumSize]
device const float *filterbank [[buffer(1)]], // [nMels * spectrumSize] precomputed
device float *melSpec [[buffer(2)]], // [numFrames * nMels]
constant uint &spectrumSize [[buffer(3)]],
constant uint &nMels [[buffer(4)]],
constant uint &numFrames [[buffer(5)]],
uint2 gid [[thread_position_in_grid]]
) {
uint frame = gid.x;
uint mel = gid.y;
if (frame >= numFrames || mel >= nMels) return;
float sum = 0.0;
for (uint bin = 0; bin < spectrumSize; bin++) {
sum += spectrum[frame * spectrumSize + bin] * filterbank[mel * spectrumSize + bin];
}
melSpec[frame * nMels + mel] = log10(max(sum, 1e-10));
}
// RMS norm with seqLen support
kernel void rms_norm_seq(
device const float *x [[buffer(0)]],
device const float *w [[buffer(1)]],
device float *y [[buffer(2)]],
constant uint &N [[buffer(3)]],
constant float &eps [[buffer(4)]],
constant uint &seqLen [[buffer(5)]],
uint2 gid [[thread_position_in_grid]]
) {
uint i = gid.x;
uint s = gid.y;
if (i >= N || s >= seqLen) return;
float ss = 0.0;
for (uint j = 0; j < N; j++) {
float val = x[s * N + j];
ss += val * val;
}
float rms = rsqrt(ss / float(N) + eps);
y[s * N + i] = x[s * N + i] * rms * w[i];
}
// SiLU activation
kernel void silu(
device const float *x [[buffer(0)]],
device float *y [[buffer(1)]],
constant uint &count [[buffer(2)]],
uint id [[thread_position_in_grid]]
) {
if (id >= count) return;
float val = x[id];
y[id] = val * (1.0 / (1.0 + exp(-val)));
}
// Residual add with weight
kernel void residual_add(
device const float *input [[buffer(0)]],
device const float *add [[buffer(1)]],
device float *output [[buffer(2)]],
constant uint &count [[buffer(3)]],
constant float &weight [[buffer(4)]],
uint id [[thread_position_in_grid]]
) {
if (id >= count) return;
output[id] = input[id] + weight * add[id];
}
// ═══════════════════════════════════════════════
// Vision Tower Kernels
// ═══════════════════════════════════════════════
// Vision add position embedding
kernel void vision_add_pos_embed(
device const float *input [[buffer(0)]],
device const float *positionEmbed [[buffer(1)]], // [2, 10240, 768]
device float *output [[buffer(2)]],
constant uint &hiddenSize [[buffer(3)]],
constant uint &numPatches [[buffer(4)]],
uint2 gid [[thread_position_in_grid]]
) {
uint h = gid.x;
uint p = gid.y;
if (h >= hiddenSize || p >= numPatches) return;
// Position embedding table: [2, 10240, 768] - use table 0, position p
// Index: table=0, pos=p, hidden=h -> 0 * 10240 * 768 + p * 768 + h
float posEmbed = positionEmbed[p * hiddenSize + h];
output[p * hiddenSize + h] = input[p * hiddenSize + h] + posEmbed;
}
// Vision head norm (RMS norm per head)
kernel void vision_head_norm(
device const float *x [[buffer(0)]], // [seqLen, numHeads, headDim]
device const float *w [[buffer(1)]], // [headDim]
device float *y [[buffer(2)]],
constant uint &numHeads [[buffer(3)]],
constant uint &headDim [[buffer(4)]],
constant uint &seqLen [[buffer(5)]],
constant float &eps [[buffer(6)]],
uint2 gid [[thread_position_in_grid]]
) {
uint idx = gid.x;
uint s = gid.y;
uint head = idx / headDim;
uint d = idx % headDim;
if (head >= numHeads || s >= seqLen) return;
float ss = 0.0;
for (uint i = 0; i < headDim; i++) {
float val = x[s * numHeads * headDim + head * headDim + i];
ss += val * val;
}
float rms = rsqrt(ss / float(headDim) + eps);
y[s * numHeads * headDim + head * headDim + d] = x[s * numHeads * headDim + head * headDim + d] * rms * w[d];
}
// Vision attention (global, no causal mask)
kernel void vision_attention(
device const float *q [[buffer(0)]], // [numPatches, numHeads, headDim]
device const float *k [[buffer(1)]], // [numPatches, numHeads, headDim]
device const float *v [[buffer(2)]], // [numPatches, numHeads, headDim]
device float *output [[buffer(3)]],
constant uint &numPatches [[buffer(4)]],
constant uint &numHeads [[buffer(5)]],
constant uint &headDim [[buffer(6)]],
uint2 gid [[thread_position_in_grid]]
) {
uint idx = gid.x;
uint pos = gid.y;
uint head = idx / headDim;
uint d = idx % headDim;
if (head >= numHeads || pos >= numPatches) return;
float scale = 1.0 / sqrt(float(headDim));
// Compute attention scores for all positions
float maxScore = -INFINITY;
for (uint p = 0; p < numPatches; p++) {
float score = 0.0;
for (uint i = 0; i < headDim; i++) {
score += q[pos * numHeads * headDim + head * headDim + i] *
k[p * numHeads * headDim + head * headDim + i];
}
score *= scale;
maxScore = max(maxScore, score);
}
// Softmax
float expSum = 0.0;
for (uint p = 0; p < numPatches; p++) {
float score = 0.0;
for (uint i = 0; i < headDim; i++) {
score += q[pos * numHeads * headDim + head * headDim + i] *
k[p * numHeads * headDim + head * headDim + i];
}
score *= scale;
expSum += exp(score - maxScore);
}
// Weighted sum of values
float outVal = 0.0;
for (uint p = 0; p < numPatches; p++) {
float score = 0.0;
for (uint i = 0; i < headDim; i++) {
score += q[pos * numHeads * headDim + head * headDim + i] *
k[p * numHeads * headDim + head * headDim + i];
}
score *= scale;
float attn = exp(score - maxScore) / expSum;
outVal += attn * v[p * numHeads * headDim + head * headDim + d];
}
output[pos * numHeads * headDim + head * headDim + d] = outVal;
}
// Vision gate multiply (SwiGLU: SiLU(gate) * up)
kernel void vision_gate_multiply(
device const float *gate [[buffer(0)]],
device const float *up [[buffer(1)]],
device float *output [[buffer(2)]],
constant uint &count [[buffer(3)]],
uint id [[thread_position_in_grid]]
) {
if (id >= count) return;
float g = gate[id];
float u = up[id];
float silu = g * (1.0 / (1.0 + exp(-g)));
output[id] = silu * u;
}
// Vision residual add (no weight)
kernel void vision_residual_add(
device const float *input [[buffer(0)]],
device const float *add [[buffer(1)]],
device float *output [[buffer(2)]],
constant uint &count [[buffer(3)]],
uint id [[thread_position_in_grid]]
) {
if (id >= count) return;
output[id] = input[id] + add[id];
}
// Vision copy output (pad hiddenSize to text hiddenSize)
kernel void vision_copy_output(
device const float *input [[buffer(0)]], // [numPatches, 768]
device float *output [[buffer(1)]], // [numPatches, 2560]
constant uint &numPatches [[buffer(2)]],
constant uint &hiddenSize [[buffer(3)]],
uint2 gid [[thread_position_in_grid]]
) {
uint h = gid.x;
uint p = gid.y;
if (h >= 2560 || p >= numPatches) return;
if (h < hiddenSize) {
output[p * 2560 + h] = input[p * hiddenSize + h];
} else {
output[p * 2560 + h] = 0.0; // Pad with zeros
}
}
// Vision embedding projection with 4-bit quantization
// weight: [outFeatures, packedSize] uint32 (each uint32 holds 8 4-bit values)
// scales: [outFeatures, numGroups] float
// biases: [outFeatures, numGroups] float
// input: [numPatches, inFeatures] float (inFeatures = packedSize * 8)
// output: [numPatches, outFeatures] float
kernel void vision_embedding_projection_quantized(
device const float *input [[buffer(0)]],
device const uint32_t *weight [[buffer(1)]], // packed uint32
device const float *scales [[buffer(2)]],
device const float *biases [[buffer(3)]],
device float *output [[buffer(4)]],
constant uint &inFeatures [[buffer(5)]], // 768
constant uint &outFeatures [[buffer(6)]], // 2560
constant uint &numPatches [[buffer(7)]],
constant uint &packedSize [[buffer(8)]], // 96 (inFeatures / 8)
constant uint &groupSize [[buffer(9)]], // 64
constant uint &numGroups [[buffer(10)]], // 12 (inFeatures / groupSize)
uint2 gid [[thread_position_in_grid]]
) {
uint of = gid.x; // output feature
uint p = gid.y; // patch position
if (of >= outFeatures || p >= numPatches) return;
float sum = 0.0;
for (uint packedIdx = 0; packedIdx < packedSize; packedIdx++) {
uint32_t packed = weight[of * packedSize + packedIdx];
// Unpack 8 4-bit values from uint32
for (uint nibbleIdx = 0; nibbleIdx < 8; nibbleIdx++) {
uint elementIdx = packedIdx * 8 + nibbleIdx;
uint nibble = (packed >> (nibbleIdx * 4)) & 0xF;
// Determine group
uint group = elementIdx / groupSize;
if (group >= numGroups) group = numGroups - 1;
// Dequantize: scale * (nibble - bias)
float scale = scales[of * numGroups + group];
float bias = biases[of * numGroups + group];
float dequantized = scale * (float(nibble) - bias);
// Multiply with input
sum += input[p * inFeatures + elementIdx] * dequantized;
}
}
output[p * outFeatures + of] = sum;
}
// ── Matmul for F32 weights (not quantized) ──
// Simple row-vector @ matrix multiplication: output = input @ weight^T
// input: [M, K], weight: [N, K] (stored row-major), output: [M, N]
kernel void matmul_f32(
device const float *input [[buffer(0)]], // [M, K]
device const float *weight [[buffer(1)]], // [N, K]
device float *output [[buffer(2)]], // [M, N]
constant uint &M [[buffer(3)]],
constant uint &K [[buffer(4)]],
constant uint &N [[buffer(5)]],
uint id [[thread_position_in_grid]]
) {
// Each thread computes one output element
uint row = 0; // For single token, M=1
uint col = id;
if (col >= N) return;
float sum = 0.0;
for (uint k = 0; k < K; k++) {
sum += input[row * K + k] * weight[col * K + k];
}
output[row * N + col] = sum;
}
// ═══════════════════════════════════════════════════════════════
// Batch Metal Kernels - Process multiple tokens simultaneously
// ═══════════════════════════════════════════════════════════════
// Batch quantized matmul - process N tokens with shared weights
kernel void quantized_matmul_batch(
device float* batchInput [[buffer(0)]], // [batchSize, inDim]
device uint8_t* weights [[buffer(1)]], // [outDim, inDim] packed
device float* scales [[buffer(2)]], // [outDim, groups]
device float* biases [[buffer(3)]], // [outDim]
device float* batchOutput [[buffer(4)]], // [batchSize, outDim]
constant uint32_t& inDim [[buffer(5)]],
constant uint32_t& outDim [[buffer(6)]],
constant uint32_t& groupSize [[buffer(7)]],
constant uint32_t& batchSize [[buffer(8)]],
uint3 gid [[thread_position_in_grid]])
{
uint batchIdx = gid.x;
uint outIdx = gid.y;
if (batchIdx >= batchSize || outIdx >= outDim) return;
device float* input = batchInput + batchIdx * inDim;
float sum = biases[outIdx];
uint groupIdx = outIdx * (inDim / groupSize);
for (uint i = 0; i < inDim; i += 4) {
float4 inVals = float4(input[i], input[i+1], input[i+2], input[i+3]);
uint packedWeight = weights[outIdx * inDim + i];
uint8_t w0 = (packedWeight >> 0) & 0xFF;
uint8_t w1 = (packedWeight >> 8) & 0xFF;
uint8_t w2 = (packedWeight >> 16) & 0xFF;
uint8_t w3 = (packedWeight >> 24) & 0xFF;
uint g0 = (i + 0) / groupSize;
uint g1 = (i + 1) / groupSize;
uint g2 = (i + 2) / groupSize;
uint g3 = (i + 3) / groupSize;
float scale0 = scales[groupIdx + g0];
float scale1 = scales[groupIdx + g1];
float scale2 = scales[groupIdx + g2];
float scale3 = scales[groupIdx + g3];
sum += inVals.x * (w0 - 128) * scale0;
sum += inVals.y * (w1 - 128) * scale1;
sum += inVals.z * (w2 - 128) * scale2;
sum += inVals.w * (w3 - 128) * scale3;
}
batchOutput[batchIdx * outDim + outIdx] = sum;
}
// Batch RMS norm - process N tokens simultaneously
kernel void rms_norm_batch(
device float* batchInput [[buffer(0)]], // [batchSize, N]
device float* weights [[buffer(1)]], // [N]
device float* batchOutput [[buffer(2)]], // [batchSize, N]
constant uint32_t& N [[buffer(3)]],
constant float& eps [[buffer(4)]],
constant uint32_t& batchSize [[buffer(5)]],
uint3 gid [[thread_position_in_grid]])
{
uint batchIdx = gid.x;
uint elemIdx = gid.y;
if (batchIdx >= batchSize || elemIdx >= N) return;
device float* input = batchInput + batchIdx * N;
float sqSum = 0.0;
for (uint i = 0; i < N; i++) {
sqSum += input[i] * input[i];
}
float rms = sqrt(sqSum / float(N) + eps);
batchOutput[batchIdx * N + elemIdx] = input[elemIdx] / rms * weights[elemIdx];
}
// Batch attention (simplified - for demonstration)
// Full implementation would require complex KV cache management
kernel void sliding_attention_batch(
device float* batchQuery [[buffer(0)]], // [batchSize, nHeads, headDim]
device float* kvCache [[buffer(1)]], // [maxSeqLen, 2, nKvHeads, headDim]
device float* batchOutput [[buffer(2)]], // [batchSize, nHeads, headDim]
constant uint32_t* positions [[buffer(3)]], // [batchSize]
constant uint32_t& nHeads [[buffer(4)]],
constant uint32_t& nKvHeads [[buffer(5)]],
constant uint32_t& headDim [[buffer(6)]],
constant uint32_t& batchSize [[buffer(7)]],
constant uint32_t& windowSize [[buffer(8)]],
uint3 gid [[thread_position_in_grid]])
{
uint batchIdx = gid.x;
uint headIdx = gid.y;
uint dimIdx = gid.z;
if (batchIdx >= batchSize || headIdx >= nHeads || dimIdx >= headDim) return;
uint pos = positions[batchIdx];
uint kvHeadIdx = headIdx / (nHeads / nKvHeads);
device float* query = batchQuery + batchIdx * nHeads * headDim + headIdx * headDim;
uint start = max(0u, pos - windowSize);
uint end = pos;
float maxScore = -1e10;
for (uint t = start; t < end; t++) {
device float* key = kvCache + t * 2 * nKvHeads * headDim + kvHeadIdx * headDim;
float score = 0.0;
for (uint d = 0; d < headDim; d++) {
score += query[d] * key[d];
}
score /= sqrt(float(headDim));
maxScore = max(maxScore, score);
}
float expSum = 0.0;
for (uint t = start; t < end; t++) {
device float* key = kvCache + t * 2 * nKvHeads * headDim + kvHeadIdx * headDim;
float score = 0.0;
for (uint d = 0; d < headDim; d++) {
score += query[d] * key[d];
}
score /= sqrt(float(headDim));
expSum += exp(score - maxScore);
}
float output = 0.0;
for (uint t = start; t < end; t++) {
device float* key = kvCache + t * 2 * nKvHeads * headDim + kvHeadIdx * headDim;
device float* value = kvCache + t * 2 * nKvHeads * headDim + nKvHeads * headDim + kvHeadIdx * headDim;
float score = 0.0;
for (uint d = 0; d < headDim; d++) {
score += query[d] * key[d];
}
score /= sqrt(float(headDim));
float weight = exp(score - maxScore) / expSum;
output += weight * value[dimIdx];
}
batchOutput[batchIdx * nHeads * headDim + headIdx * headDim + dimIdx] = output;
}
// ═══════════════════════════════════════════════════════════════
// Batch Layer Processing Kernels
// Process entire layer for multiple tokens simultaneously
// ═══════════════════════════════════════════════════════════════
// Batch RMS Norm for layer input
// Process [batchSize, hiddenSize] with shared weights
kernel void batch_layer_rms_norm(
device float* batchInput [[buffer(0)]], // [batchSize, hiddenSize]
device float* weights [[buffer(1)]], // [hiddenSize]
device float* batchOutput [[buffer(2)]], // [batchSize, hiddenSize]
constant uint32_t& hiddenSize [[buffer(3)]],
constant float& eps [[buffer(4)]],
constant uint32_t& batchSize [[buffer(5)]],
uint3 gid [[thread_position_in_grid]])
{
uint batchIdx = gid.x;
uint elemIdx = gid.y;
if (batchIdx >= batchSize || elemIdx >= hiddenSize) return;
device float* input = batchInput + batchIdx * hiddenSize;
device float* output = batchOutput + batchIdx * hiddenSize;
// Compute sum of squares for this batch element
float ss = 0.0;
for (uint i = 0; i < hiddenSize; i++) {
ss += input[i] * input[i];
}
float rms = sqrt(ss / float(hiddenSize) + eps);
output[elemIdx] = input[elemIdx] / rms * weights[elemIdx];
}
// Batch Quantized Matmul for layer projections
// Process [batchSize, outDim] with shared quantized weights
kernel void batch_layer_quantized_matmul(
device float* batchInput [[buffer(0)]], // [batchSize, inDim]
device uint8_t* weights [[buffer(1)]], // [outDim, inDim] packed
device float* scales [[buffer(2)]], // [outDim, groups]
device float* biases [[buffer(3)]], // [outDim]
device float* batchOutput [[buffer(4)]], // [batchSize, outDim]
constant uint32_t& inDim [[buffer(5)]],
constant uint32_t& outDim [[buffer(6)]],
constant uint32_t& groupSize [[buffer(7)]],
constant uint32_t& batchSize [[buffer(8)]],
uint3 gid [[thread_position_in_grid]])
{
uint batchIdx = gid.x;
uint outIdx = gid.y;
if (batchIdx >= batchSize || outIdx >= outDim) return;
device float* input = batchInput + batchIdx * inDim;
device float* output = batchOutput + batchIdx * outDim;
float sum = biases[outIdx];
uint groupIdx = outIdx * (inDim / groupSize);
// Process in groups for quantization
for (uint i = 0; i < inDim; i++) {
// Load weight (8-bit quantized)
uint8_t w = weights[outIdx * inDim + i];
// Get scale for this group
uint g = i / groupSize;
float scale = scales[groupIdx + g];
// Dequantize and accumulate
sum += input[i] * (w - 128) * scale;
}
output[outIdx] = sum;
}
// Batch Elementwise Add for residual connections
// Process [batchSize, size]
kernel void batch_eltwise_add(
device float* batchA [[buffer(0)]], // [batchSize, size]
device float* batchB [[buffer(1)]], // [batchSize, size]
device float* batchOutput [[buffer(2)]], // [batchSize, size]
constant uint32_t& size [[buffer(3)]],
constant uint32_t& batchSize [[buffer(4)]],
uint2 gid [[thread_position_in_grid]])
{
uint batchIdx = gid.x;
uint elemIdx = gid.y;
if (batchIdx >= batchSize || elemIdx >= size) return;
uint offset = batchIdx * size + elemIdx;
batchOutput[offset] = batchA[offset] + batchB[offset];
}
// Batch Gated FFN (fused gate + up projection)
// Process [batchSize, intermediateSize]
kernel void batch_fused_gate_up(
device float* batchInput [[buffer(0)]], // [batchSize, hiddenSize]
device uint8_t* gateWeights [[buffer(1)]], // [intermediateSize, hiddenSize]
device float* gateScales [[buffer(2)]],
device float* gateBiases [[buffer(3)]],
device uint8_t* upWeights [[buffer(4)]], // [intermediateSize, hiddenSize]
device float* upScales [[buffer(5)]],
device float* upBiases [[buffer(6)]],
device float* batchOutput [[buffer(7)]], // [batchSize, intermediateSize]
constant uint32_t& hiddenSize [[buffer(8)]],
constant uint32_t& intermediateSize [[buffer(9)]],
constant uint32_t& groupSize [[buffer(10)]],
constant uint32_t& batchSize [[buffer(11)]],
uint3 gid [[thread_position_in_grid]])
{
uint batchIdx = gid.x;
uint interIdx = gid.y;
if (batchIdx >= batchSize || interIdx >= intermediateSize) return;
device float* input = batchInput + batchIdx * hiddenSize;
device float* output = batchOutput + batchIdx * intermediateSize;
// Compute gate
float gate = gateBiases[interIdx];
uint gateGroupIdx = interIdx * (hiddenSize / groupSize);
for (uint i = 0; i < hiddenSize; i++) {
uint8_t w = gateWeights[interIdx * hiddenSize + i];
uint g = i / groupSize;
float scale = gateScales[gateGroupIdx + g];
gate += input[i] * (w - 128) * scale;
}
// Compute up
float up = upBiases[interIdx];
uint upGroupIdx = interIdx * (hiddenSize / groupSize);
for (uint i = 0; i < hiddenSize; i++) {
uint8_t w = upWeights[interIdx * hiddenSize + i];
uint g = i / groupSize;
float scale = upScales[upGroupIdx + g];
up += input[i] * (w - 128) * scale;
}
// Fused activation: gate * sigmoid(gate) * up
float sigmoidGate = 1.0 / (1.0 + exp(-gate));
output[interIdx] = gate * sigmoidGate * up;
}
// Batch Down Projection (FFN output)
// Process [batchSize, hiddenSize]
kernel void batch_down_projection(
device float* batchInter [[buffer(0)]], // [batchSize, intermediateSize]
device uint8_t* downWeights [[buffer(1)]], // [hiddenSize, intermediateSize]
device float* downScales [[buffer(2)]],
device float* downBiases [[buffer(3)]],
device float* batchOutput [[buffer(4)]], // [batchSize, hiddenSize]
constant uint32_t& hiddenSize [[buffer(5)]],
constant uint32_t& intermediateSize [[buffer(6)]],
constant uint32_t& groupSize [[buffer(7)]],
constant uint32_t& batchSize [[buffer(8)]],
uint3 gid [[thread_position_in_grid]])
{
uint batchIdx = gid.x;
uint outIdx = gid.y;
if (batchIdx >= batchSize || outIdx >= hiddenSize) return;
device float* inter = batchInter + batchIdx * intermediateSize;
device float* output = batchOutput + batchIdx * hiddenSize;
float sum = downBiases[outIdx];
uint groupIdx = outIdx * (intermediateSize / groupSize);
for (uint i = 0; i < intermediateSize; i++) {
uint8_t w = downWeights[outIdx * intermediateSize + i];
uint g = i / groupSize;
float scale = downScales[groupIdx + g];
sum += inter[i] * (w - 128) * scale;
}
output[outIdx] = sum;
}
// ══════════════════════════════════════════════════════════════════
// Batch Embedding Lookup - Process multiple token embeddings in parallel
// Eliminates sequential waitUntilCompleted() bottleneck
// ══════════════════════════════════════════════════════════════════
// Batch version of dequantize_row - processes multiple token embeddings
kernel void dequantize_row_batch(
device const uint *w [[buffer(0)]], // [vocabSize, nCols/8]
device const float *s [[buffer(1)]], // [vocabSize, numGroups]
device const float *b [[buffer(2)]], // [vocabSize, numGroups]
device const uint *tokenIds [[buffer(3)]], // [batchSize] - which rows to lookup
device float *out [[buffer(4)]], // [batchSize, nCols]
constant uint &nCols [[buffer(5)]],
constant uint &batchSize [[buffer(6)]],
constant uint &groupSize [[buffer(7)]],
uint3 gid [[thread_position_in_grid]]
) {
uint batchIdx = gid.x; // Which token in batch
uint colIdx = gid.y; // Which column in embedding
if (batchIdx >= batchSize || colIdx >= nCols) return;
uint tokenId = tokenIds[batchIdx];
uint g = colIdx / groupSize;
uint inG = colIdx % groupSize;
uint packedIdx = g * (groupSize / 8) + inG / 8;
uint shift = (inG % 8) * 4;
uint numGroups = nCols / groupSize;
// Lookup the quantized value
uint qval = (w[tokenId * (nCols / 8) + packedIdx] >> shift) & 0xF;
float scale = s[tokenId * numGroups + g];
float bias = b[tokenId * numGroups + g];
// Write to batch output buffer [batchIdx, colIdx]
out[batchIdx * nCols + colIdx] = float(qval) * scale + bias;
}
// Batch version with scale applied (fused dequantize + scale)
kernel void dequantize_row_batch_scaled(
device const uint *w [[buffer(0)]], // [vocabSize, nCols/8]
device const float *s [[buffer(1)]], // [vocabSize, numGroups]
device const float *b [[buffer(2)]], // [vocabSize, numGroups]
device const uint *tokenIds [[buffer(3)]], // [batchSize] - which rows to lookup
device float *out [[buffer(4)]], // [batchSize, nCols]
constant uint &nCols [[buffer(5)]],
constant uint &batchSize [[buffer(6)]],
constant uint &groupSize [[buffer(7)]],
constant float &embedScale [[buffer(8)]], // Global embedding scale
uint3 gid [[thread_position_in_grid]]
) {
uint batchIdx = gid.x;
uint colIdx = gid.y;
if (batchIdx >= batchSize || colIdx >= nCols) return;
uint tokenId = tokenIds[batchIdx];
uint g = colIdx / groupSize;
uint inG = colIdx % groupSize;
uint packedIdx = g * (groupSize / 8) + inG / 8;
uint shift = (inG % 8) * 4;
uint numGroups = nCols / groupSize;
uint qval = (w[tokenId * (nCols / 8) + packedIdx] >> shift) & 0xF;
float scale = s[tokenId * numGroups + g];
float bias = b[tokenId * numGroups + g];
// Apply embedding scale (fused)
out[batchIdx * nCols + colIdx] = (float(qval) * scale + bias) * embedScale;
}