use std::fmt; use std::fmt::Formatter; use std::sync::atomic::{AtomicU32, Ordering}; use crate::runtime::fault::Fault; use crate::runtime::run_thread::ThreadToken; use crate::runtime::span::MemorySpan; /// Records memory claims and protects from illegal access #[derive(Debug, Default)] struct MemoryGuard { claims: Vec, counter: AtomicU32, } #[derive(Clone, Copy, Eq, PartialEq, Debug, Ord, PartialOrd)] pub struct ClaimId(pub u32); impl fmt::Display for ClaimId { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { write!(f, "{}", self.0) } } #[derive(Debug, Clone)] struct Claim { owner: ThreadToken, span: MemorySpan, id: ClaimId, } impl MemoryGuard { pub fn new() -> Self { Default::default() } /// Claim a memory area pub fn claim(&mut self, owner: ThreadToken, span: MemorySpan) -> Result { // naive for claim in &self.claims { if claim.span.intersects(span) && claim.owner != owner { return Err(Fault::MemoryLocked { area: claim.span, owner: claim.owner, }); } } let id = self.next_id(); self.claims.push(Claim { id, owner, span, }); Ok(id) } /// Get a unique claim ID and increment the counter pub fn next_id(&self) -> ClaimId { ClaimId(self.counter.fetch_and(1, Ordering::Relaxed)) } /// Get the next claim ID (ID is incremented after calling "next"). /// May be used for release_owned_after() pub fn epoch(&self) -> ClaimId { ClaimId(self.counter.load(Ordering::Relaxed)) } /// Release a claim by claim ID pub fn release(&mut self, owner: ThreadToken, claim: ClaimId) -> Result<(), Fault> { match self.claims.iter().position(|c| c.id == claim && c.owner == owner) { Some(pos) => { self.claims.swap_remove(pos); Ok(()) } None => Err(Fault::ClaimNotExist { claim, owner }), } } /// Release all owned by a thread (thread ends) pub fn release_owned(&mut self, owner: ThreadToken) { self.claims.retain(|c| c.owner != owner); } /// Release all owned by a thread, with claim ID >= a given value /// (return from a subroutine) pub fn release_owned_after(&mut self, owner: ThreadToken, epoch: ClaimId) { self.claims.retain(|c| c.owner != owner || c.id >= epoch); } }