use ahash::AHashMap;
use splitmut::SplitMut;
use std::{fmt, hash::Hash};
use crate::{
bitset::BitSet, pointer::PointerFamily, Captures, IndexSet, IndexedDomain, IndexedValue,
ToIndex,
};
pub struct IndexMatrix<'a, R, C: IndexedValue + 'a, S: BitSet, P: PointerFamily<'a>> {
pub(crate) matrix: AHashMap<R, IndexSet<'a, C, S, P>>,
empty_set: IndexSet<'a, C, S, P>,
col_domain: P::Pointer<IndexedDomain<C>>,
}
impl<'a, R, C, S, P> IndexMatrix<'a, R, C, S, P>
where
R: PartialEq + Eq + Hash + Clone,
C: IndexedValue + 'a,
S: BitSet,
P: PointerFamily<'a>,
{
pub fn new(col_domain: &P::Pointer<IndexedDomain<C>>) -> Self {
IndexMatrix {
matrix: AHashMap::default(),
empty_set: IndexSet::new(col_domain),
col_domain: col_domain.clone(),
}
}
pub(crate) fn ensure_row(&mut self, row: R) -> &mut IndexSet<'a, C, S, P> {
self.matrix
.entry(row)
.or_insert_with(|| self.empty_set.clone())
}
pub fn insert<M>(&mut self, row: R, col: impl ToIndex<C, M>) -> bool {
let col = col.to_index(&self.col_domain);
self.ensure_row(row).insert(col)
}
pub fn union_into_row(&mut self, into: R, from: &IndexSet<'a, C, S, P>) -> bool {
self.ensure_row(into).union_changed(from)
}
pub fn union_rows(&mut self, from: R, to: R) -> bool {
if from == to {
return false;
}
self.ensure_row(from.clone());
self.ensure_row(to.clone());
let (from, to) = unsafe { self.matrix.get2_unchecked_mut(&from, &to) };
to.union_changed(from)
}
pub fn row(&self, row: &R) -> impl Iterator<Item = &C> + Captures<'a> + '_ {
self.matrix.get(row).into_iter().flat_map(|set| set.iter())
}
pub fn rows(&self) -> impl Iterator<Item = (&R, &IndexSet<'a, C, S, P>)> + Captures<'a> + '_ {
self.matrix.iter()
}
pub fn row_set(&self, row: &R) -> &IndexSet<'a, C, S, P> {
self.matrix.get(row).unwrap_or(&self.empty_set)
}
pub fn clear_row(&mut self, row: &R) {
self.matrix.remove(row);
}
pub fn col_domain(&self) -> &P::Pointer<IndexedDomain<C>> {
&self.col_domain
}
}
impl<'a, R, C, S, P> PartialEq for IndexMatrix<'a, R, C, S, P>
where
R: PartialEq + Eq + Hash + Clone,
C: IndexedValue + 'a,
S: BitSet,
P: PointerFamily<'a>,
{
fn eq(&self, other: &Self) -> bool {
self.matrix == other.matrix
}
}
impl<'a, R, C, S, P> Eq for IndexMatrix<'a, R, C, S, P>
where
R: PartialEq + Eq + Hash + Clone,
C: IndexedValue + 'a,
S: BitSet,
P: PointerFamily<'a>,
{
}
impl<'a, R, C, S, P> Clone for IndexMatrix<'a, R, C, S, P>
where
R: PartialEq + Eq + Hash + Clone,
C: IndexedValue + 'a,
S: BitSet,
P: PointerFamily<'a>,
{
fn clone(&self) -> Self {
Self {
matrix: self.matrix.clone(),
empty_set: self.empty_set.clone(),
col_domain: self.col_domain.clone(),
}
}
fn clone_from(&mut self, source: &Self) {
for col in self.matrix.values_mut() {
col.clear();
}
for (row, col) in source.matrix.iter() {
self.ensure_row(row.clone()).clone_from(col);
}
self.empty_set = source.empty_set.clone();
self.col_domain = source.col_domain.clone();
}
}
impl<'a, R, C, S, P> fmt::Debug for IndexMatrix<'a, R, C, S, P>
where
R: PartialEq + Eq + Hash + Clone + fmt::Debug,
C: IndexedValue + fmt::Debug + 'a,
S: BitSet,
P: PointerFamily<'a>,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_map().entries(self.rows()).finish()
}
}
#[cfg(test)]
mod test {
use crate::{test_utils::TestIndexMatrix, IndexedDomain};
use std::rc::Rc;
fn mk(s: &str) -> String {
s.to_string()
}
#[test]
fn test_indexmatrix() {
let col_domain = Rc::new(IndexedDomain::from_iter([mk("a"), mk("b"), mk("c")]));
let mut mtx = TestIndexMatrix::new(&col_domain);
mtx.insert(0, mk("b"));
mtx.insert(1, mk("c"));
assert_eq!(mtx.row(&0).collect::<Vec<_>>(), vec!["b"]);
assert_eq!(mtx.row(&1).collect::<Vec<_>>(), vec!["c"]);
assert!(mtx.union_rows(0, 1));
assert_eq!(mtx.row(&1).collect::<Vec<_>>(), vec!["b", "c"]);
}
}