Compare commits
No commits in common. 'master' and 'jsons' have entirely different histories.
@ -1,13 +0,0 @@ |
||||
# [0.3.0] |
||||
|
||||
- Update dependencies |
||||
- Added new example |
||||
- Port to rocket `0.5.0-rc.2` |
||||
|
||||
# [0.2.2] |
||||
|
||||
- Update dependencies |
||||
|
||||
# [0.2.1] |
||||
|
||||
- change from `thread_rng` to `OsRng` for better session ID entropy |
@ -1,21 +1,15 @@ |
||||
[package] |
||||
name = "rocket_session" |
||||
version = "0.3.0" |
||||
version = "0.1.0" |
||||
authors = ["Ondřej Hruška <ondra@ondrovo.com>"] |
||||
edition = "2021" |
||||
license = "MIT" |
||||
description = "Rocket.rs plug-in for cookie-based sessions holding arbitrary data" |
||||
repository = "https://git.ondrovo.com/packages/rocket_session" |
||||
readme = "README.md" |
||||
keywords = ["rocket", "rocket-rs", "session", "cookie"] |
||||
categories = [ |
||||
"web-programming", |
||||
"web-programming::http-server" |
||||
] |
||||
edition = "2018" |
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html |
||||
|
||||
[dependencies] |
||||
rand = "0.8" |
||||
rocket = "0.5.0-rc.2" |
||||
parking_lot = "0.12" |
||||
serde = { version = "1.0", features = ["derive"] } |
||||
serde_json = { version="1.0", features= ["preserve_order"] } |
||||
json_dotpath = "0.1.2" |
||||
rand = "0.7.2" |
||||
rocket = { version="0.4.2", default-features = false} |
||||
parking_lot = "0.10.0" |
||||
|
@ -1,124 +1,33 @@ |
||||
# Sessions for Rocket.rs |
||||
|
||||
Adding cookie-based sessions to a rocket application is extremely simple with this crate. |
||||
|
||||
Sessions are used to share data between related requests, such as user authentication, shopping basket, |
||||
form values that failed validation for re-filling, etc. |
||||
|
||||
## Configuration |
||||
|
||||
The implementation is generic to support any type as session data: a custom struct, `String`, |
||||
`HashMap`, or perhaps `serde_json::Value`. You're free to choose. |
||||
|
||||
The session lifetime, cookie name, and other parameters can be configured by calling chained |
||||
methods on the fairing. When a session expires, the data associated with it is dropped. |
||||
|
||||
Example: `Session::fairing().with_lifetime(Duration::from_secs(15))` |
||||
|
||||
## Usage |
||||
|
||||
To use session in a route, first make sure you have the fairing attached by calling |
||||
`rocket.attach(Session::fairing())` at start-up, and then add something like `session : Session` |
||||
to the parameter list of your route(s). Everything else--session init, expiration, cookie |
||||
management--is done for you behind the scenes. |
||||
|
||||
Session data is accessed in a closure run in the session context, using the `session.tap()` |
||||
method. This closure runs inside a per-session mutex, avoiding simultaneous mutation |
||||
from different requests. Try to *avoid lengthy operations inside the closure*, |
||||
as it effectively blocks any other request to session-enabled routes by the client. |
||||
|
||||
Every request to a session-enabled route extends the session's lifetime to the full |
||||
configured time (defaults to 1 hour). Automatic clean-up removes expired sessions to make sure |
||||
the session list does not waste memory. |
||||
|
||||
## Examples |
||||
|
||||
More examples are in the "examples" folder - run with `cargo run --example=NAME` |
||||
|
||||
### Basic Example |
||||
|
||||
This simple example uses u64 as the session variable; note that it can be a struct, map, or anything else, |
||||
it just needs to implement `Send + Sync + Default`. |
||||
Adding cookie-based sessions to a rocket application is extremely simple: |
||||
|
||||
```rust |
||||
#[macro_use] |
||||
extern crate rocket; |
||||
#![feature(proc_macro_hygiene, decl_macro)] |
||||
#[macro_use] extern crate rocket; |
||||
|
||||
type Session<'a> = rocket_session::Session<'a, u64>; |
||||
use rocket_session::Session; |
||||
use std::time::Duration; |
||||
|
||||
#[launch] |
||||
fn rocket() -> _ { |
||||
rocket::build() |
||||
.attach(Session::fairing()) |
||||
fn main() { |
||||
rocket::ignite() |
||||
.attach(Session::fairing(Duration::from_secs(3600))) |
||||
.mount("/", routes![index]) |
||||
.launch(); |
||||
} |
||||
|
||||
#[get("/")] |
||||
fn index(session: Session) -> String { |
||||
let count = session.tap(|n| { |
||||
// Change the stored value (it is &mut) |
||||
*n += 1; |
||||
|
||||
// Return something to the caller. |
||||
// This can be any type, 'tap' is generic. |
||||
*n |
||||
}); |
||||
let mut count: usize = session.get_or_default("count"); |
||||
count += 1; |
||||
session.set("count", count); |
||||
|
||||
format!("{} visits", count) |
||||
} |
||||
|
||||
``` |
||||
|
||||
## Extending Session by a Trait |
||||
|
||||
The `.tap()` method is powerful, but sometimes you may wish for something more convenient. |
||||
|
||||
Here is an example of using a custom trait and the `json_dotpath` crate to implement |
||||
a polymorphic store based on serde serialization. |
||||
Anything serializable can be stored in the session, just make sure to unpack it to the right type. |
||||
|
||||
Note that this approach is prone to data races if you're accessing the session object multiple times per request, |
||||
since every method contains its own `.tap()`. It may be safer to simply call the `.dot_*()` methods manually in one shared closure. |
||||
|
||||
```rust |
||||
use serde_json::Value; |
||||
use serde::de::DeserializeOwned; |
||||
use serde::Serialize; |
||||
use json_dotpath::DotPaths; |
||||
|
||||
pub type Session<'a> = rocket_session::Session<'a, serde_json::Map<String, Value>>; |
||||
|
||||
pub trait SessionAccess { |
||||
fn get<T: DeserializeOwned>(&self, path: &str) -> Option<T>; |
||||
|
||||
fn take<T: DeserializeOwned>(&self, path: &str) -> Option<T>; |
||||
|
||||
fn replace<O: DeserializeOwned, N: Serialize>(&self, path: &str, new: N) -> Option<O>; |
||||
|
||||
fn set<T: Serialize>(&self, path: &str, value: T); |
||||
|
||||
fn remove(&self, path: &str) -> bool; |
||||
} |
||||
|
||||
impl<'a> SessionAccess for Session<'a> { |
||||
fn get<T: DeserializeOwned>(&self, path: &str) -> Option<T> { |
||||
self.tap(|data| data.dot_get(path)) |
||||
} |
||||
|
||||
fn take<T: DeserializeOwned>(&self, path: &str) -> Option<T> { |
||||
self.tap(|data| data.dot_take(path)) |
||||
} |
||||
|
||||
fn replace<O: DeserializeOwned, N: Serialize>(&self, path: &str, new: N) -> Option<O> { |
||||
self.tap(|data| data.dot_replace(path, new)) |
||||
} |
||||
|
||||
fn set<T: Serialize>(&self, path: &str, value: T) { |
||||
self.tap(|data| data.dot_set(path, value)); |
||||
} |
||||
|
||||
fn remove(&self, path: &str) -> bool { |
||||
self.tap(|data| data.dot_remove(path)) |
||||
} |
||||
} |
||||
``` |
||||
The session driver internally uses `serde_json::Value` and the `json_dotpath` crate. |
||||
Therefore, it's possible to use dotted paths and store the session data in a more structured way. |
||||
|
||||
|
@ -1,59 +0,0 @@ |
||||
#[macro_use] |
||||
extern crate rocket; |
||||
|
||||
use rocket::response::content::RawHtml; |
||||
use rocket::response::Redirect; |
||||
|
||||
type Session<'a> = rocket_session::Session<'a, Vec<String>>; |
||||
|
||||
#[launch] |
||||
fn rocket() -> _ { |
||||
rocket::build() |
||||
.attach(Session::fairing()) |
||||
.mount("/", routes![index, add, remove]) |
||||
} |
||||
|
||||
#[get("/")] |
||||
fn index(session: Session) -> RawHtml<String> { |
||||
let mut page = String::new(); |
||||
page.push_str( |
||||
r#" |
||||
<!DOCTYPE html> |
||||
<h1>My Dogs</h1> |
||||
|
||||
<form method="POST" action="/add"> |
||||
Add Dog: <input type="text" name="name"> <input type="submit" value="Add"> |
||||
</form> |
||||
|
||||
<ul> |
||||
"#, |
||||
); |
||||
session.tap(|sess| { |
||||
for (n, dog) in sess.iter().enumerate() { |
||||
page.push_str(&format!( |
||||
r#"<li>🐶 {} <a href="/remove/{}">Remove</a></li>"#, |
||||
dog, n |
||||
)); |
||||
} |
||||
}); |
||||
page.push_str("</ul>"); |
||||
RawHtml(page) |
||||
} |
||||
|
||||
#[post("/add", data = "<dog>")] |
||||
fn add(session: Session, dog: String) -> Redirect { |
||||
session.tap(move |sess| { |
||||
sess.push(dog); |
||||
}); |
||||
Redirect::found("/") |
||||
} |
||||
|
||||
#[get("/remove/<dog>")] |
||||
fn remove(session: Session, dog: usize) -> Redirect { |
||||
session.tap(|sess| { |
||||
if dog < sess.len() { |
||||
sess.remove(dog); |
||||
} |
||||
}); |
||||
Redirect::found("/") |
||||
} |
@ -1,28 +0,0 @@ |
||||
#[macro_use] |
||||
extern crate rocket; |
||||
|
||||
use std::time::Duration; |
||||
|
||||
type Session<'a> = rocket_session::Session<'a, u64>; |
||||
|
||||
#[launch] |
||||
fn rocket() -> _ { |
||||
// This session expires in 15 seconds as a demonstration of session configuration
|
||||
rocket::build() |
||||
.attach(Session::fairing().with_lifetime(Duration::from_secs(15))) |
||||
.mount("/", routes![index]) |
||||
} |
||||
|
||||
#[get("/")] |
||||
fn index(session: Session) -> String { |
||||
let count = session.tap(|n| { |
||||
// Change the stored value (it is &mut)
|
||||
*n += 1; |
||||
|
||||
// Return something to the caller.
|
||||
// This can be any type, 'tap' is generic.
|
||||
*n |
||||
}); |
||||
|
||||
format!("{} visits", count) |
||||
} |
@ -1,73 +0,0 @@ |
||||
//! This demo is a page visit counter, with a custom cookie name, length, and expiry time.
|
||||
//!
|
||||
//! The expiry time is set to 10 seconds to illustrate how a session is cleared if inactive.
|
||||
|
||||
#[macro_use] |
||||
extern crate rocket; |
||||
|
||||
use rocket::response::content::RawHtml; |
||||
use std::time::Duration; |
||||
|
||||
#[derive(Default, Clone)] |
||||
struct SessionData { |
||||
visits1: usize, |
||||
visits2: usize, |
||||
} |
||||
|
||||
// It's convenient to define a type alias:
|
||||
type Session<'a> = rocket_session::Session<'a, SessionData>; |
||||
|
||||
#[launch] |
||||
fn rocket() -> _ { |
||||
rocket::build() |
||||
.attach( |
||||
Session::fairing() |
||||
// 10 seconds of inactivity until session expires
|
||||
// (wait 10s and refresh, the numbers will reset)
|
||||
.with_lifetime(Duration::from_secs(10)) |
||||
// custom cookie name and length
|
||||
.with_cookie_name("my_cookie") |
||||
.with_cookie_len(20), |
||||
) |
||||
.mount("/", routes![index, about]) |
||||
} |
||||
|
||||
#[get("/")] |
||||
fn index(session: Session) -> RawHtml<String> { |
||||
// Here we build the entire response inside the 'tap' closure.
|
||||
|
||||
// While inside, the session is locked to parallel changes, e.g.
|
||||
// from a different browser tab.
|
||||
session.tap(|sess| { |
||||
sess.visits1 += 1; |
||||
|
||||
RawHtml(format!( |
||||
r##" |
||||
<!DOCTYPE html> |
||||
<h1>Home</h1> |
||||
<a href="/">Refresh</a> • <a href="/about/">go to About</a> |
||||
<p>Visits: home {}, about {}</p> |
||||
"##, |
||||
sess.visits1, sess.visits2 |
||||
)) |
||||
}) |
||||
} |
||||
|
||||
#[get("/about")] |
||||
fn about(session: Session) -> RawHtml<String> { |
||||
// Here we return a value from the tap function and use it below
|
||||
let count = session.tap(|sess| { |
||||
sess.visits2 += 1; |
||||
sess.visits2 |
||||
}); |
||||
|
||||
RawHtml(format!( |
||||
r##" |
||||
<!DOCTYPE html> |
||||
<h1>About</h1> |
||||
<a href="/about">Refresh</a> • <a href="/">go home</a> |
||||
<p>Page visits: {}</p> |
||||
"##, |
||||
count |
||||
)) |
||||
} |
@ -1,330 +1,2 @@ |
||||
use std::borrow::Cow; |
||||
use std::collections::HashMap; |
||||
use std::fmt::{self, Display, Formatter}; |
||||
use std::marker::PhantomData; |
||||
use std::ops::Add; |
||||
use std::time::{Duration, Instant}; |
||||
|
||||
use parking_lot::{Mutex, RwLock, RwLockUpgradableReadGuard}; |
||||
use rand::{rngs::OsRng, Rng}; |
||||
use rocket::{ |
||||
fairing::{self, Fairing, Info}, |
||||
http::{Cookie, Status}, |
||||
outcome::Outcome, |
||||
request::FromRequest, |
||||
Build, Request, Response, Rocket, State, |
||||
}; |
||||
|
||||
/// Session store (shared state)
|
||||
#[derive(Debug)] |
||||
pub struct SessionStore<D> |
||||
where |
||||
D: 'static + Sync + Send + Default, |
||||
{ |
||||
/// The internally mutable map of sessions
|
||||
inner: RwLock<StoreInner<D>>, |
||||
// Session config
|
||||
config: SessionConfig, |
||||
} |
||||
|
||||
/// Session config object
|
||||
#[derive(Debug, Clone)] |
||||
struct SessionConfig { |
||||
/// Sessions lifespan
|
||||
lifespan: Duration, |
||||
/// Session cookie name
|
||||
cookie_name: Cow<'static, str>, |
||||
/// Session cookie path
|
||||
cookie_path: Cow<'static, str>, |
||||
/// Session ID character length
|
||||
cookie_len: usize, |
||||
} |
||||
|
||||
impl Default for SessionConfig { |
||||
fn default() -> Self { |
||||
Self { |
||||
lifespan: Duration::from_secs(3600), |
||||
cookie_name: "rocket_session".into(), |
||||
cookie_path: "/".into(), |
||||
cookie_len: 16, |
||||
} |
||||
} |
||||
} |
||||
|
||||
/// Mutable object stored inside SessionStore behind a RwLock
|
||||
#[derive(Debug)] |
||||
struct StoreInner<D> |
||||
where |
||||
D: 'static + Sync + Send + Default, |
||||
{ |
||||
sessions: HashMap<String, Mutex<SessionInstance<D>>>, |
||||
last_expiry_sweep: Instant, |
||||
} |
||||
|
||||
impl<D> Default for StoreInner<D> |
||||
where |
||||
D: 'static + Sync + Send + Default, |
||||
{ |
||||
fn default() -> Self { |
||||
Self { |
||||
sessions: Default::default(), |
||||
// the first expiry sweep is scheduled one lifetime from start-up
|
||||
last_expiry_sweep: Instant::now(), |
||||
} |
||||
} |
||||
} |
||||
|
||||
/// Session, as stored in the sessions store
|
||||
#[derive(Debug)] |
||||
struct SessionInstance<D> |
||||
where |
||||
D: 'static + Sync + Send + Default, |
||||
{ |
||||
/// Data object
|
||||
data: D, |
||||
/// Expiry
|
||||
expires: Instant, |
||||
} |
||||
|
||||
/// Session ID newtype for rocket's "local_cache"
|
||||
#[derive(Clone, Debug)] |
||||
struct SessionID(String); |
||||
|
||||
impl SessionID { |
||||
fn as_str(&self) -> &str { |
||||
self.0.as_str() |
||||
} |
||||
} |
||||
|
||||
impl Display for SessionID { |
||||
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { |
||||
f.write_str(&self.0) |
||||
} |
||||
} |
||||
|
||||
/// Session instance
|
||||
///
|
||||
/// To access the active session, simply add it as an argument to a route function.
|
||||
///
|
||||
/// Sessions are started, restored, or expired in the `FromRequest::from_request()` method
|
||||
/// when a `Session` is prepared for one of the route functions.
|
||||
#[derive(Debug)] |
||||
pub struct Session<'a, D> |
||||
where |
||||
D: 'static + Sync + Send + Default, |
||||
{ |
||||
/// The shared state reference
|
||||
store: &'a State<SessionStore<D>>, |
||||
/// Session ID
|
||||
id: &'a SessionID, |
||||
} |
||||
|
||||
#[rocket::async_trait] |
||||
impl<'r, D> FromRequest<'r> for Session<'r, D> |
||||
where |
||||
D: 'static + Sync + Send + Default, |
||||
{ |
||||
type Error = (); |
||||
|
||||
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, (Status, Self::Error), ()> { |
||||
let store = request.guard::<&State<SessionStore<D>>>().await.unwrap(); |
||||
Outcome::Success(Session { |
||||
id: request.local_cache(|| { |
||||
let store_ug = store.inner.upgradable_read(); |
||||
|
||||
// Resolve session ID
|
||||
let id = request |
||||
.cookies() |
||||
.get(&store.config.cookie_name) |
||||
.map(|cookie| SessionID(cookie.value().to_string())); |
||||
|
||||
let expires = Instant::now().add(store.config.lifespan); |
||||
|
||||
if let Some(m) = id |
||||
.as_ref() |
||||
.and_then(|token| store_ug.sessions.get(token.as_str())) |
||||
{ |
||||
// --- ID obtained from a cookie && session found in the store ---
|
||||
|
||||
let mut inner = m.lock(); |
||||
if inner.expires <= Instant::now() { |
||||
// Session expired, reuse the ID but drop data.
|
||||
inner.data = D::default(); |
||||
} |
||||
|
||||
// Session is extended by making a request with valid ID
|
||||
inner.expires = expires; |
||||
|
||||
id.unwrap() |
||||
} else { |
||||
// --- ID missing or session not found ---
|
||||
|
||||
// Get exclusive write access to the map
|
||||
let mut store_wg = RwLockUpgradableReadGuard::upgrade(store_ug); |
||||
|
||||
// This branch runs less often, and we already have write access,
|
||||
// let's check if any sessions expired. We don't want to hog memory
|
||||
// forever by abandoned sessions (e.g. when a client lost their cookie)
|
||||
|
||||
// Throttle by lifespan - e.g. sweep every hour
|
||||
if store_wg.last_expiry_sweep.elapsed() > store.config.lifespan { |
||||
let now = Instant::now(); |
||||
store_wg.sessions.retain(|_k, v| v.lock().expires > now); |
||||
|
||||
store_wg.last_expiry_sweep = now; |
||||
} |
||||
|
||||
// Find a new unique ID - we are still safely inside the write guard
|
||||
let new_id = SessionID(loop { |
||||
let token: String = OsRng |
||||
.sample_iter(&rand::distributions::Alphanumeric) |
||||
.take(store.config.cookie_len) |
||||
.map(char::from) |
||||
.collect(); |
||||
|
||||
if !store_wg.sessions.contains_key(&token) { |
||||
break token; |
||||
} |
||||
}); |
||||
|
||||
store_wg.sessions.insert( |
||||
new_id.to_string(), |
||||
Mutex::new(SessionInstance { |
||||
data: Default::default(), |
||||
expires, |
||||
}), |
||||
); |
||||
|
||||
new_id |
||||
} |
||||
}), |
||||
store, |
||||
}) |
||||
} |
||||
} |
||||
|
||||
impl<'a, D> Session<'a, D> |
||||
where |
||||
D: 'static + Sync + Send + Default, |
||||
{ |
||||
/// Create the session fairing.
|
||||
///
|
||||
/// You can configure the session store by calling chained methods on the returned value
|
||||
/// before passing it to `rocket.attach()`
|
||||
pub fn fairing() -> SessionFairing<D> { |
||||
SessionFairing::<D>::new() |
||||
} |
||||
|
||||
/// Clear session data (replace the value with default)
|
||||
pub fn clear(&self) { |
||||
self.tap(|m| { |
||||
*m = D::default(); |
||||
}) |
||||
} |
||||
|
||||
/// Access the session's data using a closure.
|
||||
///
|
||||
/// The closure is called with the data value as a mutable argument,
|
||||
/// and can return any value to be is passed up to the caller.
|
||||
pub fn tap<T>(&self, func: impl FnOnce(&mut D) -> T) -> T { |
||||
// Use a read guard, so other already active sessions are not blocked
|
||||
// from accessing the store. New incoming clients may be blocked until
|
||||
// the tap() call finishes
|
||||
let store_rg = self.store.inner.read(); |
||||
|
||||
// Unlock the session's mutex.
|
||||
// Expiry was checked and prolonged at the beginning of the request
|
||||
let mut instance = store_rg |
||||
.sessions |
||||
.get(self.id.as_str()) |
||||
.expect("Session data unexpectedly missing") |
||||
.lock(); |
||||
|
||||
func(&mut instance.data) |
||||
} |
||||
} |
||||
|
||||
/// Fairing struct
|
||||
#[derive(Default)] |
||||
pub struct SessionFairing<D> |
||||
where |
||||
D: 'static + Sync + Send + Default, |
||||
{ |
||||
config: SessionConfig, |
||||
phantom: PhantomData<D>, |
||||
} |
||||
|
||||
impl<D> SessionFairing<D> |
||||
where |
||||
D: 'static + Sync + Send + Default, |
||||
{ |
||||
fn new() -> Self { |
||||
Self::default() |
||||
} |
||||
|
||||
/// Set session lifetime (expiration time).
|
||||
///
|
||||
/// Call on the fairing before passing it to `rocket.attach()`
|
||||
pub fn with_lifetime(mut self, time: Duration) -> Self { |
||||
self.config.lifespan = time; |
||||
self |
||||
} |
||||
|
||||
/// Set session cookie name and length
|
||||
///
|
||||
/// Call on the fairing before passing it to `rocket.attach()`
|
||||
pub fn with_cookie_name(mut self, name: impl Into<Cow<'static, str>>) -> Self { |
||||
self.config.cookie_name = name.into(); |
||||
self |
||||
} |
||||
|
||||
/// Set session cookie name and length
|
||||
///
|
||||
/// Call on the fairing before passing it to `rocket.attach()`
|
||||
pub fn with_cookie_len(mut self, length: usize) -> Self { |
||||
self.config.cookie_len = length; |
||||
self |
||||
} |
||||
|
||||
/// Set session cookie name and length
|
||||
///
|
||||
/// Call on the fairing before passing it to `rocket.attach()`
|
||||
pub fn with_cookie_path(mut self, path: impl Into<Cow<'static, str>>) -> Self { |
||||
self.config.cookie_path = path.into(); |
||||
self |
||||
} |
||||
} |
||||
|
||||
#[rocket::async_trait] |
||||
impl<D> Fairing for SessionFairing<D> |
||||
where |
||||
D: 'static + Sync + Send + Default, |
||||
{ |
||||
fn info(&self) -> Info { |
||||
Info { |
||||
name: "Session", |
||||
kind: fairing::Kind::Ignite | fairing::Kind::Response, |
||||
} |
||||
} |
||||
|
||||
async fn on_ignite(&self, rocket: Rocket<Build>) -> Result<Rocket<Build>, Rocket<Build>> { |
||||
// install the store singleton
|
||||
Ok(rocket.manage(SessionStore::<D> { |
||||
inner: Default::default(), |
||||
config: self.config.clone(), |
||||
})) |
||||
} |
||||
|
||||
async fn on_response<'r>(&self, request: &'r Request<'_>, response: &mut Response) { |
||||
// send the session cookie, if session started
|
||||
let session = request.local_cache(|| SessionID("".to_string())); |
||||
|
||||
if !session.0.is_empty() { |
||||
response.adjoin_header( |
||||
Cookie::build(self.config.cookie_name.clone(), session.to_string()) |
||||
.path("/") |
||||
.finish(), |
||||
); |
||||
} |
||||
} |
||||
} |
||||
mod session; |
||||
pub use session::Session; |
||||
|
@ -0,0 +1,177 @@ |
||||
use json_dotpath::DotPaths; |
||||
use parking_lot::RwLock; |
||||
use rand::Rng; |
||||
use rocket::fairing::{self, Fairing, Info}; |
||||
use rocket::request::FromRequest; |
||||
|
||||
use rocket::{ |
||||
http::{Cookie, Status}, |
||||
Outcome, Request, Response, Rocket, State, |
||||
}; |
||||
use serde::de::DeserializeOwned; |
||||
use serde::Serialize; |
||||
use serde_json::{Map, Value}; |
||||
|
||||
use std::collections::HashMap; |
||||
use std::ops::Add; |
||||
use std::time::{Duration, Instant}; |
||||
|
||||
const SESSION_ID: &str = "SESSID"; |
||||
|
||||
type SessionsMap = HashMap<String, SessionInstance>; |
||||
|
||||
#[derive(Debug)] |
||||
struct SessionInstance { |
||||
data: serde_json::Map<String, Value>, |
||||
expires: Instant, |
||||
} |
||||
|
||||
#[derive(Default, Debug)] |
||||
struct SessionStore { |
||||
inner: RwLock<SessionsMap>, |
||||
lifespan: Duration, |
||||
} |
||||
|
||||
#[derive(PartialEq, Hash, Clone, Debug)] |
||||
struct SessionID(String); |
||||
|
||||
impl<'a, 'r> FromRequest<'a, 'r> for &'a SessionID { |
||||
type Error = (); |
||||
|
||||
fn from_request(request: &'a Request<'r>) -> Outcome<Self, (Status, Self::Error), ()> { |
||||
Outcome::Success(request.local_cache(|| { |
||||
if let Some(cookie) = request.cookies().get(SESSION_ID) { |
||||
SessionID(cookie.value().to_string()) // FIXME avoid cloning (cow?)
|
||||
} else { |
||||
SessionID( |
||||
rand::thread_rng() |
||||
.sample_iter(&rand::distributions::Alphanumeric) |
||||
.take(16) |
||||
.collect(), |
||||
) |
||||
} |
||||
})) |
||||
} |
||||
} |
||||
|
||||
#[derive(Debug)] |
||||
pub struct Session<'a> { |
||||
store: State<'a, SessionStore>, |
||||
id: &'a SessionID, |
||||
} |
||||
|
||||
impl<'a, 'r> FromRequest<'a, 'r> for Session<'a> { |
||||
type Error = (); |
||||
|
||||
fn from_request(request: &'a Request<'r>) -> Outcome<Self, (Status, Self::Error), ()> { |
||||
Outcome::Success(Session { |
||||
id: request.local_cache(|| { |
||||
if let Some(cookie) = request.cookies().get(SESSION_ID) { |
||||
SessionID(cookie.value().to_string()) |
||||
} else { |
||||
SessionID( |
||||
rand::thread_rng() |
||||
.sample_iter(&rand::distributions::Alphanumeric) |
||||
.take(16) |
||||
.collect(), |
||||
) |
||||
} |
||||
}), |
||||
store: request.guard().unwrap(), |
||||
}) |
||||
} |
||||
} |
||||
|
||||
impl<'a> Session<'a> { |
||||
pub fn fairing(lifespan: Duration) -> impl Fairing { |
||||
SessionFairing { lifespan } |
||||
} |
||||
|
||||
pub fn tap<T>(&self, func: impl FnOnce(&mut serde_json::Map<String, Value>) -> T) -> T { |
||||
let mut wg = self.store.inner.write(); |
||||
if let Some(instance) = wg.get_mut(&self.id.0) { |
||||
instance.expires = Instant::now().add(self.store.lifespan); |
||||
func(&mut instance.data) |
||||
} else { |
||||
let mut data = Map::new(); |
||||
let rv = func(&mut data); |
||||
wg.insert( |
||||
self.id.0.clone(), |
||||
SessionInstance { |
||||
data, |
||||
expires: Instant::now().add(self.store.lifespan), |
||||
}, |
||||
); |
||||
rv |
||||
} |
||||
} |
||||
|
||||
pub fn renew(&self) { |
||||
self.tap(|_| ()) |
||||
} |
||||
|
||||
pub fn reset(&self) { |
||||
self.tap(|data| data.clear()) |
||||
} |
||||
|
||||
pub fn get<T: DeserializeOwned>(&self, path: &str) -> Option<T> { |
||||
self.tap(|data| data.dot_get(path)) |
||||
} |
||||
|
||||
pub fn get_or<T: DeserializeOwned>(&self, path: &str, def: T) -> T { |
||||
self.get(path).unwrap_or(def) |
||||
} |
||||
|
||||
pub fn get_or_else<T: DeserializeOwned, F: FnOnce() -> T>(&self, path: &str, def: F) -> T { |
||||
self.get(path).unwrap_or_else(def) |
||||
} |
||||
|
||||
pub fn get_or_default<T: DeserializeOwned + Default>(&self, path: &str) -> T { |
||||
self.get(path).unwrap_or_default() |
||||
} |
||||
|
||||
pub fn take<T: DeserializeOwned>(&self, path: &str) -> Option<T> { |
||||
self.tap(|data| data.dot_take(path)) |
||||
} |
||||
|
||||
pub fn replace<O: DeserializeOwned, N: Serialize>(&self, path: &str, new: N) -> Option<O> { |
||||
self.tap(|data| data.dot_replace(path, new)) |
||||
} |
||||
|
||||
pub fn set<T: Serialize>(&self, path: &str, value: T) { |
||||
self.tap(|data| data.dot_set(path, value)); |
||||
} |
||||
|
||||
pub fn remove(&self, path: &str) -> bool { |
||||
self.tap(|data| data.dot_remove(path)) |
||||
} |
||||
} |
||||
|
||||
/// Fairing struct
|
||||
struct SessionFairing { |
||||
lifespan: Duration |
||||
} |
||||
|
||||
impl Fairing for SessionFairing { |
||||
fn info(&self) -> Info { |
||||
Info { |
||||
name: "Session", |
||||
kind: fairing::Kind::Attach | fairing::Kind::Response, |
||||
} |
||||
} |
||||
|
||||
fn on_attach(&self, rocket: Rocket) -> Result<Rocket, Rocket> { |
||||
Ok(rocket.manage(SessionStore { |
||||
inner: Default::default(), |
||||
lifespan: self.lifespan, |
||||
})) |
||||
} |
||||
|
||||
fn on_response<'r>(&self, request: &'r Request, response: &mut Response) { |
||||
let session = request.local_cache(|| SessionID("".to_string())); |
||||
|
||||
if !session.0.is_empty() { |
||||
response.adjoin_header(Cookie::build(SESSION_ID, session.0.clone()).finish()); |
||||
} |
||||
} |
||||
} |
Loading…
Reference in new issue