diff --git a/Cargo.toml b/Cargo.toml index a71366f..266d8f1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tun2proxy" -version = "0.5.1" +version = "0.5.4" edition = "2021" license = "MIT" repository = "https://github.com/tun2proxy/tun2proxy" diff --git a/src/args.rs b/src/args.rs index 16ed744..ec839fb 100644 --- a/src/args.rs +++ b/src/args.rs @@ -107,6 +107,14 @@ pub struct Args { /// Daemonize for unix family or run as Windows service #[arg(long)] pub daemonize: bool, + + /// Exit immediately when fatal error occurs, useful for running as a service + #[arg(long)] + pub exit_on_fatal_error: bool, + + /// Maximum number of sessions to be handled concurrently + #[arg(long, value_name = "number", default_value = "200")] + pub max_sessions: usize, } fn validate_tun(p: &str) -> Result { @@ -149,6 +157,8 @@ impl Default for Args { verbosity: ArgVerbosity::Info, virtual_dns_pool: IpCidr::from_str("198.18.0.0/15").unwrap(), daemonize: false, + exit_on_fatal_error: false, + max_sessions: 200, } } } diff --git a/src/bin/main.rs b/src/bin/main.rs index 4b93d37..e39b7b4 100644 --- a/src/bin/main.rs +++ b/src/bin/main.rs @@ -1,7 +1,6 @@ use tun2proxy::{Args, BoxError}; -#[tokio::main] -async fn main() -> Result<(), BoxError> { +fn main() -> Result<(), BoxError> { dotenvy::dotenv().ok(); let args = Args::parse_args(); @@ -24,11 +23,16 @@ async fn main() -> Result<(), BoxError> { return Ok(()); } + let rt = tokio::runtime::Builder::new_multi_thread().enable_all().build()?; + rt.block_on(main_async(args)) +} + +async fn main_async(args: Args) -> Result<(), BoxError> { let default = format!("{:?},hickory_proto=warn", args.verbosity); env_logger::Builder::from_env(env_logger::Env::default().default_filter_or(default)).init(); let shutdown_token = tokio_util::sync::CancellationToken::new(); - let join_handle = tokio::spawn({ + let main_loop_handle = tokio::spawn({ let shutdown_token = shutdown_token.clone(); async move { #[cfg(target_os = "linux")] @@ -51,14 +55,20 @@ async fn main() -> Result<(), BoxError> { } }); - ctrlc2::set_async_handler(async move { + let ctrlc_fired = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false)); + let ctrlc_fired_clone = ctrlc_fired.clone(); + let ctrlc_handel = ctrlc2::set_async_handler(async move { log::info!("Ctrl-C received, exiting..."); + ctrlc_fired_clone.store(true, std::sync::atomic::Ordering::SeqCst); shutdown_token.cancel(); }) .await; - if let Err(err) = join_handle.await { - log::error!("main_entry error {}", err); + main_loop_handle.await?; + + if ctrlc_fired.load(std::sync::atomic::Ordering::SeqCst) { + log::info!("Ctrl-C fired, waiting the handler to finish..."); + ctrlc_handel.await.map_err(|err| err.to_string())?; } Ok(()) diff --git a/src/lib.rs b/src/lib.rs index bdaf83c..2659896 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -68,8 +68,6 @@ pub mod win_svc; const DNS_PORT: u16 = 53; -const MAX_SESSIONS: u64 = 200; - static TASK_COUNT: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0); use std::sync::atomic::Ordering::Relaxed; @@ -260,10 +258,15 @@ where ip_stack_stream? } }; + let max_sessions = args.max_sessions as u64; match ip_stack_stream { IpStackStream::Tcp(tcp) => { - if TASK_COUNT.load(Relaxed) > MAX_SESSIONS { - log::warn!("Too many sessions that over {MAX_SESSIONS}, dropping new session"); + if TASK_COUNT.load(Relaxed) > max_sessions { + if args.exit_on_fatal_error { + log::info!("Too many sessions that over {max_sessions}, exiting..."); + break; + } + log::warn!("Too many sessions that over {max_sessions}, dropping new session"); continue; } log::trace!("Session count {}", TASK_COUNT.fetch_add(1, Relaxed) + 1); @@ -286,8 +289,12 @@ where }); } IpStackStream::Udp(udp) => { - if TASK_COUNT.load(Relaxed) > MAX_SESSIONS { - log::warn!("Too many sessions that over {MAX_SESSIONS}, dropping new session"); + if TASK_COUNT.load(Relaxed) > max_sessions { + if args.exit_on_fatal_error { + log::info!("Too many sessions that over {max_sessions}, exiting..."); + break; + } + log::warn!("Too many sessions that over {max_sessions}, dropping new session"); continue; } log::trace!("Session count {}", TASK_COUNT.fetch_add(1, Relaxed) + 1);