diff options
author | JSDurand <mmemmew@gmail.com> | 2023-05-04 13:02:39 +0800 |
---|---|---|
committer | JSDurand <mmemmew@gmail.com> | 2023-05-04 13:02:39 +0800 |
commit | 662817e6367a865a2d86a99581172cc45f585807 (patch) | |
tree | e5c7d1a0a52ce9d057d9c27ac4c7549b77198efb /graph_macro/src/lib.rs | |
parent | 57d600f261cca5d9076239e548c6e00646f774b6 (diff) |
Completed the procedural macro for deriving Graphs.
The macro `graph_derive` can automatically write the boiler-plate
codes for wrapper types one of whose sub-fields implements the `Graph`
trait. The generated implementation will delegate the `Graph`
operations to the sub-field which implements the `Graph` trait.
I plan to add more macros, corresponding to various other
graph-related traits, so that no such boiler-plate codes are needed,
at least for my use-cases.
Diffstat (limited to 'graph_macro/src/lib.rs')
-rw-r--r-- | graph_macro/src/lib.rs | 663 |
1 files changed, 654 insertions, 9 deletions
diff --git a/graph_macro/src/lib.rs b/graph_macro/src/lib.rs index a55efbb..0f56f57 100644 --- a/graph_macro/src/lib.rs +++ b/graph_macro/src/lib.rs @@ -1,19 +1,664 @@ -#![allow(unused_imports)] //! This file provides a macro to delegate the implementation of a //! type that wraps a graph. More precisely, the macro helps the //! wrapper type implement the various Graph-related traits. +//! +//! See the macro [`graph_derive`]. -use proc_macro::TokenStream; +use proc_macro::{Delimiter, Group, Ident, Literal, Punct, Spacing, Span, TokenStream, TokenTree}; -use core::iter::Peekable; +use std::{collections::HashMap, iter::Peekable}; -#[proc_macro_derive(Testing)] -pub fn test_derive(input: TokenStream) -> TokenStream { - let input = input.into_iter().peekable(); +/// A visibility modifer is either "pub", "pub(restriction)", or +/// "Private". +#[derive(Debug, Default)] +enum VisibilityMod { + /// Simple pub visibility. + Pub, + /// The pub visibility with restrictions. + /// + /// Since the restrictions can be given in a variety of ways, and + /// since it does not matter to us, we just keep them as a + /// tokentree. + PubRes(TokenTree), + #[default] + Private, +} + +#[derive(Debug)] +struct Generic { + name: Ident, + type_stream: TokenStream, +} + +impl Generic { + fn new(name: Ident, type_stream: TokenStream) -> Self { + Self { name, type_stream } + } +} + +/// The content of the struct +#[derive(Debug)] +struct GraphInfo { + name: TokenStream, + generics: Vec<Generic>, + field_name: String, + field_type: TokenStream, +} + +impl GraphInfo { + fn new( + name: TokenStream, + generics: Vec<Generic>, + field_name: String, + field_type: TokenStream, + ) -> Self { + Self { + name, + generics, + field_name, + field_type, + } + } +} + +/// For custom compiler errors +#[derive(Debug)] +struct CompileError { + mes: String, + span: Span, +} + +impl CompileError { + fn new(mes: impl ToString, span: Span) -> Self { + Self { + mes: mes.to_string(), + span, + } + } + + /// Produce a token stream that will trigger compiler errors. + /// + /// The point is that the span of the token trees will be set to + /// the desired locations, and hence produce the effect we want. + fn emit(self) -> TokenStream { + // I cannot use the parse method of the type TokenStream, as + // that will not set the spans properly. + + let compile_error_ident = TokenTree::Ident(Ident::new("compile_error", self.span.clone())); + let mut exclamation_punct = TokenTree::Punct(Punct::new('!', Spacing::Alone)); + + exclamation_punct.set_span(self.span.clone()); + + let mut arg_mes_literal = TokenTree::Literal(Literal::string(&self.mes)); + + arg_mes_literal.set_span(self.span.clone()); + + let arg_mes_stream = [arg_mes_literal].into_iter().collect(); + + let mut arg_group = TokenTree::Group(Group::new(Delimiter::Parenthesis, arg_mes_stream)); + + arg_group.set_span(self.span.clone()); + + let mut semi_colon_punct = TokenTree::Punct(Punct::new(';', Spacing::Alone)); + + semi_colon_punct.set_span(self.span); + + [ + compile_error_ident, + exclamation_punct, + arg_group, + semi_colon_punct, + ] + .into_iter() + .collect() + } +} + +macro_rules! unwrap_or_emit { + ($value:ident) => { + if let Err(err) = $value { + return err.emit(); + } + + let $value = $value.unwrap(); + }; +} + +#[allow(unused)] +macro_rules! mytodo { + () => { + Err(CompileError::new("not umplemented yet", Span::mixed_site())) + }; +} + +macro_rules! myerror { + ($mes:tt, $input:ident) => { + Err(CompileError::new($mes, $input.peek().unwrap().span())) + }; +} + +#[proc_macro_derive(Graph, attributes(graph))] +pub fn graph_derive(input: TokenStream) -> TokenStream { + let mut input = input.into_iter().peekable(); + + while move_attributes(&mut input).is_ok() {} + + let _ = get_visibility_mod(&mut input); + + let info = get_info(&mut input); + + unwrap_or_emit!(info); + + let name = info.name; + let generics = info.generics; + let field_name = info.field_name; + let field_type = info.field_type; + + let generics_impl_string = { + let mut result = String::new(); + + if !generics.is_empty() { + result.push_str("<"); + } + + for generic in generics.iter() { + result.push_str(&format!("{}:{},", generic.name, generic.type_stream)); + } + + if !generics.is_empty() { + result.push_str(">"); + } + + result + }; + + let generic_struct_string = { + let mut result = String::new(); + + if !generics.is_empty() { + result.push_str("<"); + } + + for generic in generics.iter() { + result.push_str(&format!("{},", generic.name)); + } + + if !generics.is_empty() { + result.push_str(">"); + } + + result + }; + + let generic_where_string = { + let mut result = String::new(); + + if !generics.is_empty() { + result.push_str("\nwhere\n"); + } + + for generic in generics.iter() { + result.push_str(&format!("{}: 'a,\n", generic.name)); + } + + result + }; + + let result = format!( + " +#[automatically_derived] +impl{generics_impl_string} graph::Graph for {name}{generic_struct_string} {{ +type Iter<'a> = <{field_type} as graph::Graph>::Iter<'a>{generic_where_string}; + +#[inline] +fn is_empty(&self) -> bool {{self.{field_name}.is_empty()}} + +#[inline] +fn nodes_len(&self) -> usize {{ self.{field_name}.nodes_len() }} + +#[inline] +fn children_of(&self, node: usize) -> Result<<Self as graph::Graph>::Iter<'_>, graph::error::Error> {{ +self.{field_name}.children_of(node) +}} + +#[inline] +fn degree(&self, node: usize) -> Result<usize, graph::error::Error> {{ self.{field_name}.degree(node) }} + +#[inline] +fn is_empty_node(&self, node: usize) -> Result<bool, graph::error::Error> {{ self.{field_name}.is_empty_node(node) }} + +#[inline] +fn has_edge(&self, nodea: usize, nodeb: usize) -> Result<bool, graph::error::Error> {{ +self.{field_name}.has_edge(nodea, nodeb) +}} + +#[inline] +fn replace_by_builder(&mut self, _builder: impl graph::builder::Builder<Result = Self>) {{ +unimplemented!() +}} +}}" + ); + + result.parse().unwrap() +} + +fn get_visibility_mod( + input: &mut Peekable<impl Iterator<Item = TokenTree>>, +) -> Result<VisibilityMod, CompileError> { + if let Some(TokenTree::Ident(id)) = input.peek() { + if id.to_string() == "pub" { + input.next(); + + if let Some(TokenTree::Group(g)) = input.peek() { + let delimiter = g.delimiter(); + + if delimiter != Delimiter::Parenthesis { + return Err(CompileError::new("invalid bracket", g.span())); + } + + let tree = TokenTree::Group(g.clone()); + + input.next(); + + return Ok(VisibilityMod::PubRes(tree)); + } else { + return Ok(VisibilityMod::Pub); + } + } + } + + Ok(Default::default()) +} + +fn get_info( + input: &mut Peekable<impl Iterator<Item = TokenTree>>, +) -> Result<GraphInfo, CompileError> { + if let Some(TokenTree::Ident(id)) = input.peek() { + if id.to_string() != "struct" { + let span = input.next().unwrap().span(); + return Err(CompileError::new("Only struct is supported", span)); + } + + input.next(); + + let struct_name = get_until( + input, + |tree| match tree { + TokenTree::Group(g) + if g.delimiter() == Delimiter::Brace + || g.delimiter() == Delimiter::Parenthesis => + { + true + } + TokenTree::Punct(p) if p.as_char() == ';' || p.as_char() == '<' => true, + _ => false, + }, + false, + false, + )?; + + // We are not at the end as guaranteed by `get_until`. + let peek = input.peek().unwrap(); + + if matches!(peek, TokenTree::Punct(p) if p.as_char() == ';') { + // unit struct + + return myerror!("Cannot derive graph for unit structs", input); + } + + let mut generics = Vec::new(); + + if matches!(peek, TokenTree::Punct(p) if p.as_char() == '<') { + // Extract generics + + input.next(); + + generics = get_generics(input)?; + + // println!("generics: {generics:?}"); + } + + let (graph_field_name, graph_field_type) = get_graph_field_name_type(input)?; + + Ok(GraphInfo::new( + struct_name, + generics, + graph_field_name, + graph_field_type, + )) + } else { + myerror!("unexpected token", input) + } +} + +fn get_graph_field_name_type( + input: &mut Peekable<impl Iterator<Item = TokenTree>>, +) -> Result<(String, TokenStream), CompileError> { + let end_of_input = Err(CompileError::new( + "unexpected end of input", + Span::mixed_site(), + )); + + match input.peek() { + Some(TokenTree::Group(g)) => { + let delimiter = g.delimiter(); + + if delimiter != Delimiter::Parenthesis && delimiter != Delimiter::Brace { + return myerror!( + "expected a group enclosed in curly brackets or parentheses", + input + ); + } + + if delimiter == Delimiter::Parenthesis { + // a named tuple + + // in this case only a tuple with one field is + // supported + + let mut body = g.stream().into_iter().peekable(); + + while let Some(tree) = body.peek() { + if matches!(tree, TokenTree::Punct(p) if p.as_char() == ',') { + return myerror!( + "Among named tuples, only those with one field are supported", + input + ); + } + + body.next(); + } + + let mut body = g.stream().into_iter().peekable(); + + let _vis = get_visibility_mod(&mut body)?; + + return Ok(("0".to_string(), body.collect())); + } + + let mut body = g.stream().into_iter().peekable(); + + if body.peek().is_none() { + return end_of_input; + } + + get_visibility_mod(&mut body)?; + + let mut cloned_body = body.clone(); + + let mut result_name = get_first_ident(&mut cloned_body, g.span())?.to_string(); + + move_punct(&mut cloned_body, ':')?; + + if cloned_body.peek().is_none() { + return end_of_input; + } + + let mut result_type = get_until( + &mut cloned_body, + |tree| matches!(tree, TokenTree::Punct(p) if p.as_char() == ','), + true, + true, + )?; + + let mut result = (result_name, result_type); + + 'search: while let Some(tree) = body.next() { + match tree { + TokenTree::Punct(p) if p.as_char() == '#' => match body.peek() { + Some(TokenTree::Group(g)) if g.delimiter() == Delimiter::Bracket => { + let mut attr_stream = g.stream().into_iter().peekable(); + + let mut found_attribute = false; + + while let Some(TokenTree::Ident(id)) = attr_stream.next() { + if id.to_string() == "graph" { + found_attribute = true; + + break; + } + + let _ = move_punct(&mut attr_stream, ','); + } + + body.next(); + + if found_attribute { + get_visibility_mod(&mut body)?; + + result_name = get_ident(&mut body)?.to_string(); + + move_punct(&mut body, ':')?; + + result_type = get_until( + &mut body, + |tree| matches!(tree, TokenTree::Punct(p) if p.as_char() == ','), + true, + true, + )?; + + result = (result_name, result_type); + + break 'search; + } + } + Some(_) => { + return myerror!("expected a group enclosed in square brackets", body); + } + None => { + return end_of_input; + } + }, + _ => {} + } + } + + Ok(result) + } + Some(_) => myerror!("expected a group or a semicolon", input), + + _ => end_of_input, + } +} + +fn get_generics( + input: &mut Peekable<impl Iterator<Item = TokenTree>>, +) -> Result<Vec<Generic>, CompileError> { + // Assume the starting '<' punct has already been consumed. + + let mut generic_map: HashMap<String, TokenStream> = Default::default(); + + let mut result: Vec<Generic> = Vec::new(); + + while !matches!(input.peek(), Some(TokenTree::Punct(p)) if p.as_char() == '>') { + let ident = get_ident(input)?.to_string(); + + let mut stream = TokenStream::new(); + + if matches!(input.peek(), Some(TokenTree::Punct(p)) if p.as_char() == ':') { + input.next(); + stream = get_until( + input, + |tree| matches!(tree, TokenTree::Punct(p) if p.as_char() == ',' || p.as_char() == '>'), + false, + false, + )?; + } + + generic_map.insert(ident, stream); - for tree in input { - println!("tree = {tree:?}"); + let _ = move_punct(input, ','); } - TokenStream::new() + assert!(matches!(input.next(), Some(TokenTree::Punct(p)) if p.as_char() == '>')); + + if matches!(input.peek(), Some(TokenTree::Ident(id)) if id.to_string() == "where") { + // where clauses + + input.next(); + + while !matches!(input.peek(), Some(TokenTree::Group(g)) if g.delimiter() == Delimiter::Brace) + { + let ident = get_ident(input)?.to_string(); + + move_punct(input, ':')?; + + let stream = get_until( + input, + |tree| match tree { + TokenTree::Punct(p) if p.as_char() == ',' => true, + TokenTree::Group(g) if g.delimiter() == Delimiter::Brace => true, + _ => false, + }, + false, + false, + )?; + + generic_map.insert(ident, stream); + + let _ = move_punct(input, ','); + } + } + + if generic_map.is_empty() { + // this is not valid syntax anyways + return myerror!("empty generics section?", input); + } + + for (k, v) in generic_map { + result.push(Generic::new(Ident::new(k.as_str(), Span::mixed_site()), v)); + } + + Ok(result) +} + +fn get_ident(input: &mut Peekable<impl Iterator<Item = TokenTree>>) -> Result<Ident, CompileError> { + match input.peek() { + Some(TokenTree::Ident(id)) => { + let id = id.clone(); + + input.next(); + + Ok(id) + } + Some(next_tree) => { + let span = next_tree.span(); + + Err(CompileError::new("expected an identifier", span)) + } + _ => Err(CompileError::new( + "unexpected end of input", + Span::mixed_site(), + )), + } +} + +fn get_first_ident( + input: &mut Peekable<impl Iterator<Item = TokenTree>>, + span: Span, +) -> Result<Ident, CompileError> { + while let Some(tree) = input.next() { + match tree { + TokenTree::Ident(id) => { + return Ok(id); + } + _ => {} + } + } + + Err(CompileError::new("expect at least one identifier", span)) +} + +fn get_until( + input: &mut Peekable<impl Iterator<Item = TokenTree>>, + predicate: impl Fn(&TokenTree) -> bool, + accept_end: bool, + consume_boundary: bool, +) -> Result<TokenStream, CompileError> { + let mut trees: Vec<TokenTree> = Vec::new(); + + let mut found_predicate = false; + + if consume_boundary { + while let Some(tree) = input.next() { + if predicate(&tree) { + found_predicate = true; + break; + } + + trees.push(tree); + } + } else { + while let Some(tree) = input.peek() { + if predicate(tree) { + found_predicate = true; + break; + } + + let tree = input.next().unwrap(); + + trees.push(tree); + } + } + + if found_predicate || accept_end { + Ok(trees.into_iter().collect()) + } else { + Err(CompileError::new( + "unexpected end of input", + Span::mixed_site(), + )) + } +} + +fn move_punct( + input: &mut Peekable<impl Iterator<Item = TokenTree>>, + punct: char, +) -> Result<(), CompileError> { + let error_mes = format!("expect punctuation {punct}"); + + match input.peek() { + Some(TokenTree::Punct(p)) if p.as_char() == punct => { + input.next(); + + Ok(()) + } + Some(_) => myerror!(error_mes, input), + _ => Err(CompileError::new( + "unexpected end of input", + Span::mixed_site(), + )), + } +} + +fn move_attributes( + input: &mut Peekable<impl Iterator<Item = TokenTree>>, +) -> Result<(), CompileError> { + let error_mes = "expect attributes"; + let end_of_input = Err(CompileError::new( + "unexpected end of input", + Span::mixed_site(), + )); + + match input.peek() { + Some(TokenTree::Punct(p)) if p.as_char() == '#' => { + input.next(); + + match input.peek() { + Some(TokenTree::Group(g)) if g.delimiter() == Delimiter::Bracket => { + input.next(); + } + Some(_) => { + return myerror!("expected a group in square brackets", input); + } + _ => { + return end_of_input; + } + } + + Ok(()) + } + Some(_) => myerror!(error_mes, input), + _ => end_of_input, + } } |