Improve error handling

This commit is contained in:
B. Blechschmidt 2021-09-02 21:02:17 +02:00
parent 6607df83cd
commit 5b6ac8b206
4 changed files with 159 additions and 73 deletions

View file

@ -1,4 +1,4 @@
use crate::tun2proxy::{Connection, TcpProxy, IncomingDirection, OutgoingDirection, OutgoingDataEvent, IncomingDataEvent, ConnectionManager}; use crate::tun2proxy::{Connection, TcpProxy, IncomingDirection, OutgoingDirection, OutgoingDataEvent, IncomingDataEvent, ConnectionManager, ProxyError};
use std::collections::VecDeque; use std::collections::VecDeque;
use std::net::SocketAddr; use std::net::SocketAddr;
@ -6,6 +6,7 @@ use std::net::SocketAddr;
#[allow(dead_code)] #[allow(dead_code)]
enum HttpState { enum HttpState {
SendRequest, SendRequest,
ExpectStatusCode,
ExpectResponse, ExpectResponse,
Established Established
} }
@ -23,7 +24,7 @@ pub struct HttpConnection {
impl HttpConnection { impl HttpConnection {
fn new(connection: &Connection) -> Self { fn new(connection: &Connection) -> Self {
let mut result = Self { let mut result = Self {
state: HttpState::ExpectResponse, state: HttpState::ExpectStatusCode,
client_inbuf: Default::default(), client_inbuf: Default::default(),
server_inbuf: Default::default(), server_inbuf: Default::default(),
client_outbuf: Default::default(), client_outbuf: Default::default(),
@ -34,29 +35,27 @@ impl HttpConnection {
result.server_outbuf.extend(b"CONNECT ".iter()); result.server_outbuf.extend(b"CONNECT ".iter());
result.destination_to_server_outbuf(connection); result.server_outbuf.extend(connection.dst.to_string().as_bytes());
result.server_outbuf.extend(b" HTTP/1.1\r\nHost: ".iter()); result.server_outbuf.extend(b" HTTP/1.1\r\nHost: ".iter());
result.destination_to_server_outbuf(connection); result.server_outbuf.extend(connection.dst.to_string().as_bytes());
result.server_outbuf.extend(b"\r\n\r\n".iter()); result.server_outbuf.extend(b"\r\n\r\n".iter());
result result
} }
fn destination_to_server_outbuf(&mut self, connection: &Connection) { fn state_change(&mut self) -> Result<(), ProxyError> {
let ipv6 = connection.dst.is_ipv6();
if ipv6 {
self.server_outbuf.extend(b"[".iter());
}
self.server_outbuf.extend(connection.dst.ip().to_string().as_bytes());
if ipv6 {
self.server_outbuf.extend(b"]".iter());
}
self.server_outbuf.extend(b":".iter());
self.server_outbuf.extend(connection.dst.port().to_string().as_bytes());
}
fn state_change(&mut self) {
match self.state { match self.state {
HttpState::ExpectStatusCode if self.server_inbuf.len() >= "HTTP/1.1 200 ".len() => {
let status_line: Vec<u8> = self.server_inbuf.range(0.."HTTP/1.1 200 ".len()).map(|&x| x).collect();
let slice = &status_line.as_slice()[0.."HTTP/1.1 2".len()];
if slice != b"HTTP/1.1 2" && slice != b"HTTP/1.0 2"
|| self.server_inbuf["HTTP/1.1 200 ".len() - 1] != b' '{
let status_str = String::from_utf8_lossy(&status_line.as_slice()[0.."HTTP/1.1 200".len()]);
return Err(ProxyError::new("Expected success status code. Server replied with ".to_owned() + &*status_str + "."));
}
self.state = HttpState::ExpectResponse;
return self.state_change();
}
HttpState::ExpectResponse => { HttpState::ExpectResponse => {
let mut counter = 0usize; let mut counter = 0usize;
for b_ref in self.server_inbuf.iter() { for b_ref in self.server_inbuf.iter() {
@ -74,13 +73,8 @@ impl HttpConnection {
self.server_outbuf.append(&mut self.data_buf); self.server_outbuf.append(&mut self.data_buf);
self.data_buf.clear(); self.data_buf.clear();
self.client_outbuf.extend(self.server_inbuf.iter());
self.server_outbuf.extend(self.client_inbuf.iter());
self.server_inbuf.clear();
self.client_inbuf.clear();
self.state = HttpState::Established; self.state = HttpState::Established;
return; return self.state_change();
} }
} }
@ -93,15 +87,15 @@ impl HttpConnection {
self.client_inbuf.clear(); self.client_inbuf.clear();
} }
_ => { _ => {
unreachable!();
} }
} }
Ok(())
} }
} }
impl TcpProxy for HttpConnection { impl TcpProxy for HttpConnection {
fn push_data(&mut self, event: IncomingDataEvent<'_>) { fn push_data(&mut self, event: IncomingDataEvent<'_>) -> Result<(), ProxyError> {
let direction = event.direction; let direction = event.direction;
let buffer = event.buffer; let buffer = event.buffer;
match direction { match direction {
@ -117,7 +111,7 @@ impl TcpProxy for HttpConnection {
} }
} }
self.state_change(); self.state_change()
} }
@ -143,6 +137,10 @@ impl TcpProxy for HttpConnection {
}; };
return event; return event;
} }
fn connection_established(&self) -> bool {
return self.state == HttpState::Established
}
} }
pub struct HttpManager { pub struct HttpManager {

View file

@ -1,4 +1,6 @@
#![feature(deque_make_contiguous)] #![feature(deque_make_contiguous)]
#![feature(deque_range)]
mod virtdevice; mod virtdevice;
mod socks5; mod socks5;
mod http; mod http;
@ -7,8 +9,7 @@ mod tun2proxy;
use socks5::*; use socks5::*;
use crate::http::HttpManager; use crate::http::HttpManager;
use crate::tun2proxy::TunToProxy; use crate::tun2proxy::TunToProxy;
use std::net::SocketAddr; use std::net::ToSocketAddrs;
use std::str::FromStr;
fn main() { fn main() {
let matches = clap::App::new(env!("CARGO_PKG_NAME")) let matches = clap::App::new(env!("CARGO_PKG_NAME"))
@ -44,8 +45,10 @@ fn main() {
let tun_name = matches.value_of("tun").unwrap(); let tun_name = matches.value_of("tun").unwrap();
let mut ttp = TunToProxy::new(tun_name); let mut ttp = TunToProxy::new(tun_name);
if let Some(addr) = matches.value_of("socks5_server") { if let Some(addr) = matches.value_of("socks5_server") {
if let Ok(server) = SocketAddr::from_str(addr) if let Ok(mut servers) = addr.to_socket_addrs()
{ {
let server = servers.next().unwrap();
println!("SOCKS5 server: {}", server);
ttp.add_connection_manager(Box::new(Socks5Manager::new(server))); ttp.add_connection_manager(Box::new(Socks5Manager::new(server)));
} else { } else {
eprintln!("Invalid server address."); eprintln!("Invalid server address.");
@ -54,8 +57,10 @@ fn main() {
} }
if let Some(addr) = matches.value_of("http_server") { if let Some(addr) = matches.value_of("http_server") {
if let Ok(server) = SocketAddr::from_str(addr) if let Ok(mut servers) = addr.to_socket_addrs()
{ {
let server = servers.next().unwrap();
println!("HTTP server: {}", server);
ttp.add_connection_manager(Box::new(HttpManager::new(server))); ttp.add_connection_manager(Box::new(HttpManager::new(server)));
} else { } else {
eprintln!("Invalid server address."); eprintln!("Invalid server address.");

View file

@ -1,4 +1,4 @@
use crate::tun2proxy::{Connection, OutgoingDirection, OutgoingDataEvent, IncomingDirection, IncomingDataEvent, ConnectionManager, TcpProxy}; use crate::tun2proxy::{Connection, OutgoingDirection, OutgoingDataEvent, IncomingDirection, IncomingDataEvent, ConnectionManager, TcpProxy, ProxyError};
use std::collections::VecDeque; use std::collections::VecDeque;
use std::net::{IpAddr, SocketAddr}; use std::net::{IpAddr, SocketAddr};
@ -28,6 +28,26 @@ enum SocksAuthentication {
Password = 2 Password = 2
} }
#[allow(dead_code)]
#[repr(u8)]
#[derive(Debug, Eq, PartialEq)]
enum SocksReplies {
Succeeded,
GeneralFailure,
ConnectionDisallowed,
NetworkUnreachable,
ConnectionRefused,
TtlExpired,
CommandUnsupported,
AddressUnsupported
}
impl std::fmt::Display for SocksReplies {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "{:?}", self)
}
}
pub struct SocksConnection { pub struct SocksConnection {
connection: Connection, connection: Connection,
state: SocksState, state: SocksState,
@ -54,20 +74,23 @@ impl SocksConnection {
result result
} }
fn forward_data(&mut self) { pub fn state_change(&mut self) -> Result<(), ProxyError> {
self.client_outbuf.extend(self.server_inbuf.iter());
self.server_outbuf.extend(self.client_inbuf.iter());
self.server_inbuf.clear();
self.client_inbuf.clear();
}
pub fn state_change(&mut self) {
let dst_ip = self.connection.dst.ip(); let dst_ip = self.connection.dst.ip();
match self.state { match self.state {
SocksState::ServerHello if self.server_inbuf.len() == 2 => { SocksState::ServerHello if self.server_inbuf.len() >= 2 => {
assert!(self.server_inbuf[0] == 5 && self.server_inbuf[1] == 0); if self.server_inbuf[0] != 5 {
return Err(ProxyError::new(
"SOCKS server replied with an unexpected version.".into()));
}
if self.server_inbuf[1] != 0 {
return Err(ProxyError::new(
"SOCKS server requires an unsupported authentication method.".into()));
}
self.server_inbuf.drain(0..2); self.server_inbuf.drain(0..2);
let cmd = if dst_ip.is_ipv4() { 1 } else { 4 }; let cmd = if dst_ip.is_ipv4() { 1 } else { 4 };
@ -82,31 +105,36 @@ impl SocksConnection {
]); ]);
self.state = SocksState::ReceiveResponse; self.state = SocksState::ReceiveResponse;
} return self.state_change();
SocksState::ServerHello if self.server_inbuf.len() > 2 => {
panic!("Socks protocol error!")
} }
SocksState::ReceiveResponse if self.server_inbuf.len() >= 4 => { SocksState::ReceiveResponse if self.server_inbuf.len() >= 4 => {
let _ver = self.server_inbuf[0]; let ver = self.server_inbuf[0];
let _rep = self.server_inbuf[1]; let rep = self.server_inbuf[1];
let _rsv = self.server_inbuf[2]; let _rsv = self.server_inbuf[2];
let atyp = self.server_inbuf[3]; let atyp = self.server_inbuf[3];
if ver != 5 {
return Err(ProxyError::new("SOCKS server replied with an unexpected version.".into()));
}
if rep != 0 {
return Err(ProxyError::new("SOCKS connection unsuccessful.".into()));
}
if atyp != SocksAddressType::Ipv4 as u8 if atyp != SocksAddressType::Ipv4 as u8
&& atyp != SocksAddressType::Ipv6 as u8 && atyp != SocksAddressType::Ipv6 as u8
&& atyp != SocksAddressType::DomainName as u8 { && atyp != SocksAddressType::DomainName as u8 {
panic!("Invalid address type"); return Err(ProxyError::new("SOCKS server replied with unrecognized address type.".into()));
} }
if atyp == SocksAddressType::DomainName as u8 && self.server_inbuf.len() < 5 { if atyp == SocksAddressType::DomainName as u8 && self.server_inbuf.len() < 5 {
return; return Ok(());
} }
if atyp == SocksAddressType::DomainName as u8 if atyp == SocksAddressType::DomainName as u8
&& self.server_inbuf.len() < 7 + (self.server_inbuf[4] as usize) { && self.server_inbuf.len() < 7 + (self.server_inbuf[4] as usize) {
return; return Ok(());
} }
let message_length = if atyp == SocksAddressType::Ipv4 as u8 { let message_length = if atyp == SocksAddressType::Ipv4 as u8 {
@ -121,21 +149,25 @@ impl SocksConnection {
self.server_outbuf.append(&mut self.data_buf); self.server_outbuf.append(&mut self.data_buf);
self.data_buf.clear(); self.data_buf.clear();
self.forward_data();
self.state = SocksState::Established; self.state = SocksState::Established;
return self.state_change();
} }
SocksState::Established => { SocksState::Established => {
self.forward_data(); self.client_outbuf.extend(self.server_inbuf.iter());
self.server_outbuf.extend(self.client_inbuf.iter());
self.server_inbuf.clear();
self.client_inbuf.clear();
} }
_ => {} _ => {}
} }
Ok(())
} }
} }
impl TcpProxy for SocksConnection { impl TcpProxy for SocksConnection {
fn push_data(&mut self, event: IncomingDataEvent<'_>) { fn push_data(&mut self, event: IncomingDataEvent<'_>) -> Result<(), ProxyError> {
let direction = event.direction; let direction = event.direction;
let buffer = event.buffer; let buffer = event.buffer;
match direction { match direction {
@ -151,7 +183,7 @@ impl TcpProxy for SocksConnection {
} }
} }
self.state_change(); self.state_change()
} }
@ -177,6 +209,10 @@ impl TcpProxy for SocksConnection {
}; };
return event; return event;
} }
fn connection_established(&self) -> bool {
return self.state == SocksState::Established
}
} }
pub struct Socks5Manager { pub struct Socks5Manager {

View file

@ -14,7 +14,23 @@ use smoltcp::socket::{SocketHandle, SocketSet, TcpSocket, TcpSocketBuffer};
use smoltcp::time::Instant; use smoltcp::time::Instant;
use smoltcp::wire::{IpAddress, IpCidr, Ipv4Address, Ipv4Packet, TcpPacket, UdpPacket, Ipv6Packet}; use smoltcp::wire::{IpAddress, IpCidr, Ipv4Address, Ipv4Packet, TcpPacket, UdpPacket, Ipv6Packet};
use crate::virtdevice::VirtualTunDevice; use crate::virtdevice::VirtualTunDevice;
use std::net::Shutdown::Both;
pub struct ProxyError {
message: String
}
impl ProxyError {
pub fn new(message: String) -> Self {
Self {
message
}
}
pub fn message(&self) -> String {
self.message.clone()
}
}
#[derive(Hash, Clone, Copy)] #[derive(Hash, Clone, Copy)]
pub struct Connection { pub struct Connection {
@ -23,6 +39,12 @@ pub struct Connection {
pub proto: u8 pub proto: u8
} }
impl std::fmt::Display for Connection {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "{} -> {}", self.src, self.dst)
}
}
impl Eq for Connection {} impl Eq for Connection {}
impl PartialEq<Self> for Connection { impl PartialEq<Self> for Connection {
@ -141,9 +163,10 @@ struct ConnectionState {
} }
pub(crate) trait TcpProxy { pub(crate) trait TcpProxy {
fn push_data(&mut self, event: IncomingDataEvent<'_>); fn push_data(&mut self, event: IncomingDataEvent<'_>) -> Result<(), ProxyError>;
fn consume_data(&mut self, dir: OutgoingDirection, size: usize); fn consume_data(&mut self, dir: OutgoingDirection, size: usize);
fn peek_data(&mut self, dir: OutgoingDirection) -> OutgoingDataEvent; fn peek_data(&mut self, dir: OutgoingDirection) -> OutgoingDataEvent;
fn connection_established(&self) -> bool;
} }
pub(crate) trait ConnectionManager { pub(crate) trait ConnectionManager {
@ -227,7 +250,7 @@ impl<'a> TunToProxy<'a> {
let mut connection_state = self.connections.remove(connection).unwrap(); let mut connection_state = self.connections.remove(connection).unwrap();
self.token_to_connection.remove(&connection_state.token); self.token_to_connection.remove(&connection_state.token);
self.poll.registry().deregister(&mut connection_state.mio_stream).unwrap(); self.poll.registry().deregister(&mut connection_state.mio_stream).unwrap();
println!("[{:?}] CLOSE {} -> {} (TCP)", chrono::offset::Local::now(), connection.src, connection.dst); println!("[{:?}] CLOSE {}", chrono::offset::Local::now(), connection);
} }
fn get_connection_manager(&self, connection: &Connection) -> Option<&Box<dyn ConnectionManager>>{ fn get_connection_manager(&self, connection: &Connection) -> Option<&Box<dyn ConnectionManager>>{
@ -239,25 +262,35 @@ impl<'a> TunToProxy<'a> {
None None
} }
fn print_error(error: ProxyError) {
println!("Error: {}", error.message());
}
fn tunsocket_read_and_forward(&mut self, connection: &Connection) { fn tunsocket_read_and_forward(&mut self, connection: &Connection) {
if let Some(handler) = self.managers.get_mut(&connection) { if let Some(handler) = self.managers.get_mut(&connection) {
let closed = { let closed = {
let conn_info = self.connections.get_mut(&connection).unwrap(); let conn_info = self.connections.get_mut(&connection).unwrap();
let mut socket = self.socketset.get::<TcpSocket>(conn_info.smoltcp_handle); let mut socket = self.socketset.get::<TcpSocket>(conn_info.smoltcp_handle);
while socket.can_recv() { let mut error = Ok(());
while socket.can_recv() && error.is_ok() {
socket.recv(|data| { socket.recv(|data| {
let event = IncomingDataEvent { let event = IncomingDataEvent {
direction: IncomingDirection::FromClient, direction: IncomingDirection::FromClient,
buffer: data, buffer: data,
}; };
handler.push_data(event); error = handler.push_data(event);
(data.len(), ()) (data.len(), ())
}).unwrap(); }).unwrap();
} }
socket.state() == smoltcp::socket::TcpState::CloseWait if error.is_err() {
Self::print_error(error.unwrap_err());
true
} else {
socket.state() == smoltcp::socket::TcpState::CloseWait
}
}; };
if closed { if closed {
@ -284,7 +317,11 @@ impl<'a> TunToProxy<'a> {
socket.listen(connection.dst).unwrap(); socket.listen(connection.dst).unwrap();
let handle = self.socketset.add(socket); let handle = self.socketset.add(socket);
let socket = MioTcp::new_v4().unwrap(); let socket = if server.is_ipv4() {
MioTcp::new_v4().unwrap()
} else {
MioTcp::new_v6().unwrap()
};
let client = socket.connect(server).unwrap(); let client = socket.connect(server).unwrap();
let token = Token(self.next_token); let token = Token(self.next_token);
@ -309,7 +346,7 @@ impl<'a> TunToProxy<'a> {
} }
println!("[{:?}] CONNECT {} -> {} (TCP)", chrono::offset::Local::now(), connection.src, connection.dst); println!("[{:?}] CONNECT {}", chrono::offset::Local::now(), connection);
} else if !self.connections.contains_key(&connection) { } else if !self.connections.contains_key(&connection) {
return; return;
} }
@ -388,16 +425,6 @@ impl<'a> TunToProxy<'a> {
fn mio_socket_event(&mut self, event: &Event) { fn mio_socket_event(&mut self, event: &Event) {
let connection = *self.token_to_connection.get(&event.token()).unwrap(); let connection = *self.token_to_connection.get(&event.token()).unwrap();
if event.is_read_closed() {
{
let mut socket = self.socketset.get::<TcpSocket>(self.connections.get(&connection).unwrap().smoltcp_handle);
socket.close();
}
self.expect_smoltcp_send();
self.remove_connection(&connection.clone());
return;
}
if event.is_readable() { if event.is_readable() {
{ {
let conn = self.managers.get_mut(&connection).unwrap(); let conn = self.managers.get_mut(&connection).unwrap();
@ -406,12 +433,32 @@ impl<'a> TunToProxy<'a> {
let mut buf = [0u8; 4096]; let mut buf = [0u8; 4096];
let read = state.mio_stream.read(&mut buf).unwrap(); let read = state.mio_stream.read(&mut buf).unwrap();
if read == 0 {
{
let mut socket = self.socketset.get::<TcpSocket>(self.connections.get(&connection).unwrap().smoltcp_handle);
socket.close();
}
self.expect_smoltcp_send();
self.remove_connection(&connection.clone());
return;
}
let event = IncomingDataEvent { let event = IncomingDataEvent {
direction: IncomingDirection::FromServer, direction: IncomingDirection::FromServer,
buffer: &buf[0..read], buffer: &buf[0..read],
}; };
conn.push_data(event); if let Err(error) = conn.push_data(event) {
state.mio_stream.shutdown(Both).unwrap();
{
let mut socket = self.socketset.get::<TcpSocket>(self.connections.get(&connection).unwrap().smoltcp_handle);
socket.close();
}
self.expect_smoltcp_send();
Self::print_error(error);
self.remove_connection(&connection.clone());
return;
}
} }
// We have read from the proxy server and pushed the data to the connection handler. // We have read from the proxy server and pushed the data to the connection handler.