Files
markbaseengine/Sources/MarkBase/Metal/BatchKernelsFixed.metal
T
MarkBase Admin 8a66b9086a
CI / build (push) Waiting to run
CI / unit-tests (push) Blocked by required conditions
CI / lint (push) Blocked by required conditions
v2: Initial clean branch with unit tests + CI/CD pipeline
- Started from ac75faa (initial E4B-MarkBase integration)
- Kept Sources/ (all engine code) + Package.swift + .gitignore
- Removed all ad-hoc tests, documentation, scripts, Python files
- Added Tests/00_Unit/ (MathTest, TokenizerTest, SamplerTest)
- Added .gitea/workflows/ci.yaml (build + unit tests + lint)
- Added Scripts/check_resources.sh (memory-aware test runner)
- Added Tests/Manifest.json (resource requirements for all tests)
- Focus: 4-bit quantized models only
2026-07-05 13:29:25 +08:00

154 lines
5.7 KiB
Metal

#include <metal_stdlib>
using namespace metal;
// ═══════════════════════════════════════════════════════════════
// Batch Metal Kernels - Process multiple tokens simultaneously
// ═══════════════════════════════════════════════════════════════
// Batch quantized matmul - process N tokens with shared weights
kernel void quantized_matmul_batch(
device float* batchInput [[buffer(0)]], // [batchSize, inDim]
device uint8_t* weights [[buffer(1)]], // [outDim, inDim] packed
device float* scales [[buffer(2)]], // [outDim, groups]
device float* biases [[buffer(3)]], // [outDim]
device float* batchOutput [[buffer(4)]], // [batchSize, outDim]
constant uint32_t& inDim [[buffer(5)]],
constant uint32_t& outDim [[buffer(6)]],
constant uint32_t& groupSize [[buffer(7)]],
constant uint32_t& batchSize [[buffer(8)]],
uint3 gid [[thread_position_in_grid]])
{
uint batchIdx = gid.x;
uint outIdx = gid.y;
if (batchIdx >= batchSize || outIdx >= outDim) return;
device float* input = batchInput + batchIdx * inDim;
float sum = biases[outIdx];
uint groupIdx = outIdx * (inDim / groupSize);
for (uint i = 0; i < inDim; i += 4) {
float4 inVals = float4(input[i], input[i+1], input[i+2], input[i+3]);
uint packedWeight = weights[outIdx * inDim + i];
uint8_t w0 = (packedWeight >> 0) & 0xFF;
uint8_t w1 = (packedWeight >> 8) & 0xFF;
uint8_t w2 = (packedWeight >> 16) & 0xFF;
uint8_t w3 = (packedWeight >> 24) & 0xFF;
uint g0 = (i + 0) / groupSize;
uint g1 = (i + 1) / groupSize;
uint g2 = (i + 2) / groupSize;
uint g3 = (i + 3) / groupSize;
float scale0 = scales[groupIdx + g0];
float scale1 = scales[groupIdx + g1];
float scale2 = scales[groupIdx + g2];
float scale3 = scales[groupIdx + g3];
sum += inVals.x * (w0 - 128) * scale0;
sum += inVals.y * (w1 - 128) * scale1;
sum += inVals.z * (w2 - 128) * scale2;
sum += inVals.w * (w3 - 128) * scale3;
}
batchOutput[batchIdx * outDim + outIdx] = sum;
}
// Batch RMS norm - process N tokens simultaneously
kernel void rms_norm_batch(
device float* batchInput [[buffer(0)]], // [batchSize, N]
device float* weights [[buffer(1)]], // [N]
device float* batchOutput [[buffer(2)]], // [batchSize, N]
constant uint32_t& N [[buffer(3)]],
constant float& eps [[buffer(4)]],
constant uint32_t& batchSize [[buffer(5)]],
uint3 gid [[thread_position_in_grid]])
{
uint batchIdx = gid.x;
uint elemIdx = gid.y;
if (batchIdx >= batchSize || elemIdx >= N) return;
device float* input = batchInput + batchIdx * N;
float sqSum = 0.0;
for (uint i = 0; i < N; i++) {
sqSum += input[i] * input[i];
}
float rms = sqrt(sqSum / float(N) + eps);
batchOutput[batchIdx * N + elemIdx] = input[elemIdx] / rms * weights[elemIdx];
}
// Batch attention (simplified - for demonstration)
// Full implementation would require complex KV cache management
kernel void sliding_attention_batch(
device float* batchQuery [[buffer(0)]], // [batchSize, nHeads, headDim]
device float* kvCache [[buffer(1)]], // [maxSeqLen, 2, nKvHeads, headDim]
device float* batchOutput [[buffer(2)]], // [batchSize, nHeads, headDim]
constant uint32_t* positions [[buffer(3)]], // [batchSize]
constant uint32_t& nHeads [[buffer(4)]],
constant uint32_t& nKvHeads [[buffer(5)]],
constant uint32_t& headDim [[buffer(6)]],
constant uint32_t& batchSize [[buffer(7)]],
constant uint32_t& windowSize [[buffer(8)]],
uint3 gid [[thread_position_in_grid]])
{
uint batchIdx = gid.x;
uint headIdx = gid.y;
uint dimIdx = gid.z;
if (batchIdx >= batchSize || headIdx >= nHeads || dimIdx >= headDim) return;
uint pos = positions[batchIdx];
uint kvHeadIdx = headIdx / (nHeads / nKvHeads);
device float* query = batchQuery + batchIdx * nHeads * headDim + headIdx * headDim;
uint start = max(0u, pos - windowSize);
uint end = pos;
float maxScore = -1e10;
for (uint t = start; t < end; t++) {
device float* key = kvCache + t * 2 * nKvHeads * headDim + kvHeadIdx * headDim;
float score = 0.0;
for (uint d = 0; d < headDim; d++) {
score += query[d] * key[d];
}
score /= sqrt(float(headDim));
maxScore = max(maxScore, score);
}
float expSum = 0.0;
for (uint t = start; t < end; t++) {
device float* key = kvCache + t * 2 * nKvHeads * headDim + kvHeadIdx * headDim;
float score = 0.0;
for (uint d = 0; d < headDim; d++) {
score += query[d] * key[d];
}
score /= sqrt(float(headDim));
expSum += exp(score - maxScore);
}
float output = 0.0;
for (uint t = start; t < end; t++) {
device float* key = kvCache + t * 2 * nKvHeads * headDim + kvHeadIdx * headDim;
device float* value = kvCache + t * 2 * nKvHeads * headDim + nKvHeads * headDim + kvHeadIdx * headDim;
float score = 0.0;
for (uint d = 0; d < headDim; d++) {
score += query[d] * key[d];
}
score /= sqrt(float(headDim));
float weight = exp(score - maxScore) / expSum;
output += weight * value[dimIdx];
}
batchOutput[batchIdx * nHeads * headDim + headIdx * headDim + dimIdx] = output;
}