From 8d3533e327f58e50c7896f91efc225ae204868f2 Mon Sep 17 00:00:00 2001 From: suchao Date: Sun, 20 Oct 2024 17:54:21 +0800 Subject: [PATCH] optimize and fix --- src/bin/udpgw_server.rs | 6 +++++- src/lib.rs | 5 +++++ src/udpgw.rs | 19 +++++++++++++++---- 3 files changed, 25 insertions(+), 5 deletions(-) diff --git a/src/bin/udpgw_server.rs b/src/bin/udpgw_server.rs index 97f0643..666c8a8 100644 --- a/src/bin/udpgw_server.rs +++ b/src/bin/udpgw_server.rs @@ -178,7 +178,11 @@ pub fn parse_udp(udp_mtu: u16, data_len: usize, data: &[u8]) -> Result<(&[u8], u } async fn process_udp(addr: SocketAddr, udp_timeout: u64, tx: Sender>, con: &mut UdpRequest) -> Result<()> { - let std_sock = std::net::UdpSocket::bind("0.0.0.0:0")?; + let std_sock = if con.flags & UDPGW_FLAG_IPV6 != 0 { + std::net::UdpSocket::bind("[::]:0")? + } else { + std::net::UdpSocket::bind("0.0.0.0:0")? + }; std_sock.set_nonblocking(true)?; #[cfg(target_os = "linux")] nix::sys::socket::setsockopt(&std_sock, nix::sys::socket::sockopt::ReuseAddr, &true)?; diff --git a/src/lib.rs b/src/lib.rs index 7ee936d..0ea57e2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -579,6 +579,11 @@ async fn handle_udp_gateway_session( server_stream.update_activity(); continue; } + UdpGwResponse::TcpClose => { + log::error!("Ending {} <- {} with tcp closed", udpinfo, &tcp_local_addr); + server_stream.close(); + break; + } UdpGwResponse::Data(data) => { let len = data.len(); log::debug!("{} <- {} receive udpgw len {}", udpinfo, &tcp_local_addr,len); diff --git a/src/udpgw.rs b/src/udpgw.rs index 26d47a5..eab915b 100644 --- a/src/udpgw.rs +++ b/src/udpgw.rs @@ -4,6 +4,7 @@ use std::collections::VecDeque; use std::hash::Hash; use std::mem; use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; +use std::sync::atomic::Ordering::Relaxed; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf}; use tokio::net::TcpStream; @@ -17,6 +18,8 @@ pub const UDPGW_FLAG_IPV6: u8 = 0x08; pub const UDPGW_FLAG_DOMAIN: u8 = 0x10; pub const UDPGW_FLAG_ERR: u8 = 0x20; +static TCP_COUNTER: std::sync::atomic::AtomicU32 = std::sync::atomic::AtomicU32::new(0); + #[derive(Debug, Clone, PartialEq, Eq, Hash)] #[repr(C)] #[repr(packed(1))] @@ -103,6 +106,7 @@ impl<'a> UdpGwData<'a> { pub(crate) enum UdpGwResponse<'a> { KeepAlive, Error, + TcpClose, Data(UdpGwData<'a>), } @@ -129,6 +133,12 @@ pub(crate) struct UdpGwClientStream { last_activity: std::time::Instant, } +impl Drop for UdpGwClientStream { + fn drop(&mut self) { + TCP_COUNTER.fetch_sub(1, Relaxed); + } +} + impl UdpGwClientStream { pub fn close(&mut self) { self.closed = true; @@ -184,6 +194,7 @@ impl UdpGwClientStream { inner: rx, recv_buf: vec![0; udp_mtu.into()], }; + TCP_COUNTER.fetch_add(1, Relaxed); UdpGwClientStream { local_addr, reader: Some(reader), @@ -232,7 +243,7 @@ impl UdpGwClient { } pub(crate) async fn is_full(&self) -> bool { - self.server_connections.lock().await.len() >= self.max_connections as usize + TCP_COUNTER.load(Relaxed) >= self.max_connections as u32 } pub(crate) async fn get_server_connection(&self) -> Option { @@ -411,7 +422,7 @@ impl UdpGwClient { } }; match result { - Ok(0) => Err(("tcp connection closed").into()), + Ok(0) => Ok(UdpGwResponse::TcpClose), Ok(n) => { if n < std::mem::size_of::() { return Err("received PackLenHeader error".into()); @@ -425,12 +436,12 @@ impl UdpGwClient { while left_len > 0 { if let Ok(len) = stream.inner.read(&mut stream.recv_buf[recv_len..left_len]).await { if len == 0 { - return Err("tcp connection closed".into()); + return Ok(UdpGwResponse::TcpClose); } recv_len += len; left_len -= len; } else { - return Err("tcp connection closed".into()); + return Ok(UdpGwResponse::TcpClose); } } return UdpGwClient::parse_udp_response(udp_mtu, packet_len as usize, stream);