summaryrefslogtreecommitdiff
path: root/src/eval/func.rs
diff options
context:
space:
mode:
authorLaurenz <laurmaedje@gmail.com>2022-01-31 17:57:20 +0100
committerLaurenz <laurmaedje@gmail.com>2022-02-01 12:26:13 +0100
commit6a6753cb69f7c29e857fd465eecf66a02ff76aa3 (patch)
treee157752f30f5c493ee045d98039cfd3a94cdff22 /src/eval/func.rs
parent20b1a38414101f842a6d9201133a5aaaa45a7cec (diff)
Better function representation
Diffstat (limited to 'src/eval/func.rs')
-rw-r--r--src/eval/func.rs304
1 files changed, 304 insertions, 0 deletions
diff --git a/src/eval/func.rs b/src/eval/func.rs
new file mode 100644
index 00000000..ccd0932f
--- /dev/null
+++ b/src/eval/func.rs
@@ -0,0 +1,304 @@
+use std::fmt::{self, Debug, Formatter, Write};
+use std::sync::Arc;
+
+use super::{Cast, Eval, EvalContext, Scope, Value};
+use crate::diag::{At, TypResult};
+use crate::syntax::ast::Expr;
+use crate::syntax::{Span, Spanned};
+use crate::util::EcoString;
+
+/// An evaluatable function.
+#[derive(Clone)]
+pub struct Func(Arc<Repr>);
+
+/// The different kinds of function representations.
+enum Repr {
+ /// A native rust function.
+ Native(Native),
+ /// A user-defined closure.
+ Closure(Closure),
+ /// A nested function with pre-applied arguments.
+ With(Func, Args),
+}
+
+impl Func {
+ /// Create a new function from a native rust function.
+ pub fn native(
+ name: &'static str,
+ func: fn(&mut EvalContext, &mut Args) -> TypResult<Value>,
+ ) -> Self {
+ Self(Arc::new(Repr::Native(Native { name, func })))
+ }
+
+ /// Create a new function from a closure.
+ pub fn closure(closure: Closure) -> Self {
+ Self(Arc::new(Repr::Closure(closure)))
+ }
+
+ /// The name of the function.
+ pub fn name(&self) -> Option<&str> {
+ match self.0.as_ref() {
+ Repr::Native(native) => Some(native.name),
+ Repr::Closure(closure) => closure.name.as_deref(),
+ Repr::With(func, _) => func.name(),
+ }
+ }
+
+ /// Call the function in the context with the arguments.
+ pub fn call(&self, ctx: &mut EvalContext, args: &mut Args) -> TypResult<Value> {
+ match self.0.as_ref() {
+ Repr::Native(native) => (native.func)(ctx, args),
+ Repr::Closure(closure) => closure.call(ctx, args),
+ Repr::With(wrapped, applied) => {
+ args.items.splice(.. 0, applied.items.iter().cloned());
+ wrapped.call(ctx, args)
+ }
+ }
+ }
+
+ /// Apply the given arguments to the function.
+ pub fn with(self, args: Args) -> Self {
+ Self(Arc::new(Repr::With(self, args)))
+ }
+}
+
+impl Debug for Func {
+ fn fmt(&self, f: &mut Formatter) -> fmt::Result {
+ f.write_str("<function")?;
+ if let Some(name) = self.name() {
+ f.write_char(' ')?;
+ f.write_str(name)?;
+ }
+ f.write_char('>')
+ }
+}
+
+impl PartialEq for Func {
+ fn eq(&self, other: &Self) -> bool {
+ Arc::ptr_eq(&self.0, &other.0)
+ }
+}
+
+/// A native rust function.
+struct Native {
+ /// The name of the function.
+ pub name: &'static str,
+ /// The function pointer.
+ pub func: fn(&mut EvalContext, &mut Args) -> TypResult<Value>,
+}
+
+/// A user-defined closure.
+pub struct Closure {
+ /// The name of the closure.
+ pub name: Option<EcoString>,
+ /// Captured values from outer scopes.
+ pub captured: Scope,
+ /// The parameter names and default values. Parameters with default value
+ /// are named parameters.
+ pub params: Vec<(EcoString, Option<Value>)>,
+ /// The name of an argument sink where remaining arguments are placed.
+ pub sink: Option<EcoString>,
+ /// The expression the closure should evaluate to.
+ pub body: Expr,
+}
+
+impl Closure {
+ /// Call the function in the context with the arguments.
+ pub fn call(&self, ctx: &mut EvalContext, args: &mut Args) -> TypResult<Value> {
+ // Don't leak the scopes from the call site. Instead, we use the
+ // scope of captured variables we collected earlier.
+ let prev_scopes = std::mem::take(&mut ctx.scopes);
+ ctx.scopes.top = self.captured.clone();
+
+ // Parse the arguments according to the parameter list.
+ for (param, default) in &self.params {
+ ctx.scopes.def_mut(param, match default {
+ None => args.expect::<Value>(param)?,
+ Some(default) => {
+ args.named::<Value>(param)?.unwrap_or_else(|| default.clone())
+ }
+ });
+ }
+
+ // Put the remaining arguments into the sink.
+ if let Some(sink) = &self.sink {
+ ctx.scopes.def_mut(sink, args.take());
+ }
+
+ // Evaluate the body.
+ let value = self.body.eval(ctx)?;
+
+ // Restore the call site scopes.
+ ctx.scopes = prev_scopes;
+
+ Ok(value)
+ }
+}
+
+/// Evaluated arguments to a function.
+#[derive(Clone, PartialEq)]
+pub struct Args {
+ /// The span of the whole argument list.
+ pub span: Span,
+ /// The positional and named arguments.
+ pub items: Vec<Arg>,
+}
+
+/// An argument to a function call: `12` or `draw: false`.
+#[derive(Clone, PartialEq)]
+pub struct Arg {
+ /// The span of the whole argument.
+ pub span: Span,
+ /// The name of the argument (`None` for positional arguments).
+ pub name: Option<EcoString>,
+ /// The value of the argument.
+ pub value: Spanned<Value>,
+}
+
+impl Args {
+ /// Consume and cast the first positional argument.
+ ///
+ /// Returns a `missing argument: {what}` error if no positional argument is
+ /// left.
+ pub fn expect<T>(&mut self, what: &str) -> TypResult<T>
+ where
+ T: Cast<Spanned<Value>>,
+ {
+ match self.eat()? {
+ Some(v) => Ok(v),
+ None => bail!(self.span, "missing argument: {}", what),
+ }
+ }
+
+ /// Consume and cast the first positional argument if there is one.
+ pub fn eat<T>(&mut self) -> TypResult<Option<T>>
+ where
+ T: Cast<Spanned<Value>>,
+ {
+ for (i, slot) in self.items.iter().enumerate() {
+ if slot.name.is_none() {
+ let value = self.items.remove(i).value;
+ let span = value.span;
+ return T::cast(value).at(span).map(Some);
+ }
+ }
+ Ok(None)
+ }
+
+ /// Find and consume the first castable positional argument.
+ pub fn find<T>(&mut self) -> Option<T>
+ where
+ T: Cast<Spanned<Value>>,
+ {
+ for (i, slot) in self.items.iter().enumerate() {
+ if slot.name.is_none() && T::is(&slot.value) {
+ let value = self.items.remove(i).value;
+ return T::cast(value).ok();
+ }
+ }
+ None
+ }
+
+ /// Find and consume all castable positional arguments.
+ pub fn all<T>(&mut self) -> impl Iterator<Item = T> + '_
+ where
+ T: Cast<Spanned<Value>>,
+ {
+ std::iter::from_fn(move || self.find())
+ }
+
+ /// Cast and remove the value for the given named argument, returning an
+ /// error if the conversion fails.
+ pub fn named<T>(&mut self, name: &str) -> TypResult<Option<T>>
+ where
+ T: Cast<Spanned<Value>>,
+ {
+ // We don't quit once we have a match because when multiple matches
+ // exist, we want to remove all of them and use the last one.
+ let mut i = 0;
+ let mut found = None;
+ while i < self.items.len() {
+ if self.items[i].name.as_deref() == Some(name) {
+ let value = self.items.remove(i).value;
+ let span = value.span;
+ found = Some(T::cast(value).at(span)?);
+ } else {
+ i += 1;
+ }
+ }
+ Ok(found)
+ }
+
+ /// Take out all arguments into a new instance.
+ pub fn take(&mut self) -> Self {
+ Self {
+ span: self.span,
+ items: std::mem::take(&mut self.items),
+ }
+ }
+
+ /// Return an "unexpected argument" error if there is any remaining
+ /// argument.
+ pub fn finish(self) -> TypResult<()> {
+ if let Some(arg) = self.items.first() {
+ bail!(arg.span, "unexpected argument");
+ }
+ Ok(())
+ }
+
+ /// Reinterpret these arguments as actually being an array index.
+ pub fn into_index(self) -> TypResult<i64> {
+ self.into_castable("index")
+ }
+
+ /// Reinterpret these arguments as actually being a dictionary key.
+ pub fn into_key(self) -> TypResult<EcoString> {
+ self.into_castable("key")
+ }
+
+ /// Reinterpret these arguments as actually being a single castable thing.
+ fn into_castable<T>(self, what: &str) -> TypResult<T>
+ where
+ T: Cast<Value>,
+ {
+ let mut iter = self.items.into_iter();
+ let value = match iter.next() {
+ Some(Arg { name: None, value, .. }) => value.v.cast().at(value.span)?,
+ None => {
+ bail!(self.span, "missing {}", what);
+ }
+ Some(Arg { name: Some(_), span, .. }) => {
+ bail!(span, "named pair is not allowed here");
+ }
+ };
+
+ if let Some(arg) = iter.next() {
+ bail!(arg.span, "only one {} is allowed", what);
+ }
+
+ Ok(value)
+ }
+}
+
+impl Debug for Args {
+ fn fmt(&self, f: &mut Formatter) -> fmt::Result {
+ f.write_char('(')?;
+ for (i, arg) in self.items.iter().enumerate() {
+ arg.fmt(f)?;
+ if i + 1 < self.items.len() {
+ f.write_str(", ")?;
+ }
+ }
+ f.write_char(')')
+ }
+}
+
+impl Debug for Arg {
+ fn fmt(&self, f: &mut Formatter) -> fmt::Result {
+ if let Some(name) = &self.name {
+ f.write_str(name)?;
+ f.write_str(": ")?;
+ }
+ Debug::fmt(&self.value.v, f)
+ }
+}