Introduce cancellation token and reduce amount of code (#88)

Test passed on Android. Thanks a lot.
This commit is contained in:
Andrej Mihajlov 2024-02-10 17:36:54 +01:00 committed by GitHub
parent 2434c62524
commit 2a9775ce2e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 59 additions and 126 deletions

View file

@ -29,6 +29,7 @@ log = { version = "0.4", features = ["std"] }
socks5-impl = { version = "0.5" } socks5-impl = { version = "0.5" }
thiserror = "1.0" thiserror = "1.0"
tokio = { version = "1.36", features = ["full"] } tokio = { version = "1.36", features = ["full"] }
tokio-util = "0.7"
tproxy-config = { version = "0.1", features = ["log"] } tproxy-config = { version = "0.1", features = ["log"] }
trust-dns-proto = "0.23" trust-dns-proto = "0.23"
tun2 = { version = "1.0", features = ["async"] } tun2 = { version = "1.0", features = ["async"] }

View file

@ -12,7 +12,7 @@
+ (void)startWithConfig:(NSString *)proxy_url + (void)startWithConfig:(NSString *)proxy_url
tun_fd:(int)tun_fd tun_fd:(int)tun_fd
tun_mtu:(uint32_t)tun_mtu tun_mtu:(uint16_t)tun_mtu
dns_over_tcp:(bool)dns_over_tcp dns_over_tcp:(bool)dns_over_tcp
verbose:(bool)verbose; verbose:(bool)verbose;
+ (void) shutdown; + (void) shutdown;

View file

@ -14,7 +14,7 @@
+ (void)startWithConfig:(NSString *)proxy_url + (void)startWithConfig:(NSString *)proxy_url
tun_fd:(int)tun_fd tun_fd:(int)tun_fd
tun_mtu:(uint32_t)tun_mtu tun_mtu:(uint16_t)tun_mtu
dns_over_tcp:(bool)dns_over_tcp dns_over_tcp:(bool)dns_over_tcp
verbose:(bool)verbose { verbose:(bool)verbose {
ArgDns dns_strategy = dns_over_tcp ? OverTcp : Direct; ArgDns dns_strategy = dns_over_tcp ? OverTcp : Direct;

View file

@ -7,7 +7,7 @@ use crate::{
}; };
use jni::{ use jni::{
objects::{JClass, JString}, objects::{JClass, JString},
sys::jint, sys::{jchar, jint},
JNIEnv, JNIEnv,
}; };
@ -20,7 +20,7 @@ pub unsafe extern "C" fn Java_com_github_shadowsocks_bg_Tun2proxy_run(
_clazz: JClass, _clazz: JClass,
proxy_url: JString, proxy_url: JString,
tun_fd: jint, tun_fd: jint,
tun_mtu: jint, tun_mtu: jchar,
verbosity: jint, verbosity: jint,
dns_strategy: jint, dns_strategy: jint,
) -> jint { ) -> jint {
@ -38,7 +38,7 @@ pub unsafe extern "C" fn Java_com_github_shadowsocks_bg_Tun2proxy_run(
let proxy = ArgProxy::from_url(proxy_url).unwrap(); let proxy = ArgProxy::from_url(proxy_url).unwrap();
let args = Args::new(Some(tun_fd), proxy, dns, verbosity); let args = Args::new(Some(tun_fd), proxy, dns, verbosity);
crate::api::tun2proxy_internal_run(args, tun_mtu as _) crate::api::tun2proxy_internal_run(args, tun_mtu)
} }
/// # Safety /// # Safety

View file

@ -1,16 +1,23 @@
#![cfg(any(target_os = "ios", target_os = "android"))] #![cfg(any(target_os = "ios", target_os = "android"))]
use crate::{Args, Builder, Quit}; use crate::Args;
use std::{os::raw::c_int, sync::Arc}; use std::{os::raw::c_int, sync::Mutex};
use tokio_util::sync::CancellationToken;
static mut TUN_QUIT: Option<Arc<Quit>> = None; static TUN_QUIT: Mutex<Option<CancellationToken>> = Mutex::new(None);
pub(crate) fn tun2proxy_internal_run(args: Args, tun_mtu: usize) -> c_int { pub(crate) fn tun2proxy_internal_run(args: Args, tun_mtu: u16) -> c_int {
if unsafe { TUN_QUIT.is_some() } { let mut lock = TUN_QUIT.lock().unwrap();
if lock.is_some() {
log::error!("tun2proxy already started"); log::error!("tun2proxy already started");
return -1; return -1;
} }
let shutdown_token = CancellationToken::new();
*lock = Some(shutdown_token.clone());
// explicit drop to avoid holding mutex lock while running proxy.
drop(lock);
let block = async move { let block = async move {
log::info!("Proxy {} server: {}", args.proxy.proxy_type, args.proxy.addr); log::info!("Proxy {} server: {}", args.proxy.proxy_type, args.proxy.addr);
@ -18,53 +25,40 @@ pub(crate) fn tun2proxy_internal_run(args: Args, tun_mtu: usize) -> c_int {
config.raw_fd(args.tun_fd.ok_or(crate::Error::from("tun_fd"))?); config.raw_fd(args.tun_fd.ok_or(crate::Error::from("tun_fd"))?);
let device = tun2::create_as_async(&config).map_err(std::io::Error::from)?; let device = tun2::create_as_async(&config).map_err(std::io::Error::from)?;
let join_handle = tokio::spawn(crate::run(device, tun_mtu, args, shutdown_token));
#[cfg(target_os = "android")] join_handle.await.map_err(std::io::Error::from)?
let tun2proxy = Builder::new(device, args).mtu(tun_mtu).build();
#[cfg(target_os = "ios")]
let tun2proxy = Builder::new(device, args).mtu(tun_mtu).build();
let (join_handle, quit) = tun2proxy.start();
unsafe { TUN_QUIT = Some(Arc::new(quit)) };
join_handle.await
}; };
match tokio::runtime::Builder::new_multi_thread().enable_all().build() { let exit_code = match tokio::runtime::Builder::new_multi_thread().enable_all().build() {
Err(_err) => { Err(e) => {
log::error!("failed to create tokio runtime with error: {:?}", _err); log::error!("failed to create tokio runtime with error: {:?}", e);
-1 -1
} }
Ok(rt) => match rt.block_on(block) { Ok(rt) => match rt.block_on(block) {
Ok(_) => 0, Ok(_) => 0,
Err(_err) => { Err(e) => {
log::error!("failed to run tun2proxy with error: {:?}", _err); log::error!("failed to run tun2proxy with error: {:?}", e);
-2 -2
} }
}, },
} };
// release shutdown token before exit.
let mut lock = TUN_QUIT.lock().unwrap();
let _ = lock.take();
exit_code
} }
pub(crate) fn tun2proxy_internal_stop() -> c_int { pub(crate) fn tun2proxy_internal_stop() -> c_int {
let res = match unsafe { &TUN_QUIT } { let lock = TUN_QUIT.lock().unwrap();
None => {
log::error!("tun2proxy not started"); if let Some(shutdown_token) = lock.as_ref() {
-1 shutdown_token.cancel();
} 0
Some(tun_quit) => match tokio::runtime::Builder::new_multi_thread().enable_all().build() { } else {
Err(_err) => { log::error!("tun2proxy not started");
log::error!("failed to create tokio runtime with error: {:?}", _err); -1
-2 }
}
Ok(rt) => match rt.block_on(async move { tun_quit.trigger().await }) {
Ok(_) => 0,
Err(_err) => {
log::error!("failed to stop tun2proxy with error: {:?}", _err);
-3
}
},
},
};
unsafe { TUN_QUIT = None };
res
} }

View file

@ -1,6 +1,7 @@
use tokio_util::sync::CancellationToken;
use tproxy_config::{TproxyArgs, TUN_GATEWAY, TUN_IPV4, TUN_NETMASK}; use tproxy_config::{TproxyArgs, TUN_GATEWAY, TUN_IPV4, TUN_NETMASK};
use tun2::DEFAULT_MTU as MTU; use tun2::DEFAULT_MTU as MTU;
use tun2proxy::{Args, Builder}; use tun2proxy::{self, Args};
#[tokio::main] #[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> { async fn main() -> Result<(), Box<dyn std::error::Error>> {
@ -62,11 +63,13 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
tproxy_config::tproxy_setup(&tproxy_args)?; tproxy_config::tproxy_setup(&tproxy_args)?;
} }
let tun2proxy = Builder::new(device, args).mtu(MTU as _).build(); let shutdown_token = CancellationToken::new();
let (join_handle, quit) = tun2proxy.start(); let cloned_token = shutdown_token.clone();
let join_handle = tokio::spawn(tun2proxy::run(device, MTU, args, cloned_token));
ctrlc2::set_async_handler(async move { ctrlc2::set_async_handler(async move {
quit.trigger().await.expect("quit error"); log::info!("Ctrl-C received, exiting...");
shutdown_token.cancel();
}) })
.await; .await;

View file

@ -4,7 +4,7 @@ use crate::{
args::{ArgDns, ArgProxy}, args::{ArgDns, ArgProxy},
ArgVerbosity, Args, ArgVerbosity, Args,
}; };
use std::os::raw::{c_char, c_int, c_uint}; use std::os::raw::{c_char, c_int, c_ushort};
/// # Safety /// # Safety
/// ///
@ -13,7 +13,7 @@ use std::os::raw::{c_char, c_int, c_uint};
pub unsafe extern "C" fn tun2proxy_run( pub unsafe extern "C" fn tun2proxy_run(
proxy_url: *const c_char, proxy_url: *const c_char,
tun_fd: c_int, tun_fd: c_int,
tun_mtu: c_uint, tun_mtu: c_ushort,
dns_strategy: ArgDns, dns_strategy: ArgDns,
verbosity: ArgVerbosity, verbosity: ArgVerbosity,
) -> c_int { ) -> c_int {
@ -25,7 +25,7 @@ pub unsafe extern "C" fn tun2proxy_run(
let args = Args::new(Some(tun_fd), proxy, dns_strategy, verbosity); let args = Args::new(Some(tun_fd), proxy, dns_strategy, verbosity);
crate::api::tun2proxy_internal_run(args, tun_mtu as _) crate::api::tun2proxy_internal_run(args, tun_mtu)
} }
/// # Safety /// # Safety

View file

@ -8,17 +8,16 @@ use ipstack::stream::{IpStackStream, IpStackTcpStream, IpStackUdpStream};
use proxy_handler::{ProxyHandler, ProxyHandlerManager}; use proxy_handler::{ProxyHandler, ProxyHandlerManager};
use socks::SocksProxyManager; use socks::SocksProxyManager;
pub use socks5_impl::protocol::UserKey; pub use socks5_impl::protocol::UserKey;
use std::{collections::VecDeque, future::Future, net::SocketAddr, pin::Pin, sync::Arc}; use std::{collections::VecDeque, net::SocketAddr, sync::Arc};
use tokio::{ use tokio::{
io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}, io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
net::TcpStream, net::TcpStream,
sync::{ sync::Mutex,
mpsc::{error::SendError, Receiver, Sender},
Mutex,
},
}; };
use tokio_util::sync::CancellationToken;
use tproxy_config::is_private_ip; use tproxy_config::is_private_ip;
use udp_stream::UdpStream; use udp_stream::UdpStream;
pub use { pub use {
args::{ArgDns, ArgProxy, ArgVerbosity, Args, ProxyType}, args::{ArgDns, ArgProxy, ArgVerbosity, Args, ProxyType},
error::{Error, Result}, error::{Error, Result},
@ -45,70 +44,7 @@ const MAX_SESSIONS: u64 = 200;
static TASK_COUNT: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0); static TASK_COUNT: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0);
use std::sync::atomic::Ordering::Relaxed; use std::sync::atomic::Ordering::Relaxed;
pub struct Builder<D> { pub async fn run<D>(device: D, mtu: u16, args: Args, shutdown_token: CancellationToken) -> crate::Result<()>
device: D,
mtu: Option<usize>,
args: Args,
}
impl<D: AsyncRead + AsyncWrite + Unpin + Send + 'static> Builder<D> {
pub fn new(device: D, args: Args) -> Self {
Builder { device, args, mtu: None }
}
pub fn mtu(mut self, mtu: usize) -> Self {
self.mtu = Some(mtu);
self
}
pub fn build(self) -> Tun2Socks5<impl Future<Output = crate::Result<()>> + Send + 'static> {
let (tx, rx) = tokio::sync::mpsc::channel::<()>(1);
Tun2Socks5(run(self.device, self.mtu.unwrap_or(1500), self.args, rx), tx)
}
}
pub struct Tun2Socks5<F: Future>(F, Sender<()>);
impl<F: Future + Send + 'static> Tun2Socks5<F>
where
F::Output: Send,
{
pub fn start(self) -> (JoinHandle<F::Output>, Quit) {
let r = tokio::spawn(self.0);
(JoinHandle(r), Quit(self.1))
}
}
pub struct Quit(Sender<()>);
impl Quit {
pub async fn trigger(&self) -> Result<(), SendError<()>> {
self.0.send(()).await
}
}
#[repr(transparent)]
struct TokioJoinError(tokio::task::JoinError);
impl From<TokioJoinError> for crate::Result<()> {
fn from(value: TokioJoinError) -> Self {
Err(crate::Error::Io(value.0.into()))
}
}
pub struct JoinHandle<R>(tokio::task::JoinHandle<R>);
impl<R: From<TokioJoinError>> Future for JoinHandle<R> {
type Output = R;
fn poll(mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Self::Output> {
match std::task::ready!(Pin::new(&mut self.0).poll(cx)) {
Ok(r) => std::task::Poll::Ready(r),
Err(e) => std::task::Poll::Ready(TokioJoinError(e).into()),
}
}
}
pub async fn run<D>(device: D, mtu: usize, args: Args, mut quit: Receiver<()>) -> crate::Result<()>
where where
D: AsyncRead + AsyncWrite + Unpin + Send + 'static, D: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{ {
@ -130,7 +66,7 @@ where
}; };
let mut ipstack_config = ipstack::IpStackConfig::default(); let mut ipstack_config = ipstack::IpStackConfig::default();
ipstack_config.mtu(mtu as _); ipstack_config.mtu(mtu);
ipstack_config.tcp_timeout(std::time::Duration::from_secs(600)); // 10 minutes ipstack_config.tcp_timeout(std::time::Duration::from_secs(600)); // 10 minutes
ipstack_config.udp_timeout(std::time::Duration::from_secs(10)); // 10 seconds ipstack_config.udp_timeout(std::time::Duration::from_secs(10)); // 10 seconds
@ -139,9 +75,8 @@ where
loop { loop {
let virtual_dns = virtual_dns.clone(); let virtual_dns = virtual_dns.clone();
let ip_stack_stream = tokio::select! { let ip_stack_stream = tokio::select! {
_ = quit.recv() => { _ = shutdown_token.cancelled() => {
log::info!(""); log::info!("Shutdown received");
log::info!("Ctrl-C recieved, exiting...");
break; break;
} }
ip_stack_stream = ip_stack.accept() => { ip_stack_stream = ip_stack.accept() => {