//! This file provides a default implementation of the //! [`Chain`][crate::Chain] trait. //! //! The reason for using a trait is that I might want to experiment //! with different implementation ideas in the future, and this //! modular design makes that easy. use super::*; use crate::atom::{Atom, DefaultAtom}; use core::fmt::Display; use forest::{default::DefaultForest, Forest}; use grammar::{Error as GrammarError, GrammarLabel, GrammarLabelType, TNT}; #[allow(unused_imports)] use graph::{ labelled::DLGBuilder, Builder, DLGraph, Graph, LabelExtGraph, LabelGraph, ParentsGraph, }; use std::collections::{HashMap as Map, TryReserveError}; /// The errors related to taking derivatives by chain rule. #[non_exhaustive] #[derive(Debug)] pub enum Error { /// General error for indices out of bounds. IndexOutOfBounds(usize, usize), /// The forest encounters a duplicate node, for some reason. DuplicateNode(usize), /// The chain rule machine encounters a duplicate edge, for some /// reason. DuplicateEdge(usize, usize), /// A node has no labels while it is required to have one. NodeNoLabel(usize), /// Reserving memory fails. CannotReserve(TryReserveError), /// An invalid situation happens. Invalid, } impl Display for Error { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { match self { Self::IndexOutOfBounds(index, bound) => write!(f, "index {index} out of bound {bound}"), Self::DuplicateNode(n) => write!(f, "the forest has a node {n} with a duplicate label"), Self::DuplicateEdge(source, target) => write!( f, "the forest has a duplicate edge from {source} to {target}" ), Self::NodeNoLabel(n) => write!(f, "node {n} has no labels while it should have one"), Self::CannotReserve(e) => write!(f, "cannot reserve memory: {e}"), Self::Invalid => write!(f, "invalid"), } } } impl std::error::Error for Error {} impl From for Error { fn from(value: GError) -> Self { match value { GError::IndexOutOfBounds(index, bound) => Self::IndexOutOfBounds(index, bound), GError::DuplicatedNode(n) => Self::DuplicateNode(n), GError::DuplicatedEdge(source, target) => Self::DuplicateEdge(source, target), _ => Self::Invalid, } } } impl From for Error { fn from(e: ForestError) -> Self { match e { ForestError::IndexOutOfBounds(index, bound) => Error::IndexOutOfBounds(index, bound), ForestError::DuplicatedNode(n) => Error::DuplicateNode(n), ForestError::InvalidGraphError(ge) => ge.into(), ForestError::NodeNoLabel(n) => Error::NodeNoLabel(n), } } } impl From for Error { fn from(value: TryReserveError) -> Self { Self::CannotReserve(value) } } /// The type of an index into an element in [`DerIter`]. #[derive(Debug, Copy, Clone)] enum DerIterIndex { Single(usize), Map(usize), } impl Default for DerIterIndex { fn default() -> Self { Self::Map(0) } } /// A complex type used for storing values of edges with two layers. type SecondTypeValue = (Parent, bool, Vec<(Edge, usize)>); /// An iterator of TwoLayers. #[derive(Debug, Default)] pub struct DerIter { /// Stores edges of only one layer. singles: Vec<(Edge, usize)>, /// Stores edges with two layers. They are grouped by their /// labels of the second layer. /// /// The values are tuples (forest_source, accepting, edges), where /// the edges are the grouped edges of the first layer and the /// destination. seconds: Map, /// We want to iterate the elements of the map, for which purpose /// we need an array. Since hashmaps provide no arrays, we keep /// an array of keys for iteration purposes. second_array: Vec, /// The index of the current element, either in `second_array` or /// in `singles` . index: DerIterIndex, } impl DerIter { fn add_second_layer( &mut self, label: usize, forest_source: Parent, accepting: bool, edges: Vec<(Edge, usize)>, ) { if let Some((_, _, vec)) = self.seconds.get_mut(&label) { vec.extend(edges); } else { self.seconds .insert(label, (forest_source, accepting, edges)); self.second_array.push(label); } } } impl Iterator for DerIter { type Item = TwoLayers; fn next(&mut self) -> Option { // We iterate through two layered edges first. match self.index { DerIterIndex::Map(index) => { if let Some(key) = self.second_array.get(index) { if let Some((forest_source, accepting, edges)) = self.seconds.remove(key) { self.index = DerIterIndex::Map(index + 1); Some(TwoLayers::Two(*key, forest_source, accepting, edges)) } else { // this should not happen println!("a key does not exist in the hashmap: something is wrong when taking derivatives"); None } } else { self.index = DerIterIndex::Single(0); if let Some((edge, to)) = self.singles.first() { self.index = DerIterIndex::Single(1); Some(TwoLayers::One(*edge, *to)) } else { None } } } DerIterIndex::Single(index) => { if let Some((edge, to)) = self.singles.get(index) { self.index = DerIterIndex::Single(index + 1); Some(TwoLayers::One(*edge, *to)) } else { None } } } } } /// A default implementation for the [`Chain`] trait. #[derive(Debug, Clone, Default)] pub struct DefaultChain { graph: DLGraph, atom: DefaultAtom, current: usize, history: Vec, forest: DefaultForest, accepting_vec: Vec, } impl DefaultChain { /// Return the current node. #[inline] pub fn current(&self) -> usize { self.current } /// Return the complete slice of histories. #[inline] pub fn history(&self) -> &[usize] { self.history.as_ref() } /// Return a reference to the associated forest. #[inline] pub fn forest(&self) -> &DefaultForest { &self.forest } /// Print the rule positions of the labels. pub fn print_rule_positions(&self) -> Result<(), Box> { let mut labels = std::collections::HashSet::::default(); for node in 0..self.graph.nodes_len() { labels.extend(self.graph.labels_of(node)?.map(|(label, _)| label.label)); } for label in labels.into_iter() { println!("{}", self.atom.rule_pos_string(label)?); } Ok(()) } } impl Graph for DefaultChain { type Iter<'a> = as Graph>::Iter<'a> where Self: 'a; #[inline] fn is_empty(&self) -> bool { self.graph.is_empty() } #[inline] fn nodes_len(&self) -> usize { self.graph.nodes_len() } #[inline] fn edges_len(&self) -> Option { self.graph.edges_len() } #[inline] fn children_of(&self, node_id: usize) -> Result, GError> { self.graph.children_of(node_id) } #[inline] fn degree(&self, node_id: usize) -> Result { self.graph.degree(node_id) } #[inline] fn is_empty_node(&self, node_id: usize) -> Result { self.graph.is_empty_node(node_id) } #[inline] fn has_edge(&self, source: usize, target: usize) -> Result { self.graph.has_edge(source, target) } fn replace_by_builder(&mut self, _builder: impl graph::Builder) { unimplemented!("I shall refactor this") } } impl LabelGraph for DefaultChain { type Iter<'a> = as LabelGraph>::Iter<'a> where Self: 'a; type LabelIter<'a> = as LabelGraph>::LabelIter<'a> where Self: 'a, Edge: 'a; type EdgeLabelIter<'a> = as LabelGraph>::EdgeLabelIter<'a> where Self: 'a, Edge: 'a; #[inline] fn edge_label(&self, source: usize, target: usize) -> Result, GError> { self.graph.edge_label(source, target) } #[inline] fn find_children_with_label( &self, node_id: usize, label: &Edge, ) -> Result<>::Iter<'_>, GError> { self.graph.find_children_with_label(node_id, label) } #[inline] fn labels_of(&self, node_id: usize) -> Result, GError> { self.graph.labels_of(node_id) } #[inline] fn has_edge_label(&self, node_id: usize, label: &Edge, target: usize) -> Result { self.graph.has_edge_label(node_id, label, target) } } impl LabelExtGraph for DefaultChain { #[inline] fn extend(&mut self, edges: impl IntoIterator) -> Result { let new = self.graph.extend(edges)?; let accepting_len = self.accepting_vec.len(); if self.accepting_vec.get(new).is_none() { // assert it can only grow by one node at a time. #[cfg(debug_assertions)] assert_eq!(new, accepting_len); let mut updated = false; for (label, child_iter) in self.graph.labels_of(new)? { let old_accepting = { let mut result = false; for child in child_iter { if *self .accepting_vec .get(child) .ok_or(GError::IndexOutOfBounds(child, accepting_len))? { result = true; break; } } result }; if !old_accepting { self.accepting_vec.push(false); updated = true; break; } if label.is_accepting() { self.accepting_vec.push(true); updated = true; break; } } if !updated { self.accepting_vec.push(false); } } Ok(new) } } impl Chain for DefaultChain { type Error = Error; type Atom = DefaultAtom; fn unit(atom: Self::Atom) -> Result { let mut builder: DLGBuilder = Default::default(); let root = builder.add_vertex(); let first = builder.add_vertex(); let empty_state = atom.empty(); let initial_nullable = atom .is_nullable(0) .map_err(|_| Error::IndexOutOfBounds(0, atom.non_num()))?; builder.add_edge( first, root, Edge::new(empty_state, Parent::new(0, 0), initial_nullable), )?; let graph = builder.build(); let forest = DefaultForest::new_leaf(GrammarLabel::new(GrammarLabelType::TNT(TNT::Non(0)), 0)); #[cfg(debug_assertions)] assert_eq!(forest.root(), Some(0)); let current = 1; let history = Vec::new(); let accepting_vec = vec![true, initial_nullable]; Ok(Self { graph, atom, current, history, forest, accepting_vec, }) } fn epsilon(&self) -> Result { self.accepting_vec .get(self.current) .copied() .ok_or(Error::IndexOutOfBounds( self.current, self.accepting_vec.len(), )) } fn update_history(&mut self, new: usize) { debug_assert!(new < self.graph.nodes_len()); self.history.push(self.current); self.current = new; } type DerIter = DerIter; fn derive(&mut self, t: usize, _pos: usize) -> Result { use TNT::*; /// Convert an error telling us that an index is out of bounds. /// /// # Panics /// /// The function panics if the error is not of the expected /// kind. fn index_out_of_bounds_conversion(ge: GrammarError) -> Error { match ge { GrammarError::IndexOutOfBounds(index, bound) => { Error::IndexOutOfBounds(index, bound) } _ => panic!("wrong error kind"), } } /// A helper function to generate edges to join. /// /// It first checks if the base edge is accepting. If yes, /// then pull in the children of the target. /// /// Then check if the label of the base edge has children. If /// no, then do not add this base edge itself. /// /// The generated edges will be pushed to `output` directly, /// to save some allocations. // TODO: Handle forests as well. fn generate_edges( chain: &DefaultChain, child_iter: impl Iterator + ExactSizeIterator + Clone, atom_child_iter: impl Iterator + Clone, mut output: impl AsMut>, ) -> Result<(), Error> { // First check the values from iterators are all valid. let graph_len = chain.graph.nodes_len(); let atom_len = chain.atom.nodes_len(); for child in child_iter.clone() { if !chain.graph.has_node(child) { return Err(Error::IndexOutOfBounds(child, graph_len)); } } for atom_child in atom_child_iter.clone() { if !chain.atom.has_node(atom_child) { return Err(Error::IndexOutOfBounds(atom_child, atom_len)); } } // From now on the nodes are all valid, so we can just // call `unwrap`. // Then calculate the number of edges to append, to avoid // repeated allocations let mut num = 0usize; let child_iter_total_degree = child_iter .clone() .map(|child| chain.graph.degree(child).unwrap()) .sum::(); for atom_child in atom_child_iter.clone() { let atom_child_accepting = chain.atom.is_accepting(atom_child).unwrap(); let atom_child_empty_node = chain.atom.is_empty_node(atom_child).unwrap(); if !atom_child_empty_node { num += child_iter.len(); } if atom_child_accepting { num += child_iter_total_degree; } } let num = num; let output = output.as_mut(); output.try_reserve(num)?; // now push into output let parent = Parent::new(0, 0); for atom_child in atom_child_iter { let atom_child_accepting = chain.atom.is_accepting(atom_child).unwrap(); let atom_child_empty_node = chain.atom.is_empty_node(atom_child).unwrap(); let edge = Edge::new(atom_child, parent, atom_child_accepting); if !atom_child_empty_node { output.extend(child_iter.clone().map(|child| (edge, child))); } if atom_child_accepting { for child in child_iter.clone() { for (child_label, child_child) in chain.graph.labels_of(child).unwrap() { output.extend(child_child.map(|target| (*child_label, target))); } } } } Ok(()) } let mut der_iter = DerIter::default(); for (label, child_iter) in self.graph.labels_of(self.current)? { for (atom_label, atom_child_iter) in self.atom.labels_of(label.label())? { if atom_label.is_left_p() { // We do not consider left-linearly expanded // children in the first layer. continue; } match *atom_label.get_value() { Some(Ter(ter)) if ter == t => { generate_edges( self, child_iter.clone(), atom_child_iter.clone(), &mut der_iter.singles, )?; } Some(Non(non)) => { let virtual_node = self .atom .atom(non, t) .map_err(index_out_of_bounds_conversion)?; if let Some(virtual_node) = virtual_node { let accepting = self .atom .is_accepting(virtual_node) .map_err(index_out_of_bounds_conversion)?; let mut new_edges = Vec::new(); generate_edges( self, child_iter.clone(), atom_child_iter.clone(), &mut new_edges, )?; if accepting { der_iter.singles.extend(new_edges.clone()); } let parent = Parent::new(0, 0); if !self.atom.is_empty_node(virtual_node).unwrap() { der_iter.add_second_layer( virtual_node, parent, accepting, new_edges, ); // account for atom_children without // children. for atom_child in atom_child_iter { // this has been checked in // `generate_edges` if self.atom.is_empty_node(atom_child).unwrap() { der_iter.singles.extend(child_iter.clone().map(|child| { (Edge::new(virtual_node, parent, accepting), child) })); } } } else { for atom_child in atom_child_iter { // this has been checked in // `generate_edges` if self.atom.is_empty_node(atom_child).unwrap() { // flat flat map, hmm... der_iter.singles.extend(child_iter.clone().flat_map( |child| { self.graph.labels_of(child).unwrap().flat_map( |(child_label, child_child_iter)| { child_child_iter.map(|child_child| { (*child_label, child_child) }) }, ) }, )); } } } } } _ => { continue; } } } } Ok(der_iter) } } #[cfg(test)] mod test_der_iter { use super::*; #[test] fn test() -> Result<(), Box> { let mut der_iter = DerIter::default(); let parent = Parent::new(0, 0); der_iter.singles.push((Edge::new(0, parent, true), 0)); der_iter.singles.push((Edge::new(1, parent, true), 0)); der_iter.singles.push((Edge::new(2, parent, true), 0)); der_iter.add_second_layer(3, parent, true, vec![(Edge::new(4, parent, true), 1)]); der_iter.add_second_layer(6, parent, true, vec![(Edge::new(5, parent, true), 1)]); // add an entry with a repeated label der_iter.add_second_layer(3, parent, true, vec![(Edge::new(7, parent, true), 2)]); assert_eq!( der_iter.next(), Some(TwoLayers::Two( 3, parent, true, vec![ (Edge::new(4, parent, true), 1), (Edge::new(7, parent, true), 2) ] )) ); assert_eq!( der_iter.next(), Some(TwoLayers::Two( 6, parent, true, vec![(Edge::new(5, parent, true), 1)] )) ); assert_eq!( der_iter.next(), Some(TwoLayers::One(Edge::new(0, parent, true), 0)) ); assert_eq!( der_iter.next(), Some(TwoLayers::One(Edge::new(1, parent, true), 0)) ); assert_eq!( der_iter.next(), Some(TwoLayers::One(Edge::new(2, parent, true), 0)) ); assert_eq!(der_iter.next(), None); assert_eq!(der_iter.next(), None); Ok(()) } } #[cfg(test)] mod test_chain { use super::*; use grammar::test_grammar_helper::*; #[test] fn base_test() -> Result<(), Box> { let grammar = new_notes_grammar()?; let atom = DefaultAtom::from_grammar(grammar)?; let mut chain = DefaultChain::unit(atom)?; chain.chain(3, 00)?; chain.chain(1, 01)?; chain.chain(2, 02)?; chain.chain(2, 03)?; chain.chain(2, 04)?; chain.chain(0, 05)?; chain.chain(5, 06)?; chain.chain(1, 07)?; chain.chain(6, 08)?; chain.chain(6, 09)?; chain.chain(6, 10)?; chain.chain(0, 11)?; chain.chain(0, 12)?; assert!(matches!(chain.epsilon(), Ok(true))); #[cfg(feature = "test-print-viz")] { chain.graph.print_viz("chain.gv")?; chain.atom.print_nfa("nfa.gv")?; } Ok(()) } #[test] fn test_speed() -> Result<(), Box> { let grammar = new_notes_grammar_no_regexp()?; println!("grammar: {grammar}"); let atom = DefaultAtom::from_grammar(grammar)?; let mut chain = DefaultChain::unit(atom)?; let input_template = vec![3, 1, 2, 2, 2, 0, 5, 1, 6, 6, 6, 0, 0]; let repeat_times = { let mut result = 1; for arg in std::env::args() { let parse_as_digit: Result = arg.parse(); // just use the first number in the arguments if let Ok(parse_result) = parse_as_digit { result = parse_result; break; } } result }; println!("repeating {repeat_times} times"); let input = { let mut result = Vec::with_capacity(input_template.len() * repeat_times); for _ in 0..repeat_times { result.extend(input_template.iter().copied()); } result }; let start = std::time::Instant::now(); for (index, t) in input.iter().copied().enumerate() { chain.chain(t, index)?; } let elapsed = start.elapsed(); // assert!(matches!(chain.epsilon(), Ok(true))); dbg!(elapsed); dbg!(chain.current()); assert_eq!(input.len(), chain.history().len()); if std::fs::metadata("output/history").is_ok() { std::fs::remove_file("output/history")?; } let mut history_file = std::fs::OpenOptions::new() .create(true) .write(true) .open("output/history")?; use std::fmt::Write; use std::io::Write as IOWrite; let mut log_string = String::new(); writeln!(&mut log_string, "index: terminal, history")?; for (index, t) in input.iter().copied().enumerate().take(input.len() - 1) { writeln!( &mut log_string, "{index}: {t}, {}", chain.history().get(index).unwrap() )?; } println!("Successfully logged to output/history"); history_file.write_all(log_string.as_bytes())?; #[cfg(feature = "test-print-viz")] { chain.graph.print_viz("chain.gv")?; chain.atom.print_nfa("nfa.gv")?; } Ok(()) } }