use itertools::Itertools;
use log::debug;
use rustc_middle::{
  mir::{visit::Visitor, *},
  ty::{AdtKind, TyKind},
};
use rustc_target::abi::FieldIdx;
use rustc_utils::{mir::place::PlaceCollector, AdtDefExt, OperandExt, PlaceExt};
use crate::mir::{
  placeinfo::PlaceInfo,
  utils::{self, AsyncHack},
};
#[derive(Debug)]
pub enum MutationStatus {
  Definitely,
  Possibly,
}
#[derive(Debug)]
pub enum Reason {
  Argument(u8),
  AssignTarget,
}
#[derive(Debug)]
pub struct Mutation<'tcx> {
  pub mutated: Place<'tcx>,
  pub reason: Reason,
  pub inputs: Vec<Place<'tcx>>,
  pub status: MutationStatus,
}
pub struct ModularMutationVisitor<'a, 'tcx, F>
where
  F: FnMut(Location, Vec<Mutation<'tcx>>),
{
  f: F,
  place_info: &'a PlaceInfo<'tcx>,
}
impl<'a, 'tcx, F> ModularMutationVisitor<'a, 'tcx, F>
where
  F: FnMut(Location, Vec<Mutation<'tcx>>),
{
  pub fn new(place_info: &'a PlaceInfo<'tcx>, f: F) -> Self {
    ModularMutationVisitor { place_info, f }
  }
}
impl<'tcx, F> Visitor<'tcx> for ModularMutationVisitor<'_, 'tcx, F>
where
  F: FnMut(Location, Vec<Mutation<'tcx>>),
{
  fn visit_assign(
    &mut self,
    mutated: &Place<'tcx>,
    rvalue: &Rvalue<'tcx>,
    location: Location,
  ) {
    debug!("Checking {location:?}: {mutated:?} = {rvalue:?}");
    let body = self.place_info.body;
    let tcx = self.place_info.tcx;
    match rvalue {
      Rvalue::Aggregate(agg_kind, ops) => {
        let info = match &**agg_kind {
          AggregateKind::Adt(def_id, idx, substs, _, _) => {
            let adt_def = tcx.adt_def(*def_id);
            let variant = adt_def.variant(*idx);
            let mutated = match adt_def.adt_kind() {
              AdtKind::Enum => mutated.project_deeper(
                &[ProjectionElem::Downcast(Some(variant.name), *idx)],
                tcx,
              ),
              AdtKind::Struct | AdtKind::Union => *mutated,
            };
            let fields = variant.fields.iter();
            let tys = fields
              .map(|field| field.ty(tcx, substs))
              .collect::<Vec<_>>();
            Some((mutated, tys))
          }
          AggregateKind::Tuple => {
            let ty = rvalue.ty(body.local_decls(), tcx);
            Some((*mutated, ty.tuple_fields().to_vec()))
          }
          AggregateKind::Closure(_, args) => {
            let ty = args.as_closure().upvar_tys();
            Some((*mutated, ty.to_vec()))
          }
          _ => None,
        };
        if let Some((mutated, tys)) = info {
          if tys.len() > 0 {
            let fields =
              tys
                .into_iter()
                .enumerate()
                .zip(ops.iter())
                .map(|((i, ty), input_op)| {
                  let field = PlaceElem::Field(FieldIdx::from_usize(i), ty);
                  let input_place = input_op.as_place();
                  (mutated.project_deeper(&[field], tcx), input_place)
                });
            let mutations = fields
              .map(|(mutated, input)| Mutation {
                mutated,
                reason: Reason::AssignTarget,
                inputs: input.into_iter().collect::<Vec<_>>(),
                status: MutationStatus::Definitely,
              })
              .collect::<Vec<_>>();
            (self.f)(location, mutations);
            return;
          }
        }
      }
      Rvalue::Use(Operand::Move(place) | Operand::Copy(place)) => {
        let place_ty = place.ty(&body.local_decls, tcx).ty;
        if let TyKind::Adt(adt_def, substs) = place_ty.kind() {
          if adt_def.is_struct() {
            let fields = adt_def
              .all_visible_fields(self.place_info.def_id, self.place_info.tcx)
              .enumerate()
              .map(|(i, field_def)| {
                PlaceElem::Field(FieldIdx::from_usize(i), field_def.ty(tcx, substs))
              });
            let mut mutations = fields
              .map(|field| {
                let mutated_field = mutated.project_deeper(&[field], tcx);
                let input_field = place.project_deeper(&[field], tcx);
                Mutation {
                  mutated: mutated_field,
                  reason: Reason::AssignTarget,
                  inputs: vec![input_field],
                  status: MutationStatus::Definitely,
                }
              })
              .collect::<Vec<_>>();
            if mutations.is_empty() {
              mutations.push(Mutation {
                mutated: *mutated,
                reason: Reason::AssignTarget,
                inputs: vec![*place],
                status: MutationStatus::Definitely,
              });
            }
            (self.f)(location, mutations);
            return;
          }
        }
      }
      Rvalue::Ref(_, _, place) => {
        let inputs = place
          .refs_in_projection(self.place_info.body, self.place_info.tcx)
          .map(|(place_ref, _)| Place::from_ref(place_ref, tcx))
          .collect::<Vec<_>>();
        (self.f)(location, vec![Mutation {
          mutated: *mutated,
          reason: Reason::AssignTarget,
          inputs,
          status: MutationStatus::Definitely,
        }]);
        return;
      }
      _ => {}
    }
    let mut collector = PlaceCollector::default();
    collector.visit_rvalue(rvalue, location);
    (self.f)(location, vec![Mutation {
      mutated: *mutated,
      reason: Reason::AssignTarget,
      inputs: collector.0,
      status: MutationStatus::Definitely,
    }]);
  }
  fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, location: Location) {
    debug!("Checking {location:?}: {:?}", terminator.kind);
    let tcx = self.place_info.tcx;
    match &terminator.kind {
      TerminatorKind::Call {
        args,
        destination,
        ..
      } => {
        let async_hack = AsyncHack::new(
          self.place_info.tcx,
          self.place_info.body,
          self.place_info.def_id,
        );
        let arg_places = utils::arg_places(args)
          .into_iter()
          .map(|(_, place)| place)
          .filter(|place| !async_hack.ignore_place(*place))
          .collect::<Vec<_>>();
        let arg_inputs = arg_places
          .iter()
          .flat_map(|arg| self.place_info.reachable_values(*arg, Mutability::Not))
          .copied()
          .collect_vec();
        let ret_is_unit = destination
          .ty(self.place_info.body.local_decls(), tcx)
          .ty
          .is_unit();
        let dest_inputs = if ret_is_unit {
          Vec::new()
        } else {
          arg_inputs.clone()
        };
        let mut mutations = vec![Mutation {
          mutated: *destination,
          inputs: dest_inputs,
          reason: Reason::AssignTarget,
          status: MutationStatus::Definitely,
        }];
        for (num, arg) in arg_places.into_iter().enumerate() {
          for arg_mut in self.place_info.reachable_values(arg, Mutability::Mut) {
            if *arg_mut != arg {
              mutations.push(Mutation {
                mutated: *arg_mut,
                reason: Reason::Argument(num as u8),
                inputs: arg_inputs.clone(),
                status: MutationStatus::Possibly,
              });
            }
          }
        }
        (self.f)(location, mutations);
      }
      _ => {}
    }
  }
}