diff --git a/src/error.rs b/src/error.rs index 00a73cc..b7ecc3b 100644 --- a/src/error.rs +++ b/src/error.rs @@ -9,33 +9,49 @@ pub fn s2e(s: &str) -> Error { impl From for Error { fn from(err: std::io::Error) -> Self { - Self { - message: err.to_string(), - } + From::::from(err.to_string()) } } impl From for Error { fn from(err: std::net::AddrParseError) -> Self { - Self { - message: err.to_string(), - } + From::::from(err.to_string()) } } impl From for Error { fn from(err: smoltcp::iface::RouteTableFull) -> Self { - Self { - message: format!("{err:?}"), - } + From::::from(format!("{err:?}")) + } +} + +impl From for Error { + fn from(err: smoltcp::socket::tcp::RecvError) -> Self { + From::::from(format!("{err:?}")) + } +} + +impl From for Error { + fn from(err: smoltcp::socket::tcp::ListenError) -> Self { + From::::from(format!("{err:?}")) + } +} + +impl From for Error { + fn from(err: smoltcp::socket::udp::BindError) -> Self { + From::::from(format!("{err:?}")) + } +} + +impl From for Error { + fn from(err: smoltcp::socket::tcp::SendError) -> Self { + From::::from(format!("{err:?}")) } } impl From<&str> for Error { fn from(err: &str) -> Self { - Self { - message: err.to_string(), - } + From::::from(err.to_string()) } } @@ -47,9 +63,7 @@ impl From for Error { impl From<&String> for Error { fn from(err: &String) -> Self { - Self { - message: err.to_string(), - } + From::::from(err.to_string()) } } diff --git a/src/tun2proxy.rs b/src/tun2proxy.rs index 4dcf7e5..a67643d 100644 --- a/src/tun2proxy.rs +++ b/src/tun2proxy.rs @@ -9,7 +9,7 @@ use mio::unix::SourceFd; use mio::{Events, Interest, Poll, Token}; use smoltcp::iface::{Config, Interface, SocketHandle, SocketSet}; use smoltcp::phy::{Device, Medium, RxToken, TunTapInterface, TxToken}; -use smoltcp::socket::tcp; +use smoltcp::socket::{tcp, udp}; use smoltcp::time::Instant; use smoltcp::wire::{IpCidr, IpProtocol, Ipv4Packet, Ipv6Packet, TcpPacket, UdpPacket}; use std::collections::{HashMap, HashSet}; @@ -312,14 +312,14 @@ impl<'a> TunToProxy<'a> { tun, poll, iface, - connections: Default::default(), + connections: HashMap::default(), next_token: 2, - token_to_connection: Default::default(), - connection_managers: Default::default(), + token_to_connection: HashMap::default(), + connection_managers: Vec::default(), sockets: SocketSet::new([]), device: virt, options, - write_sockets: Default::default(), + write_sockets: HashSet::default(), }; Ok(tun) } @@ -328,7 +328,7 @@ impl<'a> TunToProxy<'a> { self.connection_managers.push(manager); } - fn expect_smoltcp_send(&mut self) { + fn expect_smoltcp_send(&mut self) -> Result<(), Error> { self.iface .poll(Instant::now(), &mut self.device, &mut self.sockets); @@ -338,22 +338,22 @@ impl<'a> TunToProxy<'a> { // TODO: Actual write. Replace. self.tun .transmit(Instant::now()) - .unwrap() + .ok_or("tx token not available")? .consume(slice.len(), |buf| { buf[..].clone_from_slice(slice); }); } + Ok(()) } - fn remove_connection(&mut self, connection: &Connection) { - let mut connection_state = self.connections.remove(connection).unwrap(); - let token = &connection_state.token; + fn remove_connection(&mut self, connection: &Connection) -> Result<(), Error> { + let e = "connection not exist"; + let mut conn = self.connections.remove(connection).ok_or(e)?; + let token = &conn.token; self.token_to_connection.remove(token); - self.poll - .registry() - .deregister(&mut connection_state.mio_stream) - .unwrap(); + self.poll.registry().deregister(&mut conn.mio_stream)?; info!("CLOSE {}", connection); + Ok(()) } fn get_connection_manager(&self, connection: &Connection) -> Option> { @@ -365,23 +365,21 @@ impl<'a> TunToProxy<'a> { None } - fn tunsocket_read_and_forward(&mut self, connection: &Connection) { + fn tunsocket_read_and_forward(&mut self, connection: &Connection) -> Result<(), Error> { if let Some(state) = self.connections.get_mut(connection) { let closed = { let socket = self.sockets.get_mut::(state.smoltcp_handle); let mut error = Ok(()); while socket.can_recv() && error.is_ok() { - socket - .recv(|data| { - let event = IncomingDataEvent { - direction: IncomingDirection::FromClient, - buffer: data, - }; - error = state.handler.push_data(event); + socket.recv(|data| { + let event = IncomingDataEvent { + direction: IncomingDirection::FromClient, + buffer: data, + }; + error = state.handler.push_data(event); - (data.len(), ()) - }) - .unwrap(); + (data.len(), ()) + })?; } match error { @@ -394,24 +392,26 @@ impl<'a> TunToProxy<'a> { }; // Expect ACKs etc. from smoltcp sockets. - self.expect_smoltcp_send(); + self.expect_smoltcp_send()?; if closed { - let connection_state = self.connections.get_mut(connection).unwrap(); - connection_state.mio_stream.shutdown(Both).unwrap(); - self.remove_connection(connection); + let e = "connection not exist"; + let connection_state = self.connections.get_mut(connection).ok_or(e)?; + connection_state.mio_stream.shutdown(Both)?; + self.remove_connection(connection)?; } } + Ok(()) } - fn receive_tun(&mut self, frame: &mut [u8]) { + fn receive_tun(&mut self, frame: &mut [u8]) -> Result<(), Error> { if let Some((connection, first_packet, _payload_offset, _payload_size)) = connection_tuple(frame) { let resolved_conn = match &self.options.virtdns { None => connection.clone(), Some(virt_dns) => { - let ip = SocketAddr::try_from(connection.dst.clone()).unwrap().ip(); + let ip = SocketAddr::try_from(connection.dst.clone())?.ip(); match virt_dns.ip_to_name(&ip) { None => connection.clone(), Some(name) => connection.to_named(name.clone()), @@ -421,9 +421,9 @@ impl<'a> TunToProxy<'a> { if resolved_conn.proto == IpProtocol::Tcp.into() { let cm = self.get_connection_manager(&resolved_conn); if cm.is_none() { - return; + return Ok(()); } - let server = cm.unwrap().get_server(); + let server = cm.ok_or("no connect manager")?.get_server(); if first_packet { for manager in self.connection_managers.iter_mut() { if let Some(handler) = @@ -434,11 +434,11 @@ impl<'a> TunToProxy<'a> { tcp::SocketBuffer::new(vec![0; 4096]), ); socket.set_ack_delay(None); - let dst = SocketAddr::try_from(connection.dst).unwrap(); - socket.listen(dst).unwrap(); + let dst = SocketAddr::try_from(connection.dst)?; + socket.listen(dst)?; let handle = self.sockets.add(socket); - let client = TcpStream::connect(server).unwrap(); + let client = TcpStream::connect(server)?; let token = Token(self.next_token); self.next_token += 1; @@ -453,14 +453,11 @@ impl<'a> TunToProxy<'a> { self.token_to_connection .insert(token, resolved_conn.clone()); - self.poll - .registry() - .register( - &mut state.mio_stream, - token, - Interest::READABLE | Interest::WRITABLE, - ) - .unwrap(); + self.poll.registry().register( + &mut state.mio_stream, + token, + Interest::READABLE | Interest::WRITABLE, + )?; self.connections.insert(resolved_conn.clone(), state); @@ -469,7 +466,7 @@ impl<'a> TunToProxy<'a> { } } } else if !self.connections.contains_key(&resolved_conn) { - return; + return Ok(()); } // Inject the packet to advance the smoltcp socket state @@ -478,10 +475,10 @@ impl<'a> TunToProxy<'a> { // Having advanced the socket state, we expect the socket to ACK // Exfiltrate the response packets generated by the socket and inject them // into the tunnel interface. - self.expect_smoltcp_send(); + self.expect_smoltcp_send()?; // Read from the smoltcp socket and push the data to the connection handler. - self.tunsocket_read_and_forward(&resolved_conn); + self.tunsocket_read_and_forward(&resolved_conn)?; // The connection handler builds up the connection or encapsulates the data. // Therefore, we now expect it to write data to the server. @@ -491,28 +488,25 @@ impl<'a> TunToProxy<'a> { if let Some(virtual_dns) = &mut self.options.virtdns { let payload = &frame[_payload_offset.._payload_offset + _payload_size]; if let Some(response) = virtual_dns.receive_query(payload) { - let rx_buffer = smoltcp::socket::udp::PacketBuffer::new( - vec![smoltcp::socket::udp::PacketMetadata::EMPTY], - vec![0; 4096], - ); - let tx_buffer = smoltcp::socket::udp::PacketBuffer::new( - vec![smoltcp::socket::udp::PacketMetadata::EMPTY], - vec![0; 4096], - ); - let mut socket = smoltcp::socket::udp::Socket::new(rx_buffer, tx_buffer); - let dst = SocketAddr::try_from(connection.dst).unwrap(); - socket.bind(dst).unwrap(); + let rx_buffer = + udp::PacketBuffer::new(vec![udp::PacketMetadata::EMPTY], vec![0; 4096]); + let tx_buffer = + udp::PacketBuffer::new(vec![udp::PacketMetadata::EMPTY], vec![0; 4096]); + let mut socket = udp::Socket::new(rx_buffer, tx_buffer); + let dst = SocketAddr::try_from(connection.dst)?; + socket.bind(dst)?; socket .send_slice(response.as_slice(), resolved_conn.src.into()) .expect("failed to send DNS response"); let handle = self.sockets.add(socket); - self.expect_smoltcp_send(); + self.expect_smoltcp_send()?; self.sockets.remove(handle); } } // Otherwise, UDP is not yet supported. } } + Ok(()) } fn write_to_server(&mut self, connection: &Connection) { @@ -538,7 +532,7 @@ impl<'a> TunToProxy<'a> { } } - fn write_to_client(&mut self, token: Token, connection: &Connection) { + fn write_to_client(&mut self, token: Token, connection: &Connection) -> Result<(), Error> { loop { if let Some(state) = self.connections.get_mut(connection) { let socket_state = state.smoltcp_socket_state; @@ -549,11 +543,11 @@ impl<'a> TunToProxy<'a> { { let socket = self.sockets.get_mut::(socket_handle); if socket.may_send() { - consumed = socket.send_slice(event.buffer).unwrap(); + consumed = socket.send_slice(event.buffer)?; state .handler .consume_data(OutgoingDirection::ToClient, consumed); - self.expect_smoltcp_send(); + self.expect_smoltcp_send()?; if consumed < buflen { self.write_sockets.insert(token); break; @@ -570,97 +564,102 @@ impl<'a> TunToProxy<'a> { let socket = self.sockets.get_mut::(socket_handle); if socket_state & WRITE_CLOSED != 0 && consumed == buflen { socket.close(); - self.expect_smoltcp_send(); + self.expect_smoltcp_send()?; self.write_sockets.remove(&token); - self.remove_connection(connection); + self.remove_connection(connection)?; break; } } } + Ok(()) } - fn tun_event(&mut self, event: &Event) { + fn tun_event(&mut self, event: &Event) -> Result<(), Error> { if event.is_readable() { while let Some((rx_token, _)) = self.tun.receive(Instant::now()) { - rx_token.consume(|frame| { - self.receive_tun(frame); - }); + rx_token.consume(|frame| self.receive_tun(frame))?; } } + Ok(()) } - fn send_to_smoltcp(&mut self) { + fn send_to_smoltcp(&mut self) -> Result<(), Error> { let cloned = self.write_sockets.clone(); for token in cloned.iter() { if let Some(connection) = self.token_to_connection.get(token) { - self.write_to_client(*token, &connection.clone()); + self.write_to_client(*token, &connection.clone())?; } } + Ok(()) } - fn mio_socket_event(&mut self, event: &Event) { - if let Some(conn_ref) = self.token_to_connection.get(&event.token()) { - let connection = conn_ref.clone(); - if event.is_readable() || event.is_read_closed() { - { - let state = self.connections.get_mut(&connection).unwrap(); + fn mio_socket_event(&mut self, event: &Event) -> Result<(), Error> { + let e = "connection not found"; + let conn_ref = self.token_to_connection.get(&event.token()); + if conn_ref.is_none() { + return Ok(()); + } + let connection = conn_ref.ok_or(e)?.clone(); + if event.is_readable() || event.is_read_closed() { + { + let state = self.connections.get_mut(&connection).ok_or(e)?; - // TODO: Move this reading process to its own function. - let mut vecbuf = Vec::::new(); - let read_result = state.mio_stream.read_to_end(&mut vecbuf); - let read = match read_result { - Ok(read_result) => read_result, - Err(error) => { - if error.kind() != std::io::ErrorKind::WouldBlock { - error!("READ from proxy: {}", error); - } - vecbuf.len() + // TODO: Move this reading process to its own function. + let mut vecbuf = Vec::::new(); + let read_result = state.mio_stream.read_to_end(&mut vecbuf); + let read = match read_result { + Ok(read_result) => read_result, + Err(error) => { + if error.kind() != std::io::ErrorKind::WouldBlock { + error!("READ from proxy: {}", error); } - }; + vecbuf.len() + } + }; - if read == 0 { - { - let socket = self.sockets.get_mut::( - self.connections.get(&connection).unwrap().smoltcp_handle, - ); - socket.close(); - } - self.expect_smoltcp_send(); - self.remove_connection(&connection.clone()); - return; - } - - let data = vecbuf.as_slice(); - let data_event = IncomingDataEvent { - direction: IncomingDirection::FromServer, - buffer: &data[0..read], - }; - if let Err(error) = state.handler.push_data(data_event) { - state.mio_stream.shutdown(Both).unwrap(); - { - let socket = self.sockets.get_mut::( - self.connections.get(&connection).unwrap().smoltcp_handle, - ); - socket.close(); - } - self.expect_smoltcp_send(); - log::error! {"{error}"} - self.remove_connection(&connection.clone()); - return; - } - if event.is_read_closed() { - state.smoltcp_socket_state |= WRITE_CLOSED; + if read == 0 { + { + let socket = self.sockets.get_mut::( + self.connections.get(&connection).ok_or(e)?.smoltcp_handle, + ); + socket.close(); } + self.expect_smoltcp_send()?; + self.remove_connection(&connection.clone())?; + return Ok(()); } - // We have read from the proxy server and pushed the data to the connection handler. - // Thus, expect data to be processed (e.g. decapsulated) and forwarded to the client. - self.write_to_client(event.token(), &connection); - } - if event.is_writable() { - self.write_to_server(&connection); + let data = vecbuf.as_slice(); + let data_event = IncomingDataEvent { + direction: IncomingDirection::FromServer, + buffer: &data[0..read], + }; + if let Err(error) = state.handler.push_data(data_event) { + state.mio_stream.shutdown(Both)?; + { + let socket = self.sockets.get_mut::( + self.connections.get(&connection).ok_or(e)?.smoltcp_handle, + ); + socket.close(); + } + self.expect_smoltcp_send()?; + log::error! {"{error}"} + self.remove_connection(&connection.clone())?; + return Ok(()); + } + if event.is_read_closed() { + state.smoltcp_socket_state |= WRITE_CLOSED; + } } + + // We have read from the proxy server and pushed the data to the connection handler. + // Thus, expect data to be processed (e.g. decapsulated) and forwarded to the client. + self.write_to_client(event.token(), &connection)?; } + if event.is_writable() { + self.write_to_server(&connection); + } + Ok(()) } fn udp_event(&mut self, _event: &Event) {} @@ -672,12 +671,12 @@ impl<'a> TunToProxy<'a> { self.poll.poll(&mut events, None)?; for event in events.iter() { match event.token() { - TCP_TOKEN => self.tun_event(event), + TCP_TOKEN => self.tun_event(event)?, UDP_TOKEN => self.udp_event(event), - _ => self.mio_socket_event(event), + _ => self.mio_socket_event(event)?, } } - self.send_to_smoltcp(); + self.send_to_smoltcp()?; } } } diff --git a/src/virtdns.rs b/src/virtdns.rs index 83eb94c..9132090 100644 --- a/src/virtdns.rs +++ b/src/virtdns.rs @@ -133,7 +133,7 @@ impl VirtualDns { Some(response) } - fn increment_ip(addr: IpAddr) -> IpAddr { + fn increment_ip(addr: IpAddr) -> Option { let mut ip_bytes = match addr as IpAddr { IpAddr::V4(ip) => Vec::::from(ip.octets()), IpAddr::V6(ip) => Vec::::from(ip.octets()), @@ -151,13 +151,14 @@ impl VirtualDns { ip_bytes[i] = 0; } } - if addr.is_ipv4() { - let bytes: [u8; 4] = ip_bytes.as_slice().try_into().unwrap(); + let addr = if addr.is_ipv4() { + let bytes: [u8; 4] = ip_bytes.as_slice().try_into().ok()?; IpAddr::V4(Ipv4Addr::from(bytes)) } else { - let bytes: [u8; 16] = ip_bytes.as_slice().try_into().unwrap(); + let bytes: [u8; 16] = ip_bytes.as_slice().try_into().ok()?; IpAddr::V6(Ipv6Addr::from(bytes)) - } + }; + Some(addr) } pub fn ip_to_name(&self, addr: &IpAddr) -> Option<&String> { @@ -168,7 +169,7 @@ impl VirtualDns { let now = Instant::now(); while let Some((ip, expiry)) = self.expiry.front() { if now > *expiry { - let name = self.ip_to_name.remove(ip).unwrap(); + let name = self.ip_to_name.remove(ip)?; self.name_to_ip.remove(&name); self.expiry.pop_front(); } else { @@ -194,7 +195,7 @@ impl VirtualDns { )); return Some(self.next_addr); } - self.next_addr = Self::increment_ip(self.next_addr); + self.next_addr = Self::increment_ip(self.next_addr)?; if self.next_addr == self.broadcast_addr { // Wrap around. self.next_addr = self.network_addr;