mirror of
https://github.com/redlib-org/redlib.git
synced 2025-04-28 01:36:09 +00:00
feat: Improve OAuth error handling with custom AuthError type and better timeout management
Some checks failed
Release Build / Rust project - latest (push) Has been cancelled
Rust Build & Publish / build (push) Has been cancelled
Pull Request / cargo test (push) Has been cancelled
Pull Request / cargo fmt --all -- --check (push) Has been cancelled
Pull Request / cargo clippy -- -D warnings (push) Has been cancelled
Some checks failed
Release Build / Rust project - latest (push) Has been cancelled
Rust Build & Publish / build (push) Has been cancelled
Pull Request / cargo test (push) Has been cancelled
Pull Request / cargo fmt --all -- --check (push) Has been cancelled
Pull Request / cargo clippy -- -D warnings (push) Has been cancelled
This commit is contained in:
parent
ddeefb5917
commit
dcb507d567
1 changed files with 53 additions and 17 deletions
70
src/oauth.rs
70
src/oauth.rs
|
@ -15,6 +15,8 @@ const REDDIT_ANDROID_OAUTH_CLIENT_ID: &str = "ohXpoqrZYub1kg";
|
|||
|
||||
const AUTH_ENDPOINT: &str = "https://www.reddit.com";
|
||||
|
||||
const OAUTH_TIMEOUT: Duration = Duration::from_secs(5);
|
||||
|
||||
// Spoofed client for Android devices
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct Oauth {
|
||||
|
@ -32,24 +34,30 @@ impl Oauth {
|
|||
loop {
|
||||
let attempt = Self::new_with_timeout().await;
|
||||
match attempt {
|
||||
Ok(Some(oauth)) => {
|
||||
Ok(Ok(oauth)) => {
|
||||
info!("[✅] Successfully created OAuth client");
|
||||
return oauth;
|
||||
}
|
||||
Ok(None) => {
|
||||
error!("Failed to create OAuth client. Retrying in 5 seconds...");
|
||||
Ok(Err(e)) => {
|
||||
error!("Failed to create OAuth client: {}. Retrying in 5 seconds...", {
|
||||
match e {
|
||||
AuthError::Hyper(error) => error.to_string(),
|
||||
AuthError::SerdeDeserialize(error) => error.to_string(),
|
||||
AuthError::Field((value, error)) => format!("{error}\n{value}"),
|
||||
}
|
||||
});
|
||||
}
|
||||
Err(duration) => {
|
||||
error!("Failed to create OAuth client in {duration:?}. Retrying in 5 seconds...");
|
||||
Err(_) => {
|
||||
error!("Failed to create OAuth client before timeout. Retrying in 5 seconds...");
|
||||
}
|
||||
}
|
||||
tokio::time::sleep(Duration::from_secs(5)).await;
|
||||
tokio::time::sleep(OAUTH_TIMEOUT).await;
|
||||
}
|
||||
}
|
||||
|
||||
async fn new_with_timeout() -> Result<Option<Self>, Elapsed> {
|
||||
async fn new_with_timeout() -> Result<Result<Self, AuthError>, Elapsed> {
|
||||
let mut oauth = Self::default();
|
||||
timeout(Duration::from_secs(5), oauth.login()).await.map(|result| result.map(|_| oauth))
|
||||
timeout(OAUTH_TIMEOUT, oauth.login()).await.map(|result: Result<(), AuthError>| result.map(|_| oauth))
|
||||
}
|
||||
|
||||
pub(crate) fn default() -> Self {
|
||||
|
@ -66,7 +74,7 @@ impl Oauth {
|
|||
device,
|
||||
}
|
||||
}
|
||||
async fn login(&mut self) -> Option<()> {
|
||||
async fn login(&mut self) -> Result<(), AuthError> {
|
||||
// Construct URL for OAuth token
|
||||
let url = format!("{AUTH_ENDPOINT}/auth/v2/oauth/access-token/loid");
|
||||
let mut builder = Request::builder().method(Method::POST).uri(&url);
|
||||
|
@ -95,7 +103,7 @@ impl Oauth {
|
|||
|
||||
// Send request
|
||||
let client: &once_cell::sync::Lazy<client::Client<_, Body>> = &CLIENT;
|
||||
let resp = client.request(request).await.ok()?;
|
||||
let resp = client.request(request).await?;
|
||||
|
||||
trace!("Received response with status {} and length {:?}", resp.status(), resp.headers().get("content-length"));
|
||||
trace!("OAuth headers: {:#?}", resp.headers());
|
||||
|
@ -106,30 +114,58 @@ impl Oauth {
|
|||
// Not worried about the privacy implications, since this is randomly changed
|
||||
// and really only as privacy-concerning as the OAuth token itself.
|
||||
if let Some(header) = resp.headers().get("x-reddit-loid") {
|
||||
self.headers_map.insert("x-reddit-loid".to_owned(), header.to_str().ok()?.to_string());
|
||||
self.headers_map.insert("x-reddit-loid".to_owned(), header.to_str().unwrap().to_string());
|
||||
}
|
||||
|
||||
// Same with x-reddit-session
|
||||
if let Some(header) = resp.headers().get("x-reddit-session") {
|
||||
self.headers_map.insert("x-reddit-session".to_owned(), header.to_str().ok()?.to_string());
|
||||
self.headers_map.insert("x-reddit-session".to_owned(), header.to_str().unwrap().to_string());
|
||||
}
|
||||
|
||||
trace!("Serializing response...");
|
||||
|
||||
// Serialize response
|
||||
let body_bytes = hyper::body::to_bytes(resp.into_body()).await.ok()?;
|
||||
let json: serde_json::Value = serde_json::from_slice(&body_bytes).ok()?;
|
||||
let body_bytes = hyper::body::to_bytes(resp.into_body()).await?;
|
||||
let json: serde_json::Value = serde_json::from_slice(&body_bytes)?;
|
||||
|
||||
trace!("Accessing relevant fields...");
|
||||
|
||||
// Save token and expiry
|
||||
self.token = json.get("access_token")?.as_str()?.to_string();
|
||||
self.expires_in = json.get("expires_in")?.as_u64()?;
|
||||
self.token = json
|
||||
.get("access_token")
|
||||
.ok_or_else(|| AuthError::Field((json.clone(), "access_token")))?
|
||||
.as_str()
|
||||
.ok_or_else(|| AuthError::Field((json.clone(), "access_token: as_str")))?
|
||||
.to_string();
|
||||
self.expires_in = json
|
||||
.get("expires_in")
|
||||
.ok_or_else(|| AuthError::Field((json.clone(), "expires_in")))?
|
||||
.as_u64()
|
||||
.ok_or_else(|| AuthError::Field((json.clone(), "expires_in: as_u64")))?;
|
||||
self.headers_map.insert("Authorization".to_owned(), format!("Bearer {}", self.token));
|
||||
|
||||
info!("[✅] Success - Retrieved token \"{}...\", expires in {}", &self.token[..32], self.expires_in);
|
||||
|
||||
Some(())
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum AuthError {
|
||||
Hyper(hyper::Error),
|
||||
SerdeDeserialize(serde_json::Error),
|
||||
Field((serde_json::Value, &'static str)),
|
||||
}
|
||||
|
||||
impl From<hyper::Error> for AuthError {
|
||||
fn from(err: hyper::Error) -> Self {
|
||||
AuthError::Hyper(err)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<serde_json::Error> for AuthError {
|
||||
fn from(err: serde_json::Error) -> Self {
|
||||
AuthError::SerdeDeserialize(err)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue