mirror of
https://github.com/tun2proxy/tun2proxy.git
synced 2025-04-19 13:29:09 +00:00
support udp gateway mode
This commit is contained in:
parent
fe32a65291
commit
aee8e14a22
6 changed files with 965 additions and 1 deletions
|
@ -65,5 +65,9 @@ serde_json = "1"
|
||||||
name = "tun2proxy-bin"
|
name = "tun2proxy-bin"
|
||||||
path = "src/bin/main.rs"
|
path = "src/bin/main.rs"
|
||||||
|
|
||||||
|
[[bin]]
|
||||||
|
name = "udpgwserver"
|
||||||
|
path = "src/bin/udpgw_server.rs"
|
||||||
|
|
||||||
[profile.release]
|
[profile.release]
|
||||||
strip = "symbols"
|
strip = "symbols"
|
||||||
|
|
10
src/args.rs
10
src/args.rs
|
@ -74,6 +74,10 @@ pub struct Args {
|
||||||
#[arg(short, long, value_name = "strategy", value_enum, default_value = "direct")]
|
#[arg(short, long, value_name = "strategy", value_enum, default_value = "direct")]
|
||||||
pub dns: ArgDns,
|
pub dns: ArgDns,
|
||||||
|
|
||||||
|
/// UDP gateway address
|
||||||
|
#[arg(long, value_name = "IP:PORT")]
|
||||||
|
pub udpgw_bind_addr: Option<SocketAddr>,
|
||||||
|
|
||||||
/// DNS resolver address
|
/// DNS resolver address
|
||||||
#[arg(long, value_name = "IP", default_value = "8.8.8.8")]
|
#[arg(long, value_name = "IP", default_value = "8.8.8.8")]
|
||||||
pub dns_addr: IpAddr,
|
pub dns_addr: IpAddr,
|
||||||
|
@ -136,6 +140,7 @@ impl Default for Args {
|
||||||
admin_command: Vec::new(),
|
admin_command: Vec::new(),
|
||||||
ipv6_enabled: false,
|
ipv6_enabled: false,
|
||||||
setup,
|
setup,
|
||||||
|
udpgw_bind_addr: None,
|
||||||
dns: ArgDns::default(),
|
dns: ArgDns::default(),
|
||||||
dns_addr: "8.8.8.8".parse().unwrap(),
|
dns_addr: "8.8.8.8".parse().unwrap(),
|
||||||
bypass: vec![],
|
bypass: vec![],
|
||||||
|
@ -171,6 +176,11 @@ impl Args {
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn udpgw(&mut self, udpgw: SocketAddr) -> &mut Self {
|
||||||
|
self.udpgw_bind_addr = Some(udpgw);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(unix)]
|
#[cfg(unix)]
|
||||||
pub fn tun_fd(&mut self, tun_fd: Option<i32>) -> &mut Self {
|
pub fn tun_fd(&mut self, tun_fd: Option<i32>) -> &mut Self {
|
||||||
self.tun_fd = tun_fd;
|
self.tun_fd = tun_fd;
|
||||||
|
|
355
src/bin/udpgw_server.rs
Normal file
355
src/bin/udpgw_server.rs
Normal file
|
@ -0,0 +1,355 @@
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::mem;
|
||||||
|
use std::net::Ipv4Addr;
|
||||||
|
use std::net::SocketAddr;
|
||||||
|
use std::net::SocketAddrV4;
|
||||||
|
use std::net::ToSocketAddrs;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||||
|
use tokio::net::tcp::ReadHalf;
|
||||||
|
use tokio::net::TcpListener;
|
||||||
|
use tokio::net::UdpSocket;
|
||||||
|
use tokio::sync::mpsc;
|
||||||
|
use tokio::sync::mpsc::Sender;
|
||||||
|
use tokio::sync::Mutex;
|
||||||
|
pub use tun2proxy::udpgw::*;
|
||||||
|
use tun2proxy::ArgVerbosity;
|
||||||
|
use tun2proxy::Result;
|
||||||
|
pub(crate) const CLIENT_DISCONNECT_TIMEOUT: tokio::time::Duration = std::time::Duration::from_secs(60);
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct Connection {
|
||||||
|
flags: u8,
|
||||||
|
server_addr: SocketAddr,
|
||||||
|
conid: u16,
|
||||||
|
data: Vec<u8>,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Client {
|
||||||
|
#[allow(dead_code)]
|
||||||
|
addr: SocketAddr,
|
||||||
|
buf: Vec<u8>,
|
||||||
|
connections: Arc<Mutex<HashMap<u16, Connection>>>,
|
||||||
|
last_activity: std::time::Instant,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, clap::Parser)]
|
||||||
|
pub struct UdpGwArgs {
|
||||||
|
/// UDP mtu
|
||||||
|
#[arg(long, value_name = "udp mtu", default_value = "10240")]
|
||||||
|
pub udp_mtu: u16,
|
||||||
|
|
||||||
|
/// Verbosity level
|
||||||
|
#[arg(short, long, value_name = "level", value_enum, default_value = "info")]
|
||||||
|
pub verbosity: ArgVerbosity,
|
||||||
|
|
||||||
|
/// UDP timeout in seconds
|
||||||
|
#[arg(long, value_name = "seconds", default_value = "3")]
|
||||||
|
pub udp_timeout: u64,
|
||||||
|
|
||||||
|
/// UDP gateway listen address
|
||||||
|
#[arg(long, value_name = "IP:PORT", default_value = "127.0.0.1:7300")]
|
||||||
|
pub listen_addr: SocketAddr,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl UdpGwArgs {
|
||||||
|
#[allow(clippy::let_and_return)]
|
||||||
|
pub fn parse_args() -> Self {
|
||||||
|
use clap::Parser;
|
||||||
|
let args = Self::parse();
|
||||||
|
args
|
||||||
|
}
|
||||||
|
}
|
||||||
|
async fn send_error_response(tx: Sender<Vec<u8>>, con: &mut Connection) {
|
||||||
|
let mut error_packet = vec![];
|
||||||
|
error_packet.extend_from_slice(&(std::mem::size_of::<UdpgwHeader>() as u16).to_le_bytes());
|
||||||
|
error_packet.extend_from_slice(&[UDPGW_FLAG_ERR]);
|
||||||
|
error_packet.extend_from_slice(&con.conid.to_le_bytes());
|
||||||
|
if let Err(e) = tx.send(error_packet).await {
|
||||||
|
log::error!("send error response error {:?}", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn parse_udp_req_data(udp_mtu: u16, data_len: usize, data: &[u8]) -> Result<(&[u8], u8, u16, SocketAddr)> {
|
||||||
|
if data_len < mem::size_of::<UdpgwHeader>() {
|
||||||
|
return Err("Invalid udpgw data".into());
|
||||||
|
}
|
||||||
|
let header_bytes = &data[..mem::size_of::<UdpgwHeader>()];
|
||||||
|
let header = UdpgwHeader {
|
||||||
|
flags: header_bytes[0],
|
||||||
|
conid: u16::from_le_bytes([header_bytes[1], header_bytes[2]]),
|
||||||
|
};
|
||||||
|
|
||||||
|
let flags = header.flags;
|
||||||
|
let conid = header.conid;
|
||||||
|
|
||||||
|
// keepalive
|
||||||
|
if flags & UDPGW_FLAG_KEEPALIVE != 0 {
|
||||||
|
return Ok((data, UDPGW_FLAG_KEEPALIVE, 0, SocketAddrV4::new(Ipv4Addr::from(0), 0).into()));
|
||||||
|
}
|
||||||
|
|
||||||
|
// parse address
|
||||||
|
let ip_data = &data[mem::size_of::<UdpgwHeader>()..];
|
||||||
|
let mut data_len = data_len - mem::size_of::<UdpgwHeader>();
|
||||||
|
// port_len + min(ipv4/ipv6/(domain_len + 1))
|
||||||
|
if data_len < mem::size_of::<u16>() + 2 {
|
||||||
|
return Err("Invalid udpgw data".into());
|
||||||
|
}
|
||||||
|
if flags & UDPGW_FLAG_DOMAIN != 0 {
|
||||||
|
let addr_port = u16::from_be_bytes([ip_data[0], ip_data[1]]);
|
||||||
|
data_len -= 2;
|
||||||
|
if let Some(end) = ip_data.iter().skip(2).position(|&x| x == 0) {
|
||||||
|
let domain_slice = &ip_data[2..end + 2];
|
||||||
|
match std::str::from_utf8(domain_slice) {
|
||||||
|
Ok(domain) => {
|
||||||
|
let target_str = format!("{}:{}", domain, addr_port);
|
||||||
|
let target = target_str
|
||||||
|
.to_socket_addrs()?
|
||||||
|
.next()
|
||||||
|
.ok_or(format!("Invalid address {}", target_str))?;
|
||||||
|
// check payload length
|
||||||
|
if data_len < 2 + domain.len() {
|
||||||
|
return Err("Invalid udpgw data".into());
|
||||||
|
}
|
||||||
|
data_len -= domain.len() + 1;
|
||||||
|
if data_len > udp_mtu as usize {
|
||||||
|
return Err("too much data".into());
|
||||||
|
}
|
||||||
|
let udpdata = &ip_data[(2 + domain.len() + 1)..];
|
||||||
|
return Ok((udpdata, flags, conid, target));
|
||||||
|
}
|
||||||
|
Err(_) => {
|
||||||
|
return Err("Invalid UTF-8 sequence in domain".into());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return Err("missing domain name".into());
|
||||||
|
}
|
||||||
|
} else if flags & UDPGW_FLAG_IPV6 != 0 {
|
||||||
|
if data_len < mem::size_of::<UdpgwAddrIpv6>() {
|
||||||
|
return Err("Ipv6 Invalid UDP data".into());
|
||||||
|
}
|
||||||
|
let addr_ipv6_bytes = &ip_data[..mem::size_of::<UdpgwAddrIpv6>()];
|
||||||
|
let addr_ipv6 = UdpgwAddrIpv6 {
|
||||||
|
addr_ip: addr_ipv6_bytes[..16].try_into().map_err(|_| "Failed to convert slice to array")?,
|
||||||
|
addr_port: u16::from_be_bytes([addr_ipv6_bytes[16], addr_ipv6_bytes[17]]),
|
||||||
|
};
|
||||||
|
data_len -= mem::size_of::<UdpgwAddrIpv6>();
|
||||||
|
|
||||||
|
// check payload length
|
||||||
|
if data_len > udp_mtu as usize {
|
||||||
|
return Err("too much data".into());
|
||||||
|
}
|
||||||
|
return Ok((
|
||||||
|
&ip_data[mem::size_of::<UdpgwAddrIpv6>()..(data_len + mem::size_of::<UdpgwAddrIpv6>())],
|
||||||
|
flags,
|
||||||
|
conid,
|
||||||
|
UdpgwAddr::IPV6(addr_ipv6).into(),
|
||||||
|
));
|
||||||
|
} else {
|
||||||
|
if data_len < mem::size_of::<UdpgwAddrIpv4>() {
|
||||||
|
return Err("Ipv4 Invalid UDP data".into());
|
||||||
|
}
|
||||||
|
let addr_ipv4_bytes = &ip_data[..mem::size_of::<UdpgwAddrIpv4>()];
|
||||||
|
let addr_ipv4 = UdpgwAddrIpv4 {
|
||||||
|
addr_ip: u32::from_be_bytes([addr_ipv4_bytes[0], addr_ipv4_bytes[1], addr_ipv4_bytes[2], addr_ipv4_bytes[3]]),
|
||||||
|
addr_port: u16::from_be_bytes([addr_ipv4_bytes[4], addr_ipv4_bytes[5]]),
|
||||||
|
};
|
||||||
|
data_len -= mem::size_of::<UdpgwAddrIpv4>();
|
||||||
|
|
||||||
|
// check payload length
|
||||||
|
if data_len > udp_mtu as usize {
|
||||||
|
return Err("too much data".into());
|
||||||
|
}
|
||||||
|
|
||||||
|
return Ok((
|
||||||
|
&ip_data[mem::size_of::<UdpgwAddrIpv4>()..(data_len + mem::size_of::<UdpgwAddrIpv4>())],
|
||||||
|
flags,
|
||||||
|
conid,
|
||||||
|
UdpgwAddr::IPV4(addr_ipv4).into(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn process_udp(addr: SocketAddr, udp_timeout: u64, tx: Sender<Vec<u8>>, con: &mut Connection) -> Result<()> {
|
||||||
|
let std_sock = std::net::UdpSocket::bind("0.0.0.0:0")?;
|
||||||
|
std_sock.set_nonblocking(true)?;
|
||||||
|
nix::sys::socket::setsockopt(&std_sock, nix::sys::socket::sockopt::ReuseAddr, &true)?;
|
||||||
|
let socket = UdpSocket::from_std(std_sock)?;
|
||||||
|
socket.send_to(&con.data, &con.server_addr).await?;
|
||||||
|
con.data.resize(2048, 0);
|
||||||
|
match tokio::time::timeout(tokio::time::Duration::from_secs(udp_timeout), socket.recv_from(&mut con.data[..])).await? {
|
||||||
|
Ok((len, _addr)) => {
|
||||||
|
let mut packet = vec![];
|
||||||
|
let mut pack_len = mem::size_of::<UdpgwHeader>() + len;
|
||||||
|
match con.server_addr.into() {
|
||||||
|
UdpgwAddr::IPV4(addr_ipv4) => {
|
||||||
|
pack_len += mem::size_of::<UdpgwAddrIpv4>();
|
||||||
|
packet.extend_from_slice(&(pack_len as u16).to_le_bytes());
|
||||||
|
packet.extend_from_slice(&[con.flags]);
|
||||||
|
packet.extend_from_slice(&con.conid.to_le_bytes());
|
||||||
|
packet.extend_from_slice(&addr_ipv4.addr_ip.to_be_bytes());
|
||||||
|
packet.extend_from_slice(&addr_ipv4.addr_port.to_be_bytes());
|
||||||
|
packet.extend_from_slice(&con.data[..len]);
|
||||||
|
}
|
||||||
|
UdpgwAddr::IPV6(addr_ipv6) => {
|
||||||
|
pack_len += mem::size_of::<UdpgwAddrIpv6>();
|
||||||
|
packet.extend_from_slice(&(pack_len as u16).to_le_bytes());
|
||||||
|
packet.extend_from_slice(&[con.flags]);
|
||||||
|
packet.extend_from_slice(&con.conid.to_le_bytes());
|
||||||
|
packet.extend_from_slice(&addr_ipv6.addr_ip);
|
||||||
|
packet.extend_from_slice(&addr_ipv6.addr_port.to_be_bytes());
|
||||||
|
packet.extend_from_slice(&con.data[..len]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if let Err(e) = tx.send(packet).await {
|
||||||
|
log::error!("client {} send udp response error {:?}", addr, e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
log::error!("client {} udp recv_from error: {:?}", addr, e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn process_client_udp_req<'a>(args: Arc<UdpGwArgs>, tx: Sender<Vec<u8>>, mut client: Client, mut tcp_read_stream: ReadHalf<'a>) {
|
||||||
|
let mut buf = vec![0; args.udp_mtu as usize];
|
||||||
|
let mut len_buf = [0; mem::size_of::<PackLenHeader>()];
|
||||||
|
let udp_mtu = args.udp_mtu;
|
||||||
|
let udp_timeout = args.udp_timeout;
|
||||||
|
'out: loop {
|
||||||
|
let result;
|
||||||
|
match tokio::time::timeout(tokio::time::Duration::from_secs(2), tcp_read_stream.read(&mut len_buf)).await {
|
||||||
|
Ok(ret) => {
|
||||||
|
result = ret;
|
||||||
|
}
|
||||||
|
Err(_e) => {
|
||||||
|
if client.last_activity.elapsed() >= CLIENT_DISCONNECT_TIMEOUT {
|
||||||
|
log::warn!("client {} last_activity elapsed", client.addr);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
match result {
|
||||||
|
Ok(0) => break, // Connection closed
|
||||||
|
Ok(n) => {
|
||||||
|
if n < mem::size_of::<PackLenHeader>() {
|
||||||
|
log::error!("client {} received PackLenHeader error", client.addr);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
let packet_len = u16::from_le_bytes([len_buf[0], len_buf[1]]);
|
||||||
|
if packet_len > udp_mtu {
|
||||||
|
log::error!("client {} received packet too long", client.addr);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
log::info!("client {} recvied packet len {}", client.addr, packet_len);
|
||||||
|
buf.resize(packet_len as usize, 0);
|
||||||
|
client.buf.clear();
|
||||||
|
let mut left_len: usize = packet_len as usize;
|
||||||
|
while left_len > 0 {
|
||||||
|
if let Ok(len) = tcp_read_stream.read(&mut buf[..left_len]).await {
|
||||||
|
if len == 0 {
|
||||||
|
break 'out;
|
||||||
|
}
|
||||||
|
client.buf.extend_from_slice(&mut buf[..len]);
|
||||||
|
left_len -= len;
|
||||||
|
} else {
|
||||||
|
break 'out;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
client.last_activity = std::time::Instant::now();
|
||||||
|
let ret = parse_udp_req_data(udp_mtu, client.buf.len(), &client.buf);
|
||||||
|
if let Ok((udpdata, flags, conid, reqaddr)) = ret {
|
||||||
|
if flags & UDPGW_FLAG_KEEPALIVE != 0 {
|
||||||
|
log::debug!("client {} recvied keepalive packet", client.addr);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
log::debug!(
|
||||||
|
"client {} recvied udp data,flags:{},conid:{},addr:{:?},data len:{}",
|
||||||
|
client.addr,
|
||||||
|
flags,
|
||||||
|
conid,
|
||||||
|
reqaddr,
|
||||||
|
udpdata.len()
|
||||||
|
);
|
||||||
|
let mut con_lock = client.connections.lock().await;
|
||||||
|
let con = con_lock.get_mut(&conid);
|
||||||
|
if let Some(conn) = con {
|
||||||
|
conn.data.clear();
|
||||||
|
conn.data.extend_from_slice(udpdata);
|
||||||
|
if let Err(e) = process_udp(client.addr, udp_timeout, tx.clone(), conn).await {
|
||||||
|
log::error!("client {} process_udp error: {:?}", client.addr, e);
|
||||||
|
send_error_response(tx.clone(), conn).await;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
drop(con_lock);
|
||||||
|
let mut conn = Connection {
|
||||||
|
server_addr: reqaddr,
|
||||||
|
conid,
|
||||||
|
flags,
|
||||||
|
data: udpdata.to_vec(),
|
||||||
|
};
|
||||||
|
if let Err(e) = process_udp(client.addr, udp_timeout, tx.clone(), &mut conn).await {
|
||||||
|
send_error_response(tx.clone(), &mut conn).await;
|
||||||
|
log::error!("client {} process_udp error: {:?}", client.addr, e);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
client.connections.lock().await.insert(conid, conn);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
log::error!("client {} parse_udp_data {:?}", client.addr, ret.err());
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(_) => {
|
||||||
|
log::error!("client {} tcp_read_stream error", client.addr);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() -> Result<()> {
|
||||||
|
let args = Arc::new(UdpGwArgs::parse_args());
|
||||||
|
|
||||||
|
let tcp_listener = TcpListener::bind(args.listen_addr).await?;
|
||||||
|
|
||||||
|
log::info!("UDP GW Server started");
|
||||||
|
|
||||||
|
let default = format!("{:?},hickory_proto=warn", args.verbosity);
|
||||||
|
|
||||||
|
env_logger::Builder::from_env(env_logger::Env::default().default_filter_or(default)).init();
|
||||||
|
|
||||||
|
loop {
|
||||||
|
let (mut tcp_stream, addr) = tcp_listener.accept().await?;
|
||||||
|
let client = Client {
|
||||||
|
addr,
|
||||||
|
buf: vec![],
|
||||||
|
connections: Arc::new(Mutex::new(HashMap::new())),
|
||||||
|
last_activity: std::time::Instant::now(),
|
||||||
|
};
|
||||||
|
log::info!("client {} connected", addr);
|
||||||
|
let params = args.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let (tx, mut rx) = mpsc::channel::<Vec<u8>>(100);
|
||||||
|
let (tcp_read_stream, mut tcp_write_stream) = tcp_stream.split();
|
||||||
|
tokio::select! {
|
||||||
|
_ = process_client_udp_req(params, tx, client, tcp_read_stream) =>{}
|
||||||
|
_ = async {
|
||||||
|
loop
|
||||||
|
{
|
||||||
|
if let Some(udp_response) = rx.recv().await {
|
||||||
|
log::info!("client {} send udp data len:{}", addr, udp_response.len(),);
|
||||||
|
let _ = tcp_write_stream.write(&udp_response).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} => {}
|
||||||
|
}
|
||||||
|
log::info!("client {} disconnected", addr);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
|
@ -47,6 +47,9 @@ pub enum Error {
|
||||||
#[cfg(target_os = "linux")]
|
#[cfg(target_os = "linux")]
|
||||||
#[error("bincode::Error {0:?}")]
|
#[error("bincode::Error {0:?}")]
|
||||||
BincodeError(#[from] bincode::Error),
|
BincodeError(#[from] bincode::Error),
|
||||||
|
|
||||||
|
#[error("tokio::time::error::Elapsed")]
|
||||||
|
Timeout(#[from] tokio::time::error::Elapsed),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<&str> for Error {
|
impl From<&str> for Error {
|
||||||
|
|
145
src/lib.rs
145
src/lib.rs
|
@ -3,6 +3,7 @@ use crate::{
|
||||||
http::HttpManager,
|
http::HttpManager,
|
||||||
no_proxy::NoProxyManager,
|
no_proxy::NoProxyManager,
|
||||||
session_info::{IpProtocol, SessionInfo},
|
session_info::{IpProtocol, SessionInfo},
|
||||||
|
udpgw::UdpGwClient,
|
||||||
virtual_dns::VirtualDns,
|
virtual_dns::VirtualDns,
|
||||||
};
|
};
|
||||||
use ipstack::stream::{IpStackStream, IpStackTcpStream, IpStackUdpStream};
|
use ipstack::stream::{IpStackStream, IpStackTcpStream, IpStackUdpStream};
|
||||||
|
@ -12,7 +13,7 @@ pub use socks5_impl::protocol::UserKey;
|
||||||
use std::{
|
use std::{
|
||||||
collections::VecDeque,
|
collections::VecDeque,
|
||||||
io::ErrorKind,
|
io::ErrorKind,
|
||||||
net::{IpAddr, SocketAddr},
|
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6},
|
||||||
sync::Arc,
|
sync::Arc,
|
||||||
};
|
};
|
||||||
use tokio::{
|
use tokio::{
|
||||||
|
@ -23,6 +24,7 @@ use tokio::{
|
||||||
pub use tokio_util::sync::CancellationToken;
|
pub use tokio_util::sync::CancellationToken;
|
||||||
use tproxy_config::is_private_ip;
|
use tproxy_config::is_private_ip;
|
||||||
use udp_stream::UdpStream;
|
use udp_stream::UdpStream;
|
||||||
|
use udpgw::{UdpGwClientStream, UdpGwResponse, UDPGW_KEEPALIVE_TIME, UDPGW_MAX_CONNECTIONS};
|
||||||
|
|
||||||
pub use {
|
pub use {
|
||||||
args::{ArgDns, ArgProxy, ArgVerbosity, Args, ProxyType},
|
args::{ArgDns, ArgProxy, ArgVerbosity, Args, ProxyType},
|
||||||
|
@ -59,6 +61,7 @@ mod session_info;
|
||||||
pub mod socket_transfer;
|
pub mod socket_transfer;
|
||||||
mod socks;
|
mod socks;
|
||||||
mod traffic_status;
|
mod traffic_status;
|
||||||
|
pub mod udpgw;
|
||||||
mod virtual_dns;
|
mod virtual_dns;
|
||||||
#[doc(hidden)]
|
#[doc(hidden)]
|
||||||
pub mod win_svc;
|
pub mod win_svc;
|
||||||
|
@ -233,6 +236,19 @@ where
|
||||||
|
|
||||||
let mut ip_stack = ipstack::IpStack::new(ipstack_config, device);
|
let mut ip_stack = ipstack::IpStack::new(ipstack_config, device);
|
||||||
|
|
||||||
|
let udpgw_client = match args.udpgw_bind_addr {
|
||||||
|
None => None,
|
||||||
|
Some(addr) => {
|
||||||
|
log::info!("UDPGW enabled");
|
||||||
|
let client = Arc::new(UdpGwClient::new(mtu, UDPGW_MAX_CONNECTIONS, UDPGW_KEEPALIVE_TIME, addr));
|
||||||
|
let client_keepalive = client.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
client_keepalive.heartbeat_task().await;
|
||||||
|
});
|
||||||
|
Some(client)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
let virtual_dns = virtual_dns.clone();
|
let virtual_dns = virtual_dns.clone();
|
||||||
let ip_stack_stream = tokio::select! {
|
let ip_stack_stream = tokio::select! {
|
||||||
|
@ -265,6 +281,7 @@ where
|
||||||
if let Err(err) = handle_tcp_session(tcp, proxy_handler, socket_queue).await {
|
if let Err(err) = handle_tcp_session(tcp, proxy_handler, socket_queue).await {
|
||||||
log::error!("{} error \"{}\"", info, err);
|
log::error!("{} error \"{}\"", info, err);
|
||||||
}
|
}
|
||||||
|
|
||||||
log::trace!("Session count {}", TASK_COUNT.fetch_sub(1, Relaxed) - 1);
|
log::trace!("Session count {}", TASK_COUNT.fetch_sub(1, Relaxed) - 1);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
@ -311,6 +328,24 @@ where
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
|
if let Some(udpgw) = udpgw_client.clone() {
|
||||||
|
let tcp_src = match udp.peer_addr() {
|
||||||
|
SocketAddr::V4(_) => SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0)),
|
||||||
|
SocketAddr::V6(_) => SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0), 0, 0, 0)),
|
||||||
|
};
|
||||||
|
let tcpinfo = SessionInfo::new(tcp_src, udpgw.get_udpgw_bind_addr(), IpProtocol::Tcp);
|
||||||
|
let proxy_handler = mgr.new_proxy_handler(tcpinfo, None, false).await?;
|
||||||
|
let socket_queue = socket_queue.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
if let Err(err) =
|
||||||
|
handle_udp_gateway_session(udp, udpgw, domain_name, proxy_handler, socket_queue, ipv6_enabled).await
|
||||||
|
{
|
||||||
|
log::info!("Ending {} with \"{}\"", info, err);
|
||||||
|
}
|
||||||
|
log::trace!("Session count {}", TASK_COUNT.fetch_sub(1, Relaxed) - 1);
|
||||||
|
});
|
||||||
|
continue;
|
||||||
|
}
|
||||||
match mgr.new_proxy_handler(info, domain_name, true).await {
|
match mgr.new_proxy_handler(info, domain_name, true).await {
|
||||||
Ok(proxy_handler) => {
|
Ok(proxy_handler) => {
|
||||||
let socket_queue = socket_queue.clone();
|
let socket_queue = socket_queue.clone();
|
||||||
|
@ -429,6 +464,114 @@ async fn handle_tcp_session(
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn handle_udp_gateway_session(
|
||||||
|
mut udp_stack: IpStackUdpStream,
|
||||||
|
udpgw_client: Arc<UdpGwClient>,
|
||||||
|
domain_name: Option<String>,
|
||||||
|
proxy_handler: Arc<Mutex<dyn ProxyHandler>>,
|
||||||
|
socket_queue: Option<Arc<SocketQueue>>,
|
||||||
|
ipv6_enabled: bool,
|
||||||
|
) -> crate::Result<()> {
|
||||||
|
let (session_info, server_addr) = {
|
||||||
|
let handler = proxy_handler.lock().await;
|
||||||
|
(handler.get_session_info(), handler.get_server_addr())
|
||||||
|
};
|
||||||
|
let udpinfo = SessionInfo::new(udp_stack.local_addr(), udp_stack.peer_addr(), IpProtocol::Udp);
|
||||||
|
let udp_mtu = udpgw_client.get_udp_mtu();
|
||||||
|
let mut server_stream: UdpGwClientStream;
|
||||||
|
let server = udpgw_client.get_server_connection().await;
|
||||||
|
match server {
|
||||||
|
Some(server) => {
|
||||||
|
server_stream = server;
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
log::info!("Beginning {}", session_info);
|
||||||
|
let mut tcp_server_stream = create_tcp_stream(&socket_queue, server_addr).await?;
|
||||||
|
if let Err(e) = handle_proxy_session(&mut tcp_server_stream, proxy_handler).await {
|
||||||
|
return Err(e);
|
||||||
|
}
|
||||||
|
server_stream = UdpGwClientStream::new(udp_mtu, tcp_server_stream);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let udp_server_addr = udp_stack.peer_addr();
|
||||||
|
|
||||||
|
match domain_name {
|
||||||
|
Some(ref d) => {
|
||||||
|
log::info!("Beginning {}, domain:{}", udpinfo, d);
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
log::info!("Beginning {}", udpinfo);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log::info!("Beginning {}", udpinfo);
|
||||||
|
|
||||||
|
loop {
|
||||||
|
let len = UdpGwClient::recv_udp_packet(&mut udp_stack, &mut server_stream).await;
|
||||||
|
let read_len;
|
||||||
|
match len {
|
||||||
|
Ok(n) => {
|
||||||
|
if n == 0 {
|
||||||
|
log::info!("Ending {}", udpinfo);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
read_len = n;
|
||||||
|
crate::traffic_status::traffic_status_update(n, 0)?;
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
log::info!("Ending {} with recv_udp_packet error: {}", udpinfo, e);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Err(e) =
|
||||||
|
UdpGwClient::send_udpgw_packet(ipv6_enabled, read_len, udp_server_addr, domain_name.as_ref(), &mut server_stream).await
|
||||||
|
{
|
||||||
|
log::info!(
|
||||||
|
"{:?},Ending {} with send_udpgw_packet error: {}",
|
||||||
|
server_stream.local_addr(),
|
||||||
|
udpinfo,
|
||||||
|
e
|
||||||
|
);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
match UdpGwClient::recv_udpgw_packet(udp_mtu, &mut server_stream).await {
|
||||||
|
Ok(packet) => match packet {
|
||||||
|
//should not received keepalive
|
||||||
|
UdpGwResponse::KeepAlive => {
|
||||||
|
log::error!("Ending {} with recv keepalive", udpinfo);
|
||||||
|
let _ = server_stream.close().await;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
UdpGwResponse::Error => {
|
||||||
|
log::info!("Ending {} with recv udp error", udpinfo);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
UdpGwResponse::Data(data) => {
|
||||||
|
crate::traffic_status::traffic_status_update(0, data.len())?;
|
||||||
|
|
||||||
|
if let Err(e) = UdpGwClient::send_udp_packet(data, &mut udp_stack).await {
|
||||||
|
log::info!("Ending {} with send_udp_packet error: {}", udpinfo, e);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
Err(e) => {
|
||||||
|
log::info!("Ending {} with recv_udpgw_packet error: {}", udpinfo, e);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !server_stream.is_closed() {
|
||||||
|
udpgw_client.release_server_connection(server_stream).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
async fn handle_udp_associate_session(
|
async fn handle_udp_associate_session(
|
||||||
mut udp_stack: IpStackUdpStream,
|
mut udp_stack: IpStackUdpStream,
|
||||||
proxy_type: ProxyType,
|
proxy_type: ProxyType,
|
||||||
|
|
449
src/udpgw.rs
Normal file
449
src/udpgw.rs
Normal file
|
@ -0,0 +1,449 @@
|
||||||
|
use crate::error::Result;
|
||||||
|
use ipstack::stream::IpStackUdpStream;
|
||||||
|
use std::collections::VecDeque;
|
||||||
|
use std::hash::Hash;
|
||||||
|
use std::mem;
|
||||||
|
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
|
||||||
|
use std::pin::Pin;
|
||||||
|
use std::task::{Context, Poll};
|
||||||
|
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf};
|
||||||
|
use tokio::net::TcpStream;
|
||||||
|
use tokio::sync::Mutex;
|
||||||
|
use tokio::time::{sleep, Duration};
|
||||||
|
|
||||||
|
pub const UDPGW_MAX_CONNECTIONS: usize = 100;
|
||||||
|
pub const UDPGW_KEEPALIVE_TIME: tokio::time::Duration = std::time::Duration::from_secs(10);
|
||||||
|
pub const UDPGW_FLAG_KEEPALIVE: u8 = 0x01;
|
||||||
|
pub const UDPGW_FLAG_IPV6: u8 = 0x08;
|
||||||
|
pub const UDPGW_FLAG_DOMAIN: u8 = 0x10;
|
||||||
|
pub const UDPGW_FLAG_ERR: u8 = 0x20;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||||
|
#[repr(C)]
|
||||||
|
#[repr(packed(1))]
|
||||||
|
pub struct PackLenHeader {
|
||||||
|
packet_len: u16,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||||
|
#[repr(C)]
|
||||||
|
#[repr(packed(1))]
|
||||||
|
pub struct UdpgwHeader {
|
||||||
|
pub flags: u8,
|
||||||
|
pub conid: u16,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||||
|
#[repr(C)]
|
||||||
|
#[repr(packed(1))]
|
||||||
|
pub struct UdpgwAddrIpv4 {
|
||||||
|
pub addr_ip: u32,
|
||||||
|
pub addr_port: u16,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||||
|
#[repr(C)]
|
||||||
|
#[repr(packed(1))]
|
||||||
|
pub struct UdpgwAddrIpv6 {
|
||||||
|
pub addr_ip: [u8; 16],
|
||||||
|
pub addr_port: u16,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||||
|
pub enum UdpgwAddr {
|
||||||
|
IPV4(UdpgwAddrIpv4),
|
||||||
|
IPV6(UdpgwAddrIpv6),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<SocketAddr> for UdpgwAddr {
|
||||||
|
fn from(addr: SocketAddr) -> Self {
|
||||||
|
match addr {
|
||||||
|
SocketAddr::V4(addr_v4) => {
|
||||||
|
let ipv4_addr = addr_v4.ip().octets();
|
||||||
|
let addr_ip = u32::from_be_bytes(ipv4_addr);
|
||||||
|
UdpgwAddr::IPV4(UdpgwAddrIpv4 {
|
||||||
|
addr_ip,
|
||||||
|
addr_port: addr_v4.port(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
SocketAddr::V6(addr_v6) => {
|
||||||
|
let ipv6_addr = addr_v6.ip().octets();
|
||||||
|
UdpgwAddr::IPV6(UdpgwAddrIpv6 {
|
||||||
|
addr_ip: ipv6_addr,
|
||||||
|
addr_port: addr_v6.port(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<UdpgwAddr> for SocketAddr {
|
||||||
|
fn from(addr: UdpgwAddr) -> Self {
|
||||||
|
match addr {
|
||||||
|
UdpgwAddr::IPV4(addr_ipv4) => SocketAddrV4::new(Ipv4Addr::from(addr_ipv4.addr_ip), addr_ipv4.addr_port).into(),
|
||||||
|
UdpgwAddr::IPV6(addr_ipv6) => SocketAddrV6::new(Ipv6Addr::from(addr_ipv6.addr_ip), addr_ipv6.addr_port, 0, 0).into(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub(crate) struct UdpGwData<'a> {
|
||||||
|
flags: u8,
|
||||||
|
conid: u16,
|
||||||
|
remote_addr: SocketAddr,
|
||||||
|
udpdata: &'a [u8],
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> UdpGwData<'a> {
|
||||||
|
pub fn len(&self) -> usize {
|
||||||
|
return self.udpdata.len();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub(crate) enum UdpGwResponse<'a> {
|
||||||
|
KeepAlive,
|
||||||
|
Error,
|
||||||
|
Data(UdpGwData<'a>),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub(crate) struct UdpGwClientStream {
|
||||||
|
inner: TcpStream,
|
||||||
|
conid: u16,
|
||||||
|
tmp_buf: Vec<u8>,
|
||||||
|
send_buf: Vec<u8>,
|
||||||
|
recv_buf: Vec<u8>,
|
||||||
|
closed: bool,
|
||||||
|
last_activity: std::time::Instant,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AsyncWrite for UdpGwClientStream {
|
||||||
|
fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<tokio::io::Result<usize>> {
|
||||||
|
Pin::new(&mut self.inner).poll_write(cx, buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
|
||||||
|
Pin::new(&mut self.inner).poll_flush(cx)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
|
||||||
|
Pin::new(&mut self.inner).poll_shutdown(cx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AsyncRead for UdpGwClientStream {
|
||||||
|
fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<std::io::Result<()>> {
|
||||||
|
Pin::new(&mut self.inner).poll_read(cx, buf)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl UdpGwClientStream {
|
||||||
|
pub async fn close(&mut self) -> Result<()> {
|
||||||
|
self.inner.shutdown().await?;
|
||||||
|
self.closed = true;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
pub fn local_addr(&self) -> Result<SocketAddr> {
|
||||||
|
Ok(self.inner.local_addr()?)
|
||||||
|
}
|
||||||
|
pub fn is_closed(&mut self) -> bool {
|
||||||
|
self.closed
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn id(&mut self) -> u16 {
|
||||||
|
self.conid
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn newid(&mut self) -> u16 {
|
||||||
|
let next = self.conid;
|
||||||
|
self.conid += 1;
|
||||||
|
return next;
|
||||||
|
}
|
||||||
|
pub fn new(udp_mtu: u16, tcp_server_stream: TcpStream) -> Self {
|
||||||
|
UdpGwClientStream {
|
||||||
|
inner: tcp_server_stream,
|
||||||
|
tmp_buf: vec![0; udp_mtu.into()],
|
||||||
|
send_buf: vec![0; udp_mtu.into()],
|
||||||
|
recv_buf: vec![0; udp_mtu.into()],
|
||||||
|
last_activity: std::time::Instant::now(),
|
||||||
|
closed: false,
|
||||||
|
conid: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub(crate) struct UdpGwClient {
|
||||||
|
udp_mtu: u16,
|
||||||
|
max_connections: usize,
|
||||||
|
keepalive_time: Duration,
|
||||||
|
udpgw_bind_addr: SocketAddr,
|
||||||
|
keepalive_packet: Vec<u8>,
|
||||||
|
server_connections: Mutex<VecDeque<UdpGwClientStream>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl UdpGwClient {
|
||||||
|
pub fn new(udp_mtu: u16, max_connections: usize, keepalive_time: Duration, udpgw_bind_addr: SocketAddr) -> Self {
|
||||||
|
let mut keepalive_packet = vec![];
|
||||||
|
keepalive_packet.extend_from_slice(&(std::mem::size_of::<UdpgwHeader>() as u16).to_le_bytes());
|
||||||
|
keepalive_packet.extend_from_slice(&[UDPGW_FLAG_KEEPALIVE, 0, 0]);
|
||||||
|
let server_connections = Mutex::new(VecDeque::new());
|
||||||
|
return UdpGwClient {
|
||||||
|
udp_mtu,
|
||||||
|
max_connections,
|
||||||
|
udpgw_bind_addr,
|
||||||
|
keepalive_time,
|
||||||
|
keepalive_packet,
|
||||||
|
server_connections: server_connections,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn get_udp_mtu(&self) -> u16 {
|
||||||
|
self.udp_mtu
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) async fn get_server_connection(&self) -> Option<UdpGwClientStream> {
|
||||||
|
self.server_connections.lock().await.pop_front()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) async fn release_server_connection(&self, stream: UdpGwClientStream) {
|
||||||
|
if self.server_connections.lock().await.len() < self.max_connections {
|
||||||
|
self.server_connections.lock().await.push_back(stream);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn get_udpgw_bind_addr(&self) -> SocketAddr {
|
||||||
|
return self.udpgw_bind_addr;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) async fn heartbeat_task(&self) {
|
||||||
|
loop {
|
||||||
|
sleep(self.keepalive_time).await;
|
||||||
|
if let Some(mut stream) = self.get_server_connection().await {
|
||||||
|
if stream.last_activity.elapsed() < self.keepalive_time {
|
||||||
|
self.release_server_connection(stream).await;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
log::debug!("{:?}:{} send keepalive", stream.local_addr(), stream.id());
|
||||||
|
if let Err(e) = stream.write_all(&self.keepalive_packet).await {
|
||||||
|
let _ = stream.close().await;
|
||||||
|
log::warn!("{:?}:{} Heartbeat failed: {}", stream.local_addr(), stream.id(), e);
|
||||||
|
} else {
|
||||||
|
stream.last_activity = std::time::Instant::now();
|
||||||
|
match UdpGwClient::recv_udpgw_packet(self.udp_mtu, &mut stream).await {
|
||||||
|
Ok(UdpGwResponse::KeepAlive) => {
|
||||||
|
self.release_server_connection(stream).await;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
//shoud not receive other
|
||||||
|
_ => {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn parse_udp_response(udp_mtu: u16, data_len: usize, data: &[u8]) -> Result<UdpGwResponse> {
|
||||||
|
if data_len < mem::size_of::<UdpgwHeader>() {
|
||||||
|
return Err("Invalid udpgw data".into());
|
||||||
|
}
|
||||||
|
let header_bytes = &data[..mem::size_of::<UdpgwHeader>()];
|
||||||
|
let header = UdpgwHeader {
|
||||||
|
flags: header_bytes[0],
|
||||||
|
conid: u16::from_le_bytes([header_bytes[1], header_bytes[2]]),
|
||||||
|
};
|
||||||
|
|
||||||
|
let flags = header.flags;
|
||||||
|
let conid = header.conid;
|
||||||
|
|
||||||
|
// parse address
|
||||||
|
let ip_data = &data[mem::size_of::<UdpgwHeader>()..];
|
||||||
|
let mut data_len = data_len - mem::size_of::<UdpgwHeader>();
|
||||||
|
|
||||||
|
if flags & UDPGW_FLAG_ERR != 0 {
|
||||||
|
return Ok(UdpGwResponse::Error);
|
||||||
|
}
|
||||||
|
|
||||||
|
if flags & UDPGW_FLAG_ERR != 0 {
|
||||||
|
return Ok(UdpGwResponse::KeepAlive);
|
||||||
|
}
|
||||||
|
|
||||||
|
if flags & UDPGW_FLAG_IPV6 != 0 {
|
||||||
|
if data_len < mem::size_of::<UdpgwAddrIpv6>() {
|
||||||
|
return Err("ipv6 Invalid UDP data".into());
|
||||||
|
}
|
||||||
|
let addr_ipv6_bytes = &ip_data[..mem::size_of::<UdpgwAddrIpv6>()];
|
||||||
|
let addr_ipv6 = UdpgwAddrIpv6 {
|
||||||
|
addr_ip: addr_ipv6_bytes[..16].try_into().map_err(|_| "Failed to convert slice to array")?,
|
||||||
|
addr_port: u16::from_be_bytes([addr_ipv6_bytes[16], addr_ipv6_bytes[17]]),
|
||||||
|
};
|
||||||
|
data_len -= mem::size_of::<UdpgwAddrIpv6>();
|
||||||
|
// check payload length
|
||||||
|
if data_len > udp_mtu as usize {
|
||||||
|
return Err("too much data".into());
|
||||||
|
}
|
||||||
|
return Ok(UdpGwResponse::Data(UdpGwData {
|
||||||
|
flags,
|
||||||
|
conid,
|
||||||
|
remote_addr: UdpgwAddr::IPV6(addr_ipv6).into(),
|
||||||
|
udpdata: &ip_data[mem::size_of::<UdpgwAddrIpv6>()..(data_len + mem::size_of::<UdpgwAddrIpv6>())],
|
||||||
|
}));
|
||||||
|
} else {
|
||||||
|
if data_len < mem::size_of::<UdpgwAddrIpv4>() {
|
||||||
|
return Err("ipv4 Invalid UDP data".into());
|
||||||
|
}
|
||||||
|
let addr_ipv4_bytes = &ip_data[..mem::size_of::<UdpgwAddrIpv4>()];
|
||||||
|
let addr_ipv4 = UdpgwAddrIpv4 {
|
||||||
|
addr_ip: u32::from_be_bytes([addr_ipv4_bytes[0], addr_ipv4_bytes[1], addr_ipv4_bytes[2], addr_ipv4_bytes[3]]),
|
||||||
|
addr_port: u16::from_be_bytes([addr_ipv4_bytes[4], addr_ipv4_bytes[5]]),
|
||||||
|
};
|
||||||
|
data_len -= mem::size_of::<UdpgwAddrIpv4>();
|
||||||
|
|
||||||
|
// check payload length
|
||||||
|
if data_len > udp_mtu as usize {
|
||||||
|
return Err("too much data".into());
|
||||||
|
}
|
||||||
|
return Ok(UdpGwResponse::Data(UdpGwData {
|
||||||
|
flags,
|
||||||
|
conid,
|
||||||
|
remote_addr: UdpgwAddr::IPV4(addr_ipv4).into(),
|
||||||
|
udpdata: &ip_data[mem::size_of::<UdpgwAddrIpv4>()..(data_len + mem::size_of::<UdpgwAddrIpv4>())],
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) async fn recv_udp_packet(
|
||||||
|
udp_stack: &mut IpStackUdpStream,
|
||||||
|
stream: &mut UdpGwClientStream,
|
||||||
|
) -> std::result::Result<usize, std::io::Error> {
|
||||||
|
return udp_stack.read(&mut stream.tmp_buf).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) async fn send_udp_packet<'a>(
|
||||||
|
packet: UdpGwData<'a>,
|
||||||
|
udp_stack: &mut IpStackUdpStream,
|
||||||
|
) -> std::result::Result<(), std::io::Error> {
|
||||||
|
return udp_stack.write_all(&packet.udpdata).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) async fn recv_udpgw_packet(udp_mtu: u16, stream: &mut UdpGwClientStream) -> Result<UdpGwResponse> {
|
||||||
|
stream.recv_buf.resize(2, 0);
|
||||||
|
let result;
|
||||||
|
match tokio::time::timeout(tokio::time::Duration::from_secs(10), stream.inner.read(&mut stream.recv_buf)).await {
|
||||||
|
Ok(ret) => {
|
||||||
|
result = ret;
|
||||||
|
}
|
||||||
|
Err(_e) => {
|
||||||
|
let _ = stream.close().await;
|
||||||
|
return Err(format!("{:?} wait tcp data timeout", stream.local_addr()).into());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
match result {
|
||||||
|
Ok(0) => {
|
||||||
|
let _ = stream.close().await;
|
||||||
|
return Err(format!("{:?} tcp connection closed", stream.local_addr()).into());
|
||||||
|
}
|
||||||
|
Ok(n) => {
|
||||||
|
if n < std::mem::size_of::<PackLenHeader>() {
|
||||||
|
return Err("received PackLenHeader error".into());
|
||||||
|
}
|
||||||
|
let packet_len = u16::from_le_bytes([stream.recv_buf[0], stream.recv_buf[1]]);
|
||||||
|
if packet_len > udp_mtu {
|
||||||
|
return Err("packet too long".into());
|
||||||
|
}
|
||||||
|
stream.recv_buf.resize(udp_mtu as usize, 0);
|
||||||
|
let mut left_len: usize = packet_len as usize;
|
||||||
|
let mut recv_len = 0;
|
||||||
|
while left_len > 0 {
|
||||||
|
if let Ok(len) = stream.inner.read(&mut stream.recv_buf[recv_len..left_len]).await {
|
||||||
|
if len == 0 {
|
||||||
|
let _ = stream.close().await;
|
||||||
|
return Err(format!("{:?} tcp connection closed", stream.local_addr()).into());
|
||||||
|
}
|
||||||
|
recv_len += len;
|
||||||
|
left_len -= len;
|
||||||
|
} else {
|
||||||
|
let _ = stream.close().await;
|
||||||
|
return Err(format!("{:?} tcp connection closed", stream.local_addr()).into());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
stream.last_activity = std::time::Instant::now();
|
||||||
|
return UdpGwClient::parse_udp_response(udp_mtu, packet_len as usize, &stream.recv_buf);
|
||||||
|
}
|
||||||
|
Err(_) => {
|
||||||
|
let _ = stream.close().await;
|
||||||
|
return Err(format!("{:?} tcp read error", stream.local_addr()).into());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) async fn send_udpgw_packet(
|
||||||
|
ipv6_enabled: bool,
|
||||||
|
len: usize,
|
||||||
|
remote_addr: SocketAddr,
|
||||||
|
domain: Option<&String>,
|
||||||
|
stream: &mut UdpGwClientStream,
|
||||||
|
) -> Result<()> {
|
||||||
|
stream.send_buf.clear();
|
||||||
|
let conid = stream.newid();
|
||||||
|
let data = &stream.tmp_buf;
|
||||||
|
let mut pack_len = std::mem::size_of::<UdpgwHeader>() + len;
|
||||||
|
let packet = &mut stream.send_buf;
|
||||||
|
let mut flags = 0;
|
||||||
|
match domain {
|
||||||
|
Some(domain) => {
|
||||||
|
let addr_port = match remote_addr.into() {
|
||||||
|
UdpgwAddr::IPV4(addr_ipv4) => addr_ipv4.addr_port,
|
||||||
|
UdpgwAddr::IPV6(addr_ipv6) => addr_ipv6.addr_port,
|
||||||
|
};
|
||||||
|
pack_len += std::mem::size_of::<u16>();
|
||||||
|
let domain_len = domain.len();
|
||||||
|
if domain_len > 255 {
|
||||||
|
return Err("InvalidDomain".into());
|
||||||
|
}
|
||||||
|
pack_len += domain_len + 1;
|
||||||
|
flags = UDPGW_FLAG_DOMAIN;
|
||||||
|
packet.extend_from_slice(&(pack_len as u16).to_le_bytes());
|
||||||
|
packet.extend_from_slice(&[flags]);
|
||||||
|
packet.extend_from_slice(&conid.to_le_bytes());
|
||||||
|
packet.extend_from_slice(&addr_port.to_be_bytes());
|
||||||
|
packet.extend_from_slice(domain.as_bytes());
|
||||||
|
packet.push(0);
|
||||||
|
packet.extend_from_slice(&data[..len]);
|
||||||
|
}
|
||||||
|
None => match remote_addr.into() {
|
||||||
|
UdpgwAddr::IPV4(addr_ipv4) => {
|
||||||
|
pack_len += std::mem::size_of::<UdpgwAddrIpv4>();
|
||||||
|
packet.extend_from_slice(&(pack_len as u16).to_le_bytes());
|
||||||
|
packet.extend_from_slice(&[flags]);
|
||||||
|
packet.extend_from_slice(&conid.to_le_bytes());
|
||||||
|
packet.extend_from_slice(&addr_ipv4.addr_ip.to_be_bytes());
|
||||||
|
packet.extend_from_slice(&addr_ipv4.addr_port.to_be_bytes());
|
||||||
|
packet.extend_from_slice(&data[..len]);
|
||||||
|
}
|
||||||
|
UdpgwAddr::IPV6(addr_ipv6) => {
|
||||||
|
if !ipv6_enabled {
|
||||||
|
return Err("ipv6 not support".into());
|
||||||
|
}
|
||||||
|
flags = UDPGW_FLAG_IPV6;
|
||||||
|
pack_len += std::mem::size_of::<UdpgwAddrIpv6>();
|
||||||
|
packet.extend_from_slice(&(pack_len as u16).to_le_bytes());
|
||||||
|
packet.extend_from_slice(&[flags]);
|
||||||
|
packet.extend_from_slice(&conid.to_le_bytes());
|
||||||
|
packet.extend_from_slice(&addr_ipv6.addr_ip);
|
||||||
|
packet.extend_from_slice(&addr_ipv6.addr_port.to_be_bytes());
|
||||||
|
packet.extend_from_slice(&data[..len]);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
stream.inner.write_all(&packet).await?;
|
||||||
|
|
||||||
|
stream.last_activity = std::time::Instant::now();
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
Loading…
Add table
Reference in a new issue