use std::{fmt::Display, rc::Rc};
use flowistry_pdg_construction::CallInfo;
use paralegal_spdg::{utils::write_sep, Identifier};
use rustc_hir::def_id::{CrateNum, DefId};
use rustc_middle::ty::{
AssocKind, BoundVariableKind, Clause, ClauseKind, Instance, ProjectionPredicate,
TraitPredicate, TypingEnv,
};
use rustc_span::Span;
use rustc_type_ir::{PredicatePolarity, TyKind};
use crate::{
ana::Print,
args::{InliningDepth, Stub},
MarkerCtx, Pctx,
};
pub type K = u32;
pub struct InlineJudge<'tcx> {
ctx: Pctx<'tcx>,
included_crates: Rc<dyn Fn(CrateNum) -> bool>,
}
#[derive(strum::AsRefStr)]
pub enum InlineJudgement {
Inline(bool),
UseStub(&'static Stub),
AbstractViaType(&'static str),
}
impl Display for InlineJudgement {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.as_ref())?;
if let Self::AbstractViaType(reason) = self {
write!(f, "({reason})")?;
}
Ok(())
}
}
impl<'tcx> InlineJudge<'tcx> {
pub fn new(ctx: Pctx<'tcx>) -> Self {
let included_crates = Rc::new(ctx.opts().anactrl().inclusion_predicate(ctx.tcx()));
Self {
included_crates,
ctx,
}
}
pub fn is_included(&self, c: CrateNum) -> bool {
(self.included_crates)(c)
}
pub fn should_inline(&self, info: &CallInfo<'tcx, '_, K>) -> InlineJudgement {
let marker_target = info.async_parent.unwrap_or(info.callee);
let marker_target_def_id = marker_target.def_id();
if let Some(model) = self.ctx.marker_ctx().has_stub(marker_target_def_id) {
return if info.async_parent.is_some() {
InlineJudgement::AbstractViaType("async parent of stub")
} else {
InlineJudgement::UseStub(model)
};
}
let is_marked = self.ctx.marker_ctx().is_marked(marker_target_def_id);
let judgement = match self.ctx.opts().anactrl().inlining_depth() {
_ if !self.is_included(marker_target_def_id.krate) => {
InlineJudgement::AbstractViaType("inlining for crate disabled")
}
_ if self.ctx.tcx().is_foreign_item(marker_target_def_id) => {
InlineJudgement::AbstractViaType("cannot inline foreign item")
}
_ if self.ctx.tcx().is_constructor(marker_target_def_id) => {
InlineJudgement::AbstractViaType("is constructor")
}
_ if is_marked => InlineJudgement::AbstractViaType("marked"),
InliningDepth::Adaptive(k) => {
if self
.ctx
.marker_ctx()
.has_transitive_reachable_markers(marker_target)
{
InlineJudgement::Inline(false)
} else if *k == 0 {
InlineJudgement::AbstractViaType("adaptive inlining")
} else if info.cache_key == k {
InlineJudgement::AbstractViaType("adaptive inlining, k-depth reached")
} else {
assert!(
info.cache_key < k,
"cache key {} is greater than k {k}",
info.cache_key,
);
InlineJudgement::Inline(true)
}
}
InliningDepth::K(k) => {
if *k == 0 {
InlineJudgement::AbstractViaType("shallow inlining configured")
} else if info.cache_key == k {
InlineJudgement::AbstractViaType("k-depth reached")
} else {
assert!(
info.cache_key < k,
"cache key {} is greater than k {k}",
info.cache_key,
);
InlineJudgement::Inline(true)
}
}
InliningDepth::Unconstrained => InlineJudgement::Inline(false),
};
if let InlineJudgement::AbstractViaType(reason) = judgement {
let emit_err = !(is_marked || self.ctx.opts().relaxed());
self.ensure_is_safe_to_approximate(
info.param_env,
info.callee,
info.span,
emit_err,
reason,
)
}
judgement
}
#[allow(unused)]
fn marker_ctx(&self) -> &MarkerCtx<'tcx> {
self.ctx.marker_ctx()
}
pub fn ensure_is_safe_to_approximate(
&self,
typing_env: TypingEnv<'tcx>,
resolved: Instance<'tcx>,
call_span: Span,
emit_err: bool,
reason: &'static str,
) {
SafetyChecker {
ctx: self.ctx.clone(),
emit_err,
typing_env,
resolved,
call_span,
reason,
}
.check()
}
}
struct SafetyChecker<'tcx> {
ctx: Pctx<'tcx>,
emit_err: bool,
typing_env: TypingEnv<'tcx>,
resolved: Instance<'tcx>,
call_span: Span,
reason: &'static str,
}
impl<'tcx> SafetyChecker<'tcx> {
fn err(&self, s: &str, span: Span) {
let sess = self.ctx.tcx().dcx();
let msg = format!(
"the call to {:?} is not safe to abstract as demanded by '{}', because of: {s}",
self.resolved, self.reason
);
if self.emit_err {
let mut diagnostic = sess.struct_span_err(span, msg);
diagnostic.span_note(self.call_span, "Called from here");
diagnostic.emit();
} else {
let mut diagnostic = sess.struct_span_warn(span, msg);
diagnostic.span_note(self.call_span, "Called from here");
diagnostic.emit();
}
}
fn err_markers(&self, s: &str, markers: &[Identifier], span: Span) {
if !markers.is_empty() {
self.err(
&format!(
"{s}: found marker(s) {}",
Print(|fmt| write_sep(fmt, ", ", markers, |elem, fmt| write!(fmt, "'{elem}'")))
),
span,
);
}
}
fn check_projection_predicate(&self, predicate: &ProjectionPredicate<'tcx>, span: Span) {
if let Some(t) = predicate.term.as_type() {
let t = self.ctx.tcx().normalize_erasing_regions(self.typing_env, t);
let markers = self.ctx.marker_ctx().deep_type_markers(t);
if !markers.is_empty() {
let markers = markers.iter().map(|t| t.1).collect::<Box<_>>();
self.err_markers(
&format!("type {t:?} is not approximation safe"),
&markers,
span,
);
}
}
}
fn check_trait_predicate(&self, predicate: &TraitPredicate<'tcx>, span: Span) {
let tcx = self.ctx.tcx();
let TraitPredicate {
polarity: PredicatePolarity::Positive,
trait_ref,
} = predicate
else {
return;
};
if tcx.trait_is_auto(trait_ref.def_id) {
return;
}
let Some(self_ty) = trait_ref.args[0].as_type() else {
self.err("expected self type to be type, got {ref_1:?}", span);
return;
};
if tcx.is_fn_trait(trait_ref.def_id) {
let instance = match self_ty.kind() {
TyKind::Closure(id, args) | TyKind::FnDef(id, args) => {
Instance::expect_resolve(tcx, TypingEnv::fully_monomorphized(), *id, args, span)
}
TyKind::FnPtr(..) => {
self.err(&format!("unresolvable function pointer {self_ty:?}"), span);
return;
}
_ => {
self.err(
&format!(
"fn-trait instance for {self_ty:?} not being a function or closure"
),
span,
);
return;
}
};
let markers = self.ctx.marker_ctx().get_reachable_markers(instance);
if !markers.is_empty() {
self.err_markers(
&format!("closure {instance:?} is not approximation safe"),
markers,
span,
);
}
} else {
tcx.for_each_relevant_impl(trait_ref.def_id, self_ty, |r#impl| {
self.check_impl(r#impl, span)
})
}
}
fn check_impl(&self, r#impl: DefId, span: Span) {
for item in self
.ctx
.tcx()
.associated_items(r#impl)
.in_definition_order()
{
match item.kind {
AssocKind::Fn => {
let method = item.def_id;
let markers = self.ctx.marker_ctx().get_reachable_markers(method);
if !markers.is_empty() {
self.err_markers(&self.ctx.tcx().def_path_str(method), markers, span)
}
}
AssocKind::Const | AssocKind::Type => (),
}
}
}
fn check_predicate(&self, clause: Clause<'tcx>, span: Span) {
let kind = clause.kind();
for bound in kind.bound_vars() {
match bound {
BoundVariableKind::Ty(t) => self.err(&format!("bound type {t:?}"), span),
BoundVariableKind::Const | BoundVariableKind::Region(_) => (),
}
}
match &kind.skip_binder() {
ClauseKind::TypeOutlives(_)
| ClauseKind::WellFormed(_)
| ClauseKind::ConstArgHasType(..)
| ClauseKind::ConstEvaluatable(_)
| ClauseKind::HostEffect(_)
| ClauseKind::RegionOutlives(_) => {
}
ClauseKind::Projection(predicate) => self.check_projection_predicate(predicate, span),
ClauseKind::Trait(predicate) => self.check_trait_predicate(predicate, span),
}
}
fn check(&self) {
let tcx = self.ctx.tcx();
tcx.predicates_of(self.resolved.def_id())
.instantiate(tcx, self.resolved.args)
.into_iter()
.for_each(|(clause, span)| self.check_predicate(clause, span));
}
}