Files
markbaseengine/Sources/MarkBase/Metal/BatchLayerKernels.metal
T
MarkBase Admin 8a66b9086a
CI / build (push) Waiting to run
CI / unit-tests (push) Blocked by required conditions
CI / lint (push) Blocked by required conditions
v2: Initial clean branch with unit tests + CI/CD pipeline
- Started from ac75faa (initial E4B-MarkBase integration)
- Kept Sources/ (all engine code) + Package.swift + .gitignore
- Removed all ad-hoc tests, documentation, scripts, Python files
- Added Tests/00_Unit/ (MathTest, TokenizerTest, SamplerTest)
- Added .gitea/workflows/ci.yaml (build + unit tests + lint)
- Added Scripts/check_resources.sh (memory-aware test runner)
- Added Tests/Manifest.json (resource requirements for all tests)
- Focus: 4-bit quantized models only
2026-07-05 13:29:25 +08:00

181 lines
6.9 KiB
Metal

#include <metal_stdlib>
using namespace metal;
// ═══════════════════════════════════════════════════════════════
// Batch Layer Processing Kernels
// Process entire layer for multiple tokens simultaneously
// ═══════════════════════════════════════════════════════════════
// Batch RMS Norm for layer input
// Process [batchSize, hiddenSize] with shared weights
kernel void batch_layer_rms_norm(
device float* batchInput [[buffer(0)]], // [batchSize, hiddenSize]
device float* weights [[buffer(1)]], // [hiddenSize]
device float* batchOutput [[buffer(2)]], // [batchSize, hiddenSize]
constant uint32_t& hiddenSize [[buffer(3)]],
constant float& eps [[buffer(4)]],
constant uint32_t& batchSize [[buffer(5)]],
uint3 gid [[thread_position_in_grid]])
{
uint batchIdx = gid.x;
uint elemIdx = gid.y;
if (batchIdx >= batchSize || elemIdx >= hiddenSize) return;
device float* input = batchInput + batchIdx * hiddenSize;
device float* output = batchOutput + batchIdx * hiddenSize;
// Compute sum of squares for this batch element
float ss = 0.0;
for (uint i = 0; i < hiddenSize; i++) {
ss += input[i] * input[i];
}
float rms = sqrt(ss / float(hiddenSize) + eps);
output[elemIdx] = input[elemIdx] / rms * weights[elemIdx];
}
// Batch Quantized Matmul for layer projections
// Process [batchSize, outDim] with shared quantized weights
kernel void batch_layer_quantized_matmul(
device float* batchInput [[buffer(0)]], // [batchSize, inDim]
device uint8_t* weights [[buffer(1)]], // [outDim, inDim] packed
device float* scales [[buffer(2)]], // [outDim, groups]
device float* biases [[buffer(3)]], // [outDim]
device float* batchOutput [[buffer(4)]], // [batchSize, outDim]
constant uint32_t& inDim [[buffer(5)]],
constant uint32_t& outDim [[buffer(6)]],
constant uint32_t& groupSize [[buffer(7)]],
constant uint32_t& batchSize [[buffer(8)]],
uint3 gid [[thread_position_in_grid]])
{
uint batchIdx = gid.x;
uint outIdx = gid.y;
if (batchIdx >= batchSize || outIdx >= outDim) return;
device float* input = batchInput + batchIdx * inDim;
device float* output = batchOutput + batchIdx * outDim;
float sum = biases[outIdx];
uint groupIdx = outIdx * (inDim / groupSize);
// Process in groups for quantization
for (uint i = 0; i < inDim; i++) {
// Load weight (8-bit quantized)
uint8_t w = weights[outIdx * inDim + i];
// Get scale for this group
uint g = i / groupSize;
float scale = scales[groupIdx + g];
// Dequantize and accumulate
sum += input[i] * (w - 128) * scale;
}
output[outIdx] = sum;
}
// Batch Elementwise Add for residual connections
// Process [batchSize, size]
kernel void batch_eltwise_add(
device float* batchA [[buffer(0)]], // [batchSize, size]
device float* batchB [[buffer(1)]], // [batchSize, size]
device float* batchOutput [[buffer(2)]], // [batchSize, size]
constant uint32_t& size [[buffer(3)]],
constant uint32_t& batchSize [[buffer(4)]],
uint2 gid [[thread_position_in_grid]])
{
uint batchIdx = gid.x;
uint elemIdx = gid.y;
if (batchIdx >= batchSize || elemIdx >= size) return;
uint offset = batchIdx * size + elemIdx;
batchOutput[offset] = batchA[offset] + batchB[offset];
}
// Batch Gated FFN (fused gate + up projection)
// Process [batchSize, intermediateSize]
kernel void batch_fused_gate_up(
device float* batchInput [[buffer(0)]], // [batchSize, hiddenSize]
device uint8_t* gateWeights [[buffer(1)]], // [intermediateSize, hiddenSize]
device float* gateScales [[buffer(2)]],
device float* gateBiases [[buffer(3)]],
device uint8_t* upWeights [[buffer(4)]], // [intermediateSize, hiddenSize]
device float* upScales [[buffer(5)]],
device float* upBiases [[buffer(6)]],
device float* batchOutput [[buffer(7)]], // [batchSize, intermediateSize]
constant uint32_t& hiddenSize [[buffer(8)]],
constant uint32_t& intermediateSize [[buffer(9)]],
constant uint32_t& groupSize [[buffer(10)]],
constant uint32_t& batchSize [[buffer(11)]],
uint3 gid [[thread_position_in_grid]])
{
uint batchIdx = gid.x;
uint interIdx = gid.y;
if (batchIdx >= batchSize || interIdx >= intermediateSize) return;
device float* input = batchInput + batchIdx * hiddenSize;
device float* output = batchOutput + batchIdx * intermediateSize;
// Compute gate
float gate = gateBiases[interIdx];
uint gateGroupIdx = interIdx * (hiddenSize / groupSize);
for (uint i = 0; i < hiddenSize; i++) {
uint8_t w = gateWeights[interIdx * hiddenSize + i];
uint g = i / groupSize;
float scale = gateScales[gateGroupIdx + g];
gate += input[i] * (w - 128) * scale;
}
// Compute up
float up = upBiases[interIdx];
uint upGroupIdx = interIdx * (hiddenSize / groupSize);
for (uint i = 0; i < hiddenSize; i++) {
uint8_t w = upWeights[interIdx * hiddenSize + i];
uint g = i / groupSize;
float scale = upScales[upGroupIdx + g];
up += input[i] * (w - 128) * scale;
}
// Fused activation: gate * sigmoid(gate) * up
float sigmoidGate = 1.0 / (1.0 + exp(-gate));
output[interIdx] = gate * sigmoidGate * up;
}
// Batch Down Projection (FFN output)
// Process [batchSize, hiddenSize]
kernel void batch_down_projection(
device float* batchInter [[buffer(0)]], // [batchSize, intermediateSize]
device uint8_t* downWeights [[buffer(1)]], // [hiddenSize, intermediateSize]
device float* downScales [[buffer(2)]],
device float* downBiases [[buffer(3)]],
device float* batchOutput [[buffer(4)]], // [batchSize, hiddenSize]
constant uint32_t& hiddenSize [[buffer(5)]],
constant uint32_t& intermediateSize [[buffer(6)]],
constant uint32_t& groupSize [[buffer(7)]],
constant uint32_t& batchSize [[buffer(8)]],
uint3 gid [[thread_position_in_grid]])
{
uint batchIdx = gid.x;
uint outIdx = gid.y;
if (batchIdx >= batchSize || outIdx >= hiddenSize) return;
device float* inter = batchInter + batchIdx * intermediateSize;
device float* output = batchOutput + batchIdx * hiddenSize;
float sum = downBiases[outIdx];
uint groupIdx = outIdx * (intermediateSize / groupSize);
for (uint i = 0; i < intermediateSize; i++) {
uint8_t w = downWeights[outIdx * intermediateSize + i];
uint g = i / groupSize;
float scale = downScales[groupIdx + g];
sum += inter[i] * (w - 128) * scale;
}
output[outIdx] = sum;
}