diff --git a/Cargo.toml b/Cargo.toml index 25a01d5..7752eee 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,6 +25,7 @@ prctl = "1.0" smoltcp = { version = "0.10.0", features = ["std", "phy-tuntap_interface"] } socks5-impl = { version = "0.5", default-features = false } thiserror = "1.0" +trust-dns-proto = "0.22" unicase = "2.6.0" url = "2.4" diff --git a/src/dns.rs b/src/dns.rs new file mode 100644 index 0000000..f4b2404 --- /dev/null +++ b/src/dns.rs @@ -0,0 +1,104 @@ +#![allow(dead_code)] + +use std::{net::IpAddr, str::FromStr}; +use trust_dns_proto::{ + op::{Message, ResponseCode}, + rr::{record_type::RecordType, Name, RData, Record}, +}; + +#[cfg(feature = "use-rand")] +pub fn build_dns_request( + domain: &str, + query_type: RecordType, + used_by_tcp: bool, +) -> Result, String> { + // [dependencies] + // rand = "0.8" + use rand::{rngs::StdRng, Rng, SeedableRng}; + use trust_dns_proto::op::{header::MessageType, op_code::OpCode, query::Query}; + let name = Name::from_str(domain).map_err(|e| e.to_string())?; + let query = Query::query(name, query_type); + let mut msg = Message::new(); + msg.add_query(query) + .set_id(StdRng::from_entropy().gen()) + .set_op_code(OpCode::Query) + .set_message_type(MessageType::Query) + .set_recursion_desired(true); + let mut msg_buf = msg.to_vec().map_err(|e| e.to_string())?; + if used_by_tcp { + let mut buf = (msg_buf.len() as u16).to_be_bytes().to_vec(); + buf.append(&mut msg_buf); + Ok(buf) + } else { + Ok(msg_buf) + } +} + +pub fn build_dns_response( + mut request: Message, + domain: &str, + ip: IpAddr, + ttl: u32, +) -> Result { + let record = match ip { + IpAddr::V4(ip) => { + let mut record = Record::with(Name::from_str(domain)?, RecordType::A, ttl); + record.set_data(Some(RData::A(ip))); + record + } + IpAddr::V6(ip) => { + let mut record = Record::with(Name::from_str(domain)?, RecordType::AAAA, ttl); + record.set_data(Some(RData::AAAA(ip))); + record + } + }; + request.add_answer(record); + Ok(request) +} + +pub fn extract_ipaddr_from_dns_message(message: &Message) -> Result { + if message.response_code() != ResponseCode::NoError { + return Err(format!("{:?}", message.response_code())); + } + let mut cname = None; + for answer in message.answers() { + match answer + .data() + .ok_or("DNS response not contains answer data")? + { + RData::A(addr) => { + return Ok(IpAddr::V4(*addr)); + } + RData::AAAA(addr) => { + return Ok(IpAddr::V6(*addr)); + } + RData::CNAME(name) => { + cname = Some(name.to_utf8()); + } + _ => {} + } + } + if let Some(cname) = cname { + return Err(cname); + } + Err(format!("{:?}", message.answers())) +} + +pub fn extract_domain_from_dns_message(message: &Message) -> Result { + let query = message.queries().get(0).ok_or("DnsRequest no query body")?; + let name = query.name().to_string(); + Ok(name) +} + +pub fn parse_data_to_dns_message(data: &[u8], used_by_tcp: bool) -> Result { + if used_by_tcp { + if data.len() < 2 { + return Err("invalid dns data".into()); + } + let len = u16::from_be_bytes([data[0], data[1]]) as usize; + let data = data.get(2..len + 2).ok_or("invalid dns data")?; + return parse_data_to_dns_message(data, false); + } + let message = Message::from_vec(data).map_err(|e| e.to_string())?; + Ok(message) +} diff --git a/src/error.rs b/src/error.rs index 1a57783..b0c916d 100644 --- a/src/error.rs +++ b/src/error.rs @@ -33,6 +33,12 @@ pub enum Error { #[error("std::str::Utf8Error {0:?}")] Utf8(#[from] std::str::Utf8Error), + #[error("TryFromSliceError {0:?}")] + TryFromSlice(#[from] std::array::TryFromSliceError), + + #[error("ProtoError {0:?}")] + ProtoError(#[from] trust_dns_proto::error::ProtoError), + #[cfg(target_os = "android")] #[error("jni::errors::Error {0:?}")] Jni(#[from] jni::errors::Error), diff --git a/src/lib.rs b/src/lib.rs index 98f69c6..9588a4f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,12 +1,17 @@ -use crate::{error::Error, http::HttpManager, socks::SocksProxyManager, tun2proxy::TunToProxy}; +use crate::{ + error::Error, + http::HttpManager, + socks::SocksProxyManager, + tun2proxy::{ConnectionManager, TunToProxy}, +}; use socks5_impl::protocol::{UserKey, Version}; use std::{ net::{SocketAddr, ToSocketAddrs}, rc::Rc, }; -use tun2proxy::ConnectionManager; mod android; +mod dns; pub mod error; mod http; pub mod setup; diff --git a/src/tun2proxy.rs b/src/tun2proxy.rs index 29675e2..8e99aff 100644 --- a/src/tun2proxy.rs +++ b/src/tun2proxy.rs @@ -508,7 +508,8 @@ impl<'a> TunToProxy<'a> { let port = connection_info.dst.port(); if let (Some(virtual_dns), true) = (&mut self.options.virtual_dns, port == 53) { let payload = &frame[payload_offset..payload_offset + payload_size]; - if let Some(response) = virtual_dns.receive_query(payload) { + let response = virtual_dns.receive_query(payload)?; + { let rx_buffer = udp::PacketBuffer::new(vec![udp::PacketMetadata::EMPTY], vec![0; 4096]); let tx_buffer = diff --git a/src/virtdns.rs b/src/virtdns.rs index 1da5c54..9e13e05 100644 --- a/src/virtdns.rs +++ b/src/virtdns.rs @@ -1,3 +1,4 @@ +use crate::error::Result; use hashlink::{linked_hash_map::RawEntryMut, LruCache}; use smoltcp::wire::Ipv4Cidr; use std::{ @@ -8,22 +9,8 @@ use std::{ time::{Duration, Instant}, }; -const DNS_TTL: u8 = 30; // TTL in DNS replies in seconds const MAPPING_TIMEOUT: u64 = 60; // Mapping timeout in seconds -#[derive(Eq, PartialEq, Debug)] -#[allow(dead_code, clippy::upper_case_acronyms)] -enum DnsRecordType { - A = 1, - AAAA = 28, -} - -#[derive(Eq, PartialEq, Debug)] -#[allow(dead_code)] -enum DnsClass { - IN = 1, -} - struct NameCacheEntry { name: String, expiry: Instant, @@ -57,89 +44,16 @@ impl VirtualDns { VirtualDns::default() } - pub fn receive_query(&mut self, data: &[u8]) -> Option> { - if data.len() < 17 { - return None; - } - // bit 1: Message is a query (0) - // bits 2 - 5: Standard query opcode (0) - // bit 6: Unused - // bit 7: Message is not truncated (0) - // bit 8: Recursion desired (1) - let is_supported_query = (data[2] & 0b11111011) == 0b00000001; - let num_queries = u16::from_be_bytes(data[4..6].try_into().ok()?); - if !is_supported_query || num_queries != 1 { - return None; - } - - let (qname, offset) = VirtualDns::parse_qname(data, 12)?; - if offset + 3 >= data.len() { - return None; - } - let qtype = u16::from_be_bytes(data[offset..offset + 2].try_into().ok()?); - let qclass = u16::from_be_bytes(data[offset + 2..offset + 4].try_into().ok()?); - - if qtype != DnsRecordType::A as u16 && qtype != DnsRecordType::AAAA as u16 - || qclass != DnsClass::IN as u16 - { - return None; - } - - if qtype == DnsRecordType::A as u16 { - log::info!("DNS query: {}", qname); - } - - let mut response = Vec::::new(); - response.extend(&data[0..offset + 4]); - response[2] |= 0x80; // Message is a response - response[3] |= 0x80; // Recursion available - - // Record count of the answer section: - // We only send an answer record for A queries, assuming that IPv4 is supported everywhere. - // This way, we do not have to handle two IP spaces for the virtual DNS feature. - response[6] = 0; - response[7] = if qtype == DnsRecordType::A as u16 { - 1 - } else { - 0 - }; - - // Zero count of other sections: - // authority section - response[8] = 0; - response[9] = 0; - - // additional section - response[10] = 0; - response[11] = 0; - if qtype == DnsRecordType::A as u16 { - if let Some(ip) = self.allocate_ip(qname) { - response.extend(&[ - 0xc0, 0x0c, // Question name pointer - 0, 1, // Record type: A - 0, 1, // Class: IN - 0, 0, 0, DNS_TTL, // TTL - 0, 4, // Data length: 4 bytes - ]); - match ip { - IpAddr::V4(ip) => response.extend(ip.octets().as_ref()), - IpAddr::V6(ip) => response.extend(ip.octets().as_ref()), - }; - } else { - log::error!("Virtual IP space for DNS exhausted"); - response[7] = 0; // No answers - - // Set rcode to SERVFAIL - response[3] &= 0xf0; - response[3] |= 2; - } - } else { - response[7] = 0; // No answers - } - Some(response) + pub fn receive_query(&mut self, data: &[u8]) -> Result> { + use crate::dns; + let message = dns::parse_data_to_dns_message(data, false)?; + let qname = dns::extract_domain_from_dns_message(&message)?; + let ip = self.allocate_ip(qname.clone())?; + let message = dns::build_dns_response(message, &qname, ip, 5)?; + Ok(message.to_vec()?) } - fn increment_ip(addr: IpAddr) -> Option { + fn increment_ip(addr: IpAddr) -> Result { let mut ip_bytes = match addr as IpAddr { IpAddr::V4(ip) => Vec::::from(ip.octets()), IpAddr::V6(ip) => Vec::::from(ip.octets()), @@ -158,36 +72,29 @@ impl VirtualDns { } } let addr = if addr.is_ipv4() { - let bytes: [u8; 4] = ip_bytes.as_slice().try_into().ok()?; + let bytes: [u8; 4] = ip_bytes.as_slice().try_into()?; IpAddr::V4(Ipv4Addr::from(bytes)) } else { - let bytes: [u8; 16] = ip_bytes.as_slice().try_into().ok()?; + let bytes: [u8; 16] = ip_bytes.as_slice().try_into()?; IpAddr::V6(Ipv6Addr::from(bytes)) }; - Some(addr) + Ok(addr) } // This is to be called whenever we receive or send a packet on the socket // which connects the tun interface to the client, so existing IP address to name // mappings to not expire as long as the connection is active. - pub fn touch_ip(&mut self, addr: &IpAddr) -> bool { - match self.lru_cache.get_mut(addr) { - None => false, - Some(entry) => { - entry.expiry = Instant::now() + Duration::from_secs(MAPPING_TIMEOUT); - true - } - } + pub fn touch_ip(&mut self, addr: &IpAddr) { + _ = self.lru_cache.get_mut(addr).map(|entry| { + entry.expiry = Instant::now() + Duration::from_secs(MAPPING_TIMEOUT); + }); } pub fn resolve_ip(&mut self, addr: &IpAddr) -> Option<&String> { - match self.lru_cache.get(addr) { - None => None, - Some(entry) => Some(&entry.name), - } + self.lru_cache.get(addr).map(|entry| &entry.name) } - fn allocate_ip(&mut self, name: String) -> Option { + fn allocate_ip(&mut self, name: String) -> Result { let now = Instant::now(); loop { @@ -205,9 +112,9 @@ impl VirtualDns { } if let Some(ip) = self.name_to_ip.get(&name) { - let result = Some(*ip); - self.touch_ip(&ip.clone()); - return result; + let ip = *ip; + self.touch_ip(&ip); + return Ok(ip); } let started_at = self.next_addr; @@ -217,16 +124,10 @@ impl VirtualDns { self.lru_cache.raw_entry_mut().from_key(&self.next_addr) { let expiry = Instant::now() + Duration::from_secs(MAPPING_TIMEOUT); - vacant.insert( - self.next_addr, - NameCacheEntry { - name: name.clone(), - expiry, - }, - ); - // e.insert(name.clone()); - self.name_to_ip.insert(name, self.next_addr); - return Some(self.next_addr); + let name0 = name.clone(); + vacant.insert(self.next_addr, NameCacheEntry { name, expiry }); + self.name_to_ip.insert(name0, self.next_addr); + return Ok(self.next_addr); } self.next_addr = Self::increment_ip(self.next_addr)?; if self.next_addr == self.broadcast_addr { @@ -234,47 +135,8 @@ impl VirtualDns { self.next_addr = self.network_addr; } if self.next_addr == started_at { - return None; + return Err("Virtual IP space for DNS exhausted".into()); } } } - - /// Parse a non-root DNS qname at a specific offset and return the name along with its size. - /// DNS packet parsing should be continued after the name. - fn parse_qname(data: &[u8], mut offset: usize) -> Option<(String, usize)> { - // Since we only parse qnames and qnames can't point anywhere, - // we do not support pointers. (0xC0 is a bitmask for pointer detection.) - let label_type = data[offset] & 0xC0; - if label_type != 0x00 { - return None; - } - - let mut qname = String::from(""); - loop { - if offset >= data.len() { - return None; - } - let label_len = data[offset]; - if label_len == 0 { - if qname.is_empty() { - return None; - } - offset += 1; - break; - } - if !qname.is_empty() { - qname.push('.'); - } - for _ in 0..label_len { - offset += 1; - if offset >= data.len() { - return None; - } - qname.push(data[offset] as char); - } - offset += 1; - } - - Some((qname, offset)) - } }