diff options
| author | Laurenz <laurmaedje@gmail.com> | 2023-01-23 13:24:39 +0100 |
|---|---|---|
| committer | Laurenz <laurmaedje@gmail.com> | 2023-01-23 13:26:10 +0100 |
| commit | 2b8426b1b3a19d46a94abaece674525948c146af (patch) | |
| tree | 35eae5b09cda399224c58543f58a985cfcd00e79 /src/model/eval.rs | |
| parent | 6ca240508eed7288fcc317b9e167f6470a2f952c (diff) | |
Interpret methods on modules as functions in modules
Diffstat (limited to 'src/model/eval.rs')
| -rw-r--r-- | src/model/eval.rs | 76 |
1 files changed, 54 insertions, 22 deletions
diff --git a/src/model/eval.rs b/src/model/eval.rs index 66ff2cd5..67c733ce 100644 --- a/src/model/eval.rs +++ b/src/model/eval.rs @@ -877,6 +877,18 @@ impl ast::Binary { op: fn(Value, Value) -> StrResult<Value>, ) -> SourceResult<Value> { let rhs = self.rhs().eval(vm)?; + let lhs = self.lhs(); + + // An assignment to a dictionary field is different from a normal access + // since it can create the field instead of just modifying it. + if self.op() == ast::BinOp::Assign { + if let ast::Expr::FieldAccess(access) = &lhs { + let dict = access.access_dict(vm)?; + dict.insert(access.field().take().into(), rhs); + return Ok(Value::None); + } + } + let location = self.lhs().access(vm)?; let lhs = std::mem::take(&mut *location); *location = op(lhs, rhs).at(self.span())?; @@ -898,14 +910,7 @@ impl Eval for ast::FieldAccess { .field(&field) .ok_or_else(|| format!("unknown field `{field}`")) .at(span)?, - Value::Module(module) => module - .scope() - .get(&field) - .cloned() - .ok_or_else(|| { - format!("module `{}` does not contain `{field}`", module.name()) - }) - .at(span)?, + Value::Module(module) => module.get(&field).cloned().at(span)?, v => bail!( self.target().span(), "expected dictionary or content, found {}", @@ -921,7 +926,8 @@ impl Eval for ast::FuncCall { fn eval(&self, vm: &mut Vm) -> SourceResult<Self::Output> { let callee = self.callee(); let callee = callee.eval(vm)?.cast::<Func>().at(callee.span())?; - self.eval_with_callee(vm, callee) + let args = self.args().eval(vm)?; + Self::eval_call(vm, &callee, args, self.span()) } } @@ -929,7 +935,8 @@ impl ast::FuncCall { fn eval_in_math(&self, vm: &mut Vm) -> SourceResult<Content> { let callee = self.callee().eval(vm)?; if let Value::Func(callee) = callee { - Ok(self.eval_with_callee(vm, callee)?.display_in_math()) + let args = self.args().eval(vm)?; + Ok(Self::eval_call(vm, &callee, args, self.span())?.display_in_math()) } else { let mut body = (vm.items.math_atom)('('.into()); let mut args = self.args().eval(vm)?; @@ -944,14 +951,18 @@ impl ast::FuncCall { } } - fn eval_with_callee(&self, vm: &mut Vm, callee: Func) -> SourceResult<Value> { + fn eval_call( + vm: &mut Vm, + callee: &Func, + args: Args, + span: Span, + ) -> SourceResult<Value> { if vm.depth >= MAX_CALL_DEPTH { - bail!(self.span(), "maximum function call depth exceeded"); + bail!(span, "maximum function call depth exceeded"); } - let args = self.args().eval(vm)?; let point = || Tracepoint::Call(callee.name().map(Into::into)); - callee.call(vm, args).trace(vm.world, point, self.span()) + callee.call(vm, args).trace(vm.world, point, span) } } @@ -960,18 +971,35 @@ impl Eval for ast::MethodCall { fn eval(&self, vm: &mut Vm) -> SourceResult<Self::Output> { let span = self.span(); - let method = self.method().take(); + let method = self.method(); let result = if methods::is_mutating(&method) { let args = self.args().eval(vm)?; let value = self.target().access(vm)?; + + if let Value::Module(module) = &value { + if let Value::Func(callee) = + module.get(&method).cloned().at(method.span())? + { + return ast::FuncCall::eval_call(vm, &callee, args, self.span()); + } + } + methods::call_mut(value, &method, args, span) } else { let value = self.target().eval(vm)?; let args = self.args().eval(vm)?; + + if let Value::Module(module) = &value { + if let Value::Func(callee) = module.get(&method).at(method.span())? { + return ast::FuncCall::eval_call(vm, callee, args, self.span()); + } + } + methods::call(vm, value, &method, args, span) }; + let method = method.take(); let point = || Tracepoint::Call(Some(method.clone())); result.trace(vm.world, point, span) } @@ -1423,16 +1451,20 @@ impl Access for ast::Parenthesized { impl Access for ast::FieldAccess { fn access<'a>(&self, vm: &'a mut Vm) -> SourceResult<&'a mut Value> { - let value = self.target().access(vm)?; - let Value::Dict(dict) = value else { - bail!( + self.access_dict(vm)?.at_mut(&self.field().take()).at(self.span()) + } +} + +impl ast::FieldAccess { + fn access_dict<'a>(&self, vm: &'a mut Vm) -> SourceResult<&'a mut Dict> { + match self.target().access(vm)? { + Value::Dict(dict) => Ok(dict), + value => bail!( self.target().span(), "expected dictionary, found {}", value.type_name(), - ); - }; - - Ok(dict.at_mut(self.field().take().into())) + ), + } } } |
