File sharing server for small files https://postit.piggo.space
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
postit/src/main.rs

691 lines
22 KiB

#[macro_use]
extern crate serde_derive;
#[macro_use]
extern crate log;
use crate::config::Config;
use crate::well_known_mime::Mime;
use chrono::{DateTime, Utc};
use clappconfig::{anyhow, AppConfig};
use parking_lot::Mutex;
use rand::rngs::OsRng;
use rand::Rng;
use rouille::{Request, Response, ResponseBody};
use std::borrow::Cow;
use std::collections::HashMap;
use std::fs::OpenOptions;
use std::hash::{Hash, Hasher};
use std::io::Read;
use std::time::Duration;
use siphasher::sip::SipHasher;
mod config;
mod well_known_mime;
/// Header to set expiry (seconds)
const HDR_EXPIRY: &str = "X-Expire";
/// Header to pass secret token for update/delete
const HDR_SECRET: &str = "X-Secret";
/// GET param to pass secret token (as a substitute for header)
const GET_EXPIRY: &str = "expire";
/// GET param to pass secret token (as a substitute for header)
const GET_SECRET: &str = "secret";
const FAVICON: &[u8] = include_bytes!("embed/favicon.ico");
const ROBOTS: &[u8] = include_bytes!("embed/robots.txt");
const LANDING_PAGE: &[u8] = include_bytes!("embed/home.html");
/// Post ID (represented as a 16-digit hex string)
type PostId = u64;
/// Write token (represented as a 16-digit hex string)
type Secret = u64;
/// Hash of a data record
type DataHash = u64;
/// Post stored in the repository
#[derive(Debug, Serialize, Deserialize)]
struct Post {
/// Content-Type
mime: Mime,
/// Data hash
hash: DataHash,
/// Secret key for editing or deleting
secret: Secret,
/// Expiration timestamp
#[serde(with = "serde_chrono_datetime_as_unix")]
expires: DateTime<Utc>,
}
impl Post {
/// Check if the post expired
pub fn is_expired(&self) -> bool {
self.expires < Utc::now()
}
/// Get remaining lifetime
pub fn time_remains(&self) -> Duration {
let seconds_remains = self.expires.signed_duration_since(Utc::now())
.num_seconds();
if seconds_remains < 0 {
Duration::from_secs(0)
} else {
Duration::from_secs(seconds_remains as u64)
}
}
}
fn main() -> anyhow::Result<()> {
let config = Config::init("postit", "postit.json", env!("CARGO_PKG_VERSION"))?;
let serve_at = format!("{}:{}", config.host, config.port);
let store = Mutex::new({
let mut store = Repository::new(config);
if store.config.persistence {
if let Err(e) = store.load() {
error!("Load failed: {}", e);
}
}
store.gc_expired_posts();
store
});
rouille::start_server(serve_at, move |req| {
let mut store_w = store.lock();
let method = req.method();
info!("{} {}", method, req.raw_url());
if method == "GET" || method == "HEAD" {
if req.url() == "/" {
return decorate_response(Response::from_data("text/html", LANDING_PAGE)
.with_public_cache(86400));
}
if req.url() == "/favicon.ico" {
return decorate_response(Response::from_data("image/vnd.microsoft.icon", FAVICON)
.with_public_cache(86400 * 7));
}
if req.url() == "/robots.txt" {
return decorate_response(Response::from_data("text/plain", ROBOTS)
.with_public_cache(86400));
}
}
store_w.gc_expired_posts_if_needed();
let resp = match method {
"POST" | "PUT" => store_w.serve_post_put(req),
"GET" | "HEAD" => store_w.serve_get(req),
"DELETE" => store_w.serve_delete(req),
_ => rouille::Response::empty_400(),
};
if store_w.config.persistence {
if let Err(e) = store_w.persist_if_needed() {
error!("Store failed: {}", e);
}
}
if resp.is_error() {
warn!("Error resp: {}", resp.status_code);
}
decorate_response(resp)
});
}
fn decorate_response(resp : Response) -> Response {
resp.without_header("Server")
.with_additional_header("Server", "postit.rs")
.with_additional_header("Access-Control-Allow-Origin", "*")
.with_additional_header("X-Version", env!("CARGO_PKG_VERSION"))
}
type PostsMap = HashMap<PostId, Post>;
type DataMap = HashMap<DataHash, (usize, Vec<u8>)>;
#[derive(Debug, Serialize, Deserialize)]
struct Repository {
#[serde(skip)]
config: Config,
/// Flag that the repository needs saving
#[serde(skip)]
dirty: bool,
/// Stored posts
posts: PostsMap,
/// Post data - (use_count, data)
data: DataMap,
/// Time of last expired posts GC
#[serde(with = "serde_chrono_datetime_as_unix")]
last_gc_time: DateTime<Utc>,
}
impl Repository {
/// New instance
fn new(config: Config) -> Self {
Repository {
config,
dirty: false,
posts: Default::default(),
data: Default::default(),
last_gc_time: Utc::now(),
}
}
fn persist_if_needed(&mut self) -> anyhow::Result<()> {
if self.dirty {
self.persist()
} else {
Ok(())
}
}
/// Store to a file
fn persist(&mut self) -> anyhow::Result<()> {
debug!("Persist to file: {}", self.config.persist_file);
self.dirty = false;
let file = OpenOptions::new()
.truncate(true)
.write(true)
.create(true)
.open(&self.config.persist_file)?;
if self.config.compression {
let flate = flate2::write::DeflateEncoder::new(file, flate2::Compression::best());
bincode::serialize_into(flate, self)?;
} else {
bincode::serialize_into(file, self)?;
}
Ok(())
}
/// Load from a file
fn load(&mut self) -> anyhow::Result<()> {
debug!("Load from file: {}", self.config.persist_file);
let file = OpenOptions::new()
.read(true)
.open(&self.config.persist_file)?;
let result: Repository = if self.config.compression {
let flate = flate2::read::DeflateDecoder::new(file);
bincode::deserialize_from(flate)?
} else {
bincode::deserialize_from(file)?
};
let old_config = self.config.clone();
std::mem::replace(self, result);
self.config = old_config;
self.dirty = false;
Ok(())
}
/// Serve a DELETE request
fn serve_delete(&mut self, req: &Request) -> Response {
let post_id = match self.request_to_post_id(req, true) {
Ok(Some(pid)) => pid,
Ok(None) => return error_with_text(400, "File ID required."),
Err(resp) => return resp,
};
self.delete_post(post_id);
Response::text("Deleted.")
}
/// Serve a POST or PUT request
///
/// POST inserts a new record
/// PUT updates a record
fn serve_post_put(&mut self, req: &Request) -> Response {
let is_post = req.method() == "POST";
let is_put = req.method() == "PUT";
// Post ID is empty for POST, set for PUT
let post_id = match self.request_to_post_id(req, true) {
Ok(pid) => {
if is_put && pid.is_none() {
warn!("PUT without ID!");
return error_with_text(400, "PUT requires a file ID!");
} else if is_post && pid.is_some() {
warn!("POST with ID!");
return error_with_text(400, "Use PUT to update a file!");
}
pid
}
Err(resp) => return resp,
};
debug!("Submit new data, post ID: {:?}", post_id);
let mut data = vec![];
if let Some(body) = req.data() {
// Read up to 1 byte past the limit to catch too large uploads.
// We can't reply on the "Length" field, which is not present with chunked encoding.
body.take(self.config.max_file_size as u64 + 1)
.read_to_end(&mut data)
.unwrap();
if is_post && data.len() == 0 {
warn!("Empty body!");
return error_with_text(400, "Empty body!");
} else if data.len() > self.config.max_file_size {
warn!("Upload too large!");
return empty_error(413);
}
} else {
// Should not be possible
panic!("Req data None!");
}
// Convert "application/x-www-form-urlencoded" to text/plain (CURL uses this)
// NOTE: rouille does NOT parse urlencoded, we will serve the encoded format back if really used.
let mime = match req.header("Content-Type") {
None => None,
Some("application/x-www-form-urlencoded") => None,
Some(v) => Some(v),
};
let expiry = req.get_param(GET_EXPIRY);
let mut expiry_s = expiry.as_ref().map(|s| s.as_str());
if expiry_s.is_none() {
expiry_s = req.header(HDR_EXPIRY);
}
let expiry = match expiry_s {
Some(text) => match text.parse() {
Ok(v) => {
let dur = Duration::from_secs(v);
if dur > self.config.max_expiry {
return error_with_text(
400,
format!(
"Expiration time {} out of allowed range 0-{} s",
v,
self.config.max_expiry.as_secs()
),
);
}
Some(dur)
}
Err(_) => {
return error_with_text(
400,
"Malformed expiration, use relative time in seconds.",
);
}
},
None => None,
};
let the_id;
let resp = if let Some(id) = post_id {
// UPDATE
self.update(id, data, mime, expiry);
the_id = id;
Response::text("Updated OK.")
} else {
// INSERT
let (id, secret) = self.insert(data, mime, expiry.unwrap_or(self.config.default_expiry));
the_id = id;
Response::text(format!("{:016x}", the_id))
.with_additional_header(HDR_SECRET, format!("{:016x}", secret))
};
// add the X-Expires header to the response
let post = self.posts.get(&the_id).unwrap();
resp.with_additional_header(HDR_EXPIRY, post.time_remains().as_secs().to_string())
}
/// Serve a GET request
fn serve_get(&mut self, req: &Request) -> Response {
let post_id = match self.request_to_post_id(req, false) {
Ok(Some(pid)) => pid,
Ok(None) => return error_with_text(400, "File ID required."),
Err(resp) => return resp,
};
if let Some(post) = self.posts.get(&post_id) {
if post.is_expired() {
warn!("GET of expired post!");
Response::empty_404()
} else {
let data = match self.data.get(&post.hash) {
Some((_uses, data)) => data,
None => {
error!("No matching data!");
return error_with_text(500, "File data lost.");
}
};
let seconds_remains = post.expires.signed_duration_since(Utc::now())
.num_seconds();
let headers = vec![
(
"Content-Type".into(),
format!("{}; charset=utf8", post.mime).into(),
),
(
"Cache-Control".into(),
format!("public, max-age={}", seconds_remains).into()
),
(
HDR_EXPIRY.into(),
seconds_remains.to_string().into()
)
];
Response {
status_code: 200,
headers,
data: ResponseBody::from_data(data.clone()),
upgrade: None,
}
}
} else {
warn!("No such post!");
Response::empty_404()
}
}
/// Extract post ID from a request.
///
/// if `check_secret` is true, ensure a `X-Secret` header contains a valid write token
/// for the post ID.
fn request_to_post_id(
&self,
req: &Request,
check_secret: bool,
) -> Result<Option<PostId>, Response> {
let url = req.url();
let stripped = url.trim_matches('/');
if stripped.is_empty() {
debug!("No ID given");
return Ok(None);
}
if stripped.len() != 16 {
warn!("Bad ID len!");
return Err(Response::empty_404());
}
let id = match u64::from_str_radix(stripped, 16) {
Ok(bytes) => bytes,
Err(_) => {
warn!("ID parsing error: {}", stripped);
return Err(Response::empty_404());
}
};
if check_secret {
// Check the write token
match self.posts.get(&id) {
None => {
warn!("ID {} does not exist!", id);
return Err(error_with_text(404, "No file with this ID!"));
}
Some(post) => {
if post.is_expired() {
warn!("Access of expired file {}!", id);
return Err(error_with_text(404, "No file with this ID!"));
}
let secret = req.get_param(GET_SECRET);
let mut secret_str = secret.as_ref().map(|s| s.as_str());
if secret_str.is_none() {
secret_str = req.header(HDR_SECRET);
}
let secret: u64 =
match secret_str.map(|v| u64::from_str_radix(v, 16)) {
Some(Ok(bytes)) => bytes,
None => {
warn!("Missing secret token!");
return Err(error_with_text(400, "Secret token required!"));
}
Some(Err(e)) => {
warn!("Token parse error: {:?}", e);
return Err(error_with_text(400, "Bad secret token format!"));
}
};
if post.secret != secret {
warn!("Secret token mismatch");
return Err(error_with_text(401, "Invalid secret token!"));
}
}
}
}
// secret is now validated and we got an ID
Ok(Some(id))
}
/// Drop expired posts, if cleaning is due
fn gc_expired_posts_if_needed(&mut self) {
if Utc::now()
.signed_duration_since(self.last_gc_time)
.to_std()
.unwrap_or_default()
> self.config.expired_gc_interval
{
self.gc_expired_posts();
self.last_gc_time = Utc::now();
}
}
/// Drop expired posts
fn gc_expired_posts(&mut self) {
debug!("GC expired uploads");
let mut to_rm = vec![];
for post in &self.posts {
if post.1.is_expired() {
to_rm.push(*post.0);
}
}
if !to_rm.is_empty() {
self.dirty = true;
}
for id in to_rm {
debug!("Drop post ID {:016x}", id);
if let Some(post) = self.posts.remove(&id) {
Self::drop_data_or_decrement_rc(&mut self.data, post.hash);
}
}
}
/// Get hash of a byte vector (for deduplication)
fn hash_data(data: &Vec<u8>) -> DataHash {
let mut hasher = SipHasher::new();
data.hash(&mut hasher);
hasher.finish()
}
/// Store a data buffer under a given hash.
/// If the buffer is already present in the repository, increment its use count.
fn store_data_or_increment_rc(data_map: &mut DataMap, hash: u64, data: Vec<u8>) {
match data_map.get_mut(&hash) {
None => {
debug!("Store new data hash #{:016x}", hash);
data_map.insert(hash, (1, data));
}
Some(entry) => {
debug!("Link new use of data hash #{:016x}", hash);
entry.0 += 1; // increment use counter
}
}
}
/// Drop a data record with the given hash, or decrement its use count if there are other uses
fn drop_data_or_decrement_rc(data_map: &mut DataMap, hash: u64) {
if let Some(old_data) = data_map.get_mut(&hash) {
if old_data.0 > 1 {
old_data.0 -= 1;
debug!(
"Unlink use of data hash #{:016x} ({} remain)",
hash, old_data.0
);
} else {
debug!("Drop data hash #{:016x}", hash);
data_map.remove(&hash);
}
}
}
/// Insert a post
fn insert(&mut self, data: Vec<u8>, mime: Option<&str>, expires: Duration) -> (PostId, Secret) {
info!(
"Insert post with data of len {} bytes, mime {}, expiry {:?}",
data.len(),
mime.unwrap_or("unspecified"),
expires
);
let hash = Self::hash_data(&data);
let mime = match mime {
None => Mime::from(tree_magic::from_u8(&data)),
Some(explicit) => Mime::from(explicit),
};
Self::store_data_or_increment_rc(&mut self.data, hash, data);
let post_id = loop {
let id = OsRng.gen();
if !self.posts.contains_key(&id) {
break id;
}
};
let secret = OsRng.gen();
debug!("File ID = #{:016x} (http://{}:{}/{:016x})", post_id, self.config.host, self.config.port, post_id);
debug!("Data hash = #{:016x}, mime {}", hash, mime);
debug!("Secret = #{:016x}", secret);
self.posts.insert(
post_id,
Post {
mime,
hash,
secret,
expires: Utc::now() + chrono::Duration::from_std(expires).unwrap(), // this is safe unless mis-configured
},
);
self.dirty = true;
(post_id, secret)
}
/// Update a post by ID
fn update(&mut self, id: PostId, data: Vec<u8>, mime: Option<&str>, expires: Option<Duration>) {
info!(
"Update post id #{:016x} with data of len {} bytes, mime {}, expiry {}",
id,
data.len(),
mime.unwrap_or("unchanged"),
expires
.map(|v| Cow::Owned(format!("{:?}", v)))
.unwrap_or("unchanged".into())
);
let post = self.posts.get_mut(&id).unwrap(); // post existence was checked before
if !data.is_empty() {
let hash = Self::hash_data(&data);
if hash != post.hash {
debug!("Data hash = #{:016x} (content changed)", hash);
Self::drop_data_or_decrement_rc(&mut self.data, post.hash);
Self::store_data_or_increment_rc(&mut self.data, hash, data);
post.hash = hash;
self.dirty = true;
} else {
debug!("Data hash = #{:016x} (no change)", hash);
}
}
if let Some(mime) = mime {
let new_mime = Mime::from(mime);
if post.mime != new_mime {
debug!("Content type changed to {}", mime);
post.mime = new_mime;
self.dirty = true;
}
}
if let Some(exp) = expires {
debug!("Expiration changed to {:?} from now", exp);
post.expires = Utc::now() + chrono::Duration::from_std(exp).unwrap(); // this is safe unless mis-configured;
self.dirty = true;
}
}
/// Delete a post by ID
fn delete_post(&mut self, id: PostId) {
info!("Delete post id #{:016x}", id);
let post = self.posts.remove(&id).unwrap(); // post existence was checked before
Self::drop_data_or_decrement_rc(&mut self.data, post.hash);
self.dirty = true;
}
}
/// Serialize chrono unix timestamp as seconds
mod serde_chrono_datetime_as_unix {
use chrono::{DateTime, NaiveDateTime, Utc};
use serde::{self, Deserialize, Deserializer, Serializer};
pub fn serialize<S>(value: &DateTime<Utc>, se: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
se.serialize_i64(value.naive_utc().timestamp())
}
pub fn deserialize<'de, D>(de: D) -> Result<DateTime<Utc>, D::Error>
where
D: Deserializer<'de>,
{
let ts: i64 = i64::deserialize(de)?;
Ok(DateTime::from_utc(
NaiveDateTime::from_timestamp(ts, 0),
Utc,
))
}
}
fn error_with_text(code: u16, text: impl Into<String>) -> Response {
Response {
status_code: code,
headers: vec![("Content-Type".into(), "text/plain; charset=utf8".into())],
data: rouille::ResponseBody::from_string(text),
upgrade: None,
}
}
fn empty_error(code: u16) -> Response {
Response {
status_code: code,
headers: vec![("Content-Type".into(), "text/plain; charset=utf8".into())],
data: rouille::ResponseBody::empty(),
upgrade: None,
}
}