Files
markbaseengine/Sources/MarkBase/Metal/FusedKernels.metal
T
MarkBase Admin 8a66b9086a
CI / build (push) Waiting to run
CI / unit-tests (push) Blocked by required conditions
CI / lint (push) Blocked by required conditions
v2: Initial clean branch with unit tests + CI/CD pipeline
- Started from ac75faa (initial E4B-MarkBase integration)
- Kept Sources/ (all engine code) + Package.swift + .gitignore
- Removed all ad-hoc tests, documentation, scripts, Python files
- Added Tests/00_Unit/ (MathTest, TokenizerTest, SamplerTest)
- Added .gitea/workflows/ci.yaml (build + unit tests + lint)
- Added Scripts/check_resources.sh (memory-aware test runner)
- Added Tests/Manifest.json (resource requirements for all tests)
- Focus: 4-bit quantized models only
2026-07-05 13:29:25 +08:00

201 lines
7.0 KiB
Metal

#include <metal_stdlib>
using namespace metal;
// ─────────────────────────────────────────────────────────────────────
// Kernel Fusion: Combine multiple operations into single kernels
// Goal: Reduce kernel dispatches for common patterns
// ─────────────────────────────────────────────────────────────────────
// ── Fused Embedding Dequantize + Scale ──
// Combines: dequantize_row + eltwise_scale
// Eliminates one kernel dispatch
kernel void fused_dequantize_scale(
device const uint32_t* weight [[buffer(0)]],
device const float* scales [[buffer(1)]],
device const float* biases [[buffer(2)]],
device float* output [[buffer(3)]],
constant uint& nCols [[buffer(4)]],
constant int& row [[buffer(5)]],
constant uint& groupSize [[buffer(6)]],
constant float& scale [[buffer(7)]], // Extra scale to apply
uint id [[thread_position_in_grid]]
) {
if (id >= nCols) return;
uint numGroups = nCols / groupSize;
uint groupIdx = id / groupSize;
uint inGroupIdx = id % groupSize;
uint weightRowOffset = row * (nCols / 8);
uint packedIdx = weightRowOffset + id / 8;
uint subIdx = id % 8;
uint32_t packed = weight[packedIdx];
uint32_t qval = (packed >> (subIdx * 4)) & 0xF;
float scale_val = scales[groupIdx];
float bias_val = biases[groupIdx];
float val = float(qval) * scale_val + bias_val;
// Apply extra scale (embedding scale or per-layer scale)
val *= scale;
output[id] = val;
}
// ── Fused RMS Norm + Residual Add ──
// Combines: rmsNorm + eltwiseAdd
// Eliminates one kernel dispatch
kernel void fused_rms_norm_residual(
device const float* input [[buffer(0)]],
device const float* residual [[buffer(1)]],
device const float* weight [[buffer(2)]],
device float* output [[buffer(3)]],
constant uint& N [[buffer(4)]],
constant float& eps [[buffer(5)]],
uint tid [[thread_position_in_grid]],
uint threadgroupId [[threadgroup_position_in_grid]],
uint threadgroupSize [[threads_per_threadgroup]]
) {
// Parallel RMS computation
threadgroup float sharedSum[256];
uint laneId = tid % threadgroupSize;
uint groupId = tid / threadgroupSize;
float sumSq = 0.0;
uint start = groupId * (N / 256);
uint end = min((groupId + 1) * (N / 256), N);
for (uint i = start; i < end; i++) {
float val = input[i];
sumSq += val * val;
}
sharedSum[laneId] = sumSq;
// Simplified RMS (proper implementation would use SIMD reduction)
float rms = sqrt(sharedSum[0] / N + eps);
if (tid < N) {
float normed = input[tid] / rms * weight[tid];
output[tid] = residual[tid] + normed; // Residual add
}
}
// ── Fused Matmul + GELU + Residual ──
// Combines: quantized_matmul + gelu + eltwiseAdd
kernel void fused_matmul_gelu_residual(
device const float* input [[buffer(0)]],
device const uint32_t* weight [[buffer(1)]],
device const float* scales [[buffer(2)]],
device const float* biases [[buffer(3)]],
device const float* residual [[buffer(4)]],
device float* output [[buffer(5)]],
constant uint& inDim [[buffer(6)]],
constant uint& outDim [[buffer(7)]],
constant uint& groupSize [[buffer(8)]],
uint id [[thread_position_in_grid]]
) {
if (id >= outDim) return;
uint numGroups = inDim / groupSize;
float sum = 0.0;
for (uint g = 0; g < numGroups; g++) {
float scale = scales[id * numGroups + g];
float bias = biases[id * numGroups + g];
for (uint j = 0; j < groupSize / 8; j++) {
uint weightIdx = id * (inDim / 8) + g * (groupSize / 8) + j;
uint32_t packed = weight[weightIdx];
for (uint k = 0; k < 8; k++) {
uint inputIdx = g * groupSize + j * 8 + k;
uint32_t qval = (packed >> (k * 4)) & 0xF;
float wval = float(qval) * scale + bias;
sum += input[inputIdx] * wval;
}
}
}
// Apply GELU approximation
float gelu = sum * 0.5 * (1.0 + tanh(sum * 0.7978845608 * (1.0 + 0.044715 * sum * sum)));
// Residual add
output[id] = residual[id] + gelu;
}
// ── Batch RMS Norm for Multiple Layers ──
// Process 42 layers' norm operations in one dispatch
kernel void batch_rms_norm_layers(
device const float* inputs [[buffer(0)]], // [numLayers * hiddenSize] flattened
device const float* weights [[buffer(1)]], // [numLayers * hiddenSize] flattened
device float* outputs [[buffer(2)]], // [numLayers * hiddenSize] flattened
constant uint& hiddenSize [[buffer(3)]],
constant uint& numLayers [[buffer(4)]],
constant float& eps [[buffer(5)]],
uint2 id [[thread_position_in_grid]]
) {
uint layerIdx = id.y;
uint dimIdx = id.x;
if (layerIdx >= numLayers || dimIdx >= hiddenSize) return;
uint offset = layerIdx * hiddenSize;
// Simplified RMS computation (proper would need threadgroup reduction)
float sumSq = 0.0;
for (uint i = 0; i < hiddenSize; i++) {
float val = inputs[offset + i];
sumSq += val * val;
}
float rms = sqrt(sumSq / hiddenSize + eps);
outputs[offset + dimIdx] = inputs[offset + dimIdx] / rms * weights[offset + dimIdx];
}
// ── Fused Quantized Matmul + Bias Add ──
kernel void fused_quantized_matmul_bias(
device const float* input [[buffer(0)]],
device const uint32_t* weight [[buffer(1)]],
device const float* scales [[buffer(2)]],
device const float* biases_quant [[buffer(3)]],
device const float* bias_unquant [[buffer(4)]], // Optional unquantized bias
device float* output [[buffer(5)]],
constant uint& inDim [[buffer(6)]],
constant uint& outDim [[buffer(7)]],
constant uint& groupSize [[buffer(8)]],
constant bool& hasBias [[buffer(9)]],
uint id [[thread_position_in_grid]]
) {
if (id >= outDim) return;
uint numGroups = inDim / groupSize;
float sum = 0.0;
for (uint g = 0; g < numGroups; g++) {
float scale = scales[id * numGroups + g];
float bias = biases_quant[id * numGroups + g];
for (uint j = 0; j < groupSize / 8; j++) {
uint weightIdx = id * (inDim / 8) + g * (groupSize / 8) + j;
uint32_t packed = weight[weightIdx];
for (uint k = 0; k < 8; k++) {
uint inputIdx = g * groupSize + j * 8 + k;
uint32_t qval = (packed >> (k * 4)) & 0xF;
float wval = float(qval) * scale + bias;
sum += input[inputIdx] * wval;
}
}
}
// Add unquantized bias if present
if (hasBias) {
sum += bias_unquant[id];
}
output[id] = sum;
}