Merge pull request #58 from blechschmidt/v8

DNS over TCP
This commit is contained in:
ssrlive 2023-08-23 09:50:30 +08:00 committed by GitHub
commit e518355756
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 208 additions and 38 deletions

View file

@ -9,6 +9,9 @@ pub enum Error {
#[error("std::io::Error {0}")] #[error("std::io::Error {0}")]
Io(#[from] std::io::Error), Io(#[from] std::io::Error),
#[error("TryFromIntError {0:?}")]
TryFromInt(#[from] std::num::TryFromIntError),
#[error("std::net::AddrParseError {0}")] #[error("std::net::AddrParseError {0}")]
AddrParse(#[from] std::net::AddrParseError), AddrParse(#[from] std::net::AddrParseError),

View file

@ -98,6 +98,7 @@ impl std::fmt::Display for ProxyType {
pub struct Options { pub struct Options {
virtual_dns: Option<virtdns::VirtualDns>, virtual_dns: Option<virtdns::VirtualDns>,
mtu: Option<usize>, mtu: Option<usize>,
dns_over_tcp: bool,
} }
impl Options { impl Options {
@ -107,6 +108,13 @@ impl Options {
pub fn with_virtual_dns(mut self) -> Self { pub fn with_virtual_dns(mut self) -> Self {
self.virtual_dns = Some(virtdns::VirtualDns::new()); self.virtual_dns = Some(virtdns::VirtualDns::new());
self.dns_over_tcp = false;
self
}
pub fn with_dns_over_tcp(mut self) -> Self {
self.dns_over_tcp = true;
self.virtual_dns = None;
self self
} }

View file

@ -40,6 +40,10 @@ struct Args {
/// Verbosity level /// Verbosity level
#[arg(short, long, value_name = "level", value_enum, default_value = "info")] #[arg(short, long, value_name = "level", value_enum, default_value = "info")]
verbosity: ArgVerbosity, verbosity: ArgVerbosity,
/// Enable DNS over TCP
#[arg(long)]
dns_over_tcp: bool,
} }
#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, clap::ValueEnum)] #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, clap::ValueEnum)]
@ -79,6 +83,10 @@ fn main() -> ExitCode {
options = options.with_virtual_dns(); options = options.with_virtual_dns();
} }
if args.dns_over_tcp {
options = options.with_dns_over_tcp();
}
let interface = match args.tun_fd { let interface = match args.tun_fd {
None => NetworkInterface::Named(args.tun.clone()), None => NetworkInterface::Named(args.tun.clone()),
Some(fd) => { Some(fd) => {

View file

@ -170,7 +170,7 @@ const CLIENT_WRITE_CLOSED: u8 = 2;
const UDP_ASSO_TIMEOUT: u64 = 10; // seconds const UDP_ASSO_TIMEOUT: u64 = 10; // seconds
const DNS_PORT: u16 = 53; const DNS_PORT: u16 = 53;
struct TcpConnectState { struct ConnectionState {
smoltcp_handle: Option<SocketHandle>, smoltcp_handle: Option<SocketHandle>,
mio_stream: TcpStream, mio_stream: TcpStream,
token: Token, token: Token,
@ -183,6 +183,8 @@ struct TcpConnectState {
udp_token: Option<Token>, udp_token: Option<Token>,
udp_origin_dst: Option<SocketAddr>, udp_origin_dst: Option<SocketAddr>,
udp_data_cache: LinkedList<Vec<u8>>, udp_data_cache: LinkedList<Vec<u8>>,
udp_over_tcp_expiry: Option<::std::time::Instant>,
is_tcp_dns: bool,
} }
pub(crate) trait TcpProxy { pub(crate) trait TcpProxy {
@ -210,7 +212,7 @@ pub struct TunToProxy<'a> {
tun: TunTapInterface, tun: TunTapInterface,
poll: Poll, poll: Poll,
iface: Interface, iface: Interface,
connection_map: HashMap<ConnectionInfo, TcpConnectState>, connection_map: HashMap<ConnectionInfo, ConnectionState>,
connection_manager: Option<Rc<dyn ConnectionManager>>, connection_manager: Option<Rc<dyn ConnectionManager>>,
next_token: usize, next_token: usize,
sockets: SocketSet<'a>, sockets: SocketSet<'a>,
@ -434,7 +436,7 @@ impl<'a> TunToProxy<'a> {
Ok(()) Ok(())
} }
fn update_mio_socket_interest(poll: &mut Poll, state: &mut TcpConnectState) -> Result<()> { fn update_mio_socket_interest(poll: &mut Poll, state: &mut ConnectionState) -> Result<()> {
// Maybe we did not listen for any events before. Therefore, just swallow the error. // Maybe we did not listen for any events before. Therefore, just swallow the error.
if let Err(err) = poll.registry().deregister(&mut state.mio_stream) { if let Err(err) = poll.registry().deregister(&mut state.mio_stream) {
log::trace!("{}", err); log::trace!("{}", err);
@ -459,15 +461,7 @@ impl<'a> TunToProxy<'a> {
fn preprocess_origin_connection_info(&mut self, info: ConnectionInfo) -> Result<ConnectionInfo> { fn preprocess_origin_connection_info(&mut self, info: ConnectionInfo) -> Result<ConnectionInfo> {
let origin_dst = SocketAddr::try_from(&info.dst)?; let origin_dst = SocketAddr::try_from(&info.dst)?;
let connection_info = match &mut self.options.virtual_dns { let connection_info = match &mut self.options.virtual_dns {
None => { None => info,
let mut info = info;
let port = origin_dst.port();
if port == DNS_PORT && info.protocol == IpProtocol::Udp && dns::addr_is_private(&origin_dst) {
let dns_addr: SocketAddr = "8.8.8.8:53".parse()?; // TODO: Configurable
info.dst = Address::from(dns_addr);
}
info
}
Some(virtual_dns) => { Some(virtual_dns) => {
let dst_ip = origin_dst.ip(); let dst_ip = origin_dst.ip();
virtual_dns.touch_ip(&dst_ip); virtual_dns.touch_ip(&dst_ip);
@ -480,6 +474,147 @@ impl<'a> TunToProxy<'a> {
Ok(connection_info) Ok(connection_info)
} }
fn process_incoming_dns_over_tcp_packets(
&mut self,
manager: &Rc<dyn ConnectionManager>,
original_info: &ConnectionInfo,
origin_dst: SocketAddr,
payload: &[u8],
) -> Result<()> {
_ = dns::parse_data_to_dns_message(payload, false)?;
let mut new_info = original_info.clone();
let dns_addr: SocketAddr = "8.8.8.8:53".parse()?;
new_info.dst = Address::from(dns_addr);
let info = &new_info;
if !self.connection_map.contains_key(info) {
log::info!("DNS over TCP {} ({})", info, origin_dst);
let tcp_proxy_handler = manager.new_tcp_proxy(info, false)?;
let server_addr = manager.get_server_addr();
let mut state = self.create_new_tcp_connection_state(server_addr, origin_dst, tcp_proxy_handler, false)?;
state.is_tcp_dns = true;
state.udp_origin_dst = Some(SocketAddr::try_from(original_info.dst.clone())?);
self.connection_map.insert(info.clone(), state);
// TODO: Move this 3 lines to the function end?
self.expect_smoltcp_send()?;
self.tunsocket_read_and_forward(info)?;
self.write_to_server(info)?;
} else {
log::trace!("DNS over TCP subsequent packet {} ({})", info, origin_dst);
}
// Insert the DNS message length in front of the payload
let len = u16::try_from(payload.len())?;
let mut buf = Vec::with_capacity(2 + usize::from(len));
buf.extend_from_slice(&len.to_be_bytes());
buf.extend_from_slice(payload);
let err = "udp over tcp state not find";
let state = self.connection_map.get_mut(info).ok_or(err)?;
state.udp_over_tcp_expiry = Some(Self::common_udp_life_timeout());
let data_event = IncomingDataEvent {
direction: IncomingDirection::FromClient,
buffer: &buf,
};
state.tcp_proxy_handler.push_data(data_event)?;
Ok(())
}
fn receive_dns_over_tcp_packet_and_write_to_client(&mut self, info: &ConnectionInfo) -> Result<()> {
let err = "udp connection state not found";
let state = self.connection_map.get_mut(info).ok_or(err)?;
assert!(state.udp_over_tcp_expiry.is_some());
state.udp_over_tcp_expiry = Some(Self::common_udp_life_timeout());
// Code similar to the code in parent function. TODO: Cleanup.
let mut vecbuf = Vec::<u8>::new();
let read_result = state.mio_stream.read_to_end(&mut vecbuf);
let read = match read_result {
Ok(read_result) => read_result,
Err(error) => {
if error.kind() != std::io::ErrorKind::WouldBlock {
log::error!("{} Read from proxy: {}", info.dst, error);
}
vecbuf.len()
}
};
let data = vecbuf.as_slice();
let data_event = IncomingDataEvent {
direction: IncomingDirection::FromServer,
buffer: &data[0..read],
};
if let Err(error) = state.tcp_proxy_handler.push_data(data_event) {
log::error!("{}", error);
self.remove_connection(&info.clone())?;
return Ok(());
}
let dns_event = state.tcp_proxy_handler.peek_data(OutgoingDirection::ToClient);
let mut buf = dns_event.buffer.to_vec();
let mut to_send: LinkedList<Vec<u8>> = LinkedList::new();
loop {
if buf.len() < 2 {
break;
}
let len = u16::from_be_bytes([buf[0], buf[1]]) as usize;
if buf.len() < len + 2 {
break;
}
let data = buf[2..len + 2].to_vec();
let mut message = dns::parse_data_to_dns_message(&data, false)?;
let name = dns::extract_domain_from_dns_message(&message)?;
let ip = dns::extract_ipaddr_from_dns_message(&message);
log::info!("DNS over TCP query result: {} -> {:?}", name, ip);
state
.tcp_proxy_handler
.consume_data(OutgoingDirection::ToClient, len + 2);
dns::remove_ipv6_entries(&mut message); // TODO: Configurable
to_send.push_back(message.to_vec()?);
if len + 2 == buf.len() {
break;
}
buf = buf[len + 2..].to_vec();
}
// Write to client
let src = state.udp_origin_dst.ok_or("Expected UDP addr")?;
while let Some(packet) = to_send.pop_front() {
self.send_udp_packet_to_client(src, info.src, &packet)?;
}
Ok(())
}
fn udp_over_tcp_timeout_expired(&self, info: &ConnectionInfo) -> bool {
if let Some(state) = self.connection_map.get(info) {
if let Some(expiry) = state.udp_over_tcp_expiry {
return expiry < ::std::time::Instant::now();
}
}
false
}
fn clearup_expired_dns_over_tcp(&mut self) -> Result<()> {
let keys = self.connection_map.keys().cloned().collect::<Vec<_>>();
for key in keys {
if self.udp_over_tcp_timeout_expired(&key) {
log::trace!("UDP over TCP timeout: {}", key);
self.remove_connection(&key)?;
}
}
Ok(())
}
fn process_incoming_udp_packets( fn process_incoming_udp_packets(
&mut self, &mut self,
manager: &Rc<dyn ConnectionManager>, manager: &Rc<dyn ConnectionManager>,
@ -505,7 +640,7 @@ impl<'a> TunToProxy<'a> {
let err = "udp associate state not find"; let err = "udp associate state not find";
let state = self.connection_map.get_mut(info).ok_or(err)?; let state = self.connection_map.get_mut(info).ok_or(err)?;
assert!(state.udp_acco_expiry.is_some()); assert!(state.udp_acco_expiry.is_some());
state.udp_acco_expiry = Some(Self::udp_associate_timeout()); state.udp_acco_expiry = Some(Self::common_udp_life_timeout());
// Add SOCKS5 UDP header to the incoming data // Add SOCKS5 UDP header to the incoming data
let mut s5_udp_data = Vec::<u8>::new(); let mut s5_udp_data = Vec::<u8>::new();
@ -535,23 +670,23 @@ impl<'a> TunToProxy<'a> {
} }
let (info, _first_packet, payload_offset, payload_size) = result?; let (info, _first_packet, payload_offset, payload_size) = result?;
let origin_dst = SocketAddr::try_from(&info.dst)?; let origin_dst = SocketAddr::try_from(&info.dst)?;
let connection_info = self.preprocess_origin_connection_info(info)?; let info = self.preprocess_origin_connection_info(info)?;
let manager = self.get_connection_manager().ok_or("get connection manager")?; let manager = self.get_connection_manager().ok_or("get connection manager")?;
if connection_info.protocol == IpProtocol::Tcp { if info.protocol == IpProtocol::Tcp {
if _first_packet { if _first_packet {
let tcp_proxy_handler = manager.new_tcp_proxy(&connection_info, false)?; let tcp_proxy_handler = manager.new_tcp_proxy(&info, false)?;
let server = manager.get_server_addr(); let server = manager.get_server_addr();
let state = self.create_new_tcp_connection_state(server, origin_dst, tcp_proxy_handler, false)?; let state = self.create_new_tcp_connection_state(server, origin_dst, tcp_proxy_handler, false)?;
self.connection_map.insert(connection_info.clone(), state); self.connection_map.insert(info.clone(), state);
log::info!("Connect done {} ({})", connection_info, origin_dst); log::info!("Connect done {} ({})", info, origin_dst);
} else if !self.connection_map.contains_key(&connection_info) { } else if !self.connection_map.contains_key(&info) {
// log::debug!("Drop middle session {} ({})", connection_info, origin_dst); // log::debug!("Drop middle session {} ({})", info, origin_dst);
return Ok(()); return Ok(());
} else { } else {
// log::trace!("Subsequent packet {} ({})", connection_info, origin_dst); // log::trace!("Subsequent packet {} ({})", info, origin_dst);
} }
// Inject the packet to advance the remote proxy server smoltcp socket state // Inject the packet to advance the remote proxy server smoltcp socket state
@ -563,24 +698,28 @@ impl<'a> TunToProxy<'a> {
self.expect_smoltcp_send()?; self.expect_smoltcp_send()?;
// Read from the smoltcp socket and push the data to the connection handler. // Read from the smoltcp socket and push the data to the connection handler.
self.tunsocket_read_and_forward(&connection_info)?; self.tunsocket_read_and_forward(&info)?;
// The connection handler builds up the connection or encapsulates the data. // The connection handler builds up the connection or encapsulates the data.
// Therefore, we now expect it to write data to the server. // Therefore, we now expect it to write data to the server.
self.write_to_server(&connection_info)?; self.write_to_server(&info)?;
} else if connection_info.protocol == IpProtocol::Udp { } else if info.protocol == IpProtocol::Udp {
let port = connection_info.dst.port(); let port = info.dst.port();
let payload = &frame[payload_offset..payload_offset + payload_size]; let payload = &frame[payload_offset..payload_offset + payload_size];
if let (Some(virtual_dns), true) = (&mut self.options.virtual_dns, port == DNS_PORT) { if let (Some(virtual_dns), true) = (&mut self.options.virtual_dns, port == DNS_PORT) {
log::info!("DNS query via virtual DNS {} ({})", connection_info, origin_dst); log::info!("DNS query via virtual DNS {} ({})", info, origin_dst);
let response = virtual_dns.receive_query(payload)?; let response = virtual_dns.receive_query(payload)?;
self.send_udp_packet_to_client(origin_dst, connection_info.src, response.as_slice())?; self.send_udp_packet_to_client(origin_dst, info.src, response.as_slice())?;
} else { } else {
// Another UDP packet // Another UDP packet
self.process_incoming_udp_packets(&manager, &connection_info, origin_dst, payload)?; if self.options.dns_over_tcp && origin_dst.port() == DNS_PORT {
self.process_incoming_dns_over_tcp_packets(&manager, &info, origin_dst, payload)?;
} else {
self.process_incoming_udp_packets(&manager, &info, origin_dst, payload)?;
}
} }
} else { } else {
log::warn!("Unsupported protocol: {} ({})", connection_info, origin_dst); log::warn!("Unsupported protocol: {} ({})", info, origin_dst);
} }
Ok::<(), Error>(()) Ok::<(), Error>(())
}; };
@ -596,7 +735,7 @@ impl<'a> TunToProxy<'a> {
dst: SocketAddr, dst: SocketAddr,
tcp_proxy_handler: Box<dyn TcpProxy>, tcp_proxy_handler: Box<dyn TcpProxy>,
udp_associate: bool, udp_associate: bool,
) -> Result<TcpConnectState> { ) -> Result<ConnectionState> {
let mut socket = tcp::Socket::new( let mut socket = tcp::Socket::new(
tcp::SocketBuffer::new(vec![0; 1024 * 128]), tcp::SocketBuffer::new(vec![0; 1024 * 128]),
tcp::SocketBuffer::new(vec![0; 1024 * 128]), tcp::SocketBuffer::new(vec![0; 1024 * 128]),
@ -611,7 +750,7 @@ impl<'a> TunToProxy<'a> {
self.poll.registry().register(&mut client, token, i)?; self.poll.registry().register(&mut client, token, i)?;
let expiry = if udp_associate { let expiry = if udp_associate {
Some(Self::udp_associate_timeout()) Some(Self::common_udp_life_timeout())
} else { } else {
None None
}; };
@ -625,7 +764,7 @@ impl<'a> TunToProxy<'a> {
} else { } else {
(None, None) (None, None)
}; };
let state = TcpConnectState { let state = ConnectionState {
smoltcp_handle: Some(handle), smoltcp_handle: Some(handle),
mio_stream: client, mio_stream: client,
token, token,
@ -638,11 +777,13 @@ impl<'a> TunToProxy<'a> {
udp_token, udp_token,
udp_origin_dst: None, udp_origin_dst: None,
udp_data_cache: LinkedList::new(), udp_data_cache: LinkedList::new(),
udp_over_tcp_expiry: None,
is_tcp_dns: false,
}; };
Ok(state) Ok(state)
} }
fn udp_associate_timeout() -> ::std::time::Instant { fn common_udp_life_timeout() -> ::std::time::Instant {
::std::time::Instant::now() + ::std::time::Duration::from_secs(UDP_ASSO_TIMEOUT) ::std::time::Instant::now() + ::std::time::Duration::from_secs(UDP_ASSO_TIMEOUT)
} }
@ -777,7 +918,7 @@ impl<'a> TunToProxy<'a> {
let err = "udp connection state not found"; let err = "udp connection state not found";
let state = self.connection_map.get_mut(info).ok_or(err)?; let state = self.connection_map.get_mut(info).ok_or(err)?;
assert!(state.udp_acco_expiry.is_some()); assert!(state.udp_acco_expiry.is_some());
state.udp_acco_expiry = Some(Self::udp_associate_timeout()); state.udp_acco_expiry = Some(Self::common_udp_life_timeout());
let mut to_send: LinkedList<Vec<u8>> = LinkedList::new(); let mut to_send: LinkedList<Vec<u8>> = LinkedList::new();
if let Some(udp_socket) = state.udp_socket.as_ref() { if let Some(udp_socket) = state.udp_socket.as_ref() {
let mut buf = [0; 1 << 16]; let mut buf = [0; 1 << 16];
@ -807,7 +948,7 @@ impl<'a> TunToProxy<'a> {
Ok(()) Ok(())
} }
fn comsume_cached_udp_packets(&mut self, info: &ConnectionInfo) -> Result<()> { fn consume_cached_udp_packets(&mut self, info: &ConnectionInfo) -> Result<()> {
// Try to send the first UDP packets to remote SOCKS5 server for UDP associate session // Try to send the first UDP packets to remote SOCKS5 server for UDP associate session
if let Some(state) = self.connection_map.get_mut(info) { if let Some(state) = self.connection_map.get_mut(info) {
if let Some(udp_socket) = state.udp_socket.as_ref() { if let Some(udp_socket) = state.udp_socket.as_ref() {
@ -843,7 +984,16 @@ impl<'a> TunToProxy<'a> {
let mut block = || -> Result<(), Error> { let mut block = || -> Result<(), Error> {
if event.is_readable() || event.is_read_closed() { if event.is_readable() || event.is_read_closed() {
{ let established = self
.connection_map
.get(&conn_info)
.ok_or("")?
.tcp_proxy_handler
.connection_established();
if self.options.dns_over_tcp && conn_info.dst.port() == DNS_PORT && established {
self.receive_dns_over_tcp_packet_and_write_to_client(&conn_info)?;
return Ok(());
} else {
let e = "connection state not found"; let e = "connection state not found";
let state = self.connection_map.get_mut(&conn_info).ok_or(e)?; let state = self.connection_map.get_mut(&conn_info).ok_or(e)?;
@ -906,7 +1056,7 @@ impl<'a> TunToProxy<'a> {
// server. // server.
self.write_to_server(&conn_info)?; self.write_to_server(&conn_info)?;
self.comsume_cached_udp_packets(&conn_info)?; self.consume_cached_udp_packets(&conn_info)?;
} }
if event.is_writable() { if event.is_writable() {
@ -943,6 +1093,7 @@ impl<'a> TunToProxy<'a> {
} }
self.send_to_smoltcp()?; self.send_to_smoltcp()?;
self.clearup_expired_udp_associate()?; self.clearup_expired_udp_associate()?;
self.clearup_expired_dns_over_tcp()?;
} }
} }