diff --git a/Cargo.toml b/Cargo.toml index aeea038..e886e6e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,6 +9,7 @@ base64 = { version = "0.21" } clap = { version = "4.1", features = ["derive"] } dotenvy = "0.15" env_logger = "0.10" +hashlink = "0.8" log = "0.4" mio = { version = "0.8", features = ["os-poll", "net", "os-ext"] } smoltcp = { version = "0.9", features = ["std"] } diff --git a/src/tun2proxy.rs b/src/tun2proxy.rs index 64732cd..2e44f2d 100644 --- a/src/tun2proxy.rs +++ b/src/tun2proxy.rs @@ -253,7 +253,7 @@ pub(crate) trait ConnectionManager { fn get_credentials(&self) -> &Option; } -#[derive(Default, Clone, Debug)] +#[derive(Default)] pub struct Options { virtdns: Option, } @@ -423,11 +423,12 @@ impl<'a> TunToProxy<'a> { if let Some((connection, first_packet, _payload_offset, _payload_size)) = connection_tuple(frame) { - let resolved_conn = match &self.options.virtdns { + let resolved_conn = match &mut self.options.virtdns { None => connection.clone(), Some(virt_dns) => { let ip = SocketAddr::try_from(connection.dst.clone()).unwrap().ip(); - match virt_dns.ip_to_name(&ip) { + virt_dns.touch_ip(&ip); + match virt_dns.resolve_ip(&ip) { None => connection.clone(), Some(name) => connection.to_named(name.clone()), } @@ -564,6 +565,10 @@ impl<'a> TunToProxy<'a> { { let socket = self.sockets.get_mut::(socket_handle); if socket.may_send() { + if let Some(virtdns) = &mut self.options.virtdns { + // Unwrapping is fine because every smoltcp socket is bound to an. + virtdns.touch_ip(&IpAddr::from(socket.local_endpoint().unwrap().addr)); + } consumed = socket.send_slice(event.buffer).unwrap(); state .handler diff --git a/src/virtdns.rs b/src/virtdns.rs index 83eb94c..e34c64e 100644 --- a/src/virtdns.rs +++ b/src/virtdns.rs @@ -1,3 +1,5 @@ +use hashlink::linked_hash_map::RawEntryMut; +use hashlink::LruCache; use smoltcp::wire::Ipv4Cidr; use std::collections::{HashMap, LinkedList}; use std::convert::{TryFrom, TryInto}; @@ -5,8 +7,8 @@ use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; use std::str::FromStr; use std::time::{Duration, Instant}; -const DNS_TTL: u8 = 30; // TTL in DNS replies -const MAPPING_TIMEOUT: u64 = 60; // Mapping timeout +const DNS_TTL: u8 = 30; // TTL in DNS replies in seconds +const MAPPING_TIMEOUT: u64 = 60; // Mapping timeout in seconds #[derive(Eq, PartialEq, Debug)] #[allow(dead_code, clippy::upper_case_acronyms)] @@ -21,10 +23,13 @@ enum DnsClass { IN = 1, } -#[derive(Clone, Debug)] +struct NameCacheEntry { + name: String, + expiry: Instant, +} + pub struct VirtualDns { - ip_to_name: HashMap, - expiry: LinkedList<(IpAddr, Instant)>, + lru_cache: LruCache, name_to_ip: HashMap, network_addr: IpAddr, broadcast_addr: IpAddr, @@ -38,11 +43,10 @@ impl Default for VirtualDns { Self { next_addr: start_addr.into(), - ip_to_name: Default::default(), name_to_ip: Default::default(), - expiry: Default::default(), network_addr: IpAddr::try_from(cidr.network().address().into_address()).unwrap(), broadcast_addr: IpAddr::try_from(cidr.broadcast().unwrap().into_address()).unwrap(), + lru_cache: LruCache::new_unbounded(), } } } @@ -160,21 +164,43 @@ impl VirtualDns { } } - pub fn ip_to_name(&self, addr: &IpAddr) -> Option<&String> { - self.ip_to_name.get(addr) + // This is to be called whenever we receive or send a packet on the socket + // which connects the tun interface to the client, so existing IP address to name + // mappings to not expire as long as the connection is active. + pub fn touch_ip(&mut self, addr: &IpAddr) -> bool { + match self.lru_cache.get_mut(addr) { + None => false, + Some(entry) => { + entry.expiry = Instant::now() + Duration::from_secs(MAPPING_TIMEOUT); + true + } + } + } + + pub fn resolve_ip(&mut self, addr: &IpAddr) -> Option<&String> { + match self.lru_cache.get(addr) { + None => None, + Some(entry) => Some(&entry.name), + } } fn allocate_ip(&mut self, name: String) -> Option { let now = Instant::now(); - while let Some((ip, expiry)) = self.expiry.front() { - if now > *expiry { - let name = self.ip_to_name.remove(ip).unwrap(); - self.name_to_ip.remove(&name); - self.expiry.pop_front(); - } else { - break; + + // Building the to_remove list seems to be a bit clunky. + // But removing inside the loop does not immediately work due to borrow rules. + // TODO: Is there a better solution? + let mut to_remove = LinkedList::::new(); + for (ip, entry) in self.lru_cache.iter() { + if now > entry.expiry { + to_remove.push_back(*ip); + self.name_to_ip.remove(&entry.name); + continue; } } + for ip in to_remove { + self.lru_cache.remove(&ip); + } if let Some(ip) = self.name_to_ip.get(&name) { return Some(*ip); @@ -183,15 +209,19 @@ impl VirtualDns { let started_at = self.next_addr; loop { - if let std::collections::hash_map::Entry::Vacant(e) = - self.ip_to_name.entry(self.next_addr) + if let RawEntryMut::Vacant(vacant) = + self.lru_cache.raw_entry_mut().from_key(&self.next_addr) { - e.insert(name.clone()); - self.name_to_ip.insert(name, self.next_addr); - self.expiry.push_back(( + let expiry = Instant::now() + Duration::from_secs(MAPPING_TIMEOUT); + vacant.insert( self.next_addr, - Instant::now() + Duration::from_secs(MAPPING_TIMEOUT), - )); + NameCacheEntry { + name: name.clone(), + expiry, + }, + ); + // e.insert(name.clone()); + self.name_to_ip.insert(name, self.next_addr); return Some(self.next_addr); } self.next_addr = Self::increment_ip(self.next_addr);