diff --git a/src/http.rs b/src/http.rs index c3aec4d..8f6e6b4 100644 --- a/src/http.rs +++ b/src/http.rs @@ -4,8 +4,10 @@ use crate::tun2proxy::{ OutgoingDataEvent, OutgoingDirection, TcpProxy, }; use base64::Engine; +use smoltcp::wire::IpProtocol; use std::collections::VecDeque; use std::net::SocketAddr; +use std::rc::Rc; #[derive(Eq, PartialEq, Debug)] #[allow(dead_code)] @@ -27,7 +29,7 @@ pub struct HttpConnection { } impl HttpConnection { - fn new(connection: &Connection, manager: std::rc::Rc) -> Self { + fn new(connection: &Connection, manager: Rc) -> Self { let mut server_outbuf: VecDeque = VecDeque::new(); { let credentials = manager.get_credentials(); @@ -163,26 +165,24 @@ impl TcpProxy for HttpConnection { } pub struct HttpManager { - server: std::net::SocketAddr, + server: SocketAddr, credentials: Option, } impl ConnectionManager for HttpManager { fn handles_connection(&self, connection: &Connection) -> bool { - connection.proto == smoltcp::wire::IpProtocol::Tcp.into() + connection.proto == IpProtocol::Tcp.into() } fn new_connection( &self, connection: &Connection, - manager: std::rc::Rc, - ) -> Option> { - if connection.proto != smoltcp::wire::IpProtocol::Tcp.into() { + manager: Rc, + ) -> Option> { + if connection.proto != IpProtocol::Tcp.into() { return None; } - Some(std::boxed::Box::new(HttpConnection::new( - connection, manager, - ))) + Some(Box::new(HttpConnection::new(connection, manager))) } fn close_connection(&self, _: &Connection) {} @@ -197,8 +197,8 @@ impl ConnectionManager for HttpManager { } impl HttpManager { - pub fn new(server: SocketAddr, credentials: Option) -> std::rc::Rc { - std::rc::Rc::new(Self { + pub fn new(server: SocketAddr, credentials: Option) -> Rc { + Rc::new(Self { server, credentials, }) diff --git a/src/lib.rs b/src/lib.rs index df2910d..e3f869d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -86,5 +86,7 @@ pub fn main_entry(tun: &str, proxy: Proxy, options: Options) { ttp.add_connection_manager(HttpManager::new(proxy.addr, proxy.credentials)); } } - ttp.run(); + if let Err(e) = ttp.run() { + log::error!("{e}"); + } } diff --git a/src/socks5.rs b/src/socks5.rs index e2f27ce..ee8cb37 100644 --- a/src/socks5.rs +++ b/src/socks5.rs @@ -3,8 +3,10 @@ use crate::tun2proxy::{ Connection, ConnectionManager, Credentials, DestinationHost, IncomingDataEvent, IncomingDirection, OutgoingDataEvent, OutgoingDirection, TcpProxy, }; +use smoltcp::wire::IpProtocol; use std::collections::VecDeque; use std::net::{IpAddr, SocketAddr}; +use std::rc::Rc; #[derive(Eq, PartialEq, Debug)] #[allow(dead_code)] @@ -60,11 +62,11 @@ pub(crate) struct SocksConnection { client_outbuf: VecDeque, server_outbuf: VecDeque, data_buf: VecDeque, - manager: std::rc::Rc, + manager: Rc, } impl SocksConnection { - pub fn new(connection: &Connection, manager: std::rc::Rc) -> Self { + pub fn new(connection: &Connection, manager: Rc) -> Self { let mut result = Self { connection: connection.clone(), state: SocksState::ServerHello, @@ -291,26 +293,24 @@ impl TcpProxy for SocksConnection { } pub struct Socks5Manager { - server: std::net::SocketAddr, + server: SocketAddr, credentials: Option, } impl ConnectionManager for Socks5Manager { fn handles_connection(&self, connection: &Connection) -> bool { - connection.proto == smoltcp::wire::IpProtocol::Tcp.into() + connection.proto == IpProtocol::Tcp.into() } fn new_connection( &self, connection: &Connection, - manager: std::rc::Rc, - ) -> Option> { - if connection.proto != smoltcp::wire::IpProtocol::Tcp.into() { + manager: Rc, + ) -> Option> { + if connection.proto != IpProtocol::Tcp.into() { return None; } - Some(std::boxed::Box::new(SocksConnection::new( - connection, manager, - ))) + Some(Box::new(SocksConnection::new(connection, manager))) } fn close_connection(&self, _: &Connection) {} @@ -325,8 +325,8 @@ impl ConnectionManager for Socks5Manager { } impl Socks5Manager { - pub fn new(server: SocketAddr, credentials: Option) -> std::rc::Rc { - std::rc::Rc::new(Self { + pub fn new(server: SocketAddr, credentials: Option) -> Rc { + Rc::new(Self { server, credentials, }) diff --git a/src/tun2proxy.rs b/src/tun2proxy.rs index 94ac443..5bf3834 100644 --- a/src/tun2proxy.rs +++ b/src/tun2proxy.rs @@ -12,15 +12,17 @@ use smoltcp::phy::{Device, Medium, RxToken, TunTapInterface, TxToken}; use smoltcp::socket::tcp; use smoltcp::time::Instant; use smoltcp::wire::{ - IpAddress, IpCidr, Ipv4Address, Ipv4Packet, Ipv6Address, Ipv6Packet, TcpPacket, UdpPacket, + IpAddress, IpCidr, IpProtocol, Ipv4Address, Ipv4Packet, Ipv6Address, Ipv6Packet, TcpPacket, + UdpPacket, }; use std::collections::{HashMap, HashSet}; -use std::convert::From; +use std::convert::{From, TryFrom}; use std::fmt::{Display, Formatter}; use std::io::{Read, Write}; use std::net::Shutdown::Both; use std::net::{IpAddr, Shutdown, SocketAddr}; use std::os::unix::io::AsRawFd; +use std::rc::Rc; #[derive(Hash, Clone, Eq, PartialEq)] pub enum DestinationHost { @@ -43,17 +45,18 @@ pub(crate) struct Destination { pub(crate) port: u16, } -impl From for SocketAddr { - fn from(value: Destination) -> Self { - SocketAddr::new( +impl TryFrom for SocketAddr { + type Error = Error; + fn try_from(value: Destination) -> Result { + Ok(SocketAddr::new( match value.host { DestinationHost::Address(addr) => addr, - DestinationHost::Hostname(_) => { - panic!("Failed to convert hostname destination into socket address") + DestinationHost::Hostname(e) => { + return Err(e.into()); } }, value.port, - ) + )) } } @@ -94,8 +97,8 @@ impl Display for Destination { } #[derive(Hash, Clone, Eq, PartialEq)] -pub struct Connection { - pub(crate) src: std::net::SocketAddr, +pub(crate) struct Connection { + pub(crate) src: SocketAddr, pub(crate) dst: Destination, pub(crate) proto: u8, } @@ -145,7 +148,7 @@ fn get_transport_info( transport_offset: usize, packet: &[u8], ) -> Option<((u16, u16), bool, usize, usize)> { - if proto == smoltcp::wire::IpProtocol::Udp.into() { + if proto == IpProtocol::Udp.into() { match UdpPacket::new_checked(packet) { Ok(result) => Some(( (result.src_port(), result.dst_port()), @@ -155,7 +158,7 @@ fn get_transport_info( )), Err(_) => None, } - } else if proto == smoltcp::wire::IpProtocol::Tcp.into() { + } else if proto == IpProtocol::Tcp.into() { match TcpPacket::new_checked(packet) { Ok(result) => Some(( (result.src_port(), result.dst_port()), @@ -230,7 +233,7 @@ struct ConnectionState { smoltcp_handle: SocketHandle, mio_stream: TcpStream, token: Token, - handler: std::boxed::Box, + handler: Box, smoltcp_socket_state: u8, } @@ -261,8 +264,8 @@ pub(crate) trait ConnectionManager { fn new_connection( &self, connection: &Connection, - manager: std::rc::Rc, - ) -> Option>; + manager: Rc, + ) -> Option>; fn close_connection(&self, connection: &Connection); fn get_server(&self) -> SocketAddr; fn get_credentials(&self) -> &Option; @@ -283,15 +286,15 @@ impl Options { self } } +const TCP_TOKEN: Token = Token(0); +const UDP_TOKEN: Token = Token(1); pub(crate) struct TunToProxy<'a> { tun: TunTapInterface, poll: Poll, - tun_token: Token, - udp_token: Token, iface: Interface, connections: HashMap, - connection_managers: Vec>, + connection_managers: Vec>, next_token: usize, token_to_connection: HashMap, sockets: SocketSet<'a>, @@ -302,13 +305,12 @@ pub(crate) struct TunToProxy<'a> { impl<'a> TunToProxy<'a> { pub(crate) fn new(interface: &str, options: Options) -> Self { - let tun_token = Token(0); let tun = TunTapInterface::new(interface, Medium::Ip).unwrap(); let poll = Poll::new().unwrap(); poll.registry() .register( &mut SourceFd(&tun.as_raw_fd()), - tun_token, + TCP_TOKEN, Interest::READABLE, ) .unwrap(); @@ -337,8 +339,6 @@ impl<'a> TunToProxy<'a> { Self { tun, poll, - tun_token, - udp_token: Token(1), iface, connections: Default::default(), next_token: 2, @@ -351,7 +351,7 @@ impl<'a> TunToProxy<'a> { } } - pub(crate) fn add_connection_manager(&mut self, manager: std::rc::Rc) { + pub(crate) fn add_connection_manager(&mut self, manager: Rc) { self.connection_managers.push(manager); } @@ -383,10 +383,7 @@ impl<'a> TunToProxy<'a> { info!("CLOSE {}", connection); } - fn get_connection_manager( - &self, - connection: &Connection, - ) -> Option> { + fn get_connection_manager(&self, connection: &Connection) -> Option> { for manager in self.connection_managers.iter() { if manager.handles_connection(connection) { return Some(manager.clone()); @@ -450,7 +447,7 @@ impl<'a> TunToProxy<'a> { } } }; - if resolved_conn.proto == smoltcp::wire::IpProtocol::Tcp.into() { + if resolved_conn.proto == IpProtocol::Tcp.into() { let cm = self.get_connection_manager(&resolved_conn); if cm.is_none() { return; @@ -466,10 +463,8 @@ impl<'a> TunToProxy<'a> { smoltcp::socket::tcp::SocketBuffer::new(vec![0; 4096]), ); socket.set_ack_delay(None); - let dst = connection.dst; - socket - .listen(>::into(dst)) - .unwrap(); + let dst = SocketAddr::try_from(connection.dst).unwrap(); + socket.listen(dst).unwrap(); let handle = self.sockets.add(socket); let client = TcpStream::connect(server).unwrap(); @@ -520,7 +515,7 @@ impl<'a> TunToProxy<'a> { // The connection handler builds up the connection or encapsulates the data. // Therefore, we now expect it to write data to the server. self.write_to_server(&resolved_conn); - } else if resolved_conn.proto == smoltcp::wire::IpProtocol::Udp.into() { + } else if resolved_conn.proto == IpProtocol::Udp.into() { if let Some(virtual_dns) = &mut self.options.virtdns { let payload = &frame[_payload_offset.._payload_offset + _payload_size]; if let Some(response) = virtual_dns.receive_query(payload) { @@ -533,10 +528,8 @@ impl<'a> TunToProxy<'a> { vec![0; 4096], ); let mut socket = smoltcp::socket::udp::Socket::new(rx_buffer, tx_buffer); - let dst = resolved_conn.dst.clone(); - socket - .bind(>::into(dst)) - .unwrap(); + let dst = SocketAddr::try_from(connection.dst).unwrap(); + socket.bind(dst).unwrap(); socket .send_slice(response.as_slice(), resolved_conn.src.into()) .expect("failed to send DNS response"); @@ -700,18 +693,16 @@ impl<'a> TunToProxy<'a> { fn udp_event(&mut self, _event: &Event) {} - pub(crate) fn run(&mut self) { + pub(crate) fn run(&mut self) -> Result<(), Error> { let mut events = Events::with_capacity(1024); loop { - self.poll.poll(&mut events, None).unwrap(); + self.poll.poll(&mut events, None)?; for event in events.iter() { - if event.token() == self.tun_token { - self.tun_event(event); - } else if event.token() == self.udp_token { - self.udp_event(event); - } else { - self.mio_socket_event(event); + match event.token() { + TCP_TOKEN => self.tun_event(event), + UDP_TOKEN => self.udp_event(event), + _ => self.mio_socket_event(event), } } self.send_to_smoltcp();