diff options
| author | Laurenz <laurmaedje@gmail.com> | 2022-01-31 17:57:20 +0100 |
|---|---|---|
| committer | Laurenz <laurmaedje@gmail.com> | 2022-02-01 12:26:13 +0100 |
| commit | 6a6753cb69f7c29e857fd465eecf66a02ff76aa3 (patch) | |
| tree | e157752f30f5c493ee045d98039cfd3a94cdff22 /src/eval/func.rs | |
| parent | 20b1a38414101f842a6d9201133a5aaaa45a7cec (diff) | |
Better function representation
Diffstat (limited to 'src/eval/func.rs')
| -rw-r--r-- | src/eval/func.rs | 304 |
1 files changed, 304 insertions, 0 deletions
diff --git a/src/eval/func.rs b/src/eval/func.rs new file mode 100644 index 00000000..ccd0932f --- /dev/null +++ b/src/eval/func.rs @@ -0,0 +1,304 @@ +use std::fmt::{self, Debug, Formatter, Write}; +use std::sync::Arc; + +use super::{Cast, Eval, EvalContext, Scope, Value}; +use crate::diag::{At, TypResult}; +use crate::syntax::ast::Expr; +use crate::syntax::{Span, Spanned}; +use crate::util::EcoString; + +/// An evaluatable function. +#[derive(Clone)] +pub struct Func(Arc<Repr>); + +/// The different kinds of function representations. +enum Repr { + /// A native rust function. + Native(Native), + /// A user-defined closure. + Closure(Closure), + /// A nested function with pre-applied arguments. + With(Func, Args), +} + +impl Func { + /// Create a new function from a native rust function. + pub fn native( + name: &'static str, + func: fn(&mut EvalContext, &mut Args) -> TypResult<Value>, + ) -> Self { + Self(Arc::new(Repr::Native(Native { name, func }))) + } + + /// Create a new function from a closure. + pub fn closure(closure: Closure) -> Self { + Self(Arc::new(Repr::Closure(closure))) + } + + /// The name of the function. + pub fn name(&self) -> Option<&str> { + match self.0.as_ref() { + Repr::Native(native) => Some(native.name), + Repr::Closure(closure) => closure.name.as_deref(), + Repr::With(func, _) => func.name(), + } + } + + /// Call the function in the context with the arguments. + pub fn call(&self, ctx: &mut EvalContext, args: &mut Args) -> TypResult<Value> { + match self.0.as_ref() { + Repr::Native(native) => (native.func)(ctx, args), + Repr::Closure(closure) => closure.call(ctx, args), + Repr::With(wrapped, applied) => { + args.items.splice(.. 0, applied.items.iter().cloned()); + wrapped.call(ctx, args) + } + } + } + + /// Apply the given arguments to the function. + pub fn with(self, args: Args) -> Self { + Self(Arc::new(Repr::With(self, args))) + } +} + +impl Debug for Func { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + f.write_str("<function")?; + if let Some(name) = self.name() { + f.write_char(' ')?; + f.write_str(name)?; + } + f.write_char('>') + } +} + +impl PartialEq for Func { + fn eq(&self, other: &Self) -> bool { + Arc::ptr_eq(&self.0, &other.0) + } +} + +/// A native rust function. +struct Native { + /// The name of the function. + pub name: &'static str, + /// The function pointer. + pub func: fn(&mut EvalContext, &mut Args) -> TypResult<Value>, +} + +/// A user-defined closure. +pub struct Closure { + /// The name of the closure. + pub name: Option<EcoString>, + /// Captured values from outer scopes. + pub captured: Scope, + /// The parameter names and default values. Parameters with default value + /// are named parameters. + pub params: Vec<(EcoString, Option<Value>)>, + /// The name of an argument sink where remaining arguments are placed. + pub sink: Option<EcoString>, + /// The expression the closure should evaluate to. + pub body: Expr, +} + +impl Closure { + /// Call the function in the context with the arguments. + pub fn call(&self, ctx: &mut EvalContext, args: &mut Args) -> TypResult<Value> { + // Don't leak the scopes from the call site. Instead, we use the + // scope of captured variables we collected earlier. + let prev_scopes = std::mem::take(&mut ctx.scopes); + ctx.scopes.top = self.captured.clone(); + + // Parse the arguments according to the parameter list. + for (param, default) in &self.params { + ctx.scopes.def_mut(param, match default { + None => args.expect::<Value>(param)?, + Some(default) => { + args.named::<Value>(param)?.unwrap_or_else(|| default.clone()) + } + }); + } + + // Put the remaining arguments into the sink. + if let Some(sink) = &self.sink { + ctx.scopes.def_mut(sink, args.take()); + } + + // Evaluate the body. + let value = self.body.eval(ctx)?; + + // Restore the call site scopes. + ctx.scopes = prev_scopes; + + Ok(value) + } +} + +/// Evaluated arguments to a function. +#[derive(Clone, PartialEq)] +pub struct Args { + /// The span of the whole argument list. + pub span: Span, + /// The positional and named arguments. + pub items: Vec<Arg>, +} + +/// An argument to a function call: `12` or `draw: false`. +#[derive(Clone, PartialEq)] +pub struct Arg { + /// The span of the whole argument. + pub span: Span, + /// The name of the argument (`None` for positional arguments). + pub name: Option<EcoString>, + /// The value of the argument. + pub value: Spanned<Value>, +} + +impl Args { + /// Consume and cast the first positional argument. + /// + /// Returns a `missing argument: {what}` error if no positional argument is + /// left. + pub fn expect<T>(&mut self, what: &str) -> TypResult<T> + where + T: Cast<Spanned<Value>>, + { + match self.eat()? { + Some(v) => Ok(v), + None => bail!(self.span, "missing argument: {}", what), + } + } + + /// Consume and cast the first positional argument if there is one. + pub fn eat<T>(&mut self) -> TypResult<Option<T>> + where + T: Cast<Spanned<Value>>, + { + for (i, slot) in self.items.iter().enumerate() { + if slot.name.is_none() { + let value = self.items.remove(i).value; + let span = value.span; + return T::cast(value).at(span).map(Some); + } + } + Ok(None) + } + + /// Find and consume the first castable positional argument. + pub fn find<T>(&mut self) -> Option<T> + where + T: Cast<Spanned<Value>>, + { + for (i, slot) in self.items.iter().enumerate() { + if slot.name.is_none() && T::is(&slot.value) { + let value = self.items.remove(i).value; + return T::cast(value).ok(); + } + } + None + } + + /// Find and consume all castable positional arguments. + pub fn all<T>(&mut self) -> impl Iterator<Item = T> + '_ + where + T: Cast<Spanned<Value>>, + { + std::iter::from_fn(move || self.find()) + } + + /// Cast and remove the value for the given named argument, returning an + /// error if the conversion fails. + pub fn named<T>(&mut self, name: &str) -> TypResult<Option<T>> + where + T: Cast<Spanned<Value>>, + { + // We don't quit once we have a match because when multiple matches + // exist, we want to remove all of them and use the last one. + let mut i = 0; + let mut found = None; + while i < self.items.len() { + if self.items[i].name.as_deref() == Some(name) { + let value = self.items.remove(i).value; + let span = value.span; + found = Some(T::cast(value).at(span)?); + } else { + i += 1; + } + } + Ok(found) + } + + /// Take out all arguments into a new instance. + pub fn take(&mut self) -> Self { + Self { + span: self.span, + items: std::mem::take(&mut self.items), + } + } + + /// Return an "unexpected argument" error if there is any remaining + /// argument. + pub fn finish(self) -> TypResult<()> { + if let Some(arg) = self.items.first() { + bail!(arg.span, "unexpected argument"); + } + Ok(()) + } + + /// Reinterpret these arguments as actually being an array index. + pub fn into_index(self) -> TypResult<i64> { + self.into_castable("index") + } + + /// Reinterpret these arguments as actually being a dictionary key. + pub fn into_key(self) -> TypResult<EcoString> { + self.into_castable("key") + } + + /// Reinterpret these arguments as actually being a single castable thing. + fn into_castable<T>(self, what: &str) -> TypResult<T> + where + T: Cast<Value>, + { + let mut iter = self.items.into_iter(); + let value = match iter.next() { + Some(Arg { name: None, value, .. }) => value.v.cast().at(value.span)?, + None => { + bail!(self.span, "missing {}", what); + } + Some(Arg { name: Some(_), span, .. }) => { + bail!(span, "named pair is not allowed here"); + } + }; + + if let Some(arg) = iter.next() { + bail!(arg.span, "only one {} is allowed", what); + } + + Ok(value) + } +} + +impl Debug for Args { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + f.write_char('(')?; + for (i, arg) in self.items.iter().enumerate() { + arg.fmt(f)?; + if i + 1 < self.items.len() { + f.write_str(", ")?; + } + } + f.write_char(')') + } +} + +impl Debug for Arg { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + if let Some(name) = &self.name { + f.write_str(name)?; + f.write_str(": ")?; + } + Debug::fmt(&self.value.v, f) + } +} |
