mirror of
https://github.com/redlib-org/redlib.git
synced 2025-05-15 22:42:52 +00:00
Merge branch 'master' into feature/fixed-navbar
This commit is contained in:
commit
bb5f2674d1
41 changed files with 2834 additions and 787 deletions
165
src/client.rs
165
src/client.rs
|
@ -1,12 +1,55 @@
|
|||
use cached::proc_macro::cached;
|
||||
use futures_lite::{future::Boxed, FutureExt};
|
||||
use hyper::{body::Buf, client, Body, Request, Response, Uri};
|
||||
use hyper::{body, body::Buf, client, header, Body, Method, Request, Response, Uri};
|
||||
use libflate::gzip;
|
||||
use percent_encoding::{percent_encode, CONTROLS};
|
||||
use serde_json::Value;
|
||||
use std::result::Result;
|
||||
use std::{io, result::Result};
|
||||
|
||||
use crate::dbg_msg;
|
||||
use crate::server::RequestExt;
|
||||
|
||||
const REDDIT_URL_BASE: &str = "https://www.reddit.com";
|
||||
|
||||
/// Gets the canonical path for a resource on Reddit. This is accomplished by
|
||||
/// making a `HEAD` request to Reddit at the path given in `path`.
|
||||
///
|
||||
/// This function returns `Ok(Some(path))`, where `path`'s value is identical
|
||||
/// to that of the value of the argument `path`, if Reddit responds to our
|
||||
/// `HEAD` request with a 2xx-family HTTP code. It will also return an
|
||||
/// `Ok(Some(String))` if Reddit responds to our `HEAD` request with a
|
||||
/// `Location` header in the response, and the HTTP code is in the 3xx-family;
|
||||
/// the `String` will contain the path as reported in `Location`. The return
|
||||
/// value is `Ok(None)` if Reddit responded with a 3xx, but did not provide a
|
||||
/// `Location` header. An `Err(String)` is returned if Reddit responds with a
|
||||
/// 429, or if we were unable to decode the value in the `Location` header.
|
||||
#[cached(size = 1024, time = 600, result = true)]
|
||||
pub async fn canonical_path(path: String) -> Result<Option<String>, String> {
|
||||
let res = reddit_head(path.clone(), true).await?;
|
||||
|
||||
if res.status() == 429 {
|
||||
return Err("Too many requests.".to_string());
|
||||
};
|
||||
|
||||
// If Reddit responds with a 2xx, then the path is already canonical.
|
||||
if res.status().to_string().starts_with('2') {
|
||||
return Ok(Some(path));
|
||||
}
|
||||
|
||||
// If Reddit responds with anything other than 3xx (except for the 2xx as
|
||||
// above), return a None.
|
||||
if !res.status().to_string().starts_with('3') {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
Ok(
|
||||
res
|
||||
.headers()
|
||||
.get(header::LOCATION)
|
||||
.map(|val| percent_encode(val.as_bytes(), CONTROLS).to_string().trim_start_matches(REDDIT_URL_BASE).to_string()),
|
||||
)
|
||||
}
|
||||
|
||||
pub async fn proxy(req: Request<Body>, format: &str) -> Result<Response<Body>, String> {
|
||||
let mut url = format!("{}?{}", format, req.uri().query().unwrap_or_default());
|
||||
|
||||
|
@ -62,20 +105,39 @@ async fn stream(url: &str, req: &Request<Body>) -> Result<Response<Body>, String
|
|||
.map_err(|e| e.to_string())
|
||||
}
|
||||
|
||||
fn request(url: String, quarantine: bool) -> Boxed<Result<Response<Body>, String>> {
|
||||
/// Makes a GET request to Reddit at `path`. By default, this will honor HTTP
|
||||
/// 3xx codes Reddit returns and will automatically redirect.
|
||||
fn reddit_get(path: String, quarantine: bool) -> Boxed<Result<Response<Body>, String>> {
|
||||
request(&Method::GET, path, true, quarantine)
|
||||
}
|
||||
|
||||
/// Makes a HEAD request to Reddit at `path`. This will not follow redirects.
|
||||
fn reddit_head(path: String, quarantine: bool) -> Boxed<Result<Response<Body>, String>> {
|
||||
request(&Method::HEAD, path, false, quarantine)
|
||||
}
|
||||
|
||||
/// Makes a request to Reddit. If `redirect` is `true`, request_with_redirect
|
||||
/// will recurse on the URL that Reddit provides in the Location HTTP header
|
||||
/// in its response.
|
||||
fn request(method: &'static Method, path: String, redirect: bool, quarantine: bool) -> Boxed<Result<Response<Body>, String>> {
|
||||
// Build Reddit URL from path.
|
||||
let url = format!("{}{}", REDDIT_URL_BASE, path);
|
||||
|
||||
// Prepare the HTTPS connector.
|
||||
let https = hyper_rustls::HttpsConnectorBuilder::new().with_native_roots().https_or_http().enable_http1().build();
|
||||
|
||||
// Construct the hyper client from the HTTPS connector.
|
||||
let client: client::Client<_, hyper::Body> = client::Client::builder().build(https);
|
||||
|
||||
// Build request
|
||||
// Build request to Reddit. When making a GET, request gzip compression.
|
||||
// (Reddit doesn't do brotli yet.)
|
||||
let builder = Request::builder()
|
||||
.method("GET")
|
||||
.method(method)
|
||||
.uri(&url)
|
||||
.header("User-Agent", format!("web:libreddit:{}", env!("CARGO_PKG_VERSION")))
|
||||
.header("Host", "www.reddit.com")
|
||||
.header("Accept", "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,*/*;q=0.8")
|
||||
.header("Accept-Encoding", if method == Method::GET { "gzip" } else { "identity" })
|
||||
.header("Accept-Language", "en-US,en;q=0.5")
|
||||
.header("Connection", "keep-alive")
|
||||
.header("Cookie", if quarantine { "_options=%7B%22pref_quarantine_optin%22%3A%20true%7D" } else { "" })
|
||||
|
@ -84,26 +146,94 @@ fn request(url: String, quarantine: bool) -> Boxed<Result<Response<Body>, String
|
|||
async move {
|
||||
match builder {
|
||||
Ok(req) => match client.request(req).await {
|
||||
Ok(response) => {
|
||||
Ok(mut response) => {
|
||||
// Reddit may respond with a 3xx. Decide whether or not to
|
||||
// redirect based on caller params.
|
||||
if response.status().to_string().starts_with('3') {
|
||||
request(
|
||||
if !redirect {
|
||||
return Ok(response);
|
||||
};
|
||||
|
||||
return request(
|
||||
method,
|
||||
response
|
||||
.headers()
|
||||
.get("Location")
|
||||
.get(header::LOCATION)
|
||||
.map(|val| {
|
||||
let new_url = percent_encode(val.as_bytes(), CONTROLS).to_string();
|
||||
format!("{}{}raw_json=1", new_url, if new_url.contains('?') { "&" } else { "?" })
|
||||
// We need to make adjustments to the URI
|
||||
// we get back from Reddit. Namely, we
|
||||
// must:
|
||||
//
|
||||
// 1. Remove the authority (e.g.
|
||||
// https://www.reddit.com) that may be
|
||||
// present, so that we recurse on the
|
||||
// path (and query parameters) as
|
||||
// required.
|
||||
//
|
||||
// 2. Percent-encode the path.
|
||||
let new_path = percent_encode(val.as_bytes(), CONTROLS).to_string().trim_start_matches(REDDIT_URL_BASE).to_string();
|
||||
format!("{}{}raw_json=1", new_path, if new_path.contains('?') { "&" } else { "?" })
|
||||
})
|
||||
.unwrap_or_default()
|
||||
.to_string(),
|
||||
true,
|
||||
quarantine,
|
||||
)
|
||||
.await
|
||||
} else {
|
||||
Ok(response)
|
||||
.await;
|
||||
};
|
||||
|
||||
match response.headers().get(header::CONTENT_ENCODING) {
|
||||
// Content not compressed.
|
||||
None => Ok(response),
|
||||
|
||||
// Content encoded (hopefully with gzip).
|
||||
Some(hdr) => {
|
||||
match hdr.to_str() {
|
||||
Ok(val) => match val {
|
||||
"gzip" => {}
|
||||
"identity" => return Ok(response),
|
||||
_ => return Err("Reddit response was encoded with an unsupported compressor".to_string()),
|
||||
},
|
||||
Err(_) => return Err("Reddit response was invalid".to_string()),
|
||||
}
|
||||
|
||||
// We get here if the body is gzip-compressed.
|
||||
|
||||
// The body must be something that implements
|
||||
// std::io::Read, hence the conversion to
|
||||
// bytes::buf::Buf and then transformation into a
|
||||
// Reader.
|
||||
let mut decompressed: Vec<u8>;
|
||||
{
|
||||
let mut aggregated_body = match body::aggregate(response.body_mut()).await {
|
||||
Ok(b) => b.reader(),
|
||||
Err(e) => return Err(e.to_string()),
|
||||
};
|
||||
|
||||
let mut decoder = match gzip::Decoder::new(&mut aggregated_body) {
|
||||
Ok(decoder) => decoder,
|
||||
Err(e) => return Err(e.to_string()),
|
||||
};
|
||||
|
||||
decompressed = Vec::<u8>::new();
|
||||
if let Err(e) = io::copy(&mut decoder, &mut decompressed) {
|
||||
return Err(e.to_string());
|
||||
};
|
||||
}
|
||||
|
||||
response.headers_mut().remove(header::CONTENT_ENCODING);
|
||||
response.headers_mut().insert(header::CONTENT_LENGTH, decompressed.len().into());
|
||||
*(response.body_mut()) = Body::from(decompressed);
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => Err(e.to_string()),
|
||||
Err(e) => {
|
||||
dbg_msg!("{} {}: {}", method, path, e);
|
||||
|
||||
Err(e.to_string())
|
||||
}
|
||||
},
|
||||
Err(_) => Err("Post url contains non-ASCII characters".to_string()),
|
||||
}
|
||||
|
@ -114,9 +244,6 @@ fn request(url: String, quarantine: bool) -> Boxed<Result<Response<Body>, String
|
|||
// Make a request to a Reddit API and parse the JSON response
|
||||
#[cached(size = 100, time = 30, result = true)]
|
||||
pub async fn json(path: String, quarantine: bool) -> Result<Value, String> {
|
||||
// Build Reddit url from path
|
||||
let url = format!("https://www.reddit.com{}", path);
|
||||
|
||||
// Closure to quickly build errors
|
||||
let err = |msg: &str, e: String| -> Result<Value, String> {
|
||||
// eprintln!("{} - {}: {}", url, msg, e);
|
||||
|
@ -124,7 +251,7 @@ pub async fn json(path: String, quarantine: bool) -> Result<Value, String> {
|
|||
};
|
||||
|
||||
// Fetch the url...
|
||||
match request(url.clone(), quarantine).await {
|
||||
match reddit_get(path.clone(), quarantine).await {
|
||||
Ok(response) => {
|
||||
let status = response.status();
|
||||
|
||||
|
@ -142,7 +269,7 @@ pub async fn json(path: String, quarantine: bool) -> Result<Value, String> {
|
|||
.as_str()
|
||||
.unwrap_or_else(|| {
|
||||
json["message"].as_str().unwrap_or_else(|| {
|
||||
eprintln!("{} - Error parsing reddit error", url);
|
||||
eprintln!("{}{} - Error parsing reddit error", REDDIT_URL_BASE, path);
|
||||
"Error parsing reddit error"
|
||||
})
|
||||
})
|
||||
|
|
130
src/config.rs
Normal file
130
src/config.rs
Normal file
|
@ -0,0 +1,130 @@
|
|||
use once_cell::sync::Lazy;
|
||||
use std::{env::var, fs::read_to_string};
|
||||
|
||||
// Waiting for https://github.com/rust-lang/rust/issues/74465 to land, so we
|
||||
// can reduce reliance on once_cell.
|
||||
//
|
||||
// This is the local static that is initialized at runtime (technically at
|
||||
// first request) and contains the instance settings.
|
||||
static CONFIG: Lazy<Config> = Lazy::new(Config::load);
|
||||
|
||||
/// Stores the configuration parsed from the environment variables and the
|
||||
/// config file. `Config::Default()` contains None for each setting.
|
||||
#[derive(Default, serde::Deserialize)]
|
||||
pub struct Config {
|
||||
#[serde(rename = "LIBREDDIT_SFW_ONLY")]
|
||||
sfw_only: Option<String>,
|
||||
|
||||
#[serde(rename = "LIBREDDIT_DEFAULT_THEME")]
|
||||
default_theme: Option<String>,
|
||||
|
||||
#[serde(rename = "LIBREDDIT_DEFAULT_FRONT_PAGE")]
|
||||
default_front_page: Option<String>,
|
||||
|
||||
#[serde(rename = "LIBREDDIT_DEFAULT_LAYOUT")]
|
||||
default_layout: Option<String>,
|
||||
|
||||
#[serde(rename = "LIBREDDIT_DEFAULT_WIDE")]
|
||||
default_wide: Option<String>,
|
||||
|
||||
#[serde(rename = "LIBREDDIT_DEFAULT_COMMENT_SORT")]
|
||||
default_comment_sort: Option<String>,
|
||||
|
||||
#[serde(rename = "LIBREDDIT_DEFAULT_POST_SORT")]
|
||||
default_post_sort: Option<String>,
|
||||
|
||||
#[serde(rename = "LIBREDDIT_DEFAULT_SHOW_NSFW")]
|
||||
default_show_nsfw: Option<String>,
|
||||
|
||||
#[serde(rename = "LIBREDDIT_DEFAULT_BLUR_NSFW")]
|
||||
default_blur_nsfw: Option<String>,
|
||||
|
||||
#[serde(rename = "LIBREDDIT_DEFAULT_USE_HLS")]
|
||||
default_use_hls: Option<String>,
|
||||
|
||||
#[serde(rename = "LIBREDDIT_DEFAULT_HIDE_HLS_NOTIFICATION")]
|
||||
default_hide_hls_notification: Option<String>,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
/// Load the configuration from the environment variables and the config file.
|
||||
/// In the case that there are no environment variables set and there is no
|
||||
/// config file, this function returns a Config that contains all None values.
|
||||
pub fn load() -> Self {
|
||||
// Read from libreddit.toml config file. If for any reason, it fails, the
|
||||
// default `Config` is used (all None values)
|
||||
let config: Config = toml::from_str(&read_to_string("libreddit.toml").unwrap_or_default()).unwrap_or_default();
|
||||
// This function defines the order of preference - first check for
|
||||
// environment variables with "LIBREDDIT", then check the config, then if
|
||||
// both are `None`, return a `None` via the `map_or_else` function
|
||||
let parse = |key: &str| -> Option<String> { var(key).ok().map_or_else(|| get_setting_from_config(key, &config), Some) };
|
||||
Self {
|
||||
sfw_only: parse("LIBREDDIT_SFW_ONLY"),
|
||||
default_theme: parse("LIBREDDIT_DEFAULT_THEME"),
|
||||
default_front_page: parse("LIBREDDIT_DEFAULT_FRONT_PAGE"),
|
||||
default_layout: parse("LIBREDDIT_DEFAULT_LAYOUT"),
|
||||
default_post_sort: parse("LIBREDDIT_DEFAULT_POST_SORT"),
|
||||
default_wide: parse("LIBREDDIT_DEFAULT_WIDE"),
|
||||
default_comment_sort: parse("LIBREDDIT_DEFAULT_COMMENT_SORT"),
|
||||
default_show_nsfw: parse("LIBREDDIT_DEFAULT_SHOW_NSFW"),
|
||||
default_blur_nsfw: parse("LIBREDDIT_DEFAULT_BLUR_NSFW"),
|
||||
default_use_hls: parse("LIBREDDIT_DEFAULT_USE_HLS"),
|
||||
default_hide_hls_notification: parse("LIBREDDIT_DEFAULT_HIDE_HLS"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn get_setting_from_config(name: &str, config: &Config) -> Option<String> {
|
||||
match name {
|
||||
"LIBREDDIT_SFW_ONLY" => config.sfw_only.clone(),
|
||||
"LIBREDDIT_DEFAULT_THEME" => config.default_theme.clone(),
|
||||
"LIBREDDIT_DEFAULT_FRONT_PAGE" => config.default_front_page.clone(),
|
||||
"LIBREDDIT_DEFAULT_LAYOUT" => config.default_layout.clone(),
|
||||
"LIBREDDIT_DEFAULT_COMMENT_SORT" => config.default_comment_sort.clone(),
|
||||
"LIBREDDIT_DEFAULT_POST_SORT" => config.default_post_sort.clone(),
|
||||
"LIBREDDIT_DEFAULT_SHOW_NSFW" => config.default_show_nsfw.clone(),
|
||||
"LIBREDDIT_DEFAULT_BLUR_NSFW" => config.default_blur_nsfw.clone(),
|
||||
"LIBREDDIT_DEFAULT_USE_HLS" => config.default_use_hls.clone(),
|
||||
"LIBREDDIT_DEFAULT_HIDE_HLS_NOTIFICATION" => config.default_hide_hls_notification.clone(),
|
||||
"LIBREDDIT_DEFAULT_WIDE" => config.default_wide.clone(),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Retrieves setting from environment variable or config file.
|
||||
pub(crate) fn get_setting(name: &str) -> Option<String> {
|
||||
get_setting_from_config(name, &CONFIG)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
use {sealed_test::prelude::*, std::fs::write};
|
||||
|
||||
#[test]
|
||||
#[sealed_test(env = [("LIBREDDIT_SFW_ONLY", "on")])]
|
||||
fn test_env_var() {
|
||||
assert!(crate::utils::sfw_only())
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[sealed_test]
|
||||
fn test_config() {
|
||||
let config_to_write = r#"LIBREDDIT_DEFAULT_COMMENT_SORT = "best""#;
|
||||
write("libreddit.toml", config_to_write).unwrap();
|
||||
assert_eq!(get_setting("LIBREDDIT_DEFAULT_COMMENT_SORT"), Some("best".into()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[sealed_test(env = [("LIBREDDIT_DEFAULT_COMMENT_SORT", "top")])]
|
||||
fn test_env_config_precedence() {
|
||||
let config_to_write = r#"LIBREDDIT_DEFAULT_COMMENT_SORT = "best""#;
|
||||
write("libreddit.toml", config_to_write).unwrap();
|
||||
assert_eq!(get_setting("LIBREDDIT_DEFAULT_COMMENT_SORT"), Some("top".into()))
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[sealed_test(env = [("LIBREDDIT_DEFAULT_COMMENT_SORT", "top")])]
|
||||
fn test_alt_env_config_precedence() {
|
||||
let config_to_write = r#"LIBREDDIT_DEFAULT_COMMENT_SORT = "best""#;
|
||||
write("libreddit.toml", config_to_write).unwrap();
|
||||
assert_eq!(get_setting("LIBREDDIT_DEFAULT_COMMENT_SORT"), Some("top".into()))
|
||||
}
|
236
src/duplicates.rs
Normal file
236
src/duplicates.rs
Normal file
|
@ -0,0 +1,236 @@
|
|||
// Handler for post duplicates.
|
||||
|
||||
use crate::client::json;
|
||||
use crate::server::RequestExt;
|
||||
use crate::subreddit::{can_access_quarantine, quarantine};
|
||||
use crate::utils::{error, filter_posts, get_filters, nsfw_landing, parse_post, setting, template, Post, Preferences};
|
||||
|
||||
use askama::Template;
|
||||
use hyper::{Body, Request, Response};
|
||||
use serde_json::Value;
|
||||
use std::borrow::ToOwned;
|
||||
use std::collections::HashSet;
|
||||
use std::vec::Vec;
|
||||
|
||||
/// DuplicatesParams contains the parameters in the URL.
|
||||
struct DuplicatesParams {
|
||||
before: String,
|
||||
after: String,
|
||||
sort: String,
|
||||
}
|
||||
|
||||
/// DuplicatesTemplate defines an Askama template for rendering duplicate
|
||||
/// posts.
|
||||
#[derive(Template)]
|
||||
#[template(path = "duplicates.html")]
|
||||
struct DuplicatesTemplate {
|
||||
/// params contains the relevant request parameters.
|
||||
params: DuplicatesParams,
|
||||
|
||||
/// post is the post whose ID is specified in the reqeust URL. Note that
|
||||
/// this is not necessarily the "original" post.
|
||||
post: Post,
|
||||
|
||||
/// duplicates is the list of posts that, per Reddit, are duplicates of
|
||||
/// Post above.
|
||||
duplicates: Vec<Post>,
|
||||
|
||||
/// prefs are the user preferences.
|
||||
prefs: Preferences,
|
||||
|
||||
/// url is the request URL.
|
||||
url: String,
|
||||
|
||||
/// num_posts_filtered counts how many posts were filtered from the
|
||||
/// duplicates list.
|
||||
num_posts_filtered: u64,
|
||||
|
||||
/// all_posts_filtered is true if every duplicate was filtered. This is an
|
||||
/// edge case but can still happen.
|
||||
all_posts_filtered: bool,
|
||||
}
|
||||
|
||||
/// Make the GET request to Reddit. It assumes `req` is the appropriate Reddit
|
||||
/// REST endpoint for enumerating post duplicates.
|
||||
pub async fn item(req: Request<Body>) -> Result<Response<Body>, String> {
|
||||
let path: String = format!("{}.json?{}&raw_json=1", req.uri().path(), req.uri().query().unwrap_or_default());
|
||||
let sub = req.param("sub").unwrap_or_default();
|
||||
let quarantined = can_access_quarantine(&req, &sub);
|
||||
|
||||
// Log the request in debugging mode
|
||||
#[cfg(debug_assertions)]
|
||||
dbg!(req.param("id").unwrap_or_default());
|
||||
|
||||
// Send the GET, and await JSON.
|
||||
match json(path, quarantined).await {
|
||||
// Process response JSON.
|
||||
Ok(response) => {
|
||||
let post = parse_post(&response[0]["data"]["children"][0]).await;
|
||||
|
||||
// Return landing page if this post if this Reddit deems this post
|
||||
// NSFW, but we have also disabled the display of NSFW content
|
||||
// or if the instance is SFW-only.
|
||||
if post.nsfw && (setting(&req, "show_nsfw") != "on" || crate::utils::sfw_only()) {
|
||||
return Ok(nsfw_landing(req).await.unwrap_or_default());
|
||||
}
|
||||
|
||||
let filters = get_filters(&req);
|
||||
let (duplicates, num_posts_filtered, all_posts_filtered) = parse_duplicates(&response[1], &filters).await;
|
||||
|
||||
// These are the values for the "before=", "after=", and "sort="
|
||||
// query params, respectively.
|
||||
let mut before: String = String::new();
|
||||
let mut after: String = String::new();
|
||||
let mut sort: String = String::new();
|
||||
|
||||
// FIXME: We have to perform a kludge to work around a Reddit API
|
||||
// bug.
|
||||
//
|
||||
// The JSON object in "data" will never contain a "before" value so
|
||||
// it is impossible to use it to determine our position in a
|
||||
// listing. We'll make do by getting the ID of the first post in
|
||||
// the listing, setting that as our "before" value, and ask Reddit
|
||||
// to give us a batch of duplicate posts up to that post.
|
||||
//
|
||||
// Likewise, if we provide a "before" request in the GET, the
|
||||
// result won't have an "after" in the JSON, in addition to missing
|
||||
// the "before." So we will have to use the final post in the list
|
||||
// of duplicates.
|
||||
//
|
||||
// That being said, we'll also need to capture the value of the
|
||||
// "sort=" parameter as well, so we will need to inspect the
|
||||
// query key-value pairs anyway.
|
||||
let l = duplicates.len();
|
||||
if l > 0 {
|
||||
// This gets set to true if "before=" is one of the GET params.
|
||||
let mut have_before: bool = false;
|
||||
|
||||
// This gets set to true if "after=" is one of the GET params.
|
||||
let mut have_after: bool = false;
|
||||
|
||||
// Inspect the query key-value pairs. We will need to record
|
||||
// the value of "sort=", along with checking to see if either
|
||||
// one of "before=" or "after=" are given.
|
||||
//
|
||||
// If we're in the middle of the batch (evidenced by the
|
||||
// presence of a "before=" or "after=" parameter in the GET),
|
||||
// then use the first post as the "before" reference.
|
||||
//
|
||||
// We'll do this iteratively. Better than with .map_or()
|
||||
// since a closure will continue to operate on remaining
|
||||
// elements even after we've determined one of "before=" or
|
||||
// "after=" (or both) are in the GET request.
|
||||
//
|
||||
// In practice, here should only ever be one of "before=" or
|
||||
// "after=" and never both.
|
||||
let query_str = req.uri().query().unwrap_or_default().to_string();
|
||||
|
||||
if !query_str.is_empty() {
|
||||
for param in query_str.split('&') {
|
||||
let kv: Vec<&str> = param.split('=').collect();
|
||||
if kv.len() < 2 {
|
||||
// Reject invalid query parameter.
|
||||
continue;
|
||||
}
|
||||
|
||||
let key: &str = kv[0];
|
||||
match key {
|
||||
"before" => have_before = true,
|
||||
"after" => have_after = true,
|
||||
"sort" => {
|
||||
let val: &str = kv[1];
|
||||
match val {
|
||||
"new" | "num_comments" => sort = val.to_string(),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if have_after {
|
||||
before = "t3_".to_owned();
|
||||
before.push_str(&duplicates[0].id);
|
||||
}
|
||||
|
||||
// Address potentially missing "after". If "before=" is in the
|
||||
// GET, then "after" will be null in the JSON (see FIXME
|
||||
// above).
|
||||
if have_before {
|
||||
// The next batch will need to start from one after the
|
||||
// last post in the current batch.
|
||||
after = "t3_".to_owned();
|
||||
after.push_str(&duplicates[l - 1].id);
|
||||
|
||||
// Here is where things get terrible. Notice that we
|
||||
// haven't set `before`. In order to do so, we will
|
||||
// need to know if there is a batch that exists before
|
||||
// this one, and doing so requires actually fetching the
|
||||
// previous batch. In other words, we have to do yet one
|
||||
// more GET to Reddit. There is no other way to determine
|
||||
// whether or not to define `before`.
|
||||
//
|
||||
// We'll mitigate that by requesting at most one duplicate.
|
||||
let new_path: String = format!(
|
||||
"{}.json?before=t3_{}&sort={}&limit=1&raw_json=1",
|
||||
req.uri().path(),
|
||||
&duplicates[0].id,
|
||||
if sort.is_empty() { "num_comments".to_string() } else { sort.clone() }
|
||||
);
|
||||
match json(new_path, true).await {
|
||||
Ok(response) => {
|
||||
if !response[1]["data"]["children"].as_array().unwrap_or(&Vec::new()).is_empty() {
|
||||
before = "t3_".to_owned();
|
||||
before.push_str(&duplicates[0].id);
|
||||
}
|
||||
}
|
||||
Err(msg) => {
|
||||
// Abort entirely if we couldn't get the previous
|
||||
// batch.
|
||||
return error(req, msg).await;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
after = response[1]["data"]["after"].as_str().unwrap_or_default().to_string();
|
||||
}
|
||||
}
|
||||
let url = req.uri().to_string();
|
||||
|
||||
template(DuplicatesTemplate {
|
||||
params: DuplicatesParams { before, after, sort },
|
||||
post,
|
||||
duplicates,
|
||||
prefs: Preferences::new(&req),
|
||||
url,
|
||||
num_posts_filtered,
|
||||
all_posts_filtered,
|
||||
})
|
||||
}
|
||||
|
||||
// Process error.
|
||||
Err(msg) => {
|
||||
if msg == "quarantined" {
|
||||
let sub = req.param("sub").unwrap_or_default();
|
||||
quarantine(req, sub)
|
||||
} else {
|
||||
error(req, msg).await
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// DUPLICATES
|
||||
async fn parse_duplicates(json: &serde_json::Value, filters: &HashSet<String>) -> (Vec<Post>, u64, bool) {
|
||||
let post_duplicates: &Vec<Value> = &json["data"]["children"].as_array().map_or(Vec::new(), ToOwned::to_owned);
|
||||
let mut duplicates: Vec<Post> = Vec::new();
|
||||
|
||||
// Process each post and place them in the Vec<Post>.
|
||||
for val in post_duplicates.iter() {
|
||||
let post: Post = parse_post(val).await;
|
||||
duplicates.push(post);
|
||||
}
|
||||
|
||||
let (num_posts_filtered, all_posts_filtered) = filter_posts(&mut duplicates, filters);
|
||||
(duplicates, num_posts_filtered, all_posts_filtered)
|
||||
}
|
63
src/main.rs
63
src/main.rs
|
@ -3,6 +3,8 @@
|
|||
#![allow(clippy::cmp_owned)]
|
||||
|
||||
// Reference local files
|
||||
mod config;
|
||||
mod duplicates;
|
||||
mod post;
|
||||
mod search;
|
||||
mod settings;
|
||||
|
@ -11,13 +13,13 @@ mod user;
|
|||
mod utils;
|
||||
|
||||
// Import Crates
|
||||
use clap::{Arg, Command};
|
||||
use clap::{Arg, ArgAction, Command};
|
||||
|
||||
use futures_lite::FutureExt;
|
||||
use hyper::{header::HeaderValue, Body, Request, Response};
|
||||
|
||||
mod client;
|
||||
use client::proxy;
|
||||
use client::{canonical_path, proxy};
|
||||
use server::RequestExt;
|
||||
use utils::{error, redirect, ThemeAssets};
|
||||
|
||||
|
@ -112,7 +114,7 @@ async fn main() {
|
|||
.short('r')
|
||||
.long("redirect-https")
|
||||
.help("Redirect all HTTP requests to HTTPS (no longer functional)")
|
||||
.takes_value(false),
|
||||
.num_args(0),
|
||||
)
|
||||
.arg(
|
||||
Arg::new("address")
|
||||
|
@ -121,16 +123,18 @@ async fn main() {
|
|||
.value_name("ADDRESS")
|
||||
.help("Sets address to listen on")
|
||||
.default_value("0.0.0.0")
|
||||
.takes_value(true),
|
||||
.num_args(1),
|
||||
)
|
||||
.arg(
|
||||
Arg::new("port")
|
||||
.short('p')
|
||||
.long("port")
|
||||
.value_name("PORT")
|
||||
.env("PORT")
|
||||
.help("Port to listen on")
|
||||
.default_value("8080")
|
||||
.takes_value(true),
|
||||
.action(ArgAction::Set)
|
||||
.num_args(1),
|
||||
)
|
||||
.arg(
|
||||
Arg::new("hsts")
|
||||
|
@ -139,15 +143,15 @@ async fn main() {
|
|||
.value_name("EXPIRE_TIME")
|
||||
.help("HSTS header to tell browsers that this site should only be accessed over HTTPS")
|
||||
.default_value("604800")
|
||||
.takes_value(true),
|
||||
.num_args(1),
|
||||
)
|
||||
.get_matches();
|
||||
|
||||
let address = matches.value_of("address").unwrap_or("0.0.0.0");
|
||||
let port = std::env::var("PORT").unwrap_or_else(|_| matches.value_of("port").unwrap_or("8080").to_string());
|
||||
let hsts = matches.value_of("hsts");
|
||||
let address = matches.get_one::<String>("address").unwrap();
|
||||
let port = matches.get_one::<String>("port").unwrap();
|
||||
let hsts = matches.get_one("hsts").map(|m: &String| m.as_str());
|
||||
|
||||
let listener = [address, ":", &port].concat();
|
||||
let listener = [address, ":", port].concat();
|
||||
|
||||
println!("Starting Libreddit...");
|
||||
|
||||
|
@ -238,6 +242,16 @@ async fn main() {
|
|||
app.at("/r/:sub/comments/:id").get(|r| post::item(r).boxed());
|
||||
app.at("/r/:sub/comments/:id/:title").get(|r| post::item(r).boxed());
|
||||
app.at("/r/:sub/comments/:id/:title/:comment_id").get(|r| post::item(r).boxed());
|
||||
app.at("/comments/:id").get(|r| post::item(r).boxed());
|
||||
app.at("/comments/:id/comments").get(|r| post::item(r).boxed());
|
||||
app.at("/comments/:id/comments/:comment_id").get(|r| post::item(r).boxed());
|
||||
app.at("/comments/:id/:title").get(|r| post::item(r).boxed());
|
||||
app.at("/comments/:id/:title/:comment_id").get(|r| post::item(r).boxed());
|
||||
|
||||
app.at("/r/:sub/duplicates/:id").get(|r| duplicates::item(r).boxed());
|
||||
app.at("/r/:sub/duplicates/:id/:title").get(|r| duplicates::item(r).boxed());
|
||||
app.at("/duplicates/:id").get(|r| duplicates::item(r).boxed());
|
||||
app.at("/duplicates/:id/:title").get(|r| duplicates::item(r).boxed());
|
||||
|
||||
app.at("/r/:sub/search").get(|r| search::find(r).boxed());
|
||||
|
||||
|
@ -254,9 +268,6 @@ async fn main() {
|
|||
|
||||
app.at("/r/:sub/:sort").get(|r| subreddit::community(r).boxed());
|
||||
|
||||
// Comments handler
|
||||
app.at("/comments/:id").get(|r| post::item(r).boxed());
|
||||
|
||||
// Front page
|
||||
app.at("/").get(|r| subreddit::community(r).boxed());
|
||||
|
||||
|
@ -274,13 +285,25 @@ async fn main() {
|
|||
// Handle about pages
|
||||
app.at("/about").get(|req| error(req, "About pages aren't added yet".to_string()).boxed());
|
||||
|
||||
app.at("/:id").get(|req: Request<Body>| match req.param("id").as_deref() {
|
||||
// Sort front page
|
||||
Some("best" | "hot" | "new" | "top" | "rising" | "controversial") => subreddit::community(req).boxed(),
|
||||
// Short link for post
|
||||
Some(id) if id.len() > 4 && id.len() < 7 => post::item(req).boxed(),
|
||||
// Error message for unknown pages
|
||||
_ => error(req, "Nothing here".to_string()).boxed(),
|
||||
app.at("/:id").get(|req: Request<Body>| {
|
||||
Box::pin(async move {
|
||||
match req.param("id").as_deref() {
|
||||
// Sort front page
|
||||
Some("best" | "hot" | "new" | "top" | "rising" | "controversial") => subreddit::community(req).await,
|
||||
|
||||
// Short link for post
|
||||
Some(id) if (5..8).contains(&id.len()) => match canonical_path(format!("/{}", id)).await {
|
||||
Ok(path_opt) => match path_opt {
|
||||
Some(path) => Ok(redirect(path)),
|
||||
None => error(req, "Post ID is invalid. It may point to a post on a community that has been banned.").await,
|
||||
},
|
||||
Err(e) => error(req, e).await,
|
||||
},
|
||||
|
||||
// Error message for unknown pages
|
||||
_ => error(req, "Nothing here".to_string()).await,
|
||||
}
|
||||
})
|
||||
});
|
||||
|
||||
// Default service in case no routes match
|
||||
|
|
111
src/post.rs
111
src/post.rs
|
@ -3,7 +3,7 @@ use crate::client::json;
|
|||
use crate::server::RequestExt;
|
||||
use crate::subreddit::{can_access_quarantine, quarantine};
|
||||
use crate::utils::{
|
||||
error, format_num, format_url, get_filters, param, rewrite_urls, setting, template, time, val, Author, Awards, Comment, Flags, Flair, FlairPart, Media, Post, Preferences,
|
||||
error, format_num, get_filters, nsfw_landing, param, parse_post, rewrite_urls, setting, template, time, val, Author, Awards, Comment, Flair, FlairPart, Post, Preferences,
|
||||
};
|
||||
use hyper::{Body, Request, Response};
|
||||
|
||||
|
@ -54,8 +54,16 @@ pub async fn item(req: Request<Body>) -> Result<Response<Body>, String> {
|
|||
// Otherwise, grab the JSON output from the request
|
||||
Ok(response) => {
|
||||
// Parse the JSON into Post and Comment structs
|
||||
let post = parse_post(&response[0]).await;
|
||||
let comments = parse_comments(&response[1], &post.permalink, &post.author.name, highlighted_comment, &get_filters(&req));
|
||||
let post = parse_post(&response[0]["data"]["children"][0]).await;
|
||||
|
||||
// Return landing page if this post if this Reddit deems this post
|
||||
// NSFW, but we have also disabled the display of NSFW content
|
||||
// or if the instance is SFW-only.
|
||||
if post.nsfw && (setting(&req, "show_nsfw") != "on" || crate::utils::sfw_only()) {
|
||||
return Ok(nsfw_landing(req).await.unwrap_or_default());
|
||||
}
|
||||
|
||||
let comments = parse_comments(&response[1], &post.permalink, &post.author.name, highlighted_comment, &get_filters(&req), &req);
|
||||
let url = req.uri().to_string();
|
||||
|
||||
// Use the Post and Comment structs to generate a website to show users
|
||||
|
@ -63,7 +71,7 @@ pub async fn item(req: Request<Body>) -> Result<Response<Body>, String> {
|
|||
comments,
|
||||
post,
|
||||
sort,
|
||||
prefs: Preferences::new(req),
|
||||
prefs: Preferences::new(&req),
|
||||
single_thread,
|
||||
url,
|
||||
})
|
||||
|
@ -80,94 +88,8 @@ pub async fn item(req: Request<Body>) -> Result<Response<Body>, String> {
|
|||
}
|
||||
}
|
||||
|
||||
// POSTS
|
||||
async fn parse_post(json: &serde_json::Value) -> Post {
|
||||
// Retrieve post (as opposed to comments) from JSON
|
||||
let post: &serde_json::Value = &json["data"]["children"][0];
|
||||
|
||||
// Grab UTC time as unix timestamp
|
||||
let (rel_time, created) = time(post["data"]["created_utc"].as_f64().unwrap_or_default());
|
||||
// Parse post score and upvote ratio
|
||||
let score = post["data"]["score"].as_i64().unwrap_or_default();
|
||||
let ratio: f64 = post["data"]["upvote_ratio"].as_f64().unwrap_or(1.0) * 100.0;
|
||||
|
||||
// Determine the type of media along with the media URL
|
||||
let (post_type, media, gallery) = Media::parse(&post["data"]).await;
|
||||
|
||||
let awards: Awards = Awards::parse(&post["data"]["all_awardings"]);
|
||||
|
||||
let permalink = val(post, "permalink");
|
||||
|
||||
let body = if val(post, "removed_by_category") == "moderator" {
|
||||
format!(
|
||||
"<div class=\"md\"><p>[removed] — <a href=\"https://www.reveddit.com{}\">view removed post</a></p></div>",
|
||||
permalink
|
||||
)
|
||||
} else {
|
||||
rewrite_urls(&val(post, "selftext_html"))
|
||||
};
|
||||
|
||||
// Build a post using data parsed from Reddit post API
|
||||
Post {
|
||||
id: val(post, "id"),
|
||||
title: val(post, "title"),
|
||||
community: val(post, "subreddit"),
|
||||
body,
|
||||
author: Author {
|
||||
name: val(post, "author"),
|
||||
flair: Flair {
|
||||
flair_parts: FlairPart::parse(
|
||||
post["data"]["author_flair_type"].as_str().unwrap_or_default(),
|
||||
post["data"]["author_flair_richtext"].as_array(),
|
||||
post["data"]["author_flair_text"].as_str(),
|
||||
),
|
||||
text: val(post, "link_flair_text"),
|
||||
background_color: val(post, "author_flair_background_color"),
|
||||
foreground_color: val(post, "author_flair_text_color"),
|
||||
},
|
||||
distinguished: val(post, "distinguished"),
|
||||
},
|
||||
permalink,
|
||||
score: format_num(score),
|
||||
upvote_ratio: ratio as i64,
|
||||
post_type,
|
||||
media,
|
||||
thumbnail: Media {
|
||||
url: format_url(val(post, "thumbnail").as_str()),
|
||||
alt_url: String::new(),
|
||||
width: post["data"]["thumbnail_width"].as_i64().unwrap_or_default(),
|
||||
height: post["data"]["thumbnail_height"].as_i64().unwrap_or_default(),
|
||||
poster: "".to_string(),
|
||||
},
|
||||
flair: Flair {
|
||||
flair_parts: FlairPart::parse(
|
||||
post["data"]["link_flair_type"].as_str().unwrap_or_default(),
|
||||
post["data"]["link_flair_richtext"].as_array(),
|
||||
post["data"]["link_flair_text"].as_str(),
|
||||
),
|
||||
text: val(post, "link_flair_text"),
|
||||
background_color: val(post, "link_flair_background_color"),
|
||||
foreground_color: if val(post, "link_flair_text_color") == "dark" {
|
||||
"black".to_string()
|
||||
} else {
|
||||
"white".to_string()
|
||||
},
|
||||
},
|
||||
flags: Flags {
|
||||
nsfw: post["data"]["over_18"].as_bool().unwrap_or(false),
|
||||
stickied: post["data"]["stickied"].as_bool().unwrap_or(false),
|
||||
},
|
||||
domain: val(post, "domain"),
|
||||
rel_time,
|
||||
created,
|
||||
comments: format_num(post["data"]["num_comments"].as_i64().unwrap_or_default()),
|
||||
gallery,
|
||||
awards,
|
||||
}
|
||||
}
|
||||
|
||||
// COMMENTS
|
||||
fn parse_comments(json: &serde_json::Value, post_link: &str, post_author: &str, highlighted_comment: &str, filters: &HashSet<String>) -> Vec<Comment> {
|
||||
fn parse_comments(json: &serde_json::Value, post_link: &str, post_author: &str, highlighted_comment: &str, filters: &HashSet<String>, req: &Request<Body>) -> Vec<Comment> {
|
||||
// Parse the comment JSON into a Vector of Comments
|
||||
let comments = json["data"]["children"].as_array().map_or(Vec::new(), std::borrow::ToOwned::to_owned);
|
||||
|
||||
|
@ -187,7 +109,7 @@ fn parse_comments(json: &serde_json::Value, post_link: &str, post_author: &str,
|
|||
|
||||
// If this comment contains replies, handle those too
|
||||
let replies: Vec<Comment> = if data["replies"].is_object() {
|
||||
parse_comments(&data["replies"], post_link, post_author, highlighted_comment, filters)
|
||||
parse_comments(&data["replies"], post_link, post_author, highlighted_comment, filters, req)
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
|
@ -200,9 +122,9 @@ fn parse_comments(json: &serde_json::Value, post_link: &str, post_author: &str,
|
|||
let id = val(&comment, "id");
|
||||
let highlighted = id == highlighted_comment;
|
||||
|
||||
let body = if val(&comment, "author") == "[deleted]" && val(&comment, "body") == "[removed]" {
|
||||
let body = if (val(&comment, "author") == "[deleted]" && val(&comment, "body") == "[removed]") || val(&comment, "body") == "[ Removed by Reddit ]" {
|
||||
format!(
|
||||
"<div class=\"md\"><p>[removed] — <a href=\"https://www.reveddit.com{}{}\">view removed comment</a></p></div>",
|
||||
"<div class=\"md\"><p>[removed] — <a href=\"https://www.unddit.com{}{}\">view removed comment</a></p></div>",
|
||||
post_link, id
|
||||
)
|
||||
} else {
|
||||
|
@ -255,6 +177,7 @@ fn parse_comments(json: &serde_json::Value, post_link: &str, post_author: &str,
|
|||
awards,
|
||||
collapsed,
|
||||
is_filtered,
|
||||
prefs: Preferences::new(req),
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
// CRATES
|
||||
use crate::utils::{catch_random, error, filter_posts, format_num, format_url, get_filters, param, redirect, setting, template, val, Post, Preferences};
|
||||
use crate::utils::{self, catch_random, error, filter_posts, format_num, format_url, get_filters, param, redirect, setting, template, val, Post, Preferences};
|
||||
use crate::{
|
||||
client::json,
|
||||
subreddit::{can_access_quarantine, quarantine},
|
||||
|
@ -7,6 +7,8 @@ use crate::{
|
|||
};
|
||||
use askama::Template;
|
||||
use hyper::{Body, Request, Response};
|
||||
use once_cell::sync::Lazy;
|
||||
use regex::Regex;
|
||||
|
||||
// STRUCTS
|
||||
struct SearchParams {
|
||||
|
@ -42,13 +44,25 @@ struct SearchTemplate {
|
|||
/// Whether all fetched posts are filtered (to differentiate between no posts fetched in the first place,
|
||||
/// and all fetched posts being filtered).
|
||||
all_posts_filtered: bool,
|
||||
/// Whether all posts were hidden because they are NSFW (and user has disabled show NSFW)
|
||||
all_posts_hidden_nsfw: bool,
|
||||
no_posts: bool,
|
||||
}
|
||||
|
||||
// Regex matched against search queries to determine if they are reddit urls.
|
||||
static REDDIT_URL_MATCH: Lazy<Regex> = Lazy::new(|| Regex::new(r"^https?://([^\./]+\.)*reddit.com/").unwrap());
|
||||
|
||||
// SERVICES
|
||||
pub async fn find(req: Request<Body>) -> Result<Response<Body>, String> {
|
||||
let nsfw_results = if setting(&req, "show_nsfw") == "on" { "&include_over_18=on" } else { "" };
|
||||
// This ensures that during a search, no NSFW posts are fetched at all
|
||||
let nsfw_results = if setting(&req, "show_nsfw") == "on" && !utils::sfw_only() {
|
||||
"&include_over_18=on"
|
||||
} else {
|
||||
""
|
||||
};
|
||||
let path = format!("{}.json?{}{}&raw_json=1", req.uri().path(), req.uri().query().unwrap_or_default(), nsfw_results);
|
||||
let query = param(&path, "q").unwrap_or_default();
|
||||
let mut query = param(&path, "q").unwrap_or_default();
|
||||
query = REDDIT_URL_MATCH.replace(&query, "").to_string();
|
||||
|
||||
if query.is_empty() {
|
||||
return Ok(redirect("/".to_string()));
|
||||
|
@ -96,16 +110,19 @@ pub async fn find(req: Request<Body>) -> Result<Response<Body>, String> {
|
|||
restrict_sr: param(&path, "restrict_sr").unwrap_or_default(),
|
||||
typed,
|
||||
},
|
||||
prefs: Preferences::new(req),
|
||||
prefs: Preferences::new(&req),
|
||||
url,
|
||||
is_filtered: true,
|
||||
all_posts_filtered: false,
|
||||
all_posts_hidden_nsfw: false,
|
||||
no_posts: false,
|
||||
})
|
||||
} else {
|
||||
match Post::fetch(&path, quarantined).await {
|
||||
Ok((mut posts, after)) => {
|
||||
let all_posts_filtered = filter_posts(&mut posts, &filters);
|
||||
|
||||
let (_, all_posts_filtered) = filter_posts(&mut posts, &filters);
|
||||
let no_posts = posts.is_empty();
|
||||
let all_posts_hidden_nsfw = !no_posts && (posts.iter().all(|p| p.flags.nsfw) && setting(&req, "show_nsfw") != "on");
|
||||
template(SearchTemplate {
|
||||
posts,
|
||||
subreddits,
|
||||
|
@ -119,10 +136,12 @@ pub async fn find(req: Request<Body>) -> Result<Response<Body>, String> {
|
|||
restrict_sr: param(&path, "restrict_sr").unwrap_or_default(),
|
||||
typed,
|
||||
},
|
||||
prefs: Preferences::new(req),
|
||||
prefs: Preferences::new(&req),
|
||||
url,
|
||||
is_filtered: false,
|
||||
all_posts_filtered,
|
||||
all_posts_hidden_nsfw,
|
||||
no_posts,
|
||||
})
|
||||
}
|
||||
Err(msg) => {
|
||||
|
|
571
src/server.rs
571
src/server.rs
|
@ -1,17 +1,80 @@
|
|||
use brotli::enc::{BrotliCompress, BrotliEncoderParams};
|
||||
use cached::proc_macro::cached;
|
||||
use cookie::Cookie;
|
||||
use core::f64;
|
||||
use futures_lite::{future::Boxed, Future, FutureExt};
|
||||
use hyper::{
|
||||
header::HeaderValue,
|
||||
body,
|
||||
body::HttpBody,
|
||||
header,
|
||||
service::{make_service_fn, service_fn},
|
||||
HeaderMap,
|
||||
};
|
||||
use hyper::{Body, Method, Request, Response, Server as HyperServer};
|
||||
use libflate::gzip;
|
||||
use route_recognizer::{Params, Router};
|
||||
use std::{pin::Pin, result::Result};
|
||||
use std::{
|
||||
cmp::Ordering,
|
||||
io,
|
||||
pin::Pin,
|
||||
result::Result,
|
||||
str::{from_utf8, Split},
|
||||
string::ToString,
|
||||
};
|
||||
use time::Duration;
|
||||
|
||||
use crate::dbg_msg;
|
||||
|
||||
type BoxResponse = Pin<Box<dyn Future<Output = Result<Response<Body>, String>> + Send>>;
|
||||
|
||||
/// Compressors for the response Body, in ascending order of preference.
|
||||
#[derive(Copy, Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
|
||||
enum CompressionType {
|
||||
Passthrough,
|
||||
Gzip,
|
||||
Brotli,
|
||||
}
|
||||
|
||||
/// All browsers support gzip, so if we are given `Accept-Encoding: *`, deliver
|
||||
/// gzipped-content.
|
||||
///
|
||||
/// Brotli would be nice universally, but Safari (iOS, iPhone, macOS) reportedly
|
||||
/// doesn't support it yet.
|
||||
const DEFAULT_COMPRESSOR: CompressionType = CompressionType::Gzip;
|
||||
|
||||
impl CompressionType {
|
||||
/// Returns a `CompressionType` given a content coding
|
||||
/// in [RFC 7231](https://datatracker.ietf.org/doc/html/rfc7231#section-5.3.4)
|
||||
/// format.
|
||||
fn parse(s: &str) -> Option<CompressionType> {
|
||||
let c = match s {
|
||||
// Compressors we support.
|
||||
"gzip" => CompressionType::Gzip,
|
||||
"br" => CompressionType::Brotli,
|
||||
|
||||
// The wildcard means that we can choose whatever
|
||||
// compression we prefer. In this case, use the
|
||||
// default.
|
||||
"*" => DEFAULT_COMPRESSOR,
|
||||
|
||||
// Compressor not supported.
|
||||
_ => return None,
|
||||
};
|
||||
|
||||
Some(c)
|
||||
}
|
||||
}
|
||||
|
||||
impl ToString for CompressionType {
|
||||
fn to_string(&self) -> String {
|
||||
match self {
|
||||
CompressionType::Gzip => "gzip".to_string(),
|
||||
CompressionType::Brotli => "br".to_string(),
|
||||
_ => String::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Route<'a> {
|
||||
router: &'a mut Router<fn(Request<Body>) -> BoxResponse>,
|
||||
path: String,
|
||||
|
@ -97,7 +160,7 @@ impl ResponseExt for Response<Body> {
|
|||
}
|
||||
|
||||
fn insert_cookie(&mut self, cookie: Cookie) {
|
||||
if let Ok(val) = HeaderValue::from_str(&cookie.to_string()) {
|
||||
if let Ok(val) = header::HeaderValue::from_str(&cookie.to_string()) {
|
||||
self.headers_mut().append("Set-Cookie", val);
|
||||
}
|
||||
}
|
||||
|
@ -106,7 +169,7 @@ impl ResponseExt for Response<Body> {
|
|||
let mut cookie = Cookie::named(name);
|
||||
cookie.set_path("/");
|
||||
cookie.set_max_age(Duration::seconds(1));
|
||||
if let Ok(val) = HeaderValue::from_str(&cookie.to_string()) {
|
||||
if let Ok(val) = header::HeaderValue::from_str(&cookie.to_string()) {
|
||||
self.headers_mut().append("Set-Cookie", val);
|
||||
}
|
||||
}
|
||||
|
@ -156,10 +219,11 @@ impl Server {
|
|||
// let shared_router = router.clone();
|
||||
async move {
|
||||
Ok::<_, String>(service_fn(move |req: Request<Body>| {
|
||||
let headers = default_headers.clone();
|
||||
let req_headers = req.headers().clone();
|
||||
let def_headers = default_headers.clone();
|
||||
|
||||
// Remove double slashes and decode encoded slashes
|
||||
let mut path = req.uri().path().replace("//", "/").replace("%2F","/");
|
||||
let mut path = req.uri().path().replace("//", "/").replace("%2F", "/");
|
||||
|
||||
// Remove trailing slashes
|
||||
if path != "/" && path.ends_with('/') {
|
||||
|
@ -176,26 +240,20 @@ impl Server {
|
|||
// Run the route's function
|
||||
let func = (found.handler().to_owned().to_owned())(parammed);
|
||||
async move {
|
||||
let res: Result<Response<Body>, String> = func.await;
|
||||
// Add default headers to response
|
||||
res.map(|mut response| {
|
||||
response.headers_mut().extend(headers);
|
||||
response
|
||||
})
|
||||
match func.await {
|
||||
Ok(mut res) => {
|
||||
res.headers_mut().extend(def_headers);
|
||||
let _ = compress_response(&req_headers, &mut res).await;
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
Err(msg) => new_boilerplate(def_headers, req_headers, 500, Body::from(msg)).await,
|
||||
}
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
// If there was a routing error
|
||||
Err(e) => async move {
|
||||
// Return a 404 error
|
||||
let res: Result<Response<Body>, String> = Ok(Response::builder().status(404).body(e.into()).unwrap_or_default());
|
||||
// Add default headers to response
|
||||
res.map(|mut response| {
|
||||
response.headers_mut().extend(headers);
|
||||
response
|
||||
})
|
||||
}
|
||||
.boxed(),
|
||||
Err(e) => async move { new_boilerplate(def_headers, req_headers, 404, e.into()).await }.boxed(),
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
@ -213,3 +271,472 @@ impl Server {
|
|||
server.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a boilerplate Response for error conditions. This response will be
|
||||
/// compressed if requested by client.
|
||||
async fn new_boilerplate(
|
||||
default_headers: HeaderMap<header::HeaderValue>,
|
||||
req_headers: HeaderMap<header::HeaderValue>,
|
||||
status: u16,
|
||||
body: Body,
|
||||
) -> Result<Response<Body>, String> {
|
||||
match Response::builder().status(status).body(body) {
|
||||
Ok(mut res) => {
|
||||
let _ = compress_response(&req_headers, &mut res).await;
|
||||
|
||||
res.headers_mut().extend(default_headers.clone());
|
||||
Ok(res)
|
||||
}
|
||||
Err(msg) => Err(msg.to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Determines the desired compressor based on the Accept-Encoding header.
|
||||
///
|
||||
/// This function will honor the [q-value](https://developer.mozilla.org/en-US/docs/Glossary/Quality_values)
|
||||
/// for each compressor. The q-value is an optional parameter, a decimal value
|
||||
/// on \[0..1\], to order the compressors by preference. An Accept-Encoding value
|
||||
/// with no q-values is also accepted.
|
||||
///
|
||||
/// Here are [examples](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Accept-Encoding#examples)
|
||||
/// of valid Accept-Encoding headers.
|
||||
///
|
||||
/// ```http
|
||||
/// Accept-Encoding: gzip
|
||||
/// Accept-Encoding: gzip, compress, br
|
||||
/// Accept-Encoding: br;q=1.0, gzip;q=0.8, *;q=0.1
|
||||
/// ```
|
||||
#[cached]
|
||||
fn determine_compressor(accept_encoding: String) -> Option<CompressionType> {
|
||||
if accept_encoding.is_empty() {
|
||||
return None;
|
||||
};
|
||||
|
||||
// Keep track of the compressor candidate based on both the client's
|
||||
// preference and our own. Concrete examples:
|
||||
//
|
||||
// 1. "Accept-Encoding: gzip, br" => assuming we like brotli more than
|
||||
// gzip, and the browser supports brotli, we choose brotli
|
||||
//
|
||||
// 2. "Accept-Encoding: gzip;q=0.8, br;q=0.3" => the client has stated a
|
||||
// preference for gzip over brotli, so we choose gzip
|
||||
//
|
||||
// To do this, we need to define a struct which contains the requested
|
||||
// requested compressor (abstracted as a CompressionType enum) and the
|
||||
// q-value. If no q-value is defined for the compressor, we assume one of
|
||||
// 1.0. We first compare compressor candidates by comparing q-values, and
|
||||
// then CompressionTypes. We keep track of whatever is the greatest per our
|
||||
// ordering.
|
||||
|
||||
struct CompressorCandidate {
|
||||
alg: CompressionType,
|
||||
q: f64,
|
||||
}
|
||||
|
||||
impl Ord for CompressorCandidate {
|
||||
fn cmp(&self, other: &Self) -> Ordering {
|
||||
// Compare q-values. Break ties with the
|
||||
// CompressionType values.
|
||||
|
||||
match self.q.total_cmp(&other.q) {
|
||||
Ordering::Equal => self.alg.cmp(&other.alg),
|
||||
ord => ord,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialOrd for CompressorCandidate {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
||||
// Guard against NAN, both on our end and on the other.
|
||||
if self.q.is_nan() || other.q.is_nan() {
|
||||
return None;
|
||||
};
|
||||
|
||||
// f64 and CompressionType are ordered, except in the case
|
||||
// where the f64 is NAN (which we checked against), so we
|
||||
// can safely return a Some here.
|
||||
Some(self.cmp(other))
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq for CompressorCandidate {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
(self.q == other.q) && (self.alg == other.alg)
|
||||
}
|
||||
}
|
||||
|
||||
impl Eq for CompressorCandidate {}
|
||||
|
||||
// This is the current candidate.
|
||||
//
|
||||
// Assmume no candidate so far. We do this by assigning the sentinel value
|
||||
// of negative infinity to the q-value. If this value is negative infinity,
|
||||
// that means there was no viable compressor candidate.
|
||||
let mut cur_candidate = CompressorCandidate {
|
||||
alg: CompressionType::Passthrough,
|
||||
q: f64::NEG_INFINITY,
|
||||
};
|
||||
|
||||
// This loop reads the requested compressors and keeps track of whichever
|
||||
// one has the highest priority per our heuristic.
|
||||
for val in accept_encoding.to_string().split(',') {
|
||||
let mut q: f64 = 1.0;
|
||||
|
||||
// The compressor and q-value (if the latter is defined)
|
||||
// will be delimited by semicolons.
|
||||
let mut spl: Split<char> = val.split(';');
|
||||
|
||||
// Get the compressor. For example, in
|
||||
// gzip;q=0.8
|
||||
// this grabs "gzip" in the string. It
|
||||
// will further validate the compressor against the
|
||||
// list of those we support. If it is not supported,
|
||||
// we move onto the next one.
|
||||
let compressor: CompressionType = match spl.next() {
|
||||
// CompressionType::parse will return the appropriate enum given
|
||||
// a string. For example, it will return CompressionType::Gzip
|
||||
// when given "gzip".
|
||||
Some(s) => match CompressionType::parse(s.trim()) {
|
||||
Some(candidate) => candidate,
|
||||
|
||||
// We don't support the requested compression algorithm.
|
||||
None => continue,
|
||||
},
|
||||
|
||||
// We should never get here, but I'm paranoid.
|
||||
None => continue,
|
||||
};
|
||||
|
||||
// Get the q-value. This might not be defined, in which case assume
|
||||
// 1.0.
|
||||
if let Some(s) = spl.next() {
|
||||
if !(s.len() > 2 && s.starts_with("q=")) {
|
||||
// If the q-value is malformed, the header is malformed, so
|
||||
// abort.
|
||||
return None;
|
||||
}
|
||||
|
||||
match s[2..].parse::<f64>() {
|
||||
Ok(val) => {
|
||||
if (0.0..=1.0).contains(&val) {
|
||||
q = val;
|
||||
} else {
|
||||
// If the value is outside [0..1], header is malformed.
|
||||
// Abort.
|
||||
return None;
|
||||
};
|
||||
}
|
||||
Err(_) => {
|
||||
// If this isn't a f64, then assume a malformed header
|
||||
// value and abort.
|
||||
return None;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// If new_candidate > cur_candidate, make new_candidate the new
|
||||
// cur_candidate. But do this safely! It is very possible that
|
||||
// someone gave us the string "NAN", which (&str).parse::<f64>
|
||||
// will happily translate to f64::NAN.
|
||||
let new_candidate = CompressorCandidate { alg: compressor, q };
|
||||
if let Some(ord) = new_candidate.partial_cmp(&cur_candidate) {
|
||||
if ord == Ordering::Greater {
|
||||
cur_candidate = new_candidate;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
if cur_candidate.q != f64::NEG_INFINITY {
|
||||
Some(cur_candidate.alg)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Compress the response body, if possible or desirable. The Body will be
|
||||
/// compressed in place, and a new header Content-Encoding will be set
|
||||
/// indicating the compression algorithm.
|
||||
///
|
||||
/// This function deems Body eligible compression if and only if the following
|
||||
/// conditions are met:
|
||||
///
|
||||
/// 1. the HTTP client requests a compression encoding in the Content-Encoding
|
||||
/// header (hence the need for the req_headers);
|
||||
///
|
||||
/// 2. the content encoding corresponds to a compression algorithm we support;
|
||||
///
|
||||
/// 3. the Media type in the Content-Type response header is text with any
|
||||
/// subtype (e.g. text/plain) or application/json.
|
||||
///
|
||||
/// compress_response returns Ok on successful compression, or if not all three
|
||||
/// conditions above are met. It returns Err if there was a problem decoding
|
||||
/// any header in either req_headers or res, but res will remain intact.
|
||||
///
|
||||
/// This function logs errors to stderr, but only in debug mode. No information
|
||||
/// is logged in release builds.
|
||||
async fn compress_response(req_headers: &HeaderMap<header::HeaderValue>, res: &mut Response<Body>) -> Result<(), String> {
|
||||
// Check if the data is eligible for compression.
|
||||
if let Some(hdr) = res.headers().get(header::CONTENT_TYPE) {
|
||||
match from_utf8(hdr.as_bytes()) {
|
||||
Ok(val) => {
|
||||
let s = val.to_string();
|
||||
|
||||
// TODO: better determination of what is eligible for compression
|
||||
if !(s.starts_with("text/") || s.starts_with("application/json")) {
|
||||
return Ok(());
|
||||
};
|
||||
}
|
||||
Err(e) => {
|
||||
dbg_msg!(e);
|
||||
return Err(e.to_string());
|
||||
}
|
||||
};
|
||||
} else {
|
||||
// Response declares no Content-Type. Assume for simplicity that it
|
||||
// cannot be compressed.
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
// Don't bother if the size of the size of the response body will fit
|
||||
// within an IP frame (less the bytes that make up the TCP/IP and HTTP
|
||||
// headers).
|
||||
if res.body().size_hint().lower() < 1452 {
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
// Check to see which compressor is requested, and if we can use it.
|
||||
let accept_encoding: String = match req_headers.get(header::ACCEPT_ENCODING) {
|
||||
None => return Ok(()), // Client requested no compression.
|
||||
|
||||
Some(hdr) => match String::from_utf8(hdr.as_bytes().into()) {
|
||||
Ok(val) => val,
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
Err(e) => {
|
||||
dbg_msg!(e);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
#[cfg(not(debug_assertions))]
|
||||
Err(_) => return Ok(()),
|
||||
},
|
||||
};
|
||||
|
||||
let compressor: CompressionType = match determine_compressor(accept_encoding) {
|
||||
Some(c) => c,
|
||||
None => return Ok(()),
|
||||
};
|
||||
|
||||
// Get the body from the response.
|
||||
let body_bytes: Vec<u8> = match body::to_bytes(res.body_mut()).await {
|
||||
Ok(b) => b.to_vec(),
|
||||
Err(e) => {
|
||||
dbg_msg!(e);
|
||||
return Err(e.to_string());
|
||||
}
|
||||
};
|
||||
|
||||
// Compress!
|
||||
match compress_body(compressor, body_bytes) {
|
||||
Ok(compressed) => {
|
||||
// We get here iff the compression was successful. Replace the body
|
||||
// with the compressed payload, and add the appropriate
|
||||
// Content-Encoding header in the response.
|
||||
res.headers_mut().insert(header::CONTENT_ENCODING, compressor.to_string().parse().unwrap());
|
||||
*(res.body_mut()) = Body::from(compressed);
|
||||
}
|
||||
|
||||
Err(e) => return Err(e),
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Compresses a `Vec<u8>` given a [`CompressionType`].
|
||||
///
|
||||
/// This is a helper function for [`compress_response`] and should not be
|
||||
/// called directly.
|
||||
|
||||
// I've chosen a TTL of 600 (== 10 minutes) since compression is
|
||||
// computationally expensive and we don't want to be doing it often. This is
|
||||
// larger than client::json's TTL, but that's okay, because if client::json
|
||||
// returns a new serde_json::Value, body_bytes changes, so this function will
|
||||
// execute again.
|
||||
#[cached(size = 100, time = 600, result = true)]
|
||||
fn compress_body(compressor: CompressionType, body_bytes: Vec<u8>) -> Result<Vec<u8>, String> {
|
||||
// io::Cursor implements io::Read, required for our encoders.
|
||||
let mut reader = io::Cursor::new(body_bytes);
|
||||
|
||||
let compressed: Vec<u8> = match compressor {
|
||||
CompressionType::Gzip => {
|
||||
let mut gz: gzip::Encoder<Vec<u8>> = match gzip::Encoder::new(Vec::new()) {
|
||||
Ok(gz) => gz,
|
||||
Err(e) => {
|
||||
dbg_msg!(e);
|
||||
return Err(e.to_string());
|
||||
}
|
||||
};
|
||||
|
||||
match io::copy(&mut reader, &mut gz) {
|
||||
Ok(_) => match gz.finish().into_result() {
|
||||
Ok(compressed) => compressed,
|
||||
Err(e) => {
|
||||
dbg_msg!(e);
|
||||
return Err(e.to_string());
|
||||
}
|
||||
},
|
||||
Err(e) => {
|
||||
dbg_msg!(e);
|
||||
return Err(e.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CompressionType::Brotli => {
|
||||
// We may want to make the compression parameters configurable
|
||||
// in the future. For now, the defaults are sufficient.
|
||||
let brotli_params = BrotliEncoderParams::default();
|
||||
|
||||
let mut compressed = Vec::<u8>::new();
|
||||
match BrotliCompress(&mut reader, &mut compressed, &brotli_params) {
|
||||
Ok(_) => compressed,
|
||||
Err(e) => {
|
||||
dbg_msg!(e);
|
||||
return Err(e.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// This arm is for any requested compressor for which we don't yet
|
||||
// have an implementation.
|
||||
_ => {
|
||||
let msg = "unsupported compressor".to_string();
|
||||
return Err(msg);
|
||||
}
|
||||
};
|
||||
|
||||
Ok(compressed)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use brotli::Decompressor as BrotliDecompressor;
|
||||
use futures_lite::future::block_on;
|
||||
use lipsum::lipsum;
|
||||
use std::{boxed::Box, io};
|
||||
|
||||
#[test]
|
||||
fn test_determine_compressor() {
|
||||
// Single compressor given.
|
||||
assert_eq!(determine_compressor("unsupported".to_string()), None);
|
||||
assert_eq!(determine_compressor("gzip".to_string()), Some(CompressionType::Gzip));
|
||||
assert_eq!(determine_compressor("*".to_string()), Some(DEFAULT_COMPRESSOR));
|
||||
|
||||
// Multiple compressors.
|
||||
assert_eq!(determine_compressor("gzip, br".to_string()), Some(CompressionType::Brotli));
|
||||
assert_eq!(determine_compressor("gzip;q=0.8, br;q=0.3".to_string()), Some(CompressionType::Gzip));
|
||||
assert_eq!(determine_compressor("br, gzip".to_string()), Some(CompressionType::Brotli));
|
||||
assert_eq!(determine_compressor("br;q=0.3, gzip;q=0.4".to_string()), Some(CompressionType::Gzip));
|
||||
|
||||
// Invalid q-values.
|
||||
assert_eq!(determine_compressor("gzip;q=NAN".to_string()), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compress_response() {
|
||||
// This macro generates an Accept-Encoding header value given any number of
|
||||
// compressors.
|
||||
macro_rules! ae_gen {
|
||||
($x:expr) => {
|
||||
$x.to_string().as_str()
|
||||
};
|
||||
|
||||
($x:expr, $($y:expr),+) => {
|
||||
format!("{}, {}", $x.to_string(), ae_gen!($($y),+)).as_str()
|
||||
};
|
||||
}
|
||||
|
||||
for accept_encoding in [
|
||||
"*",
|
||||
ae_gen!(CompressionType::Gzip),
|
||||
ae_gen!(CompressionType::Brotli, CompressionType::Gzip),
|
||||
ae_gen!(CompressionType::Brotli),
|
||||
] {
|
||||
// Determine what the expected encoding should be based on both the
|
||||
// specific encodings we accept.
|
||||
let expected_encoding: CompressionType = match determine_compressor(accept_encoding.to_string()) {
|
||||
Some(s) => s,
|
||||
None => panic!("determine_compressor(accept_encoding.to_string()) => None"),
|
||||
};
|
||||
|
||||
// Build headers with our Accept-Encoding.
|
||||
let mut req_headers = HeaderMap::new();
|
||||
req_headers.insert(header::ACCEPT_ENCODING, header::HeaderValue::from_str(accept_encoding).unwrap());
|
||||
|
||||
// Build test response.
|
||||
let lorem_ipsum: String = lipsum(10000);
|
||||
let expected_lorem_ipsum = Vec::<u8>::from(lorem_ipsum.as_str());
|
||||
let mut res = Response::builder()
|
||||
.status(200)
|
||||
.header(header::CONTENT_TYPE, "text/plain")
|
||||
.body(Body::from(lorem_ipsum))
|
||||
.unwrap();
|
||||
|
||||
// Perform the compression.
|
||||
if let Err(e) = block_on(compress_response(&req_headers, &mut res)) {
|
||||
panic!("compress_response(&req_headers, &mut res) => Err(\"{}\")", e);
|
||||
};
|
||||
|
||||
// If the content was compressed, we expect the Content-Encoding
|
||||
// header to be modified.
|
||||
assert_eq!(
|
||||
res
|
||||
.headers()
|
||||
.get(header::CONTENT_ENCODING)
|
||||
.unwrap_or_else(|| panic!("missing content-encoding header"))
|
||||
.to_str()
|
||||
.unwrap_or_else(|_| panic!("failed to convert Content-Encoding header::HeaderValue to String")),
|
||||
expected_encoding.to_string()
|
||||
);
|
||||
|
||||
// Decompress body and make sure it's equal to what we started
|
||||
// with.
|
||||
//
|
||||
// In the case of no compression, just make sure the "new" body in
|
||||
// the Response is the same as what with which we start.
|
||||
let body_vec = match block_on(body::to_bytes(res.body_mut())) {
|
||||
Ok(b) => b.to_vec(),
|
||||
Err(e) => panic!("{}", e),
|
||||
};
|
||||
|
||||
if expected_encoding == CompressionType::Passthrough {
|
||||
assert!(body_vec.eq(&expected_lorem_ipsum));
|
||||
continue;
|
||||
}
|
||||
|
||||
// This provides an io::Read for the underlying body.
|
||||
let mut body_cursor: io::Cursor<Vec<u8>> = io::Cursor::new(body_vec);
|
||||
|
||||
// Match the appropriate decompresor for the given
|
||||
// expected_encoding.
|
||||
let mut decoder: Box<dyn io::Read> = match expected_encoding {
|
||||
CompressionType::Gzip => match gzip::Decoder::new(&mut body_cursor) {
|
||||
Ok(dgz) => Box::new(dgz),
|
||||
Err(e) => panic!("{}", e),
|
||||
},
|
||||
|
||||
CompressionType::Brotli => Box::new(BrotliDecompressor::new(body_cursor, expected_lorem_ipsum.len())),
|
||||
|
||||
_ => panic!("no decompressor for {}", expected_encoding.to_string()),
|
||||
};
|
||||
|
||||
let mut decompressed = Vec::<u8>::new();
|
||||
if let Err(e) = io::copy(&mut decoder, &mut decompressed) {
|
||||
panic!("{}", e);
|
||||
};
|
||||
|
||||
assert!(decompressed.eq(&expected_lorem_ipsum));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -19,7 +19,7 @@ struct SettingsTemplate {
|
|||
|
||||
// CONSTANTS
|
||||
|
||||
const PREFS: [&str; 11] = [
|
||||
const PREFS: [&str; 13] = [
|
||||
"theme",
|
||||
"front_page",
|
||||
"layout",
|
||||
|
@ -27,10 +27,12 @@ const PREFS: [&str; 11] = [
|
|||
"comment_sort",
|
||||
"post_sort",
|
||||
"show_nsfw",
|
||||
"blur_nsfw",
|
||||
"use_hls",
|
||||
"hide_hls_notification",
|
||||
"autoplay_videos",
|
||||
"fixed_navbar",
|
||||
"hide_awards",
|
||||
];
|
||||
|
||||
// FUNCTIONS
|
||||
|
@ -39,7 +41,7 @@ const PREFS: [&str; 11] = [
|
|||
pub async fn get(req: Request<Body>) -> Result<Response<Body>, String> {
|
||||
let url = req.uri().to_string();
|
||||
template(SettingsTemplate {
|
||||
prefs: Preferences::new(req),
|
||||
prefs: Preferences::new(&req),
|
||||
url,
|
||||
})
|
||||
}
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
// CRATES
|
||||
use crate::utils::{
|
||||
catch_random, error, filter_posts, format_num, format_url, get_filters, param, redirect, rewrite_urls, setting, template, val, Post, Preferences, Subreddit,
|
||||
catch_random, error, filter_posts, format_num, format_url, get_filters, nsfw_landing, param, redirect, rewrite_urls, setting, template, val, Post, Preferences, Subreddit,
|
||||
};
|
||||
use crate::{client::json, server::ResponseExt, RequestExt};
|
||||
use askama::Template;
|
||||
|
@ -24,6 +24,9 @@ struct SubredditTemplate {
|
|||
/// Whether all fetched posts are filtered (to differentiate between no posts fetched in the first place,
|
||||
/// and all fetched posts being filtered).
|
||||
all_posts_filtered: bool,
|
||||
/// Whether all posts were hidden because they are NSFW (and user has disabled show NSFW)
|
||||
all_posts_hidden_nsfw: bool,
|
||||
no_posts: bool,
|
||||
}
|
||||
|
||||
#[derive(Template)]
|
||||
|
@ -94,6 +97,12 @@ pub async fn community(req: Request<Body>) -> Result<Response<Body>, String> {
|
|||
}
|
||||
};
|
||||
|
||||
// Return landing page if this post if this is NSFW community but the user
|
||||
// has disabled the display of NSFW content or if the instance is SFW-only.
|
||||
if sub.nsfw && (setting(&req, "show_nsfw") != "on" || crate::utils::sfw_only()) {
|
||||
return Ok(nsfw_landing(req).await.unwrap_or_default());
|
||||
}
|
||||
|
||||
let path = format!("/r/{}/{}.json?{}&raw_json=1", sub_name.clone(), sort, req.uri().query().unwrap_or_default());
|
||||
let url = String::from(req.uri().path_and_query().map_or("", |val| val.as_str()));
|
||||
let redirect_url = url[1..].replace('?', "%3F").replace('&', "%26").replace('+', "%2B");
|
||||
|
@ -106,27 +115,32 @@ pub async fn community(req: Request<Body>) -> Result<Response<Body>, String> {
|
|||
posts: Vec::new(),
|
||||
sort: (sort, param(&path, "t").unwrap_or_default()),
|
||||
ends: (param(&path, "after").unwrap_or_default(), "".to_string()),
|
||||
prefs: Preferences::new(req),
|
||||
prefs: Preferences::new(&req),
|
||||
url,
|
||||
redirect_url,
|
||||
is_filtered: true,
|
||||
all_posts_filtered: false,
|
||||
all_posts_hidden_nsfw: false,
|
||||
no_posts: false,
|
||||
})
|
||||
} else {
|
||||
match Post::fetch(&path, quarantined).await {
|
||||
Ok((mut posts, after)) => {
|
||||
let all_posts_filtered = filter_posts(&mut posts, &filters);
|
||||
|
||||
let (_, all_posts_filtered) = filter_posts(&mut posts, &filters);
|
||||
let no_posts = posts.is_empty();
|
||||
let all_posts_hidden_nsfw = !no_posts && (posts.iter().all(|p| p.flags.nsfw) && setting(&req, "show_nsfw") != "on");
|
||||
template(SubredditTemplate {
|
||||
sub,
|
||||
posts,
|
||||
sort: (sort, param(&path, "t").unwrap_or_default()),
|
||||
ends: (param(&path, "after").unwrap_or_default(), after),
|
||||
prefs: Preferences::new(req),
|
||||
prefs: Preferences::new(&req),
|
||||
url,
|
||||
redirect_url,
|
||||
is_filtered: false,
|
||||
all_posts_filtered,
|
||||
all_posts_hidden_nsfw,
|
||||
no_posts,
|
||||
})
|
||||
}
|
||||
Err(msg) => match msg.as_str() {
|
||||
|
@ -145,7 +159,7 @@ pub fn quarantine(req: Request<Body>, sub: String) -> Result<Response<Body>, Str
|
|||
msg: "Please click the button below to continue to this subreddit.".to_string(),
|
||||
url: req.uri().to_string(),
|
||||
sub,
|
||||
prefs: Preferences::new(req),
|
||||
prefs: Preferences::new(&req),
|
||||
};
|
||||
|
||||
Ok(
|
||||
|
@ -192,7 +206,7 @@ pub async fn subscriptions_filters(req: Request<Body>) -> Result<Response<Body>,
|
|||
|
||||
let query = req.uri().query().unwrap_or_default().to_string();
|
||||
|
||||
let preferences = Preferences::new(req);
|
||||
let preferences = Preferences::new(&req);
|
||||
let mut sub_list = preferences.subscriptions;
|
||||
let mut filters = preferences.filters;
|
||||
|
||||
|
@ -305,7 +319,7 @@ pub async fn wiki(req: Request<Body>) -> Result<Response<Body>, String> {
|
|||
sub,
|
||||
wiki: rewrite_urls(response["data"]["content_html"].as_str().unwrap_or("<h3>Wiki not found</h3>")),
|
||||
page,
|
||||
prefs: Preferences::new(req),
|
||||
prefs: Preferences::new(&req),
|
||||
url,
|
||||
}),
|
||||
Err(msg) => {
|
||||
|
@ -343,7 +357,7 @@ pub async fn sidebar(req: Request<Body>) -> Result<Response<Body>, String> {
|
|||
// ),
|
||||
sub,
|
||||
page: "Sidebar".to_string(),
|
||||
prefs: Preferences::new(req),
|
||||
prefs: Preferences::new(&req),
|
||||
url,
|
||||
}),
|
||||
Err(msg) => {
|
||||
|
@ -416,5 +430,6 @@ async fn subreddit(sub: &str, quarantined: bool) -> Result<Subreddit, String> {
|
|||
members: format_num(members),
|
||||
active: format_num(active),
|
||||
wiki: res["data"]["wiki_enabled"].as_bool().unwrap_or_default(),
|
||||
nsfw: res["data"]["over18"].as_bool().unwrap_or_default(),
|
||||
})
|
||||
}
|
||||
|
|
28
src/user.rs
28
src/user.rs
|
@ -1,7 +1,7 @@
|
|||
// CRATES
|
||||
use crate::client::json;
|
||||
use crate::server::RequestExt;
|
||||
use crate::utils::{error, filter_posts, format_url, get_filters, param, template, Post, Preferences, User};
|
||||
use crate::utils::{error, filter_posts, format_url, get_filters, nsfw_landing, param, setting, template, Post, Preferences, User};
|
||||
use askama::Template;
|
||||
use hyper::{Body, Request, Response};
|
||||
use time::{macros::format_description, OffsetDateTime};
|
||||
|
@ -24,6 +24,9 @@ struct UserTemplate {
|
|||
/// Whether all fetched posts are filtered (to differentiate between no posts fetched in the first place,
|
||||
/// and all fetched posts being filtered).
|
||||
all_posts_filtered: bool,
|
||||
/// Whether all posts were hidden because they are NSFW (and user has disabled show NSFW)
|
||||
all_posts_hidden_nsfw: bool,
|
||||
no_posts: bool,
|
||||
}
|
||||
|
||||
// FUNCTIONS
|
||||
|
@ -43,8 +46,17 @@ pub async fn profile(req: Request<Body>) -> Result<Response<Body>, String> {
|
|||
// Retrieve other variables from Libreddit request
|
||||
let sort = param(&path, "sort").unwrap_or_default();
|
||||
let username = req.param("name").unwrap_or_default();
|
||||
|
||||
// Retrieve info from user about page.
|
||||
let user = user(&username).await.unwrap_or_default();
|
||||
|
||||
// Return landing page if this post if this Reddit deems this user NSFW,
|
||||
// but we have also disabled the display of NSFW content or if the instance
|
||||
// is SFW-only.
|
||||
if user.nsfw && (setting(&req, "show_nsfw") != "on" || crate::utils::sfw_only()) {
|
||||
return Ok(nsfw_landing(req).await.unwrap_or_default());
|
||||
}
|
||||
|
||||
let filters = get_filters(&req);
|
||||
if filters.contains(&["u_", &username].concat()) {
|
||||
template(UserTemplate {
|
||||
|
@ -53,29 +65,34 @@ pub async fn profile(req: Request<Body>) -> Result<Response<Body>, String> {
|
|||
sort: (sort, param(&path, "t").unwrap_or_default()),
|
||||
ends: (param(&path, "after").unwrap_or_default(), "".to_string()),
|
||||
listing,
|
||||
prefs: Preferences::new(req),
|
||||
prefs: Preferences::new(&req),
|
||||
url,
|
||||
redirect_url,
|
||||
is_filtered: true,
|
||||
all_posts_filtered: false,
|
||||
all_posts_hidden_nsfw: false,
|
||||
no_posts: false,
|
||||
})
|
||||
} else {
|
||||
// Request user posts/comments from Reddit
|
||||
match Post::fetch(&path, false).await {
|
||||
Ok((mut posts, after)) => {
|
||||
let all_posts_filtered = filter_posts(&mut posts, &filters);
|
||||
|
||||
let (_, all_posts_filtered) = filter_posts(&mut posts, &filters);
|
||||
let no_posts = posts.is_empty();
|
||||
let all_posts_hidden_nsfw = !no_posts && (posts.iter().all(|p| p.flags.nsfw) && setting(&req, "show_nsfw") != "on");
|
||||
template(UserTemplate {
|
||||
user,
|
||||
posts,
|
||||
sort: (sort, param(&path, "t").unwrap_or_default()),
|
||||
ends: (param(&path, "after").unwrap_or_default(), after),
|
||||
listing,
|
||||
prefs: Preferences::new(req),
|
||||
prefs: Preferences::new(&req),
|
||||
url,
|
||||
redirect_url,
|
||||
is_filtered: false,
|
||||
all_posts_filtered,
|
||||
all_posts_hidden_nsfw,
|
||||
no_posts,
|
||||
})
|
||||
}
|
||||
// If there is an error show error page
|
||||
|
@ -107,6 +124,7 @@ async fn user(name: &str) -> Result<User, String> {
|
|||
created: created.format(format_description!("[month repr:short] [day] '[year repr:last_two]")).unwrap_or_default(),
|
||||
banner: about("banner_img"),
|
||||
description: about("public_description"),
|
||||
nsfw: res["data"]["subreddit"]["over_18"].as_bool().unwrap_or_default(),
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
260
src/utils.rs
260
src/utils.rs
|
@ -9,10 +9,36 @@ use regex::Regex;
|
|||
use rust_embed::RustEmbed;
|
||||
use serde_json::Value;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::env;
|
||||
use std::str::FromStr;
|
||||
use time::{macros::format_description, Duration, OffsetDateTime};
|
||||
use url::Url;
|
||||
|
||||
/// Write a message to stderr on debug mode. This function is a no-op on
|
||||
/// release code.
|
||||
#[macro_export]
|
||||
macro_rules! dbg_msg {
|
||||
($x:expr) => {
|
||||
#[cfg(debug_assertions)]
|
||||
eprintln!("{}:{}: {}", file!(), line!(), $x.to_string())
|
||||
};
|
||||
|
||||
($($x:expr),+) => {
|
||||
#[cfg(debug_assertions)]
|
||||
dbg_msg!(format!($($x),+))
|
||||
};
|
||||
}
|
||||
|
||||
/// Identifies whether or not the page is a subreddit, a user page, or a post.
|
||||
/// This is used by the NSFW landing template to determine the mesage to convey
|
||||
/// to the user.
|
||||
#[derive(PartialEq, Eq)]
|
||||
pub enum ResourceType {
|
||||
Subreddit,
|
||||
User,
|
||||
Post,
|
||||
}
|
||||
|
||||
// Post flair with content, background color and foreground color
|
||||
pub struct Flair {
|
||||
pub flair_parts: Vec<FlairPart>,
|
||||
|
@ -210,9 +236,11 @@ pub struct Post {
|
|||
pub domain: String,
|
||||
pub rel_time: String,
|
||||
pub created: String,
|
||||
pub num_duplicates: u64,
|
||||
pub comments: (String, String),
|
||||
pub gallery: Vec<GalleryMedia>,
|
||||
pub awards: Awards,
|
||||
pub nsfw: bool,
|
||||
}
|
||||
|
||||
impl Post {
|
||||
|
@ -304,14 +332,16 @@ impl Post {
|
|||
},
|
||||
flags: Flags {
|
||||
nsfw: data["over_18"].as_bool().unwrap_or_default(),
|
||||
stickied: data["stickied"].as_bool().unwrap_or_default(),
|
||||
stickied: data["stickied"].as_bool().unwrap_or_default() || data["pinned"].as_bool().unwrap_or_default(),
|
||||
},
|
||||
permalink: val(post, "permalink"),
|
||||
rel_time,
|
||||
created,
|
||||
num_duplicates: post["data"]["num_duplicates"].as_u64().unwrap_or(0),
|
||||
comments: format_num(data["num_comments"].as_i64().unwrap_or_default()),
|
||||
gallery,
|
||||
awards,
|
||||
nsfw: post["data"]["over_18"].as_bool().unwrap_or_default(),
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -340,6 +370,7 @@ pub struct Comment {
|
|||
pub awards: Awards,
|
||||
pub collapsed: bool,
|
||||
pub is_filtered: bool,
|
||||
pub prefs: Preferences,
|
||||
}
|
||||
|
||||
#[derive(Default, Clone)]
|
||||
|
@ -403,6 +434,27 @@ pub struct ErrorTemplate {
|
|||
pub url: String,
|
||||
}
|
||||
|
||||
/// Template for NSFW landing page. The landing page is displayed when a page's
|
||||
/// content is wholly NSFW, but a user has not enabled the option to view NSFW
|
||||
/// posts.
|
||||
#[derive(Template)]
|
||||
#[template(path = "nsfwlanding.html")]
|
||||
pub struct NSFWLandingTemplate {
|
||||
/// Identifier for the resource. This is either a subreddit name or a
|
||||
/// username. (In the case of the latter, set is_user to true.)
|
||||
pub res: String,
|
||||
|
||||
/// Identifies whether or not the resource is a subreddit, a user page,
|
||||
/// or a post.
|
||||
pub res_type: ResourceType,
|
||||
|
||||
/// User preferences.
|
||||
pub prefs: Preferences,
|
||||
|
||||
/// Request URL.
|
||||
pub url: String,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
// User struct containing metadata about user
|
||||
pub struct User {
|
||||
|
@ -413,6 +465,7 @@ pub struct User {
|
|||
pub created: String,
|
||||
pub banner: String,
|
||||
pub description: String,
|
||||
pub nsfw: bool,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
|
@ -427,6 +480,7 @@ pub struct Subreddit {
|
|||
pub members: (String, String),
|
||||
pub active: (String, String),
|
||||
pub wiki: bool,
|
||||
pub nsfw: bool,
|
||||
}
|
||||
|
||||
// Parser for query params, used in sorting (eg. /r/rust/?sort=hot)
|
||||
|
@ -447,6 +501,7 @@ pub struct Preferences {
|
|||
pub layout: String,
|
||||
pub wide: String,
|
||||
pub show_nsfw: String,
|
||||
pub blur_nsfw: String,
|
||||
pub hide_hls_notification: String,
|
||||
pub use_hls: String,
|
||||
pub autoplay_videos: String,
|
||||
|
@ -455,6 +510,7 @@ pub struct Preferences {
|
|||
pub post_sort: String,
|
||||
pub subscriptions: Vec<String>,
|
||||
pub filters: Vec<String>,
|
||||
pub hide_awards: String,
|
||||
}
|
||||
|
||||
#[derive(RustEmbed)]
|
||||
|
@ -464,7 +520,7 @@ pub struct ThemeAssets;
|
|||
|
||||
impl Preferences {
|
||||
// Build preferences from cookies
|
||||
pub fn new(req: Request<Body>) -> Self {
|
||||
pub fn new(req: &Request<Body>) -> Self {
|
||||
// Read available theme names from embedded css files.
|
||||
// Always make the default "system" theme available.
|
||||
let mut themes = vec!["system".to_string()];
|
||||
|
@ -479,6 +535,7 @@ impl Preferences {
|
|||
layout: setting(&req, "layout"),
|
||||
wide: setting(&req, "wide"),
|
||||
show_nsfw: setting(&req, "show_nsfw"),
|
||||
blur_nsfw: setting(&req, "blur_nsfw"),
|
||||
use_hls: setting(&req, "use_hls"),
|
||||
hide_hls_notification: setting(&req, "hide_hls_notification"),
|
||||
autoplay_videos: setting(&req, "autoplay_videos"),
|
||||
|
@ -487,6 +544,7 @@ impl Preferences {
|
|||
post_sort: setting(&req, "post_sort"),
|
||||
subscriptions: setting(&req, "subscriptions").split('+').map(String::from).filter(|s| !s.is_empty()).collect(),
|
||||
filters: setting(&req, "filters").split('+').map(String::from).filter(|s| !s.is_empty()).collect(),
|
||||
hide_awards: setting(&req, "hide_awards"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -496,15 +554,111 @@ pub fn get_filters(req: &Request<Body>) -> HashSet<String> {
|
|||
setting(req, "filters").split('+').map(String::from).filter(|s| !s.is_empty()).collect::<HashSet<String>>()
|
||||
}
|
||||
|
||||
/// Filters a `Vec<Post>` by the given `HashSet` of filters (each filter being a subreddit name or a user name). If a
|
||||
/// `Post`'s subreddit or author is found in the filters, it is removed. Returns `true` if _all_ posts were filtered
|
||||
/// out, or `false` otherwise.
|
||||
pub fn filter_posts(posts: &mut Vec<Post>, filters: &HashSet<String>) -> bool {
|
||||
/// Filters a `Vec<Post>` by the given `HashSet` of filters (each filter being
|
||||
/// a subreddit name or a user name). If a `Post`'s subreddit or author is
|
||||
/// found in the filters, it is removed.
|
||||
///
|
||||
/// The first value of the return tuple is the number of posts filtered. The
|
||||
/// second return value is `true` if all posts were filtered.
|
||||
pub fn filter_posts(posts: &mut Vec<Post>, filters: &HashSet<String>) -> (u64, bool) {
|
||||
// This is the length of the Vec<Post> prior to applying the filter.
|
||||
let lb: u64 = posts.len().try_into().unwrap_or(0);
|
||||
|
||||
if posts.is_empty() {
|
||||
false
|
||||
(0, false)
|
||||
} else {
|
||||
posts.retain(|p| !filters.contains(&p.community) && !filters.contains(&["u_", &p.author.name].concat()));
|
||||
posts.is_empty()
|
||||
posts.retain(|p| !(filters.contains(&p.community) || filters.contains(&["u_", &p.author.name].concat())));
|
||||
|
||||
// Get the length of the Vec<Post> after applying the filter.
|
||||
// If lb > la, then at least one post was removed.
|
||||
let la: u64 = posts.len().try_into().unwrap_or(0);
|
||||
|
||||
(lb - la, posts.is_empty())
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a [`Post`] from a provided JSON.
|
||||
pub async fn parse_post(post: &serde_json::Value) -> Post {
|
||||
// Grab UTC time as unix timestamp
|
||||
let (rel_time, created) = time(post["data"]["created_utc"].as_f64().unwrap_or_default());
|
||||
// Parse post score and upvote ratio
|
||||
let score = post["data"]["score"].as_i64().unwrap_or_default();
|
||||
let ratio: f64 = post["data"]["upvote_ratio"].as_f64().unwrap_or(1.0) * 100.0;
|
||||
|
||||
// Determine the type of media along with the media URL
|
||||
let (post_type, media, gallery) = Media::parse(&post["data"]).await;
|
||||
|
||||
let awards: Awards = Awards::parse(&post["data"]["all_awardings"]);
|
||||
|
||||
let permalink = val(post, "permalink");
|
||||
|
||||
let body = if val(post, "removed_by_category") == "moderator" {
|
||||
format!(
|
||||
"<div class=\"md\"><p>[removed] — <a href=\"https://www.unddit.com{}\">view removed post</a></p></div>",
|
||||
permalink
|
||||
)
|
||||
} else {
|
||||
rewrite_urls(&val(post, "selftext_html"))
|
||||
};
|
||||
|
||||
// Build a post using data parsed from Reddit post API
|
||||
Post {
|
||||
id: val(post, "id"),
|
||||
title: val(post, "title"),
|
||||
community: val(post, "subreddit"),
|
||||
body,
|
||||
author: Author {
|
||||
name: val(post, "author"),
|
||||
flair: Flair {
|
||||
flair_parts: FlairPart::parse(
|
||||
post["data"]["author_flair_type"].as_str().unwrap_or_default(),
|
||||
post["data"]["author_flair_richtext"].as_array(),
|
||||
post["data"]["author_flair_text"].as_str(),
|
||||
),
|
||||
text: val(post, "link_flair_text"),
|
||||
background_color: val(post, "author_flair_background_color"),
|
||||
foreground_color: val(post, "author_flair_text_color"),
|
||||
},
|
||||
distinguished: val(post, "distinguished"),
|
||||
},
|
||||
permalink,
|
||||
score: format_num(score),
|
||||
upvote_ratio: ratio as i64,
|
||||
post_type,
|
||||
media,
|
||||
thumbnail: Media {
|
||||
url: format_url(val(post, "thumbnail").as_str()),
|
||||
alt_url: String::new(),
|
||||
width: post["data"]["thumbnail_width"].as_i64().unwrap_or_default(),
|
||||
height: post["data"]["thumbnail_height"].as_i64().unwrap_or_default(),
|
||||
poster: String::new(),
|
||||
},
|
||||
flair: Flair {
|
||||
flair_parts: FlairPart::parse(
|
||||
post["data"]["link_flair_type"].as_str().unwrap_or_default(),
|
||||
post["data"]["link_flair_richtext"].as_array(),
|
||||
post["data"]["link_flair_text"].as_str(),
|
||||
),
|
||||
text: val(post, "link_flair_text"),
|
||||
background_color: val(post, "link_flair_background_color"),
|
||||
foreground_color: if val(post, "link_flair_text_color") == "dark" {
|
||||
"black".to_string()
|
||||
} else {
|
||||
"white".to_string()
|
||||
},
|
||||
},
|
||||
flags: Flags {
|
||||
nsfw: post["data"]["over_18"].as_bool().unwrap_or_default(),
|
||||
stickied: post["data"]["stickied"].as_bool().unwrap_or_default() || post["data"]["pinned"].as_bool().unwrap_or(false),
|
||||
},
|
||||
domain: val(post, "domain"),
|
||||
rel_time,
|
||||
created,
|
||||
num_duplicates: post["data"]["num_duplicates"].as_u64().unwrap_or(0),
|
||||
comments: format_num(post["data"]["num_comments"].as_i64().unwrap_or_default()),
|
||||
gallery,
|
||||
awards,
|
||||
nsfw: post["data"]["over_18"].as_bool().unwrap_or_default(),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -531,8 +685,8 @@ pub fn setting(req: &Request<Body>, name: &str) -> String {
|
|||
req
|
||||
.cookie(name)
|
||||
.unwrap_or_else(|| {
|
||||
// If there is no cookie for this setting, try receiving a default from an environment variable
|
||||
if let Ok(default) = std::env::var(format!("LIBREDDIT_DEFAULT_{}", name.to_uppercase())) {
|
||||
// If there is no cookie for this setting, try receiving a default from the config
|
||||
if let Some(default) = crate::config::get_setting(&format!("LIBREDDIT_DEFAULT_{}", name.to_uppercase())) {
|
||||
Cookie::new(name, default)
|
||||
} else {
|
||||
Cookie::named(name)
|
||||
|
@ -713,11 +867,12 @@ pub fn redirect(path: String) -> Response<Body> {
|
|||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
pub async fn error(req: Request<Body>, msg: String) -> Result<Response<Body>, String> {
|
||||
/// Renders a generic error landing page.
|
||||
pub async fn error(req: Request<Body>, msg: impl ToString) -> Result<Response<Body>, String> {
|
||||
let url = req.uri().to_string();
|
||||
let body = ErrorTemplate {
|
||||
msg,
|
||||
prefs: Preferences::new(req),
|
||||
msg: msg.to_string(),
|
||||
prefs: Preferences::new(&req),
|
||||
url,
|
||||
}
|
||||
.render()
|
||||
|
@ -726,10 +881,54 @@ pub async fn error(req: Request<Body>, msg: String) -> Result<Response<Body>, St
|
|||
Ok(Response::builder().status(404).header("content-type", "text/html").body(body.into()).unwrap_or_default())
|
||||
}
|
||||
|
||||
/// Returns true if the config/env variable `LIBREDDIT_SFW_ONLY` carries the
|
||||
/// value `on`.
|
||||
///
|
||||
/// If this variable is set as such, the instance will operate in SFW-only
|
||||
/// mode; all NSFW content will be filtered. Attempts to access NSFW
|
||||
/// subreddits or posts or userpages for users Reddit has deemed NSFW will
|
||||
/// be denied.
|
||||
pub fn sfw_only() -> bool {
|
||||
match crate::config::get_setting("LIBREDDIT_SFW_ONLY") {
|
||||
Some(val) => val == "on",
|
||||
None => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Renders the landing page for NSFW content when the user has not enabled
|
||||
/// "show NSFW posts" in settings.
|
||||
pub async fn nsfw_landing(req: Request<Body>) -> Result<Response<Body>, String> {
|
||||
let res_type: ResourceType;
|
||||
let url = req.uri().to_string();
|
||||
|
||||
// Determine from the request URL if the resource is a subreddit, a user
|
||||
// page, or a post.
|
||||
let res: String = if !req.param("name").unwrap_or_default().is_empty() {
|
||||
res_type = ResourceType::User;
|
||||
req.param("name").unwrap_or_default()
|
||||
} else if !req.param("id").unwrap_or_default().is_empty() {
|
||||
res_type = ResourceType::Post;
|
||||
req.param("id").unwrap_or_default()
|
||||
} else {
|
||||
res_type = ResourceType::Subreddit;
|
||||
req.param("sub").unwrap_or_default()
|
||||
};
|
||||
|
||||
let body = NSFWLandingTemplate {
|
||||
res,
|
||||
res_type,
|
||||
prefs: Preferences::new(&req),
|
||||
url,
|
||||
}
|
||||
.render()
|
||||
.unwrap_or_default();
|
||||
|
||||
Ok(Response::builder().status(403).header("content-type", "text/html").body(body.into()).unwrap_or_default())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::format_num;
|
||||
use super::rewrite_urls;
|
||||
use super::{format_num, format_url, rewrite_urls};
|
||||
|
||||
#[test]
|
||||
fn format_num_works() {
|
||||
|
@ -749,4 +948,33 @@ mod tests {
|
|||
r#"<a href="https://www.reddit.com/r/linux_gaming/comments/x/just_a_test/">https://www.reddit.com/r/linux_gaming/comments/x/just_a_test/</a>"#
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_url() {
|
||||
assert_eq!(format_url("https://a.thumbs.redditmedia.com/XYZ.jpg"), "/thumb/a/XYZ.jpg");
|
||||
assert_eq!(format_url("https://emoji.redditmedia.com/a/b"), "/emoji/a/b");
|
||||
|
||||
assert_eq!(
|
||||
format_url("https://external-preview.redd.it/foo.jpg?auto=webp&s=bar"),
|
||||
"/preview/external-pre/foo.jpg?auto=webp&s=bar"
|
||||
);
|
||||
|
||||
assert_eq!(format_url("https://i.redd.it/foobar.jpg"), "/img/foobar.jpg");
|
||||
assert_eq!(
|
||||
format_url("https://preview.redd.it/qwerty.jpg?auto=webp&s=asdf"),
|
||||
"/preview/pre/qwerty.jpg?auto=webp&s=asdf"
|
||||
);
|
||||
assert_eq!(format_url("https://v.redd.it/foo/DASH_360.mp4?source=fallback"), "/vid/foo/360.mp4");
|
||||
assert_eq!(
|
||||
format_url("https://v.redd.it/foo/HLSPlaylist.m3u8?a=bar&v=1&f=sd"),
|
||||
"/hls/foo/HLSPlaylist.m3u8?a=bar&v=1&f=sd"
|
||||
);
|
||||
assert_eq!(format_url("https://www.redditstatic.com/gold/awards/icon/icon.png"), "/static/gold/awards/icon/icon.png");
|
||||
|
||||
assert_eq!(format_url(""), "");
|
||||
assert_eq!(format_url("self"), "");
|
||||
assert_eq!(format_url("default"), "");
|
||||
assert_eq!(format_url("nsfw"), "");
|
||||
assert_eq!(format_url("spoiler"), "");
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue