extern crate smallvec;
use flowistry_pdg::RichLocation;
use flowistry_pdg_construction::utils::type_as_fn;
use thiserror::Error;
use crate::{desc::Identifier, rustc_span::ErrorGuaranteed, Either, Symbol, TyCtxt};
pub use flowistry_pdg_construction::utils::is_virtual;
pub use paralegal_spdg::{ShortHash, TinyBitSet};
use rustc_ast as ast;
use rustc_data_structures::intern::Interned;
use rustc_hir::{
self as hir,
def::Res,
def_id::{DefId, LocalDefId},
hir_id::HirId,
BodyId,
};
use rustc_middle::{
mir::{self, ConstOperand, Location, Place, ProjectionElem, Terminator},
ty::{self, GenericArgsRef, Instance, Ty, TypingEnv},
};
use rustc_span::{symbol::Ident, Span as RustSpan, Span};
use rustc_target::spec::abi::Abi;
use std::cmp::Ordering;
mod print;
pub mod resolve;
pub use print::*;
pub fn body_span<'tcx>(tcx: TyCtxt<'tcx>, body: &mir::Body<'tcx>) -> RustSpan {
let source_map = tcx.sess.source_map();
let body_span = body.span;
let mk_span_iter = || {
body.basic_blocks
.iter()
.flat_map(|bbdat| {
bbdat
.statements
.iter()
.map(|s| s.source_info.span)
.chain([bbdat.terminator().source_info.span])
})
.map(|s| s.source_callsite())
.filter(|s| !s.is_dummy() || !s.is_empty())
};
let can_use_body_span = mk_span_iter().all(|sp| sp.from_expansion() || body_span.contains(sp));
if can_use_body_span {
return body_span;
}
let outer_source_file_idx = source_map.lookup_source_file_idx(body_span.data().lo);
mk_span_iter()
.filter(|span| {
let file_idx = source_map.lookup_source_file_idx(span.data().lo);
file_idx == outer_source_file_idx
})
.reduce(RustSpan::to)
.unwrap_or(body_span)
}
pub trait MetaItemMatch {
fn match_extract<A, F: Fn(&ast::AttrArgs) -> A>(&self, path: &[Symbol], parse: F) -> Option<A> {
self.match_get_ref(path).map(parse)
}
fn matches_path(&self, path: &[Symbol]) -> bool {
self.match_get_ref(path).is_some()
}
fn match_get_ref(&self, path: &[Symbol]) -> Option<&ast::AttrArgs>;
}
impl MetaItemMatch for ast::Attribute {
fn match_get_ref(&self, path: &[Symbol]) -> Option<&ast::AttrArgs> {
match &self.kind {
ast::AttrKind::Normal(normal) => match &normal.item {
ast::AttrItem {
path: attr_path,
args,
..
} if attr_path.segments.len() == path.len()
&& attr_path
.segments
.iter()
.zip(path)
.all(|(seg, i)| seg.ident.name == *i) =>
{
Some(args)
}
_ => None,
},
_ => None,
}
}
}
pub trait TyExt: Sized {
fn defid(self) -> Option<DefId> {
self.defid_ref().copied()
}
fn defid_ref(&self) -> Option<&DefId>;
}
impl TyExt for ty::Ty<'_> {
fn defid_ref(&self) -> Option<&DefId> {
match self.kind() {
ty::TyKind::Adt(ty::AdtDef(Interned(ty::AdtDefData { did, .. }, _)), _) => Some(did),
ty::TyKind::Foreign(did)
| ty::TyKind::FnDef(did, _)
| ty::TyKind::Closure(did, _)
| ty::TyKind::Coroutine(did, _) => Some(did),
_ => None,
}
}
}
pub trait GenericArgExt<'tcx> {
fn as_type(&self) -> Option<ty::Ty<'tcx>>;
}
impl<'tcx> GenericArgExt<'tcx> for ty::GenericArg<'tcx> {
fn as_type(&self) -> Option<ty::Ty<'tcx>> {
match self.unpack() {
ty::GenericArgKind::Type(t) => Some(t),
_ => None,
}
}
}
pub trait DfppBodyExt<'tcx> {
fn stmt_at_better_err(
&self,
l: mir::Location,
) -> Either<&mir::Statement<'tcx>, &mir::Terminator<'tcx>> {
self.maybe_stmt_at(l).unwrap()
}
fn maybe_stmt_at(
&self,
l: mir::Location,
) -> Result<Either<&mir::Statement<'tcx>, &mir::Terminator<'tcx>>, StmtAtErr<'_, 'tcx>>;
}
#[derive(Debug)]
pub enum StmtAtErr<'a, 'tcx> {
BasicBlockOutOfBound(mir::BasicBlock, &'a mir::Body<'tcx>),
StatementIndexOutOfBounds(usize, &'a mir::BasicBlockData<'tcx>),
}
impl<'tcx> DfppBodyExt<'tcx> for mir::Body<'tcx> {
fn maybe_stmt_at(
&self,
l: mir::Location,
) -> Result<Either<&mir::Statement<'tcx>, &mir::Terminator<'tcx>>, StmtAtErr<'_, 'tcx>> {
let Location {
block,
statement_index,
} = l;
let block_data = self
.basic_blocks
.get(block)
.ok_or(StmtAtErr::BasicBlockOutOfBound(block, self))?;
if statement_index == block_data.statements.len() {
Ok(Either::Right(block_data.terminator()))
} else if let Some(stmt) = block_data.statements.get(statement_index) {
Ok(Either::Left(stmt))
} else {
Err(StmtAtErr::StatementIndexOutOfBounds(
statement_index,
block_data,
))
}
}
}
pub trait InstanceExt<'tcx> {
fn sig(self, tcx: TyCtxt<'tcx>) -> Result<ty::FnSig<'tcx>, ErrorGuaranteed>;
}
impl<'tcx> InstanceExt<'tcx> for Instance<'tcx> {
fn sig(self, tcx: TyCtxt<'tcx>) -> Result<ty::FnSig<'tcx>, ErrorGuaranteed> {
let def_id = self.def_id();
let fn_kind = FunctionKind::for_def_id(tcx, def_id)?;
let typing_env = TypingEnv::fully_monomorphized();
let late_bound_sig = match fn_kind {
FunctionKind::Generator => {
let gen = self.args.as_coroutine();
ty::Binder::dummy(ty::FnSig {
inputs_and_output: tcx.mk_type_list(&[gen.resume_ty(), gen.return_ty()]),
c_variadic: false,
abi: Abi::Rust,
safety: hir::Safety::Safe,
})
}
FunctionKind::Closure => self.args.as_closure().sig(),
FunctionKind::Plain => self.ty(tcx, typing_env).fn_sig(tcx),
};
Ok(tcx.normalize_erasing_late_bound_regions(typing_env, late_bound_sig))
}
}
#[derive(Clone, Copy, Eq, PartialEq)]
pub enum FunctionKind {
Closure,
Generator,
Plain,
}
impl FunctionKind {
pub fn for_def_id(tcx: TyCtxt, def_id: DefId) -> Result<Self, ErrorGuaranteed> {
if tcx.coroutine_kind(def_id).is_some() {
Ok(Self::Generator)
} else if tcx.is_closure_like(def_id) {
Ok(Self::Closure)
} else if tcx.def_kind(def_id).is_fn_like() {
Ok(Self::Plain)
} else {
Err(tcx
.dcx()
.span_err(tcx.def_span(def_id), "Expected this item to be a function."))
}
}
}
pub fn func_of_term<'tcx>(
tcx: TyCtxt<'tcx>,
terminator: &Terminator<'tcx>,
) -> Option<(DefId, GenericArgsRef<'tcx>)> {
let mir::TerminatorKind::Call { func, .. } = &terminator.kind else {
return None;
};
let const_ = func.constant()?;
let ty = ty_of_const(const_);
type_as_fn(tcx, ty)
}
pub type SimplifiedArguments<'tcx> = Vec<Option<Place<'tcx>>>;
pub trait AsFnAndArgs<'tcx> {
fn as_fn_and_args(
&self,
tcx: TyCtxt<'tcx>,
) -> Result<(DefId, SimplifiedArguments<'tcx>, mir::Place<'tcx>), AsFnAndArgsErr<'tcx>> {
self.as_instance_and_args(tcx)
.map(|(inst, args, ret)| (inst.def_id(), args, ret))
}
fn as_instance_and_args(
&self,
tcx: TyCtxt<'tcx>,
) -> Result<(Instance<'tcx>, SimplifiedArguments<'tcx>, mir::Place<'tcx>), AsFnAndArgsErr<'tcx>>;
}
#[derive(Debug, Error)]
pub enum AsFnAndArgsErr<'tcx> {
#[error("not a constant")]
NotAConstant,
#[error("is not a function type: {0:?}")]
NotFunctionType(ty::TyKind<'tcx>),
#[error("is not a `Val` constant: {0}")]
NotValueLevelConstant(ty::Const<'tcx>),
#[error("terminator is not a `Call`")]
NotAFunctionCall,
#[error("function instance resolution errored")]
InstanceResolutionErr,
#[error("could not normalize generics {0}")]
NormalizationError(String),
#[error("instance too unspecific")]
InstanceTooUnspecific,
}
pub fn ty_of_const<'tcx>(c: &ConstOperand<'tcx>) -> Ty<'tcx> {
match c.const_ {
mir::Const::Val(_, ty) => ty,
mir::Const::Ty(cst, _) => cst,
mir::Const::Unevaluated { .. } => unreachable!(),
}
}
impl<'tcx> AsFnAndArgs<'tcx> for mir::Terminator<'tcx> {
fn as_instance_and_args(
&self,
tcx: TyCtxt<'tcx>,
) -> Result<(Instance<'tcx>, SimplifiedArguments<'tcx>, mir::Place<'tcx>), AsFnAndArgsErr<'tcx>>
{
let mir::TerminatorKind::Call {
func,
args,
destination,
..
} = &self.kind
else {
return Err(AsFnAndArgsErr::NotAFunctionCall);
};
let ty = ty_of_const(func.constant().ok_or(AsFnAndArgsErr::NotAConstant)?);
let Some((def_id, gargs)) = type_as_fn(tcx, ty) else {
return Err(AsFnAndArgsErr::NotFunctionType(*ty.kind()));
};
test_generics_normalization(tcx, gargs)
.map_err(|e| AsFnAndArgsErr::NormalizationError(format!("{e:?}")))?;
let instance =
ty::Instance::try_resolve(tcx, TypingEnv::fully_monomorphized(), def_id, gargs)
.map_err(|_| AsFnAndArgsErr::InstanceResolutionErr)?
.ok_or(AsFnAndArgsErr::InstanceTooUnspecific)?;
Ok((
instance,
args.iter().map(|a| a.node.place()).collect(),
*destination,
))
}
}
fn test_generics_normalization<'tcx>(
tcx: TyCtxt<'tcx>,
args: &'tcx ty::List<ty::GenericArg<'tcx>>,
) -> Result<(), ty::normalize_erasing_regions::NormalizationError<'tcx>> {
tcx.try_normalize_erasing_regions(TypingEnv::fully_monomorphized(), args)
.map(|_| ())
}
pub struct PlaceVisitor<F>(pub F);
impl<'tcx, F: FnMut(&mir::Place<'tcx>)> mir::visit::Visitor<'tcx> for PlaceVisitor<F> {
fn visit_place(
&mut self,
place: &mir::Place<'tcx>,
_context: mir::visit::PlaceContext,
_location: mir::Location,
) {
self.0(place)
}
}
pub enum Overlap<'tcx> {
Equal,
Independent,
Parent(&'tcx [mir::PlaceElem<'tcx>]),
Child(&'tcx [mir::PlaceElem<'tcx>]),
}
impl Overlap<'_> {
pub fn contains_other(self) -> bool {
matches!(self, Overlap::Equal | Overlap::Parent(_))
}
}
pub trait PlaceExt<'tcx> {
fn simple_overlaps(self, other: Place<'tcx>) -> Overlap<'tcx>;
}
impl<'tcx> PlaceExt<'tcx> for Place<'tcx> {
fn simple_overlaps(self, other: Place<'tcx>) -> Overlap<'tcx> {
if self.local != other.local
|| self
.projection
.iter()
.zip(other.projection)
.any(|(one, other)| one != other)
{
return Overlap::Independent;
}
match self.projection.len().cmp(&other.projection.len()) {
Ordering::Less => Overlap::Parent(&other.projection[self.projection.len()..]),
Ordering::Greater => Overlap::Child(&self.projection[other.projection.len()..]),
Ordering::Equal => Overlap::Equal,
}
}
}
pub trait NodeExt<'hir> {
fn as_fn(&self, tcx: TyCtxt) -> Option<(Ident, hir::def_id::LocalDefId, BodyId)>;
}
impl<'hir> NodeExt<'hir> for hir::Node<'hir> {
fn as_fn(&self, tcx: TyCtxt) -> Option<(Ident, hir::def_id::LocalDefId, BodyId)> {
match self {
hir::Node::Item(hir::Item {
ident,
owner_id,
kind: hir::ItemKind::Fn(_, _, body_id),
..
})
| hir::Node::ImplItem(hir::ImplItem {
ident,
owner_id,
kind: hir::ImplItemKind::Fn(_, body_id),
..
}) => Some((*ident, owner_id.def_id, *body_id)),
hir::Node::Expr(hir::Expr {
kind: hir::ExprKind::Closure(hir::Closure { body: body_id, .. }),
..
}) => Some((
Ident::from_str("closure"),
tcx.hir().body_owner_def_id(*body_id),
*body_id,
)),
_ => None,
}
}
}
pub trait IntoLocalDefId {
fn into_local_def_id(self, tcx: TyCtxt) -> LocalDefId;
}
impl IntoLocalDefId for LocalDefId {
#[inline]
fn into_local_def_id(self, _tcx: TyCtxt) -> LocalDefId {
self
}
}
impl IntoLocalDefId for BodyId {
#[inline]
fn into_local_def_id(self, tcx: TyCtxt) -> LocalDefId {
tcx.hir().body_owner_def_id(self)
}
}
impl IntoLocalDefId for HirId {
#[inline]
fn into_local_def_id(self, _: TyCtxt) -> LocalDefId {
self.expect_owner().def_id
}
}
impl<D: Copy + IntoLocalDefId> IntoLocalDefId for &'_ D {
#[inline]
fn into_local_def_id(self, tcx: TyCtxt) -> LocalDefId {
(*self).into_local_def_id(tcx)
}
}
pub trait ProjectionElemExt {
fn may_be_indirect(self) -> bool;
}
impl<V, T> ProjectionElemExt for ProjectionElem<V, T> {
fn may_be_indirect(self) -> bool {
matches!(
self,
ProjectionElem::Field(..)
| ProjectionElem::Index(..)
| ProjectionElem::ConstantIndex { .. }
| ProjectionElem::Subslice { .. }
)
}
}
pub trait IntoDefId {
fn into_def_id(self, tcx: TyCtxt) -> DefId;
}
impl IntoDefId for DefId {
#[inline]
fn into_def_id(self, _: TyCtxt) -> DefId {
self
}
}
impl IntoDefId for LocalDefId {
#[inline]
fn into_def_id(self, _: TyCtxt) -> DefId {
self.to_def_id()
}
}
impl IntoDefId for HirId {
#[inline]
fn into_def_id(self, tcx: TyCtxt) -> DefId {
self.into_local_def_id(tcx).to_def_id()
}
}
impl<D: Copy + IntoDefId> IntoDefId for &'_ D {
#[inline]
fn into_def_id(self, tcx: TyCtxt) -> DefId {
(*self).into_def_id(tcx)
}
}
impl IntoDefId for BodyId {
#[inline]
fn into_def_id(self, tcx: TyCtxt) -> DefId {
tcx.hir().body_owner_def_id(self).into_def_id(tcx)
}
}
impl IntoDefId for Res {
#[inline]
fn into_def_id(self, _: TyCtxt) -> DefId {
match self {
Res::Def(_, did) => did,
_ => panic!("turning non-def res into DefId; res is: {:?}", self),
}
}
}
pub fn identifier_for_item(tcx: TyCtxt, did: DefId) -> Identifier {
let get_parent = || identifier_for_item(tcx, tcx.parent(did));
Identifier::new_intern(
&tcx.opt_item_name(did)
.map(|n| n.to_string())
.or_else(|| {
use hir::def::DefKind::*;
match tcx.def_kind(did) {
OpaqueTy => Some("Opaque".to_string()),
Closure => {
let suffix = if tcx.is_coroutine(did) {
"coroutine"
} else {
"closure"
};
Some(format!("{}_{}", get_parent(), suffix))
}
_ => None,
}
})
.unwrap_or_else(|| {
panic!(
"Could not name {} {:?}",
tcx.def_path_debug_str(did),
tcx.def_kind(did)
)
}),
)
}
#[macro_export]
macro_rules! sym_vec {
($($e:expr),*) => {
vec![$(rustc_span::Symbol::intern($e)),*]
};
}
pub fn with_temporary_logging_level<R, F: FnOnce() -> R>(filter: log::LevelFilter, f: F) -> R {
let reset_level = log::max_level();
log::set_max_level(filter);
let r = f();
log::set_max_level(reset_level);
r
}
pub fn time<R, F: FnOnce() -> R>(msg: &str, f: F) -> R {
info!("Starting {msg}");
let time = std::time::Instant::now();
let r = f();
info!("{msg} took {}", humantime::format_duration(time.elapsed()));
r
}
pub trait Spanned<'tcx> {
fn span(&self, tcx: TyCtxt<'tcx>) -> Span;
}
impl<'tcx> Spanned<'tcx> for mir::Terminator<'tcx> {
fn span(&self, _tcx: TyCtxt<'tcx>) -> Span {
self.source_info.span
}
}
impl<'tcx> Spanned<'tcx> for mir::Statement<'tcx> {
fn span(&self, _tcx: TyCtxt<'tcx>) -> Span {
self.source_info.span
}
}
impl<'tcx> Spanned<'tcx> for (&mir::Body<'tcx>, mir::Location) {
fn span(&self, tcx: TyCtxt<'tcx>) -> Span {
self.0
.stmt_at(self.1)
.either(|e| e.span(tcx), |e| e.span(tcx))
}
}
impl<'tcx> Spanned<'tcx> for (&mir::Body<'tcx>, RichLocation) {
fn span(&self, tcx: TyCtxt<'tcx>) -> RustSpan {
let (body, loc) = self;
match loc {
RichLocation::Location(loc) => (*body, *loc).span(tcx),
RichLocation::End | RichLocation::Start => body.span,
}
}
}
impl<'tcx> Spanned<'tcx> for DefId {
fn span(&self, tcx: TyCtxt<'tcx>) -> Span {
tcx.def_span(*self)
}
}
pub fn map_either<A, B, C, D>(
either: Either<A, B>,
f: impl FnOnce(A) -> C,
g: impl FnOnce(B) -> D,
) -> Either<C, D> {
match either {
Either::Left(l) => Either::Left(f(l)),
Either::Right(r) => Either::Right(g(r)),
}
}