mirror of
https://github.com/tun2proxy/tun2proxy.git
synced 2025-04-19 21:39:09 +00:00
refactor udpgw
This commit is contained in:
parent
3fb02f0fc7
commit
0aad0d1709
3 changed files with 149 additions and 98 deletions
|
@ -6,10 +6,10 @@ use tokio::{
|
||||||
tcp::{ReadHalf, WriteHalf},
|
tcp::{ReadHalf, WriteHalf},
|
||||||
UdpSocket,
|
UdpSocket,
|
||||||
},
|
},
|
||||||
sync::mpsc::{self, Receiver, Sender},
|
sync::mpsc::{Receiver, Sender},
|
||||||
};
|
};
|
||||||
use tun2proxy::{
|
use tun2proxy::{
|
||||||
udpgw::{Packet, UDPGW_FLAG_KEEPALIVE},
|
udpgw::{Packet, UdpFlag},
|
||||||
ArgVerbosity, BoxError, Error, Result,
|
ArgVerbosity, BoxError, Error, Result,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -134,7 +134,7 @@ async fn process_client_udp_req(args: &UdpGwArgs, tx: Sender<Packet>, mut client
|
||||||
|
|
||||||
let flags = packet.header.flags;
|
let flags = packet.header.flags;
|
||||||
let conn_id = packet.header.conn_id;
|
let conn_id = packet.header.conn_id;
|
||||||
if flags & UDPGW_FLAG_KEEPALIVE != 0 {
|
if flags & UdpFlag::KEEPALIVE == UdpFlag::KEEPALIVE {
|
||||||
log::trace!("client {} send keepalive", client.addr);
|
log::trace!("client {} send keepalive", client.addr);
|
||||||
// 2. if keepalive packet, do nothing, send keepalive response to client
|
// 2. if keepalive packet, do nothing, send keepalive response to client
|
||||||
send_keepalive_response(tx.clone(), conn_id).await;
|
send_keepalive_response(tx.clone(), conn_id).await;
|
||||||
|
@ -227,7 +227,7 @@ pub async fn run(args: UdpGwArgs, shutdown_token: tokio_util::sync::Cancellation
|
||||||
log::info!("client {} connected", addr);
|
log::info!("client {} connected", addr);
|
||||||
let params = args.clone();
|
let params = args.clone();
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
let (tx, rx) = mpsc::channel::<Packet>(100);
|
let (tx, rx) = tokio::sync::mpsc::channel::<Packet>(100);
|
||||||
let (tcp_read_stream, tcp_write_stream) = tcp_stream.split();
|
let (tcp_read_stream, tcp_write_stream) = tcp_stream.split();
|
||||||
let res = tokio::select! {
|
let res = tokio::select! {
|
||||||
v = process_client_udp_req(¶ms, tx, client, tcp_read_stream) => v,
|
v = process_client_udp_req(¶ms, tx, client, tcp_read_stream) => v,
|
||||||
|
|
65
src/lib.rs
65
src/lib.rs
|
@ -251,7 +251,7 @@ where
|
||||||
));
|
));
|
||||||
let client_keepalive = client.clone();
|
let client_keepalive = client.clone();
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
client_keepalive.heartbeat_task().await;
|
let _ = client_keepalive.heartbeat_task().await;
|
||||||
});
|
});
|
||||||
client
|
client
|
||||||
});
|
});
|
||||||
|
@ -349,7 +349,7 @@ where
|
||||||
SocketAddr::V4(_) => SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)),
|
SocketAddr::V4(_) => SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)),
|
||||||
SocketAddr::V6(_) => SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, 0, 0, 0)),
|
SocketAddr::V6(_) => SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, 0, 0, 0)),
|
||||||
};
|
};
|
||||||
let tcpinfo = SessionInfo::new(tcp_src, udpgw.get_server_addr(), IpProtocol::Tcp);
|
let tcpinfo = SessionInfo::new(tcp_src, udpgw.get_udpgw_server_addr(), IpProtocol::Tcp);
|
||||||
let proxy_handler = mgr.new_proxy_handler(tcpinfo, None, false).await?;
|
let proxy_handler = mgr.new_proxy_handler(tcpinfo, None, false).await?;
|
||||||
let queue = socket_queue.clone();
|
let queue = socket_queue.clone();
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
|
@ -495,23 +495,33 @@ async fn handle_udp_gateway_session(
|
||||||
let proxy_server_addr = { proxy_handler.lock().await.get_server_addr() };
|
let proxy_server_addr = { proxy_handler.lock().await.get_server_addr() };
|
||||||
let udp_mtu = udpgw_client.get_udp_mtu();
|
let udp_mtu = udpgw_client.get_udp_mtu();
|
||||||
let udp_timeout = udpgw_client.get_udp_timeout();
|
let udp_timeout = udpgw_client.get_udp_timeout();
|
||||||
let mut stream = match udpgw_client.get_server_connection().await {
|
|
||||||
Some(server) => server,
|
let mut stream = loop {
|
||||||
None => {
|
match udpgw_client.pop_server_connection_from_queue().await {
|
||||||
if udpgw_client.is_full() {
|
Some(stream) => {
|
||||||
return Err("max udpgw connection limit reached".into());
|
if stream.is_closed() {
|
||||||
|
continue;
|
||||||
|
} else {
|
||||||
|
break stream;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
let mut tcp_server_stream = create_tcp_stream(&socket_queue, proxy_server_addr).await?;
|
None => {
|
||||||
if let Err(e) = handle_proxy_session(&mut tcp_server_stream, proxy_handler).await {
|
if udpgw_client.is_full() {
|
||||||
return Err(format!("udpgw connection error: {}", e).into());
|
return Err("max udpgw connection limit reached".into());
|
||||||
|
}
|
||||||
|
let mut tcp_server_stream = create_tcp_stream(&socket_queue, proxy_server_addr).await?;
|
||||||
|
if let Err(e) = handle_proxy_session(&mut tcp_server_stream, proxy_handler).await {
|
||||||
|
return Err(format!("udpgw connection error: {}", e).into());
|
||||||
|
}
|
||||||
|
break UdpGwClientStream::new(tcp_server_stream);
|
||||||
}
|
}
|
||||||
UdpGwClientStream::new(tcp_server_stream)
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let tcp_local_addr = stream.local_addr().clone();
|
let tcp_local_addr = stream.local_addr().clone();
|
||||||
|
let sn = stream.serial_number();
|
||||||
|
|
||||||
log::info!("[UdpGw] Beginning {} -> {}", &tcp_local_addr, udp_dst);
|
log::info!("[UdpGw] Beginning stream {} {} -> {}", sn, &tcp_local_addr, udp_dst);
|
||||||
|
|
||||||
let Some(mut reader) = stream.get_reader() else {
|
let Some(mut reader) = stream.get_reader() else {
|
||||||
return Err("get reader failed".into());
|
return Err("get reader failed".into());
|
||||||
|
@ -528,58 +538,59 @@ async fn handle_udp_gateway_session(
|
||||||
len = udp_stack.read(&mut tmp_buf) => {
|
len = udp_stack.read(&mut tmp_buf) => {
|
||||||
let read_len = match len {
|
let read_len = match len {
|
||||||
Ok(0) => {
|
Ok(0) => {
|
||||||
log::info!("[UdpGw] Ending {} <> {}", &tcp_local_addr, udp_dst);
|
log::info!("[UdpGw] Ending stream {} {} <> {}", sn, &tcp_local_addr, udp_dst);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
Ok(n) => n,
|
Ok(n) => n,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
log::info!("[UdpGw] Ending {} <> {} with recv_udp_packet {}", &tcp_local_addr, udp_dst, e);
|
log::info!("[UdpGw] Ending stream {} {} <> {} with recv_udp_packet {}", sn, &tcp_local_addr, udp_dst, e);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
crate::traffic_status::traffic_status_update(read_len, 0)?;
|
crate::traffic_status::traffic_status_update(read_len, 0)?;
|
||||||
let new_id = stream.new_id();
|
let new_id = stream.new_packet_id();
|
||||||
if let Err(e) = UdpGwClient::send_udpgw_packet(ipv6_enabled, &tmp_buf[0..read_len], udp_dst, new_id, &mut writer).await {
|
if let Err(e) = UdpGwClient::send_udpgw_packet(ipv6_enabled, &tmp_buf[0..read_len], udp_dst, new_id, &mut writer).await {
|
||||||
log::info!("[UdpGw] Ending {} <> {} with send_udpgw_packet {}", &tcp_local_addr, udp_dst, e);
|
log::info!("[UdpGw] Ending stream {} {} <> {} with send_udpgw_packet {}", sn, &tcp_local_addr, udp_dst, e);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
log::debug!("[UdpGw] {} -> {} send len {}", &tcp_local_addr, udp_dst, read_len);
|
log::debug!("[UdpGw] stream {} {} -> {} send len {}", sn, &tcp_local_addr, udp_dst, read_len);
|
||||||
stream.update_activity();
|
stream.update_activity();
|
||||||
}
|
}
|
||||||
ret = UdpGwClient::recv_udpgw_packet(udp_mtu, udp_timeout, &mut reader) => {
|
ret = UdpGwClient::recv_udpgw_packet(udp_mtu, udp_timeout, &mut reader) => {
|
||||||
match ret {
|
match ret {
|
||||||
|
Err(e) => {
|
||||||
|
log::warn!("[UdpGw] Ending stream {} {} <> {} with recv_udpgw_packet {}", sn, &tcp_local_addr, udp_dst, e);
|
||||||
|
stream.close();
|
||||||
|
break;
|
||||||
|
}
|
||||||
Ok(packet) => match packet {
|
Ok(packet) => match packet {
|
||||||
//should not received keepalive
|
//should not received keepalive
|
||||||
UdpGwResponse::KeepAlive => {
|
UdpGwResponse::KeepAlive => {
|
||||||
log::error!("[UdpGw] Ending {} <> {} with recv keepalive", &tcp_local_addr, udp_dst);
|
log::error!("[UdpGw] Ending stream {} {} <> {} with recv keepalive", sn, &tcp_local_addr, udp_dst);
|
||||||
stream.close();
|
stream.close();
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
//server udp may be timeout,can continue to receive udp data?
|
//server udp may be timeout,can continue to receive udp data?
|
||||||
UdpGwResponse::Error => {
|
UdpGwResponse::Error => {
|
||||||
log::info!("[UdpGw] Ending {} <> {} with recv udp error", &tcp_local_addr, udp_dst);
|
log::info!("[UdpGw] Ending stream {} {} <> {} with recv udp error", sn, &tcp_local_addr, udp_dst);
|
||||||
stream.update_activity();
|
stream.update_activity();
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
UdpGwResponse::TcpClose => {
|
UdpGwResponse::TcpClose => {
|
||||||
log::error!("[UdpGw] Ending {} <> {} with tcp closed", &tcp_local_addr, udp_dst);
|
log::error!("[UdpGw] Ending stream {} {} <> {} with tcp closed", sn, &tcp_local_addr, udp_dst);
|
||||||
stream.close();
|
stream.close();
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
UdpGwResponse::Data(data) => {
|
UdpGwResponse::Data(data) => {
|
||||||
use socks5_impl::protocol::StreamOperation;
|
use socks5_impl::protocol::StreamOperation;
|
||||||
let len = data.len();
|
let len = data.len();
|
||||||
log::debug!("[UdpGw] {} <- {} receive len {}", &tcp_local_addr, udp_dst, len);
|
log::debug!("[UdpGw] stream {} {} <- {} receive len {}", sn, &tcp_local_addr, udp_dst, len);
|
||||||
if let Err(e) = udp_stack.write_all(&data.data).await {
|
if let Err(e) = udp_stack.write_all(&data.data).await {
|
||||||
log::error!("[UdpGw] Ending {} <> {} with send_udp_packet {}", &tcp_local_addr, udp_dst, e);
|
log::error!("[UdpGw] Ending stream {} {} <> {} with send_udp_packet {}", sn, &tcp_local_addr, udp_dst, e);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
crate::traffic_status::traffic_status_update(0, len)?;
|
crate::traffic_status::traffic_status_update(0, len)?;
|
||||||
}
|
}
|
||||||
},
|
|
||||||
Err(e) => {
|
|
||||||
log::warn!("[UdpGw] Ending {} <> {} with recv_udpgw_packet {}", &tcp_local_addr, udp_dst, e);
|
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
stream.update_activity();
|
stream.update_activity();
|
||||||
|
@ -588,7 +599,7 @@ async fn handle_udp_gateway_session(
|
||||||
}
|
}
|
||||||
|
|
||||||
if !stream.is_closed() {
|
if !stream.is_closed() {
|
||||||
udpgw_client.release_server_connection_full(stream, reader, writer).await;
|
udpgw_client.store_server_connection_full(stream, reader, writer).await;
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|
174
src/udpgw.rs
174
src/udpgw.rs
|
@ -15,9 +15,42 @@ pub(crate) const UDPGW_LENGTH_FIELD_SIZE: usize = std::mem::size_of::<u16>();
|
||||||
pub(crate) const UDPGW_MAX_CONNECTIONS: u16 = 100;
|
pub(crate) const UDPGW_MAX_CONNECTIONS: u16 = 100;
|
||||||
pub(crate) const UDPGW_KEEPALIVE_TIME: tokio::time::Duration = std::time::Duration::from_secs(10);
|
pub(crate) const UDPGW_KEEPALIVE_TIME: tokio::time::Duration = std::time::Duration::from_secs(10);
|
||||||
|
|
||||||
pub const UDPGW_FLAG_KEEPALIVE: u8 = 0x01;
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||||
pub const UDPGW_FLAG_ERR: u8 = 0x20;
|
pub struct UdpFlag(pub u8);
|
||||||
pub const UDPGW_FLAG_DATA: u8 = 0x02;
|
|
||||||
|
impl UdpFlag {
|
||||||
|
pub const ZERO: UdpFlag = UdpFlag(0x00);
|
||||||
|
pub const KEEPALIVE: UdpFlag = UdpFlag(0x01);
|
||||||
|
pub const ERR: UdpFlag = UdpFlag(0x20);
|
||||||
|
pub const DATA: UdpFlag = UdpFlag(0x02);
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Display for UdpFlag {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
let flag = match self.0 {
|
||||||
|
0x00 => "ZERO",
|
||||||
|
0x01 => "KEEPALIVE",
|
||||||
|
0x20 => "ERR",
|
||||||
|
0x02 => "DATA",
|
||||||
|
n => return write!(f, "Unknown UdpFlag(0x{:02X})", n),
|
||||||
|
};
|
||||||
|
write!(f, "UdpFlag({})", flag)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::ops::BitAnd for UdpFlag {
|
||||||
|
type Output = Self;
|
||||||
|
fn bitand(self, rhs: Self) -> Self::Output {
|
||||||
|
UdpFlag(self.0 & rhs.0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::ops::BitOr for UdpFlag {
|
||||||
|
type Output = Self;
|
||||||
|
fn bitor(self, rhs: Self) -> Self::Output {
|
||||||
|
UdpFlag(self.0 | rhs.0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
static TCP_COUNTER: std::sync::atomic::AtomicU32 = std::sync::atomic::AtomicU32::new(0);
|
static TCP_COUNTER: std::sync::atomic::AtomicU32 = std::sync::atomic::AtomicU32::new(0);
|
||||||
|
|
||||||
|
@ -98,7 +131,7 @@ impl TryFrom<&[u8]> for Packet {
|
||||||
return Err(std::io::ErrorKind::InvalidData.into());
|
return Err(std::io::ErrorKind::InvalidData.into());
|
||||||
}
|
}
|
||||||
let header = UdpgwHeader::retrieve_from_stream(&mut iter)?;
|
let header = UdpgwHeader::retrieve_from_stream(&mut iter)?;
|
||||||
let address = if header.flags & UDPGW_FLAG_DATA != 0 {
|
let address = if header.flags & UdpFlag::DATA != UdpFlag::ZERO {
|
||||||
Some(Address::retrieve_from_stream(&mut iter)?)
|
Some(Address::retrieve_from_stream(&mut iter)?)
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
|
@ -114,11 +147,11 @@ impl Packet {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn build_keepalive_packet(conn_id: u16) -> Self {
|
pub fn build_keepalive_packet(conn_id: u16) -> Self {
|
||||||
Packet::new(UdpgwHeader::new(UDPGW_FLAG_KEEPALIVE, conn_id), None, &[])
|
Packet::new(UdpgwHeader::new(UdpFlag::KEEPALIVE, conn_id), None, &[])
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn build_error_packet(conn_id: u16) -> Self {
|
pub fn build_error_packet(conn_id: u16) -> Self {
|
||||||
Packet::new(UdpgwHeader::new(UDPGW_FLAG_ERR, conn_id), None, &[])
|
Packet::new(UdpgwHeader::new(UdpFlag::ERR, conn_id), None, &[])
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn build_packet_from_address(conn_id: u16, remote_addr: &Address, data: &[u8]) -> std::io::Result<Self> {
|
pub fn build_packet_from_address(conn_id: u16, remote_addr: &Address, data: &[u8]) -> std::io::Result<Self> {
|
||||||
|
@ -132,7 +165,7 @@ impl Packet {
|
||||||
|
|
||||||
pub fn build_ip_packet(conn_id: u16, remote_addr: SocketAddr, data: &[u8]) -> Self {
|
pub fn build_ip_packet(conn_id: u16, remote_addr: SocketAddr, data: &[u8]) -> Self {
|
||||||
let addr: Address = remote_addr.into();
|
let addr: Address = remote_addr.into();
|
||||||
Packet::new(UdpgwHeader::new(UDPGW_FLAG_DATA, conn_id), Some(addr), data)
|
Packet::new(UdpgwHeader::new(UdpFlag::DATA, conn_id), Some(addr), data)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn build_domain_packet(conn_id: u16, port: u16, domain: &str, data: &[u8]) -> std::io::Result<Self> {
|
pub fn build_domain_packet(conn_id: u16, port: u16, domain: &str, data: &[u8]) -> std::io::Result<Self> {
|
||||||
|
@ -140,7 +173,7 @@ impl Packet {
|
||||||
return Err(std::io::ErrorKind::InvalidInput.into());
|
return Err(std::io::ErrorKind::InvalidInput.into());
|
||||||
}
|
}
|
||||||
let addr = Address::from((domain, port));
|
let addr = Address::from((domain, port));
|
||||||
Ok(Packet::new(UdpgwHeader::new(UDPGW_FLAG_DATA, conn_id), Some(addr), data))
|
Ok(Packet::new(UdpgwHeader::new(UdpFlag::DATA, conn_id), Some(addr), data))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -154,7 +187,7 @@ impl StreamOperation for Packet {
|
||||||
stream.read_exact(&mut buf)?;
|
stream.read_exact(&mut buf)?;
|
||||||
let length = u16::from_be_bytes(buf) as usize;
|
let length = u16::from_be_bytes(buf) as usize;
|
||||||
let header = UdpgwHeader::retrieve_from_stream(stream)?;
|
let header = UdpgwHeader::retrieve_from_stream(stream)?;
|
||||||
let address = if header.flags & UDPGW_FLAG_DATA != 0 {
|
let address = if header.flags & UdpFlag::DATA == UdpFlag::DATA {
|
||||||
Some(Address::retrieve_from_stream(stream)?)
|
Some(Address::retrieve_from_stream(stream)?)
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
|
@ -194,7 +227,7 @@ impl AsyncStreamOperation for Packet {
|
||||||
r.read_exact(&mut buf).await?;
|
r.read_exact(&mut buf).await?;
|
||||||
let length = u16::from_be_bytes(buf) as usize;
|
let length = u16::from_be_bytes(buf) as usize;
|
||||||
let header = UdpgwHeader::retrieve_from_async_stream(r).await?;
|
let header = UdpgwHeader::retrieve_from_async_stream(r).await?;
|
||||||
let address = if header.flags & UDPGW_FLAG_DATA != 0 {
|
let address = if header.flags & UdpFlag::DATA == UdpFlag::DATA {
|
||||||
Some(Address::retrieve_from_async_stream(r).await?)
|
Some(Address::retrieve_from_async_stream(r).await?)
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
|
@ -211,14 +244,14 @@ impl AsyncStreamOperation for Packet {
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||||
pub struct UdpgwHeader {
|
pub struct UdpgwHeader {
|
||||||
pub flags: u8,
|
pub flags: UdpFlag,
|
||||||
pub conn_id: u16,
|
pub conn_id: u16,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl std::fmt::Display for UdpgwHeader {
|
impl std::fmt::Display for UdpgwHeader {
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
let id = self.conn_id;
|
let id = self.conn_id;
|
||||||
write!(f, "flags: 0x{:02x}, conn_id: {}", self.flags, id)
|
write!(f, "flags: {}, conn_id: {}", self.flags, id)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -257,7 +290,7 @@ impl AsyncStreamOperation for UdpgwHeader {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl UdpgwHeader {
|
impl UdpgwHeader {
|
||||||
pub fn new(flags: u8, conn_id: u16) -> Self {
|
pub fn new(flags: UdpFlag, conn_id: u16) -> Self {
|
||||||
UdpgwHeader { flags, conn_id }
|
UdpgwHeader { flags, conn_id }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -274,14 +307,14 @@ impl TryFrom<&[u8]> for UdpgwHeader {
|
||||||
return Err(std::io::ErrorKind::InvalidData.into());
|
return Err(std::io::ErrorKind::InvalidData.into());
|
||||||
}
|
}
|
||||||
let conn_id = u16::from_be_bytes([value[1], value[2]]);
|
let conn_id = u16::from_be_bytes([value[1], value[2]]);
|
||||||
Ok(UdpgwHeader { flags: value[0], conn_id })
|
Ok(UdpgwHeader::new(UdpFlag(value[0]), conn_id))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<&UdpgwHeader> for Vec<u8> {
|
impl From<&UdpgwHeader> for Vec<u8> {
|
||||||
fn from(header: &UdpgwHeader) -> Vec<u8> {
|
fn from(header: &UdpgwHeader) -> Vec<u8> {
|
||||||
let mut bytes = vec![0; header.len()];
|
let mut bytes = vec![0; header.len()];
|
||||||
bytes[0] = header.flags;
|
bytes[0] = header.flags.0;
|
||||||
bytes[1..3].copy_from_slice(&header.conn_id.to_be_bytes());
|
bytes[1..3].copy_from_slice(&header.conn_id.to_be_bytes());
|
||||||
bytes
|
bytes
|
||||||
}
|
}
|
||||||
|
@ -296,14 +329,17 @@ pub(crate) enum UdpGwResponse {
|
||||||
Data(Packet),
|
Data(Packet),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static SERIAL_NUMBER: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(1);
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub(crate) struct UdpGwClientStream {
|
pub(crate) struct UdpGwClientStream {
|
||||||
local_addr: String,
|
local_addr: SocketAddr,
|
||||||
writer: Option<OwnedWriteHalf>,
|
writer: Option<OwnedWriteHalf>,
|
||||||
reader: Option<OwnedReadHalf>,
|
reader: Option<OwnedReadHalf>,
|
||||||
conn_id: u16,
|
conn_id: u16,
|
||||||
closed: bool,
|
closed: bool,
|
||||||
last_activity: std::time::Instant,
|
last_activity: std::time::Instant,
|
||||||
|
serial_number: u64,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Drop for UdpGwClientStream {
|
impl Drop for UdpGwClientStream {
|
||||||
|
@ -333,34 +369,33 @@ impl UdpGwClientStream {
|
||||||
self.writer.take()
|
self.writer.take()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn local_addr(&self) -> &String {
|
pub fn local_addr(&self) -> SocketAddr {
|
||||||
&self.local_addr
|
self.local_addr
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn update_activity(&mut self) {
|
pub fn update_activity(&mut self) {
|
||||||
self.last_activity = std::time::Instant::now();
|
self.last_activity = std::time::Instant::now();
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn is_closed(&mut self) -> bool {
|
pub fn is_closed(&self) -> bool {
|
||||||
self.closed
|
self.closed
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn id(&mut self) -> u16 {
|
pub fn serial_number(&self) -> u64 {
|
||||||
self.conn_id
|
self.serial_number
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn new_id(&mut self) -> u16 {
|
pub fn new_packet_id(&mut self) -> u16 {
|
||||||
self.conn_id += 1;
|
self.conn_id += 1;
|
||||||
self.conn_id
|
self.conn_id
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn new(tcp_server_stream: TcpStream) -> Self {
|
pub fn new(tcp_server_stream: TcpStream) -> Self {
|
||||||
let default = "0.0.0.0:0".parse::<SocketAddr>().unwrap();
|
let default = "0.0.0.0:0".parse::<SocketAddr>().unwrap();
|
||||||
let local_addr = tcp_server_stream.local_addr().unwrap_or(default).to_string();
|
let local_addr = tcp_server_stream.local_addr().unwrap_or(default);
|
||||||
let (rx, tx) = tcp_server_stream.into_split();
|
let (reader, writer) = tcp_server_stream.into_split();
|
||||||
let writer = tx;
|
|
||||||
let reader = rx;
|
|
||||||
TCP_COUNTER.fetch_add(1, Relaxed);
|
TCP_COUNTER.fetch_add(1, Relaxed);
|
||||||
|
let serial_number = SERIAL_NUMBER.fetch_add(1, Relaxed);
|
||||||
UdpGwClientStream {
|
UdpGwClientStream {
|
||||||
local_addr,
|
local_addr,
|
||||||
reader: Some(reader),
|
reader: Some(reader),
|
||||||
|
@ -368,6 +403,7 @@ impl UdpGwClientStream {
|
||||||
last_activity: std::time::Instant::now(),
|
last_activity: std::time::Instant::now(),
|
||||||
closed: false,
|
closed: false,
|
||||||
conn_id: 0,
|
conn_id: 0,
|
||||||
|
serial_number,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -378,18 +414,18 @@ pub(crate) struct UdpGwClient {
|
||||||
max_connections: u16,
|
max_connections: u16,
|
||||||
udp_timeout: u64,
|
udp_timeout: u64,
|
||||||
keepalive_time: Duration,
|
keepalive_time: Duration,
|
||||||
server_addr: SocketAddr,
|
udpgw_server: SocketAddr,
|
||||||
server_connections: Mutex<VecDeque<UdpGwClientStream>>,
|
server_connections: Mutex<VecDeque<UdpGwClientStream>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl UdpGwClient {
|
impl UdpGwClient {
|
||||||
pub fn new(udp_mtu: u16, max_connections: u16, keepalive_time: Duration, udp_timeout: u64, server_addr: SocketAddr) -> Self {
|
pub fn new(udp_mtu: u16, max_connections: u16, keepalive_time: Duration, udp_timeout: u64, udpgw_server: SocketAddr) -> Self {
|
||||||
let server_connections = Mutex::new(VecDeque::with_capacity(max_connections as usize));
|
let server_connections = Mutex::new(VecDeque::with_capacity(max_connections as usize));
|
||||||
UdpGwClient {
|
UdpGwClient {
|
||||||
udp_mtu,
|
udp_mtu,
|
||||||
max_connections,
|
max_connections,
|
||||||
udp_timeout,
|
udp_timeout,
|
||||||
server_addr,
|
udpgw_server,
|
||||||
keepalive_time,
|
keepalive_time,
|
||||||
server_connections,
|
server_connections,
|
||||||
}
|
}
|
||||||
|
@ -407,22 +443,17 @@ impl UdpGwClient {
|
||||||
TCP_COUNTER.load(Relaxed) >= self.max_connections as u32
|
TCP_COUNTER.load(Relaxed) >= self.max_connections as u32
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) async fn get_server_connection(&self) -> Option<UdpGwClientStream> {
|
pub(crate) async fn pop_server_connection_from_queue(&self) -> Option<UdpGwClientStream> {
|
||||||
self.server_connections.lock().await.pop_front()
|
self.server_connections.lock().await.pop_front()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) async fn release_server_connection(&self, stream: UdpGwClientStream) {
|
pub(crate) async fn store_server_connection(&self, stream: UdpGwClientStream) {
|
||||||
if self.server_connections.lock().await.len() < self.max_connections as usize {
|
if self.server_connections.lock().await.len() < self.max_connections as usize {
|
||||||
self.server_connections.lock().await.push_back(stream);
|
self.server_connections.lock().await.push_back(stream);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) async fn release_server_connection_full(
|
pub(crate) async fn store_server_connection_full(&self, mut stream: UdpGwClientStream, reader: OwnedReadHalf, writer: OwnedWriteHalf) {
|
||||||
&self,
|
|
||||||
mut stream: UdpGwClientStream,
|
|
||||||
reader: OwnedReadHalf,
|
|
||||||
writer: OwnedWriteHalf,
|
|
||||||
) {
|
|
||||||
if self.server_connections.lock().await.len() < self.max_connections as usize {
|
if self.server_connections.lock().await.len() < self.max_connections as usize {
|
||||||
stream.set_reader(Some(reader));
|
stream.set_reader(Some(reader));
|
||||||
stream.set_writer(Some(writer));
|
stream.set_writer(Some(writer));
|
||||||
|
@ -430,42 +461,51 @@ impl UdpGwClient {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn get_server_addr(&self) -> SocketAddr {
|
pub(crate) fn get_udpgw_server_addr(&self) -> SocketAddr {
|
||||||
self.server_addr
|
self.udpgw_server
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Heartbeat task asynchronous function to periodically check and maintain the active state of the server connection.
|
/// Heartbeat task asynchronous function to periodically check and maintain the active state of the server connection.
|
||||||
pub(crate) async fn heartbeat_task(&self) {
|
pub(crate) async fn heartbeat_task(&self) -> std::io::Result<()> {
|
||||||
loop {
|
loop {
|
||||||
sleep(self.keepalive_time).await;
|
sleep(self.keepalive_time).await;
|
||||||
if let Some(mut stream) = self.get_server_connection().await {
|
let Some(mut stream) = self.pop_server_connection_from_queue().await else {
|
||||||
if stream.last_activity.elapsed() < self.keepalive_time {
|
continue;
|
||||||
self.release_server_connection(stream).await;
|
};
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
let Some(mut stream_reader) = stream.get_reader() else {
|
if stream.is_closed() {
|
||||||
continue;
|
// This stream will be dropped
|
||||||
};
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
let Some(mut stream_writer) = stream.get_writer() else {
|
if stream.last_activity.elapsed() < self.keepalive_time {
|
||||||
continue;
|
self.store_server_connection(stream).await;
|
||||||
};
|
continue;
|
||||||
let local_addr = stream_writer.local_addr();
|
}
|
||||||
log::debug!("{:?}:{} send keepalive", local_addr, stream.id());
|
|
||||||
let keepalive_packet: Vec<u8> = Packet::build_keepalive_packet(stream.id()).into();
|
let Some(mut stream_reader) = stream.get_reader() else {
|
||||||
if let Err(e) = stream_writer.write_all(&keepalive_packet).await {
|
continue;
|
||||||
log::warn!("{:?}:{} send keepalive failed: {}", local_addr, stream.id(), e);
|
};
|
||||||
continue;
|
|
||||||
}
|
let Some(mut stream_writer) = stream.get_writer() else {
|
||||||
match UdpGwClient::recv_udpgw_packet(self.udp_mtu, 10, &mut stream_reader).await {
|
continue;
|
||||||
Ok(UdpGwResponse::KeepAlive) => {
|
};
|
||||||
stream.update_activity();
|
let local_addr = stream_writer.local_addr()?;
|
||||||
self.release_server_connection_full(stream, stream_reader, stream_writer).await;
|
let sn = stream.serial_number();
|
||||||
}
|
log::trace!("stream {} {:?} send keepalive", sn, local_addr);
|
||||||
Ok(v) => log::warn!("{:?}:{} keepalive unexpected response: {:?}", local_addr, stream.id(), v),
|
let keepalive_packet: Vec<u8> = Packet::build_keepalive_packet(stream.new_packet_id()).into();
|
||||||
Err(e) => log::warn!("{:?}:{} keepalive no response, error \"{}\"", local_addr, stream.id(), e),
|
if let Err(e) = stream_writer.write_all(&keepalive_packet).await {
|
||||||
|
log::warn!("stream {} {:?} send keepalive failed: {}", sn, local_addr, e);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
match UdpGwClient::recv_udpgw_packet(self.udp_mtu, 10, &mut stream_reader).await {
|
||||||
|
Ok(UdpGwResponse::KeepAlive) => {
|
||||||
|
stream.update_activity();
|
||||||
|
self.store_server_connection_full(stream, stream_reader, stream_writer).await;
|
||||||
|
log::trace!("stream {} {:?} keepalive success", sn, local_addr);
|
||||||
}
|
}
|
||||||
|
Ok(v) => log::warn!("stream {} {:?} keepalive unexpected response: {:?}", sn, local_addr, v),
|
||||||
|
Err(e) => log::warn!("stream {} {:?} keepalive no response, error \"{}\"", sn, local_addr, e),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -474,10 +514,10 @@ impl UdpGwClient {
|
||||||
pub(crate) fn parse_udp_response(udp_mtu: u16, data: &[u8]) -> Result<UdpGwResponse> {
|
pub(crate) fn parse_udp_response(udp_mtu: u16, data: &[u8]) -> Result<UdpGwResponse> {
|
||||||
let packet = Packet::try_from(data)?;
|
let packet = Packet::try_from(data)?;
|
||||||
let flags = packet.header.flags;
|
let flags = packet.header.flags;
|
||||||
if flags & UDPGW_FLAG_ERR != 0 {
|
if flags & UdpFlag::ERR == UdpFlag::ERR {
|
||||||
return Ok(UdpGwResponse::Error);
|
return Ok(UdpGwResponse::Error);
|
||||||
}
|
}
|
||||||
if flags & UDPGW_FLAG_KEEPALIVE != 0 {
|
if flags & UdpFlag::KEEPALIVE == UdpFlag::KEEPALIVE {
|
||||||
return Ok(UdpGwResponse::KeepAlive);
|
return Ok(UdpGwResponse::KeepAlive);
|
||||||
}
|
}
|
||||||
if packet.data.len() > udp_mtu as usize {
|
if packet.data.len() > udp_mtu as usize {
|
||||||
|
|
Loading…
Add table
Reference in a new issue