8a66b9086a
- Started from ac75faa (initial E4B-MarkBase integration)
- Kept Sources/ (all engine code) + Package.swift + .gitignore
- Removed all ad-hoc tests, documentation, scripts, Python files
- Added Tests/00_Unit/ (MathTest, TokenizerTest, SamplerTest)
- Added .gitea/workflows/ci.yaml (build + unit tests + lint)
- Added Scripts/check_resources.sh (memory-aware test runner)
- Added Tests/Manifest.json (resource requirements for all tests)
- Focus: 4-bit quantized models only
181 lines
6.9 KiB
Metal
181 lines
6.9 KiB
Metal
#include <metal_stdlib>
|
|
using namespace metal;
|
|
|
|
// ═══════════════════════════════════════════════════════════════
|
|
// Batch Layer Processing Kernels
|
|
// Process entire layer for multiple tokens simultaneously
|
|
// ═══════════════════════════════════════════════════════════════
|
|
|
|
// Batch RMS Norm for layer input
|
|
// Process [batchSize, hiddenSize] with shared weights
|
|
kernel void batch_layer_rms_norm(
|
|
device float* batchInput [[buffer(0)]], // [batchSize, hiddenSize]
|
|
device float* weights [[buffer(1)]], // [hiddenSize]
|
|
device float* batchOutput [[buffer(2)]], // [batchSize, hiddenSize]
|
|
constant uint32_t& hiddenSize [[buffer(3)]],
|
|
constant float& eps [[buffer(4)]],
|
|
constant uint32_t& batchSize [[buffer(5)]],
|
|
uint3 gid [[thread_position_in_grid]])
|
|
{
|
|
uint batchIdx = gid.x;
|
|
uint elemIdx = gid.y;
|
|
|
|
if (batchIdx >= batchSize || elemIdx >= hiddenSize) return;
|
|
|
|
device float* input = batchInput + batchIdx * hiddenSize;
|
|
device float* output = batchOutput + batchIdx * hiddenSize;
|
|
|
|
// Compute sum of squares for this batch element
|
|
float ss = 0.0;
|
|
for (uint i = 0; i < hiddenSize; i++) {
|
|
ss += input[i] * input[i];
|
|
}
|
|
|
|
float rms = sqrt(ss / float(hiddenSize) + eps);
|
|
output[elemIdx] = input[elemIdx] / rms * weights[elemIdx];
|
|
}
|
|
|
|
// Batch Quantized Matmul for layer projections
|
|
// Process [batchSize, outDim] with shared quantized weights
|
|
kernel void batch_layer_quantized_matmul(
|
|
device float* batchInput [[buffer(0)]], // [batchSize, inDim]
|
|
device uint8_t* weights [[buffer(1)]], // [outDim, inDim] packed
|
|
device float* scales [[buffer(2)]], // [outDim, groups]
|
|
device float* biases [[buffer(3)]], // [outDim]
|
|
device float* batchOutput [[buffer(4)]], // [batchSize, outDim]
|
|
constant uint32_t& inDim [[buffer(5)]],
|
|
constant uint32_t& outDim [[buffer(6)]],
|
|
constant uint32_t& groupSize [[buffer(7)]],
|
|
constant uint32_t& batchSize [[buffer(8)]],
|
|
uint3 gid [[thread_position_in_grid]])
|
|
{
|
|
uint batchIdx = gid.x;
|
|
uint outIdx = gid.y;
|
|
|
|
if (batchIdx >= batchSize || outIdx >= outDim) return;
|
|
|
|
device float* input = batchInput + batchIdx * inDim;
|
|
device float* output = batchOutput + batchIdx * outDim;
|
|
|
|
float sum = biases[outIdx];
|
|
uint groupIdx = outIdx * (inDim / groupSize);
|
|
|
|
// Process in groups for quantization
|
|
for (uint i = 0; i < inDim; i++) {
|
|
// Load weight (8-bit quantized)
|
|
uint8_t w = weights[outIdx * inDim + i];
|
|
|
|
// Get scale for this group
|
|
uint g = i / groupSize;
|
|
float scale = scales[groupIdx + g];
|
|
|
|
// Dequantize and accumulate
|
|
sum += input[i] * (w - 128) * scale;
|
|
}
|
|
|
|
output[outIdx] = sum;
|
|
}
|
|
|
|
// Batch Elementwise Add for residual connections
|
|
// Process [batchSize, size]
|
|
kernel void batch_eltwise_add(
|
|
device float* batchA [[buffer(0)]], // [batchSize, size]
|
|
device float* batchB [[buffer(1)]], // [batchSize, size]
|
|
device float* batchOutput [[buffer(2)]], // [batchSize, size]
|
|
constant uint32_t& size [[buffer(3)]],
|
|
constant uint32_t& batchSize [[buffer(4)]],
|
|
uint2 gid [[thread_position_in_grid]])
|
|
{
|
|
uint batchIdx = gid.x;
|
|
uint elemIdx = gid.y;
|
|
|
|
if (batchIdx >= batchSize || elemIdx >= size) return;
|
|
|
|
uint offset = batchIdx * size + elemIdx;
|
|
batchOutput[offset] = batchA[offset] + batchB[offset];
|
|
}
|
|
|
|
// Batch Gated FFN (fused gate + up projection)
|
|
// Process [batchSize, intermediateSize]
|
|
kernel void batch_fused_gate_up(
|
|
device float* batchInput [[buffer(0)]], // [batchSize, hiddenSize]
|
|
device uint8_t* gateWeights [[buffer(1)]], // [intermediateSize, hiddenSize]
|
|
device float* gateScales [[buffer(2)]],
|
|
device float* gateBiases [[buffer(3)]],
|
|
device uint8_t* upWeights [[buffer(4)]], // [intermediateSize, hiddenSize]
|
|
device float* upScales [[buffer(5)]],
|
|
device float* upBiases [[buffer(6)]],
|
|
device float* batchOutput [[buffer(7)]], // [batchSize, intermediateSize]
|
|
constant uint32_t& hiddenSize [[buffer(8)]],
|
|
constant uint32_t& intermediateSize [[buffer(9)]],
|
|
constant uint32_t& groupSize [[buffer(10)]],
|
|
constant uint32_t& batchSize [[buffer(11)]],
|
|
uint3 gid [[thread_position_in_grid]])
|
|
{
|
|
uint batchIdx = gid.x;
|
|
uint interIdx = gid.y;
|
|
|
|
if (batchIdx >= batchSize || interIdx >= intermediateSize) return;
|
|
|
|
device float* input = batchInput + batchIdx * hiddenSize;
|
|
device float* output = batchOutput + batchIdx * intermediateSize;
|
|
|
|
// Compute gate
|
|
float gate = gateBiases[interIdx];
|
|
uint gateGroupIdx = interIdx * (hiddenSize / groupSize);
|
|
for (uint i = 0; i < hiddenSize; i++) {
|
|
uint8_t w = gateWeights[interIdx * hiddenSize + i];
|
|
uint g = i / groupSize;
|
|
float scale = gateScales[gateGroupIdx + g];
|
|
gate += input[i] * (w - 128) * scale;
|
|
}
|
|
|
|
// Compute up
|
|
float up = upBiases[interIdx];
|
|
uint upGroupIdx = interIdx * (hiddenSize / groupSize);
|
|
for (uint i = 0; i < hiddenSize; i++) {
|
|
uint8_t w = upWeights[interIdx * hiddenSize + i];
|
|
uint g = i / groupSize;
|
|
float scale = upScales[upGroupIdx + g];
|
|
up += input[i] * (w - 128) * scale;
|
|
}
|
|
|
|
// Fused activation: gate * sigmoid(gate) * up
|
|
float sigmoidGate = 1.0 / (1.0 + exp(-gate));
|
|
output[interIdx] = gate * sigmoidGate * up;
|
|
}
|
|
|
|
// Batch Down Projection (FFN output)
|
|
// Process [batchSize, hiddenSize]
|
|
kernel void batch_down_projection(
|
|
device float* batchInter [[buffer(0)]], // [batchSize, intermediateSize]
|
|
device uint8_t* downWeights [[buffer(1)]], // [hiddenSize, intermediateSize]
|
|
device float* downScales [[buffer(2)]],
|
|
device float* downBiases [[buffer(3)]],
|
|
device float* batchOutput [[buffer(4)]], // [batchSize, hiddenSize]
|
|
constant uint32_t& hiddenSize [[buffer(5)]],
|
|
constant uint32_t& intermediateSize [[buffer(6)]],
|
|
constant uint32_t& groupSize [[buffer(7)]],
|
|
constant uint32_t& batchSize [[buffer(8)]],
|
|
uint3 gid [[thread_position_in_grid]])
|
|
{
|
|
uint batchIdx = gid.x;
|
|
uint outIdx = gid.y;
|
|
|
|
if (batchIdx >= batchSize || outIdx >= hiddenSize) return;
|
|
|
|
device float* inter = batchInter + batchIdx * intermediateSize;
|
|
device float* output = batchOutput + batchIdx * hiddenSize;
|
|
|
|
float sum = downBiases[outIdx];
|
|
uint groupIdx = outIdx * (intermediateSize / groupSize);
|
|
|
|
for (uint i = 0; i < intermediateSize; i++) {
|
|
uint8_t w = downWeights[outIdx * intermediateSize + i];
|
|
uint g = i / groupSize;
|
|
float scale = downScales[groupIdx + g];
|
|
sum += inter[i] * (w - 128) * scale;
|
|
}
|
|
|
|
output[outIdx] = sum;
|
|
} |