Use LRU cache for virtual DNS

This commit introduces an LRU cache for virtual DNS and fixes a bug
where the virtual DNS feature would cause connection mappings to time
out.
This commit is contained in:
B. Blechschmidt 2023-03-24 13:26:31 +01:00
parent 7dec7f59f1
commit 3c8005e6b7
3 changed files with 62 additions and 26 deletions

View file

@ -9,6 +9,7 @@ base64 = { version = "0.21" }
clap = { version = "4.1", features = ["derive"] } clap = { version = "4.1", features = ["derive"] }
dotenvy = "0.15" dotenvy = "0.15"
env_logger = "0.10" env_logger = "0.10"
hashlink = "0.8"
log = "0.4" log = "0.4"
mio = { version = "0.8", features = ["os-poll", "net", "os-ext"] } mio = { version = "0.8", features = ["os-poll", "net", "os-ext"] }
smoltcp = { version = "0.9", features = ["std"] } smoltcp = { version = "0.9", features = ["std"] }

View file

@ -253,7 +253,7 @@ pub(crate) trait ConnectionManager {
fn get_credentials(&self) -> &Option<Credentials>; fn get_credentials(&self) -> &Option<Credentials>;
} }
#[derive(Default, Clone, Debug)] #[derive(Default)]
pub struct Options { pub struct Options {
virtdns: Option<VirtualDns>, virtdns: Option<VirtualDns>,
} }
@ -423,11 +423,12 @@ impl<'a> TunToProxy<'a> {
if let Some((connection, first_packet, _payload_offset, _payload_size)) = if let Some((connection, first_packet, _payload_offset, _payload_size)) =
connection_tuple(frame) connection_tuple(frame)
{ {
let resolved_conn = match &self.options.virtdns { let resolved_conn = match &mut self.options.virtdns {
None => connection.clone(), None => connection.clone(),
Some(virt_dns) => { Some(virt_dns) => {
let ip = SocketAddr::try_from(connection.dst.clone()).unwrap().ip(); let ip = SocketAddr::try_from(connection.dst.clone()).unwrap().ip();
match virt_dns.ip_to_name(&ip) { virt_dns.touch_ip(&ip);
match virt_dns.resolve_ip(&ip) {
None => connection.clone(), None => connection.clone(),
Some(name) => connection.to_named(name.clone()), Some(name) => connection.to_named(name.clone()),
} }
@ -564,6 +565,10 @@ impl<'a> TunToProxy<'a> {
{ {
let socket = self.sockets.get_mut::<tcp::Socket>(socket_handle); let socket = self.sockets.get_mut::<tcp::Socket>(socket_handle);
if socket.may_send() { if socket.may_send() {
if let Some(virtdns) = &mut self.options.virtdns {
// Unwrapping is fine because every smoltcp socket is bound to an.
virtdns.touch_ip(&IpAddr::from(socket.local_endpoint().unwrap().addr));
}
consumed = socket.send_slice(event.buffer).unwrap(); consumed = socket.send_slice(event.buffer).unwrap();
state state
.handler .handler

View file

@ -1,3 +1,5 @@
use hashlink::linked_hash_map::RawEntryMut;
use hashlink::LruCache;
use smoltcp::wire::Ipv4Cidr; use smoltcp::wire::Ipv4Cidr;
use std::collections::{HashMap, LinkedList}; use std::collections::{HashMap, LinkedList};
use std::convert::{TryFrom, TryInto}; use std::convert::{TryFrom, TryInto};
@ -5,8 +7,8 @@ use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use std::str::FromStr; use std::str::FromStr;
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
const DNS_TTL: u8 = 30; // TTL in DNS replies const DNS_TTL: u8 = 30; // TTL in DNS replies in seconds
const MAPPING_TIMEOUT: u64 = 60; // Mapping timeout const MAPPING_TIMEOUT: u64 = 60; // Mapping timeout in seconds
#[derive(Eq, PartialEq, Debug)] #[derive(Eq, PartialEq, Debug)]
#[allow(dead_code, clippy::upper_case_acronyms)] #[allow(dead_code, clippy::upper_case_acronyms)]
@ -21,10 +23,13 @@ enum DnsClass {
IN = 1, IN = 1,
} }
#[derive(Clone, Debug)] struct NameCacheEntry {
name: String,
expiry: Instant,
}
pub struct VirtualDns { pub struct VirtualDns {
ip_to_name: HashMap<IpAddr, String>, lru_cache: LruCache<IpAddr, NameCacheEntry>,
expiry: LinkedList<(IpAddr, Instant)>,
name_to_ip: HashMap<String, IpAddr>, name_to_ip: HashMap<String, IpAddr>,
network_addr: IpAddr, network_addr: IpAddr,
broadcast_addr: IpAddr, broadcast_addr: IpAddr,
@ -38,11 +43,10 @@ impl Default for VirtualDns {
Self { Self {
next_addr: start_addr.into(), next_addr: start_addr.into(),
ip_to_name: Default::default(),
name_to_ip: Default::default(), name_to_ip: Default::default(),
expiry: Default::default(),
network_addr: IpAddr::try_from(cidr.network().address().into_address()).unwrap(), network_addr: IpAddr::try_from(cidr.network().address().into_address()).unwrap(),
broadcast_addr: IpAddr::try_from(cidr.broadcast().unwrap().into_address()).unwrap(), broadcast_addr: IpAddr::try_from(cidr.broadcast().unwrap().into_address()).unwrap(),
lru_cache: LruCache::new_unbounded(),
} }
} }
} }
@ -160,21 +164,43 @@ impl VirtualDns {
} }
} }
pub fn ip_to_name(&self, addr: &IpAddr) -> Option<&String> { // This is to be called whenever we receive or send a packet on the socket
self.ip_to_name.get(addr) // which connects the tun interface to the client, so existing IP address to name
// mappings to not expire as long as the connection is active.
pub fn touch_ip(&mut self, addr: &IpAddr) -> bool {
match self.lru_cache.get_mut(addr) {
None => false,
Some(entry) => {
entry.expiry = Instant::now() + Duration::from_secs(MAPPING_TIMEOUT);
true
}
}
}
pub fn resolve_ip(&mut self, addr: &IpAddr) -> Option<&String> {
match self.lru_cache.get(addr) {
None => None,
Some(entry) => Some(&entry.name),
}
} }
fn allocate_ip(&mut self, name: String) -> Option<IpAddr> { fn allocate_ip(&mut self, name: String) -> Option<IpAddr> {
let now = Instant::now(); let now = Instant::now();
while let Some((ip, expiry)) = self.expiry.front() {
if now > *expiry { // Building the to_remove list seems to be a bit clunky.
let name = self.ip_to_name.remove(ip).unwrap(); // But removing inside the loop does not immediately work due to borrow rules.
self.name_to_ip.remove(&name); // TODO: Is there a better solution?
self.expiry.pop_front(); let mut to_remove = LinkedList::<IpAddr>::new();
} else { for (ip, entry) in self.lru_cache.iter() {
break; if now > entry.expiry {
to_remove.push_back(*ip);
self.name_to_ip.remove(&entry.name);
continue;
} }
} }
for ip in to_remove {
self.lru_cache.remove(&ip);
}
if let Some(ip) = self.name_to_ip.get(&name) { if let Some(ip) = self.name_to_ip.get(&name) {
return Some(*ip); return Some(*ip);
@ -183,15 +209,19 @@ impl VirtualDns {
let started_at = self.next_addr; let started_at = self.next_addr;
loop { loop {
if let std::collections::hash_map::Entry::Vacant(e) = if let RawEntryMut::Vacant(vacant) =
self.ip_to_name.entry(self.next_addr) self.lru_cache.raw_entry_mut().from_key(&self.next_addr)
{ {
e.insert(name.clone()); let expiry = Instant::now() + Duration::from_secs(MAPPING_TIMEOUT);
self.name_to_ip.insert(name, self.next_addr); vacant.insert(
self.expiry.push_back((
self.next_addr, self.next_addr,
Instant::now() + Duration::from_secs(MAPPING_TIMEOUT), NameCacheEntry {
)); name: name.clone(),
expiry,
},
);
// e.insert(name.clone());
self.name_to_ip.insert(name, self.next_addr);
return Some(self.next_addr); return Some(self.next_addr);
} }
self.next_addr = Self::increment_ip(self.next_addr); self.next_addr = Self::increment_ip(self.next_addr);