use std::{rc::Rc, time::Instant};
use self::call_string_resolver::CallStringResolver;
use super::{default_index, path_for_item, src_loc_for_span, SPDGGenerator};
use crate::{
ann::MarkerAnnotation, desc::*, discover::FnToAnalyze, stats::TimedStat, utils::*, HashMap,
HashSet, MarkerCtx,
};
use flowistry_pdg::SourceUse;
use flowistry_pdg_construction::{
body_cache::BodyCache,
graph::{DepEdge, DepEdgeKind, DepGraph, DepNode},
is_async_trait_fn, match_async_trait_assign,
utils::{handle_shims, try_monomorphize, try_resolve_function, type_as_fn},
};
use paralegal_spdg::{Node, SPDGStats};
use rustc_hir::def_id::{DefId, LocalDefId};
use rustc_middle::{
mir,
ty::{self, TyCtxt, TypingEnv},
};
use anyhow::Result;
use either::Either;
use flowistry::mir::FlowistryInput;
use petgraph::{
visit::{IntoNodeReferences, NodeIndexable, NodeRef},
Direction,
};
pub struct GraphConverter<'tcx, 'a, C> {
generator: &'a SPDGGenerator<'tcx>,
target: &'a FnToAnalyze,
dep_graph: Rc<DepGraph<'tcx>>,
def_id: DefId,
known_def_ids: &'a mut C,
types: HashMap<Node, Vec<DefId>>,
index_map: Box<[Node]>,
spdg: SPDGImpl,
marker_assignments: HashMap<Node, HashSet<Identifier>>,
call_string_resolver: call_string_resolver::CallStringResolver<'tcx, 'a>,
stats: SPDGStats,
}
impl<'a, 'tcx, C: Extend<DefId>> GraphConverter<'tcx, 'a, C> {
pub fn new_with_flowistry(
generator: &'a SPDGGenerator<'tcx>,
known_def_ids: &'a mut C,
target: &'a FnToAnalyze,
) -> Result<Self> {
let local_def_id = target.def_id;
let (dep_graph, stats) = generator.stats.measure(TimedStat::Flowistry, || {
Self::create_flowistry_graph(generator, local_def_id)
})?;
if generator.opts.dbg().dump_flowistry_pdg() {
dep_graph.generate_graphviz(format!(
"{}.flowistry-pdg.pdf",
generator.tcx.def_path_str(target.def_id)
))?
}
let def_id = local_def_id.to_def_id();
Ok(Self {
generator,
known_def_ids,
target,
index_map: vec![default_index(); dep_graph.graph.node_bound()].into(),
dep_graph: dep_graph.into(),
def_id,
types: Default::default(),
spdg: Default::default(),
marker_assignments: Default::default(),
call_string_resolver: CallStringResolver::new(
generator.tcx,
def_id,
generator.pdg_constructor.body_cache(),
generator.marker_ctx().clone(),
),
stats,
})
}
fn tcx(&self) -> TyCtxt<'tcx> {
self.generator.tcx
}
fn marker_ctx(&self) -> &MarkerCtx<'tcx> {
self.generator.marker_ctx()
}
fn entrypoint_is_async(&self) -> bool {
entrypoint_is_async(self.body_cache(), self.tcx(), self.def_id)
}
fn register_node(&mut self, old: Node, new: NodeInfo) -> Node {
let new_node = self.spdg.add_node(new);
let r = &mut self.index_map[old.index()];
assert_eq!(*r, default_index());
*r = new_node;
new_node
}
fn new_node_for(&self, old: Node) -> Node {
let res = self.index_map[old.index()];
assert_ne!(res, default_index());
res
}
fn register_markers(&mut self, node: Node, markers: impl IntoIterator<Item = Identifier>) {
let mut markers = markers.into_iter().peekable();
if markers.peek().is_some() {
self.marker_assignments
.entry(node)
.or_default()
.extend(markers);
}
}
fn node_annotations(&mut self, old_node: Node, weight: &DepNode<'tcx>) {
let leaf_loc = weight.at.leaf();
let node = self.new_node_for(old_node);
let body = self.body_cache().get(leaf_loc.function).body();
let graph = self.dep_graph.clone();
match leaf_loc.location {
RichLocation::Start
if matches!(body.local_kind(weight.place.local), mir::LocalKind::Arg) =>
{
let function_id = leaf_loc.function;
let arg_num = weight.place.local.as_u32() - 1;
self.known_def_ids.extend(Some(function_id));
self.register_annotations_for_function(node, function_id, |ann| {
ann.refinement.on_argument().contains(arg_num).unwrap()
});
}
RichLocation::End if weight.place.local == mir::RETURN_PLACE => {
let function_id = leaf_loc.function;
self.known_def_ids.extend(Some(function_id));
self.register_annotations_for_function(node, function_id, |ann| {
ann.refinement.on_return()
});
}
RichLocation::Location(loc) => {
let crate::Either::Right(
term @ mir::Terminator {
kind: mir::TerminatorKind::Call { func, .. },
source_info,
},
) = body.stmt_at(loc)
else {
return;
};
debug!("Assigning markers to {:?}", term.kind);
let res = self.call_string_resolver.resolve(weight.at);
let param_env = TypingEnv::post_analysis(self.tcx(), res.def_id());
let func =
try_monomorphize(res, self.tcx(), param_env, func, source_info.span).unwrap();
let (inst, args) =
type_as_fn(self.tcx(), ty_of_const(func.constant().unwrap())).unwrap();
let mres = try_resolve_function(
self.tcx(),
inst,
TypingEnv::post_analysis(self.tcx(), leaf_loc.function),
args,
)
.map(|inst| {
handle_shims(inst, self.tcx(), param_env, source_info.span)
.map_or(inst, |t| t.0)
});
if mres.is_none() {
debug!("Could not resolve {inst:?} properly during marker assignment");
} else {
debug!("Function monomorphized to {:?}", mres.unwrap().def_id());
}
let f = mres.map_or(inst, |i| i.def_id());
self.known_def_ids.extend(Some(f));
let mut in_edges = graph
.graph
.edges_directed(old_node, Direction::Incoming)
.filter(|e| e.weight().kind == DepEdgeKind::Data);
let needs_return_markers = in_edges.clone().next().is_none()
|| in_edges.any(|e| {
let at = e.weight().at;
#[cfg(debug_assertions)]
assert_edge_location_invariant(self.tcx(), at, body, weight.at);
weight.at == at && e.weight().target_use.is_return()
});
if needs_return_markers {
self.register_annotations_for_function(node, f, |ann| {
ann.refinement.on_return()
});
}
for e in graph.graph.edges_directed(old_node, Direction::Outgoing) {
let SourceUse::Argument(arg) = e.weight().source_use else {
continue;
};
self.register_annotations_for_function(node, f, |ann| {
ann.refinement.on_argument().contains(arg as u32).unwrap()
});
}
}
_ => (),
}
}
fn determine_place_type(
&self,
at: CallString,
place: mir::PlaceRef<'tcx>,
span: rustc_span::Span,
) -> Option<mir::tcx::PlaceTy<'tcx>> {
let tcx = self.tcx();
let locations = at.iter_from_root().collect::<Vec<_>>();
let (last, mut rest) = locations.split_last().unwrap();
if self.entrypoint_is_async() {
let (first, tail) = rest.split_first().unwrap();
assert!(expect_stmt_at(self.body_cache(), *first).is_left());
rest = tail;
}
let place = if self.entrypoint_is_async() && place.local.as_u32() == 1 && rest.len() == 1 {
if place.projection.is_empty() {
return None;
}
mir::Place {
local: place.local,
projection: self.tcx().mk_place_elems(&place.projection[..1]),
}
} else {
place.local.into()
};
let resolution = self.call_string_resolver.resolve(at);
let body = self.body_cache().get(last.function);
let raw_ty = place.ty(body.body(), tcx);
Some(
try_monomorphize(
resolution,
tcx,
TypingEnv::fully_monomorphized(),
&raw_ty,
span,
)
.unwrap(),
)
}
fn register_annotations_for_function(
&mut self,
node: Node,
function: DefId,
mut filter: impl FnMut(&MarkerAnnotation) -> bool,
) {
let parent = get_parent(self.tcx(), function);
let marker_ctx = self.marker_ctx().clone();
self.register_markers(
node,
marker_ctx
.combined_markers(function)
.chain(
parent
.into_iter()
.flat_map(|parent| marker_ctx.combined_markers(parent)),
)
.filter(|ann| filter(ann))
.map(|ann| ann.marker),
);
self.known_def_ids.extend(parent);
}
fn handle_node_types(&mut self, old_node: Node, weight: &DepNode<'tcx>) {
let i = self.new_node_for(old_node);
let Some(place_ty) =
self.determine_place_type(weight.at, weight.place.as_ref(), weight.span)
else {
return;
};
trace!("Node {:?} has place type {:?}", weight.place, place_ty);
let deep = true;
let mut node_types = self.type_is_marked(place_ty, deep).collect::<HashSet<_>>();
for (p, _) in weight.place.iter_projections() {
if let Some(place_ty) = self.determine_place_type(weight.at, p, weight.span) {
node_types.extend(self.type_is_marked(place_ty, false));
}
}
self.known_def_ids.extend(node_types.iter().copied());
let tcx = self.tcx();
if !node_types.is_empty() {
self.types.entry(i).or_default().extend(
node_types
.iter()
.filter(|t| !tcx.is_coroutine(**t) && !tcx.def_kind(*t).is_fn_like()),
)
}
trace!(
"For node {:?} found marked node types {node_types:?}",
weight.place
);
}
fn create_flowistry_graph(
generator: &SPDGGenerator<'tcx>,
local_def_id: LocalDefId,
) -> Result<(DepGraph<'tcx>, SPDGStats)> {
let pdg = generator.pdg_constructor.construct_graph(local_def_id);
Ok((pdg, Default::default()))
}
pub fn make_spdg(mut self) -> SPDG {
let start = Instant::now();
self.make_spdg_impl();
let arguments = self.determine_arguments();
let return_ = self.determine_return();
self.generator
.stats
.record_timed(TimedStat::Conversion, start.elapsed());
self.stats.conversion_time = start.elapsed();
SPDG {
path: path_for_item(self.def_id, self.tcx()),
graph: self.spdg,
id: self.def_id,
name: Identifier::new(self.target.name()),
arguments,
markers: self
.marker_assignments
.into_iter()
.map(|(k, v)| (k, v.into_iter().collect()))
.collect(),
return_,
type_assigns: self
.types
.into_iter()
.map(|(k, v)| (k, Types(v.into())))
.collect(),
statistics: self.stats,
}
}
fn body_cache(&self) -> &BodyCache<'tcx> {
self.generator.pdg_constructor.body_cache()
}
fn make_spdg_impl(&mut self) {
use petgraph::prelude::*;
let g_ref = self.dep_graph.clone();
let input = &g_ref.graph;
let tcx = self.tcx();
for (i, weight) in input.node_references() {
let at = weight.at.leaf();
let body = self.body_cache().get(at.function).body();
let node_span = body.local_decls[weight.place.local].source_info.span;
self.register_node(
i,
NodeInfo {
at: weight.at,
description: format!("{:?}", weight.place),
span: src_loc_for_span(node_span, tcx),
},
);
self.node_annotations(i, weight);
self.handle_node_types(i, weight);
}
for e in input.edge_references() {
let DepEdge {
kind,
at,
source_use,
target_use,
} = *e.weight();
self.spdg.add_edge(
self.new_node_for(e.source()),
self.new_node_for(e.target()),
EdgeInfo {
at,
kind: match kind {
DepEdgeKind::Control => EdgeKind::Control,
DepEdgeKind::Data => EdgeKind::Data,
},
source_use,
target_use,
},
);
}
}
fn type_is_marked(
&'a self,
typ: mir::tcx::PlaceTy<'tcx>,
deep: bool,
) -> impl Iterator<Item = TypeId> + 'a {
if deep {
Either::Left(self.marker_ctx().deep_type_markers(typ.ty).iter().copied())
} else {
Either::Right(self.marker_ctx().shallow_type_markers(typ.ty))
}
.map(|(d, _)| d)
}
fn try_as_root(&self, at: CallString) -> Option<GlobalLocation> {
if self.entrypoint_is_async() && at.len() == 2 {
at.iter_from_root().nth(1)
} else if at.is_at_root() {
Some(at.leaf())
} else {
None
}
}
fn determine_return(&self) -> Box<[Node]> {
self.spdg
.node_references()
.filter(|n| {
let weight = n.weight();
let at = weight.at;
matches!(self.try_as_root(at), Some(l) if l.location == RichLocation::End)
})
.map(|n| n.id())
.collect()
}
fn determine_arguments(&self) -> Box<[Node]> {
let mut g_nodes: Vec<_> = self
.dep_graph
.graph
.node_references()
.filter(|n| {
let at = n.weight().at;
let is_candidate =
matches!(self.try_as_root(at), Some(l) if l.location == RichLocation::Start);
is_candidate
})
.collect();
g_nodes.sort_by_key(|(_, i)| i.place.local);
g_nodes
.into_iter()
.map(|n| self.new_node_for(n.id()))
.collect()
}
}
#[cfg(debug_assertions)]
fn assert_edge_location_invariant<'tcx>(
tcx: TyCtxt<'tcx>,
at: CallString,
body: &mir::Body<'tcx>,
location: CallString,
) {
if location == at {
return;
}
if let RichLocation::Location(loc) = at.leaf().location {
if at.leaf().function == location.leaf().function
&& matches!(
body.stmt_at(loc),
Either::Right(mir::Terminator {
kind: mir::TerminatorKind::SwitchInt { .. },
..
})
)
{
return;
}
}
let mut msg = tcx.dcx().struct_span_fatal(
(body, at.leaf().location).span(tcx),
format!(
"This operation is performed in a different location: {}",
at
),
);
msg.span_note(
(body, location.leaf().location).span(tcx),
format!("Expected to originate here: {}", at),
);
msg.emit()
}
fn expect_stmt_at<'tcx>(
body_cache: &BodyCache<'tcx>,
loc: GlobalLocation,
) -> Either<&'tcx mir::Statement<'tcx>, &'tcx mir::Terminator<'tcx>> {
let body = &body_cache.get(loc.function).body();
let RichLocation::Location(loc) = loc.location else {
unreachable!();
};
body.stmt_at(loc)
}
fn get_parent(tcx: TyCtxt, did: DefId) -> Option<DefId> {
let ident = tcx.opt_item_ident(did)?;
let kind = match tcx.def_kind(did) {
kind if kind.is_fn_like() => ty::AssocKind::Fn,
_ => return None,
};
let r#impl = tcx.impl_of_method(did)?;
let r#trait = tcx.trait_id_of_impl(r#impl)?;
let id = tcx
.associated_items(r#trait)
.find_by_name_and_kind(tcx, ident, kind, r#trait)?
.def_id;
Some(id)
}
fn entrypoint_is_async<'tcx>(
body_cache: &BodyCache<'tcx>,
tcx: TyCtxt<'tcx>,
def_id: DefId,
) -> bool {
tcx.asyncness(def_id).is_async()
|| is_async_trait_fn(tcx, def_id, body_cache.get(def_id).body())
}
mod call_string_resolver {
use std::cell::OnceCell;
use flowistry_pdg::CallString;
use flowistry_pdg_construction::{
body_cache::BodyCache,
utils::{manufacture_substs_for, try_monomorphize, try_resolve_function},
};
use paralegal_spdg::Endpoint;
use rustc_middle::{
mir::TerminatorKind,
ty::{Instance, TypingEnv},
};
use rustc_utils::cache::Cache;
use crate::{Either, MarkerCtx, TyCtxt};
use super::{func_of_term, map_either, match_async_trait_assign, AsFnAndArgs};
pub struct CallStringResolver<'tcx, 'a> {
cache: Cache<CallString, Instance<'tcx>>,
tcx: TyCtxt<'tcx>,
entrypoint_is_async: bool,
body_cache: &'a BodyCache<'tcx>,
marker_context: MarkerCtx<'tcx>,
base: OnceCell<Instance<'tcx>>,
}
impl<'tcx, 'a> CallStringResolver<'tcx, 'a> {
pub fn resolve(&self, cs: CallString) -> Instance<'tcx> {
let (this, opt_prior_loc) = cs.pop();
if let Some(prior_loc) = opt_prior_loc {
if prior_loc.len() != 1 || !self.entrypoint_is_async {
return self.resolve_internal(prior_loc);
}
}
let def_id = this.function;
*self.base.get_or_init(|| {
try_resolve_function(
self.tcx,
def_id,
TypingEnv::post_analysis(self.tcx, def_id),
manufacture_substs_for(self.tcx, def_id).unwrap(),
)
.unwrap()
})
}
pub fn new(
tcx: TyCtxt<'tcx>,
entrypoint: Endpoint,
body_cache: &'a BodyCache<'tcx>,
marker_context: MarkerCtx<'tcx>,
) -> Self {
Self {
cache: Default::default(),
tcx,
entrypoint_is_async: super::entrypoint_is_async(body_cache, tcx, entrypoint),
body_cache,
marker_context,
base: Default::default(),
}
}
fn resolve_internal(&self, cs: CallString) -> Instance<'tcx> {
*self.cache.get(&cs, |_| {
let this = cs.leaf();
let prior = self.resolve(cs);
let tcx = self.tcx;
let base_stmt = super::expect_stmt_at(self.body_cache, this);
let param_env = TypingEnv::post_analysis(tcx, prior.def_id())
.with_post_analysis_normalized(tcx);
let normalized = map_either(
base_stmt,
|stmt| {
try_monomorphize(prior, tcx, param_env, stmt, stmt.source_info.span)
.unwrap()
},
|term| {
try_monomorphize(prior, tcx, param_env, term, term.source_info.span)
.unwrap()
},
);
let res = match normalized {
Either::Right(term) => {
let (def_id, args) = func_of_term(tcx, &term).unwrap();
let instance = Instance::expect_resolve(
tcx,
param_env,
def_id,
args,
term.source_info.span,
);
if let Some(model) = self.marker_context.has_stub(def_id) {
let TerminatorKind::Call { args, .. } = &term.kind else {
unreachable!()
};
model
.apply(tcx, instance, param_env, args, term.source_info.span)
.unwrap()
.0
} else {
term.as_instance_and_args(tcx).unwrap().0
}
}
Either::Left(stmt) => {
let (def_id, generics) = match_async_trait_assign(&stmt).unwrap();
try_resolve_function(tcx, def_id, param_env, generics).unwrap()
}
};
res
})
}
}
}