#![no_std]
#![deny(unsafe_op_in_unsafe_fn)]

use core::cell::Cell;
use core::marker::PhantomPinned;
use core::mem::{self, MaybeUninit};
use core::ops::{Deref, DerefMut};
use core::pin::Pin;
use core::ptr::{self, NonNull};

pub trait Reclaim: Sized {
    fn reclaim(self, unclaimed: Unclaimed<Self>);
}

unsafe fn reclaim_shim<T: Reclaim>(erased_ptr: NonNull<Root<Erased>>) {
    let root_ptr = erased_ptr.cast::<Root<T>>();

    // SAFETY: The caller guarantees that this value is initialized.
    let value = unsafe { ptr::addr_of!((*(*root_ptr.as_ptr()).anchor.as_ptr()).value).read() };

    // SAFETY: The caller guarantees the state is "unclaimed".
    let unclaimed = Unclaimed { root_ptr };
    value.reclaim(unclaimed);
}

pub struct NoReclaim<T: ?Sized>(pub T);
impl<T> Reclaim for NoReclaim<T> {
    fn reclaim(self, _unclaimed: Unclaimed<Self>) {}
}

struct Erased;

const STATE_UNPINNED: usize = usize::MAX;
const STATE_UNCLAIMED: usize = usize::MAX - 1;
const STATE_SHARED_MAX: usize = isize::MAX as usize;
const STATE_SHARED_INIT: usize = 0;

// Fixed layout is required so that the header is in the same place for all T.
#[repr(C)]
pub struct Root<T> {
    state: Cell<usize>,
    reclaim_shim: unsafe fn(NonNull<Root<Erased>>),
    _pin: PhantomPinned,
    anchor: MaybeUninit<ProjectAnchor<T>>,
}

impl<T> Drop for Root<T> {
    fn drop(&mut self) {
        let state = *self.state.get_mut();
        if state != STATE_UNPINNED {
            struct Abort;
            impl Drop for Abort {
                fn drop(&mut self) {
                    panic!("abort");
                }
            }

            let _abort = Abort;
            panic!("root leak");
        }
    }
}

impl<T> Root<T> {
    pub const fn new() -> Self
    where
        T: Reclaim,
    {
        Self {
            state: Cell::new(STATE_UNPINNED),
            reclaim_shim: reclaim_shim::<T>,
            _pin: PhantomPinned,
            anchor: MaybeUninit::uninit(),
        }
    }

    pub fn unclaimed(self: Pin<&mut Self>) -> Unclaimed<T> {
        // SAFETY: We don't move out of this reference.
        let root = unsafe { self.get_unchecked_mut() };

        let state = root.state.get_mut();
        if *state != STATE_UNPINNED {
            panic!("unclaimed called multiple times");
        }
        *state = STATE_UNCLAIMED;

        let root_ptr = NonNull::from(root);
        let root_erased = root_ptr.cast::<Root<Erased>>();

        // SAFETY: This is a partial write into a `MaybeUninit` behind a mutable reference.
        unsafe {
            let anchor_root =
                ptr::addr_of_mut!((*(*root_ptr.as_ptr()).anchor.as_mut_ptr()).root_ptr);
            anchor_root.write(Cell::new(Some(root_erased)));
        }

        // SAFETY: The pointer is a pinned shared read-write borrow and we set the
        // state to "unclaimed" above.
        Unclaimed { root_ptr }
    }
}

impl<T: Reclaim> Default for Root<T> {
    fn default() -> Self {
        Self::new()
    }
}

pub struct ProjectAnchor<T: ?Sized> {
    // pinned shared read-write borrow
    root_ptr: Cell<Option<NonNull<Root<Erased>>>>,
    value: T,
}

impl<T> ProjectAnchor<T> {
    pub const fn new(value: T) -> Self {
        Self {
            root_ptr: Cell::new(None),
            value,
        }
    }
}

impl<T: ?Sized> Deref for ProjectAnchor<T> {
    type Target = T;

    fn deref(&self) -> &T {
        &self.value
    }
}

impl<T: ?Sized> DerefMut for ProjectAnchor<T> {
    fn deref_mut(&mut self) -> &mut T {
        &mut self.value
    }
}

#[repr(transparent)]
pub struct Unclaimed<T> {
    // pinned shared read-write borrow
    root_ptr: NonNull<Root<T>>,
}

impl<T> Drop for Unclaimed<T> {
    fn drop(&mut self) {
        // SAFETY: The pointer is a pinned shared read-write borrow and we don't
        // move out of it.
        let root = unsafe { self.root_ptr.as_mut() };
        let state = root.state.get_mut();
        debug_assert_eq!(*state, STATE_UNCLAIMED);
        *state = STATE_UNPINNED;
    }
}

impl<T> Unclaimed<T> {
    pub fn claim(mut self, value: T) -> Xrc<T> {
        // SAFETY: The pointer is a pinned shared read-write borrow and we don't
        // move out of it.
        let root = unsafe { self.root_ptr.as_mut() };

        let state = root.state.get_mut();
        debug_assert_eq!(*state, STATE_UNCLAIMED);
        *state = STATE_SHARED_INIT;

        // SAFETY: This is a partial write into a `MaybeUninit` behind a mutable reference.
        unsafe {
            let anchor_value = ptr::addr_of_mut!((*root.anchor.as_mut_ptr()).value);
            anchor_value.write(value);
        }

        // SAFETY: `root.anchor` is fully initialized, because we initialized
        // `.root_ptr` in `Root::unclaimed` and `.value` above.
        // Note that we intentionally create a shared read-only borrow here,
        // because `Xrc` must work with read-only borrows.
        let anchor = unsafe { root.anchor.assume_init_ref() };
        let anchor_ptr = NonNull::from(anchor);
        mem::forget(self);

        // SAFETY:
        // * The pointer is a shared read-only borrow.
        // * We set the state to "shared" above.
        // * The `anchor` is derived from `root`, `root` is derived from
        //   `self.root_ptr` and `anchor.root_ptr` is `Some` and equal to
        //   `self.root_ptr`, because we set it up that way in `Root::unclaimed`.
        //   Therefore, `anchor` is derived from `anchor.root_ptr`.
        //   Because we don't do any (non-interior) mutable accesses of
        //   the root after this point, the anchor will remain valid.
        Xrc { anchor_ptr }
    }
}

#[repr(transparent)]
pub struct Xrc<T: ?Sized> {
    // shared read-only borrow that outlives `anchor.root_ptr`, which must be `Some`
    // Note that `anchor_ptr` may be derived from `anchor.root_ptr` with an equal
    // lifetime, which means that we can only write to interior mutable parts of the
    // root, because otherwise the anchor may get invalidated.
    anchor_ptr: NonNull<ProjectAnchor<T>>,
}

impl<T: ?Sized> Drop for Xrc<T> {
    fn drop(&mut self) {
        let root_header = Self::root_header(self);

        let state = root_header.state.get();
        debug_assert!(state <= STATE_SHARED_MAX);
        if state != STATE_SHARED_INIT {
            root_header.state.set(state - 1);
            return;
        }
        root_header.state.set(STATE_UNCLAIMED);

        let reclaim_shim = root_header.reclaim_shim;
        let root_erased = Self::root_ptr(self);

        // SAFETY:
        // * We checked that there are no other `Xrc`s pointing to the same data.
        // * We set the state to "unclaimed" and the old value is still in place.
        unsafe {
            reclaim_shim(root_erased);
        }
    }
}

impl<T> Clone for Xrc<T> {
    fn clone(&self) -> Self {
        let root = Self::root_header(self);

        let state = root.state.get();
        debug_assert!(state <= STATE_SHARED_MAX);
        if state == STATE_SHARED_MAX {
            panic!("ref count overflow");
        }
        root.state.set(state + 1);

        // SAFETY: We increased the reference count above.
        Self {
            anchor_ptr: self.anchor_ptr,
        }
    }
}

impl<T: ?Sized> Deref for Xrc<T> {
    type Target = T;

    fn deref(&self) -> &T {
        &Self::anchor(self).value
    }
}

impl<T: ?Sized> Xrc<T> {
    pub fn project<U, F>(this: Self, f: F) -> Xrc<U>
    where
        U: ?Sized,
        F: FnOnce(&ProjectAnchor<T>) -> &ProjectAnchor<U>,
    {
        let source = Self::anchor(&this);

        let target = f(source);

        let source_root = Self::root_ptr(&this);

        if let Some(target_root) = target.root_ptr.get() {
            if target_root != source_root {
                panic!("projected from different roots");
            }
        } else {
            target.root_ptr.set(Some(source_root));
        }

        let anchor_ptr = NonNull::from(target);
        mem::forget(this);

        // SAFETY:
        // * We ensure that `target.root_ptr == source.root_ptr` above.
        // * Because `source.root_ptr` is `Some`, `target.root_ptr` is also `Some`.
        // * The higher-ranked lifetime bound on the closure ensures that `target`
        //   outlives `source`, which outlives `source.root_ptr`.
        //   Therefore, `target` also outlives `target.root_ptr`.
        // * We forget `self` above, so the reference count is unchanged.
        Xrc { anchor_ptr }
    }

    fn anchor(this: &Self) -> &ProjectAnchor<T> {
        // SAFETY: The pointer is a shared read-write borrow.
        unsafe { this.anchor_ptr.as_ref() }
    }

    /// Returns a read-write pointer to the root with full provenance.
    fn root_ptr(this: &Self) -> NonNull<Root<Erased>> {
        let root_ptr = Self::anchor(this).root_ptr.get();
        // SAFETY: The root pointer must be set to `Some` before this
        // `Xrc` is constructed.
        unsafe { root_ptr.unwrap_unchecked() }
    }

    /// Returns a read-only reborrow of the erased root.
    fn root_header(this: &Self) -> &Root<Erased> {
        // SAFETY: We can derive a new shared read-only borrow without invalidating
        // other derived shared read-only borrows, including the anchor pointer.
        // Because `Root<T>` has a fixed layout, `T` is at the end, and `Erased`
        // is a zero-sized type with alignment 1, we can reinterpret `Root<T>`
        // as `Root<Erased>`.
        unsafe { Self::root_ptr(this).as_ref() }
    }
}

#[cfg(test)]
mod tests {
    extern crate std;
    use super::*;
    use std::prelude::rust_2021::*;

    use std::cell::RefCell;
    use std::fmt::Display;
    use std::panic::{self, AssertUnwindSafe};
    use std::pin::pin;
    use std::rc::Rc;

    fn iter_pinned_mut<T>(slice: Pin<&mut [T]>) -> impl Iterator<Item = Pin<&mut T>> {
        // SAFETY: The elements of a slice are structurally pinned.
        unsafe {
            slice
                .get_unchecked_mut()
                .iter_mut()
                .map(|elem| Pin::new_unchecked(elem))
        }
    }

    #[test]
    fn basic() {
        let root = pin!(Root::new());
        let xrc = root.unclaimed().claim(NoReclaim(42));
        assert_eq!(xrc.0, 42);
    }

    #[test]
    fn unpin_hack() {
        let mut root = pin!(Root::new());
        let xrc = root.as_mut().unclaimed().claim(NoReclaim(42));

        let res = panic::catch_unwind(AssertUnwindSafe(|| {
            // This must not create an unique borrow ...
            let _ = root.unclaimed();
        }));
        assert!(res.is_err());

        // ... to not invalidate this shared read-only borrow.
        assert_eq!(xrc.0, 42);
    }

    #[test]
    #[ignore = "will abort"]
    fn abort_unclaimed() {
        #[allow(clippy::needless_late_init)]
        let _unclaimed;
        let root = pin!(Root::<NoReclaim<()>>::new());
        _unclaimed = root.unclaimed();
    }

    #[test]
    #[ignore = "will abort"]
    fn abort_xrc() {
        #[allow(clippy::needless_late_init)]
        let _xrc;
        let root = pin!(Root::new());
        _xrc = root.unclaimed().claim(NoReclaim(()));
    }

    #[test]
    fn multiple_unclaimed_separate() {
        let mut root = pin!(Root::<NoReclaim<()>>::new());
        let _ = root.as_mut().unclaimed();
        let _ = root.as_mut().unclaimed();
    }

    #[test]
    #[should_panic = "unclaimed called multiple times"]
    fn multiple_unclaimed_concurrent() {
        let mut root = pin!(Root::<NoReclaim<()>>::new());
        let _unclaimed = root.as_mut().unclaimed();
        let _ = root.as_mut().unclaimed();
    }

    #[test]
    fn clone_reclaim_drop() {
        struct Check {
            reclaim_count: Rc<Cell<usize>>,
            drop_count: Rc<Cell<usize>>,
        }

        impl Reclaim for Check {
            fn reclaim(self, _unclaimed: Unclaimed<Self>) {
                self.reclaim_count.set(self.reclaim_count.get() + 1);
            }
        }

        impl Drop for Check {
            fn drop(&mut self) {
                self.drop_count.set(self.drop_count.get() + 1);
            }
        }

        let drop_count = Rc::new(Cell::new(0));
        let reclaim_count = Rc::new(Cell::new(0));
        let root = pin!(Root::new());
        let xrc = root.unclaimed().claim(Check {
            drop_count: drop_count.clone(),
            reclaim_count: reclaim_count.clone(),
        });
        assert_eq!(xrc.drop_count.get(), 0);
        assert_eq!(xrc.reclaim_count.get(), 0);
        let xrc2 = xrc.clone();
        assert_eq!(xrc.drop_count.get(), 0);
        assert_eq!(xrc.reclaim_count.get(), 0);
        drop(xrc);
        assert_eq!(xrc2.drop_count.get(), 0);
        assert_eq!(xrc2.reclaim_count.get(), 0);
        drop(xrc2);
        assert_eq!(drop_count.get(), 1);
        assert_eq!(reclaim_count.get(), 1);
    }

    #[test]
    fn project() {
        struct Value {
            a: i32,
            b: ProjectAnchor<i32>,
        }

        let root = pin!(Root::new());
        let xrc = root.unclaimed().claim(NoReclaim(Value {
            a: 1,
            b: ProjectAnchor::new(2),
        }));
        assert_eq!(xrc.0.a, 1);
        assert_eq!(*xrc.0.b, 2);
        let xrc2 = Xrc::project(xrc.clone(), |xrc| &xrc.0.b);
        assert_eq!(xrc.0.a, 1);
        assert_eq!(*xrc.0.b, 2);
        assert_eq!(*xrc2, 2);
    }

    #[test]
    fn project_unsized() {
        let root = pin!(Root::new());
        let xrc = root.unclaimed().claim(NoReclaim(42));
        let xrc = Xrc::project::<NoReclaim<dyn Display>, _>(xrc, |xrc| xrc);
        assert_eq!(std::format!("{}", &xrc.0), "42");
    }

    #[test]
    #[should_panic = "projected from different roots"]
    fn project_multiple() {
        let anchor = Rc::new(ProjectAnchor::new(0));

        let root1 = pin!(Root::new());
        let xrc1 = root1.unclaimed().claim(NoReclaim(anchor.clone()));

        let root2 = pin!(Root::new());
        let xrc2 = root2.unclaimed().claim(NoReclaim(anchor));

        let _ = Xrc::project(xrc1, |xrc| &*xrc.0);
        let _ = Xrc::project(xrc2, |xrc| &*xrc.0);
    }

    #[test]
    fn gc_alloc() {
        struct Gc<T> {
            free: Rc<RefCell<Vec<Unclaimed<Gc<T>>>>>,
            value: T,
        }

        impl<T> Deref for Gc<T> {
            type Target = T;

            fn deref(&self) -> &T {
                &self.value
            }
        }

        impl<T> Reclaim for Gc<T> {
            fn reclaim(self, unclaimed: Unclaimed<Self>) {
                self.free.borrow_mut().push(unclaimed)
            }
        }

        struct Alloc<T> {
            free: Rc<RefCell<Vec<Unclaimed<Gc<T>>>>>,
        }

        impl<T> Alloc<T> {
            fn try_alloc(&self, value: T) -> Result<Xrc<Gc<T>>, T> {
                let unclaimed = self.free.borrow_mut().pop();
                // drop RefMut
                if let Some(unclaimed) = unclaimed {
                    Ok(unclaimed.claim(Gc {
                        free: self.free.clone(),
                        value,
                    }))
                } else {
                    Err(value)
                }
            }
        }

        struct Value {
            number: ProjectAnchor<i32>,
            string: ProjectAnchor<String>,
        }

        #[allow(clippy::declare_interior_mutable_const)]
        const INIT: Root<Gc<Value>> = Root::new();
        let mut storage = pin!([INIT; 8]);
        let free = iter_pinned_mut(storage.as_mut())
            .map(|root| root.unclaimed())
            .collect::<Vec<_>>();
        let free = Rc::new(RefCell::new(free));
        let alloc = Alloc { free };

        let mut values = Vec::new();
        for n in 0..8 {
            let value = Value {
                number: ProjectAnchor::new(n),
                string: ProjectAnchor::new(n.to_string()),
            };
            let xrc = alloc.try_alloc(value).ok().unwrap();
            values.push(xrc);
        }

        let mut value_nums = values
            .iter()
            .map(|xrc| Xrc::project(xrc.clone(), |value| &value.number))
            .collect::<Vec<_>>();

        let mut value_strs = values
            .iter()
            .map(|xrc| Xrc::project(xrc.clone(), |value| &value.string))
            .collect::<Vec<_>>();

        assert!(value_nums.iter().map(|xrc| **xrc).eq(0..8));
        assert!(value_strs
            .iter()
            .map(|xrc| xrc.parse::<i32>().unwrap())
            .eq(0..8));

        let value = Value {
            number: ProjectAnchor::new(0),
            string: ProjectAnchor::new(String::new()),
        };

        let value = alloc.try_alloc(value).err().unwrap();
        values.clear();
        let value = alloc.try_alloc(value).err().unwrap();
        value_nums.clear();
        let value = alloc.try_alloc(value).err().unwrap();
        value_strs.clear();
        let _xrc = alloc.try_alloc(value).ok().unwrap();
    }
}