mirror of
https://github.com/tun2proxy/tun2proxy.git
synced 2025-04-19 21:39:09 +00:00
refactor udpgw with Packet struct
This commit is contained in:
parent
f049e4a998
commit
ec79d54a9e
4 changed files with 241 additions and 127 deletions
|
@ -117,7 +117,7 @@ pub struct Args {
|
|||
#[arg(long, value_name = "IP:PORT")]
|
||||
pub udpgw_server: Option<SocketAddr>,
|
||||
|
||||
/// Max udpgw connections
|
||||
/// Max udpgw connections, default value is 100
|
||||
#[cfg(feature = "udpgw")]
|
||||
#[arg(long, value_name = "number")]
|
||||
pub udpgw_max_connections: Option<u16>,
|
||||
|
|
|
@ -18,7 +18,7 @@ pub(crate) const CLIENT_DISCONNECT_TIMEOUT: tokio::time::Duration = std::time::D
|
|||
struct UdpRequest {
|
||||
flags: u8,
|
||||
server_addr: SocketAddr,
|
||||
conid: u16,
|
||||
conn_id: u16,
|
||||
data: Vec<u8>,
|
||||
}
|
||||
|
||||
|
@ -41,26 +41,26 @@ impl Client {
|
|||
|
||||
#[derive(Debug, Clone, clap::Parser)]
|
||||
pub struct UdpGwArgs {
|
||||
/// UDP gateway listen address
|
||||
#[arg(short, long, value_name = "IP:PORT", default_value = "127.0.0.1:7300")]
|
||||
pub listen_addr: SocketAddr,
|
||||
|
||||
/// UDP mtu
|
||||
#[arg(long, value_name = "udp mtu", default_value = "10240")]
|
||||
#[arg(short = 'm', long, value_name = "udp mtu", default_value = "10240")]
|
||||
pub udp_mtu: u16,
|
||||
|
||||
/// Verbosity level
|
||||
#[arg(short, long, value_name = "level", value_enum, default_value = "info")]
|
||||
pub verbosity: ArgVerbosity,
|
||||
/// UDP timeout in seconds
|
||||
#[arg(short = 't', long, value_name = "seconds", default_value = "3")]
|
||||
pub udp_timeout: u64,
|
||||
|
||||
/// Daemonize for unix family or run as Windows service
|
||||
#[cfg(unix)]
|
||||
#[arg(long)]
|
||||
pub daemonize: bool,
|
||||
|
||||
/// UDP timeout in seconds
|
||||
#[arg(long, value_name = "seconds", default_value = "3")]
|
||||
pub udp_timeout: u64,
|
||||
|
||||
/// UDP gateway listen address
|
||||
#[arg(long, value_name = "IP:PORT", default_value = "127.0.0.1:7300")]
|
||||
pub listen_addr: SocketAddr,
|
||||
/// Verbosity level
|
||||
#[arg(short, long, value_name = "level", value_enum, default_value = "info")]
|
||||
pub verbosity: ArgVerbosity,
|
||||
}
|
||||
|
||||
impl UdpGwArgs {
|
||||
|
@ -70,42 +70,43 @@ impl UdpGwArgs {
|
|||
Self::parse()
|
||||
}
|
||||
}
|
||||
|
||||
async fn send_error(tx: Sender<Vec<u8>>, con: &mut UdpRequest) {
|
||||
let error_packet = UdpgwHeader::new(UDPGW_FLAG_ERR, con.conid).into();
|
||||
let error_packet: Vec<u8> = Packet::new(UdpgwHeader::new(UDPGW_FLAG_ERR, con.conn_id), vec![]).into();
|
||||
if let Err(e) = tx.send(error_packet).await {
|
||||
log::error!("send error response error {:?}", e);
|
||||
}
|
||||
}
|
||||
|
||||
async fn send_keepalive_response(tx: Sender<Vec<u8>>, conid: u16) {
|
||||
let keepalive_packet = UdpgwHeader::new(UDPGW_FLAG_KEEPALIVE, conid).into();
|
||||
async fn send_keepalive_response(tx: Sender<Vec<u8>>, conn_id: u16) {
|
||||
let keepalive_packet: Vec<u8> = Packet::new(UdpgwHeader::new(UDPGW_FLAG_KEEPALIVE, conn_id), vec![]).into();
|
||||
if let Err(e) = tx.send(keepalive_packet).await {
|
||||
log::error!("send keepalive response error {:?}", e);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn parse_udp(udp_mtu: u16, data_len: usize, data: &[u8]) -> Result<(&[u8], u8, u16, SocketAddr)> {
|
||||
if data_len < std::mem::size_of::<UdpgwHeader>() {
|
||||
if data_len < UdpgwHeader::len() {
|
||||
return Err("Invalid udpgw data".into());
|
||||
}
|
||||
let header_bytes = &data[..std::mem::size_of::<UdpgwHeader>()];
|
||||
let header_bytes = &data[..UdpgwHeader::len()];
|
||||
let header = UdpgwHeader {
|
||||
flags: header_bytes[0],
|
||||
conid: u16::from_le_bytes([header_bytes[1], header_bytes[2]]),
|
||||
conn_id: u16::from_le_bytes([header_bytes[1], header_bytes[2]]),
|
||||
};
|
||||
|
||||
let flags = header.flags;
|
||||
let conid = header.conid;
|
||||
let conn_id = header.conn_id;
|
||||
|
||||
// keepalive
|
||||
if flags & UDPGW_FLAG_KEEPALIVE != 0 {
|
||||
return Ok((data, flags, conid, SocketAddrV4::new(Ipv4Addr::from(0), 0).into()));
|
||||
return Ok((data, flags, conn_id, SocketAddrV4::new(Ipv4Addr::from(0), 0).into()));
|
||||
}
|
||||
|
||||
let ip_data = &data[std::mem::size_of::<UdpgwHeader>()..];
|
||||
let mut data_len = data_len - std::mem::size_of::<UdpgwHeader>();
|
||||
let ip_data = &data[UdpgwHeader::len()..];
|
||||
let mut data_len = data_len - UdpgwHeader::len();
|
||||
// port_len + min(ipv4/ipv6/(domain_len + 1))
|
||||
if data_len < std::mem::size_of::<u16>() + 2 {
|
||||
if data_len < UDPGW_LENGTH_FIELD_SIZE + 2 {
|
||||
return Err("Invalid udpgw data".into());
|
||||
}
|
||||
if flags & UDPGW_FLAG_DOMAIN != 0 {
|
||||
|
@ -128,7 +129,7 @@ pub fn parse_udp(udp_mtu: u16, data_len: usize, data: &[u8]) -> Result<(&[u8], u
|
|||
return Err("too much data".into());
|
||||
}
|
||||
let udpdata = &ip_data[(2 + domain.len() + 1)..];
|
||||
Ok((udpdata, flags, conid, target))
|
||||
Ok((udpdata, flags, conn_id, target))
|
||||
}
|
||||
Err(_) => Err("Invalid UTF-8 sequence in domain".into()),
|
||||
}
|
||||
|
@ -152,7 +153,7 @@ pub fn parse_udp(udp_mtu: u16, data_len: usize, data: &[u8]) -> Result<(&[u8], u
|
|||
return Ok((
|
||||
&ip_data[std::mem::size_of::<UdpgwAddrIpv6>()..(data_len + std::mem::size_of::<UdpgwAddrIpv6>())],
|
||||
flags,
|
||||
conid,
|
||||
conn_id,
|
||||
UdpgwAddr::IPV6(addr_ipv6).into(),
|
||||
));
|
||||
} else {
|
||||
|
@ -173,7 +174,7 @@ pub fn parse_udp(udp_mtu: u16, data_len: usize, data: &[u8]) -> Result<(&[u8], u
|
|||
return Ok((
|
||||
&ip_data[std::mem::size_of::<UdpgwAddrIpv4>()..(data_len + std::mem::size_of::<UdpgwAddrIpv4>())],
|
||||
flags,
|
||||
conid,
|
||||
conn_id,
|
||||
UdpgwAddr::IPV4(addr_ipv4).into(),
|
||||
));
|
||||
}
|
||||
|
@ -195,13 +196,13 @@ async fn process_udp(addr: SocketAddr, udp_timeout: u64, tx: Sender<Vec<u8>>, co
|
|||
Ok(ret) => {
|
||||
let (len, _addr) = ret?;
|
||||
let mut packet = vec![];
|
||||
let mut pack_len = std::mem::size_of::<UdpgwHeader>() + len;
|
||||
let mut pack_len = UdpgwHeader::len() + len;
|
||||
match con.server_addr.into() {
|
||||
UdpgwAddr::IPV4(addr_ipv4) => {
|
||||
pack_len += std::mem::size_of::<UdpgwAddrIpv4>();
|
||||
packet.extend_from_slice(&(pack_len as u16).to_le_bytes());
|
||||
packet.extend_from_slice(&[con.flags]);
|
||||
packet.extend_from_slice(&con.conid.to_le_bytes());
|
||||
packet.extend_from_slice(&con.conn_id.to_le_bytes());
|
||||
packet.extend_from_slice(&addr_ipv4.addr_ip.to_be_bytes());
|
||||
packet.extend_from_slice(&addr_ipv4.addr_port.to_be_bytes());
|
||||
packet.extend_from_slice(&con.data[..len]);
|
||||
|
@ -210,7 +211,7 @@ async fn process_udp(addr: SocketAddr, udp_timeout: u64, tx: Sender<Vec<u8>>, co
|
|||
pack_len += std::mem::size_of::<UdpgwAddrIpv6>();
|
||||
packet.extend_from_slice(&(pack_len as u16).to_le_bytes());
|
||||
packet.extend_from_slice(&[con.flags]);
|
||||
packet.extend_from_slice(&con.conid.to_le_bytes());
|
||||
packet.extend_from_slice(&con.conn_id.to_le_bytes());
|
||||
packet.extend_from_slice(&addr_ipv6.addr_ip);
|
||||
packet.extend_from_slice(&addr_ipv6.addr_port.to_be_bytes());
|
||||
packet.extend_from_slice(&con.data[..len]);
|
||||
|
@ -230,7 +231,7 @@ async fn process_udp(addr: SocketAddr, udp_timeout: u64, tx: Sender<Vec<u8>>, co
|
|||
async fn process_client_udp_req(args: &UdpGwArgs, tx: Sender<Vec<u8>>, client: Client, mut reader: ReadHalf<'_>) -> std::io::Result<()> {
|
||||
let mut client = client;
|
||||
let mut buf = vec![0; args.udp_mtu as usize];
|
||||
let mut len_buf = [0; std::mem::size_of::<PackLenHeader>()];
|
||||
let mut len_buf = [0; UDPGW_LENGTH_FIELD_SIZE];
|
||||
let udp_mtu = args.udp_mtu;
|
||||
let udp_timeout = args.udp_timeout;
|
||||
|
||||
|
@ -250,8 +251,8 @@ async fn process_client_udp_req(args: &UdpGwArgs, tx: Sender<Vec<u8>>, client: C
|
|||
// Connection closed
|
||||
break;
|
||||
}
|
||||
if n < std::mem::size_of::<PackLenHeader>() {
|
||||
log::error!("client {} received PackLenHeader error", client.addr);
|
||||
if n < UDPGW_LENGTH_FIELD_SIZE {
|
||||
log::error!("client {} received Packet Length field error", client.addr);
|
||||
break;
|
||||
}
|
||||
let packet_len = u16::from_le_bytes([len_buf[0], len_buf[1]]);
|
||||
|
@ -272,23 +273,23 @@ async fn process_client_udp_req(args: &UdpGwArgs, tx: Sender<Vec<u8>>, client: C
|
|||
}
|
||||
client.last_activity = std::time::Instant::now();
|
||||
let ret = parse_udp(udp_mtu, client.buf.len(), &client.buf);
|
||||
if let Ok((udpdata, flags, conid, reqaddr)) = ret {
|
||||
if let Ok((udpdata, flags, conn_id, reqaddr)) = ret {
|
||||
if flags & UDPGW_FLAG_KEEPALIVE != 0 {
|
||||
log::debug!("client {} send keepalive", client.addr);
|
||||
send_keepalive_response(tx.clone(), conid).await;
|
||||
send_keepalive_response(tx.clone(), conn_id).await;
|
||||
continue;
|
||||
}
|
||||
log::debug!(
|
||||
"client {} received udp data,flags:{},conid:{},addr:{:?},data len:{}",
|
||||
"client {} received udp data,flags:{},conn_id:{},addr:{:?},data len:{}",
|
||||
client.addr,
|
||||
flags,
|
||||
conid,
|
||||
conn_id,
|
||||
reqaddr,
|
||||
udpdata.len()
|
||||
);
|
||||
let mut req = UdpRequest {
|
||||
server_addr: reqaddr,
|
||||
conid,
|
||||
conn_id,
|
||||
flags,
|
||||
data: udpdata.to_vec(),
|
||||
};
|
||||
|
|
12
src/lib.rs
12
src/lib.rs
|
@ -513,12 +513,8 @@ async fn handle_udp_gateway_session(
|
|||
let tcp_local_addr = server_stream.local_addr().clone();
|
||||
|
||||
match domain_name {
|
||||
Some(ref d) => {
|
||||
log::info!("Beginning {} <- {}, domain:{}", udpinfo, &tcp_local_addr, d);
|
||||
}
|
||||
None => {
|
||||
log::info!("Beginning {} <- {}", udpinfo, &tcp_local_addr);
|
||||
}
|
||||
Some(ref d) => log::info!("Beginning {} <- {}, domain:{}", udpinfo, &tcp_local_addr, d),
|
||||
None => log::info!("Beginning {} <- {}", udpinfo, &tcp_local_addr),
|
||||
}
|
||||
|
||||
let Some(mut stream_reader) = server_stream.get_reader() else {
|
||||
|
@ -547,8 +543,8 @@ async fn handle_udp_gateway_session(
|
|||
break;
|
||||
}
|
||||
}
|
||||
let newid = server_stream.newid();
|
||||
if let Err(e) = UdpGwClient::send_udpgw_packet(ipv6_enabled, read_len, udp_server_addr, domain_name.as_ref(), newid, &mut stream_writer).await {
|
||||
let new_id = server_stream.new_id();
|
||||
if let Err(e) = UdpGwClient::send_udpgw_packet(ipv6_enabled, read_len, udp_server_addr, domain_name.as_ref(), new_id, &mut stream_writer).await {
|
||||
log::info!("Ending {} <- {} with send_udpgw_packet {}", udpinfo, &tcp_local_addr, e);
|
||||
break;
|
||||
}
|
||||
|
|
279
src/udpgw.rs
279
src/udpgw.rs
|
@ -1,47 +1,171 @@
|
|||
use crate::error::Result;
|
||||
use ipstack::stream::IpStackUdpStream;
|
||||
use socks5_impl::protocol::{AsyncStreamOperation, BufMut, StreamOperation};
|
||||
use std::collections::VecDeque;
|
||||
use std::hash::Hash;
|
||||
use std::mem;
|
||||
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
|
||||
use std::sync::atomic::Ordering::Relaxed;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::time::{sleep, Duration};
|
||||
use tokio::{
|
||||
io::{AsyncReadExt, AsyncWriteExt},
|
||||
net::{
|
||||
tcp::{OwnedReadHalf, OwnedWriteHalf},
|
||||
TcpStream,
|
||||
},
|
||||
sync::Mutex,
|
||||
time::{sleep, Duration},
|
||||
};
|
||||
|
||||
pub const UDPGW_MAX_CONNECTIONS: usize = 100;
|
||||
pub const UDPGW_KEEPALIVE_TIME: tokio::time::Duration = std::time::Duration::from_secs(10);
|
||||
pub const UDPGW_FLAG_KEEPALIVE: u8 = 0x01;
|
||||
pub const UDPGW_FLAG_IPV4: u8 = 0x00;
|
||||
pub const UDPGW_FLAG_IPV6: u8 = 0x08;
|
||||
pub const UDPGW_FLAG_DOMAIN: u8 = 0x10;
|
||||
pub const UDPGW_FLAG_ERR: u8 = 0x20;
|
||||
|
||||
pub const UDPGW_LENGTH_FIELD_SIZE: usize = std::mem::size_of::<u16>();
|
||||
|
||||
static TCP_COUNTER: std::sync::atomic::AtomicU32 = std::sync::atomic::AtomicU32::new(0);
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
#[repr(C)]
|
||||
#[repr(packed(1))]
|
||||
pub struct PackLenHeader {
|
||||
packet_len: u16,
|
||||
pub struct Packet {
|
||||
pub length: u16,
|
||||
pub header: UdpgwHeader,
|
||||
pub data: Vec<u8>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
impl From<Packet> for Vec<u8> {
|
||||
fn from(packet: Packet) -> Vec<u8> {
|
||||
(&packet).into()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&Packet> for Vec<u8> {
|
||||
fn from(packet: &Packet) -> Vec<u8> {
|
||||
let mut bytes = vec![0; packet.len()];
|
||||
packet.write_to_buf(&mut bytes);
|
||||
bytes
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&[u8]> for Packet {
|
||||
type Error = std::io::Error;
|
||||
|
||||
fn try_from(value: &[u8]) -> std::result::Result<Self, Self::Error> {
|
||||
if value.len() < UDPGW_LENGTH_FIELD_SIZE {
|
||||
return Err(std::io::ErrorKind::InvalidData.into());
|
||||
}
|
||||
let length = u16::from_le_bytes([value[0], value[1]]);
|
||||
if value.len() < length as usize + UDPGW_LENGTH_FIELD_SIZE {
|
||||
return Err(std::io::ErrorKind::InvalidData.into());
|
||||
}
|
||||
let header = UdpgwHeader::try_from(&value[UDPGW_LENGTH_FIELD_SIZE..])?;
|
||||
let data = value[UDPGW_LENGTH_FIELD_SIZE + header.len()..].to_vec();
|
||||
Ok(Packet::new(header, data))
|
||||
}
|
||||
}
|
||||
|
||||
impl Packet {
|
||||
pub fn new(header: UdpgwHeader, data: Vec<u8>) -> Self {
|
||||
let length = (header.len() + data.len()) as u16;
|
||||
Packet { length, header, data }
|
||||
}
|
||||
}
|
||||
|
||||
impl StreamOperation for Packet {
|
||||
fn retrieve_from_stream<R>(stream: &mut R) -> std::io::Result<Self>
|
||||
where
|
||||
R: std::io::Read,
|
||||
Self: Sized,
|
||||
{
|
||||
let mut buf = [0; UDPGW_LENGTH_FIELD_SIZE];
|
||||
stream.read_exact(&mut buf)?;
|
||||
let length = u16::from_le_bytes(buf);
|
||||
let mut buf = [0; UdpgwHeader::len()];
|
||||
stream.read_exact(&mut buf)?;
|
||||
let header = UdpgwHeader::try_from(&buf[..])?;
|
||||
let mut data = vec![0; length as usize - header.len()];
|
||||
stream.read_exact(&mut data)?;
|
||||
Ok(Packet::new(header, data))
|
||||
}
|
||||
|
||||
fn write_to_buf<B: BufMut>(&self, buf: &mut B) {
|
||||
buf.put_u16_le(self.length);
|
||||
self.header.write_to_buf(buf);
|
||||
buf.put_slice(&self.data);
|
||||
}
|
||||
|
||||
fn len(&self) -> usize {
|
||||
UDPGW_LENGTH_FIELD_SIZE + self.header.len() + self.data.len()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl AsyncStreamOperation for Packet {
|
||||
async fn retrieve_from_async_stream<R>(r: &mut R) -> std::io::Result<Self>
|
||||
where
|
||||
R: tokio::io::AsyncRead + Unpin + Send,
|
||||
Self: Sized,
|
||||
{
|
||||
let mut buf = [0; 2];
|
||||
r.read_exact(&mut buf).await?;
|
||||
let length = u16::from_le_bytes(buf);
|
||||
let header = UdpgwHeader::retrieve_from_async_stream(r).await?;
|
||||
let mut data = vec![0; length as usize - header.len()];
|
||||
r.read_exact(&mut data).await?;
|
||||
Ok(Packet::new(header, data))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
#[repr(C)]
|
||||
#[repr(packed(1))]
|
||||
pub struct UdpgwHeader {
|
||||
pub flags: u8,
|
||||
pub conid: u16,
|
||||
pub conn_id: u16,
|
||||
}
|
||||
|
||||
impl StreamOperation for UdpgwHeader {
|
||||
fn retrieve_from_stream<R>(stream: &mut R) -> std::io::Result<Self>
|
||||
where
|
||||
R: std::io::Read,
|
||||
Self: Sized,
|
||||
{
|
||||
let mut buf = [0; UdpgwHeader::len()];
|
||||
stream.read_exact(&mut buf)?;
|
||||
UdpgwHeader::try_from(&buf[..])
|
||||
}
|
||||
|
||||
fn write_to_buf<B: BufMut>(&self, buf: &mut B) {
|
||||
let bytes: Vec<u8> = self.into();
|
||||
buf.put_slice(&bytes);
|
||||
}
|
||||
|
||||
fn len(&self) -> usize {
|
||||
Self::len()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl AsyncStreamOperation for UdpgwHeader {
|
||||
async fn retrieve_from_async_stream<R>(r: &mut R) -> std::io::Result<Self>
|
||||
where
|
||||
R: tokio::io::AsyncRead + Unpin + Send,
|
||||
Self: Sized,
|
||||
{
|
||||
let mut buf = [0; UdpgwHeader::len()];
|
||||
r.read_exact(&mut buf).await?;
|
||||
UdpgwHeader::try_from(&buf[..])
|
||||
}
|
||||
}
|
||||
|
||||
impl UdpgwHeader {
|
||||
pub fn new(flags: u8, conid: u16) -> Self {
|
||||
UdpgwHeader { flags, conid }
|
||||
pub fn new(flags: u8, conn_id: u16) -> Self {
|
||||
UdpgwHeader { flags, conn_id }
|
||||
}
|
||||
|
||||
pub const fn len() -> usize {
|
||||
std::mem::size_of::<u16>() + std::mem::size_of::<UdpgwHeader>()
|
||||
std::mem::size_of::<UdpgwHeader>()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -50,25 +174,20 @@ impl TryFrom<&[u8]> for UdpgwHeader {
|
|||
|
||||
fn try_from(value: &[u8]) -> std::result::Result<Self, Self::Error> {
|
||||
if value.len() < UdpgwHeader::len() {
|
||||
return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "Invalid UdpgwHeader"));
|
||||
}
|
||||
let len = u16::from_le_bytes([value[0], value[1]]);
|
||||
if len != std::mem::size_of::<UdpgwHeader>() as u16 {
|
||||
return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "Invalid UdpgwHeader"));
|
||||
return Err(std::io::ErrorKind::InvalidData.into());
|
||||
}
|
||||
Ok(UdpgwHeader {
|
||||
flags: value[2],
|
||||
conid: u16::from_le_bytes([value[3], value[4]]),
|
||||
flags: value[0],
|
||||
conn_id: u16::from_le_bytes([value[1], value[2]]),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl From<UdpgwHeader> for Vec<u8> {
|
||||
fn from(header: UdpgwHeader) -> Vec<u8> {
|
||||
let mut bytes = vec![0; UdpgwHeader::len()];
|
||||
bytes[0..2].copy_from_slice(&(std::mem::size_of::<UdpgwHeader>() as u16).to_le_bytes());
|
||||
bytes[2] = header.flags;
|
||||
bytes[3..5].copy_from_slice(&header.conid.to_le_bytes());
|
||||
impl From<&UdpgwHeader> for Vec<u8> {
|
||||
fn from(header: &UdpgwHeader) -> Vec<u8> {
|
||||
let mut bytes = vec![0; header.len()];
|
||||
bytes[0] = header.flags;
|
||||
bytes[1..3].copy_from_slice(&header.conn_id.to_le_bytes());
|
||||
bytes
|
||||
}
|
||||
}
|
||||
|
@ -127,9 +246,10 @@ impl From<UdpgwAddr> for SocketAddr {
|
|||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct UdpGwData<'a> {
|
||||
flags: u8,
|
||||
conid: u16,
|
||||
conn_id: u16,
|
||||
remote_addr: SocketAddr,
|
||||
udpdata: &'a [u8],
|
||||
}
|
||||
|
@ -141,6 +261,7 @@ impl<'a> UdpGwData<'a> {
|
|||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[derive(Debug)]
|
||||
pub(crate) enum UdpGwResponse<'a> {
|
||||
KeepAlive,
|
||||
Error,
|
||||
|
@ -166,7 +287,7 @@ pub(crate) struct UdpGwClientStream {
|
|||
local_addr: String,
|
||||
writer: Option<UdpGwClientStreamWriter>,
|
||||
reader: Option<UdpGwClientStreamReader>,
|
||||
conid: u16,
|
||||
conn_id: u16,
|
||||
closed: bool,
|
||||
last_activity: std::time::Instant,
|
||||
}
|
||||
|
@ -210,12 +331,12 @@ impl UdpGwClientStream {
|
|||
}
|
||||
|
||||
pub fn id(&mut self) -> u16 {
|
||||
self.conid
|
||||
self.conn_id
|
||||
}
|
||||
|
||||
pub fn newid(&mut self) -> u16 {
|
||||
self.conid += 1;
|
||||
self.conid
|
||||
pub fn new_id(&mut self) -> u16 {
|
||||
self.conn_id += 1;
|
||||
self.conn_id
|
||||
}
|
||||
pub fn new(udp_mtu: u16, tcp_server_stream: TcpStream) -> Self {
|
||||
let local_addr = tcp_server_stream
|
||||
|
@ -239,7 +360,7 @@ impl UdpGwClientStream {
|
|||
writer: Some(writer),
|
||||
last_activity: std::time::Instant::now(),
|
||||
closed: false,
|
||||
conid: 0,
|
||||
conn_id: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -257,7 +378,7 @@ pub(crate) struct UdpGwClient {
|
|||
|
||||
impl UdpGwClient {
|
||||
pub fn new(udp_mtu: u16, max_connections: u16, keepalive_time: Duration, udp_timeout: u64, server_addr: SocketAddr) -> Self {
|
||||
let keepalive_packet = UdpgwHeader::new(UDPGW_FLAG_KEEPALIVE, 0).into();
|
||||
let keepalive_packet: Vec<u8> = Packet::new(UdpgwHeader::new(UDPGW_FLAG_KEEPALIVE, 0), vec![]).into();
|
||||
let server_connections = Mutex::new(VecDeque::with_capacity(max_connections as usize));
|
||||
UdpGwClient {
|
||||
udp_mtu,
|
||||
|
@ -326,14 +447,10 @@ impl UdpGwClient {
|
|||
let Some(mut stream_writer) = stream.get_writer() else {
|
||||
continue;
|
||||
};
|
||||
log::debug!("{:?}:{} send keepalive", stream_writer.inner.local_addr(), stream.id());
|
||||
let local_addr = stream_writer.inner.local_addr();
|
||||
log::debug!("{:?}:{} send keepalive", local_addr, stream.id());
|
||||
if let Err(e) = stream_writer.inner.write_all(&self.keepalive_packet).await {
|
||||
log::warn!(
|
||||
"{:?}:{} send keepalive failed: {}",
|
||||
stream_writer.inner.local_addr(),
|
||||
stream.id(),
|
||||
e
|
||||
);
|
||||
log::warn!("{:?}:{} send keepalive failed: {}", local_addr, stream.id(), e);
|
||||
} else {
|
||||
match UdpGwClient::recv_udpgw_packet(self.udp_mtu, 10, &mut stream_reader).await {
|
||||
Ok(UdpGwResponse::KeepAlive) => {
|
||||
|
@ -341,10 +458,8 @@ impl UdpGwClient {
|
|||
self.release_server_connection_with_stream(stream, stream_reader, stream_writer)
|
||||
.await;
|
||||
}
|
||||
//shoud not receive other type
|
||||
_ => {
|
||||
log::warn!("{:?}:{} keepalive no response", stream_writer.inner.local_addr(), stream.id());
|
||||
}
|
||||
Ok(v) => log::warn!("{:?}:{} keepalive unexpected response: {:?}", local_addr, stream.id(), v),
|
||||
Err(e) => log::warn!("{:?}:{} keepalive no response, error \"{}\"", local_addr, stream.id(), e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -354,20 +469,20 @@ impl UdpGwClient {
|
|||
/// Parses the UDP response data.
|
||||
pub(crate) fn parse_udp_response(udp_mtu: u16, data_len: usize, stream: &mut UdpGwClientStreamReader) -> Result<UdpGwResponse> {
|
||||
let data = &stream.recv_buf;
|
||||
if data_len < mem::size_of::<UdpgwHeader>() {
|
||||
if data_len < UdpgwHeader::len() {
|
||||
return Err("Invalid udpgw data".into());
|
||||
}
|
||||
let header_bytes = &data[..mem::size_of::<UdpgwHeader>()];
|
||||
let header_bytes = &data[..UdpgwHeader::len()];
|
||||
let header = UdpgwHeader {
|
||||
flags: header_bytes[0],
|
||||
conid: u16::from_le_bytes([header_bytes[1], header_bytes[2]]),
|
||||
conn_id: u16::from_le_bytes([header_bytes[1], header_bytes[2]]),
|
||||
};
|
||||
|
||||
let flags = header.flags;
|
||||
let conid = header.conid;
|
||||
let conn_id = header.conn_id;
|
||||
|
||||
let ip_data = &data[mem::size_of::<UdpgwHeader>()..];
|
||||
let mut data_len = data_len - mem::size_of::<UdpgwHeader>();
|
||||
let ip_data = &data[UdpgwHeader::len()..];
|
||||
let mut data_len = data_len - UdpgwHeader::len();
|
||||
|
||||
if flags & UDPGW_FLAG_ERR != 0 {
|
||||
return Ok(UdpGwResponse::Error);
|
||||
|
@ -378,44 +493,44 @@ impl UdpGwClient {
|
|||
}
|
||||
|
||||
if flags & UDPGW_FLAG_IPV6 != 0 {
|
||||
if data_len < mem::size_of::<UdpgwAddrIpv6>() {
|
||||
if data_len < std::mem::size_of::<UdpgwAddrIpv6>() {
|
||||
return Err("ipv6 Invalid UDP data".into());
|
||||
}
|
||||
let addr_ipv6_bytes = &ip_data[..mem::size_of::<UdpgwAddrIpv6>()];
|
||||
let addr_ipv6_bytes = &ip_data[..std::mem::size_of::<UdpgwAddrIpv6>()];
|
||||
let addr_ipv6 = UdpgwAddrIpv6 {
|
||||
addr_ip: addr_ipv6_bytes[..16].try_into().map_err(|_| "Failed to convert slice to array")?,
|
||||
addr_port: u16::from_be_bytes([addr_ipv6_bytes[16], addr_ipv6_bytes[17]]),
|
||||
};
|
||||
data_len -= mem::size_of::<UdpgwAddrIpv6>();
|
||||
data_len -= std::mem::size_of::<UdpgwAddrIpv6>();
|
||||
|
||||
if data_len > udp_mtu as usize {
|
||||
return Err("too much data".into());
|
||||
}
|
||||
return Ok(UdpGwResponse::Data(UdpGwData {
|
||||
flags,
|
||||
conid,
|
||||
conn_id,
|
||||
remote_addr: UdpgwAddr::IPV6(addr_ipv6).into(),
|
||||
udpdata: &ip_data[mem::size_of::<UdpgwAddrIpv6>()..(data_len + mem::size_of::<UdpgwAddrIpv6>())],
|
||||
udpdata: &ip_data[std::mem::size_of::<UdpgwAddrIpv6>()..(data_len + std::mem::size_of::<UdpgwAddrIpv6>())],
|
||||
}));
|
||||
} else {
|
||||
if data_len < mem::size_of::<UdpgwAddrIpv4>() {
|
||||
if data_len < std::mem::size_of::<UdpgwAddrIpv4>() {
|
||||
return Err("ipv4 Invalid UDP data".into());
|
||||
}
|
||||
let addr_ipv4_bytes = &ip_data[..mem::size_of::<UdpgwAddrIpv4>()];
|
||||
let addr_ipv4_bytes = &ip_data[..std::mem::size_of::<UdpgwAddrIpv4>()];
|
||||
let addr_ipv4 = UdpgwAddrIpv4 {
|
||||
addr_ip: u32::from_be_bytes([addr_ipv4_bytes[0], addr_ipv4_bytes[1], addr_ipv4_bytes[2], addr_ipv4_bytes[3]]),
|
||||
addr_port: u16::from_be_bytes([addr_ipv4_bytes[4], addr_ipv4_bytes[5]]),
|
||||
};
|
||||
data_len -= mem::size_of::<UdpgwAddrIpv4>();
|
||||
data_len -= std::mem::size_of::<UdpgwAddrIpv4>();
|
||||
|
||||
if data_len > udp_mtu as usize {
|
||||
return Err("too much data".into());
|
||||
}
|
||||
return Ok(UdpGwResponse::Data(UdpGwData {
|
||||
flags,
|
||||
conid,
|
||||
conn_id,
|
||||
remote_addr: UdpgwAddr::IPV4(addr_ipv4).into(),
|
||||
udpdata: &ip_data[mem::size_of::<UdpgwAddrIpv4>()..(data_len + mem::size_of::<UdpgwAddrIpv4>())],
|
||||
udpdata: &ip_data[std::mem::size_of::<UdpgwAddrIpv4>()..(data_len + std::mem::size_of::<UdpgwAddrIpv4>())],
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
@ -456,8 +571,8 @@ impl UdpGwClient {
|
|||
if n == 0 {
|
||||
return Ok(UdpGwResponse::TcpClose);
|
||||
}
|
||||
if n < std::mem::size_of::<PackLenHeader>() {
|
||||
return Err("received PackLenHeader error".into());
|
||||
if n < UDPGW_LENGTH_FIELD_SIZE {
|
||||
return Err("received Packet Length field error".into());
|
||||
}
|
||||
let packet_len = u16::from_le_bytes([stream.recv_buf[0], stream.recv_buf[1]]);
|
||||
if packet_len > udp_mtu {
|
||||
|
@ -489,7 +604,7 @@ impl UdpGwClient {
|
|||
/// * `len` - Length of the data packet
|
||||
/// * `remote_addr` - Remote address
|
||||
/// * `domain` - Target domain (optional)
|
||||
/// * `conid` - Connection ID
|
||||
/// * `conn_id` - Connection ID
|
||||
/// * `stream` - UDP gateway client writer stream
|
||||
///
|
||||
/// # Returns
|
||||
|
@ -500,30 +615,28 @@ impl UdpGwClient {
|
|||
len: usize,
|
||||
remote_addr: SocketAddr,
|
||||
domain: Option<&String>,
|
||||
conid: u16,
|
||||
conn_id: u16,
|
||||
stream: &mut UdpGwClientStreamWriter,
|
||||
) -> Result<()> {
|
||||
stream.send_buf.clear();
|
||||
let data = &stream.tmp_buf;
|
||||
let mut pack_len = std::mem::size_of::<UdpgwHeader>() + len;
|
||||
let mut pack_len = UdpgwHeader::len() + len;
|
||||
let packet = &mut stream.send_buf;
|
||||
let mut flags = 0;
|
||||
match domain {
|
||||
Some(domain) => {
|
||||
let addr_port = match remote_addr.into() {
|
||||
UdpgwAddr::IPV4(addr_ipv4) => addr_ipv4.addr_port,
|
||||
UdpgwAddr::IPV6(addr_ipv6) => addr_ipv6.addr_port,
|
||||
};
|
||||
pack_len += std::mem::size_of::<u16>();
|
||||
let domain_len = domain.len();
|
||||
if domain_len > 255 {
|
||||
return Err("InvalidDomain".into());
|
||||
}
|
||||
pack_len += UDPGW_LENGTH_FIELD_SIZE;
|
||||
pack_len += domain_len + 1;
|
||||
flags = UDPGW_FLAG_DOMAIN;
|
||||
packet.extend_from_slice(&(pack_len as u16).to_le_bytes());
|
||||
packet.extend_from_slice(&[flags]);
|
||||
packet.extend_from_slice(&conid.to_le_bytes());
|
||||
packet.extend_from_slice(&[UDPGW_FLAG_DOMAIN]);
|
||||
packet.extend_from_slice(&conn_id.to_le_bytes());
|
||||
packet.extend_from_slice(&addr_port.to_be_bytes());
|
||||
packet.extend_from_slice(domain.as_bytes());
|
||||
packet.push(0);
|
||||
|
@ -533,8 +646,8 @@ impl UdpGwClient {
|
|||
UdpgwAddr::IPV4(addr_ipv4) => {
|
||||
pack_len += std::mem::size_of::<UdpgwAddrIpv4>();
|
||||
packet.extend_from_slice(&(pack_len as u16).to_le_bytes());
|
||||
packet.extend_from_slice(&[flags]);
|
||||
packet.extend_from_slice(&conid.to_le_bytes());
|
||||
packet.extend_from_slice(&[UDPGW_FLAG_IPV4]);
|
||||
packet.extend_from_slice(&conn_id.to_le_bytes());
|
||||
packet.extend_from_slice(&addr_ipv4.addr_ip.to_be_bytes());
|
||||
packet.extend_from_slice(&addr_ipv4.addr_port.to_be_bytes());
|
||||
packet.extend_from_slice(&data[..len]);
|
||||
|
@ -543,11 +656,10 @@ impl UdpGwClient {
|
|||
if !ipv6_enabled {
|
||||
return Err("ipv6 not support".into());
|
||||
}
|
||||
flags = UDPGW_FLAG_IPV6;
|
||||
pack_len += std::mem::size_of::<UdpgwAddrIpv6>();
|
||||
packet.extend_from_slice(&(pack_len as u16).to_le_bytes());
|
||||
packet.extend_from_slice(&[flags]);
|
||||
packet.extend_from_slice(&conid.to_le_bytes());
|
||||
packet.extend_from_slice(&[UDPGW_FLAG_IPV6]);
|
||||
packet.extend_from_slice(&conn_id.to_le_bytes());
|
||||
packet.extend_from_slice(&addr_ipv6.addr_ip);
|
||||
packet.extend_from_slice(&addr_ipv6.addr_port.to_be_bytes());
|
||||
packet.extend_from_slice(&data[..len]);
|
||||
|
@ -563,13 +675,18 @@ impl UdpGwClient {
|
|||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::UdpgwHeader;
|
||||
use super::{Packet, UdpgwHeader};
|
||||
use socks5_impl::protocol::StreamOperation;
|
||||
|
||||
#[test]
|
||||
fn test_udpgw_header() {
|
||||
let header = UdpgwHeader::new(0x01, 0x1234);
|
||||
let bytes = Vec::from(header.clone());
|
||||
let header2 = UdpgwHeader::try_from(&bytes[..]).unwrap();
|
||||
let mut bytes: Vec<u8> = vec![];
|
||||
let packet = Packet::new(header, vec![]);
|
||||
packet.write_to_buf(&mut bytes);
|
||||
|
||||
let header2 = Packet::retrieve_from_stream(&mut bytes.as_slice()).unwrap().header;
|
||||
|
||||
assert_eq!(header, header2);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue