Files
markbaseengine/Sources/MarkBase/Metal/StableRMSNorm.metal
T
MarkBase Admin 8a66b9086a
CI / build (push) Waiting to run
CI / unit-tests (push) Blocked by required conditions
CI / lint (push) Blocked by required conditions
v2: Initial clean branch with unit tests + CI/CD pipeline
- Started from ac75faa (initial E4B-MarkBase integration)
- Kept Sources/ (all engine code) + Package.swift + .gitignore
- Removed all ad-hoc tests, documentation, scripts, Python files
- Added Tests/00_Unit/ (MathTest, TokenizerTest, SamplerTest)
- Added .gitea/workflows/ci.yaml (build + unit tests + lint)
- Added Scripts/check_resources.sh (memory-aware test runner)
- Added Tests/Manifest.json (resource requirements for all tests)
- Focus: 4-bit quantized models only
2026-07-05 13:29:25 +08:00

236 lines
7.8 KiB
Metal

#include <metal_stdlib>
using namespace metal;
// ═══════════════════════════════════════════════
// Numerically Stable RMSNorm Kernel
// ═══════════════════════════════════════════════
// Optimized RMSNorm with numerical stability
// Uses threadgroup parallel reduction to avoid overflow
kernel void rms_norm_stable(
device const float *x [[buffer(0)]], // [N]
device const float *w [[buffer(1)]], // [N] weight (can be null)
device float *y [[buffer(2)]], // [N]
constant uint &N [[buffer(3)]],
constant float &eps [[buffer(4)]],
uint tid [[thread_position_in_threadgroup]],
uint gid [[thread_position_in_grid]],
uint tgsize [[threads_per_threadgroup]]
) {
// Early exit for out-of-range threads
if (gid >= N) return;
// Threadgroup shared memory for partial sums
threadgroup float partialSums[256];
// Step 1: Each thread computes partial sum with numerical stability
float localSum = 0.0;
uint chunkSize = N / tgsize;
uint start = tid * chunkSize;
uint end = min(start + chunkSize, N);
// Optimized SIMD batch clamp for performance
// Process 4 values at once using SIMD
for (uint i = start; i < end; i += 4) {
// Load 4 values
float4 xiVec = float4(
i < end ? x[i] : 0.0f,
i+1 < end ? x[i+1] : 0.0f,
i+2 < end ? x[i+2] : 0.0f,
i+3 < end ? x[i+3] : 0.0f
);
// Single clamp operation (SIMD)
xiVec = clamp(xiVec, -20.0f, 20.0f);
// Compute sum of squares
float4 sqVec = xiVec * xiVec;
localSum += sqVec[0] + sqVec[1] + sqVec[2] + sqVec[3];
}
// Store partial sum
if (tid < 256) {
partialSums[tid] = localSum;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Step 2: Parallel reduction in threadgroup
// Reduce to single sum
for (uint stride = tgsize / 2; stride > 0; stride >>= 1) {
if (tid < stride) {
partialSums[tid] += partialSums[tid + stride];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
// Step 3: Compute RMS from total sum
float totalSum = partialSums[0];
float meanSq = totalSum / float(N);
// Numerical stability: ensure meanSq is positive and reasonable
meanSq = max(meanSq, eps);
meanSq = min(meanSq, 10000.0f); // Prevent extreme RMS values
float rms = rsqrt(meanSq + eps);
// Numerical stability: clamp RMS to reasonable range
rms = clamp(rms, 0.01f, 100.0f);
// Step 4: Apply normalization
float xi = x[gid];
float yi = xi * rms;
// Apply weight if provided
if (w) {
yi *= w[gid];
}
// Final numerical stability: aggressive clamp output
// Progressive output clamp
float yiFinal = yi;
if (yiFinal > 50.0f) yiFinal = 50.0f;
else if (yiFinal < -50.0f) yiFinal = -50.0f;
else if (yiFinal > 20.0f) yiFinal = 20.0f + (yiFinal - 20.0f) * 0.2f;
else if (yiFinal < -20.0f) yiFinal = -20.0f + (yiFinal + 20.0f) * 0.2f;
y[gid] = yiFinal;
}
// ═══════════════════════════════════════════════
// Numerically Stable Softmax Kernel
// ═══════════════════════════════════════════════
// Stable softmax with numerical overflow protection
kernel void softmax_stable(
device const float *logits [[buffer(0)]], // [N]
device float *probs [[buffer(1)]], // [N]
constant uint &N [[buffer(2)]],
uint tid [[thread_position_in_threadgroup]],
uint gid [[thread_position_in_grid]],
uint tgsize [[threads_per_threadgroup]]
) {
if (gid >= N) return;
threadgroup float sharedMax[256];
threadgroup float sharedSumExp[256];
// Step 1: Find max using threadgroup parallel reduction
float localMax = -INFINITY;
uint chunkSize = N / tgsize;
uint start = tid * chunkSize;
uint end = min(start + chunkSize, N);
for (uint i = start; i < end; i++) {
// More aggressive logits clamp
float li = logits[i];
if (li > 30.0f) li = 30.0f;
else if (li < -30.0f) li = -30.0f;
else if (li > 10.0f) li = 10.0f + (li - 10.0f) * 0.3f;
else if (li < -10.0f) li = -10.0f + (li + 10.0f) * 0.3f;
localMax = max(localMax, li);
}
if (tid < 256) {
sharedMax[tid] = localMax;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Parallel reduction to find global max
for (uint stride = tgsize / 2; stride > 0; stride >>= 1) {
if (tid < stride) {
sharedMax[tid] = max(sharedMax[tid], sharedMax[tid + stride]);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
float globalMax = sharedMax[0];
// Optimized SIMD batch softmax
float localSumExp = 0.0;
for (uint i = start; i < end; i += 4) {
float4 liVec = float4(
i < end ? logits[i] : 0.0f,
i+1 < end ? logits[i+1] : 0.0f,
i+2 < end ? logits[i+2] : 0.0f,
i+3 < end ? logits[i+3] : 0.0f
);
// SIMD clamp
liVec = clamp(liVec, -30.0f, 30.0f);
// SIMD compute diff
float4 diffVec = liVec - globalMax;
diffVec = clamp(diffVec, -10.0f, 10.0f);
// SIMD exp
float4 expVec = exp(diffVec);
localSumExp += expVec[0] + expVec[1] + expVec[2] + expVec[3];
}
if (tid < 256) {
sharedSumExp[tid] = localSumExp;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Parallel reduction to compute total sumExp
for (uint stride = tgsize / 2; stride > 0; stride >>= 1) {
if (tid < stride) {
sharedSumExp[tid] += sharedSumExp[tid + stride];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
float totalSumExp = sharedSumExp[0];
totalSumExp = max(totalSumExp, 1e-6f); // Prevent division by zero
// Step 3: Compute output
float li = logits[gid];
if (li > 30.0f) li = 30.0f;
else if (li < -30.0f) li = -30.0f;
else if (li > 10.0f) li = 10.0f + (li - 10.0f) * 0.3f;
else if (li < -10.0f) li = -10.0f + (li + 10.0f) * 0.3f;
float diff = li - globalMax;
if (diff > 10.0f) diff = 10.0f;
else if (diff < -10.0f) diff = -10.0f;
probs[gid] = exp(diff) / totalSumExp;
}
// Alternative: Block-wise RMSNorm for very large N
kernel void rms_norm_blockwise(
device const float *x [[buffer(0)]],
device const float *w [[buffer(1)]],
device float *y [[buffer(2)]],
constant uint &N [[buffer(3)]],
constant float &eps [[buffer(4)]],
constant uint &blockSize [[buffer(5)]],
uint gid [[thread_position_in_grid]]
) {
if (gid >= N) return;
// Compute block index
uint blockIdx = gid / blockSize;
uint blockStart = blockIdx * blockSize;
uint blockEnd = min(blockStart + blockSize, N);
// Compute sum of squares for this block only
float blockSum = 0.0;
for (uint i = blockStart; i < blockEnd; i++) {
float xi = clamp(x[i], -100.0f, 100.0f);
blockSum += xi * xi;
}
// Normalize by block size
float meanSq = blockSum / float(blockEnd - blockStart);
meanSq = max(meanSq, eps);
float rms = rsqrt(meanSq + eps);
rms = clamp(rms, 0.01f, 100.0f);
// Apply normalization
float xi = x[gid];
float yi = xi * rms;
if (w) yi *= w[gid];
y[gid] = clamp(yi, -1000.0f, 1000.0f);
}