beginning async version (#84)

This commit is contained in:
ssrlive 2024-02-01 19:15:32 +08:00 committed by GitHub
parent 337619169e
commit 9c4fa4260a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
41 changed files with 2022 additions and 3286 deletions

View file

@ -1,22 +1,20 @@
use crate::{
error::Error,
tun2proxy::{
ConnectionInfo, ConnectionManager, Direction, IncomingDataEvent, IncomingDirection, OutgoingDataEvent, OutgoingDirection,
ProxyHandler,
},
directions::{IncomingDataEvent, IncomingDirection, OutgoingDataEvent, OutgoingDirection},
error::{Error, Result},
proxy_handler::{ProxyHandler, ProxyHandlerManager},
session_info::{IpProtocol, SessionInfo},
};
use base64::Engine;
use httparse::Response;
use smoltcp::wire::IpProtocol;
use socks5_impl::protocol::UserKey;
use std::{
cell::RefCell,
collections::{hash_map::RandomState, HashMap, VecDeque},
iter::FromIterator,
net::SocketAddr,
rc::Rc,
str,
sync::Arc,
};
use tokio::sync::Mutex;
use unicase::UniCase;
#[derive(Eq, PartialEq, Debug)]
@ -48,10 +46,11 @@ pub struct HttpConnection {
crlf_state: u8,
counter: usize,
skip: usize,
digest_state: Rc<RefCell<Option<DigestState>>>,
digest_state: Arc<Mutex<Option<DigestState>>>,
before: bool,
credentials: Option<UserKey>,
info: ConnectionInfo,
info: SessionInfo,
domain_name: Option<String>,
}
static PROXY_AUTHENTICATE: &str = "Proxy-Authenticate";
@ -61,7 +60,12 @@ static TRANSFER_ENCODING: &str = "Transfer-Encoding";
static CONTENT_LENGTH: &str = "Content-Length";
impl HttpConnection {
fn new(info: &ConnectionInfo, credentials: Option<UserKey>, digest_state: Rc<RefCell<Option<DigestState>>>) -> Result<Self, Error> {
async fn new(
info: SessionInfo,
domain_name: Option<String>,
credentials: Option<UserKey>,
digest_state: Arc<Mutex<Option<DigestState>>>,
) -> Result<Self> {
let mut res = Self {
state: HttpState::ExpectResponseHeaders,
client_inbuf: VecDeque::default(),
@ -74,38 +78,50 @@ impl HttpConnection {
digest_state,
before: false,
credentials,
info: info.clone(),
info,
domain_name,
};
res.send_tunnel_request()?;
res.send_tunnel_request().await?;
Ok(res)
}
fn send_tunnel_request(&mut self) -> Result<(), Error> {
async fn send_tunnel_request(&mut self) -> Result<(), Error> {
let host = if let Some(domain_name) = &self.domain_name {
format!("{}:{}", domain_name, self.info.dst.port())
} else {
self.info.dst.to_string()
};
self.server_outbuf.extend(b"CONNECT ");
self.server_outbuf.extend(self.info.dst.to_string().as_bytes());
self.server_outbuf.extend(host.as_bytes());
self.server_outbuf.extend(b" HTTP/1.1\r\nHost: ");
self.server_outbuf.extend(self.info.dst.to_string().as_bytes());
self.server_outbuf.extend(host.as_bytes());
self.server_outbuf.extend(b"\r\n");
self.send_auth_data(if self.digest_state.borrow().is_none() {
let scheme = if self.digest_state.lock().await.is_none() {
AuthenticationScheme::Basic
} else {
AuthenticationScheme::Digest
})?;
};
self.send_auth_data(scheme).await?;
self.server_outbuf.extend(b"\r\n");
Ok(())
}
fn send_auth_data(&mut self, scheme: AuthenticationScheme) -> Result<(), Error> {
async fn send_auth_data(&mut self, scheme: AuthenticationScheme) -> Result<()> {
let Some(credentials) = &self.credentials else {
return Ok(());
};
match scheme {
AuthenticationScheme::Digest => {
let uri = self.info.dst.to_string();
let uri = if let Some(domain_name) = &self.domain_name {
format!("{}:{}", domain_name, self.info.dst.port())
} else {
self.info.dst.to_string()
};
let context = digest_auth::AuthContext::new_with_method(
&credentials.username,
@ -115,8 +131,8 @@ impl HttpConnection {
digest_auth::HttpMethod::CONNECT,
);
let mut state = self.digest_state.borrow_mut();
let response = state.as_mut().unwrap().respond(&context)?;
let mut state = self.digest_state.lock().await;
let response = state.as_mut().unwrap().respond(&context).unwrap();
self.server_outbuf
.extend(format!("{}: {}\r\n", PROXY_AUTHORIZATION, response.to_header_string()).as_bytes());
@ -133,7 +149,8 @@ impl HttpConnection {
Ok(())
}
fn state_change(&mut self) -> Result<(), Error> {
#[async_recursion::async_recursion]
async fn state_change(&mut self) -> Result<()> {
match self.state {
HttpState::ExpectResponseHeaders => {
while self.counter < self.server_inbuf.len() {
@ -176,7 +193,7 @@ impl HttpConnection {
// Connection successful
self.state = HttpState::Established;
self.server_inbuf.clear();
return self.state_change();
return self.state_change().await;
}
if status_code != 407 {
@ -209,7 +226,7 @@ impl HttpConnection {
}
// Update the digest state
self.digest_state.replace(Some(state));
self.digest_state.lock().await.replace(state);
self.before = true;
let closed = match headers_map.get(&UniCase::new(CONNECTION)) {
@ -222,7 +239,7 @@ impl HttpConnection {
// Reset all the buffers
self.server_inbuf.clear();
self.server_outbuf.clear();
self.send_tunnel_request()?;
self.send_tunnel_request().await?;
self.state = HttpState::Reset;
return Ok(());
@ -260,7 +277,7 @@ impl HttpConnection {
// Close the connection by information miss
self.server_inbuf.clear();
self.server_outbuf.clear();
self.send_tunnel_request()?;
self.send_tunnel_request().await?;
self.state = HttpState::Reset;
return Ok(());
@ -271,7 +288,7 @@ impl HttpConnection {
self.state = HttpState::ExpectResponse;
self.skip = content_length + len;
return self.state_change();
return self.state_change().await;
}
HttpState::ExpectResponse => {
if self.skip > 0 {
@ -285,10 +302,10 @@ impl HttpConnection {
// self.server_outbuf.append(&mut self.data_buf);
// self.data_buf.clear();
self.send_tunnel_request()?;
self.send_tunnel_request().await?;
self.state = HttpState::ExpectResponseHeaders;
return self.state_change();
return self.state_change().await;
}
}
HttpState::Established => {
@ -299,7 +316,7 @@ impl HttpConnection {
}
HttpState::Reset => {
self.state = HttpState::ExpectResponseHeaders;
return self.state_change();
return self.state_change().await;
}
_ => {}
}
@ -307,12 +324,17 @@ impl HttpConnection {
}
}
#[async_trait::async_trait]
impl ProxyHandler for HttpConnection {
fn get_connection_info(&self) -> &ConnectionInfo {
&self.info
fn get_session_info(&self) -> SessionInfo {
self.info
}
fn push_data(&mut self, event: IncomingDataEvent<'_>) -> Result<(), Error> {
fn get_domain_name(&self) -> Option<String> {
self.domain_name.clone()
}
async fn push_data(&mut self, event: IncomingDataEvent<'_>) -> std::io::Result<()> {
let direction = event.direction;
let buffer = event.buffer;
match direction {
@ -324,7 +346,8 @@ impl ProxyHandler for HttpConnection {
}
}
self.state_change()
self.state_change().await?;
Ok(())
}
fn consume_data(&mut self, dir: OutgoingDirection, size: usize) {
@ -352,16 +375,10 @@ impl ProxyHandler for HttpConnection {
self.state == HttpState::Established
}
fn data_len(&self, dir: Direction) -> usize {
fn data_len(&self, dir: OutgoingDirection) -> usize {
match dir {
Direction::Incoming(incoming) => match incoming {
IncomingDirection::FromServer => self.server_inbuf.len(),
IncomingDirection::FromClient => self.client_inbuf.len(),
},
Direction::Outgoing(outgoing) => match outgoing {
OutgoingDirection::ToServer => self.server_outbuf.len(),
OutgoingDirection::ToClient => self.client_outbuf.len(),
},
OutgoingDirection::ToServer => self.server_outbuf.len(),
OutgoingDirection::ToClient => self.client_outbuf.len(),
}
}
@ -377,19 +394,23 @@ impl ProxyHandler for HttpConnection {
pub(crate) struct HttpManager {
server: SocketAddr,
credentials: Option<UserKey>,
digest_state: Rc<RefCell<Option<DigestState>>>,
digest_state: Arc<Mutex<Option<DigestState>>>,
}
impl ConnectionManager for HttpManager {
fn new_proxy_handler(&self, info: &ConnectionInfo, _: bool) -> Result<Box<dyn ProxyHandler>, Error> {
#[async_trait::async_trait]
impl ProxyHandlerManager for HttpManager {
async fn new_proxy_handler(
&self,
info: SessionInfo,
domain_name: Option<String>,
_udp_associate: bool,
) -> std::io::Result<Arc<Mutex<dyn ProxyHandler>>> {
if info.protocol != IpProtocol::Tcp {
return Err("Invalid protocol".into());
return Err(Error::from("Invalid protocol").into());
}
Ok(Box::new(HttpConnection::new(
info,
self.credentials.clone(),
self.digest_state.clone(),
)?))
Ok(Arc::new(Mutex::new(
HttpConnection::new(info, domain_name, self.credentials.clone(), self.digest_state.clone()).await?,
)))
}
fn get_server_addr(&self) -> SocketAddr {
@ -402,7 +423,7 @@ impl HttpManager {
Self {
server,
credentials,
digest_state: Rc::new(RefCell::new(None)),
digest_state: Arc::new(Mutex::new(None)),
}
}
}