417 lines
12 KiB
Rust
417 lines
12 KiB
Rust
use std::{
|
|
io::SeekFrom,
|
|
net::{IpAddr, SocketAddr},
|
|
str::FromStr,
|
|
};
|
|
|
|
use async_zip::error::ZipError;
|
|
use axum::{
|
|
extract::{ConnectInfo, Request},
|
|
http::HeaderMap,
|
|
response::{IntoResponse, Response},
|
|
};
|
|
use headers::{Header, HeaderMapExt};
|
|
use http::{header, StatusCode};
|
|
use mime_guess::Mime;
|
|
use serde::Serialize;
|
|
use tokio::io::{AsyncRead, AsyncReadExt, AsyncSeek, AsyncSeekExt};
|
|
use tokio_util::bytes::{BufMut, BytesMut};
|
|
|
|
use crate::error::{Error, Result};
|
|
|
|
/// HTTP response builder extensions
|
|
pub trait ResponseBuilderExt {
|
|
/// Inserts a typed header to this response.
|
|
fn typed_header<T: Header>(self, header: T) -> Self;
|
|
fn cache(self) -> Self;
|
|
fn cache_immutable(self) -> Self;
|
|
/// Consumes this builder, using the provided json-serializable `val` to return a constructed [`Response`]
|
|
fn json<T: Serialize>(self, val: &T) -> core::result::Result<Response, http::Error>;
|
|
}
|
|
|
|
impl ResponseBuilderExt for axum::http::response::Builder {
|
|
fn typed_header<T: Header>(mut self, header: T) -> Self {
|
|
if let Some(headers) = self.headers_mut() {
|
|
headers.typed_insert(header);
|
|
}
|
|
self
|
|
}
|
|
|
|
fn cache(self) -> Self {
|
|
self.header(
|
|
http::header::CACHE_CONTROL,
|
|
http::HeaderValue::from_static("max-age=1800,public"),
|
|
)
|
|
}
|
|
|
|
fn cache_immutable(self) -> Self {
|
|
self.header(
|
|
http::header::CACHE_CONTROL,
|
|
http::HeaderValue::from_static("max-age=31536000,public,immutable"),
|
|
)
|
|
}
|
|
|
|
fn json<T: Serialize>(self, val: &T) -> core::result::Result<Response, http::Error> {
|
|
// copied from axum::json::into_response
|
|
// Use a small initial capacity of 128 bytes like serde_json::to_vec
|
|
// https://docs.rs/serde_json/1.0.82/src/serde_json/ser.rs.html#2189
|
|
let mut buf = BytesMut::with_capacity(128).writer();
|
|
match serde_json::to_writer(&mut buf, val) {
|
|
Ok(()) => self
|
|
.typed_header(headers::ContentType::json())
|
|
.body(buf.into_inner().freeze().into()),
|
|
Err(err) => self
|
|
.status(StatusCode::INTERNAL_SERVER_ERROR)
|
|
.typed_header(headers::ContentType::text())
|
|
.body(err.to_string().into()),
|
|
}
|
|
}
|
|
}
|
|
|
|
pub fn accepts_gzip(headers: &HeaderMap) -> bool {
|
|
headers
|
|
.get(header::ACCEPT_ENCODING)
|
|
.and_then(|h| h.to_str().ok())
|
|
.map(|h| {
|
|
h.split(',').any(|val| {
|
|
val.split(';')
|
|
.next()
|
|
.map(|v| {
|
|
let vt = v.trim();
|
|
vt.eq_ignore_ascii_case("gzip") || vt == "*"
|
|
})
|
|
.unwrap_or_default()
|
|
})
|
|
})
|
|
.unwrap_or_default()
|
|
}
|
|
|
|
/// Seek to the contained compressed data within a zip file
|
|
pub async fn seek_to_data_offset<R: AsyncRead + AsyncSeek + Unpin>(
|
|
reader: &mut R,
|
|
header_offset: u64,
|
|
) -> core::result::Result<(), ZipError> {
|
|
const LFH_SIGNATURE: u32 = 0x4034b50;
|
|
|
|
// Seek to the header
|
|
reader.seek(SeekFrom::Start(header_offset)).await?;
|
|
|
|
// Check the signature
|
|
let signature = {
|
|
let mut buffer = [0; 4];
|
|
reader.read_exact(&mut buffer).await?;
|
|
u32::from_le_bytes(buffer)
|
|
};
|
|
|
|
match signature {
|
|
LFH_SIGNATURE => (),
|
|
actual => return Err(ZipError::UnexpectedHeaderError(actual, LFH_SIGNATURE)),
|
|
};
|
|
|
|
// Skip the local file header and trailing data
|
|
let mut header_data: [u8; 26] = [0; 26];
|
|
reader.read_exact(&mut header_data).await?;
|
|
let file_name_length = u16::from_le_bytes(header_data[22..24].try_into().unwrap());
|
|
let extra_field_length = u16::from_le_bytes(header_data[24..26].try_into().unwrap());
|
|
|
|
let trailing_size = (file_name_length as i64) + (extra_field_length as i64);
|
|
reader.seek(SeekFrom::Current(trailing_size)).await?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// Return the file extension of a website path
|
|
pub fn site_path_ext(path: &str) -> Option<&str> {
|
|
let mut parts = path.split('.').rev();
|
|
parts
|
|
.next()
|
|
.filter(|ext| !ext.contains('/') && parts.next().is_some())
|
|
}
|
|
|
|
/// Get the file extension of a website path
|
|
pub fn path_mime(path: &str) -> Option<Mime> {
|
|
site_path_ext(path).and_then(|ext| mime_guess::from_ext(ext).first())
|
|
}
|
|
|
|
pub fn full_url_from_request(request: &Request) -> String {
|
|
let uri = request.uri();
|
|
if let Some(host) = host_from_request(request) {
|
|
format!("{}{}", host, uri.path())
|
|
} else {
|
|
uri.to_string()
|
|
}
|
|
}
|
|
|
|
fn host_from_request(request: &Request) -> Option<&str> {
|
|
parse_forwarded(request.headers())
|
|
.or_else(|| {
|
|
request
|
|
.headers()
|
|
.get("X-Forwarded-Host")
|
|
.and_then(|host| host.to_str().ok())
|
|
})
|
|
.or_else(|| {
|
|
request
|
|
.headers()
|
|
.get(http::header::HOST)
|
|
.and_then(|host| host.to_str().ok())
|
|
})
|
|
}
|
|
|
|
fn parse_forwarded(headers: &HeaderMap) -> Option<&str> {
|
|
// if there are multiple `Forwarded` `HeaderMap::get` will return the first one
|
|
let forwarded_values = headers.get(header::FORWARDED)?.to_str().ok()?;
|
|
|
|
// get the first set of values
|
|
let first_value = forwarded_values.split(',').next()?;
|
|
|
|
// find the value of the `host` field
|
|
first_value.split(';').find_map(|pair| {
|
|
let (key, value) = pair.split_once('=')?;
|
|
key.trim()
|
|
.eq_ignore_ascii_case("host")
|
|
.then(|| value.trim().trim_matches('"'))
|
|
})
|
|
}
|
|
|
|
pub fn get_subdomain<'a>(host: &'a str, root_domain: &str) -> Result<&'a str> {
|
|
let stripped = host.strip_suffix(root_domain).ok_or(Error::BadRequest(
|
|
"host does not end with configured ROOT_DOMAIN".into(),
|
|
))?;
|
|
Ok(stripped.trim_end_matches('.'))
|
|
}
|
|
|
|
pub fn get_ip_address(request: &Request, real_ip_header: Option<&str>) -> Result<IpAddr> {
|
|
match real_ip_header.and_then(|header| {
|
|
request
|
|
.headers()
|
|
.get(header)
|
|
.and_then(|val| val.to_str().ok())
|
|
.and_then(|val| IpAddr::from_str(val).ok())
|
|
}) {
|
|
Some(from_header) => Ok(from_header),
|
|
None => {
|
|
let socket_addr = request
|
|
.extensions()
|
|
.get::<ConnectInfo<SocketAddr>>()
|
|
.ok_or(Error::Other("could get request ip address".into()))?
|
|
.0;
|
|
Ok(socket_addr.ip())
|
|
}
|
|
}
|
|
}
|
|
|
|
pub trait IgnoreFileNotFound {
|
|
fn ignore_file_not_found(self) -> core::result::Result<(), std::io::Error>;
|
|
}
|
|
|
|
impl<T> IgnoreFileNotFound for core::result::Result<T, std::io::Error> {
|
|
fn ignore_file_not_found(self) -> core::result::Result<(), std::io::Error> {
|
|
match self {
|
|
Ok(_) => Ok(()),
|
|
Err(e) => match e.kind() {
|
|
std::io::ErrorKind::NotFound => Ok(()),
|
|
_ => todo!(),
|
|
},
|
|
}
|
|
}
|
|
}
|
|
|
|
pub fn parse_url(input: &str) -> Result<(&str, std::str::Split<char>)> {
|
|
let s = input.trim();
|
|
let s = s.strip_prefix("http://").unwrap_or(s);
|
|
let s = s.strip_prefix("https://").unwrap_or(s);
|
|
let s = s
|
|
.split(['?', '#'])
|
|
.next()
|
|
.ok_or(Error::BadRequest("empty URL".into()))?;
|
|
let mut parts = s.split('/');
|
|
let host = parts.next().ok_or(Error::BadRequest("empty URL".into()))?;
|
|
if host.is_empty() {
|
|
return Err(Error::BadRequest("empty URL".into()));
|
|
}
|
|
if !host
|
|
.chars()
|
|
.all(|c| c.is_ascii_alphanumeric() || matches!(c, '-' | '.'))
|
|
{
|
|
return Err(Error::BadRequest(
|
|
"bad URL: domain contains invalid characters".into(),
|
|
));
|
|
}
|
|
Ok((host, parts))
|
|
}
|
|
|
|
pub fn time_to_ms(time: f64) -> u64 {
|
|
(time * 1000.0) as u64
|
|
}
|
|
|
|
/// Get the extension from a filename for selecting a viewer
|
|
pub fn filename_ext(filename: &str) -> &str {
|
|
let mut rsplit = filename.rsplit('.');
|
|
let ext = rsplit.next().unwrap();
|
|
if filename.starts_with('.') && rsplit.next().map(str::is_empty).unwrap_or(true) {
|
|
// Dotfile without extension (e.g. .bashrc)
|
|
filename
|
|
} else {
|
|
ext
|
|
}
|
|
}
|
|
|
|
#[derive(Serialize)]
|
|
pub struct ErrorJson {
|
|
status: u16,
|
|
msg: String,
|
|
}
|
|
|
|
impl ErrorJson {
|
|
pub fn ok<S: Into<String>>(msg: S) -> Self {
|
|
Self {
|
|
status: 200,
|
|
msg: msg.into(),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl From<Error> for ErrorJson {
|
|
fn from(value: Error) -> Self {
|
|
Self {
|
|
status: value.status().as_u16(),
|
|
msg: value.to_string(),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl From<http::Error> for ErrorJson {
|
|
fn from(value: http::Error) -> Self {
|
|
Self::from(Error::from(value))
|
|
}
|
|
}
|
|
|
|
impl IntoResponse for ErrorJson {
|
|
fn into_response(self) -> Response {
|
|
Response::builder().status(self.status).json(&self).unwrap()
|
|
}
|
|
}
|
|
|
|
pub fn extract_delim<'a>(s: &'a str, start: &str, end: &str) -> Option<&'a str> {
|
|
if let Some(np) = s.find(start) {
|
|
if let Some(np_end) = s[np + start.len()..].find(end) {
|
|
return Some(s[np + start.len()..np + start.len() + np_end].trim());
|
|
}
|
|
}
|
|
None
|
|
}
|
|
|
|
pub fn split_icon_prefix(s: &str) -> (&str, &str) {
|
|
if let Some((i, c)) = s
|
|
.char_indices()
|
|
.find(|(_, c)| c.is_ascii() || c.is_alphanumeric())
|
|
{
|
|
if i > 0 && c == ' ' && s.get(i + 1..).is_some() {
|
|
return (&s[..i + 1], &s[i + 1..]);
|
|
}
|
|
}
|
|
("", s)
|
|
}
|
|
|
|
#[cfg(test)]
|
|
pub(crate) mod tests {
|
|
use std::path::{Path, PathBuf};
|
|
|
|
use http::{header, HeaderMap};
|
|
use once_cell::sync::Lazy;
|
|
use path_macro::path;
|
|
use rstest::rstest;
|
|
|
|
pub static TESTFILES: Lazy<PathBuf> =
|
|
Lazy::new(|| path!(env!("CARGO_MANIFEST_DIR") / "tests" / "testfiles"));
|
|
|
|
static SITEDIR: Lazy<PathBuf> = Lazy::new(|| {
|
|
let sitedir = path!(*TESTFILES / "sites_data");
|
|
if !sitedir.is_dir() {
|
|
std::process::Command::new(path!(*TESTFILES / "sites" / "make_zip.sh"))
|
|
.output()
|
|
.unwrap();
|
|
}
|
|
sitedir
|
|
});
|
|
|
|
pub fn setup_cache_dir(dir: &Path) {
|
|
for entry in std::fs::read_dir(SITEDIR.as_path()).unwrap() {
|
|
let entry = entry.unwrap();
|
|
if entry.file_type().unwrap().is_file() {
|
|
std::fs::copy(entry.path(), path!(dir / entry.file_name())).unwrap();
|
|
}
|
|
}
|
|
}
|
|
|
|
#[rstest]
|
|
#[case("", false)]
|
|
#[case("br", false)]
|
|
#[case("gzip", true)]
|
|
#[case("GZIP", true)]
|
|
#[case("*", true)]
|
|
#[case("deflate, gzip;q=1.0, *;q=0.5", true)]
|
|
fn accepts_gzip(#[case] val: &str, #[case] expect: bool) {
|
|
let mut hdrs = HeaderMap::new();
|
|
hdrs.insert(header::ACCEPT_ENCODING, val.try_into().unwrap());
|
|
|
|
assert_eq!(super::accepts_gzip(&hdrs), expect);
|
|
}
|
|
|
|
#[rstest]
|
|
#[case("localhost", Some(""))]
|
|
#[case("test.localhost", Some("test"))]
|
|
#[case("example.com", None)]
|
|
fn get_subdomain(#[case] host: &str, #[case] expect: Option<&str>) {
|
|
assert_eq!(super::get_subdomain(host, "localhost").ok(), expect);
|
|
}
|
|
|
|
#[rstest]
|
|
#[case("https://github.com", Some("github.com"), "")]
|
|
#[case(
|
|
"https://github.com/Theta-Dev/example.project",
|
|
Some("github.com"),
|
|
"Theta-Dev/example.project/"
|
|
)]
|
|
#[case(
|
|
"https://github.com/Theta-Dev/example.project?key=val#to",
|
|
Some("github.com"),
|
|
"Theta-Dev/example.project/"
|
|
)]
|
|
#[case("", None, "")]
|
|
fn parse_url(#[case] input: &str, #[case] host: Option<&str>, #[case] parts: &str) {
|
|
match super::parse_url(input) {
|
|
Ok((h, p)) => {
|
|
assert_eq!(Some(h), host);
|
|
let parts_joined = p.fold(String::new(), |acc, p| acc + p + "/");
|
|
assert_eq!(parts_joined, parts);
|
|
}
|
|
Err(_) => {
|
|
assert!(host.is_none());
|
|
}
|
|
}
|
|
}
|
|
|
|
#[rstest]
|
|
#[case("hello.txt", "txt")]
|
|
#[case(".bashrc", ".bashrc")]
|
|
#[case("Makefile", "Makefile")]
|
|
#[case("", "")]
|
|
fn filename_ext(#[case] filename: &str, #[case] expect: &str) {
|
|
let res = super::filename_ext(filename);
|
|
assert_eq!(res, expect);
|
|
}
|
|
|
|
#[rstest]
|
|
#[case("🧪 Test", ("🧪 ", "Test"))]
|
|
#[case("🧪👨👩👦 Test", ("🧪👨👩👦 ", "Test"))]
|
|
#[case("🧪 👨👩👦 Test", ("🧪 ", "👨👩👦 Test"))]
|
|
#[case("", ("", ""))]
|
|
#[case("Test", ("", "Test"))]
|
|
#[case("運命 Test", ("", "運命 Test"))]
|
|
fn split_icon_prefix(#[case] s: &str, #[case] expect: (&str, &str)) {
|
|
let res = super::split_icon_prefix(s);
|
|
assert_eq!(res, expect);
|
|
}
|
|
}
|