refactor udpgw with Packet struct

This commit is contained in:
ssrlive 2024-10-23 03:50:00 +08:00
parent f049e4a998
commit ec79d54a9e
4 changed files with 241 additions and 127 deletions

View file

@ -117,7 +117,7 @@ pub struct Args {
#[arg(long, value_name = "IP:PORT")]
pub udpgw_server: Option<SocketAddr>,
/// Max udpgw connections
/// Max udpgw connections, default value is 100
#[cfg(feature = "udpgw")]
#[arg(long, value_name = "number")]
pub udpgw_max_connections: Option<u16>,

View file

@ -18,7 +18,7 @@ pub(crate) const CLIENT_DISCONNECT_TIMEOUT: tokio::time::Duration = std::time::D
struct UdpRequest {
flags: u8,
server_addr: SocketAddr,
conid: u16,
conn_id: u16,
data: Vec<u8>,
}
@ -41,26 +41,26 @@ impl Client {
#[derive(Debug, Clone, clap::Parser)]
pub struct UdpGwArgs {
/// UDP gateway listen address
#[arg(short, long, value_name = "IP:PORT", default_value = "127.0.0.1:7300")]
pub listen_addr: SocketAddr,
/// UDP mtu
#[arg(long, value_name = "udp mtu", default_value = "10240")]
#[arg(short = 'm', long, value_name = "udp mtu", default_value = "10240")]
pub udp_mtu: u16,
/// Verbosity level
#[arg(short, long, value_name = "level", value_enum, default_value = "info")]
pub verbosity: ArgVerbosity,
/// UDP timeout in seconds
#[arg(short = 't', long, value_name = "seconds", default_value = "3")]
pub udp_timeout: u64,
/// Daemonize for unix family or run as Windows service
#[cfg(unix)]
#[arg(long)]
pub daemonize: bool,
/// UDP timeout in seconds
#[arg(long, value_name = "seconds", default_value = "3")]
pub udp_timeout: u64,
/// UDP gateway listen address
#[arg(long, value_name = "IP:PORT", default_value = "127.0.0.1:7300")]
pub listen_addr: SocketAddr,
/// Verbosity level
#[arg(short, long, value_name = "level", value_enum, default_value = "info")]
pub verbosity: ArgVerbosity,
}
impl UdpGwArgs {
@ -70,42 +70,43 @@ impl UdpGwArgs {
Self::parse()
}
}
async fn send_error(tx: Sender<Vec<u8>>, con: &mut UdpRequest) {
let error_packet = UdpgwHeader::new(UDPGW_FLAG_ERR, con.conid).into();
let error_packet: Vec<u8> = Packet::new(UdpgwHeader::new(UDPGW_FLAG_ERR, con.conn_id), vec![]).into();
if let Err(e) = tx.send(error_packet).await {
log::error!("send error response error {:?}", e);
}
}
async fn send_keepalive_response(tx: Sender<Vec<u8>>, conid: u16) {
let keepalive_packet = UdpgwHeader::new(UDPGW_FLAG_KEEPALIVE, conid).into();
async fn send_keepalive_response(tx: Sender<Vec<u8>>, conn_id: u16) {
let keepalive_packet: Vec<u8> = Packet::new(UdpgwHeader::new(UDPGW_FLAG_KEEPALIVE, conn_id), vec![]).into();
if let Err(e) = tx.send(keepalive_packet).await {
log::error!("send keepalive response error {:?}", e);
}
}
pub fn parse_udp(udp_mtu: u16, data_len: usize, data: &[u8]) -> Result<(&[u8], u8, u16, SocketAddr)> {
if data_len < std::mem::size_of::<UdpgwHeader>() {
if data_len < UdpgwHeader::len() {
return Err("Invalid udpgw data".into());
}
let header_bytes = &data[..std::mem::size_of::<UdpgwHeader>()];
let header_bytes = &data[..UdpgwHeader::len()];
let header = UdpgwHeader {
flags: header_bytes[0],
conid: u16::from_le_bytes([header_bytes[1], header_bytes[2]]),
conn_id: u16::from_le_bytes([header_bytes[1], header_bytes[2]]),
};
let flags = header.flags;
let conid = header.conid;
let conn_id = header.conn_id;
// keepalive
if flags & UDPGW_FLAG_KEEPALIVE != 0 {
return Ok((data, flags, conid, SocketAddrV4::new(Ipv4Addr::from(0), 0).into()));
return Ok((data, flags, conn_id, SocketAddrV4::new(Ipv4Addr::from(0), 0).into()));
}
let ip_data = &data[std::mem::size_of::<UdpgwHeader>()..];
let mut data_len = data_len - std::mem::size_of::<UdpgwHeader>();
let ip_data = &data[UdpgwHeader::len()..];
let mut data_len = data_len - UdpgwHeader::len();
// port_len + min(ipv4/ipv6/(domain_len + 1))
if data_len < std::mem::size_of::<u16>() + 2 {
if data_len < UDPGW_LENGTH_FIELD_SIZE + 2 {
return Err("Invalid udpgw data".into());
}
if flags & UDPGW_FLAG_DOMAIN != 0 {
@ -128,7 +129,7 @@ pub fn parse_udp(udp_mtu: u16, data_len: usize, data: &[u8]) -> Result<(&[u8], u
return Err("too much data".into());
}
let udpdata = &ip_data[(2 + domain.len() + 1)..];
Ok((udpdata, flags, conid, target))
Ok((udpdata, flags, conn_id, target))
}
Err(_) => Err("Invalid UTF-8 sequence in domain".into()),
}
@ -152,7 +153,7 @@ pub fn parse_udp(udp_mtu: u16, data_len: usize, data: &[u8]) -> Result<(&[u8], u
return Ok((
&ip_data[std::mem::size_of::<UdpgwAddrIpv6>()..(data_len + std::mem::size_of::<UdpgwAddrIpv6>())],
flags,
conid,
conn_id,
UdpgwAddr::IPV6(addr_ipv6).into(),
));
} else {
@ -173,7 +174,7 @@ pub fn parse_udp(udp_mtu: u16, data_len: usize, data: &[u8]) -> Result<(&[u8], u
return Ok((
&ip_data[std::mem::size_of::<UdpgwAddrIpv4>()..(data_len + std::mem::size_of::<UdpgwAddrIpv4>())],
flags,
conid,
conn_id,
UdpgwAddr::IPV4(addr_ipv4).into(),
));
}
@ -195,13 +196,13 @@ async fn process_udp(addr: SocketAddr, udp_timeout: u64, tx: Sender<Vec<u8>>, co
Ok(ret) => {
let (len, _addr) = ret?;
let mut packet = vec![];
let mut pack_len = std::mem::size_of::<UdpgwHeader>() + len;
let mut pack_len = UdpgwHeader::len() + len;
match con.server_addr.into() {
UdpgwAddr::IPV4(addr_ipv4) => {
pack_len += std::mem::size_of::<UdpgwAddrIpv4>();
packet.extend_from_slice(&(pack_len as u16).to_le_bytes());
packet.extend_from_slice(&[con.flags]);
packet.extend_from_slice(&con.conid.to_le_bytes());
packet.extend_from_slice(&con.conn_id.to_le_bytes());
packet.extend_from_slice(&addr_ipv4.addr_ip.to_be_bytes());
packet.extend_from_slice(&addr_ipv4.addr_port.to_be_bytes());
packet.extend_from_slice(&con.data[..len]);
@ -210,7 +211,7 @@ async fn process_udp(addr: SocketAddr, udp_timeout: u64, tx: Sender<Vec<u8>>, co
pack_len += std::mem::size_of::<UdpgwAddrIpv6>();
packet.extend_from_slice(&(pack_len as u16).to_le_bytes());
packet.extend_from_slice(&[con.flags]);
packet.extend_from_slice(&con.conid.to_le_bytes());
packet.extend_from_slice(&con.conn_id.to_le_bytes());
packet.extend_from_slice(&addr_ipv6.addr_ip);
packet.extend_from_slice(&addr_ipv6.addr_port.to_be_bytes());
packet.extend_from_slice(&con.data[..len]);
@ -230,7 +231,7 @@ async fn process_udp(addr: SocketAddr, udp_timeout: u64, tx: Sender<Vec<u8>>, co
async fn process_client_udp_req(args: &UdpGwArgs, tx: Sender<Vec<u8>>, client: Client, mut reader: ReadHalf<'_>) -> std::io::Result<()> {
let mut client = client;
let mut buf = vec![0; args.udp_mtu as usize];
let mut len_buf = [0; std::mem::size_of::<PackLenHeader>()];
let mut len_buf = [0; UDPGW_LENGTH_FIELD_SIZE];
let udp_mtu = args.udp_mtu;
let udp_timeout = args.udp_timeout;
@ -250,8 +251,8 @@ async fn process_client_udp_req(args: &UdpGwArgs, tx: Sender<Vec<u8>>, client: C
// Connection closed
break;
}
if n < std::mem::size_of::<PackLenHeader>() {
log::error!("client {} received PackLenHeader error", client.addr);
if n < UDPGW_LENGTH_FIELD_SIZE {
log::error!("client {} received Packet Length field error", client.addr);
break;
}
let packet_len = u16::from_le_bytes([len_buf[0], len_buf[1]]);
@ -272,23 +273,23 @@ async fn process_client_udp_req(args: &UdpGwArgs, tx: Sender<Vec<u8>>, client: C
}
client.last_activity = std::time::Instant::now();
let ret = parse_udp(udp_mtu, client.buf.len(), &client.buf);
if let Ok((udpdata, flags, conid, reqaddr)) = ret {
if let Ok((udpdata, flags, conn_id, reqaddr)) = ret {
if flags & UDPGW_FLAG_KEEPALIVE != 0 {
log::debug!("client {} send keepalive", client.addr);
send_keepalive_response(tx.clone(), conid).await;
send_keepalive_response(tx.clone(), conn_id).await;
continue;
}
log::debug!(
"client {} received udp data,flags:{},conid:{},addr:{:?},data len:{}",
"client {} received udp data,flags:{},conn_id:{},addr:{:?},data len:{}",
client.addr,
flags,
conid,
conn_id,
reqaddr,
udpdata.len()
);
let mut req = UdpRequest {
server_addr: reqaddr,
conid,
conn_id,
flags,
data: udpdata.to_vec(),
};

View file

@ -513,12 +513,8 @@ async fn handle_udp_gateway_session(
let tcp_local_addr = server_stream.local_addr().clone();
match domain_name {
Some(ref d) => {
log::info!("Beginning {} <- {}, domain:{}", udpinfo, &tcp_local_addr, d);
}
None => {
log::info!("Beginning {} <- {}", udpinfo, &tcp_local_addr);
}
Some(ref d) => log::info!("Beginning {} <- {}, domain:{}", udpinfo, &tcp_local_addr, d),
None => log::info!("Beginning {} <- {}", udpinfo, &tcp_local_addr),
}
let Some(mut stream_reader) = server_stream.get_reader() else {
@ -547,8 +543,8 @@ async fn handle_udp_gateway_session(
break;
}
}
let newid = server_stream.newid();
if let Err(e) = UdpGwClient::send_udpgw_packet(ipv6_enabled, read_len, udp_server_addr, domain_name.as_ref(), newid, &mut stream_writer).await {
let new_id = server_stream.new_id();
if let Err(e) = UdpGwClient::send_udpgw_packet(ipv6_enabled, read_len, udp_server_addr, domain_name.as_ref(), new_id, &mut stream_writer).await {
log::info!("Ending {} <- {} with send_udpgw_packet {}", udpinfo, &tcp_local_addr, e);
break;
}

View file

@ -1,47 +1,171 @@
use crate::error::Result;
use ipstack::stream::IpStackUdpStream;
use socks5_impl::protocol::{AsyncStreamOperation, BufMut, StreamOperation};
use std::collections::VecDeque;
use std::hash::Hash;
use std::mem;
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
use std::sync::atomic::Ordering::Relaxed;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
use tokio::net::TcpStream;
use tokio::sync::Mutex;
use tokio::time::{sleep, Duration};
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::{
tcp::{OwnedReadHalf, OwnedWriteHalf},
TcpStream,
},
sync::Mutex,
time::{sleep, Duration},
};
pub const UDPGW_MAX_CONNECTIONS: usize = 100;
pub const UDPGW_KEEPALIVE_TIME: tokio::time::Duration = std::time::Duration::from_secs(10);
pub const UDPGW_FLAG_KEEPALIVE: u8 = 0x01;
pub const UDPGW_FLAG_IPV4: u8 = 0x00;
pub const UDPGW_FLAG_IPV6: u8 = 0x08;
pub const UDPGW_FLAG_DOMAIN: u8 = 0x10;
pub const UDPGW_FLAG_ERR: u8 = 0x20;
pub const UDPGW_LENGTH_FIELD_SIZE: usize = std::mem::size_of::<u16>();
static TCP_COUNTER: std::sync::atomic::AtomicU32 = std::sync::atomic::AtomicU32::new(0);
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[repr(C)]
#[repr(packed(1))]
pub struct PackLenHeader {
packet_len: u16,
pub struct Packet {
pub length: u16,
pub header: UdpgwHeader,
pub data: Vec<u8>,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
impl From<Packet> for Vec<u8> {
fn from(packet: Packet) -> Vec<u8> {
(&packet).into()
}
}
impl From<&Packet> for Vec<u8> {
fn from(packet: &Packet) -> Vec<u8> {
let mut bytes = vec![0; packet.len()];
packet.write_to_buf(&mut bytes);
bytes
}
}
impl TryFrom<&[u8]> for Packet {
type Error = std::io::Error;
fn try_from(value: &[u8]) -> std::result::Result<Self, Self::Error> {
if value.len() < UDPGW_LENGTH_FIELD_SIZE {
return Err(std::io::ErrorKind::InvalidData.into());
}
let length = u16::from_le_bytes([value[0], value[1]]);
if value.len() < length as usize + UDPGW_LENGTH_FIELD_SIZE {
return Err(std::io::ErrorKind::InvalidData.into());
}
let header = UdpgwHeader::try_from(&value[UDPGW_LENGTH_FIELD_SIZE..])?;
let data = value[UDPGW_LENGTH_FIELD_SIZE + header.len()..].to_vec();
Ok(Packet::new(header, data))
}
}
impl Packet {
pub fn new(header: UdpgwHeader, data: Vec<u8>) -> Self {
let length = (header.len() + data.len()) as u16;
Packet { length, header, data }
}
}
impl StreamOperation for Packet {
fn retrieve_from_stream<R>(stream: &mut R) -> std::io::Result<Self>
where
R: std::io::Read,
Self: Sized,
{
let mut buf = [0; UDPGW_LENGTH_FIELD_SIZE];
stream.read_exact(&mut buf)?;
let length = u16::from_le_bytes(buf);
let mut buf = [0; UdpgwHeader::len()];
stream.read_exact(&mut buf)?;
let header = UdpgwHeader::try_from(&buf[..])?;
let mut data = vec![0; length as usize - header.len()];
stream.read_exact(&mut data)?;
Ok(Packet::new(header, data))
}
fn write_to_buf<B: BufMut>(&self, buf: &mut B) {
buf.put_u16_le(self.length);
self.header.write_to_buf(buf);
buf.put_slice(&self.data);
}
fn len(&self) -> usize {
UDPGW_LENGTH_FIELD_SIZE + self.header.len() + self.data.len()
}
}
#[async_trait::async_trait]
impl AsyncStreamOperation for Packet {
async fn retrieve_from_async_stream<R>(r: &mut R) -> std::io::Result<Self>
where
R: tokio::io::AsyncRead + Unpin + Send,
Self: Sized,
{
let mut buf = [0; 2];
r.read_exact(&mut buf).await?;
let length = u16::from_le_bytes(buf);
let header = UdpgwHeader::retrieve_from_async_stream(r).await?;
let mut data = vec![0; length as usize - header.len()];
r.read_exact(&mut data).await?;
Ok(Packet::new(header, data))
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[repr(C)]
#[repr(packed(1))]
pub struct UdpgwHeader {
pub flags: u8,
pub conid: u16,
pub conn_id: u16,
}
impl StreamOperation for UdpgwHeader {
fn retrieve_from_stream<R>(stream: &mut R) -> std::io::Result<Self>
where
R: std::io::Read,
Self: Sized,
{
let mut buf = [0; UdpgwHeader::len()];
stream.read_exact(&mut buf)?;
UdpgwHeader::try_from(&buf[..])
}
fn write_to_buf<B: BufMut>(&self, buf: &mut B) {
let bytes: Vec<u8> = self.into();
buf.put_slice(&bytes);
}
fn len(&self) -> usize {
Self::len()
}
}
#[async_trait::async_trait]
impl AsyncStreamOperation for UdpgwHeader {
async fn retrieve_from_async_stream<R>(r: &mut R) -> std::io::Result<Self>
where
R: tokio::io::AsyncRead + Unpin + Send,
Self: Sized,
{
let mut buf = [0; UdpgwHeader::len()];
r.read_exact(&mut buf).await?;
UdpgwHeader::try_from(&buf[..])
}
}
impl UdpgwHeader {
pub fn new(flags: u8, conid: u16) -> Self {
UdpgwHeader { flags, conid }
pub fn new(flags: u8, conn_id: u16) -> Self {
UdpgwHeader { flags, conn_id }
}
pub const fn len() -> usize {
std::mem::size_of::<u16>() + std::mem::size_of::<UdpgwHeader>()
std::mem::size_of::<UdpgwHeader>()
}
}
@ -50,25 +174,20 @@ impl TryFrom<&[u8]> for UdpgwHeader {
fn try_from(value: &[u8]) -> std::result::Result<Self, Self::Error> {
if value.len() < UdpgwHeader::len() {
return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "Invalid UdpgwHeader"));
}
let len = u16::from_le_bytes([value[0], value[1]]);
if len != std::mem::size_of::<UdpgwHeader>() as u16 {
return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "Invalid UdpgwHeader"));
return Err(std::io::ErrorKind::InvalidData.into());
}
Ok(UdpgwHeader {
flags: value[2],
conid: u16::from_le_bytes([value[3], value[4]]),
flags: value[0],
conn_id: u16::from_le_bytes([value[1], value[2]]),
})
}
}
impl From<UdpgwHeader> for Vec<u8> {
fn from(header: UdpgwHeader) -> Vec<u8> {
let mut bytes = vec![0; UdpgwHeader::len()];
bytes[0..2].copy_from_slice(&(std::mem::size_of::<UdpgwHeader>() as u16).to_le_bytes());
bytes[2] = header.flags;
bytes[3..5].copy_from_slice(&header.conid.to_le_bytes());
impl From<&UdpgwHeader> for Vec<u8> {
fn from(header: &UdpgwHeader) -> Vec<u8> {
let mut bytes = vec![0; header.len()];
bytes[0] = header.flags;
bytes[1..3].copy_from_slice(&header.conn_id.to_le_bytes());
bytes
}
}
@ -127,9 +246,10 @@ impl From<UdpgwAddr> for SocketAddr {
}
#[allow(dead_code)]
#[derive(Debug)]
pub(crate) struct UdpGwData<'a> {
flags: u8,
conid: u16,
conn_id: u16,
remote_addr: SocketAddr,
udpdata: &'a [u8],
}
@ -141,6 +261,7 @@ impl<'a> UdpGwData<'a> {
}
#[allow(dead_code)]
#[derive(Debug)]
pub(crate) enum UdpGwResponse<'a> {
KeepAlive,
Error,
@ -166,7 +287,7 @@ pub(crate) struct UdpGwClientStream {
local_addr: String,
writer: Option<UdpGwClientStreamWriter>,
reader: Option<UdpGwClientStreamReader>,
conid: u16,
conn_id: u16,
closed: bool,
last_activity: std::time::Instant,
}
@ -210,12 +331,12 @@ impl UdpGwClientStream {
}
pub fn id(&mut self) -> u16 {
self.conid
self.conn_id
}
pub fn newid(&mut self) -> u16 {
self.conid += 1;
self.conid
pub fn new_id(&mut self) -> u16 {
self.conn_id += 1;
self.conn_id
}
pub fn new(udp_mtu: u16, tcp_server_stream: TcpStream) -> Self {
let local_addr = tcp_server_stream
@ -239,7 +360,7 @@ impl UdpGwClientStream {
writer: Some(writer),
last_activity: std::time::Instant::now(),
closed: false,
conid: 0,
conn_id: 0,
}
}
}
@ -257,7 +378,7 @@ pub(crate) struct UdpGwClient {
impl UdpGwClient {
pub fn new(udp_mtu: u16, max_connections: u16, keepalive_time: Duration, udp_timeout: u64, server_addr: SocketAddr) -> Self {
let keepalive_packet = UdpgwHeader::new(UDPGW_FLAG_KEEPALIVE, 0).into();
let keepalive_packet: Vec<u8> = Packet::new(UdpgwHeader::new(UDPGW_FLAG_KEEPALIVE, 0), vec![]).into();
let server_connections = Mutex::new(VecDeque::with_capacity(max_connections as usize));
UdpGwClient {
udp_mtu,
@ -326,14 +447,10 @@ impl UdpGwClient {
let Some(mut stream_writer) = stream.get_writer() else {
continue;
};
log::debug!("{:?}:{} send keepalive", stream_writer.inner.local_addr(), stream.id());
let local_addr = stream_writer.inner.local_addr();
log::debug!("{:?}:{} send keepalive", local_addr, stream.id());
if let Err(e) = stream_writer.inner.write_all(&self.keepalive_packet).await {
log::warn!(
"{:?}:{} send keepalive failed: {}",
stream_writer.inner.local_addr(),
stream.id(),
e
);
log::warn!("{:?}:{} send keepalive failed: {}", local_addr, stream.id(), e);
} else {
match UdpGwClient::recv_udpgw_packet(self.udp_mtu, 10, &mut stream_reader).await {
Ok(UdpGwResponse::KeepAlive) => {
@ -341,10 +458,8 @@ impl UdpGwClient {
self.release_server_connection_with_stream(stream, stream_reader, stream_writer)
.await;
}
//shoud not receive other type
_ => {
log::warn!("{:?}:{} keepalive no response", stream_writer.inner.local_addr(), stream.id());
}
Ok(v) => log::warn!("{:?}:{} keepalive unexpected response: {:?}", local_addr, stream.id(), v),
Err(e) => log::warn!("{:?}:{} keepalive no response, error \"{}\"", local_addr, stream.id(), e),
}
}
}
@ -354,20 +469,20 @@ impl UdpGwClient {
/// Parses the UDP response data.
pub(crate) fn parse_udp_response(udp_mtu: u16, data_len: usize, stream: &mut UdpGwClientStreamReader) -> Result<UdpGwResponse> {
let data = &stream.recv_buf;
if data_len < mem::size_of::<UdpgwHeader>() {
if data_len < UdpgwHeader::len() {
return Err("Invalid udpgw data".into());
}
let header_bytes = &data[..mem::size_of::<UdpgwHeader>()];
let header_bytes = &data[..UdpgwHeader::len()];
let header = UdpgwHeader {
flags: header_bytes[0],
conid: u16::from_le_bytes([header_bytes[1], header_bytes[2]]),
conn_id: u16::from_le_bytes([header_bytes[1], header_bytes[2]]),
};
let flags = header.flags;
let conid = header.conid;
let conn_id = header.conn_id;
let ip_data = &data[mem::size_of::<UdpgwHeader>()..];
let mut data_len = data_len - mem::size_of::<UdpgwHeader>();
let ip_data = &data[UdpgwHeader::len()..];
let mut data_len = data_len - UdpgwHeader::len();
if flags & UDPGW_FLAG_ERR != 0 {
return Ok(UdpGwResponse::Error);
@ -378,44 +493,44 @@ impl UdpGwClient {
}
if flags & UDPGW_FLAG_IPV6 != 0 {
if data_len < mem::size_of::<UdpgwAddrIpv6>() {
if data_len < std::mem::size_of::<UdpgwAddrIpv6>() {
return Err("ipv6 Invalid UDP data".into());
}
let addr_ipv6_bytes = &ip_data[..mem::size_of::<UdpgwAddrIpv6>()];
let addr_ipv6_bytes = &ip_data[..std::mem::size_of::<UdpgwAddrIpv6>()];
let addr_ipv6 = UdpgwAddrIpv6 {
addr_ip: addr_ipv6_bytes[..16].try_into().map_err(|_| "Failed to convert slice to array")?,
addr_port: u16::from_be_bytes([addr_ipv6_bytes[16], addr_ipv6_bytes[17]]),
};
data_len -= mem::size_of::<UdpgwAddrIpv6>();
data_len -= std::mem::size_of::<UdpgwAddrIpv6>();
if data_len > udp_mtu as usize {
return Err("too much data".into());
}
return Ok(UdpGwResponse::Data(UdpGwData {
flags,
conid,
conn_id,
remote_addr: UdpgwAddr::IPV6(addr_ipv6).into(),
udpdata: &ip_data[mem::size_of::<UdpgwAddrIpv6>()..(data_len + mem::size_of::<UdpgwAddrIpv6>())],
udpdata: &ip_data[std::mem::size_of::<UdpgwAddrIpv6>()..(data_len + std::mem::size_of::<UdpgwAddrIpv6>())],
}));
} else {
if data_len < mem::size_of::<UdpgwAddrIpv4>() {
if data_len < std::mem::size_of::<UdpgwAddrIpv4>() {
return Err("ipv4 Invalid UDP data".into());
}
let addr_ipv4_bytes = &ip_data[..mem::size_of::<UdpgwAddrIpv4>()];
let addr_ipv4_bytes = &ip_data[..std::mem::size_of::<UdpgwAddrIpv4>()];
let addr_ipv4 = UdpgwAddrIpv4 {
addr_ip: u32::from_be_bytes([addr_ipv4_bytes[0], addr_ipv4_bytes[1], addr_ipv4_bytes[2], addr_ipv4_bytes[3]]),
addr_port: u16::from_be_bytes([addr_ipv4_bytes[4], addr_ipv4_bytes[5]]),
};
data_len -= mem::size_of::<UdpgwAddrIpv4>();
data_len -= std::mem::size_of::<UdpgwAddrIpv4>();
if data_len > udp_mtu as usize {
return Err("too much data".into());
}
return Ok(UdpGwResponse::Data(UdpGwData {
flags,
conid,
conn_id,
remote_addr: UdpgwAddr::IPV4(addr_ipv4).into(),
udpdata: &ip_data[mem::size_of::<UdpgwAddrIpv4>()..(data_len + mem::size_of::<UdpgwAddrIpv4>())],
udpdata: &ip_data[std::mem::size_of::<UdpgwAddrIpv4>()..(data_len + std::mem::size_of::<UdpgwAddrIpv4>())],
}));
}
}
@ -456,8 +571,8 @@ impl UdpGwClient {
if n == 0 {
return Ok(UdpGwResponse::TcpClose);
}
if n < std::mem::size_of::<PackLenHeader>() {
return Err("received PackLenHeader error".into());
if n < UDPGW_LENGTH_FIELD_SIZE {
return Err("received Packet Length field error".into());
}
let packet_len = u16::from_le_bytes([stream.recv_buf[0], stream.recv_buf[1]]);
if packet_len > udp_mtu {
@ -489,7 +604,7 @@ impl UdpGwClient {
/// * `len` - Length of the data packet
/// * `remote_addr` - Remote address
/// * `domain` - Target domain (optional)
/// * `conid` - Connection ID
/// * `conn_id` - Connection ID
/// * `stream` - UDP gateway client writer stream
///
/// # Returns
@ -500,30 +615,28 @@ impl UdpGwClient {
len: usize,
remote_addr: SocketAddr,
domain: Option<&String>,
conid: u16,
conn_id: u16,
stream: &mut UdpGwClientStreamWriter,
) -> Result<()> {
stream.send_buf.clear();
let data = &stream.tmp_buf;
let mut pack_len = std::mem::size_of::<UdpgwHeader>() + len;
let mut pack_len = UdpgwHeader::len() + len;
let packet = &mut stream.send_buf;
let mut flags = 0;
match domain {
Some(domain) => {
let addr_port = match remote_addr.into() {
UdpgwAddr::IPV4(addr_ipv4) => addr_ipv4.addr_port,
UdpgwAddr::IPV6(addr_ipv6) => addr_ipv6.addr_port,
};
pack_len += std::mem::size_of::<u16>();
let domain_len = domain.len();
if domain_len > 255 {
return Err("InvalidDomain".into());
}
pack_len += UDPGW_LENGTH_FIELD_SIZE;
pack_len += domain_len + 1;
flags = UDPGW_FLAG_DOMAIN;
packet.extend_from_slice(&(pack_len as u16).to_le_bytes());
packet.extend_from_slice(&[flags]);
packet.extend_from_slice(&conid.to_le_bytes());
packet.extend_from_slice(&[UDPGW_FLAG_DOMAIN]);
packet.extend_from_slice(&conn_id.to_le_bytes());
packet.extend_from_slice(&addr_port.to_be_bytes());
packet.extend_from_slice(domain.as_bytes());
packet.push(0);
@ -533,8 +646,8 @@ impl UdpGwClient {
UdpgwAddr::IPV4(addr_ipv4) => {
pack_len += std::mem::size_of::<UdpgwAddrIpv4>();
packet.extend_from_slice(&(pack_len as u16).to_le_bytes());
packet.extend_from_slice(&[flags]);
packet.extend_from_slice(&conid.to_le_bytes());
packet.extend_from_slice(&[UDPGW_FLAG_IPV4]);
packet.extend_from_slice(&conn_id.to_le_bytes());
packet.extend_from_slice(&addr_ipv4.addr_ip.to_be_bytes());
packet.extend_from_slice(&addr_ipv4.addr_port.to_be_bytes());
packet.extend_from_slice(&data[..len]);
@ -543,11 +656,10 @@ impl UdpGwClient {
if !ipv6_enabled {
return Err("ipv6 not support".into());
}
flags = UDPGW_FLAG_IPV6;
pack_len += std::mem::size_of::<UdpgwAddrIpv6>();
packet.extend_from_slice(&(pack_len as u16).to_le_bytes());
packet.extend_from_slice(&[flags]);
packet.extend_from_slice(&conid.to_le_bytes());
packet.extend_from_slice(&[UDPGW_FLAG_IPV6]);
packet.extend_from_slice(&conn_id.to_le_bytes());
packet.extend_from_slice(&addr_ipv6.addr_ip);
packet.extend_from_slice(&addr_ipv6.addr_port.to_be_bytes());
packet.extend_from_slice(&data[..len]);
@ -563,13 +675,18 @@ impl UdpGwClient {
#[cfg(test)]
mod tests {
use super::UdpgwHeader;
use super::{Packet, UdpgwHeader};
use socks5_impl::protocol::StreamOperation;
#[test]
fn test_udpgw_header() {
let header = UdpgwHeader::new(0x01, 0x1234);
let bytes = Vec::from(header.clone());
let header2 = UdpgwHeader::try_from(&bytes[..]).unwrap();
let mut bytes: Vec<u8> = vec![];
let packet = Packet::new(header, vec![]);
packet.write_to_buf(&mut bytes);
let header2 = Packet::retrieve_from_stream(&mut bytes.as_slice()).unwrap().header;
assert_eq!(header, header2);
}
}