BinSocketAddr struct

This commit is contained in:
ssrlive 2024-10-24 13:26:20 +08:00
parent 1a99bb9f23
commit 2155ea55c0
2 changed files with 145 additions and 121 deletions

View file

@ -86,10 +86,11 @@ async fn send_keepalive_response(tx: Sender<Vec<u8>>, conn_id: u16) {
}
pub fn parse_udp(udp_mtu: u16, data_len: usize, data: &[u8]) -> Result<(&[u8], u8, u16, SocketAddr)> {
if data_len < UdpgwHeader::len() {
let header_len = UdpgwHeader::static_len();
if data_len < header_len {
return Err("Invalid udpgw data".into());
}
let header_bytes = &data[..UdpgwHeader::len()];
let header_bytes = &data[..header_len];
let header = UdpgwHeader {
flags: header_bytes[0],
conn_id: u16::from_le_bytes([header_bytes[1], header_bytes[2]]),
@ -103,8 +104,8 @@ pub fn parse_udp(udp_mtu: u16, data_len: usize, data: &[u8]) -> Result<(&[u8], u
return Ok((data, flags, conn_id, SocketAddrV4::new(Ipv4Addr::from(0), 0).into()));
}
let ip_data = &data[UdpgwHeader::len()..];
let mut data_len = data_len - UdpgwHeader::len();
let ip_data = &data[header_len..];
let mut data_len = data_len - header_len;
// port_len + min(ipv4/ipv6/(domain_len + 1))
if data_len < UDPGW_LENGTH_FIELD_SIZE + 2 {
return Err("Invalid udpgw data".into());
@ -137,45 +138,39 @@ pub fn parse_udp(udp_mtu: u16, data_len: usize, data: &[u8]) -> Result<(&[u8], u
Err("missing domain name".into())
}
} else if flags & UDPGW_FLAG_IPV6 != 0 {
if data_len < std::mem::size_of::<UdpgwAddrIpv6>() {
let addr_ipv6_len = BinSocketAddr::static_len(true);
if data_len < addr_ipv6_len {
return Err("Ipv6 Invalid UDP data".into());
}
let addr_ipv6_bytes = &ip_data[..std::mem::size_of::<UdpgwAddrIpv6>()];
let addr_ipv6 = UdpgwAddrIpv6 {
addr_ip: addr_ipv6_bytes[..16].try_into().map_err(|_| "Failed to convert slice to array")?,
addr_port: u16::from_be_bytes([addr_ipv6_bytes[16], addr_ipv6_bytes[17]]),
};
data_len -= std::mem::size_of::<UdpgwAddrIpv6>();
let addr_ipv6 = BinSocketAddr::try_from(&ip_data[..addr_ipv6_len])?;
data_len -= addr_ipv6_len;
if data_len > udp_mtu as usize {
return Err("too much data".into());
}
return Ok((
&ip_data[std::mem::size_of::<UdpgwAddrIpv6>()..(data_len + std::mem::size_of::<UdpgwAddrIpv6>())],
&ip_data[addr_ipv6_len..(data_len + addr_ipv6_len)],
flags,
conn_id,
UdpgwAddr::IPV6(addr_ipv6).into(),
addr_ipv6.into(),
));
} else {
if data_len < std::mem::size_of::<UdpgwAddrIpv4>() {
let addr_ipv4_len = BinSocketAddr::static_len(false);
if data_len < addr_ipv4_len {
return Err("Ipv4 Invalid UDP data".into());
}
let addr_ipv4_bytes = &ip_data[..std::mem::size_of::<UdpgwAddrIpv4>()];
let addr_ipv4 = UdpgwAddrIpv4 {
addr_ip: u32::from_be_bytes([addr_ipv4_bytes[0], addr_ipv4_bytes[1], addr_ipv4_bytes[2], addr_ipv4_bytes[3]]),
addr_port: u16::from_be_bytes([addr_ipv4_bytes[4], addr_ipv4_bytes[5]]),
};
data_len -= std::mem::size_of::<UdpgwAddrIpv4>();
let addr_ipv4 = BinSocketAddr::try_from(&ip_data[..addr_ipv4_len])?;
data_len -= addr_ipv4_len;
if data_len > udp_mtu as usize {
return Err("too much data".into());
}
return Ok((
&ip_data[std::mem::size_of::<UdpgwAddrIpv4>()..(data_len + std::mem::size_of::<UdpgwAddrIpv4>())],
&ip_data[addr_ipv4_len..(data_len + addr_ipv4_len)],
flags,
conn_id,
UdpgwAddr::IPV4(addr_ipv4).into(),
addr_ipv4.into(),
));
}
}
@ -196,24 +191,26 @@ async fn process_udp(addr: SocketAddr, udp_timeout: u64, tx: Sender<Vec<u8>>, co
Ok(ret) => {
let (len, _addr) = ret?;
let mut packet = vec![];
let mut pack_len = UdpgwHeader::len() + len;
match con.server_addr.into() {
UdpgwAddr::IPV4(addr_ipv4) => {
pack_len += std::mem::size_of::<UdpgwAddrIpv4>();
let mut pack_len = UdpgwHeader::static_len() + len;
match con.server_addr {
SocketAddr::V4(_) => {
let addr_ipv4 = BinSocketAddr::from(con.server_addr);
pack_len += addr_ipv4.len();
packet.extend_from_slice(&(pack_len as u16).to_le_bytes());
packet.extend_from_slice(&[con.flags]);
packet.extend_from_slice(&con.conn_id.to_le_bytes());
packet.extend_from_slice(&addr_ipv4.addr_ip.to_be_bytes());
packet.extend_from_slice(&addr_ipv4.addr_port.to_be_bytes());
let addr_ipv4_bin: Vec<u8> = addr_ipv4.into();
packet.extend_from_slice(&addr_ipv4_bin);
packet.extend_from_slice(&con.data[..len]);
}
UdpgwAddr::IPV6(addr_ipv6) => {
pack_len += std::mem::size_of::<UdpgwAddrIpv6>();
SocketAddr::V6(_) => {
let addr_ipv6 = BinSocketAddr::from(con.server_addr);
pack_len += addr_ipv6.len();
packet.extend_from_slice(&(pack_len as u16).to_le_bytes());
packet.extend_from_slice(&[con.flags]);
packet.extend_from_slice(&con.conn_id.to_le_bytes());
packet.extend_from_slice(&addr_ipv6.addr_ip);
packet.extend_from_slice(&addr_ipv6.addr_port.to_be_bytes());
let addr_ipv6_bin: Vec<u8> = addr_ipv6.into();
packet.extend_from_slice(&addr_ipv6_bin);
packet.extend_from_slice(&con.data[..len]);
}
}

View file

@ -1,10 +1,12 @@
use crate::error::Result;
use ipstack::stream::IpStackUdpStream;
use socks5_impl::protocol::{AsyncStreamOperation, BufMut, StreamOperation};
use std::collections::VecDeque;
use std::hash::Hash;
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
use std::sync::atomic::Ordering::Relaxed;
use std::{
collections::VecDeque,
hash::Hash,
net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6},
sync::atomic::Ordering::Relaxed,
};
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::{
@ -81,7 +83,7 @@ impl StreamOperation for Packet {
let mut buf = [0; UDPGW_LENGTH_FIELD_SIZE];
stream.read_exact(&mut buf)?;
let length = u16::from_le_bytes(buf);
let mut buf = [0; UdpgwHeader::len()];
let mut buf = [0; UdpgwHeader::static_len()];
stream.read_exact(&mut buf)?;
let header = UdpgwHeader::try_from(&buf[..])?;
let mut data = vec![0; length as usize - header.len()];
@ -131,7 +133,7 @@ impl StreamOperation for UdpgwHeader {
R: std::io::Read,
Self: Sized,
{
let mut buf = [0; UdpgwHeader::len()];
let mut buf = [0; UdpgwHeader::static_len()];
stream.read_exact(&mut buf)?;
UdpgwHeader::try_from(&buf[..])
}
@ -142,7 +144,7 @@ impl StreamOperation for UdpgwHeader {
}
fn len(&self) -> usize {
Self::len()
Self::static_len()
}
}
@ -153,7 +155,7 @@ impl AsyncStreamOperation for UdpgwHeader {
R: tokio::io::AsyncRead + Unpin + Send,
Self: Sized,
{
let mut buf = [0; UdpgwHeader::len()];
let mut buf = [0; UdpgwHeader::static_len()];
r.read_exact(&mut buf).await?;
UdpgwHeader::try_from(&buf[..])
}
@ -164,7 +166,7 @@ impl UdpgwHeader {
UdpgwHeader { flags, conn_id }
}
pub const fn len() -> usize {
pub const fn static_len() -> usize {
std::mem::size_of::<UdpgwHeader>()
}
}
@ -173,7 +175,7 @@ impl TryFrom<&[u8]> for UdpgwHeader {
type Error = std::io::Error;
fn try_from(value: &[u8]) -> std::result::Result<Self, Self::Error> {
if value.len() < UdpgwHeader::len() {
if value.len() < UdpgwHeader::static_len() {
return Err(std::io::ErrorKind::InvalidData.into());
}
Ok(UdpgwHeader {
@ -192,56 +194,87 @@ impl From<&UdpgwHeader> for Vec<u8> {
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[repr(C)]
#[repr(packed(1))]
pub struct UdpgwAddrIpv4 {
pub addr_ip: u32,
pub addr_port: u16,
#[allow(clippy::len_without_is_empty)]
#[derive(Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Debug)]
pub struct BinSocketAddr(SocketAddr);
impl BinSocketAddr {
pub fn len(&self) -> usize {
match self.0 {
SocketAddr::V4(_) => Self::static_len(false),
SocketAddr::V6(_) => Self::static_len(true),
}
}
pub fn static_len(is_ipv6: bool) -> usize {
if is_ipv6 {
std::mem::size_of::<Ipv6Addr>() + std::mem::size_of::<u16>()
} else {
std::mem::size_of::<Ipv4Addr>() + std::mem::size_of::<u16>()
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[repr(C)]
#[repr(packed(1))]
pub struct UdpgwAddrIpv6 {
pub addr_ip: [u8; 16],
pub addr_port: u16,
impl From<&BinSocketAddr> for Vec<u8> {
fn from(addr: &BinSocketAddr) -> Vec<u8> {
socket_addr_to_binary(&addr.0)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum UdpgwAddr {
IPV4(UdpgwAddrIpv4),
IPV6(UdpgwAddrIpv6),
impl From<BinSocketAddr> for Vec<u8> {
fn from(addr: BinSocketAddr) -> Vec<u8> {
socket_addr_to_binary(&addr.0)
}
}
impl From<SocketAddr> for UdpgwAddr {
impl TryFrom<&[u8]> for BinSocketAddr {
type Error = std::io::Error;
fn try_from(value: &[u8]) -> std::result::Result<Self, Self::Error> {
Ok(BinSocketAddr(binary_to_socket_addr(value)?))
}
}
impl From<SocketAddr> for BinSocketAddr {
fn from(addr: SocketAddr) -> Self {
BinSocketAddr(addr)
}
}
impl From<BinSocketAddr> for SocketAddr {
fn from(addr: BinSocketAddr) -> Self {
addr.0
}
}
fn socket_addr_to_binary(addr: &SocketAddr) -> Vec<u8> {
match addr {
SocketAddr::V4(addr_v4) => {
let ipv4_addr = addr_v4.ip().octets();
let addr_ip = u32::from_be_bytes(ipv4_addr);
UdpgwAddr::IPV4(UdpgwAddrIpv4 {
addr_ip,
addr_port: addr_v4.port(),
})
let mut bytes = vec![0; std::mem::size_of::<SocketAddrV4>()];
bytes[0..4].copy_from_slice(&addr_v4.ip().octets());
bytes[4..6].copy_from_slice(&addr_v4.port().to_be_bytes());
bytes
}
SocketAddr::V6(addr_v6) => {
let ipv6_addr = addr_v6.ip().octets();
UdpgwAddr::IPV6(UdpgwAddrIpv6 {
addr_ip: ipv6_addr,
addr_port: addr_v6.port(),
})
}
let mut bytes = vec![0; std::mem::size_of::<Ipv6Addr>() + std::mem::size_of::<u16>()];
bytes[0..16].copy_from_slice(&addr_v6.ip().octets());
bytes[16..18].copy_from_slice(&addr_v6.port().to_be_bytes());
bytes
}
}
}
impl From<UdpgwAddr> for SocketAddr {
fn from(addr: UdpgwAddr) -> Self {
match addr {
UdpgwAddr::IPV4(addr_ipv4) => SocketAddrV4::new(Ipv4Addr::from(addr_ipv4.addr_ip), addr_ipv4.addr_port).into(),
UdpgwAddr::IPV6(addr_ipv6) => SocketAddrV6::new(Ipv6Addr::from(addr_ipv6.addr_ip), addr_ipv6.addr_port, 0, 0).into(),
}
fn binary_to_socket_addr(bytes: &[u8]) -> std::io::Result<SocketAddr> {
if bytes.len() == std::mem::size_of::<SocketAddrV4>() {
let ip = Ipv4Addr::new(bytes[0], bytes[1], bytes[2], bytes[3]);
let port = u16::from_be_bytes([bytes[4], bytes[5]]);
Ok(SocketAddr::V4(SocketAddrV4::new(ip, port)))
} else if bytes.len() == std::mem::size_of::<Ipv6Addr>() + std::mem::size_of::<u16>() {
let mut ip = [0; 16];
ip.copy_from_slice(&bytes[0..16]);
let port = u16::from_be_bytes([bytes[16], bytes[17]]);
Ok(SocketAddr::V6(SocketAddrV6::new(ip.into(), port, 0, 0)))
} else {
Err(std::io::ErrorKind::InvalidData.into())
}
}
@ -469,10 +502,11 @@ impl UdpGwClient {
/// Parses the UDP response data.
pub(crate) fn parse_udp_response(udp_mtu: u16, data_len: usize, stream: &mut UdpGwClientStreamReader) -> Result<UdpGwResponse> {
let data = &stream.recv_buf;
if data_len < UdpgwHeader::len() {
let header_len = UdpgwHeader::static_len();
if data_len < header_len {
return Err("Invalid udpgw data".into());
}
let header_bytes = &data[..UdpgwHeader::len()];
let header_bytes = &data[..header_len];
let header = UdpgwHeader {
flags: header_bytes[0],
conn_id: u16::from_le_bytes([header_bytes[1], header_bytes[2]]),
@ -481,8 +515,8 @@ impl UdpGwClient {
let flags = header.flags;
let conn_id = header.conn_id;
let ip_data = &data[UdpgwHeader::len()..];
let mut data_len = data_len - UdpgwHeader::len();
let ip_data = &data[header_len..];
let mut data_len = data_len - header_len;
if flags & UDPGW_FLAG_ERR != 0 {
return Ok(UdpGwResponse::Error);
@ -493,15 +527,12 @@ impl UdpGwClient {
}
if flags & UDPGW_FLAG_IPV6 != 0 {
if data_len < std::mem::size_of::<UdpgwAddrIpv6>() {
let ipv6_addr_len = BinSocketAddr::static_len(true);
if data_len < ipv6_addr_len {
return Err("ipv6 Invalid UDP data".into());
}
let addr_ipv6_bytes = &ip_data[..std::mem::size_of::<UdpgwAddrIpv6>()];
let addr_ipv6 = UdpgwAddrIpv6 {
addr_ip: addr_ipv6_bytes[..16].try_into().map_err(|_| "Failed to convert slice to array")?,
addr_port: u16::from_be_bytes([addr_ipv6_bytes[16], addr_ipv6_bytes[17]]),
};
data_len -= std::mem::size_of::<UdpgwAddrIpv6>();
let addr_ipv6 = BinSocketAddr::try_from(&ip_data[..ipv6_addr_len])?;
data_len -= ipv6_addr_len;
if data_len > udp_mtu as usize {
return Err("too much data".into());
@ -509,19 +540,16 @@ impl UdpGwClient {
return Ok(UdpGwResponse::Data(UdpGwData {
flags,
conn_id,
remote_addr: UdpgwAddr::IPV6(addr_ipv6).into(),
udpdata: &ip_data[std::mem::size_of::<UdpgwAddrIpv6>()..(data_len + std::mem::size_of::<UdpgwAddrIpv6>())],
remote_addr: addr_ipv6.into(),
udpdata: &ip_data[ipv6_addr_len..(data_len + ipv6_addr_len)],
}));
} else {
if data_len < std::mem::size_of::<UdpgwAddrIpv4>() {
let ipv4_addr_len = BinSocketAddr::static_len(false);
if data_len < ipv4_addr_len {
return Err("ipv4 Invalid UDP data".into());
}
let addr_ipv4_bytes = &ip_data[..std::mem::size_of::<UdpgwAddrIpv4>()];
let addr_ipv4 = UdpgwAddrIpv4 {
addr_ip: u32::from_be_bytes([addr_ipv4_bytes[0], addr_ipv4_bytes[1], addr_ipv4_bytes[2], addr_ipv4_bytes[3]]),
addr_port: u16::from_be_bytes([addr_ipv4_bytes[4], addr_ipv4_bytes[5]]),
};
data_len -= std::mem::size_of::<UdpgwAddrIpv4>();
let addr_ipv4 = BinSocketAddr::try_from(&ip_data[..ipv4_addr_len])?;
data_len -= ipv4_addr_len;
if data_len > udp_mtu as usize {
return Err("too much data".into());
@ -529,8 +557,8 @@ impl UdpGwClient {
return Ok(UdpGwResponse::Data(UdpGwData {
flags,
conn_id,
remote_addr: UdpgwAddr::IPV4(addr_ipv4).into(),
udpdata: &ip_data[std::mem::size_of::<UdpgwAddrIpv4>()..(data_len + std::mem::size_of::<UdpgwAddrIpv4>())],
remote_addr: addr_ipv4.into(),
udpdata: &ip_data[ipv4_addr_len..(data_len + ipv4_addr_len)],
}));
}
}
@ -620,14 +648,11 @@ impl UdpGwClient {
) -> Result<()> {
stream.send_buf.clear();
let data = &stream.tmp_buf;
let mut pack_len = UdpgwHeader::len() + len;
let mut pack_len = UdpgwHeader::static_len() + len;
let packet = &mut stream.send_buf;
match domain {
Some(domain) => {
let addr_port = match remote_addr.into() {
UdpgwAddr::IPV4(addr_ipv4) => addr_ipv4.addr_port,
UdpgwAddr::IPV6(addr_ipv6) => addr_ipv6.addr_port,
};
let addr_port = remote_addr.port();
let domain_len = domain.len();
if domain_len > 255 {
return Err("InvalidDomain".into());
@ -642,26 +667,28 @@ impl UdpGwClient {
packet.push(0);
packet.extend_from_slice(&data[..len]);
}
None => match remote_addr.into() {
UdpgwAddr::IPV4(addr_ipv4) => {
pack_len += std::mem::size_of::<UdpgwAddrIpv4>();
None => match remote_addr {
SocketAddr::V4(_) => {
let addr_ipv4 = BinSocketAddr::from(remote_addr);
pack_len += addr_ipv4.len();
packet.extend_from_slice(&(pack_len as u16).to_le_bytes());
packet.extend_from_slice(&[UDPGW_FLAG_IPV4]);
packet.extend_from_slice(&conn_id.to_le_bytes());
packet.extend_from_slice(&addr_ipv4.addr_ip.to_be_bytes());
packet.extend_from_slice(&addr_ipv4.addr_port.to_be_bytes());
let addr_ipv4_bin: Vec<u8> = addr_ipv4.into();
packet.extend_from_slice(&addr_ipv4_bin);
packet.extend_from_slice(&data[..len]);
}
UdpgwAddr::IPV6(addr_ipv6) => {
SocketAddr::V6(_) => {
if !ipv6_enabled {
return Err("ipv6 not support".into());
}
pack_len += std::mem::size_of::<UdpgwAddrIpv6>();
let addr_ipv6 = BinSocketAddr::from(remote_addr);
pack_len += addr_ipv6.len();
packet.extend_from_slice(&(pack_len as u16).to_le_bytes());
packet.extend_from_slice(&[UDPGW_FLAG_IPV6]);
packet.extend_from_slice(&conn_id.to_le_bytes());
packet.extend_from_slice(&addr_ipv6.addr_ip);
packet.extend_from_slice(&addr_ipv6.addr_port.to_be_bytes());
let addr_ipv6_bin: Vec<u8> = addr_ipv6.into();
packet.extend_from_slice(&addr_ipv6_bin);
packet.extend_from_slice(&data[..len]);
}
},