summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorLaurenz <laurmaedje@gmail.com>2021-03-03 17:53:40 +0100
committerLaurenz <laurmaedje@gmail.com>2021-03-03 17:53:40 +0100
commitc94a18833f23d2b57de1b87971458fd54b56d088 (patch)
tree9e1ed55cfca15aef6d39ced50a3a5b14d2800aae /src
parent4d90a066f197264341eff6bf67e8c06cae434eb4 (diff)
Closures and function definitions 🚀
Supports: - Closure syntax: `(x, y) => z` - Shorthand for a single argument: `x => y` - Function syntax: `let f(x) = y` - Capturing of variables from the environment - Error messages for too few / many passed arguments Does not support: - Named arguments - Variadic arguments with `..`
Diffstat (limited to 'src')
-rw-r--r--src/eval/capture.rs43
-rw-r--r--src/eval/mod.rs45
-rw-r--r--src/eval/scope.rs17
-rw-r--r--src/eval/value.rs10
-rw-r--r--src/library/mod.rs2
-rw-r--r--src/parse/mod.rs92
-rw-r--r--src/pretty.rs31
-rw-r--r--src/syntax/expr.rs20
-rw-r--r--src/syntax/visit.rs53
9 files changed, 231 insertions, 82 deletions
diff --git a/src/eval/capture.rs b/src/eval/capture.rs
index 163aa24e..05760594 100644
--- a/src/eval/capture.rs
+++ b/src/eval/capture.rs
@@ -25,16 +25,11 @@ impl<'a> CapturesVisitor<'a> {
pub fn finish(self) -> Scope {
self.captures
}
-
- /// Define an internal variable.
- fn define(&mut self, ident: &Ident) {
- self.internal.def_mut(ident.as_str(), Value::None);
- }
}
impl<'ast> Visit<'ast> for CapturesVisitor<'_> {
- fn visit_expr(&mut self, item: &'ast Expr) {
- match item {
+ fn visit_expr(&mut self, node: &'ast Expr) {
+ match node {
Expr::Ident(ident) => {
// Find out whether the identifier is not locally defined, but
// captured, and if so, replace it with its value.
@@ -48,37 +43,15 @@ impl<'ast> Visit<'ast> for CapturesVisitor<'_> {
}
}
- fn visit_block(&mut self, item: &'ast ExprBlock) {
- // Blocks create a scope except if directly in a template.
- if item.scoping {
- self.internal.push();
- }
- visit_block(self, item);
- if item.scoping {
- self.internal.pop();
- }
- }
-
- fn visit_template(&mut self, item: &'ast ExprTemplate) {
- // Templates always create a scope.
- self.internal.push();
- visit_template(self, item);
- self.internal.pop();
+ fn visit_binding(&mut self, id: &'ast Ident) {
+ self.internal.def_mut(id.as_str(), Value::None);
}
- fn visit_let(&mut self, item: &'ast ExprLet) {
- self.define(&item.binding);
- visit_let(self, item);
+ fn visit_enter(&mut self) {
+ self.internal.enter();
}
- fn visit_for(&mut self, item: &'ast ExprFor) {
- match &item.pattern {
- ForPattern::Value(value) => self.define(value),
- ForPattern::KeyValue(key, value) => {
- self.define(key);
- self.define(value);
- }
- }
- visit_for(self, item);
+ fn visit_exit(&mut self) {
+ self.internal.exit();
}
}
diff --git a/src/eval/mod.rs b/src/eval/mod.rs
index c66f2ad2..f30ee7a7 100644
--- a/src/eval/mod.rs
+++ b/src/eval/mod.rs
@@ -114,6 +114,7 @@ impl Eval for Expr {
Self::Group(v) => v.eval(ctx),
Self::Block(v) => v.eval(ctx),
Self::Call(v) => v.eval(ctx),
+ Self::Closure(v) => v.eval(ctx),
Self::Unary(v) => v.eval(ctx),
Self::Binary(v) => v.eval(ctx),
Self::Let(v) => v.eval(ctx),
@@ -184,7 +185,7 @@ impl Eval for ExprBlock {
fn eval(&self, ctx: &mut EvalContext) -> Self::Output {
if self.scoping {
- ctx.scopes.push();
+ ctx.scopes.enter();
}
let mut output = Value::None;
@@ -193,7 +194,7 @@ impl Eval for ExprBlock {
}
if self.scoping {
- ctx.scopes.pop();
+ ctx.scopes.exit();
}
output
@@ -386,6 +387,40 @@ impl Eval for ExprArg {
}
}
+impl Eval for ExprClosure {
+ type Output = Value;
+
+ fn eval(&self, ctx: &mut EvalContext) -> Self::Output {
+ let params = Rc::clone(&self.params);
+ let body = Rc::clone(&self.body);
+
+ // Collect the captured variables.
+ let captured = {
+ let mut visitor = CapturesVisitor::new(&ctx.scopes);
+ visitor.visit_closure(self);
+ visitor.finish()
+ };
+
+ Value::Func(ValueFunc::new(None, move |ctx, args| {
+ // Don't leak the scopes from the call site. Instead, we use the
+ // scope of captured variables we collected earlier.
+ let prev = std::mem::take(&mut ctx.scopes);
+ ctx.scopes.top = captured.clone();
+
+ for param in params.iter() {
+ // Set the parameter to `none` if the argument is missing.
+ let value =
+ args.require::<Value>(ctx, param.as_str()).unwrap_or_default();
+ ctx.scopes.def_mut(param.as_str(), value);
+ }
+
+ let value = body.eval(ctx);
+ ctx.scopes = prev;
+ value
+ }))
+ }
+}
+
impl Eval for ExprLet {
type Output = Value;
@@ -464,7 +499,7 @@ impl Eval for ExprFor {
macro_rules! iter {
(for ($($binding:ident => $value:ident),*) in $iter:expr) => {{
let mut output = vec![];
- ctx.scopes.push();
+ ctx.scopes.enter();
#[allow(unused_parens)]
for ($($value),*) in $iter {
@@ -474,14 +509,14 @@ impl Eval for ExprFor {
Value::Template(v) => output.extend(v),
Value::Str(v) => output.push(TemplateNode::Str(v)),
Value::Error => {
- ctx.scopes.pop();
+ ctx.scopes.exit();
return Value::Error;
}
_ => {}
}
}
- ctx.scopes.pop();
+ ctx.scopes.exit();
Value::Template(output)
}};
}
diff --git a/src/eval/scope.rs b/src/eval/scope.rs
index 0991564f..c0926c0c 100644
--- a/src/eval/scope.rs
+++ b/src/eval/scope.rs
@@ -13,11 +13,11 @@ pub type Slot = Rc<RefCell<Value>>;
#[derive(Debug, Default, Clone, PartialEq)]
pub struct Scopes<'a> {
/// The active scope.
- top: Scope,
+ pub top: Scope,
/// The stack of lower scopes.
- scopes: Vec<Scope>,
+ pub scopes: Vec<Scope>,
/// The base scope.
- base: Option<&'a Scope>,
+ pub base: Option<&'a Scope>,
}
impl<'a> Scopes<'a> {
@@ -39,16 +39,16 @@ impl<'a> Scopes<'a> {
}
}
- /// Push a new scope.
- pub fn push(&mut self) {
+ /// Enter a new scope.
+ pub fn enter(&mut self) {
self.scopes.push(std::mem::take(&mut self.top));
}
- /// Pop the topmost scope.
+ /// Exit the topmost scope.
///
/// # Panics
- /// Panics if no scope was pushed.
- pub fn pop(&mut self) {
+ /// Panics if no scope was entered.
+ pub fn exit(&mut self) {
self.top = self.scopes.pop().expect("no pushed scope");
}
@@ -74,6 +74,7 @@ impl<'a> Scopes<'a> {
/// A map from variable names to variable slots.
#[derive(Default, Clone, PartialEq)]
pub struct Scope {
+ /// The mapping from names to slots.
values: HashMap<String, Slot>,
}
diff --git a/src/eval/value.rs b/src/eval/value.rs
index d910155a..7f31ea13 100644
--- a/src/eval/value.rs
+++ b/src/eval/value.rs
@@ -172,22 +172,22 @@ impl Debug for TemplateFunc {
/// A wrapper around a reference-counted executable function.
#[derive(Clone)]
pub struct ValueFunc {
- name: String,
+ name: Option<String>,
f: Rc<dyn Fn(&mut EvalContext, &mut ValueArgs) -> Value>,
}
impl ValueFunc {
/// Create a new function value from a rust function or closure.
- pub fn new<F>(name: impl Into<String>, f: F) -> Self
+ pub fn new<F>(name: Option<String>, f: F) -> Self
where
F: Fn(&mut EvalContext, &mut ValueArgs) -> Value + 'static,
{
- Self { name: name.into(), f: Rc::new(f) }
+ Self { name, f: Rc::new(f) }
}
/// The name of the function.
- pub fn name(&self) -> &str {
- &self.name
+ pub fn name(&self) -> Option<&str> {
+ self.name.as_deref()
}
}
diff --git a/src/library/mod.rs b/src/library/mod.rs
index 59198846..d34b338c 100644
--- a/src/library/mod.rs
+++ b/src/library/mod.rs
@@ -23,7 +23,7 @@ pub fn new() -> Scope {
let mut std = Scope::new();
macro_rules! set {
(func: $name:expr, $func:expr) => {
- std.def_const($name, ValueFunc::new($name, $func))
+ std.def_const($name, ValueFunc::new(Some($name.into()), $func))
};
(any: $var:expr, $any:expr) => {
std.def_const($var, ValueAny::new($any))
diff --git a/src/parse/mod.rs b/src/parse/mod.rs
index 327a99f3..29801527 100644
--- a/src/parse/mod.rs
+++ b/src/parse/mod.rs
@@ -173,21 +173,37 @@ fn primary(p: &mut Parser) -> Option<Expr> {
}
match p.peek() {
- // Function or identifier.
+ // Things that start with an identifier.
Some(Token::Ident(string)) => {
let ident = Ident {
span: p.eat_span(),
string: string.into(),
};
- match p.peek_direct() {
- Some(Token::LeftParen) | Some(Token::LeftBracket) => Some(call(p, ident)),
- _ => Some(Expr::Ident(ident)),
+ // Parenthesis or bracket means this is a function call.
+ if matches!(
+ p.peek_direct(),
+ Some(Token::LeftParen) | Some(Token::LeftBracket),
+ ) {
+ return Some(call(p, ident));
}
+
+ // Arrow means this is closure's lone parameter.
+ if p.eat_if(Token::Arrow) {
+ return expr(p).map(|body| {
+ Expr::Closure(ExprClosure {
+ span: ident.span.join(body.span()),
+ params: Rc::new(vec![ident]),
+ body: Rc::new(body),
+ })
+ });
+ }
+
+ Some(Expr::Ident(ident))
}
// Structures.
- Some(Token::LeftParen) => Some(parenthesized(p)),
+ Some(Token::LeftParen) => parenthesized(p),
Some(Token::LeftBracket) => Some(template(p)),
Some(Token::LeftBrace) => Some(block(p, true)),
@@ -228,23 +244,36 @@ fn literal(p: &mut Parser) -> Option<Expr> {
Some(Expr::Lit(Lit { span: p.eat_span(), kind }))
}
-/// Parse a parenthesized expression, which can be either of:
+/// Parse something that starts with a parenthesis, which can be either of:
/// - Array literal
/// - Dictionary literal
/// - Parenthesized expression
-pub fn parenthesized(p: &mut Parser) -> Expr {
+/// - Parameter list of closure expression
+pub fn parenthesized(p: &mut Parser) -> Option<Expr> {
p.start_group(Group::Paren, TokenMode::Code);
let colon = p.eat_if(Token::Colon);
let (items, has_comma) = collection(p);
let span = p.end_group();
+ // Leading colon makes this a dictionary.
if colon {
- // Leading colon makes this a dictionary.
- return dict(p, items, span);
+ return Some(dict(p, items, span));
+ }
+
+ // Arrow means this is closure's parameter list.
+ if p.eat_if(Token::Arrow) {
+ let params = params(p, items);
+ return expr(p).map(|body| {
+ Expr::Closure(ExprClosure {
+ span: span.join(body.span()),
+ params: Rc::new(params),
+ body: Rc::new(body),
+ })
+ });
}
// Find out which kind of collection this is.
- match items.as_slice() {
+ Some(match items.as_slice() {
[] => array(p, items, span),
[ExprArg::Pos(_)] if !has_comma => match items.into_iter().next() {
Some(ExprArg::Pos(expr)) => {
@@ -254,7 +283,7 @@ pub fn parenthesized(p: &mut Parser) -> Expr {
},
[ExprArg::Pos(_), ..] => array(p, items, span),
[ExprArg::Named(_), ..] => dict(p, items, span),
- }
+ })
}
/// Parse a collection.
@@ -331,6 +360,19 @@ fn dict(p: &mut Parser, items: Vec<ExprArg>, span: Span) -> Expr {
Expr::Dict(ExprDict { span, items: items.collect() })
}
+/// Convert a collection into a parameter list, producing errors for anything
+/// other than identifiers.
+fn params(p: &mut Parser, items: Vec<ExprArg>) -> Vec<Ident> {
+ let items = items.into_iter().filter_map(|item| match item {
+ ExprArg::Pos(Expr::Ident(id)) => Some(id),
+ _ => {
+ p.diag(error!(item.span(), "expected identifier"));
+ None
+ }
+ });
+ items.collect()
+}
+
// Parse a template value: `[...]`.
fn template(p: &mut Parser) -> Expr {
p.start_group(Group::Bracket, TokenMode::Markup);
@@ -340,7 +382,7 @@ fn template(p: &mut Parser) -> Expr {
}
/// Parse a block expression: `{...}`.
-fn block(p: &mut Parser, scopes: bool) -> Expr {
+fn block(p: &mut Parser, scoping: bool) -> Expr {
p.start_group(Group::Brace, TokenMode::Code);
let mut exprs = vec![];
while !p.eof() {
@@ -355,7 +397,7 @@ fn block(p: &mut Parser, scopes: bool) -> Expr {
p.skip_white();
}
let span = p.end_group();
- Expr::Block(ExprBlock { span, exprs, scoping: scopes })
+ Expr::Block(ExprBlock { span, exprs, scoping })
}
/// Parse an expression.
@@ -445,16 +487,38 @@ fn expr_let(p: &mut Parser) -> Option<Expr> {
let mut expr_let = None;
if let Some(binding) = ident(p) {
+ // If a parenthesis follows, this is a function definition.
+ let mut parameters = None;
+ if p.peek_direct() == Some(Token::LeftParen) {
+ p.start_group(Group::Paren, TokenMode::Code);
+ let items = collection(p).0;
+ parameters = Some(params(p, items));
+ p.end_group();
+ }
+
let mut init = None;
if p.eat_if(Token::Eq) {
init = expr(p);
+ } else if parameters.is_some() {
+ // Function definitions must have a body.
+ p.expected_at("body", p.end());
+ }
+
+ // Rewrite into a closure expression if it's a function definition.
+ if let Some(params) = parameters {
+ let body = init?;
+ init = Some(Expr::Closure(ExprClosure {
+ span: binding.span.join(body.span()),
+ params: Rc::new(params),
+ body: Rc::new(body),
+ }));
}
expr_let = Some(Expr::Let(ExprLet {
span: p.span(start),
binding,
init: init.map(Box::new),
- }))
+ }));
}
expr_let
diff --git a/src/pretty.rs b/src/pretty.rs
index 86919ac8..3f420548 100644
--- a/src/pretty.rs
+++ b/src/pretty.rs
@@ -219,6 +219,7 @@ impl Pretty for Expr {
Self::Unary(v) => v.pretty(p),
Self::Binary(v) => v.pretty(p),
Self::Call(v) => v.pretty(p),
+ Self::Closure(v) => v.pretty(p),
Self::Let(v) => v.pretty(p),
Self::If(v) => v.pretty(p),
Self::While(v) => v.pretty(p),
@@ -383,6 +384,15 @@ impl Pretty for ExprArg {
}
}
+impl Pretty for ExprClosure {
+ fn pretty(&self, p: &mut Printer) {
+ p.push('(');
+ p.join(self.params.iter(), ", ", |item, p| item.pretty(p));
+ p.push_str(") => ");
+ self.body.pretty(p);
+ }
+}
+
impl Pretty for ExprLet {
fn pretty(&self, p: &mut Printer) {
p.push_str("let ");
@@ -529,8 +539,11 @@ impl Pretty for TemplateFunc {
impl Pretty for ValueFunc {
fn pretty(&self, p: &mut Printer) {
- p.push_str("<function ");
- p.push_str(self.name());
+ p.push_str("<function");
+ if let Some(name) = self.name() {
+ p.push(' ');
+ p.push_str(name);
+ }
p.push('>');
}
}
@@ -720,8 +733,12 @@ mod tests {
roundtrip("#v(1, 2)[*Ok*]");
roundtrip("#v(1, f[2])");
+ // Closures.
+ roundtrip("{(a, b) => a + b}");
+
// Keywords.
roundtrip("#let x = 1 + 2");
+ test_parse("#let f(x) = y", "#let f = (x) => y");
test_parse("#if x [y] #else [z]", "#if x [y] else [z]");
roundtrip("#while x {y}");
roundtrip("#for x in y {z}");
@@ -777,8 +794,14 @@ mod tests {
"[*<node example>]",
);
- // Function and arguments.
- test_value(ValueFunc::new("nil", |_, _| Value::None), "<function nil>");
+ // Function.
+ test_value(ValueFunc::new(None, |_, _| Value::None), "<function>");
+ test_value(
+ ValueFunc::new(Some("nil".into()), |_, _| Value::None),
+ "<function nil>",
+ );
+
+ // Arguments.
test_value(
ValueArgs {
span: Span::ZERO,
diff --git a/src/syntax/expr.rs b/src/syntax/expr.rs
index 638d9dd3..d76ada69 100644
--- a/src/syntax/expr.rs
+++ b/src/syntax/expr.rs
@@ -25,8 +25,10 @@ pub enum Expr {
Unary(ExprUnary),
/// A binary operation: `a + b`.
Binary(ExprBinary),
- /// An invocation of a function: `foo(...)`.
+ /// An invocation of a function: `f(x, y)`.
Call(ExprCall),
+ /// A closure expression: `(x, y) => { z }`.
+ Closure(ExprClosure),
/// A let expression: `let x = 1`.
Let(ExprLet),
/// An if expression: `if x { y } else { z }`.
@@ -51,6 +53,7 @@ impl Expr {
Self::Unary(v) => v.span,
Self::Binary(v) => v.span,
Self::Call(v) => v.span,
+ Self::Closure(v) => v.span,
Self::Let(v) => v.span,
Self::If(v) => v.span,
Self::While(v) => v.span,
@@ -58,7 +61,7 @@ impl Expr {
}
}
- /// Whether the expression can be shorten in markup with a hashtag.
+ /// Whether the expression can be shortened in markup with a hashtag.
pub fn has_short_form(&self) -> bool {
matches!(self,
Expr::Ident(_)
@@ -411,6 +414,17 @@ impl ExprArg {
}
}
+/// A closure expression: `(x, y) => { z }`.
+#[derive(Debug, Clone, PartialEq)]
+pub struct ExprClosure {
+ /// The source code location.
+ pub span: Span,
+ /// The parameter bindings.
+ pub params: Rc<Vec<Ident>>,
+ /// The body of the closure.
+ pub body: Rc<Expr>,
+}
+
/// A let expression: `let x = 1`.
#[derive(Debug, Clone, PartialEq)]
pub struct ExprLet {
@@ -418,7 +432,7 @@ pub struct ExprLet {
pub span: Span,
/// The binding to assign to.
pub binding: Ident,
- /// The expression the pattern is initialized with.
+ /// The expression the binding is initialized with.
pub init: Option<Box<Expr>>,
}
diff --git a/src/syntax/visit.rs b/src/syntax/visit.rs
index 1bf260c7..15613233 100644
--- a/src/syntax/visit.rs
+++ b/src/syntax/visit.rs
@@ -3,27 +3,42 @@
use super::*;
macro_rules! visit {
- ($(fn $name:ident($v:ident, $node:ident: &$ty:ty) $body:block)*) => {
+ ($(fn $name:ident($v:ident $(, $node:ident: &$ty:ty)?) $body:block)*) => {
/// Traverses the syntax tree.
pub trait Visit<'ast> {
- $(fn $name(&mut self, $node: &'ast $ty) {
- $name(self, $node);
+ $(fn $name(&mut self $(, $node: &'ast $ty)?) {
+ $name(self, $($node)?);
})*
+
+ /// Visit a definition of a binding.
+ ///
+ /// Bindings are, for example, left-hand side of let expressions,
+ /// and key/value patterns in for loops.
+ fn visit_binding(&mut self, _: &'ast Ident) {}
+
+ /// Visit the entry into a scope.
+ fn visit_enter(&mut self) {}
+
+ /// Visit the exit from a scope.
+ fn visit_exit(&mut self) {}
}
$(visit! {
- @concat!("Walk a node of type [`", stringify!($ty), "`]."),
- pub fn $name<'ast, V>($v: &mut V, $node: &'ast $ty)
+ @$(concat!("Walk a node of type [`", stringify!($ty), "`]."), )?
+ pub fn $name<'ast, V>(
+ #[allow(unused)] $v: &mut V
+ $(, #[allow(unused)] $node: &'ast $ty)?
+ )
where
V: Visit<'ast> + ?Sized
$body
})*
};
+
(@$doc:expr, $($tts:tt)*) => {
#[doc = $doc]
$($tts)*
- }
-
+ };
}
visit! {
@@ -59,6 +74,7 @@ visit! {
Expr::Unary(e) => v.visit_unary(e),
Expr::Binary(e) => v.visit_binary(e),
Expr::Call(e) => v.visit_call(e),
+ Expr::Closure(e) => v.visit_closure(e),
Expr::Let(e) => v.visit_let(e),
Expr::If(e) => v.visit_if(e),
Expr::While(e) => v.visit_while(e),
@@ -79,7 +95,9 @@ visit! {
}
fn visit_template(v, node: &ExprTemplate) {
+ v.visit_enter();
v.visit_tree(&node.tree);
+ v.visit_exit();
}
fn visit_group(v, node: &ExprGroup) {
@@ -87,9 +105,15 @@ visit! {
}
fn visit_block(v, node: &ExprBlock) {
+ if node.scoping {
+ v.visit_enter();
+ }
for expr in &node.exprs {
v.visit_expr(&expr);
}
+ if node.scoping {
+ v.visit_exit();
+ }
}
fn visit_binary(v, node: &ExprBinary) {
@@ -106,6 +130,13 @@ visit! {
v.visit_args(&node.args);
}
+ fn visit_closure(v, node: &ExprClosure) {
+ for param in node.params.iter() {
+ v.visit_binding(param);
+ }
+ v.visit_expr(&node.body);
+ }
+
fn visit_args(v, node: &ExprArgs) {
for arg in &node.items {
v.visit_arg(arg);
@@ -120,6 +151,7 @@ visit! {
}
fn visit_let(v, node: &ExprLet) {
+ v.visit_binding(&node.binding);
if let Some(init) = &node.init {
v.visit_expr(&init);
}
@@ -139,6 +171,13 @@ visit! {
}
fn visit_for(v, node: &ExprFor) {
+ match &node.pattern {
+ ForPattern::Value(value) => v.visit_binding(value),
+ ForPattern::KeyValue(key, value) => {
+ v.visit_binding(key);
+ v.visit_binding(value);
+ }
+ }
v.visit_expr(&node.iter);
v.visit_expr(&node.body);
}