//! This file provides a macro to delegate the implementation of a //! type that wraps a graph. More precisely, the macro helps the //! wrapper type implement the various Graph-related traits. //! //! See the macro [`graph_derive`]. use proc_macro::{Delimiter, Group, Ident, Literal, Punct, Spacing, Span, TokenStream, TokenTree}; use std::{collections::HashMap, iter::Peekable}; /// A visibility modifer is either "pub", "pub(restriction)", or /// "Private". #[derive(Debug, Default)] enum VisibilityMod { /// Simple pub visibility. Pub, /// The pub visibility with restrictions. /// /// Since the restrictions can be given in a variety of ways, and /// since it does not matter to us, we just keep them as a /// tokentree. PubRes(TokenTree), #[default] Private, } #[derive(Debug)] struct Generic { name: Ident, type_stream: TokenStream, } impl Generic { fn new(name: Ident, type_stream: TokenStream) -> Self { Self { name, type_stream } } } /// The content of the struct #[derive(Debug)] struct GraphInfo { name: TokenStream, generics: Vec, field_name: String, field_type: TokenStream, } impl GraphInfo { fn new( name: TokenStream, generics: Vec, field_name: String, field_type: TokenStream, ) -> Self { Self { name, generics, field_name, field_type, } } } /// For custom compiler errors #[derive(Debug)] struct CompileError { mes: String, span: Span, } impl CompileError { fn new(mes: impl ToString, span: Span) -> Self { Self { mes: mes.to_string(), span, } } /// Produce a token stream that will trigger compiler errors. /// /// The point is that the span of the token trees will be set to /// the desired locations, and hence produce the effect we want. fn emit(self) -> TokenStream { // I cannot use the parse method of the type TokenStream, as // that will not set the spans properly. let compile_error_ident = TokenTree::Ident(Ident::new("compile_error", self.span.clone())); let mut exclamation_punct = TokenTree::Punct(Punct::new('!', Spacing::Alone)); exclamation_punct.set_span(self.span.clone()); let mut arg_mes_literal = TokenTree::Literal(Literal::string(&self.mes)); arg_mes_literal.set_span(self.span.clone()); let arg_mes_stream = [arg_mes_literal].into_iter().collect(); let mut arg_group = TokenTree::Group(Group::new(Delimiter::Parenthesis, arg_mes_stream)); arg_group.set_span(self.span.clone()); let mut semi_colon_punct = TokenTree::Punct(Punct::new(';', Spacing::Alone)); semi_colon_punct.set_span(self.span); [ compile_error_ident, exclamation_punct, arg_group, semi_colon_punct, ] .into_iter() .collect() } } macro_rules! unwrap_or_emit { ($value:ident) => { if let Err(err) = $value { return err.emit(); } let $value = $value.unwrap(); }; } #[allow(unused)] macro_rules! mytodo { () => { Err(CompileError::new("not umplemented yet", Span::mixed_site())) }; } macro_rules! myerror { ($mes:tt, $input:ident) => { Err(CompileError::new($mes, $input.peek().unwrap().span())) }; } #[proc_macro_derive(Graph, attributes(graph))] pub fn graph_derive(input: TokenStream) -> TokenStream { let mut input = input.into_iter().peekable(); while move_attributes(&mut input).is_ok() {} let _ = get_visibility_mod(&mut input); let info = get_info(&mut input); unwrap_or_emit!(info); let name = info.name; let generics = info.generics; let field_name = info.field_name; let field_type = info.field_type; let generics_impl_string = { let mut result = String::new(); if !generics.is_empty() { result.push_str("<"); } for generic in generics.iter() { result.push_str(&format!("{}:{},", generic.name, generic.type_stream)); } if !generics.is_empty() { result.push_str(">"); } result }; let generic_struct_string = { let mut result = String::new(); if !generics.is_empty() { result.push_str("<"); } for generic in generics.iter() { result.push_str(&format!("{},", generic.name)); } if !generics.is_empty() { result.push_str(">"); } result }; let generic_where_string = { let mut result = String::new(); if !generics.is_empty() { result.push_str("\nwhere\n"); } for generic in generics.iter() { result.push_str(&format!("{}: 'a,\n", generic.name)); } result }; let result = format!( " #[automatically_derived] impl{generics_impl_string} graph::Graph for {name}{generic_struct_string} {{ type Iter<'a> = <{field_type} as graph::Graph>::Iter<'a>{generic_where_string}; #[inline] fn is_empty(&self) -> bool {{self.{field_name}.is_empty()}} #[inline] fn nodes_len(&self) -> usize {{ self.{field_name}.nodes_len() }} #[inline] fn children_of(&self, node: usize) -> Result<::Iter<'_>, graph::error::Error> {{ self.{field_name}.children_of(node) }} #[inline] fn degree(&self, node: usize) -> Result {{ self.{field_name}.degree(node) }} #[inline] fn is_empty_node(&self, node: usize) -> Result {{ self.{field_name}.is_empty_node(node) }} #[inline] fn has_edge(&self, nodea: usize, nodeb: usize) -> Result {{ self.{field_name}.has_edge(nodea, nodeb) }} #[inline] fn replace_by_builder(&mut self, _builder: impl graph::builder::Builder) {{ unimplemented!() }} }}" ); result.parse().unwrap() } fn get_visibility_mod( input: &mut Peekable>, ) -> Result { if let Some(TokenTree::Ident(id)) = input.peek() { if id.to_string() == "pub" { input.next(); if let Some(TokenTree::Group(g)) = input.peek() { let delimiter = g.delimiter(); if delimiter != Delimiter::Parenthesis { return Err(CompileError::new("invalid bracket", g.span())); } let tree = TokenTree::Group(g.clone()); input.next(); return Ok(VisibilityMod::PubRes(tree)); } else { return Ok(VisibilityMod::Pub); } } } Ok(Default::default()) } fn get_info( input: &mut Peekable>, ) -> Result { if let Some(TokenTree::Ident(id)) = input.peek() { if id.to_string() != "struct" { let span = input.next().unwrap().span(); return Err(CompileError::new("Only struct is supported", span)); } input.next(); let struct_name = get_until( input, |tree| match tree { TokenTree::Group(g) if g.delimiter() == Delimiter::Brace || g.delimiter() == Delimiter::Parenthesis => { true } TokenTree::Punct(p) if p.as_char() == ';' || p.as_char() == '<' => true, _ => false, }, false, false, )?; // We are not at the end as guaranteed by `get_until`. let peek = input.peek().unwrap(); if matches!(peek, TokenTree::Punct(p) if p.as_char() == ';') { // unit struct return myerror!("Cannot derive graph for unit structs", input); } let mut generics = Vec::new(); if matches!(peek, TokenTree::Punct(p) if p.as_char() == '<') { // Extract generics input.next(); generics = get_generics(input)?; // println!("generics: {generics:?}"); } let (graph_field_name, graph_field_type) = get_graph_field_name_type(input)?; Ok(GraphInfo::new( struct_name, generics, graph_field_name, graph_field_type, )) } else { myerror!("unexpected token", input) } } fn get_graph_field_name_type( input: &mut Peekable>, ) -> Result<(String, TokenStream), CompileError> { let end_of_input = Err(CompileError::new( "unexpected end of input", Span::mixed_site(), )); match input.peek() { Some(TokenTree::Group(g)) => { let delimiter = g.delimiter(); if delimiter != Delimiter::Parenthesis && delimiter != Delimiter::Brace { return myerror!( "expected a group enclosed in curly brackets or parentheses", input ); } if delimiter == Delimiter::Parenthesis { // a named tuple // in this case only a tuple with one field is // supported let mut body = g.stream().into_iter().peekable(); while let Some(tree) = body.peek() { if matches!(tree, TokenTree::Punct(p) if p.as_char() == ',') { return myerror!( "Among named tuples, only those with one field are supported", input ); } body.next(); } let mut body = g.stream().into_iter().peekable(); let _vis = get_visibility_mod(&mut body)?; return Ok(("0".to_string(), body.collect())); } let mut body = g.stream().into_iter().peekable(); if body.peek().is_none() { return end_of_input; } get_visibility_mod(&mut body)?; let mut cloned_body = body.clone(); let mut result_name = get_first_ident(&mut cloned_body, g.span())?.to_string(); move_punct(&mut cloned_body, ':')?; if cloned_body.peek().is_none() { return end_of_input; } let mut result_type = get_until( &mut cloned_body, |tree| matches!(tree, TokenTree::Punct(p) if p.as_char() == ','), true, true, )?; let mut result = (result_name, result_type); 'search: while let Some(tree) = body.next() { match tree { TokenTree::Punct(p) if p.as_char() == '#' => match body.peek() { Some(TokenTree::Group(g)) if g.delimiter() == Delimiter::Bracket => { let mut attr_stream = g.stream().into_iter().peekable(); let mut found_attribute = false; while let Some(TokenTree::Ident(id)) = attr_stream.next() { if id.to_string() == "graph" { found_attribute = true; break; } let _ = move_punct(&mut attr_stream, ','); } body.next(); if found_attribute { get_visibility_mod(&mut body)?; result_name = get_ident(&mut body)?.to_string(); move_punct(&mut body, ':')?; result_type = get_until( &mut body, |tree| matches!(tree, TokenTree::Punct(p) if p.as_char() == ','), true, true, )?; result = (result_name, result_type); break 'search; } } Some(_) => { return myerror!("expected a group enclosed in square brackets", body); } None => { return end_of_input; } }, _ => {} } } Ok(result) } Some(_) => myerror!("expected a group or a semicolon", input), _ => end_of_input, } } fn get_generics( input: &mut Peekable>, ) -> Result, CompileError> { // Assume the starting '<' punct has already been consumed. let mut generic_map: HashMap = Default::default(); let mut result: Vec = Vec::new(); while !matches!(input.peek(), Some(TokenTree::Punct(p)) if p.as_char() == '>') { let ident = get_ident(input)?.to_string(); let mut stream = TokenStream::new(); if matches!(input.peek(), Some(TokenTree::Punct(p)) if p.as_char() == ':') { input.next(); stream = get_until( input, |tree| matches!(tree, TokenTree::Punct(p) if p.as_char() == ',' || p.as_char() == '>'), false, false, )?; } generic_map.insert(ident, stream); let _ = move_punct(input, ','); } assert!(matches!(input.next(), Some(TokenTree::Punct(p)) if p.as_char() == '>')); if matches!(input.peek(), Some(TokenTree::Ident(id)) if id.to_string() == "where") { // where clauses input.next(); while !matches!(input.peek(), Some(TokenTree::Group(g)) if g.delimiter() == Delimiter::Brace) { let ident = get_ident(input)?.to_string(); move_punct(input, ':')?; let stream = get_until( input, |tree| match tree { TokenTree::Punct(p) if p.as_char() == ',' => true, TokenTree::Group(g) if g.delimiter() == Delimiter::Brace => true, _ => false, }, false, false, )?; generic_map.insert(ident, stream); let _ = move_punct(input, ','); } } if generic_map.is_empty() { // this is not valid syntax anyways return myerror!("empty generics section?", input); } for (k, v) in generic_map { result.push(Generic::new(Ident::new(k.as_str(), Span::mixed_site()), v)); } Ok(result) } fn get_ident(input: &mut Peekable>) -> Result { match input.peek() { Some(TokenTree::Ident(id)) => { let id = id.clone(); input.next(); Ok(id) } Some(next_tree) => { let span = next_tree.span(); Err(CompileError::new("expected an identifier", span)) } _ => Err(CompileError::new( "unexpected end of input", Span::mixed_site(), )), } } fn get_first_ident( input: &mut Peekable>, span: Span, ) -> Result { while let Some(tree) = input.next() { match tree { TokenTree::Ident(id) => { return Ok(id); } _ => {} } } Err(CompileError::new("expect at least one identifier", span)) } fn get_until( input: &mut Peekable>, predicate: impl Fn(&TokenTree) -> bool, accept_end: bool, consume_boundary: bool, ) -> Result { let mut trees: Vec = Vec::new(); let mut found_predicate = false; if consume_boundary { while let Some(tree) = input.next() { if predicate(&tree) { found_predicate = true; break; } trees.push(tree); } } else { while let Some(tree) = input.peek() { if predicate(tree) { found_predicate = true; break; } let tree = input.next().unwrap(); trees.push(tree); } } if found_predicate || accept_end { Ok(trees.into_iter().collect()) } else { Err(CompileError::new( "unexpected end of input", Span::mixed_site(), )) } } fn move_punct( input: &mut Peekable>, punct: char, ) -> Result<(), CompileError> { let error_mes = format!("expect punctuation {punct}"); match input.peek() { Some(TokenTree::Punct(p)) if p.as_char() == punct => { input.next(); Ok(()) } Some(_) => myerror!(error_mes, input), _ => Err(CompileError::new( "unexpected end of input", Span::mixed_site(), )), } } fn move_attributes( input: &mut Peekable>, ) -> Result<(), CompileError> { let error_mes = "expect attributes"; let end_of_input = Err(CompileError::new( "unexpected end of input", Span::mixed_site(), )); match input.peek() { Some(TokenTree::Punct(p)) if p.as_char() == '#' => { input.next(); match input.peek() { Some(TokenTree::Group(g)) if g.delimiter() == Delimiter::Bracket => { input.next(); } Some(_) => { return myerror!("expected a group in square brackets", input); } _ => { return end_of_input; } } Ok(()) } Some(_) => myerror!(error_mes, input), _ => end_of_input, } }