paralegal_compiler/
initialization_typ.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
use std::collections::HashSet;

use crate::common::{ast::*, count_references_to_variable};

// How a variable is being initialized in the compiled policy
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum InitializationType {
    NodeCluster,
    GlobalNodesIterator,
}

/// Compute the appropriate initialization type for a variable intro.
pub fn compute_initialization_typ(
    body: &ASTNode,
    clause_intro_typ: OgClauseIntroType,
    var_intro: &VariableIntro,
) -> Option<InitializationType> {
    match &var_intro.intro {
        VariableIntroType::Roots
        | VariableIntroType::AllNodes
        | VariableIntroType::VariableSourceOf(..) => Some(InitializationType::GlobalNodesIterator),
        // Already initialized
        VariableIntroType::Variable => None,
        VariableIntroType::VariableMarked { marker: _, on_type } => {
            // Variable source of doesn't play well with node clusters, so don't use them for types
            if *on_type {
                return Some(InitializationType::GlobalNodesIterator);
            }

            let mut initialization_typ = InitializationType::NodeCluster;

            match clause_intro_typ {
                OgClauseIntroType::ForEach => {
                    for_each_initialization_typ(
                        &var_intro.variable,
                        body,
                        &None,
                        &mut initialization_typ,
                    );
                }
                OgClauseIntroType::ThereIs => {
                    there_is_initialization_typ(&var_intro.variable, body, &mut initialization_typ);
                }
                OgClauseIntroType::Conditional => todo!("i think this is unreachable???"),
            }

            Some(initialization_typ)
        }
    }
}

/// Determine how a lifted definition should be initialized.
pub fn compute_lifted_def_initialization_typ(
    definition: &Definition,
    body: &ASTNode,
) -> InitializationType {
    fn find_matching_clause(definition: &Definition, body: &ASTNode) -> Option<Clause> {
        match body {
            ASTNode::Clause(clause) => {
                match &clause.intro {
                    ClauseIntro::ForEach(intro) | ClauseIntro::ThereIs(intro) => {
                        if intro.variable == definition.variable {
                            return Some(*clause.clone());
                        }
                    }
                    ClauseIntro::Conditional(_) => {}
                }
                find_matching_clause(definition, &clause.body)
            }
            ASTNode::JoinedNodes(obligation) => find_matching_clause(definition, &obligation.src)
                .or_else(|| find_matching_clause(definition, &obligation.sink)),
            _ => None,
        }
    }
    let clause = find_matching_clause(definition, body).unwrap_or_else(|| {
        panic!(
            "Should have found clause that introduced the lifted definition {:?}",
            definition
        )
    });
    compute_initialization_typ(
        &clause.body,
        (&clause.intro).into(),
        &definition.declaration,
    )
    .unwrap_or_else(|| panic!("Lifted definitions always need to be initialized"))
}

// if it's a for each variable, then it can be a node cluster if:
//     - it's not a type
//     - not referenced in an associated call site relation
//     - not *first* introduced in an only via; if it's referenced earlier and then used in an only via later,
//       that should be ok bc any short circuting would have already happened.
//     - all of the references are sinks
//     - all of the references are srcs in a binary negation relation (except is_marked),
//       because that can be rewritten as "there does not exist"
//     - if it's in a conditional premise, then only if it's never referenced again after the premise.
fn for_each_initialization_typ(
    variable: &Variable,
    body: &ASTNode,
    conditional_premise_vars: &Option<&mut HashSet<Variable>>,
    initialization_typ: &mut InitializationType,
) {
    match body {
        ASTNode::Relation(relation) => match relation {
            Relation::Binary { left, right, typ } => {
                if variable == left
                    || (variable == right
                        && conditional_premise_vars
                            .as_ref()
                            .is_some_and(|cvp| cvp.contains(variable)))
                    || matches!(typ, Binop::AssociatedCallSite)
                {
                    *initialization_typ = InitializationType::GlobalNodesIterator;
                }
            }
            Relation::Negation(relation) => match *relation.clone() {
                Relation::Binary { left, right, typ } => {
                    if *variable == right
                        || (*variable == left
                            && conditional_premise_vars
                                .as_ref()
                                .is_some_and(|cvp| cvp.contains(variable)))
                        || matches!(typ, Binop::AssociatedCallSite)
                    {
                        *initialization_typ = InitializationType::GlobalNodesIterator;
                    }
                }
                Relation::Negation(_) => unreachable!("double negation doesn't parse"),
                Relation::IsMarked(var, _) => {
                    if *variable == var {
                        *initialization_typ = InitializationType::GlobalNodesIterator;
                    }
                }
            },
            Relation::IsMarked(..) => {}
        },
        ASTNode::Clause(clause) => match clause.intro.clone() {
            ClauseIntro::ForEach(..) | ClauseIntro::ThereIs(..) => {
                for_each_initialization_typ(
                    variable,
                    &clause.body,
                    conditional_premise_vars,
                    initialization_typ,
                );
            }
            ClauseIntro::Conditional(relation) => {
                let mut conditional_premise_vars: HashSet<Variable> = HashSet::new();
                match relation {
                    Relation::Binary { left, right, .. } => {
                        conditional_premise_vars.insert(left);
                        conditional_premise_vars.insert(right);
                    }
                    Relation::Negation(relation) => match *relation {
                        Relation::Binary { left, right, .. } => {
                            conditional_premise_vars.insert(left);
                            conditional_premise_vars.insert(right);
                        }
                        Relation::Negation(_) => unreachable!("double negation doesn't parse"),
                        Relation::IsMarked(var, _) => {
                            conditional_premise_vars.insert(var);
                        }
                    },
                    Relation::IsMarked(var, _) => {
                        conditional_premise_vars.insert(var);
                    }
                };
                for_each_initialization_typ(
                    variable,
                    &clause.body,
                    &Some(&mut conditional_premise_vars),
                    initialization_typ,
                );
            }
        },
        ASTNode::JoinedNodes(obligation) => {
            for_each_initialization_typ(
                variable,
                &obligation.src,
                conditional_premise_vars,
                initialization_typ,
            );
            for_each_initialization_typ(
                variable,
                &obligation.sink,
                conditional_premise_vars,
                initialization_typ,
            );
        }
        ASTNode::OnlyVia(..) => {
            // render_only_via doesn't call this function, so if we reach this point, we're evaluating a definition declaration that gets referenced in an only via.
            // Since the templates implement contains() for NodeClusters that just collect the nodes and iterate over them one at a time anyway,
            // there's no need to override here; it doesn't matter which typ we use.
        }
    }
}

// If it's a there is variable, then it can be a nodecluster if it gets referenced once in the clause
// as long as it's not used an associated call site relation, which need to reason about the same object across two graph queries
fn there_is_initialization_typ(
    variable: &Variable,
    body: &ASTNode,
    initialization_typ: &mut InitializationType,
) {
    // Associated call sites can't use node clusters
    fn var_in_associated_call_site_relation(variable: &Variable, body: &ASTNode) -> bool {
        match body {
            ASTNode::Relation(relation) => match relation {
                Relation::Binary { left, right, typ } => {
                    (variable == left || variable == right)
                        && matches!(typ, Binop::AssociatedCallSite)
                }
                Relation::Negation(relation) => match &**relation {
                    Relation::Binary { left, right, typ } => {
                        (variable == left || variable == right)
                            && matches!(typ, Binop::AssociatedCallSite)
                    }
                    Relation::Negation(_) => unreachable!("double negation doesn't parse"),
                    Relation::IsMarked(..) => false,
                },
                Relation::IsMarked(..) => false,
            },
            ASTNode::Clause(clause) => var_in_associated_call_site_relation(variable, &clause.body),
            ASTNode::OnlyVia(..) => false,
            ASTNode::JoinedNodes(obligation) => {
                var_in_associated_call_site_relation(variable, &obligation.src)
                    || var_in_associated_call_site_relation(variable, &obligation.sink)
            }
        }
    }

    let mut count = 0;
    count_references_to_variable(variable, body, &mut count);
    if count > 1 || var_in_associated_call_site_relation(variable, body) {
        *initialization_typ = InitializationType::GlobalNodesIterator;
    }
}