Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 8cd9b99d72 | |||
| f7093cdc5a |
Generated
+1
@@ -7593,6 +7593,7 @@ dependencies = [
|
||||
"tap",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
"tokio-stream",
|
||||
"tokio-tun",
|
||||
]
|
||||
|
||||
|
||||
@@ -32,6 +32,7 @@ serde = { workspace = true, features = ["derive"] }
|
||||
tap.workspace = true
|
||||
thiserror.workspace = true
|
||||
tokio = { workspace = true, features = ["rt-multi-thread", "net", "io-util"] }
|
||||
tokio-stream = { version = "0.1.11" }
|
||||
|
||||
[target.'cfg(target_os = "linux")'.dependencies]
|
||||
tokio-tun = "0.9.0"
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
use std::net::SocketAddr;
|
||||
use std::{net::SocketAddr, time::Duration};
|
||||
|
||||
use boringtun::x25519;
|
||||
use dashmap::{
|
||||
mapref::one::{Ref, RefMut},
|
||||
DashMap,
|
||||
};
|
||||
use tokio::sync::mpsc::{self};
|
||||
use tokio::{
|
||||
sync::mpsc::{self},
|
||||
time::{error::Elapsed, timeout},
|
||||
};
|
||||
|
||||
use crate::event::Event;
|
||||
|
||||
@@ -14,9 +17,26 @@ use crate::event::Event;
|
||||
pub struct PeerEventSender(mpsc::Sender<Event>);
|
||||
pub(crate) struct PeerEventReceiver(mpsc::Receiver<Event>);
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum PeerEventSenderError {
|
||||
#[error("timeout")]
|
||||
Timeout {
|
||||
#[from]
|
||||
source: Elapsed,
|
||||
},
|
||||
|
||||
#[error("send failed: {source}")]
|
||||
SendError {
|
||||
#[from]
|
||||
source: mpsc::error::SendError<Event>,
|
||||
},
|
||||
}
|
||||
|
||||
impl PeerEventSender {
|
||||
pub(crate) async fn send(&self, event: Event) -> Result<(), mpsc::error::SendError<Event>> {
|
||||
self.0.send(event).await
|
||||
pub(crate) async fn send(&self, event: Event) -> Result<(), PeerEventSenderError> {
|
||||
timeout(Duration::from_millis(1000), self.0.send(event))
|
||||
.await?
|
||||
.map_err(|err| err.into())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -2,21 +2,61 @@ use std::{
|
||||
collections::HashMap,
|
||||
net::{IpAddr, Ipv4Addr},
|
||||
sync::Arc,
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use etherparse::{InternetSlice, SlicedPacket};
|
||||
use tap::TapFallible;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::{
|
||||
io::{AsyncReadExt, AsyncWriteExt},
|
||||
time::timeout,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
active_peers::PeerEventSenderError,
|
||||
event::Event,
|
||||
tun_task_channel::{
|
||||
tun_task_channel, tun_task_response_channel, TunTaskPayload, TunTaskResponseRx,
|
||||
TunTaskResponseTx, TunTaskRx, TunTaskTx,
|
||||
TunTaskResponseSendError, TunTaskResponseTx, TunTaskRx, TunTaskTx,
|
||||
},
|
||||
udp_listener::PeersByIp,
|
||||
};
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum TunDeviceError {
|
||||
#[error("iface: timeout writing to tun device, dropping packet")]
|
||||
TunWriteTimeout,
|
||||
|
||||
#[error("iface: failed forwarding packet to peer: {source}")]
|
||||
ForwardToPeerFailed {
|
||||
#[from]
|
||||
source: PeerEventSenderError,
|
||||
},
|
||||
|
||||
#[error("iface: failed to forward responding packet with tag: {source}")]
|
||||
ForwardNatResponseFailed {
|
||||
#[from]
|
||||
source: TunTaskResponseSendError,
|
||||
},
|
||||
|
||||
#[error("iface: error writing to tun device: {source}")]
|
||||
TunWriteError { source: std::io::Error },
|
||||
|
||||
#[error("unable to parse destination address from packet")]
|
||||
UnableToParseDstAdddress,
|
||||
|
||||
#[error("unable to parse source address from packet")]
|
||||
UnableToParseSrcAddress {
|
||||
#[from]
|
||||
source: etherparse::ReadError,
|
||||
},
|
||||
|
||||
#[error("unable to parse source address from packet: ip header missing")]
|
||||
UnableToParseSrcAddressIpHeaderMissing,
|
||||
|
||||
#[error("unable to lock peer mutex")]
|
||||
FailedToLockPeer,
|
||||
}
|
||||
|
||||
fn setup_tokio_tun_device(name: &str, address: Ipv4Addr, netmask: Ipv4Addr) -> tokio_tun::Tun {
|
||||
log::info!("Creating TUN device with: address={address}, netmask={netmask}");
|
||||
// Read MTU size from env variable NYM_MTU_SIZE, else default to 1420.
|
||||
@@ -38,10 +78,10 @@ fn setup_tokio_tun_device(name: &str, address: Ipv4Addr, netmask: Ipv4Addr) -> t
|
||||
|
||||
pub struct TunDevice {
|
||||
// The TUN device that we read/write to, to send/receive packets
|
||||
tun: tokio_tun::Tun,
|
||||
tun: Option<tokio_tun::Tun>,
|
||||
|
||||
// Incoming data that we should send
|
||||
tun_task_rx: TunTaskRx,
|
||||
tun_task_rx: Option<TunTaskRx>,
|
||||
|
||||
// And when we get replies, this is where we should send it
|
||||
tun_task_response_tx: TunTaskResponseTx,
|
||||
@@ -74,6 +114,14 @@ pub struct AllowedIpsInner {
|
||||
peers_by_ip: Arc<tokio::sync::Mutex<PeersByIp>>,
|
||||
}
|
||||
|
||||
impl AllowedIpsInner {
|
||||
async fn lock(&self) -> Result<tokio::sync::MutexGuard<PeersByIp>, TunDeviceError> {
|
||||
timeout(Duration::from_millis(200), self.peers_by_ip.as_ref().lock())
|
||||
.await
|
||||
.map_err(|_| TunDeviceError::FailedToLockPeer)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct NatInner {
|
||||
nat_table: HashMap<IpAddr, u64>,
|
||||
}
|
||||
@@ -104,9 +152,9 @@ impl TunDevice {
|
||||
let (tun_task_response_tx, tun_task_response_rx) = tun_task_response_channel();
|
||||
|
||||
let tun_device = TunDevice {
|
||||
tun_task_rx,
|
||||
tun_task_rx: Some(tun_task_rx),
|
||||
tun_task_response_tx,
|
||||
tun,
|
||||
tun: Some(tun),
|
||||
routing_mode,
|
||||
};
|
||||
|
||||
@@ -114,47 +162,35 @@ impl TunDevice {
|
||||
}
|
||||
|
||||
// Send outbound packets out on the wild internet
|
||||
async fn handle_tun_write(&mut self, data: TunTaskPayload) {
|
||||
let (tag, packet) = data;
|
||||
let Some(dst_addr) = boringtun::noise::Tunn::dst_address(&packet) else {
|
||||
log::error!("Unable to parse dst_address in packet that was supposed to be written to tun device");
|
||||
return;
|
||||
};
|
||||
let Some(src_addr) = parse_src_address(&packet) else {
|
||||
log::error!("Unable to parse src_address in packet that was supposed to be written to tun device");
|
||||
return;
|
||||
};
|
||||
log::info!(
|
||||
"iface: write Packet({src_addr} -> {dst_addr}, {} bytes)",
|
||||
packet.len()
|
||||
);
|
||||
async fn handle_tun_write(&mut self, data: TunTaskPayload) -> Result<(), TunDeviceError> {
|
||||
{
|
||||
let (tag, ref packet) = data;
|
||||
let dst_addr = boringtun::noise::Tunn::dst_address(packet)
|
||||
.ok_or_else(|| TunDeviceError::UnableToParseDstAdddress)?;
|
||||
|
||||
// TODO: expire old entries
|
||||
if let RoutingMode::Nat(nat_table) = &mut self.routing_mode {
|
||||
nat_table.nat_table.insert(src_addr, tag);
|
||||
let src_addr = parse_src_address(packet)?;
|
||||
log::info!(
|
||||
"iface: write Packet({src_addr} -> {dst_addr}, {} bytes)",
|
||||
packet.len()
|
||||
);
|
||||
|
||||
// TODO: expire old entries
|
||||
if let RoutingMode::Nat(nat_table) = &mut self.routing_mode {
|
||||
nat_table.nat_table.insert(src_addr, tag);
|
||||
}
|
||||
}
|
||||
|
||||
tokio::time::timeout(
|
||||
std::time::Duration::from_millis(1000),
|
||||
self.tun.write_all(&packet),
|
||||
)
|
||||
.await
|
||||
.tap_err(|err| {
|
||||
log::error!("iface: write error: {err}");
|
||||
})
|
||||
.ok();
|
||||
// timeout(Duration::from_millis(1000), self.tun.write_all(&data.1))
|
||||
// .await
|
||||
// .map_err(|_| TunDeviceError::TunWriteTimeout)?
|
||||
// .map_err(|err| TunDeviceError::TunWriteError { source: err })
|
||||
}
|
||||
|
||||
// Receive reponse packets from the wild internet
|
||||
async fn handle_tun_read(&self, packet: &[u8]) {
|
||||
let Some(dst_addr) = boringtun::noise::Tunn::dst_address(packet) else {
|
||||
log::error!("Unable to parse dst_address in packet that was read from tun device");
|
||||
return;
|
||||
};
|
||||
let Some(src_addr) = parse_src_address(packet) else {
|
||||
log::error!("Unable to parse src_address in packet that was read from tun device");
|
||||
return;
|
||||
};
|
||||
async fn handle_tun_read(&self, packet: &[u8]) -> Result<(), TunDeviceError> {
|
||||
let dst_addr = boringtun::noise::Tunn::dst_address(packet)
|
||||
.ok_or(TunDeviceError::UnableToParseDstAdddress)?;
|
||||
let src_addr = parse_src_address(packet)?;
|
||||
log::info!(
|
||||
"iface: read Packet({src_addr} -> {dst_addr}, {} bytes)",
|
||||
packet.len(),
|
||||
@@ -165,64 +201,72 @@ impl TunDevice {
|
||||
match self.routing_mode {
|
||||
// This is how wireguard does it, by consulting the AllowedIPs table.
|
||||
RoutingMode::AllowedIps(ref peers_by_ip) => {
|
||||
let Ok(peers) = tokio::time::timeout(
|
||||
std::time::Duration::from_millis(1000),
|
||||
peers_by_ip.peers_by_ip.as_ref().lock(),
|
||||
)
|
||||
.await
|
||||
else {
|
||||
log::error!("Failed to lock peer");
|
||||
return;
|
||||
};
|
||||
|
||||
let peers = peers_by_ip.lock().await?;
|
||||
if let Some(peer_tx) = peers.longest_match(dst_addr).map(|(_, tx)| tx) {
|
||||
log::info!("Forward packet to wg tunnel");
|
||||
tokio::time::timeout(
|
||||
std::time::Duration::from_millis(1000),
|
||||
peer_tx.send(Event::Ip(packet.to_vec().into())),
|
||||
)
|
||||
.await
|
||||
.tap_err(|err| log::error!("Failed to forward packet to wg tunnel: {err}"))
|
||||
.ok();
|
||||
return;
|
||||
return peer_tx
|
||||
.send(Event::Ip(packet.to_vec().into()))
|
||||
.await
|
||||
.map_err(|err| err.into());
|
||||
}
|
||||
}
|
||||
|
||||
// But we can also do it by consulting the NAT table.
|
||||
RoutingMode::Nat(ref nat_table) => {
|
||||
if let Some(tag) = nat_table.nat_table.get(&dst_addr) {
|
||||
log::info!("Forward packet with tag: {tag}");
|
||||
tokio::time::timeout(
|
||||
std::time::Duration::from_millis(1000),
|
||||
self.tun_task_response_tx.send((*tag, packet.to_vec())),
|
||||
)
|
||||
.await
|
||||
.tap_err(|err| log::error!("Failed to foward packet with tag: {err}"))
|
||||
.ok();
|
||||
return;
|
||||
log::info!("Forward packet with NAT tag: {tag}");
|
||||
return self
|
||||
.tun_task_response_tx
|
||||
.send((*tag, packet.to_vec()))
|
||||
.await
|
||||
.map_err(|err| err.into());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log::info!("No peer found, packet dropped");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn run(mut self) {
|
||||
let mut buf = [0u8; 65535];
|
||||
|
||||
let tun_task_rx_stream =
|
||||
tokio_stream::wrappers::ReceiverStream::new(self.tun_task_rx.take().unwrap().0);
|
||||
use futures::StreamExt;
|
||||
let tun_task_rx_stream = tun_task_rx_stream.map(|data| {
|
||||
//{
|
||||
// let (tag, ref packet) = data;
|
||||
// let dst_addr = boringtun::noise::Tunn::dst_address(packet).unwrap();
|
||||
// // .ok_or_else(|| TunDeviceError::UnableToParseDstAdddress)?;
|
||||
|
||||
// let src_addr = parse_src_address(packet).unwrap();
|
||||
// log::info!(
|
||||
// "iface: write Packet({src_addr} -> {dst_addr}, {} bytes)",
|
||||
// packet.len()
|
||||
// );
|
||||
|
||||
// // TODO: expire old entries
|
||||
// // if let RoutingMode::Nat(nat_table) = &mut self.routing_mode {
|
||||
// // nat_table.nat_table.insert(src_addr, tag);
|
||||
// // }
|
||||
//}
|
||||
// data.1
|
||||
4
|
||||
});
|
||||
|
||||
let (mut tun_read, tun_write) = tokio::io::split(self.tun);
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
// Reading from the TUN device
|
||||
len = self.tun.read(&mut buf) => match len {
|
||||
// len = self.tun.read(&mut buf) => match len {
|
||||
len = tun_read.read(&mut buf) => match len {
|
||||
Ok(len) => {
|
||||
let packet = &buf[..len];
|
||||
tokio::time::timeout(
|
||||
std::time::Duration::from_millis(1000),
|
||||
self.handle_tun_read(packet)
|
||||
)
|
||||
.await
|
||||
.tap_err(|_err| log::error!("Failed: handle_tun_read timeout"))
|
||||
.ok();
|
||||
if let Err(err) = self.handle_tun_read(packet).await {
|
||||
log::error!("iface: handle_tun_read failed: {err}")
|
||||
}
|
||||
},
|
||||
Err(err) => {
|
||||
log::info!("iface: read error: {err}");
|
||||
@@ -230,15 +274,14 @@ impl TunDevice {
|
||||
}
|
||||
},
|
||||
// Writing to the TUN device
|
||||
Some(data) = self.tun_task_rx.recv() => {
|
||||
tokio::time::timeout(
|
||||
std::time::Duration::from_millis(1000),
|
||||
self.handle_tun_write(data)
|
||||
)
|
||||
.await
|
||||
.tap_err(|_err| log::error!("Failed: handle_tun_write timeout"))
|
||||
.ok();
|
||||
}
|
||||
//Some(data) = self.tun_task_rx.recv() => {
|
||||
// if let Err(err) = self.handle_tun_write(data).await {
|
||||
// log::error!("ifcae: handle_tun_write failed: {err}");
|
||||
// }
|
||||
//}
|
||||
// res = self.tun.send_all(&mut tun_task_rx_stream) => {
|
||||
// log::error!("finished");
|
||||
// }
|
||||
}
|
||||
}
|
||||
// log::info!("TUN device shutting down");
|
||||
@@ -249,12 +292,11 @@ impl TunDevice {
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_src_address(packet: &[u8]) -> Option<IpAddr> {
|
||||
let headers = SlicedPacket::from_ip(packet)
|
||||
.tap_err(|err| log::error!("Unable to parse IP packet: {err:?}"))
|
||||
.ok()?;
|
||||
Some(match headers.ip? {
|
||||
InternetSlice::Ipv4(ip, _) => ip.source_addr().into(),
|
||||
InternetSlice::Ipv6(ip, _) => ip.source_addr().into(),
|
||||
})
|
||||
fn parse_src_address(packet: &[u8]) -> Result<IpAddr, TunDeviceError> {
|
||||
let headers = SlicedPacket::from_ip(packet)?;
|
||||
match headers.ip {
|
||||
Some(InternetSlice::Ipv4(ip, _)) => Ok(ip.source_addr().into()),
|
||||
Some(InternetSlice::Ipv6(ip, _)) => Ok(ip.source_addr().into()),
|
||||
None => Err(TunDeviceError::UnableToParseSrcAddressIpHeaderMissing),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,10 +1,17 @@
|
||||
use tokio::sync::mpsc;
|
||||
use std::time::Duration;
|
||||
|
||||
use tokio::{
|
||||
sync::mpsc::{self, error::SendError},
|
||||
time::{error::Elapsed, timeout},
|
||||
};
|
||||
|
||||
pub(crate) type TunTaskPayload = (u64, Vec<u8>);
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct TunTaskTx(mpsc::Sender<TunTaskPayload>);
|
||||
pub(crate) struct TunTaskRx(mpsc::Receiver<TunTaskPayload>);
|
||||
pub(crate) struct TunTaskRx(pub(crate) mpsc::Receiver<TunTaskPayload>);
|
||||
|
||||
pub(crate) struct TunTaskRxStream(pub(crate) tokio_stream::wrappers::ReceiverStream<TunTaskPayload>);
|
||||
|
||||
impl TunTaskTx {
|
||||
pub async fn send(
|
||||
@@ -30,12 +37,20 @@ pub(crate) fn tun_task_channel() -> (TunTaskTx, TunTaskRx) {
|
||||
pub(crate) struct TunTaskResponseTx(mpsc::Sender<TunTaskPayload>);
|
||||
pub struct TunTaskResponseRx(mpsc::Receiver<TunTaskPayload>);
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum TunTaskResponseSendError {
|
||||
#[error("failed to send: timeout")]
|
||||
Timeout(#[from] Elapsed),
|
||||
|
||||
#[error("failed to send: {0}")]
|
||||
SendError(#[from] SendError<TunTaskPayload>),
|
||||
}
|
||||
|
||||
impl TunTaskResponseTx {
|
||||
pub(crate) async fn send(
|
||||
&self,
|
||||
data: TunTaskPayload,
|
||||
) -> Result<(), tokio::sync::mpsc::error::SendError<TunTaskPayload>> {
|
||||
self.0.send(data).await
|
||||
pub(crate) async fn send(&self, data: TunTaskPayload) -> Result<(), TunTaskResponseSendError> {
|
||||
timeout(Duration::from_millis(1000), self.0.send(data))
|
||||
.await?
|
||||
.map_err(|err| err.into())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user