From cde08fe7886fea81955b3e50cfe76f504788b83e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ond=C5=99ej=20Hru=C5=A1ka?= Date: Tue, 31 Dec 2019 20:31:38 +0100 Subject: [PATCH] add examples, automatic expired removal, better configurability --- examples/dog_list/main.rs | 70 ++++++++ examples/visit_counter/main.rs | 72 +++++++++ src/lib.rs | 286 +++++++++++++++++++++++---------- 3 files changed, 343 insertions(+), 85 deletions(-) create mode 100644 examples/dog_list/main.rs create mode 100644 examples/visit_counter/main.rs diff --git a/examples/dog_list/main.rs b/examples/dog_list/main.rs new file mode 100644 index 0000000..976ece7 --- /dev/null +++ b/examples/dog_list/main.rs @@ -0,0 +1,70 @@ +#![feature(proc_macro_hygiene, decl_macro)] +#[macro_use] +extern crate rocket; + +use rocket::response::content::Html; +use rocket::response::Redirect; +use rocket::request::Form; + +type Session<'a> = rocket_session::Session<'a, Vec>; + +fn main() { + rocket::ignite() + .attach(Session::fairing()) + .mount("/", routes![index, add, remove]) + .launch(); +} + +#[get("/")] +fn index(session: Session) -> Html { + let mut page = String::new(); + page.push_str(r#" + +

My Dogs

+ +
+ Add Dog: +
+ +
    + "#); + + session.tap(|sess| { + for (n, dog) in sess.iter().enumerate() { + page.push_str(&format!(r#" +
  • 🐶 {} Remove
  • + "#, dog, n)); + } + }); + + page.push_str(r#" +
+ "#); + + Html(page) +} + +#[derive(FromForm)] +struct AddForm { + name: String, +} + +#[post("/add", data="")] +fn add(session: Session, dog : Form) -> Redirect { + session.tap(move |sess| { + sess.push(dog.into_inner().name); + }); + + Redirect::found("/") +} + +#[get("/remove/")] +fn remove(session: Session, dog : usize) -> Redirect { + session.tap(|sess| { + if dog < sess.len() { + sess.remove(dog); + } + }); + + Redirect::found("/") +} diff --git a/examples/visit_counter/main.rs b/examples/visit_counter/main.rs new file mode 100644 index 0000000..b4e29ae --- /dev/null +++ b/examples/visit_counter/main.rs @@ -0,0 +1,72 @@ +//! This demo is a page visit counter, with a custom cookie name, length, and expiry time. +//! +//! The expiry time is set to 10 seconds to illustrate how a session is cleared if inactive. + +#![feature(proc_macro_hygiene, decl_macro)] +#[macro_use] +extern crate rocket; + +use std::time::Duration; +use rocket::response::content::Html; + +#[derive(Default, Clone)] +struct SessionData { + visits1: usize, + visits2: usize, +} + +// It's convenient to define a type alias: +type Session<'a> = rocket_session::Session<'a, SessionData>; + +fn main() { + rocket::ignite() + .attach(Session::fairing() + // 10 seconds of inactivity until session expires + // (wait 10s and refresh, the numbers will reset) + .with_lifetime(Duration::from_secs(10)) + // custom cookie name and length + .with_cookie_name("my_cookie") + .with_cookie_len(20) + ) + .mount("/", routes![index, about]) + .launch(); +} + +#[get("/")] +fn index(session: Session) -> Html { + // Here we build the entire response inside the 'tap' closure. + + // While inside, the session is locked to parallel changes, e.g. + // from a different browser tab. + session.tap(|sess| { + sess.visits1 += 1; + + Html(format!(r##" + +

Home

+ Refreshgo to About +

Visits: home {}, about {}

+ "##, + sess.visits1, + sess.visits2 + )) + }) +} + +#[get("/about")] +fn about(session: Session) -> Html { + // Here we return a value from the tap function and use it below + let count = session.tap(|sess| { + sess.visits2 += 1; + sess.visits2 + }); + + Html(format!(r##" + +

About

+ Refreshgo home +

Page visits: {}

+ "##, + count + )) +} diff --git a/src/lib.rs b/src/lib.rs index 87fe7fb..b23a2cd 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,4 @@ -use parking_lot::RwLock; +use parking_lot::{RwLock, RwLockUpgradableReadGuard, Mutex}; use rand::Rng; use rocket::{ @@ -12,60 +12,100 @@ use std::collections::HashMap; use std::marker::PhantomData; use std::ops::Add; use std::time::{Duration, Instant}; +use std::borrow::Cow; +use std::fmt::{Display, Formatter, self}; -const SESSION_COOKIE: &str = "SESSID"; -const SESSION_ID_LEN: usize = 16; - -/// Session, as stored in the sessions store +/// Session store (shared state) #[derive(Debug)] -struct SessionInstance +pub struct SessionStore where D: 'static + Sync + Send + Default, { - /// Data object - data: D, - /// Expiry - expires: Instant, + /// The internally mutable map of sessions + inner: RwLock>, + // Session config + config: SessionConfig, } -/// Session store (shared state) -#[derive(Default, Debug)] -pub struct SessionStore - where - D: 'static + Sync + Send + Default, -{ - /// The internaly mutable map of sessions - inner: RwLock>>, +/// Session config object +#[derive(Debug, Clone)] +struct SessionConfig { /// Sessions lifespan lifespan: Duration, + /// Session cookie name + cookie_name: Cow<'static, str>, + /// Session cookie path + cookie_path: Cow<'static, str>, + /// Session ID character length + cookie_len: usize, } -impl SessionStore +impl Default for SessionConfig { + fn default() -> Self { + Self { + lifespan: Duration::from_secs(3600), + cookie_name: "rocket_session".into(), + cookie_path: "/".into(), + cookie_len: 16, + } + } +} + +/// Mutable object stored inside SessionStore behind a RwLock +#[derive(Debug)] +struct StoreInner + where + D: 'static + Sync + Send + Default { + sessions: HashMap>>, + last_expiry_sweep: Instant, +} + +impl Default for StoreInner + where + D: 'static + Sync + Send + Default { + fn default() -> Self { + Self { + sessions: Default::default(), + // the first expiry sweep is scheduled one lifetime from start-up + last_expiry_sweep: Instant::now(), + } + } +} + +/// Session, as stored in the sessions store +#[derive(Debug)] +struct SessionInstance where D: 'static + Sync + Send + Default, { - /// Remove all expired sessions - pub fn remove_expired(&self) { - let now = Instant::now(); - self.inner.write().retain(|_k, v| v.expires > now); - } + /// Data object + data: D, + /// Expiry + expires: Instant, } /// Session ID newtype for rocket's "local_cache" -#[derive(PartialEq, Hash, Clone, Debug)] +#[derive(Clone, Debug)] struct SessionID(String); impl SessionID { fn as_str(&self) -> &str { self.0.as_str() } +} - fn to_string(&self) -> String { - self.0.clone() +impl Display for SessionID { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.write_str(&self.0) } } /// Session instance +/// +/// To access the active session, simply add it as an argument to a route function. +/// +/// Sessions are started, restored, or expired in the `FromRequest::from_request()` method +/// when a `Session` is prepared for one of the route functions. #[derive(Debug)] pub struct Session<'a, D> where @@ -84,45 +124,76 @@ impl<'a, 'r, D> FromRequest<'a, 'r> for Session<'a, D> type Error = (); fn from_request(request: &'a Request<'r>) -> Outcome { - let store : State> = request.guard().unwrap(); + let store: State> = request.guard().unwrap(); Outcome::Success(Session { id: request.local_cache(|| { + let store_ug = store.inner.upgradable_read(); + // Resolve session ID - let id = if let Some(cookie) = request.cookies().get(SESSION_COOKIE) { - SessionID(cookie.value().to_string()) + let id = if let Some(cookie) = request.cookies().get(&store.config.cookie_name) { + Some(SessionID(cookie.value().to_string())) } else { - SessionID( - rand::thread_rng() - .sample_iter(&rand::distributions::Alphanumeric) - .take(SESSION_ID_LEN) - .collect(), - ) + None }; - let new_expiration = Instant::now().add(store.lifespan); - let mut wg = store.inner.write(); - match wg.get_mut(id.as_str()) { - Some(ses) => { - // Check expiration - if ses.expires <= Instant::now() { - ses.data = D::default(); - } - // Update expiry timestamp - ses.expires = new_expiration; - }, - None => { - // New session - wg.insert( - id.to_string(), - SessionInstance { - data: D::default(), - expires: new_expiration, - } - ); + let expires = Instant::now().add(store.config.lifespan); + + if let Some(m) = id.as_ref() + .and_then(|token| store_ug.sessions.get(token.as_str())) + { + // --- ID obtained from a cookie && session found in the store --- + + let mut inner = m.lock(); + if inner.expires <= Instant::now() { + // Session expired, reuse the ID but drop data. + inner.data = D::default(); } - }; - id + // Session is extended by making a request with valid ID + inner.expires = expires; + + id.unwrap() + } else { + // --- ID missing or session not found --- + + // Get exclusive write access to the map + let mut store_wg = RwLockUpgradableReadGuard::upgrade(store_ug); + + // This branch runs less often, and we already have write access, + // let's check if any sessions expired. We don't want to hog memory + // forever by abandoned sessions (e.g. when a client lost their cookie) + + // Throttle by lifespan - e.g. sweep every hour + if store_wg.last_expiry_sweep.elapsed() > store.config.lifespan { + let now = Instant::now(); + store_wg.sessions + .retain(|_k, v| v.lock().expires > now); + + store_wg.last_expiry_sweep = now; + } + + // Find a new unique ID - we are still safely inside the write guard + let new_id = SessionID(loop { + let token: String = rand::thread_rng() + .sample_iter(&rand::distributions::Alphanumeric) + .take(store.config.cookie_len) + .collect(); + + if !store_wg.sessions.contains_key(&token) { + break token; + } + }); + + store_wg.sessions.insert( + new_id.to_string(), + Mutex::new(SessionInstance { + data: Default::default(), + expires, + }), + ); + + new_id + } }), store, }) @@ -133,46 +204,90 @@ impl<'a, D> Session<'a, D> where D: 'static + Sync + Send + Default, { - /// Get the fairing object - pub fn fairing(lifespan: Duration) -> impl Fairing { - SessionFairing:: { - lifespan, - _phantom: PhantomData, - } + /// Create the session fairing. + /// + /// You can configure the session store by calling chained methods on the returned value + /// before passing it to `rocket.attach()` + pub fn fairing() -> SessionFairing { + SessionFairing::::new() } - /// Access the session store - pub fn get_store(&self) -> &SessionStore { - &self.store - } - - /// Set the session object to its default state - pub fn reset(&self) { - self.tap_mut(|m| { + /// Clear session data (replace the value with default) + pub fn clear(&self) { + self.tap(|m| { *m = D::default(); }) } - pub fn tap(&self, func: impl FnOnce(&D) -> T) -> T { - let rg = self.store.inner.read(); - let instance = rg.get(self.id.as_str()).unwrap(); - func(&instance.data) - } + /// Access the session's data using a closure. + /// + /// The closure is called with the data value as a mutable argument, + /// and can return any value to be is passed up to the caller. + pub fn tap(&self, func: impl FnOnce(&mut D) -> T) -> T { + // Use a read guard, so other already active sessions are not blocked + // from accessing the store. New incoming clients may be blocked until + // the tap() call finishes + let store_rg = self.store.inner.read(); + + // Unlock the session's mutex. + // Expiry was checked and prolonged at the beginning of the request + let mut instance = store_rg.sessions.get(self.id.as_str()) + .expect("Session data unexpectedly missing") + .lock(); - pub fn tap_mut(&self, func: impl FnOnce(&mut D) -> T) -> T { - let mut wg = self.store.inner.write(); - let instance = wg.get_mut(self.id.as_str()).unwrap(); func(&mut instance.data) } } /// Fairing struct -struct SessionFairing +#[derive(Default)] +pub struct SessionFairing where D: 'static + Sync + Send + Default, { - lifespan: Duration, - _phantom: PhantomData, + config: SessionConfig, + phantom: PhantomData, +} + +impl SessionFairing + where + D: 'static + Sync + Send + Default +{ + fn new() -> Self { + Self::default() + } + + /// Set session lifetime (expiration time). + /// + /// Call on the fairing before passing it to `rocket.attach()` + pub fn with_lifetime(mut self, time: Duration) -> Self { + self.config.lifespan = time; + self + } + + /// Set session cookie name and length + /// + /// Call on the fairing before passing it to `rocket.attach()` + pub fn with_cookie_name(mut self, name: impl Into>) -> Self { + self.config.cookie_name = name.into(); + self + } + + /// Set session cookie name and length + /// + /// Call on the fairing before passing it to `rocket.attach()` + pub fn with_cookie_len(mut self, length: usize) -> Self { + self.config.cookie_len = length; + self + } + + /// Set session cookie name and length + /// + /// Call on the fairing before passing it to `rocket.attach()` + pub fn with_cookie_path(mut self, path: impl Into>) -> Self { + self.config.cookie_path = path.into(); + self + } } impl Fairing for SessionFairing @@ -190,7 +305,7 @@ impl Fairing for SessionFairing // install the store singleton Ok(rocket.manage(SessionStore:: { inner: Default::default(), - lifespan: self.lifespan, + config: self.config.clone(), })) } @@ -199,7 +314,8 @@ impl Fairing for SessionFairing let session = request.local_cache(|| SessionID("".to_string())); if !session.0.is_empty() { - response.adjoin_header(Cookie::build(SESSION_COOKIE, session.0.clone()).finish()); + response.adjoin_header(Cookie::build(self.config.cookie_name.clone(), session.to_string()) + .path("/").finish()); } } }