make TASK_COUNT as local task_count variable
Some checks failed
Push or PR / build_n_test (macos-latest) (push) Has been cancelled
Push or PR / build_n_test (ubuntu-latest) (push) Has been cancelled
Push or PR / build_n_test (windows-latest) (push) Has been cancelled
Push or PR / build_n_test_android (push) Has been cancelled
Push or PR / build_n_test_ios (push) Has been cancelled
Push or PR / Check semver (push) Has been cancelled
Integration Tests / Proxy Tests (push) Has been cancelled

This commit is contained in:
ssrlive 2025-04-20 19:56:36 +08:00
parent 7121a80300
commit 88423039c6

View file

@ -64,9 +64,6 @@ pub mod win_svc;
const DNS_PORT: u16 = 53; const DNS_PORT: u16 = 53;
static TASK_COUNT: std::sync::atomic::AtomicUsize = std::sync::atomic::AtomicUsize::new(0);
use std::sync::atomic::Ordering::Relaxed;
#[allow(unused)] #[allow(unused)]
#[derive(Hash, Copy, Clone, Eq, PartialEq, Debug)] #[derive(Hash, Copy, Clone, Eq, PartialEq, Debug)]
#[cfg_attr( #[cfg_attr(
@ -224,11 +221,11 @@ where
let socket_queue = None; let socket_queue = None;
use socks5_impl::protocol::Version::{V4, V5}; use socks5_impl::protocol::Version::{V4, V5};
let mgr = match args.proxy.proxy_type { let mgr: Arc<dyn ProxyHandlerManager> = match args.proxy.proxy_type {
ProxyType::Socks5 => Arc::new(SocksProxyManager::new(server_addr, V5, key)) as Arc<dyn ProxyHandlerManager>, ProxyType::Socks5 => Arc::new(SocksProxyManager::new(server_addr, V5, key)),
ProxyType::Socks4 => Arc::new(SocksProxyManager::new(server_addr, V4, key)) as Arc<dyn ProxyHandlerManager>, ProxyType::Socks4 => Arc::new(SocksProxyManager::new(server_addr, V4, key)),
ProxyType::Http => Arc::new(HttpManager::new(server_addr, key)) as Arc<dyn ProxyHandlerManager>, ProxyType::Http => Arc::new(HttpManager::new(server_addr, key)),
ProxyType::None => Arc::new(NoProxyManager::new()) as Arc<dyn ProxyHandlerManager>, ProxyType::None => Arc::new(NoProxyManager::new()),
}; };
let mut ipstack_config = ipstack::IpStackConfig::default(); let mut ipstack_config = ipstack::IpStackConfig::default();
@ -256,7 +253,11 @@ where
client client
}); });
let task_count = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0));
use std::sync::atomic::Ordering::Relaxed;
loop { loop {
let task_count = task_count.clone();
let virtual_dns = virtual_dns.clone(); let virtual_dns = virtual_dns.clone();
let ip_stack_stream = tokio::select! { let ip_stack_stream = tokio::select! {
_ = shutdown_token.cancelled() => { _ = shutdown_token.cancelled() => {
@ -270,7 +271,7 @@ where
let max_sessions = args.max_sessions; let max_sessions = args.max_sessions;
match ip_stack_stream { match ip_stack_stream {
IpStackStream::Tcp(tcp) => { IpStackStream::Tcp(tcp) => {
if TASK_COUNT.load(Relaxed) >= max_sessions { if task_count.load(Relaxed) >= max_sessions {
if args.exit_on_fatal_error { if args.exit_on_fatal_error {
log::info!("Too many sessions that over {max_sessions}, exiting..."); log::info!("Too many sessions that over {max_sessions}, exiting...");
break; break;
@ -278,7 +279,7 @@ where
log::warn!("Too many sessions that over {max_sessions}, dropping new session"); log::warn!("Too many sessions that over {max_sessions}, dropping new session");
continue; continue;
} }
log::trace!("Session count {}", TASK_COUNT.fetch_add(1, Relaxed).saturating_add(1)); log::trace!("Session count {}", task_count.fetch_add(1, Relaxed).saturating_add(1));
let info = SessionInfo::new(tcp.local_addr(), tcp.peer_addr(), IpProtocol::Tcp); let info = SessionInfo::new(tcp.local_addr(), tcp.peer_addr(), IpProtocol::Tcp);
let domain_name = if let Some(virtual_dns) = &virtual_dns { let domain_name = if let Some(virtual_dns) = &virtual_dns {
let mut virtual_dns = virtual_dns.lock().await; let mut virtual_dns = virtual_dns.lock().await;
@ -293,11 +294,11 @@ where
if let Err(err) = handle_tcp_session(tcp, proxy_handler, socket_queue).await { if let Err(err) = handle_tcp_session(tcp, proxy_handler, socket_queue).await {
log::error!("{} error \"{}\"", info, err); log::error!("{} error \"{}\"", info, err);
} }
log::trace!("Session count {}", TASK_COUNT.fetch_sub(1, Relaxed).saturating_sub(1)); log::trace!("Session count {}", task_count.fetch_sub(1, Relaxed).saturating_sub(1));
}); });
} }
IpStackStream::Udp(udp) => { IpStackStream::Udp(udp) => {
if TASK_COUNT.load(Relaxed) >= max_sessions { if task_count.load(Relaxed) >= max_sessions {
if args.exit_on_fatal_error { if args.exit_on_fatal_error {
log::info!("Too many sessions that over {max_sessions}, exiting..."); log::info!("Too many sessions that over {max_sessions}, exiting...");
break; break;
@ -305,11 +306,11 @@ where
log::warn!("Too many sessions that over {max_sessions}, dropping new session"); log::warn!("Too many sessions that over {max_sessions}, dropping new session");
continue; continue;
} }
log::trace!("Session count {}", TASK_COUNT.fetch_add(1, Relaxed).saturating_add(1)); log::trace!("Session count {}", task_count.fetch_add(1, Relaxed).saturating_add(1));
let mut info = SessionInfo::new(udp.local_addr(), udp.peer_addr(), IpProtocol::Udp); let mut info = SessionInfo::new(udp.local_addr(), udp.peer_addr(), IpProtocol::Udp);
if info.dst.port() == DNS_PORT { if info.dst.port() == DNS_PORT {
if is_private_ip(info.dst.ip()) { if is_private_ip(info.dst.ip()) {
info.dst.set_ip(dns_addr); info.dst.set_ip(dns_addr); // !!! Here we change the destination address to remote DNS server!!!
} }
if args.dns == ArgDns::OverTcp { if args.dns == ArgDns::OverTcp {
info.protocol = IpProtocol::Tcp; info.protocol = IpProtocol::Tcp;
@ -319,7 +320,7 @@ where
if let Err(err) = handle_dns_over_tcp_session(udp, proxy_handler, socket_queue, ipv6_enabled).await { if let Err(err) = handle_dns_over_tcp_session(udp, proxy_handler, socket_queue, ipv6_enabled).await {
log::error!("{} error \"{}\"", info, err); log::error!("{} error \"{}\"", info, err);
} }
log::trace!("Session count {}", TASK_COUNT.fetch_sub(1, Relaxed).saturating_sub(1)); log::trace!("Session count {}", task_count.fetch_sub(1, Relaxed).saturating_sub(1));
}); });
continue; continue;
} }
@ -330,7 +331,7 @@ where
log::error!("{} error \"{}\"", info, err); log::error!("{} error \"{}\"", info, err);
} }
} }
log::trace!("Session count {}", TASK_COUNT.fetch_sub(1, Relaxed).saturating_sub(1)); log::trace!("Session count {}", task_count.fetch_sub(1, Relaxed).saturating_sub(1));
}); });
continue; continue;
} }
@ -361,7 +362,7 @@ where
if let Err(e) = handle_udp_gateway_session(udp, udpgw, &dst_addr, proxy_handler, queue, ipv6_enabled).await { if let Err(e) = handle_udp_gateway_session(udp, udpgw, &dst_addr, proxy_handler, queue, ipv6_enabled).await {
log::info!("Ending {} with \"{}\"", info, e); log::info!("Ending {} with \"{}\"", info, e);
} }
log::trace!("Session count {}", TASK_COUNT.fetch_sub(1, Relaxed).saturating_sub(1)); log::trace!("Session count {}", task_count.fetch_sub(1, Relaxed).saturating_sub(1));
}); });
continue; continue;
} }
@ -373,7 +374,7 @@ where
if let Err(err) = handle_udp_associate_session(udp, ty, proxy_handler, socket_queue, ipv6_enabled).await { if let Err(err) = handle_udp_associate_session(udp, ty, proxy_handler, socket_queue, ipv6_enabled).await {
log::info!("Ending {} with \"{}\"", info, err); log::info!("Ending {} with \"{}\"", info, err);
} }
log::trace!("Session count {}", TASK_COUNT.fetch_sub(1, Relaxed).saturating_sub(1)); log::trace!("Session count {}", task_count.fetch_sub(1, Relaxed).saturating_sub(1));
}); });
} }
Err(e) => { Err(e) => {
@ -392,7 +393,7 @@ where
} }
} }
} }
Ok(TASK_COUNT.load(Relaxed)) Ok(task_count.load(Relaxed))
} }
async fn handle_virtual_dns_session(mut udp: IpStackUdpStream, dns: Arc<Mutex<VirtualDns>>) -> crate::Result<()> { async fn handle_virtual_dns_session(mut udp: IpStackUdpStream, dns: Arc<Mutex<VirtualDns>>) -> crate::Result<()> {