#![no_std] use core::cell::Cell; use core::marker::PhantomPinned; use core::mem::{self, MaybeUninit}; use core::ops::{Deref, DerefMut}; use core::pin::Pin; use core::ptr::{self, NonNull}; pub trait Reclaim: Sized { fn reclaim(self, unclaimed: Unclaimed<Self>); } unsafe fn reclaim_shim<T: Reclaim>(erased_ptr: NonNull<Root<Erased>>) { let ptr = erased_ptr.cast::<Root<T>>(); unsafe { debug_assert_eq!((*ptr.as_ptr()).state.get(), STATE_SHARED_INIT); (*ptr.as_ptr()).state.set(STATE_UNCLAIMED); }; let value = unsafe { (*ptr.as_ptr()).target.assume_init_read().value }; let unclaimed = Unclaimed { ptr }; value.reclaim(unclaimed); } pub struct NoReclaim<T: ?Sized>(pub T); impl<T> Reclaim for NoReclaim<T> { fn reclaim(self, _unclaimed: Unclaimed<Self>) {} } struct Erased; const STATE_UNPINNED: usize = usize::MAX; const STATE_UNCLAIMED: usize = usize::MAX - 1; const STATE_SHARED_MAX: usize = isize::MAX as usize; const STATE_SHARED_INIT: usize = 0; #[repr(C)] pub struct Root<T> { state: Cell<usize>, reclaim_shim: unsafe fn(NonNull<Root<Erased>>), _pin: PhantomPinned, target: MaybeUninit<ProjectTarget<T>>, } impl<T> Drop for Root<T> { fn drop(&mut self) { let state = self.state.get(); if state <= STATE_SHARED_MAX { struct Abort; impl Drop for Abort { fn drop(&mut self) { panic!("abort"); } } let _abort = Abort; panic!("root leak"); } } } impl<T> Root<T> { pub const fn new() -> Self where T: Reclaim, { Self { state: Cell::new(STATE_UNPINNED), reclaim_shim: reclaim_shim::<T>, _pin: PhantomPinned, target: MaybeUninit::uninit(), } } pub fn unclaimed(self: Pin<&mut Self>) -> Unclaimed<T> { if self.state.get() != STATE_UNPINNED { panic!("unclaimed called multiple times"); } self.state.set(STATE_UNCLAIMED); let ptr = unsafe { NonNull::from(self.get_unchecked_mut()) }; let ptr_erased = ptr.cast::<Root<Erased>>(); unsafe { let target_root = ptr::addr_of_mut!((*(*ptr.as_ptr()).target.as_mut_ptr()).root); target_root.write(Cell::new(Some(ptr_erased))); } Unclaimed { ptr } } } impl<T: Reclaim> Default for Root<T> { fn default() -> Self { Self::new() } } pub struct ProjectTarget<T: ?Sized> { root: Cell<Option<NonNull<Root<Erased>>>>, value: T, } impl<T> ProjectTarget<T> { pub const fn new(value: T) -> Self { Self { root: Cell::new(None), value, } } } impl<T: ?Sized> Deref for ProjectTarget<T> { type Target = T; fn deref(&self) -> &T { &self.value } } impl<T: ?Sized> DerefMut for ProjectTarget<T> { fn deref_mut(&mut self) -> &mut T { &mut self.value } } #[repr(transparent)] pub struct Unclaimed<T> { ptr: NonNull<Root<T>>, } impl<T> Unclaimed<T> { pub fn claim(self, value: T) -> Xrc<T> { unsafe { debug_assert_eq!((*self.ptr.as_ptr()).state.get(), STATE_UNCLAIMED); (*self.ptr.as_ptr()).state.set(STATE_SHARED_INIT); } unsafe { let target_value = ptr::addr_of_mut!((*(*self.ptr.as_ptr()).target.as_mut_ptr()).value); target_value.write(value); } // shared read-only reborrow let target = unsafe { (*self.ptr.as_ptr()).target.assume_init_ref() }; let ptr = NonNull::from(target); Xrc { ptr } } } #[repr(transparent)] pub struct Xrc<T: ?Sized> { ptr: NonNull<ProjectTarget<T>>, } impl<T: ?Sized> Drop for Xrc<T> { fn drop(&mut self) { let Some(root) = Self::inner(self).root.get() else { unreachable!("missing root"); }; let state_ref = unsafe { &(*root.as_ptr()).state }; let state = state_ref.get(); debug_assert!(state <= STATE_SHARED_MAX); if state != STATE_SHARED_INIT { state_ref.set(state - 1); return; } let reclaim_shim = unsafe { (*root.as_ptr()).reclaim_shim }; unsafe { reclaim_shim(root); } } } impl<T> Clone for Xrc<T> { fn clone(&self) -> Self { let root = Self::inner(self).root.get(); let Some(root) = root else { unreachable!("missing root"); }; let state_ref = unsafe { &(*root.as_ptr()).state }; let state = state_ref.get(); debug_assert!(state <= STATE_SHARED_MAX); if state == STATE_SHARED_MAX { panic!("ref count overflow"); } state_ref.set(state + 1); Self { ptr: self.ptr } } } impl<T: ?Sized> Deref for Xrc<T> { type Target = T; fn deref(&self) -> &T { &Self::inner(self).value } } impl<T: ?Sized> Xrc<T> { pub fn project<U, F>(this: Self, f: F) -> Xrc<U> where U: ?Sized, F: FnOnce(&ProjectTarget<T>) -> &ProjectTarget<U>, { let source = Self::inner(&this); let Some(source_root) = source.root.get() else { unreachable!("missing root"); }; let target = f(source); if let Some(target_root) = target.root.get() { if target_root != source_root { panic!("projected from different roots"); } } else { target.root.set(Some(source_root)); } let ptr = NonNull::from(target); mem::forget(this); Xrc { ptr } } fn inner(this: &Self) -> &ProjectTarget<T> { unsafe { this.ptr.as_ref() } } } #[cfg(test)] mod tests { extern crate std; use super::*; use std::prelude::rust_2021::*; use std::cell::RefCell; use std::fmt::Display; use std::panic::{self, AssertUnwindSafe}; use std::pin::pin; use std::rc::Rc; fn iter_pinned_mut<T>(slice: Pin<&mut [T]>) -> impl Iterator<Item = Pin<&mut T>> { unsafe { slice .get_unchecked_mut() .iter_mut() .map(|elem| Pin::new_unchecked(elem)) } } #[test] fn basic() { let root = pin!(Root::new()); let xrc = root.unclaimed().claim(NoReclaim(42)); assert_eq!(xrc.0, 42); } #[test] fn unpin_hack() { let mut root = pin!(Root::new()); let xrc = root.as_mut().unclaimed().claim(NoReclaim(42)); let res = panic::catch_unwind(AssertUnwindSafe(|| { // This must not create an unique borrow ... let _ = root.unclaimed(); })); assert!(res.is_err()); // ... to not invalidate this shared read-only borrow. assert_eq!(xrc.0, 42); } #[test] #[ignore = "will abort"] fn abort() { #[allow(clippy::needless_late_init)] let _xrc; let root = pin!(Root::new()); _xrc = root.unclaimed().claim(NoReclaim(())); } #[test] #[should_panic = "unclaimed called multiple times"] fn multiple_unclaimed() { let mut root = pin!(Root::<NoReclaim<i32>>::new()); let _ = root.as_mut().unclaimed(); let _ = root.as_mut().unclaimed(); } #[test] fn clone_reclaim_drop() { struct Check { reclaim_count: Rc<Cell<usize>>, drop_count: Rc<Cell<usize>>, } impl Reclaim for Check { fn reclaim(self, _unclaimed: Unclaimed<Self>) { self.reclaim_count.set(self.reclaim_count.get() + 1); } } impl Drop for Check { fn drop(&mut self) { self.drop_count.set(self.drop_count.get() + 1); } } let drop_count = Rc::new(Cell::new(0)); let reclaim_count = Rc::new(Cell::new(0)); let root = pin!(Root::new()); let xrc = root.unclaimed().claim(Check { drop_count: drop_count.clone(), reclaim_count: reclaim_count.clone(), }); assert_eq!(xrc.drop_count.get(), 0); assert_eq!(xrc.reclaim_count.get(), 0); let xrc2 = xrc.clone(); assert_eq!(xrc.drop_count.get(), 0); assert_eq!(xrc.reclaim_count.get(), 0); drop(xrc); assert_eq!(xrc2.drop_count.get(), 0); assert_eq!(xrc2.reclaim_count.get(), 0); drop(xrc2); assert_eq!(drop_count.get(), 1); assert_eq!(reclaim_count.get(), 1); } #[test] fn project() { struct Value { a: i32, b: ProjectTarget<i32>, } let root = pin!(Root::new()); let xrc = root.unclaimed().claim(NoReclaim(Value { a: 1, b: ProjectTarget::new(2), })); assert_eq!(xrc.0.a, 1); assert_eq!(*xrc.0.b, 2); let xrc2 = Xrc::project(xrc.clone(), |xrc| &xrc.0.b); assert_eq!(xrc.0.a, 1); assert_eq!(*xrc.0.b, 2); assert_eq!(*xrc2, 2); } #[test] fn project_unsized() { let root = pin!(Root::new()); let xrc = root.unclaimed().claim(NoReclaim(42)); let xrc = Xrc::project::<NoReclaim<dyn Display>, _>(xrc, |xrc| xrc); assert_eq!(std::format!("{}", &xrc.0), "42"); } #[test] #[should_panic = "projected from different roots"] fn project_multiple() { let target = Rc::new(ProjectTarget::new(0)); let root1 = pin!(Root::new()); let xrc1 = root1.unclaimed().claim(NoReclaim(target.clone())); let root2 = pin!(Root::new()); let xrc2 = root2.unclaimed().claim(NoReclaim(target)); let _ = Xrc::project(xrc1, |xrc| &*xrc.0); let _ = Xrc::project(xrc2, |xrc| &*xrc.0); } #[test] fn gc_alloc() { struct Gc<T> { free: Rc<RefCell<Vec<Unclaimed<Gc<T>>>>>, value: T, } impl<T> Deref for Gc<T> { type Target = T; fn deref(&self) -> &T { &self.value } } impl<T> Reclaim for Gc<T> { fn reclaim(self, unclaimed: Unclaimed<Self>) { self.free.borrow_mut().push(unclaimed) } } struct Alloc<T> { free: Rc<RefCell<Vec<Unclaimed<Gc<T>>>>>, } impl<T> Alloc<T> { fn try_alloc(&self, value: T) -> Result<Xrc<Gc<T>>, T> { let unclaimed = self.free.borrow_mut().pop(); // drop RefMut if let Some(unclaimed) = unclaimed { Ok(unclaimed.claim(Gc { free: self.free.clone(), value, })) } else { Err(value) } } } struct Value { number: ProjectTarget<i32>, string: ProjectTarget<String>, } #[allow(clippy::declare_interior_mutable_const)] const INIT: Root<Gc<Value>> = Root::new(); let mut storage = pin!([INIT; 8]); let free = iter_pinned_mut(storage.as_mut()) .map(|root| root.unclaimed()) .collect::<Vec<_>>(); let free = Rc::new(RefCell::new(free)); let alloc = Alloc { free }; let mut values = Vec::new(); for n in 0..8 { let value = Value { number: ProjectTarget::new(n), string: ProjectTarget::new(n.to_string()), }; let xrc = alloc.try_alloc(value).ok().unwrap(); values.push(xrc); } let mut value_nums = values .iter() .map(|xrc| Xrc::project(xrc.clone(), |value| &value.number)) .collect::<Vec<_>>(); let mut value_strs = values .iter() .map(|xrc| Xrc::project(xrc.clone(), |value| &value.string)) .collect::<Vec<_>>(); assert!(value_nums.iter().map(|xrc| **xrc).eq(0..8)); assert!(value_strs .iter() .map(|xrc| xrc.parse::<i32>().unwrap()) .eq(0..8)); let value = Value { number: ProjectTarget::new(0), string: ProjectTarget::new(String::new()), }; let value = alloc.try_alloc(value).err().unwrap(); values.clear(); let value = alloc.try_alloc(value).err().unwrap(); value_nums.clear(); let value = alloc.try_alloc(value).err().unwrap(); value_strs.clear(); let _xrc = alloc.try_alloc(value).ok().unwrap(); } }