From efea708ca183c146ed95f20a125da2b8a7204c0f Mon Sep 17 00:00:00 2001 From: ssrlive <30760636+ssrlive@users.noreply.github.com> Date: Mon, 21 Oct 2024 00:57:34 +0800 Subject: [PATCH] read code --- src/bin/udpgw_server.rs | 255 +++++++++++++++++++++------------------- 1 file changed, 131 insertions(+), 124 deletions(-) diff --git a/src/bin/udpgw_server.rs b/src/bin/udpgw_server.rs index 386976f..47e9d83 100644 --- a/src/bin/udpgw_server.rs +++ b/src/bin/udpgw_server.rs @@ -1,21 +1,20 @@ -use std::mem; -use std::net::Ipv4Addr; -use std::net::SocketAddr; -use std::net::SocketAddrV4; -use std::net::ToSocketAddrs; -use std::sync::Arc; -use tokio::io::{AsyncReadExt, AsyncWriteExt}; -use tokio::net::tcp::ReadHalf; -use tokio::net::TcpListener; -use tokio::net::UdpSocket; -use tokio::sync::mpsc; -use tokio::sync::mpsc::Sender; -use tun2proxy::udpgw::*; -use tun2proxy::ArgVerbosity; -use tun2proxy::Result; +use std::{ + net::{Ipv4Addr, SocketAddr, SocketAddrV4, ToSocketAddrs}, + sync::Arc, +}; +use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + net::{ + tcp::{ReadHalf, WriteHalf}, + UdpSocket, + }, + sync::mpsc::{self, Receiver, Sender}, +}; +use tun2proxy::{udpgw::*, ArgVerbosity, Result}; + pub(crate) const CLIENT_DISCONNECT_TIMEOUT: tokio::time::Duration = std::time::Duration::from_secs(60); -#[derive(Debug)] +#[derive(Debug, Clone)] struct UdpRequest { flags: u8, server_addr: SocketAddr, @@ -23,14 +22,23 @@ struct UdpRequest { data: Vec, } -#[derive(Debug)] -struct Client { - #[allow(dead_code)] +#[derive(Debug, Clone)] +pub struct Client { addr: SocketAddr, buf: Vec, last_activity: std::time::Instant, } +impl Client { + pub fn new(addr: SocketAddr) -> Self { + Self { + addr, + buf: vec![], + last_activity: std::time::Instant::now(), + } + } +} + #[derive(Debug, Clone, clap::Parser)] pub struct UdpGwArgs { /// UDP mtu @@ -83,10 +91,10 @@ async fn send_keepalive_response(tx: Sender>, conid: u16) { } pub fn parse_udp(udp_mtu: u16, data_len: usize, data: &[u8]) -> Result<(&[u8], u8, u16, SocketAddr)> { - if data_len < mem::size_of::() { + if data_len < std::mem::size_of::() { return Err("Invalid udpgw data".into()); } - let header_bytes = &data[..mem::size_of::()]; + let header_bytes = &data[..std::mem::size_of::()]; let header = UdpgwHeader { flags: header_bytes[0], conid: u16::from_le_bytes([header_bytes[1], header_bytes[2]]), @@ -100,10 +108,10 @@ pub fn parse_udp(udp_mtu: u16, data_len: usize, data: &[u8]) -> Result<(&[u8], u return Ok((data, flags, conid, SocketAddrV4::new(Ipv4Addr::from(0), 0).into())); } - let ip_data = &data[mem::size_of::()..]; - let mut data_len = data_len - mem::size_of::(); + let ip_data = &data[std::mem::size_of::()..]; + let mut data_len = data_len - std::mem::size_of::(); // port_len + min(ipv4/ipv6/(domain_len + 1)) - if data_len < mem::size_of::() + 2 { + if data_len < std::mem::size_of::() + 2 { return Err("Invalid udpgw data".into()); } if flags & UDPGW_FLAG_DOMAIN != 0 { @@ -134,42 +142,42 @@ pub fn parse_udp(udp_mtu: u16, data_len: usize, data: &[u8]) -> Result<(&[u8], u Err("missing domain name".into()) } } else if flags & UDPGW_FLAG_IPV6 != 0 { - if data_len < mem::size_of::() { + if data_len < std::mem::size_of::() { return Err("Ipv6 Invalid UDP data".into()); } - let addr_ipv6_bytes = &ip_data[..mem::size_of::()]; + let addr_ipv6_bytes = &ip_data[..std::mem::size_of::()]; let addr_ipv6 = UdpgwAddrIpv6 { addr_ip: addr_ipv6_bytes[..16].try_into().map_err(|_| "Failed to convert slice to array")?, addr_port: u16::from_be_bytes([addr_ipv6_bytes[16], addr_ipv6_bytes[17]]), }; - data_len -= mem::size_of::(); + data_len -= std::mem::size_of::(); if data_len > udp_mtu as usize { return Err("too much data".into()); } return Ok(( - &ip_data[mem::size_of::()..(data_len + mem::size_of::())], + &ip_data[std::mem::size_of::()..(data_len + std::mem::size_of::())], flags, conid, UdpgwAddr::IPV6(addr_ipv6).into(), )); } else { - if data_len < mem::size_of::() { + if data_len < std::mem::size_of::() { return Err("Ipv4 Invalid UDP data".into()); } - let addr_ipv4_bytes = &ip_data[..mem::size_of::()]; + let addr_ipv4_bytes = &ip_data[..std::mem::size_of::()]; let addr_ipv4 = UdpgwAddrIpv4 { addr_ip: u32::from_be_bytes([addr_ipv4_bytes[0], addr_ipv4_bytes[1], addr_ipv4_bytes[2], addr_ipv4_bytes[3]]), addr_port: u16::from_be_bytes([addr_ipv4_bytes[4], addr_ipv4_bytes[5]]), }; - data_len -= mem::size_of::(); + data_len -= std::mem::size_of::(); if data_len > udp_mtu as usize { return Err("too much data".into()); } return Ok(( - &ip_data[mem::size_of::()..(data_len + mem::size_of::())], + &ip_data[std::mem::size_of::()..(data_len + std::mem::size_of::())], flags, conid, UdpgwAddr::IPV4(addr_ipv4).into(), @@ -193,10 +201,10 @@ async fn process_udp(addr: SocketAddr, udp_timeout: u64, tx: Sender>, co Ok(ret) => { let (len, _addr) = ret?; let mut packet = vec![]; - let mut pack_len = mem::size_of::() + len; + let mut pack_len = std::mem::size_of::() + len; match con.server_addr.into() { UdpgwAddr::IPV4(addr_ipv4) => { - pack_len += mem::size_of::(); + pack_len += std::mem::size_of::(); packet.extend_from_slice(&(pack_len as u16).to_le_bytes()); packet.extend_from_slice(&[con.flags]); packet.extend_from_slice(&con.conid.to_le_bytes()); @@ -205,7 +213,7 @@ async fn process_udp(addr: SocketAddr, udp_timeout: u64, tx: Sender>, co packet.extend_from_slice(&con.data[..len]); } UdpgwAddr::IPV6(addr_ipv6) => { - pack_len += mem::size_of::(); + pack_len += std::mem::size_of::(); packet.extend_from_slice(&(pack_len as u16).to_le_bytes()); packet.extend_from_slice(&[con.flags]); packet.extend_from_slice(&con.conid.to_le_bytes()); @@ -225,102 +233,113 @@ async fn process_udp(addr: SocketAddr, udp_timeout: u64, tx: Sender>, co Ok(()) } -async fn process_client_udp_req<'a>(args: &UdpGwArgs, tx: Sender>, mut client: Client, mut tcp_read_stream: ReadHalf<'a>) { +async fn process_client_udp_req(args: &UdpGwArgs, tx: Sender>, client: Client, mut reader: ReadHalf<'_>) -> std::io::Result<()> { + let mut client = client; let mut buf = vec![0; args.udp_mtu as usize]; - let mut len_buf = [0; mem::size_of::()]; + let mut len_buf = [0; std::mem::size_of::()]; let udp_mtu = args.udp_mtu; let udp_timeout = args.udp_timeout; 'out: loop { - let result = match tokio::time::timeout(tokio::time::Duration::from_secs(2), tcp_read_stream.read(&mut len_buf)).await { + let result = match tokio::time::timeout(tokio::time::Duration::from_secs(2), reader.read(&mut len_buf)).await { Ok(ret) => ret, Err(_e) => { if client.last_activity.elapsed() >= CLIENT_DISCONNECT_TIMEOUT { log::debug!("client {} last_activity elapsed", client.addr); - return; + break; } continue; } }; - match result { - Ok(0) => break, // Connection closed - Ok(n) => { - if n < mem::size_of::() { - log::error!("client {} received PackLenHeader error", client.addr); - break; - } - let packet_len = u16::from_le_bytes([len_buf[0], len_buf[1]]); - if packet_len > udp_mtu { - log::error!("client {} received packet too long", client.addr); - break; - } - log::debug!("client {} recvied packet len {}", client.addr, packet_len); - client.buf.clear(); - let mut left_len: usize = packet_len as usize; - while left_len > 0 { - if let Ok(len) = tcp_read_stream.read(&mut buf[..left_len]).await { - if len == 0 { - break 'out; - } - client.buf.extend_from_slice(&buf[..len]); - left_len -= len; - } else { - break 'out; - } - } - client.last_activity = std::time::Instant::now(); - let ret = parse_udp(udp_mtu, client.buf.len(), &client.buf); - if let Ok((udpdata, flags, conid, reqaddr)) = ret { - if flags & UDPGW_FLAG_KEEPALIVE != 0 { - log::debug!("client {} send keepalive", client.addr); - send_keepalive_response(tx.clone(), conid).await; - continue; - } - log::debug!( - "client {} received udp data,flags:{},conid:{},addr:{:?},data len:{}", - client.addr, - flags, - conid, - reqaddr, - udpdata.len() - ); - let mut req = UdpRequest { - server_addr: reqaddr, - conid, - flags, - data: udpdata.to_vec(), - }; - let tx1 = tx.clone(); - let tx2 = tx.clone(); - tokio::spawn(async move { - if let Err(e) = process_udp(client.addr, udp_timeout, tx1, &mut req).await { - send_error(tx2, &mut req).await; - log::error!("client {} process_udp {}", client.addr, e); - } - }); - } else { - log::error!("client {} parse_udp_data {:?}", client.addr, ret.err()); - } + let n = result?; + if n == 0 { + // Connection closed + break; + } + if n < std::mem::size_of::() { + log::error!("client {} received PackLenHeader error", client.addr); + break; + } + let packet_len = u16::from_le_bytes([len_buf[0], len_buf[1]]); + if packet_len > udp_mtu { + log::error!("client {} received packet too long", client.addr); + break; + } + log::debug!("client {} recvied packet len {}", client.addr, packet_len); + client.buf.clear(); + let mut left_len: usize = packet_len as usize; + while left_len > 0 { + let len = reader.read(&mut buf[..left_len]).await?; + if len == 0 { + break 'out; } - Err(_) => { - log::error!("client {} tcp_read_stream error", client.addr); - break; + client.buf.extend_from_slice(&buf[..len]); + left_len -= len; + } + client.last_activity = std::time::Instant::now(); + let ret = parse_udp(udp_mtu, client.buf.len(), &client.buf); + if let Ok((udpdata, flags, conid, reqaddr)) = ret { + if flags & UDPGW_FLAG_KEEPALIVE != 0 { + log::debug!("client {} send keepalive", client.addr); + send_keepalive_response(tx.clone(), conid).await; + continue; } + log::debug!( + "client {} received udp data,flags:{},conid:{},addr:{:?},data len:{}", + client.addr, + flags, + conid, + reqaddr, + udpdata.len() + ); + let mut req = UdpRequest { + server_addr: reqaddr, + conid, + flags, + data: udpdata.to_vec(), + }; + let tx1 = tx.clone(); + let tx2 = tx.clone(); + tokio::spawn(async move { + if let Err(e) = process_udp(client.addr, udp_timeout, tx1, &mut req).await { + send_error(tx2, &mut req).await; + log::error!("client {} process_udp {}", client.addr, e); + } + }); + } else { + log::error!("client {} parse_udp_data {:?}", client.addr, ret.err()); } } + Ok(()) +} + +async fn write_to_client(addr: SocketAddr, mut writer: WriteHalf<'_>, mut rx: Receiver>) -> std::io::Result<()> { + loop { + let Some(udp_response) = rx.recv().await else { + log::trace!("client {} channel closed", addr); + break; + }; + if udp_response.is_empty() { + log::trace!("client {} channel recv 0", addr); + break; + } + log::trace!("send response to client {} len {}", addr, udp_response.len()); + let _r = writer.write(&udp_response).await?; + } + Ok(()) } #[tokio::main] async fn main() -> Result<()> { let args = Arc::new(UdpGwArgs::parse_args()); - let tcp_listener = TcpListener::bind(args.listen_addr).await?; + let tcp_listener = tokio::net::TcpListener::bind(args.listen_addr).await?; let default = format!("{:?}", args.verbosity); env_logger::Builder::from_env(env_logger::Env::default().default_filter_or(default)).init(); - log::info!("UDP GW Server started"); + log::info!("UDP Gateway Server running at {}", args.listen_addr); #[cfg(unix)] if args.daemonize { @@ -339,29 +358,17 @@ async fn main() -> Result<()> { loop { let (mut tcp_stream, addr) = tcp_listener.accept().await?; - let client = Client { - addr, - buf: vec![], - last_activity: std::time::Instant::now(), - }; + let client = Client::new(addr); log::info!("client {} connected", addr); - let params = Arc::clone(&args); + let params = args.clone(); tokio::spawn(async move { - let (tx, mut rx) = mpsc::channel::>(100); - let (tcp_read_stream, mut tcp_write_stream) = tcp_stream.split(); - tokio::select! { - _ = process_client_udp_req(¶ms, tx, client, tcp_read_stream) =>{} - _ = async { - loop - { - if let Some(udp_response) = rx.recv().await { - log::debug!("send udp_response len {}",udp_response.len()); - let _ = tcp_write_stream.write(&udp_response).await; - } - } - } => {} - } - log::info!("client {} disconnected", addr); + let (tx, rx) = mpsc::channel::>(100); + let (tcp_read_stream, tcp_write_stream) = tcp_stream.split(); + let res = tokio::select! { + v = process_client_udp_req(¶ms, tx, client, tcp_read_stream) => v, + v = write_to_client(addr, tcp_write_stream, rx) => v, + }; + log::info!("client {} disconnected with {:?}", addr, res); }); } }