summaryrefslogtreecommitdiff
path: root/chain/src/lib.rs
diff options
context:
space:
mode:
authorJSDurand <mmemmew@gmail.com>2023-02-28 15:46:50 +0800
committerJSDurand <mmemmew@gmail.com>2023-02-28 15:46:50 +0800
commitb63a9c05d7f86320b6cfbb4569b1880f2b804eb9 (patch)
tree13b799bc430d70ee668c8aef4b1ef499864e84bf /chain/src/lib.rs
parentf7be4c8eb8de9f525584b7fa4e53978ad233d5e4 (diff)
parentb22a5b38161fbc4a0bb9e472d42f78311b73741e (diff)
Merge from master
Diffstat (limited to 'chain/src/lib.rs')
-rw-r--r--chain/src/lib.rs20
1 files changed, 14 insertions, 6 deletions
diff --git a/chain/src/lib.rs b/chain/src/lib.rs
index d7fc519..9de1df7 100644
--- a/chain/src/lib.rs
+++ b/chain/src/lib.rs
@@ -258,11 +258,9 @@ pub trait Chain: LabelExtGraph<Edge> {
/// An iterator that iterates all layers that need to be merged.
type DerIter: Iterator<Item = TwoLayers>;
- // FIXME: Add a parameter to control whether to manipulate the
- // forests or not.
-
/// Take the derivative by a terminal `t` at position `pos`.
- fn derive(&mut self, t: usize, pos: usize) -> Result<Self::DerIter, Self::Error>;
+ fn derive(&mut self, t: usize, pos: usize, no_item: bool)
+ -> Result<Self::DerIter, Self::Error>;
/// Take the union of all derivatives.
fn union(&mut self, der_iter: Self::DerIter) -> Result<Vec<(Roi, usize)>, Self::Error> {
@@ -279,8 +277,18 @@ pub trait Chain: LabelExtGraph<Edge> {
/// Use chain rule to compute the derivative with respect to a
/// terminal.
- fn chain(&mut self, t: usize, pos: usize) -> Result<(), Self::Error> {
- let der_iter = self.derive(t, pos)?;
+ ///
+ /// # Arguments
+ ///
+ /// The argument `t` is the terminal we computet the derivative
+ /// with.
+ ///
+ /// The argument `pos` is the position within the input.
+ ///
+ /// The argument `no_item` determines whether we want the item
+ /// derivation forest as well.
+ fn chain(&mut self, t: usize, pos: usize, no_item: bool) -> Result<(), Self::Error> {
+ let der_iter = self.derive(t, pos, no_item)?;
let edges = self.union(der_iter)?;