From 52d814ce79ddd1ae2f36d7a1b1a4a2a74ebb1481 Mon Sep 17 00:00:00 2001 From: ssrlive <30760636+ssrlive@users.noreply.github.com> Date: Thu, 31 Oct 2024 10:28:12 +0800 Subject: [PATCH] refine udpgw --- src/args.rs | 6 +-- src/lib.rs | 11 ++-- src/udpgw.rs | 150 +++++++++++++++++++++++++-------------------------- 3 files changed, 84 insertions(+), 83 deletions(-) diff --git a/src/args.rs b/src/args.rs index 4e71166..919c673 100644 --- a/src/args.rs +++ b/src/args.rs @@ -116,10 +116,10 @@ pub struct Args { #[arg(long, value_name = "IP:PORT")] pub udpgw_server: Option, - /// Max udpgw connections, default value is 100 + /// Max udpgw connections, default value is 5 #[cfg(feature = "udpgw")] #[arg(long, value_name = "number", requires = "udpgw_server")] - pub udpgw_max_connections: Option, + pub udpgw_max_connections: Option, } fn validate_tun(p: &str) -> Result { @@ -201,7 +201,7 @@ impl Args { } #[cfg(feature = "udpgw")] - pub fn udpgw_max_connections(&mut self, udpgw_max_connections: u16) -> &mut Self { + pub fn udpgw_max_connections(&mut self, udpgw_max_connections: usize) -> &mut Self { self.udpgw_max_connections = Some(udpgw_max_connections); self } diff --git a/src/lib.rs b/src/lib.rs index ae4621c..41b901a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -506,7 +506,7 @@ async fn handle_udp_gateway_session( } } None => { - if udpgw_client.is_full() { + if !udpgw_client.is_in_heartbeat_progress() && udpgw_client.is_full().await { return Err("max udpgw connection limit reached".into()); } let mut tcp_server_stream = create_tcp_stream(&socket_queue, proxy_server_addr).await?; @@ -543,13 +543,13 @@ async fn handle_udp_gateway_session( } Ok(n) => n, Err(e) => { - log::info!("[UdpGw] Ending stream {} {} <> {} with recv_udp_packet {}", sn, &tcp_local_addr, udp_dst, e); + log::info!("[UdpGw] Ending stream {} {} <> {} with udp stack \"{}\"", sn, &tcp_local_addr, udp_dst, e); break; } }; crate::traffic_status::traffic_status_update(read_len, 0)?; - let new_id = stream.new_packet_id(); - if let Err(e) = UdpGwClient::send_udpgw_packet(ipv6_enabled, &tmp_buf[0..read_len], udp_dst, new_id, &mut writer).await { + let sn = stream.serial_number(); + if let Err(e) = UdpGwClient::send_udpgw_packet(ipv6_enabled, &tmp_buf[0..read_len], udp_dst, sn, &mut writer).await { log::info!("[UdpGw] Ending stream {} {} <> {} with send_udpgw_packet {}", sn, &tcp_local_addr, udp_dst, e); break; } @@ -584,7 +584,8 @@ async fn handle_udp_gateway_session( UdpGwResponse::Data(data) => { use socks5_impl::protocol::StreamOperation; let len = data.len(); - log::debug!("[UdpGw] stream {} {} <- {} receive len {}", sn, &tcp_local_addr, udp_dst, len); + let f = data.header.flags; + log::debug!("[UdpGw] stream {sn} {} <- {} receive {f} len {len}", &tcp_local_addr, udp_dst); if let Err(e) = udp_stack.write_all(&data.data).await { log::error!("[UdpGw] Ending stream {} {} <> {} with send_udp_packet {}", sn, &tcp_local_addr, udp_dst, e); break; diff --git a/src/udpgw.rs b/src/udpgw.rs index 295be25..9a2f5eb 100644 --- a/src/udpgw.rs +++ b/src/udpgw.rs @@ -12,7 +12,7 @@ use tokio::{ }; pub(crate) const UDPGW_LENGTH_FIELD_SIZE: usize = std::mem::size_of::(); -pub(crate) const UDPGW_MAX_CONNECTIONS: u16 = 100; +pub(crate) const UDPGW_MAX_CONNECTIONS: usize = 5; pub(crate) const UDPGW_KEEPALIVE_TIME: tokio::time::Duration = std::time::Duration::from_secs(10); #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] @@ -34,7 +34,7 @@ impl std::fmt::Display for UdpFlag { 0x02 => "DATA", n => return write!(f, "Unknown UdpFlag(0x{:02X})", n), }; - write!(f, "UdpFlag({})", flag) + write!(f, "{}", flag) } } @@ -52,8 +52,6 @@ impl std::ops::BitOr for UdpFlag { } } -static TCP_COUNTER: std::sync::atomic::AtomicU32 = std::sync::atomic::AtomicU32::new(0); - /// UDP Gateway Packet Format /// /// The format is referenced from SOCKS5 packet format, with additional flags and connection ID fields. @@ -250,8 +248,7 @@ pub struct UdpgwHeader { impl std::fmt::Display for UdpgwHeader { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let id = self.conn_id; - write!(f, "flags: {}, conn_id: {}", self.flags, id) + write!(f, "{} conn_id: {}", self.flags, self.conn_id) } } @@ -329,23 +326,27 @@ pub(crate) enum UdpGwResponse { Data(Packet), } -static SERIAL_NUMBER: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(1); +impl std::fmt::Display for UdpGwResponse { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + UdpGwResponse::KeepAlive => write!(f, "KeepAlive"), + UdpGwResponse::Error => write!(f, "Error"), + UdpGwResponse::TcpClose => write!(f, "TcpClose"), + UdpGwResponse::Data(packet) => write!(f, "Data({})", packet), + } + } +} + +static SERIAL_NUMBER: std::sync::atomic::AtomicU16 = std::sync::atomic::AtomicU16::new(1); #[derive(Debug)] pub(crate) struct UdpGwClientStream { local_addr: SocketAddr, writer: Option, reader: Option, - conn_id: u16, closed: bool, last_activity: std::time::Instant, - serial_number: u64, -} - -impl Drop for UdpGwClientStream { - fn drop(&mut self) { - TCP_COUNTER.fetch_sub(1, Relaxed); - } + serial_number: u16, } impl UdpGwClientStream { @@ -381,20 +382,14 @@ impl UdpGwClientStream { self.closed } - pub fn serial_number(&self) -> u64 { + pub fn serial_number(&self) -> u16 { self.serial_number } - pub fn new_packet_id(&mut self) -> u16 { - self.conn_id += 1; - self.conn_id - } - pub fn new(tcp_server_stream: TcpStream) -> Self { let default = "0.0.0.0:0".parse::().unwrap(); let local_addr = tcp_server_stream.local_addr().unwrap_or(default); let (reader, writer) = tcp_server_stream.into_split(); - TCP_COUNTER.fetch_add(1, Relaxed); let serial_number = SERIAL_NUMBER.fetch_add(1, Relaxed); UdpGwClientStream { local_addr, @@ -402,7 +397,6 @@ impl UdpGwClientStream { writer: Some(writer), last_activity: std::time::Instant::now(), closed: false, - conn_id: 0, serial_number, } } @@ -411,16 +405,17 @@ impl UdpGwClientStream { #[derive(Debug)] pub(crate) struct UdpGwClient { udp_mtu: u16, - max_connections: u16, + max_connections: usize, udp_timeout: u64, keepalive_time: Duration, udpgw_server: SocketAddr, server_connections: Mutex>, + is_in_heartbeat: std::sync::atomic::AtomicBool, } impl UdpGwClient { - pub fn new(udp_mtu: u16, max_connections: u16, keepalive_time: Duration, udp_timeout: u64, udpgw_server: SocketAddr) -> Self { - let server_connections = Mutex::new(VecDeque::with_capacity(max_connections as usize)); + pub fn new(udp_mtu: u16, max_connections: usize, keepalive_time: Duration, udp_timeout: u64, udpgw_server: SocketAddr) -> Self { + let server_connections = Mutex::new(VecDeque::with_capacity(max_connections)); UdpGwClient { udp_mtu, max_connections, @@ -428,6 +423,7 @@ impl UdpGwClient { udpgw_server, keepalive_time, server_connections, + is_in_heartbeat: std::sync::atomic::AtomicBool::new(false), } } @@ -439,8 +435,8 @@ impl UdpGwClient { self.udp_timeout } - pub(crate) fn is_full(&self) -> bool { - TCP_COUNTER.load(Relaxed) >= self.max_connections as u32 + pub(crate) async fn is_full(&self) -> bool { + self.server_connections.lock().await.len() >= self.max_connections } pub(crate) async fn pop_server_connection_from_queue(&self) -> Option { @@ -448,13 +444,13 @@ impl UdpGwClient { } pub(crate) async fn store_server_connection(&self, stream: UdpGwClientStream) { - if self.server_connections.lock().await.len() < self.max_connections as usize { + if self.server_connections.lock().await.len() < self.max_connections { self.server_connections.lock().await.push_back(stream); } } pub(crate) async fn store_server_connection_full(&self, mut stream: UdpGwClientStream, reader: OwnedReadHalf, writer: OwnedWriteHalf) { - if self.server_connections.lock().await.len() < self.max_connections as usize { + if self.server_connections.lock().await.len() < self.max_connections { stream.set_reader(Some(reader)); stream.set_writer(Some(writer)); self.server_connections.lock().await.push_back(stream); @@ -465,54 +461,59 @@ impl UdpGwClient { self.udpgw_server } + pub(crate) fn is_in_heartbeat_progress(&self) -> bool { + self.is_in_heartbeat.load(Relaxed) + } + /// Heartbeat task asynchronous function to periodically check and maintain the active state of the server connection. pub(crate) async fn heartbeat_task(&self) -> std::io::Result<()> { loop { + self.is_in_heartbeat.store(false, Relaxed); sleep(self.keepalive_time).await; - let Some(mut stream) = self.pop_server_connection_from_queue().await else { - continue; - }; + self.is_in_heartbeat.store(true, Relaxed); + let mut streams = Vec::new(); - if stream.is_closed() { - // This stream will be dropped - continue; - } - - if stream.last_activity.elapsed() < self.keepalive_time { - self.store_server_connection(stream).await; - continue; - } - - let Some(mut stream_reader) = stream.get_reader() else { - continue; - }; - - let Some(mut stream_writer) = stream.get_writer() else { - continue; - }; - let local_addr = stream_writer.local_addr()?; - let sn = stream.serial_number(); - log::trace!("stream {} {:?} send keepalive", sn, local_addr); - let keepalive_packet: Vec = Packet::build_keepalive_packet(stream.new_packet_id()).into(); - if let Err(e) = stream_writer.write_all(&keepalive_packet).await { - log::warn!("stream {} {:?} send keepalive failed: {}", sn, local_addr, e); - continue; - } - match UdpGwClient::recv_udpgw_packet(self.udp_mtu, 10, &mut stream_reader).await { - Ok(UdpGwResponse::KeepAlive) => { - stream.update_activity(); - self.store_server_connection_full(stream, stream_reader, stream_writer).await; - log::trace!("stream {} {:?} keepalive success", sn, local_addr); + while let Some(stream) = self.pop_server_connection_from_queue().await { + if !stream.is_closed() { + streams.push(stream); + } + } + + for mut stream in streams { + if stream.last_activity.elapsed() < self.keepalive_time { + self.store_server_connection(stream).await; + continue; + } + + let Some(mut stream_reader) = stream.get_reader() else { + continue; + }; + + let Some(mut stream_writer) = stream.get_writer() else { + continue; + }; + let local_addr = stream_writer.local_addr()?; + let sn = stream.serial_number(); + let keepalive_packet: Vec = Packet::build_keepalive_packet(sn).into(); + if let Err(e) = stream_writer.write_all(&keepalive_packet).await { + log::warn!("stream {} {:?} send keepalive failed: {}", sn, local_addr, e); + continue; + } + match UdpGwClient::recv_udpgw_packet(self.udp_mtu, self.udp_timeout, &mut stream_reader).await { + Ok(UdpGwResponse::KeepAlive) => { + stream.update_activity(); + self.store_server_connection_full(stream, stream_reader, stream_writer).await; + log::trace!("stream {sn} {:?} send keepalive and recieve it successfully", local_addr); + } + Ok(v) => log::debug!("stream {sn} {:?} keepalive unexpected response: {v}", local_addr), + Err(e) => log::debug!("stream {sn} {:?} keepalive no response, error \"{e}\"", local_addr), } - Ok(v) => log::warn!("stream {} {:?} keepalive unexpected response: {:?}", sn, local_addr, v), - Err(e) => log::warn!("stream {} {:?} keepalive no response, error \"{}\"", sn, local_addr, e), } } } /// Parses the UDP response data. - pub(crate) fn parse_udp_response(udp_mtu: u16, data: &[u8]) -> Result { - let packet = Packet::try_from(data)?; + pub(crate) fn parse_udp_response(udp_mtu: u16, packet: Packet) -> Result { let flags = packet.header.flags; if flags & UdpFlag::ERR == UdpFlag::ERR { return Ok(UdpGwResponse::Error); @@ -538,14 +539,13 @@ impl UdpGwClient { /// # Returns /// - `Result`: Returns a result type containing the parsed UDP gateway response, or an error if one occurs. pub(crate) async fn recv_udpgw_packet(udp_mtu: u16, udp_timeout: u64, stream: &mut OwnedReadHalf) -> Result { - let mut data = vec![0; udp_mtu.into()]; - let data_len = tokio::time::timeout(tokio::time::Duration::from_secs(udp_timeout + 2), stream.read(&mut data)) - .await - .map_err(std::io::Error::from)??; - if data_len == 0 { - return Ok(UdpGwResponse::TcpClose); - } - UdpGwClient::parse_udp_response(udp_mtu, &data[..data_len]) + let packet = tokio::time::timeout( + tokio::time::Duration::from_secs(udp_timeout + 2), + Packet::retrieve_from_async_stream(stream), + ) + .await + .map_err(std::io::Error::from)??; + UdpGwClient::parse_udp_response(udp_mtu, packet) } /// Sends a UDP gateway packet.