8a66b9086a
- Started from ac75faa (initial E4B-MarkBase integration)
- Kept Sources/ (all engine code) + Package.swift + .gitignore
- Removed all ad-hoc tests, documentation, scripts, Python files
- Added Tests/00_Unit/ (MathTest, TokenizerTest, SamplerTest)
- Added .gitea/workflows/ci.yaml (build + unit tests + lint)
- Added Scripts/check_resources.sh (memory-aware test runner)
- Added Tests/Manifest.json (resource requirements for all tests)
- Focus: 4-bit quantized models only
236 lines
7.8 KiB
Metal
236 lines
7.8 KiB
Metal
#include <metal_stdlib>
|
|
using namespace metal;
|
|
|
|
// ═══════════════════════════════════════════════
|
|
// Numerically Stable RMSNorm Kernel
|
|
// ═══════════════════════════════════════════════
|
|
|
|
// Optimized RMSNorm with numerical stability
|
|
// Uses threadgroup parallel reduction to avoid overflow
|
|
kernel void rms_norm_stable(
|
|
device const float *x [[buffer(0)]], // [N]
|
|
device const float *w [[buffer(1)]], // [N] weight (can be null)
|
|
device float *y [[buffer(2)]], // [N]
|
|
constant uint &N [[buffer(3)]],
|
|
constant float &eps [[buffer(4)]],
|
|
uint tid [[thread_position_in_threadgroup]],
|
|
uint gid [[thread_position_in_grid]],
|
|
uint tgsize [[threads_per_threadgroup]]
|
|
) {
|
|
// Early exit for out-of-range threads
|
|
if (gid >= N) return;
|
|
|
|
// Threadgroup shared memory for partial sums
|
|
threadgroup float partialSums[256];
|
|
|
|
// Step 1: Each thread computes partial sum with numerical stability
|
|
float localSum = 0.0;
|
|
uint chunkSize = N / tgsize;
|
|
uint start = tid * chunkSize;
|
|
uint end = min(start + chunkSize, N);
|
|
|
|
// Optimized SIMD batch clamp for performance
|
|
// Process 4 values at once using SIMD
|
|
for (uint i = start; i < end; i += 4) {
|
|
// Load 4 values
|
|
float4 xiVec = float4(
|
|
i < end ? x[i] : 0.0f,
|
|
i+1 < end ? x[i+1] : 0.0f,
|
|
i+2 < end ? x[i+2] : 0.0f,
|
|
i+3 < end ? x[i+3] : 0.0f
|
|
);
|
|
|
|
// Single clamp operation (SIMD)
|
|
xiVec = clamp(xiVec, -20.0f, 20.0f);
|
|
|
|
// Compute sum of squares
|
|
float4 sqVec = xiVec * xiVec;
|
|
localSum += sqVec[0] + sqVec[1] + sqVec[2] + sqVec[3];
|
|
}
|
|
|
|
// Store partial sum
|
|
if (tid < 256) {
|
|
partialSums[tid] = localSum;
|
|
}
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
// Step 2: Parallel reduction in threadgroup
|
|
// Reduce to single sum
|
|
for (uint stride = tgsize / 2; stride > 0; stride >>= 1) {
|
|
if (tid < stride) {
|
|
partialSums[tid] += partialSums[tid + stride];
|
|
}
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
}
|
|
|
|
// Step 3: Compute RMS from total sum
|
|
float totalSum = partialSums[0];
|
|
float meanSq = totalSum / float(N);
|
|
|
|
// Numerical stability: ensure meanSq is positive and reasonable
|
|
meanSq = max(meanSq, eps);
|
|
meanSq = min(meanSq, 10000.0f); // Prevent extreme RMS values
|
|
|
|
float rms = rsqrt(meanSq + eps);
|
|
|
|
// Numerical stability: clamp RMS to reasonable range
|
|
rms = clamp(rms, 0.01f, 100.0f);
|
|
|
|
// Step 4: Apply normalization
|
|
float xi = x[gid];
|
|
float yi = xi * rms;
|
|
|
|
// Apply weight if provided
|
|
if (w) {
|
|
yi *= w[gid];
|
|
}
|
|
|
|
// Final numerical stability: aggressive clamp output
|
|
// Progressive output clamp
|
|
float yiFinal = yi;
|
|
if (yiFinal > 50.0f) yiFinal = 50.0f;
|
|
else if (yiFinal < -50.0f) yiFinal = -50.0f;
|
|
else if (yiFinal > 20.0f) yiFinal = 20.0f + (yiFinal - 20.0f) * 0.2f;
|
|
else if (yiFinal < -20.0f) yiFinal = -20.0f + (yiFinal + 20.0f) * 0.2f;
|
|
y[gid] = yiFinal;
|
|
}
|
|
|
|
// ═══════════════════════════════════════════════
|
|
// Numerically Stable Softmax Kernel
|
|
// ═══════════════════════════════════════════════
|
|
|
|
// Stable softmax with numerical overflow protection
|
|
kernel void softmax_stable(
|
|
device const float *logits [[buffer(0)]], // [N]
|
|
device float *probs [[buffer(1)]], // [N]
|
|
constant uint &N [[buffer(2)]],
|
|
uint tid [[thread_position_in_threadgroup]],
|
|
uint gid [[thread_position_in_grid]],
|
|
uint tgsize [[threads_per_threadgroup]]
|
|
) {
|
|
if (gid >= N) return;
|
|
|
|
threadgroup float sharedMax[256];
|
|
threadgroup float sharedSumExp[256];
|
|
|
|
// Step 1: Find max using threadgroup parallel reduction
|
|
float localMax = -INFINITY;
|
|
uint chunkSize = N / tgsize;
|
|
uint start = tid * chunkSize;
|
|
uint end = min(start + chunkSize, N);
|
|
|
|
for (uint i = start; i < end; i++) {
|
|
// More aggressive logits clamp
|
|
float li = logits[i];
|
|
if (li > 30.0f) li = 30.0f;
|
|
else if (li < -30.0f) li = -30.0f;
|
|
else if (li > 10.0f) li = 10.0f + (li - 10.0f) * 0.3f;
|
|
else if (li < -10.0f) li = -10.0f + (li + 10.0f) * 0.3f;
|
|
localMax = max(localMax, li);
|
|
}
|
|
|
|
if (tid < 256) {
|
|
sharedMax[tid] = localMax;
|
|
}
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
// Parallel reduction to find global max
|
|
for (uint stride = tgsize / 2; stride > 0; stride >>= 1) {
|
|
if (tid < stride) {
|
|
sharedMax[tid] = max(sharedMax[tid], sharedMax[tid + stride]);
|
|
}
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
}
|
|
|
|
float globalMax = sharedMax[0];
|
|
|
|
// Optimized SIMD batch softmax
|
|
float localSumExp = 0.0;
|
|
for (uint i = start; i < end; i += 4) {
|
|
float4 liVec = float4(
|
|
i < end ? logits[i] : 0.0f,
|
|
i+1 < end ? logits[i+1] : 0.0f,
|
|
i+2 < end ? logits[i+2] : 0.0f,
|
|
i+3 < end ? logits[i+3] : 0.0f
|
|
);
|
|
|
|
// SIMD clamp
|
|
liVec = clamp(liVec, -30.0f, 30.0f);
|
|
|
|
// SIMD compute diff
|
|
float4 diffVec = liVec - globalMax;
|
|
diffVec = clamp(diffVec, -10.0f, 10.0f);
|
|
|
|
// SIMD exp
|
|
float4 expVec = exp(diffVec);
|
|
localSumExp += expVec[0] + expVec[1] + expVec[2] + expVec[3];
|
|
}
|
|
|
|
if (tid < 256) {
|
|
sharedSumExp[tid] = localSumExp;
|
|
}
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
// Parallel reduction to compute total sumExp
|
|
for (uint stride = tgsize / 2; stride > 0; stride >>= 1) {
|
|
if (tid < stride) {
|
|
sharedSumExp[tid] += sharedSumExp[tid + stride];
|
|
}
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
}
|
|
|
|
float totalSumExp = sharedSumExp[0];
|
|
totalSumExp = max(totalSumExp, 1e-6f); // Prevent division by zero
|
|
|
|
// Step 3: Compute output
|
|
float li = logits[gid];
|
|
if (li > 30.0f) li = 30.0f;
|
|
else if (li < -30.0f) li = -30.0f;
|
|
else if (li > 10.0f) li = 10.0f + (li - 10.0f) * 0.3f;
|
|
else if (li < -10.0f) li = -10.0f + (li + 10.0f) * 0.3f;
|
|
|
|
float diff = li - globalMax;
|
|
if (diff > 10.0f) diff = 10.0f;
|
|
else if (diff < -10.0f) diff = -10.0f;
|
|
probs[gid] = exp(diff) / totalSumExp;
|
|
}
|
|
|
|
// Alternative: Block-wise RMSNorm for very large N
|
|
kernel void rms_norm_blockwise(
|
|
device const float *x [[buffer(0)]],
|
|
device const float *w [[buffer(1)]],
|
|
device float *y [[buffer(2)]],
|
|
constant uint &N [[buffer(3)]],
|
|
constant float &eps [[buffer(4)]],
|
|
constant uint &blockSize [[buffer(5)]],
|
|
uint gid [[thread_position_in_grid]]
|
|
) {
|
|
if (gid >= N) return;
|
|
|
|
// Compute block index
|
|
uint blockIdx = gid / blockSize;
|
|
uint blockStart = blockIdx * blockSize;
|
|
uint blockEnd = min(blockStart + blockSize, N);
|
|
|
|
// Compute sum of squares for this block only
|
|
float blockSum = 0.0;
|
|
for (uint i = blockStart; i < blockEnd; i++) {
|
|
float xi = clamp(x[i], -100.0f, 100.0f);
|
|
blockSum += xi * xi;
|
|
}
|
|
|
|
// Normalize by block size
|
|
float meanSq = blockSum / float(blockEnd - blockStart);
|
|
meanSq = max(meanSq, eps);
|
|
|
|
float rms = rsqrt(meanSq + eps);
|
|
rms = clamp(rms, 0.01f, 100.0f);
|
|
|
|
// Apply normalization
|
|
float xi = x[gid];
|
|
float yi = xi * rms;
|
|
|
|
if (w) yi *= w[gid];
|
|
|
|
y[gid] = clamp(yi, -1000.0f, 1000.0f);
|
|
} |