diff --git a/src/socks.rs b/src/socks.rs index 3e4a7a9..4dbc091 100644 --- a/src/socks.rs +++ b/src/socks.rs @@ -142,7 +142,7 @@ impl SocksProxyImpl { let response = handshake::Response::retrieve_from_stream(&mut self.server_inbuf.clone()); if let Err(e) = &response { if e.kind() == std::io::ErrorKind::UnexpectedEof { - log::trace!("receive_server_hello_socks5 needs more data \"{}\"...", e); + // log::trace!("receive_server_hello_socks5 needs more data \"{}\"...", e); return Ok(()); } else { return Err(e.to_string().into()); @@ -212,7 +212,7 @@ impl SocksProxyImpl { let response = protocol::Response::retrieve_from_stream(&mut self.server_inbuf.clone()); if let Err(e) = &response { if e.kind() == std::io::ErrorKind::UnexpectedEof { - log::trace!("receive_connection_status needs more data \"{}\"...", e); + // log::trace!("receive_connection_status needs more data \"{}\"...", e); return Ok(()); } else { return Err(e.to_string().into()); @@ -226,7 +226,7 @@ impl SocksProxyImpl { if self.command == protocol::Command::UdpAssociate { self.udp_associate = Some(SocketAddr::try_from(&response.address)?); assert!(self.data_buf.is_empty()); - log::debug!("UDP associate: {}", response.address); + // log::debug!("UDP associate: {}", response.address); } self.server_outbuf.append(&mut self.data_buf); diff --git a/src/tun2proxy.rs b/src/tun2proxy.rs index df0f87d..a537a3b 100644 --- a/src/tun2proxy.rs +++ b/src/tun2proxy.rs @@ -1,5 +1,5 @@ use crate::{error::Error, error::Result, virtdevice::VirtualTunDevice, NetworkInterface, Options}; -use mio::{event::Event, net::TcpStream, unix::SourceFd, Events, Interest, Poll, Token}; +use mio::{event::Event, net::TcpStream, net::UdpSocket, unix::SourceFd, Events, Interest, Poll, Token}; use smoltcp::{ iface::{Config, Interface, SocketHandle, SocketSet}, phy::{Device, Medium, RxToken, TunTapInterface, TxToken}, @@ -166,7 +166,7 @@ fn connection_tuple(frame: &[u8]) -> Result<(ConnectionInfo, bool, usize, usize) const SERVER_WRITE_CLOSED: u8 = 1; const CLIENT_WRITE_CLOSED: u8 = 2; -const UDP_ASSO_TIMEOUT: u64 = 5; // seconds +const UDP_ASSO_TIMEOUT: u64 = 10; // seconds struct TcpConnectState { smoltcp_handle: Option, @@ -177,6 +177,10 @@ struct TcpConnectState { wait_read: bool, wait_write: bool, expiry: Option<::std::time::Instant>, + udp_socket: Option, + udp_token: Option, + udp_origin_dst: Option, + udp_data_cache: Option>, } pub(crate) trait TcpProxy { @@ -303,6 +307,17 @@ impl<'a> TunToProxy<'a> { .find_map(|(info, state)| if state.token == token { Some(info) } else { None }) } + fn find_info_by_udp_token(&self, token: Token) -> Option<&ConnectionInfo> { + self.connection_map.iter().find_map(|(info, state)| { + if let Some(udp_token) = state.udp_token { + if udp_token == token { + return Some(info); + } + } + None + }) + } + /// Destroy connection state machine fn remove_connection(&mut self, info: &ConnectionInfo) -> Result<(), Error> { if let Some(mut state) = self.connection_map.remove(info) { @@ -316,7 +331,16 @@ impl<'a> TunToProxy<'a> { // FIXME: Does this line should be moved up to the beginning of this function? self.expect_smoltcp_send()?; - _ = self.poll.registry().deregister(&mut state.mio_stream); + if let Err(e) = self.poll.registry().deregister(&mut state.mio_stream) { + // FIXME: The function `deregister` will frequently fail for unknown reasons. + log::debug!("{}", e); + } + + if let Some(mut udp_socket) = state.udp_socket { + if let Err(e) = self.poll.registry().deregister(&mut udp_socket) { + log::debug!("{}", e); + } + } log::info!("Close {}", info); } @@ -442,14 +466,14 @@ impl<'a> TunToProxy<'a> { return Ok(()); } let (info, _first_packet, payload_offset, payload_size) = result?; - let dst = SocketAddr::try_from(&info.dst)?; + let origin_dst = SocketAddr::try_from(&info.dst)?; let connection_info = match &mut self.options.virtual_dns { - None => info.clone(), + None => info, Some(virtual_dns) => { - let dst_ip = dst.ip(); + let dst_ip = origin_dst.ip(); virtual_dns.touch_ip(&dst_ip); match virtual_dns.resolve_ip(&dst_ip) { - None => info.clone(), + None => info, Some(name) => info.to_named(name.clone()), } } @@ -461,15 +485,16 @@ impl<'a> TunToProxy<'a> { if connection_info.protocol == IpProtocol::Tcp { if _first_packet { let tcp_proxy_handler = manager.new_tcp_proxy(&connection_info, false)?; - let state = self.create_new_tcp_connection_state(server_addr, dst, tcp_proxy_handler, false)?; + #[rustfmt::skip] + let state = self.create_new_tcp_connection_state(server_addr, origin_dst, tcp_proxy_handler, false)?; self.connection_map.insert(connection_info.clone(), state); - log::info!("Connect done {} ({})", connection_info, dst); + log::info!("Connect done {} ({})", connection_info, origin_dst); } else if !self.connection_map.contains_key(&connection_info) { - log::debug!("Not found {} ({})", connection_info, dst); + // log::debug!("Drop middle session {} ({})", connection_info, origin_dst); return Ok(()); } else { - log::trace!("Subsequent packet {} ({})", connection_info, dst); + // log::trace!("Subsequent packet {} ({})", connection_info, origin_dst); } // Inject the packet to advance the remote proxy server smoltcp socket state @@ -487,46 +512,51 @@ impl<'a> TunToProxy<'a> { // Therefore, we now expect it to write data to the server. self.write_to_server(&connection_info)?; } else if connection_info.protocol == IpProtocol::Udp { - log::trace!("{} ({})", connection_info, dst); let port = connection_info.dst.port(); let payload = &frame[payload_offset..payload_offset + payload_size]; if let (Some(virtual_dns), true) = (&mut self.options.virtual_dns, port == 53) { + log::info!("DNS query via virtual DNS {} ({})", connection_info, origin_dst); let response = virtual_dns.receive_query(payload)?; - self.send_udp_packet(dst, connection_info.src, response.as_slice())?; + self.send_udp_packet_to_client(origin_dst, connection_info.src, response.as_slice())?; } else { // Another UDP packet if !self.connection_map.contains_key(&connection_info) { - log::trace!("New UDP connection {} ({})", connection_info, dst); + log::info!("UDP associate session {} ({})", connection_info, origin_dst); let tcp_proxy_handler = manager.new_tcp_proxy(&connection_info, true)?; - let state = self.create_new_tcp_connection_state(server_addr, dst, tcp_proxy_handler, true)?; + #[rustfmt::skip] + let mut state = self.create_new_tcp_connection_state(server_addr, origin_dst, tcp_proxy_handler, true)?; + state.udp_origin_dst = Some(origin_dst); self.connection_map.insert(connection_info.clone(), state); + + self.expect_smoltcp_send()?; + self.tunsocket_read_and_forward(&connection_info)?; + self.write_to_server(&connection_info)?; + } else { + // log::trace!("Subsequent udp packet {} ({})", connection_info, origin_dst); } - self.expect_smoltcp_send()?; - self.tunsocket_read_and_forward(&connection_info)?; - self.write_to_server(&connection_info)?; + let err = "udp associate state not find"; + let state = self.connection_map.get_mut(&connection_info).ok_or(err)?; + assert!(state.expiry.is_some()); + state.expiry = Some(Self::udp_associate_timeout()); + // Add SOCKS5 UDP header to the incoming data let mut s5_udp_data = Vec::::new(); UdpHeader::new(0, connection_info.dst.clone()).write_to_stream(&mut s5_udp_data)?; s5_udp_data.extend_from_slice(payload); - let state = self - .connection_map - .get_mut(&connection_info) - .ok_or("udp associate state")?; - assert!(state.expiry.is_some()); - state.expiry = Some(Self::udp_associate_timeout()); - if let Some(udp_associate) = state.tcp_proxy_handler.get_udp_associate() { - log::debug!("UDP associate address: {}", udp_associate); - // Send packets via UDP associate... - // self.send_udp_packet(connection_info.src, udp_associate, &s5_udp_data)?; + // UDP associate session has been established, we can send packets directly... + if let Some(socket) = state.udp_socket.as_ref() { + socket.send_to(&s5_udp_data, udp_associate)?; + } } else { // UDP associate tunnel not ready yet, we must cache the packet... + state.udp_data_cache = Some(s5_udp_data); } } } else { - log::warn!("Unsupported protocol: {} ({})", connection_info, dst); + log::warn!("Unsupported protocol: {} ({})", connection_info, origin_dst); } Ok::<(), Error>(()) }; @@ -561,6 +591,16 @@ impl<'a> TunToProxy<'a> { } else { None }; + + let (udp_socket, udp_token) = if udp_associate { + let addr = (Ipv4Addr::UNSPECIFIED, 0).into(); + let mut socket = UdpSocket::bind(addr)?; + let token = self.new_token(); + self.poll.registry().register(&mut socket, token, Interest::READABLE)?; + (Some(socket), Some(token)) + } else { + (None, None) + }; let state = TcpConnectState { smoltcp_handle: Some(handle), mio_stream: client, @@ -570,6 +610,10 @@ impl<'a> TunToProxy<'a> { wait_read: true, wait_write: false, expiry, + udp_socket, + udp_token, + udp_origin_dst: None, + udp_data_cache: None, }; Ok(state) } @@ -598,7 +642,7 @@ impl<'a> TunToProxy<'a> { Ok(()) } - fn send_udp_packet(&mut self, src: SocketAddr, dst: SocketAddr, data: &[u8]) -> Result<()> { + fn send_udp_packet_to_client(&mut self, src: SocketAddr, dst: SocketAddr, data: &[u8]) -> Result<()> { let rx_buffer = udp::PacketBuffer::new(vec![udp::PacketMetadata::EMPTY], vec![0; 4096]); let tx_buffer = udp::PacketBuffer::new(vec![udp::PacketMetadata::EMPTY], vec![0; 4096]); let mut socket = udp::Socket::new(rx_buffer, tx_buffer); @@ -706,6 +750,27 @@ impl<'a> TunToProxy<'a> { } fn mio_socket_event(&mut self, event: &Event) -> Result<(), Error> { + if let Some(info) = self.find_info_by_udp_token(event.token()) { + let info = info.clone(); + let err = "udp connection state not found"; + let state = self.connection_map.get_mut(&info).ok_or(err)?; + state.expiry = Some(Self::udp_associate_timeout()); + if let Some(udp_socket) = state.udp_socket.as_ref() { + let mut buf = [0; 1 << 16]; + // Receive UDP packet from remote SOCKS5 server + let (packet_size, _svr_addr) = udp_socket.recv_from(&mut buf)?; + + let buf = buf[..packet_size].to_vec(); + let header = UdpHeader::retrieve_from_stream(&mut &buf[..])?; + + // Write to client + let src = state.udp_origin_dst.ok_or("udp address")?; + self.send_udp_packet_to_client(src, info.src, &buf[header.len()..])?; + } + + return Ok(()); + } + let conn_info = match self.find_info_by_token(event.token()) { Some(conn_info) => conn_info.clone(), None => { @@ -723,6 +788,7 @@ impl<'a> TunToProxy<'a> { let mut block = || -> Result<(), Error> { if event.is_readable() || event.is_read_closed() { { + let e = "connection state not found"; let state = self.connection_map.get_mut(&conn_info).ok_or(e)?; // TODO: Move this reading process to its own function. @@ -783,6 +849,18 @@ impl<'a> TunToProxy<'a> { // The connection handler could have produced data that is to be written to the // server. self.write_to_server(&conn_info)?; + + // Try to send the first UDP packet to remote SOCKS5 server for UDP associate session + if let Some(state) = self.connection_map.get_mut(&conn_info) { + if let Some(udp_socket) = state.udp_socket.as_ref() { + if let Some(addr) = state.tcp_proxy_handler.get_udp_associate() { + // Take ownership of udp_data_cache + if let Some(buf) = state.udp_data_cache.take() { + udp_socket.send_to(&buf, addr)?; + } + } + } + } } if event.is_writable() {