wasm/validation/
mod.rs

1use alloc::collections::btree_set::{self, BTreeSet};
2use alloc::vec::Vec;
3
4use crate::core::indices::{FuncIdx, TypeIdx};
5use crate::core::reader::section_header::{SectionHeader, SectionTy};
6use crate::core::reader::span::Span;
7use crate::core::reader::types::data::DataSegment;
8use crate::core::reader::types::element::ElemType;
9use crate::core::reader::types::export::Export;
10use crate::core::reader::types::global::{Global, GlobalType};
11use crate::core::reader::types::import::{Import, ImportDesc};
12use crate::core::reader::types::{FuncType, MemType, TableType};
13use crate::core::reader::{WasmReadable, WasmReader};
14use crate::core::sidetable::Sidetable;
15use crate::{Error, ExportDesc, Result};
16
17pub(crate) mod code;
18pub(crate) mod data;
19pub(crate) mod globals;
20pub(crate) mod read_constant_expression;
21pub(crate) mod validation_stack;
22
23#[derive(Clone, Debug)]
24pub(crate) struct ImportsLength {
25    pub imported_functions: usize,
26    pub imported_globals: usize,
27    pub imported_memories: usize,
28    pub imported_tables: usize,
29}
30
31/// Information collected from validating a module.
32/// This can be used to create a [crate::RuntimeInstance].
33#[derive(Clone, Debug)]
34pub struct ValidationInfo<'bytecode> {
35    pub(crate) wasm: &'bytecode [u8],
36    pub(crate) types: Vec<FuncType>,
37    pub(crate) imports: Vec<Import>,
38    pub(crate) functions: Vec<TypeIdx>,
39    pub(crate) tables: Vec<TableType>,
40    pub(crate) memories: Vec<MemType>,
41    pub(crate) globals: Vec<Global>,
42    #[allow(dead_code)]
43    pub(crate) exports: Vec<Export>,
44    /// Each block contains the validated code section and the stp corresponding to
45    /// the beginning of that code section
46    pub(crate) func_blocks_stps: Vec<(Span, usize)>,
47    pub(crate) sidetable: Sidetable,
48    pub(crate) data: Vec<DataSegment>,
49    /// The start function which is automatically executed during instantiation
50    pub(crate) start: Option<FuncIdx>,
51    pub(crate) elements: Vec<ElemType>,
52    pub(crate) imports_length: ImportsLength,
53    // pub(crate) exports_length: Exported,
54}
55
56fn validate_exports(validation_info: &ValidationInfo) -> Result<()> {
57    let mut found_export_names: btree_set::BTreeSet<&str> = btree_set::BTreeSet::new();
58    use crate::core::reader::types::export::ExportDesc::*;
59    for export in &validation_info.exports {
60        if found_export_names.contains(export.name.as_str()) {
61            return Err(Error::DuplicateExportName);
62        }
63        found_export_names.insert(export.name.as_str());
64        match export.desc {
65            FuncIdx(func_idx) => {
66                if validation_info.functions.len()
67                    + validation_info.imports_length.imported_functions
68                    <= func_idx
69                {
70                    return Err(Error::UnknownFunction);
71                }
72            }
73            TableIdx(table_idx) => {
74                if validation_info.tables.len() + validation_info.imports_length.imported_tables
75                    <= table_idx
76                {
77                    return Err(Error::UnknownTable);
78                }
79            }
80            MemIdx(mem_idx) => {
81                if validation_info.memories.len() + validation_info.imports_length.imported_memories
82                    <= mem_idx
83                {
84                    return Err(Error::UnknownMemory);
85                }
86            }
87            GlobalIdx(global_idx) => {
88                if validation_info.globals.len() + validation_info.imports_length.imported_globals
89                    <= global_idx
90                {
91                    return Err(Error::UnknownGlobal);
92                }
93            }
94        }
95    }
96    Ok(())
97}
98
99fn get_imports_length(imports: &Vec<Import>) -> ImportsLength {
100    let mut imports_length = ImportsLength {
101        imported_functions: 0,
102        imported_globals: 0,
103        imported_memories: 0,
104        imported_tables: 0,
105    };
106
107    for import in imports {
108        match import.desc {
109            ImportDesc::Func(_) => imports_length.imported_functions += 1,
110            ImportDesc::Global(_) => imports_length.imported_globals += 1,
111            ImportDesc::Mem(_) => imports_length.imported_memories += 1,
112            ImportDesc::Table(_) => imports_length.imported_tables += 1,
113        }
114    }
115
116    imports_length
117}
118
119pub fn validate(wasm: &[u8]) -> Result<ValidationInfo> {
120    let mut wasm = WasmReader::new(wasm);
121
122    // represents C.refs in https://webassembly.github.io/spec/core/valid/conventions.html#context
123    // A func.ref instruction is onlv valid if it has an immediate that is a member of C.refs.
124    // this list holds all the func_idx's occurring in the module, except in its functions or start function.
125    // I make an exception here by not including func_idx's occuring within data segments in C.refs as well, so that single pass validation is possible.
126    // If there is a func_idx within the data segment, this would ultimately mean that data segment cannot be validated,
127    // therefore this hack is acceptable.
128    // https://webassembly.github.io/spec/core/valid/modules.html#data-segments
129    // https://webassembly.github.io/spec/core/valid/modules.html#valid-module
130
131    let mut validation_context_refs: BTreeSet<FuncIdx> = BTreeSet::new();
132
133    trace!("Starting validation of bytecode");
134
135    trace!("Validating magic value");
136    let [0x00, 0x61, 0x73, 0x6d] = wasm.strip_bytes::<4>()? else {
137        return Err(Error::InvalidMagic);
138    };
139
140    trace!("Validating version number");
141    let [0x01, 0x00, 0x00, 0x00] = wasm.strip_bytes::<4>()? else {
142        return Err(Error::InvalidVersion);
143    };
144    debug!("Header ok");
145
146    let mut header = None;
147    read_next_header(&mut wasm, &mut header)?;
148
149    let skip_section = |wasm: &mut WasmReader, section_header: &mut Option<SectionHeader>| {
150        handle_section(wasm, section_header, SectionTy::Custom, |wasm, h| {
151            use alloc::string::*;
152            // customsec ::= section_0(custom)
153            // custom ::= name byte*
154            // name ::= b*:vec(byte) => name (if utf8(name) = b*)
155            // vec(B) ::= n:u32 (x:B)^n => x^n
156            let _name = wasm.read_name()?;
157
158            let remaining_bytes = match h
159                .contents
160                .from()
161                .checked_add(h.contents.len())
162                .and_then(|res| res.checked_sub(wasm.pc))
163            {
164                None => Err(Error::InvalidSection(
165                    SectionTy::Custom,
166                    "Remaining bytes less than 0 after reading name!".to_string(),
167                )),
168                Some(remaining_bytes) => Ok(remaining_bytes),
169            }?;
170
171            // TODO: maybe do something with these remaining bytes?
172            let mut _bytes = Vec::new();
173            for _ in 0..remaining_bytes {
174                _bytes.push(wasm.read_u8()?)
175            }
176            Ok(())
177        })
178    };
179
180    while (skip_section(&mut wasm, &mut header)?).is_some() {}
181
182    let types = handle_section(&mut wasm, &mut header, SectionTy::Type, |wasm, _| {
183        wasm.read_vec(FuncType::read)
184    })?
185    .unwrap_or_default();
186
187    while (skip_section(&mut wasm, &mut header)?).is_some() {}
188
189    let imports = handle_section(&mut wasm, &mut header, SectionTy::Import, |wasm, _| {
190        wasm.read_vec(Import::read)
191    })?
192    .unwrap_or_default();
193    let imports_length = get_imports_length(&imports);
194
195    while (skip_section(&mut wasm, &mut header)?).is_some() {}
196
197    // The `Function` section only covers module-level (or "local") functions.
198    // Imported functions have their types known in the `import` section. Both
199    // local and imported functions share the same index space.
200    //
201    // Imported functions are given priority and have the first indicies, and
202    // only after that do the local functions get assigned their indices.
203    let local_functions =
204        handle_section(&mut wasm, &mut header, SectionTy::Function, |wasm, _| {
205            wasm.read_vec(|wasm| wasm.read_var_u32().map(|u| u as usize))
206        })?
207        .unwrap_or_default();
208
209    let imported_functions = imports.iter().filter_map(|import| match &import.desc {
210        ImportDesc::Func(type_idx) => Some(*type_idx),
211        _ => None,
212    });
213
214    let all_functions = imported_functions
215        .clone()
216        .chain(local_functions.iter().cloned())
217        .collect::<Vec<TypeIdx>>();
218
219    while (skip_section(&mut wasm, &mut header)?).is_some() {}
220
221    let imported_tables = imports
222        .iter()
223        .filter_map(|m| match m.desc {
224            ImportDesc::Table(table) => Some(table),
225            _ => None,
226        })
227        .collect::<Vec<TableType>>();
228    let tables = handle_section(&mut wasm, &mut header, SectionTy::Table, |wasm, _| {
229        wasm.read_vec(TableType::read)
230    })?
231    .unwrap_or_default();
232
233    let all_tables = {
234        let mut temp = imported_tables;
235        temp.extend(tables.clone());
236        temp
237    };
238
239    while (skip_section(&mut wasm, &mut header)?).is_some() {}
240
241    let imported_memories = imports
242        .iter()
243        .filter_map(|m| match m.desc {
244            ImportDesc::Mem(mem) => Some(mem),
245            _ => None,
246        })
247        .collect::<Vec<MemType>>();
248    // let imported_memories_length = imported_memories.len();
249    let memories = handle_section(&mut wasm, &mut header, SectionTy::Memory, |wasm, _| {
250        wasm.read_vec(MemType::read)
251    })?
252    .unwrap_or_default();
253
254    let all_memories = {
255        let mut temp = imported_memories;
256        temp.extend(memories.clone());
257        temp
258    };
259    if all_memories.len() > 1 {
260        return Err(Error::MoreThanOneMemory);
261    }
262
263    while (skip_section(&mut wasm, &mut header)?).is_some() {}
264
265    // we start off with the imported globals
266    let /* mut */ imported_global_types = imports
267        .iter()
268        .filter_map(|m| match m.desc {
269            ImportDesc::Global(global) => Some(global),
270            _ => None,
271        })
272        .collect::<Vec<GlobalType>>();
273    let imported_global_types_len = imported_global_types.len();
274    let globals = handle_section(&mut wasm, &mut header, SectionTy::Global, |wasm, h| {
275        globals::validate_global_section(
276            wasm,
277            h,
278            &imported_global_types,
279            &mut validation_context_refs,
280            all_functions.len(),
281        )
282    })?
283    .unwrap_or_default();
284    let mut all_globals = Vec::new();
285    for item in imported_global_types.iter().take(imported_global_types_len) {
286        all_globals.push(Global {
287            init_expr: Span::new(usize::MAX, 0),
288            ty: *item,
289        })
290    }
291    for item in &globals {
292        all_globals.push(*item)
293    }
294
295    while (skip_section(&mut wasm, &mut header)?).is_some() {}
296
297    let exports = handle_section(&mut wasm, &mut header, SectionTy::Export, |wasm, _| {
298        wasm.read_vec(Export::read)
299    })?
300    .unwrap_or_default();
301    validation_context_refs.extend(exports.iter().filter_map(
302        |Export { name: _, desc }| match *desc {
303            ExportDesc::FuncIdx(func_idx) => Some(func_idx),
304            _ => None,
305        },
306    ));
307
308    while (skip_section(&mut wasm, &mut header)?).is_some() {}
309
310    let start = handle_section(&mut wasm, &mut header, SectionTy::Start, |wasm, _| {
311        wasm.read_var_u32().map(|idx| idx as FuncIdx)
312    })?;
313
314    while (skip_section(&mut wasm, &mut header)?).is_some() {}
315
316    let elements: Vec<ElemType> =
317        handle_section(&mut wasm, &mut header, SectionTy::Element, |wasm, _| {
318            ElemType::read_from_wasm(
319                wasm,
320                &all_functions,
321                &mut validation_context_refs,
322                &all_tables,
323                &imported_global_types,
324            )
325        })?
326        .unwrap_or_default();
327
328    while (skip_section(&mut wasm, &mut header)?).is_some() {}
329
330    // https://webassembly.github.io/spec/core/binary/modules.html#data-count-section
331    // As per the official documentation:
332    //
333    // The data count section is used to simplify single-pass validation. Since the data section occurs after the code section, the `memory.init` and `data.drop` and instructions would not be able to check whether the data segment index is valid until the data section is read. The data count section occurs before the code section, so a single-pass validator can use this count instead of deferring validation.
334    let data_count: Option<u32> =
335        handle_section(&mut wasm, &mut header, SectionTy::DataCount, |wasm, _| {
336            wasm.read_var_u32()
337        })?;
338    if data_count.is_some() {
339        trace!("data count: {}", data_count.unwrap());
340    }
341
342    while (skip_section(&mut wasm, &mut header)?).is_some() {}
343
344    let mut sidetable = Sidetable::new();
345    let func_blocks_stps = handle_section(&mut wasm, &mut header, SectionTy::Code, |wasm, h| {
346        code::validate_code_section(
347            wasm,
348            h,
349            &types,
350            &all_functions,
351            imported_functions.count(),
352            &all_globals,
353            &all_memories,
354            &data_count,
355            &all_tables,
356            &elements,
357            &validation_context_refs,
358            &mut sidetable,
359        )
360    })?
361    .unwrap_or_default();
362
363    assert_eq!(
364        func_blocks_stps.len(),
365        local_functions.len(),
366        "these should be equal"
367    ); // TODO check if this is in the spec
368
369    while (skip_section(&mut wasm, &mut header)?).is_some() {}
370
371    let data_section = handle_section(&mut wasm, &mut header, SectionTy::Data, |wasm, h| {
372        // wasm.read_vec(DataSegment::read)
373        data::validate_data_section(
374            wasm,
375            h,
376            &imported_global_types,
377            all_memories.len(),
378            all_functions.len(),
379        )
380    })?
381    .unwrap_or_default();
382
383    // https://webassembly.github.io/spec/core/binary/modules.html#data-count-section
384    if data_count.is_some() {
385        assert_eq!(data_count.unwrap() as usize, data_section.len());
386    }
387
388    while (skip_section(&mut wasm, &mut header)?).is_some() {}
389
390    // All sections should have been handled
391    if let Some(header) = header {
392        return Err(Error::SectionOutOfOrder(header.ty));
393    }
394
395    debug!("Validation was successful");
396    let validation_info = ValidationInfo {
397        wasm: wasm.into_inner(),
398        types,
399        imports,
400        functions: local_functions,
401        tables,
402        memories,
403        globals,
404        exports,
405        func_blocks_stps,
406        sidetable,
407        data: data_section,
408        start,
409        elements,
410        imports_length,
411    };
412    validate_exports(&validation_info)?;
413
414    Ok(validation_info)
415}
416
417fn read_next_header(wasm: &mut WasmReader, header: &mut Option<SectionHeader>) -> Result<()> {
418    if header.is_none() && !wasm.remaining_bytes().is_empty() {
419        *header = Some(SectionHeader::read(wasm)?);
420    }
421    Ok(())
422}
423
424#[inline(always)]
425fn handle_section<T, F: FnOnce(&mut WasmReader, SectionHeader) -> Result<T>>(
426    wasm: &mut WasmReader,
427    header: &mut Option<SectionHeader>,
428    section_ty: SectionTy,
429    handler: F,
430) -> Result<Option<T>> {
431    match &header {
432        Some(SectionHeader { ty, .. }) if *ty == section_ty => {
433            let h = header.take().unwrap();
434            trace!("Handling section {:?}", h.ty);
435            let ret = handler(wasm, h)?;
436            read_next_header(wasm, header)?;
437            Ok(Some(ret))
438        }
439        _ => Ok(None),
440    }
441}