Compare commits

...

2 Commits

Author SHA1 Message Date
Jędrzej Stuczyński b2d77aedd3 added additional logs 2024-08-22 11:07:49 +01:00
Jędrzej Stuczyński 58dcf2171c make sure to always create available_bandwidth row for new clients 2024-08-22 10:47:22 +01:00
3 changed files with 38 additions and 18 deletions
+5 -2
View File
@@ -37,10 +37,13 @@ impl BandwidthManager {
}
/// Creates a new bandwidth entry for the particular client.
pub(crate) async fn insert_new_client(&self, client_id: i64) -> Result<(), sqlx::Error> {
pub(crate) async fn insert_new_client_if_doesnt_exist(
&self,
client_id: i64,
) -> Result<(), sqlx::Error> {
// FIXME: hack; we need to change api slightly
sqlx::query!(
"INSERT INTO available_bandwidth(client_id, available, expiration) VALUES (?, 0, ?)",
"INSERT OR IGNORE INTO available_bandwidth(client_id, available, expiration) VALUES (?, 0, ?)",
client_id,
OffsetDateTime::UNIX_EPOCH,
)
+23 -1
View File
@@ -35,6 +35,13 @@ pub trait Storage: Send + Sync {
client_address: DestinationAddressBytes,
) -> Result<i64, StorageError>;
/// Creates all relevant database entries for the newly registered client
async fn insert_new_client(
&self,
client_address: DestinationAddressBytes,
shared_keys: &SharedKeys,
) -> Result<i64, StorageError>;
/// Inserts provided derived shared keys into the database.
/// If keys previously existed for the provided client, they are overwritten with the new data.
///
@@ -322,6 +329,19 @@ impl Storage for PersistentStorage {
.await?)
}
async fn insert_new_client(
&self,
client_address: DestinationAddressBytes,
shared_keys: &SharedKeys,
) -> Result<i64, StorageError> {
let id = self.insert_shared_keys(client_address, shared_keys).await?;
self.bandwidth_manager
.insert_new_client_if_doesnt_exist(id)
.await?;
Ok(id)
}
async fn insert_shared_keys(
&self,
client_address: DestinationAddressBytes,
@@ -390,7 +410,9 @@ impl Storage for PersistentStorage {
}
async fn create_bandwidth_entry(&self, client_id: i64) -> Result<(), StorageError> {
self.bandwidth_manager.insert_new_client(client_id).await?;
self.bandwidth_manager
.insert_new_client_if_doesnt_exist(client_id)
.await?;
Ok(())
}
@@ -26,6 +26,7 @@ use std::time::Duration;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_tungstenite::tungstenite::{protocol::Message, Error as WsError};
use tracing::field::debug;
use tracing::*;
use crate::node::client_handling::websocket::common_state::CommonHandlerState;
@@ -298,6 +299,7 @@ where
///
/// * `client_address`: address of the client that is going to receive the messages.
/// * `shared_keys`: shared keys derived between the client and the gateway used to encrypt and tag the messages.
#[instrument(skip_all)]
async fn push_stored_messages_to_client(
&mut self,
client_address: DestinationAddressBytes,
@@ -306,6 +308,7 @@ where
where
S: AsyncRead + AsyncWrite + Unpin,
{
debug!("attempting to push stored messages to client");
let mut start_next_after = None;
loop {
// retrieve some messages
@@ -521,6 +524,7 @@ where
/// * `client_address`: address of the client wishing to authenticate.
/// * `encrypted_address`: ciphertext of the address of the client wishing to authenticate.
/// * `iv`: fresh IV received with the request.
#[instrument(skip_all)]
async fn handle_authenticate(
&mut self,
client_protocol_version: Option<u8>,
@@ -531,6 +535,7 @@ where
where
S: AsyncRead + AsyncWrite + Unpin,
{
debug("handling client authentication");
let negotiated_protocol = self.negotiate_client_protocol(client_protocol_version)?;
// populate the negotiated protocol for future uses
self.negotiated_protocol = Some(negotiated_protocol);
@@ -610,23 +615,9 @@ where
let client_id = self
.shared_state
.storage
.insert_shared_keys(client_address, client_shared_keys)
.insert_new_client(client_address, client_shared_keys)
.await?;
// see if we have bandwidth entry for the client already, if not, create one with zero value
if self
.shared_state
.storage
.get_available_bandwidth(client_id)
.await?
.is_none()
{
self.shared_state
.storage
.create_bandwidth_entry(client_id)
.await?;
}
self.push_stored_messages_to_client(client_address, client_shared_keys)
.await?;
@@ -639,6 +630,7 @@ where
/// # Arguments
///
/// * `init_data`: init payload of the registration handshake.
#[instrument(skip_all)]
async fn handle_register(
&mut self,
client_protocol_version: Option<u8>,
@@ -647,6 +639,7 @@ where
where
S: AsyncRead + AsyncWrite + Unpin + Send,
{
debug!("handling client registration");
let negotiated_protocol = self.negotiate_client_protocol(client_protocol_version)?;
// populate the negotiated protocol for future uses
self.negotiated_protocol = Some(negotiated_protocol);
@@ -659,7 +652,9 @@ where
}
let shared_keys = self.perform_registration_handshake(init_data).await?;
debug!("managed to derived shared keys");
let client_id = self.register_client(remote_address, &shared_keys).await?;
event!(Level::DEBUG, client_id, protocol = negotiated_protocol);
let client_details = ClientDetails::new(client_id, remote_address, shared_keys);