diff --git a/src/http.rs b/src/http.rs index 0a5278a..0a5b320 100644 --- a/src/http.rs +++ b/src/http.rs @@ -1,7 +1,7 @@ use crate::error::Error; use crate::tun2proxy::{ - Connection, ConnectionManager, IncomingDataEvent, IncomingDirection, OutgoingDataEvent, - OutgoingDirection, TcpProxy, + Connection, ConnectionManager, Direction, IncomingDataEvent, IncomingDirection, + OutgoingDataEvent, OutgoingDirection, TcpProxy, }; use crate::Credentials; use base64::Engine; @@ -160,6 +160,21 @@ impl TcpProxy for HttpConnection { fn connection_established(&self) -> bool { self.state == HttpState::Established } + + fn have_data(&mut self, dir: Direction) -> bool { + match dir { + Direction::Incoming(incoming) => match incoming { + IncomingDirection::FromServer => self.server_inbuf.len() > 0, + IncomingDirection::FromClient => { + self.client_inbuf.len() > 0 || self.data_buf.len() > 0 + } + }, + Direction::Outgoing(outgoing) => match outgoing { + OutgoingDirection::ToServer => self.server_outbuf.len() > 0, + OutgoingDirection::ToClient => self.client_outbuf.len() > 0, + }, + } + } } pub(crate) struct HttpManager { diff --git a/src/socks5.rs b/src/socks5.rs index a4296af..c6d8c0a 100644 --- a/src/socks5.rs +++ b/src/socks5.rs @@ -1,7 +1,7 @@ use crate::error::Error; use crate::tun2proxy::{ - Connection, ConnectionManager, DestinationHost, IncomingDataEvent, IncomingDirection, - OutgoingDataEvent, OutgoingDirection, TcpProxy, + Connection, ConnectionManager, DestinationHost, Direction, IncomingDataEvent, + IncomingDirection, OutgoingDataEvent, OutgoingDirection, TcpProxy, }; use crate::Credentials; use smoltcp::wire::IpProtocol; @@ -368,6 +368,21 @@ impl TcpProxy for SocksConnection { fn connection_established(&self) -> bool { self.state == SocksState::Established } + + fn have_data(&mut self, dir: Direction) -> bool { + match dir { + Direction::Incoming(incoming) => match incoming { + IncomingDirection::FromServer => self.server_inbuf.len() > 0, + IncomingDirection::FromClient => { + self.client_inbuf.len() > 0 || self.data_buf.len() > 0 + } + }, + Direction::Outgoing(outgoing) => match outgoing { + OutgoingDirection::ToServer => self.server_outbuf.len() > 0, + OutgoingDirection::ToClient => self.client_outbuf.len() > 0, + }, + } + } } pub struct SocksManager { diff --git a/src/tun2proxy.rs b/src/tun2proxy.rs index a2c9ebb..e0ea424 100644 --- a/src/tun2proxy.rs +++ b/src/tun2proxy.rs @@ -21,7 +21,7 @@ use std::os::unix::io::AsRawFd; use std::rc::Rc; use std::str::FromStr; -#[derive(Hash, Clone, Eq, PartialEq)] +#[derive(Hash, Clone, Eq, PartialEq, Debug)] pub(crate) enum DestinationHost { Address(IpAddr), Hostname(String), @@ -36,7 +36,7 @@ impl std::fmt::Display for DestinationHost { } } -#[derive(Hash, Clone, Eq, PartialEq)] +#[derive(Hash, Clone, Eq, PartialEq, Debug)] pub(crate) struct Destination { pub(crate) host: DestinationHost, pub(crate) port: u16, @@ -74,7 +74,7 @@ impl std::fmt::Display for Destination { } } -#[derive(Hash, Clone, Eq, PartialEq)] +#[derive(Hash, Clone, Eq, PartialEq, Debug)] pub(crate) struct Connection { pub(crate) src: SocketAddr, pub(crate) dst: Destination, @@ -107,6 +107,12 @@ pub(crate) enum OutgoingDirection { ToClient, } +#[derive(Eq, PartialEq, Debug)] +pub(crate) enum Direction { + Incoming(IncomingDirection), + Outgoing(OutgoingDirection), +} + #[allow(dead_code)] pub(crate) enum ConnectionEvent<'a> { NewConnection(&'a Connection), @@ -214,6 +220,8 @@ struct ConnectionState { token: Token, handler: Box, close_state: u8, + wait_read: bool, + wait_write: bool, } pub(crate) trait TcpProxy { @@ -221,6 +229,7 @@ pub(crate) trait TcpProxy { fn consume_data(&mut self, dir: OutgoingDirection, size: usize); fn peek_data(&mut self, dir: OutgoingDirection) -> OutgoingDataEvent; fn connection_established(&self) -> bool; + fn have_data(&mut self, dir: Direction) -> bool; } pub(crate) trait ConnectionManager { @@ -314,12 +323,12 @@ impl<'a> TunToProxy<'a> { } fn remove_connection(&mut self, connection: &Connection) -> Result<(), Error> { - let e = "connection not exist"; - let mut conn = self.connections.remove(connection).ok_or(e)?; - let token = &conn.token; - self.token_to_connection.remove(token); - self.poll.registry().deregister(&mut conn.mio_stream)?; - info!("CLOSE {}", connection); + if let Some(mut conn) = self.connections.remove(connection) { + let token = &conn.token; + self.token_to_connection.remove(token); + _ = self.poll.registry().deregister(&mut conn.mio_stream); + info!("CLOSE {}", connection); + } Ok(()) } @@ -333,30 +342,35 @@ impl<'a> TunToProxy<'a> { } fn check_change_close_state(&mut self, connection: &Connection) -> Result<(), Error> { - let state = self - .connections - .get_mut(connection) - .ok_or("connection does not exist")?; + let state = self.connections.get_mut(connection); + if state.is_none() { + return Ok(()); + } + let state = state.unwrap(); let mut closed_ends = 0; - if (state.close_state & SERVER_WRITE_CLOSED) == SERVER_WRITE_CLOSED { - //info!("Server write closed"); - let event = state.handler.peek_data(OutgoingDirection::ToClient); - if event.buffer.is_empty() { - //info!("Server write closed and consumed"); - let socket = self.sockets.get_mut::(state.smoltcp_handle); - socket.close(); - closed_ends += 1; - } + if (state.close_state & SERVER_WRITE_CLOSED) == SERVER_WRITE_CLOSED + && !state + .handler + .have_data(Direction::Incoming(IncomingDirection::FromServer)) + && !state + .handler + .have_data(Direction::Outgoing(OutgoingDirection::ToClient)) + { + let socket = self.sockets.get_mut::(state.smoltcp_handle); + socket.close(); + closed_ends += 1; } - if (state.close_state & CLIENT_WRITE_CLOSED) == CLIENT_WRITE_CLOSED { - //info!("Client write closed"); - let event = state.handler.peek_data(OutgoingDirection::ToServer); - if event.buffer.is_empty() { - //info!("Client write closed and consumed"); - _ = state.mio_stream.shutdown(Shutdown::Write); - closed_ends += 1; - } + if (state.close_state & CLIENT_WRITE_CLOSED) == CLIENT_WRITE_CLOSED + && !state + .handler + .have_data(Direction::Incoming(IncomingDirection::FromClient)) + && !state + .handler + .have_data(Direction::Outgoing(OutgoingDirection::ToServer)) + { + _ = state.mio_stream.shutdown(Shutdown::Write); + closed_ends += 1; } if closed_ends == 2 { @@ -368,10 +382,11 @@ impl<'a> TunToProxy<'a> { fn tunsocket_read_and_forward(&mut self, connection: &Connection) -> Result<(), Error> { // Scope for mutable borrow of self. { - let state = self - .connections - .get_mut(connection) - .ok_or("connection does not exist")?; + let state = self.connections.get_mut(connection); + if state.is_none() { + return Ok(()); + } + let state = state.unwrap(); let socket = self.sockets.get_mut::(state.smoltcp_handle); let mut error = Ok(()); while socket.can_recv() && error.is_ok() { @@ -404,6 +419,38 @@ impl<'a> TunToProxy<'a> { Ok(()) } + // Update the poll registry depending on the connection's event interests. + fn update_mio_socket_interest(&mut self, connection: &Connection) -> Result<(), Error> { + let state = self + .connections + .get_mut(connection) + .ok_or("connection not found")?; + + // Maybe we did not listen for any events before. Therefore, just swallow the error. + _ = self.poll.registry().deregister(&mut state.mio_stream); + + // If we do not wait for read or write events, we do not need to register them. + if !state.wait_read && !state.wait_write { + return Ok(()); + } + + // This ugliness is due to the way Interest is implemented (as a NonZeroU8 wrapper). + let interest; + if state.wait_read && !state.wait_write { + interest = Interest::READABLE; + } else if state.wait_write && !state.wait_read { + interest = Interest::WRITABLE; + } else { + interest = Interest::READABLE | Interest::WRITABLE; + } + + self.poll + .registry() + .register(&mut state.mio_stream, state.token, interest)?; + Ok(()) + } + + // A raw packet was received on the tunnel interface. fn receive_tun(&mut self, frame: &mut [u8]) -> Result<(), Error> { if let Some((connection, first_packet, _payload_offset, _payload_size)) = connection_tuple(frame) @@ -453,6 +500,8 @@ impl<'a> TunToProxy<'a> { token, handler, close_state: 0, + wait_read: true, + wait_write: false, }; self.token_to_connection @@ -460,7 +509,7 @@ impl<'a> TunToProxy<'a> { self.poll.registry().register( &mut state.mio_stream, token, - Interest::READABLE | Interest::WRITABLE, + Interest::READABLE, )?; self.connections.insert(resolved_conn.clone(), state); @@ -525,23 +574,33 @@ impl<'a> TunToProxy<'a> { fn write_to_server(&mut self, connection: &Connection) -> Result<(), Error> { if let Some(state) = self.connections.get_mut(connection) { let event = state.handler.peek_data(OutgoingDirection::ToServer); - if event.buffer.is_empty() { + let buffer_size = event.buffer.len(); + if buffer_size == 0 { + state.wait_write = false; + self.update_mio_socket_interest(connection)?; self.check_change_close_state(connection)?; return Ok(()); } let result = state.mio_stream.write(event.buffer); match result { - Ok(consumed) => { + Ok(written) => { state .handler - .consume_data(OutgoingDirection::ToServer, consumed); + .consume_data(OutgoingDirection::ToServer, written); + state.wait_write = written < buffer_size; + self.update_mio_socket_interest(connection)?; } Err(error) if error.kind() != std::io::ErrorKind::WouldBlock => { return Err(error.into()); } - _ => {} + _ => { + // WOULDBLOCK case + state.wait_write = true; + self.update_mio_socket_interest(connection)?; + } } } + self.check_change_close_state(connection)?; Ok(()) } @@ -578,20 +637,6 @@ impl<'a> TunToProxy<'a> { } self.check_change_close_state(connection)?; - - /*let socket = self.sockets.get_mut::(socket_handle); - // Closing and removing the connection here may work in practice but is actually not - // correct. Only the write end was closed but we could still read from it! - // TODO: Fix and test half-open connection scenarios as mentioned in the README. - // TODO: Investigate how half-closed connections from the other end are handled. - if socket_state & SERVER_WRITE_CLOSED != 0 && consumed == buflen { - info!("WRCL"); - socket.close(); - self.expect_smoltcp_send()?; - self.write_sockets.remove(&token); - self.remove_connection(connection)?; - break; - }*/ } Ok(()) } @@ -669,7 +714,9 @@ impl<'a> TunToProxy<'a> { } if read == 0 || event.is_read_closed() { + state.wait_read = false; state.close_state |= SERVER_WRITE_CLOSED; + self.update_mio_socket_interest(&connection)?; self.check_change_close_state(&connection)?; self.expect_smoltcp_send()?; } @@ -678,15 +725,21 @@ impl<'a> TunToProxy<'a> { // We have read from the proxy server and pushed the data to the connection handler. // Thus, expect data to be processed (e.g. decapsulated) and forwarded to the client. self.write_to_client(event.token(), &connection)?; + + // The connection handler could have produced data that is to be written to the + // server. + self.write_to_server(&connection)?; } + if event.is_writable() { self.write_to_server(&connection)?; } + Ok(()) })() .or_else(|error| { - self.remove_connection(&connection)?; log::error! {"{error}"} + self.remove_connection(&connection)?; Ok(()) }) } @@ -695,7 +748,6 @@ impl<'a> TunToProxy<'a> { pub(crate) fn run(&mut self) -> Result<(), Error> { let mut events = Events::with_capacity(1024); - loop { match self.poll.poll(&mut events, None) { Ok(()) => { @@ -711,6 +763,8 @@ impl<'a> TunToProxy<'a> { Err(e) => { if e.kind() != std::io::ErrorKind::Interrupted { return Err(e.into()); + } else { + log::warn!("Poll interrupted") } } }