diff --git a/src/lib.rs b/src/lib.rs index dee238203e5441c20a5208b80261464cfcef0ef1..6d9c4c26cd6661e7793049f60897eaf304bfc008 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,7 +1,7 @@ #![no_std] +#![deny(unsafe_op_in_unsafe_fn)] use core::cell::Cell; -use core::hint; use core::marker::PhantomPinned; use core::mem::{self, MaybeUninit}; use core::ops::{Deref, DerefMut}; @@ -13,11 +13,13 @@ pub trait Reclaim: Sized { } unsafe fn reclaim_shim<T: Reclaim>(erased_ptr: NonNull<Root<Erased>>) { - let ptr = erased_ptr.cast::<Root<T>>(); + let root_ptr = erased_ptr.cast::<Root<T>>(); - let value = unsafe { ptr::addr_of!((*(*ptr.as_ptr()).target.as_ptr()).value).read() }; + // SAFETY: The caller guarantees that this value is initialized. + let value = unsafe { ptr::addr_of!((*(*root_ptr.as_ptr()).anchor.as_ptr()).value).read() }; - let unclaimed = Unclaimed { ptr }; + // SAFETY: The caller guarantees the state is "unclaimed". + let unclaimed = Unclaimed { root_ptr }; value.reclaim(unclaimed); } @@ -33,17 +35,18 @@ const STATE_UNCLAIMED: usize = usize::MAX - 1; const STATE_SHARED_MAX: usize = isize::MAX as usize; const STATE_SHARED_INIT: usize = 0; +// Fixed layout is required so that the header is in the same place for all T. #[repr(C)] pub struct Root<T> { state: Cell<usize>, reclaim_shim: unsafe fn(NonNull<Root<Erased>>), _pin: PhantomPinned, - target: MaybeUninit<ProjectTarget<T>>, + anchor: MaybeUninit<ProjectAnchor<T>>, } impl<T> Drop for Root<T> { fn drop(&mut self) { - let state = self.state.get(); + let state = *self.state.get_mut(); if state != STATE_UNPINNED { struct Abort; impl Drop for Abort { @@ -67,25 +70,33 @@ impl<T> Root<T> { state: Cell::new(STATE_UNPINNED), reclaim_shim: reclaim_shim::<T>, _pin: PhantomPinned, - target: MaybeUninit::uninit(), + anchor: MaybeUninit::uninit(), } } pub fn unclaimed(self: Pin<&mut Self>) -> Unclaimed<T> { - if self.state.get() != STATE_UNPINNED { + // SAFETY: We don't move out of this reference. + let root = unsafe { self.get_unchecked_mut() }; + + let state = root.state.get_mut(); + if *state != STATE_UNPINNED { panic!("unclaimed called multiple times"); } - self.state.set(STATE_UNCLAIMED); + *state = STATE_UNCLAIMED; - let ptr = unsafe { NonNull::from(self.get_unchecked_mut()) }; - let ptr_erased = ptr.cast::<Root<Erased>>(); + let root_ptr = NonNull::from(root); + let root_erased = root_ptr.cast::<Root<Erased>>(); + // SAFETY: This is a partial write into a `MaybeUninit` behind a mutable reference. unsafe { - let target_root = ptr::addr_of_mut!((*(*ptr.as_ptr()).target.as_mut_ptr()).root); - target_root.write(Cell::new(Some(ptr_erased))); + let anchor_root = + ptr::addr_of_mut!((*(*root_ptr.as_ptr()).anchor.as_mut_ptr()).root_ptr); + anchor_root.write(Cell::new(Some(root_erased))); } - Unclaimed { ptr } + // SAFETY: The pointer is a pinned shared read-write borrow and we set the + // state to "unclaimed" above. + Unclaimed { root_ptr } } } @@ -95,21 +106,22 @@ impl<T: Reclaim> Default for Root<T> { } } -pub struct ProjectTarget<T: ?Sized> { - root: Cell<Option<NonNull<Root<Erased>>>>, +pub struct ProjectAnchor<T: ?Sized> { + // pinned shared read-write borrow + root_ptr: Cell<Option<NonNull<Root<Erased>>>>, value: T, } -impl<T> ProjectTarget<T> { +impl<T> ProjectAnchor<T> { pub const fn new(value: T) -> Self { Self { - root: Cell::new(None), + root_ptr: Cell::new(None), value, } } } -impl<T: ?Sized> Deref for ProjectTarget<T> { +impl<T: ?Sized> Deref for ProjectAnchor<T> { type Target = T; fn deref(&self) -> &T { @@ -117,7 +129,7 @@ impl<T: ?Sized> Deref for ProjectTarget<T> { } } -impl<T: ?Sized> DerefMut for ProjectTarget<T> { +impl<T: ?Sized> DerefMut for ProjectAnchor<T> { fn deref_mut(&mut self) -> &mut T { &mut self.value } @@ -125,87 +137,106 @@ impl<T: ?Sized> DerefMut for ProjectTarget<T> { #[repr(transparent)] pub struct Unclaimed<T> { - ptr: NonNull<Root<T>>, + // pinned shared read-write borrow + root_ptr: NonNull<Root<T>>, } impl<T> Drop for Unclaimed<T> { fn drop(&mut self) { - unsafe { - debug_assert_eq!((*self.ptr.as_ptr()).state.get(), STATE_UNCLAIMED); - (*self.ptr.as_ptr()).state.set(STATE_UNPINNED); - } + // SAFETY: The pointer is a pinned shared read-write borrow and we don't + // move out of it. + let root = unsafe { self.root_ptr.as_mut() }; + let state = root.state.get_mut(); + debug_assert_eq!(*state, STATE_UNCLAIMED); + *state = STATE_UNPINNED; } } impl<T> Unclaimed<T> { - pub fn claim(self, value: T) -> Xrc<T> { - unsafe { - debug_assert_eq!((*self.ptr.as_ptr()).state.get(), STATE_UNCLAIMED); - (*self.ptr.as_ptr()).state.set(STATE_SHARED_INIT); - } + pub fn claim(mut self, value: T) -> Xrc<T> { + // SAFETY: The pointer is a pinned shared read-write borrow and we don't + // move out of it. + let root = unsafe { self.root_ptr.as_mut() }; + + let state = root.state.get_mut(); + debug_assert_eq!(*state, STATE_UNCLAIMED); + *state = STATE_SHARED_INIT; + // SAFETY: This is a partial write into a `MaybeUninit` behind a mutable reference. unsafe { - let target_value = ptr::addr_of_mut!((*(*self.ptr.as_ptr()).target.as_mut_ptr()).value); - target_value.write(value); + let anchor_value = ptr::addr_of_mut!((*root.anchor.as_mut_ptr()).value); + anchor_value.write(value); } - // shared read-only reborrow - let target = unsafe { (*self.ptr.as_ptr()).target.assume_init_ref() }; - let ptr = NonNull::from(target); + // SAFETY: `root.anchor` is fully initialized, because we initialized + // `.root_ptr` in `Root::unclaimed` and `.value` above. + // Note that we intentionally create a shared read-only borrow here, + // because `Xrc` must work with read-only borrows. + let anchor = unsafe { root.anchor.assume_init_ref() }; + let anchor_ptr = NonNull::from(anchor); mem::forget(self); - Xrc { ptr } + + // SAFETY: + // * The pointer is a shared read-only borrow. + // * We set the state to "shared" above. + // * The `anchor` is derived from `root`, `root` is derived from + // `self.root_ptr` and `anchor.root_ptr` is `Some` and equal to + // `self.root_ptr`, because we set it up that way in `Root::unclaimed`. + // Therefore, `anchor` is derived from `anchor.root_ptr`. + // Because we don't do any (non-interior) mutable accesses of + // the root after this point, the anchor will remain valid. + Xrc { anchor_ptr } } } #[repr(transparent)] pub struct Xrc<T: ?Sized> { - ptr: NonNull<ProjectTarget<T>>, + // shared read-only borrow that outlives `anchor.root_ptr`, which must be `Some` + // Note that `anchor_ptr` may be derived from `anchor.root_ptr` with an equal + // lifetime, which means that we can only write to interior mutable parts of the + // root, because otherwise the anchor may get invalidated. + anchor_ptr: NonNull<ProjectAnchor<T>>, } impl<T: ?Sized> Drop for Xrc<T> { fn drop(&mut self) { - let Some(root) = Self::inner(self).root.get() else { - debug_assert!(false, "missing root"); - unsafe { - hint::unreachable_unchecked(); - } - }; + let root_header = Self::root_header(self); - let state_ref = unsafe { &(*root.as_ptr()).state }; - let state = state_ref.get(); + let state = root_header.state.get(); debug_assert!(state <= STATE_SHARED_MAX); if state != STATE_SHARED_INIT { - state_ref.set(state - 1); + root_header.state.set(state - 1); return; } - state_ref.set(STATE_UNCLAIMED); + root_header.state.set(STATE_UNCLAIMED); - let reclaim_shim = unsafe { (*root.as_ptr()).reclaim_shim }; + let reclaim_shim = root_header.reclaim_shim; + let root_erased = Self::root_ptr(self); + + // SAFETY: + // * We checked that there are no other `Xrc`s pointing to the same data. + // * We set the state to "unclaimed" and the old value is still in place. unsafe { - reclaim_shim(root); + reclaim_shim(root_erased); } } } impl<T> Clone for Xrc<T> { fn clone(&self) -> Self { - let root = Self::inner(self).root.get(); - let Some(root) = root else { - debug_assert!(false, "missing root"); - unsafe { - hint::unreachable_unchecked(); - } - }; + let root = Self::root_header(self); - let state_ref = unsafe { &(*root.as_ptr()).state }; - let state = state_ref.get(); + let state = root.state.get(); debug_assert!(state <= STATE_SHARED_MAX); if state == STATE_SHARED_MAX { panic!("ref count overflow"); } - state_ref.set(state + 1); + root.state.set(state + 1); - Self { ptr: self.ptr } + // SAFETY: We increased the reference count above. + Self { + anchor_ptr: self.anchor_ptr, + } } } @@ -213,7 +244,7 @@ impl<T: ?Sized> Deref for Xrc<T> { type Target = T; fn deref(&self) -> &T { - &Self::inner(self).value + &Self::anchor(self).value } } @@ -221,33 +252,56 @@ impl<T: ?Sized> Xrc<T> { pub fn project<U, F>(this: Self, f: F) -> Xrc<U> where U: ?Sized, - F: FnOnce(&ProjectTarget<T>) -> &ProjectTarget<U>, + F: FnOnce(&ProjectAnchor<T>) -> &ProjectAnchor<U>, { - let source = Self::inner(&this); - let Some(source_root) = source.root.get() else { - debug_assert!(false, "missing root"); - unsafe { - hint::unreachable_unchecked(); - } - }; + let source = Self::anchor(&this); let target = f(source); - if let Some(target_root) = target.root.get() { + let source_root = Self::root_ptr(&this); + + if let Some(target_root) = target.root_ptr.get() { if target_root != source_root { panic!("projected from different roots"); } } else { - target.root.set(Some(source_root)); + target.root_ptr.set(Some(source_root)); } - let ptr = NonNull::from(target); + let anchor_ptr = NonNull::from(target); mem::forget(this); - Xrc { ptr } + + // SAFETY: + // * We ensure that `target.root_ptr == source.root_ptr` above. + // * Because `source.root_ptr` is `Some`, `target.root_ptr` is also `Some`. + // * The higher-ranked lifetime bound on the closure ensures that `target` + // outlives `source`, which outlives `source.root_ptr`. + // Therefore, `target` also outlives `target.root_ptr`. + // * We forget `self` above, so the reference count is unchanged. + Xrc { anchor_ptr } + } + + fn anchor(this: &Self) -> &ProjectAnchor<T> { + // SAFETY: The pointer is a shared read-write borrow. + unsafe { this.anchor_ptr.as_ref() } + } + + /// Returns a read-write pointer to the root with full provenance. + fn root_ptr(this: &Self) -> NonNull<Root<Erased>> { + let root_ptr = Self::anchor(this).root_ptr.get(); + // SAFETY: The root pointer must be set to `Some` before this + // `Xrc` is constructed. + unsafe { root_ptr.unwrap_unchecked() } } - fn inner(this: &Self) -> &ProjectTarget<T> { - unsafe { this.ptr.as_ref() } + /// Returns a read-only reborrow of the erased root. + fn root_header(this: &Self) -> &Root<Erased> { + // SAFETY: We can derive a new shared read-only borrow without invalidating + // other derived shared read-only borrows, including the anchor pointer. + // Because `Root<T>` has a fixed layout, `T` is at the end, and `Erased` + // is a zero-sized type with alignment 1, we can reinterpret `Root<T>` + // as `Root<Erased>`. + unsafe { Self::root_ptr(this).as_ref() } } } @@ -264,6 +318,7 @@ mod tests { use std::rc::Rc; fn iter_pinned_mut<T>(slice: Pin<&mut [T]>) -> impl Iterator<Item = Pin<&mut T>> { + // SAFETY: The elements of a slice are structurally pinned. unsafe { slice .get_unchecked_mut() @@ -370,13 +425,13 @@ mod tests { fn project() { struct Value { a: i32, - b: ProjectTarget<i32>, + b: ProjectAnchor<i32>, } let root = pin!(Root::new()); let xrc = root.unclaimed().claim(NoReclaim(Value { a: 1, - b: ProjectTarget::new(2), + b: ProjectAnchor::new(2), })); assert_eq!(xrc.0.a, 1); assert_eq!(*xrc.0.b, 2); @@ -397,13 +452,13 @@ mod tests { #[test] #[should_panic = "projected from different roots"] fn project_multiple() { - let target = Rc::new(ProjectTarget::new(0)); + let anchor = Rc::new(ProjectAnchor::new(0)); let root1 = pin!(Root::new()); - let xrc1 = root1.unclaimed().claim(NoReclaim(target.clone())); + let xrc1 = root1.unclaimed().claim(NoReclaim(anchor.clone())); let root2 = pin!(Root::new()); - let xrc2 = root2.unclaimed().claim(NoReclaim(target)); + let xrc2 = root2.unclaimed().claim(NoReclaim(anchor)); let _ = Xrc::project(xrc1, |xrc| &*xrc.0); let _ = Xrc::project(xrc2, |xrc| &*xrc.0); @@ -450,8 +505,8 @@ mod tests { } struct Value { - number: ProjectTarget<i32>, - string: ProjectTarget<String>, + number: ProjectAnchor<i32>, + string: ProjectAnchor<String>, } #[allow(clippy::declare_interior_mutable_const)] @@ -466,8 +521,8 @@ mod tests { let mut values = Vec::new(); for n in 0..8 { let value = Value { - number: ProjectTarget::new(n), - string: ProjectTarget::new(n.to_string()), + number: ProjectAnchor::new(n), + string: ProjectAnchor::new(n.to_string()), }; let xrc = alloc.try_alloc(value).ok().unwrap(); values.push(xrc); @@ -490,8 +545,8 @@ mod tests { .eq(0..8)); let value = Value { - number: ProjectTarget::new(0), - string: ProjectTarget::new(String::new()), + number: ProjectAnchor::new(0), + string: ProjectAnchor::new(String::new()), }; let value = alloc.try_alloc(value).err().unwrap();