Compare commits

...

2 Commits

  1. 2
      Cargo.toml
  2. 40
      README.md
  3. 77
      examples/dog_list/main.rs
  4. 74
      examples/visit_counter/main.rs
  5. 319
      src/lib.rs

@ -1,6 +1,6 @@
[package] [package]
name = "rocket_session" name = "rocket_session"
version = "0.1.1" version = "0.2.0"
authors = ["Ondřej Hruška <ondra@ondrovo.com>"] authors = ["Ondřej Hruška <ondra@ondrovo.com>"]
edition = "2018" edition = "2018"
license = "MIT" license = "MIT"

@ -2,17 +2,38 @@
Adding cookie-based sessions to a rocket application is extremely simple with this crate. Adding cookie-based sessions to a rocket application is extremely simple with this crate.
Sessions are used to share data between related requests, such as user authentication, shopping basket,
form values that failed validation for re-filling, etc.
## Configuration
The implementation is generic to support any type as session data: a custom struct, `String`, The implementation is generic to support any type as session data: a custom struct, `String`,
`HashMap`, or perhaps `serde_json::Value`. You're free to choose. `HashMap`, or perhaps `serde_json::Value`. You're free to choose.
The session expiry time is configurable through the Fairing. When a session expires, The session lifetime, cookie name, and other parameters can be configured by calling chained
the data associated with it is dropped. All expired sessions may be cleared by calling `.remove_expired()` methods on the fairing. When a session expires, the data associated with it is dropped.
on the `SessionStore`, which is be obtained in routes as `State<SessionStore>`, or from a
session instance by calling `.get_store()`. ## Usage
To use session in a route, first make sure you have the fairing attached by calling
`rocket.attach(Session::fairing())` at start-up, and then add something like `session : Session`
to the parameter list of your route(s). Everything else--session init, expiration, cookie
management--is done for you behind the scenes.
The session cookie is currently hardcoded to "SESSID" and contains 16 random characters. Session data is accessed in a closure run in the session context, using the `session.tap()`
method. This closure runs inside a per-session mutex, avoiding simultaneous mutation
from different requests. Try to *avoid lengthy operations inside the closure*,
as it effectively blocks any other request to session-enabled routes by the client.
## Basic Example Every request to a session-enabled route extends the session's lifetime to the full
configured time (defaults to 1 hour). Automatic clean-up removes expired sessions to make sure
the session list does not waste memory.
## Examples
(More examples are in the examples folder)
### Basic Example
This simple example uses u64 as the session variable; note that it can be a struct, map, or anything else, This simple example uses u64 as the session variable; note that it can be a struct, map, or anything else,
it just needs to implement `Send + Sync + Default`. it just needs to implement `Send + Sync + Default`.
@ -28,7 +49,7 @@ pub type Session<'a> = rocket_session::Session<'a, u64>;
fn main() { fn main() {
rocket::ignite() rocket::ignite()
.attach(Session::fairing(Duration::from_secs(3600))) .attach(Session::fairing())
.mount("/", routes![index]) .mount("/", routes![index])
.launch(); .launch();
} }
@ -53,7 +74,10 @@ fn index(session: Session) -> String {
The `.tap()` method is powerful, but sometimes you may wish for something more convenient. The `.tap()` method is powerful, but sometimes you may wish for something more convenient.
Here is an example of using a custom trait and the `json_dotpath` crate to implement Here is an example of using a custom trait and the `json_dotpath` crate to implement
a polymorphic store based on serde serialization: a polymorphic store based on serde serialization.
Note that this approach is prone to data races, since every method contains its own `.tap()`.
It may be safer to simply call the `.dot_*()` methods manually in one shared closure.
```rust ```rust
use serde_json::Value; use serde_json::Value;

@ -0,0 +1,77 @@
#![feature(proc_macro_hygiene, decl_macro)]
#[macro_use]
extern crate rocket;
use rocket::request::Form;
use rocket::response::content::Html;
use rocket::response::Redirect;
type Session<'a> = rocket_session::Session<'a, Vec<String>>;
fn main() {
rocket::ignite()
.attach(Session::fairing())
.mount("/", routes![index, add, remove])
.launch();
}
#[get("/")]
fn index(session: Session) -> Html<String> {
let mut page = String::new();
page.push_str(
r#"
<!DOCTYPE html>
<h1>My Dogs</h1>
<form method="POST" action="/add">
Add Dog: <input type="text" name="name"> <input type="submit" value="Add">
</form>
<ul>
"#,
);
session.tap(|sess| {
for (n, dog) in sess.iter().enumerate() {
page.push_str(&format!(
r#"
<li>&#x1F436; {} <a href="/remove/{}">Remove</a></li>
"#,
dog, n
));
}
});
page.push_str(
r#"
</ul>
"#,
);
Html(page)
}
#[derive(FromForm)]
struct AddForm {
name: String,
}
#[post("/add", data = "<dog>")]
fn add(session: Session, dog: Form<AddForm>) -> Redirect {
session.tap(move |sess| {
sess.push(dog.into_inner().name);
});
Redirect::found("/")
}
#[get("/remove/<dog>")]
fn remove(session: Session, dog: usize) -> Redirect {
session.tap(|sess| {
if dog < sess.len() {
sess.remove(dog);
}
});
Redirect::found("/")
}

@ -0,0 +1,74 @@
//! This demo is a page visit counter, with a custom cookie name, length, and expiry time.
//!
//! The expiry time is set to 10 seconds to illustrate how a session is cleared if inactive.
#![feature(proc_macro_hygiene, decl_macro)]
#[macro_use]
extern crate rocket;
use rocket::response::content::Html;
use std::time::Duration;
#[derive(Default, Clone)]
struct SessionData {
visits1: usize,
visits2: usize,
}
// It's convenient to define a type alias:
type Session<'a> = rocket_session::Session<'a, SessionData>;
fn main() {
rocket::ignite()
.attach(
Session::fairing()
// 10 seconds of inactivity until session expires
// (wait 10s and refresh, the numbers will reset)
.with_lifetime(Duration::from_secs(10))
// custom cookie name and length
.with_cookie_name("my_cookie")
.with_cookie_len(20),
)
.mount("/", routes![index, about])
.launch();
}
#[get("/")]
fn index(session: Session) -> Html<String> {
// Here we build the entire response inside the 'tap' closure.
// While inside, the session is locked to parallel changes, e.g.
// from a different browser tab.
session.tap(|sess| {
sess.visits1 += 1;
Html(format!(
r##"
<!DOCTYPE html>
<h1>Home</h1>
<a href="/">Refresh</a> &bull; <a href="/about/">go to About</a>
<p>Visits: home {}, about {}</p>
"##,
sess.visits1, sess.visits2
))
})
}
#[get("/about")]
fn about(session: Session) -> Html<String> {
// Here we return a value from the tap function and use it below
let count = session.tap(|sess| {
sess.visits2 += 1;
sess.visits2
});
Html(format!(
r##"
<!DOCTYPE html>
<h1>About</h1>
<a href="/about">Refresh</a> &bull; <a href="/">go home</a>
<p>Page visits: {}</p>
"##,
count
))
}

@ -1,4 +1,4 @@
use parking_lot::RwLock; use parking_lot::{Mutex, RwLock, RwLockUpgradableReadGuard};
use rand::Rng; use rand::Rng;
use rocket::{ use rocket::{
@ -8,68 +8,110 @@ use rocket::{
Outcome, Request, Response, Rocket, State, Outcome, Request, Response, Rocket, State,
}; };
use std::borrow::Cow;
use std::collections::HashMap; use std::collections::HashMap;
use std::fmt::{self, Display, Formatter};
use std::marker::PhantomData; use std::marker::PhantomData;
use std::ops::Add; use std::ops::Add;
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
const SESSION_COOKIE: &str = "SESSID"; /// Session store (shared state)
const SESSION_ID_LEN: usize = 16;
/// Session, as stored in the sessions store
#[derive(Debug)] #[derive(Debug)]
struct SessionInstance<D> pub struct SessionStore<D>
where where
D: 'static + Sync + Send + Default, D: 'static + Sync + Send + Default,
{ {
/// Data object /// The internally mutable map of sessions
data: D, inner: RwLock<StoreInner<D>>,
/// Expiry // Session config
expires: Instant, config: SessionConfig,
} }
/// Session store (shared state) /// Session config object
#[derive(Default, Debug)] #[derive(Debug, Clone)]
pub struct SessionStore<D> struct SessionConfig {
where
D: 'static + Sync + Send + Default,
{
/// The internaly mutable map of sessions
inner: RwLock<HashMap<String, SessionInstance<D>>>,
/// Sessions lifespan /// Sessions lifespan
lifespan: Duration, lifespan: Duration,
/// Session cookie name
cookie_name: Cow<'static, str>,
/// Session cookie path
cookie_path: Cow<'static, str>,
/// Session ID character length
cookie_len: usize,
} }
impl<D> SessionStore<D> impl Default for SessionConfig {
where fn default() -> Self {
D: 'static + Sync + Send + Default, Self {
lifespan: Duration::from_secs(3600),
cookie_name: "rocket_session".into(),
cookie_path: "/".into(),
cookie_len: 16,
}
}
}
/// Mutable object stored inside SessionStore behind a RwLock
#[derive(Debug)]
struct StoreInner<D>
where
D: 'static + Sync + Send + Default,
{ {
/// Remove all expired sessions sessions: HashMap<String, Mutex<SessionInstance<D>>>,
pub fn remove_expired(&self) { last_expiry_sweep: Instant,
let now = Instant::now(); }
self.inner.write().retain(|_k, v| v.expires > now);
impl<D> Default for StoreInner<D>
where
D: 'static + Sync + Send + Default,
{
fn default() -> Self {
Self {
sessions: Default::default(),
// the first expiry sweep is scheduled one lifetime from start-up
last_expiry_sweep: Instant::now(),
}
} }
} }
/// Session, as stored in the sessions store
#[derive(Debug)]
struct SessionInstance<D>
where
D: 'static + Sync + Send + Default,
{
/// Data object
data: D,
/// Expiry
expires: Instant,
}
/// Session ID newtype for rocket's "local_cache" /// Session ID newtype for rocket's "local_cache"
#[derive(PartialEq, Hash, Clone, Debug)] #[derive(Clone, Debug)]
struct SessionID(String); struct SessionID(String);
impl SessionID { impl SessionID {
fn as_str(&self) -> &str { fn as_str(&self) -> &str {
self.0.as_str() self.0.as_str()
} }
}
fn to_string(&self) -> String { impl Display for SessionID {
self.0.clone() fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.write_str(&self.0)
} }
} }
/// Session instance /// Session instance
///
/// To access the active session, simply add it as an argument to a route function.
///
/// Sessions are started, restored, or expired in the `FromRequest::from_request()` method
/// when a `Session` is prepared for one of the route functions.
#[derive(Debug)] #[derive(Debug)]
pub struct Session<'a, D> pub struct Session<'a, D>
where where
D: 'static + Sync + Send + Default, D: 'static + Sync + Send + Default,
{ {
/// The shared state reference /// The shared state reference
store: State<'a, SessionStore<D>>, store: State<'a, SessionStore<D>>,
@ -78,51 +120,82 @@ pub struct Session<'a, D>
} }
impl<'a, 'r, D> FromRequest<'a, 'r> for Session<'a, D> impl<'a, 'r, D> FromRequest<'a, 'r> for Session<'a, D>
where where
D: 'static + Sync + Send + Default, D: 'static + Sync + Send + Default,
{ {
type Error = (); type Error = ();
fn from_request(request: &'a Request<'r>) -> Outcome<Self, (Status, Self::Error), ()> { fn from_request(request: &'a Request<'r>) -> Outcome<Self, (Status, Self::Error), ()> {
let store : State<SessionStore<D>> = request.guard().unwrap(); let store: State<SessionStore<D>> = request.guard().unwrap();
Outcome::Success(Session { Outcome::Success(Session {
id: request.local_cache(|| { id: request.local_cache(|| {
let store_ug = store.inner.upgradable_read();
// Resolve session ID // Resolve session ID
let id = if let Some(cookie) = request.cookies().get(SESSION_COOKIE) { let id = if let Some(cookie) = request.cookies().get(&store.config.cookie_name) {
SessionID(cookie.value().to_string()) Some(SessionID(cookie.value().to_string()))
} else { } else {
SessionID( None
rand::thread_rng()
.sample_iter(&rand::distributions::Alphanumeric)
.take(SESSION_ID_LEN)
.collect(),
)
}; };
let new_expiration = Instant::now().add(store.lifespan); let expires = Instant::now().add(store.config.lifespan);
let mut wg = store.inner.write();
match wg.get_mut(id.as_str()) { if let Some(m) = id
Some(ses) => { .as_ref()
// Check expiration .and_then(|token| store_ug.sessions.get(token.as_str()))
if ses.expires <= Instant::now() { {
ses.data = D::default(); // --- ID obtained from a cookie && session found in the store ---
}
// Update expiry timestamp let mut inner = m.lock();
ses.expires = new_expiration; if inner.expires <= Instant::now() {
}, // Session expired, reuse the ID but drop data.
None => { inner.data = D::default();
// New session
wg.insert(
id.to_string(),
SessionInstance {
data: D::default(),
expires: new_expiration,
}
);
} }
};
id // Session is extended by making a request with valid ID
inner.expires = expires;
id.unwrap()
} else {
// --- ID missing or session not found ---
// Get exclusive write access to the map
let mut store_wg = RwLockUpgradableReadGuard::upgrade(store_ug);
// This branch runs less often, and we already have write access,
// let's check if any sessions expired. We don't want to hog memory
// forever by abandoned sessions (e.g. when a client lost their cookie)
// Throttle by lifespan - e.g. sweep every hour
if store_wg.last_expiry_sweep.elapsed() > store.config.lifespan {
let now = Instant::now();
store_wg.sessions.retain(|_k, v| v.lock().expires > now);
store_wg.last_expiry_sweep = now;
}
// Find a new unique ID - we are still safely inside the write guard
let new_id = SessionID(loop {
let token: String = rand::thread_rng()
.sample_iter(&rand::distributions::Alphanumeric)
.take(store.config.cookie_len)
.collect();
if !store_wg.sessions.contains_key(&token) {
break token;
}
});
store_wg.sessions.insert(
new_id.to_string(),
Mutex::new(SessionInstance {
data: Default::default(),
expires,
}),
);
new_id
}
}), }),
store, store,
}) })
@ -130,54 +203,100 @@ impl<'a, 'r, D> FromRequest<'a, 'r> for Session<'a, D>
} }
impl<'a, D> Session<'a, D> impl<'a, D> Session<'a, D>
where where
D: 'static + Sync + Send + Default, D: 'static + Sync + Send + Default,
{ {
/// Get the fairing object /// Create the session fairing.
pub fn fairing(lifespan: Duration) -> impl Fairing { ///
SessionFairing::<D> { /// You can configure the session store by calling chained methods on the returned value
lifespan, /// before passing it to `rocket.attach()`
_phantom: PhantomData, pub fn fairing() -> SessionFairing<D> {
} SessionFairing::<D>::new()
} }
/// Access the session store /// Clear session data (replace the value with default)
pub fn get_store(&self) -> &SessionStore<D> { pub fn clear(&self) {
&self.store self.tap(|m| {
}
/// Set the session object to its default state
pub fn reset(&self) {
self.tap_mut(|m| {
*m = D::default(); *m = D::default();
}) })
} }
pub fn tap<T>(&self, func: impl FnOnce(&D) -> T) -> T { /// Access the session's data using a closure.
let rg = self.store.inner.read(); ///
let instance = rg.get(self.id.as_str()).unwrap(); /// The closure is called with the data value as a mutable argument,
func(&instance.data) /// and can return any value to be is passed up to the caller.
} pub fn tap<T>(&self, func: impl FnOnce(&mut D) -> T) -> T {
// Use a read guard, so other already active sessions are not blocked
// from accessing the store. New incoming clients may be blocked until
// the tap() call finishes
let store_rg = self.store.inner.read();
// Unlock the session's mutex.
// Expiry was checked and prolonged at the beginning of the request
let mut instance = store_rg
.sessions
.get(self.id.as_str())
.expect("Session data unexpectedly missing")
.lock();
pub fn tap_mut<T>(&self, func: impl FnOnce(&mut D) -> T) -> T {
let mut wg = self.store.inner.write();
let instance = wg.get_mut(self.id.as_str()).unwrap();
func(&mut instance.data) func(&mut instance.data)
} }
} }
/// Fairing struct /// Fairing struct
struct SessionFairing<D> #[derive(Default)]
where pub struct SessionFairing<D>
D: 'static + Sync + Send + Default, where
D: 'static + Sync + Send + Default,
{ {
lifespan: Duration, config: SessionConfig,
_phantom: PhantomData<D>, phantom: PhantomData<D>,
}
impl<D> SessionFairing<D>
where
D: 'static + Sync + Send + Default,
{
fn new() -> Self {
Self::default()
}
/// Set session lifetime (expiration time).
///
/// Call on the fairing before passing it to `rocket.attach()`
pub fn with_lifetime(mut self, time: Duration) -> Self {
self.config.lifespan = time;
self
}
/// Set session cookie name and length
///
/// Call on the fairing before passing it to `rocket.attach()`
pub fn with_cookie_name(mut self, name: impl Into<Cow<'static, str>>) -> Self {
self.config.cookie_name = name.into();
self
}
/// Set session cookie name and length
///
/// Call on the fairing before passing it to `rocket.attach()`
pub fn with_cookie_len(mut self, length: usize) -> Self {
self.config.cookie_len = length;
self
}
/// Set session cookie name and length
///
/// Call on the fairing before passing it to `rocket.attach()`
pub fn with_cookie_path(mut self, path: impl Into<Cow<'static, str>>) -> Self {
self.config.cookie_path = path.into();
self
}
} }
impl<D> Fairing for SessionFairing<D> impl<D> Fairing for SessionFairing<D>
where where
D: 'static + Sync + Send + Default, D: 'static + Sync + Send + Default,
{ {
fn info(&self) -> Info { fn info(&self) -> Info {
Info { Info {
@ -190,7 +309,7 @@ impl<D> Fairing for SessionFairing<D>
// install the store singleton // install the store singleton
Ok(rocket.manage(SessionStore::<D> { Ok(rocket.manage(SessionStore::<D> {
inner: Default::default(), inner: Default::default(),
lifespan: self.lifespan, config: self.config.clone(),
})) }))
} }
@ -199,7 +318,11 @@ impl<D> Fairing for SessionFairing<D>
let session = request.local_cache(|| SessionID("".to_string())); let session = request.local_cache(|| SessionID("".to_string()));
if !session.0.is_empty() { if !session.0.is_empty() {
response.adjoin_header(Cookie::build(SESSION_COOKIE, session.0.clone()).finish()); response.adjoin_header(
Cookie::build(self.config.cookie_name.clone(), session.to_string())
.path("/")
.finish(),
);
} }
} }
} }

Loading…
Cancel
Save