use std::{ops::ControlFlow, rc::Rc};
use indexical::ToIndex;
use polonius_engine::AllFacts;
use rustc_borrowck::consumers::{BodyWithBorrowckFacts, PoloniusInput};
use rustc_hir::def_id::DefId;
use rustc_middle::{
mir::*,
ty::{
Region, RegionKind, RegionVid, Ty, TyCtxt, TyKind, TypeSuperVisitable, TypeVisitor,
},
};
use rustc_utils::{
block_timer,
cache::{Cache, CopyCache},
mir::{
body,
location_or_arg::{
index::{LocationOrArgDomain, LocationOrArgIndex},
LocationOrArg,
},
place::UNKNOWN_REGION,
},
BodyExt, MutabilityExt, PlaceExt,
};
use super::{aliases::Aliases, utils::PlaceSet, FlowistryInput};
use crate::extensions::{is_extension_active, MutabilityMode};
pub struct PlaceInfo<'tcx> {
pub tcx: TyCtxt<'tcx>,
pub body: &'tcx Body<'tcx>,
pub def_id: DefId,
location_domain: Rc<LocationOrArgDomain>,
aliases: Aliases<'tcx>,
normalized_cache: CopyCache<Place<'tcx>, Place<'tcx>>,
aliases_cache: Cache<Place<'tcx>, PlaceSet<'tcx>>,
conflicts_cache: Cache<Place<'tcx>, PlaceSet<'tcx>>,
reachable_cache: Cache<(Place<'tcx>, Mutability), PlaceSet<'tcx>>,
}
impl<'tcx> PlaceInfo<'tcx> {
fn build_location_arg_domain(body: &Body) -> Rc<LocationOrArgDomain> {
let all_locations = body.all_locations().map(LocationOrArg::Location);
let all_locals = body.args_iter().map(LocationOrArg::Arg);
let domain = all_locations.chain(all_locals).collect();
Rc::new(LocationOrArgDomain::new(domain))
}
pub fn build<'a>(
tcx: TyCtxt<'tcx>,
def_id: DefId,
input: impl FlowistryInput<'tcx, 'a>,
) -> Self {
Self::build_from_input_facts(tcx, def_id, input)
}
pub fn build_from_input_facts<'a>(
tcx: TyCtxt<'tcx>,
def_id: DefId,
input: impl FlowistryInput<'tcx, 'a>,
) -> Self {
block_timer!("aliases");
let body = input.body();
let location_domain = Self::build_location_arg_domain(body);
let aliases = Aliases::build(tcx, def_id, input);
PlaceInfo {
aliases,
tcx,
body,
def_id,
location_domain,
aliases_cache: Cache::default(),
normalized_cache: CopyCache::default(),
conflicts_cache: Cache::default(),
reachable_cache: Cache::default(),
}
}
pub fn body(&self) -> &'tcx Body<'tcx> {
self.body
}
pub fn normalize(&self, place: Place<'tcx>) -> Place<'tcx> {
self
.normalized_cache
.get(&place, |place| place.normalize(self.tcx, self.def_id))
}
pub fn aliases(&self, place: Place<'tcx>) -> &PlaceSet<'tcx> {
self
.aliases_cache
.get(&self.normalize(place), move |_| self.aliases.aliases(place))
}
pub fn children(&self, place: Place<'tcx>) -> PlaceSet<'tcx> {
PlaceSet::from_iter(place.interior_places(self.tcx, self.body, self.def_id))
}
pub fn conflicts(&self, place: Place<'tcx>) -> &PlaceSet<'tcx> {
self.conflicts_cache.get(&place, |place| {
let children = self.children(place);
let parents = place
.projection
.iter()
.enumerate()
.map(|(i, elem)| {
let place = PlaceRef {
local: place.local,
projection: &place.projection[.. i],
};
(place, elem)
})
.take_while(|(place, elem)| {
place.ty(self.body.local_decls(), self.tcx).ty.is_box()
|| !matches!(elem, PlaceElem::Deref)
})
.map(|(place_ref, _)| Place::from_ref(place_ref, self.tcx));
children.into_iter().chain(parents).collect()
})
}
pub fn reachable_values(
&self,
place: Place<'tcx>,
mutability: Mutability,
) -> &PlaceSet<'tcx> {
self.reachable_cache.get(&(place, mutability), |_| {
let ty = place.ty(self.body.local_decls(), self.tcx).ty;
let loans = self.collect_loans(ty, mutability);
loans
.into_iter()
.chain([place])
.filter(|place| {
if let Some((place, _)) = place.refs_in_projection(&self.body, self.tcx).last()
{
let ty = place.ty(self.body.local_decls(), self.tcx).ty;
if ty.is_box() || ty.is_unsafe_ptr() {
return true;
}
}
place.is_direct(self.body, self.tcx)
})
.collect()
})
}
fn collect_loans(&self, ty: Ty<'tcx>, mutability: Mutability) -> PlaceSet<'tcx> {
let mut collector = LoanCollector {
aliases: &self.aliases,
unknown_region: Region::new_var(self.tcx, UNKNOWN_REGION),
target_mutability: mutability,
stack: vec![],
loans: PlaceSet::default(),
};
collector.visit_ty(ty);
collector.loans
}
pub fn all_args(&self) -> impl Iterator<Item = (Place<'tcx>, LocationOrArgIndex)> + '_ {
self.body.args_iter().flat_map(|local| {
let location = local.to_index(&self.location_domain);
let place = Place::from_local(local, self.tcx);
let ptrs = place
.interior_pointers(self.tcx, self.body, self.def_id)
.into_values()
.flat_map(|ptrs| {
ptrs
.into_iter()
.filter(|(ptr, _)| ptr.projection.len() <= 2)
.map(|(ptr, _)| self.tcx.mk_place_deref(ptr))
});
ptrs
.chain([place])
.flat_map(|place| place.interior_places(self.tcx, self.body, self.def_id))
.map(move |place| (place, location))
})
}
pub fn location_domain(&self) -> &Rc<LocationOrArgDomain> {
&self.location_domain
}
}
struct LoanCollector<'a, 'tcx> {
aliases: &'a Aliases<'tcx>,
unknown_region: Region<'tcx>,
target_mutability: Mutability,
stack: Vec<Mutability>,
loans: PlaceSet<'tcx>,
}
impl<'tcx> TypeVisitor<TyCtxt<'tcx>> for LoanCollector<'_, 'tcx> {
type Result = ControlFlow<()>;
fn visit_ty(&mut self, ty: Ty<'tcx>) -> Self::Result {
match ty.kind() {
TyKind::Ref(_, _, mutability) => {
self.stack.push(*mutability);
ty.super_visit_with(self);
self.stack.pop();
return ControlFlow::Break(());
}
_ if ty.is_box() || ty.is_unsafe_ptr() => {
self.visit_region(self.unknown_region);
}
_ => {}
};
ty.super_visit_with(self);
ControlFlow::Continue(())
}
fn visit_region(&mut self, region: Region<'tcx>) -> Self::Result {
let region = match region.kind() {
RegionKind::ReVar(region) => region,
RegionKind::ReStatic => RegionVid::from_usize(0),
RegionKind::ReErased | RegionKind::ReBound(..) => {
return ControlFlow::Continue(());
}
_ => unreachable!("{region:?}"),
};
if let Some(loans) = self.aliases.loans.get(®ion) {
let under_immut_ref = self.stack.iter().any(|m| *m == Mutability::Not);
let ignore_mut =
is_extension_active(|mode| mode.mutability_mode == MutabilityMode::IgnoreMut);
self
.loans
.extend(loans.iter().filter_map(|(place, mutability)| {
if ignore_mut {
return Some(place);
}
let loan_mutability = if under_immut_ref {
Mutability::Not
} else {
*mutability
};
self
.target_mutability
.is_permissive_as(loan_mutability)
.then_some(place)
}))
}
ControlFlow::Continue(())
}
}
#[cfg(test)]
mod test {
use rustc_utils::{
hashset,
test_utils::{compare_sets, Placer},
};
use super::*;
use crate::test_utils;
fn placeinfo_harness(
input: &str,
f: impl for<'tcx> FnOnce(TyCtxt<'tcx>, &Body<'tcx>, PlaceInfo<'tcx>) + Send,
) {
test_utils::compile_body(input, |tcx, body_id, body_with_facts| {
let body = &body_with_facts.body;
let def_id = tcx.hir().body_owner_def_id(body_id);
let place_info = PlaceInfo::build(tcx, def_id.to_def_id(), body_with_facts);
f(tcx, body, place_info)
});
}
#[test]
fn test_placeinfo_basic() {
let input = r#"
fn main() {
let a = 0;
let mut b = 1;
let c = ((0, &a), &mut b);
let d = 0;
let e = &d;
let f = &e;
}
"#;
placeinfo_harness(input, |tcx, body, place_info| {
let p = Placer::new(tcx, body);
let c = p.local("c");
compare_sets(place_info.children(c.mk()), hashset! {
c.mk(),
c.field(0).mk(),
c.field(0).field(0).mk(),
c.field(0).field(1).mk(),
c.field(1).mk(),
});
compare_sets(place_info.conflicts(c.field(0).mk()), &hashset! {
c.mk(),
c.field(0).mk(),
c.field(0).field(0).mk(),
c.field(0).field(1).mk(),
});
compare_sets(
place_info.reachable_values(c.mk(), Mutability::Not),
&hashset! {
c.mk(),
p.local("a").mk(),
p.local("b").mk()
},
);
compare_sets(
place_info.reachable_values(c.mk(), Mutability::Mut),
&hashset! {
c.mk(),
p.local("b").mk()
},
);
compare_sets(
place_info.reachable_values(p.local("f").mk(), Mutability::Not),
&hashset! {
p.local("f").mk(),
p.local("e").mk(),
p.local("d").mk()
},
)
});
}
}