diff --git a/src/error.rs b/src/error.rs index bae52ba..d45f0ee 100644 --- a/src/error.rs +++ b/src/error.rs @@ -9,6 +9,9 @@ pub enum Error { #[error("std::io::Error {0}")] Io(#[from] std::io::Error), + #[error("TryFromIntError {0:?}")] + TryFromInt(#[from] std::num::TryFromIntError), + #[error("std::net::AddrParseError {0}")] AddrParse(#[from] std::net::AddrParseError), diff --git a/src/lib.rs b/src/lib.rs index bdbf7a0..d21d99b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -98,6 +98,7 @@ impl std::fmt::Display for ProxyType { pub struct Options { virtual_dns: Option, mtu: Option, + dns_over_tcp: bool, } impl Options { @@ -107,6 +108,13 @@ impl Options { pub fn with_virtual_dns(mut self) -> Self { self.virtual_dns = Some(virtdns::VirtualDns::new()); + self.dns_over_tcp = false; + self + } + + pub fn with_dns_over_tcp(mut self) -> Self { + self.dns_over_tcp = true; + self.virtual_dns = None; self } diff --git a/src/main.rs b/src/main.rs index 1502599..dc749c4 100644 --- a/src/main.rs +++ b/src/main.rs @@ -40,6 +40,10 @@ struct Args { /// Verbosity level #[arg(short, long, value_name = "level", value_enum, default_value = "info")] verbosity: ArgVerbosity, + + /// Enable DNS over TCP + #[arg(long)] + dns_over_tcp: bool, } #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, clap::ValueEnum)] @@ -79,6 +83,10 @@ fn main() -> ExitCode { options = options.with_virtual_dns(); } + if args.dns_over_tcp { + options = options.with_dns_over_tcp(); + } + let interface = match args.tun_fd { None => NetworkInterface::Named(args.tun.clone()), Some(fd) => { diff --git a/src/tun2proxy.rs b/src/tun2proxy.rs index d71944d..2c0c22a 100644 --- a/src/tun2proxy.rs +++ b/src/tun2proxy.rs @@ -170,7 +170,7 @@ const CLIENT_WRITE_CLOSED: u8 = 2; const UDP_ASSO_TIMEOUT: u64 = 10; // seconds const DNS_PORT: u16 = 53; -struct TcpConnectState { +struct ConnectionState { smoltcp_handle: Option, mio_stream: TcpStream, token: Token, @@ -183,6 +183,8 @@ struct TcpConnectState { udp_token: Option, udp_origin_dst: Option, udp_data_cache: LinkedList>, + udp_over_tcp_expiry: Option<::std::time::Instant>, + is_tcp_dns: bool, } pub(crate) trait TcpProxy { @@ -210,7 +212,7 @@ pub struct TunToProxy<'a> { tun: TunTapInterface, poll: Poll, iface: Interface, - connection_map: HashMap, + connection_map: HashMap, connection_manager: Option>, next_token: usize, sockets: SocketSet<'a>, @@ -237,7 +239,7 @@ impl<'a> TunToProxy<'a> { #[rustfmt::skip] let config = match tun.capabilities().medium { - Medium::Ethernet => Config::new(smoltcp::wire::EthernetAddress([0x02, 0x00, 0x00, 0x00, 0x00, 0x01]).into()), + Medium::Ethernet => Config::new(smoltcp::wire::EthernetAddress([0x02, 0x00, 0x00, 0x00, 0x00, 0x01]).into()), Medium::Ip => Config::new(smoltcp::wire::HardwareAddress::Ip), Medium::Ieee802154 => todo!(), }; @@ -434,7 +436,7 @@ impl<'a> TunToProxy<'a> { Ok(()) } - fn update_mio_socket_interest(poll: &mut Poll, state: &mut TcpConnectState) -> Result<()> { + fn update_mio_socket_interest(poll: &mut Poll, state: &mut ConnectionState) -> Result<()> { // Maybe we did not listen for any events before. Therefore, just swallow the error. if let Err(err) = poll.registry().deregister(&mut state.mio_stream) { log::trace!("{}", err); @@ -459,15 +461,7 @@ impl<'a> TunToProxy<'a> { fn preprocess_origin_connection_info(&mut self, info: ConnectionInfo) -> Result { let origin_dst = SocketAddr::try_from(&info.dst)?; let connection_info = match &mut self.options.virtual_dns { - None => { - let mut info = info; - let port = origin_dst.port(); - if port == DNS_PORT && info.protocol == IpProtocol::Udp && dns::addr_is_private(&origin_dst) { - let dns_addr: SocketAddr = "8.8.8.8:53".parse()?; // TODO: Configurable - info.dst = Address::from(dns_addr); - } - info - } + None => info, Some(virtual_dns) => { let dst_ip = origin_dst.ip(); virtual_dns.touch_ip(&dst_ip); @@ -480,6 +474,147 @@ impl<'a> TunToProxy<'a> { Ok(connection_info) } + fn process_incoming_dns_over_tcp_packets( + &mut self, + manager: &Rc, + original_info: &ConnectionInfo, + origin_dst: SocketAddr, + payload: &[u8], + ) -> Result<()> { + _ = dns::parse_data_to_dns_message(payload, false)?; + let mut new_info = original_info.clone(); + let dns_addr: SocketAddr = "8.8.8.8:53".parse()?; + new_info.dst = Address::from(dns_addr); + + let info = &new_info; + + if !self.connection_map.contains_key(info) { + log::info!("DNS over TCP {} ({})", info, origin_dst); + + let tcp_proxy_handler = manager.new_tcp_proxy(info, false)?; + let server_addr = manager.get_server_addr(); + let mut state = self.create_new_tcp_connection_state(server_addr, origin_dst, tcp_proxy_handler, false)?; + state.is_tcp_dns = true; + state.udp_origin_dst = Some(SocketAddr::try_from(original_info.dst.clone())?); + self.connection_map.insert(info.clone(), state); + + // TODO: Move this 3 lines to the function end? + self.expect_smoltcp_send()?; + self.tunsocket_read_and_forward(info)?; + self.write_to_server(info)?; + } else { + log::trace!("DNS over TCP subsequent packet {} ({})", info, origin_dst); + } + + // Insert the DNS message length in front of the payload + let len = u16::try_from(payload.len())?; + let mut buf = Vec::with_capacity(2 + usize::from(len)); + buf.extend_from_slice(&len.to_be_bytes()); + buf.extend_from_slice(payload); + + let err = "udp over tcp state not find"; + let state = self.connection_map.get_mut(info).ok_or(err)?; + state.udp_over_tcp_expiry = Some(Self::common_udp_life_timeout()); + + let data_event = IncomingDataEvent { + direction: IncomingDirection::FromClient, + buffer: &buf, + }; + state.tcp_proxy_handler.push_data(data_event)?; + Ok(()) + } + + fn receive_dns_over_tcp_packet_and_write_to_client(&mut self, info: &ConnectionInfo) -> Result<()> { + let err = "udp connection state not found"; + let state = self.connection_map.get_mut(info).ok_or(err)?; + assert!(state.udp_over_tcp_expiry.is_some()); + state.udp_over_tcp_expiry = Some(Self::common_udp_life_timeout()); + + // Code similar to the code in parent function. TODO: Cleanup. + let mut vecbuf = Vec::::new(); + let read_result = state.mio_stream.read_to_end(&mut vecbuf); + let read = match read_result { + Ok(read_result) => read_result, + Err(error) => { + if error.kind() != std::io::ErrorKind::WouldBlock { + log::error!("{} Read from proxy: {}", info.dst, error); + } + vecbuf.len() + } + }; + + let data = vecbuf.as_slice(); + let data_event = IncomingDataEvent { + direction: IncomingDirection::FromServer, + buffer: &data[0..read], + }; + if let Err(error) = state.tcp_proxy_handler.push_data(data_event) { + log::error!("{}", error); + self.remove_connection(&info.clone())?; + return Ok(()); + } + + let dns_event = state.tcp_proxy_handler.peek_data(OutgoingDirection::ToClient); + + let mut buf = dns_event.buffer.to_vec(); + let mut to_send: LinkedList> = LinkedList::new(); + loop { + if buf.len() < 2 { + break; + } + let len = u16::from_be_bytes([buf[0], buf[1]]) as usize; + if buf.len() < len + 2 { + break; + } + let data = buf[2..len + 2].to_vec(); + + let mut message = dns::parse_data_to_dns_message(&data, false)?; + + let name = dns::extract_domain_from_dns_message(&message)?; + let ip = dns::extract_ipaddr_from_dns_message(&message); + log::info!("DNS over TCP query result: {} -> {:?}", name, ip); + + state + .tcp_proxy_handler + .consume_data(OutgoingDirection::ToClient, len + 2); + + dns::remove_ipv6_entries(&mut message); // TODO: Configurable + + to_send.push_back(message.to_vec()?); + if len + 2 == buf.len() { + break; + } + buf = buf[len + 2..].to_vec(); + } + + // Write to client + let src = state.udp_origin_dst.ok_or("Expected UDP addr")?; + while let Some(packet) = to_send.pop_front() { + self.send_udp_packet_to_client(src, info.src, &packet)?; + } + Ok(()) + } + + fn udp_over_tcp_timeout_expired(&self, info: &ConnectionInfo) -> bool { + if let Some(state) = self.connection_map.get(info) { + if let Some(expiry) = state.udp_over_tcp_expiry { + return expiry < ::std::time::Instant::now(); + } + } + false + } + + fn clearup_expired_dns_over_tcp(&mut self) -> Result<()> { + let keys = self.connection_map.keys().cloned().collect::>(); + for key in keys { + if self.udp_over_tcp_timeout_expired(&key) { + log::trace!("UDP over TCP timeout: {}", key); + self.remove_connection(&key)?; + } + } + Ok(()) + } + fn process_incoming_udp_packets( &mut self, manager: &Rc, @@ -505,7 +640,7 @@ impl<'a> TunToProxy<'a> { let err = "udp associate state not find"; let state = self.connection_map.get_mut(info).ok_or(err)?; assert!(state.udp_acco_expiry.is_some()); - state.udp_acco_expiry = Some(Self::udp_associate_timeout()); + state.udp_acco_expiry = Some(Self::common_udp_life_timeout()); // Add SOCKS5 UDP header to the incoming data let mut s5_udp_data = Vec::::new(); @@ -535,23 +670,23 @@ impl<'a> TunToProxy<'a> { } let (info, _first_packet, payload_offset, payload_size) = result?; let origin_dst = SocketAddr::try_from(&info.dst)?; - let connection_info = self.preprocess_origin_connection_info(info)?; + let info = self.preprocess_origin_connection_info(info)?; let manager = self.get_connection_manager().ok_or("get connection manager")?; - if connection_info.protocol == IpProtocol::Tcp { + if info.protocol == IpProtocol::Tcp { if _first_packet { - let tcp_proxy_handler = manager.new_tcp_proxy(&connection_info, false)?; + let tcp_proxy_handler = manager.new_tcp_proxy(&info, false)?; let server = manager.get_server_addr(); let state = self.create_new_tcp_connection_state(server, origin_dst, tcp_proxy_handler, false)?; - self.connection_map.insert(connection_info.clone(), state); + self.connection_map.insert(info.clone(), state); - log::info!("Connect done {} ({})", connection_info, origin_dst); - } else if !self.connection_map.contains_key(&connection_info) { - // log::debug!("Drop middle session {} ({})", connection_info, origin_dst); + log::info!("Connect done {} ({})", info, origin_dst); + } else if !self.connection_map.contains_key(&info) { + // log::debug!("Drop middle session {} ({})", info, origin_dst); return Ok(()); } else { - // log::trace!("Subsequent packet {} ({})", connection_info, origin_dst); + // log::trace!("Subsequent packet {} ({})", info, origin_dst); } // Inject the packet to advance the remote proxy server smoltcp socket state @@ -563,24 +698,28 @@ impl<'a> TunToProxy<'a> { self.expect_smoltcp_send()?; // Read from the smoltcp socket and push the data to the connection handler. - self.tunsocket_read_and_forward(&connection_info)?; + self.tunsocket_read_and_forward(&info)?; // The connection handler builds up the connection or encapsulates the data. // Therefore, we now expect it to write data to the server. - self.write_to_server(&connection_info)?; - } else if connection_info.protocol == IpProtocol::Udp { - let port = connection_info.dst.port(); + self.write_to_server(&info)?; + } else if info.protocol == IpProtocol::Udp { + let port = info.dst.port(); let payload = &frame[payload_offset..payload_offset + payload_size]; if let (Some(virtual_dns), true) = (&mut self.options.virtual_dns, port == DNS_PORT) { - log::info!("DNS query via virtual DNS {} ({})", connection_info, origin_dst); + log::info!("DNS query via virtual DNS {} ({})", info, origin_dst); let response = virtual_dns.receive_query(payload)?; - self.send_udp_packet_to_client(origin_dst, connection_info.src, response.as_slice())?; + self.send_udp_packet_to_client(origin_dst, info.src, response.as_slice())?; } else { // Another UDP packet - self.process_incoming_udp_packets(&manager, &connection_info, origin_dst, payload)?; + if self.options.dns_over_tcp && origin_dst.port() == DNS_PORT { + self.process_incoming_dns_over_tcp_packets(&manager, &info, origin_dst, payload)?; + } else { + self.process_incoming_udp_packets(&manager, &info, origin_dst, payload)?; + } } } else { - log::warn!("Unsupported protocol: {} ({})", connection_info, origin_dst); + log::warn!("Unsupported protocol: {} ({})", info, origin_dst); } Ok::<(), Error>(()) }; @@ -596,7 +735,7 @@ impl<'a> TunToProxy<'a> { dst: SocketAddr, tcp_proxy_handler: Box, udp_associate: bool, - ) -> Result { + ) -> Result { let mut socket = tcp::Socket::new( tcp::SocketBuffer::new(vec![0; 1024 * 128]), tcp::SocketBuffer::new(vec![0; 1024 * 128]), @@ -611,7 +750,7 @@ impl<'a> TunToProxy<'a> { self.poll.registry().register(&mut client, token, i)?; let expiry = if udp_associate { - Some(Self::udp_associate_timeout()) + Some(Self::common_udp_life_timeout()) } else { None }; @@ -625,7 +764,7 @@ impl<'a> TunToProxy<'a> { } else { (None, None) }; - let state = TcpConnectState { + let state = ConnectionState { smoltcp_handle: Some(handle), mio_stream: client, token, @@ -638,11 +777,13 @@ impl<'a> TunToProxy<'a> { udp_token, udp_origin_dst: None, udp_data_cache: LinkedList::new(), + udp_over_tcp_expiry: None, + is_tcp_dns: false, }; Ok(state) } - fn udp_associate_timeout() -> ::std::time::Instant { + fn common_udp_life_timeout() -> ::std::time::Instant { ::std::time::Instant::now() + ::std::time::Duration::from_secs(UDP_ASSO_TIMEOUT) } @@ -777,7 +918,7 @@ impl<'a> TunToProxy<'a> { let err = "udp connection state not found"; let state = self.connection_map.get_mut(info).ok_or(err)?; assert!(state.udp_acco_expiry.is_some()); - state.udp_acco_expiry = Some(Self::udp_associate_timeout()); + state.udp_acco_expiry = Some(Self::common_udp_life_timeout()); let mut to_send: LinkedList> = LinkedList::new(); if let Some(udp_socket) = state.udp_socket.as_ref() { let mut buf = [0; 1 << 16]; @@ -807,7 +948,7 @@ impl<'a> TunToProxy<'a> { Ok(()) } - fn comsume_cached_udp_packets(&mut self, info: &ConnectionInfo) -> Result<()> { + fn consume_cached_udp_packets(&mut self, info: &ConnectionInfo) -> Result<()> { // Try to send the first UDP packets to remote SOCKS5 server for UDP associate session if let Some(state) = self.connection_map.get_mut(info) { if let Some(udp_socket) = state.udp_socket.as_ref() { @@ -843,7 +984,16 @@ impl<'a> TunToProxy<'a> { let mut block = || -> Result<(), Error> { if event.is_readable() || event.is_read_closed() { - { + let established = self + .connection_map + .get(&conn_info) + .ok_or("")? + .tcp_proxy_handler + .connection_established(); + if self.options.dns_over_tcp && conn_info.dst.port() == DNS_PORT && established { + self.receive_dns_over_tcp_packet_and_write_to_client(&conn_info)?; + return Ok(()); + } else { let e = "connection state not found"; let state = self.connection_map.get_mut(&conn_info).ok_or(e)?; @@ -906,7 +1056,7 @@ impl<'a> TunToProxy<'a> { // server. self.write_to_server(&conn_info)?; - self.comsume_cached_udp_packets(&conn_info)?; + self.consume_cached_udp_packets(&conn_info)?; } if event.is_writable() { @@ -943,6 +1093,7 @@ impl<'a> TunToProxy<'a> { } self.send_to_smoltcp()?; self.clearup_expired_udp_associate()?; + self.clearup_expired_dns_over_tcp()?; } }