use std::{ io::SeekFrom, net::{IpAddr, SocketAddr}, str::FromStr, }; use async_zip::error::ZipError; use axum::{ extract::{ConnectInfo, Request}, http::HeaderMap, response::{IntoResponse, Response}, }; use headers::{Header, HeaderMapExt}; use http::{header, StatusCode}; use mime_guess::Mime; use serde::Serialize; use tokio::io::{AsyncRead, AsyncReadExt, AsyncSeek, AsyncSeekExt}; use tokio_util::bytes::{BufMut, BytesMut}; use crate::error::{Error, Result}; /// HTTP response builder extensions pub trait ResponseBuilderExt { /// Inserts a typed header to this response. fn typed_header(self, header: T) -> Self; fn cache(self) -> Self; fn cache_immutable(self) -> Self; /// Consumes this builder, using the provided json-serializable `val` to return a constructed [`Response`] fn json(self, val: &T) -> core::result::Result; } impl ResponseBuilderExt for axum::http::response::Builder { fn typed_header(mut self, header: T) -> Self { if let Some(headers) = self.headers_mut() { headers.typed_insert(header); } self } fn cache(self) -> Self { self.header( http::header::CACHE_CONTROL, http::HeaderValue::from_static("max-age=1800,public"), ) } fn cache_immutable(self) -> Self { self.header( http::header::CACHE_CONTROL, http::HeaderValue::from_static("max-age=31536000,public,immutable"), ) } fn json(self, val: &T) -> core::result::Result { // copied from axum::json::into_response // Use a small initial capacity of 128 bytes like serde_json::to_vec // https://docs.rs/serde_json/1.0.82/src/serde_json/ser.rs.html#2189 let mut buf = BytesMut::with_capacity(128).writer(); match serde_json::to_writer(&mut buf, val) { Ok(()) => self .typed_header(headers::ContentType::json()) .body(buf.into_inner().freeze().into()), Err(err) => self .status(StatusCode::INTERNAL_SERVER_ERROR) .typed_header(headers::ContentType::text()) .body(err.to_string().into()), } } } pub fn accepts_gzip(headers: &HeaderMap) -> bool { headers .get(header::ACCEPT_ENCODING) .and_then(|h| h.to_str().ok()) .map(|h| { h.split(',').any(|val| { val.split(';') .next() .map(|v| { let vt = v.trim(); vt.eq_ignore_ascii_case("gzip") || vt == "*" }) .unwrap_or_default() }) }) .unwrap_or_default() } /// Seek to the contained compressed data within a zip file pub async fn seek_to_data_offset( reader: &mut R, header_offset: u64, ) -> core::result::Result<(), ZipError> { const LFH_SIGNATURE: u32 = 0x4034b50; // Seek to the header reader.seek(SeekFrom::Start(header_offset)).await?; // Check the signature let signature = { let mut buffer = [0; 4]; reader.read_exact(&mut buffer).await?; u32::from_le_bytes(buffer) }; match signature { LFH_SIGNATURE => (), actual => return Err(ZipError::UnexpectedHeaderError(actual, LFH_SIGNATURE)), }; // Skip the local file header and trailing data let mut header_data: [u8; 26] = [0; 26]; reader.read_exact(&mut header_data).await?; let file_name_length = u16::from_le_bytes(header_data[22..24].try_into().unwrap()); let extra_field_length = u16::from_le_bytes(header_data[24..26].try_into().unwrap()); let trailing_size = (file_name_length as i64) + (extra_field_length as i64); reader.seek(SeekFrom::Current(trailing_size)).await?; Ok(()) } /// Return the file extension of a website path pub fn site_path_ext(path: &str) -> Option<&str> { let mut parts = path.split('.').rev(); parts .next() .filter(|ext| !ext.contains('/') && parts.next().is_some()) } /// Get the file extension of a website path pub fn path_mime(path: &str) -> Option { site_path_ext(path).and_then(|ext| mime_guess::from_ext(ext).first()) } pub fn full_url_from_request(request: &Request) -> String { let uri = request.uri(); if let Some(host) = host_from_request(request) { format!("{}{}", host, uri.path()) } else { uri.to_string() } } fn host_from_request(request: &Request) -> Option<&str> { parse_forwarded(request.headers()) .or_else(|| { request .headers() .get("X-Forwarded-Host") .and_then(|host| host.to_str().ok()) }) .or_else(|| { request .headers() .get(http::header::HOST) .and_then(|host| host.to_str().ok()) }) } fn parse_forwarded(headers: &HeaderMap) -> Option<&str> { // if there are multiple `Forwarded` `HeaderMap::get` will return the first one let forwarded_values = headers.get(header::FORWARDED)?.to_str().ok()?; // get the first set of values let first_value = forwarded_values.split(',').next()?; // find the value of the `host` field first_value.split(';').find_map(|pair| { let (key, value) = pair.split_once('=')?; key.trim() .eq_ignore_ascii_case("host") .then(|| value.trim().trim_matches('"')) }) } pub fn get_subdomain<'a>(host: &'a str, root_domain: &str) -> Result<&'a str> { let stripped = host.strip_suffix(root_domain).ok_or(Error::BadRequest( "host does not end with configured ROOT_DOMAIN".into(), ))?; Ok(stripped.trim_end_matches('.')) } pub fn get_ip_address(request: &Request, real_ip_header: Option<&str>) -> Result { match real_ip_header.and_then(|header| { request .headers() .get(header) .and_then(|val| val.to_str().ok()) .and_then(|val| IpAddr::from_str(val).ok()) }) { Some(from_header) => Ok(from_header), None => { let socket_addr = request .extensions() .get::>() .ok_or(Error::Other("could get request ip address".into()))? .0; Ok(socket_addr.ip()) } } } pub trait IgnoreFileNotFound { fn ignore_file_not_found(self) -> core::result::Result<(), std::io::Error>; } impl IgnoreFileNotFound for core::result::Result { fn ignore_file_not_found(self) -> core::result::Result<(), std::io::Error> { match self { Ok(_) => Ok(()), Err(e) => match e.kind() { std::io::ErrorKind::NotFound => Ok(()), _ => todo!(), }, } } } pub fn parse_url(input: &str) -> Result<(&str, std::str::Split)> { let s = input.trim(); let s = s.strip_prefix("http://").unwrap_or(s); let s = s.strip_prefix("https://").unwrap_or(s); let s = s .split(['?', '#']) .next() .ok_or(Error::BadRequest("empty URL".into()))?; let mut parts = s.split('/'); let host = parts.next().ok_or(Error::BadRequest("empty URL".into()))?; if host.is_empty() { return Err(Error::BadRequest("empty URL".into())); } if !host .chars() .all(|c| c.is_ascii_alphanumeric() || matches!(c, '-' | '.')) { return Err(Error::BadRequest( "bad URL: domain contains invalid characters".into(), )); } Ok((host, parts)) } pub fn time_to_ms(time: f64) -> u64 { (time * 1000.0) as u64 } /// Get the extension from a filename for selecting a viewer pub fn filename_ext(filename: &str) -> &str { let mut rsplit = filename.rsplit('.'); let ext = rsplit.next().unwrap(); if filename.starts_with('.') && rsplit.next().map(str::is_empty).unwrap_or(true) { // Dotfile without extension (e.g. .bashrc) filename } else { ext } } #[derive(Serialize)] pub struct ErrorJson { status: u16, msg: String, } impl ErrorJson { pub fn ok>(msg: S) -> Self { Self { status: 200, msg: msg.into(), } } } impl From for ErrorJson { fn from(value: Error) -> Self { Self { status: value.status().as_u16(), msg: value.to_string(), } } } impl From for ErrorJson { fn from(value: http::Error) -> Self { Self::from(Error::from(value)) } } impl IntoResponse for ErrorJson { fn into_response(self) -> Response { Response::builder().status(self.status).json(&self).unwrap() } } pub fn extract_delim<'a>(s: &'a str, start: &str, end: &str) -> Option<&'a str> { if let Some(np) = s.find(start) { if let Some(np_end) = s[np + start.len()..].find(end) { return Some(s[np + start.len()..np + start.len() + np_end].trim()); } } None } pub fn split_icon_prefix(s: &str) -> (&str, &str) { if let Some((i, c)) = s .char_indices() .find(|(_, c)| c.is_ascii() || c.is_alphanumeric()) { if i > 0 && c == ' ' && s.get(i + 1..).is_some() { return (&s[..i + 1], &s[i + 1..]); } } ("", s) } #[cfg(test)] pub(crate) mod tests { use std::path::{Path, PathBuf}; use http::{header, HeaderMap}; use once_cell::sync::Lazy; use path_macro::path; use rstest::rstest; pub static TESTFILES: Lazy = Lazy::new(|| path!(env!("CARGO_MANIFEST_DIR") / "tests" / "testfiles")); static SITEDIR: Lazy = Lazy::new(|| { let sitedir = path!(*TESTFILES / "sites_data"); if !sitedir.is_dir() { std::process::Command::new(path!(*TESTFILES / "sites" / "make_zip.sh")) .output() .unwrap(); } sitedir }); pub fn setup_cache_dir(dir: &Path) { for entry in std::fs::read_dir(SITEDIR.as_path()).unwrap() { let entry = entry.unwrap(); if entry.file_type().unwrap().is_file() { std::fs::copy(entry.path(), path!(dir / entry.file_name())).unwrap(); } } } #[rstest] #[case("", false)] #[case("br", false)] #[case("gzip", true)] #[case("GZIP", true)] #[case("*", true)] #[case("deflate, gzip;q=1.0, *;q=0.5", true)] fn accepts_gzip(#[case] val: &str, #[case] expect: bool) { let mut hdrs = HeaderMap::new(); hdrs.insert(header::ACCEPT_ENCODING, val.try_into().unwrap()); assert_eq!(super::accepts_gzip(&hdrs), expect); } #[rstest] #[case("localhost", Some(""))] #[case("test.localhost", Some("test"))] #[case("example.com", None)] fn get_subdomain(#[case] host: &str, #[case] expect: Option<&str>) { assert_eq!(super::get_subdomain(host, "localhost").ok(), expect); } #[rstest] #[case("https://github.com", Some("github.com"), "")] #[case( "https://github.com/Theta-Dev/example.project", Some("github.com"), "Theta-Dev/example.project/" )] #[case( "https://github.com/Theta-Dev/example.project?key=val#to", Some("github.com"), "Theta-Dev/example.project/" )] #[case("", None, "")] fn parse_url(#[case] input: &str, #[case] host: Option<&str>, #[case] parts: &str) { match super::parse_url(input) { Ok((h, p)) => { assert_eq!(Some(h), host); let parts_joined = p.fold(String::new(), |acc, p| acc + p + "/"); assert_eq!(parts_joined, parts); } Err(_) => { assert!(host.is_none()); } } } #[rstest] #[case("hello.txt", "txt")] #[case(".bashrc", ".bashrc")] #[case("Makefile", "Makefile")] #[case("", "")] fn filename_ext(#[case] filename: &str, #[case] expect: &str) { let res = super::filename_ext(filename); assert_eq!(res, expect); } #[rstest] #[case("๐Ÿงช Test", ("๐Ÿงช ", "Test"))] #[case("๐Ÿงช๐Ÿ‘จโ€๐Ÿ‘ฉโ€๐Ÿ‘ฆ Test", ("๐Ÿงช๐Ÿ‘จโ€๐Ÿ‘ฉโ€๐Ÿ‘ฆ ", "Test"))] #[case("๐Ÿงช ๐Ÿ‘จโ€๐Ÿ‘ฉโ€๐Ÿ‘ฆ Test", ("๐Ÿงช ", "๐Ÿ‘จโ€๐Ÿ‘ฉโ€๐Ÿ‘ฆ Test"))] #[case("", ("", ""))] #[case("Test", ("", "Test"))] #[case("้‹ๅ‘ฝ Test", ("", "้‹ๅ‘ฝ Test"))] fn split_icon_prefix(#[case] s: &str, #[case] expect: (&str, &str)) { let res = super::split_icon_prefix(s); assert_eq!(res, expect); } }