From 0aad0d17098af91173dc7a8472de05601c6531d0 Mon Sep 17 00:00:00 2001 From: ssrlive <30760636+ssrlive@users.noreply.github.com> Date: Wed, 30 Oct 2024 19:00:28 +0800 Subject: [PATCH] refactor udpgw --- src/bin/udpgw_server.rs | 8 +- src/lib.rs | 65 ++++++++------- src/udpgw.rs | 174 ++++++++++++++++++++++++---------------- 3 files changed, 149 insertions(+), 98 deletions(-) diff --git a/src/bin/udpgw_server.rs b/src/bin/udpgw_server.rs index ac0a052..a6cb365 100644 --- a/src/bin/udpgw_server.rs +++ b/src/bin/udpgw_server.rs @@ -6,10 +6,10 @@ use tokio::{ tcp::{ReadHalf, WriteHalf}, UdpSocket, }, - sync::mpsc::{self, Receiver, Sender}, + sync::mpsc::{Receiver, Sender}, }; use tun2proxy::{ - udpgw::{Packet, UDPGW_FLAG_KEEPALIVE}, + udpgw::{Packet, UdpFlag}, ArgVerbosity, BoxError, Error, Result, }; @@ -134,7 +134,7 @@ async fn process_client_udp_req(args: &UdpGwArgs, tx: Sender, mut client let flags = packet.header.flags; let conn_id = packet.header.conn_id; - if flags & UDPGW_FLAG_KEEPALIVE != 0 { + if flags & UdpFlag::KEEPALIVE == UdpFlag::KEEPALIVE { log::trace!("client {} send keepalive", client.addr); // 2. if keepalive packet, do nothing, send keepalive response to client send_keepalive_response(tx.clone(), conn_id).await; @@ -227,7 +227,7 @@ pub async fn run(args: UdpGwArgs, shutdown_token: tokio_util::sync::Cancellation log::info!("client {} connected", addr); let params = args.clone(); tokio::spawn(async move { - let (tx, rx) = mpsc::channel::(100); + let (tx, rx) = tokio::sync::mpsc::channel::(100); let (tcp_read_stream, tcp_write_stream) = tcp_stream.split(); let res = tokio::select! { v = process_client_udp_req(¶ms, tx, client, tcp_read_stream) => v, diff --git a/src/lib.rs b/src/lib.rs index 4d855b6..5d69388 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -251,7 +251,7 @@ where )); let client_keepalive = client.clone(); tokio::spawn(async move { - client_keepalive.heartbeat_task().await; + let _ = client_keepalive.heartbeat_task().await; }); client }); @@ -349,7 +349,7 @@ where SocketAddr::V4(_) => SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)), SocketAddr::V6(_) => SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, 0, 0, 0)), }; - let tcpinfo = SessionInfo::new(tcp_src, udpgw.get_server_addr(), IpProtocol::Tcp); + let tcpinfo = SessionInfo::new(tcp_src, udpgw.get_udpgw_server_addr(), IpProtocol::Tcp); let proxy_handler = mgr.new_proxy_handler(tcpinfo, None, false).await?; let queue = socket_queue.clone(); tokio::spawn(async move { @@ -495,23 +495,33 @@ async fn handle_udp_gateway_session( let proxy_server_addr = { proxy_handler.lock().await.get_server_addr() }; let udp_mtu = udpgw_client.get_udp_mtu(); let udp_timeout = udpgw_client.get_udp_timeout(); - let mut stream = match udpgw_client.get_server_connection().await { - Some(server) => server, - None => { - if udpgw_client.is_full() { - return Err("max udpgw connection limit reached".into()); + + let mut stream = loop { + match udpgw_client.pop_server_connection_from_queue().await { + Some(stream) => { + if stream.is_closed() { + continue; + } else { + break stream; + } } - let mut tcp_server_stream = create_tcp_stream(&socket_queue, proxy_server_addr).await?; - if let Err(e) = handle_proxy_session(&mut tcp_server_stream, proxy_handler).await { - return Err(format!("udpgw connection error: {}", e).into()); + None => { + if udpgw_client.is_full() { + return Err("max udpgw connection limit reached".into()); + } + let mut tcp_server_stream = create_tcp_stream(&socket_queue, proxy_server_addr).await?; + if let Err(e) = handle_proxy_session(&mut tcp_server_stream, proxy_handler).await { + return Err(format!("udpgw connection error: {}", e).into()); + } + break UdpGwClientStream::new(tcp_server_stream); } - UdpGwClientStream::new(tcp_server_stream) } }; let tcp_local_addr = stream.local_addr().clone(); + let sn = stream.serial_number(); - log::info!("[UdpGw] Beginning {} -> {}", &tcp_local_addr, udp_dst); + log::info!("[UdpGw] Beginning stream {} {} -> {}", sn, &tcp_local_addr, udp_dst); let Some(mut reader) = stream.get_reader() else { return Err("get reader failed".into()); @@ -528,58 +538,59 @@ async fn handle_udp_gateway_session( len = udp_stack.read(&mut tmp_buf) => { let read_len = match len { Ok(0) => { - log::info!("[UdpGw] Ending {} <> {}", &tcp_local_addr, udp_dst); + log::info!("[UdpGw] Ending stream {} {} <> {}", sn, &tcp_local_addr, udp_dst); break; } Ok(n) => n, Err(e) => { - log::info!("[UdpGw] Ending {} <> {} with recv_udp_packet {}", &tcp_local_addr, udp_dst, e); + log::info!("[UdpGw] Ending stream {} {} <> {} with recv_udp_packet {}", sn, &tcp_local_addr, udp_dst, e); break; } }; crate::traffic_status::traffic_status_update(read_len, 0)?; - let new_id = stream.new_id(); + let new_id = stream.new_packet_id(); if let Err(e) = UdpGwClient::send_udpgw_packet(ipv6_enabled, &tmp_buf[0..read_len], udp_dst, new_id, &mut writer).await { - log::info!("[UdpGw] Ending {} <> {} with send_udpgw_packet {}", &tcp_local_addr, udp_dst, e); + log::info!("[UdpGw] Ending stream {} {} <> {} with send_udpgw_packet {}", sn, &tcp_local_addr, udp_dst, e); break; } - log::debug!("[UdpGw] {} -> {} send len {}", &tcp_local_addr, udp_dst, read_len); + log::debug!("[UdpGw] stream {} {} -> {} send len {}", sn, &tcp_local_addr, udp_dst, read_len); stream.update_activity(); } ret = UdpGwClient::recv_udpgw_packet(udp_mtu, udp_timeout, &mut reader) => { match ret { + Err(e) => { + log::warn!("[UdpGw] Ending stream {} {} <> {} with recv_udpgw_packet {}", sn, &tcp_local_addr, udp_dst, e); + stream.close(); + break; + } Ok(packet) => match packet { //should not received keepalive UdpGwResponse::KeepAlive => { - log::error!("[UdpGw] Ending {} <> {} with recv keepalive", &tcp_local_addr, udp_dst); + log::error!("[UdpGw] Ending stream {} {} <> {} with recv keepalive", sn, &tcp_local_addr, udp_dst); stream.close(); break; } //server udp may be timeout,can continue to receive udp data? UdpGwResponse::Error => { - log::info!("[UdpGw] Ending {} <> {} with recv udp error", &tcp_local_addr, udp_dst); + log::info!("[UdpGw] Ending stream {} {} <> {} with recv udp error", sn, &tcp_local_addr, udp_dst); stream.update_activity(); continue; } UdpGwResponse::TcpClose => { - log::error!("[UdpGw] Ending {} <> {} with tcp closed", &tcp_local_addr, udp_dst); + log::error!("[UdpGw] Ending stream {} {} <> {} with tcp closed", sn, &tcp_local_addr, udp_dst); stream.close(); break; } UdpGwResponse::Data(data) => { use socks5_impl::protocol::StreamOperation; let len = data.len(); - log::debug!("[UdpGw] {} <- {} receive len {}", &tcp_local_addr, udp_dst, len); + log::debug!("[UdpGw] stream {} {} <- {} receive len {}", sn, &tcp_local_addr, udp_dst, len); if let Err(e) = udp_stack.write_all(&data.data).await { - log::error!("[UdpGw] Ending {} <> {} with send_udp_packet {}", &tcp_local_addr, udp_dst, e); + log::error!("[UdpGw] Ending stream {} {} <> {} with send_udp_packet {}", sn, &tcp_local_addr, udp_dst, e); break; } crate::traffic_status::traffic_status_update(0, len)?; } - }, - Err(e) => { - log::warn!("[UdpGw] Ending {} <> {} with recv_udpgw_packet {}", &tcp_local_addr, udp_dst, e); - break; } } stream.update_activity(); @@ -588,7 +599,7 @@ async fn handle_udp_gateway_session( } if !stream.is_closed() { - udpgw_client.release_server_connection_full(stream, reader, writer).await; + udpgw_client.store_server_connection_full(stream, reader, writer).await; } Ok(()) diff --git a/src/udpgw.rs b/src/udpgw.rs index 4d7f018..295be25 100644 --- a/src/udpgw.rs +++ b/src/udpgw.rs @@ -15,9 +15,42 @@ pub(crate) const UDPGW_LENGTH_FIELD_SIZE: usize = std::mem::size_of::(); pub(crate) const UDPGW_MAX_CONNECTIONS: u16 = 100; pub(crate) const UDPGW_KEEPALIVE_TIME: tokio::time::Duration = std::time::Duration::from_secs(10); -pub const UDPGW_FLAG_KEEPALIVE: u8 = 0x01; -pub const UDPGW_FLAG_ERR: u8 = 0x20; -pub const UDPGW_FLAG_DATA: u8 = 0x02; +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct UdpFlag(pub u8); + +impl UdpFlag { + pub const ZERO: UdpFlag = UdpFlag(0x00); + pub const KEEPALIVE: UdpFlag = UdpFlag(0x01); + pub const ERR: UdpFlag = UdpFlag(0x20); + pub const DATA: UdpFlag = UdpFlag(0x02); +} + +impl std::fmt::Display for UdpFlag { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let flag = match self.0 { + 0x00 => "ZERO", + 0x01 => "KEEPALIVE", + 0x20 => "ERR", + 0x02 => "DATA", + n => return write!(f, "Unknown UdpFlag(0x{:02X})", n), + }; + write!(f, "UdpFlag({})", flag) + } +} + +impl std::ops::BitAnd for UdpFlag { + type Output = Self; + fn bitand(self, rhs: Self) -> Self::Output { + UdpFlag(self.0 & rhs.0) + } +} + +impl std::ops::BitOr for UdpFlag { + type Output = Self; + fn bitor(self, rhs: Self) -> Self::Output { + UdpFlag(self.0 | rhs.0) + } +} static TCP_COUNTER: std::sync::atomic::AtomicU32 = std::sync::atomic::AtomicU32::new(0); @@ -98,7 +131,7 @@ impl TryFrom<&[u8]> for Packet { return Err(std::io::ErrorKind::InvalidData.into()); } let header = UdpgwHeader::retrieve_from_stream(&mut iter)?; - let address = if header.flags & UDPGW_FLAG_DATA != 0 { + let address = if header.flags & UdpFlag::DATA != UdpFlag::ZERO { Some(Address::retrieve_from_stream(&mut iter)?) } else { None @@ -114,11 +147,11 @@ impl Packet { } pub fn build_keepalive_packet(conn_id: u16) -> Self { - Packet::new(UdpgwHeader::new(UDPGW_FLAG_KEEPALIVE, conn_id), None, &[]) + Packet::new(UdpgwHeader::new(UdpFlag::KEEPALIVE, conn_id), None, &[]) } pub fn build_error_packet(conn_id: u16) -> Self { - Packet::new(UdpgwHeader::new(UDPGW_FLAG_ERR, conn_id), None, &[]) + Packet::new(UdpgwHeader::new(UdpFlag::ERR, conn_id), None, &[]) } pub fn build_packet_from_address(conn_id: u16, remote_addr: &Address, data: &[u8]) -> std::io::Result { @@ -132,7 +165,7 @@ impl Packet { pub fn build_ip_packet(conn_id: u16, remote_addr: SocketAddr, data: &[u8]) -> Self { let addr: Address = remote_addr.into(); - Packet::new(UdpgwHeader::new(UDPGW_FLAG_DATA, conn_id), Some(addr), data) + Packet::new(UdpgwHeader::new(UdpFlag::DATA, conn_id), Some(addr), data) } pub fn build_domain_packet(conn_id: u16, port: u16, domain: &str, data: &[u8]) -> std::io::Result { @@ -140,7 +173,7 @@ impl Packet { return Err(std::io::ErrorKind::InvalidInput.into()); } let addr = Address::from((domain, port)); - Ok(Packet::new(UdpgwHeader::new(UDPGW_FLAG_DATA, conn_id), Some(addr), data)) + Ok(Packet::new(UdpgwHeader::new(UdpFlag::DATA, conn_id), Some(addr), data)) } } @@ -154,7 +187,7 @@ impl StreamOperation for Packet { stream.read_exact(&mut buf)?; let length = u16::from_be_bytes(buf) as usize; let header = UdpgwHeader::retrieve_from_stream(stream)?; - let address = if header.flags & UDPGW_FLAG_DATA != 0 { + let address = if header.flags & UdpFlag::DATA == UdpFlag::DATA { Some(Address::retrieve_from_stream(stream)?) } else { None @@ -194,7 +227,7 @@ impl AsyncStreamOperation for Packet { r.read_exact(&mut buf).await?; let length = u16::from_be_bytes(buf) as usize; let header = UdpgwHeader::retrieve_from_async_stream(r).await?; - let address = if header.flags & UDPGW_FLAG_DATA != 0 { + let address = if header.flags & UdpFlag::DATA == UdpFlag::DATA { Some(Address::retrieve_from_async_stream(r).await?) } else { None @@ -211,14 +244,14 @@ impl AsyncStreamOperation for Packet { #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct UdpgwHeader { - pub flags: u8, + pub flags: UdpFlag, pub conn_id: u16, } impl std::fmt::Display for UdpgwHeader { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let id = self.conn_id; - write!(f, "flags: 0x{:02x}, conn_id: {}", self.flags, id) + write!(f, "flags: {}, conn_id: {}", self.flags, id) } } @@ -257,7 +290,7 @@ impl AsyncStreamOperation for UdpgwHeader { } impl UdpgwHeader { - pub fn new(flags: u8, conn_id: u16) -> Self { + pub fn new(flags: UdpFlag, conn_id: u16) -> Self { UdpgwHeader { flags, conn_id } } @@ -274,14 +307,14 @@ impl TryFrom<&[u8]> for UdpgwHeader { return Err(std::io::ErrorKind::InvalidData.into()); } let conn_id = u16::from_be_bytes([value[1], value[2]]); - Ok(UdpgwHeader { flags: value[0], conn_id }) + Ok(UdpgwHeader::new(UdpFlag(value[0]), conn_id)) } } impl From<&UdpgwHeader> for Vec { fn from(header: &UdpgwHeader) -> Vec { let mut bytes = vec![0; header.len()]; - bytes[0] = header.flags; + bytes[0] = header.flags.0; bytes[1..3].copy_from_slice(&header.conn_id.to_be_bytes()); bytes } @@ -296,14 +329,17 @@ pub(crate) enum UdpGwResponse { Data(Packet), } +static SERIAL_NUMBER: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(1); + #[derive(Debug)] pub(crate) struct UdpGwClientStream { - local_addr: String, + local_addr: SocketAddr, writer: Option, reader: Option, conn_id: u16, closed: bool, last_activity: std::time::Instant, + serial_number: u64, } impl Drop for UdpGwClientStream { @@ -333,34 +369,33 @@ impl UdpGwClientStream { self.writer.take() } - pub fn local_addr(&self) -> &String { - &self.local_addr + pub fn local_addr(&self) -> SocketAddr { + self.local_addr } pub fn update_activity(&mut self) { self.last_activity = std::time::Instant::now(); } - pub fn is_closed(&mut self) -> bool { + pub fn is_closed(&self) -> bool { self.closed } - pub fn id(&mut self) -> u16 { - self.conn_id + pub fn serial_number(&self) -> u64 { + self.serial_number } - pub fn new_id(&mut self) -> u16 { + pub fn new_packet_id(&mut self) -> u16 { self.conn_id += 1; self.conn_id } pub fn new(tcp_server_stream: TcpStream) -> Self { let default = "0.0.0.0:0".parse::().unwrap(); - let local_addr = tcp_server_stream.local_addr().unwrap_or(default).to_string(); - let (rx, tx) = tcp_server_stream.into_split(); - let writer = tx; - let reader = rx; + let local_addr = tcp_server_stream.local_addr().unwrap_or(default); + let (reader, writer) = tcp_server_stream.into_split(); TCP_COUNTER.fetch_add(1, Relaxed); + let serial_number = SERIAL_NUMBER.fetch_add(1, Relaxed); UdpGwClientStream { local_addr, reader: Some(reader), @@ -368,6 +403,7 @@ impl UdpGwClientStream { last_activity: std::time::Instant::now(), closed: false, conn_id: 0, + serial_number, } } } @@ -378,18 +414,18 @@ pub(crate) struct UdpGwClient { max_connections: u16, udp_timeout: u64, keepalive_time: Duration, - server_addr: SocketAddr, + udpgw_server: SocketAddr, server_connections: Mutex>, } impl UdpGwClient { - pub fn new(udp_mtu: u16, max_connections: u16, keepalive_time: Duration, udp_timeout: u64, server_addr: SocketAddr) -> Self { + pub fn new(udp_mtu: u16, max_connections: u16, keepalive_time: Duration, udp_timeout: u64, udpgw_server: SocketAddr) -> Self { let server_connections = Mutex::new(VecDeque::with_capacity(max_connections as usize)); UdpGwClient { udp_mtu, max_connections, udp_timeout, - server_addr, + udpgw_server, keepalive_time, server_connections, } @@ -407,22 +443,17 @@ impl UdpGwClient { TCP_COUNTER.load(Relaxed) >= self.max_connections as u32 } - pub(crate) async fn get_server_connection(&self) -> Option { + pub(crate) async fn pop_server_connection_from_queue(&self) -> Option { self.server_connections.lock().await.pop_front() } - pub(crate) async fn release_server_connection(&self, stream: UdpGwClientStream) { + pub(crate) async fn store_server_connection(&self, stream: UdpGwClientStream) { if self.server_connections.lock().await.len() < self.max_connections as usize { self.server_connections.lock().await.push_back(stream); } } - pub(crate) async fn release_server_connection_full( - &self, - mut stream: UdpGwClientStream, - reader: OwnedReadHalf, - writer: OwnedWriteHalf, - ) { + pub(crate) async fn store_server_connection_full(&self, mut stream: UdpGwClientStream, reader: OwnedReadHalf, writer: OwnedWriteHalf) { if self.server_connections.lock().await.len() < self.max_connections as usize { stream.set_reader(Some(reader)); stream.set_writer(Some(writer)); @@ -430,42 +461,51 @@ impl UdpGwClient { } } - pub(crate) fn get_server_addr(&self) -> SocketAddr { - self.server_addr + pub(crate) fn get_udpgw_server_addr(&self) -> SocketAddr { + self.udpgw_server } /// Heartbeat task asynchronous function to periodically check and maintain the active state of the server connection. - pub(crate) async fn heartbeat_task(&self) { + pub(crate) async fn heartbeat_task(&self) -> std::io::Result<()> { loop { sleep(self.keepalive_time).await; - if let Some(mut stream) = self.get_server_connection().await { - if stream.last_activity.elapsed() < self.keepalive_time { - self.release_server_connection(stream).await; - continue; - } + let Some(mut stream) = self.pop_server_connection_from_queue().await else { + continue; + }; - let Some(mut stream_reader) = stream.get_reader() else { - continue; - }; + if stream.is_closed() { + // This stream will be dropped + continue; + } - let Some(mut stream_writer) = stream.get_writer() else { - continue; - }; - let local_addr = stream_writer.local_addr(); - log::debug!("{:?}:{} send keepalive", local_addr, stream.id()); - let keepalive_packet: Vec = Packet::build_keepalive_packet(stream.id()).into(); - if let Err(e) = stream_writer.write_all(&keepalive_packet).await { - log::warn!("{:?}:{} send keepalive failed: {}", local_addr, stream.id(), e); - continue; - } - match UdpGwClient::recv_udpgw_packet(self.udp_mtu, 10, &mut stream_reader).await { - Ok(UdpGwResponse::KeepAlive) => { - stream.update_activity(); - self.release_server_connection_full(stream, stream_reader, stream_writer).await; - } - Ok(v) => log::warn!("{:?}:{} keepalive unexpected response: {:?}", local_addr, stream.id(), v), - Err(e) => log::warn!("{:?}:{} keepalive no response, error \"{}\"", local_addr, stream.id(), e), + if stream.last_activity.elapsed() < self.keepalive_time { + self.store_server_connection(stream).await; + continue; + } + + let Some(mut stream_reader) = stream.get_reader() else { + continue; + }; + + let Some(mut stream_writer) = stream.get_writer() else { + continue; + }; + let local_addr = stream_writer.local_addr()?; + let sn = stream.serial_number(); + log::trace!("stream {} {:?} send keepalive", sn, local_addr); + let keepalive_packet: Vec = Packet::build_keepalive_packet(stream.new_packet_id()).into(); + if let Err(e) = stream_writer.write_all(&keepalive_packet).await { + log::warn!("stream {} {:?} send keepalive failed: {}", sn, local_addr, e); + continue; + } + match UdpGwClient::recv_udpgw_packet(self.udp_mtu, 10, &mut stream_reader).await { + Ok(UdpGwResponse::KeepAlive) => { + stream.update_activity(); + self.store_server_connection_full(stream, stream_reader, stream_writer).await; + log::trace!("stream {} {:?} keepalive success", sn, local_addr); } + Ok(v) => log::warn!("stream {} {:?} keepalive unexpected response: {:?}", sn, local_addr, v), + Err(e) => log::warn!("stream {} {:?} keepalive no response, error \"{}\"", sn, local_addr, e), } } } @@ -474,10 +514,10 @@ impl UdpGwClient { pub(crate) fn parse_udp_response(udp_mtu: u16, data: &[u8]) -> Result { let packet = Packet::try_from(data)?; let flags = packet.header.flags; - if flags & UDPGW_FLAG_ERR != 0 { + if flags & UdpFlag::ERR == UdpFlag::ERR { return Ok(UdpGwResponse::Error); } - if flags & UDPGW_FLAG_KEEPALIVE != 0 { + if flags & UdpFlag::KEEPALIVE == UdpFlag::KEEPALIVE { return Ok(UdpGwResponse::KeepAlive); } if packet.data.len() > udp_mtu as usize {