From a52f1e5e726cdb0365cc56d227419b5f58a257f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ond=C5=99ej=20Hru=C5=A1ka?= Date: Mon, 21 Sep 2020 23:28:48 +0200 Subject: [PATCH] fix fallthrough in cond branches --- csn_asm/src/instr/mod.rs | 17 +++++++++++++---- csn_asm/src/lib.rs | 9 +++++---- 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/csn_asm/src/instr/mod.rs b/csn_asm/src/instr/mod.rs index 2fe0e87..7489cdc 100644 --- a/csn_asm/src/instr/mod.rs +++ b/csn_asm/src/instr/mod.rs @@ -32,21 +32,30 @@ impl Flatten for Instr { if let Some(branches) = self.branches { let labels = HashMap::::new(); - let _branch_count = branches.len(); - for (_cnt, (cond, branch)) in branches.into_iter().enumerate() { + let branch_count = branches.len(); + let end_lbl = Label::unique(label_num); + for (cnt, (cond, branch)) in branches.into_iter().enumerate() { if labels.contains_key(&cond) { return Err(Error::Asm(AsmError::ConditionalAlreadyUsed(cond))); } - let next_lbl = Label::unique(label_num); + let next_lbl = if cnt == branch_count - 1 { + end_lbl.clone() + } else { + Label::unique(label_num) + }; ops.push(Op::JumpIf(!cond, next_lbl.clone())); for branch_instr in branch { ops.extend(branch_instr.flatten(label_num)?); } - ops.push(Op::Label(next_lbl)); + if cnt != branch_count - 1 { + ops.push(Op::Jump(end_lbl.clone())); + ops.push(Op::Label(next_lbl)); + } } + ops.push(Op::Label(end_lbl)); } Ok(ops) diff --git a/csn_asm/src/lib.rs b/csn_asm/src/lib.rs index 2a69d10..6959932 100644 --- a/csn_asm/src/lib.rs +++ b/csn_asm/src/lib.rs @@ -233,7 +233,7 @@ mod tests { Rd::new(SrcDisp::Register(Register::Gen(0))), Rd::new(SrcDisp::Register(Register::Gen(1))), ), - Op::JumpIf(Cond::NotEqual, Label::Numbered(0)), + Op::JumpIf(Cond::NotEqual, Label::Numbered(1)), Op::Mov( Wr::new(DstDisp::Register(Register::Gen(0))), Rd::new(SrcDisp::Register(Register::Gen(0))), @@ -242,8 +242,9 @@ mod tests { Wr::new(DstDisp::Register(Register::Gen(1))), Rd::new(SrcDisp::Register(Register::Gen(2))), ), - Op::Label(Label::Numbered(0)), - Op::JumpIf(Cond::LessOrEqual, Label::Numbered(1)), + Op::Jump(Label::Numbered(0)), + Op::Label(Label::Numbered(1)), + Op::JumpIf(Cond::LessOrEqual, Label::Numbered(0)), Op::Mov( Wr::new(DstDisp::Register(Register::Gen(0))), Rd::new(SrcDisp::Register(Register::Gen(0))), @@ -252,7 +253,7 @@ mod tests { Wr::new(DstDisp::Register(Register::Gen(1))), Rd::new(SrcDisp::Register(Register::Gen(1))), ), - Op::Label(Label::Numbered(1)), + Op::Label(Label::Numbered(0)), ], parsed); } }