diff --git a/src/tun2proxy.rs b/src/tun2proxy.rs index de2167d..f305323 100644 --- a/src/tun2proxy.rs +++ b/src/tun2proxy.rs @@ -482,6 +482,50 @@ impl<'a> TunToProxy<'a> { Ok(connection_info) } + fn process_incoming_udp_packets_dns_over_tcp( + &mut self, + manager: &Rc, + info: &ConnectionInfo, + origin_dst: SocketAddr, + payload: &[u8], + ) -> Result<()> { + _ = dns::parse_data_to_dns_message(payload, false)?; + + if !self.connection_map.contains_key(info) { + log::info!("DNS over TCP {} ({})", info, origin_dst); + let tcp_proxy_handler = manager.new_tcp_proxy(info, false)?; + let server_addr = manager.get_server_addr(); + let state = self.create_new_tcp_connection_state(server_addr, origin_dst, tcp_proxy_handler, false)?; + self.connection_map.insert(info.clone(), state); + + self.expect_smoltcp_send()?; + self.tunsocket_read_and_forward(info)?; + self.write_to_server(info)?; + } else { + log::trace!("DNS over TCP subsequent packet {} ({})", info, origin_dst); + } + + // Insert the DNS message length in front of the payload + let len = u16::try_from(payload.len())?; + let mut buf = Vec::with_capacity(2 + usize::from(len)); + buf.extend_from_slice(&len.to_be_bytes()); + buf.extend_from_slice(payload); + + let err = "udp over tcp state not find"; + let state = self.connection_map.get_mut(info).ok_or(err)?; + state.udp_over_tcp_expiry = Some(Self::common_udp_life_timeout()); + if state.tcp_proxy_handler.connection_established() { + _ = state.mio_stream.write(&buf)?; + } else { + // FIXME: Build an IP packet with TCP and inject it into the device, + // or cache them and send them when the connection is established? + // self.device.inject_packet(&buf); + state.udp_over_tcp_data_cache.push_back(buf); + } + + Ok(()) + } + fn process_incoming_udp_packets( &mut self, manager: &Rc, @@ -489,44 +533,6 @@ impl<'a> TunToProxy<'a> { origin_dst: SocketAddr, payload: &[u8], ) -> Result<()> { - if self.options.dns_over_tcp && origin_dst.port() == DNS_PORT { - _ = dns::parse_data_to_dns_message(payload, false)?; - - if !self.connection_map.contains_key(info) { - log::info!("DNS over TCP {} ({})", info, origin_dst); - let tcp_proxy_handler = manager.new_tcp_proxy(info, false)?; - let server_addr = manager.get_server_addr(); - let state = self.create_new_tcp_connection_state(server_addr, origin_dst, tcp_proxy_handler, false)?; - self.connection_map.insert(info.clone(), state); - - self.expect_smoltcp_send()?; - self.tunsocket_read_and_forward(info)?; - self.write_to_server(info)?; - } else { - log::trace!("DNS over TCP subsequent packet {} ({})", info, origin_dst); - } - - // Insert the DNS message length in front of the payload - let len = u16::try_from(payload.len())?; - let mut buf = Vec::with_capacity(2 + usize::from(len)); - buf.extend_from_slice(&len.to_be_bytes()); - buf.extend_from_slice(payload); - - let err = "udp over tcp state not find"; - let state = self.connection_map.get_mut(info).ok_or(err)?; - state.udp_over_tcp_expiry = Some(Self::common_udp_life_timeout()); - if state.tcp_proxy_handler.connection_established() { - _ = state.mio_stream.write(&buf)?; - } else { - // FIXME: Build an IP packet with TCP and inject it into the device, - // or cache them and send them when the connection is established? - // self.device.inject_packet(&buf); - state.udp_over_tcp_data_cache.push_back(buf); - } - - return Ok(()); - } - if !self.connection_map.contains_key(info) { log::info!("UDP associate session {} ({})", info, origin_dst); let tcp_proxy_handler = manager.new_tcp_proxy(info, true)?; @@ -617,7 +623,12 @@ impl<'a> TunToProxy<'a> { self.send_udp_packet_to_client(origin_dst, connection_info.src, response.as_slice())?; } else { // Another UDP packet - self.process_incoming_udp_packets(&manager, &connection_info, origin_dst, payload)?; + if self.options.dns_over_tcp && origin_dst.port() == DNS_PORT { + let info = &connection_info; + self.process_incoming_udp_packets_dns_over_tcp(&manager, info, origin_dst, payload)?; + } else { + self.process_incoming_udp_packets(&manager, &connection_info, origin_dst, payload)?; + } } } else { log::warn!("Unsupported protocol: {} ({})", connection_info, origin_dst);