use rustc_data_structures::fx::FxHashSet as HashSet;
use rustc_hir::def_id::DefId;
use rustc_middle::{
mir::*,
ty::{GenericArgKind, RegionKind, RegionVid, Ty, TyCtxt},
};
use rustc_span::source_map::Spanned;
use rustc_utils::{BodyExt, OperandExt, PlaceExt};
use crate::extensions::{is_extension_active, MutabilityMode};
pub type PlaceSet<'tcx> = HashSet<Place<'tcx>>;
pub fn arg_mut_ptrs<'tcx>(
args: &[(usize, Place<'tcx>)],
tcx: TyCtxt<'tcx>,
body: &Body<'tcx>,
def_id: DefId,
) -> Vec<(usize, Place<'tcx>)> {
let ignore_mut =
is_extension_active(|mode| mode.mutability_mode == MutabilityMode::IgnoreMut);
args
.iter()
.flat_map(|(i, place)| {
place
.interior_pointers(tcx, body, def_id)
.into_iter()
.flat_map(|(_, places)| {
places
.into_iter()
.filter_map(|(place, mutability)| match mutability {
Mutability::Mut => Some(place),
Mutability::Not => ignore_mut.then_some(place),
})
})
.map(move |place| (*i, tcx.mk_place_deref(place)))
})
.collect::<Vec<_>>()
}
pub fn arg_places<'tcx>(args: &[Spanned<Operand<'tcx>>]) -> Vec<(usize, Place<'tcx>)> {
args
.iter()
.enumerate()
.filter_map(|(i, arg)| arg.node.as_place().map(move |place| (i, place)))
.collect::<Vec<_>>()
}
pub struct AsyncHack<'a, 'tcx> {
context_ty: Option<Ty<'tcx>>,
tcx: TyCtxt<'tcx>,
body: &'a Body<'tcx>,
}
impl<'a, 'tcx> AsyncHack<'a, 'tcx> {
pub fn new(tcx: TyCtxt<'tcx>, body: &'a Body<'tcx>, def_id: DefId) -> Self {
let context_ty = body.async_context(tcx, def_id);
AsyncHack {
context_ty,
tcx,
body,
}
}
pub fn ignore_regions(&self) -> HashSet<RegionVid> {
match self.context_ty {
Some(context_ty) => context_ty
.walk()
.filter_map(|part| match part.unpack() {
GenericArgKind::Lifetime(r) => match r.kind() {
RegionKind::ReVar(rv) => Some(rv),
_ => None,
},
_ => None,
})
.collect::<HashSet<_>>(),
None => HashSet::default(),
}
}
pub fn ignore_place(&self, place: Place<'tcx>) -> bool {
match self.context_ty {
Some(context_ty) => {
self
.tcx
.erase_regions(place.ty(&self.body.local_decls, self.tcx).ty)
== self.tcx.erase_regions(context_ty)
}
None => false,
}
}
}