Refine code logic

This commit is contained in:
ssrlive 2023-08-05 15:52:32 +08:00 committed by GitHub
parent 8d835dc96d
commit 1031f586f7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 435 additions and 448 deletions

1
.gitignore vendored
View file

@ -1,3 +1,4 @@
examples/
build/ build/
tmp/ tmp/
.* .*

View file

@ -27,6 +27,9 @@ pub enum Error {
#[error("smoltcp::socket::tcp::SendError {0:?}")] #[error("smoltcp::socket::tcp::SendError {0:?}")]
Send(#[from] smoltcp::socket::tcp::SendError), Send(#[from] smoltcp::socket::tcp::SendError),
#[error("smoltcp::wire::Error {0:?}")]
Wire(#[from] smoltcp::wire::Error),
#[error("std::str::Utf8Error {0:?}")] #[error("std::str::Utf8Error {0:?}")]
Utf8(#[from] std::str::Utf8Error), Utf8(#[from] std::str::Utf8Error),

View file

@ -1,7 +1,7 @@
use crate::{ use crate::{
error::Error, error::Error,
tun2proxy::{ tun2proxy::{
Connection, ConnectionManager, Direction, IncomingDataEvent, IncomingDirection, ConnectionInfo, ConnectionManager, Direction, IncomingDataEvent, IncomingDirection,
OutgoingDataEvent, OutgoingDirection, TcpProxy, OutgoingDataEvent, OutgoingDirection, TcpProxy,
}, },
}; };
@ -63,8 +63,8 @@ static CONTENT_LENGTH: &str = "Content-Length";
impl HttpConnection { impl HttpConnection {
fn new( fn new(
connection: &Connection, info: &ConnectionInfo,
manager: Rc<dyn ConnectionManager>, credentials: Option<UserKey>,
digest_state: Rc<RefCell<Option<DigestState>>>, digest_state: Rc<RefCell<Option<DigestState>>>,
) -> Result<Self, Error> { ) -> Result<Self, Error> {
let mut res = Self { let mut res = Self {
@ -79,8 +79,8 @@ impl HttpConnection {
crlf_state: 0, crlf_state: 0,
digest_state, digest_state,
before: false, before: false,
credentials: manager.get_credentials().clone(), credentials,
destination: connection.dst.clone(), destination: info.dst.clone(),
}; };
res.send_tunnel_request()?; res.send_tunnel_request()?;
@ -394,28 +394,24 @@ pub(crate) struct HttpManager {
} }
impl ConnectionManager for HttpManager { impl ConnectionManager for HttpManager {
fn handles_connection(&self, connection: &Connection) -> bool { fn handles_connection(&self, info: &ConnectionInfo) -> bool {
connection.proto == IpProtocol::Tcp info.protocol == IpProtocol::Tcp
} }
fn new_connection( fn new_tcp_proxy(&self, info: &ConnectionInfo) -> Result<Box<dyn TcpProxy>, Error> {
&self, if info.protocol != IpProtocol::Tcp {
connection: &Connection, return Err("Invalid protocol".into());
manager: Rc<dyn ConnectionManager>,
) -> Result<Option<Box<dyn TcpProxy>>, Error> {
if connection.proto != IpProtocol::Tcp {
return Ok(None);
} }
Ok(Some(Box::new(HttpConnection::new( Ok(Box::new(HttpConnection::new(
connection, info,
manager, self.credentials.clone(),
self.digest_state.clone(), self.digest_state.clone(),
)?))) )?))
} }
fn close_connection(&self, _: &Connection) {} fn close_connection(&self, _: &ConnectionInfo) {}
fn get_server(&self) -> SocketAddr { fn get_server_addr(&self) -> SocketAddr {
self.server self.server
} }
@ -425,11 +421,11 @@ impl ConnectionManager for HttpManager {
} }
impl HttpManager { impl HttpManager {
pub fn new(server: SocketAddr, credentials: Option<UserKey>) -> Rc<Self> { pub fn new(server: SocketAddr, credentials: Option<UserKey>) -> Self {
Rc::new(Self { Self {
server, server,
credentials, credentials,
digest_state: Rc::new(RefCell::new(None)), digest_state: Rc::new(RefCell::new(None)),
}) }
} }
} }

View file

@ -1,9 +1,10 @@
use crate::{ use crate::{error::Error, http::HttpManager, socks::SocksProxyManager, tun2proxy::TunToProxy};
error::Error, http::HttpManager, socks::SocksManager, socks::SocksVersion, use socks5_impl::protocol::{UserKey, Version};
tun2proxy::TunToProxy, use std::{
net::{SocketAddr, ToSocketAddrs},
rc::Rc,
}; };
use socks5_impl::protocol::UserKey; use tun2proxy::ConnectionManager;
use std::net::{SocketAddr, ToSocketAddrs};
mod android; mod android;
pub mod error; pub mod error;
@ -90,7 +91,7 @@ impl std::fmt::Display for ProxyType {
#[derive(Default)] #[derive(Default)]
pub struct Options { pub struct Options {
virtdns: Option<virtdns::VirtualDns>, virtual_dns: Option<virtdns::VirtualDns>,
mtu: Option<usize>, mtu: Option<usize>,
} }
@ -100,7 +101,7 @@ impl Options {
} }
pub fn with_virtual_dns(mut self) -> Self { pub fn with_virtual_dns(mut self) -> Self {
self.virtdns = Some(virtdns::VirtualDns::new()); self.virtual_dns = Some(virtdns::VirtualDns::new());
self self
} }
@ -116,25 +117,18 @@ pub fn tun_to_proxy<'a>(
options: Options, options: Options,
) -> Result<TunToProxy<'a>, Error> { ) -> Result<TunToProxy<'a>, Error> {
let mut ttp = TunToProxy::new(interface, options)?; let mut ttp = TunToProxy::new(interface, options)?;
match proxy.proxy_type { let credentials = proxy.credentials.clone();
ProxyType::Socks4 => { let server = proxy.addr;
ttp.add_connection_manager(SocksManager::new( let mgr = match proxy.proxy_type {
proxy.addr, ProxyType::Socks4 => Rc::new(SocksProxyManager::new(server, Version::V4, credentials))
SocksVersion::V4, as Rc<dyn ConnectionManager>,
proxy.credentials.clone(), ProxyType::Socks5 => Rc::new(SocksProxyManager::new(server, Version::V5, credentials))
)); as Rc<dyn ConnectionManager>,
}
ProxyType::Socks5 => {
ttp.add_connection_manager(SocksManager::new(
proxy.addr,
SocksVersion::V5,
proxy.credentials.clone(),
));
}
ProxyType::Http => { ProxyType::Http => {
ttp.add_connection_manager(HttpManager::new(proxy.addr, proxy.credentials.clone())); Rc::new(HttpManager::new(server, credentials)) as Rc<dyn ConnectionManager>
}
} }
};
ttp.add_connection_manager(mgr);
Ok(ttp) Ok(ttp)
} }
@ -143,6 +137,7 @@ pub fn main_entry(
proxy: &Proxy, proxy: &Proxy,
options: Options, options: Options,
) -> Result<(), Error> { ) -> Result<(), Error> {
let ttp = tun_to_proxy(interface, proxy, options); let mut ttp = tun_to_proxy(interface, proxy, options)?;
ttp?.run() ttp.run()?;
Ok(())
} }

View file

@ -1,15 +1,15 @@
use crate::{ use crate::{
error::Error, error::Error,
tun2proxy::{ tun2proxy::{
Connection, ConnectionManager, Direction, IncomingDataEvent, IncomingDirection, ConnectionInfo, ConnectionManager, Direction, IncomingDataEvent, IncomingDirection,
OutgoingDataEvent, OutgoingDirection, TcpProxy, OutgoingDataEvent, OutgoingDirection, TcpProxy,
}, },
}; };
use smoltcp::wire::IpProtocol; use smoltcp::wire::IpProtocol;
use socks5_impl::protocol::{ use socks5_impl::protocol::{
self, handshake, password_method, Address, AuthMethod, StreamOperation, UserKey, self, handshake, password_method, Address, AuthMethod, StreamOperation, UserKey, Version,
}; };
use std::{collections::VecDeque, net::SocketAddr, rc::Rc}; use std::{collections::VecDeque, net::SocketAddr};
#[derive(Eq, PartialEq, Debug)] #[derive(Eq, PartialEq, Debug)]
#[allow(dead_code)] #[allow(dead_code)]
@ -23,33 +23,28 @@ enum SocksState {
Established, Established,
} }
#[repr(u8)] struct SocksProxyImpl {
#[derive(Copy, Clone, PartialEq, Debug)] info: ConnectionInfo,
pub enum SocksVersion {
V4 = 4,
V5 = 5,
}
pub(crate) struct SocksConnection {
connection: Connection,
state: SocksState, state: SocksState,
client_inbuf: VecDeque<u8>, client_inbuf: VecDeque<u8>,
server_inbuf: VecDeque<u8>, server_inbuf: VecDeque<u8>,
client_outbuf: VecDeque<u8>, client_outbuf: VecDeque<u8>,
server_outbuf: VecDeque<u8>, server_outbuf: VecDeque<u8>,
data_buf: VecDeque<u8>, data_buf: VecDeque<u8>,
version: SocksVersion, version: Version,
credentials: Option<UserKey>, credentials: Option<UserKey>,
command: protocol::Command,
udp_relay_addr: Option<Address>,
} }
impl SocksConnection { impl SocksProxyImpl {
pub fn new( pub fn new(
connection: &Connection, info: &ConnectionInfo,
manager: Rc<dyn ConnectionManager>, credentials: Option<UserKey>,
version: SocksVersion, version: Version,
) -> Result<Self, Error> { ) -> Result<Self, Error> {
let mut result = Self { let mut result = Self {
connection: connection.clone(), info: info.clone(),
state: SocksState::ServerHello, state: SocksState::ServerHello,
client_inbuf: VecDeque::default(), client_inbuf: VecDeque::default(),
server_inbuf: VecDeque::default(), server_inbuf: VecDeque::default(),
@ -57,23 +52,23 @@ impl SocksConnection {
server_outbuf: VecDeque::default(), server_outbuf: VecDeque::default(),
data_buf: VecDeque::default(), data_buf: VecDeque::default(),
version, version,
credentials: manager.get_credentials().clone(), credentials,
command: protocol::Command::Connect,
udp_relay_addr: None,
}; };
result.send_client_hello()?; result.send_client_hello()?;
Ok(result) Ok(result)
} }
fn send_client_hello(&mut self) -> Result<(), Error> { fn send_client_hello_socks4(&mut self) -> Result<(), Error> {
let credentials = &self.credentials; let credentials = &self.credentials;
match self.version {
SocksVersion::V4 => {
self.server_outbuf self.server_outbuf
.extend(&[self.version as u8, protocol::Command::Connect.into()]); .extend(&[self.version as u8, protocol::Command::Connect.into()]);
self.server_outbuf self.server_outbuf
.extend(self.connection.dst.port().to_be_bytes()); .extend(self.info.dst.port().to_be_bytes());
let mut ip_vec = Vec::<u8>::new(); let mut ip_vec = Vec::<u8>::new();
let mut name_vec = Vec::<u8>::new(); let mut name_vec = Vec::<u8>::new();
match &self.connection.dst { match &self.info.dst {
Address::SocketAddress(SocketAddr::V4(addr)) => { Address::SocketAddress(SocketAddr::V4(addr)) => {
ip_vec.extend(addr.ip().octets().as_ref()); ip_vec.extend(addr.ip().octets().as_ref());
} }
@ -96,9 +91,11 @@ impl SocksConnection {
} }
self.server_outbuf.push_back(0); self.server_outbuf.push_back(0);
self.server_outbuf.extend(name_vec); self.server_outbuf.extend(name_vec);
Ok(())
} }
SocksVersion::V5 => { fn send_client_hello_socks5(&mut self) -> Result<(), Error> {
let credentials = &self.credentials;
// Providing unassigned methods is supposed to bypass China's GFW. // Providing unassigned methods is supposed to bypass China's GFW.
// For details, refer to https://github.com/blechschmidt/tun2proxy/issues/35. // For details, refer to https://github.com/blechschmidt/tun2proxy/issues/35.
let mut methods = vec![ let mut methods = vec![
@ -110,6 +107,16 @@ impl SocksConnection {
methods.push(AuthMethod::UserPass); methods.push(AuthMethod::UserPass);
} }
handshake::Request::new(methods).write_to_stream(&mut self.server_outbuf)?; handshake::Request::new(methods).write_to_stream(&mut self.server_outbuf)?;
Ok(())
}
fn send_client_hello(&mut self) -> Result<(), Error> {
match self.version {
Version::V4 => {
self.send_client_hello_socks4()?;
}
Version::V5 => {
self.send_client_hello_socks5()?;
} }
} }
self.state = SocksState::ServerHello; self.state = SocksState::ServerHello;
@ -164,8 +171,8 @@ impl SocksConnection {
fn receive_server_hello(&mut self) -> Result<(), Error> { fn receive_server_hello(&mut self) -> Result<(), Error> {
match self.version { match self.version {
SocksVersion::V4 => self.receive_server_hello_socks4(), Version::V4 => self.receive_server_hello_socks4(),
SocksVersion::V5 => self.receive_server_hello_socks5(), Version::V5 => self.receive_server_hello_socks5(),
} }
} }
@ -213,6 +220,12 @@ impl SocksConnection {
if response.reply != protocol::Reply::Succeeded { if response.reply != protocol::Reply::Succeeded {
return Err(format!("SOCKS connection failed: {}", response.reply).into()); return Err(format!("SOCKS connection failed: {}", response.reply).into());
} }
if self.command == protocol::Command::UdpAssociate {
log::info!("UDP packet destination: {}", response.address);
self.udp_relay_addr = Some(response.address);
}
self.server_outbuf.append(&mut self.data_buf); self.server_outbuf.append(&mut self.data_buf);
self.data_buf.clear(); self.data_buf.clear();
@ -220,8 +233,9 @@ impl SocksConnection {
self.state_change() self.state_change()
} }
fn send_request(&mut self) -> Result<(), Error> { fn send_request_socks5(&mut self) -> Result<(), Error> {
protocol::Request::new(protocol::Command::Connect, self.connection.dst.clone()) // self.server_outbuf.extend(&[self.version as u8, self.command as u8, 0]);
protocol::Request::new(protocol::Command::Connect, self.info.dst.clone())
.write_to_stream(&mut self.server_outbuf)?; .write_to_stream(&mut self.server_outbuf)?;
self.state = SocksState::ReceiveResponse; self.state = SocksState::ReceiveResponse;
self.state_change() self.state_change()
@ -243,7 +257,7 @@ impl SocksConnection {
SocksState::ReceiveAuthResponse => self.receive_auth_data(), SocksState::ReceiveAuthResponse => self.receive_auth_data(),
SocksState::SendRequest => self.send_request(), SocksState::SendRequest => self.send_request_socks5(),
SocksState::ReceiveResponse => self.receive_connection_status(), SocksState::ReceiveResponse => self.receive_connection_status(),
@ -254,7 +268,7 @@ impl SocksConnection {
} }
} }
impl TcpProxy for SocksConnection { impl TcpProxy for SocksProxyImpl {
fn push_data(&mut self, event: IncomingDataEvent<'_>) -> Result<(), Error> { fn push_data(&mut self, event: IncomingDataEvent<'_>) -> Result<(), Error> {
let direction = event.direction; let direction = event.direction;
let buffer = event.buffer; let buffer = event.buffer;
@ -319,35 +333,31 @@ impl TcpProxy for SocksConnection {
} }
} }
pub struct SocksManager { pub(crate) struct SocksProxyManager {
server: SocketAddr, server: SocketAddr,
credentials: Option<UserKey>, credentials: Option<UserKey>,
version: SocksVersion, version: Version,
} }
impl ConnectionManager for SocksManager { impl ConnectionManager for SocksProxyManager {
fn handles_connection(&self, connection: &Connection) -> bool { fn handles_connection(&self, info: &ConnectionInfo) -> bool {
connection.proto == IpProtocol::Tcp info.protocol == IpProtocol::Tcp
} }
fn new_connection( fn new_tcp_proxy(&self, info: &ConnectionInfo) -> Result<Box<dyn TcpProxy>, Error> {
&self, if info.protocol != IpProtocol::Tcp {
connection: &Connection, return Err("Invalid protocol".into());
manager: Rc<dyn ConnectionManager>,
) -> Result<Option<Box<dyn TcpProxy>>, Error> {
if connection.proto != IpProtocol::Tcp {
return Ok(None);
} }
Ok(Some(Box::new(SocksConnection::new( Ok(Box::new(SocksProxyImpl::new(
connection, info,
manager, self.credentials.clone(),
self.version, self.version,
)?))) )?))
} }
fn close_connection(&self, _: &Connection) {} fn close_connection(&self, _: &ConnectionInfo) {}
fn get_server(&self) -> SocketAddr { fn get_server_addr(&self) -> SocketAddr {
self.server self.server
} }
@ -356,16 +366,12 @@ impl ConnectionManager for SocksManager {
} }
} }
impl SocksManager { impl SocksProxyManager {
pub fn new( pub(crate) fn new(server: SocketAddr, version: Version, credentials: Option<UserKey>) -> Self {
server: SocketAddr, Self {
version: SocksVersion,
credentials: Option<UserKey>,
) -> Rc<Self> {
Rc::new(Self {
server, server,
credentials, credentials,
version, version,
}) }
} }
} }

View file

@ -1,11 +1,11 @@
use crate::{error::Error, virtdevice::VirtualTunDevice, NetworkInterface, Options}; use crate::{error::Error, error::Result, virtdevice::VirtualTunDevice, NetworkInterface, Options};
use mio::{event::Event, net::TcpStream, unix::SourceFd, Events, Interest, Poll, Token}; use mio::{event::Event, net::TcpStream, unix::SourceFd, Events, Interest, Poll, Token};
use smoltcp::{ use smoltcp::{
iface::{Config, Interface, SocketHandle, SocketSet}, iface::{Config, Interface, SocketHandle, SocketSet},
phy::{Device, Medium, RxToken, TunTapInterface, TxToken}, phy::{Device, Medium, RxToken, TunTapInterface, TxToken},
socket::{tcp, tcp::State, udp, udp::UdpMetadata}, socket::{tcp, tcp::State, udp, udp::UdpMetadata},
time::Instant, time::Instant,
wire::{IpCidr, IpProtocol, Ipv4Packet, Ipv6Packet, TcpPacket, UdpPacket}, wire::{IpCidr, IpProtocol, Ipv4Packet, Ipv6Packet, TcpPacket, UdpPacket, UDP_HEADER_LEN},
}; };
use socks5_impl::protocol::{Address, UserKey}; use socks5_impl::protocol::{Address, UserKey};
use std::{ use std::{
@ -19,24 +19,40 @@ use std::{
}; };
#[derive(Hash, Clone, Eq, PartialEq, Debug)] #[derive(Hash, Clone, Eq, PartialEq, Debug)]
pub(crate) struct Connection { pub(crate) struct ConnectionInfo {
pub(crate) src: SocketAddr, pub(crate) src: SocketAddr,
pub(crate) dst: Address, pub(crate) dst: Address,
pub(crate) proto: IpProtocol, pub(crate) protocol: IpProtocol,
}
impl Default for ConnectionInfo {
fn default() -> Self {
Self {
src: SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 0),
dst: Address::unspecified(),
protocol: IpProtocol::Tcp,
}
}
}
impl ConnectionInfo {
#[allow(dead_code)]
pub fn new(src: SocketAddr, dst: Address, protocol: IpProtocol) -> Self {
Self { src, dst, protocol }
} }
impl Connection {
fn to_named(&self, name: String) -> Self { fn to_named(&self, name: String) -> Self {
let mut result = self.clone(); let mut result = self.clone();
result.dst = Address::from((name, result.dst.port())); result.dst = Address::from((name, result.dst.port()));
log::trace!("Replace dst \"{}\" -> \"{}\"", self.dst, result.dst); // let p = self.protocol;
// log::trace!("{p} replace dst \"{}\" -> \"{}\"", self.dst, result.dst);
result result
} }
} }
impl std::fmt::Display for Connection { impl std::fmt::Display for ConnectionInfo {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "{} -> {}", self.src, self.dst) write!(f, "{} {} -> {}", self.protocol, self.src, self.dst)
} }
} }
@ -60,10 +76,11 @@ pub(crate) enum Direction {
#[allow(dead_code)] #[allow(dead_code)]
pub(crate) enum ConnectionEvent<'a> { pub(crate) enum ConnectionEvent<'a> {
NewConnection(&'a Connection), NewConnection(&'a ConnectionInfo),
ConnectionClosed(&'a Connection), ConnectionClosed(&'a ConnectionInfo),
} }
#[derive(Debug)]
pub(crate) struct DataEvent<'a, T> { pub(crate) struct DataEvent<'a, T> {
pub(crate) direction: T, pub(crate) direction: T,
pub(crate) buffer: &'a [u8], pub(crate) buffer: &'a [u8],
@ -73,95 +90,87 @@ pub(crate) type IncomingDataEvent<'a> = DataEvent<'a, IncomingDirection>;
pub(crate) type OutgoingDataEvent<'a> = DataEvent<'a, OutgoingDirection>; pub(crate) type OutgoingDataEvent<'a> = DataEvent<'a, OutgoingDirection>;
fn get_transport_info( fn get_transport_info(
proto: IpProtocol, protocol: IpProtocol,
transport_offset: usize, transport_offset: usize,
packet: &[u8], packet: &[u8],
) -> Option<((u16, u16), bool, usize, usize)> { ) -> Result<((u16, u16), bool, usize, usize)> {
match proto { match protocol {
IpProtocol::Udp => match UdpPacket::new_checked(packet) { IpProtocol::Udp => UdpPacket::new_checked(packet)
Ok(result) => Some(( .map(|result| {
(
(result.src_port(), result.dst_port()), (result.src_port(), result.dst_port()),
false, false,
transport_offset + 8, transport_offset + UDP_HEADER_LEN,
packet.len() - 8, packet.len() - UDP_HEADER_LEN,
)), )
Err(_) => None, })
}, .map_err(|e| e.into()),
IpProtocol::Tcp => match TcpPacket::new_checked(packet) { IpProtocol::Tcp => TcpPacket::new_checked(packet)
Ok(result) => Some(( .map(|result| {
(
(result.src_port(), result.dst_port()), (result.src_port(), result.dst_port()),
result.syn() && !result.ack(), result.syn() && !result.ack(),
transport_offset + result.header_len() as usize, transport_offset + result.header_len() as usize,
packet.len(), packet.len(),
)), )
Err(_) => None, })
}, .map_err(|e| e.into()),
_ => None, _ => Err(format!("Unsupported protocol {protocol}").into()),
} }
} }
fn connection_tuple(frame: &[u8]) -> Option<(Connection, bool, usize, usize)> { fn connection_tuple(frame: &[u8]) -> Result<(ConnectionInfo, bool, usize, usize)> {
if let Ok(packet) = Ipv4Packet::new_checked(frame) { if let Ok(packet) = Ipv4Packet::new_checked(frame) {
let proto = packet.next_header(); let protocol = packet.next_header();
let mut a = [0_u8; 4]; let mut a = [0_u8; 4];
a.copy_from_slice(packet.src_addr().as_bytes()); a.copy_from_slice(packet.src_addr().as_bytes());
let src_addr = IpAddr::from(a); let src_addr = IpAddr::from(a);
a.copy_from_slice(packet.dst_addr().as_bytes()); a.copy_from_slice(packet.dst_addr().as_bytes());
let dst_addr = IpAddr::from(a); let dst_addr = IpAddr::from(a);
let header_len = packet.header_len().into();
return if let Some((ports, first_packet, payload_offset, payload_size)) = get_transport_info( let (ports, first_packet, payload_offset, payload_size) =
proto, get_transport_info(protocol, header_len, &frame[header_len..])?;
packet.header_len().into(), let info = ConnectionInfo {
&frame[packet.header_len().into()..],
) {
let connection = Connection {
src: SocketAddr::new(src_addr, ports.0), src: SocketAddr::new(src_addr, ports.0),
dst: SocketAddr::new(dst_addr, ports.1).into(), dst: SocketAddr::new(dst_addr, ports.1).into(),
proto, protocol,
};
Some((connection, first_packet, payload_offset, payload_size))
} else {
None
}; };
return Ok((info, first_packet, payload_offset, payload_size));
} }
match Ipv6Packet::new_checked(frame) { if let Ok(packet) = Ipv6Packet::new_checked(frame) {
Ok(packet) => {
// TODO: Support extension headers. // TODO: Support extension headers.
let proto = packet.next_header(); let protocol = packet.next_header();
let mut a = [0_u8; 16]; let mut a = [0_u8; 16];
a.copy_from_slice(packet.src_addr().as_bytes()); a.copy_from_slice(packet.src_addr().as_bytes());
let src_addr = IpAddr::from(a); let src_addr = IpAddr::from(a);
a.copy_from_slice(packet.dst_addr().as_bytes()); a.copy_from_slice(packet.dst_addr().as_bytes());
let dst_addr = IpAddr::from(a); let dst_addr = IpAddr::from(a);
let header_len = packet.header_len();
if let Some((ports, first_packet, payload_offset, payload_size)) = let (ports, first_packet, payload_offset, payload_size) =
get_transport_info(proto, packet.header_len(), &frame[packet.header_len()..]) get_transport_info(protocol, header_len, &frame[header_len..])?;
{ let info = ConnectionInfo {
let connection = Connection {
src: SocketAddr::new(src_addr, ports.0), src: SocketAddr::new(src_addr, ports.0),
dst: SocketAddr::new(dst_addr, ports.1).into(), dst: SocketAddr::new(dst_addr, ports.1).into(),
proto, protocol,
}; };
Some((connection, first_packet, payload_offset, payload_size)) return Ok((info, first_packet, payload_offset, payload_size));
} else {
None
}
}
_ => None,
} }
Err("Neither IPv6 nor IPv4 packet".into())
} }
const SERVER_WRITE_CLOSED: u8 = 1; const SERVER_WRITE_CLOSED: u8 = 1;
const CLIENT_WRITE_CLOSED: u8 = 2; const CLIENT_WRITE_CLOSED: u8 = 2;
struct ConnectionState { struct TcpConnectState {
smoltcp_handle: SocketHandle, smoltcp_handle: Option<SocketHandle>,
mio_stream: TcpStream, mio_stream: TcpStream,
token: Token, token: Token,
handler: Box<dyn TcpProxy>, tcp_proxy_handler: Box<dyn TcpProxy>,
close_state: u8, close_state: u8,
wait_read: bool, wait_read: bool,
wait_write: bool, wait_write: bool,
@ -176,30 +185,30 @@ pub(crate) trait TcpProxy {
fn reset_connection(&self) -> bool; fn reset_connection(&self) -> bool;
} }
pub(crate) trait UdpProxy {
fn send_frame(&mut self, destination: &Address, frame: &[u8]) -> Result<(), Error>;
fn receive_frame(&mut self, source: &SocketAddr, frame: &[u8]) -> Result<(), Error>;
}
pub(crate) trait ConnectionManager { pub(crate) trait ConnectionManager {
fn handles_connection(&self, connection: &Connection) -> bool; fn handles_connection(&self, info: &ConnectionInfo) -> bool;
fn new_connection( fn new_tcp_proxy(&self, info: &ConnectionInfo) -> Result<Box<dyn TcpProxy>, Error>;
&self, fn close_connection(&self, info: &ConnectionInfo);
connection: &Connection, fn get_server_addr(&self) -> SocketAddr;
manager: Rc<dyn ConnectionManager>,
) -> Result<Option<Box<dyn TcpProxy>>, Error>;
fn close_connection(&self, connection: &Connection);
fn get_server(&self) -> SocketAddr;
fn get_credentials(&self) -> &Option<UserKey>; fn get_credentials(&self) -> &Option<UserKey>;
} }
const TUN_TOKEN: Token = Token(0); const TUN_TOKEN: Token = Token(0);
const UDP_TOKEN: Token = Token(1);
const EXIT_TOKEN: Token = Token(2); const EXIT_TOKEN: Token = Token(2);
pub struct TunToProxy<'a> { pub struct TunToProxy<'a> {
tun: TunTapInterface, tun: TunTapInterface,
poll: Poll, poll: Poll,
iface: Interface, iface: Interface,
connections: HashMap<Connection, ConnectionState>, connection_map: HashMap<ConnectionInfo, TcpConnectState>,
connection_managers: Vec<Rc<dyn ConnectionManager>>, connection_managers: Vec<Rc<dyn ConnectionManager>>,
next_token: usize, next_token: usize,
token_to_connection: HashMap<Token, Connection>, token_to_info: HashMap<Token, ConnectionInfo>,
sockets: SocketSet<'a>, sockets: SocketSet<'a>,
device: VirtualTunDevice, device: VirtualTunDevice,
options: Options, options: Options,
@ -234,10 +243,10 @@ impl<'a> TunToProxy<'a> {
Medium::Ip => Config::new(smoltcp::wire::HardwareAddress::Ip), Medium::Ip => Config::new(smoltcp::wire::HardwareAddress::Ip),
Medium::Ieee802154 => todo!(), Medium::Ieee802154 => todo!(),
}; };
let mut virt = VirtualTunDevice::new(tun.capabilities()); let mut device = VirtualTunDevice::new(tun.capabilities());
let gateway4: Ipv4Addr = Ipv4Addr::from_str("0.0.0.1")?; let gateway4: Ipv4Addr = Ipv4Addr::from_str("0.0.0.1")?;
let gateway6: Ipv6Addr = Ipv6Addr::from_str("::1")?; let gateway6: Ipv6Addr = Ipv6Addr::from_str("::1")?;
let mut iface = Interface::new(config, &mut virt, Instant::now()); let mut iface = Interface::new(config, &mut device, Instant::now());
iface.update_ip_addrs(|ip_addrs| { iface.update_ip_addrs(|ip_addrs| {
ip_addrs.push(IpCidr::new(gateway4.into(), 0)).unwrap(); ip_addrs.push(IpCidr::new(gateway4.into(), 0)).unwrap();
ip_addrs.push(IpCidr::new(gateway6.into(), 0)).unwrap() ip_addrs.push(IpCidr::new(gateway6.into(), 0)).unwrap()
@ -250,12 +259,12 @@ impl<'a> TunToProxy<'a> {
tun, tun,
poll, poll,
iface, iface,
connections: HashMap::default(), connection_map: HashMap::default(),
next_token: usize::from(EXIT_TOKEN) + 1, next_token: usize::from(EXIT_TOKEN) + 1,
token_to_connection: HashMap::default(), token_to_info: HashMap::default(),
connection_managers: Vec::default(), connection_managers: Vec::default(),
sockets: SocketSet::new([]), sockets: SocketSet::new([]),
device: virt, device,
options, options,
write_sockets: HashSet::default(), write_sockets: HashSet::default(),
_exit_receiver: exit_receiver, _exit_receiver: exit_receiver,
@ -292,28 +301,34 @@ impl<'a> TunToProxy<'a> {
Ok(()) Ok(())
} }
fn remove_connection(&mut self, connection: &Connection) -> Result<(), Error> { fn remove_connection(&mut self, info: &ConnectionInfo) -> Result<(), Error> {
if let Some(mut conn) = self.connections.remove(connection) { if let Some(mut conn) = self.connection_map.remove(info) {
_ = conn.mio_stream.shutdown(Both);
if let Some(handle) = conn.smoltcp_handle {
let socket = self.sockets.get_mut::<tcp::Socket>(handle);
socket.close();
self.sockets.remove(handle);
}
self.expect_smoltcp_send()?;
let token = &conn.token; let token = &conn.token;
self.token_to_connection.remove(token); self.token_to_info.remove(token);
self.sockets.remove(conn.smoltcp_handle);
_ = self.poll.registry().deregister(&mut conn.mio_stream); _ = self.poll.registry().deregister(&mut conn.mio_stream);
log::info!("CLOSE {}", connection); log::info!("CLOSE {}", info);
} }
Ok(()) Ok(())
} }
fn get_connection_manager(&self, connection: &Connection) -> Option<Rc<dyn ConnectionManager>> { fn get_connection_manager(&self, info: &ConnectionInfo) -> Option<Rc<dyn ConnectionManager>> {
for manager in self.connection_managers.iter() { for manager in self.connection_managers.iter() {
if manager.handles_connection(connection) { if manager.handles_connection(info) {
return Some(manager.clone()); return Some(manager.clone());
} }
} }
None None
} }
fn check_change_close_state(&mut self, connection: &Connection) -> Result<(), Error> { fn check_change_close_state(&mut self, info: &ConnectionInfo) -> Result<(), Error> {
let state = self.connections.get_mut(connection); let state = self.connection_map.get_mut(info);
if state.is_none() { if state.is_none() {
return Ok(()); return Ok(());
} }
@ -321,23 +336,25 @@ impl<'a> TunToProxy<'a> {
let mut closed_ends = 0; let mut closed_ends = 0;
if (state.close_state & SERVER_WRITE_CLOSED) == SERVER_WRITE_CLOSED if (state.close_state & SERVER_WRITE_CLOSED) == SERVER_WRITE_CLOSED
&& !state && !state
.handler .tcp_proxy_handler
.have_data(Direction::Incoming(IncomingDirection::FromServer)) .have_data(Direction::Incoming(IncomingDirection::FromServer))
&& !state && !state
.handler .tcp_proxy_handler
.have_data(Direction::Outgoing(OutgoingDirection::ToClient)) .have_data(Direction::Outgoing(OutgoingDirection::ToClient))
{ {
let socket = self.sockets.get_mut::<tcp::Socket>(state.smoltcp_handle); if let Some(socket_handle) = state.smoltcp_handle {
let socket = self.sockets.get_mut::<tcp::Socket>(socket_handle);
socket.close(); socket.close();
}
closed_ends += 1; closed_ends += 1;
} }
if (state.close_state & CLIENT_WRITE_CLOSED) == CLIENT_WRITE_CLOSED if (state.close_state & CLIENT_WRITE_CLOSED) == CLIENT_WRITE_CLOSED
&& !state && !state
.handler .tcp_proxy_handler
.have_data(Direction::Incoming(IncomingDirection::FromClient)) .have_data(Direction::Incoming(IncomingDirection::FromClient))
&& !state && !state
.handler .tcp_proxy_handler
.have_data(Direction::Outgoing(OutgoingDirection::ToServer)) .have_data(Direction::Outgoing(OutgoingDirection::ToServer))
{ {
_ = state.mio_stream.shutdown(Shutdown::Write); _ = state.mio_stream.shutdown(Shutdown::Write);
@ -345,20 +362,22 @@ impl<'a> TunToProxy<'a> {
} }
if closed_ends == 2 { if closed_ends == 2 {
self.remove_connection(connection)?; self.remove_connection(info)?;
} }
Ok(()) Ok(())
} }
fn tunsocket_read_and_forward(&mut self, connection: &Connection) -> Result<(), Error> { fn tunsocket_read_and_forward(&mut self, info: &ConnectionInfo) -> Result<(), Error> {
// Scope for mutable borrow of self. // Scope for mutable borrow of self.
{ {
let state = self.connections.get_mut(connection); let state = match self.connection_map.get_mut(info) {
if state.is_none() { Some(state) => state,
return Ok(()); None => return Ok(()),
} };
let state = state.unwrap(); let socket = match state.smoltcp_handle {
let socket = self.sockets.get_mut::<tcp::Socket>(state.smoltcp_handle); Some(handle) => self.sockets.get_mut::<tcp::Socket>(handle),
None => return Ok(()),
};
let mut error = Ok(()); let mut error = Ok(());
while socket.can_recv() && error.is_ok() { while socket.can_recv() && error.is_ok() {
socket.recv(|data| { socket.recv(|data| {
@ -366,7 +385,7 @@ impl<'a> TunToProxy<'a> {
direction: IncomingDirection::FromClient, direction: IncomingDirection::FromClient,
buffer: data, buffer: data,
}; };
error = state.handler.push_data(event); error = state.tcp_proxy_handler.push_data(event);
(data.len(), ()) (data.len(), ())
})?; })?;
} }
@ -385,20 +404,14 @@ impl<'a> TunToProxy<'a> {
self.expect_smoltcp_send()?; self.expect_smoltcp_send()?;
} }
self.check_change_close_state(connection)?; self.check_change_close_state(info)?;
Ok(()) Ok(())
} }
// Update the poll registry depending on the connection's event interests. fn update_mio_socket_interest(poll: &mut Poll, state: &mut TcpConnectState) -> Result<()> {
fn update_mio_socket_interest(&mut self, connection: &Connection) -> Result<(), Error> {
let state = self
.connections
.get_mut(connection)
.ok_or("connection not found")?;
// Maybe we did not listen for any events before. Therefore, just swallow the error. // Maybe we did not listen for any events before. Therefore, just swallow the error.
_ = self.poll.registry().deregister(&mut state.mio_stream); _ = poll.registry().deregister(&mut state.mio_stream);
// If we do not wait for read or write events, we do not need to register them. // If we do not wait for read or write events, we do not need to register them.
if !state.wait_read && !state.wait_write { if !state.wait_read && !state.wait_write {
@ -415,78 +428,65 @@ impl<'a> TunToProxy<'a> {
interest = Interest::READABLE | Interest::WRITABLE; interest = Interest::READABLE | Interest::WRITABLE;
} }
self.poll poll.registry()
.registry()
.register(&mut state.mio_stream, state.token, interest)?; .register(&mut state.mio_stream, state.token, interest)?;
Ok(()) Ok(())
} }
// A raw packet was received on the tunnel interface. // A raw packet was received on the tunnel interface.
fn receive_tun(&mut self, frame: &mut [u8]) -> Result<(), Error> { fn receive_tun(&mut self, frame: &mut [u8]) -> Result<(), Error> {
if let Some((connection, first_packet, offset, size)) = connection_tuple(frame) { let mut handler = || -> Result<(), Error> {
let resolved_conn = match &mut self.options.virtdns { let (info, first_packet, payload_offset, payload_size) = connection_tuple(frame)?;
None => connection.clone(), let dst = SocketAddr::try_from(&info.dst)?;
Some(virt_dns) => { let connection_info = match &mut self.options.virtual_dns {
let ip = SocketAddr::try_from(connection.dst.clone())?.ip(); None => info.clone(),
virt_dns.touch_ip(&ip); Some(virtual_dns) => {
match virt_dns.resolve_ip(&ip) { let dst_ip = dst.ip();
None => connection.clone(), virtual_dns.touch_ip(&dst_ip);
Some(name) => connection.to_named(name.clone()), match virtual_dns.resolve_ip(&dst_ip) {
None => info.clone(),
Some(name) => info.to_named(name.clone()),
} }
} }
}; };
let dst = connection.dst; log::trace!("{} ({})", connection_info, dst);
let handler = || -> Result<(), Error> { if connection_info.protocol == IpProtocol::Tcp {
if resolved_conn.proto == IpProtocol::Tcp { let server_addr = self
let cm = self.get_connection_manager(&resolved_conn); .get_connection_manager(&connection_info)
if cm.is_none() { .ok_or("get_connection_manager")?
log::trace!("no connect manager"); .get_server_addr();
return Ok(());
}
let server = cm.unwrap().get_server();
if first_packet { if first_packet {
for manager in self.connection_managers.iter_mut() { if let Some(manager) = self.connection_managers.iter_mut().next() {
if let Some(handler) = let tcp_proxy_handler = manager.new_tcp_proxy(&connection_info)?;
manager.new_connection(&resolved_conn, manager.clone())?
{
let mut socket = tcp::Socket::new( let mut socket = tcp::Socket::new(
tcp::SocketBuffer::new(vec![0; 1024 * 128]), tcp::SocketBuffer::new(vec![0; 1024 * 128]),
tcp::SocketBuffer::new(vec![0; 1024 * 128]), tcp::SocketBuffer::new(vec![0; 1024 * 128]),
); );
socket.set_ack_delay(None); socket.set_ack_delay(None);
let dst = SocketAddr::try_from(dst)?;
socket.listen(dst)?; socket.listen(dst)?;
let handle = self.sockets.add(socket); let handle = self.sockets.add(socket);
let client = TcpStream::connect(server)?; let mut client = TcpStream::connect(server_addr)?;
let token = self.new_token(); let token = self.new_token();
let i = Interest::READABLE;
self.poll.registry().register(&mut client, token, i)?;
let mut state = ConnectionState { let state = TcpConnectState {
smoltcp_handle: handle, smoltcp_handle: Some(handle),
mio_stream: client, mio_stream: client,
token, token,
handler, tcp_proxy_handler,
close_state: 0, close_state: 0,
wait_read: true, wait_read: true,
wait_write: false, wait_write: false,
}; };
self.connection_map.insert(connection_info.clone(), state);
self.token_to_connection self.token_to_info.insert(token, connection_info.clone());
.insert(token, resolved_conn.clone());
self.poll.registry().register(
&mut state.mio_stream,
token,
Interest::READABLE,
)?;
self.connections.insert(resolved_conn.clone(), state); // log::info!("CONNECT {} ({})", connection_info, dst);
log::info!("CONNECT {}", resolved_conn,);
break;
} }
} } else if !self.connection_map.contains_key(&connection_info) {
} else if !self.connections.contains_key(&resolved_conn) {
return Ok(()); return Ok(());
} }
@ -499,31 +499,24 @@ impl<'a> TunToProxy<'a> {
self.expect_smoltcp_send()?; self.expect_smoltcp_send()?;
// Read from the smoltcp socket and push the data to the connection handler. // Read from the smoltcp socket and push the data to the connection handler.
self.tunsocket_read_and_forward(&resolved_conn)?; self.tunsocket_read_and_forward(&connection_info)?;
// The connection handler builds up the connection or encapsulates the data. // The connection handler builds up the connection or encapsulates the data.
// Therefore, we now expect it to write data to the server. // Therefore, we now expect it to write data to the server.
self.write_to_server(&resolved_conn)?; self.write_to_server(&connection_info)?;
} else if resolved_conn.proto == IpProtocol::Udp && resolved_conn.dst.port() == 53 { } else if connection_info.protocol == IpProtocol::Udp {
if let Some(virtual_dns) = &mut self.options.virtdns { let port = connection_info.dst.port();
let payload = &frame[offset..offset + size]; if let (Some(virtual_dns), true) = (&mut self.options.virtual_dns, port == 53) {
let payload = &frame[payload_offset..payload_offset + payload_size];
if let Some(response) = virtual_dns.receive_query(payload) { if let Some(response) = virtual_dns.receive_query(payload) {
let rx_buffer = udp::PacketBuffer::new( let rx_buffer =
vec![udp::PacketMetadata::EMPTY], udp::PacketBuffer::new(vec![udp::PacketMetadata::EMPTY], vec![0; 4096]);
vec![0; 4096], let tx_buffer =
); udp::PacketBuffer::new(vec![udp::PacketMetadata::EMPTY], vec![0; 4096]);
let tx_buffer = udp::PacketBuffer::new(
vec![udp::PacketMetadata::EMPTY],
vec![0; 4096],
);
let mut socket = udp::Socket::new(rx_buffer, tx_buffer); let mut socket = udp::Socket::new(rx_buffer, tx_buffer);
let dst = SocketAddr::try_from(dst)?;
socket.bind(dst)?; socket.bind(dst)?;
socket socket
.send_slice( .send_slice(response.as_slice(), UdpMetadata::from(connection_info.src))
response.as_slice(),
UdpMetadata::from(resolved_conn.src),
)
.expect("failed to send DNS response"); .expect("failed to send DNS response");
let handle = self.sockets.add(socket); let handle = self.sockets.add(socket);
self.expect_smoltcp_send()?; self.expect_smoltcp_send()?;
@ -537,28 +530,29 @@ impl<'a> TunToProxy<'a> {
if let Err(error) = handler() { if let Err(error) = handler() {
log::error!("{}", error); log::error!("{}", error);
} }
}
Ok(()) Ok(())
} }
fn write_to_server(&mut self, connection: &Connection) -> Result<(), Error> { fn write_to_server(&mut self, info: &ConnectionInfo) -> Result<(), Error> {
if let Some(state) = self.connections.get_mut(connection) { if let Some(state) = self.connection_map.get_mut(info) {
let event = state.handler.peek_data(OutgoingDirection::ToServer); let event = state
.tcp_proxy_handler
.peek_data(OutgoingDirection::ToServer);
let buffer_size = event.buffer.len(); let buffer_size = event.buffer.len();
if buffer_size == 0 { if buffer_size == 0 {
state.wait_write = false; state.wait_write = false;
self.update_mio_socket_interest(connection)?; Self::update_mio_socket_interest(&mut self.poll, state)?;
self.check_change_close_state(connection)?; self.check_change_close_state(info)?;
return Ok(()); return Ok(());
} }
let result = state.mio_stream.write(event.buffer); let result = state.mio_stream.write(event.buffer);
match result { match result {
Ok(written) => { Ok(written) => {
state state
.handler .tcp_proxy_handler
.consume_data(OutgoingDirection::ToServer, written); .consume_data(OutgoingDirection::ToServer, written);
state.wait_write = written < buffer_size; state.wait_write = written < buffer_size;
self.update_mio_socket_interest(connection)?; Self::update_mio_socket_interest(&mut self.poll, state)?;
} }
Err(error) if error.kind() != std::io::ErrorKind::WouldBlock => { Err(error) if error.kind() != std::io::ErrorKind::WouldBlock => {
return Err(error.into()); return Err(error.into());
@ -566,30 +560,35 @@ impl<'a> TunToProxy<'a> {
_ => { _ => {
// WOULDBLOCK case // WOULDBLOCK case
state.wait_write = true; state.wait_write = true;
self.update_mio_socket_interest(connection)?; Self::update_mio_socket_interest(&mut self.poll, state)?;
} }
} }
} }
self.check_change_close_state(connection)?; self.check_change_close_state(info)?;
Ok(()) Ok(())
} }
fn write_to_client(&mut self, token: Token, connection: &Connection) -> Result<(), Error> { fn write_to_client(&mut self, token: Token, info: &ConnectionInfo) -> Result<(), Error> {
while let Some(state) = self.connections.get_mut(connection) { while let Some(state) = self.connection_map.get_mut(info) {
let socket_handle = state.smoltcp_handle; let socket_handle = match state.smoltcp_handle {
let event = state.handler.peek_data(OutgoingDirection::ToClient); Some(handle) => handle,
None => break,
};
let event = state
.tcp_proxy_handler
.peek_data(OutgoingDirection::ToClient);
let buflen = event.buffer.len(); let buflen = event.buffer.len();
let consumed; let consumed;
{ {
let socket = self.sockets.get_mut::<tcp::Socket>(socket_handle); let socket = self.sockets.get_mut::<tcp::Socket>(socket_handle);
if socket.may_send() { if socket.may_send() {
if let Some(virtdns) = &mut self.options.virtdns { if let Some(virtual_dns) = &mut self.options.virtual_dns {
// Unwrapping is fine because every smoltcp socket is bound to an. // Unwrapping is fine because every smoltcp socket is bound to an.
virtdns.touch_ip(&IpAddr::from(socket.local_endpoint().unwrap().addr)); virtual_dns.touch_ip(&IpAddr::from(socket.local_endpoint().unwrap().addr));
} }
consumed = socket.send_slice(event.buffer)?; consumed = socket.send_slice(event.buffer)?;
state state
.handler .tcp_proxy_handler
.consume_data(OutgoingDirection::ToClient, consumed); .consume_data(OutgoingDirection::ToClient, consumed);
self.expect_smoltcp_send()?; self.expect_smoltcp_send()?;
if consumed < buflen { if consumed < buflen {
@ -606,7 +605,7 @@ impl<'a> TunToProxy<'a> {
} }
} }
self.check_change_close_state(connection)?; self.check_change_close_state(info)?;
} }
Ok(()) Ok(())
} }
@ -623,7 +622,7 @@ impl<'a> TunToProxy<'a> {
fn send_to_smoltcp(&mut self) -> Result<(), Error> { fn send_to_smoltcp(&mut self) -> Result<(), Error> {
let cloned = self.write_sockets.clone(); let cloned = self.write_sockets.clone();
for token in cloned.iter() { for token in cloned.iter() {
if let Some(connection) = self.token_to_connection.get(token) { if let Some(connection) = self.token_to_info.get(token) {
let connection = connection.clone(); let connection = connection.clone();
if let Err(error) = self.write_to_client(*token, &connection) { if let Err(error) = self.write_to_client(*token, &connection) {
self.remove_connection(&connection)?; self.remove_connection(&connection)?;
@ -636,24 +635,26 @@ impl<'a> TunToProxy<'a> {
fn mio_socket_event(&mut self, event: &Event) -> Result<(), Error> { fn mio_socket_event(&mut self, event: &Event) -> Result<(), Error> {
let e = "connection not found"; let e = "connection not found";
let conn_ref = self.token_to_connection.get(&event.token()); let conn_info = match self.token_to_info.get(&event.token()) {
// We may have closed the connection in an earlier iteration over the poll Some(conn_info) => conn_info.clone(),
// events, e.g. because an event through the tunnel interface indicated that the connection None => {
// We may have closed the connection in an earlier iteration over the poll events,
// e.g. because an event through the tunnel interface indicated that the connection
// should be closed. // should be closed.
if conn_ref.is_none() {
log::trace!("{e}"); log::trace!("{e}");
return Ok(()); return Ok(());
} }
let connection = conn_ref.unwrap().clone(); };
let server = self let server = self
.get_connection_manager(&connection) .get_connection_manager(&conn_info)
.unwrap() .ok_or(e)?
.get_server(); .get_server_addr();
let mut block = || -> Result<(), Error> { let mut block = || -> Result<(), Error> {
if event.is_readable() || event.is_read_closed() { if event.is_readable() || event.is_read_closed() {
{ {
let state = self.connections.get_mut(&connection).ok_or(e)?; let state = self.connection_map.get_mut(&conn_info).ok_or(e)?;
// TODO: Move this reading process to its own function. // TODO: Move this reading process to its own function.
let mut vecbuf = Vec::<u8>::new(); let mut vecbuf = Vec::<u8>::new();
@ -673,34 +674,26 @@ impl<'a> TunToProxy<'a> {
direction: IncomingDirection::FromServer, direction: IncomingDirection::FromServer,
buffer: &data[0..read], buffer: &data[0..read],
}; };
if let Err(error) = state.handler.push_data(data_event) { if let Err(error) = state.tcp_proxy_handler.push_data(data_event) {
state.mio_stream.shutdown(Both)?; log::error!("{}", error);
{ self.remove_connection(&conn_info.clone())?;
let socket = self.sockets.get_mut::<tcp::Socket>(
self.connections.get(&connection).ok_or(e)?.smoltcp_handle,
);
socket.close();
}
self.expect_smoltcp_send()?;
log::error! {"{error}"}
self.remove_connection(&connection.clone())?;
return Ok(()); return Ok(());
} }
// The handler request for reset the server connection // The handler request for reset the server connection
if state.handler.reset_connection() { if state.tcp_proxy_handler.reset_connection() {
_ = self.poll.registry().deregister(&mut state.mio_stream); _ = self.poll.registry().deregister(&mut state.mio_stream);
// Closes the connection with the proxy // Closes the connection with the proxy
state.mio_stream.shutdown(Both)?; state.mio_stream.shutdown(Both)?;
log::info!("RESET {}", connection); log::info!("RESET {}", conn_info);
state.mio_stream = TcpStream::connect(server)?; state.mio_stream = TcpStream::connect(server)?;
state.wait_read = true; state.wait_read = true;
state.wait_write = true; state.wait_write = true;
self.update_mio_socket_interest(&connection)?; Self::update_mio_socket_interest(&mut self.poll, state)?;
return Ok(()); return Ok(());
} }
@ -708,62 +701,55 @@ impl<'a> TunToProxy<'a> {
if read == 0 || event.is_read_closed() { if read == 0 || event.is_read_closed() {
state.wait_read = false; state.wait_read = false;
state.close_state |= SERVER_WRITE_CLOSED; state.close_state |= SERVER_WRITE_CLOSED;
self.update_mio_socket_interest(&connection)?; Self::update_mio_socket_interest(&mut self.poll, state)?;
self.check_change_close_state(&connection)?; self.check_change_close_state(&conn_info)?;
self.expect_smoltcp_send()?; self.expect_smoltcp_send()?;
} }
} }
// We have read from the proxy server and pushed the data to the connection handler. // We have read from the proxy server and pushed the data to the connection handler.
// Thus, expect data to be processed (e.g. decapsulated) and forwarded to the client. // Thus, expect data to be processed (e.g. decapsulated) and forwarded to the client.
self.write_to_client(event.token(), &connection)?; self.write_to_client(event.token(), &conn_info)?;
// The connection handler could have produced data that is to be written to the // The connection handler could have produced data that is to be written to the
// server. // server.
self.write_to_server(&connection)?; self.write_to_server(&conn_info)?;
} }
if event.is_writable() { if event.is_writable() {
self.write_to_server(&connection)?; self.write_to_server(&conn_info)?;
} }
Ok::<(), Error>(()) Ok::<(), Error>(())
}; };
if let Err(error) = block() { if let Err(error) = block() {
log::error!("{}", error); log::error!("{}", error);
self.remove_connection(&connection)?; self.remove_connection(&conn_info)?;
} }
Ok(()) Ok(())
} }
fn udp_event(&mut self, _event: &Event) {}
pub fn run(&mut self) -> Result<(), Error> { pub fn run(&mut self) -> Result<(), Error> {
let mut events = Events::with_capacity(1024); let mut events = Events::with_capacity(1024);
loop { loop {
match self.poll.poll(&mut events, None) { if let Err(err) = self.poll.poll(&mut events, None) {
Ok(()) => { if err.kind() == std::io::ErrorKind::Interrupted {
log::warn!("Poll interrupted: \"{err}\", ignored, continue polling");
continue;
}
return Err(err.into());
}
for event in events.iter() { for event in events.iter() {
match event.token() { match event.token() {
EXIT_TOKEN => { EXIT_TOKEN => {
log::info!("exiting..."); log::info!("Exiting tun2proxy...");
return Ok(()); return Ok(());
} }
TUN_TOKEN => self.tun_event(event)?, TUN_TOKEN => self.tun_event(event)?,
UDP_TOKEN => self.udp_event(event),
_ => self.mio_socket_event(event)?, _ => self.mio_socket_event(event)?,
} }
} }
self.send_to_smoltcp()?; self.send_to_smoltcp()?;
} }
Err(e) => {
if e.kind() == std::io::ErrorKind::Interrupted {
log::warn!("Poll interrupted: \"{e}\", ignored, continue polling");
} else {
return Err(e.into());
}
}
}
}
} }
pub fn shutdown(&mut self) -> Result<(), Error> { pub fn shutdown(&mut self) -> Result<(), Error> {