beginning async version (#84)

This commit is contained in:
ssrlive 2024-02-01 19:15:32 +08:00 committed by GitHub
parent 337619169e
commit 9c4fa4260a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
41 changed files with 2022 additions and 3286 deletions

View file

@ -1,14 +1,16 @@
#![cfg(target_os = "android")]
use crate::{error::Error, tun2proxy::TunToProxy, tun_to_proxy, NetworkInterface, Options, Proxy};
use crate::{
args::{ArgDns, ArgProxy},
error::{Error, Result},
ArgVerbosity, Args,
};
use jni::{
objects::{JClass, JString},
sys::{jboolean, jint},
JNIEnv,
};
static mut TUN_TO_PROXY: Option<TunToProxy> = None;
/// # Safety
///
/// Running tun2proxy
@ -22,8 +24,9 @@ pub unsafe extern "C" fn Java_com_github_shadowsocks_bg_Tun2proxy_run(
verbose: jboolean,
dns_over_tcp: jboolean,
) -> jint {
let log_level = if verbose != 0 { "trace" } else { "info" };
let filter_str = &format!("off,tun2proxy={log_level}");
let dns = if dns_over_tcp != 0 { ArgDns::OverTcp } else { ArgDns::Direct };
let verbosity = if verbose != 0 { ArgVerbosity::Trace } else { ArgVerbosity::Info };
let filter_str = &format!("off,tun2proxy={verbosity}");
let filter = android_logger::FilterBuilder::new().parse(filter_str).build();
android_logger::init_once(
android_logger::Config::default()
@ -31,31 +34,11 @@ pub unsafe extern "C" fn Java_com_github_shadowsocks_bg_Tun2proxy_run(
.with_max_level(log::LevelFilter::Trace)
.with_filter(filter),
);
let proxy_url = get_java_string(&mut env, &proxy_url).unwrap();
let proxy = ArgProxy::from_url(proxy_url).unwrap();
let mut block = || -> Result<(), Error> {
let proxy_url = get_java_string(&mut env, &proxy_url)?;
let proxy = Proxy::from_url(proxy_url)?;
let addr = proxy.addr;
let proxy_type = proxy.proxy_type;
log::info!("Proxy {proxy_type} server: {addr}");
let dns_addr = "8.8.8.8".parse::<std::net::IpAddr>().unwrap();
let options = Options::new().with_dns_addr(Some(dns_addr)).with_mtu(tun_mtu as usize);
let options = if dns_over_tcp != 0 { options.with_dns_over_tcp() } else { options };
let interface = NetworkInterface::Fd(tun_fd);
let tun2proxy = tun_to_proxy(&interface, &proxy, options)?;
TUN_TO_PROXY = Some(tun2proxy);
if let Some(tun2proxy) = &mut TUN_TO_PROXY {
tun2proxy.run()?;
}
Ok::<(), Error>(())
};
if let Err(error) = block() {
log::error!("failed to run tun2proxy with error: {:?}", error);
}
0
let args = Args::new(Some(tun_fd), proxy, dns, verbosity);
crate::api::tun2proxy_internal_run(args, tun_mtu as _)
}
/// # Safety
@ -63,20 +46,7 @@ pub unsafe extern "C" fn Java_com_github_shadowsocks_bg_Tun2proxy_run(
/// Shutdown tun2proxy
#[no_mangle]
pub unsafe extern "C" fn Java_com_github_shadowsocks_bg_Tun2proxy_stop(_env: JNIEnv, _: JClass) -> jint {
match &mut TUN_TO_PROXY {
None => {
log::error!("tun2proxy not started");
1
}
Some(tun2proxy) => {
if let Err(e) = tun2proxy.shutdown() {
log::error!("failed to shutdown tun2proxy with error: {:?}", e);
1
} else {
0
}
}
}
crate::api::tun2proxy_internal_stop()
}
unsafe fn get_java_string<'a>(env: &'a mut JNIEnv, string: &'a JString) -> Result<&'a str, Error> {

70
src/api.rs Normal file
View file

@ -0,0 +1,70 @@
#![cfg(any(target_os = "ios", target_os = "android"))]
use crate::{Args, Builder, Quit};
use std::{os::raw::c_int, sync::Arc};
static mut TUN_QUIT: Option<Arc<Quit>> = None;
pub(crate) fn tun2proxy_internal_run(args: Args, tun_mtu: usize) -> c_int {
if unsafe { TUN_QUIT.is_some() } {
log::error!("tun2proxy already started");
return -1;
}
let block = async move {
log::info!("Proxy {} server: {}", args.proxy.proxy_type, args.proxy.addr);
let mut config = tun2::Configuration::default();
config.raw_fd(args.tun_fd.ok_or(crate::Error::from("tun_fd"))?);
let device = tun2::create_as_async(&config).map_err(std::io::Error::from)?;
#[cfg(target_os = "android")]
let tun2proxy = Builder::new(device, args).mtu(tun_mtu).build();
#[cfg(target_os = "ios")]
let tun2proxy = Builder::new(device, args).mtu(tun_mtu).build();
let (join_handle, quit) = tun2proxy.start();
unsafe { TUN_QUIT = Some(Arc::new(quit)) };
join_handle.await
};
match tokio::runtime::Builder::new_multi_thread().enable_all().build() {
Err(_err) => {
log::error!("failed to create tokio runtime with error: {:?}", _err);
-1
}
Ok(rt) => match rt.block_on(block) {
Ok(_) => 0,
Err(_err) => {
log::error!("failed to run tun2proxy with error: {:?}", _err);
-2
}
},
}
}
pub(crate) fn tun2proxy_internal_stop() -> c_int {
let res = match unsafe { &TUN_QUIT } {
None => {
log::error!("tun2proxy not started");
-1
}
Some(tun_quit) => match tokio::runtime::Builder::new_multi_thread().enable_all().build() {
Err(_err) => {
log::error!("failed to create tokio runtime with error: {:?}", _err);
-2
}
Ok(rt) => match rt.block_on(async move { tun_quit.trigger().await }) {
Ok(_) => 0,
Err(_err) => {
log::error!("failed to stop tun2proxy with error: {:?}", _err);
-3
}
},
},
};
unsafe { TUN_QUIT = None };
res
}

198
src/args.rs Normal file
View file

@ -0,0 +1,198 @@
use crate::{Error, Result};
use socks5_impl::protocol::UserKey;
use std::net::{IpAddr, SocketAddr, ToSocketAddrs};
use tproxy_config::TUN_NAME;
#[derive(Debug, Clone, clap::Parser)]
#[command(author, version, about = "tun2proxy application.", long_about = None)]
pub struct Args {
/// Proxy URL in the form proto://[username[:password]@]host:port,
/// where proto is one of socks4, socks5, http. For example:
/// socks5://myname:password@127.0.0.1:1080
#[arg(short, long, value_parser = ArgProxy::from_url, value_name = "URL")]
pub proxy: ArgProxy,
/// Name of the tun interface
#[arg(short, long, value_name = "name", conflicts_with = "tun_fd", default_value = TUN_NAME)]
pub tun: String,
/// File descriptor of the tun interface
#[arg(long, value_name = "fd", conflicts_with = "tun")]
pub tun_fd: Option<i32>,
/// IPv6 enabled
#[arg(short = '6', long)]
pub ipv6_enabled: bool,
#[cfg(target_os = "linux")]
#[arg(short, long)]
/// Routing and system setup, which decides whether to setup the routing and system configuration,
/// this option requires root privileges
pub setup: bool,
/// DNS handling strategy
#[arg(short, long, value_name = "strategy", value_enum, default_value = "direct")]
pub dns: ArgDns,
/// DNS resolver address
#[arg(long, value_name = "IP", default_value = "8.8.8.8")]
pub dns_addr: IpAddr,
/// IPs used in routing setup which should bypass the tunnel
#[arg(short, long, value_name = "IP")]
pub bypass: Vec<IpAddr>,
/// Verbosity level
#[arg(short, long, value_name = "level", value_enum, default_value = "info")]
pub verbosity: ArgVerbosity,
}
impl Default for Args {
fn default() -> Self {
Args {
proxy: ArgProxy::default(),
tun: TUN_NAME.to_string(),
tun_fd: None,
ipv6_enabled: false,
#[cfg(target_os = "linux")]
setup: false,
dns: ArgDns::default(),
dns_addr: "8.8.8.8".parse().unwrap(),
bypass: vec![],
verbosity: ArgVerbosity::Info,
}
}
}
impl Args {
pub fn parse_args() -> Self {
use clap::Parser;
Self::parse()
}
pub fn new(tun_fd: Option<i32>, proxy: ArgProxy, dns: ArgDns, verbosity: ArgVerbosity) -> Self {
Args {
proxy,
tun_fd,
dns,
verbosity,
..Args::default()
}
}
}
#[derive(Default, Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, clap::ValueEnum)]
pub enum ArgVerbosity {
Off,
Error,
Warn,
#[default]
Info,
Debug,
Trace,
}
impl std::fmt::Display for ArgVerbosity {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
ArgVerbosity::Off => write!(f, "off"),
ArgVerbosity::Error => write!(f, "error"),
ArgVerbosity::Warn => write!(f, "warn"),
ArgVerbosity::Info => write!(f, "info"),
ArgVerbosity::Debug => write!(f, "debug"),
ArgVerbosity::Trace => write!(f, "trace"),
}
}
}
/// DNS query handling strategy
/// - Virtual: Use a virtual DNS server to handle DNS queries, also known as Fake-IP mode
/// - OverTcp: Use TCP to send DNS queries to the DNS server
/// - Direct: Do not handle DNS by relying on DNS server bypassing
#[derive(Default, Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, clap::ValueEnum)]
pub enum ArgDns {
Virtual,
OverTcp,
#[default]
Direct,
}
#[derive(Clone, Debug)]
pub struct ArgProxy {
pub proxy_type: ProxyType,
pub addr: SocketAddr,
pub credentials: Option<UserKey>,
}
impl Default for ArgProxy {
fn default() -> Self {
ArgProxy {
proxy_type: ProxyType::Socks5,
addr: "127.0.0.1:1080".parse().unwrap(),
credentials: None,
}
}
}
impl ArgProxy {
pub fn from_url(s: &str) -> Result<ArgProxy> {
let e = format!("`{s}` is not a valid proxy URL");
let url = url::Url::parse(s).map_err(|_| Error::from(&e))?;
let e = format!("`{s}` does not contain a host");
let host = url.host_str().ok_or(Error::from(e))?;
let mut url_host = String::from(host);
let e = format!("`{s}` does not contain a port");
let port = url.port().ok_or(Error::from(&e))?;
url_host.push(':');
url_host.push_str(port.to_string().as_str());
let e = format!("`{host}` could not be resolved");
let mut addr_iter = url_host.to_socket_addrs().map_err(|_| Error::from(&e))?;
let e = format!("`{host}` does not resolve to a usable IP address");
let addr = addr_iter.next().ok_or(Error::from(&e))?;
let credentials = if url.username() == "" && url.password().is_none() {
None
} else {
let username = String::from(url.username());
let password = String::from(url.password().unwrap_or(""));
Some(UserKey::new(username, password))
};
let scheme = url.scheme();
let proxy_type = match url.scheme().to_ascii_lowercase().as_str() {
"socks4" => Some(ProxyType::Socks4),
"socks5" => Some(ProxyType::Socks5),
"http" => Some(ProxyType::Http),
_ => None,
}
.ok_or(Error::from(&format!("`{scheme}` is an invalid proxy type")))?;
Ok(ArgProxy {
proxy_type,
addr,
credentials,
})
}
}
#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Default)]
pub enum ProxyType {
Socks4,
#[default]
Socks5,
Http,
}
impl std::fmt::Display for ProxyType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ProxyType::Socks4 => write!(f, "socks4"),
ProxyType::Socks5 => write!(f, "socks5"),
ProxyType::Http => write!(f, "http"),
}
}
}

83
src/bin/main.rs Normal file
View file

@ -0,0 +1,83 @@
use tproxy_config::{TproxyArgs, TUN_GATEWAY, TUN_IPV4, TUN_NETMASK};
use tun2::DEFAULT_MTU as MTU;
use tun2proxy::{Args, Builder};
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
dotenvy::dotenv().ok();
let args = Args::parse_args();
let bypass_ips = args.bypass.clone();
// let default = format!("{}={:?}", module_path!(), args.verbosity);
let default = format!("{:?}", args.verbosity);
env_logger::Builder::from_env(env_logger::Env::default().default_filter_or(default)).init();
let mut config = tun2::Configuration::default();
config.address(TUN_IPV4).netmask(TUN_NETMASK).mtu(MTU).up();
config.destination(TUN_GATEWAY);
if let Some(tun_fd) = args.tun_fd {
config.raw_fd(tun_fd);
} else {
config.name(&args.tun);
}
#[cfg(target_os = "linux")]
config.platform_config(|config| {
#[allow(deprecated)]
config.packet_information(true);
config.ensure_root_privileges(args.setup);
});
#[cfg(target_os = "windows")]
config.platform_config(|config| {
config.device_guid(Some(12324323423423434234_u128));
});
#[allow(unused_variables)]
let mut tproxy_args = TproxyArgs::new()
.tun_dns(args.dns_addr)
.proxy_addr(args.proxy.addr)
.bypass_ips(&bypass_ips);
#[allow(unused_assignments)]
if args.tun_fd.is_none() {
tproxy_args = tproxy_args.tun_name(&args.tun);
}
#[allow(unused_mut, unused_assignments, unused_variables)]
let mut setup = true;
#[cfg(target_os = "linux")]
{
setup = args.setup;
if setup {
tproxy_config::tproxy_setup(&tproxy_args)?;
}
}
let device = tun2::create_as_async(&config)?;
#[cfg(any(target_os = "windows", target_os = "macos"))]
if setup {
tproxy_config::tproxy_setup(&tproxy_args)?;
}
let tun2proxy = Builder::new(device, args).mtu(MTU).build();
let (join_handle, quit) = tun2proxy.start();
ctrlc2::set_async_handler(async move {
quit.trigger().await.expect("quit error");
})
.await;
if let Err(err) = join_handle.await {
log::trace!("main_entry error {}", err);
}
#[cfg(any(target_os = "linux", target_os = "windows", target_os = "macos"))]
if setup {
tproxy_config::tproxy_remove(&tproxy_args)?;
}
Ok(())
}

28
src/directions.rs Normal file
View file

@ -0,0 +1,28 @@
#![allow(dead_code)]
#[derive(Clone, Copy, Eq, PartialEq, Debug)]
pub(crate) enum IncomingDirection {
FromServer,
FromClient,
}
#[derive(Clone, Copy, Eq, PartialEq, Debug)]
pub(crate) enum OutgoingDirection {
ToServer,
ToClient,
}
#[derive(Clone, Copy, Eq, PartialEq, Debug)]
pub(crate) enum Direction {
Incoming(IncomingDirection),
Outgoing(OutgoingDirection),
}
#[derive(Clone, Eq, PartialEq, Debug)]
pub(crate) struct DataEvent<'a, T> {
pub(crate) direction: T,
pub(crate) buffer: &'a [u8],
}
pub(crate) type IncomingDataEvent<'a> = DataEvent<'a, IncomingDirection>;
pub(crate) type OutgoingDataEvent<'a> = DataEvent<'a, OutgoingDirection>;

View file

@ -1,39 +1,10 @@
#![allow(dead_code)]
use std::{
net::{IpAddr, Ipv4Addr, SocketAddr},
str::FromStr,
};
use std::{net::IpAddr, str::FromStr};
use trust_dns_proto::op::MessageType;
use trust_dns_proto::{
op::{Message, ResponseCode},
rr::{record_type::RecordType, Name, RData, Record},
};
#[cfg(feature = "use-rand")]
pub fn build_dns_request(domain: &str, query_type: RecordType, used_by_tcp: bool) -> Result<Vec<u8>, String> {
// [dependencies]
// rand = "0.8"
use rand::{rngs::StdRng, Rng, SeedableRng};
use trust_dns_proto::op::{header::MessageType, op_code::OpCode, query::Query};
let name = Name::from_str(domain).map_err(|e| e.to_string())?;
let query = Query::query(name, query_type);
let mut msg = Message::new();
msg.add_query(query)
.set_id(StdRng::from_entropy().gen())
.set_op_code(OpCode::Query)
.set_message_type(MessageType::Query)
.set_recursion_desired(true);
let mut msg_buf = msg.to_vec().map_err(|e| e.to_string())?;
if used_by_tcp {
let mut buf = (msg_buf.len() as u16).to_be_bytes().to_vec();
buf.append(&mut msg_buf);
Ok(buf)
} else {
Ok(msg_buf)
}
}
pub fn build_dns_response(mut request: Message, domain: &str, ip: IpAddr, ttl: u32) -> Result<Message, String> {
let record = match ip {
IpAddr::V4(ip) => {
@ -105,17 +76,3 @@ pub fn parse_data_to_dns_message(data: &[u8], used_by_tcp: bool) -> Result<Messa
let message = Message::from_vec(data).map_err(|e| e.to_string())?;
Ok(message)
}
// FIXME: use IpAddr::is_global() instead when it's stable
pub fn addr_is_private(addr: &SocketAddr) -> bool {
fn is_benchmarking(addr: &Ipv4Addr) -> bool {
addr.octets()[0] == 198 && (addr.octets()[1] & 0xfe) == 18
}
fn addr_v4_is_private(addr: &Ipv4Addr) -> bool {
is_benchmarking(addr) || addr.is_private() || addr.is_loopback() || addr.is_link_local()
}
match addr {
SocketAddr::V4(addr) => addr_v4_is_private(addr.ip()),
SocketAddr::V6(_) => false,
}
}

71
src/dump_logger.rs Normal file
View file

@ -0,0 +1,71 @@
use std::{
os::raw::{c_char, c_int, c_void},
sync::Mutex,
};
pub(crate) static DUMP_CALLBACK: Mutex<Option<DumpCallback>> = Mutex::new(None);
/// # Safety
///
/// set dump log info callback.
#[no_mangle]
pub unsafe extern "C" fn tun2proxy_set_log_callback(
callback: Option<unsafe extern "C" fn(c_int, *const c_char, *mut c_void)>,
ctx: *mut c_void,
) {
*DUMP_CALLBACK.lock().unwrap() = Some(DumpCallback(callback, ctx));
}
#[derive(Clone)]
pub struct DumpCallback(Option<unsafe extern "C" fn(c_int, *const c_char, *mut c_void)>, *mut c_void);
impl DumpCallback {
unsafe fn call(self, dump_level: c_int, info: *const c_char) {
if let Some(cb) = self.0 {
cb(dump_level, info, self.1);
}
}
}
unsafe impl Send for DumpCallback {}
unsafe impl Sync for DumpCallback {}
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub struct DumpLogger {}
impl log::Log for DumpLogger {
fn enabled(&self, metadata: &log::Metadata) -> bool {
metadata.level() <= log::Level::Trace
}
fn log(&self, record: &log::Record) {
if self.enabled(record.metadata()) {
let current_crate_name = env!("CARGO_CRATE_NAME");
if record.module_path().unwrap_or("").starts_with(current_crate_name) {
self.do_dump_log(record);
}
}
}
fn flush(&self) {}
}
impl DumpLogger {
fn do_dump_log(&self, record: &log::Record) {
let timestamp: chrono::DateTime<chrono::Local> = chrono::Local::now();
let msg = format!(
"[{} {:<5} {}] - {}",
timestamp.format("%Y-%m-%d %H:%M:%S"),
record.level(),
record.module_path().unwrap_or(""),
record.args()
);
let c_msg = std::ffi::CString::new(msg).unwrap();
let ptr = c_msg.as_ptr();
if let Some(cb) = DUMP_CALLBACK.lock().unwrap().clone() {
unsafe {
cb.call(record.level() as c_int, ptr);
}
}
}
}

View file

@ -3,9 +3,6 @@ pub enum Error {
#[error("std::ffi::NulError {0:?}")]
Nul(#[from] std::ffi::NulError),
#[error("ctrlc2::Error {0:?}")]
InterruptHandler(#[from] ctrlc2::Error),
#[error(transparent)]
Io(#[from] std::io::Error),
@ -15,35 +12,23 @@ pub enum Error {
#[error("std::net::AddrParseError {0}")]
AddrParse(#[from] std::net::AddrParseError),
#[error("smoltcp::iface::RouteTableFull {0:?}")]
RouteTableFull(#[from] smoltcp::iface::RouteTableFull),
#[error("smoltcp::socket::tcp::RecvError {0:?}")]
Recv(#[from] smoltcp::socket::tcp::RecvError),
#[error("smoltcp::socket::tcp::ListenError {0:?}")]
Listen(#[from] smoltcp::socket::tcp::ListenError),
#[error("smoltcp::socket::udp::BindError {0:?}")]
Bind(#[from] smoltcp::socket::udp::BindError),
#[error("smoltcp::socket::tcp::SendError {0:?}")]
Send(#[from] smoltcp::socket::tcp::SendError),
#[error("smoltcp::socket::udp::SendError {0:?}")]
UdpSend(#[from] smoltcp::socket::udp::SendError),
#[error("smoltcp::wire::Error {0:?}")]
Wire(#[from] smoltcp::wire::Error),
#[error("std::str::Utf8Error {0:?}")]
Utf8(#[from] std::str::Utf8Error),
#[error("TryFromSliceError {0:?}")]
TryFromSlice(#[from] std::array::TryFromSliceError),
#[error("ProtoError {0:?}")]
ProtoError(#[from] trust_dns_proto::error::ProtoError),
#[error("IpStackError {0:?}")]
IpStack(#[from] ipstack::IpStackError),
#[error("DnsProtoError {0:?}")]
DnsProto(#[from] trust_dns_proto::error::ProtoError),
#[error("httparse::Error {0:?}")]
Httparse(#[from] httparse::Error),
#[error("digest_auth::Error {0:?}")]
DigestAuth(#[from] digest_auth::Error),
#[cfg(target_os = "android")]
#[error("jni::errors::Error {0:?}")]
@ -52,18 +37,8 @@ pub enum Error {
#[error("{0}")]
String(String),
#[cfg(target_family = "unix")]
#[error("nix::errno::Errno {0:?}")]
OSError(#[from] nix::errno::Errno),
#[error("std::num::ParseIntError {0:?}")]
IntParseError(#[from] std::num::ParseIntError),
#[error("httparse::Error {0}")]
HttpError(#[from] httparse::Error),
#[error("digest_auth::Error {0}")]
DigestAuthError(#[from] digest_auth::Error),
}
impl From<&str> for Error {

View file

@ -1,22 +1,20 @@
use crate::{
error::Error,
tun2proxy::{
ConnectionInfo, ConnectionManager, Direction, IncomingDataEvent, IncomingDirection, OutgoingDataEvent, OutgoingDirection,
ProxyHandler,
},
directions::{IncomingDataEvent, IncomingDirection, OutgoingDataEvent, OutgoingDirection},
error::{Error, Result},
proxy_handler::{ProxyHandler, ProxyHandlerManager},
session_info::{IpProtocol, SessionInfo},
};
use base64::Engine;
use httparse::Response;
use smoltcp::wire::IpProtocol;
use socks5_impl::protocol::UserKey;
use std::{
cell::RefCell,
collections::{hash_map::RandomState, HashMap, VecDeque},
iter::FromIterator,
net::SocketAddr,
rc::Rc,
str,
sync::Arc,
};
use tokio::sync::Mutex;
use unicase::UniCase;
#[derive(Eq, PartialEq, Debug)]
@ -48,10 +46,11 @@ pub struct HttpConnection {
crlf_state: u8,
counter: usize,
skip: usize,
digest_state: Rc<RefCell<Option<DigestState>>>,
digest_state: Arc<Mutex<Option<DigestState>>>,
before: bool,
credentials: Option<UserKey>,
info: ConnectionInfo,
info: SessionInfo,
domain_name: Option<String>,
}
static PROXY_AUTHENTICATE: &str = "Proxy-Authenticate";
@ -61,7 +60,12 @@ static TRANSFER_ENCODING: &str = "Transfer-Encoding";
static CONTENT_LENGTH: &str = "Content-Length";
impl HttpConnection {
fn new(info: &ConnectionInfo, credentials: Option<UserKey>, digest_state: Rc<RefCell<Option<DigestState>>>) -> Result<Self, Error> {
async fn new(
info: SessionInfo,
domain_name: Option<String>,
credentials: Option<UserKey>,
digest_state: Arc<Mutex<Option<DigestState>>>,
) -> Result<Self> {
let mut res = Self {
state: HttpState::ExpectResponseHeaders,
client_inbuf: VecDeque::default(),
@ -74,38 +78,50 @@ impl HttpConnection {
digest_state,
before: false,
credentials,
info: info.clone(),
info,
domain_name,
};
res.send_tunnel_request()?;
res.send_tunnel_request().await?;
Ok(res)
}
fn send_tunnel_request(&mut self) -> Result<(), Error> {
async fn send_tunnel_request(&mut self) -> Result<(), Error> {
let host = if let Some(domain_name) = &self.domain_name {
format!("{}:{}", domain_name, self.info.dst.port())
} else {
self.info.dst.to_string()
};
self.server_outbuf.extend(b"CONNECT ");
self.server_outbuf.extend(self.info.dst.to_string().as_bytes());
self.server_outbuf.extend(host.as_bytes());
self.server_outbuf.extend(b" HTTP/1.1\r\nHost: ");
self.server_outbuf.extend(self.info.dst.to_string().as_bytes());
self.server_outbuf.extend(host.as_bytes());
self.server_outbuf.extend(b"\r\n");
self.send_auth_data(if self.digest_state.borrow().is_none() {
let scheme = if self.digest_state.lock().await.is_none() {
AuthenticationScheme::Basic
} else {
AuthenticationScheme::Digest
})?;
};
self.send_auth_data(scheme).await?;
self.server_outbuf.extend(b"\r\n");
Ok(())
}
fn send_auth_data(&mut self, scheme: AuthenticationScheme) -> Result<(), Error> {
async fn send_auth_data(&mut self, scheme: AuthenticationScheme) -> Result<()> {
let Some(credentials) = &self.credentials else {
return Ok(());
};
match scheme {
AuthenticationScheme::Digest => {
let uri = self.info.dst.to_string();
let uri = if let Some(domain_name) = &self.domain_name {
format!("{}:{}", domain_name, self.info.dst.port())
} else {
self.info.dst.to_string()
};
let context = digest_auth::AuthContext::new_with_method(
&credentials.username,
@ -115,8 +131,8 @@ impl HttpConnection {
digest_auth::HttpMethod::CONNECT,
);
let mut state = self.digest_state.borrow_mut();
let response = state.as_mut().unwrap().respond(&context)?;
let mut state = self.digest_state.lock().await;
let response = state.as_mut().unwrap().respond(&context).unwrap();
self.server_outbuf
.extend(format!("{}: {}\r\n", PROXY_AUTHORIZATION, response.to_header_string()).as_bytes());
@ -133,7 +149,8 @@ impl HttpConnection {
Ok(())
}
fn state_change(&mut self) -> Result<(), Error> {
#[async_recursion::async_recursion]
async fn state_change(&mut self) -> Result<()> {
match self.state {
HttpState::ExpectResponseHeaders => {
while self.counter < self.server_inbuf.len() {
@ -176,7 +193,7 @@ impl HttpConnection {
// Connection successful
self.state = HttpState::Established;
self.server_inbuf.clear();
return self.state_change();
return self.state_change().await;
}
if status_code != 407 {
@ -209,7 +226,7 @@ impl HttpConnection {
}
// Update the digest state
self.digest_state.replace(Some(state));
self.digest_state.lock().await.replace(state);
self.before = true;
let closed = match headers_map.get(&UniCase::new(CONNECTION)) {
@ -222,7 +239,7 @@ impl HttpConnection {
// Reset all the buffers
self.server_inbuf.clear();
self.server_outbuf.clear();
self.send_tunnel_request()?;
self.send_tunnel_request().await?;
self.state = HttpState::Reset;
return Ok(());
@ -260,7 +277,7 @@ impl HttpConnection {
// Close the connection by information miss
self.server_inbuf.clear();
self.server_outbuf.clear();
self.send_tunnel_request()?;
self.send_tunnel_request().await?;
self.state = HttpState::Reset;
return Ok(());
@ -271,7 +288,7 @@ impl HttpConnection {
self.state = HttpState::ExpectResponse;
self.skip = content_length + len;
return self.state_change();
return self.state_change().await;
}
HttpState::ExpectResponse => {
if self.skip > 0 {
@ -285,10 +302,10 @@ impl HttpConnection {
// self.server_outbuf.append(&mut self.data_buf);
// self.data_buf.clear();
self.send_tunnel_request()?;
self.send_tunnel_request().await?;
self.state = HttpState::ExpectResponseHeaders;
return self.state_change();
return self.state_change().await;
}
}
HttpState::Established => {
@ -299,7 +316,7 @@ impl HttpConnection {
}
HttpState::Reset => {
self.state = HttpState::ExpectResponseHeaders;
return self.state_change();
return self.state_change().await;
}
_ => {}
}
@ -307,12 +324,17 @@ impl HttpConnection {
}
}
#[async_trait::async_trait]
impl ProxyHandler for HttpConnection {
fn get_connection_info(&self) -> &ConnectionInfo {
&self.info
fn get_session_info(&self) -> SessionInfo {
self.info
}
fn push_data(&mut self, event: IncomingDataEvent<'_>) -> Result<(), Error> {
fn get_domain_name(&self) -> Option<String> {
self.domain_name.clone()
}
async fn push_data(&mut self, event: IncomingDataEvent<'_>) -> std::io::Result<()> {
let direction = event.direction;
let buffer = event.buffer;
match direction {
@ -324,7 +346,8 @@ impl ProxyHandler for HttpConnection {
}
}
self.state_change()
self.state_change().await?;
Ok(())
}
fn consume_data(&mut self, dir: OutgoingDirection, size: usize) {
@ -352,16 +375,10 @@ impl ProxyHandler for HttpConnection {
self.state == HttpState::Established
}
fn data_len(&self, dir: Direction) -> usize {
fn data_len(&self, dir: OutgoingDirection) -> usize {
match dir {
Direction::Incoming(incoming) => match incoming {
IncomingDirection::FromServer => self.server_inbuf.len(),
IncomingDirection::FromClient => self.client_inbuf.len(),
},
Direction::Outgoing(outgoing) => match outgoing {
OutgoingDirection::ToServer => self.server_outbuf.len(),
OutgoingDirection::ToClient => self.client_outbuf.len(),
},
OutgoingDirection::ToServer => self.server_outbuf.len(),
OutgoingDirection::ToClient => self.client_outbuf.len(),
}
}
@ -377,19 +394,23 @@ impl ProxyHandler for HttpConnection {
pub(crate) struct HttpManager {
server: SocketAddr,
credentials: Option<UserKey>,
digest_state: Rc<RefCell<Option<DigestState>>>,
digest_state: Arc<Mutex<Option<DigestState>>>,
}
impl ConnectionManager for HttpManager {
fn new_proxy_handler(&self, info: &ConnectionInfo, _: bool) -> Result<Box<dyn ProxyHandler>, Error> {
#[async_trait::async_trait]
impl ProxyHandlerManager for HttpManager {
async fn new_proxy_handler(
&self,
info: SessionInfo,
domain_name: Option<String>,
_udp_associate: bool,
) -> std::io::Result<Arc<Mutex<dyn ProxyHandler>>> {
if info.protocol != IpProtocol::Tcp {
return Err("Invalid protocol".into());
return Err(Error::from("Invalid protocol").into());
}
Ok(Box::new(HttpConnection::new(
info,
self.credentials.clone(),
self.digest_state.clone(),
)?))
Ok(Arc::new(Mutex::new(
HttpConnection::new(info, domain_name, self.credentials.clone(), self.digest_state.clone()).await?,
)))
}
fn get_server_addr(&self) -> SocketAddr {
@ -402,7 +423,7 @@ impl HttpManager {
Self {
server,
credentials,
digest_state: Rc::new(RefCell::new(None)),
digest_state: Arc::new(Mutex::new(None)),
}
}
}

41
src/ios.rs Normal file
View file

@ -0,0 +1,41 @@
#![cfg(target_os = "ios")]
use crate::{
args::{ArgDns, ArgProxy},
ArgVerbosity, Args,
};
use std::os::raw::{c_char, c_int, c_uint};
/// # Safety
///
/// Run the tun2proxy component with some arguments.
#[no_mangle]
pub unsafe extern "C" fn tun2proxy_run(
proxy_url: *const c_char,
tun_fd: c_int,
tun_mtu: c_uint,
dns_over_tcp: c_char,
verbose: c_char,
) -> c_int {
use log::LevelFilter;
let log_level = if verbose != 0 { LevelFilter::Trace } else { LevelFilter::Info };
log::set_max_level(log_level);
log::set_boxed_logger(Box::<crate::dump_logger::DumpLogger>::default()).unwrap();
let dns = if dns_over_tcp != 0 { ArgDns::OverTcp } else { ArgDns::Direct };
let verbosity = if verbose != 0 { ArgVerbosity::Trace } else { ArgVerbosity::Info };
let proxy_url = std::ffi::CStr::from_ptr(proxy_url).to_str().unwrap();
let proxy = ArgProxy::from_url(proxy_url).unwrap();
let args = Args::new(Some(tun_fd), proxy, dns, verbosity);
crate::api::tun2proxy_internal_run(args, tun_mtu as _)
}
/// # Safety
///
/// Shutdown the tun2proxy component.
#[no_mangle]
pub unsafe extern "C" fn tun2proxy_stop() -> c_int {
crate::api::tun2proxy_internal_stop()
}

View file

@ -1,171 +1,472 @@
use crate::{
error::Error,
args::ProxyType,
directions::{IncomingDataEvent, IncomingDirection, OutgoingDirection},
http::HttpManager,
socks::SocksProxyManager,
tun2proxy::{ConnectionManager, TunToProxy},
session_info::{IpProtocol, SessionInfo},
virtual_dns::VirtualDns,
};
use smoltcp::wire::IpCidr;
use socks5_impl::protocol::UserKey;
use std::{
net::{SocketAddr, ToSocketAddrs},
rc::Rc,
pub use clap;
use ipstack::stream::{IpStackStream, IpStackTcpStream, IpStackUdpStream};
use proxy_handler::{ProxyHandler, ProxyHandlerManager};
use socks::SocksProxyManager;
use std::{collections::VecDeque, future::Future, net::SocketAddr, pin::Pin, sync::Arc};
use tokio::{
io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
net::TcpStream,
sync::{
mpsc::{error::SendError, Receiver, Sender},
Mutex,
},
};
use tproxy_config::is_private_ip;
use udp_stream::UdpStream;
pub use {
args::{ArgVerbosity, Args},
error::{Error, Result},
};
mod android;
mod api;
mod args;
mod directions;
mod dns;
pub mod error;
mod dump_logger;
mod error;
mod http;
pub mod setup;
mod ios;
mod proxy_handler;
mod session_info;
mod socks;
mod tun2proxy;
pub mod util;
mod virtdevice;
mod virtdns;
#[cfg(target_os = "windows")]
mod wintuninterface;
mod virtual_dns;
#[derive(Clone, Debug)]
pub struct Proxy {
pub proxy_type: ProxyType,
pub addr: SocketAddr,
pub credentials: Option<UserKey>,
}
const DNS_PORT: u16 = 53;
pub enum NetworkInterface {
Named(String),
#[cfg(target_family = "unix")]
Fd(std::os::fd::RawFd),
}
const MAX_SESSIONS: u64 = 200;
impl Proxy {
pub fn from_url(s: &str) -> Result<Proxy, Error> {
let e = format!("`{s}` is not a valid proxy URL");
let url = url::Url::parse(s).map_err(|_| Error::from(&e))?;
let e = format!("`{s}` does not contain a host");
let host = url.host_str().ok_or(Error::from(e))?;
static TASK_COUNT: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0);
use std::sync::atomic::Ordering::Relaxed;
let mut url_host = String::from(host);
let e = format!("`{s}` does not contain a port");
let port = url.port().ok_or(Error::from(&e))?;
url_host.push(':');
url_host.push_str(port.to_string().as_str());
let e = format!("`{host}` could not be resolved");
let mut addr_iter = url_host.to_socket_addrs().map_err(|_| Error::from(&e))?;
let e = format!("`{host}` does not resolve to a usable IP address");
let addr = addr_iter.next().ok_or(Error::from(&e))?;
let credentials = if url.username() == "" && url.password().is_none() {
None
} else {
let username = String::from(url.username());
let password = String::from(url.password().unwrap_or(""));
Some(UserKey::new(username, password))
};
let scheme = url.scheme();
let proxy_type = match url.scheme().to_ascii_lowercase().as_str() {
"socks4" => Some(ProxyType::Socks4),
"socks5" => Some(ProxyType::Socks5),
"http" => Some(ProxyType::Http),
_ => None,
}
.ok_or(Error::from(&format!("`{scheme}` is an invalid proxy type")))?;
Ok(Proxy {
proxy_type,
addr,
credentials,
})
}
}
#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug)]
pub enum ProxyType {
Socks4,
Socks5,
Http,
}
impl std::fmt::Display for ProxyType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ProxyType::Socks4 => write!(f, "socks4"),
ProxyType::Socks5 => write!(f, "socks5"),
ProxyType::Http => write!(f, "http"),
}
}
}
#[derive(Default)]
pub struct Options {
virtual_dns: Option<virtdns::VirtualDns>,
pub struct Builder<D> {
device: D,
mtu: Option<usize>,
dns_over_tcp: bool,
dns_addr: Option<std::net::IpAddr>,
ipv6_enabled: bool,
pub setup: bool,
bypass: Vec<IpCidr>,
args: Args,
}
impl Options {
pub fn new() -> Self {
Options::default()
impl<D: AsyncRead + AsyncWrite + Unpin + Send + 'static> Builder<D> {
pub fn new(device: D, args: Args) -> Self {
Builder { device, args, mtu: None }
}
pub fn with_virtual_dns(mut self) -> Self {
self.virtual_dns = Some(virtdns::VirtualDns::new());
self.dns_over_tcp = false;
self
}
pub fn with_dns_over_tcp(mut self) -> Self {
self.dns_over_tcp = true;
self.virtual_dns = None;
self
}
pub fn with_dns_addr(mut self, addr: Option<std::net::IpAddr>) -> Self {
self.dns_addr = addr;
self
}
pub fn with_ipv6_enabled(mut self) -> Self {
self.ipv6_enabled = true;
self
}
pub fn with_mtu(mut self, mtu: usize) -> Self {
pub fn mtu(mut self, mtu: usize) -> Self {
self.mtu = Some(mtu);
self
}
pub fn build(self) -> Tun2Socks5<impl Future<Output = crate::Result<()>> + Send + 'static> {
let (tx, rx) = tokio::sync::mpsc::channel::<()>(1);
pub fn with_bypass_ips<'a>(mut self, bypass_ips: impl IntoIterator<Item = &'a IpCidr>) -> Self {
for bypass_ip in bypass_ips {
self.bypass.push(*bypass_ip);
}
self
Tun2Socks5(run(self.device, self.mtu.unwrap_or(1500), self.args, rx), tx)
}
}
pub fn tun_to_proxy<'a>(interface: &NetworkInterface, proxy: &Proxy, options: Options) -> Result<TunToProxy<'a>, Error> {
let mut ttp = TunToProxy::new(interface, options)?;
let credentials = proxy.credentials.clone();
let server = proxy.addr;
use socks5_impl::protocol::Version::{V4, V5};
let mgr = match proxy.proxy_type {
ProxyType::Socks4 => Rc::new(SocksProxyManager::new(server, V4, credentials)) as Rc<dyn ConnectionManager>,
ProxyType::Socks5 => Rc::new(SocksProxyManager::new(server, V5, credentials)) as Rc<dyn ConnectionManager>,
ProxyType::Http => Rc::new(HttpManager::new(server, credentials)) as Rc<dyn ConnectionManager>,
};
ttp.set_connection_manager(Some(mgr));
Ok(ttp)
pub struct Tun2Socks5<F: Future>(F, Sender<()>);
impl<F: Future + Send + 'static> Tun2Socks5<F>
where
F::Output: Send,
{
pub fn start(self) -> (JoinHandle<F::Output>, Quit) {
let r = tokio::spawn(self.0);
(JoinHandle(r), Quit(self.1))
}
}
pub fn main_entry(interface: &NetworkInterface, proxy: &Proxy, options: Options) -> Result<(), Error> {
let mut ttp = tun_to_proxy(interface, proxy, options)?;
ttp.run()?;
pub struct Quit(Sender<()>);
impl Quit {
pub async fn trigger(&self) -> Result<(), SendError<()>> {
self.0.send(()).await
}
}
#[repr(transparent)]
struct TokioJoinError(tokio::task::JoinError);
impl From<TokioJoinError> for crate::Result<()> {
fn from(value: TokioJoinError) -> Self {
Err(crate::Error::Io(value.0.into()))
}
}
pub struct JoinHandle<R>(tokio::task::JoinHandle<R>);
impl<R: From<TokioJoinError>> Future for JoinHandle<R> {
type Output = R;
fn poll(mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Self::Output> {
match std::task::ready!(Pin::new(&mut self.0).poll(cx)) {
Ok(r) => std::task::Poll::Ready(r),
Err(e) => std::task::Poll::Ready(TokioJoinError(e).into()),
}
}
}
pub async fn run<D>(device: D, mtu: usize, args: Args, mut quit: Receiver<()>) -> crate::Result<()>
where
D: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
let server_addr = args.proxy.addr;
let key = args.proxy.credentials.clone();
let dns_addr = args.dns_addr;
let ipv6_enabled = args.ipv6_enabled;
let virtual_dns = if args.dns == args::ArgDns::Virtual {
Some(Arc::new(Mutex::new(VirtualDns::new())))
} else {
None
};
use socks5_impl::protocol::Version::{V4, V5};
let mgr = match args.proxy.proxy_type {
ProxyType::Socks5 => Arc::new(SocksProxyManager::new(server_addr, V5, key)) as Arc<dyn ProxyHandlerManager>,
ProxyType::Socks4 => Arc::new(SocksProxyManager::new(server_addr, V4, key)) as Arc<dyn ProxyHandlerManager>,
ProxyType::Http => Arc::new(HttpManager::new(server_addr, key)) as Arc<dyn ProxyHandlerManager>,
};
let mut ipstack_config = ipstack::IpStackConfig::default();
ipstack_config.mtu(mtu as _);
ipstack_config.tcp_timeout(std::time::Duration::from_secs(600)); // 10 minutes
ipstack_config.udp_timeout(std::time::Duration::from_secs(10)); // 10 seconds
let mut ip_stack = ipstack::IpStack::new(ipstack_config, device);
loop {
let virtual_dns = virtual_dns.clone();
let ip_stack_stream = tokio::select! {
_ = quit.recv() => {
log::info!("");
log::info!("Ctrl-C recieved, exiting...");
break;
}
ip_stack_stream = ip_stack.accept() => {
ip_stack_stream?
}
};
match ip_stack_stream {
IpStackStream::Tcp(tcp) => {
if TASK_COUNT.load(Relaxed) > MAX_SESSIONS {
log::warn!("Too many sessions that over {MAX_SESSIONS}, dropping new session");
continue;
}
log::trace!("Session count {}", TASK_COUNT.fetch_add(1, Relaxed) + 1);
let info = SessionInfo::new(tcp.local_addr(), tcp.peer_addr(), IpProtocol::Tcp);
let domain_name = if let Some(virtual_dns) = &virtual_dns {
let mut virtual_dns = virtual_dns.lock().await;
virtual_dns.touch_ip(&tcp.peer_addr().ip());
virtual_dns.resolve_ip(&tcp.peer_addr().ip()).cloned()
} else {
None
};
let proxy_handler = mgr.new_proxy_handler(info, domain_name, false).await?;
tokio::spawn(async move {
if let Err(err) = handle_tcp_session(tcp, server_addr, proxy_handler).await {
log::error!("{} error \"{}\"", info, err);
}
log::trace!("Session count {}", TASK_COUNT.fetch_sub(1, Relaxed) - 1);
});
}
IpStackStream::Udp(udp) => {
if TASK_COUNT.load(Relaxed) > MAX_SESSIONS {
log::warn!("Too many sessions that over {MAX_SESSIONS}, dropping new session");
continue;
}
log::trace!("Session count {}", TASK_COUNT.fetch_add(1, Relaxed) + 1);
let mut info = SessionInfo::new(udp.local_addr(), udp.peer_addr(), IpProtocol::Udp);
if info.dst.port() == DNS_PORT {
if is_private_ip(info.dst.ip()) {
info.dst.set_ip(dns_addr);
}
if args.dns == args::ArgDns::OverTcp {
let proxy_handler = mgr.new_proxy_handler(info, None, false).await?;
tokio::spawn(async move {
if let Err(err) = handle_dns_over_tcp_session(udp, server_addr, proxy_handler, ipv6_enabled).await {
log::error!("{} error \"{}\"", info, err);
}
log::trace!("Session count {}", TASK_COUNT.fetch_sub(1, Relaxed) - 1);
});
continue;
}
if args.dns == args::ArgDns::Virtual {
tokio::spawn(async move {
if let Some(virtual_dns) = virtual_dns {
if let Err(err) = handle_virtual_dns_session(udp, virtual_dns).await {
log::error!("{} error \"{}\"", info, err);
}
}
log::trace!("Session count {}", TASK_COUNT.fetch_sub(1, Relaxed) - 1);
});
continue;
}
assert_eq!(args.dns, args::ArgDns::Direct);
}
let domain_name = if let Some(virtual_dns) = &virtual_dns {
let mut virtual_dns = virtual_dns.lock().await;
virtual_dns.touch_ip(&udp.peer_addr().ip());
virtual_dns.resolve_ip(&udp.peer_addr().ip()).cloned()
} else {
None
};
let proxy_handler = mgr.new_proxy_handler(info, domain_name, true).await?;
tokio::spawn(async move {
if let Err(err) = handle_udp_associate_session(udp, server_addr, proxy_handler, ipv6_enabled).await {
log::error!("{} error \"{}\"", info, err);
}
log::trace!("Session count {}", TASK_COUNT.fetch_sub(1, Relaxed) - 1);
});
}
_ => {
log::trace!("Unknown transport");
continue;
}
}
}
Ok(())
}
async fn handle_virtual_dns_session(mut udp: IpStackUdpStream, dns: Arc<Mutex<VirtualDns>>) -> crate::Result<()> {
let mut buf = [0_u8; 4096];
loop {
let len = udp.read(&mut buf).await?;
if len == 0 {
break;
}
let (msg, qname, ip) = dns.lock().await.generate_query(&buf[..len])?;
udp.write_all(&msg).await?;
log::debug!("Virtual DNS query: {} -> {}", qname, ip);
}
Ok(())
}
async fn handle_tcp_session(
tcp_stack: IpStackTcpStream,
server_addr: SocketAddr,
proxy_handler: Arc<Mutex<dyn ProxyHandler>>,
) -> crate::Result<()> {
let mut server = TcpStream::connect(server_addr).await?;
let session_info = proxy_handler.lock().await.get_session_info();
log::info!("Beginning {}", session_info);
let _ = handle_proxy_session(&mut server, proxy_handler).await?;
let (mut t_rx, mut t_tx) = tokio::io::split(tcp_stack);
let (mut s_rx, mut s_tx) = tokio::io::split(server);
let result = tokio::join! {
tokio::io::copy(&mut t_rx, &mut s_tx),
tokio::io::copy(&mut s_rx, &mut t_tx),
};
let result = match result {
(Ok(t), Ok(s)) => Ok((t, s)),
(Err(e), _) | (_, Err(e)) => Err(e),
};
log::info!("Ending {} with {:?}", session_info, result);
Ok(())
}
async fn handle_udp_associate_session(
mut udp_stack: IpStackUdpStream,
server_addr: SocketAddr,
proxy_handler: Arc<Mutex<dyn ProxyHandler>>,
ipv6_enabled: bool,
) -> crate::Result<()> {
use socks5_impl::protocol::{Address, StreamOperation, UdpHeader};
let mut server = TcpStream::connect(server_addr).await?;
let session_info = proxy_handler.lock().await.get_session_info();
let domain_name = proxy_handler.lock().await.get_domain_name();
log::info!("Beginning {}", session_info);
let udp_addr = handle_proxy_session(&mut server, proxy_handler).await?;
let udp_addr = udp_addr.ok_or("udp associate failed")?;
let mut udp_server = UdpStream::connect(udp_addr).await?;
let mut buf1 = [0_u8; 4096];
let mut buf2 = [0_u8; 4096];
loop {
tokio::select! {
len = udp_stack.read(&mut buf1) => {
let len = len?;
if len == 0 {
break;
}
let buf1 = &buf1[..len];
let s5addr = if let Some(domain_name) = &domain_name {
Address::DomainAddress(domain_name.clone(), session_info.dst.port())
} else {
session_info.dst.into()
};
// Add SOCKS5 UDP header to the incoming data
let mut s5_udp_data = Vec::<u8>::new();
UdpHeader::new(0, s5addr).write_to_stream(&mut s5_udp_data)?;
s5_udp_data.extend_from_slice(buf1);
udp_server.write_all(&s5_udp_data).await?;
}
len = udp_server.read(&mut buf2) => {
let len = len?;
if len == 0 {
break;
}
let buf2 = &buf2[..len];
// Remove SOCKS5 UDP header from the server data
let header = UdpHeader::retrieve_from_stream(&mut &buf2[..])?;
let data = &buf2[header.len()..];
let buf = if session_info.dst.port() == DNS_PORT {
let mut message = dns::parse_data_to_dns_message(data, false)?;
if !ipv6_enabled {
dns::remove_ipv6_entries(&mut message);
}
message.to_vec()?
} else {
data.to_vec()
};
udp_stack.write_all(&buf).await?;
}
}
}
log::info!("Ending {}", session_info);
Ok(())
}
async fn handle_dns_over_tcp_session(
mut udp_stack: IpStackUdpStream,
server_addr: SocketAddr,
proxy_handler: Arc<Mutex<dyn ProxyHandler>>,
ipv6_enabled: bool,
) -> crate::Result<()> {
let mut server = TcpStream::connect(server_addr).await?;
let session_info = proxy_handler.lock().await.get_session_info();
log::info!("Beginning {}", session_info);
let _ = handle_proxy_session(&mut server, proxy_handler).await?;
let mut buf1 = [0_u8; 4096];
let mut buf2 = [0_u8; 4096];
loop {
tokio::select! {
len = udp_stack.read(&mut buf1) => {
let len = len?;
if len == 0 {
break;
}
let buf1 = &buf1[..len];
_ = dns::parse_data_to_dns_message(buf1, false)?;
// Insert the DNS message length in front of the payload
let len = u16::try_from(buf1.len())?;
let mut buf = Vec::with_capacity(std::mem::size_of::<u16>() + usize::from(len));
buf.extend_from_slice(&len.to_be_bytes());
buf.extend_from_slice(buf1);
server.write_all(&buf).await?;
}
len = server.read(&mut buf2) => {
let len = len?;
if len == 0 {
break;
}
let mut buf = buf2[..len].to_vec();
let mut to_send: VecDeque<Vec<u8>> = VecDeque::new();
loop {
if buf.len() < 2 {
break;
}
let len = u16::from_be_bytes([buf[0], buf[1]]) as usize;
if buf.len() < len + 2 {
break;
}
// remove the length field
let data = buf[2..len + 2].to_vec();
let mut message = dns::parse_data_to_dns_message(&data, false)?;
let name = dns::extract_domain_from_dns_message(&message)?;
let ip = dns::extract_ipaddr_from_dns_message(&message);
log::trace!("DNS over TCP query result: {} -> {:?}", name, ip);
if !ipv6_enabled {
dns::remove_ipv6_entries(&mut message);
}
to_send.push_back(message.to_vec()?);
if len + 2 == buf.len() {
break;
}
buf = buf[len + 2..].to_vec();
}
while let Some(packet) = to_send.pop_front() {
udp_stack.write_all(&packet).await?;
}
}
}
}
log::info!("Ending {}", session_info);
Ok(())
}
async fn handle_proxy_session(server: &mut TcpStream, proxy_handler: Arc<Mutex<dyn ProxyHandler>>) -> crate::Result<Option<SocketAddr>> {
let mut launched = false;
let mut proxy_handler = proxy_handler.lock().await;
let dir = OutgoingDirection::ToServer;
loop {
if proxy_handler.connection_established() {
break;
}
if !launched {
let data = proxy_handler.peek_data(dir).buffer;
let len = data.len();
if len == 0 {
return Err("proxy_handler launched went wrong".into());
}
server.write_all(data).await?;
proxy_handler.consume_data(dir, len);
launched = true;
}
let mut buf = [0_u8; 4096];
let len = server.read(&mut buf).await?;
if len == 0 {
return Err("server closed accidentially".into());
}
let event = IncomingDataEvent {
direction: IncomingDirection::FromServer,
buffer: &buf[..len],
};
proxy_handler.push_data(event).await?;
let data = proxy_handler.peek_data(dir).buffer;
let len = data.len();
if len > 0 {
server.write_all(data).await?;
proxy_handler.consume_data(dir, len);
}
}
Ok(proxy_handler.get_udp_associate())
}

View file

@ -1,156 +0,0 @@
use clap::Parser;
use smoltcp::wire::IpCidr;
use std::{net::IpAddr, process::ExitCode};
use tun2proxy::util::str_to_cidr;
use tun2proxy::{error::Error, main_entry, NetworkInterface, Options, Proxy};
#[cfg(target_os = "linux")]
use tun2proxy::setup::{get_default_cidrs, Setup};
/// Tunnel interface to proxy
#[derive(Parser)]
#[command(author, version, about = "Tunnel interface to proxy.", long_about = None)]
struct Args {
/// Name of the tun interface
#[arg(short, long, value_name = "name", default_value = "tun0")]
tun: String,
/// File descriptor of the tun interface
#[arg(long, value_name = "fd")]
tun_fd: Option<i32>,
/// MTU of the tun interface (only with tunnel file descriptor)
#[arg(long, value_name = "mtu", default_value = "1500")]
tun_mtu: usize,
/// Proxy URL in the form proto://[username[:password]@]host:port
#[arg(short, long, value_parser = Proxy::from_url, value_name = "URL")]
proxy: Proxy,
/// DNS handling strategy
#[arg(short, long, value_name = "strategy", value_enum, default_value = "virtual")]
dns: ArgDns,
/// DNS resolver address
#[arg(long, value_name = "IP", default_value = "8.8.8.8")]
dns_addr: IpAddr,
/// IPv6 enabled
#[arg(short = '6', long)]
ipv6_enabled: bool,
/// Routing and system setup
#[arg(short, long, value_name = "method", value_enum, default_value = if cfg!(target_os = "linux") { "none" } else { "auto" })]
setup: Option<ArgSetup>,
/// IPs used in routing setup which should bypass the tunnel
#[arg(short, long, value_name = "IP|CIDR")]
bypass: Vec<String>,
/// Verbosity level
#[arg(short, long, value_name = "level", value_enum, default_value = "info")]
verbosity: ArgVerbosity,
}
/// DNS query handling strategy
/// - Virtual: Intercept DNS queries and resolve them locally with a fake IP address
/// - OverTcp: Use TCP to send DNS queries to the DNS server
/// - Direct: Do not handle DNS by relying on DNS server bypassing
#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, clap::ValueEnum)]
enum ArgDns {
Virtual,
OverTcp,
Direct,
}
#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, clap::ValueEnum)]
enum ArgSetup {
None,
Auto,
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, clap::ValueEnum)]
enum ArgVerbosity {
Off,
Error,
Warn,
Info,
Debug,
Trace,
}
fn main() -> ExitCode {
dotenvy::dotenv().ok();
let args = Args::parse();
let default = format!("{}={:?}", module_path!(), args.verbosity);
env_logger::Builder::from_env(env_logger::Env::default().default_filter_or(default)).init();
let addr = args.proxy.addr;
let proxy_type = args.proxy.proxy_type;
log::info!("Proxy {proxy_type} server: {addr}");
let mut options = Options::new();
match args.dns {
ArgDns::Virtual => {
options = options.with_virtual_dns();
}
ArgDns::OverTcp => {
options = options.with_dns_over_tcp();
}
_ => {}
}
options = options.with_dns_addr(Some(args.dns_addr));
if args.ipv6_enabled {
options = options.with_ipv6_enabled();
}
#[allow(unused_assignments)]
let interface = match args.tun_fd {
None => NetworkInterface::Named(args.tun.clone()),
Some(_fd) => {
options = options.with_mtu(args.tun_mtu);
#[cfg(not(target_family = "unix"))]
panic!("Not supported file descriptor");
#[cfg(target_family = "unix")]
NetworkInterface::Fd(_fd)
}
};
options.setup = args.setup.map(|s| s == ArgSetup::Auto).unwrap_or(false);
let block = || -> Result<(), Error> {
let mut bypass_ips = Vec::<IpCidr>::new();
for cidr_str in args.bypass {
bypass_ips.push(str_to_cidr(&cidr_str)?);
}
if bypass_ips.is_empty() {
let prefix_len = if args.proxy.addr.ip().is_ipv6() { 128 } else { 32 };
bypass_ips.push(IpCidr::new(args.proxy.addr.ip().into(), prefix_len))
}
options = options.with_bypass_ips(&bypass_ips);
#[cfg(target_os = "linux")]
{
let mut setup: Setup;
if options.setup {
setup = Setup::new(&args.tun, bypass_ips, get_default_cidrs());
setup.configure()?;
setup.drop_privileges()?;
}
}
main_entry(&interface, &args.proxy, options)?;
Ok(())
};
if let Err(e) = block() {
log::error!("{e}");
return ExitCode::FAILURE;
}
ExitCode::SUCCESS
}

30
src/proxy_handler.rs Normal file
View file

@ -0,0 +1,30 @@
use crate::{
directions::{IncomingDataEvent, OutgoingDataEvent, OutgoingDirection},
session_info::SessionInfo,
};
use std::{net::SocketAddr, sync::Arc};
use tokio::sync::Mutex;
#[async_trait::async_trait]
pub(crate) trait ProxyHandler: Send + Sync {
fn get_session_info(&self) -> SessionInfo;
fn get_domain_name(&self) -> Option<String>;
async fn push_data(&mut self, event: IncomingDataEvent<'_>) -> std::io::Result<()>;
fn consume_data(&mut self, dir: OutgoingDirection, size: usize);
fn peek_data(&mut self, dir: OutgoingDirection) -> OutgoingDataEvent;
fn connection_established(&self) -> bool;
fn data_len(&self, dir: OutgoingDirection) -> usize;
fn reset_connection(&self) -> bool;
fn get_udp_associate(&self) -> Option<SocketAddr>;
}
#[async_trait::async_trait]
pub(crate) trait ProxyHandlerManager: Send + Sync {
async fn new_proxy_handler(
&self,
info: SessionInfo,
domain_name: Option<String>,
udp_associate: bool,
) -> std::io::Result<Arc<Mutex<dyn ProxyHandler>>>;
fn get_server_addr(&self) -> SocketAddr;
}

53
src/session_info.rs Normal file
View file

@ -0,0 +1,53 @@
use std::net::{Ipv4Addr, SocketAddr};
#[allow(dead_code)]
#[derive(Hash, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Debug, Default)]
pub(crate) enum IpProtocol {
#[default]
Tcp,
Udp,
Icmp,
Other(u8),
}
impl std::fmt::Display for IpProtocol {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
IpProtocol::Tcp => write!(f, "TCP"),
IpProtocol::Udp => write!(f, "UDP"),
IpProtocol::Icmp => write!(f, "ICMP"),
IpProtocol::Other(v) => write!(f, "Other({})", v),
}
}
}
#[derive(Hash, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Debug)]
pub(crate) struct SessionInfo {
pub(crate) src: SocketAddr,
pub(crate) dst: SocketAddr,
pub(crate) protocol: IpProtocol,
id: u64,
}
impl Default for SessionInfo {
fn default() -> Self {
let src = SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 0);
let dst = SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 0);
Self::new(src, dst, IpProtocol::Tcp)
}
}
static SESSION_ID: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0);
impl SessionInfo {
pub fn new(src: SocketAddr, dst: SocketAddr, protocol: IpProtocol) -> Self {
let id = SESSION_ID.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
Self { src, dst, protocol, id }
}
}
impl std::fmt::Display for SessionInfo {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "#{} {} {} -> {}", self.id, self.protocol, self.src, self.dst)
}
}

View file

@ -1,337 +0,0 @@
#![cfg(target_os = "linux")]
use crate::error::Error;
use fork::Fork;
use smoltcp::wire::IpCidr;
use std::{
convert::TryFrom,
ffi::OsStr,
fs,
io::BufRead,
net::{Ipv4Addr, Ipv6Addr},
os::unix::io::RawFd,
process::{Command, Output},
str::FromStr,
};
#[derive(Clone)]
pub struct Setup {
routes: Vec<IpCidr>,
tunnel_bypass_addrs: Vec<IpCidr>,
tun: String,
set_up: bool,
delete_proxy_routes: Vec<IpCidr>,
child: libc::pid_t,
unmount_resolvconf: bool,
restore_resolvconf_data: Option<Vec<u8>>,
}
pub fn get_default_cidrs() -> [IpCidr; 4] {
[
IpCidr::new(Ipv4Addr::from_str("0.0.0.0").unwrap().into(), 1),
IpCidr::new(Ipv4Addr::from_str("128.0.0.0").unwrap().into(), 1),
IpCidr::new(Ipv6Addr::from_str("::").unwrap().into(), 1),
IpCidr::new(Ipv6Addr::from_str("8000::").unwrap().into(), 1),
]
}
fn run_iproute<I, S>(args: I, error: &str, require_success: bool) -> Result<Output, Error>
where
I: IntoIterator<Item = S>,
S: AsRef<OsStr>,
{
let mut command = Command::new("");
for (i, arg) in args.into_iter().enumerate() {
if i == 0 {
command = Command::new(arg);
} else {
command.arg(arg);
}
}
let e = Error::from(error);
let output = command.output().map_err(|_| e)?;
if !require_success || output.status.success() {
Ok(output)
} else {
let mut args: Vec<&str> = command.get_args().map(|x| x.to_str().unwrap()).collect();
let program = command.get_program().to_str().unwrap();
let mut cmdline = Vec::<&str>::new();
cmdline.push(program);
cmdline.append(&mut args);
let command = cmdline.as_slice().join(" ");
match String::from_utf8(output.stderr.clone()) {
Ok(output) => Err(format!("[{}] Command `{}` failed: {}", nix::unistd::getpid(), command, output).into()),
Err(_) => Err(format!("Command `{:?}` failed with exit code {}", command, output.status.code().unwrap()).into()),
}
}
}
impl Setup {
pub fn new(
tun: impl Into<String>,
tunnel_bypass_addrs: impl IntoIterator<Item = IpCidr>,
routes: impl IntoIterator<Item = IpCidr>,
) -> Self {
let routes_cidr = routes.into_iter().collect();
let bypass_cidrs = tunnel_bypass_addrs.into_iter().collect();
Self {
tun: tun.into(),
tunnel_bypass_addrs: bypass_cidrs,
routes: routes_cidr,
set_up: false,
delete_proxy_routes: Vec::<IpCidr>::new(),
child: 0,
unmount_resolvconf: false,
restore_resolvconf_data: None,
}
}
fn bypass_cidr(cidr: &IpCidr) -> Result<bool, Error> {
let is_ipv6 = match cidr {
IpCidr::Ipv4(_) => false,
IpCidr::Ipv6(_) => true,
};
let route_show_args = if is_ipv6 {
["ip", "-6", "route", "show"]
} else {
["ip", "-4", "route", "show"]
};
let routes = run_iproute(route_show_args, "failed to get routing table through the ip command", true)?;
let mut route_info = Vec::<(IpCidr, Vec<String>)>::new();
for line in routes.stdout.lines() {
if line.is_err() {
break;
}
let line = line.unwrap();
if line.starts_with([' ', '\t']) {
continue;
}
let mut split = line.split_whitespace();
let mut dst_str = split.next().unwrap();
if dst_str == "default" {
dst_str = if is_ipv6 { "::/0" } else { "0.0.0.0/0" }
}
let (addr_str, prefix_len_str) = match dst_str.split_once(['/']) {
None => (dst_str, if is_ipv6 { "128" } else { "32" }),
Some((addr_str, prefix_len_str)) => (addr_str, prefix_len_str),
};
let cidr: IpCidr = IpCidr::new(
std::net::IpAddr::from_str(addr_str).unwrap().into(),
u8::from_str(prefix_len_str).unwrap(),
);
let route_components: Vec<String> = split.map(String::from).collect();
route_info.push((cidr, route_components))
}
// Sort routes by prefix length, the most specific route comes first.
route_info.sort_by(|entry1, entry2| entry2.0.prefix_len().cmp(&entry1.0.prefix_len()));
for (route_cidr, route_components) in route_info {
if !route_cidr.contains_subnet(cidr) {
continue;
}
// The IP address is routed through a more specific route than the default route.
// In this case, there is nothing to do.
if route_cidr.prefix_len() != 0 {
break;
}
let mut proxy_route = vec!["ip".into(), "route".into(), "add".into()];
proxy_route.push(cidr.to_string());
proxy_route.extend(route_components.into_iter());
run_iproute(proxy_route, "failed to clone route for proxy", false)?;
return Ok(true);
}
Ok(false)
}
fn write_buffer_to_fd(fd: RawFd, data: &[u8]) -> Result<(), Error> {
let mut written = 0;
loop {
if written >= data.len() {
break;
}
written += nix::unistd::write(fd, &data[written..])?;
}
Ok(())
}
fn write_nameserver(fd: RawFd) -> Result<(), Error> {
let data = "nameserver 198.18.0.1\n".as_bytes();
Self::write_buffer_to_fd(fd, data)?;
nix::sys::stat::fchmod(fd, nix::sys::stat::Mode::from_bits(0o444).unwrap())?;
Ok(())
}
fn setup_resolv_conf(&mut self) -> Result<(), Error> {
let mut fd = nix::fcntl::open(
"/tmp/tun2proxy-resolv.conf",
nix::fcntl::OFlag::O_RDWR | nix::fcntl::OFlag::O_CLOEXEC | nix::fcntl::OFlag::O_CREAT,
nix::sys::stat::Mode::from_bits(0o644).unwrap(),
)?;
Self::write_nameserver(fd)?;
let source = format!("/proc/self/fd/{}", fd);
if Ok(())
!= nix::mount::mount(
source.as_str().into(),
"/etc/resolv.conf",
"".into(),
nix::mount::MsFlags::MS_BIND,
"".into(),
)
{
log::warn!("failed to bind mount custom resolv.conf onto /etc/resolv.conf, resorting to direct write");
nix::unistd::close(fd)?;
self.restore_resolvconf_data = Some(fs::read("/etc/resolv.conf")?);
fd = nix::fcntl::open(
"/etc/resolv.conf",
nix::fcntl::OFlag::O_WRONLY | nix::fcntl::OFlag::O_CLOEXEC | nix::fcntl::OFlag::O_TRUNC,
nix::sys::stat::Mode::from_bits(0o644).unwrap(),
)?;
Self::write_nameserver(fd)?;
} else {
self.unmount_resolvconf = true;
}
nix::unistd::close(fd)?;
Ok(())
}
fn add_tunnel_routes(&self) -> Result<(), Error> {
for route in &self.routes {
run_iproute(
["ip", "route", "add", route.to_string().as_str(), "dev", self.tun.as_str()],
"failed to add route",
true,
)?;
}
Ok(())
}
fn shutdown(&mut self) -> Result<(), Error> {
self.set_up = false;
log::info!("[{}] Restoring network configuration", nix::unistd::getpid());
let _ = Command::new("ip").args(["link", "del", self.tun.as_str()]).output();
for cidr in &self.delete_proxy_routes {
let _ = Command::new("ip").args(["route", "del", cidr.to_string().as_str()]).output();
}
if self.unmount_resolvconf {
nix::mount::umount("/etc/resolv.conf")?;
}
if let Some(data) = &self.restore_resolvconf_data {
fs::write("/etc/resolv.conf", data)?;
}
Ok(())
}
fn setup_and_handle_signals(&mut self, read_from_child: RawFd, write_to_parent: RawFd) {
if let Err(e) = (|| -> Result<(), Error> {
nix::unistd::close(read_from_child)?;
run_iproute(
["ip", "tuntap", "add", "name", self.tun.as_str(), "mode", "tun"],
"failed to create tunnel device",
true,
)?;
self.set_up = true;
run_iproute(
["ip", "link", "set", self.tun.as_str(), "up"],
"failed to bring up tunnel device",
true,
)?;
let mut delete_proxy_route = Vec::<IpCidr>::new();
for cidr in &self.tunnel_bypass_addrs {
if Self::bypass_cidr(cidr)? {
delete_proxy_route.push(*cidr);
}
}
self.delete_proxy_routes = delete_proxy_route;
self.setup_resolv_conf()?;
self.add_tunnel_routes()?;
// Signal to child that we are done setting up everything.
if nix::unistd::write(write_to_parent, &[1])? != 1 {
return Err("Failed to write to pipe".into());
}
nix::unistd::close(write_to_parent)?;
// Now wait for the termination signals.
let mut mask = nix::sys::signal::SigSet::empty();
mask.add(nix::sys::signal::SIGINT);
mask.add(nix::sys::signal::SIGTERM);
mask.add(nix::sys::signal::SIGQUIT);
mask.thread_block().unwrap();
let mut fd = nix::sys::signalfd::SignalFd::new(&mask).unwrap();
loop {
let res = fd.read_signal().unwrap().unwrap();
let signo = nix::sys::signal::Signal::try_from(res.ssi_signo as i32).unwrap();
if signo == nix::sys::signal::SIGINT || signo == nix::sys::signal::SIGTERM || signo == nix::sys::signal::SIGQUIT {
break;
}
}
self.shutdown()?;
Ok(())
})() {
log::error!("{e}");
self.shutdown().unwrap();
};
}
pub fn drop_privileges(&self) -> Result<(), Error> {
// 65534 is usually the nobody user. Even in cases it is not, it is safer to use this ID
// than running with UID and GID 0.
nix::unistd::setgid(nix::unistd::Gid::from_raw(65534))?;
nix::unistd::setuid(nix::unistd::Uid::from_raw(65534))?;
Ok(())
}
pub fn configure(&mut self) -> Result<(), Error> {
log::info!("[{}] Setting up network configuration", nix::unistd::getpid());
if nix::unistd::getuid() != 0.into() {
return Err("Automatic setup requires root privileges".into());
}
let (read_from_child, write_to_parent) = nix::unistd::pipe()?;
match fork::fork() {
Ok(Fork::Child) => {
prctl::set_death_signal(nix::sys::signal::SIGINT as isize).unwrap();
self.setup_and_handle_signals(read_from_child, write_to_parent);
std::process::exit(0);
}
Ok(Fork::Parent(child)) => {
self.child = child;
nix::unistd::close(write_to_parent)?;
let mut buf = [0];
if nix::unistd::read(read_from_child, &mut buf)? != 1 {
return Err("Failed to read from pipe".into());
}
nix::unistd::close(read_from_child)?;
Ok(())
}
_ => Err("Failed to fork".into()),
}
}
pub fn restore(&mut self) -> Result<(), Error> {
nix::sys::signal::kill(nix::unistd::Pid::from_raw(self.child), nix::sys::signal::SIGINT)?;
nix::sys::wait::waitpid(nix::unistd::Pid::from_raw(self.child), None)?;
Ok(())
}
}

View file

@ -1,15 +1,14 @@
use crate::{
directions::{IncomingDataEvent, IncomingDirection, OutgoingDataEvent, OutgoingDirection},
error::{Error, Result},
tun2proxy::{
ConnectionInfo, ConnectionManager, Direction, IncomingDataEvent, IncomingDirection, OutgoingDataEvent, OutgoingDirection,
ProxyHandler,
},
proxy_handler::{ProxyHandler, ProxyHandlerManager},
session_info::SessionInfo,
};
use socks5_impl::protocol::{self, handshake, password_method, Address, AuthMethod, StreamOperation, UserKey, Version};
use std::{collections::VecDeque, convert::TryFrom, net::SocketAddr};
use std::{collections::VecDeque, net::SocketAddr, sync::Arc};
use tokio::sync::Mutex;
#[derive(Eq, PartialEq, Debug)]
#[allow(dead_code)]
enum SocksState {
ClientHello,
ServerHello,
@ -21,7 +20,8 @@ enum SocksState {
}
struct SocksProxyImpl {
info: ConnectionInfo,
info: SessionInfo,
domain_name: Option<String>,
state: SocksState,
client_inbuf: VecDeque<u8>,
server_inbuf: VecDeque<u8>,
@ -34,10 +34,17 @@ struct SocksProxyImpl {
}
impl SocksProxyImpl {
fn new(info: &ConnectionInfo, credentials: Option<UserKey>, version: Version, command: protocol::Command) -> Result<Self> {
fn new(
info: SessionInfo,
domain_name: Option<String>,
credentials: Option<UserKey>,
version: Version,
command: protocol::Command,
) -> Result<Self> {
let mut result = Self {
info: info.clone(),
state: SocksState::ServerHello,
info,
domain_name,
state: SocksState::ClientHello,
client_inbuf: VecDeque::default(),
server_inbuf: VecDeque::default(),
client_outbuf: VecDeque::default(),
@ -58,16 +65,17 @@ impl SocksProxyImpl {
let mut ip_vec = Vec::<u8>::new();
let mut name_vec = Vec::<u8>::new();
match &self.info.dst {
Address::SocketAddress(SocketAddr::V4(addr)) => {
ip_vec.extend(addr.ip().octets().as_ref());
SocketAddr::V4(addr) => {
if let Some(host) = &self.domain_name {
ip_vec.extend(&[0, 0, 0, host.len() as u8]);
name_vec.extend(host.as_bytes());
name_vec.push(0);
} else {
ip_vec.extend(addr.ip().octets().as_ref());
}
}
Address::SocketAddress(SocketAddr::V6(_)) => {
return Err("SOCKS4 does not support IPv6".into());
}
Address::DomainAddress(host, _) => {
ip_vec.extend(&[0, 0, 0, host.len() as u8]);
name_vec.extend(host.as_bytes());
name_vec.push(0);
SocketAddr::V6(addr) => {
return Err(format!("SOCKS4 does not support IPv6: {}", addr).into());
}
}
self.server_outbuf.extend(ip_vec);
@ -85,14 +93,7 @@ impl SocksProxyImpl {
fn send_client_hello_socks5(&mut self) -> Result<(), Error> {
let credentials = &self.credentials;
// Providing unassigned methods is supposed to bypass China's GFW.
// For details, refer to https://github.com/blechschmidt/tun2proxy/issues/35.
#[rustfmt::skip]
let mut methods = vec![
AuthMethod::NoAuth,
AuthMethod::from(4_u8),
AuthMethod::from(100_u8),
];
let mut methods = vec![AuthMethod::NoAuth, AuthMethod::from(4_u8), AuthMethod::from(100_u8)];
if credentials.is_some() {
methods.push(AuthMethod::UserPass);
}
@ -113,29 +114,29 @@ impl SocksProxyImpl {
Ok(())
}
fn receive_server_hello_socks4(&mut self) -> Result<(), Error> {
fn receive_server_hello_socks4(&mut self) -> std::io::Result<()> {
if self.server_inbuf.len() < 8 {
return Ok(());
}
if self.server_inbuf[1] != 0x5a {
return Err("SOCKS4 server replied with an unexpected reply code.".into());
return Err(crate::Error::from("SOCKS4 server replied with an unexpected reply code.").into());
}
self.server_inbuf.drain(0..8);
self.state = SocksState::Established;
self.state_change()
Ok(())
}
fn receive_server_hello_socks5(&mut self) -> Result<(), Error> {
fn receive_server_hello_socks5(&mut self) -> std::io::Result<()> {
let response = handshake::Response::retrieve_from_stream(&mut self.server_inbuf.clone());
if let Err(e) = &response {
if let Err(e) = response {
if e.kind() == std::io::ErrorKind::UnexpectedEof {
log::trace!("receive_server_hello_socks5 needs more data \"{}\"...", e);
return Ok(());
} else {
return Err(e.to_string().into());
return Err(e);
}
}
let respones = response?;
@ -145,7 +146,7 @@ impl SocksProxyImpl {
if auth_method != AuthMethod::NoAuth && self.credentials.is_none()
|| (auth_method != AuthMethod::NoAuth && auth_method != AuthMethod::UserPass) && self.credentials.is_some()
{
return Err("SOCKS5 server requires an unsupported authentication method.".into());
return Err(crate::Error::from("SOCKS5 server requires an unsupported authentication method.").into());
}
self.state = if auth_method == AuthMethod::UserPass {
@ -156,75 +157,77 @@ impl SocksProxyImpl {
self.state_change()
}
fn receive_server_hello(&mut self) -> Result<(), Error> {
fn receive_server_hello(&mut self) -> std::io::Result<()> {
match self.version {
Version::V4 => self.receive_server_hello_socks4(),
Version::V5 => self.receive_server_hello_socks5(),
}
}
fn send_auth_data(&mut self) -> Result<(), Error> {
fn send_auth_data(&mut self) -> std::io::Result<()> {
let tmp = UserKey::default();
let credentials = self.credentials.as_ref().unwrap_or(&tmp);
let request = password_method::Request::new(&credentials.username, &credentials.password);
request.write_to_stream(&mut self.server_outbuf)?;
self.state = SocksState::ReceiveAuthResponse;
self.state_change()
Ok(())
}
fn receive_auth_data(&mut self) -> Result<(), Error> {
fn receive_auth_data(&mut self) -> std::io::Result<()> {
use password_method::Response;
let response = Response::retrieve_from_stream(&mut self.server_inbuf.clone());
if let Err(e) = &response {
if let Err(e) = response {
if e.kind() == std::io::ErrorKind::UnexpectedEof {
log::trace!("receive_auth_data needs more data \"{}\"...", e);
return Ok(());
} else {
return Err(e.to_string().into());
return Err(e);
}
}
let response = response?;
self.server_inbuf.drain(0..response.len());
if response.status != password_method::Status::Succeeded {
return Err(format!("SOCKS authentication failed: {:?}", response.status).into());
return Err(crate::Error::from(format!("SOCKS authentication failed: {:?}", response.status)).into());
}
self.state = SocksState::SendRequest;
self.state_change()
}
fn send_request_socks5(&mut self) -> Result<(), Error> {
fn send_request_socks5(&mut self) -> std::io::Result<()> {
let addr = if self.command == protocol::Command::UdpAssociate {
Address::unspecified()
} else if let Some(domain_name) = &self.domain_name {
Address::DomainAddress(domain_name.clone(), self.info.dst.port())
} else {
self.info.dst.clone()
self.info.dst.into()
};
protocol::Request::new(self.command, addr).write_to_stream(&mut self.server_outbuf)?;
self.state = SocksState::ReceiveResponse;
self.state_change()
Ok(())
}
fn receive_connection_status(&mut self) -> Result<(), Error> {
fn receive_connection_status(&mut self) -> std::io::Result<()> {
let response = protocol::Response::retrieve_from_stream(&mut self.server_inbuf.clone());
if let Err(e) = &response {
if let Err(e) = response {
if e.kind() == std::io::ErrorKind::UnexpectedEof {
log::trace!("receive_connection_status needs more data \"{}\"...", e);
return Ok(());
} else {
return Err(e.to_string().into());
return Err(e);
}
}
let response = response?;
self.server_inbuf.drain(0..response.len());
if response.reply != protocol::Reply::Succeeded {
return Err(format!("SOCKS connection failed: {}", response.reply).into());
return Err(crate::Error::from(format!("SOCKS connection failed: {}", response.reply)).into());
}
if self.command == protocol::Command::UdpAssociate {
self.udp_associate = Some(SocketAddr::try_from(&response.address)?);
log::trace!("UDP associate recieved address {}", response.address);
// log::trace!("UDP associate recieved address {}", response.address);
}
self.state = SocksState::Established;
self.state_change()
Ok(())
}
fn relay_traffic(&mut self) -> Result<(), Error> {
@ -235,31 +238,37 @@ impl SocksProxyImpl {
Ok(())
}
fn state_change(&mut self) -> Result<(), Error> {
fn state_change(&mut self) -> std::io::Result<()> {
match self.state {
SocksState::ServerHello => self.receive_server_hello(),
SocksState::ServerHello => self.receive_server_hello()?,
SocksState::SendAuthData => self.send_auth_data(),
SocksState::SendAuthData => self.send_auth_data()?,
SocksState::ReceiveAuthResponse => self.receive_auth_data(),
SocksState::ReceiveAuthResponse => self.receive_auth_data()?,
SocksState::SendRequest => self.send_request_socks5(),
SocksState::SendRequest => self.send_request_socks5()?,
SocksState::ReceiveResponse => self.receive_connection_status(),
SocksState::ReceiveResponse => self.receive_connection_status()?,
SocksState::Established => self.relay_traffic(),
SocksState::Established => self.relay_traffic()?,
_ => Ok(()),
_ => {}
}
Ok(())
}
}
#[async_trait::async_trait]
impl ProxyHandler for SocksProxyImpl {
fn get_connection_info(&self) -> &ConnectionInfo {
&self.info
fn get_session_info(&self) -> SessionInfo {
self.info
}
fn push_data(&mut self, event: IncomingDataEvent<'_>) -> Result<(), Error> {
fn get_domain_name(&self) -> Option<String> {
self.domain_name.clone()
}
async fn push_data(&mut self, event: IncomingDataEvent<'_>) -> std::io::Result<()> {
let IncomingDataEvent { direction, buffer } = event;
match direction {
IncomingDirection::FromServer => {
@ -296,16 +305,10 @@ impl ProxyHandler for SocksProxyImpl {
self.state == SocksState::Established
}
fn data_len(&self, dir: Direction) -> usize {
fn data_len(&self, dir: OutgoingDirection) -> usize {
match dir {
Direction::Incoming(incoming) => match incoming {
IncomingDirection::FromServer => self.server_inbuf.len(),
IncomingDirection::FromClient => self.client_inbuf.len(),
},
Direction::Outgoing(outgoing) => match outgoing {
OutgoingDirection::ToServer => self.server_outbuf.len(),
OutgoingDirection::ToClient => self.client_outbuf.len(),
},
OutgoingDirection::ToServer => self.server_outbuf.len(),
OutgoingDirection::ToClient => self.client_outbuf.len(),
}
}
@ -324,12 +327,24 @@ pub(crate) struct SocksProxyManager {
version: Version,
}
impl ConnectionManager for SocksProxyManager {
fn new_proxy_handler(&self, info: &ConnectionInfo, udp_associate: bool) -> Result<Box<dyn ProxyHandler>> {
#[async_trait::async_trait]
impl ProxyHandlerManager for SocksProxyManager {
async fn new_proxy_handler(
&self,
info: SessionInfo,
domain_name: Option<String>,
udp_associate: bool,
) -> std::io::Result<Arc<Mutex<dyn ProxyHandler>>> {
use socks5_impl::protocol::Command::{Connect, UdpAssociate};
let command = if udp_associate { UdpAssociate } else { Connect };
let credentials = self.credentials.clone();
Ok(Box::new(SocksProxyImpl::new(info, credentials, self.version, command)?))
Ok(Arc::new(Mutex::new(SocksProxyImpl::new(
info,
domain_name,
credentials,
self.version,
command,
)?)))
}
fn get_server_addr(&self) -> SocketAddr {

File diff suppressed because it is too large Load diff

View file

@ -1,22 +0,0 @@
use crate::error::Error;
use smoltcp::wire::IpCidr;
use std::net::IpAddr;
use std::str::FromStr;
pub fn str_to_cidr(s: &str) -> Result<IpCidr, Error> {
// IpCidr's FromString implementation requires the netmask to be specified.
// Try to parse as IP address without netmask before falling back.
match IpAddr::from_str(s) {
Err(_) => (),
Ok(cidr) => {
let prefix_len = if cidr.is_ipv4() { 32 } else { 128 };
return Ok(IpCidr::new(cidr.into(), prefix_len));
}
};
let cidr = IpCidr::from_str(s);
match cidr {
Err(()) => Err("Invalid CIDR: ".into()),
Ok(cidr) => Ok(cidr),
}
}

View file

@ -1,80 +0,0 @@
use smoltcp::{
phy::{self, Device, DeviceCapabilities},
time::Instant,
};
/// Virtual device representing the remote proxy server.
#[derive(Default)]
pub struct VirtualTunDevice {
capabilities: DeviceCapabilities,
inbuf: Vec<Vec<u8>>,
outbuf: Vec<Vec<u8>>,
}
impl VirtualTunDevice {
pub fn inject_packet(&mut self, buffer: &[u8]) {
self.inbuf.push(buffer.to_vec());
}
pub fn exfiltrate_packet(&mut self) -> Option<Vec<u8>> {
self.outbuf.pop()
}
}
pub struct VirtRxToken {
buffer: Vec<u8>,
}
impl phy::RxToken for VirtRxToken {
fn consume<R, F>(mut self, f: F) -> R
where
F: FnOnce(&mut [u8]) -> R,
{
f(&mut self.buffer[..])
}
}
pub struct VirtTxToken<'a>(&'a mut VirtualTunDevice);
impl<'a> phy::TxToken for VirtTxToken<'a> {
fn consume<R, F>(self, len: usize, f: F) -> R
where
F: FnOnce(&mut [u8]) -> R,
{
let mut buffer = vec![0; len];
let result = f(&mut buffer);
self.0.outbuf.push(buffer);
result
}
}
impl Device for VirtualTunDevice {
type RxToken<'a> = VirtRxToken;
type TxToken<'a> = VirtTxToken<'a>;
fn receive(&mut self, _timestamp: Instant) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)> {
if let Some(buffer) = self.inbuf.pop() {
let rx = Self::RxToken { buffer };
let tx = VirtTxToken(self);
return Some((rx, tx));
}
None
}
fn transmit(&mut self, _timestamp: Instant) -> Option<Self::TxToken<'_>> {
return Some(VirtTxToken(self));
}
fn capabilities(&self) -> DeviceCapabilities {
self.capabilities.clone()
}
}
impl VirtualTunDevice {
pub fn new(capabilities: DeviceCapabilities) -> Self {
Self {
capabilities,
..VirtualTunDevice::default()
}
}
}

View file

@ -1,8 +1,5 @@
#![allow(dead_code)]
use crate::error::Result;
use hashlink::{linked_hash_map::RawEntryMut, LruCache};
use smoltcp::wire::Ipv4Cidr;
use std::{
collections::HashMap,
convert::TryInto,
@ -18,6 +15,9 @@ struct NameCacheEntry {
expiry: Instant,
}
/// A virtual DNS server which allocates IP addresses to clients.
/// The IP addresses are in the range of private IP addresses.
/// The DNS server is implemented as a LRU cache.
pub struct VirtualDns {
lru_cache: LruCache<IpAddr, NameCacheEntry>,
name_to_ip: HashMap<String, IpAddr>,
@ -29,13 +29,16 @@ pub struct VirtualDns {
impl Default for VirtualDns {
fn default() -> Self {
let start_addr = Ipv4Addr::from_str("198.18.0.0").unwrap();
let cidr = Ipv4Cidr::new(start_addr.into(), 15);
let prefix_len = 15;
let network_addr = calculate_network_addr(start_addr, prefix_len);
let broadcast_addr = calculate_broadcast_addr(start_addr, prefix_len);
Self {
next_addr: start_addr.into(),
name_to_ip: HashMap::default(),
network_addr: IpAddr::from(cidr.network().address().into_address()),
broadcast_addr: IpAddr::from(cidr.broadcast().unwrap().into_address()),
network_addr: IpAddr::from(network_addr),
broadcast_addr: IpAddr::from(broadcast_addr),
lru_cache: LruCache::new_unbounded(),
}
}
@ -46,13 +49,14 @@ impl VirtualDns {
VirtualDns::default()
}
pub fn receive_query(&mut self, data: &[u8]) -> Result<Vec<u8>> {
/// Returns the DNS response to send back to the client.
pub fn generate_query(&mut self, data: &[u8]) -> Result<(Vec<u8>, String, IpAddr)> {
use crate::dns;
let message = dns::parse_data_to_dns_message(data, false)?;
let qname = dns::extract_domain_from_dns_message(&message)?;
let ip = self.allocate_ip(qname.clone())?;
let message = dns::build_dns_response(message, &qname, ip, 5)?;
Ok(message.to_vec()?)
Ok((message.to_vec()?, qname, ip))
}
fn increment_ip(addr: IpAddr) -> Result<IpAddr> {
@ -140,3 +144,30 @@ impl VirtualDns {
}
}
}
fn calculate_network_addr(ip: std::net::Ipv4Addr, prefix_len: u8) -> std::net::Ipv4Addr {
let mask = (!0u32) << (32 - prefix_len);
let ip_u32 = u32::from_be_bytes(ip.octets());
std::net::Ipv4Addr::from((ip_u32 & mask).to_be_bytes())
}
fn calculate_broadcast_addr(ip: std::net::Ipv4Addr, prefix_len: u8) -> std::net::Ipv4Addr {
let mask = (!0u32) >> prefix_len;
let ip_u32 = u32::from_be_bytes(ip.octets());
std::net::Ipv4Addr::from((ip_u32 | mask).to_be_bytes())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cidr_addr() {
let start_addr = Ipv4Addr::from_str("198.18.0.0").unwrap();
let prefix_len = 15;
let network_addr = calculate_network_addr(start_addr, prefix_len);
let broadcast_addr = calculate_broadcast_addr(start_addr, prefix_len);
assert_eq!(network_addr, Ipv4Addr::from_str("198.18.0.0").unwrap());
assert_eq!(broadcast_addr, Ipv4Addr::from_str("198.19.255.255").unwrap());
}
}

View file

@ -1,546 +0,0 @@
use mio::{event, windows::NamedPipe, Interest, Registry, Token};
use smoltcp::wire::IpCidr;
use smoltcp::{
phy::{self, Device, DeviceCapabilities, Medium},
time::Instant,
};
use std::{
cell::RefCell,
fs::OpenOptions,
io::{self, Read, Write},
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
os::windows::prelude::{FromRawHandle, IntoRawHandle, OpenOptionsExt},
rc::Rc,
sync::{Arc, Mutex},
thread::JoinHandle,
vec::Vec,
};
use windows::{
core::{GUID, PWSTR},
Win32::{
Foundation::{ERROR_BUFFER_OVERFLOW, WIN32_ERROR},
NetworkManagement::{
IpHelper::{
GetAdaptersAddresses, SetInterfaceDnsSettings, DNS_INTERFACE_SETTINGS, DNS_INTERFACE_SETTINGS_VERSION1,
DNS_SETTING_NAMESERVER, GAA_FLAG_INCLUDE_GATEWAYS, GAA_FLAG_INCLUDE_PREFIX, IF_TYPE_ETHERNET_CSMACD, IF_TYPE_IEEE80211,
IP_ADAPTER_ADDRESSES_LH,
},
Ndis::IfOperStatusUp,
},
Networking::WinSock::{AF_INET, AF_INET6, AF_UNSPEC, SOCKADDR, SOCKADDR_IN, SOCKADDR_IN6},
Storage::FileSystem::FILE_FLAG_OVERLAPPED,
},
};
fn server() -> io::Result<(NamedPipe, String)> {
use rand::Rng;
let num: u64 = rand::thread_rng().gen();
let name = format!(r"\\.\pipe\my-pipe-{}", num);
let pipe = NamedPipe::new(&name)?;
Ok((pipe, name))
}
fn client(name: &str) -> io::Result<NamedPipe> {
let mut opts = OpenOptions::new();
opts.read(true).write(true).custom_flags(FILE_FLAG_OVERLAPPED.0);
let file = opts.open(name)?;
unsafe { Ok(NamedPipe::from_raw_handle(file.into_raw_handle())) }
}
pub(crate) fn pipe() -> io::Result<(NamedPipe, NamedPipe)> {
let (pipe, name) = server()?;
Ok((pipe, client(&name)?))
}
/// A virtual TUN (IP) interface.
pub struct WinTunInterface {
wintun_session: Arc<wintun::Session>,
mtu: usize,
medium: Medium,
pipe_server: Rc<RefCell<NamedPipe>>,
pipe_server_cache: Rc<RefCell<Vec<u8>>>,
pipe_client: Arc<Mutex<NamedPipe>>,
pipe_client_cache: Arc<Mutex<Vec<u8>>>,
wintun_reader_thread: Option<JoinHandle<()>>,
old_gateway: Option<IpAddr>,
}
impl event::Source for WinTunInterface {
fn register(&mut self, registry: &Registry, token: Token, interests: Interest) -> io::Result<()> {
self.pipe_server.borrow_mut().register(registry, token, interests)?;
Ok(())
}
fn reregister(&mut self, registry: &Registry, token: Token, interests: Interest) -> io::Result<()> {
self.pipe_server.borrow_mut().reregister(registry, token, interests)?;
Ok(())
}
fn deregister(&mut self, registry: &Registry) -> io::Result<()> {
self.pipe_server.borrow_mut().deregister(registry)?;
Ok(())
}
}
impl WinTunInterface {
pub fn new(tun_name: &str, medium: Medium) -> io::Result<WinTunInterface> {
let wintun = unsafe { wintun::load() }.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
let guid = 324435345345345345_u128;
let adapter = match wintun::Adapter::open(&wintun, tun_name) {
Ok(a) => a,
Err(_) => {
wintun::Adapter::create(&wintun, tun_name, tun_name, Some(guid)).map_err(|e| io::Error::new(io::ErrorKind::Other, e))?
}
};
let session = adapter
.start_session(wintun::MAX_RING_CAPACITY)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
let wintun_session = Arc::new(session);
let (pipe_server, pipe_client) = pipe()?;
let pipe_client = Arc::new(Mutex::new(pipe_client));
let pipe_client_cache = Arc::new(Mutex::new(Vec::new()));
let mtu = adapter.get_mtu().map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
let reader_session = wintun_session.clone();
let pipe_client_clone = pipe_client.clone();
let pipe_client_cache_clone = pipe_client_cache.clone();
let reader_thread = std::thread::spawn(move || {
let block = || -> Result<(), Box<dyn std::error::Error>> {
loop {
// Take the old data from pipe_client_cache and append the new data
let cached_data = pipe_client_cache_clone.lock()?.drain(..).collect::<Vec<u8>>();
let bytes = if cached_data.len() >= mtu {
// if the cached data is greater than mtu, then sleep 1ms and return the data
std::thread::sleep(std::time::Duration::from_millis(1));
cached_data
} else {
// read data from tunnel interface
let packet = reader_session.receive_blocking()?;
let bytes = packet.bytes().to_vec();
// and append to the end of cached data
cached_data.into_iter().chain(bytes).collect::<Vec<u8>>()
};
if bytes.is_empty() {
continue;
}
let len = bytes.len();
// write data to named pipe_server
let result = { pipe_client_clone.lock()?.write(&bytes) };
match result {
Ok(n) => {
if n < len {
log::trace!("Wintun pipe_client write data {} less than buffer {}", n, len);
pipe_client_cache_clone.lock()?.extend_from_slice(&bytes[n..]);
}
}
Err(err) if err.kind() == io::ErrorKind::WouldBlock => {
log::trace!("Wintun pipe_client write WouldBlock (1) len {}", len);
pipe_client_cache_clone.lock()?.extend_from_slice(&bytes);
}
Err(err) => log::error!("Wintun pipe_client write data len {} error \"{}\"", len, err),
}
}
};
if let Err(err) = block() {
log::trace!("Reader {}", err);
}
});
Ok(WinTunInterface {
wintun_session,
mtu,
medium,
pipe_server: Rc::new(RefCell::new(pipe_server)),
pipe_server_cache: Rc::new(RefCell::new(Vec::new())),
pipe_client,
pipe_client_cache,
wintun_reader_thread: Some(reader_thread),
old_gateway: None,
})
}
pub fn pipe_client(&self) -> Arc<Mutex<NamedPipe>> {
self.pipe_client.clone()
}
pub fn pipe_client_event(&self, event: &event::Event) -> Result<(), io::Error> {
if event.is_readable() {
self.pipe_client_event_readable()
.map_err(|e| io::Error::new(io::ErrorKind::Other, e.to_string()))?;
} else if event.is_writable() {
self.pipe_client_event_writable()
.map_err(|e| io::Error::new(io::ErrorKind::Other, e.to_string()))?;
}
Ok(())
}
fn pipe_client_event_readable(&self) -> Result<(), Box<dyn std::error::Error + '_>> {
let mut reader = self.pipe_client.lock()?;
let mut buffer = vec![0; self.mtu];
loop {
// some data arieved to pipe_client from pipe_server
match reader.read(&mut buffer[..]) {
Ok(len) => match self.wintun_session.allocate_send_packet(len as u16) {
Ok(mut write_pack) => {
write_pack.bytes_mut().copy_from_slice(&buffer[..len]);
// write data to tunnel interface
self.wintun_session.send_packet(write_pack);
}
Err(err) => {
log::error!("Wintun: failed to allocate send packet: {}", err);
}
},
Err(err) if err.kind() == io::ErrorKind::WouldBlock => break,
Err(err) if err.kind() == io::ErrorKind::Interrupted => continue,
Err(err) => return Err(err.into()),
}
}
Ok(())
}
fn pipe_client_event_writable(&self) -> Result<(), Box<dyn std::error::Error + '_>> {
let cache = self.pipe_client_cache.lock()?.drain(..).collect::<Vec<u8>>();
if cache.is_empty() {
return Ok(());
}
let len = cache.len();
let result = self.pipe_client.lock()?.write(&cache[..]);
match result {
Ok(n) => {
if n < len {
log::trace!("Wintun pipe_client write data {} less than buffer {}", n, len);
self.pipe_client_cache.lock()?.extend_from_slice(&cache[n..]);
}
}
Err(err) if err.kind() == io::ErrorKind::WouldBlock => {
log::trace!("Wintun pipe_client write WouldBlock (2) len {}", len);
self.pipe_client_cache.lock()?.extend_from_slice(&cache);
}
Err(err) => log::error!("Wintun pipe_client write data len {} error \"{}\"", len, err),
}
Ok(())
}
pub fn setup_config<'a>(
&mut self,
bypass_ips: impl IntoIterator<Item = &'a IpCidr>,
dns_addr: Option<IpAddr>,
) -> Result<(), io::Error> {
let adapter = self.wintun_session.get_adapter();
// Setup the adapter's address/mask/gateway
let address = "10.1.0.33".parse::<IpAddr>().unwrap();
let mask = "255.255.255.0".parse::<IpAddr>().unwrap();
let gateway = "10.1.0.1".parse::<IpAddr>().unwrap();
adapter
.set_network_addresses_tuple(address, mask, Some(gateway))
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
// 1. Setup the adapter's DNS
let interface = GUID::from(adapter.get_guid());
let dns = dns_addr.unwrap_or("8.8.8.8".parse::<IpAddr>().unwrap());
let dns2 = "8.8.4.4".parse::<IpAddr>().unwrap();
set_interface_dns_settings(interface, &[dns, dns2])?;
// 2. Route all traffic to the adapter, here the destination is adapter's gateway
// command: `route add 0.0.0.0 mask 0.0.0.0 10.1.0.1 metric 6`
let unspecified = Ipv4Addr::UNSPECIFIED.to_string();
let gateway = gateway.to_string();
let args = &["add", &unspecified, "mask", &unspecified, &gateway, "metric", "6"];
run_command("route", args)?;
log::info!("route {:?}", args);
let old_gateways = get_active_network_interface_gateways()?;
// find ipv4 gateway address, or error return
let old_gateway = old_gateways
.iter()
.find(|addr| addr.is_ipv4())
.ok_or_else(|| io::Error::new(io::ErrorKind::Other, "No ipv4 gateway found"))?;
let old_gateway = old_gateway.ip();
self.old_gateway = Some(old_gateway);
// 3. route the bypass ip to the old gateway
// command: `route add bypass_ip old_gateway metric 1`
for bypass_ip in bypass_ips {
let args = &["add", &bypass_ip.to_string(), &old_gateway.to_string(), "metric", "1"];
run_command("route", args)?;
log::info!("route {:?}", args);
}
Ok(())
}
pub fn restore_config(&mut self) -> Result<(), io::Error> {
if self.old_gateway.is_none() {
return Ok(());
}
let unspecified = Ipv4Addr::UNSPECIFIED.to_string();
// 1. Remove current adapter's route
// command: `route delete 0.0.0.0 mask 0.0.0.0`
let args = &["delete", &unspecified, "mask", &unspecified];
run_command("route", args)?;
// 2. Add back the old gateway route
// command: `route add 0.0.0.0 mask 0.0.0.0 old_gateway metric 200`
let old_gateway = self.old_gateway.take().unwrap().to_string();
let args = &["add", &unspecified, "mask", &unspecified, &old_gateway, "metric", "200"];
run_command("route", args)?;
Ok(())
}
}
impl Drop for WinTunInterface {
fn drop(&mut self) {
if let Err(e) = self.restore_config() {
log::error!("Faild to unsetup config: {}", e);
}
if let Err(e) = self.wintun_session.shutdown() {
log::error!("phy: failed to shutdown interface: {}", e);
}
if let Some(thread) = self.wintun_reader_thread.take() {
if let Err(e) = thread.join() {
log::error!("phy: failed to join reader thread: {:?}", e);
}
}
}
}
impl Device for WinTunInterface {
type RxToken<'a> = RxToken;
type TxToken<'a> = TxToken;
fn capabilities(&self) -> DeviceCapabilities {
let mut v = DeviceCapabilities::default();
v.max_transmission_unit = self.mtu;
v.medium = self.medium;
v
}
fn receive(&mut self, _timestamp: Instant) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)> {
let mut buffer = vec![0; self.mtu];
match self.pipe_server.borrow_mut().read(&mut buffer[..]) {
Ok(size) => {
buffer.resize(size, 0);
let rx = RxToken { buffer };
let tx = TxToken {
pipe_server: self.pipe_server.clone(),
pipe_server_cache: self.pipe_server_cache.clone(),
};
Some((rx, tx))
}
Err(err) if err.kind() == io::ErrorKind::WouldBlock => None,
Err(err) => panic!("{}", err),
}
}
fn transmit(&mut self, _timestamp: Instant) -> Option<Self::TxToken<'_>> {
Some(TxToken {
pipe_server: self.pipe_server.clone(),
pipe_server_cache: self.pipe_server_cache.clone(),
})
}
}
#[doc(hidden)]
pub struct RxToken {
buffer: Vec<u8>,
}
impl phy::RxToken for RxToken {
fn consume<R, F>(mut self, f: F) -> R
where
F: FnOnce(&mut [u8]) -> R,
{
f(&mut self.buffer[..])
}
}
#[doc(hidden)]
pub struct TxToken {
pipe_server: Rc<RefCell<NamedPipe>>,
pipe_server_cache: Rc<RefCell<Vec<u8>>>,
}
impl phy::TxToken for TxToken {
fn consume<R, F>(self, len: usize, f: F) -> R
where
F: FnOnce(&mut [u8]) -> R,
{
let mut buffer = vec![0; len];
let result = f(&mut buffer);
let buffer = self.pipe_server_cache.borrow_mut().drain(..).chain(buffer).collect::<Vec<_>>();
if buffer.is_empty() {
// log::trace!("Wintun TxToken (pipe_server) is empty");
return result;
}
let len = buffer.len();
match self.pipe_server.borrow_mut().write(&buffer[..]) {
Ok(n) => {
if n < len {
log::trace!("Wintun TxToken (pipe_server) sent {} less than buffer len {}", n, len);
self.pipe_server_cache.borrow_mut().extend_from_slice(&buffer[n..]);
}
}
Err(err) if err.kind() == io::ErrorKind::WouldBlock => {
self.pipe_server_cache.borrow_mut().extend_from_slice(&buffer[..]);
log::trace!("Wintun TxToken (pipe_server) WouldBlock data len: {}", len)
}
Err(err) => log::error!("Wintun TxToken (pipe_server) len {} error \"{}\"", len, err),
}
result
}
}
pub struct NamedPipeSource(pub Arc<Mutex<NamedPipe>>);
impl event::Source for NamedPipeSource {
fn register(&mut self, registry: &Registry, token: Token, interests: Interest) -> io::Result<()> {
self.0
.lock()
.map_err(|e| io::Error::new(io::ErrorKind::Other, e.to_string()))?
.register(registry, token, interests)
}
fn reregister(&mut self, registry: &Registry, token: Token, interests: Interest) -> io::Result<()> {
self.0
.lock()
.map_err(|e| io::Error::new(io::ErrorKind::Other, e.to_string()))?
.reregister(registry, token, interests)
}
fn deregister(&mut self, registry: &Registry) -> io::Result<()> {
self.0
.lock()
.map_err(|e| io::Error::new(io::ErrorKind::Other, e.to_string()))?
.deregister(registry)
}
}
pub(crate) fn run_command(command: &str, args: &[&str]) -> io::Result<()> {
let out = std::process::Command::new(command).args(args).output()?;
if !out.status.success() {
let err = String::from_utf8_lossy(if out.stderr.is_empty() { &out.stdout } else { &out.stderr });
let info = format!("{} failed with: \"{}\"", command, err);
return Err(std::io::Error::new(std::io::ErrorKind::Other, info));
}
Ok(())
}
pub(crate) fn set_interface_dns_settings(interface: GUID, dns: &[IpAddr]) -> io::Result<()> {
// format L"1.1.1.1 8.8.8.8", or L"1.1.1.1,8.8.8.8".
let dns = dns.iter().map(|ip| ip.to_string()).collect::<Vec<_>>().join(",");
let dns = dns.encode_utf16().chain(std::iter::once(0)).collect::<Vec<_>>();
let settings = DNS_INTERFACE_SETTINGS {
Version: DNS_INTERFACE_SETTINGS_VERSION1,
Flags: DNS_SETTING_NAMESERVER as _,
NameServer: PWSTR(dns.as_ptr() as _),
..DNS_INTERFACE_SETTINGS::default()
};
unsafe { SetInterfaceDnsSettings(interface, &settings as *const _)? };
Ok(())
}
pub(crate) fn get_active_network_interface_gateways() -> io::Result<Vec<SocketAddr>> {
let mut addrs = vec![];
get_adapters_addresses(|adapter| {
if adapter.OperStatus == IfOperStatusUp && [IF_TYPE_ETHERNET_CSMACD, IF_TYPE_IEEE80211].contains(&adapter.IfType) {
let mut current_gateway = adapter.FirstGatewayAddress;
while !current_gateway.is_null() {
let gateway = unsafe { &*current_gateway };
{
let sockaddr_ptr = gateway.Address.lpSockaddr;
let sockaddr = unsafe { &*(sockaddr_ptr as *const SOCKADDR) };
let a = unsafe { sockaddr_to_socket_addr(sockaddr) }?;
addrs.push(a);
}
current_gateway = gateway.Next;
}
}
Ok(())
})?;
Ok(addrs)
}
pub(crate) fn get_adapters_addresses<F>(mut callback: F) -> io::Result<()>
where
F: FnMut(IP_ADAPTER_ADDRESSES_LH) -> io::Result<()>,
{
let mut size = 0;
let flags = GAA_FLAG_INCLUDE_PREFIX | GAA_FLAG_INCLUDE_GATEWAYS;
let family = AF_UNSPEC.0 as u32;
// Make an initial call to GetAdaptersAddresses to get the
// size needed into the size variable
let result = unsafe { GetAdaptersAddresses(family, flags, None, None, &mut size) };
if WIN32_ERROR(result) != ERROR_BUFFER_OVERFLOW {
WIN32_ERROR(result).ok()?;
}
// Allocate memory for the buffer
let mut addresses: Vec<u8> = vec![0; (size + 4) as usize];
// Make a second call to GetAdaptersAddresses to get the actual data we want
let result = unsafe {
let addr = Some(addresses.as_mut_ptr() as *mut IP_ADAPTER_ADDRESSES_LH);
GetAdaptersAddresses(family, flags, None, addr, &mut size)
};
WIN32_ERROR(result).ok()?;
// If successful, output some information from the data we received
let mut current_addresses = addresses.as_ptr() as *const IP_ADAPTER_ADDRESSES_LH;
while !current_addresses.is_null() {
unsafe {
callback(*current_addresses)?;
current_addresses = (*current_addresses).Next;
}
}
Ok(())
}
pub(crate) unsafe fn sockaddr_to_socket_addr(sock_addr: *const SOCKADDR) -> io::Result<SocketAddr> {
let address = match (*sock_addr).sa_family {
AF_INET => sockaddr_in_to_socket_addr(&*(sock_addr as *const SOCKADDR_IN)),
AF_INET6 => sockaddr_in6_to_socket_addr(&*(sock_addr as *const SOCKADDR_IN6)),
_ => return Err(io::Error::new(io::ErrorKind::Other, "Unsupported address type")),
};
Ok(address)
}
pub(crate) unsafe fn sockaddr_in_to_socket_addr(sockaddr_in: &SOCKADDR_IN) -> SocketAddr {
let ip = Ipv4Addr::new(
sockaddr_in.sin_addr.S_un.S_un_b.s_b1,
sockaddr_in.sin_addr.S_un.S_un_b.s_b2,
sockaddr_in.sin_addr.S_un.S_un_b.s_b3,
sockaddr_in.sin_addr.S_un.S_un_b.s_b4,
);
let port = u16::from_be(sockaddr_in.sin_port);
SocketAddr::new(ip.into(), port)
}
pub(crate) unsafe fn sockaddr_in6_to_socket_addr(sockaddr_in6: &SOCKADDR_IN6) -> SocketAddr {
let ip = IpAddr::V6(Ipv6Addr::new(
u16::from_be(sockaddr_in6.sin6_addr.u.Word[0]),
u16::from_be(sockaddr_in6.sin6_addr.u.Word[1]),
u16::from_be(sockaddr_in6.sin6_addr.u.Word[2]),
u16::from_be(sockaddr_in6.sin6_addr.u.Word[3]),
u16::from_be(sockaddr_in6.sin6_addr.u.Word[4]),
u16::from_be(sockaddr_in6.sin6_addr.u.Word[5]),
u16::from_be(sockaddr_in6.sin6_addr.u.Word[6]),
u16::from_be(sockaddr_in6.sin6_addr.u.Word[7]),
));
let port = u16::from_be(sockaddr_in6.sin6_port);
SocketAddr::new(ip, port)
}