Use LRU cache for virtual DNS

This commit introduces an LRU cache for virtual DNS and fixes a bug
where the virtual DNS feature would cause connection mappings to time
out.
This commit is contained in:
B. Blechschmidt 2023-03-24 13:26:31 +01:00
parent 7dec7f59f1
commit 3c8005e6b7
3 changed files with 62 additions and 26 deletions

View file

@ -253,7 +253,7 @@ pub(crate) trait ConnectionManager {
fn get_credentials(&self) -> &Option<Credentials>;
}
#[derive(Default, Clone, Debug)]
#[derive(Default)]
pub struct Options {
virtdns: Option<VirtualDns>,
}
@ -423,11 +423,12 @@ impl<'a> TunToProxy<'a> {
if let Some((connection, first_packet, _payload_offset, _payload_size)) =
connection_tuple(frame)
{
let resolved_conn = match &self.options.virtdns {
let resolved_conn = match &mut self.options.virtdns {
None => connection.clone(),
Some(virt_dns) => {
let ip = SocketAddr::try_from(connection.dst.clone()).unwrap().ip();
match virt_dns.ip_to_name(&ip) {
virt_dns.touch_ip(&ip);
match virt_dns.resolve_ip(&ip) {
None => connection.clone(),
Some(name) => connection.to_named(name.clone()),
}
@ -564,6 +565,10 @@ impl<'a> TunToProxy<'a> {
{
let socket = self.sockets.get_mut::<tcp::Socket>(socket_handle);
if socket.may_send() {
if let Some(virtdns) = &mut self.options.virtdns {
// Unwrapping is fine because every smoltcp socket is bound to an.
virtdns.touch_ip(&IpAddr::from(socket.local_endpoint().unwrap().addr));
}
consumed = socket.send_slice(event.buffer).unwrap();
state
.handler

View file

@ -1,3 +1,5 @@
use hashlink::linked_hash_map::RawEntryMut;
use hashlink::LruCache;
use smoltcp::wire::Ipv4Cidr;
use std::collections::{HashMap, LinkedList};
use std::convert::{TryFrom, TryInto};
@ -5,8 +7,8 @@ use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use std::str::FromStr;
use std::time::{Duration, Instant};
const DNS_TTL: u8 = 30; // TTL in DNS replies
const MAPPING_TIMEOUT: u64 = 60; // Mapping timeout
const DNS_TTL: u8 = 30; // TTL in DNS replies in seconds
const MAPPING_TIMEOUT: u64 = 60; // Mapping timeout in seconds
#[derive(Eq, PartialEq, Debug)]
#[allow(dead_code, clippy::upper_case_acronyms)]
@ -21,10 +23,13 @@ enum DnsClass {
IN = 1,
}
#[derive(Clone, Debug)]
struct NameCacheEntry {
name: String,
expiry: Instant,
}
pub struct VirtualDns {
ip_to_name: HashMap<IpAddr, String>,
expiry: LinkedList<(IpAddr, Instant)>,
lru_cache: LruCache<IpAddr, NameCacheEntry>,
name_to_ip: HashMap<String, IpAddr>,
network_addr: IpAddr,
broadcast_addr: IpAddr,
@ -38,11 +43,10 @@ impl Default for VirtualDns {
Self {
next_addr: start_addr.into(),
ip_to_name: Default::default(),
name_to_ip: Default::default(),
expiry: Default::default(),
network_addr: IpAddr::try_from(cidr.network().address().into_address()).unwrap(),
broadcast_addr: IpAddr::try_from(cidr.broadcast().unwrap().into_address()).unwrap(),
lru_cache: LruCache::new_unbounded(),
}
}
}
@ -160,21 +164,43 @@ impl VirtualDns {
}
}
pub fn ip_to_name(&self, addr: &IpAddr) -> Option<&String> {
self.ip_to_name.get(addr)
// This is to be called whenever we receive or send a packet on the socket
// which connects the tun interface to the client, so existing IP address to name
// mappings to not expire as long as the connection is active.
pub fn touch_ip(&mut self, addr: &IpAddr) -> bool {
match self.lru_cache.get_mut(addr) {
None => false,
Some(entry) => {
entry.expiry = Instant::now() + Duration::from_secs(MAPPING_TIMEOUT);
true
}
}
}
pub fn resolve_ip(&mut self, addr: &IpAddr) -> Option<&String> {
match self.lru_cache.get(addr) {
None => None,
Some(entry) => Some(&entry.name),
}
}
fn allocate_ip(&mut self, name: String) -> Option<IpAddr> {
let now = Instant::now();
while let Some((ip, expiry)) = self.expiry.front() {
if now > *expiry {
let name = self.ip_to_name.remove(ip).unwrap();
self.name_to_ip.remove(&name);
self.expiry.pop_front();
} else {
break;
// Building the to_remove list seems to be a bit clunky.
// But removing inside the loop does not immediately work due to borrow rules.
// TODO: Is there a better solution?
let mut to_remove = LinkedList::<IpAddr>::new();
for (ip, entry) in self.lru_cache.iter() {
if now > entry.expiry {
to_remove.push_back(*ip);
self.name_to_ip.remove(&entry.name);
continue;
}
}
for ip in to_remove {
self.lru_cache.remove(&ip);
}
if let Some(ip) = self.name_to_ip.get(&name) {
return Some(*ip);
@ -183,15 +209,19 @@ impl VirtualDns {
let started_at = self.next_addr;
loop {
if let std::collections::hash_map::Entry::Vacant(e) =
self.ip_to_name.entry(self.next_addr)
if let RawEntryMut::Vacant(vacant) =
self.lru_cache.raw_entry_mut().from_key(&self.next_addr)
{
e.insert(name.clone());
self.name_to_ip.insert(name, self.next_addr);
self.expiry.push_back((
let expiry = Instant::now() + Duration::from_secs(MAPPING_TIMEOUT);
vacant.insert(
self.next_addr,
Instant::now() + Duration::from_secs(MAPPING_TIMEOUT),
));
NameCacheEntry {
name: name.clone(),
expiry,
},
);
// e.insert(name.clone());
self.name_to_ip.insert(name, self.next_addr);
return Some(self.next_addr);
}
self.next_addr = Self::increment_ip(self.next_addr);