diff options
Diffstat (limited to 'src/syntax/ast.rs')
| -rw-r--r-- | src/syntax/ast.rs | 147 |
1 files changed, 110 insertions, 37 deletions
diff --git a/src/syntax/ast.rs b/src/syntax/ast.rs index 780c6164..94114958 100644 --- a/src/syntax/ast.rs +++ b/src/syntax/ast.rs @@ -1533,7 +1533,10 @@ impl Closure { /// /// This only exists if you use the function syntax sugar: `let f(x) = y`. pub fn name(&self) -> Option<Ident> { - self.0.children().next()?.cast() + match self.0.cast_first_match::<Pattern>()?.kind() { + PatternKind::Ident(ident) => Some(ident), + _ => Option::None, + } } /// The parameter bindings. @@ -1590,28 +1593,121 @@ impl AstNode for Param { } node! { + /// A destructuring pattern: `x` or `(x, _, ..y)`. + Pattern +} + +/// The kind of a pattern. +#[derive(Debug, Clone, Hash)] +pub enum PatternKind { + /// A single identifier: `x`. + Ident(Ident), + /// A destructuring pattern: `(x, _, ..y)`. + Destructure(Vec<DestructuringKind>), +} + +/// The kind of an element in a destructuring pattern. +#[derive(Debug, Clone, Hash)] +pub enum DestructuringKind { + /// An identifier: `x`. + Ident(Ident), + /// An argument sink: `..y`. + Sink(Option<Ident>), + /// Named arguments: `x: 1`. + Named(Ident, Ident), +} + +impl Pattern { + /// The kind of the pattern. + pub fn kind(&self) -> PatternKind { + if self.0.children().len() <= 1 { + return PatternKind::Ident(self.0.cast_first_match().unwrap_or_default()); + } + + let mut bindings = Vec::new(); + for child in self.0.children() { + match child.kind() { + SyntaxKind::Ident => { + bindings + .push(DestructuringKind::Ident(child.cast().unwrap_or_default())); + } + SyntaxKind::Spread => { + bindings.push(DestructuringKind::Sink(child.cast_first_match())); + } + SyntaxKind::Named => { + let mut filtered = child.children().filter_map(SyntaxNode::cast); + let key = filtered.next().unwrap_or_default(); + let ident = filtered.next().unwrap_or_default(); + bindings.push(DestructuringKind::Named(key, ident)); + } + _ => (), + } + } + + PatternKind::Destructure(bindings) + } + + // Returns a list of all identifiers in the pattern. + pub fn idents(&self) -> Vec<Ident> { + match self.kind() { + PatternKind::Ident(ident) => vec![ident], + PatternKind::Destructure(bindings) => bindings + .into_iter() + .filter_map(|binding| match binding { + DestructuringKind::Ident(ident) => Some(ident), + DestructuringKind::Sink(ident) => ident, + DestructuringKind::Named(_, ident) => Some(ident), + }) + .collect(), + } + } +} + +node! { /// A let binding: `let x = 1`. LetBinding } +pub enum LetBindingKind { + /// A normal binding: `let x = 1`. + Normal(Pattern), + /// A closure binding: `let f(x) = 1`. + Closure(Ident), +} + +impl LetBindingKind { + // Returns a list of all identifiers in the pattern. + pub fn idents(&self) -> Vec<Ident> { + match self { + LetBindingKind::Normal(pattern) => pattern.idents(), + LetBindingKind::Closure(ident) => { + vec![ident.clone()] + } + } + } +} + impl LetBinding { - /// The binding to assign to. - pub fn binding(&self) -> Ident { - match self.0.cast_first_match() { - Some(Expr::Ident(binding)) => binding, - Some(Expr::Closure(closure)) => closure.name().unwrap_or_default(), - _ => Ident::default(), + /// The kind of the let binding. + pub fn kind(&self) -> LetBindingKind { + if let Some(pattern) = self.0.cast_first_match::<Pattern>() { + LetBindingKind::Normal(pattern) + } else { + LetBindingKind::Closure( + self.0 + .cast_first_match::<Closure>() + .unwrap_or_default() + .name() + .unwrap_or_default(), + ) } } /// The expression the binding is initialized with. pub fn init(&self) -> Option<Expr> { - if self.0.cast_first_match::<Ident>().is_some() { - // This is a normal binding like `let x = 1`. - self.0.children().filter_map(SyntaxNode::cast).nth(1) - } else { - // This is a closure binding like `let f(x) = 1`. - self.0.cast_first_match() + match self.kind() { + LetBindingKind::Normal(_) => self.0.cast_last_match(), + LetBindingKind::Closure(_) => self.0.cast_first_match(), } } } @@ -1712,7 +1808,7 @@ node! { impl ForLoop { /// The pattern to assign to. - pub fn pattern(&self) -> ForPattern { + pub fn pattern(&self) -> Pattern { self.0.cast_first_match().unwrap_or_default() } @@ -1728,29 +1824,6 @@ impl ForLoop { } node! { - /// A for loop's destructuring pattern: `x` or `x, y`. - ForPattern -} - -impl ForPattern { - /// The key part of the pattern: index for arrays, name for dictionaries. - pub fn key(&self) -> Option<Ident> { - let mut children = self.0.children().filter_map(SyntaxNode::cast); - let key = children.next(); - if children.next().is_some() { - key - } else { - Option::None - } - } - - /// The value part of the pattern. - pub fn value(&self) -> Ident { - self.0.cast_last_match().unwrap_or_default() - } -} - -node! { /// A module import: `import "utils.typ": a, b, c`. ModuleImport } |
