diff --git a/spmd/spmd.go b/spmd/spmd.go new file mode 100644 index 0000000..17192ec --- /dev/null +++ b/spmd/spmd.go @@ -0,0 +1,147 @@ +// Package spmd contains useful primitives for SPMD programs. +// +// Usage: +// +// func main() { +// // Create N execution lanes (<= 0 for GOMAXPROCS lanes), +// // and run 'Compute' across N lanes. +// spmd.Run(-1, Compute) +// } +// +// func Compute(lane spmd.Lane) { +// log.Printf("Lane %d/%d is executing", lane.Index, lane.Count) +// +// // Execute this code on a lane locked to the main thread (aka lane.Index == 0) +// // One lane will always be locked to the main thread +// if lane.Main() { +// data, err := os.ReadFile(...) +// if err != nil { +// panic(err) +// } +// +// // Send data to all lanes ("DATA" can be any value) +// lane.Store("DATA", string(data)) +// } +// +// // Wait until all lanes are at this point +// lane.Sync() +// +// // Load stored data +// data := lane.Load("DATA").(string) +// +// // Get lane-specific access range for data +// lo, hi := lane.Range(len(data)) +// for i := lo; i < hi; i++ { +// // ... +// } +// } +package spmd + +import ( + "runtime" + "sync" + "sync/atomic" + + "git.brut.systems/judah/xx/osthread" +) + +// Run will start executing the given function across N execution lanes, +// blocking until they have all finished executing. +// +// If nLanes is <= 0, GOMAXPROCS will be used. +// +// Run must be called from the program's main function. +func Run(nLanes int, fn func(lane Lane)) { + if nLanes <= 0 { + nLanes = runtime.GOMAXPROCS(0) + } + + osthread.Start(func() { + s := new(state) + s.cond = sync.NewCond(&s.mtx) + s.total = uint64(nLanes) + + var wg sync.WaitGroup + for i := range s.total { + if i == 0 { // Lane 0 is always on the main thread + wg.Add(1) + osthread.Go(func() { + fn(Lane{state: s, Index: uint32(i), Count: uint32(s.total)}) + wg.Done() + }) + } else { // Everyone else gets scheduled like usual + wg.Go(func() { + fn(Lane{state: s, Index: uint32(i), Count: uint32(s.total)}) + }) + } + } + + wg.Wait() + }) +} + +type state struct { + mtx sync.Mutex + cond *sync.Cond + waiting atomic.Uint64 + total uint64 + userdata sync.Map +} + +type Lane struct { + state *state + Index uint32 + Count uint32 +} + +// Main returns if the lane is locked to the main thread. +func (l Lane) Main() bool { + return l.Index == 0 +} + +// Sync pauses the current lane until all lanes are at the same sync point. +func (l Lane) Sync() { + l.state.mtx.Lock() + defer l.state.mtx.Unlock() + + if l.state.waiting.Add(1) >= l.state.total { + l.state.waiting.Store(0) + l.state.cond.Broadcast() + return + } + + l.state.cond.Wait() +} + +// Store sends 'value' to all lanes. +// Store can be called concurrently. +func (l Lane) Store(key, value any) { + l.state.userdata.Store(key, value) +} + +// Load fetches a named value, returning nil if it does not exist. +// Load can be called concurrently. +func (l Lane) Load(key any) any { + v, ok := l.state.userdata.Load(key) + if !ok { + return nil + } + + return v +} + +// Range returns a lane's data range for the given length. +func (l Lane) Range(length int) (lo, hi uint) { + size := uint(length) / uint(l.state.total) + rem := uint(length) % uint(l.state.total) + + if uint(l.Index) < rem { + lo = uint(l.Index) * (size + 1) + hi = lo + size + 1 + } else { + lo = uint(l.Index)*size + rem + hi = lo + size + } + + return +}