Feature/cancellation migration (#6014)

* squashing work on using cancellation in nym crates

making nym-task wasm compilable

removed sending of status messages

replaced TaskManager with ShutdownManager in the validator rewarder

additional helpers for ShutdownManager

simplified ShutdownToken by removing the name field

TaskClient => ShutdownToken within all client tasks

wip: remove TaskHandle

* track all long-living client tasks

* add task tracking for most top level tasks within nym-node

* improved default builder

* split up cancellation module

* module documentation and unit tests

* nym node fixes and naming consistency

* wasm fixes

* assert_eq => assert

* wasm fixes and made 'run_until_shutdown' take reference instead of ownership

* linux-specific fixes to IpPacketRouter

* post rebasing fixes for signing monitor

* add ShutdownManager constructor to build it from an external token

* applying PR review suggestions
This commit is contained in:
Jędrzej Stuczyński
2025-09-10 13:56:39 +01:00
committed by GitHub
parent d3cdaf373b
commit 0ee387d983
125 changed files with 2701 additions and 1927 deletions
Generated
+3
View File
@@ -6410,6 +6410,7 @@ dependencies = [
"thiserror 2.0.12",
"time",
"tokio",
"tokio-stream",
"tokio-util",
"toml 0.8.23",
"tower-http 0.5.2",
@@ -7157,9 +7158,11 @@ dependencies = [
name = "nym-task"
version = "0.1.0"
dependencies = [
"anyhow",
"cfg-if",
"futures",
"log",
"nym-test-utils",
"thiserror 2.0.12",
"tokio",
"tokio-util",
+14 -10
View File
@@ -11,7 +11,7 @@ use nym_client_core::client::base_client::{
BaseClientBuilder, ClientInput, ClientOutput, ClientState,
};
use nym_sphinx::params::PacketType;
use nym_task::TaskHandle;
use nym_task::ShutdownManager;
use nym_validator_client::QueryHttpRpcNyxdClient;
use std::error::Error;
use std::path::PathBuf;
@@ -29,6 +29,8 @@ pub struct SocketClient {
/// Optional path to a .json file containing standalone network details.
custom_mixnet: Option<PathBuf>,
shutdown_manager: ShutdownManager,
}
impl SocketClient {
@@ -40,6 +42,7 @@ impl SocketClient {
SocketClient {
config,
custom_mixnet,
shutdown_manager: Default::default(),
}
}
@@ -49,7 +52,7 @@ impl SocketClient {
client_output: ClientOutput,
client_state: ClientState,
self_address: &Recipient,
task_client: nym_task::TaskClient,
shutdown_token: nym_task::ShutdownToken,
packet_type: PacketType,
) {
info!("Starting websocket listener...");
@@ -77,24 +80,24 @@ impl SocketClient {
shared_lane_queue_lengths,
reply_controller_sender,
Some(packet_type),
task_client.fork("websocket_handler"),
shutdown_token.clone(),
);
websocket::Listener::new(
config.socket.host,
config.socket.listening_port,
task_client.with_suffix("websocket_listener"),
shutdown_token.child_token(),
)
.start(websocket_handler);
}
/// blocking version of `start_socket` method. Will run forever (or until SIGINT is sent)
pub async fn run_socket_forever(self) -> Result<(), Box<dyn Error + Send + Sync>> {
let shutdown = self.start_socket().await?;
let mut shutdown = self.start_socket().await?;
let res = shutdown.wait_for_shutdown().await;
shutdown.run_until_shutdown().await;
log::info!("Stopping nym-client");
res
Ok(())
}
async fn initialise_storage(&self) -> Result<OnDiskPersistent, ClientError> {
@@ -119,6 +122,7 @@ impl SocketClient {
let mut base_client =
BaseClientBuilder::new(self.config().base(), storage, dkg_query_client)
.with_shutdown(self.shutdown_manager.shutdown_tracker_owned())
.with_user_agent(user_agent);
if let Some(custom_mixnet) = &self.custom_mixnet {
@@ -128,7 +132,7 @@ impl SocketClient {
Ok(base_client)
}
pub async fn start_socket(self) -> Result<TaskHandle, ClientError> {
pub async fn start_socket(self) -> Result<ShutdownManager, ClientError> {
if !self.config.socket.socket_type.is_websocket() {
return Err(ClientError::InvalidSocketMode);
}
@@ -147,13 +151,13 @@ impl SocketClient {
client_output,
client_state,
&self_address,
started_client.task_handle.get_handle(),
self.shutdown_manager.child_shutdown_token(),
packet_type,
);
info!("Client startup finished!");
info!("The address of this client is: {self_address}");
Ok(started_client.task_handle)
Ok(self.shutdown_manager)
}
}
+21 -27
View File
@@ -19,7 +19,7 @@ use nym_sphinx::receiver::ReconstructedMessage;
use nym_task::connections::{
ConnectionCommand, ConnectionCommandSender, ConnectionId, LaneQueueLengths, TransmissionLane,
};
use nym_task::TaskClient;
use nym_task::ShutdownToken;
use std::time::Duration;
use tokio::net::TcpStream;
use tokio::time::Instant;
@@ -44,7 +44,7 @@ pub(crate) struct HandlerBuilder {
lane_queue_lengths: LaneQueueLengths,
reply_controller_sender: ReplyControllerSender,
packet_type: Option<PacketType>,
task_client: TaskClient,
shutdown_token: ShutdownToken,
}
impl HandlerBuilder {
@@ -57,7 +57,7 @@ impl HandlerBuilder {
lane_queue_lengths: LaneQueueLengths,
reply_controller_sender: ReplyControllerSender,
packet_type: Option<PacketType>,
task_client: TaskClient,
shutdown_token: ShutdownToken,
) -> Self {
Self {
msg_input,
@@ -67,14 +67,13 @@ impl HandlerBuilder {
lane_queue_lengths,
reply_controller_sender,
packet_type,
task_client,
shutdown_token,
}
}
// TODO: make sure we only ever have one active handler
pub fn create_active_handler(&self) -> Handler {
let mut task_client = self.task_client.fork("active_handler");
task_client.disarm();
let shutdown_token = self.shutdown_token.clone();
Handler {
msg_input: self.msg_input.clone(),
client_connection_tx: self.client_connection_tx.clone(),
@@ -85,7 +84,7 @@ impl HandlerBuilder {
lane_queue_lengths: self.lane_queue_lengths.clone(),
reply_controller_sender: self.reply_controller_sender.clone(),
packet_type: self.packet_type,
task_client,
shutdown_token,
}
}
}
@@ -100,19 +99,14 @@ pub(crate) struct Handler {
lane_queue_lengths: LaneQueueLengths,
reply_controller_sender: ReplyControllerSender,
packet_type: Option<PacketType>,
task_client: TaskClient,
shutdown_token: ShutdownToken,
}
impl Drop for Handler {
fn drop(&mut self) {
if let Err(err) = self
let _ = self
.buffer_requester
.unbounded_send(ReceivedBufferMessage::ReceiverDisconnect)
{
if !self.task_client.is_shutdown_poll() {
error!("failed to disconnect the receiver from the buffer: {err}");
}
}
.unbounded_send(ReceivedBufferMessage::ReceiverDisconnect);
}
}
@@ -142,7 +136,7 @@ impl Handler {
{
Ok(length) => length,
Err(err) => {
if !self.task_client.is_shutdown_poll() {
if !self.shutdown_token.is_cancelled() {
error!(
"Failed to get reply queue length for connection {connection_id}: {err}"
);
@@ -192,7 +186,7 @@ impl Handler {
// the ack control is now responsible for chunking, etc.
let input_msg = InputMessage::new_regular(recipient, message, lane, self.packet_type);
if let Err(err) = self.msg_input.send(input_msg).await {
if !self.task_client.is_shutdown_poll() {
if !self.shutdown_token.is_cancelled() {
error!("Failed to send message to the input buffer: {err}");
}
}
@@ -225,7 +219,7 @@ impl Handler {
let input_msg =
InputMessage::new_anonymous(recipient, message, reply_surbs, lane, self.packet_type);
if let Err(err) = self.msg_input.send(input_msg).await {
if !self.task_client.is_shutdown_poll() {
if !self.shutdown_token.is_cancelled() {
error!("Failed to send anonymous message to the input buffer: {err}");
}
}
@@ -253,7 +247,7 @@ impl Handler {
let input_msg = InputMessage::new_reply(recipient_tag, message, lane, self.packet_type);
if let Err(err) = self.msg_input.send(input_msg).await {
if !self.task_client.is_shutdown_poll() {
if !self.shutdown_token.is_cancelled() {
error!("Failed to send reply message to the input buffer: {err}");
}
}
@@ -275,7 +269,7 @@ impl Handler {
.client_connection_tx
.unbounded_send(ConnectionCommand::Close(connection_id))
{
if !self.task_client.is_shutdown_poll() {
if !self.shutdown_token.is_cancelled() {
error!("Failed to send close connection command: {err}");
}
}
@@ -394,11 +388,14 @@ impl Handler {
}
async fn listen_for_requests(&mut self, mut msg_receiver: ReconstructedMessagesReceiver) {
let mut task_client = self.task_client.fork("select");
task_client.disarm();
let shutdown_token = self.shutdown_token.clone();
while !task_client.is_shutdown() {
loop {
tokio::select! {
_ = shutdown_token.cancelled() => {
log::trace!("Websocket handler: Received shutdown");
break;
}
// we can either get a client request from the websocket
socket_msg = self.next_websocket_request() => {
if socket_msg.is_none() {
@@ -436,9 +433,6 @@ impl Handler {
break;
}
}
_ = task_client.recv() => {
log::trace!("Websocket handler: Received shutdown");
}
}
}
log::debug!("Websocket handler: Exiting");
@@ -464,7 +458,7 @@ impl Handler {
reconstructed_sender,
))
{
if !self.task_client.is_shutdown_poll() {
if !self.shutdown_token.is_cancelled() {
error!("failed to announce the receiver to the buffer: {err}");
}
}
+7 -7
View File
@@ -3,7 +3,7 @@
use super::handler::HandlerBuilder;
use log::*;
use nym_task::TaskClient;
use nym_task::ShutdownToken;
use std::net::IpAddr;
use std::{net::SocketAddr, process, sync::Arc};
use tokio::io::AsyncWriteExt;
@@ -23,15 +23,15 @@ impl State {
pub(crate) struct Listener {
address: SocketAddr,
state: State,
task_client: TaskClient,
shutdown_token: ShutdownToken,
}
impl Listener {
pub(crate) fn new(host: IpAddr, port: u16, task_client: TaskClient) -> Self {
pub(crate) fn new(host: IpAddr, port: u16, shutdown_token: ShutdownToken) -> Self {
Listener {
address: SocketAddr::new(host, port),
state: State::AwaitingConnection,
task_client,
shutdown_token,
}
}
@@ -46,11 +46,11 @@ impl Listener {
let notify = Arc::new(Notify::new());
while !self.task_client.is_shutdown() {
while !self.shutdown_token.is_cancelled() {
tokio::select! {
// When the handler finishes we check if shutdown is signalled
_ = notify.notified() => {
if self.task_client.is_shutdown() {
if self.shutdown_token.is_cancelled() {
log::trace!("Websocket listener: detected shutdown after connection closed");
break;
}
@@ -59,7 +59,7 @@ impl Listener {
}
// ... but when there is no connected client at the time of shutdown being
// signalled, we handle it here.
_ = self.task_client.recv() => {
_ = self.shutdown_token.cancelled() => {
if !self.state.is_connected() {
log::trace!("Not connected: shutting down");
break;
@@ -1,7 +1,9 @@
// Copyright 2023 - Nym Technologies SA <contact@nymtech.net>
// SPDX-License-Identifier: Apache-2.0
use crate::error::ClientCoreError;
use crate::{client::replies::reply_storage, config::DebugConfig};
use nym_task::{ShutdownManager, ShutdownToken, ShutdownTracker};
pub fn setup_empty_reply_surb_backend(debug_config: &DebugConfig) -> reply_storage::Empty {
reply_storage::Empty {
@@ -13,3 +15,49 @@ pub fn setup_empty_reply_surb_backend(debug_config: &DebugConfig) -> reply_stora
.maximum_reply_surb_storage_threshold,
}
}
// old 'TaskHandle'
pub(crate) enum ShutdownHelper {
Internal(ShutdownManager),
External(ShutdownTracker),
}
fn new_shutdown_manager() -> Result<ShutdownManager, ClientCoreError> {
cfg_if::cfg_if! {
if #[cfg(not(target_arch = "wasm32"))] {
Ok(ShutdownManager::build_new_default()?)
} else {
Ok(ShutdownManager::new_without_signals())
}
}
}
impl ShutdownHelper {
pub(crate) fn new(shutdown_tracker: Option<ShutdownTracker>) -> Result<Self, ClientCoreError> {
match shutdown_tracker {
None => Ok(ShutdownHelper::Internal(new_shutdown_manager()?)),
Some(shutdown_tracker) => Ok(ShutdownHelper::External(shutdown_tracker)),
}
}
pub(crate) fn into_internal(self) -> Option<ShutdownManager> {
match self {
ShutdownHelper::Internal(manager) => Some(manager),
ShutdownHelper::External(_) => None,
}
}
pub(crate) fn shutdown_token(&self) -> ShutdownToken {
match self {
ShutdownHelper::External(shutdown) => shutdown.clone_shutdown_token(),
ShutdownHelper::Internal(shutdown) => shutdown.clone_shutdown_token(),
}
}
pub(crate) fn tracker(&self) -> &ShutdownTracker {
match self {
ShutdownHelper::External(shutdown) => shutdown,
ShutdownHelper::Internal(shutdown) => shutdown.shutdown_tracker(),
}
}
}
+128 -70
View File
@@ -4,6 +4,7 @@
use super::mix_traffic::ClientRequestSender;
use super::received_buffer::ReceivedBufferMessage;
use super::statistics_control::StatisticsControl;
use crate::client::base_client::helpers::ShutdownHelper;
use crate::client::base_client::storage::helpers::store_client_keys;
use crate::client::base_client::storage::MixnetClientStorage;
use crate::client::cover_traffic_stream::LoopCoverTrafficStream;
@@ -27,13 +28,13 @@ use crate::client::topology_control::nym_api_provider::NymApiTopologyProvider;
use crate::client::topology_control::{
TopologyAccessor, TopologyRefresher, TopologyRefresherConfig,
};
use crate::config;
use crate::config::{Config, DebugConfig};
use crate::error::ClientCoreError;
use crate::init::{
setup_gateway,
types::{GatewaySetup, InitialisationResult},
};
use crate::{config, spawn_future};
use futures::channel::mpsc;
use nym_bandwidth_controller::BandwidthController;
use nym_client_core_config_types::{ForgetMe, RememberMe};
@@ -48,12 +49,11 @@ use nym_gateway_client::{
use nym_sphinx::acknowledgements::AckKey;
use nym_sphinx::addressing::clients::Recipient;
use nym_sphinx::addressing::nodes::NodeIdentity;
use nym_sphinx::params::PacketType;
use nym_sphinx::receiver::{ReconstructedMessage, SphinxMessageReceiver};
use nym_statistics_common::clients::ClientStatsSender;
use nym_statistics_common::generate_client_stats_id;
use nym_task::connections::{ConnectionCommandReceiver, ConnectionCommandSender, LaneQueueLengths};
use nym_task::{TaskClient, TaskHandle};
use nym_task::{ShutdownManager, ShutdownTracker};
use nym_topology::provider_trait::TopologyProvider;
use nym_topology::HardcodedTopologyProvider;
use nym_validator_client::nym_api::NymApiClientExt;
@@ -95,7 +95,6 @@ impl ClientInput {
}
}
#[derive(Clone)]
pub struct ClientOutput {
pub received_buffer_request_sender: ReceivedBufferRequestSender,
}
@@ -195,7 +194,7 @@ pub struct BaseClientBuilder<C, S: MixnetClientStorage> {
wait_for_gateway: bool,
custom_topology_provider: Option<Box<dyn TopologyProvider + Send + Sync>>,
custom_gateway_transceiver: Option<Box<dyn GatewayTransceiver + Send>>,
shutdown: Option<TaskClient>,
shutdown: Option<ShutdownTracker>,
user_agent: Option<UserAgent>,
setup_method: GatewaySetup,
@@ -281,7 +280,7 @@ where
}
#[must_use]
pub fn with_shutdown(mut self, shutdown: TaskClient) -> Self {
pub fn with_shutdown(mut self, shutdown: ShutdownTracker) -> Self {
self.shutdown = Some(shutdown);
self
}
@@ -325,11 +324,11 @@ where
topology_accessor: TopologyAccessor,
mix_tx: BatchMixMessageSender,
stats_tx: ClientStatsSender,
task_client: TaskClient,
shutdown_tracker: &ShutdownTracker,
) {
info!("Starting loop cover traffic stream...");
let stream = LoopCoverTrafficStream::new(
let mut stream = LoopCoverTrafficStream::new(
ack_key,
debug_config.acknowledgements.average_ack_delay,
mix_tx,
@@ -338,10 +337,9 @@ where
debug_config.traffic,
debug_config.cover_traffic,
stats_tx,
task_client,
);
stream.start();
shutdown_tracker
.try_spawn_named_with_shutdown(async move { stream.run().await }, "CoverTrafficStream");
}
#[allow(clippy::too_many_arguments)]
@@ -357,13 +355,12 @@ where
reply_controller_receiver: ReplyControllerReceiver,
lane_queue_lengths: LaneQueueLengths,
client_connection_rx: ConnectionCommandReceiver,
task_client: TaskClient,
packet_type: PacketType,
stats_tx: ClientStatsSender,
shutdown_tracker: &ShutdownTracker,
) {
info!("Starting real traffic stream...");
RealMessagesController::new(
let real_messages_controller = RealMessagesController::new(
controller_config,
key_rotation_config,
ack_receiver,
@@ -376,9 +373,63 @@ where
lane_queue_lengths,
client_connection_rx,
stats_tx,
task_client,
)
.start(packet_type);
shutdown_tracker.clone_shutdown_token(),
);
// break out all the subtasks
let (mut out_queue_control, mut reply_controller, ack_controller) =
real_messages_controller.into_tasks();
let (
mut ack_listener,
mut input_listener,
mut retransmission_listener,
mut sent_notification_listener,
mut ack_action_controller,
) = ack_controller.into_tasks();
shutdown_tracker.try_spawn_named(
async move { out_queue_control.run().await },
"RealMessagesController::OutQueueControl",
);
let shutdown_token = shutdown_tracker.clone_shutdown_token();
shutdown_tracker.try_spawn_named(
async move { reply_controller.run(shutdown_token).await },
"RealMessagesController::ReplyController",
);
let shutdown_token = shutdown_tracker.clone_shutdown_token();
shutdown_tracker.try_spawn_named(
async move { ack_listener.run(shutdown_token).await },
"AcknowledgementController::AcknowledgementListener",
);
let shutdown_token = shutdown_tracker.clone_shutdown_token();
shutdown_tracker.try_spawn_named(
async move { input_listener.run(shutdown_token).await },
"AcknowledgementController::InputMessageListener",
);
let shutdown_token = shutdown_tracker.clone_shutdown_token();
shutdown_tracker.try_spawn_named(
async move { retransmission_listener.run(shutdown_token).await },
"AcknowledgementController::RetransmissionRequestListener",
);
shutdown_tracker.try_spawn_named_with_shutdown(
async move {
sent_notification_listener.run().await;
},
"AcknowledgementController::SentNotificationListener",
);
let shutdown_token = shutdown_tracker.clone_shutdown_token();
shutdown_tracker.try_spawn_named(
async move { ack_action_controller.run(shutdown_token).await },
"AcknowledgementController::ActionController",
);
// .start(packet_type);
}
// buffer controlling all messages fetched from provider
@@ -389,21 +440,29 @@ where
mixnet_receiver: MixnetMessageReceiver,
reply_key_storage: SentReplyKeys,
reply_controller_sender: ReplyControllerSender,
shutdown: TaskClient,
metrics_reporter: ClientStatsSender,
shutdown_tracker: &ShutdownTracker,
) {
info!("Starting received messages buffer controller...");
let controller: ReceivedMessagesBufferController<SphinxMessageReceiver> =
ReceivedMessagesBufferController::new(
local_encryption_keypair,
query_receiver,
mixnet_receiver,
reply_key_storage,
reply_controller_sender,
metrics_reporter,
shutdown,
);
controller.start()
let controller = ReceivedMessagesBufferController::<SphinxMessageReceiver>::new(
local_encryption_keypair,
query_receiver,
mixnet_receiver,
reply_key_storage,
reply_controller_sender,
metrics_reporter,
shutdown_tracker.clone_shutdown_token(),
);
let (mut msg_receiver, mut req_receiver) = controller.into_tasks();
shutdown_tracker.try_spawn_named(
async move { msg_receiver.run().await },
"ReceivedMessagesBufferController::FragmentedMessageReceiver",
);
shutdown_tracker.try_spawn_named(
async move { req_receiver.run().await },
"ReceivedMessagesBufferController::RequestReceiver",
);
}
#[allow(clippy::too_many_arguments)]
@@ -415,7 +474,7 @@ where
packet_router: PacketRouter,
stats_reporter: ClientStatsSender,
#[cfg(unix)] connection_fd_callback: Option<Arc<dyn Fn(RawFd) + Send + Sync>>,
shutdown: TaskClient,
shutdown_tracker: &ShutdownTracker,
) -> Result<GatewayClient<C, S::CredentialStore>, ClientCoreError>
where
<S::KeyStore as KeyStore>::StorageError: Send + Sync + 'static,
@@ -434,7 +493,7 @@ where
packet_router,
bandwidth_controller,
stats_reporter,
shutdown,
shutdown_tracker.clone_shutdown_token(),
)
} else {
let cfg = GatewayConfig::new(
@@ -459,7 +518,7 @@ where
stats_reporter,
#[cfg(unix)]
connection_fd_callback,
shutdown,
shutdown_tracker.clone_shutdown_token(),
)
};
@@ -522,7 +581,7 @@ where
packet_router: PacketRouter,
stats_reporter: ClientStatsSender,
#[cfg(unix)] connection_fd_callback: Option<Arc<dyn Fn(RawFd) + Send + Sync>>,
mut shutdown: TaskClient,
shutdown_tracker: &ShutdownTracker,
) -> Result<Box<dyn GatewayTransceiver + Send>, ClientCoreError>
where
<S::KeyStore as KeyStore>::StorageError: Send + Sync + 'static,
@@ -539,7 +598,6 @@ where
Err(ClientCoreError::CustomGatewaySelectionExpected)
} else {
// and make sure to invalidate the task client, so we wouldn't cause premature shutdown
shutdown.disarm();
custom_gateway_transceiver.set_packet_router(packet_router)?;
Ok(custom_gateway_transceiver)
};
@@ -555,7 +613,7 @@ where
stats_reporter,
#[cfg(unix)]
connection_fd_callback,
shutdown,
shutdown_tracker,
)
.await?;
@@ -586,22 +644,20 @@ where
topology_accessor: TopologyAccessor,
local_gateway: NodeIdentity,
wait_for_gateway: bool,
mut task_client: TaskClient,
shutdown_tracker: &ShutdownTracker,
) -> Result<(), ClientCoreError> {
let topology_refresher_config =
TopologyRefresherConfig::new(topology_config.topology_refresh_rate);
if topology_config.disable_refreshing {
// if we're not spawning the refresher, don't cause shutdown immediately
info!("The background topology refesher is not going to be started");
task_client.disarm();
info!("The background topology refresher is not going to be started");
}
let mut topology_refresher = TopologyRefresher::new(
topology_refresher_config,
topology_accessor,
topology_provider,
task_client,
);
// before returning, block entire runtime to refresh the current network view so that any
// components depending on topology would see a non-empty view
@@ -646,7 +702,10 @@ where
// don't spawn the refresher if we don't want to be refreshing the topology.
// only use the initial values obtained
info!("Starting topology refresher...");
topology_refresher.start();
shutdown_tracker.try_spawn_named_with_shutdown(
async move { topology_refresher.run().await },
"TopologyRefresher",
);
}
Ok(())
@@ -657,7 +716,7 @@ where
user_agent: Option<UserAgent>,
client_stats_id: String,
input_sender: Sender<InputMessage>,
task_client: TaskClient,
shutdown_tracker: &ShutdownTracker,
) -> ClientStatsSender {
info!("Starting statistics control...");
StatisticsControl::create_and_start(
@@ -667,18 +726,23 @@ where
.unwrap_or("unknown".to_string()),
client_stats_id,
input_sender.clone(),
task_client,
shutdown_tracker,
)
}
fn start_mix_traffic_controller(
gateway_transceiver: Box<dyn GatewayTransceiver + Send>,
shutdown: TaskClient,
shutdown_tracker: &ShutdownTracker,
) -> (BatchMixMessageSender, ClientRequestSender) {
info!("Starting mix traffic controller...");
let (mix_traffic_controller, mix_tx, client_tx) =
MixTrafficController::new(gateway_transceiver, shutdown);
mix_traffic_controller.start();
let (mut mix_traffic_controller, mix_tx, client_tx) =
MixTrafficController::new(gateway_transceiver, shutdown_tracker.clone_shutdown_token());
shutdown_tracker.try_spawn_named(
async move { mix_traffic_controller.run().await },
"MixTrafficController",
);
(mix_tx, client_tx)
}
@@ -686,7 +750,7 @@ where
async fn setup_persistent_reply_storage(
backend: S::ReplyStore,
key_rotation_config: KeyRotationConfig,
shutdown: TaskClient,
shutdown_tracker: &ShutdownTracker,
) -> Result<CombinedReplyStorage, ClientCoreError>
where
<S::ReplyStore as ReplyStorageBackend>::StorageError: Sync + Send,
@@ -711,13 +775,14 @@ where
})?;
let store_clone = mem_store.clone();
spawn_future!(
let shutdown_token = shutdown_tracker.clone_shutdown_token();
shutdown_tracker.try_spawn_named(
async move {
persistent_storage
.flush_on_shutdown(store_clone, shutdown)
.flush_on_shutdown(store_clone, shutdown_token)
.await
},
"PersistentReplyStorage::flush_on_shutdown"
"PersistentReplyStorage::flush_on_shutdown",
);
Ok(mem_store)
@@ -809,11 +874,7 @@ where
TopologyAccessor::new(self.config.debug.topology.ignore_egress_epoch_role);
// Shutdown notifier for signalling tasks to stop
let shutdown = self
.shutdown
.map(Into::<TaskHandle>::into)
.unwrap_or_default()
.name_if_unnamed("BaseNymClient");
let shutdown = ShutdownHelper::new(self.shutdown)?;
// channels responsible for dealing with reply-related fun
let (reply_controller_sender, reply_controller_receiver) =
@@ -845,7 +906,7 @@ where
self.user_agent.clone(),
generate_client_stats_id(*self_address.identity()),
input_sender.clone(),
shutdown.fork("statistics_control"),
shutdown.tracker(),
);
// needs to be started as the first thing to block if required waiting for the gateway
@@ -855,14 +916,14 @@ where
shared_topology_accessor.clone(),
self_address.gateway(),
self.wait_for_gateway,
shutdown.fork("topology_refresher"),
shutdown.tracker(),
)
.await?;
let gateway_packet_router = PacketRouter::new(
ack_sender,
mixnet_messages_sender,
shutdown.get_handle().named("gateway-packet-router"),
shutdown.shutdown_token(),
);
let gateway_transceiver = Self::setup_gateway_transceiver(
@@ -875,7 +936,7 @@ where
stats_reporter.clone(),
#[cfg(unix)]
self.connection_fd_callback,
shutdown.fork("gateway_transceiver"),
shutdown.tracker(),
)
.await?;
let gateway_ws_fd = gateway_transceiver.ws_fd();
@@ -883,7 +944,7 @@ where
let reply_storage = Self::setup_persistent_reply_storage(
reply_storage_backend,
key_rotation_config,
shutdown.fork("persistent_reply_storage"),
shutdown.tracker(),
)
.await?;
@@ -893,8 +954,8 @@ where
mixnet_messages_receiver,
reply_storage.key_storage(),
reply_controller_sender.clone(),
shutdown.fork("received_messages_buffer"),
stats_reporter.clone(),
shutdown.tracker(),
);
// The message_sender is the transmitter for any component generating sphinx packets
@@ -902,10 +963,8 @@ where
// traffic stream.
// The MixTrafficController then sends the actual traffic
let (message_sender, client_request_sender) = Self::start_mix_traffic_controller(
gateway_transceiver,
shutdown.fork("mix_traffic_controller"),
);
let (message_sender, client_request_sender) =
Self::start_mix_traffic_controller(gateway_transceiver, shutdown.tracker());
// Channels that the websocket listener can use to signal downstream to the real traffic
// controller that connections are closed.
@@ -933,9 +992,8 @@ where
reply_controller_receiver,
shared_lane_queue_lengths.clone(),
client_connection_rx,
shutdown.fork("real_traffic_controller"),
self.config.debug.traffic.packet_type,
stats_reporter.clone(),
shutdown.tracker(),
);
if !self
@@ -951,7 +1009,7 @@ where
shared_topology_accessor.clone(),
message_sender,
stats_reporter.clone(),
shutdown.fork("cover_traffic_stream"),
shutdown.tracker(),
);
}
@@ -979,7 +1037,7 @@ where
gateway_connection: GatewayConnection { gateway_ws_fd },
},
stats_reporter,
task_handle: shutdown,
shutdown_handle: shutdown.into_internal(),
client_request_sender,
forget_me: self.config.debug.forget_me,
remember_me: self.config.debug.remember_me,
@@ -995,7 +1053,7 @@ pub struct BaseClient {
pub client_state: ClientState,
pub stats_reporter: ClientStatsSender,
pub client_request_sender: ClientRequestSender,
pub task_handle: TaskHandle,
pub shutdown_handle: Option<ShutdownManager>,
pub forget_me: ForgetMe,
pub remember_me: RememberMe,
}
@@ -3,7 +3,7 @@
use crate::client::mix_traffic::BatchMixMessageSender;
use crate::client::topology_control::TopologyAccessor;
use crate::{config, spawn_future};
use crate::config;
use futures::task::{Context, Poll};
use futures::{Future, Stream, StreamExt};
use nym_sphinx::acknowledgements::AckKey;
@@ -12,7 +12,6 @@ use nym_sphinx::cover::generate_loop_cover_packet;
use nym_sphinx::params::{PacketSize, PacketType};
use nym_sphinx::utils::sample_poisson_duration;
use nym_statistics_common::clients::{packet_statistics::PacketStatisticsEvent, ClientStatsSender};
use nym_task::TaskClient;
use rand::{rngs::OsRng, CryptoRng, Rng};
use std::pin::Pin;
use std::sync::Arc;
@@ -69,8 +68,6 @@ where
packet_type: PacketType,
stats_tx: ClientStatsSender,
task_client: TaskClient,
}
impl<R> Stream for LoopCoverTrafficStream<R>
@@ -117,7 +114,6 @@ impl LoopCoverTrafficStream<OsRng> {
traffic_config: config::Traffic,
cover_config: config::CoverTraffic,
stats_tx: ClientStatsSender,
task_client: TaskClient,
) -> Self {
let rng = OsRng;
@@ -137,7 +133,6 @@ impl LoopCoverTrafficStream<OsRng> {
use_legacy_sphinx_format: traffic_config.use_legacy_sphinx_format,
packet_type: traffic_config.packet_type,
stats_tx,
task_client,
}
}
@@ -235,12 +230,13 @@ impl LoopCoverTrafficStream<OsRng> {
tokio::task::yield_now().await;
}
// it's fine if cover traffic stream task gets killed whilst processing next message
#[allow(clippy::panic)]
pub fn start(mut self) {
pub async fn run(&mut self) {
if self.cover_traffic.disable_loop_cover_traffic_stream {
// we should have never got here in the first place - the task should have never been created to begin with
// so panic and review the code that lead to this branch
panic!("attempted to start LoopCoverTrafficStream while config explicitly disabled it.")
panic!("attempted to run LoopCoverTrafficStream while config explicitly disabled it.")
}
// we should set initial delay only when we actually start the stream
@@ -250,32 +246,11 @@ impl LoopCoverTrafficStream<OsRng> {
);
self.set_next_delay(sampled);
let mut shutdown = self.task_client.fork("select");
while self.next().await.is_some() {
self.on_new_message().await;
}
spawn_future!(
async move {
debug!("Started LoopCoverTrafficStream with graceful shutdown support");
while !shutdown.is_shutdown() {
tokio::select! {
biased;
_ = shutdown.recv() => {
tracing::trace!("LoopCoverTrafficStream: Received shutdown");
}
next = self.next() => {
if next.is_some() {
self.on_new_message().await;
} else {
tracing::trace!("LoopCoverTrafficStream: Stopping since channel closed");
break;
}
}
}
}
shutdown.recv_timeout().await;
tracing::debug!("LoopCoverTrafficStream: Exiting");
},
"LoopCoverTrafficStream"
)
// this should never get triggered
error!("cover traffic stream has been exhausted!")
}
}
@@ -2,11 +2,9 @@
// SPDX-License-Identifier: Apache-2.0
use crate::client::mix_traffic::transceiver::GatewayTransceiver;
use crate::error::ClientCoreError;
use crate::spawn_future;
use nym_gateway_requests::ClientRequest;
use nym_sphinx::forwarding::packet::MixPacket;
use nym_task::TaskClient;
use nym_task::ShutdownToken;
use tracing::*;
use transceiver::ErasedGatewayError;
@@ -34,13 +32,13 @@ pub struct MixTrafficController {
// in long run `gateway_client` will be moved away from `MixTrafficController` anyway.
consecutive_gateway_failure_count: usize,
task_client: TaskClient,
shutdown_token: ShutdownToken,
}
impl MixTrafficController {
pub fn new<T>(
gateway_transceiver: T,
task_client: TaskClient,
shutdown_token: ShutdownToken,
) -> (
MixTrafficController,
BatchMixMessageSender,
@@ -60,7 +58,7 @@ impl MixTrafficController {
mix_rx: message_receiver,
client_rx: client_receiver,
consecutive_gateway_failure_count: 0,
task_client,
shutdown_token,
},
message_sender,
client_sender,
@@ -69,7 +67,7 @@ impl MixTrafficController {
pub fn new_dynamic(
gateway_transceiver: Box<dyn GatewayTransceiver + Send>,
task_client: TaskClient,
shutdown_token: ShutdownToken,
) -> (
MixTrafficController,
BatchMixMessageSender,
@@ -84,7 +82,7 @@ impl MixTrafficController {
mix_rx: message_receiver,
client_rx: client_receiver,
consecutive_gateway_failure_count: 0,
task_client,
shutdown_token,
},
message_sender,
client_sender,
@@ -107,7 +105,7 @@ impl MixTrafficController {
tokio::select! {
biased;
_ = self.task_client.recv() => {
_ = self.shutdown_token.cancelled() => {
trace!("received shutdown while handling messages");
Ok(())
}
@@ -127,7 +125,7 @@ impl MixTrafficController {
async fn on_client_request(&mut self, client_request: ClientRequest) {
tokio::select! {
biased;
_ = self.task_client.recv() => {
_ = self.shutdown_token.cancelled() => {
trace!("received shutdown while handling client request");
}
result = self.gateway_transceiver.send_client_request(client_request) => {
@@ -138,52 +136,44 @@ impl MixTrafficController {
}
}
pub fn start(mut self) {
spawn_future!(
async move {
debug!("Started MixTrafficController with graceful shutdown support");
while !self.task_client.is_shutdown() {
tokio::select! {
biased;
_ = self.task_client.recv() => {
tracing::trace!("MixTrafficController: Received shutdown");
break;
}
mix_packets = self.mix_rx.recv() => match mix_packets {
Some(mix_packets) => {
if let Err(err) = self.on_messages(mix_packets).await {
error!("Failed to send sphinx packet(s) to the gateway: {err}");
if self.consecutive_gateway_failure_count == MAX_FAILURE_COUNT {
// Disconnect from the gateway. If we should try to re-connect
// is handled at a higher layer.
error!("Failed to send sphinx packet to the gateway {MAX_FAILURE_COUNT} times in a row - assuming the gateway is dead");
// Do we need to handle the embedded mixnet client case
// separately?
self.task_client.send_we_stopped(Box::new(ClientCoreError::GatewayFailedToForwardMessages));
break;
}
}
},
None => {
tracing::trace!("MixTrafficController: Stopping since channel closed");
pub async fn run(&mut self) {
debug!("Started MixTrafficController with graceful shutdown support");
loop {
tokio::select! {
biased;
_ = self.shutdown_token.cancelled() => {
trace!("MixTrafficController: Received shutdown");
break;
}
mix_packets = self.mix_rx.recv() => match mix_packets {
Some(mix_packets) => {
if let Err(err) = self.on_messages(mix_packets).await {
error!("Failed to send sphinx packet(s) to the gateway: {err}");
if self.consecutive_gateway_failure_count == MAX_FAILURE_COUNT {
// Disconnect from the gateway. If we should try to re-connect
// is handled at a higher layer.
error!("Failed to send sphinx packet to the gateway {MAX_FAILURE_COUNT} times in a row - assuming the gateway is dead");
// Do we need to handle the embedded mixnet client case
// separately?
break;
}
},
client_request = self.client_rx.recv() => match client_request {
Some(client_request) => {
self.on_client_request(client_request).await;
},
None => {
tracing::trace!("MixTrafficController, client request channel closed");
break
}
},
}
},
None => {
trace!("MixTrafficController: Stopping since channel closed");
break;
}
}
self.task_client.recv_timeout().await;
tracing::debug!("MixTrafficController: Exiting");
},
"MixTrafficController"
);
},
client_request = self.client_rx.recv() => match client_request {
Some(client_request) => {
self.on_client_request(client_request).await;
},
None => {
trace!("MixTrafficController, client request channel closed");
break}
},
}
}
debug!("MixTrafficController: Exiting");
}
}
@@ -10,18 +10,17 @@ use nym_sphinx::{
acknowledgements::{identifier::recover_identifier, AckKey},
chunking::fragment::{FragmentIdentifier, COVER_FRAG_ID},
};
use nym_task::TaskClient;
use nym_task::ShutdownToken;
use std::sync::Arc;
use tracing::*;
/// Module responsible for listening for any data resembling acknowledgements from the network
/// and firing actions to remove them from the 'Pending' state.
pub(super) struct AcknowledgementListener {
pub(crate) struct AcknowledgementListener {
ack_key: Arc<AckKey>,
ack_receiver: AcknowledgementReceiver,
action_sender: AckActionSender,
stats_tx: ClientStatsSender,
task_client: TaskClient,
}
impl AcknowledgementListener {
@@ -30,14 +29,12 @@ impl AcknowledgementListener {
ack_receiver: AcknowledgementReceiver,
action_sender: AckActionSender,
stats_tx: ClientStatsSender,
task_client: TaskClient,
) -> Self {
AcknowledgementListener {
ack_key,
ack_receiver,
action_sender,
stats_tx,
task_client,
}
}
@@ -68,14 +65,9 @@ impl AcknowledgementListener {
trace!("Received {frag_id} from the mix network");
self.stats_tx
.report(PacketStatisticsEvent::RealAckReceived(ack_content.len()).into());
if let Err(err) = self
let _ = self
.action_sender
.unbounded_send(Action::new_remove(frag_id))
{
if !self.task_client.is_shutdown_poll() {
error!("Failed to send remove action to action controller: {err}");
}
}
.unbounded_send(Action::new_remove(frag_id));
}
async fn handle_ack_receiver_item(&mut self, item: Vec<Vec<u8>>) {
@@ -85,11 +77,16 @@ impl AcknowledgementListener {
}
}
pub(super) async fn run(&mut self) {
pub(crate) async fn run(&mut self, shutdown_token: ShutdownToken) {
debug!("Started AcknowledgementListener with graceful shutdown support");
while !self.task_client.is_shutdown() {
loop {
tokio::select! {
biased;
_ = shutdown_token.cancelled() => {
tracing::trace!("AcknowledgementListener: Received shutdown");
break;
}
acks = self.ack_receiver.next() => match acks {
Some(acks) => self.handle_ack_receiver_item(acks).await,
None => {
@@ -97,12 +94,9 @@ impl AcknowledgementListener {
break;
}
},
_ = self.task_client.recv() => {
tracing::trace!("AcknowledgementListener: Received shutdown");
}
}
}
self.task_client.recv_timeout().await;
tracing::debug!("AcknowledgementListener: Exiting");
}
}
@@ -8,7 +8,7 @@ use futures::StreamExt;
use nym_nonexhaustive_delayqueue::{Expired, NonExhaustiveDelayQueue, QueueKey};
use nym_sphinx::chunking::fragment::FragmentIdentifier;
use nym_sphinx::Delay as SphinxDelay;
use nym_task::TaskClient;
use nym_task::ShutdownToken;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
@@ -82,7 +82,7 @@ impl Config {
}
}
pub(super) struct ActionController {
pub(crate) struct ActionController {
/// Configurable parameters of the `ActionController`
config: Config,
@@ -102,8 +102,6 @@ pub(super) struct ActionController {
/// Channel for notifying `RetransmissionRequestListener` about expired acknowledgements.
retransmission_sender: RetransmissionRequestSender,
task_client: TaskClient,
}
impl ActionController {
@@ -111,7 +109,6 @@ impl ActionController {
config: Config,
retransmission_sender: RetransmissionRequestSender,
incoming_actions: AckActionReceiver,
task_client: TaskClient,
) -> Self {
ActionController {
config,
@@ -119,7 +116,6 @@ impl ActionController {
pending_acks_timers: NonExhaustiveDelayQueue::new(),
incoming_actions,
retransmission_sender,
task_client,
}
}
@@ -226,14 +222,9 @@ impl ActionController {
// downgrading an arc and then upgrading vs cloning is difference of 30ns vs 15ns
// so it's literally a NO difference while it might prevent us from unnecessarily
// resending data (in maybe 1 in 1 million cases, but it's something)
if let Err(err) = self
let _ = self
.retransmission_sender
.unbounded_send(Arc::downgrade(pending_ack_data))
{
if !self.task_client.is_shutdown_poll() {
tracing::error!("Failed to send pending ack for retransmission: {err}");
}
}
.unbounded_send(Arc::downgrade(pending_ack_data));
} else {
// this shouldn't cause any issues but shouldn't have happened to begin with!
error!("An already removed pending ack has expired")
@@ -251,11 +242,16 @@ impl ActionController {
}
}
pub(super) async fn run(&mut self) {
pub(crate) async fn run(&mut self, shutdown_token: ShutdownToken) {
debug!("Started ActionController with graceful shutdown support");
while !self.task_client.is_shutdown() {
loop {
tokio::select! {
biased;
_ = shutdown_token.cancelled() => {
tracing::trace!("ActionController: Received shutdown");
break;
}
action = self.incoming_actions.next() => match action {
Some(action) => self.process_action(action),
None => {
@@ -272,13 +268,8 @@ impl ActionController {
break;
}
},
_ = self.task_client.recv() => {
tracing::trace!("ActionController: Received shutdown");
break;
}
}
}
self.task_client.recv_timeout().await;
tracing::debug!("ActionController: Exiting");
}
}
@@ -10,21 +10,20 @@ use nym_sphinx::anonymous_replies::requests::AnonymousSenderTag;
use nym_sphinx::forwarding::packet::MixPacket;
use nym_sphinx::params::PacketType;
use nym_task::connections::TransmissionLane;
use nym_task::TaskClient;
use nym_task::ShutdownToken;
use rand::{CryptoRng, Rng};
use tracing::*;
/// Module responsible for dealing with the received messages: splitting them, creating acknowledgements,
/// putting everything into sphinx packets, etc.
/// It also makes an initial sending attempt for said messages.
pub(super) struct InputMessageListener<R>
pub(crate) struct InputMessageListener<R>
where
R: CryptoRng + Rng,
{
input_receiver: InputMessageReceiver,
message_handler: MessageHandler<R>,
reply_controller_sender: ReplyControllerSender,
task_client: TaskClient,
}
impl<R> InputMessageListener<R>
@@ -38,13 +37,11 @@ where
input_receiver: InputMessageReceiver,
message_handler: MessageHandler<R>,
reply_controller_sender: ReplyControllerSender,
task_client: TaskClient,
) -> Self {
InputMessageListener {
input_receiver,
message_handler,
reply_controller_sender,
task_client,
}
}
@@ -68,14 +65,9 @@ where
max_retransmissions: Option<u32>,
) {
// offload reply handling to the dedicated task
if let Err(err) =
let _ =
self.reply_controller_sender
.send_reply(recipient_tag, data, lane, max_retransmissions)
{
if !self.task_client.is_shutdown_poll() {
error!("failed to send a reply - {err}");
}
}
.send_reply(recipient_tag, data, lane, max_retransmissions);
}
async fn handle_plain_message(
@@ -221,13 +213,13 @@ where
};
}
pub(super) async fn run(&mut self) {
pub(crate) async fn run(&mut self, shutdown_token: ShutdownToken) {
debug!("Started InputMessageListener with graceful shutdown support");
while !self.task_client.is_shutdown() {
loop {
tokio::select! {
biased;
_ = self.task_client.recv() => {
_ = shutdown_token.cancelled() => {
tracing::trace!("InputMessageListener: Received shutdown");
break;
}
@@ -243,7 +235,6 @@ where
}
}
self.task_client.recv_timeout().await;
tracing::debug!("InputMessageListener: Exiting");
}
}
@@ -10,7 +10,6 @@ use self::{
use crate::client::inbound_messages::InputMessageReceiver;
use crate::client::real_messages_control::message_handler::MessageHandler;
use crate::client::replies::reply_controller::ReplyControllerSender;
use crate::spawn_future;
use action_controller::AckActionReceiver;
use futures::channel::mpsc;
use nym_gateway_client::AcknowledgementReceiver;
@@ -23,13 +22,11 @@ use nym_sphinx::{
Delay as SphinxDelay,
};
use nym_statistics_common::clients::ClientStatsSender;
use nym_task::TaskClient;
use rand::{CryptoRng, Rng};
use std::{
sync::{Arc, Weak},
time::Duration,
};
use tracing::*;
pub(crate) use action_controller::{AckActionSender, Action};
@@ -190,6 +187,9 @@ pub(super) struct Config {
/// Predefined packet size used for the encapsulated messages.
packet_size: PacketSize,
/// Type of packets used for retransmissions
packet_type: PacketType,
}
impl Config {
@@ -197,12 +197,14 @@ impl Config {
maximum_retransmissions: Option<u32>,
ack_wait_addition: Duration,
ack_wait_multiplier: f64,
packet_type: PacketType,
) -> Self {
Config {
maximum_retransmissions,
ack_wait_addition,
ack_wait_multiplier,
packet_size: Default::default(),
packet_type,
}
}
@@ -212,7 +214,7 @@ impl Config {
}
}
pub(super) struct AcknowledgementController<R>
pub(crate) struct AcknowledgementController<R>
where
R: CryptoRng + Rng,
{
@@ -234,7 +236,6 @@ where
message_handler: MessageHandler<R>,
reply_controller_sender: ReplyControllerSender,
stats_tx: ClientStatsSender,
task_client: TaskClient,
) -> Self {
let (retransmission_tx, retransmission_rx) = mpsc::unbounded();
@@ -244,7 +245,6 @@ where
action_config,
retransmission_tx,
connectors.ack_action_receiver,
task_client.fork("action_controller"),
);
// will listen for any acks coming from the network
@@ -253,7 +253,6 @@ where
connectors.ack_receiver,
connectors.ack_action_sender.clone(),
stats_tx,
task_client.fork("acknowledgement_listener"),
);
// will listen for any new messages from the client
@@ -261,7 +260,6 @@ where
connectors.input_receiver,
message_handler.clone(),
reply_controller_sender.clone(),
task_client.fork("input_message_listener"),
);
// will listen for any ack timeouts and trigger retransmission
@@ -271,16 +269,13 @@ where
message_handler,
retransmission_rx,
reply_controller_sender,
task_client.fork("retransmission_request_listener"),
config.packet_type,
);
// will listen for events indicating the packet was sent through the network so that
// the retransmission timer should be started.
let sent_notification_listener = SentNotificationListener::new(
connectors.sent_notifier,
connectors.ack_action_sender,
task_client.with_suffix("sent_notification_listener"),
);
let sent_notification_listener =
SentNotificationListener::new(connectors.sent_notifier, connectors.ack_action_sender);
AcknowledgementController {
acknowledgement_listener,
@@ -291,51 +286,21 @@ where
}
}
pub(super) fn start(self, packet_type: PacketType) {
let mut acknowledgement_listener = self.acknowledgement_listener;
let mut input_message_listener = self.input_message_listener;
let mut retransmission_request_listener = self.retransmission_request_listener;
let mut sent_notification_listener = self.sent_notification_listener;
let mut action_controller = self.action_controller;
spawn_future!(
async move {
acknowledgement_listener.run().await;
debug!("The acknowledgement listener has finished execution!");
},
"AcknowledgementController::AcknowledgementListener"
);
spawn_future!(
async move {
input_message_listener.run().await;
debug!("The input listener has finished execution!");
},
"AcknowledgementController::InputMessageListener"
);
spawn_future!(
async move {
retransmission_request_listener.run(packet_type).await;
debug!("The retransmission request listener has finished execution!");
},
"AcknowledgementController::RetransmissionRequestListener"
);
spawn_future!(
async move {
sent_notification_listener.run().await;
debug!("The sent notification listener has finished execution!");
},
"AcknowledgementController::SentNotificationListener"
);
spawn_future!(
async move {
action_controller.run().await;
debug!("The controller has finished execution!");
},
"AcknowledgementController::ActionController"
);
pub(crate) fn into_tasks(
self,
) -> (
AcknowledgementListener,
InputMessageListener<R>,
RetransmissionRequestListener<R>,
SentNotificationListener,
ActionController,
) {
(
self.acknowledgement_listener,
self.input_message_listener,
self.retransmission_request_listener,
self.sent_notification_listener,
self.action_controller,
)
}
}
@@ -13,19 +13,19 @@ use futures::StreamExt;
use nym_sphinx::chunking::fragment::Fragment;
use nym_sphinx::preparer::PreparedFragment;
use nym_sphinx::{addressing::clients::Recipient, params::PacketType};
use nym_task::{connections::TransmissionLane, TaskClient};
use nym_task::{connections::TransmissionLane, ShutdownToken};
use rand::{CryptoRng, Rng};
use std::sync::{Arc, Weak};
use tracing::*;
// responsible for packet retransmission upon fired timer
pub(super) struct RetransmissionRequestListener<R> {
pub(crate) struct RetransmissionRequestListener<R> {
maximum_retransmissions: Option<u32>,
action_sender: AckActionSender,
message_handler: MessageHandler<R>,
request_receiver: RetransmissionRequestReceiver,
reply_controller_sender: ReplyControllerSender,
task_client: TaskClient,
packet_type: PacketType,
}
impl<R> RetransmissionRequestListener<R>
@@ -38,7 +38,7 @@ where
message_handler: MessageHandler<R>,
request_receiver: RetransmissionRequestReceiver,
reply_controller_sender: ReplyControllerSender,
task_client: TaskClient,
packet_type: PacketType,
) -> Self {
RetransmissionRequestListener {
maximum_retransmissions,
@@ -46,7 +46,7 @@ where
message_handler,
request_receiver,
reply_controller_sender,
task_client,
packet_type,
}
}
@@ -67,7 +67,6 @@ where
async fn on_retransmission_request(
&mut self,
weak_timed_out_ack: Weak<PendingAcknowledgement>,
packet_type: PacketType,
) {
let timed_out_ack = match weak_timed_out_ack.upgrade() {
Some(timed_out_ack) => timed_out_ack,
@@ -97,22 +96,18 @@ where
} => {
// if this is retransmission for reply, offload it to the dedicated task
// that deals with all the surbs
if let Err(err) = self.reply_controller_sender.send_retransmission_data(
let _ = self.reply_controller_sender.send_retransmission_data(
*recipient_tag,
weak_timed_out_ack,
*extra_surb_request,
) {
if !self.task_client.is_shutdown_poll() {
error!("Failed to send retransmission data to the reply controller: {err}");
}
}
);
return;
}
PacketDestination::KnownRecipient(recipient) => {
self.prepare_normal_retransmission_chunk(
**recipient,
timed_out_ack.message_chunk.clone(),
packet_type,
self.packet_type,
)
.await
}
@@ -153,14 +148,9 @@ where
// is sent to the `OutQueueControl` and has gone through its internal queue
// with the additional poisson delay.
// And since Actions are executed in order `UpdateTimer` will HAVE TO be executed before `StartTimer`
if let Err(err) = self
let _ = self
.action_sender
.unbounded_send(Action::new_update_pending_ack(frag_id, new_delay))
{
if !self.task_client.is_shutdown_poll() {
error!("Failed to send update pending ack action to the controller: {err}");
}
}
.unbounded_send(Action::new_update_pending_ack(frag_id, new_delay));
// send to `OutQueueControl` to eventually send to the mix network
self.message_handler
@@ -174,18 +164,18 @@ where
.await
}
pub(super) async fn run(&mut self, packet_type: PacketType) {
pub(crate) async fn run(&mut self, shutdown_token: ShutdownToken) {
debug!("Started RetransmissionRequestListener with graceful shutdown support");
while !self.task_client.is_shutdown() {
loop {
tokio::select! {
biased;
_ = self.task_client.recv() => {
_ = shutdown_token.cancelled() => {
tracing::trace!("RetransmissionRequestListener: Received shutdown");
break;
}
timed_out_ack = self.request_receiver.next() => match timed_out_ack {
Some(timed_out_ack) => self.on_retransmission_request(timed_out_ack, packet_type).await,
Some(timed_out_ack) => self.on_retransmission_request(timed_out_ack).await,
None => {
tracing::trace!("RetransmissionRequestListener: Stopping since channel closed");
break;
@@ -194,7 +184,6 @@ where
}
}
self.task_client.recv_timeout().await;
tracing::debug!("RetransmissionRequestListener: Exiting");
}
}
@@ -5,29 +5,25 @@ use super::action_controller::{AckActionSender, Action};
use super::SentPacketNotificationReceiver;
use futures::StreamExt;
use nym_sphinx::chunking::fragment::{FragmentIdentifier, COVER_FRAG_ID};
use nym_task::TaskClient;
use tracing::*;
/// Module responsible for starting up retransmission timers.
/// It is required because when we send our packet to the `real traffic stream` controlled
/// by a poisson timer, there's no guarantee the message will be sent immediately, so we might
/// accidentally fire retransmission way quicker than we should have.
pub(super) struct SentNotificationListener {
pub(crate) struct SentNotificationListener {
sent_notifier: SentPacketNotificationReceiver,
action_sender: AckActionSender,
task_client: TaskClient,
}
impl SentNotificationListener {
pub(super) fn new(
sent_notifier: SentPacketNotificationReceiver,
action_sender: AckActionSender,
task_client: TaskClient,
) -> Self {
SentNotificationListener {
sent_notifier,
action_sender,
task_client,
}
}
@@ -36,37 +32,18 @@ impl SentNotificationListener {
trace!("sent off a cover message - no need to start retransmission timer!");
return;
}
if let Err(err) = self
let _ = self
.action_sender
.unbounded_send(Action::new_start_timer(frag_id))
{
if !self.task_client.is_shutdown_poll() {
error!("Failed to send start timer action to action controller: {err}");
}
}
.unbounded_send(Action::new_start_timer(frag_id));
}
pub(super) async fn run(&mut self) {
pub(crate) async fn run(&mut self) {
debug!("Started SentNotificationListener with graceful shutdown support");
while !self.task_client.is_shutdown() {
tokio::select! {
frag_id = self.sent_notifier.next() => match frag_id {
Some(frag_id) => {
self.on_sent_message(frag_id).await;
}
None => {
tracing::trace!("SentNotificationListener: Stopping since channel closed");
break;
}
},
_ = self.task_client.recv() => {
tracing::trace!("SentNotificationListener: Received shutdown");
break;
}
}
while let Some(frag_id) = self.sent_notifier.next().await {
self.on_sent_message(frag_id).await;
}
assert!(self.task_client.is_shutdown_poll());
tracing::debug!("SentNotificationListener: Exiting");
}
}
@@ -20,7 +20,7 @@ use nym_sphinx::params::{PacketSize, PacketType};
use nym_sphinx::preparer::{MessagePreparer, PreparedFragment};
use nym_sphinx::Delay;
use nym_task::connections::TransmissionLane;
use nym_task::TaskClient;
use nym_task::ShutdownToken;
use nym_topology::{NymRouteProvider, NymTopologyError};
use rand::{CryptoRng, Rng};
use std::collections::HashMap;
@@ -189,7 +189,7 @@ pub(crate) struct MessageHandler<R> {
topology_access: TopologyAccessor,
reply_key_storage: SentReplyKeys,
tag_storage: UsedSenderTags,
task_client: TaskClient,
shutdown_token: ShutdownToken,
}
impl<R> MessageHandler<R>
@@ -205,7 +205,7 @@ where
topology_access: TopologyAccessor,
reply_key_storage: SentReplyKeys,
tag_storage: UsedSenderTags,
task_client: TaskClient,
shutdown_token: ShutdownToken,
) -> Self
where
R: Copy,
@@ -228,7 +228,7 @@ where
topology_access,
reply_key_storage,
tag_storage,
task_client,
shutdown_token,
}
}
@@ -712,7 +712,7 @@ where
.action_sender
.unbounded_send(Action::UpdatePendingAck(id, new_delay))
{
if !self.task_client.is_shutdown_poll() {
if !self.shutdown_token.is_cancelled() {
error!("Failed to send update action to the controller: {err}");
}
}
@@ -723,7 +723,7 @@ where
.action_sender
.unbounded_send(Action::new_insert(pending_acks))
{
if !self.task_client.is_shutdown_poll() {
if !self.shutdown_token.is_cancelled() {
error!("Failed to send insert action to the controller: {err}");
}
}
@@ -737,7 +737,7 @@ where
) {
tokio::select! {
biased;
_ = self.task_client.recv() => {
_ = self.shutdown_token.cancelled() => {
trace!("received shutdown while attempting to forward mixnet messages");
}
sending_res = self.real_message_sender.send((messages, transmission_lane)) => {
@@ -14,26 +14,21 @@ use crate::client::replies::reply_controller::{
ReplyController, ReplyControllerReceiver, ReplyControllerSender,
};
use crate::client::replies::reply_storage::CombinedReplyStorage;
use crate::config;
use crate::{
client::{
inbound_messages::InputMessageReceiver, mix_traffic::BatchMixMessageSender,
real_messages_control::acknowledgement_control::AcknowledgementControllerConnectors,
topology_control::TopologyAccessor,
},
spawn_future,
use crate::client::{
inbound_messages::InputMessageReceiver, mix_traffic::BatchMixMessageSender,
real_messages_control::acknowledgement_control::AcknowledgementControllerConnectors,
topology_control::TopologyAccessor,
};
use crate::config;
use futures::channel::mpsc;
use nym_gateway_client::AcknowledgementReceiver;
use nym_sphinx::acknowledgements::AckKey;
use nym_sphinx::addressing::clients::Recipient;
use nym_sphinx::params::PacketType;
use nym_statistics_common::clients::ClientStatsSender;
use nym_task::connections::{ConnectionCommandReceiver, LaneQueueLengths};
use nym_task::TaskClient;
use nym_task::ShutdownToken;
use rand::{rngs::OsRng, CryptoRng, Rng};
use std::sync::Arc;
use tracing::*;
use crate::client::replies::reply_controller::key_rotation_helpers::KeyRotationConfig;
pub(crate) use acknowledgement_control::{AckActionSender, Action};
@@ -69,6 +64,7 @@ impl<'a> From<&'a Config> for acknowledgement_control::Config {
cfg.traffic.maximum_number_of_retransmissions,
cfg.acks.ack_wait_addition,
cfg.acks.ack_wait_multiplier,
cfg.traffic.packet_type,
)
.with_custom_packet_size(cfg.traffic.primary_packet_size)
}
@@ -146,7 +142,7 @@ impl RealMessagesController<OsRng> {
lane_queue_lengths: LaneQueueLengths,
client_connection_rx: ConnectionCommandReceiver,
stats_tx: ClientStatsSender,
task_client: TaskClient,
shutdown_token: ShutdownToken,
) -> Self {
let rng = OsRng;
@@ -178,7 +174,7 @@ impl RealMessagesController<OsRng> {
topology_access.clone(),
reply_storage.key_storage(),
reply_storage.tags_storage(),
task_client.fork("message_handler"),
shutdown_token.clone(),
);
let ack_control = AcknowledgementController::new(
@@ -188,7 +184,6 @@ impl RealMessagesController<OsRng> {
message_handler.clone(),
reply_controller_sender,
stats_tx.clone(),
task_client.fork("ack_control"),
);
let reply_control = ReplyController::new(
@@ -196,7 +191,6 @@ impl RealMessagesController<OsRng> {
message_handler,
reply_storage,
reply_controller_receiver,
task_client.fork("reply_controller"),
);
let out_queue_control = OutQueueControl::new(
@@ -209,7 +203,7 @@ impl RealMessagesController<OsRng> {
lane_queue_lengths,
client_connection_rx,
stats_tx,
task_client.with_suffix("out_queue_control"),
shutdown_token.clone(),
);
RealMessagesController {
@@ -219,26 +213,13 @@ impl RealMessagesController<OsRng> {
}
}
pub fn start(self, packet_type: PacketType) {
let mut out_queue_control = self.out_queue_control;
let ack_control = self.ack_control;
let mut reply_control = self.reply_control;
spawn_future!(
async move {
out_queue_control.run().await;
debug!("The out queue controller has finished execution!");
},
"RealMessagesController::OutQueueControl)"
);
spawn_future!(
async move {
reply_control.run().await;
debug!("The reply controller has finished execution!");
},
"RealMessagesController::ReplyController"
);
ack_control.start(packet_type);
pub fn into_tasks(
self,
) -> (
OutQueueControl<OsRng>,
ReplyController<OsRng>,
AcknowledgementController<OsRng>,
) {
(self.out_queue_control, self.reply_control, self.ack_control)
}
}
@@ -21,7 +21,7 @@ use nym_statistics_common::clients::{packet_statistics::PacketStatisticsEvent, C
use nym_task::connections::{
ConnectionCommand, ConnectionCommandReceiver, ConnectionId, LaneQueueLengths, TransmissionLane,
};
use nym_task::TaskClient;
use nym_task::ShutdownToken;
use rand::{CryptoRng, Rng};
use std::pin::Pin;
use std::sync::Arc;
@@ -119,7 +119,7 @@ where
/// Channel used for sending metrics events (specifically `PacketStatistics` events) to the metrics tracker.
stats_tx: ClientStatsSender,
task_client: TaskClient,
shutdown_token: ShutdownToken,
}
#[derive(Debug)]
@@ -179,7 +179,7 @@ where
lane_queue_lengths: LaneQueueLengths,
client_connection_rx: ConnectionCommandReceiver,
stats_tx: ClientStatsSender,
task_client: TaskClient,
shutdown_token: ShutdownToken,
) -> Self {
OutQueueControl {
config,
@@ -194,7 +194,7 @@ where
client_connection_rx,
lane_queue_lengths,
stats_tx,
task_client,
shutdown_token,
}
}
@@ -282,7 +282,7 @@ where
let sending_res = tokio::select! {
biased;
_ = self.task_client.recv() => {
_ = self.shutdown_token.cancelled() => {
trace!("received shutdown signal while attempting to send mix message");
return
}
@@ -293,7 +293,7 @@ where
match sending_res {
Err(_) => {
if !self.task_client.is_shutdown_poll() {
if !self.shutdown_token.is_cancelled() {
tracing::error!(
"failed to send mixnet packet due to closed channel (outside of shutdown!)"
);
@@ -536,9 +536,7 @@ where
}
#[cfg(not(target_arch = "wasm32"))]
fn log_status(&self, shutdown: &mut TaskClient) {
use crate::error::ClientCoreStatusMessage;
fn log_status(&self) {
let packets = self.transmission_buffer.total_size();
let lanes = self.transmission_buffer.lanes();
let mult = self.sending_delay_controller.current_multiplier();
@@ -567,32 +565,33 @@ where
tracing::debug!("{status_str}");
}
// Send status message to whoever is listening (possibly UI)
if mult == self.sending_delay_controller.max_multiplier() {
shutdown.send_status_msg(Box::new(ClientCoreStatusMessage::GatewayIsVerySlow));
} else if mult > self.sending_delay_controller.min_multiplier() {
shutdown.send_status_msg(Box::new(ClientCoreStatusMessage::GatewayIsSlow));
}
// leave the code commented in case somebody wanted to restore this logic with a different channel
// // Send status message to whoever is listening (possibly UI)
// if mult == self.sending_delay_controller.max_multiplier() {
// shutdown.send_status_msg(Box::new(ClientCoreStatusMessage::GatewayIsVerySlow));
// } else if mult > self.sending_delay_controller.min_multiplier() {
// shutdown.send_status_msg(Box::new(ClientCoreStatusMessage::GatewayIsSlow));
// }
}
pub(super) async fn run(&mut self) {
pub(crate) async fn run(&mut self) {
debug!("Started OutQueueControl with graceful shutdown support");
let mut shutdown = self.task_client.fork("select");
// avoid borrow on self
let shutdown_token = self.shutdown_token.clone();
#[cfg(not(target_arch = "wasm32"))]
{
let mut status_timer = tokio::time::interval(Duration::from_secs(5));
while !shutdown.is_shutdown() {
loop {
tokio::select! {
biased;
_ = shutdown.recv() => {
_ = shutdown_token.cancelled() => {
tracing::trace!("OutQueueControl: Received shutdown");
break;
}
_ = status_timer.tick() => {
self.log_status(&mut shutdown);
self.log_status();
}
next_message = self.next() => if let Some(next_message) = next_message {
self.on_message(next_message).await;
@@ -602,16 +601,16 @@ where
}
}
}
shutdown.recv_timeout().await;
}
#[cfg(target_arch = "wasm32")]
{
while !shutdown.is_shutdown() {
loop {
tokio::select! {
biased;
_ = shutdown.recv() => {
_ = shutdown_token.cancelled() => {
tracing::trace!("OutQueueControl: Received shutdown");
break;
}
next_message = self.next() => if let Some(next_message) = next_message {
self.on_message(next_message).await;
@@ -83,11 +83,13 @@ impl SendingDelayController {
self.current_multiplier
}
#[allow(dead_code)]
#[cfg(not(target_arch = "wasm32"))]
pub(crate) fn min_multiplier(&self) -> u32 {
self.lower_bound
}
#[allow(dead_code)]
#[cfg(not(target_arch = "wasm32"))]
pub(crate) fn max_multiplier(&self) -> u32 {
self.upper_bound
@@ -5,7 +5,6 @@ use crate::client::helpers::get_time_now;
use crate::client::replies::{
reply_controller::ReplyControllerSender, reply_storage::SentReplyKeys,
};
use crate::spawn_future;
use futures::channel::mpsc;
use futures::lock::Mutex;
use futures::StreamExt;
@@ -20,7 +19,7 @@ use nym_sphinx::message::{NymMessage, PlainMessage};
use nym_sphinx::params::ReplySurbKeyDigestAlgorithm;
use nym_sphinx::receiver::{MessageReceiver, MessageRecoveryError, ReconstructedMessage};
use nym_statistics_common::clients::{packet_statistics::PacketStatisticsEvent, ClientStatsSender};
use nym_task::TaskClient;
use nym_task::ShutdownToken;
use std::collections::HashSet;
use std::sync::Arc;
use std::time::Duration;
@@ -172,7 +171,7 @@ struct ReceivedMessagesBuffer<R: MessageReceiver> {
inner: Arc<Mutex<ReceivedMessagesBufferInner<R>>>,
reply_key_storage: SentReplyKeys,
reply_controller_sender: ReplyControllerSender,
task_client: TaskClient,
shutdown_token: ShutdownToken,
}
impl<R: MessageReceiver> ReceivedMessagesBuffer<R> {
@@ -181,7 +180,7 @@ impl<R: MessageReceiver> ReceivedMessagesBuffer<R> {
reply_key_storage: SentReplyKeys,
reply_controller_sender: ReplyControllerSender,
stats_tx: ClientStatsSender,
task_client: TaskClient,
shutdown_token: ShutdownToken,
) -> Self {
ReceivedMessagesBuffer {
inner: Arc::new(Mutex::new(ReceivedMessagesBufferInner {
@@ -195,7 +194,7 @@ impl<R: MessageReceiver> ReceivedMessagesBuffer<R> {
})),
reply_key_storage,
reply_controller_sender,
task_client,
shutdown_token,
}
}
@@ -316,7 +315,7 @@ impl<R: MessageReceiver> ReceivedMessagesBuffer<R> {
reply_surbs,
from_surb_request,
) {
if !self.task_client.is_shutdown_poll() {
if !self.shutdown_token.is_cancelled() {
error!("{err}");
}
}
@@ -339,7 +338,7 @@ impl<R: MessageReceiver> ReceivedMessagesBuffer<R> {
.reply_controller_sender
.send_additional_surbs_request(*recipient, amount)
{
if !self.task_client.is_shutdown_poll() {
if !self.shutdown_token.is_cancelled() {
error!("{err}");
}
}
@@ -466,22 +465,22 @@ pub enum ReceivedBufferMessage {
ReceiverDisconnect,
}
struct RequestReceiver<R: MessageReceiver> {
pub(crate) struct RequestReceiver<R: MessageReceiver> {
received_buffer: ReceivedMessagesBuffer<R>,
query_receiver: ReceivedBufferRequestReceiver,
task_client: TaskClient,
shutdown_token: ShutdownToken,
}
impl<R: MessageReceiver> RequestReceiver<R> {
fn new(
received_buffer: ReceivedMessagesBuffer<R>,
query_receiver: ReceivedBufferRequestReceiver,
task_client: TaskClient,
shutdown_token: ShutdownToken,
) -> Self {
RequestReceiver {
received_buffer,
query_receiver,
task_client,
shutdown_token,
}
}
@@ -496,66 +495,70 @@ impl<R: MessageReceiver> RequestReceiver<R> {
}
}
async fn run(&mut self) {
pub(crate) async fn run(&mut self) {
debug!("Started RequestReceiver with graceful shutdown support");
while !self.task_client.is_shutdown() {
loop {
tokio::select! {
biased;
_ = self.task_client.recv() => {
_ = self.shutdown_token.cancelled() => {
tracing::trace!("RequestReceiver: Received shutdown");
break;
}
request = self.query_receiver.next() => {
if let Some(message) = request {
self.handle_message(message).await
} else {
tracing::trace!("RequestReceiver: Stopping since channel closed");
self.shutdown_token.cancelled().await;
break;
}
},
}
}
self.task_client.recv().await;
tracing::debug!("RequestReceiver: Exiting");
}
}
struct FragmentedMessageReceiver<R: MessageReceiver> {
pub(crate) struct FragmentedMessageReceiver<R: MessageReceiver> {
received_buffer: ReceivedMessagesBuffer<R>,
mixnet_packet_receiver: MixnetMessageReceiver,
task_client: TaskClient,
shutdown_token: ShutdownToken,
}
impl<R: MessageReceiver> FragmentedMessageReceiver<R> {
fn new(
received_buffer: ReceivedMessagesBuffer<R>,
mixnet_packet_receiver: MixnetMessageReceiver,
task_client: TaskClient,
shutdown_token: ShutdownToken,
) -> Self {
FragmentedMessageReceiver {
received_buffer,
mixnet_packet_receiver,
task_client,
shutdown_token,
}
}
async fn run(&mut self) -> Result<(), MessageRecoveryError> {
pub(crate) async fn run(&mut self) -> Result<(), MessageRecoveryError> {
debug!("Started FragmentedMessageReceiver with graceful shutdown support");
while !self.task_client.is_shutdown() {
loop {
tokio::select! {
biased;
_ = self.shutdown_token.cancelled() => {
tracing::trace!("FragmentedMessageReceiver: Received shutdown");
break;
}
new_messages = self.mixnet_packet_receiver.next() => {
if let Some(new_messages) = new_messages {
self.received_buffer.handle_new_received(new_messages).await?;
} else {
tracing::trace!("FragmentedMessageReceiver: Stopping since channel closed");
self.shutdown_token.cancelled().await;
break;
}
},
_ = self.task_client.recv_with_delay() => {
tracing::trace!("FragmentedMessageReceiver: Received shutdown");
}
}
}
self.task_client.recv_timeout().await;
tracing::debug!("FragmentedMessageReceiver: Exiting");
Ok(())
}
@@ -574,48 +577,31 @@ impl<R: MessageReceiver + Clone + Send + 'static> ReceivedMessagesBufferControll
reply_key_storage: SentReplyKeys,
reply_controller_sender: ReplyControllerSender,
metrics_reporter: ClientStatsSender,
task_client: TaskClient,
shutdown_token: ShutdownToken,
) -> Self {
let received_buffer = ReceivedMessagesBuffer::new(
local_encryption_keypair,
reply_key_storage,
reply_controller_sender,
metrics_reporter,
task_client.fork("received_messages_buffer"),
shutdown_token.clone(),
);
ReceivedMessagesBufferController {
fragmented_message_receiver: FragmentedMessageReceiver::new(
received_buffer.clone(),
mixnet_packet_receiver,
task_client.fork("fragmented_message_receiver"),
shutdown_token.clone(),
),
request_receiver: RequestReceiver::new(
received_buffer,
query_receiver,
task_client.with_suffix("request_receiver"),
shutdown_token.clone(),
),
}
}
pub fn start(self) {
let mut fragmented_message_receiver = self.fragmented_message_receiver;
let mut request_receiver = self.request_receiver;
spawn_future!(
async move {
match fragmented_message_receiver.run().await {
Ok(_) => {}
Err(e) => error!("{e}"),
}
},
"ReceivedMessagesBufferController::FragmentedMessageReceiver"
);
spawn_future!(
async move {
request_receiver.run().await;
},
"ReceivedMessagesBufferController::RequestReceiver"
);
pub(crate) fn into_tasks(self) -> (FragmentedMessageReceiver<R>, RequestReceiver<R>) {
(self.fragmented_message_receiver, self.request_receiver)
}
}
@@ -7,7 +7,7 @@ use crate::client::replies::reply_controller::key_rotation_helpers::KeyRotationC
use crate::client::replies::reply_storage::CombinedReplyStorage;
use crate::config;
use futures::StreamExt;
use nym_task::TaskClient;
use nym_task::ShutdownToken;
use rand::rngs::OsRng;
use rand::{CryptoRng, Rng};
use std::time::Duration;
@@ -60,9 +60,6 @@ pub struct ReplyController<R> {
receiver_controller: ReceiverReplyController<R>,
request_receiver: ReplyControllerReceiver,
// Listen for shutdown signals
task_client: TaskClient,
}
impl ReplyController<OsRng> {
@@ -71,7 +68,6 @@ impl ReplyController<OsRng> {
message_handler: MessageHandler<OsRng>,
full_reply_storage: CombinedReplyStorage,
request_receiver: ReplyControllerReceiver,
task_client: TaskClient,
) -> Self {
ReplyController {
config,
@@ -86,7 +82,6 @@ impl ReplyController<OsRng> {
message_handler,
),
request_receiver,
task_client,
}
}
}
@@ -148,22 +143,21 @@ where
self.sender_controller.inspect_and_clear_stale_data(now)
}
pub(crate) async fn run(&mut self) {
pub(crate) async fn run(&mut self, shutdown_token: ShutdownToken) {
debug!("Started ReplyController with graceful shutdown support");
let mut shutdown = self.task_client.fork("reply-controller");
let polling_rate = Duration::from_secs(5);
let mut stale_inspection = new_interval_stream(polling_rate);
let polling_rate = self.config.key_rotation.epoch_duration / 8;
let mut invalidation_inspection = new_interval_stream(polling_rate);
while !shutdown.is_shutdown() {
loop {
tokio::select! {
biased;
_ = shutdown.recv() => {
_ = shutdown_token.cancelled() => {
tracing::trace!("ReplyController: Received shutdown");
break;
},
req = self.request_receiver.next() => match req {
Some(req) => self.handle_request(req).await,
@@ -181,7 +175,6 @@ where
}
}
}
assert!(shutdown.is_shutdown_poll());
tracing::debug!("ReplyController: Exiting");
}
}
@@ -16,21 +16,17 @@
#![warn(clippy::todo)]
#![warn(clippy::dbg_macro)]
use crate::client::inbound_messages::{InputMessage, InputMessageSender};
use futures::StreamExt;
use nym_client_core_config_types::StatsReporting;
use nym_sphinx::addressing::Recipient;
use nym_statistics_common::clients::{
ClientStatsController, ClientStatsReceiver, ClientStatsSender,
};
use nym_task::{connections::TransmissionLane, TaskClient};
use nym_task::{connections::TransmissionLane, ShutdownToken, ShutdownTracker};
use std::time::Duration;
use crate::{
client::inbound_messages::{InputMessage, InputMessageSender},
spawn_future,
};
/// Time interval between reporting statistics locally (logging/task_client)
/// Time interval between reporting statistics locally (logging/shutdown_token)
const LOCAL_REPORT_INTERVAL: Duration = Duration::from_secs(2);
/// Interval for taking snapshots of the statistics
const SNAPSHOT_INTERVAL: Duration = Duration::from_millis(500);
@@ -51,9 +47,6 @@ pub(crate) struct StatisticsControl {
/// Config for stats reporting (enabled, address, interval)
reporting_config: StatsReporting,
/// Task client for listening for shutdown
task_client: TaskClient,
}
impl StatisticsControl {
@@ -62,24 +55,20 @@ impl StatisticsControl {
client_type: String,
client_stats_id: String,
report_tx: InputMessageSender,
task_client: TaskClient,
shutdown_token: ShutdownToken,
) -> (Self, ClientStatsSender) {
let (stats_tx, stats_rx) = tokio::sync::mpsc::unbounded_channel();
let stats = ClientStatsController::new(client_stats_id, client_type);
let mut task_client_stats_sender = task_client.fork("stats_sender");
task_client_stats_sender.disarm();
(
StatisticsControl {
stats,
stats_rx,
report_tx,
reporting_config,
task_client,
},
ClientStatsSender::new(Some(stats_tx), task_client_stats_sender),
ClientStatsSender::new(Some(stats_tx), shutdown_token),
)
}
@@ -99,7 +88,8 @@ impl StatisticsControl {
}
}
async fn run(&mut self) {
// manually control the shutdown mechanism as we don't want to get interrupted mid-snapshot
pub async fn run(&mut self, shutdown_token: ShutdownToken) {
tracing::debug!("Started StatisticsControl with graceful shutdown support");
#[cfg(not(target_arch = "wasm32"))]
@@ -129,10 +119,10 @@ impl StatisticsControl {
let mut snapshot_interval =
gloo_timers::future::IntervalStream::new(SNAPSHOT_INTERVAL.as_millis() as u32);
while !self.task_client.is_shutdown() {
loop {
tokio::select! {
biased;
_ = self.task_client.recv() => {
_ = shutdown_token.cancelled() => {
tracing::trace!("StatisticsControl: Received shutdown");
break;
},
@@ -157,37 +147,34 @@ impl StatisticsControl {
}
_ = local_report_interval.next() => {
self.stats.local_report(&mut self.task_client);
self.stats.local_report();
}
}
}
tracing::debug!("StatisticsControl: Exiting");
}
pub(crate) fn start(mut self) {
spawn_future!(
async move {
self.run().await;
},
"StatisticsControl"
)
}
pub(crate) fn create_and_start(
reporting_config: StatsReporting,
client_type: String,
client_stats_id: String,
report_tx: InputMessageSender,
task_client: TaskClient,
shutdown_tracker: &ShutdownTracker,
) -> ClientStatsSender {
let (controller, sender) = Self::create(
let (mut controller, sender) = Self::create(
reporting_config,
client_type,
client_stats_id,
report_tx,
task_client,
shutdown_tracker.child_shutdown_token(),
);
let shutdown_token = shutdown_tracker.clone_shutdown_token();
shutdown_tracker.try_spawn_named(
async move {
controller.run(shutdown_token).await;
},
"StatisticsControl",
);
controller.start();
sender
}
}
@@ -1,11 +1,9 @@
// Copyright 2021-2023 - Nym Technologies SA <contact@nymtech.net>
// SPDX-License-Identifier: Apache-2.0
use crate::spawn_future;
pub(crate) use accessor::{TopologyAccessor, TopologyReadPermit};
use futures::StreamExt;
use nym_sphinx::addressing::nodes::NodeIdentity;
use nym_task::TaskClient;
use nym_topology::NymTopologyError;
use std::time::Duration;
use tracing::*;
@@ -41,8 +39,6 @@ pub struct TopologyRefresher {
refresh_rate: Duration,
consecutive_failure_count: usize,
task_client: TaskClient,
}
impl TopologyRefresher {
@@ -50,14 +46,12 @@ impl TopologyRefresher {
cfg: TopologyRefresherConfig,
topology_accessor: TopologyAccessor,
topology_provider: Box<dyn TopologyProvider + Send + Sync>,
task_client: TaskClient,
) -> Self {
TopologyRefresher {
topology_provider,
topology_accessor,
refresh_rate: cfg.refresh_rate,
consecutive_failure_count: 0,
task_client,
}
}
@@ -144,40 +138,30 @@ impl TopologyRefresher {
}
}
pub fn start(mut self) {
spawn_future!(
async move {
debug!("Started TopologyRefresher with graceful shutdown support");
// it's perfectly fine if task is interrupted mid-refresh
// there's no data to persist or send over
pub async fn run(&mut self) {
debug!("Started TopologyRefresher with graceful shutdown support");
#[cfg(not(target_arch = "wasm32"))]
let mut interval = tokio_stream::wrappers::IntervalStream::new(
tokio::time::interval(self.refresh_rate),
);
#[cfg(not(target_arch = "wasm32"))]
let mut interval =
tokio_stream::wrappers::IntervalStream::new(tokio::time::interval(self.refresh_rate));
#[cfg(target_arch = "wasm32")]
let mut interval =
gloo_timers::future::IntervalStream::new(self.refresh_rate.as_millis() as u32);
#[cfg(target_arch = "wasm32")]
let mut interval =
gloo_timers::future::IntervalStream::new(self.refresh_rate.as_millis() as u32);
// We already have an initial topology, so no need to refresh it immediately.
// My understanding is that js setInterval does not fire immediately, so it's not
// needed there.
#[cfg(not(target_arch = "wasm32"))]
interval.next().await;
// We already have an initial topology, so no need to refresh it immediately.
// My understanding is that js setInterval does not fire immediately, so it's not
// needed there.
#[cfg(not(target_arch = "wasm32"))]
interval.next().await;
while !self.task_client.is_shutdown() {
tokio::select! {
_ = interval.next() => {
self.try_refresh().await;
},
_ = self.task_client.recv() => {
tracing::trace!("TopologyRefresher: Received shutdown");
},
}
}
self.task_client.recv_timeout().await;
tracing::debug!("TopologyRefresher: Exiting");
},
"TopologyRefresher"
)
while interval.next().await.is_some() {
self.try_refresh().await;
}
// this should never get triggered
error!("topology refresher interval has been exhausted!")
}
}
+3 -35
View File
@@ -17,7 +17,9 @@ pub use nym_topology::{
HardcodedTopologyProvider, NymRouteProvider, NymTopology, NymTopologyError, TopologyProvider,
};
#[deprecated(note = "use spawn_future from nym_task crate instead")]
#[cfg(target_arch = "wasm32")]
#[track_caller]
pub fn spawn_future<F>(future: F)
where
F: Future<Output = ()> + 'static,
@@ -25,9 +27,7 @@ where
wasm_bindgen_futures::spawn_local(future);
}
// TODO: expose similar API to the rest of the codebase,
// perhaps with some simple trait for a task to define its name
#[deprecated(note = "use spawn_future from nym_task crate instead")]
#[cfg(not(target_arch = "wasm32"))]
#[track_caller]
pub fn spawn_future<F>(future: F)
@@ -37,35 +37,3 @@ where
{
tokio::spawn(future);
}
#[cfg(not(target_arch = "wasm32"))]
#[track_caller]
pub fn spawn_named_future<F>(future: F, name: &str)
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
cfg_if::cfg_if! {if #[cfg(tokio_unstable)] {
#[allow(clippy::expect_used)]
tokio::task::Builder::new().name(name).spawn(future).expect("failed to spawn future");
} else {
let _ = name;
tracing::debug!(r#"the underlying binary hasn't been built with `RUSTFLAGS="--cfg tokio_unstable"` - the future naming won't do anything"#);
spawn_future(future);
}}
}
#[macro_export]
macro_rules! spawn_future {
($future:expr) => {{
$crate::spawn_future($future)
}};
($future:expr, $name:expr) => {{
cfg_if::cfg_if! {if #[cfg(not(target_arch = "wasm32"))] {
$crate::spawn_named_future($future, $name)
} else {
let _ = $name;
$crate::spawn_future($future)
}}
}};
}
+2 -2
View File
@@ -40,7 +40,7 @@ where
pub async fn flush_on_shutdown(
mut self,
mem_state: CombinedReplyStorage,
mut shutdown: nym_task::TaskClient,
shutdown: nym_task::ShutdownToken,
) {
use tracing::{debug, error, info};
@@ -50,7 +50,7 @@ where
return;
}
shutdown.recv().await;
shutdown.cancelled().await;
info!("PersistentReplyStorage is flushing all reply-related data to underlying storage");
if let Err(err) = self.backend.flush_surb_storage(&mem_state).await {
@@ -12,7 +12,7 @@ use crate::socket_state::{ws_fd, PartiallyDelegatedHandle, SocketState};
use crate::traits::GatewayPacketRouter;
use crate::{cleanup_socket_message, try_decrypt_binary_message};
use futures::{SinkExt, StreamExt};
use nym_bandwidth_controller::{BandwidthController, BandwidthStatusMessage};
use nym_bandwidth_controller::BandwidthController;
use nym_credential_storage::ephemeral_storage::EphemeralStorage as EphemeralCredentialStorage;
use nym_credential_storage::storage::Storage as CredentialStorage;
use nym_credentials::CredentialSpendingData;
@@ -27,7 +27,7 @@ use nym_gateway_requests::{
use nym_sphinx::forwarding::packet::MixPacket;
use nym_statistics_common::clients::connection::ConnectionStatsEvent;
use nym_statistics_common::clients::ClientStatsSender;
use nym_task::TaskClient;
use nym_task::ShutdownToken;
use nym_validator_client::nyxd::contract_traits::DkgQueryClient;
use rand::rngs::OsRng;
use std::sync::Arc;
@@ -109,7 +109,7 @@ pub struct GatewayClient<C, St = EphemeralCredentialStorage> {
connection_fd_callback: Option<Arc<dyn Fn(RawFd) + Send + Sync>>,
/// Listen to shutdown messages and send notifications back to the task manager
task_client: TaskClient,
shutdown_token: ShutdownToken,
}
impl<C, St> GatewayClient<C, St> {
@@ -124,7 +124,7 @@ impl<C, St> GatewayClient<C, St> {
bandwidth_controller: Option<BandwidthController<C, St>>,
stats_reporter: ClientStatsSender,
#[cfg(unix)] connection_fd_callback: Option<Arc<dyn Fn(RawFd) + Send + Sync>>,
task_client: TaskClient,
shutdown_token: ShutdownToken,
) -> Self {
GatewayClient {
cfg,
@@ -141,7 +141,7 @@ impl<C, St> GatewayClient<C, St> {
negotiated_protocol: None,
#[cfg(unix)]
connection_fd_callback,
task_client,
shutdown_token,
}
}
@@ -293,7 +293,7 @@ impl<C, St> GatewayClient<C, St> {
loop {
tokio::select! {
_ = self.task_client.recv() => {
_ = self.shutdown_token.cancelled() => {
log::trace!("GatewayClient control response: Received shutdown");
log::debug!("GatewayClient control response: Exiting");
break Err(GatewayClientError::ConnectionClosedGatewayShutdown);
@@ -514,7 +514,7 @@ impl<C, St> GatewayClient<C, St> {
self.cfg.bandwidth.require_tickets,
derive_aes256_gcm_siv_key,
#[cfg(not(target_arch = "wasm32"))]
self.task_client.clone(),
self.shutdown_token.clone(),
)
.await
.map_err(GatewayClientError::RegistrationFailure),
@@ -631,9 +631,6 @@ impl<C, St> GatewayClient<C, St> {
self.negotiated_protocol = protocol_version;
log::debug!("authenticated: {status}, bandwidth remaining: {bandwidth_remaining}");
self.task_client.send_status_msg(Box::new(
BandwidthStatusMessage::RemainingBandwidth(bandwidth_remaining),
));
Ok(())
}
ServerResponse::Error { message } => Err(GatewayClientError::GatewayError(message)),
@@ -1069,7 +1066,7 @@ impl<C, St> GatewayClient<C, St> {
.expect("no shared key present even though we're authenticated!"),
),
self.bandwidth.clone(),
self.task_client.clone(),
self.shutdown_token.clone(),
)
}
_ => unreachable!(),
@@ -1143,8 +1140,8 @@ impl GatewayClient<InitOnly, EphemeralCredentialStorage> {
// perfectly fine here, because it's not meant to be used
let (ack_tx, _) = mpsc::unbounded();
let (mix_tx, _) = mpsc::unbounded();
let task_client = TaskClient::dummy();
let packet_router = PacketRouter::new(ack_tx, mix_tx, task_client.clone());
let shutdown_token = ShutdownToken::default();
let packet_router = PacketRouter::new(ack_tx, mix_tx, shutdown_token.clone());
GatewayClient {
cfg: GatewayClientConfig::default().with_disabled_credentials_mode(true),
@@ -1157,11 +1154,11 @@ impl GatewayClient<InitOnly, EphemeralCredentialStorage> {
connection: SocketState::NotConnected,
packet_router,
bandwidth_controller: None,
stats_reporter: ClientStatsSender::new(None, task_client.clone()),
stats_reporter: ClientStatsSender::new(None, shutdown_token.clone()),
negotiated_protocol: None,
#[cfg(unix)]
connection_fd_callback,
task_client,
shutdown_token,
}
}
@@ -1170,7 +1167,7 @@ impl GatewayClient<InitOnly, EphemeralCredentialStorage> {
packet_router: PacketRouter,
bandwidth_controller: Option<BandwidthController<C, St>>,
stats_reporter: ClientStatsSender,
task_client: TaskClient,
shutdown_token: ShutdownToken,
) -> GatewayClient<C, St> {
// invariants that can't be broken
// (unless somebody decided to expose some field that wasn't meant to be exposed)
@@ -1193,7 +1190,7 @@ impl GatewayClient<InitOnly, EphemeralCredentialStorage> {
negotiated_protocol: self.negotiated_protocol,
#[cfg(unix)]
connection_fd_callback: self.connection_fd_callback,
task_client,
shutdown_token,
}
}
}
@@ -7,7 +7,7 @@
use crate::error::GatewayClientError;
use crate::GatewayPacketRouter;
use futures::channel::mpsc;
use nym_task::TaskClient;
use nym_task::ShutdownToken;
pub type MixnetMessageSender = mpsc::UnboundedSender<Vec<Vec<u8>>>;
pub type MixnetMessageReceiver = mpsc::UnboundedReceiver<Vec<Vec<u8>>>;
@@ -19,14 +19,14 @@ pub type AcknowledgementReceiver = mpsc::UnboundedReceiver<Vec<Vec<u8>>>;
pub struct PacketRouter {
ack_sender: AcknowledgementSender,
mixnet_message_sender: MixnetMessageSender,
shutdown: TaskClient,
shutdown: ShutdownToken,
}
impl PacketRouter {
pub fn new(
ack_sender: AcknowledgementSender,
mixnet_message_sender: MixnetMessageSender,
shutdown: TaskClient,
shutdown: ShutdownToken,
) -> Self {
PacketRouter {
ack_sender,
@@ -42,7 +42,7 @@ impl PacketRouter {
if let Err(err) = self.mixnet_message_sender.unbounded_send(received_messages) {
// check if the failure is due to the shutdown being in progress and thus the receiver channel
// having already been dropped
if self.shutdown.is_shutdown_poll() || self.shutdown.is_dummy() {
if self.shutdown.is_cancelled() {
// This should ideally not happen, but it's ok
tracing::warn!("Failed to send mixnet messages due to receiver task shutdown");
return Err(GatewayClientError::ShutdownInProgress);
@@ -58,7 +58,7 @@ impl PacketRouter {
if let Err(err) = self.ack_sender.unbounded_send(received_acks) {
// check if the failure is due to the shutdown being in progress and thus the receiver channel
// having already been dropped
if self.shutdown.is_shutdown_poll() || self.shutdown.is_dummy() {
if self.shutdown.is_cancelled() {
// This should ideally not happen, but it's ok
tracing::warn!("Failed to send acks due to receiver task shutdown");
return Err(GatewayClientError::ShutdownInProgress);
@@ -69,10 +69,6 @@ impl PacketRouter {
}
Ok(())
}
pub fn disarm(&mut self) {
self.shutdown.disarm();
}
}
impl GatewayPacketRouter for PacketRouter {
@@ -11,7 +11,7 @@ use futures::stream::{SplitSink, SplitStream};
use futures::{SinkExt, StreamExt};
use nym_gateway_requests::shared_key::SharedGatewayKey;
use nym_gateway_requests::{SensitiveServerResponse, ServerResponse, SimpleGatewayRequestsError};
use nym_task::TaskClient;
use nym_task::ShutdownToken;
use si_scale::helpers::bibytes2;
use std::os::raw::c_int as RawFd;
use std::sync::Arc;
@@ -87,13 +87,13 @@ impl PartiallyDelegatedRouter {
}
}
async fn run(mut self, mut split_stream: SplitStream<WsConn>, mut task_client: TaskClient) {
async fn run(mut self, mut split_stream: SplitStream<WsConn>, shutdown_token: ShutdownToken) {
let mut chunked_stream = (&mut split_stream).ready_chunks(8);
let ret: Result<_, GatewayClientError> = loop {
tokio::select! {
biased;
// received system-wide shutdown
_ = task_client.recv() => {
_ = shutdown_token.cancelled() => {
log::trace!("GatewayClient listener: Received shutdown");
log::debug!("GatewayClient listener: Exiting");
return;
@@ -118,11 +118,7 @@ impl PartiallyDelegatedRouter {
let return_res = match ret {
Err(err) => self.stream_return.send(Err(err)),
Ok(_) => {
self.packet_router.disarm();
task_client.disarm();
self.stream_return.send(Ok(split_stream))
}
Ok(_) => self.stream_return.send(Ok(split_stream)),
};
if return_res.is_err() {
@@ -266,8 +262,8 @@ impl PartiallyDelegatedRouter {
Ok(plaintexts)
}
fn spawn(self, split_stream: SplitStream<WsConn>, task_client: TaskClient) {
let fut = async move { self.run(split_stream, task_client).await };
fn spawn(self, split_stream: SplitStream<WsConn>, shutdown_token: ShutdownToken) {
let fut = async move { self.run(split_stream, shutdown_token).await };
#[cfg(target_arch = "wasm32")]
wasm_bindgen_futures::spawn_local(fut);
@@ -283,7 +279,7 @@ impl PartiallyDelegatedHandle {
packet_router: PacketRouter,
shared_key: Arc<SharedGatewayKey>,
client_bandwidth: ClientBandwidth,
shutdown: TaskClient,
shutdown: ShutdownToken,
) -> Self {
// when called for, it NEEDS TO yield back the stream so that we could merge it and
// read control request responses.
@@ -126,7 +126,7 @@ pub struct CredentialHandlerConfig {
pub maximum_time_between_redemption: Duration,
}
pub(crate) struct CredentialHandler {
pub struct CredentialHandler {
config: CredentialHandlerConfig,
multisig_threshold: f32,
ticket_receiver: UnboundedReceiver<ClientTicket>,
@@ -907,7 +907,7 @@ impl CredentialHandler {
Ok(())
}
async fn run(mut self, mut shutdown: nym_task::TaskClient) {
pub async fn run(mut self, shutdown: nym_task::ShutdownToken) {
info!("Starting Ecash CredentialSender");
// attempt to clear any pending operations
@@ -919,11 +919,12 @@ impl CredentialHandler {
let start = Instant::now() + self.config.pending_poller;
let mut resolver_interval = interval_at(start, self.config.pending_poller);
while !shutdown.is_shutdown() {
loop {
tokio::select! {
biased;
_ = shutdown.recv() => {
_ = shutdown.cancelled() => {
trace!("client_handling::credentialSender : received shutdown");
break
},
Some(ticket) = self.ticket_receiver.next() => {
let (queued_up, _) = self.ticket_receiver.size_hint();
@@ -946,8 +947,4 @@ impl CredentialHandler {
}
}
}
pub(crate) fn start(self, shutdown: nym_task::TaskClient) {
tokio::spawn(async move { self.run(shutdown).await });
}
}
@@ -82,9 +82,8 @@ impl EcashManager {
credential_handler_cfg: CredentialHandlerConfig,
nyxd_client: DirectSigningHttpRpcNyxdClient,
pk_bytes: [u8; 32],
shutdown: nym_task::TaskClient,
storage: GatewayStorage,
) -> Result<Self, Error> {
) -> Result<(Self, CredentialHandler), Error> {
let shared_state = SharedState::new(nyxd_client, Box::new(storage)).await?;
let (cred_sender, cred_receiver) = mpsc::unbounded();
@@ -92,14 +91,16 @@ impl EcashManager {
let cs =
CredentialHandler::new(credential_handler_cfg, cred_receiver, shared_state.clone())
.await?;
cs.start(shutdown);
Ok(EcashManager {
shared_state,
pk_bytes,
pay_infos: Default::default(),
cred_sender,
})
Ok((
EcashManager {
shared_state,
pk_bytes,
pay_infos: Default::default(),
cred_sender,
},
cs,
))
}
pub async fn verify_pay_info(&self, pay_info: NymPayInfo) -> Result<usize, EcashTicketError> {
@@ -14,7 +14,7 @@ use std::task::{Context, Poll};
use tungstenite::{Error as WsError, Message as WsMessage};
#[cfg(not(target_arch = "wasm32"))]
use nym_task::TaskClient;
use nym_task::ShutdownToken;
pub(crate) type WsItem = Result<WsMessage, WsError>;
@@ -52,7 +52,7 @@ pub fn client_handshake<'a, S, R>(
gateway_pubkey: ed25519::PublicKey,
expects_credential_usage: bool,
derive_aes256_gcm_siv_key: bool,
#[cfg(not(target_arch = "wasm32"))] shutdown: TaskClient,
#[cfg(not(target_arch = "wasm32"))] shutdown_token: ShutdownToken,
) -> GatewayHandshake<'a>
where
S: Stream<Item = WsItem> + Sink<WsMessage> + Unpin + Send + 'a,
@@ -64,7 +64,7 @@ where
identity,
Some(gateway_pubkey),
#[cfg(not(target_arch = "wasm32"))]
shutdown,
shutdown_token,
)
.with_credential_usage(expects_credential_usage)
.with_aes256_gcm_siv_key(derive_aes256_gcm_siv_key);
@@ -80,13 +80,13 @@ pub fn gateway_handshake<'a, S, R>(
ws_stream: &'a mut S,
identity: &'a ed25519::KeyPair,
received_init_payload: Vec<u8>,
shutdown: TaskClient,
shutdown_token: ShutdownToken,
) -> GatewayHandshake<'a>
where
S: Stream<Item = WsItem> + Sink<WsMessage> + Unpin + Send + 'a,
R: CryptoRng + RngCore + Send,
{
let state = State::new(rng, ws_stream, identity, None, shutdown);
let state = State::new(rng, ws_stream, identity, None, shutdown_token);
GatewayHandshake {
handshake_future: Box::pin(state.perform_gateway_handshake(received_init_payload)),
}
@@ -149,7 +149,7 @@ mod tests {
*gateway_keys.public_key(),
false,
true,
TaskClient::dummy(),
ShutdownToken::default(),
);
let client_fut = handshake_client.spawn_timeboxed();
@@ -176,7 +176,7 @@ mod tests {
gateway_ws,
gateway_keys,
init_msg,
TaskClient::dummy(),
ShutdownToken::default(),
);
let gateway_fut = handshake_gateway.spawn_timeboxed();
@@ -24,7 +24,7 @@ use tracing::log::*;
use tungstenite::Message as WsMessage;
#[cfg(not(target_arch = "wasm32"))]
use nym_task::TaskClient;
use nym_task::ShutdownToken;
#[cfg(not(target_arch = "wasm32"))]
use tokio::time::timeout;
@@ -63,7 +63,7 @@ pub(crate) struct State<'a, S, R> {
// channel to receive shutdown signal
#[cfg(not(target_arch = "wasm32"))]
shutdown: TaskClient,
shutdown_token: ShutdownToken,
}
impl<'a, S, R> State<'a, S, R> {
@@ -72,7 +72,7 @@ impl<'a, S, R> State<'a, S, R> {
ws_stream: &'a mut S,
identity: &'a ed25519::KeyPair,
remote_pubkey: Option<ed25519::PublicKey>,
#[cfg(not(target_arch = "wasm32"))] shutdown: TaskClient,
#[cfg(not(target_arch = "wasm32"))] shutdown_token: ShutdownToken,
) -> Self
where
R: CryptoRng + RngCore,
@@ -89,7 +89,7 @@ impl<'a, S, R> State<'a, S, R> {
expects_credential_usage: false,
derive_aes256_gcm_siv_key: false,
#[cfg(not(target_arch = "wasm32"))]
shutdown,
shutdown_token,
}
}
@@ -306,7 +306,7 @@ impl<'a, S, R> State<'a, S, R> {
loop {
tokio::select! {
biased;
_ = self.shutdown.recv() => return Err(HandshakeError::ReceivedShutdown),
_ = self.shutdown_token.cancelled() => return Err(HandshakeError::ReceivedShutdown),
msg = self.ws_stream.next() => {
let Some(ret) = Self::on_wg_msg(msg)? else {
continue;
+8 -7
View File
@@ -9,7 +9,7 @@ use futures::StreamExt;
use nym_crypto::asymmetric::x25519;
use nym_sphinx::acknowledgements::AckKey;
use nym_sphinx::receiver::{MessageReceiver, SphinxMessageReceiver};
use nym_task::TaskClient;
use nym_task::ShutdownToken;
use serde::de::DeserializeOwned;
use std::sync::Arc;
@@ -24,7 +24,7 @@ pub struct SimpleMessageReceiver<T, R: MessageReceiver = SphinxMessageReceiver>
acks_receiver: mpsc::UnboundedReceiver<Vec<Vec<u8>>>,
received_sender: ReceivedSender<T>,
shutdown: TaskClient,
shutdown: ShutdownToken,
}
impl<T> SimpleMessageReceiver<T, SphinxMessageReceiver> {
@@ -34,7 +34,7 @@ impl<T> SimpleMessageReceiver<T, SphinxMessageReceiver> {
mixnet_message_receiver: mpsc::UnboundedReceiver<Vec<Vec<u8>>>,
acks_receiver: mpsc::UnboundedReceiver<Vec<Vec<u8>>>,
received_sender: ReceivedSender<T>,
shutdown: TaskClient,
shutdown: ShutdownToken,
) -> Self {
Self::new(
local_encryption_keypair,
@@ -54,7 +54,7 @@ impl<T, R: MessageReceiver> SimpleMessageReceiver<T, R> {
mixnet_message_receiver: mpsc::UnboundedReceiver<Vec<Vec<u8>>>,
acks_receiver: mpsc::UnboundedReceiver<Vec<Vec<u8>>>,
received_sender: ReceivedSender<T>,
shutdown: TaskClient,
shutdown: ShutdownToken,
) -> Self {
SimpleMessageReceiver {
message_processor: TestPacketProcessor::new(local_encryption_keypair, ack_key),
@@ -91,11 +91,12 @@ impl<T, R: MessageReceiver> SimpleMessageReceiver<T, R> {
where
T: DeserializeOwned,
{
while !self.shutdown.is_shutdown() {
loop {
tokio::select! {
biased;
_ = self.shutdown.recv() => {
log_info!("SimpleMessageReceiver: received shutdown")
_ = self.shutdown.cancelled() => {
log_info!("SimpleMessageReceiver: received shutdown");
break
}
mixnet_messages = self.mixnet_message_receiver.next() => {
let Some(mixnet_messages) = mixnet_messages else {
+32 -44
View File
@@ -23,9 +23,7 @@ use nym_client_core::init::types::GatewaySetup;
use nym_credential_storage::storage::Storage as CredentialStorage;
use nym_sphinx::addressing::clients::Recipient;
use nym_sphinx::params::PacketType;
use nym_task::{TaskClient, TaskHandle, TaskStatus};
use anyhow::anyhow;
use nym_task::{ShutdownManager, ShutdownTracker};
use nym_validator_client::UserAgent;
use std::error::Error;
use std::path::PathBuf;
@@ -46,7 +44,7 @@ pub enum Socks5ControlMessage {
pub struct StartedSocks5Client {
/// Handle for managing graceful shutdown of this client. If dropped, the client will be stopped.
pub shutdown_handle: TaskHandle,
pub shutdown_handle: ShutdownManager,
/// Address of the started client
pub address: Recipient,
@@ -65,6 +63,8 @@ pub struct NymClient<S> {
/// Optional path to a .json file containing standalone network details.
custom_mixnet: Option<PathBuf>,
shutdown_manager: ShutdownManager,
}
impl<S> NymClient<S>
@@ -92,6 +92,7 @@ where
setup_method: GatewaySetup::MustLoad { gateway_id: None },
user_agent,
custom_mixnet,
shutdown_manager: Default::default(),
}
}
@@ -108,7 +109,7 @@ where
client_output: ClientOutput,
client_status: ClientState,
self_address: Recipient,
shutdown: TaskClient,
shutdown: ShutdownTracker,
packet_type: PacketType,
) {
info!("Starting socks5 listener...");
@@ -148,51 +149,39 @@ where
socks5_config.send_anonymously,
socks5_config.socks5_debug,
),
shutdown.clone(),
shutdown,
packet_type,
);
nym_task::spawn_with_report_error(
async move {
sphinx_socks
.serve(
input_sender,
received_buffer_request_sender,
connection_command_sender,
)
.await
},
shutdown,
);
nym_task::spawn_future(async move {
sphinx_socks
.serve(
input_sender,
received_buffer_request_sender,
connection_command_sender,
)
.await
});
}
/// blocking version of `start` method. Will run forever (or until SIGINT is sent)
pub async fn run_forever(self) -> Result<(), Box<dyn Error + Send + Sync>> {
let started = self.start().await?;
let mut started = self.start().await?;
let res = started.shutdown_handle.wait_for_shutdown().await;
started.shutdown_handle.run_until_shutdown().await;
log::info!("Stopping nym-socks5-client");
res
Ok(())
}
// Variant of `run_forever` that listens for remote control messages
pub async fn run_and_listen(
self,
mut receiver: Socks5ControlMessageReceiver,
sender: nym_task::StatusSender,
) -> Result<(), Box<dyn Error + Send + Sync>> {
// Start the main task
let started = self.start().await?;
let mut shutdown = started
.shutdown_handle
.try_into_task_manager()
.ok_or(anyhow!(
"attempted to use `run_and_listen` without owning shutdown handle"
))?;
let mut task_manager = started.shutdown_handle;
// Listen to status messages from task, that we forward back to the caller
shutdown
.start_status_listener(sender, TaskStatus::Ready)
.await;
let mut shutdown_signals = task_manager.detach_shutdown_signals();
let res = tokio::select! {
biased;
@@ -207,22 +196,20 @@ where
}
}
Ok(())
}
Some(msg) = shutdown.wait_for_error() => {
log::info!("Task error: {msg:?}");
Err(msg)
}
_ = tokio::signal::ctrl_c() => {
log::info!("Received SIGINT");
},
_ = shutdown_signals.wait_for_signal() => {
log::info!("Received shutdown signal");
Ok(())
},
};
log::info!("Sending shutdown");
shutdown.signal_shutdown().ok();
if !task_manager.is_cancelled() {
log::info!("Sending shutdown");
task_manager.send_cancellation();
}
log::info!("Waiting for tasks to finish... (Press ctrl-c to force)");
shutdown.wait_for_shutdown().await;
task_manager.perform_shutdown().await;
log::info!("Stopping nym-socks5-client");
res
@@ -238,6 +225,7 @@ where
let mut base_builder =
BaseClientBuilder::new(self.config.base(), self.storage, dkg_query_client)
.with_shutdown(self.shutdown_manager.shutdown_tracker_owned())
.with_gateway_setup(self.setup_method)
.with_user_agent(self.user_agent);
@@ -261,7 +249,7 @@ where
client_output,
client_state,
self_address,
started_client.task_handle.get_handle(),
self.shutdown_manager.shutdown_tracker_owned(),
packet_type,
);
@@ -269,7 +257,7 @@ where
info!("The address of this client is: {self_address}");
Ok(StartedSocks5Client {
shutdown_handle: started_client.task_handle,
shutdown_handle: self.shutdown_manager,
address: self_address,
})
}
@@ -21,7 +21,7 @@ use nym_sphinx::addressing::clients::Recipient;
use nym_sphinx::params::PacketSize;
use nym_sphinx::params::PacketType;
use nym_task::connections::{LaneQueueLengths, TransmissionLane};
use nym_task::TaskClient;
use nym_task::ShutdownTracker;
use pin_project::pin_project;
use rand::RngCore;
use std::io;
@@ -185,7 +185,7 @@ pub(crate) struct SocksClient {
self_address: Recipient,
started_proxy: bool,
lane_queue_lengths: LaneQueueLengths,
shutdown_listener: TaskClient,
shutdown_listener: ShutdownTracker,
packet_type: Option<PacketType>,
}
@@ -214,12 +214,9 @@ impl SocksClient {
controller_sender: ControllerSender,
self_address: &Recipient,
lane_queue_lengths: LaneQueueLengths,
mut shutdown_listener: TaskClient,
shutdown_listener: ShutdownTracker,
packet_type: Option<PacketType>,
) -> Self {
// If this task fails and exits, we don't want to send shutdown signal
shutdown_listener.disarm();
let connection_id = Self::generate_random();
SocksClient {
@@ -294,7 +291,6 @@ impl SocksClient {
.shutdown()
.await
.map_err(|source| SocksProxyError::SocketShutdownFailure { source })?;
self.shutdown_listener.disarm();
Ok(())
}
@@ -13,13 +13,13 @@ use nym_service_providers_common::interface::{ControlResponse, ResponseContent};
use nym_socks5_proxy_helpers::connection_controller::{ControllerCommand, ControllerSender};
use nym_socks5_requests::{Socks5ProviderResponse, Socks5Response, Socks5ResponseContent};
use nym_sphinx::receiver::ReconstructedMessage;
use nym_task::TaskClient;
use nym_task::ShutdownToken;
pub(crate) struct MixnetResponseListener {
buffer_requester: ReceivedBufferRequestSender,
mix_response_receiver: ReconstructedMessagesReceiver,
controller_sender: ControllerSender,
shutdown: TaskClient,
shutdown: ShutdownToken,
}
impl Drop for MixnetResponseListener {
@@ -28,7 +28,7 @@ impl Drop for MixnetResponseListener {
.buffer_requester
.unbounded_send(ReceivedBufferMessage::ReceiverDisconnect)
{
if self.shutdown.is_shutdown_poll() {
if self.shutdown.is_cancelled() {
log::debug!("The buffer request failed: {err}");
} else {
log::error!("The buffer request failed: {err}");
@@ -41,7 +41,7 @@ impl MixnetResponseListener {
pub(crate) fn new(
buffer_requester: ReceivedBufferRequestSender,
controller_sender: ControllerSender,
shutdown: TaskClient,
shutdown: ShutdownToken,
) -> Self {
let (mix_response_sender, mix_response_receiver) = mpsc::unbounded();
buffer_requester
@@ -130,13 +130,18 @@ impl MixnetResponseListener {
}
pub(crate) async fn run(&mut self) {
while !self.shutdown.is_shutdown() {
loop {
tokio::select! {
biased;
_ = self.shutdown.cancelled() => {
log::trace!("MixnetResponseListener: Received shutdown");
break;
}
received_responses = self.mix_response_receiver.next() => {
if let Some(received_responses) = received_responses {
for reconstructed_message in received_responses {
if let Err(err) = self.on_message(reconstructed_message) {
self.shutdown.send_status_msg(Box::new(err));
debug!("message handling error: {err}")
}
}
} else {
@@ -144,12 +149,8 @@ impl MixnetResponseListener {
break;
}
},
_ = self.shutdown.recv() => {
log::trace!("MixnetResponseListener: Received shutdown");
}
}
}
self.shutdown.recv_timeout().await;
log::debug!("MixnetResponseListener: Exiting");
}
}
+31 -24
View File
@@ -12,7 +12,7 @@ use nym_socks5_proxy_helpers::connection_controller::Controller;
use nym_sphinx::addressing::clients::Recipient;
use nym_sphinx::params::PacketType;
use nym_task::connections::{ConnectionCommandSender, LaneQueueLengths};
use nym_task::TaskClient;
use nym_task::ShutdownTracker;
use std::net::SocketAddr;
use tap::TapFallible;
use tokio::net::TcpListener;
@@ -25,7 +25,7 @@ pub struct NymSocksServer {
self_address: Recipient,
client_config: client::Config,
lane_queue_lengths: LaneQueueLengths,
shutdown: TaskClient,
shutdown: ShutdownTracker,
packet_type: PacketType,
}
@@ -39,7 +39,7 @@ impl NymSocksServer {
self_address: Recipient,
lane_queue_lengths: LaneQueueLengths,
client_config: client::Config,
shutdown: TaskClient,
shutdown: ShutdownTracker,
packet_type: PacketType,
) -> Self {
info!("Listening on {bind_address}");
@@ -72,7 +72,7 @@ impl NymSocksServer {
let (mut active_streams_controller, controller_sender) = Controller::new(
client_connection_tx,
//BroadcastActiveConnections::Off,
self.shutdown.clone(),
self.shutdown.clone_shutdown_token(),
);
tokio::spawn(async move {
active_streams_controller.run().await;
@@ -82,20 +82,30 @@ impl NymSocksServer {
let mut mixnet_response_listener = MixnetResponseListener::new(
buffer_requester,
controller_sender.clone(),
self.shutdown.clone(),
self.shutdown.clone_shutdown_token(),
);
self.shutdown.try_spawn_named(
async move {
mixnet_response_listener.run().await;
},
"Socks5MixnetListener",
);
tokio::spawn(async move {
mixnet_response_listener.run().await;
});
// TODO:, if required, there should be another task here responsible for control requests.
// it should get `input_sender` to send actual requests into the mixnet
// and some channel that connects it from `MixnetResponseListener` to receive
// any control responses
let shutdown = self.shutdown.clone_shutdown_token();
loop {
tokio::select! {
Ok((stream, _remote)) = listener.accept() => {
biased;
_ = shutdown.cancelled() => {
log::trace!("NymSocksServer: Received shutdown");
log::debug!("NymSocksServer: Exiting");
return Ok(());
}
Ok((stream, remote)) = listener.accept() => {
let mut client = SocksClient::new(
self.client_config,
stream,
@@ -109,23 +119,20 @@ impl NymSocksServer {
Some(self.packet_type)
);
tokio::spawn(async move {
if let Err(err) = client.run().await {
error!("Error! {err}");
if client.send_error(err).await.is_err() {
warn!("Failed to send error code");
self.shutdown.try_spawn_named(
async move {
if let Err(err) = client.run().await {
error!("Error! {err}");
if client.send_error(err).await.is_err() {
warn!("Failed to send error code");
};
if client.shutdown().await.is_err() {
warn!("Failed to shutdown TcpStream");
};
};
if client.shutdown().await.is_err() {
warn!("Failed to shutdown TcpStream");
};
}
});
}, &format!("Socks5Client::{remote}")
);
},
_ = self.shutdown.recv() => {
log::trace!("NymSocksServer: Received shutdown");
log::debug!("NymSocksServer: Exiting");
return Ok(());
}
}
}
}
@@ -7,7 +7,7 @@ use log::*;
use nym_ordered_buffer::{OrderedMessageBuffer, ReadContiguousData};
use nym_socks5_requests::{ConnectionId, SocketData};
use nym_task::connections::{ConnectionCommand, ConnectionCommandSender};
use nym_task::TaskClient;
use nym_task::ShutdownToken;
use std::collections::{HashMap, HashSet};
/// A generic message produced after reading from a socket/connection.
@@ -101,13 +101,13 @@ pub struct Controller {
// un-order messages. Note we don't ever expect to have more than 1-2 messages per connection here
pending_messages: HashMap<ConnectionId, Vec<SocketData>>,
shutdown: TaskClient,
shutdown: ShutdownToken,
}
impl Controller {
pub fn new(
client_connection_tx: ConnectionCommandSender,
shutdown: TaskClient,
shutdown: ShutdownToken,
) -> (Self, ControllerSender) {
let (sender, receiver) = mpsc::unbounded();
(
@@ -155,7 +155,7 @@ impl Controller {
.client_connection_tx
.unbounded_send(ConnectionCommand::Close(conn_id))
{
if self.shutdown.is_shutdown_poll() {
if self.shutdown.is_cancelled() {
log::debug!("Failed to send: {err}");
} else {
log::error!("Failed to send: {err}");
@@ -230,7 +230,6 @@ impl Controller {
},
}
}
self.shutdown.recv_timeout().await;
log::debug!("SOCKS5 Controller: Exiting");
}
}
@@ -11,7 +11,7 @@ use log::*;
use nym_socks5_requests::{ConnectionId, SocketData};
use nym_task::connections::LaneQueueLengths;
use nym_task::connections::TransmissionLane;
use nym_task::TaskClient;
use nym_task::ShutdownToken;
use std::sync::Arc;
use std::time::Duration;
use tokio::select;
@@ -81,7 +81,7 @@ pub(super) async fn run_inbound<F, S>(
available_plaintext_per_mix_packet: usize,
shutdown_notify: Arc<Notify>,
lane_queue_lengths: Option<LaneQueueLengths>,
mut shutdown_listener: TaskClient,
shutdown_listener: ShutdownToken,
) -> OwnedReadHalf
where
F: Fn(SocketData) -> S + Send + 'static,
@@ -129,7 +129,7 @@ where
message_sender.send_empty_close().await;
break;
}
_ = shutdown_listener.recv() => {
_ = shutdown_listener.cancelled() => {
log::trace!("ProxyRunner inbound: Received shutdown");
break;
}
@@ -171,6 +171,5 @@ where
trace!("{connection_id} - inbound closed");
shutdown_notify.notify_one();
shutdown_listener.disarm();
reader
}
@@ -5,7 +5,7 @@ use crate::connection_controller::ConnectionReceiver;
use crate::ordered_sender::OrderedMessageSender;
use nym_socks5_requests::{ConnectionId, SocketData};
use nym_task::connections::LaneQueueLengths;
use nym_task::TaskClient;
use nym_task::ShutdownTracker;
use std::fmt::Debug;
use std::{sync::Arc, time::Duration};
use tokio::{net::TcpStream, sync::Notify};
@@ -57,7 +57,8 @@ pub struct ProxyRunner<S> {
available_plaintext_per_mix_packet: usize,
// Listens to shutdown commands from higher up
shutdown_listener: TaskClient,
// and spawn new tracked tasks
shutdown_tracker: ShutdownTracker,
}
impl<S> ProxyRunner<S>
@@ -74,7 +75,7 @@ where
available_plaintext_per_mix_packet: usize,
connection_id: ConnectionId,
lane_queue_lengths: Option<LaneQueueLengths>,
shutdown_listener: TaskClient,
shutdown_tracker: ShutdownTracker,
) -> Self {
ProxyRunner {
mix_receiver: Some(mix_receiver),
@@ -85,7 +86,7 @@ where
connection_id,
lane_queue_lengths,
available_plaintext_per_mix_packet,
shutdown_listener,
shutdown_tracker,
}
}
@@ -113,7 +114,7 @@ where
self.available_plaintext_per_mix_packet,
Arc::clone(&shutdown_notify),
self.lane_queue_lengths.clone(),
self.shutdown_listener.clone(),
self.shutdown_tracker.clone_shutdown_token(),
);
let outbound_future = outbound::run_outbound(
@@ -123,14 +124,26 @@ where
self.mix_receiver.take().unwrap(),
self.connection_id,
shutdown_notify,
self.shutdown_listener.clone(),
self.shutdown_tracker.clone_shutdown_token(),
);
// TODO: this shouldn't really have to spawn tasks inside "library" code, but
// if we used join directly, stuff would have been executed on the same thread
// (it's not bad, but an unnecessary slowdown)
let handle_inbound = tokio::spawn(inbound_future);
let handle_outbound = tokio::spawn(outbound_future);
let handle_inbound = self.shutdown_tracker.try_spawn_named(
inbound_future,
&format!(
"Socks5Inbound::{}::{}",
self.remote_source_address, self.connection_id
),
);
let handle_outbound = self.shutdown_tracker.try_spawn_named(
outbound_future,
&format!(
"Socks5Outbound::{}::{}",
self.remote_source_address, self.connection_id
),
);
let (inbound_result, outbound_result) =
futures::future::join(handle_inbound, handle_outbound).await;
@@ -148,7 +161,6 @@ where
}
pub fn into_inner(mut self) -> (TcpStream, ConnectionReceiver) {
self.shutdown_listener.disarm();
(
self.socket.take().unwrap(),
self.mix_receiver.take().unwrap(),
@@ -7,7 +7,7 @@ use futures::FutureExt;
use futures::StreamExt;
use log::*;
use nym_socks5_requests::ConnectionId;
use nym_task::TaskClient;
use nym_task::ShutdownToken;
use std::{sync::Arc, time::Duration};
use tokio::io::AsyncWriteExt;
use tokio::select;
@@ -51,7 +51,7 @@ pub(super) async fn run_outbound(
mut mix_receiver: ConnectionReceiver,
connection_id: ConnectionId,
shutdown_notify: Arc<Notify>,
mut shutdown_listener: TaskClient,
shutdown_listener: ShutdownToken,
) -> (OwnedWriteHalf, ConnectionReceiver) {
let shutdown_future = shutdown_notify.notified().then(|_| sleep(SHUTDOWN_TIMEOUT));
tokio::pin!(shutdown_future);
@@ -60,6 +60,11 @@ pub(super) async fn run_outbound(
loop {
select! {
biased;
_ = shutdown_listener.cancelled() => {
log::trace!("ProxyRunner outbound: Received shutdown");
break;
}
connection_message = mix_receiver.next() => {
if let Some(connection_message) = connection_message {
if deal_with_message(connection_message, &mut writer, &local_destination_address, &remote_source_address, connection_id).await {
@@ -80,16 +85,11 @@ pub(super) async fn run_outbound(
debug!("closing outbound proxy after inbound was closed {SHUTDOWN_TIMEOUT:?} ago");
break;
}
_ = shutdown_listener.recv() => {
log::trace!("ProxyRunner outbound: Received shutdown");
break;
}
}
}
trace!("{connection_id} - outbound closed");
shutdown_notify.notify_one();
shutdown_listener.disarm();
(writer, mix_receiver)
}
+7 -7
View File
@@ -3,7 +3,7 @@
use crate::report::client::{ClientStatsReport, OsInformation};
use nym_task::TaskClient;
use nym_task::ShutdownToken;
use time::{OffsetDateTime, Time};
use tokio::sync::mpsc::UnboundedSender;
@@ -25,18 +25,18 @@ pub type ClientStatsReceiver = tokio::sync::mpsc::UnboundedReceiver<ClientStatsE
#[derive(Clone)]
pub struct ClientStatsSender {
stats_tx: Option<UnboundedSender<ClientStatsEvents>>,
task_client: TaskClient,
shutdown_token: ShutdownToken,
}
impl ClientStatsSender {
/// Create a new statistics Sender
pub fn new(
stats_tx: Option<UnboundedSender<ClientStatsEvents>>,
task_client: TaskClient,
shutdown_token: ShutdownToken,
) -> Self {
ClientStatsSender {
stats_tx,
task_client,
shutdown_token,
}
}
@@ -44,7 +44,7 @@ impl ClientStatsSender {
pub fn report(&self, event: ClientStatsEvents) {
if let Some(tx) = &self.stats_tx {
if let Err(err) = tx.send(event) {
if !self.task_client.is_shutdown_poll() {
if !self.shutdown_token.is_cancelled() {
log::error!("Failed to send stats event: {err}");
}
}
@@ -137,8 +137,8 @@ impl ClientStatsController {
self.packet_stats.snapshot();
}
pub fn local_report(&mut self, task_client: &mut TaskClient) {
self.packet_stats.local_report(task_client);
pub fn local_report(&mut self) {
self.packet_stats.local_report();
self.gateway_conn_stats.local_report();
self.nym_api_stats.local_report();
}
@@ -449,15 +449,16 @@ impl PacketStatisticsControl {
self.stats.clone()
}
pub(crate) fn local_report(&mut self, task_client: &mut nym_task::TaskClient) {
let rates = self.report_rates();
pub(crate) fn local_report(&mut self) {
let _rates = self.report_rates();
self.check_for_notable_events();
self.report_counters();
// Report our current bandwidth used to e.g a GUI client
if let Some(rates) = rates {
task_client.send_status_msg(Box::new(MixnetBandwidthStatisticsEvent::new(rates)));
}
// leave the code commented in case somebody wanted to restore this logic with a different channel
// // Report our current bandwidth used to e.g a GUI client
// if let Some(rates) = rates {
// task_client.send_status_msg(Box::new(MixnetBandwidthStatisticsEvent::new(rates)));
// }
}
// Add the current stats to the history, and remove old ones.
+8
View File
@@ -30,5 +30,13 @@ workspace = true
workspace = true
features = ["tokio"]
[features]
tokio-tracing = ["tokio/tracing"]
[dev-dependencies]
anyhow = { workspace = true }
tokio = { workspace = true, features = ["rt-multi-thread", "net", "signal", "test-util", "macros"] }
nym-test-utils = { path = "../test-utils" }
[lints]
workspace = true
-414
View File
@@ -1,414 +0,0 @@
// Copyright 2025 - Nym Technologies SA <contact@nymtech.net>
// SPDX-License-Identifier: Apache-2.0
use crate::{TaskClient, TaskManager};
use futures::stream::FuturesUnordered;
use futures::StreamExt;
use std::future::Future;
use std::mem;
use std::ops::Deref;
use std::pin::Pin;
use std::time::Duration;
use tokio::task::JoinSet;
use tokio::time::sleep;
use tokio_util::sync::{CancellationToken, DropGuard};
use tokio_util::task::TaskTracker;
use tracing::{debug, info, trace};
#[cfg(unix)]
use tokio::signal::unix::{signal, SignalKind};
pub const DEFAULT_MAX_SHUTDOWN_DURATION: Duration = Duration::from_secs(5);
pub fn token_name(name: &Option<String>) -> String {
name.clone().unwrap_or_else(|| "unknown".to_string())
}
// a wrapper around tokio's CancellationToken that adds optional `name` information to more easily
// track down sources of shutdown
#[derive(Debug, Default)]
pub struct ShutdownToken {
name: Option<String>,
inner: CancellationToken,
}
impl Clone for ShutdownToken {
fn clone(&self) -> Self {
// make sure to not accidentally overflow the stack if we keep cloning the handle
let name = if let Some(name) = &self.name {
if name != Self::OVERFLOW_NAME && name.len() < Self::MAX_NAME_LENGTH {
Some(format!("{name}-child"))
} else {
Some(Self::OVERFLOW_NAME.to_string())
}
} else {
None
};
ShutdownToken {
name,
inner: self.inner.clone(),
}
}
}
impl Deref for ShutdownToken {
type Target = CancellationToken;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl ShutdownToken {
const MAX_NAME_LENGTH: usize = 128;
const OVERFLOW_NAME: &'static str = "reached maximum ShutdownToken children name depth";
pub fn new(name: impl Into<String>) -> Self {
ShutdownToken {
name: Some(name.into()),
inner: CancellationToken::new(),
}
}
pub fn ephemeral() -> Self {
ShutdownToken::new("ephemeral-token")
}
// Creates a ShutdownToken which will get cancelled whenever the current token gets cancelled.
// Unlike a cloned/forked ShutdownToken, cancelling a child token does not cancel the parent token.
#[must_use]
pub fn child_token<S: Into<String>>(&self, child_suffix: S) -> Self {
let suffix = child_suffix.into();
let child_name = if let Some(base) = &self.name {
format!("{base}-{suffix}")
} else {
format!("unknown-{suffix}")
};
ShutdownToken {
name: Some(child_name),
inner: self.inner.child_token(),
}
}
// Creates a clone of the ShutdownToken which will get cancelled whenever the current token gets cancelled, and vice versa.
#[must_use]
pub fn clone_with_suffix<S: Into<String>>(&self, child_suffix: S) -> Self {
let mut child = self.clone();
let suffix = child_suffix.into();
let child_name = if let Some(base) = &self.name {
format!("{base}-{suffix}")
} else {
format!("unknown-{suffix}")
};
child.name = Some(child_name);
child
}
// exposed method with the old name for easier migration
// it will eventually be removed so please try to use `.clone_with_suffix` instead
#[must_use]
#[deprecated(note = "use .clone_with_suffix instead")]
pub fn fork<S: Into<String>>(&self, child_suffix: S) -> Self {
self.clone_with_suffix(child_suffix)
}
// exposed method with the old name for easier migration
// it will eventually be removed so please try to use `.clone().named(name)` instead
#[must_use]
#[deprecated(note = "use .clone().named(name) instead")]
pub fn fork_named<S: Into<String>>(&self, name: S) -> Self {
self.clone().named(name)
}
#[must_use]
pub fn named<S: Into<String>>(mut self, name: S) -> Self {
self.name = Some(name.into());
self
}
#[must_use]
pub fn add_suffix<S: Into<String>>(self, suffix: S) -> Self {
let suffix = suffix.into();
let name = if let Some(base) = &self.name {
format!("{base}-{suffix}")
} else {
format!("unknown-{suffix}")
};
self.named(name)
}
// Returned guard will cancel this token (and all its children) on drop unless disarmed.
pub fn drop_guard(self) -> ShutdownDropGuard {
ShutdownDropGuard {
name: self.name,
inner: self.inner.drop_guard(),
}
}
pub fn name(&self) -> String {
token_name(&self.name)
}
pub async fn run_until_cancelled<F>(&self, fut: F) -> Option<F::Output>
where
F: Future,
{
let res = self.inner.run_until_cancelled(fut).await;
trace!("'{}' got cancelled", self.name());
res
}
}
pub struct ShutdownDropGuard {
name: Option<String>,
inner: DropGuard,
}
impl Deref for ShutdownDropGuard {
type Target = DropGuard;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl ShutdownDropGuard {
pub fn disarm(self) -> ShutdownToken {
ShutdownToken {
name: self.name,
inner: self.inner.disarm(),
}
}
pub fn name(&self) -> String {
token_name(&self.name)
}
}
#[derive(Default)]
pub struct ShutdownSignals(JoinSet<()>);
impl ShutdownSignals {
pub async fn wait_for_signal(&mut self) {
self.0.join_next().await;
}
}
pub struct ShutdownManager {
pub root_token: ShutdownToken,
legacy_task_manager: Option<TaskManager>,
shutdown_signals: ShutdownSignals,
// the reason I'm not using a `JoinSet` is because it forces us to use futures with the same `::Output` type
tracker: TaskTracker,
max_shutdown_duration: Duration,
}
impl Deref for ShutdownManager {
type Target = TaskTracker;
fn deref(&self) -> &Self::Target {
&self.tracker
}
}
impl ShutdownManager {
pub fn new(root_token_name: impl Into<String>) -> Self {
let manager = ShutdownManager {
root_token: ShutdownToken::new(root_token_name),
legacy_task_manager: None,
shutdown_signals: Default::default(),
tracker: Default::default(),
max_shutdown_duration: Duration::from_secs(10),
};
// we need to add an explicit watcher for the cancellation token being cancelled
// so that we could cancel all legacy tasks
let cancel_watcher = manager.root_token.clone();
manager.with_shutdown(async move { cancel_watcher.cancelled().await })
}
pub fn empty_mock() -> Self {
ShutdownManager {
root_token: ShutdownToken::ephemeral(),
legacy_task_manager: None,
shutdown_signals: Default::default(),
tracker: Default::default(),
max_shutdown_duration: Default::default(),
}
}
pub fn with_legacy_task_manager(mut self) -> Self {
let mut legacy_manager =
TaskManager::default().named(format!("{}-legacy", self.root_token.name()));
let mut legacy_error_rx = legacy_manager.task_return_error_rx();
let mut legacy_drop_rx = legacy_manager.task_drop_rx();
self.legacy_task_manager = Some(legacy_manager);
// add a task that listens for legacy task clients being dropped to trigger cancellation
self.with_shutdown(async move {
tokio::select! {
_ = legacy_error_rx.recv() => (),
_ = legacy_drop_rx.recv() => (),
}
info!("received legacy shutdown signal");
})
}
#[cfg(not(target_arch = "wasm32"))]
pub fn with_default_shutdown_signals(self) -> std::io::Result<Self> {
cfg_if::cfg_if! {
if #[cfg(unix)] {
self.with_interrupt_signal()
.with_terminate_signal()?
.with_quit_signal()
} else {
Ok(self.with_interrupt_signal())
}
}
}
#[must_use]
#[track_caller]
pub fn with_shutdown<F>(mut self, shutdown: F) -> Self
where
F: Future<Output = ()>,
F: Send + 'static,
{
let shutdown_token = self.root_token.clone();
self.shutdown_signals.0.spawn(async move {
shutdown.await;
info!("sending cancellation after receiving shutdown signal");
shutdown_token.cancel();
});
self
}
#[cfg(unix)]
#[track_caller]
pub fn with_shutdown_signal(self, signal_kind: SignalKind) -> std::io::Result<Self> {
let mut sig = signal(signal_kind)?;
Ok(self.with_shutdown(async move {
sig.recv().await;
}))
}
#[cfg(not(target_arch = "wasm32"))]
#[track_caller]
pub fn with_interrupt_signal(self) -> Self {
self.with_shutdown(async move {
let _ = tokio::signal::ctrl_c().await;
})
}
#[cfg(unix)]
#[track_caller]
pub fn with_terminate_signal(self) -> std::io::Result<Self> {
self.with_shutdown_signal(SignalKind::terminate())
}
#[cfg(unix)]
#[track_caller]
pub fn with_quit_signal(self) -> std::io::Result<Self> {
self.with_shutdown_signal(SignalKind::quit())
}
#[must_use]
pub fn with_shutdown_duration(mut self, duration: Duration) -> Self {
self.max_shutdown_duration = duration;
self
}
pub fn child_token<S: Into<String>>(&self, child_suffix: S) -> ShutdownToken {
self.root_token.child_token(child_suffix)
}
pub fn clone_token<S: Into<String>>(&self, child_suffix: S) -> ShutdownToken {
self.root_token.clone_with_suffix(child_suffix)
}
#[must_use]
pub fn subscribe_legacy<S: Into<String>>(&self, child_suffix: S) -> TaskClient {
// alternatively we could have set self.legacy_task_manager = Some(TaskManager::default());
// on demand if it wasn't unavailable, but then we'd have to use mutable reference
#[allow(clippy::expect_used)]
self.legacy_task_manager
.as_ref()
.expect("did not enable legacy shutdown support")
.subscribe_named(child_suffix)
}
async fn finish_shutdown(mut self) {
let mut wait_futures = FuturesUnordered::<Pin<Box<dyn Future<Output = ()>>>>::new();
// force shutdown via ctrl-c
wait_futures.push(Box::pin(async move {
#[cfg(not(target_arch = "wasm32"))]
let interrupt_future = tokio::signal::ctrl_c();
#[cfg(target_arch = "wasm32")]
let interrupt_future = futures::future::pending::<()>();
let _ = interrupt_future.await;
info!("received interrupt - forcing shutdown");
}));
// timeout
wait_futures.push(Box::pin(async move {
sleep(self.max_shutdown_duration).await;
info!("timeout reached, forcing shutdown");
}));
// graceful
wait_futures.push(Box::pin(async move {
self.tracker.wait().await;
debug!("migrated tasks successfully shutdown");
if let Some(legacy) = self.legacy_task_manager.as_mut() {
legacy.wait_for_graceful_shutdown().await;
debug!("legacy tasks successfully shutdown");
}
info!("all registered tasks successfully shutdown")
}));
wait_futures.next().await;
}
pub fn detach_shutdown_signals(&mut self) -> ShutdownSignals {
mem::take(&mut self.shutdown_signals)
}
pub fn replace_shutdown_signals(&mut self, signals: ShutdownSignals) {
self.shutdown_signals = signals;
}
// cancellation safe
pub async fn wait_for_shutdown_signal(&mut self) {
self.shutdown_signals.0.join_next().await;
}
pub async fn perform_shutdown(mut self) {
if let Some(legacy_manager) = self.legacy_task_manager.as_mut() {
info!("attempting to shutdown legacy tasks");
let _ = legacy_manager.signal_shutdown();
}
info!("waiting for tasks to finish... (press ctrl-c to force)");
self.finish_shutdown().await;
}
pub async fn run_until_shutdown(mut self) {
self.wait_for_shutdown_signal().await;
self.perform_shutdown().await;
}
}
+744
View File
@@ -0,0 +1,744 @@
// Copyright 2025 - Nym Technologies SA <contact@nymtech.net>
// SPDX-License-Identifier: Apache-2.0
use crate::cancellation::tracker::{Cancelled, ShutdownTracker};
use crate::spawn::JoinHandle;
use crate::ShutdownToken;
use futures::stream::FuturesUnordered;
use futures::StreamExt;
use log::error;
use std::future::Future;
use std::mem;
use std::pin::Pin;
use std::time::Duration;
use tracing::info;
#[cfg(not(target_arch = "wasm32"))]
use tokio::time::sleep;
#[cfg(target_arch = "wasm32")]
use wasmtimer::tokio::sleep;
#[cfg(unix)]
use tokio::signal::unix::{signal, SignalKind};
use tokio::task::JoinSet;
/// A top level structure responsible for controlling process shutdown by listening to
/// the underlying registered signals and issuing cancellation to tasks derived from its root cancellation token.
#[allow(deprecated)]
pub struct ShutdownManager {
/// Optional reference to the legacy [TaskManager](crate::TaskManager) to allow easier
/// transition to the new system.
pub(crate) legacy_task_manager: Option<crate::TaskManager>,
/// Registered [ShutdownSignals](ShutdownSignals) that will trigger process shutdown if detected.
pub(crate) shutdown_signals: ShutdownSignals,
/// Combined [TaskTracker](tokio_util::task::TaskTracker) and [ShutdownToken](ShutdownToken)
/// for spawning and tracking tasks associated with this ShutdownManager.
pub(crate) tracker: ShutdownTracker,
/// The maximum shutdown duration when tracked tasks could gracefully exit
/// before forcing the shutdown.
pub(crate) max_shutdown_duration: Duration,
}
/// Wrapper behind futures that upon completion will trigger binary shutdown.
#[derive(Default)]
pub struct ShutdownSignals(JoinSet<()>);
impl ShutdownSignals {
/// Wait for any of the registered signals to be ready
pub async fn wait_for_signal(&mut self) {
self.0.join_next().await;
}
}
// note: default implementation will ONLY listen for SIGINT and will ignore SIGTERM and SIGQUIT
// this is due to result type when registering the signal
#[cfg(not(target_arch = "wasm32"))]
impl Default for ShutdownManager {
fn default() -> Self {
ShutdownManager::new_without_signals()
.with_interrupt_signal()
.with_cancel_on_panic()
}
}
#[cfg(not(target_arch = "wasm32"))]
impl ShutdownManager {
/// Create new instance of ShutdownManager with the most sensible defaults, so that:
/// - shutdown will be triggered upon either SIGINT, SIGTERM (unix only) or SIGQUIT (unix only) being sent
/// - shutdown will be triggered upon any task panicking
pub fn build_new_default() -> std::io::Result<Self> {
Ok(ShutdownManager::new_without_signals()
.with_default_shutdown_signals()?
.with_cancel_on_panic())
}
/// Register a new shutdown signal that upon completion will trigger system shutdown.
#[must_use]
#[track_caller]
pub fn with_shutdown<F>(mut self, shutdown: F) -> Self
where
F: Future<Output = ()>,
F: Send + 'static,
{
let shutdown_token = self.tracker.clone_shutdown_token();
self.shutdown_signals.0.spawn(async move {
shutdown.await;
info!("sending cancellation after receiving shutdown signal");
shutdown_token.cancel();
});
self
}
/// Include support for the legacy [TaskManager](TaskManager) to this instance of the ShutdownManager.
/// This will allow issuing [TaskClient](TaskClient) for tasks that still require them.
#[allow(deprecated)]
pub fn with_legacy_task_manager(mut self) -> Self {
let mut legacy_manager = crate::TaskManager::default().named("legacy-task-manager");
let mut legacy_error_rx = legacy_manager.task_return_error_rx();
let mut legacy_drop_rx = legacy_manager.task_drop_rx();
self.legacy_task_manager = Some(legacy_manager);
// add a task that listens for legacy task clients being dropped to trigger cancellation
self.with_shutdown(async move {
tokio::select! {
_ = legacy_error_rx.recv() => (),
_ = legacy_drop_rx.recv() => (),
}
info!("received legacy shutdown signal");
})
}
/// Add the specified signal to the currently registered shutdown signals that will trigger
/// cancellation of all registered tasks.
#[cfg(unix)]
#[track_caller]
pub fn with_shutdown_signal(self, signal_kind: SignalKind) -> std::io::Result<Self> {
let mut sig = signal(signal_kind)?;
Ok(self.with_shutdown(async move {
sig.recv().await;
}))
}
/// Add the SIGTERM signal to the currently registered shutdown signals that will trigger
/// cancellation of all registered tasks.
#[cfg(unix)]
#[track_caller]
pub fn with_terminate_signal(self) -> std::io::Result<Self> {
self.with_shutdown_signal(SignalKind::terminate())
}
/// Add the SIGQUIT signal to the currently registered shutdown signals that will trigger
/// cancellation of all registered tasks.
#[cfg(unix)]
#[track_caller]
pub fn with_quit_signal(self) -> std::io::Result<Self> {
self.with_shutdown_signal(SignalKind::quit())
}
/// Add default signals to the set of the currently registered shutdown signals that will trigger
/// cancellation of all registered tasks.
/// This includes SIGINT, SIGTERM and SIGQUIT for unix-based platforms and SIGINT for other targets (such as windows)/
pub fn with_default_shutdown_signals(self) -> std::io::Result<Self> {
cfg_if::cfg_if! {
if #[cfg(unix)] {
self.with_interrupt_signal()
.with_terminate_signal()?
.with_quit_signal()
} else {
Ok(self.with_interrupt_signal())
}
}
}
/// Add the SIGINT (ctrl-c) signal to the currently registered shutdown signals that will trigger
/// cancellation of all registered tasks.
#[track_caller]
pub fn with_interrupt_signal(self) -> Self {
self.with_shutdown(async move {
let _ = tokio::signal::ctrl_c().await;
})
}
/// Spawn the provided future on the current Tokio runtime, and track it in the underlying [TaskTracker](tokio_util::task::TaskTracker).
#[track_caller]
pub fn spawn<F>(&self, task: F) -> JoinHandle<F::Output>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
self.tracker.spawn(task)
}
/// Spawn the provided future on the current Tokio runtime,
/// and track it in the underlying [TaskTracker](tokio_util::task::TaskTracker).
/// Furthermore, attach a name to the spawned task to more easily track it within a [tokio console](https://github.com/tokio-rs/console)
///
/// Note that is no different from [spawn](Self::spawn) if the underlying binary
/// has not been built with `RUSTFLAGS="--cfg tokio_unstable"` and `--features="tokio-tracing"`
#[track_caller]
pub fn try_spawn_named<F>(&self, task: F, name: &str) -> JoinHandle<F::Output>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
self.tracker.try_spawn_named(task, name)
}
/// Spawn the provided future on the provided Tokio runtime,
/// and track it in the underlying [TaskTracker](tokio_util::task::TaskTracker).
#[track_caller]
pub fn spawn_on<F>(&self, task: F, handle: &tokio::runtime::Handle) -> JoinHandle<F::Output>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
self.tracker.spawn_on(task, handle)
}
/// Spawn the provided future on the current [LocalSet](tokio::task::LocalSet),
/// and track it in the underlying [TaskTracker](tokio_util::task::TaskTracker).
#[track_caller]
pub fn spawn_local<F>(&self, task: F) -> JoinHandle<F::Output>
where
F: Future + 'static,
F::Output: 'static,
{
self.tracker.spawn_local(task)
}
/// Spawn the provided blocking task on the current Tokio runtime,
/// and track it in the underlying [TaskTracker](tokio_util::task::TaskTracker).
#[track_caller]
pub fn spawn_blocking<F, T>(&self, task: F) -> JoinHandle<T>
where
F: FnOnce() -> T,
F: Send + 'static,
T: Send + 'static,
{
self.tracker.spawn_blocking(task)
}
/// Spawn the provided blocking task on the provided Tokio runtime,
/// and track it in the underlying [TaskTracker](tokio_util::task::TaskTracker).
#[track_caller]
pub fn spawn_blocking_on<F, T>(&self, task: F, handle: &tokio::runtime::Handle) -> JoinHandle<T>
where
F: FnOnce() -> T,
F: Send + 'static,
T: Send + 'static,
{
self.tracker.spawn_blocking_on(task, handle)
}
/// Spawn the provided future on the current Tokio runtime
/// that will get cancelled once a global shutdown signal is detected,
/// and track it in the underlying [TaskTracker](tokio_util::task::TaskTracker).
///
/// Note that to fully use the naming feature, such as tracking within a [tokio console](https://github.com/tokio-rs/console),
/// the underlying binary has to be built with `RUSTFLAGS="--cfg tokio_unstable"` and `--features="tokio-tracing"`
#[track_caller]
pub fn try_spawn_named_with_shutdown<F>(
&self,
task: F,
name: &str,
) -> JoinHandle<Result<F::Output, Cancelled>>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
self.tracker.try_spawn_named_with_shutdown(task, name)
}
/// Spawn the provided future on the current Tokio runtime
/// that will get cancelled once a global shutdown signal is detected,
/// and track it in the underlying [TaskTracker](tokio_util::task::TaskTracker).
#[track_caller]
pub fn spawn_with_shutdown<F>(&self, task: F) -> JoinHandle<Result<F::Output, Cancelled>>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
self.tracker.spawn_with_shutdown(task)
}
}
#[cfg(target_arch = "wasm32")]
impl ShutdownManager {
/// Run the provided future on the current thread, and track it in the underlying [TaskTracker](tokio_util::task::TaskTracker).
#[track_caller]
pub fn spawn<F>(&self, task: F) -> JoinHandle<F::Output>
where
F: Future + 'static,
{
self.tracker.spawn(task)
}
/// Run the provided future on the current thread, and track it in the underlying [TaskTracker](tokio_util::task::TaskTracker).
/// It has exactly the same behaviour as [spawn](Self::spawn) and it only exists to provide
/// the same interface as non-wasm32 targets.
#[track_caller]
pub fn try_spawn_named<F>(&self, task: F, name: &str) -> JoinHandle<F::Output>
where
F: Future + 'static,
{
self.tracker.try_spawn_named(task, name)
}
/// Run the provided future on the current thread
/// that will get cancelled once a global shutdown signal is detected,
/// and track it in the underlying [TaskTracker](tokio_util::task::TaskTracker).
/// It has exactly the same behaviour as [spawn_with_shutdown](Self::spawn_with_shutdown) and it only exists to provide
/// the same interface as non-wasm32 targets.
#[track_caller]
pub fn try_spawn_named_with_shutdown<F>(
&self,
task: F,
name: &str,
) -> JoinHandle<Result<F::Output, Cancelled>>
where
F: Future<Output = ()> + Send + 'static,
{
self.tracker.try_spawn_named_with_shutdown(task, name)
}
/// Run the provided future on the current thread
/// that will get cancelled once a global shutdown signal is detected,
/// and track it in the underlying [TaskTracker](tokio_util::task::TaskTracker).
#[track_caller]
pub fn spawn_with_shutdown<F>(&self, task: F) -> JoinHandle<Result<F::Output, Cancelled>>
where
F: Future<Output = ()> + Send + 'static,
{
self.tracker.spawn_with_shutdown(task)
}
}
impl ShutdownManager {
/// Create new instance of ShutdownManager without any external shutdown signals registered,
/// meaning it will only attempt to wait for all tasks spawned on its tracker to gracefully finish execution.
pub fn new_without_signals() -> Self {
Self::new_from_external_shutdown_token(ShutdownToken::new())
}
/// Create new instance of the ShutdownManager using an external shutdown token.
///
/// Note: it will not listen to any external shutdown signals!
/// You might want further customise it with [shutdown signals](Self::with_shutdown)
/// (or just use [the default set](Self::with_default_shutdown_signals).
/// Similarly, you might want to include [cancellation on panic](Self::with_cancel_on_panic)
/// to make sure everything gets cancelled if one of the tasks panics.
pub fn new_from_external_shutdown_token(shutdown_token: ShutdownToken) -> Self {
let manager = ShutdownManager {
legacy_task_manager: None,
shutdown_signals: Default::default(),
tracker: ShutdownTracker::new_from_external_shutdown_token(shutdown_token),
max_shutdown_duration: Duration::from_secs(10),
};
// we need to add an explicit watcher for the cancellation token being cancelled
// so that we could cancel all legacy tasks
cfg_if::cfg_if! {if #[cfg(not(target_arch = "wasm32"))] {
let cancel_watcher = manager.tracker.clone_shutdown_token();
manager.with_shutdown(async move { cancel_watcher.cancelled().await })
} else {
manager
}}
}
/// Create an empty testing mock of the ShutdownManager with no signals registered.
pub fn empty_mock() -> Self {
ShutdownManager {
legacy_task_manager: None,
shutdown_signals: Default::default(),
tracker: Default::default(),
max_shutdown_duration: Default::default(),
}
}
/// Add additional panic hook such that upon triggering, the root [ShutdownToken](ShutdownToken) gets cancelled.
/// Note: an unfortunate limitation of this is that graceful shutdown will no longer be possible
/// since that task that has panicked will not exit and thus all shutdowns will have to be either forced
/// or will have to time out.
#[must_use]
pub fn with_cancel_on_panic(self) -> Self {
let current_hook = std::panic::take_hook();
let shutdown_token = self.clone_shutdown_token();
std::panic::set_hook(Box::new(move |panic_info| {
// 1. call existing hook
current_hook(panic_info);
let location = panic_info
.location()
.map(|l| l.to_string())
.unwrap_or_else(|| "<unknown>".to_string());
let payload = if let Some(payload) = panic_info.payload().downcast_ref::<&str>() {
payload
} else {
""
};
// 2. issue cancellation
error!("panicked at {location}: {payload}. issuing global cancellation");
shutdown_token.cancel();
}));
self
}
/// Change the maximum shutdown duration when tracked tasks could gracefully exit
/// before forcing the shutdown.
#[must_use]
pub fn with_shutdown_duration(mut self, duration: Duration) -> Self {
self.max_shutdown_duration = duration;
self
}
/// Returns true if the root [ShutdownToken](ShutdownToken) has been cancelled.
pub fn is_cancelled(&self) -> bool {
self.tracker.root_cancellation_token.is_cancelled()
}
/// Get a reference to the used [ShutdownTracker](ShutdownTracker)
pub fn shutdown_tracker(&self) -> &ShutdownTracker {
&self.tracker
}
/// Get a cloned instance of the used [ShutdownTracker](ShutdownTracker)
pub fn shutdown_tracker_owned(&self) -> ShutdownTracker {
self.tracker.clone()
}
/// Waits until the underlying [TaskTracker](tokio_util::task::TaskTracker) is both closed and empty.
///
/// If the underlying [TaskTracker](tokio_util::task::TaskTracker) is already closed and empty when this method is called, then it
/// returns immediately.
pub async fn wait_for_tracker(&self) {
self.tracker.wait_for_tracker().await;
}
/// Close the underlying [TaskTracker](tokio_util::task::TaskTracker).
///
/// This allows [`wait_for_tracker`] futures to complete. It does not prevent you from spawning new tasks.
///
/// Returns `true` if this closed the underlying [TaskTracker](tokio_util::task::TaskTracker), or `false` if it was already closed.
///
/// [`wait_for_tracker`]: ShutdownTracker::wait_for_tracker
pub fn close_tracker(&self) -> bool {
self.tracker.close_tracker()
}
/// Reopen the underlying [TaskTracker](tokio_util::task::TaskTracker).
///
/// This prevents [`wait_for_tracker`] futures from completing even if the underlying [TaskTracker](tokio_util::task::TaskTracker) is empty.
///
/// Returns `true` if this reopened the underlying [TaskTracker](tokio_util::task::TaskTracker), or `false` if it was already open.
///
/// [`wait_for_tracker`]: ShutdownTracker::wait_for_tracker
pub fn reopen_tracker(&self) -> bool {
self.tracker.reopen_tracker()
}
/// Returns `true` if the underlying [TaskTracker](tokio_util::task::TaskTracker) is [closed](Self::close_tracker).
pub fn is_tracker_closed(&self) -> bool {
self.tracker.is_tracker_closed()
}
/// Returns the number of tasks tracked by the underlying [TaskTracker](tokio_util::task::TaskTracker).
pub fn tracked_tasks(&self) -> usize {
self.tracker.tracked_tasks()
}
/// Returns `true` if there are no tasks in the underlying [TaskTracker](tokio_util::task::TaskTracker).
pub fn is_tracker_empty(&self) -> bool {
self.tracker.is_tracker_empty()
}
/// Obtain a [ShutdownToken](crate::cancellation::ShutdownToken) that is a child of the root token
pub fn child_shutdown_token(&self) -> ShutdownToken {
self.tracker.root_cancellation_token.child_token()
}
/// Obtain a [ShutdownToken](crate::cancellation::ShutdownToken) on the same hierarchical structure as the root token
pub fn clone_shutdown_token(&self) -> ShutdownToken {
self.tracker.root_cancellation_token.clone()
}
/// Attempt to create a handle to a legacy [TaskClient] to support tasks that hasn't migrated
/// from the legacy [TaskManager].
/// Note. To use this method [ShutdownManager] must be built with `.with_legacy_task_manager()`
#[must_use]
#[deprecated]
#[allow(deprecated)]
pub fn subscribe_legacy<S: Into<String>>(&self, child_suffix: S) -> crate::TaskClient {
// alternatively we could have set self.legacy_task_manager = Some(TaskManager::default());
// on demand if it wasn't unavailable, but then we'd have to use mutable reference
#[allow(clippy::expect_used)]
self.legacy_task_manager
.as_ref()
.expect("did not enable legacy shutdown support")
.subscribe_named(child_suffix)
}
/// Finalise the shutdown procedure by waiting until either:
/// - all tracked tasks have terminated
/// - timeout has been reached
/// - shutdown has been forced (by sending SIGINT)
async fn finish_shutdown(&mut self) {
let mut wait_futures = FuturesUnordered::<Pin<Box<dyn Future<Output = ()> + Send>>>::new();
// force shutdown via ctrl-c
wait_futures.push(Box::pin(async move {
#[cfg(not(target_arch = "wasm32"))]
let interrupt_future = tokio::signal::ctrl_c();
#[cfg(target_arch = "wasm32")]
let interrupt_future = futures::future::pending::<()>();
let _ = interrupt_future.await;
info!("received interrupt - forcing shutdown");
}));
// timeout
let max_shutdown = self.max_shutdown_duration;
wait_futures.push(Box::pin(async move {
sleep(max_shutdown).await;
info!("timeout reached - forcing shutdown");
}));
// graceful
let tracker = self.tracker.clone();
wait_futures.push(Box::pin(async move {
tracker.wait_for_tracker().await;
info!("all tracked tasks successfully shutdown");
if let Some(legacy) = self.legacy_task_manager.as_mut() {
legacy.wait_for_graceful_shutdown().await;
info!("all legacy tasks successfully shutdown");
}
info!("all registered tasks successfully shutdown")
}));
wait_futures.next().await;
}
/// Remove the current set of [ShutdownSignals] from this instance of
/// [ShutdownManager] replacing it with an empty set.
///
/// This is potentially useful if one wishes to start listening for the signals
/// before the whole process has been fully set up.
pub fn detach_shutdown_signals(&mut self) -> ShutdownSignals {
mem::take(&mut self.shutdown_signals)
}
/// Replace the current set of [ShutdownSignals] used for determining
/// whether the underlying process should be stopped.
pub fn replace_shutdown_signals(&mut self, signals: ShutdownSignals) {
self.shutdown_signals = signals;
}
/// Send cancellation signal to all registered tasks by cancelling the root token
/// and sending shutdown signal, if applicable, on the legacy [TaskManager]
pub fn send_cancellation(&self) {
if let Some(legacy_manager) = self.legacy_task_manager.as_ref() {
info!("attempting to shutdown legacy tasks");
let _ = legacy_manager.signal_shutdown();
}
self.tracker.root_cancellation_token.cancel();
}
/// Wait until receiving one of the registered shutdown signals
/// this method is cancellation safe
pub async fn wait_for_shutdown_signal(&mut self) {
#[cfg(not(target_arch = "wasm32"))]
self.shutdown_signals.0.join_next().await;
#[cfg(target_arch = "wasm32")]
self.tracker.root_cancellation_token.cancelled().await;
}
/// Perform system shutdown by sending relevant signals and waiting until either:
/// - all tracked tasks have terminated
/// - timeout has been reached
/// - shutdown has been forced (by sending SIGINT)
pub async fn perform_shutdown(&mut self) {
self.send_cancellation();
info!("waiting for tasks to finish... (press ctrl-c to force)");
self.finish_shutdown().await;
}
/// Wait until a shutdown signal has been received and trigger system shutdown.
pub async fn run_until_shutdown(&mut self) {
self.close_tracker();
self.wait_for_shutdown_signal().await;
self.perform_shutdown().await;
}
}
#[cfg(test)]
mod tests {
use super::*;
use nym_test_utils::traits::{ElapsedExt, Timeboxed};
use std::sync::atomic::AtomicBool;
use std::sync::Arc;
#[tokio::test]
async fn shutdown_with_no_tracked_tasks_and_signals() -> anyhow::Result<()> {
let mut manager = ShutdownManager::new_without_signals();
let res = manager.run_until_shutdown().timeboxed().await;
assert!(res.has_elapsed());
let mut manager = ShutdownManager::new_without_signals();
let shutdown = manager.clone_shutdown_token();
shutdown.cancel();
let res = manager.run_until_shutdown().timeboxed().await;
assert!(!res.has_elapsed());
Ok(())
}
#[tokio::test]
async fn shutdown_signal() -> anyhow::Result<()> {
let timeout_shutdown = sleep(Duration::from_millis(100));
let mut manager = ShutdownManager::new_without_signals().with_shutdown(timeout_shutdown);
// execution finishes after the sleep gets finishes
let res = manager
.run_until_shutdown()
.execute_with_deadline(Duration::from_millis(200))
.await;
assert!(!res.has_elapsed());
Ok(())
}
#[tokio::test]
async fn panic_hook() -> anyhow::Result<()> {
let mut manager = ShutdownManager::new_without_signals().with_cancel_on_panic();
manager.spawn_with_shutdown(async move {
sleep(Duration::from_millis(10000)).await;
});
manager.spawn_with_shutdown(async move {
sleep(Duration::from_millis(10)).await;
panic!("panicking");
});
// execution finishes after the panic gets triggered
let res = manager
.run_until_shutdown()
.execute_with_deadline(Duration::from_millis(200))
.await;
assert!(!res.has_elapsed());
Ok(())
}
#[tokio::test]
async fn task_cancellation() -> anyhow::Result<()> {
let timeout_shutdown = sleep(Duration::from_millis(100));
let mut manager = ShutdownManager::new_without_signals().with_shutdown(timeout_shutdown);
let cancelled1 = Arc::new(AtomicBool::new(false));
let cancelled1_clone = cancelled1.clone();
let cancelled2 = Arc::new(AtomicBool::new(false));
let cancelled2_clone = cancelled2.clone();
let shutdown = manager.clone_shutdown_token();
manager.spawn(async move {
shutdown.cancelled().await;
cancelled1_clone.store(true, std::sync::atomic::Ordering::Relaxed);
});
let shutdown = manager.clone_shutdown_token();
manager.spawn(async move {
shutdown.cancelled().await;
cancelled2_clone.store(true, std::sync::atomic::Ordering::Relaxed);
});
let res = manager
.run_until_shutdown()
.execute_with_deadline(Duration::from_millis(200))
.await;
assert!(!res.has_elapsed());
assert!(cancelled1.load(std::sync::atomic::Ordering::Relaxed));
assert!(cancelled2.load(std::sync::atomic::Ordering::Relaxed));
Ok(())
}
#[tokio::test]
async fn cancellation_within_task() -> anyhow::Result<()> {
let mut manager = ShutdownManager::new_without_signals();
let cancelled1 = Arc::new(AtomicBool::new(false));
let cancelled1_clone = cancelled1.clone();
let shutdown = manager.clone_shutdown_token();
manager.spawn(async move {
shutdown.cancelled().await;
cancelled1_clone.store(true, std::sync::atomic::Ordering::Relaxed);
});
let shutdown = manager.clone_shutdown_token();
manager.spawn(async move {
sleep(Duration::from_millis(10)).await;
shutdown.cancel();
});
let res = manager
.run_until_shutdown()
.execute_with_deadline(Duration::from_millis(200))
.await;
assert!(!res.has_elapsed());
assert!(cancelled1.load(std::sync::atomic::Ordering::Relaxed));
Ok(())
}
#[tokio::test]
async fn shutdown_timeout() -> anyhow::Result<()> {
let timeout_shutdown = sleep(Duration::from_millis(50));
let mut manager = ShutdownManager::new_without_signals()
.with_shutdown(timeout_shutdown)
.with_shutdown_duration(Duration::from_millis(1000));
// ignore shutdown signals
manager.spawn(async move {
sleep(Duration::from_millis(1000)).await;
});
let res = manager
.run_until_shutdown()
.execute_with_deadline(Duration::from_millis(200))
.await;
assert!(res.has_elapsed());
let timeout_shutdown = sleep(Duration::from_millis(50));
let mut manager = ShutdownManager::new_without_signals()
.with_shutdown(timeout_shutdown)
.with_shutdown_duration(Duration::from_millis(100));
// ignore shutdown signals
manager.spawn(async move {
sleep(Duration::from_millis(1000)).await;
});
let res = manager
.run_until_shutdown()
.execute_with_deadline(Duration::from_millis(200))
.await;
assert!(!res.has_elapsed());
Ok(())
}
}
+54
View File
@@ -0,0 +1,54 @@
// Copyright 2025 - Nym Technologies SA <contact@nymtech.net>
// SPDX-License-Identifier: Apache-2.0
//! A [CancellationToken](tokio_util::sync::CancellationToken)-backed shutdown mechanism for Nym binaries.
//!
//! It allows creation of a centralised manager for keeping track of all signals that are meant
//! to trigger exit of all associated tasks and sending cancellation to the aforementioned futures.
//!
//! # Default usage
//!
//! ```no_run
//! use std::time::Duration;
//! use tokio::time::sleep;
//! use nym_task::{ShutdownManager, ShutdownToken};
//!
//! async fn my_task() {
//! loop {
//! sleep(Duration::from_secs(5)).await
//! // do some periodic work that can be easily interrupted
//! }
//! }
//!
//! async fn important_work_that_cant_be_interrupted() {}
//!
//! async fn my_managed_task(shutdown_token: ShutdownToken) {
//! tokio::select! {
//! _ = shutdown_token.cancelled() => {}
//! _ = important_work_that_cant_be_interrupted() => {}
//! }
//! }
//! #[tokio::main]
//! async fn main() {
//! let mut shutdown_manager = ShutdownManager::build_new_default().expect("failed to register default shutdown signals");
//!
//! let shutdown_token = shutdown_manager.child_shutdown_token();
//! shutdown_manager.try_spawn_named(async move { my_managed_task(shutdown_token).await }, "important-managed-task");
//! shutdown_manager.try_spawn_named_with_shutdown(my_task(), "another-task");
//!
//! // wait for shutdown signal
//! shutdown_manager.run_until_shutdown().await;
//! }
//! ```
use std::time::Duration;
pub mod manager;
pub mod token;
pub mod tracker;
pub use manager::ShutdownManager;
pub use token::{ShutdownDropGuard, ShutdownToken};
pub use tracker::ShutdownTracker;
pub const DEFAULT_MAX_SHUTDOWN_DURATION: Duration = Duration::from_secs(5);
+150
View File
@@ -0,0 +1,150 @@
// Copyright 2025 - Nym Technologies SA <contact@nymtech.net>
// SPDX-License-Identifier: Apache-2.0
use crate::event::SentStatus;
use std::future::Future;
use tokio_util::sync::{
CancellationToken, DropGuard, WaitForCancellationFuture, WaitForCancellationFutureOwned,
};
use tracing::warn;
/// A wrapped [CancellationToken](tokio_util::sync::CancellationToken) that is used for
/// signalling and listening for cancellation requests.
// We don't use CancellationToken in case we wanted to include additional fields/methods
// down the line.
#[derive(Debug, Clone, Default)]
pub struct ShutdownToken {
inner: CancellationToken,
}
impl ShutdownToken {
/// A drop in no-op replacement for `send_status_msg` for easier migration from [TaskClient](crate::TaskClient).
#[deprecated]
#[track_caller]
pub fn send_status_msg(&self, status: SentStatus) {
let caller = std::panic::Location::caller();
warn!("{caller} attempted to send {status} - there are no more listeners of those");
}
/// Creates a new ShutdownToken in the non-cancelled state.
pub fn new() -> Self {
ShutdownToken {
inner: CancellationToken::new(),
}
}
/// Gets reference to the underlying [CancellationToken](tokio_util::sync::CancellationToken).
pub fn inner(&self) -> &CancellationToken {
&self.inner
}
/// Creates a `ShutdownToken` which will get cancelled whenever the
/// current token gets cancelled. Unlike a cloned `ShutdownToken`,
/// cancelling a child token does not cancel the parent token.
///
/// If the current token is already cancelled, the child token will get
/// returned in cancelled state.
pub fn child_token(&self) -> ShutdownToken {
ShutdownToken {
inner: self.inner.child_token(),
}
}
/// Cancel the underlying [CancellationToken](tokio_util::sync::CancellationToken) and all child tokens which had been
/// derived from it.
///
/// This will wake up all tasks which are waiting for cancellation.
pub fn cancel(&self) {
self.inner.cancel();
}
/// Returns `true` if the underlying [CancellationToken](tokio_util::sync::CancellationToken) is cancelled.
pub fn is_cancelled(&self) -> bool {
self.inner.is_cancelled()
}
/// Returns a `Future` that gets fulfilled when cancellation is requested.
///
/// The future will complete immediately if the token is already cancelled
/// when this method is called.
///
/// # Cancel safety
///
/// This method is cancel safe.
pub fn cancelled(&self) -> WaitForCancellationFuture<'_> {
self.inner.cancelled()
}
/// Returns a `Future` that gets fulfilled when cancellation is requested.
///
/// The future will complete immediately if the token is already cancelled
/// when this method is called.
///
/// The function takes self by value and returns a future that owns the
/// token.
///
/// # Cancel safety
///
/// This method is cancel safe.
pub fn cancelled_owned(self) -> WaitForCancellationFutureOwned {
self.inner.cancelled_owned()
}
/// Creates a `ShutdownDropGuard` for this token.
///
/// Returned guard will cancel this token (and all its children) on drop
/// unless disarmed.
pub fn drop_guard(self) -> ShutdownDropGuard {
ShutdownDropGuard {
inner: self.inner.drop_guard(),
}
}
/// Runs a future to completion and returns its result wrapped inside an `Option`
/// unless the `ShutdownToken` is cancelled. In that case the function returns
/// `None` and the future gets dropped.
///
/// # Cancel safety
///
/// This method is only cancel safe if `fut` is cancel safe.
pub async fn run_until_cancelled<F>(&self, fut: F) -> Option<F::Output>
where
F: Future,
{
self.inner.run_until_cancelled(fut).await
}
/// Runs a future to completion and returns its result wrapped inside an `Option`
/// unless the `ShutdownToken` is cancelled. In that case the function returns
/// `None` and the future gets dropped.
///
/// The function takes self by value and returns a future that owns the token.
///
/// # Cancel safety
///
/// This method is only cancel safe if `fut` is cancel safe.
pub async fn run_until_cancelled_owned<F>(self, fut: F) -> Option<F::Output>
where
F: Future,
{
self.inner.run_until_cancelled_owned(fut).await
}
}
/// A wrapper for [DropGuard](tokio_util::sync::DropGuard) that wraps around a cancellation token
/// which automatically cancels it on drop.
/// It is created using `drop_guard` method on the `ShutdownToken`.
pub struct ShutdownDropGuard {
inner: DropGuard,
}
impl ShutdownDropGuard {
/// Returns stored [ShutdownToken](ShutdownToken) and removes this drop guard instance
/// (i.e. it will no longer cancel token). Other guards for this token
/// are not affected.
pub fn disarm(self) -> ShutdownToken {
ShutdownToken {
inner: self.inner.disarm(),
}
}
}
+317
View File
@@ -0,0 +1,317 @@
// Copyright 2025 - Nym Technologies SA <contact@nymtech.net>
// SPDX-License-Identifier: Apache-2.0
use crate::cancellation::token::ShutdownToken;
use crate::spawn::{spawn_named_future, JoinHandle};
use crate::spawn_future;
use std::future::Future;
use thiserror::Error;
use tokio_util::task::TaskTracker;
use tracing::{debug, trace};
#[derive(Debug, Error)]
#[error("task got cancelled")]
pub struct Cancelled;
/// Extracted [TaskTracker](tokio_util::task::TaskTracker) and [ShutdownToken](ShutdownToken) to more easily allow tracking nested tasks
/// without having to pass whole [ShutdownManager](ShutdownManager) around.
#[derive(Clone, Default, Debug)]
pub struct ShutdownTracker {
/// The root [ShutdownToken](ShutdownToken) that will trigger all derived tasks
/// to receive cancellation signal.
pub(crate) root_cancellation_token: ShutdownToken,
// Note: the reason we're not using a `JoinSet` is
// because it forces us to use futures with the same `::Output` type,
// which is not really a desirable property in this instance.
/// Tracker used for keeping track of all registered tasks
/// so that they could be stopped gracefully before ending the process.
pub(crate) tracker: TaskTracker,
}
#[cfg(not(target_arch = "wasm32"))]
impl ShutdownTracker {
/// Spawn the provided future on the current Tokio runtime, and track it in the underlying [TaskTracker](tokio_util::task::TaskTracker).
#[track_caller]
pub fn spawn<F>(&self, task: F) -> JoinHandle<F::Output>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
let tracked = self.tracker.track_future(task);
spawn_future(tracked)
}
/// Spawn the provided future on the current Tokio runtime,
/// and track it in the underlying [TaskTracker](tokio_util::task::TaskTracker).
/// Furthermore, attach a name to the spawned task to more easily track it within a [tokio console](https://github.com/tokio-rs/console)
///
/// Note that is no different from [spawn](Self::spawn) if the underlying binary
/// has not been built with `RUSTFLAGS="--cfg tokio_unstable"` and `--features="tokio-tracing"`
#[track_caller]
pub fn try_spawn_named<F>(&self, task: F, name: &str) -> JoinHandle<F::Output>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
trace!("attempting to spawn task {name}");
let tracked = self.tracker.track_future(task);
spawn_named_future(tracked, name)
}
/// Spawn the provided future on the provided Tokio runtime,
/// and track it in the underlying [TaskTracker](tokio_util::task::TaskTracker).
#[track_caller]
pub fn spawn_on<F>(&self, task: F, handle: &tokio::runtime::Handle) -> JoinHandle<F::Output>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
self.tracker.spawn_on(task, handle)
}
/// Spawn the provided future on the current [LocalSet](tokio::task::LocalSet),
/// and track it in the underlying [TaskTracker](tokio_util::task::TaskTracker).
#[track_caller]
pub fn spawn_local<F>(&self, task: F) -> JoinHandle<F::Output>
where
F: Future + 'static,
F::Output: 'static,
{
self.tracker.spawn_local(task)
}
/// Spawn the provided blocking task on the current Tokio runtime,
/// and track it in the underlying [TaskTracker](tokio_util::task::TaskTracker).
#[track_caller]
pub fn spawn_blocking<F, T>(&self, task: F) -> JoinHandle<T>
where
F: FnOnce() -> T,
F: Send + 'static,
T: Send + 'static,
{
self.tracker.spawn_blocking(task)
}
/// Spawn the provided blocking task on the provided Tokio runtime,
/// and track it in the underlying [TaskTracker](tokio_util::task::TaskTracker).
#[track_caller]
pub fn spawn_blocking_on<F, T>(&self, task: F, handle: &tokio::runtime::Handle) -> JoinHandle<T>
where
F: FnOnce() -> T,
F: Send + 'static,
T: Send + 'static,
{
self.tracker.spawn_blocking_on(task, handle)
}
/// Spawn the provided future on the current Tokio runtime
/// that will get cancelled once a global shutdown signal is detected,
/// and track it in the underlying [TaskTracker](tokio_util::task::TaskTracker).
///
/// Note that to fully use the naming feature, such as tracking within a [tokio console](https://github.com/tokio-rs/console),
/// the underlying binary has to be built with `RUSTFLAGS="--cfg tokio_unstable"` and `--features="tokio-tracing"`
#[track_caller]
pub fn try_spawn_named_with_shutdown<F>(
&self,
task: F,
name: &str,
) -> JoinHandle<Result<F::Output, Cancelled>>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
trace!("attempting to spawn task {name} (with top-level cancellation)");
let caller = std::panic::Location::caller();
let shutdown_token = self.clone_shutdown_token();
let name_owned = name.to_string();
let tracked = self.tracker.track_future(async move {
match shutdown_token.run_until_cancelled_owned(task).await {
Some(result) => {
debug!("{name_owned} @ {caller}: task has finished execution");
Ok(result)
}
None => {
trace!("{name_owned} @ {caller}: shutdown signal received, shutting down");
Err(Cancelled)
}
}
});
spawn_named_future(tracked, name)
}
/// Spawn the provided future on the current Tokio runtime
/// that will get cancelled once a global shutdown signal is detected,
/// and track it in the underlying [TaskTracker](tokio_util::task::TaskTracker).
#[track_caller]
pub fn spawn_with_shutdown<F>(&self, task: F) -> JoinHandle<Result<F::Output, Cancelled>>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
let caller = std::panic::Location::caller();
let shutdown_token = self.clone_shutdown_token();
self.tracker.spawn(async move {
match shutdown_token.run_until_cancelled_owned(task).await {
Some(result) => {
debug!("{caller}: task has finished execution");
Ok(result)
}
None => {
trace!("{caller}: shutdown signal received, shutting down");
Err(Cancelled)
}
}
})
}
}
#[cfg(target_arch = "wasm32")]
impl ShutdownTracker {
/// Run the provided future on the current thread, and track it in the underlying [TaskTracker](tokio_util::task::TaskTracker).
#[track_caller]
pub fn spawn<F>(&self, task: F) -> JoinHandle<F::Output>
where
F: Future + 'static,
{
let tracked = self.tracker.track_future(task);
spawn_future(tracked)
}
/// Run the provided future on the current thread, and track it in the underlying [TaskTracker](tokio_util::task::TaskTracker).
/// It has exactly the same behaviour as [spawn](Self::spawn) and it only exists to provide
/// the same interface as non-wasm32 targets.
#[track_caller]
pub fn try_spawn_named<F>(&self, task: F, name: &str) -> JoinHandle<F::Output>
where
F: Future + 'static,
{
let tracked = self.tracker.track_future(task);
spawn_named_future(tracked, name)
}
/// Run the provided future on the current thread
/// that will get cancelled once a global shutdown signal is detected,
/// and track it in the underlying [TaskTracker](tokio_util::task::TaskTracker).
/// It has exactly the same behaviour as [spawn_with_shutdown](Self::spawn_with_shutdown) and it only exists to provide
/// the same interface as non-wasm32 targets.
#[track_caller]
pub fn try_spawn_named_with_shutdown<F>(
&self,
task: F,
name: &str,
) -> JoinHandle<Result<F::Output, Cancelled>>
where
F: Future<Output = ()> + 'static,
{
let caller = std::panic::Location::caller();
let shutdown_token = self.clone_shutdown_token();
let tracked = self.tracker.track_future(async move {
match shutdown_token.run_until_cancelled_owned(task).await {
Some(result) => {
debug!("{caller}: task has finished execution");
Ok(result)
}
None => {
trace!("{caller}: shutdown signal received, shutting down");
Err(Cancelled)
}
}
});
spawn_named_future(tracked, name)
}
/// Run the provided future on the current thread
/// that will get cancelled once a global shutdown signal is detected,
/// and track it in the underlying [TaskTracker](tokio_util::task::TaskTracker).
#[track_caller]
pub fn spawn_with_shutdown<F>(&self, task: F) -> JoinHandle<Result<F::Output, Cancelled>>
where
F: Future<Output = ()> + 'static,
{
let caller = std::panic::Location::caller();
let shutdown_token = self.clone_shutdown_token();
let tracked = self.tracker.track_future(async move {
match shutdown_token.run_until_cancelled_owned(task).await {
Some(result) => {
debug!("{caller}: task has finished execution");
Ok(result)
}
None => {
trace!("{caller}: shutdown signal received, shutting down");
Err(Cancelled)
}
}
});
spawn_future(tracked)
}
}
impl ShutdownTracker {
/// Create new instance of the ShutdownTracker using an external shutdown token.
/// This could be useful in situations where shutdown is being managed by an external entity
/// that is not [ShutdownManager](ShutdownManager), but interface requires providing a ShutdownTracker,
/// such as client-core tasks
pub fn new_from_external_shutdown_token(shutdown_token: ShutdownToken) -> Self {
ShutdownTracker {
root_cancellation_token: shutdown_token,
tracker: Default::default(),
}
}
/// Waits until the underlying [TaskTracker](tokio_util::task::TaskTracker) is both closed and empty.
///
/// If the underlying [TaskTracker](tokio_util::task::TaskTracker) is already closed and empty when this method is called, then it
/// returns immediately.
pub async fn wait_for_tracker(&self) {
self.tracker.wait().await;
}
/// Close the underlying [TaskTracker](tokio_util::task::TaskTracker).
///
/// This allows [`wait_for_tracker`] futures to complete. It does not prevent you from spawning new tasks.
///
/// Returns `true` if this closed the underlying [TaskTracker](tokio_util::task::TaskTracker), or `false` if it was already closed.
///
/// [`wait_for_tracker`]: Self::wait_for_tracker
pub fn close_tracker(&self) -> bool {
self.tracker.close()
}
/// Reopen the underlying [TaskTracker](tokio_util::task::TaskTracker).
///
/// This prevents [`wait_for_tracker`] futures from completing even if the underlying [TaskTracker](tokio_util::task::TaskTracker) is empty.
///
/// Returns `true` if this reopened the underlying [TaskTracker](tokio_util::task::TaskTracker), or `false` if it was already open.
///
/// [`wait_for_tracker`]: Self::wait_for_tracker
pub fn reopen_tracker(&self) -> bool {
self.tracker.reopen()
}
/// Returns `true` if the underlying [TaskTracker](tokio_util::task::TaskTracker) is [closed](Self::close_tracker).
pub fn is_tracker_closed(&self) -> bool {
self.tracker.is_closed()
}
/// Returns the number of tasks tracked by the underlying [TaskTracker](tokio_util::task::TaskTracker).
pub fn tracked_tasks(&self) -> usize {
self.tracker.len()
}
/// Returns `true` if there are no tasks in the underlying [TaskTracker](tokio_util::task::TaskTracker).
pub fn is_tracker_empty(&self) -> bool {
self.tracker.is_empty()
}
/// Obtain a [ShutdownToken](crate::cancellation::ShutdownToken) that is a child of the root token
pub fn child_shutdown_token(&self) -> ShutdownToken {
self.root_cancellation_token.child_token()
}
/// Obtain a [ShutdownToken](crate::cancellation::ShutdownToken) on the same hierarchical structure as the root token
pub fn clone_shutdown_token(&self) -> ShutdownToken {
self.root_cancellation_token.clone()
}
}
+17 -20
View File
@@ -2,12 +2,9 @@
// SPDX-License-Identifier: Apache-2.0
use futures::channel::mpsc;
use std::{
collections::HashMap,
time::{Duration, Instant},
};
use std::collections::HashMap;
const LANE_CONSIDERED_CLEAR: usize = 10;
// const LANE_CONSIDERED_CLEAR: usize = 10;
pub type ConnectionId = u64;
@@ -83,21 +80,21 @@ impl LaneQueueLengths {
}
}
pub async fn wait_until_clear(&self, lane: &TransmissionLane, timeout: Option<Duration>) {
let total_time_waited = Instant::now();
loop {
let lane_length = self.get(lane).unwrap_or_default();
if lane_length < LANE_CONSIDERED_CLEAR {
break;
}
if timeout.is_some_and(|timeout| total_time_waited.elapsed() > timeout) {
log::warn!("Timeout reached while waiting for queue to clear");
break;
}
log::trace!("Waiting for queue to clear ({lane_length} items left)");
tokio::time::sleep(Duration::from_millis(100)).await;
}
}
// pub async fn wait_until_clear(&self, lane: &TransmissionLane, timeout: Option<Duration>) {
// let total_time_waited = Instant::now();
// loop {
// let lane_length = self.get(lane).unwrap_or_default();
// if lane_length < LANE_CONSIDERED_CLEAR {
// break;
// }
// if timeout.is_some_and(|timeout| total_time_waited.elapsed() > timeout) {
// log::warn!("Timeout reached while waiting for queue to clear");
// break;
// }
// log::trace!("Waiting for queue to clear ({lane_length} items left)");
// tokio::time::sleep(Duration::from_millis(100)).await;
// }
// }
}
impl Default for LaneQueueLengths {
+4 -3
View File
@@ -9,10 +9,11 @@ pub mod manager;
pub mod signal;
pub mod spawn;
pub use cancellation::{ShutdownDropGuard, ShutdownManager, ShutdownToken};
pub use cancellation::{ShutdownDropGuard, ShutdownManager, ShutdownToken, ShutdownTracker};
pub use event::{StatusReceiver, StatusSender, TaskStatus, TaskStatusEvent};
pub use manager::{TaskClient, TaskHandle, TaskManager};
pub use spawn::{spawn, spawn_with_report_error};
#[allow(deprecated)]
pub use manager::{TaskClient, TaskManager};
pub use spawn::spawn_future;
pub use tokio_util::task::TaskTracker;
#[cfg(not(target_arch = "wasm32"))]
+20 -1
View File
@@ -44,6 +44,7 @@ enum TaskError {
/// Listens to status and error messages from tasks, as well as notifying them to gracefully
/// shutdown. Keeps track of if task stop unexpectedly, such as in a panic.
#[deprecated(note = "use ShutdownManager instead")]
#[derive(Debug)]
pub struct TaskManager {
// optional name assigned to the task manager that all subscribed task clients will inherit
@@ -72,6 +73,7 @@ pub struct TaskManager {
task_status_rx: Option<StatusReceiver>,
}
#[allow(deprecated)]
impl Default for TaskManager {
fn default() -> Self {
let (notify_tx, notify_rx) = watch::channel(());
@@ -95,6 +97,8 @@ impl Default for TaskManager {
}
}
#[allow(deprecated)]
#[allow(clippy::expect_used)]
impl TaskManager {
pub fn new(shutdown_timer_secs: u64) -> Self {
Self {
@@ -168,7 +172,7 @@ impl TaskManager {
if let Some(mut task_status_rx) = self.task_status_rx.take() {
log::info!("Starting status message listener");
crate::spawn::spawn(async move {
crate::spawn::spawn_future(async move {
loop {
if let Some(msg) = task_status_rx.next().await {
log::trace!("Got msg: {msg}");
@@ -186,12 +190,14 @@ impl TaskManager {
}
// used for compatibility with the ShutdownManager
#[cfg(not(target_arch = "wasm32"))]
pub(crate) fn task_return_error_rx(&mut self) -> ErrorReceiver {
self.task_return_error_rx
.take()
.expect("unable to get error channel: attempt to wait twice?")
}
#[cfg(not(target_arch = "wasm32"))]
pub(crate) fn task_drop_rx(&mut self) -> ErrorReceiver {
self.task_drop_rx
.take()
@@ -259,6 +265,7 @@ impl TaskManager {
/// Listen for shutdown notifications, and can send error and status messages back to the
/// `TaskManager`
#[derive(Debug)]
#[deprecated(note = "use ShutdownToken instead")]
pub struct TaskClient {
// optional name assigned to the shutdown handle
name: Option<String>,
@@ -286,6 +293,7 @@ pub struct TaskClient {
mode: ClientOperatingMode,
}
#[allow(deprecated)]
impl Clone for TaskClient {
fn clone(&self) -> Self {
// make sure to not accidentally overflow the stack if we keep cloning the handle
@@ -313,6 +321,7 @@ impl Clone for TaskClient {
}
}
#[allow(deprecated)]
impl TaskClient {
const MAX_NAME_LENGTH: usize = 128;
const OVERFLOW_NAME: &'static str = "reached maximum TaskClient children name depth";
@@ -433,6 +442,8 @@ impl TaskClient {
.await
}
// legacy code
#[allow(clippy::panic)]
pub async fn recv_timeout(&mut self) {
if self.mode.is_dummy() {
return pending().await;
@@ -505,6 +516,7 @@ impl TaskClient {
}
}
#[allow(deprecated)]
impl Drop for TaskClient {
fn drop(&mut self) {
if !self.mode.should_signal_on_drop() {
@@ -572,6 +584,8 @@ impl ClientOperatingMode {
}
}
#[deprecated]
#[allow(deprecated)]
#[derive(Debug)]
pub enum TaskHandle {
/// Full [`TaskManager`] that was created by the underlying task.
@@ -581,24 +595,28 @@ pub enum TaskHandle {
External(TaskClient),
}
#[allow(deprecated)]
impl From<TaskManager> for TaskHandle {
fn from(value: TaskManager) -> Self {
TaskHandle::Internal(value)
}
}
#[allow(deprecated)]
impl From<TaskClient> for TaskHandle {
fn from(value: TaskClient) -> Self {
TaskHandle::External(value)
}
}
#[allow(deprecated)]
impl Default for TaskHandle {
fn default() -> Self {
TaskHandle::Internal(TaskManager::default())
}
}
#[allow(deprecated)]
impl TaskHandle {
#[must_use]
pub fn name_if_unnamed<S: Into<String>>(self, name: S) -> Self {
@@ -666,6 +684,7 @@ mod tests {
use super::*;
#[tokio::test]
#[allow(deprecated)]
async fn signal_shutdown() {
let shutdown = TaskManager::default();
let mut listener = shutdown.subscribe();
+7 -3
View File
@@ -1,6 +1,7 @@
use crate::{manager::SentError, TaskManager};
use crate::manager::SentError;
#[cfg(unix)]
#[allow(clippy::expect_used)]
pub async fn wait_for_signal() {
use tokio::signal::unix::{signal, SignalKind};
let mut sigterm = signal(SignalKind::terminate()).expect("Failed to setup SIGTERM channel");
@@ -28,8 +29,10 @@ pub async fn wait_for_signal() {
}
}
#[allow(deprecated)]
#[cfg(unix)]
pub async fn wait_for_signal_and_error(shutdown: &mut TaskManager) -> Result<(), SentError> {
#[allow(clippy::expect_used)]
pub async fn wait_for_signal_and_error(shutdown: &mut crate::TaskManager) -> Result<(), SentError> {
use tokio::signal::unix::{signal, SignalKind};
let mut sigterm = signal(SignalKind::terminate()).expect("Failed to setup SIGTERM channel");
@@ -55,8 +58,9 @@ pub async fn wait_for_signal_and_error(shutdown: &mut TaskManager) -> Result<(),
}
}
#[allow(deprecated)]
#[cfg(not(unix))]
pub async fn wait_for_signal_and_error(shutdown: &mut TaskManager) -> Result<(), SentError> {
pub async fn wait_for_signal_and_error(shutdown: &mut crate::TaskManager) -> Result<(), SentError> {
tokio::select! {
_ = tokio::signal::ctrl_c() => {
log::info!("Received SIGINT");
+60 -16
View File
@@ -1,35 +1,79 @@
use crate::TaskClient;
use std::future::Future;
#[cfg(not(target_arch = "wasm32"))]
pub type JoinHandle<F> = tokio::task::JoinHandle<F>;
// no JoinHandle equivalent in wasm
#[cfg(target_arch = "wasm32")]
pub fn spawn<F>(future: F)
#[derive(Clone, Copy)]
pub struct FakeJoinHandle<F> {
_p: std::marker::PhantomData<F>,
}
#[cfg(target_arch = "wasm32")]
pub type JoinHandle<F> = FakeJoinHandle<F>;
#[cfg(target_arch = "wasm32")]
#[track_caller]
pub fn spawn_future<F>(future: F) -> JoinHandle<F::Output>
where
F: Future<Output = ()> + 'static,
F: Future + 'static,
{
wasm_bindgen_futures::spawn_local(future);
wasm_bindgen_futures::spawn_local(async move {
// make sure the future outputs `()`
future.await;
});
FakeJoinHandle {
_p: std::marker::PhantomData,
}
}
// Note: prefer spawning tasks directly on the ShutdownManager
#[cfg(not(target_arch = "wasm32"))]
#[track_caller]
pub fn spawn<F>(future: F)
pub fn spawn_future<F>(future: F) -> JoinHandle<F::Output>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
tokio::spawn(future);
tokio::spawn(future)
}
// Note: prefer spawning tasks directly on the ShutdownManager
#[cfg(not(target_arch = "wasm32"))]
#[track_caller]
pub fn spawn_with_report_error<F, T, E>(future: F, mut shutdown: TaskClient)
pub fn spawn_named_future<F>(future: F, name: &str) -> JoinHandle<F::Output>
where
F: Future<Output = Result<T, E>> + Send + 'static,
T: 'static,
E: std::error::Error + Send + Sync + 'static,
F: Future + Send + 'static,
F::Output: Send + 'static,
{
let future_that_sends = async move {
if let Err(err) = future.await {
shutdown.send_we_stopped(Box::new(err));
}
};
spawn(future_that_sends);
cfg_if::cfg_if! {if #[cfg(all(tokio_unstable, feature="tokio-tracing"))] {
#[allow(clippy::expect_used)]
tokio::task::Builder::new().name(name).spawn(future).expect("failed to spawn future")
} else {
let _ = name;
tracing::debug!(r#"the underlying binary hasn't been built with `RUSTFLAGS="--cfg tokio_unstable"` - the future naming won't do anything"#);
spawn_future(future)
}}
}
#[cfg(target_arch = "wasm32")]
#[track_caller]
pub fn spawn_named_future<F>(future: F, name: &str) -> JoinHandle<F::Output>
where
F: Future + 'static,
{
// not supported in wasm
let _ = name;
spawn_future(future)
}
#[macro_export]
macro_rules! spawn_future {
($future:expr) => {{
$crate::spawn_future($future)
}};
($future:expr, $name:expr) => {{
$crate::spawn_named_future($future, $name)
}};
}
+10
View File
@@ -32,6 +32,16 @@ pub trait Timeboxed: IntoFuture + Sized {
impl<T> Timeboxed for T where T: IntoFuture + Sized {}
pub trait ElapsedExt {
fn has_elapsed(&self) -> bool;
}
impl<T> ElapsedExt for Result<T, Elapsed> {
fn has_elapsed(&self) -> bool {
self.is_err()
}
}
// those are internal testing traits so we're not concerned about auto traits
#[allow(async_fn_in_trait)]
pub trait Spawnable: Future + Sized + Send + 'static {
+8 -6
View File
@@ -51,24 +51,26 @@ impl PacketListener {
info!("Started listening for echo packets on {}", self.address);
while !self.shutdown_token.is_cancelled() {
loop {
// cloning the arc as each accepted socket is handled in separate task
let connection_handler = Arc::clone(&self.connection_handler);
tokio::select! {
biased;
_ = self.shutdown_token.cancelled() => {
trace!("PacketListener: Received shutdown");
break;
}
socket = listener.accept() => {
match socket {
Ok((socket, remote_addr)) => {
debug!("New verloc connection from {remote_addr}");
let cancel = self.shutdown_token.child_token(format!("handler_{remote_addr}"));
tokio::spawn(async move { cancel.run_until_cancelled(connection_handler.handle_connection(socket, remote_addr)).await });
let cancel = self.shutdown_token.child_token();
tokio::spawn(cancel.run_until_cancelled_owned(connection_handler.handle_connection(socket, remote_addr)));
}
Err(err) => warn!("Failed to accept incoming connection - {err}"),
}
},
_ = self.shutdown_token.cancelled() => {
trace!("PacketListener: Received shutdown");
}
}
}
}
+9 -7
View File
@@ -40,12 +40,12 @@ impl VerlocMeasurer {
config.packet_timeout,
config.connection_timeout,
config.delay_between_packets,
shutdown_token.clone_with_suffix("packet_sender"),
shutdown_token.clone(),
)),
packet_listener: Arc::new(PacketListener::new(
config.listening_address,
Arc::clone(&identity),
shutdown_token.clone_with_suffix("packet_listener"),
shutdown_token.clone(),
)),
shutdown_token,
config,
@@ -92,8 +92,13 @@ impl VerlocMeasurer {
.collect::<FuturesUnordered<_>>();
// exhaust the results
while !self.shutdown_token.is_cancelled() {
loop {
tokio::select! {
biased;
_ = self.shutdown_token.cancelled() => {
trace!("Shutdown received while measuring");
return MeasurementOutcome::Shutdown;
}
measurement_result = measurement_chunk.next() => {
let Some(result) = measurement_result else {
// if the stream has finished, it means we got everything we could have gotten
@@ -117,10 +122,6 @@ impl VerlocMeasurer {
};
chunk_results.push(VerlocNodeResult::new(identity, measurement_result));
},
_ = self.shutdown_token.cancelled() => {
trace!("Shutdown received while measuring");
return MeasurementOutcome::Shutdown;
}
}
}
@@ -208,6 +209,7 @@ impl VerlocMeasurer {
_ = sleep(self.config.testing_interval) => {},
_ = self.shutdown_token.cancelled() => {
trace!("Shutdown received while sleeping");
break;
}
}
}
@@ -1,12 +1,9 @@
// Copyright 2025 - Nym Technologies SA <contact@nymtech.net>
// SPDX-License-Identifier: Apache-2.0
use std::sync::Arc;
use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;
use nym_wireguard::WgApiWrapper;
use std::sync::Arc;
use tokio::task::JoinHandle;
pub(crate) mod openapi;
pub(crate) mod router;
@@ -20,7 +17,6 @@ pub(crate) mod state;
/// AFTER you have shut down BG tasks (or past their grace period).
#[allow(unused)]
pub struct ShutdownHandles {
axum_shutdown_button: CancellationToken,
/// Tokio JoinHandle for axum server's task
axum_join_handle: AxumJoinHandle,
/// Wireguard API for kernel interactions
@@ -30,13 +26,8 @@ pub struct ShutdownHandles {
impl ShutdownHandles {
/// Cancellation token is given to Axum server constructor. When the token
/// receives a shutdown signal, Axum server will shut down gracefully.
pub fn new(
axum_join_handle: AxumJoinHandle,
wg_api: Arc<WgApiWrapper>,
axum_shutdown_button: CancellationToken,
) -> Self {
pub fn new(axum_join_handle: AxumJoinHandle, wg_api: Arc<WgApiWrapper>) -> Self {
Self {
axum_shutdown_button,
axum_join_handle,
wg_api,
}
@@ -7,8 +7,8 @@ use axum::routing::get;
use axum::Router;
use core::net::SocketAddr;
use nym_http_api_common::middleware::logging::log_request_info;
use std::future::Future;
use tokio::net::TcpListener;
use tokio_util::sync::WaitForCancellationFutureOwned;
use tower_http::cors::CorsLayer;
use utoipa::OpenApi;
use utoipa_swagger_ui::SwaggerUi;
@@ -88,14 +88,17 @@ pub struct ApiHttpServer {
}
impl ApiHttpServer {
pub async fn run(self, receiver: WaitForCancellationFutureOwned) -> Result<(), std::io::Error> {
pub async fn run<F>(self, signal: F) -> Result<(), std::io::Error>
where
F: Future<Output = ()> + Send + 'static,
{
// into_make_service_with_connect_info allows us to see client ip address
axum::serve(
self.listener,
self.router
.into_make_service_with_connect_info::<SocketAddr>(),
)
.with_graceful_shutdown(receiver)
.with_graceful_shutdown(signal)
.await
}
}
+2 -2
View File
@@ -163,7 +163,7 @@ pub async fn start_wireguard(
ecash_manager: Arc<EcashManager>,
metrics: nym_node_metrics::NymNodeMetrics,
peers: Vec<Peer>,
task_client: nym_task::TaskClient,
shutdown_token: nym_task::ShutdownToken,
wireguard_data: WireguardData,
) -> Result<std::sync::Arc<WgApiWrapper>, Box<dyn std::error::Error + Send + Sync + 'static>> {
use base64::{prelude::BASE64_STANDARD, Engine};
@@ -250,7 +250,7 @@ pub async fn start_wireguard(
peer_bandwidth_managers,
wireguard_data.inner.peer_tx.clone(),
wireguard_data.peer_rx,
task_client,
shutdown_token,
);
tokio::spawn(async move { controller.run().await });
+17 -18
View File
@@ -84,7 +84,7 @@ pub struct PeerController {
host_information: Arc<RwLock<Host>>,
bw_storage_managers: HashMap<Key, SharedBandwidthStorageManager>,
timeout_check_interval: IntervalStream,
task_client: nym_task::TaskClient,
shutdown_token: nym_task::ShutdownToken,
}
impl PeerController {
@@ -97,11 +97,10 @@ impl PeerController {
bw_storage_managers: HashMap<Key, (SharedBandwidthStorageManager, Peer)>,
request_tx: mpsc::Sender<PeerControlRequest>,
request_rx: mpsc::Receiver<PeerControlRequest>,
task_client: nym_task::TaskClient,
shutdown_token: nym_task::ShutdownToken,
) -> Self {
let timeout_check_interval = tokio_stream::wrappers::IntervalStream::new(
tokio::time::interval(DEFAULT_PEER_TIMEOUT_CHECK),
);
let timeout_check_interval =
IntervalStream::new(tokio::time::interval(DEFAULT_PEER_TIMEOUT_CHECK));
let host_information = Arc::new(RwLock::new(initial_host_information));
for (public_key, (bandwidth_storage_manager, peer)) in bw_storage_managers.iter() {
let cached_peer_manager = CachedPeerManager::new(peer);
@@ -111,7 +110,7 @@ impl PeerController {
cached_peer_manager,
bandwidth_storage_manager.clone(),
request_tx.clone(),
&task_client,
&shutdown_token,
);
let public_key = public_key.clone();
tokio::spawn(async move {
@@ -132,7 +131,7 @@ impl PeerController {
request_tx,
request_rx,
timeout_check_interval,
task_client,
shutdown_token,
metrics,
}
}
@@ -191,7 +190,7 @@ impl PeerController {
cached_peer_manager,
bandwidth_storage_manager.clone(),
self.request_tx.clone(),
&self.task_client,
&self.shutdown_token,
);
self.bw_storage_managers
.insert(peer.public_key.clone(), bandwidth_storage_manager);
@@ -383,7 +382,7 @@ impl PeerController {
*self.host_information.write().await = host;
}
_ = self.task_client.recv() => {
_ = self.shutdown_token.cancelled() => {
log::trace!("PeerController handler: Received shutdown");
break;
}
@@ -513,7 +512,7 @@ pub fn start_controller(
request_rx: mpsc::Receiver<PeerControlRequest>,
) -> (
Arc<RwLock<nym_gateway_storage::traits::mock::MockGatewayStorage>>,
nym_task::TaskManager,
nym_task::ShutdownManager,
) {
use std::sync::Arc;
@@ -524,7 +523,7 @@ pub fn start_controller(
Box::new(storage.clone()),
));
let wg_api = Arc::new(MockWgApi::default());
let task_manager = nym_task::TaskManager::default();
let shutdown_manager = nym_task::ShutdownManager::empty_mock();
let mut peer_controller = PeerController::new(
ecash_manager,
Default::default(),
@@ -533,17 +532,17 @@ pub fn start_controller(
Default::default(),
request_tx,
request_rx,
task_manager.subscribe(),
shutdown_manager.child_shutdown_token(),
);
tokio::spawn(async move { peer_controller.run().await });
(storage, task_manager)
(storage, shutdown_manager)
}
#[cfg(feature = "mock")]
pub async fn stop_controller(mut task_manager: nym_task::TaskManager) {
task_manager.signal_shutdown().unwrap();
task_manager.wait_for_shutdown().await;
pub async fn stop_controller(mut shutdown_manager: nym_task::ShutdownManager) {
shutdown_manager.send_cancellation();
shutdown_manager.run_until_shutdown().await;
}
#[cfg(test)]
@@ -553,7 +552,7 @@ mod tests {
#[tokio::test]
async fn start_and_stop() {
let (request_tx, request_rx) = mpsc::channel(1);
let (_, task_manager) = start_controller(request_tx.clone(), request_rx);
stop_controller(task_manager).await;
let (_, shutdown_manager) = start_controller(request_tx.clone(), request_rx);
stop_controller(shutdown_manager).await;
}
}
+16 -16
View File
@@ -7,7 +7,7 @@ use crate::peer_storage_manager::{CachedPeerManager, PeerInformation};
use defguard_wireguard_rs::{host::Host, key::Key, net::IpAddrMask};
use futures::channel::oneshot;
use nym_credential_verification::bandwidth_storage_manager::BandwidthStorageManager;
use nym_task::TaskClient;
use nym_task::ShutdownToken;
use nym_wireguard_types::DEFAULT_PEER_TIMEOUT_CHECK;
use std::sync::Arc;
use tokio::sync::{mpsc, RwLock};
@@ -43,7 +43,7 @@ pub struct PeerHandle {
bandwidth_storage_manager: SharedBandwidthStorageManager,
request_tx: mpsc::Sender<PeerControlRequest>,
timeout_check_interval: IntervalStream,
task_client: TaskClient,
shutdown_token: ShutdownToken,
}
impl PeerHandle {
@@ -53,13 +53,12 @@ impl PeerHandle {
cached_peer: CachedPeerManager,
bandwidth_storage_manager: SharedBandwidthStorageManager,
request_tx: mpsc::Sender<PeerControlRequest>,
task_client: &TaskClient,
shutdown_token: &ShutdownToken,
) -> Self {
let timeout_check_interval = tokio_stream::wrappers::IntervalStream::new(
tokio::time::interval(DEFAULT_PEER_TIMEOUT_CHECK),
);
let mut task_client = task_client.fork(format!("peer_{public_key}"));
task_client.disarm();
let shutdown_token = shutdown_token.clone();
PeerHandle {
public_key,
host_information,
@@ -67,7 +66,7 @@ impl PeerHandle {
bandwidth_storage_manager,
request_tx,
timeout_check_interval,
task_client,
shutdown_token,
}
}
@@ -181,8 +180,18 @@ impl PeerHandle {
}
pub async fn run(&mut self) {
while !self.task_client.is_shutdown() {
loop {
tokio::select! {
biased;
_ = self.shutdown_token.cancelled() => {
log::trace!("PeerHandle: Received shutdown");
if let Err(e) = self.bandwidth_storage_manager.inner().write().await.sync_storage_bandwidth().await {
log::error!("Storage sync failed - {e}, unaccounted bandwidth might have been consumed");
}
log::trace!("PeerHandle: Finished shutdown");
break;
}
_ = self.timeout_check_interval.next() => {
match self.continue_checking().await {
Ok(true) => continue,
@@ -201,15 +210,6 @@ impl PeerHandle {
},
}
}
_ = self.task_client.recv() => {
log::trace!("PeerHandle: Received shutdown");
if let Err(e) = self.bandwidth_storage_manager.inner().write().await.sync_storage_bandwidth().await {
log::error!("Storage sync failed - {e}, unaccounted bandwidth might have been consumed");
}
log::trace!("PeerHandle: Finished shutdown");
}
}
}
}
+1
View File
@@ -1151,6 +1151,7 @@ dependencies = [
name = "nym-crypto"
version = "0.4.0"
dependencies = [
"base64 0.22.1",
"bs58",
"ed25519-dalek",
"nym-pemstore",
@@ -8,8 +8,7 @@ use futures::StreamExt;
use nym_network_requester::{GatewayPacketRouter, PacketRouter};
use nym_sphinx::addressing::clients::Recipient;
use nym_sphinx::DestinationAddressBytes;
use nym_task::TaskClient;
use tokio::task::JoinHandle;
use nym_task::ShutdownToken;
use tracing::{debug, error, trace};
#[derive(Debug)]
@@ -53,10 +52,6 @@ impl MessageRouter {
}
}
pub(crate) fn start_with_shutdown(self, shutdown: TaskClient) -> JoinHandle<()> {
tokio::spawn(self.run_with_shutdown(shutdown))
}
fn handle_received_messages(&self, messages: Vec<Vec<u8>>) {
if let Err(err) = self.packet_router.route_received(messages) {
// TODO: what should we do here? I don't think this could/should ever fail.
@@ -65,10 +60,15 @@ impl MessageRouter {
}
}
pub(crate) async fn run_with_shutdown(mut self, mut shutdown: TaskClient) {
pub(crate) async fn run_with_shutdown(mut self, shutdown: ShutdownToken) {
debug!("Started embedded client message router with graceful shutdown support");
while !shutdown.is_shutdown() {
loop {
tokio::select! {
biased;
_ = shutdown.cancelled() => {
trace!("embedded_clients::MessageRouter: Received shutdown");
break;
}
messages = self.mix_receiver.next() => match messages {
Some(messages) => self.handle_received_messages(messages),
None => {
@@ -76,11 +76,6 @@ impl MessageRouter {
break;
}
},
_ = shutdown.recv_with_delay() => {
trace!("embedded_clients::MessageRouter: Received shutdown");
debug_assert!(shutdown.is_shutdown());
break
}
}
}
@@ -29,7 +29,6 @@ use nym_gateway_storage::traits::SharedKeyGatewayStorage;
use nym_node_metrics::events::MetricsEvent;
use nym_sphinx::forwarding::packet::MixPacket;
use nym_statistics_common::{gateways::GatewaySessionEvent, types::SessionType};
use nym_task::TaskClient;
use nym_validator_client::coconut::EcashApiError;
use rand::{random, CryptoRng, Rng};
use std::{process, time::Duration};
@@ -583,7 +582,7 @@ impl<R, S> AuthenticatedHandler<R, S> {
/// Simultaneously listens for incoming client requests, which realistically should only be
/// binary requests to forward sphinx packets or increase bandwidth
/// and for sphinx packets received from the mix network that should be sent back to the client.
pub(crate) async fn listen_for_requests(mut self, mut shutdown: TaskClient)
pub(crate) async fn listen_for_requests(mut self)
where
R: Rng + CryptoRng,
S: AsyncRead + AsyncWrite + Unpin,
@@ -593,11 +592,8 @@ impl<R, S> AuthenticatedHandler<R, S> {
// Ping timeout future used to check if the client responded to our ping request
let mut ping_timeout: OptionFuture<_> = None.into();
while !shutdown.is_shutdown() {
loop {
tokio::select! {
_ = shutdown.recv() => {
trace!("client_handling::AuthenticatedHandler: received shutdown");
},
// Received a request to ping the client to check if it's still active
tx = self.is_active_request_receiver.next() => {
match tx {
@@ -32,7 +32,7 @@ use nym_gateway_storage::traits::InboxGatewayStorage;
use nym_gateway_storage::traits::SharedKeyGatewayStorage;
use nym_node_metrics::events::MetricsEvent;
use nym_sphinx::DestinationAddressBytes;
use nym_task::TaskClient;
use nym_task::ShutdownToken;
use rand::CryptoRng;
use std::net::SocketAddr;
use std::time::Duration;
@@ -127,7 +127,7 @@ pub(crate) struct FreshHandler<R, S> {
pub(crate) shared_state: CommonHandlerState,
pub(crate) socket_connection: SocketStream<S>,
pub(crate) peer_address: SocketAddr,
pub(crate) shutdown: TaskClient,
pub(crate) shutdown: ShutdownToken,
// currently unused (but populated)
pub(crate) negotiated_protocol: Option<u8>,
@@ -145,7 +145,7 @@ impl<R, S> FreshHandler<R, S> {
conn: S,
shared_state: CommonHandlerState,
peer_address: SocketAddr,
shutdown: TaskClient,
shutdown: ShutdownToken,
) -> Self {
FreshHandler {
rng,
@@ -917,60 +917,49 @@ impl<R, S> FreshHandler<R, S> {
pub(crate) async fn handle_until_authenticated_or_failure(
mut self,
shutdown: &mut TaskClient,
) -> Option<AuthenticatedHandler<R, S>>
where
S: AsyncRead + AsyncWrite + Unpin + Send,
R: CryptoRng + RngCore + Send,
{
while !shutdown.is_shutdown() {
let req = tokio::select! {
biased;
_ = shutdown.recv() => {
return None
},
req = self.wait_for_initial_message() => req,
};
let initial_request = match req {
Ok(req) => req,
Err(err) => {
self.send_and_forget_error_response(err).await;
return None;
}
};
// see if we managed to register the client through this request
let maybe_auth_res = match self.handle_initial_client_request(initial_request).await {
Ok(maybe_auth_res) => maybe_auth_res,
Err(err) => {
debug!("initial client request handling error: {err}");
self.send_and_forget_error_response(err).await;
return None;
}
};
if let Some(registration_details) = maybe_auth_res {
let (mix_sender, mix_receiver) = mpsc::unbounded();
// Channel for handlers to ask other handlers if they are still active.
let (is_active_request_sender, is_active_request_receiver) = mpsc::unbounded();
self.shared_state.active_clients_store.insert_remote(
registration_details.address,
mix_sender,
is_active_request_sender,
registration_details.session_request_timestamp,
);
return AuthenticatedHandler::upgrade(
self,
registration_details,
mix_receiver,
is_active_request_receiver,
)
.await
.inspect_err(|err| error!("failed to upgrade client handler: {err}"))
.ok();
let initial_request = match self.wait_for_initial_message().await {
Ok(req) => req,
Err(err) => {
self.send_and_forget_error_response(err).await;
return None;
}
};
// see if we managed to register the client through this request
let maybe_auth_res = match self.handle_initial_client_request(initial_request).await {
Ok(maybe_auth_res) => maybe_auth_res,
Err(err) => {
debug!("initial client request handling error: {err}");
self.send_and_forget_error_response(err).await;
return None;
}
};
if let Some(registration_details) = maybe_auth_res {
let (mix_sender, mix_receiver) = mpsc::unbounded();
// Channel for handlers to ask other handlers if they are still active.
let (is_active_request_sender, is_active_request_receiver) = mpsc::unbounded();
self.shared_state.active_clients_store.insert_remote(
registration_details.address,
mix_sender,
is_active_request_sender,
registration_details.session_request_timestamp,
);
AuthenticatedHandler::upgrade(
self,
registration_details,
mix_receiver,
is_active_request_receiver,
)
.await
.inspect_err(|err| error!("failed to upgrade client handler: {err}"))
.ok();
}
None
@@ -1031,6 +1020,15 @@ impl<R, S> FreshHandler<R, S> {
S: AsyncRead + AsyncWrite + Unpin + Send,
R: CryptoRng + RngCore + Send,
{
super::handle_connection(self).await
let remote = self.peer_address;
let shutdown = self.shutdown.clone();
tokio::select! {
_ = shutdown.cancelled() => {
trace!("received cancellation")
}
_ = super::handle_connection(self) => {
debug!("finished connection handler for {remote}")
}
}
}
}
@@ -98,15 +98,6 @@ where
R: Rng + CryptoRng + Send,
S: AsyncRead + AsyncWrite + Unpin + Send,
{
// don't accept any new requests if we have already received shutdown
if handle.shutdown.is_shutdown_poll() {
debug!("stopping the handle as we have received a shutdown");
return;
}
// If the connection handler abruptly stops, we shouldn't signal global shutdown
handle.shutdown.disarm();
match tokio::time::timeout(
WEBSOCKET_HANDSHAKE_TIMEOUT,
handle.perform_websocket_handshake(),
@@ -126,13 +117,8 @@ where
trace!("managed to perform websocket handshake!");
let mut shutdown = handle.shutdown.clone();
if let Some(auth_handle) = handle
.handle_until_authenticated_or_failure(&mut shutdown)
.await
{
auth_handle.listen_for_requests(shutdown).await
if let Some(auth_handle) = handle.handle_until_authenticated_or_failure().await {
auth_handle.listen_for_requests().await
}
trace!("the handler is done!");
@@ -3,19 +3,18 @@
use crate::node::client_handling::websocket::common_state::CommonHandlerState;
use crate::node::client_handling::websocket::connection_handler::FreshHandler;
use nym_task::TaskClient;
use nym_task::ShutdownTracker;
use rand::rngs::OsRng;
use std::net::SocketAddr;
use std::{io, process};
use tokio::net::TcpStream;
use tokio::task::JoinHandle;
use tracing::*;
pub struct Listener {
address: SocketAddr,
maximum_open_connections: usize,
shared_state: CommonHandlerState,
shutdown: TaskClient,
shutdown: ShutdownTracker,
}
impl Listener {
@@ -23,7 +22,7 @@ impl Listener {
address: SocketAddr,
maximum_open_connections: usize,
shared_state: CommonHandlerState,
shutdown: TaskClient,
shutdown: ShutdownTracker,
) -> Self {
Listener {
address,
@@ -45,15 +44,12 @@ impl Listener {
socket: TcpStream,
remote_address: SocketAddr,
) -> FreshHandler<OsRng, TcpStream> {
let shutdown = self
.shutdown
.fork(format!("websocket_handler_{remote_address}"));
FreshHandler::new(
OsRng,
socket,
self.shared_state.clone(),
remote_address,
shutdown,
self.shutdown.clone_shutdown_token(),
)
}
@@ -88,16 +84,19 @@ impl Listener {
.new_ingress_websocket_client();
// 4. spawn the task handling the client connection
tokio::spawn(async move {
// TODO: refactor it similarly to the mixnet listener on the nym-node
let metrics_ref = handle.shared_state.metrics.clone();
self.shutdown.try_spawn_named(
async move {
// TODO: refactor it similarly to the mixnet listener on the nym-node
let metrics_ref = handle.shared_state.metrics.clone();
// 4.1. handle all client requests until connection gets terminated
handle.start_handling().await;
// 4.1. handle all client requests until connection gets terminated
handle.start_handling().await;
// 4.2. decrement the connection counter
metrics_ref.network.disconnected_ingress_websocket_client();
});
// 4.2. decrement the connection counter
metrics_ref.network.disconnected_ingress_websocket_client();
},
&format!("Websocket::{remote_address}"),
);
}
Err(err) => warn!("failed to accept client connection: {err}"),
}
@@ -105,7 +104,7 @@ impl Listener {
// TODO: change the signature to pub(crate) async fn run(&self, handler: Handler)
pub(crate) async fn run(&mut self) {
pub async fn run(&mut self) {
info!("Starting websocket listener at {}", self.address);
let tcp_listener = match tokio::net::TcpListener::bind(self.address).await {
Ok(listener) => listener,
@@ -115,21 +114,18 @@ impl Listener {
}
};
while !self.shutdown.is_shutdown() {
let shutdown_token = self.shutdown.clone_shutdown_token();
loop {
tokio::select! {
biased;
_ = self.shutdown.recv() => {
_ = shutdown_token.cancelled() => {
trace!("client_handling::Listener: received shutdown");
break
}
connection = tcp_listener.accept() => {
self.try_handle_accepted_connection(connection)
}
}
}
}
pub fn start(mut self) -> JoinHandle<()> {
tokio::spawn(async move { self.run().await })
}
}
@@ -3,7 +3,7 @@
use nym_client_core::{config::disk_persistence::CommonClientPaths, TopologyProvider};
use nym_sdk::{GatewayTransceiver, NymNetworkDetails};
use nym_task::TaskClient;
use nym_task::ShutdownTracker;
use crate::node::internal_service_providers::authenticator::{
config::BaseClientConfig, error::AuthenticatorError,
@@ -15,7 +15,7 @@ use crate::node::internal_service_providers::authenticator::{
// TODO: refactor this function and its arguments
pub async fn create_mixnet_client(
config: &BaseClientConfig,
shutdown: TaskClient,
shutdown: ShutdownTracker,
custom_transceiver: Option<Box<dyn GatewayTransceiver + Send + Sync>>,
custom_topology_provider: Option<Box<dyn TopologyProvider + Send + Sync>>,
wait_for_gateway: bool,
@@ -39,7 +39,7 @@ use nym_sdk::mixnet::{
};
use nym_service_provider_requests_common::{Protocol, ServiceProviderType};
use nym_sphinx::receiver::ReconstructedMessage;
use nym_task::TaskHandle;
use nym_task::ShutdownToken;
use nym_wireguard::WireguardGatewayData;
use nym_wireguard_types::PeerPublicKey;
use rand::{prelude::IteratorRandom, thread_rng};
@@ -70,9 +70,6 @@ pub(crate) struct MixnetListener {
// The mixnet client that we use to send and receive packets from the mixnet
pub(crate) mixnet_client: nym_sdk::mixnet::MixnetClient,
// The task handle for the main loop
pub(crate) task_handle: TaskHandle,
// Registrations awaiting confirmation
pub(crate) registred_and_free: RwLock<RegistredAndFree>,
@@ -91,7 +88,6 @@ impl MixnetListener {
free_private_network_ips: PrivateIPs,
wireguard_gateway_data: WireguardGatewayData,
mixnet_client: nym_sdk::mixnet::MixnetClient,
task_handle: TaskHandle,
ecash_verifier: Arc<dyn EcashManager + Send + Sync>,
) -> Self {
let timeout_check_interval =
@@ -99,7 +95,6 @@ impl MixnetListener {
MixnetListener {
config,
mixnet_client,
task_handle,
registred_and_free: RwLock::new(RegistredAndFree::new(free_private_network_ips)),
peer_manager: PeerManager::new(wireguard_gateway_data),
ecash_verifier,
@@ -812,14 +807,18 @@ impl MixnetListener {
})
}
pub(crate) async fn run(mut self) -> Result<(), AuthenticatorError> {
pub(crate) async fn run(
mut self,
shutdown_token: ShutdownToken,
) -> Result<(), AuthenticatorError> {
tracing::info!("Using authenticator version {CURRENT_VERSION}");
let mut task_client = self.task_handle.fork("main_loop");
while !task_client.is_shutdown() {
loop {
tokio::select! {
_ = task_client.recv() => {
biased;
_ = shutdown_token.cancelled() => {
tracing::debug!("Authenticator [main loop]: received shutdown");
break;
},
_ = self.timeout_check_interval.next() => {
if let Err(e) = self.remove_stale_registrations().await {
@@ -7,7 +7,7 @@ use ipnetwork::IpNetwork;
use nym_client_core::{HardcodedTopologyProvider, TopologyProvider};
use nym_credential_verification::ecash::EcashManager;
use nym_sdk::{mixnet::Recipient, GatewayTransceiver};
use nym_task::{TaskClient, TaskHandle};
use nym_task::ShutdownTracker;
use nym_wireguard::WireguardGatewayData;
use std::{net::IpAddr, path::Path, sync::Arc, time::SystemTime};
@@ -40,7 +40,7 @@ pub struct Authenticator {
wireguard_gateway_data: WireguardGatewayData,
ecash_verifier: Arc<EcashManager>,
used_private_network_ips: Vec<IpAddr>,
shutdown: Option<TaskClient>,
shutdown: ShutdownTracker,
on_start: Option<oneshot::Sender<OnStartData>>,
}
@@ -50,6 +50,7 @@ impl Authenticator {
wireguard_gateway_data: WireguardGatewayData,
used_private_network_ips: Vec<IpAddr>,
ecash_verifier: Arc<EcashManager>,
shutdown: ShutdownTracker,
) -> Self {
Self {
config,
@@ -59,18 +60,11 @@ impl Authenticator {
ecash_verifier,
wireguard_gateway_data,
used_private_network_ips,
shutdown: None,
shutdown,
on_start: None,
}
}
#[must_use]
#[allow(unused)]
pub fn with_shutdown(mut self, shutdown: TaskClient) -> Self {
self.shutdown = Some(shutdown);
self
}
#[must_use]
#[allow(unused)]
pub fn with_wait_for_gateway(mut self, wait_for_gateway: bool) -> Self {
@@ -123,14 +117,10 @@ impl Authenticator {
pub async fn run_service_provider(self) -> Result<(), AuthenticatorError> {
// Used to notify tasks to shutdown. Not all tasks fully supports this (yet).
let task_handle: TaskHandle = self.shutdown.map(Into::into).unwrap_or_default();
// Connect to the mixnet
let mixnet_client = crate::node::internal_service_providers::authenticator::mixnet_client::create_mixnet_client(
&self.config.base,
task_handle
.get_handle()
.named("nym_sdk::MixnetClient[AUTH]"),
self.shutdown.clone(),
self.custom_gateway_transceiver,
self.custom_topology_provider,
self.wait_for_gateway,
@@ -162,7 +152,6 @@ impl Authenticator {
free_private_network_ips,
self.wireguard_gateway_data,
mixnet_client,
task_handle,
self.ecash_verifier,
);
@@ -176,6 +165,8 @@ impl Authenticator {
}
}
mixnet_listener.run().await
mixnet_listener
.run(self.shutdown.clone_shutdown_token())
.await
}
}
@@ -0,0 +1,52 @@
use nym_bin_common::logging::LoggingSettings;
use nym_network_defaults::mainnet;
use url::Url;
mod persistence;
pub use crate::service_providers::ip_packet_router::config::persistence::IpPacketRouterPaths;
pub use nym_client_core::config::Config as BaseClientConfig;
#[derive(Debug, Clone, PartialEq)]
pub struct Config {
pub base: BaseClientConfig,
pub ip_packet_router: IpPacketRouter,
pub storage_paths: IpPacketRouterPaths,
}
impl Config {
pub fn validate(&self) -> bool {
// no other sections have explicit requirements (yet)
self.base.validate()
}
#[doc(hidden)]
pub fn set_no_poisson_process(&mut self) {
self.base.set_no_poisson_process()
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct IpPacketRouter {
/// Disable Poisson sending rate.
pub disable_poisson_rate: bool,
/// Specifies the url for an upstream source of the exit policy used by this node.
pub upstream_exit_policy_url: Option<Url>,
}
impl Default for IpPacketRouter {
fn default() -> Self {
IpPacketRouter {
disable_poisson_rate: true,
#[allow(clippy::expect_used)]
upstream_exit_policy_url: Some(
mainnet::EXIT_POLICY_URL
.parse()
.expect("invalid default exit policy URL"),
),
}
}
}
@@ -17,9 +17,9 @@ use nym_network_requester::error::NetworkRequesterError;
use nym_network_requester::NRServiceProviderBuilder;
use nym_sdk::mixnet::Recipient;
use nym_sdk::{GatewayTransceiver, LocalGateway, PacketRouter};
use nym_task::TaskClient;
use nym_task::ShutdownTracker;
use std::fmt::Display;
use tokio::task::JoinHandle;
use std::marker::PhantomData;
use tracing::error;
pub mod authenticator;
@@ -91,12 +91,11 @@ impl RunnableServiceProvider for Authenticator {
pub struct ServiceProviderBeingBuilt<T: RunnableServiceProvider> {
on_start_rx: oneshot::Receiver<T::OnStartData>,
sp_builder: T,
sp_message_router_builder: SpMessageRouterBuilder,
sp_message_router_builder: SpMessageRouterBuilder<T>,
shutdown_tracker: ShutdownTracker,
}
pub struct StartedServiceProvider<T: RunnableServiceProvider> {
pub sp_join_handle: JoinHandle<()>,
pub message_router_join_handle: JoinHandle<()>,
pub on_start_data: T::OnStartData,
pub handle: LocalEmbeddedClientHandle,
}
@@ -109,26 +108,31 @@ where
pub(crate) fn new(
on_start_rx: oneshot::Receiver<T::OnStartData>,
sp_builder: T,
sp_message_router_builder: SpMessageRouterBuilder,
sp_message_router_builder: SpMessageRouterBuilder<T>,
shutdown_tracker: ShutdownTracker,
) -> Self {
ServiceProviderBeingBuilt {
on_start_rx,
sp_builder,
sp_message_router_builder,
shutdown_tracker,
}
}
pub async fn start_service_provider(
mut self,
) -> Result<StartedServiceProvider<T>, GatewayError> {
let sp_join_handle = tokio::task::spawn(async move {
if let Err(err) = self.sp_builder.run_service_provider().await {
error!(
"the {} service provider encountered an error: {err}",
T::NAME
)
}
});
self.shutdown_tracker.try_spawn_named(
async move {
if let Err(err) = self.sp_builder.run_service_provider().await {
error!(
"the {} service provider encountered an error: {err}",
T::NAME
)
}
},
&format!("{}::Provider", T::NAME),
);
// TODO: if something is blocking during SP startup, the below will wait forever
// we need to introduce additional timeouts here.
@@ -145,13 +149,10 @@ where
};
let mix_sender = self.sp_message_router_builder.mix_sender();
let message_router_join_handle = self
.sp_message_router_builder
.start_message_router(packet_router);
self.sp_message_router_builder
.start_message_router(packet_router, &self.shutdown_tracker);
Ok(StartedServiceProvider {
sp_join_handle,
message_router_join_handle,
handle: LocalEmbeddedClientHandle::new(on_start_data.address(), mix_sender),
on_start_data,
})
@@ -180,19 +181,19 @@ impl ExitServiceProviders {
}
}
pub struct SpMessageRouterBuilder {
pub struct SpMessageRouterBuilder<T> {
mix_sender: Option<MixMessageSender>,
mix_receiver: MixMessageReceiver,
router_receiver: oneshot::Receiver<PacketRouter>,
gateway_transceiver: Option<LocalGateway>,
shutdown: TaskClient,
_typ: PhantomData<T>,
}
impl SpMessageRouterBuilder {
impl<T> SpMessageRouterBuilder<T> {
pub(crate) fn new(
node_identity: ed25519::PublicKey,
forwarding_channel: MixForwardingSender,
shutdown: TaskClient,
) -> Self {
let (mix_sender, mix_receiver) = mpsc::unbounded();
let (router_tx, router_rx) = oneshot::channel();
@@ -204,7 +205,7 @@ impl SpMessageRouterBuilder {
mix_receiver,
router_receiver: router_rx,
gateway_transceiver: Some(transceiver),
shutdown,
_typ: Default::default(),
}
}
@@ -224,7 +225,17 @@ impl SpMessageRouterBuilder {
.expect("attempting to use the same mix sender twice")
}
fn start_message_router(self, packet_router: PacketRouter) -> JoinHandle<()> {
MessageRouter::new(self.mix_receiver, packet_router).start_with_shutdown(self.shutdown)
fn start_message_router(self, packet_router: PacketRouter, shutdown_tracker: &ShutdownTracker)
where
T: RunnableServiceProvider,
{
let shutdown_token = shutdown_tracker.clone_shutdown_token();
let message_router = MessageRouter::new(self.mix_receiver, packet_router);
shutdown_tracker.try_spawn_named(
async move {
message_router.run_with_shutdown(shutdown_token).await;
},
&format!("{}::MessageRouter", T::NAME),
);
}
}
+45 -58
View File
@@ -1,7 +1,6 @@
// Copyright 2020-2024 - Nym Technologies SA <contact@nymtech.net>
// SPDX-License-Identifier: GPL-3.0-only
use crate::config::Config;
use crate::error::GatewayError;
use crate::node::client_handling::websocket;
use crate::node::internal_service_providers::{
@@ -19,7 +18,7 @@ use nym_network_defaults::NymNetworkDetails;
use nym_network_requester::NRServiceProviderBuilder;
use nym_node_metrics::events::MetricEventsSender;
use nym_node_metrics::NymNodeMetrics;
use nym_task::{ShutdownToken, TaskClient};
use nym_task::ShutdownTracker;
use nym_topology::TopologyProvider;
use nym_validator_client::nyxd::{Coin, CosmWasmClient};
use nym_validator_client::{nyxd, DirectSigningHttpRpcNyxdClient};
@@ -35,6 +34,7 @@ pub(crate) mod client_handling;
pub(crate) mod internal_service_providers;
mod stale_data_cleaner;
use crate::config::Config;
use crate::node::internal_service_providers::authenticator::Authenticator;
pub use client_handling::active_clients::ActiveClientsStore;
pub use nym_gateway_stats_storage::PersistentStatsStorage;
@@ -91,9 +91,7 @@ pub struct GatewayTasksBuilder {
mnemonic: Arc<Zeroizing<bip39::Mnemonic>>,
legacy_task_client: TaskClient,
shutdown_token: ShutdownToken,
shutdown_tracker: ShutdownTracker,
// populated and cached as necessary
ecash_manager: Option<Arc<EcashManager>>,
@@ -103,14 +101,6 @@ pub struct GatewayTasksBuilder {
wireguard_networks: Option<Vec<IpAddr>>,
}
impl Drop for GatewayTasksBuilder {
fn drop(&mut self) {
// disarm the shutdown as it was already used to construct relevant tasks and we don't want the builder
// to cause shutdown
self.legacy_task_client.disarm();
}
}
impl GatewayTasksBuilder {
#[allow(clippy::too_many_arguments)]
pub fn new(
@@ -121,8 +111,7 @@ impl GatewayTasksBuilder {
metrics_sender: MetricEventsSender,
metrics: NymNodeMetrics,
mnemonic: Arc<Zeroizing<bip39::Mnemonic>>,
legacy_task_client: TaskClient,
shutdown_token: ShutdownToken,
shutdown_tracker: ShutdownTracker,
) -> GatewayTasksBuilder {
GatewayTasksBuilder {
config,
@@ -136,8 +125,7 @@ impl GatewayTasksBuilder {
metrics_sender,
metrics,
mnemonic,
legacy_task_client,
shutdown_token,
shutdown_tracker,
ecash_manager: None,
wireguard_peers: None,
wireguard_networks: None,
@@ -227,17 +215,22 @@ impl GatewayTasksBuilder {
};
let nyxd_client = self.build_nyxd_signing_client().await?;
let ecash_manager = Arc::new(
EcashManager::new(
handler_config,
nyxd_client,
self.identity_keypair.public_key().to_bytes(),
self.legacy_task_client.fork("ecash_manager"),
self.storage.clone(),
)
.await?,
let (ecash_manager, credential_handler) = EcashManager::new(
handler_config,
nyxd_client,
self.identity_keypair.public_key().to_bytes(),
self.storage.clone(),
)
.await?;
let shutdown_token = self.shutdown_tracker.clone_shutdown_token();
self.shutdown_tracker.try_spawn_named(
async move { credential_handler.run(shutdown_token).await },
"EcashCredentialHandler",
);
Ok(ecash_manager)
Ok(Arc::new(ecash_manager))
}
async fn ecash_manager(&mut self) -> Result<Arc<EcashManager>, GatewayError> {
@@ -274,7 +267,7 @@ impl GatewayTasksBuilder {
self.config.gateway.websocket_bind_address,
self.config.debug.maximum_open_connections,
shared_state,
self.legacy_task_client.fork("websocket"),
self.shutdown_tracker.clone(),
))
}
@@ -290,19 +283,17 @@ impl GatewayTasksBuilder {
let mut message_router_builder = SpMessageRouterBuilder::new(
*self.identity_keypair.public_key(),
self.mix_packet_sender.clone(),
self.legacy_task_client
.fork("network_requester_message_router"),
);
let transceiver = message_router_builder.gateway_transceiver();
let (on_start_tx, on_start_rx) = oneshot::channel();
let mut nr_builder = NRServiceProviderBuilder::new(nr_opts.config.clone())
.with_shutdown(self.legacy_task_client.fork("network_requester_sp"))
.with_custom_gateway_transceiver(transceiver)
.with_wait_for_gateway(true)
.with_minimum_gateway_performance(0)
.with_custom_topology_provider(topology_provider)
.with_on_start(on_start_tx);
let mut nr_builder =
NRServiceProviderBuilder::new(nr_opts.config.clone(), self.shutdown_tracker.clone())
.with_custom_gateway_transceiver(transceiver)
.with_wait_for_gateway(true)
.with_minimum_gateway_performance(0)
.with_custom_topology_provider(topology_provider)
.with_on_start(on_start_tx);
if let Some(custom_mixnet) = &nr_opts.custom_mixnet_path {
nr_builder = nr_builder.with_stored_topology(custom_mixnet)?
@@ -312,6 +303,7 @@ impl GatewayTasksBuilder {
on_start_rx,
nr_builder,
message_router_builder,
self.shutdown_tracker.clone(),
))
}
@@ -326,18 +318,17 @@ impl GatewayTasksBuilder {
let mut message_router_builder = SpMessageRouterBuilder::new(
*self.identity_keypair.public_key(),
self.mix_packet_sender.clone(),
self.legacy_task_client.fork("ipr_message_router"),
);
let transceiver = message_router_builder.gateway_transceiver();
let (on_start_tx, on_start_rx) = oneshot::channel();
let mut ip_packet_router = IpPacketRouter::new(ip_opts.config.clone())
.with_shutdown(self.legacy_task_client.fork("ipr_sp"))
.with_custom_gateway_transceiver(Box::new(transceiver))
.with_wait_for_gateway(true)
.with_minimum_gateway_performance(0)
.with_custom_topology_provider(topology_provider)
.with_on_start(on_start_tx);
let mut ip_packet_router =
IpPacketRouter::new(ip_opts.config.clone(), self.shutdown_tracker.clone())
.with_custom_gateway_transceiver(Box::new(transceiver))
.with_wait_for_gateway(true)
.with_minimum_gateway_performance(0)
.with_custom_topology_provider(topology_provider)
.with_on_start(on_start_tx);
if let Some(custom_mixnet) = &ip_opts.custom_mixnet_path {
ip_packet_router = ip_packet_router.with_stored_topology(custom_mixnet)?
@@ -347,6 +338,7 @@ impl GatewayTasksBuilder {
on_start_rx,
ip_packet_router,
message_router_builder,
self.shutdown_tracker.clone(),
))
}
@@ -432,7 +424,6 @@ impl GatewayTasksBuilder {
let mut message_router_builder = SpMessageRouterBuilder::new(
*self.identity_keypair.public_key(),
self.mix_packet_sender.clone(),
self.legacy_task_client.fork("authenticator_message_router"),
);
let transceiver = message_router_builder.gateway_transceiver();
@@ -443,9 +434,9 @@ impl GatewayTasksBuilder {
wireguard_data.inner.clone(),
used_private_network_ips,
ecash_manager,
self.shutdown_tracker.clone(),
)
.with_custom_gateway_transceiver(transceiver)
.with_shutdown(self.legacy_task_client.fork("authenticator_sp"))
.with_wait_for_gateway(true)
.with_minimum_gateway_performance(0)
.with_custom_topology_provider(topology_provider)
@@ -459,13 +450,13 @@ impl GatewayTasksBuilder {
on_start_rx,
authenticator_server,
message_router_builder,
self.shutdown_tracker.clone(),
))
}
pub fn build_stale_messages_cleaner(&self) -> StaleMessagesCleaner {
StaleMessagesCleaner::new(
&self.storage,
self.legacy_task_client.fork("stale_messages_cleaner"),
self.config.debug.stale_messages_max_age,
self.config.debug.stale_messages_cleaner_run_interval,
)
@@ -476,7 +467,7 @@ impl GatewayTasksBuilder {
&mut self,
) -> Result<Arc<nym_wireguard::WgApiWrapper>, Box<dyn std::error::Error + Send + Sync>> {
let _ = self.metrics.clone();
let _ = self.shutdown_token.clone();
let _ = self.shutdown_tracker.clone();
unimplemented!("wireguard is not supported on this platform")
}
@@ -517,26 +508,22 @@ impl GatewayTasksBuilder {
ecash_manager,
self.metrics.clone(),
all_peers,
self.legacy_task_client.fork("wireguard"),
self.shutdown_tracker.clone_shutdown_token(),
wireguard_data,
)
.await?;
let server = router.build_server(&bind_address).await?;
let cancel_token: tokio_util::sync::CancellationToken = (*self.shutdown_token).clone();
let axum_shutdown_receiver = cancel_token.clone().cancelled_owned();
let cancel_token = self.shutdown_tracker.clone_shutdown_token();
let server_handle = tokio::spawn(async move {
{
info!("Started Wireguard Axum HTTP V2 server on {bind_address}");
server.run(axum_shutdown_receiver).await
server.run(cancel_token.cancelled_owned()).await
}
});
let shutdown_handles = nym_wireguard_private_metadata_server::ShutdownHandles::new(
server_handle,
wg_handle,
cancel_token,
);
let shutdown_handles =
nym_wireguard_private_metadata_server::ShutdownHandles::new(server_handle, wg_handle);
Ok(shutdown_handles)
}
+5 -12
View File
@@ -2,16 +2,14 @@
// SPDX-License-Identifier: GPL-3.0-only
use nym_gateway_storage::{GatewayStorage, InboxManager};
use nym_task::TaskClient;
use nym_task::ShutdownToken;
use std::error::Error;
use std::time::Duration;
use time::OffsetDateTime;
use tokio::task::JoinHandle;
use tracing::{debug, trace, warn};
pub struct StaleMessagesCleaner {
inbox_manager: InboxManager,
task_client: TaskClient,
max_message_age: Duration,
run_interval: Duration,
}
@@ -19,13 +17,11 @@ pub struct StaleMessagesCleaner {
impl StaleMessagesCleaner {
pub(crate) fn new(
storage: &GatewayStorage,
task_client: TaskClient,
max_message_age: Duration,
run_interval: Duration,
) -> Self {
StaleMessagesCleaner {
inbox_manager: storage.inbox_manager().clone(),
task_client,
max_message_age,
run_interval,
}
@@ -36,13 +32,14 @@ impl StaleMessagesCleaner {
self.inbox_manager.remove_stale(cutoff).await
}
async fn run(&mut self) {
pub async fn run(&mut self, shutdown_token: ShutdownToken) {
let mut interval = tokio::time::interval(self.run_interval);
while !self.task_client.is_shutdown() {
loop {
tokio::select! {
biased;
_ = self.task_client.recv() => {
_ = shutdown_token.cancelled() => {
trace!("StaleMessagesCleaner: received shutdown");
break;
}
_ = interval.tick() => {
if let Err(err) = self.clean_up_stale_messages().await {
@@ -53,8 +50,4 @@ impl StaleMessagesCleaner {
}
debug!("StaleMessagesCleaner: Exiting");
}
pub fn start(mut self) -> JoinHandle<()> {
tokio::spawn(async move { self.run().await })
}
}
+7 -5
View File
@@ -282,8 +282,13 @@ impl<R: RngCore + CryptoRng + Clone> DkgController<R> {
let mut last_polled = OffsetDateTime::now_utc();
let mut last_tick_duration = Default::default();
while !shutdown.is_cancelled() {
loop {
tokio::select! {
biased;
_ = shutdown.cancelled() => {
trace!("DkgController: Received shutdown");
break;
}
_ = interval.tick() => {
let now = OffsetDateTime::now_utc();
let tick_duration = now - last_polled;
@@ -300,9 +305,6 @@ impl<R: RngCore + CryptoRng + Clone> DkgController<R> {
error!("failed to update the DKG state: {err}")
}
}
_ = shutdown.cancelled() => {
trace!("DkgController: Received shutdown");
}
}
}
}
@@ -319,7 +321,7 @@ impl<R: RngCore + CryptoRng + Clone> DkgController<R> {
where
R: Sync + Send + 'static,
{
let shutdown_listener = shutdown_manager.clone_token("DKG controller");
let shutdown_listener = shutdown_manager.clone_shutdown_token();
let dkg_controller = DkgController::new(
config,
nyxd_client,
+1 -1
View File
@@ -138,7 +138,7 @@ impl EcashState {
EcashBackgroundStateCleaner::new(
global_config,
storage.clone(),
shutdown_manager.clone_token("ecash-state-data-cleaner"),
shutdown_manager.clone_shutdown_token(),
),
),
global: GlobalEcachState::new(contract_address),
+1 -1
View File
@@ -266,7 +266,7 @@ impl EpochAdvancer {
described_cache,
storage.to_owned(),
);
let shutdown_listener = shutdown_manager.clone_token("epoch-advancer");
let shutdown_listener = shutdown_manager.clone_shutdown_token();
tokio::spawn(async move { epoch_advancer.run(shutdown_listener).await });
}
}
+2 -1
View File
@@ -114,11 +114,12 @@ impl KeyRotationController {
self.contract_cache.naive_wait_for_initial_values().await;
self.handle_contract_cache_update().await;
while !shutdown_token.is_cancelled() {
loop {
tokio::select! {
biased;
_ = shutdown_token.cancelled() => {
trace!("KeyRotationController: Received shutdown");
break;
}
_ = self.contract_cache_watcher.changed() => {
self.handle_contract_cache_update().await
+9 -4
View File
@@ -25,7 +25,7 @@ use nym_crypto::asymmetric::{ed25519, x25519};
use nym_sphinx::acknowledgements::AckKey;
use nym_sphinx::params::PacketType;
use nym_sphinx::receiver::MessageReceiver;
use nym_task::ShutdownManager;
use nym_task::{ShutdownManager, ShutdownToken};
use std::sync::Arc;
use tracing::info;
@@ -84,6 +84,7 @@ impl<'a> NetworkMonitorBuilder<'a> {
pub(crate) async fn build<R: MessageReceiver + Send + Sync + 'static>(
self,
shutdown_token: ShutdownToken,
) -> NetworkMonitorRunnables<R> {
// TODO: those keys change constant throughout the whole execution of the monitor.
// and on top of that, they are used with ALL the gateways -> presumably this should change
@@ -127,6 +128,7 @@ impl<'a> NetworkMonitorBuilder<'a> {
gateway_status_update_sender,
Arc::clone(&identity_keypair),
bandwidth_controller,
shutdown_token,
);
let received_processor = new_received_processor(
@@ -170,9 +172,9 @@ impl<R: MessageReceiver + Send + Sync + 'static> NetworkMonitorRunnables<R> {
pub(crate) fn spawn_tasks(self, shutdown: &ShutdownManager) {
let mut packet_receiver = self.packet_receiver;
let mut monitor = self.monitor;
let shutdown_listener = shutdown.clone_token("NM-packet-receiver");
let shutdown_listener = shutdown.clone_shutdown_token();
tokio::spawn(async move { packet_receiver.run(shutdown_listener).await });
let shutdown_listener = shutdown.clone_token("NM-main");
let shutdown_listener = shutdown.clone_shutdown_token();
tokio::spawn(async move { monitor.run(shutdown_listener).await });
}
}
@@ -202,12 +204,14 @@ fn new_packet_sender(
gateways_status_updater: GatewayClientUpdateSender,
local_identity: Arc<ed25519::KeyPair>,
bandwidth_controller: BandwidthController<nyxd::Client, PersistentStorage>,
shutdown_token: ShutdownToken,
) -> PacketSender {
PacketSender::new(
config,
gateways_status_updater,
local_identity,
bandwidth_controller,
shutdown_token,
)
}
@@ -252,6 +256,7 @@ pub(crate) async fn start<R: MessageReceiver + Send + Sync + 'static>(
nyxd_client,
);
info!("Starting network monitor...");
let runnables: NetworkMonitorRunnables<R> = monitor_builder.build().await;
let runnables: NetworkMonitorRunnables<R> =
monitor_builder.build(shutdown.clone_shutdown_token()).await;
runnables.spawn_tasks(shutdown);
}
+8 -4
View File
@@ -334,20 +334,24 @@ impl<R: MessageReceiver + Send + Sync> Monitor<R> {
.await;
let mut run_interval = tokio::time::interval(self.run_interval);
while !shutdown_token.is_cancelled() {
loop {
tokio::select! {
biased;
_ = shutdown_token.cancelled() => {
trace!("UpdateHandler: Received shutdown");
break;
}
_ = run_interval.tick() => {
tokio::select! {
biased;
_ = shutdown_token.cancelled() => {
trace!("UpdateHandler: Received shutdown");
break;
}
_ = self.test_run() => (),
}
}
_ = shutdown_token.cancelled() => {
trace!("UpdateHandler: Received shutdown");
}
}
}
}
@@ -58,11 +58,12 @@ impl PacketReceiver {
}
pub(crate) async fn run(&mut self, shutdown_token: ShutdownToken) {
while !shutdown_token.is_cancelled() {
loop {
tokio::select! {
biased;
_ = shutdown_token.cancelled() => {
trace!("UpdateHandler: Received shutdown");
break;
}
// unwrap here is fine as it can only return a `None` if the PacketSender has died
// and if that was the case, then the entire monitor is already in an undefined state
@@ -20,6 +20,7 @@ use nym_gateway_client::{
AcknowledgementReceiver, GatewayClient, MixnetMessageReceiver, PacketRouter, SharedGatewayKey,
};
use nym_sphinx::forwarding::packet::MixPacket;
use nym_task::ShutdownToken;
use pin_project::pin_project;
use sqlx::__rt::timeout;
use std::mem;
@@ -91,6 +92,7 @@ impl GatewayPackets {
struct FreshGatewayClientData {
gateways_status_updater: GatewayClientUpdateSender,
local_identity: Arc<ed25519::KeyPair>,
shutdown_token: ShutdownToken,
gateway_response_timeout: Duration,
bandwidth_controller: BandwidthController<nyxd::Client, PersistentStorage>,
disabled_credentials_mode: bool,
@@ -127,11 +129,13 @@ impl PacketSender {
gateways_status_updater: GatewayClientUpdateSender,
local_identity: Arc<ed25519::KeyPair>,
bandwidth_controller: BandwidthController<nyxd::Client, PersistentStorage>,
shutdown_token: ShutdownToken,
) -> Self {
PacketSender {
fresh_gateway_client_data: Arc::new(FreshGatewayClientData {
gateways_status_updater,
local_identity,
shutdown_token,
gateway_response_timeout: config.network_monitor.debug.gateway_response_timeout,
bandwidth_controller,
disabled_credentials_mode: config.network_monitor.debug.disabled_credentials_mode,
@@ -154,10 +158,6 @@ impl PacketSender {
GatewayClientHandle,
(MixnetMessageReceiver, AcknowledgementReceiver),
) {
// I think the proper one should be passed around instead...
let task_client =
nym_task::TaskClient::dummy().named(format!("gateway-{}", config.gateway_identity));
let (message_sender, message_receiver) = mpsc::unbounded();
// currently we do not care about acks at all, but we must keep the channel alive
@@ -167,7 +167,7 @@ impl PacketSender {
let gateway_packet_router = PacketRouter::new(
ack_sender,
message_sender,
task_client.fork("packet_router"),
fresh_gateway_client_data.shutdown_token.clone(),
);
let shared_keys = fresh_gateway_client_data
@@ -186,11 +186,11 @@ impl PacketSender {
Some(fresh_gateway_client_data.bandwidth_controller.clone()),
nym_statistics_common::clients::ClientStatsSender::new(
None,
task_client.fork("client_stats_sender"),
fresh_gateway_client_data.shutdown_token.clone(),
),
#[cfg(unix)]
None,
task_client,
fresh_gateway_client_data.shutdown_token.clone(),
);
(
@@ -45,7 +45,7 @@ pub(crate) async fn start_cache_refresher(
.with_update_fn(move |main_cache, update| {
refresher_update_fn(main_cache, update, values_to_retain)
})
.start(shutdown_manager.clone_token("performance-contract-cache-refresher"));
.start(shutdown_manager.clone_shutdown_token());
Ok(warmed_up_cache)
}
+5 -1
View File
@@ -69,11 +69,12 @@ impl NodeStatusCacheRefresher {
pub async fn run(&mut self, shutdown_token: ShutdownToken) {
let mut last_update = OffsetDateTime::now_utc();
let mut fallback_interval = time::interval(self.fallback_caching_interval);
while !shutdown_token.is_cancelled() {
loop {
tokio::select! {
biased;
_ = shutdown_token.cancelled() => {
trace!("NodeStatusCacheRefresher: Received shutdown");
break;
}
// Update node status cache when the contract cache / describe cache is updated
Ok(_) = self.mixnet_contract_cache_listener.changed() => {
@@ -81,6 +82,7 @@ impl NodeStatusCacheRefresher {
_ = self.maybe_refresh(&mut fallback_interval, &mut last_update) => (),
_ = shutdown_token.cancelled() => {
trace!("NodeStatusCacheRefresher: Received shutdown");
break;
}
}
}
@@ -89,6 +91,7 @@ impl NodeStatusCacheRefresher {
_ = self.maybe_refresh(&mut fallback_interval, &mut last_update) => (),
_ = shutdown_token.cancelled() => {
trace!("NodeStatusCacheRefresher: Received shutdown");
break;
}
}
}
@@ -99,6 +102,7 @@ impl NodeStatusCacheRefresher {
_ = self.maybe_refresh(&mut fallback_interval, &mut last_update) => (),
_ = shutdown_token.cancelled() => {
trace!("NodeStatusCacheRefresher: Received shutdown");
break;
}
}
}
+1 -1
View File
@@ -50,6 +50,6 @@ pub(crate) fn start_cache_refresh(
described_cache_cache_listener,
performance_provider,
);
let shutdown_listener = shutdown_manager.clone_token("node-status-refresher");
let shutdown_listener = shutdown_manager.clone_shutdown_token();
tokio::spawn(async move { nym_api_cache_refresher.run(shutdown_listener).await });
}
@@ -98,11 +98,12 @@ impl HistoricalUptimeUpdater {
let start = Instant::now() + time_left;
let mut interval = interval_at(start, ONE_DAY);
while !shutdown_token.is_cancelled() {
loop {
tokio::select! {
biased;
_ = shutdown_token.cancelled() => {
trace!("UpdateHandler: Received shutdown");
break;
}
_ = interval.tick() => {
info!("updating historical uptimes of nodes");
@@ -118,7 +119,7 @@ impl HistoricalUptimeUpdater {
pub(crate) fn start(storage: NymApiStorage, shutdown: &ShutdownManager) {
let uptime_updater = HistoricalUptimeUpdater::new(storage);
let shutdown_listener = shutdown.child_token("uptime-updater");
let shutdown_listener = shutdown.child_shutdown_token();
tokio::spawn(async move { uptime_updater.run(shutdown_listener).await });
}
}
+1 -1
View File
@@ -23,7 +23,7 @@ pub(crate) fn start_refresher(
.named("signers-cache-refresher");
let shared_cache = refresher.get_shared_cache();
refresher.start_with_delay(
shutdown_manager.clone_token("signers-cache-refresher"),
shutdown_manager.clone_shutdown_token(),
config.debug.refresher_start_delay,
);
shared_cache
+3 -2
View File
@@ -239,11 +239,12 @@ where
self.provider.wait_until_ready().await;
let mut refresh_interval = interval(self.refreshing_interval);
while !shutdown_token.is_cancelled() {
loop {
tokio::select! {
biased;
_ = shutdown_token.cancelled() => {
trace!("{}: Received shutdown", self.name)
trace!("{}: Received shutdown", self.name);
break
}
_ = refresh_interval.tick() => self.refresh(&shutdown_token).await,
// note: `Notify` is not cancellation safe, HOWEVER, there's only one listener,
+9 -9
View File
@@ -120,7 +120,7 @@ pub(crate) struct Args {
}
async fn start_nym_api_tasks(config: &Config) -> anyhow::Result<ShutdownManager> {
let shutdown_manager = ShutdownManager::new("nym-api")
let shutdown_manager = ShutdownManager::build_new_default()?
.with_shutdown_duration(Duration::from_secs(TASK_MANAGER_TIMEOUT_S));
let nyxd_client = nyxd::Client::new(config)?;
@@ -256,8 +256,8 @@ async fn start_nym_api_tasks(config: &Config) -> anyhow::Result<ShutdownManager>
let describe_cache_refresh_requester = describe_cache_refresher.refresh_requester();
let describe_cache_watcher = describe_cache_refresher
.start_with_watcher(shutdown_manager.clone_token("node-self-described-data-refresher"));
let describe_cache_watcher =
describe_cache_refresher.start_with_watcher(shutdown_manager.clone_shutdown_token());
let performance_provider = if config.performance_provider.use_performance_contract_data {
if network_details
@@ -289,8 +289,8 @@ async fn start_nym_api_tasks(config: &Config) -> anyhow::Result<ShutdownManager>
};
// start all the caches first
let contract_cache_watcher = mixnet_contract_cache_refresher
.start_with_watcher(shutdown_manager.clone_token("contracts-data-refresher"));
let contract_cache_watcher =
mixnet_contract_cache_refresher.start_with_watcher(shutdown_manager.clone_shutdown_token());
node_status_api::start_cache_refresh(
&config.node_status_api,
@@ -359,12 +359,12 @@ async fn start_nym_api_tasks(config: &Config) -> anyhow::Result<ShutdownManager>
contract_cache_watcher,
mixnet_contract_cache_state,
)
.start(shutdown_manager.clone_token("KeyRotationController"));
.start(shutdown_manager.clone_shutdown_token());
let bind_address = config.base.bind_address.to_owned();
let server = router.build_server(&bind_address).await?;
let http_shutdown = shutdown_manager.clone_token("axum-http");
let http_shutdown = shutdown_manager.clone_shutdown_token();
tokio::spawn(async move {
{
info!("Started Axum HTTP V2 server on {bind_address}");
@@ -372,7 +372,7 @@ async fn start_nym_api_tasks(config: &Config) -> anyhow::Result<ShutdownManager>
}
});
shutdown_manager.close();
shutdown_manager.close_tracker();
Ok(shutdown_manager)
}
@@ -385,7 +385,7 @@ pub(crate) async fn execute(args: Args) -> anyhow::Result<()> {
config.validate()?;
let shutdown_manager = start_nym_api_tasks(&config).await?;
let mut shutdown_manager = start_nym_api_tasks(&config).await?;
shutdown_manager.run_until_shutdown().await;
Ok(())
+2 -1
View File
@@ -40,6 +40,7 @@ tracing-indicatif = { workspace = true }
tracing-subscriber.workspace = true
tokio = { workspace = true, features = ["macros", "sync", "rt-multi-thread"] }
tokio-util = { workspace = true, features = ["codec"] }
tokio-stream = { workspace = true }
toml = { workspace = true }
url = { workspace = true, features = ["serde"] }
zeroize = { workspace = true, features = ["zeroize_derive"] }
@@ -129,7 +130,7 @@ criterion = { workspace = true, features = ["async_tokio"] }
rand_chacha = { workspace = true }
[features]
tokio-console = ["console-subscriber"]
tokio-console = ["console-subscriber", "nym-task/tokio-tracing"]
[lints]
workspace = true
-1
View File
@@ -159,7 +159,6 @@ pub fn gateway_tasks_config(config: &Config) -> GatewayTasksConfig {
.to_common_client_paths(),
ip_packet_router_description: Default::default(),
},
logging: config.logging,
},
custom_mixnet_path: None,
+1 -1
View File
@@ -187,7 +187,7 @@ pub(crate) async fn get_current_rotation_id(
nym_apis: &[Url],
fallback_nyxd: &[Url],
) -> Result<u32, NymNodeError> {
let apis_client = NymApisClient::new(nym_apis, ShutdownToken::ephemeral())?;
let apis_client = NymApisClient::new(nym_apis, ShutdownToken::default())?;
if let Ok(rotation_info) = apis_client.get_key_rotation_info().await.map(|r| r.details) {
if rotation_info.is_epoch_stuck() {
return Err(NymNodeError::StuckEpoch);
+1 -1
View File
@@ -349,7 +349,7 @@ impl KeyRotationController {
let state_update_future = sleep(next_action.until_deadline());
pin_mut!(state_update_future);
while !self.shutdown_token.is_cancelled() {
loop {
tokio::select! {
biased;
_ = self.shutdown_token.cancelled() => {
+3 -10
View File
@@ -12,7 +12,6 @@ use std::any::TypeId;
use std::collections::HashMap;
use std::ops::DerefMut;
use std::time::Duration;
use tokio::task::JoinHandle;
use tokio::time::{interval_at, Instant};
use tracing::{debug, error, trace, warn};
@@ -25,11 +24,10 @@ pub(crate) struct MetricsAggregator {
// registered_handlers: HashMap<TypeId, Box<dyn Any + Send + Sync + 'static>>,
event_sender: MetricEventsSender,
event_receiver: MetricEventsReceiver,
shutdown: ShutdownToken,
}
impl MetricsAggregator {
pub fn new(handlers_update_interval: Duration, shutdown: ShutdownToken) -> Self {
pub fn new(handlers_update_interval: Duration) -> Self {
let (event_sender, event_receiver) = events_channels();
MetricsAggregator {
@@ -37,7 +35,6 @@ impl MetricsAggregator {
registered_handlers: Default::default(),
event_sender,
event_receiver,
shutdown,
}
}
@@ -106,7 +103,7 @@ impl MetricsAggregator {
}
}
pub async fn run(&mut self) {
pub async fn run(&mut self, shutdown_token: ShutdownToken) {
self.on_start().await;
let start = Instant::now() + self.handlers_update_interval;
@@ -117,7 +114,7 @@ impl MetricsAggregator {
loop {
tokio::select! {
biased;
_ = self.shutdown.cancelled() => {
_ = shutdown_token.cancelled() => {
debug!("MetricsAggregator: Received shutdown");
break;
}
@@ -144,8 +141,4 @@ impl MetricsAggregator {
}
trace!("MetricsAggregator: Exiting");
}
pub fn start(mut self) -> JoinHandle<()> {
tokio::spawn(async move { self.run().await })
}
}
+16 -26
View File
@@ -1,15 +1,15 @@
// Copyright 2024 - Nym Technologies SA <contact@nymtech.net>
// SPDX-License-Identifier: GPL-3.0-only
use futures::StreamExt;
use human_repr::HumanCount;
use human_repr::HumanThroughput;
use nym_node_metrics::NymNodeMetrics;
use nym_task::ShutdownToken;
use std::time::Duration;
use time::OffsetDateTime;
use tokio::task::JoinHandle;
use tokio::time::{interval_at, Instant};
use tracing::{info, trace};
use tokio_stream::wrappers::IntervalStream;
use tracing::{error, info, trace};
struct AtLastUpdate {
time: OffsetDateTime,
@@ -49,20 +49,14 @@ pub(crate) struct ConsoleLogger {
logging_delay: Duration,
at_last_update: AtLastUpdate,
metrics: NymNodeMetrics,
shutdown: ShutdownToken,
}
impl ConsoleLogger {
pub(crate) fn new(
logging_delay: Duration,
metrics: NymNodeMetrics,
shutdown: ShutdownToken,
) -> Self {
pub(crate) fn new(logging_delay: Duration, metrics: NymNodeMetrics) -> Self {
ConsoleLogger {
logging_delay,
at_last_update: AtLastUpdate::new(),
metrics,
shutdown,
}
}
@@ -123,23 +117,19 @@ impl ConsoleLogger {
// TODO: add websocket-client traffic
}
async fn run(&mut self) {
pub(crate) async fn run(&mut self) {
trace!("Starting ConsoleLogger");
let mut interval = interval_at(Instant::now() + self.logging_delay, self.logging_delay);
loop {
tokio::select! {
biased;
_ = self.shutdown.cancelled() => {
trace!("ConsoleLogger: Received shutdown");
break
}
_ = interval.tick() => self.log_running_stats().await,
};
}
trace!("ConsoleLogger: Exiting");
}
pub(crate) fn start(mut self) -> JoinHandle<()> {
tokio::spawn(async move { self.run().await })
let mut stream = IntervalStream::new(interval_at(
Instant::now() + self.logging_delay,
self.logging_delay,
));
while stream.next().await.is_some() {
self.log_running_stats().await
}
// this should never get triggered
error!("console logger interval has been exhausted!")
}
}
+4 -5
View File
@@ -91,7 +91,6 @@ impl Drop for ConnectionHandler {
impl ConnectionHandler {
pub(crate) fn new(shared: &SharedData, remote_address: SocketAddr) -> Self {
let shutdown = shared.shutdown.child_token(remote_address.to_string());
shared.metrics.network.new_active_ingress_mixnet_client();
ConnectionHandler {
@@ -103,7 +102,7 @@ impl ConnectionHandler {
final_hop: shared.final_hop.clone(),
noise_config: shared.noise_config.clone(),
metrics: shared.metrics.clone(),
shutdown,
shutdown_token: shared.shutdown_token.child_token(),
},
remote_address,
pending_packets: PendingReplayCheckPackets::new(),
@@ -369,7 +368,7 @@ impl ConnectionHandler {
Some(Err(_)) => {
// our mutex got poisoned - we have to shut down
error!("CRITICAL FAILURE: replay bloomfilter mutex poisoning!");
self.shared.shutdown.cancel();
self.shared.shutdown_token.cancel();
return false;
}
};
@@ -394,7 +393,7 @@ impl ConnectionHandler {
else {
// our mutex got poisoned - we have to shut down
error!("CRITICAL FAILURE: replay bloomfilter mutex poisoning!");
self.shared.shutdown.cancel();
self.shared.shutdown_token.cancel();
return;
};
@@ -489,7 +488,7 @@ impl ConnectionHandler {
loop {
tokio::select! {
biased;
_ = self.shared.shutdown.cancelled() => {
_ = self.shared.shutdown_token.cancelled() => {
trace!("connection handler: received shutdown");
break
}
+3 -10
View File
@@ -4,12 +4,10 @@
use crate::node::mixnet::SharedData;
use nym_task::ShutdownToken;
use std::net::SocketAddr;
use tokio::task::JoinHandle;
use tracing::{debug, error, info, trace};
pub(crate) struct Listener {
bind_address: SocketAddr,
shutdown: ShutdownToken,
shared_data: SharedData,
}
@@ -17,19 +15,18 @@ impl Listener {
pub(crate) fn new(bind_address: SocketAddr, shared_data: SharedData) -> Self {
Listener {
bind_address,
shutdown: shared_data.shutdown.clone_with_suffix("socket-listener"),
shared_data,
}
}
pub(crate) async fn run(&mut self) {
pub(crate) async fn run(&mut self, shutdown: ShutdownToken) {
info!("attempting to run mixnet listener on {}", self.bind_address);
let tcp_listener = match tokio::net::TcpListener::bind(self.bind_address).await {
Ok(listener) => listener,
Err(err) => {
error!("Failed to bind to {}: {err}. Are you sure nothing else is running on the specified port and your user has sufficient permission to bind to the requested address?", self.bind_address);
self.shutdown.cancel();
shutdown.cancel();
return;
}
};
@@ -37,7 +34,7 @@ impl Listener {
loop {
tokio::select! {
biased;
_ = self.shutdown.cancelled() => {
_ = shutdown.cancelled() => {
trace!("mixnet listener: received shutdown");
break
}
@@ -48,8 +45,4 @@ impl Listener {
}
debug!("mixnet socket listener: Exiting");
}
pub(crate) fn start(mut self) -> JoinHandle<()> {
tokio::spawn(async move { self.run().await })
}
}
@@ -26,16 +26,10 @@ pub struct PacketForwarder<C, F> {
packet_sender: MixForwardingSender,
packet_receiver: MixForwardingReceiver,
shutdown: ShutdownToken,
}
impl<C, F> PacketForwarder<C, F> {
pub fn new(
client: C,
routing_filter: F,
metrics: NymNodeMetrics,
shutdown: ShutdownToken,
) -> Self {
pub fn new(client: C, routing_filter: F, metrics: NymNodeMetrics) -> Self {
let (packet_sender, packet_receiver) = mix_forwarding_channels();
PacketForwarder {
@@ -45,7 +39,6 @@ impl<C, F> PacketForwarder<C, F> {
routing_filter,
packet_sender,
packet_receiver,
shutdown,
}
}
@@ -127,7 +120,7 @@ impl<C, F> PacketForwarder<C, F> {
.update_packet_forwarder_queue_size(channel_size)
}
pub async fn run(&mut self)
pub async fn run(&mut self, shutdown_token: ShutdownToken)
where
C: SendWithoutResponse,
F: RoutingFilter,
@@ -137,7 +130,7 @@ impl<C, F> PacketForwarder<C, F> {
loop {
tokio::select! {
biased;
_ = self.shutdown.cancelled() => {
_ = shutdown_token.cancelled() => {
debug!("PacketForwarder: Received shutdown");
break;
}
+7 -6
View File
@@ -63,7 +63,7 @@ impl ProcessingConfig {
}
}
// explicitly do NOT derive clone as we want to manually apply relevant suffixes to the task clients
// explicitly do NOT derive clone as we want the childs to use CHILD shutdown tokens
pub(crate) struct SharedData {
pub(super) processing_config: ProcessingConfig,
pub(super) sphinx_keys: ActiveSphinxKeys,
@@ -79,7 +79,8 @@ pub(crate) struct SharedData {
pub(super) noise_config: NoiseConfig,
pub(super) metrics: NymNodeMetrics,
pub(super) shutdown: ShutdownToken,
pub(super) shutdown_token: ShutdownToken,
}
fn convert_to_metrics_version(processed: MixPacketVersion) -> PacketKind {
@@ -99,7 +100,7 @@ impl SharedData {
final_hop: SharedFinalHopData,
noise_config: NoiseConfig,
metrics: NymNodeMetrics,
shutdown: ShutdownToken,
shutdown_token: ShutdownToken,
) -> Self {
SharedData {
processing_config,
@@ -109,7 +110,7 @@ impl SharedData {
final_hop,
noise_config,
metrics,
shutdown,
shutdown_token,
}
}
@@ -188,10 +189,10 @@ impl SharedData {
.mixnet_forwarder
.forward_packet(PacketToForward::new(packet, delay_until))
.is_err()
&& !self.shutdown.is_cancelled()
&& !self.shutdown_token.is_cancelled()
{
error!("failed to forward sphinx packet on the channel while the process is not going through the shutdown!");
self.shutdown.cancel();
self.shutdown_token.cancel();
}
}
+72 -58
View File
@@ -54,7 +54,7 @@ use nym_noise::config::{NoiseConfig, NoiseNetworkView};
use nym_noise_keys::VersionedNoiseKey;
use nym_sphinx_acknowledgements::AckKey;
use nym_sphinx_addressing::Recipient;
use nym_task::{ShutdownManager, ShutdownToken, TaskClient};
use nym_task::{ShutdownManager, ShutdownToken, ShutdownTracker};
use nym_validator_client::UserAgent;
use nym_verloc::measurements::SharedVerlocStats;
use nym_verloc::{self, measurements::VerlocMeasurer};
@@ -465,19 +465,21 @@ impl NymNode {
wireguard: Some(wireguard_data),
config,
accepted_operator_terms_and_conditions: false,
shutdown_manager: ShutdownManager::new("NymNode")
.with_legacy_task_manager()
.with_default_shutdown_signals()
shutdown_manager: ShutdownManager::build_new_default()
.map_err(|source| NymNodeError::ShutdownSignalFailure { source })?,
})
}
pub(crate) fn config(&self) -> &Config {
&self.config
pub(crate) fn shutdown_tracker(&self) -> &ShutdownTracker {
self.shutdown_manager.shutdown_tracker()
}
pub(crate) fn shutdown_token<S: Into<String>>(&self, child_suffix: S) -> ShutdownToken {
self.shutdown_manager.clone_token(child_suffix)
pub(crate) fn shutdown_token(&self) -> ShutdownToken {
self.shutdown_manager.clone_shutdown_token()
}
pub(crate) fn config(&self) -> &Config {
&self.config
}
pub(crate) fn with_accepted_operator_terms_and_conditions(
@@ -561,7 +563,7 @@ impl NymNode {
self.config.mixnet.nym_api_urls.clone(),
self.config.debug.topology_cache_ttl,
self.config.debug.routing_nodes_check_interval,
self.shutdown_manager.clone_token("network-refresher"),
self.shutdown_manager.clone_shutdown_token(),
)
.await
}
@@ -605,8 +607,6 @@ impl NymNode {
metrics_sender: MetricEventsSender,
active_clients_store: ActiveClientsStore,
mix_packet_sender: MixForwardingSender,
legacy_task_client: TaskClient,
shutdown_token: ShutdownToken,
) -> Result<(), NymNodeError> {
let config = gateway_tasks_config(&self.config);
@@ -624,8 +624,7 @@ impl NymNode {
metrics_sender,
self.metrics.clone(),
self.entry_gateway.mnemonic.clone(),
legacy_task_client,
shutdown_token,
self.shutdown_tracker().clone(),
);
// if we're running in entry mode, start the websocket
@@ -634,10 +633,11 @@ impl NymNode {
"starting the clients websocket... on {}",
self.config.gateway_tasks.ws_bind_address
);
let websocket = gateway_tasks_builder
let mut websocket = gateway_tasks_builder
.build_websocket_listener(active_clients_store.clone())
.await?;
websocket.start();
self.shutdown_tracker()
.try_spawn_named(async move { websocket.run().await }, "EntryWebsocket");
} else {
info!("node not running in entry mode: the websocket will remain closed");
}
@@ -697,8 +697,12 @@ impl NymNode {
}
// start task for removing stale and un-retrieved client messages
let stale_messages_cleaner = gateway_tasks_builder.build_stale_messages_cleaner();
stale_messages_cleaner.start();
let mut stale_messages_cleaner = gateway_tasks_builder.build_stale_messages_cleaner();
let shutdown_token = self.shutdown_token();
self.shutdown_tracker().try_spawn_named(
async move { stale_messages_cleaner.run(shutdown_token).await },
"StaleMessagesCleaner",
);
Ok(())
}
@@ -875,25 +879,23 @@ impl NymNode {
let mut verloc_measurer = VerlocMeasurer::new(
config,
self.ed25519_identity_keys.clone(),
self.shutdown_manager.clone_token("verloc"),
self.shutdown_manager.clone_shutdown_token(),
);
verloc_measurer.set_shared_state(self.verloc_stats.clone());
tokio::spawn(async move { verloc_measurer.run().await });
self.shutdown_manager
.try_spawn_named(async move { verloc_measurer.run().await }, "VerlocMeasurer");
}
pub(crate) fn setup_metrics_backend(
&self,
active_clients_store: ActiveClientsStore,
active_egress_mixnet_connections: ActiveConnections,
shutdown: ShutdownToken,
) -> MetricEventsSender {
info!("setting up node metrics...");
// aggregator (to listen for any metrics events)
let mut metrics_aggregator = MetricsAggregator::new(
self.config.metrics.debug.aggregator_update_rate,
shutdown.clone_with_suffix("aggregator"),
);
let mut metrics_aggregator =
MetricsAggregator::new(self.config.metrics.debug.aggregator_update_rate);
// >>>> START: register all relevant handlers for custom events
@@ -950,18 +952,25 @@ impl NymNode {
// console logger to preserve old mixnode functionalities
if self.config.metrics.debug.log_stats_to_console {
ConsoleLogger::new(
let mut console_logger = ConsoleLogger::new(
self.config.metrics.debug.console_logging_update_interval,
self.metrics.clone(),
shutdown.clone_with_suffix("metrics-console-logger"),
)
.start();
);
self.shutdown_tracker().try_spawn_named_with_shutdown(
async move { console_logger.run().await },
"ConsoleLogger",
);
}
let events_sender = metrics_aggregator.sender();
// spawn the aggregator task
metrics_aggregator.start();
let shutdown_token = self.shutdown_token();
self.shutdown_tracker().try_spawn_named(
async move { metrics_aggregator.run(shutdown_token).await },
"MetricsAggregator",
);
events_sender
}
@@ -983,14 +992,15 @@ impl NymNode {
sphinx_keys.keys.primary_key_rotation_id(),
sphinx_keys.keys.secondary_key_rotation_id(),
self.metrics.clone(),
self.shutdown_manager
.clone_token("replay-detection-background-flush"),
self.shutdown_manager.clone_shutdown_token(),
)
.await?;
let bloomfilters_manager = replay_detection_background.bloomfilters_manager();
self.shutdown_manager
.spawn(async move { replay_detection_background.run().await });
self.shutdown_manager.try_spawn_named(
async move { replay_detection_background.run().await },
"ReplayDetection",
);
Ok(bloomfilters_manager)
}
@@ -998,7 +1008,7 @@ impl NymNode {
fn setup_nym_apis_client(&self) -> Result<NymApisClient, NymNodeError> {
NymApisClient::new(
&self.config.mixnet.nym_api_urls,
self.shutdown_manager.clone_token("nym-apis-client"),
self.shutdown_manager.clone_shutdown_token(),
)
}
@@ -1029,7 +1039,7 @@ impl NymNode {
nym_apis_client,
replay_protection_manager,
managed_keys,
self.shutdown_manager.clone_token("key-rotation-controller"),
self.shutdown_manager.clone_shutdown_token(),
);
rotation_controller.start();
@@ -1042,7 +1052,6 @@ impl NymNode {
replay_protection_bloomfilter: ReplayProtectionBloomfilters,
routing_filter: F,
noise_config: NoiseConfig,
shutdown: ShutdownToken,
) -> Result<(MixForwardingSender, ActiveConnections), NymNodeError>
where
F: RoutingFilter + Send + Sync + 'static,
@@ -1073,14 +1082,16 @@ impl NymNode {
);
let active_connections = mixnet_client.active_connections();
let mut packet_forwarder = PacketForwarder::new(
mixnet_client,
routing_filter,
self.metrics.clone(),
shutdown.clone_with_suffix("mix-packet-forwarder"),
);
let mut packet_forwarder =
PacketForwarder::new(mixnet_client, routing_filter, self.metrics.clone());
let mix_packet_sender = packet_forwarder.sender();
tokio::spawn(async move { packet_forwarder.run().await });
let shutdown_token = self.shutdown_token();
self.shutdown_tracker().try_spawn_named(
async move { packet_forwarder.run(shutdown_token).await },
"PacketForwarder",
);
let final_hop_data = SharedFinalHopData::new(
active_clients_store.clone(),
@@ -1095,14 +1106,21 @@ impl NymNode {
final_hop_data,
noise_config,
self.metrics.clone(),
shutdown,
self.shutdown_token(),
);
let mut mixnet_listener = mixnet::Listener::new(self.config.mixnet.bind_address, shared);
let shutdown_token = self.shutdown_token();
self.shutdown_tracker().try_spawn_named(
async move { mixnet_listener.run(shutdown_token).await },
"MixnetListener",
);
mixnet::Listener::new(self.config.mixnet.bind_address, shared).start();
Ok((mix_packet_sender, active_connections))
}
pub(crate) async fn run_minimal_mixnet_processing(self) -> Result<(), NymNodeError> {
pub(crate) async fn run_minimal_mixnet_processing(mut self) -> Result<(), NymNodeError> {
let noise_config = nym_noise::config::NoiseConfig::new(
self.x25519_noise_keys.clone(),
NoiseNetworkView::new_empty(),
@@ -1115,11 +1133,10 @@ impl NymNode {
ReplayProtectionBloomfilters::new_disabled(),
OpenFilter,
noise_config,
self.shutdown_manager.clone_token("mixnet-traffic"),
)
.await?;
self.shutdown_manager.close();
self.shutdown_manager.close_tracker();
self.shutdown_manager.run_until_shutdown().await;
Ok(())
@@ -1137,16 +1154,17 @@ impl NymNode {
let http_server = self.build_http_server().await?;
let bind_address = self.config.http.bind_address;
let server_shutdown = self.shutdown_manager.clone_token("http-server");
let server_shutdown = self.shutdown_manager.clone_shutdown_token();
self.shutdown_manager.spawn(async move {
{
self.shutdown_manager.try_spawn_named(
async move {
info!("starting NymNodeHTTPServer on {bind_address}");
http_server
.with_graceful_shutdown(async move { server_shutdown.cancelled().await })
.await
}
});
},
"HttpApi",
);
let nym_apis_client = self.setup_nym_apis_client()?;
@@ -1172,14 +1190,12 @@ impl NymNode {
bloomfilters_manager.bloomfilters(),
network_refresher.routing_filter(),
noise_config,
self.shutdown_manager.clone_token("mixnet-traffic"),
)
.await?;
let metrics_sender = self.setup_metrics_backend(
active_clients_store.clone(),
active_egress_mixnet_connections,
self.shutdown_manager.clone_token("metrics"),
);
self.start_gateway_tasks(
@@ -1187,8 +1203,6 @@ impl NymNode {
metrics_sender,
active_clients_store,
mix_packet_sender,
self.shutdown_manager.subscribe_legacy("gateway-tasks"),
self.shutdown_manager.child_token("gateway-tasks"),
)
.await?;
@@ -1196,7 +1210,7 @@ impl NymNode {
.await?;
network_refresher.start();
self.shutdown_manager.close();
self.shutdown_manager.close_tracker();
Ok(self.shutdown_manager)
}

Some files were not shown because too many files have changed in this diff Show More