Merge branch 'master' into feature/fixed-navbar

This commit is contained in:
Spike 2023-01-16 19:43:54 -08:00 committed by GitHub
commit bb5f2674d1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
41 changed files with 2834 additions and 787 deletions

View file

@ -1,12 +1,55 @@
use cached::proc_macro::cached;
use futures_lite::{future::Boxed, FutureExt};
use hyper::{body::Buf, client, Body, Request, Response, Uri};
use hyper::{body, body::Buf, client, header, Body, Method, Request, Response, Uri};
use libflate::gzip;
use percent_encoding::{percent_encode, CONTROLS};
use serde_json::Value;
use std::result::Result;
use std::{io, result::Result};
use crate::dbg_msg;
use crate::server::RequestExt;
const REDDIT_URL_BASE: &str = "https://www.reddit.com";
/// Gets the canonical path for a resource on Reddit. This is accomplished by
/// making a `HEAD` request to Reddit at the path given in `path`.
///
/// This function returns `Ok(Some(path))`, where `path`'s value is identical
/// to that of the value of the argument `path`, if Reddit responds to our
/// `HEAD` request with a 2xx-family HTTP code. It will also return an
/// `Ok(Some(String))` if Reddit responds to our `HEAD` request with a
/// `Location` header in the response, and the HTTP code is in the 3xx-family;
/// the `String` will contain the path as reported in `Location`. The return
/// value is `Ok(None)` if Reddit responded with a 3xx, but did not provide a
/// `Location` header. An `Err(String)` is returned if Reddit responds with a
/// 429, or if we were unable to decode the value in the `Location` header.
#[cached(size = 1024, time = 600, result = true)]
pub async fn canonical_path(path: String) -> Result<Option<String>, String> {
let res = reddit_head(path.clone(), true).await?;
if res.status() == 429 {
return Err("Too many requests.".to_string());
};
// If Reddit responds with a 2xx, then the path is already canonical.
if res.status().to_string().starts_with('2') {
return Ok(Some(path));
}
// If Reddit responds with anything other than 3xx (except for the 2xx as
// above), return a None.
if !res.status().to_string().starts_with('3') {
return Ok(None);
}
Ok(
res
.headers()
.get(header::LOCATION)
.map(|val| percent_encode(val.as_bytes(), CONTROLS).to_string().trim_start_matches(REDDIT_URL_BASE).to_string()),
)
}
pub async fn proxy(req: Request<Body>, format: &str) -> Result<Response<Body>, String> {
let mut url = format!("{}?{}", format, req.uri().query().unwrap_or_default());
@ -62,20 +105,39 @@ async fn stream(url: &str, req: &Request<Body>) -> Result<Response<Body>, String
.map_err(|e| e.to_string())
}
fn request(url: String, quarantine: bool) -> Boxed<Result<Response<Body>, String>> {
/// Makes a GET request to Reddit at `path`. By default, this will honor HTTP
/// 3xx codes Reddit returns and will automatically redirect.
fn reddit_get(path: String, quarantine: bool) -> Boxed<Result<Response<Body>, String>> {
request(&Method::GET, path, true, quarantine)
}
/// Makes a HEAD request to Reddit at `path`. This will not follow redirects.
fn reddit_head(path: String, quarantine: bool) -> Boxed<Result<Response<Body>, String>> {
request(&Method::HEAD, path, false, quarantine)
}
/// Makes a request to Reddit. If `redirect` is `true`, request_with_redirect
/// will recurse on the URL that Reddit provides in the Location HTTP header
/// in its response.
fn request(method: &'static Method, path: String, redirect: bool, quarantine: bool) -> Boxed<Result<Response<Body>, String>> {
// Build Reddit URL from path.
let url = format!("{}{}", REDDIT_URL_BASE, path);
// Prepare the HTTPS connector.
let https = hyper_rustls::HttpsConnectorBuilder::new().with_native_roots().https_or_http().enable_http1().build();
// Construct the hyper client from the HTTPS connector.
let client: client::Client<_, hyper::Body> = client::Client::builder().build(https);
// Build request
// Build request to Reddit. When making a GET, request gzip compression.
// (Reddit doesn't do brotli yet.)
let builder = Request::builder()
.method("GET")
.method(method)
.uri(&url)
.header("User-Agent", format!("web:libreddit:{}", env!("CARGO_PKG_VERSION")))
.header("Host", "www.reddit.com")
.header("Accept", "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,*/*;q=0.8")
.header("Accept-Encoding", if method == Method::GET { "gzip" } else { "identity" })
.header("Accept-Language", "en-US,en;q=0.5")
.header("Connection", "keep-alive")
.header("Cookie", if quarantine { "_options=%7B%22pref_quarantine_optin%22%3A%20true%7D" } else { "" })
@ -84,26 +146,94 @@ fn request(url: String, quarantine: bool) -> Boxed<Result<Response<Body>, String
async move {
match builder {
Ok(req) => match client.request(req).await {
Ok(response) => {
Ok(mut response) => {
// Reddit may respond with a 3xx. Decide whether or not to
// redirect based on caller params.
if response.status().to_string().starts_with('3') {
request(
if !redirect {
return Ok(response);
};
return request(
method,
response
.headers()
.get("Location")
.get(header::LOCATION)
.map(|val| {
let new_url = percent_encode(val.as_bytes(), CONTROLS).to_string();
format!("{}{}raw_json=1", new_url, if new_url.contains('?') { "&" } else { "?" })
// We need to make adjustments to the URI
// we get back from Reddit. Namely, we
// must:
//
// 1. Remove the authority (e.g.
// https://www.reddit.com) that may be
// present, so that we recurse on the
// path (and query parameters) as
// required.
//
// 2. Percent-encode the path.
let new_path = percent_encode(val.as_bytes(), CONTROLS).to_string().trim_start_matches(REDDIT_URL_BASE).to_string();
format!("{}{}raw_json=1", new_path, if new_path.contains('?') { "&" } else { "?" })
})
.unwrap_or_default()
.to_string(),
true,
quarantine,
)
.await
} else {
Ok(response)
.await;
};
match response.headers().get(header::CONTENT_ENCODING) {
// Content not compressed.
None => Ok(response),
// Content encoded (hopefully with gzip).
Some(hdr) => {
match hdr.to_str() {
Ok(val) => match val {
"gzip" => {}
"identity" => return Ok(response),
_ => return Err("Reddit response was encoded with an unsupported compressor".to_string()),
},
Err(_) => return Err("Reddit response was invalid".to_string()),
}
// We get here if the body is gzip-compressed.
// The body must be something that implements
// std::io::Read, hence the conversion to
// bytes::buf::Buf and then transformation into a
// Reader.
let mut decompressed: Vec<u8>;
{
let mut aggregated_body = match body::aggregate(response.body_mut()).await {
Ok(b) => b.reader(),
Err(e) => return Err(e.to_string()),
};
let mut decoder = match gzip::Decoder::new(&mut aggregated_body) {
Ok(decoder) => decoder,
Err(e) => return Err(e.to_string()),
};
decompressed = Vec::<u8>::new();
if let Err(e) = io::copy(&mut decoder, &mut decompressed) {
return Err(e.to_string());
};
}
response.headers_mut().remove(header::CONTENT_ENCODING);
response.headers_mut().insert(header::CONTENT_LENGTH, decompressed.len().into());
*(response.body_mut()) = Body::from(decompressed);
Ok(response)
}
}
}
Err(e) => Err(e.to_string()),
Err(e) => {
dbg_msg!("{} {}: {}", method, path, e);
Err(e.to_string())
}
},
Err(_) => Err("Post url contains non-ASCII characters".to_string()),
}
@ -114,9 +244,6 @@ fn request(url: String, quarantine: bool) -> Boxed<Result<Response<Body>, String
// Make a request to a Reddit API and parse the JSON response
#[cached(size = 100, time = 30, result = true)]
pub async fn json(path: String, quarantine: bool) -> Result<Value, String> {
// Build Reddit url from path
let url = format!("https://www.reddit.com{}", path);
// Closure to quickly build errors
let err = |msg: &str, e: String| -> Result<Value, String> {
// eprintln!("{} - {}: {}", url, msg, e);
@ -124,7 +251,7 @@ pub async fn json(path: String, quarantine: bool) -> Result<Value, String> {
};
// Fetch the url...
match request(url.clone(), quarantine).await {
match reddit_get(path.clone(), quarantine).await {
Ok(response) => {
let status = response.status();
@ -142,7 +269,7 @@ pub async fn json(path: String, quarantine: bool) -> Result<Value, String> {
.as_str()
.unwrap_or_else(|| {
json["message"].as_str().unwrap_or_else(|| {
eprintln!("{} - Error parsing reddit error", url);
eprintln!("{}{} - Error parsing reddit error", REDDIT_URL_BASE, path);
"Error parsing reddit error"
})
})

130
src/config.rs Normal file
View file

@ -0,0 +1,130 @@
use once_cell::sync::Lazy;
use std::{env::var, fs::read_to_string};
// Waiting for https://github.com/rust-lang/rust/issues/74465 to land, so we
// can reduce reliance on once_cell.
//
// This is the local static that is initialized at runtime (technically at
// first request) and contains the instance settings.
static CONFIG: Lazy<Config> = Lazy::new(Config::load);
/// Stores the configuration parsed from the environment variables and the
/// config file. `Config::Default()` contains None for each setting.
#[derive(Default, serde::Deserialize)]
pub struct Config {
#[serde(rename = "LIBREDDIT_SFW_ONLY")]
sfw_only: Option<String>,
#[serde(rename = "LIBREDDIT_DEFAULT_THEME")]
default_theme: Option<String>,
#[serde(rename = "LIBREDDIT_DEFAULT_FRONT_PAGE")]
default_front_page: Option<String>,
#[serde(rename = "LIBREDDIT_DEFAULT_LAYOUT")]
default_layout: Option<String>,
#[serde(rename = "LIBREDDIT_DEFAULT_WIDE")]
default_wide: Option<String>,
#[serde(rename = "LIBREDDIT_DEFAULT_COMMENT_SORT")]
default_comment_sort: Option<String>,
#[serde(rename = "LIBREDDIT_DEFAULT_POST_SORT")]
default_post_sort: Option<String>,
#[serde(rename = "LIBREDDIT_DEFAULT_SHOW_NSFW")]
default_show_nsfw: Option<String>,
#[serde(rename = "LIBREDDIT_DEFAULT_BLUR_NSFW")]
default_blur_nsfw: Option<String>,
#[serde(rename = "LIBREDDIT_DEFAULT_USE_HLS")]
default_use_hls: Option<String>,
#[serde(rename = "LIBREDDIT_DEFAULT_HIDE_HLS_NOTIFICATION")]
default_hide_hls_notification: Option<String>,
}
impl Config {
/// Load the configuration from the environment variables and the config file.
/// In the case that there are no environment variables set and there is no
/// config file, this function returns a Config that contains all None values.
pub fn load() -> Self {
// Read from libreddit.toml config file. If for any reason, it fails, the
// default `Config` is used (all None values)
let config: Config = toml::from_str(&read_to_string("libreddit.toml").unwrap_or_default()).unwrap_or_default();
// This function defines the order of preference - first check for
// environment variables with "LIBREDDIT", then check the config, then if
// both are `None`, return a `None` via the `map_or_else` function
let parse = |key: &str| -> Option<String> { var(key).ok().map_or_else(|| get_setting_from_config(key, &config), Some) };
Self {
sfw_only: parse("LIBREDDIT_SFW_ONLY"),
default_theme: parse("LIBREDDIT_DEFAULT_THEME"),
default_front_page: parse("LIBREDDIT_DEFAULT_FRONT_PAGE"),
default_layout: parse("LIBREDDIT_DEFAULT_LAYOUT"),
default_post_sort: parse("LIBREDDIT_DEFAULT_POST_SORT"),
default_wide: parse("LIBREDDIT_DEFAULT_WIDE"),
default_comment_sort: parse("LIBREDDIT_DEFAULT_COMMENT_SORT"),
default_show_nsfw: parse("LIBREDDIT_DEFAULT_SHOW_NSFW"),
default_blur_nsfw: parse("LIBREDDIT_DEFAULT_BLUR_NSFW"),
default_use_hls: parse("LIBREDDIT_DEFAULT_USE_HLS"),
default_hide_hls_notification: parse("LIBREDDIT_DEFAULT_HIDE_HLS"),
}
}
}
fn get_setting_from_config(name: &str, config: &Config) -> Option<String> {
match name {
"LIBREDDIT_SFW_ONLY" => config.sfw_only.clone(),
"LIBREDDIT_DEFAULT_THEME" => config.default_theme.clone(),
"LIBREDDIT_DEFAULT_FRONT_PAGE" => config.default_front_page.clone(),
"LIBREDDIT_DEFAULT_LAYOUT" => config.default_layout.clone(),
"LIBREDDIT_DEFAULT_COMMENT_SORT" => config.default_comment_sort.clone(),
"LIBREDDIT_DEFAULT_POST_SORT" => config.default_post_sort.clone(),
"LIBREDDIT_DEFAULT_SHOW_NSFW" => config.default_show_nsfw.clone(),
"LIBREDDIT_DEFAULT_BLUR_NSFW" => config.default_blur_nsfw.clone(),
"LIBREDDIT_DEFAULT_USE_HLS" => config.default_use_hls.clone(),
"LIBREDDIT_DEFAULT_HIDE_HLS_NOTIFICATION" => config.default_hide_hls_notification.clone(),
"LIBREDDIT_DEFAULT_WIDE" => config.default_wide.clone(),
_ => None,
}
}
/// Retrieves setting from environment variable or config file.
pub(crate) fn get_setting(name: &str) -> Option<String> {
get_setting_from_config(name, &CONFIG)
}
#[cfg(test)]
use {sealed_test::prelude::*, std::fs::write};
#[test]
#[sealed_test(env = [("LIBREDDIT_SFW_ONLY", "on")])]
fn test_env_var() {
assert!(crate::utils::sfw_only())
}
#[test]
#[sealed_test]
fn test_config() {
let config_to_write = r#"LIBREDDIT_DEFAULT_COMMENT_SORT = "best""#;
write("libreddit.toml", config_to_write).unwrap();
assert_eq!(get_setting("LIBREDDIT_DEFAULT_COMMENT_SORT"), Some("best".into()));
}
#[test]
#[sealed_test(env = [("LIBREDDIT_DEFAULT_COMMENT_SORT", "top")])]
fn test_env_config_precedence() {
let config_to_write = r#"LIBREDDIT_DEFAULT_COMMENT_SORT = "best""#;
write("libreddit.toml", config_to_write).unwrap();
assert_eq!(get_setting("LIBREDDIT_DEFAULT_COMMENT_SORT"), Some("top".into()))
}
#[test]
#[sealed_test(env = [("LIBREDDIT_DEFAULT_COMMENT_SORT", "top")])]
fn test_alt_env_config_precedence() {
let config_to_write = r#"LIBREDDIT_DEFAULT_COMMENT_SORT = "best""#;
write("libreddit.toml", config_to_write).unwrap();
assert_eq!(get_setting("LIBREDDIT_DEFAULT_COMMENT_SORT"), Some("top".into()))
}

236
src/duplicates.rs Normal file
View file

@ -0,0 +1,236 @@
// Handler for post duplicates.
use crate::client::json;
use crate::server::RequestExt;
use crate::subreddit::{can_access_quarantine, quarantine};
use crate::utils::{error, filter_posts, get_filters, nsfw_landing, parse_post, setting, template, Post, Preferences};
use askama::Template;
use hyper::{Body, Request, Response};
use serde_json::Value;
use std::borrow::ToOwned;
use std::collections::HashSet;
use std::vec::Vec;
/// DuplicatesParams contains the parameters in the URL.
struct DuplicatesParams {
before: String,
after: String,
sort: String,
}
/// DuplicatesTemplate defines an Askama template for rendering duplicate
/// posts.
#[derive(Template)]
#[template(path = "duplicates.html")]
struct DuplicatesTemplate {
/// params contains the relevant request parameters.
params: DuplicatesParams,
/// post is the post whose ID is specified in the reqeust URL. Note that
/// this is not necessarily the "original" post.
post: Post,
/// duplicates is the list of posts that, per Reddit, are duplicates of
/// Post above.
duplicates: Vec<Post>,
/// prefs are the user preferences.
prefs: Preferences,
/// url is the request URL.
url: String,
/// num_posts_filtered counts how many posts were filtered from the
/// duplicates list.
num_posts_filtered: u64,
/// all_posts_filtered is true if every duplicate was filtered. This is an
/// edge case but can still happen.
all_posts_filtered: bool,
}
/// Make the GET request to Reddit. It assumes `req` is the appropriate Reddit
/// REST endpoint for enumerating post duplicates.
pub async fn item(req: Request<Body>) -> Result<Response<Body>, String> {
let path: String = format!("{}.json?{}&raw_json=1", req.uri().path(), req.uri().query().unwrap_or_default());
let sub = req.param("sub").unwrap_or_default();
let quarantined = can_access_quarantine(&req, &sub);
// Log the request in debugging mode
#[cfg(debug_assertions)]
dbg!(req.param("id").unwrap_or_default());
// Send the GET, and await JSON.
match json(path, quarantined).await {
// Process response JSON.
Ok(response) => {
let post = parse_post(&response[0]["data"]["children"][0]).await;
// Return landing page if this post if this Reddit deems this post
// NSFW, but we have also disabled the display of NSFW content
// or if the instance is SFW-only.
if post.nsfw && (setting(&req, "show_nsfw") != "on" || crate::utils::sfw_only()) {
return Ok(nsfw_landing(req).await.unwrap_or_default());
}
let filters = get_filters(&req);
let (duplicates, num_posts_filtered, all_posts_filtered) = parse_duplicates(&response[1], &filters).await;
// These are the values for the "before=", "after=", and "sort="
// query params, respectively.
let mut before: String = String::new();
let mut after: String = String::new();
let mut sort: String = String::new();
// FIXME: We have to perform a kludge to work around a Reddit API
// bug.
//
// The JSON object in "data" will never contain a "before" value so
// it is impossible to use it to determine our position in a
// listing. We'll make do by getting the ID of the first post in
// the listing, setting that as our "before" value, and ask Reddit
// to give us a batch of duplicate posts up to that post.
//
// Likewise, if we provide a "before" request in the GET, the
// result won't have an "after" in the JSON, in addition to missing
// the "before." So we will have to use the final post in the list
// of duplicates.
//
// That being said, we'll also need to capture the value of the
// "sort=" parameter as well, so we will need to inspect the
// query key-value pairs anyway.
let l = duplicates.len();
if l > 0 {
// This gets set to true if "before=" is one of the GET params.
let mut have_before: bool = false;
// This gets set to true if "after=" is one of the GET params.
let mut have_after: bool = false;
// Inspect the query key-value pairs. We will need to record
// the value of "sort=", along with checking to see if either
// one of "before=" or "after=" are given.
//
// If we're in the middle of the batch (evidenced by the
// presence of a "before=" or "after=" parameter in the GET),
// then use the first post as the "before" reference.
//
// We'll do this iteratively. Better than with .map_or()
// since a closure will continue to operate on remaining
// elements even after we've determined one of "before=" or
// "after=" (or both) are in the GET request.
//
// In practice, here should only ever be one of "before=" or
// "after=" and never both.
let query_str = req.uri().query().unwrap_or_default().to_string();
if !query_str.is_empty() {
for param in query_str.split('&') {
let kv: Vec<&str> = param.split('=').collect();
if kv.len() < 2 {
// Reject invalid query parameter.
continue;
}
let key: &str = kv[0];
match key {
"before" => have_before = true,
"after" => have_after = true,
"sort" => {
let val: &str = kv[1];
match val {
"new" | "num_comments" => sort = val.to_string(),
_ => {}
}
}
_ => {}
}
}
}
if have_after {
before = "t3_".to_owned();
before.push_str(&duplicates[0].id);
}
// Address potentially missing "after". If "before=" is in the
// GET, then "after" will be null in the JSON (see FIXME
// above).
if have_before {
// The next batch will need to start from one after the
// last post in the current batch.
after = "t3_".to_owned();
after.push_str(&duplicates[l - 1].id);
// Here is where things get terrible. Notice that we
// haven't set `before`. In order to do so, we will
// need to know if there is a batch that exists before
// this one, and doing so requires actually fetching the
// previous batch. In other words, we have to do yet one
// more GET to Reddit. There is no other way to determine
// whether or not to define `before`.
//
// We'll mitigate that by requesting at most one duplicate.
let new_path: String = format!(
"{}.json?before=t3_{}&sort={}&limit=1&raw_json=1",
req.uri().path(),
&duplicates[0].id,
if sort.is_empty() { "num_comments".to_string() } else { sort.clone() }
);
match json(new_path, true).await {
Ok(response) => {
if !response[1]["data"]["children"].as_array().unwrap_or(&Vec::new()).is_empty() {
before = "t3_".to_owned();
before.push_str(&duplicates[0].id);
}
}
Err(msg) => {
// Abort entirely if we couldn't get the previous
// batch.
return error(req, msg).await;
}
}
} else {
after = response[1]["data"]["after"].as_str().unwrap_or_default().to_string();
}
}
let url = req.uri().to_string();
template(DuplicatesTemplate {
params: DuplicatesParams { before, after, sort },
post,
duplicates,
prefs: Preferences::new(&req),
url,
num_posts_filtered,
all_posts_filtered,
})
}
// Process error.
Err(msg) => {
if msg == "quarantined" {
let sub = req.param("sub").unwrap_or_default();
quarantine(req, sub)
} else {
error(req, msg).await
}
}
}
}
// DUPLICATES
async fn parse_duplicates(json: &serde_json::Value, filters: &HashSet<String>) -> (Vec<Post>, u64, bool) {
let post_duplicates: &Vec<Value> = &json["data"]["children"].as_array().map_or(Vec::new(), ToOwned::to_owned);
let mut duplicates: Vec<Post> = Vec::new();
// Process each post and place them in the Vec<Post>.
for val in post_duplicates.iter() {
let post: Post = parse_post(val).await;
duplicates.push(post);
}
let (num_posts_filtered, all_posts_filtered) = filter_posts(&mut duplicates, filters);
(duplicates, num_posts_filtered, all_posts_filtered)
}

View file

@ -3,6 +3,8 @@
#![allow(clippy::cmp_owned)]
// Reference local files
mod config;
mod duplicates;
mod post;
mod search;
mod settings;
@ -11,13 +13,13 @@ mod user;
mod utils;
// Import Crates
use clap::{Arg, Command};
use clap::{Arg, ArgAction, Command};
use futures_lite::FutureExt;
use hyper::{header::HeaderValue, Body, Request, Response};
mod client;
use client::proxy;
use client::{canonical_path, proxy};
use server::RequestExt;
use utils::{error, redirect, ThemeAssets};
@ -112,7 +114,7 @@ async fn main() {
.short('r')
.long("redirect-https")
.help("Redirect all HTTP requests to HTTPS (no longer functional)")
.takes_value(false),
.num_args(0),
)
.arg(
Arg::new("address")
@ -121,16 +123,18 @@ async fn main() {
.value_name("ADDRESS")
.help("Sets address to listen on")
.default_value("0.0.0.0")
.takes_value(true),
.num_args(1),
)
.arg(
Arg::new("port")
.short('p')
.long("port")
.value_name("PORT")
.env("PORT")
.help("Port to listen on")
.default_value("8080")
.takes_value(true),
.action(ArgAction::Set)
.num_args(1),
)
.arg(
Arg::new("hsts")
@ -139,15 +143,15 @@ async fn main() {
.value_name("EXPIRE_TIME")
.help("HSTS header to tell browsers that this site should only be accessed over HTTPS")
.default_value("604800")
.takes_value(true),
.num_args(1),
)
.get_matches();
let address = matches.value_of("address").unwrap_or("0.0.0.0");
let port = std::env::var("PORT").unwrap_or_else(|_| matches.value_of("port").unwrap_or("8080").to_string());
let hsts = matches.value_of("hsts");
let address = matches.get_one::<String>("address").unwrap();
let port = matches.get_one::<String>("port").unwrap();
let hsts = matches.get_one("hsts").map(|m: &String| m.as_str());
let listener = [address, ":", &port].concat();
let listener = [address, ":", port].concat();
println!("Starting Libreddit...");
@ -238,6 +242,16 @@ async fn main() {
app.at("/r/:sub/comments/:id").get(|r| post::item(r).boxed());
app.at("/r/:sub/comments/:id/:title").get(|r| post::item(r).boxed());
app.at("/r/:sub/comments/:id/:title/:comment_id").get(|r| post::item(r).boxed());
app.at("/comments/:id").get(|r| post::item(r).boxed());
app.at("/comments/:id/comments").get(|r| post::item(r).boxed());
app.at("/comments/:id/comments/:comment_id").get(|r| post::item(r).boxed());
app.at("/comments/:id/:title").get(|r| post::item(r).boxed());
app.at("/comments/:id/:title/:comment_id").get(|r| post::item(r).boxed());
app.at("/r/:sub/duplicates/:id").get(|r| duplicates::item(r).boxed());
app.at("/r/:sub/duplicates/:id/:title").get(|r| duplicates::item(r).boxed());
app.at("/duplicates/:id").get(|r| duplicates::item(r).boxed());
app.at("/duplicates/:id/:title").get(|r| duplicates::item(r).boxed());
app.at("/r/:sub/search").get(|r| search::find(r).boxed());
@ -254,9 +268,6 @@ async fn main() {
app.at("/r/:sub/:sort").get(|r| subreddit::community(r).boxed());
// Comments handler
app.at("/comments/:id").get(|r| post::item(r).boxed());
// Front page
app.at("/").get(|r| subreddit::community(r).boxed());
@ -274,13 +285,25 @@ async fn main() {
// Handle about pages
app.at("/about").get(|req| error(req, "About pages aren't added yet".to_string()).boxed());
app.at("/:id").get(|req: Request<Body>| match req.param("id").as_deref() {
// Sort front page
Some("best" | "hot" | "new" | "top" | "rising" | "controversial") => subreddit::community(req).boxed(),
// Short link for post
Some(id) if id.len() > 4 && id.len() < 7 => post::item(req).boxed(),
// Error message for unknown pages
_ => error(req, "Nothing here".to_string()).boxed(),
app.at("/:id").get(|req: Request<Body>| {
Box::pin(async move {
match req.param("id").as_deref() {
// Sort front page
Some("best" | "hot" | "new" | "top" | "rising" | "controversial") => subreddit::community(req).await,
// Short link for post
Some(id) if (5..8).contains(&id.len()) => match canonical_path(format!("/{}", id)).await {
Ok(path_opt) => match path_opt {
Some(path) => Ok(redirect(path)),
None => error(req, "Post ID is invalid. It may point to a post on a community that has been banned.").await,
},
Err(e) => error(req, e).await,
},
// Error message for unknown pages
_ => error(req, "Nothing here".to_string()).await,
}
})
});
// Default service in case no routes match

View file

@ -3,7 +3,7 @@ use crate::client::json;
use crate::server::RequestExt;
use crate::subreddit::{can_access_quarantine, quarantine};
use crate::utils::{
error, format_num, format_url, get_filters, param, rewrite_urls, setting, template, time, val, Author, Awards, Comment, Flags, Flair, FlairPart, Media, Post, Preferences,
error, format_num, get_filters, nsfw_landing, param, parse_post, rewrite_urls, setting, template, time, val, Author, Awards, Comment, Flair, FlairPart, Post, Preferences,
};
use hyper::{Body, Request, Response};
@ -54,8 +54,16 @@ pub async fn item(req: Request<Body>) -> Result<Response<Body>, String> {
// Otherwise, grab the JSON output from the request
Ok(response) => {
// Parse the JSON into Post and Comment structs
let post = parse_post(&response[0]).await;
let comments = parse_comments(&response[1], &post.permalink, &post.author.name, highlighted_comment, &get_filters(&req));
let post = parse_post(&response[0]["data"]["children"][0]).await;
// Return landing page if this post if this Reddit deems this post
// NSFW, but we have also disabled the display of NSFW content
// or if the instance is SFW-only.
if post.nsfw && (setting(&req, "show_nsfw") != "on" || crate::utils::sfw_only()) {
return Ok(nsfw_landing(req).await.unwrap_or_default());
}
let comments = parse_comments(&response[1], &post.permalink, &post.author.name, highlighted_comment, &get_filters(&req), &req);
let url = req.uri().to_string();
// Use the Post and Comment structs to generate a website to show users
@ -63,7 +71,7 @@ pub async fn item(req: Request<Body>) -> Result<Response<Body>, String> {
comments,
post,
sort,
prefs: Preferences::new(req),
prefs: Preferences::new(&req),
single_thread,
url,
})
@ -80,94 +88,8 @@ pub async fn item(req: Request<Body>) -> Result<Response<Body>, String> {
}
}
// POSTS
async fn parse_post(json: &serde_json::Value) -> Post {
// Retrieve post (as opposed to comments) from JSON
let post: &serde_json::Value = &json["data"]["children"][0];
// Grab UTC time as unix timestamp
let (rel_time, created) = time(post["data"]["created_utc"].as_f64().unwrap_or_default());
// Parse post score and upvote ratio
let score = post["data"]["score"].as_i64().unwrap_or_default();
let ratio: f64 = post["data"]["upvote_ratio"].as_f64().unwrap_or(1.0) * 100.0;
// Determine the type of media along with the media URL
let (post_type, media, gallery) = Media::parse(&post["data"]).await;
let awards: Awards = Awards::parse(&post["data"]["all_awardings"]);
let permalink = val(post, "permalink");
let body = if val(post, "removed_by_category") == "moderator" {
format!(
"<div class=\"md\"><p>[removed] — <a href=\"https://www.reveddit.com{}\">view removed post</a></p></div>",
permalink
)
} else {
rewrite_urls(&val(post, "selftext_html"))
};
// Build a post using data parsed from Reddit post API
Post {
id: val(post, "id"),
title: val(post, "title"),
community: val(post, "subreddit"),
body,
author: Author {
name: val(post, "author"),
flair: Flair {
flair_parts: FlairPart::parse(
post["data"]["author_flair_type"].as_str().unwrap_or_default(),
post["data"]["author_flair_richtext"].as_array(),
post["data"]["author_flair_text"].as_str(),
),
text: val(post, "link_flair_text"),
background_color: val(post, "author_flair_background_color"),
foreground_color: val(post, "author_flair_text_color"),
},
distinguished: val(post, "distinguished"),
},
permalink,
score: format_num(score),
upvote_ratio: ratio as i64,
post_type,
media,
thumbnail: Media {
url: format_url(val(post, "thumbnail").as_str()),
alt_url: String::new(),
width: post["data"]["thumbnail_width"].as_i64().unwrap_or_default(),
height: post["data"]["thumbnail_height"].as_i64().unwrap_or_default(),
poster: "".to_string(),
},
flair: Flair {
flair_parts: FlairPart::parse(
post["data"]["link_flair_type"].as_str().unwrap_or_default(),
post["data"]["link_flair_richtext"].as_array(),
post["data"]["link_flair_text"].as_str(),
),
text: val(post, "link_flair_text"),
background_color: val(post, "link_flair_background_color"),
foreground_color: if val(post, "link_flair_text_color") == "dark" {
"black".to_string()
} else {
"white".to_string()
},
},
flags: Flags {
nsfw: post["data"]["over_18"].as_bool().unwrap_or(false),
stickied: post["data"]["stickied"].as_bool().unwrap_or(false),
},
domain: val(post, "domain"),
rel_time,
created,
comments: format_num(post["data"]["num_comments"].as_i64().unwrap_or_default()),
gallery,
awards,
}
}
// COMMENTS
fn parse_comments(json: &serde_json::Value, post_link: &str, post_author: &str, highlighted_comment: &str, filters: &HashSet<String>) -> Vec<Comment> {
fn parse_comments(json: &serde_json::Value, post_link: &str, post_author: &str, highlighted_comment: &str, filters: &HashSet<String>, req: &Request<Body>) -> Vec<Comment> {
// Parse the comment JSON into a Vector of Comments
let comments = json["data"]["children"].as_array().map_or(Vec::new(), std::borrow::ToOwned::to_owned);
@ -187,7 +109,7 @@ fn parse_comments(json: &serde_json::Value, post_link: &str, post_author: &str,
// If this comment contains replies, handle those too
let replies: Vec<Comment> = if data["replies"].is_object() {
parse_comments(&data["replies"], post_link, post_author, highlighted_comment, filters)
parse_comments(&data["replies"], post_link, post_author, highlighted_comment, filters, req)
} else {
Vec::new()
};
@ -200,9 +122,9 @@ fn parse_comments(json: &serde_json::Value, post_link: &str, post_author: &str,
let id = val(&comment, "id");
let highlighted = id == highlighted_comment;
let body = if val(&comment, "author") == "[deleted]" && val(&comment, "body") == "[removed]" {
let body = if (val(&comment, "author") == "[deleted]" && val(&comment, "body") == "[removed]") || val(&comment, "body") == "[ Removed by Reddit ]" {
format!(
"<div class=\"md\"><p>[removed] — <a href=\"https://www.reveddit.com{}{}\">view removed comment</a></p></div>",
"<div class=\"md\"><p>[removed] — <a href=\"https://www.unddit.com{}{}\">view removed comment</a></p></div>",
post_link, id
)
} else {
@ -255,6 +177,7 @@ fn parse_comments(json: &serde_json::Value, post_link: &str, post_author: &str,
awards,
collapsed,
is_filtered,
prefs: Preferences::new(req),
}
})
.collect()

View file

@ -1,5 +1,5 @@
// CRATES
use crate::utils::{catch_random, error, filter_posts, format_num, format_url, get_filters, param, redirect, setting, template, val, Post, Preferences};
use crate::utils::{self, catch_random, error, filter_posts, format_num, format_url, get_filters, param, redirect, setting, template, val, Post, Preferences};
use crate::{
client::json,
subreddit::{can_access_quarantine, quarantine},
@ -7,6 +7,8 @@ use crate::{
};
use askama::Template;
use hyper::{Body, Request, Response};
use once_cell::sync::Lazy;
use regex::Regex;
// STRUCTS
struct SearchParams {
@ -42,13 +44,25 @@ struct SearchTemplate {
/// Whether all fetched posts are filtered (to differentiate between no posts fetched in the first place,
/// and all fetched posts being filtered).
all_posts_filtered: bool,
/// Whether all posts were hidden because they are NSFW (and user has disabled show NSFW)
all_posts_hidden_nsfw: bool,
no_posts: bool,
}
// Regex matched against search queries to determine if they are reddit urls.
static REDDIT_URL_MATCH: Lazy<Regex> = Lazy::new(|| Regex::new(r"^https?://([^\./]+\.)*reddit.com/").unwrap());
// SERVICES
pub async fn find(req: Request<Body>) -> Result<Response<Body>, String> {
let nsfw_results = if setting(&req, "show_nsfw") == "on" { "&include_over_18=on" } else { "" };
// This ensures that during a search, no NSFW posts are fetched at all
let nsfw_results = if setting(&req, "show_nsfw") == "on" && !utils::sfw_only() {
"&include_over_18=on"
} else {
""
};
let path = format!("{}.json?{}{}&raw_json=1", req.uri().path(), req.uri().query().unwrap_or_default(), nsfw_results);
let query = param(&path, "q").unwrap_or_default();
let mut query = param(&path, "q").unwrap_or_default();
query = REDDIT_URL_MATCH.replace(&query, "").to_string();
if query.is_empty() {
return Ok(redirect("/".to_string()));
@ -96,16 +110,19 @@ pub async fn find(req: Request<Body>) -> Result<Response<Body>, String> {
restrict_sr: param(&path, "restrict_sr").unwrap_or_default(),
typed,
},
prefs: Preferences::new(req),
prefs: Preferences::new(&req),
url,
is_filtered: true,
all_posts_filtered: false,
all_posts_hidden_nsfw: false,
no_posts: false,
})
} else {
match Post::fetch(&path, quarantined).await {
Ok((mut posts, after)) => {
let all_posts_filtered = filter_posts(&mut posts, &filters);
let (_, all_posts_filtered) = filter_posts(&mut posts, &filters);
let no_posts = posts.is_empty();
let all_posts_hidden_nsfw = !no_posts && (posts.iter().all(|p| p.flags.nsfw) && setting(&req, "show_nsfw") != "on");
template(SearchTemplate {
posts,
subreddits,
@ -119,10 +136,12 @@ pub async fn find(req: Request<Body>) -> Result<Response<Body>, String> {
restrict_sr: param(&path, "restrict_sr").unwrap_or_default(),
typed,
},
prefs: Preferences::new(req),
prefs: Preferences::new(&req),
url,
is_filtered: false,
all_posts_filtered,
all_posts_hidden_nsfw,
no_posts,
})
}
Err(msg) => {

View file

@ -1,17 +1,80 @@
use brotli::enc::{BrotliCompress, BrotliEncoderParams};
use cached::proc_macro::cached;
use cookie::Cookie;
use core::f64;
use futures_lite::{future::Boxed, Future, FutureExt};
use hyper::{
header::HeaderValue,
body,
body::HttpBody,
header,
service::{make_service_fn, service_fn},
HeaderMap,
};
use hyper::{Body, Method, Request, Response, Server as HyperServer};
use libflate::gzip;
use route_recognizer::{Params, Router};
use std::{pin::Pin, result::Result};
use std::{
cmp::Ordering,
io,
pin::Pin,
result::Result,
str::{from_utf8, Split},
string::ToString,
};
use time::Duration;
use crate::dbg_msg;
type BoxResponse = Pin<Box<dyn Future<Output = Result<Response<Body>, String>> + Send>>;
/// Compressors for the response Body, in ascending order of preference.
#[derive(Copy, Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
enum CompressionType {
Passthrough,
Gzip,
Brotli,
}
/// All browsers support gzip, so if we are given `Accept-Encoding: *`, deliver
/// gzipped-content.
///
/// Brotli would be nice universally, but Safari (iOS, iPhone, macOS) reportedly
/// doesn't support it yet.
const DEFAULT_COMPRESSOR: CompressionType = CompressionType::Gzip;
impl CompressionType {
/// Returns a `CompressionType` given a content coding
/// in [RFC 7231](https://datatracker.ietf.org/doc/html/rfc7231#section-5.3.4)
/// format.
fn parse(s: &str) -> Option<CompressionType> {
let c = match s {
// Compressors we support.
"gzip" => CompressionType::Gzip,
"br" => CompressionType::Brotli,
// The wildcard means that we can choose whatever
// compression we prefer. In this case, use the
// default.
"*" => DEFAULT_COMPRESSOR,
// Compressor not supported.
_ => return None,
};
Some(c)
}
}
impl ToString for CompressionType {
fn to_string(&self) -> String {
match self {
CompressionType::Gzip => "gzip".to_string(),
CompressionType::Brotli => "br".to_string(),
_ => String::new(),
}
}
}
pub struct Route<'a> {
router: &'a mut Router<fn(Request<Body>) -> BoxResponse>,
path: String,
@ -97,7 +160,7 @@ impl ResponseExt for Response<Body> {
}
fn insert_cookie(&mut self, cookie: Cookie) {
if let Ok(val) = HeaderValue::from_str(&cookie.to_string()) {
if let Ok(val) = header::HeaderValue::from_str(&cookie.to_string()) {
self.headers_mut().append("Set-Cookie", val);
}
}
@ -106,7 +169,7 @@ impl ResponseExt for Response<Body> {
let mut cookie = Cookie::named(name);
cookie.set_path("/");
cookie.set_max_age(Duration::seconds(1));
if let Ok(val) = HeaderValue::from_str(&cookie.to_string()) {
if let Ok(val) = header::HeaderValue::from_str(&cookie.to_string()) {
self.headers_mut().append("Set-Cookie", val);
}
}
@ -156,10 +219,11 @@ impl Server {
// let shared_router = router.clone();
async move {
Ok::<_, String>(service_fn(move |req: Request<Body>| {
let headers = default_headers.clone();
let req_headers = req.headers().clone();
let def_headers = default_headers.clone();
// Remove double slashes and decode encoded slashes
let mut path = req.uri().path().replace("//", "/").replace("%2F","/");
let mut path = req.uri().path().replace("//", "/").replace("%2F", "/");
// Remove trailing slashes
if path != "/" && path.ends_with('/') {
@ -176,26 +240,20 @@ impl Server {
// Run the route's function
let func = (found.handler().to_owned().to_owned())(parammed);
async move {
let res: Result<Response<Body>, String> = func.await;
// Add default headers to response
res.map(|mut response| {
response.headers_mut().extend(headers);
response
})
match func.await {
Ok(mut res) => {
res.headers_mut().extend(def_headers);
let _ = compress_response(&req_headers, &mut res).await;
Ok(res)
}
Err(msg) => new_boilerplate(def_headers, req_headers, 500, Body::from(msg)).await,
}
}
.boxed()
}
// If there was a routing error
Err(e) => async move {
// Return a 404 error
let res: Result<Response<Body>, String> = Ok(Response::builder().status(404).body(e.into()).unwrap_or_default());
// Add default headers to response
res.map(|mut response| {
response.headers_mut().extend(headers);
response
})
}
.boxed(),
Err(e) => async move { new_boilerplate(def_headers, req_headers, 404, e.into()).await }.boxed(),
}
}))
}
@ -213,3 +271,472 @@ impl Server {
server.boxed()
}
}
/// Create a boilerplate Response for error conditions. This response will be
/// compressed if requested by client.
async fn new_boilerplate(
default_headers: HeaderMap<header::HeaderValue>,
req_headers: HeaderMap<header::HeaderValue>,
status: u16,
body: Body,
) -> Result<Response<Body>, String> {
match Response::builder().status(status).body(body) {
Ok(mut res) => {
let _ = compress_response(&req_headers, &mut res).await;
res.headers_mut().extend(default_headers.clone());
Ok(res)
}
Err(msg) => Err(msg.to_string()),
}
}
/// Determines the desired compressor based on the Accept-Encoding header.
///
/// This function will honor the [q-value](https://developer.mozilla.org/en-US/docs/Glossary/Quality_values)
/// for each compressor. The q-value is an optional parameter, a decimal value
/// on \[0..1\], to order the compressors by preference. An Accept-Encoding value
/// with no q-values is also accepted.
///
/// Here are [examples](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Accept-Encoding#examples)
/// of valid Accept-Encoding headers.
///
/// ```http
/// Accept-Encoding: gzip
/// Accept-Encoding: gzip, compress, br
/// Accept-Encoding: br;q=1.0, gzip;q=0.8, *;q=0.1
/// ```
#[cached]
fn determine_compressor(accept_encoding: String) -> Option<CompressionType> {
if accept_encoding.is_empty() {
return None;
};
// Keep track of the compressor candidate based on both the client's
// preference and our own. Concrete examples:
//
// 1. "Accept-Encoding: gzip, br" => assuming we like brotli more than
// gzip, and the browser supports brotli, we choose brotli
//
// 2. "Accept-Encoding: gzip;q=0.8, br;q=0.3" => the client has stated a
// preference for gzip over brotli, so we choose gzip
//
// To do this, we need to define a struct which contains the requested
// requested compressor (abstracted as a CompressionType enum) and the
// q-value. If no q-value is defined for the compressor, we assume one of
// 1.0. We first compare compressor candidates by comparing q-values, and
// then CompressionTypes. We keep track of whatever is the greatest per our
// ordering.
struct CompressorCandidate {
alg: CompressionType,
q: f64,
}
impl Ord for CompressorCandidate {
fn cmp(&self, other: &Self) -> Ordering {
// Compare q-values. Break ties with the
// CompressionType values.
match self.q.total_cmp(&other.q) {
Ordering::Equal => self.alg.cmp(&other.alg),
ord => ord,
}
}
}
impl PartialOrd for CompressorCandidate {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
// Guard against NAN, both on our end and on the other.
if self.q.is_nan() || other.q.is_nan() {
return None;
};
// f64 and CompressionType are ordered, except in the case
// where the f64 is NAN (which we checked against), so we
// can safely return a Some here.
Some(self.cmp(other))
}
}
impl PartialEq for CompressorCandidate {
fn eq(&self, other: &Self) -> bool {
(self.q == other.q) && (self.alg == other.alg)
}
}
impl Eq for CompressorCandidate {}
// This is the current candidate.
//
// Assmume no candidate so far. We do this by assigning the sentinel value
// of negative infinity to the q-value. If this value is negative infinity,
// that means there was no viable compressor candidate.
let mut cur_candidate = CompressorCandidate {
alg: CompressionType::Passthrough,
q: f64::NEG_INFINITY,
};
// This loop reads the requested compressors and keeps track of whichever
// one has the highest priority per our heuristic.
for val in accept_encoding.to_string().split(',') {
let mut q: f64 = 1.0;
// The compressor and q-value (if the latter is defined)
// will be delimited by semicolons.
let mut spl: Split<char> = val.split(';');
// Get the compressor. For example, in
// gzip;q=0.8
// this grabs "gzip" in the string. It
// will further validate the compressor against the
// list of those we support. If it is not supported,
// we move onto the next one.
let compressor: CompressionType = match spl.next() {
// CompressionType::parse will return the appropriate enum given
// a string. For example, it will return CompressionType::Gzip
// when given "gzip".
Some(s) => match CompressionType::parse(s.trim()) {
Some(candidate) => candidate,
// We don't support the requested compression algorithm.
None => continue,
},
// We should never get here, but I'm paranoid.
None => continue,
};
// Get the q-value. This might not be defined, in which case assume
// 1.0.
if let Some(s) = spl.next() {
if !(s.len() > 2 && s.starts_with("q=")) {
// If the q-value is malformed, the header is malformed, so
// abort.
return None;
}
match s[2..].parse::<f64>() {
Ok(val) => {
if (0.0..=1.0).contains(&val) {
q = val;
} else {
// If the value is outside [0..1], header is malformed.
// Abort.
return None;
};
}
Err(_) => {
// If this isn't a f64, then assume a malformed header
// value and abort.
return None;
}
}
};
// If new_candidate > cur_candidate, make new_candidate the new
// cur_candidate. But do this safely! It is very possible that
// someone gave us the string "NAN", which (&str).parse::<f64>
// will happily translate to f64::NAN.
let new_candidate = CompressorCandidate { alg: compressor, q };
if let Some(ord) = new_candidate.partial_cmp(&cur_candidate) {
if ord == Ordering::Greater {
cur_candidate = new_candidate;
}
};
}
if cur_candidate.q != f64::NEG_INFINITY {
Some(cur_candidate.alg)
} else {
None
}
}
/// Compress the response body, if possible or desirable. The Body will be
/// compressed in place, and a new header Content-Encoding will be set
/// indicating the compression algorithm.
///
/// This function deems Body eligible compression if and only if the following
/// conditions are met:
///
/// 1. the HTTP client requests a compression encoding in the Content-Encoding
/// header (hence the need for the req_headers);
///
/// 2. the content encoding corresponds to a compression algorithm we support;
///
/// 3. the Media type in the Content-Type response header is text with any
/// subtype (e.g. text/plain) or application/json.
///
/// compress_response returns Ok on successful compression, or if not all three
/// conditions above are met. It returns Err if there was a problem decoding
/// any header in either req_headers or res, but res will remain intact.
///
/// This function logs errors to stderr, but only in debug mode. No information
/// is logged in release builds.
async fn compress_response(req_headers: &HeaderMap<header::HeaderValue>, res: &mut Response<Body>) -> Result<(), String> {
// Check if the data is eligible for compression.
if let Some(hdr) = res.headers().get(header::CONTENT_TYPE) {
match from_utf8(hdr.as_bytes()) {
Ok(val) => {
let s = val.to_string();
// TODO: better determination of what is eligible for compression
if !(s.starts_with("text/") || s.starts_with("application/json")) {
return Ok(());
};
}
Err(e) => {
dbg_msg!(e);
return Err(e.to_string());
}
};
} else {
// Response declares no Content-Type. Assume for simplicity that it
// cannot be compressed.
return Ok(());
};
// Don't bother if the size of the size of the response body will fit
// within an IP frame (less the bytes that make up the TCP/IP and HTTP
// headers).
if res.body().size_hint().lower() < 1452 {
return Ok(());
};
// Check to see which compressor is requested, and if we can use it.
let accept_encoding: String = match req_headers.get(header::ACCEPT_ENCODING) {
None => return Ok(()), // Client requested no compression.
Some(hdr) => match String::from_utf8(hdr.as_bytes().into()) {
Ok(val) => val,
#[cfg(debug_assertions)]
Err(e) => {
dbg_msg!(e);
return Ok(());
}
#[cfg(not(debug_assertions))]
Err(_) => return Ok(()),
},
};
let compressor: CompressionType = match determine_compressor(accept_encoding) {
Some(c) => c,
None => return Ok(()),
};
// Get the body from the response.
let body_bytes: Vec<u8> = match body::to_bytes(res.body_mut()).await {
Ok(b) => b.to_vec(),
Err(e) => {
dbg_msg!(e);
return Err(e.to_string());
}
};
// Compress!
match compress_body(compressor, body_bytes) {
Ok(compressed) => {
// We get here iff the compression was successful. Replace the body
// with the compressed payload, and add the appropriate
// Content-Encoding header in the response.
res.headers_mut().insert(header::CONTENT_ENCODING, compressor.to_string().parse().unwrap());
*(res.body_mut()) = Body::from(compressed);
}
Err(e) => return Err(e),
}
Ok(())
}
/// Compresses a `Vec<u8>` given a [`CompressionType`].
///
/// This is a helper function for [`compress_response`] and should not be
/// called directly.
// I've chosen a TTL of 600 (== 10 minutes) since compression is
// computationally expensive and we don't want to be doing it often. This is
// larger than client::json's TTL, but that's okay, because if client::json
// returns a new serde_json::Value, body_bytes changes, so this function will
// execute again.
#[cached(size = 100, time = 600, result = true)]
fn compress_body(compressor: CompressionType, body_bytes: Vec<u8>) -> Result<Vec<u8>, String> {
// io::Cursor implements io::Read, required for our encoders.
let mut reader = io::Cursor::new(body_bytes);
let compressed: Vec<u8> = match compressor {
CompressionType::Gzip => {
let mut gz: gzip::Encoder<Vec<u8>> = match gzip::Encoder::new(Vec::new()) {
Ok(gz) => gz,
Err(e) => {
dbg_msg!(e);
return Err(e.to_string());
}
};
match io::copy(&mut reader, &mut gz) {
Ok(_) => match gz.finish().into_result() {
Ok(compressed) => compressed,
Err(e) => {
dbg_msg!(e);
return Err(e.to_string());
}
},
Err(e) => {
dbg_msg!(e);
return Err(e.to_string());
}
}
}
CompressionType::Brotli => {
// We may want to make the compression parameters configurable
// in the future. For now, the defaults are sufficient.
let brotli_params = BrotliEncoderParams::default();
let mut compressed = Vec::<u8>::new();
match BrotliCompress(&mut reader, &mut compressed, &brotli_params) {
Ok(_) => compressed,
Err(e) => {
dbg_msg!(e);
return Err(e.to_string());
}
}
}
// This arm is for any requested compressor for which we don't yet
// have an implementation.
_ => {
let msg = "unsupported compressor".to_string();
return Err(msg);
}
};
Ok(compressed)
}
#[cfg(test)]
mod tests {
use super::*;
use brotli::Decompressor as BrotliDecompressor;
use futures_lite::future::block_on;
use lipsum::lipsum;
use std::{boxed::Box, io};
#[test]
fn test_determine_compressor() {
// Single compressor given.
assert_eq!(determine_compressor("unsupported".to_string()), None);
assert_eq!(determine_compressor("gzip".to_string()), Some(CompressionType::Gzip));
assert_eq!(determine_compressor("*".to_string()), Some(DEFAULT_COMPRESSOR));
// Multiple compressors.
assert_eq!(determine_compressor("gzip, br".to_string()), Some(CompressionType::Brotli));
assert_eq!(determine_compressor("gzip;q=0.8, br;q=0.3".to_string()), Some(CompressionType::Gzip));
assert_eq!(determine_compressor("br, gzip".to_string()), Some(CompressionType::Brotli));
assert_eq!(determine_compressor("br;q=0.3, gzip;q=0.4".to_string()), Some(CompressionType::Gzip));
// Invalid q-values.
assert_eq!(determine_compressor("gzip;q=NAN".to_string()), None);
}
#[test]
fn test_compress_response() {
// This macro generates an Accept-Encoding header value given any number of
// compressors.
macro_rules! ae_gen {
($x:expr) => {
$x.to_string().as_str()
};
($x:expr, $($y:expr),+) => {
format!("{}, {}", $x.to_string(), ae_gen!($($y),+)).as_str()
};
}
for accept_encoding in [
"*",
ae_gen!(CompressionType::Gzip),
ae_gen!(CompressionType::Brotli, CompressionType::Gzip),
ae_gen!(CompressionType::Brotli),
] {
// Determine what the expected encoding should be based on both the
// specific encodings we accept.
let expected_encoding: CompressionType = match determine_compressor(accept_encoding.to_string()) {
Some(s) => s,
None => panic!("determine_compressor(accept_encoding.to_string()) => None"),
};
// Build headers with our Accept-Encoding.
let mut req_headers = HeaderMap::new();
req_headers.insert(header::ACCEPT_ENCODING, header::HeaderValue::from_str(accept_encoding).unwrap());
// Build test response.
let lorem_ipsum: String = lipsum(10000);
let expected_lorem_ipsum = Vec::<u8>::from(lorem_ipsum.as_str());
let mut res = Response::builder()
.status(200)
.header(header::CONTENT_TYPE, "text/plain")
.body(Body::from(lorem_ipsum))
.unwrap();
// Perform the compression.
if let Err(e) = block_on(compress_response(&req_headers, &mut res)) {
panic!("compress_response(&req_headers, &mut res) => Err(\"{}\")", e);
};
// If the content was compressed, we expect the Content-Encoding
// header to be modified.
assert_eq!(
res
.headers()
.get(header::CONTENT_ENCODING)
.unwrap_or_else(|| panic!("missing content-encoding header"))
.to_str()
.unwrap_or_else(|_| panic!("failed to convert Content-Encoding header::HeaderValue to String")),
expected_encoding.to_string()
);
// Decompress body and make sure it's equal to what we started
// with.
//
// In the case of no compression, just make sure the "new" body in
// the Response is the same as what with which we start.
let body_vec = match block_on(body::to_bytes(res.body_mut())) {
Ok(b) => b.to_vec(),
Err(e) => panic!("{}", e),
};
if expected_encoding == CompressionType::Passthrough {
assert!(body_vec.eq(&expected_lorem_ipsum));
continue;
}
// This provides an io::Read for the underlying body.
let mut body_cursor: io::Cursor<Vec<u8>> = io::Cursor::new(body_vec);
// Match the appropriate decompresor for the given
// expected_encoding.
let mut decoder: Box<dyn io::Read> = match expected_encoding {
CompressionType::Gzip => match gzip::Decoder::new(&mut body_cursor) {
Ok(dgz) => Box::new(dgz),
Err(e) => panic!("{}", e),
},
CompressionType::Brotli => Box::new(BrotliDecompressor::new(body_cursor, expected_lorem_ipsum.len())),
_ => panic!("no decompressor for {}", expected_encoding.to_string()),
};
let mut decompressed = Vec::<u8>::new();
if let Err(e) = io::copy(&mut decoder, &mut decompressed) {
panic!("{}", e);
};
assert!(decompressed.eq(&expected_lorem_ipsum));
}
}
}

View file

@ -19,7 +19,7 @@ struct SettingsTemplate {
// CONSTANTS
const PREFS: [&str; 11] = [
const PREFS: [&str; 13] = [
"theme",
"front_page",
"layout",
@ -27,10 +27,12 @@ const PREFS: [&str; 11] = [
"comment_sort",
"post_sort",
"show_nsfw",
"blur_nsfw",
"use_hls",
"hide_hls_notification",
"autoplay_videos",
"fixed_navbar",
"hide_awards",
];
// FUNCTIONS
@ -39,7 +41,7 @@ const PREFS: [&str; 11] = [
pub async fn get(req: Request<Body>) -> Result<Response<Body>, String> {
let url = req.uri().to_string();
template(SettingsTemplate {
prefs: Preferences::new(req),
prefs: Preferences::new(&req),
url,
})
}

View file

@ -1,6 +1,6 @@
// CRATES
use crate::utils::{
catch_random, error, filter_posts, format_num, format_url, get_filters, param, redirect, rewrite_urls, setting, template, val, Post, Preferences, Subreddit,
catch_random, error, filter_posts, format_num, format_url, get_filters, nsfw_landing, param, redirect, rewrite_urls, setting, template, val, Post, Preferences, Subreddit,
};
use crate::{client::json, server::ResponseExt, RequestExt};
use askama::Template;
@ -24,6 +24,9 @@ struct SubredditTemplate {
/// Whether all fetched posts are filtered (to differentiate between no posts fetched in the first place,
/// and all fetched posts being filtered).
all_posts_filtered: bool,
/// Whether all posts were hidden because they are NSFW (and user has disabled show NSFW)
all_posts_hidden_nsfw: bool,
no_posts: bool,
}
#[derive(Template)]
@ -94,6 +97,12 @@ pub async fn community(req: Request<Body>) -> Result<Response<Body>, String> {
}
};
// Return landing page if this post if this is NSFW community but the user
// has disabled the display of NSFW content or if the instance is SFW-only.
if sub.nsfw && (setting(&req, "show_nsfw") != "on" || crate::utils::sfw_only()) {
return Ok(nsfw_landing(req).await.unwrap_or_default());
}
let path = format!("/r/{}/{}.json?{}&raw_json=1", sub_name.clone(), sort, req.uri().query().unwrap_or_default());
let url = String::from(req.uri().path_and_query().map_or("", |val| val.as_str()));
let redirect_url = url[1..].replace('?', "%3F").replace('&', "%26").replace('+', "%2B");
@ -106,27 +115,32 @@ pub async fn community(req: Request<Body>) -> Result<Response<Body>, String> {
posts: Vec::new(),
sort: (sort, param(&path, "t").unwrap_or_default()),
ends: (param(&path, "after").unwrap_or_default(), "".to_string()),
prefs: Preferences::new(req),
prefs: Preferences::new(&req),
url,
redirect_url,
is_filtered: true,
all_posts_filtered: false,
all_posts_hidden_nsfw: false,
no_posts: false,
})
} else {
match Post::fetch(&path, quarantined).await {
Ok((mut posts, after)) => {
let all_posts_filtered = filter_posts(&mut posts, &filters);
let (_, all_posts_filtered) = filter_posts(&mut posts, &filters);
let no_posts = posts.is_empty();
let all_posts_hidden_nsfw = !no_posts && (posts.iter().all(|p| p.flags.nsfw) && setting(&req, "show_nsfw") != "on");
template(SubredditTemplate {
sub,
posts,
sort: (sort, param(&path, "t").unwrap_or_default()),
ends: (param(&path, "after").unwrap_or_default(), after),
prefs: Preferences::new(req),
prefs: Preferences::new(&req),
url,
redirect_url,
is_filtered: false,
all_posts_filtered,
all_posts_hidden_nsfw,
no_posts,
})
}
Err(msg) => match msg.as_str() {
@ -145,7 +159,7 @@ pub fn quarantine(req: Request<Body>, sub: String) -> Result<Response<Body>, Str
msg: "Please click the button below to continue to this subreddit.".to_string(),
url: req.uri().to_string(),
sub,
prefs: Preferences::new(req),
prefs: Preferences::new(&req),
};
Ok(
@ -192,7 +206,7 @@ pub async fn subscriptions_filters(req: Request<Body>) -> Result<Response<Body>,
let query = req.uri().query().unwrap_or_default().to_string();
let preferences = Preferences::new(req);
let preferences = Preferences::new(&req);
let mut sub_list = preferences.subscriptions;
let mut filters = preferences.filters;
@ -305,7 +319,7 @@ pub async fn wiki(req: Request<Body>) -> Result<Response<Body>, String> {
sub,
wiki: rewrite_urls(response["data"]["content_html"].as_str().unwrap_or("<h3>Wiki not found</h3>")),
page,
prefs: Preferences::new(req),
prefs: Preferences::new(&req),
url,
}),
Err(msg) => {
@ -343,7 +357,7 @@ pub async fn sidebar(req: Request<Body>) -> Result<Response<Body>, String> {
// ),
sub,
page: "Sidebar".to_string(),
prefs: Preferences::new(req),
prefs: Preferences::new(&req),
url,
}),
Err(msg) => {
@ -416,5 +430,6 @@ async fn subreddit(sub: &str, quarantined: bool) -> Result<Subreddit, String> {
members: format_num(members),
active: format_num(active),
wiki: res["data"]["wiki_enabled"].as_bool().unwrap_or_default(),
nsfw: res["data"]["over18"].as_bool().unwrap_or_default(),
})
}

View file

@ -1,7 +1,7 @@
// CRATES
use crate::client::json;
use crate::server::RequestExt;
use crate::utils::{error, filter_posts, format_url, get_filters, param, template, Post, Preferences, User};
use crate::utils::{error, filter_posts, format_url, get_filters, nsfw_landing, param, setting, template, Post, Preferences, User};
use askama::Template;
use hyper::{Body, Request, Response};
use time::{macros::format_description, OffsetDateTime};
@ -24,6 +24,9 @@ struct UserTemplate {
/// Whether all fetched posts are filtered (to differentiate between no posts fetched in the first place,
/// and all fetched posts being filtered).
all_posts_filtered: bool,
/// Whether all posts were hidden because they are NSFW (and user has disabled show NSFW)
all_posts_hidden_nsfw: bool,
no_posts: bool,
}
// FUNCTIONS
@ -43,8 +46,17 @@ pub async fn profile(req: Request<Body>) -> Result<Response<Body>, String> {
// Retrieve other variables from Libreddit request
let sort = param(&path, "sort").unwrap_or_default();
let username = req.param("name").unwrap_or_default();
// Retrieve info from user about page.
let user = user(&username).await.unwrap_or_default();
// Return landing page if this post if this Reddit deems this user NSFW,
// but we have also disabled the display of NSFW content or if the instance
// is SFW-only.
if user.nsfw && (setting(&req, "show_nsfw") != "on" || crate::utils::sfw_only()) {
return Ok(nsfw_landing(req).await.unwrap_or_default());
}
let filters = get_filters(&req);
if filters.contains(&["u_", &username].concat()) {
template(UserTemplate {
@ -53,29 +65,34 @@ pub async fn profile(req: Request<Body>) -> Result<Response<Body>, String> {
sort: (sort, param(&path, "t").unwrap_or_default()),
ends: (param(&path, "after").unwrap_or_default(), "".to_string()),
listing,
prefs: Preferences::new(req),
prefs: Preferences::new(&req),
url,
redirect_url,
is_filtered: true,
all_posts_filtered: false,
all_posts_hidden_nsfw: false,
no_posts: false,
})
} else {
// Request user posts/comments from Reddit
match Post::fetch(&path, false).await {
Ok((mut posts, after)) => {
let all_posts_filtered = filter_posts(&mut posts, &filters);
let (_, all_posts_filtered) = filter_posts(&mut posts, &filters);
let no_posts = posts.is_empty();
let all_posts_hidden_nsfw = !no_posts && (posts.iter().all(|p| p.flags.nsfw) && setting(&req, "show_nsfw") != "on");
template(UserTemplate {
user,
posts,
sort: (sort, param(&path, "t").unwrap_or_default()),
ends: (param(&path, "after").unwrap_or_default(), after),
listing,
prefs: Preferences::new(req),
prefs: Preferences::new(&req),
url,
redirect_url,
is_filtered: false,
all_posts_filtered,
all_posts_hidden_nsfw,
no_posts,
})
}
// If there is an error show error page
@ -107,6 +124,7 @@ async fn user(name: &str) -> Result<User, String> {
created: created.format(format_description!("[month repr:short] [day] '[year repr:last_two]")).unwrap_or_default(),
banner: about("banner_img"),
description: about("public_description"),
nsfw: res["data"]["subreddit"]["over_18"].as_bool().unwrap_or_default(),
}
})
}

View file

@ -9,10 +9,36 @@ use regex::Regex;
use rust_embed::RustEmbed;
use serde_json::Value;
use std::collections::{HashMap, HashSet};
use std::env;
use std::str::FromStr;
use time::{macros::format_description, Duration, OffsetDateTime};
use url::Url;
/// Write a message to stderr on debug mode. This function is a no-op on
/// release code.
#[macro_export]
macro_rules! dbg_msg {
($x:expr) => {
#[cfg(debug_assertions)]
eprintln!("{}:{}: {}", file!(), line!(), $x.to_string())
};
($($x:expr),+) => {
#[cfg(debug_assertions)]
dbg_msg!(format!($($x),+))
};
}
/// Identifies whether or not the page is a subreddit, a user page, or a post.
/// This is used by the NSFW landing template to determine the mesage to convey
/// to the user.
#[derive(PartialEq, Eq)]
pub enum ResourceType {
Subreddit,
User,
Post,
}
// Post flair with content, background color and foreground color
pub struct Flair {
pub flair_parts: Vec<FlairPart>,
@ -210,9 +236,11 @@ pub struct Post {
pub domain: String,
pub rel_time: String,
pub created: String,
pub num_duplicates: u64,
pub comments: (String, String),
pub gallery: Vec<GalleryMedia>,
pub awards: Awards,
pub nsfw: bool,
}
impl Post {
@ -304,14 +332,16 @@ impl Post {
},
flags: Flags {
nsfw: data["over_18"].as_bool().unwrap_or_default(),
stickied: data["stickied"].as_bool().unwrap_or_default(),
stickied: data["stickied"].as_bool().unwrap_or_default() || data["pinned"].as_bool().unwrap_or_default(),
},
permalink: val(post, "permalink"),
rel_time,
created,
num_duplicates: post["data"]["num_duplicates"].as_u64().unwrap_or(0),
comments: format_num(data["num_comments"].as_i64().unwrap_or_default()),
gallery,
awards,
nsfw: post["data"]["over_18"].as_bool().unwrap_or_default(),
});
}
@ -340,6 +370,7 @@ pub struct Comment {
pub awards: Awards,
pub collapsed: bool,
pub is_filtered: bool,
pub prefs: Preferences,
}
#[derive(Default, Clone)]
@ -403,6 +434,27 @@ pub struct ErrorTemplate {
pub url: String,
}
/// Template for NSFW landing page. The landing page is displayed when a page's
/// content is wholly NSFW, but a user has not enabled the option to view NSFW
/// posts.
#[derive(Template)]
#[template(path = "nsfwlanding.html")]
pub struct NSFWLandingTemplate {
/// Identifier for the resource. This is either a subreddit name or a
/// username. (In the case of the latter, set is_user to true.)
pub res: String,
/// Identifies whether or not the resource is a subreddit, a user page,
/// or a post.
pub res_type: ResourceType,
/// User preferences.
pub prefs: Preferences,
/// Request URL.
pub url: String,
}
#[derive(Default)]
// User struct containing metadata about user
pub struct User {
@ -413,6 +465,7 @@ pub struct User {
pub created: String,
pub banner: String,
pub description: String,
pub nsfw: bool,
}
#[derive(Default)]
@ -427,6 +480,7 @@ pub struct Subreddit {
pub members: (String, String),
pub active: (String, String),
pub wiki: bool,
pub nsfw: bool,
}
// Parser for query params, used in sorting (eg. /r/rust/?sort=hot)
@ -447,6 +501,7 @@ pub struct Preferences {
pub layout: String,
pub wide: String,
pub show_nsfw: String,
pub blur_nsfw: String,
pub hide_hls_notification: String,
pub use_hls: String,
pub autoplay_videos: String,
@ -455,6 +510,7 @@ pub struct Preferences {
pub post_sort: String,
pub subscriptions: Vec<String>,
pub filters: Vec<String>,
pub hide_awards: String,
}
#[derive(RustEmbed)]
@ -464,7 +520,7 @@ pub struct ThemeAssets;
impl Preferences {
// Build preferences from cookies
pub fn new(req: Request<Body>) -> Self {
pub fn new(req: &Request<Body>) -> Self {
// Read available theme names from embedded css files.
// Always make the default "system" theme available.
let mut themes = vec!["system".to_string()];
@ -479,6 +535,7 @@ impl Preferences {
layout: setting(&req, "layout"),
wide: setting(&req, "wide"),
show_nsfw: setting(&req, "show_nsfw"),
blur_nsfw: setting(&req, "blur_nsfw"),
use_hls: setting(&req, "use_hls"),
hide_hls_notification: setting(&req, "hide_hls_notification"),
autoplay_videos: setting(&req, "autoplay_videos"),
@ -487,6 +544,7 @@ impl Preferences {
post_sort: setting(&req, "post_sort"),
subscriptions: setting(&req, "subscriptions").split('+').map(String::from).filter(|s| !s.is_empty()).collect(),
filters: setting(&req, "filters").split('+').map(String::from).filter(|s| !s.is_empty()).collect(),
hide_awards: setting(&req, "hide_awards"),
}
}
}
@ -496,15 +554,111 @@ pub fn get_filters(req: &Request<Body>) -> HashSet<String> {
setting(req, "filters").split('+').map(String::from).filter(|s| !s.is_empty()).collect::<HashSet<String>>()
}
/// Filters a `Vec<Post>` by the given `HashSet` of filters (each filter being a subreddit name or a user name). If a
/// `Post`'s subreddit or author is found in the filters, it is removed. Returns `true` if _all_ posts were filtered
/// out, or `false` otherwise.
pub fn filter_posts(posts: &mut Vec<Post>, filters: &HashSet<String>) -> bool {
/// Filters a `Vec<Post>` by the given `HashSet` of filters (each filter being
/// a subreddit name or a user name). If a `Post`'s subreddit or author is
/// found in the filters, it is removed.
///
/// The first value of the return tuple is the number of posts filtered. The
/// second return value is `true` if all posts were filtered.
pub fn filter_posts(posts: &mut Vec<Post>, filters: &HashSet<String>) -> (u64, bool) {
// This is the length of the Vec<Post> prior to applying the filter.
let lb: u64 = posts.len().try_into().unwrap_or(0);
if posts.is_empty() {
false
(0, false)
} else {
posts.retain(|p| !filters.contains(&p.community) && !filters.contains(&["u_", &p.author.name].concat()));
posts.is_empty()
posts.retain(|p| !(filters.contains(&p.community) || filters.contains(&["u_", &p.author.name].concat())));
// Get the length of the Vec<Post> after applying the filter.
// If lb > la, then at least one post was removed.
let la: u64 = posts.len().try_into().unwrap_or(0);
(lb - la, posts.is_empty())
}
}
/// Creates a [`Post`] from a provided JSON.
pub async fn parse_post(post: &serde_json::Value) -> Post {
// Grab UTC time as unix timestamp
let (rel_time, created) = time(post["data"]["created_utc"].as_f64().unwrap_or_default());
// Parse post score and upvote ratio
let score = post["data"]["score"].as_i64().unwrap_or_default();
let ratio: f64 = post["data"]["upvote_ratio"].as_f64().unwrap_or(1.0) * 100.0;
// Determine the type of media along with the media URL
let (post_type, media, gallery) = Media::parse(&post["data"]).await;
let awards: Awards = Awards::parse(&post["data"]["all_awardings"]);
let permalink = val(post, "permalink");
let body = if val(post, "removed_by_category") == "moderator" {
format!(
"<div class=\"md\"><p>[removed] — <a href=\"https://www.unddit.com{}\">view removed post</a></p></div>",
permalink
)
} else {
rewrite_urls(&val(post, "selftext_html"))
};
// Build a post using data parsed from Reddit post API
Post {
id: val(post, "id"),
title: val(post, "title"),
community: val(post, "subreddit"),
body,
author: Author {
name: val(post, "author"),
flair: Flair {
flair_parts: FlairPart::parse(
post["data"]["author_flair_type"].as_str().unwrap_or_default(),
post["data"]["author_flair_richtext"].as_array(),
post["data"]["author_flair_text"].as_str(),
),
text: val(post, "link_flair_text"),
background_color: val(post, "author_flair_background_color"),
foreground_color: val(post, "author_flair_text_color"),
},
distinguished: val(post, "distinguished"),
},
permalink,
score: format_num(score),
upvote_ratio: ratio as i64,
post_type,
media,
thumbnail: Media {
url: format_url(val(post, "thumbnail").as_str()),
alt_url: String::new(),
width: post["data"]["thumbnail_width"].as_i64().unwrap_or_default(),
height: post["data"]["thumbnail_height"].as_i64().unwrap_or_default(),
poster: String::new(),
},
flair: Flair {
flair_parts: FlairPart::parse(
post["data"]["link_flair_type"].as_str().unwrap_or_default(),
post["data"]["link_flair_richtext"].as_array(),
post["data"]["link_flair_text"].as_str(),
),
text: val(post, "link_flair_text"),
background_color: val(post, "link_flair_background_color"),
foreground_color: if val(post, "link_flair_text_color") == "dark" {
"black".to_string()
} else {
"white".to_string()
},
},
flags: Flags {
nsfw: post["data"]["over_18"].as_bool().unwrap_or_default(),
stickied: post["data"]["stickied"].as_bool().unwrap_or_default() || post["data"]["pinned"].as_bool().unwrap_or(false),
},
domain: val(post, "domain"),
rel_time,
created,
num_duplicates: post["data"]["num_duplicates"].as_u64().unwrap_or(0),
comments: format_num(post["data"]["num_comments"].as_i64().unwrap_or_default()),
gallery,
awards,
nsfw: post["data"]["over_18"].as_bool().unwrap_or_default(),
}
}
@ -531,8 +685,8 @@ pub fn setting(req: &Request<Body>, name: &str) -> String {
req
.cookie(name)
.unwrap_or_else(|| {
// If there is no cookie for this setting, try receiving a default from an environment variable
if let Ok(default) = std::env::var(format!("LIBREDDIT_DEFAULT_{}", name.to_uppercase())) {
// If there is no cookie for this setting, try receiving a default from the config
if let Some(default) = crate::config::get_setting(&format!("LIBREDDIT_DEFAULT_{}", name.to_uppercase())) {
Cookie::new(name, default)
} else {
Cookie::named(name)
@ -713,11 +867,12 @@ pub fn redirect(path: String) -> Response<Body> {
.unwrap_or_default()
}
pub async fn error(req: Request<Body>, msg: String) -> Result<Response<Body>, String> {
/// Renders a generic error landing page.
pub async fn error(req: Request<Body>, msg: impl ToString) -> Result<Response<Body>, String> {
let url = req.uri().to_string();
let body = ErrorTemplate {
msg,
prefs: Preferences::new(req),
msg: msg.to_string(),
prefs: Preferences::new(&req),
url,
}
.render()
@ -726,10 +881,54 @@ pub async fn error(req: Request<Body>, msg: String) -> Result<Response<Body>, St
Ok(Response::builder().status(404).header("content-type", "text/html").body(body.into()).unwrap_or_default())
}
/// Returns true if the config/env variable `LIBREDDIT_SFW_ONLY` carries the
/// value `on`.
///
/// If this variable is set as such, the instance will operate in SFW-only
/// mode; all NSFW content will be filtered. Attempts to access NSFW
/// subreddits or posts or userpages for users Reddit has deemed NSFW will
/// be denied.
pub fn sfw_only() -> bool {
match crate::config::get_setting("LIBREDDIT_SFW_ONLY") {
Some(val) => val == "on",
None => false,
}
}
/// Renders the landing page for NSFW content when the user has not enabled
/// "show NSFW posts" in settings.
pub async fn nsfw_landing(req: Request<Body>) -> Result<Response<Body>, String> {
let res_type: ResourceType;
let url = req.uri().to_string();
// Determine from the request URL if the resource is a subreddit, a user
// page, or a post.
let res: String = if !req.param("name").unwrap_or_default().is_empty() {
res_type = ResourceType::User;
req.param("name").unwrap_or_default()
} else if !req.param("id").unwrap_or_default().is_empty() {
res_type = ResourceType::Post;
req.param("id").unwrap_or_default()
} else {
res_type = ResourceType::Subreddit;
req.param("sub").unwrap_or_default()
};
let body = NSFWLandingTemplate {
res,
res_type,
prefs: Preferences::new(&req),
url,
}
.render()
.unwrap_or_default();
Ok(Response::builder().status(403).header("content-type", "text/html").body(body.into()).unwrap_or_default())
}
#[cfg(test)]
mod tests {
use super::format_num;
use super::rewrite_urls;
use super::{format_num, format_url, rewrite_urls};
#[test]
fn format_num_works() {
@ -749,4 +948,33 @@ mod tests {
r#"<a href="https://www.reddit.com/r/linux_gaming/comments/x/just_a_test/">https://www.reddit.com/r/linux_gaming/comments/x/just_a_test/</a>"#
)
}
#[test]
fn test_format_url() {
assert_eq!(format_url("https://a.thumbs.redditmedia.com/XYZ.jpg"), "/thumb/a/XYZ.jpg");
assert_eq!(format_url("https://emoji.redditmedia.com/a/b"), "/emoji/a/b");
assert_eq!(
format_url("https://external-preview.redd.it/foo.jpg?auto=webp&s=bar"),
"/preview/external-pre/foo.jpg?auto=webp&s=bar"
);
assert_eq!(format_url("https://i.redd.it/foobar.jpg"), "/img/foobar.jpg");
assert_eq!(
format_url("https://preview.redd.it/qwerty.jpg?auto=webp&s=asdf"),
"/preview/pre/qwerty.jpg?auto=webp&s=asdf"
);
assert_eq!(format_url("https://v.redd.it/foo/DASH_360.mp4?source=fallback"), "/vid/foo/360.mp4");
assert_eq!(
format_url("https://v.redd.it/foo/HLSPlaylist.m3u8?a=bar&v=1&f=sd"),
"/hls/foo/HLSPlaylist.m3u8?a=bar&v=1&f=sd"
);
assert_eq!(format_url("https://www.redditstatic.com/gold/awards/icon/icon.png"), "/static/gold/awards/icon/icon.png");
assert_eq!(format_url(""), "");
assert_eq!(format_url("self"), "");
assert_eq!(format_url("default"), "");
assert_eq!(format_url("nsfw"), "");
assert_eq!(format_url("spoiler"), "");
}
}