pibd: bound segment decoding allocations (#3850)

This commit is contained in:
David Burkett
2026-06-10 12:02:01 -04:00
committed by GitHub
parent db2923a147
commit 62e5ace442
5 changed files with 250 additions and 75 deletions
+115 -34
View File
@@ -280,6 +280,102 @@ pub struct BitmapSegment {
proof: SegmentProof,
}
impl BitmapSegment {
// Matches the upper end of the currently served PIBD bitmap segment range.
const MAX_SEGMENT_HEIGHT: u8 = 13;
fn max_chunks(identifier: &SegmentIdentifier) -> Result<usize, ser::Error> {
if identifier.height > Self::MAX_SEGMENT_HEIGHT {
return Err(ser::Error::TooLargeReadErr);
}
1usize
.checked_shl(identifier.height as u32)
.ok_or(ser::Error::TooLargeReadErr)
}
fn leaf_offset(identifier: &SegmentIdentifier) -> Result<u64, ser::Error> {
let segment_capacity = 1u64
.checked_shl(identifier.height as u32)
.ok_or(ser::Error::TooLargeReadErr)?;
segment_capacity
.checked_mul(identifier.idx)
.ok_or(ser::Error::TooLargeReadErr)
}
fn n_chunks(blocks: &[BitmapBlock]) -> Result<usize, ser::Error> {
let (last, full_blocks) = blocks.split_last().ok_or(ser::Error::CorruptedData)?;
for block in full_blocks {
if block.try_n_chunks()? != BitmapBlock::NCHUNKS {
return Err(ser::Error::CorruptedData);
}
}
let last_chunks = last.try_n_chunks()?;
if last_chunks == 0 {
return Err(ser::Error::CorruptedData);
}
full_blocks
.len()
.checked_mul(BitmapBlock::NCHUNKS)
.and_then(|n| n.checked_add(last_chunks))
.ok_or(ser::Error::TooLargeReadErr)
}
fn validate_blocks(
identifier: &SegmentIdentifier,
blocks: &[BitmapBlock],
) -> Result<usize, ser::Error> {
let offset = Self::leaf_offset(identifier)?;
let n_chunks = Self::n_chunks(blocks)?;
if n_chunks > Self::max_chunks(identifier)? {
return Err(ser::Error::TooLargeReadErr);
}
offset
.checked_add((n_chunks - 1) as u64)
.ok_or(ser::Error::TooLargeReadErr)?;
Ok(n_chunks)
}
/// Convert this bitmap segment into a PMMR segment, validating its encoded shape.
pub fn into_segment(self) -> Result<Segment<BitmapChunk>, ser::Error> {
let BitmapSegment {
identifier,
blocks,
proof,
} = self;
let n_chunks = Self::validate_blocks(&identifier, &blocks)?;
let mut leaf_pos = Vec::with_capacity(n_chunks);
let mut chunks = Vec::with_capacity(n_chunks);
let offset = Self::leaf_offset(&identifier)?;
for i in 0..(n_chunks as u64) {
let insertion_idx = offset.checked_add(i).ok_or(ser::Error::TooLargeReadErr)?;
leaf_pos.push(pmmr::insertion_to_pmmr_index(insertion_idx));
chunks.push(BitmapChunk::new());
}
for (block_idx, block) in blocks.into_iter().enumerate() {
block.try_n_chunks()?;
let offset = block_idx * BitmapBlock::NCHUNKS;
for (i, _) in block.inner.iter().enumerate().filter(|&(_, v)| v) {
chunks
.get_mut(offset + i / BitmapChunk::LEN_BITS)
.ok_or(ser::Error::CorruptedData)?
.0
.set(i % BitmapChunk::LEN_BITS, true);
}
}
Ok(Segment::from_parts(
identifier,
Vec::new(),
Vec::new(),
leaf_pos,
chunks,
proof,
))
}
}
impl Writeable for BitmapSegment {
fn write<W: Writer>(&self, writer: &mut W) -> Result<(), ser::Error> {
Writeable::write(&self.identifier, writer)?;
@@ -297,10 +393,20 @@ impl Readable for BitmapSegment {
let identifier: SegmentIdentifier = Readable::read(reader)?;
let n_blocks = reader.read_u16()? as usize;
if n_blocks == 0 {
return Err(ser::Error::CorruptedData);
}
let max_blocks = (BitmapSegment::max_chunks(&identifier)? + BitmapBlock::NCHUNKS - 1)
/ BitmapBlock::NCHUNKS;
if n_blocks > max_blocks {
return Err(ser::Error::TooLargeReadErr);
}
BitmapSegment::leaf_offset(&identifier)?;
let mut blocks = Vec::<BitmapBlock>::with_capacity(n_blocks);
for _ in 0..n_blocks {
blocks.push(Readable::read(reader)?);
}
BitmapSegment::validate_blocks(&identifier, &blocks)?;
let proof = Readable::read(reader)?;
Ok(Self {
@@ -348,36 +454,7 @@ impl From<Segment<BitmapChunk>> for BitmapSegment {
// TODO: this can be sped up with some `unsafe` code
impl From<BitmapSegment> for Segment<BitmapChunk> {
fn from(segment: BitmapSegment) -> Self {
let BitmapSegment {
identifier,
blocks,
proof,
} = segment;
// Count the number of chunks taking into account that the final block might be smaller
let n_chunks = (blocks.len() - 1) * BitmapBlock::NCHUNKS
+ blocks.last().map(|b| b.n_chunks()).unwrap_or(0);
let mut leaf_pos = Vec::with_capacity(n_chunks);
let mut chunks = Vec::with_capacity(n_chunks);
let offset = (1 << identifier.height) * identifier.idx;
for i in 0..(n_chunks as u64) {
leaf_pos.push(pmmr::insertion_to_pmmr_index(offset + i));
chunks.push(BitmapChunk::new());
}
for (block_idx, block) in blocks.into_iter().enumerate() {
assert!(block.inner.len() <= BitmapBlock::NBITS as usize);
let offset = block_idx * BitmapBlock::NCHUNKS;
for (i, _) in block.inner.iter().enumerate().filter(|&(_, v)| v) {
chunks
.get_mut(offset + i / BitmapChunk::LEN_BITS)
.unwrap()
.0
.set(i % BitmapChunk::LEN_BITS, true);
}
}
Segment::from_parts(identifier, Vec::new(), Vec::new(), leaf_pos, chunks, proof)
segment.into_segment().expect("valid bitmap segment")
}
}
@@ -401,12 +478,16 @@ impl BitmapBlock {
}
}
fn n_chunks(&self) -> usize {
fn try_n_chunks(&self) -> Result<usize, ser::Error> {
let length = self.inner.len();
assert_eq!(length % BitmapChunk::LEN_BITS, 0);
if length % BitmapChunk::LEN_BITS != 0 {
return Err(ser::Error::CorruptedData);
}
let n_chunks = length / BitmapChunk::LEN_BITS;
assert!(n_chunks <= BitmapBlock::NCHUNKS);
n_chunks
if n_chunks > BitmapBlock::NCHUNKS {
return Err(ser::Error::TooLargeReadErr);
}
Ok(n_chunks)
}
}
+61 -2
View File
@@ -1,7 +1,7 @@
use self::chain::txhashset::{BitmapAccumulator, BitmapSegment};
use self::core::core::pmmr::segment::{Segment, SegmentIdentifier};
use self::core::ser::{
BinReader, BinWriter, DeserializationMode, ProtocolVersion, Readable, Writeable,
self, BinReader, BinWriter, DeserializationMode, ProtocolVersion, Readable, Writeable,
};
use croaring::Bitmap;
use grin_chain as chain;
@@ -10,6 +10,29 @@ use grin_util::secp::rand::Rng;
use rand::thread_rng;
use std::io::Cursor;
fn push_u16(bytes: &mut Vec<u8>, n: u16) {
bytes.extend_from_slice(&n.to_be_bytes());
}
fn push_u64(bytes: &mut Vec<u8>, n: u64) {
bytes.extend_from_slice(&n.to_be_bytes());
}
fn bitmap_segment_header(height: u8, idx: u64, n_blocks: u16) -> Vec<u8> {
let mut bytes = vec![height];
push_u64(&mut bytes, idx);
push_u16(&mut bytes, n_blocks);
bytes
}
fn read_bitmap_segment(bytes: &[u8]) -> Result<BitmapSegment, ser::Error> {
ser::deserialize(
&mut &bytes[..],
ProtocolVersion(1),
DeserializationMode::default(),
)
}
fn test_roundtrip(entries: usize) {
let mut rng = thread_rng();
@@ -63,7 +86,7 @@ fn test_roundtrip(entries: usize) {
assert_eq!(bms, bms2);
// Convert back to `Segment`
let segment2 = Segment::from(bms2);
let segment2 = bms2.into_segment().unwrap();
assert_eq!(segment, segment2);
}
@@ -83,3 +106,39 @@ fn abundant_segment_ser_roundtrip() {
let max = 1 << 16;
test_roundtrip(thread_rng().gen_range(max - 4096, max - 1024));
}
#[test]
fn bitmap_segment_read_rejects_empty_blocks() {
let bytes = bitmap_segment_header(9, 0, 0);
assert_eq!(
read_bitmap_segment(&bytes).err(),
Some(ser::Error::CorruptedData)
);
}
#[test]
fn bitmap_segment_read_rejects_too_many_blocks() {
let bytes = bitmap_segment_header(9, 0, 9);
assert_eq!(
read_bitmap_segment(&bytes).err(),
Some(ser::Error::TooLargeReadErr)
);
}
#[test]
fn bitmap_segment_read_rejects_too_large_height() {
let bytes = bitmap_segment_header(14, 0, 1);
assert_eq!(
read_bitmap_segment(&bytes).err(),
Some(ser::Error::TooLargeReadErr)
);
}
#[test]
fn bitmap_segment_read_rejects_offset_overflow() {
let bytes = bitmap_segment_header(13, u64::MAX, 1);
assert_eq!(
read_bitmap_segment(&bytes).err(),
Some(ser::Error::TooLargeReadErr)
);
}
+41 -38
View File
@@ -21,6 +21,39 @@ use croaring::Bitmap;
use std::cmp::min;
use std::fmt::Debug;
const MAX_SEGMENT_READ_ITEMS: u64 = 1_000_000;
const SEGMENT_READ_PREALLOC_ITEMS: u64 = 1024;
fn read_segment_item_count<R: Reader>(reader: &mut R) -> Result<u64, Error> {
let count = reader.read_u64()?;
if count > MAX_SEGMENT_READ_ITEMS {
return Err(Error::TooLargeReadErr);
}
Ok(count)
}
fn read_segment_positions<R: Reader>(reader: &mut R, count: u64) -> Result<Vec<u64>, Error> {
let mut positions = Vec::with_capacity(min(count, SEGMENT_READ_PREALLOC_ITEMS) as usize);
let mut last_pos = 0;
for _ in 0..count {
let pos = reader.read_u64()?;
if pos <= last_pos {
return Err(Error::SortError);
}
last_pos = pos;
positions.push(pos - 1);
}
Ok(positions)
}
fn read_segment_items<T: Readable, R: Reader>(reader: &mut R, count: u64) -> Result<Vec<T>, Error> {
let mut items = Vec::with_capacity(min(count, SEGMENT_READ_PREALLOC_ITEMS) as usize);
for _ in 0..count {
items.push(T::read(reader)?);
}
Ok(items)
}
#[derive(Clone, Debug, Eq, PartialEq)]
/// Possible segment types, according to this desegmenter
pub enum SegmentType {
@@ -568,39 +601,13 @@ impl<T: Readable> Readable for Segment<T> {
fn read<R: Reader>(reader: &mut R) -> Result<Self, Error> {
let identifier = Readable::read(reader)?;
let n_hashes = reader.read_u64()? as usize;
let mut hash_pos = Vec::with_capacity(n_hashes);
let mut last_pos = 0;
for _ in 0..n_hashes {
let pos = reader.read_u64()?;
if pos <= last_pos {
return Err(Error::SortError);
}
last_pos = pos;
hash_pos.push(pos - 1);
}
let n_hashes = read_segment_item_count(reader)?;
let hash_pos = read_segment_positions(reader, n_hashes)?;
let hashes = read_segment_items(reader, n_hashes)?;
let mut hashes = Vec::<Hash>::with_capacity(n_hashes);
for _ in 0..n_hashes {
hashes.push(Readable::read(reader)?);
}
let n_leaves = reader.read_u64()? as usize;
let mut leaf_pos = Vec::with_capacity(n_leaves);
last_pos = 0;
for _ in 0..n_leaves {
let pos = reader.read_u64()?;
if pos <= last_pos {
return Err(Error::SortError);
}
last_pos = pos;
leaf_pos.push(pos - 1);
}
let mut leaf_data = Vec::<T>::with_capacity(n_leaves);
for _ in 0..n_leaves {
leaf_data.push(Readable::read(reader)?);
}
let n_leaves = read_segment_item_count(reader)?;
let leaf_pos = read_segment_positions(reader, n_leaves)?;
let leaf_data = read_segment_items(reader, n_leaves)?;
let proof = Readable::read(reader)?;
@@ -823,12 +830,8 @@ impl SegmentProof {
impl Readable for SegmentProof {
fn read<R: Reader>(reader: &mut R) -> Result<Self, Error> {
let n_hashes = reader.read_u64()? as usize;
let mut hashes = Vec::with_capacity(n_hashes);
for _ in 0..n_hashes {
let hash: Hash = Readable::read(reader)?;
hashes.push(hash);
}
let n_hashes = read_segment_item_count(reader)?;
let hashes = read_segment_items(reader, n_hashes)?;
Ok(Self { hashes })
}
}
+32
View File
@@ -16,10 +16,15 @@ mod common;
use self::core::core::pmmr;
use self::core::core::{Segment, SegmentIdentifier};
use self::core::ser::{self, DeserializationMode, ProtocolVersion};
use common::TestElem;
use grin_core as core;
use grin_core::core::pmmr::ReadablePMMR;
fn push_u64(bytes: &mut Vec<u8>, n: u64) {
bytes.extend_from_slice(&n.to_be_bytes());
}
fn test_unprunable_size(height: u8, n_leaves: u32) {
let size = 1u64 << height;
let n_segments = (n_leaves as u64 + size - 1) / size;
@@ -59,3 +64,30 @@ fn unprunable_mmr() {
test_unprunable_size(3, i);
}
}
#[test]
fn segment_read_rejects_large_hash_count() {
let mut bytes = vec![1];
push_u64(&mut bytes, 0);
push_u64(&mut bytes, 1_000_001);
let res: Result<Segment<TestElem>, _> = ser::deserialize(
&mut &bytes[..],
ProtocolVersion(1),
DeserializationMode::default(),
);
assert_eq!(res.err(), Some(ser::Error::TooLargeReadErr));
}
#[test]
fn segment_proof_read_rejects_large_hash_count() {
let mut bytes = vec![];
push_u64(&mut bytes, 1_000_001);
let res: Result<self::core::core::SegmentProof, _> = ser::deserialize(
&mut &bytes[..],
ProtocolVersion(1),
DeserializationMode::default(),
);
assert_eq!(res.err(), Some(ser::Error::TooLargeReadErr));
}
+1 -1
View File
@@ -382,7 +382,7 @@ impl MessageHandler for Protocol {
block_hash,
output_root
);
adapter.receive_bitmap_segment(block_hash, output_root, segment.into())?;
adapter.receive_bitmap_segment(block_hash, output_root, segment.into_segment()?)?;
Consumed::None
}
Message::OutputSegment(req) => {