// Package spmd contains useful primitives for SPMD programs. // // Usage: // // func main() { // // Create N execution lanes (<= 0 for GOMAXPROCS lanes), // // and run 'Compute' across N lanes. // spmd.Run(-1, Compute) // } // // func Compute(lane spmd.Lane) { // log.Printf("Lane %d/%d is executing", lane.Index, lane.Count) // // // Execute this code on a lane locked to the main thread (aka lane.Index == 0) // // One lane will always be locked to the main thread // if lane.Main() { // data, err := os.ReadFile(...) // if err != nil { // panic(err) // } // // // Send data to all lanes ("DATA" can be any value) // lane.Store("DATA", string(data)) // } // // // Wait until all lanes are at this point // lane.Sync() // // // Load stored data // data := lane.Load("DATA").(string) // // // Get lane-specific access range for data // lo, hi := lane.Range(len(data)) // for i := lo; i < hi; i++ { // // ... // } // } package spmd import ( "runtime" "sync" "sync/atomic" "git.brut.systems/judah/xx/osthread" ) // Run will start executing the given function across N execution lanes, // blocking until they have all finished executing. // // If nLanes is <= 0, GOMAXPROCS will be used. // // Run must be called from the program's main function. func Run(nLanes int, fn func(lane Lane)) { if nLanes <= 0 { nLanes = runtime.GOMAXPROCS(0) } osthread.Start(func() { s := new(state) s.cond = sync.NewCond(&s.mtx) s.total = uint64(nLanes) var wg sync.WaitGroup for i := range s.total { if i == 0 { // Lane 0 is always on the main thread wg.Add(1) osthread.Go(func() { fn(Lane{state: s, Index: uint32(i), Count: uint32(s.total)}) wg.Done() }) } else { // Everyone else gets scheduled like usual wg.Go(func() { fn(Lane{state: s, Index: uint32(i), Count: uint32(s.total)}) }) } } wg.Wait() }) } type state struct { mtx sync.Mutex cond *sync.Cond waiting atomic.Uint64 total uint64 userdata sync.Map } type Lane struct { state *state Index uint32 Count uint32 } // Main returns if the lane is locked to the main thread. func (l Lane) Main() bool { return l.Index == 0 } // Sync pauses the current lane until all lanes are at the same sync point. func (l Lane) Sync() { l.state.mtx.Lock() defer l.state.mtx.Unlock() if l.state.waiting.Add(1) >= l.state.total { l.state.waiting.Store(0) l.state.cond.Broadcast() return } l.state.cond.Wait() } // Store sends 'value' to all lanes. // // Store can be called concurrently. func (l Lane) Store(key, value any) { l.state.userdata.Store(key, value) } // Load fetches a named value, returning nil if it does not exist. // // Load can be called concurrently. func (l Lane) Load(key any) any { v, ok := l.state.userdata.Load(key) if !ok { return nil } return v } // Range returns a lane's data range for the given length. func (l Lane) Range(length int) (lo, hi uint) { size := uint(length) / uint(l.state.total) rem := uint(length) % uint(l.state.total) if uint(l.Index) < rem { lo = uint(l.Index) * (size + 1) hi = lo + size + 1 } else { lo = uint(l.Index)*size + rem hi = lo + size } return }