diff options
author | JSDurand <mmemmew@gmail.com> | 2023-06-02 14:51:25 +0800 |
---|---|---|
committer | JSDurand <mmemmew@gmail.com> | 2023-06-02 14:51:25 +0800 |
commit | 1455da10f943e2aa1bdf26fb2697dafccc61e073 (patch) | |
tree | 7d6c51a1040fc6b05bf3a19386e05016f8c0ab0f /viz/src/strong_component.rs | |
parent | 83c66eb77c6affaa9ac4fabd808556613c5bf973 (diff) |
viz: finished decycle algorithm
Diffstat (limited to 'viz/src/strong_component.rs')
-rw-r--r-- | viz/src/strong_component.rs | 257 |
1 files changed, 257 insertions, 0 deletions
diff --git a/viz/src/strong_component.rs b/viz/src/strong_component.rs new file mode 100644 index 0000000..37dcd39 --- /dev/null +++ b/viz/src/strong_component.rs @@ -0,0 +1,257 @@ +//! This module implements the Trajan's algorithm for finding strongly +//! connected components of a directed graph with only one depth-first +//! traversal. + +use graph::{error::Error, Graph}; +use std::borrow::Borrow; + +/// This function accepts a graph and returns a list of strongly +/// connected components, represented as a list of nodes. +pub fn tarjan<B, G>(g: B) -> Result<Vec<Vec<usize>>, Error> +where + B: Borrow<G>, + G: Graph, +{ + let g = g.borrow(); + + // List of components + let mut components: Vec<Vec<usize>> = Vec::new(); + + // List of depth levels of nodes + let mut indices: Vec<usize> = vec![0; g.nodes_len()]; + + indices.shrink_to_fit(); + + // List of low link numbers of nodes + let mut lowlinks: Vec<usize> = indices.clone(); + + // The stack used in Trajan's algorithm + let mut tarjan_stack: Vec<usize> = Vec::new(); + + // The list of booleans to indicate whether a node is waiting on + // the stack + let mut waiting: Vec<bool> = vec![false; g.nodes_len()]; + + waiting.shrink_to_fit(); + + // a struct to simplify recursing + + #[derive(Debug)] + enum StackElement { + Seen(usize, Vec<usize>), + Unseen(usize), + } + + use StackElement::{Seen, Unseen}; + + // convenient macros + + macro_rules! index { + ($num: ident) => { + indices.get($num).copied().unwrap() + }; + } + + macro_rules! lowlink { + ($num: ident) => { + lowlinks.get($num).copied().unwrap() + }; + } + + // The stack used to replace recursive function calls + let mut recursive_stack: Vec<StackElement> = Vec::new(); + + // The next index to assign + let mut next_index: usize = 1; + + for node in g.nodes() { + if indices.get(node).copied() == Some(0) { + recursive_stack.push(Unseen(node)); + + 'recursion: while let Some(stack_element) = recursive_stack.pop() { + let stack_node: usize; + + match stack_element { + Seen(node, children) => { + stack_node = node; + + for child in children { + *lowlinks.get_mut(node).unwrap() = + std::cmp::min(lowlink!(node), lowlink!(child)); + } + } + + Unseen(node) => { + stack_node = node; + + tarjan_stack.push(node); + + // It is safe to unwrap here since the + // condition of the if clause already serves + // as a guard. + *indices.get_mut(node).unwrap() = next_index; + *lowlinks.get_mut(node).unwrap() = next_index; + *waiting.get_mut(node).unwrap() = true; + + next_index += 1; + + let mut node_index: Option<usize> = None; + + for child in g.children_of(node)? { + // Ignore self-loops + if node == child { + continue; + } + + match indices.get(child).copied() { + Some(0) => { + match node_index { + Some(index) => match recursive_stack.get_mut(index) { + Some(Seen(_, children)) => { + children.push(child); + } + Some(_) => { + unreachable!("wrong index: {index}"); + } + None => { + unreachable!("index {index} out of bounds"); + } + }, + None => { + node_index = Some(recursive_stack.len()); + + let mut children = Vec::with_capacity(g.degree(node)?); + children.push(child); + + recursive_stack.push(Seen(node, children)); + } + } + + recursive_stack.push(Unseen(child)); + } + Some(_) if waiting.get(child).copied().unwrap() => { + *lowlinks.get_mut(node).unwrap() = + std::cmp::min(lowlink!(node), index!(child)); + } + None => { + return Err(Error::IndexOutOfBounds(child, g.nodes_len())); + } + _ => { + // crossing edges are ignored + } + } + } + + if node_index.is_some() { + continue 'recursion; + } + } + } + + if lowlink!(stack_node) == index!(stack_node) { + let mut component: Vec<usize> = Vec::new(); + + while let Some(top) = tarjan_stack.pop() { + *waiting.get_mut(top).unwrap() = false; + + component.push(top); + + if top == stack_node { + components.push(component); + + break; + } + } + } + } + } + } + + Ok(components) +} + +#[cfg(test)] +mod tests { + use super::*; + use graph::adlist::{ALGBuilder, ALGraph}; + use graph::builder::Builder; + + use std::collections::BTreeSet as Set; + + fn make_cycle(n: usize) -> Result<ALGraph, graph::error::Error> { + let mut builder = ALGBuilder::default(); + + builder.add_vertices(n); + + for i in 0..(n - 1) { + builder.add_edge(i, i + 1, ())?; + } + + builder.add_edge(n - 1, 0, ())?; + + Ok(builder.build()) + } + + fn make_two_cycles(n: usize) -> Result<ALGraph, graph::error::Error> { + let mut builder = ALGBuilder::default(); + + builder.add_vertices(2 * n); + + for i in 0..(2 * n - 1) { + builder.add_edge(i, i + 1, ())?; + } + + builder.add_edge(n - 1, 0, ())?; + builder.add_edge(n - 2, 0, ())?; // random noise + builder.add_edge(0, n - 1, ())?; // random noise + builder.add_edge(0, 2 * n - 1, ())?; // random noise + builder.add_edge(2 * n - 1, n, ())?; + + Ok(builder.build()) + } + + #[test] + fn test_cycle() -> Result<(), Box<dyn std::error::Error>> { + let length = 10; + + let cycle = make_cycle(length)?; + + let components = tarjan::<_, ALGraph>(&cycle)?; + + println!("components = {components:?}"); + + assert_eq!(components.len(), 1); + + let set: Set<usize> = components.first().unwrap().into_iter().copied().collect(); + + let answer: Set<usize> = (0..length).collect(); + + assert_eq!(set, answer); + + Ok(()) + } + + #[test] + fn test_two_components() -> Result<(), Box<dyn std::error::Error>> { + let half_length = 10; + + let graph = make_two_cycles(half_length)?; + + let components = tarjan::<_, ALGraph>(graph)?; + + println!("components = {components:?}"); + + assert_eq!(components.len(), 2); + + let first_set: Set<usize> = components.get(0).unwrap().into_iter().copied().collect(); + let first_answer: Set<usize> = (half_length..(2 * half_length)).collect(); + + let second_set: Set<usize> = components.get(1).unwrap().into_iter().copied().collect(); + let second_answer: Set<usize> = (0..half_length).collect(); + + assert_eq!(first_set, first_answer); + assert_eq!(second_set, second_answer); + + Ok(()) + } +} |