diff --git a/src/oauth.rs b/src/oauth.rs index 5627900..f5659da 100644 --- a/src/oauth.rs +++ b/src/oauth.rs @@ -15,6 +15,8 @@ const REDDIT_ANDROID_OAUTH_CLIENT_ID: &str = "ohXpoqrZYub1kg"; const AUTH_ENDPOINT: &str = "https://www.reddit.com"; +const OAUTH_TIMEOUT: Duration = Duration::from_secs(5); + // Spoofed client for Android devices #[derive(Debug, Clone, Default)] pub struct Oauth { @@ -32,24 +34,30 @@ impl Oauth { loop { let attempt = Self::new_with_timeout().await; match attempt { - Ok(Some(oauth)) => { + Ok(Ok(oauth)) => { info!("[✅] Successfully created OAuth client"); return oauth; } - Ok(None) => { - error!("Failed to create OAuth client. Retrying in 5 seconds..."); + Ok(Err(e)) => { + error!("Failed to create OAuth client: {}. Retrying in 5 seconds...", { + match e { + AuthError::Hyper(error) => error.to_string(), + AuthError::SerdeDeserialize(error) => error.to_string(), + AuthError::Field((value, error)) => format!("{error}\n{value}"), + } + }); } - Err(duration) => { - error!("Failed to create OAuth client in {duration:?}. Retrying in 5 seconds..."); + Err(_) => { + error!("Failed to create OAuth client before timeout. Retrying in 5 seconds..."); } } - tokio::time::sleep(Duration::from_secs(5)).await; + tokio::time::sleep(OAUTH_TIMEOUT).await; } } - async fn new_with_timeout() -> Result, Elapsed> { + async fn new_with_timeout() -> Result, Elapsed> { let mut oauth = Self::default(); - timeout(Duration::from_secs(5), oauth.login()).await.map(|result| result.map(|_| oauth)) + timeout(OAUTH_TIMEOUT, oauth.login()).await.map(|result: Result<(), AuthError>| result.map(|_| oauth)) } pub(crate) fn default() -> Self { @@ -66,7 +74,7 @@ impl Oauth { device, } } - async fn login(&mut self) -> Option<()> { + async fn login(&mut self) -> Result<(), AuthError> { // Construct URL for OAuth token let url = format!("{AUTH_ENDPOINT}/auth/v2/oauth/access-token/loid"); let mut builder = Request::builder().method(Method::POST).uri(&url); @@ -95,7 +103,7 @@ impl Oauth { // Send request let client: &once_cell::sync::Lazy> = &CLIENT; - let resp = client.request(request).await.ok()?; + let resp = client.request(request).await?; trace!("Received response with status {} and length {:?}", resp.status(), resp.headers().get("content-length")); trace!("OAuth headers: {:#?}", resp.headers()); @@ -106,30 +114,58 @@ impl Oauth { // Not worried about the privacy implications, since this is randomly changed // and really only as privacy-concerning as the OAuth token itself. if let Some(header) = resp.headers().get("x-reddit-loid") { - self.headers_map.insert("x-reddit-loid".to_owned(), header.to_str().ok()?.to_string()); + self.headers_map.insert("x-reddit-loid".to_owned(), header.to_str().unwrap().to_string()); } // Same with x-reddit-session if let Some(header) = resp.headers().get("x-reddit-session") { - self.headers_map.insert("x-reddit-session".to_owned(), header.to_str().ok()?.to_string()); + self.headers_map.insert("x-reddit-session".to_owned(), header.to_str().unwrap().to_string()); } trace!("Serializing response..."); // Serialize response - let body_bytes = hyper::body::to_bytes(resp.into_body()).await.ok()?; - let json: serde_json::Value = serde_json::from_slice(&body_bytes).ok()?; + let body_bytes = hyper::body::to_bytes(resp.into_body()).await?; + let json: serde_json::Value = serde_json::from_slice(&body_bytes)?; trace!("Accessing relevant fields..."); // Save token and expiry - self.token = json.get("access_token")?.as_str()?.to_string(); - self.expires_in = json.get("expires_in")?.as_u64()?; + self.token = json + .get("access_token") + .ok_or_else(|| AuthError::Field((json.clone(), "access_token")))? + .as_str() + .ok_or_else(|| AuthError::Field((json.clone(), "access_token: as_str")))? + .to_string(); + self.expires_in = json + .get("expires_in") + .ok_or_else(|| AuthError::Field((json.clone(), "expires_in")))? + .as_u64() + .ok_or_else(|| AuthError::Field((json.clone(), "expires_in: as_u64")))?; self.headers_map.insert("Authorization".to_owned(), format!("Bearer {}", self.token)); info!("[✅] Success - Retrieved token \"{}...\", expires in {}", &self.token[..32], self.expires_in); - Some(()) + Ok(()) + } +} + +#[derive(Debug)] +enum AuthError { + Hyper(hyper::Error), + SerdeDeserialize(serde_json::Error), + Field((serde_json::Value, &'static str)), +} + +impl From for AuthError { + fn from(err: hyper::Error) -> Self { + AuthError::Hyper(err) + } +} + +impl From for AuthError { + fn from(err: serde_json::Error) -> Self { + AuthError::SerdeDeserialize(err) } }