support udp gateway mode

This commit is contained in:
suchao 2024-10-19 13:40:35 +08:00
parent 87c2b666ab
commit 0a833d69a6
5 changed files with 309 additions and 202 deletions

View file

@ -70,14 +70,18 @@ pub struct Args {
#[arg(short, long, default_value = if cfg!(target_os = "linux") { "false" } else { "true" })]
pub setup: bool,
/// DNS handling strategy
#[arg(short, long, value_name = "strategy", value_enum, default_value = "direct")]
pub dns: ArgDns,
/// UDP gateway address
#[arg(long, value_name = "IP:PORT")]
pub udpgw_bind_addr: Option<SocketAddr>,
/// Max udpgw connections
#[arg(long, value_name = "number", default_value = "100")]
pub max_udpgw_connections: u16,
/// DNS handling strategy
#[arg(short, long, value_name = "strategy", value_enum, default_value = "direct")]
pub dns: ArgDns,
/// DNS resolver address
#[arg(long, value_name = "IP", default_value = "8.8.8.8")]
pub dns_addr: IpAddr,
@ -149,6 +153,7 @@ impl Default for Args {
ipv6_enabled: false,
setup,
udpgw_bind_addr: None,
max_udpgw_connections: 100,
dns: ArgDns::default(),
dns_addr: "8.8.8.8".parse().unwrap(),
bypass: vec![],

View file

@ -1,4 +1,3 @@
use std::collections::HashMap;
use std::mem;
use std::net::Ipv4Addr;
use std::net::SocketAddr;
@ -11,25 +10,24 @@ use tokio::net::TcpListener;
use tokio::net::UdpSocket;
use tokio::sync::mpsc;
use tokio::sync::mpsc::Sender;
use tokio::sync::Mutex;
pub use tun2proxy::udpgw::*;
use tun2proxy::ArgVerbosity;
use tun2proxy::Result;
pub(crate) const CLIENT_DISCONNECT_TIMEOUT: tokio::time::Duration = std::time::Duration::from_secs(60);
pub(crate) const CLIENT_DISCONNECT_TIMEOUT: tokio::time::Duration = std::time::Duration::from_secs(30);
#[derive(Debug)]
struct Connection {
struct UdpRequest {
flags: u8,
server_addr: SocketAddr,
conid: u16,
data: Vec<u8>,
}
#[derive(Debug)]
struct Client {
#[allow(dead_code)]
addr: SocketAddr,
buf: Vec<u8>,
connections: Arc<Mutex<HashMap<u16, Connection>>>,
last_activity: std::time::Instant,
}
@ -43,6 +41,10 @@ pub struct UdpGwArgs {
#[arg(short, long, value_name = "level", value_enum, default_value = "info")]
pub verbosity: ArgVerbosity,
/// Daemonize for unix family or run as Windows service
#[arg(long)]
pub daemonize: bool,
/// UDP timeout in seconds
#[arg(long, value_name = "seconds", default_value = "3")]
pub udp_timeout: u64,
@ -56,11 +58,10 @@ impl UdpGwArgs {
#[allow(clippy::let_and_return)]
pub fn parse_args() -> Self {
use clap::Parser;
let args = Self::parse();
args
Self::parse()
}
}
async fn send_error_response(tx: Sender<Vec<u8>>, con: &mut Connection) {
async fn send_error(tx: Sender<Vec<u8>>, con: &mut UdpRequest) {
let mut error_packet = vec![];
error_packet.extend_from_slice(&(std::mem::size_of::<UdpgwHeader>() as u16).to_le_bytes());
error_packet.extend_from_slice(&[UDPGW_FLAG_ERR]);
@ -70,7 +71,13 @@ async fn send_error_response(tx: Sender<Vec<u8>>, con: &mut Connection) {
}
}
pub fn parse_udp_req_data(udp_mtu: u16, data_len: usize, data: &[u8]) -> Result<(&[u8], u8, u16, SocketAddr)> {
async fn send_keepalive_response(tx: Sender<Vec<u8>>, keepalive_packet: &[u8]) {
if let Err(e) = tx.send(keepalive_packet.to_vec()).await {
log::error!("send keepalive response error {:?}", e);
}
}
pub fn parse_udp(udp_mtu: u16, data_len: usize, data: &[u8]) -> Result<(&[u8], u8, u16, SocketAddr)> {
if data_len < mem::size_of::<UdpgwHeader>() {
return Err("Invalid udpgw data".into());
}
@ -85,10 +92,9 @@ pub fn parse_udp_req_data(udp_mtu: u16, data_len: usize, data: &[u8]) -> Result<
// keepalive
if flags & UDPGW_FLAG_KEEPALIVE != 0 {
return Ok((data, UDPGW_FLAG_KEEPALIVE, 0, SocketAddrV4::new(Ipv4Addr::from(0), 0).into()));
return Ok((data, flags, conid, SocketAddrV4::new(Ipv4Addr::from(0), 0).into()));
}
// parse address
let ip_data = &data[mem::size_of::<UdpgwHeader>()..];
let mut data_len = data_len - mem::size_of::<UdpgwHeader>();
// port_len + min(ipv4/ipv6/(domain_len + 1))
@ -107,7 +113,6 @@ pub fn parse_udp_req_data(udp_mtu: u16, data_len: usize, data: &[u8]) -> Result<
.to_socket_addrs()?
.next()
.ok_or(format!("Invalid address {}", target_str))?;
// check payload length
if data_len < 2 + domain.len() {
return Err("Invalid udpgw data".into());
}
@ -136,7 +141,6 @@ pub fn parse_udp_req_data(udp_mtu: u16, data_len: usize, data: &[u8]) -> Result<
};
data_len -= mem::size_of::<UdpgwAddrIpv6>();
// check payload length
if data_len > udp_mtu as usize {
return Err("too much data".into());
}
@ -157,7 +161,6 @@ pub fn parse_udp_req_data(udp_mtu: u16, data_len: usize, data: &[u8]) -> Result<
};
data_len -= mem::size_of::<UdpgwAddrIpv4>();
// check payload length
if data_len > udp_mtu as usize {
return Err("too much data".into());
}
@ -171,15 +174,16 @@ pub fn parse_udp_req_data(udp_mtu: u16, data_len: usize, data: &[u8]) -> Result<
}
}
async fn process_udp(addr: SocketAddr, udp_timeout: u64, tx: Sender<Vec<u8>>, con: &mut Connection) -> Result<()> {
async fn process_udp(addr: SocketAddr, udp_timeout: u64, tx: Sender<Vec<u8>>, con: &mut UdpRequest) -> Result<()> {
let std_sock = std::net::UdpSocket::bind("0.0.0.0:0")?;
std_sock.set_nonblocking(true)?;
nix::sys::socket::setsockopt(&std_sock, nix::sys::socket::sockopt::ReuseAddr, &true)?;
let socket = UdpSocket::from_std(std_sock)?;
socket.send_to(&con.data, &con.server_addr).await?;
con.data.resize(2048, 0);
match tokio::time::timeout(tokio::time::Duration::from_secs(udp_timeout), socket.recv_from(&mut con.data[..])).await? {
Ok((len, _addr)) => {
match tokio::time::timeout(tokio::time::Duration::from_secs(udp_timeout), socket.recv_from(&mut con.data)).await {
Ok(ret) => {
let (len, _addr) = ret?;
let mut packet = vec![];
let mut pack_len = mem::size_of::<UdpgwHeader>() + len;
match con.server_addr.into() {
@ -203,17 +207,17 @@ async fn process_udp(addr: SocketAddr, udp_timeout: u64, tx: Sender<Vec<u8>>, co
}
}
if let Err(e) = tx.send(packet).await {
log::error!("client {} send udp response error {:?}", addr, e);
log::error!("client {} send udp response {}", addr, e);
}
}
Err(e) => {
log::error!("client {} udp recv_from error: {:?}", addr, e);
log::warn!("client {} udp recv_from {}", addr, e);
}
}
Ok(())
}
async fn process_client_udp_req<'a>(args: Arc<UdpGwArgs>, tx: Sender<Vec<u8>>, mut client: Client, mut tcp_read_stream: ReadHalf<'a>) {
async fn process_client_udp_req<'a>(args: &UdpGwArgs, tx: Sender<Vec<u8>>, mut client: Client, mut tcp_read_stream: ReadHalf<'a>) {
let mut buf = vec![0; args.udp_mtu as usize];
let mut len_buf = [0; mem::size_of::<PackLenHeader>()];
let udp_mtu = args.udp_mtu;
@ -226,7 +230,7 @@ async fn process_client_udp_req<'a>(args: Arc<UdpGwArgs>, tx: Sender<Vec<u8>>, m
}
Err(_e) => {
if client.last_activity.elapsed() >= CLIENT_DISCONNECT_TIMEOUT {
log::warn!("client {} last_activity elapsed", client.addr);
log::debug!("client {} last_activity elapsed", client.addr);
return;
}
continue;
@ -244,8 +248,7 @@ async fn process_client_udp_req<'a>(args: Arc<UdpGwArgs>, tx: Sender<Vec<u8>>, m
log::error!("client {} received packet too long", client.addr);
break;
}
log::info!("client {} recvied packet len {}", client.addr, packet_len);
buf.resize(packet_len as usize, 0);
log::debug!("client {} recvied packet len {}", client.addr, packet_len);
client.buf.clear();
let mut left_len: usize = packet_len as usize;
while left_len > 0 {
@ -260,48 +263,37 @@ async fn process_client_udp_req<'a>(args: Arc<UdpGwArgs>, tx: Sender<Vec<u8>>, m
}
}
client.last_activity = std::time::Instant::now();
let ret = parse_udp_req_data(udp_mtu, client.buf.len(), &client.buf);
let ret = parse_udp(udp_mtu, client.buf.len(), &client.buf);
if let Ok((udpdata, flags, conid, reqaddr)) = ret {
if flags & UDPGW_FLAG_KEEPALIVE != 0 {
log::debug!("client {} recvied keepalive packet", client.addr);
log::debug!("client {} send keepalive", client.addr);
send_keepalive_response(tx.clone(), udpdata).await;
continue;
}
log::debug!(
"client {} recvied udp data,flags:{},conid:{},addr:{:?},data len:{}",
"client {} received udp data,flags:{},conid:{},addr:{:?},data len:{}",
client.addr,
flags,
conid,
reqaddr,
udpdata.len()
);
let mut con_lock = client.connections.lock().await;
let con = con_lock.get_mut(&conid);
if let Some(conn) = con {
conn.data.clear();
conn.data.extend_from_slice(udpdata);
if let Err(e) = process_udp(client.addr, udp_timeout, tx.clone(), conn).await {
log::error!("client {} process_udp error: {:?}", client.addr, e);
send_error_response(tx.clone(), conn).await;
continue;
}
} else {
drop(con_lock);
let mut conn = Connection {
let mut req = UdpRequest {
server_addr: reqaddr,
conid,
flags,
data: udpdata.to_vec(),
};
if let Err(e) = process_udp(client.addr, udp_timeout, tx.clone(), &mut conn).await {
send_error_response(tx.clone(), &mut conn).await;
log::error!("client {} process_udp error: {:?}", client.addr, e);
continue;
}
client.connections.lock().await.insert(conid, conn);
let tx1 = tx.clone();
let tx2 = tx.clone();
tokio::spawn(async move {
if let Err(e) = process_udp(client.addr, udp_timeout, tx1, &mut req).await {
send_error(tx2, &mut req).await;
log::error!("client {} process_udp {}", client.addr, e);
}
});
} else {
log::error!("client {} parse_udp_data {:?}", client.addr, ret.err());
continue;
}
}
Err(_) => {
@ -318,32 +310,52 @@ async fn main() -> Result<()> {
let tcp_listener = TcpListener::bind(args.listen_addr).await?;
log::info!("UDP GW Server started");
let default = format!("{:?},hickory_proto=warn", args.verbosity);
let default = format!("{:?}", args.verbosity);
env_logger::Builder::from_env(env_logger::Env::default().default_filter_or(default)).init();
log::info!("UDP GW Server started");
#[cfg(unix)]
if args.daemonize {
let stdout = std::fs::File::create("/tmp/udpgw.out")?;
let stderr = std::fs::File::create("/tmp/udpgw.err")?;
let daemonize = daemonize::Daemonize::new()
.working_directory("/tmp")
.umask(0o777)
.stdout(stdout)
.stderr(stderr)
.privileged_action(|| "Executed before drop privileges");
let _ = daemonize
.start()
.map_err(|e| format!("Failed to daemonize process, error:{:?}", e))?;
}
#[cfg(target_os = "windows")]
if args.daemonize {
tun2proxy::win_svc::start_service()?;
return Ok(());
}
loop {
let (mut tcp_stream, addr) = tcp_listener.accept().await?;
let client = Client {
addr,
buf: vec![],
connections: Arc::new(Mutex::new(HashMap::new())),
last_activity: std::time::Instant::now(),
};
log::info!("client {} connected", addr);
let params = args.clone();
let params = Arc::clone(&args);
tokio::spawn(async move {
let (tx, mut rx) = mpsc::channel::<Vec<u8>>(100);
let (tcp_read_stream, mut tcp_write_stream) = tcp_stream.split();
tokio::select! {
_ = process_client_udp_req(params, tx, client, tcp_read_stream) =>{}
_ = process_client_udp_req(&params, tx, client, tcp_read_stream) =>{}
_ = async {
loop
{
if let Some(udp_response) = rx.recv().await {
log::info!("client {} send udp data len:{}", addr, udp_response.len(),);
log::debug!("send udp_response len {}",udp_response.len());
let _ = tcp_write_stream.write(&udp_response).await;
}
}

View file

@ -47,9 +47,6 @@ pub enum Error {
#[cfg(target_os = "linux")]
#[error("bincode::Error {0:?}")]
BincodeError(#[from] bincode::Error),
#[error("tokio::time::error::Elapsed")]
Timeout(#[from] tokio::time::error::Elapsed),
}
impl From<&str> for Error {

View file

@ -24,7 +24,7 @@ use tokio::{
pub use tokio_util::sync::CancellationToken;
use tproxy_config::is_private_ip;
use udp_stream::UdpStream;
use udpgw::{UdpGwClientStream, UdpGwResponse, UDPGW_KEEPALIVE_TIME, UDPGW_MAX_CONNECTIONS};
use udpgw::{UdpGwClientStream, UdpGwResponse, UDPGW_KEEPALIVE_TIME};
pub use {
args::{ArgDns, ArgProxy, ArgVerbosity, Args, ProxyType},
@ -238,7 +238,7 @@ where
None => None,
Some(addr) => {
log::info!("UDPGW enabled");
let client = Arc::new(UdpGwClient::new(mtu, UDPGW_MAX_CONNECTIONS, UDPGW_KEEPALIVE_TIME, addr));
let client = Arc::new(UdpGwClient::new(mtu, args.max_udpgw_connections, UDPGW_KEEPALIVE_TIME, args.udp_timeout, addr));
let client_keepalive = client.clone();
tokio::spawn(async move {
client_keepalive.heartbeat_task().await;
@ -485,6 +485,7 @@ async fn handle_udp_gateway_session(
};
let udpinfo = SessionInfo::new(udp_stack.local_addr(), udp_stack.peer_addr(), IpProtocol::Udp);
let udp_mtu = udpgw_client.get_udp_mtu();
let udp_timeout = udpgw_client.get_udp_timeout();
let mut server_stream: UdpGwClientStream;
let server = udpgw_client.get_server_connection().await;
match server {
@ -492,10 +493,12 @@ async fn handle_udp_gateway_session(
server_stream = server;
}
None => {
log::info!("Beginning {}", session_info);
if udpgw_client.is_full().await {
return Err("max udpgw connection limit reached".into());
}
let mut tcp_server_stream = create_tcp_stream(&socket_queue, server_addr).await?;
if let Err(e) = handle_proxy_session(&mut tcp_server_stream, proxy_handler).await {
return Err(e);
return Err(format!("udpgw connection error: {}",e).into());
}
server_stream = UdpGwClientStream::new(udp_mtu, tcp_server_stream);
}
@ -503,77 +506,93 @@ async fn handle_udp_gateway_session(
let udp_server_addr = udp_stack.peer_addr();
let tcp_local_addr = server_stream.local_addr().clone();
match domain_name {
Some(ref d) => {
log::info!("Beginning {}, domain:{}", udpinfo, d);
log::info!("Beginning {} <- {}, domain:{}", udpinfo, &tcp_local_addr, d);
}
None => {
log::info!("Beginning {}", udpinfo);
log::info!("Beginning {} <- {}", udpinfo, &tcp_local_addr);
}
}
log::info!("Beginning {}", udpinfo);
let Some(mut stream_reader) = server_stream.get_reader() else {
return Err("get reader failed".into());
};
let Some(mut stream_writer) = server_stream.get_writer() else {
return Err("get writer failed".into());
};
loop {
let len = UdpGwClient::recv_udp_packet(&mut udp_stack, &mut server_stream).await;
tokio::select! {
len = UdpGwClient::recv_udp_packet(&mut udp_stack, &mut stream_writer) => {
let read_len;
match len {
Ok(n) => {
if n == 0 {
log::info!("Ending {}", udpinfo);
log::info!("Ending {} <- {}",udpinfo, &tcp_local_addr);
break;
}
read_len = n;
crate::traffic_status::traffic_status_update(n, 0)?;
}
Err(e) => {
log::info!("Ending {} with recv_udp_packet error: {}", udpinfo, e);
log::info!("Ending {} <- {} with recv_udp_packet {}", udpinfo, &tcp_local_addr, e);
break;
}
}
let newid = server_stream.newid();
if let Err(e) =
UdpGwClient::send_udpgw_packet(ipv6_enabled, read_len, udp_server_addr, domain_name.as_ref(), &mut server_stream).await
UdpGwClient::send_udpgw_packet(ipv6_enabled, read_len, udp_server_addr, domain_name.as_ref(),newid,&mut stream_writer).await
{
log::info!(
"{:?},Ending {} with send_udpgw_packet error: {}",
server_stream.local_addr(),
"Ending {} <- {} with send_udpgw_packet {}",
udpinfo,
&tcp_local_addr,
e
);
break;
}
match UdpGwClient::recv_udpgw_packet(udp_mtu, &mut server_stream).await {
server_stream.update_activity();
}
ret = UdpGwClient::recv_udpgw_packet(udp_mtu, udp_timeout, &mut stream_reader) => {
match ret {
Ok(packet) => match packet {
//should not received keepalive
UdpGwResponse::KeepAlive => {
log::error!("Ending {} with recv keepalive", udpinfo);
let _ = server_stream.close().await;
log::error!("Ending {} <- {} with recv keepalive", udpinfo, &tcp_local_addr);
server_stream.close();
break;
}
//server udp may be timeout,can continue to receive udp data?
UdpGwResponse::Error => {
log::info!("Ending {} with recv udp error", udpinfo);
log::info!("Ending {} <- {} with recv udp error", udpinfo, &tcp_local_addr);
server_stream.update_activity();
continue;
}
UdpGwResponse::Data(data) => {
crate::traffic_status::traffic_status_update(0, data.len())?;
let len = data.len();
if let Err(e) = UdpGwClient::send_udp_packet(data, &mut udp_stack).await {
log::info!("Ending {} with send_udp_packet error: {}", udpinfo, e);
log::error!("Ending {} <- {} with send_udp_packet {}", udpinfo, &tcp_local_addr, e);
break;
}
crate::traffic_status::traffic_status_update(0, len)?;
}
},
Err(e) => {
log::info!("Ending {} with recv_udpgw_packet error: {}", udpinfo, e);
log::warn!("Ending {} <- {} with recv_udpgw_packet {}", udpinfo, &tcp_local_addr, e);
break;
}
}
server_stream.update_activity();
}
}
}
if !server_stream.is_closed() {
udpgw_client.release_server_connection(server_stream).await;
udpgw_client.release_server_connection_with_stream(server_stream,stream_reader,stream_writer).await;
}
Ok(())

View file

@ -4,9 +4,8 @@ use std::collections::VecDeque;
use std::hash::Hash;
use std::mem;
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
use tokio::net::TcpStream;
use tokio::sync::Mutex;
use tokio::time::{sleep, Duration};
@ -108,45 +107,56 @@ pub(crate) enum UdpGwResponse<'a> {
}
#[derive(Debug)]
pub(crate) struct UdpGwClientStream {
inner: TcpStream,
conid: u16,
pub(crate) struct UdpGwClientStreamWriter {
inner: OwnedWriteHalf,
tmp_buf: Vec<u8>,
send_buf: Vec<u8>,
}
#[derive(Debug)]
pub(crate) struct UdpGwClientStreamReader {
inner: OwnedReadHalf,
recv_buf: Vec<u8>,
}
#[derive(Debug)]
pub(crate) struct UdpGwClientStream {
local_addr: String,
writer: Option<UdpGwClientStreamWriter>,
reader: Option<UdpGwClientStreamReader>,
conid: u16,
closed: bool,
last_activity: std::time::Instant,
}
impl AsyncWrite for UdpGwClientStream {
fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<tokio::io::Result<usize>> {
Pin::new(&mut self.inner).poll_write(cx, buf)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.inner).poll_flush(cx)
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.inner).poll_shutdown(cx)
}
}
impl AsyncRead for UdpGwClientStream {
fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.inner).poll_read(cx, buf)
}
}
impl UdpGwClientStream {
pub async fn close(&mut self) -> Result<()> {
self.inner.shutdown().await?;
pub fn close(&mut self) {
self.closed = true;
Ok(())
}
pub fn local_addr(&self) -> Result<SocketAddr> {
Ok(self.inner.local_addr()?)
pub fn get_reader(&mut self) -> Option<UdpGwClientStreamReader> {
self.reader.take()
}
pub fn set_reader(&mut self, mut reader: Option<UdpGwClientStreamReader>) {
self.reader = reader.take();
}
pub fn set_writer(&mut self, mut writer: Option<UdpGwClientStreamWriter>) {
self.writer = writer.take();
}
pub fn get_writer(&mut self) -> Option<UdpGwClientStreamWriter> {
self.writer.take()
}
pub fn local_addr(&self) -> &String {
&self.local_addr
}
pub fn update_activity(&mut self) {
self.last_activity = std::time::Instant::now();
}
pub fn is_closed(&mut self) -> bool {
self.closed
}
@ -156,16 +166,28 @@ impl UdpGwClientStream {
}
pub fn newid(&mut self) -> u16 {
let next = self.conid;
self.conid += 1;
return next;
self.conid
}
pub fn new(udp_mtu: u16, tcp_server_stream: TcpStream) -> Self {
UdpGwClientStream {
inner: tcp_server_stream,
let local_addr = tcp_server_stream
.local_addr()
.unwrap_or_else(|_| "0.0.0.0:0".parse::<SocketAddr>().unwrap())
.to_string();
let (rx, tx) = tcp_server_stream.into_split();
let writer = UdpGwClientStreamWriter {
inner: tx,
tmp_buf: vec![0; udp_mtu.into()],
send_buf: vec![0; udp_mtu.into()],
};
let reader = UdpGwClientStreamReader {
inner: rx,
recv_buf: vec![0; udp_mtu.into()],
};
UdpGwClientStream {
local_addr,
reader: Some(reader),
writer: Some(writer),
last_activity: std::time::Instant::now(),
closed: false,
conid: 0,
@ -176,7 +198,8 @@ impl UdpGwClientStream {
#[derive(Debug)]
pub(crate) struct UdpGwClient {
udp_mtu: u16,
max_connections: usize,
max_connections: u16,
udp_timeout: u64,
keepalive_time: Duration,
udpgw_bind_addr: SocketAddr,
keepalive_packet: Vec<u8>,
@ -184,18 +207,19 @@ pub(crate) struct UdpGwClient {
}
impl UdpGwClient {
pub fn new(udp_mtu: u16, max_connections: usize, keepalive_time: Duration, udpgw_bind_addr: SocketAddr) -> Self {
pub fn new(udp_mtu: u16, max_connections: u16, keepalive_time: Duration, udp_timeout: u64, udpgw_bind_addr: SocketAddr) -> Self {
let mut keepalive_packet = vec![];
keepalive_packet.extend_from_slice(&(std::mem::size_of::<UdpgwHeader>() as u16).to_le_bytes());
keepalive_packet.extend_from_slice(&[UDPGW_FLAG_KEEPALIVE, 0, 0]);
let server_connections = Mutex::new(VecDeque::new());
let server_connections = Mutex::new(VecDeque::with_capacity(max_connections as usize));
return UdpGwClient {
udp_mtu,
max_connections,
udp_timeout,
udpgw_bind_addr,
keepalive_time,
keepalive_packet,
server_connections: server_connections,
server_connections,
};
}
@ -203,12 +227,33 @@ impl UdpGwClient {
self.udp_mtu
}
pub(crate) fn get_udp_timeout(&self) -> u64 {
self.udp_timeout
}
pub(crate) async fn is_full(&self) -> bool {
self.server_connections.lock().await.len() >= self.max_connections as usize
}
pub(crate) async fn get_server_connection(&self) -> Option<UdpGwClientStream> {
self.server_connections.lock().await.pop_front()
}
pub(crate) async fn release_server_connection(&self, stream: UdpGwClientStream) {
if self.server_connections.lock().await.len() < self.max_connections {
if self.server_connections.lock().await.len() < self.max_connections as usize {
self.server_connections.lock().await.push_back(stream);
}
}
pub(crate) async fn release_server_connection_with_stream(
&self,
mut stream: UdpGwClientStream,
reader: UdpGwClientStreamReader,
writer: UdpGwClientStreamWriter,
) {
if self.server_connections.lock().await.len() < self.max_connections as usize {
stream.set_reader(Some(reader));
stream.set_writer(Some(writer));
self.server_connections.lock().await.push_back(stream);
}
}
@ -217,6 +262,7 @@ impl UdpGwClient {
return self.udpgw_bind_addr;
}
/// Heartbeat task asynchronous function to periodically check and maintain the active state of the server connection.
pub(crate) async fn heartbeat_task(&self) {
loop {
sleep(self.keepalive_time).await;
@ -225,28 +271,35 @@ impl UdpGwClient {
self.release_server_connection(stream).await;
continue;
}
log::debug!("{:?}:{} send keepalive", stream.local_addr(), stream.id());
if let Err(e) = stream.write_all(&self.keepalive_packet).await {
let _ = stream.close().await;
log::warn!("{:?}:{} Heartbeat failed: {}", stream.local_addr(), stream.id(), e);
let Some(mut stream_reader) = stream.get_reader() else {
continue;
};
let Some(mut stream_writer) = stream.get_writer() else {
continue;
};
log::debug!("{:?}:{} send keepalive", stream_writer.inner.local_addr(), stream.id());
if let Err(e) = stream_writer.inner.write_all(&self.keepalive_packet).await {
log::warn!("{:?}:{} Heartbeat failed: {}", stream_writer.inner.local_addr(), stream.id(), e);
} else {
stream.last_activity = std::time::Instant::now();
match UdpGwClient::recv_udpgw_packet(self.udp_mtu, &mut stream).await {
match UdpGwClient::recv_udpgw_packet(self.udp_mtu, 10, &mut stream_reader).await {
Ok(UdpGwResponse::KeepAlive) => {
self.release_server_connection(stream).await;
continue;
}
//shoud not receive other
_ => {
continue;
stream.last_activity = std::time::Instant::now();
self.release_server_connection_with_stream(stream, stream_reader, stream_writer)
.await;
}
//shoud not receive other type
_ => {}
}
}
}
}
}
pub(crate) fn parse_udp_response(udp_mtu: u16, data_len: usize, data: &[u8]) -> Result<UdpGwResponse> {
/// Parses the UDP response data.
pub(crate) fn parse_udp_response(udp_mtu: u16, data_len: usize, stream: &mut UdpGwClientStreamReader) -> Result<UdpGwResponse> {
let data = &stream.recv_buf;
if data_len < mem::size_of::<UdpgwHeader>() {
return Err("Invalid udpgw data".into());
}
@ -259,7 +312,6 @@ impl UdpGwClient {
let flags = header.flags;
let conid = header.conid;
// parse address
let ip_data = &data[mem::size_of::<UdpgwHeader>()..];
let mut data_len = data_len - mem::size_of::<UdpgwHeader>();
@ -267,7 +319,7 @@ impl UdpGwClient {
return Ok(UdpGwResponse::Error);
}
if flags & UDPGW_FLAG_ERR != 0 {
if flags & UDPGW_FLAG_KEEPALIVE != 0 {
return Ok(UdpGwResponse::KeepAlive);
}
@ -281,7 +333,7 @@ impl UdpGwClient {
addr_port: u16::from_be_bytes([addr_ipv6_bytes[16], addr_ipv6_bytes[17]]),
};
data_len -= mem::size_of::<UdpgwAddrIpv6>();
// check payload length
if data_len > udp_mtu as usize {
return Err("too much data".into());
}
@ -302,7 +354,6 @@ impl UdpGwClient {
};
data_len -= mem::size_of::<UdpgwAddrIpv4>();
// check payload length
if data_len > udp_mtu as usize {
return Err("too much data".into());
}
@ -317,7 +368,7 @@ impl UdpGwClient {
pub(crate) async fn recv_udp_packet(
udp_stack: &mut IpStackUdpStream,
stream: &mut UdpGwClientStream,
stream: &mut UdpGwClientStreamWriter,
) -> std::result::Result<usize, std::io::Error> {
return udp_stack.read(&mut stream.tmp_buf).await;
}
@ -329,22 +380,35 @@ impl UdpGwClient {
return udp_stack.write_all(&packet.udpdata).await;
}
pub(crate) async fn recv_udpgw_packet(udp_mtu: u16, stream: &mut UdpGwClientStream) -> Result<UdpGwResponse> {
stream.recv_buf.resize(2, 0);
/// Receives a UDP gateway packet.
///
/// This function is responsible for receiving packets from the UDP gateway
///
/// # Arguments
/// - `udp_mtu`: The maximum transmission unit size for UDP packets.
/// - `udp_timeout`: The timeout in seconds for receiving UDP packets.
/// - `stream`: A mutable reference to the UDP gateway client stream reader.
///
/// # Returns
/// - `Result<UdpGwResponse>`: Returns a result type containing the parsed UDP gateway response, or an error if one occurs.
pub(crate) async fn recv_udpgw_packet(udp_mtu: u16, udp_timeout: u64, stream: &mut UdpGwClientStreamReader) -> Result<UdpGwResponse> {
let result;
match tokio::time::timeout(tokio::time::Duration::from_secs(10), stream.inner.read(&mut stream.recv_buf)).await {
match tokio::time::timeout(
tokio::time::Duration::from_secs(udp_timeout + 2),
stream.inner.read(&mut stream.recv_buf[..2]),
)
.await
{
Ok(ret) => {
result = ret;
}
Err(_e) => {
let _ = stream.close().await;
return Err(format!("{:?} wait tcp data timeout", stream.local_addr()).into());
return Err(format!("wait tcp data timeout").into());
}
};
match result {
Ok(0) => {
let _ = stream.close().await;
return Err(format!("{:?} tcp connection closed", stream.local_addr()).into());
return Err(format!("tcp connection closed").into());
}
Ok(n) => {
if n < std::mem::size_of::<PackLenHeader>() {
@ -354,41 +418,53 @@ impl UdpGwClient {
if packet_len > udp_mtu {
return Err("packet too long".into());
}
stream.recv_buf.resize(udp_mtu as usize, 0);
let mut left_len: usize = packet_len as usize;
let mut recv_len = 0;
while left_len > 0 {
if let Ok(len) = stream.inner.read(&mut stream.recv_buf[recv_len..left_len]).await {
if len == 0 {
let _ = stream.close().await;
return Err(format!("{:?} tcp connection closed", stream.local_addr()).into());
return Err("tcp connection closed".into());
}
recv_len += len;
left_len -= len;
} else {
let _ = stream.close().await;
return Err(format!("{:?} tcp connection closed", stream.local_addr()).into());
return Err("tcp connection closed".into());
}
}
stream.last_activity = std::time::Instant::now();
return UdpGwClient::parse_udp_response(udp_mtu, packet_len as usize, &stream.recv_buf);
return UdpGwClient::parse_udp_response(udp_mtu, packet_len as usize, stream);
}
Err(_) => {
let _ = stream.close().await;
return Err(format!("{:?} tcp read error", stream.local_addr()).into());
return Err("tcp read error".into());
}
}
}
/// Sends a UDP gateway packet.
///
/// This function constructs and sends a UDP gateway packet based on the IPv6 enabled status, data length,
/// remote address, domain (if any), connection ID, and the UDP gateway client writer stream.
///
/// # Arguments
///
/// * `ipv6_enabled` - Whether IPv6 is enabled
/// * `len` - Length of the data packet
/// * `remote_addr` - Remote address
/// * `domain` - Target domain (optional)
/// * `conid` - Connection ID
/// * `stream` - UDP gateway client writer stream
///
/// # Returns
///
/// Returns `Ok(())` if the packet is sent successfully, otherwise returns an error.
pub(crate) async fn send_udpgw_packet(
ipv6_enabled: bool,
len: usize,
remote_addr: SocketAddr,
domain: Option<&String>,
stream: &mut UdpGwClientStream,
conid: u16,
stream: &mut UdpGwClientStreamWriter,
) -> Result<()> {
stream.send_buf.clear();
let conid = stream.newid();
let data = &stream.tmp_buf;
let mut pack_len = std::mem::size_of::<UdpgwHeader>() + len;
let packet = &mut stream.send_buf;
@ -442,8 +518,6 @@ impl UdpGwClient {
stream.inner.write_all(&packet).await?;
stream.last_activity = std::time::Instant::now();
Ok(())
}
}