diff --git a/src/tun2proxy.rs b/src/tun2proxy.rs index f71e93f..f418f6a 100644 --- a/src/tun2proxy.rs +++ b/src/tun2proxy.rs @@ -192,6 +192,7 @@ struct ConnectionState { udp_data_cache: LinkedList>, dns_over_tcp_expiry: Option<::std::time::Instant>, is_tcp_closed: bool, + continue_read: bool, } pub(crate) trait ProxyHandler { @@ -267,16 +268,17 @@ impl<'a> TunToProxy<'a> { let poll = Poll::new()?; + let interests = Interest::READABLE | Interest::WRITABLE; + #[cfg(target_family = "unix")] poll.registry() - .register(&mut SourceFd(&tun.as_raw_fd()), TUN_TOKEN, Interest::READABLE)?; + .register(&mut SourceFd(&tun.as_raw_fd()), TUN_TOKEN, interests)?; #[cfg(target_os = "windows")] { - let interest = Interest::READABLE | Interest::WRITABLE; - poll.registry().register(&mut tun, TUN_TOKEN, interest)?; + poll.registry().register(&mut tun, TUN_TOKEN, interests)?; let mut pipe = NamedPipeSource(tun.pipe_client()); - poll.registry().register(&mut pipe, PIPE_TOKEN, interest)?; + poll.registry().register(&mut pipe, PIPE_TOKEN, interests)?; } #[cfg(target_family = "unix")] @@ -845,6 +847,7 @@ impl<'a> TunToProxy<'a> { udp_data_cache: LinkedList::new(), dns_over_tcp_expiry: None, is_tcp_closed: false, + continue_read: false, }; Ok(state) } @@ -968,6 +971,32 @@ impl<'a> TunToProxy<'a> { rx_token.consume(|frame| self.receive_tun(frame))?; } } + + #[cfg(unix)] + if event.is_writable() { + log::info!("Tun writable, sessions count: {}", self.connection_map.len()); + + let item = self.connection_map.iter().find(|(_, state)| state.continue_read); + if let Some((conn_info, _)) = item { + let conn_info = conn_info.clone(); + let (success, len) = self.read_server_n_write_proxy_handler(&conn_info)?; + if !success { + return Ok(()); + } + let e = "connection state not found"; + let state = self.connection_map.get_mut(&conn_info).ok_or(e)?; + + if len == 0 || event.is_read_closed() { + state.wait_read = false; + state.close_state |= SERVER_WRITE_CLOSED; + Self::update_mio_socket_interest(&mut self.poll, state)?; + self.check_change_close_state(&conn_info)?; + self.expect_smoltcp_send()?; + } + self.write_to_client(&conn_info)?; + } + } + #[cfg(target_os = "windows")] if event.is_writable() { // log::trace!("Tun writable"); @@ -1051,15 +1080,22 @@ impl<'a> TunToProxy<'a> { fn read_server_n_write_proxy_handler(&mut self, conn_info: &ConnectionInfo) -> Result<(bool, usize), Error> { let e = "connection state not found"; let state = self.connection_map.get_mut(conn_info).ok_or(e)?; + state.continue_read = false; let mut vecbuf = vec![]; + use std::io::{Error, ErrorKind}; let r = Self::read_data_from_tcp_stream(&mut state.mio_stream, &mut state.is_tcp_closed, |data| { vecbuf.extend_from_slice(data); + if vecbuf.len() >= IP_PACKAGE_MAX_SIZE { + return Err(Error::new(ErrorKind::OutOfMemory, "IP_PACKAGE_MAX_SIZE exceeded")); + } Ok(()) }); let len = vecbuf.len(); if let Err(error) = r { - { + if error.kind() == ErrorKind::OutOfMemory { + state.continue_read = true; + } else { log::error!("{}", error); self.remove_connection(conn_info)?; return Ok((false, len));