Fix CPU spikes due to always-writable event and improve half-open connection handling

This commit is contained in:
B. Blechschmidt 2023-04-04 00:18:50 +02:00
parent 0be39345a8
commit 10a674d1c9
3 changed files with 143 additions and 59 deletions

View file

@ -1,7 +1,7 @@
use crate::error::Error; use crate::error::Error;
use crate::tun2proxy::{ use crate::tun2proxy::{
Connection, ConnectionManager, IncomingDataEvent, IncomingDirection, OutgoingDataEvent, Connection, ConnectionManager, Direction, IncomingDataEvent, IncomingDirection,
OutgoingDirection, TcpProxy, OutgoingDataEvent, OutgoingDirection, TcpProxy,
}; };
use crate::Credentials; use crate::Credentials;
use base64::Engine; use base64::Engine;
@ -160,6 +160,21 @@ impl TcpProxy for HttpConnection {
fn connection_established(&self) -> bool { fn connection_established(&self) -> bool {
self.state == HttpState::Established self.state == HttpState::Established
} }
fn have_data(&mut self, dir: Direction) -> bool {
match dir {
Direction::Incoming(incoming) => match incoming {
IncomingDirection::FromServer => self.server_inbuf.len() > 0,
IncomingDirection::FromClient => {
self.client_inbuf.len() > 0 || self.data_buf.len() > 0
}
},
Direction::Outgoing(outgoing) => match outgoing {
OutgoingDirection::ToServer => self.server_outbuf.len() > 0,
OutgoingDirection::ToClient => self.client_outbuf.len() > 0,
},
}
}
} }
pub(crate) struct HttpManager { pub(crate) struct HttpManager {

View file

@ -1,7 +1,7 @@
use crate::error::Error; use crate::error::Error;
use crate::tun2proxy::{ use crate::tun2proxy::{
Connection, ConnectionManager, DestinationHost, IncomingDataEvent, IncomingDirection, Connection, ConnectionManager, DestinationHost, Direction, IncomingDataEvent,
OutgoingDataEvent, OutgoingDirection, TcpProxy, IncomingDirection, OutgoingDataEvent, OutgoingDirection, TcpProxy,
}; };
use crate::Credentials; use crate::Credentials;
use smoltcp::wire::IpProtocol; use smoltcp::wire::IpProtocol;
@ -368,6 +368,21 @@ impl TcpProxy for SocksConnection {
fn connection_established(&self) -> bool { fn connection_established(&self) -> bool {
self.state == SocksState::Established self.state == SocksState::Established
} }
fn have_data(&mut self, dir: Direction) -> bool {
match dir {
Direction::Incoming(incoming) => match incoming {
IncomingDirection::FromServer => self.server_inbuf.len() > 0,
IncomingDirection::FromClient => {
self.client_inbuf.len() > 0 || self.data_buf.len() > 0
}
},
Direction::Outgoing(outgoing) => match outgoing {
OutgoingDirection::ToServer => self.server_outbuf.len() > 0,
OutgoingDirection::ToClient => self.client_outbuf.len() > 0,
},
}
}
} }
pub struct SocksManager { pub struct SocksManager {

View file

@ -21,7 +21,7 @@ use std::os::unix::io::AsRawFd;
use std::rc::Rc; use std::rc::Rc;
use std::str::FromStr; use std::str::FromStr;
#[derive(Hash, Clone, Eq, PartialEq)] #[derive(Hash, Clone, Eq, PartialEq, Debug)]
pub(crate) enum DestinationHost { pub(crate) enum DestinationHost {
Address(IpAddr), Address(IpAddr),
Hostname(String), Hostname(String),
@ -36,7 +36,7 @@ impl std::fmt::Display for DestinationHost {
} }
} }
#[derive(Hash, Clone, Eq, PartialEq)] #[derive(Hash, Clone, Eq, PartialEq, Debug)]
pub(crate) struct Destination { pub(crate) struct Destination {
pub(crate) host: DestinationHost, pub(crate) host: DestinationHost,
pub(crate) port: u16, pub(crate) port: u16,
@ -74,7 +74,7 @@ impl std::fmt::Display for Destination {
} }
} }
#[derive(Hash, Clone, Eq, PartialEq)] #[derive(Hash, Clone, Eq, PartialEq, Debug)]
pub(crate) struct Connection { pub(crate) struct Connection {
pub(crate) src: SocketAddr, pub(crate) src: SocketAddr,
pub(crate) dst: Destination, pub(crate) dst: Destination,
@ -107,6 +107,12 @@ pub(crate) enum OutgoingDirection {
ToClient, ToClient,
} }
#[derive(Eq, PartialEq, Debug)]
pub(crate) enum Direction {
Incoming(IncomingDirection),
Outgoing(OutgoingDirection),
}
#[allow(dead_code)] #[allow(dead_code)]
pub(crate) enum ConnectionEvent<'a> { pub(crate) enum ConnectionEvent<'a> {
NewConnection(&'a Connection), NewConnection(&'a Connection),
@ -214,6 +220,8 @@ struct ConnectionState {
token: Token, token: Token,
handler: Box<dyn TcpProxy>, handler: Box<dyn TcpProxy>,
close_state: u8, close_state: u8,
wait_read: bool,
wait_write: bool,
} }
pub(crate) trait TcpProxy { pub(crate) trait TcpProxy {
@ -221,6 +229,7 @@ pub(crate) trait TcpProxy {
fn consume_data(&mut self, dir: OutgoingDirection, size: usize); fn consume_data(&mut self, dir: OutgoingDirection, size: usize);
fn peek_data(&mut self, dir: OutgoingDirection) -> OutgoingDataEvent; fn peek_data(&mut self, dir: OutgoingDirection) -> OutgoingDataEvent;
fn connection_established(&self) -> bool; fn connection_established(&self) -> bool;
fn have_data(&mut self, dir: Direction) -> bool;
} }
pub(crate) trait ConnectionManager { pub(crate) trait ConnectionManager {
@ -314,12 +323,12 @@ impl<'a> TunToProxy<'a> {
} }
fn remove_connection(&mut self, connection: &Connection) -> Result<(), Error> { fn remove_connection(&mut self, connection: &Connection) -> Result<(), Error> {
let e = "connection not exist"; if let Some(mut conn) = self.connections.remove(connection) {
let mut conn = self.connections.remove(connection).ok_or(e)?;
let token = &conn.token; let token = &conn.token;
self.token_to_connection.remove(token); self.token_to_connection.remove(token);
self.poll.registry().deregister(&mut conn.mio_stream)?; _ = self.poll.registry().deregister(&mut conn.mio_stream);
info!("CLOSE {}", connection); info!("CLOSE {}", connection);
}
Ok(()) Ok(())
} }
@ -333,31 +342,36 @@ impl<'a> TunToProxy<'a> {
} }
fn check_change_close_state(&mut self, connection: &Connection) -> Result<(), Error> { fn check_change_close_state(&mut self, connection: &Connection) -> Result<(), Error> {
let state = self let state = self.connections.get_mut(connection);
.connections if state.is_none() {
.get_mut(connection) return Ok(());
.ok_or("connection does not exist")?; }
let state = state.unwrap();
let mut closed_ends = 0; let mut closed_ends = 0;
if (state.close_state & SERVER_WRITE_CLOSED) == SERVER_WRITE_CLOSED { if (state.close_state & SERVER_WRITE_CLOSED) == SERVER_WRITE_CLOSED
//info!("Server write closed"); && !state
let event = state.handler.peek_data(OutgoingDirection::ToClient); .handler
if event.buffer.is_empty() { .have_data(Direction::Incoming(IncomingDirection::FromServer))
//info!("Server write closed and consumed"); && !state
.handler
.have_data(Direction::Outgoing(OutgoingDirection::ToClient))
{
let socket = self.sockets.get_mut::<tcp::Socket>(state.smoltcp_handle); let socket = self.sockets.get_mut::<tcp::Socket>(state.smoltcp_handle);
socket.close(); socket.close();
closed_ends += 1; closed_ends += 1;
} }
}
if (state.close_state & CLIENT_WRITE_CLOSED) == CLIENT_WRITE_CLOSED { if (state.close_state & CLIENT_WRITE_CLOSED) == CLIENT_WRITE_CLOSED
//info!("Client write closed"); && !state
let event = state.handler.peek_data(OutgoingDirection::ToServer); .handler
if event.buffer.is_empty() { .have_data(Direction::Incoming(IncomingDirection::FromClient))
//info!("Client write closed and consumed"); && !state
.handler
.have_data(Direction::Outgoing(OutgoingDirection::ToServer))
{
_ = state.mio_stream.shutdown(Shutdown::Write); _ = state.mio_stream.shutdown(Shutdown::Write);
closed_ends += 1; closed_ends += 1;
} }
}
if closed_ends == 2 { if closed_ends == 2 {
self.remove_connection(connection)?; self.remove_connection(connection)?;
@ -368,10 +382,11 @@ impl<'a> TunToProxy<'a> {
fn tunsocket_read_and_forward(&mut self, connection: &Connection) -> Result<(), Error> { fn tunsocket_read_and_forward(&mut self, connection: &Connection) -> Result<(), Error> {
// Scope for mutable borrow of self. // Scope for mutable borrow of self.
{ {
let state = self let state = self.connections.get_mut(connection);
.connections if state.is_none() {
.get_mut(connection) return Ok(());
.ok_or("connection does not exist")?; }
let state = state.unwrap();
let socket = self.sockets.get_mut::<tcp::Socket>(state.smoltcp_handle); let socket = self.sockets.get_mut::<tcp::Socket>(state.smoltcp_handle);
let mut error = Ok(()); let mut error = Ok(());
while socket.can_recv() && error.is_ok() { while socket.can_recv() && error.is_ok() {
@ -404,6 +419,38 @@ impl<'a> TunToProxy<'a> {
Ok(()) Ok(())
} }
// Update the poll registry depending on the connection's event interests.
fn update_mio_socket_interest(&mut self, connection: &Connection) -> Result<(), Error> {
let state = self
.connections
.get_mut(connection)
.ok_or("connection not found")?;
// Maybe we did not listen for any events before. Therefore, just swallow the error.
_ = self.poll.registry().deregister(&mut state.mio_stream);
// If we do not wait for read or write events, we do not need to register them.
if !state.wait_read && !state.wait_write {
return Ok(());
}
// This ugliness is due to the way Interest is implemented (as a NonZeroU8 wrapper).
let interest;
if state.wait_read && !state.wait_write {
interest = Interest::READABLE;
} else if state.wait_write && !state.wait_read {
interest = Interest::WRITABLE;
} else {
interest = Interest::READABLE | Interest::WRITABLE;
}
self.poll
.registry()
.register(&mut state.mio_stream, state.token, interest)?;
Ok(())
}
// A raw packet was received on the tunnel interface.
fn receive_tun(&mut self, frame: &mut [u8]) -> Result<(), Error> { fn receive_tun(&mut self, frame: &mut [u8]) -> Result<(), Error> {
if let Some((connection, first_packet, _payload_offset, _payload_size)) = if let Some((connection, first_packet, _payload_offset, _payload_size)) =
connection_tuple(frame) connection_tuple(frame)
@ -453,6 +500,8 @@ impl<'a> TunToProxy<'a> {
token, token,
handler, handler,
close_state: 0, close_state: 0,
wait_read: true,
wait_write: false,
}; };
self.token_to_connection self.token_to_connection
@ -460,7 +509,7 @@ impl<'a> TunToProxy<'a> {
self.poll.registry().register( self.poll.registry().register(
&mut state.mio_stream, &mut state.mio_stream,
token, token,
Interest::READABLE | Interest::WRITABLE, Interest::READABLE,
)?; )?;
self.connections.insert(resolved_conn.clone(), state); self.connections.insert(resolved_conn.clone(), state);
@ -525,23 +574,33 @@ impl<'a> TunToProxy<'a> {
fn write_to_server(&mut self, connection: &Connection) -> Result<(), Error> { fn write_to_server(&mut self, connection: &Connection) -> Result<(), Error> {
if let Some(state) = self.connections.get_mut(connection) { if let Some(state) = self.connections.get_mut(connection) {
let event = state.handler.peek_data(OutgoingDirection::ToServer); let event = state.handler.peek_data(OutgoingDirection::ToServer);
if event.buffer.is_empty() { let buffer_size = event.buffer.len();
if buffer_size == 0 {
state.wait_write = false;
self.update_mio_socket_interest(connection)?;
self.check_change_close_state(connection)?; self.check_change_close_state(connection)?;
return Ok(()); return Ok(());
} }
let result = state.mio_stream.write(event.buffer); let result = state.mio_stream.write(event.buffer);
match result { match result {
Ok(consumed) => { Ok(written) => {
state state
.handler .handler
.consume_data(OutgoingDirection::ToServer, consumed); .consume_data(OutgoingDirection::ToServer, written);
state.wait_write = written < buffer_size;
self.update_mio_socket_interest(connection)?;
} }
Err(error) if error.kind() != std::io::ErrorKind::WouldBlock => { Err(error) if error.kind() != std::io::ErrorKind::WouldBlock => {
return Err(error.into()); return Err(error.into());
} }
_ => {} _ => {
// WOULDBLOCK case
state.wait_write = true;
self.update_mio_socket_interest(connection)?;
} }
} }
}
self.check_change_close_state(connection)?;
Ok(()) Ok(())
} }
@ -578,20 +637,6 @@ impl<'a> TunToProxy<'a> {
} }
self.check_change_close_state(connection)?; self.check_change_close_state(connection)?;
/*let socket = self.sockets.get_mut::<tcp::Socket>(socket_handle);
// Closing and removing the connection here may work in practice but is actually not
// correct. Only the write end was closed but we could still read from it!
// TODO: Fix and test half-open connection scenarios as mentioned in the README.
// TODO: Investigate how half-closed connections from the other end are handled.
if socket_state & SERVER_WRITE_CLOSED != 0 && consumed == buflen {
info!("WRCL");
socket.close();
self.expect_smoltcp_send()?;
self.write_sockets.remove(&token);
self.remove_connection(connection)?;
break;
}*/
} }
Ok(()) Ok(())
} }
@ -669,7 +714,9 @@ impl<'a> TunToProxy<'a> {
} }
if read == 0 || event.is_read_closed() { if read == 0 || event.is_read_closed() {
state.wait_read = false;
state.close_state |= SERVER_WRITE_CLOSED; state.close_state |= SERVER_WRITE_CLOSED;
self.update_mio_socket_interest(&connection)?;
self.check_change_close_state(&connection)?; self.check_change_close_state(&connection)?;
self.expect_smoltcp_send()?; self.expect_smoltcp_send()?;
} }
@ -678,15 +725,21 @@ impl<'a> TunToProxy<'a> {
// We have read from the proxy server and pushed the data to the connection handler. // We have read from the proxy server and pushed the data to the connection handler.
// Thus, expect data to be processed (e.g. decapsulated) and forwarded to the client. // Thus, expect data to be processed (e.g. decapsulated) and forwarded to the client.
self.write_to_client(event.token(), &connection)?; self.write_to_client(event.token(), &connection)?;
// The connection handler could have produced data that is to be written to the
// server.
self.write_to_server(&connection)?;
} }
if event.is_writable() { if event.is_writable() {
self.write_to_server(&connection)?; self.write_to_server(&connection)?;
} }
Ok(()) Ok(())
})() })()
.or_else(|error| { .or_else(|error| {
self.remove_connection(&connection)?;
log::error! {"{error}"} log::error! {"{error}"}
self.remove_connection(&connection)?;
Ok(()) Ok(())
}) })
} }
@ -695,7 +748,6 @@ impl<'a> TunToProxy<'a> {
pub(crate) fn run(&mut self) -> Result<(), Error> { pub(crate) fn run(&mut self) -> Result<(), Error> {
let mut events = Events::with_capacity(1024); let mut events = Events::with_capacity(1024);
loop { loop {
match self.poll.poll(&mut events, None) { match self.poll.poll(&mut events, None) {
Ok(()) => { Ok(()) => {
@ -711,6 +763,8 @@ impl<'a> TunToProxy<'a> {
Err(e) => { Err(e) => {
if e.kind() != std::io::ErrorKind::Interrupted { if e.kind() != std::io::ErrorKind::Interrupted {
return Err(e.into()); return Err(e.into());
} else {
log::warn!("Poll interrupted")
} }
} }
} }