#![no_std]

use core::cell::Cell;
use core::marker::PhantomPinned;
use core::mem::{self, MaybeUninit};
use core::ops::{Deref, DerefMut};
use core::pin::Pin;
use core::ptr::{self, NonNull};

pub trait Reclaim: Sized {
    fn reclaim(self, unclaimed: Unclaimed<Self>);
}

unsafe fn reclaim_shim<T: Reclaim>(erased_ptr: NonNull<Root<Erased>>) {
    let ptr = erased_ptr.cast::<Root<T>>();

    unsafe {
        debug_assert_eq!((*ptr.as_ptr()).state.get(), STATE_SHARED_INIT);
        (*ptr.as_ptr()).state.set(STATE_UNCLAIMED);
    };

    let value = unsafe { (*ptr.as_ptr()).target.assume_init_read().value };

    let unclaimed = Unclaimed { ptr };
    value.reclaim(unclaimed);
}

pub struct NoReclaim<T: ?Sized>(pub T);
impl<T> Reclaim for NoReclaim<T> {
    fn reclaim(self, _unclaimed: Unclaimed<Self>) {}
}

struct Erased;

const STATE_UNPINNED: usize = usize::MAX;
const STATE_UNCLAIMED: usize = usize::MAX - 1;
const STATE_SHARED_MAX: usize = isize::MAX as usize;
const STATE_SHARED_INIT: usize = 0;

#[repr(C)]
pub struct Root<T> {
    state: Cell<usize>,
    reclaim_shim: unsafe fn(NonNull<Root<Erased>>),
    _pin: PhantomPinned,
    target: MaybeUninit<ProjectTarget<T>>,
}

impl<T> Drop for Root<T> {
    fn drop(&mut self) {
        let state = self.state.get();
        if state <= STATE_SHARED_MAX {
            struct Abort;
            impl Drop for Abort {
                fn drop(&mut self) {
                    panic!("abort");
                }
            }

            let _abort = Abort;
            panic!("root leak");
        }
    }
}

impl<T> Root<T> {
    pub const fn new() -> Self
    where
        T: Reclaim,
    {
        Self {
            state: Cell::new(STATE_UNPINNED),
            reclaim_shim: reclaim_shim::<T>,
            _pin: PhantomPinned,
            target: MaybeUninit::uninit(),
        }
    }

    pub fn unclaimed(self: Pin<&mut Self>) -> Unclaimed<T> {
        if self.state.get() != STATE_UNPINNED {
            panic!("unclaimed called multiple times");
        }
        self.state.set(STATE_UNCLAIMED);

        let ptr = unsafe { NonNull::from(self.get_unchecked_mut()) };
        let ptr_erased = ptr.cast::<Root<Erased>>();

        unsafe {
            let target_root = ptr::addr_of_mut!((*(*ptr.as_ptr()).target.as_mut_ptr()).root);
            target_root.write(Cell::new(Some(ptr_erased)));
        }

        Unclaimed { ptr }
    }
}

impl<T: Reclaim> Default for Root<T> {
    fn default() -> Self {
        Self::new()
    }
}

pub struct ProjectTarget<T: ?Sized> {
    root: Cell<Option<NonNull<Root<Erased>>>>,
    value: T,
}

impl<T> ProjectTarget<T> {
    pub const fn new(value: T) -> Self {
        Self {
            root: Cell::new(None),
            value,
        }
    }
}

impl<T: ?Sized> Deref for ProjectTarget<T> {
    type Target = T;

    fn deref(&self) -> &T {
        &self.value
    }
}

impl<T: ?Sized> DerefMut for ProjectTarget<T> {
    fn deref_mut(&mut self) -> &mut T {
        &mut self.value
    }
}

#[repr(transparent)]
pub struct Unclaimed<T> {
    ptr: NonNull<Root<T>>,
}

impl<T> Unclaimed<T> {
    pub fn claim(self, value: T) -> Xrc<T> {
        unsafe {
            debug_assert_eq!((*self.ptr.as_ptr()).state.get(), STATE_UNCLAIMED);
            (*self.ptr.as_ptr()).state.set(STATE_SHARED_INIT);
        }

        unsafe {
            let target_value = ptr::addr_of_mut!((*(*self.ptr.as_ptr()).target.as_mut_ptr()).value);
            target_value.write(value);
        }

        // shared read-only reborrow
        let target = unsafe { (*self.ptr.as_ptr()).target.assume_init_ref() };
        let ptr = NonNull::from(target);
        Xrc { ptr }
    }
}

#[repr(transparent)]
pub struct Xrc<T: ?Sized> {
    ptr: NonNull<ProjectTarget<T>>,
}

impl<T: ?Sized> Drop for Xrc<T> {
    fn drop(&mut self) {
        let Some(root) = Self::inner(self).root.get() else {
            unreachable!("missing root");
        };

        let state_ref = unsafe { &(*root.as_ptr()).state };
        let state = state_ref.get();
        debug_assert!(state <= STATE_SHARED_MAX);
        if state != STATE_SHARED_INIT {
            state_ref.set(state - 1);
            return;
        }

        let reclaim_shim = unsafe { (*root.as_ptr()).reclaim_shim };
        unsafe {
            reclaim_shim(root);
        }
    }
}

impl<T> Clone for Xrc<T> {
    fn clone(&self) -> Self {
        let root = Self::inner(self).root.get();
        let Some(root) = root else {
            unreachable!("missing root");
        };

        let state_ref = unsafe { &(*root.as_ptr()).state };
        let state = state_ref.get();
        debug_assert!(state <= STATE_SHARED_MAX);
        if state == STATE_SHARED_MAX {
            panic!("ref count overflow");
        }
        state_ref.set(state + 1);

        Self { ptr: self.ptr }
    }
}

impl<T: ?Sized> Deref for Xrc<T> {
    type Target = T;

    fn deref(&self) -> &T {
        &Self::inner(self).value
    }
}

impl<T: ?Sized> Xrc<T> {
    pub fn project<U, F>(this: Self, f: F) -> Xrc<U>
    where
        U: ?Sized,
        F: FnOnce(&ProjectTarget<T>) -> &ProjectTarget<U>,
    {
        let source = Self::inner(&this);
        let Some(source_root) = source.root.get() else {
            unreachable!("missing root");
        };

        let target = f(source);

        if let Some(target_root) = target.root.get() {
            if target_root != source_root {
                panic!("projected from different roots");
            }
        } else {
            target.root.set(Some(source_root));
        }

        let ptr = NonNull::from(target);
        mem::forget(this);
        Xrc { ptr }
    }

    fn inner(this: &Self) -> &ProjectTarget<T> {
        unsafe { this.ptr.as_ref() }
    }
}

#[cfg(test)]
mod tests {
    extern crate std;
    use super::*;
    use std::prelude::rust_2021::*;

    use std::cell::RefCell;
    use std::fmt::Display;
    use std::panic::{self, AssertUnwindSafe};
    use std::pin::pin;
    use std::rc::Rc;

    fn iter_pinned_mut<T>(slice: Pin<&mut [T]>) -> impl Iterator<Item = Pin<&mut T>> {
        unsafe {
            slice
                .get_unchecked_mut()
                .iter_mut()
                .map(|elem| Pin::new_unchecked(elem))
        }
    }

    #[test]
    fn basic() {
        let root = pin!(Root::new());
        let xrc = root.unclaimed().claim(NoReclaim(42));
        assert_eq!(xrc.0, 42);
    }

    #[test]
    fn unpin_hack() {
        let mut root = pin!(Root::new());
        let xrc = root.as_mut().unclaimed().claim(NoReclaim(42));

        let res = panic::catch_unwind(AssertUnwindSafe(|| {
            // This must not create an unique borrow ...
            let _ = root.unclaimed();
        }));
        assert!(res.is_err());

        // ... to not invalidate this shared read-only borrow.
        assert_eq!(xrc.0, 42);
    }

    #[test]
    #[ignore = "will abort"]
    fn abort() {
        #[allow(clippy::needless_late_init)]
        let _xrc;
        let root = pin!(Root::new());
        _xrc = root.unclaimed().claim(NoReclaim(()));
    }

    #[test]
    #[should_panic = "unclaimed called multiple times"]
    fn multiple_unclaimed() {
        let mut root = pin!(Root::<NoReclaim<i32>>::new());
        let _ = root.as_mut().unclaimed();
        let _ = root.as_mut().unclaimed();
    }

    #[test]
    fn clone_reclaim_drop() {
        struct Check {
            reclaim_count: Rc<Cell<usize>>,
            drop_count: Rc<Cell<usize>>,
        }

        impl Reclaim for Check {
            fn reclaim(self, _unclaimed: Unclaimed<Self>) {
                self.reclaim_count.set(self.reclaim_count.get() + 1);
            }
        }

        impl Drop for Check {
            fn drop(&mut self) {
                self.drop_count.set(self.drop_count.get() + 1);
            }
        }

        let drop_count = Rc::new(Cell::new(0));
        let reclaim_count = Rc::new(Cell::new(0));
        let root = pin!(Root::new());
        let xrc = root.unclaimed().claim(Check {
            drop_count: drop_count.clone(),
            reclaim_count: reclaim_count.clone(),
        });
        assert_eq!(xrc.drop_count.get(), 0);
        assert_eq!(xrc.reclaim_count.get(), 0);
        let xrc2 = xrc.clone();
        assert_eq!(xrc.drop_count.get(), 0);
        assert_eq!(xrc.reclaim_count.get(), 0);
        drop(xrc);
        assert_eq!(xrc2.drop_count.get(), 0);
        assert_eq!(xrc2.reclaim_count.get(), 0);
        drop(xrc2);
        assert_eq!(drop_count.get(), 1);
        assert_eq!(reclaim_count.get(), 1);
    }

    #[test]
    fn project() {
        struct Value {
            a: i32,
            b: ProjectTarget<i32>,
        }

        let root = pin!(Root::new());
        let xrc = root.unclaimed().claim(NoReclaim(Value {
            a: 1,
            b: ProjectTarget::new(2),
        }));
        assert_eq!(xrc.0.a, 1);
        assert_eq!(*xrc.0.b, 2);
        let xrc2 = Xrc::project(xrc.clone(), |xrc| &xrc.0.b);
        assert_eq!(xrc.0.a, 1);
        assert_eq!(*xrc.0.b, 2);
        assert_eq!(*xrc2, 2);
    }

    #[test]
    fn project_unsized() {
        let root = pin!(Root::new());
        let xrc = root.unclaimed().claim(NoReclaim(42));
        let xrc = Xrc::project::<NoReclaim<dyn Display>, _>(xrc, |xrc| xrc);
        assert_eq!(std::format!("{}", &xrc.0), "42");
    }

    #[test]
    #[should_panic = "projected from different roots"]
    fn project_multiple() {
        let target = Rc::new(ProjectTarget::new(0));

        let root1 = pin!(Root::new());
        let xrc1 = root1.unclaimed().claim(NoReclaim(target.clone()));

        let root2 = pin!(Root::new());
        let xrc2 = root2.unclaimed().claim(NoReclaim(target));

        let _ = Xrc::project(xrc1, |xrc| &*xrc.0);
        let _ = Xrc::project(xrc2, |xrc| &*xrc.0);
    }

    #[test]
    fn gc_alloc() {
        struct Gc<T> {
            free: Rc<RefCell<Vec<Unclaimed<Gc<T>>>>>,
            value: T,
        }

        impl<T> Deref for Gc<T> {
            type Target = T;

            fn deref(&self) -> &T {
                &self.value
            }
        }

        impl<T> Reclaim for Gc<T> {
            fn reclaim(self, unclaimed: Unclaimed<Self>) {
                self.free.borrow_mut().push(unclaimed)
            }
        }

        struct Alloc<T> {
            free: Rc<RefCell<Vec<Unclaimed<Gc<T>>>>>,
        }

        impl<T> Alloc<T> {
            fn try_alloc(&self, value: T) -> Result<Xrc<Gc<T>>, T> {
                let unclaimed = self.free.borrow_mut().pop();
                // drop RefMut
                if let Some(unclaimed) = unclaimed {
                    Ok(unclaimed.claim(Gc {
                        free: self.free.clone(),
                        value,
                    }))
                } else {
                    Err(value)
                }
            }
        }

        struct Value {
            number: ProjectTarget<i32>,
            string: ProjectTarget<String>,
        }

        #[allow(clippy::declare_interior_mutable_const)]
        const INIT: Root<Gc<Value>> = Root::new();
        let mut storage = pin!([INIT; 8]);
        let free = iter_pinned_mut(storage.as_mut())
            .map(|root| root.unclaimed())
            .collect::<Vec<_>>();
        let free = Rc::new(RefCell::new(free));
        let alloc = Alloc { free };

        let mut values = Vec::new();
        for n in 0..8 {
            let value = Value {
                number: ProjectTarget::new(n),
                string: ProjectTarget::new(n.to_string()),
            };
            let xrc = alloc.try_alloc(value).ok().unwrap();
            values.push(xrc);
        }

        let mut value_nums = values
            .iter()
            .map(|xrc| Xrc::project(xrc.clone(), |value| &value.number))
            .collect::<Vec<_>>();

        let mut value_strs = values
            .iter()
            .map(|xrc| Xrc::project(xrc.clone(), |value| &value.string))
            .collect::<Vec<_>>();

        assert!(value_nums.iter().map(|xrc| **xrc).eq(0..8));
        assert!(value_strs
            .iter()
            .map(|xrc| xrc.parse::<i32>().unwrap())
            .eq(0..8));

        let value = Value {
            number: ProjectTarget::new(0),
            string: ProjectTarget::new(String::new()),
        };

        let value = alloc.try_alloc(value).err().unwrap();
        values.clear();
        let value = alloc.try_alloc(value).err().unwrap();
        value_nums.clear();
        let value = alloc.try_alloc(value).err().unwrap();
        value_strs.clear();
        let _xrc = alloc.try_alloc(value).ok().unwrap();
    }
}