8a66b9086a
- Started from ac75faa (initial E4B-MarkBase integration)
- Kept Sources/ (all engine code) + Package.swift + .gitignore
- Removed all ad-hoc tests, documentation, scripts, Python files
- Added Tests/00_Unit/ (MathTest, TokenizerTest, SamplerTest)
- Added .gitea/workflows/ci.yaml (build + unit tests + lint)
- Added Scripts/check_resources.sh (memory-aware test runner)
- Added Tests/Manifest.json (resource requirements for all tests)
- Focus: 4-bit quantized models only
1228 lines
71 KiB
Metal
1228 lines
71 KiB
Metal
#include <metal_stdlib>
|
||
using namespace metal;
|
||
|
||
// ═══════════════════════════════════════════════
|
||
// SIMD Optimized Kernels - Phase 1
|
||
// ═══════════════════════════════════════════════
|
||
|
||
constant uint SIMD_WIDTH = 4;
|
||
constant uint HEAD_DIM = 128;
|
||
constant uint MAX_WINDOW = 4096;
|
||
|
||
// ── SIMD Optimized Sliding Attention ───────────────
|
||
// Uses threadgroup cache for K/V + float4 SIMD operations
|
||
kernel void sliding_attention_simd(
|
||
device const float *q [[buffer(0)]],
|
||
device const float *k [[buffer(1)]],
|
||
device const float *v [[buffer(2)]],
|
||
device float *out [[buffer(3)]],
|
||
constant uint &nHeads [[buffer(4)]],
|
||
constant uint &nKvHeads [[buffer(5)]],
|
||
constant uint &headDim [[buffer(6)]],
|
||
constant uint &windowSize [[buffer(7)]],
|
||
constant int &offset [[buffer(8)]],
|
||
threadgroup float4 *shared_k [[threadgroup(0)]], // [MAX_WINDOW, nKvHeads, headDim/SIMD_WIDTH]
|
||
threadgroup float4 *shared_v [[threadgroup(1)]], // [MAX_WINDOW, nKvHeads, headDim/SIMD_WIDTH]
|
||
uint2 gid [[thread_position_in_grid]],
|
||
uint2 tid [[thread_position_in_threadgroup]],
|
||
uint2 tgSize [[threads_per_threadgroup]]
|
||
) {
|
||
uint head = gid.x;
|
||
uint dimBlock = gid.y; // dim/SIMD_WIDTH
|
||
|
||
if (head >= nHeads || dimBlock >= headDim/SIMD_WIDTH) return;
|
||
|
||
uint kvHead = head % nKvHeads;
|
||
uint seqLen = uint(offset + 1);
|
||
uint actualWindow = min(seqLen, windowSize);
|
||
int base = int(offset) - int(actualWindow) + 1;
|
||
|
||
// ── Threadgroup Cache Loading ───────────────────
|
||
// Cooperative loading of K and V into threadgroup memory
|
||
uint loadStride = tgSize.x * tgSize.y;
|
||
uint totalElements = actualWindow * nKvHeads * (headDim/SIMD_WIDTH);
|
||
|
||
for (uint idx = tid.y * tgSize.x + tid.x; idx < totalElements; idx += loadStride) {
|
||
uint t = idx / (nKvHeads * (headDim/SIMD_WIDTH));
|
||
uint h = (idx % (nKvHeads * (headDim/SIMD_WIDTH))) / (headDim/SIMD_WIDTH);
|
||
uint db = idx % (headDim/SIMD_WIDTH);
|
||
|
||
int logicalPos = base + int(t);
|
||
uint cacheIdx = logicalPos >= 0 ? uint(logicalPos) % windowSize : 0;
|
||
|
||
uint kOffset = (cacheIdx * nKvHeads + h) * headDim + db * SIMD_WIDTH;
|
||
uint vOffset = (cacheIdx * nKvHeads + h) * headDim + db * SIMD_WIDTH;
|
||
|
||
shared_k[idx] = float4(
|
||
k[kOffset], k[kOffset+1], k[kOffset+2], k[kOffset+3]
|
||
);
|
||
shared_v[idx] = float4(
|
||
v[vOffset], v[vOffset+1], v[vOffset+2], v[vOffset+3]
|
||
);
|
||
}
|
||
|
||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||
|
||
// ── SIMD Dot Product ────────────────────────────
|
||
float scale = 1.0 / sqrt(float(headDim));
|
||
|
||
// Load Q as float4 (SIMD)
|
||
uint qOffset = head * headDim + dimBlock * SIMD_WIDTH;
|
||
float4 qVec = float4(q[qOffset], q[qOffset+1], q[qOffset+2], q[qOffset+3]);
|
||
|
||
// Pass 1: find max score
|
||
float maxScore = -INFINITY;
|
||
for (uint t = 0; t < actualWindow; t++) {
|
||
uint sharedIdx = t * nKvHeads * (headDim/SIMD_WIDTH) + kvHead * (headDim/SIMD_WIDTH) + dimBlock;
|
||
float4 kVec = shared_k[sharedIdx];
|
||
|
||
float score = dot(qVec, kVec) * scale;
|
||
// Text model has NO attention softcapping (same as original kernel)
|
||
maxScore = max(maxScore, score);
|
||
}
|
||
|
||
// Pass 2: softmax + weighted sum
|
||
float sumExp = 0.0;
|
||
float4 resultVec = float4(0.0);
|
||
|
||
for (uint t = 0; t < actualWindow; t++) {
|
||
uint sharedIdx = t * nKvHeads * (headDim/SIMD_WIDTH) + kvHead * (headDim/SIMD_WIDTH) + dimBlock;
|
||
float4 kVec = shared_k[sharedIdx];
|
||
float4 vVec = shared_v[sharedIdx];
|
||
|
||
float score = dot(qVec, kVec) * scale;
|
||
// Text model has NO attention softcapping (same as original kernel)
|
||
float expVal = exp(score - maxScore);
|
||
sumExp += expVal;
|
||
resultVec += expVal * vVec;
|
||
}
|
||
|
||
resultVec /= sumExp;
|
||
|
||
// Write output
|
||
uint outOffset = head * headDim + dimBlock * SIMD_WIDTH;
|
||
out[outOffset] = resultVec.x;
|
||
out[outOffset+1] = resultVec.y;
|
||
out[outOffset+2] = resultVec.z;
|
||
out[outOffset+3] = resultVec.w;
|
||
}
|
||
|
||
// ── SIMD Optimized Full Attention ──────────────────
|
||
kernel void full_attention_simd(
|
||
device const float *q [[buffer(0)]],
|
||
device const float *k [[buffer(1)]],
|
||
device const float *v [[buffer(2)]],
|
||
device float *out [[buffer(3)]],
|
||
constant uint &nHeads [[buffer(4)]],
|
||
constant uint &nKvHeads [[buffer(5)]],
|
||
constant uint &headDim [[buffer(6)]],
|
||
constant uint &seqLen [[buffer(7)]],
|
||
threadgroup float4 *shared_k [[threadgroup(0)]],
|
||
threadgroup float4 *shared_v [[threadgroup(1)]],
|
||
uint2 gid [[thread_position_in_grid]],
|
||
uint2 tid [[thread_position_in_threadgroup]],
|
||
uint2 tgSize [[threads_per_threadgroup]]
|
||
) {
|
||
uint head = gid.x;
|
||
uint dimBlock = gid.y;
|
||
|
||
if (head >= nHeads || dimBlock >= headDim/SIMD_WIDTH) return;
|
||
|
||
uint kvHead = head % nKvHeads;
|
||
|
||
// Threadgroup cache loading
|
||
uint loadStride = tgSize.x * tgSize.y;
|
||
uint totalElements = seqLen * nKvHeads * (headDim/SIMD_WIDTH);
|
||
|
||
for (uint idx = tid.y * tgSize.x + tid.x; idx < totalElements; idx += loadStride) {
|
||
uint t = idx / (nKvHeads * (headDim/SIMD_WIDTH));
|
||
uint h = (idx % (nKvHeads * (headDim/SIMD_WIDTH))) / (headDim/SIMD_WIDTH);
|
||
uint db = idx % (headDim/SIMD_WIDTH);
|
||
|
||
uint kOffset = (t * nKvHeads + h) * headDim + db * SIMD_WIDTH;
|
||
uint vOffset = (t * nKvHeads + h) * headDim + db * SIMD_WIDTH;
|
||
|
||
shared_k[idx] = float4(k[kOffset], k[kOffset+1], k[kOffset+2], k[kOffset+3]);
|
||
shared_v[idx] = float4(v[vOffset], v[vOffset+1], v[vOffset+2], v[vOffset+3]);
|
||
}
|
||
|
||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||
|
||
float scale = 1.0 / sqrt(float(headDim));
|
||
|
||
float4 qVec = float4(
|
||
q[head * headDim + dimBlock * SIMD_WIDTH],
|
||
q[head * headDim + dimBlock * SIMD_WIDTH + 1],
|
||
q[head * headDim + dimBlock * SIMD_WIDTH + 2],
|
||
q[head * headDim + dimBlock * SIMD_WIDTH + 3]
|
||
);
|
||
|
||
// Pass 1: max score
|
||
float maxScore = -INFINITY;
|
||
for (uint t = 0; t < seqLen; t++) {
|
||
uint sharedIdx = t * nKvHeads * (headDim/SIMD_WIDTH) + kvHead * (headDim/SIMD_WIDTH) + dimBlock;
|
||
float score = dot(qVec, shared_k[sharedIdx]) * scale;
|
||
// Text model has NO attention softcapping
|
||
maxScore = max(maxScore, score);
|
||
}
|
||
|
||
// Pass 2: softmax + weighted sum
|
||
float sumExp = 0.0;
|
||
float4 resultVec = float4(0.0);
|
||
|
||
for (uint t = 0; t < seqLen; t++) {
|
||
uint sharedIdx = t * nKvHeads * (headDim/SIMD_WIDTH) + kvHead * (headDim/SIMD_WIDTH) + dimBlock;
|
||
float score = dot(qVec, shared_k[sharedIdx]) * scale;
|
||
// Text model has NO attention softcapping
|
||
float expVal = exp(score - maxScore);
|
||
sumExp += expVal;
|
||
resultVec += expVal * shared_v[sharedIdx];
|
||
}
|
||
|
||
resultVec /= sumExp;
|
||
|
||
uint outOffset = head * headDim + dimBlock * SIMD_WIDTH;
|
||
out[outOffset] = resultVec.x;
|
||
out[outOffset+1] = resultVec.y;
|
||
out[outOffset+2] = resultVec.z;
|
||
out[outOffset+3] = resultVec.w;
|
||
}
|
||
|
||
// ── Block-based Quantized Matmul (High Performance) ────
|
||
// Each threadgroup processes multiple output rows cooperatively
|
||
// Each thread computes partial sum, then reduce to final output
|
||
kernel void quantized_matmul_block(
|
||
device const float *x [[buffer(0)]],
|
||
device const uint *w [[buffer(1)]],
|
||
device const float *s [[buffer(2)]],
|
||
device const float *b [[buffer(3)]],
|
||
device float *out [[buffer(4)]],
|
||
constant uint &inDim [[buffer(5)]],
|
||
constant uint &outDim [[buffer(6)]],
|
||
constant uint &groupSize [[buffer(7)]], // 64
|
||
threadgroup float *partial_sums [[threadgroup(0)]], // For reduction
|
||
threadgroup float *shared_x [[threadgroup(1)]], // Input vector cache
|
||
uint2 gid [[thread_position_in_grid]],
|
||
uint2 tid [[thread_position_in_threadgroup]],
|
||
uint2 tgSize [[threads_per_threadgroup]]
|
||
) {
|
||
uint outRow = gid.x; // Output row index
|
||
uint threadInRow = gid.y; // Thread index within row (0..tgSize.y-1)
|
||
|
||
if (outRow >= outDim) return;
|
||
|
||
uint numThreadsPerRow = tgSize.y;
|
||
uint numGroups = inDim / groupSize;
|
||
|
||
// ── Cooperative loading of input vector ──────────────────
|
||
for (uint i = tid.y * tgSize.x + tid.x; i < inDim; i += tgSize.x * tgSize.y) {
|
||
shared_x[i] = x[i];
|
||
}
|
||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||
|
||
// ── Compute partial dot product for this thread ────────────────
|
||
// Each thread processes a portion of input dimensions
|
||
uint groupsPerThread = numGroups / numThreadsPerRow;
|
||
uint startGroup = threadInRow * groupsPerThread;
|
||
uint endGroup = (threadInRow == numThreadsPerRow - 1) ? numGroups : startGroup + groupsPerThread;
|
||
|
||
float localSum = 0.0;
|
||
|
||
for (uint g = startGroup; g < endGroup; g++) {
|
||
float scale = s[outRow * numGroups + g];
|
||
float bias = b[outRow * numGroups + g];
|
||
|
||
uint packedBase = outRow * (inDim / 8) + g * (groupSize / 8);
|
||
|
||
// Process 2 packed uint32 at a time (16 nibbles)
|
||
for (uint p = 0; p < 8; p += 2) {
|
||
uint packed0 = w[packedBase + p];
|
||
uint packed1 = w[packedBase + p + 1];
|
||
|
||
uint xBase = g * groupSize + p * 8;
|
||
|
||
// SIMD float4 processing
|
||
float4 xVec0 = float4(
|
||
shared_x[xBase + 0], shared_x[xBase + 1],
|
||
shared_x[xBase + 2], shared_x[xBase + 3]
|
||
);
|
||
float4 xVec1 = float4(
|
||
shared_x[xBase + 4], shared_x[xBase + 5],
|
||
shared_x[xBase + 6], shared_x[xBase + 7]
|
||
);
|
||
float4 xVec2 = float4(
|
||
shared_x[xBase + 8], shared_x[xBase + 9],
|
||
shared_x[xBase + 10], shared_x[xBase + 11]
|
||
);
|
||
float4 xVec3 = float4(
|
||
shared_x[xBase + 12], shared_x[xBase + 13],
|
||
shared_x[xBase + 14], shared_x[xBase + 15]
|
||
);
|
||
|
||
float4 qVec0 = float4(
|
||
float((packed0 >> 0) & 0xF) * scale + bias,
|
||
float((packed0 >> 4) & 0xF) * scale + bias,
|
||
float((packed0 >> 8) & 0xF) * scale + bias,
|
||
float((packed0 >> 12) & 0xF) * scale + bias
|
||
);
|
||
float4 qVec1 = float4(
|
||
float((packed0 >> 16) & 0xF) * scale + bias,
|
||
float((packed0 >> 20) & 0xF) * scale + bias,
|
||
float((packed0 >> 24) & 0xF) * scale + bias,
|
||
float((packed0 >> 28) & 0xF) * scale + bias
|
||
);
|
||
float4 qVec2 = float4(
|
||
float((packed1 >> 0) & 0xF) * scale + bias,
|
||
float((packed1 >> 4) & 0xF) * scale + bias,
|
||
float((packed1 >> 8) & 0xF) * scale + bias,
|
||
float((packed1 >> 12) & 0xF) * scale + bias
|
||
);
|
||
float4 qVec3 = float4(
|
||
float((packed1 >> 16) & 0xF) * scale + bias,
|
||
float((packed1 >> 20) & 0xF) * scale + bias,
|
||
float((packed1 >> 24) & 0xF) * scale + bias,
|
||
float((packed1 >> 28) & 0xF) * scale + bias
|
||
);
|
||
|
||
localSum += dot(qVec0, xVec0);
|
||
localSum += dot(qVec1, xVec1);
|
||
localSum += dot(qVec2, xVec2);
|
||
localSum += dot(qVec3, xVec3);
|
||
}
|
||
}
|
||
|
||
// ── Parallel reduction within threadgroup ───────────────────
|
||
uint reductionIdx = tid.y; // Each thread in row contributes to one sum
|
||
partial_sums[reductionIdx] = localSum;
|
||
|
||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||
|
||
// Reduce to final sum (one thread writes final result)
|
||
if (tid.y == 0) {
|
||
float finalSum = 0.0;
|
||
for (uint t = 0; t < numThreadsPerRow; t++) {
|
||
finalSum += partial_sums[t];
|
||
}
|
||
out[outRow] = finalSum;
|
||
}
|
||
}
|
||
|
||
// ── SIMD Optimized Quantized Matmul (Legacy) ───────────────
|
||
// Kept for backward compatibility
|
||
kernel void quantized_matmul_simd(
|
||
device const float *x [[buffer(0)]],
|
||
device const uint *w [[buffer(1)]],
|
||
device const float *s [[buffer(2)]],
|
||
device const float *b [[buffer(3)]],
|
||
device float *out [[buffer(4)]],
|
||
constant uint &inDim [[buffer(5)]],
|
||
constant uint &outDim [[buffer(6)]],
|
||
constant uint &groupSize [[buffer(7)]], // 64
|
||
threadgroup float *shared_x [[threadgroup(0)]],
|
||
uint gid [[thread_position_in_grid]],
|
||
uint tid [[thread_position_in_threadgroup]],
|
||
uint tgSize [[threads_per_threadgroup]]
|
||
) {
|
||
uint outRow = gid;
|
||
if (outRow >= outDim) return;
|
||
|
||
// ── Cooperative input load ────────────────────────────
|
||
for (uint i = tid; i < inDim; i += tgSize) {
|
||
shared_x[i] = x[i];
|
||
}
|
||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||
|
||
// ── Compute dot product ────────────────────────────────
|
||
uint numGroups = inDim / groupSize;
|
||
float sum = 0.0;
|
||
|
||
for (uint g = 0; g < numGroups; g++) {
|
||
float scale = s[outRow * numGroups + g];
|
||
float bias = b[outRow * numGroups + g];
|
||
|
||
uint packedBase = outRow * (inDim / 8) + g * (groupSize / 8);
|
||
uint xBase = g * groupSize;
|
||
|
||
// Process 4 uint32 per iteration (32 nibbles) — half the loop count
|
||
for (uint p = 0; p < 8; p += 4) {
|
||
// Vectorized uint4 load (reduces load instructions)
|
||
device uint4 *packedPtr = (device uint4*)(&w[packedBase + p]);
|
||
uint4 packed = *packedPtr;
|
||
|
||
// Load 32 input values as 8 × float4
|
||
float4 xVec0 = float4(shared_x[xBase + p*8 + 0], shared_x[xBase + p*8 + 1], shared_x[xBase + p*8 + 2], shared_x[xBase + p*8 + 3]);
|
||
float4 xVec1 = float4(shared_x[xBase + p*8 + 4], shared_x[xBase + p*8 + 5], shared_x[xBase + p*8 + 6], shared_x[xBase + p*8 + 7]);
|
||
float4 xVec2 = float4(shared_x[xBase + p*8 + 8], shared_x[xBase + p*8 + 9], shared_x[xBase + p*8 + 10], shared_x[xBase + p*8 + 11]);
|
||
float4 xVec3 = float4(shared_x[xBase + p*8 + 12], shared_x[xBase + p*8 + 13], shared_x[xBase + p*8 + 14], shared_x[xBase + p*8 + 15]);
|
||
float4 xVec4 = float4(shared_x[xBase + p*8 + 16], shared_x[xBase + p*8 + 17], shared_x[xBase + p*8 + 18], shared_x[xBase + p*8 + 19]);
|
||
float4 xVec5 = float4(shared_x[xBase + p*8 + 20], shared_x[xBase + p*8 + 21], shared_x[xBase + p*8 + 22], shared_x[xBase + p*8 + 23]);
|
||
float4 xVec6 = float4(shared_x[xBase + p*8 + 24], shared_x[xBase + p*8 + 25], shared_x[xBase + p*8 + 26], shared_x[xBase + p*8 + 27]);
|
||
float4 xVec7 = float4(shared_x[xBase + p*8 + 28], shared_x[xBase + p*8 + 29], shared_x[xBase + p*8 + 30], shared_x[xBase + p*8 + 31]);
|
||
|
||
// Unpack + dequantize 4 uint32 → 8 float4, all with same scale+bias
|
||
float4 qVec0 = float4(float((packed.x >> 0) & 0xF) * scale + bias,
|
||
float((packed.x >> 4) & 0xF) * scale + bias,
|
||
float((packed.x >> 8) & 0xF) * scale + bias,
|
||
float((packed.x >> 12) & 0xF) * scale + bias);
|
||
float4 qVec1 = float4(float((packed.x >> 16) & 0xF) * scale + bias,
|
||
float((packed.x >> 20) & 0xF) * scale + bias,
|
||
float((packed.x >> 24) & 0xF) * scale + bias,
|
||
float((packed.x >> 28) & 0xF) * scale + bias);
|
||
float4 qVec2 = float4(float((packed.y >> 0) & 0xF) * scale + bias,
|
||
float((packed.y >> 4) & 0xF) * scale + bias,
|
||
float((packed.y >> 8) & 0xF) * scale + bias,
|
||
float((packed.y >> 12) & 0xF) * scale + bias);
|
||
float4 qVec3 = float4(float((packed.y >> 16) & 0xF) * scale + bias,
|
||
float((packed.y >> 20) & 0xF) * scale + bias,
|
||
float((packed.y >> 24) & 0xF) * scale + bias,
|
||
float((packed.y >> 28) & 0xF) * scale + bias);
|
||
float4 qVec4 = float4(float((packed.z >> 0) & 0xF) * scale + bias,
|
||
float((packed.z >> 4) & 0xF) * scale + bias,
|
||
float((packed.z >> 8) & 0xF) * scale + bias,
|
||
float((packed.z >> 12) & 0xF) * scale + bias);
|
||
float4 qVec5 = float4(float((packed.z >> 16) & 0xF) * scale + bias,
|
||
float((packed.z >> 20) & 0xF) * scale + bias,
|
||
float((packed.z >> 24) & 0xF) * scale + bias,
|
||
float((packed.z >> 28) & 0xF) * scale + bias);
|
||
float4 qVec6 = float4(float((packed.w >> 0) & 0xF) * scale + bias,
|
||
float((packed.w >> 4) & 0xF) * scale + bias,
|
||
float((packed.w >> 8) & 0xF) * scale + bias,
|
||
float((packed.w >> 12) & 0xF) * scale + bias);
|
||
float4 qVec7 = float4(float((packed.w >> 16) & 0xF) * scale + bias,
|
||
float((packed.w >> 20) & 0xF) * scale + bias,
|
||
float((packed.w >> 24) & 0xF) * scale + bias,
|
||
float((packed.w >> 28) & 0xF) * scale + bias);
|
||
|
||
// 8 × float4 dot products fused into one expression
|
||
sum += dot(qVec0, xVec0) + dot(qVec1, xVec1)
|
||
+ dot(qVec2, xVec2) + dot(qVec3, xVec3)
|
||
+ dot(qVec4, xVec4) + dot(qVec5, xVec5)
|
||
+ dot(qVec6, xVec6) + dot(qVec7, xVec7);
|
||
}
|
||
}
|
||
|
||
out[outRow] = sum;
|
||
}
|
||
|
||
// ── 8-bit SIMD Quantized Matmul ────────────────────
|
||
// Same as quantized_matmul_simd but for 8-bit weights (4 values per uint32, mask 0xFF)
|
||
kernel void quantized_matmul_simd_8bit(
|
||
device const float *x [[buffer(0)]],
|
||
device const uint *w [[buffer(1)]],
|
||
device const float *s [[buffer(2)]],
|
||
device const float *b [[buffer(3)]],
|
||
device float *out [[buffer(4)]],
|
||
constant uint &inDim [[buffer(5)]],
|
||
constant uint &outDim [[buffer(6)]],
|
||
constant uint &groupSize [[buffer(7)]], // 64
|
||
threadgroup float *shared_x [[threadgroup(0)]],
|
||
uint gid [[thread_position_in_grid]],
|
||
uint tid [[thread_position_in_threadgroup]],
|
||
uint tgSize [[threads_per_threadgroup]]
|
||
) {
|
||
uint outRow = gid;
|
||
if (outRow >= outDim) return;
|
||
|
||
for (uint i = tid; i < inDim; i += tgSize) {
|
||
shared_x[i] = x[i];
|
||
}
|
||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||
|
||
uint numGroups = inDim / groupSize;
|
||
float sum = 0.0;
|
||
|
||
for (uint g = 0; g < numGroups; g++) {
|
||
float scale = s[outRow * numGroups + g];
|
||
float bias = b[outRow * numGroups + g];
|
||
|
||
uint packedBase = outRow * (inDim / 4) + g * (groupSize / 4);
|
||
uint xBase = g * groupSize;
|
||
|
||
// Process 4 uint32 per iteration (16 × 8-bit values) — 2× fewer loops
|
||
for (uint p = 0; p < groupSize / 4; p += 4) {
|
||
device uint4 *packedPtr = (device uint4*)(&w[packedBase + p]);
|
||
uint4 packed = *packedPtr;
|
||
|
||
float4 xVec0 = float4(shared_x[xBase + p*4 + 0], shared_x[xBase + p*4 + 1], shared_x[xBase + p*4 + 2], shared_x[xBase + p*4 + 3]);
|
||
float4 xVec1 = float4(shared_x[xBase + p*4 + 4], shared_x[xBase + p*4 + 5], shared_x[xBase + p*4 + 6], shared_x[xBase + p*4 + 7]);
|
||
float4 xVec2 = float4(shared_x[xBase + p*4 + 8], shared_x[xBase + p*4 + 9], shared_x[xBase + p*4 + 10], shared_x[xBase + p*4 + 11]);
|
||
float4 xVec3 = float4(shared_x[xBase + p*4 + 12], shared_x[xBase + p*4 + 13], shared_x[xBase + p*4 + 14], shared_x[xBase + p*4 + 15]);
|
||
|
||
float4 qVec0 = float4(float((packed.x >> 0) & 0xFF) * scale + bias, float((packed.x >> 8) & 0xFF) * scale + bias, float((packed.x >> 16) & 0xFF) * scale + bias, float((packed.x >> 24) & 0xFF) * scale + bias);
|
||
float4 qVec1 = float4(float((packed.y >> 0) & 0xFF) * scale + bias, float((packed.y >> 8) & 0xFF) * scale + bias, float((packed.y >> 16) & 0xFF) * scale + bias, float((packed.y >> 24) & 0xFF) * scale + bias);
|
||
float4 qVec2 = float4(float((packed.z >> 0) & 0xFF) * scale + bias, float((packed.z >> 8) & 0xFF) * scale + bias, float((packed.z >> 16) & 0xFF) * scale + bias, float((packed.z >> 24) & 0xFF) * scale + bias);
|
||
float4 qVec3 = float4(float((packed.w >> 0) & 0xFF) * scale + bias, float((packed.w >> 8) & 0xFF) * scale + bias, float((packed.w >> 16) & 0xFF) * scale + bias, float((packed.w >> 24) & 0xFF) * scale + bias);
|
||
|
||
sum += dot(qVec0, xVec0) + dot(qVec1, xVec1)
|
||
+ dot(qVec2, xVec2) + dot(qVec3, xVec3);
|
||
}
|
||
}
|
||
|
||
out[outRow] = sum;
|
||
}
|
||
|
||
// ── Fused Gate+Up+Down for MoE Experts (4-bit) ────
|
||
// Single kernel replaces: fusedGateUp + blit + downMatmul + scaledAdd
|
||
// Phase 1: compute gate(x) * up(x) → threadgroup intermediate
|
||
// Phase 2: compute down(intermediate) → accum += weight * result
|
||
kernel void quantized_matmul_gate_up_down(
|
||
device const float *x [[buffer(0)]],
|
||
device const uint *w_gate [[buffer(1)]],
|
||
device const float *s_gate [[buffer(2)]],
|
||
device const float *b_gate [[buffer(3)]],
|
||
device const uint *w_up [[buffer(4)]],
|
||
device const float *s_up [[buffer(5)]],
|
||
device const float *b_up [[buffer(6)]],
|
||
device const uint *w_down [[buffer(7)]],
|
||
device const float *s_down [[buffer(8)]],
|
||
device const float *b_down [[buffer(9)]],
|
||
device float *accum [[buffer(10)]],
|
||
constant uint &hiddenSize [[buffer(11)]],
|
||
constant uint &moeIntermediate [[buffer(12)]],
|
||
constant uint &groupSize [[buffer(13)]],
|
||
constant float &expertWeight [[buffer(14)]],
|
||
threadgroup float *shared_space [[threadgroup(0)]],
|
||
uint gid [[thread_position_in_grid]],
|
||
uint tid [[thread_position_in_threadgroup]],
|
||
uint tgSize [[threads_per_threadgroup]]
|
||
) {
|
||
// ── Cooperative input load ────────────────────
|
||
for (uint i = tid; i < hiddenSize; i += tgSize) {
|
||
shared_space[i] = x[i];
|
||
}
|
||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||
|
||
uint numGroupsIn = hiddenSize / groupSize;
|
||
uint numGroupsOut = moeIntermediate / groupSize;
|
||
uint packedPerIn = hiddenSize / 8;
|
||
uint packedPerOut = moeIntermediate / 8;
|
||
|
||
// ── Phase 1: gate(x) + up(x) ─────────────────
|
||
if (gid < moeIntermediate) {
|
||
float gateSum = 0.0, upSum = 0.0;
|
||
|
||
for (uint g = 0; g < numGroupsIn; g++) {
|
||
float gScale = s_gate[gid * numGroupsIn + g];
|
||
float gBias = b_gate[gid * numGroupsIn + g];
|
||
float uScale = s_up[gid * numGroupsIn + g];
|
||
float uBias = b_up[gid * numGroupsIn + g];
|
||
|
||
uint wBase = gid * packedPerIn + g * (groupSize / 8);
|
||
uint xBase = g * groupSize;
|
||
|
||
for (uint p = 0; p < 8; p += 4) {
|
||
device uint4 *gPtr = (device uint4*)(&w_gate[wBase + p]);
|
||
device uint4 *uPtr = (device uint4*)(&w_up[wBase + p]);
|
||
uint4 gP = *gPtr;
|
||
uint4 uP = *uPtr;
|
||
|
||
float4 xv0 = float4(shared_space[xBase + p*8], shared_space[xBase + p*8 + 1], shared_space[xBase + p*8 + 2], shared_space[xBase + p*8 + 3]);
|
||
float4 xv1 = float4(shared_space[xBase + p*8 + 4], shared_space[xBase + p*8 + 5], shared_space[xBase + p*8 + 6], shared_space[xBase + p*8 + 7]);
|
||
float4 xv2 = float4(shared_space[xBase + p*8 + 8], shared_space[xBase + p*8 + 9], shared_space[xBase + p*8 + 10], shared_space[xBase + p*8 + 11]);
|
||
float4 xv3 = float4(shared_space[xBase + p*8 + 12], shared_space[xBase + p*8 + 13], shared_space[xBase + p*8 + 14], shared_space[xBase + p*8 + 15]);
|
||
float4 xv4 = float4(shared_space[xBase + p*8 + 16], shared_space[xBase + p*8 + 17], shared_space[xBase + p*8 + 18], shared_space[xBase + p*8 + 19]);
|
||
float4 xv5 = float4(shared_space[xBase + p*8 + 20], shared_space[xBase + p*8 + 21], shared_space[xBase + p*8 + 22], shared_space[xBase + p*8 + 23]);
|
||
float4 xv6 = float4(shared_space[xBase + p*8 + 24], shared_space[xBase + p*8 + 25], shared_space[xBase + p*8 + 26], shared_space[xBase + p*8 + 27]);
|
||
float4 xv7 = float4(shared_space[xBase + p*8 + 28], shared_space[xBase + p*8 + 29], shared_space[xBase + p*8 + 30], shared_space[xBase + p*8 + 31]);
|
||
|
||
float4 g0 = float4(float((gP.x >> 0) & 0xF) * gScale + gBias, float((gP.x >> 4) & 0xF) * gScale + gBias, float((gP.x >> 8) & 0xF) * gScale + gBias, float((gP.x >> 12) & 0xF) * gScale + gBias);
|
||
float4 g1 = float4(float((gP.x >> 16) & 0xF) * gScale + gBias, float((gP.x >> 20) & 0xF) * gScale + gBias, float((gP.x >> 24) & 0xF) * gScale + gBias, float((gP.x >> 28) & 0xF) * gScale + gBias);
|
||
float4 g2 = float4(float((gP.y >> 0) & 0xF) * gScale + gBias, float((gP.y >> 4) & 0xF) * gScale + gBias, float((gP.y >> 8) & 0xF) * gScale + gBias, float((gP.y >> 12) & 0xF) * gScale + gBias);
|
||
float4 g3 = float4(float((gP.y >> 16) & 0xF) * gScale + gBias, float((gP.y >> 20) & 0xF) * gScale + gBias, float((gP.y >> 24) & 0xF) * gScale + gBias, float((gP.y >> 28) & 0xF) * gScale + gBias);
|
||
float4 g4 = float4(float((gP.z >> 0) & 0xF) * gScale + gBias, float((gP.z >> 4) & 0xF) * gScale + gBias, float((gP.z >> 8) & 0xF) * gScale + gBias, float((gP.z >> 12) & 0xF) * gScale + gBias);
|
||
float4 g5 = float4(float((gP.z >> 16) & 0xF) * gScale + gBias, float((gP.z >> 20) & 0xF) * gScale + gBias, float((gP.z >> 24) & 0xF) * gScale + gBias, float((gP.z >> 28) & 0xF) * gScale + gBias);
|
||
float4 g6 = float4(float((gP.w >> 0) & 0xF) * gScale + gBias, float((gP.w >> 4) & 0xF) * gScale + gBias, float((gP.w >> 8) & 0xF) * gScale + gBias, float((gP.w >> 12) & 0xF) * gScale + gBias);
|
||
float4 g7 = float4(float((gP.w >> 16) & 0xF) * gScale + gBias, float((gP.w >> 20) & 0xF) * gScale + gBias, float((gP.w >> 24) & 0xF) * gScale + gBias, float((gP.w >> 28) & 0xF) * gScale + gBias);
|
||
|
||
float4 u0 = float4(float((uP.x >> 0) & 0xF) * uScale + uBias, float((uP.x >> 4) & 0xF) * uScale + uBias, float((uP.x >> 8) & 0xF) * uScale + uBias, float((uP.x >> 12) & 0xF) * uScale + uBias);
|
||
float4 u1 = float4(float((uP.x >> 16) & 0xF) * uScale + uBias, float((uP.x >> 20) & 0xF) * uScale + uBias, float((uP.x >> 24) & 0xF) * uScale + uBias, float((uP.x >> 28) & 0xF) * uScale + uBias);
|
||
float4 u2 = float4(float((uP.y >> 0) & 0xF) * uScale + uBias, float((uP.y >> 4) & 0xF) * uScale + uBias, float((uP.y >> 8) & 0xF) * uScale + uBias, float((uP.y >> 12) & 0xF) * uScale + uBias);
|
||
float4 u3 = float4(float((uP.y >> 16) & 0xF) * uScale + uBias, float((uP.y >> 20) & 0xF) * uScale + uBias, float((uP.y >> 24) & 0xF) * uScale + uBias, float((uP.y >> 28) & 0xF) * uScale + uBias);
|
||
float4 u4 = float4(float((uP.z >> 0) & 0xF) * uScale + uBias, float((uP.z >> 4) & 0xF) * uScale + uBias, float((uP.z >> 8) & 0xF) * uScale + uBias, float((uP.z >> 12) & 0xF) * uScale + uBias);
|
||
float4 u5 = float4(float((uP.z >> 16) & 0xF) * uScale + uBias, float((uP.z >> 20) & 0xF) * uScale + uBias, float((uP.z >> 24) & 0xF) * uScale + uBias, float((uP.z >> 28) & 0xF) * uScale + uBias);
|
||
float4 u6 = float4(float((uP.w >> 0) & 0xF) * uScale + uBias, float((uP.w >> 4) & 0xF) * uScale + uBias, float((uP.w >> 8) & 0xF) * uScale + uBias, float((uP.w >> 12) & 0xF) * uScale + uBias);
|
||
float4 u7 = float4(float((uP.w >> 16) & 0xF) * uScale + uBias, float((uP.w >> 20) & 0xF) * uScale + uBias, float((uP.w >> 24) & 0xF) * uScale + uBias, float((uP.w >> 28) & 0xF) * uScale + uBias);
|
||
|
||
gateSum += dot(g0, xv0) + dot(g1, xv1) + dot(g2, xv2) + dot(g3, xv3)
|
||
+ dot(g4, xv4) + dot(g5, xv5) + dot(g6, xv6) + dot(g7, xv7);
|
||
upSum += dot(u0, xv0) + dot(u1, xv1) + dot(u2, xv2) + dot(u3, xv3)
|
||
+ dot(u4, xv4) + dot(u5, xv5) + dot(u6, xv6) + dot(u7, xv7);
|
||
}
|
||
}
|
||
|
||
// GELU activation (same formula as existing quantized_matmul_gate_up)
|
||
if (gateSum > 100.0) gateSum = 100.0;
|
||
if (gateSum < -100.0) gateSum = -100.0;
|
||
float v = gateSum;
|
||
float geluVal;
|
||
float absv = v > 0 ? v : -v;
|
||
if (absv > 10.0) {
|
||
geluVal = v > 0 ? v : 0.0;
|
||
} else {
|
||
float v3 = v * v * v;
|
||
geluVal = 0.5 * v * (1.0 + tanh(0.7978845608028654 * (v + 0.044715 * v3)));
|
||
}
|
||
|
||
// Clamp upSum
|
||
if (upSum > 100.0) upSum = 100.0;
|
||
if (upSum < -100.0) upSum = -100.0;
|
||
|
||
float product = geluVal * upSum;
|
||
if (product > 10.0) product = 10.0;
|
||
if (product < -10.0) product = -10.0;
|
||
if (isnan(product) || isinf(product)) product = 0.0;
|
||
|
||
shared_space[gid] = product;
|
||
}
|
||
|
||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||
|
||
// ── Phase 2: down(intermediate) + accumulate ─
|
||
if (gid < hiddenSize) {
|
||
float sum = 0.0;
|
||
|
||
for (uint g = 0; g < numGroupsOut; g++) {
|
||
float scale = s_down[gid * numGroupsOut + g];
|
||
float bias = b_down[gid * numGroupsOut + g];
|
||
|
||
uint wBase = gid * packedPerOut + g * (groupSize / 8);
|
||
uint iBase = g * groupSize;
|
||
|
||
for (uint p = 0; p < 8; p += 4) {
|
||
device uint4 *wPtr = (device uint4*)(&w_down[wBase + p]);
|
||
uint4 packed = *wPtr;
|
||
|
||
float4 i0 = float4(shared_space[iBase + p*8], shared_space[iBase + p*8 + 1], shared_space[iBase + p*8 + 2], shared_space[iBase + p*8 + 3]);
|
||
float4 i1 = float4(shared_space[iBase + p*8 + 4], shared_space[iBase + p*8 + 5], shared_space[iBase + p*8 + 6], shared_space[iBase + p*8 + 7]);
|
||
float4 i2 = float4(shared_space[iBase + p*8 + 8], shared_space[iBase + p*8 + 9], shared_space[iBase + p*8 + 10], shared_space[iBase + p*8 + 11]);
|
||
float4 i3 = float4(shared_space[iBase + p*8 + 12], shared_space[iBase + p*8 + 13], shared_space[iBase + p*8 + 14], shared_space[iBase + p*8 + 15]);
|
||
float4 i4 = float4(shared_space[iBase + p*8 + 16], shared_space[iBase + p*8 + 17], shared_space[iBase + p*8 + 18], shared_space[iBase + p*8 + 19]);
|
||
float4 i5 = float4(shared_space[iBase + p*8 + 20], shared_space[iBase + p*8 + 21], shared_space[iBase + p*8 + 22], shared_space[iBase + p*8 + 23]);
|
||
float4 i6 = float4(shared_space[iBase + p*8 + 24], shared_space[iBase + p*8 + 25], shared_space[iBase + p*8 + 26], shared_space[iBase + p*8 + 27]);
|
||
float4 i7 = float4(shared_space[iBase + p*8 + 28], shared_space[iBase + p*8 + 29], shared_space[iBase + p*8 + 30], shared_space[iBase + p*8 + 31]);
|
||
|
||
float4 q0 = float4(float((packed.x >> 0) & 0xF) * scale + bias, float((packed.x >> 4) & 0xF) * scale + bias, float((packed.x >> 8) & 0xF) * scale + bias, float((packed.x >> 12) & 0xF) * scale + bias);
|
||
float4 q1 = float4(float((packed.x >> 16) & 0xF) * scale + bias, float((packed.x >> 20) & 0xF) * scale + bias, float((packed.x >> 24) & 0xF) * scale + bias, float((packed.x >> 28) & 0xF) * scale + bias);
|
||
float4 q2 = float4(float((packed.y >> 0) & 0xF) * scale + bias, float((packed.y >> 4) & 0xF) * scale + bias, float((packed.y >> 8) & 0xF) * scale + bias, float((packed.y >> 12) & 0xF) * scale + bias);
|
||
float4 q3 = float4(float((packed.y >> 16) & 0xF) * scale + bias, float((packed.y >> 20) & 0xF) * scale + bias, float((packed.y >> 24) & 0xF) * scale + bias, float((packed.y >> 28) & 0xF) * scale + bias);
|
||
float4 q4 = float4(float((packed.z >> 0) & 0xF) * scale + bias, float((packed.z >> 4) & 0xF) * scale + bias, float((packed.z >> 8) & 0xF) * scale + bias, float((packed.z >> 12) & 0xF) * scale + bias);
|
||
float4 q5 = float4(float((packed.z >> 16) & 0xF) * scale + bias, float((packed.z >> 20) & 0xF) * scale + bias, float((packed.z >> 24) & 0xF) * scale + bias, float((packed.z >> 28) & 0xF) * scale + bias);
|
||
float4 q6 = float4(float((packed.w >> 0) & 0xF) * scale + bias, float((packed.w >> 4) & 0xF) * scale + bias, float((packed.w >> 8) & 0xF) * scale + bias, float((packed.w >> 12) & 0xF) * scale + bias);
|
||
float4 q7 = float4(float((packed.w >> 16) & 0xF) * scale + bias, float((packed.w >> 20) & 0xF) * scale + bias, float((packed.w >> 24) & 0xF) * scale + bias, float((packed.w >> 28) & 0xF) * scale + bias);
|
||
|
||
sum += dot(q0, i0) + dot(q1, i1) + dot(q2, i2) + dot(q3, i3)
|
||
+ dot(q4, i4) + dot(q5, i5) + dot(q6, i6) + dot(q7, i7);
|
||
}
|
||
}
|
||
|
||
accum[gid] += expertWeight * sum;
|
||
}
|
||
}
|
||
|
||
// ── Fused Gate+Up+Down for MoE Experts (8-bit) ────
|
||
kernel void quantized_matmul_gate_up_down_8bit(
|
||
device const float *x [[buffer(0)]],
|
||
device const uint *w_gate [[buffer(1)]],
|
||
device const float *s_gate [[buffer(2)]],
|
||
device const float *b_gate [[buffer(3)]],
|
||
device const uint *w_up [[buffer(4)]],
|
||
device const float *s_up [[buffer(5)]],
|
||
device const float *b_up [[buffer(6)]],
|
||
device const uint *w_down [[buffer(7)]],
|
||
device const float *s_down [[buffer(8)]],
|
||
device const float *b_down [[buffer(9)]],
|
||
device float *accum [[buffer(10)]],
|
||
constant uint &hiddenSize [[buffer(11)]],
|
||
constant uint &moeIntermediate [[buffer(12)]],
|
||
constant uint &groupSize [[buffer(13)]],
|
||
constant float &expertWeight [[buffer(14)]],
|
||
threadgroup float *shared_space [[threadgroup(0)]],
|
||
uint gid [[thread_position_in_grid]],
|
||
uint tid [[thread_position_in_threadgroup]],
|
||
uint tgSize [[threads_per_threadgroup]]
|
||
) {
|
||
for (uint i = tid; i < hiddenSize; i += tgSize) {
|
||
shared_space[i] = x[i];
|
||
}
|
||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||
|
||
uint numGroupsIn = hiddenSize / groupSize;
|
||
uint numGroupsOut = moeIntermediate / groupSize;
|
||
uint packedPerIn = hiddenSize / 4;
|
||
uint packedPerOut = moeIntermediate / 4;
|
||
|
||
if (gid < moeIntermediate) {
|
||
float gateSum = 0.0, upSum = 0.0;
|
||
|
||
for (uint g = 0; g < numGroupsIn; g++) {
|
||
float gScale = s_gate[gid * numGroupsIn + g];
|
||
float gBias = b_gate[gid * numGroupsIn + g];
|
||
float uScale = s_up[gid * numGroupsIn + g];
|
||
float uBias = b_up[gid * numGroupsIn + g];
|
||
|
||
uint wBase = gid * packedPerIn + g * (groupSize / 4);
|
||
uint xBase = g * groupSize;
|
||
|
||
for (uint p = 0; p < groupSize / 4; p += 4) {
|
||
device uint4 *gPtr = (device uint4*)(&w_gate[wBase + p]);
|
||
device uint4 *uPtr = (device uint4*)(&w_up[wBase + p]);
|
||
uint4 gP = *gPtr;
|
||
uint4 uP = *uPtr;
|
||
|
||
float4 x0 = float4(shared_space[xBase + p*4], shared_space[xBase + p*4 + 1], shared_space[xBase + p*4 + 2], shared_space[xBase + p*4 + 3]);
|
||
float4 x1 = float4(shared_space[xBase + p*4 + 4], shared_space[xBase + p*4 + 5], shared_space[xBase + p*4 + 6], shared_space[xBase + p*4 + 7]);
|
||
float4 x2 = float4(shared_space[xBase + p*4 + 8], shared_space[xBase + p*4 + 9], shared_space[xBase + p*4 + 10], shared_space[xBase + p*4 + 11]);
|
||
float4 x3 = float4(shared_space[xBase + p*4 + 12], shared_space[xBase + p*4 + 13], shared_space[xBase + p*4 + 14], shared_space[xBase + p*4 + 15]);
|
||
|
||
float4 g0 = float4(float((gP.x >> 0) & 0xFF) * gScale + gBias, float((gP.x >> 8) & 0xFF) * gScale + gBias, float((gP.x >> 16) & 0xFF) * gScale + gBias, float((gP.x >> 24) & 0xFF) * gScale + gBias);
|
||
float4 g1 = float4(float((gP.y >> 0) & 0xFF) * gScale + gBias, float((gP.y >> 8) & 0xFF) * gScale + gBias, float((gP.y >> 16) & 0xFF) * gScale + gBias, float((gP.y >> 24) & 0xFF) * gScale + gBias);
|
||
float4 g2 = float4(float((gP.z >> 0) & 0xFF) * gScale + gBias, float((gP.z >> 8) & 0xFF) * gScale + gBias, float((gP.z >> 16) & 0xFF) * gScale + gBias, float((gP.z >> 24) & 0xFF) * gScale + gBias);
|
||
float4 g3 = float4(float((gP.w >> 0) & 0xFF) * gScale + gBias, float((gP.w >> 8) & 0xFF) * gScale + gBias, float((gP.w >> 16) & 0xFF) * gScale + gBias, float((gP.w >> 24) & 0xFF) * gScale + gBias);
|
||
|
||
float4 u0 = float4(float((uP.x >> 0) & 0xFF) * uScale + uBias, float((uP.x >> 8) & 0xFF) * uScale + uBias, float((uP.x >> 16) & 0xFF) * uScale + uBias, float((uP.x >> 24) & 0xFF) * uScale + uBias);
|
||
float4 u1 = float4(float((uP.y >> 0) & 0xFF) * uScale + uBias, float((uP.y >> 8) & 0xFF) * uScale + uBias, float((uP.y >> 16) & 0xFF) * uScale + uBias, float((uP.y >> 24) & 0xFF) * uScale + uBias);
|
||
float4 u2 = float4(float((uP.z >> 0) & 0xFF) * uScale + uBias, float((uP.z >> 8) & 0xFF) * uScale + uBias, float((uP.z >> 16) & 0xFF) * uScale + uBias, float((uP.z >> 24) & 0xFF) * uScale + uBias);
|
||
float4 u3 = float4(float((uP.w >> 0) & 0xFF) * uScale + uBias, float((uP.w >> 8) & 0xFF) * uScale + uBias, float((uP.w >> 16) & 0xFF) * uScale + uBias, float((uP.w >> 24) & 0xFF) * uScale + uBias);
|
||
|
||
gateSum += dot(g0, x0) + dot(g1, x1) + dot(g2, x2) + dot(g3, x3);
|
||
upSum += dot(u0, x0) + dot(u1, x1) + dot(u2, x2) + dot(u3, x3);
|
||
}
|
||
}
|
||
|
||
if (gateSum > 100.0) gateSum = 100.0;
|
||
if (gateSum < -100.0) gateSum = -100.0;
|
||
float v = gateSum;
|
||
float geluVal;
|
||
float absv = v > 0 ? v : -v;
|
||
if (absv > 10.0) {
|
||
geluVal = v > 0 ? v : 0.0;
|
||
} else {
|
||
float v3 = v * v * v;
|
||
geluVal = 0.5 * v * (1.0 + tanh(0.7978845608028654 * (v + 0.044715 * v3)));
|
||
}
|
||
|
||
if (upSum > 100.0) upSum = 100.0;
|
||
if (upSum < -100.0) upSum = -100.0;
|
||
|
||
float product = geluVal * upSum;
|
||
if (product > 10.0) product = 10.0;
|
||
if (product < -10.0) product = -10.0;
|
||
if (isnan(product) || isinf(product)) product = 0.0;
|
||
|
||
shared_space[gid] = product;
|
||
}
|
||
|
||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||
|
||
if (gid < hiddenSize) {
|
||
float sum = 0.0;
|
||
|
||
for (uint g = 0; g < numGroupsOut; g++) {
|
||
float scale = s_down[gid * numGroupsOut + g];
|
||
float bias = b_down[gid * numGroupsOut + g];
|
||
|
||
uint wBase = gid * packedPerOut + g * (groupSize / 4);
|
||
uint iBase = g * groupSize;
|
||
|
||
for (uint p = 0; p < groupSize / 4; p += 4) {
|
||
device uint4 *wPtr = (device uint4*)(&w_down[wBase + p]);
|
||
uint4 packed = *wPtr;
|
||
|
||
float4 i0 = float4(shared_space[iBase + p*4], shared_space[iBase + p*4 + 1], shared_space[iBase + p*4 + 2], shared_space[iBase + p*4 + 3]);
|
||
float4 i1 = float4(shared_space[iBase + p*4 + 4], shared_space[iBase + p*4 + 5], shared_space[iBase + p*4 + 6], shared_space[iBase + p*4 + 7]);
|
||
float4 i2 = float4(shared_space[iBase + p*4 + 8], shared_space[iBase + p*4 + 9], shared_space[iBase + p*4 + 10], shared_space[iBase + p*4 + 11]);
|
||
float4 i3 = float4(shared_space[iBase + p*4 + 12], shared_space[iBase + p*4 + 13], shared_space[iBase + p*4 + 14], shared_space[iBase + p*4 + 15]);
|
||
|
||
float4 q0 = float4(float((packed.x >> 0) & 0xFF) * scale + bias, float((packed.x >> 8) & 0xFF) * scale + bias, float((packed.x >> 16) & 0xFF) * scale + bias, float((packed.x >> 24) & 0xFF) * scale + bias);
|
||
float4 q1 = float4(float((packed.y >> 0) & 0xFF) * scale + bias, float((packed.y >> 8) & 0xFF) * scale + bias, float((packed.y >> 16) & 0xFF) * scale + bias, float((packed.y >> 24) & 0xFF) * scale + bias);
|
||
float4 q2 = float4(float((packed.z >> 0) & 0xFF) * scale + bias, float((packed.z >> 8) & 0xFF) * scale + bias, float((packed.z >> 16) & 0xFF) * scale + bias, float((packed.z >> 24) & 0xFF) * scale + bias);
|
||
float4 q3 = float4(float((packed.w >> 0) & 0xFF) * scale + bias, float((packed.w >> 8) & 0xFF) * scale + bias, float((packed.w >> 16) & 0xFF) * scale + bias, float((packed.w >> 24) & 0xFF) * scale + bias);
|
||
|
||
sum += dot(q0, i0) + dot(q1, i1) + dot(q2, i2) + dot(q3, i3);
|
||
}
|
||
}
|
||
|
||
accum[gid] += expertWeight * sum;
|
||
}
|
||
}
|
||
|
||
// ── Parallel RMS Norm ──────────────────────────────
|
||
// Uses threadgroup reduction for computing sum of squares
|
||
kernel void rms_norm_parallel(
|
||
device const float *x [[buffer(0)]],
|
||
device const float *w [[buffer(1)]],
|
||
device float *y [[buffer(2)]],
|
||
constant uint &N [[buffer(3)]],
|
||
constant float &eps [[buffer(4)]],
|
||
threadgroup float *partial_sums [[threadgroup(0)]],
|
||
uint tid [[thread_position_in_threadgroup]],
|
||
uint tgSize [[threads_per_threadgroup]],
|
||
uint gid [[thread_position_in_grid]]
|
||
) {
|
||
// Phase 1: Each thread computes partial sum of squares
|
||
float localSum = 0.0;
|
||
for (uint i = tid; i < N; i += tgSize) {
|
||
localSum += x[i] * x[i];
|
||
}
|
||
partial_sums[tid] = localSum;
|
||
|
||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||
|
||
// Phase 2: Parallel reduction
|
||
for (uint stride = tgSize/2; stride > 0; stride >>= 1) {
|
||
if (tid < stride) {
|
||
partial_sums[tid] += partial_sums[tid + stride];
|
||
}
|
||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||
}
|
||
|
||
// Phase 3: Compute RMS and normalize
|
||
float ss = partial_sums[0];
|
||
float rms = rsqrt(ss / float(N) + eps);
|
||
|
||
// Each thread outputs its portion
|
||
for (uint i = tid; i < N; i += tgSize) {
|
||
y[i] = x[i] * rms * (w ? w[i] : 1.0);
|
||
}
|
||
}
|
||
|
||
// ── MoE Mega Kernel (CPU-free Router + All Experts) ─
|
||
// Single kernel replaces: router matmul + CPU softmax/topk + 8× expert dispatch
|
||
// Eliminates 30 CPU syncs per token for MoE models
|
||
// Threadgroup memory layout:
|
||
// [0..hiddenSize-1] = x input (reloaded each expert iteration)
|
||
// [0..moeIntermediate-1] = intermediate gate*up (written each expert iteration)
|
||
// [numExperts..numExperts+numExperts-1] = router logits (numExperts = hiddenSize is actually not)
|
||
// [numExperts..hiddenSize-1-numExperts] used more efficiently:
|
||
// After x loaded, router uses shared_space for logits, then overwritten by intermediate
|
||
kernel void moe_mega_kernel(
|
||
device const float *x [[buffer(0)]],
|
||
device const uint *w_router [[buffer(1)]],
|
||
device const float *s_router [[buffer(2)]],
|
||
device const float *b_router [[buffer(3)]],
|
||
device const uint *w_gate [[buffer(4)]],
|
||
device const float *s_gate [[buffer(5)]],
|
||
device const float *b_gate [[buffer(6)]],
|
||
device const uint *w_up [[buffer(7)]],
|
||
device const float *s_up [[buffer(8)]],
|
||
device const float *b_up [[buffer(9)]],
|
||
device const uint *w_down [[buffer(10)]],
|
||
device const float *s_down [[buffer(11)]],
|
||
device const float *b_down [[buffer(12)]],
|
||
device float *accum [[buffer(13)]],
|
||
constant uint &hiddenSize [[buffer(14)]],
|
||
constant uint &moeIntermediate [[buffer(15)]],
|
||
constant uint &numExperts [[buffer(16)]],
|
||
constant float &routerScale [[buffer(17)]],
|
||
constant uint &topK [[buffer(18)]],
|
||
threadgroup float *shared_space [[threadgroup(0)]],
|
||
uint gid [[thread_position_in_grid]],
|
||
uint tid [[thread_position_in_threadgroup]],
|
||
uint tgSize [[threads_per_threadgroup]]
|
||
) {
|
||
uint numGroupsIn = hiddenSize / 64;
|
||
uint numGroupsOut = moeIntermediate / 64;
|
||
uint packedPerIn = hiddenSize / 8;
|
||
uint packedPerOut = moeIntermediate / 8;
|
||
|
||
// ── Phase 0: Cooperative load x ────────────
|
||
for (uint i = tid; i < hiddenSize; i += tgSize) {
|
||
shared_space[i] = x[i];
|
||
}
|
||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||
|
||
// ── Phase 0: Router matmul ─────────────────
|
||
// Router logits stored at shared_space[hiddenSize .. hiddenSize+numExperts-1]
|
||
uint logitBase = hiddenSize;
|
||
uint topKBase = hiddenSize + numExperts;
|
||
|
||
if (tid < numExperts) {
|
||
float logit = 0.0;
|
||
for (uint g = 0; g < numGroupsIn; g++) {
|
||
float scale = s_router[tid * numGroupsIn + g];
|
||
float bias = b_router[tid * numGroupsIn + g];
|
||
uint wBase = tid * packedPerIn + g * 8;
|
||
uint xBase = g * 64;
|
||
|
||
for (uint p = 0; p < 8; p += 4) {
|
||
device uint4 *rPtr = (device uint4*)(&w_router[wBase + p]);
|
||
uint4 packed = *rPtr;
|
||
|
||
float4 xv0 = float4(shared_space[xBase + p*8], shared_space[xBase + p*8 + 1], shared_space[xBase + p*8 + 2], shared_space[xBase + p*8 + 3]);
|
||
float4 xv1 = float4(shared_space[xBase + p*8 + 4], shared_space[xBase + p*8 + 5], shared_space[xBase + p*8 + 6], shared_space[xBase + p*8 + 7]);
|
||
float4 xv2 = float4(shared_space[xBase + p*8 + 8], shared_space[xBase + p*8 + 9], shared_space[xBase + p*8 + 10], shared_space[xBase + p*8 + 11]);
|
||
float4 xv3 = float4(shared_space[xBase + p*8 + 12], shared_space[xBase + p*8 + 13], shared_space[xBase + p*8 + 14], shared_space[xBase + p*8 + 15]);
|
||
float4 xv4 = float4(shared_space[xBase + p*8 + 16], shared_space[xBase + p*8 + 17], shared_space[xBase + p*8 + 18], shared_space[xBase + p*8 + 19]);
|
||
float4 xv5 = float4(shared_space[xBase + p*8 + 20], shared_space[xBase + p*8 + 21], shared_space[xBase + p*8 + 22], shared_space[xBase + p*8 + 23]);
|
||
float4 xv6 = float4(shared_space[xBase + p*8 + 24], shared_space[xBase + p*8 + 25], shared_space[xBase + p*8 + 26], shared_space[xBase + p*8 + 27]);
|
||
float4 xv7 = float4(shared_space[xBase + p*8 + 28], shared_space[xBase + p*8 + 29], shared_space[xBase + p*8 + 30], shared_space[xBase + p*8 + 31]);
|
||
|
||
float4 q0 = float4(float((packed.x >> 0) & 0xF) * scale + bias, float((packed.x >> 4) & 0xF) * scale + bias, float((packed.x >> 8) & 0xF) * scale + bias, float((packed.x >> 12) & 0xF) * scale + bias);
|
||
float4 q1 = float4(float((packed.x >> 16) & 0xF) * scale + bias, float((packed.x >> 20) & 0xF) * scale + bias, float((packed.x >> 24) & 0xF) * scale + bias, float((packed.x >> 28) & 0xF) * scale + bias);
|
||
float4 q2 = float4(float((packed.y >> 0) & 0xF) * scale + bias, float((packed.y >> 4) & 0xF) * scale + bias, float((packed.y >> 8) & 0xF) * scale + bias, float((packed.y >> 12) & 0xF) * scale + bias);
|
||
float4 q3 = float4(float((packed.y >> 16) & 0xF) * scale + bias, float((packed.y >> 20) & 0xF) * scale + bias, float((packed.y >> 24) & 0xF) * scale + bias, float((packed.y >> 28) & 0xF) * scale + bias);
|
||
float4 q4 = float4(float((packed.z >> 0) & 0xF) * scale + bias, float((packed.z >> 4) & 0xF) * scale + bias, float((packed.z >> 8) & 0xF) * scale + bias, float((packed.z >> 12) & 0xF) * scale + bias);
|
||
float4 q5 = float4(float((packed.z >> 16) & 0xF) * scale + bias, float((packed.z >> 20) & 0xF) * scale + bias, float((packed.z >> 24) & 0xF) * scale + bias, float((packed.z >> 28) & 0xF) * scale + bias);
|
||
float4 q6 = float4(float((packed.w >> 0) & 0xF) * scale + bias, float((packed.w >> 4) & 0xF) * scale + bias, float((packed.w >> 8) & 0xF) * scale + bias, float((packed.w >> 12) & 0xF) * scale + bias);
|
||
float4 q7 = float4(float((packed.w >> 16) & 0xF) * scale + bias, float((packed.w >> 20) & 0xF) * scale + bias, float((packed.w >> 24) & 0xF) * scale + bias, float((packed.w >> 28) & 0xF) * scale + bias);
|
||
|
||
logit += dot(q0, xv0) + dot(q1, xv1) + dot(q2, xv2) + dot(q3, xv3)
|
||
+ dot(q4, xv4) + dot(q5, xv5) + dot(q6, xv6) + dot(q7, xv7);
|
||
}
|
||
}
|
||
shared_space[logitBase + tid] = logit * routerScale;
|
||
}
|
||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||
|
||
// ── Phase 0: Softmax ────────────────────────
|
||
// Find max
|
||
float maxVal = -FLT_MAX;
|
||
if (tid < numExperts) {
|
||
maxVal = shared_space[logitBase + tid];
|
||
}
|
||
float maxReduce = maxVal;
|
||
for (uint stride = tgSize/2; stride > 0; stride >>= 1) {
|
||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||
if (tid < stride) {
|
||
maxReduce = fmax(maxReduce, shared_space[logitBase + tid + stride]);
|
||
shared_space[logitBase + tid] = maxReduce;
|
||
}
|
||
}
|
||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||
float globalMax = shared_space[logitBase];
|
||
|
||
// Compute exp sum
|
||
float localExp = 0.0;
|
||
if (tid < numExperts) {
|
||
float v = shared_space[logitBase + tid];
|
||
localExp = exp(v - globalMax);
|
||
shared_space[logitBase + tid] = localExp;
|
||
}
|
||
float sumReduce = localExp;
|
||
for (uint stride = tgSize/2; stride > 0; stride >>= 1) {
|
||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||
if (tid < stride) {
|
||
sumReduce += shared_space[logitBase + tid + stride];
|
||
shared_space[logitBase + tid] = sumReduce;
|
||
}
|
||
}
|
||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||
float expSum = shared_space[logitBase];
|
||
if (expSum <= 0) expSum = 1.0;
|
||
|
||
// Normalize
|
||
if (tid < numExperts) {
|
||
shared_space[logitBase + tid] /= expSum;
|
||
}
|
||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||
|
||
// ── Phase 0: Top-K ─────────────────────────
|
||
if (tid == 0) {
|
||
float vals[128];
|
||
uint idx[128];
|
||
for (uint i = 0; i < numExperts; i++) {
|
||
vals[i] = shared_space[logitBase + i];
|
||
idx[i] = i;
|
||
}
|
||
for (uint i = 0; i < topK; i++) {
|
||
uint best = i;
|
||
for (uint j = i + 1; j < numExperts; j++) {
|
||
if (vals[j] > vals[best]) { best = j; }
|
||
}
|
||
float tmpV = vals[i]; vals[i] = vals[best]; vals[best] = tmpV;
|
||
uint tmpI = idx[i]; idx[i] = idx[best]; idx[best] = tmpI;
|
||
shared_space[topKBase + i] = float(idx[i]);
|
||
}
|
||
float topKSum = 0.0;
|
||
for (uint i = 0; i < topK; i++) { topKSum += vals[i]; }
|
||
if (topKSum <= 0) topKSum = 1.0;
|
||
for (uint i = 0; i < topK; i++) {
|
||
shared_space[topKBase + topK + i] = vals[i] / topKSum;
|
||
}
|
||
}
|
||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||
|
||
// ── Phases 1-8: Expert dispatch ─────────────
|
||
uint expertWeightBase = topKBase;
|
||
|
||
for (uint e = 0; e < topK; e++) {
|
||
uint expertIdx = uint(shared_space[expertWeightBase + e]);
|
||
float expertWeight = shared_space[expertWeightBase + topK + e];
|
||
|
||
// Reload x for this expert (intermediate from previous iteration
|
||
// overwrites shared_space[0..moeIntermediate-1])
|
||
for (uint i = tid; i < hiddenSize; i += tgSize) {
|
||
shared_space[i] = x[i];
|
||
}
|
||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||
|
||
// Phase 1: gate+up → intermediate
|
||
if (gid < moeIntermediate) {
|
||
float gateSum = 0.0, upSum = 0.0;
|
||
// weight and scale buffers have different strides
|
||
uint wGateBase = expertIdx * moeIntermediate * packedPerIn;
|
||
uint sGateBase = expertIdx * moeIntermediate * numGroupsIn;
|
||
uint wUpBase = expertIdx * moeIntermediate * packedPerIn;
|
||
uint sUpBase = expertIdx * moeIntermediate * numGroupsIn;
|
||
|
||
for (uint g = 0; g < numGroupsIn; g++) {
|
||
float gScale = s_gate[sGateBase + gid * numGroupsIn + g];
|
||
float gBias = b_gate[sGateBase + gid * numGroupsIn + g];
|
||
float uScale = s_up[sUpBase + gid * numGroupsIn + g];
|
||
float uBias = b_up[sUpBase + gid * numGroupsIn + g];
|
||
|
||
uint wb = gid * packedPerIn + g * 8;
|
||
uint xBase = g * 64;
|
||
|
||
for (uint p = 0; p < 8; p += 4) {
|
||
device uint4 *gPtr = (device uint4*)(&w_gate[wGateBase + wb + p]);
|
||
device uint4 *uPtr = (device uint4*)(&w_up[wUpBase + wb + p]);
|
||
uint4 gP = *gPtr;
|
||
uint4 uP = *uPtr;
|
||
|
||
float4 xv0 = float4(shared_space[xBase + p*8], shared_space[xBase + p*8 + 1], shared_space[xBase + p*8 + 2], shared_space[xBase + p*8 + 3]);
|
||
float4 xv1 = float4(shared_space[xBase + p*8 + 4], shared_space[xBase + p*8 + 5], shared_space[xBase + p*8 + 6], shared_space[xBase + p*8 + 7]);
|
||
float4 xv2 = float4(shared_space[xBase + p*8 + 8], shared_space[xBase + p*8 + 9], shared_space[xBase + p*8 + 10], shared_space[xBase + p*8 + 11]);
|
||
float4 xv3 = float4(shared_space[xBase + p*8 + 12], shared_space[xBase + p*8 + 13], shared_space[xBase + p*8 + 14], shared_space[xBase + p*8 + 15]);
|
||
float4 xv4 = float4(shared_space[xBase + p*8 + 16], shared_space[xBase + p*8 + 17], shared_space[xBase + p*8 + 18], shared_space[xBase + p*8 + 19]);
|
||
float4 xv5 = float4(shared_space[xBase + p*8 + 20], shared_space[xBase + p*8 + 21], shared_space[xBase + p*8 + 22], shared_space[xBase + p*8 + 23]);
|
||
float4 xv6 = float4(shared_space[xBase + p*8 + 24], shared_space[xBase + p*8 + 25], shared_space[xBase + p*8 + 26], shared_space[xBase + p*8 + 27]);
|
||
float4 xv7 = float4(shared_space[xBase + p*8 + 28], shared_space[xBase + p*8 + 29], shared_space[xBase + p*8 + 30], shared_space[xBase + p*8 + 31]);
|
||
|
||
float4 g0 = float4(float((gP.x >> 0) & 0xF) * gScale + gBias, float((gP.x >> 4) & 0xF) * gScale + gBias, float((gP.x >> 8) & 0xF) * gScale + gBias, float((gP.x >> 12) & 0xF) * gScale + gBias);
|
||
float4 g1 = float4(float((gP.x >> 16) & 0xF) * gScale + gBias, float((gP.x >> 20) & 0xF) * gScale + gBias, float((gP.x >> 24) & 0xF) * gScale + gBias, float((gP.x >> 28) & 0xF) * gScale + gBias);
|
||
float4 g2 = float4(float((gP.y >> 0) & 0xF) * gScale + gBias, float((gP.y >> 4) & 0xF) * gScale + gBias, float((gP.y >> 8) & 0xF) * gScale + gBias, float((gP.y >> 12) & 0xF) * gScale + gBias);
|
||
float4 g3 = float4(float((gP.y >> 16) & 0xF) * gScale + gBias, float((gP.y >> 20) & 0xF) * gScale + gBias, float((gP.y >> 24) & 0xF) * gScale + gBias, float((gP.y >> 28) & 0xF) * gScale + gBias);
|
||
float4 g4 = float4(float((gP.z >> 0) & 0xF) * gScale + gBias, float((gP.z >> 4) & 0xF) * gScale + gBias, float((gP.z >> 8) & 0xF) * gScale + gBias, float((gP.z >> 12) & 0xF) * gScale + gBias);
|
||
float4 g5 = float4(float((gP.z >> 16) & 0xF) * gScale + gBias, float((gP.z >> 20) & 0xF) * gScale + gBias, float((gP.z >> 24) & 0xF) * gScale + gBias, float((gP.z >> 28) & 0xF) * gScale + gBias);
|
||
float4 g6 = float4(float((gP.w >> 0) & 0xF) * gScale + gBias, float((gP.w >> 4) & 0xF) * gScale + gBias, float((gP.w >> 8) & 0xF) * gScale + gBias, float((gP.w >> 12) & 0xF) * gScale + gBias);
|
||
float4 g7 = float4(float((gP.w >> 16) & 0xF) * gScale + gBias, float((gP.w >> 20) & 0xF) * gScale + gBias, float((gP.w >> 24) & 0xF) * gScale + gBias, float((gP.w >> 28) & 0xF) * gScale + gBias);
|
||
|
||
float4 u0 = float4(float((uP.x >> 0) & 0xF) * uScale + uBias, float((uP.x >> 4) & 0xF) * uScale + uBias, float((uP.x >> 8) & 0xF) * uScale + uBias, float((uP.x >> 12) & 0xF) * uScale + uBias);
|
||
float4 u1 = float4(float((uP.x >> 16) & 0xF) * uScale + uBias, float((uP.x >> 20) & 0xF) * uScale + uBias, float((uP.x >> 24) & 0xF) * uScale + uBias, float((uP.x >> 28) & 0xF) * uScale + uBias);
|
||
float4 u2 = float4(float((uP.y >> 0) & 0xF) * uScale + uBias, float((uP.y >> 4) & 0xF) * uScale + uBias, float((uP.y >> 8) & 0xF) * uScale + uBias, float((uP.y >> 12) & 0xF) * uScale + uBias);
|
||
float4 u3 = float4(float((uP.y >> 16) & 0xF) * uScale + uBias, float((uP.y >> 20) & 0xF) * uScale + uBias, float((uP.y >> 24) & 0xF) * uScale + uBias, float((uP.y >> 28) & 0xF) * uScale + uBias);
|
||
float4 u4 = float4(float((uP.z >> 0) & 0xF) * uScale + uBias, float((uP.z >> 4) & 0xF) * uScale + uBias, float((uP.z >> 8) & 0xF) * uScale + uBias, float((uP.z >> 12) & 0xF) * uScale + uBias);
|
||
float4 u5 = float4(float((uP.z >> 16) & 0xF) * uScale + uBias, float((uP.z >> 20) & 0xF) * uScale + uBias, float((uP.z >> 24) & 0xF) * uScale + uBias, float((uP.z >> 28) & 0xF) * uScale + uBias);
|
||
float4 u6 = float4(float((uP.w >> 0) & 0xF) * uScale + uBias, float((uP.w >> 4) & 0xF) * uScale + uBias, float((uP.w >> 8) & 0xF) * uScale + uBias, float((uP.w >> 12) & 0xF) * uScale + uBias);
|
||
float4 u7 = float4(float((uP.w >> 16) & 0xF) * uScale + uBias, float((uP.w >> 20) & 0xF) * uScale + uBias, float((uP.w >> 24) & 0xF) * uScale + uBias, float((uP.w >> 28) & 0xF) * uScale + uBias);
|
||
|
||
gateSum += dot(g0, xv0) + dot(g1, xv1) + dot(g2, xv2) + dot(g3, xv3)
|
||
+ dot(g4, xv4) + dot(g5, xv5) + dot(g6, xv6) + dot(g7, xv7);
|
||
upSum += dot(u0, xv0) + dot(u1, xv1) + dot(u2, xv2) + dot(u3, xv3)
|
||
+ dot(u4, xv4) + dot(u5, xv5) + dot(u6, xv6) + dot(u7, xv7);
|
||
}
|
||
}
|
||
|
||
// GELU
|
||
if (gateSum > 100.0) gateSum = 100.0;
|
||
if (gateSum < -100.0) gateSum = -100.0;
|
||
float absv = gateSum > 0 ? gateSum : -gateSum;
|
||
float geluVal;
|
||
if (absv > 10.0) {
|
||
geluVal = gateSum > 0 ? gateSum : 0.0;
|
||
} else {
|
||
float v3 = gateSum * gateSum * gateSum;
|
||
geluVal = 0.5 * gateSum * (1.0 + tanh(0.7978845608028654 * (gateSum + 0.044715 * v3)));
|
||
}
|
||
|
||
if (upSum > 100.0) upSum = 100.0;
|
||
if (upSum < -100.0) upSum = -100.0;
|
||
|
||
float product = geluVal * upSum;
|
||
if (product > 10.0) product = 10.0;
|
||
if (product < -10.0) product = -10.0;
|
||
if (isnan(product) || isinf(product)) product = 0.0;
|
||
|
||
shared_space[gid] = product;
|
||
}
|
||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||
|
||
// Phase 2: down projection + accumulate
|
||
if (gid < hiddenSize) {
|
||
float sum = 0.0;
|
||
uint wDownBase = expertIdx * hiddenSize * packedPerOut;
|
||
|
||
for (uint g = 0; g < numGroupsOut; g++) {
|
||
float scale = s_down[wDownBase + gid * numGroupsOut + g];
|
||
float bias = b_down[wDownBase + gid * numGroupsOut + g];
|
||
|
||
uint wb = gid * packedPerOut + g * 8;
|
||
uint iBase = g * 64;
|
||
|
||
for (uint p = 0; p < 8; p += 4) {
|
||
device uint4 *wPtr = (device uint4*)(&w_down[wDownBase + wb + p]);
|
||
uint4 packed = *wPtr;
|
||
|
||
float4 i0 = float4(shared_space[iBase + p*8], shared_space[iBase + p*8 + 1], shared_space[iBase + p*8 + 2], shared_space[iBase + p*8 + 3]);
|
||
float4 i1 = float4(shared_space[iBase + p*8 + 4], shared_space[iBase + p*8 + 5], shared_space[iBase + p*8 + 6], shared_space[iBase + p*8 + 7]);
|
||
float4 i2 = float4(shared_space[iBase + p*8 + 8], shared_space[iBase + p*8 + 9], shared_space[iBase + p*8 + 10], shared_space[iBase + p*8 + 11]);
|
||
float4 i3 = float4(shared_space[iBase + p*8 + 12], shared_space[iBase + p*8 + 13], shared_space[iBase + p*8 + 14], shared_space[iBase + p*8 + 15]);
|
||
float4 i4 = float4(shared_space[iBase + p*8 + 16], shared_space[iBase + p*8 + 17], shared_space[iBase + p*8 + 18], shared_space[iBase + p*8 + 19]);
|
||
float4 i5 = float4(shared_space[iBase + p*8 + 20], shared_space[iBase + p*8 + 21], shared_space[iBase + p*8 + 22], shared_space[iBase + p*8 + 23]);
|
||
float4 i6 = float4(shared_space[iBase + p*8 + 24], shared_space[iBase + p*8 + 25], shared_space[iBase + p*8 + 26], shared_space[iBase + p*8 + 27]);
|
||
float4 i7 = float4(shared_space[iBase + p*8 + 28], shared_space[iBase + p*8 + 29], shared_space[iBase + p*8 + 30], shared_space[iBase + p*8 + 31]);
|
||
|
||
float4 q0 = float4(float((packed.x >> 0) & 0xF) * scale + bias, float((packed.x >> 4) & 0xF) * scale + bias, float((packed.x >> 8) & 0xF) * scale + bias, float((packed.x >> 12) & 0xF) * scale + bias);
|
||
float4 q1 = float4(float((packed.x >> 16) & 0xF) * scale + bias, float((packed.x >> 20) & 0xF) * scale + bias, float((packed.x >> 24) & 0xF) * scale + bias, float((packed.x >> 28) & 0xF) * scale + bias);
|
||
float4 q2 = float4(float((packed.y >> 0) & 0xF) * scale + bias, float((packed.y >> 4) & 0xF) * scale + bias, float((packed.y >> 8) & 0xF) * scale + bias, float((packed.y >> 12) & 0xF) * scale + bias);
|
||
float4 q3 = float4(float((packed.y >> 16) & 0xF) * scale + bias, float((packed.y >> 20) & 0xF) * scale + bias, float((packed.y >> 24) & 0xF) * scale + bias, float((packed.y >> 28) & 0xF) * scale + bias);
|
||
float4 q4 = float4(float((packed.z >> 0) & 0xF) * scale + bias, float((packed.z >> 4) & 0xF) * scale + bias, float((packed.z >> 8) & 0xF) * scale + bias, float((packed.z >> 12) & 0xF) * scale + bias);
|
||
float4 q5 = float4(float((packed.z >> 16) & 0xF) * scale + bias, float((packed.z >> 20) & 0xF) * scale + bias, float((packed.z >> 24) & 0xF) * scale + bias, float((packed.z >> 28) & 0xF) * scale + bias);
|
||
float4 q6 = float4(float((packed.w >> 0) & 0xF) * scale + bias, float((packed.w >> 4) & 0xF) * scale + bias, float((packed.w >> 8) & 0xF) * scale + bias, float((packed.w >> 12) & 0xF) * scale + bias);
|
||
float4 q7 = float4(float((packed.w >> 16) & 0xF) * scale + bias, float((packed.w >> 20) & 0xF) * scale + bias, float((packed.w >> 24) & 0xF) * scale + bias, float((packed.w >> 28) & 0xF) * scale + bias);
|
||
|
||
sum += dot(q0, i0) + dot(q1, i1) + dot(q2, i2) + dot(q3, i3)
|
||
+ dot(q4, i4) + dot(q5, i5) + dot(q6, i6) + dot(q7, i7);
|
||
}
|
||
}
|
||
|
||
accum[gid] += expertWeight * sum;
|
||
}
|
||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||
}
|
||
}
|
||
|
||
// ── Optimized Fused Gate+Up (Dense Models) ───────
|
||
// Threadgroup-cached input + uint4 loads
|
||
// Used by: 26B-Standard, 31B, E4B-MarkBase
|
||
kernel void quantized_matmul_gate_up_opt(
|
||
device const float *x [[buffer(0)]],
|
||
device const uint *w_gate [[buffer(1)]],
|
||
device const float *s_gate [[buffer(2)]],
|
||
device const float *b_gate [[buffer(3)]],
|
||
device const uint *w_up [[buffer(4)]],
|
||
device const float *s_up [[buffer(5)]],
|
||
device const float *b_up [[buffer(6)]],
|
||
device float *out [[buffer(7)]],
|
||
constant uint &inDim [[buffer(8)]],
|
||
constant uint &outDim [[buffer(9)]],
|
||
constant uint &groupSize [[buffer(10)]],
|
||
threadgroup float *shared_x [[threadgroup(0)]],
|
||
uint gid [[thread_position_in_grid]],
|
||
uint tid [[thread_position_in_threadgroup]],
|
||
uint tgSize [[threads_per_threadgroup]]
|
||
) {
|
||
if (gid >= outDim) return;
|
||
|
||
for (uint i = tid; i < inDim; i += tgSize) {
|
||
shared_x[i] = x[i];
|
||
}
|
||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||
|
||
uint numGroups = inDim / groupSize;
|
||
uint packedPerOut = inDim / 8;
|
||
|
||
float gateSum = 0.0, upSum = 0.0;
|
||
for (uint g = 0; g < numGroups; g++) {
|
||
float gScale = s_gate[gid * numGroups + g];
|
||
float gBias = b_gate[gid * numGroups + g];
|
||
float uScale = s_up[gid * numGroups + g];
|
||
float uBias = b_up[gid * numGroups + g];
|
||
|
||
uint wBase = gid * packedPerOut + g * (groupSize / 8);
|
||
uint xBase = g * groupSize;
|
||
|
||
for (uint p = 0; p < 8; p += 4) {
|
||
device uint4 *gPtr = (device uint4*)(&w_gate[wBase + p]);
|
||
device uint4 *uPtr = (device uint4*)(&w_up[wBase + p]);
|
||
uint4 gP = *gPtr;
|
||
uint4 uP = *uPtr;
|
||
|
||
float4 xv0 = float4(shared_x[xBase + p*8], shared_x[xBase + p*8 + 1], shared_x[xBase + p*8 + 2], shared_x[xBase + p*8 + 3]);
|
||
float4 xv1 = float4(shared_x[xBase + p*8 + 4], shared_x[xBase + p*8 + 5], shared_x[xBase + p*8 + 6], shared_x[xBase + p*8 + 7]);
|
||
float4 xv2 = float4(shared_x[xBase + p*8 + 8], shared_x[xBase + p*8 + 9], shared_x[xBase + p*8 + 10], shared_x[xBase + p*8 + 11]);
|
||
float4 xv3 = float4(shared_x[xBase + p*8 + 12], shared_x[xBase + p*8 + 13], shared_x[xBase + p*8 + 14], shared_x[xBase + p*8 + 15]);
|
||
float4 xv4 = float4(shared_x[xBase + p*8 + 16], shared_x[xBase + p*8 + 17], shared_x[xBase + p*8 + 18], shared_x[xBase + p*8 + 19]);
|
||
float4 xv5 = float4(shared_x[xBase + p*8 + 20], shared_x[xBase + p*8 + 21], shared_x[xBase + p*8 + 22], shared_x[xBase + p*8 + 23]);
|
||
float4 xv6 = float4(shared_x[xBase + p*8 + 24], shared_x[xBase + p*8 + 25], shared_x[xBase + p*8 + 26], shared_x[xBase + p*8 + 27]);
|
||
float4 xv7 = float4(shared_x[xBase + p*8 + 28], shared_x[xBase + p*8 + 29], shared_x[xBase + p*8 + 30], shared_x[xBase + p*8 + 31]);
|
||
|
||
float4 g0 = float4(float((gP.x >> 0) & 0xF) * gScale + gBias, float((gP.x >> 4) & 0xF) * gScale + gBias, float((gP.x >> 8) & 0xF) * gScale + gBias, float((gP.x >> 12) & 0xF) * gScale + gBias);
|
||
float4 g1 = float4(float((gP.x >> 16) & 0xF) * gScale + gBias, float((gP.x >> 20) & 0xF) * gScale + gBias, float((gP.x >> 24) & 0xF) * gScale + gBias, float((gP.x >> 28) & 0xF) * gScale + gBias);
|
||
float4 g2 = float4(float((gP.y >> 0) & 0xF) * gScale + gBias, float((gP.y >> 4) & 0xF) * gScale + gBias, float((gP.y >> 8) & 0xF) * gScale + gBias, float((gP.y >> 12) & 0xF) * gScale + gBias);
|
||
float4 g3 = float4(float((gP.y >> 16) & 0xF) * gScale + gBias, float((gP.y >> 20) & 0xF) * gScale + gBias, float((gP.y >> 24) & 0xF) * gScale + gBias, float((gP.y >> 28) & 0xF) * gScale + gBias);
|
||
float4 g4 = float4(float((gP.z >> 0) & 0xF) * gScale + gBias, float((gP.z >> 4) & 0xF) * gScale + gBias, float((gP.z >> 8) & 0xF) * gScale + gBias, float((gP.z >> 12) & 0xF) * gScale + gBias);
|
||
float4 g5 = float4(float((gP.z >> 16) & 0xF) * gScale + gBias, float((gP.z >> 20) & 0xF) * gScale + gBias, float((gP.z >> 24) & 0xF) * gScale + gBias, float((gP.z >> 28) & 0xF) * gScale + gBias);
|
||
float4 g6 = float4(float((gP.w >> 0) & 0xF) * gScale + gBias, float((gP.w >> 4) & 0xF) * gScale + gBias, float((gP.w >> 8) & 0xF) * gScale + gBias, float((gP.w >> 12) & 0xF) * gScale + gBias);
|
||
float4 g7 = float4(float((gP.w >> 16) & 0xF) * gScale + gBias, float((gP.w >> 20) & 0xF) * gScale + gBias, float((gP.w >> 24) & 0xF) * gScale + gBias, float((gP.w >> 28) & 0xF) * gScale + gBias);
|
||
|
||
float4 u0 = float4(float((uP.x >> 0) & 0xF) * uScale + uBias, float((uP.x >> 4) & 0xF) * uScale + uBias, float((uP.x >> 8) & 0xF) * uScale + uBias, float((uP.x >> 12) & 0xF) * uScale + uBias);
|
||
float4 u1 = float4(float((uP.x >> 16) & 0xF) * uScale + uBias, float((uP.x >> 20) & 0xF) * uScale + uBias, float((uP.x >> 24) & 0xF) * uScale + uBias, float((uP.x >> 28) & 0xF) * uScale + uBias);
|
||
float4 u2 = float4(float((uP.y >> 0) & 0xF) * uScale + uBias, float((uP.y >> 4) & 0xF) * uScale + uBias, float((uP.y >> 8) & 0xF) * uScale + uBias, float((uP.y >> 12) & 0xF) * uScale + uBias);
|
||
float4 u3 = float4(float((uP.y >> 16) & 0xF) * uScale + uBias, float((uP.y >> 20) & 0xF) * uScale + uBias, float((uP.y >> 24) & 0xF) * uScale + uBias, float((uP.y >> 28) & 0xF) * uScale + uBias);
|
||
float4 u4 = float4(float((uP.z >> 0) & 0xF) * uScale + uBias, float((uP.z >> 4) & 0xF) * uScale + uBias, float((uP.z >> 8) & 0xF) * uScale + uBias, float((uP.z >> 12) & 0xF) * uScale + uBias);
|
||
float4 u5 = float4(float((uP.z >> 16) & 0xF) * uScale + uBias, float((uP.z >> 20) & 0xF) * uScale + uBias, float((uP.z >> 24) & 0xF) * uScale + uBias, float((uP.z >> 28) & 0xF) * uScale + uBias);
|
||
float4 u6 = float4(float((uP.w >> 0) & 0xF) * uScale + uBias, float((uP.w >> 4) & 0xF) * uScale + uBias, float((uP.w >> 8) & 0xF) * uScale + uBias, float((uP.w >> 12) & 0xF) * uScale + uBias);
|
||
float4 u7 = float4(float((uP.w >> 16) & 0xF) * uScale + uBias, float((uP.w >> 20) & 0xF) * uScale + uBias, float((uP.w >> 24) & 0xF) * uScale + uBias, float((uP.w >> 28) & 0xF) * uScale + uBias);
|
||
|
||
gateSum += dot(g0, xv0) + dot(g1, xv1) + dot(g2, xv2) + dot(g3, xv3)
|
||
+ dot(g4, xv4) + dot(g5, xv5) + dot(g6, xv6) + dot(g7, xv7);
|
||
upSum += dot(u0, xv0) + dot(u1, xv1) + dot(u2, xv2) + dot(u3, xv3)
|
||
+ dot(u4, xv4) + dot(u5, xv5) + dot(u6, xv6) + dot(u7, xv7);
|
||
}
|
||
}
|
||
|
||
if (gateSum > 100.0) gateSum = 100.0;
|
||
if (gateSum < -100.0) gateSum = -100.0;
|
||
float v = gateSum;
|
||
float absv = v > 0 ? v : -v;
|
||
float gate;
|
||
if (absv > 10.0) {
|
||
gate = v > 0 ? v : 0.0;
|
||
} else {
|
||
float v3 = v * v * v;
|
||
gate = 0.5 * v * (1.0 + tanh(0.7978845608028654 * (v + 0.044715 * v3)));
|
||
}
|
||
|
||
if (upSum > 100.0) upSum = 100.0;
|
||
if (upSum < -100.0) upSum = -100.0;
|
||
|
||
float product = gate * upSum;
|
||
if (product > 10.0) product = 10.0;
|
||
if (product < -10.0) product = -10.0;
|
||
|
||
out[gid] = product;
|
||
}
|
||
|
||
// ── SIMD Elementwise Operations ─────────────────────
|
||
|
||
kernel void eltwise_mul_simd(
|
||
device const float *a,
|
||
device const float *b,
|
||
device float *out,
|
||
constant uint &count,
|
||
uint id [[thread_position_in_grid]]
|
||
) {
|
||
uint idx = id * SIMD_WIDTH;
|
||
if (idx >= count) return;
|
||
|
||
float4 aVec = float4(a[idx], a[idx+1], a[idx+2], a[idx+3]);
|
||
float4 bVec = float4(b[idx], b[idx+1], b[idx+2], b[idx+3]);
|
||
float4 outVec = aVec * bVec;
|
||
|
||
if (idx < count) out[idx] = outVec.x;
|
||
if (idx+1 < count) out[idx+1] = outVec.y;
|
||
if (idx+2 < count) out[idx+2] = outVec.z;
|
||
if (idx+3 < count) out[idx+3] = outVec.w;
|
||
}
|
||
|
||
kernel void eltwise_add_simd(
|
||
device const float *a,
|
||
device const float *b,
|
||
device float *out,
|
||
constant uint &count,
|
||
uint id [[thread_position_in_grid]]
|
||
) {
|
||
uint idx = id * SIMD_WIDTH;
|
||
if (idx >= count) return;
|
||
|
||
float4 aVec = float4(a[idx], a[idx+1], a[idx+2], a[idx+3]);
|
||
float4 bVec = float4(b[idx], b[idx+1], b[idx+2], b[idx+3]);
|
||
float4 outVec = aVec + bVec;
|
||
|
||
if (idx < count) out[idx] = outVec.x;
|
||
if (idx+1 < count) out[idx+1] = outVec.y;
|
||
if (idx+2 < count) out[idx+2] = outVec.z;
|
||
if (idx+3 < count) out[idx+3] = outVec.w;
|
||
} |