v2: remove remaining logit scaling hacks from batch/optimized paths
This commit is contained in:
@@ -161,12 +161,6 @@ extension E4BModel {
|
|||||||
cmdBuf: cmdBuf
|
cmdBuf: cmdBuf
|
||||||
)
|
)
|
||||||
|
|
||||||
// Logits scaling
|
|
||||||
if embedWeight.groupSize == 32 && embedWeight.inDim == hiddenSize {
|
|
||||||
let logitsScale = Float(30.0 / 116.23 / sqrt(Float(hiddenSize)))
|
|
||||||
try scaleBufferOptimized(logitsBuffer, scale: logitsScale, count: vocabSize, cmdBuf: cmdBuf)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Softcapping
|
// Softcapping
|
||||||
if let cap = finalLogitSoftcapping {
|
if let cap = finalLogitSoftcapping {
|
||||||
try applyLogitSoftcappingOptimized(
|
try applyLogitSoftcappingOptimized(
|
||||||
|
|||||||
@@ -160,26 +160,6 @@ embedCmdBuf.waitUntilCompleted()
|
|||||||
encLM.dispatchThreads(gridLM, threadsPerThreadgroup: tgLM)
|
encLM.dispatchThreads(gridLM, threadsPerThreadgroup: tgLM)
|
||||||
encLM.endEncoding()
|
encLM.endEncoding()
|
||||||
|
|
||||||
// Logits scaling and softcapping (batch)
|
|
||||||
if embedWeight.groupSize == 32 {
|
|
||||||
let logitsScale = Float(30.0 / 116.23 / sqrt(Float(hiddenSize)))
|
|
||||||
// Use eltwise_scale for batch scaling
|
|
||||||
let pso = try engine.pipeline(named: "eltwise_scale")
|
|
||||||
let enc = layerCmdBuf.makeComputeCommandEncoder()!
|
|
||||||
enc.setComputePipelineState(pso)
|
|
||||||
|
|
||||||
enc.setBuffer(context.batchOutputBuffer, offset: 0, index: 0)
|
|
||||||
var ls = logitsScale
|
|
||||||
enc.setBytes(&ls, length: 4, index: 1)
|
|
||||||
var total = UInt32(batchSize * vocabSize)
|
|
||||||
enc.setBytes(&total, length: 4, index: 2)
|
|
||||||
|
|
||||||
let tg = MTLSize(width: 256, height: 1, depth: 1)
|
|
||||||
let grid = MTLSize(width: batchSize * vocabSize, height: 1, depth: 1)
|
|
||||||
enc.dispatchThreads(grid, threadsPerThreadgroup: tg)
|
|
||||||
enc.endEncoding()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Softcapping (skip if kernel not found)
|
// Softcapping (skip if kernel not found)
|
||||||
if let cap = finalLogitSoftcapping {
|
if let cap = finalLogitSoftcapping {
|
||||||
// Try to use tanh_scale kernel
|
// Try to use tanh_scale kernel
|
||||||
|
|||||||
@@ -110,12 +110,6 @@ extension E4BModel {
|
|||||||
try quantizedMatmulOptimized(input: lmInput, weights: embedWeight,
|
try quantizedMatmulOptimized(input: lmInput, weights: embedWeight,
|
||||||
output: logitsBuffer, cmdBuf: cmdBuf3)
|
output: logitsBuffer, cmdBuf: cmdBuf3)
|
||||||
|
|
||||||
// Logits scaling (if needed)
|
|
||||||
if embedWeight.groupSize == 32 && embedWeight.inDim == hiddenSize {
|
|
||||||
let logitsScale = Float(30.0 / 116.23 / sqrt(Float(hiddenSize)))
|
|
||||||
try scaleBufferOptimized(logitsBuffer, scale: logitsScale, count: vocabSize, cmdBuf: cmdBuf3)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Logit softcapping
|
// Logit softcapping
|
||||||
if let cap = finalLogitSoftcapping {
|
if let cap = finalLogitSoftcapping {
|
||||||
try applyLogitSoftcappingOptimized(buffer: logitsBuffer, cap: cap,
|
try applyLogitSoftcappingOptimized(buffer: logitsBuffer, cap: cap,
|
||||||
|
|||||||
Reference in New Issue
Block a user