v2: remove remaining logit scaling hacks from batch/optimized paths
This commit is contained in:
@@ -161,12 +161,6 @@ extension E4BModel {
|
||||
cmdBuf: cmdBuf
|
||||
)
|
||||
|
||||
// Logits scaling
|
||||
if embedWeight.groupSize == 32 && embedWeight.inDim == hiddenSize {
|
||||
let logitsScale = Float(30.0 / 116.23 / sqrt(Float(hiddenSize)))
|
||||
try scaleBufferOptimized(logitsBuffer, scale: logitsScale, count: vocabSize, cmdBuf: cmdBuf)
|
||||
}
|
||||
|
||||
// Softcapping
|
||||
if let cap = finalLogitSoftcapping {
|
||||
try applyLogitSoftcappingOptimized(
|
||||
|
||||
@@ -160,26 +160,6 @@ embedCmdBuf.waitUntilCompleted()
|
||||
encLM.dispatchThreads(gridLM, threadsPerThreadgroup: tgLM)
|
||||
encLM.endEncoding()
|
||||
|
||||
// Logits scaling and softcapping (batch)
|
||||
if embedWeight.groupSize == 32 {
|
||||
let logitsScale = Float(30.0 / 116.23 / sqrt(Float(hiddenSize)))
|
||||
// Use eltwise_scale for batch scaling
|
||||
let pso = try engine.pipeline(named: "eltwise_scale")
|
||||
let enc = layerCmdBuf.makeComputeCommandEncoder()!
|
||||
enc.setComputePipelineState(pso)
|
||||
|
||||
enc.setBuffer(context.batchOutputBuffer, offset: 0, index: 0)
|
||||
var ls = logitsScale
|
||||
enc.setBytes(&ls, length: 4, index: 1)
|
||||
var total = UInt32(batchSize * vocabSize)
|
||||
enc.setBytes(&total, length: 4, index: 2)
|
||||
|
||||
let tg = MTLSize(width: 256, height: 1, depth: 1)
|
||||
let grid = MTLSize(width: batchSize * vocabSize, height: 1, depth: 1)
|
||||
enc.dispatchThreads(grid, threadsPerThreadgroup: tg)
|
||||
enc.endEncoding()
|
||||
}
|
||||
|
||||
// Softcapping (skip if kernel not found)
|
||||
if let cap = finalLogitSoftcapping {
|
||||
// Try to use tanh_scale kernel
|
||||
|
||||
@@ -110,12 +110,6 @@ extension E4BModel {
|
||||
try quantizedMatmulOptimized(input: lmInput, weights: embedWeight,
|
||||
output: logitsBuffer, cmdBuf: cmdBuf3)
|
||||
|
||||
// Logits scaling (if needed)
|
||||
if embedWeight.groupSize == 32 && embedWeight.inDim == hiddenSize {
|
||||
let logitsScale = Float(30.0 / 116.23 / sqrt(Float(hiddenSize)))
|
||||
try scaleBufferOptimized(logitsBuffer, scale: logitsScale, count: vocabSize, cmdBuf: cmdBuf3)
|
||||
}
|
||||
|
||||
// Logit softcapping
|
||||
if let cap = finalLogitSoftcapping {
|
||||
try applyLogitSoftcappingOptimized(buffer: logitsBuffer, cap: cap,
|
||||
|
||||
Reference in New Issue
Block a user