diff --git a/asm/asm.go b/asm/asm.go new file mode 100644 index 0000000..6ec7d2a --- /dev/null +++ b/asm/asm.go @@ -0,0 +1,26 @@ +package asm + +type Call struct { + registers + fn uintptr + sp uintptr + stack [128]uintptr +} + +func PrepareCall(fn uintptr) Call { + return Call{fn: fn} +} + +func (c *Call) Push(v uintptr) { + c.stack[c.sp] = v + c.sp += 1 +} + +func (c *Call) Pop() uintptr { + c.sp -= 1 + return c.stack[c.sp] +} + +func (c *Call) Do() { + perform_call(c) +} diff --git a/asm/asm_arm64.go b/asm/asm_arm64.go new file mode 100644 index 0000000..3eea56a --- /dev/null +++ b/asm/asm_arm64.go @@ -0,0 +1,16 @@ +package asm + +type registers struct { + R0, R1, R2, R3, R4, + R5, R6, R7, R8, R9, + R10, R11, R12, R13, R14, + R15, R16, R17, R19, R20, + R21, R22, R23, R24, R25, + R26, R27, R29, R30, R31 uintptr + + F0, F1, F2, F3, F4, + F5, F6, F7, F8, F9, + F10 float64 +} + +func perform_call(c *Call) diff --git a/asm/asm_arm64.s b/asm/asm_arm64.s new file mode 100644 index 0000000..e69de29 diff --git a/asm/asm_arm64_test.go b/asm/asm_arm64_test.go new file mode 100644 index 0000000..d6ecc6d --- /dev/null +++ b/asm/asm_arm64_test.go @@ -0,0 +1,23 @@ +package asm_test + +import ( + "testing" + + _ "unsafe" + + "git.brut.systems/judah/xx/asm" +) + +func TestAdd(t *testing.T) { + call := asm.PrepareCall(0) + call.R1 = 10 + call.R2 = 10 + call.Do() + + result := call.R0 + expected := uintptr(8) + + if result != expected { + t.Errorf("Expected %d, got %d", expected, result) + } +} diff --git a/stable/array.go b/containerx/bucket/array.go similarity index 99% rename from stable/array.go rename to containerx/bucket/array.go index 503ff3a..3dac56e 100644 --- a/stable/array.go +++ b/containerx/bucket/array.go @@ -1,4 +1,4 @@ -package stable +package bucket import ( "iter" @@ -10,9 +10,9 @@ const DefaultElementsPerBucket = 32 // This means it is safe to take a pointer to a value within the array // while continuing to append to it. type Array[T any] struct { - buckets []bucket[T] last int elements_per_bucket int + buckets []bucket[T] } func (s *Array[T]) Init() { diff --git a/stable/array_test.go b/containerx/bucket/array_test.go similarity index 55% rename from stable/array_test.go rename to containerx/bucket/array_test.go index ff17a89..8dd3356 100644 --- a/stable/array_test.go +++ b/containerx/bucket/array_test.go @@ -1,10 +1,11 @@ -package stable_test +package bucket_test import ( "runtime" "testing" - "git.brut.systems/judah/xx/stable" + "git.brut.systems/judah/xx/containerx/bucket" + "git.brut.systems/judah/xx/testx" ) func TestArray_StableWithGC(t *testing.T) { @@ -13,7 +14,7 @@ func TestArray_StableWithGC(t *testing.T) { ptr *int } - var arr stable.Array[valuewithptr] + var arr bucket.Array[valuewithptr] aptr := arr.Append(valuewithptr{value: 10, ptr: nil}) bptr := arr.Append(valuewithptr{value: 20, ptr: &aptr.value}) @@ -23,27 +24,27 @@ func TestArray_StableWithGC(t *testing.T) { runtime.GC() } - expect(t, arr.Get(0) == aptr) - expect(t, arr.Get(1) == bptr) - expect(t, arr.Len() == N+2, "len was %d", arr.Len()) - expect(t, bptr.ptr != nil && bptr.value == 20) - expect(t, bptr.ptr == &aptr.value, "%p vs. %p", bptr.ptr, &aptr.value) + testx.Expect(t, arr.Get(0) == aptr) + testx.Expect(t, arr.Get(1) == bptr) + testx.Expect(t, arr.Len() == N+2, "len was %d", arr.Len()) + testx.Expect(t, bptr.ptr != nil && bptr.value == 20) + testx.Expect(t, bptr.ptr == &aptr.value, "%p vs. %p", bptr.ptr, &aptr.value) } func BenchmarkArray_RandomAccess(b *testing.B) { - var arr stable.Array[int] + var arr bucket.Array[int] for i := range b.N { arr.Append(i * i) } b.ResetTimer() for i := range b.N { - arr.Get(i % 10000) + arr.Get(i % b.N) } } func BenchmarkArray_Append(b *testing.B) { - var arr stable.Array[int] + var arr bucket.Array[int] for i := range b.N { arr.Append(i * i) } @@ -55,7 +56,7 @@ func BenchmarkArray_Append(b *testing.B) { } func BenchmarkArray_Iteration(b *testing.B) { - var arr stable.Array[int] + var arr bucket.Array[int] for i := range b.N { arr.Append(i * i) } @@ -67,16 +68,3 @@ func BenchmarkArray_Iteration(b *testing.B) { sum += v } } - -func expect(t *testing.T, cond bool, message ...any) { - t.Helper() - - if !cond { - if len(message) == 0 { - message = append(message, "assertion failed") - } - - str := message[0].(string) - t.Fatalf(str, message[1:]...) - } -} diff --git a/stable/xar.go b/containerx/xar/xar.go similarity index 99% rename from stable/xar.go rename to containerx/xar/xar.go index 3163266..fd2b40b 100644 --- a/stable/xar.go +++ b/containerx/xar/xar.go @@ -1,4 +1,4 @@ -package stable +package xar import ( "iter" diff --git a/containerx/xar/xar_test.go b/containerx/xar/xar_test.go new file mode 100644 index 0000000..ecb1c46 --- /dev/null +++ b/containerx/xar/xar_test.go @@ -0,0 +1,105 @@ +package xar_test + +import ( + "runtime" + "testing" + + "git.brut.systems/judah/xx/containerx/xar" + "git.brut.systems/judah/xx/testx" +) + +func TestXar_StableWithGC(t *testing.T) { + type valuewithptr struct { + value int + ptr *int + } + + var x xar.Xar[valuewithptr] + x.InitWithSize(8) + + aptr := x.Append(valuewithptr{value: 10, ptr: nil}) + bptr := x.Append(valuewithptr{value: 20, ptr: &aptr.value}) + + const N = 1000 + for i := range N { + x.Append(valuewithptr{value: i}) + runtime.GC() + } + + testx.Expect(t, x.Get(0) == bptr) + testx.Expect(t, x.Get(1) == bptr) + testx.Expect(t, x.Len() == N+2, "len was %d", x.Len()) + testx.Expect(t, bptr.ptr != nil && bptr.value == 20) + testx.Expect(t, bptr.ptr == &aptr.value, "%p vs. %p", bptr.ptr, &aptr.value) +} + +func TestXar_ResetAndReuse(t *testing.T) { + var x xar.Xar[int] + start := x.Append(60) + x.AppendMany(10, 20, 30, 40, 50) + + x.Reset() + runtime.GC() + + testx.Expect(t, x.Cap() != 0) + testx.Expect(t, x.Len() == 0) + + x.Append(0xFF) + x.Append(0xFC) + x.Append(0xFB) + + testx.Expect(t, x.Get(0) == start) + testx.Expect(t, x.Len() == 3) +} + +func TestXar_Iterators(t *testing.T) { + var x xar.Xar[int] + x.AppendMany(0, 1, 2, 3, 4, 5) + + iterations := 0 + for i, v := range x.Values() { + iterations += 1 + testx.Expect(t, v == i, "v: %d, i: %d", v, i) + } + + testx.Expect(t, iterations == x.Len()) +} + +func BenchmarkXar_Append(b *testing.B) { + var x xar.Xar[int] + for i := range b.N { + x.Append(i * i) + } + + x.Reset() + for i := range b.N { + x.Append(i * i) + } +} + +func BenchmarkXar_RandomAccess(b *testing.B) { + var x xar.Xar[int] + for i := range b.N { + x.Append(i * i) + } + + b.ResetTimer() + + for i := range b.N { + x.Get(i % b.N) + } +} + +func BenchmarkXar_Iteration(b *testing.B) { + var x xar.Xar[int] + for i := range b.N { + x.Append(i * i) + } + + b.ResetTimer() + + sum := 0 + for _, v := range x.Values() { + sum += v + } +} diff --git a/go.mod b/go.mod index 00758d6..97aca90 100644 --- a/go.mod +++ b/go.mod @@ -2,4 +2,4 @@ module git.brut.systems/judah/xx go 1.25.0 -require github.com/ebitengine/purego v0.9.1 // indirect +require github.com/ebitengine/purego v0.9.1 diff --git a/mem/mem.go b/mem/mem.go index 5a89be0..128c917 100644 --- a/mem/mem.go +++ b/mem/mem.go @@ -8,16 +8,19 @@ import ( // // Not to be confused with [unsafe.Sizeof] which returns the size of a type via an expression. func SizeOf[T any]() uintptr { - var zero T - return unsafe.Sizeof(zero) + return unsafe.Sizeof(ZeroValue[T]()) } // AlignOf returns the alignment (in bytes) of the given type. // // Not to be confused with [unsafe.AlignOf] which returns the alignment of a type via an expression. func AlignOf[T any]() uintptr { - var zero T - return unsafe.Alignof(zero) + return unsafe.Alignof(ZeroValue[T]()) +} + +// ZeroValue returns the zero value of a given type. +func ZeroValue[T any]() (_ T) { + return } // BitCast performs a bit conversion between two types of the same size. @@ -43,7 +46,7 @@ func Copy(dst, src unsafe.Pointer, size uintptr) unsafe.Pointer { // Returns dst. func Clear(dst unsafe.Pointer, value byte, count uintptr) unsafe.Pointer { b := (*byte)(dst) - for range count { + for range count { // @todo: loop unroll/maybe use asm? *b = value b = (*byte)(unsafe.Add(dst, 1)) } @@ -72,3 +75,11 @@ func AlignBackward(address uintptr, alignment uintptr) uintptr { } return address &^ (alignment - 1) } + +// Aligned returns if the address is aligned to the given power-of-two alignment. +func Aligned(address uintptr, alignment uintptr) bool { + if alignment == 0 || (alignment&(alignment-1)) != 0 { + panic("aligned: alignment must be a power of two") + } + return address&(alignment-1) == 0 +} diff --git a/osthread/osthread.go b/osthread/osthread.go index c855f66..a10f1b6 100644 --- a/osthread/osthread.go +++ b/osthread/osthread.go @@ -20,6 +20,8 @@ import "runtime" // // Start must be called from the program's main function. Once called, it blocks until entrypoint returns. func Start(entrypoint func()) { + defer runtime.UnlockOSThread() + done := make(chan any) // Run entrypoint in a separate goroutine. diff --git a/pointer/pinned.go b/pointer/pinned.go new file mode 100644 index 0000000..3e852f0 --- /dev/null +++ b/pointer/pinned.go @@ -0,0 +1,87 @@ +package pointer + +import ( + "runtime" + "unsafe" + + "git.brut.systems/judah/xx/mem" +) + +type Pinned[T any] struct { + base unsafe.Pointer + pinner runtime.Pinner +} + +func Pin[T any](ptr *T) (r Pinned[T]) { + r.pinner.Pin(ptr) + r.base = unsafe.Pointer(ptr) + return +} + +func Cast[TOut, TIn any](p Pinned[TIn]) Pinned[TOut] { + return Pinned[TOut]{ + base: unsafe.Pointer(p.base), + pinner: p.pinner, + } +} + +func (p Pinned[T]) Unpin() { + p.pinner.Unpin() + p.base = nil +} + +func (p Pinned[T]) Pointer() unsafe.Pointer { + return p.base +} + +func (p Pinned[T]) Address() uintptr { + return uintptr(p.base) +} + +func (p Pinned[T]) Nil() bool { + return p.base == nil +} + +func (p Pinned[T]) Add(amount uintptr) Pinned[T] { + return Pinned[T]{ + base: unsafe.Pointer(uintptr(p.base) + amount), + pinner: p.pinner, + } +} + +func (p Pinned[T]) Sub(amount uintptr) Pinned[T] { + return Pinned[T]{ + base: unsafe.Pointer(uintptr(p.base) - amount), + pinner: p.pinner, + } +} + +func (p Pinned[T]) Aligned() bool { + return mem.Aligned(uintptr(p.base), mem.AlignOf[T]()) +} + +func (p Pinned[T]) AlignForward() Pinned[T] { + return Pinned[T]{ + base: unsafe.Pointer(mem.AlignForward(uintptr(p.base), mem.AlignOf[T]())), + pinner: p.pinner, + } +} + +func (p Pinned[T]) AlignBackward() Pinned[T] { + return Pinned[T]{ + base: unsafe.Pointer(mem.AlignBackward(uintptr(p.base), mem.AlignOf[T]())), + pinner: p.pinner, + } +} + +func (p Pinned[T]) Load() T { + return *(*T)(p.base) +} + +func (p Pinned[T]) Store(value T) { + *(*T)(p.base) = value +} + +func (p Pinned[T]) Nth(index int) T { + return p.Add(uintptr(index) * mem.SizeOf[T]()).Load() +} diff --git a/stable/pointer_test.go b/pointer/pinned_test.go similarity index 66% rename from stable/pointer_test.go rename to pointer/pinned_test.go index 1587e8c..5c0a659 100644 --- a/stable/pointer_test.go +++ b/pointer/pinned_test.go @@ -1,36 +1,35 @@ -package stable_test +package pointer_test import ( "testing" - "git.brut.systems/judah/xx/stable" + "git.brut.systems/judah/xx/pointer" ) func TestPointer_AlignForward(t *testing.T) { tests := []struct { - name string - offset uintptr - alignment uintptr + name string + offset uintptr }{ - {"align 8 bytes", 1, 8}, - {"align 16 bytes", 3, 16}, - {"align 32 bytes", 7, 32}, - {"align 64 bytes", 15, 64}, + {"align 8 bytes", 1}, + {"align 16 bytes", 3}, + {"align 32 bytes", 7}, + {"align 64 bytes", 15}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { value := int32(789) - pinned := stable.PointerPin(&value) + pinned := pointer.Pin(&value) defer pinned.Unpin() // Add offset to misalign misaligned := pinned.Add(tt.offset) - aligned := misaligned.AlignForward(tt.alignment) + aligned := misaligned.AlignForward() // Check alignment - if aligned.Address()%tt.alignment != 0 { - t.Errorf("Address %d is not aligned to %d bytes", aligned.Address(), tt.alignment) + if !aligned.Aligned() { + t.Errorf("Address %d is not aligned", aligned.Address()) } // Check it's forward aligned (greater or equal) @@ -43,29 +42,28 @@ func TestPointer_AlignForward(t *testing.T) { func TestPointer_AlignBackward(t *testing.T) { tests := []struct { - name string - offset uintptr - alignment uintptr + name string + offset uintptr }{ - {"align 8 bytes", 5, 8}, - {"align 16 bytes", 10, 16}, - {"align 32 bytes", 20, 32}, - {"align 64 bytes", 40, 64}, + {"align 8 bytes", 5}, + {"align 16 bytes", 10}, + {"align 32 bytes", 20}, + {"align 64 bytes", 40}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { value := int32(321) - pinned := stable.PointerPin(&value) + pinned := pointer.Pin(&value) defer pinned.Unpin() // Add offset to misalign misaligned := pinned.Add(tt.offset) - aligned := misaligned.AlignBackward(tt.alignment) + aligned := misaligned.AlignBackward() // Check alignment - if aligned.Address()%tt.alignment != 0 { - t.Errorf("Address %d is not aligned to %d bytes", aligned.Address(), tt.alignment) + if !aligned.Aligned() { + t.Errorf("Address %d is not aligned", aligned.Address()) } // Check it's backward aligned (less or equal) @@ -79,7 +77,7 @@ func TestPointer_AlignBackward(t *testing.T) { func TestPointer_Nth(t *testing.T) { // Test with int32 array arr := []int32{10, 20, 30, 40, 50} - pinned := stable.PointerPin(&arr[0]) + pinned := pointer.Pin(&arr[0]) defer pinned.Unpin() for i := 0; i < len(arr); i++ { @@ -93,7 +91,7 @@ func TestPointer_Nth(t *testing.T) { func TestPointer_NthFloat64(t *testing.T) { // Test with a float64 array arr := []float64{1.1, 2.2, 3.3, 4.4, 5.5} - pinned := stable.PointerPin(&arr[0]) + pinned := pointer.Pin(&arr[0]) defer pinned.Unpin() for i := 0; i < len(arr); i++ { @@ -106,7 +104,7 @@ func TestPointer_NthFloat64(t *testing.T) { func TestPointerArithmeticChain(t *testing.T) { value := int32(888) - pinned := stable.PointerPin(&value) + pinned := pointer.Pin(&value) defer pinned.Unpin() // Test chaining operations @@ -121,10 +119,10 @@ func TestPointerArithmeticChain(t *testing.T) { func TestPointer_Cast(t *testing.T) { value := int32(123) - i32 := stable.PointerPin(&value) + i32 := pointer.Pin(&value) defer i32.Unpin() - f32 := stable.PointerCast[float32](i32) + f32 := pointer.Cast[float32](i32) f32.Store(3.14) if value == 123 { diff --git a/stable/pointer.go b/stable/pointer.go deleted file mode 100644 index f971f06..0000000 --- a/stable/pointer.go +++ /dev/null @@ -1,83 +0,0 @@ -package stable - -import ( - "runtime" - "unsafe" - - "git.brut.systems/judah/xx/mem" -) - -type Pointer[T any] struct { - base unsafe.Pointer - pinner runtime.Pinner -} - -func PointerPin[T any](ptr *T) (r Pointer[T]) { - r.pinner.Pin(ptr) - r.base = unsafe.Pointer(ptr) - return -} - -func PointerCast[TOut, TIn any](p Pointer[TIn]) Pointer[TOut] { - return Pointer[TOut]{ - base: unsafe.Pointer(p.base), - pinner: p.pinner, - } -} - -func (p Pointer[T]) Unpin() { - p.pinner.Unpin() - p.base = nil -} - -func (p Pointer[T]) Pointer() unsafe.Pointer { - return p.base -} - -func (p Pointer[T]) Address() uintptr { - return uintptr(p.base) -} - -func (p Pointer[T]) Nil() bool { - return p.base == nil -} - -func (p Pointer[T]) Add(amount uintptr) Pointer[T] { - return Pointer[T]{ - base: unsafe.Pointer(uintptr(p.base) + amount), - pinner: p.pinner, - } -} - -func (p Pointer[T]) Sub(amount uintptr) Pointer[T] { - return Pointer[T]{ - base: unsafe.Pointer(uintptr(p.base) - amount), - pinner: p.pinner, - } -} - -func (p Pointer[T]) AlignForward(alignment uintptr) Pointer[T] { - return Pointer[T]{ - base: unsafe.Pointer(mem.AlignForward(uintptr(p.base), alignment)), - pinner: p.pinner, - } -} - -func (p Pointer[T]) AlignBackward(alignment uintptr) Pointer[T] { - return Pointer[T]{ - base: unsafe.Pointer(mem.AlignBackward(uintptr(p.base), alignment)), - pinner: p.pinner, - } -} - -func (p Pointer[T]) Load() T { - return *(*T)(p.base) -} - -func (p Pointer[T]) Store(value T) { - *(*T)(p.base) = value -} - -func (p Pointer[T]) Nth(index int) T { - return p.Add(uintptr(index) * mem.SizeOf[T]()).Load() -} diff --git a/stable/xar_test.go b/stable/xar_test.go deleted file mode 100644 index 134a4e2..0000000 --- a/stable/xar_test.go +++ /dev/null @@ -1,104 +0,0 @@ -package stable_test - -import ( - "runtime" - "testing" - - "git.brut.systems/judah/xx/stable" -) - -func TestXar_StableWithGC(t *testing.T) { - type valuewithptr struct { - value int - ptr *int - } - - var xar stable.Xar[valuewithptr] - xar.InitWithSize(8) - - aptr := xar.Append(valuewithptr{value: 10, ptr: nil}) - bptr := xar.Append(valuewithptr{value: 20, ptr: &aptr.value}) - - const N = 1000 - for i := range N { - xar.Append(valuewithptr{value: i}) - runtime.GC() - } - - expect(t, xar.Get(0) == aptr) - expect(t, xar.Get(1) == bptr) - expect(t, xar.Len() == N+2, "len was %d", xar.Len()) - expect(t, bptr.ptr != nil && bptr.value == 20) - expect(t, bptr.ptr == &aptr.value, "%p vs. %p", bptr.ptr, &aptr.value) -} - -func TestXar_ResetAndReuse(t *testing.T) { - var xar stable.Xar[int] - start := xar.Append(60) - xar.AppendMany(10, 20, 30, 40, 50) - - xar.Reset() - runtime.GC() - - expect(t, xar.Cap() != 0) - expect(t, xar.Len() == 0) - - xar.Append(0xFF) - xar.Append(0xFC) - xar.Append(0xFB) - - expect(t, xar.Get(0) == start) - expect(t, xar.Len() == 3) -} - -func TestXar_Iterators(t *testing.T) { - var xar stable.Xar[int] - xar.AppendMany(0, 1, 2, 3, 4, 5) - - iterations := 0 - for i, v := range xar.Values() { - iterations += 1 - expect(t, v == i, "v: %d, i: %d", v, i) - } - - expect(t, iterations == xar.Len()) -} - -func BenchmarkXar_Append(b *testing.B) { - var xar stable.Xar[int] - for i := range b.N { - xar.Append(i * i) - } - - xar.Reset() - for i := range b.N { - xar.Append(i * i) - } -} - -func BenchmarkXar_RandomAccess(b *testing.B) { - var xar stable.Xar[int] - for i := range b.N { - xar.Append(i * i) - } - - b.ResetTimer() - - for i := range b.N { - xar.Get(i % 10000) - } -} - -func BenchmarkXar_Iteration(b *testing.B) { - var xar stable.Xar[int] - for i := range b.N { - xar.Append(i * i) - } - - b.ResetTimer() - - sum := 0 - for _, v := range xar.Values() { - sum += v - } -} diff --git a/testx/testx.go b/testx/testx.go new file mode 100644 index 0000000..c60f499 --- /dev/null +++ b/testx/testx.go @@ -0,0 +1,16 @@ +package testx + +import "testing" + +func Expect(t *testing.T, cond bool, message ...any) { + t.Helper() + + if !cond { + if len(message) == 0 { + message = append(message, "assertion failed") + } + + str := message[0].(string) + t.Fatalf(str, message[1:]...) + } +} diff --git a/union/union.go b/union/union.go new file mode 100644 index 0000000..88a6b89 --- /dev/null +++ b/union/union.go @@ -0,0 +1,164 @@ +package union + +import ( + "errors" + "fmt" + "reflect" + "strings" + "unsafe" + + "git.brut.systems/judah/xx/mem" +) + +var ( + ErrUninitializedAccess = errors.New("access of uninitialized union") + ErrInvalidType = errors.New("type does not exist within union") +) + +// anystruct represents a struct type with any members. +// +// Note: because Go's type constraint system can't enforce +// this, anystruct is here for documentation purposes. +type anystruct any + +// @note(judah): is there a way to declare the type parameters +// to allow 'type Value union.Of[...]' so users can define their +// own methods? + +// Of represents a union of different types. +// +// Since members are accessed by type instead of name, +// T is expected to be a struct of types like so: +// +// type Value = union.Of[struct { +// int32 +// uint32 +// float32 +// }) +type Of[T anystruct] struct { + typ reflect.Kind + mem []byte +} + +func (u Of[T]) Size() uintptr { + return mem.SizeOf[T]() +} + +// String returns the string representation of a union. +func (u Of[T]) String() string { + var b strings.Builder + b.WriteString("union[") + if u.typ == reflect.Invalid { + b.WriteString("none") + } else { + b.WriteString(u.typ.String()) + } + b.WriteString("] {") + + t := reflect.TypeFor[T]() + if t.Kind() == reflect.Struct { + b.WriteByte(' ') + fields := getInternalFields(u) + for i, field := range fields { + b.WriteString(field.Type.String()) + if i < len(fields)-1 { + b.WriteString("; ") + } + } + b.WriteByte(' ') + } + + b.WriteByte('}') + return b.String() +} + +// Is returns true if the given type is currently stored in the union. +func Is[E any, T anystruct](u Of[T]) bool { + // Explicit invalid check to make sure invalid types don't result in false-positives. + if u.typ == reflect.Invalid { + return false + } + + return u.typ == reflect.TypeFor[E]().Kind() +} + +// Set overwrites the backing memory of a union with the given value; initializing the union if uninitialized. +// +// Set is unsafe and will not verify if the backing memory has enough capacity to store the value. +// Use [SetSafe] for more safety checks. +func Set[V any, T anystruct](u *Of[T], value V) { + if u.mem == nil { + u.mem = make([]byte, mem.SizeOf[T]()) + } + + *(*V)(unsafe.Pointer(&u.mem[0])) = value + u.typ = reflect.TypeFor[V]().Kind() +} + +// SetSafe overwrites the backing memory of a union with the given value, +// returning an error if the value cannot be stored in the union. +// +// Use [Set] for fewer safety checks. +func SetSafe[V any, T anystruct](u *Of[T], value V) error { + if u.mem == nil { + u.mem = make([]byte, mem.SizeOf[T]()) + } + + vt := reflect.TypeFor[V]() + for _, field := range getInternalFields(*u) { + if field.Type == vt { + *(*V)(unsafe.Pointer(&u.mem[0])) = value + u.typ = reflect.TypeFor[V]().Kind() + return nil + } + } + + return fmt.Errorf("%s - %w", vt, ErrInvalidType) +} + +// Get returns the union's backing memory interpreted as a value of type V, panicking if the union is uninitialized. +// +// Get is unsafe and will not verify if the type exists within the union. +// Use [GetSafe] for more safety checks. +func Get[V any, T anystruct](u Of[T]) V { + if u.mem == nil { + panic(ErrUninitializedAccess) + } + + return *(*V)(unsafe.Pointer(&u.mem[0])) +} + +// GetSafe returns the union's backing memory interpreted as a value of type V, returning an error if the type +// does not exist within the union or the union is uninitialized. +// +// Use [Get] for fewer safety checks. +func GetSafe[V any, T anystruct](u Of[T]) (V, error) { + if u.mem == nil { + return mem.ZeroValue[V](), ErrUninitializedAccess + } + + vt := reflect.TypeFor[V]() + for _, field := range getInternalFields(u) { + if field.Type == vt { + return *(*V)(unsafe.Pointer(&u.mem[0])), nil + } + } + + return mem.ZeroValue[V](), ErrInvalidType +} + +// getInternalFields returns an array of reflect.StructField belonging +// to the internal type of a union. +func getInternalFields[U Of[T], T anystruct](_ U) []reflect.StructField { + backing := reflect.TypeFor[T]() + if backing.Kind() != reflect.Struct { + return nil + } + + var fields []reflect.StructField + for i := range backing.NumField() { + fields = append(fields, backing.Field(i)) + } + + return fields +} diff --git a/union/union_test.go b/union/union_test.go new file mode 100644 index 0000000..11b93a5 --- /dev/null +++ b/union/union_test.go @@ -0,0 +1,192 @@ +package union_test + +import ( + "testing" + + "git.brut.systems/judah/xx/union" +) + +func TestUnion_BasicGetSet(t *testing.T) { + type Numbers = union.Of[struct { + uint8 + bool + }] + + var num Numbers + union.Set[uint8](&num, 1) + + b := union.Get[bool](num) + if !b { + t.Errorf("expected bool value to be true, was %v", b) + } + + union.Set(&num, false) + + i := union.Get[uint8](num) + if i != 0 { + t.Errorf("expected uint8 value to be 0, was %v", i) + } +} + +type ( + expr = union.Of[struct { + binaryExpr + intExpr + floatExpr + }] + binaryExpr struct { + Op string + Lhs expr + Rhs expr + } + intExpr int64 + floatExpr float64 +) + +func TestUnion_OfStructs(t *testing.T) { + makeInt := func(value int64) (e expr) { + union.Set(&e, intExpr(value)) + return + } + makeFloat := func(value float64) (e expr) { + union.Set(&e, floatExpr(value)) + return + } + makeBinop := func(op string, lhs, rhs expr) (e expr) { + union.Set(&e, binaryExpr{ + Op: op, + Lhs: lhs, + Rhs: rhs, + }) + return + } + + expr1 := makeBinop("+", makeInt(10), makeInt(20)) + bin1 := union.Get[binaryExpr](expr1) + if bin1.Op != "+" { + t.Errorf("incorrect op returned from union: %s", bin1.Op) + } + if lhs := union.Get[intExpr](bin1.Lhs); lhs != 10 { + t.Errorf("incorrect lhs returned from union: %v", lhs) + } + if rhs := union.Get[intExpr](bin1.Rhs); rhs != 20 { + t.Errorf("incorrect rhs returned from union: %v", rhs) + } + + expr2 := makeBinop("-", expr1, makeFloat(3.14)) + bin2 := union.Get[binaryExpr](expr2) + if bin2.Op != "-" { + t.Errorf("incorrect op returned from union of union: %s", bin2.Op) + } + if lhs := union.Get[binaryExpr](bin2.Lhs); lhs.Op != "+" { + t.Errorf("incorrect lhs returned from union of union: %v", lhs) + } + if rhs := union.Get[floatExpr](bin2.Rhs); rhs != 3.14 { + t.Errorf("incorrect rhs returned from union of union: %v", rhs) + } +} + +func TestUnion_OfPointers(t *testing.T) { + type Value = union.Of[struct { + *float64 + *uint64 + }] + + var ( + original uint64 = 100 + value Value + ) + + if union.Is[*uint64](value) || union.Is[*float64](value) { + t.Error("union internal type was incorrect before usage") + } + + union.Set(&value, &original) + + if !union.Is[*uint64](value) { + t.Error("union internal type was incorrect after Set") + } + + fptr := union.Get[*float64](value) + *fptr = 3.14 + + if original == 100 { + t.Error("original value did not change") + } + + uptr := union.Get[*uint64](value) + *uptr = 200 + + if *fptr == 3.14 { + t.Error("float pointer value did not change after modification") + } + + if original != 200 { + t.Errorf("original value was incorrect: %v", original) + } +} + +func TestUnion_ToString(t *testing.T) { + type ( + Struct = union.Of[struct { + int32 + uint32 + }] + Interface = union.Of[interface { + Int() + Bool() + }] + Bool = union.Of[bool] + ) + + var ( + s Struct + i Interface + b Bool + ) + + if s.String() != "union[none] { int32; uint32 }" { + t.Errorf("valid union had invalid stringification: %s", s.String()) + } + + if i.String() != b.String() { + t.Errorf("invalid union had invalid stringification: %s, %s", i.String(), b.String()) + } + + union.Set[int32](&s, 10) + + if s.String() != "union[int32] { int32; uint32 }" { + t.Errorf("valid union had invalid stringification after Set: %s", s.String()) + } +} + +func TestUnion_SafeUsage(t *testing.T) { + type Value = union.Of[struct { + int32 + uint32 + float32 + }] + + var v Value + if _, err := union.GetSafe[int32](v); err == nil { + t.Errorf("GetSafe did not error for an uninitialized union") + } + + if err := union.SetSafe(&v, false); err == nil { + t.Error("SetSafe allowed invalid type") + } + + if err := union.SetSafe[int32](&v, 10); err != nil { + t.Errorf("SetSafe failed with valid type: %s", err) + } + + if _, err := union.GetSafe[bool](v); err == nil { + t.Errorf("GetSafe allowed invalid type") + } + + if v, err := union.GetSafe[int32](v); err != nil { + t.Errorf("GetSafe failed with valid type: %s", err) + } else if v != 10 { + t.Errorf("GetSafe returned invalid value: %v", v) + } +}