summaryrefslogtreecommitdiff
path: root/crates/typst-library/src/foundations/plugin.rs
blob: 31107dc34f2a26b52ef31409c0f125a878489289 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
use std::fmt::{self, Debug, Formatter};
use std::hash::{Hash, Hasher};
use std::sync::{Arc, Mutex};

use ecow::{eco_format, EcoString};
use typst_syntax::Spanned;
use wasmi::{AsContext, AsContextMut};

use crate::diag::{bail, At, SourceResult, StrResult};
use crate::engine::Engine;
use crate::foundations::{func, repr, scope, ty, Bytes};
use crate::World;

/// A WebAssembly plugin.
///
/// Typst is capable of interfacing with plugins compiled to WebAssembly. Plugin
/// functions may accept multiple [byte buffers]($bytes) as arguments and return
/// a single byte buffer. They should typically be wrapped in idiomatic Typst
/// functions that perform the necessary conversions between native Typst types
/// and bytes.
///
/// Plugins run in isolation from your system, which means that printing,
/// reading files, or anything like that will not be supported for security
/// reasons. To run as a plugin, a program needs to be compiled to a 32-bit
/// shared WebAssembly library. Many compilers will use the
/// [WASI ABI](https://wasi.dev/) by default or as their only option (e.g.
/// emscripten), which allows printing, reading files, etc. This ABI will not
/// directly work with Typst. You will either need to compile to a different
/// target or [stub all functions](https://github.com/astrale-sharp/wasm-minimal-protocol/blob/master/wasi-stub).
///
/// # Plugins and Packages
/// Plugins are distributed as packages. A package can make use of a plugin
/// simply by including a WebAssembly file and loading it. Because the
/// byte-based plugin interface is quite low-level, plugins are typically
/// exposed through wrapper functions, that also live in the same package.
///
/// # Purity
/// Plugin functions must be pure: Given the same arguments, they must always
/// return the same value. The reason for this is that Typst functions must be
/// pure (which is quite fundamental to the language design) and, since Typst
/// function can call plugin functions, this requirement is inherited. In
/// particular, if a plugin function is called twice with the same arguments,
/// Typst might cache the results and call your function only once.
///
/// # Example
/// ```example
/// #let myplugin = plugin("hello.wasm")
/// #let concat(a, b) = str(
///   myplugin.concatenate(
///     bytes(a),
///     bytes(b),
///   )
/// )
///
/// #concat("hello", "world")
/// ```
///
/// # Protocol
/// To be used as a plugin, a WebAssembly module must conform to the following
/// protocol:
///
/// ## Exports
/// A plugin module can export functions to make them callable from Typst. To
/// conform to the protocol, an exported function should:
///
/// - Take `n` 32-bit integer arguments `a_1`, `a_2`, ..., `a_n` (interpreted as
///   lengths, so `usize/size_t` may be preferable), and return one 32-bit
///   integer.
///
/// - The function should first allocate a buffer `buf` of length
///   `a_1 + a_2 + ... + a_n`, and then call
///   `wasm_minimal_protocol_write_args_to_buffer(buf.ptr)`.
///
/// - The `a_1` first bytes of the buffer now constitute the first argument, the
///   `a_2` next bytes the second argument, and so on.
///
/// - The function can now do its job with the arguments and produce an output
///   buffer. Before returning, it should call
///   `wasm_minimal_protocol_send_result_to_host` to send its result back to the
///   host.
///
/// - To signal success, the function should return `0`.
///
/// - To signal an error, the function should return `1`. The written buffer is
///   then interpreted as an UTF-8 encoded error message.
///
/// ## Imports
/// Plugin modules need to import two functions that are provided by the runtime.
/// (Types and functions are described using WAT syntax.)
///
/// - `(import "typst_env" "wasm_minimal_protocol_write_args_to_buffer" (func (param i32)))`
///
///   Writes the arguments for the current function into a plugin-allocated
///   buffer. When a plugin function is called, it
///   [receives the lengths](#exports) of its input buffers as arguments. It
///   should then allocate a buffer whose capacity is at least the sum of these
///   lengths. It should then call this function with a `ptr` to the buffer to
///   fill it with the arguments, one after another.
///
/// - `(import "typst_env" "wasm_minimal_protocol_send_result_to_host" (func (param i32 i32)))`
///
///   Sends the output of the current function to the host (Typst). The first
///   parameter shall be a pointer to a buffer (`ptr`), while the second is the
///   length of that buffer (`len`). The memory pointed at by `ptr` can be freed
///   immediately after this function returns. If the message should be
///   interpreted as an error message, it should be encoded as UTF-8.
///
/// # Resources
/// For more resources, check out the
/// [wasm-minimal-protocol repository](https://github.com/astrale-sharp/wasm-minimal-protocol).
/// It contains:
///
/// - A list of example plugin implementations and a test runner for these
///   examples
/// - Wrappers to help you write your plugin in Rust (Zig wrapper in
///   development)
/// - A stubber for WASI
#[ty(scope, cast)]
#[derive(Clone)]
pub struct Plugin(Arc<Repr>);

/// The internal representation of a plugin.
struct Repr {
    /// The raw WebAssembly bytes.
    bytes: Bytes,
    /// The function defined by the WebAssembly module.
    functions: Vec<(EcoString, wasmi::Func)>,
    /// Owns all data associated with the WebAssembly module.
    store: Mutex<Store>,
}

/// Owns all data associated with the WebAssembly module.
type Store = wasmi::Store<StoreData>;

/// If there was an error reading/writing memory, keep the offset + length to
/// display an error message.
struct MemoryError {
    offset: u32,
    length: u32,
    write: bool,
}
/// The persistent store data used for communication between store and host.
#[derive(Default)]
struct StoreData {
    args: Vec<Bytes>,
    output: Vec<u8>,
    memory_error: Option<MemoryError>,
}

#[scope]
impl Plugin {
    /// Creates a new plugin from a WebAssembly file.
    #[func(constructor)]
    pub fn construct(
        /// The engine.
        engine: &mut Engine,
        /// Path to a WebAssembly file.
        ///
        /// For more details, see the [Paths section]($syntax/#paths).
        path: Spanned<EcoString>,
    ) -> SourceResult<Plugin> {
        let Spanned { v: path, span } = path;
        let id = span.resolve_path(&path).at(span)?;
        let data = engine.world.file(id).at(span)?;
        Plugin::new(data).at(span)
    }
}

impl Plugin {
    /// Create a new plugin from raw WebAssembly bytes.
    #[comemo::memoize]
    #[typst_macros::time(name = "load plugin")]
    pub fn new(bytes: Bytes) -> StrResult<Plugin> {
        let engine = wasmi::Engine::default();
        let module = wasmi::Module::new(&engine, bytes.as_slice())
            .map_err(|err| format!("failed to load WebAssembly module ({err})"))?;

        let mut linker = wasmi::Linker::new(&engine);
        linker
            .func_wrap(
                "typst_env",
                "wasm_minimal_protocol_send_result_to_host",
                wasm_minimal_protocol_send_result_to_host,
            )
            .unwrap();
        linker
            .func_wrap(
                "typst_env",
                "wasm_minimal_protocol_write_args_to_buffer",
                wasm_minimal_protocol_write_args_to_buffer,
            )
            .unwrap();

        let mut store = Store::new(&engine, StoreData::default());
        let instance = linker
            .instantiate(&mut store, &module)
            .and_then(|pre_instance| pre_instance.start(&mut store))
            .map_err(|e| eco_format!("{e}"))?;

        // Ensure that the plugin exports its memory.
        if !matches!(
            instance.get_export(&store, "memory"),
            Some(wasmi::Extern::Memory(_))
        ) {
            bail!("plugin does not export its memory");
        }

        // Collect exported functions.
        let functions = instance
            .exports(&store)
            .filter_map(|export| {
                let name = export.name().into();
                export.into_func().map(|func| (name, func))
            })
            .collect();

        Ok(Plugin(Arc::new(Repr { bytes, functions, store: Mutex::new(store) })))
    }

    /// Call the plugin function with the given `name`.
    #[comemo::memoize]
    #[typst_macros::time(name = "call plugin")]
    pub fn call(&self, name: &str, args: Vec<Bytes>) -> StrResult<Bytes> {
        // Find the function with the given name.
        let func = self
            .0
            .functions
            .iter()
            .find(|(v, _)| v == name)
            .map(|&(_, func)| func)
            .ok_or_else(|| {
                eco_format!("plugin does not contain a function called {name}")
            })?;

        let mut store = self.0.store.lock().unwrap();
        let ty = func.ty(store.as_context());

        // Check function signature.
        if ty.params().iter().any(|&v| v != wasmi::core::ValType::I32) {
            bail!(
                "plugin function `{name}` has a parameter that is not a 32-bit integer"
            );
        }
        if ty.results() != [wasmi::core::ValType::I32] {
            bail!("plugin function `{name}` does not return exactly one 32-bit integer");
        }

        // Check inputs.
        let expected = ty.params().len();
        let given = args.len();
        if expected != given {
            bail!(
                "plugin function takes {expected} argument{}, but {given} {} given",
                if expected == 1 { "" } else { "s" },
                if given == 1 { "was" } else { "were" },
            );
        }

        // Collect the lengths of the argument buffers.
        let lengths = args
            .iter()
            .map(|a| wasmi::Val::I32(a.len() as i32))
            .collect::<Vec<_>>();

        // Store the input data.
        store.data_mut().args = args;

        // Call the function.
        let mut code = wasmi::Val::I32(-1);
        func.call(store.as_context_mut(), &lengths, std::slice::from_mut(&mut code))
            .map_err(|err| eco_format!("plugin panicked: {err}"))?;
        if let Some(MemoryError { offset, length, write }) =
            store.data_mut().memory_error.take()
        {
            return Err(eco_format!(
                "plugin tried to {kind} out of bounds: pointer {offset:#x} is out of bounds for {kind} of length {length}",
                kind = if write { "write" } else { "read" }
            ));
        }

        // Extract the returned data.
        let output = std::mem::take(&mut store.data_mut().output);

        // Parse the functions return value.
        match code {
            wasmi::Val::I32(0) => {}
            wasmi::Val::I32(1) => match std::str::from_utf8(&output) {
                Ok(message) => bail!("plugin errored with: {message}"),
                Err(_) => {
                    bail!("plugin errored, but did not return a valid error message")
                }
            },
            _ => bail!("plugin did not respect the protocol"),
        };

        Ok(output.into())
    }

    /// An iterator over all the function names defined by the plugin.
    pub fn iter(&self) -> impl Iterator<Item = &EcoString> {
        self.0.functions.as_slice().iter().map(|(func_name, _)| func_name)
    }
}

impl Debug for Plugin {
    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
        f.pad("Plugin(..)")
    }
}

impl repr::Repr for Plugin {
    fn repr(&self) -> EcoString {
        "plugin(..)".into()
    }
}

impl PartialEq for Plugin {
    fn eq(&self, other: &Self) -> bool {
        self.0.bytes == other.0.bytes
    }
}

impl Hash for Plugin {
    fn hash<H: Hasher>(&self, state: &mut H) {
        self.0.bytes.hash(state);
    }
}

/// Write the arguments to the plugin function into the plugin's memory.
fn wasm_minimal_protocol_write_args_to_buffer(
    mut caller: wasmi::Caller<StoreData>,
    ptr: u32,
) {
    let memory = caller.get_export("memory").unwrap().into_memory().unwrap();
    let arguments = std::mem::take(&mut caller.data_mut().args);
    let mut offset = ptr as usize;
    for arg in arguments {
        if memory.write(&mut caller, offset, arg.as_slice()).is_err() {
            caller.data_mut().memory_error = Some(MemoryError {
                offset: offset as u32,
                length: arg.len() as u32,
                write: true,
            });
            return;
        }
        offset += arg.len();
    }
}

/// Extracts the output of the plugin function from the plugin's memory.
fn wasm_minimal_protocol_send_result_to_host(
    mut caller: wasmi::Caller<StoreData>,
    ptr: u32,
    len: u32,
) {
    let memory = caller.get_export("memory").unwrap().into_memory().unwrap();
    let mut buffer = std::mem::take(&mut caller.data_mut().output);
    buffer.resize(len as usize, 0);
    if memory.read(&caller, ptr as _, &mut buffer).is_err() {
        caller.data_mut().memory_error =
            Some(MemoryError { offset: ptr, length: len, write: false });
        return;
    }
    caller.data_mut().output = buffer;
}