read_data_from_tcp_stream

This commit is contained in:
ssrlive 2023-11-10 13:40:38 +08:00
parent ca94d15597
commit 25257e9a27

View file

@ -585,10 +585,15 @@ impl<'a> TunToProxy<'a> {
state.dns_over_tcp_expiry = Some(Self::common_udp_life_timeout());
let mut vecbuf = vec![];
Self::read_data_from_tcp_stream(&mut state.mio_stream, &mut state.is_tcp_closed, |data| {
vecbuf.extend_from_slice(data);
Ok(())
})?;
Self::read_data_from_tcp_stream(
&mut state.mio_stream,
IP_PACKAGE_MAX_SIZE,
&mut state.is_tcp_closed,
|data| {
vecbuf.extend_from_slice(data);
Ok(())
},
)?;
let data_event = IncomingDataEvent {
direction: IncomingDirection::FromServer,
@ -1084,13 +1089,18 @@ impl<'a> TunToProxy<'a> {
let mut vecbuf = vec![];
use std::io::{Error, ErrorKind};
let r = Self::read_data_from_tcp_stream(&mut state.mio_stream, &mut state.is_tcp_closed, |data| {
vecbuf.extend_from_slice(data);
if vecbuf.len() >= IP_PACKAGE_MAX_SIZE {
return Err(Error::new(ErrorKind::OutOfMemory, "IP_PACKAGE_MAX_SIZE exceeded"));
}
Ok(())
});
let r = Self::read_data_from_tcp_stream(
&mut state.mio_stream,
IP_PACKAGE_MAX_SIZE,
&mut state.is_tcp_closed,
|data| {
vecbuf.extend_from_slice(data);
if vecbuf.len() >= IP_PACKAGE_MAX_SIZE {
return Err(Error::new(ErrorKind::OutOfMemory, "IP_PACKAGE_MAX_SIZE exceeded"));
}
Ok(())
},
);
let len = vecbuf.len();
if let Err(error) = r {
if error.kind() == ErrorKind::OutOfMemory {
@ -1207,11 +1217,17 @@ impl<'a> TunToProxy<'a> {
Ok(())
}
fn read_data_from_tcp_stream<F>(stream: &mut TcpStream, is_closed: &mut bool, mut cb: F) -> std::io::Result<()>
fn read_data_from_tcp_stream<F>(
stream: &mut dyn std::io::Read,
buffer_size: usize,
is_closed: &mut bool,
mut callback: F,
) -> std::io::Result<()>
where
F: FnMut(&mut [u8]) -> std::io::Result<()>,
{
let mut tmp = [0_u8; IP_PACKAGE_MAX_SIZE];
assert!(buffer_size > 0);
let mut tmp = vec![0_u8; buffer_size];
loop {
match stream.read(&mut tmp) {
Ok(0) => {
@ -1220,7 +1236,7 @@ impl<'a> TunToProxy<'a> {
break;
}
Ok(read_result) => {
cb(&mut tmp[0..read_result])?;
callback(&mut tmp[0..read_result])?;
}
Err(error) => {
if error.kind() == std::io::ErrorKind::WouldBlock {