diff --git a/Cargo.toml b/Cargo.toml index d237122..3ad63ce 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tun2proxy" -version = "0.5.4" +version = "0.6.0" edition = "2021" license = "MIT" repository = "https://github.com/tun2proxy/tun2proxy" @@ -13,6 +13,10 @@ rust-version = "1.80" [lib] crate-type = ["staticlib", "cdylib", "lib"] +[features] +default = ["udpgw"] +udpgw = [] + [dependencies] async-trait = "0.1" base64 = { version = "0.22" } @@ -42,11 +46,6 @@ url = "2" [target.'cfg(target_os="linux")'.dependencies] serde = { version = "1", features = ["derive"] } bincode = "1" -nix = { version = "0.29", default-features = false, features = [ - "fs", - "socket", - "uio", -] } [target.'cfg(target_os="android")'.dependencies] android_logger = "0.14" @@ -54,6 +53,11 @@ jni = { version = "0.21", default-features = false } [target.'cfg(unix)'.dependencies] daemonize = "0.5" +nix = { version = "0.29", default-features = false, features = [ + "fs", + "socket", + "uio", +] } [target.'cfg(target_os = "windows")'.dependencies] windows-service = "0.7" @@ -65,5 +69,10 @@ serde_json = "1" name = "tun2proxy-bin" path = "src/bin/main.rs" +[[bin]] +name = "udpgw-server" +path = "src/bin/udpgw_server.rs" +required-features = ["udpgw"] + [profile.release] strip = "symbols" diff --git a/src/args.rs b/src/args.rs index 352cb57..65d0998 100644 --- a/src/args.rs +++ b/src/args.rs @@ -111,6 +111,16 @@ pub struct Args { /// Maximum number of sessions to be handled concurrently #[arg(long, value_name = "number", default_value = "200")] pub max_sessions: usize, + + /// UDP gateway server address, similar to badvpn-udpgw + #[cfg(feature = "udpgw")] + #[arg(long, value_name = "IP:PORT")] + pub udpgw_server: Option, + + /// Max udpgw connections, default value is 100 + #[cfg(feature = "udpgw")] + #[arg(long, value_name = "number", requires = "udpgw_server")] + pub udpgw_max_connections: Option, } fn validate_tun(p: &str) -> Result { @@ -154,6 +164,10 @@ impl Default for Args { daemonize: false, exit_on_fatal_error: false, max_sessions: 200, + #[cfg(feature = "udpgw")] + udpgw_server: None, + #[cfg(feature = "udpgw")] + udpgw_max_connections: None, } } } @@ -181,6 +195,18 @@ impl Args { self } + #[cfg(feature = "udpgw")] + pub fn udpgw_server(&mut self, udpgw: SocketAddr) -> &mut Self { + self.udpgw_server = Some(udpgw); + self + } + + #[cfg(feature = "udpgw")] + pub fn udpgw_max_connections(&mut self, udpgw_max_connections: u16) -> &mut Self { + self.udpgw_max_connections = Some(udpgw_max_connections); + self + } + #[cfg(unix)] pub fn tun_fd(&mut self, tun_fd: Option) -> &mut Self { self.tun_fd = tun_fd; diff --git a/src/bin/udpgw_server.rs b/src/bin/udpgw_server.rs new file mode 100644 index 0000000..6f6d080 --- /dev/null +++ b/src/bin/udpgw_server.rs @@ -0,0 +1,210 @@ +use socks5_impl::protocol::{AddressType, AsyncStreamOperation}; +use std::{net::SocketAddr, sync::Arc}; +use tokio::{ + io::AsyncWriteExt, + net::{ + tcp::{ReadHalf, WriteHalf}, + UdpSocket, + }, + sync::mpsc::{self, Receiver, Sender}, +}; +use tun2proxy::{ + udpgw::{Packet, UDPGW_FLAG_KEEPALIVE}, + ArgVerbosity, Result, +}; + +pub(crate) const CLIENT_DISCONNECT_TIMEOUT: tokio::time::Duration = std::time::Duration::from_secs(60); + +#[derive(Debug, Clone)] +pub struct Client { + addr: SocketAddr, + last_activity: std::time::Instant, +} + +impl Client { + pub fn new(addr: SocketAddr) -> Self { + let last_activity = std::time::Instant::now(); + Self { addr, last_activity } + } +} + +#[derive(Debug, Clone, clap::Parser)] +pub struct UdpGwArgs { + /// UDP gateway listen address + #[arg(short, long, value_name = "IP:PORT", default_value = "127.0.0.1:7300")] + pub listen_addr: SocketAddr, + + /// UDP mtu + #[arg(short = 'm', long, value_name = "udp mtu", default_value = "10240")] + pub udp_mtu: u16, + + /// UDP timeout in seconds + #[arg(short = 't', long, value_name = "seconds", default_value = "3")] + pub udp_timeout: u64, + + /// Daemonize for unix family or run as Windows service + #[cfg(unix)] + #[arg(long)] + pub daemonize: bool, + + /// Verbosity level + #[arg(short, long, value_name = "level", value_enum, default_value = "info")] + pub verbosity: ArgVerbosity, +} + +impl UdpGwArgs { + #[allow(clippy::let_and_return)] + pub fn parse_args() -> Self { + use clap::Parser; + Self::parse() + } +} + +async fn send_error(tx: Sender, conn_id: u16) { + let error_packet = Packet::build_error_packet(conn_id); + if let Err(e) = tx.send(error_packet).await { + log::error!("send error response error {:?}", e); + } +} + +async fn send_keepalive_response(tx: Sender, conn_id: u16) { + let keepalive_packet = Packet::build_keepalive_packet(conn_id); + if let Err(e) = tx.send(keepalive_packet).await { + log::error!("send keepalive response error {:?}", e); + } +} + +/// Send data field of packet from client to destination server and receive response, +/// then wrap response data to the packet's data field and send packet back to client. +async fn process_udp(client: SocketAddr, udp_mtu: u16, udp_timeout: u64, tx: Sender, mut packet: Packet) -> Result<()> { + let Some(dst_addr) = &packet.address else { + log::error!("client {} udp request address is None", client); + return Ok(()); + }; + let std_sock = if dst_addr.get_type() == AddressType::IPv6 { + std::net::UdpSocket::bind("[::]:0")? + } else { + std::net::UdpSocket::bind("0.0.0.0:0")? + }; + std_sock.set_nonblocking(true)?; + #[cfg(unix)] + nix::sys::socket::setsockopt(&std_sock, nix::sys::socket::sockopt::ReuseAddr, &true)?; + let socket = UdpSocket::from_std(std_sock)?; + use std::net::ToSocketAddrs; + let Some(dst_addr) = dst_addr.to_socket_addrs()?.next() else { + log::error!("client {} udp request address to_socket_addrs", client); + return Ok(()); + }; + // 1. send udp data to destination server + socket.send_to(&packet.data, &dst_addr).await?; + packet.data.resize(udp_mtu as usize, 0); + // 2. receive response from destination server + let (len, _addr) = tokio::time::timeout(tokio::time::Duration::from_secs(udp_timeout), socket.recv_from(&mut packet.data)) + .await + .map_err(std::io::Error::from)??; + packet.data.truncate(len); + // 3. send response back to client + use std::io::{Error, ErrorKind::BrokenPipe}; + tx.send(packet).await.map_err(|e| Error::new(BrokenPipe, e))?; + Ok(()) +} + +async fn process_client_udp_req(args: &UdpGwArgs, tx: Sender, mut client: Client, mut reader: ReadHalf<'_>) -> std::io::Result<()> { + let udp_timeout = args.udp_timeout; + let udp_mtu = args.udp_mtu; + + loop { + // 1. read udpgw packet from client + let res = tokio::time::timeout(tokio::time::Duration::from_secs(2), Packet::retrieve_from_async_stream(&mut reader)).await; + let packet = match res { + Ok(Ok(packet)) => packet, + Ok(Err(e)) => { + log::error!("client {} retrieve_from_async_stream \"{}\"", client.addr, e); + break; + } + Err(e) => { + if client.last_activity.elapsed() >= CLIENT_DISCONNECT_TIMEOUT { + log::debug!("client {} last_activity elapsed \"{e}\"", client.addr); + break; + } + continue; + } + }; + client.last_activity = std::time::Instant::now(); + + let flags = packet.header.flags; + let conn_id = packet.header.conn_id; + if flags & UDPGW_FLAG_KEEPALIVE != 0 { + log::trace!("client {} send keepalive", client.addr); + // 2. if keepalive packet, do nothing, send keepalive response to client + send_keepalive_response(tx.clone(), conn_id).await; + continue; + } + log::trace!("client {} received udp data {}", client.addr, packet); + + // 3. process client udpgw packet in a new task + let tx = tx.clone(); + tokio::spawn(async move { + if let Err(e) = process_udp(client.addr, udp_mtu, udp_timeout, tx.clone(), packet).await { + send_error(tx, conn_id).await; + log::error!("client {} process udp function {}", client.addr, e); + } + }); + } + Ok(()) +} + +async fn write_to_client(addr: SocketAddr, mut writer: WriteHalf<'_>, mut rx: Receiver) -> std::io::Result<()> { + loop { + use std::io::{Error, ErrorKind::BrokenPipe}; + let packet = rx.recv().await.ok_or(Error::new(BrokenPipe, "recv error"))?; + log::trace!("send response to client {} with {}", addr, packet); + let data: Vec = packet.into(); + let _r = writer.write(&data).await?; + } +} + +#[tokio::main] +async fn main() -> Result<()> { + let args = Arc::new(UdpGwArgs::parse_args()); + + let tcp_listener = tokio::net::TcpListener::bind(args.listen_addr).await?; + + let default = format!("{:?}", args.verbosity); + + env_logger::Builder::from_env(env_logger::Env::default().default_filter_or(default)).init(); + + log::info!("{} {} starting...", module_path!(), env!("CARGO_PKG_VERSION")); + log::info!("UDP Gateway Server running at {}", args.listen_addr); + + #[cfg(unix)] + if args.daemonize { + let stdout = std::fs::File::create("/tmp/udpgw.out")?; + let stderr = std::fs::File::create("/tmp/udpgw.err")?; + let daemonize = daemonize::Daemonize::new() + .working_directory("/tmp") + .umask(0o777) + .stdout(stdout) + .stderr(stderr) + .privileged_action(|| "Executed before drop privileges"); + let _ = daemonize + .start() + .map_err(|e| format!("Failed to daemonize process, error:{:?}", e))?; + } + + loop { + let (mut tcp_stream, addr) = tcp_listener.accept().await?; + let client = Client::new(addr); + log::info!("client {} connected", addr); + let params = args.clone(); + tokio::spawn(async move { + let (tx, rx) = mpsc::channel::(100); + let (tcp_read_stream, tcp_write_stream) = tcp_stream.split(); + let res = tokio::select! { + v = process_client_udp_req(¶ms, tx, client, tcp_read_stream) => v, + v = write_to_client(addr, tcp_write_stream, rx) => v, + }; + log::info!("client {} disconnected with {:?}", addr, res); + }); + } +} diff --git a/src/error.rs b/src/error.rs index f460b62..755ee0c 100644 --- a/src/error.rs +++ b/src/error.rs @@ -6,7 +6,7 @@ pub enum Error { #[error(transparent)] Io(#[from] std::io::Error), - #[cfg(target_os = "linux")] + #[cfg(unix)] #[error("nix::errno::Errno {0:?}")] NixErrno(#[from] nix::errno::Errno), diff --git a/src/lib.rs b/src/lib.rs index 554fe14..242ba15 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,5 @@ +#[cfg(feature = "udpgw")] +use crate::udpgw::UdpGwClient; use crate::{ directions::{IncomingDataEvent, IncomingDirection, OutgoingDirection}, http::HttpManager, @@ -9,6 +11,8 @@ use ipstack::stream::{IpStackStream, IpStackTcpStream, IpStackUdpStream}; use proxy_handler::{ProxyHandler, ProxyHandlerManager}; use socks::SocksProxyManager; pub use socks5_impl::protocol::UserKey; +#[cfg(feature = "udpgw")] +use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6}; use std::{ collections::VecDeque, io::ErrorKind, @@ -23,6 +27,8 @@ use tokio::{ pub use tokio_util::sync::CancellationToken; use tproxy_config::is_private_ip; use udp_stream::UdpStream; +#[cfg(feature = "udpgw")] +use udpgw::{UdpGwClientStream, UdpGwResponse, UDPGW_KEEPALIVE_TIME, UDPGW_MAX_CONNECTIONS}; pub use { args::{ArgDns, ArgProxy, ArgVerbosity, Args, ProxyType}, @@ -59,6 +65,8 @@ mod session_info; pub mod socket_transfer; mod socks; mod traffic_status; +#[cfg(feature = "udpgw")] +pub mod udpgw; mod virtual_dns; #[doc(hidden)] pub mod win_svc; @@ -231,6 +239,23 @@ where let mut ip_stack = ipstack::IpStack::new(ipstack_config, device); + #[cfg(feature = "udpgw")] + let udpgw_client = args.udpgw_server.as_ref().map(|addr| { + log::info!("UDPGW enabled"); + let client = Arc::new(UdpGwClient::new( + mtu, + args.udpgw_max_connections.unwrap_or(UDPGW_MAX_CONNECTIONS), + UDPGW_KEEPALIVE_TIME, + args.udp_timeout, + *addr, + )); + let client_keepalive = client.clone(); + tokio::spawn(async move { + client_keepalive.heartbeat_task().await; + }); + client + }); + loop { let virtual_dns = virtual_dns.clone(); let ip_stack_stream = tokio::select! { @@ -318,6 +343,24 @@ where } else { None }; + #[cfg(feature = "udpgw")] + if let Some(udpgw) = udpgw_client.clone() { + let tcp_src = match udp.peer_addr() { + SocketAddr::V4(_) => SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)), + SocketAddr::V6(_) => SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, 0, 0, 0)), + }; + let tcpinfo = SessionInfo::new(tcp_src, udpgw.get_server_addr(), IpProtocol::Tcp); + let proxy_handler = mgr.new_proxy_handler(tcpinfo, None, false).await?; + let queue = socket_queue.clone(); + tokio::spawn(async move { + let dst = info.dst; // real UDP destination address + if let Err(e) = handle_udp_gateway_session(udp, udpgw, dst, domain_name, proxy_handler, queue, ipv6_enabled).await { + log::info!("Ending {} with \"{}\"", info, e); + } + log::trace!("Session count {}", TASK_COUNT.fetch_sub(1, Relaxed) - 1); + }); + continue; + } match mgr.new_proxy_handler(info, domain_name, true).await { Ok(proxy_handler) => { let socket_queue = socket_queue.clone(); @@ -436,6 +479,125 @@ async fn handle_tcp_session( Ok(()) } +#[cfg(feature = "udpgw")] +async fn handle_udp_gateway_session( + mut udp_stack: IpStackUdpStream, + udpgw_client: Arc, + udp_dst: SocketAddr, + domain_name: Option, + proxy_handler: Arc>, + socket_queue: Option>, + ipv6_enabled: bool, +) -> crate::Result<()> { + let proxy_server_addr = { proxy_handler.lock().await.get_server_addr() }; + let udp_mtu = udpgw_client.get_udp_mtu(); + let udp_timeout = udpgw_client.get_udp_timeout(); + let mut stream = match udpgw_client.get_server_connection().await { + Some(server) => server, + None => { + if udpgw_client.is_full() { + return Err("max udpgw connection limit reached".into()); + } + let mut tcp_server_stream = create_tcp_stream(&socket_queue, proxy_server_addr).await?; + if let Err(e) = handle_proxy_session(&mut tcp_server_stream, proxy_handler).await { + return Err(format!("udpgw connection error: {}", e).into()); + } + UdpGwClientStream::new(tcp_server_stream) + } + }; + + let tcp_local_addr = stream.local_addr().clone(); + + match domain_name { + Some(ref d) => log::info!("[UdpGw] Beginning {} -> {}, domain:{}", &tcp_local_addr, udp_dst, d), + None => log::info!("[UdpGw] Beginning {} -> {}", &tcp_local_addr, udp_dst), + } + + let Some(mut reader) = stream.get_reader() else { + return Err("get reader failed".into()); + }; + + let Some(mut writer) = stream.get_writer() else { + return Err("get writer failed".into()); + }; + + let mut tmp_buf = vec![0; udp_mtu.into()]; + + loop { + tokio::select! { + len = udp_stack.read(&mut tmp_buf) => { + let read_len = match len { + Ok(0) => { + log::info!("[UdpGw] Ending {} <> {}", &tcp_local_addr, udp_dst); + break; + } + Ok(n) => n, + Err(e) => { + log::info!("[UdpGw] Ending {} <> {} with recv_udp_packet {}", &tcp_local_addr, udp_dst, e); + break; + } + }; + crate::traffic_status::traffic_status_update(read_len, 0)?; + let new_id = stream.new_id(); + let remote_addr = match domain_name { + Some(ref d) => socks5_impl::protocol::Address::from((d.clone(), udp_dst.port())), + None => udp_dst.into(), + }; + if let Err(e) = UdpGwClient::send_udpgw_packet(ipv6_enabled, &tmp_buf[0..read_len], &remote_addr, new_id, &mut writer).await { + log::info!("[UdpGw] Ending {} <> {} with send_udpgw_packet {}", &tcp_local_addr, udp_dst, e); + break; + } + log::debug!("[UdpGw] {} -> {} send len {}", &tcp_local_addr, udp_dst, read_len); + stream.update_activity(); + } + ret = UdpGwClient::recv_udpgw_packet(udp_mtu, udp_timeout, &mut reader) => { + match ret { + Ok(packet) => match packet { + //should not received keepalive + UdpGwResponse::KeepAlive => { + log::error!("[UdpGw] Ending {} <> {} with recv keepalive", &tcp_local_addr, udp_dst); + stream.close(); + break; + } + //server udp may be timeout,can continue to receive udp data? + UdpGwResponse::Error => { + log::info!("[UdpGw] Ending {} <> {} with recv udp error", &tcp_local_addr, udp_dst); + stream.update_activity(); + continue; + } + UdpGwResponse::TcpClose => { + log::error!("[UdpGw] Ending {} <> {} with tcp closed", &tcp_local_addr, udp_dst); + stream.close(); + break; + } + UdpGwResponse::Data(data) => { + use socks5_impl::protocol::StreamOperation; + let len = data.len(); + log::debug!("[UdpGw] {} <- {} receive len {}", &tcp_local_addr, udp_dst, len); + if let Err(e) = udp_stack.write_all(&data.data).await { + log::error!("[UdpGw] Ending {} <> {} with send_udp_packet {}", &tcp_local_addr, udp_dst, e); + break; + } + crate::traffic_status::traffic_status_update(0, len)?; + } + }, + Err(e) => { + log::warn!("[UdpGw] Ending {} <> {} with recv_udpgw_packet {}", &tcp_local_addr, udp_dst, e); + break; + } + } + stream.update_activity(); + } + } + } + + if !stream.is_closed() { + udpgw_client.release_server_connection_full(stream, reader, writer).await; + } + + Ok(()) +} + async fn handle_udp_associate_session( mut udp_stack: IpStackUdpStream, proxy_type: ProxyType, diff --git a/src/udpgw.rs b/src/udpgw.rs new file mode 100644 index 0000000..4d7f018 --- /dev/null +++ b/src/udpgw.rs @@ -0,0 +1,542 @@ +use crate::error::Result; +use socks5_impl::protocol::{Address, AsyncStreamOperation, BufMut, StreamOperation}; +use std::{collections::VecDeque, hash::Hash, net::SocketAddr, sync::atomic::Ordering::Relaxed}; +use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + net::{ + tcp::{OwnedReadHalf, OwnedWriteHalf}, + TcpStream, + }, + sync::Mutex, + time::{sleep, Duration}, +}; + +pub(crate) const UDPGW_LENGTH_FIELD_SIZE: usize = std::mem::size_of::(); +pub(crate) const UDPGW_MAX_CONNECTIONS: u16 = 100; +pub(crate) const UDPGW_KEEPALIVE_TIME: tokio::time::Duration = std::time::Duration::from_secs(10); + +pub const UDPGW_FLAG_KEEPALIVE: u8 = 0x01; +pub const UDPGW_FLAG_ERR: u8 = 0x20; +pub const UDPGW_FLAG_DATA: u8 = 0x02; + +static TCP_COUNTER: std::sync::atomic::AtomicU32 = std::sync::atomic::AtomicU32::new(0); + +/// UDP Gateway Packet Format +/// +/// The format is referenced from SOCKS5 packet format, with additional flags and connection ID fields. +/// +/// `LEN`: This field is indicated the length of the packet, not including the length field itself. +/// +/// `FLAGS`: This field is used to indicate the packet type. The flags are defined as follows: +/// - `0x01`: Keepalive packet without address and data +/// - `0x20`: Error packet without address and data +/// - `0x02`: Data packet with address and data +/// +/// `CONN_ID`: This field is used to indicate the unique connection ID for the packet. +/// +/// `ATYP` & `DST.ADDR` & `DST.PORT`: This fields are used to indicate the remote address and port. +/// It can be either an IPv4 address, an IPv6 address, or a domain name, depending on the `ATYP` field. +/// The address format directly uses the address format of the [SOCKS5](https://datatracker.ietf.org/doc/html/rfc1928#section-4) protocol. +/// - `ATYP`: Address Type, 1 byte, indicating the type of address ( 0x01-IPv4, 0x04-IPv6, or 0x03-domain name ) +/// - `DST.ADDR`: Destination Address. If `ATYP` is 0x01 or 0x04, it is 4 or 16 bytes of IP address; +/// If `ATYP` is 0x03, it is a domain name, `DST.ADDR` is a variable length field, +/// it begins with a 1-byte length field and then the domain name without null-termination, +/// since the length field is 1 byte, the maximum length of the domain name is 255 bytes. +/// - `DST.PORT`: Destination Port, 2 bytes, the port number of the destination address. +/// +/// `DATA`: The data field, a variable length field, the length is determined by the `LEN` field. +/// +/// All the digits fields are in big-endian byte order. +/// +/// ```plain +/// +-----+ +-------+---------+ +------+----------+----------+ +----------+ +/// | LEN | | FLAGS | CONN_ID | | ATYP | DST.ADDR | DST.PORT | | DATA | +/// +-----+ +-------+---------+ +------+----------+----------+ +----------+ +/// | 2 | | 1 | 2 | | 1 | Variable | 2 | | Variable | +/// +-----+ +-------+---------+ +------+----------+----------+ +----------+ +/// ``` +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct Packet { + pub header: UdpgwHeader, + pub address: Option
, + pub data: Vec, +} + +impl std::fmt::Display for Packet { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let addr = self.address.as_ref().map_or("None".to_string(), |addr| addr.to_string()); + let len = self.data.len(); + write!(f, "Packet {{ {}, address: {}, payload length: {} }}", self.header, addr, len) + } +} + +impl From for Vec { + fn from(packet: Packet) -> Vec { + (&packet).into() + } +} + +impl From<&Packet> for Vec { + fn from(packet: &Packet) -> Vec { + let mut bytes: Vec = vec![]; + packet.write_to_buf(&mut bytes); + bytes + } +} + +impl TryFrom<&[u8]> for Packet { + type Error = std::io::Error; + + fn try_from(value: &[u8]) -> std::result::Result { + if value.len() < UDPGW_LENGTH_FIELD_SIZE { + return Err(std::io::ErrorKind::InvalidData.into()); + } + let mut iter = std::io::Cursor::new(value); + use tokio_util::bytes::Buf; + let length = iter.get_u16(); + if value.len() < length as usize + UDPGW_LENGTH_FIELD_SIZE { + return Err(std::io::ErrorKind::InvalidData.into()); + } + let header = UdpgwHeader::retrieve_from_stream(&mut iter)?; + let address = if header.flags & UDPGW_FLAG_DATA != 0 { + Some(Address::retrieve_from_stream(&mut iter)?) + } else { + None + }; + Ok(Packet::new(header, address, iter.chunk())) + } +} + +impl Packet { + pub fn new(header: UdpgwHeader, address: Option
, data: &[u8]) -> Self { + let data = data.to_vec(); + Packet { header, address, data } + } + + pub fn build_keepalive_packet(conn_id: u16) -> Self { + Packet::new(UdpgwHeader::new(UDPGW_FLAG_KEEPALIVE, conn_id), None, &[]) + } + + pub fn build_error_packet(conn_id: u16) -> Self { + Packet::new(UdpgwHeader::new(UDPGW_FLAG_ERR, conn_id), None, &[]) + } + + pub fn build_packet_from_address(conn_id: u16, remote_addr: &Address, data: &[u8]) -> std::io::Result { + use socks5_impl::protocol::Address::{DomainAddress, SocketAddress}; + let packet = match remote_addr { + SocketAddress(addr) => Packet::build_ip_packet(conn_id, *addr, data), + DomainAddress(domain, port) => Packet::build_domain_packet(conn_id, *port, domain, data)?, + }; + Ok(packet) + } + + pub fn build_ip_packet(conn_id: u16, remote_addr: SocketAddr, data: &[u8]) -> Self { + let addr: Address = remote_addr.into(); + Packet::new(UdpgwHeader::new(UDPGW_FLAG_DATA, conn_id), Some(addr), data) + } + + pub fn build_domain_packet(conn_id: u16, port: u16, domain: &str, data: &[u8]) -> std::io::Result { + if domain.len() > 255 { + return Err(std::io::ErrorKind::InvalidInput.into()); + } + let addr = Address::from((domain, port)); + Ok(Packet::new(UdpgwHeader::new(UDPGW_FLAG_DATA, conn_id), Some(addr), data)) + } +} + +impl StreamOperation for Packet { + fn retrieve_from_stream(stream: &mut R) -> std::io::Result + where + R: std::io::Read, + Self: Sized, + { + let mut buf = [0; UDPGW_LENGTH_FIELD_SIZE]; + stream.read_exact(&mut buf)?; + let length = u16::from_be_bytes(buf) as usize; + let header = UdpgwHeader::retrieve_from_stream(stream)?; + let address = if header.flags & UDPGW_FLAG_DATA != 0 { + Some(Address::retrieve_from_stream(stream)?) + } else { + None + }; + let read_len = header.len() + address.as_ref().map_or(0, |addr| addr.len()); + if length < read_len { + return Err(std::io::ErrorKind::InvalidData.into()); + } + let mut data = vec![0; length - read_len]; + stream.read_exact(&mut data)?; + Ok(Packet::new(header, address, &data)) + } + + fn write_to_buf(&self, buf: &mut B) { + let len = self.len() - UDPGW_LENGTH_FIELD_SIZE; + buf.put_u16(len as u16); + self.header.write_to_buf(buf); + if let Some(addr) = &self.address { + addr.write_to_buf(buf); + } + buf.put_slice(&self.data); + } + + fn len(&self) -> usize { + UDPGW_LENGTH_FIELD_SIZE + self.header.len() + self.address.as_ref().map_or(0, |addr| addr.len()) + self.data.len() + } +} + +#[async_trait::async_trait] +impl AsyncStreamOperation for Packet { + async fn retrieve_from_async_stream(r: &mut R) -> std::io::Result + where + R: tokio::io::AsyncRead + Unpin + Send, + Self: Sized, + { + let mut buf = [0; UDPGW_LENGTH_FIELD_SIZE]; + r.read_exact(&mut buf).await?; + let length = u16::from_be_bytes(buf) as usize; + let header = UdpgwHeader::retrieve_from_async_stream(r).await?; + let address = if header.flags & UDPGW_FLAG_DATA != 0 { + Some(Address::retrieve_from_async_stream(r).await?) + } else { + None + }; + let read_len = header.len() + address.as_ref().map_or(0, |addr| addr.len()); + if length < read_len { + return Err(std::io::ErrorKind::InvalidData.into()); + } + let mut data = vec![0; length - read_len]; + r.read_exact(&mut data).await?; + Ok(Packet::new(header, address, &data)) + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct UdpgwHeader { + pub flags: u8, + pub conn_id: u16, +} + +impl std::fmt::Display for UdpgwHeader { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let id = self.conn_id; + write!(f, "flags: 0x{:02x}, conn_id: {}", self.flags, id) + } +} + +impl StreamOperation for UdpgwHeader { + fn retrieve_from_stream(stream: &mut R) -> std::io::Result + where + R: std::io::Read, + Self: Sized, + { + let mut buf = [0; UdpgwHeader::static_len()]; + stream.read_exact(&mut buf)?; + UdpgwHeader::try_from(&buf[..]) + } + + fn write_to_buf(&self, buf: &mut B) { + let bytes: Vec = self.into(); + buf.put_slice(&bytes); + } + + fn len(&self) -> usize { + Self::static_len() + } +} + +#[async_trait::async_trait] +impl AsyncStreamOperation for UdpgwHeader { + async fn retrieve_from_async_stream(r: &mut R) -> std::io::Result + where + R: tokio::io::AsyncRead + Unpin + Send, + Self: Sized, + { + let mut buf = [0; UdpgwHeader::static_len()]; + r.read_exact(&mut buf).await?; + UdpgwHeader::try_from(&buf[..]) + } +} + +impl UdpgwHeader { + pub fn new(flags: u8, conn_id: u16) -> Self { + UdpgwHeader { flags, conn_id } + } + + pub const fn static_len() -> usize { + std::mem::size_of::() + std::mem::size_of::() + } +} + +impl TryFrom<&[u8]> for UdpgwHeader { + type Error = std::io::Error; + + fn try_from(value: &[u8]) -> std::result::Result { + if value.len() < UdpgwHeader::static_len() { + return Err(std::io::ErrorKind::InvalidData.into()); + } + let conn_id = u16::from_be_bytes([value[1], value[2]]); + Ok(UdpgwHeader { flags: value[0], conn_id }) + } +} + +impl From<&UdpgwHeader> for Vec { + fn from(header: &UdpgwHeader) -> Vec { + let mut bytes = vec![0; header.len()]; + bytes[0] = header.flags; + bytes[1..3].copy_from_slice(&header.conn_id.to_be_bytes()); + bytes + } +} + +#[allow(dead_code)] +#[derive(Debug)] +pub(crate) enum UdpGwResponse { + KeepAlive, + Error, + TcpClose, + Data(Packet), +} + +#[derive(Debug)] +pub(crate) struct UdpGwClientStream { + local_addr: String, + writer: Option, + reader: Option, + conn_id: u16, + closed: bool, + last_activity: std::time::Instant, +} + +impl Drop for UdpGwClientStream { + fn drop(&mut self) { + TCP_COUNTER.fetch_sub(1, Relaxed); + } +} + +impl UdpGwClientStream { + pub fn close(&mut self) { + self.closed = true; + } + + pub fn get_reader(&mut self) -> Option { + self.reader.take() + } + + pub fn set_reader(&mut self, reader: Option) { + self.reader = reader; + } + + pub fn set_writer(&mut self, writer: Option) { + self.writer = writer; + } + + pub fn get_writer(&mut self) -> Option { + self.writer.take() + } + + pub fn local_addr(&self) -> &String { + &self.local_addr + } + + pub fn update_activity(&mut self) { + self.last_activity = std::time::Instant::now(); + } + + pub fn is_closed(&mut self) -> bool { + self.closed + } + + pub fn id(&mut self) -> u16 { + self.conn_id + } + + pub fn new_id(&mut self) -> u16 { + self.conn_id += 1; + self.conn_id + } + + pub fn new(tcp_server_stream: TcpStream) -> Self { + let default = "0.0.0.0:0".parse::().unwrap(); + let local_addr = tcp_server_stream.local_addr().unwrap_or(default).to_string(); + let (rx, tx) = tcp_server_stream.into_split(); + let writer = tx; + let reader = rx; + TCP_COUNTER.fetch_add(1, Relaxed); + UdpGwClientStream { + local_addr, + reader: Some(reader), + writer: Some(writer), + last_activity: std::time::Instant::now(), + closed: false, + conn_id: 0, + } + } +} + +#[derive(Debug)] +pub(crate) struct UdpGwClient { + udp_mtu: u16, + max_connections: u16, + udp_timeout: u64, + keepalive_time: Duration, + server_addr: SocketAddr, + server_connections: Mutex>, +} + +impl UdpGwClient { + pub fn new(udp_mtu: u16, max_connections: u16, keepalive_time: Duration, udp_timeout: u64, server_addr: SocketAddr) -> Self { + let server_connections = Mutex::new(VecDeque::with_capacity(max_connections as usize)); + UdpGwClient { + udp_mtu, + max_connections, + udp_timeout, + server_addr, + keepalive_time, + server_connections, + } + } + + pub(crate) fn get_udp_mtu(&self) -> u16 { + self.udp_mtu + } + + pub(crate) fn get_udp_timeout(&self) -> u64 { + self.udp_timeout + } + + pub(crate) fn is_full(&self) -> bool { + TCP_COUNTER.load(Relaxed) >= self.max_connections as u32 + } + + pub(crate) async fn get_server_connection(&self) -> Option { + self.server_connections.lock().await.pop_front() + } + + pub(crate) async fn release_server_connection(&self, stream: UdpGwClientStream) { + if self.server_connections.lock().await.len() < self.max_connections as usize { + self.server_connections.lock().await.push_back(stream); + } + } + + pub(crate) async fn release_server_connection_full( + &self, + mut stream: UdpGwClientStream, + reader: OwnedReadHalf, + writer: OwnedWriteHalf, + ) { + if self.server_connections.lock().await.len() < self.max_connections as usize { + stream.set_reader(Some(reader)); + stream.set_writer(Some(writer)); + self.server_connections.lock().await.push_back(stream); + } + } + + pub(crate) fn get_server_addr(&self) -> SocketAddr { + self.server_addr + } + + /// Heartbeat task asynchronous function to periodically check and maintain the active state of the server connection. + pub(crate) async fn heartbeat_task(&self) { + loop { + sleep(self.keepalive_time).await; + if let Some(mut stream) = self.get_server_connection().await { + if stream.last_activity.elapsed() < self.keepalive_time { + self.release_server_connection(stream).await; + continue; + } + + let Some(mut stream_reader) = stream.get_reader() else { + continue; + }; + + let Some(mut stream_writer) = stream.get_writer() else { + continue; + }; + let local_addr = stream_writer.local_addr(); + log::debug!("{:?}:{} send keepalive", local_addr, stream.id()); + let keepalive_packet: Vec = Packet::build_keepalive_packet(stream.id()).into(); + if let Err(e) = stream_writer.write_all(&keepalive_packet).await { + log::warn!("{:?}:{} send keepalive failed: {}", local_addr, stream.id(), e); + continue; + } + match UdpGwClient::recv_udpgw_packet(self.udp_mtu, 10, &mut stream_reader).await { + Ok(UdpGwResponse::KeepAlive) => { + stream.update_activity(); + self.release_server_connection_full(stream, stream_reader, stream_writer).await; + } + Ok(v) => log::warn!("{:?}:{} keepalive unexpected response: {:?}", local_addr, stream.id(), v), + Err(e) => log::warn!("{:?}:{} keepalive no response, error \"{}\"", local_addr, stream.id(), e), + } + } + } + } + + /// Parses the UDP response data. + pub(crate) fn parse_udp_response(udp_mtu: u16, data: &[u8]) -> Result { + let packet = Packet::try_from(data)?; + let flags = packet.header.flags; + if flags & UDPGW_FLAG_ERR != 0 { + return Ok(UdpGwResponse::Error); + } + if flags & UDPGW_FLAG_KEEPALIVE != 0 { + return Ok(UdpGwResponse::KeepAlive); + } + if packet.data.len() > udp_mtu as usize { + return Err("too much data".into()); + } + Ok(UdpGwResponse::Data(packet)) + } + + /// Receives a UDP gateway packet. + /// + /// This function is responsible for receiving packets from the UDP gateway + /// + /// # Arguments + /// - `udp_mtu`: The maximum transmission unit size for UDP packets. + /// - `udp_timeout`: The timeout in seconds for receiving UDP packets. + /// - `stream`: A mutable reference to the UDP gateway client stream reader. + /// + /// # Returns + /// - `Result`: Returns a result type containing the parsed UDP gateway response, or an error if one occurs. + pub(crate) async fn recv_udpgw_packet(udp_mtu: u16, udp_timeout: u64, stream: &mut OwnedReadHalf) -> Result { + let mut data = vec![0; udp_mtu.into()]; + let data_len = tokio::time::timeout(tokio::time::Duration::from_secs(udp_timeout + 2), stream.read(&mut data)) + .await + .map_err(std::io::Error::from)??; + if data_len == 0 { + return Ok(UdpGwResponse::TcpClose); + } + UdpGwClient::parse_udp_response(udp_mtu, &data[..data_len]) + } + + /// Sends a UDP gateway packet. + /// + /// This function constructs and sends a UDP gateway packet based on the IPv6 enabled status, data length, + /// remote address, domain (if any), connection ID, and the UDP gateway client writer stream. + /// + /// # Arguments + /// + /// * `ipv6_enabled` - Whether IPv6 is enabled + /// * `data` - The data packet + /// * `remote_addr` - Remote address + /// * `conn_id` - Connection ID + /// * `stream` - UDP gateway client writer stream + /// + /// # Returns + /// + /// Returns `Ok(())` if the packet is sent successfully, otherwise returns an error. + pub(crate) async fn send_udpgw_packet( + ipv6_enabled: bool, + data: &[u8], + remote_addr: &socks5_impl::protocol::Address, + conn_id: u16, + stream: &mut OwnedWriteHalf, + ) -> Result<()> { + if !ipv6_enabled && remote_addr.get_type() == socks5_impl::protocol::AddressType::IPv6 { + return Err("ipv6 not support".into()); + } + let out_data: Vec = Packet::build_packet_from_address(conn_id, remote_addr, data)?.into(); + stream.write_all(&out_data).await?; + + Ok(()) + } +}