jc/vm/interp.jai
2025-06-04 22:59:22 -06:00

221 lines
6.4 KiB
Text

Interp :: struct {
allocator: Allocator;
symbols: kv.Kv(string, *Interp_Value);
toplevel: []*Node;
}
Interp_Value :: struct {
kind: Kind;
union {
b: bool;
i: s64;
u: u64;
f: float64;
s: string;
p: *void;
proc: *Node_Procedure;
}
Kind :: enum {
none;
nil;
bool;
int;
float;
string;
pointer;
procedure;
}
}
init :: (i: *Interp, allocator: Allocator) {
value_nil = make_interp_value(i, .nil);
value_true = make_interp_value(i, .bool);
value_true.b = true;
value_false = make_interp_value(i, .bool);
value_false.b = false;
}
interp_program :: (i: *Interp) {
for i.toplevel if it.kind == {
case .variable;
var := it.(*Node_Var);
sym := var.symbol;
basic.assert(!kv.exists(*i.symbols, sym.str), "redeclaring symbol '%'", sym.str); // @errors
value := value_nil;
if var.value_expr != null {
value = interp_expr(i, var.value_expr);
basic.assert(value != null); // @errors
}
kv.set(*i.symbols, sym.str, value);
case .procedure;
proc := it.(*Node_Procedure);
sym := proc.symbol;
basic.assert(!kv.exists(*i.symbols, sym.str), "redeclaring procedure '%'", sym.str);
value := make_interp_value(i, .procedure);
value.proc = proc;
kv.set(*i.symbols, sym.str, value);
case .print;
print := it.(*Node_Print);
expr := interp_expr(i, print.expr);
if expr == null continue;
if expr.kind == {
case .none; // do nothing
case .nil; basic.print("nil");
case .bool; basic.print("%", expr.b);
case .int; basic.print("%", expr.i);
case .float; basic.print("%", expr.f);
case .string; basic.print("%", expr.s);
case; basic.assert(false, "unhandled interp value kind: %", expr.kind);
}
basic.print("\n");
case;
basic.assert(false, "unhandled node kind: %", it.kind); // @errors
}
}
interp_expr :: (i: *Interp, expr: *Node) -> *Interp_Value {
if expr.kind == {
case .procedure_call;
call := expr.(*Node_Procedure_Call);
args := call.all_arguments;
// @temp
sym := call.call_expr.(*Node_Symbol);
basic.assert(sym.kind == .symbol);
value, ok := kv.get(*i.symbols, sym.str);
basic.assert(ok, "procedure didn't exists '%'", sym.str);
basic.assert(value.kind == .procedure, "attempt to call non procedure '%'", sym.str);
result := value_nil;
proc := value.proc;
// @todo(judah): check arity, create scope, map args to locals, exec
for expr: proc.block.body {
if expr.kind == .return_ {
ret := expr.(*Node_Return);
if ret.values.count != 0 {
result = interp_expr(i, ret.values[0]);
}
break;
}
}
return result;
case .unary;
do_unop :: (code: Code) #expand {
if rhs.kind == {
case .int;
right := rhs.i;
res.i = #insert,scope() code;
case .float;
right := rhs.f;
res.f = #insert,scope() code;
case;
basic.assert(false, "cannot use unary operator '%' on values of type '%'", un.op, rhs.kind);
}
}
un := expr.(*Node_Unary);
rhs := interp_expr(i, un.right);
res := make_interp_value(i, rhs.kind);
if un.op.kind == {
case .plus; do_unop(#code ifx right < 0 then -right else right);
case .minus; do_unop(#code -right);
case; basic.assert(false, "unhandled unary operator '%'", un.op.str); // @errors
}
return res;
case .binary;
bin := expr.(*Node_Binary);
lhs := interp_expr(i, bin.left);
rhs := interp_expr(i, bin.right);
basic.assert(lhs.kind == rhs.kind, "type mismatch % vs. %", lhs.kind, rhs.kind); // @errors
do_binop :: (code: Code) #expand {
if lhs.kind == {
case .int;
left := lhs.i;
right := rhs.i;
res.i = #insert,scope() code;
case .float;
left := lhs.f;
right := rhs.f;
res.f = #insert,scope() code;
case;
basic.assert(false, "cannot use binary operator '%' on values of type '%'", bin.op, lhs.kind);
}
}
res := make_interp_value(i, lhs.kind);
if bin.op.kind == {
case .plus; do_binop(#code left + right);
case .minus; do_binop(#code left - right);
case .star; do_binop(#code left * right);
case .f_slash;
basic.assert(rhs.i != 0, "divide by zero"); // @errors
do_binop(#code left / right);
case .percent;
basic.assert(lhs.kind == .int, "cannot use binary operator '%%' on values of type '%'", lhs.kind);
res.i = lhs.i % rhs.i;
case; basic.assert(false, "unhandled binary operator '%'", bin.op.str);
}
return res;
case .symbol;
sym := expr.(*Node_Symbol);
value, ok := kv.get(*i.symbols, sym.str);
basic.assert(ok, "use of undeclared symbol '%'", sym.str); // @errors
return value;
case .literal;
lit := expr.(*Node_Literal);
if lit.value_kind == {
case .int;
value := make_interp_value(i, .int);
value.i = lit.i;
return value;
case .float;
value := make_interp_value(i, .float);
value.f = lit.f;
return value;
case; basic.assert(false, "unhandled literal kind: %", lit.value_kind); // @errors
}
case; basic.assert(false, "unhandled node kind: %", expr.kind); // @errors
}
return null;
}
#scope_file;
value_nil: *Interp_Value;
value_true: *Interp_Value;
value_false: *Interp_Value;
make_interp_value :: (i: *Interp, kind: Interp_Value.Kind) -> *Interp_Value {
value := mem.request_memory(Interp_Value,, allocator = i.allocator);
value.kind = kind;
return value;
}