From 4307c685579bf511b458e7ec3e0be902d0d93c9c Mon Sep 17 00:00:00 2001
From: Lukas Markeffsky <@>
Date: Thu, 1 Feb 2024 10:21:09 +0000
Subject: [PATCH] fix soundness

---
 src/lib.rs | 38 ++++++++++++++++++++++++++++++++------
 1 file changed, 32 insertions(+), 6 deletions(-)

diff --git a/src/lib.rs b/src/lib.rs
index 16b985a..5ec943a 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -19,7 +19,7 @@ unsafe fn reclaim_shim<T: Reclaim>(erased_ptr: NonNull<Root<Erased>>) {
         (*ptr.as_ptr()).state.set(STATE_UNCLAIMED);
     };
 
-    let value = unsafe { (*ptr.as_ptr()).target.assume_init_read().value };
+    let value = unsafe { ptr::addr_of!((*(*ptr.as_ptr()).target.as_ptr()).value).read() };
 
     let unclaimed = Unclaimed { ptr };
     value.reclaim(unclaimed);
@@ -48,7 +48,7 @@ pub struct Root<T> {
 impl<T> Drop for Root<T> {
     fn drop(&mut self) {
         let state = self.state.get();
-        if state <= STATE_SHARED_MAX {
+        if state != STATE_UNPINNED {
             struct Abort;
             impl Drop for Abort {
                 fn drop(&mut self) {
@@ -132,6 +132,15 @@ pub struct Unclaimed<T> {
     ptr: NonNull<Root<T>>,
 }
 
+impl<T> Drop for Unclaimed<T> {
+    fn drop(&mut self) {
+        unsafe {
+            debug_assert_eq!((*self.ptr.as_ptr()).state.get(), STATE_UNCLAIMED);
+            (*self.ptr.as_ptr()).state.set(STATE_UNPINNED);
+        }
+    }
+}
+
 impl<T> Unclaimed<T> {
     pub fn claim(self, value: T) -> Xrc<T> {
         unsafe {
@@ -147,6 +156,7 @@ impl<T> Unclaimed<T> {
         // shared read-only reborrow
         let target = unsafe { (*self.ptr.as_ptr()).target.assume_init_ref() };
         let ptr = NonNull::from(target);
+        mem::forget(self);
         Xrc { ptr }
     }
 }
@@ -280,7 +290,16 @@ mod tests {
 
     #[test]
     #[ignore = "will abort"]
-    fn abort() {
+    fn abort_unclaimed() {
+        #[allow(clippy::needless_late_init)]
+        let _unclaimed;
+        let root = pin!(Root::<NoReclaim<()>>::new());
+        _unclaimed = root.unclaimed();
+    }
+
+    #[test]
+    #[ignore = "will abort"]
+    fn abort_xrc() {
         #[allow(clippy::needless_late_init)]
         let _xrc;
         let root = pin!(Root::new());
@@ -288,10 +307,17 @@ mod tests {
     }
 
     #[test]
-    #[should_panic = "unclaimed called multiple times"]
-    fn multiple_unclaimed() {
-        let mut root = pin!(Root::<NoReclaim<i32>>::new());
+    fn multiple_unclaimed_separate() {
+        let mut root = pin!(Root::<NoReclaim<()>>::new());
+        let _ = root.as_mut().unclaimed();
         let _ = root.as_mut().unclaimed();
+    }
+
+    #[test]
+    #[should_panic = "unclaimed called multiple times"]
+    fn multiple_unclaimed_concurrent() {
+        let mut root = pin!(Root::<NoReclaim<()>>::new());
+        let _unclaimed = root.as_mut().unclaimed();
         let _ = root.as_mut().unclaimed();
     }
 
-- 
GitLab