diff --git a/src/args.rs b/src/args.rs index efcd2e7..4291189 100644 --- a/src/args.rs +++ b/src/args.rs @@ -117,7 +117,7 @@ pub struct Args { #[arg(long, value_name = "IP:PORT")] pub udpgw_server: Option, - /// Max udpgw connections + /// Max udpgw connections, default value is 100 #[cfg(feature = "udpgw")] #[arg(long, value_name = "number")] pub udpgw_max_connections: Option, diff --git a/src/bin/udpgw_server.rs b/src/bin/udpgw_server.rs index dc2da57..081e6be 100644 --- a/src/bin/udpgw_server.rs +++ b/src/bin/udpgw_server.rs @@ -18,7 +18,7 @@ pub(crate) const CLIENT_DISCONNECT_TIMEOUT: tokio::time::Duration = std::time::D struct UdpRequest { flags: u8, server_addr: SocketAddr, - conid: u16, + conn_id: u16, data: Vec, } @@ -41,26 +41,26 @@ impl Client { #[derive(Debug, Clone, clap::Parser)] pub struct UdpGwArgs { + /// UDP gateway listen address + #[arg(short, long, value_name = "IP:PORT", default_value = "127.0.0.1:7300")] + pub listen_addr: SocketAddr, + /// UDP mtu - #[arg(long, value_name = "udp mtu", default_value = "10240")] + #[arg(short = 'm', long, value_name = "udp mtu", default_value = "10240")] pub udp_mtu: u16, - /// Verbosity level - #[arg(short, long, value_name = "level", value_enum, default_value = "info")] - pub verbosity: ArgVerbosity, + /// UDP timeout in seconds + #[arg(short = 't', long, value_name = "seconds", default_value = "3")] + pub udp_timeout: u64, /// Daemonize for unix family or run as Windows service #[cfg(unix)] #[arg(long)] pub daemonize: bool, - /// UDP timeout in seconds - #[arg(long, value_name = "seconds", default_value = "3")] - pub udp_timeout: u64, - - /// UDP gateway listen address - #[arg(long, value_name = "IP:PORT", default_value = "127.0.0.1:7300")] - pub listen_addr: SocketAddr, + /// Verbosity level + #[arg(short, long, value_name = "level", value_enum, default_value = "info")] + pub verbosity: ArgVerbosity, } impl UdpGwArgs { @@ -70,42 +70,43 @@ impl UdpGwArgs { Self::parse() } } + async fn send_error(tx: Sender>, con: &mut UdpRequest) { - let error_packet = UdpgwHeader::new(UDPGW_FLAG_ERR, con.conid).into(); + let error_packet: Vec = Packet::new(UdpgwHeader::new(UDPGW_FLAG_ERR, con.conn_id), vec![]).into(); if let Err(e) = tx.send(error_packet).await { log::error!("send error response error {:?}", e); } } -async fn send_keepalive_response(tx: Sender>, conid: u16) { - let keepalive_packet = UdpgwHeader::new(UDPGW_FLAG_KEEPALIVE, conid).into(); +async fn send_keepalive_response(tx: Sender>, conn_id: u16) { + let keepalive_packet: Vec = Packet::new(UdpgwHeader::new(UDPGW_FLAG_KEEPALIVE, conn_id), vec![]).into(); if let Err(e) = tx.send(keepalive_packet).await { log::error!("send keepalive response error {:?}", e); } } pub fn parse_udp(udp_mtu: u16, data_len: usize, data: &[u8]) -> Result<(&[u8], u8, u16, SocketAddr)> { - if data_len < std::mem::size_of::() { + if data_len < UdpgwHeader::len() { return Err("Invalid udpgw data".into()); } - let header_bytes = &data[..std::mem::size_of::()]; + let header_bytes = &data[..UdpgwHeader::len()]; let header = UdpgwHeader { flags: header_bytes[0], - conid: u16::from_le_bytes([header_bytes[1], header_bytes[2]]), + conn_id: u16::from_le_bytes([header_bytes[1], header_bytes[2]]), }; let flags = header.flags; - let conid = header.conid; + let conn_id = header.conn_id; // keepalive if flags & UDPGW_FLAG_KEEPALIVE != 0 { - return Ok((data, flags, conid, SocketAddrV4::new(Ipv4Addr::from(0), 0).into())); + return Ok((data, flags, conn_id, SocketAddrV4::new(Ipv4Addr::from(0), 0).into())); } - let ip_data = &data[std::mem::size_of::()..]; - let mut data_len = data_len - std::mem::size_of::(); + let ip_data = &data[UdpgwHeader::len()..]; + let mut data_len = data_len - UdpgwHeader::len(); // port_len + min(ipv4/ipv6/(domain_len + 1)) - if data_len < std::mem::size_of::() + 2 { + if data_len < UDPGW_LENGTH_FIELD_SIZE + 2 { return Err("Invalid udpgw data".into()); } if flags & UDPGW_FLAG_DOMAIN != 0 { @@ -128,7 +129,7 @@ pub fn parse_udp(udp_mtu: u16, data_len: usize, data: &[u8]) -> Result<(&[u8], u return Err("too much data".into()); } let udpdata = &ip_data[(2 + domain.len() + 1)..]; - Ok((udpdata, flags, conid, target)) + Ok((udpdata, flags, conn_id, target)) } Err(_) => Err("Invalid UTF-8 sequence in domain".into()), } @@ -152,7 +153,7 @@ pub fn parse_udp(udp_mtu: u16, data_len: usize, data: &[u8]) -> Result<(&[u8], u return Ok(( &ip_data[std::mem::size_of::()..(data_len + std::mem::size_of::())], flags, - conid, + conn_id, UdpgwAddr::IPV6(addr_ipv6).into(), )); } else { @@ -173,7 +174,7 @@ pub fn parse_udp(udp_mtu: u16, data_len: usize, data: &[u8]) -> Result<(&[u8], u return Ok(( &ip_data[std::mem::size_of::()..(data_len + std::mem::size_of::())], flags, - conid, + conn_id, UdpgwAddr::IPV4(addr_ipv4).into(), )); } @@ -195,13 +196,13 @@ async fn process_udp(addr: SocketAddr, udp_timeout: u64, tx: Sender>, co Ok(ret) => { let (len, _addr) = ret?; let mut packet = vec![]; - let mut pack_len = std::mem::size_of::() + len; + let mut pack_len = UdpgwHeader::len() + len; match con.server_addr.into() { UdpgwAddr::IPV4(addr_ipv4) => { pack_len += std::mem::size_of::(); packet.extend_from_slice(&(pack_len as u16).to_le_bytes()); packet.extend_from_slice(&[con.flags]); - packet.extend_from_slice(&con.conid.to_le_bytes()); + packet.extend_from_slice(&con.conn_id.to_le_bytes()); packet.extend_from_slice(&addr_ipv4.addr_ip.to_be_bytes()); packet.extend_from_slice(&addr_ipv4.addr_port.to_be_bytes()); packet.extend_from_slice(&con.data[..len]); @@ -210,7 +211,7 @@ async fn process_udp(addr: SocketAddr, udp_timeout: u64, tx: Sender>, co pack_len += std::mem::size_of::(); packet.extend_from_slice(&(pack_len as u16).to_le_bytes()); packet.extend_from_slice(&[con.flags]); - packet.extend_from_slice(&con.conid.to_le_bytes()); + packet.extend_from_slice(&con.conn_id.to_le_bytes()); packet.extend_from_slice(&addr_ipv6.addr_ip); packet.extend_from_slice(&addr_ipv6.addr_port.to_be_bytes()); packet.extend_from_slice(&con.data[..len]); @@ -230,7 +231,7 @@ async fn process_udp(addr: SocketAddr, udp_timeout: u64, tx: Sender>, co async fn process_client_udp_req(args: &UdpGwArgs, tx: Sender>, client: Client, mut reader: ReadHalf<'_>) -> std::io::Result<()> { let mut client = client; let mut buf = vec![0; args.udp_mtu as usize]; - let mut len_buf = [0; std::mem::size_of::()]; + let mut len_buf = [0; UDPGW_LENGTH_FIELD_SIZE]; let udp_mtu = args.udp_mtu; let udp_timeout = args.udp_timeout; @@ -250,8 +251,8 @@ async fn process_client_udp_req(args: &UdpGwArgs, tx: Sender>, client: C // Connection closed break; } - if n < std::mem::size_of::() { - log::error!("client {} received PackLenHeader error", client.addr); + if n < UDPGW_LENGTH_FIELD_SIZE { + log::error!("client {} received Packet Length field error", client.addr); break; } let packet_len = u16::from_le_bytes([len_buf[0], len_buf[1]]); @@ -272,23 +273,23 @@ async fn process_client_udp_req(args: &UdpGwArgs, tx: Sender>, client: C } client.last_activity = std::time::Instant::now(); let ret = parse_udp(udp_mtu, client.buf.len(), &client.buf); - if let Ok((udpdata, flags, conid, reqaddr)) = ret { + if let Ok((udpdata, flags, conn_id, reqaddr)) = ret { if flags & UDPGW_FLAG_KEEPALIVE != 0 { log::debug!("client {} send keepalive", client.addr); - send_keepalive_response(tx.clone(), conid).await; + send_keepalive_response(tx.clone(), conn_id).await; continue; } log::debug!( - "client {} received udp data,flags:{},conid:{},addr:{:?},data len:{}", + "client {} received udp data,flags:{},conn_id:{},addr:{:?},data len:{}", client.addr, flags, - conid, + conn_id, reqaddr, udpdata.len() ); let mut req = UdpRequest { server_addr: reqaddr, - conid, + conn_id, flags, data: udpdata.to_vec(), }; diff --git a/src/lib.rs b/src/lib.rs index ff04568..3c34d9f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -513,12 +513,8 @@ async fn handle_udp_gateway_session( let tcp_local_addr = server_stream.local_addr().clone(); match domain_name { - Some(ref d) => { - log::info!("Beginning {} <- {}, domain:{}", udpinfo, &tcp_local_addr, d); - } - None => { - log::info!("Beginning {} <- {}", udpinfo, &tcp_local_addr); - } + Some(ref d) => log::info!("Beginning {} <- {}, domain:{}", udpinfo, &tcp_local_addr, d), + None => log::info!("Beginning {} <- {}", udpinfo, &tcp_local_addr), } let Some(mut stream_reader) = server_stream.get_reader() else { @@ -547,8 +543,8 @@ async fn handle_udp_gateway_session( break; } } - let newid = server_stream.newid(); - if let Err(e) = UdpGwClient::send_udpgw_packet(ipv6_enabled, read_len, udp_server_addr, domain_name.as_ref(), newid, &mut stream_writer).await { + let new_id = server_stream.new_id(); + if let Err(e) = UdpGwClient::send_udpgw_packet(ipv6_enabled, read_len, udp_server_addr, domain_name.as_ref(), new_id, &mut stream_writer).await { log::info!("Ending {} <- {} with send_udpgw_packet {}", udpinfo, &tcp_local_addr, e); break; } diff --git a/src/udpgw.rs b/src/udpgw.rs index 9118af8..1b0798e 100644 --- a/src/udpgw.rs +++ b/src/udpgw.rs @@ -1,47 +1,171 @@ use crate::error::Result; use ipstack::stream::IpStackUdpStream; +use socks5_impl::protocol::{AsyncStreamOperation, BufMut, StreamOperation}; use std::collections::VecDeque; use std::hash::Hash; -use std::mem; use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; use std::sync::atomic::Ordering::Relaxed; -use tokio::io::{AsyncReadExt, AsyncWriteExt}; -use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf}; -use tokio::net::TcpStream; -use tokio::sync::Mutex; -use tokio::time::{sleep, Duration}; +use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + net::{ + tcp::{OwnedReadHalf, OwnedWriteHalf}, + TcpStream, + }, + sync::Mutex, + time::{sleep, Duration}, +}; pub const UDPGW_MAX_CONNECTIONS: usize = 100; pub const UDPGW_KEEPALIVE_TIME: tokio::time::Duration = std::time::Duration::from_secs(10); pub const UDPGW_FLAG_KEEPALIVE: u8 = 0x01; +pub const UDPGW_FLAG_IPV4: u8 = 0x00; pub const UDPGW_FLAG_IPV6: u8 = 0x08; pub const UDPGW_FLAG_DOMAIN: u8 = 0x10; pub const UDPGW_FLAG_ERR: u8 = 0x20; +pub const UDPGW_LENGTH_FIELD_SIZE: usize = std::mem::size_of::(); + static TCP_COUNTER: std::sync::atomic::AtomicU32 = std::sync::atomic::AtomicU32::new(0); #[derive(Debug, Clone, PartialEq, Eq, Hash)] -#[repr(C)] -#[repr(packed(1))] -pub struct PackLenHeader { - packet_len: u16, +pub struct Packet { + pub length: u16, + pub header: UdpgwHeader, + pub data: Vec, } -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +impl From for Vec { + fn from(packet: Packet) -> Vec { + (&packet).into() + } +} + +impl From<&Packet> for Vec { + fn from(packet: &Packet) -> Vec { + let mut bytes = vec![0; packet.len()]; + packet.write_to_buf(&mut bytes); + bytes + } +} + +impl TryFrom<&[u8]> for Packet { + type Error = std::io::Error; + + fn try_from(value: &[u8]) -> std::result::Result { + if value.len() < UDPGW_LENGTH_FIELD_SIZE { + return Err(std::io::ErrorKind::InvalidData.into()); + } + let length = u16::from_le_bytes([value[0], value[1]]); + if value.len() < length as usize + UDPGW_LENGTH_FIELD_SIZE { + return Err(std::io::ErrorKind::InvalidData.into()); + } + let header = UdpgwHeader::try_from(&value[UDPGW_LENGTH_FIELD_SIZE..])?; + let data = value[UDPGW_LENGTH_FIELD_SIZE + header.len()..].to_vec(); + Ok(Packet::new(header, data)) + } +} + +impl Packet { + pub fn new(header: UdpgwHeader, data: Vec) -> Self { + let length = (header.len() + data.len()) as u16; + Packet { length, header, data } + } +} + +impl StreamOperation for Packet { + fn retrieve_from_stream(stream: &mut R) -> std::io::Result + where + R: std::io::Read, + Self: Sized, + { + let mut buf = [0; UDPGW_LENGTH_FIELD_SIZE]; + stream.read_exact(&mut buf)?; + let length = u16::from_le_bytes(buf); + let mut buf = [0; UdpgwHeader::len()]; + stream.read_exact(&mut buf)?; + let header = UdpgwHeader::try_from(&buf[..])?; + let mut data = vec![0; length as usize - header.len()]; + stream.read_exact(&mut data)?; + Ok(Packet::new(header, data)) + } + + fn write_to_buf(&self, buf: &mut B) { + buf.put_u16_le(self.length); + self.header.write_to_buf(buf); + buf.put_slice(&self.data); + } + + fn len(&self) -> usize { + UDPGW_LENGTH_FIELD_SIZE + self.header.len() + self.data.len() + } +} + +#[async_trait::async_trait] +impl AsyncStreamOperation for Packet { + async fn retrieve_from_async_stream(r: &mut R) -> std::io::Result + where + R: tokio::io::AsyncRead + Unpin + Send, + Self: Sized, + { + let mut buf = [0; 2]; + r.read_exact(&mut buf).await?; + let length = u16::from_le_bytes(buf); + let header = UdpgwHeader::retrieve_from_async_stream(r).await?; + let mut data = vec![0; length as usize - header.len()]; + r.read_exact(&mut data).await?; + Ok(Packet::new(header, data)) + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] #[repr(C)] #[repr(packed(1))] pub struct UdpgwHeader { pub flags: u8, - pub conid: u16, + pub conn_id: u16, +} + +impl StreamOperation for UdpgwHeader { + fn retrieve_from_stream(stream: &mut R) -> std::io::Result + where + R: std::io::Read, + Self: Sized, + { + let mut buf = [0; UdpgwHeader::len()]; + stream.read_exact(&mut buf)?; + UdpgwHeader::try_from(&buf[..]) + } + + fn write_to_buf(&self, buf: &mut B) { + let bytes: Vec = self.into(); + buf.put_slice(&bytes); + } + + fn len(&self) -> usize { + Self::len() + } +} + +#[async_trait::async_trait] +impl AsyncStreamOperation for UdpgwHeader { + async fn retrieve_from_async_stream(r: &mut R) -> std::io::Result + where + R: tokio::io::AsyncRead + Unpin + Send, + Self: Sized, + { + let mut buf = [0; UdpgwHeader::len()]; + r.read_exact(&mut buf).await?; + UdpgwHeader::try_from(&buf[..]) + } } impl UdpgwHeader { - pub fn new(flags: u8, conid: u16) -> Self { - UdpgwHeader { flags, conid } + pub fn new(flags: u8, conn_id: u16) -> Self { + UdpgwHeader { flags, conn_id } } pub const fn len() -> usize { - std::mem::size_of::() + std::mem::size_of::() + std::mem::size_of::() } } @@ -50,25 +174,20 @@ impl TryFrom<&[u8]> for UdpgwHeader { fn try_from(value: &[u8]) -> std::result::Result { if value.len() < UdpgwHeader::len() { - return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "Invalid UdpgwHeader")); - } - let len = u16::from_le_bytes([value[0], value[1]]); - if len != std::mem::size_of::() as u16 { - return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "Invalid UdpgwHeader")); + return Err(std::io::ErrorKind::InvalidData.into()); } Ok(UdpgwHeader { - flags: value[2], - conid: u16::from_le_bytes([value[3], value[4]]), + flags: value[0], + conn_id: u16::from_le_bytes([value[1], value[2]]), }) } } -impl From for Vec { - fn from(header: UdpgwHeader) -> Vec { - let mut bytes = vec![0; UdpgwHeader::len()]; - bytes[0..2].copy_from_slice(&(std::mem::size_of::() as u16).to_le_bytes()); - bytes[2] = header.flags; - bytes[3..5].copy_from_slice(&header.conid.to_le_bytes()); +impl From<&UdpgwHeader> for Vec { + fn from(header: &UdpgwHeader) -> Vec { + let mut bytes = vec![0; header.len()]; + bytes[0] = header.flags; + bytes[1..3].copy_from_slice(&header.conn_id.to_le_bytes()); bytes } } @@ -127,9 +246,10 @@ impl From for SocketAddr { } #[allow(dead_code)] +#[derive(Debug)] pub(crate) struct UdpGwData<'a> { flags: u8, - conid: u16, + conn_id: u16, remote_addr: SocketAddr, udpdata: &'a [u8], } @@ -141,6 +261,7 @@ impl<'a> UdpGwData<'a> { } #[allow(dead_code)] +#[derive(Debug)] pub(crate) enum UdpGwResponse<'a> { KeepAlive, Error, @@ -166,7 +287,7 @@ pub(crate) struct UdpGwClientStream { local_addr: String, writer: Option, reader: Option, - conid: u16, + conn_id: u16, closed: bool, last_activity: std::time::Instant, } @@ -210,12 +331,12 @@ impl UdpGwClientStream { } pub fn id(&mut self) -> u16 { - self.conid + self.conn_id } - pub fn newid(&mut self) -> u16 { - self.conid += 1; - self.conid + pub fn new_id(&mut self) -> u16 { + self.conn_id += 1; + self.conn_id } pub fn new(udp_mtu: u16, tcp_server_stream: TcpStream) -> Self { let local_addr = tcp_server_stream @@ -239,7 +360,7 @@ impl UdpGwClientStream { writer: Some(writer), last_activity: std::time::Instant::now(), closed: false, - conid: 0, + conn_id: 0, } } } @@ -257,7 +378,7 @@ pub(crate) struct UdpGwClient { impl UdpGwClient { pub fn new(udp_mtu: u16, max_connections: u16, keepalive_time: Duration, udp_timeout: u64, server_addr: SocketAddr) -> Self { - let keepalive_packet = UdpgwHeader::new(UDPGW_FLAG_KEEPALIVE, 0).into(); + let keepalive_packet: Vec = Packet::new(UdpgwHeader::new(UDPGW_FLAG_KEEPALIVE, 0), vec![]).into(); let server_connections = Mutex::new(VecDeque::with_capacity(max_connections as usize)); UdpGwClient { udp_mtu, @@ -326,14 +447,10 @@ impl UdpGwClient { let Some(mut stream_writer) = stream.get_writer() else { continue; }; - log::debug!("{:?}:{} send keepalive", stream_writer.inner.local_addr(), stream.id()); + let local_addr = stream_writer.inner.local_addr(); + log::debug!("{:?}:{} send keepalive", local_addr, stream.id()); if let Err(e) = stream_writer.inner.write_all(&self.keepalive_packet).await { - log::warn!( - "{:?}:{} send keepalive failed: {}", - stream_writer.inner.local_addr(), - stream.id(), - e - ); + log::warn!("{:?}:{} send keepalive failed: {}", local_addr, stream.id(), e); } else { match UdpGwClient::recv_udpgw_packet(self.udp_mtu, 10, &mut stream_reader).await { Ok(UdpGwResponse::KeepAlive) => { @@ -341,10 +458,8 @@ impl UdpGwClient { self.release_server_connection_with_stream(stream, stream_reader, stream_writer) .await; } - //shoud not receive other type - _ => { - log::warn!("{:?}:{} keepalive no response", stream_writer.inner.local_addr(), stream.id()); - } + Ok(v) => log::warn!("{:?}:{} keepalive unexpected response: {:?}", local_addr, stream.id(), v), + Err(e) => log::warn!("{:?}:{} keepalive no response, error \"{}\"", local_addr, stream.id(), e), } } } @@ -354,20 +469,20 @@ impl UdpGwClient { /// Parses the UDP response data. pub(crate) fn parse_udp_response(udp_mtu: u16, data_len: usize, stream: &mut UdpGwClientStreamReader) -> Result { let data = &stream.recv_buf; - if data_len < mem::size_of::() { + if data_len < UdpgwHeader::len() { return Err("Invalid udpgw data".into()); } - let header_bytes = &data[..mem::size_of::()]; + let header_bytes = &data[..UdpgwHeader::len()]; let header = UdpgwHeader { flags: header_bytes[0], - conid: u16::from_le_bytes([header_bytes[1], header_bytes[2]]), + conn_id: u16::from_le_bytes([header_bytes[1], header_bytes[2]]), }; let flags = header.flags; - let conid = header.conid; + let conn_id = header.conn_id; - let ip_data = &data[mem::size_of::()..]; - let mut data_len = data_len - mem::size_of::(); + let ip_data = &data[UdpgwHeader::len()..]; + let mut data_len = data_len - UdpgwHeader::len(); if flags & UDPGW_FLAG_ERR != 0 { return Ok(UdpGwResponse::Error); @@ -378,44 +493,44 @@ impl UdpGwClient { } if flags & UDPGW_FLAG_IPV6 != 0 { - if data_len < mem::size_of::() { + if data_len < std::mem::size_of::() { return Err("ipv6 Invalid UDP data".into()); } - let addr_ipv6_bytes = &ip_data[..mem::size_of::()]; + let addr_ipv6_bytes = &ip_data[..std::mem::size_of::()]; let addr_ipv6 = UdpgwAddrIpv6 { addr_ip: addr_ipv6_bytes[..16].try_into().map_err(|_| "Failed to convert slice to array")?, addr_port: u16::from_be_bytes([addr_ipv6_bytes[16], addr_ipv6_bytes[17]]), }; - data_len -= mem::size_of::(); + data_len -= std::mem::size_of::(); if data_len > udp_mtu as usize { return Err("too much data".into()); } return Ok(UdpGwResponse::Data(UdpGwData { flags, - conid, + conn_id, remote_addr: UdpgwAddr::IPV6(addr_ipv6).into(), - udpdata: &ip_data[mem::size_of::()..(data_len + mem::size_of::())], + udpdata: &ip_data[std::mem::size_of::()..(data_len + std::mem::size_of::())], })); } else { - if data_len < mem::size_of::() { + if data_len < std::mem::size_of::() { return Err("ipv4 Invalid UDP data".into()); } - let addr_ipv4_bytes = &ip_data[..mem::size_of::()]; + let addr_ipv4_bytes = &ip_data[..std::mem::size_of::()]; let addr_ipv4 = UdpgwAddrIpv4 { addr_ip: u32::from_be_bytes([addr_ipv4_bytes[0], addr_ipv4_bytes[1], addr_ipv4_bytes[2], addr_ipv4_bytes[3]]), addr_port: u16::from_be_bytes([addr_ipv4_bytes[4], addr_ipv4_bytes[5]]), }; - data_len -= mem::size_of::(); + data_len -= std::mem::size_of::(); if data_len > udp_mtu as usize { return Err("too much data".into()); } return Ok(UdpGwResponse::Data(UdpGwData { flags, - conid, + conn_id, remote_addr: UdpgwAddr::IPV4(addr_ipv4).into(), - udpdata: &ip_data[mem::size_of::()..(data_len + mem::size_of::())], + udpdata: &ip_data[std::mem::size_of::()..(data_len + std::mem::size_of::())], })); } } @@ -456,8 +571,8 @@ impl UdpGwClient { if n == 0 { return Ok(UdpGwResponse::TcpClose); } - if n < std::mem::size_of::() { - return Err("received PackLenHeader error".into()); + if n < UDPGW_LENGTH_FIELD_SIZE { + return Err("received Packet Length field error".into()); } let packet_len = u16::from_le_bytes([stream.recv_buf[0], stream.recv_buf[1]]); if packet_len > udp_mtu { @@ -489,7 +604,7 @@ impl UdpGwClient { /// * `len` - Length of the data packet /// * `remote_addr` - Remote address /// * `domain` - Target domain (optional) - /// * `conid` - Connection ID + /// * `conn_id` - Connection ID /// * `stream` - UDP gateway client writer stream /// /// # Returns @@ -500,30 +615,28 @@ impl UdpGwClient { len: usize, remote_addr: SocketAddr, domain: Option<&String>, - conid: u16, + conn_id: u16, stream: &mut UdpGwClientStreamWriter, ) -> Result<()> { stream.send_buf.clear(); let data = &stream.tmp_buf; - let mut pack_len = std::mem::size_of::() + len; + let mut pack_len = UdpgwHeader::len() + len; let packet = &mut stream.send_buf; - let mut flags = 0; match domain { Some(domain) => { let addr_port = match remote_addr.into() { UdpgwAddr::IPV4(addr_ipv4) => addr_ipv4.addr_port, UdpgwAddr::IPV6(addr_ipv6) => addr_ipv6.addr_port, }; - pack_len += std::mem::size_of::(); let domain_len = domain.len(); if domain_len > 255 { return Err("InvalidDomain".into()); } + pack_len += UDPGW_LENGTH_FIELD_SIZE; pack_len += domain_len + 1; - flags = UDPGW_FLAG_DOMAIN; packet.extend_from_slice(&(pack_len as u16).to_le_bytes()); - packet.extend_from_slice(&[flags]); - packet.extend_from_slice(&conid.to_le_bytes()); + packet.extend_from_slice(&[UDPGW_FLAG_DOMAIN]); + packet.extend_from_slice(&conn_id.to_le_bytes()); packet.extend_from_slice(&addr_port.to_be_bytes()); packet.extend_from_slice(domain.as_bytes()); packet.push(0); @@ -533,8 +646,8 @@ impl UdpGwClient { UdpgwAddr::IPV4(addr_ipv4) => { pack_len += std::mem::size_of::(); packet.extend_from_slice(&(pack_len as u16).to_le_bytes()); - packet.extend_from_slice(&[flags]); - packet.extend_from_slice(&conid.to_le_bytes()); + packet.extend_from_slice(&[UDPGW_FLAG_IPV4]); + packet.extend_from_slice(&conn_id.to_le_bytes()); packet.extend_from_slice(&addr_ipv4.addr_ip.to_be_bytes()); packet.extend_from_slice(&addr_ipv4.addr_port.to_be_bytes()); packet.extend_from_slice(&data[..len]); @@ -543,11 +656,10 @@ impl UdpGwClient { if !ipv6_enabled { return Err("ipv6 not support".into()); } - flags = UDPGW_FLAG_IPV6; pack_len += std::mem::size_of::(); packet.extend_from_slice(&(pack_len as u16).to_le_bytes()); - packet.extend_from_slice(&[flags]); - packet.extend_from_slice(&conid.to_le_bytes()); + packet.extend_from_slice(&[UDPGW_FLAG_IPV6]); + packet.extend_from_slice(&conn_id.to_le_bytes()); packet.extend_from_slice(&addr_ipv6.addr_ip); packet.extend_from_slice(&addr_ipv6.addr_port.to_be_bytes()); packet.extend_from_slice(&data[..len]); @@ -563,13 +675,18 @@ impl UdpGwClient { #[cfg(test)] mod tests { - use super::UdpgwHeader; + use super::{Packet, UdpgwHeader}; + use socks5_impl::protocol::StreamOperation; #[test] fn test_udpgw_header() { let header = UdpgwHeader::new(0x01, 0x1234); - let bytes = Vec::from(header.clone()); - let header2 = UdpgwHeader::try_from(&bytes[..]).unwrap(); + let mut bytes: Vec = vec![]; + let packet = Packet::new(header, vec![]); + packet.write_to_buf(&mut bytes); + + let header2 = Packet::retrieve_from_stream(&mut bytes.as_slice()).unwrap().header; + assert_eq!(header, header2); } }