use itertools::Itertools;
use log::debug;
use rustc_middle::{
mir::{visit::Visitor, *},
ty::{AdtKind, TyKind},
};
use rustc_target::abi::FieldIdx;
use rustc_utils::{mir::place::PlaceCollector, AdtDefExt, OperandExt, PlaceExt};
use crate::mir::{
placeinfo::PlaceInfo,
utils::{self, AsyncHack},
};
#[derive(Debug)]
pub enum MutationStatus {
Definitely,
Possibly,
}
#[derive(Debug)]
pub enum Reason {
Argument(u8),
AssignTarget,
}
#[derive(Debug)]
pub struct Mutation<'tcx> {
pub mutated: Place<'tcx>,
pub reason: Reason,
pub inputs: Vec<Place<'tcx>>,
pub status: MutationStatus,
}
pub struct ModularMutationVisitor<'a, 'tcx, F>
where
F: FnMut(Location, Vec<Mutation<'tcx>>),
{
f: F,
place_info: &'a PlaceInfo<'tcx>,
}
impl<'a, 'tcx, F> ModularMutationVisitor<'a, 'tcx, F>
where
F: FnMut(Location, Vec<Mutation<'tcx>>),
{
pub fn new(place_info: &'a PlaceInfo<'tcx>, f: F) -> Self {
ModularMutationVisitor { place_info, f }
}
}
impl<'tcx, F> Visitor<'tcx> for ModularMutationVisitor<'_, 'tcx, F>
where
F: FnMut(Location, Vec<Mutation<'tcx>>),
{
fn visit_assign(
&mut self,
mutated: &Place<'tcx>,
rvalue: &Rvalue<'tcx>,
location: Location,
) {
debug!("Checking {location:?}: {mutated:?} = {rvalue:?}");
let body = self.place_info.body;
let tcx = self.place_info.tcx;
match rvalue {
Rvalue::Aggregate(agg_kind, ops) => {
let info = match &**agg_kind {
AggregateKind::Adt(def_id, idx, substs, _, _) => {
let adt_def = tcx.adt_def(*def_id);
let variant = adt_def.variant(*idx);
let mutated = match adt_def.adt_kind() {
AdtKind::Enum => mutated.project_deeper(
&[ProjectionElem::Downcast(Some(variant.name), *idx)],
tcx,
),
AdtKind::Struct | AdtKind::Union => *mutated,
};
let fields = variant.fields.iter();
let tys = fields
.map(|field| field.ty(tcx, substs))
.collect::<Vec<_>>();
Some((mutated, tys))
}
AggregateKind::Tuple => {
let ty = rvalue.ty(body.local_decls(), tcx);
Some((*mutated, ty.tuple_fields().to_vec()))
}
AggregateKind::Closure(_, args) => {
let ty = args.as_closure().upvar_tys();
Some((*mutated, ty.to_vec()))
}
_ => None,
};
if let Some((mutated, tys)) = info {
if tys.len() > 0 {
let fields =
tys
.into_iter()
.enumerate()
.zip(ops.iter())
.map(|((i, ty), input_op)| {
let field = PlaceElem::Field(FieldIdx::from_usize(i), ty);
let input_place = input_op.as_place();
(mutated.project_deeper(&[field], tcx), input_place)
});
let mutations = fields
.map(|(mutated, input)| Mutation {
mutated,
reason: Reason::AssignTarget,
inputs: input.into_iter().collect::<Vec<_>>(),
status: MutationStatus::Definitely,
})
.collect::<Vec<_>>();
(self.f)(location, mutations);
return;
}
}
}
Rvalue::Use(Operand::Move(place) | Operand::Copy(place)) => {
let place_ty = place.ty(&body.local_decls, tcx).ty;
if let TyKind::Adt(adt_def, substs) = place_ty.kind() {
if adt_def.is_struct() {
let fields = adt_def
.all_visible_fields(self.place_info.def_id, self.place_info.tcx)
.enumerate()
.map(|(i, field_def)| {
PlaceElem::Field(FieldIdx::from_usize(i), field_def.ty(tcx, substs))
});
let mut mutations = fields
.map(|field| {
let mutated_field = mutated.project_deeper(&[field], tcx);
let input_field = place.project_deeper(&[field], tcx);
Mutation {
mutated: mutated_field,
reason: Reason::AssignTarget,
inputs: vec![input_field],
status: MutationStatus::Definitely,
}
})
.collect::<Vec<_>>();
if mutations.is_empty() {
mutations.push(Mutation {
mutated: *mutated,
reason: Reason::AssignTarget,
inputs: vec![*place],
status: MutationStatus::Definitely,
});
}
(self.f)(location, mutations);
return;
}
}
}
Rvalue::Ref(_, _, place) => {
let inputs = place
.refs_in_projection(self.place_info.body, self.place_info.tcx)
.map(|(place_ref, _)| Place::from_ref(place_ref, tcx))
.collect::<Vec<_>>();
(self.f)(location, vec![Mutation {
mutated: *mutated,
reason: Reason::AssignTarget,
inputs,
status: MutationStatus::Definitely,
}]);
return;
}
_ => {}
}
let mut collector = PlaceCollector::default();
collector.visit_rvalue(rvalue, location);
(self.f)(location, vec![Mutation {
mutated: *mutated,
reason: Reason::AssignTarget,
inputs: collector.0,
status: MutationStatus::Definitely,
}]);
}
fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, location: Location) {
debug!("Checking {location:?}: {:?}", terminator.kind);
let tcx = self.place_info.tcx;
match &terminator.kind {
TerminatorKind::Call {
args,
destination,
..
} => {
let async_hack = AsyncHack::new(
self.place_info.tcx,
self.place_info.body,
self.place_info.def_id,
);
let arg_places = utils::arg_places(args)
.into_iter()
.map(|(_, place)| place)
.filter(|place| !async_hack.ignore_place(*place))
.collect::<Vec<_>>();
let arg_inputs = arg_places
.iter()
.flat_map(|arg| self.place_info.reachable_values(*arg, Mutability::Not))
.copied()
.collect_vec();
let ret_is_unit = destination
.ty(self.place_info.body.local_decls(), tcx)
.ty
.is_unit();
let dest_inputs = if ret_is_unit {
Vec::new()
} else {
arg_inputs.clone()
};
let mut mutations = vec![Mutation {
mutated: *destination,
inputs: dest_inputs,
reason: Reason::AssignTarget,
status: MutationStatus::Definitely,
}];
for (num, arg) in arg_places.into_iter().enumerate() {
for arg_mut in self.place_info.reachable_values(arg, Mutability::Mut) {
if *arg_mut != arg {
mutations.push(Mutation {
mutated: *arg_mut,
reason: Reason::Argument(num as u8),
inputs: arg_inputs.clone(),
status: MutationStatus::Possibly,
});
}
}
}
(self.f)(location, mutations);
}
_ => {}
}
}
}