add support for unprivileged namespaces

This commit is contained in:
Remy D. Farley 2024-04-03 14:26:46 +00:00 committed by B. Blechschmidt
parent 5e99c9f874
commit d351b5031c
11 changed files with 615 additions and 48 deletions

View file

@ -37,6 +37,11 @@ udp-stream = { version = "0.0", default-features = false }
unicase = "2.7"
url = "2.5"
[target.'cfg(target_os="linux")'.dependencies]
serde = { version = "1", features = [ "derive" ] }
bincode = "1"
nix = { version = "0", default-features = false, features = ["fs", "socket", "uio"] }
[target.'cfg(target_os="android")'.dependencies]
android_logger = "0.13"
jni = { version = "0.21", default-features = false }

View file

@ -1,6 +1,9 @@
use crate::{Error, Result};
use socks5_impl::protocol::UserKey;
use std::net::{IpAddr, SocketAddr, ToSocketAddrs};
use std::{
ffi::OsString,
net::{IpAddr, SocketAddr, ToSocketAddrs},
};
#[derive(Debug, Clone, clap::Parser)]
#[command(author, version, about = "Tunnel interface to proxy.", long_about = None)]
@ -20,13 +23,29 @@ pub struct Args {
#[arg(long, value_name = "fd", conflicts_with = "tun")]
pub tun_fd: Option<i32>,
/// Create a tun interface in a newly created unprivileged namespace
/// while maintaining proxy connectivity via the global network namespace.
#[arg(long)]
pub unshare: bool,
/// File descriptor for UNIX datagram socket meant to transfer
/// network sockets from global namespace to the new one.
/// See `unshare(1)`, `namespaces(7)`, `sendmsg(2)`, `unix(7)`.
#[arg(long)]
pub socket_transfer_fd: Option<i32>,
/// Specify a command to run with root-like capabilities in the new namespace.
/// This could be useful to start additional daemons, e.g. `openvpn` instance.
#[arg(requires = "unshare")]
pub admin_command: Vec<OsString>,
/// IPv6 enabled
#[arg(short = '6', long)]
pub ipv6_enabled: bool,
#[arg(short, long)]
/// Routing and system setup, which decides whether to setup the routing and system configuration.
/// This option is only available on Linux and requires root privileges.
/// This option is only available on Linux and requires root-like privileges. See `capabilities(7)`.
pub setup: bool,
/// DNS handling strategy
@ -60,6 +79,9 @@ impl Default for Args {
proxy: ArgProxy::default(),
tun: None,
tun_fd: None,
unshare: false,
socket_transfer_fd: None,
admin_command: Vec::new(),
ipv6_enabled: false,
setup: false,
dns: ArgDns::default(),

View file

@ -13,10 +13,19 @@ async fn main() -> Result<(), BoxError> {
let join_handle = tokio::spawn({
let shutdown_token = shutdown_token.clone();
async move {
if args.unshare && args.socket_transfer_fd.is_none() {
#[cfg(target_os = "linux")]
if let Err(err) = namespace_proxy_main(args, shutdown_token).await {
log::error!("namespace proxy error: {}", err);
}
#[cfg(not(target_os = "linux"))]
log::error!("Your platform doesn't support unprivileged namespaces");
} else {
if let Err(err) = tun2proxy::desktop_run_async(args, shutdown_token).await {
log::error!("main loop error: {}", err);
}
}
}
});
ctrlc2::set_async_handler(async move {
@ -31,3 +40,48 @@ async fn main() -> Result<(), BoxError> {
Ok(())
}
#[cfg(target_os = "linux")]
async fn namespace_proxy_main(
_args: Args,
_shutdown_token: tokio_util::sync::CancellationToken,
) -> Result<std::process::ExitStatus, tun2proxy::Error> {
use std::os::fd::AsRawFd;
let (socket, remote_fd) = tun2proxy::socket_transfer::create_transfer_socket_pair().await?;
let child = tokio::process::Command::new("unshare")
.args("--user --map-current-user --net --mount --keep-caps --kill-child --fork".split(' '))
.arg(std::env::current_exe()?)
.arg("--socket-transfer-fd")
.arg(remote_fd.as_raw_fd().to_string())
.args(std::env::args().skip(1))
.kill_on_drop(true)
.spawn();
let mut child = match child {
Err(err) if err.kind() == std::io::ErrorKind::NotFound => {
log::error!("`unshare(1)` executable wasn't located in PATH.");
log::error!("Consider installing linux utils package: `apt install util-linux`");
log::error!("Or similar for your distribution.");
return Err(err.into());
}
child => child?,
};
log::info!("The tun proxy is running in unprivileged mode. See `namespaces(7)`.");
log::info!("");
log::info!("If you need to run a process that relies on root-like capabilities (e.g. `openvpn`)");
log::info!("Use `tun2proxy --unshare --setup [...] -- openvpn --config [...]`");
log::info!("");
log::info!("To run a new process in the created namespace (e.g. a flatpak app)");
log::info!(
"Use `nsenter --preserve-credentials --user --net --mount --target {} /bin/sh`",
child.id().unwrap_or(0)
);
log::info!("");
tokio::spawn(async move { tun2proxy::socket_transfer::process_socket_requests(&socket).await });
Ok(child.wait().await?)
}

View file

@ -131,11 +131,69 @@ pub async fn desktop_run_async(args: Args, shutdown_token: tokio_util::sync::Can
restore = Some(tproxy_config::tproxy_setup(&tproxy_args)?);
}
#[cfg(target_os = "linux")]
{
let run_ip_util = |args: String| {
tokio::process::Command::new("ip")
.args(args.split(' '))
.stdout(std::process::Stdio::null())
.stderr(std::process::Stdio::null())
.spawn()
.ok();
};
if setup && !args.ipv6_enabled {
// Remove ipv6 connectivity if not explicitly required
// TODO: remove this when upstream will get updated
run_ip_util(format!("-6 route delete ::/1 dev {}", tproxy_args.tun_name));
run_ip_util(format!("-6 route delete 80::/1 dev {}", tproxy_args.tun_name));
}
if setup && args.unshare {
// New namespace doesn't have any other routing device by default
// So our `tun` device should act as such to make space for other proxies.
run_ip_util(format!("route delete 0.0.0.0/1 dev {}", tproxy_args.tun_name));
run_ip_util(format!("route delete 128.0.0.0/1 dev {}", tproxy_args.tun_name));
run_ip_util(format!("route add 0.0.0.0/0 dev {}", tproxy_args.tun_name));
if args.ipv6_enabled {
run_ip_util(format!("-6 route delete ::/1 dev {}", tproxy_args.tun_name));
run_ip_util(format!("-6 route delete 80::/1 dev {}", tproxy_args.tun_name));
run_ip_util(format!("-6 route add ::/0 dev {}", tproxy_args.tun_name));
}
}
}
let mut admin_command_args = args.admin_command.iter();
if let Some(command) = admin_command_args.next() {
let child = tokio::process::Command::new(command)
.args(admin_command_args)
.kill_on_drop(true)
.spawn();
match child {
Err(err) => {
log::warn!("Failed to start admin process: {err}");
}
Ok(mut child) => {
tokio::spawn(async move {
if let Err(err) = child.wait().await {
log::warn!("Admin process terminated: {err}");
}
});
}
};
}
let join_handle = tokio::spawn(crate::run(device, MTU, args, shutdown_token));
join_handle.await.map_err(std::io::Error::from)??;
#[cfg(any(target_os = "linux", target_os = "windows", target_os = "macos"))]
if setup {
// TODO: This probably should be handled by a destructor
// since otherwise removal is not guaranteed if anything above returns early.
tproxy_config::tproxy_remove(restore)?;
}

View file

@ -6,6 +6,10 @@ pub enum Error {
#[error(transparent)]
Io(#[from] std::io::Error),
#[cfg(target_os = "linux")]
#[error("nix::errno::Errno {0:?}")]
NixErrno(#[from] nix::errno::Errno),
#[error("TryFromIntError {0:?}")]
TryFromInt(#[from] std::num::TryFromIntError),
@ -39,6 +43,10 @@ pub enum Error {
#[error("std::num::ParseIntError {0:?}")]
IntParseError(#[from] std::num::ParseIntError),
#[cfg(target_os = "linux")]
#[error("bincode::Error {0:?}")]
BincodeError(#[from] bincode::Error),
}
impl From<&str> for Error {

View file

@ -38,6 +38,7 @@ enum HttpState {
pub(crate) type DigestState = digest_auth::WwwAuthenticateHeader;
pub struct HttpConnection {
server_addr: SocketAddr,
state: HttpState,
client_inbuf: VecDeque<u8>,
server_inbuf: VecDeque<u8>,
@ -61,12 +62,14 @@ static CONTENT_LENGTH: &str = "Content-Length";
impl HttpConnection {
async fn new(
server_addr: SocketAddr,
info: SessionInfo,
domain_name: Option<String>,
credentials: Option<UserKey>,
digest_state: Arc<Mutex<Option<DigestState>>>,
) -> Result<Self> {
let mut res = Self {
server_addr,
state: HttpState::ExpectResponseHeaders,
client_inbuf: VecDeque::default(),
server_inbuf: VecDeque::default(),
@ -330,6 +333,10 @@ impl HttpConnection {
#[async_trait::async_trait]
impl ProxyHandler for HttpConnection {
fn get_server_addr(&self) -> SocketAddr {
self.server_addr
}
fn get_session_info(&self) -> SessionInfo {
self.info
}
@ -413,7 +420,7 @@ impl ProxyHandlerManager for HttpManager {
return Err(Error::from("Protocol not supported by HTTP proxy").into());
}
Ok(Arc::new(Mutex::new(
HttpConnection::new(info, domain_name, self.credentials.clone(), self.digest_state.clone()).await?,
HttpConnection::new(self.server, info, domain_name, self.credentials.clone(), self.digest_state.clone()).await?,
)))
}

View file

@ -9,11 +9,16 @@ use ipstack::stream::{IpStackStream, IpStackTcpStream, IpStackUdpStream};
use proxy_handler::{ProxyHandler, ProxyHandlerManager};
use socks::SocksProxyManager;
pub use socks5_impl::protocol::UserKey;
use std::{collections::VecDeque, net::SocketAddr, sync::Arc};
use std::{
collections::VecDeque,
io::ErrorKind,
net::{IpAddr, SocketAddr},
sync::Arc,
};
use tokio::{
io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
net::TcpStream,
sync::Mutex,
net::{TcpSocket, TcpStream, UdpSocket},
sync::{mpsc::Receiver, Mutex},
};
pub use tokio_util::sync::CancellationToken;
use tproxy_config::is_private_ip;
@ -46,6 +51,7 @@ mod mobile_api;
mod no_proxy;
mod proxy_handler;
mod session_info;
pub mod socket_transfer;
mod socks;
mod virtual_dns;
@ -56,6 +62,81 @@ const MAX_SESSIONS: u64 = 200;
static TASK_COUNT: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0);
use std::sync::atomic::Ordering::Relaxed;
#[allow(unused)]
#[derive(Hash, Copy, Clone, Eq, PartialEq, Debug)]
#[cfg_attr(target_os = "linux", derive(serde::Serialize, serde::Deserialize))]
pub enum SocketProtocol {
Tcp,
Udp,
}
#[allow(unused)]
#[derive(Hash, Copy, Clone, Eq, PartialEq, Debug)]
#[cfg_attr(target_os = "linux", derive(serde::Serialize, serde::Deserialize))]
pub enum SocketDomain {
IpV4,
IpV6,
}
impl From<IpAddr> for SocketDomain {
fn from(value: IpAddr) -> Self {
match value {
IpAddr::V4(_) => Self::IpV4,
IpAddr::V6(_) => Self::IpV6,
}
}
}
struct SocketQueue {
tcp_v4: Mutex<Receiver<TcpSocket>>,
tcp_v6: Mutex<Receiver<TcpSocket>>,
udp_v4: Mutex<Receiver<UdpSocket>>,
udp_v6: Mutex<Receiver<UdpSocket>>,
}
impl SocketQueue {
async fn recv_tcp(&self, domain: SocketDomain) -> Result<TcpSocket, std::io::Error> {
match domain {
SocketDomain::IpV4 => &self.tcp_v4,
SocketDomain::IpV6 => &self.tcp_v6,
}
.lock()
.await
.recv()
.await
.ok_or(ErrorKind::Other.into())
}
async fn recv_udp(&self, domain: SocketDomain) -> Result<UdpSocket, std::io::Error> {
match domain {
SocketDomain::IpV4 => &self.udp_v4,
SocketDomain::IpV6 => &self.udp_v6,
}
.lock()
.await
.recv()
.await
.ok_or(ErrorKind::Other.into())
}
}
async fn create_tcp_stream(socket_queue: &Option<Arc<SocketQueue>>, peer: SocketAddr) -> std::io::Result<TcpStream> {
match &socket_queue {
None => TcpStream::connect(peer).await,
Some(queue) => queue.recv_tcp(peer.ip().into()).await?.connect(peer).await,
}
}
async fn create_udp_stream(socket_queue: &Option<Arc<SocketQueue>>, peer: SocketAddr) -> std::io::Result<UdpStream> {
match &socket_queue {
None => UdpStream::connect(peer).await,
Some(queue) => {
let socket = queue.recv_udp(peer.ip().into()).await?;
socket.connect(peer).await?;
UdpStream::from_tokio(socket).await
}
}
}
/// Run the proxy server
/// # Arguments
/// * `device` - The network device to use
@ -78,6 +159,56 @@ where
None
};
#[cfg(target_os = "linux")]
let socket_queue = match args.socket_transfer_fd {
None => None,
Some(fd) => {
use crate::socket_transfer::{reconstruct_socket, reconstruct_transfer_socket, request_sockets};
use tokio::sync::mpsc::channel;
let fd = reconstruct_socket(fd)?;
let socket = reconstruct_transfer_socket(fd)?;
let socket = Arc::new(Mutex::new(socket));
macro_rules! create_socket_queue {
($domain:ident) => {{
const SOCKETS_PER_REQUEST: usize = 64;
let socket = socket.clone();
let (tx, rx) = channel(SOCKETS_PER_REQUEST);
tokio::spawn(async move {
loop {
let sockets =
match request_sockets(socket.lock().await, SocketDomain::$domain, SOCKETS_PER_REQUEST as u32).await {
Ok(sockets) => sockets,
Err(err) => {
log::warn!("Socket allocation request failed: {err}");
continue;
}
};
for s in sockets {
if let Err(_) = tx.send(s).await {
return;
}
}
}
});
Mutex::new(rx)
}};
}
Some(Arc::new(SocketQueue {
tcp_v4: create_socket_queue!(IpV4),
tcp_v6: create_socket_queue!(IpV6),
udp_v4: create_socket_queue!(IpV4),
udp_v6: create_socket_queue!(IpV6),
}))
}
};
#[cfg(not(target_os = "linux"))]
let socket_queue = None;
use socks5_impl::protocol::Version::{V4, V5};
let mgr = match args.proxy.proxy_type {
ProxyType::Socks5 => Arc::new(SocksProxyManager::new(server_addr, V5, key)) as Arc<dyn ProxyHandlerManager>,
@ -120,8 +251,9 @@ where
None
};
let proxy_handler = mgr.new_proxy_handler(info, domain_name, false).await?;
let socket_queue = socket_queue.clone();
tokio::spawn(async move {
if let Err(err) = handle_tcp_session(tcp, server_addr, proxy_handler).await {
if let Err(err) = handle_tcp_session(tcp, proxy_handler, socket_queue).await {
log::error!("{} error \"{}\"", info, err);
}
log::trace!("Session count {}", TASK_COUNT.fetch_sub(1, Relaxed) - 1);
@ -140,8 +272,9 @@ where
}
if args.dns == ArgDns::OverTcp {
let proxy_handler = mgr.new_proxy_handler(info, None, false).await?;
let socket_queue = socket_queue.clone();
tokio::spawn(async move {
if let Err(err) = handle_dns_over_tcp_session(udp, server_addr, proxy_handler, ipv6_enabled).await {
if let Err(err) = handle_dns_over_tcp_session(udp, proxy_handler, socket_queue, ipv6_enabled).await {
log::error!("{} error \"{}\"", info, err);
}
log::trace!("Session count {}", TASK_COUNT.fetch_sub(1, Relaxed) - 1);
@ -170,8 +303,11 @@ where
};
match mgr.new_proxy_handler(info, domain_name, true).await {
Ok(proxy_handler) => {
let socket_queue = socket_queue.clone();
tokio::spawn(async move {
if let Err(err) = handle_udp_associate_session(udp, server_addr, proxy_handler, ipv6_enabled).await {
if let Err(err) =
handle_udp_associate_session(udp, args.proxy.proxy_type, proxy_handler, socket_queue, ipv6_enabled).await
{
log::info!("Ending {} with \"{}\"", info, err);
}
log::trace!("Session count {}", TASK_COUNT.fetch_sub(1, Relaxed) - 1);
@ -207,12 +343,17 @@ async fn handle_virtual_dns_session(mut udp: IpStackUdpStream, dns: Arc<Mutex<Vi
async fn handle_tcp_session(
mut tcp_stack: IpStackTcpStream,
server_addr: SocketAddr,
proxy_handler: Arc<Mutex<dyn ProxyHandler>>,
socket_queue: Option<Arc<SocketQueue>>,
) -> crate::Result<()> {
let mut server = TcpStream::connect(server_addr).await?;
let (session_info, server_addr) = {
let handler = proxy_handler.lock().await;
(handler.get_session_info(), handler.get_server_addr())
};
let mut server = create_tcp_stream(&socket_queue, server_addr).await?;
let session_info = proxy_handler.lock().await.get_session_info();
log::info!("Beginning {}", session_info);
if let Err(e) = handle_proxy_session(&mut server, proxy_handler).await {
@ -246,20 +387,36 @@ async fn handle_tcp_session(
async fn handle_udp_associate_session(
mut udp_stack: IpStackUdpStream,
server_addr: SocketAddr,
proxy_type: ProxyType,
proxy_handler: Arc<Mutex<dyn ProxyHandler>>,
socket_queue: Option<Arc<SocketQueue>>,
ipv6_enabled: bool,
) -> crate::Result<()> {
use socks5_impl::protocol::{Address, StreamOperation, UdpHeader};
let mut server = TcpStream::connect(server_addr).await?;
let session_info = proxy_handler.lock().await.get_session_info();
let domain_name = proxy_handler.lock().await.get_domain_name();
let (session_info, server_addr, domain_name, udp_addr) = {
let handler = proxy_handler.lock().await;
(
handler.get_session_info(),
handler.get_server_addr(),
handler.get_domain_name(),
handler.get_udp_associate(),
)
};
log::info!("Beginning {}", session_info);
let udp_addr = handle_proxy_session(&mut server, proxy_handler).await?;
let udp_addr = udp_addr.ok_or("udp associate failed")?;
let udp_addr = match udp_addr {
Some(udp_addr) => udp_addr,
None => {
let mut server = create_tcp_stream(&socket_queue, server_addr).await?;
let mut udp_server = UdpStream::connect(udp_addr).await?;
let udp_addr = handle_proxy_session(&mut server, proxy_handler).await?;
udp_addr.ok_or("udp associate failed")?
}
};
let mut udp_server = create_udp_stream(&socket_queue, udp_addr).await?;
let mut buf1 = [0_u8; 4096];
let mut buf2 = [0_u8; 4096];
@ -272,6 +429,7 @@ async fn handle_udp_associate_session(
}
let buf1 = &buf1[..len];
if let ProxyType::Socks4 | ProxyType::Socks5 = proxy_type {
let s5addr = if let Some(domain_name) = &domain_name {
Address::DomainAddress(domain_name.clone(), session_info.dst.port())
} else {
@ -284,6 +442,9 @@ async fn handle_udp_associate_session(
s5_udp_data.extend_from_slice(buf1);
udp_server.write_all(&s5_udp_data).await?;
} else {
udp_server.write_all(buf1).await?;
}
}
len = udp_server.read(&mut buf2) => {
let len = len?;
@ -292,6 +453,7 @@ async fn handle_udp_associate_session(
}
let buf2 = &buf2[..len];
if let ProxyType::Socks4 | ProxyType::Socks5 = proxy_type {
// Remove SOCKS5 UDP header from the server data
let header = UdpHeader::retrieve_from_stream(&mut &buf2[..])?;
let data = &buf2[header.len()..];
@ -307,6 +469,9 @@ async fn handle_udp_associate_session(
};
udp_stack.write_all(&buf).await?;
} else {
udp_stack.write_all(buf2).await?;
}
}
}
}
@ -318,13 +483,18 @@ async fn handle_udp_associate_session(
async fn handle_dns_over_tcp_session(
mut udp_stack: IpStackUdpStream,
server_addr: SocketAddr,
proxy_handler: Arc<Mutex<dyn ProxyHandler>>,
socket_queue: Option<Arc<SocketQueue>>,
ipv6_enabled: bool,
) -> crate::Result<()> {
let mut server = TcpStream::connect(server_addr).await?;
let (session_info, server_addr) = {
let handler = proxy_handler.lock().await;
(handler.get_session_info(), handler.get_server_addr())
};
let mut server = create_tcp_stream(&socket_queue, server_addr).await?;
let session_info = proxy_handler.lock().await.get_session_info();
log::info!("Beginning {}", session_info);
let _ = handle_proxy_session(&mut server, proxy_handler).await?;

View file

@ -16,6 +16,10 @@ struct NoProxyHandler {
#[async_trait::async_trait]
impl ProxyHandler for NoProxyHandler {
fn get_server_addr(&self) -> SocketAddr {
self.info.dst
}
fn get_session_info(&self) -> SessionInfo {
self.info
}

View file

@ -7,6 +7,7 @@ use tokio::sync::Mutex;
#[async_trait::async_trait]
pub(crate) trait ProxyHandler: Send + Sync {
fn get_server_addr(&self) -> SocketAddr;
fn get_session_info(&self) -> SessionInfo;
fn get_domain_name(&self) -> Option<String>;
async fn push_data(&mut self, event: IncomingDataEvent<'_>) -> std::io::Result<()>;

230
src/socket_transfer.rs Normal file
View file

@ -0,0 +1,230 @@
#![cfg(target_os = "linux")]
use crate::{error, SocketDomain, SocketProtocol};
use nix::{
errno::Errno,
fcntl::{self, FdFlag},
sys::socket::{cmsg_space, getsockopt, recvmsg, sendmsg, sockopt, ControlMessage, ControlMessageOwned, MsgFlags, SockType},
};
use serde::{Deserialize, Serialize};
use std::{
io::{ErrorKind, IoSlice, IoSliceMut, Result},
ops::DerefMut,
os::fd::{AsFd, AsRawFd, FromRawFd, IntoRawFd, OwnedFd, RawFd},
};
use tokio::net::{TcpSocket, UdpSocket, UnixDatagram};
const REQUEST_BUFFER_SIZE: usize = 64;
#[derive(Hash, Copy, Clone, Eq, PartialEq, Debug, Serialize, Deserialize)]
struct Request {
protocol: SocketProtocol,
domain: SocketDomain,
number: u32,
}
#[derive(Hash, Copy, Clone, Eq, PartialEq, Debug, Serialize, Deserialize)]
enum Response {
Ok,
}
/// Reconstruct socket from raw `fd`
pub fn reconstruct_socket(fd: RawFd) -> Result<OwnedFd> {
// Check if `fd` is valid
let fd_flags = fcntl::fcntl(fd, fcntl::F_GETFD)?;
// `fd` is confirmed to be valid so it should be closed
let socket = unsafe { OwnedFd::from_raw_fd(fd) };
// Insert CLOEXEC flag to the `fd` to prevent further propagation across `execve(2)` calls
let mut fd_flags = FdFlag::from_bits(fd_flags).ok_or(ErrorKind::Unsupported)?;
if !fd_flags.contains(FdFlag::FD_CLOEXEC) {
fd_flags.insert(FdFlag::FD_CLOEXEC);
fcntl::fcntl(fd, fcntl::F_SETFD(fd_flags))?;
}
Ok(socket)
}
/// Reconstruct transfer socket from `fd`
///
/// Panics if called outside of tokio runtime
pub fn reconstruct_transfer_socket(fd: OwnedFd) -> Result<UnixDatagram> {
// Check if socket of type DATAGRAM
let sock_type = getsockopt(&fd, sockopt::SockType)?;
if !matches!(sock_type, SockType::Datagram) {
return Err(ErrorKind::InvalidInput.into());
}
let std_socket: std::os::unix::net::UnixDatagram = fd.into();
std_socket.set_nonblocking(true)?;
// Fails if tokio context is absent
Ok(UnixDatagram::from_std(std_socket).unwrap())
}
/// Create pair of interconnected sockets one of which is set to stay open across `execve(2)` calls.
pub async fn create_transfer_socket_pair() -> std::io::Result<(UnixDatagram, OwnedFd)> {
let (local, remote) = tokio::net::UnixDatagram::pair()?;
let remote_fd: OwnedFd = remote.into_std().unwrap().into();
// Get `remote_fd` flags
let fd_flags = fcntl::fcntl(remote_fd.as_raw_fd(), fcntl::F_GETFD)?;
// Remove CLOEXEC flag from the `remote_fd` to allow propagating across `execve(2)`
let mut fd_flags = FdFlag::from_bits(fd_flags).ok_or(ErrorKind::Unsupported)?;
fd_flags.remove(FdFlag::FD_CLOEXEC);
fcntl::fcntl(remote_fd.as_raw_fd(), fcntl::F_SETFD(fd_flags))?;
Ok((local, remote_fd))
}
pub trait TransferableSocket: Sized {
fn from_fd(fd: OwnedFd) -> Result<Self>;
fn domain() -> SocketProtocol;
}
impl TransferableSocket for TcpSocket {
fn from_fd(fd: OwnedFd) -> Result<Self> {
// Check if socket is of type STREAM
let sock_type = getsockopt(&fd, sockopt::SockType)?;
if !matches!(sock_type, SockType::Stream) {
return Err(ErrorKind::InvalidInput.into());
}
let std_stream: std::net::TcpStream = fd.into();
std_stream.set_nonblocking(true)?;
Ok(TcpSocket::from_std_stream(std_stream))
}
fn domain() -> SocketProtocol {
SocketProtocol::Tcp
}
}
impl TransferableSocket for UdpSocket {
/// Panics if called outside of tokio runtime
fn from_fd(fd: OwnedFd) -> Result<Self> {
// Check if socket is of type DATAGRAM
let sock_type = getsockopt(&fd, sockopt::SockType)?;
if !matches!(sock_type, SockType::Datagram) {
return Err(ErrorKind::InvalidInput.into());
}
let std_socket: std::net::UdpSocket = fd.into();
std_socket.set_nonblocking(true)?;
Ok(UdpSocket::try_from(std_socket).unwrap())
}
fn domain() -> SocketProtocol {
SocketProtocol::Udp
}
}
/// Send [`Request`] to `socket` and return received [`TransferableSocket`]s
///
/// Panics if called outside of tokio runtime
pub async fn request_sockets<S, T>(mut socket: S, domain: SocketDomain, number: u32) -> error::Result<Vec<T>>
where
S: DerefMut<Target = UnixDatagram>,
T: TransferableSocket,
{
// Borrow socket as mut to prevent multiple simultaneous requests
let socket = socket.deref_mut();
// Send request
let request = bincode::serialize(&Request {
protocol: T::domain(),
domain,
number,
})?;
socket.send(&request[..]).await?;
// Receive response
loop {
socket.readable().await?;
let mut buf = [0_u8; REQUEST_BUFFER_SIZE];
let mut iov = [IoSliceMut::new(&mut buf[..])];
let mut cmsg = Vec::with_capacity(cmsg_space::<RawFd>() * number as usize);
let msg = recvmsg::<()>(socket.as_fd().as_raw_fd(), &mut iov, Some(&mut cmsg), MsgFlags::empty());
let msg = match msg {
Err(Errno::EAGAIN) => continue,
msg => msg?,
};
// Parse response
let response = &msg.iovs().next().unwrap()[..msg.bytes];
let response: Response = bincode::deserialize(response)?;
if !matches!(response, Response::Ok) {
return Err("Request for new sockets failed".into());
}
// Process received file descriptors
let mut sockets = Vec::<T>::with_capacity(number as usize);
for cmsg in msg.cmsgs() {
if let ControlMessageOwned::ScmRights(fds) = cmsg {
for fd in fds {
if fd < 0 {
return Err("Received socket is invalid".into());
}
let owned_fd = reconstruct_socket(fd)?;
sockets.push(T::from_fd(owned_fd)?);
}
}
}
return Ok(sockets);
}
}
/// Process [`Request`]s received from `socket`
///
/// Panics if called outside of tokio runtime
pub async fn process_socket_requests(socket: &UnixDatagram) -> error::Result<()> {
loop {
let mut buf = [0_u8; REQUEST_BUFFER_SIZE];
let len = socket.recv(&mut buf[..]).await?;
let request: Request = bincode::deserialize(&buf[..len])?;
let response = Response::Ok;
let buf = bincode::serialize(&response)?;
let mut owned_fd_buf: Vec<OwnedFd> = Vec::with_capacity(request.number as usize);
for _ in 0..request.number {
let fd = match request.protocol {
SocketProtocol::Tcp => match request.domain {
SocketDomain::IpV4 => tokio::net::TcpSocket::new_v4(),
SocketDomain::IpV6 => tokio::net::TcpSocket::new_v6(),
}
.map(|s| unsafe { OwnedFd::from_raw_fd(s.into_raw_fd()) }),
SocketProtocol::Udp => match request.domain {
SocketDomain::IpV4 => tokio::net::UdpSocket::bind("0.0.0.0:0").await,
SocketDomain::IpV6 => tokio::net::UdpSocket::bind("[::]:0").await,
}
.map(|s| s.into_std().unwrap().into()),
};
match fd {
Err(err) => log::warn!("Failed to allocate socket: {err}"),
Ok(fd) => owned_fd_buf.push(fd),
};
}
socket.writable().await?;
let raw_fd_buf: Vec<RawFd> = owned_fd_buf.iter().map(|fd| fd.as_raw_fd()).collect();
let cmsg = ControlMessage::ScmRights(&raw_fd_buf[..]);
let iov = [IoSlice::new(&buf[..])];
sendmsg::<()>(socket.as_raw_fd(), &iov, &[cmsg], MsgFlags::empty(), None)?;
}
}

View file

@ -20,6 +20,7 @@ enum SocksState {
}
struct SocksProxyImpl {
server_addr: SocketAddr,
info: SessionInfo,
domain_name: Option<String>,
state: SocksState,
@ -35,6 +36,7 @@ struct SocksProxyImpl {
impl SocksProxyImpl {
fn new(
server_addr: SocketAddr,
info: SessionInfo,
domain_name: Option<String>,
credentials: Option<UserKey>,
@ -42,6 +44,7 @@ impl SocksProxyImpl {
command: protocol::Command,
) -> Result<Self> {
let mut result = Self {
server_addr,
info,
domain_name,
state: SocksState::ClientHello,
@ -260,6 +263,10 @@ impl SocksProxyImpl {
#[async_trait::async_trait]
impl ProxyHandler for SocksProxyImpl {
fn get_server_addr(&self) -> SocketAddr {
self.server_addr
}
fn get_session_info(&self) -> SessionInfo {
self.info
}
@ -339,6 +346,7 @@ impl ProxyHandlerManager for SocksProxyManager {
let command = if udp_associate { UdpAssociate } else { Connect };
let credentials = self.credentials.clone();
Ok(Arc::new(Mutex::new(SocksProxyImpl::new(
self.server,
info,
domain_name,
credentials,