wasm/execution/
linear_memory.rs

1use core::{cell::UnsafeCell, mem};
2
3use alloc::vec::Vec;
4
5use crate::{
6    core::{indices::MemIdx, little_endian::LittleEndianBytes},
7    rw_spinlock::RwSpinLock,
8    RuntimeError,
9};
10
11/// Implementation of the linear memory suitable for concurrent access
12///
13/// Implements the base for the instructions described in
14/// <https://webassembly.github.io/spec/core/exec/instructions.html#memory-instructions>.
15///
16/// This linear memory implementation internally relies on a `Vec<UnsafeCell<u8>>`. Thus, the atomic
17/// unit of information for it is a byte (`u8`). All access to the linear memory internally occurs
18/// through pointers, avoiding the creation of shared and mut refs to the internal data completely.
19/// This avoids undefined behavior, except for the race-condition inherent to concurrent writes.
20/// Because of this, the [`LinearMemory::store`] function does not require `&mut self` -- `&self`
21/// suffices.
22///
23/// # Notes on overflowing
24///
25/// All operations that rely on accessing `n` bytes starting at `index` in the linear memory have to
26/// perform bounds checking. Thus they always have to ensure that `n + index < linear_memory.len()`
27/// holds true (e.g. `n + index - 1` must be a valid index into `linear_memory`). However,
28/// writing that check as is bears the danger of an overflow, assuming that `n`, `index` and
29/// `linear_memory.len()` are the same given integer type, `n + index` can overflow, resulting in
30/// the check passing despite the access being out of bounds!
31///
32/// To avoid this, the bounds checks are carefully ordered to avoid any overflows:
33///
34/// - First we check, that `n <= linear_memory.len()` holds true, ensuring that the amount of bytes
35///   to be accessed is indeed smaller than or equal to the linear memory's size. If this does not
36///   hold true, continuation of the operation will yield out of bounds access in any case.
37/// - Then, as a second check, we verify that `index <= linear_memory.len() - n`. This way we
38///   avoid the overflow, as there is no addition. The subtraction in the left hand can not
39///   underflow, due to the previous check (which asserts that `n` is smaller than or equal to
40///   `linear_memory.len()`).
41///
42/// Combined in the given order, these two checks enable bounds checking without risking any
43/// overflow or underflow, provided that `n`, `index` and `linear_memory.len()` are of the same
44/// integer type.
45///
46/// # Notes on locking
47///
48/// The internal data vector of the [`LinearMemory`] is wrapped in a [`RwSpinLock`]. Despite the
49/// name, writes to the linear memory do not require an acquisition of a write lock. Writes are
50/// implemented through a shared ref to the internal vector, with an `UnsafeCell` to achieve
51/// interior mutability.
52///
53/// However, linear memory can grow. As the linear memory is implemented via a [`Vec`], a grow can
54/// result in the vector's internal data buffer to be copied over to a bigger, fresh allocation.
55/// The old buffer is then freed. Combined with concurrent mutable access, this can cause
56/// use-after-free. To avoid this, a grow operation of the linear memory acquires a write lock,
57/// blocking all read/write to the linear memory inbetween.
58///
59/// # Unsafe Note
60///
61/// Raw pointer access it required, because concurent mutation of the linear memory might happen
62/// (consider the threading proposal for WASM, where mutliple WASM threads access the same linear
63/// memory at the same time). The inherent race condition results in UB w/r/t the state of the `u8`s
64/// in the inner data. However, this is tolerable, e.g. avoiding race conditions on the state of the
65/// linear memory can not be the task of the interpreter, but has to be fulfilled by the interpreted
66/// bytecode itself.
67// TODO if a memmap like operation is available, the linear memory implementation can be optimized brutally. Out-of-bound access can be mapped to userspace handled page-faults, e.g. the MMU takes over that responsibility of catching out of bounds. Grow can happen without copying of data, by mapping new pages consecutively after the current final page of the linear memory.
68pub struct LinearMemory<const PAGE_SIZE: usize = { crate::Limits::MEM_PAGE_SIZE as usize }> {
69    inner_data: RwSpinLock<Vec<UnsafeCell<u8>>>,
70}
71
72/// Type to express the page count
73pub type PageCountTy = u16;
74
75impl<const PAGE_SIZE: usize> LinearMemory<PAGE_SIZE> {
76    /// Size of a page in the linear memory, measured in bytes
77    ///
78    /// The WASM specification demands a page size of 64 KiB, that is `65536` bytes:
79    /// <https://webassembly.github.io/spec/core/exec/runtime.html?highlight=page#memory-instances>
80    const PAGE_SIZE: usize = PAGE_SIZE;
81
82    /// Create a new, empty [`LinearMemory`]
83    pub fn new() -> Self {
84        Self {
85            inner_data: RwSpinLock::new(Vec::new()),
86        }
87    }
88
89    /// Create a new, empty [`LinearMemory`]
90    pub fn new_with_initial_pages(pages: PageCountTy) -> Self {
91        let size_bytes = Self::PAGE_SIZE * pages as usize;
92        let mut data = Vec::with_capacity(size_bytes);
93        data.resize_with(size_bytes, || UnsafeCell::new(0));
94
95        Self {
96            inner_data: RwSpinLock::new(data),
97        }
98    }
99
100    /// Grow the [`LinearMemory`] by a number of pages
101    pub fn grow(&self, pages_to_add: PageCountTy) {
102        let mut lock_guard = self.inner_data.write();
103        let prior_length_bytes = lock_guard.len();
104        let new_length_bytes = prior_length_bytes + Self::PAGE_SIZE * pages_to_add as usize;
105        lock_guard.resize_with(new_length_bytes, || UnsafeCell::new(0));
106    }
107
108    /// Get the number of pages currently allocated to this [`LinearMemory`]
109    pub fn pages(&self) -> PageCountTy {
110        PageCountTy::try_from(self.inner_data.read().len() / PAGE_SIZE).unwrap()
111    }
112
113    /// Get the length in bytes currently allocated to this [`LinearMemory`]
114    // TODO remove this op
115    pub fn len(&self) -> usize {
116        self.inner_data.read().len()
117    }
118
119    /// At a given index, store a datum in the [`LinearMemory`]
120    pub fn store<const N: usize, T: LittleEndianBytes<N>>(
121        &self,
122        index: MemIdx,
123        value: T,
124    ) -> Result<(), RuntimeError> {
125        let value_size = mem::size_of::<T>();
126
127        // Unless someone implementes something wrong like `impl LittleEndianBytes<3> for f64`, this
128        // check is already guaranteed at the type level. Therefore only a debug_assert.
129        debug_assert_eq!(value_size, N, "value size must match const generic N");
130
131        let lock_guard = self.inner_data.read();
132
133        // A value must fit into the linear memory
134        if value_size > lock_guard.len() {
135            error!("value does not fit into linear memory");
136            return Err(RuntimeError::MemoryAccessOutOfBounds);
137        }
138
139        // The following statement must be true
140        // `index + value_size <= lock_guard.len()`
141        // This check verifies it, while avoiding the possible overflow. The subtraction can not
142        // underflow because of the previous check.
143
144        if (index) > lock_guard.len() - value_size {
145            error!("value write would extend beyond the end of the linear memory");
146            return Err(RuntimeError::MemoryAccessOutOfBounds);
147        }
148
149        // TODO this unwrap can not fail, maybe use unwrap_unchecked?
150        let ptr = lock_guard.get(index).unwrap().get();
151        let bytes = value.to_le_bytes(); //
152
153        // Safety argument:
154        //
155        // - nonoverlapping is guaranteed, because `src` is a pointer to a stack allocated array,
156        //   while the destination is heap allocated Vec
157        // - the first check above guarantee that `src` fits into the destination
158        // - the second check above guarantees that even with the offset in `index`, `src` does not
159        //   extend beyond the destinations last `UnsafeCell<u8>`
160        // - the use of `UnsafeCell` avoids any `&` or `&mut` to ever be created on any of the `u8`s
161        //   contained in the `UnsafeCell`s, so no UB is created through the existence of unsound
162        //   references
163        unsafe { ptr.copy_from_nonoverlapping(bytes.as_ref().as_ptr(), value_size) }
164
165        Ok(())
166    }
167
168    /// From a given index, load a datum in the [`LinearMemory`]
169    pub fn load<const N: usize, T: LittleEndianBytes<N>>(
170        &self,
171        index: MemIdx,
172    ) -> Result<T, RuntimeError> {
173        let value_size = mem::size_of::<T>();
174
175        // Unless someone implementes something wrong like `LittleEndianBytes<3> for i8`, this
176        // check is already guaranteed at the type level. Therefore only a debug_assert.
177        debug_assert_eq!(value_size, N, "value size must match const generic N");
178
179        let lock_guard = self.inner_data.read();
180
181        // A value must fit into the linear memory
182        if value_size > lock_guard.len() {
183            error!("value does not fit into linear memory");
184            return Err(RuntimeError::MemoryAccessOutOfBounds);
185        }
186
187        // The following statement must be true
188        // `index + value_size <= lock_guard.len()`
189        // This check verifies it, while avoiding the possible overflow. The subtraction can not
190        // underflow because of the previous assert.
191
192        if (index) > lock_guard.len() - value_size {
193            error!("value read would extend beyond the end of the linear_memory");
194            return Err(RuntimeError::MemoryAccessOutOfBounds);
195        }
196
197        let ptr = lock_guard.get(index).unwrap().get();
198        let mut bytes = [0; N];
199
200        // Safety argument:
201        //
202        // - nonoverlapping is guaranteed, because `dest` is a pointer to a stack allocated array,
203        //   while the source is heap allocated Vec
204        // - the first assert above guarantee that source is bigger than `dest`
205        // - the second assert above guarantees that even with the offset in `index`, `dest` does
206        //   not extend beyond the destinations last `UnsafeCell<u8>` in source
207        // - the use of `UnsafeCell` avoids any `&` or `&mut` to ever be created on any of the `u8`s
208        //   contained in the `UnsafeCell`s, so no UB is created through the existence of unsound
209        //   references
210        unsafe { ptr.copy_to_nonoverlapping(bytes.as_mut_ptr(), bytes.len()) };
211        Ok(T::from_le_bytes(bytes))
212    }
213
214    /// Implementation of the behavior described in
215    /// <https://webassembly.github.io/spec/core/exec/instructions.html#xref-syntax-instructions-syntax-instr-memory-mathsf-memory-fill>.
216    /// Note, that the WASM spec defines the behavior by recursion, while our implementation uses
217    /// the memset like [`core::ptr::write_bytes`].
218    ///
219    /// <https://webassembly.github.io/spec/core/exec/instructions.html#xref-syntax-instructions-syntax-instr-memory-mathsf-memory-fill>
220    pub fn fill(&self, index: MemIdx, data_byte: u8, count: MemIdx) -> Result<(), RuntimeError> {
221        let lock_guard = self.inner_data.read();
222
223        /* check destination for out of bounds access */
224        // Specification step 12.
225        if count > lock_guard.len() {
226            error!("fill count is bigger than the linear memory");
227            return Err(RuntimeError::MemoryAccessOutOfBounds);
228        }
229
230        // Specification step 12.
231        if index > lock_guard.len() - count {
232            error!("fill extends beyond the linear memory's end");
233            return Err(RuntimeError::MemoryAccessOutOfBounds);
234        }
235
236        /* check if there is anything to be done */
237        // Specification step 13.
238        if count == 0 {
239            return Ok(());
240        }
241
242        let ptr = lock_guard[index].get();
243        unsafe {
244            // Specification step 14-21.
245            ptr.write_bytes(data_byte, count);
246        }
247
248        Ok(())
249    }
250
251    /// Copy `count` bytes from one region in the linear memory to another region in the same or a
252    /// different linear memory
253    ///
254    /// - Both regions may overlap
255    /// - Copies the `count` bytes starting from `source_index`, overwriting the `count` bytes
256    ///   starting from `destination_index`
257    ///
258    /// <https://webassembly.github.io/spec/core/exec/instructions.html#xref-syntax-instructions-syntax-instr-memory-mathsf-memory-copy>
259    pub fn copy(
260        &self,
261        destination_index: MemIdx,
262        source_mem: &Self,
263        source_index: MemIdx,
264        count: MemIdx,
265    ) -> Result<(), RuntimeError> {
266        // self is the destination
267        let lock_guard_self = self.inner_data.read();
268
269        // other is the source
270        let lock_guard_other = source_mem.inner_data.read();
271
272        /* check destination for out of bounds access */
273        // Specification step 12.
274        if count > lock_guard_self.len() {
275            error!("copy count is bigger than the destination linear memory");
276            return Err(RuntimeError::MemoryAccessOutOfBounds);
277        }
278
279        // Specification step 12.
280        if destination_index > lock_guard_self.len() - count {
281            error!("copy destination extends beyond the linear memory's end");
282            return Err(RuntimeError::MemoryAccessOutOfBounds);
283        }
284
285        /* check source for out of bounds access */
286        // Specification step 12.
287        if count > lock_guard_other.len() {
288            error!("copy count is bigger than the source linear memory");
289            return Err(RuntimeError::MemoryAccessOutOfBounds);
290        }
291
292        // Specification step 12.
293        if source_index > lock_guard_other.len() - count {
294            error!("copy source extends beyond the linear memory's end");
295            return Err(RuntimeError::MemoryAccessOutOfBounds);
296        }
297
298        /* check if there is anything to be done */
299        // Specification step 13.
300        if count == 0 {
301            return Ok(());
302        }
303
304        // acquire pointers
305        let destination_ptr = lock_guard_self[destination_index].get();
306        let source_ptr = lock_guard_other[source_index].get();
307
308        // copy the data
309        unsafe {
310            // TODO investigate if it is worth to use a conditional `copy_from_nonoverlapping`
311            // if the non-overlapping can be confirmed (and the count is bigger than a certain
312            // threshold).
313
314            // Specification step 14-15.
315            destination_ptr.copy_from(source_ptr, count);
316        }
317
318        Ok(())
319    }
320
321    // Rationale behind having `source_index` and `count` when the callsite could also just create a subslice for `source_data`? Have all the index error checks in one place.
322    //
323    // <https://webassembly.github.io/spec/core/exec/instructions.html#xref-syntax-instructions-syntax-instr-memory-mathsf-memory-init-x>
324    pub fn init(
325        &self,
326        destination_index: MemIdx,
327        source_data: &[u8],
328        source_index: MemIdx,
329        count: MemIdx,
330    ) -> Result<(), RuntimeError> {
331        // self is the destination
332        let lock_guard_self = self.inner_data.read();
333        let data_len = source_data.len();
334
335        /* check source for out of bounds access */
336        // Specification step 16.
337        if count > data_len {
338            error!("init count is bigger than the data instance");
339            return Err(RuntimeError::MemoryAccessOutOfBounds);
340        }
341
342        // Specification step 16.
343        if source_index > data_len - count {
344            error!("init source extends beyond the data instance's end");
345            return Err(RuntimeError::MemoryAccessOutOfBounds);
346        }
347
348        /* check destination for out of bounds access */
349        // Specification step 16.
350        if count > lock_guard_self.len() {
351            error!("init count is bigger than the linear memory");
352            return Err(RuntimeError::MemoryAccessOutOfBounds);
353        }
354
355        // Specification step 16.
356        if destination_index > lock_guard_self.len() - count {
357            error!("init extends beyond the linear memory's end");
358            return Err(RuntimeError::MemoryAccessOutOfBounds);
359        }
360
361        /* check if there is anything to be done */
362        // Specification step 17.
363        if count == 0 {
364            return Ok(());
365        }
366
367        // acquire pointers
368        let destination_ptr = lock_guard_self[destination_index].get();
369        let source_ptr = &source_data[source_index];
370
371        // copy the data
372        unsafe {
373            // Specification step 18-27.
374            destination_ptr.copy_from_nonoverlapping(source_ptr, count);
375        }
376
377        Ok(())
378    }
379}
380
381impl<const PAGE_SIZE: usize> core::fmt::Debug for LinearMemory<PAGE_SIZE> {
382    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
383        write!(f, "LinearMemory {{ inner_data: [ ")?;
384        let lock_guard = self.inner_data.read();
385        let mut iter = lock_guard.iter();
386
387        if let Some(first_byte_uc) = iter.next() {
388            write!(f, "{}", unsafe { *first_byte_uc.get() })?;
389        }
390
391        for uc in iter {
392            // Safety argument:
393            //
394            // TODO
395            let byte = unsafe { *uc.get() };
396
397            write!(f, ", {byte}")?;
398        }
399        write!(f, " ] }}")
400    }
401}
402
403impl<const PAGE_SIZE: usize> Default for LinearMemory<PAGE_SIZE> {
404    fn default() -> Self {
405        Self::new()
406    }
407}
408
409#[cfg(test)]
410mod test {
411    use alloc::format;
412
413    use super::*;
414
415    const PAGE_SIZE: usize = 1 << 8;
416    const PAGES: PageCountTy = 2;
417
418    #[test]
419    fn new_constructor() {
420        let lin_mem = LinearMemory::<PAGE_SIZE>::new();
421        assert_eq!(lin_mem.pages(), 0);
422    }
423
424    #[test]
425    fn new_grow() {
426        let lin_mem = LinearMemory::<PAGE_SIZE>::new();
427        lin_mem.grow(1);
428        assert_eq!(lin_mem.pages(), 1);
429    }
430
431    #[test]
432    fn debug_print() {
433        let lin_mem = LinearMemory::<PAGE_SIZE>::new_with_initial_pages(1);
434        assert_eq!(lin_mem.pages(), 1);
435
436        let expected_length = "LinearMemory { inner_data: [  ] }".len() + PAGE_SIZE * "0, ".len();
437        let tol = 2;
438
439        let debug_repr = format!("{lin_mem:?}");
440        let lower_bound = expected_length - tol;
441        let upper_bound = expected_length + tol;
442        assert!((lower_bound..upper_bound).contains(&debug_repr.len()));
443    }
444
445    #[test]
446    fn roundtrip_normal_range_i8_neg127() {
447        let x: i8 = -127;
448        let highest_legal_offset = PAGE_SIZE - mem::size_of::<i8>();
449        for offset in 0..MemIdx::try_from(highest_legal_offset).unwrap() {
450            let lin_mem = LinearMemory::<PAGE_SIZE>::new_with_initial_pages(PAGES);
451
452            lin_mem.store(offset, x).unwrap();
453
454            assert_eq!(
455                lin_mem
456                    .load::<{ core::mem::size_of::<i8>() }, i8>(offset)
457                    .unwrap(),
458                x,
459                "load store roundtrip for {x:?} failed!"
460            );
461        }
462    }
463
464    #[test]
465    fn roundtrip_normal_range_f32_13() {
466        let x: f32 = 13.0;
467        let highest_legal_offset = PAGE_SIZE - mem::size_of::<f32>();
468        for offset in 0..MemIdx::try_from(highest_legal_offset).unwrap() {
469            let lin_mem = LinearMemory::<PAGE_SIZE>::new_with_initial_pages(PAGES);
470
471            lin_mem.store(offset, x).unwrap();
472
473            assert_eq!(
474                lin_mem
475                    .load::<{ core::mem::size_of::<f32>() }, f32>(offset)
476                    .unwrap(),
477                x,
478                "load store roundtrip for {x:?} failed!"
479            );
480        }
481    }
482
483    #[test]
484    fn roundtrip_normal_range_f64_min() {
485        let x: f64 = f64::MIN;
486        let highest_legal_offset = PAGE_SIZE - mem::size_of::<f64>();
487        for offset in 0..MemIdx::try_from(highest_legal_offset).unwrap() {
488            let lin_mem = LinearMemory::<PAGE_SIZE>::new_with_initial_pages(PAGES);
489
490            lin_mem.store(offset, x).unwrap();
491
492            assert_eq!(
493                lin_mem
494                    .load::<{ core::mem::size_of::<f64>() }, f64>(offset)
495                    .unwrap(),
496                x,
497                "load store roundtrip for {x:?} failed!"
498            );
499        }
500    }
501
502    #[test]
503    fn roundtrip_normal_range_f64_nan() {
504        let x: f64 = f64::NAN;
505        let highest_legal_offset = PAGE_SIZE - mem::size_of::<f64>();
506        for offset in 0..MemIdx::try_from(highest_legal_offset).unwrap() {
507            let lin_mem = LinearMemory::<PAGE_SIZE>::new_with_initial_pages(PAGES);
508
509            lin_mem.store(offset, x).unwrap();
510
511            assert!(
512                lin_mem
513                    .load::<{ core::mem::size_of::<f64>() }, f64>(offset)
514                    .unwrap()
515                    .is_nan(),
516                "load store roundtrip for {x:?} failed!"
517            );
518        }
519    }
520
521    #[test]
522    #[should_panic(
523        expected = "called `Result::unwrap()` on an `Err` value: MemoryAccessOutOfBounds"
524    )]
525    fn store_out_of_range_u128_max() {
526        let x: u128 = u128::MAX;
527        let pages = 1;
528        let lowest_illegal_offset = PAGE_SIZE - mem::size_of::<u128>() + 1;
529        let lowest_illegal_offset = MemIdx::try_from(lowest_illegal_offset).unwrap();
530        let lin_mem = LinearMemory::<PAGE_SIZE>::new_with_initial_pages(pages);
531
532        lin_mem.store(lowest_illegal_offset, x).unwrap();
533    }
534
535    #[test]
536    #[should_panic(
537        expected = "called `Result::unwrap()` on an `Err` value: MemoryAccessOutOfBounds"
538    )]
539    fn store_empty_lineaer_memory_u8() {
540        let x: u8 = u8::MAX;
541        let pages = 0;
542        let lowest_illegal_offset = PAGE_SIZE - mem::size_of::<u8>() + 1;
543        let lowest_illegal_offset = MemIdx::try_from(lowest_illegal_offset).unwrap();
544        let lin_mem = LinearMemory::<PAGE_SIZE>::new_with_initial_pages(pages);
545
546        lin_mem.store(lowest_illegal_offset, x).unwrap();
547    }
548
549    #[test]
550    #[should_panic(
551        expected = "called `Result::unwrap()` on an `Err` value: MemoryAccessOutOfBounds"
552    )]
553    fn load_out_of_range_u128_max() {
554        let pages = 1;
555        let lowest_illegal_offset = PAGE_SIZE - mem::size_of::<u128>() + 1;
556        let lowest_illegal_offset = MemIdx::try_from(lowest_illegal_offset).unwrap();
557        let lin_mem = LinearMemory::<PAGE_SIZE>::new_with_initial_pages(pages);
558
559        let _x: u128 = lin_mem.load(lowest_illegal_offset).unwrap();
560    }
561
562    #[test]
563    #[should_panic(
564        expected = "called `Result::unwrap()` on an `Err` value: MemoryAccessOutOfBounds"
565    )]
566    fn load_empty_lineaer_memory_u8() {
567        let pages = 0;
568        let lowest_illegal_offset = PAGE_SIZE - mem::size_of::<u8>() + 1;
569        let lowest_illegal_offset = MemIdx::try_from(lowest_illegal_offset).unwrap();
570        let lin_mem = LinearMemory::<PAGE_SIZE>::new_with_initial_pages(pages);
571
572        let _x: u8 = lin_mem.load(lowest_illegal_offset).unwrap();
573    }
574
575    #[test]
576    #[should_panic]
577    fn copy_out_of_bounds() {
578        let lin_mem_0 = LinearMemory::<PAGE_SIZE>::new_with_initial_pages(2);
579        let lin_mem_1 = LinearMemory::<PAGE_SIZE>::new_with_initial_pages(1);
580        lin_mem_0.copy(0, &lin_mem_1, 0, PAGE_SIZE + 1).unwrap();
581    }
582}