v2: remove remaining logit scaling hacks from batch/optimized paths
CI / build (push) Waiting to run
CI / unit-tests (push) Blocked by required conditions
CI / lint (push) Blocked by required conditions

This commit is contained in:
MarkBase Admin
2026-07-05 22:41:49 +08:00
parent 239474bef0
commit 7a8edf77ee
3 changed files with 0 additions and 32 deletions
-6
View File
@@ -161,12 +161,6 @@ extension E4BModel {
cmdBuf: cmdBuf
)
// Logits scaling
if embedWeight.groupSize == 32 && embedWeight.inDim == hiddenSize {
let logitsScale = Float(30.0 / 116.23 / sqrt(Float(hiddenSize)))
try scaleBufferOptimized(logitsBuffer, scale: logitsScale, count: vocabSize, cmdBuf: cmdBuf)
}
// Softcapping
if let cap = finalLogitSoftcapping {
try applyLogitSoftcappingOptimized(
@@ -160,26 +160,6 @@ embedCmdBuf.waitUntilCompleted()
encLM.dispatchThreads(gridLM, threadsPerThreadgroup: tgLM)
encLM.endEncoding()
// Logits scaling and softcapping (batch)
if embedWeight.groupSize == 32 {
let logitsScale = Float(30.0 / 116.23 / sqrt(Float(hiddenSize)))
// Use eltwise_scale for batch scaling
let pso = try engine.pipeline(named: "eltwise_scale")
let enc = layerCmdBuf.makeComputeCommandEncoder()!
enc.setComputePipelineState(pso)
enc.setBuffer(context.batchOutputBuffer, offset: 0, index: 0)
var ls = logitsScale
enc.setBytes(&ls, length: 4, index: 1)
var total = UInt32(batchSize * vocabSize)
enc.setBytes(&total, length: 4, index: 2)
let tg = MTLSize(width: 256, height: 1, depth: 1)
let grid = MTLSize(width: batchSize * vocabSize, height: 1, depth: 1)
enc.dispatchThreads(grid, threadsPerThreadgroup: tg)
enc.endEncoding()
}
// Softcapping (skip if kernel not found)
if let cap = finalLogitSoftcapping {
// Try to use tanh_scale kernel
-6
View File
@@ -110,12 +110,6 @@ extension E4BModel {
try quantizedMatmulOptimized(input: lmInput, weights: embedWeight,
output: logitsBuffer, cmdBuf: cmdBuf3)
// Logits scaling (if needed)
if embedWeight.groupSize == 32 && embedWeight.inDim == hiddenSize {
let logitsScale = Float(30.0 / 116.23 / sqrt(Float(hiddenSize)))
try scaleBufferOptimized(logitsBuffer, scale: logitsScale, count: vocabSize, cmdBuf: cmdBuf3)
}
// Logit softcapping
if let cap = finalLogitSoftcapping {
try applyLogitSoftcappingOptimized(buffer: logitsBuffer, cap: cap,