Browse Source

add examples, automatic expired removal, better configurability

Ondřej Hruška 9 months ago
parent
commit
cde08fe788
Signed by: Ondřej Hruška <ondra@ondrovo.com> GPG key ID: 2C5FD5035250423D
3 changed files with 343 additions and 85 deletions
  1. 70 0
      examples/dog_list/main.rs
  2. 72 0
      examples/visit_counter/main.rs
  3. 201 85
      src/lib.rs

+ 70 - 0
examples/dog_list/main.rs View File

@@ -0,0 +1,70 @@
1
+#![feature(proc_macro_hygiene, decl_macro)]
2
+#[macro_use]
3
+extern crate rocket;
4
+
5
+use rocket::response::content::Html;
6
+use rocket::response::Redirect;
7
+use rocket::request::Form;
8
+
9
+type Session<'a> = rocket_session::Session<'a, Vec<String>>;
10
+
11
+fn main() {
12
+    rocket::ignite()
13
+        .attach(Session::fairing())
14
+        .mount("/", routes![index, add, remove])
15
+        .launch();
16
+}
17
+
18
+#[get("/")]
19
+fn index(session: Session) -> Html<String> {
20
+    let mut page = String::new();
21
+    page.push_str(r#"
22
+            <!DOCTYPE html>
23
+            <h1>My Dogs</h1>
24
+
25
+            <form method="POST" action="/add">
26
+            Add Dog: <input type="text" name="name"> <input type="submit" value="Add">
27
+            </form>
28
+
29
+            <ul>
30
+        "#);
31
+
32
+    session.tap(|sess| {
33
+        for (n, dog) in sess.iter().enumerate() {
34
+            page.push_str(&format!(r#"
35
+                <li>&#x1F436; {} <a href="/remove/{}">Remove</a></li>
36
+            "#, dog, n));
37
+        }
38
+    });
39
+
40
+    page.push_str(r#"
41
+            </ul>
42
+        "#);
43
+
44
+    Html(page)
45
+}
46
+
47
+#[derive(FromForm)]
48
+struct AddForm {
49
+    name: String,
50
+}
51
+
52
+#[post("/add", data="<dog>")]
53
+fn add(session: Session, dog : Form<AddForm>) -> Redirect {
54
+    session.tap(move |sess| {
55
+        sess.push(dog.into_inner().name);
56
+    });
57
+
58
+    Redirect::found("/")
59
+}
60
+
61
+#[get("/remove/<dog>")]
62
+fn remove(session: Session, dog : usize) -> Redirect {
63
+    session.tap(|sess| {
64
+        if dog < sess.len() {
65
+            sess.remove(dog);
66
+        }
67
+    });
68
+
69
+    Redirect::found("/")
70
+}

+ 72 - 0
examples/visit_counter/main.rs View File

@@ -0,0 +1,72 @@
1
+//! This demo is a page visit counter, with a custom cookie name, length, and expiry time.
2
+//!
3
+//! The expiry time is set to 10 seconds to illustrate how a session is cleared if inactive.
4
+
5
+#![feature(proc_macro_hygiene, decl_macro)]
6
+#[macro_use]
7
+extern crate rocket;
8
+
9
+use std::time::Duration;
10
+use rocket::response::content::Html;
11
+
12
+#[derive(Default, Clone)]
13
+struct SessionData {
14
+    visits1: usize,
15
+    visits2: usize,
16
+}
17
+
18
+// It's convenient to define a type alias:
19
+type Session<'a> = rocket_session::Session<'a, SessionData>;
20
+
21
+fn main() {
22
+    rocket::ignite()
23
+        .attach(Session::fairing()
24
+            // 10 seconds of inactivity until session expires
25
+            // (wait 10s and refresh, the numbers will reset)
26
+            .with_lifetime(Duration::from_secs(10))
27
+            // custom cookie name and length
28
+            .with_cookie_name("my_cookie")
29
+            .with_cookie_len(20)
30
+        )
31
+        .mount("/", routes![index, about])
32
+        .launch();
33
+}
34
+
35
+#[get("/")]
36
+fn index(session: Session) -> Html<String> {
37
+    // Here we build the entire response inside the 'tap' closure.
38
+
39
+    // While inside, the session is locked to parallel changes, e.g.
40
+    // from a different browser tab.
41
+    session.tap(|sess| {
42
+        sess.visits1 += 1;
43
+
44
+        Html(format!(r##"
45
+                <!DOCTYPE html>
46
+                <h1>Home</h1>
47
+                <a href="/">Refresh</a> &bull; <a href="/about/">go to About</a>
48
+                <p>Visits: home {}, about {}</p>
49
+            "##,
50
+            sess.visits1,
51
+            sess.visits2
52
+        ))
53
+    })
54
+}
55
+
56
+#[get("/about")]
57
+fn about(session: Session) -> Html<String> {
58
+    // Here we return a value from the tap function and use it below
59
+    let count = session.tap(|sess| {
60
+        sess.visits2 += 1;
61
+        sess.visits2
62
+    });
63
+
64
+    Html(format!(r##"
65
+            <!DOCTYPE html>
66
+            <h1>About</h1>
67
+            <a href="/about">Refresh</a> &bull; <a href="/">go home</a>
68
+            <p>Page visits: {}</p>
69
+        "##,
70
+        count
71
+    ))
72
+}

+ 201 - 85
src/lib.rs View File

@@ -1,4 +1,4 @@
1
-use parking_lot::RwLock;
1
+use parking_lot::{RwLock, RwLockUpgradableReadGuard, Mutex};
2 2
 use rand::Rng;
3 3
 
4 4
 use rocket::{
@@ -12,60 +12,100 @@ use std::collections::HashMap;
12 12
 use std::marker::PhantomData;
13 13
 use std::ops::Add;
14 14
 use std::time::{Duration, Instant};
15
+use std::borrow::Cow;
16
+use std::fmt::{Display, Formatter, self};
15 17
 
16
-const SESSION_COOKIE: &str = "SESSID";
17
-const SESSION_ID_LEN: usize = 16;
18
-
19
-/// Session, as stored in the sessions store
18
+/// Session store (shared state)
20 19
 #[derive(Debug)]
21
-struct SessionInstance<D>
20
+pub struct SessionStore<D>
22 21
     where
23 22
         D: 'static + Sync + Send + Default,
24 23
 {
25
-    /// Data object
26
-    data: D,
27
-    /// Expiry
28
-    expires: Instant,
24
+    /// The internally mutable map of sessions
25
+    inner: RwLock<StoreInner<D>>,
26
+    // Session config
27
+    config: SessionConfig,
29 28
 }
30 29
 
31
-/// Session store (shared state)
32
-#[derive(Default, Debug)]
33
-pub struct SessionStore<D>
34
-    where
35
-        D: 'static + Sync + Send + Default,
36
-{
37
-    /// The internaly mutable map of sessions
38
-    inner: RwLock<HashMap<String, SessionInstance<D>>>,
30
+/// Session config object
31
+#[derive(Debug, Clone)]
32
+struct SessionConfig {
39 33
     /// Sessions lifespan
40 34
     lifespan: Duration,
35
+    /// Session cookie name
36
+    cookie_name: Cow<'static, str>,
37
+    /// Session cookie path
38
+    cookie_path: Cow<'static, str>,
39
+    /// Session ID character length
40
+    cookie_len: usize,
41 41
 }
42 42
 
43
-impl<D> SessionStore<D>
43
+impl Default for SessionConfig {
44
+    fn default() -> Self {
45
+        Self {
46
+            lifespan: Duration::from_secs(3600),
47
+            cookie_name: "rocket_session".into(),
48
+            cookie_path: "/".into(),
49
+            cookie_len: 16,
50
+        }
51
+    }
52
+}
53
+
54
+/// Mutable object stored inside SessionStore behind a RwLock
55
+#[derive(Debug)]
56
+struct StoreInner<D>
57
+    where
58
+        D: 'static + Sync + Send + Default {
59
+    sessions: HashMap<String, Mutex<SessionInstance<D>>>,
60
+    last_expiry_sweep: Instant,
61
+}
62
+
63
+impl<D> Default for StoreInner<D>
64
+    where
65
+        D: 'static + Sync + Send + Default {
66
+    fn default() -> Self {
67
+        Self {
68
+            sessions: Default::default(),
69
+            // the first expiry sweep is scheduled one lifetime from start-up
70
+            last_expiry_sweep: Instant::now(),
71
+        }
72
+    }
73
+}
74
+
75
+/// Session, as stored in the sessions store
76
+#[derive(Debug)]
77
+struct SessionInstance<D>
44 78
     where
45 79
         D: 'static + Sync + Send + Default,
46 80
 {
47
-    /// Remove all expired sessions
48
-    pub fn remove_expired(&self) {
49
-        let now = Instant::now();
50
-        self.inner.write().retain(|_k, v| v.expires > now);
51
-    }
81
+    /// Data object
82
+    data: D,
83
+    /// Expiry
84
+    expires: Instant,
52 85
 }
53 86
 
54 87
 /// Session ID newtype for rocket's "local_cache"
55
-#[derive(PartialEq, Hash, Clone, Debug)]
88
+#[derive(Clone, Debug)]
56 89
 struct SessionID(String);
57 90
 
58 91
 impl SessionID {
59 92
     fn as_str(&self) -> &str {
60 93
         self.0.as_str()
61 94
     }
95
+}
62 96
 
63
-    fn to_string(&self) -> String {
64
-        self.0.clone()
97
+impl Display for SessionID {
98
+    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
99
+        f.write_str(&self.0)
65 100
     }
66 101
 }
67 102
 
68 103
 /// Session instance
104
+///
105
+/// To access the active session, simply add it as an argument to a route function.
106
+///
107
+/// Sessions are started, restored, or expired in the `FromRequest::from_request()` method
108
+/// when a `Session` is prepared for one of the route functions.
69 109
 #[derive(Debug)]
70 110
 pub struct Session<'a, D>
71 111
     where
@@ -84,45 +124,76 @@ impl<'a, 'r, D> FromRequest<'a, 'r> for Session<'a, D>
84 124
     type Error = ();
85 125
 
86 126
     fn from_request(request: &'a Request<'r>) -> Outcome<Self, (Status, Self::Error), ()> {
87
-        let store : State<SessionStore<D>> = request.guard().unwrap();
127
+        let store: State<SessionStore<D>> = request.guard().unwrap();
88 128
         Outcome::Success(Session {
89 129
             id: request.local_cache(|| {
130
+                let store_ug = store.inner.upgradable_read();
131
+
90 132
                 // Resolve session ID
91
-                let id = if let Some(cookie) = request.cookies().get(SESSION_COOKIE) {
92
-                    SessionID(cookie.value().to_string())
133
+                let id = if let Some(cookie) = request.cookies().get(&store.config.cookie_name) {
134
+                    Some(SessionID(cookie.value().to_string()))
93 135
                 } else {
94
-                    SessionID(
95
-                        rand::thread_rng()
96
-                            .sample_iter(&rand::distributions::Alphanumeric)
97
-                            .take(SESSION_ID_LEN)
98
-                            .collect(),
99
-                    )
136
+                    None
100 137
                 };
101 138
 
102
-                let new_expiration = Instant::now().add(store.lifespan);
103
-                let mut wg = store.inner.write();
104
-                match wg.get_mut(id.as_str()) {
105
-                    Some(ses) => {
106
-                        // Check expiration
107
-                        if ses.expires <= Instant::now() {
108
-                            ses.data = D::default();
109
-                        }
110
-                        // Update expiry timestamp
111
-                        ses.expires = new_expiration;
112
-                    },
113
-                    None => {
114
-                        // New session
115
-                        wg.insert(
116
-                            id.to_string(),
117
-                            SessionInstance {
118
-                                data: D::default(),
119
-                                expires: new_expiration,
120
-                            }
121
-                        );
139
+                let expires = Instant::now().add(store.config.lifespan);
140
+
141
+                if let Some(m) = id.as_ref()
142
+                    .and_then(|token| store_ug.sessions.get(token.as_str()))
143
+                {
144
+                    // --- ID obtained from a cookie && session found in the store ---
145
+
146
+                    let mut inner = m.lock();
147
+                    if inner.expires <= Instant::now() {
148
+                        // Session expired, reuse the ID but drop data.
149
+                        inner.data = D::default();
122 150
                     }
123
-                };
124 151
 
125
-                id
152
+                    // Session is extended by making a request with valid ID
153
+                    inner.expires = expires;
154
+
155
+                    id.unwrap()
156
+                } else {
157
+                    // --- ID missing or session not found ---
158
+
159
+                    // Get exclusive write access to the map
160
+                    let mut store_wg = RwLockUpgradableReadGuard::upgrade(store_ug);
161
+
162
+                    // This branch runs less often, and we already have write access,
163
+                    // let's check if any sessions expired. We don't want to hog memory
164
+                    // forever by abandoned sessions (e.g. when a client lost their cookie)
165
+
166
+                    // Throttle by lifespan - e.g. sweep every hour
167
+                    if store_wg.last_expiry_sweep.elapsed() > store.config.lifespan {
168
+                        let now = Instant::now();
169
+                        store_wg.sessions
170
+                            .retain(|_k, v| v.lock().expires > now);
171
+
172
+                        store_wg.last_expiry_sweep = now;
173
+                    }
174
+
175
+                    // Find a new unique ID - we are still safely inside the write guard
176
+                    let new_id = SessionID(loop {
177
+                        let token: String = rand::thread_rng()
178
+                            .sample_iter(&rand::distributions::Alphanumeric)
179
+                            .take(store.config.cookie_len)
180
+                            .collect();
181
+
182
+                        if !store_wg.sessions.contains_key(&token) {
183
+                            break token;
184
+                        }
185
+                    });
186
+
187
+                    store_wg.sessions.insert(
188
+                        new_id.to_string(),
189
+                        Mutex::new(SessionInstance {
190
+                            data: Default::default(),
191
+                            expires,
192
+                        }),
193
+                    );
194
+
195
+                    new_id
196
+                }
126 197
             }),
127 198
             store,
128 199
         })
@@ -133,46 +204,90 @@ impl<'a, D> Session<'a, D>
133 204
     where
134 205
         D: 'static + Sync + Send + Default,
135 206
 {
136
-    /// Get the fairing object
137
-    pub fn fairing(lifespan: Duration) -> impl Fairing {
138
-        SessionFairing::<D> {
139
-            lifespan,
140
-            _phantom: PhantomData,
141
-        }
207
+    /// Create the session fairing.
208
+    ///
209
+    /// You can configure the session store by calling chained methods on the returned value
210
+    /// before passing it to `rocket.attach()`
211
+    pub fn fairing() -> SessionFairing<D> {
212
+        SessionFairing::<D>::new()
142 213
     }
143 214
 
144
-    /// Access the session store
145
-    pub fn get_store(&self) -> &SessionStore<D> {
146
-        &self.store
147
-    }
148
-
149
-    /// Set the session object to its default state
150
-    pub fn reset(&self) {
151
-        self.tap_mut(|m| {
215
+    /// Clear session data (replace the value with default)
216
+    pub fn clear(&self) {
217
+        self.tap(|m| {
152 218
             *m = D::default();
153 219
         })
154 220
     }
155 221
 
156
-    pub fn tap<T>(&self, func: impl FnOnce(&D) -> T) -> T {
157
-        let rg = self.store.inner.read();
158
-        let instance = rg.get(self.id.as_str()).unwrap();
159
-        func(&instance.data)
160
-    }
222
+    /// Access the session's data using a closure.
223
+    ///
224
+    /// The closure is called with the data value as a mutable argument,
225
+    /// and can return any value to be is passed up to the caller.
226
+    pub fn tap<T>(&self, func: impl FnOnce(&mut D) -> T) -> T {
227
+        // Use a read guard, so other already active sessions are not blocked
228
+        // from accessing the store. New incoming clients may be blocked until
229
+        // the tap() call finishes
230
+        let store_rg = self.store.inner.read();
231
+
232
+        // Unlock the session's mutex.
233
+        // Expiry was checked and prolonged at the beginning of the request
234
+        let mut instance = store_rg.sessions.get(self.id.as_str())
235
+            .expect("Session data unexpectedly missing")
236
+            .lock();
161 237
 
162
-    pub fn tap_mut<T>(&self, func: impl FnOnce(&mut D) -> T) -> T {
163
-        let mut wg = self.store.inner.write();
164
-        let instance = wg.get_mut(self.id.as_str()).unwrap();
165 238
         func(&mut instance.data)
166 239
     }
167 240
 }
168 241
 
169 242
 /// Fairing struct
170
-struct SessionFairing<D>
243
+#[derive(Default)]
244
+pub struct SessionFairing<D>
171 245
     where
172 246
         D: 'static + Sync + Send + Default,
173 247
 {
174
-    lifespan: Duration,
175
-    _phantom: PhantomData<D>,
248
+    config: SessionConfig,
249
+    phantom: PhantomData<D>,
250
+}
251
+
252
+impl<D> SessionFairing<D>
253
+    where
254
+        D: 'static + Sync + Send + Default
255
+{
256
+    fn new() -> Self {
257
+        Self::default()
258
+    }
259
+
260
+    /// Set session lifetime (expiration time).
261
+    ///
262
+    /// Call on the fairing before passing it to `rocket.attach()`
263
+    pub fn with_lifetime(mut self, time: Duration) -> Self {
264
+        self.config.lifespan = time;
265
+        self
266
+    }
267
+
268
+    /// Set session cookie name and length
269
+    ///
270
+    /// Call on the fairing before passing it to `rocket.attach()`
271
+    pub fn with_cookie_name(mut self, name: impl Into<Cow<'static, str>>) -> Self {
272
+        self.config.cookie_name = name.into();
273
+        self
274
+    }
275
+
276
+    /// Set session cookie name and length
277
+    ///
278
+    /// Call on the fairing before passing it to `rocket.attach()`
279
+    pub fn with_cookie_len(mut self, length: usize) -> Self {
280
+        self.config.cookie_len = length;
281
+        self
282
+    }
283
+
284
+    /// Set session cookie name and length
285
+    ///
286
+    /// Call on the fairing before passing it to `rocket.attach()`
287
+    pub fn with_cookie_path(mut self, path: impl Into<Cow<'static, str>>) -> Self {
288
+        self.config.cookie_path = path.into();
289
+        self
290
+    }
176 291
 }
177 292
 
178 293
 impl<D> Fairing for SessionFairing<D>
@@ -190,7 +305,7 @@ impl<D> Fairing for SessionFairing<D>
190 305
         // install the store singleton
191 306
         Ok(rocket.manage(SessionStore::<D> {
192 307
             inner: Default::default(),
193
-            lifespan: self.lifespan,
308
+            config: self.config.clone(),
194 309
         }))
195 310
     }
196 311
 
@@ -199,7 +314,8 @@ impl<D> Fairing for SessionFairing<D>
199 314
         let session = request.local_cache(|| SessionID("".to_string()));
200 315
 
201 316
         if !session.0.is_empty() {
202
-            response.adjoin_header(Cookie::build(SESSION_COOKIE, session.0.clone()).finish());
317
+            response.adjoin_header(Cookie::build(self.config.cookie_name.clone(), session.to_string())
318
+                .path("/").finish());
203 319
         }
204 320
     }
205 321
 }