8a66b9086a
- Started from ac75faa (initial E4B-MarkBase integration)
- Kept Sources/ (all engine code) + Package.swift + .gitignore
- Removed all ad-hoc tests, documentation, scripts, Python files
- Added Tests/00_Unit/ (MathTest, TokenizerTest, SamplerTest)
- Added .gitea/workflows/ci.yaml (build + unit tests + lint)
- Added Scripts/check_resources.sh (memory-aware test runner)
- Added Tests/Manifest.json (resource requirements for all tests)
- Focus: 4-bit quantized models only
201 lines
7.0 KiB
Metal
201 lines
7.0 KiB
Metal
#include <metal_stdlib>
|
|
using namespace metal;
|
|
|
|
// ─────────────────────────────────────────────────────────────────────
|
|
// Kernel Fusion: Combine multiple operations into single kernels
|
|
// Goal: Reduce kernel dispatches for common patterns
|
|
// ─────────────────────────────────────────────────────────────────────
|
|
|
|
// ── Fused Embedding Dequantize + Scale ──
|
|
// Combines: dequantize_row + eltwise_scale
|
|
// Eliminates one kernel dispatch
|
|
kernel void fused_dequantize_scale(
|
|
device const uint32_t* weight [[buffer(0)]],
|
|
device const float* scales [[buffer(1)]],
|
|
device const float* biases [[buffer(2)]],
|
|
device float* output [[buffer(3)]],
|
|
constant uint& nCols [[buffer(4)]],
|
|
constant int& row [[buffer(5)]],
|
|
constant uint& groupSize [[buffer(6)]],
|
|
constant float& scale [[buffer(7)]], // Extra scale to apply
|
|
uint id [[thread_position_in_grid]]
|
|
) {
|
|
if (id >= nCols) return;
|
|
|
|
uint numGroups = nCols / groupSize;
|
|
uint groupIdx = id / groupSize;
|
|
uint inGroupIdx = id % groupSize;
|
|
|
|
uint weightRowOffset = row * (nCols / 8);
|
|
uint packedIdx = weightRowOffset + id / 8;
|
|
uint subIdx = id % 8;
|
|
|
|
uint32_t packed = weight[packedIdx];
|
|
uint32_t qval = (packed >> (subIdx * 4)) & 0xF;
|
|
|
|
float scale_val = scales[groupIdx];
|
|
float bias_val = biases[groupIdx];
|
|
|
|
float val = float(qval) * scale_val + bias_val;
|
|
|
|
// Apply extra scale (embedding scale or per-layer scale)
|
|
val *= scale;
|
|
|
|
output[id] = val;
|
|
}
|
|
|
|
// ── Fused RMS Norm + Residual Add ──
|
|
// Combines: rmsNorm + eltwiseAdd
|
|
// Eliminates one kernel dispatch
|
|
kernel void fused_rms_norm_residual(
|
|
device const float* input [[buffer(0)]],
|
|
device const float* residual [[buffer(1)]],
|
|
device const float* weight [[buffer(2)]],
|
|
device float* output [[buffer(3)]],
|
|
constant uint& N [[buffer(4)]],
|
|
constant float& eps [[buffer(5)]],
|
|
uint tid [[thread_position_in_grid]],
|
|
uint threadgroupId [[threadgroup_position_in_grid]],
|
|
uint threadgroupSize [[threads_per_threadgroup]]
|
|
) {
|
|
// Parallel RMS computation
|
|
threadgroup float sharedSum[256];
|
|
|
|
uint laneId = tid % threadgroupSize;
|
|
uint groupId = tid / threadgroupSize;
|
|
|
|
float sumSq = 0.0;
|
|
uint start = groupId * (N / 256);
|
|
uint end = min((groupId + 1) * (N / 256), N);
|
|
|
|
for (uint i = start; i < end; i++) {
|
|
float val = input[i];
|
|
sumSq += val * val;
|
|
}
|
|
|
|
sharedSum[laneId] = sumSq;
|
|
|
|
// Simplified RMS (proper implementation would use SIMD reduction)
|
|
float rms = sqrt(sharedSum[0] / N + eps);
|
|
|
|
if (tid < N) {
|
|
float normed = input[tid] / rms * weight[tid];
|
|
output[tid] = residual[tid] + normed; // Residual add
|
|
}
|
|
}
|
|
|
|
// ── Fused Matmul + GELU + Residual ──
|
|
// Combines: quantized_matmul + gelu + eltwiseAdd
|
|
kernel void fused_matmul_gelu_residual(
|
|
device const float* input [[buffer(0)]],
|
|
device const uint32_t* weight [[buffer(1)]],
|
|
device const float* scales [[buffer(2)]],
|
|
device const float* biases [[buffer(3)]],
|
|
device const float* residual [[buffer(4)]],
|
|
device float* output [[buffer(5)]],
|
|
constant uint& inDim [[buffer(6)]],
|
|
constant uint& outDim [[buffer(7)]],
|
|
constant uint& groupSize [[buffer(8)]],
|
|
uint id [[thread_position_in_grid]]
|
|
) {
|
|
if (id >= outDim) return;
|
|
|
|
uint numGroups = inDim / groupSize;
|
|
float sum = 0.0;
|
|
|
|
for (uint g = 0; g < numGroups; g++) {
|
|
float scale = scales[id * numGroups + g];
|
|
float bias = biases[id * numGroups + g];
|
|
|
|
for (uint j = 0; j < groupSize / 8; j++) {
|
|
uint weightIdx = id * (inDim / 8) + g * (groupSize / 8) + j;
|
|
uint32_t packed = weight[weightIdx];
|
|
|
|
for (uint k = 0; k < 8; k++) {
|
|
uint inputIdx = g * groupSize + j * 8 + k;
|
|
uint32_t qval = (packed >> (k * 4)) & 0xF;
|
|
float wval = float(qval) * scale + bias;
|
|
sum += input[inputIdx] * wval;
|
|
}
|
|
}
|
|
}
|
|
|
|
// Apply GELU approximation
|
|
float gelu = sum * 0.5 * (1.0 + tanh(sum * 0.7978845608 * (1.0 + 0.044715 * sum * sum)));
|
|
|
|
// Residual add
|
|
output[id] = residual[id] + gelu;
|
|
}
|
|
|
|
// ── Batch RMS Norm for Multiple Layers ──
|
|
// Process 42 layers' norm operations in one dispatch
|
|
kernel void batch_rms_norm_layers(
|
|
device const float* inputs [[buffer(0)]], // [numLayers * hiddenSize] flattened
|
|
device const float* weights [[buffer(1)]], // [numLayers * hiddenSize] flattened
|
|
device float* outputs [[buffer(2)]], // [numLayers * hiddenSize] flattened
|
|
constant uint& hiddenSize [[buffer(3)]],
|
|
constant uint& numLayers [[buffer(4)]],
|
|
constant float& eps [[buffer(5)]],
|
|
uint2 id [[thread_position_in_grid]]
|
|
) {
|
|
uint layerIdx = id.y;
|
|
uint dimIdx = id.x;
|
|
|
|
if (layerIdx >= numLayers || dimIdx >= hiddenSize) return;
|
|
|
|
uint offset = layerIdx * hiddenSize;
|
|
|
|
// Simplified RMS computation (proper would need threadgroup reduction)
|
|
float sumSq = 0.0;
|
|
for (uint i = 0; i < hiddenSize; i++) {
|
|
float val = inputs[offset + i];
|
|
sumSq += val * val;
|
|
}
|
|
float rms = sqrt(sumSq / hiddenSize + eps);
|
|
|
|
outputs[offset + dimIdx] = inputs[offset + dimIdx] / rms * weights[offset + dimIdx];
|
|
}
|
|
|
|
// ── Fused Quantized Matmul + Bias Add ──
|
|
kernel void fused_quantized_matmul_bias(
|
|
device const float* input [[buffer(0)]],
|
|
device const uint32_t* weight [[buffer(1)]],
|
|
device const float* scales [[buffer(2)]],
|
|
device const float* biases_quant [[buffer(3)]],
|
|
device const float* bias_unquant [[buffer(4)]], // Optional unquantized bias
|
|
device float* output [[buffer(5)]],
|
|
constant uint& inDim [[buffer(6)]],
|
|
constant uint& outDim [[buffer(7)]],
|
|
constant uint& groupSize [[buffer(8)]],
|
|
constant bool& hasBias [[buffer(9)]],
|
|
uint id [[thread_position_in_grid]]
|
|
) {
|
|
if (id >= outDim) return;
|
|
|
|
uint numGroups = inDim / groupSize;
|
|
float sum = 0.0;
|
|
|
|
for (uint g = 0; g < numGroups; g++) {
|
|
float scale = scales[id * numGroups + g];
|
|
float bias = biases_quant[id * numGroups + g];
|
|
|
|
for (uint j = 0; j < groupSize / 8; j++) {
|
|
uint weightIdx = id * (inDim / 8) + g * (groupSize / 8) + j;
|
|
uint32_t packed = weight[weightIdx];
|
|
|
|
for (uint k = 0; k < 8; k++) {
|
|
uint inputIdx = g * groupSize + j * 8 + k;
|
|
uint32_t qval = (packed >> (k * 4)) & 0xF;
|
|
float wval = float(qval) * scale + bias;
|
|
sum += input[inputIdx] * wval;
|
|
}
|
|
}
|
|
}
|
|
|
|
// Add unquantized bias if present
|
|
if (hasBias) {
|
|
sum += bias_unquant[id];
|
|
}
|
|
|
|
output[id] = sum;
|
|
} |