diff options
author | JSDurand <mmemmew@gmail.com> | 2023-02-28 15:46:50 +0800 |
---|---|---|
committer | JSDurand <mmemmew@gmail.com> | 2023-02-28 15:46:50 +0800 |
commit | b63a9c05d7f86320b6cfbb4569b1880f2b804eb9 (patch) | |
tree | 13b799bc430d70ee668c8aef4b1ef499864e84bf /chain/src/lib.rs | |
parent | f7be4c8eb8de9f525584b7fa4e53978ad233d5e4 (diff) | |
parent | b22a5b38161fbc4a0bb9e472d42f78311b73741e (diff) |
Merge from master
Diffstat (limited to 'chain/src/lib.rs')
-rw-r--r-- | chain/src/lib.rs | 20 |
1 files changed, 14 insertions, 6 deletions
diff --git a/chain/src/lib.rs b/chain/src/lib.rs index d7fc519..9de1df7 100644 --- a/chain/src/lib.rs +++ b/chain/src/lib.rs @@ -258,11 +258,9 @@ pub trait Chain: LabelExtGraph<Edge> { /// An iterator that iterates all layers that need to be merged. type DerIter: Iterator<Item = TwoLayers>; - // FIXME: Add a parameter to control whether to manipulate the - // forests or not. - /// Take the derivative by a terminal `t` at position `pos`. - fn derive(&mut self, t: usize, pos: usize) -> Result<Self::DerIter, Self::Error>; + fn derive(&mut self, t: usize, pos: usize, no_item: bool) + -> Result<Self::DerIter, Self::Error>; /// Take the union of all derivatives. fn union(&mut self, der_iter: Self::DerIter) -> Result<Vec<(Roi, usize)>, Self::Error> { @@ -279,8 +277,18 @@ pub trait Chain: LabelExtGraph<Edge> { /// Use chain rule to compute the derivative with respect to a /// terminal. - fn chain(&mut self, t: usize, pos: usize) -> Result<(), Self::Error> { - let der_iter = self.derive(t, pos)?; + /// + /// # Arguments + /// + /// The argument `t` is the terminal we computet the derivative + /// with. + /// + /// The argument `pos` is the position within the input. + /// + /// The argument `no_item` determines whether we want the item + /// derivation forest as well. + fn chain(&mut self, t: usize, pos: usize, no_item: bool) -> Result<(), Self::Error> { + let der_iter = self.derive(t, pos, no_item)?; let edges = self.union(der_iter)?; |