Files
markbaseengine/Sources/MarkBase/BatchTemps.swift
T
MarkBase Admin 8a66b9086a
CI / build (push) Waiting to run
CI / unit-tests (push) Blocked by required conditions
CI / lint (push) Blocked by required conditions
v2: Initial clean branch with unit tests + CI/CD pipeline
- Started from ac75faa (initial E4B-MarkBase integration)
- Kept Sources/ (all engine code) + Package.swift + .gitignore
- Removed all ad-hoc tests, documentation, scripts, Python files
- Added Tests/00_Unit/ (MathTest, TokenizerTest, SamplerTest)
- Added .gitea/workflows/ci.yaml (build + unit tests + lint)
- Added Scripts/check_resources.sh (memory-aware test runner)
- Added Tests/Manifest.json (resource requirements for all tests)
- Focus: 4-bit quantized models only
2026-07-05 13:29:25 +08:00

58 lines
2.3 KiB
Swift

import Metal
// ══════════════════════════════════════════════════════════════════
// Batch Forward Temps - Extended buffers for batch processing
// ══════════════════════════════════════════════════════════════════
extension ForwardTemps {
/// Create batch-specific temporary buffers
/// These are separate from single-token buffers to avoid interference
public func createBatchBuffers(
device: MTLDevice,
batchSize: Int,
hiddenSize: Int,
nHeads: Int,
headDim: Int,
intermediateSize: Int
) throws -> BatchTemps {
return try BatchTemps(
device: device,
batchSize: batchSize,
hiddenSize: hiddenSize,
nHeads: nHeads,
headDim: headDim,
intermediateSize: intermediateSize
)
}
}
/// Batch-specific temporary buffers for parallel layer processing
public struct BatchTemps {
public let hBatch: MTLBuffer // [batchSize, hiddenSize] - hidden state batch
public let qBatch: MTLBuffer // [batchSize, nHeads * headDim] - query batch
public let nsBatch: MTLBuffer // [batchSize, nHeads * headDim] - norm scratch batch
public let interBatch: MTLBuffer // [batchSize, intermediateSize] - intermediate batch
public init(
device: MTLDevice,
batchSize: Int,
hiddenSize: Int,
nHeads: Int,
headDim: Int,
intermediateSize: Int
) throws {
func buf(_ n: Int) throws -> MTLBuffer {
guard let b = device.makeBuffer(length: n * MemoryLayout<Float>.stride,
options: .storageModeShared)
else { throw NSError(domain: "BatchTemps", code: -1,
userInfo: [NSLocalizedDescriptionKey: "Buffer creation failed"]) }
return b
}
hBatch = try buf(batchSize * hiddenSize)
qBatch = try buf(batchSize * nHeads * headDim)
nsBatch = try buf(batchSize * nHeads * headDim)
interBatch = try buf(batchSize * intermediateSize)
}
}