summaryrefslogtreecommitdiff
path: root/viz/src/strong_component.rs
diff options
context:
space:
mode:
authorJSDurand <mmemmew@gmail.com>2023-06-02 14:51:25 +0800
committerJSDurand <mmemmew@gmail.com>2023-06-02 14:51:25 +0800
commit1455da10f943e2aa1bdf26fb2697dafccc61e073 (patch)
tree7d6c51a1040fc6b05bf3a19386e05016f8c0ab0f /viz/src/strong_component.rs
parent83c66eb77c6affaa9ac4fabd808556613c5bf973 (diff)
viz: finished decycle algorithm
Diffstat (limited to 'viz/src/strong_component.rs')
-rw-r--r--viz/src/strong_component.rs257
1 files changed, 257 insertions, 0 deletions
diff --git a/viz/src/strong_component.rs b/viz/src/strong_component.rs
new file mode 100644
index 0000000..37dcd39
--- /dev/null
+++ b/viz/src/strong_component.rs
@@ -0,0 +1,257 @@
+//! This module implements the Trajan's algorithm for finding strongly
+//! connected components of a directed graph with only one depth-first
+//! traversal.
+
+use graph::{error::Error, Graph};
+use std::borrow::Borrow;
+
+/// This function accepts a graph and returns a list of strongly
+/// connected components, represented as a list of nodes.
+pub fn tarjan<B, G>(g: B) -> Result<Vec<Vec<usize>>, Error>
+where
+ B: Borrow<G>,
+ G: Graph,
+{
+ let g = g.borrow();
+
+ // List of components
+ let mut components: Vec<Vec<usize>> = Vec::new();
+
+ // List of depth levels of nodes
+ let mut indices: Vec<usize> = vec![0; g.nodes_len()];
+
+ indices.shrink_to_fit();
+
+ // List of low link numbers of nodes
+ let mut lowlinks: Vec<usize> = indices.clone();
+
+ // The stack used in Trajan's algorithm
+ let mut tarjan_stack: Vec<usize> = Vec::new();
+
+ // The list of booleans to indicate whether a node is waiting on
+ // the stack
+ let mut waiting: Vec<bool> = vec![false; g.nodes_len()];
+
+ waiting.shrink_to_fit();
+
+ // a struct to simplify recursing
+
+ #[derive(Debug)]
+ enum StackElement {
+ Seen(usize, Vec<usize>),
+ Unseen(usize),
+ }
+
+ use StackElement::{Seen, Unseen};
+
+ // convenient macros
+
+ macro_rules! index {
+ ($num: ident) => {
+ indices.get($num).copied().unwrap()
+ };
+ }
+
+ macro_rules! lowlink {
+ ($num: ident) => {
+ lowlinks.get($num).copied().unwrap()
+ };
+ }
+
+ // The stack used to replace recursive function calls
+ let mut recursive_stack: Vec<StackElement> = Vec::new();
+
+ // The next index to assign
+ let mut next_index: usize = 1;
+
+ for node in g.nodes() {
+ if indices.get(node).copied() == Some(0) {
+ recursive_stack.push(Unseen(node));
+
+ 'recursion: while let Some(stack_element) = recursive_stack.pop() {
+ let stack_node: usize;
+
+ match stack_element {
+ Seen(node, children) => {
+ stack_node = node;
+
+ for child in children {
+ *lowlinks.get_mut(node).unwrap() =
+ std::cmp::min(lowlink!(node), lowlink!(child));
+ }
+ }
+
+ Unseen(node) => {
+ stack_node = node;
+
+ tarjan_stack.push(node);
+
+ // It is safe to unwrap here since the
+ // condition of the if clause already serves
+ // as a guard.
+ *indices.get_mut(node).unwrap() = next_index;
+ *lowlinks.get_mut(node).unwrap() = next_index;
+ *waiting.get_mut(node).unwrap() = true;
+
+ next_index += 1;
+
+ let mut node_index: Option<usize> = None;
+
+ for child in g.children_of(node)? {
+ // Ignore self-loops
+ if node == child {
+ continue;
+ }
+
+ match indices.get(child).copied() {
+ Some(0) => {
+ match node_index {
+ Some(index) => match recursive_stack.get_mut(index) {
+ Some(Seen(_, children)) => {
+ children.push(child);
+ }
+ Some(_) => {
+ unreachable!("wrong index: {index}");
+ }
+ None => {
+ unreachable!("index {index} out of bounds");
+ }
+ },
+ None => {
+ node_index = Some(recursive_stack.len());
+
+ let mut children = Vec::with_capacity(g.degree(node)?);
+ children.push(child);
+
+ recursive_stack.push(Seen(node, children));
+ }
+ }
+
+ recursive_stack.push(Unseen(child));
+ }
+ Some(_) if waiting.get(child).copied().unwrap() => {
+ *lowlinks.get_mut(node).unwrap() =
+ std::cmp::min(lowlink!(node), index!(child));
+ }
+ None => {
+ return Err(Error::IndexOutOfBounds(child, g.nodes_len()));
+ }
+ _ => {
+ // crossing edges are ignored
+ }
+ }
+ }
+
+ if node_index.is_some() {
+ continue 'recursion;
+ }
+ }
+ }
+
+ if lowlink!(stack_node) == index!(stack_node) {
+ let mut component: Vec<usize> = Vec::new();
+
+ while let Some(top) = tarjan_stack.pop() {
+ *waiting.get_mut(top).unwrap() = false;
+
+ component.push(top);
+
+ if top == stack_node {
+ components.push(component);
+
+ break;
+ }
+ }
+ }
+ }
+ }
+ }
+
+ Ok(components)
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use graph::adlist::{ALGBuilder, ALGraph};
+ use graph::builder::Builder;
+
+ use std::collections::BTreeSet as Set;
+
+ fn make_cycle(n: usize) -> Result<ALGraph, graph::error::Error> {
+ let mut builder = ALGBuilder::default();
+
+ builder.add_vertices(n);
+
+ for i in 0..(n - 1) {
+ builder.add_edge(i, i + 1, ())?;
+ }
+
+ builder.add_edge(n - 1, 0, ())?;
+
+ Ok(builder.build())
+ }
+
+ fn make_two_cycles(n: usize) -> Result<ALGraph, graph::error::Error> {
+ let mut builder = ALGBuilder::default();
+
+ builder.add_vertices(2 * n);
+
+ for i in 0..(2 * n - 1) {
+ builder.add_edge(i, i + 1, ())?;
+ }
+
+ builder.add_edge(n - 1, 0, ())?;
+ builder.add_edge(n - 2, 0, ())?; // random noise
+ builder.add_edge(0, n - 1, ())?; // random noise
+ builder.add_edge(0, 2 * n - 1, ())?; // random noise
+ builder.add_edge(2 * n - 1, n, ())?;
+
+ Ok(builder.build())
+ }
+
+ #[test]
+ fn test_cycle() -> Result<(), Box<dyn std::error::Error>> {
+ let length = 10;
+
+ let cycle = make_cycle(length)?;
+
+ let components = tarjan::<_, ALGraph>(&cycle)?;
+
+ println!("components = {components:?}");
+
+ assert_eq!(components.len(), 1);
+
+ let set: Set<usize> = components.first().unwrap().into_iter().copied().collect();
+
+ let answer: Set<usize> = (0..length).collect();
+
+ assert_eq!(set, answer);
+
+ Ok(())
+ }
+
+ #[test]
+ fn test_two_components() -> Result<(), Box<dyn std::error::Error>> {
+ let half_length = 10;
+
+ let graph = make_two_cycles(half_length)?;
+
+ let components = tarjan::<_, ALGraph>(graph)?;
+
+ println!("components = {components:?}");
+
+ assert_eq!(components.len(), 2);
+
+ let first_set: Set<usize> = components.get(0).unwrap().into_iter().copied().collect();
+ let first_answer: Set<usize> = (half_length..(2 * half_length)).collect();
+
+ let second_set: Set<usize> = components.get(1).unwrap().into_iter().copied().collect();
+ let second_answer: Set<usize> = (0..half_length).collect();
+
+ assert_eq!(first_set, first_answer);
+ assert_eq!(second_set, second_answer);
+
+ Ok(())
+ }
+}