summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJSDurand <mmemmew@gmail.com>2023-08-04 10:07:14 +0800
committerJSDurand <mmemmew@gmail.com>2023-08-04 10:07:14 +0800
commite64172f7909c71dc609099d3f5c4666d063653eb (patch)
tree1af3ed103f164cc572940afa25e4eced032231a4
parent81854107bcf0b4480cfb11e8af7fec6894240c0c (diff)
chain/default: Add funtion `print_current`
* chain/src/default.rs: This is useful for debugging the chain-rule machine.
-rw-r--r--chain/src/default.rs76
1 files changed, 74 insertions, 2 deletions
diff --git a/chain/src/default.rs b/chain/src/default.rs
index 618e560..1e0eba2 100644
--- a/chain/src/default.rs
+++ b/chain/src/default.rs
@@ -255,6 +255,66 @@ impl DefaultChain {
Ok(())
}
+
+ /// Print the graph that is relevant to the current node.
+ pub fn print_current(&self, filename: &str) -> Result<(), Box<dyn std::error::Error>> {
+ let filename = format!("output/{filename}");
+
+ let preamble = "digraph nfa {
+ fontname=\"Helvetica,Arial,sans-serif\"
+ node [fontname=\"Helvetica,Arial,sans-serif\", ordering=out]
+ edge [fontname=\"Helvetica,Arial,sans-serif\"]
+ rankdir=LR;\n";
+
+ let nodes_len = self.graph.nodes_len();
+
+ let mut relevant_nodes: HashSet<usize> = HashSet::with_capacity(nodes_len);
+
+ let mut stack: Vec<usize> = Vec::with_capacity(nodes_len);
+
+ stack.push(self.current);
+
+ while let Some(top) = stack.pop() {
+ if relevant_nodes.contains(&top) {
+ continue;
+ }
+
+ relevant_nodes.insert(top);
+
+ stack.extend(self.graph.children_of(top)?);
+ }
+
+ let relevant_nodes = relevant_nodes;
+
+ let mut post = String::new();
+
+ for (source, target) in self.edges() {
+ if !relevant_nodes.contains(&source) || !relevant_nodes.contains(&target) {
+ continue;
+ }
+
+ for label in self.edge_label(source, target).unwrap() {
+ post.push_str(&format!(" {source} -> {target} [label = \"{label}\"]\n"));
+ }
+ }
+
+ post.push_str("}\n");
+
+ let result = format!("{preamble}{post}");
+
+ if std::fs::metadata(&filename).is_ok() {
+ std::fs::remove_file(&filename)?;
+ }
+
+ let mut file = std::fs::File::options()
+ .write(true)
+ .create(true)
+ .open(&filename)?;
+
+ use std::io::Write;
+
+ file.write_all(result.as_bytes()).map_err(Into::into)
+ }
}
impl LabelGraph<Edge> for DefaultChain {
@@ -674,6 +734,10 @@ impl Chain for DefaultChain {
.is_accepting(virtual_node)
.map_err(index_out_of_bounds_conversion)?;
+ // if pos == 9 {
+ // dbg!(label, atom_label, self.current, &self.history);
+ // }
+
let first_segment_pavi: PaVi;
let virtual_pavi: PaVi;
@@ -843,7 +907,7 @@ impl Chain for DefaultChain {
.vertex_label(root)?
.ok_or(Error::NodeNoLabel(root))?;
- dbg!(root_degree, root_label);
+ // dbg!(root_degree, root_label);
// First perform reduction.
@@ -1182,11 +1246,15 @@ mod test_chain {
let grammar: Grammar = grammar_str.parse()?;
let atom = DefaultAtom::from_grammar(grammar)?;
+
let mut chain = DefaultChain::unit(atom)?;
let no_item = false;
- let input: &[usize] = &[3, 0, 2, 1, 1, 0, 1, 4, 0, 2, 1];
+ let input: &[usize] = &[
+ 3, 0, 2, 1, 5, 0, 6, 1, 3, 0, 2,
+ 1, // 5, 0, 6, 1, 3, 0, 2, 1, 5, 0, 6, 1, 4, 0, 2, 1,
+ ];
let input_len = input.len();
@@ -1196,11 +1264,15 @@ mod test_chain {
for (pos, t) in input.iter().copied().enumerate().take(input_len) {
chain.chain(t, pos, no_item)?;
if to_print {
+ chain.print_current(&format!("chain {pos}.gv"))?;
chain.forest().print_viz(&format!("forest {pos}.gv"))?;
}
dbg!(pos, t);
}
+ chain.print_viz("chain.gv")?;
+ chain.atom.print_viz("nfa.gv")?;
+
// let _ = chain.forest.print_viz("forest before extraction.gv");
let extracted = chain.end_of_input(input_len, input[input_len - 1])?;