8a66b9086a
- Started from ac75faa (initial E4B-MarkBase integration)
- Kept Sources/ (all engine code) + Package.swift + .gitignore
- Removed all ad-hoc tests, documentation, scripts, Python files
- Added Tests/00_Unit/ (MathTest, TokenizerTest, SamplerTest)
- Added .gitea/workflows/ci.yaml (build + unit tests + lint)
- Added Scripts/check_resources.sh (memory-aware test runner)
- Added Tests/Manifest.json (resource requirements for all tests)
- Focus: 4-bit quantized models only
133 lines
5.7 KiB
Metal
133 lines
5.7 KiB
Metal
#include <metal_stdlib>
|
|
using namespace metal;
|
|
|
|
// ════════════════════════════════════════════════════════
|
|
// Kernel Fusion Optimizations - Reduce dispatch overhead
|
|
// ════════════════════════════════════════════════════════
|
|
|
|
// Use SIMD_WIDTH from OptimizedKernels.metal (already defined as uint = 4)
|
|
|
|
// ── Fused RMS Norm + Quantized Matmul ────────────────
|
|
// Combines norm and projection in single kernel
|
|
// Saves 1 dispatch per layer (42 layers = 42 fewer dispatches)
|
|
kernel void rms_norm_matmul_fused(
|
|
device const float *x [[buffer(0)]], // Input [inDim]
|
|
device const float *normW [[buffer(1)]], // Norm weight [inDim]
|
|
device const uint *w [[buffer(2)]], // Packed weights [outDim, inDim/8]
|
|
device const float *s [[buffer(3)]], // Scales [outDim, inDim/64]
|
|
device const float *b [[buffer(4)]], // Biases [outDim, inDim/64]
|
|
device float *out [[buffer(5)]], // Output [outDim]
|
|
constant uint &inDim [[buffer(6)]],
|
|
constant uint &outDim [[buffer(7)]],
|
|
constant float &eps [[buffer(8)]],
|
|
constant uint &groupSize [[buffer(9)]],
|
|
threadgroup float *shared_norm_x [[threadgroup(0)]], // Normed input cache
|
|
uint gid [[thread_position_in_grid]],
|
|
uint tid [[thread_position_in_threadgroup]],
|
|
uint tgSize [[threads_per_threadgroup]]
|
|
) {
|
|
uint outRow = gid;
|
|
if (outRow >= outDim) return;
|
|
|
|
// ── Phase 1: RMS Norm (cooperative) ───────────────────────
|
|
// Compute sum of squares in threadgroup
|
|
float localSum = 0.0;
|
|
for (uint i = tid; i < inDim; i += tgSize) {
|
|
float val = x[i];
|
|
localSum += val * val;
|
|
}
|
|
|
|
// Parallel reduction (simplified - single threadgroup)
|
|
threadgroup float partial_sums[256];
|
|
partial_sums[tid] = localSum;
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
// Reduce to single sum
|
|
for (uint stride = tgSize/2; stride > 0; stride >>= 1) {
|
|
if (tid < stride) {
|
|
partial_sums[tid] += partial_sums[tid + stride];
|
|
}
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
}
|
|
|
|
// Compute RMS and normalize
|
|
float rms = rsqrt(partial_sums[0] / float(inDim) + eps);
|
|
|
|
// Store normed values in threadgroup cache
|
|
for (uint i = tid; i < inDim; i += tgSize) {
|
|
shared_norm_x[i] = x[i] * rms * (normW ? normW[i] : 1.0);
|
|
}
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
// ── Phase 2: Quantized Matmul ─────────────────────────────
|
|
// Each thread processes one output row
|
|
uint numGroups = inDim / groupSize;
|
|
float sum = 0.0;
|
|
|
|
for (uint g = 0; g < numGroups; g++) {
|
|
float scale = s[outRow * numGroups + g];
|
|
float bias = b[outRow * numGroups + g];
|
|
|
|
uint packedBase = outRow * (inDim / 8) + g * (groupSize / 8);
|
|
|
|
// SIMD processing (batch 2 packed values)
|
|
for (uint p = 0; p < 8; p += 2) {
|
|
uint packed0 = w[packedBase + p];
|
|
uint packed1 = w[packedBase + p + 1];
|
|
|
|
uint xBase = g * groupSize + p * 8;
|
|
|
|
float4 xVec0 = float4(
|
|
shared_norm_x[xBase + 0], shared_norm_x[xBase + 1],
|
|
shared_norm_x[xBase + 2], shared_norm_x[xBase + 3]
|
|
);
|
|
float4 xVec1 = float4(
|
|
shared_norm_x[xBase + 4], shared_norm_x[xBase + 5],
|
|
shared_norm_x[xBase + 6], shared_norm_x[xBase + 7]
|
|
);
|
|
float4 xVec2 = float4(
|
|
shared_norm_x[xBase + 8], shared_norm_x[xBase + 9],
|
|
shared_norm_x[xBase + 10], shared_norm_x[xBase + 11]
|
|
);
|
|
float4 xVec3 = float4(
|
|
shared_norm_x[xBase + 12], shared_norm_x[xBase + 13],
|
|
shared_norm_x[xBase + 14], shared_norm_x[xBase + 15]
|
|
);
|
|
|
|
float4 qVec0 = float4(
|
|
float((packed0 >> 0) & 0xF) * scale + bias,
|
|
float((packed0 >> 4) & 0xF) * scale + bias,
|
|
float((packed0 >> 8) & 0xF) * scale + bias,
|
|
float((packed0 >> 12) & 0xF) * scale + bias
|
|
);
|
|
float4 qVec1 = float4(
|
|
float((packed0 >> 16) & 0xF) * scale + bias,
|
|
float((packed0 >> 20) & 0xF) * scale + bias,
|
|
float((packed0 >> 24) & 0xF) * scale + bias,
|
|
float((packed0 >> 28) & 0xF) * scale + bias
|
|
);
|
|
float4 qVec2 = float4(
|
|
float((packed1 >> 0) & 0xF) * scale + bias,
|
|
float((packed1 >> 4) & 0xF) * scale + bias,
|
|
float((packed1 >> 8) & 0xF) * scale + bias,
|
|
float((packed1 >> 12) & 0xF) * scale + bias
|
|
);
|
|
float4 qVec3 = float4(
|
|
float((packed1 >> 16) & 0xF) * scale + bias,
|
|
float((packed1 >> 20) & 0xF) * scale + bias,
|
|
float((packed1 >> 24) & 0xF) * scale + bias,
|
|
float((packed1 >> 28) & 0xF) * scale + bias
|
|
);
|
|
|
|
sum += dot(qVec0, xVec0);
|
|
sum += dot(qVec1, xVec1);
|
|
sum += dot(qVec2, xVec2);
|
|
sum += dot(qVec3, xVec3);
|
|
}
|
|
}
|
|
|
|
out[outRow] = sum;
|
|
}
|
|
|
|
// Note: batch_matmul_8 not possible in Metal - pointer arrays not supported as parameters
|
|
// Alternative: Use Argument Buffer (Metal 2.0+) or separate dispatches
|