Skip to main content

dfir_lang/graph/
graph_algorithms.rs

1//! General graph algorithm utility functions
2
3use std::collections::BTreeMap;
4
5/// Topologically sorts a set of nodes. Returns a list where the order of `Id`s will agree with
6/// the order of any path through the graph.
7///
8/// This succeeds if the input is a directed acyclic graph (DAG).
9///
10/// If the input has a cycle, an `Err` will be returned containing the cycle. Each node in the
11/// cycle will be listed exactly once.
12///
13/// <https://en.wikipedia.org/wiki/Topological_sorting>
14pub fn topo_sort<Id, NodeIds, PredsFn, PredsIter>(
15    node_ids: NodeIds,
16    mut preds_fn: PredsFn,
17) -> Result<Vec<Id>, Vec<Id>>
18where
19    Id: Copy + Eq + Ord,
20    NodeIds: IntoIterator<Item = Id>,
21    PredsFn: FnMut(Id) -> PredsIter,
22    PredsIter: IntoIterator<Item = Id>,
23{
24    let (mut marked, mut order) = Default::default();
25
26    fn pred_dfs_postorder<Id, PredsFn, PredsIter>(
27        node_id: Id,
28        preds_fn: &mut PredsFn,
29        marked: &mut BTreeMap<Id, bool>, // `false` => temporary, `true` => permanent.
30        order: &mut Vec<Id>,
31    ) -> Result<(), ()>
32    where
33        Id: Copy + Eq + Ord,
34        PredsFn: FnMut(Id) -> PredsIter,
35        PredsIter: IntoIterator<Item = Id>,
36    {
37        match marked.get(&node_id) {
38            Some(_permanent @ true) => Ok(()),
39            Some(_temporary @ false) => {
40                // Cycle found!
41                order.clear();
42                order.push(node_id);
43                Err(())
44            }
45            None => {
46                marked.insert(node_id, false);
47                for next_pred in (preds_fn)(node_id) {
48                    pred_dfs_postorder(next_pred, preds_fn, marked, order).map_err(|()| {
49                        if order.len() == 1 || order.first().unwrap() != order.last().unwrap() {
50                            order.push(node_id);
51                        }
52                    })?;
53                }
54                order.push(node_id);
55                marked.insert(node_id, true);
56                Ok(())
57            }
58        }
59    }
60
61    for node_id in node_ids {
62        if pred_dfs_postorder(node_id, &mut preds_fn, &mut marked, &mut order).is_err() {
63            // Cycle found.
64            let end = order.last().unwrap();
65            let beg = order.iter().position(|n| n == end).unwrap();
66            order.drain(0..=beg);
67            return Err(order);
68        }
69    }
70
71    Ok(order)
72}
73
74#[cfg(test)]
75mod test {
76    use std::collections::{BTreeMap, BTreeSet};
77
78    use itertools::Itertools;
79
80    use super::*;
81
82    #[test]
83    pub fn test_toposort() {
84        let edges = [
85            (5, 11),
86            (11, 2),
87            (11, 9),
88            (11, 10),
89            (7, 11),
90            (7, 8),
91            (8, 9),
92            (3, 8),
93            (3, 10),
94        ];
95
96        // https://commons.wikimedia.org/wiki/File:Directed_acyclic_graph_2.svg
97        let sort = topo_sort([2, 3, 5, 7, 8, 9, 10, 11], |v| {
98            edges
99                .iter()
100                .filter(move |&&(_, dst)| v == dst)
101                .map(|&(src, _)| src)
102        });
103        assert!(
104            sort.is_ok(),
105            "Did not expect cycle: {:?}",
106            sort.unwrap_err()
107        );
108
109        let sort = sort.unwrap();
110        println!("{:?}", sort);
111
112        let position: BTreeMap<_, _> = sort.iter().enumerate().map(|(i, &x)| (x, i)).collect();
113        for (src, dst) in edges.iter() {
114            assert!(position[src] < position[dst]);
115        }
116    }
117
118    #[test]
119    pub fn test_toposort_cycle() {
120        // https://commons.wikimedia.org/wiki/File:Directed_graph,_cyclic.svg
121        //          ┌────►C──────┐
122        //          │            │
123        //          │            ▼
124        // A───────►B            E ─────►F
125        //          ▲            │
126        //          │            │
127        //          └─────D◄─────┘
128        let edges = [
129            ('A', 'B'),
130            ('B', 'C'),
131            ('C', 'E'),
132            ('D', 'B'),
133            ('E', 'F'),
134            ('E', 'D'),
135        ];
136        let ids = edges
137            .iter()
138            .flat_map(|&(a, b)| [a, b])
139            .collect::<BTreeSet<_>>();
140        let cycle_rotations = BTreeSet::from_iter([
141            ['B', 'C', 'E', 'D'],
142            ['C', 'E', 'D', 'B'],
143            ['E', 'D', 'B', 'C'],
144            ['D', 'B', 'C', 'E'],
145        ]);
146
147        let permutations = ids.iter().copied().permutations(ids.len());
148        for permutation in permutations {
149            let result = topo_sort(permutation.iter().copied(), |v| {
150                edges
151                    .iter()
152                    .filter(move |&&(_, dst)| v == dst)
153                    .map(|&(src, _)| src)
154            });
155            assert!(result.is_err());
156            let cycle = result.unwrap_err();
157            assert!(
158                cycle_rotations.contains(&*cycle),
159                "cycle: {:?}, vertex order: {:?}",
160                cycle,
161                permutation
162            );
163        }
164    }
165}