use std::{hash::Hash, time::Instant};
use log::{debug, info};
use polonius_engine::AllFacts;
use rustc_borrowck::consumers::{BodyWithBorrowckFacts, PoloniusInput};
use rustc_data_structures::{
fx::{FxHashMap as HashMap, FxHashSet as HashSet},
graph::{iterate::reverse_post_order, scc::Sccs, vec_graph::VecGraph},
intern::Interned,
};
use rustc_hir::def_id::DefId;
use rustc_index::{
bit_set::{ChunkedBitSet, SparseBitMatrix},
IndexVec,
};
use rustc_middle::{
mir::{visit::Visitor, *},
ty::{Region, RegionKind, RegionVid, Ty, TyCtxt, TyKind},
};
use rustc_utils::{mir::place::UNKNOWN_REGION, timer::elapsed, PlaceExt};
use super::FlowistryInput;
use crate::{
extensions::{is_extension_active, PointerMode},
mir::utils::{AsyncHack, PlaceSet},
};
#[derive(Default)]
struct GatherBorrows<'tcx> {
borrows: Vec<(RegionVid, BorrowKind, Place<'tcx>)>,
}
macro_rules! region_pat {
($name:ident) => {
Region(Interned(RegionKind::ReVar($name), _))
};
}
impl<'tcx> Visitor<'tcx> for GatherBorrows<'tcx> {
fn visit_assign(
&mut self,
_place: &Place<'tcx>,
rvalue: &Rvalue<'tcx>,
_location: Location,
) {
if let Rvalue::Ref(region_pat!(region), kind, borrowed_place) = rvalue {
self.borrows.push((*region, *kind, *borrowed_place));
}
}
}
type LoanSet<'tcx> = HashSet<(Place<'tcx>, Mutability)>;
type LoanMap<'tcx> = HashMap<RegionVid, LoanSet<'tcx>>;
pub struct Aliases<'tcx> {
tcx: TyCtxt<'tcx>,
body: &'tcx Body<'tcx>,
pub(super) loans: LoanMap<'tcx>,
}
rustc_index::newtype_index! {
#[orderable]
#[debug_format = "rs{}"]
struct RegionSccIndex {}
}
impl<'tcx> Aliases<'tcx> {
pub fn build<'a>(
tcx: TyCtxt<'tcx>,
def_id: DefId,
input: impl FlowistryInput<'tcx, 'a>,
) -> Self {
let loans = Self::compute_loans(tcx, def_id, input);
Aliases {
tcx,
body: input.body(),
loans,
}
}
fn compute_loans<'a>(
tcx: TyCtxt<'tcx>,
def_id: DefId,
input: impl FlowistryInput<'tcx, 'a>,
) -> LoanMap<'tcx> {
let start = Instant::now();
let body = input.body();
let static_region = RegionVid::from_usize(0);
let all_pointers = body
.local_decls()
.indices()
.flat_map(|local| {
Place::from_local(local, tcx).interior_pointers(tcx, body, def_id)
})
.collect::<Vec<_>>();
let max_region = all_pointers
.iter()
.map(|(region, _)| *region)
.chain(
input
.input_facts_subset_base()
.flat_map(|(r1, r2)| [r1, r2]),
)
.filter(|r| *r != UNKNOWN_REGION)
.max()
.unwrap_or(static_region);
let num_regions = max_region.as_usize() + 1;
let all_regions = (0 .. num_regions).map(RegionVid::from_usize);
let mut subset = SparseBitMatrix::new(num_regions);
let async_hack = AsyncHack::new(tcx, body, def_id);
let ignore_regions = async_hack.ignore_regions();
for (a, b) in input.input_facts_subset_base() {
if ignore_regions.contains(&a) || ignore_regions.contains(&b) {
continue;
}
subset.insert(a, b);
}
for a in all_regions.clone() {
subset.insert(static_region, a);
}
if is_extension_active(|mode| mode.pointer_mode == PointerMode::Conservative) {
let mut region_to_pointers: HashMap<_, Vec<_>> = HashMap::default();
for (region, places) in &all_pointers {
if *region != UNKNOWN_REGION {
region_to_pointers
.entry(*region)
.or_default()
.extend(places);
}
}
let constraints = generate_conservative_constraints(tcx, body, ®ion_to_pointers);
for (a, b) in constraints {
subset.insert(a, b);
}
}
let mut contains: LoanMap<'tcx> = HashMap::default();
let mut definite: HashMap<RegionVid, (Ty<'tcx>, Vec<PlaceElem<'tcx>>)> =
HashMap::default();
let mut gather_borrows = GatherBorrows::default();
gather_borrows.visit_body(body);
for (region, kind, place) in gather_borrows.borrows {
if place.is_direct(body, tcx) {
contains
.entry(region)
.or_default()
.insert((place, kind.to_mutbl_lossy()));
}
let def = match place.refs_in_projection(body, tcx).next() {
Some((ptr, proj)) => {
let ptr_ty = ptr.ty(body.local_decls(), tcx).ty;
(ptr_ty.builtin_deref(true).unwrap(), proj.to_vec())
}
None => (
body.local_decls()[place.local].ty,
place.projection.to_vec(),
),
};
definite.insert(region, def);
}
for arg in body.args_iter() {
for (region, places) in
Place::from_local(arg, tcx).interior_pointers(tcx, body, def_id)
{
let region_contains = contains.entry(region).or_default();
for (place, mutability) in places {
if place.projection.len() <= 2 {
region_contains.insert((tcx.mk_place_deref(place), mutability));
}
}
}
}
let unk_contains = contains.entry(UNKNOWN_REGION).or_default();
for (region, places) in &all_pointers {
if *region == UNKNOWN_REGION {
for (place, _) in places {
unk_contains.insert((tcx.mk_place_deref(*place), Mutability::Mut));
}
}
}
info!(
"Initial places in loan set: {}, total regions {}, definite regions: {}",
contains.values().map(|set| set.len()).sum::<usize>(),
contains.len(),
definite.len()
);
debug!("Initial contains: {contains:#?}");
debug!("Definite: {definite:#?}");
let edge_pairs = subset
.rows()
.flat_map(|r1| subset.iter(r1).map(move |r2| (r1, r2)))
.collect::<Vec<_>>();
let subset_graph = VecGraph::<_, false>::new(num_regions, edge_pairs);
let subset_sccs = Sccs::<RegionVid, RegionSccIndex>::new(&subset_graph);
let mut scc_to_regions = IndexVec::from_elem_n(
ChunkedBitSet::new_empty(num_regions),
subset_sccs.num_sccs(),
);
for r in all_regions.clone() {
let scc = subset_sccs.scc(r);
scc_to_regions[scc].insert(r);
}
let scc_order = reverse_post_order(&subset_sccs, subset_sccs.scc(static_region));
elapsed("relation construction", start);
let start = Instant::now();
for r in all_regions {
contains.entry(r).or_default();
}
for scc_idx in scc_order {
loop {
let mut changed = false;
let scc = &scc_to_regions[scc_idx];
for a in scc.iter() {
for b in subset.iter(a) {
if a == b {
continue;
}
let a_contains =
unsafe { &*(contains.get(&a).unwrap() as *const LoanSet<'tcx>) };
let b_contains =
unsafe { &mut *(contains.get_mut(&b).unwrap() as *mut LoanSet<'tcx>) };
let cyclic = scc.contains(b);
match definite.get(&b) {
Some((ty, proj)) if !cyclic => {
for (p, mutability) in a_contains.iter() {
let p_ty = p.ty(body.local_decls(), tcx).ty;
let p_proj = if *ty == p_ty {
let mut full_proj = p.projection.to_vec();
full_proj.extend(proj);
Place::make(p.local, tcx.mk_place_elems(&full_proj), tcx)
} else {
*p
};
changed |= b_contains.insert((p_proj, *mutability));
}
}
_ => {
let orig_len = b_contains.len();
b_contains.extend(a_contains);
changed |= b_contains.len() != orig_len;
}
}
}
}
if !changed {
break;
}
}
}
elapsed("fixpoint", start);
info!(
"Final places in loan set: {}",
contains.values().map(|set| set.len()).sum::<usize>()
);
log::trace!("contains: {contains:#?}");
contains
}
pub fn aliases(&self, place: Place<'tcx>) -> PlaceSet<'tcx> {
let mut aliases = HashSet::default();
if place.is_arg(self.body) {
aliases.insert(place);
return aliases;
}
let Some((ptr, after)) = place.refs_in_projection(&self.body, self.tcx).last() else {
aliases.insert(place);
return aliases;
};
let ptr_ty = ptr.ty(self.body.local_decls(), self.tcx).ty;
let (region, orig_ty) = match ptr_ty.kind() {
_ if ptr_ty.is_box() => (
UNKNOWN_REGION,
ptr_ty.boxed_ty().expect("Could not unbox boxed type??"),
),
TyKind::RawPtr(ty, _) => (UNKNOWN_REGION, *ty),
TyKind::Ref(Region(Interned(RegionKind::ReVar(region), _)), ty, _) => {
(*region, *ty)
}
_ => return aliases,
};
let region_loans = self
.loans
.get(®ion)
.map(|loans| loans.iter())
.into_iter()
.flatten();
let region_aliases = region_loans.map(|(loan, _)| {
let loan_ty = loan.ty(self.body.local_decls(), self.tcx).ty;
if orig_ty == loan_ty {
let mut projection = loan.projection.to_vec();
projection.extend(after.iter().copied());
Place::make(loan.local, &projection, self.tcx)
} else {
*loan
}
});
aliases.extend(region_aliases);
log::trace!("Aliases for place {place:?} are {aliases:?}");
aliases
}
}
fn generate_conservative_constraints<'tcx>(
tcx: TyCtxt<'tcx>,
body: &Body<'tcx>,
region_to_pointers: &HashMap<RegionVid, Vec<(Place<'tcx>, Mutability)>>,
) -> Vec<(RegionVid, RegionVid)> {
let get_ty = |p| tcx.mk_place_deref(p).ty(body.local_decls(), tcx).ty;
let same_ty = |p1, p2| get_ty(p1) == get_ty(p2);
region_to_pointers
.iter()
.flat_map(|(region, places)| {
let regions_with_place = region_to_pointers
.iter()
.filter(|(other_region, other_places)| {
*region != **other_region
&& places.iter().any(|(place, _)| {
other_places
.iter()
.any(|(other_place, _)| same_ty(*place, *other_place))
})
});
regions_with_place
.flat_map(|(other_region, _)| {
[(*region, *other_region), (*other_region, *region)]
})
.collect::<Vec<_>>()
})
.collect::<Vec<_>>()
}
#[cfg(test)]
mod test {
use rustc_utils::{
hashset,
test_utils::{compare_sets, Placer},
};
use super::*;
use crate::test_utils;
fn alias_harness(
input: &str,
f: impl for<'tcx> FnOnce(TyCtxt<'tcx>, &Body<'tcx>, Aliases<'tcx>) + Send,
) {
test_utils::compile_body(input, |tcx, body_id, body_with_facts| {
let body = &body_with_facts.body;
let def_id = tcx.hir().body_owner_def_id(body_id);
let aliases = Aliases::build(tcx, def_id.to_def_id(), body_with_facts);
f(tcx, body, aliases)
});
}
#[test]
fn test_aliases_basic() {
let input = r#"
fn main() {
fn foo<'a, 'b>(x: &'a i32, y: &'b i32) -> &'a i32 { x }
let a = 1;
let b = 2;
let c = &a;
let d = &b;
let e = foo(c, d);
}
"#;
alias_harness(input, |tcx, body, aliases| {
let p = Placer::new(tcx, body);
let d_deref = p.local("d").deref().mk();
let e_deref = p.local("e").deref().mk();
compare_sets(aliases.aliases(e_deref), hashset! { p.local("a").mk()});
compare_sets(aliases.aliases(d_deref), hashset! { p.local("b").mk() });
});
}
#[test]
fn test_aliases_projection() {
let input = r#"
fn main() {
let a = vec![0];
let b = a.get(0).unwrap();
let c = (0, 0);
let d = &c.1;
}
"#;
alias_harness(input, |tcx, body, aliases| {
let p = Placer::new(tcx, body);
let b_deref = p.local("b").deref().mk();
let d_deref = p.local("d").deref().mk();
compare_sets(aliases.aliases(b_deref), hashset! { p.local("a").mk() });
compare_sets(
aliases.aliases(d_deref),
hashset! { p.local("c").field(1).mk() },
);
});
}
}