xx/spmd/spmd.go
2025-12-06 22:33:23 -07:00

149 lines
3.1 KiB
Go

// Package spmd contains useful primitives for SPMD programs.
//
// Usage:
//
// func main() {
// // Create N execution lanes (<= 0 for GOMAXPROCS lanes),
// // and run 'Compute' across N lanes.
// spmd.Run(-1, Compute)
// }
//
// func Compute(lane spmd.Lane) {
// log.Printf("Lane %d/%d is executing", lane.Index, lane.Count)
//
// // Execute this code on a lane locked to the main thread (aka lane.Index == 0)
// // One lane will always be locked to the main thread
// if lane.Main() {
// data, err := os.ReadFile(...)
// if err != nil {
// panic(err)
// }
//
// // Send data to all lanes ("DATA" can be any value)
// lane.Store("DATA", string(data))
// }
//
// // Wait until all lanes are at this point
// lane.Sync()
//
// // Load stored data
// data := lane.Load("DATA").(string)
//
// // Get lane-specific access range for data
// lo, hi := lane.Range(len(data))
// for i := lo; i < hi; i++ {
// // ...
// }
// }
package spmd
import (
"runtime"
"sync"
"sync/atomic"
"git.brut.systems/judah/xx/osthread"
)
// Run will start executing the given function across N execution lanes,
// blocking until they have all finished executing.
//
// If nLanes is <= 0, GOMAXPROCS will be used.
//
// Run must be called from the program's main function.
func Run(nLanes int, fn func(lane Lane)) {
if nLanes <= 0 {
nLanes = runtime.GOMAXPROCS(0)
}
osthread.Start(func() {
s := new(state)
s.cond = sync.NewCond(&s.mtx)
s.total = uint64(nLanes)
var wg sync.WaitGroup
for i := range s.total {
if i == 0 { // Lane 0 is always on the main thread
wg.Add(1)
osthread.Go(func() {
fn(Lane{state: s, Index: uint32(i), Count: uint32(s.total)})
wg.Done()
})
} else { // Everyone else gets scheduled like usual
wg.Go(func() {
fn(Lane{state: s, Index: uint32(i), Count: uint32(s.total)})
})
}
}
wg.Wait()
})
}
type state struct {
mtx sync.Mutex
cond *sync.Cond
waiting atomic.Uint64
total uint64
userdata sync.Map
}
type Lane struct {
state *state
Index uint32
Count uint32
}
// Main returns if the lane is locked to the main thread.
func (l Lane) Main() bool {
return l.Index == 0
}
// Sync pauses the current lane until all lanes are at the same sync point.
func (l Lane) Sync() {
l.state.mtx.Lock()
defer l.state.mtx.Unlock()
if l.state.waiting.Add(1) >= l.state.total {
l.state.waiting.Store(0)
l.state.cond.Broadcast()
return
}
l.state.cond.Wait()
}
// Store sends 'value' to all lanes.
//
// Store can be called concurrently.
func (l Lane) Store(key, value any) {
l.state.userdata.Store(key, value)
}
// Load fetches a named value, returning nil if it does not exist.
//
// Load can be called concurrently.
func (l Lane) Load(key any) any {
v, ok := l.state.userdata.Load(key)
if !ok {
return nil
}
return v
}
// Range returns a lane's data range for the given length.
func (l Lane) Range(length int) (lo, hi uint) {
size := uint(length) / uint(l.state.total)
rem := uint(length) % uint(l.state.total)
if uint(l.Index) < rem {
lo = uint(l.Index) * (size + 1)
hi = lo + size + 1
} else {
lo = uint(l.Index)*size + rem
hi = lo + size
}
return
}