elemaudio_rs/
engine.rs

1//! Generic DSP engine that pairs a [`DspGraph`] implementation with a
2//! [`Runtime`], handling graph mounting, parameter diffing, and audio
3//! processing.
4//!
5//! # Usage
6//!
7//! 1. Implement [`DspGraph`] for your graph — a pure function that
8//!    builds `Vec<Node>` from parameters.
9//! 2. Create an [`Engine`] at activation time.
10//! 3. Call [`Engine::set_params`] on parameter changes — the engine
11//!    auto-discovers keyed consts and native node props from the graph
12//!    and emits only the minimal update batches.
13//! 4. Call [`Engine::process`] on every audio block.
14//!
15//! # Example
16//!
17//! ```ignore
18//! use elemaudio_rs::{el, extra, Node};
19//! use elemaudio_rs::engine::{DspGraph, Engine};
20//! use serde_json::json;
21//!
22//! struct MyDelay;
23//!
24//! impl DspGraph for MyDelay {
25//!     type Params = f64; // just delay_ms for this example
26//!
27//!     fn build(params: &f64) -> Vec<Node> {
28//!         let delay = el::const_with_key("delay", *params);
29//!         let fb = el::const_(0.3);
30//!         let input = el::r#in(json!({"channel": 0}), None);
31//!         vec![extra::stride_delay(json!({"maxDelayMs": 1500}), delay, fb, input)]
32//!     }
33//! }
34//!
35//! let engine = Engine::<MyDelay>::new(44100.0, 512, &250.0).unwrap();
36//! ```
37
38use std::collections::HashMap;
39use std::marker::PhantomData;
40
41use crate::graph::{Graph, MountedGraph, Node};
42use crate::runtime::{Instruction, InstructionBatch, Runtime};
43
44// Re-export legacy types for backward compatibility during migration.
45pub use legacy::{KeyedConst, NativeProp};
46
47/// A pure graph-building function.
48///
49/// Implement this trait for each DSP topology. The engine handles
50/// mounting, parameter diffing, and runtime delegation. Keyed consts
51/// and native node props are discovered automatically from the graph
52/// — no manual declarations needed.
53pub trait DspGraph {
54    /// Parameter snapshot type (e.g., a struct with delay_ms, feedback, etc.).
55    type Params: Clone;
56
57    /// Build the graph from parameters. Returns output root nodes.
58    ///
59    /// Called once at activation to mount the graph, and again on each
60    /// `set_params` call to diff keyed consts and native props.
61    fn build(params: &Self::Params) -> Vec<Node>;
62}
63
64/// Generic DSP engine that owns a [`Runtime`] and a mounted graph.
65pub struct Engine<G: DspGraph> {
66    runtime: Runtime,
67    mounted: MountedGraph,
68    /// Keyed const values snapshot from the last build, for change detection.
69    keyed_consts: HashMap<String, f64>,
70    /// Native node props snapshot from the last build, for change detection.
71    /// Key: (node_kind, prop_name) → value. Used with node ID index for SetProperty.
72    native_props: HashMap<(String, String), serde_json::Value>,
73    /// Native node IDs grouped by kind, for SetProperty targeting.
74    native_node_ids: HashMap<String, Vec<i32>>,
75    /// Current parameters.
76    current: G::Params,
77    _phantom: PhantomData<G>,
78}
79
80impl<G: DspGraph> Engine<G> {
81    /// Create the engine, build and mount the graph, apply it to the runtime.
82    pub fn new(sample_rate: f64, buffer_size: usize, params: &G::Params) -> Result<Self, String> {
83        let runtime = Runtime::new()
84            .sample_rate(sample_rate)
85            .buffer_size(buffer_size)
86            .call()
87            .map_err(|e| format!("failed to create runtime: {e}"))?;
88
89        let roots = G::build(params);
90        let keyed_consts = collect_keyed_consts(&roots);
91        let native_props = collect_native_props(&roots);
92
93        let graph = Graph::new().render(roots);
94        let mounted = graph
95            .mount()
96            .map_err(|e| format!("graph mount failed: {e}"))?;
97
98        runtime
99            .apply_instructions(mounted.batch())
100            .map_err(|e| format!("failed to apply initial graph: {e}"))?;
101
102        let mut native_node_ids: HashMap<String, Vec<i32>> = HashMap::new();
103        for (_, node) in mounted.all_nodes() {
104            native_node_ids
105                .entry(node.kind().to_string())
106                .or_default()
107                .push(node.id());
108        }
109
110        log::info!(
111            "Engine<{}> created: {} keyed consts, {} native prop groups",
112            std::any::type_name::<G>(),
113            keyed_consts.len(),
114            native_props.len(),
115        );
116
117        Ok(Self {
118            runtime,
119            mounted,
120            keyed_consts,
121            native_props,
122            native_node_ids,
123            current: params.clone(),
124            _phantom: PhantomData,
125        })
126    }
127
128    /// Update parameters. Rebuilds the graph declaratively, diffs keyed
129    /// consts and native node props against the previous build, and emits
130    /// minimal instruction batches for any changes.
131    pub fn set_params(&mut self, params: &G::Params) {
132        let new_roots = G::build(params);
133        let new_keyed_consts = collect_keyed_consts(&new_roots);
134        let new_native_props = collect_native_props(&new_roots);
135
136        // Diff keyed consts.
137        for (key, &new_val) in &new_keyed_consts {
138            let changed = self
139                .keyed_consts
140                .get(key)
141                .map_or(true, |&old| (old - new_val).abs() > f64::EPSILON);
142
143            if changed {
144                if let Some(batch) = self.mounted.set_const_value(key, new_val) {
145                    let _ = self.runtime.apply_instructions(&batch);
146                }
147            }
148        }
149
150        // Diff native props.
151        for ((kind, prop), new_val) in &new_native_props {
152            let changed = self
153                .native_props
154                .get(&(kind.clone(), prop.clone()))
155                .map_or(true, |old| old != new_val);
156
157            if changed {
158                if let Some(ids) = self.native_node_ids.get(kind) {
159                    let mut batch = InstructionBatch::new();
160                    for &id in ids {
161                        batch.push(Instruction::SetProperty {
162                            node_id: id,
163                            property: prop.clone(),
164                            value: new_val.clone(),
165                        });
166                    }
167                    batch.push(Instruction::CommitUpdates);
168                    let _ = self.runtime.apply_instructions(&batch);
169                }
170            }
171        }
172
173        self.keyed_consts = new_keyed_consts;
174        self.native_props = new_native_props;
175        self.current = params.clone();
176    }
177
178    /// Process a block of audio.
179    pub fn process(
180        &self,
181        num_samples: usize,
182        inputs: &[&[f64]],
183        outputs: &mut [&mut [f64]],
184    ) -> crate::Result<()> {
185        self.runtime.process(num_samples, inputs, outputs)
186    }
187
188    /// Returns a reference to the mounted graph.
189    pub fn mounted(&self) -> &MountedGraph {
190        &self.mounted
191    }
192
193    /// Returns a reference to the underlying runtime.
194    pub fn runtime(&self) -> &Runtime {
195        &self.runtime
196    }
197
198    /// Returns the current parameters.
199    pub fn params(&self) -> &G::Params {
200        &self.current
201    }
202}
203
204// ---- Graph tree walkers -----------------------------------------------
205
206/// Walk the node tree and collect all keyed `const` nodes with their values.
207fn collect_keyed_consts(roots: &[Node]) -> HashMap<String, f64> {
208    let mut result = HashMap::new();
209    for root in roots {
210        walk_keyed_consts(root, &mut result);
211    }
212    result
213}
214
215fn walk_keyed_consts(node: &Node, result: &mut HashMap<String, f64>) {
216    if node.kind() == "const" {
217        if let serde_json::Value::Object(props) = node.props() {
218            if let (Some(key), Some(value)) = (
219                props.get("key").and_then(|v| v.as_str()),
220                props.get("value").and_then(|v| v.as_f64()),
221            ) {
222                result.insert(key.to_string(), value);
223            }
224        }
225    }
226    for child in node.children() {
227        walk_keyed_consts(child, result);
228    }
229}
230
231/// Walk the node tree and collect all non-const, non-structural node props.
232/// Returns a map of (node_kind, prop_name) → prop_value.
233///
234/// Excludes: `const`, `root`, `in`, `add`, `sub`, `mul`, `div` (structural
235/// nodes whose props are either empty or don't change between builds).
236/// Also excludes the `key` prop itself (it's identity, not a parameter).
237fn collect_native_props(roots: &[Node]) -> HashMap<(String, String), serde_json::Value> {
238    let mut result = HashMap::new();
239    for root in roots {
240        walk_native_props(root, &mut result);
241    }
242    result
243}
244
245const STRUCTURAL_KINDS: &[&str] = &[
246    "const", "root", "in", "add", "sub", "mul", "div", "mod", "sin", "cos", "tan", "tanh", "exp",
247    "log", "log2", "sqrt", "abs", "ceil", "floor", "round", "le", "ge", "eq", "and", "or", "pow",
248    "min", "max", "pole", "z", "prewarp", "phasor", "sphasor",
249];
250
251fn walk_native_props(node: &Node, result: &mut HashMap<(String, String), serde_json::Value>) {
252    let kind = node.kind();
253    if !STRUCTURAL_KINDS.contains(&kind) {
254        if let serde_json::Value::Object(props) = node.props() {
255            for (prop_name, prop_value) in props {
256                if prop_name == "key" {
257                    continue;
258                }
259                result.insert((kind.to_string(), prop_name.clone()), prop_value.clone());
260            }
261        }
262    }
263    for child in node.children() {
264        walk_native_props(child, result);
265    }
266}
267
268// ---- Legacy types (backward compatibility) ----------------------------
269
270/// Legacy types kept for backward compatibility during migration.
271/// New code should not use these — the engine auto-discovers keyed
272/// consts and native props from the graph tree.
273mod legacy {
274    /// A keyed const declaration (legacy — auto-discovered by engine).
275    #[derive(Debug, Clone)]
276    pub struct KeyedConst {
277        pub key: String,
278        pub value: f64,
279    }
280
281    impl KeyedConst {
282        pub fn new(key: impl Into<String>, value: f64) -> Self {
283            Self {
284                key: key.into(),
285                value,
286            }
287        }
288    }
289
290    /// A native node property declaration (legacy — auto-discovered by engine).
291    #[derive(Debug, Clone)]
292    pub struct NativeProp {
293        pub node_kind: String,
294        pub property: String,
295        pub value: f64,
296    }
297
298    impl NativeProp {
299        pub fn new(node_kind: impl Into<String>, property: impl Into<String>, value: f64) -> Self {
300            Self {
301                node_kind: node_kind.into(),
302                property: property.into(),
303                value,
304            }
305        }
306    }
307}
308
309#[cfg(test)]
310mod tests {
311    use super::*;
312    use crate::authoring::{el, extra};
313    use serde_json::json;
314
315    #[derive(Debug, Clone)]
316    struct TestParams {
317        delay_ms: f64,
318        fb: f64,
319        transition_ms: f64,
320    }
321
322    struct TestDelayGraph;
323
324    impl DspGraph for TestDelayGraph {
325        type Params = TestParams;
326
327        fn build(p: &TestParams) -> Vec<Node> {
328            let delay = el::const_with_key("delay", p.delay_ms);
329            let fb = el::const_with_key("fb", p.fb);
330            let input = el::r#in(json!({"channel": 0}), None);
331            vec![extra::stride_delay(
332                json!({ "maxDelayMs": 500, "transitionMs": p.transition_ms }),
333                delay,
334                fb,
335                input,
336            )]
337        }
338    }
339
340    #[test]
341    fn engine_creates_and_processes() {
342        let params = TestParams {
343            delay_ms: 50.0,
344            fb: 0.0,
345            transition_ms: 10.0,
346        };
347
348        let engine = Engine::<TestDelayGraph>::new(44100.0, 512, &params).expect("engine creation");
349
350        let mut input_buf = vec![0.0_f64; 512];
351        input_buf[0] = 1.0;
352        let mut output_buf = vec![0.0_f64; 512];
353
354        let inputs = [input_buf.as_slice()];
355        let mut outputs = [output_buf.as_mut_slice()];
356        engine.process(512, &inputs, &mut outputs).expect("process");
357
358        let silence = vec![0.0_f64; 512];
359        let mut found = false;
360        for _ in 0..20 {
361            let inputs = [silence.as_slice()];
362            let mut out = vec![0.0_f64; 512];
363            let mut outputs = [out.as_mut_slice()];
364            engine.process(512, &inputs, &mut outputs).expect("process");
365            if outputs[0].iter().any(|&s| s.abs() > 1e-10) {
366                found = true;
367                break;
368            }
369        }
370        assert!(found, "engine should produce delayed output");
371    }
372
373    #[test]
374    fn engine_auto_discovers_keyed_consts() {
375        let params = TestParams {
376            delay_ms: 50.0,
377            fb: 0.3,
378            transition_ms: 10.0,
379        };
380
381        let engine = Engine::<TestDelayGraph>::new(44100.0, 64, &params).expect("engine creation");
382
383        assert_eq!(engine.keyed_consts.len(), 2);
384        assert!((engine.keyed_consts["delay"] - 50.0).abs() < f64::EPSILON);
385        assert!((engine.keyed_consts["fb"] - 0.3).abs() < f64::EPSILON);
386    }
387
388    #[test]
389    fn engine_auto_discovers_native_props() {
390        let params = TestParams {
391            delay_ms: 50.0,
392            fb: 0.0,
393            transition_ms: 10.0,
394        };
395
396        let engine = Engine::<TestDelayGraph>::new(44100.0, 64, &params).expect("engine creation");
397
398        // stridedelay has props: maxDelayMs, transitionMs, bigLeapMode
399        assert!(engine
400            .native_props
401            .contains_key(&("stridedelay".to_string(), "transitionMs".to_string())));
402        assert!(engine
403            .native_props
404            .contains_key(&("stridedelay".to_string(), "maxDelayMs".to_string())));
405    }
406
407    #[test]
408    fn engine_set_params_diffs_automatically() {
409        let params = TestParams {
410            delay_ms: 50.0,
411            fb: 0.0,
412            transition_ms: 10.0,
413        };
414
415        let mut engine =
416            Engine::<TestDelayGraph>::new(44100.0, 64, &params).expect("engine creation");
417
418        let new_params = TestParams {
419            delay_ms: 100.0,
420            fb: 0.3,
421            transition_ms: 20.0,
422        };
423        engine.set_params(&new_params);
424
425        assert!((engine.keyed_consts["delay"] - 100.0).abs() < f64::EPSILON);
426        assert!((engine.keyed_consts["fb"] - 0.3).abs() < f64::EPSILON);
427        assert_eq!(
428            engine.native_props[&("stridedelay".to_string(), "transitionMs".to_string())],
429            serde_json::json!(20.0)
430        );
431    }
432
433    #[test]
434    fn engine_no_update_when_params_unchanged() {
435        let params = TestParams {
436            delay_ms: 50.0,
437            fb: 0.0,
438            transition_ms: 10.0,
439        };
440
441        let mut engine =
442            Engine::<TestDelayGraph>::new(44100.0, 64, &params).expect("engine creation");
443
444        engine.set_params(&params);
445        assert!((engine.keyed_consts["delay"] - 50.0).abs() < f64::EPSILON);
446    }
447
448    #[test]
449    fn engine_rejects_duplicate_keys() {
450        struct BadGraph;
451
452        impl DspGraph for BadGraph {
453            type Params = ();
454
455            fn build(_: &()) -> Vec<Node> {
456                let a = el::const_with_key("dup", 1.0);
457                let b = el::const_with_key("dup", 2.0);
458                vec![el::add((a, b))]
459            }
460        }
461
462        let result = Engine::<BadGraph>::new(44100.0, 64, &());
463        assert!(result.is_err());
464        let err = result.err().unwrap();
465        assert!(err.contains("duplicate"), "got: {err}");
466    }
467}