read code

This commit is contained in:
ssrlive 2024-10-21 00:57:34 +08:00
parent b2482ab411
commit efea708ca1

View file

@ -1,21 +1,20 @@
use std::mem; use std::{
use std::net::Ipv4Addr; net::{Ipv4Addr, SocketAddr, SocketAddrV4, ToSocketAddrs},
use std::net::SocketAddr; sync::Arc,
use std::net::SocketAddrV4; };
use std::net::ToSocketAddrs; use tokio::{
use std::sync::Arc; io::{AsyncReadExt, AsyncWriteExt},
use tokio::io::{AsyncReadExt, AsyncWriteExt}; net::{
use tokio::net::tcp::ReadHalf; tcp::{ReadHalf, WriteHalf},
use tokio::net::TcpListener; UdpSocket,
use tokio::net::UdpSocket; },
use tokio::sync::mpsc; sync::mpsc::{self, Receiver, Sender},
use tokio::sync::mpsc::Sender; };
use tun2proxy::udpgw::*; use tun2proxy::{udpgw::*, ArgVerbosity, Result};
use tun2proxy::ArgVerbosity;
use tun2proxy::Result;
pub(crate) const CLIENT_DISCONNECT_TIMEOUT: tokio::time::Duration = std::time::Duration::from_secs(60); pub(crate) const CLIENT_DISCONNECT_TIMEOUT: tokio::time::Duration = std::time::Duration::from_secs(60);
#[derive(Debug)] #[derive(Debug, Clone)]
struct UdpRequest { struct UdpRequest {
flags: u8, flags: u8,
server_addr: SocketAddr, server_addr: SocketAddr,
@ -23,14 +22,23 @@ struct UdpRequest {
data: Vec<u8>, data: Vec<u8>,
} }
#[derive(Debug)] #[derive(Debug, Clone)]
struct Client { pub struct Client {
#[allow(dead_code)]
addr: SocketAddr, addr: SocketAddr,
buf: Vec<u8>, buf: Vec<u8>,
last_activity: std::time::Instant, last_activity: std::time::Instant,
} }
impl Client {
pub fn new(addr: SocketAddr) -> Self {
Self {
addr,
buf: vec![],
last_activity: std::time::Instant::now(),
}
}
}
#[derive(Debug, Clone, clap::Parser)] #[derive(Debug, Clone, clap::Parser)]
pub struct UdpGwArgs { pub struct UdpGwArgs {
/// UDP mtu /// UDP mtu
@ -83,10 +91,10 @@ async fn send_keepalive_response(tx: Sender<Vec<u8>>, conid: u16) {
} }
pub fn parse_udp(udp_mtu: u16, data_len: usize, data: &[u8]) -> Result<(&[u8], u8, u16, SocketAddr)> { pub fn parse_udp(udp_mtu: u16, data_len: usize, data: &[u8]) -> Result<(&[u8], u8, u16, SocketAddr)> {
if data_len < mem::size_of::<UdpgwHeader>() { if data_len < std::mem::size_of::<UdpgwHeader>() {
return Err("Invalid udpgw data".into()); return Err("Invalid udpgw data".into());
} }
let header_bytes = &data[..mem::size_of::<UdpgwHeader>()]; let header_bytes = &data[..std::mem::size_of::<UdpgwHeader>()];
let header = UdpgwHeader { let header = UdpgwHeader {
flags: header_bytes[0], flags: header_bytes[0],
conid: u16::from_le_bytes([header_bytes[1], header_bytes[2]]), conid: u16::from_le_bytes([header_bytes[1], header_bytes[2]]),
@ -100,10 +108,10 @@ pub fn parse_udp(udp_mtu: u16, data_len: usize, data: &[u8]) -> Result<(&[u8], u
return Ok((data, flags, conid, SocketAddrV4::new(Ipv4Addr::from(0), 0).into())); return Ok((data, flags, conid, SocketAddrV4::new(Ipv4Addr::from(0), 0).into()));
} }
let ip_data = &data[mem::size_of::<UdpgwHeader>()..]; let ip_data = &data[std::mem::size_of::<UdpgwHeader>()..];
let mut data_len = data_len - mem::size_of::<UdpgwHeader>(); let mut data_len = data_len - std::mem::size_of::<UdpgwHeader>();
// port_len + min(ipv4/ipv6/(domain_len + 1)) // port_len + min(ipv4/ipv6/(domain_len + 1))
if data_len < mem::size_of::<u16>() + 2 { if data_len < std::mem::size_of::<u16>() + 2 {
return Err("Invalid udpgw data".into()); return Err("Invalid udpgw data".into());
} }
if flags & UDPGW_FLAG_DOMAIN != 0 { if flags & UDPGW_FLAG_DOMAIN != 0 {
@ -134,42 +142,42 @@ pub fn parse_udp(udp_mtu: u16, data_len: usize, data: &[u8]) -> Result<(&[u8], u
Err("missing domain name".into()) Err("missing domain name".into())
} }
} else if flags & UDPGW_FLAG_IPV6 != 0 { } else if flags & UDPGW_FLAG_IPV6 != 0 {
if data_len < mem::size_of::<UdpgwAddrIpv6>() { if data_len < std::mem::size_of::<UdpgwAddrIpv6>() {
return Err("Ipv6 Invalid UDP data".into()); return Err("Ipv6 Invalid UDP data".into());
} }
let addr_ipv6_bytes = &ip_data[..mem::size_of::<UdpgwAddrIpv6>()]; let addr_ipv6_bytes = &ip_data[..std::mem::size_of::<UdpgwAddrIpv6>()];
let addr_ipv6 = UdpgwAddrIpv6 { let addr_ipv6 = UdpgwAddrIpv6 {
addr_ip: addr_ipv6_bytes[..16].try_into().map_err(|_| "Failed to convert slice to array")?, addr_ip: addr_ipv6_bytes[..16].try_into().map_err(|_| "Failed to convert slice to array")?,
addr_port: u16::from_be_bytes([addr_ipv6_bytes[16], addr_ipv6_bytes[17]]), addr_port: u16::from_be_bytes([addr_ipv6_bytes[16], addr_ipv6_bytes[17]]),
}; };
data_len -= mem::size_of::<UdpgwAddrIpv6>(); data_len -= std::mem::size_of::<UdpgwAddrIpv6>();
if data_len > udp_mtu as usize { if data_len > udp_mtu as usize {
return Err("too much data".into()); return Err("too much data".into());
} }
return Ok(( return Ok((
&ip_data[mem::size_of::<UdpgwAddrIpv6>()..(data_len + mem::size_of::<UdpgwAddrIpv6>())], &ip_data[std::mem::size_of::<UdpgwAddrIpv6>()..(data_len + std::mem::size_of::<UdpgwAddrIpv6>())],
flags, flags,
conid, conid,
UdpgwAddr::IPV6(addr_ipv6).into(), UdpgwAddr::IPV6(addr_ipv6).into(),
)); ));
} else { } else {
if data_len < mem::size_of::<UdpgwAddrIpv4>() { if data_len < std::mem::size_of::<UdpgwAddrIpv4>() {
return Err("Ipv4 Invalid UDP data".into()); return Err("Ipv4 Invalid UDP data".into());
} }
let addr_ipv4_bytes = &ip_data[..mem::size_of::<UdpgwAddrIpv4>()]; let addr_ipv4_bytes = &ip_data[..std::mem::size_of::<UdpgwAddrIpv4>()];
let addr_ipv4 = UdpgwAddrIpv4 { let addr_ipv4 = UdpgwAddrIpv4 {
addr_ip: u32::from_be_bytes([addr_ipv4_bytes[0], addr_ipv4_bytes[1], addr_ipv4_bytes[2], addr_ipv4_bytes[3]]), addr_ip: u32::from_be_bytes([addr_ipv4_bytes[0], addr_ipv4_bytes[1], addr_ipv4_bytes[2], addr_ipv4_bytes[3]]),
addr_port: u16::from_be_bytes([addr_ipv4_bytes[4], addr_ipv4_bytes[5]]), addr_port: u16::from_be_bytes([addr_ipv4_bytes[4], addr_ipv4_bytes[5]]),
}; };
data_len -= mem::size_of::<UdpgwAddrIpv4>(); data_len -= std::mem::size_of::<UdpgwAddrIpv4>();
if data_len > udp_mtu as usize { if data_len > udp_mtu as usize {
return Err("too much data".into()); return Err("too much data".into());
} }
return Ok(( return Ok((
&ip_data[mem::size_of::<UdpgwAddrIpv4>()..(data_len + mem::size_of::<UdpgwAddrIpv4>())], &ip_data[std::mem::size_of::<UdpgwAddrIpv4>()..(data_len + std::mem::size_of::<UdpgwAddrIpv4>())],
flags, flags,
conid, conid,
UdpgwAddr::IPV4(addr_ipv4).into(), UdpgwAddr::IPV4(addr_ipv4).into(),
@ -193,10 +201,10 @@ async fn process_udp(addr: SocketAddr, udp_timeout: u64, tx: Sender<Vec<u8>>, co
Ok(ret) => { Ok(ret) => {
let (len, _addr) = ret?; let (len, _addr) = ret?;
let mut packet = vec![]; let mut packet = vec![];
let mut pack_len = mem::size_of::<UdpgwHeader>() + len; let mut pack_len = std::mem::size_of::<UdpgwHeader>() + len;
match con.server_addr.into() { match con.server_addr.into() {
UdpgwAddr::IPV4(addr_ipv4) => { UdpgwAddr::IPV4(addr_ipv4) => {
pack_len += mem::size_of::<UdpgwAddrIpv4>(); pack_len += std::mem::size_of::<UdpgwAddrIpv4>();
packet.extend_from_slice(&(pack_len as u16).to_le_bytes()); packet.extend_from_slice(&(pack_len as u16).to_le_bytes());
packet.extend_from_slice(&[con.flags]); packet.extend_from_slice(&[con.flags]);
packet.extend_from_slice(&con.conid.to_le_bytes()); packet.extend_from_slice(&con.conid.to_le_bytes());
@ -205,7 +213,7 @@ async fn process_udp(addr: SocketAddr, udp_timeout: u64, tx: Sender<Vec<u8>>, co
packet.extend_from_slice(&con.data[..len]); packet.extend_from_slice(&con.data[..len]);
} }
UdpgwAddr::IPV6(addr_ipv6) => { UdpgwAddr::IPV6(addr_ipv6) => {
pack_len += mem::size_of::<UdpgwAddrIpv6>(); pack_len += std::mem::size_of::<UdpgwAddrIpv6>();
packet.extend_from_slice(&(pack_len as u16).to_le_bytes()); packet.extend_from_slice(&(pack_len as u16).to_le_bytes());
packet.extend_from_slice(&[con.flags]); packet.extend_from_slice(&[con.flags]);
packet.extend_from_slice(&con.conid.to_le_bytes()); packet.extend_from_slice(&con.conid.to_le_bytes());
@ -225,102 +233,113 @@ async fn process_udp(addr: SocketAddr, udp_timeout: u64, tx: Sender<Vec<u8>>, co
Ok(()) Ok(())
} }
async fn process_client_udp_req<'a>(args: &UdpGwArgs, tx: Sender<Vec<u8>>, mut client: Client, mut tcp_read_stream: ReadHalf<'a>) { async fn process_client_udp_req(args: &UdpGwArgs, tx: Sender<Vec<u8>>, client: Client, mut reader: ReadHalf<'_>) -> std::io::Result<()> {
let mut client = client;
let mut buf = vec![0; args.udp_mtu as usize]; let mut buf = vec![0; args.udp_mtu as usize];
let mut len_buf = [0; mem::size_of::<PackLenHeader>()]; let mut len_buf = [0; std::mem::size_of::<PackLenHeader>()];
let udp_mtu = args.udp_mtu; let udp_mtu = args.udp_mtu;
let udp_timeout = args.udp_timeout; let udp_timeout = args.udp_timeout;
'out: loop { 'out: loop {
let result = match tokio::time::timeout(tokio::time::Duration::from_secs(2), tcp_read_stream.read(&mut len_buf)).await { let result = match tokio::time::timeout(tokio::time::Duration::from_secs(2), reader.read(&mut len_buf)).await {
Ok(ret) => ret, Ok(ret) => ret,
Err(_e) => { Err(_e) => {
if client.last_activity.elapsed() >= CLIENT_DISCONNECT_TIMEOUT { if client.last_activity.elapsed() >= CLIENT_DISCONNECT_TIMEOUT {
log::debug!("client {} last_activity elapsed", client.addr); log::debug!("client {} last_activity elapsed", client.addr);
return; break;
} }
continue; continue;
} }
}; };
match result { let n = result?;
Ok(0) => break, // Connection closed if n == 0 {
Ok(n) => { // Connection closed
if n < mem::size_of::<PackLenHeader>() { break;
log::error!("client {} received PackLenHeader error", client.addr); }
break; if n < std::mem::size_of::<PackLenHeader>() {
} log::error!("client {} received PackLenHeader error", client.addr);
let packet_len = u16::from_le_bytes([len_buf[0], len_buf[1]]); break;
if packet_len > udp_mtu { }
log::error!("client {} received packet too long", client.addr); let packet_len = u16::from_le_bytes([len_buf[0], len_buf[1]]);
break; if packet_len > udp_mtu {
} log::error!("client {} received packet too long", client.addr);
log::debug!("client {} recvied packet len {}", client.addr, packet_len); break;
client.buf.clear(); }
let mut left_len: usize = packet_len as usize; log::debug!("client {} recvied packet len {}", client.addr, packet_len);
while left_len > 0 { client.buf.clear();
if let Ok(len) = tcp_read_stream.read(&mut buf[..left_len]).await { let mut left_len: usize = packet_len as usize;
if len == 0 { while left_len > 0 {
break 'out; let len = reader.read(&mut buf[..left_len]).await?;
} if len == 0 {
client.buf.extend_from_slice(&buf[..len]); break 'out;
left_len -= len;
} else {
break 'out;
}
}
client.last_activity = std::time::Instant::now();
let ret = parse_udp(udp_mtu, client.buf.len(), &client.buf);
if let Ok((udpdata, flags, conid, reqaddr)) = ret {
if flags & UDPGW_FLAG_KEEPALIVE != 0 {
log::debug!("client {} send keepalive", client.addr);
send_keepalive_response(tx.clone(), conid).await;
continue;
}
log::debug!(
"client {} received udp data,flags:{},conid:{},addr:{:?},data len:{}",
client.addr,
flags,
conid,
reqaddr,
udpdata.len()
);
let mut req = UdpRequest {
server_addr: reqaddr,
conid,
flags,
data: udpdata.to_vec(),
};
let tx1 = tx.clone();
let tx2 = tx.clone();
tokio::spawn(async move {
if let Err(e) = process_udp(client.addr, udp_timeout, tx1, &mut req).await {
send_error(tx2, &mut req).await;
log::error!("client {} process_udp {}", client.addr, e);
}
});
} else {
log::error!("client {} parse_udp_data {:?}", client.addr, ret.err());
}
} }
Err(_) => { client.buf.extend_from_slice(&buf[..len]);
log::error!("client {} tcp_read_stream error", client.addr); left_len -= len;
break; }
client.last_activity = std::time::Instant::now();
let ret = parse_udp(udp_mtu, client.buf.len(), &client.buf);
if let Ok((udpdata, flags, conid, reqaddr)) = ret {
if flags & UDPGW_FLAG_KEEPALIVE != 0 {
log::debug!("client {} send keepalive", client.addr);
send_keepalive_response(tx.clone(), conid).await;
continue;
} }
log::debug!(
"client {} received udp data,flags:{},conid:{},addr:{:?},data len:{}",
client.addr,
flags,
conid,
reqaddr,
udpdata.len()
);
let mut req = UdpRequest {
server_addr: reqaddr,
conid,
flags,
data: udpdata.to_vec(),
};
let tx1 = tx.clone();
let tx2 = tx.clone();
tokio::spawn(async move {
if let Err(e) = process_udp(client.addr, udp_timeout, tx1, &mut req).await {
send_error(tx2, &mut req).await;
log::error!("client {} process_udp {}", client.addr, e);
}
});
} else {
log::error!("client {} parse_udp_data {:?}", client.addr, ret.err());
} }
} }
Ok(())
}
async fn write_to_client(addr: SocketAddr, mut writer: WriteHalf<'_>, mut rx: Receiver<Vec<u8>>) -> std::io::Result<()> {
loop {
let Some(udp_response) = rx.recv().await else {
log::trace!("client {} channel closed", addr);
break;
};
if udp_response.is_empty() {
log::trace!("client {} channel recv 0", addr);
break;
}
log::trace!("send response to client {} len {}", addr, udp_response.len());
let _r = writer.write(&udp_response).await?;
}
Ok(())
} }
#[tokio::main] #[tokio::main]
async fn main() -> Result<()> { async fn main() -> Result<()> {
let args = Arc::new(UdpGwArgs::parse_args()); let args = Arc::new(UdpGwArgs::parse_args());
let tcp_listener = TcpListener::bind(args.listen_addr).await?; let tcp_listener = tokio::net::TcpListener::bind(args.listen_addr).await?;
let default = format!("{:?}", args.verbosity); let default = format!("{:?}", args.verbosity);
env_logger::Builder::from_env(env_logger::Env::default().default_filter_or(default)).init(); env_logger::Builder::from_env(env_logger::Env::default().default_filter_or(default)).init();
log::info!("UDP GW Server started"); log::info!("UDP Gateway Server running at {}", args.listen_addr);
#[cfg(unix)] #[cfg(unix)]
if args.daemonize { if args.daemonize {
@ -339,29 +358,17 @@ async fn main() -> Result<()> {
loop { loop {
let (mut tcp_stream, addr) = tcp_listener.accept().await?; let (mut tcp_stream, addr) = tcp_listener.accept().await?;
let client = Client { let client = Client::new(addr);
addr,
buf: vec![],
last_activity: std::time::Instant::now(),
};
log::info!("client {} connected", addr); log::info!("client {} connected", addr);
let params = Arc::clone(&args); let params = args.clone();
tokio::spawn(async move { tokio::spawn(async move {
let (tx, mut rx) = mpsc::channel::<Vec<u8>>(100); let (tx, rx) = mpsc::channel::<Vec<u8>>(100);
let (tcp_read_stream, mut tcp_write_stream) = tcp_stream.split(); let (tcp_read_stream, tcp_write_stream) = tcp_stream.split();
tokio::select! { let res = tokio::select! {
_ = process_client_udp_req(&params, tx, client, tcp_read_stream) =>{} v = process_client_udp_req(&params, tx, client, tcp_read_stream) => v,
_ = async { v = write_to_client(addr, tcp_write_stream, rx) => v,
loop };
{ log::info!("client {} disconnected with {:?}", addr, res);
if let Some(udp_response) = rx.recv().await {
log::debug!("send udp_response len {}",udp_response.len());
let _ = tcp_write_stream.write(&udp_response).await;
}
}
} => {}
}
log::info!("client {} disconnected", addr);
}); });
} }
} }