8a66b9086a
- Started from ac75faa (initial E4B-MarkBase integration)
- Kept Sources/ (all engine code) + Package.swift + .gitignore
- Removed all ad-hoc tests, documentation, scripts, Python files
- Added Tests/00_Unit/ (MathTest, TokenizerTest, SamplerTest)
- Added .gitea/workflows/ci.yaml (build + unit tests + lint)
- Added Scripts/check_resources.sh (memory-aware test runner)
- Added Tests/Manifest.json (resource requirements for all tests)
- Focus: 4-bit quantized models only
154 lines
5.7 KiB
Metal
154 lines
5.7 KiB
Metal
#include <metal_stdlib>
|
|
using namespace metal;
|
|
|
|
// ═══════════════════════════════════════════════════════════════
|
|
// Batch Metal Kernels - Process multiple tokens simultaneously
|
|
// ═══════════════════════════════════════════════════════════════
|
|
|
|
// Batch quantized matmul - process N tokens with shared weights
|
|
kernel void quantized_matmul_batch(
|
|
device float* batchInput [[buffer(0)]], // [batchSize, inDim]
|
|
device uint8_t* weights [[buffer(1)]], // [outDim, inDim] packed
|
|
device float* scales [[buffer(2)]], // [outDim, groups]
|
|
device float* biases [[buffer(3)]], // [outDim]
|
|
device float* batchOutput [[buffer(4)]], // [batchSize, outDim]
|
|
constant uint32_t& inDim [[buffer(5)]],
|
|
constant uint32_t& outDim [[buffer(6)]],
|
|
constant uint32_t& groupSize [[buffer(7)]],
|
|
constant uint32_t& batchSize [[buffer(8)]],
|
|
uint3 gid [[thread_position_in_grid]])
|
|
{
|
|
uint batchIdx = gid.x;
|
|
uint outIdx = gid.y;
|
|
|
|
if (batchIdx >= batchSize || outIdx >= outDim) return;
|
|
|
|
device float* input = batchInput + batchIdx * inDim;
|
|
float sum = biases[outIdx];
|
|
uint groupIdx = outIdx * (inDim / groupSize);
|
|
|
|
for (uint i = 0; i < inDim; i += 4) {
|
|
float4 inVals = float4(input[i], input[i+1], input[i+2], input[i+3]);
|
|
|
|
uint packedWeight = weights[outIdx * inDim + i];
|
|
uint8_t w0 = (packedWeight >> 0) & 0xFF;
|
|
uint8_t w1 = (packedWeight >> 8) & 0xFF;
|
|
uint8_t w2 = (packedWeight >> 16) & 0xFF;
|
|
uint8_t w3 = (packedWeight >> 24) & 0xFF;
|
|
|
|
uint g0 = (i + 0) / groupSize;
|
|
uint g1 = (i + 1) / groupSize;
|
|
uint g2 = (i + 2) / groupSize;
|
|
uint g3 = (i + 3) / groupSize;
|
|
|
|
float scale0 = scales[groupIdx + g0];
|
|
float scale1 = scales[groupIdx + g1];
|
|
float scale2 = scales[groupIdx + g2];
|
|
float scale3 = scales[groupIdx + g3];
|
|
|
|
sum += inVals.x * (w0 - 128) * scale0;
|
|
sum += inVals.y * (w1 - 128) * scale1;
|
|
sum += inVals.z * (w2 - 128) * scale2;
|
|
sum += inVals.w * (w3 - 128) * scale3;
|
|
}
|
|
|
|
batchOutput[batchIdx * outDim + outIdx] = sum;
|
|
}
|
|
|
|
// Batch RMS norm - process N tokens simultaneously
|
|
kernel void rms_norm_batch(
|
|
device float* batchInput [[buffer(0)]], // [batchSize, N]
|
|
device float* weights [[buffer(1)]], // [N]
|
|
device float* batchOutput [[buffer(2)]], // [batchSize, N]
|
|
constant uint32_t& N [[buffer(3)]],
|
|
constant float& eps [[buffer(4)]],
|
|
constant uint32_t& batchSize [[buffer(5)]],
|
|
uint3 gid [[thread_position_in_grid]])
|
|
{
|
|
uint batchIdx = gid.x;
|
|
uint elemIdx = gid.y;
|
|
|
|
if (batchIdx >= batchSize || elemIdx >= N) return;
|
|
|
|
device float* input = batchInput + batchIdx * N;
|
|
float sqSum = 0.0;
|
|
for (uint i = 0; i < N; i++) {
|
|
sqSum += input[i] * input[i];
|
|
}
|
|
|
|
float rms = sqrt(sqSum / float(N) + eps);
|
|
batchOutput[batchIdx * N + elemIdx] = input[elemIdx] / rms * weights[elemIdx];
|
|
}
|
|
|
|
// Batch attention (simplified - for demonstration)
|
|
// Full implementation would require complex KV cache management
|
|
kernel void sliding_attention_batch(
|
|
device float* batchQuery [[buffer(0)]], // [batchSize, nHeads, headDim]
|
|
device float* kvCache [[buffer(1)]], // [maxSeqLen, 2, nKvHeads, headDim]
|
|
device float* batchOutput [[buffer(2)]], // [batchSize, nHeads, headDim]
|
|
constant uint32_t* positions [[buffer(3)]], // [batchSize]
|
|
constant uint32_t& nHeads [[buffer(4)]],
|
|
constant uint32_t& nKvHeads [[buffer(5)]],
|
|
constant uint32_t& headDim [[buffer(6)]],
|
|
constant uint32_t& batchSize [[buffer(7)]],
|
|
constant uint32_t& windowSize [[buffer(8)]],
|
|
uint3 gid [[thread_position_in_grid]])
|
|
{
|
|
uint batchIdx = gid.x;
|
|
uint headIdx = gid.y;
|
|
uint dimIdx = gid.z;
|
|
|
|
if (batchIdx >= batchSize || headIdx >= nHeads || dimIdx >= headDim) return;
|
|
|
|
uint pos = positions[batchIdx];
|
|
uint kvHeadIdx = headIdx / (nHeads / nKvHeads);
|
|
|
|
device float* query = batchQuery + batchIdx * nHeads * headDim + headIdx * headDim;
|
|
|
|
uint start = max(0u, pos - windowSize);
|
|
uint end = pos;
|
|
|
|
float maxScore = -1e10;
|
|
for (uint t = start; t < end; t++) {
|
|
device float* key = kvCache + t * 2 * nKvHeads * headDim + kvHeadIdx * headDim;
|
|
|
|
float score = 0.0;
|
|
for (uint d = 0; d < headDim; d++) {
|
|
score += query[d] * key[d];
|
|
}
|
|
|
|
score /= sqrt(float(headDim));
|
|
maxScore = max(maxScore, score);
|
|
}
|
|
|
|
float expSum = 0.0;
|
|
for (uint t = start; t < end; t++) {
|
|
device float* key = kvCache + t * 2 * nKvHeads * headDim + kvHeadIdx * headDim;
|
|
|
|
float score = 0.0;
|
|
for (uint d = 0; d < headDim; d++) {
|
|
score += query[d] * key[d];
|
|
}
|
|
|
|
score /= sqrt(float(headDim));
|
|
expSum += exp(score - maxScore);
|
|
}
|
|
|
|
float output = 0.0;
|
|
for (uint t = start; t < end; t++) {
|
|
device float* key = kvCache + t * 2 * nKvHeads * headDim + kvHeadIdx * headDim;
|
|
device float* value = kvCache + t * 2 * nKvHeads * headDim + nKvHeads * headDim + kvHeadIdx * headDim;
|
|
|
|
float score = 0.0;
|
|
for (uint d = 0; d < headDim; d++) {
|
|
score += query[d] * key[d];
|
|
}
|
|
|
|
score /= sqrt(float(headDim));
|
|
float weight = exp(score - maxScore) / expSum;
|
|
|
|
output += weight * value[dimIdx];
|
|
}
|
|
|
|
batchOutput[batchIdx * nHeads * headDim + headIdx * headDim + dimIdx] = output;
|
|
} |