diff --git a/vm/interp.jai b/vm/interp.jai index 0b1a15f..6c1762a 100644 --- a/vm/interp.jai +++ b/vm/interp.jai @@ -63,7 +63,7 @@ interp_statement :: (i: *Interp, stmt: *Node, scope: *Interp_Scope) { value := value_nil; if var.value_expr != null { - value = interp_expr(i, var.value_expr, scope); + value = interp_expression(i, var.value_expr, scope); basic.assert(value != null); // @errors } @@ -72,7 +72,7 @@ interp_statement :: (i: *Interp, stmt: *Node, scope: *Interp_Scope) { case .assign; assign := stmt.(*Node_Assign); - src := interp_expr(i, assign.src, scope); + src := interp_expression(i, assign.src, scope); basic.assert(src != null); // @errors dst := interp_lvalue(i, assign.dst, scope); @@ -80,10 +80,8 @@ interp_statement :: (i: *Interp, stmt: *Node, scope: *Interp_Scope) { basic.assert(dst.kind == .value); // @errors // @todo: typechecking - if assign.op.kind == { - case .equal; - dst.val.* = src.*; - } + basic.assert(assign.op.kind == .equal, "unexpected assign op: %", assign.op.kind); + dst.val.* = src.*; case .procedure; proc := stmt.(*Node_Procedure); @@ -96,7 +94,7 @@ interp_statement :: (i: *Interp, stmt: *Node, scope: *Interp_Scope) { case .print; print := stmt.(*Node_Print); - expr := interp_expr(i, print.expr, scope); + expr := interp_expression(i, print.expr, scope); if expr == null return; if expr.kind == { @@ -111,8 +109,15 @@ interp_statement :: (i: *Interp, stmt: *Node, scope: *Interp_Scope) { basic.print("\n"); + case .assert; + assert := stmt.(*Node_Assert); + expr := interp_expression(i, assert.expr, scope); + basic.assert(expr != null, "runtime assertion failed"); + basic.assert(expr.kind == .bool, "non-boolean expression given to assert"); + basic.assert(expr.b, "runtime assertion failed"); + case; - interp_expr(i, stmt, scope); + interp_expression(i, stmt, scope); // basic.assert(false, "unhandled node kind: %", stmt.kind); // @errors } } @@ -133,7 +138,7 @@ interp_lvalue :: (i: *Interp, expr: *Node, scope: *Interp_Scope) -> *Interp_Valu return null; } -interp_expr :: (i: *Interp, expr: *Node, scope: *Interp_Scope) -> *Interp_Value { +interp_expression :: (i: *Interp, expr: *Node, scope: *Interp_Scope) -> *Interp_Value { if expr.kind == { case .procedure_call; call := expr.(*Node_Procedure_Call); @@ -153,14 +158,14 @@ interp_expr :: (i: *Interp, expr: *Node, scope: *Interp_Scope) -> *Interp_Value proc_scope := make_scope(scope); for proc.arguments { - kv.set(*proc_scope.bindings, it.symbol.str, interp_expr(i, args[it_index], scope)); + kv.set(*proc_scope.bindings, it.symbol.str, interp_expression(i, args[it_index], scope)); } for expr: proc.block.body { if expr.kind == .return_ { ret := expr.(*Node_Return); if ret.values.count != 0 { - result = interp_expr(i, ret.values[0], proc_scope); + result = interp_expression(i, ret.values[0], proc_scope); } break; @@ -187,7 +192,7 @@ interp_expr :: (i: *Interp, expr: *Node, scope: *Interp_Scope) -> *Interp_Value } un := expr.(*Node_Unary); - rhs := interp_expr(i, un.right, scope); + rhs := interp_expression(i, un.right, scope); res := make_interp_value(i, rhs.kind); if un.op.kind == { case .plus; do_unop(#code ifx right < 0 then -right else right); @@ -199,8 +204,8 @@ interp_expr :: (i: *Interp, expr: *Node, scope: *Interp_Scope) -> *Interp_Value case .binary; bin := expr.(*Node_Binary); - lhs := interp_expr(i, bin.left, scope); - rhs := interp_expr(i, bin.right, scope); + lhs := interp_expression(i, bin.left, scope); + rhs := interp_expression(i, bin.right, scope); basic.assert(lhs.kind == rhs.kind, "type mismatch % vs. %", lhs.kind, rhs.kind); // @errors do_binop :: (code: Code) #expand { @@ -220,16 +225,38 @@ interp_expr :: (i: *Interp, expr: *Node, scope: *Interp_Scope) -> *Interp_Value res := make_interp_value(i, lhs.kind); if bin.op.kind == { - case .plus; do_binop(#code left + right); - case .minus; do_binop(#code left - right); - case .star; do_binop(#code left * right); + case .plus; do_binop(#code left + right); + case .minus; do_binop(#code left - right); + case .star; do_binop(#code left * right); + case .f_slash; basic.assert(rhs.i != 0, "divide by zero"); // @errors do_binop(#code left / right); + case .percent; basic.assert(lhs.kind == .int, "cannot use binary operator '%%' on values of type '%'", lhs.kind); res.i = lhs.i % rhs.i; + // @todo: typechecking + case .equal_equal; + res.kind = .bool; + res.b = lhs.i == rhs.i; + case .bang_equal; + res.kind = .bool; + res.b = lhs.i != rhs.i; + case .less; + res.kind = .bool; + res.b = lhs.i < rhs.i; + case .less_equal; + res.kind = .bool; + res.b = lhs.i <= rhs.i; + case .more; + res.kind = .bool; + res.b = lhs.i > rhs.i; + case .more_equal; + res.kind = .bool; + res.b = lhs.i >= rhs.i; + case; basic.assert(false, "unhandled binary operator '%'", bin.op.str); } diff --git a/vm/module.jai b/vm/module.jai index 4992ae4..3ffe284 100644 --- a/vm/module.jai +++ b/vm/module.jai @@ -28,6 +28,7 @@ strings :: #import "String"; // @future var x = 21.0 var y = 22.0 + var z = x + y x = x + 1.0 / 2.0 print x @@ -35,7 +36,12 @@ strings :: #import "String"; // @future x = add(x, div(1.0, 2.0)) print x - print add(x, y) + print x == x + print x == y + print x == z + + assert x == y + assert x != z END); interp: Interp; diff --git a/vm/parser.jai b/vm/parser.jai index 1f0ca5b..293b178 100644 --- a/vm/parser.jai +++ b/vm/parser.jai @@ -31,6 +31,23 @@ Token :: struct { kw_false; kw_print; + kw_assert; + + equal_equal; // == + bang_equal; // != + and_equal; // &= + or_equal; // |= + less_equal; // <= + more_equal; // >= + + plus_equal; // += + minus_equal; // -= + star_equal; // *= + f_slash_equal; // /= + percent_equal; // %= + + and_and_equal; // &&= + or_or_equal; // ||= equal :: #char "="; plus :: #char "+"; @@ -39,8 +56,11 @@ Token :: struct { percent :: #char "%"; bang :: #char "!"; and :: #char "&"; + or :: #char "|"; f_slash :: #char "/"; b_slash :: #char "\\"; + less :: #char "<"; + more :: #char ">"; l_paren :: #char "("; r_paren :: #char ")"; @@ -52,6 +72,7 @@ Token :: struct { dot :: #char "."; colon :: #char ":"; semicolon :: #char ";"; + } } @@ -66,6 +87,7 @@ Node :: struct { stmt_start; print; + assert; return_; assign; stmt_end; @@ -218,6 +240,13 @@ Node_Print :: struct { expr: *Node; } +Node_Assert :: struct { + #as using n: Node; + n.kind = .assert; + + expr: *Node; +} + Node_Procedure_Call :: struct { #as using n: Node; n.kind = .procedure_call; @@ -353,6 +382,18 @@ parse_statement :: (p: *Parser) -> *Node { node.expr = expr; return node; + // assert(cond) + // assert cond + case .kw_assert; + consume_token(p); + + expr := parse_expression(p); + basic.assert(expr != null, "expected expression"); // @errors + + node := make_node(p, Node_Assert); + node.expr = expr; + return node; + // fn symbol(arg0, ..argN) do ... end case .kw_fn; consume_token(p); @@ -412,7 +453,12 @@ parse_simple_statement :: (p: *Parser) -> *Node { t := peek_token(p); if t.kind == { - case .equal; + case .equal; #through; + case .plus_equal; #through; + case .minus_equal; #through; + case .star_equal; #through; + case .f_slash_equal; #through; + case .percent_equal; consume_token(p); src := parse_expression(p); @@ -421,7 +467,29 @@ parse_simple_statement :: (p: *Parser) -> *Node { node := make_node(p, Node_Assign); node.op = t; node.dst = dst; - node.src = src; + + if t.kind == .equal { + node.src = src; + } + // transform these into 'dst = dst op src' + else { + bin := make_node(p, Node_Binary); + bin.left = dst; + bin.right = src; + bin.op = t; + + if t.kind == { + case .plus_equal; bin.op.kind = .plus; + case .minus_equal; bin.op.kind = .minus; + case .star_equal; bin.op.kind = .star; + case .f_slash_equal; bin.op.kind = .f_slash; + case .percent_equal; bin.op.kind = .percent; + } + + node.op.kind = .equal; + node.src = bin; + } + return node; } @@ -516,13 +584,13 @@ parse_expression :: (p: *Parser, min_precedence := 1) -> *Node { case .minus; return 3; - // case .equal_equal; #through; - // case .bang_equal; #through; - // case .less; #through; - // case .less_equal; #through; - // case .more; #through; - // case .more_equal; - // return 2; + case .equal_equal; #through; + case .bang_equal; #through; + case .less; #through; + case .less_equal; #through; + case .more; #through; + case .more_equal; + return 2; } return 0; @@ -802,8 +870,9 @@ consume_token :: (p: *Parser) -> Token { case "true"; t.kind = .kw_true; case "false"; t.kind = .kw_false; - case "print"; t.kind = .kw_print; - case; t.kind = .symbol; + case "print"; t.kind = .kw_print; + case "assert"; t.kind = .kw_assert; + case; t.kind = .symbol; } return t; @@ -821,15 +890,36 @@ consume_token :: (p: *Parser) -> Token { return t; } + with_optional_equal :: (current_kind: Token.Kind, equal_kind: Token.Kind) -> Token #expand { + token := Token.{ + kind = current_kind, + str = string.{ data = `p.source.data + `p.offset, count = 1 }, + }; + + `p.offset += 1; + if `p.offset < `p.source.count && `p.source[`p.offset] == #char "=" { + `p.offset += 1; + + token.kind = equal_kind; + token.str.count = 2; + } + + return token; + } + if c == { - case "+"; #through; - case "-"; #through; - case "*"; #through; - case "/"; #through; - case "="; #through; - case "%"; #through; - case "!"; #through; - case "&"; #through; + case "+"; return with_optional_equal(c.(Token.Kind), .plus_equal); + case "-"; return with_optional_equal(c.(Token.Kind), .minus_equal); + case "*"; return with_optional_equal(c.(Token.Kind), .star_equal); + case "/"; return with_optional_equal(c.(Token.Kind), .f_slash_equal); + case "="; return with_optional_equal(c.(Token.Kind), .equal_equal); + case "%"; return with_optional_equal(c.(Token.Kind), .percent_equal); + case "!"; return with_optional_equal(c.(Token.Kind), .bang_equal); + case "&"; return with_optional_equal(c.(Token.Kind), .and_equal); + case "|"; return with_optional_equal(c.(Token.Kind), .or_equal); + case "<"; return with_optional_equal(c.(Token.Kind), .less_equal); + case ">"; return with_optional_equal(c.(Token.Kind), .more_equal); + case ","; #through; case "."; #through; case ":"; #through;