From 2155ea55c099d30d86ff31f67dfbe09f76e6c8df Mon Sep 17 00:00:00 2001 From: ssrlive <30760636+ssrlive@users.noreply.github.com> Date: Thu, 24 Oct 2024 13:26:20 +0800 Subject: [PATCH] BinSocketAddr struct --- src/bin/udpgw_server.rs | 61 ++++++------ src/udpgw.rs | 205 +++++++++++++++++++++++----------------- 2 files changed, 145 insertions(+), 121 deletions(-) diff --git a/src/bin/udpgw_server.rs b/src/bin/udpgw_server.rs index d3add78..38b36df 100644 --- a/src/bin/udpgw_server.rs +++ b/src/bin/udpgw_server.rs @@ -86,10 +86,11 @@ async fn send_keepalive_response(tx: Sender>, conn_id: u16) { } pub fn parse_udp(udp_mtu: u16, data_len: usize, data: &[u8]) -> Result<(&[u8], u8, u16, SocketAddr)> { - if data_len < UdpgwHeader::len() { + let header_len = UdpgwHeader::static_len(); + if data_len < header_len { return Err("Invalid udpgw data".into()); } - let header_bytes = &data[..UdpgwHeader::len()]; + let header_bytes = &data[..header_len]; let header = UdpgwHeader { flags: header_bytes[0], conn_id: u16::from_le_bytes([header_bytes[1], header_bytes[2]]), @@ -103,8 +104,8 @@ pub fn parse_udp(udp_mtu: u16, data_len: usize, data: &[u8]) -> Result<(&[u8], u return Ok((data, flags, conn_id, SocketAddrV4::new(Ipv4Addr::from(0), 0).into())); } - let ip_data = &data[UdpgwHeader::len()..]; - let mut data_len = data_len - UdpgwHeader::len(); + let ip_data = &data[header_len..]; + let mut data_len = data_len - header_len; // port_len + min(ipv4/ipv6/(domain_len + 1)) if data_len < UDPGW_LENGTH_FIELD_SIZE + 2 { return Err("Invalid udpgw data".into()); @@ -137,45 +138,39 @@ pub fn parse_udp(udp_mtu: u16, data_len: usize, data: &[u8]) -> Result<(&[u8], u Err("missing domain name".into()) } } else if flags & UDPGW_FLAG_IPV6 != 0 { - if data_len < std::mem::size_of::() { + let addr_ipv6_len = BinSocketAddr::static_len(true); + if data_len < addr_ipv6_len { return Err("Ipv6 Invalid UDP data".into()); } - let addr_ipv6_bytes = &ip_data[..std::mem::size_of::()]; - let addr_ipv6 = UdpgwAddrIpv6 { - addr_ip: addr_ipv6_bytes[..16].try_into().map_err(|_| "Failed to convert slice to array")?, - addr_port: u16::from_be_bytes([addr_ipv6_bytes[16], addr_ipv6_bytes[17]]), - }; - data_len -= std::mem::size_of::(); + let addr_ipv6 = BinSocketAddr::try_from(&ip_data[..addr_ipv6_len])?; + data_len -= addr_ipv6_len; if data_len > udp_mtu as usize { return Err("too much data".into()); } return Ok(( - &ip_data[std::mem::size_of::()..(data_len + std::mem::size_of::())], + &ip_data[addr_ipv6_len..(data_len + addr_ipv6_len)], flags, conn_id, - UdpgwAddr::IPV6(addr_ipv6).into(), + addr_ipv6.into(), )); } else { - if data_len < std::mem::size_of::() { + let addr_ipv4_len = BinSocketAddr::static_len(false); + if data_len < addr_ipv4_len { return Err("Ipv4 Invalid UDP data".into()); } - let addr_ipv4_bytes = &ip_data[..std::mem::size_of::()]; - let addr_ipv4 = UdpgwAddrIpv4 { - addr_ip: u32::from_be_bytes([addr_ipv4_bytes[0], addr_ipv4_bytes[1], addr_ipv4_bytes[2], addr_ipv4_bytes[3]]), - addr_port: u16::from_be_bytes([addr_ipv4_bytes[4], addr_ipv4_bytes[5]]), - }; - data_len -= std::mem::size_of::(); + let addr_ipv4 = BinSocketAddr::try_from(&ip_data[..addr_ipv4_len])?; + data_len -= addr_ipv4_len; if data_len > udp_mtu as usize { return Err("too much data".into()); } return Ok(( - &ip_data[std::mem::size_of::()..(data_len + std::mem::size_of::())], + &ip_data[addr_ipv4_len..(data_len + addr_ipv4_len)], flags, conn_id, - UdpgwAddr::IPV4(addr_ipv4).into(), + addr_ipv4.into(), )); } } @@ -196,24 +191,26 @@ async fn process_udp(addr: SocketAddr, udp_timeout: u64, tx: Sender>, co Ok(ret) => { let (len, _addr) = ret?; let mut packet = vec![]; - let mut pack_len = UdpgwHeader::len() + len; - match con.server_addr.into() { - UdpgwAddr::IPV4(addr_ipv4) => { - pack_len += std::mem::size_of::(); + let mut pack_len = UdpgwHeader::static_len() + len; + match con.server_addr { + SocketAddr::V4(_) => { + let addr_ipv4 = BinSocketAddr::from(con.server_addr); + pack_len += addr_ipv4.len(); packet.extend_from_slice(&(pack_len as u16).to_le_bytes()); packet.extend_from_slice(&[con.flags]); packet.extend_from_slice(&con.conn_id.to_le_bytes()); - packet.extend_from_slice(&addr_ipv4.addr_ip.to_be_bytes()); - packet.extend_from_slice(&addr_ipv4.addr_port.to_be_bytes()); + let addr_ipv4_bin: Vec = addr_ipv4.into(); + packet.extend_from_slice(&addr_ipv4_bin); packet.extend_from_slice(&con.data[..len]); } - UdpgwAddr::IPV6(addr_ipv6) => { - pack_len += std::mem::size_of::(); + SocketAddr::V6(_) => { + let addr_ipv6 = BinSocketAddr::from(con.server_addr); + pack_len += addr_ipv6.len(); packet.extend_from_slice(&(pack_len as u16).to_le_bytes()); packet.extend_from_slice(&[con.flags]); packet.extend_from_slice(&con.conn_id.to_le_bytes()); - packet.extend_from_slice(&addr_ipv6.addr_ip); - packet.extend_from_slice(&addr_ipv6.addr_port.to_be_bytes()); + let addr_ipv6_bin: Vec = addr_ipv6.into(); + packet.extend_from_slice(&addr_ipv6_bin); packet.extend_from_slice(&con.data[..len]); } } diff --git a/src/udpgw.rs b/src/udpgw.rs index 44946fd..4523c6c 100644 --- a/src/udpgw.rs +++ b/src/udpgw.rs @@ -1,10 +1,12 @@ use crate::error::Result; use ipstack::stream::IpStackUdpStream; use socks5_impl::protocol::{AsyncStreamOperation, BufMut, StreamOperation}; -use std::collections::VecDeque; -use std::hash::Hash; -use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; -use std::sync::atomic::Ordering::Relaxed; +use std::{ + collections::VecDeque, + hash::Hash, + net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}, + sync::atomic::Ordering::Relaxed, +}; use tokio::{ io::{AsyncReadExt, AsyncWriteExt}, net::{ @@ -81,7 +83,7 @@ impl StreamOperation for Packet { let mut buf = [0; UDPGW_LENGTH_FIELD_SIZE]; stream.read_exact(&mut buf)?; let length = u16::from_le_bytes(buf); - let mut buf = [0; UdpgwHeader::len()]; + let mut buf = [0; UdpgwHeader::static_len()]; stream.read_exact(&mut buf)?; let header = UdpgwHeader::try_from(&buf[..])?; let mut data = vec![0; length as usize - header.len()]; @@ -131,7 +133,7 @@ impl StreamOperation for UdpgwHeader { R: std::io::Read, Self: Sized, { - let mut buf = [0; UdpgwHeader::len()]; + let mut buf = [0; UdpgwHeader::static_len()]; stream.read_exact(&mut buf)?; UdpgwHeader::try_from(&buf[..]) } @@ -142,7 +144,7 @@ impl StreamOperation for UdpgwHeader { } fn len(&self) -> usize { - Self::len() + Self::static_len() } } @@ -153,7 +155,7 @@ impl AsyncStreamOperation for UdpgwHeader { R: tokio::io::AsyncRead + Unpin + Send, Self: Sized, { - let mut buf = [0; UdpgwHeader::len()]; + let mut buf = [0; UdpgwHeader::static_len()]; r.read_exact(&mut buf).await?; UdpgwHeader::try_from(&buf[..]) } @@ -164,7 +166,7 @@ impl UdpgwHeader { UdpgwHeader { flags, conn_id } } - pub const fn len() -> usize { + pub const fn static_len() -> usize { std::mem::size_of::() } } @@ -173,7 +175,7 @@ impl TryFrom<&[u8]> for UdpgwHeader { type Error = std::io::Error; fn try_from(value: &[u8]) -> std::result::Result { - if value.len() < UdpgwHeader::len() { + if value.len() < UdpgwHeader::static_len() { return Err(std::io::ErrorKind::InvalidData.into()); } Ok(UdpgwHeader { @@ -192,56 +194,87 @@ impl From<&UdpgwHeader> for Vec { } } -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -#[repr(C)] -#[repr(packed(1))] -pub struct UdpgwAddrIpv4 { - pub addr_ip: u32, - pub addr_port: u16, -} +#[allow(clippy::len_without_is_empty)] +#[derive(Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Debug)] +pub struct BinSocketAddr(SocketAddr); -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -#[repr(C)] -#[repr(packed(1))] -pub struct UdpgwAddrIpv6 { - pub addr_ip: [u8; 16], - pub addr_port: u16, -} +impl BinSocketAddr { + pub fn len(&self) -> usize { + match self.0 { + SocketAddr::V4(_) => Self::static_len(false), + SocketAddr::V6(_) => Self::static_len(true), + } + } -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub enum UdpgwAddr { - IPV4(UdpgwAddrIpv4), - IPV6(UdpgwAddrIpv6), -} - -impl From for UdpgwAddr { - fn from(addr: SocketAddr) -> Self { - match addr { - SocketAddr::V4(addr_v4) => { - let ipv4_addr = addr_v4.ip().octets(); - let addr_ip = u32::from_be_bytes(ipv4_addr); - UdpgwAddr::IPV4(UdpgwAddrIpv4 { - addr_ip, - addr_port: addr_v4.port(), - }) - } - SocketAddr::V6(addr_v6) => { - let ipv6_addr = addr_v6.ip().octets(); - UdpgwAddr::IPV6(UdpgwAddrIpv6 { - addr_ip: ipv6_addr, - addr_port: addr_v6.port(), - }) - } + pub fn static_len(is_ipv6: bool) -> usize { + if is_ipv6 { + std::mem::size_of::() + std::mem::size_of::() + } else { + std::mem::size_of::() + std::mem::size_of::() } } } -impl From for SocketAddr { - fn from(addr: UdpgwAddr) -> Self { - match addr { - UdpgwAddr::IPV4(addr_ipv4) => SocketAddrV4::new(Ipv4Addr::from(addr_ipv4.addr_ip), addr_ipv4.addr_port).into(), - UdpgwAddr::IPV6(addr_ipv6) => SocketAddrV6::new(Ipv6Addr::from(addr_ipv6.addr_ip), addr_ipv6.addr_port, 0, 0).into(), +impl From<&BinSocketAddr> for Vec { + fn from(addr: &BinSocketAddr) -> Vec { + socket_addr_to_binary(&addr.0) + } +} + +impl From for Vec { + fn from(addr: BinSocketAddr) -> Vec { + socket_addr_to_binary(&addr.0) + } +} + +impl TryFrom<&[u8]> for BinSocketAddr { + type Error = std::io::Error; + fn try_from(value: &[u8]) -> std::result::Result { + Ok(BinSocketAddr(binary_to_socket_addr(value)?)) + } +} + +impl From for BinSocketAddr { + fn from(addr: SocketAddr) -> Self { + BinSocketAddr(addr) + } +} + +impl From for SocketAddr { + fn from(addr: BinSocketAddr) -> Self { + addr.0 + } +} + +fn socket_addr_to_binary(addr: &SocketAddr) -> Vec { + match addr { + SocketAddr::V4(addr_v4) => { + let mut bytes = vec![0; std::mem::size_of::()]; + bytes[0..4].copy_from_slice(&addr_v4.ip().octets()); + bytes[4..6].copy_from_slice(&addr_v4.port().to_be_bytes()); + bytes } + SocketAddr::V6(addr_v6) => { + let mut bytes = vec![0; std::mem::size_of::() + std::mem::size_of::()]; + bytes[0..16].copy_from_slice(&addr_v6.ip().octets()); + bytes[16..18].copy_from_slice(&addr_v6.port().to_be_bytes()); + bytes + } + } +} + +fn binary_to_socket_addr(bytes: &[u8]) -> std::io::Result { + if bytes.len() == std::mem::size_of::() { + let ip = Ipv4Addr::new(bytes[0], bytes[1], bytes[2], bytes[3]); + let port = u16::from_be_bytes([bytes[4], bytes[5]]); + Ok(SocketAddr::V4(SocketAddrV4::new(ip, port))) + } else if bytes.len() == std::mem::size_of::() + std::mem::size_of::() { + let mut ip = [0; 16]; + ip.copy_from_slice(&bytes[0..16]); + let port = u16::from_be_bytes([bytes[16], bytes[17]]); + Ok(SocketAddr::V6(SocketAddrV6::new(ip.into(), port, 0, 0))) + } else { + Err(std::io::ErrorKind::InvalidData.into()) } } @@ -469,10 +502,11 @@ impl UdpGwClient { /// Parses the UDP response data. pub(crate) fn parse_udp_response(udp_mtu: u16, data_len: usize, stream: &mut UdpGwClientStreamReader) -> Result { let data = &stream.recv_buf; - if data_len < UdpgwHeader::len() { + let header_len = UdpgwHeader::static_len(); + if data_len < header_len { return Err("Invalid udpgw data".into()); } - let header_bytes = &data[..UdpgwHeader::len()]; + let header_bytes = &data[..header_len]; let header = UdpgwHeader { flags: header_bytes[0], conn_id: u16::from_le_bytes([header_bytes[1], header_bytes[2]]), @@ -481,8 +515,8 @@ impl UdpGwClient { let flags = header.flags; let conn_id = header.conn_id; - let ip_data = &data[UdpgwHeader::len()..]; - let mut data_len = data_len - UdpgwHeader::len(); + let ip_data = &data[header_len..]; + let mut data_len = data_len - header_len; if flags & UDPGW_FLAG_ERR != 0 { return Ok(UdpGwResponse::Error); @@ -493,15 +527,12 @@ impl UdpGwClient { } if flags & UDPGW_FLAG_IPV6 != 0 { - if data_len < std::mem::size_of::() { + let ipv6_addr_len = BinSocketAddr::static_len(true); + if data_len < ipv6_addr_len { return Err("ipv6 Invalid UDP data".into()); } - let addr_ipv6_bytes = &ip_data[..std::mem::size_of::()]; - let addr_ipv6 = UdpgwAddrIpv6 { - addr_ip: addr_ipv6_bytes[..16].try_into().map_err(|_| "Failed to convert slice to array")?, - addr_port: u16::from_be_bytes([addr_ipv6_bytes[16], addr_ipv6_bytes[17]]), - }; - data_len -= std::mem::size_of::(); + let addr_ipv6 = BinSocketAddr::try_from(&ip_data[..ipv6_addr_len])?; + data_len -= ipv6_addr_len; if data_len > udp_mtu as usize { return Err("too much data".into()); @@ -509,19 +540,16 @@ impl UdpGwClient { return Ok(UdpGwResponse::Data(UdpGwData { flags, conn_id, - remote_addr: UdpgwAddr::IPV6(addr_ipv6).into(), - udpdata: &ip_data[std::mem::size_of::()..(data_len + std::mem::size_of::())], + remote_addr: addr_ipv6.into(), + udpdata: &ip_data[ipv6_addr_len..(data_len + ipv6_addr_len)], })); } else { - if data_len < std::mem::size_of::() { + let ipv4_addr_len = BinSocketAddr::static_len(false); + if data_len < ipv4_addr_len { return Err("ipv4 Invalid UDP data".into()); } - let addr_ipv4_bytes = &ip_data[..std::mem::size_of::()]; - let addr_ipv4 = UdpgwAddrIpv4 { - addr_ip: u32::from_be_bytes([addr_ipv4_bytes[0], addr_ipv4_bytes[1], addr_ipv4_bytes[2], addr_ipv4_bytes[3]]), - addr_port: u16::from_be_bytes([addr_ipv4_bytes[4], addr_ipv4_bytes[5]]), - }; - data_len -= std::mem::size_of::(); + let addr_ipv4 = BinSocketAddr::try_from(&ip_data[..ipv4_addr_len])?; + data_len -= ipv4_addr_len; if data_len > udp_mtu as usize { return Err("too much data".into()); @@ -529,8 +557,8 @@ impl UdpGwClient { return Ok(UdpGwResponse::Data(UdpGwData { flags, conn_id, - remote_addr: UdpgwAddr::IPV4(addr_ipv4).into(), - udpdata: &ip_data[std::mem::size_of::()..(data_len + std::mem::size_of::())], + remote_addr: addr_ipv4.into(), + udpdata: &ip_data[ipv4_addr_len..(data_len + ipv4_addr_len)], })); } } @@ -620,14 +648,11 @@ impl UdpGwClient { ) -> Result<()> { stream.send_buf.clear(); let data = &stream.tmp_buf; - let mut pack_len = UdpgwHeader::len() + len; + let mut pack_len = UdpgwHeader::static_len() + len; let packet = &mut stream.send_buf; match domain { Some(domain) => { - let addr_port = match remote_addr.into() { - UdpgwAddr::IPV4(addr_ipv4) => addr_ipv4.addr_port, - UdpgwAddr::IPV6(addr_ipv6) => addr_ipv6.addr_port, - }; + let addr_port = remote_addr.port(); let domain_len = domain.len(); if domain_len > 255 { return Err("InvalidDomain".into()); @@ -642,26 +667,28 @@ impl UdpGwClient { packet.push(0); packet.extend_from_slice(&data[..len]); } - None => match remote_addr.into() { - UdpgwAddr::IPV4(addr_ipv4) => { - pack_len += std::mem::size_of::(); + None => match remote_addr { + SocketAddr::V4(_) => { + let addr_ipv4 = BinSocketAddr::from(remote_addr); + pack_len += addr_ipv4.len(); packet.extend_from_slice(&(pack_len as u16).to_le_bytes()); packet.extend_from_slice(&[UDPGW_FLAG_IPV4]); packet.extend_from_slice(&conn_id.to_le_bytes()); - packet.extend_from_slice(&addr_ipv4.addr_ip.to_be_bytes()); - packet.extend_from_slice(&addr_ipv4.addr_port.to_be_bytes()); + let addr_ipv4_bin: Vec = addr_ipv4.into(); + packet.extend_from_slice(&addr_ipv4_bin); packet.extend_from_slice(&data[..len]); } - UdpgwAddr::IPV6(addr_ipv6) => { + SocketAddr::V6(_) => { if !ipv6_enabled { return Err("ipv6 not support".into()); } - pack_len += std::mem::size_of::(); + let addr_ipv6 = BinSocketAddr::from(remote_addr); + pack_len += addr_ipv6.len(); packet.extend_from_slice(&(pack_len as u16).to_le_bytes()); packet.extend_from_slice(&[UDPGW_FLAG_IPV6]); packet.extend_from_slice(&conn_id.to_le_bytes()); - packet.extend_from_slice(&addr_ipv6.addr_ip); - packet.extend_from_slice(&addr_ipv6.addr_port.to_be_bytes()); + let addr_ipv6_bin: Vec = addr_ipv6.into(); + packet.extend_from_slice(&addr_ipv6_bin); packet.extend_from_slice(&data[..len]); } },