Compare commits

...

2 Commits

Author SHA1 Message Date
Jon Häggblad 8cd9b99d72 wip 2023-11-16 11:22:05 +00:00
Jon Häggblad f7093cdc5a Rework error handling in tun device 2023-11-15 12:33:45 +00:00
5 changed files with 184 additions and 105 deletions
Generated
+1
View File
@@ -7593,6 +7593,7 @@ dependencies = [
"tap",
"thiserror",
"tokio",
"tokio-stream",
"tokio-tun",
]
+1
View File
@@ -32,6 +32,7 @@ serde = { workspace = true, features = ["derive"] }
tap.workspace = true
thiserror.workspace = true
tokio = { workspace = true, features = ["rt-multi-thread", "net", "io-util"] }
tokio-stream = { version = "0.1.11" }
[target.'cfg(target_os = "linux")'.dependencies]
tokio-tun = "0.9.0"
+24 -4
View File
@@ -1,11 +1,14 @@
use std::net::SocketAddr;
use std::{net::SocketAddr, time::Duration};
use boringtun::x25519;
use dashmap::{
mapref::one::{Ref, RefMut},
DashMap,
};
use tokio::sync::mpsc::{self};
use tokio::{
sync::mpsc::{self},
time::{error::Elapsed, timeout},
};
use crate::event::Event;
@@ -14,9 +17,26 @@ use crate::event::Event;
pub struct PeerEventSender(mpsc::Sender<Event>);
pub(crate) struct PeerEventReceiver(mpsc::Receiver<Event>);
#[derive(thiserror::Error, Debug)]
pub enum PeerEventSenderError {
#[error("timeout")]
Timeout {
#[from]
source: Elapsed,
},
#[error("send failed: {source}")]
SendError {
#[from]
source: mpsc::error::SendError<Event>,
},
}
impl PeerEventSender {
pub(crate) async fn send(&self, event: Event) -> Result<(), mpsc::error::SendError<Event>> {
self.0.send(event).await
pub(crate) async fn send(&self, event: Event) -> Result<(), PeerEventSenderError> {
timeout(Duration::from_millis(1000), self.0.send(event))
.await?
.map_err(|err| err.into())
}
}
+136 -94
View File
@@ -2,21 +2,61 @@ use std::{
collections::HashMap,
net::{IpAddr, Ipv4Addr},
sync::Arc,
time::Duration,
};
use etherparse::{InternetSlice, SlicedPacket};
use tap::TapFallible;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
time::timeout,
};
use crate::{
active_peers::PeerEventSenderError,
event::Event,
tun_task_channel::{
tun_task_channel, tun_task_response_channel, TunTaskPayload, TunTaskResponseRx,
TunTaskResponseTx, TunTaskRx, TunTaskTx,
TunTaskResponseSendError, TunTaskResponseTx, TunTaskRx, TunTaskTx,
},
udp_listener::PeersByIp,
};
#[derive(thiserror::Error, Debug)]
pub enum TunDeviceError {
#[error("iface: timeout writing to tun device, dropping packet")]
TunWriteTimeout,
#[error("iface: failed forwarding packet to peer: {source}")]
ForwardToPeerFailed {
#[from]
source: PeerEventSenderError,
},
#[error("iface: failed to forward responding packet with tag: {source}")]
ForwardNatResponseFailed {
#[from]
source: TunTaskResponseSendError,
},
#[error("iface: error writing to tun device: {source}")]
TunWriteError { source: std::io::Error },
#[error("unable to parse destination address from packet")]
UnableToParseDstAdddress,
#[error("unable to parse source address from packet")]
UnableToParseSrcAddress {
#[from]
source: etherparse::ReadError,
},
#[error("unable to parse source address from packet: ip header missing")]
UnableToParseSrcAddressIpHeaderMissing,
#[error("unable to lock peer mutex")]
FailedToLockPeer,
}
fn setup_tokio_tun_device(name: &str, address: Ipv4Addr, netmask: Ipv4Addr) -> tokio_tun::Tun {
log::info!("Creating TUN device with: address={address}, netmask={netmask}");
// Read MTU size from env variable NYM_MTU_SIZE, else default to 1420.
@@ -38,10 +78,10 @@ fn setup_tokio_tun_device(name: &str, address: Ipv4Addr, netmask: Ipv4Addr) -> t
pub struct TunDevice {
// The TUN device that we read/write to, to send/receive packets
tun: tokio_tun::Tun,
tun: Option<tokio_tun::Tun>,
// Incoming data that we should send
tun_task_rx: TunTaskRx,
tun_task_rx: Option<TunTaskRx>,
// And when we get replies, this is where we should send it
tun_task_response_tx: TunTaskResponseTx,
@@ -74,6 +114,14 @@ pub struct AllowedIpsInner {
peers_by_ip: Arc<tokio::sync::Mutex<PeersByIp>>,
}
impl AllowedIpsInner {
async fn lock(&self) -> Result<tokio::sync::MutexGuard<PeersByIp>, TunDeviceError> {
timeout(Duration::from_millis(200), self.peers_by_ip.as_ref().lock())
.await
.map_err(|_| TunDeviceError::FailedToLockPeer)
}
}
pub struct NatInner {
nat_table: HashMap<IpAddr, u64>,
}
@@ -104,9 +152,9 @@ impl TunDevice {
let (tun_task_response_tx, tun_task_response_rx) = tun_task_response_channel();
let tun_device = TunDevice {
tun_task_rx,
tun_task_rx: Some(tun_task_rx),
tun_task_response_tx,
tun,
tun: Some(tun),
routing_mode,
};
@@ -114,47 +162,35 @@ impl TunDevice {
}
// Send outbound packets out on the wild internet
async fn handle_tun_write(&mut self, data: TunTaskPayload) {
let (tag, packet) = data;
let Some(dst_addr) = boringtun::noise::Tunn::dst_address(&packet) else {
log::error!("Unable to parse dst_address in packet that was supposed to be written to tun device");
return;
};
let Some(src_addr) = parse_src_address(&packet) else {
log::error!("Unable to parse src_address in packet that was supposed to be written to tun device");
return;
};
log::info!(
"iface: write Packet({src_addr} -> {dst_addr}, {} bytes)",
packet.len()
);
async fn handle_tun_write(&mut self, data: TunTaskPayload) -> Result<(), TunDeviceError> {
{
let (tag, ref packet) = data;
let dst_addr = boringtun::noise::Tunn::dst_address(packet)
.ok_or_else(|| TunDeviceError::UnableToParseDstAdddress)?;
// TODO: expire old entries
if let RoutingMode::Nat(nat_table) = &mut self.routing_mode {
nat_table.nat_table.insert(src_addr, tag);
let src_addr = parse_src_address(packet)?;
log::info!(
"iface: write Packet({src_addr} -> {dst_addr}, {} bytes)",
packet.len()
);
// TODO: expire old entries
if let RoutingMode::Nat(nat_table) = &mut self.routing_mode {
nat_table.nat_table.insert(src_addr, tag);
}
}
tokio::time::timeout(
std::time::Duration::from_millis(1000),
self.tun.write_all(&packet),
)
.await
.tap_err(|err| {
log::error!("iface: write error: {err}");
})
.ok();
// timeout(Duration::from_millis(1000), self.tun.write_all(&data.1))
// .await
// .map_err(|_| TunDeviceError::TunWriteTimeout)?
// .map_err(|err| TunDeviceError::TunWriteError { source: err })
}
// Receive reponse packets from the wild internet
async fn handle_tun_read(&self, packet: &[u8]) {
let Some(dst_addr) = boringtun::noise::Tunn::dst_address(packet) else {
log::error!("Unable to parse dst_address in packet that was read from tun device");
return;
};
let Some(src_addr) = parse_src_address(packet) else {
log::error!("Unable to parse src_address in packet that was read from tun device");
return;
};
async fn handle_tun_read(&self, packet: &[u8]) -> Result<(), TunDeviceError> {
let dst_addr = boringtun::noise::Tunn::dst_address(packet)
.ok_or(TunDeviceError::UnableToParseDstAdddress)?;
let src_addr = parse_src_address(packet)?;
log::info!(
"iface: read Packet({src_addr} -> {dst_addr}, {} bytes)",
packet.len(),
@@ -165,64 +201,72 @@ impl TunDevice {
match self.routing_mode {
// This is how wireguard does it, by consulting the AllowedIPs table.
RoutingMode::AllowedIps(ref peers_by_ip) => {
let Ok(peers) = tokio::time::timeout(
std::time::Duration::from_millis(1000),
peers_by_ip.peers_by_ip.as_ref().lock(),
)
.await
else {
log::error!("Failed to lock peer");
return;
};
let peers = peers_by_ip.lock().await?;
if let Some(peer_tx) = peers.longest_match(dst_addr).map(|(_, tx)| tx) {
log::info!("Forward packet to wg tunnel");
tokio::time::timeout(
std::time::Duration::from_millis(1000),
peer_tx.send(Event::Ip(packet.to_vec().into())),
)
.await
.tap_err(|err| log::error!("Failed to forward packet to wg tunnel: {err}"))
.ok();
return;
return peer_tx
.send(Event::Ip(packet.to_vec().into()))
.await
.map_err(|err| err.into());
}
}
// But we can also do it by consulting the NAT table.
RoutingMode::Nat(ref nat_table) => {
if let Some(tag) = nat_table.nat_table.get(&dst_addr) {
log::info!("Forward packet with tag: {tag}");
tokio::time::timeout(
std::time::Duration::from_millis(1000),
self.tun_task_response_tx.send((*tag, packet.to_vec())),
)
.await
.tap_err(|err| log::error!("Failed to foward packet with tag: {err}"))
.ok();
return;
log::info!("Forward packet with NAT tag: {tag}");
return self
.tun_task_response_tx
.send((*tag, packet.to_vec()))
.await
.map_err(|err| err.into());
}
}
}
log::info!("No peer found, packet dropped");
Ok(())
}
pub async fn run(mut self) {
let mut buf = [0u8; 65535];
let tun_task_rx_stream =
tokio_stream::wrappers::ReceiverStream::new(self.tun_task_rx.take().unwrap().0);
use futures::StreamExt;
let tun_task_rx_stream = tun_task_rx_stream.map(|data| {
//{
// let (tag, ref packet) = data;
// let dst_addr = boringtun::noise::Tunn::dst_address(packet).unwrap();
// // .ok_or_else(|| TunDeviceError::UnableToParseDstAdddress)?;
// let src_addr = parse_src_address(packet).unwrap();
// log::info!(
// "iface: write Packet({src_addr} -> {dst_addr}, {} bytes)",
// packet.len()
// );
// // TODO: expire old entries
// // if let RoutingMode::Nat(nat_table) = &mut self.routing_mode {
// // nat_table.nat_table.insert(src_addr, tag);
// // }
//}
// data.1
4
});
let (mut tun_read, tun_write) = tokio::io::split(self.tun);
loop {
tokio::select! {
// Reading from the TUN device
len = self.tun.read(&mut buf) => match len {
// len = self.tun.read(&mut buf) => match len {
len = tun_read.read(&mut buf) => match len {
Ok(len) => {
let packet = &buf[..len];
tokio::time::timeout(
std::time::Duration::from_millis(1000),
self.handle_tun_read(packet)
)
.await
.tap_err(|_err| log::error!("Failed: handle_tun_read timeout"))
.ok();
if let Err(err) = self.handle_tun_read(packet).await {
log::error!("iface: handle_tun_read failed: {err}")
}
},
Err(err) => {
log::info!("iface: read error: {err}");
@@ -230,15 +274,14 @@ impl TunDevice {
}
},
// Writing to the TUN device
Some(data) = self.tun_task_rx.recv() => {
tokio::time::timeout(
std::time::Duration::from_millis(1000),
self.handle_tun_write(data)
)
.await
.tap_err(|_err| log::error!("Failed: handle_tun_write timeout"))
.ok();
}
//Some(data) = self.tun_task_rx.recv() => {
// if let Err(err) = self.handle_tun_write(data).await {
// log::error!("ifcae: handle_tun_write failed: {err}");
// }
//}
// res = self.tun.send_all(&mut tun_task_rx_stream) => {
// log::error!("finished");
// }
}
}
// log::info!("TUN device shutting down");
@@ -249,12 +292,11 @@ impl TunDevice {
}
}
fn parse_src_address(packet: &[u8]) -> Option<IpAddr> {
let headers = SlicedPacket::from_ip(packet)
.tap_err(|err| log::error!("Unable to parse IP packet: {err:?}"))
.ok()?;
Some(match headers.ip? {
InternetSlice::Ipv4(ip, _) => ip.source_addr().into(),
InternetSlice::Ipv6(ip, _) => ip.source_addr().into(),
})
fn parse_src_address(packet: &[u8]) -> Result<IpAddr, TunDeviceError> {
let headers = SlicedPacket::from_ip(packet)?;
match headers.ip {
Some(InternetSlice::Ipv4(ip, _)) => Ok(ip.source_addr().into()),
Some(InternetSlice::Ipv6(ip, _)) => Ok(ip.source_addr().into()),
None => Err(TunDeviceError::UnableToParseSrcAddressIpHeaderMissing),
}
}
+22 -7
View File
@@ -1,10 +1,17 @@
use tokio::sync::mpsc;
use std::time::Duration;
use tokio::{
sync::mpsc::{self, error::SendError},
time::{error::Elapsed, timeout},
};
pub(crate) type TunTaskPayload = (u64, Vec<u8>);
#[derive(Clone)]
pub struct TunTaskTx(mpsc::Sender<TunTaskPayload>);
pub(crate) struct TunTaskRx(mpsc::Receiver<TunTaskPayload>);
pub(crate) struct TunTaskRx(pub(crate) mpsc::Receiver<TunTaskPayload>);
pub(crate) struct TunTaskRxStream(pub(crate) tokio_stream::wrappers::ReceiverStream<TunTaskPayload>);
impl TunTaskTx {
pub async fn send(
@@ -30,12 +37,20 @@ pub(crate) fn tun_task_channel() -> (TunTaskTx, TunTaskRx) {
pub(crate) struct TunTaskResponseTx(mpsc::Sender<TunTaskPayload>);
pub struct TunTaskResponseRx(mpsc::Receiver<TunTaskPayload>);
#[derive(thiserror::Error, Debug)]
pub enum TunTaskResponseSendError {
#[error("failed to send: timeout")]
Timeout(#[from] Elapsed),
#[error("failed to send: {0}")]
SendError(#[from] SendError<TunTaskPayload>),
}
impl TunTaskResponseTx {
pub(crate) async fn send(
&self,
data: TunTaskPayload,
) -> Result<(), tokio::sync::mpsc::error::SendError<TunTaskPayload>> {
self.0.send(data).await
pub(crate) async fn send(&self, data: TunTaskPayload) -> Result<(), TunTaskResponseSendError> {
timeout(Duration::from_millis(1000), self.0.send(data))
.await?
.map_err(|err| err.into())
}
}