Merge pull request #58 from blechschmidt/v8

DNS over TCP
This commit is contained in:
ssrlive 2023-08-23 09:50:30 +08:00 committed by GitHub
commit e518355756
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 208 additions and 38 deletions

View file

@ -9,6 +9,9 @@ pub enum Error {
#[error("std::io::Error {0}")]
Io(#[from] std::io::Error),
#[error("TryFromIntError {0:?}")]
TryFromInt(#[from] std::num::TryFromIntError),
#[error("std::net::AddrParseError {0}")]
AddrParse(#[from] std::net::AddrParseError),

View file

@ -98,6 +98,7 @@ impl std::fmt::Display for ProxyType {
pub struct Options {
virtual_dns: Option<virtdns::VirtualDns>,
mtu: Option<usize>,
dns_over_tcp: bool,
}
impl Options {
@ -107,6 +108,13 @@ impl Options {
pub fn with_virtual_dns(mut self) -> Self {
self.virtual_dns = Some(virtdns::VirtualDns::new());
self.dns_over_tcp = false;
self
}
pub fn with_dns_over_tcp(mut self) -> Self {
self.dns_over_tcp = true;
self.virtual_dns = None;
self
}

View file

@ -40,6 +40,10 @@ struct Args {
/// Verbosity level
#[arg(short, long, value_name = "level", value_enum, default_value = "info")]
verbosity: ArgVerbosity,
/// Enable DNS over TCP
#[arg(long)]
dns_over_tcp: bool,
}
#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, clap::ValueEnum)]
@ -79,6 +83,10 @@ fn main() -> ExitCode {
options = options.with_virtual_dns();
}
if args.dns_over_tcp {
options = options.with_dns_over_tcp();
}
let interface = match args.tun_fd {
None => NetworkInterface::Named(args.tun.clone()),
Some(fd) => {

View file

@ -170,7 +170,7 @@ const CLIENT_WRITE_CLOSED: u8 = 2;
const UDP_ASSO_TIMEOUT: u64 = 10; // seconds
const DNS_PORT: u16 = 53;
struct TcpConnectState {
struct ConnectionState {
smoltcp_handle: Option<SocketHandle>,
mio_stream: TcpStream,
token: Token,
@ -183,6 +183,8 @@ struct TcpConnectState {
udp_token: Option<Token>,
udp_origin_dst: Option<SocketAddr>,
udp_data_cache: LinkedList<Vec<u8>>,
udp_over_tcp_expiry: Option<::std::time::Instant>,
is_tcp_dns: bool,
}
pub(crate) trait TcpProxy {
@ -210,7 +212,7 @@ pub struct TunToProxy<'a> {
tun: TunTapInterface,
poll: Poll,
iface: Interface,
connection_map: HashMap<ConnectionInfo, TcpConnectState>,
connection_map: HashMap<ConnectionInfo, ConnectionState>,
connection_manager: Option<Rc<dyn ConnectionManager>>,
next_token: usize,
sockets: SocketSet<'a>,
@ -237,7 +239,7 @@ impl<'a> TunToProxy<'a> {
#[rustfmt::skip]
let config = match tun.capabilities().medium {
Medium::Ethernet => Config::new(smoltcp::wire::EthernetAddress([0x02, 0x00, 0x00, 0x00, 0x00, 0x01]).into()),
Medium::Ethernet => Config::new(smoltcp::wire::EthernetAddress([0x02, 0x00, 0x00, 0x00, 0x00, 0x01]).into()),
Medium::Ip => Config::new(smoltcp::wire::HardwareAddress::Ip),
Medium::Ieee802154 => todo!(),
};
@ -434,7 +436,7 @@ impl<'a> TunToProxy<'a> {
Ok(())
}
fn update_mio_socket_interest(poll: &mut Poll, state: &mut TcpConnectState) -> Result<()> {
fn update_mio_socket_interest(poll: &mut Poll, state: &mut ConnectionState) -> Result<()> {
// Maybe we did not listen for any events before. Therefore, just swallow the error.
if let Err(err) = poll.registry().deregister(&mut state.mio_stream) {
log::trace!("{}", err);
@ -459,15 +461,7 @@ impl<'a> TunToProxy<'a> {
fn preprocess_origin_connection_info(&mut self, info: ConnectionInfo) -> Result<ConnectionInfo> {
let origin_dst = SocketAddr::try_from(&info.dst)?;
let connection_info = match &mut self.options.virtual_dns {
None => {
let mut info = info;
let port = origin_dst.port();
if port == DNS_PORT && info.protocol == IpProtocol::Udp && dns::addr_is_private(&origin_dst) {
let dns_addr: SocketAddr = "8.8.8.8:53".parse()?; // TODO: Configurable
info.dst = Address::from(dns_addr);
}
info
}
None => info,
Some(virtual_dns) => {
let dst_ip = origin_dst.ip();
virtual_dns.touch_ip(&dst_ip);
@ -480,6 +474,147 @@ impl<'a> TunToProxy<'a> {
Ok(connection_info)
}
fn process_incoming_dns_over_tcp_packets(
&mut self,
manager: &Rc<dyn ConnectionManager>,
original_info: &ConnectionInfo,
origin_dst: SocketAddr,
payload: &[u8],
) -> Result<()> {
_ = dns::parse_data_to_dns_message(payload, false)?;
let mut new_info = original_info.clone();
let dns_addr: SocketAddr = "8.8.8.8:53".parse()?;
new_info.dst = Address::from(dns_addr);
let info = &new_info;
if !self.connection_map.contains_key(info) {
log::info!("DNS over TCP {} ({})", info, origin_dst);
let tcp_proxy_handler = manager.new_tcp_proxy(info, false)?;
let server_addr = manager.get_server_addr();
let mut state = self.create_new_tcp_connection_state(server_addr, origin_dst, tcp_proxy_handler, false)?;
state.is_tcp_dns = true;
state.udp_origin_dst = Some(SocketAddr::try_from(original_info.dst.clone())?);
self.connection_map.insert(info.clone(), state);
// TODO: Move this 3 lines to the function end?
self.expect_smoltcp_send()?;
self.tunsocket_read_and_forward(info)?;
self.write_to_server(info)?;
} else {
log::trace!("DNS over TCP subsequent packet {} ({})", info, origin_dst);
}
// Insert the DNS message length in front of the payload
let len = u16::try_from(payload.len())?;
let mut buf = Vec::with_capacity(2 + usize::from(len));
buf.extend_from_slice(&len.to_be_bytes());
buf.extend_from_slice(payload);
let err = "udp over tcp state not find";
let state = self.connection_map.get_mut(info).ok_or(err)?;
state.udp_over_tcp_expiry = Some(Self::common_udp_life_timeout());
let data_event = IncomingDataEvent {
direction: IncomingDirection::FromClient,
buffer: &buf,
};
state.tcp_proxy_handler.push_data(data_event)?;
Ok(())
}
fn receive_dns_over_tcp_packet_and_write_to_client(&mut self, info: &ConnectionInfo) -> Result<()> {
let err = "udp connection state not found";
let state = self.connection_map.get_mut(info).ok_or(err)?;
assert!(state.udp_over_tcp_expiry.is_some());
state.udp_over_tcp_expiry = Some(Self::common_udp_life_timeout());
// Code similar to the code in parent function. TODO: Cleanup.
let mut vecbuf = Vec::<u8>::new();
let read_result = state.mio_stream.read_to_end(&mut vecbuf);
let read = match read_result {
Ok(read_result) => read_result,
Err(error) => {
if error.kind() != std::io::ErrorKind::WouldBlock {
log::error!("{} Read from proxy: {}", info.dst, error);
}
vecbuf.len()
}
};
let data = vecbuf.as_slice();
let data_event = IncomingDataEvent {
direction: IncomingDirection::FromServer,
buffer: &data[0..read],
};
if let Err(error) = state.tcp_proxy_handler.push_data(data_event) {
log::error!("{}", error);
self.remove_connection(&info.clone())?;
return Ok(());
}
let dns_event = state.tcp_proxy_handler.peek_data(OutgoingDirection::ToClient);
let mut buf = dns_event.buffer.to_vec();
let mut to_send: LinkedList<Vec<u8>> = LinkedList::new();
loop {
if buf.len() < 2 {
break;
}
let len = u16::from_be_bytes([buf[0], buf[1]]) as usize;
if buf.len() < len + 2 {
break;
}
let data = buf[2..len + 2].to_vec();
let mut message = dns::parse_data_to_dns_message(&data, false)?;
let name = dns::extract_domain_from_dns_message(&message)?;
let ip = dns::extract_ipaddr_from_dns_message(&message);
log::info!("DNS over TCP query result: {} -> {:?}", name, ip);
state
.tcp_proxy_handler
.consume_data(OutgoingDirection::ToClient, len + 2);
dns::remove_ipv6_entries(&mut message); // TODO: Configurable
to_send.push_back(message.to_vec()?);
if len + 2 == buf.len() {
break;
}
buf = buf[len + 2..].to_vec();
}
// Write to client
let src = state.udp_origin_dst.ok_or("Expected UDP addr")?;
while let Some(packet) = to_send.pop_front() {
self.send_udp_packet_to_client(src, info.src, &packet)?;
}
Ok(())
}
fn udp_over_tcp_timeout_expired(&self, info: &ConnectionInfo) -> bool {
if let Some(state) = self.connection_map.get(info) {
if let Some(expiry) = state.udp_over_tcp_expiry {
return expiry < ::std::time::Instant::now();
}
}
false
}
fn clearup_expired_dns_over_tcp(&mut self) -> Result<()> {
let keys = self.connection_map.keys().cloned().collect::<Vec<_>>();
for key in keys {
if self.udp_over_tcp_timeout_expired(&key) {
log::trace!("UDP over TCP timeout: {}", key);
self.remove_connection(&key)?;
}
}
Ok(())
}
fn process_incoming_udp_packets(
&mut self,
manager: &Rc<dyn ConnectionManager>,
@ -505,7 +640,7 @@ impl<'a> TunToProxy<'a> {
let err = "udp associate state not find";
let state = self.connection_map.get_mut(info).ok_or(err)?;
assert!(state.udp_acco_expiry.is_some());
state.udp_acco_expiry = Some(Self::udp_associate_timeout());
state.udp_acco_expiry = Some(Self::common_udp_life_timeout());
// Add SOCKS5 UDP header to the incoming data
let mut s5_udp_data = Vec::<u8>::new();
@ -535,23 +670,23 @@ impl<'a> TunToProxy<'a> {
}
let (info, _first_packet, payload_offset, payload_size) = result?;
let origin_dst = SocketAddr::try_from(&info.dst)?;
let connection_info = self.preprocess_origin_connection_info(info)?;
let info = self.preprocess_origin_connection_info(info)?;
let manager = self.get_connection_manager().ok_or("get connection manager")?;
if connection_info.protocol == IpProtocol::Tcp {
if info.protocol == IpProtocol::Tcp {
if _first_packet {
let tcp_proxy_handler = manager.new_tcp_proxy(&connection_info, false)?;
let tcp_proxy_handler = manager.new_tcp_proxy(&info, false)?;
let server = manager.get_server_addr();
let state = self.create_new_tcp_connection_state(server, origin_dst, tcp_proxy_handler, false)?;
self.connection_map.insert(connection_info.clone(), state);
self.connection_map.insert(info.clone(), state);
log::info!("Connect done {} ({})", connection_info, origin_dst);
} else if !self.connection_map.contains_key(&connection_info) {
// log::debug!("Drop middle session {} ({})", connection_info, origin_dst);
log::info!("Connect done {} ({})", info, origin_dst);
} else if !self.connection_map.contains_key(&info) {
// log::debug!("Drop middle session {} ({})", info, origin_dst);
return Ok(());
} else {
// log::trace!("Subsequent packet {} ({})", connection_info, origin_dst);
// log::trace!("Subsequent packet {} ({})", info, origin_dst);
}
// Inject the packet to advance the remote proxy server smoltcp socket state
@ -563,24 +698,28 @@ impl<'a> TunToProxy<'a> {
self.expect_smoltcp_send()?;
// Read from the smoltcp socket and push the data to the connection handler.
self.tunsocket_read_and_forward(&connection_info)?;
self.tunsocket_read_and_forward(&info)?;
// The connection handler builds up the connection or encapsulates the data.
// Therefore, we now expect it to write data to the server.
self.write_to_server(&connection_info)?;
} else if connection_info.protocol == IpProtocol::Udp {
let port = connection_info.dst.port();
self.write_to_server(&info)?;
} else if info.protocol == IpProtocol::Udp {
let port = info.dst.port();
let payload = &frame[payload_offset..payload_offset + payload_size];
if let (Some(virtual_dns), true) = (&mut self.options.virtual_dns, port == DNS_PORT) {
log::info!("DNS query via virtual DNS {} ({})", connection_info, origin_dst);
log::info!("DNS query via virtual DNS {} ({})", info, origin_dst);
let response = virtual_dns.receive_query(payload)?;
self.send_udp_packet_to_client(origin_dst, connection_info.src, response.as_slice())?;
self.send_udp_packet_to_client(origin_dst, info.src, response.as_slice())?;
} else {
// Another UDP packet
self.process_incoming_udp_packets(&manager, &connection_info, origin_dst, payload)?;
if self.options.dns_over_tcp && origin_dst.port() == DNS_PORT {
self.process_incoming_dns_over_tcp_packets(&manager, &info, origin_dst, payload)?;
} else {
self.process_incoming_udp_packets(&manager, &info, origin_dst, payload)?;
}
}
} else {
log::warn!("Unsupported protocol: {} ({})", connection_info, origin_dst);
log::warn!("Unsupported protocol: {} ({})", info, origin_dst);
}
Ok::<(), Error>(())
};
@ -596,7 +735,7 @@ impl<'a> TunToProxy<'a> {
dst: SocketAddr,
tcp_proxy_handler: Box<dyn TcpProxy>,
udp_associate: bool,
) -> Result<TcpConnectState> {
) -> Result<ConnectionState> {
let mut socket = tcp::Socket::new(
tcp::SocketBuffer::new(vec![0; 1024 * 128]),
tcp::SocketBuffer::new(vec![0; 1024 * 128]),
@ -611,7 +750,7 @@ impl<'a> TunToProxy<'a> {
self.poll.registry().register(&mut client, token, i)?;
let expiry = if udp_associate {
Some(Self::udp_associate_timeout())
Some(Self::common_udp_life_timeout())
} else {
None
};
@ -625,7 +764,7 @@ impl<'a> TunToProxy<'a> {
} else {
(None, None)
};
let state = TcpConnectState {
let state = ConnectionState {
smoltcp_handle: Some(handle),
mio_stream: client,
token,
@ -638,11 +777,13 @@ impl<'a> TunToProxy<'a> {
udp_token,
udp_origin_dst: None,
udp_data_cache: LinkedList::new(),
udp_over_tcp_expiry: None,
is_tcp_dns: false,
};
Ok(state)
}
fn udp_associate_timeout() -> ::std::time::Instant {
fn common_udp_life_timeout() -> ::std::time::Instant {
::std::time::Instant::now() + ::std::time::Duration::from_secs(UDP_ASSO_TIMEOUT)
}
@ -777,7 +918,7 @@ impl<'a> TunToProxy<'a> {
let err = "udp connection state not found";
let state = self.connection_map.get_mut(info).ok_or(err)?;
assert!(state.udp_acco_expiry.is_some());
state.udp_acco_expiry = Some(Self::udp_associate_timeout());
state.udp_acco_expiry = Some(Self::common_udp_life_timeout());
let mut to_send: LinkedList<Vec<u8>> = LinkedList::new();
if let Some(udp_socket) = state.udp_socket.as_ref() {
let mut buf = [0; 1 << 16];
@ -807,7 +948,7 @@ impl<'a> TunToProxy<'a> {
Ok(())
}
fn comsume_cached_udp_packets(&mut self, info: &ConnectionInfo) -> Result<()> {
fn consume_cached_udp_packets(&mut self, info: &ConnectionInfo) -> Result<()> {
// Try to send the first UDP packets to remote SOCKS5 server for UDP associate session
if let Some(state) = self.connection_map.get_mut(info) {
if let Some(udp_socket) = state.udp_socket.as_ref() {
@ -843,7 +984,16 @@ impl<'a> TunToProxy<'a> {
let mut block = || -> Result<(), Error> {
if event.is_readable() || event.is_read_closed() {
{
let established = self
.connection_map
.get(&conn_info)
.ok_or("")?
.tcp_proxy_handler
.connection_established();
if self.options.dns_over_tcp && conn_info.dst.port() == DNS_PORT && established {
self.receive_dns_over_tcp_packet_and_write_to_client(&conn_info)?;
return Ok(());
} else {
let e = "connection state not found";
let state = self.connection_map.get_mut(&conn_info).ok_or(e)?;
@ -906,7 +1056,7 @@ impl<'a> TunToProxy<'a> {
// server.
self.write_to_server(&conn_info)?;
self.comsume_cached_udp_packets(&conn_info)?;
self.consume_cached_udp_packets(&conn_info)?;
}
if event.is_writable() {
@ -943,6 +1093,7 @@ impl<'a> TunToProxy<'a> {
}
self.send_to_smoltcp()?;
self.clearup_expired_udp_associate()?;
self.clearup_expired_dns_over_tcp()?;
}
}