Introduce cancellation token and reduce amount of code (#88)

Test passed on Android. Thanks a lot.
This commit is contained in:
Andrej Mihajlov 2024-02-10 17:36:54 +01:00 committed by GitHub
parent 2434c62524
commit 2a9775ce2e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 59 additions and 126 deletions

View file

@ -29,6 +29,7 @@ log = { version = "0.4", features = ["std"] }
socks5-impl = { version = "0.5" }
thiserror = "1.0"
tokio = { version = "1.36", features = ["full"] }
tokio-util = "0.7"
tproxy-config = { version = "0.1", features = ["log"] }
trust-dns-proto = "0.23"
tun2 = { version = "1.0", features = ["async"] }

View file

@ -12,7 +12,7 @@
+ (void)startWithConfig:(NSString *)proxy_url
tun_fd:(int)tun_fd
tun_mtu:(uint32_t)tun_mtu
tun_mtu:(uint16_t)tun_mtu
dns_over_tcp:(bool)dns_over_tcp
verbose:(bool)verbose;
+ (void) shutdown;

View file

@ -14,7 +14,7 @@
+ (void)startWithConfig:(NSString *)proxy_url
tun_fd:(int)tun_fd
tun_mtu:(uint32_t)tun_mtu
tun_mtu:(uint16_t)tun_mtu
dns_over_tcp:(bool)dns_over_tcp
verbose:(bool)verbose {
ArgDns dns_strategy = dns_over_tcp ? OverTcp : Direct;

View file

@ -7,7 +7,7 @@ use crate::{
};
use jni::{
objects::{JClass, JString},
sys::jint,
sys::{jchar, jint},
JNIEnv,
};
@ -20,7 +20,7 @@ pub unsafe extern "C" fn Java_com_github_shadowsocks_bg_Tun2proxy_run(
_clazz: JClass,
proxy_url: JString,
tun_fd: jint,
tun_mtu: jint,
tun_mtu: jchar,
verbosity: jint,
dns_strategy: jint,
) -> jint {
@ -38,7 +38,7 @@ pub unsafe extern "C" fn Java_com_github_shadowsocks_bg_Tun2proxy_run(
let proxy = ArgProxy::from_url(proxy_url).unwrap();
let args = Args::new(Some(tun_fd), proxy, dns, verbosity);
crate::api::tun2proxy_internal_run(args, tun_mtu as _)
crate::api::tun2proxy_internal_run(args, tun_mtu)
}
/// # Safety

View file

@ -1,16 +1,23 @@
#![cfg(any(target_os = "ios", target_os = "android"))]
use crate::{Args, Builder, Quit};
use std::{os::raw::c_int, sync::Arc};
use crate::Args;
use std::{os::raw::c_int, sync::Mutex};
use tokio_util::sync::CancellationToken;
static mut TUN_QUIT: Option<Arc<Quit>> = None;
static TUN_QUIT: Mutex<Option<CancellationToken>> = Mutex::new(None);
pub(crate) fn tun2proxy_internal_run(args: Args, tun_mtu: usize) -> c_int {
if unsafe { TUN_QUIT.is_some() } {
pub(crate) fn tun2proxy_internal_run(args: Args, tun_mtu: u16) -> c_int {
let mut lock = TUN_QUIT.lock().unwrap();
if lock.is_some() {
log::error!("tun2proxy already started");
return -1;
}
let shutdown_token = CancellationToken::new();
*lock = Some(shutdown_token.clone());
// explicit drop to avoid holding mutex lock while running proxy.
drop(lock);
let block = async move {
log::info!("Proxy {} server: {}", args.proxy.proxy_type, args.proxy.addr);
@ -18,53 +25,40 @@ pub(crate) fn tun2proxy_internal_run(args: Args, tun_mtu: usize) -> c_int {
config.raw_fd(args.tun_fd.ok_or(crate::Error::from("tun_fd"))?);
let device = tun2::create_as_async(&config).map_err(std::io::Error::from)?;
let join_handle = tokio::spawn(crate::run(device, tun_mtu, args, shutdown_token));
#[cfg(target_os = "android")]
let tun2proxy = Builder::new(device, args).mtu(tun_mtu).build();
#[cfg(target_os = "ios")]
let tun2proxy = Builder::new(device, args).mtu(tun_mtu).build();
let (join_handle, quit) = tun2proxy.start();
unsafe { TUN_QUIT = Some(Arc::new(quit)) };
join_handle.await
join_handle.await.map_err(std::io::Error::from)?
};
match tokio::runtime::Builder::new_multi_thread().enable_all().build() {
Err(_err) => {
log::error!("failed to create tokio runtime with error: {:?}", _err);
let exit_code = match tokio::runtime::Builder::new_multi_thread().enable_all().build() {
Err(e) => {
log::error!("failed to create tokio runtime with error: {:?}", e);
-1
}
Ok(rt) => match rt.block_on(block) {
Ok(_) => 0,
Err(_err) => {
log::error!("failed to run tun2proxy with error: {:?}", _err);
Err(e) => {
log::error!("failed to run tun2proxy with error: {:?}", e);
-2
}
},
}
};
// release shutdown token before exit.
let mut lock = TUN_QUIT.lock().unwrap();
let _ = lock.take();
exit_code
}
pub(crate) fn tun2proxy_internal_stop() -> c_int {
let res = match unsafe { &TUN_QUIT } {
None => {
log::error!("tun2proxy not started");
-1
}
Some(tun_quit) => match tokio::runtime::Builder::new_multi_thread().enable_all().build() {
Err(_err) => {
log::error!("failed to create tokio runtime with error: {:?}", _err);
-2
}
Ok(rt) => match rt.block_on(async move { tun_quit.trigger().await }) {
Ok(_) => 0,
Err(_err) => {
log::error!("failed to stop tun2proxy with error: {:?}", _err);
-3
}
},
},
};
unsafe { TUN_QUIT = None };
res
let lock = TUN_QUIT.lock().unwrap();
if let Some(shutdown_token) = lock.as_ref() {
shutdown_token.cancel();
0
} else {
log::error!("tun2proxy not started");
-1
}
}

View file

@ -1,6 +1,7 @@
use tokio_util::sync::CancellationToken;
use tproxy_config::{TproxyArgs, TUN_GATEWAY, TUN_IPV4, TUN_NETMASK};
use tun2::DEFAULT_MTU as MTU;
use tun2proxy::{Args, Builder};
use tun2proxy::{self, Args};
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
@ -62,11 +63,13 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
tproxy_config::tproxy_setup(&tproxy_args)?;
}
let tun2proxy = Builder::new(device, args).mtu(MTU as _).build();
let (join_handle, quit) = tun2proxy.start();
let shutdown_token = CancellationToken::new();
let cloned_token = shutdown_token.clone();
let join_handle = tokio::spawn(tun2proxy::run(device, MTU, args, cloned_token));
ctrlc2::set_async_handler(async move {
quit.trigger().await.expect("quit error");
log::info!("Ctrl-C received, exiting...");
shutdown_token.cancel();
})
.await;

View file

@ -4,7 +4,7 @@ use crate::{
args::{ArgDns, ArgProxy},
ArgVerbosity, Args,
};
use std::os::raw::{c_char, c_int, c_uint};
use std::os::raw::{c_char, c_int, c_ushort};
/// # Safety
///
@ -13,7 +13,7 @@ use std::os::raw::{c_char, c_int, c_uint};
pub unsafe extern "C" fn tun2proxy_run(
proxy_url: *const c_char,
tun_fd: c_int,
tun_mtu: c_uint,
tun_mtu: c_ushort,
dns_strategy: ArgDns,
verbosity: ArgVerbosity,
) -> c_int {
@ -25,7 +25,7 @@ pub unsafe extern "C" fn tun2proxy_run(
let args = Args::new(Some(tun_fd), proxy, dns_strategy, verbosity);
crate::api::tun2proxy_internal_run(args, tun_mtu as _)
crate::api::tun2proxy_internal_run(args, tun_mtu)
}
/// # Safety

View file

@ -8,17 +8,16 @@ use ipstack::stream::{IpStackStream, IpStackTcpStream, IpStackUdpStream};
use proxy_handler::{ProxyHandler, ProxyHandlerManager};
use socks::SocksProxyManager;
pub use socks5_impl::protocol::UserKey;
use std::{collections::VecDeque, future::Future, net::SocketAddr, pin::Pin, sync::Arc};
use std::{collections::VecDeque, net::SocketAddr, sync::Arc};
use tokio::{
io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
net::TcpStream,
sync::{
mpsc::{error::SendError, Receiver, Sender},
Mutex,
},
sync::Mutex,
};
use tokio_util::sync::CancellationToken;
use tproxy_config::is_private_ip;
use udp_stream::UdpStream;
pub use {
args::{ArgDns, ArgProxy, ArgVerbosity, Args, ProxyType},
error::{Error, Result},
@ -45,70 +44,7 @@ const MAX_SESSIONS: u64 = 200;
static TASK_COUNT: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0);
use std::sync::atomic::Ordering::Relaxed;
pub struct Builder<D> {
device: D,
mtu: Option<usize>,
args: Args,
}
impl<D: AsyncRead + AsyncWrite + Unpin + Send + 'static> Builder<D> {
pub fn new(device: D, args: Args) -> Self {
Builder { device, args, mtu: None }
}
pub fn mtu(mut self, mtu: usize) -> Self {
self.mtu = Some(mtu);
self
}
pub fn build(self) -> Tun2Socks5<impl Future<Output = crate::Result<()>> + Send + 'static> {
let (tx, rx) = tokio::sync::mpsc::channel::<()>(1);
Tun2Socks5(run(self.device, self.mtu.unwrap_or(1500), self.args, rx), tx)
}
}
pub struct Tun2Socks5<F: Future>(F, Sender<()>);
impl<F: Future + Send + 'static> Tun2Socks5<F>
where
F::Output: Send,
{
pub fn start(self) -> (JoinHandle<F::Output>, Quit) {
let r = tokio::spawn(self.0);
(JoinHandle(r), Quit(self.1))
}
}
pub struct Quit(Sender<()>);
impl Quit {
pub async fn trigger(&self) -> Result<(), SendError<()>> {
self.0.send(()).await
}
}
#[repr(transparent)]
struct TokioJoinError(tokio::task::JoinError);
impl From<TokioJoinError> for crate::Result<()> {
fn from(value: TokioJoinError) -> Self {
Err(crate::Error::Io(value.0.into()))
}
}
pub struct JoinHandle<R>(tokio::task::JoinHandle<R>);
impl<R: From<TokioJoinError>> Future for JoinHandle<R> {
type Output = R;
fn poll(mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Self::Output> {
match std::task::ready!(Pin::new(&mut self.0).poll(cx)) {
Ok(r) => std::task::Poll::Ready(r),
Err(e) => std::task::Poll::Ready(TokioJoinError(e).into()),
}
}
}
pub async fn run<D>(device: D, mtu: usize, args: Args, mut quit: Receiver<()>) -> crate::Result<()>
pub async fn run<D>(device: D, mtu: u16, args: Args, shutdown_token: CancellationToken) -> crate::Result<()>
where
D: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
@ -130,7 +66,7 @@ where
};
let mut ipstack_config = ipstack::IpStackConfig::default();
ipstack_config.mtu(mtu as _);
ipstack_config.mtu(mtu);
ipstack_config.tcp_timeout(std::time::Duration::from_secs(600)); // 10 minutes
ipstack_config.udp_timeout(std::time::Duration::from_secs(10)); // 10 seconds
@ -139,9 +75,8 @@ where
loop {
let virtual_dns = virtual_dns.clone();
let ip_stack_stream = tokio::select! {
_ = quit.recv() => {
log::info!("");
log::info!("Ctrl-C recieved, exiting...");
_ = shutdown_token.cancelled() => {
log::info!("Shutdown received");
break;
}
ip_stack_stream = ip_stack.accept() => {