192 lines
3.8 KiB
Go
192 lines
3.8 KiB
Go
package union_test
|
|
|
|
import (
|
|
"testing"
|
|
|
|
"git.brut.systems/judah/xx/union"
|
|
)
|
|
|
|
func TestUnion_BasicGetSet(t *testing.T) {
|
|
type Numbers = union.Of[struct {
|
|
uint8
|
|
bool
|
|
}]
|
|
|
|
var num Numbers
|
|
union.Set[uint8](&num, 1)
|
|
|
|
b := union.Get[bool](num)
|
|
if !b {
|
|
t.Errorf("expected bool value to be true, was %v", b)
|
|
}
|
|
|
|
union.Set(&num, false)
|
|
|
|
i := union.Get[uint8](num)
|
|
if i != 0 {
|
|
t.Errorf("expected uint8 value to be 0, was %v", i)
|
|
}
|
|
}
|
|
|
|
type (
|
|
expr = union.Of[struct {
|
|
binaryExpr
|
|
intExpr
|
|
floatExpr
|
|
}]
|
|
binaryExpr struct {
|
|
Op string
|
|
Lhs expr
|
|
Rhs expr
|
|
}
|
|
intExpr int64
|
|
floatExpr float64
|
|
)
|
|
|
|
func TestUnion_OfStructs(t *testing.T) {
|
|
makeInt := func(value int64) (e expr) {
|
|
union.Set(&e, intExpr(value))
|
|
return
|
|
}
|
|
makeFloat := func(value float64) (e expr) {
|
|
union.Set(&e, floatExpr(value))
|
|
return
|
|
}
|
|
makeBinop := func(op string, lhs, rhs expr) (e expr) {
|
|
union.Set(&e, binaryExpr{
|
|
Op: op,
|
|
Lhs: lhs,
|
|
Rhs: rhs,
|
|
})
|
|
return
|
|
}
|
|
|
|
expr1 := makeBinop("+", makeInt(10), makeInt(20))
|
|
bin1 := union.Get[binaryExpr](expr1)
|
|
if bin1.Op != "+" {
|
|
t.Errorf("incorrect op returned from union: %s", bin1.Op)
|
|
}
|
|
if lhs := union.Get[intExpr](bin1.Lhs); lhs != 10 {
|
|
t.Errorf("incorrect lhs returned from union: %v", lhs)
|
|
}
|
|
if rhs := union.Get[intExpr](bin1.Rhs); rhs != 20 {
|
|
t.Errorf("incorrect rhs returned from union: %v", rhs)
|
|
}
|
|
|
|
expr2 := makeBinop("-", expr1, makeFloat(3.14))
|
|
bin2 := union.Get[binaryExpr](expr2)
|
|
if bin2.Op != "-" {
|
|
t.Errorf("incorrect op returned from union of union: %s", bin2.Op)
|
|
}
|
|
if lhs := union.Get[binaryExpr](bin2.Lhs); lhs.Op != "+" {
|
|
t.Errorf("incorrect lhs returned from union of union: %v", lhs)
|
|
}
|
|
if rhs := union.Get[floatExpr](bin2.Rhs); rhs != 3.14 {
|
|
t.Errorf("incorrect rhs returned from union of union: %v", rhs)
|
|
}
|
|
}
|
|
|
|
func TestUnion_OfPointers(t *testing.T) {
|
|
type Value = union.Of[struct {
|
|
*float64
|
|
*uint64
|
|
}]
|
|
|
|
var (
|
|
original uint64 = 100
|
|
value Value
|
|
)
|
|
|
|
if union.Is[*uint64](value) || union.Is[*float64](value) {
|
|
t.Error("union internal type was incorrect before usage")
|
|
}
|
|
|
|
union.Set(&value, &original)
|
|
|
|
if !union.Is[*uint64](value) {
|
|
t.Error("union internal type was incorrect after Set")
|
|
}
|
|
|
|
fptr := union.Get[*float64](value)
|
|
*fptr = 3.14
|
|
|
|
if original == 100 {
|
|
t.Error("original value did not change")
|
|
}
|
|
|
|
uptr := union.Get[*uint64](value)
|
|
*uptr = 200
|
|
|
|
if *fptr == 3.14 {
|
|
t.Error("float pointer value did not change after modification")
|
|
}
|
|
|
|
if original != 200 {
|
|
t.Errorf("original value was incorrect: %v", original)
|
|
}
|
|
}
|
|
|
|
func TestUnion_ToString(t *testing.T) {
|
|
type (
|
|
Struct = union.Of[struct {
|
|
int32
|
|
uint32
|
|
}]
|
|
Interface = union.Of[interface {
|
|
Int()
|
|
Bool()
|
|
}]
|
|
Bool = union.Of[bool]
|
|
)
|
|
|
|
var (
|
|
s Struct
|
|
i Interface
|
|
b Bool
|
|
)
|
|
|
|
if s.String() != "union[none] { int32; uint32 }" {
|
|
t.Errorf("valid union had invalid stringification: %s", s.String())
|
|
}
|
|
|
|
if i.String() != b.String() {
|
|
t.Errorf("invalid union had invalid stringification: %s, %s", i.String(), b.String())
|
|
}
|
|
|
|
union.Set[int32](&s, 10)
|
|
|
|
if s.String() != "union[int32] { int32; uint32 }" {
|
|
t.Errorf("valid union had invalid stringification after Set: %s", s.String())
|
|
}
|
|
}
|
|
|
|
func TestUnion_SafeUsage(t *testing.T) {
|
|
type Value = union.Of[struct {
|
|
int32
|
|
uint32
|
|
float32
|
|
}]
|
|
|
|
var v Value
|
|
if _, err := union.GetSafe[int32](v); err == nil {
|
|
t.Errorf("GetSafe did not error for an uninitialized union")
|
|
}
|
|
|
|
if err := union.SetSafe(&v, false); err == nil {
|
|
t.Error("SetSafe allowed invalid type")
|
|
}
|
|
|
|
if err := union.SetSafe[int32](&v, 10); err != nil {
|
|
t.Errorf("SetSafe failed with valid type: %s", err)
|
|
}
|
|
|
|
if _, err := union.GetSafe[bool](v); err == nil {
|
|
t.Errorf("GetSafe allowed invalid type")
|
|
}
|
|
|
|
if v, err := union.GetSafe[int32](v); err != nil {
|
|
t.Errorf("GetSafe failed with valid type: %s", err)
|
|
} else if v != 10 {
|
|
t.Errorf("GetSafe returned invalid value: %v", v)
|
|
}
|
|
}
|