addressing coderabbit comments

This commit is contained in:
benedettadavico
2026-06-01 11:59:43 +02:00
parent a52a8c3e81
commit a98a65c16d
8 changed files with 85 additions and 21 deletions
+9 -3
View File
@@ -658,18 +658,24 @@ func checkPorts(target string, ports []uint16, timeoutSec uint64, tnet *netstack
}
}
} else {
// All other targets can handle concurrent connections, probably
log.Printf("Port check: testing %d ports on %s concurrently (timeout %v each)",
len(ports), target, timeout)
// All other targets can handle concurrent connections, probably.
// A semaphore caps concurrent tnet.DialContext calls to avoid
// overwhelming the single userspace netstack instance.
const maxConcurrentDials = 64
log.Printf("Port check: testing %d ports on %s concurrently (max %d at a time, timeout %v each)",
len(ports), target, maxConcurrentDials, timeout)
var (
mu sync.Mutex
wg sync.WaitGroup
sem = make(chan struct{}, maxConcurrentDials)
)
for _, p := range ports {
wg.Add(1)
go func(port uint16) {
defer wg.Done()
sem <- struct{}{}
defer func() { <-sem }()
addr := net.JoinHostPort(targetIP, fmt.Sprintf("%d", port))
ctx, cancel := context.WithTimeout(context.Background(), timeout)
c, err := tnet.DialContext(ctx, "tcp", addr)
+15 -1
View File
@@ -257,7 +257,21 @@ impl PortCheckResult {
pub fn closed_ports(&self) -> Vec<u16> {
self.ports
.iter()
.filter_map(|(k, &open)| if !open { k.parse().ok() } else { None })
.filter_map(|(k, &open)| {
if open {
return None;
}
match k.parse::<u16>() {
Ok(port) => Some(port),
Err(e) => {
tracing::warn!(
"Skipping port key {:?} that could not be parsed as u16: {e}",
k
);
None
}
}
})
.collect()
}
}
@@ -34,8 +34,10 @@ pub(super) fn parse_server_config(s: &str) -> Result<ServerConfig, String> {
let port = parts[1]
.parse::<u16>()
.map_err(|_| "Invalid port number".to_string())?;
let auth_key =
PrivateKey::from_base58_string(env::var("NODE_STATUS_AGENT_AUTH_KEY").unwrap()).unwrap();
let raw_key = env::var("NODE_STATUS_AGENT_AUTH_KEY")
.map_err(|_| "NODE_STATUS_AGENT_AUTH_KEY environment variable is not set".to_string())?;
let auth_key = PrivateKey::from_base58_string(raw_key)
.map_err(|e| format!("Failed to decode NODE_STATUS_AGENT_AUTH_KEY as base58: {e}"))?;
Ok(ServerConfig {
address,
@@ -104,11 +104,21 @@ impl Storage {
.fetch(&self.pool)
.try_collect::<Vec<_>>()
.await?;
let items: Vec<Gateway> = items
.into_iter()
.map(|item| item.try_into())
.collect::<anyhow::Result<Vec<_>>>()
.inspect_err(|e| error!("Conversion from DTO failed: {e}. Invalidly stored data?"))?;
let mut gateways: Vec<Gateway> = Vec::with_capacity(items.len());
let mut failed = 0usize;
for item in items {
match item.try_into() {
Ok(gw) => gateways.push(gw),
Err(e) => {
error!("Conversion from DTO failed: {e}. Invalidly stored data?");
failed += 1;
}
}
}
if failed > 0 {
tracing::warn!("{failed} gateway DTO(s) failed conversion and were skipped");
}
let items = gateways;
tracing::trace!("Fetched {} gateways from DB", items.len());
Ok(items)
}
@@ -24,6 +24,27 @@ pub(crate) async fn count_testruns_in_progress(
.map_err(anyhow::Error::from)
}
pub(crate) async fn count_testruns_in_progress_by_kind(
conn: &mut DbConnection,
kind: TestRunKind,
) -> anyhow::Result<Option<i64>> {
sqlx::query_scalar!(
r#"SELECT
COUNT(id) as "count: i64"
FROM testruns
WHERE
status = $1
AND
kind = $2
"#,
TestRunStatus::InProgress as i64,
kind as i16,
)
.fetch_one(conn.as_mut())
.await
.map_err(anyhow::Error::from)
}
pub(crate) async fn get_in_progress_testrun_by_id(
conn: &mut DbConnection,
testrun_id: i32,
@@ -142,13 +142,18 @@ async fn request_ports_check_testrun(
.await
.map_err(HttpError::internal_with_logging)?;
let active_testruns = db::queries::testruns::count_testruns_in_progress(&mut conn)
let active_ports_check_testruns = db::queries::testruns::count_testruns_in_progress_by_kind(
&mut conn,
TestRunKind::PortsCheck,
)
.await
.map_err(HttpError::internal_with_logging)?
.unwrap_or_default();
let max_count = state.agent_max_count();
if active_testruns >= max_count {
tracing::warn!("{active_testruns}/{max_count} testruns in progress, rejecting",);
if active_ports_check_testruns >= max_count {
tracing::warn!(
"{active_ports_check_testruns}/{max_count} ports-check testruns in progress, rejecting",
);
return Err(HttpError::no_testruns_available());
}
@@ -562,6 +567,10 @@ async fn process_ports_check_submission(
payload.port_check_result.can_register,
);
// probe_log is intentionally not persisted for ports-check submissions:
// the ports-check result is a lightweight JSONB record and does not warrant
// the storage overhead of a full probe log. Full logs are retained only for
// regular probe testruns via update_gateway_last_probe_log.
queries::testruns::update_gateway_ports_check_only(
conn,
gateway_id,
@@ -54,7 +54,8 @@ pub(crate) async fn start_http_api(
.unwrap_or(true);
if ports_check_scheduler_enabled {
let mut interval = tokio::time::interval(std::time::Duration::from_secs(60 * 10));
let period = std::time::Duration::from_secs(60 * 10);
let mut interval = tokio::time::interval_at(tokio::time::Instant::now() + period, period);
interval.set_missed_tick_behavior(MissedTickBehavior::Skip);
let scheduler_shutdown = shutdown_tracker.clone_shutdown_token().cancelled_owned();
shutdown_tracker.spawn(async move {
@@ -134,7 +134,8 @@ impl AppState {
return Err(HttpError::unauthorized());
};
if request.verify_signature().is_err() {
if let Err(err) = request.verify_signature() {
tracing::debug!("Signature verification error: {:?}", err);
tracing::warn!("Signature verification failed, rejecting");
return Err(HttpError::unauthorized());
}