Skip to main content

ssz/
bits.rs

1//! `Bitvector<N>` and `Bitlist<N>` — bit-packed homogeneous boolean storage.
2//!
3//! Both use big-endian bit packing within each byte: bit `i` of the
4//! logical bitstream is stored at `bytes[i / 8] & (1 << (i % 8))`.
5//!
6//! * `Bitvector<N>` has length exactly `N`, takes `(N + 7) / 8` bytes; any
7//!   bits beyond `N` in the final byte must be zero.
8//! * `Bitlist<N>` has variable length up to `N`. The wire form appends a
9//!   sentinel `1` bit immediately after the data bits; the decoder finds
10//!   the highest set bit in the final byte to recover the length.
11
12use alloc::vec::Vec;
13use core::fmt;
14use digest::Digest;
15use digest::typenum::U32;
16
17use crate::merkle::{merkleize, mix_in_length, pack_bytes};
18use crate::{BYTES_PER_LENGTH_OFFSET, Decode, DecodeError, Encode, HashTreeRoot};
19
20// --------------------------------------------------------------------------
21// Bitvector<N>
22// --------------------------------------------------------------------------
23
24/// SSZ Bitvector with a compile-time length of `N` bits.
25///
26/// Storage is a heap-allocated byte vector of length `(N + 7) / 8`. Using a
27/// `Vec` avoids `generic_const_exprs` (still unstable) while keeping the
28/// invariant at the type level.
29#[derive(Clone, PartialEq, Eq)]
30pub struct Bitvector<const N: usize> {
31    bytes: Vec<u8>,
32}
33
34#[inline]
35const fn bitvec_bytes(n: usize) -> usize {
36    n.div_ceil(8)
37}
38
39impl<const N: usize> Default for Bitvector<N> {
40    fn default() -> Self {
41        Self {
42            bytes: alloc::vec![0u8; bitvec_bytes(N)],
43        }
44    }
45}
46
47impl<const N: usize> Bitvector<N> {
48    /// Build from a slice (must have exact length `(N+7)/8`).
49    pub fn from_slice(bytes: &[u8]) -> Result<Self, DecodeError> {
50        let needed = bitvec_bytes(N);
51        if bytes.len() != needed {
52            return Err(DecodeError::UnexpectedEof {
53                expected: needed,
54                actual: bytes.len(),
55            });
56        }
57        validate_trailing_zero_bits(bytes, N)?;
58        Ok(Self {
59            bytes: bytes.to_vec(),
60        })
61    }
62
63    /// Get bit `i`. Panics if `i >= N`.
64    pub fn get(&self, i: usize) -> bool {
65        assert!(i < N, "bit index out of bounds");
66        (self.bytes[i / 8] >> (i % 8)) & 1 == 1
67    }
68
69    /// Set bit `i` to `v`. Panics if `i >= N`.
70    pub fn set(&mut self, i: usize, v: bool) {
71        assert!(i < N, "bit index out of bounds");
72        let mask = 1u8 << (i % 8);
73        if v {
74            self.bytes[i / 8] |= mask;
75        } else {
76            self.bytes[i / 8] &= !mask;
77        }
78    }
79
80    /// Borrow the raw packed bytes.
81    pub fn as_bytes(&self) -> &[u8] {
82        &self.bytes
83    }
84
85    /// Length in bits (always `N`).
86    pub const fn len(&self) -> usize {
87        N
88    }
89
90    /// `true` iff `N == 0`.
91    pub const fn is_empty(&self) -> bool {
92        N == 0
93    }
94}
95
96impl<const N: usize> fmt::Debug for Bitvector<N> {
97    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
98        f.debug_struct("Bitvector")
99            .field("n", &N)
100            .field("bytes", &&self.bytes[..])
101            .finish()
102    }
103}
104
105impl<const N: usize> Encode for Bitvector<N> {
106    fn is_ssz_fixed_len() -> bool {
107        true
108    }
109    fn ssz_fixed_len() -> usize {
110        bitvec_bytes(N)
111    }
112    fn ssz_bytes_len(&self) -> usize {
113        bitvec_bytes(N)
114    }
115    fn ssz_append(&self, buf: &mut Vec<u8>) {
116        buf.extend_from_slice(&self.bytes);
117    }
118}
119
120impl<const N: usize> Decode for Bitvector<N> {
121    fn is_ssz_fixed_len() -> bool {
122        true
123    }
124    fn ssz_fixed_len() -> usize {
125        bitvec_bytes(N)
126    }
127    fn from_ssz_bytes(bytes: &[u8]) -> Result<Self, DecodeError> {
128        Self::from_slice(bytes)
129    }
130}
131
132impl<const N: usize> HashTreeRoot for Bitvector<N> {
133    fn hash_tree_root<D: Digest<OutputSize = U32>>(&self) -> [u8; 32] {
134        let chunks = pack_bytes(&self.bytes);
135        let chunk_limit = bitvec_bytes(N).div_ceil(32).max(1);
136        merkleize::<D>(&chunks, chunk_limit)
137    }
138}
139
140#[inline]
141fn validate_trailing_zero_bits(bytes: &[u8], n: usize) -> Result<(), DecodeError> {
142    if n == 0 {
143        // Even with N=0 we expect a zero-length byte slice; if any byte
144        // exists it must be all zeros (defensive).
145        if bytes.iter().any(|&b| b != 0) {
146            return Err(DecodeError::ExcessBits);
147        }
148        return Ok(());
149    }
150    let last_byte_bits = n % 8;
151    if last_byte_bits == 0 {
152        return Ok(());
153    }
154    let mask = !((1u8 << last_byte_bits) - 1);
155    let last = bytes[bytes.len() - 1];
156    if last & mask != 0 {
157        return Err(DecodeError::ExcessBits);
158    }
159    Ok(())
160}
161
162// --------------------------------------------------------------------------
163// Bitlist<N>
164// --------------------------------------------------------------------------
165
166/// SSZ Bitlist with a compile-time bit cap of `N`.
167///
168/// Wire format: packed bits (LSB-first within bytes), followed by a
169/// sentinel `1` bit immediately after the data bits. The sentinel marks
170/// the end of the logical bitstream and is not part of the bit content.
171pub struct Bitlist<const N: u64> {
172    bytes: Vec<u8>,
173    bit_len: u64,
174}
175
176impl<const N: u64> Bitlist<N> {
177    /// Build an empty bitlist.
178    pub fn new() -> Self {
179        Self {
180            bytes: Vec::new(),
181            bit_len: 0,
182        }
183    }
184
185    /// Build from a logical bit vector.
186    pub fn from_bits(bits: &[bool]) -> Result<Self, DecodeError> {
187        if (bits.len() as u64) > N {
188            return Err(DecodeError::BoundExceeded {
189                len: bits.len() as u64,
190                bound: N,
191            });
192        }
193        let byte_len = bits.len().div_ceil(8);
194        let mut bytes: Vec<u8> = alloc::vec![0u8; byte_len];
195        for (i, b) in bits.iter().enumerate() {
196            if *b {
197                bytes[i / 8] |= 1 << (i % 8);
198            }
199        }
200        Ok(Self {
201            bytes,
202            bit_len: bits.len() as u64,
203        })
204    }
205
206    /// Logical bit length.
207    pub fn len(&self) -> u64 {
208        self.bit_len
209    }
210
211    /// `true` iff there are no bits.
212    pub fn is_empty(&self) -> bool {
213        self.bit_len == 0
214    }
215
216    /// Get bit `i`. Returns `None` if `i >= len`.
217    pub fn get(&self, i: u64) -> Option<bool> {
218        if i >= self.bit_len {
219            return None;
220        }
221        let byte = self.bytes[(i / 8) as usize];
222        Some((byte >> (i % 8)) & 1 == 1)
223    }
224
225    /// Borrow the raw packed data bytes (without the sentinel layer).
226    pub fn data_bytes(&self) -> &[u8] {
227        &self.bytes
228    }
229
230    /// Total wire byte length (data + sentinel byte if no spare bit).
231    fn wire_byte_len(&self) -> usize {
232        ((self.bit_len + 1) as usize).div_ceil(8)
233    }
234}
235
236impl<const N: u64> Default for Bitlist<N> {
237    fn default() -> Self {
238        Self::new()
239    }
240}
241
242impl<const N: u64> fmt::Debug for Bitlist<N> {
243    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
244        f.debug_struct("Bitlist")
245            .field("cap", &N)
246            .field("bit_len", &self.bit_len)
247            .field("bytes", &&self.bytes[..])
248            .finish()
249    }
250}
251
252impl<const N: u64> Clone for Bitlist<N> {
253    fn clone(&self) -> Self {
254        Self {
255            bytes: self.bytes.clone(),
256            bit_len: self.bit_len,
257        }
258    }
259}
260
261impl<const N: u64> PartialEq for Bitlist<N> {
262    fn eq(&self, other: &Self) -> bool {
263        if self.bit_len != other.bit_len {
264            return false;
265        }
266        self.bytes == other.bytes
267    }
268}
269
270impl<const N: u64> Eq for Bitlist<N> {}
271
272impl<const N: u64> Encode for Bitlist<N> {
273    fn is_ssz_fixed_len() -> bool {
274        false
275    }
276    fn ssz_fixed_len() -> usize {
277        BYTES_PER_LENGTH_OFFSET
278    }
279    fn ssz_bytes_len(&self) -> usize {
280        self.wire_byte_len()
281    }
282    fn ssz_append(&self, buf: &mut Vec<u8>) {
283        // Layout: copy data bytes, then set sentinel bit at position `bit_len`.
284        let wire_len = self.wire_byte_len();
285        let start = buf.len();
286        buf.resize(start + wire_len, 0u8);
287        // Copy data bits.
288        for i in 0..self.bit_len {
289            if (self.bytes[(i / 8) as usize] >> (i % 8)) & 1 == 1 {
290                buf[start + (i / 8) as usize] |= 1 << (i % 8);
291            }
292        }
293        // Sentinel.
294        let sb = start + (self.bit_len / 8) as usize;
295        buf[sb] |= 1u8 << (self.bit_len % 8);
296    }
297}
298
299impl<const N: u64> Decode for Bitlist<N> {
300    fn is_ssz_fixed_len() -> bool {
301        false
302    }
303    fn ssz_fixed_len() -> usize {
304        BYTES_PER_LENGTH_OFFSET
305    }
306    fn from_ssz_bytes(bytes: &[u8]) -> Result<Self, DecodeError> {
307        if bytes.is_empty() {
308            return Err(DecodeError::MissingBitlistSentinel);
309        }
310        let last = bytes[bytes.len() - 1];
311        if last == 0 {
312            return Err(DecodeError::MissingBitlistSentinel);
313        }
314        let sentinel_bit_in_byte = 7 - last.leading_zeros() as usize; // highest set bit
315        let bit_len = ((bytes.len() - 1) * 8 + sentinel_bit_in_byte) as u64;
316        if bit_len > N {
317            return Err(DecodeError::BoundExceeded {
318                len: bit_len,
319                bound: N,
320            });
321        }
322
323        // Extract data: copy all bytes, then clear the sentinel bit and the
324        // bits above it in the last byte.
325        let mut data: Vec<u8> = Vec::with_capacity(bytes.len());
326        data.extend_from_slice(bytes);
327        if let Some(last_byte) = data.last_mut() {
328            let keep_mask = (1u8 << sentinel_bit_in_byte).wrapping_sub(1);
329            *last_byte &= keep_mask;
330        }
331        // Trim the trailing data byte if it ended up all-zero AND we don't
332        // need it for data bits.
333        let needed_data_bytes = (bit_len as usize).div_ceil(8);
334        while data.len() > needed_data_bytes {
335            data.pop();
336        }
337        // If the needed data byte count is shorter than the wire byte count
338        // by 1 and we already popped, the data array now holds exactly
339        // `needed_data_bytes` bytes. Otherwise we keep `data.len() ==
340        // bytes.len()` (sentinel was alone in its byte, masked off).
341        Ok(Self {
342            bytes: data,
343            bit_len,
344        })
345    }
346}
347
348impl<const N: u64> HashTreeRoot for Bitlist<N> {
349    fn hash_tree_root<D: Digest<OutputSize = U32>>(&self) -> [u8; 32] {
350        let data = &self.bytes[..];
351        let chunks = pack_bytes(data);
352        let cap_bytes = (N as usize).div_ceil(8);
353        let chunk_limit = cap_bytes.div_ceil(32).max(1);
354        let inner = merkleize::<D>(&chunks, chunk_limit);
355        mix_in_length::<D>(inner, self.bit_len)
356    }
357}
358
359#[cfg(test)]
360mod tests {
361    use super::*;
362    use sha2::Sha256;
363
364    #[test]
365    fn bitvector_set_get() {
366        let mut bv: Bitvector<10> = Bitvector::default();
367        bv.set(0, true);
368        bv.set(9, true);
369        assert!(bv.get(0));
370        assert!(!bv.get(1));
371        assert!(bv.get(9));
372    }
373
374    #[test]
375    fn bitvector_round_trip() {
376        let mut bv: Bitvector<10> = Bitvector::default();
377        bv.set(0, true);
378        bv.set(3, true);
379        bv.set(9, true);
380        let bytes = bv.as_ssz_bytes();
381        let decoded = Bitvector::<10>::from_ssz_bytes(&bytes).unwrap();
382        assert_eq!(bv, decoded);
383    }
384
385    #[test]
386    fn bitvector_rejects_excess_bits() {
387        // N=4, but high nibble has a set bit
388        let raw = [0b00010000u8];
389        assert!(Bitvector::<4>::from_slice(&raw).is_err());
390    }
391
392    #[test]
393    fn bitlist_empty_round_trip() {
394        let bl: Bitlist<256> = Bitlist::new();
395        let bytes = bl.as_ssz_bytes();
396        // Empty bitlist: sentinel bit at position 0 → byte 0x01.
397        assert_eq!(bytes, vec![0x01]);
398        let decoded = Bitlist::<256>::from_ssz_bytes(&bytes).unwrap();
399        assert_eq!(bl, decoded);
400        assert_eq!(decoded.len(), 0);
401    }
402
403    #[test]
404    fn bitlist_round_trip() {
405        let bits = [true, false, true, true, false, true, false, false, true];
406        let bl: Bitlist<256> = Bitlist::from_bits(&bits).unwrap();
407        let bytes = bl.as_ssz_bytes();
408        let decoded = Bitlist::<256>::from_ssz_bytes(&bytes).unwrap();
409        assert_eq!(bl, decoded);
410        assert_eq!(decoded.len(), 9);
411        for (i, b) in bits.iter().enumerate() {
412            assert_eq!(decoded.get(i as u64), Some(*b));
413        }
414    }
415
416    #[test]
417    fn bitlist_hash_matches_after_round_trip() {
418        let bits = [true, false, true, false, true];
419        let bl: Bitlist<256> = Bitlist::from_bits(&bits).unwrap();
420        let h1 = bl.hash_tree_root::<Sha256>();
421        let bytes = bl.as_ssz_bytes();
422        let bl2 = Bitlist::<256>::from_ssz_bytes(&bytes).unwrap();
423        let h2 = bl2.hash_tree_root::<Sha256>();
424        assert_eq!(h1, h2);
425    }
426}