summaryrefslogtreecommitdiff
path: root/chain/src/default.rs
diff options
context:
space:
mode:
Diffstat (limited to 'chain/src/default.rs')
-rw-r--r--chain/src/default.rs784
1 files changed, 776 insertions, 8 deletions
diff --git a/chain/src/default.rs b/chain/src/default.rs
index e04be9f..697b997 100644
--- a/chain/src/default.rs
+++ b/chain/src/default.rs
@@ -6,18 +6,47 @@
//! modular design makes that easy.
use super::*;
+use crate::atom::{Atom, DefaultAtom};
use core::fmt::Display;
+use forest::{default::DefaultForest, Forest};
+use grammar::{Error as GrammarError, GrammarLabel, GrammarLabelType, TNT};
+#[allow(unused_imports)]
+use graph::{
+ labelled::DLGBuilder, Builder, DLGraph, Graph, LabelExtGraph, LabelGraph, ParentsGraph,
+};
+
+use std::collections::{HashMap as Map, TryReserveError};
/// The errors related to taking derivatives by chain rule.
+#[non_exhaustive]
#[derive(Debug)]
pub enum Error {
+ /// General error for indices out of bounds.
+ IndexOutOfBounds(usize, usize),
+ /// The forest encounters a duplicate node, for some reason.
+ DuplicateNode(usize),
+ /// The chain rule machine encounters a duplicate edge, for some
+ /// reason.
+ DuplicateEdge(usize, usize),
+ /// A node has no labels while it is required to have one.
+ NodeNoLabel(usize),
+ /// Reserving memory fails.
+ CannotReserve(TryReserveError),
/// An invalid situation happens.
Invalid,
}
impl Display for Error {
- fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
+ Self::IndexOutOfBounds(index, bound) => write!(f, "index {index} out of bound {bound}"),
+ Self::DuplicateNode(n) => write!(f, "the forest has a node {n} with a duplicate label"),
+ Self::DuplicateEdge(source, target) => write!(
+ f,
+ "the forest has a duplicate edge from {source} to {target}"
+ ),
+ Self::NodeNoLabel(n) => write!(f, "node {n} has no labels while it should have one"),
+ Self::CannotReserve(e) => write!(f, "cannot reserve memory: {e}"),
Self::Invalid => write!(f, "invalid"),
}
}
@@ -25,22 +54,761 @@ impl Display for Error {
impl std::error::Error for Error {}
+impl From<GError> for Error {
+ fn from(value: GError) -> Self {
+ match value {
+ GError::IndexOutOfBounds(index, bound) => Self::IndexOutOfBounds(index, bound),
+ GError::DuplicatedNode(n) => Self::DuplicateNode(n),
+ GError::DuplicatedEdge(source, target) => Self::DuplicateEdge(source, target),
+ _ => Self::Invalid,
+ }
+ }
+}
+
+impl From<ForestError> for Error {
+ fn from(e: ForestError) -> Self {
+ match e {
+ ForestError::IndexOutOfBounds(index, bound) => Error::IndexOutOfBounds(index, bound),
+ ForestError::DuplicatedNode(n) => Error::DuplicateNode(n),
+ ForestError::InvalidGraphError(ge) => ge.into(),
+ ForestError::NodeNoLabel(n) => Error::NodeNoLabel(n),
+ }
+ }
+}
+
+impl From<TryReserveError> for Error {
+ fn from(value: TryReserveError) -> Self {
+ Self::CannotReserve(value)
+ }
+}
+
+/// The type of an index into an element in [`DerIter`].
+#[derive(Debug, Copy, Clone)]
+enum DerIterIndex {
+ Single(usize),
+ Map(usize),
+}
+
+impl Default for DerIterIndex {
+ fn default() -> Self {
+ Self::Map(0)
+ }
+}
+
+/// A complex type used for storing values of edges with two layers.
+type SecondTypeValue = (Parent, bool, Vec<(Edge, usize)>);
+
+/// An iterator of TwoLayers.
+#[derive(Debug, Default)]
+pub struct DerIter {
+ /// Stores edges of only one layer.
+ singles: Vec<(Edge, usize)>,
+ /// Stores edges with two layers. They are grouped by their
+ /// labels of the second layer.
+ ///
+ /// The values are tuples (forest_source, accepting, edges), where
+ /// the edges are the grouped edges of the first layer and the
+ /// destination.
+ seconds: Map<usize, SecondTypeValue>,
+ /// We want to iterate the elements of the map, for which purpose
+ /// we need an array. Since hashmaps provide no arrays, we keep
+ /// an array of keys for iteration purposes.
+ second_array: Vec<usize>,
+ /// The index of the current element, either in `second_array` or
+ /// in `singles` .
+ index: DerIterIndex,
+}
+
+impl DerIter {
+ fn add_second_layer(
+ &mut self,
+ label: usize,
+ forest_source: Parent,
+ accepting: bool,
+ edges: Vec<(Edge, usize)>,
+ ) {
+ if let Some((_, _, vec)) = self.seconds.get_mut(&label) {
+ vec.extend(edges);
+ } else {
+ self.seconds
+ .insert(label, (forest_source, accepting, edges));
+
+ self.second_array.push(label);
+ }
+ }
+}
+
+impl Iterator for DerIter {
+ type Item = TwoLayers;
+
+ fn next(&mut self) -> Option<Self::Item> {
+ // We iterate through two layered edges first.
+ match self.index {
+ DerIterIndex::Map(index) => {
+ if let Some(key) = self.second_array.get(index) {
+ if let Some((forest_source, accepting, edges)) = self.seconds.remove(key) {
+ self.index = DerIterIndex::Map(index + 1);
+
+ Some(TwoLayers::Two(*key, forest_source, accepting, edges))
+ } else {
+ // this should not happen
+ println!("a key does not exist in the hashmap: something is wrong when taking derivatives");
+ None
+ }
+ } else {
+ self.index = DerIterIndex::Single(0);
+
+ if let Some((edge, to)) = self.singles.first() {
+ self.index = DerIterIndex::Single(1);
+
+ Some(TwoLayers::One(*edge, *to))
+ } else {
+ None
+ }
+ }
+ }
+ DerIterIndex::Single(index) => {
+ if let Some((edge, to)) = self.singles.get(index) {
+ self.index = DerIterIndex::Single(index + 1);
+
+ Some(TwoLayers::One(*edge, *to))
+ } else {
+ None
+ }
+ }
+ }
+ }
+}
+
/// A default implementation for the [`Chain`] trait.
#[derive(Debug, Clone, Default)]
-pub struct DefaultChain {}
+pub struct DefaultChain {
+ graph: DLGraph<Edge>,
+ atom: DefaultAtom,
+ current: usize,
+ history: Vec<usize>,
+ forest: DefaultForest<GrammarLabel>,
+ accepting_vec: Vec<bool>,
+}
+
+impl DefaultChain {
+ /// Return the current node.
+ #[inline]
+ pub fn current(&self) -> usize {
+ self.current
+ }
+
+ /// Return the complete slice of histories.
+ #[inline]
+ pub fn history(&self) -> &[usize] {
+ self.history.as_ref()
+ }
+
+ /// Return a reference to the associated forest.
+ #[inline]
+ pub fn forest(&self) -> &DefaultForest<GrammarLabel> {
+ &self.forest
+ }
+
+ /// Print the rule positions of the labels.
+ pub fn print_rule_positions(&self) -> Result<(), Box<dyn std::error::Error>> {
+ let mut labels = std::collections::HashSet::<usize>::default();
+
+ for node in 0..self.graph.nodes_len() {
+ labels.extend(self.graph.labels_of(node)?.map(|(label, _)| label.label));
+ }
+
+ for label in labels.into_iter() {
+ println!("{}", self.atom.rule_pos_string(label)?);
+ }
+
+ Ok(())
+ }
+}
+
+impl Graph for DefaultChain {
+ type Iter<'a> = <DLGraph<Edge> as Graph>::Iter<'a>
+ where
+ Self: 'a;
+
+ #[inline]
+ fn is_empty(&self) -> bool {
+ self.graph.is_empty()
+ }
+
+ #[inline]
+ fn nodes_len(&self) -> usize {
+ self.graph.nodes_len()
+ }
+
+ #[inline]
+ fn edges_len(&self) -> Option<usize> {
+ self.graph.edges_len()
+ }
+
+ #[inline]
+ fn children_of(&self, node_id: usize) -> Result<Self::Iter<'_>, GError> {
+ self.graph.children_of(node_id)
+ }
+
+ #[inline]
+ fn degree(&self, node_id: usize) -> Result<usize, GError> {
+ self.graph.degree(node_id)
+ }
+
+ #[inline]
+ fn is_empty_node(&self, node_id: usize) -> Result<bool, GError> {
+ self.graph.is_empty_node(node_id)
+ }
+
+ #[inline]
+ fn has_edge(&self, source: usize, target: usize) -> Result<bool, GError> {
+ self.graph.has_edge(source, target)
+ }
+
+ fn replace_by_builder(&mut self, _builder: impl graph::Builder<Result = Self>) {
+ unimplemented!("I shall refactor this")
+ }
+}
+
+impl LabelGraph<Edge> for DefaultChain {
+ type Iter<'a> = <DLGraph<Edge> as LabelGraph<Edge>>::Iter<'a>
+ where
+ Self: 'a;
+
+ type LabelIter<'a> = <DLGraph<Edge> as LabelGraph<Edge>>::LabelIter<'a>
+ where
+ Self: 'a,
+ Edge: 'a;
+
+ type EdgeLabelIter<'a> = <DLGraph<Edge> as LabelGraph<Edge>>::EdgeLabelIter<'a>
+ where
+ Self: 'a,
+ Edge: 'a;
+
+ #[inline]
+ fn edge_label(&self, source: usize, target: usize) -> Result<Self::EdgeLabelIter<'_>, GError> {
+ self.graph.edge_label(source, target)
+ }
+
+ #[inline]
+ fn find_children_with_label(
+ &self,
+ node_id: usize,
+ label: &Edge,
+ ) -> Result<<Self as LabelGraph<Edge>>::Iter<'_>, GError> {
+ self.graph.find_children_with_label(node_id, label)
+ }
+
+ #[inline]
+ fn labels_of(&self, node_id: usize) -> Result<Self::LabelIter<'_>, GError> {
+ self.graph.labels_of(node_id)
+ }
+
+ #[inline]
+ fn has_edge_label(&self, node_id: usize, label: &Edge, target: usize) -> Result<bool, GError> {
+ self.graph.has_edge_label(node_id, label, target)
+ }
+}
+
+impl LabelExtGraph<Edge> for DefaultChain {
+ #[inline]
+ fn extend(&mut self, edges: impl IntoIterator<Item = (Edge, usize)>) -> Result<usize, GError> {
+ let new = self.graph.extend(edges)?;
+ let accepting_len = self.accepting_vec.len();
+
+ if self.accepting_vec.get(new).is_none() {
+ // assert it can only grow by one node at a time.
+ #[cfg(debug_assertions)]
+ assert_eq!(new, accepting_len);
+
+ let mut updated = false;
+
+ for (label, child_iter) in self.graph.labels_of(new)? {
+ let old_accepting = {
+ let mut result = false;
+ for child in child_iter {
+ if *self
+ .accepting_vec
+ .get(child)
+ .ok_or(GError::IndexOutOfBounds(child, accepting_len))?
+ {
+ result = true;
+ break;
+ }
+ }
+
+ result
+ };
+
+ if !old_accepting {
+ self.accepting_vec.push(false);
+ updated = true;
+
+ break;
+ }
+
+ if label.is_accepting() {
+ self.accepting_vec.push(true);
+ updated = true;
+
+ break;
+ }
+ }
+
+ if !updated {
+ self.accepting_vec.push(false);
+ }
+ }
+
+ Ok(new)
+ }
+}
impl Chain for DefaultChain {
type Error = Error;
- fn unit() -> Self {
- todo!()
+ type Atom = DefaultAtom;
+
+ fn unit(atom: Self::Atom) -> Result<Self, Self::Error> {
+ let mut builder: DLGBuilder<Edge> = Default::default();
+
+ let root = builder.add_vertex();
+ let first = builder.add_vertex();
+
+ let empty_state = atom.empty();
+
+ let initial_nullable = atom
+ .is_nullable(0)
+ .map_err(|_| Error::IndexOutOfBounds(0, atom.non_num()))?;
+
+ builder.add_edge(
+ first,
+ root,
+ Edge::new(empty_state, Parent::new(0, 0), initial_nullable),
+ )?;
+
+ let graph = builder.build();
+
+ let forest =
+ DefaultForest::new_leaf(GrammarLabel::new(GrammarLabelType::TNT(TNT::Non(0)), 0));
+
+ #[cfg(debug_assertions)]
+ assert_eq!(forest.root(), Some(0));
+
+ let current = 1;
+
+ let history = Vec::new();
+
+ let accepting_vec = vec![true, initial_nullable];
+
+ Ok(Self {
+ graph,
+ atom,
+ current,
+ history,
+ forest,
+ accepting_vec,
+ })
+ }
+
+ fn epsilon(&self) -> Result<bool, Self::Error> {
+ self.accepting_vec
+ .get(self.current)
+ .copied()
+ .ok_or(Error::IndexOutOfBounds(
+ self.current,
+ self.accepting_vec.len(),
+ ))
+ }
+
+ fn update_history(&mut self, new: usize) {
+ debug_assert!(new < self.graph.nodes_len());
+
+ self.history.push(self.current);
+
+ self.current = new;
+ }
+
+ type DerIter = DerIter;
+
+ fn derive(&mut self, t: usize, pos: usize) -> Result<Self::DerIter, Self::Error> {
+ use TNT::*;
+
+ /// Convert an error telling us that an index is out of bounds.
+ ///
+ /// # Panics
+ ///
+ /// The function panics if the error is not of the expected
+ /// kind.
+ fn index_out_of_bounds_conversion(ge: GrammarError) -> Error {
+ match ge {
+ GrammarError::IndexOutOfBounds(index, bound) => {
+ Error::IndexOutOfBounds(index, bound)
+ }
+ _ => panic!("wrong error kind"),
+ }
+ }
+
+ /// A helper function to generate edges to join.
+ ///
+ /// It first checks if the base edge is accepting. If yes,
+ /// then pull in the children of the target.
+ ///
+ /// Then check if the label of the base edge has children. If
+ /// no, then do not add this base edge itself.
+ ///
+ /// The generated edges will be pushed to `output` directly,
+ /// to save some allocations.
+ // TODO: Handle forests as well.
+ fn generate_edges(
+ chain: &DefaultChain,
+ child_iter: impl Iterator<Item = usize> + ExactSizeIterator + Clone,
+ atom_child_iter: impl Iterator<Item = usize> + Clone,
+ mut output: impl AsMut<Vec<(Edge, usize)>>,
+ ) -> Result<(), Error> {
+ // First check the values from iterators are all valid.
+ let graph_len = chain.graph.nodes_len();
+ let atom_len = chain.atom.nodes_len();
+
+ for child in child_iter.clone() {
+ if !chain.graph.has_node(child) {
+ return Err(Error::IndexOutOfBounds(child, graph_len));
+ }
+ }
+
+ for atom_child in atom_child_iter.clone() {
+ if !chain.atom.has_node(atom_child) {
+ return Err(Error::IndexOutOfBounds(atom_child, atom_len));
+ }
+ }
+
+ // From now on the nodes are all valid, so we can just
+ // call `unwrap`.
+
+ // Then calculate the number of edges to append, to avoid
+ // repeated allocations
+ let mut num = 0usize;
+
+ let child_iter_total_degree = child_iter
+ .clone()
+ .map(|child| chain.graph.degree(child).unwrap())
+ .sum::<usize>();
+
+ for atom_child in atom_child_iter.clone() {
+ let atom_child_accepting = chain.atom.is_accepting(atom_child).unwrap();
+ let atom_child_empty_node = chain.atom.is_empty_node(atom_child).unwrap();
+
+ if !atom_child_empty_node {
+ num += child_iter.len();
+ }
+
+ if atom_child_accepting {
+ num += child_iter_total_degree;
+ }
+ }
+
+ let num = num;
+
+ let output = output.as_mut();
+
+ output.try_reserve(num)?;
+
+ // now push into output
+
+ let parent = Parent::new(0, 0);
+
+ for atom_child in atom_child_iter.clone() {
+ let atom_child_accepting = chain.atom.is_accepting(atom_child).unwrap();
+ let atom_child_empty_node = chain.atom.is_empty_node(atom_child).unwrap();
+
+ let edge = Edge::new(atom_child, parent, atom_child_accepting);
+
+ if !atom_child_empty_node {
+ output.extend(child_iter.clone().map(|child| (edge, child)));
+ }
+
+ if atom_child_accepting {
+ for child in child_iter.clone() {
+ for (child_label, child_child) in chain.graph.labels_of(child).unwrap() {
+ output.extend(child_child.map(|target| (*child_label, target)));
+ }
+ }
+ }
+ }
+
+ Ok(())
+ }
+
+ let mut der_iter = DerIter::default();
+
+ for (label, child_iter) in self.graph.labels_of(self.current)? {
+ for (atom_label, atom_child_iter) in self.atom.labels_of(label.label())? {
+ if atom_label.is_left_p() {
+ // We do not consider left-linearly expanded
+ // children in the first layer.
+ continue;
+ }
+
+ match *atom_label.get_value() {
+ Some(Ter(ter)) if ter == t => {
+ generate_edges(
+ self,
+ child_iter.clone(),
+ atom_child_iter.clone(),
+ &mut der_iter.singles,
+ )?;
+ }
+ Some(Non(non)) => {
+ let virtual_node = self
+ .atom
+ .atom(non, t)
+ .map_err(index_out_of_bounds_conversion)?;
+
+ if let Some(virtual_node) = virtual_node {
+ let accepting = self
+ .atom
+ .is_accepting(virtual_node)
+ .map_err(index_out_of_bounds_conversion)?;
+
+ let mut new_edges = Vec::new();
+
+ generate_edges(
+ self,
+ child_iter.clone(),
+ atom_child_iter.clone(),
+ &mut new_edges,
+ )?;
+
+ if accepting {
+ der_iter.singles.extend(new_edges.clone());
+ }
+
+ let parent = Parent::new(0, 0);
+
+ if !self.atom.is_empty_node(virtual_node).unwrap() {
+ der_iter.add_second_layer(
+ virtual_node,
+ parent,
+ accepting,
+ new_edges,
+ );
+
+ // account for atom_children without
+ // children.
+
+ for atom_child in atom_child_iter {
+ // this has been checked in
+ // `generate_edges`
+ if self.atom.is_empty_node(atom_child).unwrap() {
+ der_iter.singles.extend(child_iter.clone().map(|child| {
+ (Edge::new(virtual_node, parent, accepting), child)
+ }));
+ }
+ }
+ } else {
+ for atom_child in atom_child_iter {
+ // this has been checked in
+ // `generate_edges`
+ if self.atom.is_empty_node(atom_child).unwrap() {
+ // flat flat map, hmm...
+ der_iter.singles.extend(child_iter.clone().flat_map(
+ |child| {
+ self.graph.labels_of(child).unwrap().flat_map(
+ |(child_label, child_child_iter)| {
+ child_child_iter.map(|child_child| {
+ (*child_label, child_child)
+ })
+ },
+ )
+ },
+ ));
+ }
+ }
+ }
+ }
+ }
+ _ => {
+ continue;
+ }
+ }
+ }
+ }
+
+ Ok(der_iter)
+ }
+}
+
+#[cfg(test)]
+mod test_der_iter {
+ use super::*;
+
+ #[test]
+ fn test() -> Result<(), Box<dyn std::error::Error>> {
+ let mut der_iter = DerIter::default();
+
+ let parent = Parent::new(0, 0);
+
+ der_iter.singles.push((Edge::new(0, parent, true), 0));
+
+ der_iter.singles.push((Edge::new(1, parent, true), 0));
+
+ der_iter.singles.push((Edge::new(2, parent, true), 0));
+
+ der_iter.add_second_layer(3, parent, true, vec![(Edge::new(4, parent, true), 1)]);
+
+ der_iter.add_second_layer(6, parent, true, vec![(Edge::new(5, parent, true), 1)]);
+
+ // add an entry with a repeated label
+ der_iter.add_second_layer(3, parent, true, vec![(Edge::new(7, parent, true), 2)]);
+
+ assert_eq!(
+ der_iter.next(),
+ Some(TwoLayers::Two(
+ 3,
+ parent,
+ true,
+ vec![
+ (Edge::new(4, parent, true), 1),
+ (Edge::new(7, parent, true), 2)
+ ]
+ ))
+ );
+
+ assert_eq!(
+ der_iter.next(),
+ Some(TwoLayers::Two(
+ 6,
+ parent,
+ true,
+ vec![(Edge::new(5, parent, true), 1)]
+ ))
+ );
+
+ assert_eq!(
+ der_iter.next(),
+ Some(TwoLayers::One(Edge::new(0, parent, true), 0))
+ );
+
+ assert_eq!(
+ der_iter.next(),
+ Some(TwoLayers::One(Edge::new(1, parent, true), 0))
+ );
+
+ assert_eq!(
+ der_iter.next(),
+ Some(TwoLayers::One(Edge::new(2, parent, true), 0))
+ );
+
+ assert_eq!(der_iter.next(), None);
+ assert_eq!(der_iter.next(), None);
+
+ Ok(())
}
+}
+
+#[cfg(test)]
+mod test_chain {
+ use super::*;
+ use grammar::test_grammar_helper::*;
+
+ #[test]
+ fn base_test() -> Result<(), Box<dyn std::error::Error>> {
+ let grammar = new_notes_grammar()?;
+
+ let atom = DefaultAtom::from_grammar(grammar)?;
+
+ let mut chain = DefaultChain::unit(atom)?;
+
+ chain.chain(3, 00)?;
+ chain.chain(1, 01)?;
+ chain.chain(2, 02)?;
+ chain.chain(2, 03)?;
+ chain.chain(2, 04)?;
+ chain.chain(0, 05)?;
+ chain.chain(5, 06)?;
+ chain.chain(1, 07)?;
+ chain.chain(6, 08)?;
+ chain.chain(6, 09)?;
+ chain.chain(6, 10)?;
+ chain.chain(0, 11)?;
+ chain.chain(0, 12)?;
+
+ assert!(matches!(chain.epsilon(), Ok(true)));
+
+ #[cfg(feature = "test-print-viz")]
+ {
+ chain.graph.print_viz("chain.gv")?;
+ chain.atom.print_nfa("nfa.gv")?;
+ }
- fn chain(&mut self, _t: usize) {
- todo!()
+ Ok(())
}
- fn epsilon(&self) -> bool {
- todo!()
+ #[test]
+ fn test_speed() -> Result<(), Box<dyn std::error::Error>> {
+ let grammar = new_notes_grammar_no_regexp()?;
+
+ println!("grammar: {grammar}");
+
+ let atom = DefaultAtom::from_grammar(grammar)?;
+
+ let mut chain = DefaultChain::unit(atom)?;
+
+ let input_template = vec![3, 1, 2, 2, 2, 0, 5, 1, 6, 6, 6, 0, 0];
+
+ let repeat_times = {
+ let mut result = 1;
+
+ for arg in std::env::args() {
+ let parse_as_digit: Result<usize, _> = arg.parse();
+
+ if let Ok(parse_result) = parse_as_digit {
+ result = parse_result;
+
+ break;
+ }
+ }
+
+ result
+ };
+
+ println!("repeating {repeat_times} times");
+
+ let input = {
+ let mut result = Vec::with_capacity(input_template.len() * repeat_times);
+
+ for _ in 0..repeat_times {
+ result.extend(input_template.iter().copied());
+ }
+
+ result
+ };
+
+ let start = std::time::Instant::now();
+
+ for (index, t) in input.iter().copied().enumerate() {
+ chain.chain(t, index)?;
+ }
+
+ let elapsed = start.elapsed();
+
+ // assert!(matches!(chain.epsilon(), Ok(true)));
+
+ dbg!(elapsed);
+ dbg!(chain.current());
+
+ println!("index: terminal, history");
+ for (index, t) in input.iter().copied().enumerate().take(input.len() - 1) {
+ println!("{index}: {t}, {}", chain.history().get(index).unwrap());
+ }
+
+ #[cfg(feature = "test-print-viz")]
+ {
+ chain.graph.print_viz("chain.gv")?;
+ chain.atom.print_nfa("nfa.gv")?;
+ }
+
+ Ok(())
}
}