spmd: move from old repo and simplifly (aka. use osthread)
This commit is contained in:
parent
346bb3a2e9
commit
bc751cc791
1 changed files with 147 additions and 0 deletions
147
spmd/spmd.go
Normal file
147
spmd/spmd.go
Normal file
|
|
@ -0,0 +1,147 @@
|
|||
// Package spmd contains useful primitives for SPMD programs.
|
||||
//
|
||||
// Usage:
|
||||
//
|
||||
// func main() {
|
||||
// // Create N execution lanes (<= 0 for GOMAXPROCS lanes),
|
||||
// // and run 'Compute' across N lanes.
|
||||
// spmd.Run(-1, Compute)
|
||||
// }
|
||||
//
|
||||
// func Compute(lane spmd.Lane) {
|
||||
// log.Printf("Lane %d/%d is executing", lane.Index, lane.Count)
|
||||
//
|
||||
// // Execute this code on a lane locked to the main thread (aka lane.Index == 0)
|
||||
// // One lane will always be locked to the main thread
|
||||
// if lane.Main() {
|
||||
// data, err := os.ReadFile(...)
|
||||
// if err != nil {
|
||||
// panic(err)
|
||||
// }
|
||||
//
|
||||
// // Send data to all lanes ("DATA" can be any value)
|
||||
// lane.Store("DATA", string(data))
|
||||
// }
|
||||
//
|
||||
// // Wait until all lanes are at this point
|
||||
// lane.Sync()
|
||||
//
|
||||
// // Load stored data
|
||||
// data := lane.Load("DATA").(string)
|
||||
//
|
||||
// // Get lane-specific access range for data
|
||||
// lo, hi := lane.Range(len(data))
|
||||
// for i := lo; i < hi; i++ {
|
||||
// // ...
|
||||
// }
|
||||
// }
|
||||
package spmd
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"git.brut.systems/judah/xx/osthread"
|
||||
)
|
||||
|
||||
// Run will start executing the given function across N execution lanes,
|
||||
// blocking until they have all finished executing.
|
||||
//
|
||||
// If nLanes is <= 0, GOMAXPROCS will be used.
|
||||
//
|
||||
// Run must be called from the program's main function.
|
||||
func Run(nLanes int, fn func(lane Lane)) {
|
||||
if nLanes <= 0 {
|
||||
nLanes = runtime.GOMAXPROCS(0)
|
||||
}
|
||||
|
||||
osthread.Start(func() {
|
||||
s := new(state)
|
||||
s.cond = sync.NewCond(&s.mtx)
|
||||
s.total = uint64(nLanes)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := range s.total {
|
||||
if i == 0 { // Lane 0 is always on the main thread
|
||||
wg.Add(1)
|
||||
osthread.Go(func() {
|
||||
fn(Lane{state: s, Index: uint32(i), Count: uint32(s.total)})
|
||||
wg.Done()
|
||||
})
|
||||
} else { // Everyone else gets scheduled like usual
|
||||
wg.Go(func() {
|
||||
fn(Lane{state: s, Index: uint32(i), Count: uint32(s.total)})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
})
|
||||
}
|
||||
|
||||
type state struct {
|
||||
mtx sync.Mutex
|
||||
cond *sync.Cond
|
||||
waiting atomic.Uint64
|
||||
total uint64
|
||||
userdata sync.Map
|
||||
}
|
||||
|
||||
type Lane struct {
|
||||
state *state
|
||||
Index uint32
|
||||
Count uint32
|
||||
}
|
||||
|
||||
// Main returns if the lane is locked to the main thread.
|
||||
func (l Lane) Main() bool {
|
||||
return l.Index == 0
|
||||
}
|
||||
|
||||
// Sync pauses the current lane until all lanes are at the same sync point.
|
||||
func (l Lane) Sync() {
|
||||
l.state.mtx.Lock()
|
||||
defer l.state.mtx.Unlock()
|
||||
|
||||
if l.state.waiting.Add(1) >= l.state.total {
|
||||
l.state.waiting.Store(0)
|
||||
l.state.cond.Broadcast()
|
||||
return
|
||||
}
|
||||
|
||||
l.state.cond.Wait()
|
||||
}
|
||||
|
||||
// Store sends 'value' to all lanes.
|
||||
// Store can be called concurrently.
|
||||
func (l Lane) Store(key, value any) {
|
||||
l.state.userdata.Store(key, value)
|
||||
}
|
||||
|
||||
// Load fetches a named value, returning nil if it does not exist.
|
||||
// Load can be called concurrently.
|
||||
func (l Lane) Load(key any) any {
|
||||
v, ok := l.state.userdata.Load(key)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
return v
|
||||
}
|
||||
|
||||
// Range returns a lane's data range for the given length.
|
||||
func (l Lane) Range(length int) (lo, hi uint) {
|
||||
size := uint(length) / uint(l.state.total)
|
||||
rem := uint(length) % uint(l.state.total)
|
||||
|
||||
if uint(l.Index) < rem {
|
||||
lo = uint(l.Index) * (size + 1)
|
||||
hi = lo + size + 1
|
||||
} else {
|
||||
lo = uint(l.Index)*size + rem
|
||||
hi = lo + size
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
Loading…
Reference in a new issue