mirror of
https://github.com/tun2proxy/tun2proxy.git
synced 2025-06-20 16:10:52 +00:00
support udp gateway mode
This commit is contained in:
parent
87c2b666ab
commit
0a833d69a6
5 changed files with 309 additions and 202 deletions
13
src/args.rs
13
src/args.rs
|
@ -70,14 +70,18 @@ pub struct Args {
|
||||||
#[arg(short, long, default_value = if cfg!(target_os = "linux") { "false" } else { "true" })]
|
#[arg(short, long, default_value = if cfg!(target_os = "linux") { "false" } else { "true" })]
|
||||||
pub setup: bool,
|
pub setup: bool,
|
||||||
|
|
||||||
/// DNS handling strategy
|
|
||||||
#[arg(short, long, value_name = "strategy", value_enum, default_value = "direct")]
|
|
||||||
pub dns: ArgDns,
|
|
||||||
|
|
||||||
/// UDP gateway address
|
/// UDP gateway address
|
||||||
#[arg(long, value_name = "IP:PORT")]
|
#[arg(long, value_name = "IP:PORT")]
|
||||||
pub udpgw_bind_addr: Option<SocketAddr>,
|
pub udpgw_bind_addr: Option<SocketAddr>,
|
||||||
|
|
||||||
|
/// Max udpgw connections
|
||||||
|
#[arg(long, value_name = "number", default_value = "100")]
|
||||||
|
pub max_udpgw_connections: u16,
|
||||||
|
|
||||||
|
/// DNS handling strategy
|
||||||
|
#[arg(short, long, value_name = "strategy", value_enum, default_value = "direct")]
|
||||||
|
pub dns: ArgDns,
|
||||||
|
|
||||||
/// DNS resolver address
|
/// DNS resolver address
|
||||||
#[arg(long, value_name = "IP", default_value = "8.8.8.8")]
|
#[arg(long, value_name = "IP", default_value = "8.8.8.8")]
|
||||||
pub dns_addr: IpAddr,
|
pub dns_addr: IpAddr,
|
||||||
|
@ -149,6 +153,7 @@ impl Default for Args {
|
||||||
ipv6_enabled: false,
|
ipv6_enabled: false,
|
||||||
setup,
|
setup,
|
||||||
udpgw_bind_addr: None,
|
udpgw_bind_addr: None,
|
||||||
|
max_udpgw_connections: 100,
|
||||||
dns: ArgDns::default(),
|
dns: ArgDns::default(),
|
||||||
dns_addr: "8.8.8.8".parse().unwrap(),
|
dns_addr: "8.8.8.8".parse().unwrap(),
|
||||||
bypass: vec![],
|
bypass: vec![],
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
use std::collections::HashMap;
|
|
||||||
use std::mem;
|
use std::mem;
|
||||||
use std::net::Ipv4Addr;
|
use std::net::Ipv4Addr;
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
|
@ -11,25 +10,24 @@ use tokio::net::TcpListener;
|
||||||
use tokio::net::UdpSocket;
|
use tokio::net::UdpSocket;
|
||||||
use tokio::sync::mpsc;
|
use tokio::sync::mpsc;
|
||||||
use tokio::sync::mpsc::Sender;
|
use tokio::sync::mpsc::Sender;
|
||||||
use tokio::sync::Mutex;
|
|
||||||
pub use tun2proxy::udpgw::*;
|
pub use tun2proxy::udpgw::*;
|
||||||
use tun2proxy::ArgVerbosity;
|
use tun2proxy::ArgVerbosity;
|
||||||
use tun2proxy::Result;
|
use tun2proxy::Result;
|
||||||
pub(crate) const CLIENT_DISCONNECT_TIMEOUT: tokio::time::Duration = std::time::Duration::from_secs(60);
|
pub(crate) const CLIENT_DISCONNECT_TIMEOUT: tokio::time::Duration = std::time::Duration::from_secs(30);
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct Connection {
|
struct UdpRequest {
|
||||||
flags: u8,
|
flags: u8,
|
||||||
server_addr: SocketAddr,
|
server_addr: SocketAddr,
|
||||||
conid: u16,
|
conid: u16,
|
||||||
data: Vec<u8>,
|
data: Vec<u8>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
struct Client {
|
struct Client {
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
addr: SocketAddr,
|
addr: SocketAddr,
|
||||||
buf: Vec<u8>,
|
buf: Vec<u8>,
|
||||||
connections: Arc<Mutex<HashMap<u16, Connection>>>,
|
|
||||||
last_activity: std::time::Instant,
|
last_activity: std::time::Instant,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -43,6 +41,10 @@ pub struct UdpGwArgs {
|
||||||
#[arg(short, long, value_name = "level", value_enum, default_value = "info")]
|
#[arg(short, long, value_name = "level", value_enum, default_value = "info")]
|
||||||
pub verbosity: ArgVerbosity,
|
pub verbosity: ArgVerbosity,
|
||||||
|
|
||||||
|
/// Daemonize for unix family or run as Windows service
|
||||||
|
#[arg(long)]
|
||||||
|
pub daemonize: bool,
|
||||||
|
|
||||||
/// UDP timeout in seconds
|
/// UDP timeout in seconds
|
||||||
#[arg(long, value_name = "seconds", default_value = "3")]
|
#[arg(long, value_name = "seconds", default_value = "3")]
|
||||||
pub udp_timeout: u64,
|
pub udp_timeout: u64,
|
||||||
|
@ -56,11 +58,10 @@ impl UdpGwArgs {
|
||||||
#[allow(clippy::let_and_return)]
|
#[allow(clippy::let_and_return)]
|
||||||
pub fn parse_args() -> Self {
|
pub fn parse_args() -> Self {
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
let args = Self::parse();
|
Self::parse()
|
||||||
args
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
async fn send_error_response(tx: Sender<Vec<u8>>, con: &mut Connection) {
|
async fn send_error(tx: Sender<Vec<u8>>, con: &mut UdpRequest) {
|
||||||
let mut error_packet = vec![];
|
let mut error_packet = vec![];
|
||||||
error_packet.extend_from_slice(&(std::mem::size_of::<UdpgwHeader>() as u16).to_le_bytes());
|
error_packet.extend_from_slice(&(std::mem::size_of::<UdpgwHeader>() as u16).to_le_bytes());
|
||||||
error_packet.extend_from_slice(&[UDPGW_FLAG_ERR]);
|
error_packet.extend_from_slice(&[UDPGW_FLAG_ERR]);
|
||||||
|
@ -70,7 +71,13 @@ async fn send_error_response(tx: Sender<Vec<u8>>, con: &mut Connection) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn parse_udp_req_data(udp_mtu: u16, data_len: usize, data: &[u8]) -> Result<(&[u8], u8, u16, SocketAddr)> {
|
async fn send_keepalive_response(tx: Sender<Vec<u8>>, keepalive_packet: &[u8]) {
|
||||||
|
if let Err(e) = tx.send(keepalive_packet.to_vec()).await {
|
||||||
|
log::error!("send keepalive response error {:?}", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn parse_udp(udp_mtu: u16, data_len: usize, data: &[u8]) -> Result<(&[u8], u8, u16, SocketAddr)> {
|
||||||
if data_len < mem::size_of::<UdpgwHeader>() {
|
if data_len < mem::size_of::<UdpgwHeader>() {
|
||||||
return Err("Invalid udpgw data".into());
|
return Err("Invalid udpgw data".into());
|
||||||
}
|
}
|
||||||
|
@ -85,10 +92,9 @@ pub fn parse_udp_req_data(udp_mtu: u16, data_len: usize, data: &[u8]) -> Result<
|
||||||
|
|
||||||
// keepalive
|
// keepalive
|
||||||
if flags & UDPGW_FLAG_KEEPALIVE != 0 {
|
if flags & UDPGW_FLAG_KEEPALIVE != 0 {
|
||||||
return Ok((data, UDPGW_FLAG_KEEPALIVE, 0, SocketAddrV4::new(Ipv4Addr::from(0), 0).into()));
|
return Ok((data, flags, conid, SocketAddrV4::new(Ipv4Addr::from(0), 0).into()));
|
||||||
}
|
}
|
||||||
|
|
||||||
// parse address
|
|
||||||
let ip_data = &data[mem::size_of::<UdpgwHeader>()..];
|
let ip_data = &data[mem::size_of::<UdpgwHeader>()..];
|
||||||
let mut data_len = data_len - mem::size_of::<UdpgwHeader>();
|
let mut data_len = data_len - mem::size_of::<UdpgwHeader>();
|
||||||
// port_len + min(ipv4/ipv6/(domain_len + 1))
|
// port_len + min(ipv4/ipv6/(domain_len + 1))
|
||||||
|
@ -107,7 +113,6 @@ pub fn parse_udp_req_data(udp_mtu: u16, data_len: usize, data: &[u8]) -> Result<
|
||||||
.to_socket_addrs()?
|
.to_socket_addrs()?
|
||||||
.next()
|
.next()
|
||||||
.ok_or(format!("Invalid address {}", target_str))?;
|
.ok_or(format!("Invalid address {}", target_str))?;
|
||||||
// check payload length
|
|
||||||
if data_len < 2 + domain.len() {
|
if data_len < 2 + domain.len() {
|
||||||
return Err("Invalid udpgw data".into());
|
return Err("Invalid udpgw data".into());
|
||||||
}
|
}
|
||||||
|
@ -136,7 +141,6 @@ pub fn parse_udp_req_data(udp_mtu: u16, data_len: usize, data: &[u8]) -> Result<
|
||||||
};
|
};
|
||||||
data_len -= mem::size_of::<UdpgwAddrIpv6>();
|
data_len -= mem::size_of::<UdpgwAddrIpv6>();
|
||||||
|
|
||||||
// check payload length
|
|
||||||
if data_len > udp_mtu as usize {
|
if data_len > udp_mtu as usize {
|
||||||
return Err("too much data".into());
|
return Err("too much data".into());
|
||||||
}
|
}
|
||||||
|
@ -157,7 +161,6 @@ pub fn parse_udp_req_data(udp_mtu: u16, data_len: usize, data: &[u8]) -> Result<
|
||||||
};
|
};
|
||||||
data_len -= mem::size_of::<UdpgwAddrIpv4>();
|
data_len -= mem::size_of::<UdpgwAddrIpv4>();
|
||||||
|
|
||||||
// check payload length
|
|
||||||
if data_len > udp_mtu as usize {
|
if data_len > udp_mtu as usize {
|
||||||
return Err("too much data".into());
|
return Err("too much data".into());
|
||||||
}
|
}
|
||||||
|
@ -171,15 +174,16 @@ pub fn parse_udp_req_data(udp_mtu: u16, data_len: usize, data: &[u8]) -> Result<
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn process_udp(addr: SocketAddr, udp_timeout: u64, tx: Sender<Vec<u8>>, con: &mut Connection) -> Result<()> {
|
async fn process_udp(addr: SocketAddr, udp_timeout: u64, tx: Sender<Vec<u8>>, con: &mut UdpRequest) -> Result<()> {
|
||||||
let std_sock = std::net::UdpSocket::bind("0.0.0.0:0")?;
|
let std_sock = std::net::UdpSocket::bind("0.0.0.0:0")?;
|
||||||
std_sock.set_nonblocking(true)?;
|
std_sock.set_nonblocking(true)?;
|
||||||
nix::sys::socket::setsockopt(&std_sock, nix::sys::socket::sockopt::ReuseAddr, &true)?;
|
nix::sys::socket::setsockopt(&std_sock, nix::sys::socket::sockopt::ReuseAddr, &true)?;
|
||||||
let socket = UdpSocket::from_std(std_sock)?;
|
let socket = UdpSocket::from_std(std_sock)?;
|
||||||
socket.send_to(&con.data, &con.server_addr).await?;
|
socket.send_to(&con.data, &con.server_addr).await?;
|
||||||
con.data.resize(2048, 0);
|
con.data.resize(2048, 0);
|
||||||
match tokio::time::timeout(tokio::time::Duration::from_secs(udp_timeout), socket.recv_from(&mut con.data[..])).await? {
|
match tokio::time::timeout(tokio::time::Duration::from_secs(udp_timeout), socket.recv_from(&mut con.data)).await {
|
||||||
Ok((len, _addr)) => {
|
Ok(ret) => {
|
||||||
|
let (len, _addr) = ret?;
|
||||||
let mut packet = vec![];
|
let mut packet = vec![];
|
||||||
let mut pack_len = mem::size_of::<UdpgwHeader>() + len;
|
let mut pack_len = mem::size_of::<UdpgwHeader>() + len;
|
||||||
match con.server_addr.into() {
|
match con.server_addr.into() {
|
||||||
|
@ -203,17 +207,17 @@ async fn process_udp(addr: SocketAddr, udp_timeout: u64, tx: Sender<Vec<u8>>, co
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if let Err(e) = tx.send(packet).await {
|
if let Err(e) = tx.send(packet).await {
|
||||||
log::error!("client {} send udp response error {:?}", addr, e);
|
log::error!("client {} send udp response {}", addr, e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
log::error!("client {} udp recv_from error: {:?}", addr, e);
|
log::warn!("client {} udp recv_from {}", addr, e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn process_client_udp_req<'a>(args: Arc<UdpGwArgs>, tx: Sender<Vec<u8>>, mut client: Client, mut tcp_read_stream: ReadHalf<'a>) {
|
async fn process_client_udp_req<'a>(args: &UdpGwArgs, tx: Sender<Vec<u8>>, mut client: Client, mut tcp_read_stream: ReadHalf<'a>) {
|
||||||
let mut buf = vec![0; args.udp_mtu as usize];
|
let mut buf = vec![0; args.udp_mtu as usize];
|
||||||
let mut len_buf = [0; mem::size_of::<PackLenHeader>()];
|
let mut len_buf = [0; mem::size_of::<PackLenHeader>()];
|
||||||
let udp_mtu = args.udp_mtu;
|
let udp_mtu = args.udp_mtu;
|
||||||
|
@ -226,7 +230,7 @@ async fn process_client_udp_req<'a>(args: Arc<UdpGwArgs>, tx: Sender<Vec<u8>>, m
|
||||||
}
|
}
|
||||||
Err(_e) => {
|
Err(_e) => {
|
||||||
if client.last_activity.elapsed() >= CLIENT_DISCONNECT_TIMEOUT {
|
if client.last_activity.elapsed() >= CLIENT_DISCONNECT_TIMEOUT {
|
||||||
log::warn!("client {} last_activity elapsed", client.addr);
|
log::debug!("client {} last_activity elapsed", client.addr);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
continue;
|
continue;
|
||||||
|
@ -244,8 +248,7 @@ async fn process_client_udp_req<'a>(args: Arc<UdpGwArgs>, tx: Sender<Vec<u8>>, m
|
||||||
log::error!("client {} received packet too long", client.addr);
|
log::error!("client {} received packet too long", client.addr);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
log::info!("client {} recvied packet len {}", client.addr, packet_len);
|
log::debug!("client {} recvied packet len {}", client.addr, packet_len);
|
||||||
buf.resize(packet_len as usize, 0);
|
|
||||||
client.buf.clear();
|
client.buf.clear();
|
||||||
let mut left_len: usize = packet_len as usize;
|
let mut left_len: usize = packet_len as usize;
|
||||||
while left_len > 0 {
|
while left_len > 0 {
|
||||||
|
@ -260,48 +263,37 @@ async fn process_client_udp_req<'a>(args: Arc<UdpGwArgs>, tx: Sender<Vec<u8>>, m
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
client.last_activity = std::time::Instant::now();
|
client.last_activity = std::time::Instant::now();
|
||||||
let ret = parse_udp_req_data(udp_mtu, client.buf.len(), &client.buf);
|
let ret = parse_udp(udp_mtu, client.buf.len(), &client.buf);
|
||||||
if let Ok((udpdata, flags, conid, reqaddr)) = ret {
|
if let Ok((udpdata, flags, conid, reqaddr)) = ret {
|
||||||
if flags & UDPGW_FLAG_KEEPALIVE != 0 {
|
if flags & UDPGW_FLAG_KEEPALIVE != 0 {
|
||||||
log::debug!("client {} recvied keepalive packet", client.addr);
|
log::debug!("client {} send keepalive", client.addr);
|
||||||
|
send_keepalive_response(tx.clone(), udpdata).await;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
log::debug!(
|
log::debug!(
|
||||||
"client {} recvied udp data,flags:{},conid:{},addr:{:?},data len:{}",
|
"client {} received udp data,flags:{},conid:{},addr:{:?},data len:{}",
|
||||||
client.addr,
|
client.addr,
|
||||||
flags,
|
flags,
|
||||||
conid,
|
conid,
|
||||||
reqaddr,
|
reqaddr,
|
||||||
udpdata.len()
|
udpdata.len()
|
||||||
);
|
);
|
||||||
let mut con_lock = client.connections.lock().await;
|
let mut req = UdpRequest {
|
||||||
let con = con_lock.get_mut(&conid);
|
|
||||||
if let Some(conn) = con {
|
|
||||||
conn.data.clear();
|
|
||||||
conn.data.extend_from_slice(udpdata);
|
|
||||||
if let Err(e) = process_udp(client.addr, udp_timeout, tx.clone(), conn).await {
|
|
||||||
log::error!("client {} process_udp error: {:?}", client.addr, e);
|
|
||||||
send_error_response(tx.clone(), conn).await;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
drop(con_lock);
|
|
||||||
let mut conn = Connection {
|
|
||||||
server_addr: reqaddr,
|
server_addr: reqaddr,
|
||||||
conid,
|
conid,
|
||||||
flags,
|
flags,
|
||||||
data: udpdata.to_vec(),
|
data: udpdata.to_vec(),
|
||||||
};
|
};
|
||||||
if let Err(e) = process_udp(client.addr, udp_timeout, tx.clone(), &mut conn).await {
|
let tx1 = tx.clone();
|
||||||
send_error_response(tx.clone(), &mut conn).await;
|
let tx2 = tx.clone();
|
||||||
log::error!("client {} process_udp error: {:?}", client.addr, e);
|
tokio::spawn(async move {
|
||||||
continue;
|
if let Err(e) = process_udp(client.addr, udp_timeout, tx1, &mut req).await {
|
||||||
}
|
send_error(tx2, &mut req).await;
|
||||||
client.connections.lock().await.insert(conid, conn);
|
log::error!("client {} process_udp {}", client.addr, e);
|
||||||
}
|
}
|
||||||
|
});
|
||||||
} else {
|
} else {
|
||||||
log::error!("client {} parse_udp_data {:?}", client.addr, ret.err());
|
log::error!("client {} parse_udp_data {:?}", client.addr, ret.err());
|
||||||
continue;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
|
@ -318,32 +310,52 @@ async fn main() -> Result<()> {
|
||||||
|
|
||||||
let tcp_listener = TcpListener::bind(args.listen_addr).await?;
|
let tcp_listener = TcpListener::bind(args.listen_addr).await?;
|
||||||
|
|
||||||
log::info!("UDP GW Server started");
|
let default = format!("{:?}", args.verbosity);
|
||||||
|
|
||||||
let default = format!("{:?},hickory_proto=warn", args.verbosity);
|
|
||||||
|
|
||||||
env_logger::Builder::from_env(env_logger::Env::default().default_filter_or(default)).init();
|
env_logger::Builder::from_env(env_logger::Env::default().default_filter_or(default)).init();
|
||||||
|
|
||||||
|
log::info!("UDP GW Server started");
|
||||||
|
|
||||||
|
#[cfg(unix)]
|
||||||
|
if args.daemonize {
|
||||||
|
let stdout = std::fs::File::create("/tmp/udpgw.out")?;
|
||||||
|
let stderr = std::fs::File::create("/tmp/udpgw.err")?;
|
||||||
|
let daemonize = daemonize::Daemonize::new()
|
||||||
|
.working_directory("/tmp")
|
||||||
|
.umask(0o777)
|
||||||
|
.stdout(stdout)
|
||||||
|
.stderr(stderr)
|
||||||
|
.privileged_action(|| "Executed before drop privileges");
|
||||||
|
let _ = daemonize
|
||||||
|
.start()
|
||||||
|
.map_err(|e| format!("Failed to daemonize process, error:{:?}", e))?;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(target_os = "windows")]
|
||||||
|
if args.daemonize {
|
||||||
|
tun2proxy::win_svc::start_service()?;
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
let (mut tcp_stream, addr) = tcp_listener.accept().await?;
|
let (mut tcp_stream, addr) = tcp_listener.accept().await?;
|
||||||
let client = Client {
|
let client = Client {
|
||||||
addr,
|
addr,
|
||||||
buf: vec![],
|
buf: vec![],
|
||||||
connections: Arc::new(Mutex::new(HashMap::new())),
|
|
||||||
last_activity: std::time::Instant::now(),
|
last_activity: std::time::Instant::now(),
|
||||||
};
|
};
|
||||||
log::info!("client {} connected", addr);
|
log::info!("client {} connected", addr);
|
||||||
let params = args.clone();
|
let params = Arc::clone(&args);
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
let (tx, mut rx) = mpsc::channel::<Vec<u8>>(100);
|
let (tx, mut rx) = mpsc::channel::<Vec<u8>>(100);
|
||||||
let (tcp_read_stream, mut tcp_write_stream) = tcp_stream.split();
|
let (tcp_read_stream, mut tcp_write_stream) = tcp_stream.split();
|
||||||
tokio::select! {
|
tokio::select! {
|
||||||
_ = process_client_udp_req(params, tx, client, tcp_read_stream) =>{}
|
_ = process_client_udp_req(¶ms, tx, client, tcp_read_stream) =>{}
|
||||||
_ = async {
|
_ = async {
|
||||||
loop
|
loop
|
||||||
{
|
{
|
||||||
if let Some(udp_response) = rx.recv().await {
|
if let Some(udp_response) = rx.recv().await {
|
||||||
log::info!("client {} send udp data len:{}", addr, udp_response.len(),);
|
log::debug!("send udp_response len {}",udp_response.len());
|
||||||
let _ = tcp_write_stream.write(&udp_response).await;
|
let _ = tcp_write_stream.write(&udp_response).await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -47,9 +47,6 @@ pub enum Error {
|
||||||
#[cfg(target_os = "linux")]
|
#[cfg(target_os = "linux")]
|
||||||
#[error("bincode::Error {0:?}")]
|
#[error("bincode::Error {0:?}")]
|
||||||
BincodeError(#[from] bincode::Error),
|
BincodeError(#[from] bincode::Error),
|
||||||
|
|
||||||
#[error("tokio::time::error::Elapsed")]
|
|
||||||
Timeout(#[from] tokio::time::error::Elapsed),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<&str> for Error {
|
impl From<&str> for Error {
|
||||||
|
|
67
src/lib.rs
67
src/lib.rs
|
@ -24,7 +24,7 @@ use tokio::{
|
||||||
pub use tokio_util::sync::CancellationToken;
|
pub use tokio_util::sync::CancellationToken;
|
||||||
use tproxy_config::is_private_ip;
|
use tproxy_config::is_private_ip;
|
||||||
use udp_stream::UdpStream;
|
use udp_stream::UdpStream;
|
||||||
use udpgw::{UdpGwClientStream, UdpGwResponse, UDPGW_KEEPALIVE_TIME, UDPGW_MAX_CONNECTIONS};
|
use udpgw::{UdpGwClientStream, UdpGwResponse, UDPGW_KEEPALIVE_TIME};
|
||||||
|
|
||||||
pub use {
|
pub use {
|
||||||
args::{ArgDns, ArgProxy, ArgVerbosity, Args, ProxyType},
|
args::{ArgDns, ArgProxy, ArgVerbosity, Args, ProxyType},
|
||||||
|
@ -238,7 +238,7 @@ where
|
||||||
None => None,
|
None => None,
|
||||||
Some(addr) => {
|
Some(addr) => {
|
||||||
log::info!("UDPGW enabled");
|
log::info!("UDPGW enabled");
|
||||||
let client = Arc::new(UdpGwClient::new(mtu, UDPGW_MAX_CONNECTIONS, UDPGW_KEEPALIVE_TIME, addr));
|
let client = Arc::new(UdpGwClient::new(mtu, args.max_udpgw_connections, UDPGW_KEEPALIVE_TIME, args.udp_timeout, addr));
|
||||||
let client_keepalive = client.clone();
|
let client_keepalive = client.clone();
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
client_keepalive.heartbeat_task().await;
|
client_keepalive.heartbeat_task().await;
|
||||||
|
@ -485,6 +485,7 @@ async fn handle_udp_gateway_session(
|
||||||
};
|
};
|
||||||
let udpinfo = SessionInfo::new(udp_stack.local_addr(), udp_stack.peer_addr(), IpProtocol::Udp);
|
let udpinfo = SessionInfo::new(udp_stack.local_addr(), udp_stack.peer_addr(), IpProtocol::Udp);
|
||||||
let udp_mtu = udpgw_client.get_udp_mtu();
|
let udp_mtu = udpgw_client.get_udp_mtu();
|
||||||
|
let udp_timeout = udpgw_client.get_udp_timeout();
|
||||||
let mut server_stream: UdpGwClientStream;
|
let mut server_stream: UdpGwClientStream;
|
||||||
let server = udpgw_client.get_server_connection().await;
|
let server = udpgw_client.get_server_connection().await;
|
||||||
match server {
|
match server {
|
||||||
|
@ -492,10 +493,12 @@ async fn handle_udp_gateway_session(
|
||||||
server_stream = server;
|
server_stream = server;
|
||||||
}
|
}
|
||||||
None => {
|
None => {
|
||||||
log::info!("Beginning {}", session_info);
|
if udpgw_client.is_full().await {
|
||||||
|
return Err("max udpgw connection limit reached".into());
|
||||||
|
}
|
||||||
let mut tcp_server_stream = create_tcp_stream(&socket_queue, server_addr).await?;
|
let mut tcp_server_stream = create_tcp_stream(&socket_queue, server_addr).await?;
|
||||||
if let Err(e) = handle_proxy_session(&mut tcp_server_stream, proxy_handler).await {
|
if let Err(e) = handle_proxy_session(&mut tcp_server_stream, proxy_handler).await {
|
||||||
return Err(e);
|
return Err(format!("udpgw connection error: {}",e).into());
|
||||||
}
|
}
|
||||||
server_stream = UdpGwClientStream::new(udp_mtu, tcp_server_stream);
|
server_stream = UdpGwClientStream::new(udp_mtu, tcp_server_stream);
|
||||||
}
|
}
|
||||||
|
@ -503,77 +506,93 @@ async fn handle_udp_gateway_session(
|
||||||
|
|
||||||
let udp_server_addr = udp_stack.peer_addr();
|
let udp_server_addr = udp_stack.peer_addr();
|
||||||
|
|
||||||
|
let tcp_local_addr = server_stream.local_addr().clone();
|
||||||
|
|
||||||
match domain_name {
|
match domain_name {
|
||||||
Some(ref d) => {
|
Some(ref d) => {
|
||||||
log::info!("Beginning {}, domain:{}", udpinfo, d);
|
log::info!("Beginning {} <- {}, domain:{}", udpinfo, &tcp_local_addr, d);
|
||||||
}
|
}
|
||||||
None => {
|
None => {
|
||||||
log::info!("Beginning {}", udpinfo);
|
log::info!("Beginning {} <- {}", udpinfo, &tcp_local_addr);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
log::info!("Beginning {}", udpinfo);
|
let Some(mut stream_reader) = server_stream.get_reader() else {
|
||||||
|
return Err("get reader failed".into());
|
||||||
|
};
|
||||||
|
|
||||||
|
let Some(mut stream_writer) = server_stream.get_writer() else {
|
||||||
|
return Err("get writer failed".into());
|
||||||
|
};
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
let len = UdpGwClient::recv_udp_packet(&mut udp_stack, &mut server_stream).await;
|
tokio::select! {
|
||||||
|
len = UdpGwClient::recv_udp_packet(&mut udp_stack, &mut stream_writer) => {
|
||||||
let read_len;
|
let read_len;
|
||||||
match len {
|
match len {
|
||||||
Ok(n) => {
|
Ok(n) => {
|
||||||
if n == 0 {
|
if n == 0 {
|
||||||
log::info!("Ending {}", udpinfo);
|
log::info!("Ending {} <- {}",udpinfo, &tcp_local_addr);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
read_len = n;
|
read_len = n;
|
||||||
crate::traffic_status::traffic_status_update(n, 0)?;
|
crate::traffic_status::traffic_status_update(n, 0)?;
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
log::info!("Ending {} with recv_udp_packet error: {}", udpinfo, e);
|
log::info!("Ending {} <- {} with recv_udp_packet {}", udpinfo, &tcp_local_addr, e);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
let newid = server_stream.newid();
|
||||||
if let Err(e) =
|
if let Err(e) =
|
||||||
UdpGwClient::send_udpgw_packet(ipv6_enabled, read_len, udp_server_addr, domain_name.as_ref(), &mut server_stream).await
|
UdpGwClient::send_udpgw_packet(ipv6_enabled, read_len, udp_server_addr, domain_name.as_ref(),newid,&mut stream_writer).await
|
||||||
{
|
{
|
||||||
log::info!(
|
log::info!(
|
||||||
"{:?},Ending {} with send_udpgw_packet error: {}",
|
"Ending {} <- {} with send_udpgw_packet {}",
|
||||||
server_stream.local_addr(),
|
|
||||||
udpinfo,
|
udpinfo,
|
||||||
|
&tcp_local_addr,
|
||||||
e
|
e
|
||||||
);
|
);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
server_stream.update_activity();
|
||||||
match UdpGwClient::recv_udpgw_packet(udp_mtu, &mut server_stream).await {
|
}
|
||||||
|
ret = UdpGwClient::recv_udpgw_packet(udp_mtu, udp_timeout, &mut stream_reader) => {
|
||||||
|
match ret {
|
||||||
Ok(packet) => match packet {
|
Ok(packet) => match packet {
|
||||||
//should not received keepalive
|
//should not received keepalive
|
||||||
UdpGwResponse::KeepAlive => {
|
UdpGwResponse::KeepAlive => {
|
||||||
log::error!("Ending {} with recv keepalive", udpinfo);
|
log::error!("Ending {} <- {} with recv keepalive", udpinfo, &tcp_local_addr);
|
||||||
let _ = server_stream.close().await;
|
server_stream.close();
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
//server udp may be timeout,can continue to receive udp data?
|
||||||
UdpGwResponse::Error => {
|
UdpGwResponse::Error => {
|
||||||
log::info!("Ending {} with recv udp error", udpinfo);
|
log::info!("Ending {} <- {} with recv udp error", udpinfo, &tcp_local_addr);
|
||||||
|
server_stream.update_activity();
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
UdpGwResponse::Data(data) => {
|
UdpGwResponse::Data(data) => {
|
||||||
crate::traffic_status::traffic_status_update(0, data.len())?;
|
let len = data.len();
|
||||||
|
|
||||||
if let Err(e) = UdpGwClient::send_udp_packet(data, &mut udp_stack).await {
|
if let Err(e) = UdpGwClient::send_udp_packet(data, &mut udp_stack).await {
|
||||||
log::info!("Ending {} with send_udp_packet error: {}", udpinfo, e);
|
log::error!("Ending {} <- {} with send_udp_packet {}", udpinfo, &tcp_local_addr, e);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
crate::traffic_status::traffic_status_update(0, len)?;
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
log::info!("Ending {} with recv_udpgw_packet error: {}", udpinfo, e);
|
log::warn!("Ending {} <- {} with recv_udpgw_packet {}", udpinfo, &tcp_local_addr, e);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
server_stream.update_activity();
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !server_stream.is_closed() {
|
if !server_stream.is_closed() {
|
||||||
udpgw_client.release_server_connection(server_stream).await;
|
udpgw_client.release_server_connection_with_stream(server_stream,stream_reader,stream_writer).await;
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|
230
src/udpgw.rs
230
src/udpgw.rs
|
@ -4,9 +4,8 @@ use std::collections::VecDeque;
|
||||||
use std::hash::Hash;
|
use std::hash::Hash;
|
||||||
use std::mem;
|
use std::mem;
|
||||||
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
|
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
|
||||||
use std::pin::Pin;
|
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||||
use std::task::{Context, Poll};
|
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
|
||||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf};
|
|
||||||
use tokio::net::TcpStream;
|
use tokio::net::TcpStream;
|
||||||
use tokio::sync::Mutex;
|
use tokio::sync::Mutex;
|
||||||
use tokio::time::{sleep, Duration};
|
use tokio::time::{sleep, Duration};
|
||||||
|
@ -108,45 +107,56 @@ pub(crate) enum UdpGwResponse<'a> {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub(crate) struct UdpGwClientStream {
|
pub(crate) struct UdpGwClientStreamWriter {
|
||||||
inner: TcpStream,
|
inner: OwnedWriteHalf,
|
||||||
conid: u16,
|
|
||||||
tmp_buf: Vec<u8>,
|
tmp_buf: Vec<u8>,
|
||||||
send_buf: Vec<u8>,
|
send_buf: Vec<u8>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub(crate) struct UdpGwClientStreamReader {
|
||||||
|
inner: OwnedReadHalf,
|
||||||
recv_buf: Vec<u8>,
|
recv_buf: Vec<u8>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub(crate) struct UdpGwClientStream {
|
||||||
|
local_addr: String,
|
||||||
|
writer: Option<UdpGwClientStreamWriter>,
|
||||||
|
reader: Option<UdpGwClientStreamReader>,
|
||||||
|
conid: u16,
|
||||||
closed: bool,
|
closed: bool,
|
||||||
last_activity: std::time::Instant,
|
last_activity: std::time::Instant,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AsyncWrite for UdpGwClientStream {
|
|
||||||
fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<tokio::io::Result<usize>> {
|
|
||||||
Pin::new(&mut self.inner).poll_write(cx, buf)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
|
|
||||||
Pin::new(&mut self.inner).poll_flush(cx)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
|
|
||||||
Pin::new(&mut self.inner).poll_shutdown(cx)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl AsyncRead for UdpGwClientStream {
|
|
||||||
fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<std::io::Result<()>> {
|
|
||||||
Pin::new(&mut self.inner).poll_read(cx, buf)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl UdpGwClientStream {
|
impl UdpGwClientStream {
|
||||||
pub async fn close(&mut self) -> Result<()> {
|
pub fn close(&mut self) {
|
||||||
self.inner.shutdown().await?;
|
|
||||||
self.closed = true;
|
self.closed = true;
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
pub fn local_addr(&self) -> Result<SocketAddr> {
|
pub fn get_reader(&mut self) -> Option<UdpGwClientStreamReader> {
|
||||||
Ok(self.inner.local_addr()?)
|
self.reader.take()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn set_reader(&mut self, mut reader: Option<UdpGwClientStreamReader>) {
|
||||||
|
self.reader = reader.take();
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn set_writer(&mut self, mut writer: Option<UdpGwClientStreamWriter>) {
|
||||||
|
self.writer = writer.take();
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_writer(&mut self) -> Option<UdpGwClientStreamWriter> {
|
||||||
|
self.writer.take()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn local_addr(&self) -> &String {
|
||||||
|
&self.local_addr
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn update_activity(&mut self) {
|
||||||
|
self.last_activity = std::time::Instant::now();
|
||||||
|
}
|
||||||
|
|
||||||
pub fn is_closed(&mut self) -> bool {
|
pub fn is_closed(&mut self) -> bool {
|
||||||
self.closed
|
self.closed
|
||||||
}
|
}
|
||||||
|
@ -156,16 +166,28 @@ impl UdpGwClientStream {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn newid(&mut self) -> u16 {
|
pub fn newid(&mut self) -> u16 {
|
||||||
let next = self.conid;
|
|
||||||
self.conid += 1;
|
self.conid += 1;
|
||||||
return next;
|
self.conid
|
||||||
}
|
}
|
||||||
pub fn new(udp_mtu: u16, tcp_server_stream: TcpStream) -> Self {
|
pub fn new(udp_mtu: u16, tcp_server_stream: TcpStream) -> Self {
|
||||||
UdpGwClientStream {
|
let local_addr = tcp_server_stream
|
||||||
inner: tcp_server_stream,
|
.local_addr()
|
||||||
|
.unwrap_or_else(|_| "0.0.0.0:0".parse::<SocketAddr>().unwrap())
|
||||||
|
.to_string();
|
||||||
|
let (rx, tx) = tcp_server_stream.into_split();
|
||||||
|
let writer = UdpGwClientStreamWriter {
|
||||||
|
inner: tx,
|
||||||
tmp_buf: vec![0; udp_mtu.into()],
|
tmp_buf: vec![0; udp_mtu.into()],
|
||||||
send_buf: vec![0; udp_mtu.into()],
|
send_buf: vec![0; udp_mtu.into()],
|
||||||
|
};
|
||||||
|
let reader = UdpGwClientStreamReader {
|
||||||
|
inner: rx,
|
||||||
recv_buf: vec![0; udp_mtu.into()],
|
recv_buf: vec![0; udp_mtu.into()],
|
||||||
|
};
|
||||||
|
UdpGwClientStream {
|
||||||
|
local_addr,
|
||||||
|
reader: Some(reader),
|
||||||
|
writer: Some(writer),
|
||||||
last_activity: std::time::Instant::now(),
|
last_activity: std::time::Instant::now(),
|
||||||
closed: false,
|
closed: false,
|
||||||
conid: 0,
|
conid: 0,
|
||||||
|
@ -176,7 +198,8 @@ impl UdpGwClientStream {
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub(crate) struct UdpGwClient {
|
pub(crate) struct UdpGwClient {
|
||||||
udp_mtu: u16,
|
udp_mtu: u16,
|
||||||
max_connections: usize,
|
max_connections: u16,
|
||||||
|
udp_timeout: u64,
|
||||||
keepalive_time: Duration,
|
keepalive_time: Duration,
|
||||||
udpgw_bind_addr: SocketAddr,
|
udpgw_bind_addr: SocketAddr,
|
||||||
keepalive_packet: Vec<u8>,
|
keepalive_packet: Vec<u8>,
|
||||||
|
@ -184,18 +207,19 @@ pub(crate) struct UdpGwClient {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl UdpGwClient {
|
impl UdpGwClient {
|
||||||
pub fn new(udp_mtu: u16, max_connections: usize, keepalive_time: Duration, udpgw_bind_addr: SocketAddr) -> Self {
|
pub fn new(udp_mtu: u16, max_connections: u16, keepalive_time: Duration, udp_timeout: u64, udpgw_bind_addr: SocketAddr) -> Self {
|
||||||
let mut keepalive_packet = vec![];
|
let mut keepalive_packet = vec![];
|
||||||
keepalive_packet.extend_from_slice(&(std::mem::size_of::<UdpgwHeader>() as u16).to_le_bytes());
|
keepalive_packet.extend_from_slice(&(std::mem::size_of::<UdpgwHeader>() as u16).to_le_bytes());
|
||||||
keepalive_packet.extend_from_slice(&[UDPGW_FLAG_KEEPALIVE, 0, 0]);
|
keepalive_packet.extend_from_slice(&[UDPGW_FLAG_KEEPALIVE, 0, 0]);
|
||||||
let server_connections = Mutex::new(VecDeque::new());
|
let server_connections = Mutex::new(VecDeque::with_capacity(max_connections as usize));
|
||||||
return UdpGwClient {
|
return UdpGwClient {
|
||||||
udp_mtu,
|
udp_mtu,
|
||||||
max_connections,
|
max_connections,
|
||||||
|
udp_timeout,
|
||||||
udpgw_bind_addr,
|
udpgw_bind_addr,
|
||||||
keepalive_time,
|
keepalive_time,
|
||||||
keepalive_packet,
|
keepalive_packet,
|
||||||
server_connections: server_connections,
|
server_connections,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -203,12 +227,33 @@ impl UdpGwClient {
|
||||||
self.udp_mtu
|
self.udp_mtu
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn get_udp_timeout(&self) -> u64 {
|
||||||
|
self.udp_timeout
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) async fn is_full(&self) -> bool {
|
||||||
|
self.server_connections.lock().await.len() >= self.max_connections as usize
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) async fn get_server_connection(&self) -> Option<UdpGwClientStream> {
|
pub(crate) async fn get_server_connection(&self) -> Option<UdpGwClientStream> {
|
||||||
self.server_connections.lock().await.pop_front()
|
self.server_connections.lock().await.pop_front()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) async fn release_server_connection(&self, stream: UdpGwClientStream) {
|
pub(crate) async fn release_server_connection(&self, stream: UdpGwClientStream) {
|
||||||
if self.server_connections.lock().await.len() < self.max_connections {
|
if self.server_connections.lock().await.len() < self.max_connections as usize {
|
||||||
|
self.server_connections.lock().await.push_back(stream);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) async fn release_server_connection_with_stream(
|
||||||
|
&self,
|
||||||
|
mut stream: UdpGwClientStream,
|
||||||
|
reader: UdpGwClientStreamReader,
|
||||||
|
writer: UdpGwClientStreamWriter,
|
||||||
|
) {
|
||||||
|
if self.server_connections.lock().await.len() < self.max_connections as usize {
|
||||||
|
stream.set_reader(Some(reader));
|
||||||
|
stream.set_writer(Some(writer));
|
||||||
self.server_connections.lock().await.push_back(stream);
|
self.server_connections.lock().await.push_back(stream);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -217,6 +262,7 @@ impl UdpGwClient {
|
||||||
return self.udpgw_bind_addr;
|
return self.udpgw_bind_addr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Heartbeat task asynchronous function to periodically check and maintain the active state of the server connection.
|
||||||
pub(crate) async fn heartbeat_task(&self) {
|
pub(crate) async fn heartbeat_task(&self) {
|
||||||
loop {
|
loop {
|
||||||
sleep(self.keepalive_time).await;
|
sleep(self.keepalive_time).await;
|
||||||
|
@ -225,28 +271,35 @@ impl UdpGwClient {
|
||||||
self.release_server_connection(stream).await;
|
self.release_server_connection(stream).await;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
log::debug!("{:?}:{} send keepalive", stream.local_addr(), stream.id());
|
|
||||||
if let Err(e) = stream.write_all(&self.keepalive_packet).await {
|
let Some(mut stream_reader) = stream.get_reader() else {
|
||||||
let _ = stream.close().await;
|
continue;
|
||||||
log::warn!("{:?}:{} Heartbeat failed: {}", stream.local_addr(), stream.id(), e);
|
};
|
||||||
|
|
||||||
|
let Some(mut stream_writer) = stream.get_writer() else {
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
log::debug!("{:?}:{} send keepalive", stream_writer.inner.local_addr(), stream.id());
|
||||||
|
if let Err(e) = stream_writer.inner.write_all(&self.keepalive_packet).await {
|
||||||
|
log::warn!("{:?}:{} Heartbeat failed: {}", stream_writer.inner.local_addr(), stream.id(), e);
|
||||||
} else {
|
} else {
|
||||||
stream.last_activity = std::time::Instant::now();
|
match UdpGwClient::recv_udpgw_packet(self.udp_mtu, 10, &mut stream_reader).await {
|
||||||
match UdpGwClient::recv_udpgw_packet(self.udp_mtu, &mut stream).await {
|
|
||||||
Ok(UdpGwResponse::KeepAlive) => {
|
Ok(UdpGwResponse::KeepAlive) => {
|
||||||
self.release_server_connection(stream).await;
|
stream.last_activity = std::time::Instant::now();
|
||||||
continue;
|
self.release_server_connection_with_stream(stream, stream_reader, stream_writer)
|
||||||
}
|
.await;
|
||||||
//shoud not receive other
|
|
||||||
_ => {
|
|
||||||
continue;
|
|
||||||
}
|
}
|
||||||
|
//shoud not receive other type
|
||||||
|
_ => {}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn parse_udp_response(udp_mtu: u16, data_len: usize, data: &[u8]) -> Result<UdpGwResponse> {
|
/// Parses the UDP response data.
|
||||||
|
pub(crate) fn parse_udp_response(udp_mtu: u16, data_len: usize, stream: &mut UdpGwClientStreamReader) -> Result<UdpGwResponse> {
|
||||||
|
let data = &stream.recv_buf;
|
||||||
if data_len < mem::size_of::<UdpgwHeader>() {
|
if data_len < mem::size_of::<UdpgwHeader>() {
|
||||||
return Err("Invalid udpgw data".into());
|
return Err("Invalid udpgw data".into());
|
||||||
}
|
}
|
||||||
|
@ -259,7 +312,6 @@ impl UdpGwClient {
|
||||||
let flags = header.flags;
|
let flags = header.flags;
|
||||||
let conid = header.conid;
|
let conid = header.conid;
|
||||||
|
|
||||||
// parse address
|
|
||||||
let ip_data = &data[mem::size_of::<UdpgwHeader>()..];
|
let ip_data = &data[mem::size_of::<UdpgwHeader>()..];
|
||||||
let mut data_len = data_len - mem::size_of::<UdpgwHeader>();
|
let mut data_len = data_len - mem::size_of::<UdpgwHeader>();
|
||||||
|
|
||||||
|
@ -267,7 +319,7 @@ impl UdpGwClient {
|
||||||
return Ok(UdpGwResponse::Error);
|
return Ok(UdpGwResponse::Error);
|
||||||
}
|
}
|
||||||
|
|
||||||
if flags & UDPGW_FLAG_ERR != 0 {
|
if flags & UDPGW_FLAG_KEEPALIVE != 0 {
|
||||||
return Ok(UdpGwResponse::KeepAlive);
|
return Ok(UdpGwResponse::KeepAlive);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -281,7 +333,7 @@ impl UdpGwClient {
|
||||||
addr_port: u16::from_be_bytes([addr_ipv6_bytes[16], addr_ipv6_bytes[17]]),
|
addr_port: u16::from_be_bytes([addr_ipv6_bytes[16], addr_ipv6_bytes[17]]),
|
||||||
};
|
};
|
||||||
data_len -= mem::size_of::<UdpgwAddrIpv6>();
|
data_len -= mem::size_of::<UdpgwAddrIpv6>();
|
||||||
// check payload length
|
|
||||||
if data_len > udp_mtu as usize {
|
if data_len > udp_mtu as usize {
|
||||||
return Err("too much data".into());
|
return Err("too much data".into());
|
||||||
}
|
}
|
||||||
|
@ -302,7 +354,6 @@ impl UdpGwClient {
|
||||||
};
|
};
|
||||||
data_len -= mem::size_of::<UdpgwAddrIpv4>();
|
data_len -= mem::size_of::<UdpgwAddrIpv4>();
|
||||||
|
|
||||||
// check payload length
|
|
||||||
if data_len > udp_mtu as usize {
|
if data_len > udp_mtu as usize {
|
||||||
return Err("too much data".into());
|
return Err("too much data".into());
|
||||||
}
|
}
|
||||||
|
@ -317,7 +368,7 @@ impl UdpGwClient {
|
||||||
|
|
||||||
pub(crate) async fn recv_udp_packet(
|
pub(crate) async fn recv_udp_packet(
|
||||||
udp_stack: &mut IpStackUdpStream,
|
udp_stack: &mut IpStackUdpStream,
|
||||||
stream: &mut UdpGwClientStream,
|
stream: &mut UdpGwClientStreamWriter,
|
||||||
) -> std::result::Result<usize, std::io::Error> {
|
) -> std::result::Result<usize, std::io::Error> {
|
||||||
return udp_stack.read(&mut stream.tmp_buf).await;
|
return udp_stack.read(&mut stream.tmp_buf).await;
|
||||||
}
|
}
|
||||||
|
@ -329,22 +380,35 @@ impl UdpGwClient {
|
||||||
return udp_stack.write_all(&packet.udpdata).await;
|
return udp_stack.write_all(&packet.udpdata).await;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) async fn recv_udpgw_packet(udp_mtu: u16, stream: &mut UdpGwClientStream) -> Result<UdpGwResponse> {
|
/// Receives a UDP gateway packet.
|
||||||
stream.recv_buf.resize(2, 0);
|
///
|
||||||
|
/// This function is responsible for receiving packets from the UDP gateway
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
/// - `udp_mtu`: The maximum transmission unit size for UDP packets.
|
||||||
|
/// - `udp_timeout`: The timeout in seconds for receiving UDP packets.
|
||||||
|
/// - `stream`: A mutable reference to the UDP gateway client stream reader.
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
/// - `Result<UdpGwResponse>`: Returns a result type containing the parsed UDP gateway response, or an error if one occurs.
|
||||||
|
pub(crate) async fn recv_udpgw_packet(udp_mtu: u16, udp_timeout: u64, stream: &mut UdpGwClientStreamReader) -> Result<UdpGwResponse> {
|
||||||
let result;
|
let result;
|
||||||
match tokio::time::timeout(tokio::time::Duration::from_secs(10), stream.inner.read(&mut stream.recv_buf)).await {
|
match tokio::time::timeout(
|
||||||
|
tokio::time::Duration::from_secs(udp_timeout + 2),
|
||||||
|
stream.inner.read(&mut stream.recv_buf[..2]),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
Ok(ret) => {
|
Ok(ret) => {
|
||||||
result = ret;
|
result = ret;
|
||||||
}
|
}
|
||||||
Err(_e) => {
|
Err(_e) => {
|
||||||
let _ = stream.close().await;
|
return Err(format!("wait tcp data timeout").into());
|
||||||
return Err(format!("{:?} wait tcp data timeout", stream.local_addr()).into());
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
match result {
|
match result {
|
||||||
Ok(0) => {
|
Ok(0) => {
|
||||||
let _ = stream.close().await;
|
return Err(format!("tcp connection closed").into());
|
||||||
return Err(format!("{:?} tcp connection closed", stream.local_addr()).into());
|
|
||||||
}
|
}
|
||||||
Ok(n) => {
|
Ok(n) => {
|
||||||
if n < std::mem::size_of::<PackLenHeader>() {
|
if n < std::mem::size_of::<PackLenHeader>() {
|
||||||
|
@ -354,41 +418,53 @@ impl UdpGwClient {
|
||||||
if packet_len > udp_mtu {
|
if packet_len > udp_mtu {
|
||||||
return Err("packet too long".into());
|
return Err("packet too long".into());
|
||||||
}
|
}
|
||||||
stream.recv_buf.resize(udp_mtu as usize, 0);
|
|
||||||
let mut left_len: usize = packet_len as usize;
|
let mut left_len: usize = packet_len as usize;
|
||||||
let mut recv_len = 0;
|
let mut recv_len = 0;
|
||||||
while left_len > 0 {
|
while left_len > 0 {
|
||||||
if let Ok(len) = stream.inner.read(&mut stream.recv_buf[recv_len..left_len]).await {
|
if let Ok(len) = stream.inner.read(&mut stream.recv_buf[recv_len..left_len]).await {
|
||||||
if len == 0 {
|
if len == 0 {
|
||||||
let _ = stream.close().await;
|
return Err("tcp connection closed".into());
|
||||||
return Err(format!("{:?} tcp connection closed", stream.local_addr()).into());
|
|
||||||
}
|
}
|
||||||
recv_len += len;
|
recv_len += len;
|
||||||
left_len -= len;
|
left_len -= len;
|
||||||
} else {
|
} else {
|
||||||
let _ = stream.close().await;
|
return Err("tcp connection closed".into());
|
||||||
return Err(format!("{:?} tcp connection closed", stream.local_addr()).into());
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
stream.last_activity = std::time::Instant::now();
|
return UdpGwClient::parse_udp_response(udp_mtu, packet_len as usize, stream);
|
||||||
return UdpGwClient::parse_udp_response(udp_mtu, packet_len as usize, &stream.recv_buf);
|
|
||||||
}
|
}
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
let _ = stream.close().await;
|
return Err("tcp read error".into());
|
||||||
return Err(format!("{:?} tcp read error", stream.local_addr()).into());
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Sends a UDP gateway packet.
|
||||||
|
///
|
||||||
|
/// This function constructs and sends a UDP gateway packet based on the IPv6 enabled status, data length,
|
||||||
|
/// remote address, domain (if any), connection ID, and the UDP gateway client writer stream.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `ipv6_enabled` - Whether IPv6 is enabled
|
||||||
|
/// * `len` - Length of the data packet
|
||||||
|
/// * `remote_addr` - Remote address
|
||||||
|
/// * `domain` - Target domain (optional)
|
||||||
|
/// * `conid` - Connection ID
|
||||||
|
/// * `stream` - UDP gateway client writer stream
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
///
|
||||||
|
/// Returns `Ok(())` if the packet is sent successfully, otherwise returns an error.
|
||||||
pub(crate) async fn send_udpgw_packet(
|
pub(crate) async fn send_udpgw_packet(
|
||||||
ipv6_enabled: bool,
|
ipv6_enabled: bool,
|
||||||
len: usize,
|
len: usize,
|
||||||
remote_addr: SocketAddr,
|
remote_addr: SocketAddr,
|
||||||
domain: Option<&String>,
|
domain: Option<&String>,
|
||||||
stream: &mut UdpGwClientStream,
|
conid: u16,
|
||||||
|
stream: &mut UdpGwClientStreamWriter,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
stream.send_buf.clear();
|
stream.send_buf.clear();
|
||||||
let conid = stream.newid();
|
|
||||||
let data = &stream.tmp_buf;
|
let data = &stream.tmp_buf;
|
||||||
let mut pack_len = std::mem::size_of::<UdpgwHeader>() + len;
|
let mut pack_len = std::mem::size_of::<UdpgwHeader>() + len;
|
||||||
let packet = &mut stream.send_buf;
|
let packet = &mut stream.send_buf;
|
||||||
|
@ -442,8 +518,6 @@ impl UdpGwClient {
|
||||||
|
|
||||||
stream.inner.write_all(&packet).await?;
|
stream.inner.write_all(&packet).await?;
|
||||||
|
|
||||||
stream.last_activity = std::time::Instant::now();
|
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue