diff --git a/src/tun2proxy.rs b/src/tun2proxy.rs index 8e295aa..c72b45c 100644 --- a/src/tun2proxy.rs +++ b/src/tun2proxy.rs @@ -159,7 +159,8 @@ fn connection_tuple(frame: &[u8]) -> Option<(Connection, bool, usize, usize)> { struct ConnectionState { smoltcp_handle: SocketHandle, mio_stream: TcpStream, - token: Token + token: Token, + handler: std::boxed::Box } pub(crate) trait TcpProxy { @@ -183,7 +184,6 @@ pub(crate) struct TunToProxy<'a> { udp_token: Token, iface: Interface<'a, VirtualTunDevice>, connections: HashMap, - managers: HashMap>, connection_managers: Vec>, next_token: usize, token_to_connection: HashMap, @@ -222,7 +222,6 @@ impl<'a> TunToProxy<'a> { next_token: 2, token_to_connection: Default::default(), socketset: SocketSet::new([]), - managers: Default::default(), connection_managers: Default::default() } } @@ -246,7 +245,6 @@ impl<'a> TunToProxy<'a> { } fn remove_connection(&mut self, connection: &Connection) { - self.managers.remove(connection); let mut connection_state = self.connections.remove(connection).unwrap(); self.token_to_connection.remove(&connection_state.token); self.poll.registry().deregister(&mut connection_state.mio_stream).unwrap(); @@ -267,10 +265,9 @@ impl<'a> TunToProxy<'a> { } fn tunsocket_read_and_forward(&mut self, connection: &Connection) { - if let Some(handler) = self.managers.get_mut(&connection) { + if let Some(state) = self.connections.get_mut(&connection) { let closed = { - let conn_info = self.connections.get_mut(&connection).unwrap(); - let mut socket = self.socketset.get::(conn_info.smoltcp_handle); + let mut socket = self.socketset.get::(state.smoltcp_handle); let mut error = Ok(()); while socket.can_recv() && error.is_ok() { socket.recv(|data| { @@ -279,7 +276,7 @@ impl<'a> TunToProxy<'a> { buffer: data, }; - error = handler.push_data(event); + error = state.handler.push_data(event); (data.len(), ()) }).unwrap(); @@ -312,41 +309,44 @@ impl<'a> TunToProxy<'a> { } let server = cm.unwrap().get_server(); if first_packet { - let mut socket = TcpSocket::new(TcpSocketBuffer::new(vec![0; 4096]), TcpSocketBuffer::new(vec![0; 4096])); - socket.set_ack_delay(None); - socket.listen(connection.dst).unwrap(); - let handle = self.socketset.add(socket); - - let socket = if server.is_ipv4() { - MioTcp::new_v4().unwrap() - } else { - MioTcp::new_v6().unwrap() - }; - let client = socket.connect(server).unwrap(); - - let token = Token(self.next_token); - self.next_token += 1; - - let mut conn = ConnectionState { - smoltcp_handle: handle, - mio_stream: client, - token - }; - - self.token_to_connection.insert(token, connection); - self.poll.registry().register(&mut conn.mio_stream, token, Interest::READABLE | Interest::WRITABLE).unwrap(); - - self.connections.insert(connection, conn); - for manager in self.connection_managers.iter_mut() { if let Some(handler) = manager.new_connection(&connection) { - self.managers.insert(connection, handler); + let mut socket = TcpSocket::new( + TcpSocketBuffer::new(vec![0; 4096]), + TcpSocketBuffer::new(vec![0; 4096])); + socket.set_ack_delay(None); + socket.listen(connection.dst).unwrap(); + let handle = self.socketset.add(socket); + + let socket = if server.is_ipv4() { + MioTcp::new_v4().unwrap() + } else { + MioTcp::new_v6().unwrap() + }; + let client = socket.connect(server).unwrap(); + + let token = Token(self.next_token); + self.next_token += 1; + + let mut state = ConnectionState { + smoltcp_handle: handle, + mio_stream: client, + token, + handler + }; + + self.token_to_connection.insert(token, connection); + self.poll.registry().register(&mut state.mio_stream, token, Interest::READABLE | Interest::WRITABLE).unwrap(); + + self.connections.insert(connection, state); + + + + + println!("[{:?}] CONNECT {}", chrono::offset::Local::now(), connection); break; } } - - - println!("[{:?}] CONNECT {}", chrono::offset::Local::now(), connection); } else if !self.connections.contains_key(&connection) { return; } @@ -377,16 +377,15 @@ impl<'a> TunToProxy<'a> { } fn write_to_server(&mut self, connection: &Connection) { - if let Some(handler) = self.managers.get_mut(&connection) { - let event = handler.peek_data(OutgoingDirection::ToServer); + if let Some(state) = self.connections.get_mut(&connection) { + let event = state.handler.peek_data(OutgoingDirection::ToServer); if event.buffer.len() == 0 { return; } - let connection_state = self.connections.get_mut(&connection).unwrap(); - let result = connection_state.mio_stream.write(event.buffer); + let result = state.mio_stream.write(event.buffer); match result { Ok(consumed) => { - handler.consume_data(OutgoingDirection::ToServer, consumed); + state.handler.consume_data(OutgoingDirection::ToServer, consumed); } Err(error) if error.kind() != std::io::ErrorKind::WouldBlock => { panic!("Error: {:?}", error); @@ -399,12 +398,12 @@ impl<'a> TunToProxy<'a> { } fn write_to_client(&mut self, connection: &Connection) { - if let Some(handler) = self.managers.get_mut(&connection) { - let event = handler.peek_data(OutgoingDirection::ToClient); - let socket = &mut self.socketset.get::(self.connections.get(&connection).unwrap().smoltcp_handle); + if let Some(state) = self.connections.get_mut(&connection) { + let event = state.handler.peek_data(OutgoingDirection::ToClient); + let socket = &mut self.socketset.get::(state.smoltcp_handle); if socket.may_send() { let consumed = socket.send_slice(event.buffer).unwrap(); - handler.consume_data(OutgoingDirection::ToClient, consumed); + state.handler.consume_data(OutgoingDirection::ToClient, consumed); } } } @@ -427,7 +426,6 @@ impl<'a> TunToProxy<'a> { if event.is_readable() { { - let conn = self.managers.get_mut(&connection).unwrap(); let state = self.connections.get_mut(&connection).unwrap(); let mut buf = [0u8; 4096]; @@ -448,7 +446,7 @@ impl<'a> TunToProxy<'a> { buffer: &buf[0..read], }; - if let Err(error) = conn.push_data(event) { + if let Err(error) = state.handler.push_data(event) { state.mio_stream.shutdown(Both).unwrap(); { let mut socket = self.socketset.get::(self.connections.get(&connection).unwrap().smoltcp_handle);