Files
markbaseengine/Sources/MarkBase/Metal/FusionKernels.metal
T
MarkBase Admin 8a66b9086a
CI / build (push) Waiting to run
CI / unit-tests (push) Blocked by required conditions
CI / lint (push) Blocked by required conditions
v2: Initial clean branch with unit tests + CI/CD pipeline
- Started from ac75faa (initial E4B-MarkBase integration)
- Kept Sources/ (all engine code) + Package.swift + .gitignore
- Removed all ad-hoc tests, documentation, scripts, Python files
- Added Tests/00_Unit/ (MathTest, TokenizerTest, SamplerTest)
- Added .gitea/workflows/ci.yaml (build + unit tests + lint)
- Added Scripts/check_resources.sh (memory-aware test runner)
- Added Tests/Manifest.json (resource requirements for all tests)
- Focus: 4-bit quantized models only
2026-07-05 13:29:25 +08:00

133 lines
5.7 KiB
Metal

#include <metal_stdlib>
using namespace metal;
// ════════════════════════════════════════════════════════
// Kernel Fusion Optimizations - Reduce dispatch overhead
// ════════════════════════════════════════════════════════
// Use SIMD_WIDTH from OptimizedKernels.metal (already defined as uint = 4)
// ── Fused RMS Norm + Quantized Matmul ────────────────
// Combines norm and projection in single kernel
// Saves 1 dispatch per layer (42 layers = 42 fewer dispatches)
kernel void rms_norm_matmul_fused(
device const float *x [[buffer(0)]], // Input [inDim]
device const float *normW [[buffer(1)]], // Norm weight [inDim]
device const uint *w [[buffer(2)]], // Packed weights [outDim, inDim/8]
device const float *s [[buffer(3)]], // Scales [outDim, inDim/64]
device const float *b [[buffer(4)]], // Biases [outDim, inDim/64]
device float *out [[buffer(5)]], // Output [outDim]
constant uint &inDim [[buffer(6)]],
constant uint &outDim [[buffer(7)]],
constant float &eps [[buffer(8)]],
constant uint &groupSize [[buffer(9)]],
threadgroup float *shared_norm_x [[threadgroup(0)]], // Normed input cache
uint gid [[thread_position_in_grid]],
uint tid [[thread_position_in_threadgroup]],
uint tgSize [[threads_per_threadgroup]]
) {
uint outRow = gid;
if (outRow >= outDim) return;
// ── Phase 1: RMS Norm (cooperative) ───────────────────────
// Compute sum of squares in threadgroup
float localSum = 0.0;
for (uint i = tid; i < inDim; i += tgSize) {
float val = x[i];
localSum += val * val;
}
// Parallel reduction (simplified - single threadgroup)
threadgroup float partial_sums[256];
partial_sums[tid] = localSum;
threadgroup_barrier(mem_flags::mem_threadgroup);
// Reduce to single sum
for (uint stride = tgSize/2; stride > 0; stride >>= 1) {
if (tid < stride) {
partial_sums[tid] += partial_sums[tid + stride];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
// Compute RMS and normalize
float rms = rsqrt(partial_sums[0] / float(inDim) + eps);
// Store normed values in threadgroup cache
for (uint i = tid; i < inDim; i += tgSize) {
shared_norm_x[i] = x[i] * rms * (normW ? normW[i] : 1.0);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// ── Phase 2: Quantized Matmul ─────────────────────────────
// Each thread processes one output row
uint numGroups = inDim / groupSize;
float sum = 0.0;
for (uint g = 0; g < numGroups; g++) {
float scale = s[outRow * numGroups + g];
float bias = b[outRow * numGroups + g];
uint packedBase = outRow * (inDim / 8) + g * (groupSize / 8);
// SIMD processing (batch 2 packed values)
for (uint p = 0; p < 8; p += 2) {
uint packed0 = w[packedBase + p];
uint packed1 = w[packedBase + p + 1];
uint xBase = g * groupSize + p * 8;
float4 xVec0 = float4(
shared_norm_x[xBase + 0], shared_norm_x[xBase + 1],
shared_norm_x[xBase + 2], shared_norm_x[xBase + 3]
);
float4 xVec1 = float4(
shared_norm_x[xBase + 4], shared_norm_x[xBase + 5],
shared_norm_x[xBase + 6], shared_norm_x[xBase + 7]
);
float4 xVec2 = float4(
shared_norm_x[xBase + 8], shared_norm_x[xBase + 9],
shared_norm_x[xBase + 10], shared_norm_x[xBase + 11]
);
float4 xVec3 = float4(
shared_norm_x[xBase + 12], shared_norm_x[xBase + 13],
shared_norm_x[xBase + 14], shared_norm_x[xBase + 15]
);
float4 qVec0 = float4(
float((packed0 >> 0) & 0xF) * scale + bias,
float((packed0 >> 4) & 0xF) * scale + bias,
float((packed0 >> 8) & 0xF) * scale + bias,
float((packed0 >> 12) & 0xF) * scale + bias
);
float4 qVec1 = float4(
float((packed0 >> 16) & 0xF) * scale + bias,
float((packed0 >> 20) & 0xF) * scale + bias,
float((packed0 >> 24) & 0xF) * scale + bias,
float((packed0 >> 28) & 0xF) * scale + bias
);
float4 qVec2 = float4(
float((packed1 >> 0) & 0xF) * scale + bias,
float((packed1 >> 4) & 0xF) * scale + bias,
float((packed1 >> 8) & 0xF) * scale + bias,
float((packed1 >> 12) & 0xF) * scale + bias
);
float4 qVec3 = float4(
float((packed1 >> 16) & 0xF) * scale + bias,
float((packed1 >> 20) & 0xF) * scale + bias,
float((packed1 >> 24) & 0xF) * scale + bias,
float((packed1 >> 28) & 0xF) * scale + bias
);
sum += dot(qVec0, xVec0);
sum += dot(qVec1, xVec1);
sum += dot(qVec2, xVec2);
sum += dot(qVec3, xVec3);
}
}
out[outRow] = sum;
}
// Note: batch_matmul_8 not possible in Metal - pointer arrays not supported as parameters
// Alternative: Use Argument Buffer (Metal 2.0+) or separate dispatches