From ee16774fe7cf72e5eba62f90bdb99ccbb77eba7e Mon Sep 17 00:00:00 2001 From: Eduard Urbach Date: Sun, 7 Jul 2024 21:55:32 +0200 Subject: [PATCH] Improved branch compilation --- examples/hello/hello.q | 2 +- src/build/asm/Assembler.go | 14 ++++- src/build/asm/Mnemonic.go | 12 ++--- src/build/core/Compare.go | 47 +++++++++++++++++ src/build/core/CompileCondition.go | 82 ++++++++++++++++++++++++++++++ src/build/core/CompileIf.go | 43 ++-------------- src/build/core/Evaluate.go | 41 --------------- src/build/core/state.go | 10 ++-- tests/programs/branch.q | 22 ++------ 9 files changed, 160 insertions(+), 113 deletions(-) create mode 100644 src/build/core/Compare.go create mode 100644 src/build/core/CompileCondition.go delete mode 100644 src/build/core/Evaluate.go diff --git a/examples/hello/hello.q b/examples/hello/hello.q index f6e8e26..7cdf32f 100644 --- a/examples/hello/hello.q +++ b/examples/hello/hello.q @@ -1,7 +1,7 @@ main() { x := f(1) + f(2) + f(3) - if x != f(8) { + if x != f(8) || x != 9 || x == 6 { exit(42) } diff --git a/src/build/asm/Assembler.go b/src/build/asm/Assembler.go index caf600a..18c536d 100644 --- a/src/build/asm/Assembler.go +++ b/src/build/asm/Assembler.go @@ -57,7 +57,12 @@ func (a Assembler) Finalize() ([]byte, []byte) { Position: Address(len(code) - size), Size: uint8(size), Resolve: func() Address { - destination := labels[label.Name] + destination, exists := labels[label.Name] + + if !exists { + panic("unknown call label") + } + distance := destination - nextInstructionAddress return Address(distance) }, @@ -100,7 +105,12 @@ func (a Assembler) Finalize() ([]byte, []byte) { Position: Address(len(code) - size), Size: uint8(size), Resolve: func() Address { - destination := labels[label.Name] + destination, exists := labels[label.Name] + + if !exists { + panic("unknown jump label") + } + distance := destination - nextInstructionAddress return Address(distance) }, diff --git a/src/build/asm/Mnemonic.go b/src/build/asm/Mnemonic.go index 06bd87a..d25bc99 100644 --- a/src/build/asm/Mnemonic.go +++ b/src/build/asm/Mnemonic.go @@ -42,17 +42,17 @@ func (m Mnemonic) String() string { case JUMP: return "jump" case JE: - return "jump ==" + return "jump if ==" case JNE: - return "jump !=" + return "jump if !=" case JL: - return "jump <" + return "jump if <" case JG: - return "jump >" + return "jump if >" case JLE: - return "jump <=" + return "jump if <=" case JGE: - return "jump >=" + return "jump if >=" case LABEL: return "label" case MOVE: diff --git a/src/build/core/Compare.go b/src/build/core/Compare.go new file mode 100644 index 0000000..955f67e --- /dev/null +++ b/src/build/core/Compare.go @@ -0,0 +1,47 @@ +package core + +import ( + "git.akyoto.dev/cli/q/src/build/ast" + "git.akyoto.dev/cli/q/src/build/errors" + "git.akyoto.dev/cli/q/src/build/expression" + "git.akyoto.dev/cli/q/src/build/token" +) + +// Compare evaluates a boolean expression. +func (f *Function) Compare(comparison *expression.Expression) error { + left := comparison.Children[0] + right := comparison.Children[1] + + if left.IsLeaf() && left.Token.Kind == token.Identifier { + name := left.Token.Text() + variable, exists := f.variables[name] + + if !exists { + return errors.New(&errors.UnknownIdentifier{Name: name}, f.File, left.Token.Position) + } + + defer f.useVariable(variable) + return f.Execute(comparison.Token, variable.Register, right) + } + + if ast.IsFunctionCall(left) && right.IsLeaf() { + err := f.CompileCall(left) + + if err != nil { + return err + } + + return f.Execute(comparison.Token, f.cpu.Output[0], right) + } + + tmp := f.cpu.MustFindFree(f.cpu.General) + err := f.ExpressionToRegister(left, tmp) + + if err != nil { + return err + } + + f.cpu.Use(tmp) + defer f.cpu.Free(tmp) + return f.Execute(comparison.Token, tmp, right) +} diff --git a/src/build/core/CompileCondition.go b/src/build/core/CompileCondition.go new file mode 100644 index 0000000..3a50df7 --- /dev/null +++ b/src/build/core/CompileCondition.go @@ -0,0 +1,82 @@ +package core + +import ( + "git.akyoto.dev/cli/q/src/build/asm" + "git.akyoto.dev/cli/q/src/build/expression" +) + +// CompileCondition inserts code to jump to the start label or end label depending on the truth of the condition. +func (f *Function) CompileCondition(condition *expression.Expression, startLabel string, endLabel string) error { + switch condition.Token.Text() { + case "||": + left := condition.Children[0] + err := f.CompileCondition(left, startLabel, endLabel) + + if err != nil { + return err + } + + f.JumpIfTrue(left.Token.Text(), startLabel) + + right := condition.Children[1] + err = f.CompileCondition(right, startLabel, endLabel) + + if err != nil { + return err + } + + if condition.Parent == nil { + f.JumpIfFalse(right.Token.Text(), endLabel) + } else { + f.JumpIfTrue(right.Token.Text(), startLabel) + } + + return nil + case "&&": + return nil + default: + err := f.Compare(condition) + + if condition.Parent == nil { + f.JumpIfFalse(condition.Token.Text(), endLabel) + } + + return err + } +} + +// JumpIfFalse jumps to the label if the previous comparison was false. +func (f *Function) JumpIfFalse(operator string, label string) { + switch operator { + case "==": + f.assembler.Label(asm.JNE, label) + case "!=": + f.assembler.Label(asm.JE, label) + case ">": + f.assembler.Label(asm.JLE, label) + case "<": + f.assembler.Label(asm.JGE, label) + case ">=": + f.assembler.Label(asm.JL, label) + case "<=": + f.assembler.Label(asm.JG, label) + } +} + +// JumpIfTrue jumps to the label if the previous comparison was true. +func (f *Function) JumpIfTrue(operator string, label string) { + switch operator { + case "==": + f.assembler.Label(asm.JE, label) + case "!=": + f.assembler.Label(asm.JNE, label) + case ">": + f.assembler.Label(asm.JG, label) + case "<": + f.assembler.Label(asm.JL, label) + case ">=": + f.assembler.Label(asm.JGE, label) + case "<=": + f.assembler.Label(asm.JLE, label) + } +} diff --git a/src/build/core/CompileIf.go b/src/build/core/CompileIf.go index 0b9b584..c73faba 100644 --- a/src/build/core/CompileIf.go +++ b/src/build/core/CompileIf.go @@ -9,51 +9,16 @@ import ( // CompileIf compiles a branch instruction. func (f *Function) CompileIf(branch *ast.If) error { - err := f.Evaluate(branch.Condition) + startLabel := fmt.Sprintf("%s_if_start_%d", f.Name, f.count.branch) + endLabel := fmt.Sprintf("%s_if_end_%d", f.Name, f.count.branch) + err := f.CompileCondition(branch.Condition, startLabel, endLabel) if err != nil { return err } - endLabel := fmt.Sprintf("%s_end_if_%d", f.Name, f.count.branch) - f.JumpIfFalse(branch.Condition.Token.Text(), endLabel) + f.assembler.Label(asm.LABEL, startLabel) defer f.assembler.Label(asm.LABEL, endLabel) f.count.branch++ return f.CompileAST(branch.Body) } - -// JumpIfFalse jumps to the label if the previous comparison was false. -func (f *Function) JumpIfFalse(operator string, label string) { - switch operator { - case "==": - f.assembler.Label(asm.JNE, label) - case "!=": - f.assembler.Label(asm.JE, label) - case ">": - f.assembler.Label(asm.JLE, label) - case "<": - f.assembler.Label(asm.JGE, label) - case ">=": - f.assembler.Label(asm.JL, label) - case "<=": - f.assembler.Label(asm.JG, label) - } -} - -// JumpIfTrue jumps to the label if the previous comparison was true. -func (f *Function) JumpIfTrue(operator string, label string) { - switch operator { - case "==": - f.assembler.Label(asm.JE, label) - case "!=": - f.assembler.Label(asm.JNE, label) - case ">": - f.assembler.Label(asm.JG, label) - case "<": - f.assembler.Label(asm.JL, label) - case ">=": - f.assembler.Label(asm.JGE, label) - case "<=": - f.assembler.Label(asm.JLE, label) - } -} diff --git a/src/build/core/Evaluate.go b/src/build/core/Evaluate.go deleted file mode 100644 index 5c3a60c..0000000 --- a/src/build/core/Evaluate.go +++ /dev/null @@ -1,41 +0,0 @@ -package core - -import ( - "git.akyoto.dev/cli/q/src/build/ast" - "git.akyoto.dev/cli/q/src/build/expression" - "git.akyoto.dev/cli/q/src/build/token" -) - -// Evaluate evaluates an expression. -func (f *Function) Evaluate(value *expression.Expression) error { - left := value.Children[0] - right := value.Children[1] - - if left.IsLeaf() && left.Token.Kind == token.Identifier { - variable := f.variables[left.Token.Text()] - register := variable.Register - defer f.useVariable(variable) - return f.Execute(value.Token, register, right) - } - - if ast.IsFunctionCall(left) && right.IsLeaf() { - err := f.CompileCall(left) - - if err != nil { - return err - } - - return f.Execute(value.Token, f.cpu.Output[0], right) - } - - tmp := f.cpu.MustFindFree(f.cpu.General) - err := f.ExpressionToRegister(left, tmp) - - if err != nil { - return err - } - - f.cpu.Use(tmp) - defer f.cpu.Free(tmp) - return f.Execute(value.Token, tmp, right) -} diff --git a/src/build/core/state.go b/src/build/core/state.go index 8ae6838..3127eda 100644 --- a/src/build/core/state.go +++ b/src/build/core/state.go @@ -27,20 +27,20 @@ type counter struct { // PrintInstructions shows the assembly instructions. func (s *state) PrintInstructions() { - ansi.Dim.Println("╭────────────────────────────────────────────────╮") + ansi.Dim.Println("╭────────────────────────────────────────────────────╮") for _, x := range s.assembler.Instructions { ansi.Dim.Print("│ ") switch x.Mnemonic { case asm.LABEL: - ansi.Yellow.Printf("%-46s", x.Data.String()+":") + ansi.Yellow.Printf("%-50s", x.Data.String()+":") case asm.COMMENT: - ansi.Dim.Printf("%-46s", x.Data.String()) + ansi.Dim.Printf("%-50s", x.Data.String()) default: - ansi.Green.Printf("%-8s", x.Mnemonic.String()) + ansi.Green.Printf("%-12s", x.Mnemonic.String()) if x.Data != nil { fmt.Printf("%-38s", x.Data.String()) @@ -52,5 +52,5 @@ func (s *state) PrintInstructions() { ansi.Dim.Print(" │\n") } - ansi.Dim.Println("╰────────────────────────────────────────────────╯") + ansi.Dim.Println("╰────────────────────────────────────────────────────╯") } diff --git a/tests/programs/branch.q b/tests/programs/branch.q index 547fcfc..5a2830e 100644 --- a/tests/programs/branch.q +++ b/tests/programs/branch.q @@ -25,31 +25,15 @@ main() { exit(1) } - if x >= 1 { + if x >= 1 || 1 <= x { exit(1) } - if 1 <= x { + if x == inc(x) || x == dec(x) { exit(1) } - if x == inc(x) { - exit(1) - } - - if x == dec(x) { - exit(1) - } - - if inc(0) == x { - exit(1) - } - - if dec(0) == x { - exit(1) - } - - if inc(x) == dec(x) { + if inc(0) == x || dec(0) == x || inc(x) == dec(x) { exit(1) }