Files
markbaseengine/Sources/MarkBase/Metal/OptimizedKernels.metal
T
MarkBase Admin 8a66b9086a
CI / build (push) Waiting to run
CI / unit-tests (push) Blocked by required conditions
CI / lint (push) Blocked by required conditions
v2: Initial clean branch with unit tests + CI/CD pipeline
- Started from ac75faa (initial E4B-MarkBase integration)
- Kept Sources/ (all engine code) + Package.swift + .gitignore
- Removed all ad-hoc tests, documentation, scripts, Python files
- Added Tests/00_Unit/ (MathTest, TokenizerTest, SamplerTest)
- Added .gitea/workflows/ci.yaml (build + unit tests + lint)
- Added Scripts/check_resources.sh (memory-aware test runner)
- Added Tests/Manifest.json (resource requirements for all tests)
- Focus: 4-bit quantized models only
2026-07-05 13:29:25 +08:00

1228 lines
71 KiB
Metal
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#include <metal_stdlib>
using namespace metal;
// ═══════════════════════════════════════════════
// SIMD Optimized Kernels - Phase 1
// ═══════════════════════════════════════════════
constant uint SIMD_WIDTH = 4;
constant uint HEAD_DIM = 128;
constant uint MAX_WINDOW = 4096;
// ── SIMD Optimized Sliding Attention ───────────────
// Uses threadgroup cache for K/V + float4 SIMD operations
kernel void sliding_attention_simd(
device const float *q [[buffer(0)]],
device const float *k [[buffer(1)]],
device const float *v [[buffer(2)]],
device float *out [[buffer(3)]],
constant uint &nHeads [[buffer(4)]],
constant uint &nKvHeads [[buffer(5)]],
constant uint &headDim [[buffer(6)]],
constant uint &windowSize [[buffer(7)]],
constant int &offset [[buffer(8)]],
threadgroup float4 *shared_k [[threadgroup(0)]], // [MAX_WINDOW, nKvHeads, headDim/SIMD_WIDTH]
threadgroup float4 *shared_v [[threadgroup(1)]], // [MAX_WINDOW, nKvHeads, headDim/SIMD_WIDTH]
uint2 gid [[thread_position_in_grid]],
uint2 tid [[thread_position_in_threadgroup]],
uint2 tgSize [[threads_per_threadgroup]]
) {
uint head = gid.x;
uint dimBlock = gid.y; // dim/SIMD_WIDTH
if (head >= nHeads || dimBlock >= headDim/SIMD_WIDTH) return;
uint kvHead = head % nKvHeads;
uint seqLen = uint(offset + 1);
uint actualWindow = min(seqLen, windowSize);
int base = int(offset) - int(actualWindow) + 1;
// ── Threadgroup Cache Loading ───────────────────
// Cooperative loading of K and V into threadgroup memory
uint loadStride = tgSize.x * tgSize.y;
uint totalElements = actualWindow * nKvHeads * (headDim/SIMD_WIDTH);
for (uint idx = tid.y * tgSize.x + tid.x; idx < totalElements; idx += loadStride) {
uint t = idx / (nKvHeads * (headDim/SIMD_WIDTH));
uint h = (idx % (nKvHeads * (headDim/SIMD_WIDTH))) / (headDim/SIMD_WIDTH);
uint db = idx % (headDim/SIMD_WIDTH);
int logicalPos = base + int(t);
uint cacheIdx = logicalPos >= 0 ? uint(logicalPos) % windowSize : 0;
uint kOffset = (cacheIdx * nKvHeads + h) * headDim + db * SIMD_WIDTH;
uint vOffset = (cacheIdx * nKvHeads + h) * headDim + db * SIMD_WIDTH;
shared_k[idx] = float4(
k[kOffset], k[kOffset+1], k[kOffset+2], k[kOffset+3]
);
shared_v[idx] = float4(
v[vOffset], v[vOffset+1], v[vOffset+2], v[vOffset+3]
);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// ── SIMD Dot Product ────────────────────────────
float scale = 1.0 / sqrt(float(headDim));
// Load Q as float4 (SIMD)
uint qOffset = head * headDim + dimBlock * SIMD_WIDTH;
float4 qVec = float4(q[qOffset], q[qOffset+1], q[qOffset+2], q[qOffset+3]);
// Pass 1: find max score
float maxScore = -INFINITY;
for (uint t = 0; t < actualWindow; t++) {
uint sharedIdx = t * nKvHeads * (headDim/SIMD_WIDTH) + kvHead * (headDim/SIMD_WIDTH) + dimBlock;
float4 kVec = shared_k[sharedIdx];
float score = dot(qVec, kVec) * scale;
// Text model has NO attention softcapping (same as original kernel)
maxScore = max(maxScore, score);
}
// Pass 2: softmax + weighted sum
float sumExp = 0.0;
float4 resultVec = float4(0.0);
for (uint t = 0; t < actualWindow; t++) {
uint sharedIdx = t * nKvHeads * (headDim/SIMD_WIDTH) + kvHead * (headDim/SIMD_WIDTH) + dimBlock;
float4 kVec = shared_k[sharedIdx];
float4 vVec = shared_v[sharedIdx];
float score = dot(qVec, kVec) * scale;
// Text model has NO attention softcapping (same as original kernel)
float expVal = exp(score - maxScore);
sumExp += expVal;
resultVec += expVal * vVec;
}
resultVec /= sumExp;
// Write output
uint outOffset = head * headDim + dimBlock * SIMD_WIDTH;
out[outOffset] = resultVec.x;
out[outOffset+1] = resultVec.y;
out[outOffset+2] = resultVec.z;
out[outOffset+3] = resultVec.w;
}
// ── SIMD Optimized Full Attention ──────────────────
kernel void full_attention_simd(
device const float *q [[buffer(0)]],
device const float *k [[buffer(1)]],
device const float *v [[buffer(2)]],
device float *out [[buffer(3)]],
constant uint &nHeads [[buffer(4)]],
constant uint &nKvHeads [[buffer(5)]],
constant uint &headDim [[buffer(6)]],
constant uint &seqLen [[buffer(7)]],
threadgroup float4 *shared_k [[threadgroup(0)]],
threadgroup float4 *shared_v [[threadgroup(1)]],
uint2 gid [[thread_position_in_grid]],
uint2 tid [[thread_position_in_threadgroup]],
uint2 tgSize [[threads_per_threadgroup]]
) {
uint head = gid.x;
uint dimBlock = gid.y;
if (head >= nHeads || dimBlock >= headDim/SIMD_WIDTH) return;
uint kvHead = head % nKvHeads;
// Threadgroup cache loading
uint loadStride = tgSize.x * tgSize.y;
uint totalElements = seqLen * nKvHeads * (headDim/SIMD_WIDTH);
for (uint idx = tid.y * tgSize.x + tid.x; idx < totalElements; idx += loadStride) {
uint t = idx / (nKvHeads * (headDim/SIMD_WIDTH));
uint h = (idx % (nKvHeads * (headDim/SIMD_WIDTH))) / (headDim/SIMD_WIDTH);
uint db = idx % (headDim/SIMD_WIDTH);
uint kOffset = (t * nKvHeads + h) * headDim + db * SIMD_WIDTH;
uint vOffset = (t * nKvHeads + h) * headDim + db * SIMD_WIDTH;
shared_k[idx] = float4(k[kOffset], k[kOffset+1], k[kOffset+2], k[kOffset+3]);
shared_v[idx] = float4(v[vOffset], v[vOffset+1], v[vOffset+2], v[vOffset+3]);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
float scale = 1.0 / sqrt(float(headDim));
float4 qVec = float4(
q[head * headDim + dimBlock * SIMD_WIDTH],
q[head * headDim + dimBlock * SIMD_WIDTH + 1],
q[head * headDim + dimBlock * SIMD_WIDTH + 2],
q[head * headDim + dimBlock * SIMD_WIDTH + 3]
);
// Pass 1: max score
float maxScore = -INFINITY;
for (uint t = 0; t < seqLen; t++) {
uint sharedIdx = t * nKvHeads * (headDim/SIMD_WIDTH) + kvHead * (headDim/SIMD_WIDTH) + dimBlock;
float score = dot(qVec, shared_k[sharedIdx]) * scale;
// Text model has NO attention softcapping
maxScore = max(maxScore, score);
}
// Pass 2: softmax + weighted sum
float sumExp = 0.0;
float4 resultVec = float4(0.0);
for (uint t = 0; t < seqLen; t++) {
uint sharedIdx = t * nKvHeads * (headDim/SIMD_WIDTH) + kvHead * (headDim/SIMD_WIDTH) + dimBlock;
float score = dot(qVec, shared_k[sharedIdx]) * scale;
// Text model has NO attention softcapping
float expVal = exp(score - maxScore);
sumExp += expVal;
resultVec += expVal * shared_v[sharedIdx];
}
resultVec /= sumExp;
uint outOffset = head * headDim + dimBlock * SIMD_WIDTH;
out[outOffset] = resultVec.x;
out[outOffset+1] = resultVec.y;
out[outOffset+2] = resultVec.z;
out[outOffset+3] = resultVec.w;
}
// ── Block-based Quantized Matmul (High Performance) ────
// Each threadgroup processes multiple output rows cooperatively
// Each thread computes partial sum, then reduce to final output
kernel void quantized_matmul_block(
device const float *x [[buffer(0)]],
device const uint *w [[buffer(1)]],
device const float *s [[buffer(2)]],
device const float *b [[buffer(3)]],
device float *out [[buffer(4)]],
constant uint &inDim [[buffer(5)]],
constant uint &outDim [[buffer(6)]],
constant uint &groupSize [[buffer(7)]], // 64
threadgroup float *partial_sums [[threadgroup(0)]], // For reduction
threadgroup float *shared_x [[threadgroup(1)]], // Input vector cache
uint2 gid [[thread_position_in_grid]],
uint2 tid [[thread_position_in_threadgroup]],
uint2 tgSize [[threads_per_threadgroup]]
) {
uint outRow = gid.x; // Output row index
uint threadInRow = gid.y; // Thread index within row (0..tgSize.y-1)
if (outRow >= outDim) return;
uint numThreadsPerRow = tgSize.y;
uint numGroups = inDim / groupSize;
// ── Cooperative loading of input vector ──────────────────
for (uint i = tid.y * tgSize.x + tid.x; i < inDim; i += tgSize.x * tgSize.y) {
shared_x[i] = x[i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// ── Compute partial dot product for this thread ────────────────
// Each thread processes a portion of input dimensions
uint groupsPerThread = numGroups / numThreadsPerRow;
uint startGroup = threadInRow * groupsPerThread;
uint endGroup = (threadInRow == numThreadsPerRow - 1) ? numGroups : startGroup + groupsPerThread;
float localSum = 0.0;
for (uint g = startGroup; g < endGroup; g++) {
float scale = s[outRow * numGroups + g];
float bias = b[outRow * numGroups + g];
uint packedBase = outRow * (inDim / 8) + g * (groupSize / 8);
// Process 2 packed uint32 at a time (16 nibbles)
for (uint p = 0; p < 8; p += 2) {
uint packed0 = w[packedBase + p];
uint packed1 = w[packedBase + p + 1];
uint xBase = g * groupSize + p * 8;
// SIMD float4 processing
float4 xVec0 = float4(
shared_x[xBase + 0], shared_x[xBase + 1],
shared_x[xBase + 2], shared_x[xBase + 3]
);
float4 xVec1 = float4(
shared_x[xBase + 4], shared_x[xBase + 5],
shared_x[xBase + 6], shared_x[xBase + 7]
);
float4 xVec2 = float4(
shared_x[xBase + 8], shared_x[xBase + 9],
shared_x[xBase + 10], shared_x[xBase + 11]
);
float4 xVec3 = float4(
shared_x[xBase + 12], shared_x[xBase + 13],
shared_x[xBase + 14], shared_x[xBase + 15]
);
float4 qVec0 = float4(
float((packed0 >> 0) & 0xF) * scale + bias,
float((packed0 >> 4) & 0xF) * scale + bias,
float((packed0 >> 8) & 0xF) * scale + bias,
float((packed0 >> 12) & 0xF) * scale + bias
);
float4 qVec1 = float4(
float((packed0 >> 16) & 0xF) * scale + bias,
float((packed0 >> 20) & 0xF) * scale + bias,
float((packed0 >> 24) & 0xF) * scale + bias,
float((packed0 >> 28) & 0xF) * scale + bias
);
float4 qVec2 = float4(
float((packed1 >> 0) & 0xF) * scale + bias,
float((packed1 >> 4) & 0xF) * scale + bias,
float((packed1 >> 8) & 0xF) * scale + bias,
float((packed1 >> 12) & 0xF) * scale + bias
);
float4 qVec3 = float4(
float((packed1 >> 16) & 0xF) * scale + bias,
float((packed1 >> 20) & 0xF) * scale + bias,
float((packed1 >> 24) & 0xF) * scale + bias,
float((packed1 >> 28) & 0xF) * scale + bias
);
localSum += dot(qVec0, xVec0);
localSum += dot(qVec1, xVec1);
localSum += dot(qVec2, xVec2);
localSum += dot(qVec3, xVec3);
}
}
// ── Parallel reduction within threadgroup ───────────────────
uint reductionIdx = tid.y; // Each thread in row contributes to one sum
partial_sums[reductionIdx] = localSum;
threadgroup_barrier(mem_flags::mem_threadgroup);
// Reduce to final sum (one thread writes final result)
if (tid.y == 0) {
float finalSum = 0.0;
for (uint t = 0; t < numThreadsPerRow; t++) {
finalSum += partial_sums[t];
}
out[outRow] = finalSum;
}
}
// ── SIMD Optimized Quantized Matmul (Legacy) ───────────────
// Kept for backward compatibility
kernel void quantized_matmul_simd(
device const float *x [[buffer(0)]],
device const uint *w [[buffer(1)]],
device const float *s [[buffer(2)]],
device const float *b [[buffer(3)]],
device float *out [[buffer(4)]],
constant uint &inDim [[buffer(5)]],
constant uint &outDim [[buffer(6)]],
constant uint &groupSize [[buffer(7)]], // 64
threadgroup float *shared_x [[threadgroup(0)]],
uint gid [[thread_position_in_grid]],
uint tid [[thread_position_in_threadgroup]],
uint tgSize [[threads_per_threadgroup]]
) {
uint outRow = gid;
if (outRow >= outDim) return;
// ── Cooperative input load ────────────────────────────
for (uint i = tid; i < inDim; i += tgSize) {
shared_x[i] = x[i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// ── Compute dot product ────────────────────────────────
uint numGroups = inDim / groupSize;
float sum = 0.0;
for (uint g = 0; g < numGroups; g++) {
float scale = s[outRow * numGroups + g];
float bias = b[outRow * numGroups + g];
uint packedBase = outRow * (inDim / 8) + g * (groupSize / 8);
uint xBase = g * groupSize;
// Process 4 uint32 per iteration (32 nibbles) — half the loop count
for (uint p = 0; p < 8; p += 4) {
// Vectorized uint4 load (reduces load instructions)
device uint4 *packedPtr = (device uint4*)(&w[packedBase + p]);
uint4 packed = *packedPtr;
// Load 32 input values as 8 × float4
float4 xVec0 = float4(shared_x[xBase + p*8 + 0], shared_x[xBase + p*8 + 1], shared_x[xBase + p*8 + 2], shared_x[xBase + p*8 + 3]);
float4 xVec1 = float4(shared_x[xBase + p*8 + 4], shared_x[xBase + p*8 + 5], shared_x[xBase + p*8 + 6], shared_x[xBase + p*8 + 7]);
float4 xVec2 = float4(shared_x[xBase + p*8 + 8], shared_x[xBase + p*8 + 9], shared_x[xBase + p*8 + 10], shared_x[xBase + p*8 + 11]);
float4 xVec3 = float4(shared_x[xBase + p*8 + 12], shared_x[xBase + p*8 + 13], shared_x[xBase + p*8 + 14], shared_x[xBase + p*8 + 15]);
float4 xVec4 = float4(shared_x[xBase + p*8 + 16], shared_x[xBase + p*8 + 17], shared_x[xBase + p*8 + 18], shared_x[xBase + p*8 + 19]);
float4 xVec5 = float4(shared_x[xBase + p*8 + 20], shared_x[xBase + p*8 + 21], shared_x[xBase + p*8 + 22], shared_x[xBase + p*8 + 23]);
float4 xVec6 = float4(shared_x[xBase + p*8 + 24], shared_x[xBase + p*8 + 25], shared_x[xBase + p*8 + 26], shared_x[xBase + p*8 + 27]);
float4 xVec7 = float4(shared_x[xBase + p*8 + 28], shared_x[xBase + p*8 + 29], shared_x[xBase + p*8 + 30], shared_x[xBase + p*8 + 31]);
// Unpack + dequantize 4 uint32 → 8 float4, all with same scale+bias
float4 qVec0 = float4(float((packed.x >> 0) & 0xF) * scale + bias,
float((packed.x >> 4) & 0xF) * scale + bias,
float((packed.x >> 8) & 0xF) * scale + bias,
float((packed.x >> 12) & 0xF) * scale + bias);
float4 qVec1 = float4(float((packed.x >> 16) & 0xF) * scale + bias,
float((packed.x >> 20) & 0xF) * scale + bias,
float((packed.x >> 24) & 0xF) * scale + bias,
float((packed.x >> 28) & 0xF) * scale + bias);
float4 qVec2 = float4(float((packed.y >> 0) & 0xF) * scale + bias,
float((packed.y >> 4) & 0xF) * scale + bias,
float((packed.y >> 8) & 0xF) * scale + bias,
float((packed.y >> 12) & 0xF) * scale + bias);
float4 qVec3 = float4(float((packed.y >> 16) & 0xF) * scale + bias,
float((packed.y >> 20) & 0xF) * scale + bias,
float((packed.y >> 24) & 0xF) * scale + bias,
float((packed.y >> 28) & 0xF) * scale + bias);
float4 qVec4 = float4(float((packed.z >> 0) & 0xF) * scale + bias,
float((packed.z >> 4) & 0xF) * scale + bias,
float((packed.z >> 8) & 0xF) * scale + bias,
float((packed.z >> 12) & 0xF) * scale + bias);
float4 qVec5 = float4(float((packed.z >> 16) & 0xF) * scale + bias,
float((packed.z >> 20) & 0xF) * scale + bias,
float((packed.z >> 24) & 0xF) * scale + bias,
float((packed.z >> 28) & 0xF) * scale + bias);
float4 qVec6 = float4(float((packed.w >> 0) & 0xF) * scale + bias,
float((packed.w >> 4) & 0xF) * scale + bias,
float((packed.w >> 8) & 0xF) * scale + bias,
float((packed.w >> 12) & 0xF) * scale + bias);
float4 qVec7 = float4(float((packed.w >> 16) & 0xF) * scale + bias,
float((packed.w >> 20) & 0xF) * scale + bias,
float((packed.w >> 24) & 0xF) * scale + bias,
float((packed.w >> 28) & 0xF) * scale + bias);
// 8 × float4 dot products fused into one expression
sum += dot(qVec0, xVec0) + dot(qVec1, xVec1)
+ dot(qVec2, xVec2) + dot(qVec3, xVec3)
+ dot(qVec4, xVec4) + dot(qVec5, xVec5)
+ dot(qVec6, xVec6) + dot(qVec7, xVec7);
}
}
out[outRow] = sum;
}
// ── 8-bit SIMD Quantized Matmul ────────────────────
// Same as quantized_matmul_simd but for 8-bit weights (4 values per uint32, mask 0xFF)
kernel void quantized_matmul_simd_8bit(
device const float *x [[buffer(0)]],
device const uint *w [[buffer(1)]],
device const float *s [[buffer(2)]],
device const float *b [[buffer(3)]],
device float *out [[buffer(4)]],
constant uint &inDim [[buffer(5)]],
constant uint &outDim [[buffer(6)]],
constant uint &groupSize [[buffer(7)]], // 64
threadgroup float *shared_x [[threadgroup(0)]],
uint gid [[thread_position_in_grid]],
uint tid [[thread_position_in_threadgroup]],
uint tgSize [[threads_per_threadgroup]]
) {
uint outRow = gid;
if (outRow >= outDim) return;
for (uint i = tid; i < inDim; i += tgSize) {
shared_x[i] = x[i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
uint numGroups = inDim / groupSize;
float sum = 0.0;
for (uint g = 0; g < numGroups; g++) {
float scale = s[outRow * numGroups + g];
float bias = b[outRow * numGroups + g];
uint packedBase = outRow * (inDim / 4) + g * (groupSize / 4);
uint xBase = g * groupSize;
// Process 4 uint32 per iteration (16 × 8-bit values) — 2× fewer loops
for (uint p = 0; p < groupSize / 4; p += 4) {
device uint4 *packedPtr = (device uint4*)(&w[packedBase + p]);
uint4 packed = *packedPtr;
float4 xVec0 = float4(shared_x[xBase + p*4 + 0], shared_x[xBase + p*4 + 1], shared_x[xBase + p*4 + 2], shared_x[xBase + p*4 + 3]);
float4 xVec1 = float4(shared_x[xBase + p*4 + 4], shared_x[xBase + p*4 + 5], shared_x[xBase + p*4 + 6], shared_x[xBase + p*4 + 7]);
float4 xVec2 = float4(shared_x[xBase + p*4 + 8], shared_x[xBase + p*4 + 9], shared_x[xBase + p*4 + 10], shared_x[xBase + p*4 + 11]);
float4 xVec3 = float4(shared_x[xBase + p*4 + 12], shared_x[xBase + p*4 + 13], shared_x[xBase + p*4 + 14], shared_x[xBase + p*4 + 15]);
float4 qVec0 = float4(float((packed.x >> 0) & 0xFF) * scale + bias, float((packed.x >> 8) & 0xFF) * scale + bias, float((packed.x >> 16) & 0xFF) * scale + bias, float((packed.x >> 24) & 0xFF) * scale + bias);
float4 qVec1 = float4(float((packed.y >> 0) & 0xFF) * scale + bias, float((packed.y >> 8) & 0xFF) * scale + bias, float((packed.y >> 16) & 0xFF) * scale + bias, float((packed.y >> 24) & 0xFF) * scale + bias);
float4 qVec2 = float4(float((packed.z >> 0) & 0xFF) * scale + bias, float((packed.z >> 8) & 0xFF) * scale + bias, float((packed.z >> 16) & 0xFF) * scale + bias, float((packed.z >> 24) & 0xFF) * scale + bias);
float4 qVec3 = float4(float((packed.w >> 0) & 0xFF) * scale + bias, float((packed.w >> 8) & 0xFF) * scale + bias, float((packed.w >> 16) & 0xFF) * scale + bias, float((packed.w >> 24) & 0xFF) * scale + bias);
sum += dot(qVec0, xVec0) + dot(qVec1, xVec1)
+ dot(qVec2, xVec2) + dot(qVec3, xVec3);
}
}
out[outRow] = sum;
}
// ── Fused Gate+Up+Down for MoE Experts (4-bit) ────
// Single kernel replaces: fusedGateUp + blit + downMatmul + scaledAdd
// Phase 1: compute gate(x) * up(x) → threadgroup intermediate
// Phase 2: compute down(intermediate) → accum += weight * result
kernel void quantized_matmul_gate_up_down(
device const float *x [[buffer(0)]],
device const uint *w_gate [[buffer(1)]],
device const float *s_gate [[buffer(2)]],
device const float *b_gate [[buffer(3)]],
device const uint *w_up [[buffer(4)]],
device const float *s_up [[buffer(5)]],
device const float *b_up [[buffer(6)]],
device const uint *w_down [[buffer(7)]],
device const float *s_down [[buffer(8)]],
device const float *b_down [[buffer(9)]],
device float *accum [[buffer(10)]],
constant uint &hiddenSize [[buffer(11)]],
constant uint &moeIntermediate [[buffer(12)]],
constant uint &groupSize [[buffer(13)]],
constant float &expertWeight [[buffer(14)]],
threadgroup float *shared_space [[threadgroup(0)]],
uint gid [[thread_position_in_grid]],
uint tid [[thread_position_in_threadgroup]],
uint tgSize [[threads_per_threadgroup]]
) {
// ── Cooperative input load ────────────────────
for (uint i = tid; i < hiddenSize; i += tgSize) {
shared_space[i] = x[i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
uint numGroupsIn = hiddenSize / groupSize;
uint numGroupsOut = moeIntermediate / groupSize;
uint packedPerIn = hiddenSize / 8;
uint packedPerOut = moeIntermediate / 8;
// ── Phase 1: gate(x) + up(x) ─────────────────
if (gid < moeIntermediate) {
float gateSum = 0.0, upSum = 0.0;
for (uint g = 0; g < numGroupsIn; g++) {
float gScale = s_gate[gid * numGroupsIn + g];
float gBias = b_gate[gid * numGroupsIn + g];
float uScale = s_up[gid * numGroupsIn + g];
float uBias = b_up[gid * numGroupsIn + g];
uint wBase = gid * packedPerIn + g * (groupSize / 8);
uint xBase = g * groupSize;
for (uint p = 0; p < 8; p += 4) {
device uint4 *gPtr = (device uint4*)(&w_gate[wBase + p]);
device uint4 *uPtr = (device uint4*)(&w_up[wBase + p]);
uint4 gP = *gPtr;
uint4 uP = *uPtr;
float4 xv0 = float4(shared_space[xBase + p*8], shared_space[xBase + p*8 + 1], shared_space[xBase + p*8 + 2], shared_space[xBase + p*8 + 3]);
float4 xv1 = float4(shared_space[xBase + p*8 + 4], shared_space[xBase + p*8 + 5], shared_space[xBase + p*8 + 6], shared_space[xBase + p*8 + 7]);
float4 xv2 = float4(shared_space[xBase + p*8 + 8], shared_space[xBase + p*8 + 9], shared_space[xBase + p*8 + 10], shared_space[xBase + p*8 + 11]);
float4 xv3 = float4(shared_space[xBase + p*8 + 12], shared_space[xBase + p*8 + 13], shared_space[xBase + p*8 + 14], shared_space[xBase + p*8 + 15]);
float4 xv4 = float4(shared_space[xBase + p*8 + 16], shared_space[xBase + p*8 + 17], shared_space[xBase + p*8 + 18], shared_space[xBase + p*8 + 19]);
float4 xv5 = float4(shared_space[xBase + p*8 + 20], shared_space[xBase + p*8 + 21], shared_space[xBase + p*8 + 22], shared_space[xBase + p*8 + 23]);
float4 xv6 = float4(shared_space[xBase + p*8 + 24], shared_space[xBase + p*8 + 25], shared_space[xBase + p*8 + 26], shared_space[xBase + p*8 + 27]);
float4 xv7 = float4(shared_space[xBase + p*8 + 28], shared_space[xBase + p*8 + 29], shared_space[xBase + p*8 + 30], shared_space[xBase + p*8 + 31]);
float4 g0 = float4(float((gP.x >> 0) & 0xF) * gScale + gBias, float((gP.x >> 4) & 0xF) * gScale + gBias, float((gP.x >> 8) & 0xF) * gScale + gBias, float((gP.x >> 12) & 0xF) * gScale + gBias);
float4 g1 = float4(float((gP.x >> 16) & 0xF) * gScale + gBias, float((gP.x >> 20) & 0xF) * gScale + gBias, float((gP.x >> 24) & 0xF) * gScale + gBias, float((gP.x >> 28) & 0xF) * gScale + gBias);
float4 g2 = float4(float((gP.y >> 0) & 0xF) * gScale + gBias, float((gP.y >> 4) & 0xF) * gScale + gBias, float((gP.y >> 8) & 0xF) * gScale + gBias, float((gP.y >> 12) & 0xF) * gScale + gBias);
float4 g3 = float4(float((gP.y >> 16) & 0xF) * gScale + gBias, float((gP.y >> 20) & 0xF) * gScale + gBias, float((gP.y >> 24) & 0xF) * gScale + gBias, float((gP.y >> 28) & 0xF) * gScale + gBias);
float4 g4 = float4(float((gP.z >> 0) & 0xF) * gScale + gBias, float((gP.z >> 4) & 0xF) * gScale + gBias, float((gP.z >> 8) & 0xF) * gScale + gBias, float((gP.z >> 12) & 0xF) * gScale + gBias);
float4 g5 = float4(float((gP.z >> 16) & 0xF) * gScale + gBias, float((gP.z >> 20) & 0xF) * gScale + gBias, float((gP.z >> 24) & 0xF) * gScale + gBias, float((gP.z >> 28) & 0xF) * gScale + gBias);
float4 g6 = float4(float((gP.w >> 0) & 0xF) * gScale + gBias, float((gP.w >> 4) & 0xF) * gScale + gBias, float((gP.w >> 8) & 0xF) * gScale + gBias, float((gP.w >> 12) & 0xF) * gScale + gBias);
float4 g7 = float4(float((gP.w >> 16) & 0xF) * gScale + gBias, float((gP.w >> 20) & 0xF) * gScale + gBias, float((gP.w >> 24) & 0xF) * gScale + gBias, float((gP.w >> 28) & 0xF) * gScale + gBias);
float4 u0 = float4(float((uP.x >> 0) & 0xF) * uScale + uBias, float((uP.x >> 4) & 0xF) * uScale + uBias, float((uP.x >> 8) & 0xF) * uScale + uBias, float((uP.x >> 12) & 0xF) * uScale + uBias);
float4 u1 = float4(float((uP.x >> 16) & 0xF) * uScale + uBias, float((uP.x >> 20) & 0xF) * uScale + uBias, float((uP.x >> 24) & 0xF) * uScale + uBias, float((uP.x >> 28) & 0xF) * uScale + uBias);
float4 u2 = float4(float((uP.y >> 0) & 0xF) * uScale + uBias, float((uP.y >> 4) & 0xF) * uScale + uBias, float((uP.y >> 8) & 0xF) * uScale + uBias, float((uP.y >> 12) & 0xF) * uScale + uBias);
float4 u3 = float4(float((uP.y >> 16) & 0xF) * uScale + uBias, float((uP.y >> 20) & 0xF) * uScale + uBias, float((uP.y >> 24) & 0xF) * uScale + uBias, float((uP.y >> 28) & 0xF) * uScale + uBias);
float4 u4 = float4(float((uP.z >> 0) & 0xF) * uScale + uBias, float((uP.z >> 4) & 0xF) * uScale + uBias, float((uP.z >> 8) & 0xF) * uScale + uBias, float((uP.z >> 12) & 0xF) * uScale + uBias);
float4 u5 = float4(float((uP.z >> 16) & 0xF) * uScale + uBias, float((uP.z >> 20) & 0xF) * uScale + uBias, float((uP.z >> 24) & 0xF) * uScale + uBias, float((uP.z >> 28) & 0xF) * uScale + uBias);
float4 u6 = float4(float((uP.w >> 0) & 0xF) * uScale + uBias, float((uP.w >> 4) & 0xF) * uScale + uBias, float((uP.w >> 8) & 0xF) * uScale + uBias, float((uP.w >> 12) & 0xF) * uScale + uBias);
float4 u7 = float4(float((uP.w >> 16) & 0xF) * uScale + uBias, float((uP.w >> 20) & 0xF) * uScale + uBias, float((uP.w >> 24) & 0xF) * uScale + uBias, float((uP.w >> 28) & 0xF) * uScale + uBias);
gateSum += dot(g0, xv0) + dot(g1, xv1) + dot(g2, xv2) + dot(g3, xv3)
+ dot(g4, xv4) + dot(g5, xv5) + dot(g6, xv6) + dot(g7, xv7);
upSum += dot(u0, xv0) + dot(u1, xv1) + dot(u2, xv2) + dot(u3, xv3)
+ dot(u4, xv4) + dot(u5, xv5) + dot(u6, xv6) + dot(u7, xv7);
}
}
// GELU activation (same formula as existing quantized_matmul_gate_up)
if (gateSum > 100.0) gateSum = 100.0;
if (gateSum < -100.0) gateSum = -100.0;
float v = gateSum;
float geluVal;
float absv = v > 0 ? v : -v;
if (absv > 10.0) {
geluVal = v > 0 ? v : 0.0;
} else {
float v3 = v * v * v;
geluVal = 0.5 * v * (1.0 + tanh(0.7978845608028654 * (v + 0.044715 * v3)));
}
// Clamp upSum
if (upSum > 100.0) upSum = 100.0;
if (upSum < -100.0) upSum = -100.0;
float product = geluVal * upSum;
if (product > 10.0) product = 10.0;
if (product < -10.0) product = -10.0;
if (isnan(product) || isinf(product)) product = 0.0;
shared_space[gid] = product;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// ── Phase 2: down(intermediate) + accumulate ─
if (gid < hiddenSize) {
float sum = 0.0;
for (uint g = 0; g < numGroupsOut; g++) {
float scale = s_down[gid * numGroupsOut + g];
float bias = b_down[gid * numGroupsOut + g];
uint wBase = gid * packedPerOut + g * (groupSize / 8);
uint iBase = g * groupSize;
for (uint p = 0; p < 8; p += 4) {
device uint4 *wPtr = (device uint4*)(&w_down[wBase + p]);
uint4 packed = *wPtr;
float4 i0 = float4(shared_space[iBase + p*8], shared_space[iBase + p*8 + 1], shared_space[iBase + p*8 + 2], shared_space[iBase + p*8 + 3]);
float4 i1 = float4(shared_space[iBase + p*8 + 4], shared_space[iBase + p*8 + 5], shared_space[iBase + p*8 + 6], shared_space[iBase + p*8 + 7]);
float4 i2 = float4(shared_space[iBase + p*8 + 8], shared_space[iBase + p*8 + 9], shared_space[iBase + p*8 + 10], shared_space[iBase + p*8 + 11]);
float4 i3 = float4(shared_space[iBase + p*8 + 12], shared_space[iBase + p*8 + 13], shared_space[iBase + p*8 + 14], shared_space[iBase + p*8 + 15]);
float4 i4 = float4(shared_space[iBase + p*8 + 16], shared_space[iBase + p*8 + 17], shared_space[iBase + p*8 + 18], shared_space[iBase + p*8 + 19]);
float4 i5 = float4(shared_space[iBase + p*8 + 20], shared_space[iBase + p*8 + 21], shared_space[iBase + p*8 + 22], shared_space[iBase + p*8 + 23]);
float4 i6 = float4(shared_space[iBase + p*8 + 24], shared_space[iBase + p*8 + 25], shared_space[iBase + p*8 + 26], shared_space[iBase + p*8 + 27]);
float4 i7 = float4(shared_space[iBase + p*8 + 28], shared_space[iBase + p*8 + 29], shared_space[iBase + p*8 + 30], shared_space[iBase + p*8 + 31]);
float4 q0 = float4(float((packed.x >> 0) & 0xF) * scale + bias, float((packed.x >> 4) & 0xF) * scale + bias, float((packed.x >> 8) & 0xF) * scale + bias, float((packed.x >> 12) & 0xF) * scale + bias);
float4 q1 = float4(float((packed.x >> 16) & 0xF) * scale + bias, float((packed.x >> 20) & 0xF) * scale + bias, float((packed.x >> 24) & 0xF) * scale + bias, float((packed.x >> 28) & 0xF) * scale + bias);
float4 q2 = float4(float((packed.y >> 0) & 0xF) * scale + bias, float((packed.y >> 4) & 0xF) * scale + bias, float((packed.y >> 8) & 0xF) * scale + bias, float((packed.y >> 12) & 0xF) * scale + bias);
float4 q3 = float4(float((packed.y >> 16) & 0xF) * scale + bias, float((packed.y >> 20) & 0xF) * scale + bias, float((packed.y >> 24) & 0xF) * scale + bias, float((packed.y >> 28) & 0xF) * scale + bias);
float4 q4 = float4(float((packed.z >> 0) & 0xF) * scale + bias, float((packed.z >> 4) & 0xF) * scale + bias, float((packed.z >> 8) & 0xF) * scale + bias, float((packed.z >> 12) & 0xF) * scale + bias);
float4 q5 = float4(float((packed.z >> 16) & 0xF) * scale + bias, float((packed.z >> 20) & 0xF) * scale + bias, float((packed.z >> 24) & 0xF) * scale + bias, float((packed.z >> 28) & 0xF) * scale + bias);
float4 q6 = float4(float((packed.w >> 0) & 0xF) * scale + bias, float((packed.w >> 4) & 0xF) * scale + bias, float((packed.w >> 8) & 0xF) * scale + bias, float((packed.w >> 12) & 0xF) * scale + bias);
float4 q7 = float4(float((packed.w >> 16) & 0xF) * scale + bias, float((packed.w >> 20) & 0xF) * scale + bias, float((packed.w >> 24) & 0xF) * scale + bias, float((packed.w >> 28) & 0xF) * scale + bias);
sum += dot(q0, i0) + dot(q1, i1) + dot(q2, i2) + dot(q3, i3)
+ dot(q4, i4) + dot(q5, i5) + dot(q6, i6) + dot(q7, i7);
}
}
accum[gid] += expertWeight * sum;
}
}
// ── Fused Gate+Up+Down for MoE Experts (8-bit) ────
kernel void quantized_matmul_gate_up_down_8bit(
device const float *x [[buffer(0)]],
device const uint *w_gate [[buffer(1)]],
device const float *s_gate [[buffer(2)]],
device const float *b_gate [[buffer(3)]],
device const uint *w_up [[buffer(4)]],
device const float *s_up [[buffer(5)]],
device const float *b_up [[buffer(6)]],
device const uint *w_down [[buffer(7)]],
device const float *s_down [[buffer(8)]],
device const float *b_down [[buffer(9)]],
device float *accum [[buffer(10)]],
constant uint &hiddenSize [[buffer(11)]],
constant uint &moeIntermediate [[buffer(12)]],
constant uint &groupSize [[buffer(13)]],
constant float &expertWeight [[buffer(14)]],
threadgroup float *shared_space [[threadgroup(0)]],
uint gid [[thread_position_in_grid]],
uint tid [[thread_position_in_threadgroup]],
uint tgSize [[threads_per_threadgroup]]
) {
for (uint i = tid; i < hiddenSize; i += tgSize) {
shared_space[i] = x[i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
uint numGroupsIn = hiddenSize / groupSize;
uint numGroupsOut = moeIntermediate / groupSize;
uint packedPerIn = hiddenSize / 4;
uint packedPerOut = moeIntermediate / 4;
if (gid < moeIntermediate) {
float gateSum = 0.0, upSum = 0.0;
for (uint g = 0; g < numGroupsIn; g++) {
float gScale = s_gate[gid * numGroupsIn + g];
float gBias = b_gate[gid * numGroupsIn + g];
float uScale = s_up[gid * numGroupsIn + g];
float uBias = b_up[gid * numGroupsIn + g];
uint wBase = gid * packedPerIn + g * (groupSize / 4);
uint xBase = g * groupSize;
for (uint p = 0; p < groupSize / 4; p += 4) {
device uint4 *gPtr = (device uint4*)(&w_gate[wBase + p]);
device uint4 *uPtr = (device uint4*)(&w_up[wBase + p]);
uint4 gP = *gPtr;
uint4 uP = *uPtr;
float4 x0 = float4(shared_space[xBase + p*4], shared_space[xBase + p*4 + 1], shared_space[xBase + p*4 + 2], shared_space[xBase + p*4 + 3]);
float4 x1 = float4(shared_space[xBase + p*4 + 4], shared_space[xBase + p*4 + 5], shared_space[xBase + p*4 + 6], shared_space[xBase + p*4 + 7]);
float4 x2 = float4(shared_space[xBase + p*4 + 8], shared_space[xBase + p*4 + 9], shared_space[xBase + p*4 + 10], shared_space[xBase + p*4 + 11]);
float4 x3 = float4(shared_space[xBase + p*4 + 12], shared_space[xBase + p*4 + 13], shared_space[xBase + p*4 + 14], shared_space[xBase + p*4 + 15]);
float4 g0 = float4(float((gP.x >> 0) & 0xFF) * gScale + gBias, float((gP.x >> 8) & 0xFF) * gScale + gBias, float((gP.x >> 16) & 0xFF) * gScale + gBias, float((gP.x >> 24) & 0xFF) * gScale + gBias);
float4 g1 = float4(float((gP.y >> 0) & 0xFF) * gScale + gBias, float((gP.y >> 8) & 0xFF) * gScale + gBias, float((gP.y >> 16) & 0xFF) * gScale + gBias, float((gP.y >> 24) & 0xFF) * gScale + gBias);
float4 g2 = float4(float((gP.z >> 0) & 0xFF) * gScale + gBias, float((gP.z >> 8) & 0xFF) * gScale + gBias, float((gP.z >> 16) & 0xFF) * gScale + gBias, float((gP.z >> 24) & 0xFF) * gScale + gBias);
float4 g3 = float4(float((gP.w >> 0) & 0xFF) * gScale + gBias, float((gP.w >> 8) & 0xFF) * gScale + gBias, float((gP.w >> 16) & 0xFF) * gScale + gBias, float((gP.w >> 24) & 0xFF) * gScale + gBias);
float4 u0 = float4(float((uP.x >> 0) & 0xFF) * uScale + uBias, float((uP.x >> 8) & 0xFF) * uScale + uBias, float((uP.x >> 16) & 0xFF) * uScale + uBias, float((uP.x >> 24) & 0xFF) * uScale + uBias);
float4 u1 = float4(float((uP.y >> 0) & 0xFF) * uScale + uBias, float((uP.y >> 8) & 0xFF) * uScale + uBias, float((uP.y >> 16) & 0xFF) * uScale + uBias, float((uP.y >> 24) & 0xFF) * uScale + uBias);
float4 u2 = float4(float((uP.z >> 0) & 0xFF) * uScale + uBias, float((uP.z >> 8) & 0xFF) * uScale + uBias, float((uP.z >> 16) & 0xFF) * uScale + uBias, float((uP.z >> 24) & 0xFF) * uScale + uBias);
float4 u3 = float4(float((uP.w >> 0) & 0xFF) * uScale + uBias, float((uP.w >> 8) & 0xFF) * uScale + uBias, float((uP.w >> 16) & 0xFF) * uScale + uBias, float((uP.w >> 24) & 0xFF) * uScale + uBias);
gateSum += dot(g0, x0) + dot(g1, x1) + dot(g2, x2) + dot(g3, x3);
upSum += dot(u0, x0) + dot(u1, x1) + dot(u2, x2) + dot(u3, x3);
}
}
if (gateSum > 100.0) gateSum = 100.0;
if (gateSum < -100.0) gateSum = -100.0;
float v = gateSum;
float geluVal;
float absv = v > 0 ? v : -v;
if (absv > 10.0) {
geluVal = v > 0 ? v : 0.0;
} else {
float v3 = v * v * v;
geluVal = 0.5 * v * (1.0 + tanh(0.7978845608028654 * (v + 0.044715 * v3)));
}
if (upSum > 100.0) upSum = 100.0;
if (upSum < -100.0) upSum = -100.0;
float product = geluVal * upSum;
if (product > 10.0) product = 10.0;
if (product < -10.0) product = -10.0;
if (isnan(product) || isinf(product)) product = 0.0;
shared_space[gid] = product;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (gid < hiddenSize) {
float sum = 0.0;
for (uint g = 0; g < numGroupsOut; g++) {
float scale = s_down[gid * numGroupsOut + g];
float bias = b_down[gid * numGroupsOut + g];
uint wBase = gid * packedPerOut + g * (groupSize / 4);
uint iBase = g * groupSize;
for (uint p = 0; p < groupSize / 4; p += 4) {
device uint4 *wPtr = (device uint4*)(&w_down[wBase + p]);
uint4 packed = *wPtr;
float4 i0 = float4(shared_space[iBase + p*4], shared_space[iBase + p*4 + 1], shared_space[iBase + p*4 + 2], shared_space[iBase + p*4 + 3]);
float4 i1 = float4(shared_space[iBase + p*4 + 4], shared_space[iBase + p*4 + 5], shared_space[iBase + p*4 + 6], shared_space[iBase + p*4 + 7]);
float4 i2 = float4(shared_space[iBase + p*4 + 8], shared_space[iBase + p*4 + 9], shared_space[iBase + p*4 + 10], shared_space[iBase + p*4 + 11]);
float4 i3 = float4(shared_space[iBase + p*4 + 12], shared_space[iBase + p*4 + 13], shared_space[iBase + p*4 + 14], shared_space[iBase + p*4 + 15]);
float4 q0 = float4(float((packed.x >> 0) & 0xFF) * scale + bias, float((packed.x >> 8) & 0xFF) * scale + bias, float((packed.x >> 16) & 0xFF) * scale + bias, float((packed.x >> 24) & 0xFF) * scale + bias);
float4 q1 = float4(float((packed.y >> 0) & 0xFF) * scale + bias, float((packed.y >> 8) & 0xFF) * scale + bias, float((packed.y >> 16) & 0xFF) * scale + bias, float((packed.y >> 24) & 0xFF) * scale + bias);
float4 q2 = float4(float((packed.z >> 0) & 0xFF) * scale + bias, float((packed.z >> 8) & 0xFF) * scale + bias, float((packed.z >> 16) & 0xFF) * scale + bias, float((packed.z >> 24) & 0xFF) * scale + bias);
float4 q3 = float4(float((packed.w >> 0) & 0xFF) * scale + bias, float((packed.w >> 8) & 0xFF) * scale + bias, float((packed.w >> 16) & 0xFF) * scale + bias, float((packed.w >> 24) & 0xFF) * scale + bias);
sum += dot(q0, i0) + dot(q1, i1) + dot(q2, i2) + dot(q3, i3);
}
}
accum[gid] += expertWeight * sum;
}
}
// ── Parallel RMS Norm ──────────────────────────────
// Uses threadgroup reduction for computing sum of squares
kernel void rms_norm_parallel(
device const float *x [[buffer(0)]],
device const float *w [[buffer(1)]],
device float *y [[buffer(2)]],
constant uint &N [[buffer(3)]],
constant float &eps [[buffer(4)]],
threadgroup float *partial_sums [[threadgroup(0)]],
uint tid [[thread_position_in_threadgroup]],
uint tgSize [[threads_per_threadgroup]],
uint gid [[thread_position_in_grid]]
) {
// Phase 1: Each thread computes partial sum of squares
float localSum = 0.0;
for (uint i = tid; i < N; i += tgSize) {
localSum += x[i] * x[i];
}
partial_sums[tid] = localSum;
threadgroup_barrier(mem_flags::mem_threadgroup);
// Phase 2: Parallel reduction
for (uint stride = tgSize/2; stride > 0; stride >>= 1) {
if (tid < stride) {
partial_sums[tid] += partial_sums[tid + stride];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
// Phase 3: Compute RMS and normalize
float ss = partial_sums[0];
float rms = rsqrt(ss / float(N) + eps);
// Each thread outputs its portion
for (uint i = tid; i < N; i += tgSize) {
y[i] = x[i] * rms * (w ? w[i] : 1.0);
}
}
// ── MoE Mega Kernel (CPU-free Router + All Experts) ─
// Single kernel replaces: router matmul + CPU softmax/topk + 8× expert dispatch
// Eliminates 30 CPU syncs per token for MoE models
// Threadgroup memory layout:
// [0..hiddenSize-1] = x input (reloaded each expert iteration)
// [0..moeIntermediate-1] = intermediate gate*up (written each expert iteration)
// [numExperts..numExperts+numExperts-1] = router logits (numExperts = hiddenSize is actually not)
// [numExperts..hiddenSize-1-numExperts] used more efficiently:
// After x loaded, router uses shared_space for logits, then overwritten by intermediate
kernel void moe_mega_kernel(
device const float *x [[buffer(0)]],
device const uint *w_router [[buffer(1)]],
device const float *s_router [[buffer(2)]],
device const float *b_router [[buffer(3)]],
device const uint *w_gate [[buffer(4)]],
device const float *s_gate [[buffer(5)]],
device const float *b_gate [[buffer(6)]],
device const uint *w_up [[buffer(7)]],
device const float *s_up [[buffer(8)]],
device const float *b_up [[buffer(9)]],
device const uint *w_down [[buffer(10)]],
device const float *s_down [[buffer(11)]],
device const float *b_down [[buffer(12)]],
device float *accum [[buffer(13)]],
constant uint &hiddenSize [[buffer(14)]],
constant uint &moeIntermediate [[buffer(15)]],
constant uint &numExperts [[buffer(16)]],
constant float &routerScale [[buffer(17)]],
constant uint &topK [[buffer(18)]],
threadgroup float *shared_space [[threadgroup(0)]],
uint gid [[thread_position_in_grid]],
uint tid [[thread_position_in_threadgroup]],
uint tgSize [[threads_per_threadgroup]]
) {
uint numGroupsIn = hiddenSize / 64;
uint numGroupsOut = moeIntermediate / 64;
uint packedPerIn = hiddenSize / 8;
uint packedPerOut = moeIntermediate / 8;
// ── Phase 0: Cooperative load x ────────────
for (uint i = tid; i < hiddenSize; i += tgSize) {
shared_space[i] = x[i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// ── Phase 0: Router matmul ─────────────────
// Router logits stored at shared_space[hiddenSize .. hiddenSize+numExperts-1]
uint logitBase = hiddenSize;
uint topKBase = hiddenSize + numExperts;
if (tid < numExperts) {
float logit = 0.0;
for (uint g = 0; g < numGroupsIn; g++) {
float scale = s_router[tid * numGroupsIn + g];
float bias = b_router[tid * numGroupsIn + g];
uint wBase = tid * packedPerIn + g * 8;
uint xBase = g * 64;
for (uint p = 0; p < 8; p += 4) {
device uint4 *rPtr = (device uint4*)(&w_router[wBase + p]);
uint4 packed = *rPtr;
float4 xv0 = float4(shared_space[xBase + p*8], shared_space[xBase + p*8 + 1], shared_space[xBase + p*8 + 2], shared_space[xBase + p*8 + 3]);
float4 xv1 = float4(shared_space[xBase + p*8 + 4], shared_space[xBase + p*8 + 5], shared_space[xBase + p*8 + 6], shared_space[xBase + p*8 + 7]);
float4 xv2 = float4(shared_space[xBase + p*8 + 8], shared_space[xBase + p*8 + 9], shared_space[xBase + p*8 + 10], shared_space[xBase + p*8 + 11]);
float4 xv3 = float4(shared_space[xBase + p*8 + 12], shared_space[xBase + p*8 + 13], shared_space[xBase + p*8 + 14], shared_space[xBase + p*8 + 15]);
float4 xv4 = float4(shared_space[xBase + p*8 + 16], shared_space[xBase + p*8 + 17], shared_space[xBase + p*8 + 18], shared_space[xBase + p*8 + 19]);
float4 xv5 = float4(shared_space[xBase + p*8 + 20], shared_space[xBase + p*8 + 21], shared_space[xBase + p*8 + 22], shared_space[xBase + p*8 + 23]);
float4 xv6 = float4(shared_space[xBase + p*8 + 24], shared_space[xBase + p*8 + 25], shared_space[xBase + p*8 + 26], shared_space[xBase + p*8 + 27]);
float4 xv7 = float4(shared_space[xBase + p*8 + 28], shared_space[xBase + p*8 + 29], shared_space[xBase + p*8 + 30], shared_space[xBase + p*8 + 31]);
float4 q0 = float4(float((packed.x >> 0) & 0xF) * scale + bias, float((packed.x >> 4) & 0xF) * scale + bias, float((packed.x >> 8) & 0xF) * scale + bias, float((packed.x >> 12) & 0xF) * scale + bias);
float4 q1 = float4(float((packed.x >> 16) & 0xF) * scale + bias, float((packed.x >> 20) & 0xF) * scale + bias, float((packed.x >> 24) & 0xF) * scale + bias, float((packed.x >> 28) & 0xF) * scale + bias);
float4 q2 = float4(float((packed.y >> 0) & 0xF) * scale + bias, float((packed.y >> 4) & 0xF) * scale + bias, float((packed.y >> 8) & 0xF) * scale + bias, float((packed.y >> 12) & 0xF) * scale + bias);
float4 q3 = float4(float((packed.y >> 16) & 0xF) * scale + bias, float((packed.y >> 20) & 0xF) * scale + bias, float((packed.y >> 24) & 0xF) * scale + bias, float((packed.y >> 28) & 0xF) * scale + bias);
float4 q4 = float4(float((packed.z >> 0) & 0xF) * scale + bias, float((packed.z >> 4) & 0xF) * scale + bias, float((packed.z >> 8) & 0xF) * scale + bias, float((packed.z >> 12) & 0xF) * scale + bias);
float4 q5 = float4(float((packed.z >> 16) & 0xF) * scale + bias, float((packed.z >> 20) & 0xF) * scale + bias, float((packed.z >> 24) & 0xF) * scale + bias, float((packed.z >> 28) & 0xF) * scale + bias);
float4 q6 = float4(float((packed.w >> 0) & 0xF) * scale + bias, float((packed.w >> 4) & 0xF) * scale + bias, float((packed.w >> 8) & 0xF) * scale + bias, float((packed.w >> 12) & 0xF) * scale + bias);
float4 q7 = float4(float((packed.w >> 16) & 0xF) * scale + bias, float((packed.w >> 20) & 0xF) * scale + bias, float((packed.w >> 24) & 0xF) * scale + bias, float((packed.w >> 28) & 0xF) * scale + bias);
logit += dot(q0, xv0) + dot(q1, xv1) + dot(q2, xv2) + dot(q3, xv3)
+ dot(q4, xv4) + dot(q5, xv5) + dot(q6, xv6) + dot(q7, xv7);
}
}
shared_space[logitBase + tid] = logit * routerScale;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// ── Phase 0: Softmax ────────────────────────
// Find max
float maxVal = -FLT_MAX;
if (tid < numExperts) {
maxVal = shared_space[logitBase + tid];
}
float maxReduce = maxVal;
for (uint stride = tgSize/2; stride > 0; stride >>= 1) {
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tid < stride) {
maxReduce = fmax(maxReduce, shared_space[logitBase + tid + stride]);
shared_space[logitBase + tid] = maxReduce;
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
float globalMax = shared_space[logitBase];
// Compute exp sum
float localExp = 0.0;
if (tid < numExperts) {
float v = shared_space[logitBase + tid];
localExp = exp(v - globalMax);
shared_space[logitBase + tid] = localExp;
}
float sumReduce = localExp;
for (uint stride = tgSize/2; stride > 0; stride >>= 1) {
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tid < stride) {
sumReduce += shared_space[logitBase + tid + stride];
shared_space[logitBase + tid] = sumReduce;
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
float expSum = shared_space[logitBase];
if (expSum <= 0) expSum = 1.0;
// Normalize
if (tid < numExperts) {
shared_space[logitBase + tid] /= expSum;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// ── Phase 0: Top-K ─────────────────────────
if (tid == 0) {
float vals[128];
uint idx[128];
for (uint i = 0; i < numExperts; i++) {
vals[i] = shared_space[logitBase + i];
idx[i] = i;
}
for (uint i = 0; i < topK; i++) {
uint best = i;
for (uint j = i + 1; j < numExperts; j++) {
if (vals[j] > vals[best]) { best = j; }
}
float tmpV = vals[i]; vals[i] = vals[best]; vals[best] = tmpV;
uint tmpI = idx[i]; idx[i] = idx[best]; idx[best] = tmpI;
shared_space[topKBase + i] = float(idx[i]);
}
float topKSum = 0.0;
for (uint i = 0; i < topK; i++) { topKSum += vals[i]; }
if (topKSum <= 0) topKSum = 1.0;
for (uint i = 0; i < topK; i++) {
shared_space[topKBase + topK + i] = vals[i] / topKSum;
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// ── Phases 1-8: Expert dispatch ─────────────
uint expertWeightBase = topKBase;
for (uint e = 0; e < topK; e++) {
uint expertIdx = uint(shared_space[expertWeightBase + e]);
float expertWeight = shared_space[expertWeightBase + topK + e];
// Reload x for this expert (intermediate from previous iteration
// overwrites shared_space[0..moeIntermediate-1])
for (uint i = tid; i < hiddenSize; i += tgSize) {
shared_space[i] = x[i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Phase 1: gate+up → intermediate
if (gid < moeIntermediate) {
float gateSum = 0.0, upSum = 0.0;
// weight and scale buffers have different strides
uint wGateBase = expertIdx * moeIntermediate * packedPerIn;
uint sGateBase = expertIdx * moeIntermediate * numGroupsIn;
uint wUpBase = expertIdx * moeIntermediate * packedPerIn;
uint sUpBase = expertIdx * moeIntermediate * numGroupsIn;
for (uint g = 0; g < numGroupsIn; g++) {
float gScale = s_gate[sGateBase + gid * numGroupsIn + g];
float gBias = b_gate[sGateBase + gid * numGroupsIn + g];
float uScale = s_up[sUpBase + gid * numGroupsIn + g];
float uBias = b_up[sUpBase + gid * numGroupsIn + g];
uint wb = gid * packedPerIn + g * 8;
uint xBase = g * 64;
for (uint p = 0; p < 8; p += 4) {
device uint4 *gPtr = (device uint4*)(&w_gate[wGateBase + wb + p]);
device uint4 *uPtr = (device uint4*)(&w_up[wUpBase + wb + p]);
uint4 gP = *gPtr;
uint4 uP = *uPtr;
float4 xv0 = float4(shared_space[xBase + p*8], shared_space[xBase + p*8 + 1], shared_space[xBase + p*8 + 2], shared_space[xBase + p*8 + 3]);
float4 xv1 = float4(shared_space[xBase + p*8 + 4], shared_space[xBase + p*8 + 5], shared_space[xBase + p*8 + 6], shared_space[xBase + p*8 + 7]);
float4 xv2 = float4(shared_space[xBase + p*8 + 8], shared_space[xBase + p*8 + 9], shared_space[xBase + p*8 + 10], shared_space[xBase + p*8 + 11]);
float4 xv3 = float4(shared_space[xBase + p*8 + 12], shared_space[xBase + p*8 + 13], shared_space[xBase + p*8 + 14], shared_space[xBase + p*8 + 15]);
float4 xv4 = float4(shared_space[xBase + p*8 + 16], shared_space[xBase + p*8 + 17], shared_space[xBase + p*8 + 18], shared_space[xBase + p*8 + 19]);
float4 xv5 = float4(shared_space[xBase + p*8 + 20], shared_space[xBase + p*8 + 21], shared_space[xBase + p*8 + 22], shared_space[xBase + p*8 + 23]);
float4 xv6 = float4(shared_space[xBase + p*8 + 24], shared_space[xBase + p*8 + 25], shared_space[xBase + p*8 + 26], shared_space[xBase + p*8 + 27]);
float4 xv7 = float4(shared_space[xBase + p*8 + 28], shared_space[xBase + p*8 + 29], shared_space[xBase + p*8 + 30], shared_space[xBase + p*8 + 31]);
float4 g0 = float4(float((gP.x >> 0) & 0xF) * gScale + gBias, float((gP.x >> 4) & 0xF) * gScale + gBias, float((gP.x >> 8) & 0xF) * gScale + gBias, float((gP.x >> 12) & 0xF) * gScale + gBias);
float4 g1 = float4(float((gP.x >> 16) & 0xF) * gScale + gBias, float((gP.x >> 20) & 0xF) * gScale + gBias, float((gP.x >> 24) & 0xF) * gScale + gBias, float((gP.x >> 28) & 0xF) * gScale + gBias);
float4 g2 = float4(float((gP.y >> 0) & 0xF) * gScale + gBias, float((gP.y >> 4) & 0xF) * gScale + gBias, float((gP.y >> 8) & 0xF) * gScale + gBias, float((gP.y >> 12) & 0xF) * gScale + gBias);
float4 g3 = float4(float((gP.y >> 16) & 0xF) * gScale + gBias, float((gP.y >> 20) & 0xF) * gScale + gBias, float((gP.y >> 24) & 0xF) * gScale + gBias, float((gP.y >> 28) & 0xF) * gScale + gBias);
float4 g4 = float4(float((gP.z >> 0) & 0xF) * gScale + gBias, float((gP.z >> 4) & 0xF) * gScale + gBias, float((gP.z >> 8) & 0xF) * gScale + gBias, float((gP.z >> 12) & 0xF) * gScale + gBias);
float4 g5 = float4(float((gP.z >> 16) & 0xF) * gScale + gBias, float((gP.z >> 20) & 0xF) * gScale + gBias, float((gP.z >> 24) & 0xF) * gScale + gBias, float((gP.z >> 28) & 0xF) * gScale + gBias);
float4 g6 = float4(float((gP.w >> 0) & 0xF) * gScale + gBias, float((gP.w >> 4) & 0xF) * gScale + gBias, float((gP.w >> 8) & 0xF) * gScale + gBias, float((gP.w >> 12) & 0xF) * gScale + gBias);
float4 g7 = float4(float((gP.w >> 16) & 0xF) * gScale + gBias, float((gP.w >> 20) & 0xF) * gScale + gBias, float((gP.w >> 24) & 0xF) * gScale + gBias, float((gP.w >> 28) & 0xF) * gScale + gBias);
float4 u0 = float4(float((uP.x >> 0) & 0xF) * uScale + uBias, float((uP.x >> 4) & 0xF) * uScale + uBias, float((uP.x >> 8) & 0xF) * uScale + uBias, float((uP.x >> 12) & 0xF) * uScale + uBias);
float4 u1 = float4(float((uP.x >> 16) & 0xF) * uScale + uBias, float((uP.x >> 20) & 0xF) * uScale + uBias, float((uP.x >> 24) & 0xF) * uScale + uBias, float((uP.x >> 28) & 0xF) * uScale + uBias);
float4 u2 = float4(float((uP.y >> 0) & 0xF) * uScale + uBias, float((uP.y >> 4) & 0xF) * uScale + uBias, float((uP.y >> 8) & 0xF) * uScale + uBias, float((uP.y >> 12) & 0xF) * uScale + uBias);
float4 u3 = float4(float((uP.y >> 16) & 0xF) * uScale + uBias, float((uP.y >> 20) & 0xF) * uScale + uBias, float((uP.y >> 24) & 0xF) * uScale + uBias, float((uP.y >> 28) & 0xF) * uScale + uBias);
float4 u4 = float4(float((uP.z >> 0) & 0xF) * uScale + uBias, float((uP.z >> 4) & 0xF) * uScale + uBias, float((uP.z >> 8) & 0xF) * uScale + uBias, float((uP.z >> 12) & 0xF) * uScale + uBias);
float4 u5 = float4(float((uP.z >> 16) & 0xF) * uScale + uBias, float((uP.z >> 20) & 0xF) * uScale + uBias, float((uP.z >> 24) & 0xF) * uScale + uBias, float((uP.z >> 28) & 0xF) * uScale + uBias);
float4 u6 = float4(float((uP.w >> 0) & 0xF) * uScale + uBias, float((uP.w >> 4) & 0xF) * uScale + uBias, float((uP.w >> 8) & 0xF) * uScale + uBias, float((uP.w >> 12) & 0xF) * uScale + uBias);
float4 u7 = float4(float((uP.w >> 16) & 0xF) * uScale + uBias, float((uP.w >> 20) & 0xF) * uScale + uBias, float((uP.w >> 24) & 0xF) * uScale + uBias, float((uP.w >> 28) & 0xF) * uScale + uBias);
gateSum += dot(g0, xv0) + dot(g1, xv1) + dot(g2, xv2) + dot(g3, xv3)
+ dot(g4, xv4) + dot(g5, xv5) + dot(g6, xv6) + dot(g7, xv7);
upSum += dot(u0, xv0) + dot(u1, xv1) + dot(u2, xv2) + dot(u3, xv3)
+ dot(u4, xv4) + dot(u5, xv5) + dot(u6, xv6) + dot(u7, xv7);
}
}
// GELU
if (gateSum > 100.0) gateSum = 100.0;
if (gateSum < -100.0) gateSum = -100.0;
float absv = gateSum > 0 ? gateSum : -gateSum;
float geluVal;
if (absv > 10.0) {
geluVal = gateSum > 0 ? gateSum : 0.0;
} else {
float v3 = gateSum * gateSum * gateSum;
geluVal = 0.5 * gateSum * (1.0 + tanh(0.7978845608028654 * (gateSum + 0.044715 * v3)));
}
if (upSum > 100.0) upSum = 100.0;
if (upSum < -100.0) upSum = -100.0;
float product = geluVal * upSum;
if (product > 10.0) product = 10.0;
if (product < -10.0) product = -10.0;
if (isnan(product) || isinf(product)) product = 0.0;
shared_space[gid] = product;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Phase 2: down projection + accumulate
if (gid < hiddenSize) {
float sum = 0.0;
uint wDownBase = expertIdx * hiddenSize * packedPerOut;
for (uint g = 0; g < numGroupsOut; g++) {
float scale = s_down[wDownBase + gid * numGroupsOut + g];
float bias = b_down[wDownBase + gid * numGroupsOut + g];
uint wb = gid * packedPerOut + g * 8;
uint iBase = g * 64;
for (uint p = 0; p < 8; p += 4) {
device uint4 *wPtr = (device uint4*)(&w_down[wDownBase + wb + p]);
uint4 packed = *wPtr;
float4 i0 = float4(shared_space[iBase + p*8], shared_space[iBase + p*8 + 1], shared_space[iBase + p*8 + 2], shared_space[iBase + p*8 + 3]);
float4 i1 = float4(shared_space[iBase + p*8 + 4], shared_space[iBase + p*8 + 5], shared_space[iBase + p*8 + 6], shared_space[iBase + p*8 + 7]);
float4 i2 = float4(shared_space[iBase + p*8 + 8], shared_space[iBase + p*8 + 9], shared_space[iBase + p*8 + 10], shared_space[iBase + p*8 + 11]);
float4 i3 = float4(shared_space[iBase + p*8 + 12], shared_space[iBase + p*8 + 13], shared_space[iBase + p*8 + 14], shared_space[iBase + p*8 + 15]);
float4 i4 = float4(shared_space[iBase + p*8 + 16], shared_space[iBase + p*8 + 17], shared_space[iBase + p*8 + 18], shared_space[iBase + p*8 + 19]);
float4 i5 = float4(shared_space[iBase + p*8 + 20], shared_space[iBase + p*8 + 21], shared_space[iBase + p*8 + 22], shared_space[iBase + p*8 + 23]);
float4 i6 = float4(shared_space[iBase + p*8 + 24], shared_space[iBase + p*8 + 25], shared_space[iBase + p*8 + 26], shared_space[iBase + p*8 + 27]);
float4 i7 = float4(shared_space[iBase + p*8 + 28], shared_space[iBase + p*8 + 29], shared_space[iBase + p*8 + 30], shared_space[iBase + p*8 + 31]);
float4 q0 = float4(float((packed.x >> 0) & 0xF) * scale + bias, float((packed.x >> 4) & 0xF) * scale + bias, float((packed.x >> 8) & 0xF) * scale + bias, float((packed.x >> 12) & 0xF) * scale + bias);
float4 q1 = float4(float((packed.x >> 16) & 0xF) * scale + bias, float((packed.x >> 20) & 0xF) * scale + bias, float((packed.x >> 24) & 0xF) * scale + bias, float((packed.x >> 28) & 0xF) * scale + bias);
float4 q2 = float4(float((packed.y >> 0) & 0xF) * scale + bias, float((packed.y >> 4) & 0xF) * scale + bias, float((packed.y >> 8) & 0xF) * scale + bias, float((packed.y >> 12) & 0xF) * scale + bias);
float4 q3 = float4(float((packed.y >> 16) & 0xF) * scale + bias, float((packed.y >> 20) & 0xF) * scale + bias, float((packed.y >> 24) & 0xF) * scale + bias, float((packed.y >> 28) & 0xF) * scale + bias);
float4 q4 = float4(float((packed.z >> 0) & 0xF) * scale + bias, float((packed.z >> 4) & 0xF) * scale + bias, float((packed.z >> 8) & 0xF) * scale + bias, float((packed.z >> 12) & 0xF) * scale + bias);
float4 q5 = float4(float((packed.z >> 16) & 0xF) * scale + bias, float((packed.z >> 20) & 0xF) * scale + bias, float((packed.z >> 24) & 0xF) * scale + bias, float((packed.z >> 28) & 0xF) * scale + bias);
float4 q6 = float4(float((packed.w >> 0) & 0xF) * scale + bias, float((packed.w >> 4) & 0xF) * scale + bias, float((packed.w >> 8) & 0xF) * scale + bias, float((packed.w >> 12) & 0xF) * scale + bias);
float4 q7 = float4(float((packed.w >> 16) & 0xF) * scale + bias, float((packed.w >> 20) & 0xF) * scale + bias, float((packed.w >> 24) & 0xF) * scale + bias, float((packed.w >> 28) & 0xF) * scale + bias);
sum += dot(q0, i0) + dot(q1, i1) + dot(q2, i2) + dot(q3, i3)
+ dot(q4, i4) + dot(q5, i5) + dot(q6, i6) + dot(q7, i7);
}
}
accum[gid] += expertWeight * sum;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
}
// ── Optimized Fused Gate+Up (Dense Models) ───────
// Threadgroup-cached input + uint4 loads
// Used by: 26B-Standard, 31B, E4B-MarkBase
kernel void quantized_matmul_gate_up_opt(
device const float *x [[buffer(0)]],
device const uint *w_gate [[buffer(1)]],
device const float *s_gate [[buffer(2)]],
device const float *b_gate [[buffer(3)]],
device const uint *w_up [[buffer(4)]],
device const float *s_up [[buffer(5)]],
device const float *b_up [[buffer(6)]],
device float *out [[buffer(7)]],
constant uint &inDim [[buffer(8)]],
constant uint &outDim [[buffer(9)]],
constant uint &groupSize [[buffer(10)]],
threadgroup float *shared_x [[threadgroup(0)]],
uint gid [[thread_position_in_grid]],
uint tid [[thread_position_in_threadgroup]],
uint tgSize [[threads_per_threadgroup]]
) {
if (gid >= outDim) return;
for (uint i = tid; i < inDim; i += tgSize) {
shared_x[i] = x[i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
uint numGroups = inDim / groupSize;
uint packedPerOut = inDim / 8;
float gateSum = 0.0, upSum = 0.0;
for (uint g = 0; g < numGroups; g++) {
float gScale = s_gate[gid * numGroups + g];
float gBias = b_gate[gid * numGroups + g];
float uScale = s_up[gid * numGroups + g];
float uBias = b_up[gid * numGroups + g];
uint wBase = gid * packedPerOut + g * (groupSize / 8);
uint xBase = g * groupSize;
for (uint p = 0; p < 8; p += 4) {
device uint4 *gPtr = (device uint4*)(&w_gate[wBase + p]);
device uint4 *uPtr = (device uint4*)(&w_up[wBase + p]);
uint4 gP = *gPtr;
uint4 uP = *uPtr;
float4 xv0 = float4(shared_x[xBase + p*8], shared_x[xBase + p*8 + 1], shared_x[xBase + p*8 + 2], shared_x[xBase + p*8 + 3]);
float4 xv1 = float4(shared_x[xBase + p*8 + 4], shared_x[xBase + p*8 + 5], shared_x[xBase + p*8 + 6], shared_x[xBase + p*8 + 7]);
float4 xv2 = float4(shared_x[xBase + p*8 + 8], shared_x[xBase + p*8 + 9], shared_x[xBase + p*8 + 10], shared_x[xBase + p*8 + 11]);
float4 xv3 = float4(shared_x[xBase + p*8 + 12], shared_x[xBase + p*8 + 13], shared_x[xBase + p*8 + 14], shared_x[xBase + p*8 + 15]);
float4 xv4 = float4(shared_x[xBase + p*8 + 16], shared_x[xBase + p*8 + 17], shared_x[xBase + p*8 + 18], shared_x[xBase + p*8 + 19]);
float4 xv5 = float4(shared_x[xBase + p*8 + 20], shared_x[xBase + p*8 + 21], shared_x[xBase + p*8 + 22], shared_x[xBase + p*8 + 23]);
float4 xv6 = float4(shared_x[xBase + p*8 + 24], shared_x[xBase + p*8 + 25], shared_x[xBase + p*8 + 26], shared_x[xBase + p*8 + 27]);
float4 xv7 = float4(shared_x[xBase + p*8 + 28], shared_x[xBase + p*8 + 29], shared_x[xBase + p*8 + 30], shared_x[xBase + p*8 + 31]);
float4 g0 = float4(float((gP.x >> 0) & 0xF) * gScale + gBias, float((gP.x >> 4) & 0xF) * gScale + gBias, float((gP.x >> 8) & 0xF) * gScale + gBias, float((gP.x >> 12) & 0xF) * gScale + gBias);
float4 g1 = float4(float((gP.x >> 16) & 0xF) * gScale + gBias, float((gP.x >> 20) & 0xF) * gScale + gBias, float((gP.x >> 24) & 0xF) * gScale + gBias, float((gP.x >> 28) & 0xF) * gScale + gBias);
float4 g2 = float4(float((gP.y >> 0) & 0xF) * gScale + gBias, float((gP.y >> 4) & 0xF) * gScale + gBias, float((gP.y >> 8) & 0xF) * gScale + gBias, float((gP.y >> 12) & 0xF) * gScale + gBias);
float4 g3 = float4(float((gP.y >> 16) & 0xF) * gScale + gBias, float((gP.y >> 20) & 0xF) * gScale + gBias, float((gP.y >> 24) & 0xF) * gScale + gBias, float((gP.y >> 28) & 0xF) * gScale + gBias);
float4 g4 = float4(float((gP.z >> 0) & 0xF) * gScale + gBias, float((gP.z >> 4) & 0xF) * gScale + gBias, float((gP.z >> 8) & 0xF) * gScale + gBias, float((gP.z >> 12) & 0xF) * gScale + gBias);
float4 g5 = float4(float((gP.z >> 16) & 0xF) * gScale + gBias, float((gP.z >> 20) & 0xF) * gScale + gBias, float((gP.z >> 24) & 0xF) * gScale + gBias, float((gP.z >> 28) & 0xF) * gScale + gBias);
float4 g6 = float4(float((gP.w >> 0) & 0xF) * gScale + gBias, float((gP.w >> 4) & 0xF) * gScale + gBias, float((gP.w >> 8) & 0xF) * gScale + gBias, float((gP.w >> 12) & 0xF) * gScale + gBias);
float4 g7 = float4(float((gP.w >> 16) & 0xF) * gScale + gBias, float((gP.w >> 20) & 0xF) * gScale + gBias, float((gP.w >> 24) & 0xF) * gScale + gBias, float((gP.w >> 28) & 0xF) * gScale + gBias);
float4 u0 = float4(float((uP.x >> 0) & 0xF) * uScale + uBias, float((uP.x >> 4) & 0xF) * uScale + uBias, float((uP.x >> 8) & 0xF) * uScale + uBias, float((uP.x >> 12) & 0xF) * uScale + uBias);
float4 u1 = float4(float((uP.x >> 16) & 0xF) * uScale + uBias, float((uP.x >> 20) & 0xF) * uScale + uBias, float((uP.x >> 24) & 0xF) * uScale + uBias, float((uP.x >> 28) & 0xF) * uScale + uBias);
float4 u2 = float4(float((uP.y >> 0) & 0xF) * uScale + uBias, float((uP.y >> 4) & 0xF) * uScale + uBias, float((uP.y >> 8) & 0xF) * uScale + uBias, float((uP.y >> 12) & 0xF) * uScale + uBias);
float4 u3 = float4(float((uP.y >> 16) & 0xF) * uScale + uBias, float((uP.y >> 20) & 0xF) * uScale + uBias, float((uP.y >> 24) & 0xF) * uScale + uBias, float((uP.y >> 28) & 0xF) * uScale + uBias);
float4 u4 = float4(float((uP.z >> 0) & 0xF) * uScale + uBias, float((uP.z >> 4) & 0xF) * uScale + uBias, float((uP.z >> 8) & 0xF) * uScale + uBias, float((uP.z >> 12) & 0xF) * uScale + uBias);
float4 u5 = float4(float((uP.z >> 16) & 0xF) * uScale + uBias, float((uP.z >> 20) & 0xF) * uScale + uBias, float((uP.z >> 24) & 0xF) * uScale + uBias, float((uP.z >> 28) & 0xF) * uScale + uBias);
float4 u6 = float4(float((uP.w >> 0) & 0xF) * uScale + uBias, float((uP.w >> 4) & 0xF) * uScale + uBias, float((uP.w >> 8) & 0xF) * uScale + uBias, float((uP.w >> 12) & 0xF) * uScale + uBias);
float4 u7 = float4(float((uP.w >> 16) & 0xF) * uScale + uBias, float((uP.w >> 20) & 0xF) * uScale + uBias, float((uP.w >> 24) & 0xF) * uScale + uBias, float((uP.w >> 28) & 0xF) * uScale + uBias);
gateSum += dot(g0, xv0) + dot(g1, xv1) + dot(g2, xv2) + dot(g3, xv3)
+ dot(g4, xv4) + dot(g5, xv5) + dot(g6, xv6) + dot(g7, xv7);
upSum += dot(u0, xv0) + dot(u1, xv1) + dot(u2, xv2) + dot(u3, xv3)
+ dot(u4, xv4) + dot(u5, xv5) + dot(u6, xv6) + dot(u7, xv7);
}
}
if (gateSum > 100.0) gateSum = 100.0;
if (gateSum < -100.0) gateSum = -100.0;
float v = gateSum;
float absv = v > 0 ? v : -v;
float gate;
if (absv > 10.0) {
gate = v > 0 ? v : 0.0;
} else {
float v3 = v * v * v;
gate = 0.5 * v * (1.0 + tanh(0.7978845608028654 * (v + 0.044715 * v3)));
}
if (upSum > 100.0) upSum = 100.0;
if (upSum < -100.0) upSum = -100.0;
float product = gate * upSum;
if (product > 10.0) product = 10.0;
if (product < -10.0) product = -10.0;
out[gid] = product;
}
// ── SIMD Elementwise Operations ─────────────────────
kernel void eltwise_mul_simd(
device const float *a,
device const float *b,
device float *out,
constant uint &count,
uint id [[thread_position_in_grid]]
) {
uint idx = id * SIMD_WIDTH;
if (idx >= count) return;
float4 aVec = float4(a[idx], a[idx+1], a[idx+2], a[idx+3]);
float4 bVec = float4(b[idx], b[idx+1], b[idx+2], b[idx+3]);
float4 outVec = aVec * bVec;
if (idx < count) out[idx] = outVec.x;
if (idx+1 < count) out[idx+1] = outVec.y;
if (idx+2 < count) out[idx+2] = outVec.z;
if (idx+3 < count) out[idx+3] = outVec.w;
}
kernel void eltwise_add_simd(
device const float *a,
device const float *b,
device float *out,
constant uint &count,
uint id [[thread_position_in_grid]]
) {
uint idx = id * SIMD_WIDTH;
if (idx >= count) return;
float4 aVec = float4(a[idx], a[idx+1], a[idx+2], a[idx+3]);
float4 bVec = float4(b[idx], b[idx+1], b[idx+2], b[idx+3]);
float4 outVec = aVec + bVec;
if (idx < count) out[idx] = outVec.x;
if (idx+1 < count) out[idx+1] = outVec.y;
if (idx+2 < count) out[idx+2] = outVec.z;
if (idx+3 < count) out[idx+3] = outVec.w;
}