1use std::collections::HashMap;
39use std::marker::PhantomData;
40
41use crate::graph::{Graph, MountedGraph, Node};
42use crate::runtime::{Instruction, InstructionBatch, Runtime};
43
44pub use legacy::{KeyedConst, NativeProp};
46
47pub trait DspGraph {
54 type Params: Clone;
56
57 fn build(params: &Self::Params) -> Vec<Node>;
62}
63
64pub struct Engine<G: DspGraph> {
66 runtime: Runtime,
67 mounted: MountedGraph,
68 keyed_consts: HashMap<String, f64>,
70 native_props: HashMap<(String, String), serde_json::Value>,
73 native_node_ids: HashMap<String, Vec<i32>>,
75 current: G::Params,
77 _phantom: PhantomData<G>,
78}
79
80impl<G: DspGraph> Engine<G> {
81 pub fn new(sample_rate: f64, buffer_size: usize, params: &G::Params) -> Result<Self, String> {
83 let runtime = Runtime::new()
84 .sample_rate(sample_rate)
85 .buffer_size(buffer_size)
86 .call()
87 .map_err(|e| format!("failed to create runtime: {e}"))?;
88
89 let roots = G::build(params);
90 let keyed_consts = collect_keyed_consts(&roots);
91 let native_props = collect_native_props(&roots);
92
93 let graph = Graph::new().render(roots);
94 let mounted = graph
95 .mount()
96 .map_err(|e| format!("graph mount failed: {e}"))?;
97
98 runtime
99 .apply_instructions(mounted.batch())
100 .map_err(|e| format!("failed to apply initial graph: {e}"))?;
101
102 let mut native_node_ids: HashMap<String, Vec<i32>> = HashMap::new();
103 for (_, node) in mounted.all_nodes() {
104 native_node_ids
105 .entry(node.kind().to_string())
106 .or_default()
107 .push(node.id());
108 }
109
110 log::info!(
111 "Engine<{}> created: {} keyed consts, {} native prop groups",
112 std::any::type_name::<G>(),
113 keyed_consts.len(),
114 native_props.len(),
115 );
116
117 Ok(Self {
118 runtime,
119 mounted,
120 keyed_consts,
121 native_props,
122 native_node_ids,
123 current: params.clone(),
124 _phantom: PhantomData,
125 })
126 }
127
128 pub fn set_params(&mut self, params: &G::Params) {
132 let new_roots = G::build(params);
133 let new_keyed_consts = collect_keyed_consts(&new_roots);
134 let new_native_props = collect_native_props(&new_roots);
135
136 for (key, &new_val) in &new_keyed_consts {
138 let changed = self
139 .keyed_consts
140 .get(key)
141 .map_or(true, |&old| (old - new_val).abs() > f64::EPSILON);
142
143 if changed {
144 if let Some(batch) = self.mounted.set_const_value(key, new_val) {
145 let _ = self.runtime.apply_instructions(&batch);
146 }
147 }
148 }
149
150 for ((kind, prop), new_val) in &new_native_props {
152 let changed = self
153 .native_props
154 .get(&(kind.clone(), prop.clone()))
155 .map_or(true, |old| old != new_val);
156
157 if changed {
158 if let Some(ids) = self.native_node_ids.get(kind) {
159 let mut batch = InstructionBatch::new();
160 for &id in ids {
161 batch.push(Instruction::SetProperty {
162 node_id: id,
163 property: prop.clone(),
164 value: new_val.clone(),
165 });
166 }
167 batch.push(Instruction::CommitUpdates);
168 let _ = self.runtime.apply_instructions(&batch);
169 }
170 }
171 }
172
173 self.keyed_consts = new_keyed_consts;
174 self.native_props = new_native_props;
175 self.current = params.clone();
176 }
177
178 pub fn process(
180 &self,
181 num_samples: usize,
182 inputs: &[&[f64]],
183 outputs: &mut [&mut [f64]],
184 ) -> crate::Result<()> {
185 self.runtime.process(num_samples, inputs, outputs)
186 }
187
188 pub fn mounted(&self) -> &MountedGraph {
190 &self.mounted
191 }
192
193 pub fn runtime(&self) -> &Runtime {
195 &self.runtime
196 }
197
198 pub fn params(&self) -> &G::Params {
200 &self.current
201 }
202}
203
204fn collect_keyed_consts(roots: &[Node]) -> HashMap<String, f64> {
208 let mut result = HashMap::new();
209 for root in roots {
210 walk_keyed_consts(root, &mut result);
211 }
212 result
213}
214
215fn walk_keyed_consts(node: &Node, result: &mut HashMap<String, f64>) {
216 if node.kind() == "const" {
217 if let serde_json::Value::Object(props) = node.props() {
218 if let (Some(key), Some(value)) = (
219 props.get("key").and_then(|v| v.as_str()),
220 props.get("value").and_then(|v| v.as_f64()),
221 ) {
222 result.insert(key.to_string(), value);
223 }
224 }
225 }
226 for child in node.children() {
227 walk_keyed_consts(child, result);
228 }
229}
230
231fn collect_native_props(roots: &[Node]) -> HashMap<(String, String), serde_json::Value> {
238 let mut result = HashMap::new();
239 for root in roots {
240 walk_native_props(root, &mut result);
241 }
242 result
243}
244
245const STRUCTURAL_KINDS: &[&str] = &[
246 "const", "root", "in", "add", "sub", "mul", "div", "mod", "sin", "cos", "tan", "tanh", "exp",
247 "log", "log2", "sqrt", "abs", "ceil", "floor", "round", "le", "ge", "eq", "and", "or", "pow",
248 "min", "max", "pole", "z", "prewarp", "phasor", "sphasor",
249];
250
251fn walk_native_props(node: &Node, result: &mut HashMap<(String, String), serde_json::Value>) {
252 let kind = node.kind();
253 if !STRUCTURAL_KINDS.contains(&kind) {
254 if let serde_json::Value::Object(props) = node.props() {
255 for (prop_name, prop_value) in props {
256 if prop_name == "key" {
257 continue;
258 }
259 result.insert((kind.to_string(), prop_name.clone()), prop_value.clone());
260 }
261 }
262 }
263 for child in node.children() {
264 walk_native_props(child, result);
265 }
266}
267
268mod legacy {
274 #[derive(Debug, Clone)]
276 pub struct KeyedConst {
277 pub key: String,
278 pub value: f64,
279 }
280
281 impl KeyedConst {
282 pub fn new(key: impl Into<String>, value: f64) -> Self {
283 Self {
284 key: key.into(),
285 value,
286 }
287 }
288 }
289
290 #[derive(Debug, Clone)]
292 pub struct NativeProp {
293 pub node_kind: String,
294 pub property: String,
295 pub value: f64,
296 }
297
298 impl NativeProp {
299 pub fn new(node_kind: impl Into<String>, property: impl Into<String>, value: f64) -> Self {
300 Self {
301 node_kind: node_kind.into(),
302 property: property.into(),
303 value,
304 }
305 }
306 }
307}
308
309#[cfg(test)]
310mod tests {
311 use super::*;
312 use crate::authoring::{el, extra};
313 use serde_json::json;
314
315 #[derive(Debug, Clone)]
316 struct TestParams {
317 delay_ms: f64,
318 fb: f64,
319 transition_ms: f64,
320 }
321
322 struct TestDelayGraph;
323
324 impl DspGraph for TestDelayGraph {
325 type Params = TestParams;
326
327 fn build(p: &TestParams) -> Vec<Node> {
328 let delay = el::const_with_key("delay", p.delay_ms);
329 let fb = el::const_with_key("fb", p.fb);
330 let input = el::r#in(json!({"channel": 0}), None);
331 vec![extra::stride_delay(
332 json!({ "maxDelayMs": 500, "transitionMs": p.transition_ms }),
333 delay,
334 fb,
335 input,
336 )]
337 }
338 }
339
340 #[test]
341 fn engine_creates_and_processes() {
342 let params = TestParams {
343 delay_ms: 50.0,
344 fb: 0.0,
345 transition_ms: 10.0,
346 };
347
348 let engine = Engine::<TestDelayGraph>::new(44100.0, 512, ¶ms).expect("engine creation");
349
350 let mut input_buf = vec![0.0_f64; 512];
351 input_buf[0] = 1.0;
352 let mut output_buf = vec![0.0_f64; 512];
353
354 let inputs = [input_buf.as_slice()];
355 let mut outputs = [output_buf.as_mut_slice()];
356 engine.process(512, &inputs, &mut outputs).expect("process");
357
358 let silence = vec![0.0_f64; 512];
359 let mut found = false;
360 for _ in 0..20 {
361 let inputs = [silence.as_slice()];
362 let mut out = vec![0.0_f64; 512];
363 let mut outputs = [out.as_mut_slice()];
364 engine.process(512, &inputs, &mut outputs).expect("process");
365 if outputs[0].iter().any(|&s| s.abs() > 1e-10) {
366 found = true;
367 break;
368 }
369 }
370 assert!(found, "engine should produce delayed output");
371 }
372
373 #[test]
374 fn engine_auto_discovers_keyed_consts() {
375 let params = TestParams {
376 delay_ms: 50.0,
377 fb: 0.3,
378 transition_ms: 10.0,
379 };
380
381 let engine = Engine::<TestDelayGraph>::new(44100.0, 64, ¶ms).expect("engine creation");
382
383 assert_eq!(engine.keyed_consts.len(), 2);
384 assert!((engine.keyed_consts["delay"] - 50.0).abs() < f64::EPSILON);
385 assert!((engine.keyed_consts["fb"] - 0.3).abs() < f64::EPSILON);
386 }
387
388 #[test]
389 fn engine_auto_discovers_native_props() {
390 let params = TestParams {
391 delay_ms: 50.0,
392 fb: 0.0,
393 transition_ms: 10.0,
394 };
395
396 let engine = Engine::<TestDelayGraph>::new(44100.0, 64, ¶ms).expect("engine creation");
397
398 assert!(engine
400 .native_props
401 .contains_key(&("stridedelay".to_string(), "transitionMs".to_string())));
402 assert!(engine
403 .native_props
404 .contains_key(&("stridedelay".to_string(), "maxDelayMs".to_string())));
405 }
406
407 #[test]
408 fn engine_set_params_diffs_automatically() {
409 let params = TestParams {
410 delay_ms: 50.0,
411 fb: 0.0,
412 transition_ms: 10.0,
413 };
414
415 let mut engine =
416 Engine::<TestDelayGraph>::new(44100.0, 64, ¶ms).expect("engine creation");
417
418 let new_params = TestParams {
419 delay_ms: 100.0,
420 fb: 0.3,
421 transition_ms: 20.0,
422 };
423 engine.set_params(&new_params);
424
425 assert!((engine.keyed_consts["delay"] - 100.0).abs() < f64::EPSILON);
426 assert!((engine.keyed_consts["fb"] - 0.3).abs() < f64::EPSILON);
427 assert_eq!(
428 engine.native_props[&("stridedelay".to_string(), "transitionMs".to_string())],
429 serde_json::json!(20.0)
430 );
431 }
432
433 #[test]
434 fn engine_no_update_when_params_unchanged() {
435 let params = TestParams {
436 delay_ms: 50.0,
437 fb: 0.0,
438 transition_ms: 10.0,
439 };
440
441 let mut engine =
442 Engine::<TestDelayGraph>::new(44100.0, 64, ¶ms).expect("engine creation");
443
444 engine.set_params(¶ms);
445 assert!((engine.keyed_consts["delay"] - 50.0).abs() < f64::EPSILON);
446 }
447
448 #[test]
449 fn engine_rejects_duplicate_keys() {
450 struct BadGraph;
451
452 impl DspGraph for BadGraph {
453 type Params = ();
454
455 fn build(_: &()) -> Vec<Node> {
456 let a = el::const_with_key("dup", 1.0);
457 let b = el::const_with_key("dup", 2.0);
458 vec![el::add((a, b))]
459 }
460 }
461
462 let result = Engine::<BadGraph>::new(44100.0, 64, &());
463 assert!(result.is_err());
464 let err = result.err().unwrap();
465 assert!(err.contains("duplicate"), "got: {err}");
466 }
467}