use std::{cell::RefCell, rc::Rc};
use indexical::impls::RustcIndexMatrix as IndexMatrix;
use log::{debug, trace};
use rustc_data_structures::fx::FxHashMap as HashMap;
use rustc_hir::{def_id::DefId, BodyId};
use rustc_middle::{
mir::{visit::Visitor, *},
ty::TyCtxt,
};
use rustc_mir_dataflow::Analysis;
use rustc_utils::{
mir::{
control_dependencies::ControlDependencies,
location_or_arg::{
index::{LocationOrArgDomain, LocationOrArgSet},
LocationOrArg,
},
},
BodyExt, OperandExt, PlaceExt,
};
use smallvec::SmallVec;
use super::{
mutation::{ModularMutationVisitor, Mutation, MutationStatus},
FlowResults,
};
use crate::{
extensions::{is_extension_active, ContextMode, MutabilityMode},
mir::placeinfo::PlaceInfo,
};
pub type FlowDomain<'tcx> = IndexMatrix<Place<'tcx>, LocationOrArg>;
pub struct FlowAnalysis<'tcx> {
pub tcx: TyCtxt<'tcx>,
pub def_id: DefId,
pub body: &'tcx Body<'tcx>,
pub place_info: PlaceInfo<'tcx>,
pub(crate) control_dependencies: ControlDependencies<BasicBlock>,
pub(crate) recurse_cache: RefCell<HashMap<BodyId, FlowResults<'tcx>>>,
}
impl<'tcx> FlowAnalysis<'tcx> {
pub fn new(
tcx: TyCtxt<'tcx>,
def_id: DefId,
body: &'tcx Body<'tcx>,
place_info: PlaceInfo<'tcx>,
) -> Self {
let recurse_cache = RefCell::new(HashMap::default());
let control_dependencies = body.control_dependencies();
debug!("Control dependencies: {control_dependencies:?}");
FlowAnalysis {
tcx,
def_id,
body,
place_info,
control_dependencies,
recurse_cache,
}
}
pub fn location_domain(&self) -> &Rc<LocationOrArgDomain> {
self.place_info.location_domain()
}
fn provenance(&self, place: Place<'tcx>) -> SmallVec<[Place<'tcx>; 8]> {
place
.refs_in_projection(self.body, self.tcx)
.flat_map(|(place_ref, _)| {
self
.place_info
.aliases(Place::from_ref(place_ref, self.tcx))
})
.copied()
.collect()
}
fn influences(&self, place: Place<'tcx>) -> SmallVec<[Place<'tcx>; 8]> {
let conflicts = self.place_info.aliases(place).iter().copied();
conflicts
.chain(self.provenance(place))
.flat_map(|alias| self.place_info.conflicts(alias))
.copied()
.collect()
}
pub fn deps_for(
&self,
state: &FlowDomain<'tcx>,
place: Place<'tcx>,
) -> LocationOrArgSet {
let mut deps = LocationOrArgSet::new(self.location_domain());
for subplace in self
.place_info
.reachable_values(place, Mutability::Not)
.iter()
.flat_map(|place| self.influences(*place))
{
deps.union(state.row_set(&self.place_info.normalize(subplace)));
}
deps
}
pub(crate) fn transfer_function(
&self,
state: &mut FlowDomain<'tcx>,
mutations: Vec<Mutation<'tcx>>,
location: Location,
) {
debug!(" Applying mutations {mutations:?}");
let location_domain = self.location_domain();
let mut all_deps = {
let mut deps = LocationOrArgSet::new(location_domain);
deps.insert(location);
vec![deps; mutations.len()]
};
let add_deps = |state: &FlowDomain<'tcx>,
input,
target_deps: &mut LocationOrArgSet| {
for relevant in self.influences(input) {
let relevant_deps = state.row_set(&self.place_info.normalize(relevant));
trace!(" For relevant {relevant:?} for input {input:?} adding deps {relevant_deps:?}");
target_deps.union(relevant_deps);
}
};
for (mt, deps) in mutations.iter().zip(&mut all_deps) {
for input in &mt.inputs {
add_deps(state, *input, deps);
}
}
let controlled_by = self.control_dependencies.dependent_on(location.block);
let body = self.body;
for block in controlled_by.into_iter().flat_map(|set| set.iter()) {
for deps in &mut all_deps {
deps.insert(body.terminator_loc(block));
}
let terminator = body.basic_blocks[block].terminator();
if let TerminatorKind::SwitchInt { discr, .. } = &terminator.kind {
if let Some(discr_place) = discr.as_place() {
for deps in &mut all_deps {
add_deps(state, discr_place, deps);
}
}
}
}
let ignore_mut =
is_extension_active(|mode| mode.mutability_mode == MutabilityMode::IgnoreMut);
for (mt, deps) in mutations.iter().zip(&mut all_deps) {
if matches!(mt.status, MutationStatus::Definitely)
&& self.place_info.aliases(mt.mutated).len() == 1
{
for sub in self.place_info.children(mt.mutated).iter() {
state.clear_row(&self.place_info.normalize(*sub));
}
}
for place in self.provenance(mt.mutated) {
for conflict in self.place_info.conflicts(place) {
deps.union(state.row_set(&self.place_info.normalize(*conflict)));
}
}
let mutable_aliases = self
.place_info
.aliases(mt.mutated)
.iter()
.filter(|alias| {
let has_immut = alias.iter_projections().any(|(sub_place, _)| {
let ty = sub_place.ty(body.local_decls(), self.tcx).ty;
matches!(ty.ref_mutability(), Some(Mutability::Not))
});
!has_immut || ignore_mut
})
.collect::<SmallVec<[_; 8]>>();
debug!(" Mutated places: {mutable_aliases:?}");
debug!(" with deps {deps:?}");
for alias in mutable_aliases {
state.union_into_row(self.place_info.normalize(*alias), deps);
}
}
}
}
impl<'tcx> Analysis<'tcx> for FlowAnalysis<'tcx> {
type Domain = FlowDomain<'tcx>;
const NAME: &'static str = "FlowAnalysis";
fn bottom_value(&self, _body: &Body<'tcx>) -> Self::Domain {
FlowDomain::new(self.location_domain())
}
fn initialize_start_block(&self, _body: &Body<'tcx>, state: &mut Self::Domain) {
for (arg, loc) in self.place_info.all_args() {
for place in self.place_info.conflicts(arg) {
debug!(
"arg={arg:?} / place={place:?} / loc={:?}",
self.location_domain().value(loc)
);
state.insert(self.place_info.normalize(*place), loc);
}
}
}
fn apply_primary_statement_effect(
&mut self,
state: &mut Self::Domain,
statement: &Statement<'tcx>,
location: Location,
) {
ModularMutationVisitor::new(&self.place_info, |_, mutations| {
self.transfer_function(state, mutations, location)
})
.visit_statement(statement, location);
}
fn apply_primary_terminator_effect<'mir>(
&mut self,
state: &mut Self::Domain,
terminator: &'mir Terminator<'tcx>,
location: Location,
) -> TerminatorEdges<'mir, 'tcx> {
if matches!(terminator.kind, TerminatorKind::Call { .. })
&& is_extension_active(|mode| mode.context_mode == ContextMode::Recurse)
&& self.recurse_into_call(state, &terminator.kind, location)
{
return terminator.edges();
}
ModularMutationVisitor::new(&self.place_info, |_, mutations| {
self.transfer_function(state, mutations, location)
})
.visit_terminator(terminator, location);
terminator.edges()
}
fn apply_call_return_effect(
&mut self,
_state: &mut Self::Domain,
_block: BasicBlock,
_return_places: CallReturnPlaces<'_, 'tcx>,
) {
}
}