diff --git a/Cargo.toml b/Cargo.toml index ca23b27..65e1f74 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -37,6 +37,11 @@ udp-stream = { version = "0.0", default-features = false } unicase = "2.7" url = "2.5" +[target.'cfg(target_os="linux")'.dependencies] +serde = { version = "1", features = [ "derive" ] } +bincode = "1" +nix = { version = "0", default-features = false, features = ["fs", "socket", "uio"] } + [target.'cfg(target_os="android")'.dependencies] android_logger = "0.13" jni = { version = "0.21", default-features = false } diff --git a/src/args.rs b/src/args.rs index e7df38c..bd9d413 100644 --- a/src/args.rs +++ b/src/args.rs @@ -1,6 +1,9 @@ use crate::{Error, Result}; use socks5_impl::protocol::UserKey; -use std::net::{IpAddr, SocketAddr, ToSocketAddrs}; +use std::{ + ffi::OsString, + net::{IpAddr, SocketAddr, ToSocketAddrs}, +}; #[derive(Debug, Clone, clap::Parser)] #[command(author, version, about = "Tunnel interface to proxy.", long_about = None)] @@ -20,13 +23,29 @@ pub struct Args { #[arg(long, value_name = "fd", conflicts_with = "tun")] pub tun_fd: Option, + /// Create a tun interface in a newly created unprivileged namespace + /// while maintaining proxy connectivity via the global network namespace. + #[arg(long)] + pub unshare: bool, + + /// File descriptor for UNIX datagram socket meant to transfer + /// network sockets from global namespace to the new one. + /// See `unshare(1)`, `namespaces(7)`, `sendmsg(2)`, `unix(7)`. + #[arg(long)] + pub socket_transfer_fd: Option, + + /// Specify a command to run with root-like capabilities in the new namespace. + /// This could be useful to start additional daemons, e.g. `openvpn` instance. + #[arg(requires = "unshare")] + pub admin_command: Vec, + /// IPv6 enabled #[arg(short = '6', long)] pub ipv6_enabled: bool, #[arg(short, long)] /// Routing and system setup, which decides whether to setup the routing and system configuration. - /// This option is only available on Linux and requires root privileges. + /// This option is only available on Linux and requires root-like privileges. See `capabilities(7)`. pub setup: bool, /// DNS handling strategy @@ -60,6 +79,9 @@ impl Default for Args { proxy: ArgProxy::default(), tun: None, tun_fd: None, + unshare: false, + socket_transfer_fd: None, + admin_command: Vec::new(), ipv6_enabled: false, setup: false, dns: ArgDns::default(), diff --git a/src/bin/main.rs b/src/bin/main.rs index 7be56ab..2fdad5e 100644 --- a/src/bin/main.rs +++ b/src/bin/main.rs @@ -13,8 +13,17 @@ async fn main() -> Result<(), BoxError> { let join_handle = tokio::spawn({ let shutdown_token = shutdown_token.clone(); async move { - if let Err(err) = tun2proxy::desktop_run_async(args, shutdown_token).await { - log::error!("main loop error: {}", err); + if args.unshare && args.socket_transfer_fd.is_none() { + #[cfg(target_os = "linux")] + if let Err(err) = namespace_proxy_main(args, shutdown_token).await { + log::error!("namespace proxy error: {}", err); + } + #[cfg(not(target_os = "linux"))] + log::error!("Your platform doesn't support unprivileged namespaces"); + } else { + if let Err(err) = tun2proxy::desktop_run_async(args, shutdown_token).await { + log::error!("main loop error: {}", err); + } } } }); @@ -31,3 +40,48 @@ async fn main() -> Result<(), BoxError> { Ok(()) } + +#[cfg(target_os = "linux")] +async fn namespace_proxy_main( + _args: Args, + _shutdown_token: tokio_util::sync::CancellationToken, +) -> Result { + use std::os::fd::AsRawFd; + + let (socket, remote_fd) = tun2proxy::socket_transfer::create_transfer_socket_pair().await?; + + let child = tokio::process::Command::new("unshare") + .args("--user --map-current-user --net --mount --keep-caps --kill-child --fork".split(' ')) + .arg(std::env::current_exe()?) + .arg("--socket-transfer-fd") + .arg(remote_fd.as_raw_fd().to_string()) + .args(std::env::args().skip(1)) + .kill_on_drop(true) + .spawn(); + + let mut child = match child { + Err(err) if err.kind() == std::io::ErrorKind::NotFound => { + log::error!("`unshare(1)` executable wasn't located in PATH."); + log::error!("Consider installing linux utils package: `apt install util-linux`"); + log::error!("Or similar for your distribution."); + return Err(err.into()); + } + child => child?, + }; + + log::info!("The tun proxy is running in unprivileged mode. See `namespaces(7)`."); + log::info!(""); + log::info!("If you need to run a process that relies on root-like capabilities (e.g. `openvpn`)"); + log::info!("Use `tun2proxy --unshare --setup [...] -- openvpn --config [...]`"); + log::info!(""); + log::info!("To run a new process in the created namespace (e.g. a flatpak app)"); + log::info!( + "Use `nsenter --preserve-credentials --user --net --mount --target {} /bin/sh`", + child.id().unwrap_or(0) + ); + log::info!(""); + + tokio::spawn(async move { tun2proxy::socket_transfer::process_socket_requests(&socket).await }); + + Ok(child.wait().await?) +} diff --git a/src/desktop_api.rs b/src/desktop_api.rs index 3d5723e..71f67dc 100644 --- a/src/desktop_api.rs +++ b/src/desktop_api.rs @@ -131,11 +131,69 @@ pub async fn desktop_run_async(args: Args, shutdown_token: tokio_util::sync::Can restore = Some(tproxy_config::tproxy_setup(&tproxy_args)?); } + #[cfg(target_os = "linux")] + { + let run_ip_util = |args: String| { + tokio::process::Command::new("ip") + .args(args.split(' ')) + .stdout(std::process::Stdio::null()) + .stderr(std::process::Stdio::null()) + .spawn() + .ok(); + }; + + if setup && !args.ipv6_enabled { + // Remove ipv6 connectivity if not explicitly required + // TODO: remove this when upstream will get updated + run_ip_util(format!("-6 route delete ::/1 dev {}", tproxy_args.tun_name)); + run_ip_util(format!("-6 route delete 80::/1 dev {}", tproxy_args.tun_name)); + } + + if setup && args.unshare { + // New namespace doesn't have any other routing device by default + // So our `tun` device should act as such to make space for other proxies. + run_ip_util(format!("route delete 0.0.0.0/1 dev {}", tproxy_args.tun_name)); + run_ip_util(format!("route delete 128.0.0.0/1 dev {}", tproxy_args.tun_name)); + + run_ip_util(format!("route add 0.0.0.0/0 dev {}", tproxy_args.tun_name)); + + if args.ipv6_enabled { + run_ip_util(format!("-6 route delete ::/1 dev {}", tproxy_args.tun_name)); + run_ip_util(format!("-6 route delete 80::/1 dev {}", tproxy_args.tun_name)); + + run_ip_util(format!("-6 route add ::/0 dev {}", tproxy_args.tun_name)); + } + } + } + + let mut admin_command_args = args.admin_command.iter(); + if let Some(command) = admin_command_args.next() { + let child = tokio::process::Command::new(command) + .args(admin_command_args) + .kill_on_drop(true) + .spawn(); + + match child { + Err(err) => { + log::warn!("Failed to start admin process: {err}"); + } + Ok(mut child) => { + tokio::spawn(async move { + if let Err(err) = child.wait().await { + log::warn!("Admin process terminated: {err}"); + } + }); + } + }; + } + let join_handle = tokio::spawn(crate::run(device, MTU, args, shutdown_token)); join_handle.await.map_err(std::io::Error::from)??; #[cfg(any(target_os = "linux", target_os = "windows", target_os = "macos"))] if setup { + // TODO: This probably should be handled by a destructor + // since otherwise removal is not guaranteed if anything above returns early. tproxy_config::tproxy_remove(restore)?; } diff --git a/src/error.rs b/src/error.rs index 96b9732..2afd19b 100644 --- a/src/error.rs +++ b/src/error.rs @@ -6,6 +6,10 @@ pub enum Error { #[error(transparent)] Io(#[from] std::io::Error), + #[cfg(target_os = "linux")] + #[error("nix::errno::Errno {0:?}")] + NixErrno(#[from] nix::errno::Errno), + #[error("TryFromIntError {0:?}")] TryFromInt(#[from] std::num::TryFromIntError), @@ -39,6 +43,10 @@ pub enum Error { #[error("std::num::ParseIntError {0:?}")] IntParseError(#[from] std::num::ParseIntError), + + #[cfg(target_os = "linux")] + #[error("bincode::Error {0:?}")] + BincodeError(#[from] bincode::Error), } impl From<&str> for Error { diff --git a/src/http.rs b/src/http.rs index 73cf9b6..7aa2569 100644 --- a/src/http.rs +++ b/src/http.rs @@ -38,6 +38,7 @@ enum HttpState { pub(crate) type DigestState = digest_auth::WwwAuthenticateHeader; pub struct HttpConnection { + server_addr: SocketAddr, state: HttpState, client_inbuf: VecDeque, server_inbuf: VecDeque, @@ -61,12 +62,14 @@ static CONTENT_LENGTH: &str = "Content-Length"; impl HttpConnection { async fn new( + server_addr: SocketAddr, info: SessionInfo, domain_name: Option, credentials: Option, digest_state: Arc>>, ) -> Result { let mut res = Self { + server_addr, state: HttpState::ExpectResponseHeaders, client_inbuf: VecDeque::default(), server_inbuf: VecDeque::default(), @@ -330,6 +333,10 @@ impl HttpConnection { #[async_trait::async_trait] impl ProxyHandler for HttpConnection { + fn get_server_addr(&self) -> SocketAddr { + self.server_addr + } + fn get_session_info(&self) -> SessionInfo { self.info } @@ -413,7 +420,7 @@ impl ProxyHandlerManager for HttpManager { return Err(Error::from("Protocol not supported by HTTP proxy").into()); } Ok(Arc::new(Mutex::new( - HttpConnection::new(info, domain_name, self.credentials.clone(), self.digest_state.clone()).await?, + HttpConnection::new(self.server, info, domain_name, self.credentials.clone(), self.digest_state.clone()).await?, ))) } diff --git a/src/lib.rs b/src/lib.rs index e2701fe..183f37d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -9,11 +9,16 @@ use ipstack::stream::{IpStackStream, IpStackTcpStream, IpStackUdpStream}; use proxy_handler::{ProxyHandler, ProxyHandlerManager}; use socks::SocksProxyManager; pub use socks5_impl::protocol::UserKey; -use std::{collections::VecDeque, net::SocketAddr, sync::Arc}; +use std::{ + collections::VecDeque, + io::ErrorKind, + net::{IpAddr, SocketAddr}, + sync::Arc, +}; use tokio::{ io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}, - net::TcpStream, - sync::Mutex, + net::{TcpSocket, TcpStream, UdpSocket}, + sync::{mpsc::Receiver, Mutex}, }; pub use tokio_util::sync::CancellationToken; use tproxy_config::is_private_ip; @@ -46,6 +51,7 @@ mod mobile_api; mod no_proxy; mod proxy_handler; mod session_info; +pub mod socket_transfer; mod socks; mod virtual_dns; @@ -56,6 +62,81 @@ const MAX_SESSIONS: u64 = 200; static TASK_COUNT: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0); use std::sync::atomic::Ordering::Relaxed; +#[allow(unused)] +#[derive(Hash, Copy, Clone, Eq, PartialEq, Debug)] +#[cfg_attr(target_os = "linux", derive(serde::Serialize, serde::Deserialize))] +pub enum SocketProtocol { + Tcp, + Udp, +} + +#[allow(unused)] +#[derive(Hash, Copy, Clone, Eq, PartialEq, Debug)] +#[cfg_attr(target_os = "linux", derive(serde::Serialize, serde::Deserialize))] +pub enum SocketDomain { + IpV4, + IpV6, +} + +impl From for SocketDomain { + fn from(value: IpAddr) -> Self { + match value { + IpAddr::V4(_) => Self::IpV4, + IpAddr::V6(_) => Self::IpV6, + } + } +} + +struct SocketQueue { + tcp_v4: Mutex>, + tcp_v6: Mutex>, + udp_v4: Mutex>, + udp_v6: Mutex>, +} + +impl SocketQueue { + async fn recv_tcp(&self, domain: SocketDomain) -> Result { + match domain { + SocketDomain::IpV4 => &self.tcp_v4, + SocketDomain::IpV6 => &self.tcp_v6, + } + .lock() + .await + .recv() + .await + .ok_or(ErrorKind::Other.into()) + } + async fn recv_udp(&self, domain: SocketDomain) -> Result { + match domain { + SocketDomain::IpV4 => &self.udp_v4, + SocketDomain::IpV6 => &self.udp_v6, + } + .lock() + .await + .recv() + .await + .ok_or(ErrorKind::Other.into()) + } +} + +async fn create_tcp_stream(socket_queue: &Option>, peer: SocketAddr) -> std::io::Result { + match &socket_queue { + None => TcpStream::connect(peer).await, + Some(queue) => queue.recv_tcp(peer.ip().into()).await?.connect(peer).await, + } +} + +async fn create_udp_stream(socket_queue: &Option>, peer: SocketAddr) -> std::io::Result { + match &socket_queue { + None => UdpStream::connect(peer).await, + Some(queue) => { + let socket = queue.recv_udp(peer.ip().into()).await?; + socket.connect(peer).await?; + UdpStream::from_tokio(socket).await + } + } +} + /// Run the proxy server /// # Arguments /// * `device` - The network device to use @@ -78,6 +159,56 @@ where None }; + #[cfg(target_os = "linux")] + let socket_queue = match args.socket_transfer_fd { + None => None, + Some(fd) => { + use crate::socket_transfer::{reconstruct_socket, reconstruct_transfer_socket, request_sockets}; + use tokio::sync::mpsc::channel; + + let fd = reconstruct_socket(fd)?; + let socket = reconstruct_transfer_socket(fd)?; + let socket = Arc::new(Mutex::new(socket)); + + macro_rules! create_socket_queue { + ($domain:ident) => {{ + const SOCKETS_PER_REQUEST: usize = 64; + + let socket = socket.clone(); + let (tx, rx) = channel(SOCKETS_PER_REQUEST); + tokio::spawn(async move { + loop { + let sockets = + match request_sockets(socket.lock().await, SocketDomain::$domain, SOCKETS_PER_REQUEST as u32).await { + Ok(sockets) => sockets, + Err(err) => { + log::warn!("Socket allocation request failed: {err}"); + continue; + } + }; + for s in sockets { + if let Err(_) = tx.send(s).await { + return; + } + } + } + }); + Mutex::new(rx) + }}; + } + + Some(Arc::new(SocketQueue { + tcp_v4: create_socket_queue!(IpV4), + tcp_v6: create_socket_queue!(IpV6), + udp_v4: create_socket_queue!(IpV4), + udp_v6: create_socket_queue!(IpV6), + })) + } + }; + + #[cfg(not(target_os = "linux"))] + let socket_queue = None; + use socks5_impl::protocol::Version::{V4, V5}; let mgr = match args.proxy.proxy_type { ProxyType::Socks5 => Arc::new(SocksProxyManager::new(server_addr, V5, key)) as Arc, @@ -120,8 +251,9 @@ where None }; let proxy_handler = mgr.new_proxy_handler(info, domain_name, false).await?; + let socket_queue = socket_queue.clone(); tokio::spawn(async move { - if let Err(err) = handle_tcp_session(tcp, server_addr, proxy_handler).await { + if let Err(err) = handle_tcp_session(tcp, proxy_handler, socket_queue).await { log::error!("{} error \"{}\"", info, err); } log::trace!("Session count {}", TASK_COUNT.fetch_sub(1, Relaxed) - 1); @@ -140,8 +272,9 @@ where } if args.dns == ArgDns::OverTcp { let proxy_handler = mgr.new_proxy_handler(info, None, false).await?; + let socket_queue = socket_queue.clone(); tokio::spawn(async move { - if let Err(err) = handle_dns_over_tcp_session(udp, server_addr, proxy_handler, ipv6_enabled).await { + if let Err(err) = handle_dns_over_tcp_session(udp, proxy_handler, socket_queue, ipv6_enabled).await { log::error!("{} error \"{}\"", info, err); } log::trace!("Session count {}", TASK_COUNT.fetch_sub(1, Relaxed) - 1); @@ -170,8 +303,11 @@ where }; match mgr.new_proxy_handler(info, domain_name, true).await { Ok(proxy_handler) => { + let socket_queue = socket_queue.clone(); tokio::spawn(async move { - if let Err(err) = handle_udp_associate_session(udp, server_addr, proxy_handler, ipv6_enabled).await { + if let Err(err) = + handle_udp_associate_session(udp, args.proxy.proxy_type, proxy_handler, socket_queue, ipv6_enabled).await + { log::info!("Ending {} with \"{}\"", info, err); } log::trace!("Session count {}", TASK_COUNT.fetch_sub(1, Relaxed) - 1); @@ -207,12 +343,17 @@ async fn handle_virtual_dns_session(mut udp: IpStackUdpStream, dns: Arc>, + socket_queue: Option>, ) -> crate::Result<()> { - let mut server = TcpStream::connect(server_addr).await?; + let (session_info, server_addr) = { + let handler = proxy_handler.lock().await; + + (handler.get_session_info(), handler.get_server_addr()) + }; + + let mut server = create_tcp_stream(&socket_queue, server_addr).await?; - let session_info = proxy_handler.lock().await.get_session_info(); log::info!("Beginning {}", session_info); if let Err(e) = handle_proxy_session(&mut server, proxy_handler).await { @@ -246,20 +387,36 @@ async fn handle_tcp_session( async fn handle_udp_associate_session( mut udp_stack: IpStackUdpStream, - server_addr: SocketAddr, + proxy_type: ProxyType, proxy_handler: Arc>, + socket_queue: Option>, ipv6_enabled: bool, ) -> crate::Result<()> { use socks5_impl::protocol::{Address, StreamOperation, UdpHeader}; - let mut server = TcpStream::connect(server_addr).await?; - let session_info = proxy_handler.lock().await.get_session_info(); - let domain_name = proxy_handler.lock().await.get_domain_name(); + + let (session_info, server_addr, domain_name, udp_addr) = { + let handler = proxy_handler.lock().await; + ( + handler.get_session_info(), + handler.get_server_addr(), + handler.get_domain_name(), + handler.get_udp_associate(), + ) + }; + log::info!("Beginning {}", session_info); - let udp_addr = handle_proxy_session(&mut server, proxy_handler).await?; - let udp_addr = udp_addr.ok_or("udp associate failed")?; + let udp_addr = match udp_addr { + Some(udp_addr) => udp_addr, + None => { + let mut server = create_tcp_stream(&socket_queue, server_addr).await?; - let mut udp_server = UdpStream::connect(udp_addr).await?; + let udp_addr = handle_proxy_session(&mut server, proxy_handler).await?; + udp_addr.ok_or("udp associate failed")? + } + }; + + let mut udp_server = create_udp_stream(&socket_queue, udp_addr).await?; let mut buf1 = [0_u8; 4096]; let mut buf2 = [0_u8; 4096]; @@ -272,18 +429,22 @@ async fn handle_udp_associate_session( } let buf1 = &buf1[..len]; - let s5addr = if let Some(domain_name) = &domain_name { - Address::DomainAddress(domain_name.clone(), session_info.dst.port()) + if let ProxyType::Socks4 | ProxyType::Socks5 = proxy_type { + let s5addr = if let Some(domain_name) = &domain_name { + Address::DomainAddress(domain_name.clone(), session_info.dst.port()) + } else { + session_info.dst.into() + }; + + // Add SOCKS5 UDP header to the incoming data + let mut s5_udp_data = Vec::::new(); + UdpHeader::new(0, s5addr).write_to_stream(&mut s5_udp_data)?; + s5_udp_data.extend_from_slice(buf1); + + udp_server.write_all(&s5_udp_data).await?; } else { - session_info.dst.into() - }; - - // Add SOCKS5 UDP header to the incoming data - let mut s5_udp_data = Vec::::new(); - UdpHeader::new(0, s5addr).write_to_stream(&mut s5_udp_data)?; - s5_udp_data.extend_from_slice(buf1); - - udp_server.write_all(&s5_udp_data).await?; + udp_server.write_all(buf1).await?; + } } len = udp_server.read(&mut buf2) => { let len = len?; @@ -292,21 +453,25 @@ async fn handle_udp_associate_session( } let buf2 = &buf2[..len]; - // Remove SOCKS5 UDP header from the server data - let header = UdpHeader::retrieve_from_stream(&mut &buf2[..])?; - let data = &buf2[header.len()..]; + if let ProxyType::Socks4 | ProxyType::Socks5 = proxy_type { + // Remove SOCKS5 UDP header from the server data + let header = UdpHeader::retrieve_from_stream(&mut &buf2[..])?; + let data = &buf2[header.len()..]; - let buf = if session_info.dst.port() == DNS_PORT { - let mut message = dns::parse_data_to_dns_message(data, false)?; - if !ipv6_enabled { - dns::remove_ipv6_entries(&mut message); - } - message.to_vec()? + let buf = if session_info.dst.port() == DNS_PORT { + let mut message = dns::parse_data_to_dns_message(data, false)?; + if !ipv6_enabled { + dns::remove_ipv6_entries(&mut message); + } + message.to_vec()? + } else { + data.to_vec() + }; + + udp_stack.write_all(&buf).await?; } else { - data.to_vec() - }; - - udp_stack.write_all(&buf).await?; + udp_stack.write_all(buf2).await?; + } } } } @@ -318,13 +483,18 @@ async fn handle_udp_associate_session( async fn handle_dns_over_tcp_session( mut udp_stack: IpStackUdpStream, - server_addr: SocketAddr, proxy_handler: Arc>, + socket_queue: Option>, ipv6_enabled: bool, ) -> crate::Result<()> { - let mut server = TcpStream::connect(server_addr).await?; + let (session_info, server_addr) = { + let handler = proxy_handler.lock().await; + + (handler.get_session_info(), handler.get_server_addr()) + }; + + let mut server = create_tcp_stream(&socket_queue, server_addr).await?; - let session_info = proxy_handler.lock().await.get_session_info(); log::info!("Beginning {}", session_info); let _ = handle_proxy_session(&mut server, proxy_handler).await?; diff --git a/src/no_proxy.rs b/src/no_proxy.rs index 99c4dbf..83edadf 100644 --- a/src/no_proxy.rs +++ b/src/no_proxy.rs @@ -16,6 +16,10 @@ struct NoProxyHandler { #[async_trait::async_trait] impl ProxyHandler for NoProxyHandler { + fn get_server_addr(&self) -> SocketAddr { + self.info.dst + } + fn get_session_info(&self) -> SessionInfo { self.info } diff --git a/src/proxy_handler.rs b/src/proxy_handler.rs index 5621347..94406a6 100644 --- a/src/proxy_handler.rs +++ b/src/proxy_handler.rs @@ -7,6 +7,7 @@ use tokio::sync::Mutex; #[async_trait::async_trait] pub(crate) trait ProxyHandler: Send + Sync { + fn get_server_addr(&self) -> SocketAddr; fn get_session_info(&self) -> SessionInfo; fn get_domain_name(&self) -> Option; async fn push_data(&mut self, event: IncomingDataEvent<'_>) -> std::io::Result<()>; diff --git a/src/socket_transfer.rs b/src/socket_transfer.rs new file mode 100644 index 0000000..194c6f7 --- /dev/null +++ b/src/socket_transfer.rs @@ -0,0 +1,230 @@ +#![cfg(target_os = "linux")] + +use crate::{error, SocketDomain, SocketProtocol}; +use nix::{ + errno::Errno, + fcntl::{self, FdFlag}, + sys::socket::{cmsg_space, getsockopt, recvmsg, sendmsg, sockopt, ControlMessage, ControlMessageOwned, MsgFlags, SockType}, +}; +use serde::{Deserialize, Serialize}; +use std::{ + io::{ErrorKind, IoSlice, IoSliceMut, Result}, + ops::DerefMut, + os::fd::{AsFd, AsRawFd, FromRawFd, IntoRawFd, OwnedFd, RawFd}, +}; +use tokio::net::{TcpSocket, UdpSocket, UnixDatagram}; + +const REQUEST_BUFFER_SIZE: usize = 64; + +#[derive(Hash, Copy, Clone, Eq, PartialEq, Debug, Serialize, Deserialize)] +struct Request { + protocol: SocketProtocol, + domain: SocketDomain, + number: u32, +} + +#[derive(Hash, Copy, Clone, Eq, PartialEq, Debug, Serialize, Deserialize)] +enum Response { + Ok, +} + +/// Reconstruct socket from raw `fd` +pub fn reconstruct_socket(fd: RawFd) -> Result { + // Check if `fd` is valid + let fd_flags = fcntl::fcntl(fd, fcntl::F_GETFD)?; + + // `fd` is confirmed to be valid so it should be closed + let socket = unsafe { OwnedFd::from_raw_fd(fd) }; + + // Insert CLOEXEC flag to the `fd` to prevent further propagation across `execve(2)` calls + let mut fd_flags = FdFlag::from_bits(fd_flags).ok_or(ErrorKind::Unsupported)?; + if !fd_flags.contains(FdFlag::FD_CLOEXEC) { + fd_flags.insert(FdFlag::FD_CLOEXEC); + fcntl::fcntl(fd, fcntl::F_SETFD(fd_flags))?; + } + + Ok(socket) +} + +/// Reconstruct transfer socket from `fd` +/// +/// Panics if called outside of tokio runtime +pub fn reconstruct_transfer_socket(fd: OwnedFd) -> Result { + // Check if socket of type DATAGRAM + let sock_type = getsockopt(&fd, sockopt::SockType)?; + if !matches!(sock_type, SockType::Datagram) { + return Err(ErrorKind::InvalidInput.into()); + } + + let std_socket: std::os::unix::net::UnixDatagram = fd.into(); + std_socket.set_nonblocking(true)?; + + // Fails if tokio context is absent + Ok(UnixDatagram::from_std(std_socket).unwrap()) +} + +/// Create pair of interconnected sockets one of which is set to stay open across `execve(2)` calls. +pub async fn create_transfer_socket_pair() -> std::io::Result<(UnixDatagram, OwnedFd)> { + let (local, remote) = tokio::net::UnixDatagram::pair()?; + + let remote_fd: OwnedFd = remote.into_std().unwrap().into(); + + // Get `remote_fd` flags + let fd_flags = fcntl::fcntl(remote_fd.as_raw_fd(), fcntl::F_GETFD)?; + + // Remove CLOEXEC flag from the `remote_fd` to allow propagating across `execve(2)` + let mut fd_flags = FdFlag::from_bits(fd_flags).ok_or(ErrorKind::Unsupported)?; + fd_flags.remove(FdFlag::FD_CLOEXEC); + fcntl::fcntl(remote_fd.as_raw_fd(), fcntl::F_SETFD(fd_flags))?; + + Ok((local, remote_fd)) +} + +pub trait TransferableSocket: Sized { + fn from_fd(fd: OwnedFd) -> Result; + fn domain() -> SocketProtocol; +} + +impl TransferableSocket for TcpSocket { + fn from_fd(fd: OwnedFd) -> Result { + // Check if socket is of type STREAM + let sock_type = getsockopt(&fd, sockopt::SockType)?; + if !matches!(sock_type, SockType::Stream) { + return Err(ErrorKind::InvalidInput.into()); + } + + let std_stream: std::net::TcpStream = fd.into(); + std_stream.set_nonblocking(true)?; + + Ok(TcpSocket::from_std_stream(std_stream)) + } + + fn domain() -> SocketProtocol { + SocketProtocol::Tcp + } +} + +impl TransferableSocket for UdpSocket { + /// Panics if called outside of tokio runtime + fn from_fd(fd: OwnedFd) -> Result { + // Check if socket is of type DATAGRAM + let sock_type = getsockopt(&fd, sockopt::SockType)?; + if !matches!(sock_type, SockType::Datagram) { + return Err(ErrorKind::InvalidInput.into()); + } + + let std_socket: std::net::UdpSocket = fd.into(); + std_socket.set_nonblocking(true)?; + + Ok(UdpSocket::try_from(std_socket).unwrap()) + } + + fn domain() -> SocketProtocol { + SocketProtocol::Udp + } +} + +/// Send [`Request`] to `socket` and return received [`TransferableSocket`]s +/// +/// Panics if called outside of tokio runtime +pub async fn request_sockets(mut socket: S, domain: SocketDomain, number: u32) -> error::Result> +where + S: DerefMut, + T: TransferableSocket, +{ + // Borrow socket as mut to prevent multiple simultaneous requests + let socket = socket.deref_mut(); + + // Send request + let request = bincode::serialize(&Request { + protocol: T::domain(), + domain, + number, + })?; + + socket.send(&request[..]).await?; + + // Receive response + loop { + socket.readable().await?; + + let mut buf = [0_u8; REQUEST_BUFFER_SIZE]; + let mut iov = [IoSliceMut::new(&mut buf[..])]; + let mut cmsg = Vec::with_capacity(cmsg_space::() * number as usize); + + let msg = recvmsg::<()>(socket.as_fd().as_raw_fd(), &mut iov, Some(&mut cmsg), MsgFlags::empty()); + + let msg = match msg { + Err(Errno::EAGAIN) => continue, + msg => msg?, + }; + + // Parse response + let response = &msg.iovs().next().unwrap()[..msg.bytes]; + let response: Response = bincode::deserialize(response)?; + if !matches!(response, Response::Ok) { + return Err("Request for new sockets failed".into()); + } + + // Process received file descriptors + let mut sockets = Vec::::with_capacity(number as usize); + for cmsg in msg.cmsgs() { + if let ControlMessageOwned::ScmRights(fds) = cmsg { + for fd in fds { + if fd < 0 { + return Err("Received socket is invalid".into()); + } + + let owned_fd = reconstruct_socket(fd)?; + sockets.push(T::from_fd(owned_fd)?); + } + } + } + + return Ok(sockets); + } +} + +/// Process [`Request`]s received from `socket` +/// +/// Panics if called outside of tokio runtime +pub async fn process_socket_requests(socket: &UnixDatagram) -> error::Result<()> { + loop { + let mut buf = [0_u8; REQUEST_BUFFER_SIZE]; + + let len = socket.recv(&mut buf[..]).await?; + + let request: Request = bincode::deserialize(&buf[..len])?; + + let response = Response::Ok; + let buf = bincode::serialize(&response)?; + + let mut owned_fd_buf: Vec = Vec::with_capacity(request.number as usize); + for _ in 0..request.number { + let fd = match request.protocol { + SocketProtocol::Tcp => match request.domain { + SocketDomain::IpV4 => tokio::net::TcpSocket::new_v4(), + SocketDomain::IpV6 => tokio::net::TcpSocket::new_v6(), + } + .map(|s| unsafe { OwnedFd::from_raw_fd(s.into_raw_fd()) }), + SocketProtocol::Udp => match request.domain { + SocketDomain::IpV4 => tokio::net::UdpSocket::bind("0.0.0.0:0").await, + SocketDomain::IpV6 => tokio::net::UdpSocket::bind("[::]:0").await, + } + .map(|s| s.into_std().unwrap().into()), + }; + match fd { + Err(err) => log::warn!("Failed to allocate socket: {err}"), + Ok(fd) => owned_fd_buf.push(fd), + }; + } + + socket.writable().await?; + + let raw_fd_buf: Vec = owned_fd_buf.iter().map(|fd| fd.as_raw_fd()).collect(); + let cmsg = ControlMessage::ScmRights(&raw_fd_buf[..]); + let iov = [IoSlice::new(&buf[..])]; + + sendmsg::<()>(socket.as_raw_fd(), &iov, &[cmsg], MsgFlags::empty(), None)?; + } +} diff --git a/src/socks.rs b/src/socks.rs index 90a6bcd..a7c0b0b 100644 --- a/src/socks.rs +++ b/src/socks.rs @@ -20,6 +20,7 @@ enum SocksState { } struct SocksProxyImpl { + server_addr: SocketAddr, info: SessionInfo, domain_name: Option, state: SocksState, @@ -35,6 +36,7 @@ struct SocksProxyImpl { impl SocksProxyImpl { fn new( + server_addr: SocketAddr, info: SessionInfo, domain_name: Option, credentials: Option, @@ -42,6 +44,7 @@ impl SocksProxyImpl { command: protocol::Command, ) -> Result { let mut result = Self { + server_addr, info, domain_name, state: SocksState::ClientHello, @@ -260,6 +263,10 @@ impl SocksProxyImpl { #[async_trait::async_trait] impl ProxyHandler for SocksProxyImpl { + fn get_server_addr(&self) -> SocketAddr { + self.server_addr + } + fn get_session_info(&self) -> SessionInfo { self.info } @@ -339,6 +346,7 @@ impl ProxyHandlerManager for SocksProxyManager { let command = if udp_associate { UdpAssociate } else { Connect }; let credentials = self.credentials.clone(); Ok(Arc::new(Mutex::new(SocksProxyImpl::new( + self.server, info, domain_name, credentials,