mirror of
https://github.com/tun2proxy/tun2proxy.git
synced 2025-04-20 13:59:10 +00:00
support udp gateway mode
This commit is contained in:
parent
87c2b666ab
commit
0a833d69a6
5 changed files with 309 additions and 202 deletions
13
src/args.rs
13
src/args.rs
|
@ -70,14 +70,18 @@ pub struct Args {
|
|||
#[arg(short, long, default_value = if cfg!(target_os = "linux") { "false" } else { "true" })]
|
||||
pub setup: bool,
|
||||
|
||||
/// DNS handling strategy
|
||||
#[arg(short, long, value_name = "strategy", value_enum, default_value = "direct")]
|
||||
pub dns: ArgDns,
|
||||
|
||||
/// UDP gateway address
|
||||
#[arg(long, value_name = "IP:PORT")]
|
||||
pub udpgw_bind_addr: Option<SocketAddr>,
|
||||
|
||||
/// Max udpgw connections
|
||||
#[arg(long, value_name = "number", default_value = "100")]
|
||||
pub max_udpgw_connections: u16,
|
||||
|
||||
/// DNS handling strategy
|
||||
#[arg(short, long, value_name = "strategy", value_enum, default_value = "direct")]
|
||||
pub dns: ArgDns,
|
||||
|
||||
/// DNS resolver address
|
||||
#[arg(long, value_name = "IP", default_value = "8.8.8.8")]
|
||||
pub dns_addr: IpAddr,
|
||||
|
@ -149,6 +153,7 @@ impl Default for Args {
|
|||
ipv6_enabled: false,
|
||||
setup,
|
||||
udpgw_bind_addr: None,
|
||||
max_udpgw_connections: 100,
|
||||
dns: ArgDns::default(),
|
||||
dns_addr: "8.8.8.8".parse().unwrap(),
|
||||
bypass: vec![],
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
use std::collections::HashMap;
|
||||
use std::mem;
|
||||
use std::net::Ipv4Addr;
|
||||
use std::net::SocketAddr;
|
||||
|
@ -11,25 +10,24 @@ use tokio::net::TcpListener;
|
|||
use tokio::net::UdpSocket;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::mpsc::Sender;
|
||||
use tokio::sync::Mutex;
|
||||
pub use tun2proxy::udpgw::*;
|
||||
use tun2proxy::ArgVerbosity;
|
||||
use tun2proxy::Result;
|
||||
pub(crate) const CLIENT_DISCONNECT_TIMEOUT: tokio::time::Duration = std::time::Duration::from_secs(60);
|
||||
pub(crate) const CLIENT_DISCONNECT_TIMEOUT: tokio::time::Duration = std::time::Duration::from_secs(30);
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Connection {
|
||||
struct UdpRequest {
|
||||
flags: u8,
|
||||
server_addr: SocketAddr,
|
||||
conid: u16,
|
||||
data: Vec<u8>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Client {
|
||||
#[allow(dead_code)]
|
||||
addr: SocketAddr,
|
||||
buf: Vec<u8>,
|
||||
connections: Arc<Mutex<HashMap<u16, Connection>>>,
|
||||
last_activity: std::time::Instant,
|
||||
}
|
||||
|
||||
|
@ -43,6 +41,10 @@ pub struct UdpGwArgs {
|
|||
#[arg(short, long, value_name = "level", value_enum, default_value = "info")]
|
||||
pub verbosity: ArgVerbosity,
|
||||
|
||||
/// Daemonize for unix family or run as Windows service
|
||||
#[arg(long)]
|
||||
pub daemonize: bool,
|
||||
|
||||
/// UDP timeout in seconds
|
||||
#[arg(long, value_name = "seconds", default_value = "3")]
|
||||
pub udp_timeout: u64,
|
||||
|
@ -56,11 +58,10 @@ impl UdpGwArgs {
|
|||
#[allow(clippy::let_and_return)]
|
||||
pub fn parse_args() -> Self {
|
||||
use clap::Parser;
|
||||
let args = Self::parse();
|
||||
args
|
||||
Self::parse()
|
||||
}
|
||||
}
|
||||
async fn send_error_response(tx: Sender<Vec<u8>>, con: &mut Connection) {
|
||||
async fn send_error(tx: Sender<Vec<u8>>, con: &mut UdpRequest) {
|
||||
let mut error_packet = vec![];
|
||||
error_packet.extend_from_slice(&(std::mem::size_of::<UdpgwHeader>() as u16).to_le_bytes());
|
||||
error_packet.extend_from_slice(&[UDPGW_FLAG_ERR]);
|
||||
|
@ -70,7 +71,13 @@ async fn send_error_response(tx: Sender<Vec<u8>>, con: &mut Connection) {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn parse_udp_req_data(udp_mtu: u16, data_len: usize, data: &[u8]) -> Result<(&[u8], u8, u16, SocketAddr)> {
|
||||
async fn send_keepalive_response(tx: Sender<Vec<u8>>, keepalive_packet: &[u8]) {
|
||||
if let Err(e) = tx.send(keepalive_packet.to_vec()).await {
|
||||
log::error!("send keepalive response error {:?}", e);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn parse_udp(udp_mtu: u16, data_len: usize, data: &[u8]) -> Result<(&[u8], u8, u16, SocketAddr)> {
|
||||
if data_len < mem::size_of::<UdpgwHeader>() {
|
||||
return Err("Invalid udpgw data".into());
|
||||
}
|
||||
|
@ -85,10 +92,9 @@ pub fn parse_udp_req_data(udp_mtu: u16, data_len: usize, data: &[u8]) -> Result<
|
|||
|
||||
// keepalive
|
||||
if flags & UDPGW_FLAG_KEEPALIVE != 0 {
|
||||
return Ok((data, UDPGW_FLAG_KEEPALIVE, 0, SocketAddrV4::new(Ipv4Addr::from(0), 0).into()));
|
||||
return Ok((data, flags, conid, SocketAddrV4::new(Ipv4Addr::from(0), 0).into()));
|
||||
}
|
||||
|
||||
// parse address
|
||||
let ip_data = &data[mem::size_of::<UdpgwHeader>()..];
|
||||
let mut data_len = data_len - mem::size_of::<UdpgwHeader>();
|
||||
// port_len + min(ipv4/ipv6/(domain_len + 1))
|
||||
|
@ -107,7 +113,6 @@ pub fn parse_udp_req_data(udp_mtu: u16, data_len: usize, data: &[u8]) -> Result<
|
|||
.to_socket_addrs()?
|
||||
.next()
|
||||
.ok_or(format!("Invalid address {}", target_str))?;
|
||||
// check payload length
|
||||
if data_len < 2 + domain.len() {
|
||||
return Err("Invalid udpgw data".into());
|
||||
}
|
||||
|
@ -136,7 +141,6 @@ pub fn parse_udp_req_data(udp_mtu: u16, data_len: usize, data: &[u8]) -> Result<
|
|||
};
|
||||
data_len -= mem::size_of::<UdpgwAddrIpv6>();
|
||||
|
||||
// check payload length
|
||||
if data_len > udp_mtu as usize {
|
||||
return Err("too much data".into());
|
||||
}
|
||||
|
@ -157,7 +161,6 @@ pub fn parse_udp_req_data(udp_mtu: u16, data_len: usize, data: &[u8]) -> Result<
|
|||
};
|
||||
data_len -= mem::size_of::<UdpgwAddrIpv4>();
|
||||
|
||||
// check payload length
|
||||
if data_len > udp_mtu as usize {
|
||||
return Err("too much data".into());
|
||||
}
|
||||
|
@ -171,15 +174,16 @@ pub fn parse_udp_req_data(udp_mtu: u16, data_len: usize, data: &[u8]) -> Result<
|
|||
}
|
||||
}
|
||||
|
||||
async fn process_udp(addr: SocketAddr, udp_timeout: u64, tx: Sender<Vec<u8>>, con: &mut Connection) -> Result<()> {
|
||||
async fn process_udp(addr: SocketAddr, udp_timeout: u64, tx: Sender<Vec<u8>>, con: &mut UdpRequest) -> Result<()> {
|
||||
let std_sock = std::net::UdpSocket::bind("0.0.0.0:0")?;
|
||||
std_sock.set_nonblocking(true)?;
|
||||
nix::sys::socket::setsockopt(&std_sock, nix::sys::socket::sockopt::ReuseAddr, &true)?;
|
||||
let socket = UdpSocket::from_std(std_sock)?;
|
||||
socket.send_to(&con.data, &con.server_addr).await?;
|
||||
con.data.resize(2048, 0);
|
||||
match tokio::time::timeout(tokio::time::Duration::from_secs(udp_timeout), socket.recv_from(&mut con.data[..])).await? {
|
||||
Ok((len, _addr)) => {
|
||||
match tokio::time::timeout(tokio::time::Duration::from_secs(udp_timeout), socket.recv_from(&mut con.data)).await {
|
||||
Ok(ret) => {
|
||||
let (len, _addr) = ret?;
|
||||
let mut packet = vec![];
|
||||
let mut pack_len = mem::size_of::<UdpgwHeader>() + len;
|
||||
match con.server_addr.into() {
|
||||
|
@ -203,17 +207,17 @@ async fn process_udp(addr: SocketAddr, udp_timeout: u64, tx: Sender<Vec<u8>>, co
|
|||
}
|
||||
}
|
||||
if let Err(e) = tx.send(packet).await {
|
||||
log::error!("client {} send udp response error {:?}", addr, e);
|
||||
log::error!("client {} send udp response {}", addr, e);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
log::error!("client {} udp recv_from error: {:?}", addr, e);
|
||||
log::warn!("client {} udp recv_from {}", addr, e);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn process_client_udp_req<'a>(args: Arc<UdpGwArgs>, tx: Sender<Vec<u8>>, mut client: Client, mut tcp_read_stream: ReadHalf<'a>) {
|
||||
async fn process_client_udp_req<'a>(args: &UdpGwArgs, tx: Sender<Vec<u8>>, mut client: Client, mut tcp_read_stream: ReadHalf<'a>) {
|
||||
let mut buf = vec![0; args.udp_mtu as usize];
|
||||
let mut len_buf = [0; mem::size_of::<PackLenHeader>()];
|
||||
let udp_mtu = args.udp_mtu;
|
||||
|
@ -226,7 +230,7 @@ async fn process_client_udp_req<'a>(args: Arc<UdpGwArgs>, tx: Sender<Vec<u8>>, m
|
|||
}
|
||||
Err(_e) => {
|
||||
if client.last_activity.elapsed() >= CLIENT_DISCONNECT_TIMEOUT {
|
||||
log::warn!("client {} last_activity elapsed", client.addr);
|
||||
log::debug!("client {} last_activity elapsed", client.addr);
|
||||
return;
|
||||
}
|
||||
continue;
|
||||
|
@ -244,8 +248,7 @@ async fn process_client_udp_req<'a>(args: Arc<UdpGwArgs>, tx: Sender<Vec<u8>>, m
|
|||
log::error!("client {} received packet too long", client.addr);
|
||||
break;
|
||||
}
|
||||
log::info!("client {} recvied packet len {}", client.addr, packet_len);
|
||||
buf.resize(packet_len as usize, 0);
|
||||
log::debug!("client {} recvied packet len {}", client.addr, packet_len);
|
||||
client.buf.clear();
|
||||
let mut left_len: usize = packet_len as usize;
|
||||
while left_len > 0 {
|
||||
|
@ -260,48 +263,37 @@ async fn process_client_udp_req<'a>(args: Arc<UdpGwArgs>, tx: Sender<Vec<u8>>, m
|
|||
}
|
||||
}
|
||||
client.last_activity = std::time::Instant::now();
|
||||
let ret = parse_udp_req_data(udp_mtu, client.buf.len(), &client.buf);
|
||||
let ret = parse_udp(udp_mtu, client.buf.len(), &client.buf);
|
||||
if let Ok((udpdata, flags, conid, reqaddr)) = ret {
|
||||
if flags & UDPGW_FLAG_KEEPALIVE != 0 {
|
||||
log::debug!("client {} recvied keepalive packet", client.addr);
|
||||
log::debug!("client {} send keepalive", client.addr);
|
||||
send_keepalive_response(tx.clone(), udpdata).await;
|
||||
continue;
|
||||
}
|
||||
log::debug!(
|
||||
"client {} recvied udp data,flags:{},conid:{},addr:{:?},data len:{}",
|
||||
"client {} received udp data,flags:{},conid:{},addr:{:?},data len:{}",
|
||||
client.addr,
|
||||
flags,
|
||||
conid,
|
||||
reqaddr,
|
||||
udpdata.len()
|
||||
);
|
||||
let mut con_lock = client.connections.lock().await;
|
||||
let con = con_lock.get_mut(&conid);
|
||||
if let Some(conn) = con {
|
||||
conn.data.clear();
|
||||
conn.data.extend_from_slice(udpdata);
|
||||
if let Err(e) = process_udp(client.addr, udp_timeout, tx.clone(), conn).await {
|
||||
log::error!("client {} process_udp error: {:?}", client.addr, e);
|
||||
send_error_response(tx.clone(), conn).await;
|
||||
continue;
|
||||
}
|
||||
} else {
|
||||
drop(con_lock);
|
||||
let mut conn = Connection {
|
||||
let mut req = UdpRequest {
|
||||
server_addr: reqaddr,
|
||||
conid,
|
||||
flags,
|
||||
data: udpdata.to_vec(),
|
||||
};
|
||||
if let Err(e) = process_udp(client.addr, udp_timeout, tx.clone(), &mut conn).await {
|
||||
send_error_response(tx.clone(), &mut conn).await;
|
||||
log::error!("client {} process_udp error: {:?}", client.addr, e);
|
||||
continue;
|
||||
}
|
||||
client.connections.lock().await.insert(conid, conn);
|
||||
let tx1 = tx.clone();
|
||||
let tx2 = tx.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = process_udp(client.addr, udp_timeout, tx1, &mut req).await {
|
||||
send_error(tx2, &mut req).await;
|
||||
log::error!("client {} process_udp {}", client.addr, e);
|
||||
}
|
||||
});
|
||||
} else {
|
||||
log::error!("client {} parse_udp_data {:?}", client.addr, ret.err());
|
||||
continue;
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
|
@ -318,32 +310,52 @@ async fn main() -> Result<()> {
|
|||
|
||||
let tcp_listener = TcpListener::bind(args.listen_addr).await?;
|
||||
|
||||
log::info!("UDP GW Server started");
|
||||
|
||||
let default = format!("{:?},hickory_proto=warn", args.verbosity);
|
||||
let default = format!("{:?}", args.verbosity);
|
||||
|
||||
env_logger::Builder::from_env(env_logger::Env::default().default_filter_or(default)).init();
|
||||
|
||||
log::info!("UDP GW Server started");
|
||||
|
||||
#[cfg(unix)]
|
||||
if args.daemonize {
|
||||
let stdout = std::fs::File::create("/tmp/udpgw.out")?;
|
||||
let stderr = std::fs::File::create("/tmp/udpgw.err")?;
|
||||
let daemonize = daemonize::Daemonize::new()
|
||||
.working_directory("/tmp")
|
||||
.umask(0o777)
|
||||
.stdout(stdout)
|
||||
.stderr(stderr)
|
||||
.privileged_action(|| "Executed before drop privileges");
|
||||
let _ = daemonize
|
||||
.start()
|
||||
.map_err(|e| format!("Failed to daemonize process, error:{:?}", e))?;
|
||||
}
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
if args.daemonize {
|
||||
tun2proxy::win_svc::start_service()?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
loop {
|
||||
let (mut tcp_stream, addr) = tcp_listener.accept().await?;
|
||||
let client = Client {
|
||||
addr,
|
||||
buf: vec![],
|
||||
connections: Arc::new(Mutex::new(HashMap::new())),
|
||||
last_activity: std::time::Instant::now(),
|
||||
};
|
||||
log::info!("client {} connected", addr);
|
||||
let params = args.clone();
|
||||
let params = Arc::clone(&args);
|
||||
tokio::spawn(async move {
|
||||
let (tx, mut rx) = mpsc::channel::<Vec<u8>>(100);
|
||||
let (tcp_read_stream, mut tcp_write_stream) = tcp_stream.split();
|
||||
tokio::select! {
|
||||
_ = process_client_udp_req(params, tx, client, tcp_read_stream) =>{}
|
||||
_ = process_client_udp_req(¶ms, tx, client, tcp_read_stream) =>{}
|
||||
_ = async {
|
||||
loop
|
||||
{
|
||||
if let Some(udp_response) = rx.recv().await {
|
||||
log::info!("client {} send udp data len:{}", addr, udp_response.len(),);
|
||||
log::debug!("send udp_response len {}",udp_response.len());
|
||||
let _ = tcp_write_stream.write(&udp_response).await;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -47,9 +47,6 @@ pub enum Error {
|
|||
#[cfg(target_os = "linux")]
|
||||
#[error("bincode::Error {0:?}")]
|
||||
BincodeError(#[from] bincode::Error),
|
||||
|
||||
#[error("tokio::time::error::Elapsed")]
|
||||
Timeout(#[from] tokio::time::error::Elapsed),
|
||||
}
|
||||
|
||||
impl From<&str> for Error {
|
||||
|
|
67
src/lib.rs
67
src/lib.rs
|
@ -24,7 +24,7 @@ use tokio::{
|
|||
pub use tokio_util::sync::CancellationToken;
|
||||
use tproxy_config::is_private_ip;
|
||||
use udp_stream::UdpStream;
|
||||
use udpgw::{UdpGwClientStream, UdpGwResponse, UDPGW_KEEPALIVE_TIME, UDPGW_MAX_CONNECTIONS};
|
||||
use udpgw::{UdpGwClientStream, UdpGwResponse, UDPGW_KEEPALIVE_TIME};
|
||||
|
||||
pub use {
|
||||
args::{ArgDns, ArgProxy, ArgVerbosity, Args, ProxyType},
|
||||
|
@ -238,7 +238,7 @@ where
|
|||
None => None,
|
||||
Some(addr) => {
|
||||
log::info!("UDPGW enabled");
|
||||
let client = Arc::new(UdpGwClient::new(mtu, UDPGW_MAX_CONNECTIONS, UDPGW_KEEPALIVE_TIME, addr));
|
||||
let client = Arc::new(UdpGwClient::new(mtu, args.max_udpgw_connections, UDPGW_KEEPALIVE_TIME, args.udp_timeout, addr));
|
||||
let client_keepalive = client.clone();
|
||||
tokio::spawn(async move {
|
||||
client_keepalive.heartbeat_task().await;
|
||||
|
@ -485,6 +485,7 @@ async fn handle_udp_gateway_session(
|
|||
};
|
||||
let udpinfo = SessionInfo::new(udp_stack.local_addr(), udp_stack.peer_addr(), IpProtocol::Udp);
|
||||
let udp_mtu = udpgw_client.get_udp_mtu();
|
||||
let udp_timeout = udpgw_client.get_udp_timeout();
|
||||
let mut server_stream: UdpGwClientStream;
|
||||
let server = udpgw_client.get_server_connection().await;
|
||||
match server {
|
||||
|
@ -492,10 +493,12 @@ async fn handle_udp_gateway_session(
|
|||
server_stream = server;
|
||||
}
|
||||
None => {
|
||||
log::info!("Beginning {}", session_info);
|
||||
if udpgw_client.is_full().await {
|
||||
return Err("max udpgw connection limit reached".into());
|
||||
}
|
||||
let mut tcp_server_stream = create_tcp_stream(&socket_queue, server_addr).await?;
|
||||
if let Err(e) = handle_proxy_session(&mut tcp_server_stream, proxy_handler).await {
|
||||
return Err(e);
|
||||
return Err(format!("udpgw connection error: {}",e).into());
|
||||
}
|
||||
server_stream = UdpGwClientStream::new(udp_mtu, tcp_server_stream);
|
||||
}
|
||||
|
@ -503,77 +506,93 @@ async fn handle_udp_gateway_session(
|
|||
|
||||
let udp_server_addr = udp_stack.peer_addr();
|
||||
|
||||
let tcp_local_addr = server_stream.local_addr().clone();
|
||||
|
||||
match domain_name {
|
||||
Some(ref d) => {
|
||||
log::info!("Beginning {}, domain:{}", udpinfo, d);
|
||||
log::info!("Beginning {} <- {}, domain:{}", udpinfo, &tcp_local_addr, d);
|
||||
}
|
||||
None => {
|
||||
log::info!("Beginning {}", udpinfo);
|
||||
log::info!("Beginning {} <- {}", udpinfo, &tcp_local_addr);
|
||||
}
|
||||
}
|
||||
|
||||
log::info!("Beginning {}", udpinfo);
|
||||
let Some(mut stream_reader) = server_stream.get_reader() else {
|
||||
return Err("get reader failed".into());
|
||||
};
|
||||
|
||||
let Some(mut stream_writer) = server_stream.get_writer() else {
|
||||
return Err("get writer failed".into());
|
||||
};
|
||||
|
||||
loop {
|
||||
let len = UdpGwClient::recv_udp_packet(&mut udp_stack, &mut server_stream).await;
|
||||
tokio::select! {
|
||||
len = UdpGwClient::recv_udp_packet(&mut udp_stack, &mut stream_writer) => {
|
||||
let read_len;
|
||||
match len {
|
||||
Ok(n) => {
|
||||
if n == 0 {
|
||||
log::info!("Ending {}", udpinfo);
|
||||
log::info!("Ending {} <- {}",udpinfo, &tcp_local_addr);
|
||||
break;
|
||||
}
|
||||
read_len = n;
|
||||
crate::traffic_status::traffic_status_update(n, 0)?;
|
||||
}
|
||||
Err(e) => {
|
||||
log::info!("Ending {} with recv_udp_packet error: {}", udpinfo, e);
|
||||
log::info!("Ending {} <- {} with recv_udp_packet {}", udpinfo, &tcp_local_addr, e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let newid = server_stream.newid();
|
||||
if let Err(e) =
|
||||
UdpGwClient::send_udpgw_packet(ipv6_enabled, read_len, udp_server_addr, domain_name.as_ref(), &mut server_stream).await
|
||||
UdpGwClient::send_udpgw_packet(ipv6_enabled, read_len, udp_server_addr, domain_name.as_ref(),newid,&mut stream_writer).await
|
||||
{
|
||||
log::info!(
|
||||
"{:?},Ending {} with send_udpgw_packet error: {}",
|
||||
server_stream.local_addr(),
|
||||
"Ending {} <- {} with send_udpgw_packet {}",
|
||||
udpinfo,
|
||||
&tcp_local_addr,
|
||||
e
|
||||
);
|
||||
break;
|
||||
}
|
||||
|
||||
match UdpGwClient::recv_udpgw_packet(udp_mtu, &mut server_stream).await {
|
||||
server_stream.update_activity();
|
||||
}
|
||||
ret = UdpGwClient::recv_udpgw_packet(udp_mtu, udp_timeout, &mut stream_reader) => {
|
||||
match ret {
|
||||
Ok(packet) => match packet {
|
||||
//should not received keepalive
|
||||
UdpGwResponse::KeepAlive => {
|
||||
log::error!("Ending {} with recv keepalive", udpinfo);
|
||||
let _ = server_stream.close().await;
|
||||
log::error!("Ending {} <- {} with recv keepalive", udpinfo, &tcp_local_addr);
|
||||
server_stream.close();
|
||||
break;
|
||||
}
|
||||
//server udp may be timeout,can continue to receive udp data?
|
||||
UdpGwResponse::Error => {
|
||||
log::info!("Ending {} with recv udp error", udpinfo);
|
||||
log::info!("Ending {} <- {} with recv udp error", udpinfo, &tcp_local_addr);
|
||||
server_stream.update_activity();
|
||||
continue;
|
||||
}
|
||||
UdpGwResponse::Data(data) => {
|
||||
crate::traffic_status::traffic_status_update(0, data.len())?;
|
||||
|
||||
let len = data.len();
|
||||
if let Err(e) = UdpGwClient::send_udp_packet(data, &mut udp_stack).await {
|
||||
log::info!("Ending {} with send_udp_packet error: {}", udpinfo, e);
|
||||
log::error!("Ending {} <- {} with send_udp_packet {}", udpinfo, &tcp_local_addr, e);
|
||||
break;
|
||||
}
|
||||
crate::traffic_status::traffic_status_update(0, len)?;
|
||||
}
|
||||
},
|
||||
Err(e) => {
|
||||
log::info!("Ending {} with recv_udpgw_packet error: {}", udpinfo, e);
|
||||
log::warn!("Ending {} <- {} with recv_udpgw_packet {}", udpinfo, &tcp_local_addr, e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
server_stream.update_activity();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !server_stream.is_closed() {
|
||||
udpgw_client.release_server_connection(server_stream).await;
|
||||
udpgw_client.release_server_connection_with_stream(server_stream,stream_reader,stream_writer).await;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
|
230
src/udpgw.rs
230
src/udpgw.rs
|
@ -4,9 +4,8 @@ use std::collections::VecDeque;
|
|||
use std::hash::Hash;
|
||||
use std::mem;
|
||||
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf};
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::time::{sleep, Duration};
|
||||
|
@ -108,45 +107,56 @@ pub(crate) enum UdpGwResponse<'a> {
|
|||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct UdpGwClientStream {
|
||||
inner: TcpStream,
|
||||
conid: u16,
|
||||
pub(crate) struct UdpGwClientStreamWriter {
|
||||
inner: OwnedWriteHalf,
|
||||
tmp_buf: Vec<u8>,
|
||||
send_buf: Vec<u8>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct UdpGwClientStreamReader {
|
||||
inner: OwnedReadHalf,
|
||||
recv_buf: Vec<u8>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct UdpGwClientStream {
|
||||
local_addr: String,
|
||||
writer: Option<UdpGwClientStreamWriter>,
|
||||
reader: Option<UdpGwClientStreamReader>,
|
||||
conid: u16,
|
||||
closed: bool,
|
||||
last_activity: std::time::Instant,
|
||||
}
|
||||
|
||||
impl AsyncWrite for UdpGwClientStream {
|
||||
fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<tokio::io::Result<usize>> {
|
||||
Pin::new(&mut self.inner).poll_write(cx, buf)
|
||||
}
|
||||
|
||||
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
|
||||
Pin::new(&mut self.inner).poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
|
||||
Pin::new(&mut self.inner).poll_shutdown(cx)
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncRead for UdpGwClientStream {
|
||||
fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<std::io::Result<()>> {
|
||||
Pin::new(&mut self.inner).poll_read(cx, buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl UdpGwClientStream {
|
||||
pub async fn close(&mut self) -> Result<()> {
|
||||
self.inner.shutdown().await?;
|
||||
pub fn close(&mut self) {
|
||||
self.closed = true;
|
||||
Ok(())
|
||||
}
|
||||
pub fn local_addr(&self) -> Result<SocketAddr> {
|
||||
Ok(self.inner.local_addr()?)
|
||||
pub fn get_reader(&mut self) -> Option<UdpGwClientStreamReader> {
|
||||
self.reader.take()
|
||||
}
|
||||
|
||||
pub fn set_reader(&mut self, mut reader: Option<UdpGwClientStreamReader>) {
|
||||
self.reader = reader.take();
|
||||
}
|
||||
|
||||
pub fn set_writer(&mut self, mut writer: Option<UdpGwClientStreamWriter>) {
|
||||
self.writer = writer.take();
|
||||
}
|
||||
|
||||
pub fn get_writer(&mut self) -> Option<UdpGwClientStreamWriter> {
|
||||
self.writer.take()
|
||||
}
|
||||
|
||||
pub fn local_addr(&self) -> &String {
|
||||
&self.local_addr
|
||||
}
|
||||
|
||||
pub fn update_activity(&mut self) {
|
||||
self.last_activity = std::time::Instant::now();
|
||||
}
|
||||
|
||||
pub fn is_closed(&mut self) -> bool {
|
||||
self.closed
|
||||
}
|
||||
|
@ -156,16 +166,28 @@ impl UdpGwClientStream {
|
|||
}
|
||||
|
||||
pub fn newid(&mut self) -> u16 {
|
||||
let next = self.conid;
|
||||
self.conid += 1;
|
||||
return next;
|
||||
self.conid
|
||||
}
|
||||
pub fn new(udp_mtu: u16, tcp_server_stream: TcpStream) -> Self {
|
||||
UdpGwClientStream {
|
||||
inner: tcp_server_stream,
|
||||
let local_addr = tcp_server_stream
|
||||
.local_addr()
|
||||
.unwrap_or_else(|_| "0.0.0.0:0".parse::<SocketAddr>().unwrap())
|
||||
.to_string();
|
||||
let (rx, tx) = tcp_server_stream.into_split();
|
||||
let writer = UdpGwClientStreamWriter {
|
||||
inner: tx,
|
||||
tmp_buf: vec![0; udp_mtu.into()],
|
||||
send_buf: vec![0; udp_mtu.into()],
|
||||
};
|
||||
let reader = UdpGwClientStreamReader {
|
||||
inner: rx,
|
||||
recv_buf: vec![0; udp_mtu.into()],
|
||||
};
|
||||
UdpGwClientStream {
|
||||
local_addr,
|
||||
reader: Some(reader),
|
||||
writer: Some(writer),
|
||||
last_activity: std::time::Instant::now(),
|
||||
closed: false,
|
||||
conid: 0,
|
||||
|
@ -176,7 +198,8 @@ impl UdpGwClientStream {
|
|||
#[derive(Debug)]
|
||||
pub(crate) struct UdpGwClient {
|
||||
udp_mtu: u16,
|
||||
max_connections: usize,
|
||||
max_connections: u16,
|
||||
udp_timeout: u64,
|
||||
keepalive_time: Duration,
|
||||
udpgw_bind_addr: SocketAddr,
|
||||
keepalive_packet: Vec<u8>,
|
||||
|
@ -184,18 +207,19 @@ pub(crate) struct UdpGwClient {
|
|||
}
|
||||
|
||||
impl UdpGwClient {
|
||||
pub fn new(udp_mtu: u16, max_connections: usize, keepalive_time: Duration, udpgw_bind_addr: SocketAddr) -> Self {
|
||||
pub fn new(udp_mtu: u16, max_connections: u16, keepalive_time: Duration, udp_timeout: u64, udpgw_bind_addr: SocketAddr) -> Self {
|
||||
let mut keepalive_packet = vec![];
|
||||
keepalive_packet.extend_from_slice(&(std::mem::size_of::<UdpgwHeader>() as u16).to_le_bytes());
|
||||
keepalive_packet.extend_from_slice(&[UDPGW_FLAG_KEEPALIVE, 0, 0]);
|
||||
let server_connections = Mutex::new(VecDeque::new());
|
||||
let server_connections = Mutex::new(VecDeque::with_capacity(max_connections as usize));
|
||||
return UdpGwClient {
|
||||
udp_mtu,
|
||||
max_connections,
|
||||
udp_timeout,
|
||||
udpgw_bind_addr,
|
||||
keepalive_time,
|
||||
keepalive_packet,
|
||||
server_connections: server_connections,
|
||||
server_connections,
|
||||
};
|
||||
}
|
||||
|
||||
|
@ -203,12 +227,33 @@ impl UdpGwClient {
|
|||
self.udp_mtu
|
||||
}
|
||||
|
||||
pub(crate) fn get_udp_timeout(&self) -> u64 {
|
||||
self.udp_timeout
|
||||
}
|
||||
|
||||
pub(crate) async fn is_full(&self) -> bool {
|
||||
self.server_connections.lock().await.len() >= self.max_connections as usize
|
||||
}
|
||||
|
||||
pub(crate) async fn get_server_connection(&self) -> Option<UdpGwClientStream> {
|
||||
self.server_connections.lock().await.pop_front()
|
||||
}
|
||||
|
||||
pub(crate) async fn release_server_connection(&self, stream: UdpGwClientStream) {
|
||||
if self.server_connections.lock().await.len() < self.max_connections {
|
||||
if self.server_connections.lock().await.len() < self.max_connections as usize {
|
||||
self.server_connections.lock().await.push_back(stream);
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn release_server_connection_with_stream(
|
||||
&self,
|
||||
mut stream: UdpGwClientStream,
|
||||
reader: UdpGwClientStreamReader,
|
||||
writer: UdpGwClientStreamWriter,
|
||||
) {
|
||||
if self.server_connections.lock().await.len() < self.max_connections as usize {
|
||||
stream.set_reader(Some(reader));
|
||||
stream.set_writer(Some(writer));
|
||||
self.server_connections.lock().await.push_back(stream);
|
||||
}
|
||||
}
|
||||
|
@ -217,6 +262,7 @@ impl UdpGwClient {
|
|||
return self.udpgw_bind_addr;
|
||||
}
|
||||
|
||||
/// Heartbeat task asynchronous function to periodically check and maintain the active state of the server connection.
|
||||
pub(crate) async fn heartbeat_task(&self) {
|
||||
loop {
|
||||
sleep(self.keepalive_time).await;
|
||||
|
@ -225,28 +271,35 @@ impl UdpGwClient {
|
|||
self.release_server_connection(stream).await;
|
||||
continue;
|
||||
}
|
||||
log::debug!("{:?}:{} send keepalive", stream.local_addr(), stream.id());
|
||||
if let Err(e) = stream.write_all(&self.keepalive_packet).await {
|
||||
let _ = stream.close().await;
|
||||
log::warn!("{:?}:{} Heartbeat failed: {}", stream.local_addr(), stream.id(), e);
|
||||
|
||||
let Some(mut stream_reader) = stream.get_reader() else {
|
||||
continue;
|
||||
};
|
||||
|
||||
let Some(mut stream_writer) = stream.get_writer() else {
|
||||
continue;
|
||||
};
|
||||
log::debug!("{:?}:{} send keepalive", stream_writer.inner.local_addr(), stream.id());
|
||||
if let Err(e) = stream_writer.inner.write_all(&self.keepalive_packet).await {
|
||||
log::warn!("{:?}:{} Heartbeat failed: {}", stream_writer.inner.local_addr(), stream.id(), e);
|
||||
} else {
|
||||
stream.last_activity = std::time::Instant::now();
|
||||
match UdpGwClient::recv_udpgw_packet(self.udp_mtu, &mut stream).await {
|
||||
match UdpGwClient::recv_udpgw_packet(self.udp_mtu, 10, &mut stream_reader).await {
|
||||
Ok(UdpGwResponse::KeepAlive) => {
|
||||
self.release_server_connection(stream).await;
|
||||
continue;
|
||||
}
|
||||
//shoud not receive other
|
||||
_ => {
|
||||
continue;
|
||||
stream.last_activity = std::time::Instant::now();
|
||||
self.release_server_connection_with_stream(stream, stream_reader, stream_writer)
|
||||
.await;
|
||||
}
|
||||
//shoud not receive other type
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn parse_udp_response(udp_mtu: u16, data_len: usize, data: &[u8]) -> Result<UdpGwResponse> {
|
||||
/// Parses the UDP response data.
|
||||
pub(crate) fn parse_udp_response(udp_mtu: u16, data_len: usize, stream: &mut UdpGwClientStreamReader) -> Result<UdpGwResponse> {
|
||||
let data = &stream.recv_buf;
|
||||
if data_len < mem::size_of::<UdpgwHeader>() {
|
||||
return Err("Invalid udpgw data".into());
|
||||
}
|
||||
|
@ -259,7 +312,6 @@ impl UdpGwClient {
|
|||
let flags = header.flags;
|
||||
let conid = header.conid;
|
||||
|
||||
// parse address
|
||||
let ip_data = &data[mem::size_of::<UdpgwHeader>()..];
|
||||
let mut data_len = data_len - mem::size_of::<UdpgwHeader>();
|
||||
|
||||
|
@ -267,7 +319,7 @@ impl UdpGwClient {
|
|||
return Ok(UdpGwResponse::Error);
|
||||
}
|
||||
|
||||
if flags & UDPGW_FLAG_ERR != 0 {
|
||||
if flags & UDPGW_FLAG_KEEPALIVE != 0 {
|
||||
return Ok(UdpGwResponse::KeepAlive);
|
||||
}
|
||||
|
||||
|
@ -281,7 +333,7 @@ impl UdpGwClient {
|
|||
addr_port: u16::from_be_bytes([addr_ipv6_bytes[16], addr_ipv6_bytes[17]]),
|
||||
};
|
||||
data_len -= mem::size_of::<UdpgwAddrIpv6>();
|
||||
// check payload length
|
||||
|
||||
if data_len > udp_mtu as usize {
|
||||
return Err("too much data".into());
|
||||
}
|
||||
|
@ -302,7 +354,6 @@ impl UdpGwClient {
|
|||
};
|
||||
data_len -= mem::size_of::<UdpgwAddrIpv4>();
|
||||
|
||||
// check payload length
|
||||
if data_len > udp_mtu as usize {
|
||||
return Err("too much data".into());
|
||||
}
|
||||
|
@ -317,7 +368,7 @@ impl UdpGwClient {
|
|||
|
||||
pub(crate) async fn recv_udp_packet(
|
||||
udp_stack: &mut IpStackUdpStream,
|
||||
stream: &mut UdpGwClientStream,
|
||||
stream: &mut UdpGwClientStreamWriter,
|
||||
) -> std::result::Result<usize, std::io::Error> {
|
||||
return udp_stack.read(&mut stream.tmp_buf).await;
|
||||
}
|
||||
|
@ -329,22 +380,35 @@ impl UdpGwClient {
|
|||
return udp_stack.write_all(&packet.udpdata).await;
|
||||
}
|
||||
|
||||
pub(crate) async fn recv_udpgw_packet(udp_mtu: u16, stream: &mut UdpGwClientStream) -> Result<UdpGwResponse> {
|
||||
stream.recv_buf.resize(2, 0);
|
||||
/// Receives a UDP gateway packet.
|
||||
///
|
||||
/// This function is responsible for receiving packets from the UDP gateway
|
||||
///
|
||||
/// # Arguments
|
||||
/// - `udp_mtu`: The maximum transmission unit size for UDP packets.
|
||||
/// - `udp_timeout`: The timeout in seconds for receiving UDP packets.
|
||||
/// - `stream`: A mutable reference to the UDP gateway client stream reader.
|
||||
///
|
||||
/// # Returns
|
||||
/// - `Result<UdpGwResponse>`: Returns a result type containing the parsed UDP gateway response, or an error if one occurs.
|
||||
pub(crate) async fn recv_udpgw_packet(udp_mtu: u16, udp_timeout: u64, stream: &mut UdpGwClientStreamReader) -> Result<UdpGwResponse> {
|
||||
let result;
|
||||
match tokio::time::timeout(tokio::time::Duration::from_secs(10), stream.inner.read(&mut stream.recv_buf)).await {
|
||||
match tokio::time::timeout(
|
||||
tokio::time::Duration::from_secs(udp_timeout + 2),
|
||||
stream.inner.read(&mut stream.recv_buf[..2]),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(ret) => {
|
||||
result = ret;
|
||||
}
|
||||
Err(_e) => {
|
||||
let _ = stream.close().await;
|
||||
return Err(format!("{:?} wait tcp data timeout", stream.local_addr()).into());
|
||||
return Err(format!("wait tcp data timeout").into());
|
||||
}
|
||||
};
|
||||
match result {
|
||||
Ok(0) => {
|
||||
let _ = stream.close().await;
|
||||
return Err(format!("{:?} tcp connection closed", stream.local_addr()).into());
|
||||
return Err(format!("tcp connection closed").into());
|
||||
}
|
||||
Ok(n) => {
|
||||
if n < std::mem::size_of::<PackLenHeader>() {
|
||||
|
@ -354,41 +418,53 @@ impl UdpGwClient {
|
|||
if packet_len > udp_mtu {
|
||||
return Err("packet too long".into());
|
||||
}
|
||||
stream.recv_buf.resize(udp_mtu as usize, 0);
|
||||
let mut left_len: usize = packet_len as usize;
|
||||
let mut recv_len = 0;
|
||||
while left_len > 0 {
|
||||
if let Ok(len) = stream.inner.read(&mut stream.recv_buf[recv_len..left_len]).await {
|
||||
if len == 0 {
|
||||
let _ = stream.close().await;
|
||||
return Err(format!("{:?} tcp connection closed", stream.local_addr()).into());
|
||||
return Err("tcp connection closed".into());
|
||||
}
|
||||
recv_len += len;
|
||||
left_len -= len;
|
||||
} else {
|
||||
let _ = stream.close().await;
|
||||
return Err(format!("{:?} tcp connection closed", stream.local_addr()).into());
|
||||
return Err("tcp connection closed".into());
|
||||
}
|
||||
}
|
||||
stream.last_activity = std::time::Instant::now();
|
||||
return UdpGwClient::parse_udp_response(udp_mtu, packet_len as usize, &stream.recv_buf);
|
||||
return UdpGwClient::parse_udp_response(udp_mtu, packet_len as usize, stream);
|
||||
}
|
||||
Err(_) => {
|
||||
let _ = stream.close().await;
|
||||
return Err(format!("{:?} tcp read error", stream.local_addr()).into());
|
||||
return Err("tcp read error".into());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Sends a UDP gateway packet.
|
||||
///
|
||||
/// This function constructs and sends a UDP gateway packet based on the IPv6 enabled status, data length,
|
||||
/// remote address, domain (if any), connection ID, and the UDP gateway client writer stream.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `ipv6_enabled` - Whether IPv6 is enabled
|
||||
/// * `len` - Length of the data packet
|
||||
/// * `remote_addr` - Remote address
|
||||
/// * `domain` - Target domain (optional)
|
||||
/// * `conid` - Connection ID
|
||||
/// * `stream` - UDP gateway client writer stream
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Returns `Ok(())` if the packet is sent successfully, otherwise returns an error.
|
||||
pub(crate) async fn send_udpgw_packet(
|
||||
ipv6_enabled: bool,
|
||||
len: usize,
|
||||
remote_addr: SocketAddr,
|
||||
domain: Option<&String>,
|
||||
stream: &mut UdpGwClientStream,
|
||||
conid: u16,
|
||||
stream: &mut UdpGwClientStreamWriter,
|
||||
) -> Result<()> {
|
||||
stream.send_buf.clear();
|
||||
let conid = stream.newid();
|
||||
let data = &stream.tmp_buf;
|
||||
let mut pack_len = std::mem::size_of::<UdpgwHeader>() + len;
|
||||
let packet = &mut stream.send_buf;
|
||||
|
@ -442,8 +518,6 @@ impl UdpGwClient {
|
|||
|
||||
stream.inner.write_all(&packet).await?;
|
||||
|
||||
stream.last_activity = std::time::Instant::now();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue