8a66b9086a
- Started from ac75faa (initial E4B-MarkBase integration)
- Kept Sources/ (all engine code) + Package.swift + .gitignore
- Removed all ad-hoc tests, documentation, scripts, Python files
- Added Tests/00_Unit/ (MathTest, TokenizerTest, SamplerTest)
- Added .gitea/workflows/ci.yaml (build + unit tests + lint)
- Added Scripts/check_resources.sh (memory-aware test runner)
- Added Tests/Manifest.json (resource requirements for all tests)
- Focus: 4-bit quantized models only
170 lines
6.6 KiB
Metal
170 lines
6.6 KiB
Metal
#include <metal_stdlib>
|
|
using namespace metal;
|
|
|
|
// ════════════════════════════════════════════════════════
|
|
// Float16 Metal Kernels
|
|
// ════════════════════════════════════════════════════════
|
|
|
|
// ── Float16 Quantized Matmul ──────────────────────────
|
|
// Uses half precision for input/weights
|
|
kernel void quantized_matmul_f16(
|
|
device const half *x [[buffer(0)]], // Input [inDim]
|
|
device const uint *w [[buffer(1)]], // Packed weights [outDim, inDim/8]
|
|
device const half *s [[buffer(2)]], // Scales [outDim, inDim/64]
|
|
device const half *b [[buffer(3)]], // Biases [outDim, inDim/64]
|
|
device float *out [[buffer(4)]], // Output [outDim] - Float32 for accuracy
|
|
constant uint &inDim [[buffer(5)]],
|
|
constant uint &outDim [[buffer(6)]],
|
|
constant uint &groupSize [[buffer(7)]],
|
|
threadgroup half *shared_x [[threadgroup(0)]], // Input cache in half
|
|
uint gid [[thread_position_in_grid]],
|
|
uint tid [[thread_position_in_threadgroup]],
|
|
uint tgSize [[threads_per_threadgroup]]
|
|
) {
|
|
uint outRow = gid;
|
|
if (outRow >= outDim) return;
|
|
|
|
// Cooperative loading of input vector
|
|
for (uint i = tid; i < inDim; i += tgSize) {
|
|
shared_x[i] = x[i];
|
|
}
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
// Compute dot product
|
|
uint numGroups = inDim / groupSize;
|
|
float sum = 0.0;
|
|
|
|
for (uint g = 0; g < numGroups; g++) {
|
|
half scale = s[outRow * numGroups + g];
|
|
half bias = b[outRow * numGroups + g];
|
|
|
|
uint packedBase = outRow * (inDim / 8) + g * (groupSize / 8);
|
|
|
|
// Process 8 packed uint32 values
|
|
for (uint p = 0; p < 8; p += 2) {
|
|
uint packed0 = w[packedBase + p];
|
|
uint packed1 = w[packedBase + p + 1];
|
|
|
|
uint xBase = g * groupSize + p * 8;
|
|
|
|
// Load 16 half values
|
|
half4 xVec0 = half4(shared_x[xBase+0], shared_x[xBase+1], shared_x[xBase+2], shared_x[xBase+3]);
|
|
half4 xVec1 = half4(shared_x[xBase+4], shared_x[xBase+5], shared_x[xBase+6], shared_x[xBase+7]);
|
|
half4 xVec2 = half4(shared_x[xBase+8], shared_x[xBase+9], shared_x[xBase+10], shared_x[xBase+11]);
|
|
half4 xVec3 = half4(shared_x[xBase+12], shared_x[xBase+13], shared_x[xBase+14], shared_x[xBase+15]);
|
|
|
|
// Dequantize
|
|
half4 qVec0 = half4(
|
|
half((packed0 >> 0) & 0xF) * scale + bias,
|
|
half((packed0 >> 4) & 0xF) * scale + bias,
|
|
half((packed0 >> 8) & 0xF) * scale + bias,
|
|
half((packed0 >> 12) & 0xF) * scale + bias
|
|
);
|
|
half4 qVec1 = half4(
|
|
half((packed0 >> 16) & 0xF) * scale + bias,
|
|
half((packed0 >> 20) & 0xF) * scale + bias,
|
|
half((packed0 >> 24) & 0xF) * scale + bias,
|
|
half((packed0 >> 28) & 0xF) * scale + bias
|
|
);
|
|
half4 qVec2 = half4(
|
|
half((packed1 >> 0) & 0xF) * scale + bias,
|
|
half((packed1 >> 4) & 0xF) * scale + bias,
|
|
half((packed1 >> 8) & 0xF) * scale + bias,
|
|
half((packed1 >> 12) & 0xF) * scale + bias
|
|
);
|
|
half4 qVec3 = half4(
|
|
half((packed1 >> 16) & 0xF) * scale + bias,
|
|
half((packed1 >> 20) & 0xF) * scale + bias,
|
|
half((packed1 >> 24) & 0xF) * scale + bias,
|
|
half((packed1 >> 28) & 0xF) * scale + bias
|
|
);
|
|
|
|
// Accumulate in Float32 for accuracy
|
|
sum += float(dot(qVec0, xVec0)) + float(dot(qVec1, xVec1)) +
|
|
float(dot(qVec2, xVec2)) + float(dot(qVec3, xVec3));
|
|
}
|
|
}
|
|
|
|
out[outRow] = sum;
|
|
}
|
|
|
|
// ── Float16 RMS Norm ──────────────────────────────────
|
|
kernel void rms_norm_f16(
|
|
device const half *x [[buffer(0)]], // Input [N]
|
|
device const half *w [[buffer(1)]], // Weight [N]
|
|
device half *y [[buffer(2)]], // Output [N]
|
|
constant uint &N [[buffer(3)]],
|
|
constant half &eps [[buffer(4)]],
|
|
threadgroup half *partial_sums [[threadgroup(0)]],
|
|
uint tid [[thread_position_in_threadgroup]],
|
|
uint tgSize [[threads_per_threadgroup]]
|
|
) {
|
|
// Phase 1: Each thread computes partial sum of squares
|
|
half localSum = 0.0;
|
|
for (uint i = tid; i < N; i += tgSize) {
|
|
localSum += x[i] * x[i];
|
|
}
|
|
partial_sums[tid] = localSum;
|
|
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
// Phase 2: Parallel reduction
|
|
for (uint stride = tgSize/2; stride > 0; stride >>= 1) {
|
|
if (tid < stride) {
|
|
partial_sums[tid] += partial_sums[tid + stride];
|
|
}
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
}
|
|
|
|
// Phase 3: Compute RMS and normalize
|
|
half ss = partial_sums[0];
|
|
half rms = rsqrt(ss / half(N) + eps);
|
|
|
|
// Each thread outputs its portion
|
|
for (uint i = tid; i < N; i += tgSize) {
|
|
y[i] = x[i] * rms * (w ? w[i] : half(1.0));
|
|
}
|
|
}
|
|
|
|
// ── Float16 Elementwise Operations ────────────────────
|
|
|
|
kernel void eltwise_mul_f16(
|
|
device const half *a,
|
|
device const half *b,
|
|
device half *out,
|
|
constant uint &count,
|
|
uint id [[thread_position_in_grid]]
|
|
) {
|
|
uint idx = id * 4;
|
|
if (idx >= count) return;
|
|
|
|
half4 aVec = half4(a[idx], a[idx+1], a[idx+2], a[idx+3]);
|
|
half4 bVec = half4(b[idx], b[idx+1], b[idx+2], b[idx+3]);
|
|
half4 outVec = aVec * bVec;
|
|
|
|
if (idx < count) out[idx] = outVec.x;
|
|
if (idx+1 < count) out[idx+1] = outVec.y;
|
|
if (idx+2 < count) out[idx+2] = outVec.z;
|
|
if (idx+3 < count) out[idx+3] = outVec.w;
|
|
}
|
|
|
|
kernel void eltwise_add_f16(
|
|
device const half *a,
|
|
device const half *b,
|
|
device half *out,
|
|
constant uint &count,
|
|
uint id [[thread_position_in_grid]]
|
|
) {
|
|
uint idx = id * 4;
|
|
if (idx >= count) return;
|
|
|
|
half4 aVec = half4(a[idx], a[idx+1], a[idx+2], a[idx+3]);
|
|
half4 bVec = half4(b[idx], b[idx+1], b[idx+2], b[idx+3]);
|
|
half4 outVec = aVec + bVec;
|
|
|
|
if (idx < count) out[idx] = outVec.x;
|
|
if (idx+1 < count) out[idx+1] = outVec.y;
|
|
if (idx+2 < count) out[idx+2] = outVec.z;
|
|
if (idx+3 < count) out[idx+3] = outVec.w;
|
|
}
|