Skip to main content

javm_recompiler_x86/
predecode.rs

1//! Pre-decode PVM bytecode into a flat instruction stream for fast codegen.
2//!
3//! Replaces the byte-by-byte bitmask scan in the codegen loop with a single
4//! upfront decode pass. The codegen then iterates a `&[PreDecodedInst]` slice,
5//! eliminating redundant `compute_skip()` and `decode_args()` calls.
6
7use alloc::vec;
8use alloc::vec::Vec;
9use javm_exec::args::{self, Args};
10use javm_exec::instruction::Opcode;
11pub use javm_exec::predecoded::PreDecodedInst;
12
13/// Pre-decode all instructions from raw code+bitmask into a flat array.
14///
15/// Three passes:
16/// 1. Decode each instruction (opcode, args, pc, next_pc)
17/// 2. Identify gas block boundaries (branch targets, post-terminators, jump table)
18/// 3. Compute gas cost for each gas block start
19pub fn predecode(code: &[u8], bitmask: &[u8], jump_table: &[u32]) -> Vec<PreDecodedInst> {
20    // --- Pass 1: Decode instructions ---
21    let estimated_count = bitmask.iter().filter(|&&b| b == 1).count();
22    let mut instrs: Vec<PreDecodedInst> = Vec::with_capacity(estimated_count);
23
24    let mut pc: usize = 0;
25    while pc < code.len() {
26        if pc < bitmask.len() && bitmask[pc] != 1 {
27            pc += 1;
28            continue;
29        }
30
31        let opcode = Opcode::from_byte(code[pc]).unwrap_or(Opcode::Trap);
32        let skip = compute_skip(pc, bitmask);
33        let next_pc = pc + 1 + skip;
34        let category = opcode.category();
35        let args = args::decode_args(code, pc, skip, category);
36
37        // Extract flat register fields for fast gas cost lookup
38        let (ra, rb, rd) = match args {
39            Args::ThreeReg { ra, rb, rd } => (ra as u8, rb as u8, rd as u8),
40            Args::TwoReg { rd: d, ra: a } => (a as u8, 0xFF, d as u8),
41            Args::TwoRegImm { ra, rb, .. }
42            | Args::TwoRegOffset { ra, rb, .. }
43            | Args::TwoRegTwoImm { ra, rb, .. } => (ra as u8, rb as u8, 0xFF),
44            Args::RegImm { ra, .. }
45            | Args::RegExtImm { ra, .. }
46            | Args::RegTwoImm { ra, .. }
47            | Args::RegImmOffset { ra, .. } => (ra as u8, 0xFF, 0xFF),
48            _ => (0xFF, 0xFF, 0xFF),
49        };
50        instrs.push(PreDecodedInst {
51            opcode,
52            args,
53            pc: pc as u32,
54            next_pc: next_pc as u32,
55            gas_cost: 0,
56            is_gas_block_start: false,
57            ra,
58            rb,
59            rd,
60        });
61
62        pc = next_pc;
63    }
64
65    // --- Pass 2: Mark gas block starts ---
66    // Build PC → instruction index map for O(1) target lookup.
67    let mut pc_to_idx: Vec<u32> = vec![u32::MAX; code.len() + 1];
68    for (i, instr) in instrs.iter().enumerate() {
69        pc_to_idx[instr.pc as usize] = i as u32;
70    }
71
72    let mut is_gas_start = vec![false; instrs.len()];
73
74    // PC=0 always starts a gas block
75    if !instrs.is_empty() {
76        is_gas_start[0] = true;
77    }
78
79    // Jump table entries
80    for &target in jump_table {
81        let t = target as usize;
82        if t < pc_to_idx.len() && pc_to_idx[t] != u32::MAX {
83            is_gas_start[pc_to_idx[t] as usize] = true;
84        }
85    }
86
87    // Branch/jump targets and post-terminator fallthroughs
88    for i in 0..instrs.len() {
89        let instr = &instrs[i];
90
91        // Extract branch/jump target from decoded args
92        let target_pc = match instr.args {
93            Args::Offset { offset } => Some(offset as usize),
94            Args::RegImmOffset { offset, .. } => Some(offset as usize),
95            Args::TwoRegOffset { offset, .. } => Some(offset as usize),
96            _ => None,
97        };
98        if let Some(t) = target_pc
99            && t < pc_to_idx.len()
100            && pc_to_idx[t] != u32::MAX
101        {
102            is_gas_start[pc_to_idx[t] as usize] = true;
103        }
104
105        // Fallthrough after terminator
106        if instr.opcode.is_terminator() && i + 1 < instrs.len() {
107            is_gas_start[i + 1] = true;
108        }
109
110        // Ecalli: next instruction is a re-entry point
111        if matches!(instr.opcode, Opcode::Ecalli) && i + 1 < instrs.len() {
112            is_gas_start[i + 1] = true;
113        }
114    }
115
116    // --- Mark gas block start flags on instructions ---
117    // Gas costs are computed inline during codegen (single-pass).
118    for i in 0..instrs.len() {
119        if is_gas_start[i] {
120            instrs[i].is_gas_block_start = true;
121        }
122    }
123
124    instrs
125}
126
127/// Compute gas block start bitmap from raw code+bitmask (no full Args decoding).
128/// Returns `Vec<bool>` indexed by PVM byte offset. True = this PC starts a gas block.
129pub fn compute_gas_blocks(code: &[u8], bitmask: &[u8], jump_table: &[u32]) -> Vec<bool> {
130    let mut gas_starts = vec![false; code.len()];
131
132    // PC=0 always starts a gas block
133    if !code.is_empty() {
134        gas_starts[0] = true;
135    }
136
137    // Jump table entries
138    for &target in jump_table {
139        let t = target as usize;
140        if t < code.len() && t < bitmask.len() && bitmask[t] == 1 {
141            gas_starts[t] = true;
142        }
143    }
144
145    // Scan instructions for branch targets and terminators
146    let mut pc: usize = 0;
147    while pc < code.len() {
148        if pc < bitmask.len() && bitmask[pc] != 1 {
149            pc += 1;
150            continue;
151        }
152
153        let opcode = Opcode::from_byte(code[pc]);
154        let skip = compute_skip(pc, bitmask);
155        let next_pc = pc + 1 + skip;
156
157        if let Some(op) = opcode {
158            // Extract branch/jump targets from raw bytes
159            let category = op.category();
160            let target_pc = match category {
161                javm_exec::instruction::InstructionCategory::OneOffset => {
162                    // Jump: offset is signed, relative to pc
163                    let raw = args::decode_args(code, pc, skip, category);
164                    match raw {
165                        Args::Offset { offset } => Some(offset as usize),
166                        _ => None,
167                    }
168                }
169                javm_exec::instruction::InstructionCategory::OneRegImmOffset => {
170                    let raw = args::decode_args(code, pc, skip, category);
171                    match raw {
172                        Args::RegImmOffset { offset, .. } => Some(offset as usize),
173                        _ => None,
174                    }
175                }
176                javm_exec::instruction::InstructionCategory::TwoRegOneOffset => {
177                    let raw = args::decode_args(code, pc, skip, category);
178                    match raw {
179                        Args::TwoRegOffset { offset, .. } => Some(offset as usize),
180                        _ => None,
181                    }
182                }
183                _ => None,
184            };
185            if let Some(t) = target_pc
186                && t < code.len()
187                && t < bitmask.len()
188                && bitmask[t] == 1
189            {
190                gas_starts[t] = true;
191            }
192
193            // Post-terminator fallthrough
194            if op.is_terminator() && next_pc < code.len() {
195                gas_starts[next_pc] = true;
196            }
197
198            // Post-ecalli
199            if matches!(op, Opcode::Ecalli) && next_pc < code.len() {
200                gas_starts[next_pc] = true;
201            }
202        }
203
204        pc = next_pc;
205    }
206
207    gas_starts
208}
209
210/// Compute skip(i) — distance to next instruction start.
211fn compute_skip(pc: usize, bitmask: &[u8]) -> usize {
212    for j in 0..25 {
213        let idx = pc + 1 + j;
214        let bit = if idx < bitmask.len() { bitmask[idx] } else { 1 };
215        if bit == 1 {
216            return j;
217        }
218    }
219    24
220}
221
222#[cfg(test)]
223mod tests {
224    use super::*;
225
226    // === predecode ===
227
228    #[test]
229    fn test_predecode_empty() {
230        let instrs = predecode(&[], &[], &[]);
231        assert!(instrs.is_empty());
232    }
233
234    #[test]
235    fn test_predecode_single_trap() {
236        // Trap = opcode 0, no arguments
237        let code = vec![0u8];
238        let bitmask = vec![1u8];
239        let instrs = predecode(&code, &bitmask, &[]);
240
241        assert_eq!(instrs.len(), 1);
242        assert_eq!(instrs[0].opcode, Opcode::Trap);
243        assert_eq!(instrs[0].pc, 0);
244        assert_eq!(instrs[0].next_pc, 1);
245        assert!(
246            instrs[0].is_gas_block_start,
247            "PC=0 is always a gas block start"
248        );
249    }
250
251    #[test]
252    fn test_predecode_sequence() {
253        // load_imm(51) r0, 42; ecalli(10) 0
254        let code = vec![51, 0, 42, 10, 0];
255        let bitmask = vec![1, 0, 0, 1, 0];
256
257        let instrs = predecode(&code, &bitmask, &[]);
258        assert_eq!(instrs.len(), 2);
259        assert_eq!(instrs[0].opcode, Opcode::LoadImm);
260        assert_eq!(instrs[0].pc, 0);
261        assert_eq!(instrs[0].next_pc, 3);
262        assert_eq!(instrs[1].opcode, Opcode::Ecalli);
263        assert_eq!(instrs[1].pc, 3);
264    }
265
266    #[test]
267    fn test_predecode_gas_block_after_terminator() {
268        // trap(0); load_imm(51) r0, 1
269        // Trap is a terminator, so load_imm starts a new gas block
270        let code = vec![0, 51, 0, 1];
271        let bitmask = vec![1, 1, 0, 0];
272
273        let instrs = predecode(&code, &bitmask, &[]);
274        assert_eq!(instrs.len(), 2);
275        assert!(instrs[0].is_gas_block_start, "PC=0 always");
276        assert!(
277            instrs[1].is_gas_block_start,
278            "post-terminator should be gas block start"
279        );
280    }
281
282    #[test]
283    fn test_predecode_gas_block_after_ecalli() {
284        // ecalli(10) 0; load_imm(51) r0, 1
285        let code = vec![10, 0, 51, 0, 1];
286        let bitmask = vec![1, 0, 1, 0, 0];
287
288        let instrs = predecode(&code, &bitmask, &[]);
289        assert_eq!(instrs.len(), 2);
290        assert!(
291            instrs[1].is_gas_block_start,
292            "post-ecalli should be gas block start"
293        );
294    }
295
296    #[test]
297    fn test_predecode_branch_target_is_gas_start() {
298        // jump(40) offset=-5 (targets PC=0); load_imm(51) r0, 1
299        // Jump at PC=0 targets PC=0 (self-loop)
300        let offset: i32 = 0; // targets self (PC + 0 = 0)
301        let code = vec![
302            40,
303            offset as u8,
304            (offset >> 8) as u8,
305            (offset >> 16) as u8,
306            (offset >> 24) as u8,
307            51,
308            0,
309            1, // load_imm after the jump
310        ];
311        let bitmask = vec![1, 0, 0, 0, 0, 1, 0, 0];
312
313        let instrs = predecode(&code, &bitmask, &[]);
314        assert_eq!(instrs.len(), 2);
315        // PC=0 is both the first instruction AND a branch target
316        assert!(instrs[0].is_gas_block_start);
317        // Post-terminator (jump is a terminator)
318        assert!(instrs[1].is_gas_block_start);
319    }
320
321    #[test]
322    fn test_predecode_jump_table_target_is_gas_start() {
323        // Two instructions: load_imm at PC=0, load_imm at PC=3
324        // Jump table says PC=3 is a target
325        let code = vec![51, 0, 1, 51, 1, 2];
326        let bitmask = vec![1, 0, 0, 1, 0, 0];
327
328        let instrs = predecode(&code, &bitmask, &[3]);
329        assert_eq!(instrs.len(), 2);
330        assert!(instrs[0].is_gas_block_start, "PC=0 always");
331        assert!(
332            instrs[1].is_gas_block_start,
333            "jump table target should be gas block start"
334        );
335    }
336
337    #[test]
338    fn test_predecode_non_target_not_gas_start() {
339        // Two consecutive load_imm instructions, no branches
340        let code = vec![51, 0, 1, 51, 1, 2];
341        let bitmask = vec![1, 0, 0, 1, 0, 0];
342
343        let instrs = predecode(&code, &bitmask, &[]);
344        assert_eq!(instrs.len(), 2);
345        assert!(instrs[0].is_gas_block_start, "PC=0 always");
346        assert!(
347            !instrs[1].is_gas_block_start,
348            "not a target, not post-terminator"
349        );
350    }
351
352    // === compute_gas_blocks ===
353
354    #[test]
355    fn test_gas_blocks_empty() {
356        let blocks = compute_gas_blocks(&[], &[], &[]);
357        assert!(blocks.is_empty());
358    }
359
360    #[test]
361    fn test_gas_blocks_pc0_always_start() {
362        let code = vec![51, 0, 1];
363        let bitmask = vec![1, 0, 0];
364        let blocks = compute_gas_blocks(&code, &bitmask, &[]);
365        assert!(blocks[0], "PC=0 should always be a gas block start");
366    }
367
368    #[test]
369    fn test_gas_blocks_jump_table() {
370        let code = vec![51, 0, 1, 51, 1, 2];
371        let bitmask = vec![1, 0, 0, 1, 0, 0];
372        let blocks = compute_gas_blocks(&code, &bitmask, &[3]);
373        assert!(blocks[0]);
374        assert!(blocks[3], "jump table target should be gas block start");
375    }
376
377    #[test]
378    fn test_gas_blocks_post_terminator() {
379        let code = vec![0, 51, 0, 1]; // trap; load_imm
380        let bitmask = vec![1, 1, 0, 0];
381        let blocks = compute_gas_blocks(&code, &bitmask, &[]);
382        assert!(blocks[0]);
383        assert!(blocks[1], "post-terminator should be gas block start");
384    }
385
386    // === compute_skip ===
387
388    #[test]
389    fn test_skip_single_byte() {
390        // Next byte is an instruction start
391        assert_eq!(compute_skip(0, &[1, 1]), 0);
392    }
393
394    #[test]
395    fn test_skip_multi_byte() {
396        // Two continuation bytes before next instruction start
397        assert_eq!(compute_skip(0, &[1, 0, 0, 1]), 2);
398    }
399
400    #[test]
401    fn test_skip_at_end() {
402        // Past end of bitmask → treated as instruction start
403        assert_eq!(compute_skip(0, &[1]), 0);
404    }
405}