diff --git a/Cargo.toml b/Cargo.toml index 5bb8714..4593f71 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,6 +29,7 @@ log = { version = "0.4", features = ["std"] } socks5-impl = { version = "0.5" } thiserror = "1.0" tokio = { version = "1.36", features = ["full"] } +tokio-util = "0.7" tproxy-config = { version = "0.1", features = ["log"] } trust-dns-proto = "0.23" tun2 = { version = "1.0", features = ["async"] } diff --git a/apple/tun2proxy/Tun2proxyWrapper.h b/apple/tun2proxy/Tun2proxyWrapper.h index 70badfd..517192e 100644 --- a/apple/tun2proxy/Tun2proxyWrapper.h +++ b/apple/tun2proxy/Tun2proxyWrapper.h @@ -12,7 +12,7 @@ + (void)startWithConfig:(NSString *)proxy_url tun_fd:(int)tun_fd - tun_mtu:(uint32_t)tun_mtu + tun_mtu:(uint16_t)tun_mtu dns_over_tcp:(bool)dns_over_tcp verbose:(bool)verbose; + (void) shutdown; diff --git a/apple/tun2proxy/Tun2proxyWrapper.m b/apple/tun2proxy/Tun2proxyWrapper.m index 3620390..d1f8705 100644 --- a/apple/tun2proxy/Tun2proxyWrapper.m +++ b/apple/tun2proxy/Tun2proxyWrapper.m @@ -14,7 +14,7 @@ + (void)startWithConfig:(NSString *)proxy_url tun_fd:(int)tun_fd - tun_mtu:(uint32_t)tun_mtu + tun_mtu:(uint16_t)tun_mtu dns_over_tcp:(bool)dns_over_tcp verbose:(bool)verbose { ArgDns dns_strategy = dns_over_tcp ? OverTcp : Direct; diff --git a/src/android.rs b/src/android.rs index 8d0e836..b9798d0 100644 --- a/src/android.rs +++ b/src/android.rs @@ -7,7 +7,7 @@ use crate::{ }; use jni::{ objects::{JClass, JString}, - sys::jint, + sys::{jchar, jint}, JNIEnv, }; @@ -20,7 +20,7 @@ pub unsafe extern "C" fn Java_com_github_shadowsocks_bg_Tun2proxy_run( _clazz: JClass, proxy_url: JString, tun_fd: jint, - tun_mtu: jint, + tun_mtu: jchar, verbosity: jint, dns_strategy: jint, ) -> jint { @@ -38,7 +38,7 @@ pub unsafe extern "C" fn Java_com_github_shadowsocks_bg_Tun2proxy_run( let proxy = ArgProxy::from_url(proxy_url).unwrap(); let args = Args::new(Some(tun_fd), proxy, dns, verbosity); - crate::api::tun2proxy_internal_run(args, tun_mtu as _) + crate::api::tun2proxy_internal_run(args, tun_mtu) } /// # Safety diff --git a/src/api.rs b/src/api.rs index 4f7d3cf..003121b 100644 --- a/src/api.rs +++ b/src/api.rs @@ -1,16 +1,23 @@ #![cfg(any(target_os = "ios", target_os = "android"))] -use crate::{Args, Builder, Quit}; -use std::{os::raw::c_int, sync::Arc}; +use crate::Args; +use std::{os::raw::c_int, sync::Mutex}; +use tokio_util::sync::CancellationToken; -static mut TUN_QUIT: Option> = None; +static TUN_QUIT: Mutex> = Mutex::new(None); -pub(crate) fn tun2proxy_internal_run(args: Args, tun_mtu: usize) -> c_int { - if unsafe { TUN_QUIT.is_some() } { +pub(crate) fn tun2proxy_internal_run(args: Args, tun_mtu: u16) -> c_int { + let mut lock = TUN_QUIT.lock().unwrap(); + if lock.is_some() { log::error!("tun2proxy already started"); return -1; } + let shutdown_token = CancellationToken::new(); + *lock = Some(shutdown_token.clone()); + // explicit drop to avoid holding mutex lock while running proxy. + drop(lock); + let block = async move { log::info!("Proxy {} server: {}", args.proxy.proxy_type, args.proxy.addr); @@ -18,53 +25,40 @@ pub(crate) fn tun2proxy_internal_run(args: Args, tun_mtu: usize) -> c_int { config.raw_fd(args.tun_fd.ok_or(crate::Error::from("tun_fd"))?); let device = tun2::create_as_async(&config).map_err(std::io::Error::from)?; + let join_handle = tokio::spawn(crate::run(device, tun_mtu, args, shutdown_token)); - #[cfg(target_os = "android")] - let tun2proxy = Builder::new(device, args).mtu(tun_mtu).build(); - #[cfg(target_os = "ios")] - let tun2proxy = Builder::new(device, args).mtu(tun_mtu).build(); - let (join_handle, quit) = tun2proxy.start(); - - unsafe { TUN_QUIT = Some(Arc::new(quit)) }; - - join_handle.await + join_handle.await.map_err(std::io::Error::from)? }; - match tokio::runtime::Builder::new_multi_thread().enable_all().build() { - Err(_err) => { - log::error!("failed to create tokio runtime with error: {:?}", _err); + let exit_code = match tokio::runtime::Builder::new_multi_thread().enable_all().build() { + Err(e) => { + log::error!("failed to create tokio runtime with error: {:?}", e); -1 } Ok(rt) => match rt.block_on(block) { Ok(_) => 0, - Err(_err) => { - log::error!("failed to run tun2proxy with error: {:?}", _err); + Err(e) => { + log::error!("failed to run tun2proxy with error: {:?}", e); -2 } }, - } + }; + + // release shutdown token before exit. + let mut lock = TUN_QUIT.lock().unwrap(); + let _ = lock.take(); + + exit_code } pub(crate) fn tun2proxy_internal_stop() -> c_int { - let res = match unsafe { &TUN_QUIT } { - None => { - log::error!("tun2proxy not started"); - -1 - } - Some(tun_quit) => match tokio::runtime::Builder::new_multi_thread().enable_all().build() { - Err(_err) => { - log::error!("failed to create tokio runtime with error: {:?}", _err); - -2 - } - Ok(rt) => match rt.block_on(async move { tun_quit.trigger().await }) { - Ok(_) => 0, - Err(_err) => { - log::error!("failed to stop tun2proxy with error: {:?}", _err); - -3 - } - }, - }, - }; - unsafe { TUN_QUIT = None }; - res + let lock = TUN_QUIT.lock().unwrap(); + + if let Some(shutdown_token) = lock.as_ref() { + shutdown_token.cancel(); + 0 + } else { + log::error!("tun2proxy not started"); + -1 + } } diff --git a/src/bin/main.rs b/src/bin/main.rs index 0730354..f6e5d25 100644 --- a/src/bin/main.rs +++ b/src/bin/main.rs @@ -1,6 +1,7 @@ +use tokio_util::sync::CancellationToken; use tproxy_config::{TproxyArgs, TUN_GATEWAY, TUN_IPV4, TUN_NETMASK}; use tun2::DEFAULT_MTU as MTU; -use tun2proxy::{Args, Builder}; +use tun2proxy::{self, Args}; #[tokio::main] async fn main() -> Result<(), Box> { @@ -62,11 +63,13 @@ async fn main() -> Result<(), Box> { tproxy_config::tproxy_setup(&tproxy_args)?; } - let tun2proxy = Builder::new(device, args).mtu(MTU as _).build(); - let (join_handle, quit) = tun2proxy.start(); + let shutdown_token = CancellationToken::new(); + let cloned_token = shutdown_token.clone(); + let join_handle = tokio::spawn(tun2proxy::run(device, MTU, args, cloned_token)); ctrlc2::set_async_handler(async move { - quit.trigger().await.expect("quit error"); + log::info!("Ctrl-C received, exiting..."); + shutdown_token.cancel(); }) .await; diff --git a/src/ios.rs b/src/ios.rs index 034939d..8322066 100644 --- a/src/ios.rs +++ b/src/ios.rs @@ -4,7 +4,7 @@ use crate::{ args::{ArgDns, ArgProxy}, ArgVerbosity, Args, }; -use std::os::raw::{c_char, c_int, c_uint}; +use std::os::raw::{c_char, c_int, c_ushort}; /// # Safety /// @@ -13,7 +13,7 @@ use std::os::raw::{c_char, c_int, c_uint}; pub unsafe extern "C" fn tun2proxy_run( proxy_url: *const c_char, tun_fd: c_int, - tun_mtu: c_uint, + tun_mtu: c_ushort, dns_strategy: ArgDns, verbosity: ArgVerbosity, ) -> c_int { @@ -25,7 +25,7 @@ pub unsafe extern "C" fn tun2proxy_run( let args = Args::new(Some(tun_fd), proxy, dns_strategy, verbosity); - crate::api::tun2proxy_internal_run(args, tun_mtu as _) + crate::api::tun2proxy_internal_run(args, tun_mtu) } /// # Safety diff --git a/src/lib.rs b/src/lib.rs index d1fde87..7175970 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,17 +8,16 @@ use ipstack::stream::{IpStackStream, IpStackTcpStream, IpStackUdpStream}; use proxy_handler::{ProxyHandler, ProxyHandlerManager}; use socks::SocksProxyManager; pub use socks5_impl::protocol::UserKey; -use std::{collections::VecDeque, future::Future, net::SocketAddr, pin::Pin, sync::Arc}; +use std::{collections::VecDeque, net::SocketAddr, sync::Arc}; use tokio::{ io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}, net::TcpStream, - sync::{ - mpsc::{error::SendError, Receiver, Sender}, - Mutex, - }, + sync::Mutex, }; +use tokio_util::sync::CancellationToken; use tproxy_config::is_private_ip; use udp_stream::UdpStream; + pub use { args::{ArgDns, ArgProxy, ArgVerbosity, Args, ProxyType}, error::{Error, Result}, @@ -45,70 +44,7 @@ const MAX_SESSIONS: u64 = 200; static TASK_COUNT: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0); use std::sync::atomic::Ordering::Relaxed; -pub struct Builder { - device: D, - mtu: Option, - args: Args, -} - -impl Builder { - pub fn new(device: D, args: Args) -> Self { - Builder { device, args, mtu: None } - } - pub fn mtu(mut self, mtu: usize) -> Self { - self.mtu = Some(mtu); - self - } - pub fn build(self) -> Tun2Socks5> + Send + 'static> { - let (tx, rx) = tokio::sync::mpsc::channel::<()>(1); - - Tun2Socks5(run(self.device, self.mtu.unwrap_or(1500), self.args, rx), tx) - } -} - -pub struct Tun2Socks5(F, Sender<()>); - -impl Tun2Socks5 -where - F::Output: Send, -{ - pub fn start(self) -> (JoinHandle, Quit) { - let r = tokio::spawn(self.0); - (JoinHandle(r), Quit(self.1)) - } -} - -pub struct Quit(Sender<()>); - -impl Quit { - pub async fn trigger(&self) -> Result<(), SendError<()>> { - self.0.send(()).await - } -} - -#[repr(transparent)] -struct TokioJoinError(tokio::task::JoinError); - -impl From for crate::Result<()> { - fn from(value: TokioJoinError) -> Self { - Err(crate::Error::Io(value.0.into())) - } -} - -pub struct JoinHandle(tokio::task::JoinHandle); - -impl> Future for JoinHandle { - type Output = R; - - fn poll(mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll { - match std::task::ready!(Pin::new(&mut self.0).poll(cx)) { - Ok(r) => std::task::Poll::Ready(r), - Err(e) => std::task::Poll::Ready(TokioJoinError(e).into()), - } - } -} - -pub async fn run(device: D, mtu: usize, args: Args, mut quit: Receiver<()>) -> crate::Result<()> +pub async fn run(device: D, mtu: u16, args: Args, shutdown_token: CancellationToken) -> crate::Result<()> where D: AsyncRead + AsyncWrite + Unpin + Send + 'static, { @@ -130,7 +66,7 @@ where }; let mut ipstack_config = ipstack::IpStackConfig::default(); - ipstack_config.mtu(mtu as _); + ipstack_config.mtu(mtu); ipstack_config.tcp_timeout(std::time::Duration::from_secs(600)); // 10 minutes ipstack_config.udp_timeout(std::time::Duration::from_secs(10)); // 10 seconds @@ -139,9 +75,8 @@ where loop { let virtual_dns = virtual_dns.clone(); let ip_stack_stream = tokio::select! { - _ = quit.recv() => { - log::info!(""); - log::info!("Ctrl-C recieved, exiting..."); + _ = shutdown_token.cancelled() => { + log::info!("Shutdown received"); break; } ip_stack_stream = ip_stack.accept() => {