Chore/bugfixes (#6783)

* added unit tests for MemoryEcachTicketbookManager

* bugfix: propagate socks5 proxy errors instead of panicking

* introduce guard against providing too short verification keyduring signature validation

* add checked overflow checks for icmp packet construction

* fix kcp loggin

* forbid construction of illegal sphninx fragments

* fix division by zero in packet statistics calculations
This commit is contained in:
Jędrzej Stuczyński
2026-05-22 15:30:24 +01:00
committed by GitHub
parent d2833c76c0
commit 6b0a904d10
10 changed files with 563 additions and 38 deletions
@@ -230,8 +230,8 @@ impl MemoryEcachTicketbookManager {
expiration_date: t.ticketbook.expiration_date(),
ticketbook_type: t.ticketbook.ticketbook_type().to_string(),
epoch_id: t.ticketbook.epoch_id() as u32,
total_tickets: t.ticketbook.spent_tickets() as u32,
used_tickets: t.total_tickets,
total_tickets: t.total_tickets,
used_tickets: t.ticketbook.spent_tickets() as u32,
})
.collect()
}
@@ -333,3 +333,339 @@ impl MemoryEcachTicketbookManager {
guard.emergency_credentials.remove(typ);
}
}
#[cfg(test)]
mod tests {
use super::*;
use nym_compact_ecash::tests::helpers::generate_expiration_date_signatures;
use nym_compact_ecash::{issue, ttp_keygen};
use nym_credentials_interface::TicketType;
use nym_crypto::asymmetric::ed25519;
use nym_ecash_time::EcashTime;
use nym_test_utils::helpers::deterministic_rng;
fn mock_issuance(deposit_id: u32) -> IssuanceTicketBook {
let identifier = "foomp";
let mut rng = deterministic_rng();
let key = ed25519::PrivateKey::new(&mut rng);
let typ = TicketType::V1MixnetEntry;
IssuanceTicketBook::new(deposit_id, identifier, key, typ)
}
fn mock_ticketbook() -> anyhow::Result<IssuedTicketBook> {
let signing_keys = ttp_keygen(1, 1)?.remove(0);
let issuance = mock_issuance(42);
let expiration_date = issuance.expiration_date();
let sig_req = issuance.prepare_for_signing();
let _exp_date_sigs = generate_expiration_date_signatures(
sig_req.expiration_date.ecash_unix_timestamp(),
&[signing_keys.secret_key()],
&[signing_keys.verification_key()],
&signing_keys.verification_key(),
&[1],
)?;
let blind_sig = issue(
signing_keys.secret_key(),
sig_req.ecash_pub_key,
&sig_req.withdrawal_request,
expiration_date.ecash_unix_timestamp(),
issuance.ticketbook_type().encode(),
)?;
let partial_wallet =
issuance.unblind_signature(&signing_keys.verification_key(), &sig_req, blind_sig, 1)?;
let wallet = issuance.aggregate_signature_shares(
&signing_keys.verification_key(),
&[partial_wallet],
sig_req,
)?;
Ok(issuance.into_issued_ticketbook(wallet, 1))
}
fn mock_verification_key() -> VerificationKeyAuth {
ttp_keygen(1, 1).unwrap().remove(0).verification_key()
}
#[tokio::test]
async fn get_ticketbooks_info_empty() {
let manager = MemoryEcachTicketbookManager::new();
let info = manager.get_ticketbooks_info().await;
assert!(info.is_empty());
}
#[tokio::test]
async fn get_ticketbooks_info_maps_inserted_ticketbook() -> anyhow::Result<()> {
let manager = MemoryEcachTicketbookManager::new();
let ticketbook = mock_ticketbook()?;
let total_tickets = 100;
let used_tickets = 25;
manager
.insert_new_ticketbook(&ticketbook, total_tickets, used_tickets)
.await;
let info = manager.get_ticketbooks_info().await;
assert_eq!(info.len(), 1);
let entry = &info[0];
assert_eq!(entry.id, 0);
assert_eq!(entry.expiration_date, ticketbook.expiration_date());
assert_eq!(
entry.ticketbook_type,
ticketbook.ticketbook_type().to_string()
);
assert_eq!(entry.epoch_id, ticketbook.epoch_id() as u32);
assert_eq!(entry.total_tickets, total_tickets);
assert_eq!(entry.used_tickets, used_tickets);
Ok(())
}
#[tokio::test]
async fn contains_ticketbook_reflects_insertion() -> anyhow::Result<()> {
let manager = MemoryEcachTicketbookManager::new();
let ticketbook = mock_ticketbook()?;
assert!(!manager.contains_ticketbook(&ticketbook).await);
manager.insert_new_ticketbook(&ticketbook, 100, 0).await;
assert!(manager.contains_ticketbook(&ticketbook).await);
Ok(())
}
#[tokio::test]
async fn insert_new_ticketbook_assigns_incrementing_ids() -> anyhow::Result<()> {
let manager = MemoryEcachTicketbookManager::new();
let ticketbook = mock_ticketbook()?;
manager.insert_new_ticketbook(&ticketbook, 100, 0).await;
manager.insert_new_ticketbook(&ticketbook, 100, 0).await;
let mut ids: Vec<i64> = manager
.get_ticketbooks_info()
.await
.into_iter()
.map(|i| i.id)
.collect();
ids.sort();
assert_eq!(ids, vec![0, 1]);
Ok(())
}
#[tokio::test]
async fn get_next_unspent_ticketbook_updates_spent_and_exhausts() -> anyhow::Result<()> {
let manager = MemoryEcachTicketbookManager::new();
let ticketbook = mock_ticketbook()?;
let typ = ticketbook.ticketbook_type().to_string();
// total = 3, used = 0 — leaves 3 tickets available
manager.insert_new_ticketbook(&ticketbook, 3, 0).await;
let first = manager
.get_next_unspent_ticketbook_and_update(typ.clone(), 2)
.await;
assert!(first.is_some());
let first = first.unwrap();
assert_eq!(first.total_tickets, 3);
// returned ticketbook reflects state *before* the update
assert_eq!(first.ticketbook.spent_tickets(), 0);
// next withdrawal of 2 should be rejected (only 1 left)
let second = manager
.get_next_unspent_ticketbook_and_update(typ.clone(), 2)
.await;
assert!(second.is_none());
// but a withdrawal of 1 succeeds
let third = manager
.get_next_unspent_ticketbook_and_update(typ.clone(), 1)
.await;
assert!(third.is_some());
// and now nothing left
let fourth = manager.get_next_unspent_ticketbook_and_update(typ, 1).await;
assert!(fourth.is_none());
Ok(())
}
#[tokio::test]
async fn get_next_unspent_ticketbook_filters_by_type() -> anyhow::Result<()> {
let manager = MemoryEcachTicketbookManager::new();
let ticketbook = mock_ticketbook()?;
manager.insert_new_ticketbook(&ticketbook, 5, 0).await;
let mismatched = manager
.get_next_unspent_ticketbook_and_update("nonexistent_type".to_string(), 1)
.await;
assert!(mismatched.is_none());
Ok(())
}
#[tokio::test]
async fn revert_ticketbook_withdrawal_resets_spent_only_when_expected_matches(
) -> anyhow::Result<()> {
let manager = MemoryEcachTicketbookManager::new();
let ticketbook = mock_ticketbook()?;
let typ = ticketbook.ticketbook_type().to_string();
manager.insert_new_ticketbook(&ticketbook, 10, 0).await;
manager
.get_next_unspent_ticketbook_and_update(typ.clone(), 4)
.await
.expect("should withdraw");
// stale expected_current_total_spent — should be rejected
assert!(!manager.revert_ticketbook_withdrawal(0, 4, 99).await);
// spent_tickets unchanged
let used_after_failed = manager.get_ticketbooks_info().await[0].used_tickets;
assert_eq!(used_after_failed, 4);
// matching expected — should succeed and restore
assert!(manager.revert_ticketbook_withdrawal(0, 4, 4).await);
let used_after_revert = manager.get_ticketbooks_info().await[0].used_tickets;
assert_eq!(used_after_revert, 0);
// unknown ticketbook_id is rejected
assert!(!manager.revert_ticketbook_withdrawal(999, 1, 0).await);
Ok(())
}
#[tokio::test]
async fn pending_ticketbook_round_trip() {
let manager = MemoryEcachTicketbookManager::new();
let issuance = mock_issuance(7);
let deposit_id = issuance.deposit_id() as i64;
assert!(manager.get_pending_ticketbooks().await.is_empty());
manager.insert_pending_ticketbook(&issuance).await;
let pending = manager.get_pending_ticketbooks().await;
assert_eq!(pending.len(), 1);
assert_eq!(pending[0].pending_id, deposit_id);
assert_eq!(
pending[0].pending_ticketbook.deposit_id(),
issuance.deposit_id()
);
manager.remove_pending_ticketbook(deposit_id).await;
assert!(manager.get_pending_ticketbooks().await.is_empty());
// removing a non-existent id is a no-op
manager.remove_pending_ticketbook(999).await;
}
#[tokio::test]
async fn emergency_credential_lifecycle() {
let manager = MemoryEcachTicketbookManager::new();
let cred_a = EmergencyCredentialContent {
typ: "type-a".to_string(),
content: vec![1, 2, 3],
expiration: None,
};
let cred_b = EmergencyCredentialContent {
typ: "type-a".to_string(),
content: vec![4, 5, 6],
expiration: None,
};
let cred_c = EmergencyCredentialContent {
typ: "type-b".to_string(),
content: vec![7, 8, 9],
expiration: None,
};
assert!(manager.get_emergency_credential("type-a").await.is_none());
manager.insert_emergency_credential(&cred_a).await;
manager.insert_emergency_credential(&cred_b).await;
manager.insert_emergency_credential(&cred_c).await;
// get returns the first inserted entry for the type
let first = manager.get_emergency_credential("type-a").await.unwrap();
assert_eq!(first.id, 0);
assert_eq!(first.data.content, vec![1, 2, 3]);
// remove by id drops only that entry; type-a now exposes cred_b
manager.remove_emergency_credential(0).await;
let after_remove = manager.get_emergency_credential("type-a").await.unwrap();
assert_eq!(after_remove.id, 1);
assert_eq!(after_remove.data.content, vec![4, 5, 6]);
// remove by type clears the bucket entirely
manager.remove_emergency_credentials_of_type("type-a").await;
assert!(manager.get_emergency_credential("type-a").await.is_none());
// unrelated type is untouched
assert!(manager.get_emergency_credential("type-b").await.is_some());
}
#[tokio::test]
async fn master_verification_key_round_trip() {
let manager = MemoryEcachTicketbookManager::new();
let key = mock_verification_key();
let epoch = EpochVerificationKey {
epoch_id: 7,
key: key.clone(),
};
assert!(manager.get_master_verification_key(7).await.is_none());
manager.insert_master_verification_key(&epoch).await;
assert_eq!(manager.get_master_verification_key(7).await, Some(key));
assert!(manager.get_master_verification_key(8).await.is_none());
}
#[tokio::test]
async fn coin_index_signatures_round_trip() {
let manager = MemoryEcachTicketbookManager::new();
let sigs = AggregatedCoinIndicesSignatures {
epoch_id: 3,
signatures: vec![],
};
assert!(manager.get_coin_index_signatures(3).await.is_none());
manager.insert_coin_index_signatures(&sigs).await;
let retrieved = manager.get_coin_index_signatures(3).await;
assert!(retrieved.is_some());
assert!(retrieved.unwrap().is_empty());
assert!(manager.get_coin_index_signatures(4).await.is_none());
}
#[tokio::test]
async fn expiration_date_signatures_round_trip() {
let manager = MemoryEcachTicketbookManager::new();
let date = nym_ecash_time::ecash_today().date();
let sigs = AggregatedExpirationDateSignatures {
epoch_id: 5,
expiration_date: date,
signatures: vec![],
};
assert!(manager
.get_expiration_date_signatures(date, 5)
.await
.is_none());
manager.insert_expiration_date_signatures(&sigs).await;
let retrieved = manager.get_expiration_date_signatures(date, 5).await;
assert!(retrieved.is_some());
assert!(retrieved.unwrap().is_empty());
// wrong epoch / wrong date → miss
assert!(manager
.get_expiration_date_signatures(date, 6)
.await
.is_none());
}
}
@@ -25,6 +25,9 @@ pub enum Error {
#[error("failed to create ipv4 packet")]
Ipv4PacketCreationFailure,
#[error("packet length {length} exceeds the u16 IP header field")]
PacketLengthOverflow { length: usize },
}
// Result type based on our error type
@@ -79,9 +79,14 @@ pub fn wrap_icmp_in_ipv4(
let mut ipv4_packet =
MutableIpv4Packet::owned(ipv4_buffer).ok_or(Error::Ipv4PacketCreationFailure)?;
let total_length_u16 =
u16::try_from(total_length).map_err(|_| Error::PacketLengthOverflow {
length: total_length,
})?;
ipv4_packet.set_version(4);
ipv4_packet.set_header_length(5);
ipv4_packet.set_total_length(total_length as u16);
ipv4_packet.set_total_length(total_length_u16);
ipv4_packet.set_ttl(64);
ipv4_packet.set_next_level_protocol(pnet_packet::ip::IpNextHeaderProtocols::Icmp);
ipv4_packet.set_source(source);
@@ -101,12 +106,18 @@ pub fn wrap_icmp_in_ipv6(
source: Ipv6Addr,
destination: Ipv6Addr,
) -> Result<Ipv6Packet> {
let ipv6_buffer = vec![0u8; 40 + icmp_echo_request.packet().len()];
let payload_length = icmp_echo_request.packet().len();
let payload_length_u16 =
u16::try_from(payload_length).map_err(|_| Error::PacketLengthOverflow {
length: payload_length,
})?;
let ipv6_buffer = vec![0u8; 40 + payload_length];
let mut ipv6_packet =
MutableIpv6Packet::owned(ipv6_buffer).ok_or(Error::Ipv4PacketCreationFailure)?;
ipv6_packet.set_version(6);
ipv6_packet.set_payload_length(icmp_echo_request.packet().len() as u16);
ipv6_packet.set_payload_length(payload_length_u16);
ipv6_packet.set_next_header(pnet_packet::ip::IpNextHeaderProtocols::Icmpv6);
ipv6_packet.set_hop_limit(64);
ipv6_packet.set_source(source);
@@ -164,3 +175,122 @@ pub(crate) fn is_icmp_v6_echo_reply(packet: &Bytes) -> Option<(u16, Ipv6Addr, Ip
}
None
}
#[cfg(test)]
mod tests {
use super::*;
use pnet_packet::icmp::IcmpTypes;
use pnet_packet::icmpv6::Icmpv6Types;
use pnet_packet::ip::IpNextHeaderProtocols;
const V4_SRC: Ipv4Addr = Ipv4Addr::new(10, 0, 0, 1);
const V4_DST: Ipv4Addr = Ipv4Addr::new(10, 0, 0, 2);
const V6_SRC: Ipv6Addr = Ipv6Addr::new(0xfe80, 0, 0, 0, 0, 0, 0, 1);
const V6_DST: Ipv6Addr = Ipv6Addr::new(0xfe80, 0, 0, 0, 0, 0, 0, 2);
#[test]
fn icmpv4_echo_request_sets_fields_and_valid_checksum() {
let echo = create_icmpv4_echo_request(42, 7).unwrap();
assert_eq!(echo.get_sequence_number(), 42);
assert_eq!(echo.get_identifier(), 7);
assert_eq!(echo.get_icmp_type(), IcmpTypes::EchoRequest);
// pnet's `checksum` skips the checksum word, so recomputing on the produced
// packet must equal the stored value.
let icmp = IcmpPacket::new(echo.packet()).unwrap();
assert_eq!(echo.get_checksum(), pnet_packet::icmp::checksum(&icmp));
}
#[test]
fn icmpv6_echo_request_sets_fields_and_valid_checksum() {
let echo = create_icmpv6_echo_request(99, 12, &V6_SRC, &V6_DST).unwrap();
assert_eq!(echo.get_sequence_number(), 99);
assert_eq!(echo.get_identifier(), 12);
assert_eq!(echo.get_icmpv6_type(), Icmpv6Types::EchoRequest);
let icmpv6 = icmpv6::Icmpv6Packet::new(echo.packet()).unwrap();
assert_eq!(
echo.get_checksum(),
pnet_packet::icmpv6::checksum(&icmpv6, &V6_SRC, &V6_DST)
);
}
#[test]
fn wrap_icmp_in_ipv4_sets_headers_and_payload() {
let echo = create_icmpv4_echo_request(1, 2).unwrap();
let echo_bytes = echo.packet().to_vec();
let packet = wrap_icmp_in_ipv4(echo, V4_SRC, V4_DST).unwrap();
assert_eq!(packet.get_version(), 4);
assert_eq!(packet.get_header_length(), 5);
assert_eq!(packet.get_total_length() as usize, 20 + echo_bytes.len());
assert_eq!(packet.get_ttl(), 64);
assert_eq!(
packet.get_next_level_protocol(),
IpNextHeaderProtocols::Icmp
);
assert_eq!(packet.get_source(), V4_SRC);
assert_eq!(packet.get_destination(), V4_DST);
assert_eq!(packet.payload(), echo_bytes.as_slice());
}
#[test]
fn wrap_icmp_in_ipv6_sets_headers_and_payload() {
let echo = create_icmpv6_echo_request(1, 2, &V6_SRC, &V6_DST).unwrap();
let echo_bytes = echo.packet().to_vec();
let packet = wrap_icmp_in_ipv6(echo, V6_SRC, V6_DST).unwrap();
assert_eq!(packet.get_version(), 6);
assert_eq!(packet.get_payload_length() as usize, echo_bytes.len());
assert_eq!(packet.get_next_header(), IpNextHeaderProtocols::Icmpv6);
assert_eq!(packet.get_hop_limit(), 64);
assert_eq!(packet.get_source(), V6_SRC);
assert_eq!(packet.get_destination(), V6_DST);
assert_eq!(packet.payload(), echo_bytes.as_slice());
}
#[test]
fn compute_ipv4_checksum_is_zero_on_correctly_checksummed_packet() {
let echo = create_icmpv4_echo_request(1, 2).unwrap();
let packet = wrap_icmp_in_ipv4(echo, V4_SRC, V4_DST).unwrap();
// RFC 1071: summing every 16-bit word of a header that already contains its
// own checksum yields all-ones; the one's complement is therefore zero.
assert_eq!(compute_ipv4_checksum(&packet), 0);
}
#[test]
fn is_icmp_echo_reply_extracts_identifier_and_addresses() {
// pnet's EchoReply/EchoRequest share the same byte layout (only the ICMP
// type field differs) and `is_icmp_echo_reply` does not check the type,
// so a wrapped echo *request* exercises the same parsing path.
let identifier = 1234;
let echo = create_icmpv4_echo_request(7, identifier).unwrap();
let packet = wrap_icmp_in_ipv4(echo, V4_SRC, V4_DST).unwrap();
let bytes = Bytes::copy_from_slice(packet.packet());
assert_eq!(
is_icmp_echo_reply(&bytes),
Some((identifier, V4_SRC, V4_DST))
);
}
#[test]
fn is_icmp_v6_echo_reply_extracts_identifier_and_addresses() {
let identifier = 5678;
let echo = create_icmpv6_echo_request(7, identifier, &V6_SRC, &V6_DST).unwrap();
let packet = wrap_icmp_in_ipv6(echo, V6_SRC, V6_DST).unwrap();
let bytes = Bytes::copy_from_slice(packet.packet());
assert_eq!(
is_icmp_v6_echo_reply(&bytes),
Some((identifier, V6_SRC, V6_DST))
);
}
#[test]
fn is_icmp_echo_reply_returns_none_for_undersized_bytes() {
let bytes = Bytes::from_static(&[0u8; 4]);
assert!(is_icmp_echo_reply(&bytes).is_none());
assert!(is_icmp_v6_echo_reply(&bytes).is_none());
}
}
+14 -15
View File
@@ -6,7 +6,7 @@ use std::{
use ansi_term::Color::Yellow;
use bytes::{Buf, BytesMut};
use log::{debug, error, warn};
use log::{debug, error, trace, warn};
use std::thread;
use crate::MAX_RTO;
@@ -499,21 +499,9 @@ impl KcpSession {
self.snd_buf.len(),
post_retain_sns
);
// Corrected format string arguments for the removed count log
debug!(
"[ConvID: {}, Thread: {:?}] parse_una(una={}): Removed {} segment(s) from snd_buf ({} -> {}). Remaining sns: {:?}",
self.conv,
thread::current().id(),
una,
removed_count,
original_len,
self.snd_buf.len(),
post_retain_sns
);
if removed_count > 0 {
// Use trace level if no segments were removed but buffer wasn't empty
debug!(
if removed_count == 0 {
trace!(
"[ConvID: {}, Thread: {:?}] parse_una(una={}): No segments removed from snd_buf (len={}). Remaining sns: {:?}",
self.conv,
thread::current().id(),
@@ -521,6 +509,17 @@ impl KcpSession {
original_len,
self.snd_buf.iter().map(|s| s.sn).collect::<Vec<_>>()
);
} else {
debug!(
"[ConvID: {}, Thread: {:?}] parse_una(una={}): Removed {} segment(s) from snd_buf ({} -> {}). Remaining sns: {:?}",
self.conv,
thread::current().id(),
una,
removed_count,
original_len,
self.snd_buf.len(),
post_retain_sns
);
}
// Update the known acknowledged sequence number.
@@ -719,6 +719,10 @@ impl Payment {
return Err(CompactEcashError::SpendSignaturesValidity);
}
if verification_key.beta_g2.len() < 4 {
return Err(CompactEcashError::VerificationKeyTooShort);
}
let kappa_type = self.kappa + verification_key.beta_g2[3] * type_scalar(self.t_type);
if !check_bilinear_pairing(
&self.sig.h.to_affine(),
+31 -4
View File
@@ -386,8 +386,13 @@ impl FragmentHeader {
{
return Err(ChunkingError::MalformedHeaderError);
}
// post-link requires total == current == u8::MAX so the constructor
// stays in lockstep with `try_from_bytes`'s deserialiser check.
if let Some(nfid) = next_fragments_set_id
&& (nfid <= 0 || current_fragment != total_fragments || nfid == id)
&& (nfid <= 0
|| current_fragment != u8::MAX
|| total_fragments != u8::MAX
|| nfid == id)
{
return Err(ChunkingError::MalformedHeaderError);
}
@@ -1124,9 +1129,13 @@ mod fragment_header {
}
#[test]
fn can_only_be_post_linked_for_last_fragment() {
assert!(FragmentHeader::try_new(12345, 10, 10, None, Some(1234)).is_ok());
assert!(FragmentHeader::try_new(12345, u8::MAX, u8::MAX, None, Some(1234),).is_ok());
fn can_only_be_post_linked_for_last_fragment_of_full_set() {
// post-linking requires total == current == u8::MAX (a *full* set)
assert!(FragmentHeader::try_new(12345, u8::MAX, u8::MAX, None, Some(1234)).is_ok());
assert!(FragmentHeader::try_new(12345, 10, 10, None, Some(1234)).is_err());
assert!(
FragmentHeader::try_new(12345, u8::MAX - 1, u8::MAX - 1, None, Some(1234)).is_err()
);
assert!(FragmentHeader::try_new(12345, 10, 2, Some(1234), None).is_err());
}
@@ -1192,5 +1201,23 @@ mod fragment_header {
assert_eq!(fragmented_header, recovered_header);
assert_eq!(LINKED_FRAGMENTED_HEADER_LEN, bytes_used);
}
#[test]
fn post_linked_with_non_max_total_is_rejected_by_constructor_and_deserialiser() {
// Regression: try_new used to accept post-linked headers where
// total/current != u8::MAX, but try_from_bytes rejects them, so
// such headers could never round-trip.
assert!(FragmentHeader::try_new(12345, 10, 10, None, Some(1234)).is_err());
// The deserialiser must still reject the corresponding bytes if
// some future change tries to emit them. Build the malformed bytes
// by hand from a valid post-linked header, then overwrite the
// total/current fragment counts.
let valid = FragmentHeader::try_new(12345, u8::MAX, u8::MAX, None, Some(1234)).unwrap();
let mut malformed = valid.to_bytes();
malformed[4] = 10; // total_fragments
malformed[5] = 10; // current_fragment
assert!(FragmentHeader::try_from_bytes(&malformed).is_err());
}
}
}
+10 -3
View File
@@ -411,7 +411,7 @@ impl SocksClient {
let recipient = self.service_provider;
let packet_type = self.packet_type;
let (stream, _) = ProxyRunner::new(
let proxy_result = ProxyRunner::new(
stream,
local_stream_remote,
remote_proxy_target,
@@ -449,8 +449,15 @@ impl SocksClient {
)
}
})
.await
.into_inner();
.await;
let (stream, _) = match proxy_result {
Ok(runner) => runner.into_inner(),
Err(err) => {
log::error!("proxy runner for connection {connection_id} failed: {err}");
return;
}
};
// recover stream from the proxy
self.stream.finish_proxy(stream)
}
@@ -8,6 +8,7 @@ use nym_task::connections::LaneQueueLengths;
use nym_task::ShutdownTracker;
use std::fmt::Debug;
use std::{sync::Arc, time::Duration};
use tokio::task::JoinError;
use tokio::{net::TcpStream, sync::Notify};
mod inbound;
@@ -92,7 +93,7 @@ where
// The `adapter_fn` is used to transform whatever was read into appropriate
// request/response as required by entity running particular side of the proxy.
pub async fn run<F>(mut self, adapter_fn: F) -> Self
pub async fn run<F>(mut self, adapter_fn: F) -> Result<Self, JoinError>
where
F: Fn(SocketData) -> S + Send + Sync + 'static,
{
@@ -148,16 +149,22 @@ where
let (inbound_result, outbound_result) =
futures::future::join(handle_inbound, handle_outbound).await;
if inbound_result.is_err() || outbound_result.is_err() {
panic!("TODO: some future error?")
}
let read_half = inbound_result.unwrap();
let (write_half, mix_receiver) = outbound_result.unwrap();
let read_half = inbound_result.inspect_err(|err| {
log::error!(
"inbound proxy task for connection {} failed: {err}",
self.connection_id
)
})?;
let (write_half, mix_receiver) = outbound_result.inspect_err(|err| {
log::error!(
"outbound proxy task for connection {} failed: {err}",
self.connection_id
)
})?;
self.socket = Some(write_half.reunite(read_half).unwrap());
self.mix_receiver = Some(mix_receiver);
self
Ok(self)
}
pub fn into_inner(mut self) -> (TcpStream, ConnectionReceiver) {
@@ -490,6 +490,11 @@ impl PacketStatisticsControl {
// Do basic averaging over the entire history, which just uses the first and last
if let Some((start, start_stats)) = self.history.front() {
let duration_secs = Instant::now().duration_since(start).as_secs_f64();
// skip when only one entry was just pushed in this tick: dividing by 0
// would yield inf/NaN rates that downstream consumers treat as real values.
if duration_secs == 0.0 {
return None;
}
let delta = self.stats.clone() - start_stats.clone();
let rates = PacketRates::from(delta) / duration_secs;
Some(rates)
@@ -53,7 +53,7 @@ impl Connection {
let remote_source_address = "???".to_string(); // we don't know ip address of requester
let connection_id = self.id;
let return_address = self.return_address.clone();
let (stream, _) = ProxyRunner::new(
let proxy_result = ProxyRunner::new(
stream,
self.address.clone(),
remote_source_address,
@@ -76,8 +76,15 @@ impl Connection {
socket_data.header.local_socket_closed,
)
})
.await
.into_inner();
.await;
let (stream, _) = match proxy_result {
Ok(runner) => runner.into_inner(),
Err(err) => {
log::error!("proxy runner for connection {connection_id} failed: {err}");
return;
}
};
self.conn = Some(stream);
}
}