From be45bc8a22585c0b32b4ea48ce05445f3010429d Mon Sep 17 00:00:00 2001 From: ssrlive <30760636+ssrlive@users.noreply.github.com> Date: Mon, 21 Oct 2024 17:26:35 +0800 Subject: [PATCH] read code --- src/lib.rs | 33 ++++++++++++++---------------- src/udpgw.rs | 58 ++++++++++++++++++++++------------------------------ 2 files changed, 40 insertions(+), 51 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 7206478..b8bdab8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -240,24 +240,21 @@ where let mut ip_stack = ipstack::IpStack::new(ipstack_config, device); #[cfg(feature = "udpgw")] - let udpgw_client = match args.udpgw_server { - None => None, - Some(addr) => { - log::info!("UDPGW enabled"); - let client = Arc::new(UdpGwClient::new( - mtu, - args.udpgw_max_connections.unwrap_or(100), - UDPGW_KEEPALIVE_TIME, - args.udp_timeout, - addr, - )); - let client_keepalive = client.clone(); - tokio::spawn(async move { - client_keepalive.heartbeat_task().await; - }); - Some(client) - } - }; + let udpgw_client = args.udpgw_server.as_ref().map(|addr| { + log::info!("UDPGW enabled"); + let client = Arc::new(UdpGwClient::new( + mtu, + args.udpgw_max_connections.unwrap_or(100), + UDPGW_KEEPALIVE_TIME, + args.udp_timeout, + *addr, + )); + let client_keepalive = client.clone(); + tokio::spawn(async move { + client_keepalive.heartbeat_task().await; + }); + client + }); loop { let virtual_dns = virtual_dns.clone(); diff --git a/src/udpgw.rs b/src/udpgw.rs index bf546e1..45a980d 100644 --- a/src/udpgw.rs +++ b/src/udpgw.rs @@ -446,44 +446,36 @@ impl UdpGwClient { /// # Returns /// - `Result`: Returns a result type containing the parsed UDP gateway response, or an error if one occurs. pub(crate) async fn recv_udpgw_packet(udp_mtu: u16, udp_timeout: u64, stream: &mut UdpGwClientStreamReader) -> Result { - let result = match tokio::time::timeout( + let result = tokio::time::timeout( tokio::time::Duration::from_secs(udp_timeout + 2), stream.inner.read(&mut stream.recv_buf[..2]), ) .await - { - Ok(ret) => ret, - Err(_e) => { - return Err(("wait tcp data timeout").into()); - } - }; - match result { - Ok(0) => Ok(UdpGwResponse::TcpClose), - Ok(n) => { - if n < std::mem::size_of::() { - return Err("received PackLenHeader error".into()); - } - let packet_len = u16::from_le_bytes([stream.recv_buf[0], stream.recv_buf[1]]); - if packet_len > udp_mtu { - return Err("packet too long".into()); - } - let mut left_len: usize = packet_len as usize; - let mut recv_len = 0; - while left_len > 0 { - if let Ok(len) = stream.inner.read(&mut stream.recv_buf[recv_len..left_len]).await { - if len == 0 { - return Ok(UdpGwResponse::TcpClose); - } - recv_len += len; - left_len -= len; - } else { - return Ok(UdpGwResponse::TcpClose); - } - } - return UdpGwClient::parse_udp_response(udp_mtu, packet_len as usize, stream); - } - Err(_) => Err("tcp read error".into()), + .map_err(std::io::Error::from)?; + let n = result?; + if n == 0 { + return Ok(UdpGwResponse::TcpClose); } + if n < std::mem::size_of::() { + return Err("received PackLenHeader error".into()); + } + let packet_len = u16::from_le_bytes([stream.recv_buf[0], stream.recv_buf[1]]); + if packet_len > udp_mtu { + return Err("packet too long".into()); + } + let mut left_len: usize = packet_len as usize; + let mut recv_len = 0; + while left_len > 0 { + let Ok(len) = stream.inner.read(&mut stream.recv_buf[recv_len..left_len]).await else { + return Ok(UdpGwResponse::TcpClose); + }; + if len == 0 { + return Ok(UdpGwResponse::TcpClose); + } + recv_len += len; + left_len -= len; + } + UdpGwClient::parse_udp_response(udp_mtu, packet_len as usize, stream) } /// Sends a UDP gateway packet.