use rustc_data_structures::fx::FxHashSet as HashSet;
use rustc_hir::def_id::DefId;
use rustc_middle::{
  mir::*,
  ty::{GenericArgKind, RegionKind, RegionVid, Ty, TyCtxt},
};
use rustc_span::source_map::Spanned;
use rustc_utils::{BodyExt, OperandExt, PlaceExt};
use crate::extensions::{is_extension_active, MutabilityMode};
pub type PlaceSet<'tcx> = HashSet<Place<'tcx>>;
pub fn arg_mut_ptrs<'tcx>(
  args: &[(usize, Place<'tcx>)],
  tcx: TyCtxt<'tcx>,
  body: &Body<'tcx>,
  def_id: DefId,
) -> Vec<(usize, Place<'tcx>)> {
  let ignore_mut =
    is_extension_active(|mode| mode.mutability_mode == MutabilityMode::IgnoreMut);
  args
    .iter()
    .flat_map(|(i, place)| {
      place
        .interior_pointers(tcx, body, def_id)
        .into_iter()
        .flat_map(|(_, places)| {
          places
            .into_iter()
            .filter_map(|(place, mutability)| match mutability {
              Mutability::Mut => Some(place),
              Mutability::Not => ignore_mut.then_some(place),
            })
        })
        .map(move |place| (*i, tcx.mk_place_deref(place)))
    })
    .collect::<Vec<_>>()
}
pub fn arg_places<'tcx>(args: &[Spanned<Operand<'tcx>>]) -> Vec<(usize, Place<'tcx>)> {
  args
    .iter()
    .enumerate()
    .filter_map(|(i, arg)| arg.node.as_place().map(move |place| (i, place)))
    .collect::<Vec<_>>()
}
pub struct AsyncHack<'a, 'tcx> {
  context_ty: Option<Ty<'tcx>>,
  tcx: TyCtxt<'tcx>,
  body: &'a Body<'tcx>,
}
impl<'a, 'tcx> AsyncHack<'a, 'tcx> {
  pub fn new(tcx: TyCtxt<'tcx>, body: &'a Body<'tcx>, def_id: DefId) -> Self {
    let context_ty = body.async_context(tcx, def_id);
    AsyncHack {
      context_ty,
      tcx,
      body,
    }
  }
  pub fn ignore_regions(&self) -> HashSet<RegionVid> {
    match self.context_ty {
      Some(context_ty) => context_ty
        .walk()
        .filter_map(|part| match part.unpack() {
          GenericArgKind::Lifetime(r) => match r.kind() {
            RegionKind::ReVar(rv) => Some(rv),
            _ => None,
          },
          _ => None,
        })
        .collect::<HashSet<_>>(),
      None => HashSet::default(),
    }
  }
  pub fn ignore_place(&self, place: Place<'tcx>) -> bool {
    match self.context_ty {
      Some(context_ty) => {
        self
          .tcx
          .erase_regions(place.ty(&self.body.local_decls, self.tcx).ty)
          == self.tcx.erase_regions(context_ty)
      }
      None => false,
    }
  }
}