147 lines
3.1 KiB
Go
147 lines
3.1 KiB
Go
// Package spmd contains useful primitives for SPMD programs.
|
|
//
|
|
// Usage:
|
|
//
|
|
// func main() {
|
|
// // Create N execution lanes (<= 0 for GOMAXPROCS lanes),
|
|
// // and run 'Compute' across N lanes.
|
|
// spmd.Run(-1, Compute)
|
|
// }
|
|
//
|
|
// func Compute(lane spmd.Lane) {
|
|
// log.Printf("Lane %d/%d is executing", lane.Index, lane.Count)
|
|
//
|
|
// // Execute this code on a lane locked to the main thread (aka lane.Index == 0)
|
|
// // One lane will always be locked to the main thread
|
|
// if lane.Main() {
|
|
// data, err := os.ReadFile(...)
|
|
// if err != nil {
|
|
// panic(err)
|
|
// }
|
|
//
|
|
// // Send data to all lanes ("DATA" can be any value)
|
|
// lane.Store("DATA", string(data))
|
|
// }
|
|
//
|
|
// // Wait until all lanes are at this point
|
|
// lane.Sync()
|
|
//
|
|
// // Load stored data
|
|
// data := lane.Load("DATA").(string)
|
|
//
|
|
// // Get lane-specific access range for data
|
|
// lo, hi := lane.Range(len(data))
|
|
// for i := lo; i < hi; i++ {
|
|
// // ...
|
|
// }
|
|
// }
|
|
package spmd
|
|
|
|
import (
|
|
"runtime"
|
|
"sync"
|
|
"sync/atomic"
|
|
|
|
"git.brut.systems/judah/xx/osthread"
|
|
)
|
|
|
|
// Run will start executing the given function across N execution lanes,
|
|
// blocking until they have all finished executing.
|
|
//
|
|
// If nLanes is <= 0, GOMAXPROCS will be used.
|
|
//
|
|
// Run must be called from the program's main function.
|
|
func Run(nLanes int, fn func(lane Lane)) {
|
|
if nLanes <= 0 {
|
|
nLanes = runtime.GOMAXPROCS(0)
|
|
}
|
|
|
|
osthread.Start(func() {
|
|
s := new(state)
|
|
s.cond = sync.NewCond(&s.mtx)
|
|
s.total = uint64(nLanes)
|
|
|
|
var wg sync.WaitGroup
|
|
for i := range s.total {
|
|
if i == 0 { // Lane 0 is always on the main thread
|
|
wg.Add(1)
|
|
osthread.Go(func() {
|
|
fn(Lane{state: s, Index: uint32(i), Count: uint32(s.total)})
|
|
wg.Done()
|
|
})
|
|
} else { // Everyone else gets scheduled like usual
|
|
wg.Go(func() {
|
|
fn(Lane{state: s, Index: uint32(i), Count: uint32(s.total)})
|
|
})
|
|
}
|
|
}
|
|
|
|
wg.Wait()
|
|
})
|
|
}
|
|
|
|
type state struct {
|
|
mtx sync.Mutex
|
|
cond *sync.Cond
|
|
waiting atomic.Uint64
|
|
total uint64
|
|
userdata sync.Map
|
|
}
|
|
|
|
type Lane struct {
|
|
state *state
|
|
Index uint32
|
|
Count uint32
|
|
}
|
|
|
|
// Main returns if the lane is locked to the main thread.
|
|
func (l Lane) Main() bool {
|
|
return l.Index == 0
|
|
}
|
|
|
|
// Sync pauses the current lane until all lanes are at the same sync point.
|
|
func (l Lane) Sync() {
|
|
l.state.mtx.Lock()
|
|
defer l.state.mtx.Unlock()
|
|
|
|
if l.state.waiting.Add(1) >= l.state.total {
|
|
l.state.waiting.Store(0)
|
|
l.state.cond.Broadcast()
|
|
return
|
|
}
|
|
|
|
l.state.cond.Wait()
|
|
}
|
|
|
|
// Store sends 'value' to all lanes.
|
|
// Store can be called concurrently.
|
|
func (l Lane) Store(key, value any) {
|
|
l.state.userdata.Store(key, value)
|
|
}
|
|
|
|
// Load fetches a named value, returning nil if it does not exist.
|
|
// Load can be called concurrently.
|
|
func (l Lane) Load(key any) any {
|
|
v, ok := l.state.userdata.Load(key)
|
|
if !ok {
|
|
return nil
|
|
}
|
|
|
|
return v
|
|
}
|
|
|
|
// Range returns a lane's data range for the given length.
|
|
func (l Lane) Range(length int) (lo, hi uint) {
|
|
size := uint(length) / uint(l.state.total)
|
|
rem := uint(length) % uint(l.state.total)
|
|
|
|
if uint(l.Index) < rem {
|
|
lo = uint(l.Index) * (size + 1)
|
|
hi = lo + size + 1
|
|
} else {
|
|
lo = uint(l.Index)*size + rem
|
|
hi = lo + size
|
|
}
|
|
|
|
return
|
|
}
|