From f88ef45ee6e285df59c7aa5cec935de331b4b6e0 Mon Sep 17 00:00:00 2001 From: Pg Biel <9021226+PgBiel@users.noreply.github.com> Date: Wed, 3 May 2023 09:20:53 -0300 Subject: Function scopes (#1032) --- src/eval/func.rs | 29 +++++++++++++- src/eval/mod.rs | 112 ++++++++++++++++++++++++++++++++++++++++-------------- src/eval/value.rs | 1 + 3 files changed, 112 insertions(+), 30 deletions(-) (limited to 'src/eval') diff --git a/src/eval/func.rs b/src/eval/func.rs index a6e0de84..51eba564 100644 --- a/src/eval/func.rs +++ b/src/eval/func.rs @@ -5,12 +5,13 @@ use std::hash::{Hash, Hasher}; use std::sync::Arc; use comemo::{Prehashed, Track, Tracked, TrackedMut}; +use ecow::eco_format; use once_cell::sync::Lazy; use super::{ cast_to_value, Args, CastInfo, Eval, Flow, Route, Scope, Scopes, Tracer, Value, Vm, }; -use crate::diag::{bail, SourceResult}; +use crate::diag::{bail, SourceResult, StrResult}; use crate::model::{ElemFunc, Introspector, StabilityProvider, Vt}; use crate::syntax::ast::{self, AstNode, Expr, Ident}; use crate::syntax::{SourceId, Span, SyntaxNode}; @@ -144,6 +145,30 @@ impl Func { _ => None, } } + + /// Get a field from this function's scope, if possible. + pub fn get(&self, field: &str) -> StrResult<&Value> { + match &self.repr { + Repr::Native(func) => func.info.scope.get(field).ok_or_else(|| { + eco_format!( + "function `{}` does not contain field `{}`", + func.info.name, + field + ) + }), + Repr::Elem(func) => func.info().scope.get(field).ok_or_else(|| { + eco_format!( + "function `{}` does not contain field `{}`", + func.name(), + field + ) + }), + Repr::Closure(_) => { + Err(eco_format!("cannot access fields on user-defined functions")) + } + Repr::With(arc) => arc.0.get(field), + } + } } impl Debug for Func { @@ -225,6 +250,8 @@ pub struct FuncInfo { pub returns: Vec<&'static str>, /// Which category the function is part of. pub category: &'static str, + /// The function's own scope of fields and sub-functions. + pub scope: Scope, } impl FuncInfo { diff --git a/src/eval/mod.rs b/src/eval/mod.rs index b430b400..a837c9e0 100644 --- a/src/eval/mod.rs +++ b/src/eval/mod.rs @@ -42,7 +42,7 @@ use std::mem; use std::path::{Path, PathBuf}; use comemo::{Track, Tracked, TrackedMut}; -use ecow::EcoVec; +use ecow::{EcoString, EcoVec}; use unicode_segmentation::UnicodeSegmentation; use crate::diag::{ @@ -1077,7 +1077,15 @@ impl Eval for ast::FuncCall { if methods::is_mutating(&field) { let args = args.eval(vm)?; let target = target.access(vm)?; - if !matches!(target, Value::Symbol(_) | Value::Module(_)) { + + // Prioritize a function's own methods (with, where) over its + // fields. This is fine as we define each field of a function, + // if it has any. + // ('methods_on' will be empty for Symbol and Module - their + // method calls always refer to their fields.) + if !matches!(target, Value::Symbol(_) | Value::Module(_) | Value::Func(_)) + || methods_on(target.type_name()).iter().any(|(m, _)| m == &field) + { return methods::call_mut(target, &field, args, span).trace( vm.world(), point, @@ -1088,7 +1096,10 @@ impl Eval for ast::FuncCall { } else { let target = target.eval(vm)?; let args = args.eval(vm)?; - if !matches!(target, Value::Symbol(_) | Value::Module(_)) { + + if !matches!(target, Value::Symbol(_) | Value::Module(_) | Value::Func(_)) + || methods_on(target.type_name()).iter().any(|(m, _)| m == &field) + { return methods::call(vm, target, &field, args, span).trace( vm.world(), point, @@ -1613,6 +1624,42 @@ impl Eval for ast::ForLoop { } } +/// Applies imports from `import` to the current scope. +fn apply_imports>( + imports: Option, + vm: &mut Vm, + source_value: V, + name: impl Fn(&V) -> EcoString, + scope: impl Fn(&V) -> &Scope, +) -> SourceResult<()> { + match imports { + None => { + vm.scopes.top.define(name(&source_value), source_value); + } + Some(ast::Imports::Wildcard) => { + for (var, value) in scope(&source_value).iter() { + vm.scopes.top.define(var.clone(), value.clone()); + } + } + Some(ast::Imports::Items(idents)) => { + let mut errors = vec![]; + let scope = scope(&source_value); + for ident in idents { + if let Some(value) = scope.get(&ident) { + vm.define(ident, value.clone()); + } else { + errors.push(error!(ident.span(), "unresolved import")); + } + } + if !errors.is_empty() { + return Err(Box::new(errors)); + } + } + } + + Ok(()) +} + impl Eval for ast::ModuleImport { type Output = Value; @@ -1620,30 +1667,26 @@ impl Eval for ast::ModuleImport { fn eval(&self, vm: &mut Vm) -> SourceResult { let span = self.source().span(); let source = self.source().eval(vm)?; - let module = import(vm, source, span)?; - - match self.imports() { - None => { - vm.scopes.top.define(module.name().clone(), module); - } - Some(ast::Imports::Wildcard) => { - for (var, value) in module.scope().iter() { - vm.scopes.top.define(var.clone(), value.clone()); - } - } - Some(ast::Imports::Items(idents)) => { - let mut errors = vec![]; - for ident in idents { - if let Some(value) = module.scope().get(&ident) { - vm.define(ident, value.clone()); - } else { - errors.push(error!(ident.span(), "unresolved import")); - } - } - if !errors.is_empty() { - return Err(Box::new(errors)); - } + if let Value::Func(func) = source { + if func.info().is_none() { + bail!(span, "cannot import from user-defined functions"); } + apply_imports( + self.imports(), + vm, + func, + |func| func.info().unwrap().name.into(), + |func| &func.info().unwrap().scope, + )?; + } else { + let module = import(vm, source, span, true)?; + apply_imports( + self.imports(), + vm, + module, + |module| module.name().clone(), + |module| module.scope(), + )?; } Ok(Value::None) @@ -1657,17 +1700,28 @@ impl Eval for ast::ModuleInclude { fn eval(&self, vm: &mut Vm) -> SourceResult { let span = self.source().span(); let source = self.source().eval(vm)?; - let module = import(vm, source, span)?; + let module = import(vm, source, span, false)?; Ok(module.content()) } } /// Process an import of a module relative to the current location. -fn import(vm: &mut Vm, source: Value, span: Span) -> SourceResult { +fn import( + vm: &mut Vm, + source: Value, + span: Span, + accept_functions: bool, +) -> SourceResult { let path = match source { Value::Str(path) => path, Value::Module(module) => return Ok(module), - v => bail!(span, "expected path or module, found {}", v.type_name()), + v => { + if accept_functions { + bail!(span, "expected path, module or function, found {}", v.type_name()) + } else { + bail!(span, "expected path or module, found {}", v.type_name()) + } + } }; // Load the source file. diff --git a/src/eval/value.rs b/src/eval/value.rs index 1bfad9c8..bd612cce 100644 --- a/src/eval/value.rs +++ b/src/eval/value.rs @@ -127,6 +127,7 @@ impl Value { Self::Dict(dict) => dict.at(field, None).cloned(), Self::Content(content) => content.at(field, None), Self::Module(module) => module.get(field).cloned(), + Self::Func(func) => func.get(field).cloned(), v => Err(eco_format!("cannot access fields on type {}", v.type_name())), } } -- cgit v1.2.3