use std::{
	borrow::Borrow,
	collections::{hash_map::Entry, HashMap},
	hash::Hash,
};
pub trait Transactional {
	fn commit(&mut self);
	fn rollback(&mut self);
}
#[derive(Debug, Eq, PartialEq, Clone)]
pub struct TransactionalHashMap<K, V>
where
	K: Eq + Hash + Clone,
	V: Clone,
{
	inner: HashMap<K, V>,
	overridden_originals: HashMap<K, Option<V>>,
}
impl<K, V> TransactionalHashMap<K, V>
where
	K: Eq + Hash + Clone,
	V: Clone,
{
	pub fn new() -> Self {
		Self { inner: HashMap::new(), overridden_originals: HashMap::new() }
	}
	#[inline]
	pub fn with_capacity(capacity: usize) -> Self {
		Self {
			inner: HashMap::with_capacity(capacity),
			overridden_originals: HashMap::with_capacity(capacity),
		}
	}
	pub fn inner(&self) -> &HashMap<K, V> {
		&self.inner
	}
	#[inline]
	pub fn entry(&mut self, key: K) -> Entry<'_, K, V> {
		match self.inner.entry(key.clone()) {
			Entry::Vacant(v) => {
				if !self.overridden_originals.contains_key(&key) {
					self.overridden_originals.insert(key, None);
				}
				Entry::Vacant(v)
			},
			Entry::Occupied(o) => {
				if !self.overridden_originals.contains_key(&key) {
					self.overridden_originals.insert(key, Some(o.get().clone()));
				}
				Entry::Occupied(o)
			},
		}
	}
	#[inline]
	pub fn insert(&mut self, k: K, v: V) -> Option<V> {
		let prev = self.inner.remove(&k);
		if !self.overridden_originals.contains_key(&k) {
			self.overridden_originals.insert(k.clone(), prev);
		}
		self.inner.insert(k, v)
	}
	#[inline]
	pub fn remove(&mut self, k: &K) -> Option<V> {
		if let Some(v) = self.inner.remove(&k) {
			if !self.overridden_originals.contains_key(&k) {
				self.overridden_originals.insert(k.clone(), Some(v.clone()));
			}
			return Some(v)
		}
		None
	}
	#[inline]
	pub fn clear(&mut self) {
		for (key, value) in self.inner.drain() {
			if !self.overridden_originals.contains_key(&key) {
				self.overridden_originals.insert(key.clone(), Some(value));
			}
		}
	}
	#[inline]
	pub fn len(&self) -> usize {
		self.inner.len()
	}
	#[inline]
	pub fn get<Q: ?Sized>(&self, k: &Q) -> Option<&V>
	where
		K: Borrow<Q>,
		Q: Hash + Eq,
	{
		self.inner.get(k)
	}
	#[inline]
	pub fn get_mut<Q: ?Sized>(&mut self, k: &Q) -> Option<&mut V>
	where
		K: Borrow<Q>,
		Q: Hash + Eq,
	{
		self.inner.get_mut(k)
	}
}
impl<K, V> FromIterator<(K, V)> for TransactionalHashMap<K, V>
where
	K: Eq + Hash + Clone,
	V: Clone,
{
	fn from_iter<T: IntoIterator<Item = (K, V)>>(iter: T) -> TransactionalHashMap<K, V> {
		let mut map = TransactionalHashMap::new();
		map.extend(iter);
		map
	}
}
impl<K, V> Extend<(K, V)> for TransactionalHashMap<K, V>
where
	K: Eq + Hash + Clone,
	V: Clone,
{
	#[inline]
	fn extend<T: IntoIterator<Item = (K, V)>>(&mut self, iter: T) {
		self.inner.extend(iter)
	}
}
impl<K, V> Transactional for TransactionalHashMap<K, V>
where
	K: Eq + Hash + Clone,
	V: Clone,
{
	fn commit(&mut self) {
		self.overridden_originals = HashMap::new();
	}
	fn rollback(&mut self) {
		for (key, v) in self.overridden_originals.drain() {
			match v {
				Some(old) => self.inner.insert(key, old),
				None => self.inner.remove(&key),
			};
		}
	}
}
#[cfg(test)]
mod tests {
	use super::*;
	use std::collections::BTreeMap;
	#[test]
	fn transactional_hashmap_should_revert_the_state_as_before() {
		let arr = [(1, 9), (2, 8), (3, 7), (4, 6), (5, 5), (6, 4), (7, 3), (8, 2), (9, 1)];
		let mut transactional = TransactionalHashMap::new();
		for (k, v) in arr {
			transactional.insert(k, v);
		}
		transactional.commit();
		let arr2 = vec![(10, 100), (11, 200), (12, 300), (13, 400)];
		for (k, v) in arr2 {
			transactional.insert(k, v);
		}
		let inner_sorted: BTreeMap<_, _> = transactional.inner.clone().into_iter().collect();
		let expected = BTreeMap::from([
			(1, 9),
			(2, 8),
			(3, 7),
			(4, 6),
			(5, 5),
			(6, 4),
			(7, 3),
			(8, 2),
			(9, 1),
			(10, 100),
			(11, 200),
			(12, 300),
			(13, 400),
		]);
		assert_eq!(inner_sorted, expected);
		transactional.rollback();
		let inner_sorted: BTreeMap<_, _> = transactional.inner.clone().into_iter().collect();
		let expected = BTreeMap::from(arr);
		assert_eq!(inner_sorted, expected);
		transactional.remove(&5);
		transactional.insert(89, 98);
		let inner_sorted: BTreeMap<_, _> = transactional.inner.clone().into_iter().collect();
		let expected_2 = BTreeMap::from([
			(1, 9),
			(2, 8),
			(3, 7),
			(4, 6),
			(89, 98),
			(6, 4),
			(7, 3),
			(8, 2),
			(9, 1),
		]);
		assert_eq!(inner_sorted, expected_2);
		transactional.rollback();
		let inner_sorted: BTreeMap<_, _> = transactional.inner.clone().into_iter().collect();
		assert_eq!(inner_sorted, expected);
		assert_eq!(transactional.overridden_originals.len(), 0);
	}
	#[test]
	fn transactional_hashmap_entry_should_behave_as_expected() {
		let mut transactional = TransactionalHashMap::new();
		transactional.entry(5).and_modify(|n| *n = 100).or_insert(200);
		let inner_sorted: BTreeMap<_, _> = transactional.inner.clone().into_iter().collect();
		assert_eq!(inner_sorted, BTreeMap::from([(5, 200)]));
		transactional.commit();
		transactional.entry(5).and_modify(|n| *n = 100).or_insert(200);
		let inner_sorted: BTreeMap<_, _> = transactional.inner.clone().into_iter().collect();
		assert_eq!(inner_sorted, BTreeMap::from([(5, 100)]));
		transactional.rollback();
		let inner_sorted: BTreeMap<_, _> = transactional.inner.clone().into_iter().collect();
		assert_eq!(inner_sorted, BTreeMap::from([(5, 200)]));
	}
	#[test]
	fn transactional_hashmap_vec() {
		let mut transactional: TransactionalHashMap<u32, Vec<u32>> = TransactionalHashMap::new();
		for i in 1..5 {
			transactional.entry(1).or_default().push(i);
		}
		for i in 10..15 {
			transactional.entry(2).or_default().push(i);
		}
		assert_eq!(transactional.overridden_originals.len(), 2);
		let inner_sorted: BTreeMap<_, _> = transactional.inner.clone().into_iter().collect();
		assert_eq!(
			inner_sorted,
			BTreeMap::from([(1, vec![1, 2, 3, 4]), (2, vec![10, 11, 12, 13, 14])])
		);
		transactional.commit();
		for i in (0..3).rev() {
			transactional.entry(1).or_default().remove(i);
		}
		let inner_sorted: BTreeMap<_, _> = transactional.inner.clone().into_iter().collect();
		assert_eq!(inner_sorted, BTreeMap::from([(1, vec![4]), (2, vec![10, 11, 12, 13, 14])]));
		transactional.rollback();
		let inner_sorted: BTreeMap<_, _> = transactional.inner.clone().into_iter().collect();
		assert_eq!(
			inner_sorted,
			BTreeMap::from([(1, vec![1, 2, 3, 4]), (2, vec![10, 11, 12, 13, 14])])
		);
		assert_eq!(transactional.overridden_originals.len(), 0);
	}
}