Skip to main content

javm_transpiler/
lib.rs

1//! RISC-V ELF to JAM PVM transpiler.
2//!
3//! Converts RISC-V rv64em ELF binaries into PVM program blobs
4//! suitable for execution by the JAR PVM (Appendix A).
5//!
6//! Also provides utilities to hand-assemble PVM programs directly.
7
8pub mod assembler;
9pub mod emitter;
10pub mod layout;
11pub mod linker;
12pub mod program;
13pub mod riscv;
14
15use thiserror::Error;
16
17/// Parse a signed variable-length immediate from PVM bytecode.
18///
19/// Reads `lx` bytes starting at `code[start]`, sign-extends to i64.
20/// Used by peephole passes to extract load_imm values and memory offsets.
21fn parse_signed_imm(code: &[u8], start: usize, lx: usize) -> i64 {
22    let mut buf = [0u8; 8];
23    for k in 0..lx.min(8) {
24        if start + k < code.len() {
25            buf[k] = code[start + k];
26        }
27    }
28    if lx > 0 && lx <= 8 && buf[lx.min(8) - 1] & 0x80 != 0 {
29        for b in &mut buf[lx.min(8)..8] {
30            *b = 0xFF;
31        }
32    }
33    i64::from_le_bytes(buf)
34}
35
36#[derive(Error, Debug)]
37pub enum TranspileError {
38    #[error("ELF parse error: {0}")]
39    ElfParse(String),
40    #[error("unsupported RISC-V instruction at offset {offset:#x}: {detail}")]
41    UnsupportedInstruction { offset: usize, detail: String },
42    #[error("unsupported relocation: {0}")]
43    UnsupportedRelocation(String),
44    #[error("register mapping error: RISC-V register {0} has no PVM equivalent")]
45    RegisterMapping(u8),
46    #[error("code too large: {0} bytes")]
47    CodeTooLarge(usize),
48    #[error("invalid section: {0}")]
49    InvalidSection(String),
50}
51
52/// Link a RISC-V rv64em ELF binary into a v3 chain
53/// [`javm_cap::image::Image`]. The Image carries the PVM CODE sub-blob
54/// in its `code` field, populated endpoints (from any
55/// `.subsoil.endpoints` ELF section, or a single PC-0 fallback for
56/// `subsoil::entry!`-based guests), and standard kernel-ABI slot
57/// conventions.
58pub fn link_elf(elf_data: &[u8]) -> Result<javm_cap::image::Image, TranspileError> {
59    linker::link_elf(elf_data)
60}
61
62/// Compute skip distance from bitmask: number of continuation bytes after position `pc`.
63fn skip_for(bitmask: &[u8], pc: usize) -> usize {
64    for j in 0..25 {
65        let idx = pc + 1 + j;
66        if idx >= bitmask.len() || bitmask[idx] == 1 {
67            return j;
68        }
69    }
70    0
71}
72
73/// Collect all branch targets and jump table entries from PVM code.
74///
75/// Returns a set of byte offsets that are branch/jump destinations.
76/// Used by peephole passes to avoid fusing across branch boundaries.
77fn collect_branch_targets(
78    code: &[u8],
79    bitmask: &[u8],
80    jump_table: &[u32],
81) -> std::collections::HashSet<usize> {
82    let len = code.len();
83    let mut targets = std::collections::HashSet::new();
84    let mut i = 0;
85    while i < len {
86        if i >= bitmask.len() || bitmask[i] != 1 {
87            i += 1;
88            continue;
89        }
90        let op = code[i];
91        let s = skip_for(bitmask, i);
92        // jump (40): 4-byte offset
93        if op == 40 && i + 5 <= len {
94            let off = i32::from_le_bytes([code[i + 1], code[i + 2], code[i + 3], code[i + 4]]);
95            let t = (i as i64 + off as i64) as usize;
96            if t < len {
97                targets.insert(t);
98            }
99        }
100        // branch_eq..branch_ge_u (170-175): 4-byte offset at +2
101        if (170..=175).contains(&op) && i + 6 <= len {
102            let off = i32::from_le_bytes([code[i + 2], code[i + 3], code[i + 4], code[i + 5]]);
103            let t = (i as i64 + off as i64) as usize;
104            if t < len {
105                targets.insert(t);
106            }
107        }
108        // branch_*_imm (80-90): variable-length offset
109        if (80..=90).contains(&op) && i + 2 <= len {
110            let reg_byte = code[i + 1];
111            let lx = ((reg_byte as usize / 16) % 8).min(4);
112            let ly = if s > lx + 1 { (s - lx - 1).min(4) } else { 0 };
113            let off_start = i + 2 + lx;
114            if ly > 0 && off_start + ly <= len {
115                let mut buf = [0u8; 4];
116                buf[..ly].copy_from_slice(&code[off_start..off_start + ly]);
117                if ly < 4 && buf[ly - 1] & 0x80 != 0 {
118                    for b in &mut buf[ly..4] {
119                        *b = 0xFF;
120                    }
121                }
122                let off = i32::from_le_bytes(buf);
123                let t = (i as i64 + off as i64) as usize;
124                if t < len {
125                    targets.insert(t);
126                }
127            }
128        }
129        i += 1 + s;
130    }
131    for &jt in jump_table {
132        targets.insert(jt as usize);
133    }
134    targets
135}
136
137/// Peephole pass: fuse `load_imm(51) + ThreeReg ALU` into `TwoRegOneImm` immediate form.
138///
139/// Scans the PVM code for consecutive pairs where:
140/// 1. First instruction is `load_imm` (opcode 51)
141/// 2. Second instruction is a ThreeReg ALU op with an immediate-form equivalent
142/// 3. The load destination register equals the ALU output register (dead after ALU)
143/// 4. The load value fits in i32 (4-byte immediate)
144/// 5. Neither instruction is a branch target
145///
146/// When fusable, rewrites the pair in-place: the first instruction becomes the
147/// TwoRegOneImm form with a 4-byte immediate, and all remaining bytes through
148/// the end of the second instruction become bitmask=0 continuation bytes.
149pub fn peephole_fuse_load_imm_alu(
150    code: &mut [u8],
151    bitmask: &mut [u8],
152    jump_table: &[u32],
153) -> usize {
154    let len = code.len();
155    if len < 4 {
156        return 0;
157    }
158
159    let targets = collect_branch_targets(code, bitmask, jump_table);
160
161    // ThreeReg ALU → TwoRegOneImm immediate form mapping
162    let imm_opcode = |three_reg_op: u8| -> Option<u8> {
163        match three_reg_op {
164            200 => Some(149), // add_64 → add_imm_64
165            202 => Some(150), // mul_64 → mul_imm_64
166            207 => Some(151), // shl_64 → shl_imm_64
167            208 => Some(152), // shr_64 → shr_imm_64
168            209 => Some(153), // sar_64 → sar_imm_64
169            210 => Some(132), // and → and_imm
170            211 => Some(133), // xor → xor_imm
171            212 => Some(134), // or → or_imm
172            // set_lt_u (216) and set_lt_s (217) are non-commutative — handled below
173            _ => None,
174        }
175    };
176
177    let mut fused = 0;
178    let mut i = 0;
179    while i < len {
180        if i >= bitmask.len() || bitmask[i] != 1 {
181            i += 1;
182            continue;
183        }
184        let op = code[i];
185        let s = skip_for(bitmask, i);
186        let next_i = i + 1 + s;
187
188        // Look for load_imm (51) followed by a ThreeReg ALU
189        if op == 51 && next_i < len && bitmask[next_i] == 1 && !targets.contains(&next_i) {
190            let alu_op = code[next_i];
191            let alu_s = skip_for(bitmask, next_i);
192
193            // Check if the following instruction is a ThreeReg ALU we can fuse with.
194            // Covers: commutative ops (add, mul, shift, logic), sub_64, set_lt.
195            let is_fusable_alu =
196                imm_opcode(alu_op).is_some() || alu_op == 201 || alu_op == 216 || alu_op == 217;
197
198            if is_fusable_alu && i + 1 < len && next_i + 2 < len {
199                // Parse load_imm: [51, reg_byte, imm...] — OneRegOneImm
200                let load_reg_byte = code[i + 1];
201                let load_rd = load_reg_byte & 0x0F;
202                let lx = s.saturating_sub(1);
203                let load_val = parse_signed_imm(code, i + 2, lx);
204
205                // Parse ThreeReg ALU: [op, ra|(rb<<4), rd]
206                let alu_reg1 = code[next_i + 1];
207                let alu_ra = alu_reg1 & 0x0F;
208                let alu_rb = (alu_reg1 >> 4) & 0x0F;
209                let alu_rd = code[next_i + 2].min(12);
210
211                let fits_i32 = load_val >= i32::MIN as i64 && load_val <= i32::MAX as i64;
212                let matches_ra = load_rd == alu_ra;
213                let matches_rb = load_rd == alu_rb;
214                let end_of_pair = next_i + 1 + alu_s;
215
216                /// Write a fused TwoRegOneImm instruction in-place.
217                /// `reg_byte` is `rd | (base << 4)`.
218                fn emit_fused(
219                    code: &mut [u8],
220                    bitmask: &mut [u8],
221                    i: usize,
222                    end: usize,
223                    fused_op: u8,
224                    reg_byte: u8,
225                    imm: i32,
226                ) -> bool {
227                    if end >= i + 6 {
228                        code[i] = fused_op;
229                        code[i + 1] = reg_byte;
230                        let imm_bytes = imm.to_le_bytes();
231                        code[i + 2] = imm_bytes[0];
232                        code[i + 3] = imm_bytes[1];
233                        code[i + 4] = imm_bytes[2];
234                        code[i + 5] = imm_bytes[3];
235                        for k in 6..(end - i) {
236                            code[i + k] = 0;
237                        }
238                        for b in &mut bitmask[(i + 1)..end] {
239                            *b = 0;
240                        }
241                        true
242                    } else {
243                        false
244                    }
245                }
246
247                // Non-commutative: set_lt_u (216) and set_lt_s (217).
248                // rd = ra < rb: constant as rb → set_lt_imm, constant as ra → set_gt_imm
249                if (alu_op == 216 || alu_op == 217)
250                    && fits_i32
251                    && load_rd == alu_rd
252                    && (matches_ra != matches_rb)
253                {
254                    let (cmp_imm_op, base) = if matches_rb {
255                        let op = if alu_op == 216 { 136u8 } else { 137u8 };
256                        (op, alu_ra)
257                    } else {
258                        let op = if alu_op == 216 { 142u8 } else { 143u8 };
259                        (op, alu_rb)
260                    };
261                    if emit_fused(
262                        code,
263                        bitmask,
264                        i,
265                        end_of_pair,
266                        cmp_imm_op,
267                        alu_rd | (base << 4),
268                        load_val as i32,
269                    ) {
270                        fused += 1;
271                        i = end_of_pair;
272                        continue;
273                    }
274                }
275
276                // Special case: sub_64 (201) is non-commutative.
277                // load_imm rd, K; sub_64 rd, ra, rb (rd = ra - rb):
278                //   rd==rb (constant subtrahend): rd = ra - K → add_imm_64(rd, ra, -K)
279                //   rd==ra (constant minuend):    rd = K - rb → neg_add_imm_64(rd, rb, K)
280                if alu_op == 201 && fits_i32 && load_rd == alu_rd && (matches_ra != matches_rb) {
281                    let result = if matches_rb {
282                        let neg_k = -(load_val as i32) as i64;
283                        if neg_k >= i32::MIN as i64 && neg_k <= i32::MAX as i64 {
284                            Some((149u8, alu_ra, neg_k as i32))
285                        } else {
286                            None
287                        }
288                    } else {
289                        Some((154u8, alu_rb, load_val as i32))
290                    };
291                    if let Some((sub_imm_op, base, imm32)) = result
292                        && emit_fused(
293                            code,
294                            bitmask,
295                            i,
296                            end_of_pair,
297                            sub_imm_op,
298                            alu_rd | (base << 4),
299                            imm32,
300                        )
301                    {
302                        fused += 1;
303                        i = end_of_pair;
304                        continue;
305                    }
306                }
307
308                // General commutative ALU ops with immediate form
309                if let Some(imm_op) = imm_opcode(alu_op)
310                    && fits_i32
311                    && load_rd == alu_rd
312                    && (matches_ra || matches_rb)
313                {
314                    let base = if matches_ra { alu_rb } else { alu_ra };
315                    if emit_fused(
316                        code,
317                        bitmask,
318                        i,
319                        end_of_pair,
320                        imm_op,
321                        alu_rd | (base << 4),
322                        load_val as i32,
323                    ) {
324                        fused += 1;
325                        i = end_of_pair;
326                        continue;
327                    }
328                }
329            }
330        }
331        i += 1 + s;
332    }
333    fused
334}
335
336/// Peephole pass: fuse `load_imm` + indirect memory op into direct memory op.
337///
338/// When `load_imm rd, K` is immediately followed by `load_ind_X dest, rd, offset`
339/// or `store_ind_X [rd + offset], val`, and `K + offset` fits in i32, the pair is
340/// replaced by the direct `load_X dest, K+offset` or `store_X [K+offset], val`.
341/// This eliminates the intermediate address register load.
342pub fn peephole_fuse_load_imm_memory(
343    code: &mut [u8],
344    bitmask: &mut [u8],
345    jump_table: &[u32],
346) -> usize {
347    let len = code.len();
348    if len < 4 {
349        return 0;
350    }
351
352    let targets = collect_branch_targets(code, bitmask, jump_table);
353
354    // Map indirect opcode → direct opcode
355    let direct_opcode = |ind_op: u8| -> Option<u8> {
356        match ind_op {
357            124 => Some(52), // load_ind_u8  → load_u8
358            125 => Some(53), // load_ind_i8  → load_i8
359            126 => Some(54), // load_ind_u16 → load_u16
360            127 => Some(55), // load_ind_i16 → load_i16
361            128 => Some(56), // load_ind_u32 → load_u32
362            129 => Some(57), // load_ind_i32 → load_i32
363            130 => Some(58), // load_ind_u64 → load_u64
364            120 => Some(59), // store_ind_u8  → store_u8
365            121 => Some(60), // store_ind_u16 → store_u16
366            122 => Some(61), // store_ind_u32 → store_u32
367            123 => Some(62), // store_ind_u64 → store_u64
368            _ => None,
369        }
370    };
371
372    // For load_ind: rd is dest, ra is base. We fuse when base == load_imm's rd.
373    // For store_ind: rd is value, ra is base. We fuse when base == load_imm's rd.
374    // In both cases, ra (high nibble of reg byte) must match load_imm's destination.
375    let is_load_ind = |op: u8| -> bool { (124..=130).contains(&op) };
376
377    let mut fused = 0;
378    let mut i = 0;
379    while i < len {
380        if i >= bitmask.len() || bitmask[i] != 1 {
381            i += 1;
382            continue;
383        }
384        let op = code[i];
385        let s = skip_for(bitmask, i);
386        let next_i = i + 1 + s;
387
388        // Look for load_imm (51) followed by load_ind or store_ind
389        if op == 51 && next_i < len && bitmask[next_i] == 1 && !targets.contains(&next_i) {
390            let mem_op = code[next_i];
391            let mem_s = skip_for(bitmask, next_i);
392            if let Some(dir_op) = direct_opcode(mem_op) {
393                // Parse load_imm: [51, reg_byte, imm...]
394                if i + 1 < len {
395                    let load_rd = code[i + 1] & 0x0F;
396                    let lx = s.saturating_sub(1);
397                    let load_val = parse_signed_imm(code, i + 2, lx);
398
399                    // Parse memory op: [mem_op, rd|(ra<<4), imm0-3]
400                    if next_i + 2 < len {
401                        let mem_reg_byte = code[next_i + 1];
402                        let mem_rd = mem_reg_byte & 0x0F; // dest (load) or value (store)
403                        let mem_ra = (mem_reg_byte >> 4) & 0x0F; // base address register
404
405                        // Fuse if: load_imm's rd == memory op's base register (ra)
406                        // AND the loaded register is not also used as the value in a store
407                        // (i.e., for store_ind: load_rd must be ra, not rd)
408                        let base_matches = load_rd == mem_ra;
409
410                        // For load_ind: also check load_rd != mem_rd if load_rd == mem_ra,
411                        // because the direct form loses ra. But we keep mem_rd as the dest,
412                        // so the only constraint is base_matches.
413                        // For store_ind: load_rd == mem_ra is sufficient. If load_rd == mem_rd
414                        // too, the value is the same constant — still valid since the store
415                        // reads mem_rd BEFORE we'd clobber it.
416                        //
417                        // Additional safety: if load_rd is used as BOTH base and value in
418                        // store_ind (mem_ra == mem_rd == load_rd), the direct store still
419                        // reads the value from the register correctly.
420                        let is_load = is_load_ind(mem_op);
421                        let safe = if is_load {
422                            // For load_ind: load_rd == mem_ra. If load_rd == mem_rd too,
423                            // the load overwrites the base register — but we don't need
424                            // the base anymore since we're using the direct address.
425                            base_matches
426                        } else {
427                            // For store_ind: load_rd == mem_ra. Must also ensure
428                            // load_rd != mem_rd OR the store doesn't need the original
429                            // register value (it needs the LOADED constant, which is fine).
430                            base_matches
431                        };
432
433                        // Parse memory op's offset
434                        let ly = mem_s.saturating_sub(1);
435                        let offset = parse_signed_imm(code, next_i + 2, ly);
436
437                        let combined = load_val.wrapping_add(offset);
438                        let fits_u32 = combined >= 0 && combined <= u32::MAX as i64;
439                        let end_of_pair = next_i + 1 + mem_s;
440
441                        if safe && fits_u32 && end_of_pair >= next_i + 6 {
442                            // Rewrite memory op in-place as direct form
443                            code[next_i] = dir_op;
444                            // Direct form: [dir_op, rd, imm0-3] (OneRegOneImm)
445                            // rd is the dest (load) or value (store) register
446                            code[next_i + 1] = mem_rd;
447                            let addr_bytes = (combined as u32).to_le_bytes();
448                            code[next_i + 2] = addr_bytes[0];
449                            code[next_i + 3] = addr_bytes[1];
450                            code[next_i + 4] = addr_bytes[2];
451                            code[next_i + 5] = addr_bytes[3];
452                            // Zero remaining bytes
453                            for k in 6..(end_of_pair - next_i) {
454                                code[next_i + k] = 0;
455                            }
456                            // Clear continuation bitmask for memory op
457                            for b in &mut bitmask[(next_i + 1)..end_of_pair] {
458                                *b = 0;
459                            }
460
461                            // NOP the load_imm by clearing its bitmask
462                            bitmask[i] = 0;
463                            for b in code[i..next_i].iter_mut() {
464                                *b = 0;
465                            }
466
467                            fused += 1;
468                            i = end_of_pair;
469                            continue;
470                        }
471                    }
472                }
473            }
474        }
475        i += 1 + s;
476    }
477    fused
478}
479
480/// Peephole pass: eliminate dead `load_imm` instructions.
481///
482/// When a `load_imm` (opcode 51) or `load_imm_64` (opcode 20) writes to register R,
483/// and the immediately following instruction also writes to R without reading it
484/// (another load_imm/load_imm_64, or move_reg with R as destination), the first
485/// instruction is dead and can be replaced with a no-op (bitmask cleared).
486///
487/// The second instruction must not be a branch target (otherwise the first
488/// load_imm could be reached independently via a different path).
489pub fn peephole_eliminate_dead_load_imm(
490    code: &mut [u8],
491    bitmask: &mut [u8],
492    jump_table: &[u32],
493) -> usize {
494    let len = code.len();
495    if len < 4 {
496        return 0;
497    }
498
499    let targets = collect_branch_targets(code, bitmask, jump_table);
500
501    /// Extract the destination register from a load_imm (51) or load_imm_64 (20).
502    /// Returns None if the instruction doesn't write to a register or is malformed.
503    fn load_dest_reg(code: &[u8], pc: usize) -> Option<u8> {
504        let op = code[pc];
505        if (op == 51 || op == 20) && pc + 1 < code.len() {
506            Some(code[pc + 1] & 0x0F)
507        } else {
508            None
509        }
510    }
511
512    /// Check if an instruction at `pc` unconditionally writes to register `rd`
513    /// without reading it first. Covers: load_imm(51), load_imm_64(20), move_reg(100).
514    fn writes_without_reading(code: &[u8], pc: usize, rd: u8) -> bool {
515        if pc >= code.len() {
516            return false;
517        }
518        let op = code[pc];
519        match op {
520            // load_imm / load_imm_64: dest is bits 0-3 of reg_byte
521            51 | 20 => pc + 1 < code.len() && (code[pc + 1] & 0x0F) == rd,
522            // move_reg: [100, rd|(rs<<4)] — writes rd, reads rs
523            // Safe only if rd != rs (otherwise it reads rd too, but move to self is still dead)
524            100 => pc + 1 < code.len() && (code[pc + 1] & 0x0F) == rd,
525            _ => false,
526        }
527    }
528
529    let mut eliminated = 0;
530    let mut i = 0;
531    while i < len {
532        if i >= bitmask.len() || bitmask[i] != 1 {
533            i += 1;
534            continue;
535        }
536        let s = skip_for(bitmask, i);
537        let next_i = i + 1 + s;
538
539        if let Some(rd) = load_dest_reg(code, i)
540            && next_i < len
541            && bitmask[next_i] == 1
542            && !targets.contains(&next_i)
543            && writes_without_reading(code, next_i, rd)
544        {
545            // First load_imm is dead — NOP it by clearing its bitmask
546            bitmask[i] = 0;
547            // Zero out the instruction bytes
548            for b in code[i..next_i].iter_mut() {
549                *b = 0;
550            }
551            eliminated += 1;
552            i = next_i;
553            continue;
554        }
555        i += 1 + s;
556    }
557    eliminated
558}
559
560/// Post-pass: ensure all PVM branch targets are basic block starts (ϖ).
561///
562/// Scans the PVM code for branch/jump instructions, extracts their targets,
563/// and inserts `fallthrough` (opcode 1) before any target not preceded by a
564/// terminator. Adjusts all branch offsets and jump table entries to account
565/// for the inserted bytes.
566///
567/// This guarantees the JAM spec invariant: all branch targets ∈ ϖ.
568pub fn ensure_branch_targets_are_block_starts(
569    code: &mut Vec<u8>,
570    bitmask: &mut Vec<u8>,
571    jump_table: &mut [u32],
572) {
573    let terminators: &[u8] = &[0, 1, 2, 10, 40, 50, 80, 180];
574    let is_terminator = |op: u8| -> bool {
575        terminators.contains(&op) || (81..=90).contains(&op) || (170..=175).contains(&op)
576    };
577
578    // Helper: compute skip from bitmask (next instruction start after pc)
579    let skip_for = |bm: &[u8], pc: usize| -> usize {
580        for j in 0..25 {
581            let idx = pc + 1 + j;
582            if idx >= bm.len() || bm[idx] == 1 {
583                return j;
584            }
585        }
586        0
587    };
588
589    // Pass 1: find all branch target PCs and check which need fallthrough.
590    let len = code.len();
591    let mut insert_positions: Vec<usize> = Vec::new(); // PVM offsets to insert fallthrough BEFORE
592
593    // Build post-terminator set for checking
594    let mut post_term = std::collections::HashSet::new();
595    post_term.insert(0usize);
596    {
597        let mut i = 0;
598        while i < len {
599            if i >= bitmask.len() || bitmask[i] != 1 {
600                i += 1;
601                continue;
602            }
603            let op = code[i];
604            let s = skip_for(bitmask, i);
605            if is_terminator(op) {
606                let nxt = i + 1 + s;
607                if nxt < len && nxt < bitmask.len() && bitmask[nxt] == 1 {
608                    post_term.insert(nxt);
609                }
610            }
611            i += 1 + s;
612        }
613    }
614
615    // Collect branch targets
616    let mut branch_targets = std::collections::HashSet::new();
617    {
618        let mut i = 0;
619        while i < len {
620            if i >= bitmask.len() || bitmask[i] != 1 {
621                i += 1;
622                continue;
623            }
624            let op = code[i];
625            let s = skip_for(bitmask, i);
626
627            // OneOffset: opcode 40 (jump), 80 (load_imm_jump)
628            if op == 40 && i + 5 <= len {
629                let off = i32::from_le_bytes([code[i + 1], code[i + 2], code[i + 3], code[i + 4]]);
630                let t = (i as i64 + off as i64) as usize;
631                if t < len && t < bitmask.len() && bitmask[t] == 1 {
632                    branch_targets.insert(t);
633                }
634            }
635            // TwoRegOneOffset: opcodes 170-175
636            if (170..=175).contains(&op) && i + 6 <= len {
637                let off = i32::from_le_bytes([code[i + 2], code[i + 3], code[i + 4], code[i + 5]]);
638                let t = (i as i64 + off as i64) as usize;
639                if t < len && t < bitmask.len() && bitmask[t] == 1 {
640                    branch_targets.insert(t);
641                }
642            }
643            // OneRegImmOffset: opcodes 80-90
644            if (80..=90).contains(&op) && i + 2 <= len {
645                let reg_byte = code[i + 1];
646                let lx = ((reg_byte as usize / 16) % 8).min(4);
647                let ly = if s > lx + 1 { (s - lx - 1).min(4) } else { 0 };
648                let off_start = i + 2 + lx;
649                if ly > 0 && off_start + ly <= len {
650                    let mut buf = [0u8; 4];
651                    buf[..ly].copy_from_slice(&code[off_start..off_start + ly]);
652                    if ly < 4 && buf[ly - 1] & 0x80 != 0 {
653                        for b in &mut buf[ly..4] {
654                            *b = 0xFF;
655                        }
656                    }
657                    let off = i32::from_le_bytes(buf);
658                    let t = (i as i64 + off as i64) as usize;
659                    if t < len && t < bitmask.len() && bitmask[t] == 1 {
660                        branch_targets.insert(t);
661                    }
662                }
663            }
664            i += 1 + s;
665        }
666    }
667
668    // Find branch targets not in post_term
669    for &t in &branch_targets {
670        if !post_term.contains(&t) {
671            insert_positions.push(t);
672        }
673    }
674    // Also check jump table entries
675    for &jt_entry in jump_table.iter() {
676        let t = jt_entry as usize;
677        if t < len
678            && t < bitmask.len()
679            && bitmask[t] == 1
680            && !post_term.contains(&t)
681            && !insert_positions.contains(&t)
682        {
683            insert_positions.push(t);
684        }
685    }
686
687    if insert_positions.is_empty() {
688        return;
689    }
690
691    insert_positions.sort();
692    insert_positions.dedup();
693
694    // Pass 2: build new code/bitmask with fallthroughs inserted.
695    // Also build an offset map: old_pc → new_pc.
696    let new_len = len + insert_positions.len();
697    let mut new_code = Vec::with_capacity(new_len);
698    let mut new_bitmask = Vec::with_capacity(new_len);
699    let mut offset_map = vec![0u32; len + 1]; // old_pc → new_pc
700
701    let mut insert_idx = 0;
702    for old_pc in 0..len {
703        // Insert fallthrough before this PC if needed
704        while insert_idx < insert_positions.len() && insert_positions[insert_idx] == old_pc {
705            new_code.push(1); // fallthrough opcode
706            new_bitmask.push(1); // instruction start
707            insert_idx += 1;
708        }
709        offset_map[old_pc] = new_code.len() as u32;
710        new_code.push(code[old_pc]);
711        new_bitmask.push(bitmask[old_pc]);
712    }
713    offset_map[len] = new_code.len() as u32;
714
715    // Pass 3: fix all PC-relative branch offsets in the new code.
716    // Scan for branch instructions and recalculate their offsets.
717    {
718        let mut i = 0;
719        while i < new_code.len() {
720            if i >= new_bitmask.len() || new_bitmask[i] != 1 {
721                i += 1;
722                continue;
723            }
724            let op = new_code[i];
725            let s = {
726                let mut s = 0;
727                for j in 0..25 {
728                    let idx = i + 1 + j;
729                    if idx >= new_bitmask.len() || new_bitmask[idx] == 1 {
730                        s = j;
731                        break;
732                    }
733                }
734                s
735            };
736
737            // OneOffset with fixed 4-byte immediate: opcode 40 (jump)
738            if op == 40 && i + 5 <= new_code.len() {
739                let _old_off = i32::from_le_bytes([
740                    new_code[i + 1],
741                    new_code[i + 2],
742                    new_code[i + 3],
743                    new_code[i + 4],
744                ]);
745                // Find old PC for this instruction
746                // The instruction at new_pc=i maps back to some old_pc.
747                // old_target = old_pc + old_off. new_target = offset_map[old_target].
748                // new_off = new_target - new_pc = offset_map[old_target] - i.
749                // But we need old_pc. We can compute: old_target was in the original code.
750                // Since new code has extra bytes, old_off referenced old positions.
751                // Actually, the offset was already resolved in the old code. old_target = old_inst_pc + old_off.
752                // We need to map old_inst_pc back. But that's complex.
753                // Simpler: compute old target from old offset, then remap.
754                // We need to find which old_pc maps to this new i.
755                // Build reverse map:
756                // Actually let's just do this with a reverse lookup.
757            }
758
759            i += 1 + s;
760        }
761    }
762
763    // This approach is getting complex. Use a simpler strategy:
764    // rebuild fixups from scratch by scanning old code, computing old targets,
765    // and patching new code with remapped offsets.
766
767    // Actually, let's use the offset_map directly on the old code's branch instructions.
768    {
769        let mut old_i = 0;
770        while old_i < len {
771            if old_i >= bitmask.len() || bitmask[old_i] != 1 {
772                old_i += 1;
773                continue;
774            }
775            let op = code[old_i];
776            let s = skip_for(bitmask, old_i);
777            let new_i = offset_map[old_i] as usize;
778
779            // Fix OneOffset: opcode 40
780            if op == 40 && old_i + 5 <= len {
781                let old_off = i32::from_le_bytes([
782                    code[old_i + 1],
783                    code[old_i + 2],
784                    code[old_i + 3],
785                    code[old_i + 4],
786                ]);
787                let old_target = (old_i as i64 + old_off as i64) as usize;
788                if old_target <= len {
789                    let new_target = offset_map[old_target] as i64;
790                    let new_off = (new_target - new_i as i64) as i32;
791                    new_code[new_i + 1..new_i + 5].copy_from_slice(&new_off.to_le_bytes());
792                }
793            }
794            // Fix TwoRegOneOffset: opcodes 170-175
795            if (170..=175).contains(&op) && old_i + 6 <= len {
796                let old_off = i32::from_le_bytes([
797                    code[old_i + 2],
798                    code[old_i + 3],
799                    code[old_i + 4],
800                    code[old_i + 5],
801                ]);
802                let old_target = (old_i as i64 + old_off as i64) as usize;
803                if old_target <= len {
804                    let new_target = offset_map[old_target] as i64;
805                    let new_off = (new_target - new_i as i64) as i32;
806                    new_code[new_i + 2..new_i + 6].copy_from_slice(&new_off.to_le_bytes());
807                }
808            }
809            // Fix OneRegImmOffset: opcodes 80-90
810            if (80..=90).contains(&op) && old_i + 2 <= len {
811                let reg_byte = code[old_i + 1];
812                let lx = ((reg_byte as usize / 16) % 8).min(4);
813                let ly = if s > lx + 1 { (s - lx - 1).min(4) } else { 0 };
814                let off_start_old = old_i + 2 + lx;
815                if ly > 0 && off_start_old + ly <= len {
816                    let mut buf = [0u8; 4];
817                    buf[..ly].copy_from_slice(&code[off_start_old..off_start_old + ly]);
818                    if ly < 4 && buf[ly - 1] & 0x80 != 0 {
819                        for b in &mut buf[ly..4] {
820                            *b = 0xFF;
821                        }
822                    }
823                    let old_off = i32::from_le_bytes(buf);
824                    let old_target = (old_i as i64 + old_off as i64) as usize;
825                    if old_target <= len {
826                        let new_target = offset_map[old_target] as i64;
827                        let new_off = (new_target - new_i as i64) as i32;
828                        // Write back with same length ly
829                        let new_bytes = new_off.to_le_bytes();
830                        let off_start_new = new_i + 2 + lx;
831                        new_code[off_start_new..off_start_new + ly]
832                            .copy_from_slice(&new_bytes[..ly]);
833                    }
834                }
835            }
836
837            old_i += 1 + s;
838        }
839    }
840
841    // Fix jump table entries
842    for entry in jump_table.iter_mut() {
843        let old_pc = *entry as usize;
844        if old_pc <= len {
845            *entry = offset_map[old_pc];
846        }
847    }
848
849    *code = new_code;
850    *bitmask = new_bitmask;
851}
852
853#[cfg(test)]
854mod tests {
855    use super::*;
856
857    // === peephole_fuse_load_imm_alu tests ===
858
859    #[test]
860    fn test_fuse_load_imm_add64() {
861        // load_imm φ[2], 42 (rd=2, imm=42)
862        // add_64 φ[2] = φ[0] + φ[2] (ra=0, rb=2, rd=2)
863        // → add_imm_64 φ[2] = φ[0] + 42
864        let mut code = vec![
865            51, 2, 42, // load_imm rd=2, imm=42 (skip=1)
866            200, 0x20, 2, // add_64 ra=0, rb=2, rd=2
867        ];
868        let mut bitmask = vec![1, 0, 0, 1, 0, 0];
869
870        let fused = peephole_fuse_load_imm_alu(&mut code, &mut bitmask, &[]);
871        assert_eq!(fused, 1);
872        assert_eq!(code[0], 149); // add_imm_64
873        assert_eq!(code[1] & 0x0F, 2); // rd=2
874        assert_eq!(code[1] >> 4, 0); // base=0
875        assert_eq!(code[2], 42); // imm low byte
876        assert_eq!(bitmask[0], 1);
877        assert_eq!(bitmask[3], 0); // old ALU start cleared
878    }
879
880    #[test]
881    fn test_fuse_load_imm_mul64() {
882        // load_imm φ[3], 7 → mul_64 φ[3] = φ[1] * φ[3]
883        // → mul_imm_64 φ[3] = φ[1] * 7
884        let mut code = vec![
885            51, 3, 7, // load_imm rd=3, imm=7
886            202, 0x31, 3, // mul_64 ra=1, rb=3, rd=3
887        ];
888        let mut bitmask = vec![1, 0, 0, 1, 0, 0];
889
890        let fused = peephole_fuse_load_imm_alu(&mut code, &mut bitmask, &[]);
891        assert_eq!(fused, 1);
892        assert_eq!(code[0], 150); // mul_imm_64
893    }
894
895    #[test]
896    fn test_fuse_skips_branch_target() {
897        // Same pattern but ALU is a branch target → should NOT fuse
898        let mut code = vec![
899            51, 2, 42, 200, 0x20, 2, 40, 253, 255, 255,
900            255, // jump -3 (targets offset 3 = the add_64)
901        ];
902        let mut bitmask = vec![1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0];
903
904        let fused = peephole_fuse_load_imm_alu(&mut code, &mut bitmask, &[]);
905        assert_eq!(fused, 0, "should not fuse when ALU is a branch target");
906    }
907
908    #[test]
909    fn test_fuse_no_match_different_rd() {
910        // load_imm writes to rd=2 but ALU rd=3 → should NOT fuse
911        let mut code = vec![
912            51, 2, 42, // load_imm rd=2
913            200, 0x20, 3, // add_64 rd=3 (not 2)
914        ];
915        let mut bitmask = vec![1, 0, 0, 1, 0, 0];
916
917        let fused = peephole_fuse_load_imm_alu(&mut code, &mut bitmask, &[]);
918        assert_eq!(fused, 0);
919    }
920
921    #[test]
922    fn test_fuse_sub64_constant_subtrahend() {
923        // load_imm φ[2], 5; sub_64 φ[2] = φ[0] - φ[2]
924        // → add_imm_64 φ[2] = φ[0] + (-5)
925        let mut code = vec![
926            51, 2, 5, // load_imm rd=2, imm=5
927            201, 0x20, 2, // sub_64 ra=0, rb=2, rd=2
928        ];
929        let mut bitmask = vec![1, 0, 0, 1, 0, 0];
930
931        let fused = peephole_fuse_load_imm_alu(&mut code, &mut bitmask, &[]);
932        assert_eq!(fused, 1);
933        assert_eq!(code[0], 149); // add_imm_64
934        // imm should be -5 as i32 LE
935        let imm = i32::from_le_bytes([code[2], code[3], code[4], code[5]]);
936        assert_eq!(imm, -5);
937    }
938
939    // === peephole_eliminate_dead_load_imm tests ===
940
941    #[test]
942    fn test_eliminate_dead_load_imm() {
943        // load_imm φ[2], 99 (dead — immediately overwritten)
944        // load_imm φ[2], 42
945        let mut code = vec![
946            51, 2, 99, // dead load_imm rd=2
947            51, 2, 42, // overwrites rd=2
948        ];
949        let mut bitmask = vec![1, 0, 0, 1, 0, 0];
950
951        let eliminated = peephole_eliminate_dead_load_imm(&mut code, &mut bitmask, &[]);
952        assert_eq!(eliminated, 1);
953        assert_eq!(bitmask[0], 0, "dead instruction bitmask cleared");
954        assert_eq!(code[0], 0, "dead instruction bytes zeroed");
955        assert_eq!(code[3], 51, "second load_imm preserved");
956    }
957
958    #[test]
959    fn test_eliminate_dead_load_imm_branch_target() {
960        // load_imm φ[2], 99; load_imm φ[2], 42
961        // BUT second is a branch target → should NOT eliminate
962        let mut code = vec![
963            51, 2, 99, 51, 2, 42, 40, 253, 255, 255, 255, // jump -3 (targets offset 3)
964        ];
965        let mut bitmask = vec![1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0];
966
967        let eliminated = peephole_eliminate_dead_load_imm(&mut code, &mut bitmask, &[]);
968        assert_eq!(
969            eliminated, 0,
970            "should not eliminate when next is branch target"
971        );
972    }
973
974    #[test]
975    fn test_eliminate_dead_load_imm_different_reg() {
976        // load_imm φ[2], 99; load_imm φ[3], 42 → NOT dead (different registers)
977        let mut code = vec![
978            51, 2, 99, // rd=2
979            51, 3, 42, // rd=3
980        ];
981        let mut bitmask = vec![1, 0, 0, 1, 0, 0];
982
983        let eliminated = peephole_eliminate_dead_load_imm(&mut code, &mut bitmask, &[]);
984        assert_eq!(eliminated, 0);
985    }
986
987    // === peephole_fuse_load_imm_memory tests ===
988
989    #[test]
990    fn test_fuse_load_imm_load_ind() {
991        // load_imm φ[3], 0x100; load_ind_u32 φ[2], φ[3], 4
992        // → NOP; load_u32 φ[2], 0x104
993        // load_ind_u32 format: [128, rd|(ra<<4), offset_bytes]
994        // rd=2 (dest), ra=3 (base). reg byte = 2 | (3<<4) = 0x32.
995        let mut code = vec![
996            51, 3, 0, 1, // load_imm rd=3, imm=0x100 (skip=2: reg+2 imm bytes)
997            128, 0x32, 4, 0, 0, 0, // load_ind_u32 rd=2, ra=3, offset=4
998        ];
999        let mut bitmask = vec![1, 0, 0, 0, 1, 0, 0, 0, 0, 0];
1000
1001        let fused = peephole_fuse_load_imm_memory(&mut code, &mut bitmask, &[]);
1002        assert_eq!(fused, 1);
1003        // load_imm is NOP'd (bitmask[0]=0), memory op rewritten in-place
1004        assert_eq!(bitmask[0], 0, "load_imm should be NOP'd");
1005        assert_eq!(code[4], 56, "load_ind_u32(128) → load_u32(56)");
1006        assert_eq!(code[5], 2, "dest register preserved");
1007        // Combined address: 0x100 + 4 = 0x104
1008        let addr = u32::from_le_bytes([code[6], code[7], code[8], code[9]]);
1009        assert_eq!(addr, 0x104);
1010    }
1011
1012    // === parse_signed_imm tests ===
1013
1014    #[test]
1015    fn test_parse_signed_imm_positive() {
1016        let code = [42, 0];
1017        assert_eq!(parse_signed_imm(&code, 0, 2), 42);
1018    }
1019
1020    #[test]
1021    fn test_parse_signed_imm_negative() {
1022        // -1 in 1 byte = 0xFF, sign-extended
1023        let code = [0xFF];
1024        assert_eq!(parse_signed_imm(&code, 0, 1), -1);
1025    }
1026
1027    #[test]
1028    fn test_parse_signed_imm_zero_length() {
1029        let code = [42];
1030        assert_eq!(parse_signed_imm(&code, 0, 0), 0);
1031    }
1032
1033    mod proptests {
1034        use super::*;
1035        use proptest::prelude::*;
1036
1037        /// Generate random PVM-like bytecode: instruction starts at every 3rd byte
1038        /// (simulating load_imm + ALU patterns).
1039        fn random_pvm_program() -> impl Strategy<Value = (Vec<u8>, Vec<u8>)> {
1040            // Generate 3-30 instructions, each 1-6 bytes
1041            proptest::collection::vec(
1042                (
1043                    0u8..=255u8,                                // opcode
1044                    proptest::collection::vec(0u8..=255, 0..5), // operand bytes
1045                ),
1046                3..30,
1047            )
1048            .prop_map(|instrs| {
1049                let mut code = Vec::new();
1050                let mut bitmask = Vec::new();
1051                for (opcode, operands) in &instrs {
1052                    code.push(*opcode);
1053                    bitmask.push(1u8);
1054                    for &b in operands {
1055                        code.push(b);
1056                        bitmask.push(0u8);
1057                    }
1058                }
1059                (code, bitmask)
1060            })
1061        }
1062
1063        proptest! {
1064            #![proptest_config(ProptestConfig::with_cases(256))]
1065
1066            /// Peephole ALU fusion is idempotent: applying twice produces the same
1067            /// code as applying once.
1068            #[test]
1069            fn alu_fusion_idempotent((code, bitmask) in random_pvm_program()) {
1070                let mut c1 = code.clone();
1071                let mut b1 = bitmask.clone();
1072                peephole_fuse_load_imm_alu(&mut c1, &mut b1, &[]);
1073
1074                let mut c2 = c1.clone();
1075                let mut b2 = b1.clone();
1076                peephole_fuse_load_imm_alu(&mut c2, &mut b2, &[]);
1077
1078                prop_assert_eq!(&c1, &c2, "code should not change on second pass");
1079                prop_assert_eq!(&b1, &b2, "bitmask should not change on second pass");
1080            }
1081
1082            /// Peephole memory fusion is idempotent.
1083            #[test]
1084            fn memory_fusion_idempotent((code, bitmask) in random_pvm_program()) {
1085                let mut c1 = code.clone();
1086                let mut b1 = bitmask.clone();
1087                peephole_fuse_load_imm_memory(&mut c1, &mut b1, &[]);
1088
1089                let mut c2 = c1.clone();
1090                let mut b2 = b1.clone();
1091                peephole_fuse_load_imm_memory(&mut c2, &mut b2, &[]);
1092
1093                prop_assert_eq!(&c1, &c2);
1094                prop_assert_eq!(&b1, &b2);
1095            }
1096
1097            /// Dead load_imm elimination is idempotent.
1098            #[test]
1099            fn dead_load_imm_idempotent((code, bitmask) in random_pvm_program()) {
1100                let mut c1 = code.clone();
1101                let mut b1 = bitmask.clone();
1102                peephole_eliminate_dead_load_imm(&mut c1, &mut b1, &[]);
1103
1104                let mut c2 = c1.clone();
1105                let mut b2 = b1.clone();
1106                peephole_eliminate_dead_load_imm(&mut c2, &mut b2, &[]);
1107
1108                prop_assert_eq!(&c1, &c2);
1109                prop_assert_eq!(&b1, &b2);
1110            }
1111
1112            /// Full peephole pipeline is idempotent.
1113            #[test]
1114            fn full_pipeline_idempotent((code, bitmask) in random_pvm_program()) {
1115                let apply = |c: &mut Vec<u8>, b: &mut Vec<u8>| {
1116                    peephole_fuse_load_imm_alu(c, b, &[]);
1117                    peephole_fuse_load_imm_memory(c, b, &[]);
1118                    peephole_eliminate_dead_load_imm(c, b, &[]);
1119                };
1120
1121                let mut c1 = code.clone();
1122                let mut b1 = bitmask.clone();
1123                apply(&mut c1, &mut b1);
1124
1125                let mut c2 = c1.clone();
1126                let mut b2 = b1.clone();
1127                apply(&mut c2, &mut b2);
1128
1129                prop_assert_eq!(&c1, &c2, "full pipeline should be idempotent");
1130                prop_assert_eq!(&b1, &b2);
1131            }
1132
1133            /// Peephole passes never increase code/bitmask length.
1134            #[test]
1135            fn passes_never_grow((code, bitmask) in random_pvm_program()) {
1136                let orig_len = code.len();
1137                let mut c = code;
1138                let mut b = bitmask;
1139                peephole_fuse_load_imm_alu(&mut c, &mut b, &[]);
1140                peephole_fuse_load_imm_memory(&mut c, &mut b, &[]);
1141                peephole_eliminate_dead_load_imm(&mut c, &mut b, &[]);
1142                prop_assert_eq!(c.len(), orig_len, "code length should not change");
1143                prop_assert_eq!(b.len(), orig_len, "bitmask length should not change");
1144            }
1145        }
1146    }
1147}