diff --git a/chain/src/txhashset/bitmap_accumulator.rs b/chain/src/txhashset/bitmap_accumulator.rs index cae45f7d..6d6b388d 100644 --- a/chain/src/txhashset/bitmap_accumulator.rs +++ b/chain/src/txhashset/bitmap_accumulator.rs @@ -280,6 +280,102 @@ pub struct BitmapSegment { proof: SegmentProof, } +impl BitmapSegment { + // Matches the upper end of the currently served PIBD bitmap segment range. + const MAX_SEGMENT_HEIGHT: u8 = 13; + + fn max_chunks(identifier: &SegmentIdentifier) -> Result { + if identifier.height > Self::MAX_SEGMENT_HEIGHT { + return Err(ser::Error::TooLargeReadErr); + } + 1usize + .checked_shl(identifier.height as u32) + .ok_or(ser::Error::TooLargeReadErr) + } + + fn leaf_offset(identifier: &SegmentIdentifier) -> Result { + let segment_capacity = 1u64 + .checked_shl(identifier.height as u32) + .ok_or(ser::Error::TooLargeReadErr)?; + segment_capacity + .checked_mul(identifier.idx) + .ok_or(ser::Error::TooLargeReadErr) + } + + fn n_chunks(blocks: &[BitmapBlock]) -> Result { + let (last, full_blocks) = blocks.split_last().ok_or(ser::Error::CorruptedData)?; + for block in full_blocks { + if block.try_n_chunks()? != BitmapBlock::NCHUNKS { + return Err(ser::Error::CorruptedData); + } + } + let last_chunks = last.try_n_chunks()?; + if last_chunks == 0 { + return Err(ser::Error::CorruptedData); + } + full_blocks + .len() + .checked_mul(BitmapBlock::NCHUNKS) + .and_then(|n| n.checked_add(last_chunks)) + .ok_or(ser::Error::TooLargeReadErr) + } + + fn validate_blocks( + identifier: &SegmentIdentifier, + blocks: &[BitmapBlock], + ) -> Result { + let offset = Self::leaf_offset(identifier)?; + let n_chunks = Self::n_chunks(blocks)?; + if n_chunks > Self::max_chunks(identifier)? { + return Err(ser::Error::TooLargeReadErr); + } + offset + .checked_add((n_chunks - 1) as u64) + .ok_or(ser::Error::TooLargeReadErr)?; + Ok(n_chunks) + } + + /// Convert this bitmap segment into a PMMR segment, validating its encoded shape. + pub fn into_segment(self) -> Result, ser::Error> { + let BitmapSegment { + identifier, + blocks, + proof, + } = self; + + let n_chunks = Self::validate_blocks(&identifier, &blocks)?; + let mut leaf_pos = Vec::with_capacity(n_chunks); + let mut chunks = Vec::with_capacity(n_chunks); + let offset = Self::leaf_offset(&identifier)?; + for i in 0..(n_chunks as u64) { + let insertion_idx = offset.checked_add(i).ok_or(ser::Error::TooLargeReadErr)?; + leaf_pos.push(pmmr::insertion_to_pmmr_index(insertion_idx)); + chunks.push(BitmapChunk::new()); + } + + for (block_idx, block) in blocks.into_iter().enumerate() { + block.try_n_chunks()?; + let offset = block_idx * BitmapBlock::NCHUNKS; + for (i, _) in block.inner.iter().enumerate().filter(|&(_, v)| v) { + chunks + .get_mut(offset + i / BitmapChunk::LEN_BITS) + .ok_or(ser::Error::CorruptedData)? + .0 + .set(i % BitmapChunk::LEN_BITS, true); + } + } + + Ok(Segment::from_parts( + identifier, + Vec::new(), + Vec::new(), + leaf_pos, + chunks, + proof, + )) + } +} + impl Writeable for BitmapSegment { fn write(&self, writer: &mut W) -> Result<(), ser::Error> { Writeable::write(&self.identifier, writer)?; @@ -297,10 +393,20 @@ impl Readable for BitmapSegment { let identifier: SegmentIdentifier = Readable::read(reader)?; let n_blocks = reader.read_u16()? as usize; + if n_blocks == 0 { + return Err(ser::Error::CorruptedData); + } + let max_blocks = (BitmapSegment::max_chunks(&identifier)? + BitmapBlock::NCHUNKS - 1) + / BitmapBlock::NCHUNKS; + if n_blocks > max_blocks { + return Err(ser::Error::TooLargeReadErr); + } + BitmapSegment::leaf_offset(&identifier)?; let mut blocks = Vec::::with_capacity(n_blocks); for _ in 0..n_blocks { blocks.push(Readable::read(reader)?); } + BitmapSegment::validate_blocks(&identifier, &blocks)?; let proof = Readable::read(reader)?; Ok(Self { @@ -348,36 +454,7 @@ impl From> for BitmapSegment { // TODO: this can be sped up with some `unsafe` code impl From for Segment { fn from(segment: BitmapSegment) -> Self { - let BitmapSegment { - identifier, - blocks, - proof, - } = segment; - - // Count the number of chunks taking into account that the final block might be smaller - let n_chunks = (blocks.len() - 1) * BitmapBlock::NCHUNKS - + blocks.last().map(|b| b.n_chunks()).unwrap_or(0); - let mut leaf_pos = Vec::with_capacity(n_chunks); - let mut chunks = Vec::with_capacity(n_chunks); - let offset = (1 << identifier.height) * identifier.idx; - for i in 0..(n_chunks as u64) { - leaf_pos.push(pmmr::insertion_to_pmmr_index(offset + i)); - chunks.push(BitmapChunk::new()); - } - - for (block_idx, block) in blocks.into_iter().enumerate() { - assert!(block.inner.len() <= BitmapBlock::NBITS as usize); - let offset = block_idx * BitmapBlock::NCHUNKS; - for (i, _) in block.inner.iter().enumerate().filter(|&(_, v)| v) { - chunks - .get_mut(offset + i / BitmapChunk::LEN_BITS) - .unwrap() - .0 - .set(i % BitmapChunk::LEN_BITS, true); - } - } - - Segment::from_parts(identifier, Vec::new(), Vec::new(), leaf_pos, chunks, proof) + segment.into_segment().expect("valid bitmap segment") } } @@ -401,12 +478,16 @@ impl BitmapBlock { } } - fn n_chunks(&self) -> usize { + fn try_n_chunks(&self) -> Result { let length = self.inner.len(); - assert_eq!(length % BitmapChunk::LEN_BITS, 0); + if length % BitmapChunk::LEN_BITS != 0 { + return Err(ser::Error::CorruptedData); + } let n_chunks = length / BitmapChunk::LEN_BITS; - assert!(n_chunks <= BitmapBlock::NCHUNKS); - n_chunks + if n_chunks > BitmapBlock::NCHUNKS { + return Err(ser::Error::TooLargeReadErr); + } + Ok(n_chunks) } } diff --git a/chain/tests/bitmap_segment.rs b/chain/tests/bitmap_segment.rs index 960625ba..39092b8e 100644 --- a/chain/tests/bitmap_segment.rs +++ b/chain/tests/bitmap_segment.rs @@ -1,7 +1,7 @@ use self::chain::txhashset::{BitmapAccumulator, BitmapSegment}; use self::core::core::pmmr::segment::{Segment, SegmentIdentifier}; use self::core::ser::{ - BinReader, BinWriter, DeserializationMode, ProtocolVersion, Readable, Writeable, + self, BinReader, BinWriter, DeserializationMode, ProtocolVersion, Readable, Writeable, }; use croaring::Bitmap; use grin_chain as chain; @@ -10,6 +10,29 @@ use grin_util::secp::rand::Rng; use rand::thread_rng; use std::io::Cursor; +fn push_u16(bytes: &mut Vec, n: u16) { + bytes.extend_from_slice(&n.to_be_bytes()); +} + +fn push_u64(bytes: &mut Vec, n: u64) { + bytes.extend_from_slice(&n.to_be_bytes()); +} + +fn bitmap_segment_header(height: u8, idx: u64, n_blocks: u16) -> Vec { + let mut bytes = vec![height]; + push_u64(&mut bytes, idx); + push_u16(&mut bytes, n_blocks); + bytes +} + +fn read_bitmap_segment(bytes: &[u8]) -> Result { + ser::deserialize( + &mut &bytes[..], + ProtocolVersion(1), + DeserializationMode::default(), + ) +} + fn test_roundtrip(entries: usize) { let mut rng = thread_rng(); @@ -63,7 +86,7 @@ fn test_roundtrip(entries: usize) { assert_eq!(bms, bms2); // Convert back to `Segment` - let segment2 = Segment::from(bms2); + let segment2 = bms2.into_segment().unwrap(); assert_eq!(segment, segment2); } @@ -83,3 +106,39 @@ fn abundant_segment_ser_roundtrip() { let max = 1 << 16; test_roundtrip(thread_rng().gen_range(max - 4096, max - 1024)); } + +#[test] +fn bitmap_segment_read_rejects_empty_blocks() { + let bytes = bitmap_segment_header(9, 0, 0); + assert_eq!( + read_bitmap_segment(&bytes).err(), + Some(ser::Error::CorruptedData) + ); +} + +#[test] +fn bitmap_segment_read_rejects_too_many_blocks() { + let bytes = bitmap_segment_header(9, 0, 9); + assert_eq!( + read_bitmap_segment(&bytes).err(), + Some(ser::Error::TooLargeReadErr) + ); +} + +#[test] +fn bitmap_segment_read_rejects_too_large_height() { + let bytes = bitmap_segment_header(14, 0, 1); + assert_eq!( + read_bitmap_segment(&bytes).err(), + Some(ser::Error::TooLargeReadErr) + ); +} + +#[test] +fn bitmap_segment_read_rejects_offset_overflow() { + let bytes = bitmap_segment_header(13, u64::MAX, 1); + assert_eq!( + read_bitmap_segment(&bytes).err(), + Some(ser::Error::TooLargeReadErr) + ); +} diff --git a/core/src/core/pmmr/segment.rs b/core/src/core/pmmr/segment.rs index 8df8b8f9..f9f555bd 100644 --- a/core/src/core/pmmr/segment.rs +++ b/core/src/core/pmmr/segment.rs @@ -21,6 +21,39 @@ use croaring::Bitmap; use std::cmp::min; use std::fmt::Debug; +const MAX_SEGMENT_READ_ITEMS: u64 = 1_000_000; +const SEGMENT_READ_PREALLOC_ITEMS: u64 = 1024; + +fn read_segment_item_count(reader: &mut R) -> Result { + let count = reader.read_u64()?; + if count > MAX_SEGMENT_READ_ITEMS { + return Err(Error::TooLargeReadErr); + } + Ok(count) +} + +fn read_segment_positions(reader: &mut R, count: u64) -> Result, Error> { + let mut positions = Vec::with_capacity(min(count, SEGMENT_READ_PREALLOC_ITEMS) as usize); + let mut last_pos = 0; + for _ in 0..count { + let pos = reader.read_u64()?; + if pos <= last_pos { + return Err(Error::SortError); + } + last_pos = pos; + positions.push(pos - 1); + } + Ok(positions) +} + +fn read_segment_items(reader: &mut R, count: u64) -> Result, Error> { + let mut items = Vec::with_capacity(min(count, SEGMENT_READ_PREALLOC_ITEMS) as usize); + for _ in 0..count { + items.push(T::read(reader)?); + } + Ok(items) +} + #[derive(Clone, Debug, Eq, PartialEq)] /// Possible segment types, according to this desegmenter pub enum SegmentType { @@ -568,39 +601,13 @@ impl Readable for Segment { fn read(reader: &mut R) -> Result { let identifier = Readable::read(reader)?; - let n_hashes = reader.read_u64()? as usize; - let mut hash_pos = Vec::with_capacity(n_hashes); - let mut last_pos = 0; - for _ in 0..n_hashes { - let pos = reader.read_u64()?; - if pos <= last_pos { - return Err(Error::SortError); - } - last_pos = pos; - hash_pos.push(pos - 1); - } + let n_hashes = read_segment_item_count(reader)?; + let hash_pos = read_segment_positions(reader, n_hashes)?; + let hashes = read_segment_items(reader, n_hashes)?; - let mut hashes = Vec::::with_capacity(n_hashes); - for _ in 0..n_hashes { - hashes.push(Readable::read(reader)?); - } - - let n_leaves = reader.read_u64()? as usize; - let mut leaf_pos = Vec::with_capacity(n_leaves); - last_pos = 0; - for _ in 0..n_leaves { - let pos = reader.read_u64()?; - if pos <= last_pos { - return Err(Error::SortError); - } - last_pos = pos; - leaf_pos.push(pos - 1); - } - - let mut leaf_data = Vec::::with_capacity(n_leaves); - for _ in 0..n_leaves { - leaf_data.push(Readable::read(reader)?); - } + let n_leaves = read_segment_item_count(reader)?; + let leaf_pos = read_segment_positions(reader, n_leaves)?; + let leaf_data = read_segment_items(reader, n_leaves)?; let proof = Readable::read(reader)?; @@ -823,12 +830,8 @@ impl SegmentProof { impl Readable for SegmentProof { fn read(reader: &mut R) -> Result { - let n_hashes = reader.read_u64()? as usize; - let mut hashes = Vec::with_capacity(n_hashes); - for _ in 0..n_hashes { - let hash: Hash = Readable::read(reader)?; - hashes.push(hash); - } + let n_hashes = read_segment_item_count(reader)?; + let hashes = read_segment_items(reader, n_hashes)?; Ok(Self { hashes }) } } diff --git a/core/tests/segment.rs b/core/tests/segment.rs index cb81402f..3ab04f7a 100644 --- a/core/tests/segment.rs +++ b/core/tests/segment.rs @@ -16,10 +16,15 @@ mod common; use self::core::core::pmmr; use self::core::core::{Segment, SegmentIdentifier}; +use self::core::ser::{self, DeserializationMode, ProtocolVersion}; use common::TestElem; use grin_core as core; use grin_core::core::pmmr::ReadablePMMR; +fn push_u64(bytes: &mut Vec, n: u64) { + bytes.extend_from_slice(&n.to_be_bytes()); +} + fn test_unprunable_size(height: u8, n_leaves: u32) { let size = 1u64 << height; let n_segments = (n_leaves as u64 + size - 1) / size; @@ -59,3 +64,30 @@ fn unprunable_mmr() { test_unprunable_size(3, i); } } + +#[test] +fn segment_read_rejects_large_hash_count() { + let mut bytes = vec![1]; + push_u64(&mut bytes, 0); + push_u64(&mut bytes, 1_000_001); + + let res: Result, _> = ser::deserialize( + &mut &bytes[..], + ProtocolVersion(1), + DeserializationMode::default(), + ); + assert_eq!(res.err(), Some(ser::Error::TooLargeReadErr)); +} + +#[test] +fn segment_proof_read_rejects_large_hash_count() { + let mut bytes = vec![]; + push_u64(&mut bytes, 1_000_001); + + let res: Result = ser::deserialize( + &mut &bytes[..], + ProtocolVersion(1), + DeserializationMode::default(), + ); + assert_eq!(res.err(), Some(ser::Error::TooLargeReadErr)); +} diff --git a/p2p/src/protocol.rs b/p2p/src/protocol.rs index 31e06932..2699b304 100644 --- a/p2p/src/protocol.rs +++ b/p2p/src/protocol.rs @@ -382,7 +382,7 @@ impl MessageHandler for Protocol { block_hash, output_root ); - adapter.receive_bitmap_segment(block_hash, output_root, segment.into())?; + adapter.receive_bitmap_segment(block_hash, output_root, segment.into_segment()?)?; Consumed::None } Message::OutputSegment(req) => {