diff options
author | JSDurand <mmemmew@gmail.com> | 2023-02-28 15:46:50 +0800 |
---|---|---|
committer | JSDurand <mmemmew@gmail.com> | 2023-02-28 15:46:50 +0800 |
commit | b63a9c05d7f86320b6cfbb4569b1880f2b804eb9 (patch) | |
tree | 13b799bc430d70ee668c8aef4b1ef499864e84bf /chain | |
parent | f7be4c8eb8de9f525584b7fa4e53978ad233d5e4 (diff) | |
parent | b22a5b38161fbc4a0bb9e472d42f78311b73741e (diff) |
Merge from master
Diffstat (limited to 'chain')
-rw-r--r-- | chain/src/default.rs | 226 | ||||
-rw-r--r-- | chain/src/lib.rs | 20 |
2 files changed, 151 insertions, 95 deletions
diff --git a/chain/src/default.rs b/chain/src/default.rs index cd0a898..da665bf 100644 --- a/chain/src/default.rs +++ b/chain/src/default.rs @@ -498,7 +498,12 @@ impl Chain for DefaultChain { // Of course we can use an optional vector to prevent allocating // too much memory for edges whose corresponding vector is empty. - fn derive(&mut self, t: usize, pos: usize) -> Result<Self::DerIter, Self::Error> { + fn derive( + &mut self, + t: usize, + pos: usize, + no_item: bool, + ) -> Result<Self::DerIter, Self::Error> { use TNT::*; /// A helper function to generate edges to join. @@ -623,35 +628,41 @@ impl Chain for DefaultChain { match *atom_label.get_value() { Some(Ter(ter)) if ter == t => { - // prepare forest fragment - - let fragment = - generate_fragment([atom_moved.into(), Ter(ter).into()], pos)?; - - if pos == 4 { - dbg!(atom_moved, label); - self.forest - .print_viz(&format!( - "pos4tb - {atom_moved}-{:?}.gv", - label.true_source() - )) - .unwrap(); - } + let new_pavi: PaVi; - let new_pavi = self.forest.insert_item( - *label, - fragment, - atom_child_iter.clone(), - &self.atom, - )?; + if !no_item { + // prepare forest fragment + + let fragment = + generate_fragment([atom_moved.into(), Ter(ter).into()], pos)?; + + if pos == 4 { + dbg!(atom_moved, label); + self.forest + .print_viz(&format!( + "pos4tb - {atom_moved}-{:?}.gv", + label.true_source() + )) + .unwrap(); + } + + new_pavi = self.forest.insert_item( + *label, + fragment, + atom_child_iter.clone(), + &self.atom, + )?; - if pos == 4 { - self.forest - .print_viz(&format!( - "pos4ta - {atom_moved}-{:?}.gv", - label.true_source() - )) - .unwrap(); + if pos == 4 { + self.forest + .print_viz(&format!( + "pos4ta - {atom_moved}-{:?}.gv", + label.true_source() + )) + .unwrap(); + } + } else { + new_pavi = PaVi::default(); } let accepting = generate_edges( @@ -682,52 +693,60 @@ impl Chain for DefaultChain { .map_err(index_out_of_bounds_conversion)?; if let Some(virtual_node) = virtual_node { - let first_fragment = - generate_fragment([atom_moved.into(), Non(non).into()], pos)?; - - if pos == 4 { - dbg!(atom_moved, label); - self.forest - .print_viz(&format!("pos4nb - {atom_moved}-{label:?}.gv")) - .unwrap(); - } - - let first_segment_pavi = self.forest.insert_item( - *label, - first_fragment, - atom_child_iter.clone(), - &self.atom, - )?; - - if pos == 4 { - self.forest - .print_viz(&format!("pos4na - {atom_moved}-{label:?}.gv")) - .unwrap(); - } - let accepting = self .atom .is_accepting(virtual_node) .map_err(index_out_of_bounds_conversion)?; - let virtual_fragment = - DefaultForest::new_leaf(GrammarLabel::new(Ter(t), pos)); - - // NOTE: We only need the PaVi from the - // first segment, so we pass an empty - // iterator, in which case the passed - // label is only used for the PaVi. - let virtual_pavi = self.forest.insert_item( - Edge::new(0, first_segment_pavi, accepting), - virtual_fragment, - std::iter::empty(), - &self.atom, - )?; + let first_segment_pavi: PaVi; + let virtual_pavi: PaVi; - if pos == 4 { - self.forest - .print_viz(&format!("pos4va - {atom_moved}-{:?}.gv", label)) - .unwrap(); + if !no_item { + let first_fragment = + generate_fragment([atom_moved.into(), Non(non).into()], pos)?; + + if pos == 4 { + dbg!(atom_moved, label); + self.forest + .print_viz(&format!("pos4nb - {atom_moved}-{:?}.gv", label)) + .unwrap(); + } + + first_segment_pavi = self.forest.insert_item( + *label, + first_fragment, + atom_child_iter.clone(), + &self.atom, + )?; + + if pos == 4 { + self.forest + .print_viz(&format!("pos4na - {atom_moved}-{:?}.gv", label)) + .unwrap(); + } + + let virtual_fragment = + DefaultForest::new_leaf(GrammarLabel::new(Ter(t), pos)); + + // NOTE: We only need the PaVi from the + // first segment, so we pass an empty + // iterator, in which case the passed + // label is only used for the PaVi. + virtual_pavi = self.forest.insert_item( + Edge::new(0, first_segment_pavi, accepting), + virtual_fragment, + std::iter::empty(), + &self.atom, + )?; + + if pos == 4 { + self.forest + .print_viz(&format!("pos4va - {atom_moved}-{:?}.gv", label)) + .unwrap(); + } + } else { + first_segment_pavi = PaVi::default(); + virtual_pavi = PaVi::default(); } let mut new_edges = Vec::new(); @@ -1037,26 +1056,37 @@ mod test_chain { #[test] fn base_test() -> Result<(), Box<dyn std::error::Error>> { + let mut no_item = false; + + for arg in std::env::args() { + if arg == "no_item" { + no_item = true; + break; + } + } + let grammar = new_notes_grammar()?; let atom = DefaultAtom::from_grammar(grammar)?; let mut chain = DefaultChain::unit(atom)?; - chain.chain(3, 00)?; - chain.chain(1, 01)?; - chain.chain(2, 02)?; - chain.chain(2, 03)?; - chain.chain(2, 04)?; - chain.chain(0, 05)?; - chain.chain(5, 06)?; - chain.chain(1, 07)?; - chain.chain(6, 08)?; - chain.chain(6, 09)?; - chain.chain(6, 10)?; - chain.chain(0, 11)?; - chain.chain(0, 12)?; - - chain.end_of_input()?; + chain.chain(3, 00, no_item)?; + chain.chain(1, 01, no_item)?; + chain.chain(2, 02, no_item)?; + chain.chain(2, 03, no_item)?; + chain.chain(2, 04, no_item)?; + chain.chain(0, 05, no_item)?; + chain.chain(5, 06, no_item)?; + chain.chain(1, 07, no_item)?; + chain.chain(6, 08, no_item)?; + chain.chain(6, 09, no_item)?; + chain.chain(6, 10, no_item)?; + chain.chain(0, 11, no_item)?; + chain.chain(0, 12, no_item)?; + + if !no_item { + chain.end_of_input()?; + } for label in chain.labels_of(chain.current())?.map(|(label, _)| label) { dbg!(label); @@ -1079,20 +1109,29 @@ mod test_chain { #[test] fn test_ambiguity() -> Result<(), Box<dyn std::error::Error>> { + let mut no_item = false; + + for arg in std::env::args() { + if arg == "no_item" { + no_item = true; + break; + } + } + let grammar = new_paren_grammar()?; let atom = DefaultAtom::from_grammar(grammar)?; let mut chain = DefaultChain::unit(atom)?; - chain.chain(0, 0)?; + chain.chain(0, 0, no_item)?; chain.forest.print_viz("forest0.gv")?; - chain.chain(2, 1)?; + chain.chain(2, 1, no_item)?; chain.forest.print_viz("forest1.gv")?; - chain.chain(2, 2)?; + chain.chain(2, 2, no_item)?; chain.forest.print_viz("forest2.gv")?; - chain.chain(2, 3)?; + chain.chain(2, 3, no_item)?; chain.forest.print_viz("forest3.gv")?; - chain.chain(1, 4)?; + chain.chain(1, 4, no_item)?; chain.forest.print_viz("forest4.gv")?; chain.end_of_input()?; chain.forest.print_viz("forest.gv")?; @@ -1127,6 +1166,15 @@ mod test_chain { #[test] fn test_speed() -> Result<(), Box<dyn std::error::Error>> { + let mut no_item = false; + + for arg in std::env::args() { + if arg == "no_item" { + no_item = true; + break; + } + } + let grammar = new_notes_grammar_no_regexp()?; println!("grammar: {grammar}"); @@ -1161,7 +1209,7 @@ mod test_chain { let start = std::time::Instant::now(); for (index, t) in input.iter().copied().enumerate() { - chain.chain(t, index)?; + chain.chain(t, index, no_item)?; } let elapsed = start.elapsed(); diff --git a/chain/src/lib.rs b/chain/src/lib.rs index d7fc519..9de1df7 100644 --- a/chain/src/lib.rs +++ b/chain/src/lib.rs @@ -258,11 +258,9 @@ pub trait Chain: LabelExtGraph<Edge> { /// An iterator that iterates all layers that need to be merged. type DerIter: Iterator<Item = TwoLayers>; - // FIXME: Add a parameter to control whether to manipulate the - // forests or not. - /// Take the derivative by a terminal `t` at position `pos`. - fn derive(&mut self, t: usize, pos: usize) -> Result<Self::DerIter, Self::Error>; + fn derive(&mut self, t: usize, pos: usize, no_item: bool) + -> Result<Self::DerIter, Self::Error>; /// Take the union of all derivatives. fn union(&mut self, der_iter: Self::DerIter) -> Result<Vec<(Roi, usize)>, Self::Error> { @@ -279,8 +277,18 @@ pub trait Chain: LabelExtGraph<Edge> { /// Use chain rule to compute the derivative with respect to a /// terminal. - fn chain(&mut self, t: usize, pos: usize) -> Result<(), Self::Error> { - let der_iter = self.derive(t, pos)?; + /// + /// # Arguments + /// + /// The argument `t` is the terminal we computet the derivative + /// with. + /// + /// The argument `pos` is the position within the input. + /// + /// The argument `no_item` determines whether we want the item + /// derivation forest as well. + fn chain(&mut self, t: usize, pos: usize, no_item: bool) -> Result<(), Self::Error> { + let der_iter = self.derive(t, pos, no_item)?; let edges = self.union(der_iter)?; |