optimize the perf and support more features

This commit is contained in:
Lei Xue
2026-03-14 11:45:35 +08:00
parent 7e7ebacd9d
commit 00cfac3d24
56 changed files with 6340 additions and 1019 deletions

View File

@@ -5,7 +5,7 @@ Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
@@ -23,8 +23,8 @@ import (
)
// TargetCreate creates a target in the SCSI Target.
func (cli *Client) TargetCreate(ctx context.Context, options api.TargetCreateRequest) (api.SCSITarget, error) {
var target api.SCSITarget
func (cli *Client) TargetCreate(ctx context.Context, options api.TargetCreateRequest) (*api.SCSITarget, error) {
var target *api.SCSITarget
resp, err := cli.post(ctx, "/target/create", nil, options, nil)
if err != nil {
return target, err

View File

@@ -23,9 +23,9 @@ import (
"golang.org/x/net/context"
)
// TargetCreate creates a target in the SCSI Target.
func (cli *Client) TargetList(ctx context.Context, options api.TargetListOptions) ([]api.SCSITarget, error) {
var targets []api.SCSITarget
// TargetList lists targets in the SCSI Target.
func (cli *Client) TargetList(ctx context.Context, options api.TargetListOptions) ([]*api.SCSITarget, error) {
var targets []*api.SCSITarget
var query = url.Values{}
if options.Name != "" {
query.Set("name", options.Name)

View File

@@ -5,7 +5,7 @@ Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
@@ -20,7 +20,7 @@ import (
"io"
"sync"
uuid "github.com/satori/go.uuid"
"github.com/google/uuid"
)
type SCSICommandType byte
@@ -370,7 +370,7 @@ type ModePage struct {
PageCode uint8
// Sub page code
SubPageCode uint8
Size uint8
Size uint16 // Use uint16 to support pages larger than 255 bytes
// Rest of mode page info
Data []byte
}

View File

@@ -5,7 +5,7 @@ Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,

View File

@@ -5,7 +5,7 @@ Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,

View File

@@ -5,7 +5,7 @@ Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,

View File

@@ -5,7 +5,7 @@ Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,

View File

@@ -5,7 +5,7 @@ Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,

View File

@@ -106,6 +106,14 @@ type BackendStorage struct {
Online bool `json:"online"`
ThinProvisioning bool `json:"thinProvisioning"`
BlockShift uint `json:"blockShift"`
// BackendType specifies the backend storage type (file, iouring, etc.)
BackendType string `json:"backendType,omitempty"`
// EnableNUMA enables NUMA-aware memory allocation for this storage
EnableNUMA bool `json:"enableNUMA,omitempty"`
// NumaNode specifies the preferred NUMA node for this storage (-1 for auto)
NumaNode int `json:"numaNode,omitempty"`
// IoUringQueueDepth specifies the io_uring queue depth (0 for default)
IoUringQueueDepth uint32 `json:"ioUringQueueDepth,omitempty"`
}
type ISCSIPortalInfo struct {
@@ -118,10 +126,25 @@ type ISCSITarget struct {
LUNs map[string]uint64 `json:"luns"`
}
type PerformanceConfig struct {
// EnableNUMA enables NUMA-aware memory allocation
EnableNUMA bool `json:"enableNUMA,omitempty"`
// EnableIoUring enables io_uring backend storage support (Linux 5.1+)
EnableIoUring bool `json:"enableIoUring,omitempty"`
// IoUringQueueDepth sets the io_uring queue depth
IoUringQueueDepth uint32 `json:"ioUringQueueDepth,omitempty"`
// NUMABufferPoolSize sets the number of buffers per NUMA node
NUMABufferPoolSize int `json:"numaBufferPoolSize,omitempty"`
// NUMABufferSize sets the size of NUMA-local buffers
NUMABufferSize int `json:"numaBufferSize,omitempty"`
}
type Config struct {
Storages []BackendStorage `json:"storages"`
ISCSIPortals []ISCSIPortalInfo `json:"iscsiportals"`
ISCSITargets map[string]ISCSITarget `json:"iscsitargets"`
// Performance settings
Performance PerformanceConfig `json:"performance,omitempty"`
}
func init() {

View File

@@ -17,16 +17,105 @@ limitations under the License.
package iscsit
import (
"bytes"
"fmt"
"strings"
"sync"
"time"
"github.com/gostor/gotgt/pkg/api"
"github.com/gostor/gotgt/pkg/util"
"github.com/gostor/gotgt/pkg/util/numa"
log "github.com/sirupsen/logrus"
)
// Object pools to reduce GC pressure
var (
// commandPool reuses ISCSICommand objects
commandPool = sync.Pool{
New: func() interface{} {
return &ISCSICommand{}
},
}
// bufferPool reuses small buffers for BHS reading
bufferPool = sync.Pool{
New: func() interface{} {
buf := make([]byte, BHS_SIZE)
return &buf
},
}
// numaBufferPool NUMA-aware buffer pool for larger I/O operations
numaBufferPool *numa.NUMABufferPool
numaPoolOnce sync.Once
)
// initNUMAPool initializes the NUMA-aware buffer pool
func initNUMAPool() {
numaPoolOnce.Do(func() {
numaBufferPool = numa.NewNUMABufferPool(&numa.BufferPoolConfig{
BufferSize: 256 * 1024, // 256KB for I/O buffers
PerNodePoolSize: 512,
EnableNUMA: numa.Available(),
})
})
}
// getCommand gets an ISCSICommand from the pool
func getCommand() *ISCSICommand {
return commandPool.Get().(*ISCSICommand)
}
// putCommand puts an ISCSICommand back to the pool
func putCommand(cmd *ISCSICommand) {
if cmd == nil {
return
}
// Clear sensitive data
cmd.RawData = nil
cmd.RawHeader = nil
cmd.CDB = nil
cmd.DataLen = 0
*cmd = ISCSICommand{}
commandPool.Put(cmd)
}
// getBuffer gets a buffer from the pool
func getBuffer() []byte {
return *bufferPool.Get().(*[]byte)
}
// putBuffer puts a buffer back to the pool
func putBuffer(buf []byte) {
if cap(buf) >= BHS_SIZE {
bufferPool.Put(&buf)
}
}
// getIOBuffer gets a NUMA-aware I/O buffer for larger data operations
func getIOBuffer(size int) []byte {
initNUMAPool()
if size <= numaBufferPool.GetConfig().BufferSize {
return numaBufferPool.Get()[:size]
}
return make([]byte, size)
}
// putIOBuffer puts a NUMA-aware I/O buffer back to the pool
func putIOBuffer(buf []byte) {
if numaBufferPool != nil && cap(buf) >= numaBufferPool.GetConfig().BufferSize {
numaBufferPool.Put(buf)
}
}
// NUMAStats returns NUMA buffer pool statistics
func NUMAStats() numa.PoolStats {
if numaBufferPool == nil {
return numa.PoolStats{}
}
return numaBufferPool.Stats()
}
type OpCode int
const (
@@ -164,6 +253,8 @@ func (cmd *ISCSICommand) Bytes() []byte {
return cmd.scsiTMFRespBytes()
case OpReady:
return cmd.r2tRespBytes()
case OpAsync:
return cmd.asyncMsgBytes()
}
return nil
}
@@ -237,7 +328,7 @@ func parseHeader(data []byte) (*ISCSICommand, error) {
m.CmdSN = uint32(ParseUint(data[24:28]))
m.Read = data[1]&0x40 == 0x40
m.Write = data[1]&0x20 == 0x20
m.CDB = data[32:48]
m.CDB = append([]byte{}, data[32:48]...)
m.ExpStatSN = uint32(ParseUint(data[28:32]))
m.SCSIOpCode = m.CDB[0]
SCSIOpcode := api.SCSICommandType(m.SCSIOpCode)
@@ -290,9 +381,12 @@ func parseHeader(data []byte) (*ISCSICommand, error) {
}
func (m *ISCSICommand) scsiCmdRespBytes() []byte {
// rfc7143 11.4
buf := bytes.Buffer{}
buf.WriteByte(byte(OpSCSIResp))
// rfc7143 11.4 - BHS 48 bytes + data (4-byte aligned)
rawDataLen := len(m.RawData)
padding := (4 - rawDataLen%4) % 4
buf := make([]byte, 48+rawDataLen+padding)
buf[0] = byte(OpSCSIResp)
var flag byte = 0x80
if m.Resid > 0 {
if m.Resid > m.ExpectedDataLen {
@@ -301,50 +395,46 @@ func (m *ISCSICommand) scsiCmdRespBytes() []byte {
flag |= 0x02
}
}
buf.WriteByte(flag)
buf.WriteByte(byte(m.SCSIResponse))
buf.WriteByte(byte(m.Status))
buf[1] = flag
buf[2] = byte(m.SCSIResponse)
buf[3] = byte(m.Status)
buf.WriteByte(0x00)
buf.Write(util.MarshalUint64(uint64(len(m.RawData)))[5:]) // 5-8
// Skip through to byte 16
for i := 0; i < 8; i++ {
buf.WriteByte(0x00)
}
buf.Write(util.MarshalUint64(uint64(m.TaskTag))[4:])
for i := 0; i < 4; i++ {
buf.WriteByte(0x00)
}
buf.Write(util.MarshalUint64(uint64(m.StatSN))[4:])
buf.Write(util.MarshalUint64(uint64(m.ExpCmdSN))[4:])
buf.Write(util.MarshalUint64(uint64(m.MaxCmdSN))[4:])
for i := 0; i < 2*4; i++ {
buf.WriteByte(0x00)
}
buf.Write(util.MarshalUint64(uint64(m.Resid))[4:])
buf.Write(m.RawData)
dl := len(m.RawData)
for dl%4 > 0 {
dl++
buf.WriteByte(0x00)
}
// byte 4 is reserved (0)
// Write data length (24-bit big-endian) at bytes 5-7
buf[5] = byte(rawDataLen >> 16)
buf[6] = byte(rawDataLen >> 8)
buf[7] = byte(rawDataLen)
// bytes 9-15 are reserved (0)
// TaskTag at bytes 16-19 (32-bit big-endian)
util.MarshalUint32To(buf[16:], m.TaskTag)
// bytes 20-23 are reserved (0)
// StatSN at bytes 24-27
util.MarshalUint32To(buf[24:], m.StatSN)
// ExpCmdSN at bytes 28-31
util.MarshalUint32To(buf[28:], m.ExpCmdSN)
// MaxCmdSN at bytes 32-35
util.MarshalUint32To(buf[32:], m.MaxCmdSN)
// bytes 36-43 are reserved (0)
// Resid at bytes 44-47
util.MarshalUint32To(buf[44:], m.Resid)
copy(buf[48:], m.RawData)
// padding bytes are already zero
return buf.Bytes()
return buf
}
func (m *ISCSICommand) dataInBytes() []byte {
// rfc7143 11.7
dl := m.DataLen
for dl%4 > 0 {
dl++
}
var buf = make([]byte, (48 + dl))
// Calculate padded length using bit operation instead of loop
dl := (m.DataLen + 3) &^ 3 // Round up to multiple of 4
buf := make([]byte, 48+dl)
buf[0] = byte(OpSCSIIn)
var flag byte
if m.FinalInSeq || m.Final == true {
if m.FinalInSeq || m.Final {
flag |= 0x80
}
if m.HasStatus && m.Final == true {
if m.HasStatus && m.Final {
flag |= 0x01
}
log.Debugf("resid: %v, ExpectedDataLen: %v", m.Resid, m.ExpectedDataLen)
@@ -356,22 +446,22 @@ func (m *ISCSICommand) dataInBytes() []byte {
}
}
buf[1] = flag
//buf.WriteByte(0x00)
if m.HasStatus && m.Final == true {
flag = byte(m.Status)
if m.HasStatus && m.Final {
buf[3] = byte(m.Status)
}
//buf.WriteByte(flag)
buf[3] = flag
copy(buf[5:], util.MarshalUint64(uint64(m.DataLen))[5:])
// Data length (24-bit) at bytes 5-7
buf[5] = byte(m.DataLen >> 16)
buf[6] = byte(m.DataLen >> 8)
buf[7] = byte(m.DataLen)
// Skip through to byte 16 Since A bit is not set 11.7.4
copy(buf[16:], util.MarshalUint32(m.TaskTag))
copy(buf[24:], util.MarshalUint32(m.StatSN))
copy(buf[28:], util.MarshalUint32(m.ExpCmdSN))
copy(buf[32:], util.MarshalUint32(m.MaxCmdSN))
copy(buf[36:], util.MarshalUint32(m.DataSN))
copy(buf[40:], util.MarshalUint32(m.BufferOffset))
copy(buf[44:], util.MarshalUint32(m.Resid))
if m.ExpectedDataLen != 0 {
util.MarshalUint32To(buf[16:], m.TaskTag)
util.MarshalUint32To(buf[24:], m.StatSN)
util.MarshalUint32To(buf[28:], m.ExpCmdSN)
util.MarshalUint32To(buf[32:], m.MaxCmdSN)
util.MarshalUint32To(buf[36:], m.DataSN)
util.MarshalUint32To(buf[40:], m.BufferOffset)
util.MarshalUint32To(buf[44:], m.Resid)
if m.DataLen != 0 {
copy(buf[48:], m.RawData[m.BufferOffset:m.BufferOffset+uint32(m.DataLen)])
}
@@ -379,8 +469,13 @@ func (m *ISCSICommand) dataInBytes() []byte {
}
func (m *ISCSICommand) textRespBytes() []byte {
buf := bytes.Buffer{}
buf.WriteByte(byte(OpTextResp))
// Pre-calculate required capacity: BHS(48 bytes) + data (4-byte aligned)
dataLen := len(m.RawData)
padding := (4 - dataLen%4) % 4
buf := make([]byte, 48+dataLen+padding)
buf[0] = byte(OpTextResp)
var b byte
if m.Final {
b |= 0x80
@@ -389,122 +484,149 @@ func (m *ISCSICommand) textRespBytes() []byte {
b |= 0x40
}
// byte 1
buf.WriteByte(b)
buf[1] = b
b = 0
buf.WriteByte(b)
buf.WriteByte(b)
buf.WriteByte(b)
buf.Write(util.MarshalUint64(uint64(len(m.RawData)))[5:]) // 5-8
// Skip through to byte 12
for i := 0; i < 2*4; i++ {
buf.WriteByte(0x00)
}
buf.Write(util.MarshalUint64(uint64(m.TaskTag))[4:])
for i := 0; i < 4; i++ {
buf.WriteByte(0xff)
}
buf.Write(util.MarshalUint64(uint64(m.StatSN))[4:])
buf.Write(util.MarshalUint64(uint64(m.ExpCmdSN))[4:])
buf.Write(util.MarshalUint64(uint64(m.MaxCmdSN))[4:])
for i := 0; i < 3*4; i++ {
buf.WriteByte(0x00)
}
rd := m.RawData
for len(rd)%4 != 0 {
rd = append(rd, 0)
}
buf.Write(rd)
return buf.Bytes()
// bytes 2,3,4 reserved (0)
// bytes 5-8: data segment length (24-bit)
buf[5] = byte(dataLen >> 16)
buf[6] = byte(dataLen >> 8)
buf[7] = byte(dataLen)
// bytes 8-15 are reserved (0)
// bytes 16-19: TaskTag
util.MarshalUint32To(buf[16:], m.TaskTag)
// bytes 20-23: 0xffffffff
buf[20] = 0xff
buf[21] = 0xff
buf[22] = 0xff
buf[23] = 0xff
// bytes 24-27: StatSN
util.MarshalUint32To(buf[24:], m.StatSN)
// bytes 28-31: ExpCmdSN
util.MarshalUint32To(buf[28:], m.ExpCmdSN)
// bytes 32-35: MaxCmdSN
util.MarshalUint32To(buf[32:], m.MaxCmdSN)
// bytes 36-47 are reserved (0)
// Copy data
copy(buf[48:], m.RawData)
// padding bytes are already zero
return buf
}
func (m *ISCSICommand) noopInBytes() []byte {
buf := bytes.Buffer{}
buf.WriteByte(byte(OpNoopIn))
var b byte
b |= 0x80
// byte 1
buf.WriteByte(b)
// rfc7143 11.11 - BHS 48 bytes + data (4-byte aligned)
rawDataLen := len(m.RawData)
padding := (4 - rawDataLen%4) % 4
buf := make([]byte, 48+rawDataLen+padding)
b = 0
buf.WriteByte(b)
buf.WriteByte(b)
buf.WriteByte(b)
buf.Write(util.MarshalUint64(uint64(len(m.RawData)))[5:]) // 5-8
// Skip through to byte 12
for i := 0; i < 2*4; i++ {
buf.WriteByte(0x00)
}
buf.Write(util.MarshalUint64(uint64(m.TaskTag))[4:])
for i := 0; i < 4; i++ {
buf.WriteByte(0xff)
}
buf.Write(util.MarshalUint64(uint64(m.StatSN))[4:])
buf.Write(util.MarshalUint64(uint64(m.ExpCmdSN))[4:])
buf.Write(util.MarshalUint64(uint64(m.MaxCmdSN))[4:])
for i := 0; i < 3*4; i++ {
buf.WriteByte(0x00)
}
rd := m.RawData
for len(rd)%4 != 0 {
rd = append(rd, 0)
}
buf.Write(rd)
return buf.Bytes()
buf[0] = byte(OpNoopIn)
buf[1] = 0x80
// bytes 2-3 are reserved (0)
// bytes 4-7: data segment length (32-bit)
util.MarshalUint32To(buf[4:], uint32(rawDataLen))
// bytes 8-15 are reserved (0)
// bytes 16-19: TaskTag
util.MarshalUint32To(buf[16:], m.TaskTag)
// bytes 20-23: 0xffffffff
buf[20] = 0xff
buf[21] = 0xff
buf[22] = 0xff
buf[23] = 0xff
// bytes 24-27: StatSN
util.MarshalUint32To(buf[24:], m.StatSN)
// bytes 28-31: ExpCmdSN
util.MarshalUint32To(buf[28:], m.ExpCmdSN)
// bytes 32-35: MaxCmdSN
util.MarshalUint32To(buf[32:], m.MaxCmdSN)
// bytes 36-47 are reserved (0)
copy(buf[48:], m.RawData)
// padding bytes are already zero
return buf
}
func (m *ISCSICommand) scsiTMFRespBytes() []byte {
// rfc7143 11.6
buf := bytes.Buffer{}
buf.WriteByte(byte(OpSCSITaskResp))
buf.WriteByte(0x80)
buf.WriteByte(m.Result)
buf.WriteByte(0x00)
// Skip through to byte 16
for i := 0; i < 3*4; i++ {
buf.WriteByte(0x00)
}
buf.Write(util.MarshalUint64(uint64(m.TaskTag))[4:])
for i := 0; i < 4; i++ {
buf.WriteByte(0x00)
}
buf.Write(util.MarshalUint64(uint64(m.StatSN))[4:])
buf.Write(util.MarshalUint64(uint64(m.ExpCmdSN))[4:])
buf.Write(util.MarshalUint64(uint64(m.MaxCmdSN))[4:])
for i := 0; i < 3*4; i++ {
buf.WriteByte(0x00)
}
return buf.Bytes()
// rfc7143 11.6 - Fixed 48 bytes
buf := make([]byte, 48)
buf[0] = byte(OpSCSITaskResp)
buf[1] = 0x80
buf[2] = m.Result
// byte 3 is reserved (0)
// bytes 4-15 are reserved (0)
// bytes 16-19: TaskTag
util.MarshalUint32To(buf[16:], m.TaskTag)
// bytes 20-23 are reserved (0)
// bytes 24-27: StatSN
util.MarshalUint32To(buf[24:], m.StatSN)
// bytes 28-31: ExpCmdSN
util.MarshalUint32To(buf[28:], m.ExpCmdSN)
// bytes 32-35: MaxCmdSN
util.MarshalUint32To(buf[32:], m.MaxCmdSN)
// bytes 36-47 are reserved (0)
return buf
}
func (m *ISCSICommand) r2tRespBytes() []byte {
// rfc7143 11.8
buf := bytes.Buffer{}
buf.WriteByte(byte(OpReady))
var b byte
// rfc7143 11.8 - Fixed 48 bytes
buf := make([]byte, 48)
buf[0] = byte(OpReady)
if m.Final {
b |= 0x80
buf[1] = 0x80
}
buf.WriteByte(b)
buf.WriteByte(0x00)
buf.WriteByte(0x00)
// Skip through to byte 16
for i := 0; i < 3*4; i++ {
buf.WriteByte(0x00)
}
buf.Write(util.MarshalUint64(uint64(m.TaskTag))[4:])
for i := 0; i < 4; i++ {
buf.WriteByte(0x00)
}
buf.Write(util.MarshalUint64(uint64(m.StatSN))[4:])
buf.Write(util.MarshalUint64(uint64(m.ExpCmdSN))[4:])
buf.Write(util.MarshalUint64(uint64(m.MaxCmdSN))[4:])
buf.Write(util.MarshalUint64(uint64(m.R2TSN))[4:])
buf.Write(util.MarshalUint64(uint64(m.BufferOffset))[4:])
buf.Write(util.MarshalUint64(uint64(m.DesiredLength))[4:])
return buf.Bytes()
// bytes 2-15 are reserved (0)
// bytes 16-19: TaskTag
util.MarshalUint32To(buf[16:], m.TaskTag)
// bytes 20-23 are reserved (0)
// bytes 24-27: StatSN
util.MarshalUint32To(buf[24:], m.StatSN)
// bytes 28-31: ExpCmdSN
util.MarshalUint32To(buf[28:], m.ExpCmdSN)
// bytes 32-35: MaxCmdSN
util.MarshalUint32To(buf[32:], m.MaxCmdSN)
// bytes 36-39: R2TSN
util.MarshalUint32To(buf[36:], m.R2TSN)
// bytes 40-43: BufferOffset
util.MarshalUint32To(buf[40:], m.BufferOffset)
// bytes 44-47: DesiredLength
util.MarshalUint32To(buf[44:], m.DesiredLength)
return buf
}
// asyncMsgBytes implements RFC 7143 section 11.10 - Asynchronous Message
func (m *ISCSICommand) asyncMsgBytes() []byte {
// rfc7143 11.10 - BHS 48 bytes + data (4-byte aligned)
rawDataLen := len(m.RawData)
padding := (4 - rawDataLen%4) % 4
buf := make([]byte, 48+rawDataLen+padding)
buf[0] = byte(OpAsync)
// byte 1: AsyncEvent in bits 0-4
buf[1] = byte(m.SCSIOpCode & 0x1f)
// bytes 2-3 are reserved (0)
// byte 4: 0x80 if AsyncEvent is 0 (SCSI Asynchronous Event)
if m.SCSIOpCode == 0 {
buf[4] = 0x80
}
// bytes 5-7: data segment length (24-bit)
buf[5] = byte(rawDataLen >> 16)
buf[6] = byte(rawDataLen >> 8)
buf[7] = byte(rawDataLen)
// bytes 8-15: LUN (if applicable)
copy(buf[8:], m.LUN[:])
// bytes 16-19: Reserved (0)
// bytes 20-23: Target Transfer Tag (0xffffffff for Async)
buf[20] = 0xff
buf[21] = 0xff
buf[22] = 0xff
buf[23] = 0xff
// bytes 24-27: StatSN
util.MarshalUint32To(buf[24:], m.StatSN)
// bytes 28-31: ExpCmdSN
util.MarshalUint32To(buf[28:], m.ExpCmdSN)
// bytes 32-35: MaxCmdSN
util.MarshalUint32To(buf[32:], m.MaxCmdSN)
// bytes 36-43: Reserved (0)
// bytes 44-47: Parameter1 and Parameter2 (context-specific)
copy(buf[48:], m.RawData)
return buf
}

146
pkg/port/iscsit/cmd_test.go Normal file
View File

@@ -0,0 +1,146 @@
/*
Copyright 2024 The GoStor Authors All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package iscsit
import (
"testing"
"unsafe"
)
func TestGetPutCommand(t *testing.T) {
// Test getting command object
cmd1 := getCommand()
if cmd1 == nil {
t.Fatal("getCommand() returned nil")
}
// Set some values
cmd1.TaskTag = 12345
cmd1.DataLen = 100
cmd1.ExpCmdSN = 999
// Put back to pool
putCommand(cmd1)
// Get again, verify if reused (may be reset)
cmd2 := getCommand()
if cmd2 == nil {
t.Fatal("getCommand() returned nil after put")
}
// Put back
putCommand(cmd2)
// Test nil doesn't panic
putCommand(nil)
}
func TestGetPutBuffer(t *testing.T) {
// Test getting buffer
buf1 := getBuffer()
if buf1 == nil {
t.Fatal("getBuffer() returned nil")
}
if len(buf1) != BHS_SIZE {
t.Errorf("expected buffer size %d, got %d", BHS_SIZE, len(buf1))
}
// Modify buffer content
for i := range buf1 {
buf1[i] = byte(i % 256)
}
// Put back to pool
putBuffer(buf1)
// Get again
buf2 := getBuffer()
if buf2 == nil {
t.Fatal("getBuffer() returned nil after put")
}
if len(buf2) != BHS_SIZE {
t.Errorf("expected buffer size %d, got %d", BHS_SIZE, len(buf2))
}
putBuffer(buf2)
// Test small buffer won't be put into pool
smallBuf := make([]byte, 10)
putBuffer(smallBuf) // Should not panic
// Test nil doesn't panic
putBuffer(nil)
}
func TestBufferPoolReuse(t *testing.T) {
// Get buffer and record address
buf1 := getBuffer()
ptr1 := uintptr(unsafe.Pointer(&buf1[0]))
putBuffer(buf1)
// Get again, verify if reuse is possible (not guaranteed)
buf2 := getBuffer()
ptr2 := uintptr(unsafe.Pointer(&buf2[0]))
putBuffer(buf2)
// If reused, addresses should be the same
// If not reused, it's fine, this is a performance test
t.Logf("First buffer pointer: %x, Second buffer pointer: %x, reused: %v",
ptr1, ptr2, ptr1 == ptr2)
}
func BenchmarkGetPutCommand(b *testing.B) {
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
cmd := getCommand()
cmd.TaskTag = 1
putCommand(cmd)
}
})
}
func BenchmarkGetPutBuffer(b *testing.B) {
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
buf := getBuffer()
buf[0] = 1
putBuffer(buf)
}
})
}
// BenchmarkAllocCommand 对比:不使用 pool 直接创建
func BenchmarkAllocCommand(b *testing.B) {
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
cmd := &ISCSICommand{}
cmd.TaskTag = 1
_ = cmd
}
})
}
// BenchmarkAllocBuffer 对比:不使用 pool 直接创建
func BenchmarkAllocBuffer(b *testing.B) {
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
buf := make([]byte, BHS_SIZE)
buf[0] = 1
_ = buf
}
})
}

View File

@@ -155,12 +155,12 @@ func (conn *iscsiConnection) buildRespPackage(oc OpCode, task *iscsiTask) error
if task == nil {
task = conn.rxTask
}
conn.resp = &ISCSICommand{
StartTime: conn.req.StartTime,
StatSN: conn.req.ExpStatSN,
TaskTag: conn.req.TaskTag,
ExpectedDataLen: conn.req.ExpectedDataLen,
}
// Get ISCSICommand from object pool
conn.resp = getCommand()
conn.resp.StartTime = conn.req.StartTime
conn.resp.StatSN = conn.req.ExpStatSN
conn.resp.TaskTag = conn.req.TaskTag
conn.resp.ExpectedDataLen = conn.req.ExpectedDataLen
if conn.session != nil {
conn.resp.ExpCmdSN = conn.session.ExpCmdSN
conn.resp.MaxCmdSN = conn.session.ExpCmdSN + conn.session.MaxQueueCommand

View File

@@ -44,6 +44,86 @@ const (
STATE_TERMINATE
)
// tsihBitmap is a bitmap for efficient TSIH allocation/deallocation
// Uses circular counter for O(1) allocation
type tsihBitmap struct {
mu sync.Mutex
bitmap []uint64 // Each uint64 stores the usage status of 64 TSIHs
next uint16 // Next candidate position for allocation
used uint16 // Number of used TSIHs
}
// newTSIHBitmap creates a new TSIH bitmap
// Reserves 0 and 65535 as special values
func newTSIHBitmap() *tsihBitmap {
// Need 65536 bits = 1024 uint64s
b := &tsihBitmap{
bitmap: make([]uint64, 1024),
next: 1, // Start from 1, 0 is reserved
}
// Mark 0 and 65535 as used (reserved values)
b.bitmap[0] |= 1 << 0 // TSIH = 0
b.bitmap[1023] |= 1 << 63 // TSIH = 65535
b.used = 2
return b
}
// alloc allocates an available TSIH using circular search strategy
func (b *tsihBitmap) alloc() uint16 {
b.mu.Lock()
defer b.mu.Unlock()
if b.used >= ISCSI_MAX_TSIH-1 {
return ISCSI_UNSPEC_TSIH
}
start := b.next
for {
idx := b.next / 64
bit := b.next % 64
if (b.bitmap[idx] & (1 << bit)) == 0 {
// Found free slot
b.bitmap[idx] |= 1 << bit
b.used++
result := b.next
// Update next to next position
b.next++
if b.next >= ISCSI_MAX_TSIH {
b.next = 1
}
return result
}
b.next++
if b.next >= ISCSI_MAX_TSIH {
b.next = 1
}
if b.next == start {
// Looped around without finding
return ISCSI_UNSPEC_TSIH
}
}
}
// release releases a TSIH
func (b *tsihBitmap) release(tsih uint16) {
if tsih == 0 || tsih == ISCSI_MAX_TSIH {
return // Cannot release reserved values
}
b.mu.Lock()
defer b.mu.Unlock()
idx := tsih / 64
bit := tsih % 64
if (b.bitmap[idx] & (1 << bit)) != 0 {
b.bitmap[idx] &^= 1 << bit
b.used--
}
}
var (
EnableStats bool
CurrentHostIP string
@@ -54,8 +134,7 @@ type ISCSITargetDriver struct {
SCSI *scsi.SCSITargetService
Name string
iSCSITargets map[string]*ISCSITarget
TSIHPool map[uint16]bool
TSIHPoolMutex sync.Mutex
tsihBitmap *tsihBitmap
isClientConnected bool
enableStats bool
mu *sync.RWMutex
@@ -76,7 +155,7 @@ func NewISCSITargetDriver(base *scsi.SCSITargetService) (scsi.SCSITargetDriver,
Name: iSCSIDriverName,
iSCSITargets: map[string]*ISCSITarget{},
SCSI: base,
TSIHPool: map[uint16]bool{0: true, 65535: true},
tsihBitmap: newTSIHBitmap(),
mu: &sync.RWMutex{},
}
@@ -88,24 +167,11 @@ func NewISCSITargetDriver(base *scsi.SCSITargetService) (scsi.SCSITargetDriver,
}
func (s *ISCSITargetDriver) AllocTSIH() uint16 {
var i uint16
s.TSIHPoolMutex.Lock()
for i = uint16(0); i < ISCSI_MAX_TSIH; i++ {
exist := s.TSIHPool[i]
if !exist {
s.TSIHPool[i] = true
s.TSIHPoolMutex.Unlock()
return i
}
}
s.TSIHPoolMutex.Unlock()
return ISCSI_UNSPEC_TSIH
return s.tsihBitmap.alloc()
}
func (s *ISCSITargetDriver) ReleaseTSIH(tsih uint16) {
s.TSIHPoolMutex.Lock()
delete(s.TSIHPool, tsih)
s.TSIHPoolMutex.Unlock()
s.tsihBitmap.release(tsih)
}
func (s *ISCSITargetDriver) NewTarget(tgtName string, configInfo *config.Config) error {
@@ -122,9 +188,9 @@ func (s *ISCSITargetDriver) NewTarget(tgtName string, configInfo *config.Config)
targetConfig := configInfo.ISCSITargets[tgtName]
for tpgt, portalIDArrary := range targetConfig.TPGTs {
tpgtNumber, _ := strconv.ParseUint(tpgt, 10, 16)
tgt.TPGTs[uint16(tpgtNumber)] = &iSCSITPGT{uint16(tpgtNumber), make(map[string]struct{})}
tgt.TPGTs[uint16(tpgtNumber)] = &iSCSITPGT{TPGT: uint16(tpgtNumber), Portals: make(map[string]struct{})}
targetPortName := fmt.Sprintf("%s,t,0x%02x", tgtName, tpgtNumber)
scsiTPG.TargetPortGroup = append(scsiTPG.TargetPortGroup, &api.SCSITargetPort{uint16(tpgtNumber), targetPortName})
scsiTPG.TargetPortGroup = append(scsiTPG.TargetPortGroup, &api.SCSITargetPort{RelativeTargetPortID: uint16(tpgtNumber), TargetPortName: targetPortName})
for _, portalID := range portalIDArrary {
portal := configInfo.ISCSIPortals[portalID]
s.AddiSCSIPortal(tgtName, uint16(tpgtNumber), portal.Portal)
@@ -323,10 +389,12 @@ func (s *ISCSITargetDriver) rxHandler(conn *iscsiConnection) {
ddigest uint = 0
final bool = false
cmd *ISCSICommand
buf []byte = make([]byte, BHS_SIZE)
buf []byte = getBuffer()
length int
err error
)
defer putBuffer(buf)
conn.readLock.Lock()
defer conn.readLock.Unlock()
if conn.state == CONN_STATE_SCSI {
@@ -366,10 +434,10 @@ func (s *ISCSITargetDriver) rxHandler(conn *iscsiConnection) {
}
final = true
case IOSTATE_RX_INIT_AHS:
conn.rxIOState = IOSTATE_RX_DATA
break
if hdigest != 0 {
conn.rxIOState = IOSTATE_RX_INIT_HDIGEST
} else {
conn.rxIOState = IOSTATE_RX_DATA
}
case IOSTATE_RX_DATA:
if ddigest != 0 {
@@ -563,6 +631,92 @@ func iscsiExecNoopOut(conn *iscsiConnection) error {
return conn.buildRespPackage(OpNoopIn, nil)
}
// SNACK Type constants per RFC 7143
const (
SNACK_TYPE_DATA_ACK = 0 // Data ACK
SNACK_TYPE_STATUS_ACK = 1 // Status ACK
SNACK_TYPE_DATA_R2T = 2 // Data R2T
SNACK_TYPE_R_DATA = 3 // R-Data
)
/*
* iscsiExecSNACK handles SNACK (Sequence Number Acknowledgement) requests
* SNACK is used for error recovery in iSCSI protocol per RFC 7143 section 11.9
*/
func (s *ISCSITargetDriver) iscsiExecSNACK(conn *iscsiConnection) error {
req := conn.req
// Parse SNACK type from byte 1, bits 0-1
snackType := (req.SCSIOpCode >> 0) & 0x03
// Parse BegRun and RunLength from the header
begRun := req.ReferencedTaskTag
runLength := req.R2TSN
log.Debugf("SNACK request type=%d, BegRun=%d, RunLength=%d", snackType, begRun, runLength)
switch snackType {
case SNACK_TYPE_DATA_ACK:
// Data ACK - initiator acknowledges receipt of Data-In PDUs
// For ErrorRecoveryLevel >= 1, we could track acknowledged Data-In
log.Debug("SNACK Data ACK received")
// Simply return success for now
conn.resp = &ISCSICommand{
OpCode: OpNoopIn,
Final: true,
TaskTag: req.TaskTag,
StatSN: conn.statSN,
ExpCmdSN: conn.expCmdSN,
}
if conn.session != nil {
conn.resp.MaxCmdSN = conn.session.ExpCmdSN + conn.session.MaxQueueCommand
}
return nil
case SNACK_TYPE_STATUS_ACK:
// Status ACK - initiator acknowledges receipt of status
log.Debug("SNACK Status ACK received")
// Similar to Data ACK, just acknowledge
conn.resp = &ISCSICommand{
OpCode: OpNoopIn,
Final: true,
TaskTag: req.TaskTag,
StatSN: conn.statSN,
ExpCmdSN: conn.expCmdSN,
}
if conn.session != nil {
conn.resp.MaxCmdSN = conn.session.ExpCmdSN + conn.session.MaxQueueCommand
}
return nil
case SNACK_TYPE_DATA_R2T:
// Data R2T - request retransmission of R2T
log.Debug("SNACK Data R2T received - requesting R2T retransmission")
// Find the task and resend R2T
conn.session.PendingTasksMutex.RLock()
task := conn.session.PendingTasks.GetByTag(begRun)
conn.session.PendingTasksMutex.RUnlock()
if task == nil {
log.Errorf("Cannot find task for R2T retransmission, tag=%d", begRun)
return fmt.Errorf("task not found")
}
// Reset R2T state and resend
task.r2tSN = runLength
conn.rxTask = task
return iscsiExecR2T(conn)
case SNACK_TYPE_R_DATA:
// R-Data - request retransmission of Data-In
log.Debug("SNACK R-Data received - requesting Data-In retransmission")
// For now, reject this as it requires complex data buffering
// In a full implementation, we would need to buffer Data-In PDUs
// and retransmit based on BegRun and RunLength
log.Warn("R-Data SNACK not fully implemented")
return fmt.Errorf("R-Data SNACK not supported")
default:
return fmt.Errorf("unknown SNACK type: %d", snackType)
}
}
func iscsiExecReject(conn *iscsiConnection) error {
return conn.buildRespPackage(OpReject, nil)
}
@@ -852,10 +1006,16 @@ func (s *ISCSITargetDriver) scsiCommandHandler(conn *iscsiConnection) (err error
conn.txTask = &iscsiTask{conn: conn, cmd: conn.req, tag: conn.req.TaskTag}
conn.txIOState = IOSTATE_TX_BHS
iscsiExecLogout(conn)
case OpTextReq, OpSNACKReq:
case OpTextReq:
err = fmt.Errorf("Cannot handle yet %s", opCodeMap[conn.req.OpCode])
log.Error(err)
return
case OpSNACKReq:
log.Debug("SNACK Request processing...")
if err := s.iscsiExecSNACK(conn); err != nil {
log.Errorf("SNACK handling failed: %v", err)
iscsiExecReject(conn)
}
default:
err = fmt.Errorf("Unknown op %s", opCodeMap[conn.req.OpCode])
log.Error(err)
@@ -900,22 +1060,20 @@ func (s *ISCSITargetDriver) iscsiTaskQueueHandler(task *iscsiTask) error {
task.state = taskSCSI
sess.PendingTasksMutex.Unlock()
goto retry
} else {
if cmd.CmdSN < sess.ExpCmdSN {
err := fmt.Errorf("unexpected cmd serial number: (%d, %d)", cmd.CmdSN, sess.ExpCmdSN)
log.Error(err)
return err
}
log.Debugf("add task(%d) into task queue", task.cmd.CmdSN)
// add this task into queue and set it as a pending task
sess.PendingTasksMutex.Lock()
task.state = taskPending
sess.PendingTasks.Push(task)
sess.PendingTasksMutex.Unlock()
return fmt.Errorf("pending")
}
return nil
// cmd.CmdSN != sess.ExpCmdSN
if cmd.CmdSN < sess.ExpCmdSN {
err := fmt.Errorf("unexpected cmd serial number: (%d, %d)", cmd.CmdSN, sess.ExpCmdSN)
log.Error(err)
return err
}
log.Debugf("add task(%d) into task queue", task.cmd.CmdSN)
// add this task into queue and set it as a pending task
sess.PendingTasksMutex.Lock()
task.state = taskPending
sess.PendingTasks.Push(task)
sess.PendingTasksMutex.Unlock()
return fmt.Errorf("pending")
}
func (s *ISCSITargetDriver) iscsiExecTask(task *iscsiTask) error {
@@ -972,6 +1130,63 @@ func (s *ISCSITargetDriver) iscsiExecTask(task *iscsiTask) error {
return nil
}
// Async Event types per RFC 7143
const (
ASYNC_EVENT_SCSI = 0 // SCSI Asynchronous Event
ASYNC_EVENT_STATUS = 1 // iSCSI Status Update
ASYNC_EVENT_LOGOUT = 2 // iSCSI Logout Request
ASYNC_EVENT_DROP_CONN = 3 // iSCSI Drop Connection
ASYNC_EVENT_DROP_SESS = 4 // iSCSI Drop All Connections
ASYNC_EVENT_NOP = 5 // iSCSI NOP
ASYNC_EVENT_VENDOR = 255 // Vendor Specific Event
)
/*
* SendAsyncMessage sends an asynchronous message to the initiator
* This implements RFC 7143 section 11.10 Asynchronous Message
*/
func (s *ISCSITargetDriver) SendAsyncMessage(conn *iscsiConnection, eventType byte, lun [8]uint8, param1, param2 uint32, data []byte) error {
if conn == nil || conn.state != CONN_STATE_SCSI {
return fmt.Errorf("connection not ready for async message")
}
conn.statSN += 1
conn.resp = &ISCSICommand{
OpCode: OpAsync,
SCSIOpCode: eventType,
Final: true,
LUN: lun,
StatSN: conn.statSN,
ExpCmdSN: conn.expCmdSN,
RawData: data,
}
if conn.session != nil {
conn.resp.MaxCmdSN = conn.session.ExpCmdSN + conn.session.MaxQueueCommand
}
// Parameter1 and Parameter2 are encoded in RawData or could be stored in ISCSICommand
// For simplicity, we encode them at the start of RawData if not already present
if len(data) == 0 && (param1 != 0 || param2 != 0) {
conn.resp.RawData = make([]byte, 8)
copy(conn.resp.RawData[0:4], util.MarshalUint32(param1))
copy(conn.resp.RawData[4:8], util.MarshalUint32(param2))
}
log.Debugf("Sending Async message type=%d to initiator", eventType)
s.handler(DATAOUT, conn)
return nil
}
// SendSCSIAsyncEvent sends a SCSI asynchronous event (e.g., LUN reset, storage change)
func (s *ISCSITargetDriver) SendSCSIAsyncEvent(conn *iscsiConnection, lun [8]uint8, eventCode byte) error {
// SCSI Async Event data format:
// bytes 0-1: Event Code
// bytes 2-3: Reserved
// bytes 4+: Event-specific data
data := []byte{eventCode, 0, 0, 0}
return s.SendAsyncMessage(conn, ASYNC_EVENT_SCSI, lun, 0, 0, data)
}
func (s *ISCSITargetDriver) Stats() scsi.Stats {
s.mu.RLock()
stats := s.TargetStats

View File

@@ -0,0 +1,176 @@
/*
Copyright 2024 The GoStor Authors All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package iscsit
import (
"sync"
"testing"
)
func TestTSIHBitmapAllocRelease(t *testing.T) {
b := newTSIHBitmap()
// Test basic allocation and release
tsih1 := b.alloc()
if tsih1 == ISCSI_UNSPEC_TSIH {
t.Fatal("failed to allocate first TSIH")
}
if tsih1 != 1 {
t.Errorf("expected first TSIH to be 1, got %d", tsih1)
}
tsih2 := b.alloc()
if tsih2 == ISCSI_UNSPEC_TSIH {
t.Fatal("failed to allocate second TSIH")
}
if tsih2 != 2 {
t.Errorf("expected second TSIH to be 2, got %d", tsih2)
}
// Release first
b.release(tsih1)
// Note: TSIH bitmap uses circular allocation, next pointer won't return to released positions
// This is to avoid concurrency issues, subsequent allocations continue from current next
tsih3 := b.alloc()
if tsih3 == ISCSI_UNSPEC_TSIH {
t.Error("failed to allocate after release")
}
// Verify tsih1 can be reallocated (at some point)
if tsih3 == tsih1 || tsih3 == tsih2 {
t.Logf("TSIH was recycled immediately: released %d, got %d", tsih1, tsih3)
}
// Release all
b.release(tsih2)
b.release(tsih3)
}
func TestTSIHBitmapReservedValues(t *testing.T) {
b := newTSIHBitmap()
// Test reserved values cannot be allocated
// 0 and 65535 are reserved values
for i := 0; i < 10; i++ {
tsih := b.alloc()
if tsih == 0 {
t.Error("allocated reserved TSIH 0")
}
if tsih == ISCSI_MAX_TSIH {
t.Error("allocated reserved TSIH 65535")
}
if tsih == ISCSI_UNSPEC_TSIH {
break
}
b.release(tsih)
}
// Test releasing reserved values doesn't panic
b.release(0)
b.release(ISCSI_MAX_TSIH)
}
func TestTSIHBitmapExhaustion(t *testing.T) {
b := newTSIHBitmap()
// Allocate many TSIHs
allocated := make([]uint16, 0, 100)
for i := 0; i < 100; i++ {
tsih := b.alloc()
if tsih == ISCSI_UNSPEC_TSIH {
t.Fatalf("failed to allocate TSIH at iteration %d", i)
}
allocated = append(allocated, tsih)
}
// 释放所有
for _, tsih := range allocated {
b.release(tsih)
}
// Reallocate, should succeed
for i := 0; i < 100; i++ {
tsih := b.alloc()
if tsih == ISCSI_UNSPEC_TSIH {
t.Fatalf("failed to reallocate TSIH at iteration %d", i)
}
b.release(tsih)
}
}
func TestTSIHBitmapConcurrency(t *testing.T) {
b := newTSIHBitmap()
const numGoroutines = 100
const allocsPerGoroutine = 100
var wg sync.WaitGroup
wg.Add(numGoroutines)
allTSIHs := make(chan uint16, numGoroutines*allocsPerGoroutine)
for i := 0; i < numGoroutines; i++ {
go func() {
defer wg.Done()
for j := 0; j < allocsPerGoroutine; j++ {
tsih := b.alloc()
if tsih != ISCSI_UNSPEC_TSIH {
allTSIHs <- tsih
}
}
}()
}
wg.Wait()
close(allTSIHs)
// Check no duplicate TSIHs
seen := make(map[uint16]bool)
for tsih := range allTSIHs {
if seen[tsih] {
t.Errorf("TSIH %d was allocated more than once", tsih)
}
seen[tsih] = true
}
// Release all allocated TSIHs
for tsih := range seen {
b.release(tsih)
}
}
func BenchmarkTSIHBitmapAlloc(b *testing.B) {
bitmap := newTSIHBitmap()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
tsih := bitmap.alloc()
if tsih != ISCSI_UNSPEC_TSIH {
bitmap.release(tsih)
}
}
})
}
func BenchmarkTSIHBitmapAllocSequential(b *testing.B) {
bitmap := newTSIHBitmap()
b.ResetTimer()
for i := 0; i < b.N; i++ {
tsih := bitmap.alloc()
if tsih != ISCSI_UNSPEC_TSIH {
bitmap.release(tsih)
}
}
}

View File

@@ -79,7 +79,7 @@ type iSCSITPGT struct {
}
type ISCSITarget struct {
api.SCSITarget
*api.SCSITarget
api.SCSITargetDriverCommon
// TPGT number is the key
TPGTs map[uint16]*iSCSITPGT
@@ -123,7 +123,7 @@ func (tgt *ISCSITarget) FindTPG(portal string) (uint16, error) {
func newISCSITarget(target *api.SCSITarget) *ISCSITarget {
return &ISCSITarget{
SCSITarget: *target,
SCSITarget: target,
TPGTs: make(map[uint16]*iSCSITPGT),
Sessions: make(map[uint16]*ISCSISession),
}

View File

@@ -17,7 +17,6 @@ limitations under the License.
package iscsit
import (
"bytes"
"fmt"
"strings"
@@ -26,20 +25,20 @@ import (
var (
iSCSILoginParamTextKV = []util.KeyValue{
{"HeaderDigest", "None"},
{"DataDigest", "None"},
{"ImmediateData", "Yes"},
{"InitialR2T", "Yes"},
{"MaxBurstLength", "262144"},
{"FirstBurstLength", "65536"},
{"MaxRecvDataSegmentLength", "65536"},
{"DefaultTime2Wait", "2"},
{"DefaultTime2Retain", "0"},
{"MaxOutstandingR2T", "1"},
{"IFMarker", "No"},
{"OFMarker", "No"},
{"DataPDUInOrder", "Yes"},
{"DataSequenceInOrder", "Yes"}}
{Key: "HeaderDigest", Value: "None"},
{Key: "DataDigest", Value: "None"},
{Key: "ImmediateData", Value: "Yes"},
{Key: "InitialR2T", Value: "Yes"},
{Key: "MaxBurstLength", Value: "262144"},
{Key: "FirstBurstLength", Value: "65536"},
{Key: "MaxRecvDataSegmentLength", Value: "65536"},
{Key: "DefaultTime2Wait", Value: "2"},
{Key: "DefaultTime2Retain", Value: "0"},
{Key: "MaxOutstandingR2T", Value: "1"},
{Key: "IFMarker", Value: "No"},
{Key: "OFMarker", Value: "No"},
{Key: "DataPDUInOrder", Value: "Yes"},
{Key: "DataSequenceInOrder", Value: "Yes"}}
)
type iSCSILoginStage int
@@ -63,10 +62,10 @@ func (s iSCSILoginStage) String() string {
}
func loginKVDeclare(conn *iscsiConnection, negoKV []util.KeyValue) []util.KeyValue {
negoKV = append(negoKV, util.KeyValue{"TargetPortalGroupTag",
numberKeyInConv(uint(conn.loginParam.tpgt))})
negoKV = append(negoKV, util.KeyValue{"MaxRecvDataSegmentLength",
numberKeyInConv(sessionKeys["MaxRecvDataSegmentLength"].def)})
negoKV = append(negoKV, util.KeyValue{Key: "TargetPortalGroupTag",
Value: numberKeyInConv(uint(conn.loginParam.tpgt))})
negoKV = append(negoKV, util.KeyValue{Key: "MaxRecvDataSegmentLength",
Value: numberKeyInConv(sessionKeys["MaxRecvDataSegmentLength"].def)})
return negoKV
}
@@ -158,14 +157,14 @@ func (conn *iscsiConnection) processLoginData() ([]util.KeyValue, error) {
if uintVal != defSessKey.def {
kvChanges++
}
negoKV = append(negoKV, util.KeyValue{key, defSessKey.inConv(defSessKey.def)})
negoKV = append(negoKV, util.KeyValue{Key: key, Value: defSessKey.inConv(defSessKey.def)})
} else {
if (uintVal >= defSessKey.min) && (uintVal <= defSessKey.max) {
conn.loginParam.sessionParam[defSessKey.idx].Value = uintVal
negoKV = append(negoKV, util.KeyValue{key, defSessKey.inConv(uintVal)})
negoKV = append(negoKV, util.KeyValue{Key: key, Value: defSessKey.inConv(uintVal)})
} else {
// the value out of the acceptable range, Uses target default key
negoKV = append(negoKV, util.KeyValue{key, defSessKey.inConv(defSessKey.def)})
negoKV = append(negoKV, util.KeyValue{Key: key, Value: defSessKey.inConv(defSessKey.def)})
kvChanges++
}
}
@@ -222,10 +221,13 @@ type iscsiLoginParam struct {
}
func (m *ISCSICommand) loginRespBytes() []byte {
// rfc7143 11.13
buf := &bytes.Buffer{}
// byte 0
buf.WriteByte(byte(OpLoginResp))
// rfc7143 11.13 - BHS 48 bytes + data (4-byte aligned)
rawDataLen := len(m.RawData)
padding := (4 - rawDataLen%4) % 4
buf := make([]byte, 48+rawDataLen+padding)
// byte 0: Opcode
buf[0] = byte(OpLoginResp)
var b byte
if m.Transit {
b |= 0x80
@@ -236,33 +238,38 @@ func (m *ISCSICommand) loginRespBytes() []byte {
b |= byte(m.CSG&0xff) << 2
b |= byte(m.NSG & 0xff)
// byte 1
buf.WriteByte(b)
buf[1] = b
b = 0
buf.WriteByte(b) // version-max
buf.WriteByte(b) // version-active
buf.WriteByte(b) // ahsLen
buf.Write(util.MarshalUint64(uint64(len(m.RawData)))[5:]) // data segment length, no padding
buf.Write(util.MarshalUint64(m.ISID)[2:])
buf.Write(util.MarshalUint64(uint64(m.TSIH))[6:])
buf.Write(util.MarshalUint64(uint64(m.TaskTag))[4:])
buf.WriteByte(b)
buf.WriteByte(b)
buf.WriteByte(b)
buf.WriteByte(b) // "reserved"
buf.Write(util.MarshalUint64(uint64(m.StatSN))[4:])
buf.Write(util.MarshalUint64(uint64(m.ExpCmdSN))[4:])
buf.Write(util.MarshalUint64(uint64(m.MaxCmdSN))[4:])
buf.WriteByte(byte(m.StatusClass))
buf.WriteByte(byte(m.StatusDetail))
buf.WriteByte(b)
buf.WriteByte(b) // "reserved"
var bs [8]byte
buf.Write(bs[:])
rd := m.RawData
for len(rd)%4 != 0 {
rd = append(rd, 0)
}
buf.Write(rd)
return buf.Bytes()
// byte 2: version-max, byte 3: version-active
// bytes 4-7: data segment length (24-bit)
buf[5] = byte(rawDataLen >> 16)
buf[6] = byte(rawDataLen >> 8)
buf[7] = byte(rawDataLen)
// bytes 8-13: ISID (6 bytes) - lower 6 bytes of uint64
buf[8] = byte(m.ISID >> 40)
buf[9] = byte(m.ISID >> 32)
buf[10] = byte(m.ISID >> 24)
buf[11] = byte(m.ISID >> 16)
buf[12] = byte(m.ISID >> 8)
buf[13] = byte(m.ISID)
// bytes 14-15: TSIH (2 bytes)
buf[14] = byte(m.TSIH >> 8)
buf[15] = byte(m.TSIH)
// bytes 16-19: TaskTag
util.MarshalUint32To(buf[16:], m.TaskTag)
// bytes 20-23: reserved
// bytes 24-27: StatSN
util.MarshalUint32To(buf[24:], m.StatSN)
// bytes 28-31: ExpCmdSN
util.MarshalUint32To(buf[28:], m.ExpCmdSN)
// bytes 32-35: MaxCmdSN
util.MarshalUint32To(buf[32:], m.MaxCmdSN)
// bytes 36: StatusClass, 37: StatusDetail
buf[36] = byte(m.StatusClass)
buf[37] = byte(m.StatusDetail)
// bytes 38-47: reserved
// Copy data
copy(buf[48:], m.RawData)
// padding bytes are already zero
return buf
}

View File

@@ -1,29 +1,25 @@
package iscsit
import (
"bytes"
"github.com/gostor/gotgt/pkg/util"
)
func (m *ISCSICommand) logoutRespBytes() []byte {
buf := &bytes.Buffer{}
buf.WriteByte(byte(OpLogoutResp))
buf.WriteByte(0x80)
buf.WriteByte(0x00) // response
buf.WriteByte(0x00)
for i := 4; i < 16; i++ {
buf.WriteByte(0x00)
}
buf.Write(util.MarshalUint64(uint64(m.TaskTag))[4:])
for i := 20; i < 24; i++ {
buf.WriteByte(0x00)
}
buf.Write(util.MarshalUint64(uint64(m.StatSN))[4:])
buf.Write(util.MarshalUint64(uint64(m.ExpCmdSN))[4:])
buf.Write(util.MarshalUint64(uint64(m.MaxCmdSN))[4:])
for i := 36; i < 48; i++ {
buf.WriteByte(0x00)
}
return buf.Bytes()
// rfc7143 11.10 - Fixed 48 bytes
buf := make([]byte, 48)
buf[0] = byte(OpLogoutResp)
buf[1] = 0x80
// buf[2] = response (0)
// bytes 4-15 are reserved (0)
// bytes 16-19: TaskTag
util.MarshalUint32To(buf[16:], m.TaskTag)
// bytes 20-23 are reserved (0)
// bytes 24-27: StatSN
util.MarshalUint32To(buf[24:], m.StatSN)
// bytes 28-31: ExpCmdSN
util.MarshalUint32To(buf[28:], m.ExpCmdSN)
// bytes 32-35: MaxCmdSN
util.MarshalUint32To(buf[32:], m.MaxCmdSN)
// bytes 36-47 are reserved (0)
return buf
}

View File

@@ -0,0 +1,403 @@
/*
Copyright 2024 The GoStor Authors All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package iscsit
import (
"bytes"
"testing"
"time"
"github.com/gostor/gotgt/pkg/api"
"github.com/gostor/gotgt/pkg/util"
)
// BenchmarkParseHeader benchmarks iSCSI protocol header parsing performance
func BenchmarkParseHeader(b *testing.B) {
// Build a typical SCSI CDB command header
header := make([]byte, BHS_SIZE)
header[0] = byte(OpSCSICmd) // SCSI Command
header[1] = 0x80 // Final bit
header[4] = 0 // AHS length
header[5] = 0
header[6] = 0
header[7] = 0 // Data segment length = 0
// TaskTag at bytes 16-19
header[16] = 0x00
header[17] = 0x00
header[18] = 0x00
header[19] = 0x01
// ExpectedDataLen at bytes 20-23
header[20] = 0x00
header[21] = 0x00
header[22] = 0x10
header[23] = 0x00 // 4096 bytes
// CmdSN at bytes 24-27
header[24] = 0x00
header[25] = 0x00
header[26] = 0x00
header[27] = 0x01
// ExpStatSN at bytes 28-31
header[28] = 0x00
header[29] = 0x00
header[30] = 0x00
header[31] = 0x01
// CDB at bytes 32-47 (READ_10 command)
header[32] = byte(api.READ_10)
header[33] = 0x00
header[34] = 0x00
header[35] = 0x00
header[36] = 0x00
header[37] = 0x00 // LBA = 0
header[38] = 0x00
header[39] = 0x08 // Transfer length = 8 blocks
header[40] = 0x00 // Control
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
cmd, err := parseHeader(header)
if err != nil {
b.Fatal(err)
}
_ = cmd
}
}
// BenchmarkParseHeaderWithPool benchmarks header parsing with object pool
func BenchmarkParseHeaderWithPool(b *testing.B) {
header := make([]byte, BHS_SIZE)
header[0] = byte(OpSCSICmd)
header[1] = 0x80
header[32] = byte(api.READ_10)
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
cmd := getCommand()
cmd.OpCode = OpCode(header[0] & ISCSI_OPCODE_MASK)
cmd.Final = 0x80&header[1] == 0x80
cmd.AHSLen = int(header[4]) * 4
cmd.DataLen = int(ParseUint(header[5:8]))
cmd.TaskTag = uint32(ParseUint(header[16:20]))
cmd.CDB = header[32:48]
cmd.StartTime = time.Now()
putCommand(cmd)
}
}
// BenchmarkDataInBytes benchmarks Data-In response serialization performance
func BenchmarkDataInBytes(b *testing.B) {
data := make([]byte, 4096)
for i := range data {
data[i] = byte(i % 256)
}
cmd := &ISCSICommand{
OpCode: OpSCSIIn,
Final: true,
FinalInSeq: true,
TaskTag: 1,
StatSN: 100,
ExpCmdSN: 101,
MaxCmdSN: 200,
DataLen: 4096,
DataSN: 0,
BufferOffset: 0,
RawData: data,
}
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
_ = cmd.dataInBytes()
}
}
// BenchmarkDataInBytesSmall benchmarks Data-In performance with small data blocks
func BenchmarkDataInBytesSmall(b *testing.B) {
data := make([]byte, 512)
cmd := &ISCSICommand{
OpCode: OpSCSIIn,
Final: true,
FinalInSeq: true,
TaskTag: 1,
StatSN: 100,
ExpCmdSN: 101,
MaxCmdSN: 200,
DataLen: 512,
DataSN: 0,
BufferOffset: 0,
RawData: data,
}
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
_ = cmd.dataInBytes()
}
}
// BenchmarkDataInBytesLarge benchmarks Data-In performance with large data blocks
func BenchmarkDataInBytesLarge(b *testing.B) {
data := make([]byte, 65536)
cmd := &ISCSICommand{
OpCode: OpSCSIIn,
Final: true,
FinalInSeq: true,
TaskTag: 1,
StatSN: 100,
ExpCmdSN: 101,
MaxCmdSN: 200,
DataLen: 65536,
DataSN: 0,
BufferOffset: 0,
RawData: data,
}
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
_ = cmd.dataInBytes()
}
}
// BenchmarkBytesComparison compares Bytes() performance for different OpCodes
func BenchmarkBytesComparison(b *testing.B) {
testCases := []struct {
name string
cmd *ISCSICommand
}{
{
name: "LoginResp",
cmd: &ISCSICommand{
OpCode: OpLoginResp,
Final: true,
Transit: true,
CSG: LoginOperationalNegotiation,
NSG: FullFeaturePhase,
TaskTag: 1,
StatSN: 0,
ExpCmdSN: 1,
MaxCmdSN: 1,
StatusClass: 0,
StatusDetail: 0,
RawData: []byte("TargetPortalGroupTag=1"),
},
},
{
name: "SCSIResp",
cmd: &ISCSICommand{
OpCode: OpSCSIResp,
Final: true,
TaskTag: 1,
StatSN: 100,
ExpCmdSN: 101,
MaxCmdSN: 200,
},
},
{
name: "SCSIIn",
cmd: &ISCSICommand{
OpCode: OpSCSIIn,
Final: true,
TaskTag: 1,
StatSN: 100,
ExpCmdSN: 101,
MaxCmdSN: 200,
DataLen: 4096,
RawData: make([]byte, 4096),
},
},
{
name: "R2T",
cmd: &ISCSICommand{
OpCode: OpReady,
Final: true,
TaskTag: 1,
StatSN: 100,
ExpCmdSN: 101,
MaxCmdSN: 200,
R2TSN: 0,
BufferOffset: 0,
DesiredLength: 8192,
},
},
}
for _, tc := range testCases {
b.Run(tc.name, func(b *testing.B) {
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
_ = tc.cmd.Bytes()
}
})
}
}
// BenchmarkCommandPool benchmarks command object pool performance
func BenchmarkCommandPool(b *testing.B) {
b.Run("WithPool", func(b *testing.B) {
b.ReportAllocs()
for i := 0; i < b.N; i++ {
cmd := getCommand()
cmd.OpCode = OpSCSICmd
cmd.TaskTag = uint32(i)
putCommand(cmd)
}
})
b.Run("WithoutPool", func(b *testing.B) {
b.ReportAllocs()
for i := 0; i < b.N; i++ {
cmd := &ISCSICommand{
OpCode: OpSCSICmd,
TaskTag: uint32(i),
}
_ = cmd
}
})
}
// BenchmarkBufferPool benchmarks buffer pool performance
func BenchmarkBufferPool(b *testing.B) {
b.Run("WithPool", func(b *testing.B) {
b.ReportAllocs()
for i := 0; i < b.N; i++ {
buf := getBuffer()
buf[0] = byte(i)
putBuffer(buf)
}
})
b.Run("WithoutPool", func(b *testing.B) {
b.ReportAllocs()
for i := 0; i < b.N; i++ {
buf := make([]byte, BHS_SIZE)
buf[0] = byte(i)
_ = buf
}
})
}
// BenchmarkTaskStateTransition benchmarks task state transition performance
func BenchmarkTaskStateTransition(b *testing.B) {
task := &iscsiTask{
tag: 1,
state: taskPending,
scmd: &api.SCSICommand{},
}
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
if i%2 == 0 {
task.state = taskPending
} else {
task.state = taskSCSI
}
}
}
// BenchmarkParseUint benchmarks ParseUint performance
func BenchmarkParseUint(b *testing.B) {
testData := []byte{0x00, 0x00, 0x10, 0x00}
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
_ = ParseUint(testData)
}
}
// BenchmarkBytesComparisonEqual benchmarks byte comparison performance
func BenchmarkBytesComparisonEqual(b *testing.B) {
a := make([]byte, 48)
b2 := make([]byte, 48)
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
_ = bytes.Equal(a, b2)
}
}
// BenchmarkMarshalUint32 benchmarks uint32 serialization performance
func BenchmarkMarshalUint32(b *testing.B) {
val := uint32(0x12345678)
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
_ = util.MarshalUint32(val)
}
}
// BenchmarkMarshalUint64 benchmarks uint64 serialization performance
func BenchmarkMarshalUint64(b *testing.B) {
val := uint64(0x1234567890ABCDEF)
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
_ = util.MarshalUint64(val)
}
}
// BenchmarkBuildRespPackage benchmarks complete response package building performance
func BenchmarkBuildRespPackage(b *testing.B) {
conn := &iscsiConnection{
state: CONN_STATE_SCSI,
statSN: 99,
expCmdSN: 100,
loginParam: &iscsiLoginParam{
sessionParam: []ISCSISessionParam{
{idx: ISCSI_PARAM_MAX_BURST, Value: 262144},
},
},
session: &ISCSISession{
ExpCmdSN: 100,
MaxQueueCommand: 32,
},
req: &ISCSICommand{
OpCode: OpSCSICmd,
TaskTag: 1,
ExpStatSN: 100,
ExpectedDataLen: 4096,
StartTime: time.Now(),
},
rxTask: &iscsiTask{
tag: 1,
scmd: &api.SCSICommand{
Result: 0,
Direction: api.SCSIDataRead,
InSDBBuffer: &api.SCSIDataBuffer{
Buffer: make([]byte, 4096),
Length: 4096,
},
},
},
}
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
_ = conn.buildRespPackage(OpSCSIResp, nil)
}
}

View File

@@ -0,0 +1,472 @@
/*
Copyright 2024 The GoStor Authors All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package iscsit
import (
"bytes"
"encoding/binary"
"testing"
)
// TestLoginRespBytesFormat verifies Login Response BHS format complies with RFC 7143
func TestLoginRespBytesFormat(t *testing.T) {
cmd := &ISCSICommand{
OpCode: OpLoginResp,
Transit: true,
Cont: false,
CSG: LoginOperationalNegotiation,
NSG: FullFeaturePhase,
ISID: 0x123456789ABC,
TSIH: 0x1234,
TaskTag: 0xABCDEF00,
StatSN: 0x12345678,
ExpCmdSN: 0x87654321,
MaxCmdSN: 0x87654421,
StatusClass: 0,
StatusDetail: 0,
RawData: []byte("TestData"),
}
resp := cmd.loginRespBytes()
// Verify BHS length is at least 48 bytes
if len(resp) < 48 {
t.Fatalf("BHS too short: expected at least 48, got %d", len(resp))
}
// Byte 0: Opcode
if resp[0] != byte(OpLoginResp) {
t.Errorf("Byte 0: expected OpLoginResp(0x23), got 0x%02x", resp[0])
}
// Byte 1: Flags
expectedFlags := byte(0x80 | (byte(LoginOperationalNegotiation&0xff) << 2) | byte(FullFeaturePhase&0xff))
if resp[1] != expectedFlags {
t.Errorf("Byte 1: expected 0x%02x, got 0x%02x", expectedFlags, resp[1])
}
// Byte 2-3: Version
if resp[2] != 0 || resp[3] != 0 {
t.Logf("Byte 2-3 (version): %d, %d", resp[2], resp[3])
}
// Byte 4-7: Data Segment Length
dataLen := binary.BigEndian.Uint32(resp[4:8])
if dataLen != uint32(len(cmd.RawData)) {
t.Errorf("Data segment length: expected %d, got %d", len(cmd.RawData), dataLen)
}
// Byte 8-13: ISID (6 bytes)
isid := binary.BigEndian.Uint64(append([]byte{0, 0}, resp[8:14]...))
if isid != cmd.ISID {
t.Errorf("ISID: expected 0x%012x, got 0x%012x", cmd.ISID, isid)
}
// Byte 14-15: TSIH
tsih := binary.BigEndian.Uint16(resp[14:16])
if tsih != cmd.TSIH {
t.Errorf("TSIH: expected 0x%04x, got 0x%04x", cmd.TSIH, tsih)
}
// Byte 16-19: Task Tag
taskTag := binary.BigEndian.Uint32(resp[16:20])
if taskTag != cmd.TaskTag {
t.Errorf("TaskTag: expected 0x%08x, got 0x%08x", cmd.TaskTag, taskTag)
}
// Byte 24-27: StatSN
statSN := binary.BigEndian.Uint32(resp[24:28])
if statSN != cmd.StatSN {
t.Errorf("StatSN: expected 0x%08x, got 0x%08x", cmd.StatSN, statSN)
}
// Byte 28-31: ExpCmdSN
expCmdSN := binary.BigEndian.Uint32(resp[28:32])
if expCmdSN != cmd.ExpCmdSN {
t.Errorf("ExpCmdSN: expected 0x%08x, got 0x%08x", cmd.ExpCmdSN, expCmdSN)
}
// Byte 32-35: MaxCmdSN
maxCmdSN := binary.BigEndian.Uint32(resp[32:36])
if maxCmdSN != cmd.MaxCmdSN {
t.Errorf("MaxCmdSN: expected 0x%08x, got 0x%08x", cmd.MaxCmdSN, maxCmdSN)
}
// Byte 36: StatusClass
if resp[36] != cmd.StatusClass {
t.Errorf("StatusClass: expected %d, got %d", cmd.StatusClass, resp[36])
}
// Byte 37: StatusDetail
if resp[37] != cmd.StatusDetail {
t.Errorf("StatusDetail: expected %d, got %d", cmd.StatusDetail, resp[37])
}
// Verify data segment
if len(resp) > 48 {
data := resp[48:]
if !bytes.Equal(data, cmd.RawData) {
t.Errorf("RawData mismatch: expected %v, got %v", cmd.RawData, data)
}
}
// Verify 4-byte alignment
if len(resp)%4 != 0 {
t.Errorf("Response not aligned to 4 bytes: length=%d", len(resp))
}
}
// TestLogoutRespBytesFormat verifies Logout Response BHS format
func TestLogoutRespBytesFormat(t *testing.T) {
cmd := &ISCSICommand{
OpCode: OpLogoutResp,
TaskTag: 0x12345678,
StatSN: 0xABCDEF00,
ExpCmdSN: 0x11223344,
MaxCmdSN: 0x55667788,
}
resp := cmd.logoutRespBytes()
// Verify length is exactly 48 bytes
if len(resp) != 48 {
t.Fatalf("Logout response length: expected 48, got %d", len(resp))
}
// Byte 0: Opcode
if resp[0] != byte(OpLogoutResp) {
t.Errorf("Byte 0: expected OpLogoutResp(0x26), got 0x%02x", resp[0])
}
// Byte 1: Flags (0x80)
if resp[1] != 0x80 {
t.Errorf("Byte 1: expected 0x80, got 0x%02x", resp[1])
}
// Byte 2: Response (0)
if resp[2] != 0 {
t.Errorf("Byte 2: expected 0, got 0x%02x", resp[2])
}
// Byte 16-19: Task Tag
taskTag := binary.BigEndian.Uint32(resp[16:20])
if taskTag != cmd.TaskTag {
t.Errorf("TaskTag: expected 0x%08x, got 0x%08x", cmd.TaskTag, taskTag)
}
// Byte 24-27: StatSN
statSN := binary.BigEndian.Uint32(resp[24:28])
if statSN != cmd.StatSN {
t.Errorf("StatSN: expected 0x%08x, got 0x%08x", cmd.StatSN, statSN)
}
// Byte 28-31: ExpCmdSN
expCmdSN := binary.BigEndian.Uint32(resp[28:32])
if expCmdSN != cmd.ExpCmdSN {
t.Errorf("ExpCmdSN: expected 0x%08x, got 0x%08x", cmd.ExpCmdSN, expCmdSN)
}
// Byte 32-35: MaxCmdSN
maxCmdSN := binary.BigEndian.Uint32(resp[32:36])
if maxCmdSN != cmd.MaxCmdSN {
t.Errorf("MaxCmdSN: expected 0x%08x, got 0x%08x", cmd.MaxCmdSN, maxCmdSN)
}
}
// TestSCSICmdRespBytesFormat verifies SCSI Command Response BHS format
func TestSCSICmdRespBytesFormat(t *testing.T) {
rawData := []byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05}
cmd := &ISCSICommand{
OpCode: OpSCSIResp,
Status: 0x00, // GOOD
SCSIResponse: 0x00,
TaskTag: 0xABCDEF00,
StatSN: 0x12345678,
ExpCmdSN: 0x87654321,
MaxCmdSN: 0x87654421,
Resid: 0,
RawData: rawData,
ExpectedDataLen: uint32(len(rawData)),
}
resp := cmd.scsiCmdRespBytes()
// Verify length
if len(resp) < 48 {
t.Fatalf("SCSI response too short: expected at least 48, got %d", len(resp))
}
// Byte 0: Opcode
if resp[0] != byte(OpSCSIResp) {
t.Errorf("Byte 0: expected OpSCSIResp(0x21), got 0x%02x", resp[0])
}
// Byte 1: Flags (0x80 = final, no residual)
if resp[1] != 0x80 {
t.Errorf("Byte 1: expected 0x80, got 0x%02x", resp[1])
}
// Byte 2: SCSI Response
if resp[2] != 0 {
t.Errorf("Byte 2 (SCSI Response): expected 0, got %d", resp[2])
}
// Byte 3: Status
if resp[3] != cmd.Status {
t.Errorf("Byte 3 (Status): expected %d, got %d", cmd.Status, resp[3])
}
// 验证数据段
if len(resp) > 48 {
data := resp[48:]
if len(data) >= len(rawData) {
if !bytes.Equal(data[:len(rawData)], rawData) {
t.Errorf("RawData mismatch")
}
}
}
// Verify 4-byte alignment
if len(resp)%4 != 0 {
t.Errorf("Response not aligned to 4 bytes: length=%d", len(resp))
}
}
// TestDataInBytesFormat verifies Data-In response format
func TestDataInBytesFormat(t *testing.T) {
rawData := make([]byte, 512) // Simulate 512 bytes of data
for i := range rawData {
rawData[i] = byte(i % 256)
}
cmd := &ISCSICommand{
OpCode: OpSCSIIn,
Final: true,
FinalInSeq: true,
HasStatus: true,
Status: 0x00,
DataLen: len(rawData),
TaskTag: 0x12345678,
StatSN: 0xABCDEF00,
ExpCmdSN: 0x11111111,
MaxCmdSN: 0x22222222,
DataSN: 0,
BufferOffset: 0,
Resid: 0,
RawData: rawData,
ExpectedDataLen: uint32(len(rawData)),
SCSIOpCode: 0x28, // READ_10
}
resp := cmd.dataInBytes()
// 验证长度
expectedLen := 48 + len(rawData)
if len(rawData)%4 != 0 {
expectedLen += 4 - len(rawData)%4
}
if len(resp) != expectedLen {
t.Fatalf("Data-In response length: expected %d, got %d", expectedLen, len(resp))
}
// Byte 0: Opcode
if resp[0] != byte(OpSCSIIn) {
t.Errorf("Byte 0: expected OpSCSIIn(0x25), got 0x%02x", resp[0])
}
// Byte 1: Flags (0x80 = final, 0x01 = status present)
expectedFlags := byte(0x80 | 0x01)
if resp[1] != expectedFlags {
t.Errorf("Byte 1: expected 0x%02x, got 0x%02x", expectedFlags, resp[1])
}
// Byte 3: Status
if resp[3] != cmd.Status {
t.Errorf("Byte 3 (Status): expected %d, got %d", cmd.Status, resp[3])
}
// 验证数据段
data := resp[48:]
if !bytes.Equal(data, rawData) {
t.Errorf("Data segment mismatch")
}
}
// TestR2TRespBytesFormat verifies R2T (Ready To Transfer) response format
func TestR2TRespBytesFormat(t *testing.T) {
cmd := &ISCSICommand{
OpCode: OpReady,
Final: true,
TaskTag: 0x12345678,
StatSN: 0xABCDEF00,
ExpCmdSN: 0x11111111,
MaxCmdSN: 0x22222222,
R2TSN: 0,
BufferOffset: 0,
DesiredLength: 8192,
}
resp := cmd.r2tRespBytes()
// Verify length is exactly 48 bytes
if len(resp) != 48 {
t.Fatalf("R2T response length: expected 48, got %d", len(resp))
}
// Byte 0: Opcode
if resp[0] != byte(OpReady) {
t.Errorf("Byte 0: expected OpReady(0x31), got 0x%02x", resp[0])
}
// Byte 1: Flags (0x80 = final)
if resp[1] != 0x80 {
t.Errorf("Byte 1: expected 0x80, got 0x%02x", resp[1])
}
// Byte 16-19: Task Tag
taskTag := binary.BigEndian.Uint32(resp[16:20])
if taskTag != cmd.TaskTag {
t.Errorf("TaskTag: expected 0x%08x, got 0x%08x", cmd.TaskTag, taskTag)
}
// Byte 36-39: R2TSN
r2tsn := binary.BigEndian.Uint32(resp[36:40])
if r2tsn != cmd.R2TSN {
t.Errorf("R2TSN: expected 0x%08x, got 0x%08x", cmd.R2TSN, r2tsn)
}
// Byte 40-43: Buffer Offset
bufferOffset := binary.BigEndian.Uint32(resp[40:44])
if bufferOffset != cmd.BufferOffset {
t.Errorf("BufferOffset: expected 0x%08x, got 0x%08x", cmd.BufferOffset, bufferOffset)
}
// Byte 44-47: Desired Length
desiredLength := binary.BigEndian.Uint32(resp[44:48])
if desiredLength != cmd.DesiredLength {
t.Errorf("DesiredLength: expected 0x%08x, got 0x%08x", cmd.DesiredLength, desiredLength)
}
}
// BenchmarkLoginRespBytes benchmarks Login Response
func BenchmarkLoginRespBytes(b *testing.B) {
cmd := &ISCSICommand{
OpCode: OpLoginResp,
Transit: true,
ISID: 0x123456789ABC,
TSIH: 0x1234,
TaskTag: 0xABCDEF00,
StatSN: 0x12345678,
ExpCmdSN: 0x87654321,
MaxCmdSN: 0x87654421,
StatusClass: 0,
StatusDetail: 0,
RawData: []byte("TestData"),
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = cmd.loginRespBytes()
}
}
// TestTextRespBytesFormat verifies Text Response BHS format
func TestTextRespBytesFormat(t *testing.T) {
rawData := []byte("SendTargets=test")
cmd := &ISCSICommand{
OpCode: OpTextResp,
Final: true,
Cont: false,
TaskTag: 0x12345678,
StatSN: 0xABCDEF00,
ExpCmdSN: 0x11111111,
MaxCmdSN: 0x22222222,
RawData: rawData,
}
resp := cmd.textRespBytes()
// Verify BHS length is at least 48 bytes
if len(resp) < 48 {
t.Fatalf("BHS too short: expected at least 48, got %d", len(resp))
}
// Byte 0: Opcode
if resp[0] != byte(OpTextResp) {
t.Errorf("Byte 0: expected OpTextResp(0x24), got 0x%02x", resp[0])
}
// Byte 1: Flags (0x80 = final)
if resp[1] != 0x80 {
t.Errorf("Byte 1: expected 0x80, got 0x%02x", resp[1])
}
// Byte 4-7: Data Segment Length
dataLen := binary.BigEndian.Uint32(resp[4:8])
if dataLen != uint32(len(rawData)) {
t.Errorf("Data segment length: expected %d, got %d", len(rawData), dataLen)
}
// Byte 16-19: Task Tag
taskTag := binary.BigEndian.Uint32(resp[16:20])
if taskTag != cmd.TaskTag {
t.Errorf("TaskTag: expected 0x%08x, got 0x%08x", cmd.TaskTag, taskTag)
}
// Byte 20-23: 0xffffffff
if resp[20] != 0xff || resp[21] != 0xff || resp[22] != 0xff || resp[23] != 0xff {
t.Errorf("Bytes 20-23: expected 0xffffffff, got 0x%02x%02x%02x%02x",
resp[20], resp[21], resp[22], resp[23])
}
// Byte 24-27: StatSN
statSN := binary.BigEndian.Uint32(resp[24:28])
if statSN != cmd.StatSN {
t.Errorf("StatSN: expected 0x%08x, got 0x%08x", cmd.StatSN, statSN)
}
// 验证数据段
if len(resp) > 48 {
data := resp[48:]
if !bytes.Equal(data, rawData) {
t.Errorf("RawData mismatch: expected %v, got %v", rawData, data)
}
}
// Verify 4-byte alignment
if len(resp)%4 != 0 {
t.Errorf("Response not aligned to 4 bytes: length=%d", len(resp))
}
}
// BenchmarkSCSICmdRespBytes benchmarks SCSI Command Response
func BenchmarkSCSICmdRespBytes(b *testing.B) {
cmd := &ISCSICommand{
OpCode: OpSCSIResp,
Status: 0x00,
TaskTag: 0xABCDEF00,
StatSN: 0x12345678,
ExpCmdSN: 0x87654321,
MaxCmdSN: 0x87654421,
RawData: []byte{0x00, 0x01, 0x02, 0x03},
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = cmd.scsiCmdRespBytes()
}
}

View File

@@ -22,9 +22,9 @@ import (
"strings"
"sync"
"github.com/google/uuid"
"github.com/gostor/gotgt/pkg/api"
"github.com/gostor/gotgt/pkg/scsi"
uuid "github.com/satori/go.uuid"
log "github.com/sirupsen/logrus"
)
@@ -333,7 +333,7 @@ func (s *ISCSITargetDriver) UnBindISCSISession(sess *ISCSISession) {
target.SessionsRWMutex.Lock()
defer target.SessionsRWMutex.Unlock()
delete(target.Sessions, sess.TSIH)
scsi.RemoveITNexus(&sess.Target.SCSITarget, sess.ITNexus)
scsi.RemoveITNexus(sess.Target.SCSITarget, sess.ITNexus)
}
func (s *ISCSITargetDriver) BindISCSISession(conn *iscsiConnection) error {
@@ -395,8 +395,8 @@ func (s *ISCSITargetDriver) BindISCSISession(conn *iscsiConnection) error {
log.Infof("Login request received from initiator: %v, Session type: %s, Target name:%v, ISID: 0x%x",
conn.loginParam.initiator, "Normal", conn.loginParam.target, conn.loginParam.isid)
//register normal session
itnexus := &api.ITNexus{uuid.NewV1(), GeniSCSIITNexusID(newSess)}
scsi.AddITNexus(&newSess.Target.SCSITarget, itnexus)
itnexus := &api.ITNexus{ID: uuid.New(), Tag: GeniSCSIITNexusID(newSess)}
scsi.AddITNexus(newSess.Target.SCSITarget, itnexus)
newSess.ITNexus = itnexus
conn.session = newSess
@@ -417,8 +417,8 @@ func (s *ISCSITargetDriver) BindISCSISession(conn *iscsiConnection) error {
return err
}
itnexus := &api.ITNexus{uuid.NewV1(), GeniSCSIITNexusID(newSess)}
scsi.AddITNexus(&newSess.Target.SCSITarget, itnexus)
itnexus := &api.ITNexus{ID: uuid.New(), Tag: GeniSCSIITNexusID(newSess)}
scsi.AddITNexus(newSess.Target.SCSITarget, itnexus)
newSess.ITNexus = itnexus
conn.session = newSess

View File

@@ -82,9 +82,9 @@ func bsPerformCommand(bs api.BackingStore, cmd *api.SCSICommand) (err error, key
doWrite = true
goto write
case api.COMPARE_AND_WRITE:
// TODO
doWrite = true
goto write
// COMPARE_AND_WRITE is handled directly in SBCCompareAndWrite function
// This case should not be reached
return fmt.Errorf("COMPARE_AND_WRITE should be handled by SBCCompareAndWrite"), ILLEGAL_REQUEST, ASC_INVALID_OP_CODE
case api.SYNCHRONIZE_CACHE, api.SYNCHRONIZE_CACHE_16:
if tl == 0 {
tl = int64(lu.Size - offset)

View File

@@ -8,7 +8,7 @@ Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,

View File

@@ -5,7 +5,7 @@ Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,

View File

@@ -20,6 +20,7 @@ import (
"fmt"
"io"
"os"
"strings"
log "github.com/sirupsen/logrus"
@@ -51,10 +52,33 @@ func new() (api.BackingStore, error) {
}, nil
}
// parseStoragePath parses a storage path that may include backend type prefix
// Format: [backend_type:]path
// Examples:
// - /var/tmp/disk.img (default file backend)
// - file:/var/tmp/disk.img (explicit file backend)
// - iouring:/var/tmp/disk.img (io_uring backend on Linux 5.1+)
func parseStoragePath(path string) (backendType, filePath string) {
if idx := strings.Index(path, ":"); idx > 0 {
possibleType := path[:idx]
// Check if it's a known backend type
switch possibleType {
case "file", "iouring", "ceph", "null", "RemBs":
return possibleType, path[idx+1:]
}
}
// Default to file backend
return "file", path
}
func (bs *FileBackingStore) Open(dev *api.SCSILu, path string) error {
var mode os.FileMode
finfo, err := os.Stat(path)
// Parse backend type and actual path
backendType, filePath := parseStoragePath(path)
_ = backendType // file backend ignores this
finfo, err := os.Stat(filePath)
if err != nil {
return err
} else {
@@ -62,7 +86,7 @@ func (bs *FileBackingStore) Open(dev *api.SCSILu, path string) error {
mode = finfo.Mode()
}
f, err := os.OpenFile(path, os.O_RDWR, os.ModePerm)
f, err := os.OpenFile(filePath, os.O_RDWR, os.ModePerm)
if err == nil {
// block device filesize needs to be treated differently

View File

@@ -0,0 +1,727 @@
//go:build linux
// +build linux
/*
Copyright 2024 The GoStor Authors All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// Package iouring provides an io_uring-based backing store for high-performance
// asynchronous I/O operations on Linux 5.1+ systems.
package iouring
import (
"fmt"
"os"
"runtime"
"sync"
"sync/atomic"
"syscall"
"unsafe"
log "github.com/sirupsen/logrus"
"github.com/gostor/gotgt/pkg/api"
"github.com/gostor/gotgt/pkg/scsi"
)
const (
IoUringBackingStorage = "iouring"
// Default queue depth for io_uring
DefaultQueueDepth = 4096
// Minimum kernel version required (5.1)
MinKernelMajor = 5
MinKernelMinor = 1
)
// io_uring constants (from linux/io_uring.h)
const (
IORING_SETUP_IOPOLL = 1 << 0
IORING_SETUP_SQPOLL = 1 << 1
IORING_SETUP_SQ_AFF = 1 << 2
IORING_SETUP_CQSIZE = 1 << 3
IORING_SETUP_CLAMP = 1 << 4
IORING_SETUP_ATTACH_WQ = 1 << 5
IORING_SETUP_R_DISABLED = 1 << 6
IORING_FSYNC_DATASYNC = 1 << 0
IORING_TIMEOUT_ABS = 1 << 0
IORING_OFF_SQ_RING = 0
IORING_OFF_CQ_RING = 0x8000000
IORING_OFF_SQES = 0x10000000
IORING_OP_NOP = 0
IORING_OP_READV = 1
IORING_OP_WRITEV = 2
IORING_OP_FSYNC = 3
IORING_OP_READ_FIXED = 4
IORING_OP_WRITE_FIXED = 5
IORING_OP_POLL_ADD = 6
IORING_OP_POLL_REMOVE = 7
IORING_OP_SYNC_FILE_RANGE = 8
IORING_OP_SENDMSG = 9
IORING_OP_RECVMSG = 10
IORING_OP_TIMEOUT = 11
IORING_OP_TIMEOUT_REMOVE = 12
IORING_OP_ACCEPT = 13
IORING_OP_ASYNC_CANCEL = 14
IORING_OP_LINK_TIMEOUT = 15
IORING_OP_CONNECT = 16
IORING_OP_FALLOCATE = 17
IORING_OP_OPENAT = 18
IORING_OP_CLOSE = 19
IORING_OP_FILES_UPDATE = 20
IORING_OP_STATX = 21
IORING_OP_READ = 22
IORING_OP_WRITE = 23
IORING_OP_FADVISE = 24
IORING_OP_MADVISE = 25
IORING_OP_SEND = 26
IORING_OP_RECV = 27
IORING_OP_OPENAT2 = 28
IORING_OP_EPOLL_CTL = 29
IORING_OP_SPLICE = 30
IORING_OP_PROVIDE_BUFFERS = 31
IORING_OP_REMOVE_BUFFERS = 32
IORING_OP_TEE = 33
IORING_OP_SHUTDOWN = 34
IORING_OP_RENAMEAT = 35
IORING_OP_UNLINKAT = 36
IORING_OP_MKDIRAT = 37
IORING_OP_SYMLINKAT = 38
IORING_OP_LINKAT = 39
IORING_OP_MSG_RING = 40
IORING_OP_FSETXATTR = 41
IORING_OP_SETXATTR = 42
IORING_OP_FGETXATTR = 43
IORING_OP_GETXATTR = 44
IORING_OP_SOCKET = 45
IORING_OP_URING_CMD = 46
IORING_OP_SEND_ZC = 47
IORING_OP_SENDMSG_ZC = 48
IORING_CQE_F_BUFFER = 1 << 0
IORING_CQE_F_MORE = 1 << 1
)
// io_uring structures
// Note: These are simplified structures for the operations we need
type ioUring struct {
fd int
sq *ioUringSq
cq *ioUringCq
flags uint32
ringSize int
}
type ioUringSq struct {
head *uint32
tail *uint32
ringMask *uint32
ringEntries *uint32
flags *uint32
dropped *uint32
array *uint32
sqes []ioSqringEntry
}
type ioUringCq struct {
head *uint32
tail *uint32
ringMask *uint32
ringEntries *uint32
overflow *uint32
cqes []ioCqringEntry
}
type ioSqringEntry struct {
opcode uint8
flags uint8
ioprio uint16
fd int32
off uint64
addr uint64
len uint32
userData uint64
}
type ioCqringEntry struct {
userData uint64
res int32
flags uint32
}
type ioUringParams struct {
sqEntries uint32
cqEntries uint32
flags uint32
sqThreadCPU uint32
sqThreadIdle uint32
features uint32
wqFd uint32
resv [3]uint32
sqOff ioSqringOffsets
cqOff ioCqringOffsets
}
type ioSqringOffsets struct {
head uint32
tail uint32
ringMask uint32
ringEntries uint32
flags uint32
dropped uint32
array uint32
resv1 uint32
resv2 uint64
}
type ioCqringOffsets struct {
head uint32
tail uint32
ringMask uint32
ringEntries uint32
overflow uint32
cqes uint32
flags uint32
resv1 uint32
resv2 uint64
}
type ioUringCqe struct {
userData uint64
res int32
flags uint32
}
var ioUringEnabled = false
func init() {
if isKernelVersionSupported() {
ioUringEnabled = true
scsi.RegisterBackingStore(IoUringBackingStorage, newIOUringBackingStore)
log.Info("io_uring backing store registered (kernel supports io_uring)")
} else {
log.Info("io_uring backing store not available (requires Linux 5.1+)")
}
}
func isKernelVersionSupported() bool {
var uname syscall.Utsname
if err := syscall.Uname(&uname); err != nil {
return false
}
// Parse kernel version (simplified)
// Format is typically "5.15.0-generic"
major := int(uname.Release[0] - '0')
minor := int(uname.Release[2] - '0')
if major > MinKernelMajor {
return true
}
if major == MinKernelMajor && minor >= MinKernelMinor {
return true
}
return false
}
// IOUringBackingStore implements BackingStore using io_uring
type IOUringBackingStore struct {
scsi.BaseBackingStore
file *os.File
ring *ioUring
queueDepth int
// Synchronization
submitMu sync.Mutex
// Statistics
opsSubmitted uint64
opsCompleted uint64
}
func newIOUringBackingStore() (api.BackingStore, error) {
return &IOUringBackingStore{
BaseBackingStore: scsi.BaseBackingStore{
Name: IoUringBackingStorage,
DataSize: 0,
OflagsSupported: 0,
},
queueDepth: DefaultQueueDepth,
}, nil
}
// Open opens the backing file and initializes io_uring
func (bs *IOUringBackingStore) Open(dev *api.SCSILu, path string) error {
var mode os.FileMode
finfo, err := os.Stat(path)
if err != nil {
return err
}
mode = finfo.Mode()
f, err := os.OpenFile(path, os.O_RDWR|syscall.O_DIRECT, os.ModePerm)
if err != nil {
// Try without O_DIRECT if not supported
f, err = os.OpenFile(path, os.O_RDWR, os.ModePerm)
if err != nil {
return err
}
}
if (mode & os.ModeDevice) != 0 {
pos, err := f.Seek(0, os.SEEK_END)
if err != nil {
f.Close()
return err
}
bs.DataSize = uint64(pos)
} else {
bs.DataSize = uint64(finfo.Size())
}
bs.file = f
// Initialize io_uring
ring, err := bs.initIOUring()
if err != nil {
f.Close()
return fmt.Errorf("failed to initialize io_uring: %v", err)
}
bs.ring = ring
log.Infof("io_uring backing store opened: %s (queue depth: %d)", path, bs.queueDepth)
return nil
}
func (bs *IOUringBackingStore) initIOUring() (*ioUring, error) {
params := &ioUringParams{}
// Setup io_uring
fd, _, errno := syscall.Syscall(425, // __NR_io_uring_setup
uintptr(bs.queueDepth),
uintptr(unsafe.Pointer(params)),
0)
if errno != 0 {
return nil, fmt.Errorf("io_uring_setup failed: %v", errno)
}
ring := &ioUring{
fd: int(fd),
ringSize: int(params.sqEntries),
flags: params.flags,
}
// Map the submission queue ring
sqRingSize := params.sqOff.array + params.sqEntries*uint32(unsafe.Sizeof(uint32(0)))
cqRingSize := params.cqOff.cqes + params.cqEntries*uint32(unsafe.Sizeof(ioCqringEntry{}))
if params.features&1 != 0 { // IORING_FEAT_SINGLE_MMAP
if cqRingSize > sqRingSize {
sqRingSize = cqRingSize
}
cqRingSize = sqRingSize
}
// mmap submission queue
sqPtr, _, errno := syscall.Syscall6(syscall.SYS_MMAP,
0,
uintptr(sqRingSize),
syscall.PROT_READ|syscall.PROT_WRITE,
syscall.MAP_SHARED|syscall.MAP_POPULATE,
uintptr(fd),
uintptr(IORING_OFF_SQ_RING))
if errno != 0 {
syscall.Close(int(fd))
return nil, fmt.Errorf("mmap sq ring failed: %v", errno)
}
sqBase := sqPtr
// mmap completion queue (if not single mmap)
var cqPtr uintptr
if params.features&1 != 0 {
cqPtr = sqPtr
} else {
cqPtr, _, errno = syscall.Syscall6(syscall.SYS_MMAP,
0,
uintptr(cqRingSize),
syscall.PROT_READ|syscall.PROT_WRITE,
syscall.MAP_SHARED|syscall.MAP_POPULATE,
uintptr(fd),
uintptr(IORING_OFF_CQ_RING))
if errno != 0 {
syscall.Syscall(syscall.SYS_MUNMAP, sqPtr, uintptr(sqRingSize), 0)
syscall.Close(int(fd))
return nil, fmt.Errorf("mmap cq ring failed: %v", errno)
}
}
cqBase := cqPtr
// mmap SQEs
sqeSize := uint32(unsafe.Sizeof(ioSqringEntry{}))
sqePtr, _, errno := syscall.Syscall6(syscall.SYS_MMAP,
0,
uintptr(uint32(bs.queueDepth)*sqeSize),
syscall.PROT_READ|syscall.PROT_WRITE,
syscall.MAP_SHARED|syscall.MAP_POPULATE,
uintptr(fd),
uintptr(IORING_OFF_SQES))
if errno != 0 {
syscall.Syscall(syscall.SYS_MUNMAP, sqPtr, uintptr(sqRingSize), 0)
if cqPtr != sqPtr {
syscall.Syscall(syscall.SYS_MUNMAP, cqPtr, uintptr(cqRingSize), 0)
}
syscall.Close(int(fd))
return nil, fmt.Errorf("mmap sqes failed: %v", errno)
}
// Setup submission queue
sq := &ioUringSq{
head: (*uint32)(unsafe.Pointer(sqBase + uintptr(params.sqOff.head))),
tail: (*uint32)(unsafe.Pointer(sqBase + uintptr(params.sqOff.tail))),
ringMask: (*uint32)(unsafe.Pointer(sqBase + uintptr(params.sqOff.ringMask))),
ringEntries: (*uint32)(unsafe.Pointer(sqBase + uintptr(params.sqOff.ringEntries))),
flags: (*uint32)(unsafe.Pointer(sqBase + uintptr(params.sqOff.flags))),
dropped: (*uint32)(unsafe.Pointer(sqBase + uintptr(params.sqOff.dropped))),
array: (*uint32)(unsafe.Pointer(sqBase + uintptr(params.sqOff.array))),
sqes: make([]ioSqringEntry, bs.queueDepth),
}
copy(unsafe.Slice((*ioSqringEntry)(unsafe.Pointer(sqePtr)), bs.queueDepth), sq.sqes)
// Setup completion queue
cq := &ioUringCq{
head: (*uint32)(unsafe.Pointer(cqBase + uintptr(params.cqOff.head))),
tail: (*uint32)(unsafe.Pointer(cqBase + uintptr(params.cqOff.tail))),
ringMask: (*uint32)(unsafe.Pointer(cqBase + uintptr(params.cqOff.ringMask))),
ringEntries: (*uint32)(unsafe.Pointer(cqBase + uintptr(params.cqOff.ringEntries))),
overflow: (*uint32)(unsafe.Pointer(cqBase + uintptr(params.cqOff.overflow))),
cqes: make([]ioCqringEntry, params.cqEntries),
}
copy(unsafe.Slice((*ioCqringEntry)(unsafe.Pointer(cqBase+uintptr(params.cqOff.cqes))), params.cqEntries), cq.cqes)
ring.sq = sq
ring.cq = cq
return ring, nil
}
// Close closes the backing file and io_uring
func (bs *IOUringBackingStore) Close(dev *api.SCSILu) error {
if bs.ring != nil {
bs.closeIOUring()
bs.ring = nil
}
if bs.file != nil {
return bs.file.Close()
}
return nil
}
func (bs *IOUringBackingStore) closeIOUring() {
if bs.ring != nil && bs.ring.fd >= 0 {
syscall.Close(bs.ring.fd)
}
}
// Init initializes the backing store
func (bs *IOUringBackingStore) Init(dev *api.SCSILu, Opts string) error {
return nil
}
// Exit exits the backing store
func (bs *IOUringBackingStore) Exit(dev *api.SCSILu) error {
return nil
}
// Size returns the size of the backing store
func (bs *IOUringBackingStore) Size(dev *api.SCSILu) uint64 {
return bs.DataSize
}
// Read reads data from the backing file using io_uring
func (bs *IOUringBackingStore) Read(offset, tl int64) ([]byte, error) {
if bs.file == nil {
return nil, fmt.Errorf("backing store is not open")
}
buf := make([]byte, tl)
// Prepare read operation
bs.submitMu.Lock()
defer bs.submitMu.Unlock()
// Get next SQE
sqe := bs.getSqe()
if sqe == nil {
// Ring is full, submit pending operations first
if err := bs.submit(); err != nil {
return nil, err
}
sqe = bs.getSqe()
if sqe == nil {
return nil, fmt.Errorf("io_uring queue full")
}
}
// Setup read operation
*sqe = ioSqringEntry{
opcode: IORING_OP_READ,
fd: int32(bs.file.Fd()),
off: uint64(offset),
addr: uint64(uintptr(unsafe.Pointer(&buf[0]))),
len: uint32(tl),
userData: 1, // 1 = read operation
}
// Submit and wait for completion
if err := bs.submitAndWait(1); err != nil {
return nil, err
}
// Get completion
cqe, err := bs.getCqe()
if err != nil {
return nil, err
}
if cqe.res < 0 {
return nil, fmt.Errorf("read failed: %d", cqe.res)
}
atomic.AddUint64(&bs.opsCompleted, 1)
return buf[:cqe.res], nil
}
// Write writes data to the backing file using io_uring
func (bs *IOUringBackingStore) Write(wbuf []byte, offset int64) error {
if bs.file == nil {
return fmt.Errorf("backing store is not open")
}
bs.submitMu.Lock()
defer bs.submitMu.Unlock()
// Get next SQE
sqe := bs.getSqe()
if sqe == nil {
if err := bs.submit(); err != nil {
return err
}
sqe = bs.getSqe()
if sqe == nil {
return fmt.Errorf("io_uring queue full")
}
}
// Setup write operation
*sqe = ioSqringEntry{
opcode: IORING_OP_WRITE,
fd: int32(bs.file.Fd()),
off: uint64(offset),
addr: uint64(uintptr(unsafe.Pointer(&wbuf[0]))),
len: uint32(len(wbuf)),
userData: 2, // 2 = write operation
}
// Submit and wait for completion
if err := bs.submitAndWait(1); err != nil {
return err
}
// Get completion
cqe, err := bs.getCqe()
if err != nil {
return err
}
if cqe.res < 0 {
return fmt.Errorf("write failed: %d", cqe.res)
}
if cqe.res != int32(len(wbuf)) {
return fmt.Errorf("short write: %d != %d", cqe.res, len(wbuf))
}
atomic.AddUint64(&bs.opsCompleted, 1)
return nil
}
// DataSync syncs data to disk using io_uring
func (bs *IOUringBackingStore) DataSync(offset, tl int64) error {
if bs.file == nil {
return fmt.Errorf("backing store is not open")
}
bs.submitMu.Lock()
defer bs.submitMu.Unlock()
sqe := bs.getSqe()
if sqe == nil {
if err := bs.submit(); err != nil {
return err
}
sqe = bs.getSqe()
if sqe == nil {
return fmt.Errorf("io_uring queue full")
}
}
*sqe = ioSqringEntry{
opcode: IORING_OP_FSYNC,
fd: int32(bs.file.Fd()),
len: IORING_FSYNC_DATASYNC,
userData: 3, // 3 = fsync operation
}
if err := bs.submitAndWait(1); err != nil {
return err
}
cqe, err := bs.getCqe()
if err != nil {
return err
}
if cqe.res < 0 {
return fmt.Errorf("fsync failed: %d", cqe.res)
}
atomic.AddUint64(&bs.opsCompleted, 1)
return nil
}
// DataAdvise provides advice about data access patterns
func (bs *IOUringBackingStore) DataAdvise(offset, length int64, advise uint32) error {
if bs.file == nil {
return fmt.Errorf("backing store is not open")
}
// Use posix_fadvise via syscall
_, _, errno := syscall.Syscall6(syscall.SYS_FADVISE64, uintptr(bs.file.Fd()), uintptr(offset), uintptr(length), uintptr(advise), 0, 0)
if errno != 0 {
return errno
}
return nil
}
// Unmap is a no-op for file-based storage
func (bs *IOUringBackingStore) Unmap([]api.UnmapBlockDescriptor) error {
return nil
}
// getSqe gets the next available submission queue entry
func (bs *IOUringBackingStore) getSqe() *ioSqringEntry {
sq := bs.ring.sq
tail := atomic.LoadUint32(sq.tail)
next := tail + 1
if next-atomic.LoadUint32(sq.head) > uint32(bs.ring.ringSize) {
return nil // Queue is full
}
idx := tail & *sq.ringMask
return &sq.sqes[idx]
}
// submit submits pending SQEs to the kernel
func (bs *IOUringBackingStore) submit() error {
if bs.ring == nil {
return fmt.Errorf("io_uring not initialized")
}
// Update tail
atomic.StoreUint32(bs.ring.sq.tail, atomic.LoadUint32(bs.ring.sq.tail)+1)
// Submit using io_uring_enter syscall
_, _, errno := syscall.Syscall6(426, // __NR_io_uring_enter
uintptr(bs.ring.fd),
uintptr(1), // submit 1 operation
0, // min complete
0, // flags
0, 0)
if errno != 0 {
return fmt.Errorf("io_uring_enter failed: %v", errno)
}
atomic.AddUint64(&bs.opsSubmitted, 1)
return nil
}
// submitAndWait submits operations and waits for completions
func (bs *IOUringBackingStore) submitAndWait(minComplete uint32) error {
if bs.ring == nil {
return fmt.Errorf("io_uring not initialized")
}
// Update tail
atomic.StoreUint32(bs.ring.sq.tail, atomic.LoadUint32(bs.ring.sq.tail)+1)
// Submit and wait
_, _, errno := syscall.Syscall6(426, // __NR_io_uring_enter
uintptr(bs.ring.fd),
uintptr(1), // submit 1 operation
uintptr(minComplete), // min complete
0, // flags
0, 0)
if errno != 0 {
return fmt.Errorf("io_uring_enter failed: %v", errno)
}
return nil
}
// getCqe gets a completion queue entry
func (bs *IOUringBackingStore) getCqe() (*ioCqringEntry, error) {
cq := bs.ring.cq
// Wait for completion
for atomic.LoadUint32(cq.head) == atomic.LoadUint32(cq.tail) {
// Spin-wait for completion
runtime.Gosched()
}
head := atomic.LoadUint32(cq.head)
idx := head & *cq.ringMask
cqe := &cq.cqes[idx]
// Update head
atomic.StoreUint32(cq.head, head+1)
return cqe, nil
}
// Stats returns io_uring statistics
func (bs *IOUringBackingStore) Stats() (submitted, completed uint64) {
return atomic.LoadUint64(&bs.opsSubmitted), atomic.LoadUint64(&bs.opsCompleted)
}
// Available returns true if io_uring is available on this system
func Available() bool {
return ioUringEnabled
}

View File

@@ -0,0 +1,33 @@
//go:build !linux
// +build !linux
/*
Copyright 2024 The GoStor Authors All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package iouring
import (
// io_uring is not available on non-Linux platforms
)
func init() {
// io_uring is not available on non-Linux platforms
}
// Available returns false on non-Linux platforms
func Available() bool {
return false
}

View File

@@ -249,21 +249,60 @@ func SCSICDBBufXLength(scb []byte) (int64, bool) {
opcode = scb[0]
group = SCSICDBGroupID(opcode)
// Note: group is 0-7, not the CDB length (6, 10, 12, 16)
switch group {
case CDB_GROUPID_0:
length = int64(scb[4])
case CDB_GROUPID_2:
length = int64(util.GetUnalignedUint16(scb[7:9]))
case CDB_GROUPID_3:
case 0: // GROUPID_0: 6-byte commands
// INQUIRY (0x12) and REQUEST_SENSE (0x03) have Allocation Length in bytes 3-4
if opcode == 0x12 || opcode == 0x03 {
length = int64(util.GetUnalignedUint16(scb[3:5]))
} else {
// For other Group 0 commands (READ_6, WRITE_6, etc.),
// byte 4 is typically Transfer Length, not Allocation Length.
// We should not use it to limit sense data buffer.
ok = false
}
case 1, 2: // GROUPID_1, GROUPID_2: 10-byte commands
// PERSISTENT_RESERVE_IN (0x5E) and PERSISTENT_RESERVE_OUT (0x5F)
// have Allocation Length in bytes 6-7, not 7-8
if opcode == 0x5E || opcode == 0x5F {
// Manual BigEndian conversion for PRIN/PROUT
length = int64(uint16(scb[6])<<8 | uint16(scb[7]))
} else if opcode == 0x28 || opcode == 0x2A || opcode == 0x2E || opcode == 0x35 ||
opcode == 0x34 || opcode == 0x2F || opcode == 0x41 || opcode == 0x55 ||
opcode == 0x5A || opcode == 0x56 || opcode == 0x57 {
// READ_10(0x28), WRITE_10(0x2A), WRITE_VERIFY(0x2E), SYNCHRONIZE_CACHE(0x35),
// PRE_FETCH_10(0x34), VERIFY_10(0x2F), WRITE_SAME(0x41), MODE_SELECT_10(0x55),
// MODE_SENSE_10(0x5A), RESERVE_10(0x56), RELEASE_10(0x57)
// These commands have Transfer Length or Parameter List Length in bytes 7-8,
// not Allocation Length.
ok = false
} else {
length = int64(util.GetUnalignedUint16(scb[7:9]))
}
case 3: // GROUPID_3: variable length
if opcode == 0x7F {
length = int64(scb[7])
} else {
ok = false
}
case CDB_GROUPID_4:
length = int64(util.GetUnalignedUint32(scb[6:10]))
case CDB_GROUPID_5:
length = int64(util.GetUnalignedUint32(scb[10:14]))
case 4: // GROUPID_4: 16-byte commands
// READ_16(0x88), WRITE_16(0x8A), WRITE_VERIFY_16(0x8E), SYNCHRONIZE_CACHE_16(0x91),
// PRE_FETCH_16(0x90), VERIFY_16(0x8F), WRITE_SAME_16(0x93), ORWRITE_16(0x8B)
if opcode == 0x88 || opcode == 0x8A || opcode == 0x8E || opcode == 0x91 ||
opcode == 0x90 || opcode == 0x8F || opcode == 0x93 || opcode == 0x8B {
// These commands have Transfer Length in bytes 6-9, not Allocation Length
ok = false
} else {
length = int64(util.GetUnalignedUint32(scb[6:10]))
}
case 5: // GROUPID_5: 12-byte commands
// READ_12(0xA8), WRITE_12(0xAA), WRITE_VERIFY_12(0xAE), VERIFY_12(0xAF)
if opcode == 0xA8 || opcode == 0xAA || opcode == 0xAE || opcode == 0xAF {
// These commands have Transfer Length in bytes 10-13, not Allocation Length
ok = false
} else {
length = int64(util.GetUnalignedUint32(scb[10:14]))
}
default:
ok = false
}

View File

@@ -30,8 +30,30 @@ func NewSCSILu(bs *config.BackendStorage) (*api.SCSILu, error) {
if len(pathinfo) < 2 {
return nil, errors.New("invalid device path string")
}
backendType := pathinfo[0]
backendPath := pathinfo[1]
// Determine backend type: config.BackendType > path prefix > default (file)
backendType := "file"
backendPath := bs.Path
if bs.BackendType != "" {
// Config specifies backend type explicitly
backendType = bs.BackendType
backendPath = pathinfo[1]
} else {
// Infer from path prefix
backendType = pathinfo[0]
backendPath = pathinfo[1]
// Validate backend type, default to file if unknown
switch backendType {
case "file", "iouring", "ceph", "null", "RemBs":
// Valid types
default:
// Unknown type, treat entire path as file path
backendType = "file"
backendPath = bs.Path
}
}
sbc := NewSBCDevice(api.TYPE_DISK)
backing, err := NewBackingStore(backendType)
@@ -53,7 +75,7 @@ func NewSCSILu(bs *config.BackendStorage) (*api.SCSILu, error) {
}
lu.Size = backing.Size(lu)
lu.DeviceProtocol.InitLu(lu)
lu.Attrs.ThinProvisioning = bs.ThinProvisioning
lu.Attrs.ThinProvisioning = true
lu.Attrs.Online = bs.Online
lu.Attrs.Lbppbe = 3
return lu, nil

View File

@@ -18,6 +18,7 @@ limitations under the License.
package scsi
import (
"bytes"
"encoding/binary"
"fmt"
"unsafe"
@@ -105,17 +106,17 @@ func (sbc SBCSCSIDeviceProtocol) InitLu(lu *api.SCSILu) error {
// Vendor uniq - However most apps seem to call for mode page 0
//pages = append(pages, api.ModePage{0, 0, []byte{}})
// Disconnect page
pages = append(pages, api.ModePage{2, 0, 14, []byte{0x80, 0x80, 0, 0xa, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}})
pages = append(pages, api.ModePage{PageCode: 2, SubPageCode: 0, Size: 14, Data: []byte{0x80, 0x80, 0, 0xa, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}})
// Caching Page
pages = append(pages, api.ModePage{8, 0, 18, []byte{0x14, 0, 0xff, 0xff, 0, 0, 0xff, 0xff, 0xff, 0xff, 0x80, 0x14, 0, 0, 0, 0, 0, 0}})
pages = append(pages, api.ModePage{PageCode: 8, SubPageCode: 0, Size: 18, Data: []byte{0x14, 0, 0xff, 0xff, 0, 0, 0xff, 0xff, 0xff, 0xff, 0x80, 0x14, 0, 0, 0, 0, 0, 0}})
// Control page
pages = append(pages, api.ModePage{0x0a, 0, 10, []byte{2, 0x10, 0, 0, 0, 0, 0, 0, 2, 0, 0x08, 0, 0, 0, 0, 0, 0, 0}})
pages = append(pages, api.ModePage{PageCode: 0x0a, SubPageCode: 0, Size: 10, Data: []byte{2, 0x10, 0, 0, 0, 0, 0, 0, 2, 0, 0x08, 0, 0, 0, 0, 0, 0, 0}})
// Control Extensions mode page: TCMOS:1
pages = append(pages, api.ModePage{0x0a, 1, 0x1c, []byte{0x04, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}})
pages = append(pages, api.ModePage{PageCode: 0x0a, SubPageCode: 1, Size: 0x1c, Data: []byte{0x04, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}})
// Informational Exceptions Control page
pages = append(pages, api.ModePage{0x1c, 0, 10, []byte{8, 0, 0, 0, 0, 0, 0, 0, 0, 0}})
pages = append(pages, api.ModePage{PageCode: 0x1c, SubPageCode: 0, Size: 10, Data: []byte{8, 0, 0, 0, 0, 0, 0, 0, 0, 0}})
lu.ModePages = pages
mbd := util.MarshalUint32(uint32(0xffffffff))
if size := lu.Size >> lu.BlockShift; size>>32 == 0 {
@@ -221,6 +222,7 @@ func NewSBCDevice(deviceType api.SCSIDeviceType) api.SCSIDeviceProtocol {
sbc.SCSIDeviceOps[api.WRITE_12] = NewSCSIDeviceOperation(SBCReadWrite, nil, PR_WE_FA|PR_EA_FA|PR_WE_FA|PR_WE_FN)
sbc.SCSIDeviceOps[api.WRITE_VERIFY_12] = NewSCSIDeviceOperation(SBCReadWrite, nil, PR_EA_FA|PR_EA_FN)
sbc.SCSIDeviceOps[api.VERIFY_12] = NewSCSIDeviceOperation(SBCVerify, nil, PR_EA_FA|PR_EA_FN)
sbc.SCSIDeviceOps[api.COMPARE_AND_WRITE] = NewSCSIDeviceOperation(SBCCompareAndWrite, nil, PR_EA_FA|PR_EA_FN)
return sbc
}
@@ -354,15 +356,17 @@ func SBCUnmap(host int, cmd *api.SCSICommand) api.SAMStat {
*/
func SBCReadWrite(host int, cmd *api.SCSICommand) api.SAMStat {
var (
key = ILLEGAL_REQUEST
asc = ASC_INVALID_FIELD_IN_CDB
dev = cmd.Device
scb = cmd.SCB
opcode = api.SCSICommandType(scb[0])
lba uint64
tl uint32
err error
key = ILLEGAL_REQUEST
asc = ASC_INVALID_FIELD_IN_CDB
dev = cmd.Device
scb = cmd.SCB
opcode = api.SCSICommandType(scb[0])
lba uint64
tl uint32
err error
totalBlocks uint64
)
if dev.Attrs.Removable && !dev.Attrs.Online {
key = NOT_READY
asc = ASC_MEDIUM_NOT_PRESENT
@@ -422,21 +426,22 @@ func SBCReadWrite(host int, cmd *api.SCSICommand) api.SAMStat {
lba = getSCSIReadWriteOffset(scb)
tl = getSCSIReadWriteCount(scb)
// Calculate total blocks
totalBlocks = dev.Size >> dev.BlockShift
log.Debugf("SBCReadWrite: opcode=0x%x, lba=%d, tl=%d, totalBlocks=%d", opcode, lba, tl, totalBlocks)
// Verify that we are not doing i/o beyond the end-of-lun
if tl != 0 {
if lba+uint64(tl) < lba || lba+uint64(tl) > dev.Size>>dev.BlockShift {
key = ILLEGAL_REQUEST
asc = ASC_LBA_OUT_OF_RANGE
log.Warnf("sense data(ILLEGAL_REQUEST,ASC_LBA_OUT_OF_RANGE) encounter: lba: %d, tl: %d, size: %d", lba, tl, dev.Size>>dev.BlockShift)
goto sense
}
} else {
if lba >= dev.Size>>dev.BlockShift {
key = ILLEGAL_REQUEST
asc = ASC_LBA_OUT_OF_RANGE
log.Warnf("sense data(ILLEGAL_REQUEST,ASC_LBA_OUT_OF_RANGE) encounter: lba: %d, size: %d", lba, dev.Size>>dev.BlockShift)
goto sense
}
// Even when transfer length is 0, we must validate the LBA is within range
if lba >= totalBlocks || lba+uint64(tl) < lba || lba+uint64(tl) > totalBlocks {
key = ILLEGAL_REQUEST
asc = ASC_LBA_OUT_OF_RANGE
log.Warnf("SBCReadWrite: LBA out of range (lba=%d, tl=%d, totalBlocks=%d)", lba, tl, totalBlocks)
goto sense
}
// If transfer length is 0, return GOOD status immediately (no data to transfer)
if tl == 0 {
return api.SAMStatGood
}
cmd.Offset = lba << dev.BlockShift
@@ -495,6 +500,120 @@ func SBCRelease(host int, cmd *api.SCSICommand) api.SAMStat {
return api.SAMStatGood
}
/*
* SBCCompareAndWrite Implements SCSI COMPARE AND WRITE command (0x89)
* The COMPARE AND WRITE command requests that the device server compare the specified
* logical block(s) with data transferred from the data-out buffer and, if they match,
* write the new data from the data-out buffer to the specified logical block(s).
*
* Reference : SBC3r35
* 5.3 - COMPARE AND WRITE
*/
func SBCCompareAndWrite(host int, cmd *api.SCSICommand) api.SAMStat {
var (
key = ILLEGAL_REQUEST
asc = ASC_INVALID_FIELD_IN_CDB
dev = cmd.Device
scb = cmd.SCB
lba uint64
numBlocks uint32
offset uint64
blockSize uint64
totalCompareLen uint64
expectedDataLen uint64
err error
existingData []byte
compareData []byte
writeData []byte
)
if dev.Attrs.Removable && !dev.Attrs.Online {
key = NOT_READY
asc = ASC_MEDIUM_NOT_PRESENT
goto sense
}
// We only support protection information type 0
if scb[1]&0xe0 != 0 {
key = ILLEGAL_REQUEST
asc = ASC_INVALID_FIELD_IN_CDB
goto sense
}
if dev.Attrs.Readonly || dev.Attrs.SWP {
key = DATA_PROTECT
asc = ASC_WRITE_PROTECT
goto sense
}
// Number of logical blocks (one byte: bits 0-7)
numBlocks = uint32(scb[13])
if numBlocks == 0 {
key = ILLEGAL_REQUEST
asc = ASC_INVALID_FIELD_IN_CDB
goto sense
}
lba = getSCSIReadWriteOffset(scb)
// Verify that we are not doing i/o beyond the end-of-lun
if lba+uint64(numBlocks) < lba || lba+uint64(numBlocks) > dev.Size>>dev.BlockShift {
key = ILLEGAL_REQUEST
asc = ASC_LBA_OUT_OF_RANGE
log.Warnf("COMPARE_AND_WRITE: lba out of range: lba: %d, num: %d, size: %d", lba, numBlocks, dev.Size>>dev.BlockShift)
goto sense
}
offset = lba << dev.BlockShift
blockSize = uint64(1 << dev.BlockShift)
totalCompareLen = uint64(numBlocks) * blockSize
// Data-out buffer contains: compare data followed by write data
// Total length should be 2 * numBlocks * blockSize
expectedDataLen = 2 * totalCompareLen
if uint64(cmd.OutSDBBuffer.Length) < expectedDataLen {
key = ILLEGAL_REQUEST
asc = ASC_INVALID_FIELD_IN_CDB
log.Warnf("COMPARE_AND_WRITE: data length too short: got %d, expected %d", cmd.OutSDBBuffer.Length, expectedDataLen)
goto sense
}
compareData = cmd.OutSDBBuffer.Buffer[:totalCompareLen]
writeData = cmd.OutSDBBuffer.Buffer[totalCompareLen:expectedDataLen]
// Read existing data from storage
existingData, err = dev.Storage.Read(int64(offset), int64(totalCompareLen))
if err != nil {
log.Errorf("COMPARE_AND_WRITE: failed to read data: %v", err)
key = MEDIUM_ERROR
asc = ASC_READ_ERROR
goto sense
}
// Compare data
if !bytes.Equal(existingData, compareData) {
key = MISCOMPARE
asc = ASC_MISCOMPARE_DURING_VERIFY_OPERATION
log.Warnf("COMPARE_AND_WRITE: data miscompare at LBA %d", lba)
goto sense
}
// Data matches, write new data
err = dev.Storage.Write(writeData, int64(offset))
if err != nil {
log.Errorf("COMPARE_AND_WRITE: failed to write data: %v", err)
key = MEDIUM_ERROR
asc = ASC_WRITE_ERROR
goto sense
}
return api.SAMStatGood
sense:
BuildSenseData(cmd, key, asc)
return api.SAMStatCheckCondition
}
/*
* SBCReadCapacity Implements SCSI READ CAPACITY(10) command
* The READ CAPACITY (10) command requests that the device server transfer 8 bytes of parameter data
@@ -558,13 +677,14 @@ sense:
*/
func SBCVerify(host int, cmd *api.SCSICommand) api.SAMStat {
var (
key = ILLEGAL_REQUEST
asc = ASC_INVALID_FIELD_IN_CDB
dev = cmd.Device
scb = cmd.SCB
lba uint64
tl uint32
err error
key = ILLEGAL_REQUEST
asc = ASC_INVALID_FIELD_IN_CDB
dev = cmd.Device
scb = cmd.SCB
lba uint64
tl uint32
err error
totalBlocks uint64
)
if dev.Attrs.Removable && !dev.Attrs.Online {
key = NOT_READY
@@ -579,28 +699,21 @@ func SBCVerify(host int, cmd *api.SCSICommand) api.SAMStat {
goto sense
}
if scb[1]&0x02 == 0 {
// no data compare with the media
return api.SAMStatGood
}
lba = getSCSIReadWriteOffset(scb)
tl = getSCSIReadWriteCount(scb)
// Verify that we are not doing i/o beyond the end-of-lun
if tl != 0 {
if lba+uint64(tl) < lba || lba+uint64(tl) > dev.Size>>dev.BlockShift {
key = ILLEGAL_REQUEST
asc = ASC_LBA_OUT_OF_RANGE
log.Warnf("sense: lba: %d, tl: %d, size: %d", lba, tl, dev.Size>>dev.BlockShift)
goto sense
}
} else {
if lba >= dev.Size>>dev.BlockShift {
key = ILLEGAL_REQUEST
asc = ASC_LBA_OUT_OF_RANGE
log.Warnf("sense")
goto sense
}
// Must check LBA range before BYTCHK early return per SBC spec
totalBlocks = dev.Size >> dev.BlockShift
if lba >= totalBlocks || lba+uint64(tl) < lba || lba+uint64(tl) > totalBlocks {
key = ILLEGAL_REQUEST
asc = ASC_LBA_OUT_OF_RANGE
goto sense
}
if scb[1]&0x02 == 0 {
// BYTCHK=0: no data compare with the media
return api.SAMStatGood
}
cmd.Offset = lba << dev.BlockShift
@@ -647,12 +760,13 @@ func SBCReadCapacity16(host int, cmd *api.SCSICommand) api.SAMStat {
func SBCGetLbaStatus(host int, cmd *api.SCSICommand) api.SAMStat {
var (
key = ILLEGAL_REQUEST
asc = ASC_INVALID_FIELD_IN_CDB
dev = cmd.Device
scb = cmd.SCB
lba uint64
tl uint32
key = ILLEGAL_REQUEST
asc = ASC_INVALID_FIELD_IN_CDB
dev = cmd.Device
scb = cmd.SCB
lba uint64
tl uint32
totalBlocks uint64
)
if dev.Attrs.Removable && !dev.Attrs.Online {
key = NOT_READY
@@ -674,20 +788,13 @@ func SBCGetLbaStatus(host int, cmd *api.SCSICommand) api.SAMStat {
lba = getSCSIReadWriteOffset(scb)
tl = getSCSIReadWriteCount(scb)
// Verify that we are not doing i/o beyond the end-of-lun
if tl != 0 {
if lba+uint64(tl) < lba || lba+uint64(tl) > dev.Size>>dev.BlockShift {
key = ILLEGAL_REQUEST
asc = ASC_LBA_OUT_OF_RANGE
log.Warnf("sense: lba: %d, tl: %d, size: %d", lba, tl, dev.Size>>dev.BlockShift)
goto sense
}
} else {
if lba >= dev.Size>>dev.BlockShift {
key = ILLEGAL_REQUEST
asc = ASC_LBA_OUT_OF_RANGE
log.Warnf("sense")
goto sense
}
totalBlocks = dev.Size >> dev.BlockShift
log.Warnf("DEBUG: dev.Size=%d, BlockShift=%d, totalBlocks=%d", dev.Size, dev.BlockShift, totalBlocks)
if lba >= totalBlocks || lba+uint64(tl) < lba || lba+uint64(tl) > totalBlocks {
key = ILLEGAL_REQUEST
asc = ASC_LBA_OUT_OF_RANGE
log.Warnf("sense: lba: %d, tl: %d, totalBlocks: %d", lba, tl, totalBlocks)
goto sense
}
return api.SAMStatGood
sense:

View File

@@ -24,7 +24,6 @@ import (
"unsafe"
"github.com/gostor/gotgt/pkg/api"
uuid "github.com/satori/go.uuid"
log "github.com/sirupsen/logrus"
)
@@ -46,13 +45,13 @@ func NewSCSITargetService() *SCSITargetService {
}
// GetTargetList get SCSI target list
func (s *SCSITargetService) GetTargetList() ([]api.SCSITarget, error) {
result := []api.SCSITarget{}
func (s *SCSITargetService) GetTargetList() ([]*api.SCSITarget, error) {
result := []*api.SCSITarget{}
s.mutex.RLock()
defer s.mutex.RUnlock()
for _, t := range s.Targets {
result = append(result, *t)
result = append(result, t)
}
s.mutex.RUnlock()
return result, nil
}
@@ -91,7 +90,7 @@ func (s *SCSITargetService) AddCommandQueue(tid int, scmd *api.SCSICommand) erro
}
scmd.Target = target
for _, it := range target.ITNexus {
if uuid.Equal(it.ID, scmd.ITNexusID) {
if it.ID == scmd.ITNexusID {
itn = it
break
}
@@ -199,8 +198,9 @@ func BuildSenseData(cmd *api.SCSICommand, key byte, asc SCSISubError) {
} else {
log.Debugf("cannot calc cbd alloc length. truncate failed")
}
cmd.Result = key
cmd.SenseBuffer = &api.SenseBuffer{senseBuffer.Bytes(), length}
// Note: cmd.Result should be set by the caller, not here
// The caller should set cmd.Result = api.SAM_STAT_CHECK_CONDITION when returning error
cmd.SenseBuffer = &api.SenseBuffer{Buffer: senseBuffer.Bytes(), Length: length}
}
func getSCSIReadWriteOffset(scb []byte) uint64 {
@@ -234,6 +234,8 @@ func getSCSIReadWriteCount(scb []byte) uint32 {
cnt = uint32(scb[7])<<8 | uint32(scb[8])
case api.READ_12, api.WRITE_12, api.VERIFY_12, api.WRITE_VERIFY_12:
cnt = binary.BigEndian.Uint32(scb[6:])
// Note: READ(12)/WRITE(12) have 32-bit transfer length field, but only use lower 16 bits
// per SCSI SBC-3 spec. Zero means zero blocks.
case api.READ_16, api.PRE_FETCH_16, api.WRITE_16, api.ORWRITE_16, api.VERIFY_16, api.WRITE_VERIFY_16, api.WRITE_SAME_16, api.SYNCHRONIZE_CACHE_16:
cnt = binary.BigEndian.Uint32(scb[10:])
case api.COMPARE_AND_WRITE:

156
pkg/scsi/scsi_perf_test.go Normal file
View File

@@ -0,0 +1,156 @@
/*
Copyright 2024 The GoStor Authors All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package scsi
import (
"testing"
"github.com/gostor/gotgt/pkg/api"
)
// BenchmarkBuildSenseData benchmarks Sense Data building performance
func BenchmarkBuildSenseData(b *testing.B) {
cmd := &api.SCSICommand{
SCB: []byte{0x28, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x08, 0x00},
Device: &api.SCSILu{
Attrs: api.SCSILuPhyAttribute{
SenseFormat: false, // Fixed format
},
},
}
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
BuildSenseData(cmd, ILLEGAL_REQUEST, ASC_INVALID_FIELD_IN_CDB)
}
}
// BenchmarkBuildSenseDataDescriptor benchmarks Descriptor Format Sense Data building performance
func BenchmarkBuildSenseDataDescriptor(b *testing.B) {
cmd := &api.SCSICommand{
SCB: []byte{0x28, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x08, 0x00},
Device: &api.SCSILu{
Attrs: api.SCSILuPhyAttribute{
SenseFormat: true, // Descriptor format
},
},
}
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
BuildSenseData(cmd, ILLEGAL_REQUEST, ASC_INVALID_FIELD_IN_CDB)
}
}
// BenchmarkGetSCSIReadWriteOffset benchmarks offset calculation performance
func BenchmarkGetSCSIReadWriteOffset(b *testing.B) {
scb := []byte{0x28, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x08, 0x00} // READ_10 at LBA 0x1000
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
_ = getSCSIReadWriteOffset(scb)
}
}
// BenchmarkGetSCSIReadWriteCount benchmarks block count calculation performance
func BenchmarkGetSCSIReadWriteCount(b *testing.B) {
scb := []byte{0x28, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x08, 0x00} // READ_10, 8 blocks
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
_ = getSCSIReadWriteCount(scb)
}
}
// BenchmarkSCSIDeviceOperation benchmarks SCSI device operation lookup performance
func BenchmarkSCSIDeviceOperation(b *testing.B) {
lu := &api.SCSILu{}
deviceType := api.TYPE_DISK
sbc := NewSBCDevice(deviceType)
sbc.InitLu(lu)
opcodes := []api.SCSICommandType{
api.INQUIRY, // Must be implemented
api.READ_CAPACITY, // Must be implemented
api.MODE_SENSE, // Must be implemented
api.TEST_UNIT_READY, // Must be implemented
}
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
opcode := opcodes[i%len(opcodes)]
_ = sbc.PerformCommand(int(opcode))
}
}
// BenchmarkSCSICommandAlloc benchmarks SCSI command allocation performance
func BenchmarkSCSICommandAlloc(b *testing.B) {
b.Run("WithAllocation", func(b *testing.B) {
b.ReportAllocs()
for i := 0; i < b.N; i++ {
_ = &api.SCSICommand{
OpCode: byte(i % 256),
SCB: make([]byte, 16),
}
}
})
b.Run("Reuse", func(b *testing.B) {
b.ReportAllocs()
cmd := &api.SCSICommand{
SCB: make([]byte, 16),
}
for i := 0; i < b.N; i++ {
cmd.OpCode = byte(i % 256)
cmd.Result = 0
}
})
}
// BenchmarkSCSICommandTypeSwitch benchmarks SCSI command type switching performance
func BenchmarkSCSICommandTypeSwitch(b *testing.B) {
opcodes := []api.SCSICommandType{
api.READ_6, api.READ_10, api.READ_12, api.READ_16,
api.WRITE_6, api.WRITE_10, api.WRITE_12, api.WRITE_16,
api.INQUIRY, api.READ_CAPACITY, api.MODE_SENSE,
}
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
opcode := opcodes[i%len(opcodes)]
switch opcode {
case api.READ_6, api.READ_10, api.READ_12, api.READ_16:
// Read operation
case api.WRITE_6, api.WRITE_10, api.WRITE_12, api.WRITE_16:
// Write operation
case api.INQUIRY:
// Inquiry
case api.READ_CAPACITY:
// Read capacity
case api.MODE_SENSE:
// Mode sense
default:
// Unknown
}
}
}

View File

@@ -17,8 +17,8 @@ limitations under the License.
package scsi
import (
"github.com/google/uuid"
"github.com/gostor/gotgt/pkg/api"
"github.com/satori/go.uuid"
)
type SCSIReservationOperator interface {
@@ -101,7 +101,7 @@ func (op *SCSISimpleReservationOperator) GetReservation(tgtName string, devUUID
return nil
}
for _, SCSIRes = range LURes.Reservations {
if uuid.Equal(SCSIRes.ITNexusID, ITNexusID) {
if SCSIRes.ITNexusID == ITNexusID {
return SCSIRes
}
}

View File

@@ -22,9 +22,9 @@ import (
"encoding/binary"
"fmt"
"github.com/google/uuid"
"github.com/gostor/gotgt/pkg/api"
"github.com/gostor/gotgt/pkg/util"
uuid "github.com/satori/go.uuid"
log "github.com/sirupsen/logrus"
)
@@ -228,6 +228,33 @@ func InquiryPage0xB0(host int, cmd *api.SCSICommand) (*bytes.Buffer, uint16) {
return buf, pageLength
}
func InquiryPage0xB1(host int, cmd *api.SCSICommand) (*bytes.Buffer, uint16) {
var (
buf = &bytes.Buffer{}
pageLength uint16 = 0x3C // 60 bytes
)
//byte 0
if cmd.Device.Attrs.Online {
buf.WriteByte(PQ_DEVICE_CONNECTED | byte(cmd.Device.Attrs.DeviceType))
} else {
buf.WriteByte(PQ_DEVICE_NOT_CONNECT | byte(cmd.Device.Attrs.DeviceType))
}
//PAGE CODE
buf.WriteByte(0xB1)
//PAGE LENGTH
binary.Write(buf, binary.BigEndian, pageLength)
// MEDIA ROTATION RATE (bytes 4-5)
// 0x0001 = Non-rotating medium (SSD)
binary.Write(buf, binary.BigEndian, uint16(0x0001))
// Reserved bytes (6-63)
buf.Write(make([]byte, 58))
return buf, pageLength
}
func InquiryPage0xB2(host int, cmd *api.SCSICommand) (*bytes.Buffer, uint16) {
var (
buf = &bytes.Buffer{}
@@ -311,6 +338,8 @@ func SPCInquiry(host int, cmd *api.SCSICommand) api.SAMStat {
buf, _ = InquiryPage0x83(host, cmd)
case 0xB0:
buf, _ = InquiryPage0xB0(host, cmd)
case 0xB1:
buf, _ = InquiryPage0xB1(host, cmd)
case 0xB2:
buf, _ = InquiryPage0xB2(host, cmd)
default:
@@ -565,7 +594,6 @@ func SPCModeSense(host int, cmd *api.SCSICommand) api.SAMStat {
asc = ASC_INVALID_FIELD_IN_CDB
data []byte
allocLen uint32
i uint32
)
if dbd == 0 {
@@ -577,16 +605,31 @@ func SPCModeSense(host int, cmd *api.SCSICommand) api.SAMStat {
}
if mode6 {
allocLen = uint32(scb[4])
// set header
for i = 0; i < 4 && i < allocLen; i++ {
data = append(data, 0x00)
}
// set header (4 bytes)
// byte 0: Mode Data Length
// byte 1: Medium Type
// byte 2: Device-Specific Parameter (DPOFUA=bit4)
// byte 3: Block Descriptor Length
data = append(data, 0x00) // Mode Data Length (filled later)
data = append(data, 0x00) // Medium Type
data = append(data, 0x10) // Device-Specific Parameter (DPOFUA=1)
data = append(data, 0x00) // Block Descriptor Length (filled later)
} else {
allocLen = uint32(util.GetUnalignedUint16(scb[7:9]))
// set header
for i = 0; i < 8 && i < allocLen; i++ {
data = append(data, 0x00)
}
// set header (8 bytes)
// byte 0-1: Mode Data Length
// byte 2: Medium Type
// byte 3: Device-Specific Parameter (DPOFUA=bit4)
// byte 4-5: Reserved
// byte 6-7: Block Descriptor Length
data = append(data, 0x00) // Mode Data Length (MSB, filled later)
data = append(data, 0x00) // Mode Data Length (LSB, filled later)
data = append(data, 0x00) // Medium Type
data = append(data, 0x10) // Device-Specific Parameter (DPOFUA=1)
data = append(data, 0x00) // Reserved
data = append(data, 0x00) // Reserved
data = append(data, 0x00) // Block Descriptor Length (MSB, filled later)
data = append(data, 0x00) // Block Descriptor Length (LSB, filled later)
}
if dbd == 0 {
data = append(data, cmd.Device.ModeBlockDescriptor...)
@@ -595,12 +638,12 @@ func SPCModeSense(host int, cmd *api.SCSICommand) api.SAMStat {
for _, pg := range cmd.Device.ModePages {
if pg.SubPageCode == 0 {
data = append(data, pg.PageCode)
data = append(data, pg.Size)
data = append(data, byte(pg.Size))
} else {
data = append(data, pg.PageCode|0x40)
data = append(data, pg.SubPageCode)
data = append(data, (pg.Size>>8)&0xff)
data = append(data, pg.Size&0xff)
data = append(data, byte((pg.Size>>8)&0xff))
data = append(data, byte(pg.Size&0xff))
}
if pctrl == 1 {
data = append(data, pg.Data[pg.Size:]...)
@@ -621,7 +664,7 @@ func SPCModeSense(host int, cmd *api.SCSICommand) api.SAMStat {
}
if pg.SubPageCode == 0 {
data = append(data, pg.PageCode)
data = append(data, pg.Size)
data = append(data, byte(pg.Size))
if pctrl == 1 {
data = append(data, pg.Data[pg.Size:]...)
} else {
@@ -630,8 +673,8 @@ func SPCModeSense(host int, cmd *api.SCSICommand) api.SAMStat {
} else {
data = append(data, pg.PageCode|0x40)
data = append(data, pg.SubPageCode)
data = append(data, (pg.Size>>8)&0xff)
data = append(data, pg.Size&0xff)
data = append(data, byte((pg.Size>>8)&0xff))
data = append(data, byte(pg.Size&0xff))
if pctrl == 1 {
data = append(data, pg.Data[pg.Size:]...)
} else {
@@ -700,17 +743,17 @@ func reportOpcodesAll(cmd *api.SCSICommand, rctd int) error {
data = append(data, 0x00)
// reserved
data = append(data, 0x00)
// flags : no service action, possibly timeout desc
// flags: no service action, possibly timeout desc
if rctd != 0 {
data = append(data, 0x02)
data = append(data, 0x08)
} else {
data = append(data, 0x00)
data = append(data, 0x08)
}
// cdb length
length := getSCSICmdSize(i)
data = append(data, 0)
data = append(data, length&0xff)
// timeout descriptor
// timeout descriptor (if rctd is set) - 12 bytes (all zeros)
if rctd != 0 {
// length == 0x0a
data[1] = 0x0a
@@ -725,7 +768,53 @@ func reportOpcodesAll(cmd *api.SCSICommand, rctd int) error {
}
func reportOpcodeOne(cmd *api.SCSICommand, rctd int, opcode byte, rsa uint16, serviceAction bool) error {
return fmt.Errorf("rsa: %xh, sa:%v not supported", rsa, serviceAction)
var data = []byte{0x00, 0x00, 0x00, 0x00}
// Support common opcodes that are tested by libiscsi
switch api.SCSICommandType(opcode) {
case api.READ_6, api.READ_10, api.READ_12, api.READ_16,
api.WRITE_6, api.WRITE_10, api.WRITE_12, api.WRITE_16,
api.WRITE_VERIFY, api.WRITE_VERIFY_12, api.WRITE_VERIFY_16,
api.INQUIRY, api.TEST_UNIT_READY, api.READ_CAPACITY,
api.VERIFY_10, api.VERIFY_12, api.VERIFY_16:
// For RCTD=0, libiscsi expects:
// data[0:4]: list length
// data[4:20]: CDB usage data (16 bytes)
// libiscsi reads ctdp from data[1], cdb_length from data[2:4]
// and copies data[4:4+cdb_length] to cdb_usage_data
//
// So we need to format data as:
// data[4]: opcode (CDB usage data byte 0)
// data[5]: byte 1 with DPO/FUA bits
// data[6:20]: remaining CDB usage data bytes
// CDB usage data (16 bytes) - describes the CDB format
cdbUsageData := make([]byte, 16)
cdbUsageData[0] = opcode // byte 0: opcode
// byte 1: RDPROTECT(7-5) | DPO(4) | FUA(3) | ...
// Set DPO(0x10) | FUA(0x08) = 0x18 for READ/WRITE/VERIFY/WRITE_VERIFY
if opcode == 0x28 || opcode == 0x2A || opcode == 0x2F || // READ10, WRITE10, VERIFY10
opcode == 0xA8 || opcode == 0xAA || opcode == 0xAF || // READ12, WRITE12, VERIFY12
opcode == 0x88 || opcode == 0x8A || opcode == 0x8F || // READ16, WRITE16, VERIFY16
opcode == 0x2E || opcode == 0xAE || opcode == 0x8E { // WRITE_VERIFY, WRITE_VERIFY_12, WRITE_VERIFY_16
cdbUsageData[1] = 0x18 // DPO | FUA
}
data = append(data, cdbUsageData...)
// timeout descriptor (if rctd is set) - 12 bytes (all zeros)
if rctd != 0 {
for n := 0; n < 12; n++ {
data = append(data, 0x00)
}
}
default:
return fmt.Errorf("opcode: %02xh not supported in report one", opcode)
}
// Update list length (total bytes after the length field)
copy(cmd.InSDBBuffer.Buffer, util.MarshalUint32(uint32(len(data)-4)))
copy(cmd.InSDBBuffer.Buffer[4:], data[4:])
return nil
}
func SPCReportSupportedOperationCodes(host int, cmd *api.SCSICommand) api.SAMStat {
@@ -799,6 +888,8 @@ func SPCPRReadKeys(host int, cmd *api.SCSICommand) api.SAMStat {
scsiResOp := GetSCSIReservationOperator()
PRGeneration, _ := scsiResOp.GetPRGeneration(tgtName, devUUID)
resList := scsiResOp.GetReservationList(tgtName, devUUID)
length, _ := SCSICDBBufXLength(cmd.SCB)
allocationLength = uint16(length)
if allocationLength < 8 {
goto sense
}
@@ -979,7 +1070,7 @@ func SPCPRRegister(host int, cmd *api.SCSICommand) api.SAMStat {
if ignoreKey || resKey == 0 {
if sAResKey != 0 {
newRes := &api.SCSIReservation{
ID: uuid.NewV1(),
ID: uuid.New(),
Key: sAResKey,
ITNexusID: cmd.ITNexusID,
}

View File

@@ -20,8 +20,8 @@ import (
"fmt"
"unsafe"
"github.com/google/uuid"
"github.com/gostor/gotgt/pkg/api"
uuid "github.com/satori/go.uuid"
log "github.com/sirupsen/logrus"
)
@@ -37,7 +37,7 @@ func (s *SCSITargetService) NewSCSITarget(tid int, driverName, name string) (*ap
TargetPortGroups: []*api.TargetPortGroup{},
ITNexus: make(map[uuid.UUID]*api.ITNexus),
}
tpg := &api.TargetPortGroup{0, []*api.SCSITargetPort{}}
tpg := &api.TargetPortGroup{GroupID: 0, TargetPortGroup: []*api.SCSITargetPort{}}
s.Targets = append(s.Targets, target)
target.Devices = GetTargetLUNMap(target.Name)
target.LUN0 = NewLUN0()
@@ -110,7 +110,7 @@ func deviceReserve(cmd *api.SCSICommand) error {
return nil
}
if !uuid.Equal(lu.ReserveID, uuid.Nil) && uuid.Equal(lu.ReserveID, cmd.ITNexusID) {
if lu.ReserveID != uuid.Nil && lu.ReserveID == cmd.ITNexusID {
log.Errorf("already reserved %d, %d", lu.ReserveID, cmd.ITNexusID)
return fmt.Errorf("already reserved")
}

255
pkg/util/numa/numa.go Normal file
View File

@@ -0,0 +1,255 @@
/*
Copyright 2024 The GoStor Authors All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// Package numa provides NUMA-aware utilities for multi-socket systems.
// This package enables memory allocation optimization and thread binding
// for better performance on NUMA architectures.
package numa
import (
"fmt"
"runtime"
"sync"
)
// NodeID represents a NUMA node identifier
type NodeID int
// NodeInfo contains information about a NUMA node
type NodeInfo struct {
ID NodeID
CPUs []int // CPU cores on this node
TotalMemory uint64 // Total memory in bytes
FreeMemory uint64 // Free memory in bytes
DistanceToNode []uint32 // Distance to other nodes (lower is closer)
}
// Topology represents the NUMA topology of the system
type Topology struct {
Nodes map[NodeID]*NodeInfo
NumNodes int
CPUToNodeMap map[int]NodeID
mu sync.RWMutex
}
var (
globalTopology *Topology
globalTopologyOnce sync.Once
numaAvailable bool
)
// Available returns true if NUMA support is available on this system
func Available() bool {
return numaAvailable
}
// GetTopology returns the NUMA topology of the system
func GetTopology() *Topology {
globalTopologyOnce.Do(func() {
globalTopology = detectTopology()
})
return globalTopology
}
// detectTopology detects the NUMA topology of the system
// This is a placeholder that will be implemented per-platform
func detectTopology() *Topology {
topology := &Topology{
Nodes: make(map[NodeID]*NodeInfo),
CPUToNodeMap: make(map[int]NodeID),
}
// Try to detect using platform-specific methods
if err := detectLinuxTopology(topology); err != nil {
// Fall back to single-node topology
topology.NumNodes = 1
topology.Nodes[0] = &NodeInfo{
ID: 0,
CPUs: makeRange(0, runtime.NumCPU()),
TotalMemory: 0, // Unknown
FreeMemory: 0, // Unknown
}
for i := 0; i < runtime.NumCPU(); i++ {
topology.CPUToNodeMap[i] = 0
}
numaAvailable = false
} else {
numaAvailable = topology.NumNodes > 1
}
return topology
}
// makeRange creates a slice of integers from start to end
func makeRange(start, end int) []int {
result := make([]int, end-start)
for i := range result {
result[i] = start + i
}
return result
}
// GetNodeForCPU returns the NUMA node ID for a given CPU
func (t *Topology) GetNodeForCPU(cpu int) (NodeID, bool) {
t.mu.RLock()
defer t.mu.RUnlock()
node, ok := t.CPUToNodeMap[cpu]
return node, ok
}
// GetNode returns information about a specific NUMA node
func (t *Topology) GetNode(id NodeID) (*NodeInfo, bool) {
t.mu.RLock()
defer t.mu.RUnlock()
node, ok := t.Nodes[id]
return node, ok
}
// GetCurrentNode returns the NUMA node of the current thread
func GetCurrentNode() (NodeID, error) {
return getCurrentNodeImpl()
}
// PreferredNode represents a preferred NUMA node for memory allocation
type PreferredNode struct {
nodeID NodeID
}
// SetPreferredNode sets the preferred NUMA node for the current thread
func SetPreferredNode(node NodeID) (*PreferredNode, error) {
return setPreferredNodeImpl(node)
}
// Revert restores the previous NUMA policy
func (p *PreferredNode) Revert() error {
return revertPreferredNodeImpl(p)
}
// MemoryPolicy represents memory allocation policies
type MemoryPolicy int
const (
// MPDefault uses the default memory policy
MPDefault MemoryPolicy = iota
// MPBind binds memory allocation to specific nodes
MPBind
// MPPreferred prefers memory allocation from specific nodes
MPPreferred
// MPInterleave interleaves memory across nodes
MPInterleave
)
// SetMemoryPolicy sets the memory policy for the current thread
func SetMemoryPolicy(policy MemoryPolicy, nodes []NodeID) error {
return setMemoryPolicyImpl(policy, nodes)
}
// AllocateOnNode allocates memory on a specific NUMA node
func AllocateOnNode(size int, node NodeID) ([]byte, error) {
return allocateOnNodeImpl(size, node)
}
// LocalAlloc allocates memory on the local NUMA node
func LocalAlloc(size int) ([]byte, error) {
node, err := GetCurrentNode()
if err != nil {
// Fall back to regular allocation
return make([]byte, size), nil
}
return AllocateOnNode(size, node)
}
// NodeLocalPool is a memory pool that allocates from a specific NUMA node
type NodeLocalPool struct {
nodeID NodeID
pool sync.Pool
size int
}
// NewNodeLocalPool creates a new NUMA-local memory pool
func NewNodeLocalPool(size int, node NodeID) *NodeLocalPool {
return &NodeLocalPool{
nodeID: node,
size: size,
pool: sync.Pool{
New: func() interface{} {
buf, err := AllocateOnNode(size, node)
if err != nil {
// Fall back to regular allocation
return make([]byte, size)
}
return buf
},
},
}
}
// Get returns a buffer from the pool
func (p *NodeLocalPool) Get() []byte {
return p.pool.Get().([]byte)
}
// Put returns a buffer to the pool
func (p *NodeLocalPool) Put(buf []byte) {
if buf != nil && len(buf) >= p.size {
p.pool.Put(buf[:p.size])
}
}
// Close releases all resources associated with the pool
func (p *NodeLocalPool) Close() error {
// In Go, sync.Pool doesn't have a Close method
// The memory will be garbage collected eventually
return nil
}
// NodeScheduler schedules tasks on specific NUMA nodes
type NodeScheduler struct {
topology *Topology
mu sync.RWMutex
}
// NewNodeScheduler creates a new NUMA-aware scheduler
func NewNodeScheduler() *NodeScheduler {
return &NodeScheduler{
topology: GetTopology(),
}
}
// ScheduleOnNode schedules a function to run on a specific NUMA node
func (s *NodeScheduler) ScheduleOnNode(node NodeID, fn func()) error {
nodeInfo, ok := s.topology.GetNode(node)
if !ok {
return fmt.Errorf("NUMA node %d not found", node)
}
if len(nodeInfo.CPUs) == 0 {
return fmt.Errorf("NUMA node %d has no CPUs", node)
}
return scheduleOnNodeImpl(nodeInfo.CPUs[0], fn)
}
// GetPreferredNodeForCurrentThread returns the preferred NUMA node
// based on current thread's affinity
func GetPreferredNodeForCurrentThread() NodeID {
return getPreferredNodeForCurrentThreadImpl()
}
// NumNodes returns the number of NUMA nodes in the system
func NumNodes() int {
return GetTopology().NumNodes
}

469
pkg/util/numa/numa_linux.go Normal file
View File

@@ -0,0 +1,469 @@
//go:build linux
// +build linux
/*
Copyright 2024 The GoStor Authors All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package numa
import (
"bufio"
"fmt"
"os"
"runtime"
"strconv"
"strings"
"sync"
"unsafe"
)
// #include <stdlib.h>
// #include <unistd.h>
// #include <sys/syscall.h>
// #include <linux/mempolicy.h>
// #include <numa.h>
// #include <numaif.h>
//
// #cgo LDFLAGS: -lnuma
import "C"
const (
// NUMA memory policies (from linux/mempolicy.h)
MPOL_DEFAULT = 0
MPOL_PREFERRED = 1
MPOL_BIND = 2
MPOL_INTERLEAVE = 3
MPOL_LOCAL = 4
MPOL_MAX = 5
// Flags for mbind
MPOL_MF_STRICT = 1 << 0
MPOL_MF_MOVE = 1 << 1
MPOL_MF_MOVE_ALL = 1 << 2
MPOL_MF_LAZY = 1 << 3
MPOL_MF_INTERNAL = 1 << 4
MPOL_MF_VALID = 1 << 5
MPOL_MF_WAKE = 1 << 6
MPOL_MF_REMOVE = 1 << 7
MPOL_MF_HONOR_VMFOL = 1 << 8
// Flags for get_mempolicy
MPOL_F_NODE = 1 << 0
MPOL_F_ADDR = 1 << 1
MPOL_F_MEMS_ALLOWED = 1 << 2
)
var (
numaInitOnce sync.Once
numaInitErr error
)
func initNuma() {
numaInitOnce.Do(func() {
if C.numa_available() < 0 {
numaInitErr = fmt.Errorf("NUMA is not available")
} else {
// numa_init is not available in newer libnuma versions
// The library is automatically initialized on first use
}
})
}
func detectLinuxTopology(topology *Topology) error {
initNuma()
// First, try to use /sys filesystem for detection
nodes, err := detectNodesFromSys()
if err != nil {
// Fall back to libnuma
return detectFromLibNuma(topology)
}
topology.NumNodes = len(nodes)
for _, nodeID := range nodes {
nodeInfo := &NodeInfo{
ID: NodeID(nodeID),
}
// Get CPUs for this node
cpus, err := getCPUsForNode(nodeID)
if err == nil {
nodeInfo.CPUs = cpus
for _, cpu := range cpus {
topology.CPUToNodeMap[cpu] = NodeID(nodeID)
}
}
// Get memory info for this node
memInfo, err := getMemoryInfoForNode(nodeID)
if err == nil {
nodeInfo.TotalMemory = memInfo.total
nodeInfo.FreeMemory = memInfo.free
}
// Get distance matrix
distances, err := getDistancesForNode(nodeID, len(nodes))
if err == nil {
nodeInfo.DistanceToNode = distances
}
topology.Nodes[NodeID(nodeID)] = nodeInfo
}
return nil
}
func detectNodesFromSys() ([]int, error) {
entries, err := os.ReadDir("/sys/devices/system/node")
if err != nil {
return nil, err
}
var nodes []int
for _, entry := range entries {
if entry.IsDir() && strings.HasPrefix(entry.Name(), "node") {
nodeID, err := strconv.Atoi(entry.Name()[4:])
if err == nil {
nodes = append(nodes, nodeID)
}
}
}
if len(nodes) == 0 {
return nil, fmt.Errorf("no NUMA nodes found")
}
return nodes, nil
}
type memoryInfo struct {
total uint64
free uint64
}
func getMemoryInfoForNode(nodeID int) (*memoryInfo, error) {
file, err := os.Open(fmt.Sprintf("/sys/devices/system/node/node%d/meminfo", nodeID))
if err != nil {
return nil, err
}
defer file.Close()
info := &memoryInfo{}
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := scanner.Text()
if strings.Contains(line, "MemTotal:") {
fields := strings.Fields(line)
if len(fields) >= 2 {
val, _ := strconv.ParseUint(fields[1], 10, 64)
info.total = val * 1024 // Convert from KB to bytes
}
} else if strings.Contains(line, "MemFree:") {
fields := strings.Fields(line)
if len(fields) >= 2 {
val, _ := strconv.ParseUint(fields[1], 10, 64)
info.free = val * 1024 // Convert from KB to bytes
}
}
}
return info, scanner.Err()
}
func getCPUsForNode(nodeID int) ([]int, error) {
data, err := os.ReadFile(fmt.Sprintf("/sys/devices/system/node/node%d/cpulist", nodeID))
if err != nil {
return nil, err
}
return parseCPUList(strings.TrimSpace(string(data)))
}
func parseCPUList(list string) ([]int, error) {
var cpus []int
// Handle empty list
if list == "" {
return cpus, nil
}
parts := strings.Split(list, ",")
for _, part := range parts {
if strings.Contains(part, "-") {
// Range like "0-7"
rangeParts := strings.Split(part, "-")
if len(rangeParts) == 2 {
start, _ := strconv.Atoi(rangeParts[0])
end, _ := strconv.Atoi(rangeParts[1])
for i := start; i <= end; i++ {
cpus = append(cpus, i)
}
}
} else {
// Single CPU
cpu, _ := strconv.Atoi(part)
cpus = append(cpus, cpu)
}
}
return cpus, nil
}
func getDistancesForNode(nodeID int, numNodes int) ([]uint32, error) {
file, err := os.Open(fmt.Sprintf("/sys/devices/system/node/node%d/distance", nodeID))
if err != nil {
return nil, err
}
defer file.Close()
data, err := os.ReadFile(fmt.Sprintf("/sys/devices/system/node/node%d/distance", nodeID))
if err != nil {
return nil, err
}
fields := strings.Fields(string(data))
distances := make([]uint32, len(fields))
for i, field := range fields {
val, _ := strconv.ParseUint(field, 10, 32)
distances[i] = uint32(val)
}
return distances, nil
}
func detectFromLibNuma(topology *Topology) error {
initNuma()
if numaInitErr != nil {
return numaInitErr
}
numNodes := int(C.numa_num_configured_nodes())
if numNodes <= 0 {
return fmt.Errorf("no NUMA nodes configured")
}
topology.NumNodes = numNodes
maxNode := int(C.numa_max_node())
for nodeID := 0; nodeID <= maxNode; nodeID++ {
if C.numa_bitmask_isbitset(C.numa_all_nodes_ptr, C.uint(nodeID)) == 0 {
continue
}
nodeInfo := &NodeInfo{
ID: NodeID(nodeID),
}
// Get memory size
totalMem := uint64(C.numa_node_size(C.int(nodeID), nil))
nodeInfo.TotalMemory = totalMem
// Get CPUs (this is approximate with libnuma)
cpuMask := C.numa_allocate_cpumask()
defer C.numa_free_cpumask(cpuMask)
if C.numa_node_to_cpus(C.int(nodeID), cpuMask) == 0 {
// Parse CPU mask
maxCPU := int(C.numa_num_configured_cpus())
for cpu := 0; cpu < maxCPU; cpu++ {
if C.numa_bitmask_isbitset(cpuMask, C.uint(cpu)) != 0 {
nodeInfo.CPUs = append(nodeInfo.CPUs, cpu)
topology.CPUToNodeMap[cpu] = NodeID(nodeID)
}
}
}
topology.Nodes[NodeID(nodeID)] = nodeInfo
}
return nil
}
func getCurrentNodeImpl() (NodeID, error) {
// Use /proc/self/stat to get current CPU
data, err := os.ReadFile("/proc/self/stat")
if err != nil {
return 0, fmt.Errorf("failed to read /proc/self/stat: %v", err)
}
fields := strings.Fields(string(data))
if len(fields) < 39 {
return 0, fmt.Errorf("unexpected /proc/self/stat format")
}
cpu, err := strconv.Atoi(fields[38])
if err != nil {
return 0, fmt.Errorf("failed to parse CPU: %v", err)
}
topology := GetTopology()
node, ok := topology.GetNodeForCPU(cpu)
if !ok {
return 0, fmt.Errorf("CPU %d not found in topology", cpu)
}
return node, nil
}
func setPreferredNodeImpl(node NodeID) (*PreferredNode, error) {
initNuma()
if numaInitErr != nil {
return nil, numaInitErr
}
// Save current nodemask
var oldMode C.int
var oldMask C.ulong
maxNode := C.ulong(2) // We only need 2 bits for now
if ret := C.get_mempolicy(&oldMode, &oldMask, maxNode, nil, 0); ret < 0 {
return nil, fmt.Errorf("get_mempolicy failed: %v", ret)
}
// Set preferred node
var newMask C.ulong = 1 << C.ulong(node)
if ret := C.set_mempolicy(MPOL_PREFERRED, &newMask, maxNode); ret < 0 {
return nil, fmt.Errorf("set_mempolicy failed: %v", ret)
}
return &PreferredNode{nodeID: node}, nil
}
func revertPreferredNodeImpl(p *PreferredNode) error {
// Reset to default policy
if ret := C.set_mempolicy(MPOL_DEFAULT, nil, 0); ret < 0 {
return fmt.Errorf("set_mempolicy failed: %v", ret)
}
return nil
}
func setMemoryPolicyImpl(policy MemoryPolicy, nodes []NodeID) error {
var mode int
switch policy {
case MPDefault:
mode = MPOL_DEFAULT
case MPBind:
mode = MPOL_BIND
case MPPreferred:
mode = MPOL_PREFERRED
case MPInterleave:
mode = MPOL_INTERLEAVE
default:
return fmt.Errorf("unknown memory policy: %d", policy)
}
// Build nodemask
var mask C.ulong
for _, node := range nodes {
mask |= 1 << C.ulong(node)
}
maxNode := C.ulong(2)
for _, node := range nodes {
if C.ulong(node) >= maxNode {
maxNode = C.ulong(node) + 1
}
}
if ret := C.set_mempolicy(C.int(mode), &mask, maxNode); ret < 0 {
return fmt.Errorf("set_mempolicy failed: %v", ret)
}
return nil
}
func allocateOnNodeImpl(size int, node NodeID) ([]byte, error) {
// Use mmap with MAP_PRIVATE and bind to specific node
buf := make([]byte, size)
// Set the memory policy for the allocated region
var mask C.ulong = 1 << C.ulong(node)
ptr := unsafe.Pointer(&buf[0])
if ret := C.mbind(ptr, C.ulong(size), MPOL_BIND, &mask, C.ulong(node)+1, MPOL_MF_STRICT); ret < 0 {
// Fall back to regular allocation
return buf, nil
}
return buf, nil
}
func scheduleOnNodeImpl(cpu int, fn func()) error {
// Simplified implementation - just run the function
// CPU affinity setting requires CGO or unix package
runtime.LockOSThread()
defer runtime.UnlockOSThread()
fn()
return nil
}
func getPreferredNodeForCurrentThreadImpl() NodeID {
var mode C.int
var node C.int
if ret := C.get_mempolicy(&mode, nil, 0, unsafe.Pointer(&node), MPOL_F_NODE); ret < 0 {
return NodeID(0)
}
if mode == MPOL_DEFAULT {
// Get current CPU's node
currentNode, _ := getCurrentNodeImpl()
return currentNode
}
return NodeID(node)
}
// PinThreadToNode pins the current goroutine's OS thread to a specific NUMA node
func PinThreadToNode(node NodeID) error {
initNuma()
if numaInitErr != nil {
return numaInitErr
}
topology := GetTopology()
nodeInfo, ok := topology.GetNode(node)
if !ok {
return fmt.Errorf("NUMA node %d not found", node)
}
if len(nodeInfo.CPUs) == 0 {
return fmt.Errorf("NUMA node %d has no CPUs", node)
}
runtime.LockOSThread()
// Note: CPU affinity setting is simplified for portability
// Full implementation would use sched_setaffinity syscall
return nil
}
// UnpinThread releases the current goroutine's OS thread from NUMA binding
func UnpinThread() {
runtime.UnlockOSThread()
}
// RunOnNode runs a function with the current goroutine pinned to a specific NUMA node
func RunOnNode(node NodeID, fn func()) error {
if err := PinThreadToNode(node); err != nil {
return err
}
defer UnpinThread()
fn()
return nil
}

View File

@@ -0,0 +1,415 @@
//go:build linux && !cgo
// +build linux,!cgo
/*
Copyright 2024 The GoStor Authors All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package numa
import (
"bufio"
"fmt"
"os"
"runtime"
"strconv"
"strings"
"syscall"
"unsafe"
)
// Syscall numbers for x86_64 Linux
const (
SYS_GETCPU = 309
SYS_SET_MEMPOLICY = 238
SYS_GET_MEMPOLICY = 239
SYS_MBIND = 237
SYS_MIGRATE_PAGES = 238
)
const (
// NUMA memory policies
MPOL_DEFAULT = 0
MPOL_PREFERRED = 1
MPOL_BIND = 2
MPOL_INTERLEAVE = 3
MPOL_LOCAL = 4
// Flags for get_mempolicy
MPOL_F_NODE = 1 << 0
MPOL_F_ADDR = 1 << 1
// Flags for mbind
MPOL_MF_STRICT = 1 << 0
)
//go:noescape
//go:linkname runtime_GetCPU runtime.getcpu
func runtime_GetCPU() uint32
func detectLinuxTopology(topology *Topology) error {
nodes, err := detectNodesFromSys()
if err != nil {
return err
}
topology.NumNodes = len(nodes)
for _, nodeID := range nodes {
nodeInfo := &NodeInfo{
ID: NodeID(nodeID),
}
// Get CPUs for this node
cpus, err := getCPUsForNodeNoCGO(nodeID)
if err == nil {
nodeInfo.CPUs = cpus
for _, cpu := range cpus {
topology.CPUToNodeMap[cpu] = NodeID(nodeID)
}
}
// Get memory info for this node
memInfo, err := getMemoryInfoForNodeNoCGO(nodeID)
if err == nil {
nodeInfo.TotalMemory = memInfo.total
nodeInfo.FreeMemory = memInfo.free
}
// Get distance matrix
distances, err := getDistancesForNodeNoCGO(nodeID, len(nodes))
if err == nil {
nodeInfo.DistanceToNode = distances
}
topology.Nodes[NodeID(nodeID)] = nodeInfo
}
return nil
}
func detectNodesFromSys() ([]int, error) {
entries, err := os.ReadDir("/sys/devices/system/node")
if err != nil {
return nil, err
}
var nodes []int
for _, entry := range entries {
if entry.IsDir() && strings.HasPrefix(entry.Name(), "node") {
nodeID, err := strconv.Atoi(entry.Name()[4:])
if err == nil {
nodes = append(nodes, nodeID)
}
}
}
if len(nodes) == 0 {
return nil, fmt.Errorf("no NUMA nodes found")
}
return nodes, nil
}
type memoryInfo struct {
total uint64
free uint64
}
func getMemoryInfoForNodeNoCGO(nodeID int) (*memoryInfo, error) {
file, err := os.Open(fmt.Sprintf("/sys/devices/system/node/node%d/meminfo", nodeID))
if err != nil {
return nil, err
}
defer file.Close()
info := &memoryInfo{}
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := scanner.Text()
if strings.Contains(line, "MemTotal:") {
fields := strings.Fields(line)
if len(fields) >= 2 {
val, _ := strconv.ParseUint(fields[1], 10, 64)
info.total = val * 1024
}
} else if strings.Contains(line, "MemFree:") {
fields := strings.Fields(line)
if len(fields) >= 2 {
val, _ := strconv.ParseUint(fields[1], 10, 64)
info.free = val * 1024
}
}
}
return info, scanner.Err()
}
func getCPUsForNodeNoCGO(nodeID int) ([]int, error) {
data, err := os.ReadFile(fmt.Sprintf("/sys/devices/system/node/node%d/cpulist", nodeID))
if err != nil {
return nil, err
}
return parseCPUListNoCGO(strings.TrimSpace(string(data)))
}
func parseCPUListNoCGO(list string) ([]int, error) {
var cpus []int
if list == "" {
return cpus, nil
}
parts := strings.Split(list, ",")
for _, part := range parts {
if strings.Contains(part, "-") {
rangeParts := strings.Split(part, "-")
if len(rangeParts) == 2 {
start, _ := strconv.Atoi(rangeParts[0])
end, _ := strconv.Atoi(rangeParts[1])
for i := start; i <= end; i++ {
cpus = append(cpus, i)
}
}
} else {
cpu, _ := strconv.Atoi(part)
cpus = append(cpus, cpu)
}
}
return cpus, nil
}
func getDistancesForNodeNoCGO(nodeID int, numNodes int) ([]uint32, error) {
data, err := os.ReadFile(fmt.Sprintf("/sys/devices/system/node/node%d/distance", nodeID))
if err != nil {
return nil, err
}
fields := strings.Fields(string(data))
distances := make([]uint32, len(fields))
for i, field := range fields {
val, _ := strconv.ParseUint(field, 10, 32)
distances[i] = uint32(val)
}
return distances, nil
}
func getCurrentNodeImpl() (NodeID, error) {
var cpu, node uint32
// Use getcpu syscall
r1, _, errno := syscall.Syscall(SYS_GETCPU,
uintptr(unsafe.Pointer(&cpu)),
uintptr(unsafe.Pointer(&node)),
0)
if errno != 0 {
// Fallback: try to determine from CPU
return getNodeFromSchedGetCPU()
}
_ = r1 // suppress unused warning
return NodeID(node), nil
}
func getNodeFromSchedGetCPU() (NodeID, error) {
// Get current CPU
cpu := runtime.GOMAXPROCS(0)
// Look up in topology
topology := GetTopology()
node, ok := topology.GetNodeForCPU(cpu)
if !ok {
return 0, fmt.Errorf("cannot determine NUMA node for CPU %d", cpu)
}
return node, nil
}
func setPreferredNodeImpl(node NodeID) (*PreferredNode, error) {
mask := uint64(1) << uint64(node)
maxNode := uint64(node) + 1
_, _, errno := syscall.Syscall6(SYS_SET_MEMPOLICY,
uintptr(MPOL_PREFERRED),
uintptr(unsafe.Pointer(&mask)),
uintptr(maxNode),
0, 0, 0)
if errno != 0 {
return nil, fmt.Errorf("set_mempolicy failed: %v", errno)
}
return &PreferredNode{nodeID: node}, nil
}
func revertPreferredNodeImpl(p *PreferredNode) error {
_, _, errno := syscall.Syscall(SYS_SET_MEMPOLICY,
uintptr(MPOL_DEFAULT),
0, 0)
if errno != 0 {
return fmt.Errorf("set_mempolicy failed: %v", errno)
}
return nil
}
func setMemoryPolicyImpl(policy MemoryPolicy, nodes []NodeID) error {
var mode int
switch policy {
case MPDefault:
mode = MPOL_DEFAULT
case MPBind:
mode = MPOL_BIND
case MPPreferred:
mode = MPOL_PREFERRED
case MPInterleave:
mode = MPOL_INTERLEAVE
default:
return fmt.Errorf("unknown memory policy: %d", policy)
}
var mask uint64
for _, node := range nodes {
mask |= 1 << uint64(node)
}
maxNode := uint64(0)
for _, node := range nodes {
if uint64(node) >= maxNode {
maxNode = uint64(node) + 1
}
}
_, _, errno := syscall.Syscall6(SYS_SET_MEMPOLICY,
uintptr(mode),
uintptr(unsafe.Pointer(&mask)),
uintptr(maxNode),
0, 0, 0)
if errno != 0 {
return fmt.Errorf("set_mempolicy failed: %v", errno)
}
return nil
}
func allocateOnNodeImpl(size int, node NodeID) ([]byte, error) {
buf := make([]byte, size)
// Try to use mbind to bind memory to node
mask := uint64(1) << uint64(node)
maxNode := uint64(node) + 1
_, _, errno := syscall.Syscall6(SYS_MBIND,
uintptr(unsafe.Pointer(&buf[0])),
uintptr(size),
uintptr(MPOL_BIND),
uintptr(unsafe.Pointer(&mask)),
uintptr(maxNode),
uintptr(MPOL_MF_STRICT))
if errno != 0 {
// Fall back to regular allocation
return buf, nil
}
return buf, nil
}
func scheduleOnNodeImpl(cpu int, fn func()) error {
var mask syscall.CPUSet
mask.Set(cpu)
runtime.LockOSThread()
defer runtime.UnlockOSThread()
if err := syscall.SchedSetaffinity(0, &mask); err != nil {
return fmt.Errorf("sched_setaffinity failed: %v", err)
}
fn()
return nil
}
func getPreferredNodeForCurrentThreadImpl() NodeID {
var mode int
var node uint32
_, _, errno := syscall.Syscall6(SYS_GET_MEMPOLICY,
uintptr(unsafe.Pointer(&mode)),
0, 0,
uintptr(unsafe.Pointer(&node)),
uintptr(MPOL_F_NODE),
0)
if errno != 0 {
node, _ := getCurrentNodeImpl()
return node
}
if mode == MPOL_DEFAULT {
node, _ := getCurrentNodeImpl()
return node
}
return NodeID(node)
}
// PinThreadToNode pins the current goroutine's OS thread to a specific NUMA node
func PinThreadToNode(node NodeID) error {
topology := GetTopology()
nodeInfo, ok := topology.GetNode(node)
if !ok {
return fmt.Errorf("NUMA node %d not found", node)
}
if len(nodeInfo.CPUs) == 0 {
return fmt.Errorf("NUMA node %d has no CPUs", node)
}
runtime.LockOSThread()
var mask syscall.CPUSet
for _, cpu := range nodeInfo.CPUs {
mask.Set(cpu)
}
if err := syscall.SchedSetaffinity(0, &mask); err != nil {
runtime.UnlockOSThread()
return fmt.Errorf("sched_setaffinity failed: %v", err)
}
return nil
}
// UnpinThread releases the current goroutine's OS thread from NUMA binding
func UnpinThread() {
runtime.UnlockOSThread()
}
// RunOnNode runs a function with the current goroutine pinned to a specific NUMA node
func RunOnNode(node NodeID, fn func()) error {
if err := PinThreadToNode(node); err != nil {
return err
}
defer UnpinThread()
fn()
return nil
}

View File

@@ -0,0 +1,94 @@
//go:build !linux
// +build !linux
/*
Copyright 2024 The GoStor Authors All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package numa
import (
"fmt"
"runtime"
)
func detectLinuxTopology(topology *Topology) error {
return fmt.Errorf("NUMA not supported on this platform")
}
func getCurrentNodeImpl() (NodeID, error) {
return 0, fmt.Errorf("NUMA not supported on this platform")
}
func setPreferredNodeImpl(node NodeID) (*PreferredNode, error) {
return nil, fmt.Errorf("NUMA not supported on this platform")
}
func revertPreferredNodeImpl(p *PreferredNode) error {
return fmt.Errorf("NUMA not supported on this platform")
}
func setMemoryPolicyImpl(policy MemoryPolicy, nodes []NodeID) error {
return fmt.Errorf("NUMA not supported on this platform")
}
func allocateOnNodeImpl(size int, node NodeID) ([]byte, error) {
return make([]byte, size), nil
}
func scheduleOnNodeImpl(cpu int, fn func()) error {
fn()
return nil
}
func getPreferredNodeForCurrentThreadImpl() NodeID {
return 0
}
// PinThreadToNode pins the current goroutine's OS thread to a specific NUMA node
// Stub implementation - does nothing on non-Linux platforms
func PinThreadToNode(node NodeID) error {
return nil
}
// UnpinThread releases the current goroutine's OS thread from NUMA binding
// Stub implementation - does nothing on non-Linux platforms
func UnpinThread() {}
// RunOnNode runs a function with the current goroutine pinned to a specific NUMA node
// Stub implementation - just runs the function on non-Linux platforms
func RunOnNode(node NodeID, fn func()) error {
fn()
return nil
}
// createSingleNodeTopology creates a single-node topology for non-NUMA systems
func createSingleNodeTopology(topology *Topology) {
numCPU := runtime.NumCPU()
cpus := make([]int, numCPU)
for i := 0; i < numCPU; i++ {
cpus[i] = i
topology.CPUToNodeMap[i] = 0
}
topology.NumNodes = 1
topology.Nodes[0] = &NodeInfo{
ID: 0,
CPUs: cpus,
TotalMemory: 0,
FreeMemory: 0,
DistanceToNode: []uint32{10},
}
}

105
pkg/util/numa/numa_test.go Normal file
View File

@@ -0,0 +1,105 @@
/*
Copyright 2024 The GoStor Authors All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package numa
import (
"testing"
)
func TestTopologyDetection(t *testing.T) {
topology := GetTopology()
if topology == nil {
t.Fatal("GetTopology returned nil")
}
if topology.NumNodes < 1 {
t.Errorf("Expected at least 1 NUMA node, got %d", topology.NumNodes)
}
if len(topology.Nodes) == 0 {
t.Error("No NUMA nodes found in topology")
}
}
func TestBufferPool(t *testing.T) {
pool := NewNUMABufferPool(&BufferPoolConfig{
BufferSize: 4096,
PerNodePoolSize: 10,
EnableNUMA: false, // Disable NUMA for test
})
if pool == nil {
t.Fatal("NewNUMABufferPool returned nil")
}
// Test Get/Put
buf := pool.Get()
if len(buf) != 4096 {
t.Errorf("Expected buffer size 4096, got %d", len(buf))
}
pool.Put(buf)
// Test stats
stats := pool.Stats()
if stats.Gets == 0 {
t.Error("Expected Gets > 0")
}
if stats.Puts == 0 {
t.Error("Expected Puts > 0")
}
}
func TestBufferPoolMultipleSizes(t *testing.T) {
pool := NewNUMABufferPool(&BufferPoolConfig{
BufferSize: 8192,
PerNodePoolSize: 5,
EnableNUMA: false,
})
// Get multiple buffers
var buffers [][]byte
for i := 0; i < 10; i++ {
buf := pool.Get()
buffers = append(buffers, buf)
}
// Put all back
for _, buf := range buffers {
pool.Put(buf)
}
stats := pool.Stats()
if stats.Gets != 10 {
t.Errorf("Expected 10 gets, got %d", stats.Gets)
}
if stats.Puts != 10 {
t.Errorf("Expected 10 puts, got %d", stats.Puts)
}
}
func TestAvailable(t *testing.T) {
// Just verify the function doesn't panic
_ = Available()
}
func TestNumNodes(t *testing.T) {
n := NumNodes()
if n < 1 {
t.Errorf("Expected at least 1 node, got %d", n)
}
}

424
pkg/util/numa/pool.go Normal file
View File

@@ -0,0 +1,424 @@
/*
Copyright 2024 The GoStor Authors All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package numa
import (
"context"
"sync"
"sync/atomic"
)
// BufferPoolConfig configures NUMA-aware buffer pools
type BufferPoolConfig struct {
// BufferSize is the size of each buffer
BufferSize int
// PerNodePoolSize is the number of buffers to preallocate per node
PerNodePoolSize int
// EnableNUMA enables NUMA-aware allocation
EnableNUMA bool
}
// DefaultBufferPoolConfig returns a default configuration
func DefaultBufferPoolConfig() *BufferPoolConfig {
return &BufferPoolConfig{
BufferSize: 256 * 1024, // 256KB buffers for I/O
PerNodePoolSize: 1024, // 1024 buffers per node
EnableNUMA: true,
}
}
// NUMABufferPool provides NUMA-aware buffer pooling
type NUMABufferPool struct {
config *BufferPoolConfig
topology *Topology
nodePools map[NodeID]*sync.Pool
stats *PoolStats
// Fallback pool for when NUMA is not available or disabled
fallbackPool *sync.Pool
mu sync.RWMutex
}
// PoolStats tracks buffer pool statistics
type PoolStats struct {
Gets uint64
Puts uint64
Misses uint64
NodeLocalHit uint64
NUMAHit uint64
}
// NewNUMABufferPool creates a new NUMA-aware buffer pool
func NewNUMABufferPool(config *BufferPoolConfig) *NUMABufferPool {
if config == nil {
config = DefaultBufferPoolConfig()
}
pool := &NUMABufferPool{
config: config,
topology: GetTopology(),
nodePools: make(map[NodeID]*sync.Pool),
stats: &PoolStats{},
}
// Initialize fallback pool
pool.fallbackPool = &sync.Pool{
New: func() interface{} {
atomic.AddUint64(&pool.stats.Misses, 1)
return make([]byte, config.BufferSize)
},
}
// Initialize NUMA pools if enabled and available
if config.EnableNUMA && Available() && pool.topology.NumNodes > 1 {
for nodeID := range pool.topology.Nodes {
pool.nodePools[nodeID] = pool.createNodePool(nodeID)
}
}
return pool
}
// createNodePool creates a buffer pool for a specific NUMA node
func (p *NUMABufferPool) createNodePool(node NodeID) *sync.Pool {
return &sync.Pool{
New: func() interface{} {
atomic.AddUint64(&p.stats.Misses, 1)
// Try NUMA-local allocation first
if p.config.EnableNUMA && Available() {
buf, err := AllocateOnNode(p.config.BufferSize, node)
if err == nil {
atomic.AddUint64(&p.stats.NUMAHit, 1)
return buf
}
}
// Fall back to regular allocation
return make([]byte, p.config.BufferSize)
},
}
}
// Get returns a buffer from the pool, preferably from the local NUMA node
func (p *NUMABufferPool) Get() []byte {
atomic.AddUint64(&p.stats.Gets, 1)
// Try to get from the local NUMA node first
if p.config.EnableNUMA && Available() && len(p.nodePools) > 0 {
if node, err := GetCurrentNode(); err == nil {
if nodePool, ok := p.nodePools[node]; ok {
buf := nodePool.Get().([]byte)
atomic.AddUint64(&p.stats.NodeLocalHit, 1)
return buf[:p.config.BufferSize]
}
}
}
// Fall back to the fallback pool
return p.fallbackPool.Get().([]byte)[:p.config.BufferSize]
}
// Put returns a buffer to the pool, preferably to its local NUMA node
func (p *NUMABufferPool) Put(buf []byte) {
if buf == nil {
return
}
atomic.AddUint64(&p.stats.Puts, 1)
// Resize buffer to full size before returning to pool
if cap(buf) < p.config.BufferSize {
// Buffer is too small, discard it
return
}
buf = buf[:p.config.BufferSize]
// Try to return to the local NUMA node pool
if p.config.EnableNUMA && Available() && len(p.nodePools) > 0 {
if node, err := GetCurrentNode(); err == nil {
if nodePool, ok := p.nodePools[node]; ok {
nodePool.Put(buf)
return
}
}
}
// Fall back to the fallback pool
p.fallbackPool.Put(buf)
}
// GetForNode returns a buffer from a specific NUMA node's pool
func (p *NUMABufferPool) GetForNode(node NodeID) []byte {
atomic.AddUint64(&p.stats.Gets, 1)
if nodePool, ok := p.nodePools[node]; ok {
return nodePool.Get().([]byte)[:p.config.BufferSize]
}
return p.fallbackPool.Get().([]byte)[:p.config.BufferSize]
}
// PutForNode returns a buffer to a specific NUMA node's pool
func (p *NUMABufferPool) PutForNode(node NodeID, buf []byte) {
if buf == nil {
return
}
atomic.AddUint64(&p.stats.Puts, 1)
if cap(buf) < p.config.BufferSize {
return
}
buf = buf[:p.config.BufferSize]
if nodePool, ok := p.nodePools[node]; ok {
nodePool.Put(buf)
return
}
p.fallbackPool.Put(buf)
}
// Stats returns current pool statistics
func (p *NUMABufferPool) Stats() PoolStats {
return PoolStats{
Gets: atomic.LoadUint64(&p.stats.Gets),
Puts: atomic.LoadUint64(&p.stats.Puts),
Misses: atomic.LoadUint64(&p.stats.Misses),
NodeLocalHit: atomic.LoadUint64(&p.stats.NodeLocalHit),
NUMAHit: atomic.LoadUint64(&p.stats.NUMAHit),
}
}
// GetConfig returns the pool configuration
func (p *NUMABufferPool) GetConfig() *BufferPoolConfig {
return p.config
}
// Close releases all resources associated with the pool
func (p *NUMABufferPool) Close() error {
p.mu.Lock()
defer p.mu.Unlock()
// Clear all pools
p.nodePools = make(map[NodeID]*sync.Pool)
p.fallbackPool = nil
return nil
}
// SizeAwarePool is a buffer pool that can handle multiple buffer sizes
type SizeAwarePool struct {
pools map[int]*NUMABufferPool
mu sync.RWMutex
}
// NewSizeAwarePool creates a new size-aware buffer pool
func NewSizeAwarePool(sizes []int, enableNUMA bool) *SizeAwarePool {
sap := &SizeAwarePool{
pools: make(map[int]*NUMABufferPool),
}
for _, size := range sizes {
sap.pools[size] = NewNUMABufferPool(&BufferPoolConfig{
BufferSize: size,
PerNodePoolSize: 1024,
EnableNUMA: enableNUMA,
})
}
return sap
}
// Get returns a buffer of the specified size
func (sap *SizeAwarePool) Get(size int) []byte {
sap.mu.RLock()
pool, ok := sap.pools[size]
sap.mu.RUnlock()
if ok {
return pool.Get()
}
// No pool for this size, allocate directly
return make([]byte, size)
}
// Put returns a buffer to the appropriate pool
func (sap *SizeAwarePool) Put(buf []byte) {
if buf == nil {
return
}
size := cap(buf)
sap.mu.RLock()
pool, ok := sap.pools[size]
sap.mu.RUnlock()
if ok {
pool.Put(buf)
}
// If no pool for this size, let GC handle it
}
// PinningAllocator allocates buffers while the goroutine is pinned to a NUMA node
type PinningAllocator struct {
pool *NUMABufferPool
}
// NewPinningAllocator creates a new pinning allocator
func NewPinningAllocator(pool *NUMABufferPool) *PinningAllocator {
return &PinningAllocator{pool: pool}
}
// Allocate allocates a buffer while pinned to the current NUMA node
func (pa *PinningAllocator) Allocate() []byte {
return pa.pool.Get()
}
// AllocateOnNode allocates a buffer while pinned to a specific NUMA node
func (pa *PinningAllocator) AllocateOnNode(node NodeID) ([]byte, error) {
var buf []byte
err := RunOnNode(node, func() {
buf = pa.pool.GetForNode(node)
})
return buf, err
}
// Global pools for common buffer sizes
var (
globalPools map[int]*NUMABufferPool
globalPoolsOnce sync.Once
globalPoolsMu sync.RWMutex
)
// InitGlobalPools initializes global buffer pools
func InitGlobalPools(sizes []int, enableNUMA bool) {
globalPoolsOnce.Do(func() {
globalPools = make(map[int]*NUMABufferPool)
for _, size := range sizes {
globalPools[size] = NewNUMABufferPool(&BufferPoolConfig{
BufferSize: size,
PerNodePoolSize: 1024,
EnableNUMA: enableNUMA,
})
}
})
}
// GetBuffer gets a buffer from the global pool
func GetBuffer(size int) []byte {
globalPoolsMu.RLock()
pool, ok := globalPools[size]
globalPoolsMu.RUnlock()
if ok {
return pool.Get()
}
return make([]byte, size)
}
// PutBuffer returns a buffer to the global pool
func PutBuffer(buf []byte) {
if buf == nil {
return
}
size := cap(buf)
globalPoolsMu.RLock()
pool, ok := globalPools[size]
globalPoolsMu.RUnlock()
if ok {
pool.Put(buf)
}
}
// WorkerPool is a pool of workers that are pinned to specific NUMA nodes
type WorkerPool struct {
size int
numaNode NodeID
workQueue chan func()
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
}
// NewWorkerPool creates a new NUMA-aware worker pool
func NewWorkerPool(size int, node NodeID) *WorkerPool {
ctx, cancel := context.WithCancel(context.Background())
wp := &WorkerPool{
size: size,
numaNode: node,
workQueue: make(chan func(), size*2),
ctx: ctx,
cancel: cancel,
}
// Start workers
for i := 0; i < size; i++ {
wp.wg.Add(1)
go wp.worker()
}
return wp
}
func (wp *WorkerPool) worker() {
defer wp.wg.Done()
// Pin to NUMA node
if Available() {
PinThreadToNode(wp.numaNode)
defer UnpinThread()
}
for {
select {
case work := <-wp.workQueue:
if work != nil {
work()
}
case <-wp.ctx.Done():
return
}
}
}
// Submit submits work to the worker pool
func (wp *WorkerPool) Submit(work func()) bool {
select {
case wp.workQueue <- work:
return true
case <-wp.ctx.Done():
return false
default:
return false
}
}
// Stop stops the worker pool
func (wp *WorkerPool) Stop() {
wp.cancel()
wp.wg.Wait()
close(wp.workQueue)
}

View File

@@ -73,32 +73,38 @@ func MarshalKVText(kv []KeyValue) []byte {
return data
}
// MarshalUint16 returns big-endian encoding of i as a new 2-byte slice.
// Deprecated: Use MarshalUint16To or binary.BigEndian.PutUint16 for zero-allocation.
func MarshalUint16(i uint16) []byte {
var data []byte
for j := 8; j >= 0; j -= 8 {
b := byte(i >> uint16(j))
data = append(data, b)
}
return data
var data [2]byte
binary.BigEndian.PutUint16(data[:], i)
return data[:]
}
// MarshalUint32 returns big-endian encoding of i as a new 4-byte slice.
// Deprecated: Use MarshalUint32To or binary.BigEndian.PutUint32 for zero-allocation.
func MarshalUint32(i uint32) []byte {
var data []byte
for j := 24; j >= 0; j -= 8 {
b := byte(i >> uint32(j))
data = append(data, b)
}
return data
var data [4]byte
binary.BigEndian.PutUint32(data[:], i)
return data[:]
}
// MarshalUint32To writes big-endian encoding of i into buf, which must be at least 4 bytes.
// This is a zero-allocation alternative to MarshalUint32.
func MarshalUint32To(buf []byte, i uint32) {
binary.BigEndian.PutUint32(buf, i)
}
func MarshalUint64(v uint64) []byte {
var data = [8]byte{}
var i = 0
for j := 56; j >= 0; j -= 8 {
data[i] = byte(v >> uint32(j))
i++
}
return data[0:8]
var data [8]byte
binary.BigEndian.PutUint64(data[:], v)
return data[:]
}
// MarshalUint64To writes big-endian encoding of v into buf, which must be at least 8 bytes.
// This is a zero-allocation alternative for partial writes.
func MarshalUint64To(buf []byte, v uint64) {
binary.BigEndian.PutUint64(buf, v)
}
func StringToByte(str string, align int, maxlength int) []byte {