refactor udpgw with Packet struct

This commit is contained in:
ssrlive 2024-10-23 03:50:00 +08:00
parent f049e4a998
commit ec79d54a9e
4 changed files with 241 additions and 127 deletions

View file

@ -117,7 +117,7 @@ pub struct Args {
#[arg(long, value_name = "IP:PORT")] #[arg(long, value_name = "IP:PORT")]
pub udpgw_server: Option<SocketAddr>, pub udpgw_server: Option<SocketAddr>,
/// Max udpgw connections /// Max udpgw connections, default value is 100
#[cfg(feature = "udpgw")] #[cfg(feature = "udpgw")]
#[arg(long, value_name = "number")] #[arg(long, value_name = "number")]
pub udpgw_max_connections: Option<u16>, pub udpgw_max_connections: Option<u16>,

View file

@ -18,7 +18,7 @@ pub(crate) const CLIENT_DISCONNECT_TIMEOUT: tokio::time::Duration = std::time::D
struct UdpRequest { struct UdpRequest {
flags: u8, flags: u8,
server_addr: SocketAddr, server_addr: SocketAddr,
conid: u16, conn_id: u16,
data: Vec<u8>, data: Vec<u8>,
} }
@ -41,26 +41,26 @@ impl Client {
#[derive(Debug, Clone, clap::Parser)] #[derive(Debug, Clone, clap::Parser)]
pub struct UdpGwArgs { pub struct UdpGwArgs {
/// UDP gateway listen address
#[arg(short, long, value_name = "IP:PORT", default_value = "127.0.0.1:7300")]
pub listen_addr: SocketAddr,
/// UDP mtu /// UDP mtu
#[arg(long, value_name = "udp mtu", default_value = "10240")] #[arg(short = 'm', long, value_name = "udp mtu", default_value = "10240")]
pub udp_mtu: u16, pub udp_mtu: u16,
/// Verbosity level /// UDP timeout in seconds
#[arg(short, long, value_name = "level", value_enum, default_value = "info")] #[arg(short = 't', long, value_name = "seconds", default_value = "3")]
pub verbosity: ArgVerbosity, pub udp_timeout: u64,
/// Daemonize for unix family or run as Windows service /// Daemonize for unix family or run as Windows service
#[cfg(unix)] #[cfg(unix)]
#[arg(long)] #[arg(long)]
pub daemonize: bool, pub daemonize: bool,
/// UDP timeout in seconds /// Verbosity level
#[arg(long, value_name = "seconds", default_value = "3")] #[arg(short, long, value_name = "level", value_enum, default_value = "info")]
pub udp_timeout: u64, pub verbosity: ArgVerbosity,
/// UDP gateway listen address
#[arg(long, value_name = "IP:PORT", default_value = "127.0.0.1:7300")]
pub listen_addr: SocketAddr,
} }
impl UdpGwArgs { impl UdpGwArgs {
@ -70,42 +70,43 @@ impl UdpGwArgs {
Self::parse() Self::parse()
} }
} }
async fn send_error(tx: Sender<Vec<u8>>, con: &mut UdpRequest) { async fn send_error(tx: Sender<Vec<u8>>, con: &mut UdpRequest) {
let error_packet = UdpgwHeader::new(UDPGW_FLAG_ERR, con.conid).into(); let error_packet: Vec<u8> = Packet::new(UdpgwHeader::new(UDPGW_FLAG_ERR, con.conn_id), vec![]).into();
if let Err(e) = tx.send(error_packet).await { if let Err(e) = tx.send(error_packet).await {
log::error!("send error response error {:?}", e); log::error!("send error response error {:?}", e);
} }
} }
async fn send_keepalive_response(tx: Sender<Vec<u8>>, conid: u16) { async fn send_keepalive_response(tx: Sender<Vec<u8>>, conn_id: u16) {
let keepalive_packet = UdpgwHeader::new(UDPGW_FLAG_KEEPALIVE, conid).into(); let keepalive_packet: Vec<u8> = Packet::new(UdpgwHeader::new(UDPGW_FLAG_KEEPALIVE, conn_id), vec![]).into();
if let Err(e) = tx.send(keepalive_packet).await { if let Err(e) = tx.send(keepalive_packet).await {
log::error!("send keepalive response error {:?}", e); log::error!("send keepalive response error {:?}", e);
} }
} }
pub fn parse_udp(udp_mtu: u16, data_len: usize, data: &[u8]) -> Result<(&[u8], u8, u16, SocketAddr)> { pub fn parse_udp(udp_mtu: u16, data_len: usize, data: &[u8]) -> Result<(&[u8], u8, u16, SocketAddr)> {
if data_len < std::mem::size_of::<UdpgwHeader>() { if data_len < UdpgwHeader::len() {
return Err("Invalid udpgw data".into()); return Err("Invalid udpgw data".into());
} }
let header_bytes = &data[..std::mem::size_of::<UdpgwHeader>()]; let header_bytes = &data[..UdpgwHeader::len()];
let header = UdpgwHeader { let header = UdpgwHeader {
flags: header_bytes[0], flags: header_bytes[0],
conid: u16::from_le_bytes([header_bytes[1], header_bytes[2]]), conn_id: u16::from_le_bytes([header_bytes[1], header_bytes[2]]),
}; };
let flags = header.flags; let flags = header.flags;
let conid = header.conid; let conn_id = header.conn_id;
// keepalive // keepalive
if flags & UDPGW_FLAG_KEEPALIVE != 0 { if flags & UDPGW_FLAG_KEEPALIVE != 0 {
return Ok((data, flags, conid, SocketAddrV4::new(Ipv4Addr::from(0), 0).into())); return Ok((data, flags, conn_id, SocketAddrV4::new(Ipv4Addr::from(0), 0).into()));
} }
let ip_data = &data[std::mem::size_of::<UdpgwHeader>()..]; let ip_data = &data[UdpgwHeader::len()..];
let mut data_len = data_len - std::mem::size_of::<UdpgwHeader>(); let mut data_len = data_len - UdpgwHeader::len();
// port_len + min(ipv4/ipv6/(domain_len + 1)) // port_len + min(ipv4/ipv6/(domain_len + 1))
if data_len < std::mem::size_of::<u16>() + 2 { if data_len < UDPGW_LENGTH_FIELD_SIZE + 2 {
return Err("Invalid udpgw data".into()); return Err("Invalid udpgw data".into());
} }
if flags & UDPGW_FLAG_DOMAIN != 0 { if flags & UDPGW_FLAG_DOMAIN != 0 {
@ -128,7 +129,7 @@ pub fn parse_udp(udp_mtu: u16, data_len: usize, data: &[u8]) -> Result<(&[u8], u
return Err("too much data".into()); return Err("too much data".into());
} }
let udpdata = &ip_data[(2 + domain.len() + 1)..]; let udpdata = &ip_data[(2 + domain.len() + 1)..];
Ok((udpdata, flags, conid, target)) Ok((udpdata, flags, conn_id, target))
} }
Err(_) => Err("Invalid UTF-8 sequence in domain".into()), Err(_) => Err("Invalid UTF-8 sequence in domain".into()),
} }
@ -152,7 +153,7 @@ pub fn parse_udp(udp_mtu: u16, data_len: usize, data: &[u8]) -> Result<(&[u8], u
return Ok(( return Ok((
&ip_data[std::mem::size_of::<UdpgwAddrIpv6>()..(data_len + std::mem::size_of::<UdpgwAddrIpv6>())], &ip_data[std::mem::size_of::<UdpgwAddrIpv6>()..(data_len + std::mem::size_of::<UdpgwAddrIpv6>())],
flags, flags,
conid, conn_id,
UdpgwAddr::IPV6(addr_ipv6).into(), UdpgwAddr::IPV6(addr_ipv6).into(),
)); ));
} else { } else {
@ -173,7 +174,7 @@ pub fn parse_udp(udp_mtu: u16, data_len: usize, data: &[u8]) -> Result<(&[u8], u
return Ok(( return Ok((
&ip_data[std::mem::size_of::<UdpgwAddrIpv4>()..(data_len + std::mem::size_of::<UdpgwAddrIpv4>())], &ip_data[std::mem::size_of::<UdpgwAddrIpv4>()..(data_len + std::mem::size_of::<UdpgwAddrIpv4>())],
flags, flags,
conid, conn_id,
UdpgwAddr::IPV4(addr_ipv4).into(), UdpgwAddr::IPV4(addr_ipv4).into(),
)); ));
} }
@ -195,13 +196,13 @@ async fn process_udp(addr: SocketAddr, udp_timeout: u64, tx: Sender<Vec<u8>>, co
Ok(ret) => { Ok(ret) => {
let (len, _addr) = ret?; let (len, _addr) = ret?;
let mut packet = vec![]; let mut packet = vec![];
let mut pack_len = std::mem::size_of::<UdpgwHeader>() + len; let mut pack_len = UdpgwHeader::len() + len;
match con.server_addr.into() { match con.server_addr.into() {
UdpgwAddr::IPV4(addr_ipv4) => { UdpgwAddr::IPV4(addr_ipv4) => {
pack_len += std::mem::size_of::<UdpgwAddrIpv4>(); pack_len += std::mem::size_of::<UdpgwAddrIpv4>();
packet.extend_from_slice(&(pack_len as u16).to_le_bytes()); packet.extend_from_slice(&(pack_len as u16).to_le_bytes());
packet.extend_from_slice(&[con.flags]); packet.extend_from_slice(&[con.flags]);
packet.extend_from_slice(&con.conid.to_le_bytes()); packet.extend_from_slice(&con.conn_id.to_le_bytes());
packet.extend_from_slice(&addr_ipv4.addr_ip.to_be_bytes()); packet.extend_from_slice(&addr_ipv4.addr_ip.to_be_bytes());
packet.extend_from_slice(&addr_ipv4.addr_port.to_be_bytes()); packet.extend_from_slice(&addr_ipv4.addr_port.to_be_bytes());
packet.extend_from_slice(&con.data[..len]); packet.extend_from_slice(&con.data[..len]);
@ -210,7 +211,7 @@ async fn process_udp(addr: SocketAddr, udp_timeout: u64, tx: Sender<Vec<u8>>, co
pack_len += std::mem::size_of::<UdpgwAddrIpv6>(); pack_len += std::mem::size_of::<UdpgwAddrIpv6>();
packet.extend_from_slice(&(pack_len as u16).to_le_bytes()); packet.extend_from_slice(&(pack_len as u16).to_le_bytes());
packet.extend_from_slice(&[con.flags]); packet.extend_from_slice(&[con.flags]);
packet.extend_from_slice(&con.conid.to_le_bytes()); packet.extend_from_slice(&con.conn_id.to_le_bytes());
packet.extend_from_slice(&addr_ipv6.addr_ip); packet.extend_from_slice(&addr_ipv6.addr_ip);
packet.extend_from_slice(&addr_ipv6.addr_port.to_be_bytes()); packet.extend_from_slice(&addr_ipv6.addr_port.to_be_bytes());
packet.extend_from_slice(&con.data[..len]); packet.extend_from_slice(&con.data[..len]);
@ -230,7 +231,7 @@ async fn process_udp(addr: SocketAddr, udp_timeout: u64, tx: Sender<Vec<u8>>, co
async fn process_client_udp_req(args: &UdpGwArgs, tx: Sender<Vec<u8>>, client: Client, mut reader: ReadHalf<'_>) -> std::io::Result<()> { async fn process_client_udp_req(args: &UdpGwArgs, tx: Sender<Vec<u8>>, client: Client, mut reader: ReadHalf<'_>) -> std::io::Result<()> {
let mut client = client; let mut client = client;
let mut buf = vec![0; args.udp_mtu as usize]; let mut buf = vec![0; args.udp_mtu as usize];
let mut len_buf = [0; std::mem::size_of::<PackLenHeader>()]; let mut len_buf = [0; UDPGW_LENGTH_FIELD_SIZE];
let udp_mtu = args.udp_mtu; let udp_mtu = args.udp_mtu;
let udp_timeout = args.udp_timeout; let udp_timeout = args.udp_timeout;
@ -250,8 +251,8 @@ async fn process_client_udp_req(args: &UdpGwArgs, tx: Sender<Vec<u8>>, client: C
// Connection closed // Connection closed
break; break;
} }
if n < std::mem::size_of::<PackLenHeader>() { if n < UDPGW_LENGTH_FIELD_SIZE {
log::error!("client {} received PackLenHeader error", client.addr); log::error!("client {} received Packet Length field error", client.addr);
break; break;
} }
let packet_len = u16::from_le_bytes([len_buf[0], len_buf[1]]); let packet_len = u16::from_le_bytes([len_buf[0], len_buf[1]]);
@ -272,23 +273,23 @@ async fn process_client_udp_req(args: &UdpGwArgs, tx: Sender<Vec<u8>>, client: C
} }
client.last_activity = std::time::Instant::now(); client.last_activity = std::time::Instant::now();
let ret = parse_udp(udp_mtu, client.buf.len(), &client.buf); let ret = parse_udp(udp_mtu, client.buf.len(), &client.buf);
if let Ok((udpdata, flags, conid, reqaddr)) = ret { if let Ok((udpdata, flags, conn_id, reqaddr)) = ret {
if flags & UDPGW_FLAG_KEEPALIVE != 0 { if flags & UDPGW_FLAG_KEEPALIVE != 0 {
log::debug!("client {} send keepalive", client.addr); log::debug!("client {} send keepalive", client.addr);
send_keepalive_response(tx.clone(), conid).await; send_keepalive_response(tx.clone(), conn_id).await;
continue; continue;
} }
log::debug!( log::debug!(
"client {} received udp data,flags:{},conid:{},addr:{:?},data len:{}", "client {} received udp data,flags:{},conn_id:{},addr:{:?},data len:{}",
client.addr, client.addr,
flags, flags,
conid, conn_id,
reqaddr, reqaddr,
udpdata.len() udpdata.len()
); );
let mut req = UdpRequest { let mut req = UdpRequest {
server_addr: reqaddr, server_addr: reqaddr,
conid, conn_id,
flags, flags,
data: udpdata.to_vec(), data: udpdata.to_vec(),
}; };

View file

@ -513,12 +513,8 @@ async fn handle_udp_gateway_session(
let tcp_local_addr = server_stream.local_addr().clone(); let tcp_local_addr = server_stream.local_addr().clone();
match domain_name { match domain_name {
Some(ref d) => { Some(ref d) => log::info!("Beginning {} <- {}, domain:{}", udpinfo, &tcp_local_addr, d),
log::info!("Beginning {} <- {}, domain:{}", udpinfo, &tcp_local_addr, d); None => log::info!("Beginning {} <- {}", udpinfo, &tcp_local_addr),
}
None => {
log::info!("Beginning {} <- {}", udpinfo, &tcp_local_addr);
}
} }
let Some(mut stream_reader) = server_stream.get_reader() else { let Some(mut stream_reader) = server_stream.get_reader() else {
@ -547,8 +543,8 @@ async fn handle_udp_gateway_session(
break; break;
} }
} }
let newid = server_stream.newid(); let new_id = server_stream.new_id();
if let Err(e) = UdpGwClient::send_udpgw_packet(ipv6_enabled, read_len, udp_server_addr, domain_name.as_ref(), newid, &mut stream_writer).await { if let Err(e) = UdpGwClient::send_udpgw_packet(ipv6_enabled, read_len, udp_server_addr, domain_name.as_ref(), new_id, &mut stream_writer).await {
log::info!("Ending {} <- {} with send_udpgw_packet {}", udpinfo, &tcp_local_addr, e); log::info!("Ending {} <- {} with send_udpgw_packet {}", udpinfo, &tcp_local_addr, e);
break; break;
} }

View file

@ -1,47 +1,171 @@
use crate::error::Result; use crate::error::Result;
use ipstack::stream::IpStackUdpStream; use ipstack::stream::IpStackUdpStream;
use socks5_impl::protocol::{AsyncStreamOperation, BufMut, StreamOperation};
use std::collections::VecDeque; use std::collections::VecDeque;
use std::hash::Hash; use std::hash::Hash;
use std::mem;
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
use std::sync::atomic::Ordering::Relaxed; use std::sync::atomic::Ordering::Relaxed;
use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::{
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf}; io::{AsyncReadExt, AsyncWriteExt},
use tokio::net::TcpStream; net::{
use tokio::sync::Mutex; tcp::{OwnedReadHalf, OwnedWriteHalf},
use tokio::time::{sleep, Duration}; TcpStream,
},
sync::Mutex,
time::{sleep, Duration},
};
pub const UDPGW_MAX_CONNECTIONS: usize = 100; pub const UDPGW_MAX_CONNECTIONS: usize = 100;
pub const UDPGW_KEEPALIVE_TIME: tokio::time::Duration = std::time::Duration::from_secs(10); pub const UDPGW_KEEPALIVE_TIME: tokio::time::Duration = std::time::Duration::from_secs(10);
pub const UDPGW_FLAG_KEEPALIVE: u8 = 0x01; pub const UDPGW_FLAG_KEEPALIVE: u8 = 0x01;
pub const UDPGW_FLAG_IPV4: u8 = 0x00;
pub const UDPGW_FLAG_IPV6: u8 = 0x08; pub const UDPGW_FLAG_IPV6: u8 = 0x08;
pub const UDPGW_FLAG_DOMAIN: u8 = 0x10; pub const UDPGW_FLAG_DOMAIN: u8 = 0x10;
pub const UDPGW_FLAG_ERR: u8 = 0x20; pub const UDPGW_FLAG_ERR: u8 = 0x20;
pub const UDPGW_LENGTH_FIELD_SIZE: usize = std::mem::size_of::<u16>();
static TCP_COUNTER: std::sync::atomic::AtomicU32 = std::sync::atomic::AtomicU32::new(0); static TCP_COUNTER: std::sync::atomic::AtomicU32 = std::sync::atomic::AtomicU32::new(0);
#[derive(Debug, Clone, PartialEq, Eq, Hash)] #[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[repr(C)] pub struct Packet {
#[repr(packed(1))] pub length: u16,
pub struct PackLenHeader { pub header: UdpgwHeader,
packet_len: u16, pub data: Vec<u8>,
} }
#[derive(Debug, Clone, PartialEq, Eq, Hash)] impl From<Packet> for Vec<u8> {
fn from(packet: Packet) -> Vec<u8> {
(&packet).into()
}
}
impl From<&Packet> for Vec<u8> {
fn from(packet: &Packet) -> Vec<u8> {
let mut bytes = vec![0; packet.len()];
packet.write_to_buf(&mut bytes);
bytes
}
}
impl TryFrom<&[u8]> for Packet {
type Error = std::io::Error;
fn try_from(value: &[u8]) -> std::result::Result<Self, Self::Error> {
if value.len() < UDPGW_LENGTH_FIELD_SIZE {
return Err(std::io::ErrorKind::InvalidData.into());
}
let length = u16::from_le_bytes([value[0], value[1]]);
if value.len() < length as usize + UDPGW_LENGTH_FIELD_SIZE {
return Err(std::io::ErrorKind::InvalidData.into());
}
let header = UdpgwHeader::try_from(&value[UDPGW_LENGTH_FIELD_SIZE..])?;
let data = value[UDPGW_LENGTH_FIELD_SIZE + header.len()..].to_vec();
Ok(Packet::new(header, data))
}
}
impl Packet {
pub fn new(header: UdpgwHeader, data: Vec<u8>) -> Self {
let length = (header.len() + data.len()) as u16;
Packet { length, header, data }
}
}
impl StreamOperation for Packet {
fn retrieve_from_stream<R>(stream: &mut R) -> std::io::Result<Self>
where
R: std::io::Read,
Self: Sized,
{
let mut buf = [0; UDPGW_LENGTH_FIELD_SIZE];
stream.read_exact(&mut buf)?;
let length = u16::from_le_bytes(buf);
let mut buf = [0; UdpgwHeader::len()];
stream.read_exact(&mut buf)?;
let header = UdpgwHeader::try_from(&buf[..])?;
let mut data = vec![0; length as usize - header.len()];
stream.read_exact(&mut data)?;
Ok(Packet::new(header, data))
}
fn write_to_buf<B: BufMut>(&self, buf: &mut B) {
buf.put_u16_le(self.length);
self.header.write_to_buf(buf);
buf.put_slice(&self.data);
}
fn len(&self) -> usize {
UDPGW_LENGTH_FIELD_SIZE + self.header.len() + self.data.len()
}
}
#[async_trait::async_trait]
impl AsyncStreamOperation for Packet {
async fn retrieve_from_async_stream<R>(r: &mut R) -> std::io::Result<Self>
where
R: tokio::io::AsyncRead + Unpin + Send,
Self: Sized,
{
let mut buf = [0; 2];
r.read_exact(&mut buf).await?;
let length = u16::from_le_bytes(buf);
let header = UdpgwHeader::retrieve_from_async_stream(r).await?;
let mut data = vec![0; length as usize - header.len()];
r.read_exact(&mut data).await?;
Ok(Packet::new(header, data))
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[repr(C)] #[repr(C)]
#[repr(packed(1))] #[repr(packed(1))]
pub struct UdpgwHeader { pub struct UdpgwHeader {
pub flags: u8, pub flags: u8,
pub conid: u16, pub conn_id: u16,
}
impl StreamOperation for UdpgwHeader {
fn retrieve_from_stream<R>(stream: &mut R) -> std::io::Result<Self>
where
R: std::io::Read,
Self: Sized,
{
let mut buf = [0; UdpgwHeader::len()];
stream.read_exact(&mut buf)?;
UdpgwHeader::try_from(&buf[..])
}
fn write_to_buf<B: BufMut>(&self, buf: &mut B) {
let bytes: Vec<u8> = self.into();
buf.put_slice(&bytes);
}
fn len(&self) -> usize {
Self::len()
}
}
#[async_trait::async_trait]
impl AsyncStreamOperation for UdpgwHeader {
async fn retrieve_from_async_stream<R>(r: &mut R) -> std::io::Result<Self>
where
R: tokio::io::AsyncRead + Unpin + Send,
Self: Sized,
{
let mut buf = [0; UdpgwHeader::len()];
r.read_exact(&mut buf).await?;
UdpgwHeader::try_from(&buf[..])
}
} }
impl UdpgwHeader { impl UdpgwHeader {
pub fn new(flags: u8, conid: u16) -> Self { pub fn new(flags: u8, conn_id: u16) -> Self {
UdpgwHeader { flags, conid } UdpgwHeader { flags, conn_id }
} }
pub const fn len() -> usize { pub const fn len() -> usize {
std::mem::size_of::<u16>() + std::mem::size_of::<UdpgwHeader>() std::mem::size_of::<UdpgwHeader>()
} }
} }
@ -50,25 +174,20 @@ impl TryFrom<&[u8]> for UdpgwHeader {
fn try_from(value: &[u8]) -> std::result::Result<Self, Self::Error> { fn try_from(value: &[u8]) -> std::result::Result<Self, Self::Error> {
if value.len() < UdpgwHeader::len() { if value.len() < UdpgwHeader::len() {
return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "Invalid UdpgwHeader")); return Err(std::io::ErrorKind::InvalidData.into());
}
let len = u16::from_le_bytes([value[0], value[1]]);
if len != std::mem::size_of::<UdpgwHeader>() as u16 {
return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "Invalid UdpgwHeader"));
} }
Ok(UdpgwHeader { Ok(UdpgwHeader {
flags: value[2], flags: value[0],
conid: u16::from_le_bytes([value[3], value[4]]), conn_id: u16::from_le_bytes([value[1], value[2]]),
}) })
} }
} }
impl From<UdpgwHeader> for Vec<u8> { impl From<&UdpgwHeader> for Vec<u8> {
fn from(header: UdpgwHeader) -> Vec<u8> { fn from(header: &UdpgwHeader) -> Vec<u8> {
let mut bytes = vec![0; UdpgwHeader::len()]; let mut bytes = vec![0; header.len()];
bytes[0..2].copy_from_slice(&(std::mem::size_of::<UdpgwHeader>() as u16).to_le_bytes()); bytes[0] = header.flags;
bytes[2] = header.flags; bytes[1..3].copy_from_slice(&header.conn_id.to_le_bytes());
bytes[3..5].copy_from_slice(&header.conid.to_le_bytes());
bytes bytes
} }
} }
@ -127,9 +246,10 @@ impl From<UdpgwAddr> for SocketAddr {
} }
#[allow(dead_code)] #[allow(dead_code)]
#[derive(Debug)]
pub(crate) struct UdpGwData<'a> { pub(crate) struct UdpGwData<'a> {
flags: u8, flags: u8,
conid: u16, conn_id: u16,
remote_addr: SocketAddr, remote_addr: SocketAddr,
udpdata: &'a [u8], udpdata: &'a [u8],
} }
@ -141,6 +261,7 @@ impl<'a> UdpGwData<'a> {
} }
#[allow(dead_code)] #[allow(dead_code)]
#[derive(Debug)]
pub(crate) enum UdpGwResponse<'a> { pub(crate) enum UdpGwResponse<'a> {
KeepAlive, KeepAlive,
Error, Error,
@ -166,7 +287,7 @@ pub(crate) struct UdpGwClientStream {
local_addr: String, local_addr: String,
writer: Option<UdpGwClientStreamWriter>, writer: Option<UdpGwClientStreamWriter>,
reader: Option<UdpGwClientStreamReader>, reader: Option<UdpGwClientStreamReader>,
conid: u16, conn_id: u16,
closed: bool, closed: bool,
last_activity: std::time::Instant, last_activity: std::time::Instant,
} }
@ -210,12 +331,12 @@ impl UdpGwClientStream {
} }
pub fn id(&mut self) -> u16 { pub fn id(&mut self) -> u16 {
self.conid self.conn_id
} }
pub fn newid(&mut self) -> u16 { pub fn new_id(&mut self) -> u16 {
self.conid += 1; self.conn_id += 1;
self.conid self.conn_id
} }
pub fn new(udp_mtu: u16, tcp_server_stream: TcpStream) -> Self { pub fn new(udp_mtu: u16, tcp_server_stream: TcpStream) -> Self {
let local_addr = tcp_server_stream let local_addr = tcp_server_stream
@ -239,7 +360,7 @@ impl UdpGwClientStream {
writer: Some(writer), writer: Some(writer),
last_activity: std::time::Instant::now(), last_activity: std::time::Instant::now(),
closed: false, closed: false,
conid: 0, conn_id: 0,
} }
} }
} }
@ -257,7 +378,7 @@ pub(crate) struct UdpGwClient {
impl UdpGwClient { impl UdpGwClient {
pub fn new(udp_mtu: u16, max_connections: u16, keepalive_time: Duration, udp_timeout: u64, server_addr: SocketAddr) -> Self { pub fn new(udp_mtu: u16, max_connections: u16, keepalive_time: Duration, udp_timeout: u64, server_addr: SocketAddr) -> Self {
let keepalive_packet = UdpgwHeader::new(UDPGW_FLAG_KEEPALIVE, 0).into(); let keepalive_packet: Vec<u8> = Packet::new(UdpgwHeader::new(UDPGW_FLAG_KEEPALIVE, 0), vec![]).into();
let server_connections = Mutex::new(VecDeque::with_capacity(max_connections as usize)); let server_connections = Mutex::new(VecDeque::with_capacity(max_connections as usize));
UdpGwClient { UdpGwClient {
udp_mtu, udp_mtu,
@ -326,14 +447,10 @@ impl UdpGwClient {
let Some(mut stream_writer) = stream.get_writer() else { let Some(mut stream_writer) = stream.get_writer() else {
continue; continue;
}; };
log::debug!("{:?}:{} send keepalive", stream_writer.inner.local_addr(), stream.id()); let local_addr = stream_writer.inner.local_addr();
log::debug!("{:?}:{} send keepalive", local_addr, stream.id());
if let Err(e) = stream_writer.inner.write_all(&self.keepalive_packet).await { if let Err(e) = stream_writer.inner.write_all(&self.keepalive_packet).await {
log::warn!( log::warn!("{:?}:{} send keepalive failed: {}", local_addr, stream.id(), e);
"{:?}:{} send keepalive failed: {}",
stream_writer.inner.local_addr(),
stream.id(),
e
);
} else { } else {
match UdpGwClient::recv_udpgw_packet(self.udp_mtu, 10, &mut stream_reader).await { match UdpGwClient::recv_udpgw_packet(self.udp_mtu, 10, &mut stream_reader).await {
Ok(UdpGwResponse::KeepAlive) => { Ok(UdpGwResponse::KeepAlive) => {
@ -341,10 +458,8 @@ impl UdpGwClient {
self.release_server_connection_with_stream(stream, stream_reader, stream_writer) self.release_server_connection_with_stream(stream, stream_reader, stream_writer)
.await; .await;
} }
//shoud not receive other type Ok(v) => log::warn!("{:?}:{} keepalive unexpected response: {:?}", local_addr, stream.id(), v),
_ => { Err(e) => log::warn!("{:?}:{} keepalive no response, error \"{}\"", local_addr, stream.id(), e),
log::warn!("{:?}:{} keepalive no response", stream_writer.inner.local_addr(), stream.id());
}
} }
} }
} }
@ -354,20 +469,20 @@ impl UdpGwClient {
/// Parses the UDP response data. /// Parses the UDP response data.
pub(crate) fn parse_udp_response(udp_mtu: u16, data_len: usize, stream: &mut UdpGwClientStreamReader) -> Result<UdpGwResponse> { pub(crate) fn parse_udp_response(udp_mtu: u16, data_len: usize, stream: &mut UdpGwClientStreamReader) -> Result<UdpGwResponse> {
let data = &stream.recv_buf; let data = &stream.recv_buf;
if data_len < mem::size_of::<UdpgwHeader>() { if data_len < UdpgwHeader::len() {
return Err("Invalid udpgw data".into()); return Err("Invalid udpgw data".into());
} }
let header_bytes = &data[..mem::size_of::<UdpgwHeader>()]; let header_bytes = &data[..UdpgwHeader::len()];
let header = UdpgwHeader { let header = UdpgwHeader {
flags: header_bytes[0], flags: header_bytes[0],
conid: u16::from_le_bytes([header_bytes[1], header_bytes[2]]), conn_id: u16::from_le_bytes([header_bytes[1], header_bytes[2]]),
}; };
let flags = header.flags; let flags = header.flags;
let conid = header.conid; let conn_id = header.conn_id;
let ip_data = &data[mem::size_of::<UdpgwHeader>()..]; let ip_data = &data[UdpgwHeader::len()..];
let mut data_len = data_len - mem::size_of::<UdpgwHeader>(); let mut data_len = data_len - UdpgwHeader::len();
if flags & UDPGW_FLAG_ERR != 0 { if flags & UDPGW_FLAG_ERR != 0 {
return Ok(UdpGwResponse::Error); return Ok(UdpGwResponse::Error);
@ -378,44 +493,44 @@ impl UdpGwClient {
} }
if flags & UDPGW_FLAG_IPV6 != 0 { if flags & UDPGW_FLAG_IPV6 != 0 {
if data_len < mem::size_of::<UdpgwAddrIpv6>() { if data_len < std::mem::size_of::<UdpgwAddrIpv6>() {
return Err("ipv6 Invalid UDP data".into()); return Err("ipv6 Invalid UDP data".into());
} }
let addr_ipv6_bytes = &ip_data[..mem::size_of::<UdpgwAddrIpv6>()]; let addr_ipv6_bytes = &ip_data[..std::mem::size_of::<UdpgwAddrIpv6>()];
let addr_ipv6 = UdpgwAddrIpv6 { let addr_ipv6 = UdpgwAddrIpv6 {
addr_ip: addr_ipv6_bytes[..16].try_into().map_err(|_| "Failed to convert slice to array")?, addr_ip: addr_ipv6_bytes[..16].try_into().map_err(|_| "Failed to convert slice to array")?,
addr_port: u16::from_be_bytes([addr_ipv6_bytes[16], addr_ipv6_bytes[17]]), addr_port: u16::from_be_bytes([addr_ipv6_bytes[16], addr_ipv6_bytes[17]]),
}; };
data_len -= mem::size_of::<UdpgwAddrIpv6>(); data_len -= std::mem::size_of::<UdpgwAddrIpv6>();
if data_len > udp_mtu as usize { if data_len > udp_mtu as usize {
return Err("too much data".into()); return Err("too much data".into());
} }
return Ok(UdpGwResponse::Data(UdpGwData { return Ok(UdpGwResponse::Data(UdpGwData {
flags, flags,
conid, conn_id,
remote_addr: UdpgwAddr::IPV6(addr_ipv6).into(), remote_addr: UdpgwAddr::IPV6(addr_ipv6).into(),
udpdata: &ip_data[mem::size_of::<UdpgwAddrIpv6>()..(data_len + mem::size_of::<UdpgwAddrIpv6>())], udpdata: &ip_data[std::mem::size_of::<UdpgwAddrIpv6>()..(data_len + std::mem::size_of::<UdpgwAddrIpv6>())],
})); }));
} else { } else {
if data_len < mem::size_of::<UdpgwAddrIpv4>() { if data_len < std::mem::size_of::<UdpgwAddrIpv4>() {
return Err("ipv4 Invalid UDP data".into()); return Err("ipv4 Invalid UDP data".into());
} }
let addr_ipv4_bytes = &ip_data[..mem::size_of::<UdpgwAddrIpv4>()]; let addr_ipv4_bytes = &ip_data[..std::mem::size_of::<UdpgwAddrIpv4>()];
let addr_ipv4 = UdpgwAddrIpv4 { let addr_ipv4 = UdpgwAddrIpv4 {
addr_ip: u32::from_be_bytes([addr_ipv4_bytes[0], addr_ipv4_bytes[1], addr_ipv4_bytes[2], addr_ipv4_bytes[3]]), addr_ip: u32::from_be_bytes([addr_ipv4_bytes[0], addr_ipv4_bytes[1], addr_ipv4_bytes[2], addr_ipv4_bytes[3]]),
addr_port: u16::from_be_bytes([addr_ipv4_bytes[4], addr_ipv4_bytes[5]]), addr_port: u16::from_be_bytes([addr_ipv4_bytes[4], addr_ipv4_bytes[5]]),
}; };
data_len -= mem::size_of::<UdpgwAddrIpv4>(); data_len -= std::mem::size_of::<UdpgwAddrIpv4>();
if data_len > udp_mtu as usize { if data_len > udp_mtu as usize {
return Err("too much data".into()); return Err("too much data".into());
} }
return Ok(UdpGwResponse::Data(UdpGwData { return Ok(UdpGwResponse::Data(UdpGwData {
flags, flags,
conid, conn_id,
remote_addr: UdpgwAddr::IPV4(addr_ipv4).into(), remote_addr: UdpgwAddr::IPV4(addr_ipv4).into(),
udpdata: &ip_data[mem::size_of::<UdpgwAddrIpv4>()..(data_len + mem::size_of::<UdpgwAddrIpv4>())], udpdata: &ip_data[std::mem::size_of::<UdpgwAddrIpv4>()..(data_len + std::mem::size_of::<UdpgwAddrIpv4>())],
})); }));
} }
} }
@ -456,8 +571,8 @@ impl UdpGwClient {
if n == 0 { if n == 0 {
return Ok(UdpGwResponse::TcpClose); return Ok(UdpGwResponse::TcpClose);
} }
if n < std::mem::size_of::<PackLenHeader>() { if n < UDPGW_LENGTH_FIELD_SIZE {
return Err("received PackLenHeader error".into()); return Err("received Packet Length field error".into());
} }
let packet_len = u16::from_le_bytes([stream.recv_buf[0], stream.recv_buf[1]]); let packet_len = u16::from_le_bytes([stream.recv_buf[0], stream.recv_buf[1]]);
if packet_len > udp_mtu { if packet_len > udp_mtu {
@ -489,7 +604,7 @@ impl UdpGwClient {
/// * `len` - Length of the data packet /// * `len` - Length of the data packet
/// * `remote_addr` - Remote address /// * `remote_addr` - Remote address
/// * `domain` - Target domain (optional) /// * `domain` - Target domain (optional)
/// * `conid` - Connection ID /// * `conn_id` - Connection ID
/// * `stream` - UDP gateway client writer stream /// * `stream` - UDP gateway client writer stream
/// ///
/// # Returns /// # Returns
@ -500,30 +615,28 @@ impl UdpGwClient {
len: usize, len: usize,
remote_addr: SocketAddr, remote_addr: SocketAddr,
domain: Option<&String>, domain: Option<&String>,
conid: u16, conn_id: u16,
stream: &mut UdpGwClientStreamWriter, stream: &mut UdpGwClientStreamWriter,
) -> Result<()> { ) -> Result<()> {
stream.send_buf.clear(); stream.send_buf.clear();
let data = &stream.tmp_buf; let data = &stream.tmp_buf;
let mut pack_len = std::mem::size_of::<UdpgwHeader>() + len; let mut pack_len = UdpgwHeader::len() + len;
let packet = &mut stream.send_buf; let packet = &mut stream.send_buf;
let mut flags = 0;
match domain { match domain {
Some(domain) => { Some(domain) => {
let addr_port = match remote_addr.into() { let addr_port = match remote_addr.into() {
UdpgwAddr::IPV4(addr_ipv4) => addr_ipv4.addr_port, UdpgwAddr::IPV4(addr_ipv4) => addr_ipv4.addr_port,
UdpgwAddr::IPV6(addr_ipv6) => addr_ipv6.addr_port, UdpgwAddr::IPV6(addr_ipv6) => addr_ipv6.addr_port,
}; };
pack_len += std::mem::size_of::<u16>();
let domain_len = domain.len(); let domain_len = domain.len();
if domain_len > 255 { if domain_len > 255 {
return Err("InvalidDomain".into()); return Err("InvalidDomain".into());
} }
pack_len += UDPGW_LENGTH_FIELD_SIZE;
pack_len += domain_len + 1; pack_len += domain_len + 1;
flags = UDPGW_FLAG_DOMAIN;
packet.extend_from_slice(&(pack_len as u16).to_le_bytes()); packet.extend_from_slice(&(pack_len as u16).to_le_bytes());
packet.extend_from_slice(&[flags]); packet.extend_from_slice(&[UDPGW_FLAG_DOMAIN]);
packet.extend_from_slice(&conid.to_le_bytes()); packet.extend_from_slice(&conn_id.to_le_bytes());
packet.extend_from_slice(&addr_port.to_be_bytes()); packet.extend_from_slice(&addr_port.to_be_bytes());
packet.extend_from_slice(domain.as_bytes()); packet.extend_from_slice(domain.as_bytes());
packet.push(0); packet.push(0);
@ -533,8 +646,8 @@ impl UdpGwClient {
UdpgwAddr::IPV4(addr_ipv4) => { UdpgwAddr::IPV4(addr_ipv4) => {
pack_len += std::mem::size_of::<UdpgwAddrIpv4>(); pack_len += std::mem::size_of::<UdpgwAddrIpv4>();
packet.extend_from_slice(&(pack_len as u16).to_le_bytes()); packet.extend_from_slice(&(pack_len as u16).to_le_bytes());
packet.extend_from_slice(&[flags]); packet.extend_from_slice(&[UDPGW_FLAG_IPV4]);
packet.extend_from_slice(&conid.to_le_bytes()); packet.extend_from_slice(&conn_id.to_le_bytes());
packet.extend_from_slice(&addr_ipv4.addr_ip.to_be_bytes()); packet.extend_from_slice(&addr_ipv4.addr_ip.to_be_bytes());
packet.extend_from_slice(&addr_ipv4.addr_port.to_be_bytes()); packet.extend_from_slice(&addr_ipv4.addr_port.to_be_bytes());
packet.extend_from_slice(&data[..len]); packet.extend_from_slice(&data[..len]);
@ -543,11 +656,10 @@ impl UdpGwClient {
if !ipv6_enabled { if !ipv6_enabled {
return Err("ipv6 not support".into()); return Err("ipv6 not support".into());
} }
flags = UDPGW_FLAG_IPV6;
pack_len += std::mem::size_of::<UdpgwAddrIpv6>(); pack_len += std::mem::size_of::<UdpgwAddrIpv6>();
packet.extend_from_slice(&(pack_len as u16).to_le_bytes()); packet.extend_from_slice(&(pack_len as u16).to_le_bytes());
packet.extend_from_slice(&[flags]); packet.extend_from_slice(&[UDPGW_FLAG_IPV6]);
packet.extend_from_slice(&conid.to_le_bytes()); packet.extend_from_slice(&conn_id.to_le_bytes());
packet.extend_from_slice(&addr_ipv6.addr_ip); packet.extend_from_slice(&addr_ipv6.addr_ip);
packet.extend_from_slice(&addr_ipv6.addr_port.to_be_bytes()); packet.extend_from_slice(&addr_ipv6.addr_port.to_be_bytes());
packet.extend_from_slice(&data[..len]); packet.extend_from_slice(&data[..len]);
@ -563,13 +675,18 @@ impl UdpGwClient {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::UdpgwHeader; use super::{Packet, UdpgwHeader};
use socks5_impl::protocol::StreamOperation;
#[test] #[test]
fn test_udpgw_header() { fn test_udpgw_header() {
let header = UdpgwHeader::new(0x01, 0x1234); let header = UdpgwHeader::new(0x01, 0x1234);
let bytes = Vec::from(header.clone()); let mut bytes: Vec<u8> = vec![];
let header2 = UdpgwHeader::try_from(&bytes[..]).unwrap(); let packet = Packet::new(header, vec![]);
packet.write_to_buf(&mut bytes);
let header2 = Packet::retrieve_from_stream(&mut bytes.as_slice()).unwrap().header;
assert_eq!(header, header2); assert_eq!(header, header2);
} }
} }