mirror of
https://github.com/tun2proxy/tun2proxy.git
synced 2025-04-20 13:59:10 +00:00
read code
This commit is contained in:
parent
b2482ab411
commit
efea708ca1
1 changed files with 131 additions and 124 deletions
|
@ -1,21 +1,20 @@
|
|||
use std::mem;
|
||||
use std::net::Ipv4Addr;
|
||||
use std::net::SocketAddr;
|
||||
use std::net::SocketAddrV4;
|
||||
use std::net::ToSocketAddrs;
|
||||
use std::sync::Arc;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::net::tcp::ReadHalf;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::net::UdpSocket;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::mpsc::Sender;
|
||||
use tun2proxy::udpgw::*;
|
||||
use tun2proxy::ArgVerbosity;
|
||||
use tun2proxy::Result;
|
||||
use std::{
|
||||
net::{Ipv4Addr, SocketAddr, SocketAddrV4, ToSocketAddrs},
|
||||
sync::Arc,
|
||||
};
|
||||
use tokio::{
|
||||
io::{AsyncReadExt, AsyncWriteExt},
|
||||
net::{
|
||||
tcp::{ReadHalf, WriteHalf},
|
||||
UdpSocket,
|
||||
},
|
||||
sync::mpsc::{self, Receiver, Sender},
|
||||
};
|
||||
use tun2proxy::{udpgw::*, ArgVerbosity, Result};
|
||||
|
||||
pub(crate) const CLIENT_DISCONNECT_TIMEOUT: tokio::time::Duration = std::time::Duration::from_secs(60);
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
struct UdpRequest {
|
||||
flags: u8,
|
||||
server_addr: SocketAddr,
|
||||
|
@ -23,14 +22,23 @@ struct UdpRequest {
|
|||
data: Vec<u8>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Client {
|
||||
#[allow(dead_code)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Client {
|
||||
addr: SocketAddr,
|
||||
buf: Vec<u8>,
|
||||
last_activity: std::time::Instant,
|
||||
}
|
||||
|
||||
impl Client {
|
||||
pub fn new(addr: SocketAddr) -> Self {
|
||||
Self {
|
||||
addr,
|
||||
buf: vec![],
|
||||
last_activity: std::time::Instant::now(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, clap::Parser)]
|
||||
pub struct UdpGwArgs {
|
||||
/// UDP mtu
|
||||
|
@ -83,10 +91,10 @@ async fn send_keepalive_response(tx: Sender<Vec<u8>>, conid: u16) {
|
|||
}
|
||||
|
||||
pub fn parse_udp(udp_mtu: u16, data_len: usize, data: &[u8]) -> Result<(&[u8], u8, u16, SocketAddr)> {
|
||||
if data_len < mem::size_of::<UdpgwHeader>() {
|
||||
if data_len < std::mem::size_of::<UdpgwHeader>() {
|
||||
return Err("Invalid udpgw data".into());
|
||||
}
|
||||
let header_bytes = &data[..mem::size_of::<UdpgwHeader>()];
|
||||
let header_bytes = &data[..std::mem::size_of::<UdpgwHeader>()];
|
||||
let header = UdpgwHeader {
|
||||
flags: header_bytes[0],
|
||||
conid: u16::from_le_bytes([header_bytes[1], header_bytes[2]]),
|
||||
|
@ -100,10 +108,10 @@ pub fn parse_udp(udp_mtu: u16, data_len: usize, data: &[u8]) -> Result<(&[u8], u
|
|||
return Ok((data, flags, conid, SocketAddrV4::new(Ipv4Addr::from(0), 0).into()));
|
||||
}
|
||||
|
||||
let ip_data = &data[mem::size_of::<UdpgwHeader>()..];
|
||||
let mut data_len = data_len - mem::size_of::<UdpgwHeader>();
|
||||
let ip_data = &data[std::mem::size_of::<UdpgwHeader>()..];
|
||||
let mut data_len = data_len - std::mem::size_of::<UdpgwHeader>();
|
||||
// port_len + min(ipv4/ipv6/(domain_len + 1))
|
||||
if data_len < mem::size_of::<u16>() + 2 {
|
||||
if data_len < std::mem::size_of::<u16>() + 2 {
|
||||
return Err("Invalid udpgw data".into());
|
||||
}
|
||||
if flags & UDPGW_FLAG_DOMAIN != 0 {
|
||||
|
@ -134,42 +142,42 @@ pub fn parse_udp(udp_mtu: u16, data_len: usize, data: &[u8]) -> Result<(&[u8], u
|
|||
Err("missing domain name".into())
|
||||
}
|
||||
} else if flags & UDPGW_FLAG_IPV6 != 0 {
|
||||
if data_len < mem::size_of::<UdpgwAddrIpv6>() {
|
||||
if data_len < std::mem::size_of::<UdpgwAddrIpv6>() {
|
||||
return Err("Ipv6 Invalid UDP data".into());
|
||||
}
|
||||
let addr_ipv6_bytes = &ip_data[..mem::size_of::<UdpgwAddrIpv6>()];
|
||||
let addr_ipv6_bytes = &ip_data[..std::mem::size_of::<UdpgwAddrIpv6>()];
|
||||
let addr_ipv6 = UdpgwAddrIpv6 {
|
||||
addr_ip: addr_ipv6_bytes[..16].try_into().map_err(|_| "Failed to convert slice to array")?,
|
||||
addr_port: u16::from_be_bytes([addr_ipv6_bytes[16], addr_ipv6_bytes[17]]),
|
||||
};
|
||||
data_len -= mem::size_of::<UdpgwAddrIpv6>();
|
||||
data_len -= std::mem::size_of::<UdpgwAddrIpv6>();
|
||||
|
||||
if data_len > udp_mtu as usize {
|
||||
return Err("too much data".into());
|
||||
}
|
||||
return Ok((
|
||||
&ip_data[mem::size_of::<UdpgwAddrIpv6>()..(data_len + mem::size_of::<UdpgwAddrIpv6>())],
|
||||
&ip_data[std::mem::size_of::<UdpgwAddrIpv6>()..(data_len + std::mem::size_of::<UdpgwAddrIpv6>())],
|
||||
flags,
|
||||
conid,
|
||||
UdpgwAddr::IPV6(addr_ipv6).into(),
|
||||
));
|
||||
} else {
|
||||
if data_len < mem::size_of::<UdpgwAddrIpv4>() {
|
||||
if data_len < std::mem::size_of::<UdpgwAddrIpv4>() {
|
||||
return Err("Ipv4 Invalid UDP data".into());
|
||||
}
|
||||
let addr_ipv4_bytes = &ip_data[..mem::size_of::<UdpgwAddrIpv4>()];
|
||||
let addr_ipv4_bytes = &ip_data[..std::mem::size_of::<UdpgwAddrIpv4>()];
|
||||
let addr_ipv4 = UdpgwAddrIpv4 {
|
||||
addr_ip: u32::from_be_bytes([addr_ipv4_bytes[0], addr_ipv4_bytes[1], addr_ipv4_bytes[2], addr_ipv4_bytes[3]]),
|
||||
addr_port: u16::from_be_bytes([addr_ipv4_bytes[4], addr_ipv4_bytes[5]]),
|
||||
};
|
||||
data_len -= mem::size_of::<UdpgwAddrIpv4>();
|
||||
data_len -= std::mem::size_of::<UdpgwAddrIpv4>();
|
||||
|
||||
if data_len > udp_mtu as usize {
|
||||
return Err("too much data".into());
|
||||
}
|
||||
|
||||
return Ok((
|
||||
&ip_data[mem::size_of::<UdpgwAddrIpv4>()..(data_len + mem::size_of::<UdpgwAddrIpv4>())],
|
||||
&ip_data[std::mem::size_of::<UdpgwAddrIpv4>()..(data_len + std::mem::size_of::<UdpgwAddrIpv4>())],
|
||||
flags,
|
||||
conid,
|
||||
UdpgwAddr::IPV4(addr_ipv4).into(),
|
||||
|
@ -193,10 +201,10 @@ async fn process_udp(addr: SocketAddr, udp_timeout: u64, tx: Sender<Vec<u8>>, co
|
|||
Ok(ret) => {
|
||||
let (len, _addr) = ret?;
|
||||
let mut packet = vec![];
|
||||
let mut pack_len = mem::size_of::<UdpgwHeader>() + len;
|
||||
let mut pack_len = std::mem::size_of::<UdpgwHeader>() + len;
|
||||
match con.server_addr.into() {
|
||||
UdpgwAddr::IPV4(addr_ipv4) => {
|
||||
pack_len += mem::size_of::<UdpgwAddrIpv4>();
|
||||
pack_len += std::mem::size_of::<UdpgwAddrIpv4>();
|
||||
packet.extend_from_slice(&(pack_len as u16).to_le_bytes());
|
||||
packet.extend_from_slice(&[con.flags]);
|
||||
packet.extend_from_slice(&con.conid.to_le_bytes());
|
||||
|
@ -205,7 +213,7 @@ async fn process_udp(addr: SocketAddr, udp_timeout: u64, tx: Sender<Vec<u8>>, co
|
|||
packet.extend_from_slice(&con.data[..len]);
|
||||
}
|
||||
UdpgwAddr::IPV6(addr_ipv6) => {
|
||||
pack_len += mem::size_of::<UdpgwAddrIpv6>();
|
||||
pack_len += std::mem::size_of::<UdpgwAddrIpv6>();
|
||||
packet.extend_from_slice(&(pack_len as u16).to_le_bytes());
|
||||
packet.extend_from_slice(&[con.flags]);
|
||||
packet.extend_from_slice(&con.conid.to_le_bytes());
|
||||
|
@ -225,102 +233,113 @@ async fn process_udp(addr: SocketAddr, udp_timeout: u64, tx: Sender<Vec<u8>>, co
|
|||
Ok(())
|
||||
}
|
||||
|
||||
async fn process_client_udp_req<'a>(args: &UdpGwArgs, tx: Sender<Vec<u8>>, mut client: Client, mut tcp_read_stream: ReadHalf<'a>) {
|
||||
async fn process_client_udp_req(args: &UdpGwArgs, tx: Sender<Vec<u8>>, client: Client, mut reader: ReadHalf<'_>) -> std::io::Result<()> {
|
||||
let mut client = client;
|
||||
let mut buf = vec![0; args.udp_mtu as usize];
|
||||
let mut len_buf = [0; mem::size_of::<PackLenHeader>()];
|
||||
let mut len_buf = [0; std::mem::size_of::<PackLenHeader>()];
|
||||
let udp_mtu = args.udp_mtu;
|
||||
let udp_timeout = args.udp_timeout;
|
||||
|
||||
'out: loop {
|
||||
let result = match tokio::time::timeout(tokio::time::Duration::from_secs(2), tcp_read_stream.read(&mut len_buf)).await {
|
||||
let result = match tokio::time::timeout(tokio::time::Duration::from_secs(2), reader.read(&mut len_buf)).await {
|
||||
Ok(ret) => ret,
|
||||
Err(_e) => {
|
||||
if client.last_activity.elapsed() >= CLIENT_DISCONNECT_TIMEOUT {
|
||||
log::debug!("client {} last_activity elapsed", client.addr);
|
||||
return;
|
||||
break;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
};
|
||||
match result {
|
||||
Ok(0) => break, // Connection closed
|
||||
Ok(n) => {
|
||||
if n < mem::size_of::<PackLenHeader>() {
|
||||
log::error!("client {} received PackLenHeader error", client.addr);
|
||||
break;
|
||||
}
|
||||
let packet_len = u16::from_le_bytes([len_buf[0], len_buf[1]]);
|
||||
if packet_len > udp_mtu {
|
||||
log::error!("client {} received packet too long", client.addr);
|
||||
break;
|
||||
}
|
||||
log::debug!("client {} recvied packet len {}", client.addr, packet_len);
|
||||
client.buf.clear();
|
||||
let mut left_len: usize = packet_len as usize;
|
||||
while left_len > 0 {
|
||||
if let Ok(len) = tcp_read_stream.read(&mut buf[..left_len]).await {
|
||||
if len == 0 {
|
||||
break 'out;
|
||||
}
|
||||
client.buf.extend_from_slice(&buf[..len]);
|
||||
left_len -= len;
|
||||
} else {
|
||||
break 'out;
|
||||
}
|
||||
}
|
||||
client.last_activity = std::time::Instant::now();
|
||||
let ret = parse_udp(udp_mtu, client.buf.len(), &client.buf);
|
||||
if let Ok((udpdata, flags, conid, reqaddr)) = ret {
|
||||
if flags & UDPGW_FLAG_KEEPALIVE != 0 {
|
||||
log::debug!("client {} send keepalive", client.addr);
|
||||
send_keepalive_response(tx.clone(), conid).await;
|
||||
continue;
|
||||
}
|
||||
log::debug!(
|
||||
"client {} received udp data,flags:{},conid:{},addr:{:?},data len:{}",
|
||||
client.addr,
|
||||
flags,
|
||||
conid,
|
||||
reqaddr,
|
||||
udpdata.len()
|
||||
);
|
||||
let mut req = UdpRequest {
|
||||
server_addr: reqaddr,
|
||||
conid,
|
||||
flags,
|
||||
data: udpdata.to_vec(),
|
||||
};
|
||||
let tx1 = tx.clone();
|
||||
let tx2 = tx.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = process_udp(client.addr, udp_timeout, tx1, &mut req).await {
|
||||
send_error(tx2, &mut req).await;
|
||||
log::error!("client {} process_udp {}", client.addr, e);
|
||||
}
|
||||
});
|
||||
} else {
|
||||
log::error!("client {} parse_udp_data {:?}", client.addr, ret.err());
|
||||
}
|
||||
let n = result?;
|
||||
if n == 0 {
|
||||
// Connection closed
|
||||
break;
|
||||
}
|
||||
if n < std::mem::size_of::<PackLenHeader>() {
|
||||
log::error!("client {} received PackLenHeader error", client.addr);
|
||||
break;
|
||||
}
|
||||
let packet_len = u16::from_le_bytes([len_buf[0], len_buf[1]]);
|
||||
if packet_len > udp_mtu {
|
||||
log::error!("client {} received packet too long", client.addr);
|
||||
break;
|
||||
}
|
||||
log::debug!("client {} recvied packet len {}", client.addr, packet_len);
|
||||
client.buf.clear();
|
||||
let mut left_len: usize = packet_len as usize;
|
||||
while left_len > 0 {
|
||||
let len = reader.read(&mut buf[..left_len]).await?;
|
||||
if len == 0 {
|
||||
break 'out;
|
||||
}
|
||||
Err(_) => {
|
||||
log::error!("client {} tcp_read_stream error", client.addr);
|
||||
break;
|
||||
client.buf.extend_from_slice(&buf[..len]);
|
||||
left_len -= len;
|
||||
}
|
||||
client.last_activity = std::time::Instant::now();
|
||||
let ret = parse_udp(udp_mtu, client.buf.len(), &client.buf);
|
||||
if let Ok((udpdata, flags, conid, reqaddr)) = ret {
|
||||
if flags & UDPGW_FLAG_KEEPALIVE != 0 {
|
||||
log::debug!("client {} send keepalive", client.addr);
|
||||
send_keepalive_response(tx.clone(), conid).await;
|
||||
continue;
|
||||
}
|
||||
log::debug!(
|
||||
"client {} received udp data,flags:{},conid:{},addr:{:?},data len:{}",
|
||||
client.addr,
|
||||
flags,
|
||||
conid,
|
||||
reqaddr,
|
||||
udpdata.len()
|
||||
);
|
||||
let mut req = UdpRequest {
|
||||
server_addr: reqaddr,
|
||||
conid,
|
||||
flags,
|
||||
data: udpdata.to_vec(),
|
||||
};
|
||||
let tx1 = tx.clone();
|
||||
let tx2 = tx.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = process_udp(client.addr, udp_timeout, tx1, &mut req).await {
|
||||
send_error(tx2, &mut req).await;
|
||||
log::error!("client {} process_udp {}", client.addr, e);
|
||||
}
|
||||
});
|
||||
} else {
|
||||
log::error!("client {} parse_udp_data {:?}", client.addr, ret.err());
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn write_to_client(addr: SocketAddr, mut writer: WriteHalf<'_>, mut rx: Receiver<Vec<u8>>) -> std::io::Result<()> {
|
||||
loop {
|
||||
let Some(udp_response) = rx.recv().await else {
|
||||
log::trace!("client {} channel closed", addr);
|
||||
break;
|
||||
};
|
||||
if udp_response.is_empty() {
|
||||
log::trace!("client {} channel recv 0", addr);
|
||||
break;
|
||||
}
|
||||
log::trace!("send response to client {} len {}", addr, udp_response.len());
|
||||
let _r = writer.write(&udp_response).await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
let args = Arc::new(UdpGwArgs::parse_args());
|
||||
|
||||
let tcp_listener = TcpListener::bind(args.listen_addr).await?;
|
||||
let tcp_listener = tokio::net::TcpListener::bind(args.listen_addr).await?;
|
||||
|
||||
let default = format!("{:?}", args.verbosity);
|
||||
|
||||
env_logger::Builder::from_env(env_logger::Env::default().default_filter_or(default)).init();
|
||||
|
||||
log::info!("UDP GW Server started");
|
||||
log::info!("UDP Gateway Server running at {}", args.listen_addr);
|
||||
|
||||
#[cfg(unix)]
|
||||
if args.daemonize {
|
||||
|
@ -339,29 +358,17 @@ async fn main() -> Result<()> {
|
|||
|
||||
loop {
|
||||
let (mut tcp_stream, addr) = tcp_listener.accept().await?;
|
||||
let client = Client {
|
||||
addr,
|
||||
buf: vec![],
|
||||
last_activity: std::time::Instant::now(),
|
||||
};
|
||||
let client = Client::new(addr);
|
||||
log::info!("client {} connected", addr);
|
||||
let params = Arc::clone(&args);
|
||||
let params = args.clone();
|
||||
tokio::spawn(async move {
|
||||
let (tx, mut rx) = mpsc::channel::<Vec<u8>>(100);
|
||||
let (tcp_read_stream, mut tcp_write_stream) = tcp_stream.split();
|
||||
tokio::select! {
|
||||
_ = process_client_udp_req(¶ms, tx, client, tcp_read_stream) =>{}
|
||||
_ = async {
|
||||
loop
|
||||
{
|
||||
if let Some(udp_response) = rx.recv().await {
|
||||
log::debug!("send udp_response len {}",udp_response.len());
|
||||
let _ = tcp_write_stream.write(&udp_response).await;
|
||||
}
|
||||
}
|
||||
} => {}
|
||||
}
|
||||
log::info!("client {} disconnected", addr);
|
||||
let (tx, rx) = mpsc::channel::<Vec<u8>>(100);
|
||||
let (tcp_read_stream, tcp_write_stream) = tcp_stream.split();
|
||||
let res = tokio::select! {
|
||||
v = process_client_udp_req(¶ms, tx, client, tcp_read_stream) => v,
|
||||
v = write_to_client(addr, tcp_write_stream, rx) => v,
|
||||
};
|
||||
log::info!("client {} disconnected with {:?}", addr, res);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue