read code

This commit is contained in:
ssrlive 2024-10-21 00:57:34 +08:00
parent b2482ab411
commit efea708ca1

View file

@ -1,21 +1,20 @@
use std::mem;
use std::net::Ipv4Addr;
use std::net::SocketAddr;
use std::net::SocketAddrV4;
use std::net::ToSocketAddrs;
use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::tcp::ReadHalf;
use tokio::net::TcpListener;
use tokio::net::UdpSocket;
use tokio::sync::mpsc;
use tokio::sync::mpsc::Sender;
use tun2proxy::udpgw::*;
use tun2proxy::ArgVerbosity;
use tun2proxy::Result;
use std::{
net::{Ipv4Addr, SocketAddr, SocketAddrV4, ToSocketAddrs},
sync::Arc,
};
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::{
tcp::{ReadHalf, WriteHalf},
UdpSocket,
},
sync::mpsc::{self, Receiver, Sender},
};
use tun2proxy::{udpgw::*, ArgVerbosity, Result};
pub(crate) const CLIENT_DISCONNECT_TIMEOUT: tokio::time::Duration = std::time::Duration::from_secs(60);
#[derive(Debug)]
#[derive(Debug, Clone)]
struct UdpRequest {
flags: u8,
server_addr: SocketAddr,
@ -23,14 +22,23 @@ struct UdpRequest {
data: Vec<u8>,
}
#[derive(Debug)]
struct Client {
#[allow(dead_code)]
#[derive(Debug, Clone)]
pub struct Client {
addr: SocketAddr,
buf: Vec<u8>,
last_activity: std::time::Instant,
}
impl Client {
pub fn new(addr: SocketAddr) -> Self {
Self {
addr,
buf: vec![],
last_activity: std::time::Instant::now(),
}
}
}
#[derive(Debug, Clone, clap::Parser)]
pub struct UdpGwArgs {
/// UDP mtu
@ -83,10 +91,10 @@ async fn send_keepalive_response(tx: Sender<Vec<u8>>, conid: u16) {
}
pub fn parse_udp(udp_mtu: u16, data_len: usize, data: &[u8]) -> Result<(&[u8], u8, u16, SocketAddr)> {
if data_len < mem::size_of::<UdpgwHeader>() {
if data_len < std::mem::size_of::<UdpgwHeader>() {
return Err("Invalid udpgw data".into());
}
let header_bytes = &data[..mem::size_of::<UdpgwHeader>()];
let header_bytes = &data[..std::mem::size_of::<UdpgwHeader>()];
let header = UdpgwHeader {
flags: header_bytes[0],
conid: u16::from_le_bytes([header_bytes[1], header_bytes[2]]),
@ -100,10 +108,10 @@ pub fn parse_udp(udp_mtu: u16, data_len: usize, data: &[u8]) -> Result<(&[u8], u
return Ok((data, flags, conid, SocketAddrV4::new(Ipv4Addr::from(0), 0).into()));
}
let ip_data = &data[mem::size_of::<UdpgwHeader>()..];
let mut data_len = data_len - mem::size_of::<UdpgwHeader>();
let ip_data = &data[std::mem::size_of::<UdpgwHeader>()..];
let mut data_len = data_len - std::mem::size_of::<UdpgwHeader>();
// port_len + min(ipv4/ipv6/(domain_len + 1))
if data_len < mem::size_of::<u16>() + 2 {
if data_len < std::mem::size_of::<u16>() + 2 {
return Err("Invalid udpgw data".into());
}
if flags & UDPGW_FLAG_DOMAIN != 0 {
@ -134,42 +142,42 @@ pub fn parse_udp(udp_mtu: u16, data_len: usize, data: &[u8]) -> Result<(&[u8], u
Err("missing domain name".into())
}
} else if flags & UDPGW_FLAG_IPV6 != 0 {
if data_len < mem::size_of::<UdpgwAddrIpv6>() {
if data_len < std::mem::size_of::<UdpgwAddrIpv6>() {
return Err("Ipv6 Invalid UDP data".into());
}
let addr_ipv6_bytes = &ip_data[..mem::size_of::<UdpgwAddrIpv6>()];
let addr_ipv6_bytes = &ip_data[..std::mem::size_of::<UdpgwAddrIpv6>()];
let addr_ipv6 = UdpgwAddrIpv6 {
addr_ip: addr_ipv6_bytes[..16].try_into().map_err(|_| "Failed to convert slice to array")?,
addr_port: u16::from_be_bytes([addr_ipv6_bytes[16], addr_ipv6_bytes[17]]),
};
data_len -= mem::size_of::<UdpgwAddrIpv6>();
data_len -= std::mem::size_of::<UdpgwAddrIpv6>();
if data_len > udp_mtu as usize {
return Err("too much data".into());
}
return Ok((
&ip_data[mem::size_of::<UdpgwAddrIpv6>()..(data_len + mem::size_of::<UdpgwAddrIpv6>())],
&ip_data[std::mem::size_of::<UdpgwAddrIpv6>()..(data_len + std::mem::size_of::<UdpgwAddrIpv6>())],
flags,
conid,
UdpgwAddr::IPV6(addr_ipv6).into(),
));
} else {
if data_len < mem::size_of::<UdpgwAddrIpv4>() {
if data_len < std::mem::size_of::<UdpgwAddrIpv4>() {
return Err("Ipv4 Invalid UDP data".into());
}
let addr_ipv4_bytes = &ip_data[..mem::size_of::<UdpgwAddrIpv4>()];
let addr_ipv4_bytes = &ip_data[..std::mem::size_of::<UdpgwAddrIpv4>()];
let addr_ipv4 = UdpgwAddrIpv4 {
addr_ip: u32::from_be_bytes([addr_ipv4_bytes[0], addr_ipv4_bytes[1], addr_ipv4_bytes[2], addr_ipv4_bytes[3]]),
addr_port: u16::from_be_bytes([addr_ipv4_bytes[4], addr_ipv4_bytes[5]]),
};
data_len -= mem::size_of::<UdpgwAddrIpv4>();
data_len -= std::mem::size_of::<UdpgwAddrIpv4>();
if data_len > udp_mtu as usize {
return Err("too much data".into());
}
return Ok((
&ip_data[mem::size_of::<UdpgwAddrIpv4>()..(data_len + mem::size_of::<UdpgwAddrIpv4>())],
&ip_data[std::mem::size_of::<UdpgwAddrIpv4>()..(data_len + std::mem::size_of::<UdpgwAddrIpv4>())],
flags,
conid,
UdpgwAddr::IPV4(addr_ipv4).into(),
@ -193,10 +201,10 @@ async fn process_udp(addr: SocketAddr, udp_timeout: u64, tx: Sender<Vec<u8>>, co
Ok(ret) => {
let (len, _addr) = ret?;
let mut packet = vec![];
let mut pack_len = mem::size_of::<UdpgwHeader>() + len;
let mut pack_len = std::mem::size_of::<UdpgwHeader>() + len;
match con.server_addr.into() {
UdpgwAddr::IPV4(addr_ipv4) => {
pack_len += mem::size_of::<UdpgwAddrIpv4>();
pack_len += std::mem::size_of::<UdpgwAddrIpv4>();
packet.extend_from_slice(&(pack_len as u16).to_le_bytes());
packet.extend_from_slice(&[con.flags]);
packet.extend_from_slice(&con.conid.to_le_bytes());
@ -205,7 +213,7 @@ async fn process_udp(addr: SocketAddr, udp_timeout: u64, tx: Sender<Vec<u8>>, co
packet.extend_from_slice(&con.data[..len]);
}
UdpgwAddr::IPV6(addr_ipv6) => {
pack_len += mem::size_of::<UdpgwAddrIpv6>();
pack_len += std::mem::size_of::<UdpgwAddrIpv6>();
packet.extend_from_slice(&(pack_len as u16).to_le_bytes());
packet.extend_from_slice(&[con.flags]);
packet.extend_from_slice(&con.conid.to_le_bytes());
@ -225,102 +233,113 @@ async fn process_udp(addr: SocketAddr, udp_timeout: u64, tx: Sender<Vec<u8>>, co
Ok(())
}
async fn process_client_udp_req<'a>(args: &UdpGwArgs, tx: Sender<Vec<u8>>, mut client: Client, mut tcp_read_stream: ReadHalf<'a>) {
async fn process_client_udp_req(args: &UdpGwArgs, tx: Sender<Vec<u8>>, client: Client, mut reader: ReadHalf<'_>) -> std::io::Result<()> {
let mut client = client;
let mut buf = vec![0; args.udp_mtu as usize];
let mut len_buf = [0; mem::size_of::<PackLenHeader>()];
let mut len_buf = [0; std::mem::size_of::<PackLenHeader>()];
let udp_mtu = args.udp_mtu;
let udp_timeout = args.udp_timeout;
'out: loop {
let result = match tokio::time::timeout(tokio::time::Duration::from_secs(2), tcp_read_stream.read(&mut len_buf)).await {
let result = match tokio::time::timeout(tokio::time::Duration::from_secs(2), reader.read(&mut len_buf)).await {
Ok(ret) => ret,
Err(_e) => {
if client.last_activity.elapsed() >= CLIENT_DISCONNECT_TIMEOUT {
log::debug!("client {} last_activity elapsed", client.addr);
return;
break;
}
continue;
}
};
match result {
Ok(0) => break, // Connection closed
Ok(n) => {
if n < mem::size_of::<PackLenHeader>() {
log::error!("client {} received PackLenHeader error", client.addr);
break;
}
let packet_len = u16::from_le_bytes([len_buf[0], len_buf[1]]);
if packet_len > udp_mtu {
log::error!("client {} received packet too long", client.addr);
break;
}
log::debug!("client {} recvied packet len {}", client.addr, packet_len);
client.buf.clear();
let mut left_len: usize = packet_len as usize;
while left_len > 0 {
if let Ok(len) = tcp_read_stream.read(&mut buf[..left_len]).await {
if len == 0 {
break 'out;
}
client.buf.extend_from_slice(&buf[..len]);
left_len -= len;
} else {
break 'out;
}
}
client.last_activity = std::time::Instant::now();
let ret = parse_udp(udp_mtu, client.buf.len(), &client.buf);
if let Ok((udpdata, flags, conid, reqaddr)) = ret {
if flags & UDPGW_FLAG_KEEPALIVE != 0 {
log::debug!("client {} send keepalive", client.addr);
send_keepalive_response(tx.clone(), conid).await;
continue;
}
log::debug!(
"client {} received udp data,flags:{},conid:{},addr:{:?},data len:{}",
client.addr,
flags,
conid,
reqaddr,
udpdata.len()
);
let mut req = UdpRequest {
server_addr: reqaddr,
conid,
flags,
data: udpdata.to_vec(),
};
let tx1 = tx.clone();
let tx2 = tx.clone();
tokio::spawn(async move {
if let Err(e) = process_udp(client.addr, udp_timeout, tx1, &mut req).await {
send_error(tx2, &mut req).await;
log::error!("client {} process_udp {}", client.addr, e);
}
});
} else {
log::error!("client {} parse_udp_data {:?}", client.addr, ret.err());
}
let n = result?;
if n == 0 {
// Connection closed
break;
}
if n < std::mem::size_of::<PackLenHeader>() {
log::error!("client {} received PackLenHeader error", client.addr);
break;
}
let packet_len = u16::from_le_bytes([len_buf[0], len_buf[1]]);
if packet_len > udp_mtu {
log::error!("client {} received packet too long", client.addr);
break;
}
log::debug!("client {} recvied packet len {}", client.addr, packet_len);
client.buf.clear();
let mut left_len: usize = packet_len as usize;
while left_len > 0 {
let len = reader.read(&mut buf[..left_len]).await?;
if len == 0 {
break 'out;
}
Err(_) => {
log::error!("client {} tcp_read_stream error", client.addr);
break;
client.buf.extend_from_slice(&buf[..len]);
left_len -= len;
}
client.last_activity = std::time::Instant::now();
let ret = parse_udp(udp_mtu, client.buf.len(), &client.buf);
if let Ok((udpdata, flags, conid, reqaddr)) = ret {
if flags & UDPGW_FLAG_KEEPALIVE != 0 {
log::debug!("client {} send keepalive", client.addr);
send_keepalive_response(tx.clone(), conid).await;
continue;
}
log::debug!(
"client {} received udp data,flags:{},conid:{},addr:{:?},data len:{}",
client.addr,
flags,
conid,
reqaddr,
udpdata.len()
);
let mut req = UdpRequest {
server_addr: reqaddr,
conid,
flags,
data: udpdata.to_vec(),
};
let tx1 = tx.clone();
let tx2 = tx.clone();
tokio::spawn(async move {
if let Err(e) = process_udp(client.addr, udp_timeout, tx1, &mut req).await {
send_error(tx2, &mut req).await;
log::error!("client {} process_udp {}", client.addr, e);
}
});
} else {
log::error!("client {} parse_udp_data {:?}", client.addr, ret.err());
}
}
Ok(())
}
async fn write_to_client(addr: SocketAddr, mut writer: WriteHalf<'_>, mut rx: Receiver<Vec<u8>>) -> std::io::Result<()> {
loop {
let Some(udp_response) = rx.recv().await else {
log::trace!("client {} channel closed", addr);
break;
};
if udp_response.is_empty() {
log::trace!("client {} channel recv 0", addr);
break;
}
log::trace!("send response to client {} len {}", addr, udp_response.len());
let _r = writer.write(&udp_response).await?;
}
Ok(())
}
#[tokio::main]
async fn main() -> Result<()> {
let args = Arc::new(UdpGwArgs::parse_args());
let tcp_listener = TcpListener::bind(args.listen_addr).await?;
let tcp_listener = tokio::net::TcpListener::bind(args.listen_addr).await?;
let default = format!("{:?}", args.verbosity);
env_logger::Builder::from_env(env_logger::Env::default().default_filter_or(default)).init();
log::info!("UDP GW Server started");
log::info!("UDP Gateway Server running at {}", args.listen_addr);
#[cfg(unix)]
if args.daemonize {
@ -339,29 +358,17 @@ async fn main() -> Result<()> {
loop {
let (mut tcp_stream, addr) = tcp_listener.accept().await?;
let client = Client {
addr,
buf: vec![],
last_activity: std::time::Instant::now(),
};
let client = Client::new(addr);
log::info!("client {} connected", addr);
let params = Arc::clone(&args);
let params = args.clone();
tokio::spawn(async move {
let (tx, mut rx) = mpsc::channel::<Vec<u8>>(100);
let (tcp_read_stream, mut tcp_write_stream) = tcp_stream.split();
tokio::select! {
_ = process_client_udp_req(&params, tx, client, tcp_read_stream) =>{}
_ = async {
loop
{
if let Some(udp_response) = rx.recv().await {
log::debug!("send udp_response len {}",udp_response.len());
let _ = tcp_write_stream.write(&udp_response).await;
}
}
} => {}
}
log::info!("client {} disconnected", addr);
let (tx, rx) = mpsc::channel::<Vec<u8>>(100);
let (tcp_read_stream, tcp_write_stream) = tcp_stream.split();
let res = tokio::select! {
v = process_client_udp_req(&params, tx, client, tcp_read_stream) => v,
v = write_to_client(addr, tcp_write_stream, rx) => v,
};
log::info!("client {} disconnected with {:?}", addr, res);
});
}
}