From 43d87f69f6dd5874c6de32e43af1cf7373f05754 Mon Sep 17 00:00:00 2001
From: Lukas Markeffsky <@>
Date: Thu, 1 Feb 2024 14:14:46 +0000
Subject: [PATCH] add safety comments

---
 src/lib.rs | 227 +++++++++++++++++++++++++++++++++--------------------
 1 file changed, 141 insertions(+), 86 deletions(-)

diff --git a/src/lib.rs b/src/lib.rs
index dee2382..6d9c4c2 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -1,7 +1,7 @@
 #![no_std]
+#![deny(unsafe_op_in_unsafe_fn)]
 
 use core::cell::Cell;
-use core::hint;
 use core::marker::PhantomPinned;
 use core::mem::{self, MaybeUninit};
 use core::ops::{Deref, DerefMut};
@@ -13,11 +13,13 @@ pub trait Reclaim: Sized {
 }
 
 unsafe fn reclaim_shim<T: Reclaim>(erased_ptr: NonNull<Root<Erased>>) {
-    let ptr = erased_ptr.cast::<Root<T>>();
+    let root_ptr = erased_ptr.cast::<Root<T>>();
 
-    let value = unsafe { ptr::addr_of!((*(*ptr.as_ptr()).target.as_ptr()).value).read() };
+    // SAFETY: The caller guarantees that this value is initialized.
+    let value = unsafe { ptr::addr_of!((*(*root_ptr.as_ptr()).anchor.as_ptr()).value).read() };
 
-    let unclaimed = Unclaimed { ptr };
+    // SAFETY: The caller guarantees the state is "unclaimed".
+    let unclaimed = Unclaimed { root_ptr };
     value.reclaim(unclaimed);
 }
 
@@ -33,17 +35,18 @@ const STATE_UNCLAIMED: usize = usize::MAX - 1;
 const STATE_SHARED_MAX: usize = isize::MAX as usize;
 const STATE_SHARED_INIT: usize = 0;
 
+// Fixed layout is required so that the header is in the same place for all T.
 #[repr(C)]
 pub struct Root<T> {
     state: Cell<usize>,
     reclaim_shim: unsafe fn(NonNull<Root<Erased>>),
     _pin: PhantomPinned,
-    target: MaybeUninit<ProjectTarget<T>>,
+    anchor: MaybeUninit<ProjectAnchor<T>>,
 }
 
 impl<T> Drop for Root<T> {
     fn drop(&mut self) {
-        let state = self.state.get();
+        let state = *self.state.get_mut();
         if state != STATE_UNPINNED {
             struct Abort;
             impl Drop for Abort {
@@ -67,25 +70,33 @@ impl<T> Root<T> {
             state: Cell::new(STATE_UNPINNED),
             reclaim_shim: reclaim_shim::<T>,
             _pin: PhantomPinned,
-            target: MaybeUninit::uninit(),
+            anchor: MaybeUninit::uninit(),
         }
     }
 
     pub fn unclaimed(self: Pin<&mut Self>) -> Unclaimed<T> {
-        if self.state.get() != STATE_UNPINNED {
+        // SAFETY: We don't move out of this reference.
+        let root = unsafe { self.get_unchecked_mut() };
+
+        let state = root.state.get_mut();
+        if *state != STATE_UNPINNED {
             panic!("unclaimed called multiple times");
         }
-        self.state.set(STATE_UNCLAIMED);
+        *state = STATE_UNCLAIMED;
 
-        let ptr = unsafe { NonNull::from(self.get_unchecked_mut()) };
-        let ptr_erased = ptr.cast::<Root<Erased>>();
+        let root_ptr = NonNull::from(root);
+        let root_erased = root_ptr.cast::<Root<Erased>>();
 
+        // SAFETY: This is a partial write into a `MaybeUninit` behind a mutable reference.
         unsafe {
-            let target_root = ptr::addr_of_mut!((*(*ptr.as_ptr()).target.as_mut_ptr()).root);
-            target_root.write(Cell::new(Some(ptr_erased)));
+            let anchor_root =
+                ptr::addr_of_mut!((*(*root_ptr.as_ptr()).anchor.as_mut_ptr()).root_ptr);
+            anchor_root.write(Cell::new(Some(root_erased)));
         }
 
-        Unclaimed { ptr }
+        // SAFETY: The pointer is a pinned shared read-write borrow and we set the
+        // state to "unclaimed" above.
+        Unclaimed { root_ptr }
     }
 }
 
@@ -95,21 +106,22 @@ impl<T: Reclaim> Default for Root<T> {
     }
 }
 
-pub struct ProjectTarget<T: ?Sized> {
-    root: Cell<Option<NonNull<Root<Erased>>>>,
+pub struct ProjectAnchor<T: ?Sized> {
+    // pinned shared read-write borrow
+    root_ptr: Cell<Option<NonNull<Root<Erased>>>>,
     value: T,
 }
 
-impl<T> ProjectTarget<T> {
+impl<T> ProjectAnchor<T> {
     pub const fn new(value: T) -> Self {
         Self {
-            root: Cell::new(None),
+            root_ptr: Cell::new(None),
             value,
         }
     }
 }
 
-impl<T: ?Sized> Deref for ProjectTarget<T> {
+impl<T: ?Sized> Deref for ProjectAnchor<T> {
     type Target = T;
 
     fn deref(&self) -> &T {
@@ -117,7 +129,7 @@ impl<T: ?Sized> Deref for ProjectTarget<T> {
     }
 }
 
-impl<T: ?Sized> DerefMut for ProjectTarget<T> {
+impl<T: ?Sized> DerefMut for ProjectAnchor<T> {
     fn deref_mut(&mut self) -> &mut T {
         &mut self.value
     }
@@ -125,87 +137,106 @@ impl<T: ?Sized> DerefMut for ProjectTarget<T> {
 
 #[repr(transparent)]
 pub struct Unclaimed<T> {
-    ptr: NonNull<Root<T>>,
+    // pinned shared read-write borrow
+    root_ptr: NonNull<Root<T>>,
 }
 
 impl<T> Drop for Unclaimed<T> {
     fn drop(&mut self) {
-        unsafe {
-            debug_assert_eq!((*self.ptr.as_ptr()).state.get(), STATE_UNCLAIMED);
-            (*self.ptr.as_ptr()).state.set(STATE_UNPINNED);
-        }
+        // SAFETY: The pointer is a pinned shared read-write borrow and we don't
+        // move out of it.
+        let root = unsafe { self.root_ptr.as_mut() };
+        let state = root.state.get_mut();
+        debug_assert_eq!(*state, STATE_UNCLAIMED);
+        *state = STATE_UNPINNED;
     }
 }
 
 impl<T> Unclaimed<T> {
-    pub fn claim(self, value: T) -> Xrc<T> {
-        unsafe {
-            debug_assert_eq!((*self.ptr.as_ptr()).state.get(), STATE_UNCLAIMED);
-            (*self.ptr.as_ptr()).state.set(STATE_SHARED_INIT);
-        }
+    pub fn claim(mut self, value: T) -> Xrc<T> {
+        // SAFETY: The pointer is a pinned shared read-write borrow and we don't
+        // move out of it.
+        let root = unsafe { self.root_ptr.as_mut() };
+
+        let state = root.state.get_mut();
+        debug_assert_eq!(*state, STATE_UNCLAIMED);
+        *state = STATE_SHARED_INIT;
 
+        // SAFETY: This is a partial write into a `MaybeUninit` behind a mutable reference.
         unsafe {
-            let target_value = ptr::addr_of_mut!((*(*self.ptr.as_ptr()).target.as_mut_ptr()).value);
-            target_value.write(value);
+            let anchor_value = ptr::addr_of_mut!((*root.anchor.as_mut_ptr()).value);
+            anchor_value.write(value);
         }
 
-        // shared read-only reborrow
-        let target = unsafe { (*self.ptr.as_ptr()).target.assume_init_ref() };
-        let ptr = NonNull::from(target);
+        // SAFETY: `root.anchor` is fully initialized, because we initialized
+        // `.root_ptr` in `Root::unclaimed` and `.value` above.
+        // Note that we intentionally create a shared read-only borrow here,
+        // because `Xrc` must work with read-only borrows.
+        let anchor = unsafe { root.anchor.assume_init_ref() };
+        let anchor_ptr = NonNull::from(anchor);
         mem::forget(self);
-        Xrc { ptr }
+
+        // SAFETY:
+        // * The pointer is a shared read-only borrow.
+        // * We set the state to "shared" above.
+        // * The `anchor` is derived from `root`, `root` is derived from
+        //   `self.root_ptr` and `anchor.root_ptr` is `Some` and equal to
+        //   `self.root_ptr`, because we set it up that way in `Root::unclaimed`.
+        //   Therefore, `anchor` is derived from `anchor.root_ptr`.
+        //   Because we don't do any (non-interior) mutable accesses of
+        //   the root after this point, the anchor will remain valid.
+        Xrc { anchor_ptr }
     }
 }
 
 #[repr(transparent)]
 pub struct Xrc<T: ?Sized> {
-    ptr: NonNull<ProjectTarget<T>>,
+    // shared read-only borrow that outlives `anchor.root_ptr`, which must be `Some`
+    // Note that `anchor_ptr` may be derived from `anchor.root_ptr` with an equal
+    // lifetime, which means that we can only write to interior mutable parts of the
+    // root, because otherwise the anchor may get invalidated.
+    anchor_ptr: NonNull<ProjectAnchor<T>>,
 }
 
 impl<T: ?Sized> Drop for Xrc<T> {
     fn drop(&mut self) {
-        let Some(root) = Self::inner(self).root.get() else {
-            debug_assert!(false, "missing root");
-            unsafe {
-                hint::unreachable_unchecked();
-            }
-        };
+        let root_header = Self::root_header(self);
 
-        let state_ref = unsafe { &(*root.as_ptr()).state };
-        let state = state_ref.get();
+        let state = root_header.state.get();
         debug_assert!(state <= STATE_SHARED_MAX);
         if state != STATE_SHARED_INIT {
-            state_ref.set(state - 1);
+            root_header.state.set(state - 1);
             return;
         }
-        state_ref.set(STATE_UNCLAIMED);
+        root_header.state.set(STATE_UNCLAIMED);
 
-        let reclaim_shim = unsafe { (*root.as_ptr()).reclaim_shim };
+        let reclaim_shim = root_header.reclaim_shim;
+        let root_erased = Self::root_ptr(self);
+
+        // SAFETY:
+        // * We checked that there are no other `Xrc`s pointing to the same data.
+        // * We set the state to "unclaimed" and the old value is still in place.
         unsafe {
-            reclaim_shim(root);
+            reclaim_shim(root_erased);
         }
     }
 }
 
 impl<T> Clone for Xrc<T> {
     fn clone(&self) -> Self {
-        let root = Self::inner(self).root.get();
-        let Some(root) = root else {
-            debug_assert!(false, "missing root");
-            unsafe {
-                hint::unreachable_unchecked();
-            }
-        };
+        let root = Self::root_header(self);
 
-        let state_ref = unsafe { &(*root.as_ptr()).state };
-        let state = state_ref.get();
+        let state = root.state.get();
         debug_assert!(state <= STATE_SHARED_MAX);
         if state == STATE_SHARED_MAX {
             panic!("ref count overflow");
         }
-        state_ref.set(state + 1);
+        root.state.set(state + 1);
 
-        Self { ptr: self.ptr }
+        // SAFETY: We increased the reference count above.
+        Self {
+            anchor_ptr: self.anchor_ptr,
+        }
     }
 }
 
@@ -213,7 +244,7 @@ impl<T: ?Sized> Deref for Xrc<T> {
     type Target = T;
 
     fn deref(&self) -> &T {
-        &Self::inner(self).value
+        &Self::anchor(self).value
     }
 }
 
@@ -221,33 +252,56 @@ impl<T: ?Sized> Xrc<T> {
     pub fn project<U, F>(this: Self, f: F) -> Xrc<U>
     where
         U: ?Sized,
-        F: FnOnce(&ProjectTarget<T>) -> &ProjectTarget<U>,
+        F: FnOnce(&ProjectAnchor<T>) -> &ProjectAnchor<U>,
     {
-        let source = Self::inner(&this);
-        let Some(source_root) = source.root.get() else {
-            debug_assert!(false, "missing root");
-            unsafe {
-                hint::unreachable_unchecked();
-            }
-        };
+        let source = Self::anchor(&this);
 
         let target = f(source);
 
-        if let Some(target_root) = target.root.get() {
+        let source_root = Self::root_ptr(&this);
+
+        if let Some(target_root) = target.root_ptr.get() {
             if target_root != source_root {
                 panic!("projected from different roots");
             }
         } else {
-            target.root.set(Some(source_root));
+            target.root_ptr.set(Some(source_root));
         }
 
-        let ptr = NonNull::from(target);
+        let anchor_ptr = NonNull::from(target);
         mem::forget(this);
-        Xrc { ptr }
+
+        // SAFETY:
+        // * We ensure that `target.root_ptr == source.root_ptr` above.
+        // * Because `source.root_ptr` is `Some`, `target.root_ptr` is also `Some`.
+        // * The higher-ranked lifetime bound on the closure ensures that `target`
+        //   outlives `source`, which outlives `source.root_ptr`.
+        //   Therefore, `target` also outlives `target.root_ptr`.
+        // * We forget `self` above, so the reference count is unchanged.
+        Xrc { anchor_ptr }
+    }
+
+    fn anchor(this: &Self) -> &ProjectAnchor<T> {
+        // SAFETY: The pointer is a shared read-write borrow.
+        unsafe { this.anchor_ptr.as_ref() }
+    }
+
+    /// Returns a read-write pointer to the root with full provenance.
+    fn root_ptr(this: &Self) -> NonNull<Root<Erased>> {
+        let root_ptr = Self::anchor(this).root_ptr.get();
+        // SAFETY: The root pointer must be set to `Some` before this
+        // `Xrc` is constructed.
+        unsafe { root_ptr.unwrap_unchecked() }
     }
 
-    fn inner(this: &Self) -> &ProjectTarget<T> {
-        unsafe { this.ptr.as_ref() }
+    /// Returns a read-only reborrow of the erased root.
+    fn root_header(this: &Self) -> &Root<Erased> {
+        // SAFETY: We can derive a new shared read-only borrow without invalidating
+        // other derived shared read-only borrows, including the anchor pointer.
+        // Because `Root<T>` has a fixed layout, `T` is at the end, and `Erased`
+        // is a zero-sized type with alignment 1, we can reinterpret `Root<T>`
+        // as `Root<Erased>`.
+        unsafe { Self::root_ptr(this).as_ref() }
     }
 }
 
@@ -264,6 +318,7 @@ mod tests {
     use std::rc::Rc;
 
     fn iter_pinned_mut<T>(slice: Pin<&mut [T]>) -> impl Iterator<Item = Pin<&mut T>> {
+        // SAFETY: The elements of a slice are structurally pinned.
         unsafe {
             slice
                 .get_unchecked_mut()
@@ -370,13 +425,13 @@ mod tests {
     fn project() {
         struct Value {
             a: i32,
-            b: ProjectTarget<i32>,
+            b: ProjectAnchor<i32>,
         }
 
         let root = pin!(Root::new());
         let xrc = root.unclaimed().claim(NoReclaim(Value {
             a: 1,
-            b: ProjectTarget::new(2),
+            b: ProjectAnchor::new(2),
         }));
         assert_eq!(xrc.0.a, 1);
         assert_eq!(*xrc.0.b, 2);
@@ -397,13 +452,13 @@ mod tests {
     #[test]
     #[should_panic = "projected from different roots"]
     fn project_multiple() {
-        let target = Rc::new(ProjectTarget::new(0));
+        let anchor = Rc::new(ProjectAnchor::new(0));
 
         let root1 = pin!(Root::new());
-        let xrc1 = root1.unclaimed().claim(NoReclaim(target.clone()));
+        let xrc1 = root1.unclaimed().claim(NoReclaim(anchor.clone()));
 
         let root2 = pin!(Root::new());
-        let xrc2 = root2.unclaimed().claim(NoReclaim(target));
+        let xrc2 = root2.unclaimed().claim(NoReclaim(anchor));
 
         let _ = Xrc::project(xrc1, |xrc| &*xrc.0);
         let _ = Xrc::project(xrc2, |xrc| &*xrc.0);
@@ -450,8 +505,8 @@ mod tests {
         }
 
         struct Value {
-            number: ProjectTarget<i32>,
-            string: ProjectTarget<String>,
+            number: ProjectAnchor<i32>,
+            string: ProjectAnchor<String>,
         }
 
         #[allow(clippy::declare_interior_mutable_const)]
@@ -466,8 +521,8 @@ mod tests {
         let mut values = Vec::new();
         for n in 0..8 {
             let value = Value {
-                number: ProjectTarget::new(n),
-                string: ProjectTarget::new(n.to_string()),
+                number: ProjectAnchor::new(n),
+                string: ProjectAnchor::new(n.to_string()),
             };
             let xrc = alloc.try_alloc(value).ok().unwrap();
             values.push(xrc);
@@ -490,8 +545,8 @@ mod tests {
             .eq(0..8));
 
         let value = Value {
-            number: ProjectTarget::new(0),
-            string: ProjectTarget::new(String::new()),
+            number: ProjectAnchor::new(0),
+            string: ProjectAnchor::new(String::new()),
         };
 
         let value = alloc.try_alloc(value).err().unwrap();
-- 
GitLab