Skip to main content

dfir_lang/graph/
mod.rs

1//! Graph representation stages for DFIR graphs.
2
3use std::borrow::Cow;
4use std::hash::Hash;
5
6use proc_macro2::{Ident, Span, TokenStream};
7use quote::ToTokens;
8use serde::{Deserialize, Serialize};
9use syn::punctuated::Punctuated;
10use syn::spanned::Spanned;
11use syn::{Expr, ExprPath, GenericArgument, Token, Type};
12
13use self::ops::{OperatorConstraints, Persistence};
14use crate::diagnostic::{Diagnostic, Diagnostics, Level};
15use crate::parse::{DfirCode, IndexInt, Operator, PortIndex, Ported};
16use crate::pretty_span::PrettySpan;
17
18mod di_mul_graph;
19mod eliminate_extra_unions_tees;
20mod flat_graph_builder;
21mod flat_to_partitioned;
22mod graph_write;
23mod meta_graph;
24mod meta_graph_debugging;
25
26use std::fmt::Display;
27
28pub use di_mul_graph::DiMulGraph;
29pub use eliminate_extra_unions_tees::eliminate_extra_unions_tees;
30pub use flat_graph_builder::{FlatGraphBuilder, FlatGraphBuilderOutput};
31pub use flat_to_partitioned::partition_graph;
32pub use meta_graph::{DfirGraph, WriteConfig, WriteGraphType};
33
34pub use crate::graph_ids::{GraphEdgeId, GraphLoopId, GraphNodeId, GraphSubgraphId};
35
36pub mod graph_algorithms;
37pub mod ops;
38
39impl GraphSubgraphId {
40    /// Generate a deterministic `Ident` for the given subgraph ID.
41    pub fn as_ident(self, span: Span) -> Ident {
42        use slotmap::Key;
43        Ident::new(&format!("sgid_{:?}", self.data()), span)
44    }
45}
46
47impl GraphLoopId {
48    /// Generate a deterministic `Ident` for the given loop ID.
49    pub fn as_ident(self, span: Span) -> Ident {
50        use slotmap::Key;
51        Ident::new(&format!("loop_{:?}", self.data()), span)
52    }
53}
54
55/// Context identifier as a string.
56const CONTEXT: &str = "context";
57/// Runnable DFIR graph object identifier as a string.
58const GRAPH: &str = "df";
59
60const HANDOFF_NODE_STR: &str = "handoff";
61const MODULE_BOUNDARY_NODE_STR: &str = "module_boundary";
62
63mod serde_syn {
64    use serde::{Deserialize, Deserializer, Serializer};
65
66    pub fn serialize<S, T>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
67    where
68        S: Serializer,
69        T: quote::ToTokens,
70    {
71        serializer.serialize_str(&value.to_token_stream().to_string())
72    }
73
74    pub fn deserialize<'de, D, T>(deserializer: D) -> Result<T, D::Error>
75    where
76        D: Deserializer<'de>,
77        T: syn::parse::Parse,
78    {
79        let s = String::deserialize(deserializer)?;
80        syn::parse_str(&s).map_err(<D::Error as serde::de::Error>::custom)
81    }
82}
83
84/// A variable name assigned to a pipeline in DFIR syntax.
85///
86/// Fundamentally a serializable/deserializable wrapper around [`syn::Ident`].
87#[derive(Clone, Debug, Serialize, Deserialize, PartialOrd, Ord, PartialEq, Eq, Hash)]
88pub struct Varname(#[serde(with = "serde_syn")] pub Ident);
89
90/// A node, corresponding to an operator or a handoff.
91#[derive(Clone, Serialize, Deserialize)]
92pub enum GraphNode {
93    /// An operator.
94    Operator(#[serde(with = "serde_syn")] Operator),
95    /// A handoff point, used between subgraphs (or within a subgraph to break a cycle).
96    Handoff {
97        /// The span of the input into the handoff.
98        #[serde(skip, default = "Span::call_site")]
99        src_span: Span,
100        /// The span of the output out of the handoff.
101        #[serde(skip, default = "Span::call_site")]
102        dst_span: Span,
103    },
104
105    /// Module Boundary, used for importing modules. Only exists prior to partitioning.
106    ModuleBoundary {
107        /// If this module is an input or output boundary.
108        input: bool,
109
110        /// The span of the import!() expression that imported this module.
111        /// The value of this span when the ModuleBoundary node is still inside the module is Span::call_site()
112        /// TODO: This could one day reference into the module file itself?
113        #[serde(skip, default = "Span::call_site")]
114        import_expr: Span,
115    },
116}
117impl GraphNode {
118    /// Return the node as a human-readable string.
119    pub fn to_pretty_string(&self) -> Cow<'static, str> {
120        match self {
121            GraphNode::Operator(op) => op.to_pretty_string().into(),
122            GraphNode::Handoff { .. } => HANDOFF_NODE_STR.into(),
123            GraphNode::ModuleBoundary { .. } => MODULE_BOUNDARY_NODE_STR.into(),
124        }
125    }
126
127    /// Return the name of the node as a string, excluding parenthesis and op source code.
128    pub fn to_name_string(&self) -> Cow<'static, str> {
129        match self {
130            GraphNode::Operator(op) => op.name_string().into(),
131            GraphNode::Handoff { .. } => HANDOFF_NODE_STR.into(),
132            GraphNode::ModuleBoundary { .. } => MODULE_BOUNDARY_NODE_STR.into(),
133        }
134    }
135
136    /// Return the source code span of the node (for operators) or input/otput spans for handoffs.
137    pub fn span(&self) -> Span {
138        match self {
139            Self::Operator(op) => op.span(),
140            &Self::Handoff {
141                src_span, dst_span, ..
142            } => src_span.join(dst_span).unwrap_or(src_span),
143            Self::ModuleBoundary { import_expr, .. } => *import_expr,
144        }
145    }
146}
147impl std::fmt::Debug for GraphNode {
148    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
149        match self {
150            Self::Operator(operator) => {
151                write!(f, "Node::Operator({} span)", PrettySpan(operator.span()))
152            }
153            Self::Handoff { .. } => write!(f, "Node::Handoff"),
154            Self::ModuleBoundary { input, .. } => {
155                write!(f, "Node::ModuleBoundary{{input: {}}}", input)
156            }
157        }
158    }
159}
160
161/// Meta-data relating to operators which may be useful throughout the compilation process.
162///
163/// This data can be generated from the graph, but it is useful to have it readily available
164/// pre-computed as many algorithms use the same info. Stuff like port names, arguments, and the
165/// [`OperatorConstraints`] for the operator.
166///
167/// Because it is derived from the graph itself, there can be "cache invalidation"-esque issues
168/// if this data is not kept in sync with the graph.
169#[derive(Clone, Debug)]
170pub struct OperatorInstance {
171    /// Name of the operator (will match [`OperatorConstraints::name`]).
172    pub op_constraints: &'static OperatorConstraints,
173    /// Port values used as this operator's input.
174    pub input_ports: Vec<PortIndexValue>,
175    /// Port values used as this operator's output.
176    pub output_ports: Vec<PortIndexValue>,
177    /// Singleton references within the operator arguments.
178    pub singletons_referenced: Vec<Ident>,
179
180    /// Generic arguments.
181    pub generics: OpInstGenerics,
182    /// Arguments provided by the user into the operator as arguments.
183    /// I.e. the `a, b, c` in `-> my_op(a, b, c) -> `.
184    ///
185    /// These arguments do not include singleton postprocessing codegen. Instead use
186    /// [`ops::WriteContextArgs::arguments`].
187    pub arguments_pre: Punctuated<Expr, Token![,]>,
188    /// Unparsed arguments, for singleton parsing.
189    pub arguments_raw: TokenStream,
190}
191
192/// Operator generic arguments, split into specific categories.
193#[derive(Clone, Debug)]
194pub struct OpInstGenerics {
195    /// Operator generic (type or lifetime) arguments.
196    pub generic_args: Option<Punctuated<GenericArgument, Token![,]>>,
197    /// Lifetime persistence arguments. Corresponds to a prefix of [`Self::generic_args`].
198    pub persistence_args: Vec<Persistence>,
199    /// Type persistence arguments. Corersponds to a (suffix) of [`Self::generic_args`].
200    pub type_args: Vec<Type>,
201}
202
203impl OpInstGenerics {
204    /// Helper to join a sequence of spans into a single span, if possible.
205    ///
206    /// Returns `None` if there are no spans or if any `Span::join` call fails
207    /// (for example, when spans are not contiguous).
208    fn join_spans<I>(mut spans: I) -> Option<Span>
209    where
210        I: Iterator<Item = Span>,
211    {
212        let mut span = spans.next()?;
213        for s in spans {
214            span = span.join(s)?;
215        }
216        Some(span)
217    }
218
219    /// Returns a [`Span`] containing all persistence (lifetime) args if possible.
220    pub fn persistence_args_span(&self) -> Option<Span> {
221        self.generic_args.as_ref().and_then(|args| {
222            Self::join_spans(
223                args.iter()
224                    .filter(|a| matches!(a, GenericArgument::Lifetime(_)))
225                    .map(|a| a.span()),
226            )
227        })
228    }
229
230    /// Returns a [`Span`] containing all type args if possible.
231    pub fn type_args_span(&self) -> Option<Span> {
232        self.generic_args.as_ref().and_then(|args| {
233            Self::join_spans(
234                args.iter()
235                    .filter(|a| matches!(a, GenericArgument::Type(_)))
236                    .map(|a| a.span()),
237            )
238        })
239    }
240}
241
242/// Gets the generic arguments for the operator.
243///
244/// This helper method is useful due to the special handling of persistence lifetimes (`'static`,
245/// `'tick`, `'mutable`) which must come before other generic parameters.
246pub fn get_operator_generics(diagnostics: &mut Diagnostics, operator: &Operator) -> OpInstGenerics {
247    // Generic arguments.
248    let generic_args = operator.type_arguments().cloned();
249    let persistence_args = generic_args.iter().flatten().map_while(|generic_arg| match generic_arg {
250            GenericArgument::Lifetime(lifetime) => {
251                match &*lifetime.ident.to_string() {
252                    "none" => Some(Persistence::None),
253                    "loop" => Some(Persistence::Loop),
254                    "tick" => Some(Persistence::Tick),
255                    "static" => Some(Persistence::Static),
256                    "mutable" => Some(Persistence::Mutable),
257                    _ => {
258                        diagnostics.push(Diagnostic::spanned(
259                            generic_arg.span(),
260                            Level::Error,
261                            format!("Unknown lifetime generic argument `'{}`, expected `'none`, `'loop`, `'tick`, `'static`, or `'mutable`.", lifetime.ident),
262                        ));
263                        // TODO(mingwei): should really keep going and not short circuit?
264                        None
265                    }
266                }
267            },
268            _ => None,
269        }).collect::<Vec<_>>();
270    let type_args = generic_args
271        .iter()
272        .flatten()
273        .skip(persistence_args.len())
274        .map_while(|generic_arg| match generic_arg {
275            GenericArgument::Type(typ) => Some(typ),
276            _ => None,
277        })
278        .cloned()
279        .collect::<Vec<_>>();
280
281    OpInstGenerics {
282        generic_args,
283        persistence_args,
284        type_args,
285    }
286}
287
288/// Push, Pull, Comp, or Hoff polarity.
289#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
290pub enum Color {
291    /// Pull (green)
292    Pull,
293    /// Push (blue)
294    Push,
295    /// Computation (yellow)
296    Comp,
297    /// Handoff (grey) -- not a color for operators, inserted between subgraphs.
298    Hoff,
299}
300
301/// Helper struct for [`PortIndex`] which keeps span information for elided ports.
302#[derive(Clone, Debug, Serialize, Deserialize)]
303pub enum PortIndexValue {
304    /// An integer value: `[0]`, `[1]`, etc. Can be negative although we don't use that (2023-08-16).
305    Int(#[serde(with = "serde_syn")] IndexInt),
306    /// A name or path. `[pos]`, `[neg]`, etc. Can use `::` separators but we don't use that (2023-08-16).
307    Path(#[serde(with = "serde_syn")] ExprPath),
308    /// Elided, unspecified port. We have this variant, rather than wrapping in `Option`, in order
309    /// to preserve the `Span` information.
310    Elided(#[serde(skip)] Option<Span>),
311}
312impl PortIndexValue {
313    /// For a [`Ported`] value like `[port_in]name[port_out]`, get the `port_in` and `port_out` as
314    /// [`PortIndexValue`]s.
315    pub fn from_ported<Inner>(ported: Ported<Inner>) -> (Self, Inner, Self)
316    where
317        Inner: Spanned,
318    {
319        let ported_span = Some(ported.inner.span());
320        let port_inn = ported
321            .inn
322            .map(|idx| idx.index.into())
323            .unwrap_or_else(|| Self::Elided(ported_span));
324        let inner = ported.inner;
325        let port_out = ported
326            .out
327            .map(|idx| idx.index.into())
328            .unwrap_or_else(|| Self::Elided(ported_span));
329        (port_inn, inner, port_out)
330    }
331
332    /// Returns `true` if `self` is not [`PortIndexValue::Elided`].
333    pub fn is_specified(&self) -> bool {
334        !matches!(self, Self::Elided(_))
335    }
336
337    /// Returns whichever of the two ports are specified.
338    /// If both are [`Self::Elided`], returns [`Self::Elided`].
339    /// If both are specified, returns `Err(self)`.
340    #[allow(clippy::allow_attributes, reason = "Only triggered on nightly.")]
341    #[allow(
342        clippy::result_large_err,
343        reason = "variants are same size, error isn't to be propagated."
344    )]
345    pub fn combine(self, other: Self) -> Result<Self, Self> {
346        match (self.is_specified(), other.is_specified()) {
347            (false, _other) => Ok(other),
348            (true, false) => Ok(self),
349            (true, true) => Err(self),
350        }
351    }
352
353    /// Formats self as a human-readable string for error messages.
354    pub fn as_error_message_string(&self) -> String {
355        match self {
356            PortIndexValue::Int(n) => format!("`{}`", n.value),
357            PortIndexValue::Path(path) => format!("`{}`", path.to_token_stream()),
358            PortIndexValue::Elided(_) => "<elided>".to_owned(),
359        }
360    }
361
362    /// Returns the span of this port value.
363    pub fn span(&self) -> Span {
364        match self {
365            PortIndexValue::Int(x) => x.span(),
366            PortIndexValue::Path(x) => x.span(),
367            PortIndexValue::Elided(span) => span.unwrap_or_else(Span::call_site),
368        }
369    }
370}
371impl From<PortIndex> for PortIndexValue {
372    fn from(value: PortIndex) -> Self {
373        match value {
374            PortIndex::Int(x) => Self::Int(x),
375            PortIndex::Path(x) => Self::Path(x),
376        }
377    }
378}
379impl PartialEq for PortIndexValue {
380    fn eq(&self, other: &Self) -> bool {
381        match (self, other) {
382            (Self::Int(l0), Self::Int(r0)) => l0 == r0,
383            (Self::Path(l0), Self::Path(r0)) => l0 == r0,
384            (Self::Elided(_), Self::Elided(_)) => true,
385            _else => false,
386        }
387    }
388}
389impl Eq for PortIndexValue {}
390impl PartialOrd for PortIndexValue {
391    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
392        Some(self.cmp(other))
393    }
394}
395impl Ord for PortIndexValue {
396    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
397        match (self, other) {
398            (Self::Int(s), Self::Int(o)) => s.cmp(o),
399            (Self::Path(s), Self::Path(o)) => s
400                .to_token_stream()
401                .to_string()
402                .cmp(&o.to_token_stream().to_string()),
403            (Self::Elided(_), Self::Elided(_)) => std::cmp::Ordering::Equal,
404            (Self::Int(_), Self::Path(_)) => std::cmp::Ordering::Less,
405            (Self::Path(_), Self::Int(_)) => std::cmp::Ordering::Greater,
406            (_, Self::Elided(_)) => std::cmp::Ordering::Less,
407            (Self::Elided(_), _) => std::cmp::Ordering::Greater,
408        }
409    }
410}
411
412impl Display for PortIndexValue {
413    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
414        match self {
415            PortIndexValue::Int(x) => write!(f, "{}", x.to_token_stream()),
416            PortIndexValue::Path(x) => write!(f, "{}", x.to_token_stream()),
417            PortIndexValue::Elided(_) => write!(f, "[]"),
418        }
419    }
420}
421
422/// Output of [`build_dfir_code`].
423pub struct BuildDfirCodeOutput {
424    /// The now-partitioned graph.
425    pub partitioned_graph: DfirGraph,
426    /// The Rust source code tokens for the DFIR.
427    pub code: TokenStream,
428    /// Any (non-error) diagnostics emitted.
429    pub diagnostics: Diagnostics,
430}
431
432/// Compiles a [`DfirCode`] AST into inline source code that runs the dataflow.
433pub fn build_dfir_code(
434    dfir_code: DfirCode,
435    root: &TokenStream,
436) -> Result<BuildDfirCodeOutput, Diagnostics> {
437    let flat_graph_builder = FlatGraphBuilder::from_dfir(dfir_code);
438
439    let FlatGraphBuilderOutput {
440        mut flat_graph,
441        uses,
442        mut diagnostics,
443    } = flat_graph_builder.build()?;
444
445    let () = match flat_graph.merge_modules() {
446        Ok(()) => (),
447        Err(d) => {
448            diagnostics.push(d);
449            return Err(diagnostics);
450        }
451    };
452
453    eliminate_extra_unions_tees(&mut flat_graph);
454
455    // Reject `loop { }` blocks (not yet supported in inline codegen).
456    // TODO(cleanup): find a better home for this check — ideally inside `partition_graph` once
457    // it supports returning multiple diagnostics.
458    for (_loop_id, nodes) in flat_graph.loops() {
459        let span = nodes
460            .first()
461            .map_or_else(Span::call_site, |&n| flat_graph.node(n).span());
462        diagnostics.push(Diagnostic::spanned(
463            span,
464            Level::Error,
465            "`loop { }` blocks are not (yet) supported in `dfir_syntax!`.",
466        ));
467    }
468    if diagnostics.has_error() {
469        return Err(diagnostics);
470    }
471
472    let partitioned_graph = match partition_graph(flat_graph) {
473        Ok(partitioned_graph) => partitioned_graph,
474        Err(d) => {
475            diagnostics.push(d);
476            return Err(diagnostics);
477        }
478    };
479
480    let code =
481        partitioned_graph.as_code(root, true, quote::quote! { #( #uses )* }, &mut diagnostics)?;
482
483    Ok(BuildDfirCodeOutput {
484        partitioned_graph,
485        code,
486        diagnostics,
487    })
488}
489
490/// Changes all of token's spans to `span`, recursing into groups.
491fn change_spans(tokens: TokenStream, span: Span) -> TokenStream {
492    use proc_macro2::{Group, TokenTree};
493    tokens
494        .into_iter()
495        .map(|token| match token {
496            TokenTree::Group(mut group) => {
497                group.set_span(span);
498                TokenTree::Group(Group::new(
499                    group.delimiter(),
500                    change_spans(group.stream(), span),
501                ))
502            }
503            TokenTree::Ident(mut ident) => {
504                ident.set_span(span.resolved_at(ident.span()));
505                TokenTree::Ident(ident)
506            }
507            TokenTree::Punct(mut punct) => {
508                punct.set_span(span);
509                TokenTree::Punct(punct)
510            }
511            TokenTree::Literal(mut literal) => {
512                literal.set_span(span);
513                TokenTree::Literal(literal)
514            }
515        })
516        .collect()
517}