1use alloc::vec::Vec;
13use core::fmt;
14use digest::Digest;
15use digest::typenum::U32;
16
17use crate::merkle::{merkleize, mix_in_length, pack_bytes};
18use crate::{BYTES_PER_LENGTH_OFFSET, Decode, DecodeError, Encode, HashTreeRoot};
19
20#[derive(Clone, PartialEq, Eq)]
30pub struct Bitvector<const N: usize> {
31 bytes: Vec<u8>,
32}
33
34#[inline]
35const fn bitvec_bytes(n: usize) -> usize {
36 n.div_ceil(8)
37}
38
39impl<const N: usize> Default for Bitvector<N> {
40 fn default() -> Self {
41 Self {
42 bytes: alloc::vec![0u8; bitvec_bytes(N)],
43 }
44 }
45}
46
47impl<const N: usize> Bitvector<N> {
48 pub fn from_slice(bytes: &[u8]) -> Result<Self, DecodeError> {
50 let needed = bitvec_bytes(N);
51 if bytes.len() != needed {
52 return Err(DecodeError::UnexpectedEof {
53 expected: needed,
54 actual: bytes.len(),
55 });
56 }
57 validate_trailing_zero_bits(bytes, N)?;
58 Ok(Self {
59 bytes: bytes.to_vec(),
60 })
61 }
62
63 pub fn get(&self, i: usize) -> bool {
65 assert!(i < N, "bit index out of bounds");
66 (self.bytes[i / 8] >> (i % 8)) & 1 == 1
67 }
68
69 pub fn set(&mut self, i: usize, v: bool) {
71 assert!(i < N, "bit index out of bounds");
72 let mask = 1u8 << (i % 8);
73 if v {
74 self.bytes[i / 8] |= mask;
75 } else {
76 self.bytes[i / 8] &= !mask;
77 }
78 }
79
80 pub fn as_bytes(&self) -> &[u8] {
82 &self.bytes
83 }
84
85 pub const fn len(&self) -> usize {
87 N
88 }
89
90 pub const fn is_empty(&self) -> bool {
92 N == 0
93 }
94}
95
96impl<const N: usize> fmt::Debug for Bitvector<N> {
97 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
98 f.debug_struct("Bitvector")
99 .field("n", &N)
100 .field("bytes", &&self.bytes[..])
101 .finish()
102 }
103}
104
105impl<const N: usize> Encode for Bitvector<N> {
106 fn is_ssz_fixed_len() -> bool {
107 true
108 }
109 fn ssz_fixed_len() -> usize {
110 bitvec_bytes(N)
111 }
112 fn ssz_bytes_len(&self) -> usize {
113 bitvec_bytes(N)
114 }
115 fn ssz_append(&self, buf: &mut Vec<u8>) {
116 buf.extend_from_slice(&self.bytes);
117 }
118}
119
120impl<const N: usize> Decode for Bitvector<N> {
121 fn is_ssz_fixed_len() -> bool {
122 true
123 }
124 fn ssz_fixed_len() -> usize {
125 bitvec_bytes(N)
126 }
127 fn from_ssz_bytes(bytes: &[u8]) -> Result<Self, DecodeError> {
128 Self::from_slice(bytes)
129 }
130}
131
132impl<const N: usize> HashTreeRoot for Bitvector<N> {
133 fn hash_tree_root<D: Digest<OutputSize = U32>>(&self) -> [u8; 32] {
134 let chunks = pack_bytes(&self.bytes);
135 let chunk_limit = bitvec_bytes(N).div_ceil(32).max(1);
136 merkleize::<D>(&chunks, chunk_limit)
137 }
138}
139
140#[inline]
141fn validate_trailing_zero_bits(bytes: &[u8], n: usize) -> Result<(), DecodeError> {
142 if n == 0 {
143 if bytes.iter().any(|&b| b != 0) {
146 return Err(DecodeError::ExcessBits);
147 }
148 return Ok(());
149 }
150 let last_byte_bits = n % 8;
151 if last_byte_bits == 0 {
152 return Ok(());
153 }
154 let mask = !((1u8 << last_byte_bits) - 1);
155 let last = bytes[bytes.len() - 1];
156 if last & mask != 0 {
157 return Err(DecodeError::ExcessBits);
158 }
159 Ok(())
160}
161
162pub struct Bitlist<const N: u64> {
172 bytes: Vec<u8>,
173 bit_len: u64,
174}
175
176impl<const N: u64> Bitlist<N> {
177 pub fn new() -> Self {
179 Self {
180 bytes: Vec::new(),
181 bit_len: 0,
182 }
183 }
184
185 pub fn from_bits(bits: &[bool]) -> Result<Self, DecodeError> {
187 if (bits.len() as u64) > N {
188 return Err(DecodeError::BoundExceeded {
189 len: bits.len() as u64,
190 bound: N,
191 });
192 }
193 let byte_len = bits.len().div_ceil(8);
194 let mut bytes: Vec<u8> = alloc::vec![0u8; byte_len];
195 for (i, b) in bits.iter().enumerate() {
196 if *b {
197 bytes[i / 8] |= 1 << (i % 8);
198 }
199 }
200 Ok(Self {
201 bytes,
202 bit_len: bits.len() as u64,
203 })
204 }
205
206 pub fn len(&self) -> u64 {
208 self.bit_len
209 }
210
211 pub fn is_empty(&self) -> bool {
213 self.bit_len == 0
214 }
215
216 pub fn get(&self, i: u64) -> Option<bool> {
218 if i >= self.bit_len {
219 return None;
220 }
221 let byte = self.bytes[(i / 8) as usize];
222 Some((byte >> (i % 8)) & 1 == 1)
223 }
224
225 pub fn data_bytes(&self) -> &[u8] {
227 &self.bytes
228 }
229
230 fn wire_byte_len(&self) -> usize {
232 ((self.bit_len + 1) as usize).div_ceil(8)
233 }
234}
235
236impl<const N: u64> Default for Bitlist<N> {
237 fn default() -> Self {
238 Self::new()
239 }
240}
241
242impl<const N: u64> fmt::Debug for Bitlist<N> {
243 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
244 f.debug_struct("Bitlist")
245 .field("cap", &N)
246 .field("bit_len", &self.bit_len)
247 .field("bytes", &&self.bytes[..])
248 .finish()
249 }
250}
251
252impl<const N: u64> Clone for Bitlist<N> {
253 fn clone(&self) -> Self {
254 Self {
255 bytes: self.bytes.clone(),
256 bit_len: self.bit_len,
257 }
258 }
259}
260
261impl<const N: u64> PartialEq for Bitlist<N> {
262 fn eq(&self, other: &Self) -> bool {
263 if self.bit_len != other.bit_len {
264 return false;
265 }
266 self.bytes == other.bytes
267 }
268}
269
270impl<const N: u64> Eq for Bitlist<N> {}
271
272impl<const N: u64> Encode for Bitlist<N> {
273 fn is_ssz_fixed_len() -> bool {
274 false
275 }
276 fn ssz_fixed_len() -> usize {
277 BYTES_PER_LENGTH_OFFSET
278 }
279 fn ssz_bytes_len(&self) -> usize {
280 self.wire_byte_len()
281 }
282 fn ssz_append(&self, buf: &mut Vec<u8>) {
283 let wire_len = self.wire_byte_len();
285 let start = buf.len();
286 buf.resize(start + wire_len, 0u8);
287 for i in 0..self.bit_len {
289 if (self.bytes[(i / 8) as usize] >> (i % 8)) & 1 == 1 {
290 buf[start + (i / 8) as usize] |= 1 << (i % 8);
291 }
292 }
293 let sb = start + (self.bit_len / 8) as usize;
295 buf[sb] |= 1u8 << (self.bit_len % 8);
296 }
297}
298
299impl<const N: u64> Decode for Bitlist<N> {
300 fn is_ssz_fixed_len() -> bool {
301 false
302 }
303 fn ssz_fixed_len() -> usize {
304 BYTES_PER_LENGTH_OFFSET
305 }
306 fn from_ssz_bytes(bytes: &[u8]) -> Result<Self, DecodeError> {
307 if bytes.is_empty() {
308 return Err(DecodeError::MissingBitlistSentinel);
309 }
310 let last = bytes[bytes.len() - 1];
311 if last == 0 {
312 return Err(DecodeError::MissingBitlistSentinel);
313 }
314 let sentinel_bit_in_byte = 7 - last.leading_zeros() as usize; let bit_len = ((bytes.len() - 1) * 8 + sentinel_bit_in_byte) as u64;
316 if bit_len > N {
317 return Err(DecodeError::BoundExceeded {
318 len: bit_len,
319 bound: N,
320 });
321 }
322
323 let mut data: Vec<u8> = Vec::with_capacity(bytes.len());
326 data.extend_from_slice(bytes);
327 if let Some(last_byte) = data.last_mut() {
328 let keep_mask = (1u8 << sentinel_bit_in_byte).wrapping_sub(1);
329 *last_byte &= keep_mask;
330 }
331 let needed_data_bytes = (bit_len as usize).div_ceil(8);
334 while data.len() > needed_data_bytes {
335 data.pop();
336 }
337 Ok(Self {
342 bytes: data,
343 bit_len,
344 })
345 }
346}
347
348impl<const N: u64> HashTreeRoot for Bitlist<N> {
349 fn hash_tree_root<D: Digest<OutputSize = U32>>(&self) -> [u8; 32] {
350 let data = &self.bytes[..];
351 let chunks = pack_bytes(data);
352 let cap_bytes = (N as usize).div_ceil(8);
353 let chunk_limit = cap_bytes.div_ceil(32).max(1);
354 let inner = merkleize::<D>(&chunks, chunk_limit);
355 mix_in_length::<D>(inner, self.bit_len)
356 }
357}
358
359#[cfg(test)]
360mod tests {
361 use super::*;
362 use sha2::Sha256;
363
364 #[test]
365 fn bitvector_set_get() {
366 let mut bv: Bitvector<10> = Bitvector::default();
367 bv.set(0, true);
368 bv.set(9, true);
369 assert!(bv.get(0));
370 assert!(!bv.get(1));
371 assert!(bv.get(9));
372 }
373
374 #[test]
375 fn bitvector_round_trip() {
376 let mut bv: Bitvector<10> = Bitvector::default();
377 bv.set(0, true);
378 bv.set(3, true);
379 bv.set(9, true);
380 let bytes = bv.as_ssz_bytes();
381 let decoded = Bitvector::<10>::from_ssz_bytes(&bytes).unwrap();
382 assert_eq!(bv, decoded);
383 }
384
385 #[test]
386 fn bitvector_rejects_excess_bits() {
387 let raw = [0b00010000u8];
389 assert!(Bitvector::<4>::from_slice(&raw).is_err());
390 }
391
392 #[test]
393 fn bitlist_empty_round_trip() {
394 let bl: Bitlist<256> = Bitlist::new();
395 let bytes = bl.as_ssz_bytes();
396 assert_eq!(bytes, vec![0x01]);
398 let decoded = Bitlist::<256>::from_ssz_bytes(&bytes).unwrap();
399 assert_eq!(bl, decoded);
400 assert_eq!(decoded.len(), 0);
401 }
402
403 #[test]
404 fn bitlist_round_trip() {
405 let bits = [true, false, true, true, false, true, false, false, true];
406 let bl: Bitlist<256> = Bitlist::from_bits(&bits).unwrap();
407 let bytes = bl.as_ssz_bytes();
408 let decoded = Bitlist::<256>::from_ssz_bytes(&bytes).unwrap();
409 assert_eq!(bl, decoded);
410 assert_eq!(decoded.len(), 9);
411 for (i, b) in bits.iter().enumerate() {
412 assert_eq!(decoded.get(i as u64), Some(*b));
413 }
414 }
415
416 #[test]
417 fn bitlist_hash_matches_after_round_trip() {
418 let bits = [true, false, true, false, true];
419 let bl: Bitlist<256> = Bitlist::from_bits(&bits).unwrap();
420 let h1 = bl.hash_tree_root::<Sha256>();
421 let bytes = bl.as_ssz_bytes();
422 let bl2 = Bitlist::<256>::from_ssz_bytes(&bytes).unwrap();
423 let h2 = bl2.hash_tree_root::<Sha256>();
424 assert_eq!(h1, h2);
425 }
426}