From 6ab7760822ccd24b4ef126d4737d41f1be15fe19 Mon Sep 17 00:00:00 2001 From: Laurenz Date: Wed, 1 Mar 2023 16:30:58 +0100 Subject: Split up `model` module --- src/eval/func.rs | 574 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 574 insertions(+) create mode 100644 src/eval/func.rs (limited to 'src/eval/func.rs') diff --git a/src/eval/func.rs b/src/eval/func.rs new file mode 100644 index 00000000..e5280932 --- /dev/null +++ b/src/eval/func.rs @@ -0,0 +1,574 @@ +use std::fmt::{self, Debug, Formatter}; +use std::hash::{Hash, Hasher}; +use std::sync::Arc; + +use comemo::{Prehashed, Track, Tracked, TrackedMut}; +use ecow::EcoString; + +use super::{Args, CastInfo, Dict, Eval, Flow, Route, Scope, Scopes, Tracer, Value, Vm}; +use crate::diag::{bail, SourceResult, StrResult}; +use crate::model::{Node, NodeId, Selector, StyleMap}; +use crate::syntax::ast::{self, AstNode, Expr}; +use crate::syntax::{SourceId, Span, SyntaxNode}; +use crate::util::hash128; +use crate::World; + +/// An evaluatable function. +#[derive(Clone, Hash)] +pub struct Func(Arc>, Span); + +/// The different kinds of function representations. +#[derive(Hash)] +enum Repr { + /// A native rust function. + Native(Native), + /// A user-defined closure. + Closure(Closure), + /// A nested function with pre-applied arguments. + With(Func, Args), +} + +impl Func { + /// Create a new function from a type that can be turned into a function. + pub fn from_type(name: &'static str) -> Self { + T::create_func(name) + } + + /// Create a new function from a native rust function. + pub fn from_fn( + func: fn(&Vm, &mut Args) -> SourceResult, + info: FuncInfo, + ) -> Self { + Self( + Arc::new(Prehashed::new(Repr::Native(Native { + func, + set: None, + node: None, + info, + }))), + Span::detached(), + ) + } + + /// Create a new function from a native rust node. + pub fn from_node(mut info: FuncInfo) -> Self { + info.params.extend(T::properties()); + Self( + Arc::new(Prehashed::new(Repr::Native(Native { + func: |ctx, args| { + let styles = T::set(args, true)?; + let content = T::construct(ctx, args)?; + Ok(Value::Content(content.styled_with_map(styles.scoped()))) + }, + set: Some(|args| T::set(args, false)), + node: Some(NodeId::of::()), + info, + }))), + Span::detached(), + ) + } + + /// Create a new function from a closure. + pub(super) fn from_closure(closure: Closure, span: Span) -> Self { + Self(Arc::new(Prehashed::new(Repr::Closure(closure))), span) + } + + /// The name of the function. + pub fn name(&self) -> Option<&str> { + match &**self.0 { + Repr::Native(native) => Some(native.info.name), + Repr::Closure(closure) => closure.name.as_deref(), + Repr::With(func, _) => func.name(), + } + } + + /// Extract details the function. + pub fn info(&self) -> Option<&FuncInfo> { + match &**self.0 { + Repr::Native(native) => Some(&native.info), + Repr::With(func, _) => func.info(), + _ => None, + } + } + + /// The function's span. + pub fn span(&self) -> Span { + self.1 + } + + /// Attach a span to the function. + pub fn spanned(mut self, span: Span) -> Self { + self.1 = span; + self + } + + /// The number of positional arguments this function takes, if known. + pub fn argc(&self) -> Option { + match &**self.0 { + Repr::Closure(closure) => closure.argc(), + Repr::With(wrapped, applied) => Some(wrapped.argc()?.saturating_sub( + applied.items.iter().filter(|arg| arg.name.is_none()).count(), + )), + _ => None, + } + } + + /// Call the function with the given arguments. + pub fn call(&self, vm: &mut Vm, mut args: Args) -> SourceResult { + match &**self.0 { + Repr::Native(native) => { + let value = (native.func)(vm, &mut args)?; + args.finish()?; + Ok(value) + } + Repr::Closure(closure) => { + // Determine the route inside the closure. + let fresh = Route::new(closure.location); + let route = + if vm.location.is_detached() { fresh.track() } else { vm.route }; + + Closure::call( + self, + vm.world, + route, + TrackedMut::reborrow_mut(&mut vm.tracer), + vm.depth + 1, + args, + ) + } + Repr::With(wrapped, applied) => { + args.items = applied.items.iter().cloned().chain(args.items).collect(); + return wrapped.call(vm, args); + } + } + } + + /// Call the function without an existing virtual machine. + pub fn call_detached( + &self, + world: Tracked, + args: Args, + ) -> SourceResult { + let route = Route::default(); + let id = SourceId::detached(); + let scopes = Scopes::new(None); + let mut tracer = Tracer::default(); + let mut vm = Vm::new(world, route.track(), tracer.track_mut(), id, scopes, 0); + self.call(&mut vm, args) + } + + /// Apply the given arguments to the function. + pub fn with(self, args: Args) -> Self { + let span = self.1; + Self(Arc::new(Prehashed::new(Repr::With(self, args))), span) + } + + /// Create a selector for this function's node type, filtering by node's + /// whose [fields](super::Content::field) match the given arguments. + pub fn where_(self, args: &mut Args) -> StrResult { + let fields = args.to_named(); + args.items.retain(|arg| arg.name.is_none()); + self.select(Some(fields)) + } + + /// Execute the function's set rule and return the resulting style map. + pub fn set(&self, mut args: Args) -> SourceResult { + Ok(match &**self.0 { + Repr::Native(Native { set: Some(set), .. }) => { + let styles = set(&mut args)?; + args.finish()?; + styles + } + _ => StyleMap::new(), + }) + } + + /// Create a selector for this function's node type. + pub fn select(&self, fields: Option) -> StrResult { + match **self.0 { + Repr::Native(Native { node: Some(id), .. }) => { + if id == item!(text_id) { + Err("to select text, please use a string or regex instead")?; + } + + Ok(Selector::Node(id, fields)) + } + _ => Err("this function is not selectable")?, + } + } +} + +impl Debug for Func { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + match self.name() { + Some(name) => write!(f, ""), + None => f.write_str(""), + } + } +} + +impl PartialEq for Func { + fn eq(&self, other: &Self) -> bool { + hash128(&self.0) == hash128(&other.0) + } +} + +/// Types that can be turned into functions. +pub trait FuncType { + /// Create a function with the given name from this type. + fn create_func(name: &'static str) -> Func; +} + +/// A function defined by a native rust function or node. +struct Native { + /// The function pointer. + func: fn(&Vm, &mut Args) -> SourceResult, + /// The set rule. + set: Option SourceResult>, + /// The id of the node to customize with this function's show rule. + node: Option, + /// Documentation of the function. + info: FuncInfo, +} + +impl Hash for Native { + fn hash(&self, state: &mut H) { + (self.func as usize).hash(state); + self.set.map(|set| set as usize).hash(state); + self.node.hash(state); + } +} + +/// Details about a function. +#[derive(Debug, Clone)] +pub struct FuncInfo { + /// The function's name. + pub name: &'static str, + /// The display name of the function. + pub display: &'static str, + /// Documentation for the function. + pub docs: &'static str, + /// Details about the function's parameters. + pub params: Vec, + /// Valid types for the return value. + pub returns: Vec<&'static str>, + /// Which category the function is part of. + pub category: &'static str, +} + +impl FuncInfo { + /// Get the parameter info for a parameter with the given name + pub fn param(&self, name: &str) -> Option<&ParamInfo> { + self.params.iter().find(|param| param.name == name) + } +} + +/// Describes a named parameter. +#[derive(Debug, Clone)] +pub struct ParamInfo { + /// The parameter's name. + pub name: &'static str, + /// Documentation for the parameter. + pub docs: &'static str, + /// Valid values for the parameter. + pub cast: CastInfo, + /// Is the parameter positional? + pub positional: bool, + /// Is the parameter named? + /// + /// Can be true even if `positional` is true if the parameter can be given + /// in both variants. + pub named: bool, + /// Is the parameter required? + pub required: bool, + /// Can the parameter be given any number of times? + pub variadic: bool, + /// Is the parameter settable with a set rule? + pub settable: bool, +} + +/// A user-defined closure. +#[derive(Hash)] +pub(super) struct Closure { + /// The source file where the closure was defined. + pub location: SourceId, + /// The name of the closure. + pub name: Option, + /// Captured values from outer scopes. + pub captured: Scope, + /// The parameter names and default values. Parameters with default value + /// are named parameters. + pub params: Vec<(EcoString, Option)>, + /// The name of an argument sink where remaining arguments are placed. + pub sink: Option, + /// The expression the closure should evaluate to. + pub body: Expr, +} + +impl Closure { + /// Call the function in the context with the arguments. + #[comemo::memoize] + fn call( + this: &Func, + world: Tracked, + route: Tracked, + tracer: TrackedMut, + depth: usize, + mut args: Args, + ) -> SourceResult { + let closure = match &**this.0 { + Repr::Closure(closure) => closure, + _ => panic!("`this` must be a closure"), + }; + + // Don't leak the scopes from the call site. Instead, we use the scope + // of captured variables we collected earlier. + let mut scopes = Scopes::new(None); + scopes.top = closure.captured.clone(); + + // Provide the closure itself for recursive calls. + if let Some(name) = &closure.name { + scopes.top.define(name.clone(), Value::Func(this.clone())); + } + + // Parse the arguments according to the parameter list. + for (param, default) in &closure.params { + scopes.top.define( + param.clone(), + match default { + Some(default) => { + args.named::(param)?.unwrap_or_else(|| default.clone()) + } + None => args.expect::(param)?, + }, + ); + } + + // Put the remaining arguments into the sink. + if let Some(sink) = &closure.sink { + scopes.top.define(sink.clone(), args.take()); + } + + // Ensure all arguments have been used. + args.finish()?; + + // Evaluate the body. + let mut sub = Vm::new(world, route, tracer, closure.location, scopes, depth); + let result = closure.body.eval(&mut sub); + + // Handle control flow. + match sub.flow { + Some(Flow::Return(_, Some(explicit))) => return Ok(explicit), + Some(Flow::Return(_, None)) => {} + Some(flow) => bail!(flow.forbidden()), + None => {} + } + + result + } + + /// The number of positional arguments this function takes, if known. + fn argc(&self) -> Option { + if self.sink.is_some() { + return None; + } + + Some(self.params.iter().filter(|(_, default)| default.is_none()).count()) + } +} + +/// A visitor that determines which variables to capture for a closure. +pub(super) struct CapturesVisitor<'a> { + external: &'a Scopes<'a>, + internal: Scopes<'a>, + captures: Scope, +} + +impl<'a> CapturesVisitor<'a> { + /// Create a new visitor for the given external scopes. + pub fn new(external: &'a Scopes) -> Self { + Self { + external, + internal: Scopes::new(None), + captures: Scope::new(), + } + } + + /// Return the scope of captured variables. + pub fn finish(self) -> Scope { + self.captures + } + + /// Visit any node and collect all captured variables. + pub fn visit(&mut self, node: &SyntaxNode) { + match node.cast() { + // Every identifier is a potential variable that we need to capture. + // Identifiers that shouldn't count as captures because they + // actually bind a new name are handled below (individually through + // the expressions that contain them). + Some(ast::Expr::Ident(ident)) => self.capture(ident), + Some(ast::Expr::MathIdent(ident)) => self.capture_in_math(ident), + + // Code and content blocks create a scope. + Some(ast::Expr::Code(_) | ast::Expr::Content(_)) => { + self.internal.enter(); + for child in node.children() { + self.visit(child); + } + self.internal.exit(); + } + + // A closure contains parameter bindings, which are bound before the + // body is evaluated. Care must be taken so that the default values + // of named parameters cannot access previous parameter bindings. + Some(ast::Expr::Closure(expr)) => { + for param in expr.params() { + if let ast::Param::Named(named) = param { + self.visit(named.expr().as_untyped()); + } + } + + self.internal.enter(); + if let Some(name) = expr.name() { + self.bind(name); + } + + for param in expr.params() { + match param { + ast::Param::Pos(ident) => self.bind(ident), + ast::Param::Named(named) => self.bind(named.name()), + ast::Param::Sink(ident) => self.bind(ident), + } + } + + self.visit(expr.body().as_untyped()); + self.internal.exit(); + } + + // A let expression contains a binding, but that binding is only + // active after the body is evaluated. + Some(ast::Expr::Let(expr)) => { + if let Some(init) = expr.init() { + self.visit(init.as_untyped()); + } + self.bind(expr.binding()); + } + + // A for loop contains one or two bindings in its pattern. These are + // active after the iterable is evaluated but before the body is + // evaluated. + Some(ast::Expr::For(expr)) => { + self.visit(expr.iter().as_untyped()); + self.internal.enter(); + let pattern = expr.pattern(); + if let Some(key) = pattern.key() { + self.bind(key); + } + self.bind(pattern.value()); + self.visit(expr.body().as_untyped()); + self.internal.exit(); + } + + // An import contains items, but these are active only after the + // path is evaluated. + Some(ast::Expr::Import(expr)) => { + self.visit(expr.source().as_untyped()); + if let Some(ast::Imports::Items(items)) = expr.imports() { + for item in items { + self.bind(item); + } + } + } + + // Everything else is traversed from left to right. + _ => { + for child in node.children() { + self.visit(child); + } + } + } + } + + /// Bind a new internal variable. + fn bind(&mut self, ident: ast::Ident) { + self.internal.top.define(ident.take(), Value::None); + } + + /// Capture a variable if it isn't internal. + fn capture(&mut self, ident: ast::Ident) { + if self.internal.get(&ident).is_err() { + if let Ok(value) = self.external.get(&ident) { + self.captures.define_captured(ident.take(), value.clone()); + } + } + } + + /// Capture a variable in math mode if it isn't internal. + fn capture_in_math(&mut self, ident: ast::MathIdent) { + if self.internal.get(&ident).is_err() { + if let Ok(value) = self.external.get_in_math(&ident) { + self.captures.define_captured(ident.take(), value.clone()); + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::syntax::parse; + + #[track_caller] + fn test(text: &str, result: &[&str]) { + let mut scopes = Scopes::new(None); + scopes.top.define("f", 0); + scopes.top.define("x", 0); + scopes.top.define("y", 0); + scopes.top.define("z", 0); + + let mut visitor = CapturesVisitor::new(&scopes); + let root = parse(text); + visitor.visit(&root); + + let captures = visitor.finish(); + let mut names: Vec<_> = captures.iter().map(|(k, _)| k).collect(); + names.sort(); + + assert_eq!(names, result); + } + + #[test] + fn test_captures() { + // Let binding and function definition. + test("#let x = x", &["x"]); + test("#let x; #(x + y)", &["y"]); + test("#let f(x, y) = x + y", &[]); + test("#let f(x, y) = f", &[]); + test("#let f = (x, y) => f", &["f"]); + + // Closure with different kinds of params. + test("#((x, y) => x + z)", &["z"]); + test("#((x: y, z) => x + z)", &["y"]); + test("#((..x) => x + y)", &["y"]); + test("#((x, y: x + z) => x + y)", &["x", "z"]); + test("#{x => x; x}", &["x"]); + + // Show rule. + test("#show y: x => x", &["y"]); + test("#show y: x => x + z", &["y", "z"]); + test("#show x: x => x", &["x"]); + + // For loop. + test("#for x in y { x + z }", &["y", "z"]); + test("#for x, y in y { x + y }", &["y"]); + test("#for x in y {} #x", &["x", "y"]); + + // Import. + test("#import z: x, y", &["z"]); + test("#import x + y: x, y, z", &["x", "y"]); + + // Blocks. + test("#{ let x = 1; { let y = 2; y }; x + y }", &["y"]); + test("#[#let x = 1]#x", &["x"]); + } +} -- cgit v1.2.3