Skip to main content

revmc_statetest/
compiled.rs

1// revmc-specific code: compilation, handler integration, and test orchestration.
2
3use crate::runner::{
4    check_evm_execution, execute_test_suite, skip_test, TestError, TestErrorKind, TestRunnerState,
5};
6use dashmap::DashMap;
7use revm::{
8    context::{block::BlockEnv, cfg::CfgEnv, tx::TxEnv},
9    context_interface::result::{EVMError, HaltReason, InvalidTransaction},
10    database::{self, bal::EvmDatabaseError},
11    database_interface::{DatabaseCommit, EmptyDB},
12    handler::{EvmTr, FrameResult, Handler, ItemOrResult},
13    primitives::{hardfork::SpecId, keccak256, B256, U256},
14    statetest_types::{SpecName, TestSuite, TestUnit},
15    Context, MainBuilder, MainContext, MainnetEvm,
16};
17use revmc::{EvmCompiler, EvmCompilerFn, EvmLlvmBackend, Linker, OptimizationLevel};
18use std::{
19    cell::RefCell,
20    collections::HashMap,
21    mem::ManuallyDrop,
22    path::{Path, PathBuf},
23    sync::{
24        atomic::{AtomicUsize, Ordering},
25        Arc, Mutex, OnceLock,
26    },
27    time::{Duration, Instant},
28};
29use thread_local::ThreadLocal;
30
31// ── Compile mode ────────────────────────────────────────────────────────────
32
33/// How to compile and execute bytecodes in the test suite.
34#[derive(Clone, Copy, Debug, Default)]
35pub enum CompileMode {
36    /// Standard interpreter execution (no compilation).
37    #[default]
38    Interpreter,
39    /// JIT-compile all bytecodes before execution.
40    Jit,
41    /// AOT-compile all bytecodes to a shared library, then load and execute.
42    Aot,
43}
44
45// ── Compiled contracts ──────────────────────────────────────────────────────
46
47/// Compiled contracts cache mapping bytecode hash to compiled function.
48#[derive(Default)]
49pub struct CompiledContracts {
50    functions: HashMap<B256, EvmCompilerFn>,
51}
52
53impl CompiledContracts {
54    pub fn new() -> Self {
55        Self::default()
56    }
57
58    pub fn get(&self, code_hash: &B256) -> Option<EvmCompilerFn> {
59        self.functions.get(code_hash).copied()
60    }
61
62    pub fn insert(&mut self, code_hash: B256, func: EvmCompilerFn) {
63        self.functions.insert(code_hash, func);
64    }
65}
66
67// ── Handler ─────────────────────────────────────────────────────────────────
68
69type StateTestEvm<'a> = MainnetEvm<revm::handler::MainnetContext<&'a mut database::State<EmptyDB>>>;
70type StateTestError = EVMError<EvmDatabaseError<std::convert::Infallible>, InvalidTransaction>;
71
72/// Custom handler that dispatches to compiled functions. All bytecodes —
73/// including runtime-created ones (CREATE/CREATE2) — are JIT-compiled before
74/// execution. Never falls back to the interpreter.
75pub struct CompiledHandler<'a> {
76    pub compiled: &'a CompiledContracts,
77    pub cache: &'a CompileCache,
78    pub spec_id: SpecId,
79}
80
81impl Handler for CompiledHandler<'_> {
82    type Evm = StateTestEvm<'static>;
83    type Error = StateTestError;
84    type HaltReason = HaltReason;
85
86    fn run_exec_loop(
87        &mut self,
88        evm: &mut Self::Evm,
89        first_frame_input: revm::interpreter::interpreter_action::FrameInit,
90    ) -> Result<FrameResult, Self::Error> {
91        let res = evm.frame_init(first_frame_input)?;
92        if let ItemOrResult::Result(frame_result) = res {
93            return Ok(frame_result);
94        }
95        loop {
96            let call_or_result = {
97                let frame = evm.frame_stack.get();
98                let bytecode_hash = frame.interpreter.bytecode.get_or_calculate_hash();
99                let f = if let Some(f) = self.compiled.get(&bytecode_hash) {
100                    f
101                } else {
102                    // Runtime-created contract (CREATE/CREATE2); compile it.
103                    let code = frame.interpreter.bytecode.original_byte_slice();
104                    self.cache
105                        .compile_single(bytecode_hash, code, self.spec_id)
106                        .expect("compilation failed for runtime bytecode")
107                };
108                {
109                    let ctx = &mut evm.ctx;
110                    let action = unsafe { f.call_with_interpreter(&mut frame.interpreter, ctx) };
111                    frame.process_next_action::<_, StateTestError>(ctx, action).inspect(|i| {
112                        if i.is_result() {
113                            frame.set_finished(true);
114                        }
115                    })?
116                }
117            };
118            let result = match call_or_result {
119                ItemOrResult::Item(init) => match evm.frame_init(init)? {
120                    ItemOrResult::Item(_) => continue,
121                    ItemOrResult::Result(result) => result,
122                },
123                ItemOrResult::Result(result) => result,
124            };
125            if let Some(result) = evm.frame_return_result(result)? {
126                return Ok(result);
127            }
128        }
129    }
130}
131
132// ── Compilation cache ────────────────────────────────────────────────────────
133
134type ClaimedEntry<'a> = (B256, &'a [u8], String, Arc<OnceLock<EvmCompilerFn>>);
135
136/// Thread-safe compilation cache shared across workers.
137pub struct CompileCache {
138    mode: CompileMode,
139    functions: DashMap<(B256, SpecId), Arc<OnceLock<EvmCompilerFn>>>,
140    /// Keep AOT shared libraries alive. Unused for JIT mode.
141    libs: Mutex<Vec<(tempfile::TempDir, libloading::Library)>>,
142    /// `ManuallyDrop` because `ThreadLocal::drop` drops values on the calling thread, but each
143    /// `EvmLlvmBackend` holds a reference to a thread-local LLVM context that only lives on the
144    /// thread that created it. Dropping on a different thread would use-after-free.
145    compiler: ManuallyDrop<ThreadLocal<RefCell<EvmCompiler<EvmLlvmBackend>>>>,
146    n_hits: AtomicUsize,
147    n_misses: AtomicUsize,
148}
149
150impl CompileCache {
151    pub fn new(mode: CompileMode) -> Self {
152        Self {
153            mode,
154            functions: Default::default(),
155            libs: Default::default(),
156            compiler: Default::default(),
157            n_hits: Default::default(),
158            n_misses: Default::default(),
159        }
160    }
161
162    /// Partition a test unit's contracts into cached (already compiled) and
163    /// claimed (this thread must compile them). Uses `DashMap::entry()` to
164    /// atomically distinguish vacant (we compile) from occupied (we wait).
165    fn claim_missing<'a>(
166        &self,
167        unit: &'a TestUnit,
168        spec_id: SpecId,
169    ) -> (CompiledContracts, Vec<ClaimedEntry<'a>>) {
170        use dashmap::mapref::entry::Entry;
171
172        let mut compiled = CompiledContracts::new();
173        let mut claimed = Vec::new();
174
175        for info in unit.pre.values() {
176            if info.code.is_empty() {
177                continue;
178            }
179            let code_hash = keccak256(&info.code);
180            if compiled.get(&code_hash).is_some() {
181                continue;
182            }
183            if claimed.iter().any(|(h, _, _, _): &(B256, _, _, _)| h == &code_hash) {
184                continue;
185            }
186
187            match self.functions.entry((code_hash, spec_id)) {
188                Entry::Occupied(e) => {
189                    let lock = e.get().clone();
190                    drop(e);
191                    // Already compiled or being compiled by another thread.
192                    if let Some(f) = lock.get() {
193                        self.n_hits.fetch_add(1, Ordering::Relaxed);
194                        compiled.insert(code_hash, *f);
195                    }
196                    // Otherwise: another thread is compiling it, wait_for_all handles it.
197                }
198                Entry::Vacant(e) => {
199                    // We're first — claim it.
200                    let lock = Arc::new(OnceLock::new());
201                    e.insert(lock.clone());
202                    self.n_misses.fetch_add(1, Ordering::Relaxed);
203                    claimed.push((
204                        code_hash,
205                        &info.code[..],
206                        format!("contract_{code_hash:x}"),
207                        lock,
208                    ));
209                }
210            }
211        }
212
213        (compiled, claimed)
214    }
215
216    /// Wait for all contracts in a test unit to be compiled, returning the
217    /// fully populated `CompiledContracts`.
218    fn wait_for_all(&self, unit: &TestUnit, spec_id: SpecId) -> CompiledContracts {
219        let mut compiled = CompiledContracts::new();
220        for info in unit.pre.values() {
221            if info.code.is_empty() {
222                continue;
223            }
224            let code_hash = keccak256(&info.code);
225            if compiled.get(&code_hash).is_some() {
226                continue;
227            }
228            if let Some(entry) = self.functions.get(&(code_hash, spec_id)) {
229                let f = entry.value().wait();
230                compiled.insert(code_hash, *f);
231            }
232        }
233        compiled
234    }
235
236    /// Compile all contracts in a test unit, returning the compiled functions.
237    /// Uses the cache's mode (JIT or AOT) to determine compilation strategy.
238    pub fn compile(
239        &self,
240        unit: &TestUnit,
241        spec_id: SpecId,
242    ) -> Result<CompiledContracts, TestErrorKind> {
243        let (mut compiled, claimed) = self.claim_missing(unit, spec_id);
244        if claimed.is_empty() {
245            // Still need to wait for contracts another thread claimed.
246            let rest = self.wait_for_all(unit, spec_id);
247            for (hash, f) in rest.functions {
248                compiled.functions.entry(hash).or_insert(f);
249            }
250            return Ok(compiled);
251        }
252
253        match self.mode {
254            CompileMode::Jit => self.compile_jit_batch(&claimed, &mut compiled, spec_id)?,
255            CompileMode::Aot => self.compile_aot_batch(&claimed, &mut compiled, spec_id)?,
256            CompileMode::Interpreter => unreachable!(),
257        }
258
259        // Wait for contracts claimed by other threads.
260        let rest = self.wait_for_all(unit, spec_id);
261        for (hash, f) in rest.functions {
262            compiled.functions.entry(hash).or_insert(f);
263        }
264
265        Ok(compiled)
266    }
267
268    /// Compile a single bytecode (e.g. from CREATE/CREATE2 at runtime).
269    /// Claims the cache slot first, then compiles via the batch path.
270    fn compile_single(
271        &self,
272        code_hash: B256,
273        code: &[u8],
274        spec_id: SpecId,
275    ) -> Result<EvmCompilerFn, TestErrorKind> {
276        use dashmap::mapref::entry::Entry;
277
278        match self.functions.entry((code_hash, spec_id)) {
279            Entry::Occupied(e) => {
280                let lock = e.get().clone();
281                drop(e);
282                let f = lock.wait();
283                self.n_hits.fetch_add(1, Ordering::Relaxed);
284                Ok(*f)
285            }
286            Entry::Vacant(e) => {
287                let lock = Arc::new(OnceLock::new());
288                e.insert(lock.clone());
289                self.n_misses.fetch_add(1, Ordering::Relaxed);
290
291                let name = format!("runtime_{code_hash:x}");
292                let claimed = vec![(code_hash, code, name, lock)];
293                let mut compiled = CompiledContracts::new();
294
295                match self.mode {
296                    CompileMode::Jit => self.compile_jit_batch(&claimed, &mut compiled, spec_id)?,
297                    CompileMode::Aot => self.compile_aot_batch(&claimed, &mut compiled, spec_id)?,
298                    CompileMode::Interpreter => unreachable!(),
299                }
300
301                Ok(compiled.get(&code_hash).unwrap())
302            }
303        }
304    }
305
306    fn compile_jit_batch(
307        &self,
308        claimed: &[ClaimedEntry<'_>],
309        compiled: &mut CompiledContracts,
310        spec_id: SpecId,
311    ) -> Result<(), TestErrorKind> {
312        let mut compiler = self.compiler.get_or(|| make_compiler(false)).borrow_mut();
313
314        let mut func_ids = Vec::new();
315        for (code_hash, code, name, _) in claimed {
316            let func_id = compiler
317                .translate(name, *code, spec_id)
318                .map_err(|e| TestErrorKind::CompilationError(format!("translate {name}: {e}")))?;
319            func_ids.push((*code_hash, func_id));
320        }
321
322        for (i, (code_hash, func_id)) in func_ids.into_iter().enumerate() {
323            let func = unsafe { compiler.jit_function(func_id) }.map_err(|e| {
324                TestErrorKind::CompilationError(format!("jit {:x}: {e}", code_hash))
325            })?;
326            claimed[i].3.set(func).ok();
327            compiled.insert(code_hash, func);
328        }
329
330        let _ = compiler.clear_ir();
331
332        Ok(())
333    }
334
335    fn compile_aot_batch(
336        &self,
337        claimed: &[ClaimedEntry<'_>],
338        compiled: &mut CompiledContracts,
339        spec_id: SpecId,
340    ) -> Result<(), TestErrorKind> {
341        let mut compiler = self.compiler.get_or(|| make_compiler(true)).borrow_mut();
342
343        let mut names: Vec<(B256, String)> = Vec::new();
344        for (code_hash, code, name, _) in claimed {
345            compiler
346                .translate(name, *code, spec_id)
347                .map_err(|e| TestErrorKind::CompilationError(format!("translate {name}: {e}")))?;
348            names.push((*code_hash, name.clone()));
349        }
350
351        let tmp_dir = tempfile::tempdir()
352            .map_err(|e| TestErrorKind::CompilationError(format!("tempdir: {e}")))?;
353        let obj_path = tmp_dir.path().join("a.o");
354        let so_path = tmp_dir.path().join("a.so");
355
356        compiler
357            .write_object_to_file(&obj_path)
358            .map_err(|e| TestErrorKind::CompilationError(format!("write object: {e}")))?;
359
360        let linker = Linker::new();
361        linker
362            .link(&so_path, [obj_path.to_str().unwrap()])
363            .map_err(|e| TestErrorKind::CompilationError(format!("link: {e}")))?;
364
365        let lib = unsafe { libloading::Library::new(&so_path) }
366            .map_err(|e| TestErrorKind::CompilationError(format!("load: {e}")))?;
367
368        for (i, (code_hash, name)) in names.iter().enumerate() {
369            let f: libloading::Symbol<'_, EvmCompilerFn> = unsafe { lib.get(name.as_bytes()) }
370                .map_err(|e| TestErrorKind::CompilationError(format!("symbol {name}: {e}")))?;
371            claimed[i].3.set(*f).ok();
372            compiled.insert(*code_hash, *f);
373        }
374
375        self.libs.lock().unwrap().push((tmp_dir, lib));
376
377        let _ = compiler.clear_ir();
378
379        Ok(())
380    }
381
382    pub fn print_stats(&self) {
383        let hits = self.n_hits.load(Ordering::Relaxed);
384        let misses = self.n_misses.load(Ordering::Relaxed);
385        let total = hits + misses;
386        if total > 0 {
387            let label = match self.mode {
388                CompileMode::Jit => "JIT",
389                CompileMode::Aot => "AOT",
390                CompileMode::Interpreter => unreachable!(),
391            };
392            let rate = hits as f64 / total as f64 * 100.0;
393            let n_libs = self.libs.lock().unwrap().len();
394            if n_libs > 0 {
395                println!(
396                    "{label} cache: {total} lookups, {hits} hits, {misses} misses ({rate:.1}% hit rate), {} unique, {n_libs} shared libs",
397                    self.functions.len()
398                );
399            } else {
400                println!(
401                    "{label} cache: {total} lookups, {hits} hits, {misses} misses ({rate:.1}% hit rate), {} unique",
402                    self.functions.len()
403                );
404            }
405        }
406    }
407}
408
409fn make_compiler(aot: bool) -> RefCell<EvmCompiler<EvmLlvmBackend>> {
410    RefCell::new(EvmCompiler::new(EvmLlvmBackend::new(aot, OptimizationLevel::Aggressive).unwrap()))
411}
412
413// ── Compiled test execution ─────────────────────────────────────────────────
414
415pub struct CompiledTestContext<'a> {
416    pub compiled: &'a CompiledContracts,
417    pub cache: &'a CompileCache,
418    pub spec_id: SpecId,
419    pub test: &'a revm::statetest_types::Test,
420    pub unit: &'a TestUnit,
421    pub name: &'a str,
422    pub cfg: &'a CfgEnv,
423    pub block: &'a BlockEnv,
424    pub tx: &'a TxEnv,
425    pub cache_state: &'a database::CacheState,
426    pub elapsed: &'a Arc<Mutex<Duration>>,
427}
428
429/// Execute a single test using compiled functions via the custom handler.
430pub fn execute_single_test_compiled(ctx: CompiledTestContext<'_>) -> Result<(), TestErrorKind> {
431    let prestate = ctx.cache_state.clone();
432    let mut state =
433        database::State::builder().with_cached_prestate(prestate).with_bundle_update().build();
434
435    let timer = Instant::now();
436    // SAFETY: The handler and evm do not outlive `state`. The `'static` in
437    // `StateTestEvm<'static>` is required by the `Handler` trait but we
438    // guarantee the borrow is valid for the duration of `handler.run`.
439    let exec_result = unsafe {
440        let db_ref = &mut *(&mut state as *mut database::State<EmptyDB>);
441        let evm_context = Context::mainnet()
442            .with_block(ctx.block.clone())
443            .with_tx(ctx.tx.clone())
444            .with_cfg(ctx.cfg.clone())
445            .with_db(db_ref);
446        let mut handler =
447            CompiledHandler { compiled: ctx.compiled, cache: ctx.cache, spec_id: ctx.spec_id };
448        let mut evm = evm_context.build_mainnet();
449        let result = handler.run(&mut evm);
450        if result.is_ok() {
451            let s = evm.ctx.journaled_state.finalize();
452            DatabaseCommit::commit(&mut evm.ctx.journaled_state.database, s);
453        }
454        result
455    };
456    let db = &mut state;
457    *ctx.elapsed.lock().unwrap() += timer.elapsed();
458
459    check_evm_execution(
460        ctx.test,
461        ctx.unit.out.as_ref(),
462        ctx.name,
463        &exec_result,
464        db,
465        *ctx.cfg.spec(),
466        false,
467    )
468}
469
470// ── Suite-level execution (compiled) ─────────────────────────────────────────
471
472/// Execute a single test suite file, compiling all contracts before execution.
473fn execute_test_suite_compiled(
474    path: &Path,
475    elapsed: &Arc<Mutex<Duration>>,
476    cache: &CompileCache,
477) -> Result<(), TestError> {
478    if skip_test(path) {
479        return Ok(());
480    }
481
482    let s = std::fs::read_to_string(path).unwrap();
483    let path_str = path.to_string_lossy().into_owned();
484    let suite: TestSuite = serde_json::from_str(&s).map_err(|e| TestError {
485        name: "Unknown".to_string(),
486        path: path_str.clone(),
487        kind: e.into(),
488    })?;
489
490    for (name, unit) in suite.0 {
491        let cache_state = unit.state();
492
493        let mut cfg = CfgEnv::default();
494        cfg.chain_id = unit.env.current_chain_id.unwrap_or(U256::ONE).try_into().unwrap_or(1);
495
496        for (spec_name, tests) in &unit.post {
497            if *spec_name == SpecName::Constantinople {
498                continue;
499            }
500
501            let spec_id = spec_name.to_spec_id();
502            cfg.set_spec_and_mainnet_gas_params(spec_id);
503
504            if cfg.spec().is_enabled_in(SpecId::OSAKA) {
505                cfg.set_max_blobs_per_tx(6);
506            } else if cfg.spec().is_enabled_in(SpecId::PRAGUE) {
507                cfg.set_max_blobs_per_tx(9);
508            } else {
509                cfg.set_max_blobs_per_tx(6);
510            }
511
512            let block = unit.block_env(&mut cfg);
513
514            let compiled = cache.compile(&unit, spec_id).map_err(|e| TestError {
515                name: name.clone(),
516                path: path_str.clone(),
517                kind: e,
518            })?;
519
520            for test in tests.iter() {
521                let tx = match test.tx_env(&unit) {
522                    Ok(tx) => tx,
523                    Err(_) if test.expect_exception.is_some() => continue,
524                    Err(_) => {
525                        return Err(TestError {
526                            name,
527                            path: path_str,
528                            kind: TestErrorKind::UnknownPrivateKey(unit.transaction.secret_key),
529                        });
530                    }
531                };
532
533                let result = execute_single_test_compiled(CompiledTestContext {
534                    compiled: &compiled,
535                    cache,
536                    spec_id,
537                    test,
538                    unit: &unit,
539                    name: &name,
540                    cfg: &cfg,
541                    block: &block,
542                    tx: &tx,
543                    cache_state: &cache_state,
544                    elapsed,
545                });
546
547                if let Err(e) = result {
548                    return Err(TestError { name, path: path_str, kind: e });
549                }
550            }
551        }
552    }
553    Ok(())
554}
555
556// ── Top-level runner ────────────────────────────────────────────────────────
557
558fn run_test_worker(
559    state: TestRunnerState,
560    keep_going: bool,
561    mode: CompileMode,
562    cache: Option<&CompileCache>,
563) -> Result<(), TestError> {
564    loop {
565        if !keep_going && state.n_errors.load(Ordering::SeqCst) > 0 {
566            return Ok(());
567        }
568
569        let Some(test_path) = state.next_test() else {
570            return Ok(());
571        };
572
573        let result = match mode {
574            CompileMode::Interpreter => {
575                execute_test_suite(&test_path, &state.elapsed, false, false)
576            }
577            CompileMode::Jit | CompileMode::Aot => {
578                execute_test_suite_compiled(&test_path, &state.elapsed, cache.unwrap())
579            }
580        };
581
582        state.console_bar.inc(1);
583
584        if let Err(err) = result {
585            state.n_errors.fetch_add(1, Ordering::SeqCst);
586            if !keep_going {
587                return Err(err);
588            }
589        }
590    }
591}
592
593/// Run all test files.
594pub fn run(
595    test_files: Vec<PathBuf>,
596    single_thread: bool,
597    keep_going: bool,
598    mode: CompileMode,
599) -> Result<(), TestError> {
600    let _ = tracing_subscriber::fmt::try_init();
601
602    let n_files = test_files.len();
603    let state = TestRunnerState::new(test_files);
604
605    let cache = match mode {
606        CompileMode::Interpreter => None,
607        CompileMode::Jit | CompileMode::Aot => Some(Arc::new(CompileCache::new(mode))),
608    };
609
610    let num_threads = if single_thread {
611        1
612    } else {
613        match std::thread::available_parallelism() {
614            Ok(n) => n.get().min(n_files),
615            Err(_) => 1,
616        }
617    };
618
619    let mut handles = Vec::with_capacity(num_threads);
620    for i in 0..num_threads {
621        let state = state.clone();
622        let cache = cache.clone();
623
624        let thread = std::thread::Builder::new()
625            .name(format!("runner-{i}"))
626            .spawn(move || run_test_worker(state, keep_going, mode, cache.as_deref()))
627            .unwrap();
628
629        handles.push(thread);
630    }
631
632    let mut thread_errors = Vec::new();
633    for (i, handle) in handles.into_iter().enumerate() {
634        match handle.join() {
635            Ok(Ok(())) => {}
636            Ok(Err(e)) => thread_errors.push(e),
637            Err(_) => thread_errors.push(TestError {
638                name: format!("thread {i} panicked"),
639                path: String::new(),
640                kind: TestErrorKind::Panic,
641            }),
642        }
643    }
644
645    state.console_bar.finish();
646
647    println!(
648        "Finished execution. Total CPU time: {:.6}s",
649        state.elapsed.lock().unwrap().as_secs_f64()
650    );
651
652    if let Some(cache) = &cache {
653        cache.print_stats();
654    }
655
656    let n_errors = state.n_errors.load(Ordering::SeqCst);
657    let n_thread_errors = thread_errors.len();
658
659    if n_errors == 0 && n_thread_errors == 0 {
660        println!("All tests passed!");
661        Ok(())
662    } else {
663        println!("Encountered {n_errors} errors out of {n_files} total tests");
664
665        if n_thread_errors == 0 {
666            std::process::exit(1);
667        }
668
669        if n_thread_errors > 1 {
670            println!("{n_thread_errors} threads returned an error, out of {num_threads} total:");
671            for error in &thread_errors {
672                println!("{error}");
673            }
674        }
675        Err(thread_errors.swap_remove(0))
676    }
677}