diff --git a/src/lib.rs b/src/lib.rs index 3c34d9f..f214f08 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -353,7 +353,8 @@ where let proxy_handler = mgr.new_proxy_handler(tcpinfo, None, false).await?; let queue = socket_queue.clone(); tokio::spawn(async move { - if let Err(e) = handle_udp_gateway_session(udp, udpgw, domain_name, proxy_handler, queue, ipv6_enabled).await { + let dst = info.dst; // real UDP destination address + if let Err(e) = handle_udp_gateway_session(udp, udpgw, dst, domain_name, proxy_handler, queue, ipv6_enabled).await { log::info!("Ending {} with \"{}\"", info, e); } log::trace!("Session count {}", TASK_COUNT.fetch_sub(1, Relaxed) - 1); @@ -482,16 +483,14 @@ async fn handle_tcp_session( async fn handle_udp_gateway_session( mut udp_stack: IpStackUdpStream, udpgw_client: Arc, + udp_dst: SocketAddr, domain_name: Option, proxy_handler: Arc>, socket_queue: Option>, ipv6_enabled: bool, ) -> crate::Result<()> { - let (_session_info, server_addr) = { - let handler = proxy_handler.lock().await; - (handler.get_session_info(), handler.get_server_addr()) - }; - let udpinfo = SessionInfo::new(udp_stack.local_addr(), udp_stack.peer_addr(), IpProtocol::Udp); + let proxy_server_addr = { proxy_handler.lock().await.get_server_addr() }; + let udpinfo = SessionInfo::new(udp_stack.local_addr(), udp_dst, IpProtocol::Udp); let udp_mtu = udpgw_client.get_udp_mtu(); let udp_timeout = udpgw_client.get_udp_timeout(); let mut server_stream = match udpgw_client.get_server_connection().await { @@ -500,7 +499,7 @@ async fn handle_udp_gateway_session( if udpgw_client.is_full() { return Err("max udpgw connection limit reached".into()); } - let mut tcp_server_stream = create_tcp_stream(&socket_queue, server_addr).await?; + let mut tcp_server_stream = create_tcp_stream(&socket_queue, proxy_server_addr).await?; if let Err(e) = handle_proxy_session(&mut tcp_server_stream, proxy_handler).await { return Err(format!("udpgw connection error: {}", e).into()); } @@ -508,8 +507,6 @@ async fn handle_udp_gateway_session( } }; - let udp_server_addr = udp_stack.peer_addr(); - let tcp_local_addr = server_stream.local_addr().clone(); match domain_name { @@ -544,7 +541,7 @@ async fn handle_udp_gateway_session( } } let new_id = server_stream.new_id(); - if let Err(e) = UdpGwClient::send_udpgw_packet(ipv6_enabled, read_len, udp_server_addr, domain_name.as_ref(), new_id, &mut stream_writer).await { + if let Err(e) = UdpGwClient::send_udpgw_packet(ipv6_enabled, read_len, udp_dst, domain_name.as_ref(), new_id, &mut stream_writer).await { log::info!("Ending {} <- {} with send_udpgw_packet {}", udpinfo, &tcp_local_addr, e); break; }