artifactview/src/util.rs

417 lines
12 KiB
Rust

use std::{
io::SeekFrom,
net::{IpAddr, SocketAddr},
str::FromStr,
};
use async_zip::error::ZipError;
use axum::{
extract::{ConnectInfo, Request},
http::HeaderMap,
response::{IntoResponse, Response},
};
use headers::{Header, HeaderMapExt};
use http::{header, StatusCode};
use mime_guess::Mime;
use serde::Serialize;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncSeek, AsyncSeekExt};
use tokio_util::bytes::{BufMut, BytesMut};
use crate::error::{Error, Result};
/// HTTP response builder extensions
pub trait ResponseBuilderExt {
/// Inserts a typed header to this response.
fn typed_header<T: Header>(self, header: T) -> Self;
fn cache(self) -> Self;
fn cache_immutable(self) -> Self;
/// Consumes this builder, using the provided json-serializable `val` to return a constructed [`Response`]
fn json<T: Serialize>(self, val: &T) -> core::result::Result<Response, http::Error>;
}
impl ResponseBuilderExt for axum::http::response::Builder {
fn typed_header<T: Header>(mut self, header: T) -> Self {
if let Some(headers) = self.headers_mut() {
headers.typed_insert(header);
}
self
}
fn cache(self) -> Self {
self.header(
http::header::CACHE_CONTROL,
http::HeaderValue::from_static("max-age=1800,public"),
)
}
fn cache_immutable(self) -> Self {
self.header(
http::header::CACHE_CONTROL,
http::HeaderValue::from_static("max-age=31536000,public,immutable"),
)
}
fn json<T: Serialize>(self, val: &T) -> core::result::Result<Response, http::Error> {
// copied from axum::json::into_response
// Use a small initial capacity of 128 bytes like serde_json::to_vec
// https://docs.rs/serde_json/1.0.82/src/serde_json/ser.rs.html#2189
let mut buf = BytesMut::with_capacity(128).writer();
match serde_json::to_writer(&mut buf, val) {
Ok(()) => self
.typed_header(headers::ContentType::json())
.body(buf.into_inner().freeze().into()),
Err(err) => self
.status(StatusCode::INTERNAL_SERVER_ERROR)
.typed_header(headers::ContentType::text())
.body(err.to_string().into()),
}
}
}
pub fn accepts_gzip(headers: &HeaderMap) -> bool {
headers
.get(header::ACCEPT_ENCODING)
.and_then(|h| h.to_str().ok())
.map(|h| {
h.split(',').any(|val| {
val.split(';')
.next()
.map(|v| {
let vt = v.trim();
vt.eq_ignore_ascii_case("gzip") || vt == "*"
})
.unwrap_or_default()
})
})
.unwrap_or_default()
}
/// Seek to the contained compressed data within a zip file
pub async fn seek_to_data_offset<R: AsyncRead + AsyncSeek + Unpin>(
reader: &mut R,
header_offset: u64,
) -> core::result::Result<(), ZipError> {
const LFH_SIGNATURE: u32 = 0x4034b50;
// Seek to the header
reader.seek(SeekFrom::Start(header_offset)).await?;
// Check the signature
let signature = {
let mut buffer = [0; 4];
reader.read_exact(&mut buffer).await?;
u32::from_le_bytes(buffer)
};
match signature {
LFH_SIGNATURE => (),
actual => return Err(ZipError::UnexpectedHeaderError(actual, LFH_SIGNATURE)),
};
// Skip the local file header and trailing data
let mut header_data: [u8; 26] = [0; 26];
reader.read_exact(&mut header_data).await?;
let file_name_length = u16::from_le_bytes(header_data[22..24].try_into().unwrap());
let extra_field_length = u16::from_le_bytes(header_data[24..26].try_into().unwrap());
let trailing_size = (file_name_length as i64) + (extra_field_length as i64);
reader.seek(SeekFrom::Current(trailing_size)).await?;
Ok(())
}
/// Return the file extension of a website path
pub fn site_path_ext(path: &str) -> Option<&str> {
let mut parts = path.split('.').rev();
parts
.next()
.filter(|ext| !ext.contains('/') && parts.next().is_some())
}
/// Get the file extension of a website path
pub fn path_mime(path: &str) -> Option<Mime> {
site_path_ext(path).and_then(|ext| mime_guess::from_ext(ext).first())
}
pub fn full_url_from_request(request: &Request) -> String {
let uri = request.uri();
if let Some(host) = host_from_request(request) {
format!("{}{}", host, uri.path())
} else {
uri.to_string()
}
}
fn host_from_request(request: &Request) -> Option<&str> {
parse_forwarded(request.headers())
.or_else(|| {
request
.headers()
.get("X-Forwarded-Host")
.and_then(|host| host.to_str().ok())
})
.or_else(|| {
request
.headers()
.get(http::header::HOST)
.and_then(|host| host.to_str().ok())
})
}
fn parse_forwarded(headers: &HeaderMap) -> Option<&str> {
// if there are multiple `Forwarded` `HeaderMap::get` will return the first one
let forwarded_values = headers.get(header::FORWARDED)?.to_str().ok()?;
// get the first set of values
let first_value = forwarded_values.split(',').next()?;
// find the value of the `host` field
first_value.split(';').find_map(|pair| {
let (key, value) = pair.split_once('=')?;
key.trim()
.eq_ignore_ascii_case("host")
.then(|| value.trim().trim_matches('"'))
})
}
pub fn get_subdomain<'a>(host: &'a str, root_domain: &str) -> Result<&'a str> {
let stripped = host.strip_suffix(root_domain).ok_or(Error::BadRequest(
"host does not end with configured ROOT_DOMAIN".into(),
))?;
Ok(stripped.trim_end_matches('.'))
}
pub fn get_ip_address(request: &Request, real_ip_header: Option<&str>) -> Result<IpAddr> {
match real_ip_header.and_then(|header| {
request
.headers()
.get(header)
.and_then(|val| val.to_str().ok())
.and_then(|val| IpAddr::from_str(val).ok())
}) {
Some(from_header) => Ok(from_header),
None => {
let socket_addr = request
.extensions()
.get::<ConnectInfo<SocketAddr>>()
.ok_or(Error::Other("could get request ip address".into()))?
.0;
Ok(socket_addr.ip())
}
}
}
pub trait IgnoreFileNotFound {
fn ignore_file_not_found(self) -> core::result::Result<(), std::io::Error>;
}
impl<T> IgnoreFileNotFound for core::result::Result<T, std::io::Error> {
fn ignore_file_not_found(self) -> core::result::Result<(), std::io::Error> {
match self {
Ok(_) => Ok(()),
Err(e) => match e.kind() {
std::io::ErrorKind::NotFound => Ok(()),
_ => todo!(),
},
}
}
}
pub fn parse_url(input: &str) -> Result<(&str, std::str::Split<char>)> {
let s = input.trim();
let s = s.strip_prefix("http://").unwrap_or(s);
let s = s.strip_prefix("https://").unwrap_or(s);
let s = s
.split(['?', '#'])
.next()
.ok_or(Error::BadRequest("empty URL".into()))?;
let mut parts = s.split('/');
let host = parts.next().ok_or(Error::BadRequest("empty URL".into()))?;
if host.is_empty() {
return Err(Error::BadRequest("empty URL".into()));
}
if !host
.chars()
.all(|c| c.is_ascii_alphanumeric() || matches!(c, '-' | '.'))
{
return Err(Error::BadRequest(
"bad URL: domain contains invalid characters".into(),
));
}
Ok((host, parts))
}
pub fn time_to_ms(time: f64) -> u64 {
(time * 1000.0) as u64
}
/// Get the extension from a filename for selecting a viewer
pub fn filename_ext(filename: &str) -> &str {
let mut rsplit = filename.rsplit('.');
let ext = rsplit.next().unwrap();
if filename.starts_with('.') && rsplit.next().map(str::is_empty).unwrap_or(true) {
// Dotfile without extension (e.g. .bashrc)
filename
} else {
ext
}
}
#[derive(Serialize)]
pub struct ErrorJson {
status: u16,
msg: String,
}
impl ErrorJson {
pub fn ok<S: Into<String>>(msg: S) -> Self {
Self {
status: 200,
msg: msg.into(),
}
}
}
impl From<Error> for ErrorJson {
fn from(value: Error) -> Self {
Self {
status: value.status().as_u16(),
msg: value.to_string(),
}
}
}
impl From<http::Error> for ErrorJson {
fn from(value: http::Error) -> Self {
Self::from(Error::from(value))
}
}
impl IntoResponse for ErrorJson {
fn into_response(self) -> Response {
Response::builder().status(self.status).json(&self).unwrap()
}
}
pub fn extract_delim<'a>(s: &'a str, start: &str, end: &str) -> Option<&'a str> {
if let Some(np) = s.find(start) {
if let Some(np_end) = s[np + start.len()..].find(end) {
return Some(s[np + start.len()..np + start.len() + np_end].trim());
}
}
None
}
pub fn split_icon_prefix(s: &str) -> (&str, &str) {
if let Some((i, c)) = s
.char_indices()
.find(|(_, c)| c.is_ascii() || c.is_alphanumeric())
{
if i > 0 && c == ' ' && s.get(i + 1..).is_some() {
return (&s[..i + 1], &s[i + 1..]);
}
}
("", s)
}
#[cfg(test)]
pub(crate) mod tests {
use std::path::{Path, PathBuf};
use http::{header, HeaderMap};
use once_cell::sync::Lazy;
use path_macro::path;
use rstest::rstest;
pub static TESTFILES: Lazy<PathBuf> =
Lazy::new(|| path!(env!("CARGO_MANIFEST_DIR") / "tests" / "testfiles"));
static SITEDIR: Lazy<PathBuf> = Lazy::new(|| {
let sitedir = path!(*TESTFILES / "sites_data");
if !sitedir.is_dir() {
std::process::Command::new(path!(*TESTFILES / "sites" / "make_zip.sh"))
.output()
.unwrap();
}
sitedir
});
pub fn setup_cache_dir(dir: &Path) {
for entry in std::fs::read_dir(SITEDIR.as_path()).unwrap() {
let entry = entry.unwrap();
if entry.file_type().unwrap().is_file() {
std::fs::copy(entry.path(), path!(dir / entry.file_name())).unwrap();
}
}
}
#[rstest]
#[case("", false)]
#[case("br", false)]
#[case("gzip", true)]
#[case("GZIP", true)]
#[case("*", true)]
#[case("deflate, gzip;q=1.0, *;q=0.5", true)]
fn accepts_gzip(#[case] val: &str, #[case] expect: bool) {
let mut hdrs = HeaderMap::new();
hdrs.insert(header::ACCEPT_ENCODING, val.try_into().unwrap());
assert_eq!(super::accepts_gzip(&hdrs), expect);
}
#[rstest]
#[case("localhost", Some(""))]
#[case("test.localhost", Some("test"))]
#[case("example.com", None)]
fn get_subdomain(#[case] host: &str, #[case] expect: Option<&str>) {
assert_eq!(super::get_subdomain(host, "localhost").ok(), expect);
}
#[rstest]
#[case("https://github.com", Some("github.com"), "")]
#[case(
"https://github.com/Theta-Dev/example.project",
Some("github.com"),
"Theta-Dev/example.project/"
)]
#[case(
"https://github.com/Theta-Dev/example.project?key=val#to",
Some("github.com"),
"Theta-Dev/example.project/"
)]
#[case("", None, "")]
fn parse_url(#[case] input: &str, #[case] host: Option<&str>, #[case] parts: &str) {
match super::parse_url(input) {
Ok((h, p)) => {
assert_eq!(Some(h), host);
let parts_joined = p.fold(String::new(), |acc, p| acc + p + "/");
assert_eq!(parts_joined, parts);
}
Err(_) => {
assert!(host.is_none());
}
}
}
#[rstest]
#[case("hello.txt", "txt")]
#[case(".bashrc", ".bashrc")]
#[case("Makefile", "Makefile")]
#[case("", "")]
fn filename_ext(#[case] filename: &str, #[case] expect: &str) {
let res = super::filename_ext(filename);
assert_eq!(res, expect);
}
#[rstest]
#[case("🧪 Test", ("🧪 ", "Test"))]
#[case("🧪👨‍👩‍👦 Test", ("🧪👨‍👩‍👦 ", "Test"))]
#[case("🧪 👨‍👩‍👦 Test", ("🧪 ", "👨‍👩‍👦 Test"))]
#[case("", ("", ""))]
#[case("Test", ("", "Test"))]
#[case("運命 Test", ("", "運命 Test"))]
fn split_icon_prefix(#[case] s: &str, #[case] expect: (&str, &str)) {
let res = super::split_icon_prefix(s);
assert_eq!(res, expect);
}
}