From 0a833d69a657d0f695a8d5d410781e60ebb201c7 Mon Sep 17 00:00:00 2001 From: suchao Date: Sat, 19 Oct 2024 13:40:35 +0800 Subject: [PATCH] support udp gateway mode --- src/args.rs | 13 ++- src/bin/udpgw_server.rs | 128 ++++++++++++---------- src/error.rs | 3 - src/lib.rs | 137 +++++++++++++----------- src/udpgw.rs | 230 ++++++++++++++++++++++++++-------------- 5 files changed, 309 insertions(+), 202 deletions(-) diff --git a/src/args.rs b/src/args.rs index ec839fb..ce4a4e7 100644 --- a/src/args.rs +++ b/src/args.rs @@ -70,14 +70,18 @@ pub struct Args { #[arg(short, long, default_value = if cfg!(target_os = "linux") { "false" } else { "true" })] pub setup: bool, - /// DNS handling strategy - #[arg(short, long, value_name = "strategy", value_enum, default_value = "direct")] - pub dns: ArgDns, - /// UDP gateway address #[arg(long, value_name = "IP:PORT")] pub udpgw_bind_addr: Option, + /// Max udpgw connections + #[arg(long, value_name = "number", default_value = "100")] + pub max_udpgw_connections: u16, + + /// DNS handling strategy + #[arg(short, long, value_name = "strategy", value_enum, default_value = "direct")] + pub dns: ArgDns, + /// DNS resolver address #[arg(long, value_name = "IP", default_value = "8.8.8.8")] pub dns_addr: IpAddr, @@ -149,6 +153,7 @@ impl Default for Args { ipv6_enabled: false, setup, udpgw_bind_addr: None, + max_udpgw_connections: 100, dns: ArgDns::default(), dns_addr: "8.8.8.8".parse().unwrap(), bypass: vec![], diff --git a/src/bin/udpgw_server.rs b/src/bin/udpgw_server.rs index 45671ec..7652209 100644 --- a/src/bin/udpgw_server.rs +++ b/src/bin/udpgw_server.rs @@ -1,4 +1,3 @@ -use std::collections::HashMap; use std::mem; use std::net::Ipv4Addr; use std::net::SocketAddr; @@ -11,25 +10,24 @@ use tokio::net::TcpListener; use tokio::net::UdpSocket; use tokio::sync::mpsc; use tokio::sync::mpsc::Sender; -use tokio::sync::Mutex; pub use tun2proxy::udpgw::*; use tun2proxy::ArgVerbosity; use tun2proxy::Result; -pub(crate) const CLIENT_DISCONNECT_TIMEOUT: tokio::time::Duration = std::time::Duration::from_secs(60); +pub(crate) const CLIENT_DISCONNECT_TIMEOUT: tokio::time::Duration = std::time::Duration::from_secs(30); #[derive(Debug)] -struct Connection { +struct UdpRequest { flags: u8, server_addr: SocketAddr, conid: u16, data: Vec, } +#[derive(Debug)] struct Client { #[allow(dead_code)] addr: SocketAddr, buf: Vec, - connections: Arc>>, last_activity: std::time::Instant, } @@ -43,6 +41,10 @@ pub struct UdpGwArgs { #[arg(short, long, value_name = "level", value_enum, default_value = "info")] pub verbosity: ArgVerbosity, + /// Daemonize for unix family or run as Windows service + #[arg(long)] + pub daemonize: bool, + /// UDP timeout in seconds #[arg(long, value_name = "seconds", default_value = "3")] pub udp_timeout: u64, @@ -56,11 +58,10 @@ impl UdpGwArgs { #[allow(clippy::let_and_return)] pub fn parse_args() -> Self { use clap::Parser; - let args = Self::parse(); - args + Self::parse() } } -async fn send_error_response(tx: Sender>, con: &mut Connection) { +async fn send_error(tx: Sender>, con: &mut UdpRequest) { let mut error_packet = vec![]; error_packet.extend_from_slice(&(std::mem::size_of::() as u16).to_le_bytes()); error_packet.extend_from_slice(&[UDPGW_FLAG_ERR]); @@ -70,7 +71,13 @@ async fn send_error_response(tx: Sender>, con: &mut Connection) { } } -pub fn parse_udp_req_data(udp_mtu: u16, data_len: usize, data: &[u8]) -> Result<(&[u8], u8, u16, SocketAddr)> { +async fn send_keepalive_response(tx: Sender>, keepalive_packet: &[u8]) { + if let Err(e) = tx.send(keepalive_packet.to_vec()).await { + log::error!("send keepalive response error {:?}", e); + } +} + +pub fn parse_udp(udp_mtu: u16, data_len: usize, data: &[u8]) -> Result<(&[u8], u8, u16, SocketAddr)> { if data_len < mem::size_of::() { return Err("Invalid udpgw data".into()); } @@ -85,10 +92,9 @@ pub fn parse_udp_req_data(udp_mtu: u16, data_len: usize, data: &[u8]) -> Result< // keepalive if flags & UDPGW_FLAG_KEEPALIVE != 0 { - return Ok((data, UDPGW_FLAG_KEEPALIVE, 0, SocketAddrV4::new(Ipv4Addr::from(0), 0).into())); + return Ok((data, flags, conid, SocketAddrV4::new(Ipv4Addr::from(0), 0).into())); } - // parse address let ip_data = &data[mem::size_of::()..]; let mut data_len = data_len - mem::size_of::(); // port_len + min(ipv4/ipv6/(domain_len + 1)) @@ -107,7 +113,6 @@ pub fn parse_udp_req_data(udp_mtu: u16, data_len: usize, data: &[u8]) -> Result< .to_socket_addrs()? .next() .ok_or(format!("Invalid address {}", target_str))?; - // check payload length if data_len < 2 + domain.len() { return Err("Invalid udpgw data".into()); } @@ -136,7 +141,6 @@ pub fn parse_udp_req_data(udp_mtu: u16, data_len: usize, data: &[u8]) -> Result< }; data_len -= mem::size_of::(); - // check payload length if data_len > udp_mtu as usize { return Err("too much data".into()); } @@ -157,7 +161,6 @@ pub fn parse_udp_req_data(udp_mtu: u16, data_len: usize, data: &[u8]) -> Result< }; data_len -= mem::size_of::(); - // check payload length if data_len > udp_mtu as usize { return Err("too much data".into()); } @@ -171,15 +174,16 @@ pub fn parse_udp_req_data(udp_mtu: u16, data_len: usize, data: &[u8]) -> Result< } } -async fn process_udp(addr: SocketAddr, udp_timeout: u64, tx: Sender>, con: &mut Connection) -> Result<()> { +async fn process_udp(addr: SocketAddr, udp_timeout: u64, tx: Sender>, con: &mut UdpRequest) -> Result<()> { let std_sock = std::net::UdpSocket::bind("0.0.0.0:0")?; std_sock.set_nonblocking(true)?; nix::sys::socket::setsockopt(&std_sock, nix::sys::socket::sockopt::ReuseAddr, &true)?; let socket = UdpSocket::from_std(std_sock)?; socket.send_to(&con.data, &con.server_addr).await?; con.data.resize(2048, 0); - match tokio::time::timeout(tokio::time::Duration::from_secs(udp_timeout), socket.recv_from(&mut con.data[..])).await? { - Ok((len, _addr)) => { + match tokio::time::timeout(tokio::time::Duration::from_secs(udp_timeout), socket.recv_from(&mut con.data)).await { + Ok(ret) => { + let (len, _addr) = ret?; let mut packet = vec![]; let mut pack_len = mem::size_of::() + len; match con.server_addr.into() { @@ -203,17 +207,17 @@ async fn process_udp(addr: SocketAddr, udp_timeout: u64, tx: Sender>, co } } if let Err(e) = tx.send(packet).await { - log::error!("client {} send udp response error {:?}", addr, e); + log::error!("client {} send udp response {}", addr, e); } } Err(e) => { - log::error!("client {} udp recv_from error: {:?}", addr, e); + log::warn!("client {} udp recv_from {}", addr, e); } } Ok(()) } -async fn process_client_udp_req<'a>(args: Arc, tx: Sender>, mut client: Client, mut tcp_read_stream: ReadHalf<'a>) { +async fn process_client_udp_req<'a>(args: &UdpGwArgs, tx: Sender>, mut client: Client, mut tcp_read_stream: ReadHalf<'a>) { let mut buf = vec![0; args.udp_mtu as usize]; let mut len_buf = [0; mem::size_of::()]; let udp_mtu = args.udp_mtu; @@ -226,7 +230,7 @@ async fn process_client_udp_req<'a>(args: Arc, tx: Sender>, m } Err(_e) => { if client.last_activity.elapsed() >= CLIENT_DISCONNECT_TIMEOUT { - log::warn!("client {} last_activity elapsed", client.addr); + log::debug!("client {} last_activity elapsed", client.addr); return; } continue; @@ -244,8 +248,7 @@ async fn process_client_udp_req<'a>(args: Arc, tx: Sender>, m log::error!("client {} received packet too long", client.addr); break; } - log::info!("client {} recvied packet len {}", client.addr, packet_len); - buf.resize(packet_len as usize, 0); + log::debug!("client {} recvied packet len {}", client.addr, packet_len); client.buf.clear(); let mut left_len: usize = packet_len as usize; while left_len > 0 { @@ -260,48 +263,37 @@ async fn process_client_udp_req<'a>(args: Arc, tx: Sender>, m } } client.last_activity = std::time::Instant::now(); - let ret = parse_udp_req_data(udp_mtu, client.buf.len(), &client.buf); + let ret = parse_udp(udp_mtu, client.buf.len(), &client.buf); if let Ok((udpdata, flags, conid, reqaddr)) = ret { if flags & UDPGW_FLAG_KEEPALIVE != 0 { - log::debug!("client {} recvied keepalive packet", client.addr); + log::debug!("client {} send keepalive", client.addr); + send_keepalive_response(tx.clone(), udpdata).await; continue; } log::debug!( - "client {} recvied udp data,flags:{},conid:{},addr:{:?},data len:{}", + "client {} received udp data,flags:{},conid:{},addr:{:?},data len:{}", client.addr, flags, conid, reqaddr, udpdata.len() ); - let mut con_lock = client.connections.lock().await; - let con = con_lock.get_mut(&conid); - if let Some(conn) = con { - conn.data.clear(); - conn.data.extend_from_slice(udpdata); - if let Err(e) = process_udp(client.addr, udp_timeout, tx.clone(), conn).await { - log::error!("client {} process_udp error: {:?}", client.addr, e); - send_error_response(tx.clone(), conn).await; - continue; + let mut req = UdpRequest { + server_addr: reqaddr, + conid, + flags, + data: udpdata.to_vec(), + }; + let tx1 = tx.clone(); + let tx2 = tx.clone(); + tokio::spawn(async move { + if let Err(e) = process_udp(client.addr, udp_timeout, tx1, &mut req).await { + send_error(tx2, &mut req).await; + log::error!("client {} process_udp {}", client.addr, e); } - } else { - drop(con_lock); - let mut conn = Connection { - server_addr: reqaddr, - conid, - flags, - data: udpdata.to_vec(), - }; - if let Err(e) = process_udp(client.addr, udp_timeout, tx.clone(), &mut conn).await { - send_error_response(tx.clone(), &mut conn).await; - log::error!("client {} process_udp error: {:?}", client.addr, e); - continue; - } - client.connections.lock().await.insert(conid, conn); - } + }); } else { log::error!("client {} parse_udp_data {:?}", client.addr, ret.err()); - continue; } } Err(_) => { @@ -318,32 +310,52 @@ async fn main() -> Result<()> { let tcp_listener = TcpListener::bind(args.listen_addr).await?; - log::info!("UDP GW Server started"); - - let default = format!("{:?},hickory_proto=warn", args.verbosity); + let default = format!("{:?}", args.verbosity); env_logger::Builder::from_env(env_logger::Env::default().default_filter_or(default)).init(); + log::info!("UDP GW Server started"); + + #[cfg(unix)] + if args.daemonize { + let stdout = std::fs::File::create("/tmp/udpgw.out")?; + let stderr = std::fs::File::create("/tmp/udpgw.err")?; + let daemonize = daemonize::Daemonize::new() + .working_directory("/tmp") + .umask(0o777) + .stdout(stdout) + .stderr(stderr) + .privileged_action(|| "Executed before drop privileges"); + let _ = daemonize + .start() + .map_err(|e| format!("Failed to daemonize process, error:{:?}", e))?; + } + + #[cfg(target_os = "windows")] + if args.daemonize { + tun2proxy::win_svc::start_service()?; + return Ok(()); + } + loop { let (mut tcp_stream, addr) = tcp_listener.accept().await?; let client = Client { addr, buf: vec![], - connections: Arc::new(Mutex::new(HashMap::new())), last_activity: std::time::Instant::now(), }; log::info!("client {} connected", addr); - let params = args.clone(); + let params = Arc::clone(&args); tokio::spawn(async move { let (tx, mut rx) = mpsc::channel::>(100); let (tcp_read_stream, mut tcp_write_stream) = tcp_stream.split(); tokio::select! { - _ = process_client_udp_req(params, tx, client, tcp_read_stream) =>{} + _ = process_client_udp_req(¶ms, tx, client, tcp_read_stream) =>{} _ = async { loop { if let Some(udp_response) = rx.recv().await { - log::info!("client {} send udp data len:{}", addr, udp_response.len(),); + log::debug!("send udp_response len {}",udp_response.len()); let _ = tcp_write_stream.write(&udp_response).await; } } diff --git a/src/error.rs b/src/error.rs index a26c74f..f460b62 100644 --- a/src/error.rs +++ b/src/error.rs @@ -47,9 +47,6 @@ pub enum Error { #[cfg(target_os = "linux")] #[error("bincode::Error {0:?}")] BincodeError(#[from] bincode::Error), - - #[error("tokio::time::error::Elapsed")] - Timeout(#[from] tokio::time::error::Elapsed), } impl From<&str> for Error { diff --git a/src/lib.rs b/src/lib.rs index 2659896..e5c9998 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -24,7 +24,7 @@ use tokio::{ pub use tokio_util::sync::CancellationToken; use tproxy_config::is_private_ip; use udp_stream::UdpStream; -use udpgw::{UdpGwClientStream, UdpGwResponse, UDPGW_KEEPALIVE_TIME, UDPGW_MAX_CONNECTIONS}; +use udpgw::{UdpGwClientStream, UdpGwResponse, UDPGW_KEEPALIVE_TIME}; pub use { args::{ArgDns, ArgProxy, ArgVerbosity, Args, ProxyType}, @@ -238,7 +238,7 @@ where None => None, Some(addr) => { log::info!("UDPGW enabled"); - let client = Arc::new(UdpGwClient::new(mtu, UDPGW_MAX_CONNECTIONS, UDPGW_KEEPALIVE_TIME, addr)); + let client = Arc::new(UdpGwClient::new(mtu, args.max_udpgw_connections, UDPGW_KEEPALIVE_TIME, args.udp_timeout, addr)); let client_keepalive = client.clone(); tokio::spawn(async move { client_keepalive.heartbeat_task().await; @@ -485,6 +485,7 @@ async fn handle_udp_gateway_session( }; let udpinfo = SessionInfo::new(udp_stack.local_addr(), udp_stack.peer_addr(), IpProtocol::Udp); let udp_mtu = udpgw_client.get_udp_mtu(); + let udp_timeout = udpgw_client.get_udp_timeout(); let mut server_stream: UdpGwClientStream; let server = udpgw_client.get_server_connection().await; match server { @@ -492,10 +493,12 @@ async fn handle_udp_gateway_session( server_stream = server; } None => { - log::info!("Beginning {}", session_info); + if udpgw_client.is_full().await { + return Err("max udpgw connection limit reached".into()); + } let mut tcp_server_stream = create_tcp_stream(&socket_queue, server_addr).await?; if let Err(e) = handle_proxy_session(&mut tcp_server_stream, proxy_handler).await { - return Err(e); + return Err(format!("udpgw connection error: {}",e).into()); } server_stream = UdpGwClientStream::new(udp_mtu, tcp_server_stream); } @@ -503,77 +506,93 @@ async fn handle_udp_gateway_session( let udp_server_addr = udp_stack.peer_addr(); + let tcp_local_addr = server_stream.local_addr().clone(); + match domain_name { Some(ref d) => { - log::info!("Beginning {}, domain:{}", udpinfo, d); + log::info!("Beginning {} <- {}, domain:{}", udpinfo, &tcp_local_addr, d); } None => { - log::info!("Beginning {}", udpinfo); + log::info!("Beginning {} <- {}", udpinfo, &tcp_local_addr); } } - log::info!("Beginning {}", udpinfo); + let Some(mut stream_reader) = server_stream.get_reader() else { + return Err("get reader failed".into()); + }; + + let Some(mut stream_writer) = server_stream.get_writer() else { + return Err("get writer failed".into()); + }; loop { - let len = UdpGwClient::recv_udp_packet(&mut udp_stack, &mut server_stream).await; - let read_len; - match len { - Ok(n) => { - if n == 0 { - log::info!("Ending {}", udpinfo); - break; - } - read_len = n; - crate::traffic_status::traffic_status_update(n, 0)?; - } - Err(e) => { - log::info!("Ending {} with recv_udp_packet error: {}", udpinfo, e); - break; - } - } - - if let Err(e) = - UdpGwClient::send_udpgw_packet(ipv6_enabled, read_len, udp_server_addr, domain_name.as_ref(), &mut server_stream).await - { - log::info!( - "{:?},Ending {} with send_udpgw_packet error: {}", - server_stream.local_addr(), - udpinfo, - e - ); - break; - } - - match UdpGwClient::recv_udpgw_packet(udp_mtu, &mut server_stream).await { - Ok(packet) => match packet { - //should not received keepalive - UdpGwResponse::KeepAlive => { - log::error!("Ending {} with recv keepalive", udpinfo); - let _ = server_stream.close().await; - break; - } - UdpGwResponse::Error => { - log::info!("Ending {} with recv udp error", udpinfo); - continue; - } - UdpGwResponse::Data(data) => { - crate::traffic_status::traffic_status_update(0, data.len())?; - - if let Err(e) = UdpGwClient::send_udp_packet(data, &mut udp_stack).await { - log::info!("Ending {} with send_udp_packet error: {}", udpinfo, e); + tokio::select! { + len = UdpGwClient::recv_udp_packet(&mut udp_stack, &mut stream_writer) => { + let read_len; + match len { + Ok(n) => { + if n == 0 { + log::info!("Ending {} <- {}",udpinfo, &tcp_local_addr); + break; + } + read_len = n; + crate::traffic_status::traffic_status_update(n, 0)?; + } + Err(e) => { + log::info!("Ending {} <- {} with recv_udp_packet {}", udpinfo, &tcp_local_addr, e); break; } } - }, - Err(e) => { - log::info!("Ending {} with recv_udpgw_packet error: {}", udpinfo, e); - break; + let newid = server_stream.newid(); + if let Err(e) = + UdpGwClient::send_udpgw_packet(ipv6_enabled, read_len, udp_server_addr, domain_name.as_ref(),newid,&mut stream_writer).await + { + log::info!( + "Ending {} <- {} with send_udpgw_packet {}", + udpinfo, + &tcp_local_addr, + e + ); + break; + } + server_stream.update_activity(); + } + ret = UdpGwClient::recv_udpgw_packet(udp_mtu, udp_timeout, &mut stream_reader) => { + match ret { + Ok(packet) => match packet { + //should not received keepalive + UdpGwResponse::KeepAlive => { + log::error!("Ending {} <- {} with recv keepalive", udpinfo, &tcp_local_addr); + server_stream.close(); + break; + } + //server udp may be timeout,can continue to receive udp data? + UdpGwResponse::Error => { + log::info!("Ending {} <- {} with recv udp error", udpinfo, &tcp_local_addr); + server_stream.update_activity(); + continue; + } + UdpGwResponse::Data(data) => { + let len = data.len(); + if let Err(e) = UdpGwClient::send_udp_packet(data, &mut udp_stack).await { + log::error!("Ending {} <- {} with send_udp_packet {}", udpinfo, &tcp_local_addr, e); + break; + } + crate::traffic_status::traffic_status_update(0, len)?; + } + }, + Err(e) => { + log::warn!("Ending {} <- {} with recv_udpgw_packet {}", udpinfo, &tcp_local_addr, e); + break; + } + } + server_stream.update_activity(); } } } - if !server_stream.is_closed() { - udpgw_client.release_server_connection(server_stream).await; + if !server_stream.is_closed() { + udpgw_client.release_server_connection_with_stream(server_stream,stream_reader,stream_writer).await; } Ok(()) diff --git a/src/udpgw.rs b/src/udpgw.rs index 3b12f2f..d6989dd 100644 --- a/src/udpgw.rs +++ b/src/udpgw.rs @@ -4,9 +4,8 @@ use std::collections::VecDeque; use std::hash::Hash; use std::mem; use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; -use std::pin::Pin; -use std::task::{Context, Poll}; -use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf}; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf}; use tokio::net::TcpStream; use tokio::sync::Mutex; use tokio::time::{sleep, Duration}; @@ -108,45 +107,56 @@ pub(crate) enum UdpGwResponse<'a> { } #[derive(Debug)] -pub(crate) struct UdpGwClientStream { - inner: TcpStream, - conid: u16, +pub(crate) struct UdpGwClientStreamWriter { + inner: OwnedWriteHalf, tmp_buf: Vec, send_buf: Vec, +} + +#[derive(Debug)] +pub(crate) struct UdpGwClientStreamReader { + inner: OwnedReadHalf, recv_buf: Vec, +} + +#[derive(Debug)] +pub(crate) struct UdpGwClientStream { + local_addr: String, + writer: Option, + reader: Option, + conid: u16, closed: bool, last_activity: std::time::Instant, } -impl AsyncWrite for UdpGwClientStream { - fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { - Pin::new(&mut self.inner).poll_write(cx, buf) - } - - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.inner).poll_flush(cx) - } - - fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.inner).poll_shutdown(cx) - } -} - -impl AsyncRead for UdpGwClientStream { - fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll> { - Pin::new(&mut self.inner).poll_read(cx, buf) - } -} - impl UdpGwClientStream { - pub async fn close(&mut self) -> Result<()> { - self.inner.shutdown().await?; + pub fn close(&mut self) { self.closed = true; - Ok(()) } - pub fn local_addr(&self) -> Result { - Ok(self.inner.local_addr()?) + pub fn get_reader(&mut self) -> Option { + self.reader.take() } + + pub fn set_reader(&mut self, mut reader: Option) { + self.reader = reader.take(); + } + + pub fn set_writer(&mut self, mut writer: Option) { + self.writer = writer.take(); + } + + pub fn get_writer(&mut self) -> Option { + self.writer.take() + } + + pub fn local_addr(&self) -> &String { + &self.local_addr + } + + pub fn update_activity(&mut self) { + self.last_activity = std::time::Instant::now(); + } + pub fn is_closed(&mut self) -> bool { self.closed } @@ -156,16 +166,28 @@ impl UdpGwClientStream { } pub fn newid(&mut self) -> u16 { - let next = self.conid; self.conid += 1; - return next; + self.conid } pub fn new(udp_mtu: u16, tcp_server_stream: TcpStream) -> Self { - UdpGwClientStream { - inner: tcp_server_stream, + let local_addr = tcp_server_stream + .local_addr() + .unwrap_or_else(|_| "0.0.0.0:0".parse::().unwrap()) + .to_string(); + let (rx, tx) = tcp_server_stream.into_split(); + let writer = UdpGwClientStreamWriter { + inner: tx, tmp_buf: vec![0; udp_mtu.into()], send_buf: vec![0; udp_mtu.into()], + }; + let reader = UdpGwClientStreamReader { + inner: rx, recv_buf: vec![0; udp_mtu.into()], + }; + UdpGwClientStream { + local_addr, + reader: Some(reader), + writer: Some(writer), last_activity: std::time::Instant::now(), closed: false, conid: 0, @@ -176,7 +198,8 @@ impl UdpGwClientStream { #[derive(Debug)] pub(crate) struct UdpGwClient { udp_mtu: u16, - max_connections: usize, + max_connections: u16, + udp_timeout: u64, keepalive_time: Duration, udpgw_bind_addr: SocketAddr, keepalive_packet: Vec, @@ -184,18 +207,19 @@ pub(crate) struct UdpGwClient { } impl UdpGwClient { - pub fn new(udp_mtu: u16, max_connections: usize, keepalive_time: Duration, udpgw_bind_addr: SocketAddr) -> Self { + pub fn new(udp_mtu: u16, max_connections: u16, keepalive_time: Duration, udp_timeout: u64, udpgw_bind_addr: SocketAddr) -> Self { let mut keepalive_packet = vec![]; keepalive_packet.extend_from_slice(&(std::mem::size_of::() as u16).to_le_bytes()); keepalive_packet.extend_from_slice(&[UDPGW_FLAG_KEEPALIVE, 0, 0]); - let server_connections = Mutex::new(VecDeque::new()); + let server_connections = Mutex::new(VecDeque::with_capacity(max_connections as usize)); return UdpGwClient { udp_mtu, max_connections, + udp_timeout, udpgw_bind_addr, keepalive_time, keepalive_packet, - server_connections: server_connections, + server_connections, }; } @@ -203,12 +227,33 @@ impl UdpGwClient { self.udp_mtu } + pub(crate) fn get_udp_timeout(&self) -> u64 { + self.udp_timeout + } + + pub(crate) async fn is_full(&self) -> bool { + self.server_connections.lock().await.len() >= self.max_connections as usize + } + pub(crate) async fn get_server_connection(&self) -> Option { self.server_connections.lock().await.pop_front() } pub(crate) async fn release_server_connection(&self, stream: UdpGwClientStream) { - if self.server_connections.lock().await.len() < self.max_connections { + if self.server_connections.lock().await.len() < self.max_connections as usize { + self.server_connections.lock().await.push_back(stream); + } + } + + pub(crate) async fn release_server_connection_with_stream( + &self, + mut stream: UdpGwClientStream, + reader: UdpGwClientStreamReader, + writer: UdpGwClientStreamWriter, + ) { + if self.server_connections.lock().await.len() < self.max_connections as usize { + stream.set_reader(Some(reader)); + stream.set_writer(Some(writer)); self.server_connections.lock().await.push_back(stream); } } @@ -217,6 +262,7 @@ impl UdpGwClient { return self.udpgw_bind_addr; } + /// Heartbeat task asynchronous function to periodically check and maintain the active state of the server connection. pub(crate) async fn heartbeat_task(&self) { loop { sleep(self.keepalive_time).await; @@ -225,28 +271,35 @@ impl UdpGwClient { self.release_server_connection(stream).await; continue; } - log::debug!("{:?}:{} send keepalive", stream.local_addr(), stream.id()); - if let Err(e) = stream.write_all(&self.keepalive_packet).await { - let _ = stream.close().await; - log::warn!("{:?}:{} Heartbeat failed: {}", stream.local_addr(), stream.id(), e); + + let Some(mut stream_reader) = stream.get_reader() else { + continue; + }; + + let Some(mut stream_writer) = stream.get_writer() else { + continue; + }; + log::debug!("{:?}:{} send keepalive", stream_writer.inner.local_addr(), stream.id()); + if let Err(e) = stream_writer.inner.write_all(&self.keepalive_packet).await { + log::warn!("{:?}:{} Heartbeat failed: {}", stream_writer.inner.local_addr(), stream.id(), e); } else { - stream.last_activity = std::time::Instant::now(); - match UdpGwClient::recv_udpgw_packet(self.udp_mtu, &mut stream).await { + match UdpGwClient::recv_udpgw_packet(self.udp_mtu, 10, &mut stream_reader).await { Ok(UdpGwResponse::KeepAlive) => { - self.release_server_connection(stream).await; - continue; - } - //shoud not receive other - _ => { - continue; + stream.last_activity = std::time::Instant::now(); + self.release_server_connection_with_stream(stream, stream_reader, stream_writer) + .await; } + //shoud not receive other type + _ => {} } } } } } - pub(crate) fn parse_udp_response(udp_mtu: u16, data_len: usize, data: &[u8]) -> Result { + /// Parses the UDP response data. + pub(crate) fn parse_udp_response(udp_mtu: u16, data_len: usize, stream: &mut UdpGwClientStreamReader) -> Result { + let data = &stream.recv_buf; if data_len < mem::size_of::() { return Err("Invalid udpgw data".into()); } @@ -259,7 +312,6 @@ impl UdpGwClient { let flags = header.flags; let conid = header.conid; - // parse address let ip_data = &data[mem::size_of::()..]; let mut data_len = data_len - mem::size_of::(); @@ -267,7 +319,7 @@ impl UdpGwClient { return Ok(UdpGwResponse::Error); } - if flags & UDPGW_FLAG_ERR != 0 { + if flags & UDPGW_FLAG_KEEPALIVE != 0 { return Ok(UdpGwResponse::KeepAlive); } @@ -281,7 +333,7 @@ impl UdpGwClient { addr_port: u16::from_be_bytes([addr_ipv6_bytes[16], addr_ipv6_bytes[17]]), }; data_len -= mem::size_of::(); - // check payload length + if data_len > udp_mtu as usize { return Err("too much data".into()); } @@ -302,7 +354,6 @@ impl UdpGwClient { }; data_len -= mem::size_of::(); - // check payload length if data_len > udp_mtu as usize { return Err("too much data".into()); } @@ -317,7 +368,7 @@ impl UdpGwClient { pub(crate) async fn recv_udp_packet( udp_stack: &mut IpStackUdpStream, - stream: &mut UdpGwClientStream, + stream: &mut UdpGwClientStreamWriter, ) -> std::result::Result { return udp_stack.read(&mut stream.tmp_buf).await; } @@ -329,22 +380,35 @@ impl UdpGwClient { return udp_stack.write_all(&packet.udpdata).await; } - pub(crate) async fn recv_udpgw_packet(udp_mtu: u16, stream: &mut UdpGwClientStream) -> Result { - stream.recv_buf.resize(2, 0); + /// Receives a UDP gateway packet. + /// + /// This function is responsible for receiving packets from the UDP gateway + /// + /// # Arguments + /// - `udp_mtu`: The maximum transmission unit size for UDP packets. + /// - `udp_timeout`: The timeout in seconds for receiving UDP packets. + /// - `stream`: A mutable reference to the UDP gateway client stream reader. + /// + /// # Returns + /// - `Result`: Returns a result type containing the parsed UDP gateway response, or an error if one occurs. + pub(crate) async fn recv_udpgw_packet(udp_mtu: u16, udp_timeout: u64, stream: &mut UdpGwClientStreamReader) -> Result { let result; - match tokio::time::timeout(tokio::time::Duration::from_secs(10), stream.inner.read(&mut stream.recv_buf)).await { + match tokio::time::timeout( + tokio::time::Duration::from_secs(udp_timeout + 2), + stream.inner.read(&mut stream.recv_buf[..2]), + ) + .await + { Ok(ret) => { result = ret; } Err(_e) => { - let _ = stream.close().await; - return Err(format!("{:?} wait tcp data timeout", stream.local_addr()).into()); + return Err(format!("wait tcp data timeout").into()); } }; match result { Ok(0) => { - let _ = stream.close().await; - return Err(format!("{:?} tcp connection closed", stream.local_addr()).into()); + return Err(format!("tcp connection closed").into()); } Ok(n) => { if n < std::mem::size_of::() { @@ -354,41 +418,53 @@ impl UdpGwClient { if packet_len > udp_mtu { return Err("packet too long".into()); } - stream.recv_buf.resize(udp_mtu as usize, 0); let mut left_len: usize = packet_len as usize; let mut recv_len = 0; while left_len > 0 { if let Ok(len) = stream.inner.read(&mut stream.recv_buf[recv_len..left_len]).await { if len == 0 { - let _ = stream.close().await; - return Err(format!("{:?} tcp connection closed", stream.local_addr()).into()); + return Err("tcp connection closed".into()); } recv_len += len; left_len -= len; } else { - let _ = stream.close().await; - return Err(format!("{:?} tcp connection closed", stream.local_addr()).into()); + return Err("tcp connection closed".into()); } } - stream.last_activity = std::time::Instant::now(); - return UdpGwClient::parse_udp_response(udp_mtu, packet_len as usize, &stream.recv_buf); + return UdpGwClient::parse_udp_response(udp_mtu, packet_len as usize, stream); } Err(_) => { - let _ = stream.close().await; - return Err(format!("{:?} tcp read error", stream.local_addr()).into()); + return Err("tcp read error".into()); } } } + /// Sends a UDP gateway packet. + /// + /// This function constructs and sends a UDP gateway packet based on the IPv6 enabled status, data length, + /// remote address, domain (if any), connection ID, and the UDP gateway client writer stream. + /// + /// # Arguments + /// + /// * `ipv6_enabled` - Whether IPv6 is enabled + /// * `len` - Length of the data packet + /// * `remote_addr` - Remote address + /// * `domain` - Target domain (optional) + /// * `conid` - Connection ID + /// * `stream` - UDP gateway client writer stream + /// + /// # Returns + /// + /// Returns `Ok(())` if the packet is sent successfully, otherwise returns an error. pub(crate) async fn send_udpgw_packet( ipv6_enabled: bool, len: usize, remote_addr: SocketAddr, domain: Option<&String>, - stream: &mut UdpGwClientStream, + conid: u16, + stream: &mut UdpGwClientStreamWriter, ) -> Result<()> { stream.send_buf.clear(); - let conid = stream.newid(); let data = &stream.tmp_buf; let mut pack_len = std::mem::size_of::() + len; let packet = &mut stream.send_buf; @@ -442,8 +518,6 @@ impl UdpGwClient { stream.inner.write_all(&packet).await?; - stream.last_activity = std::time::Instant::now(); - Ok(()) } }