mirror of
https://github.com/tun2proxy/tun2proxy.git
synced 2025-04-21 22:39:08 +00:00
refactor core logic
This commit is contained in:
parent
2155ea55c0
commit
d4afc8f655
3 changed files with 270 additions and 630 deletions
|
@ -1,41 +1,30 @@
|
||||||
use std::{
|
use socks5_impl::protocol::{AddressType, AsyncStreamOperation};
|
||||||
net::{Ipv4Addr, SocketAddr, SocketAddrV4, ToSocketAddrs},
|
use std::{net::SocketAddr, sync::Arc};
|
||||||
sync::Arc,
|
|
||||||
};
|
|
||||||
use tokio::{
|
use tokio::{
|
||||||
io::{AsyncReadExt, AsyncWriteExt},
|
io::AsyncWriteExt,
|
||||||
net::{
|
net::{
|
||||||
tcp::{ReadHalf, WriteHalf},
|
tcp::{ReadHalf, WriteHalf},
|
||||||
UdpSocket,
|
UdpSocket,
|
||||||
},
|
},
|
||||||
sync::mpsc::{self, Receiver, Sender},
|
sync::mpsc::{self, Receiver, Sender},
|
||||||
};
|
};
|
||||||
use tun2proxy::{udpgw::*, ArgVerbosity, Result};
|
use tun2proxy::{
|
||||||
|
udpgw::{Packet, UDPGW_FLAG_KEEPALIVE},
|
||||||
|
ArgVerbosity, Result,
|
||||||
|
};
|
||||||
|
|
||||||
pub(crate) const CLIENT_DISCONNECT_TIMEOUT: tokio::time::Duration = std::time::Duration::from_secs(60);
|
pub(crate) const CLIENT_DISCONNECT_TIMEOUT: tokio::time::Duration = std::time::Duration::from_secs(60);
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
struct UdpRequest {
|
|
||||||
flags: u8,
|
|
||||||
server_addr: SocketAddr,
|
|
||||||
conn_id: u16,
|
|
||||||
data: Vec<u8>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct Client {
|
pub struct Client {
|
||||||
addr: SocketAddr,
|
addr: SocketAddr,
|
||||||
buf: Vec<u8>,
|
|
||||||
last_activity: std::time::Instant,
|
last_activity: std::time::Instant,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Client {
|
impl Client {
|
||||||
pub fn new(addr: SocketAddr) -> Self {
|
pub fn new(addr: SocketAddr) -> Self {
|
||||||
Self {
|
let last_activity = std::time::Instant::now();
|
||||||
addr,
|
Self { addr, last_activity }
|
||||||
buf: vec![],
|
|
||||||
last_activity: std::time::Instant::now(),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -71,112 +60,28 @@ impl UdpGwArgs {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn send_error(tx: Sender<Vec<u8>>, con: &mut UdpRequest) {
|
async fn send_error(tx: Sender<Packet>, conn_id: u16) {
|
||||||
let error_packet: Vec<u8> = Packet::new(UdpgwHeader::new(UDPGW_FLAG_ERR, con.conn_id), vec![]).into();
|
let error_packet = Packet::build_error_packet(conn_id);
|
||||||
if let Err(e) = tx.send(error_packet).await {
|
if let Err(e) = tx.send(error_packet).await {
|
||||||
log::error!("send error response error {:?}", e);
|
log::error!("send error response error {:?}", e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn send_keepalive_response(tx: Sender<Vec<u8>>, conn_id: u16) {
|
async fn send_keepalive_response(tx: Sender<Packet>, conn_id: u16) {
|
||||||
let keepalive_packet: Vec<u8> = Packet::new(UdpgwHeader::new(UDPGW_FLAG_KEEPALIVE, conn_id), vec![]).into();
|
let keepalive_packet = Packet::build_keepalive_packet(conn_id);
|
||||||
if let Err(e) = tx.send(keepalive_packet).await {
|
if let Err(e) = tx.send(keepalive_packet).await {
|
||||||
log::error!("send keepalive response error {:?}", e);
|
log::error!("send keepalive response error {:?}", e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn parse_udp(udp_mtu: u16, data_len: usize, data: &[u8]) -> Result<(&[u8], u8, u16, SocketAddr)> {
|
/// Send data field of packet from client to destination server and receive response,
|
||||||
let header_len = UdpgwHeader::static_len();
|
/// then wrap response data to the packet's data field and send packet back to client.
|
||||||
if data_len < header_len {
|
async fn process_udp(client: SocketAddr, udp_mtu: u16, udp_timeout: u64, tx: Sender<Packet>, mut packet: Packet) -> Result<()> {
|
||||||
return Err("Invalid udpgw data".into());
|
let Some(dst_addr) = &packet.address else {
|
||||||
}
|
log::error!("client {} udp request address is None", client);
|
||||||
let header_bytes = &data[..header_len];
|
return Ok(());
|
||||||
let header = UdpgwHeader {
|
|
||||||
flags: header_bytes[0],
|
|
||||||
conn_id: u16::from_le_bytes([header_bytes[1], header_bytes[2]]),
|
|
||||||
};
|
};
|
||||||
|
let std_sock = if dst_addr.get_type() == AddressType::IPv6 {
|
||||||
let flags = header.flags;
|
|
||||||
let conn_id = header.conn_id;
|
|
||||||
|
|
||||||
// keepalive
|
|
||||||
if flags & UDPGW_FLAG_KEEPALIVE != 0 {
|
|
||||||
return Ok((data, flags, conn_id, SocketAddrV4::new(Ipv4Addr::from(0), 0).into()));
|
|
||||||
}
|
|
||||||
|
|
||||||
let ip_data = &data[header_len..];
|
|
||||||
let mut data_len = data_len - header_len;
|
|
||||||
// port_len + min(ipv4/ipv6/(domain_len + 1))
|
|
||||||
if data_len < UDPGW_LENGTH_FIELD_SIZE + 2 {
|
|
||||||
return Err("Invalid udpgw data".into());
|
|
||||||
}
|
|
||||||
if flags & UDPGW_FLAG_DOMAIN != 0 {
|
|
||||||
let addr_port = u16::from_be_bytes([ip_data[0], ip_data[1]]);
|
|
||||||
data_len -= 2;
|
|
||||||
if let Some(end) = ip_data.iter().skip(2).position(|&x| x == 0) {
|
|
||||||
let domain_slice = &ip_data[2..end + 2];
|
|
||||||
match std::str::from_utf8(domain_slice) {
|
|
||||||
Ok(domain) => {
|
|
||||||
let target_str = format!("{}:{}", domain, addr_port);
|
|
||||||
let target = target_str
|
|
||||||
.to_socket_addrs()?
|
|
||||||
.next()
|
|
||||||
.ok_or(format!("Invalid address {}", target_str))?;
|
|
||||||
if data_len < 2 + domain.len() {
|
|
||||||
return Err("Invalid udpgw data".into());
|
|
||||||
}
|
|
||||||
data_len -= domain.len() + 1;
|
|
||||||
if data_len > udp_mtu as usize {
|
|
||||||
return Err("too much data".into());
|
|
||||||
}
|
|
||||||
let udpdata = &ip_data[(2 + domain.len() + 1)..];
|
|
||||||
Ok((udpdata, flags, conn_id, target))
|
|
||||||
}
|
|
||||||
Err(_) => Err("Invalid UTF-8 sequence in domain".into()),
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
Err("missing domain name".into())
|
|
||||||
}
|
|
||||||
} else if flags & UDPGW_FLAG_IPV6 != 0 {
|
|
||||||
let addr_ipv6_len = BinSocketAddr::static_len(true);
|
|
||||||
if data_len < addr_ipv6_len {
|
|
||||||
return Err("Ipv6 Invalid UDP data".into());
|
|
||||||
}
|
|
||||||
let addr_ipv6 = BinSocketAddr::try_from(&ip_data[..addr_ipv6_len])?;
|
|
||||||
data_len -= addr_ipv6_len;
|
|
||||||
|
|
||||||
if data_len > udp_mtu as usize {
|
|
||||||
return Err("too much data".into());
|
|
||||||
}
|
|
||||||
return Ok((
|
|
||||||
&ip_data[addr_ipv6_len..(data_len + addr_ipv6_len)],
|
|
||||||
flags,
|
|
||||||
conn_id,
|
|
||||||
addr_ipv6.into(),
|
|
||||||
));
|
|
||||||
} else {
|
|
||||||
let addr_ipv4_len = BinSocketAddr::static_len(false);
|
|
||||||
if data_len < addr_ipv4_len {
|
|
||||||
return Err("Ipv4 Invalid UDP data".into());
|
|
||||||
}
|
|
||||||
let addr_ipv4 = BinSocketAddr::try_from(&ip_data[..addr_ipv4_len])?;
|
|
||||||
data_len -= addr_ipv4_len;
|
|
||||||
|
|
||||||
if data_len > udp_mtu as usize {
|
|
||||||
return Err("too much data".into());
|
|
||||||
}
|
|
||||||
|
|
||||||
return Ok((
|
|
||||||
&ip_data[addr_ipv4_len..(data_len + addr_ipv4_len)],
|
|
||||||
flags,
|
|
||||||
conn_id,
|
|
||||||
addr_ipv4.into(),
|
|
||||||
));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn process_udp(addr: SocketAddr, udp_timeout: u64, tx: Sender<Vec<u8>>, con: &mut UdpRequest) -> Result<()> {
|
|
||||||
let std_sock = if con.flags & UDPGW_FLAG_IPV6 != 0 {
|
|
||||||
std::net::UdpSocket::bind("[::]:0")?
|
std::net::UdpSocket::bind("[::]:0")?
|
||||||
} else {
|
} else {
|
||||||
std::net::UdpSocket::bind("0.0.0.0:0")?
|
std::net::UdpSocket::bind("0.0.0.0:0")?
|
||||||
|
@ -185,163 +90,78 @@ async fn process_udp(addr: SocketAddr, udp_timeout: u64, tx: Sender<Vec<u8>>, co
|
||||||
#[cfg(unix)]
|
#[cfg(unix)]
|
||||||
nix::sys::socket::setsockopt(&std_sock, nix::sys::socket::sockopt::ReuseAddr, &true)?;
|
nix::sys::socket::setsockopt(&std_sock, nix::sys::socket::sockopt::ReuseAddr, &true)?;
|
||||||
let socket = UdpSocket::from_std(std_sock)?;
|
let socket = UdpSocket::from_std(std_sock)?;
|
||||||
socket.send_to(&con.data, &con.server_addr).await?;
|
use std::net::ToSocketAddrs;
|
||||||
con.data.resize(2048, 0);
|
let Some(dst_addr) = dst_addr.to_socket_addrs()?.next() else {
|
||||||
match tokio::time::timeout(tokio::time::Duration::from_secs(udp_timeout), socket.recv_from(&mut con.data)).await {
|
log::error!("client {} udp request address to_socket_addrs", client);
|
||||||
Ok(ret) => {
|
return Ok(());
|
||||||
let (len, _addr) = ret?;
|
};
|
||||||
let mut packet = vec![];
|
// 1. send udp data to destination server
|
||||||
let mut pack_len = UdpgwHeader::static_len() + len;
|
socket.send_to(&packet.data, &dst_addr).await?;
|
||||||
match con.server_addr {
|
packet.data.resize(udp_mtu as usize, 0);
|
||||||
SocketAddr::V4(_) => {
|
// 2. receive response from destination server
|
||||||
let addr_ipv4 = BinSocketAddr::from(con.server_addr);
|
let (len, _addr) = tokio::time::timeout(tokio::time::Duration::from_secs(udp_timeout), socket.recv_from(&mut packet.data))
|
||||||
pack_len += addr_ipv4.len();
|
.await
|
||||||
packet.extend_from_slice(&(pack_len as u16).to_le_bytes());
|
.map_err(std::io::Error::from)??;
|
||||||
packet.extend_from_slice(&[con.flags]);
|
packet.data.truncate(len);
|
||||||
packet.extend_from_slice(&con.conn_id.to_le_bytes());
|
// 3. send response back to client
|
||||||
let addr_ipv4_bin: Vec<u8> = addr_ipv4.into();
|
use std::io::{Error, ErrorKind::BrokenPipe};
|
||||||
packet.extend_from_slice(&addr_ipv4_bin);
|
tx.send(packet).await.map_err(|e| Error::new(BrokenPipe, e))?;
|
||||||
packet.extend_from_slice(&con.data[..len]);
|
|
||||||
}
|
|
||||||
SocketAddr::V6(_) => {
|
|
||||||
let addr_ipv6 = BinSocketAddr::from(con.server_addr);
|
|
||||||
pack_len += addr_ipv6.len();
|
|
||||||
packet.extend_from_slice(&(pack_len as u16).to_le_bytes());
|
|
||||||
packet.extend_from_slice(&[con.flags]);
|
|
||||||
packet.extend_from_slice(&con.conn_id.to_le_bytes());
|
|
||||||
let addr_ipv6_bin: Vec<u8> = addr_ipv6.into();
|
|
||||||
packet.extend_from_slice(&addr_ipv6_bin);
|
|
||||||
packet.extend_from_slice(&con.data[..len]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if let Err(e) = tx.send(packet).await {
|
|
||||||
log::error!("client {} send udp response {}", addr, e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
log::warn!("client {} udp recv_from {}", addr, e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn process_client_udp_req(args: &UdpGwArgs, tx: Sender<Vec<u8>>, client: Client, mut reader: ReadHalf<'_>) -> std::io::Result<()> {
|
async fn process_client_udp_req(args: &UdpGwArgs, tx: Sender<Packet>, mut client: Client, mut reader: ReadHalf<'_>) -> std::io::Result<()> {
|
||||||
let mut client = client;
|
|
||||||
let mut buf = vec![0; args.udp_mtu as usize];
|
|
||||||
let mut len_buf = [0; UDPGW_LENGTH_FIELD_SIZE];
|
|
||||||
let udp_mtu = args.udp_mtu;
|
|
||||||
let udp_timeout = args.udp_timeout;
|
let udp_timeout = args.udp_timeout;
|
||||||
|
let udp_mtu = args.udp_mtu;
|
||||||
|
|
||||||
'out: loop {
|
loop {
|
||||||
/*
|
// 1. read udpgw packet from client
|
||||||
use socks5_impl::protocol::AsyncStreamOperation;
|
|
||||||
let res = tokio::time::timeout(tokio::time::Duration::from_secs(2), Packet::retrieve_from_async_stream(&mut reader)).await;
|
let res = tokio::time::timeout(tokio::time::Duration::from_secs(2), Packet::retrieve_from_async_stream(&mut reader)).await;
|
||||||
let packet = match res {
|
let packet = match res {
|
||||||
Ok(Ok(packet)) => packet,
|
Ok(Ok(packet)) => packet,
|
||||||
Ok(Err(e)) => {
|
Ok(Err(e)) => {
|
||||||
log::error!("client {} retrieve_from_async_stream {}", client.addr, e);
|
log::error!("client {} retrieve_from_async_stream \"{}\"", client.addr, e);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
Err(_) => {
|
Err(e) => {
|
||||||
if client.last_activity.elapsed() >= CLIENT_DISCONNECT_TIMEOUT {
|
if client.last_activity.elapsed() >= CLIENT_DISCONNECT_TIMEOUT {
|
||||||
log::debug!("client {} last_activity elapsed", client.addr);
|
log::debug!("client {} last_activity elapsed {e}", client.addr);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
client.buf.clear();
|
|
||||||
client.buf.extend_from_slice(&packet.data);
|
|
||||||
*/
|
|
||||||
|
|
||||||
//*
|
|
||||||
let result = match tokio::time::timeout(tokio::time::Duration::from_secs(2), reader.read(&mut len_buf)).await {
|
|
||||||
Ok(ret) => ret,
|
|
||||||
Err(_e) => {
|
|
||||||
if client.last_activity.elapsed() >= CLIENT_DISCONNECT_TIMEOUT {
|
|
||||||
log::debug!("client {} last_activity elapsed", client.addr);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
let n = result?;
|
|
||||||
if n == 0 {
|
|
||||||
// Connection closed
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
if n < UDPGW_LENGTH_FIELD_SIZE {
|
|
||||||
log::error!("client {} received Packet Length field error", client.addr);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
let packet_len = u16::from_le_bytes([len_buf[0], len_buf[1]]);
|
|
||||||
if packet_len > udp_mtu {
|
|
||||||
log::error!("client {} received packet too long", client.addr);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
log::trace!("client {} recvied packet len {}", client.addr, packet_len);
|
|
||||||
client.buf.clear();
|
|
||||||
let mut left_len: usize = packet_len as usize;
|
|
||||||
while left_len > 0 {
|
|
||||||
let len = reader.read(&mut buf[..left_len]).await?;
|
|
||||||
if len == 0 {
|
|
||||||
break 'out;
|
|
||||||
}
|
|
||||||
client.buf.extend_from_slice(&buf[..len]);
|
|
||||||
left_len -= len;
|
|
||||||
}
|
|
||||||
// */
|
|
||||||
client.last_activity = std::time::Instant::now();
|
client.last_activity = std::time::Instant::now();
|
||||||
let ret = parse_udp(udp_mtu, client.buf.len(), &client.buf);
|
|
||||||
if let Ok((udpdata, flags, conn_id, reqaddr)) = ret {
|
let flags = packet.header.flags;
|
||||||
|
let conn_id = packet.header.conn_id;
|
||||||
if flags & UDPGW_FLAG_KEEPALIVE != 0 {
|
if flags & UDPGW_FLAG_KEEPALIVE != 0 {
|
||||||
log::debug!("client {} send keepalive", client.addr);
|
log::trace!("client {} send keepalive", client.addr);
|
||||||
|
// 2. if keepalive packet, do nothing, send keepalive response to client
|
||||||
send_keepalive_response(tx.clone(), conn_id).await;
|
send_keepalive_response(tx.clone(), conn_id).await;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
log::debug!(
|
log::trace!("client {} received udp data {}", client.addr, packet);
|
||||||
"client {} received udp data,flags:{},conn_id:{},addr:{:?},data len:{}",
|
|
||||||
client.addr,
|
// 3. process client udpgw packet in a new task
|
||||||
flags,
|
let tx = tx.clone();
|
||||||
conn_id,
|
|
||||||
reqaddr,
|
|
||||||
udpdata.len()
|
|
||||||
);
|
|
||||||
let mut req = UdpRequest {
|
|
||||||
server_addr: reqaddr,
|
|
||||||
conn_id,
|
|
||||||
flags,
|
|
||||||
data: udpdata.to_vec(),
|
|
||||||
};
|
|
||||||
let tx1 = tx.clone();
|
|
||||||
let tx2 = tx.clone();
|
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
if let Err(e) = process_udp(client.addr, udp_timeout, tx1, &mut req).await {
|
if let Err(e) = process_udp(client.addr, udp_mtu, udp_timeout, tx.clone(), packet).await {
|
||||||
send_error(tx2, &mut req).await;
|
send_error(tx, conn_id).await;
|
||||||
log::error!("client {} process_udp {}", client.addr, e);
|
log::error!("client {} process udp function {}", client.addr, e);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
} else {
|
|
||||||
log::error!("client {} parse_udp_data {:?}", client.addr, ret.err());
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn write_to_client(addr: SocketAddr, mut writer: WriteHalf<'_>, mut rx: Receiver<Vec<u8>>) -> std::io::Result<()> {
|
async fn write_to_client(addr: SocketAddr, mut writer: WriteHalf<'_>, mut rx: Receiver<Packet>) -> std::io::Result<()> {
|
||||||
loop {
|
loop {
|
||||||
let Some(udp_response) = rx.recv().await else {
|
use std::io::{Error, ErrorKind::BrokenPipe};
|
||||||
log::trace!("client {} channel closed", addr);
|
let packet = rx.recv().await.ok_or(Error::new(BrokenPipe, "recv error"))?;
|
||||||
break;
|
log::trace!("send response to client {} with {}", addr, packet);
|
||||||
};
|
let data: Vec<u8> = packet.into();
|
||||||
if udp_response.is_empty() {
|
let _r = writer.write(&data).await?;
|
||||||
log::trace!("client {} channel recv 0", addr);
|
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
log::trace!("send response to client {} len {}", addr, udp_response.len());
|
|
||||||
let _r = writer.write(&udp_response).await?;
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
|
@ -354,6 +174,7 @@ async fn main() -> Result<()> {
|
||||||
|
|
||||||
env_logger::Builder::from_env(env_logger::Env::default().default_filter_or(default)).init();
|
env_logger::Builder::from_env(env_logger::Env::default().default_filter_or(default)).init();
|
||||||
|
|
||||||
|
log::info!("{} {} starting...", module_path!(), env!("CARGO_PKG_VERSION"));
|
||||||
log::info!("UDP Gateway Server running at {}", args.listen_addr);
|
log::info!("UDP Gateway Server running at {}", args.listen_addr);
|
||||||
|
|
||||||
#[cfg(unix)]
|
#[cfg(unix)]
|
||||||
|
@ -377,7 +198,7 @@ async fn main() -> Result<()> {
|
||||||
log::info!("client {} connected", addr);
|
log::info!("client {} connected", addr);
|
||||||
let params = args.clone();
|
let params = args.clone();
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
let (tx, rx) = mpsc::channel::<Vec<u8>>(100);
|
let (tx, rx) = mpsc::channel::<Packet>(100);
|
||||||
let (tcp_read_stream, tcp_write_stream) = tcp_stream.split();
|
let (tcp_read_stream, tcp_write_stream) = tcp_stream.split();
|
||||||
let res = tokio::select! {
|
let res = tokio::select! {
|
||||||
v = process_client_udp_req(¶ms, tx, client, tcp_read_stream) => v,
|
v = process_client_udp_req(¶ms, tx, client, tcp_read_stream) => v,
|
||||||
|
|
34
src/lib.rs
34
src/lib.rs
|
@ -28,7 +28,7 @@ pub use tokio_util::sync::CancellationToken;
|
||||||
use tproxy_config::is_private_ip;
|
use tproxy_config::is_private_ip;
|
||||||
use udp_stream::UdpStream;
|
use udp_stream::UdpStream;
|
||||||
#[cfg(feature = "udpgw")]
|
#[cfg(feature = "udpgw")]
|
||||||
use udpgw::{UdpGwClientStream, UdpGwResponse, UDPGW_KEEPALIVE_TIME};
|
use udpgw::{UdpGwClientStream, UdpGwResponse, UDPGW_KEEPALIVE_TIME, UDPGW_MAX_CONNECTIONS};
|
||||||
|
|
||||||
pub use {
|
pub use {
|
||||||
args::{ArgDns, ArgProxy, ArgVerbosity, Args, ProxyType},
|
args::{ArgDns, ArgProxy, ArgVerbosity, Args, ProxyType},
|
||||||
|
@ -244,7 +244,7 @@ where
|
||||||
log::info!("UDPGW enabled");
|
log::info!("UDPGW enabled");
|
||||||
let client = Arc::new(UdpGwClient::new(
|
let client = Arc::new(UdpGwClient::new(
|
||||||
mtu,
|
mtu,
|
||||||
args.udpgw_max_connections.unwrap_or(100),
|
args.udpgw_max_connections.unwrap_or(UDPGW_MAX_CONNECTIONS),
|
||||||
UDPGW_KEEPALIVE_TIME,
|
UDPGW_KEEPALIVE_TIME,
|
||||||
args.udp_timeout,
|
args.udp_timeout,
|
||||||
*addr,
|
*addr,
|
||||||
|
@ -502,7 +502,7 @@ async fn handle_udp_gateway_session(
|
||||||
if let Err(e) = handle_proxy_session(&mut tcp_server_stream, proxy_handler).await {
|
if let Err(e) = handle_proxy_session(&mut tcp_server_stream, proxy_handler).await {
|
||||||
return Err(format!("udpgw connection error: {}", e).into());
|
return Err(format!("udpgw connection error: {}", e).into());
|
||||||
}
|
}
|
||||||
UdpGwClientStream::new(udp_mtu, tcp_server_stream)
|
UdpGwClientStream::new(tcp_server_stream)
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -521,26 +521,29 @@ async fn handle_udp_gateway_session(
|
||||||
return Err("get writer failed".into());
|
return Err("get writer failed".into());
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let mut tmp_buf = vec![0; udp_mtu.into()];
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
tokio::select! {
|
tokio::select! {
|
||||||
len = UdpGwClient::recv_udp_packet(&mut udp_stack, &mut writer) => {
|
len = udp_stack.read(&mut tmp_buf) => {
|
||||||
let read_len;
|
let read_len = match len {
|
||||||
match len {
|
Ok(0) => {
|
||||||
Ok(n) => {
|
|
||||||
if n == 0 {
|
|
||||||
log::info!("[UdpGw] Ending {} <> {}", &tcp_local_addr, udp_dst);
|
log::info!("[UdpGw] Ending {} <> {}", &tcp_local_addr, udp_dst);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
read_len = n;
|
Ok(n) => n,
|
||||||
crate::traffic_status::traffic_status_update(n, 0)?;
|
|
||||||
}
|
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
log::info!("[UdpGw] Ending {} <> {} with recv_udp_packet {}", &tcp_local_addr, udp_dst, e);
|
log::info!("[UdpGw] Ending {} <> {} with recv_udp_packet {}", &tcp_local_addr, udp_dst, e);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
|
crate::traffic_status::traffic_status_update(read_len, 0)?;
|
||||||
let new_id = stream.new_id();
|
let new_id = stream.new_id();
|
||||||
if let Err(e) = UdpGwClient::send_udpgw_packet(ipv6_enabled, read_len, udp_dst, domain_name.as_ref(), new_id, &mut writer).await {
|
let remote_addr = match domain_name {
|
||||||
|
Some(ref d) => socks5_impl::protocol::Address::from((d.clone(), udp_dst.port())),
|
||||||
|
None => udp_dst.into(),
|
||||||
|
};
|
||||||
|
if let Err(e) = UdpGwClient::send_udpgw_packet(ipv6_enabled, &tmp_buf[0..read_len], &remote_addr, new_id, &mut writer).await {
|
||||||
log::info!("[UdpGw] Ending {} <> {} with send_udpgw_packet {}", &tcp_local_addr, udp_dst, e);
|
log::info!("[UdpGw] Ending {} <> {} with send_udpgw_packet {}", &tcp_local_addr, udp_dst, e);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
@ -568,9 +571,10 @@ async fn handle_udp_gateway_session(
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
UdpGwResponse::Data(data) => {
|
UdpGwResponse::Data(data) => {
|
||||||
|
use socks5_impl::protocol::StreamOperation;
|
||||||
let len = data.len();
|
let len = data.len();
|
||||||
log::debug!("[UdpGw] {} <- {} receive len {}", &tcp_local_addr, udp_dst, len);
|
log::debug!("[UdpGw] {} <- {} receive len {}", &tcp_local_addr, udp_dst, len);
|
||||||
if let Err(e) = UdpGwClient::send_udp_packet(data, &mut udp_stack).await {
|
if let Err(e) = udp_stack.write_all(&data.data).await {
|
||||||
log::error!("[UdpGw] Ending {} <> {} with send_udp_packet {}", &tcp_local_addr, udp_dst, e);
|
log::error!("[UdpGw] Ending {} <> {} with send_udp_packet {}", &tcp_local_addr, udp_dst, e);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
@ -588,7 +592,7 @@ async fn handle_udp_gateway_session(
|
||||||
}
|
}
|
||||||
|
|
||||||
if !stream.is_closed() {
|
if !stream.is_closed() {
|
||||||
udpgw_client.release_server_connection_with_stream(stream, reader, writer).await;
|
udpgw_client.release_server_connection_full(stream, reader, writer).await;
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|
529
src/udpgw.rs
529
src/udpgw.rs
|
@ -1,12 +1,6 @@
|
||||||
use crate::error::Result;
|
use crate::error::Result;
|
||||||
use ipstack::stream::IpStackUdpStream;
|
use socks5_impl::protocol::{Address, AsyncStreamOperation, BufMut, StreamOperation};
|
||||||
use socks5_impl::protocol::{AsyncStreamOperation, BufMut, StreamOperation};
|
use std::{collections::VecDeque, hash::Hash, net::SocketAddr, sync::atomic::Ordering::Relaxed};
|
||||||
use std::{
|
|
||||||
collections::VecDeque,
|
|
||||||
hash::Hash,
|
|
||||||
net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6},
|
|
||||||
sync::atomic::Ordering::Relaxed,
|
|
||||||
};
|
|
||||||
use tokio::{
|
use tokio::{
|
||||||
io::{AsyncReadExt, AsyncWriteExt},
|
io::{AsyncReadExt, AsyncWriteExt},
|
||||||
net::{
|
net::{
|
||||||
|
@ -17,25 +11,65 @@ use tokio::{
|
||||||
time::{sleep, Duration},
|
time::{sleep, Duration},
|
||||||
};
|
};
|
||||||
|
|
||||||
pub const UDPGW_MAX_CONNECTIONS: usize = 100;
|
pub(crate) const UDPGW_LENGTH_FIELD_SIZE: usize = std::mem::size_of::<u16>();
|
||||||
pub const UDPGW_KEEPALIVE_TIME: tokio::time::Duration = std::time::Duration::from_secs(10);
|
pub(crate) const UDPGW_MAX_CONNECTIONS: u16 = 100;
|
||||||
pub const UDPGW_FLAG_KEEPALIVE: u8 = 0x01;
|
pub(crate) const UDPGW_KEEPALIVE_TIME: tokio::time::Duration = std::time::Duration::from_secs(10);
|
||||||
pub const UDPGW_FLAG_IPV4: u8 = 0x00;
|
|
||||||
pub const UDPGW_FLAG_IPV6: u8 = 0x08;
|
|
||||||
pub const UDPGW_FLAG_DOMAIN: u8 = 0x10;
|
|
||||||
pub const UDPGW_FLAG_ERR: u8 = 0x20;
|
|
||||||
|
|
||||||
pub const UDPGW_LENGTH_FIELD_SIZE: usize = std::mem::size_of::<u16>();
|
pub const UDPGW_FLAG_KEEPALIVE: u8 = 0x01;
|
||||||
|
pub const UDPGW_FLAG_ERR: u8 = 0x20;
|
||||||
|
pub const UDPGW_FLAG_DATA: u8 = 0x02;
|
||||||
|
|
||||||
static TCP_COUNTER: std::sync::atomic::AtomicU32 = std::sync::atomic::AtomicU32::new(0);
|
static TCP_COUNTER: std::sync::atomic::AtomicU32 = std::sync::atomic::AtomicU32::new(0);
|
||||||
|
|
||||||
|
/// UDP Gateway Packet Format
|
||||||
|
///
|
||||||
|
/// The format is referenced from SOCKS5 packet format, with additional flags and connection ID fields.
|
||||||
|
///
|
||||||
|
/// `LEN`: This field is indicated the length of the packet, not including the length field itself.
|
||||||
|
///
|
||||||
|
/// `FLAGS`: This field is used to indicate the packet type. The flags are defined as follows:
|
||||||
|
/// - `0x01`: Keepalive packet without address and data
|
||||||
|
/// - `0x20`: Error packet without address and data
|
||||||
|
/// - `0x02`: Data packet with address and data
|
||||||
|
///
|
||||||
|
/// `CONN_ID`: This field is used to indicate the unique connection ID for the packet.
|
||||||
|
///
|
||||||
|
/// `ATYP` & `DST.ADDR` & `DST.PORT`: This fields are used to indicate the remote address and port.
|
||||||
|
/// It can be either an IPv4 address, an IPv6 address, or a domain name, depending on the `ATYP` field.
|
||||||
|
/// The address format directly uses the address format of the [SOCKS5](https://datatracker.ietf.org/doc/html/rfc1928#section-4) protocol.
|
||||||
|
/// - `ATYP`: Address Type, 1 byte, indicating the type of address ( 0x01-IPv4, 0x04-IPv6, or 0x03-domain name )
|
||||||
|
/// - `DST.ADDR`: Destination Address. If `ATYP` is 0x01 or 0x04, it is 4 or 16 bytes of IP address;
|
||||||
|
/// If `ATYP` is 0x03, it is a domain name, `DST.ADDR` is a variable length field,
|
||||||
|
/// it begins with a 1-byte length field and then the domain name without null-termination,
|
||||||
|
/// since the length field is 1 byte, the maximum length of the domain name is 255 bytes.
|
||||||
|
/// - `DST.PORT`: Destination Port, 2 bytes, the port number of the destination address.
|
||||||
|
///
|
||||||
|
/// `DATA`: The data field, a variable length field, the length is determined by the `LEN` field.
|
||||||
|
///
|
||||||
|
/// All the digits fields are in big-endian byte order.
|
||||||
|
///
|
||||||
|
/// ```plain
|
||||||
|
/// +-----+ +-------+---------+ +------+----------+----------+ +----------+
|
||||||
|
/// | LEN | | FLAGS | CONN_ID | | ATYP | DST.ADDR | DST.PORT | | DATA |
|
||||||
|
/// +-----+ +-------+---------+ +------+----------+----------+ +----------+
|
||||||
|
/// | 2 | | 1 | 2 | | 1 | Variable | 2 | | Variable |
|
||||||
|
/// +-----+ +-------+---------+ +------+----------+----------+ +----------+
|
||||||
|
/// ```
|
||||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||||
pub struct Packet {
|
pub struct Packet {
|
||||||
pub length: u16,
|
|
||||||
pub header: UdpgwHeader,
|
pub header: UdpgwHeader,
|
||||||
|
pub address: Option<Address>,
|
||||||
pub data: Vec<u8>,
|
pub data: Vec<u8>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Display for Packet {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
let addr = self.address.as_ref().map_or("None".to_string(), |addr| addr.to_string());
|
||||||
|
let len = self.data.len();
|
||||||
|
write!(f, "Packet {{ {}, address: {}, payload length: {} }}", self.header, addr, len)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl From<Packet> for Vec<u8> {
|
impl From<Packet> for Vec<u8> {
|
||||||
fn from(packet: Packet) -> Vec<u8> {
|
fn from(packet: Packet) -> Vec<u8> {
|
||||||
(&packet).into()
|
(&packet).into()
|
||||||
|
@ -57,20 +91,56 @@ impl TryFrom<&[u8]> for Packet {
|
||||||
if value.len() < UDPGW_LENGTH_FIELD_SIZE {
|
if value.len() < UDPGW_LENGTH_FIELD_SIZE {
|
||||||
return Err(std::io::ErrorKind::InvalidData.into());
|
return Err(std::io::ErrorKind::InvalidData.into());
|
||||||
}
|
}
|
||||||
let length = u16::from_le_bytes([value[0], value[1]]);
|
let mut iter = std::io::Cursor::new(value);
|
||||||
|
use tokio_util::bytes::Buf;
|
||||||
|
let length = iter.get_u16();
|
||||||
if value.len() < length as usize + UDPGW_LENGTH_FIELD_SIZE {
|
if value.len() < length as usize + UDPGW_LENGTH_FIELD_SIZE {
|
||||||
return Err(std::io::ErrorKind::InvalidData.into());
|
return Err(std::io::ErrorKind::InvalidData.into());
|
||||||
}
|
}
|
||||||
let header = UdpgwHeader::try_from(&value[UDPGW_LENGTH_FIELD_SIZE..])?;
|
let header = UdpgwHeader::retrieve_from_stream(&mut iter)?;
|
||||||
let data = value[UDPGW_LENGTH_FIELD_SIZE + header.len()..].to_vec();
|
let address = if header.flags & UDPGW_FLAG_DATA != 0 {
|
||||||
Ok(Packet::new(header, data))
|
Some(Address::retrieve_from_stream(&mut iter)?)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
Ok(Packet::new(header, address, iter.chunk()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Packet {
|
impl Packet {
|
||||||
pub fn new(header: UdpgwHeader, data: Vec<u8>) -> Self {
|
pub fn new(header: UdpgwHeader, address: Option<Address>, data: &[u8]) -> Self {
|
||||||
let length = (header.len() + data.len()) as u16;
|
let data = data.to_vec();
|
||||||
Packet { length, header, data }
|
Packet { header, address, data }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn build_keepalive_packet(conn_id: u16) -> Self {
|
||||||
|
Packet::new(UdpgwHeader::new(UDPGW_FLAG_KEEPALIVE, conn_id), None, &[])
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn build_error_packet(conn_id: u16) -> Self {
|
||||||
|
Packet::new(UdpgwHeader::new(UDPGW_FLAG_ERR, conn_id), None, &[])
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn build_packet_from_address(conn_id: u16, remote_addr: &Address, data: &[u8]) -> std::io::Result<Self> {
|
||||||
|
use socks5_impl::protocol::Address::{DomainAddress, SocketAddress};
|
||||||
|
let packet = match remote_addr {
|
||||||
|
SocketAddress(addr) => Packet::build_ip_packet(conn_id, *addr, data),
|
||||||
|
DomainAddress(domain, port) => Packet::build_domain_packet(conn_id, *port, domain, data)?,
|
||||||
|
};
|
||||||
|
Ok(packet)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn build_ip_packet(conn_id: u16, remote_addr: SocketAddr, data: &[u8]) -> Self {
|
||||||
|
let addr: Address = remote_addr.into();
|
||||||
|
Packet::new(UdpgwHeader::new(UDPGW_FLAG_DATA, conn_id), Some(addr), data)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn build_domain_packet(conn_id: u16, port: u16, domain: &str, data: &[u8]) -> std::io::Result<Self> {
|
||||||
|
if domain.len() > 255 {
|
||||||
|
return Err(std::io::ErrorKind::InvalidInput.into());
|
||||||
|
}
|
||||||
|
let addr = Address::from((domain, port));
|
||||||
|
Ok(Packet::new(UdpgwHeader::new(UDPGW_FLAG_DATA, conn_id), Some(addr), data))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -82,23 +152,30 @@ impl StreamOperation for Packet {
|
||||||
{
|
{
|
||||||
let mut buf = [0; UDPGW_LENGTH_FIELD_SIZE];
|
let mut buf = [0; UDPGW_LENGTH_FIELD_SIZE];
|
||||||
stream.read_exact(&mut buf)?;
|
stream.read_exact(&mut buf)?;
|
||||||
let length = u16::from_le_bytes(buf);
|
let length = u16::from_be_bytes(buf) as usize;
|
||||||
let mut buf = [0; UdpgwHeader::static_len()];
|
let header = UdpgwHeader::retrieve_from_stream(stream)?;
|
||||||
stream.read_exact(&mut buf)?;
|
let address = if header.flags & UDPGW_FLAG_DATA != 0 {
|
||||||
let header = UdpgwHeader::try_from(&buf[..])?;
|
Some(Address::retrieve_from_stream(stream)?)
|
||||||
let mut data = vec![0; length as usize - header.len()];
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
let mut data = vec![0; length - header.len() - address.as_ref().map_or(0, |addr| addr.len())];
|
||||||
stream.read_exact(&mut data)?;
|
stream.read_exact(&mut data)?;
|
||||||
Ok(Packet::new(header, data))
|
Ok(Packet::new(header, address, &data))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn write_to_buf<B: BufMut>(&self, buf: &mut B) {
|
fn write_to_buf<B: BufMut>(&self, buf: &mut B) {
|
||||||
buf.put_u16_le(self.length);
|
let len = self.len() - UDPGW_LENGTH_FIELD_SIZE;
|
||||||
|
buf.put_u16(len as u16);
|
||||||
self.header.write_to_buf(buf);
|
self.header.write_to_buf(buf);
|
||||||
|
if let Some(addr) = &self.address {
|
||||||
|
addr.write_to_buf(buf);
|
||||||
|
}
|
||||||
buf.put_slice(&self.data);
|
buf.put_slice(&self.data);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn len(&self) -> usize {
|
fn len(&self) -> usize {
|
||||||
UDPGW_LENGTH_FIELD_SIZE + self.header.len() + self.data.len()
|
UDPGW_LENGTH_FIELD_SIZE + self.header.len() + self.address.as_ref().map_or(0, |addr| addr.len()) + self.data.len()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -111,22 +188,32 @@ impl AsyncStreamOperation for Packet {
|
||||||
{
|
{
|
||||||
let mut buf = [0; 2];
|
let mut buf = [0; 2];
|
||||||
r.read_exact(&mut buf).await?;
|
r.read_exact(&mut buf).await?;
|
||||||
let length = u16::from_le_bytes(buf);
|
let length = u16::from_be_bytes(buf) as usize;
|
||||||
let header = UdpgwHeader::retrieve_from_async_stream(r).await?;
|
let header = UdpgwHeader::retrieve_from_async_stream(r).await?;
|
||||||
let mut data = vec![0; length as usize - header.len()];
|
let address = if header.flags & UDPGW_FLAG_DATA != 0 {
|
||||||
|
Some(Address::retrieve_from_async_stream(r).await?)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
let mut data = vec![0; length - header.len() - address.as_ref().map_or(0, |addr| addr.len())];
|
||||||
r.read_exact(&mut data).await?;
|
r.read_exact(&mut data).await?;
|
||||||
Ok(Packet::new(header, data))
|
Ok(Packet::new(header, address, &data))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||||
#[repr(C)]
|
|
||||||
#[repr(packed(1))]
|
|
||||||
pub struct UdpgwHeader {
|
pub struct UdpgwHeader {
|
||||||
pub flags: u8,
|
pub flags: u8,
|
||||||
pub conn_id: u16,
|
pub conn_id: u16,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Display for UdpgwHeader {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
let id = self.conn_id;
|
||||||
|
write!(f, "flags: 0x{:02x}, conn_id: {}", self.flags, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl StreamOperation for UdpgwHeader {
|
impl StreamOperation for UdpgwHeader {
|
||||||
fn retrieve_from_stream<R>(stream: &mut R) -> std::io::Result<Self>
|
fn retrieve_from_stream<R>(stream: &mut R) -> std::io::Result<Self>
|
||||||
where
|
where
|
||||||
|
@ -167,7 +254,7 @@ impl UdpgwHeader {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub const fn static_len() -> usize {
|
pub const fn static_len() -> usize {
|
||||||
std::mem::size_of::<UdpgwHeader>()
|
std::mem::size_of::<u8>() + std::mem::size_of::<u16>()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -178,10 +265,8 @@ impl TryFrom<&[u8]> for UdpgwHeader {
|
||||||
if value.len() < UdpgwHeader::static_len() {
|
if value.len() < UdpgwHeader::static_len() {
|
||||||
return Err(std::io::ErrorKind::InvalidData.into());
|
return Err(std::io::ErrorKind::InvalidData.into());
|
||||||
}
|
}
|
||||||
Ok(UdpgwHeader {
|
let conn_id = u16::from_be_bytes([value[1], value[2]]);
|
||||||
flags: value[0],
|
Ok(UdpgwHeader { flags: value[0], conn_id })
|
||||||
conn_id: u16::from_le_bytes([value[1], value[2]]),
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -189,137 +274,25 @@ impl From<&UdpgwHeader> for Vec<u8> {
|
||||||
fn from(header: &UdpgwHeader) -> Vec<u8> {
|
fn from(header: &UdpgwHeader) -> Vec<u8> {
|
||||||
let mut bytes = vec![0; header.len()];
|
let mut bytes = vec![0; header.len()];
|
||||||
bytes[0] = header.flags;
|
bytes[0] = header.flags;
|
||||||
bytes[1..3].copy_from_slice(&header.conn_id.to_le_bytes());
|
bytes[1..3].copy_from_slice(&header.conn_id.to_be_bytes());
|
||||||
bytes
|
bytes
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::len_without_is_empty)]
|
|
||||||
#[derive(Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Debug)]
|
|
||||||
pub struct BinSocketAddr(SocketAddr);
|
|
||||||
|
|
||||||
impl BinSocketAddr {
|
|
||||||
pub fn len(&self) -> usize {
|
|
||||||
match self.0 {
|
|
||||||
SocketAddr::V4(_) => Self::static_len(false),
|
|
||||||
SocketAddr::V6(_) => Self::static_len(true),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn static_len(is_ipv6: bool) -> usize {
|
|
||||||
if is_ipv6 {
|
|
||||||
std::mem::size_of::<Ipv6Addr>() + std::mem::size_of::<u16>()
|
|
||||||
} else {
|
|
||||||
std::mem::size_of::<Ipv4Addr>() + std::mem::size_of::<u16>()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<&BinSocketAddr> for Vec<u8> {
|
|
||||||
fn from(addr: &BinSocketAddr) -> Vec<u8> {
|
|
||||||
socket_addr_to_binary(&addr.0)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<BinSocketAddr> for Vec<u8> {
|
|
||||||
fn from(addr: BinSocketAddr) -> Vec<u8> {
|
|
||||||
socket_addr_to_binary(&addr.0)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TryFrom<&[u8]> for BinSocketAddr {
|
|
||||||
type Error = std::io::Error;
|
|
||||||
fn try_from(value: &[u8]) -> std::result::Result<Self, Self::Error> {
|
|
||||||
Ok(BinSocketAddr(binary_to_socket_addr(value)?))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<SocketAddr> for BinSocketAddr {
|
|
||||||
fn from(addr: SocketAddr) -> Self {
|
|
||||||
BinSocketAddr(addr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<BinSocketAddr> for SocketAddr {
|
|
||||||
fn from(addr: BinSocketAddr) -> Self {
|
|
||||||
addr.0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn socket_addr_to_binary(addr: &SocketAddr) -> Vec<u8> {
|
|
||||||
match addr {
|
|
||||||
SocketAddr::V4(addr_v4) => {
|
|
||||||
let mut bytes = vec![0; std::mem::size_of::<SocketAddrV4>()];
|
|
||||||
bytes[0..4].copy_from_slice(&addr_v4.ip().octets());
|
|
||||||
bytes[4..6].copy_from_slice(&addr_v4.port().to_be_bytes());
|
|
||||||
bytes
|
|
||||||
}
|
|
||||||
SocketAddr::V6(addr_v6) => {
|
|
||||||
let mut bytes = vec![0; std::mem::size_of::<Ipv6Addr>() + std::mem::size_of::<u16>()];
|
|
||||||
bytes[0..16].copy_from_slice(&addr_v6.ip().octets());
|
|
||||||
bytes[16..18].copy_from_slice(&addr_v6.port().to_be_bytes());
|
|
||||||
bytes
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn binary_to_socket_addr(bytes: &[u8]) -> std::io::Result<SocketAddr> {
|
|
||||||
if bytes.len() == std::mem::size_of::<SocketAddrV4>() {
|
|
||||||
let ip = Ipv4Addr::new(bytes[0], bytes[1], bytes[2], bytes[3]);
|
|
||||||
let port = u16::from_be_bytes([bytes[4], bytes[5]]);
|
|
||||||
Ok(SocketAddr::V4(SocketAddrV4::new(ip, port)))
|
|
||||||
} else if bytes.len() == std::mem::size_of::<Ipv6Addr>() + std::mem::size_of::<u16>() {
|
|
||||||
let mut ip = [0; 16];
|
|
||||||
ip.copy_from_slice(&bytes[0..16]);
|
|
||||||
let port = u16::from_be_bytes([bytes[16], bytes[17]]);
|
|
||||||
Ok(SocketAddr::V6(SocketAddrV6::new(ip.into(), port, 0, 0)))
|
|
||||||
} else {
|
|
||||||
Err(std::io::ErrorKind::InvalidData.into())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub(crate) struct UdpGwData<'a> {
|
pub(crate) enum UdpGwResponse {
|
||||||
flags: u8,
|
|
||||||
conn_id: u16,
|
|
||||||
remote_addr: SocketAddr,
|
|
||||||
udpdata: &'a [u8],
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'a> UdpGwData<'a> {
|
|
||||||
pub fn len(&self) -> usize {
|
|
||||||
self.udpdata.len()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[allow(dead_code)]
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub(crate) enum UdpGwResponse<'a> {
|
|
||||||
KeepAlive,
|
KeepAlive,
|
||||||
Error,
|
Error,
|
||||||
TcpClose,
|
TcpClose,
|
||||||
Data(UdpGwData<'a>),
|
Data(Packet),
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub(crate) struct UdpGwClientStreamWriter {
|
|
||||||
inner: OwnedWriteHalf,
|
|
||||||
tmp_buf: Vec<u8>,
|
|
||||||
send_buf: Vec<u8>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub(crate) struct UdpGwClientStreamReader {
|
|
||||||
inner: OwnedReadHalf,
|
|
||||||
recv_buf: Vec<u8>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub(crate) struct UdpGwClientStream {
|
pub(crate) struct UdpGwClientStream {
|
||||||
local_addr: String,
|
local_addr: String,
|
||||||
writer: Option<UdpGwClientStreamWriter>,
|
writer: Option<OwnedWriteHalf>,
|
||||||
reader: Option<UdpGwClientStreamReader>,
|
reader: Option<OwnedReadHalf>,
|
||||||
conn_id: u16,
|
conn_id: u16,
|
||||||
closed: bool,
|
closed: bool,
|
||||||
last_activity: std::time::Instant,
|
last_activity: std::time::Instant,
|
||||||
|
@ -335,19 +308,20 @@ impl UdpGwClientStream {
|
||||||
pub fn close(&mut self) {
|
pub fn close(&mut self) {
|
||||||
self.closed = true;
|
self.closed = true;
|
||||||
}
|
}
|
||||||
pub fn get_reader(&mut self) -> Option<UdpGwClientStreamReader> {
|
|
||||||
|
pub fn get_reader(&mut self) -> Option<OwnedReadHalf> {
|
||||||
self.reader.take()
|
self.reader.take()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn set_reader(&mut self, mut reader: Option<UdpGwClientStreamReader>) {
|
pub fn set_reader(&mut self, reader: Option<OwnedReadHalf>) {
|
||||||
self.reader = reader.take();
|
self.reader = reader;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn set_writer(&mut self, mut writer: Option<UdpGwClientStreamWriter>) {
|
pub fn set_writer(&mut self, writer: Option<OwnedWriteHalf>) {
|
||||||
self.writer = writer.take();
|
self.writer = writer;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_writer(&mut self) -> Option<UdpGwClientStreamWriter> {
|
pub fn get_writer(&mut self) -> Option<OwnedWriteHalf> {
|
||||||
self.writer.take()
|
self.writer.take()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -371,21 +345,13 @@ impl UdpGwClientStream {
|
||||||
self.conn_id += 1;
|
self.conn_id += 1;
|
||||||
self.conn_id
|
self.conn_id
|
||||||
}
|
}
|
||||||
pub fn new(udp_mtu: u16, tcp_server_stream: TcpStream) -> Self {
|
|
||||||
let local_addr = tcp_server_stream
|
pub fn new(tcp_server_stream: TcpStream) -> Self {
|
||||||
.local_addr()
|
let default = "0.0.0.0:0".parse::<SocketAddr>().unwrap();
|
||||||
.unwrap_or_else(|_| "0.0.0.0:0".parse::<SocketAddr>().unwrap())
|
let local_addr = tcp_server_stream.local_addr().unwrap_or(default).to_string();
|
||||||
.to_string();
|
|
||||||
let (rx, tx) = tcp_server_stream.into_split();
|
let (rx, tx) = tcp_server_stream.into_split();
|
||||||
let writer = UdpGwClientStreamWriter {
|
let writer = tx;
|
||||||
inner: tx,
|
let reader = rx;
|
||||||
tmp_buf: vec![0; udp_mtu.into()],
|
|
||||||
send_buf: vec![0; udp_mtu.into()],
|
|
||||||
};
|
|
||||||
let reader = UdpGwClientStreamReader {
|
|
||||||
inner: rx,
|
|
||||||
recv_buf: vec![0; udp_mtu.into()],
|
|
||||||
};
|
|
||||||
TCP_COUNTER.fetch_add(1, Relaxed);
|
TCP_COUNTER.fetch_add(1, Relaxed);
|
||||||
UdpGwClientStream {
|
UdpGwClientStream {
|
||||||
local_addr,
|
local_addr,
|
||||||
|
@ -405,13 +371,11 @@ pub(crate) struct UdpGwClient {
|
||||||
udp_timeout: u64,
|
udp_timeout: u64,
|
||||||
keepalive_time: Duration,
|
keepalive_time: Duration,
|
||||||
server_addr: SocketAddr,
|
server_addr: SocketAddr,
|
||||||
keepalive_packet: Vec<u8>,
|
|
||||||
server_connections: Mutex<VecDeque<UdpGwClientStream>>,
|
server_connections: Mutex<VecDeque<UdpGwClientStream>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl UdpGwClient {
|
impl UdpGwClient {
|
||||||
pub fn new(udp_mtu: u16, max_connections: u16, keepalive_time: Duration, udp_timeout: u64, server_addr: SocketAddr) -> Self {
|
pub fn new(udp_mtu: u16, max_connections: u16, keepalive_time: Duration, udp_timeout: u64, server_addr: SocketAddr) -> Self {
|
||||||
let keepalive_packet: Vec<u8> = Packet::new(UdpgwHeader::new(UDPGW_FLAG_KEEPALIVE, 0), vec![]).into();
|
|
||||||
let server_connections = Mutex::new(VecDeque::with_capacity(max_connections as usize));
|
let server_connections = Mutex::new(VecDeque::with_capacity(max_connections as usize));
|
||||||
UdpGwClient {
|
UdpGwClient {
|
||||||
udp_mtu,
|
udp_mtu,
|
||||||
|
@ -419,7 +383,6 @@ impl UdpGwClient {
|
||||||
udp_timeout,
|
udp_timeout,
|
||||||
server_addr,
|
server_addr,
|
||||||
keepalive_time,
|
keepalive_time,
|
||||||
keepalive_packet,
|
|
||||||
server_connections,
|
server_connections,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -446,11 +409,11 @@ impl UdpGwClient {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) async fn release_server_connection_with_stream(
|
pub(crate) async fn release_server_connection_full(
|
||||||
&self,
|
&self,
|
||||||
mut stream: UdpGwClientStream,
|
mut stream: UdpGwClientStream,
|
||||||
reader: UdpGwClientStreamReader,
|
reader: OwnedReadHalf,
|
||||||
writer: UdpGwClientStreamWriter,
|
writer: OwnedWriteHalf,
|
||||||
) {
|
) {
|
||||||
if self.server_connections.lock().await.len() < self.max_connections as usize {
|
if self.server_connections.lock().await.len() < self.max_connections as usize {
|
||||||
stream.set_reader(Some(reader));
|
stream.set_reader(Some(reader));
|
||||||
|
@ -480,16 +443,17 @@ impl UdpGwClient {
|
||||||
let Some(mut stream_writer) = stream.get_writer() else {
|
let Some(mut stream_writer) = stream.get_writer() else {
|
||||||
continue;
|
continue;
|
||||||
};
|
};
|
||||||
let local_addr = stream_writer.inner.local_addr();
|
let local_addr = stream_writer.local_addr();
|
||||||
log::debug!("{:?}:{} send keepalive", local_addr, stream.id());
|
log::debug!("{:?}:{} send keepalive", local_addr, stream.id());
|
||||||
if let Err(e) = stream_writer.inner.write_all(&self.keepalive_packet).await {
|
let keepalive_packet: Vec<u8> = Packet::build_keepalive_packet(stream.id()).into();
|
||||||
|
if let Err(e) = stream_writer.write_all(&keepalive_packet).await {
|
||||||
log::warn!("{:?}:{} send keepalive failed: {}", local_addr, stream.id(), e);
|
log::warn!("{:?}:{} send keepalive failed: {}", local_addr, stream.id(), e);
|
||||||
} else {
|
continue;
|
||||||
|
}
|
||||||
match UdpGwClient::recv_udpgw_packet(self.udp_mtu, 10, &mut stream_reader).await {
|
match UdpGwClient::recv_udpgw_packet(self.udp_mtu, 10, &mut stream_reader).await {
|
||||||
Ok(UdpGwResponse::KeepAlive) => {
|
Ok(UdpGwResponse::KeepAlive) => {
|
||||||
stream.update_activity();
|
stream.update_activity();
|
||||||
self.release_server_connection_with_stream(stream, stream_reader, stream_writer)
|
self.release_server_connection_full(stream, stream_reader, stream_writer).await;
|
||||||
.await;
|
|
||||||
}
|
}
|
||||||
Ok(v) => log::warn!("{:?}:{} keepalive unexpected response: {:?}", local_addr, stream.id(), v),
|
Ok(v) => log::warn!("{:?}:{} keepalive unexpected response: {:?}", local_addr, stream.id(), v),
|
||||||
Err(e) => log::warn!("{:?}:{} keepalive no response, error \"{}\"", local_addr, stream.id(), e),
|
Err(e) => log::warn!("{:?}:{} keepalive no response, error \"{}\"", local_addr, stream.id(), e),
|
||||||
|
@ -497,84 +461,21 @@ impl UdpGwClient {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
/// Parses the UDP response data.
|
/// Parses the UDP response data.
|
||||||
pub(crate) fn parse_udp_response(udp_mtu: u16, data_len: usize, stream: &mut UdpGwClientStreamReader) -> Result<UdpGwResponse> {
|
pub(crate) fn parse_udp_response(udp_mtu: u16, data: &[u8]) -> Result<UdpGwResponse> {
|
||||||
let data = &stream.recv_buf;
|
let packet = Packet::try_from(data)?;
|
||||||
let header_len = UdpgwHeader::static_len();
|
let flags = packet.header.flags;
|
||||||
if data_len < header_len {
|
|
||||||
return Err("Invalid udpgw data".into());
|
|
||||||
}
|
|
||||||
let header_bytes = &data[..header_len];
|
|
||||||
let header = UdpgwHeader {
|
|
||||||
flags: header_bytes[0],
|
|
||||||
conn_id: u16::from_le_bytes([header_bytes[1], header_bytes[2]]),
|
|
||||||
};
|
|
||||||
|
|
||||||
let flags = header.flags;
|
|
||||||
let conn_id = header.conn_id;
|
|
||||||
|
|
||||||
let ip_data = &data[header_len..];
|
|
||||||
let mut data_len = data_len - header_len;
|
|
||||||
|
|
||||||
if flags & UDPGW_FLAG_ERR != 0 {
|
if flags & UDPGW_FLAG_ERR != 0 {
|
||||||
return Ok(UdpGwResponse::Error);
|
return Ok(UdpGwResponse::Error);
|
||||||
}
|
}
|
||||||
|
|
||||||
if flags & UDPGW_FLAG_KEEPALIVE != 0 {
|
if flags & UDPGW_FLAG_KEEPALIVE != 0 {
|
||||||
return Ok(UdpGwResponse::KeepAlive);
|
return Ok(UdpGwResponse::KeepAlive);
|
||||||
}
|
}
|
||||||
|
if packet.data.len() > udp_mtu as usize {
|
||||||
if flags & UDPGW_FLAG_IPV6 != 0 {
|
|
||||||
let ipv6_addr_len = BinSocketAddr::static_len(true);
|
|
||||||
if data_len < ipv6_addr_len {
|
|
||||||
return Err("ipv6 Invalid UDP data".into());
|
|
||||||
}
|
|
||||||
let addr_ipv6 = BinSocketAddr::try_from(&ip_data[..ipv6_addr_len])?;
|
|
||||||
data_len -= ipv6_addr_len;
|
|
||||||
|
|
||||||
if data_len > udp_mtu as usize {
|
|
||||||
return Err("too much data".into());
|
return Err("too much data".into());
|
||||||
}
|
}
|
||||||
return Ok(UdpGwResponse::Data(UdpGwData {
|
Ok(UdpGwResponse::Data(packet))
|
||||||
flags,
|
|
||||||
conn_id,
|
|
||||||
remote_addr: addr_ipv6.into(),
|
|
||||||
udpdata: &ip_data[ipv6_addr_len..(data_len + ipv6_addr_len)],
|
|
||||||
}));
|
|
||||||
} else {
|
|
||||||
let ipv4_addr_len = BinSocketAddr::static_len(false);
|
|
||||||
if data_len < ipv4_addr_len {
|
|
||||||
return Err("ipv4 Invalid UDP data".into());
|
|
||||||
}
|
|
||||||
let addr_ipv4 = BinSocketAddr::try_from(&ip_data[..ipv4_addr_len])?;
|
|
||||||
data_len -= ipv4_addr_len;
|
|
||||||
|
|
||||||
if data_len > udp_mtu as usize {
|
|
||||||
return Err("too much data".into());
|
|
||||||
}
|
|
||||||
return Ok(UdpGwResponse::Data(UdpGwData {
|
|
||||||
flags,
|
|
||||||
conn_id,
|
|
||||||
remote_addr: addr_ipv4.into(),
|
|
||||||
udpdata: &ip_data[ipv4_addr_len..(data_len + ipv4_addr_len)],
|
|
||||||
}));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) async fn recv_udp_packet(
|
|
||||||
udp_stack: &mut IpStackUdpStream,
|
|
||||||
stream: &mut UdpGwClientStreamWriter,
|
|
||||||
) -> std::result::Result<usize, std::io::Error> {
|
|
||||||
udp_stack.read(&mut stream.tmp_buf).await
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) async fn send_udp_packet<'a>(
|
|
||||||
packet: UdpGwData<'a>,
|
|
||||||
udp_stack: &mut IpStackUdpStream,
|
|
||||||
) -> std::result::Result<(), std::io::Error> {
|
|
||||||
udp_stack.write_all(packet.udpdata).await
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Receives a UDP gateway packet.
|
/// Receives a UDP gateway packet.
|
||||||
|
@ -588,37 +489,15 @@ impl UdpGwClient {
|
||||||
///
|
///
|
||||||
/// # Returns
|
/// # Returns
|
||||||
/// - `Result<UdpGwResponse>`: Returns a result type containing the parsed UDP gateway response, or an error if one occurs.
|
/// - `Result<UdpGwResponse>`: Returns a result type containing the parsed UDP gateway response, or an error if one occurs.
|
||||||
pub(crate) async fn recv_udpgw_packet(udp_mtu: u16, udp_timeout: u64, stream: &mut UdpGwClientStreamReader) -> Result<UdpGwResponse> {
|
pub(crate) async fn recv_udpgw_packet(udp_mtu: u16, udp_timeout: u64, stream: &mut OwnedReadHalf) -> Result<UdpGwResponse> {
|
||||||
let result = tokio::time::timeout(
|
let mut data = vec![0; udp_mtu.into()];
|
||||||
tokio::time::Duration::from_secs(udp_timeout + 2),
|
let data_len = tokio::time::timeout(tokio::time::Duration::from_secs(udp_timeout + 2), stream.read(&mut data))
|
||||||
stream.inner.read(&mut stream.recv_buf[..2]),
|
|
||||||
)
|
|
||||||
.await
|
.await
|
||||||
.map_err(std::io::Error::from)?;
|
.map_err(std::io::Error::from)??;
|
||||||
let n = result?;
|
if data_len == 0 {
|
||||||
if n == 0 {
|
|
||||||
return Ok(UdpGwResponse::TcpClose);
|
return Ok(UdpGwResponse::TcpClose);
|
||||||
}
|
}
|
||||||
if n < UDPGW_LENGTH_FIELD_SIZE {
|
UdpGwClient::parse_udp_response(udp_mtu, &data[..data_len])
|
||||||
return Err("received Packet Length field error".into());
|
|
||||||
}
|
|
||||||
let packet_len = u16::from_le_bytes([stream.recv_buf[0], stream.recv_buf[1]]);
|
|
||||||
if packet_len > udp_mtu {
|
|
||||||
return Err("packet too long".into());
|
|
||||||
}
|
|
||||||
let mut left_len: usize = packet_len as usize;
|
|
||||||
let mut recv_len = 0;
|
|
||||||
while left_len > 0 {
|
|
||||||
let Ok(len) = stream.inner.read(&mut stream.recv_buf[recv_len..left_len]).await else {
|
|
||||||
return Ok(UdpGwResponse::TcpClose);
|
|
||||||
};
|
|
||||||
if len == 0 {
|
|
||||||
return Ok(UdpGwResponse::TcpClose);
|
|
||||||
}
|
|
||||||
recv_len += len;
|
|
||||||
left_len -= len;
|
|
||||||
}
|
|
||||||
UdpGwClient::parse_udp_response(udp_mtu, packet_len as usize, stream)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Sends a UDP gateway packet.
|
/// Sends a UDP gateway packet.
|
||||||
|
@ -629,9 +508,8 @@ impl UdpGwClient {
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
///
|
///
|
||||||
/// * `ipv6_enabled` - Whether IPv6 is enabled
|
/// * `ipv6_enabled` - Whether IPv6 is enabled
|
||||||
/// * `len` - Length of the data packet
|
/// * `data` - The data packet
|
||||||
/// * `remote_addr` - Remote address
|
/// * `remote_addr` - Remote address
|
||||||
/// * `domain` - Target domain (optional)
|
|
||||||
/// * `conn_id` - Connection ID
|
/// * `conn_id` - Connection ID
|
||||||
/// * `stream` - UDP gateway client writer stream
|
/// * `stream` - UDP gateway client writer stream
|
||||||
///
|
///
|
||||||
|
@ -640,80 +518,17 @@ impl UdpGwClient {
|
||||||
/// Returns `Ok(())` if the packet is sent successfully, otherwise returns an error.
|
/// Returns `Ok(())` if the packet is sent successfully, otherwise returns an error.
|
||||||
pub(crate) async fn send_udpgw_packet(
|
pub(crate) async fn send_udpgw_packet(
|
||||||
ipv6_enabled: bool,
|
ipv6_enabled: bool,
|
||||||
len: usize,
|
data: &[u8],
|
||||||
remote_addr: SocketAddr,
|
remote_addr: &socks5_impl::protocol::Address,
|
||||||
domain: Option<&String>,
|
|
||||||
conn_id: u16,
|
conn_id: u16,
|
||||||
stream: &mut UdpGwClientStreamWriter,
|
stream: &mut OwnedWriteHalf,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
stream.send_buf.clear();
|
if !ipv6_enabled && remote_addr.get_type() == socks5_impl::protocol::AddressType::IPv6 {
|
||||||
let data = &stream.tmp_buf;
|
|
||||||
let mut pack_len = UdpgwHeader::static_len() + len;
|
|
||||||
let packet = &mut stream.send_buf;
|
|
||||||
match domain {
|
|
||||||
Some(domain) => {
|
|
||||||
let addr_port = remote_addr.port();
|
|
||||||
let domain_len = domain.len();
|
|
||||||
if domain_len > 255 {
|
|
||||||
return Err("InvalidDomain".into());
|
|
||||||
}
|
|
||||||
pack_len += UDPGW_LENGTH_FIELD_SIZE;
|
|
||||||
pack_len += domain_len + 1;
|
|
||||||
packet.extend_from_slice(&(pack_len as u16).to_le_bytes());
|
|
||||||
packet.extend_from_slice(&[UDPGW_FLAG_DOMAIN]);
|
|
||||||
packet.extend_from_slice(&conn_id.to_le_bytes());
|
|
||||||
packet.extend_from_slice(&addr_port.to_be_bytes());
|
|
||||||
packet.extend_from_slice(domain.as_bytes());
|
|
||||||
packet.push(0);
|
|
||||||
packet.extend_from_slice(&data[..len]);
|
|
||||||
}
|
|
||||||
None => match remote_addr {
|
|
||||||
SocketAddr::V4(_) => {
|
|
||||||
let addr_ipv4 = BinSocketAddr::from(remote_addr);
|
|
||||||
pack_len += addr_ipv4.len();
|
|
||||||
packet.extend_from_slice(&(pack_len as u16).to_le_bytes());
|
|
||||||
packet.extend_from_slice(&[UDPGW_FLAG_IPV4]);
|
|
||||||
packet.extend_from_slice(&conn_id.to_le_bytes());
|
|
||||||
let addr_ipv4_bin: Vec<u8> = addr_ipv4.into();
|
|
||||||
packet.extend_from_slice(&addr_ipv4_bin);
|
|
||||||
packet.extend_from_slice(&data[..len]);
|
|
||||||
}
|
|
||||||
SocketAddr::V6(_) => {
|
|
||||||
if !ipv6_enabled {
|
|
||||||
return Err("ipv6 not support".into());
|
return Err("ipv6 not support".into());
|
||||||
}
|
}
|
||||||
let addr_ipv6 = BinSocketAddr::from(remote_addr);
|
let out_data: Vec<u8> = Packet::build_packet_from_address(conn_id, remote_addr, data)?.into();
|
||||||
pack_len += addr_ipv6.len();
|
stream.write_all(&out_data).await?;
|
||||||
packet.extend_from_slice(&(pack_len as u16).to_le_bytes());
|
|
||||||
packet.extend_from_slice(&[UDPGW_FLAG_IPV6]);
|
|
||||||
packet.extend_from_slice(&conn_id.to_le_bytes());
|
|
||||||
let addr_ipv6_bin: Vec<u8> = addr_ipv6.into();
|
|
||||||
packet.extend_from_slice(&addr_ipv6_bin);
|
|
||||||
packet.extend_from_slice(&data[..len]);
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
stream.inner.write_all(packet).await?;
|
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use super::{Packet, UdpgwHeader};
|
|
||||||
use socks5_impl::protocol::StreamOperation;
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_udpgw_header() {
|
|
||||||
let header = UdpgwHeader::new(0x01, 0x1234);
|
|
||||||
let mut bytes: Vec<u8> = vec![];
|
|
||||||
let packet = Packet::new(header, vec![]);
|
|
||||||
packet.write_to_buf(&mut bytes);
|
|
||||||
|
|
||||||
let header2 = Packet::retrieve_from_stream(&mut bytes.as_slice()).unwrap().header;
|
|
||||||
|
|
||||||
assert_eq!(header, header2);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue