From 0be39345a8856a93e2a9f381b8c9c544d7101b29 Mon Sep 17 00:00:00 2001 From: "B. Blechschmidt" Date: Mon, 3 Apr 2023 20:31:31 +0200 Subject: [PATCH] Improve handling of half-open connections --- src/tun2proxy.rs | 202 +++++++++++++++++++++++++++-------------------- 1 file changed, 115 insertions(+), 87 deletions(-) diff --git a/src/tun2proxy.rs b/src/tun2proxy.rs index e6453c8..a2c9ebb 100644 --- a/src/tun2proxy.rs +++ b/src/tun2proxy.rs @@ -8,6 +8,7 @@ use mio::unix::SourceFd; use mio::{Events, Interest, Poll, Token}; use smoltcp::iface::{Config, Interface, SocketHandle, SocketSet}; use smoltcp::phy::{Device, Medium, RxToken, TunTapInterface, TxToken}; +use smoltcp::socket::tcp::State; use smoltcp::socket::{tcp, udp}; use smoltcp::time::Instant; use smoltcp::wire::{IpCidr, IpProtocol, Ipv4Packet, Ipv6Packet, TcpPacket, UdpPacket}; @@ -15,7 +16,7 @@ use std::collections::{HashMap, HashSet}; use std::convert::{From, TryFrom}; use std::io::{Read, Write}; use std::net::Shutdown::Both; -use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, Shutdown, SocketAddr}; use std::os::unix::io::AsRawFd; use std::rc::Rc; use std::str::FromStr; @@ -204,14 +205,15 @@ fn connection_tuple(frame: &[u8]) -> Option<(Connection, bool, usize, usize)> { } } -const WRITE_CLOSED: u8 = 1; +const SERVER_WRITE_CLOSED: u8 = 1; +const CLIENT_WRITE_CLOSED: u8 = 2; struct ConnectionState { smoltcp_handle: SocketHandle, mio_stream: TcpStream, token: Token, handler: Box, - smoltcp_socket_state: u8, + close_state: u8, } pub(crate) trait TcpProxy { @@ -330,42 +332,75 @@ impl<'a> TunToProxy<'a> { None } - fn tunsocket_read_and_forward(&mut self, connection: &Connection) -> Result<(), Error> { - if let Some(state) = self.connections.get_mut(connection) { - let closed = { + fn check_change_close_state(&mut self, connection: &Connection) -> Result<(), Error> { + let state = self + .connections + .get_mut(connection) + .ok_or("connection does not exist")?; + let mut closed_ends = 0; + if (state.close_state & SERVER_WRITE_CLOSED) == SERVER_WRITE_CLOSED { + //info!("Server write closed"); + let event = state.handler.peek_data(OutgoingDirection::ToClient); + if event.buffer.is_empty() { + //info!("Server write closed and consumed"); let socket = self.sockets.get_mut::(state.smoltcp_handle); - let mut error = Ok(()); - while socket.can_recv() && error.is_ok() { - socket.recv(|data| { - let event = IncomingDataEvent { - direction: IncomingDirection::FromClient, - buffer: data, - }; - error = state.handler.push_data(event); + socket.close(); + closed_ends += 1; + } + } - (data.len(), ()) - })?; - } + if (state.close_state & CLIENT_WRITE_CLOSED) == CLIENT_WRITE_CLOSED { + //info!("Client write closed"); + let event = state.handler.peek_data(OutgoingDirection::ToServer); + if event.buffer.is_empty() { + //info!("Client write closed and consumed"); + _ = state.mio_stream.shutdown(Shutdown::Write); + closed_ends += 1; + } + } - match error { - Ok(_) => socket.state() == tcp::State::CloseWait, - Err(e) => { - log::error!("{e}"); - true - } - } - }; + if closed_ends == 2 { + self.remove_connection(connection)?; + } + Ok(()) + } + + fn tunsocket_read_and_forward(&mut self, connection: &Connection) -> Result<(), Error> { + // Scope for mutable borrow of self. + { + let state = self + .connections + .get_mut(connection) + .ok_or("connection does not exist")?; + let socket = self.sockets.get_mut::(state.smoltcp_handle); + let mut error = Ok(()); + while socket.can_recv() && error.is_ok() { + socket.recv(|data| { + let event = IncomingDataEvent { + direction: IncomingDirection::FromClient, + buffer: data, + }; + error = state.handler.push_data(event); + (data.len(), ()) + })?; + } + + if !socket.may_recv() + && socket.state() != State::Listen + && socket.state() != State::SynSent + && socket.state() != State::SynReceived + { + // We cannot yet close the write end of the mio stream here because we may still + // need to send data. + state.close_state |= CLIENT_WRITE_CLOSED; + } // Expect ACKs etc. from smoltcp sockets. self.expect_smoltcp_send()?; - - if closed { - let e = "connection not exist"; - let connection_state = self.connections.get_mut(connection).ok_or(e)?; - connection_state.mio_stream.shutdown(Both)?; - self.remove_connection(connection)?; - } } + + self.check_change_close_state(connection)?; + Ok(()) } @@ -417,7 +452,7 @@ impl<'a> TunToProxy<'a> { mio_stream: client, token, handler, - smoltcp_socket_state: 0, + close_state: 0, }; self.token_to_connection @@ -491,6 +526,7 @@ impl<'a> TunToProxy<'a> { if let Some(state) = self.connections.get_mut(connection) { let event = state.handler.peek_data(OutgoingDirection::ToServer); if event.buffer.is_empty() { + self.check_change_close_state(connection)?; return Ok(()); } let result = state.mio_stream.write(event.buffer); @@ -510,51 +546,52 @@ impl<'a> TunToProxy<'a> { } fn write_to_client(&mut self, token: Token, connection: &Connection) -> Result<(), Error> { - loop { - if let Some(state) = self.connections.get_mut(connection) { - let socket_state = state.smoltcp_socket_state; - let socket_handle = state.smoltcp_handle; - let event = state.handler.peek_data(OutgoingDirection::ToClient); - let buflen = event.buffer.len(); - let consumed; - { - let socket = self.sockets.get_mut::(socket_handle); - if socket.may_send() { - if let Some(virtdns) = &mut self.options.virtdns { - // Unwrapping is fine because every smoltcp socket is bound to an. - virtdns.touch_ip(&IpAddr::from(socket.local_endpoint().unwrap().addr)); - } - consumed = socket.send_slice(event.buffer)?; - state - .handler - .consume_data(OutgoingDirection::ToClient, consumed); - self.expect_smoltcp_send()?; - if consumed < buflen { - self.write_sockets.insert(token); - break; - } else { - self.write_sockets.remove(&token); - if consumed == 0 { - break; - } - } - } else { - break; - } - } + while let Some(state) = self.connections.get_mut(connection) { + let socket_handle = state.smoltcp_handle; + let event = state.handler.peek_data(OutgoingDirection::ToClient); + let buflen = event.buffer.len(); + let consumed; + { let socket = self.sockets.get_mut::(socket_handle); - // Closing and removing the connection here may work in practice but is actually not - // correct. Only the write end was closed but we could still read from it! - // TODO: Fix and test half-open connection scenarios as mentioned in the README. - // TODO: Investigate how half-closed connections from the other end are handled. - if socket_state & WRITE_CLOSED != 0 && consumed == buflen { - socket.close(); + if socket.may_send() { + if let Some(virtdns) = &mut self.options.virtdns { + // Unwrapping is fine because every smoltcp socket is bound to an. + virtdns.touch_ip(&IpAddr::from(socket.local_endpoint().unwrap().addr)); + } + consumed = socket.send_slice(event.buffer)?; + state + .handler + .consume_data(OutgoingDirection::ToClient, consumed); self.expect_smoltcp_send()?; - self.write_sockets.remove(&token); - self.remove_connection(connection)?; + if consumed < buflen { + self.write_sockets.insert(token); + break; + } else { + self.write_sockets.remove(&token); + if consumed == 0 { + break; + } + } + } else { break; } } + + self.check_change_close_state(connection)?; + + /*let socket = self.sockets.get_mut::(socket_handle); + // Closing and removing the connection here may work in practice but is actually not + // correct. Only the write end was closed but we could still read from it! + // TODO: Fix and test half-open connection scenarios as mentioned in the README. + // TODO: Investigate how half-closed connections from the other end are handled. + if socket_state & SERVER_WRITE_CLOSED != 0 && consumed == buflen { + info!("WRCL"); + socket.close(); + self.expect_smoltcp_send()?; + self.write_sockets.remove(&token); + self.remove_connection(connection)?; + break; + }*/ } Ok(()) } @@ -612,18 +649,6 @@ impl<'a> TunToProxy<'a> { } }; - if read == 0 { - { - let socket = self.sockets.get_mut::( - self.connections.get(&connection).ok_or(e)?.smoltcp_handle, - ); - socket.close(); - } - self.expect_smoltcp_send()?; - self.remove_connection(&connection.clone())?; - return Ok(()); - } - let data = vecbuf.as_slice(); let data_event = IncomingDataEvent { direction: IncomingDirection::FromServer, @@ -642,8 +667,11 @@ impl<'a> TunToProxy<'a> { self.remove_connection(&connection.clone())?; return Ok(()); } - if event.is_read_closed() { - state.smoltcp_socket_state |= WRITE_CLOSED; + + if read == 0 || event.is_read_closed() { + state.close_state |= SERVER_WRITE_CLOSED; + self.check_change_close_state(&connection)?; + self.expect_smoltcp_send()?; } }