From aee8e14a22b7eb070a22222b9f218efcd131c72e Mon Sep 17 00:00:00 2001 From: suchao Date: Thu, 17 Oct 2024 13:59:46 +0800 Subject: [PATCH] support udp gateway mode --- Cargo.toml | 4 + src/args.rs | 10 + src/bin/udpgw_server.rs | 355 +++++++++++++++++++++++++++++++ src/error.rs | 3 + src/lib.rs | 145 ++++++++++++- src/udpgw.rs | 449 ++++++++++++++++++++++++++++++++++++++++ 6 files changed, 965 insertions(+), 1 deletion(-) create mode 100644 src/bin/udpgw_server.rs create mode 100644 src/udpgw.rs diff --git a/Cargo.toml b/Cargo.toml index 0f9d112..a71366f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -65,5 +65,9 @@ serde_json = "1" name = "tun2proxy-bin" path = "src/bin/main.rs" +[[bin]] +name = "udpgwserver" +path = "src/bin/udpgw_server.rs" + [profile.release] strip = "symbols" diff --git a/src/args.rs b/src/args.rs index 6233350..16ed744 100644 --- a/src/args.rs +++ b/src/args.rs @@ -74,6 +74,10 @@ pub struct Args { #[arg(short, long, value_name = "strategy", value_enum, default_value = "direct")] pub dns: ArgDns, + /// UDP gateway address + #[arg(long, value_name = "IP:PORT")] + pub udpgw_bind_addr: Option, + /// DNS resolver address #[arg(long, value_name = "IP", default_value = "8.8.8.8")] pub dns_addr: IpAddr, @@ -136,6 +140,7 @@ impl Default for Args { admin_command: Vec::new(), ipv6_enabled: false, setup, + udpgw_bind_addr: None, dns: ArgDns::default(), dns_addr: "8.8.8.8".parse().unwrap(), bypass: vec![], @@ -171,6 +176,11 @@ impl Args { self } + pub fn udpgw(&mut self, udpgw: SocketAddr) -> &mut Self { + self.udpgw_bind_addr = Some(udpgw); + self + } + #[cfg(unix)] pub fn tun_fd(&mut self, tun_fd: Option) -> &mut Self { self.tun_fd = tun_fd; diff --git a/src/bin/udpgw_server.rs b/src/bin/udpgw_server.rs new file mode 100644 index 0000000..45671ec --- /dev/null +++ b/src/bin/udpgw_server.rs @@ -0,0 +1,355 @@ +use std::collections::HashMap; +use std::mem; +use std::net::Ipv4Addr; +use std::net::SocketAddr; +use std::net::SocketAddrV4; +use std::net::ToSocketAddrs; +use std::sync::Arc; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::tcp::ReadHalf; +use tokio::net::TcpListener; +use tokio::net::UdpSocket; +use tokio::sync::mpsc; +use tokio::sync::mpsc::Sender; +use tokio::sync::Mutex; +pub use tun2proxy::udpgw::*; +use tun2proxy::ArgVerbosity; +use tun2proxy::Result; +pub(crate) const CLIENT_DISCONNECT_TIMEOUT: tokio::time::Duration = std::time::Duration::from_secs(60); + +#[derive(Debug)] +struct Connection { + flags: u8, + server_addr: SocketAddr, + conid: u16, + data: Vec, +} + +struct Client { + #[allow(dead_code)] + addr: SocketAddr, + buf: Vec, + connections: Arc>>, + last_activity: std::time::Instant, +} + +#[derive(Debug, Clone, clap::Parser)] +pub struct UdpGwArgs { + /// UDP mtu + #[arg(long, value_name = "udp mtu", default_value = "10240")] + pub udp_mtu: u16, + + /// Verbosity level + #[arg(short, long, value_name = "level", value_enum, default_value = "info")] + pub verbosity: ArgVerbosity, + + /// UDP timeout in seconds + #[arg(long, value_name = "seconds", default_value = "3")] + pub udp_timeout: u64, + + /// UDP gateway listen address + #[arg(long, value_name = "IP:PORT", default_value = "127.0.0.1:7300")] + pub listen_addr: SocketAddr, +} + +impl UdpGwArgs { + #[allow(clippy::let_and_return)] + pub fn parse_args() -> Self { + use clap::Parser; + let args = Self::parse(); + args + } +} +async fn send_error_response(tx: Sender>, con: &mut Connection) { + let mut error_packet = vec![]; + error_packet.extend_from_slice(&(std::mem::size_of::() as u16).to_le_bytes()); + error_packet.extend_from_slice(&[UDPGW_FLAG_ERR]); + error_packet.extend_from_slice(&con.conid.to_le_bytes()); + if let Err(e) = tx.send(error_packet).await { + log::error!("send error response error {:?}", e); + } +} + +pub fn parse_udp_req_data(udp_mtu: u16, data_len: usize, data: &[u8]) -> Result<(&[u8], u8, u16, SocketAddr)> { + if data_len < mem::size_of::() { + return Err("Invalid udpgw data".into()); + } + let header_bytes = &data[..mem::size_of::()]; + let header = UdpgwHeader { + flags: header_bytes[0], + conid: u16::from_le_bytes([header_bytes[1], header_bytes[2]]), + }; + + let flags = header.flags; + let conid = header.conid; + + // keepalive + if flags & UDPGW_FLAG_KEEPALIVE != 0 { + return Ok((data, UDPGW_FLAG_KEEPALIVE, 0, SocketAddrV4::new(Ipv4Addr::from(0), 0).into())); + } + + // parse address + let ip_data = &data[mem::size_of::()..]; + let mut data_len = data_len - mem::size_of::(); + // port_len + min(ipv4/ipv6/(domain_len + 1)) + if data_len < mem::size_of::() + 2 { + return Err("Invalid udpgw data".into()); + } + if flags & UDPGW_FLAG_DOMAIN != 0 { + let addr_port = u16::from_be_bytes([ip_data[0], ip_data[1]]); + data_len -= 2; + if let Some(end) = ip_data.iter().skip(2).position(|&x| x == 0) { + let domain_slice = &ip_data[2..end + 2]; + match std::str::from_utf8(domain_slice) { + Ok(domain) => { + let target_str = format!("{}:{}", domain, addr_port); + let target = target_str + .to_socket_addrs()? + .next() + .ok_or(format!("Invalid address {}", target_str))?; + // check payload length + if data_len < 2 + domain.len() { + return Err("Invalid udpgw data".into()); + } + data_len -= domain.len() + 1; + if data_len > udp_mtu as usize { + return Err("too much data".into()); + } + let udpdata = &ip_data[(2 + domain.len() + 1)..]; + return Ok((udpdata, flags, conid, target)); + } + Err(_) => { + return Err("Invalid UTF-8 sequence in domain".into()); + } + } + } else { + return Err("missing domain name".into()); + } + } else if flags & UDPGW_FLAG_IPV6 != 0 { + if data_len < mem::size_of::() { + return Err("Ipv6 Invalid UDP data".into()); + } + let addr_ipv6_bytes = &ip_data[..mem::size_of::()]; + let addr_ipv6 = UdpgwAddrIpv6 { + addr_ip: addr_ipv6_bytes[..16].try_into().map_err(|_| "Failed to convert slice to array")?, + addr_port: u16::from_be_bytes([addr_ipv6_bytes[16], addr_ipv6_bytes[17]]), + }; + data_len -= mem::size_of::(); + + // check payload length + if data_len > udp_mtu as usize { + return Err("too much data".into()); + } + return Ok(( + &ip_data[mem::size_of::()..(data_len + mem::size_of::())], + flags, + conid, + UdpgwAddr::IPV6(addr_ipv6).into(), + )); + } else { + if data_len < mem::size_of::() { + return Err("Ipv4 Invalid UDP data".into()); + } + let addr_ipv4_bytes = &ip_data[..mem::size_of::()]; + let addr_ipv4 = UdpgwAddrIpv4 { + addr_ip: u32::from_be_bytes([addr_ipv4_bytes[0], addr_ipv4_bytes[1], addr_ipv4_bytes[2], addr_ipv4_bytes[3]]), + addr_port: u16::from_be_bytes([addr_ipv4_bytes[4], addr_ipv4_bytes[5]]), + }; + data_len -= mem::size_of::(); + + // check payload length + if data_len > udp_mtu as usize { + return Err("too much data".into()); + } + + return Ok(( + &ip_data[mem::size_of::()..(data_len + mem::size_of::())], + flags, + conid, + UdpgwAddr::IPV4(addr_ipv4).into(), + )); + } +} + +async fn process_udp(addr: SocketAddr, udp_timeout: u64, tx: Sender>, con: &mut Connection) -> Result<()> { + let std_sock = std::net::UdpSocket::bind("0.0.0.0:0")?; + std_sock.set_nonblocking(true)?; + nix::sys::socket::setsockopt(&std_sock, nix::sys::socket::sockopt::ReuseAddr, &true)?; + let socket = UdpSocket::from_std(std_sock)?; + socket.send_to(&con.data, &con.server_addr).await?; + con.data.resize(2048, 0); + match tokio::time::timeout(tokio::time::Duration::from_secs(udp_timeout), socket.recv_from(&mut con.data[..])).await? { + Ok((len, _addr)) => { + let mut packet = vec![]; + let mut pack_len = mem::size_of::() + len; + match con.server_addr.into() { + UdpgwAddr::IPV4(addr_ipv4) => { + pack_len += mem::size_of::(); + packet.extend_from_slice(&(pack_len as u16).to_le_bytes()); + packet.extend_from_slice(&[con.flags]); + packet.extend_from_slice(&con.conid.to_le_bytes()); + packet.extend_from_slice(&addr_ipv4.addr_ip.to_be_bytes()); + packet.extend_from_slice(&addr_ipv4.addr_port.to_be_bytes()); + packet.extend_from_slice(&con.data[..len]); + } + UdpgwAddr::IPV6(addr_ipv6) => { + pack_len += mem::size_of::(); + packet.extend_from_slice(&(pack_len as u16).to_le_bytes()); + packet.extend_from_slice(&[con.flags]); + packet.extend_from_slice(&con.conid.to_le_bytes()); + packet.extend_from_slice(&addr_ipv6.addr_ip); + packet.extend_from_slice(&addr_ipv6.addr_port.to_be_bytes()); + packet.extend_from_slice(&con.data[..len]); + } + } + if let Err(e) = tx.send(packet).await { + log::error!("client {} send udp response error {:?}", addr, e); + } + } + Err(e) => { + log::error!("client {} udp recv_from error: {:?}", addr, e); + } + } + Ok(()) +} + +async fn process_client_udp_req<'a>(args: Arc, tx: Sender>, mut client: Client, mut tcp_read_stream: ReadHalf<'a>) { + let mut buf = vec![0; args.udp_mtu as usize]; + let mut len_buf = [0; mem::size_of::()]; + let udp_mtu = args.udp_mtu; + let udp_timeout = args.udp_timeout; + 'out: loop { + let result; + match tokio::time::timeout(tokio::time::Duration::from_secs(2), tcp_read_stream.read(&mut len_buf)).await { + Ok(ret) => { + result = ret; + } + Err(_e) => { + if client.last_activity.elapsed() >= CLIENT_DISCONNECT_TIMEOUT { + log::warn!("client {} last_activity elapsed", client.addr); + return; + } + continue; + } + }; + match result { + Ok(0) => break, // Connection closed + Ok(n) => { + if n < mem::size_of::() { + log::error!("client {} received PackLenHeader error", client.addr); + break; + } + let packet_len = u16::from_le_bytes([len_buf[0], len_buf[1]]); + if packet_len > udp_mtu { + log::error!("client {} received packet too long", client.addr); + break; + } + log::info!("client {} recvied packet len {}", client.addr, packet_len); + buf.resize(packet_len as usize, 0); + client.buf.clear(); + let mut left_len: usize = packet_len as usize; + while left_len > 0 { + if let Ok(len) = tcp_read_stream.read(&mut buf[..left_len]).await { + if len == 0 { + break 'out; + } + client.buf.extend_from_slice(&mut buf[..len]); + left_len -= len; + } else { + break 'out; + } + } + client.last_activity = std::time::Instant::now(); + let ret = parse_udp_req_data(udp_mtu, client.buf.len(), &client.buf); + if let Ok((udpdata, flags, conid, reqaddr)) = ret { + if flags & UDPGW_FLAG_KEEPALIVE != 0 { + log::debug!("client {} recvied keepalive packet", client.addr); + continue; + } + log::debug!( + "client {} recvied udp data,flags:{},conid:{},addr:{:?},data len:{}", + client.addr, + flags, + conid, + reqaddr, + udpdata.len() + ); + let mut con_lock = client.connections.lock().await; + let con = con_lock.get_mut(&conid); + if let Some(conn) = con { + conn.data.clear(); + conn.data.extend_from_slice(udpdata); + if let Err(e) = process_udp(client.addr, udp_timeout, tx.clone(), conn).await { + log::error!("client {} process_udp error: {:?}", client.addr, e); + send_error_response(tx.clone(), conn).await; + continue; + } + } else { + drop(con_lock); + let mut conn = Connection { + server_addr: reqaddr, + conid, + flags, + data: udpdata.to_vec(), + }; + if let Err(e) = process_udp(client.addr, udp_timeout, tx.clone(), &mut conn).await { + send_error_response(tx.clone(), &mut conn).await; + log::error!("client {} process_udp error: {:?}", client.addr, e); + continue; + } + client.connections.lock().await.insert(conid, conn); + } + } else { + log::error!("client {} parse_udp_data {:?}", client.addr, ret.err()); + continue; + } + } + Err(_) => { + log::error!("client {} tcp_read_stream error", client.addr); + break; + } + } + } +} + +#[tokio::main] +async fn main() -> Result<()> { + let args = Arc::new(UdpGwArgs::parse_args()); + + let tcp_listener = TcpListener::bind(args.listen_addr).await?; + + log::info!("UDP GW Server started"); + + let default = format!("{:?},hickory_proto=warn", args.verbosity); + + env_logger::Builder::from_env(env_logger::Env::default().default_filter_or(default)).init(); + + loop { + let (mut tcp_stream, addr) = tcp_listener.accept().await?; + let client = Client { + addr, + buf: vec![], + connections: Arc::new(Mutex::new(HashMap::new())), + last_activity: std::time::Instant::now(), + }; + log::info!("client {} connected", addr); + let params = args.clone(); + tokio::spawn(async move { + let (tx, mut rx) = mpsc::channel::>(100); + let (tcp_read_stream, mut tcp_write_stream) = tcp_stream.split(); + tokio::select! { + _ = process_client_udp_req(params, tx, client, tcp_read_stream) =>{} + _ = async { + loop + { + if let Some(udp_response) = rx.recv().await { + log::info!("client {} send udp data len:{}", addr, udp_response.len(),); + let _ = tcp_write_stream.write(&udp_response).await; + } + } + } => {} + } + log::info!("client {} disconnected", addr); + }); + } +} diff --git a/src/error.rs b/src/error.rs index f460b62..a26c74f 100644 --- a/src/error.rs +++ b/src/error.rs @@ -47,6 +47,9 @@ pub enum Error { #[cfg(target_os = "linux")] #[error("bincode::Error {0:?}")] BincodeError(#[from] bincode::Error), + + #[error("tokio::time::error::Elapsed")] + Timeout(#[from] tokio::time::error::Elapsed), } impl From<&str> for Error { diff --git a/src/lib.rs b/src/lib.rs index 46cdec9..bdaf83c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,6 +3,7 @@ use crate::{ http::HttpManager, no_proxy::NoProxyManager, session_info::{IpProtocol, SessionInfo}, + udpgw::UdpGwClient, virtual_dns::VirtualDns, }; use ipstack::stream::{IpStackStream, IpStackTcpStream, IpStackUdpStream}; @@ -12,7 +13,7 @@ pub use socks5_impl::protocol::UserKey; use std::{ collections::VecDeque, io::ErrorKind, - net::{IpAddr, SocketAddr}, + net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}, sync::Arc, }; use tokio::{ @@ -23,6 +24,7 @@ use tokio::{ pub use tokio_util::sync::CancellationToken; use tproxy_config::is_private_ip; use udp_stream::UdpStream; +use udpgw::{UdpGwClientStream, UdpGwResponse, UDPGW_KEEPALIVE_TIME, UDPGW_MAX_CONNECTIONS}; pub use { args::{ArgDns, ArgProxy, ArgVerbosity, Args, ProxyType}, @@ -59,6 +61,7 @@ mod session_info; pub mod socket_transfer; mod socks; mod traffic_status; +pub mod udpgw; mod virtual_dns; #[doc(hidden)] pub mod win_svc; @@ -233,6 +236,19 @@ where let mut ip_stack = ipstack::IpStack::new(ipstack_config, device); + let udpgw_client = match args.udpgw_bind_addr { + None => None, + Some(addr) => { + log::info!("UDPGW enabled"); + let client = Arc::new(UdpGwClient::new(mtu, UDPGW_MAX_CONNECTIONS, UDPGW_KEEPALIVE_TIME, addr)); + let client_keepalive = client.clone(); + tokio::spawn(async move { + client_keepalive.heartbeat_task().await; + }); + Some(client) + } + }; + loop { let virtual_dns = virtual_dns.clone(); let ip_stack_stream = tokio::select! { @@ -265,6 +281,7 @@ where if let Err(err) = handle_tcp_session(tcp, proxy_handler, socket_queue).await { log::error!("{} error \"{}\"", info, err); } + log::trace!("Session count {}", TASK_COUNT.fetch_sub(1, Relaxed) - 1); }); } @@ -311,6 +328,24 @@ where } else { None }; + if let Some(udpgw) = udpgw_client.clone() { + let tcp_src = match udp.peer_addr() { + SocketAddr::V4(_) => SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0)), + SocketAddr::V6(_) => SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0), 0, 0, 0)), + }; + let tcpinfo = SessionInfo::new(tcp_src, udpgw.get_udpgw_bind_addr(), IpProtocol::Tcp); + let proxy_handler = mgr.new_proxy_handler(tcpinfo, None, false).await?; + let socket_queue = socket_queue.clone(); + tokio::spawn(async move { + if let Err(err) = + handle_udp_gateway_session(udp, udpgw, domain_name, proxy_handler, socket_queue, ipv6_enabled).await + { + log::info!("Ending {} with \"{}\"", info, err); + } + log::trace!("Session count {}", TASK_COUNT.fetch_sub(1, Relaxed) - 1); + }); + continue; + } match mgr.new_proxy_handler(info, domain_name, true).await { Ok(proxy_handler) => { let socket_queue = socket_queue.clone(); @@ -429,6 +464,114 @@ async fn handle_tcp_session( Ok(()) } +async fn handle_udp_gateway_session( + mut udp_stack: IpStackUdpStream, + udpgw_client: Arc, + domain_name: Option, + proxy_handler: Arc>, + socket_queue: Option>, + ipv6_enabled: bool, +) -> crate::Result<()> { + let (session_info, server_addr) = { + let handler = proxy_handler.lock().await; + (handler.get_session_info(), handler.get_server_addr()) + }; + let udpinfo = SessionInfo::new(udp_stack.local_addr(), udp_stack.peer_addr(), IpProtocol::Udp); + let udp_mtu = udpgw_client.get_udp_mtu(); + let mut server_stream: UdpGwClientStream; + let server = udpgw_client.get_server_connection().await; + match server { + Some(server) => { + server_stream = server; + } + None => { + log::info!("Beginning {}", session_info); + let mut tcp_server_stream = create_tcp_stream(&socket_queue, server_addr).await?; + if let Err(e) = handle_proxy_session(&mut tcp_server_stream, proxy_handler).await { + return Err(e); + } + server_stream = UdpGwClientStream::new(udp_mtu, tcp_server_stream); + } + }; + + let udp_server_addr = udp_stack.peer_addr(); + + match domain_name { + Some(ref d) => { + log::info!("Beginning {}, domain:{}", udpinfo, d); + } + None => { + log::info!("Beginning {}", udpinfo); + } + } + + log::info!("Beginning {}", udpinfo); + + loop { + let len = UdpGwClient::recv_udp_packet(&mut udp_stack, &mut server_stream).await; + let read_len; + match len { + Ok(n) => { + if n == 0 { + log::info!("Ending {}", udpinfo); + break; + } + read_len = n; + crate::traffic_status::traffic_status_update(n, 0)?; + } + Err(e) => { + log::info!("Ending {} with recv_udp_packet error: {}", udpinfo, e); + break; + } + } + + if let Err(e) = + UdpGwClient::send_udpgw_packet(ipv6_enabled, read_len, udp_server_addr, domain_name.as_ref(), &mut server_stream).await + { + log::info!( + "{:?},Ending {} with send_udpgw_packet error: {}", + server_stream.local_addr(), + udpinfo, + e + ); + break; + } + + match UdpGwClient::recv_udpgw_packet(udp_mtu, &mut server_stream).await { + Ok(packet) => match packet { + //should not received keepalive + UdpGwResponse::KeepAlive => { + log::error!("Ending {} with recv keepalive", udpinfo); + let _ = server_stream.close().await; + break; + } + UdpGwResponse::Error => { + log::info!("Ending {} with recv udp error", udpinfo); + continue; + } + UdpGwResponse::Data(data) => { + crate::traffic_status::traffic_status_update(0, data.len())?; + + if let Err(e) = UdpGwClient::send_udp_packet(data, &mut udp_stack).await { + log::info!("Ending {} with send_udp_packet error: {}", udpinfo, e); + break; + } + } + }, + Err(e) => { + log::info!("Ending {} with recv_udpgw_packet error: {}", udpinfo, e); + break; + } + } + } + + if !server_stream.is_closed() { + udpgw_client.release_server_connection(server_stream).await; + } + + Ok(()) +} + async fn handle_udp_associate_session( mut udp_stack: IpStackUdpStream, proxy_type: ProxyType, diff --git a/src/udpgw.rs b/src/udpgw.rs new file mode 100644 index 0000000..3b12f2f --- /dev/null +++ b/src/udpgw.rs @@ -0,0 +1,449 @@ +use crate::error::Result; +use ipstack::stream::IpStackUdpStream; +use std::collections::VecDeque; +use std::hash::Hash; +use std::mem; +use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf}; +use tokio::net::TcpStream; +use tokio::sync::Mutex; +use tokio::time::{sleep, Duration}; + +pub const UDPGW_MAX_CONNECTIONS: usize = 100; +pub const UDPGW_KEEPALIVE_TIME: tokio::time::Duration = std::time::Duration::from_secs(10); +pub const UDPGW_FLAG_KEEPALIVE: u8 = 0x01; +pub const UDPGW_FLAG_IPV6: u8 = 0x08; +pub const UDPGW_FLAG_DOMAIN: u8 = 0x10; +pub const UDPGW_FLAG_ERR: u8 = 0x20; + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[repr(C)] +#[repr(packed(1))] +pub struct PackLenHeader { + packet_len: u16, +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[repr(C)] +#[repr(packed(1))] +pub struct UdpgwHeader { + pub flags: u8, + pub conid: u16, +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[repr(C)] +#[repr(packed(1))] +pub struct UdpgwAddrIpv4 { + pub addr_ip: u32, + pub addr_port: u16, +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[repr(C)] +#[repr(packed(1))] +pub struct UdpgwAddrIpv6 { + pub addr_ip: [u8; 16], + pub addr_port: u16, +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum UdpgwAddr { + IPV4(UdpgwAddrIpv4), + IPV6(UdpgwAddrIpv6), +} + +impl From for UdpgwAddr { + fn from(addr: SocketAddr) -> Self { + match addr { + SocketAddr::V4(addr_v4) => { + let ipv4_addr = addr_v4.ip().octets(); + let addr_ip = u32::from_be_bytes(ipv4_addr); + UdpgwAddr::IPV4(UdpgwAddrIpv4 { + addr_ip, + addr_port: addr_v4.port(), + }) + } + SocketAddr::V6(addr_v6) => { + let ipv6_addr = addr_v6.ip().octets(); + UdpgwAddr::IPV6(UdpgwAddrIpv6 { + addr_ip: ipv6_addr, + addr_port: addr_v6.port(), + }) + } + } + } +} + +impl From for SocketAddr { + fn from(addr: UdpgwAddr) -> Self { + match addr { + UdpgwAddr::IPV4(addr_ipv4) => SocketAddrV4::new(Ipv4Addr::from(addr_ipv4.addr_ip), addr_ipv4.addr_port).into(), + UdpgwAddr::IPV6(addr_ipv6) => SocketAddrV6::new(Ipv6Addr::from(addr_ipv6.addr_ip), addr_ipv6.addr_port, 0, 0).into(), + } + } +} + +#[allow(dead_code)] +pub(crate) struct UdpGwData<'a> { + flags: u8, + conid: u16, + remote_addr: SocketAddr, + udpdata: &'a [u8], +} + +impl<'a> UdpGwData<'a> { + pub fn len(&self) -> usize { + return self.udpdata.len(); + } +} + +#[allow(dead_code)] +pub(crate) enum UdpGwResponse<'a> { + KeepAlive, + Error, + Data(UdpGwData<'a>), +} + +#[derive(Debug)] +pub(crate) struct UdpGwClientStream { + inner: TcpStream, + conid: u16, + tmp_buf: Vec, + send_buf: Vec, + recv_buf: Vec, + closed: bool, + last_activity: std::time::Instant, +} + +impl AsyncWrite for UdpGwClientStream { + fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + Pin::new(&mut self.inner).poll_write(cx, buf) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.inner).poll_flush(cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.inner).poll_shutdown(cx) + } +} + +impl AsyncRead for UdpGwClientStream { + fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll> { + Pin::new(&mut self.inner).poll_read(cx, buf) + } +} + +impl UdpGwClientStream { + pub async fn close(&mut self) -> Result<()> { + self.inner.shutdown().await?; + self.closed = true; + Ok(()) + } + pub fn local_addr(&self) -> Result { + Ok(self.inner.local_addr()?) + } + pub fn is_closed(&mut self) -> bool { + self.closed + } + + pub fn id(&mut self) -> u16 { + self.conid + } + + pub fn newid(&mut self) -> u16 { + let next = self.conid; + self.conid += 1; + return next; + } + pub fn new(udp_mtu: u16, tcp_server_stream: TcpStream) -> Self { + UdpGwClientStream { + inner: tcp_server_stream, + tmp_buf: vec![0; udp_mtu.into()], + send_buf: vec![0; udp_mtu.into()], + recv_buf: vec![0; udp_mtu.into()], + last_activity: std::time::Instant::now(), + closed: false, + conid: 0, + } + } +} + +#[derive(Debug)] +pub(crate) struct UdpGwClient { + udp_mtu: u16, + max_connections: usize, + keepalive_time: Duration, + udpgw_bind_addr: SocketAddr, + keepalive_packet: Vec, + server_connections: Mutex>, +} + +impl UdpGwClient { + pub fn new(udp_mtu: u16, max_connections: usize, keepalive_time: Duration, udpgw_bind_addr: SocketAddr) -> Self { + let mut keepalive_packet = vec![]; + keepalive_packet.extend_from_slice(&(std::mem::size_of::() as u16).to_le_bytes()); + keepalive_packet.extend_from_slice(&[UDPGW_FLAG_KEEPALIVE, 0, 0]); + let server_connections = Mutex::new(VecDeque::new()); + return UdpGwClient { + udp_mtu, + max_connections, + udpgw_bind_addr, + keepalive_time, + keepalive_packet, + server_connections: server_connections, + }; + } + + pub(crate) fn get_udp_mtu(&self) -> u16 { + self.udp_mtu + } + + pub(crate) async fn get_server_connection(&self) -> Option { + self.server_connections.lock().await.pop_front() + } + + pub(crate) async fn release_server_connection(&self, stream: UdpGwClientStream) { + if self.server_connections.lock().await.len() < self.max_connections { + self.server_connections.lock().await.push_back(stream); + } + } + + pub(crate) fn get_udpgw_bind_addr(&self) -> SocketAddr { + return self.udpgw_bind_addr; + } + + pub(crate) async fn heartbeat_task(&self) { + loop { + sleep(self.keepalive_time).await; + if let Some(mut stream) = self.get_server_connection().await { + if stream.last_activity.elapsed() < self.keepalive_time { + self.release_server_connection(stream).await; + continue; + } + log::debug!("{:?}:{} send keepalive", stream.local_addr(), stream.id()); + if let Err(e) = stream.write_all(&self.keepalive_packet).await { + let _ = stream.close().await; + log::warn!("{:?}:{} Heartbeat failed: {}", stream.local_addr(), stream.id(), e); + } else { + stream.last_activity = std::time::Instant::now(); + match UdpGwClient::recv_udpgw_packet(self.udp_mtu, &mut stream).await { + Ok(UdpGwResponse::KeepAlive) => { + self.release_server_connection(stream).await; + continue; + } + //shoud not receive other + _ => { + continue; + } + } + } + } + } + } + + pub(crate) fn parse_udp_response(udp_mtu: u16, data_len: usize, data: &[u8]) -> Result { + if data_len < mem::size_of::() { + return Err("Invalid udpgw data".into()); + } + let header_bytes = &data[..mem::size_of::()]; + let header = UdpgwHeader { + flags: header_bytes[0], + conid: u16::from_le_bytes([header_bytes[1], header_bytes[2]]), + }; + + let flags = header.flags; + let conid = header.conid; + + // parse address + let ip_data = &data[mem::size_of::()..]; + let mut data_len = data_len - mem::size_of::(); + + if flags & UDPGW_FLAG_ERR != 0 { + return Ok(UdpGwResponse::Error); + } + + if flags & UDPGW_FLAG_ERR != 0 { + return Ok(UdpGwResponse::KeepAlive); + } + + if flags & UDPGW_FLAG_IPV6 != 0 { + if data_len < mem::size_of::() { + return Err("ipv6 Invalid UDP data".into()); + } + let addr_ipv6_bytes = &ip_data[..mem::size_of::()]; + let addr_ipv6 = UdpgwAddrIpv6 { + addr_ip: addr_ipv6_bytes[..16].try_into().map_err(|_| "Failed to convert slice to array")?, + addr_port: u16::from_be_bytes([addr_ipv6_bytes[16], addr_ipv6_bytes[17]]), + }; + data_len -= mem::size_of::(); + // check payload length + if data_len > udp_mtu as usize { + return Err("too much data".into()); + } + return Ok(UdpGwResponse::Data(UdpGwData { + flags, + conid, + remote_addr: UdpgwAddr::IPV6(addr_ipv6).into(), + udpdata: &ip_data[mem::size_of::()..(data_len + mem::size_of::())], + })); + } else { + if data_len < mem::size_of::() { + return Err("ipv4 Invalid UDP data".into()); + } + let addr_ipv4_bytes = &ip_data[..mem::size_of::()]; + let addr_ipv4 = UdpgwAddrIpv4 { + addr_ip: u32::from_be_bytes([addr_ipv4_bytes[0], addr_ipv4_bytes[1], addr_ipv4_bytes[2], addr_ipv4_bytes[3]]), + addr_port: u16::from_be_bytes([addr_ipv4_bytes[4], addr_ipv4_bytes[5]]), + }; + data_len -= mem::size_of::(); + + // check payload length + if data_len > udp_mtu as usize { + return Err("too much data".into()); + } + return Ok(UdpGwResponse::Data(UdpGwData { + flags, + conid, + remote_addr: UdpgwAddr::IPV4(addr_ipv4).into(), + udpdata: &ip_data[mem::size_of::()..(data_len + mem::size_of::())], + })); + } + } + + pub(crate) async fn recv_udp_packet( + udp_stack: &mut IpStackUdpStream, + stream: &mut UdpGwClientStream, + ) -> std::result::Result { + return udp_stack.read(&mut stream.tmp_buf).await; + } + + pub(crate) async fn send_udp_packet<'a>( + packet: UdpGwData<'a>, + udp_stack: &mut IpStackUdpStream, + ) -> std::result::Result<(), std::io::Error> { + return udp_stack.write_all(&packet.udpdata).await; + } + + pub(crate) async fn recv_udpgw_packet(udp_mtu: u16, stream: &mut UdpGwClientStream) -> Result { + stream.recv_buf.resize(2, 0); + let result; + match tokio::time::timeout(tokio::time::Duration::from_secs(10), stream.inner.read(&mut stream.recv_buf)).await { + Ok(ret) => { + result = ret; + } + Err(_e) => { + let _ = stream.close().await; + return Err(format!("{:?} wait tcp data timeout", stream.local_addr()).into()); + } + }; + match result { + Ok(0) => { + let _ = stream.close().await; + return Err(format!("{:?} tcp connection closed", stream.local_addr()).into()); + } + Ok(n) => { + if n < std::mem::size_of::() { + return Err("received PackLenHeader error".into()); + } + let packet_len = u16::from_le_bytes([stream.recv_buf[0], stream.recv_buf[1]]); + if packet_len > udp_mtu { + return Err("packet too long".into()); + } + stream.recv_buf.resize(udp_mtu as usize, 0); + let mut left_len: usize = packet_len as usize; + let mut recv_len = 0; + while left_len > 0 { + if let Ok(len) = stream.inner.read(&mut stream.recv_buf[recv_len..left_len]).await { + if len == 0 { + let _ = stream.close().await; + return Err(format!("{:?} tcp connection closed", stream.local_addr()).into()); + } + recv_len += len; + left_len -= len; + } else { + let _ = stream.close().await; + return Err(format!("{:?} tcp connection closed", stream.local_addr()).into()); + } + } + stream.last_activity = std::time::Instant::now(); + return UdpGwClient::parse_udp_response(udp_mtu, packet_len as usize, &stream.recv_buf); + } + Err(_) => { + let _ = stream.close().await; + return Err(format!("{:?} tcp read error", stream.local_addr()).into()); + } + } + } + + pub(crate) async fn send_udpgw_packet( + ipv6_enabled: bool, + len: usize, + remote_addr: SocketAddr, + domain: Option<&String>, + stream: &mut UdpGwClientStream, + ) -> Result<()> { + stream.send_buf.clear(); + let conid = stream.newid(); + let data = &stream.tmp_buf; + let mut pack_len = std::mem::size_of::() + len; + let packet = &mut stream.send_buf; + let mut flags = 0; + match domain { + Some(domain) => { + let addr_port = match remote_addr.into() { + UdpgwAddr::IPV4(addr_ipv4) => addr_ipv4.addr_port, + UdpgwAddr::IPV6(addr_ipv6) => addr_ipv6.addr_port, + }; + pack_len += std::mem::size_of::(); + let domain_len = domain.len(); + if domain_len > 255 { + return Err("InvalidDomain".into()); + } + pack_len += domain_len + 1; + flags = UDPGW_FLAG_DOMAIN; + packet.extend_from_slice(&(pack_len as u16).to_le_bytes()); + packet.extend_from_slice(&[flags]); + packet.extend_from_slice(&conid.to_le_bytes()); + packet.extend_from_slice(&addr_port.to_be_bytes()); + packet.extend_from_slice(domain.as_bytes()); + packet.push(0); + packet.extend_from_slice(&data[..len]); + } + None => match remote_addr.into() { + UdpgwAddr::IPV4(addr_ipv4) => { + pack_len += std::mem::size_of::(); + packet.extend_from_slice(&(pack_len as u16).to_le_bytes()); + packet.extend_from_slice(&[flags]); + packet.extend_from_slice(&conid.to_le_bytes()); + packet.extend_from_slice(&addr_ipv4.addr_ip.to_be_bytes()); + packet.extend_from_slice(&addr_ipv4.addr_port.to_be_bytes()); + packet.extend_from_slice(&data[..len]); + } + UdpgwAddr::IPV6(addr_ipv6) => { + if !ipv6_enabled { + return Err("ipv6 not support".into()); + } + flags = UDPGW_FLAG_IPV6; + pack_len += std::mem::size_of::(); + packet.extend_from_slice(&(pack_len as u16).to_le_bytes()); + packet.extend_from_slice(&[flags]); + packet.extend_from_slice(&conid.to_le_bytes()); + packet.extend_from_slice(&addr_ipv6.addr_ip); + packet.extend_from_slice(&addr_ipv6.addr_port.to_be_bytes()); + packet.extend_from_slice(&data[..len]); + } + }, + } + + stream.inner.write_all(&packet).await?; + + stream.last_activity = std::time::Instant::now(); + + Ok(()) + } +}