summaryrefslogtreecommitdiff
path: root/graph_macro/src/lib.rs
diff options
context:
space:
mode:
authorJSDurand <mmemmew@gmail.com>2023-05-04 13:02:39 +0800
committerJSDurand <mmemmew@gmail.com>2023-05-04 13:02:39 +0800
commit662817e6367a865a2d86a99581172cc45f585807 (patch)
treee5c7d1a0a52ce9d057d9c27ac4c7549b77198efb /graph_macro/src/lib.rs
parent57d600f261cca5d9076239e548c6e00646f774b6 (diff)
Completed the procedural macro for deriving Graphs.
The macro `graph_derive` can automatically write the boiler-plate codes for wrapper types one of whose sub-fields implements the `Graph` trait. The generated implementation will delegate the `Graph` operations to the sub-field which implements the `Graph` trait. I plan to add more macros, corresponding to various other graph-related traits, so that no such boiler-plate codes are needed, at least for my use-cases.
Diffstat (limited to 'graph_macro/src/lib.rs')
-rw-r--r--graph_macro/src/lib.rs663
1 files changed, 654 insertions, 9 deletions
diff --git a/graph_macro/src/lib.rs b/graph_macro/src/lib.rs
index a55efbb..0f56f57 100644
--- a/graph_macro/src/lib.rs
+++ b/graph_macro/src/lib.rs
@@ -1,19 +1,664 @@
-#![allow(unused_imports)]
//! This file provides a macro to delegate the implementation of a
//! type that wraps a graph. More precisely, the macro helps the
//! wrapper type implement the various Graph-related traits.
+//!
+//! See the macro [`graph_derive`].
-use proc_macro::TokenStream;
+use proc_macro::{Delimiter, Group, Ident, Literal, Punct, Spacing, Span, TokenStream, TokenTree};
-use core::iter::Peekable;
+use std::{collections::HashMap, iter::Peekable};
-#[proc_macro_derive(Testing)]
-pub fn test_derive(input: TokenStream) -> TokenStream {
- let input = input.into_iter().peekable();
+/// A visibility modifer is either "pub", "pub(restriction)", or
+/// "Private".
+#[derive(Debug, Default)]
+enum VisibilityMod {
+ /// Simple pub visibility.
+ Pub,
+ /// The pub visibility with restrictions.
+ ///
+ /// Since the restrictions can be given in a variety of ways, and
+ /// since it does not matter to us, we just keep them as a
+ /// tokentree.
+ PubRes(TokenTree),
+ #[default]
+ Private,
+}
+
+#[derive(Debug)]
+struct Generic {
+ name: Ident,
+ type_stream: TokenStream,
+}
+
+impl Generic {
+ fn new(name: Ident, type_stream: TokenStream) -> Self {
+ Self { name, type_stream }
+ }
+}
+
+/// The content of the struct
+#[derive(Debug)]
+struct GraphInfo {
+ name: TokenStream,
+ generics: Vec<Generic>,
+ field_name: String,
+ field_type: TokenStream,
+}
+
+impl GraphInfo {
+ fn new(
+ name: TokenStream,
+ generics: Vec<Generic>,
+ field_name: String,
+ field_type: TokenStream,
+ ) -> Self {
+ Self {
+ name,
+ generics,
+ field_name,
+ field_type,
+ }
+ }
+}
+
+/// For custom compiler errors
+#[derive(Debug)]
+struct CompileError {
+ mes: String,
+ span: Span,
+}
+
+impl CompileError {
+ fn new(mes: impl ToString, span: Span) -> Self {
+ Self {
+ mes: mes.to_string(),
+ span,
+ }
+ }
+
+ /// Produce a token stream that will trigger compiler errors.
+ ///
+ /// The point is that the span of the token trees will be set to
+ /// the desired locations, and hence produce the effect we want.
+ fn emit(self) -> TokenStream {
+ // I cannot use the parse method of the type TokenStream, as
+ // that will not set the spans properly.
+
+ let compile_error_ident = TokenTree::Ident(Ident::new("compile_error", self.span.clone()));
+ let mut exclamation_punct = TokenTree::Punct(Punct::new('!', Spacing::Alone));
+
+ exclamation_punct.set_span(self.span.clone());
+
+ let mut arg_mes_literal = TokenTree::Literal(Literal::string(&self.mes));
+
+ arg_mes_literal.set_span(self.span.clone());
+
+ let arg_mes_stream = [arg_mes_literal].into_iter().collect();
+
+ let mut arg_group = TokenTree::Group(Group::new(Delimiter::Parenthesis, arg_mes_stream));
+
+ arg_group.set_span(self.span.clone());
+
+ let mut semi_colon_punct = TokenTree::Punct(Punct::new(';', Spacing::Alone));
+
+ semi_colon_punct.set_span(self.span);
+
+ [
+ compile_error_ident,
+ exclamation_punct,
+ arg_group,
+ semi_colon_punct,
+ ]
+ .into_iter()
+ .collect()
+ }
+}
+
+macro_rules! unwrap_or_emit {
+ ($value:ident) => {
+ if let Err(err) = $value {
+ return err.emit();
+ }
+
+ let $value = $value.unwrap();
+ };
+}
+
+#[allow(unused)]
+macro_rules! mytodo {
+ () => {
+ Err(CompileError::new("not umplemented yet", Span::mixed_site()))
+ };
+}
+
+macro_rules! myerror {
+ ($mes:tt, $input:ident) => {
+ Err(CompileError::new($mes, $input.peek().unwrap().span()))
+ };
+}
+
+#[proc_macro_derive(Graph, attributes(graph))]
+pub fn graph_derive(input: TokenStream) -> TokenStream {
+ let mut input = input.into_iter().peekable();
+
+ while move_attributes(&mut input).is_ok() {}
+
+ let _ = get_visibility_mod(&mut input);
+
+ let info = get_info(&mut input);
+
+ unwrap_or_emit!(info);
+
+ let name = info.name;
+ let generics = info.generics;
+ let field_name = info.field_name;
+ let field_type = info.field_type;
+
+ let generics_impl_string = {
+ let mut result = String::new();
+
+ if !generics.is_empty() {
+ result.push_str("<");
+ }
+
+ for generic in generics.iter() {
+ result.push_str(&format!("{}:{},", generic.name, generic.type_stream));
+ }
+
+ if !generics.is_empty() {
+ result.push_str(">");
+ }
+
+ result
+ };
+
+ let generic_struct_string = {
+ let mut result = String::new();
+
+ if !generics.is_empty() {
+ result.push_str("<");
+ }
+
+ for generic in generics.iter() {
+ result.push_str(&format!("{},", generic.name));
+ }
+
+ if !generics.is_empty() {
+ result.push_str(">");
+ }
+
+ result
+ };
+
+ let generic_where_string = {
+ let mut result = String::new();
+
+ if !generics.is_empty() {
+ result.push_str("\nwhere\n");
+ }
+
+ for generic in generics.iter() {
+ result.push_str(&format!("{}: 'a,\n", generic.name));
+ }
+
+ result
+ };
+
+ let result = format!(
+ "
+#[automatically_derived]
+impl{generics_impl_string} graph::Graph for {name}{generic_struct_string} {{
+type Iter<'a> = <{field_type} as graph::Graph>::Iter<'a>{generic_where_string};
+
+#[inline]
+fn is_empty(&self) -> bool {{self.{field_name}.is_empty()}}
+
+#[inline]
+fn nodes_len(&self) -> usize {{ self.{field_name}.nodes_len() }}
+
+#[inline]
+fn children_of(&self, node: usize) -> Result<<Self as graph::Graph>::Iter<'_>, graph::error::Error> {{
+self.{field_name}.children_of(node)
+}}
+
+#[inline]
+fn degree(&self, node: usize) -> Result<usize, graph::error::Error> {{ self.{field_name}.degree(node) }}
+
+#[inline]
+fn is_empty_node(&self, node: usize) -> Result<bool, graph::error::Error> {{ self.{field_name}.is_empty_node(node) }}
+
+#[inline]
+fn has_edge(&self, nodea: usize, nodeb: usize) -> Result<bool, graph::error::Error> {{
+self.{field_name}.has_edge(nodea, nodeb)
+}}
+
+#[inline]
+fn replace_by_builder(&mut self, _builder: impl graph::builder::Builder<Result = Self>) {{
+unimplemented!()
+}}
+}}"
+ );
+
+ result.parse().unwrap()
+}
+
+fn get_visibility_mod(
+ input: &mut Peekable<impl Iterator<Item = TokenTree>>,
+) -> Result<VisibilityMod, CompileError> {
+ if let Some(TokenTree::Ident(id)) = input.peek() {
+ if id.to_string() == "pub" {
+ input.next();
+
+ if let Some(TokenTree::Group(g)) = input.peek() {
+ let delimiter = g.delimiter();
+
+ if delimiter != Delimiter::Parenthesis {
+ return Err(CompileError::new("invalid bracket", g.span()));
+ }
+
+ let tree = TokenTree::Group(g.clone());
+
+ input.next();
+
+ return Ok(VisibilityMod::PubRes(tree));
+ } else {
+ return Ok(VisibilityMod::Pub);
+ }
+ }
+ }
+
+ Ok(Default::default())
+}
+
+fn get_info(
+ input: &mut Peekable<impl Iterator<Item = TokenTree>>,
+) -> Result<GraphInfo, CompileError> {
+ if let Some(TokenTree::Ident(id)) = input.peek() {
+ if id.to_string() != "struct" {
+ let span = input.next().unwrap().span();
+ return Err(CompileError::new("Only struct is supported", span));
+ }
+
+ input.next();
+
+ let struct_name = get_until(
+ input,
+ |tree| match tree {
+ TokenTree::Group(g)
+ if g.delimiter() == Delimiter::Brace
+ || g.delimiter() == Delimiter::Parenthesis =>
+ {
+ true
+ }
+ TokenTree::Punct(p) if p.as_char() == ';' || p.as_char() == '<' => true,
+ _ => false,
+ },
+ false,
+ false,
+ )?;
+
+ // We are not at the end as guaranteed by `get_until`.
+ let peek = input.peek().unwrap();
+
+ if matches!(peek, TokenTree::Punct(p) if p.as_char() == ';') {
+ // unit struct
+
+ return myerror!("Cannot derive graph for unit structs", input);
+ }
+
+ let mut generics = Vec::new();
+
+ if matches!(peek, TokenTree::Punct(p) if p.as_char() == '<') {
+ // Extract generics
+
+ input.next();
+
+ generics = get_generics(input)?;
+
+ // println!("generics: {generics:?}");
+ }
+
+ let (graph_field_name, graph_field_type) = get_graph_field_name_type(input)?;
+
+ Ok(GraphInfo::new(
+ struct_name,
+ generics,
+ graph_field_name,
+ graph_field_type,
+ ))
+ } else {
+ myerror!("unexpected token", input)
+ }
+}
+
+fn get_graph_field_name_type(
+ input: &mut Peekable<impl Iterator<Item = TokenTree>>,
+) -> Result<(String, TokenStream), CompileError> {
+ let end_of_input = Err(CompileError::new(
+ "unexpected end of input",
+ Span::mixed_site(),
+ ));
+
+ match input.peek() {
+ Some(TokenTree::Group(g)) => {
+ let delimiter = g.delimiter();
+
+ if delimiter != Delimiter::Parenthesis && delimiter != Delimiter::Brace {
+ return myerror!(
+ "expected a group enclosed in curly brackets or parentheses",
+ input
+ );
+ }
+
+ if delimiter == Delimiter::Parenthesis {
+ // a named tuple
+
+ // in this case only a tuple with one field is
+ // supported
+
+ let mut body = g.stream().into_iter().peekable();
+
+ while let Some(tree) = body.peek() {
+ if matches!(tree, TokenTree::Punct(p) if p.as_char() == ',') {
+ return myerror!(
+ "Among named tuples, only those with one field are supported",
+ input
+ );
+ }
+
+ body.next();
+ }
+
+ let mut body = g.stream().into_iter().peekable();
+
+ let _vis = get_visibility_mod(&mut body)?;
+
+ return Ok(("0".to_string(), body.collect()));
+ }
+
+ let mut body = g.stream().into_iter().peekable();
+
+ if body.peek().is_none() {
+ return end_of_input;
+ }
+
+ get_visibility_mod(&mut body)?;
+
+ let mut cloned_body = body.clone();
+
+ let mut result_name = get_first_ident(&mut cloned_body, g.span())?.to_string();
+
+ move_punct(&mut cloned_body, ':')?;
+
+ if cloned_body.peek().is_none() {
+ return end_of_input;
+ }
+
+ let mut result_type = get_until(
+ &mut cloned_body,
+ |tree| matches!(tree, TokenTree::Punct(p) if p.as_char() == ','),
+ true,
+ true,
+ )?;
+
+ let mut result = (result_name, result_type);
+
+ 'search: while let Some(tree) = body.next() {
+ match tree {
+ TokenTree::Punct(p) if p.as_char() == '#' => match body.peek() {
+ Some(TokenTree::Group(g)) if g.delimiter() == Delimiter::Bracket => {
+ let mut attr_stream = g.stream().into_iter().peekable();
+
+ let mut found_attribute = false;
+
+ while let Some(TokenTree::Ident(id)) = attr_stream.next() {
+ if id.to_string() == "graph" {
+ found_attribute = true;
+
+ break;
+ }
+
+ let _ = move_punct(&mut attr_stream, ',');
+ }
+
+ body.next();
+
+ if found_attribute {
+ get_visibility_mod(&mut body)?;
+
+ result_name = get_ident(&mut body)?.to_string();
+
+ move_punct(&mut body, ':')?;
+
+ result_type = get_until(
+ &mut body,
+ |tree| matches!(tree, TokenTree::Punct(p) if p.as_char() == ','),
+ true,
+ true,
+ )?;
+
+ result = (result_name, result_type);
+
+ break 'search;
+ }
+ }
+ Some(_) => {
+ return myerror!("expected a group enclosed in square brackets", body);
+ }
+ None => {
+ return end_of_input;
+ }
+ },
+ _ => {}
+ }
+ }
+
+ Ok(result)
+ }
+ Some(_) => myerror!("expected a group or a semicolon", input),
+
+ _ => end_of_input,
+ }
+}
+
+fn get_generics(
+ input: &mut Peekable<impl Iterator<Item = TokenTree>>,
+) -> Result<Vec<Generic>, CompileError> {
+ // Assume the starting '<' punct has already been consumed.
+
+ let mut generic_map: HashMap<String, TokenStream> = Default::default();
+
+ let mut result: Vec<Generic> = Vec::new();
+
+ while !matches!(input.peek(), Some(TokenTree::Punct(p)) if p.as_char() == '>') {
+ let ident = get_ident(input)?.to_string();
+
+ let mut stream = TokenStream::new();
+
+ if matches!(input.peek(), Some(TokenTree::Punct(p)) if p.as_char() == ':') {
+ input.next();
+ stream = get_until(
+ input,
+ |tree| matches!(tree, TokenTree::Punct(p) if p.as_char() == ',' || p.as_char() == '>'),
+ false,
+ false,
+ )?;
+ }
+
+ generic_map.insert(ident, stream);
- for tree in input {
- println!("tree = {tree:?}");
+ let _ = move_punct(input, ',');
}
- TokenStream::new()
+ assert!(matches!(input.next(), Some(TokenTree::Punct(p)) if p.as_char() == '>'));
+
+ if matches!(input.peek(), Some(TokenTree::Ident(id)) if id.to_string() == "where") {
+ // where clauses
+
+ input.next();
+
+ while !matches!(input.peek(), Some(TokenTree::Group(g)) if g.delimiter() == Delimiter::Brace)
+ {
+ let ident = get_ident(input)?.to_string();
+
+ move_punct(input, ':')?;
+
+ let stream = get_until(
+ input,
+ |tree| match tree {
+ TokenTree::Punct(p) if p.as_char() == ',' => true,
+ TokenTree::Group(g) if g.delimiter() == Delimiter::Brace => true,
+ _ => false,
+ },
+ false,
+ false,
+ )?;
+
+ generic_map.insert(ident, stream);
+
+ let _ = move_punct(input, ',');
+ }
+ }
+
+ if generic_map.is_empty() {
+ // this is not valid syntax anyways
+ return myerror!("empty generics section?", input);
+ }
+
+ for (k, v) in generic_map {
+ result.push(Generic::new(Ident::new(k.as_str(), Span::mixed_site()), v));
+ }
+
+ Ok(result)
+}
+
+fn get_ident(input: &mut Peekable<impl Iterator<Item = TokenTree>>) -> Result<Ident, CompileError> {
+ match input.peek() {
+ Some(TokenTree::Ident(id)) => {
+ let id = id.clone();
+
+ input.next();
+
+ Ok(id)
+ }
+ Some(next_tree) => {
+ let span = next_tree.span();
+
+ Err(CompileError::new("expected an identifier", span))
+ }
+ _ => Err(CompileError::new(
+ "unexpected end of input",
+ Span::mixed_site(),
+ )),
+ }
+}
+
+fn get_first_ident(
+ input: &mut Peekable<impl Iterator<Item = TokenTree>>,
+ span: Span,
+) -> Result<Ident, CompileError> {
+ while let Some(tree) = input.next() {
+ match tree {
+ TokenTree::Ident(id) => {
+ return Ok(id);
+ }
+ _ => {}
+ }
+ }
+
+ Err(CompileError::new("expect at least one identifier", span))
+}
+
+fn get_until(
+ input: &mut Peekable<impl Iterator<Item = TokenTree>>,
+ predicate: impl Fn(&TokenTree) -> bool,
+ accept_end: bool,
+ consume_boundary: bool,
+) -> Result<TokenStream, CompileError> {
+ let mut trees: Vec<TokenTree> = Vec::new();
+
+ let mut found_predicate = false;
+
+ if consume_boundary {
+ while let Some(tree) = input.next() {
+ if predicate(&tree) {
+ found_predicate = true;
+ break;
+ }
+
+ trees.push(tree);
+ }
+ } else {
+ while let Some(tree) = input.peek() {
+ if predicate(tree) {
+ found_predicate = true;
+ break;
+ }
+
+ let tree = input.next().unwrap();
+
+ trees.push(tree);
+ }
+ }
+
+ if found_predicate || accept_end {
+ Ok(trees.into_iter().collect())
+ } else {
+ Err(CompileError::new(
+ "unexpected end of input",
+ Span::mixed_site(),
+ ))
+ }
+}
+
+fn move_punct(
+ input: &mut Peekable<impl Iterator<Item = TokenTree>>,
+ punct: char,
+) -> Result<(), CompileError> {
+ let error_mes = format!("expect punctuation {punct}");
+
+ match input.peek() {
+ Some(TokenTree::Punct(p)) if p.as_char() == punct => {
+ input.next();
+
+ Ok(())
+ }
+ Some(_) => myerror!(error_mes, input),
+ _ => Err(CompileError::new(
+ "unexpected end of input",
+ Span::mixed_site(),
+ )),
+ }
+}
+
+fn move_attributes(
+ input: &mut Peekable<impl Iterator<Item = TokenTree>>,
+) -> Result<(), CompileError> {
+ let error_mes = "expect attributes";
+ let end_of_input = Err(CompileError::new(
+ "unexpected end of input",
+ Span::mixed_site(),
+ ));
+
+ match input.peek() {
+ Some(TokenTree::Punct(p)) if p.as_char() == '#' => {
+ input.next();
+
+ match input.peek() {
+ Some(TokenTree::Group(g)) if g.delimiter() == Delimiter::Bracket => {
+ input.next();
+ }
+ Some(_) => {
+ return myerror!("expected a group in square brackets", input);
+ }
+ _ => {
+ return end_of_input;
+ }
+ }
+
+ Ok(())
+ }
+ Some(_) => myerror!(error_mes, input),
+ _ => end_of_input,
+ }
}