use super::{path_for_item, src_loc_for_span, SPDGGenerator};
use crate::{
ann::MarkerAnnotation, desc::*, discover::FnToAnalyze, utils::*, HashMap, HashSet, MarkerCtx,
Pctx,
};
use flowistry_pdg::{rustc_portable::Location, SourceUse};
use flowistry_pdg_construction::{
call_tree_visit::{VisitDriver, Visitor},
determine_async,
graph::{DepEdge, DepEdgeKind, DepNode, OneHopLocation, PartialGraph},
utils::{handle_shims, try_monomorphize, try_resolve_function, type_as_fn, ShimResult},
};
use paralegal_spdg::Node;
use rustc_hash::FxHashSet;
use rustc_hir::def_id::DefId;
use rustc_middle::{
mir,
ty::{Instance, TyCtxt, TypingEnv},
};
use either::Either;
use flowistry::mir::FlowistryInput;
use petgraph::visit::{IntoNodeReferences, NodeRef};
fn dep_edge_kind_to_edge_kind(kind: DepEdgeKind) -> EdgeKind {
match kind {
DepEdgeKind::Control => EdgeKind::Control,
DepEdgeKind::Data => EdgeKind::Data,
}
}
#[cfg(debug_assertions)]
fn assert_edge_location_invariant<'tcx, Loc: Eq + std::fmt::Display>(
tcx: TyCtxt<'tcx>,
at: Loc,
body: &mir::Body<'tcx>,
location: Loc,
leaf: impl Fn(&Loc) -> GlobalLocation,
) {
if location == at {
return;
}
if let RichLocation::Location(loc) = leaf(&at).location {
if leaf(&at).function == leaf(&location).function
&& matches!(
body.stmt_at(loc),
Either::Right(mir::Terminator {
kind: mir::TerminatorKind::SwitchInt { .. },
..
})
)
{
return;
}
}
let mut msg = tcx.dcx().struct_span_fatal(
(body, leaf(&at).location).span(tcx),
format!(
"This operation is performed in a different location: {}",
at
),
);
msg.span_note(
(body, leaf(&location).location).span(tcx),
format!("Expected to originate here: {}", at),
);
msg.emit()
}
struct GraphAssembler<'tcx, 'a> {
generator: &'a SPDGGenerator<'tcx>,
is_async: bool,
nodes: Vec<Vec<GNode>>,
graph: SPDGImpl,
marker_assignments: HashMap<GNode, HashSet<Identifier>>,
types: HashMap<GNode, Vec<DefId>>,
known_def_ids: &'a mut FxHashSet<DefId>,
}
pub fn assemble_pdg<'a>(
generator: &'a SPDGGenerator<'_>,
known_def_ids: &'a mut FxHashSet<DefId>,
target: &'a FnToAnalyze,
) -> SPDG {
let tcx = generator.tcx();
let base_body_def_id = target.def_id.to_def_id();
let base_body = generator
.pdg_constructor
.body_cache()
.try_get(base_body_def_id)
.unwrap_or_else(|| {
panic!("INVARIANT VIOLATED: body for local function {base_body_def_id:?} cannot be loaded.",)
})
.body();
let async_state = determine_async(tcx, base_body_def_id, base_body);
let possibly_generator_id =
async_state.map_or(base_body_def_id, |(generator, ..)| generator.def_id());
let (possible_generator_instance, k) = generator
.pdg_constructor
.create_root_key(possibly_generator_id.expect_local());
let mut driver = VisitDriver::new(&generator.pdg_constructor, possible_generator_instance, k);
let mut assembler = GraphAssembler::new(generator, known_def_ids, target.def_id.to_def_id());
if let Some((_, loc, ..)) = async_state {
driver.with_pushed_stack(
GlobalLocation {
function: base_body_def_id,
location: RichLocation::Location(loc),
},
|driver| {
driver.start(&mut assembler);
},
);
let base_instance = generator
.pdg_constructor
.create_root_key(base_body_def_id.expect_local())
.0;
assembler.fix_async_args(base_instance, loc, &mut driver);
} else {
driver.start(&mut assembler);
}
let return_ = assembler.determine_return();
let arguments = assembler.determine_arguments();
let graph = assembler.graph;
SPDG {
name: Identifier::new(target.name()),
path: path_for_item(target.def_id.to_def_id(), tcx),
id: target.def_id.to_def_id(),
graph,
markers: assembler
.marker_assignments
.into_iter()
.map(|(GNode(node), markers)| (node, markers.into_iter().collect()))
.collect(),
arguments,
return_,
type_assigns: assembler
.types
.into_iter()
.map(|(GNode(k), v)| (k, Types(v.into())))
.collect(),
statistics: Default::default(),
}
}
impl<'tcx, 'a> GraphAssembler<'tcx, 'a> {
fn new(
generator: &'a SPDGGenerator<'tcx>,
known_def_ids: &'a mut FxHashSet<DefId>,
def_id: DefId,
) -> Self {
let is_async = entrypoint_is_async(
generator.pdg_constructor.body_cache(),
generator.tcx(),
def_id,
);
Self {
graph: SPDGImpl::new(),
nodes: Default::default(),
marker_assignments: Default::default(),
known_def_ids,
types: Default::default(),
generator,
is_async,
}
}
fn register_markers(&mut self, node: GNode, markers: impl IntoIterator<Item = Identifier>) {
let mut markers = markers.into_iter().peekable();
if markers.peek().is_some() {
self.marker_assignments
.entry(node)
.or_default()
.extend(markers);
}
}
fn add_node<K: Clone>(
&mut self,
node: Node,
vis: &mut VisitDriver<'tcx, '_, K>,
weight: &DepNode<'tcx, OneHopLocation>,
) -> GNode {
let weight = globalize_node(vis, weight, self.tcx());
let table = self.nodes.last_mut().unwrap();
let prior = table[node.index()];
if GNode(Node::end()) != prior {
prior
} else {
let my_idx = GNode(self.graph.add_node(weight));
table[node.index()] = my_idx;
my_idx
}
}
fn add_untranslatable_node(
&mut self,
place: mir::Place,
at: CallString,
span: rustc_span::Span,
) -> GNode {
GNode(self.graph.add_node(NodeInfo {
at,
description: format!("{place:?}"),
span: src_loc_for_span(span, self.tcx()),
local: place.local.as_u32(),
}))
}
fn ctx(&self) -> &Pctx<'tcx> {
&self.generator.ctx
}
fn tcx(&self) -> TyCtxt<'tcx> {
self.ctx().tcx()
}
fn marker_ctx(&self) -> &MarkerCtx<'tcx> {
self.generator.marker_ctx()
}
fn register_annotations_for_function(
&mut self,
node: GNode,
function: DefId,
mut filter: impl FnMut(&MarkerAnnotation) -> bool,
) {
let parent = get_parent(self.tcx(), function);
let marker_ctx = self.marker_ctx().clone();
self.register_markers(
node,
marker_ctx
.combined_markers(function)
.chain(
parent
.into_iter()
.flat_map(|parent| marker_ctx.combined_markers(parent)),
)
.filter(|ann| filter(ann))
.map(|ann| ann.marker),
);
self.known_def_ids.extend(parent);
}
fn handle_node_types<K: Clone>(
&mut self,
node: GNode,
at: &OneHopLocation,
place: mir::Place<'tcx>,
span: rustc_span::Span,
vis: &VisitDriver<'tcx, '_, K>,
) {
trace!("Checking types for node {node:?} ({:?})", place);
let tcx = self.tcx();
let function = vis.current_function();
let (base_place, projections) =
if self.entrypoint_is_async() && place.local.as_u32() == 1 && at.in_child.is_none() {
if place.projection.is_empty() {
return;
}
let (base_project, rest) = place.projection.split_first().unwrap();
(
mir::Place {
local: place.local,
projection: self.tcx().mk_place_elems(&[*base_project]),
},
rest,
)
} else {
(place.local.into(), place.projection.as_slice())
};
trace!("Using base place {base_place:?} with projections {projections:?}");
let resolution = vis.current_function();
let body = self
.generator
.pdg_constructor
.body_cache()
.get(function.def_id())
.body();
let raw_ty = base_place.ty(body, tcx);
let base_ty = try_monomorphize(
resolution,
tcx,
TypingEnv::fully_monomorphized(),
&raw_ty,
span,
)
.unwrap();
self.handle_node_types_helper(node, base_ty, projections);
}
fn handle_node_types_helper(
&mut self,
node: GNode,
mut base_ty: mir::tcx::PlaceTy<'tcx>,
projections: &[mir::PlaceElem<'tcx>],
) {
trace!("Has place type {base_ty:?}");
let mut node_types = HashSet::new();
for proj in projections {
node_types.extend(self.type_is_marked(base_ty, false));
base_ty = base_ty.projection_ty(self.tcx(), *proj);
}
node_types.extend(self.type_is_marked(base_ty, true));
self.known_def_ids.extend(node_types.iter().copied());
let tcx = self.tcx();
if !node_types.is_empty() {
self.types.entry(node).or_default().extend(
node_types
.iter()
.filter(|t| !tcx.is_coroutine(**t) && !tcx.def_kind(*t).is_fn_like()),
)
}
trace!("Found marked node types {node_types:?}",);
}
fn type_is_marked(
&'a self,
typ: mir::tcx::PlaceTy<'tcx>,
deep: bool,
) -> impl Iterator<Item = TypeId> + 'a {
if deep {
Either::Left(self.marker_ctx().deep_type_markers(typ.ty).iter().copied())
} else {
Either::Right(self.marker_ctx().shallow_type_markers(typ.ty))
}
.map(|(d, _)| d)
}
fn node_annotations<K: Clone>(
&mut self,
local_node: Node,
node: GNode,
weight: &DepNode<'tcx, OneHopLocation>,
vis: &VisitDriver<'tcx, '_, K>,
) {
let leaf_loc = weight.at.location;
let function = vis.current_function();
let function_id = function.def_id();
let body = self
.generator
.pdg_constructor
.body_cache()
.get(function.def_id())
.body();
match leaf_loc {
RichLocation::Start
if matches!(body.local_kind(weight.place.local), mir::LocalKind::Arg) =>
{
let arg_num = weight.place.local.as_u32() - 1;
self.known_def_ids.extend([function_id]);
self.register_annotations_for_argument(node, arg_num, function_id);
}
RichLocation::End if weight.place.local == mir::RETURN_PLACE => {
self.known_def_ids.extend([function_id]);
self.register_annotations_for_return(node, function_id);
}
RichLocation::Location(loc) => self.handle_node_annotations_for_regular_location(
local_node, node, weight, body, loc, vis,
),
_ => (),
}
}
fn register_annotations_for_argument(&mut self, node: GNode, arg_num: u32, function_id: DefId) {
self.register_annotations_for_function(node, function_id, |ann| {
ann.refinement.on_argument().contains(arg_num).unwrap()
});
}
fn register_annotations_for_return(&mut self, node: GNode, function_id: DefId) {
self.register_annotations_for_function(node, function_id, |ann| ann.refinement.on_return());
}
fn handle_node_annotations_for_regular_location<K: Clone>(
&mut self,
local_node: Node,
node: GNode,
weight: &DepNode<'tcx, OneHopLocation>,
body: &mir::Body<'tcx>,
loc: Location,
vis: &VisitDriver<'tcx, '_, K>,
) {
let function = vis.current_function();
let function_id = function.def_id();
let crate::Either::Right(
term @ mir::Terminator {
kind: mir::TerminatorKind::Call { func, .. },
..
},
) = body.stmt_at(loc)
else {
return;
};
debug!("Assigning markers to {:?}", term.kind);
let param_env = TypingEnv::post_analysis(self.tcx(), function.def_id());
let func =
try_monomorphize(function, self.tcx(), param_env, func, term.source_info.span).unwrap();
let Some(funcc) = func.constant() else {
self.generator.ctx.maybe_span_err(
weight.span,
"SOUNDNESS: Cannot determine markers for function call",
);
return;
};
let (inst, args) = type_as_fn(self.tcx(), ty_of_const(funcc)).unwrap();
let f = if let Some(inst) = try_resolve_function(
self.tcx(),
inst,
TypingEnv::post_analysis(self.tcx(), function_id),
args,
) {
match handle_shims(inst, self.tcx(), param_env, weight.span) {
ShimResult::IsHandledShim { instance, .. } => instance,
ShimResult::IsNotShim => inst,
ShimResult::IsNonHandleableShim => {
self.ctx().maybe_span_err(
weight.span,
"SOUNDNESS: Cannot determine markers for shim usage",
);
return;
}
}
.def_id()
} else {
debug!("Could not resolve {inst:?} properly during marker assignment");
inst
};
self.known_def_ids.extend(Some(f));
let graph = vis.current_graph();
let mut is_return_use = false;
let mut has_no_data_edges = true;
for eref in graph.raw().edges_directed(local_node, petgraph::Outgoing) {
let SourceUse::Argument(arg) = eref.weight().source_use else {
continue;
};
self.register_annotations_for_function(node, f, |ann| {
ann.refinement.on_argument().contains(arg as u32).unwrap()
});
}
for eref in graph.raw().edges_directed(local_node, petgraph::Incoming) {
if eref.weight().kind == DepEdgeKind::Data {
has_no_data_edges = false;
let at = eref.weight().at.clone();
#[cfg(debug_assertions)]
assert_edge_location_invariant(
self.tcx(),
at.clone(),
body,
weight.at.clone(),
|at| GlobalLocation {
function: function.def_id(),
location: at.location,
},
);
if weight.at == at && eref.weight().target_use.is_return() {
is_return_use = true;
}
}
}
let needs_return_markers = has_no_data_edges | is_return_use;
if needs_return_markers {
self.register_annotations_for_function(node, f, |ann| ann.refinement.on_return());
}
}
fn determine_arguments(&self) -> Box<[Node]> {
let mut g_nodes: Vec<_> = self
.graph
.node_references()
.filter(|n| {
let at = n.weight().at;
let is_candidate =
matches!(self.try_as_root(at), Some(l) if l.location == RichLocation::Start);
is_candidate
})
.collect();
g_nodes.sort_by_key(|(_, i)| i.local);
g_nodes.into_iter().map(|(n, _)| n).collect()
}
fn determine_return(&self) -> Box<[Node]> {
self.graph
.node_references()
.filter(|n| {
let weight = n.weight();
let at = weight.at;
matches!(self.try_as_root(at), Some(l) if l.location == RichLocation::End)
})
.map(|n| n.id())
.collect()
}
fn try_as_root(&self, at: CallString) -> Option<GlobalLocation> {
at.is_at_root().then(|| at.leaf())
}
fn entrypoint_is_async(&self) -> bool {
self.is_async
}
fn fix_async_args<K: Clone + std::hash::Hash + Eq>(
&mut self,
instance: Instance<'tcx>,
loc: Location,
driver: &mut VisitDriver<'tcx, 'a, K>,
) {
let def_id = instance.def_id();
self.known_def_ids.extend([def_id]);
let tcx = self.generator.tcx();
let base_body = self
.generator
.pdg_constructor
.body_cache()
.try_get(def_id)
.unwrap_or_else(|| {
panic!("INVARIANT VIOLATED: body for local function {def_id:?} cannot be loaded.",)
})
.body();
let pgraph = driver.current_graph_as_rc();
let args_as_nodes = base_body
.args_iter()
.map(|arg| {
self.add_untranslatable_node(
arg.into(),
driver.globalize_location(&RichLocation::Start.into()),
base_body.local_decls[arg].source_info.span,
)
})
.collect::<Vec<_>>();
let return_node = self.add_untranslatable_node(
mir::RETURN_PLACE.into(),
driver.globalize_location(&RichLocation::End.into()),
base_body.local_decls[mir::RETURN_PLACE].source_info.span,
);
let mono_ty = |local| {
let decl = &base_body.local_decls[local];
mir::tcx::PlaceTy::from_ty(
try_monomorphize(
instance,
tcx,
TypingEnv::fully_monomorphized(),
&decl.ty,
decl.source_info.span,
)
.unwrap(),
)
};
for (arg_num, a) in args_as_nodes.iter().enumerate() {
self.register_annotations_for_argument(*a, arg_num as u32, def_id);
let local = mir::Local::from_usize(arg_num + 1);
self.handle_node_types_helper(*a, mono_ty(local), &[]);
}
self.register_annotations_for_return(return_node, def_id);
let local = mir::RETURN_PLACE;
self.handle_node_types_helper(return_node, mono_ty(local), &[]);
let generator_loc = RichLocation::Location(loc);
let transition_at = CallString::new(&[GlobalLocation {
location: generator_loc,
function: def_id,
}]);
for (nidx, n) in pgraph.iter_nodes() {
if n.place.local.as_u32() == 1 && n.at.location == RichLocation::Start {
let ridx = self.translate_node(nidx);
let Some(mir::ProjectionElem::Field(id, _)) = n.place.projection.first() else {
tcx.dcx().span_err(
n.span,
format!("Expected field projection on async generator in {def_id:?}, found {:?}", n.place),
);
continue;
};
let arg = args_as_nodes[id.as_usize()];
self.graph.add_edge(
arg.to_index(),
ridx.to_index(),
EdgeInfo {
kind: EdgeKind::Data,
at: transition_at,
source_use: SourceUse::Argument(id.as_u32() as u8),
target_use: TargetUse::Assign,
},
);
} else if n.place.local == mir::RETURN_PLACE {
let ridx = self.translate_node(nidx);
self.graph.add_edge(
ridx.to_index(),
return_node.to_index(),
EdgeInfo {
kind: EdgeKind::Data,
at: transition_at,
source_use: SourceUse::Operand,
target_use: TargetUse::Return,
},
);
}
}
}
fn translate_node(&self, node: Node) -> GNode {
self.translate_node_in(node, self.nodes.len() - 1)
}
fn translate_node_in(&self, node: Node, index: usize) -> GNode {
let idx = self.nodes[index][node.index()];
assert_ne!(idx.to_index(), Node::end(), "Node {node:?} is unknown");
idx
}
fn with_new_translation_table<R>(&mut self, size: usize, f: impl FnOnce(&mut Self) -> R) -> R {
self.nodes.push(vec![GNode(Node::end()); size]);
let result = f(self);
if self.nodes.len() != 1 {
assert_eq!(self.nodes.pop().unwrap().len(), size);
}
result
}
}
fn globalize_node<'tcx, K: Clone>(
vis: &mut VisitDriver<'tcx, '_, K>,
node: &DepNode<'tcx, OneHopLocation>,
tcx: TyCtxt<'tcx>,
) -> NodeInfo {
let at = vis.globalize_location(&node.at);
NodeInfo {
at,
description: format!("{:?}", node.place),
span: src_loc_for_span(node.span, tcx),
local: node.place.local.as_u32(),
}
}
fn globalize_edge<K: Clone>(
vis: &mut VisitDriver<'_, '_, K>,
edge: &DepEdge<OneHopLocation>,
) -> EdgeInfo {
let at = vis.globalize_location(&edge.at);
EdgeInfo {
kind: dep_edge_kind_to_edge_kind(edge.kind),
at,
source_use: edge.source_use,
target_use: edge.target_use,
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct GNode(petgraph::graph::NodeIndex);
impl GNode {
fn to_index(self) -> petgraph::graph::NodeIndex {
self.0
}
}
impl<'tcx, K: std::hash::Hash + Eq + Clone> Visitor<'tcx, K> for GraphAssembler<'tcx, '_> {
fn visit_parent_connection(
&mut self,
_vis: &mut VisitDriver<'tcx, '_, K>,
in_caller: Node,
in_this: Node,
_is_at_start: bool,
) {
let [parent_table, this_table] = self.nodes.last_chunk_mut().unwrap();
this_table[in_this.index()] = parent_table[in_caller.index()]
}
fn visit_node(
&mut self,
vis: &mut VisitDriver<'tcx, '_, K>,
k: Node,
node: &DepNode<'tcx, OneHopLocation>,
) {
let is_in_child = node.at.in_child.is_some();
let idx = self.add_node(k, vis, node);
if !is_in_child {
self.node_annotations(k, idx, node, vis);
self.handle_node_types(idx, &node.at, node.place, node.span, vis);
}
}
fn visit_edge(
&mut self,
vis: &mut VisitDriver<'tcx, '_, K>,
src: Node,
dst: Node,
kind: &DepEdge<OneHopLocation>,
) {
let src = self.translate_node(src);
let dst = self.translate_node(dst);
let new_kind = globalize_edge(vis, kind);
self.graph
.add_edge(src.to_index(), dst.to_index(), new_kind);
}
fn visit_partial_graph(
&mut self,
vis: &mut VisitDriver<'tcx, '_, K>,
graph: &PartialGraph<'tcx, K>,
) {
self.with_new_translation_table(graph.node_count(), |slf: &mut Self| {
trace!(
"Visiting partial graph {:?}",
slf.tcx().def_path_str(graph.def_id())
);
vis.visit_partial_graph(slf, graph);
})
}
fn visit_ctrl_edge(
&mut self,
_vis: &mut VisitDriver<'tcx, '_, K>,
index: usize,
src: Node,
dst: Node,
edge: &DepEdge<CallString>,
) {
let src = self.translate_node_in(src, index);
let dst = self.translate_node(dst);
self.graph.add_edge(
src.to_index(),
dst.to_index(),
EdgeInfo {
kind: dep_edge_kind_to_edge_kind(edge.kind),
at: edge.at,
source_use: edge.source_use,
target_use: edge.target_use,
},
);
}
}