From e64172f7909c71dc609099d3f5c4666d063653eb Mon Sep 17 00:00:00 2001 From: JSDurand Date: Fri, 4 Aug 2023 10:07:14 +0800 Subject: chain/default: Add funtion `print_current` * chain/src/default.rs: This is useful for debugging the chain-rule machine. --- chain/src/default.rs | 76 ++++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 74 insertions(+), 2 deletions(-) (limited to 'chain') diff --git a/chain/src/default.rs b/chain/src/default.rs index 618e560..1e0eba2 100644 --- a/chain/src/default.rs +++ b/chain/src/default.rs @@ -255,6 +255,66 @@ impl DefaultChain { Ok(()) } + + /// Print the graph that is relevant to the current node. + pub fn print_current(&self, filename: &str) -> Result<(), Box> { + let filename = format!("output/{filename}"); + + let preamble = "digraph nfa { + fontname=\"Helvetica,Arial,sans-serif\" + node [fontname=\"Helvetica,Arial,sans-serif\", ordering=out] + edge [fontname=\"Helvetica,Arial,sans-serif\"] + rankdir=LR;\n"; + + let nodes_len = self.graph.nodes_len(); + + let mut relevant_nodes: HashSet = HashSet::with_capacity(nodes_len); + + let mut stack: Vec = Vec::with_capacity(nodes_len); + + stack.push(self.current); + + while let Some(top) = stack.pop() { + if relevant_nodes.contains(&top) { + continue; + } + + relevant_nodes.insert(top); + + stack.extend(self.graph.children_of(top)?); + } + + let relevant_nodes = relevant_nodes; + + let mut post = String::new(); + + for (source, target) in self.edges() { + if !relevant_nodes.contains(&source) || !relevant_nodes.contains(&target) { + continue; + } + + for label in self.edge_label(source, target).unwrap() { + post.push_str(&format!(" {source} -> {target} [label = \"{label}\"]\n")); + } + } + + post.push_str("}\n"); + + let result = format!("{preamble}{post}"); + + if std::fs::metadata(&filename).is_ok() { + std::fs::remove_file(&filename)?; + } + + let mut file = std::fs::File::options() + .write(true) + .create(true) + .open(&filename)?; + + use std::io::Write; + + file.write_all(result.as_bytes()).map_err(Into::into) + } } impl LabelGraph for DefaultChain { @@ -674,6 +734,10 @@ impl Chain for DefaultChain { .is_accepting(virtual_node) .map_err(index_out_of_bounds_conversion)?; + // if pos == 9 { + // dbg!(label, atom_label, self.current, &self.history); + // } + let first_segment_pavi: PaVi; let virtual_pavi: PaVi; @@ -843,7 +907,7 @@ impl Chain for DefaultChain { .vertex_label(root)? .ok_or(Error::NodeNoLabel(root))?; - dbg!(root_degree, root_label); + // dbg!(root_degree, root_label); // First perform reduction. @@ -1182,11 +1246,15 @@ mod test_chain { let grammar: Grammar = grammar_str.parse()?; let atom = DefaultAtom::from_grammar(grammar)?; + let mut chain = DefaultChain::unit(atom)?; let no_item = false; - let input: &[usize] = &[3, 0, 2, 1, 1, 0, 1, 4, 0, 2, 1]; + let input: &[usize] = &[ + 3, 0, 2, 1, 5, 0, 6, 1, 3, 0, 2, + 1, // 5, 0, 6, 1, 3, 0, 2, 1, 5, 0, 6, 1, 4, 0, 2, 1, + ]; let input_len = input.len(); @@ -1196,11 +1264,15 @@ mod test_chain { for (pos, t) in input.iter().copied().enumerate().take(input_len) { chain.chain(t, pos, no_item)?; if to_print { + chain.print_current(&format!("chain {pos}.gv"))?; chain.forest().print_viz(&format!("forest {pos}.gv"))?; } dbg!(pos, t); } + chain.print_viz("chain.gv")?; + chain.atom.print_viz("nfa.gv")?; + // let _ = chain.forest.print_viz("forest before extraction.gv"); let extracted = chain.end_of_input(input_len, input[input_len - 1])?; -- cgit v1.2.3-18-g5258