use std::{borrow::Cow, collections::VecDeque};
use log::{trace, warn};
use rustc_data_structures::fx::{FxHashMap as HashMap, FxHashSet as HashSet};
use rustc_hir::def_id::DefId;
use rustc_infer::infer::TyCtxtInferExt;
use rustc_middle::{
  mir::{
    visit::{PlaceContext, Visitor},
    Body, HasLocalDecls, Local, Location, Mutability, Place, PlaceElem, PlaceRef,
    ProjectionElem, VarDebugInfo, VarDebugInfoContents, RETURN_PLACE,
  },
  traits::ObligationCause,
  ty::{self, AdtKind, Region, RegionKind, RegionVid, Ty, TyCtxt, TyKind, TypeVisitor},
};
use rustc_target::abi::{FieldIdx, VariantIdx};
use rustc_trait_selection::traits::NormalizeExt;
use rustc_type_ir::TypingMode;
use crate::{AdtDefExt, SpanExt};
#[derive(Default)]
pub struct PlaceCollector<'tcx>(pub Vec<Place<'tcx>>);
impl<'tcx> Visitor<'tcx> for PlaceCollector<'tcx> {
  fn visit_place(
    &mut self,
    place: &Place<'tcx>,
    _context: PlaceContext,
    _location: Location,
  ) {
    self.0.push(*place);
  }
}
pub trait PlaceExt<'tcx> {
  fn make(local: Local, projection: &[PlaceElem<'tcx>], tcx: TyCtxt<'tcx>) -> Self;
  fn from_ref(place: PlaceRef<'tcx>, tcx: TyCtxt<'tcx>) -> Self;
  fn from_local(local: Local, tcx: TyCtxt<'tcx>) -> Self;
  fn is_arg(&self, body: &Body<'tcx>) -> bool;
  fn is_direct(&self, body: &Body<'tcx>, tcx: TyCtxt<'tcx>) -> bool;
  fn refs_in_projection(
    self,
    body: &Body<'tcx>,
    tcx: TyCtxt<'tcx>,
  ) -> impl Iterator<Item = (PlaceRef<'tcx>, &'tcx [PlaceElem<'tcx>])>;
  fn interior_pointers(
    &self,
    tcx: TyCtxt<'tcx>,
    body: &Body<'tcx>,
    def_id: DefId,
  ) -> HashMap<RegionVid, Vec<(Place<'tcx>, Mutability)>>;
  fn interior_places(
    &self,
    tcx: TyCtxt<'tcx>,
    body: &Body<'tcx>,
    def_id: DefId,
  ) -> HashSet<Place<'tcx>>;
  fn interior_paths(
    &self,
    tcx: TyCtxt<'tcx>,
    body: &Body<'tcx>,
    def_id: DefId,
  ) -> HashSet<Place<'tcx>>;
  fn to_string(&self, tcx: TyCtxt<'tcx>, body: &Body<'tcx>) -> Option<String>;
  fn normalize(&self, tcx: TyCtxt<'tcx>, def_id: DefId) -> Place<'tcx>;
  fn is_source_visible(&self, tcx: TyCtxt, body: &Body) -> bool;
}
impl<'tcx> PlaceExt<'tcx> for Place<'tcx> {
  fn make(local: Local, projection: &[PlaceElem<'tcx>], tcx: TyCtxt<'tcx>) -> Self {
    Place {
      local,
      projection: tcx.mk_place_elems(projection),
    }
  }
  fn from_ref(place: PlaceRef<'tcx>, tcx: TyCtxt<'tcx>) -> Self {
    Self::make(place.local, place.projection, tcx)
  }
  fn from_local(local: Local, tcx: TyCtxt<'tcx>) -> Self {
    Place::make(local, &[], tcx)
  }
  fn is_arg(&self, body: &Body<'tcx>) -> bool {
    let i = self.local.as_usize();
    i > 0 && i - 1 < body.arg_count
  }
  fn is_direct(&self, body: &Body<'tcx>, tcx: TyCtxt<'tcx>) -> bool {
    !self.is_indirect()
      || self.is_arg(body)
      || self.refs_in_projection(body, tcx).next().is_none()
  }
  fn refs_in_projection(
    self,
    body: &Body<'tcx>,
    tcx: TyCtxt<'tcx>,
  ) -> impl Iterator<Item = (PlaceRef<'tcx>, &'tcx [PlaceElem<'tcx>])> {
    self
      .projection
      .iter()
      .enumerate()
      .scan(
        Place::from(self.local).ty(body, tcx),
        move |ty, (i, elem)| {
          let old_ty = *ty;
          *ty = ty.projection_ty(tcx, elem);
          Some((i, elem, old_ty))
        },
      )
      .filter_map(move |(i, elem, ty)| match elem {
        ProjectionElem::Deref => {
          let ptr = PlaceRef {
            local: self.local,
            projection: &self.projection[.. i],
          };
          let after = &self.projection[i + 1 ..];
          (!ty.ty.is_box()).then_some((ptr, after))
        }
        _ => None,
      })
  }
  fn interior_pointers(
    &self,
    tcx: TyCtxt<'tcx>,
    body: &Body<'tcx>,
    def_id: DefId,
  ) -> HashMap<RegionVid, Vec<(Place<'tcx>, Mutability)>> {
    let ty = self.ty(body.local_decls(), tcx).ty;
    let mut region_collector = RegionVisitor::<RegionMemberCollector>::new(
      tcx,
      def_id,
      *self,
      if
      false {
        StoppingCondition::AfterRefs
      } else {
        StoppingCondition::None
      },
    );
    region_collector.visit_ty(ty);
    region_collector.into_inner().0
  }
  fn interior_places(
    &self,
    tcx: TyCtxt<'tcx>,
    body: &Body<'tcx>,
    def_id: DefId,
  ) -> HashSet<Place<'tcx>> {
    let ty = self.ty(body.local_decls(), tcx).ty;
    let mut region_collector = RegionVisitor::<VisitedPlacesCollector>::new(
      tcx,
      def_id,
      *self,
      StoppingCondition::BeforeRefs,
    );
    region_collector.visit_ty(ty);
    region_collector.into_inner().0
  }
  fn interior_paths(
    &self,
    tcx: TyCtxt<'tcx>,
    body: &Body<'tcx>,
    def_id: DefId,
  ) -> HashSet<Place<'tcx>> {
    let ty = self.ty(body.local_decls(), tcx).ty;
    let mut region_collector = RegionVisitor::<VisitedPlacesCollector>::new(
      tcx,
      def_id,
      *self,
      StoppingCondition::None,
    );
    region_collector.visit_ty(ty);
    region_collector.into_inner().0
  }
  fn to_string(&self, tcx: TyCtxt<'tcx>, body: &Body<'tcx>) -> Option<String> {
    let local_name = if self.local == RETURN_PLACE {
      Cow::Borrowed("RETURN")
    } else {
      let get_local_name = |info: &VarDebugInfo<'tcx>| match info.value {
        VarDebugInfoContents::Place(place) if place.local == self.local => info
          .source_info
          .span
          .as_local(body.span)
          .map(|_| info.name.to_string()),
        _ => None,
      };
      let local_name = body.var_debug_info.iter().find_map(get_local_name)?;
      Cow::Owned(local_name)
    };
    #[derive(Copy, Clone)]
    enum ElemPosition {
      Prefix,
      Suffix,
    }
    let elem_to_string = |(index, (place, elem)): (
      usize,
      (PlaceRef<'tcx>, PlaceElem<'tcx>),
    )|
     -> Option<(ElemPosition, Cow<'static, str>)> {
      Some(match elem {
        ProjectionElem::Deref => (ElemPosition::Prefix, "*".into()),
        ProjectionElem::Field(field, _) => {
          let ty = place.ty(&body.local_decls, tcx).ty;
          let field_name = match ty.kind() {
            TyKind::Adt(def, _substs) => {
              let fields = match def.adt_kind() {
                AdtKind::Struct => &def.non_enum_variant().fields,
                AdtKind::Enum => {
                  let Some(PlaceElem::Downcast(_, variant_idx)) =
                    self.projection.get(index - 1)
                  else {
                    unimplemented!()
                  };
                  &def.variant(*variant_idx).fields
                }
                kind => {
                  log::debug!("place::to_string unimplemented for {kind:?}");
                  return None;
                }
              };
              fields[field].ident(tcx).to_string()
            }
            TyKind::Tuple(_) => field.as_usize().to_string(),
            TyKind::Closure(def_id, _substs) => match def_id.as_local() {
              Some(local_def_id) => {
                let captures = tcx.closure_captures(local_def_id);
                captures[field.as_usize()].var_ident.to_string()
              }
              None => field.as_usize().to_string(),
            },
            kind => {
              log::debug!("place::to_string unimplemented for {kind:?}");
              return None;
            }
          };
          (ElemPosition::Suffix, format!(".{field_name}").into())
        }
        ProjectionElem::Downcast(sym, _) => {
          let variant = sym.map(|s| s.to_string()).unwrap_or_else(|| "??".into());
          (ElemPosition::Suffix, format!("@{variant}",).into())
        }
        ProjectionElem::Index(_) => (ElemPosition::Suffix, "[_]".into()),
        kind => {
          log::debug!("place::to_string unimplemented for {kind:?}");
          return None;
        }
      })
    };
    let tuples = self
      .iter_projections()
      .enumerate()
      .map(elem_to_string)
      .collect::<Option<Vec<_>>>()?;
    let (positions, contents): (Vec<_>, Vec<_>) = tuples.into_iter().unzip();
    let mut parts = VecDeque::from([local_name]);
    for (i, string) in contents.into_iter().enumerate() {
      match positions[i] {
        ElemPosition::Prefix => {
          parts.push_front(string);
          if matches!(positions.get(i + 1), Some(ElemPosition::Suffix)) {
            parts.push_front(Cow::Borrowed("("));
            parts.push_back(Cow::Borrowed(")"));
          }
        }
        ElemPosition::Suffix => parts.push_back(string),
      }
    }
    let full = parts.make_contiguous().join("");
    Some(full)
  }
  fn normalize(&self, tcx: TyCtxt<'tcx>, def_id: DefId) -> Place<'tcx> {
    let param_env = tcx.param_env(def_id);
    let place = tcx.erase_regions(*self);
    let typing_mode = TypingMode::PostAnalysis;
    let infcx = tcx.infer_ctxt().build(typing_mode);
    let place = infcx
      .at(&ObligationCause::dummy(), param_env)
      .normalize(place)
      .value;
    let projection = place
      .projection
      .into_iter()
      .filter_map(|elem| match elem {
        ProjectionElem::Index(_) | ProjectionElem::ConstantIndex { .. } => {
          Some(ProjectionElem::Index(Local::from_usize(0)))
        }
        ProjectionElem::Subslice { .. } => None,
        _ => Some(elem),
      })
      .collect::<Vec<_>>();
    Place::make(place.local, &projection, tcx)
  }
  fn is_source_visible(&self, _tcx: TyCtxt, body: &Body) -> bool {
    let local = self.local;
    let local_info = &body.local_decls[local];
    let is_loc = local_info.is_user_variable();
    let from_desugaring = local_info.from_compiler_desugaring();
    let from_expansion = local_info.source_info.span.from_expansion();
    is_loc && !from_desugaring && !from_expansion
  }
}
#[derive(Copy, Clone)]
enum StoppingCondition {
  None,
  BeforeRefs,
  AfterRefs,
}
trait RegionVisitorDispatcher<'tcx> {
  fn on_visit_place(&mut self, _: Place<'tcx>) {}
  fn on_visit_type(&mut self, _: Ty<'tcx>) {}
  fn on_visit_region_member(&mut self, _: RegionVid, _: Place<'tcx>, _: Mutability) {}
}
#[derive(Default)]
struct VisitedPlacesCollector<'tcx>(HashSet<Place<'tcx>>);
impl<'tcx> RegionVisitorDispatcher<'tcx> for VisitedPlacesCollector<'tcx> {
  fn on_visit_place(&mut self, place: Place<'tcx>) {
    self.0.insert(place);
  }
}
#[derive(Default)]
struct VisitedTypesCollector<'tcx>(HashSet<Ty<'tcx>>);
impl<'tcx> RegionVisitorDispatcher<'tcx> for VisitedTypesCollector<'tcx> {
  fn on_visit_type(&mut self, ty: Ty<'tcx>) {
    self.0.insert(ty);
  }
}
#[derive(Default)]
struct RegionMemberCollector<'tcx>(HashMap<RegionVid, Vec<(Place<'tcx>, Mutability)>>);
impl<'tcx> RegionVisitorDispatcher<'tcx> for RegionMemberCollector<'tcx> {
  fn on_visit_region_member(
    &mut self,
    key: RegionVid,
    place: Place<'tcx>,
    mutbl: Mutability,
  ) {
    self.0.entry(key).or_default().push((place, mutbl));
  }
}
struct RegionVisitor<'tcx, Dispatcher> {
  tcx: TyCtxt<'tcx>,
  def_id: DefId,
  local: Local,
  place_stack: Vec<PlaceElem<'tcx>>,
  ty_stack: Vec<Ty<'tcx>>,
  dispatcher: Dispatcher,
  stop_at: StoppingCondition,
}
impl<'tcx, Dispatcher: Default> RegionVisitor<'tcx, Dispatcher> {
  fn new(
    tcx: TyCtxt<'tcx>,
    def_id: DefId,
    place: Place<'tcx>,
    stop_at: StoppingCondition,
  ) -> Self {
    Self {
      tcx,
      def_id,
      local: place.local,
      place_stack: place.projection.to_vec(),
      ty_stack: Vec::new(),
      dispatcher: Default::default(),
      stop_at,
    }
  }
  fn into_inner(self) -> Dispatcher {
    self.dispatcher
  }
}
pub const UNKNOWN_REGION: RegionVid = RegionVid::MAX;
impl<'tcx, Dispatcher: RegionVisitorDispatcher<'tcx>> TypeVisitor<TyCtxt<'tcx>>
  for RegionVisitor<'tcx, Dispatcher>
{
  fn visit_ty(&mut self, ty: Ty<'tcx>) {
    let tcx = self.tcx;
    if self.ty_stack.iter().any(|visited_ty| ty == *visited_ty) {
      return;
    }
    trace!(
      "exploring {:?} with {ty:?}",
      Place::make(self.local, &self.place_stack, tcx)
    );
    self.ty_stack.push(ty);
    match ty.kind() {
      _ if ty.is_box() => {
        self.place_stack.push(ProjectionElem::Deref);
        self.visit_ty(ty.boxed_ty().expect("Cannot unbox boxed type??"));
        self.place_stack.pop();
      }
      TyKind::Tuple(fields) => {
        for (i, field) in fields.iter().enumerate() {
          self
            .place_stack
            .push(ProjectionElem::Field(FieldIdx::from_usize(i), field));
          self.visit_ty(field);
          self.place_stack.pop();
        }
      }
      TyKind::Adt(adt_def, subst) => match adt_def.adt_kind() {
        ty::AdtKind::Struct => {
          for (i, field) in adt_def.all_visible_fields(self.def_id, tcx).enumerate() {
            let ty = field.ty(tcx, subst);
            self
              .place_stack
              .push(ProjectionElem::Field(FieldIdx::from_usize(i), ty));
            self.visit_ty(ty);
            self.place_stack.pop();
          }
        }
        ty::AdtKind::Union => {
          }
        ty::AdtKind::Enum => {
          for (i, variant) in adt_def.variants().iter().enumerate() {
            let variant_index = VariantIdx::from_usize(i);
            let cast = PlaceElem::Downcast(
              Some(adt_def.variant(variant_index).ident(tcx).name),
              variant_index,
            );
            self.place_stack.push(cast);
            for (j, field) in variant.fields.iter().enumerate() {
              let ty = field.ty(tcx, subst);
              let field = ProjectionElem::Field(FieldIdx::from_usize(j), ty);
              self.place_stack.push(field);
              self.visit_ty(ty);
              self.place_stack.pop();
            }
            self.place_stack.pop();
          }
        }
      },
      TyKind::Array(elem_ty, _) | TyKind::Slice(elem_ty) => {
        self
          .place_stack
          .push(ProjectionElem::Index(Local::from_usize(0)));
        self.visit_ty(*elem_ty);
        self.place_stack.pop();
      }
      TyKind::Ref(region, elem_ty, _) => match self.stop_at {
        StoppingCondition::None => {
          self.visit_region(*region);
          self.place_stack.push(ProjectionElem::Deref);
          self.visit_ty(*elem_ty);
          self.place_stack.pop();
        }
        StoppingCondition::AfterRefs => {
          self.visit_region(*region);
        }
        StoppingCondition::BeforeRefs => {}
      },
      TyKind::Closure(_, substs) | TyKind::Coroutine(_, substs) => {
        self.visit_ty(substs.as_closure().tupled_upvars_ty());
      }
      TyKind::RawPtr(ty, _) => {
        self.visit_region(Region::new_var(tcx, UNKNOWN_REGION));
        self.place_stack.push(ProjectionElem::Deref);
        self.visit_ty(*ty);
        self.place_stack.pop();
      }
      TyKind::FnDef(..)
      | TyKind::FnPtr(..)
      | TyKind::Foreign(..)
      | TyKind::Dynamic(..)
      | TyKind::Param(..)
      | TyKind::Never => {}
      _ if ty.is_primitive_ty() => {}
      _ => warn!("unimplemented {ty:?} ({:?})", ty.kind()),
    };
    self
      .dispatcher
      .on_visit_place(Place::make(self.local, &self.place_stack, tcx));
    self.dispatcher.on_visit_type(ty);
    self.ty_stack.pop();
  }
  fn visit_region(&mut self, region: ty::Region<'tcx>) -> Self::Result {
    trace!("visiting region {region:?}");
    let region = match region.kind() {
      RegionKind::ReVar(region) => region,
      RegionKind::ReStatic => RegionVid::from_usize(0),
      RegionKind::ReErased | RegionKind::ReLateParam(_) => {
        return;
      }
      _ => unreachable!("{:?}: {:?}", self.ty_stack.first().unwrap(), region),
    };
    let mutability = if self
      .ty_stack
      .iter()
      .any(|ty| ty.is_ref() && ty.ref_mutability().unwrap() == Mutability::Not)
    {
      Mutability::Not
    } else {
      Mutability::Mut
    };
    let place = Place::make(self.local, &self.place_stack, self.tcx);
    self
      .dispatcher
      .on_visit_region_member(region, place, mutability);
    self
      .dispatcher
      .on_visit_place(self.tcx.mk_place_deref(place));
  }
}
#[cfg(test)]
mod test {
  use rustc_borrowck::consumers::BodyWithBorrowckFacts;
  use rustc_hir::BodyId;
  use rustc_middle::{
    mir::{Place, PlaceElem},
    ty::TyCtxt,
  };
  use super::PlaceExt;
  use crate::{
    mir::body::BodyExt,
    test_utils::{self, compare_sets, CompileResult, Placer},
  };
  #[test]
  fn test_place_arg_direct() {
    let input = r"
fn foobar(x: &i32) {
  let y = 1;
  let z = &y;
  let k = Box::new(*x);
  let ref_k = &k;
  let box_ref = Box::new(x);
}
";
    test_utils::compile_body(input, |tcx, _, body_with_facts| {
      let body = &body_with_facts.body;
      let name_map = body.debug_info_name_map();
      let x = Place::from_local(name_map["x"], tcx);
      assert!(x.is_arg(body));
      assert!(x.is_direct(body, tcx));
      assert!(Place::make(x.local, &[PlaceElem::Deref], tcx).is_direct(body, tcx));
      let y = Place::from_local(name_map["y"], tcx);
      assert!(!y.is_arg(body));
      assert!(y.is_direct(body, tcx));
      let z = Place::from_local(name_map["z"], tcx);
      assert!(!z.is_arg(body));
      assert!(z.is_direct(body, tcx));
      assert!(!Place::make(z.local, &[PlaceElem::Deref], tcx).is_direct(body, tcx));
      let k = Place::from_local(name_map["k"], tcx);
      assert!(!k.is_arg(body));
      assert!(k.is_direct(body, tcx));
      assert!(Place::make(k.local, &[PlaceElem::Deref], tcx).is_direct(body, tcx));
      let deref_k = Place::make(k.local, &[PlaceElem::Deref], tcx);
      assert!(deref_k.is_direct(body, tcx));
      assert!(!deref_k.is_arg(body));
      assert_eq!(deref_k.refs_in_projection(body, tcx).count(), 0);
      let ref_k = Place::from_local(name_map["ref_k"], tcx);
      assert!(!ref_k.is_arg(body));
      assert!(k.is_direct(body, tcx));
      let deref_ref_k =
        Place::make(ref_k.local, &[PlaceElem::Deref, PlaceElem::Deref], tcx);
      assert!(deref_ref_k.is_direct(body, tcx));
      assert_eq!(deref_ref_k.refs_in_projection(body, tcx).count(), 1);
      let box_ref = Place::from_local(name_map["box_ref"], tcx);
      assert!(!box_ref.is_arg(body));
      assert!(!box_ref.is_indirect());
      let box_ref_deref = Place::make(box_ref.local, &[PlaceElem::Deref], tcx);
      assert_eq!(box_ref_deref.refs_in_projection(body, tcx).count(), 0);
      assert!(box_ref_deref.is_direct(body, tcx));
      let box_ref_deref_deref = box_ref_deref.project_deeper(&[PlaceElem::Deref], tcx);
      assert_eq!(box_ref_deref_deref.refs_in_projection(body, tcx).count(), 1);
      assert!(!box_ref_deref_deref.is_direct(body, tcx));
    });
  }
  #[test]
  fn test_place_to_string() {
    let input = r"
struct Point { x: usize, y: usize }
fn main() {
  let x = (0, 0);
  let y = Some(1);
  let z = &[Some((0, 1))];
  let w = (&y,);
  let p = &Point { x: 0, y: 0 };
}";
    test_utils::compile_body(input, |tcx, _, body_with_facts| {
      let body = &body_with_facts.body;
      let p = Placer::new(tcx, body);
      let x = p.local("x").mk();
      let x_1 = p.local("x").field(1).mk();
      let y_some_0 = p.local("y").downcast(1).field(0).mk();
      let z_deref_some_0_1 = p
        .local("z")
        .deref()
        .index(0)
        .downcast(1)
        .field(0)
        .field(1)
        .mk();
      let w_0_deref = p.local("w").field(0).deref().mk();
      let w_0_deref_some = p.local("w").field(0).deref().downcast(1).mk();
      let p_deref_x = p.local("p").deref().field(0).mk();
      let tests = [
        (x, "x"),
        (x_1, "x.1"),
        (y_some_0, "y@Some.0"),
        (z_deref_some_0_1, "(*z)[_]@Some.0.1"),
        (w_0_deref, "*w.0"),
        (w_0_deref_some, "(*w.0)@Some"),
        (p_deref_x, "(*p).x"),
      ];
      for (place, expected) in tests {
        assert_eq!(place.to_string(tcx, body).unwrap(), expected);
      }
    });
  }
  #[test]
  fn test_place_visitors() {
    let input = r"
fn main() {
  let x = 0;
  let y = (0, &x);
}
";
    fn callback<'tcx>(
      tcx: TyCtxt<'tcx>,
      body_id: BodyId,
      body_with_facts: &BodyWithBorrowckFacts<'tcx>,
    ) {
      let body = &body_with_facts.body;
      let def_id = tcx.hir().body_owner_def_id(body_id).to_def_id();
      let p = Placer::new(tcx, body);
      let y = p.local("y").mk();
      let y0 = p.local("y").field(0).mk();
      let y1 = p.local("y").field(1).mk();
      let y1_deref = p.local("y").field(1).deref().mk();
      compare_sets(y.interior_paths(tcx, body, def_id), [y, y0, y1, y1_deref]);
      compare_sets(y.interior_places(tcx, body, def_id), [y, y0, y1]);
      compare_sets(
        y.interior_pointers(tcx, body, def_id)
          .into_values()
          .flat_map(|vs| vs.into_iter().map(|(p, _)| p)),
        [y1],
      );
    }
    test_utils::compile_body(input, callback);
  }
}