feat: Improve OAuth error handling with custom AuthError type and better timeout management
Some checks failed
Release Build / Rust project - latest (push) Has been cancelled
Rust Build & Publish / build (push) Has been cancelled
Pull Request / cargo test (push) Has been cancelled
Pull Request / cargo fmt --all -- --check (push) Has been cancelled
Pull Request / cargo clippy -- -D warnings (push) Has been cancelled

This commit is contained in:
Matthew Esposito 2025-04-21 14:17:27 -04:00
parent ddeefb5917
commit dcb507d567

View file

@ -15,6 +15,8 @@ const REDDIT_ANDROID_OAUTH_CLIENT_ID: &str = "ohXpoqrZYub1kg";
const AUTH_ENDPOINT: &str = "https://www.reddit.com";
const OAUTH_TIMEOUT: Duration = Duration::from_secs(5);
// Spoofed client for Android devices
#[derive(Debug, Clone, Default)]
pub struct Oauth {
@ -32,24 +34,30 @@ impl Oauth {
loop {
let attempt = Self::new_with_timeout().await;
match attempt {
Ok(Some(oauth)) => {
Ok(Ok(oauth)) => {
info!("[✅] Successfully created OAuth client");
return oauth;
}
Ok(None) => {
error!("Failed to create OAuth client. Retrying in 5 seconds...");
Ok(Err(e)) => {
error!("Failed to create OAuth client: {}. Retrying in 5 seconds...", {
match e {
AuthError::Hyper(error) => error.to_string(),
AuthError::SerdeDeserialize(error) => error.to_string(),
AuthError::Field((value, error)) => format!("{error}\n{value}"),
}
});
}
Err(duration) => {
error!("Failed to create OAuth client in {duration:?}. Retrying in 5 seconds...");
Err(_) => {
error!("Failed to create OAuth client before timeout. Retrying in 5 seconds...");
}
}
tokio::time::sleep(Duration::from_secs(5)).await;
tokio::time::sleep(OAUTH_TIMEOUT).await;
}
}
async fn new_with_timeout() -> Result<Option<Self>, Elapsed> {
async fn new_with_timeout() -> Result<Result<Self, AuthError>, Elapsed> {
let mut oauth = Self::default();
timeout(Duration::from_secs(5), oauth.login()).await.map(|result| result.map(|_| oauth))
timeout(OAUTH_TIMEOUT, oauth.login()).await.map(|result: Result<(), AuthError>| result.map(|_| oauth))
}
pub(crate) fn default() -> Self {
@ -66,7 +74,7 @@ impl Oauth {
device,
}
}
async fn login(&mut self) -> Option<()> {
async fn login(&mut self) -> Result<(), AuthError> {
// Construct URL for OAuth token
let url = format!("{AUTH_ENDPOINT}/auth/v2/oauth/access-token/loid");
let mut builder = Request::builder().method(Method::POST).uri(&url);
@ -95,7 +103,7 @@ impl Oauth {
// Send request
let client: &once_cell::sync::Lazy<client::Client<_, Body>> = &CLIENT;
let resp = client.request(request).await.ok()?;
let resp = client.request(request).await?;
trace!("Received response with status {} and length {:?}", resp.status(), resp.headers().get("content-length"));
trace!("OAuth headers: {:#?}", resp.headers());
@ -106,30 +114,58 @@ impl Oauth {
// Not worried about the privacy implications, since this is randomly changed
// and really only as privacy-concerning as the OAuth token itself.
if let Some(header) = resp.headers().get("x-reddit-loid") {
self.headers_map.insert("x-reddit-loid".to_owned(), header.to_str().ok()?.to_string());
self.headers_map.insert("x-reddit-loid".to_owned(), header.to_str().unwrap().to_string());
}
// Same with x-reddit-session
if let Some(header) = resp.headers().get("x-reddit-session") {
self.headers_map.insert("x-reddit-session".to_owned(), header.to_str().ok()?.to_string());
self.headers_map.insert("x-reddit-session".to_owned(), header.to_str().unwrap().to_string());
}
trace!("Serializing response...");
// Serialize response
let body_bytes = hyper::body::to_bytes(resp.into_body()).await.ok()?;
let json: serde_json::Value = serde_json::from_slice(&body_bytes).ok()?;
let body_bytes = hyper::body::to_bytes(resp.into_body()).await?;
let json: serde_json::Value = serde_json::from_slice(&body_bytes)?;
trace!("Accessing relevant fields...");
// Save token and expiry
self.token = json.get("access_token")?.as_str()?.to_string();
self.expires_in = json.get("expires_in")?.as_u64()?;
self.token = json
.get("access_token")
.ok_or_else(|| AuthError::Field((json.clone(), "access_token")))?
.as_str()
.ok_or_else(|| AuthError::Field((json.clone(), "access_token: as_str")))?
.to_string();
self.expires_in = json
.get("expires_in")
.ok_or_else(|| AuthError::Field((json.clone(), "expires_in")))?
.as_u64()
.ok_or_else(|| AuthError::Field((json.clone(), "expires_in: as_u64")))?;
self.headers_map.insert("Authorization".to_owned(), format!("Bearer {}", self.token));
info!("[✅] Success - Retrieved token \"{}...\", expires in {}", &self.token[..32], self.expires_in);
Some(())
Ok(())
}
}
#[derive(Debug)]
enum AuthError {
Hyper(hyper::Error),
SerdeDeserialize(serde_json::Error),
Field((serde_json::Value, &'static str)),
}
impl From<hyper::Error> for AuthError {
fn from(err: hyper::Error) -> Self {
AuthError::Hyper(err)
}
}
impl From<serde_json::Error> for AuthError {
fn from(err: serde_json::Error) -> Self {
AuthError::SerdeDeserialize(err)
}
}