8a66b9086a
- Started from ac75faa (initial E4B-MarkBase integration)
- Kept Sources/ (all engine code) + Package.swift + .gitignore
- Removed all ad-hoc tests, documentation, scripts, Python files
- Added Tests/00_Unit/ (MathTest, TokenizerTest, SamplerTest)
- Added .gitea/workflows/ci.yaml (build + unit tests + lint)
- Added Scripts/check_resources.sh (memory-aware test runner)
- Added Tests/Manifest.json (resource requirements for all tests)
- Focus: 4-bit quantized models only
2053 lines
77 KiB
Metal
2053 lines
77 KiB
Metal
#include <metal_stdlib>
|
||
using namespace metal;
|
||
|
||
// ═══════════════════════════════════════════════
|
||
// E4B Inference Kernels
|
||
// ═══════════════════════════════════════════════
|
||
|
||
// ── Constants ──────────────────────────────────
|
||
constant uint GROUP_SIZE = 64; // MLX quantization group size (E4B 4-bit)
|
||
|
||
// ── 1. RMSNorm ─────────────────────────────────
|
||
// y[i] = x[i] * rsqrt(mean(x^2) + eps) * w[i]
|
||
// NOTE: NOT safe for in-place when dispatched with multiple threadgroups.
|
||
// The Swift layer always passes separate input/output buffers.
|
||
kernel void rms_norm(
|
||
device const float *x [[buffer(0)]], // [N]
|
||
device const float *w [[buffer(1)]], // [N] weight (can be null)
|
||
device float *y [[buffer(2)]], // [N]
|
||
constant uint &N [[buffer(3)]],
|
||
constant float &eps [[buffer(4)]],
|
||
uint id [[thread_position_in_grid]]
|
||
) {
|
||
if (id >= N) return;
|
||
|
||
float ss = 0.0;
|
||
for (uint i = 0; i < N; i++) ss += x[i] * x[i];
|
||
float rms = rsqrt(ss / float(N) + eps);
|
||
y[id] = (w ? x[id] * rms * w[id] : x[id] * rms);
|
||
}
|
||
|
||
// ── Sampling Kernels ────────────────────────────
|
||
|
||
// Softmax: probs[i] = exp(logits[i] - max) / sum(exp)
|
||
kernel void softmax(
|
||
device const float *logits [[buffer(0)]], // [N]
|
||
device float *probs [[buffer(1)]], // [N]
|
||
constant uint &N [[buffer(2)]],
|
||
uint id [[thread_position_in_grid]]
|
||
) {
|
||
if (id >= N) return;
|
||
|
||
// Pass 1: find max (all threads compute same)
|
||
float maxVal = -INFINITY;
|
||
for (uint i = 0; i < N; i++) maxVal = max(maxVal, logits[i]);
|
||
|
||
// Pass 2: exp and sum
|
||
float sumExp = 0.0;
|
||
for (uint i = 0; i < N; i++) sumExp += exp(logits[i] - maxVal);
|
||
|
||
// Output
|
||
probs[id] = exp(logits[id] - maxVal) / sumExp;
|
||
}
|
||
|
||
// Temperature scaling: logits[i] /= temperature
|
||
kernel void temperature_scale(
|
||
device float *logits [[buffer(0)]], // [N] in-place
|
||
constant uint &N [[buffer(1)]],
|
||
constant float &temperature [[buffer(2)]],
|
||
uint id [[thread_position_in_grid]]
|
||
) {
|
||
if (id >= N) return;
|
||
logits[id] = logits[id] / temperature;
|
||
}
|
||
|
||
// Argmax: find index of maximum value
|
||
// Uses atomic to safely update best index across threads
|
||
kernel void argmax(
|
||
device const float *logits [[buffer(0)]], // [N]
|
||
device atomic_uint *bestIdx [[buffer(1)]], // single atomic uint
|
||
device atomic_float *bestVal [[buffer(2)]], // single atomic float
|
||
constant uint &N [[buffer(3)]],
|
||
uint id [[thread_position_in_grid]]
|
||
) {
|
||
if (id >= N) return;
|
||
|
||
float val = logits[id];
|
||
|
||
// Atomic compare and swap
|
||
float oldBest = atomic_load_explicit(bestVal, memory_order_relaxed);
|
||
while (val > oldBest) {
|
||
// Try to update
|
||
if (atomic_compare_exchange_weak_explicit(
|
||
bestVal, &oldBest, val,
|
||
memory_order_relaxed, memory_order_relaxed)) {
|
||
atomic_store_explicit(bestIdx, id, memory_order_relaxed);
|
||
break;
|
||
}
|
||
// oldBest was updated by another thread, retry
|
||
}
|
||
}
|
||
|
||
// Top-k mask: set logits outside top-k to -inf
|
||
// Uses parallel sort-like approach with threshold
|
||
kernel void top_k_mask(
|
||
device float *logits [[buffer(0)]], // [N] in-place
|
||
constant uint &N [[buffer(1)]],
|
||
constant uint &k [[buffer(2)]],
|
||
uint id [[thread_position_in_grid]]
|
||
) {
|
||
if (id >= N) return;
|
||
|
||
// Find k-th largest value using parallel bubble sort idea
|
||
// This is O(N*k) but simple and works for small vocab sizes
|
||
// For large vocab, use a proper GPU sort
|
||
|
||
// Simple approach: each thread maintains a local threshold
|
||
// We'll use a different approach - find threshold via bucket
|
||
// For now, use a simpler single-thread approach in Swift
|
||
|
||
// Actually, for efficiency, we'll do top-k in Swift on CPU
|
||
// This kernel just marks the structure
|
||
}
|
||
|
||
// ── 1b. Grouped RMSNorm (per-head norm) ──────────
|
||
// Groups are contiguous blocks of `groupSize` elements.
|
||
// Each group computes its own RMS independently.
|
||
// weight buffer layout: [groupSize], replicated across groups (same weight for each head).
|
||
kernel void rms_norm_grouped(
|
||
device const float *x [[buffer(0)]], // [N]
|
||
device const float *w [[buffer(1)]], // [groupSize] weight replicated across groups
|
||
device float *y [[buffer(2)]], // [N]
|
||
constant uint &N [[buffer(3)]],
|
||
constant uint &groupSize [[buffer(4)]],
|
||
constant float &eps [[buffer(5)]],
|
||
uint id [[thread_position_in_grid]]
|
||
) {
|
||
if (id >= N) return;
|
||
uint g = id / groupSize;
|
||
uint start = g * groupSize;
|
||
uint end = min(start + groupSize, N);
|
||
uint wIdx = id % groupSize; // Weight index within group
|
||
|
||
float ss = 0.0;
|
||
for (uint i = start; i < end; i++) ss += x[i] * x[i];
|
||
float rms = rsqrt(ss / float(groupSize) + eps);
|
||
y[id] = (w ? x[id] * rms * w[wIdx] : x[id] * rms);
|
||
}
|
||
|
||
// ── 2. GELU Approximation ──────────────────────
|
||
// gelu(x) ≈ x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
|
||
kernel void gelu_approx(
|
||
device const float *x [[buffer(0)]],
|
||
device float *y [[buffer(1)]],
|
||
constant uint &N [[buffer(2)]],
|
||
uint id [[thread_position_in_grid]]
|
||
) {
|
||
if (id >= N) return;
|
||
float v = x[id];
|
||
// Numerically stable GELU: clamp |v| to avoid v^3 overflow in Float32.
|
||
// For |v| > 10, GELU ≈ max(v, 0).
|
||
float c = M_SQRT2_F * M_2_SQRTPI_F * 0.5; // sqrt(2/pi)
|
||
float absv = v > 0 ? v : -v;
|
||
if (absv > 10.0) {
|
||
y[id] = v > 0 ? v : 0.0;
|
||
} else {
|
||
float v3 = v * v * v;
|
||
y[id] = 0.5 * v * (1.0 + tanh(c * (v + 0.044715 * v3)));
|
||
}
|
||
}
|
||
|
||
// ── 3. Quantized MatMul (MLX U32-packed format) ─
|
||
// out[outDim] = dequant(weight[outDim, inDim/8]) @ x[inDim]
|
||
// scales/biases: [outDim, inDim/GROUP_SIZE]
|
||
kernel void quantized_matmul(
|
||
device const float *x [[buffer(0)]],
|
||
device const uint *w [[buffer(1)]],
|
||
device const float *s [[buffer(2)]],
|
||
device const float *b [[buffer(3)]],
|
||
device float *out [[buffer(4)]],
|
||
constant uint &inDim [[buffer(5)]],
|
||
constant uint &outDim [[buffer(6)]],
|
||
constant uint &groupSize [[buffer(7)]],
|
||
uint id [[thread_position_in_grid]]
|
||
) {
|
||
if (id >= outDim) return;
|
||
|
||
uint numGroups = inDim / groupSize;
|
||
uint packedPerOut = inDim / 8; // 8 × 4-bit per U32
|
||
float sum = 0.0;
|
||
|
||
for (uint g = 0; g < numGroups; g++) {
|
||
float scale = s[id * numGroups + g];
|
||
float bias = b[id * numGroups + g];
|
||
|
||
for (uint j = 0; j < groupSize; j++) {
|
||
uint packedIdx = g * (groupSize / 8) + j / 8;
|
||
uint shift = (j % 8) * 4;
|
||
uint qval = (w[id * packedPerOut + packedIdx] >> shift) & 0xF;
|
||
float dq = float(qval) * scale + bias;
|
||
sum += dq * x[g * groupSize + j];
|
||
}
|
||
}
|
||
out[id] = sum;
|
||
}
|
||
|
||
// ── 3b. Quantized MatMul with fused GELU ───────
|
||
// out = gelu(quantized_matmul(x, w, s, b))
|
||
kernel void quantized_matmul_gelu(
|
||
device const float *x [[buffer(0)]],
|
||
device const uint *w [[buffer(1)]],
|
||
device const float *s [[buffer(2)]],
|
||
device const float *b [[buffer(3)]],
|
||
device float *out [[buffer(4)]],
|
||
constant uint &inDim [[buffer(5)]],
|
||
constant uint &outDim [[buffer(6)]],
|
||
constant uint &groupSize [[buffer(7)]],
|
||
uint id [[thread_position_in_grid]]
|
||
) {
|
||
if (id >= outDim) return;
|
||
|
||
uint numGroups = inDim / groupSize;
|
||
uint packedPerOut = inDim / 8;
|
||
float sum = 0.0;
|
||
|
||
for (uint g = 0; g < numGroups; g++) {
|
||
float scale = s[id * numGroups + g];
|
||
float bias = b[id * numGroups + g];
|
||
for (uint j = 0; j < groupSize; j++) {
|
||
uint packedIdx = g * (groupSize / 8) + j / 8;
|
||
uint shift = (j % 8) * 4;
|
||
uint qval = (w[id * packedPerOut + packedIdx] >> shift) & 0xF;
|
||
float dq = float(qval) * scale + bias;
|
||
sum += dq * x[g * groupSize + j];
|
||
}
|
||
}
|
||
|
||
float v = sum;
|
||
float c = M_SQRT2_F * M_2_SQRTPI_F * 0.5;
|
||
float absv = v > 0 ? v : -v;
|
||
if (absv > 10.0) {
|
||
out[id] = v > 0 ? v : 0.0;
|
||
} else {
|
||
float v3 = v * v * v;
|
||
out[id] = 0.5 * v * (1.0 + tanh(c * (v + 0.044715 * v3)));
|
||
}
|
||
}
|
||
|
||
// ── 4. Quantized MatMul + Mul (for SwiGLU-style) ─
|
||
// Output layout: out[0:outDim] = gelu(gate_out), out[outDim:2*outDim] = up_out
|
||
// Gate gets GELU activation, Up has no activation
|
||
// Swift code will do element-wise multiply: gelu(gate) * up
|
||
kernel void quantized_matmul_gate_up(
|
||
device const float *x [[buffer(0)]],
|
||
// gate projection
|
||
device const uint *w_gate [[buffer(1)]],
|
||
device const float *s_gate [[buffer(2)]],
|
||
device const float *b_gate [[buffer(3)]],
|
||
// up projection
|
||
device const uint *w_up [[buffer(4)]],
|
||
device const float *s_up [[buffer(5)]],
|
||
device const float *b_up [[buffer(6)]],
|
||
device float *out [[buffer(7)]],
|
||
constant uint &inDim [[buffer(8)]],
|
||
constant uint &outDim [[buffer(9)]],
|
||
constant uint &groupSize [[buffer(10)]],
|
||
uint id [[thread_position_in_grid]]
|
||
) {
|
||
if (id >= outDim) return;
|
||
|
||
uint numGroups = inDim / groupSize;
|
||
uint packedPerOut = inDim / 8;
|
||
|
||
// Gate projection + GELU
|
||
float gateSum = 0.0;
|
||
for (uint g = 0; g < numGroups; g++) {
|
||
float scale = s_gate[id * numGroups + g];
|
||
float bias = b_gate[id * numGroups + g];
|
||
for (uint j = 0; j < groupSize; j++) {
|
||
uint packedIdx = g * (groupSize / 8) + j / 8;
|
||
uint shift = (j % 8) * 4;
|
||
uint qval = (w_gate[id * packedPerOut + packedIdx] >> shift) & 0xF;
|
||
gateSum += (float(qval) * scale + bias) * x[g * groupSize + j];
|
||
}
|
||
}
|
||
|
||
// Clamp gateSum to prevent overflow
|
||
if (gateSum > 100.0) gateSum = 100.0;
|
||
if (gateSum < -100.0) gateSum = -100.0;
|
||
|
||
float v = gateSum;
|
||
float c = M_SQRT2_F * M_2_SQRTPI_F * 0.5;
|
||
float absv = v > 0 ? v : -v;
|
||
float gate;
|
||
if (absv > 10.0) {
|
||
gate = v > 0 ? v : 0.0;
|
||
} else {
|
||
float v3 = v * v * v;
|
||
gate = 0.5 * v * (1.0 + tanh(c * (v + 0.044715 * v3)));
|
||
}
|
||
|
||
// Up projection (no activation)
|
||
float upSum = 0.0;
|
||
for (uint g = 0; g < numGroups; g++) {
|
||
float scale = s_up[id * numGroups + g];
|
||
float bias = b_up[id * numGroups + g];
|
||
for (uint j = 0; j < groupSize; j++) {
|
||
uint packedIdx = g * (groupSize / 8) + j / 8;
|
||
uint shift = (j % 8) * 4;
|
||
uint qval = (w_up[id * packedPerOut + packedIdx] >> shift) & 0xF;
|
||
upSum += (float(qval) * scale + bias) * x[g * groupSize + j];
|
||
}
|
||
}
|
||
|
||
// Clamp upSum to prevent overflow
|
||
if (upSum > 100.0) upSum = 100.0;
|
||
if (upSum < -100.0) upSum = -100.0;
|
||
|
||
// Clamp gate*up product to prevent overflow
|
||
float product = gate * upSum;
|
||
if (product > 10.0) product = 10.0;
|
||
if (product < -10.0) product = -10.0;
|
||
if (isnan(product) || isinf(product)) product = 0.0;
|
||
|
||
// Original: output element-wise product (for testing)
|
||
out[id] = product;
|
||
}
|
||
|
||
// ── 8-bit Fused Gate+Up Matmul ───────────────
|
||
// Same as quantized_matmul_gate_up but for 8-bit weights (4 values per uint32, mask 0xFF)
|
||
kernel void quantized_matmul_gate_up_8bit(
|
||
device const float *x [[buffer(0)]],
|
||
device const uint *w_gate [[buffer(1)]],
|
||
device const float *s_gate [[buffer(2)]],
|
||
device const float *b_gate [[buffer(3)]],
|
||
device const uint *w_up [[buffer(4)]],
|
||
device const float *s_up [[buffer(5)]],
|
||
device const float *b_up [[buffer(6)]],
|
||
device float *out [[buffer(7)]],
|
||
constant uint &inDim [[buffer(8)]],
|
||
constant uint &outDim [[buffer(9)]],
|
||
constant uint &groupSize [[buffer(10)]],
|
||
uint id [[thread_position_in_grid]]
|
||
) {
|
||
if (id >= outDim) return;
|
||
|
||
uint numGroups = inDim / groupSize;
|
||
uint packedPerOut = inDim / 4;
|
||
|
||
// Gate projection + GELU
|
||
float gateSum = 0.0;
|
||
for (uint g = 0; g < numGroups; g++) {
|
||
float scale = s_gate[id * numGroups + g];
|
||
float bias = b_gate[id * numGroups + g];
|
||
for (uint j = 0; j < groupSize; j++) {
|
||
uint packedIdx = g * (groupSize / 4) + j / 4;
|
||
uint shift = (j % 4) * 8;
|
||
uint qval = (w_gate[id * packedPerOut + packedIdx] >> shift) & 0xFF;
|
||
gateSum += (float(qval) * scale + bias) * x[g * groupSize + j];
|
||
}
|
||
}
|
||
|
||
// Clamp gateSum to prevent overflow
|
||
if (gateSum > 100.0) gateSum = 100.0;
|
||
if (gateSum < -100.0) gateSum = -100.0;
|
||
|
||
float v = gateSum;
|
||
float c = M_SQRT2_F * M_2_SQRTPI_F * 0.5;
|
||
float absv = v > 0 ? v : -v;
|
||
float gate;
|
||
if (absv > 10.0) {
|
||
gate = v > 0 ? v : 0.0;
|
||
} else {
|
||
float v3 = v * v * v;
|
||
gate = 0.5 * v * (1.0 + tanh(c * (v + 0.044715 * v3)));
|
||
}
|
||
|
||
// Up projection (no activation)
|
||
float upSum = 0.0;
|
||
for (uint g = 0; g < numGroups; g++) {
|
||
float scale = s_up[id * numGroups + g];
|
||
float bias = b_up[id * numGroups + g];
|
||
for (uint j = 0; j < groupSize; j++) {
|
||
uint packedIdx = g * (groupSize / 4) + j / 4;
|
||
uint shift = (j % 4) * 8;
|
||
uint qval = (w_up[id * packedPerOut + packedIdx] >> shift) & 0xFF;
|
||
upSum += (float(qval) * scale + bias) * x[g * groupSize + j];
|
||
}
|
||
}
|
||
|
||
// Clamp upSum and product
|
||
if (upSum > 100.0) upSum = 100.0;
|
||
if (upSum < -100.0) upSum = -100.0;
|
||
|
||
float product = gate * upSum;
|
||
if (product > 10.0) product = 10.0;
|
||
if (product < -10.0) product = -10.0;
|
||
if (isnan(product) || isinf(product)) product = 0.0;
|
||
|
||
out[id] = product;
|
||
}
|
||
|
||
// ── 5a. Half-split RoPE (Q only) ─────────────
|
||
// Gemma uses rotate_half style: pairs are (d, d + headDim/2), not (d1, d2)
|
||
kernel void apply_rope_q(
|
||
device float *q [[buffer(0)]], // [nHeads, headDim] — in-place
|
||
constant uint &nHeads [[buffer(1)]],
|
||
constant uint &headDim [[buffer(2)]],
|
||
constant uint &rotatedDim [[buffer(3)]],
|
||
constant float &theta [[buffer(4)]],
|
||
constant float &scale [[buffer(5)]],
|
||
constant int &position [[buffer(6)]],
|
||
uint id [[thread_position_in_grid]]
|
||
) {
|
||
uint halfDim = headDim / 2;
|
||
uint nPairs = rotatedDim / 2;
|
||
uint head = id / nPairs;
|
||
uint pair = id % nPairs;
|
||
if (head >= nHeads || pair >= nPairs || nPairs == 0) return;
|
||
|
||
// Half rotation: pair i corresponds to (i, i + halfDim)
|
||
uint d1 = pair;
|
||
uint d2 = pair + halfDim;
|
||
|
||
float freqBase = pow(theta, -2.0 * float(pair) / float(headDim));
|
||
float freq = freqBase * pow(scale, float(position));
|
||
float c = cos(float(position) * freq);
|
||
float s = sin(float(position) * freq);
|
||
|
||
device float *h = q + head * headDim;
|
||
float v1 = h[d1], v2 = h[d2];
|
||
// rotate_half style: [ -v2 * sin, v1 * sin ] + [ v1 * cos, v2 * cos ]
|
||
// But we compute in-place: d1 = v1*c - v2*s, d2 = v1*s + v2*c (same formula!)
|
||
h[d1] = v1 * c - v2 * s;
|
||
h[d2] = v1 * s + v2 * c;
|
||
}
|
||
|
||
// ── 5b. Half-split RoPE (K only) ─────────────
|
||
// Gemma uses rotate_half style: pairs are (d, d + headDim/2), not (d1, d2)
|
||
kernel void apply_rope_k(
|
||
device float *k [[buffer(0)]], // [nKvHeads, headDim] — in-place
|
||
constant uint &nKvHeads [[buffer(1)]],
|
||
constant uint &headDim [[buffer(2)]],
|
||
constant uint &rotatedDim [[buffer(3)]],
|
||
constant float &theta [[buffer(4)]],
|
||
constant float &scale [[buffer(5)]],
|
||
constant int &position [[buffer(6)]],
|
||
uint id [[thread_position_in_grid]]
|
||
) {
|
||
uint halfDim = headDim / 2;
|
||
uint nPairs = rotatedDim / 2;
|
||
uint kvHead = id / nPairs;
|
||
uint pair = id % nPairs;
|
||
if (kvHead >= nKvHeads || pair >= nPairs || nPairs == 0) return;
|
||
|
||
// Half rotation: pair i corresponds to (i, i + halfDim)
|
||
uint d1 = pair;
|
||
uint d2 = pair + halfDim;
|
||
|
||
float freqBase = pow(theta, -2.0 * float(pair) / float(headDim));
|
||
float freq = freqBase * pow(scale, float(position));
|
||
float c = cos(float(position) * freq);
|
||
float s = sin(float(position) * freq);
|
||
|
||
device float *h = k + kvHead * headDim;
|
||
float v1 = h[d1], v2 = h[d2];
|
||
h[d1] = v1 * c - v2 * s;
|
||
h[d2] = v1 * s + v2 * c;
|
||
}
|
||
|
||
// ── 6. Scaled Dot-Product Attention (Sliding, rotating buffer) ──
|
||
// O = softmax(Q K^T) V with sliding window (rotating) and GQA.
|
||
// cacheIdx = (start + t) % windowSize for rotating wrap.
|
||
kernel void sliding_attention(
|
||
device const float *q [[buffer(0)]], // [nHeads, headDim]
|
||
device const float *k [[buffer(1)]], // [windowSize, nKvHeads, headDim]
|
||
device const float *v [[buffer(2)]], // [windowSize, nKvHeads, headDim]
|
||
device float *out [[buffer(3)]], // [nHeads, headDim]
|
||
constant uint &nHeads [[buffer(4)]],
|
||
constant uint &nKvHeads [[buffer(5)]],
|
||
constant uint &headDim [[buffer(6)]],
|
||
constant uint &windowSize [[buffer(7)]],
|
||
constant int &offset [[buffer(8)]], // current logical position
|
||
uint2 gid [[thread_position_in_grid]]
|
||
) {
|
||
uint head = gid.x;
|
||
uint dim = gid.y;
|
||
if (head >= nHeads || dim >= headDim) return;
|
||
|
||
uint kvHead = head % nKvHeads;
|
||
uint seqLen = uint(offset + 1);
|
||
uint actualWindow = min(seqLen, windowSize);
|
||
int base = int(offset) - int(actualWindow) + 1; // may be negative
|
||
|
||
// Cache layout: [maxLength, nKvHeads, headDim] flat
|
||
// k[p * nKvHeads * headDim + h * headDim + d]
|
||
// v[p * nKvHeads * headDim + h * headDim + d]
|
||
|
||
float scale = 1.0 / sqrt(float(headDim));
|
||
|
||
// Pass 1: find max score
|
||
float maxScore = -INFINITY;
|
||
for (uint t = 0; t < actualWindow; t++) {
|
||
int logicalPos = base + int(t);
|
||
uint cacheIdx = logicalPos >= 0 ? uint(logicalPos) % windowSize : 0;
|
||
float score = 0.0;
|
||
for (uint d = 0; d < headDim; d++) {
|
||
score += q[head * headDim + d] * k[(cacheIdx * nKvHeads + kvHead) * headDim + d];
|
||
}
|
||
score *= scale;
|
||
// Text model has NO attention softcapping
|
||
maxScore = max(maxScore, score);
|
||
}
|
||
|
||
// Pass 2: softmax + weighted sum
|
||
float sumExp = 0.0;
|
||
float result = 0.0;
|
||
for (uint t = 0; t < actualWindow; t++) {
|
||
int logicalPos = base + int(t);
|
||
uint cacheIdx = logicalPos >= 0 ? uint(logicalPos) % windowSize : 0;
|
||
float score = 0.0;
|
||
for (uint d = 0; d < headDim; d++) {
|
||
score += q[head * headDim + d] * k[(cacheIdx * nKvHeads + kvHead) * headDim + d];
|
||
}
|
||
score *= scale;
|
||
// Text model has NO attention softcapping
|
||
float expVal = exp(score - maxScore);
|
||
sumExp += expVal;
|
||
result += expVal * v[(cacheIdx * nKvHeads + kvHead) * headDim + dim];
|
||
}
|
||
out[head * headDim + dim] = result / sumExp;
|
||
}
|
||
|
||
// ── 6b. Sliding attention with current K,V appended ──
|
||
// For non-owner sliding layers: use cache K,V + current layer's K,V
|
||
// Cache entries are positions 0..cacheLen-1, current K/V is at position = cacheLen
|
||
kernel void sliding_attention_with_current(
|
||
device const float *q [[buffer(0)]], // [nHeads, headDim]
|
||
device const float *cacheK[[buffer(1)]], // [windowSize, nKvHeads, headDim]
|
||
device const float *cacheV[[buffer(2)]], // [windowSize, nKvHeads, headDim]
|
||
device const float *curK [[buffer(3)]], // [nKvHeads, headDim]
|
||
device const float *curV [[buffer(4)]], // [nKvHeads, headDim]
|
||
device float *out [[buffer(5)]], // [nHeads, headDim]
|
||
constant uint &nHeads [[buffer(6)]],
|
||
constant uint &nKvHeads [[buffer(7)]],
|
||
constant uint &headDim [[buffer(8)]],
|
||
constant uint &windowSize [[buffer(9)]],
|
||
constant uint &cacheLen [[buffer(10)]], // number of entries in cache
|
||
constant int &position [[buffer(11)]], // current position for causal mask
|
||
uint2 gid [[thread_position_in_grid]]
|
||
) {
|
||
uint head = gid.x;
|
||
uint dim = gid.y;
|
||
if (head >= nHeads || dim >= headDim) return;
|
||
|
||
uint kvHead = head % nKvHeads;
|
||
uint seqLen = cacheLen + 1; // cache entries + current K,V
|
||
uint actualWindow = min(seqLen, windowSize);
|
||
int base = int(seqLen) - int(actualWindow);
|
||
|
||
float scale = 1.0 / sqrt(float(headDim));
|
||
|
||
// Pass 1: find max score (with causal mask)
|
||
float maxScore = -INFINITY;
|
||
for (uint t = 0; t < actualWindow; t++) {
|
||
int logicalPos = base + int(t);
|
||
// Causal mask: only attend to positions <= current position
|
||
if (logicalPos > position) continue; // skip future positions
|
||
|
||
float score = 0.0;
|
||
if (logicalPos >= 0 && uint(logicalPos) < cacheLen) {
|
||
// From cache
|
||
uint cacheIdx = uint(logicalPos) % windowSize;
|
||
for (uint d = 0; d < headDim; d++) {
|
||
score += q[head * headDim + d] * cacheK[(cacheIdx * nKvHeads + kvHead) * headDim + d];
|
||
}
|
||
} else {
|
||
// Current layer's K,V (position = cacheLen)
|
||
for (uint d = 0; d < headDim; d++) {
|
||
score += q[head * headDim + d] * curK[kvHead * headDim + d];
|
||
}
|
||
}
|
||
score *= scale;
|
||
// Text model has NO attention softcapping
|
||
maxScore = max(maxScore, score);
|
||
}
|
||
|
||
// Pass 2: softmax + weighted sum (with causal mask)
|
||
float sumExp = 0.0;
|
||
float result = 0.0;
|
||
for (uint t = 0; t < actualWindow; t++) {
|
||
int logicalPos = base + int(t);
|
||
// Causal mask: only attend to positions <= current position
|
||
if (logicalPos > position) continue; // skip future positions
|
||
|
||
float score = 0.0;
|
||
if (logicalPos >= 0 && uint(logicalPos) < cacheLen) {
|
||
// From cache
|
||
uint cacheIdx = uint(logicalPos) % windowSize;
|
||
for (uint d = 0; d < headDim; d++) {
|
||
score += q[head * headDim + d] * cacheK[(cacheIdx * nKvHeads + kvHead) * headDim + d];
|
||
}
|
||
score *= scale;
|
||
// Text model has NO attention softcapping
|
||
float expVal = exp(score - maxScore);
|
||
sumExp += expVal;
|
||
result += expVal * cacheV[(cacheIdx * nKvHeads + kvHead) * headDim + dim];
|
||
} else {
|
||
// Current layer's K,V
|
||
for (uint d = 0; d < headDim; d++) {
|
||
score += q[head * headDim + d] * curK[kvHead * headDim + d];
|
||
}
|
||
score *= scale;
|
||
// Text model has NO attention softcapping
|
||
float expVal = exp(score - maxScore);
|
||
sumExp += expVal;
|
||
result += expVal * curV[kvHead * headDim + dim];
|
||
}
|
||
}
|
||
out[head * headDim + dim] = result / sumExp;
|
||
}
|
||
|
||
// ── 7. Scaled Dot-Product Attention (Full, causal) ──
|
||
// Two-pass: find max, then softmax + weighted sum (no score array splat).
|
||
// Supports arbitrary sequence length.
|
||
kernel void full_attention(
|
||
device const float *q [[buffer(0)]], // [nHeads, headDim]
|
||
device const float *k [[buffer(1)]], // [maxPos, nKvHeads, headDim]
|
||
device const float *v [[buffer(2)]], // [maxPos, nKvHeads, headDim]
|
||
device float *out [[buffer(3)]], // [nHeads, headDim]
|
||
constant uint &nHeads [[buffer(4)]],
|
||
constant uint &nKvHeads [[buffer(5)]],
|
||
constant uint &headDim [[buffer(6)]],
|
||
constant uint &maxPos [[buffer(7)]],
|
||
constant int &offset [[buffer(8)]],
|
||
uint2 gid [[thread_position_in_grid]]
|
||
) {
|
||
uint head = gid.x;
|
||
uint dim = gid.y;
|
||
if (head >= nHeads || dim >= headDim) return;
|
||
|
||
uint kvHead = head % nKvHeads;
|
||
uint seqLen = uint(offset + 1);
|
||
|
||
// Cache layout: [maxPos, nKvHeads, headDim] flat
|
||
// k[t * nKvHeads * headDim + h * headDim + d]
|
||
// v[t * nKvHeads * headDim + h * headDim + d]
|
||
|
||
float scale = 1.0 / sqrt(float(headDim));
|
||
|
||
// Pass 1: max
|
||
float maxScore = -INFINITY;
|
||
for (uint t = 0; t < seqLen; t++) {
|
||
float score = 0.0;
|
||
for (uint d = 0; d < headDim; d++) {
|
||
score += q[head * headDim + d] * k[(t * nKvHeads + kvHead) * headDim + d];
|
||
}
|
||
score *= scale;
|
||
// Text model has NO attention softcapping
|
||
maxScore = max(maxScore, score);
|
||
}
|
||
|
||
// Pass 2: softmax + weighted sum
|
||
float sumExp = 0.0;
|
||
float result = 0.0;
|
||
for (uint t = 0; t < seqLen; t++) {
|
||
float score = 0.0;
|
||
for (uint d = 0; d < headDim; d++) {
|
||
score += q[head * headDim + d] * k[(t * nKvHeads + kvHead) * headDim + d];
|
||
}
|
||
score *= scale;
|
||
// Text model has NO attention softcapping
|
||
float expVal = exp(score - maxScore);
|
||
sumExp += expVal;
|
||
result += expVal * v[(t * nKvHeads + kvHead) * headDim + dim];
|
||
}
|
||
out[head * headDim + dim] = result / sumExp;
|
||
}
|
||
|
||
// ── 7b. Full attention with current K,V appended ──
|
||
// For non-owner layers: use cache K,V + current layer's K,V
|
||
// Cache entries are positions 0..cacheLen-1, current K/V is at position = cacheLen
|
||
// Causal mask: only attend to positions <= current position
|
||
kernel void full_attention_with_current(
|
||
device const float *q [[buffer(0)]], // [nHeads, headDim]
|
||
device const float *cacheK[[buffer(1)]], // [maxPos, nKvHeads, headDim]
|
||
device const float *cacheV[[buffer(2)]], // [maxPos, nKvHeads, headDim]
|
||
device const float *curK [[buffer(3)]], // [nKvHeads, headDim]
|
||
device const float *curV [[buffer(4)]], // [nKvHeads, headDim]
|
||
device float *out [[buffer(5)]], // [nHeads, headDim]
|
||
constant uint &nHeads [[buffer(6)]],
|
||
constant uint &nKvHeads [[buffer(7)]],
|
||
constant uint &headDim [[buffer(8)]],
|
||
constant uint &cacheLen [[buffer(9)]], // number of entries in cache (position + 1)
|
||
constant int &position [[buffer(10)]], // current position for causal mask
|
||
uint2 gid [[thread_position_in_grid]]
|
||
) {
|
||
uint head = gid.x;
|
||
uint dim = gid.y;
|
||
if (head >= nHeads || dim >= headDim) return;
|
||
|
||
uint kvHead = head % nKvHeads;
|
||
|
||
float scale = 1.0 / sqrt(float(headDim));
|
||
|
||
// Pass 1: max score (with causal mask)
|
||
float maxScore = -INFINITY;
|
||
uint seqLen = cacheLen + 1; // cache entries + current K,V (always include current)
|
||
// Process all entries (cache + current), apply causal mask per-entry
|
||
for (uint t = 0; t < seqLen; t++) {
|
||
int logicalPos = int(t);
|
||
if (logicalPos > position) continue; // causal mask: skip future positions
|
||
float score = 0.0;
|
||
if (t < cacheLen) {
|
||
// From cache
|
||
for (uint d = 0; d < headDim; d++) {
|
||
score += q[head * headDim + d] * cacheK[(t * nKvHeads + kvHead) * headDim + d];
|
||
}
|
||
} else {
|
||
// Current layer's K,V
|
||
for (uint d = 0; d < headDim; d++) {
|
||
score += q[head * headDim + d] * curK[kvHead * headDim + d];
|
||
}
|
||
}
|
||
score *= scale;
|
||
// Text model has NO attention softcapping
|
||
maxScore = max(maxScore, score);
|
||
}
|
||
|
||
// Pass 2: softmax + weighted sum (with causal mask)
|
||
float sumExp = 0.0;
|
||
float result = 0.0;
|
||
for (uint t = 0; t < seqLen; t++) {
|
||
int logicalPos = int(t);
|
||
if (logicalPos > position) continue; // causal mask: skip future positions
|
||
float score = 0.0;
|
||
if (t < cacheLen) {
|
||
// From cache
|
||
for (uint d = 0; d < headDim; d++) {
|
||
score += q[head * headDim + d] * cacheK[(t * nKvHeads + kvHead) * headDim + d];
|
||
}
|
||
score *= scale;
|
||
float expVal = exp(score - maxScore);
|
||
sumExp += expVal;
|
||
result += expVal * cacheV[(t * nKvHeads + kvHead) * headDim + dim];
|
||
} else {
|
||
// Current layer's K,V
|
||
for (uint d = 0; d < headDim; d++) {
|
||
score += q[head * headDim + d] * curK[kvHead * headDim + d];
|
||
}
|
||
score *= scale;
|
||
float expVal = exp(score - maxScore);
|
||
sumExp += expVal;
|
||
result += expVal * curV[kvHead * headDim + dim];
|
||
}
|
||
}
|
||
|
||
out[head * headDim + dim] = result / sumExp;
|
||
}
|
||
|
||
// ── 8. Dequantize a single row ─────────────────
|
||
kernel void dequantize_row(
|
||
device const uint *w [[buffer(0)]], // [nRows, nCols/8]
|
||
device const float *s [[buffer(1)]], // [nRows, numGroups]
|
||
device const float *b [[buffer(2)]], // [nRows, numGroups]
|
||
device float *out [[buffer(3)]], // [nCols]
|
||
constant uint &nCols [[buffer(4)]],
|
||
constant int &rowIdx [[buffer(5)]],
|
||
constant uint &groupSize [[buffer(6)]],
|
||
uint id [[thread_position_in_grid]]
|
||
) {
|
||
if (id >= nCols) return;
|
||
uint g = id / groupSize;
|
||
uint inG = id % groupSize;
|
||
uint packedIdx = g * (groupSize / 8) + inG / 8;
|
||
uint shift = (inG % 8) * 4;
|
||
uint qval = (w[rowIdx * (nCols / 8) + packedIdx] >> shift) & 0xF;
|
||
uint numGroups = nCols / groupSize;
|
||
float scale = s[rowIdx * numGroups + g];
|
||
float bias = b[rowIdx * numGroups + g];
|
||
out[id] = float(qval) * scale + bias;
|
||
}
|
||
|
||
// ── 9. Element-wise helpers ────────────────────
|
||
kernel void eltwise_add(
|
||
device const float *a [[buffer(0)]],
|
||
device const float *b [[buffer(1)]],
|
||
device float *out [[buffer(2)]],
|
||
constant uint &n [[buffer(3)]],
|
||
uint id [[thread_position_in_grid]]
|
||
) { if (id < n) out[id] = a[id] + b[id]; }
|
||
|
||
kernel void eltwise_mul(
|
||
device const float *a [[buffer(0)]],
|
||
device const float *b [[buffer(1)]],
|
||
device float *out [[buffer(2)]],
|
||
constant uint &n [[buffer(3)]],
|
||
uint id [[thread_position_in_grid]]
|
||
) { if (id < n) out[id] = a[id] * b[id]; }
|
||
|
||
kernel void eltwise_scale(
|
||
device float *buf [[buffer(0)]],
|
||
constant float &scale [[buffer(1)]],
|
||
constant uint &n [[buffer(2)]],
|
||
uint id [[thread_position_in_grid]]
|
||
) { if (id < n) buf[id] *= scale; }
|
||
|
||
// out = a * scaleA + b * scaleB
|
||
kernel void eltwise_add_scaled(
|
||
device const float *a [[buffer(0)]],
|
||
constant float &scaleA [[buffer(1)]],
|
||
device const float *b [[buffer(2)]],
|
||
constant float &scaleB [[buffer(3)]],
|
||
device float *out [[buffer(4)]],
|
||
constant uint &n [[buffer(5)]],
|
||
uint id [[thread_position_in_grid]]
|
||
) { if (id < n) out[id] = a[id] * scaleA + b[id] * scaleB; }
|
||
|
||
// ── 10. Tanh scaling (logit softcapping) ──────
|
||
// out[i] = tanh(in[i] / cap) * cap
|
||
kernel void tanh_scale(
|
||
device const float *inp [[buffer(0)]],
|
||
device float *out [[buffer(1)]],
|
||
constant float &cap [[buffer(2)]],
|
||
constant uint &n [[buffer(3)]],
|
||
uint id [[thread_position_in_grid]]
|
||
) {
|
||
if (id >= n) return;
|
||
float v = inp[id];
|
||
out[id] = tanh(v / cap) * cap;
|
||
}
|
||
|
||
// ══════════════════════════════════════════════════════
|
||
// Audio Processing Kernels
|
||
// ══════════════════════════════════════════════════════
|
||
|
||
// ── Audio Subsample Convolution ──
|
||
// 2D convolution for audio feature extraction
|
||
kernel void audio_conv2d(
|
||
device const float *input [[buffer(0)]], // [inChannels, height, width]
|
||
device const float *weight [[buffer(1)]], // [outChannels, inChannels, kernelH, kernelW]
|
||
device float *output [[buffer(2)]], // [outChannels, outHeight, outWidth]
|
||
constant uint &inChannels [[buffer(3)]],
|
||
constant uint &outChannels [[buffer(4)]],
|
||
constant uint &inHeight [[buffer(5)]],
|
||
constant uint &inWidth [[buffer(6)]],
|
||
constant uint &kernelH [[buffer(7)]],
|
||
constant uint &kernelW [[buffer(8)]],
|
||
constant uint &strideH [[buffer(9)]],
|
||
constant uint &strideW [[buffer(10)]],
|
||
uint3 gid [[thread_position_in_grid]]
|
||
) {
|
||
uint oc = gid.x; // output channel
|
||
uint oh = gid.y; // output height
|
||
uint ow = gid.z; // output width
|
||
|
||
if (oc >= outChannels || oh >= (inHeight - kernelH + strideH) / strideH ||
|
||
ow >= (inWidth - kernelW + strideW) / strideW) return;
|
||
|
||
float sum = 0.0;
|
||
for (uint ic = 0; ic < inChannels; ic++) {
|
||
for (uint kh = 0; kh < kernelH; kh++) {
|
||
for (uint kw = 0; kw < kernelW; kw++) {
|
||
uint ih = oh * strideH + kh;
|
||
uint iw = ow * strideW + kw;
|
||
uint inIdx = ic * inHeight * inWidth + ih * inWidth + iw;
|
||
uint wIdx = oc * inChannels * kernelH * kernelW +
|
||
ic * kernelH * kernelW + kh * kernelW + kw;
|
||
sum += input[inIdx] * weight[wIdx];
|
||
}
|
||
}
|
||
}
|
||
|
||
uint outIdx = oc * ((inHeight - kernelH + strideH) / strideH) *
|
||
((inWidth - kernelW + strideW) / strideW) + oh *
|
||
((inWidth - kernelW + strideW) / strideW) + ow;
|
||
output[outIdx] = sum;
|
||
}
|
||
|
||
// ── Audio RMS Norm ──
|
||
// Per-channel RMS normalization for audio features
|
||
kernel void audio_rms_norm(
|
||
device const float *input [[buffer(0)]],
|
||
device const float *weight [[buffer(1)]],
|
||
device float *output [[buffer(2)]],
|
||
constant uint &channels [[buffer(3)]],
|
||
constant uint &featureSize [[buffer(4)]],
|
||
constant float &eps [[buffer(5)]],
|
||
uint2 gid [[thread_position_in_grid]]
|
||
) {
|
||
uint ch = gid.x;
|
||
uint feat = gid.y;
|
||
if (ch >= channels || feat >= featureSize) return;
|
||
|
||
// Compute RMS for this channel
|
||
float ss = 0.0;
|
||
for (uint f = 0; f < featureSize; f++) {
|
||
float v = input[ch * featureSize + f];
|
||
ss += v * v;
|
||
}
|
||
float rms = sqrt(ss / float(featureSize) + eps);
|
||
|
||
// Normalize and apply weight
|
||
uint idx = ch * featureSize + feat;
|
||
output[idx] = input[idx] / rms * weight[ch];
|
||
}
|
||
|
||
// ── Audio Linear Projection ──
|
||
// Linear projection for audio features
|
||
kernel void audio_linear(
|
||
device const float *input [[buffer(0)]], // [inFeatures]
|
||
device const float *weight [[buffer(1)]], // [outFeatures, inFeatures]
|
||
device const float *bias [[buffer(2)]], // [outFeatures] (optional)
|
||
device float *output [[buffer(3)]], // [outFeatures]
|
||
constant uint &inFeatures [[buffer(4)]],
|
||
constant uint &outFeatures [[buffer(5)]],
|
||
constant bool &hasBias [[buffer(6)]],
|
||
uint gid [[thread_position_in_grid]]
|
||
) {
|
||
uint of = gid;
|
||
if (of >= outFeatures) return;
|
||
|
||
float sum = hasBias ? bias[of] : 0.0;
|
||
for (uint inf = 0; inf < inFeatures; inf++) {
|
||
sum += input[inf] * weight[of * inFeatures + inf];
|
||
}
|
||
output[of] = sum;
|
||
}
|
||
|
||
// ── Audio Attention ──
|
||
// Sliding window attention for audio encoder
|
||
kernel void audio_attention(
|
||
device const float *q [[buffer(0)]], // [nHeads, headDim]
|
||
device const float *k [[buffer(1)]], // [seqLen, nKvHeads, headDim]
|
||
device const float *v [[buffer(2)]], // [seqLen, nKvHeads, headDim]
|
||
device float *out [[buffer(3)]], // [nHeads, headDim]
|
||
constant uint &nHeads [[buffer(4)]],
|
||
constant uint &nKvHeads [[buffer(5)]],
|
||
constant uint &headDim [[buffer(6)]],
|
||
constant uint &seqLen [[buffer(7)]],
|
||
constant uint &chunkSize [[buffer(8)]],
|
||
constant uint &contextLeft [[buffer(9)]],
|
||
constant float &logitCap [[buffer(10)]],
|
||
uint2 gid [[thread_position_in_grid]]
|
||
) {
|
||
uint head = gid.x;
|
||
uint dim = gid.y;
|
||
if (head >= nHeads || dim >= headDim) return;
|
||
|
||
uint kvHead = head % nKvHeads;
|
||
float scale = 1.0 / sqrt(float(headDim));
|
||
|
||
// Compute attention scores for chunk
|
||
float maxScore = -INFINITY;
|
||
uint startPos = max(0u, seqLen - chunkSize);
|
||
|
||
for (uint t = startPos; t < seqLen; t++) {
|
||
float score = 0.0;
|
||
for (uint d = 0; d < headDim; d++) {
|
||
score += q[head * headDim + d] * k[(t * nKvHeads + kvHead) * headDim + d];
|
||
}
|
||
score *= scale;
|
||
score = tanh(score / logitCap) * logitCap;
|
||
maxScore = max(maxScore, score);
|
||
}
|
||
|
||
// Softmax + weighted sum
|
||
float sumExp = 0.0;
|
||
float result = 0.0;
|
||
for (uint t = startPos; t < seqLen; t++) {
|
||
float score = 0.0;
|
||
for (uint d = 0; d < headDim; d++) {
|
||
score += q[head * headDim + d] * k[(t * nKvHeads + kvHead) * headDim + d];
|
||
}
|
||
score *= scale;
|
||
score = tanh(score / logitCap) * logitCap;
|
||
float expVal = exp(score - maxScore);
|
||
sumExp += expVal;
|
||
result += expVal * v[(t * nKvHeads + kvHead) * headDim + dim];
|
||
}
|
||
|
||
out[head * headDim + dim] = result / sumExp;
|
||
}
|
||
|
||
// ── Audio FFN (Feed-Forward Network) ──
|
||
kernel void audio_ffn(
|
||
device const float *input [[buffer(0)]],
|
||
device const float *weight1 [[buffer(1)]],
|
||
device const float *weight2 [[buffer(2)]],
|
||
device float *output [[buffer(3)]],
|
||
constant uint &inFeatures [[buffer(4)]],
|
||
constant uint &hiddenFeatures [[buffer(5)]],
|
||
uint gid [[thread_position_in_grid]]
|
||
) {
|
||
uint of = gid;
|
||
if (of >= inFeatures) return;
|
||
|
||
// FFN: hidden = silu(input @ weight1), output = hidden @ weight2
|
||
float sum = 0.0;
|
||
for (uint hf = 0; hf < hiddenFeatures; hf++) {
|
||
float h = 0.0;
|
||
for (uint inf = 0; inf < inFeatures; inf++) {
|
||
h += input[inf] * weight1[hf * inFeatures + inf];
|
||
}
|
||
// SiLU activation: x * sigmoid(x)
|
||
float sigmoidH = 1.0 / (1.0 + exp(-h));
|
||
float activated = h * sigmoidH;
|
||
sum += activated * weight2[of * hiddenFeatures + hf];
|
||
}
|
||
output[of] = sum;
|
||
}
|
||
|
||
// ── Audio 1D Convolution (Local Context) ──
|
||
kernel void audio_conv1d(
|
||
device const float *input [[buffer(0)]],
|
||
device const float *weight [[buffer(1)]],
|
||
device float *output [[buffer(2)]],
|
||
constant uint &inChannels [[buffer(3)]],
|
||
constant uint &outChannels [[buffer(4)]],
|
||
constant uint &kernelSize [[buffer(5)]],
|
||
constant uint &seqLen [[buffer(6)]],
|
||
uint2 gid [[thread_position_in_grid]]
|
||
) {
|
||
uint oc = gid.x;
|
||
uint pos = gid.y;
|
||
if (oc >= outChannels || pos >= seqLen) return;
|
||
|
||
float sum = 0.0;
|
||
for (uint ic = 0; ic < inChannels; ic++) {
|
||
for (uint k = 0; k < kernelSize; k++) {
|
||
int inPos = int(pos) - int(kernelSize / 2) + int(k);
|
||
if (inPos >= 0 && inPos < int(seqLen)) {
|
||
uint inIdx = ic * seqLen + uint(inPos);
|
||
uint wIdx = oc * inChannels * kernelSize + ic * kernelSize + k;
|
||
sum += input[inIdx] * weight[wIdx];
|
||
}
|
||
}
|
||
}
|
||
output[oc * seqLen + pos] = sum;
|
||
}
|
||
|
||
// ═══════════════════════════════════════════════
|
||
// Audio Tower Kernels
|
||
// ═══════════════════════════════════════════════
|
||
|
||
// Audio subsample conv 2D: stride-2 conv2d with group norm
|
||
// Treats mel spectrogram as 2D surface: [inCh, H, W] where H=nMels, W=seqLen
|
||
// Conv weight (safetensors format): [outCh, kernelH, kernelW, inCh] = [outCh, 3, 3, inCh]
|
||
// Norm weight: [outCh] per-channel group norm
|
||
// Output: [outCh, outH, outW] CHW flat where outH=(H+1)/2, outW=(W+1)/2
|
||
kernel void audio_subsample_conv_2d(
|
||
device const float *input [[buffer(0)]], // [inCh, H, W] CHW flat
|
||
device const float *convWeight [[buffer(1)]], // [outCh, 3, 3, inCh]
|
||
device const float *normWeight [[buffer(2)]], // [outCh]
|
||
device float *output [[buffer(3)]], // [outCh, outH, outW] CHW flat
|
||
constant uint &inChannels [[buffer(4)]],
|
||
constant uint &outChannels [[buffer(5)]],
|
||
constant uint &height [[buffer(6)]], // H = nMels
|
||
constant uint &width [[buffer(7)]], // W = seqLen
|
||
uint3 gid [[thread_position_in_grid]]
|
||
) {
|
||
uint oc = gid.x;
|
||
uint oh = gid.y;
|
||
uint ow = gid.z;
|
||
uint outH = (height + 1) / 2;
|
||
uint outW = (width + 1) / 2;
|
||
if (oc >= outChannels || oh >= outH || ow >= outW) return;
|
||
|
||
int ihStart = int(oh * 2) - 1;
|
||
int iwStart = int(ow * 2) - 1;
|
||
|
||
float sum = 0.0;
|
||
for (uint ic = 0; ic < inChannels; ic++) {
|
||
for (uint kh = 0; kh < 3; kh++) {
|
||
for (uint kw = 0; kw < 3; kw++) {
|
||
int ih = ihStart + int(kh);
|
||
int iw = iwStart + int(kw);
|
||
if (ih >= 0 && ih < int(height) && iw >= 0 && iw < int(width)) {
|
||
uint inIdx = ic * height * width + uint(ih) * width + uint(iw);
|
||
uint wIdx = oc * 9 * inChannels + kh * 3 * inChannels + kw * inChannels + ic;
|
||
sum += input[inIdx] * convWeight[wIdx];
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
sum = sum * normWeight[oc];
|
||
output[oc * outH * outW + oh * outW + ow] = sum;
|
||
}
|
||
|
||
// Transpose 2D matrix: [rows, cols] -> [cols, rows]
|
||
// Used for converting mel spectrogram from [seqLen, nMels] to CHW [1, nMels, seqLen]
|
||
kernel void transpose_2d(
|
||
device const float *input [[buffer(0)]], // [rows, cols] row-major
|
||
device float *output [[buffer(1)]], // [cols, rows] row-major
|
||
constant uint &rows [[buffer(2)]],
|
||
constant uint &cols [[buffer(3)]],
|
||
uint2 gid [[thread_position_in_grid]]
|
||
) {
|
||
uint r = gid.y;
|
||
uint c = gid.x;
|
||
if (r >= rows || c >= cols) return;
|
||
|
||
output[c * rows + r] = input[r * cols + c];
|
||
}
|
||
|
||
// Flatten CHW [C, H, W] -> row-major [W, C*H]
|
||
// Useful after subsample conv to prepare for linear projection
|
||
kernel void audio_flatten_chw(
|
||
device const float *input [[buffer(0)]], // [C, H, W] CHW flat
|
||
device float *output [[buffer(1)]], // [W, C*H] row-major
|
||
constant uint &C [[buffer(2)]],
|
||
constant uint &H [[buffer(3)]],
|
||
constant uint &W [[buffer(4)]],
|
||
uint2 gid [[thread_position_in_grid]]
|
||
) {
|
||
uint ch = gid.x;
|
||
uint w = gid.y;
|
||
if (ch >= C * H || w >= W) return;
|
||
|
||
uint c = ch / H;
|
||
uint h = ch % H;
|
||
uint inIdx = c * H * W + h * W + w;
|
||
output[w * (C * H) + ch] = input[inIdx];
|
||
}
|
||
|
||
// Audio linear seq: [seqLen, inFeatures] -> [seqLen, outFeatures]
|
||
kernel void audio_linear_seq(
|
||
device const float *input [[buffer(0)]],
|
||
device const float *weight [[buffer(1)]],
|
||
device const float *bias [[buffer(2)]],
|
||
device float *output [[buffer(3)]],
|
||
constant uint &inFeatures [[buffer(4)]],
|
||
constant uint &outFeatures [[buffer(5)]],
|
||
constant bool &hasBias [[buffer(6)]],
|
||
constant uint &seqLen [[buffer(7)]],
|
||
uint2 gid [[thread_position_in_grid]]
|
||
) {
|
||
uint of = gid.x;
|
||
uint s = gid.y;
|
||
if (of >= outFeatures || s >= seqLen) return;
|
||
|
||
float sum = hasBias ? bias[of] : 0.0;
|
||
for (uint i = 0; i < inFeatures; i++) {
|
||
sum += input[s * inFeatures + i] * weight[of * inFeatures + i];
|
||
}
|
||
output[s * outFeatures + of] = sum;
|
||
}
|
||
|
||
// Audio quantized matmul with sequence dimension (batched)
|
||
// output[s * outDim + of] = bias[of] + sum_i input[s * inDim + i] * deq(weight[of][i])
|
||
kernel void quantized_matmul_seq(
|
||
device const float *input [[buffer(0)]], // [seqLen, inDim]
|
||
device const uint *weight [[buffer(1)]], // [outDim, inDim/8]
|
||
device const float *scales [[buffer(2)]], // [outDim, inDim/64]
|
||
device const float *biases_q [[buffer(3)]], // [outDim, inDim/64]
|
||
device const float *bias [[buffer(4)]], // [outDim] optional output bias
|
||
device float *output [[buffer(5)]], // [seqLen, outDim]
|
||
constant uint &inDim [[buffer(6)]],
|
||
constant uint &outDim [[buffer(7)]],
|
||
constant bool &hasBias [[buffer(8)]],
|
||
constant uint &seqLen [[buffer(9)]],
|
||
uint2 gid [[thread_position_in_grid]]
|
||
) {
|
||
uint of = gid.x;
|
||
uint s = gid.y;
|
||
if (of >= outDim || s >= seqLen) return;
|
||
|
||
uint numGroups = inDim / GROUP_SIZE;
|
||
uint packedPerOut = inDim / 8;
|
||
float sum = hasBias ? bias[of] : 0.0;
|
||
|
||
for (uint g = 0; g < numGroups; g++) {
|
||
float scale = scales[of * numGroups + g];
|
||
float bias_q = biases_q[of * numGroups + g];
|
||
for (uint j = 0; j < GROUP_SIZE; j++) {
|
||
uint packedIdx = g * (GROUP_SIZE / 8) + j / 8;
|
||
uint shift = (j % 8) * 4;
|
||
uint qval = (weight[of * packedPerOut + packedIdx] >> shift) & 0xF;
|
||
float dq = float(qval) * scale + bias_q;
|
||
sum += dq * input[s * inDim + g * GROUP_SIZE + j];
|
||
}
|
||
}
|
||
output[s * outDim + of] = sum;
|
||
}
|
||
|
||
// Audio quantized linear with 8-bit quantization scales
|
||
kernel void audio_quantized_linear(
|
||
device const float *input [[buffer(0)]],
|
||
device const float *weight [[buffer(1)]],
|
||
device const float *inputMin [[buffer(2)]],
|
||
device const float *inputMax [[buffer(3)]],
|
||
device const float *outputMin [[buffer(4)]],
|
||
device const float *outputMax [[buffer(5)]],
|
||
device float *output [[buffer(6)]],
|
||
constant uint &inFeatures [[buffer(7)]],
|
||
constant uint &outFeatures [[buffer(8)]],
|
||
constant uint &seqLen [[buffer(9)]],
|
||
uint2 gid [[thread_position_in_grid]]
|
||
) {
|
||
uint of = gid.x;
|
||
uint s = gid.y;
|
||
if (of >= outFeatures || s >= seqLen) return;
|
||
|
||
float sum = 0.0;
|
||
for (uint i = 0; i < inFeatures; i++) {
|
||
// Apply input quantization scale if available
|
||
float inVal = input[s * inFeatures + i];
|
||
if (inputMin && inputMax) {
|
||
float scale = (inputMax[0] - inputMin[0]) / 255.0;
|
||
inVal = inVal * scale + inputMin[0];
|
||
}
|
||
sum += inVal * weight[of * inFeatures + i];
|
||
}
|
||
|
||
// Apply output quantization scale if available
|
||
if (outputMin && outputMax) {
|
||
float scale = (outputMax[0] - outputMin[0]) / 255.0;
|
||
sum = sum * scale + outputMin[0];
|
||
}
|
||
|
||
output[s * outFeatures + of] = sum;
|
||
}
|
||
|
||
// Audio attention with relative position and context window
|
||
kernel void audio_attention_full(
|
||
device const float *q [[buffer(0)]],
|
||
device const float *k [[buffer(1)]],
|
||
device const float *v [[buffer(2)]],
|
||
device const float *relativeK [[buffer(3)]],
|
||
device const float *perDimScale [[buffer(4)]],
|
||
device float *output [[buffer(5)]],
|
||
constant uint &seqLen [[buffer(6)]],
|
||
constant uint &numHeads [[buffer(7)]],
|
||
constant uint &headDim [[buffer(8)]],
|
||
constant uint &contextLeft [[buffer(9)]],
|
||
constant float &logitCap [[buffer(10)]],
|
||
uint2 gid [[thread_position_in_grid]]
|
||
) {
|
||
uint idx = gid.x;
|
||
uint pos = gid.y;
|
||
|
||
uint head = idx / headDim;
|
||
uint d = idx % headDim;
|
||
|
||
if (head >= numHeads || pos >= seqLen) return;
|
||
|
||
float qVal = q[pos * numHeads * headDim + head * headDim + d];
|
||
|
||
// Compute attention scores
|
||
float sum = 0.0;
|
||
float maxScore = -INFINITY;
|
||
|
||
// Context window: attend to positions [pos - contextLeft, pos]
|
||
int startPos = max(0, int(pos) - int(contextLeft));
|
||
|
||
for (int p = startPos; p <= int(pos); p++) {
|
||
float kVal = k[uint(p) * numHeads * headDim + head * headDim + d];
|
||
float score = qVal * kVal * perDimScale[head * headDim + d];
|
||
score = min(score, logitCap);
|
||
score = max(score, -logitCap);
|
||
maxScore = max(maxScore, score);
|
||
}
|
||
|
||
// Softmax
|
||
float expSum = 0.0;
|
||
for (int p = startPos; p <= int(pos); p++) {
|
||
float kVal = k[uint(p) * numHeads * headDim + head * headDim + d];
|
||
float score = qVal * kVal * perDimScale[head * headDim + d];
|
||
score = min(score, logitCap);
|
||
score = max(score, -logitCap);
|
||
expSum += exp(score - maxScore);
|
||
}
|
||
|
||
// Weighted sum of values
|
||
float outVal = 0.0;
|
||
for (int p = startPos; p <= int(pos); p++) {
|
||
float kVal = k[uint(p) * numHeads * headDim + head * headDim + d];
|
||
float vVal = v[uint(p) * numHeads * headDim + head * headDim + d];
|
||
float score = qVal * kVal * perDimScale[head * headDim + d];
|
||
score = min(score, logitCap);
|
||
score = max(score, -logitCap);
|
||
float attn = exp(score - maxScore) / expSum;
|
||
outVal += attn * vVal;
|
||
}
|
||
|
||
output[pos * numHeads * headDim + head * headDim + d] = outVal;
|
||
}
|
||
|
||
// Audio depthwise conv1d
|
||
kernel void audio_depthwise_conv1d(
|
||
device const float *input [[buffer(0)]],
|
||
device const float *weight [[buffer(1)]],
|
||
device const float *norm [[buffer(2)]],
|
||
device float *output [[buffer(3)]],
|
||
constant uint &channels [[buffer(4)]],
|
||
constant uint &kernelSize [[buffer(5)]],
|
||
constant uint &seqLen [[buffer(6)]],
|
||
uint2 gid [[thread_position_in_grid]]
|
||
) {
|
||
uint c = gid.x;
|
||
uint pos = gid.y;
|
||
if (c >= channels || pos >= seqLen) return;
|
||
|
||
int halfKernel = int(kernelSize) / 2;
|
||
float sum = 0.0;
|
||
|
||
for (uint k = 0; k < kernelSize; k++) {
|
||
int inPos = int(pos) - halfKernel + int(k);
|
||
if (inPos >= 0 && inPos < int(seqLen)) {
|
||
uint inIdx = uint(inPos) * channels + c;
|
||
uint wIdx = c * kernelSize + k;
|
||
sum += input[inIdx] * weight[wIdx];
|
||
}
|
||
}
|
||
|
||
// Apply norm
|
||
sum = sum * norm[c];
|
||
|
||
output[pos * channels + c] = sum;
|
||
}
|
||
|
||
// ═══════════════════════════════════════════════
|
||
// GPU Mel Spectrogram Extraction
|
||
// ═══════════════════════════════════════════════
|
||
|
||
// DFT magnitude spectrum for all frames in parallel
|
||
// Grid: [numFrames, spectrumSize] where spectrumSize = nFft/2 + 1
|
||
kernel void audio_dft_magnitude(
|
||
device const float *audioData [[buffer(0)]], // [audioLen]
|
||
device float *spectrum [[buffer(1)]], // [numFrames * spectrumSize]
|
||
constant uint &nFft [[buffer(2)]],
|
||
constant uint &hopLength [[buffer(3)]],
|
||
constant uint &numFrames [[buffer(4)]],
|
||
constant uint &spectrumSize [[buffer(5)]],
|
||
constant uint &audioLen [[buffer(6)]], // total audio length for bounds check
|
||
uint2 gid [[thread_position_in_grid]]
|
||
) {
|
||
uint frame = gid.x;
|
||
uint bin = gid.y;
|
||
if (frame >= numFrames || bin >= spectrumSize) return;
|
||
|
||
uint start = frame * hopLength;
|
||
float real = 0.0;
|
||
float imag = 0.0;
|
||
|
||
for (uint i = 0; i < nFft; i++) {
|
||
float angle = -2.0 * M_PI_F * float(bin) * float(i) / float(nFft);
|
||
float sample = (start + i < audioLen) ? audioData[start + i] : 0.0;
|
||
float window = 0.5 * (1.0 - cos(2.0 * M_PI_F * float(i) / float(nFft - 1)));
|
||
float val = sample * window;
|
||
real += val * cos(angle);
|
||
imag += val * sin(angle);
|
||
}
|
||
|
||
spectrum[frame * spectrumSize + bin] = sqrt(real * real + imag * imag);
|
||
}
|
||
|
||
// Apply mel filterbank to DFT magnitude spectrum
|
||
// Grid: [numFrames, nMels]
|
||
kernel void audio_mel_filterbank(
|
||
device const float *spectrum [[buffer(0)]], // [numFrames * spectrumSize]
|
||
device const float *filterbank [[buffer(1)]], // [nMels * spectrumSize] precomputed
|
||
device float *melSpec [[buffer(2)]], // [numFrames * nMels]
|
||
constant uint &spectrumSize [[buffer(3)]],
|
||
constant uint &nMels [[buffer(4)]],
|
||
constant uint &numFrames [[buffer(5)]],
|
||
uint2 gid [[thread_position_in_grid]]
|
||
) {
|
||
uint frame = gid.x;
|
||
uint mel = gid.y;
|
||
if (frame >= numFrames || mel >= nMels) return;
|
||
|
||
float sum = 0.0;
|
||
for (uint bin = 0; bin < spectrumSize; bin++) {
|
||
sum += spectrum[frame * spectrumSize + bin] * filterbank[mel * spectrumSize + bin];
|
||
}
|
||
melSpec[frame * nMels + mel] = log10(max(sum, 1e-10));
|
||
}
|
||
|
||
// RMS norm with seqLen support
|
||
kernel void rms_norm_seq(
|
||
device const float *x [[buffer(0)]],
|
||
device const float *w [[buffer(1)]],
|
||
device float *y [[buffer(2)]],
|
||
constant uint &N [[buffer(3)]],
|
||
constant float &eps [[buffer(4)]],
|
||
constant uint &seqLen [[buffer(5)]],
|
||
uint2 gid [[thread_position_in_grid]]
|
||
) {
|
||
uint i = gid.x;
|
||
uint s = gid.y;
|
||
if (i >= N || s >= seqLen) return;
|
||
|
||
float ss = 0.0;
|
||
for (uint j = 0; j < N; j++) {
|
||
float val = x[s * N + j];
|
||
ss += val * val;
|
||
}
|
||
float rms = rsqrt(ss / float(N) + eps);
|
||
y[s * N + i] = x[s * N + i] * rms * w[i];
|
||
}
|
||
|
||
// SiLU activation
|
||
kernel void silu(
|
||
device const float *x [[buffer(0)]],
|
||
device float *y [[buffer(1)]],
|
||
constant uint &count [[buffer(2)]],
|
||
uint id [[thread_position_in_grid]]
|
||
) {
|
||
if (id >= count) return;
|
||
float val = x[id];
|
||
y[id] = val * (1.0 / (1.0 + exp(-val)));
|
||
}
|
||
|
||
// Residual add with weight
|
||
kernel void residual_add(
|
||
device const float *input [[buffer(0)]],
|
||
device const float *add [[buffer(1)]],
|
||
device float *output [[buffer(2)]],
|
||
constant uint &count [[buffer(3)]],
|
||
constant float &weight [[buffer(4)]],
|
||
uint id [[thread_position_in_grid]]
|
||
) {
|
||
if (id >= count) return;
|
||
output[id] = input[id] + weight * add[id];
|
||
}
|
||
|
||
// ═══════════════════════════════════════════════
|
||
// Vision Tower Kernels
|
||
// ═══════════════════════════════════════════════
|
||
|
||
// Vision add position embedding
|
||
kernel void vision_add_pos_embed(
|
||
device const float *input [[buffer(0)]],
|
||
device const float *positionEmbed [[buffer(1)]], // [2, 10240, 768]
|
||
device float *output [[buffer(2)]],
|
||
constant uint &hiddenSize [[buffer(3)]],
|
||
constant uint &numPatches [[buffer(4)]],
|
||
uint2 gid [[thread_position_in_grid]]
|
||
) {
|
||
uint h = gid.x;
|
||
uint p = gid.y;
|
||
if (h >= hiddenSize || p >= numPatches) return;
|
||
|
||
// Position embedding table: [2, 10240, 768] - use table 0, position p
|
||
// Index: table=0, pos=p, hidden=h -> 0 * 10240 * 768 + p * 768 + h
|
||
float posEmbed = positionEmbed[p * hiddenSize + h];
|
||
output[p * hiddenSize + h] = input[p * hiddenSize + h] + posEmbed;
|
||
}
|
||
|
||
// Vision head norm (RMS norm per head)
|
||
kernel void vision_head_norm(
|
||
device const float *x [[buffer(0)]], // [seqLen, numHeads, headDim]
|
||
device const float *w [[buffer(1)]], // [headDim]
|
||
device float *y [[buffer(2)]],
|
||
constant uint &numHeads [[buffer(3)]],
|
||
constant uint &headDim [[buffer(4)]],
|
||
constant uint &seqLen [[buffer(5)]],
|
||
constant float &eps [[buffer(6)]],
|
||
uint2 gid [[thread_position_in_grid]]
|
||
) {
|
||
uint idx = gid.x;
|
||
uint s = gid.y;
|
||
|
||
uint head = idx / headDim;
|
||
uint d = idx % headDim;
|
||
|
||
if (head >= numHeads || s >= seqLen) return;
|
||
|
||
float ss = 0.0;
|
||
for (uint i = 0; i < headDim; i++) {
|
||
float val = x[s * numHeads * headDim + head * headDim + i];
|
||
ss += val * val;
|
||
}
|
||
float rms = rsqrt(ss / float(headDim) + eps);
|
||
y[s * numHeads * headDim + head * headDim + d] = x[s * numHeads * headDim + head * headDim + d] * rms * w[d];
|
||
}
|
||
|
||
// Vision attention (global, no causal mask)
|
||
kernel void vision_attention(
|
||
device const float *q [[buffer(0)]], // [numPatches, numHeads, headDim]
|
||
device const float *k [[buffer(1)]], // [numPatches, numHeads, headDim]
|
||
device const float *v [[buffer(2)]], // [numPatches, numHeads, headDim]
|
||
device float *output [[buffer(3)]],
|
||
constant uint &numPatches [[buffer(4)]],
|
||
constant uint &numHeads [[buffer(5)]],
|
||
constant uint &headDim [[buffer(6)]],
|
||
uint2 gid [[thread_position_in_grid]]
|
||
) {
|
||
uint idx = gid.x;
|
||
uint pos = gid.y;
|
||
|
||
uint head = idx / headDim;
|
||
uint d = idx % headDim;
|
||
|
||
if (head >= numHeads || pos >= numPatches) return;
|
||
|
||
float scale = 1.0 / sqrt(float(headDim));
|
||
|
||
// Compute attention scores for all positions
|
||
float maxScore = -INFINITY;
|
||
for (uint p = 0; p < numPatches; p++) {
|
||
float score = 0.0;
|
||
for (uint i = 0; i < headDim; i++) {
|
||
score += q[pos * numHeads * headDim + head * headDim + i] *
|
||
k[p * numHeads * headDim + head * headDim + i];
|
||
}
|
||
score *= scale;
|
||
maxScore = max(maxScore, score);
|
||
}
|
||
|
||
// Softmax
|
||
float expSum = 0.0;
|
||
for (uint p = 0; p < numPatches; p++) {
|
||
float score = 0.0;
|
||
for (uint i = 0; i < headDim; i++) {
|
||
score += q[pos * numHeads * headDim + head * headDim + i] *
|
||
k[p * numHeads * headDim + head * headDim + i];
|
||
}
|
||
score *= scale;
|
||
expSum += exp(score - maxScore);
|
||
}
|
||
|
||
// Weighted sum of values
|
||
float outVal = 0.0;
|
||
for (uint p = 0; p < numPatches; p++) {
|
||
float score = 0.0;
|
||
for (uint i = 0; i < headDim; i++) {
|
||
score += q[pos * numHeads * headDim + head * headDim + i] *
|
||
k[p * numHeads * headDim + head * headDim + i];
|
||
}
|
||
score *= scale;
|
||
float attn = exp(score - maxScore) / expSum;
|
||
outVal += attn * v[p * numHeads * headDim + head * headDim + d];
|
||
}
|
||
|
||
output[pos * numHeads * headDim + head * headDim + d] = outVal;
|
||
}
|
||
|
||
// Vision gate multiply (SwiGLU: SiLU(gate) * up)
|
||
kernel void vision_gate_multiply(
|
||
device const float *gate [[buffer(0)]],
|
||
device const float *up [[buffer(1)]],
|
||
device float *output [[buffer(2)]],
|
||
constant uint &count [[buffer(3)]],
|
||
uint id [[thread_position_in_grid]]
|
||
) {
|
||
if (id >= count) return;
|
||
float g = gate[id];
|
||
float u = up[id];
|
||
float silu = g * (1.0 / (1.0 + exp(-g)));
|
||
output[id] = silu * u;
|
||
}
|
||
|
||
// Vision residual add (no weight)
|
||
kernel void vision_residual_add(
|
||
device const float *input [[buffer(0)]],
|
||
device const float *add [[buffer(1)]],
|
||
device float *output [[buffer(2)]],
|
||
constant uint &count [[buffer(3)]],
|
||
uint id [[thread_position_in_grid]]
|
||
) {
|
||
if (id >= count) return;
|
||
output[id] = input[id] + add[id];
|
||
}
|
||
|
||
// Vision copy output (pad hiddenSize to text hiddenSize)
|
||
kernel void vision_copy_output(
|
||
device const float *input [[buffer(0)]], // [numPatches, 768]
|
||
device float *output [[buffer(1)]], // [numPatches, 2560]
|
||
constant uint &numPatches [[buffer(2)]],
|
||
constant uint &hiddenSize [[buffer(3)]],
|
||
uint2 gid [[thread_position_in_grid]]
|
||
) {
|
||
uint h = gid.x;
|
||
uint p = gid.y;
|
||
if (h >= 2560 || p >= numPatches) return;
|
||
|
||
if (h < hiddenSize) {
|
||
output[p * 2560 + h] = input[p * hiddenSize + h];
|
||
} else {
|
||
output[p * 2560 + h] = 0.0; // Pad with zeros
|
||
}
|
||
}
|
||
|
||
// Vision embedding projection with 4-bit quantization
|
||
// weight: [outFeatures, packedSize] uint32 (each uint32 holds 8 4-bit values)
|
||
// scales: [outFeatures, numGroups] float
|
||
// biases: [outFeatures, numGroups] float
|
||
// input: [numPatches, inFeatures] float (inFeatures = packedSize * 8)
|
||
// output: [numPatches, outFeatures] float
|
||
kernel void vision_embedding_projection_quantized(
|
||
device const float *input [[buffer(0)]],
|
||
device const uint32_t *weight [[buffer(1)]], // packed uint32
|
||
device const float *scales [[buffer(2)]],
|
||
device const float *biases [[buffer(3)]],
|
||
device float *output [[buffer(4)]],
|
||
constant uint &inFeatures [[buffer(5)]], // 768
|
||
constant uint &outFeatures [[buffer(6)]], // 2560
|
||
constant uint &numPatches [[buffer(7)]],
|
||
constant uint &packedSize [[buffer(8)]], // 96 (inFeatures / 8)
|
||
constant uint &groupSize [[buffer(9)]], // 64
|
||
constant uint &numGroups [[buffer(10)]], // 12 (inFeatures / groupSize)
|
||
uint2 gid [[thread_position_in_grid]]
|
||
) {
|
||
uint of = gid.x; // output feature
|
||
uint p = gid.y; // patch position
|
||
if (of >= outFeatures || p >= numPatches) return;
|
||
|
||
float sum = 0.0;
|
||
|
||
for (uint packedIdx = 0; packedIdx < packedSize; packedIdx++) {
|
||
uint32_t packed = weight[of * packedSize + packedIdx];
|
||
|
||
// Unpack 8 4-bit values from uint32
|
||
for (uint nibbleIdx = 0; nibbleIdx < 8; nibbleIdx++) {
|
||
uint elementIdx = packedIdx * 8 + nibbleIdx;
|
||
uint nibble = (packed >> (nibbleIdx * 4)) & 0xF;
|
||
|
||
// Determine group
|
||
uint group = elementIdx / groupSize;
|
||
if (group >= numGroups) group = numGroups - 1;
|
||
|
||
// Dequantize: scale * (nibble - bias)
|
||
float scale = scales[of * numGroups + group];
|
||
float bias = biases[of * numGroups + group];
|
||
float dequantized = scale * (float(nibble) - bias);
|
||
|
||
// Multiply with input
|
||
sum += input[p * inFeatures + elementIdx] * dequantized;
|
||
}
|
||
}
|
||
|
||
output[p * outFeatures + of] = sum;
|
||
}
|
||
|
||
// ── Matmul for F32 weights (not quantized) ──
|
||
// Simple row-vector @ matrix multiplication: output = input @ weight^T
|
||
// input: [M, K], weight: [N, K] (stored row-major), output: [M, N]
|
||
kernel void matmul_f32(
|
||
device const float *input [[buffer(0)]], // [M, K]
|
||
device const float *weight [[buffer(1)]], // [N, K]
|
||
device float *output [[buffer(2)]], // [M, N]
|
||
constant uint &M [[buffer(3)]],
|
||
constant uint &K [[buffer(4)]],
|
||
constant uint &N [[buffer(5)]],
|
||
uint id [[thread_position_in_grid]]
|
||
) {
|
||
// Each thread computes one output element
|
||
uint row = 0; // For single token, M=1
|
||
uint col = id;
|
||
|
||
if (col >= N) return;
|
||
|
||
float sum = 0.0;
|
||
for (uint k = 0; k < K; k++) {
|
||
sum += input[row * K + k] * weight[col * K + k];
|
||
}
|
||
output[row * N + col] = sum;
|
||
}
|
||
|
||
// ═══════════════════════════════════════════════════════════════
|
||
// Batch Metal Kernels - Process multiple tokens simultaneously
|
||
// ═══════════════════════════════════════════════════════════════
|
||
|
||
// Batch quantized matmul - process N tokens with shared weights
|
||
kernel void quantized_matmul_batch(
|
||
device float* batchInput [[buffer(0)]], // [batchSize, inDim]
|
||
device uint8_t* weights [[buffer(1)]], // [outDim, inDim] packed
|
||
device float* scales [[buffer(2)]], // [outDim, groups]
|
||
device float* biases [[buffer(3)]], // [outDim]
|
||
device float* batchOutput [[buffer(4)]], // [batchSize, outDim]
|
||
constant uint32_t& inDim [[buffer(5)]],
|
||
constant uint32_t& outDim [[buffer(6)]],
|
||
constant uint32_t& groupSize [[buffer(7)]],
|
||
constant uint32_t& batchSize [[buffer(8)]],
|
||
uint3 gid [[thread_position_in_grid]])
|
||
{
|
||
uint batchIdx = gid.x;
|
||
uint outIdx = gid.y;
|
||
|
||
if (batchIdx >= batchSize || outIdx >= outDim) return;
|
||
|
||
device float* input = batchInput + batchIdx * inDim;
|
||
float sum = biases[outIdx];
|
||
uint groupIdx = outIdx * (inDim / groupSize);
|
||
|
||
for (uint i = 0; i < inDim; i += 4) {
|
||
float4 inVals = float4(input[i], input[i+1], input[i+2], input[i+3]);
|
||
|
||
uint packedWeight = weights[outIdx * inDim + i];
|
||
uint8_t w0 = (packedWeight >> 0) & 0xFF;
|
||
uint8_t w1 = (packedWeight >> 8) & 0xFF;
|
||
uint8_t w2 = (packedWeight >> 16) & 0xFF;
|
||
uint8_t w3 = (packedWeight >> 24) & 0xFF;
|
||
|
||
uint g0 = (i + 0) / groupSize;
|
||
uint g1 = (i + 1) / groupSize;
|
||
uint g2 = (i + 2) / groupSize;
|
||
uint g3 = (i + 3) / groupSize;
|
||
|
||
float scale0 = scales[groupIdx + g0];
|
||
float scale1 = scales[groupIdx + g1];
|
||
float scale2 = scales[groupIdx + g2];
|
||
float scale3 = scales[groupIdx + g3];
|
||
|
||
sum += inVals.x * (w0 - 128) * scale0;
|
||
sum += inVals.y * (w1 - 128) * scale1;
|
||
sum += inVals.z * (w2 - 128) * scale2;
|
||
sum += inVals.w * (w3 - 128) * scale3;
|
||
}
|
||
|
||
batchOutput[batchIdx * outDim + outIdx] = sum;
|
||
}
|
||
|
||
// Batch RMS norm - process N tokens simultaneously
|
||
kernel void rms_norm_batch(
|
||
device float* batchInput [[buffer(0)]], // [batchSize, N]
|
||
device float* weights [[buffer(1)]], // [N]
|
||
device float* batchOutput [[buffer(2)]], // [batchSize, N]
|
||
constant uint32_t& N [[buffer(3)]],
|
||
constant float& eps [[buffer(4)]],
|
||
constant uint32_t& batchSize [[buffer(5)]],
|
||
uint3 gid [[thread_position_in_grid]])
|
||
{
|
||
uint batchIdx = gid.x;
|
||
uint elemIdx = gid.y;
|
||
|
||
if (batchIdx >= batchSize || elemIdx >= N) return;
|
||
|
||
device float* input = batchInput + batchIdx * N;
|
||
float sqSum = 0.0;
|
||
for (uint i = 0; i < N; i++) {
|
||
sqSum += input[i] * input[i];
|
||
}
|
||
|
||
float rms = sqrt(sqSum / float(N) + eps);
|
||
batchOutput[batchIdx * N + elemIdx] = input[elemIdx] / rms * weights[elemIdx];
|
||
}
|
||
|
||
// Batch attention (simplified - for demonstration)
|
||
// Full implementation would require complex KV cache management
|
||
kernel void sliding_attention_batch(
|
||
device float* batchQuery [[buffer(0)]], // [batchSize, nHeads, headDim]
|
||
device float* kvCache [[buffer(1)]], // [maxSeqLen, 2, nKvHeads, headDim]
|
||
device float* batchOutput [[buffer(2)]], // [batchSize, nHeads, headDim]
|
||
constant uint32_t* positions [[buffer(3)]], // [batchSize]
|
||
constant uint32_t& nHeads [[buffer(4)]],
|
||
constant uint32_t& nKvHeads [[buffer(5)]],
|
||
constant uint32_t& headDim [[buffer(6)]],
|
||
constant uint32_t& batchSize [[buffer(7)]],
|
||
constant uint32_t& windowSize [[buffer(8)]],
|
||
uint3 gid [[thread_position_in_grid]])
|
||
{
|
||
uint batchIdx = gid.x;
|
||
uint headIdx = gid.y;
|
||
uint dimIdx = gid.z;
|
||
|
||
if (batchIdx >= batchSize || headIdx >= nHeads || dimIdx >= headDim) return;
|
||
|
||
uint pos = positions[batchIdx];
|
||
uint kvHeadIdx = headIdx / (nHeads / nKvHeads);
|
||
|
||
device float* query = batchQuery + batchIdx * nHeads * headDim + headIdx * headDim;
|
||
|
||
uint start = max(0u, pos - windowSize);
|
||
uint end = pos;
|
||
|
||
float maxScore = -1e10;
|
||
for (uint t = start; t < end; t++) {
|
||
device float* key = kvCache + t * 2 * nKvHeads * headDim + kvHeadIdx * headDim;
|
||
|
||
float score = 0.0;
|
||
for (uint d = 0; d < headDim; d++) {
|
||
score += query[d] * key[d];
|
||
}
|
||
|
||
score /= sqrt(float(headDim));
|
||
maxScore = max(maxScore, score);
|
||
}
|
||
|
||
float expSum = 0.0;
|
||
for (uint t = start; t < end; t++) {
|
||
device float* key = kvCache + t * 2 * nKvHeads * headDim + kvHeadIdx * headDim;
|
||
|
||
float score = 0.0;
|
||
for (uint d = 0; d < headDim; d++) {
|
||
score += query[d] * key[d];
|
||
}
|
||
|
||
score /= sqrt(float(headDim));
|
||
expSum += exp(score - maxScore);
|
||
}
|
||
|
||
float output = 0.0;
|
||
for (uint t = start; t < end; t++) {
|
||
device float* key = kvCache + t * 2 * nKvHeads * headDim + kvHeadIdx * headDim;
|
||
device float* value = kvCache + t * 2 * nKvHeads * headDim + nKvHeads * headDim + kvHeadIdx * headDim;
|
||
|
||
float score = 0.0;
|
||
for (uint d = 0; d < headDim; d++) {
|
||
score += query[d] * key[d];
|
||
}
|
||
|
||
score /= sqrt(float(headDim));
|
||
float weight = exp(score - maxScore) / expSum;
|
||
|
||
output += weight * value[dimIdx];
|
||
}
|
||
|
||
batchOutput[batchIdx * nHeads * headDim + headIdx * headDim + dimIdx] = output;
|
||
}
|
||
// ═══════════════════════════════════════════════════════════════
|
||
// Batch Layer Processing Kernels
|
||
// Process entire layer for multiple tokens simultaneously
|
||
// ═══════════════════════════════════════════════════════════════
|
||
|
||
// Batch RMS Norm for layer input
|
||
// Process [batchSize, hiddenSize] with shared weights
|
||
kernel void batch_layer_rms_norm(
|
||
device float* batchInput [[buffer(0)]], // [batchSize, hiddenSize]
|
||
device float* weights [[buffer(1)]], // [hiddenSize]
|
||
device float* batchOutput [[buffer(2)]], // [batchSize, hiddenSize]
|
||
constant uint32_t& hiddenSize [[buffer(3)]],
|
||
constant float& eps [[buffer(4)]],
|
||
constant uint32_t& batchSize [[buffer(5)]],
|
||
uint3 gid [[thread_position_in_grid]])
|
||
{
|
||
uint batchIdx = gid.x;
|
||
uint elemIdx = gid.y;
|
||
|
||
if (batchIdx >= batchSize || elemIdx >= hiddenSize) return;
|
||
|
||
device float* input = batchInput + batchIdx * hiddenSize;
|
||
device float* output = batchOutput + batchIdx * hiddenSize;
|
||
|
||
// Compute sum of squares for this batch element
|
||
float ss = 0.0;
|
||
for (uint i = 0; i < hiddenSize; i++) {
|
||
ss += input[i] * input[i];
|
||
}
|
||
|
||
float rms = sqrt(ss / float(hiddenSize) + eps);
|
||
output[elemIdx] = input[elemIdx] / rms * weights[elemIdx];
|
||
}
|
||
|
||
// Batch Quantized Matmul for layer projections
|
||
// Process [batchSize, outDim] with shared quantized weights
|
||
kernel void batch_layer_quantized_matmul(
|
||
device float* batchInput [[buffer(0)]], // [batchSize, inDim]
|
||
device uint8_t* weights [[buffer(1)]], // [outDim, inDim] packed
|
||
device float* scales [[buffer(2)]], // [outDim, groups]
|
||
device float* biases [[buffer(3)]], // [outDim]
|
||
device float* batchOutput [[buffer(4)]], // [batchSize, outDim]
|
||
constant uint32_t& inDim [[buffer(5)]],
|
||
constant uint32_t& outDim [[buffer(6)]],
|
||
constant uint32_t& groupSize [[buffer(7)]],
|
||
constant uint32_t& batchSize [[buffer(8)]],
|
||
uint3 gid [[thread_position_in_grid]])
|
||
{
|
||
uint batchIdx = gid.x;
|
||
uint outIdx = gid.y;
|
||
|
||
if (batchIdx >= batchSize || outIdx >= outDim) return;
|
||
|
||
device float* input = batchInput + batchIdx * inDim;
|
||
device float* output = batchOutput + batchIdx * outDim;
|
||
|
||
float sum = biases[outIdx];
|
||
uint groupIdx = outIdx * (inDim / groupSize);
|
||
|
||
// Process in groups for quantization
|
||
for (uint i = 0; i < inDim; i++) {
|
||
// Load weight (8-bit quantized)
|
||
uint8_t w = weights[outIdx * inDim + i];
|
||
|
||
// Get scale for this group
|
||
uint g = i / groupSize;
|
||
float scale = scales[groupIdx + g];
|
||
|
||
// Dequantize and accumulate
|
||
sum += input[i] * (w - 128) * scale;
|
||
}
|
||
|
||
output[outIdx] = sum;
|
||
}
|
||
|
||
// Batch Elementwise Add for residual connections
|
||
// Process [batchSize, size]
|
||
kernel void batch_eltwise_add(
|
||
device float* batchA [[buffer(0)]], // [batchSize, size]
|
||
device float* batchB [[buffer(1)]], // [batchSize, size]
|
||
device float* batchOutput [[buffer(2)]], // [batchSize, size]
|
||
constant uint32_t& size [[buffer(3)]],
|
||
constant uint32_t& batchSize [[buffer(4)]],
|
||
uint2 gid [[thread_position_in_grid]])
|
||
{
|
||
uint batchIdx = gid.x;
|
||
uint elemIdx = gid.y;
|
||
|
||
if (batchIdx >= batchSize || elemIdx >= size) return;
|
||
|
||
uint offset = batchIdx * size + elemIdx;
|
||
batchOutput[offset] = batchA[offset] + batchB[offset];
|
||
}
|
||
|
||
// Batch Gated FFN (fused gate + up projection)
|
||
// Process [batchSize, intermediateSize]
|
||
kernel void batch_fused_gate_up(
|
||
device float* batchInput [[buffer(0)]], // [batchSize, hiddenSize]
|
||
device uint8_t* gateWeights [[buffer(1)]], // [intermediateSize, hiddenSize]
|
||
device float* gateScales [[buffer(2)]],
|
||
device float* gateBiases [[buffer(3)]],
|
||
device uint8_t* upWeights [[buffer(4)]], // [intermediateSize, hiddenSize]
|
||
device float* upScales [[buffer(5)]],
|
||
device float* upBiases [[buffer(6)]],
|
||
device float* batchOutput [[buffer(7)]], // [batchSize, intermediateSize]
|
||
constant uint32_t& hiddenSize [[buffer(8)]],
|
||
constant uint32_t& intermediateSize [[buffer(9)]],
|
||
constant uint32_t& groupSize [[buffer(10)]],
|
||
constant uint32_t& batchSize [[buffer(11)]],
|
||
uint3 gid [[thread_position_in_grid]])
|
||
{
|
||
uint batchIdx = gid.x;
|
||
uint interIdx = gid.y;
|
||
|
||
if (batchIdx >= batchSize || interIdx >= intermediateSize) return;
|
||
|
||
device float* input = batchInput + batchIdx * hiddenSize;
|
||
device float* output = batchOutput + batchIdx * intermediateSize;
|
||
|
||
// Compute gate
|
||
float gate = gateBiases[interIdx];
|
||
uint gateGroupIdx = interIdx * (hiddenSize / groupSize);
|
||
for (uint i = 0; i < hiddenSize; i++) {
|
||
uint8_t w = gateWeights[interIdx * hiddenSize + i];
|
||
uint g = i / groupSize;
|
||
float scale = gateScales[gateGroupIdx + g];
|
||
gate += input[i] * (w - 128) * scale;
|
||
}
|
||
|
||
// Compute up
|
||
float up = upBiases[interIdx];
|
||
uint upGroupIdx = interIdx * (hiddenSize / groupSize);
|
||
for (uint i = 0; i < hiddenSize; i++) {
|
||
uint8_t w = upWeights[interIdx * hiddenSize + i];
|
||
uint g = i / groupSize;
|
||
float scale = upScales[upGroupIdx + g];
|
||
up += input[i] * (w - 128) * scale;
|
||
}
|
||
|
||
// Fused activation: gate * sigmoid(gate) * up
|
||
float sigmoidGate = 1.0 / (1.0 + exp(-gate));
|
||
output[interIdx] = gate * sigmoidGate * up;
|
||
}
|
||
|
||
// Batch Down Projection (FFN output)
|
||
// Process [batchSize, hiddenSize]
|
||
kernel void batch_down_projection(
|
||
device float* batchInter [[buffer(0)]], // [batchSize, intermediateSize]
|
||
device uint8_t* downWeights [[buffer(1)]], // [hiddenSize, intermediateSize]
|
||
device float* downScales [[buffer(2)]],
|
||
device float* downBiases [[buffer(3)]],
|
||
device float* batchOutput [[buffer(4)]], // [batchSize, hiddenSize]
|
||
constant uint32_t& hiddenSize [[buffer(5)]],
|
||
constant uint32_t& intermediateSize [[buffer(6)]],
|
||
constant uint32_t& groupSize [[buffer(7)]],
|
||
constant uint32_t& batchSize [[buffer(8)]],
|
||
uint3 gid [[thread_position_in_grid]])
|
||
{
|
||
uint batchIdx = gid.x;
|
||
uint outIdx = gid.y;
|
||
|
||
if (batchIdx >= batchSize || outIdx >= hiddenSize) return;
|
||
|
||
device float* inter = batchInter + batchIdx * intermediateSize;
|
||
device float* output = batchOutput + batchIdx * hiddenSize;
|
||
|
||
float sum = downBiases[outIdx];
|
||
uint groupIdx = outIdx * (intermediateSize / groupSize);
|
||
|
||
for (uint i = 0; i < intermediateSize; i++) {
|
||
uint8_t w = downWeights[outIdx * intermediateSize + i];
|
||
uint g = i / groupSize;
|
||
float scale = downScales[groupIdx + g];
|
||
sum += inter[i] * (w - 128) * scale;
|
||
}
|
||
|
||
output[outIdx] = sum;
|
||
}
|
||
|
||
// ══════════════════════════════════════════════════════════════════
|
||
// Batch Embedding Lookup - Process multiple token embeddings in parallel
|
||
// Eliminates sequential waitUntilCompleted() bottleneck
|
||
// ══════════════════════════════════════════════════════════════════
|
||
|
||
// Batch version of dequantize_row - processes multiple token embeddings
|
||
kernel void dequantize_row_batch(
|
||
device const uint *w [[buffer(0)]], // [vocabSize, nCols/8]
|
||
device const float *s [[buffer(1)]], // [vocabSize, numGroups]
|
||
device const float *b [[buffer(2)]], // [vocabSize, numGroups]
|
||
device const uint *tokenIds [[buffer(3)]], // [batchSize] - which rows to lookup
|
||
device float *out [[buffer(4)]], // [batchSize, nCols]
|
||
constant uint &nCols [[buffer(5)]],
|
||
constant uint &batchSize [[buffer(6)]],
|
||
constant uint &groupSize [[buffer(7)]],
|
||
uint3 gid [[thread_position_in_grid]]
|
||
) {
|
||
uint batchIdx = gid.x; // Which token in batch
|
||
uint colIdx = gid.y; // Which column in embedding
|
||
|
||
if (batchIdx >= batchSize || colIdx >= nCols) return;
|
||
|
||
uint tokenId = tokenIds[batchIdx];
|
||
uint g = colIdx / groupSize;
|
||
uint inG = colIdx % groupSize;
|
||
uint packedIdx = g * (groupSize / 8) + inG / 8;
|
||
uint shift = (inG % 8) * 4;
|
||
|
||
uint numGroups = nCols / groupSize;
|
||
|
||
// Lookup the quantized value
|
||
uint qval = (w[tokenId * (nCols / 8) + packedIdx] >> shift) & 0xF;
|
||
float scale = s[tokenId * numGroups + g];
|
||
float bias = b[tokenId * numGroups + g];
|
||
|
||
// Write to batch output buffer [batchIdx, colIdx]
|
||
out[batchIdx * nCols + colIdx] = float(qval) * scale + bias;
|
||
}
|
||
|
||
// Batch version with scale applied (fused dequantize + scale)
|
||
kernel void dequantize_row_batch_scaled(
|
||
device const uint *w [[buffer(0)]], // [vocabSize, nCols/8]
|
||
device const float *s [[buffer(1)]], // [vocabSize, numGroups]
|
||
device const float *b [[buffer(2)]], // [vocabSize, numGroups]
|
||
device const uint *tokenIds [[buffer(3)]], // [batchSize] - which rows to lookup
|
||
device float *out [[buffer(4)]], // [batchSize, nCols]
|
||
constant uint &nCols [[buffer(5)]],
|
||
constant uint &batchSize [[buffer(6)]],
|
||
constant uint &groupSize [[buffer(7)]],
|
||
constant float &embedScale [[buffer(8)]], // Global embedding scale
|
||
uint3 gid [[thread_position_in_grid]]
|
||
) {
|
||
uint batchIdx = gid.x;
|
||
uint colIdx = gid.y;
|
||
|
||
if (batchIdx >= batchSize || colIdx >= nCols) return;
|
||
|
||
uint tokenId = tokenIds[batchIdx];
|
||
uint g = colIdx / groupSize;
|
||
uint inG = colIdx % groupSize;
|
||
uint packedIdx = g * (groupSize / 8) + inG / 8;
|
||
uint shift = (inG % 8) * 4;
|
||
|
||
uint numGroups = nCols / groupSize;
|
||
|
||
uint qval = (w[tokenId * (nCols / 8) + packedIdx] >> shift) & 0xF;
|
||
float scale = s[tokenId * numGroups + g];
|
||
float bias = b[tokenId * numGroups + g];
|
||
|
||
// Apply embedding scale (fused)
|
||
out[batchIdx * nCols + colIdx] = (float(qval) * scale + bias) * embedScale;
|
||
} |