From d4afc8f655bce545154b7484e55f664e2d718402 Mon Sep 17 00:00:00 2001 From: ssrlive <30760636+ssrlive@users.noreply.github.com> Date: Fri, 25 Oct 2024 01:01:46 +0800 Subject: [PATCH] refactor core logic --- src/bin/udpgw_server.rs | 317 +++++------------------ src/lib.rs | 38 +-- src/udpgw.rs | 545 +++++++++++++--------------------------- 3 files changed, 270 insertions(+), 630 deletions(-) diff --git a/src/bin/udpgw_server.rs b/src/bin/udpgw_server.rs index 38b36df..f0932c7 100644 --- a/src/bin/udpgw_server.rs +++ b/src/bin/udpgw_server.rs @@ -1,41 +1,30 @@ -use std::{ - net::{Ipv4Addr, SocketAddr, SocketAddrV4, ToSocketAddrs}, - sync::Arc, -}; +use socks5_impl::protocol::{AddressType, AsyncStreamOperation}; +use std::{net::SocketAddr, sync::Arc}; use tokio::{ - io::{AsyncReadExt, AsyncWriteExt}, + io::AsyncWriteExt, net::{ tcp::{ReadHalf, WriteHalf}, UdpSocket, }, sync::mpsc::{self, Receiver, Sender}, }; -use tun2proxy::{udpgw::*, ArgVerbosity, Result}; +use tun2proxy::{ + udpgw::{Packet, UDPGW_FLAG_KEEPALIVE}, + ArgVerbosity, Result, +}; pub(crate) const CLIENT_DISCONNECT_TIMEOUT: tokio::time::Duration = std::time::Duration::from_secs(60); -#[derive(Debug, Clone)] -struct UdpRequest { - flags: u8, - server_addr: SocketAddr, - conn_id: u16, - data: Vec, -} - #[derive(Debug, Clone)] pub struct Client { addr: SocketAddr, - buf: Vec, last_activity: std::time::Instant, } impl Client { pub fn new(addr: SocketAddr) -> Self { - Self { - addr, - buf: vec![], - last_activity: std::time::Instant::now(), - } + let last_activity = std::time::Instant::now(); + Self { addr, last_activity } } } @@ -71,112 +60,28 @@ impl UdpGwArgs { } } -async fn send_error(tx: Sender>, con: &mut UdpRequest) { - let error_packet: Vec = Packet::new(UdpgwHeader::new(UDPGW_FLAG_ERR, con.conn_id), vec![]).into(); +async fn send_error(tx: Sender, conn_id: u16) { + let error_packet = Packet::build_error_packet(conn_id); if let Err(e) = tx.send(error_packet).await { log::error!("send error response error {:?}", e); } } -async fn send_keepalive_response(tx: Sender>, conn_id: u16) { - let keepalive_packet: Vec = Packet::new(UdpgwHeader::new(UDPGW_FLAG_KEEPALIVE, conn_id), vec![]).into(); +async fn send_keepalive_response(tx: Sender, conn_id: u16) { + let keepalive_packet = Packet::build_keepalive_packet(conn_id); if let Err(e) = tx.send(keepalive_packet).await { log::error!("send keepalive response error {:?}", e); } } -pub fn parse_udp(udp_mtu: u16, data_len: usize, data: &[u8]) -> Result<(&[u8], u8, u16, SocketAddr)> { - let header_len = UdpgwHeader::static_len(); - if data_len < header_len { - return Err("Invalid udpgw data".into()); - } - let header_bytes = &data[..header_len]; - let header = UdpgwHeader { - flags: header_bytes[0], - conn_id: u16::from_le_bytes([header_bytes[1], header_bytes[2]]), +/// Send data field of packet from client to destination server and receive response, +/// then wrap response data to the packet's data field and send packet back to client. +async fn process_udp(client: SocketAddr, udp_mtu: u16, udp_timeout: u64, tx: Sender, mut packet: Packet) -> Result<()> { + let Some(dst_addr) = &packet.address else { + log::error!("client {} udp request address is None", client); + return Ok(()); }; - - let flags = header.flags; - let conn_id = header.conn_id; - - // keepalive - if flags & UDPGW_FLAG_KEEPALIVE != 0 { - return Ok((data, flags, conn_id, SocketAddrV4::new(Ipv4Addr::from(0), 0).into())); - } - - let ip_data = &data[header_len..]; - let mut data_len = data_len - header_len; - // port_len + min(ipv4/ipv6/(domain_len + 1)) - if data_len < UDPGW_LENGTH_FIELD_SIZE + 2 { - return Err("Invalid udpgw data".into()); - } - if flags & UDPGW_FLAG_DOMAIN != 0 { - let addr_port = u16::from_be_bytes([ip_data[0], ip_data[1]]); - data_len -= 2; - if let Some(end) = ip_data.iter().skip(2).position(|&x| x == 0) { - let domain_slice = &ip_data[2..end + 2]; - match std::str::from_utf8(domain_slice) { - Ok(domain) => { - let target_str = format!("{}:{}", domain, addr_port); - let target = target_str - .to_socket_addrs()? - .next() - .ok_or(format!("Invalid address {}", target_str))?; - if data_len < 2 + domain.len() { - return Err("Invalid udpgw data".into()); - } - data_len -= domain.len() + 1; - if data_len > udp_mtu as usize { - return Err("too much data".into()); - } - let udpdata = &ip_data[(2 + domain.len() + 1)..]; - Ok((udpdata, flags, conn_id, target)) - } - Err(_) => Err("Invalid UTF-8 sequence in domain".into()), - } - } else { - Err("missing domain name".into()) - } - } else if flags & UDPGW_FLAG_IPV6 != 0 { - let addr_ipv6_len = BinSocketAddr::static_len(true); - if data_len < addr_ipv6_len { - return Err("Ipv6 Invalid UDP data".into()); - } - let addr_ipv6 = BinSocketAddr::try_from(&ip_data[..addr_ipv6_len])?; - data_len -= addr_ipv6_len; - - if data_len > udp_mtu as usize { - return Err("too much data".into()); - } - return Ok(( - &ip_data[addr_ipv6_len..(data_len + addr_ipv6_len)], - flags, - conn_id, - addr_ipv6.into(), - )); - } else { - let addr_ipv4_len = BinSocketAddr::static_len(false); - if data_len < addr_ipv4_len { - return Err("Ipv4 Invalid UDP data".into()); - } - let addr_ipv4 = BinSocketAddr::try_from(&ip_data[..addr_ipv4_len])?; - data_len -= addr_ipv4_len; - - if data_len > udp_mtu as usize { - return Err("too much data".into()); - } - - return Ok(( - &ip_data[addr_ipv4_len..(data_len + addr_ipv4_len)], - flags, - conn_id, - addr_ipv4.into(), - )); - } -} - -async fn process_udp(addr: SocketAddr, udp_timeout: u64, tx: Sender>, con: &mut UdpRequest) -> Result<()> { - let std_sock = if con.flags & UDPGW_FLAG_IPV6 != 0 { + let std_sock = if dst_addr.get_type() == AddressType::IPv6 { std::net::UdpSocket::bind("[::]:0")? } else { std::net::UdpSocket::bind("0.0.0.0:0")? @@ -185,163 +90,78 @@ async fn process_udp(addr: SocketAddr, udp_timeout: u64, tx: Sender>, co #[cfg(unix)] nix::sys::socket::setsockopt(&std_sock, nix::sys::socket::sockopt::ReuseAddr, &true)?; let socket = UdpSocket::from_std(std_sock)?; - socket.send_to(&con.data, &con.server_addr).await?; - con.data.resize(2048, 0); - match tokio::time::timeout(tokio::time::Duration::from_secs(udp_timeout), socket.recv_from(&mut con.data)).await { - Ok(ret) => { - let (len, _addr) = ret?; - let mut packet = vec![]; - let mut pack_len = UdpgwHeader::static_len() + len; - match con.server_addr { - SocketAddr::V4(_) => { - let addr_ipv4 = BinSocketAddr::from(con.server_addr); - pack_len += addr_ipv4.len(); - packet.extend_from_slice(&(pack_len as u16).to_le_bytes()); - packet.extend_from_slice(&[con.flags]); - packet.extend_from_slice(&con.conn_id.to_le_bytes()); - let addr_ipv4_bin: Vec = addr_ipv4.into(); - packet.extend_from_slice(&addr_ipv4_bin); - packet.extend_from_slice(&con.data[..len]); - } - SocketAddr::V6(_) => { - let addr_ipv6 = BinSocketAddr::from(con.server_addr); - pack_len += addr_ipv6.len(); - packet.extend_from_slice(&(pack_len as u16).to_le_bytes()); - packet.extend_from_slice(&[con.flags]); - packet.extend_from_slice(&con.conn_id.to_le_bytes()); - let addr_ipv6_bin: Vec = addr_ipv6.into(); - packet.extend_from_slice(&addr_ipv6_bin); - packet.extend_from_slice(&con.data[..len]); - } - } - if let Err(e) = tx.send(packet).await { - log::error!("client {} send udp response {}", addr, e); - } - } - Err(e) => { - log::warn!("client {} udp recv_from {}", addr, e); - } - } + use std::net::ToSocketAddrs; + let Some(dst_addr) = dst_addr.to_socket_addrs()?.next() else { + log::error!("client {} udp request address to_socket_addrs", client); + return Ok(()); + }; + // 1. send udp data to destination server + socket.send_to(&packet.data, &dst_addr).await?; + packet.data.resize(udp_mtu as usize, 0); + // 2. receive response from destination server + let (len, _addr) = tokio::time::timeout(tokio::time::Duration::from_secs(udp_timeout), socket.recv_from(&mut packet.data)) + .await + .map_err(std::io::Error::from)??; + packet.data.truncate(len); + // 3. send response back to client + use std::io::{Error, ErrorKind::BrokenPipe}; + tx.send(packet).await.map_err(|e| Error::new(BrokenPipe, e))?; Ok(()) } -async fn process_client_udp_req(args: &UdpGwArgs, tx: Sender>, client: Client, mut reader: ReadHalf<'_>) -> std::io::Result<()> { - let mut client = client; - let mut buf = vec![0; args.udp_mtu as usize]; - let mut len_buf = [0; UDPGW_LENGTH_FIELD_SIZE]; - let udp_mtu = args.udp_mtu; +async fn process_client_udp_req(args: &UdpGwArgs, tx: Sender, mut client: Client, mut reader: ReadHalf<'_>) -> std::io::Result<()> { let udp_timeout = args.udp_timeout; + let udp_mtu = args.udp_mtu; - 'out: loop { - /* - use socks5_impl::protocol::AsyncStreamOperation; + loop { + // 1. read udpgw packet from client let res = tokio::time::timeout(tokio::time::Duration::from_secs(2), Packet::retrieve_from_async_stream(&mut reader)).await; let packet = match res { Ok(Ok(packet)) => packet, Ok(Err(e)) => { - log::error!("client {} retrieve_from_async_stream {}", client.addr, e); + log::error!("client {} retrieve_from_async_stream \"{}\"", client.addr, e); break; } - Err(_) => { + Err(e) => { if client.last_activity.elapsed() >= CLIENT_DISCONNECT_TIMEOUT { - log::debug!("client {} last_activity elapsed", client.addr); + log::debug!("client {} last_activity elapsed {e}", client.addr); break; } continue; } }; - client.buf.clear(); - client.buf.extend_from_slice(&packet.data); - */ - - //* - let result = match tokio::time::timeout(tokio::time::Duration::from_secs(2), reader.read(&mut len_buf)).await { - Ok(ret) => ret, - Err(_e) => { - if client.last_activity.elapsed() >= CLIENT_DISCONNECT_TIMEOUT { - log::debug!("client {} last_activity elapsed", client.addr); - break; - } - continue; - } - }; - let n = result?; - if n == 0 { - // Connection closed - break; - } - if n < UDPGW_LENGTH_FIELD_SIZE { - log::error!("client {} received Packet Length field error", client.addr); - break; - } - let packet_len = u16::from_le_bytes([len_buf[0], len_buf[1]]); - if packet_len > udp_mtu { - log::error!("client {} received packet too long", client.addr); - break; - } - log::trace!("client {} recvied packet len {}", client.addr, packet_len); - client.buf.clear(); - let mut left_len: usize = packet_len as usize; - while left_len > 0 { - let len = reader.read(&mut buf[..left_len]).await?; - if len == 0 { - break 'out; - } - client.buf.extend_from_slice(&buf[..len]); - left_len -= len; - } - // */ client.last_activity = std::time::Instant::now(); - let ret = parse_udp(udp_mtu, client.buf.len(), &client.buf); - if let Ok((udpdata, flags, conn_id, reqaddr)) = ret { - if flags & UDPGW_FLAG_KEEPALIVE != 0 { - log::debug!("client {} send keepalive", client.addr); - send_keepalive_response(tx.clone(), conn_id).await; - continue; - } - log::debug!( - "client {} received udp data,flags:{},conn_id:{},addr:{:?},data len:{}", - client.addr, - flags, - conn_id, - reqaddr, - udpdata.len() - ); - let mut req = UdpRequest { - server_addr: reqaddr, - conn_id, - flags, - data: udpdata.to_vec(), - }; - let tx1 = tx.clone(); - let tx2 = tx.clone(); - tokio::spawn(async move { - if let Err(e) = process_udp(client.addr, udp_timeout, tx1, &mut req).await { - send_error(tx2, &mut req).await; - log::error!("client {} process_udp {}", client.addr, e); - } - }); - } else { - log::error!("client {} parse_udp_data {:?}", client.addr, ret.err()); + + let flags = packet.header.flags; + let conn_id = packet.header.conn_id; + if flags & UDPGW_FLAG_KEEPALIVE != 0 { + log::trace!("client {} send keepalive", client.addr); + // 2. if keepalive packet, do nothing, send keepalive response to client + send_keepalive_response(tx.clone(), conn_id).await; + continue; } + log::trace!("client {} received udp data {}", client.addr, packet); + + // 3. process client udpgw packet in a new task + let tx = tx.clone(); + tokio::spawn(async move { + if let Err(e) = process_udp(client.addr, udp_mtu, udp_timeout, tx.clone(), packet).await { + send_error(tx, conn_id).await; + log::error!("client {} process udp function {}", client.addr, e); + } + }); } Ok(()) } -async fn write_to_client(addr: SocketAddr, mut writer: WriteHalf<'_>, mut rx: Receiver>) -> std::io::Result<()> { +async fn write_to_client(addr: SocketAddr, mut writer: WriteHalf<'_>, mut rx: Receiver) -> std::io::Result<()> { loop { - let Some(udp_response) = rx.recv().await else { - log::trace!("client {} channel closed", addr); - break; - }; - if udp_response.is_empty() { - log::trace!("client {} channel recv 0", addr); - break; - } - log::trace!("send response to client {} len {}", addr, udp_response.len()); - let _r = writer.write(&udp_response).await?; + use std::io::{Error, ErrorKind::BrokenPipe}; + let packet = rx.recv().await.ok_or(Error::new(BrokenPipe, "recv error"))?; + log::trace!("send response to client {} with {}", addr, packet); + let data: Vec = packet.into(); + let _r = writer.write(&data).await?; } - Ok(()) } #[tokio::main] @@ -354,6 +174,7 @@ async fn main() -> Result<()> { env_logger::Builder::from_env(env_logger::Env::default().default_filter_or(default)).init(); + log::info!("{} {} starting...", module_path!(), env!("CARGO_PKG_VERSION")); log::info!("UDP Gateway Server running at {}", args.listen_addr); #[cfg(unix)] @@ -377,7 +198,7 @@ async fn main() -> Result<()> { log::info!("client {} connected", addr); let params = args.clone(); tokio::spawn(async move { - let (tx, rx) = mpsc::channel::>(100); + let (tx, rx) = mpsc::channel::(100); let (tcp_read_stream, tcp_write_stream) = tcp_stream.split(); let res = tokio::select! { v = process_client_udp_req(¶ms, tx, client, tcp_read_stream) => v, diff --git a/src/lib.rs b/src/lib.rs index 36c5298..242ba15 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -28,7 +28,7 @@ pub use tokio_util::sync::CancellationToken; use tproxy_config::is_private_ip; use udp_stream::UdpStream; #[cfg(feature = "udpgw")] -use udpgw::{UdpGwClientStream, UdpGwResponse, UDPGW_KEEPALIVE_TIME}; +use udpgw::{UdpGwClientStream, UdpGwResponse, UDPGW_KEEPALIVE_TIME, UDPGW_MAX_CONNECTIONS}; pub use { args::{ArgDns, ArgProxy, ArgVerbosity, Args, ProxyType}, @@ -244,7 +244,7 @@ where log::info!("UDPGW enabled"); let client = Arc::new(UdpGwClient::new( mtu, - args.udpgw_max_connections.unwrap_or(100), + args.udpgw_max_connections.unwrap_or(UDPGW_MAX_CONNECTIONS), UDPGW_KEEPALIVE_TIME, args.udp_timeout, *addr, @@ -502,7 +502,7 @@ async fn handle_udp_gateway_session( if let Err(e) = handle_proxy_session(&mut tcp_server_stream, proxy_handler).await { return Err(format!("udpgw connection error: {}", e).into()); } - UdpGwClientStream::new(udp_mtu, tcp_server_stream) + UdpGwClientStream::new(tcp_server_stream) } }; @@ -521,26 +521,29 @@ async fn handle_udp_gateway_session( return Err("get writer failed".into()); }; + let mut tmp_buf = vec![0; udp_mtu.into()]; + loop { tokio::select! { - len = UdpGwClient::recv_udp_packet(&mut udp_stack, &mut writer) => { - let read_len; - match len { - Ok(n) => { - if n == 0 { - log::info!("[UdpGw] Ending {} <> {}", &tcp_local_addr, udp_dst); - break; - } - read_len = n; - crate::traffic_status::traffic_status_update(n, 0)?; + len = udp_stack.read(&mut tmp_buf) => { + let read_len = match len { + Ok(0) => { + log::info!("[UdpGw] Ending {} <> {}", &tcp_local_addr, udp_dst); + break; } + Ok(n) => n, Err(e) => { log::info!("[UdpGw] Ending {} <> {} with recv_udp_packet {}", &tcp_local_addr, udp_dst, e); break; } - } + }; + crate::traffic_status::traffic_status_update(read_len, 0)?; let new_id = stream.new_id(); - if let Err(e) = UdpGwClient::send_udpgw_packet(ipv6_enabled, read_len, udp_dst, domain_name.as_ref(), new_id, &mut writer).await { + let remote_addr = match domain_name { + Some(ref d) => socks5_impl::protocol::Address::from((d.clone(), udp_dst.port())), + None => udp_dst.into(), + }; + if let Err(e) = UdpGwClient::send_udpgw_packet(ipv6_enabled, &tmp_buf[0..read_len], &remote_addr, new_id, &mut writer).await { log::info!("[UdpGw] Ending {} <> {} with send_udpgw_packet {}", &tcp_local_addr, udp_dst, e); break; } @@ -568,9 +571,10 @@ async fn handle_udp_gateway_session( break; } UdpGwResponse::Data(data) => { + use socks5_impl::protocol::StreamOperation; let len = data.len(); log::debug!("[UdpGw] {} <- {} receive len {}", &tcp_local_addr, udp_dst, len); - if let Err(e) = UdpGwClient::send_udp_packet(data, &mut udp_stack).await { + if let Err(e) = udp_stack.write_all(&data.data).await { log::error!("[UdpGw] Ending {} <> {} with send_udp_packet {}", &tcp_local_addr, udp_dst, e); break; } @@ -588,7 +592,7 @@ async fn handle_udp_gateway_session( } if !stream.is_closed() { - udpgw_client.release_server_connection_with_stream(stream, reader, writer).await; + udpgw_client.release_server_connection_full(stream, reader, writer).await; } Ok(()) diff --git a/src/udpgw.rs b/src/udpgw.rs index 4523c6c..04a2787 100644 --- a/src/udpgw.rs +++ b/src/udpgw.rs @@ -1,12 +1,6 @@ use crate::error::Result; -use ipstack::stream::IpStackUdpStream; -use socks5_impl::protocol::{AsyncStreamOperation, BufMut, StreamOperation}; -use std::{ - collections::VecDeque, - hash::Hash, - net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}, - sync::atomic::Ordering::Relaxed, -}; +use socks5_impl::protocol::{Address, AsyncStreamOperation, BufMut, StreamOperation}; +use std::{collections::VecDeque, hash::Hash, net::SocketAddr, sync::atomic::Ordering::Relaxed}; use tokio::{ io::{AsyncReadExt, AsyncWriteExt}, net::{ @@ -17,25 +11,65 @@ use tokio::{ time::{sleep, Duration}, }; -pub const UDPGW_MAX_CONNECTIONS: usize = 100; -pub const UDPGW_KEEPALIVE_TIME: tokio::time::Duration = std::time::Duration::from_secs(10); -pub const UDPGW_FLAG_KEEPALIVE: u8 = 0x01; -pub const UDPGW_FLAG_IPV4: u8 = 0x00; -pub const UDPGW_FLAG_IPV6: u8 = 0x08; -pub const UDPGW_FLAG_DOMAIN: u8 = 0x10; -pub const UDPGW_FLAG_ERR: u8 = 0x20; +pub(crate) const UDPGW_LENGTH_FIELD_SIZE: usize = std::mem::size_of::(); +pub(crate) const UDPGW_MAX_CONNECTIONS: u16 = 100; +pub(crate) const UDPGW_KEEPALIVE_TIME: tokio::time::Duration = std::time::Duration::from_secs(10); -pub const UDPGW_LENGTH_FIELD_SIZE: usize = std::mem::size_of::(); +pub const UDPGW_FLAG_KEEPALIVE: u8 = 0x01; +pub const UDPGW_FLAG_ERR: u8 = 0x20; +pub const UDPGW_FLAG_DATA: u8 = 0x02; static TCP_COUNTER: std::sync::atomic::AtomicU32 = std::sync::atomic::AtomicU32::new(0); +/// UDP Gateway Packet Format +/// +/// The format is referenced from SOCKS5 packet format, with additional flags and connection ID fields. +/// +/// `LEN`: This field is indicated the length of the packet, not including the length field itself. +/// +/// `FLAGS`: This field is used to indicate the packet type. The flags are defined as follows: +/// - `0x01`: Keepalive packet without address and data +/// - `0x20`: Error packet without address and data +/// - `0x02`: Data packet with address and data +/// +/// `CONN_ID`: This field is used to indicate the unique connection ID for the packet. +/// +/// `ATYP` & `DST.ADDR` & `DST.PORT`: This fields are used to indicate the remote address and port. +/// It can be either an IPv4 address, an IPv6 address, or a domain name, depending on the `ATYP` field. +/// The address format directly uses the address format of the [SOCKS5](https://datatracker.ietf.org/doc/html/rfc1928#section-4) protocol. +/// - `ATYP`: Address Type, 1 byte, indicating the type of address ( 0x01-IPv4, 0x04-IPv6, or 0x03-domain name ) +/// - `DST.ADDR`: Destination Address. If `ATYP` is 0x01 or 0x04, it is 4 or 16 bytes of IP address; +/// If `ATYP` is 0x03, it is a domain name, `DST.ADDR` is a variable length field, +/// it begins with a 1-byte length field and then the domain name without null-termination, +/// since the length field is 1 byte, the maximum length of the domain name is 255 bytes. +/// - `DST.PORT`: Destination Port, 2 bytes, the port number of the destination address. +/// +/// `DATA`: The data field, a variable length field, the length is determined by the `LEN` field. +/// +/// All the digits fields are in big-endian byte order. +/// +/// ```plain +/// +-----+ +-------+---------+ +------+----------+----------+ +----------+ +/// | LEN | | FLAGS | CONN_ID | | ATYP | DST.ADDR | DST.PORT | | DATA | +/// +-----+ +-------+---------+ +------+----------+----------+ +----------+ +/// | 2 | | 1 | 2 | | 1 | Variable | 2 | | Variable | +/// +-----+ +-------+---------+ +------+----------+----------+ +----------+ +/// ``` #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct Packet { - pub length: u16, pub header: UdpgwHeader, + pub address: Option
, pub data: Vec, } +impl std::fmt::Display for Packet { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let addr = self.address.as_ref().map_or("None".to_string(), |addr| addr.to_string()); + let len = self.data.len(); + write!(f, "Packet {{ {}, address: {}, payload length: {} }}", self.header, addr, len) + } +} + impl From for Vec { fn from(packet: Packet) -> Vec { (&packet).into() @@ -57,20 +91,56 @@ impl TryFrom<&[u8]> for Packet { if value.len() < UDPGW_LENGTH_FIELD_SIZE { return Err(std::io::ErrorKind::InvalidData.into()); } - let length = u16::from_le_bytes([value[0], value[1]]); + let mut iter = std::io::Cursor::new(value); + use tokio_util::bytes::Buf; + let length = iter.get_u16(); if value.len() < length as usize + UDPGW_LENGTH_FIELD_SIZE { return Err(std::io::ErrorKind::InvalidData.into()); } - let header = UdpgwHeader::try_from(&value[UDPGW_LENGTH_FIELD_SIZE..])?; - let data = value[UDPGW_LENGTH_FIELD_SIZE + header.len()..].to_vec(); - Ok(Packet::new(header, data)) + let header = UdpgwHeader::retrieve_from_stream(&mut iter)?; + let address = if header.flags & UDPGW_FLAG_DATA != 0 { + Some(Address::retrieve_from_stream(&mut iter)?) + } else { + None + }; + Ok(Packet::new(header, address, iter.chunk())) } } impl Packet { - pub fn new(header: UdpgwHeader, data: Vec) -> Self { - let length = (header.len() + data.len()) as u16; - Packet { length, header, data } + pub fn new(header: UdpgwHeader, address: Option
, data: &[u8]) -> Self { + let data = data.to_vec(); + Packet { header, address, data } + } + + pub fn build_keepalive_packet(conn_id: u16) -> Self { + Packet::new(UdpgwHeader::new(UDPGW_FLAG_KEEPALIVE, conn_id), None, &[]) + } + + pub fn build_error_packet(conn_id: u16) -> Self { + Packet::new(UdpgwHeader::new(UDPGW_FLAG_ERR, conn_id), None, &[]) + } + + pub fn build_packet_from_address(conn_id: u16, remote_addr: &Address, data: &[u8]) -> std::io::Result { + use socks5_impl::protocol::Address::{DomainAddress, SocketAddress}; + let packet = match remote_addr { + SocketAddress(addr) => Packet::build_ip_packet(conn_id, *addr, data), + DomainAddress(domain, port) => Packet::build_domain_packet(conn_id, *port, domain, data)?, + }; + Ok(packet) + } + + pub fn build_ip_packet(conn_id: u16, remote_addr: SocketAddr, data: &[u8]) -> Self { + let addr: Address = remote_addr.into(); + Packet::new(UdpgwHeader::new(UDPGW_FLAG_DATA, conn_id), Some(addr), data) + } + + pub fn build_domain_packet(conn_id: u16, port: u16, domain: &str, data: &[u8]) -> std::io::Result { + if domain.len() > 255 { + return Err(std::io::ErrorKind::InvalidInput.into()); + } + let addr = Address::from((domain, port)); + Ok(Packet::new(UdpgwHeader::new(UDPGW_FLAG_DATA, conn_id), Some(addr), data)) } } @@ -82,23 +152,30 @@ impl StreamOperation for Packet { { let mut buf = [0; UDPGW_LENGTH_FIELD_SIZE]; stream.read_exact(&mut buf)?; - let length = u16::from_le_bytes(buf); - let mut buf = [0; UdpgwHeader::static_len()]; - stream.read_exact(&mut buf)?; - let header = UdpgwHeader::try_from(&buf[..])?; - let mut data = vec![0; length as usize - header.len()]; + let length = u16::from_be_bytes(buf) as usize; + let header = UdpgwHeader::retrieve_from_stream(stream)?; + let address = if header.flags & UDPGW_FLAG_DATA != 0 { + Some(Address::retrieve_from_stream(stream)?) + } else { + None + }; + let mut data = vec![0; length - header.len() - address.as_ref().map_or(0, |addr| addr.len())]; stream.read_exact(&mut data)?; - Ok(Packet::new(header, data)) + Ok(Packet::new(header, address, &data)) } fn write_to_buf(&self, buf: &mut B) { - buf.put_u16_le(self.length); + let len = self.len() - UDPGW_LENGTH_FIELD_SIZE; + buf.put_u16(len as u16); self.header.write_to_buf(buf); + if let Some(addr) = &self.address { + addr.write_to_buf(buf); + } buf.put_slice(&self.data); } fn len(&self) -> usize { - UDPGW_LENGTH_FIELD_SIZE + self.header.len() + self.data.len() + UDPGW_LENGTH_FIELD_SIZE + self.header.len() + self.address.as_ref().map_or(0, |addr| addr.len()) + self.data.len() } } @@ -111,22 +188,32 @@ impl AsyncStreamOperation for Packet { { let mut buf = [0; 2]; r.read_exact(&mut buf).await?; - let length = u16::from_le_bytes(buf); + let length = u16::from_be_bytes(buf) as usize; let header = UdpgwHeader::retrieve_from_async_stream(r).await?; - let mut data = vec![0; length as usize - header.len()]; + let address = if header.flags & UDPGW_FLAG_DATA != 0 { + Some(Address::retrieve_from_async_stream(r).await?) + } else { + None + }; + let mut data = vec![0; length - header.len() - address.as_ref().map_or(0, |addr| addr.len())]; r.read_exact(&mut data).await?; - Ok(Packet::new(header, data)) + Ok(Packet::new(header, address, &data)) } } #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -#[repr(C)] -#[repr(packed(1))] pub struct UdpgwHeader { pub flags: u8, pub conn_id: u16, } +impl std::fmt::Display for UdpgwHeader { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let id = self.conn_id; + write!(f, "flags: 0x{:02x}, conn_id: {}", self.flags, id) + } +} + impl StreamOperation for UdpgwHeader { fn retrieve_from_stream(stream: &mut R) -> std::io::Result where @@ -167,7 +254,7 @@ impl UdpgwHeader { } pub const fn static_len() -> usize { - std::mem::size_of::() + std::mem::size_of::() + std::mem::size_of::() } } @@ -178,10 +265,8 @@ impl TryFrom<&[u8]> for UdpgwHeader { if value.len() < UdpgwHeader::static_len() { return Err(std::io::ErrorKind::InvalidData.into()); } - Ok(UdpgwHeader { - flags: value[0], - conn_id: u16::from_le_bytes([value[1], value[2]]), - }) + let conn_id = u16::from_be_bytes([value[1], value[2]]); + Ok(UdpgwHeader { flags: value[0], conn_id }) } } @@ -189,137 +274,25 @@ impl From<&UdpgwHeader> for Vec { fn from(header: &UdpgwHeader) -> Vec { let mut bytes = vec![0; header.len()]; bytes[0] = header.flags; - bytes[1..3].copy_from_slice(&header.conn_id.to_le_bytes()); + bytes[1..3].copy_from_slice(&header.conn_id.to_be_bytes()); bytes } } -#[allow(clippy::len_without_is_empty)] -#[derive(Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Debug)] -pub struct BinSocketAddr(SocketAddr); - -impl BinSocketAddr { - pub fn len(&self) -> usize { - match self.0 { - SocketAddr::V4(_) => Self::static_len(false), - SocketAddr::V6(_) => Self::static_len(true), - } - } - - pub fn static_len(is_ipv6: bool) -> usize { - if is_ipv6 { - std::mem::size_of::() + std::mem::size_of::() - } else { - std::mem::size_of::() + std::mem::size_of::() - } - } -} - -impl From<&BinSocketAddr> for Vec { - fn from(addr: &BinSocketAddr) -> Vec { - socket_addr_to_binary(&addr.0) - } -} - -impl From for Vec { - fn from(addr: BinSocketAddr) -> Vec { - socket_addr_to_binary(&addr.0) - } -} - -impl TryFrom<&[u8]> for BinSocketAddr { - type Error = std::io::Error; - fn try_from(value: &[u8]) -> std::result::Result { - Ok(BinSocketAddr(binary_to_socket_addr(value)?)) - } -} - -impl From for BinSocketAddr { - fn from(addr: SocketAddr) -> Self { - BinSocketAddr(addr) - } -} - -impl From for SocketAddr { - fn from(addr: BinSocketAddr) -> Self { - addr.0 - } -} - -fn socket_addr_to_binary(addr: &SocketAddr) -> Vec { - match addr { - SocketAddr::V4(addr_v4) => { - let mut bytes = vec![0; std::mem::size_of::()]; - bytes[0..4].copy_from_slice(&addr_v4.ip().octets()); - bytes[4..6].copy_from_slice(&addr_v4.port().to_be_bytes()); - bytes - } - SocketAddr::V6(addr_v6) => { - let mut bytes = vec![0; std::mem::size_of::() + std::mem::size_of::()]; - bytes[0..16].copy_from_slice(&addr_v6.ip().octets()); - bytes[16..18].copy_from_slice(&addr_v6.port().to_be_bytes()); - bytes - } - } -} - -fn binary_to_socket_addr(bytes: &[u8]) -> std::io::Result { - if bytes.len() == std::mem::size_of::() { - let ip = Ipv4Addr::new(bytes[0], bytes[1], bytes[2], bytes[3]); - let port = u16::from_be_bytes([bytes[4], bytes[5]]); - Ok(SocketAddr::V4(SocketAddrV4::new(ip, port))) - } else if bytes.len() == std::mem::size_of::() + std::mem::size_of::() { - let mut ip = [0; 16]; - ip.copy_from_slice(&bytes[0..16]); - let port = u16::from_be_bytes([bytes[16], bytes[17]]); - Ok(SocketAddr::V6(SocketAddrV6::new(ip.into(), port, 0, 0))) - } else { - Err(std::io::ErrorKind::InvalidData.into()) - } -} - #[allow(dead_code)] #[derive(Debug)] -pub(crate) struct UdpGwData<'a> { - flags: u8, - conn_id: u16, - remote_addr: SocketAddr, - udpdata: &'a [u8], -} - -impl<'a> UdpGwData<'a> { - pub fn len(&self) -> usize { - self.udpdata.len() - } -} - -#[allow(dead_code)] -#[derive(Debug)] -pub(crate) enum UdpGwResponse<'a> { +pub(crate) enum UdpGwResponse { KeepAlive, Error, TcpClose, - Data(UdpGwData<'a>), -} - -#[derive(Debug)] -pub(crate) struct UdpGwClientStreamWriter { - inner: OwnedWriteHalf, - tmp_buf: Vec, - send_buf: Vec, -} - -#[derive(Debug)] -pub(crate) struct UdpGwClientStreamReader { - inner: OwnedReadHalf, - recv_buf: Vec, + Data(Packet), } #[derive(Debug)] pub(crate) struct UdpGwClientStream { local_addr: String, - writer: Option, - reader: Option, + writer: Option, + reader: Option, conn_id: u16, closed: bool, last_activity: std::time::Instant, @@ -335,19 +308,20 @@ impl UdpGwClientStream { pub fn close(&mut self) { self.closed = true; } - pub fn get_reader(&mut self) -> Option { + + pub fn get_reader(&mut self) -> Option { self.reader.take() } - pub fn set_reader(&mut self, mut reader: Option) { - self.reader = reader.take(); + pub fn set_reader(&mut self, reader: Option) { + self.reader = reader; } - pub fn set_writer(&mut self, mut writer: Option) { - self.writer = writer.take(); + pub fn set_writer(&mut self, writer: Option) { + self.writer = writer; } - pub fn get_writer(&mut self) -> Option { + pub fn get_writer(&mut self) -> Option { self.writer.take() } @@ -371,21 +345,13 @@ impl UdpGwClientStream { self.conn_id += 1; self.conn_id } - pub fn new(udp_mtu: u16, tcp_server_stream: TcpStream) -> Self { - let local_addr = tcp_server_stream - .local_addr() - .unwrap_or_else(|_| "0.0.0.0:0".parse::().unwrap()) - .to_string(); + + pub fn new(tcp_server_stream: TcpStream) -> Self { + let default = "0.0.0.0:0".parse::().unwrap(); + let local_addr = tcp_server_stream.local_addr().unwrap_or(default).to_string(); let (rx, tx) = tcp_server_stream.into_split(); - let writer = UdpGwClientStreamWriter { - inner: tx, - tmp_buf: vec![0; udp_mtu.into()], - send_buf: vec![0; udp_mtu.into()], - }; - let reader = UdpGwClientStreamReader { - inner: rx, - recv_buf: vec![0; udp_mtu.into()], - }; + let writer = tx; + let reader = rx; TCP_COUNTER.fetch_add(1, Relaxed); UdpGwClientStream { local_addr, @@ -405,13 +371,11 @@ pub(crate) struct UdpGwClient { udp_timeout: u64, keepalive_time: Duration, server_addr: SocketAddr, - keepalive_packet: Vec, server_connections: Mutex>, } impl UdpGwClient { pub fn new(udp_mtu: u16, max_connections: u16, keepalive_time: Duration, udp_timeout: u64, server_addr: SocketAddr) -> Self { - let keepalive_packet: Vec = Packet::new(UdpgwHeader::new(UDPGW_FLAG_KEEPALIVE, 0), vec![]).into(); let server_connections = Mutex::new(VecDeque::with_capacity(max_connections as usize)); UdpGwClient { udp_mtu, @@ -419,7 +383,6 @@ impl UdpGwClient { udp_timeout, server_addr, keepalive_time, - keepalive_packet, server_connections, } } @@ -446,11 +409,11 @@ impl UdpGwClient { } } - pub(crate) async fn release_server_connection_with_stream( + pub(crate) async fn release_server_connection_full( &self, mut stream: UdpGwClientStream, - reader: UdpGwClientStreamReader, - writer: UdpGwClientStreamWriter, + reader: OwnedReadHalf, + writer: OwnedWriteHalf, ) { if self.server_connections.lock().await.len() < self.max_connections as usize { stream.set_reader(Some(reader)); @@ -480,101 +443,39 @@ impl UdpGwClient { let Some(mut stream_writer) = stream.get_writer() else { continue; }; - let local_addr = stream_writer.inner.local_addr(); + let local_addr = stream_writer.local_addr(); log::debug!("{:?}:{} send keepalive", local_addr, stream.id()); - if let Err(e) = stream_writer.inner.write_all(&self.keepalive_packet).await { + let keepalive_packet: Vec = Packet::build_keepalive_packet(stream.id()).into(); + if let Err(e) = stream_writer.write_all(&keepalive_packet).await { log::warn!("{:?}:{} send keepalive failed: {}", local_addr, stream.id(), e); - } else { - match UdpGwClient::recv_udpgw_packet(self.udp_mtu, 10, &mut stream_reader).await { - Ok(UdpGwResponse::KeepAlive) => { - stream.update_activity(); - self.release_server_connection_with_stream(stream, stream_reader, stream_writer) - .await; - } - Ok(v) => log::warn!("{:?}:{} keepalive unexpected response: {:?}", local_addr, stream.id(), v), - Err(e) => log::warn!("{:?}:{} keepalive no response, error \"{}\"", local_addr, stream.id(), e), + continue; + } + match UdpGwClient::recv_udpgw_packet(self.udp_mtu, 10, &mut stream_reader).await { + Ok(UdpGwResponse::KeepAlive) => { + stream.update_activity(); + self.release_server_connection_full(stream, stream_reader, stream_writer).await; } + Ok(v) => log::warn!("{:?}:{} keepalive unexpected response: {:?}", local_addr, stream.id(), v), + Err(e) => log::warn!("{:?}:{} keepalive no response, error \"{}\"", local_addr, stream.id(), e), } } } } /// Parses the UDP response data. - pub(crate) fn parse_udp_response(udp_mtu: u16, data_len: usize, stream: &mut UdpGwClientStreamReader) -> Result { - let data = &stream.recv_buf; - let header_len = UdpgwHeader::static_len(); - if data_len < header_len { - return Err("Invalid udpgw data".into()); - } - let header_bytes = &data[..header_len]; - let header = UdpgwHeader { - flags: header_bytes[0], - conn_id: u16::from_le_bytes([header_bytes[1], header_bytes[2]]), - }; - - let flags = header.flags; - let conn_id = header.conn_id; - - let ip_data = &data[header_len..]; - let mut data_len = data_len - header_len; - + pub(crate) fn parse_udp_response(udp_mtu: u16, data: &[u8]) -> Result { + let packet = Packet::try_from(data)?; + let flags = packet.header.flags; if flags & UDPGW_FLAG_ERR != 0 { return Ok(UdpGwResponse::Error); } - if flags & UDPGW_FLAG_KEEPALIVE != 0 { return Ok(UdpGwResponse::KeepAlive); } - - if flags & UDPGW_FLAG_IPV6 != 0 { - let ipv6_addr_len = BinSocketAddr::static_len(true); - if data_len < ipv6_addr_len { - return Err("ipv6 Invalid UDP data".into()); - } - let addr_ipv6 = BinSocketAddr::try_from(&ip_data[..ipv6_addr_len])?; - data_len -= ipv6_addr_len; - - if data_len > udp_mtu as usize { - return Err("too much data".into()); - } - return Ok(UdpGwResponse::Data(UdpGwData { - flags, - conn_id, - remote_addr: addr_ipv6.into(), - udpdata: &ip_data[ipv6_addr_len..(data_len + ipv6_addr_len)], - })); - } else { - let ipv4_addr_len = BinSocketAddr::static_len(false); - if data_len < ipv4_addr_len { - return Err("ipv4 Invalid UDP data".into()); - } - let addr_ipv4 = BinSocketAddr::try_from(&ip_data[..ipv4_addr_len])?; - data_len -= ipv4_addr_len; - - if data_len > udp_mtu as usize { - return Err("too much data".into()); - } - return Ok(UdpGwResponse::Data(UdpGwData { - flags, - conn_id, - remote_addr: addr_ipv4.into(), - udpdata: &ip_data[ipv4_addr_len..(data_len + ipv4_addr_len)], - })); + if packet.data.len() > udp_mtu as usize { + return Err("too much data".into()); } - } - - pub(crate) async fn recv_udp_packet( - udp_stack: &mut IpStackUdpStream, - stream: &mut UdpGwClientStreamWriter, - ) -> std::result::Result { - udp_stack.read(&mut stream.tmp_buf).await - } - - pub(crate) async fn send_udp_packet<'a>( - packet: UdpGwData<'a>, - udp_stack: &mut IpStackUdpStream, - ) -> std::result::Result<(), std::io::Error> { - udp_stack.write_all(packet.udpdata).await + Ok(UdpGwResponse::Data(packet)) } /// Receives a UDP gateway packet. @@ -588,37 +489,15 @@ impl UdpGwClient { /// /// # Returns /// - `Result`: Returns a result type containing the parsed UDP gateway response, or an error if one occurs. - pub(crate) async fn recv_udpgw_packet(udp_mtu: u16, udp_timeout: u64, stream: &mut UdpGwClientStreamReader) -> Result { - let result = tokio::time::timeout( - tokio::time::Duration::from_secs(udp_timeout + 2), - stream.inner.read(&mut stream.recv_buf[..2]), - ) - .await - .map_err(std::io::Error::from)?; - let n = result?; - if n == 0 { + pub(crate) async fn recv_udpgw_packet(udp_mtu: u16, udp_timeout: u64, stream: &mut OwnedReadHalf) -> Result { + let mut data = vec![0; udp_mtu.into()]; + let data_len = tokio::time::timeout(tokio::time::Duration::from_secs(udp_timeout + 2), stream.read(&mut data)) + .await + .map_err(std::io::Error::from)??; + if data_len == 0 { return Ok(UdpGwResponse::TcpClose); } - if n < UDPGW_LENGTH_FIELD_SIZE { - return Err("received Packet Length field error".into()); - } - let packet_len = u16::from_le_bytes([stream.recv_buf[0], stream.recv_buf[1]]); - if packet_len > udp_mtu { - return Err("packet too long".into()); - } - let mut left_len: usize = packet_len as usize; - let mut recv_len = 0; - while left_len > 0 { - let Ok(len) = stream.inner.read(&mut stream.recv_buf[recv_len..left_len]).await else { - return Ok(UdpGwResponse::TcpClose); - }; - if len == 0 { - return Ok(UdpGwResponse::TcpClose); - } - recv_len += len; - left_len -= len; - } - UdpGwClient::parse_udp_response(udp_mtu, packet_len as usize, stream) + UdpGwClient::parse_udp_response(udp_mtu, &data[..data_len]) } /// Sends a UDP gateway packet. @@ -629,9 +508,8 @@ impl UdpGwClient { /// # Arguments /// /// * `ipv6_enabled` - Whether IPv6 is enabled - /// * `len` - Length of the data packet + /// * `data` - The data packet /// * `remote_addr` - Remote address - /// * `domain` - Target domain (optional) /// * `conn_id` - Connection ID /// * `stream` - UDP gateway client writer stream /// @@ -640,80 +518,17 @@ impl UdpGwClient { /// Returns `Ok(())` if the packet is sent successfully, otherwise returns an error. pub(crate) async fn send_udpgw_packet( ipv6_enabled: bool, - len: usize, - remote_addr: SocketAddr, - domain: Option<&String>, + data: &[u8], + remote_addr: &socks5_impl::protocol::Address, conn_id: u16, - stream: &mut UdpGwClientStreamWriter, + stream: &mut OwnedWriteHalf, ) -> Result<()> { - stream.send_buf.clear(); - let data = &stream.tmp_buf; - let mut pack_len = UdpgwHeader::static_len() + len; - let packet = &mut stream.send_buf; - match domain { - Some(domain) => { - let addr_port = remote_addr.port(); - let domain_len = domain.len(); - if domain_len > 255 { - return Err("InvalidDomain".into()); - } - pack_len += UDPGW_LENGTH_FIELD_SIZE; - pack_len += domain_len + 1; - packet.extend_from_slice(&(pack_len as u16).to_le_bytes()); - packet.extend_from_slice(&[UDPGW_FLAG_DOMAIN]); - packet.extend_from_slice(&conn_id.to_le_bytes()); - packet.extend_from_slice(&addr_port.to_be_bytes()); - packet.extend_from_slice(domain.as_bytes()); - packet.push(0); - packet.extend_from_slice(&data[..len]); - } - None => match remote_addr { - SocketAddr::V4(_) => { - let addr_ipv4 = BinSocketAddr::from(remote_addr); - pack_len += addr_ipv4.len(); - packet.extend_from_slice(&(pack_len as u16).to_le_bytes()); - packet.extend_from_slice(&[UDPGW_FLAG_IPV4]); - packet.extend_from_slice(&conn_id.to_le_bytes()); - let addr_ipv4_bin: Vec = addr_ipv4.into(); - packet.extend_from_slice(&addr_ipv4_bin); - packet.extend_from_slice(&data[..len]); - } - SocketAddr::V6(_) => { - if !ipv6_enabled { - return Err("ipv6 not support".into()); - } - let addr_ipv6 = BinSocketAddr::from(remote_addr); - pack_len += addr_ipv6.len(); - packet.extend_from_slice(&(pack_len as u16).to_le_bytes()); - packet.extend_from_slice(&[UDPGW_FLAG_IPV6]); - packet.extend_from_slice(&conn_id.to_le_bytes()); - let addr_ipv6_bin: Vec = addr_ipv6.into(); - packet.extend_from_slice(&addr_ipv6_bin); - packet.extend_from_slice(&data[..len]); - } - }, + if !ipv6_enabled && remote_addr.get_type() == socks5_impl::protocol::AddressType::IPv6 { + return Err("ipv6 not support".into()); } - - stream.inner.write_all(packet).await?; + let out_data: Vec = Packet::build_packet_from_address(conn_id, remote_addr, data)?.into(); + stream.write_all(&out_data).await?; Ok(()) } } - -#[cfg(test)] -mod tests { - use super::{Packet, UdpgwHeader}; - use socks5_impl::protocol::StreamOperation; - - #[test] - fn test_udpgw_header() { - let header = UdpgwHeader::new(0x01, 0x1234); - let mut bytes: Vec = vec![]; - let packet = Packet::new(header, vec![]); - packet.write_to_buf(&mut bytes); - - let header2 = Packet::retrieve_from_stream(&mut bytes.as_slice()).unwrap().header; - - assert_eq!(header, header2); - } -}