mirror of
https://github.com/tun2proxy/tun2proxy.git
synced 2025-04-21 14:29:10 +00:00
refine udpgw
This commit is contained in:
parent
b4142453fd
commit
52d814ce79
3 changed files with 84 additions and 83 deletions
|
@ -116,10 +116,10 @@ pub struct Args {
|
||||||
#[arg(long, value_name = "IP:PORT")]
|
#[arg(long, value_name = "IP:PORT")]
|
||||||
pub udpgw_server: Option<SocketAddr>,
|
pub udpgw_server: Option<SocketAddr>,
|
||||||
|
|
||||||
/// Max udpgw connections, default value is 100
|
/// Max udpgw connections, default value is 5
|
||||||
#[cfg(feature = "udpgw")]
|
#[cfg(feature = "udpgw")]
|
||||||
#[arg(long, value_name = "number", requires = "udpgw_server")]
|
#[arg(long, value_name = "number", requires = "udpgw_server")]
|
||||||
pub udpgw_max_connections: Option<u16>,
|
pub udpgw_max_connections: Option<usize>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn validate_tun(p: &str) -> Result<String> {
|
fn validate_tun(p: &str) -> Result<String> {
|
||||||
|
@ -201,7 +201,7 @@ impl Args {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "udpgw")]
|
#[cfg(feature = "udpgw")]
|
||||||
pub fn udpgw_max_connections(&mut self, udpgw_max_connections: u16) -> &mut Self {
|
pub fn udpgw_max_connections(&mut self, udpgw_max_connections: usize) -> &mut Self {
|
||||||
self.udpgw_max_connections = Some(udpgw_max_connections);
|
self.udpgw_max_connections = Some(udpgw_max_connections);
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
11
src/lib.rs
11
src/lib.rs
|
@ -506,7 +506,7 @@ async fn handle_udp_gateway_session(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
None => {
|
None => {
|
||||||
if udpgw_client.is_full() {
|
if !udpgw_client.is_in_heartbeat_progress() && udpgw_client.is_full().await {
|
||||||
return Err("max udpgw connection limit reached".into());
|
return Err("max udpgw connection limit reached".into());
|
||||||
}
|
}
|
||||||
let mut tcp_server_stream = create_tcp_stream(&socket_queue, proxy_server_addr).await?;
|
let mut tcp_server_stream = create_tcp_stream(&socket_queue, proxy_server_addr).await?;
|
||||||
|
@ -543,13 +543,13 @@ async fn handle_udp_gateway_session(
|
||||||
}
|
}
|
||||||
Ok(n) => n,
|
Ok(n) => n,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
log::info!("[UdpGw] Ending stream {} {} <> {} with recv_udp_packet {}", sn, &tcp_local_addr, udp_dst, e);
|
log::info!("[UdpGw] Ending stream {} {} <> {} with udp stack \"{}\"", sn, &tcp_local_addr, udp_dst, e);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
crate::traffic_status::traffic_status_update(read_len, 0)?;
|
crate::traffic_status::traffic_status_update(read_len, 0)?;
|
||||||
let new_id = stream.new_packet_id();
|
let sn = stream.serial_number();
|
||||||
if let Err(e) = UdpGwClient::send_udpgw_packet(ipv6_enabled, &tmp_buf[0..read_len], udp_dst, new_id, &mut writer).await {
|
if let Err(e) = UdpGwClient::send_udpgw_packet(ipv6_enabled, &tmp_buf[0..read_len], udp_dst, sn, &mut writer).await {
|
||||||
log::info!("[UdpGw] Ending stream {} {} <> {} with send_udpgw_packet {}", sn, &tcp_local_addr, udp_dst, e);
|
log::info!("[UdpGw] Ending stream {} {} <> {} with send_udpgw_packet {}", sn, &tcp_local_addr, udp_dst, e);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
@ -584,7 +584,8 @@ async fn handle_udp_gateway_session(
|
||||||
UdpGwResponse::Data(data) => {
|
UdpGwResponse::Data(data) => {
|
||||||
use socks5_impl::protocol::StreamOperation;
|
use socks5_impl::protocol::StreamOperation;
|
||||||
let len = data.len();
|
let len = data.len();
|
||||||
log::debug!("[UdpGw] stream {} {} <- {} receive len {}", sn, &tcp_local_addr, udp_dst, len);
|
let f = data.header.flags;
|
||||||
|
log::debug!("[UdpGw] stream {sn} {} <- {} receive {f} len {len}", &tcp_local_addr, udp_dst);
|
||||||
if let Err(e) = udp_stack.write_all(&data.data).await {
|
if let Err(e) = udp_stack.write_all(&data.data).await {
|
||||||
log::error!("[UdpGw] Ending stream {} {} <> {} with send_udp_packet {}", sn, &tcp_local_addr, udp_dst, e);
|
log::error!("[UdpGw] Ending stream {} {} <> {} with send_udp_packet {}", sn, &tcp_local_addr, udp_dst, e);
|
||||||
break;
|
break;
|
||||||
|
|
150
src/udpgw.rs
150
src/udpgw.rs
|
@ -12,7 +12,7 @@ use tokio::{
|
||||||
};
|
};
|
||||||
|
|
||||||
pub(crate) const UDPGW_LENGTH_FIELD_SIZE: usize = std::mem::size_of::<u16>();
|
pub(crate) const UDPGW_LENGTH_FIELD_SIZE: usize = std::mem::size_of::<u16>();
|
||||||
pub(crate) const UDPGW_MAX_CONNECTIONS: u16 = 100;
|
pub(crate) const UDPGW_MAX_CONNECTIONS: usize = 5;
|
||||||
pub(crate) const UDPGW_KEEPALIVE_TIME: tokio::time::Duration = std::time::Duration::from_secs(10);
|
pub(crate) const UDPGW_KEEPALIVE_TIME: tokio::time::Duration = std::time::Duration::from_secs(10);
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||||
|
@ -34,7 +34,7 @@ impl std::fmt::Display for UdpFlag {
|
||||||
0x02 => "DATA",
|
0x02 => "DATA",
|
||||||
n => return write!(f, "Unknown UdpFlag(0x{:02X})", n),
|
n => return write!(f, "Unknown UdpFlag(0x{:02X})", n),
|
||||||
};
|
};
|
||||||
write!(f, "UdpFlag({})", flag)
|
write!(f, "{}", flag)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -52,8 +52,6 @@ impl std::ops::BitOr for UdpFlag {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static TCP_COUNTER: std::sync::atomic::AtomicU32 = std::sync::atomic::AtomicU32::new(0);
|
|
||||||
|
|
||||||
/// UDP Gateway Packet Format
|
/// UDP Gateway Packet Format
|
||||||
///
|
///
|
||||||
/// The format is referenced from SOCKS5 packet format, with additional flags and connection ID fields.
|
/// The format is referenced from SOCKS5 packet format, with additional flags and connection ID fields.
|
||||||
|
@ -250,8 +248,7 @@ pub struct UdpgwHeader {
|
||||||
|
|
||||||
impl std::fmt::Display for UdpgwHeader {
|
impl std::fmt::Display for UdpgwHeader {
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
let id = self.conn_id;
|
write!(f, "{} conn_id: {}", self.flags, self.conn_id)
|
||||||
write!(f, "flags: {}, conn_id: {}", self.flags, id)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -329,23 +326,27 @@ pub(crate) enum UdpGwResponse {
|
||||||
Data(Packet),
|
Data(Packet),
|
||||||
}
|
}
|
||||||
|
|
||||||
static SERIAL_NUMBER: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(1);
|
impl std::fmt::Display for UdpGwResponse {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
match self {
|
||||||
|
UdpGwResponse::KeepAlive => write!(f, "KeepAlive"),
|
||||||
|
UdpGwResponse::Error => write!(f, "Error"),
|
||||||
|
UdpGwResponse::TcpClose => write!(f, "TcpClose"),
|
||||||
|
UdpGwResponse::Data(packet) => write!(f, "Data({})", packet),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static SERIAL_NUMBER: std::sync::atomic::AtomicU16 = std::sync::atomic::AtomicU16::new(1);
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub(crate) struct UdpGwClientStream {
|
pub(crate) struct UdpGwClientStream {
|
||||||
local_addr: SocketAddr,
|
local_addr: SocketAddr,
|
||||||
writer: Option<OwnedWriteHalf>,
|
writer: Option<OwnedWriteHalf>,
|
||||||
reader: Option<OwnedReadHalf>,
|
reader: Option<OwnedReadHalf>,
|
||||||
conn_id: u16,
|
|
||||||
closed: bool,
|
closed: bool,
|
||||||
last_activity: std::time::Instant,
|
last_activity: std::time::Instant,
|
||||||
serial_number: u64,
|
serial_number: u16,
|
||||||
}
|
|
||||||
|
|
||||||
impl Drop for UdpGwClientStream {
|
|
||||||
fn drop(&mut self) {
|
|
||||||
TCP_COUNTER.fetch_sub(1, Relaxed);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl UdpGwClientStream {
|
impl UdpGwClientStream {
|
||||||
|
@ -381,20 +382,14 @@ impl UdpGwClientStream {
|
||||||
self.closed
|
self.closed
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn serial_number(&self) -> u64 {
|
pub fn serial_number(&self) -> u16 {
|
||||||
self.serial_number
|
self.serial_number
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn new_packet_id(&mut self) -> u16 {
|
|
||||||
self.conn_id += 1;
|
|
||||||
self.conn_id
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn new(tcp_server_stream: TcpStream) -> Self {
|
pub fn new(tcp_server_stream: TcpStream) -> Self {
|
||||||
let default = "0.0.0.0:0".parse::<SocketAddr>().unwrap();
|
let default = "0.0.0.0:0".parse::<SocketAddr>().unwrap();
|
||||||
let local_addr = tcp_server_stream.local_addr().unwrap_or(default);
|
let local_addr = tcp_server_stream.local_addr().unwrap_or(default);
|
||||||
let (reader, writer) = tcp_server_stream.into_split();
|
let (reader, writer) = tcp_server_stream.into_split();
|
||||||
TCP_COUNTER.fetch_add(1, Relaxed);
|
|
||||||
let serial_number = SERIAL_NUMBER.fetch_add(1, Relaxed);
|
let serial_number = SERIAL_NUMBER.fetch_add(1, Relaxed);
|
||||||
UdpGwClientStream {
|
UdpGwClientStream {
|
||||||
local_addr,
|
local_addr,
|
||||||
|
@ -402,7 +397,6 @@ impl UdpGwClientStream {
|
||||||
writer: Some(writer),
|
writer: Some(writer),
|
||||||
last_activity: std::time::Instant::now(),
|
last_activity: std::time::Instant::now(),
|
||||||
closed: false,
|
closed: false,
|
||||||
conn_id: 0,
|
|
||||||
serial_number,
|
serial_number,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -411,16 +405,17 @@ impl UdpGwClientStream {
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub(crate) struct UdpGwClient {
|
pub(crate) struct UdpGwClient {
|
||||||
udp_mtu: u16,
|
udp_mtu: u16,
|
||||||
max_connections: u16,
|
max_connections: usize,
|
||||||
udp_timeout: u64,
|
udp_timeout: u64,
|
||||||
keepalive_time: Duration,
|
keepalive_time: Duration,
|
||||||
udpgw_server: SocketAddr,
|
udpgw_server: SocketAddr,
|
||||||
server_connections: Mutex<VecDeque<UdpGwClientStream>>,
|
server_connections: Mutex<VecDeque<UdpGwClientStream>>,
|
||||||
|
is_in_heartbeat: std::sync::atomic::AtomicBool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl UdpGwClient {
|
impl UdpGwClient {
|
||||||
pub fn new(udp_mtu: u16, max_connections: u16, keepalive_time: Duration, udp_timeout: u64, udpgw_server: SocketAddr) -> Self {
|
pub fn new(udp_mtu: u16, max_connections: usize, keepalive_time: Duration, udp_timeout: u64, udpgw_server: SocketAddr) -> Self {
|
||||||
let server_connections = Mutex::new(VecDeque::with_capacity(max_connections as usize));
|
let server_connections = Mutex::new(VecDeque::with_capacity(max_connections));
|
||||||
UdpGwClient {
|
UdpGwClient {
|
||||||
udp_mtu,
|
udp_mtu,
|
||||||
max_connections,
|
max_connections,
|
||||||
|
@ -428,6 +423,7 @@ impl UdpGwClient {
|
||||||
udpgw_server,
|
udpgw_server,
|
||||||
keepalive_time,
|
keepalive_time,
|
||||||
server_connections,
|
server_connections,
|
||||||
|
is_in_heartbeat: std::sync::atomic::AtomicBool::new(false),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -439,8 +435,8 @@ impl UdpGwClient {
|
||||||
self.udp_timeout
|
self.udp_timeout
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn is_full(&self) -> bool {
|
pub(crate) async fn is_full(&self) -> bool {
|
||||||
TCP_COUNTER.load(Relaxed) >= self.max_connections as u32
|
self.server_connections.lock().await.len() >= self.max_connections
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) async fn pop_server_connection_from_queue(&self) -> Option<UdpGwClientStream> {
|
pub(crate) async fn pop_server_connection_from_queue(&self) -> Option<UdpGwClientStream> {
|
||||||
|
@ -448,13 +444,13 @@ impl UdpGwClient {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) async fn store_server_connection(&self, stream: UdpGwClientStream) {
|
pub(crate) async fn store_server_connection(&self, stream: UdpGwClientStream) {
|
||||||
if self.server_connections.lock().await.len() < self.max_connections as usize {
|
if self.server_connections.lock().await.len() < self.max_connections {
|
||||||
self.server_connections.lock().await.push_back(stream);
|
self.server_connections.lock().await.push_back(stream);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) async fn store_server_connection_full(&self, mut stream: UdpGwClientStream, reader: OwnedReadHalf, writer: OwnedWriteHalf) {
|
pub(crate) async fn store_server_connection_full(&self, mut stream: UdpGwClientStream, reader: OwnedReadHalf, writer: OwnedWriteHalf) {
|
||||||
if self.server_connections.lock().await.len() < self.max_connections as usize {
|
if self.server_connections.lock().await.len() < self.max_connections {
|
||||||
stream.set_reader(Some(reader));
|
stream.set_reader(Some(reader));
|
||||||
stream.set_writer(Some(writer));
|
stream.set_writer(Some(writer));
|
||||||
self.server_connections.lock().await.push_back(stream);
|
self.server_connections.lock().await.push_back(stream);
|
||||||
|
@ -465,54 +461,59 @@ impl UdpGwClient {
|
||||||
self.udpgw_server
|
self.udpgw_server
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn is_in_heartbeat_progress(&self) -> bool {
|
||||||
|
self.is_in_heartbeat.load(Relaxed)
|
||||||
|
}
|
||||||
|
|
||||||
/// Heartbeat task asynchronous function to periodically check and maintain the active state of the server connection.
|
/// Heartbeat task asynchronous function to periodically check and maintain the active state of the server connection.
|
||||||
pub(crate) async fn heartbeat_task(&self) -> std::io::Result<()> {
|
pub(crate) async fn heartbeat_task(&self) -> std::io::Result<()> {
|
||||||
loop {
|
loop {
|
||||||
|
self.is_in_heartbeat.store(false, Relaxed);
|
||||||
sleep(self.keepalive_time).await;
|
sleep(self.keepalive_time).await;
|
||||||
let Some(mut stream) = self.pop_server_connection_from_queue().await else {
|
self.is_in_heartbeat.store(true, Relaxed);
|
||||||
continue;
|
let mut streams = Vec::new();
|
||||||
};
|
|
||||||
|
|
||||||
if stream.is_closed() {
|
while let Some(stream) = self.pop_server_connection_from_queue().await {
|
||||||
// This stream will be dropped
|
if !stream.is_closed() {
|
||||||
continue;
|
streams.push(stream);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
if stream.last_activity.elapsed() < self.keepalive_time {
|
|
||||||
self.store_server_connection(stream).await;
|
for mut stream in streams {
|
||||||
continue;
|
if stream.last_activity.elapsed() < self.keepalive_time {
|
||||||
}
|
self.store_server_connection(stream).await;
|
||||||
|
continue;
|
||||||
let Some(mut stream_reader) = stream.get_reader() else {
|
}
|
||||||
continue;
|
|
||||||
};
|
let Some(mut stream_reader) = stream.get_reader() else {
|
||||||
|
continue;
|
||||||
let Some(mut stream_writer) = stream.get_writer() else {
|
};
|
||||||
continue;
|
|
||||||
};
|
let Some(mut stream_writer) = stream.get_writer() else {
|
||||||
let local_addr = stream_writer.local_addr()?;
|
continue;
|
||||||
let sn = stream.serial_number();
|
};
|
||||||
log::trace!("stream {} {:?} send keepalive", sn, local_addr);
|
let local_addr = stream_writer.local_addr()?;
|
||||||
let keepalive_packet: Vec<u8> = Packet::build_keepalive_packet(stream.new_packet_id()).into();
|
let sn = stream.serial_number();
|
||||||
if let Err(e) = stream_writer.write_all(&keepalive_packet).await {
|
let keepalive_packet: Vec<u8> = Packet::build_keepalive_packet(sn).into();
|
||||||
log::warn!("stream {} {:?} send keepalive failed: {}", sn, local_addr, e);
|
if let Err(e) = stream_writer.write_all(&keepalive_packet).await {
|
||||||
continue;
|
log::warn!("stream {} {:?} send keepalive failed: {}", sn, local_addr, e);
|
||||||
}
|
continue;
|
||||||
match UdpGwClient::recv_udpgw_packet(self.udp_mtu, 10, &mut stream_reader).await {
|
}
|
||||||
Ok(UdpGwResponse::KeepAlive) => {
|
match UdpGwClient::recv_udpgw_packet(self.udp_mtu, self.udp_timeout, &mut stream_reader).await {
|
||||||
stream.update_activity();
|
Ok(UdpGwResponse::KeepAlive) => {
|
||||||
self.store_server_connection_full(stream, stream_reader, stream_writer).await;
|
stream.update_activity();
|
||||||
log::trace!("stream {} {:?} keepalive success", sn, local_addr);
|
self.store_server_connection_full(stream, stream_reader, stream_writer).await;
|
||||||
|
log::trace!("stream {sn} {:?} send keepalive and recieve it successfully", local_addr);
|
||||||
|
}
|
||||||
|
Ok(v) => log::debug!("stream {sn} {:?} keepalive unexpected response: {v}", local_addr),
|
||||||
|
Err(e) => log::debug!("stream {sn} {:?} keepalive no response, error \"{e}\"", local_addr),
|
||||||
}
|
}
|
||||||
Ok(v) => log::warn!("stream {} {:?} keepalive unexpected response: {:?}", sn, local_addr, v),
|
|
||||||
Err(e) => log::warn!("stream {} {:?} keepalive no response, error \"{}\"", sn, local_addr, e),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Parses the UDP response data.
|
/// Parses the UDP response data.
|
||||||
pub(crate) fn parse_udp_response(udp_mtu: u16, data: &[u8]) -> Result<UdpGwResponse> {
|
pub(crate) fn parse_udp_response(udp_mtu: u16, packet: Packet) -> Result<UdpGwResponse> {
|
||||||
let packet = Packet::try_from(data)?;
|
|
||||||
let flags = packet.header.flags;
|
let flags = packet.header.flags;
|
||||||
if flags & UdpFlag::ERR == UdpFlag::ERR {
|
if flags & UdpFlag::ERR == UdpFlag::ERR {
|
||||||
return Ok(UdpGwResponse::Error);
|
return Ok(UdpGwResponse::Error);
|
||||||
|
@ -538,14 +539,13 @@ impl UdpGwClient {
|
||||||
/// # Returns
|
/// # Returns
|
||||||
/// - `Result<UdpGwResponse>`: Returns a result type containing the parsed UDP gateway response, or an error if one occurs.
|
/// - `Result<UdpGwResponse>`: Returns a result type containing the parsed UDP gateway response, or an error if one occurs.
|
||||||
pub(crate) async fn recv_udpgw_packet(udp_mtu: u16, udp_timeout: u64, stream: &mut OwnedReadHalf) -> Result<UdpGwResponse> {
|
pub(crate) async fn recv_udpgw_packet(udp_mtu: u16, udp_timeout: u64, stream: &mut OwnedReadHalf) -> Result<UdpGwResponse> {
|
||||||
let mut data = vec![0; udp_mtu.into()];
|
let packet = tokio::time::timeout(
|
||||||
let data_len = tokio::time::timeout(tokio::time::Duration::from_secs(udp_timeout + 2), stream.read(&mut data))
|
tokio::time::Duration::from_secs(udp_timeout + 2),
|
||||||
.await
|
Packet::retrieve_from_async_stream(stream),
|
||||||
.map_err(std::io::Error::from)??;
|
)
|
||||||
if data_len == 0 {
|
.await
|
||||||
return Ok(UdpGwResponse::TcpClose);
|
.map_err(std::io::Error::from)??;
|
||||||
}
|
UdpGwClient::parse_udp_response(udp_mtu, packet)
|
||||||
UdpGwClient::parse_udp_response(udp_mtu, &data[..data_len])
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Sends a UDP gateway packet.
|
/// Sends a UDP gateway packet.
|
||||||
|
|
Loading…
Add table
Reference in a new issue