Files
markbaseengine/Sources/MarkBase/Metal/Float16Kernels.metal
T
MarkBase Admin 8a66b9086a
CI / build (push) Waiting to run
CI / unit-tests (push) Blocked by required conditions
CI / lint (push) Blocked by required conditions
v2: Initial clean branch with unit tests + CI/CD pipeline
- Started from ac75faa (initial E4B-MarkBase integration)
- Kept Sources/ (all engine code) + Package.swift + .gitignore
- Removed all ad-hoc tests, documentation, scripts, Python files
- Added Tests/00_Unit/ (MathTest, TokenizerTest, SamplerTest)
- Added .gitea/workflows/ci.yaml (build + unit tests + lint)
- Added Scripts/check_resources.sh (memory-aware test runner)
- Added Tests/Manifest.json (resource requirements for all tests)
- Focus: 4-bit quantized models only
2026-07-05 13:29:25 +08:00

170 lines
6.6 KiB
Metal

#include <metal_stdlib>
using namespace metal;
// ════════════════════════════════════════════════════════
// Float16 Metal Kernels
// ════════════════════════════════════════════════════════
// ── Float16 Quantized Matmul ──────────────────────────
// Uses half precision for input/weights
kernel void quantized_matmul_f16(
device const half *x [[buffer(0)]], // Input [inDim]
device const uint *w [[buffer(1)]], // Packed weights [outDim, inDim/8]
device const half *s [[buffer(2)]], // Scales [outDim, inDim/64]
device const half *b [[buffer(3)]], // Biases [outDim, inDim/64]
device float *out [[buffer(4)]], // Output [outDim] - Float32 for accuracy
constant uint &inDim [[buffer(5)]],
constant uint &outDim [[buffer(6)]],
constant uint &groupSize [[buffer(7)]],
threadgroup half *shared_x [[threadgroup(0)]], // Input cache in half
uint gid [[thread_position_in_grid]],
uint tid [[thread_position_in_threadgroup]],
uint tgSize [[threads_per_threadgroup]]
) {
uint outRow = gid;
if (outRow >= outDim) return;
// Cooperative loading of input vector
for (uint i = tid; i < inDim; i += tgSize) {
shared_x[i] = x[i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Compute dot product
uint numGroups = inDim / groupSize;
float sum = 0.0;
for (uint g = 0; g < numGroups; g++) {
half scale = s[outRow * numGroups + g];
half bias = b[outRow * numGroups + g];
uint packedBase = outRow * (inDim / 8) + g * (groupSize / 8);
// Process 8 packed uint32 values
for (uint p = 0; p < 8; p += 2) {
uint packed0 = w[packedBase + p];
uint packed1 = w[packedBase + p + 1];
uint xBase = g * groupSize + p * 8;
// Load 16 half values
half4 xVec0 = half4(shared_x[xBase+0], shared_x[xBase+1], shared_x[xBase+2], shared_x[xBase+3]);
half4 xVec1 = half4(shared_x[xBase+4], shared_x[xBase+5], shared_x[xBase+6], shared_x[xBase+7]);
half4 xVec2 = half4(shared_x[xBase+8], shared_x[xBase+9], shared_x[xBase+10], shared_x[xBase+11]);
half4 xVec3 = half4(shared_x[xBase+12], shared_x[xBase+13], shared_x[xBase+14], shared_x[xBase+15]);
// Dequantize
half4 qVec0 = half4(
half((packed0 >> 0) & 0xF) * scale + bias,
half((packed0 >> 4) & 0xF) * scale + bias,
half((packed0 >> 8) & 0xF) * scale + bias,
half((packed0 >> 12) & 0xF) * scale + bias
);
half4 qVec1 = half4(
half((packed0 >> 16) & 0xF) * scale + bias,
half((packed0 >> 20) & 0xF) * scale + bias,
half((packed0 >> 24) & 0xF) * scale + bias,
half((packed0 >> 28) & 0xF) * scale + bias
);
half4 qVec2 = half4(
half((packed1 >> 0) & 0xF) * scale + bias,
half((packed1 >> 4) & 0xF) * scale + bias,
half((packed1 >> 8) & 0xF) * scale + bias,
half((packed1 >> 12) & 0xF) * scale + bias
);
half4 qVec3 = half4(
half((packed1 >> 16) & 0xF) * scale + bias,
half((packed1 >> 20) & 0xF) * scale + bias,
half((packed1 >> 24) & 0xF) * scale + bias,
half((packed1 >> 28) & 0xF) * scale + bias
);
// Accumulate in Float32 for accuracy
sum += float(dot(qVec0, xVec0)) + float(dot(qVec1, xVec1)) +
float(dot(qVec2, xVec2)) + float(dot(qVec3, xVec3));
}
}
out[outRow] = sum;
}
// ── Float16 RMS Norm ──────────────────────────────────
kernel void rms_norm_f16(
device const half *x [[buffer(0)]], // Input [N]
device const half *w [[buffer(1)]], // Weight [N]
device half *y [[buffer(2)]], // Output [N]
constant uint &N [[buffer(3)]],
constant half &eps [[buffer(4)]],
threadgroup half *partial_sums [[threadgroup(0)]],
uint tid [[thread_position_in_threadgroup]],
uint tgSize [[threads_per_threadgroup]]
) {
// Phase 1: Each thread computes partial sum of squares
half localSum = 0.0;
for (uint i = tid; i < N; i += tgSize) {
localSum += x[i] * x[i];
}
partial_sums[tid] = localSum;
threadgroup_barrier(mem_flags::mem_threadgroup);
// Phase 2: Parallel reduction
for (uint stride = tgSize/2; stride > 0; stride >>= 1) {
if (tid < stride) {
partial_sums[tid] += partial_sums[tid + stride];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
// Phase 3: Compute RMS and normalize
half ss = partial_sums[0];
half rms = rsqrt(ss / half(N) + eps);
// Each thread outputs its portion
for (uint i = tid; i < N; i += tgSize) {
y[i] = x[i] * rms * (w ? w[i] : half(1.0));
}
}
// ── Float16 Elementwise Operations ────────────────────
kernel void eltwise_mul_f16(
device const half *a,
device const half *b,
device half *out,
constant uint &count,
uint id [[thread_position_in_grid]]
) {
uint idx = id * 4;
if (idx >= count) return;
half4 aVec = half4(a[idx], a[idx+1], a[idx+2], a[idx+3]);
half4 bVec = half4(b[idx], b[idx+1], b[idx+2], b[idx+3]);
half4 outVec = aVec * bVec;
if (idx < count) out[idx] = outVec.x;
if (idx+1 < count) out[idx+1] = outVec.y;
if (idx+2 < count) out[idx+2] = outVec.z;
if (idx+3 < count) out[idx+3] = outVec.w;
}
kernel void eltwise_add_f16(
device const half *a,
device const half *b,
device half *out,
constant uint &count,
uint id [[thread_position_in_grid]]
) {
uint idx = id * 4;
if (idx >= count) return;
half4 aVec = half4(a[idx], a[idx+1], a[idx+2], a[idx+3]);
half4 bVec = half4(b[idx], b[idx+1], b[idx+2], b[idx+3]);
half4 outVec = aVec + bVec;
if (idx < count) out[idx] = outVec.x;
if (idx+1 < count) out[idx+1] = outVec.y;
if (idx+2 < count) out[idx+2] = outVec.z;
if (idx+3 < count) out[idx+3] = outVec.w;
}