Remove redundant HashMap

This commit is contained in:
B. Blechschmidt 2021-09-02 22:36:47 +02:00
parent 5b6ac8b206
commit 93f0444028

View file

@ -159,7 +159,8 @@ fn connection_tuple(frame: &[u8]) -> Option<(Connection, bool, usize, usize)> {
struct ConnectionState { struct ConnectionState {
smoltcp_handle: SocketHandle, smoltcp_handle: SocketHandle,
mio_stream: TcpStream, mio_stream: TcpStream,
token: Token token: Token,
handler: std::boxed::Box<dyn TcpProxy>
} }
pub(crate) trait TcpProxy { pub(crate) trait TcpProxy {
@ -183,7 +184,6 @@ pub(crate) struct TunToProxy<'a> {
udp_token: Token, udp_token: Token,
iface: Interface<'a, VirtualTunDevice>, iface: Interface<'a, VirtualTunDevice>,
connections: HashMap<Connection, ConnectionState>, connections: HashMap<Connection, ConnectionState>,
managers: HashMap<Connection, std::boxed::Box<dyn TcpProxy>>,
connection_managers: Vec<std::boxed::Box<dyn ConnectionManager>>, connection_managers: Vec<std::boxed::Box<dyn ConnectionManager>>,
next_token: usize, next_token: usize,
token_to_connection: HashMap<Token, Connection>, token_to_connection: HashMap<Token, Connection>,
@ -222,7 +222,6 @@ impl<'a> TunToProxy<'a> {
next_token: 2, next_token: 2,
token_to_connection: Default::default(), token_to_connection: Default::default(),
socketset: SocketSet::new([]), socketset: SocketSet::new([]),
managers: Default::default(),
connection_managers: Default::default() connection_managers: Default::default()
} }
} }
@ -246,7 +245,6 @@ impl<'a> TunToProxy<'a> {
} }
fn remove_connection(&mut self, connection: &Connection) { fn remove_connection(&mut self, connection: &Connection) {
self.managers.remove(connection);
let mut connection_state = self.connections.remove(connection).unwrap(); let mut connection_state = self.connections.remove(connection).unwrap();
self.token_to_connection.remove(&connection_state.token); self.token_to_connection.remove(&connection_state.token);
self.poll.registry().deregister(&mut connection_state.mio_stream).unwrap(); self.poll.registry().deregister(&mut connection_state.mio_stream).unwrap();
@ -267,10 +265,9 @@ impl<'a> TunToProxy<'a> {
} }
fn tunsocket_read_and_forward(&mut self, connection: &Connection) { fn tunsocket_read_and_forward(&mut self, connection: &Connection) {
if let Some(handler) = self.managers.get_mut(&connection) { if let Some(state) = self.connections.get_mut(&connection) {
let closed = { let closed = {
let conn_info = self.connections.get_mut(&connection).unwrap(); let mut socket = self.socketset.get::<TcpSocket>(state.smoltcp_handle);
let mut socket = self.socketset.get::<TcpSocket>(conn_info.smoltcp_handle);
let mut error = Ok(()); let mut error = Ok(());
while socket.can_recv() && error.is_ok() { while socket.can_recv() && error.is_ok() {
socket.recv(|data| { socket.recv(|data| {
@ -279,7 +276,7 @@ impl<'a> TunToProxy<'a> {
buffer: data, buffer: data,
}; };
error = handler.push_data(event); error = state.handler.push_data(event);
(data.len(), ()) (data.len(), ())
}).unwrap(); }).unwrap();
@ -312,41 +309,44 @@ impl<'a> TunToProxy<'a> {
} }
let server = cm.unwrap().get_server(); let server = cm.unwrap().get_server();
if first_packet { if first_packet {
let mut socket = TcpSocket::new(TcpSocketBuffer::new(vec![0; 4096]), TcpSocketBuffer::new(vec![0; 4096]));
socket.set_ack_delay(None);
socket.listen(connection.dst).unwrap();
let handle = self.socketset.add(socket);
let socket = if server.is_ipv4() {
MioTcp::new_v4().unwrap()
} else {
MioTcp::new_v6().unwrap()
};
let client = socket.connect(server).unwrap();
let token = Token(self.next_token);
self.next_token += 1;
let mut conn = ConnectionState {
smoltcp_handle: handle,
mio_stream: client,
token
};
self.token_to_connection.insert(token, connection);
self.poll.registry().register(&mut conn.mio_stream, token, Interest::READABLE | Interest::WRITABLE).unwrap();
self.connections.insert(connection, conn);
for manager in self.connection_managers.iter_mut() { for manager in self.connection_managers.iter_mut() {
if let Some(handler) = manager.new_connection(&connection) { if let Some(handler) = manager.new_connection(&connection) {
self.managers.insert(connection, handler); let mut socket = TcpSocket::new(
TcpSocketBuffer::new(vec![0; 4096]),
TcpSocketBuffer::new(vec![0; 4096]));
socket.set_ack_delay(None);
socket.listen(connection.dst).unwrap();
let handle = self.socketset.add(socket);
let socket = if server.is_ipv4() {
MioTcp::new_v4().unwrap()
} else {
MioTcp::new_v6().unwrap()
};
let client = socket.connect(server).unwrap();
let token = Token(self.next_token);
self.next_token += 1;
let mut state = ConnectionState {
smoltcp_handle: handle,
mio_stream: client,
token,
handler
};
self.token_to_connection.insert(token, connection);
self.poll.registry().register(&mut state.mio_stream, token, Interest::READABLE | Interest::WRITABLE).unwrap();
self.connections.insert(connection, state);
println!("[{:?}] CONNECT {}", chrono::offset::Local::now(), connection);
break; break;
} }
} }
println!("[{:?}] CONNECT {}", chrono::offset::Local::now(), connection);
} else if !self.connections.contains_key(&connection) { } else if !self.connections.contains_key(&connection) {
return; return;
} }
@ -377,16 +377,15 @@ impl<'a> TunToProxy<'a> {
} }
fn write_to_server(&mut self, connection: &Connection) { fn write_to_server(&mut self, connection: &Connection) {
if let Some(handler) = self.managers.get_mut(&connection) { if let Some(state) = self.connections.get_mut(&connection) {
let event = handler.peek_data(OutgoingDirection::ToServer); let event = state.handler.peek_data(OutgoingDirection::ToServer);
if event.buffer.len() == 0 { if event.buffer.len() == 0 {
return; return;
} }
let connection_state = self.connections.get_mut(&connection).unwrap(); let result = state.mio_stream.write(event.buffer);
let result = connection_state.mio_stream.write(event.buffer);
match result { match result {
Ok(consumed) => { Ok(consumed) => {
handler.consume_data(OutgoingDirection::ToServer, consumed); state.handler.consume_data(OutgoingDirection::ToServer, consumed);
} }
Err(error) if error.kind() != std::io::ErrorKind::WouldBlock => { Err(error) if error.kind() != std::io::ErrorKind::WouldBlock => {
panic!("Error: {:?}", error); panic!("Error: {:?}", error);
@ -399,12 +398,12 @@ impl<'a> TunToProxy<'a> {
} }
fn write_to_client(&mut self, connection: &Connection) { fn write_to_client(&mut self, connection: &Connection) {
if let Some(handler) = self.managers.get_mut(&connection) { if let Some(state) = self.connections.get_mut(&connection) {
let event = handler.peek_data(OutgoingDirection::ToClient); let event = state.handler.peek_data(OutgoingDirection::ToClient);
let socket = &mut self.socketset.get::<TcpSocket>(self.connections.get(&connection).unwrap().smoltcp_handle); let socket = &mut self.socketset.get::<TcpSocket>(state.smoltcp_handle);
if socket.may_send() { if socket.may_send() {
let consumed = socket.send_slice(event.buffer).unwrap(); let consumed = socket.send_slice(event.buffer).unwrap();
handler.consume_data(OutgoingDirection::ToClient, consumed); state.handler.consume_data(OutgoingDirection::ToClient, consumed);
} }
} }
} }
@ -427,7 +426,6 @@ impl<'a> TunToProxy<'a> {
if event.is_readable() { if event.is_readable() {
{ {
let conn = self.managers.get_mut(&connection).unwrap();
let state = self.connections.get_mut(&connection).unwrap(); let state = self.connections.get_mut(&connection).unwrap();
let mut buf = [0u8; 4096]; let mut buf = [0u8; 4096];
@ -448,7 +446,7 @@ impl<'a> TunToProxy<'a> {
buffer: &buf[0..read], buffer: &buf[0..read],
}; };
if let Err(error) = conn.push_data(event) { if let Err(error) = state.handler.push_data(event) {
state.mio_stream.shutdown(Both).unwrap(); state.mio_stream.shutdown(Both).unwrap();
{ {
let mut socket = self.socketset.get::<TcpSocket>(self.connections.get(&connection).unwrap().smoltcp_handle); let mut socket = self.socketset.get::<TcpSocket>(self.connections.get(&connection).unwrap().smoltcp_handle);