summaryrefslogtreecommitdiff
path: root/src/model/eval.rs
diff options
context:
space:
mode:
authorLaurenz <laurmaedje@gmail.com>2023-01-23 13:24:39 +0100
committerLaurenz <laurmaedje@gmail.com>2023-01-23 13:26:10 +0100
commit2b8426b1b3a19d46a94abaece674525948c146af (patch)
tree35eae5b09cda399224c58543f58a985cfcd00e79 /src/model/eval.rs
parent6ca240508eed7288fcc317b9e167f6470a2f952c (diff)
Interpret methods on modules as functions in modules
Diffstat (limited to 'src/model/eval.rs')
-rw-r--r--src/model/eval.rs76
1 files changed, 54 insertions, 22 deletions
diff --git a/src/model/eval.rs b/src/model/eval.rs
index 66ff2cd5..67c733ce 100644
--- a/src/model/eval.rs
+++ b/src/model/eval.rs
@@ -877,6 +877,18 @@ impl ast::Binary {
op: fn(Value, Value) -> StrResult<Value>,
) -> SourceResult<Value> {
let rhs = self.rhs().eval(vm)?;
+ let lhs = self.lhs();
+
+ // An assignment to a dictionary field is different from a normal access
+ // since it can create the field instead of just modifying it.
+ if self.op() == ast::BinOp::Assign {
+ if let ast::Expr::FieldAccess(access) = &lhs {
+ let dict = access.access_dict(vm)?;
+ dict.insert(access.field().take().into(), rhs);
+ return Ok(Value::None);
+ }
+ }
+
let location = self.lhs().access(vm)?;
let lhs = std::mem::take(&mut *location);
*location = op(lhs, rhs).at(self.span())?;
@@ -898,14 +910,7 @@ impl Eval for ast::FieldAccess {
.field(&field)
.ok_or_else(|| format!("unknown field `{field}`"))
.at(span)?,
- Value::Module(module) => module
- .scope()
- .get(&field)
- .cloned()
- .ok_or_else(|| {
- format!("module `{}` does not contain `{field}`", module.name())
- })
- .at(span)?,
+ Value::Module(module) => module.get(&field).cloned().at(span)?,
v => bail!(
self.target().span(),
"expected dictionary or content, found {}",
@@ -921,7 +926,8 @@ impl Eval for ast::FuncCall {
fn eval(&self, vm: &mut Vm) -> SourceResult<Self::Output> {
let callee = self.callee();
let callee = callee.eval(vm)?.cast::<Func>().at(callee.span())?;
- self.eval_with_callee(vm, callee)
+ let args = self.args().eval(vm)?;
+ Self::eval_call(vm, &callee, args, self.span())
}
}
@@ -929,7 +935,8 @@ impl ast::FuncCall {
fn eval_in_math(&self, vm: &mut Vm) -> SourceResult<Content> {
let callee = self.callee().eval(vm)?;
if let Value::Func(callee) = callee {
- Ok(self.eval_with_callee(vm, callee)?.display_in_math())
+ let args = self.args().eval(vm)?;
+ Ok(Self::eval_call(vm, &callee, args, self.span())?.display_in_math())
} else {
let mut body = (vm.items.math_atom)('('.into());
let mut args = self.args().eval(vm)?;
@@ -944,14 +951,18 @@ impl ast::FuncCall {
}
}
- fn eval_with_callee(&self, vm: &mut Vm, callee: Func) -> SourceResult<Value> {
+ fn eval_call(
+ vm: &mut Vm,
+ callee: &Func,
+ args: Args,
+ span: Span,
+ ) -> SourceResult<Value> {
if vm.depth >= MAX_CALL_DEPTH {
- bail!(self.span(), "maximum function call depth exceeded");
+ bail!(span, "maximum function call depth exceeded");
}
- let args = self.args().eval(vm)?;
let point = || Tracepoint::Call(callee.name().map(Into::into));
- callee.call(vm, args).trace(vm.world, point, self.span())
+ callee.call(vm, args).trace(vm.world, point, span)
}
}
@@ -960,18 +971,35 @@ impl Eval for ast::MethodCall {
fn eval(&self, vm: &mut Vm) -> SourceResult<Self::Output> {
let span = self.span();
- let method = self.method().take();
+ let method = self.method();
let result = if methods::is_mutating(&method) {
let args = self.args().eval(vm)?;
let value = self.target().access(vm)?;
+
+ if let Value::Module(module) = &value {
+ if let Value::Func(callee) =
+ module.get(&method).cloned().at(method.span())?
+ {
+ return ast::FuncCall::eval_call(vm, &callee, args, self.span());
+ }
+ }
+
methods::call_mut(value, &method, args, span)
} else {
let value = self.target().eval(vm)?;
let args = self.args().eval(vm)?;
+
+ if let Value::Module(module) = &value {
+ if let Value::Func(callee) = module.get(&method).at(method.span())? {
+ return ast::FuncCall::eval_call(vm, callee, args, self.span());
+ }
+ }
+
methods::call(vm, value, &method, args, span)
};
+ let method = method.take();
let point = || Tracepoint::Call(Some(method.clone()));
result.trace(vm.world, point, span)
}
@@ -1423,16 +1451,20 @@ impl Access for ast::Parenthesized {
impl Access for ast::FieldAccess {
fn access<'a>(&self, vm: &'a mut Vm) -> SourceResult<&'a mut Value> {
- let value = self.target().access(vm)?;
- let Value::Dict(dict) = value else {
- bail!(
+ self.access_dict(vm)?.at_mut(&self.field().take()).at(self.span())
+ }
+}
+
+impl ast::FieldAccess {
+ fn access_dict<'a>(&self, vm: &'a mut Vm) -> SourceResult<&'a mut Dict> {
+ match self.target().access(vm)? {
+ Value::Dict(dict) => Ok(dict),
+ value => bail!(
self.target().span(),
"expected dictionary, found {}",
value.type_name(),
- );
- };
-
- Ok(dict.at_mut(self.field().take().into()))
+ ),
+ }
}
}