diff --git a/.gitignore b/.gitignore index 4ac1fec..bc020c7 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +examples/ build/ tmp/ .* diff --git a/src/error.rs b/src/error.rs index 86596f4..1a57783 100644 --- a/src/error.rs +++ b/src/error.rs @@ -27,6 +27,9 @@ pub enum Error { #[error("smoltcp::socket::tcp::SendError {0:?}")] Send(#[from] smoltcp::socket::tcp::SendError), + #[error("smoltcp::wire::Error {0:?}")] + Wire(#[from] smoltcp::wire::Error), + #[error("std::str::Utf8Error {0:?}")] Utf8(#[from] std::str::Utf8Error), diff --git a/src/http.rs b/src/http.rs index d5ff54f..5c5d76b 100644 --- a/src/http.rs +++ b/src/http.rs @@ -1,7 +1,7 @@ use crate::{ error::Error, tun2proxy::{ - Connection, ConnectionManager, Direction, IncomingDataEvent, IncomingDirection, + ConnectionInfo, ConnectionManager, Direction, IncomingDataEvent, IncomingDirection, OutgoingDataEvent, OutgoingDirection, TcpProxy, }, }; @@ -63,8 +63,8 @@ static CONTENT_LENGTH: &str = "Content-Length"; impl HttpConnection { fn new( - connection: &Connection, - manager: Rc, + info: &ConnectionInfo, + credentials: Option, digest_state: Rc>>, ) -> Result { let mut res = Self { @@ -79,8 +79,8 @@ impl HttpConnection { crlf_state: 0, digest_state, before: false, - credentials: manager.get_credentials().clone(), - destination: connection.dst.clone(), + credentials, + destination: info.dst.clone(), }; res.send_tunnel_request()?; @@ -394,28 +394,24 @@ pub(crate) struct HttpManager { } impl ConnectionManager for HttpManager { - fn handles_connection(&self, connection: &Connection) -> bool { - connection.proto == IpProtocol::Tcp + fn handles_connection(&self, info: &ConnectionInfo) -> bool { + info.protocol == IpProtocol::Tcp } - fn new_connection( - &self, - connection: &Connection, - manager: Rc, - ) -> Result>, Error> { - if connection.proto != IpProtocol::Tcp { - return Ok(None); + fn new_tcp_proxy(&self, info: &ConnectionInfo) -> Result, Error> { + if info.protocol != IpProtocol::Tcp { + return Err("Invalid protocol".into()); } - Ok(Some(Box::new(HttpConnection::new( - connection, - manager, + Ok(Box::new(HttpConnection::new( + info, + self.credentials.clone(), self.digest_state.clone(), - )?))) + )?)) } - fn close_connection(&self, _: &Connection) {} + fn close_connection(&self, _: &ConnectionInfo) {} - fn get_server(&self) -> SocketAddr { + fn get_server_addr(&self) -> SocketAddr { self.server } @@ -425,11 +421,11 @@ impl ConnectionManager for HttpManager { } impl HttpManager { - pub fn new(server: SocketAddr, credentials: Option) -> Rc { - Rc::new(Self { + pub fn new(server: SocketAddr, credentials: Option) -> Self { + Self { server, credentials, digest_state: Rc::new(RefCell::new(None)), - }) + } } } diff --git a/src/lib.rs b/src/lib.rs index 2c743c0..98f69c6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,9 +1,10 @@ -use crate::{ - error::Error, http::HttpManager, socks::SocksManager, socks::SocksVersion, - tun2proxy::TunToProxy, +use crate::{error::Error, http::HttpManager, socks::SocksProxyManager, tun2proxy::TunToProxy}; +use socks5_impl::protocol::{UserKey, Version}; +use std::{ + net::{SocketAddr, ToSocketAddrs}, + rc::Rc, }; -use socks5_impl::protocol::UserKey; -use std::net::{SocketAddr, ToSocketAddrs}; +use tun2proxy::ConnectionManager; mod android; pub mod error; @@ -90,7 +91,7 @@ impl std::fmt::Display for ProxyType { #[derive(Default)] pub struct Options { - virtdns: Option, + virtual_dns: Option, mtu: Option, } @@ -100,7 +101,7 @@ impl Options { } pub fn with_virtual_dns(mut self) -> Self { - self.virtdns = Some(virtdns::VirtualDns::new()); + self.virtual_dns = Some(virtdns::VirtualDns::new()); self } @@ -116,25 +117,18 @@ pub fn tun_to_proxy<'a>( options: Options, ) -> Result, Error> { let mut ttp = TunToProxy::new(interface, options)?; - match proxy.proxy_type { - ProxyType::Socks4 => { - ttp.add_connection_manager(SocksManager::new( - proxy.addr, - SocksVersion::V4, - proxy.credentials.clone(), - )); - } - ProxyType::Socks5 => { - ttp.add_connection_manager(SocksManager::new( - proxy.addr, - SocksVersion::V5, - proxy.credentials.clone(), - )); - } + let credentials = proxy.credentials.clone(); + let server = proxy.addr; + let mgr = match proxy.proxy_type { + ProxyType::Socks4 => Rc::new(SocksProxyManager::new(server, Version::V4, credentials)) + as Rc, + ProxyType::Socks5 => Rc::new(SocksProxyManager::new(server, Version::V5, credentials)) + as Rc, ProxyType::Http => { - ttp.add_connection_manager(HttpManager::new(proxy.addr, proxy.credentials.clone())); + Rc::new(HttpManager::new(server, credentials)) as Rc } - } + }; + ttp.add_connection_manager(mgr); Ok(ttp) } @@ -143,6 +137,7 @@ pub fn main_entry( proxy: &Proxy, options: Options, ) -> Result<(), Error> { - let ttp = tun_to_proxy(interface, proxy, options); - ttp?.run() + let mut ttp = tun_to_proxy(interface, proxy, options)?; + ttp.run()?; + Ok(()) } diff --git a/src/socks.rs b/src/socks.rs index 6c57f26..d5fd6ad 100644 --- a/src/socks.rs +++ b/src/socks.rs @@ -1,15 +1,15 @@ use crate::{ error::Error, tun2proxy::{ - Connection, ConnectionManager, Direction, IncomingDataEvent, IncomingDirection, + ConnectionInfo, ConnectionManager, Direction, IncomingDataEvent, IncomingDirection, OutgoingDataEvent, OutgoingDirection, TcpProxy, }, }; use smoltcp::wire::IpProtocol; use socks5_impl::protocol::{ - self, handshake, password_method, Address, AuthMethod, StreamOperation, UserKey, + self, handshake, password_method, Address, AuthMethod, StreamOperation, UserKey, Version, }; -use std::{collections::VecDeque, net::SocketAddr, rc::Rc}; +use std::{collections::VecDeque, net::SocketAddr}; #[derive(Eq, PartialEq, Debug)] #[allow(dead_code)] @@ -23,33 +23,28 @@ enum SocksState { Established, } -#[repr(u8)] -#[derive(Copy, Clone, PartialEq, Debug)] -pub enum SocksVersion { - V4 = 4, - V5 = 5, -} - -pub(crate) struct SocksConnection { - connection: Connection, +struct SocksProxyImpl { + info: ConnectionInfo, state: SocksState, client_inbuf: VecDeque, server_inbuf: VecDeque, client_outbuf: VecDeque, server_outbuf: VecDeque, data_buf: VecDeque, - version: SocksVersion, + version: Version, credentials: Option, + command: protocol::Command, + udp_relay_addr: Option
, } -impl SocksConnection { +impl SocksProxyImpl { pub fn new( - connection: &Connection, - manager: Rc, - version: SocksVersion, + info: &ConnectionInfo, + credentials: Option, + version: Version, ) -> Result { let mut result = Self { - connection: connection.clone(), + info: info.clone(), state: SocksState::ServerHello, client_inbuf: VecDeque::default(), server_inbuf: VecDeque::default(), @@ -57,59 +52,71 @@ impl SocksConnection { server_outbuf: VecDeque::default(), data_buf: VecDeque::default(), version, - credentials: manager.get_credentials().clone(), + credentials, + command: protocol::Command::Connect, + udp_relay_addr: None, }; result.send_client_hello()?; Ok(result) } - fn send_client_hello(&mut self) -> Result<(), Error> { + fn send_client_hello_socks4(&mut self) -> Result<(), Error> { let credentials = &self.credentials; - match self.version { - SocksVersion::V4 => { - self.server_outbuf - .extend(&[self.version as u8, protocol::Command::Connect.into()]); - self.server_outbuf - .extend(self.connection.dst.port().to_be_bytes()); - let mut ip_vec = Vec::::new(); - let mut name_vec = Vec::::new(); - match &self.connection.dst { - Address::SocketAddress(SocketAddr::V4(addr)) => { - ip_vec.extend(addr.ip().octets().as_ref()); - } - Address::SocketAddress(SocketAddr::V6(_)) => { - return Err("SOCKS4 does not support IPv6".into()); - } - Address::DomainAddress(host, _) => { - ip_vec.extend(&[0, 0, 0, host.len() as u8]); - name_vec.extend(host.as_bytes()); - name_vec.push(0); - } - } - self.server_outbuf.extend(ip_vec); - if let Some(credentials) = credentials { - self.server_outbuf.extend(credentials.username.as_bytes()); - if !credentials.password.is_empty() { - self.server_outbuf.push_back(b':'); - self.server_outbuf.extend(credentials.password.as_bytes()); - } - } - self.server_outbuf.push_back(0); - self.server_outbuf.extend(name_vec); + self.server_outbuf + .extend(&[self.version as u8, protocol::Command::Connect.into()]); + self.server_outbuf + .extend(self.info.dst.port().to_be_bytes()); + let mut ip_vec = Vec::::new(); + let mut name_vec = Vec::::new(); + match &self.info.dst { + Address::SocketAddress(SocketAddr::V4(addr)) => { + ip_vec.extend(addr.ip().octets().as_ref()); } + Address::SocketAddress(SocketAddr::V6(_)) => { + return Err("SOCKS4 does not support IPv6".into()); + } + Address::DomainAddress(host, _) => { + ip_vec.extend(&[0, 0, 0, host.len() as u8]); + name_vec.extend(host.as_bytes()); + name_vec.push(0); + } + } + self.server_outbuf.extend(ip_vec); + if let Some(credentials) = credentials { + self.server_outbuf.extend(credentials.username.as_bytes()); + if !credentials.password.is_empty() { + self.server_outbuf.push_back(b':'); + self.server_outbuf.extend(credentials.password.as_bytes()); + } + } + self.server_outbuf.push_back(0); + self.server_outbuf.extend(name_vec); + Ok(()) + } - SocksVersion::V5 => { - // Providing unassigned methods is supposed to bypass China's GFW. - // For details, refer to https://github.com/blechschmidt/tun2proxy/issues/35. - let mut methods = vec![ - AuthMethod::NoAuth, - AuthMethod::from(4_u8), - AuthMethod::from(100_u8), - ]; - if credentials.is_some() { - methods.push(AuthMethod::UserPass); - } - handshake::Request::new(methods).write_to_stream(&mut self.server_outbuf)?; + fn send_client_hello_socks5(&mut self) -> Result<(), Error> { + let credentials = &self.credentials; + // Providing unassigned methods is supposed to bypass China's GFW. + // For details, refer to https://github.com/blechschmidt/tun2proxy/issues/35. + let mut methods = vec![ + AuthMethod::NoAuth, + AuthMethod::from(4_u8), + AuthMethod::from(100_u8), + ]; + if credentials.is_some() { + methods.push(AuthMethod::UserPass); + } + handshake::Request::new(methods).write_to_stream(&mut self.server_outbuf)?; + Ok(()) + } + + fn send_client_hello(&mut self) -> Result<(), Error> { + match self.version { + Version::V4 => { + self.send_client_hello_socks4()?; + } + Version::V5 => { + self.send_client_hello_socks5()?; } } self.state = SocksState::ServerHello; @@ -164,8 +171,8 @@ impl SocksConnection { fn receive_server_hello(&mut self) -> Result<(), Error> { match self.version { - SocksVersion::V4 => self.receive_server_hello_socks4(), - SocksVersion::V5 => self.receive_server_hello_socks5(), + Version::V4 => self.receive_server_hello_socks4(), + Version::V5 => self.receive_server_hello_socks5(), } } @@ -213,6 +220,12 @@ impl SocksConnection { if response.reply != protocol::Reply::Succeeded { return Err(format!("SOCKS connection failed: {}", response.reply).into()); } + + if self.command == protocol::Command::UdpAssociate { + log::info!("UDP packet destination: {}", response.address); + self.udp_relay_addr = Some(response.address); + } + self.server_outbuf.append(&mut self.data_buf); self.data_buf.clear(); @@ -220,8 +233,9 @@ impl SocksConnection { self.state_change() } - fn send_request(&mut self) -> Result<(), Error> { - protocol::Request::new(protocol::Command::Connect, self.connection.dst.clone()) + fn send_request_socks5(&mut self) -> Result<(), Error> { + // self.server_outbuf.extend(&[self.version as u8, self.command as u8, 0]); + protocol::Request::new(protocol::Command::Connect, self.info.dst.clone()) .write_to_stream(&mut self.server_outbuf)?; self.state = SocksState::ReceiveResponse; self.state_change() @@ -243,7 +257,7 @@ impl SocksConnection { SocksState::ReceiveAuthResponse => self.receive_auth_data(), - SocksState::SendRequest => self.send_request(), + SocksState::SendRequest => self.send_request_socks5(), SocksState::ReceiveResponse => self.receive_connection_status(), @@ -254,7 +268,7 @@ impl SocksConnection { } } -impl TcpProxy for SocksConnection { +impl TcpProxy for SocksProxyImpl { fn push_data(&mut self, event: IncomingDataEvent<'_>) -> Result<(), Error> { let direction = event.direction; let buffer = event.buffer; @@ -319,35 +333,31 @@ impl TcpProxy for SocksConnection { } } -pub struct SocksManager { +pub(crate) struct SocksProxyManager { server: SocketAddr, credentials: Option, - version: SocksVersion, + version: Version, } -impl ConnectionManager for SocksManager { - fn handles_connection(&self, connection: &Connection) -> bool { - connection.proto == IpProtocol::Tcp +impl ConnectionManager for SocksProxyManager { + fn handles_connection(&self, info: &ConnectionInfo) -> bool { + info.protocol == IpProtocol::Tcp } - fn new_connection( - &self, - connection: &Connection, - manager: Rc, - ) -> Result>, Error> { - if connection.proto != IpProtocol::Tcp { - return Ok(None); + fn new_tcp_proxy(&self, info: &ConnectionInfo) -> Result, Error> { + if info.protocol != IpProtocol::Tcp { + return Err("Invalid protocol".into()); } - Ok(Some(Box::new(SocksConnection::new( - connection, - manager, + Ok(Box::new(SocksProxyImpl::new( + info, + self.credentials.clone(), self.version, - )?))) + )?)) } - fn close_connection(&self, _: &Connection) {} + fn close_connection(&self, _: &ConnectionInfo) {} - fn get_server(&self) -> SocketAddr { + fn get_server_addr(&self) -> SocketAddr { self.server } @@ -356,16 +366,12 @@ impl ConnectionManager for SocksManager { } } -impl SocksManager { - pub fn new( - server: SocketAddr, - version: SocksVersion, - credentials: Option, - ) -> Rc { - Rc::new(Self { +impl SocksProxyManager { + pub(crate) fn new(server: SocketAddr, version: Version, credentials: Option) -> Self { + Self { server, credentials, version, - }) + } } } diff --git a/src/tun2proxy.rs b/src/tun2proxy.rs index d5af74a..29675e2 100644 --- a/src/tun2proxy.rs +++ b/src/tun2proxy.rs @@ -1,11 +1,11 @@ -use crate::{error::Error, virtdevice::VirtualTunDevice, NetworkInterface, Options}; +use crate::{error::Error, error::Result, virtdevice::VirtualTunDevice, NetworkInterface, Options}; use mio::{event::Event, net::TcpStream, unix::SourceFd, Events, Interest, Poll, Token}; use smoltcp::{ iface::{Config, Interface, SocketHandle, SocketSet}, phy::{Device, Medium, RxToken, TunTapInterface, TxToken}, socket::{tcp, tcp::State, udp, udp::UdpMetadata}, time::Instant, - wire::{IpCidr, IpProtocol, Ipv4Packet, Ipv6Packet, TcpPacket, UdpPacket}, + wire::{IpCidr, IpProtocol, Ipv4Packet, Ipv6Packet, TcpPacket, UdpPacket, UDP_HEADER_LEN}, }; use socks5_impl::protocol::{Address, UserKey}; use std::{ @@ -19,24 +19,40 @@ use std::{ }; #[derive(Hash, Clone, Eq, PartialEq, Debug)] -pub(crate) struct Connection { +pub(crate) struct ConnectionInfo { pub(crate) src: SocketAddr, pub(crate) dst: Address, - pub(crate) proto: IpProtocol, + pub(crate) protocol: IpProtocol, } -impl Connection { +impl Default for ConnectionInfo { + fn default() -> Self { + Self { + src: SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 0), + dst: Address::unspecified(), + protocol: IpProtocol::Tcp, + } + } +} + +impl ConnectionInfo { + #[allow(dead_code)] + pub fn new(src: SocketAddr, dst: Address, protocol: IpProtocol) -> Self { + Self { src, dst, protocol } + } + fn to_named(&self, name: String) -> Self { let mut result = self.clone(); result.dst = Address::from((name, result.dst.port())); - log::trace!("Replace dst \"{}\" -> \"{}\"", self.dst, result.dst); + // let p = self.protocol; + // log::trace!("{p} replace dst \"{}\" -> \"{}\"", self.dst, result.dst); result } } -impl std::fmt::Display for Connection { +impl std::fmt::Display for ConnectionInfo { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(f, "{} -> {}", self.src, self.dst) + write!(f, "{} {} -> {}", self.protocol, self.src, self.dst) } } @@ -60,10 +76,11 @@ pub(crate) enum Direction { #[allow(dead_code)] pub(crate) enum ConnectionEvent<'a> { - NewConnection(&'a Connection), - ConnectionClosed(&'a Connection), + NewConnection(&'a ConnectionInfo), + ConnectionClosed(&'a ConnectionInfo), } +#[derive(Debug)] pub(crate) struct DataEvent<'a, T> { pub(crate) direction: T, pub(crate) buffer: &'a [u8], @@ -73,95 +90,87 @@ pub(crate) type IncomingDataEvent<'a> = DataEvent<'a, IncomingDirection>; pub(crate) type OutgoingDataEvent<'a> = DataEvent<'a, OutgoingDirection>; fn get_transport_info( - proto: IpProtocol, + protocol: IpProtocol, transport_offset: usize, packet: &[u8], -) -> Option<((u16, u16), bool, usize, usize)> { - match proto { - IpProtocol::Udp => match UdpPacket::new_checked(packet) { - Ok(result) => Some(( - (result.src_port(), result.dst_port()), - false, - transport_offset + 8, - packet.len() - 8, - )), - Err(_) => None, - }, - IpProtocol::Tcp => match TcpPacket::new_checked(packet) { - Ok(result) => Some(( - (result.src_port(), result.dst_port()), - result.syn() && !result.ack(), - transport_offset + result.header_len() as usize, - packet.len(), - )), - Err(_) => None, - }, - _ => None, +) -> Result<((u16, u16), bool, usize, usize)> { + match protocol { + IpProtocol::Udp => UdpPacket::new_checked(packet) + .map(|result| { + ( + (result.src_port(), result.dst_port()), + false, + transport_offset + UDP_HEADER_LEN, + packet.len() - UDP_HEADER_LEN, + ) + }) + .map_err(|e| e.into()), + IpProtocol::Tcp => TcpPacket::new_checked(packet) + .map(|result| { + ( + (result.src_port(), result.dst_port()), + result.syn() && !result.ack(), + transport_offset + result.header_len() as usize, + packet.len(), + ) + }) + .map_err(|e| e.into()), + _ => Err(format!("Unsupported protocol {protocol}").into()), } } -fn connection_tuple(frame: &[u8]) -> Option<(Connection, bool, usize, usize)> { +fn connection_tuple(frame: &[u8]) -> Result<(ConnectionInfo, bool, usize, usize)> { if let Ok(packet) = Ipv4Packet::new_checked(frame) { - let proto = packet.next_header(); + let protocol = packet.next_header(); let mut a = [0_u8; 4]; a.copy_from_slice(packet.src_addr().as_bytes()); let src_addr = IpAddr::from(a); a.copy_from_slice(packet.dst_addr().as_bytes()); let dst_addr = IpAddr::from(a); + let header_len = packet.header_len().into(); - return if let Some((ports, first_packet, payload_offset, payload_size)) = get_transport_info( - proto, - packet.header_len().into(), - &frame[packet.header_len().into()..], - ) { - let connection = Connection { - src: SocketAddr::new(src_addr, ports.0), - dst: SocketAddr::new(dst_addr, ports.1).into(), - proto, - }; - Some((connection, first_packet, payload_offset, payload_size)) - } else { - None + let (ports, first_packet, payload_offset, payload_size) = + get_transport_info(protocol, header_len, &frame[header_len..])?; + let info = ConnectionInfo { + src: SocketAddr::new(src_addr, ports.0), + dst: SocketAddr::new(dst_addr, ports.1).into(), + protocol, }; + return Ok((info, first_packet, payload_offset, payload_size)); } - match Ipv6Packet::new_checked(frame) { - Ok(packet) => { - // TODO: Support extension headers. - let proto = packet.next_header(); + if let Ok(packet) = Ipv6Packet::new_checked(frame) { + // TODO: Support extension headers. + let protocol = packet.next_header(); - let mut a = [0_u8; 16]; - a.copy_from_slice(packet.src_addr().as_bytes()); - let src_addr = IpAddr::from(a); - a.copy_from_slice(packet.dst_addr().as_bytes()); - let dst_addr = IpAddr::from(a); + let mut a = [0_u8; 16]; + a.copy_from_slice(packet.src_addr().as_bytes()); + let src_addr = IpAddr::from(a); + a.copy_from_slice(packet.dst_addr().as_bytes()); + let dst_addr = IpAddr::from(a); + let header_len = packet.header_len(); - if let Some((ports, first_packet, payload_offset, payload_size)) = - get_transport_info(proto, packet.header_len(), &frame[packet.header_len()..]) - { - let connection = Connection { - src: SocketAddr::new(src_addr, ports.0), - dst: SocketAddr::new(dst_addr, ports.1).into(), - proto, - }; - Some((connection, first_packet, payload_offset, payload_size)) - } else { - None - } - } - _ => None, + let (ports, first_packet, payload_offset, payload_size) = + get_transport_info(protocol, header_len, &frame[header_len..])?; + let info = ConnectionInfo { + src: SocketAddr::new(src_addr, ports.0), + dst: SocketAddr::new(dst_addr, ports.1).into(), + protocol, + }; + return Ok((info, first_packet, payload_offset, payload_size)); } + Err("Neither IPv6 nor IPv4 packet".into()) } const SERVER_WRITE_CLOSED: u8 = 1; const CLIENT_WRITE_CLOSED: u8 = 2; -struct ConnectionState { - smoltcp_handle: SocketHandle, +struct TcpConnectState { + smoltcp_handle: Option, mio_stream: TcpStream, token: Token, - handler: Box, + tcp_proxy_handler: Box, close_state: u8, wait_read: bool, wait_write: bool, @@ -176,30 +185,30 @@ pub(crate) trait TcpProxy { fn reset_connection(&self) -> bool; } +pub(crate) trait UdpProxy { + fn send_frame(&mut self, destination: &Address, frame: &[u8]) -> Result<(), Error>; + fn receive_frame(&mut self, source: &SocketAddr, frame: &[u8]) -> Result<(), Error>; +} + pub(crate) trait ConnectionManager { - fn handles_connection(&self, connection: &Connection) -> bool; - fn new_connection( - &self, - connection: &Connection, - manager: Rc, - ) -> Result>, Error>; - fn close_connection(&self, connection: &Connection); - fn get_server(&self) -> SocketAddr; + fn handles_connection(&self, info: &ConnectionInfo) -> bool; + fn new_tcp_proxy(&self, info: &ConnectionInfo) -> Result, Error>; + fn close_connection(&self, info: &ConnectionInfo); + fn get_server_addr(&self) -> SocketAddr; fn get_credentials(&self) -> &Option; } const TUN_TOKEN: Token = Token(0); -const UDP_TOKEN: Token = Token(1); const EXIT_TOKEN: Token = Token(2); pub struct TunToProxy<'a> { tun: TunTapInterface, poll: Poll, iface: Interface, - connections: HashMap, + connection_map: HashMap, connection_managers: Vec>, next_token: usize, - token_to_connection: HashMap, + token_to_info: HashMap, sockets: SocketSet<'a>, device: VirtualTunDevice, options: Options, @@ -234,10 +243,10 @@ impl<'a> TunToProxy<'a> { Medium::Ip => Config::new(smoltcp::wire::HardwareAddress::Ip), Medium::Ieee802154 => todo!(), }; - let mut virt = VirtualTunDevice::new(tun.capabilities()); + let mut device = VirtualTunDevice::new(tun.capabilities()); let gateway4: Ipv4Addr = Ipv4Addr::from_str("0.0.0.1")?; let gateway6: Ipv6Addr = Ipv6Addr::from_str("::1")?; - let mut iface = Interface::new(config, &mut virt, Instant::now()); + let mut iface = Interface::new(config, &mut device, Instant::now()); iface.update_ip_addrs(|ip_addrs| { ip_addrs.push(IpCidr::new(gateway4.into(), 0)).unwrap(); ip_addrs.push(IpCidr::new(gateway6.into(), 0)).unwrap() @@ -250,12 +259,12 @@ impl<'a> TunToProxy<'a> { tun, poll, iface, - connections: HashMap::default(), + connection_map: HashMap::default(), next_token: usize::from(EXIT_TOKEN) + 1, - token_to_connection: HashMap::default(), + token_to_info: HashMap::default(), connection_managers: Vec::default(), sockets: SocketSet::new([]), - device: virt, + device, options, write_sockets: HashSet::default(), _exit_receiver: exit_receiver, @@ -292,28 +301,34 @@ impl<'a> TunToProxy<'a> { Ok(()) } - fn remove_connection(&mut self, connection: &Connection) -> Result<(), Error> { - if let Some(mut conn) = self.connections.remove(connection) { + fn remove_connection(&mut self, info: &ConnectionInfo) -> Result<(), Error> { + if let Some(mut conn) = self.connection_map.remove(info) { + _ = conn.mio_stream.shutdown(Both); + if let Some(handle) = conn.smoltcp_handle { + let socket = self.sockets.get_mut::(handle); + socket.close(); + self.sockets.remove(handle); + } + self.expect_smoltcp_send()?; let token = &conn.token; - self.token_to_connection.remove(token); - self.sockets.remove(conn.smoltcp_handle); + self.token_to_info.remove(token); _ = self.poll.registry().deregister(&mut conn.mio_stream); - log::info!("CLOSE {}", connection); + log::info!("CLOSE {}", info); } Ok(()) } - fn get_connection_manager(&self, connection: &Connection) -> Option> { + fn get_connection_manager(&self, info: &ConnectionInfo) -> Option> { for manager in self.connection_managers.iter() { - if manager.handles_connection(connection) { + if manager.handles_connection(info) { return Some(manager.clone()); } } None } - fn check_change_close_state(&mut self, connection: &Connection) -> Result<(), Error> { - let state = self.connections.get_mut(connection); + fn check_change_close_state(&mut self, info: &ConnectionInfo) -> Result<(), Error> { + let state = self.connection_map.get_mut(info); if state.is_none() { return Ok(()); } @@ -321,23 +336,25 @@ impl<'a> TunToProxy<'a> { let mut closed_ends = 0; if (state.close_state & SERVER_WRITE_CLOSED) == SERVER_WRITE_CLOSED && !state - .handler + .tcp_proxy_handler .have_data(Direction::Incoming(IncomingDirection::FromServer)) && !state - .handler + .tcp_proxy_handler .have_data(Direction::Outgoing(OutgoingDirection::ToClient)) { - let socket = self.sockets.get_mut::(state.smoltcp_handle); - socket.close(); + if let Some(socket_handle) = state.smoltcp_handle { + let socket = self.sockets.get_mut::(socket_handle); + socket.close(); + } closed_ends += 1; } if (state.close_state & CLIENT_WRITE_CLOSED) == CLIENT_WRITE_CLOSED && !state - .handler + .tcp_proxy_handler .have_data(Direction::Incoming(IncomingDirection::FromClient)) && !state - .handler + .tcp_proxy_handler .have_data(Direction::Outgoing(OutgoingDirection::ToServer)) { _ = state.mio_stream.shutdown(Shutdown::Write); @@ -345,20 +362,22 @@ impl<'a> TunToProxy<'a> { } if closed_ends == 2 { - self.remove_connection(connection)?; + self.remove_connection(info)?; } Ok(()) } - fn tunsocket_read_and_forward(&mut self, connection: &Connection) -> Result<(), Error> { + fn tunsocket_read_and_forward(&mut self, info: &ConnectionInfo) -> Result<(), Error> { // Scope for mutable borrow of self. { - let state = self.connections.get_mut(connection); - if state.is_none() { - return Ok(()); - } - let state = state.unwrap(); - let socket = self.sockets.get_mut::(state.smoltcp_handle); + let state = match self.connection_map.get_mut(info) { + Some(state) => state, + None => return Ok(()), + }; + let socket = match state.smoltcp_handle { + Some(handle) => self.sockets.get_mut::(handle), + None => return Ok(()), + }; let mut error = Ok(()); while socket.can_recv() && error.is_ok() { socket.recv(|data| { @@ -366,7 +385,7 @@ impl<'a> TunToProxy<'a> { direction: IncomingDirection::FromClient, buffer: data, }; - error = state.handler.push_data(event); + error = state.tcp_proxy_handler.push_data(event); (data.len(), ()) })?; } @@ -385,20 +404,14 @@ impl<'a> TunToProxy<'a> { self.expect_smoltcp_send()?; } - self.check_change_close_state(connection)?; + self.check_change_close_state(info)?; Ok(()) } - // Update the poll registry depending on the connection's event interests. - fn update_mio_socket_interest(&mut self, connection: &Connection) -> Result<(), Error> { - let state = self - .connections - .get_mut(connection) - .ok_or("connection not found")?; - + fn update_mio_socket_interest(poll: &mut Poll, state: &mut TcpConnectState) -> Result<()> { // Maybe we did not listen for any events before. Therefore, just swallow the error. - _ = self.poll.registry().deregister(&mut state.mio_stream); + _ = poll.registry().deregister(&mut state.mio_stream); // If we do not wait for read or write events, we do not need to register them. if !state.wait_read && !state.wait_write { @@ -415,150 +428,131 @@ impl<'a> TunToProxy<'a> { interest = Interest::READABLE | Interest::WRITABLE; } - self.poll - .registry() + poll.registry() .register(&mut state.mio_stream, state.token, interest)?; Ok(()) } // A raw packet was received on the tunnel interface. fn receive_tun(&mut self, frame: &mut [u8]) -> Result<(), Error> { - if let Some((connection, first_packet, offset, size)) = connection_tuple(frame) { - let resolved_conn = match &mut self.options.virtdns { - None => connection.clone(), - Some(virt_dns) => { - let ip = SocketAddr::try_from(connection.dst.clone())?.ip(); - virt_dns.touch_ip(&ip); - match virt_dns.resolve_ip(&ip) { - None => connection.clone(), - Some(name) => connection.to_named(name.clone()), + let mut handler = || -> Result<(), Error> { + let (info, first_packet, payload_offset, payload_size) = connection_tuple(frame)?; + let dst = SocketAddr::try_from(&info.dst)?; + let connection_info = match &mut self.options.virtual_dns { + None => info.clone(), + Some(virtual_dns) => { + let dst_ip = dst.ip(); + virtual_dns.touch_ip(&dst_ip); + match virtual_dns.resolve_ip(&dst_ip) { + None => info.clone(), + Some(name) => info.to_named(name.clone()), } } }; - let dst = connection.dst; - let handler = || -> Result<(), Error> { - if resolved_conn.proto == IpProtocol::Tcp { - let cm = self.get_connection_manager(&resolved_conn); - if cm.is_none() { - log::trace!("no connect manager"); - return Ok(()); + log::trace!("{} ({})", connection_info, dst); + if connection_info.protocol == IpProtocol::Tcp { + let server_addr = self + .get_connection_manager(&connection_info) + .ok_or("get_connection_manager")? + .get_server_addr(); + if first_packet { + if let Some(manager) = self.connection_managers.iter_mut().next() { + let tcp_proxy_handler = manager.new_tcp_proxy(&connection_info)?; + let mut socket = tcp::Socket::new( + tcp::SocketBuffer::new(vec![0; 1024 * 128]), + tcp::SocketBuffer::new(vec![0; 1024 * 128]), + ); + socket.set_ack_delay(None); + socket.listen(dst)?; + let handle = self.sockets.add(socket); + + let mut client = TcpStream::connect(server_addr)?; + let token = self.new_token(); + let i = Interest::READABLE; + self.poll.registry().register(&mut client, token, i)?; + + let state = TcpConnectState { + smoltcp_handle: Some(handle), + mio_stream: client, + token, + tcp_proxy_handler, + close_state: 0, + wait_read: true, + wait_write: false, + }; + self.connection_map.insert(connection_info.clone(), state); + + self.token_to_info.insert(token, connection_info.clone()); + + // log::info!("CONNECT {} ({})", connection_info, dst); } - let server = cm.unwrap().get_server(); - if first_packet { - for manager in self.connection_managers.iter_mut() { - if let Some(handler) = - manager.new_connection(&resolved_conn, manager.clone())? - { - let mut socket = tcp::Socket::new( - tcp::SocketBuffer::new(vec![0; 1024 * 128]), - tcp::SocketBuffer::new(vec![0; 1024 * 128]), - ); - socket.set_ack_delay(None); - let dst = SocketAddr::try_from(dst)?; - socket.listen(dst)?; - let handle = self.sockets.add(socket); - - let client = TcpStream::connect(server)?; - - let token = self.new_token(); - - let mut state = ConnectionState { - smoltcp_handle: handle, - mio_stream: client, - token, - handler, - close_state: 0, - wait_read: true, - wait_write: false, - }; - - self.token_to_connection - .insert(token, resolved_conn.clone()); - self.poll.registry().register( - &mut state.mio_stream, - token, - Interest::READABLE, - )?; - - self.connections.insert(resolved_conn.clone(), state); - - log::info!("CONNECT {}", resolved_conn,); - break; - } - } - } else if !self.connections.contains_key(&resolved_conn) { - return Ok(()); - } - - // Inject the packet to advance the smoltcp socket state - self.device.inject_packet(frame); - - // Having advanced the socket state, we expect the socket to ACK - // Exfiltrate the response packets generated by the socket and inject them - // into the tunnel interface. - self.expect_smoltcp_send()?; - - // Read from the smoltcp socket and push the data to the connection handler. - self.tunsocket_read_and_forward(&resolved_conn)?; - - // The connection handler builds up the connection or encapsulates the data. - // Therefore, we now expect it to write data to the server. - self.write_to_server(&resolved_conn)?; - } else if resolved_conn.proto == IpProtocol::Udp && resolved_conn.dst.port() == 53 { - if let Some(virtual_dns) = &mut self.options.virtdns { - let payload = &frame[offset..offset + size]; - if let Some(response) = virtual_dns.receive_query(payload) { - let rx_buffer = udp::PacketBuffer::new( - vec![udp::PacketMetadata::EMPTY], - vec![0; 4096], - ); - let tx_buffer = udp::PacketBuffer::new( - vec![udp::PacketMetadata::EMPTY], - vec![0; 4096], - ); - let mut socket = udp::Socket::new(rx_buffer, tx_buffer); - let dst = SocketAddr::try_from(dst)?; - socket.bind(dst)?; - socket - .send_slice( - response.as_slice(), - UdpMetadata::from(resolved_conn.src), - ) - .expect("failed to send DNS response"); - let handle = self.sockets.add(socket); - self.expect_smoltcp_send()?; - self.sockets.remove(handle); - } - } - // Otherwise, UDP is not yet supported. + } else if !self.connection_map.contains_key(&connection_info) { + return Ok(()); } - Ok::<(), Error>(()) - }; - if let Err(error) = handler() { - log::error!("{}", error); + + // Inject the packet to advance the smoltcp socket state + self.device.inject_packet(frame); + + // Having advanced the socket state, we expect the socket to ACK + // Exfiltrate the response packets generated by the socket and inject them + // into the tunnel interface. + self.expect_smoltcp_send()?; + + // Read from the smoltcp socket and push the data to the connection handler. + self.tunsocket_read_and_forward(&connection_info)?; + + // The connection handler builds up the connection or encapsulates the data. + // Therefore, we now expect it to write data to the server. + self.write_to_server(&connection_info)?; + } else if connection_info.protocol == IpProtocol::Udp { + let port = connection_info.dst.port(); + if let (Some(virtual_dns), true) = (&mut self.options.virtual_dns, port == 53) { + let payload = &frame[payload_offset..payload_offset + payload_size]; + if let Some(response) = virtual_dns.receive_query(payload) { + let rx_buffer = + udp::PacketBuffer::new(vec![udp::PacketMetadata::EMPTY], vec![0; 4096]); + let tx_buffer = + udp::PacketBuffer::new(vec![udp::PacketMetadata::EMPTY], vec![0; 4096]); + let mut socket = udp::Socket::new(rx_buffer, tx_buffer); + socket.bind(dst)?; + socket + .send_slice(response.as_slice(), UdpMetadata::from(connection_info.src)) + .expect("failed to send DNS response"); + let handle = self.sockets.add(socket); + self.expect_smoltcp_send()?; + self.sockets.remove(handle); + } + } + // Otherwise, UDP is not yet supported. } + Ok::<(), Error>(()) + }; + if let Err(error) = handler() { + log::error!("{}", error); } Ok(()) } - fn write_to_server(&mut self, connection: &Connection) -> Result<(), Error> { - if let Some(state) = self.connections.get_mut(connection) { - let event = state.handler.peek_data(OutgoingDirection::ToServer); + fn write_to_server(&mut self, info: &ConnectionInfo) -> Result<(), Error> { + if let Some(state) = self.connection_map.get_mut(info) { + let event = state + .tcp_proxy_handler + .peek_data(OutgoingDirection::ToServer); let buffer_size = event.buffer.len(); if buffer_size == 0 { state.wait_write = false; - self.update_mio_socket_interest(connection)?; - self.check_change_close_state(connection)?; + Self::update_mio_socket_interest(&mut self.poll, state)?; + self.check_change_close_state(info)?; return Ok(()); } let result = state.mio_stream.write(event.buffer); match result { Ok(written) => { state - .handler + .tcp_proxy_handler .consume_data(OutgoingDirection::ToServer, written); state.wait_write = written < buffer_size; - self.update_mio_socket_interest(connection)?; + Self::update_mio_socket_interest(&mut self.poll, state)?; } Err(error) if error.kind() != std::io::ErrorKind::WouldBlock => { return Err(error.into()); @@ -566,30 +560,35 @@ impl<'a> TunToProxy<'a> { _ => { // WOULDBLOCK case state.wait_write = true; - self.update_mio_socket_interest(connection)?; + Self::update_mio_socket_interest(&mut self.poll, state)?; } } } - self.check_change_close_state(connection)?; + self.check_change_close_state(info)?; Ok(()) } - fn write_to_client(&mut self, token: Token, connection: &Connection) -> Result<(), Error> { - while let Some(state) = self.connections.get_mut(connection) { - let socket_handle = state.smoltcp_handle; - let event = state.handler.peek_data(OutgoingDirection::ToClient); + fn write_to_client(&mut self, token: Token, info: &ConnectionInfo) -> Result<(), Error> { + while let Some(state) = self.connection_map.get_mut(info) { + let socket_handle = match state.smoltcp_handle { + Some(handle) => handle, + None => break, + }; + let event = state + .tcp_proxy_handler + .peek_data(OutgoingDirection::ToClient); let buflen = event.buffer.len(); let consumed; { let socket = self.sockets.get_mut::(socket_handle); if socket.may_send() { - if let Some(virtdns) = &mut self.options.virtdns { + if let Some(virtual_dns) = &mut self.options.virtual_dns { // Unwrapping is fine because every smoltcp socket is bound to an. - virtdns.touch_ip(&IpAddr::from(socket.local_endpoint().unwrap().addr)); + virtual_dns.touch_ip(&IpAddr::from(socket.local_endpoint().unwrap().addr)); } consumed = socket.send_slice(event.buffer)?; state - .handler + .tcp_proxy_handler .consume_data(OutgoingDirection::ToClient, consumed); self.expect_smoltcp_send()?; if consumed < buflen { @@ -606,7 +605,7 @@ impl<'a> TunToProxy<'a> { } } - self.check_change_close_state(connection)?; + self.check_change_close_state(info)?; } Ok(()) } @@ -623,7 +622,7 @@ impl<'a> TunToProxy<'a> { fn send_to_smoltcp(&mut self) -> Result<(), Error> { let cloned = self.write_sockets.clone(); for token in cloned.iter() { - if let Some(connection) = self.token_to_connection.get(token) { + if let Some(connection) = self.token_to_info.get(token) { let connection = connection.clone(); if let Err(error) = self.write_to_client(*token, &connection) { self.remove_connection(&connection)?; @@ -636,24 +635,26 @@ impl<'a> TunToProxy<'a> { fn mio_socket_event(&mut self, event: &Event) -> Result<(), Error> { let e = "connection not found"; - let conn_ref = self.token_to_connection.get(&event.token()); - // We may have closed the connection in an earlier iteration over the poll - // events, e.g. because an event through the tunnel interface indicated that the connection - // should be closed. - if conn_ref.is_none() { - log::trace!("{e}"); - return Ok(()); - } - let connection = conn_ref.unwrap().clone(); + let conn_info = match self.token_to_info.get(&event.token()) { + Some(conn_info) => conn_info.clone(), + None => { + // We may have closed the connection in an earlier iteration over the poll events, + // e.g. because an event through the tunnel interface indicated that the connection + // should be closed. + log::trace!("{e}"); + return Ok(()); + } + }; + let server = self - .get_connection_manager(&connection) - .unwrap() - .get_server(); + .get_connection_manager(&conn_info) + .ok_or(e)? + .get_server_addr(); let mut block = || -> Result<(), Error> { if event.is_readable() || event.is_read_closed() { { - let state = self.connections.get_mut(&connection).ok_or(e)?; + let state = self.connection_map.get_mut(&conn_info).ok_or(e)?; // TODO: Move this reading process to its own function. let mut vecbuf = Vec::::new(); @@ -673,34 +674,26 @@ impl<'a> TunToProxy<'a> { direction: IncomingDirection::FromServer, buffer: &data[0..read], }; - if let Err(error) = state.handler.push_data(data_event) { - state.mio_stream.shutdown(Both)?; - { - let socket = self.sockets.get_mut::( - self.connections.get(&connection).ok_or(e)?.smoltcp_handle, - ); - socket.close(); - } - self.expect_smoltcp_send()?; - log::error! {"{error}"} - self.remove_connection(&connection.clone())?; + if let Err(error) = state.tcp_proxy_handler.push_data(data_event) { + log::error!("{}", error); + self.remove_connection(&conn_info.clone())?; return Ok(()); } // The handler request for reset the server connection - if state.handler.reset_connection() { + if state.tcp_proxy_handler.reset_connection() { _ = self.poll.registry().deregister(&mut state.mio_stream); // Closes the connection with the proxy state.mio_stream.shutdown(Both)?; - log::info!("RESET {}", connection); + log::info!("RESET {}", conn_info); state.mio_stream = TcpStream::connect(server)?; state.wait_read = true; state.wait_write = true; - self.update_mio_socket_interest(&connection)?; + Self::update_mio_socket_interest(&mut self.poll, state)?; return Ok(()); } @@ -708,61 +701,54 @@ impl<'a> TunToProxy<'a> { if read == 0 || event.is_read_closed() { state.wait_read = false; state.close_state |= SERVER_WRITE_CLOSED; - self.update_mio_socket_interest(&connection)?; - self.check_change_close_state(&connection)?; + Self::update_mio_socket_interest(&mut self.poll, state)?; + self.check_change_close_state(&conn_info)?; self.expect_smoltcp_send()?; } } // We have read from the proxy server and pushed the data to the connection handler. // Thus, expect data to be processed (e.g. decapsulated) and forwarded to the client. - self.write_to_client(event.token(), &connection)?; + self.write_to_client(event.token(), &conn_info)?; // The connection handler could have produced data that is to be written to the // server. - self.write_to_server(&connection)?; + self.write_to_server(&conn_info)?; } if event.is_writable() { - self.write_to_server(&connection)?; + self.write_to_server(&conn_info)?; } Ok::<(), Error>(()) }; if let Err(error) = block() { log::error!("{}", error); - self.remove_connection(&connection)?; + self.remove_connection(&conn_info)?; } Ok(()) } - fn udp_event(&mut self, _event: &Event) {} - pub fn run(&mut self) -> Result<(), Error> { let mut events = Events::with_capacity(1024); loop { - match self.poll.poll(&mut events, None) { - Ok(()) => { - for event in events.iter() { - match event.token() { - EXIT_TOKEN => { - log::info!("exiting..."); - return Ok(()); - } - TUN_TOKEN => self.tun_event(event)?, - UDP_TOKEN => self.udp_event(event), - _ => self.mio_socket_event(event)?, - } - } - self.send_to_smoltcp()?; + if let Err(err) = self.poll.poll(&mut events, None) { + if err.kind() == std::io::ErrorKind::Interrupted { + log::warn!("Poll interrupted: \"{err}\", ignored, continue polling"); + continue; } - Err(e) => { - if e.kind() == std::io::ErrorKind::Interrupted { - log::warn!("Poll interrupted: \"{e}\", ignored, continue polling"); - } else { - return Err(e.into()); + return Err(err.into()); + } + for event in events.iter() { + match event.token() { + EXIT_TOKEN => { + log::info!("Exiting tun2proxy..."); + return Ok(()); } + TUN_TOKEN => self.tun_event(event)?, + _ => self.mio_socket_event(event)?, } } + self.send_to_smoltcp()?; } }