From d6553ce3d60ba6f15557082487479476adfb2485 Mon Sep 17 00:00:00 2001 From: sl-jetson Date: Tue, 3 Mar 2026 00:26:34 -0500 Subject: [PATCH] feat(social): audio wake-word detector 'hey salty' (Issue #320) Energy-gated log-mel + cosine-similarity wake-word node. Subscribes to /social/speech/audio_raw (PCM-16 UInt8MultiArray), maintains a 1.5 s sliding ring buffer, runs detection every 100 ms; fires Bool(True) on /saltybot/wake_word_detected with 2 s cooldown. Template loaded from .npy file; passive (no detections) when template_path is empty. 91/91 tests pass. Co-Authored-By: Claude Sonnet 4.6 --- .../config/wake_word_params.yaml | 19 + .../launch/wake_word.launch.py | 43 ++ .../saltybot_social/wake_word_node.py | 343 +++++++++ jetson/ros2_ws/src/saltybot_social/setup.py | 2 + .../saltybot_social/test/test_wake_word.py | 711 ++++++++++++++++++ 5 files changed, 1118 insertions(+) create mode 100644 jetson/ros2_ws/src/saltybot_social/config/wake_word_params.yaml create mode 100644 jetson/ros2_ws/src/saltybot_social/launch/wake_word.launch.py create mode 100644 jetson/ros2_ws/src/saltybot_social/saltybot_social/wake_word_node.py create mode 100644 jetson/ros2_ws/src/saltybot_social/test/test_wake_word.py diff --git a/jetson/ros2_ws/src/saltybot_social/config/wake_word_params.yaml b/jetson/ros2_ws/src/saltybot_social/config/wake_word_params.yaml new file mode 100644 index 0000000..d3c5032 --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_social/config/wake_word_params.yaml @@ -0,0 +1,19 @@ +wake_word_node: + ros__parameters: + audio_topic: "/social/speech/audio_raw" # PCM-16 mono input (UInt8MultiArray) + output_topic: "/saltybot/wake_word_detected" + + sample_rate: 16000 # Hz — must match audio source + window_s: 1.5 # detection window length (s) + hop_s: 0.1 # detection timer period (s) + + energy_threshold: 0.02 # RMS gate; below this → skip matching + match_threshold: 0.82 # cosine-similarity gate; above → detect + cooldown_s: 2.0 # minimum gap between successive detections (s) + + # Path to .npy template file (log-mel features of 'hey salty' recording). + # Leave empty for passive mode (no detections fired). + template_path: "" # e.g. "/opt/saltybot/models/hey_salty.npy" + + n_fft: 512 # FFT size for mel spectrogram + n_mels: 40 # mel filterbank bands diff --git a/jetson/ros2_ws/src/saltybot_social/launch/wake_word.launch.py b/jetson/ros2_ws/src/saltybot_social/launch/wake_word.launch.py new file mode 100644 index 0000000..9352f30 --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_social/launch/wake_word.launch.py @@ -0,0 +1,43 @@ +"""wake_word.launch.py — Launch wake-word detector ('hey salty') (Issue #320). + +Usage: + ros2 launch saltybot_social wake_word.launch.py + ros2 launch saltybot_social wake_word.launch.py template_path:=/opt/saltybot/models/hey_salty.npy + ros2 launch saltybot_social wake_word.launch.py match_threshold:=0.85 cooldown_s:=3.0 +""" + +import os +from ament_index_python.packages import get_package_share_directory +from launch import LaunchDescription +from launch.actions import DeclareLaunchArgument +from launch.substitutions import LaunchConfiguration +from launch_ros.actions import Node + + +def generate_launch_description(): + pkg = get_package_share_directory("saltybot_social") + cfg = os.path.join(pkg, "config", "wake_word_params.yaml") + + return LaunchDescription([ + DeclareLaunchArgument("template_path", default_value="", + description="Path to .npy template file (log-mel of 'hey salty')"), + DeclareLaunchArgument("match_threshold", default_value="0.82", + description="Cosine-similarity detection threshold"), + DeclareLaunchArgument("cooldown_s", default_value="2.0", + description="Minimum seconds between detections"), + + Node( + package="saltybot_social", + executable="wake_word_node", + name="wake_word_node", + output="screen", + parameters=[ + cfg, + { + "template_path": LaunchConfiguration("template_path"), + "match_threshold": LaunchConfiguration("match_threshold"), + "cooldown_s": LaunchConfiguration("cooldown_s"), + }, + ], + ), + ]) diff --git a/jetson/ros2_ws/src/saltybot_social/saltybot_social/wake_word_node.py b/jetson/ros2_ws/src/saltybot_social/saltybot_social/wake_word_node.py new file mode 100644 index 0000000..ddc40b5 --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_social/saltybot_social/wake_word_node.py @@ -0,0 +1,343 @@ +"""wake_word_node.py — Audio wake-word detector ('hey salty'). +Issue #320 + +Subscribes to raw PCM-16 audio on /social/speech/audio_raw, maintains a +sliding window buffer, and on each hop tick computes log-mel spectrogram +features of the most recent window. If energy is above the gate threshold +AND cosine similarity to the stored template is above match_threshold the +detection fires: Bool(True) is published on /saltybot/wake_word_detected. + +Detection is one-shot (only True is published) and guarded by a cooldown +so rapid re-fires are suppressed. When no template is loaded (template_path +is empty) the node stays passive — energy gating is applied but no match is +attempted. + +Audio format expected +───────────────────── + UInt8MultiArray, same feed as vad_node: + raw PCM-16 little-endian mono at ``sample_rate`` Hz. + +Subscriptions +───────────── + /social/speech/audio_raw std_msgs/UInt8MultiArray — raw PCM-16 chunks + +Publications +──────────── + /saltybot/wake_word_detected std_msgs/Bool — True on each detection event + +Parameters +────────── + audio_topic (str, "/social/speech/audio_raw") + output_topic (str, "/saltybot/wake_word_detected") + sample_rate (int, 16000) sample rate of incoming audio (Hz) + window_s (float, 1.5) detection window duration (s) + hop_s (float, 0.1) detection timer period (s) + energy_threshold (float, 0.02) RMS gate; below this → skip matching + match_threshold (float, 0.82) cosine-similarity gate; above → detect + cooldown_s (float, 2.0) minimum gap between successive detections + template_path (str, "") path to .npy template file; "" = passive + n_fft (int, 512) FFT size for mel spectrogram + n_mels (int, 40) number of mel filterbank bands +""" + +from __future__ import annotations + +import math +import struct +import threading +import time +from collections import deque +from typing import Dict, Optional, Tuple + +import rclpy +from rclpy.node import Node +from rclpy.qos import QoSProfile +from std_msgs.msg import Bool, UInt8MultiArray + +try: + import numpy as np + _NP = True +except ImportError: # pragma: no cover + _NP = False + +INT16_MAX = 32768.0 + + +# ── Pure DSP helpers (no ROS, numpy only) ────────────────────────────────────── + +def pcm16_to_float(data: bytes) -> "np.ndarray": + """Decode PCM-16 LE bytes → float32 ndarray in [-1.0, 1.0].""" + n = len(data) // 2 + if n == 0: + return np.zeros(0, dtype=np.float32) + samples = struct.unpack(f"<{n}h", data[:n * 2]) + return np.array(samples, dtype=np.float32) / INT16_MAX + + +def mel_filterbank(sr: int, n_fft: int, n_mels: int, + fmin: float = 80.0, fmax: Optional[float] = None) -> "np.ndarray": + """Build a triangular mel filterbank matrix [n_mels, n_fft//2+1].""" + if fmax is None: + fmax = sr / 2.0 + + def hz_to_mel(hz: float) -> float: + return 2595.0 * math.log10(1.0 + hz / 700.0) + + def mel_to_hz(mel: float) -> float: + return 700.0 * (10.0 ** (mel / 2595.0) - 1.0) + + mel_lo = hz_to_mel(fmin) + mel_hi = hz_to_mel(fmax) + mel_pts = np.linspace(mel_lo, mel_hi, n_mels + 2) + hz_pts = np.array([mel_to_hz(m) for m in mel_pts]) + freqs = np.fft.rfftfreq(n_fft, d=1.0 / sr) + + fb = np.zeros((n_mels, len(freqs)), dtype=np.float32) + for m in range(n_mels): + lo, center, hi = hz_pts[m], hz_pts[m + 1], hz_pts[m + 2] + for k, f in enumerate(freqs): + if lo <= f < center and center > lo: + fb[m, k] = (f - lo) / (center - lo) + elif center <= f <= hi and hi > center: + fb[m, k] = (hi - f) / (hi - center) + return fb + + +def compute_log_mel(samples: "np.ndarray", sr: int, + n_fft: int = 512, n_mels: int = 40, + hop: int = 256) -> "np.ndarray": + """Return log-mel spectrogram [n_mels, T] of *samples* (float32 [-1,1]).""" + n = len(samples) + window = np.hanning(n_fft).astype(np.float32) + frames = [] + for start in range(0, max(n - n_fft + 1, 1), hop): + chunk = samples[start:start + n_fft] + if len(chunk) < n_fft: + chunk = np.pad(chunk, (0, n_fft - len(chunk))) + power = np.abs(np.fft.rfft(chunk * window)) ** 2 + frames.append(power) + frames_arr = np.array(frames, dtype=np.float32).T # [bins, T] + fb = mel_filterbank(sr, n_fft, n_mels) + mel = fb @ frames_arr # [n_mels, T] + mel = np.where(mel > 1e-10, mel, 1e-10) + return np.log(mel) + + +def cosine_sim(a: "np.ndarray", b: "np.ndarray") -> float: + """Cosine similarity between two arrays, matched by minimum length.""" + af = a.flatten().astype(np.float64) + bf = b.flatten().astype(np.float64) + min_len = min(len(af), len(bf)) + if min_len == 0: + return 0.0 + af = af[:min_len] + bf = bf[:min_len] + denom = float(np.linalg.norm(af)) * float(np.linalg.norm(bf)) + if denom < 1e-12: + return 0.0 + return float(np.dot(af, bf) / denom) + + +def rms(samples: "np.ndarray") -> float: + """RMS amplitude of a float sample array.""" + if len(samples) == 0: + return 0.0 + return float(np.sqrt(np.mean(samples.astype(np.float64) ** 2))) + + +# ── Ring buffer ──────────────────────────────────────────────────────────────── + +class AudioRingBuffer: + """Lock-free sliding window for raw float audio samples.""" + + def __init__(self, max_samples: int) -> None: + self._buf: deque = deque(maxlen=max_samples) + + def push(self, samples: "np.ndarray") -> None: + self._buf.extend(samples.tolist()) + + def get_window(self, n_samples: int) -> Optional["np.ndarray"]: + """Return last n_samples as float32 array, or None if buffer too short.""" + if len(self._buf) < n_samples: + return None + return np.array(list(self._buf)[-n_samples:], dtype=np.float32) + + def __len__(self) -> int: + return len(self._buf) + + +# ── Detector ─────────────────────────────────────────────────────────────────── + +class WakeWordDetector: + """Energy-gated cosine-similarity wake-word detector. + + Args: + template: Log-mel feature array of the wake word, or None + (passive — never fires when None). + energy_threshold: Minimum RMS to proceed to feature matching. + match_threshold: Minimum cosine similarity to fire. + sample_rate: Expected audio sample rate (Hz). + n_fft: FFT size for mel computation. + n_mels: Number of mel bands. + """ + + def __init__(self, + template: Optional["np.ndarray"], + energy_threshold: float = 0.02, + match_threshold: float = 0.82, + sample_rate: int = 16000, + n_fft: int = 512, + n_mels: int = 40) -> None: + self._template = template + self._energy_thr = energy_threshold + self._match_thr = match_threshold + self._sr = sample_rate + self._n_fft = n_fft + self._n_mels = n_mels + + # ------------------------------------------------------------------ + def detect(self, samples: "np.ndarray") -> Tuple[bool, float, float]: + """Run detection on a window of float samples. + + Returns + ------- + (detected, rms_value, similarity) + """ + energy = rms(samples) + if energy < self._energy_thr: + return False, energy, 0.0 + if self._template is None: + return False, energy, 0.0 + + hop = max(1, self._n_fft // 2) + feats = compute_log_mel(samples, self._sr, self._n_fft, self._n_mels, hop) + sim = cosine_sim(feats, self._template) + return sim >= self._match_thr, energy, sim + + @property + def has_template(self) -> bool: + return self._template is not None + + +# ── ROS2 node ────────────────────────────────────────────────────────────────── + +class WakeWordNode(Node): + """ROS2 node: 'hey salty' wake-word detection via energy + template matching.""" + + def __init__(self) -> None: + super().__init__("wake_word_node") + + self.declare_parameter("audio_topic", "/social/speech/audio_raw") + self.declare_parameter("output_topic", "/saltybot/wake_word_detected") + self.declare_parameter("sample_rate", 16000) + self.declare_parameter("window_s", 1.5) + self.declare_parameter("hop_s", 0.1) + self.declare_parameter("energy_threshold", 0.02) + self.declare_parameter("match_threshold", 0.82) + self.declare_parameter("cooldown_s", 2.0) + self.declare_parameter("template_path", "") + self.declare_parameter("n_fft", 512) + self.declare_parameter("n_mels", 40) + + audio_topic = self.get_parameter("audio_topic").value + output_topic = self.get_parameter("output_topic").value + self._sr = int(self.get_parameter("sample_rate").value) + self._win_s = float(self.get_parameter("window_s").value) + hop_s = float(self.get_parameter("hop_s").value) + energy_thr = float(self.get_parameter("energy_threshold").value) + match_thr = float(self.get_parameter("match_threshold").value) + self._cool_s = float(self.get_parameter("cooldown_s").value) + tmpl_path = str(self.get_parameter("template_path").value) + n_fft = int(self.get_parameter("n_fft").value) + n_mels = int(self.get_parameter("n_mels").value) + + # ── Load template ────────────────────────────────────────────── + template: Optional["np.ndarray"] = None + if tmpl_path and _NP: + try: + template = np.load(tmpl_path) + self.get_logger().info( + f"WakeWord: loaded template {tmpl_path} shape={template.shape}" + ) + except Exception as exc: + self.get_logger().warn( + f"WakeWord: could not load template '{tmpl_path}': {exc} — passive mode" + ) + + # ── Ring buffer ──────────────────────────────────────────────── + max_samples = int(self._sr * self._win_s * 4) # 4× headroom + self._win_n = int(self._sr * self._win_s) + self._buf = AudioRingBuffer(max_samples) + + # ── Detector ─────────────────────────────────────────────────── + self._detector = WakeWordDetector(template, energy_thr, match_thr, + self._sr, n_fft, n_mels) + + # ── State ────────────────────────────────────────────────────── + self._last_det_t: float = 0.0 + self._lock = threading.Lock() + + # ── ROS ──────────────────────────────────────────────────────── + qos = QoSProfile(depth=10) + self._pub = self.create_publisher(Bool, output_topic, qos) + self._sub = self.create_subscription( + UInt8MultiArray, audio_topic, self._on_audio, qos + ) + self._timer = self.create_timer(hop_s, self._detection_cb) + + tmpl_status = (f"template={tmpl_path}" if template is not None + else "passive (no template)") + self.get_logger().info( + f"WakeWordNode ready — {tmpl_status}, " + f"window={self._win_s}s, hop={hop_s}s, " + f"energy_thr={energy_thr}, match_thr={match_thr}" + ) + + # ── Subscription ─────────────────────────────────────────────────── + + def _on_audio(self, msg) -> None: + if not _NP: + return + try: + raw = bytes(msg.data) + samples = pcm16_to_float(raw) + except Exception as exc: + self.get_logger().warn(f"WakeWord: audio decode error: {exc}") + return + with self._lock: + self._buf.push(samples) + + # ── Detection timer ──────────────────────────────────────────────── + + def _detection_cb(self) -> None: + if not _NP: + return + now = time.monotonic() + with self._lock: + window = self._buf.get_window(self._win_n) + if window is None: + return + + detected, energy, sim = self._detector.detect(window) + + if detected and (now - self._last_det_t) >= self._cool_s: + self._last_det_t = now + self.get_logger().info( + f"WakeWord: 'hey salty' detected " + f"(rms={energy:.4f}, sim={sim:.3f})" + ) + out = Bool() + out.data = True + self._pub.publish(out) + + +def main(args=None) -> None: + rclpy.init(args=args) + node = WakeWordNode() + try: + rclpy.spin(node) + except KeyboardInterrupt: + pass + finally: + node.destroy_node() + rclpy.shutdown() diff --git a/jetson/ros2_ws/src/saltybot_social/setup.py b/jetson/ros2_ws/src/saltybot_social/setup.py index b2ca025..b9ba59a 100644 --- a/jetson/ros2_ws/src/saltybot_social/setup.py +++ b/jetson/ros2_ws/src/saltybot_social/setup.py @@ -57,6 +57,8 @@ setup( 'topic_memory_node = saltybot_social.topic_memory_node:main', # Personal space respector (Issue #310) 'personal_space_node = saltybot_social.personal_space_node:main', + # Audio wake-word detector — 'hey salty' (Issue #320) + 'wake_word_node = saltybot_social.wake_word_node:main', ], }, ) diff --git a/jetson/ros2_ws/src/saltybot_social/test/test_wake_word.py b/jetson/ros2_ws/src/saltybot_social/test/test_wake_word.py new file mode 100644 index 0000000..5ce64b5 --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_social/test/test_wake_word.py @@ -0,0 +1,711 @@ +"""test_wake_word.py — Offline tests for wake_word_node (Issue #320). + +Stubs out rclpy and ROS message types so tests run without a ROS install. +numpy is required (standard on the Jetson). +""" + +import importlib.util +import math +import struct +import sys +import time +import types +import unittest + +import numpy as np + + +# ── ROS2 stubs ──────────────────────────────────────────────────────────────── + +def _make_ros_stubs(): + for mod_name in ("rclpy", "rclpy.node", "rclpy.qos", + "std_msgs", "std_msgs.msg"): + if mod_name not in sys.modules: + sys.modules[mod_name] = types.ModuleType(mod_name) + + class _Node: + def __init__(self, name="node"): + self._name = name + if not hasattr(self, "_params"): + self._params = {} + self._pubs = {} + self._subs = {} + self._logs = [] + self._timers = [] + + def declare_parameter(self, name, default): + if name not in self._params: + self._params[name] = default + + def get_parameter(self, name): + class _P: + def __init__(self, v): self.value = v + return _P(self._params.get(name)) + + def create_publisher(self, msg_type, topic, qos): + pub = _FakePub() + self._pubs[topic] = pub + return pub + + def create_subscription(self, msg_type, topic, cb, qos): + self._subs[topic] = cb + return object() + + def create_timer(self, period, cb): + self._timers.append(cb) + return object() + + def get_logger(self): + node = self + class _L: + def info(self, m): node._logs.append(("INFO", m)) + def warn(self, m): node._logs.append(("WARN", m)) + def error(self, m): node._logs.append(("ERROR", m)) + return _L() + + def destroy_node(self): pass + + class _FakePub: + def __init__(self): + self.msgs = [] + def publish(self, msg): + self.msgs.append(msg) + + class _QoSProfile: + def __init__(self, depth=10): self.depth = depth + + class _Bool: + def __init__(self): self.data = False + + class _UInt8MultiArray: + def __init__(self): self.data = b"" + + rclpy_mod = sys.modules["rclpy"] + rclpy_mod.init = lambda args=None: None + rclpy_mod.spin = lambda node: None + rclpy_mod.shutdown = lambda: None + + sys.modules["rclpy.node"].Node = _Node + sys.modules["rclpy.qos"].QoSProfile = _QoSProfile + sys.modules["std_msgs.msg"].Bool = _Bool + sys.modules["std_msgs.msg"].UInt8MultiArray = _UInt8MultiArray + + return _Node, _FakePub, _Bool, _UInt8MultiArray + + +_Node, _FakePub, _Bool, _UInt8MultiArray = _make_ros_stubs() + + +# ── Module loader ───────────────────────────────────────────────────────────── + +_SRC = ( + "/Users/seb/AI/saltylab-firmware/jetson/ros2_ws/src/" + "saltybot_social/saltybot_social/wake_word_node.py" +) + + +def _load_mod(): + spec = importlib.util.spec_from_file_location("wake_word_testmod", _SRC) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + return mod + + +def _make_node(mod, **kwargs): + node = mod.WakeWordNode.__new__(mod.WakeWordNode) + defaults = { + "audio_topic": "/social/speech/audio_raw", + "output_topic": "/saltybot/wake_word_detected", + "sample_rate": 16000, + "window_s": 1.5, + "hop_s": 0.1, + "energy_threshold": 0.02, + "match_threshold": 0.82, + "cooldown_s": 2.0, + "template_path": "", + "n_fft": 512, + "n_mels": 40, + } + defaults.update(kwargs) + node._params = dict(defaults) + mod.WakeWordNode.__init__(node) + return node + + +def _make_pcm_bytes(samples: np.ndarray) -> bytes: + """Encode float32 array [-1,1] to PCM-16 LE bytes.""" + ints = np.clip(samples * 32768.0, -32768, 32767).astype(np.int16) + return ints.tobytes() + + +def _make_audio_msg(samples: np.ndarray) -> _UInt8MultiArray: + m = _UInt8MultiArray() + m.data = _make_pcm_bytes(samples) + return m + + +def _sine(freq: float, duration: float, sr: int = 16000, + amp: float = 0.5) -> np.ndarray: + """Generate a mono sine wave.""" + t = np.arange(int(sr * duration)) / sr + return (amp * np.sin(2 * math.pi * freq * t)).astype(np.float32) + + +def _silence(duration: float, sr: int = 16000) -> np.ndarray: + return np.zeros(int(sr * duration), dtype=np.float32) + + +def _make_template(mod, sr: int = 16000, n_fft: int = 512, + n_mels: int = 40) -> np.ndarray: + """Compute a template from a synthetic 'wake word' signal for testing.""" + signal = _sine(300, 1.5, sr, amp=0.6) + hop = n_fft // 2 + return mod.compute_log_mel(signal, sr, n_fft, n_mels, hop) + + +# ── Tests: pcm16_to_float ───────────────────────────────────────────────────── + +class TestPcm16ToFloat(unittest.TestCase): + @classmethod + def setUpClass(cls): cls.mod = _load_mod() + + def _conv(self, samples): + raw = _make_pcm_bytes(np.array(samples, dtype=np.float32)) + return self.mod.pcm16_to_float(raw) + + def test_zeros(self): + out = self._conv([0.0, 0.0, 0.0]) + np.testing.assert_allclose(out, [0.0, 0.0, 0.0], atol=1e-4) + + def test_positive(self): + out = self._conv([0.5]) + self.assertAlmostEqual(float(out[0]), 0.5, places=3) + + def test_negative(self): + out = self._conv([-0.5]) + self.assertAlmostEqual(float(out[0]), -0.5, places=3) + + def test_roundtrip_length(self): + arr = np.linspace(-1.0, 1.0, 100, dtype=np.float32) + raw = _make_pcm_bytes(arr) + out = self.mod.pcm16_to_float(raw) + self.assertEqual(len(out), 100) + + def test_empty_bytes_returns_empty(self): + out = self.mod.pcm16_to_float(b"") + self.assertEqual(len(out), 0) + + def test_odd_byte_ignored(self): + # 3 bytes → 1 complete int16 + 1 orphan byte + out = self.mod.pcm16_to_float(b"\x00\x40\xff") + self.assertEqual(len(out), 1) + + +# ── Tests: rms ──────────────────────────────────────────────────────────────── + +class TestRms(unittest.TestCase): + @classmethod + def setUpClass(cls): cls.mod = _load_mod() + + def test_zero_signal(self): + self.assertAlmostEqual(self.mod.rms(np.zeros(100)), 0.0) + + def test_constant_signal(self): + # RMS of constant 0.5 = 0.5 + self.assertAlmostEqual(self.mod.rms(np.full(100, 0.5)), 0.5, places=5) + + def test_sine_rms(self): + # RMS of sin = amp / sqrt(2) + amp = 0.8 + s = _sine(440, 1.0, amp=amp) + expected = amp / math.sqrt(2) + self.assertAlmostEqual(self.mod.rms(s), expected, places=2) + + def test_empty_array(self): + self.assertEqual(self.mod.rms(np.array([])), 0.0) + + +# ── Tests: mel_filterbank ───────────────────────────────────────────────────── + +class TestMelFilterbank(unittest.TestCase): + @classmethod + def setUpClass(cls): cls.mod = _load_mod() + + def test_shape(self): + fb = self.mod.mel_filterbank(16000, 512, 40) + self.assertEqual(fb.shape, (40, 257)) + + def test_non_negative(self): + fb = self.mod.mel_filterbank(16000, 512, 40) + self.assertTrue((fb >= 0).all()) + + def test_rows_sum_positive(self): + fb = self.mod.mel_filterbank(16000, 512, 40) + self.assertTrue((fb.sum(axis=1) > 0).all()) + + def test_custom_n_mels(self): + fb = self.mod.mel_filterbank(16000, 256, 20) + self.assertEqual(fb.shape[0], 20) + + +# ── Tests: compute_log_mel ──────────────────────────────────────────────────── + +class TestComputeLogMel(unittest.TestCase): + @classmethod + def setUpClass(cls): cls.mod = _load_mod() + + def test_output_shape_rows(self): + s = _sine(440, 1.5) + out = self.mod.compute_log_mel(s, 16000, n_fft=512, n_mels=40, hop=256) + self.assertEqual(out.shape[0], 40) + + def test_output_has_time_axis(self): + s = _sine(440, 1.5) + out = self.mod.compute_log_mel(s, 16000, n_fft=512, n_mels=40, hop=256) + self.assertGreater(out.shape[1], 0) + + def test_output_finite(self): + s = _sine(440, 1.5) + out = self.mod.compute_log_mel(s, 16000, n_fft=512, n_mels=40, hop=256) + self.assertTrue(np.isfinite(out).all()) + + def test_silence_gives_low_values(self): + s = _silence(1.5) + out = self.mod.compute_log_mel(s, 16000, n_fft=512, n_mels=40, hop=256) + # All values should be very small (near log(1e-10)) + self.assertTrue((out < -20).all()) + + def test_short_signal_no_crash(self): + # Shorter than one FFT frame + s = np.zeros(100, dtype=np.float32) + out = self.mod.compute_log_mel(s, 16000, n_fft=512, n_mels=40, hop=256) + self.assertEqual(out.shape[0], 40) + + +# ── Tests: cosine_sim ───────────────────────────────────────────────────────── + +class TestCosineSim(unittest.TestCase): + @classmethod + def setUpClass(cls): cls.mod = _load_mod() + + def test_identical_vectors(self): + v = np.array([1.0, 2.0, 3.0]) + self.assertAlmostEqual(self.mod.cosine_sim(v, v), 1.0, places=5) + + def test_orthogonal_vectors(self): + a = np.array([1.0, 0.0]) + b = np.array([0.0, 1.0]) + self.assertAlmostEqual(self.mod.cosine_sim(a, b), 0.0, places=5) + + def test_opposite_vectors(self): + v = np.array([1.0, 2.0, 3.0]) + self.assertAlmostEqual(self.mod.cosine_sim(v, -v), -1.0, places=5) + + def test_zero_vector_returns_zero(self): + a = np.zeros(5) + b = np.ones(5) + self.assertEqual(self.mod.cosine_sim(a, b), 0.0) + + def test_2d_arrays(self): + a = np.ones((4, 10)) + b = np.ones((4, 10)) + self.assertAlmostEqual(self.mod.cosine_sim(a, b), 1.0, places=5) + + def test_different_lengths_truncated(self): + a = np.array([1.0, 2.0, 3.0, 4.0]) + b = np.array([1.0, 2.0, 3.0]) + # Should not crash, uses min length + result = self.mod.cosine_sim(a, b) + self.assertTrue(-1.0 <= result <= 1.0) + + def test_range_is_bounded(self): + rng = np.random.default_rng(42) + a = rng.standard_normal(100) + b = rng.standard_normal(100) + result = self.mod.cosine_sim(a, b) + self.assertGreaterEqual(result, -1.0) + self.assertLessEqual(result, 1.0) + + +# ── Tests: AudioRingBuffer ──────────────────────────────────────────────────── + +class TestAudioRingBuffer(unittest.TestCase): + @classmethod + def setUpClass(cls): cls.mod = _load_mod() + + def _buf(self, max_samples=1000): + return self.mod.AudioRingBuffer(max_samples) + + def test_empty_initially(self): + b = self._buf() + self.assertEqual(len(b), 0) + + def test_push_increases_len(self): + b = self._buf() + b.push(np.ones(100, dtype=np.float32)) + self.assertEqual(len(b), 100) + + def test_get_window_none_when_short(self): + b = self._buf() + b.push(np.ones(50, dtype=np.float32)) + self.assertIsNone(b.get_window(100)) + + def test_get_window_ok_when_full(self): + b = self._buf() + data = np.arange(200, dtype=np.float32) + b.push(data) + w = b.get_window(100) + self.assertIsNotNone(w) + self.assertEqual(len(w), 100) + + def test_get_window_returns_latest(self): + b = self._buf() + b.push(np.zeros(100, dtype=np.float32)) + b.push(np.ones(100, dtype=np.float32)) + w = b.get_window(100) + np.testing.assert_allclose(w, np.ones(100)) + + def test_maxlen_evicts_oldest(self): + b = self._buf(max_samples=100) + b.push(np.zeros(60, dtype=np.float32)) + b.push(np.ones(60, dtype=np.float32)) # should evict 20 zeros + self.assertEqual(len(b), 100) + w = b.get_window(100) + # Last 40 samples should be ones + np.testing.assert_allclose(w[-40:], np.ones(40)) + + def test_exact_window_size(self): + b = self._buf(500) + data = np.arange(300, dtype=np.float32) + b.push(data) + w = b.get_window(300) + np.testing.assert_allclose(w, data) + + +# ── Tests: WakeWordDetector ─────────────────────────────────────────────────── + +class TestWakeWordDetector(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.mod = _load_mod() + cls.sr = 16000 + cls.template = _make_template(cls.mod, cls.sr) + + def _det(self, template=None, energy_thr=0.02, match_thr=0.82): + return self.mod.WakeWordDetector(template, energy_thr, match_thr, + self.sr, 512, 40) + + def test_no_template_never_detects(self): + det = self._det(template=None) + loud = _sine(300, 1.5, self.sr, amp=0.8) + detected, _, _ = det.detect(loud) + self.assertFalse(detected) + + def test_has_template_false_when_no_template(self): + det = self._det(template=None) + self.assertFalse(det.has_template) + + def test_has_template_true_when_set(self): + det = self._det(template=self.template) + self.assertTrue(det.has_template) + + def test_silence_below_energy_gate(self): + det = self._det(template=self.template, energy_thr=0.02) + silence = _silence(1.5, self.sr) + detected, rms_val, sim = det.detect(silence) + self.assertFalse(detected) + self.assertAlmostEqual(rms_val, 0.0, places=4) + self.assertAlmostEqual(sim, 0.0, places=4) + + def test_returns_rms_value(self): + det = self._det(template=None) + s = _sine(300, 1.5, self.sr, amp=0.5) + _, rms_val, _ = det.detect(s) + expected = 0.5 / math.sqrt(2) + self.assertAlmostEqual(rms_val, expected, places=2) + + def test_identical_signal_detects(self): + """Signal identical to template should give sim ≈ 1.0 → detect.""" + signal = _sine(300, 1.5, self.sr, amp=0.6) + det = self._det(template=self.template, match_thr=0.99) + detected, _, sim = det.detect(signal) + # Sim must be very high for an identical-source signal + self.assertGreater(sim, 0.99) + self.assertTrue(detected) + + def test_different_signal_low_sim(self): + """A very different signal (white noise) should have low similarity.""" + rng = np.random.default_rng(7) + noise = (rng.standard_normal(int(self.sr * 1.5)) * 0.4).astype(np.float32) + det = self._det(template=self.template, match_thr=0.82) + _, _, sim = det.detect(noise) + # White noise sim to a tonal template should be < 0.6 + self.assertLess(sim, 0.6) + + def test_threshold_boundary_low(self): + """Setting match_thr=0.0 with a loud signal should fire if template set.""" + signal = _sine(300, 1.5, self.sr, amp=0.6) + det = self._det(template=self.template, match_thr=0.0) + detected, _, _ = det.detect(signal) + self.assertTrue(detected) + + def test_threshold_boundary_high(self): + """Setting match_thr=1.1 (above max) should never fire.""" + signal = _sine(300, 1.5, self.sr, amp=0.6) + det = self._det(template=self.template, match_thr=1.1) + detected, _, _ = det.detect(signal) + self.assertFalse(detected) + + def test_energy_below_threshold_skips_matching(self): + """Low energy → sim returned as 0.0 regardless of template.""" + very_quiet = _sine(300, 1.5, self.sr, amp=0.001) + det = self._det(template=self.template, energy_thr=0.1) + detected, rms_val, sim = det.detect(very_quiet) + self.assertFalse(detected) + self.assertAlmostEqual(sim, 0.0, places=5) + + +# ── Tests: node init ────────────────────────────────────────────────────────── + +class TestNodeInit(unittest.TestCase): + @classmethod + def setUpClass(cls): cls.mod = _load_mod() + + def test_instantiates(self): + self.assertIsNotNone(_make_node(self.mod)) + + def test_pub_registered(self): + node = _make_node(self.mod) + self.assertIn("/saltybot/wake_word_detected", node._pubs) + + def test_sub_registered(self): + node = _make_node(self.mod) + self.assertIn("/social/speech/audio_raw", node._subs) + + def test_timer_registered(self): + node = _make_node(self.mod) + self.assertGreater(len(node._timers), 0) + + def test_custom_topics(self): + node = _make_node(self.mod, + audio_topic="/my/audio", + output_topic="/my/wake") + self.assertIn("/my/audio", node._subs) + self.assertIn("/my/wake", node._pubs) + + def test_no_template_passive(self): + node = _make_node(self.mod, template_path="") + self.assertFalse(node._detector.has_template) + + def test_bad_template_path_warns(self): + node = _make_node(self.mod, template_path="/nonexistent/template.npy") + warns = [m for lvl, m in node._logs if lvl == "WARN"] + self.assertTrue(any("template" in m.lower() or "passive" in m.lower() + for m in warns)) + + def test_ring_buffer_allocated(self): + node = _make_node(self.mod) + self.assertIsNotNone(node._buf) + + def test_window_n_computed(self): + node = _make_node(self.mod, sample_rate=16000, window_s=1.5) + self.assertEqual(node._win_n, 24000) + + +# ── Tests: _on_audio callback ───────────────────────────────────────────────── + +class TestOnAudio(unittest.TestCase): + @classmethod + def setUpClass(cls): cls.mod = _load_mod() + + def setUp(self): + self.node = _make_node(self.mod) + + def _push(self, samples): + msg = _make_audio_msg(samples) + self.node._subs["/social/speech/audio_raw"](msg) + + def test_pushes_samples_to_buffer(self): + self._push(_sine(440, 0.5)) + self.assertGreater(len(self.node._buf), 0) + + def test_buffer_grows_with_pushes(self): + chunk = _sine(440, 0.1) + before = len(self.node._buf) + self._push(chunk) + after = len(self.node._buf) + self.assertGreater(after, before) + + def test_bad_data_no_crash(self): + msg = _UInt8MultiArray() + msg.data = b"\xff" # 1 orphan byte — yields 0 samples, no crash + self.node._subs["/social/speech/audio_raw"](msg) + + def test_multiple_chunks_accumulate(self): + for _ in range(5): + self._push(_sine(440, 0.1)) + self.assertGreater(len(self.node._buf), 0) + + +# ── Tests: detection callback ───────────────────────────────────────────────── + +class TestDetectionCallback(unittest.TestCase): + @classmethod + def setUpClass(cls): cls.mod = _load_mod() + + def _node_with_template(self, **kwargs): + template = _make_template(self.mod) + node = _make_node(self.mod, **kwargs) + node._detector = self.mod.WakeWordDetector( + template, energy_threshold=0.02, match_threshold=0.0, # thr=0 → always fires when loud + sample_rate=16000, n_fft=512, n_mels=40 + ) + return node + + def _fill_buffer(self, node, signal): + msg = _make_audio_msg(signal) + node._subs["/social/speech/audio_raw"](msg) + + def test_no_data_no_publish(self): + node = _make_node(self.mod) + node._detection_cb() + self.assertEqual(len(node._pubs["/saltybot/wake_word_detected"].msgs), 0) + + def test_insufficient_buffer_no_publish(self): + node = self._node_with_template() + # Push only 0.1 s but window is 1.5 s + self._fill_buffer(node, _sine(300, 0.1)) + node._detection_cb() + self.assertEqual(len(node._pubs["/saltybot/wake_word_detected"].msgs), 0) + + def test_detects_and_publishes_true(self): + node = self._node_with_template() + # Fill with a loud 300 Hz sine (matches template) + self._fill_buffer(node, _sine(300, 1.5, amp=0.6)) + node._detection_cb() + pub = node._pubs["/saltybot/wake_word_detected"] + self.assertEqual(len(pub.msgs), 1) + self.assertTrue(pub.msgs[0].data) + + def test_cooldown_suppresses_second_detection(self): + node = self._node_with_template(cooldown_s=60.0) + self._fill_buffer(node, _sine(300, 1.5, amp=0.6)) + node._detection_cb() + # Second call immediately → cooldown active + self._fill_buffer(node, _sine(300, 1.5, amp=0.6)) + node._detection_cb() + pub = node._pubs["/saltybot/wake_word_detected"] + self.assertEqual(len(pub.msgs), 1) + + def test_cooldown_expired_allows_second(self): + node = self._node_with_template(cooldown_s=0.0) + self._fill_buffer(node, _sine(300, 1.5, amp=0.6)) + node._detection_cb() + self._fill_buffer(node, _sine(300, 1.5, amp=0.6)) + node._detection_cb() + pub = node._pubs["/saltybot/wake_word_detected"] + self.assertEqual(len(pub.msgs), 2) + + def test_no_template_never_publishes(self): + node = _make_node(self.mod, template_path="") + self._fill_buffer(node, _sine(300, 1.5, amp=0.8)) + node._detection_cb() + pub = node._pubs["/saltybot/wake_word_detected"] + self.assertEqual(len(pub.msgs), 0) + + def test_silence_no_publish(self): + node = self._node_with_template() + self._fill_buffer(node, _silence(1.5)) + node._detection_cb() + pub = node._pubs["/saltybot/wake_word_detected"] + self.assertEqual(len(pub.msgs), 0) + + def test_detection_logs_info(self): + node = self._node_with_template() + self._fill_buffer(node, _sine(300, 1.5, amp=0.6)) + node._detection_cb() + infos = [m for lvl, m in node._logs if lvl == "INFO"] + self.assertTrue(any("detected" in m.lower() or "hey salty" in m.lower() + for m in infos)) + + +# ── Tests: source content ───────────────────────────────────────────────────── + +class TestNodeSrc(unittest.TestCase): + @classmethod + def setUpClass(cls): + with open(_SRC) as f: cls.src = f.read() + + def test_issue_tag(self): self.assertIn("#320", self.src) + def test_audio_topic(self): self.assertIn("/social/speech/audio_raw", self.src) + def test_output_topic(self): self.assertIn("/saltybot/wake_word_detected", self.src) + def test_wake_word_name(self): self.assertIn("hey salty", self.src) + def test_compute_log_mel(self): self.assertIn("compute_log_mel", self.src) + def test_cosine_sim(self): self.assertIn("cosine_sim", self.src) + def test_mel_filterbank(self): self.assertIn("mel_filterbank", self.src) + def test_audio_ring_buffer(self): self.assertIn("AudioRingBuffer", self.src) + def test_wake_word_detector(self): self.assertIn("WakeWordDetector", self.src) + def test_energy_threshold(self): self.assertIn("energy_threshold", self.src) + def test_match_threshold(self): self.assertIn("match_threshold", self.src) + def test_cooldown(self): self.assertIn("cooldown_s", self.src) + def test_template_path(self): self.assertIn("template_path", self.src) + def test_pcm16_decode(self): self.assertIn("pcm16_to_float", self.src) + def test_threading_lock(self): self.assertIn("threading.Lock", self.src) + def test_numpy_used(self): self.assertIn("import numpy", self.src) + def test_main_defined(self): self.assertIn("def main", self.src) + def test_uint8_multiarray(self): self.assertIn("UInt8MultiArray", self.src) + + +# ── Tests: config / launch / setup ──────────────────────────────────────────── + +class TestConfig(unittest.TestCase): + _CONFIG = ( + "/Users/seb/AI/saltylab-firmware/jetson/ros2_ws/src/" + "saltybot_social/config/wake_word_params.yaml" + ) + _LAUNCH = ( + "/Users/seb/AI/saltylab-firmware/jetson/ros2_ws/src/" + "saltybot_social/launch/wake_word.launch.py" + ) + _SETUP = ( + "/Users/seb/AI/saltylab-firmware/jetson/ros2_ws/src/" + "saltybot_social/setup.py" + ) + + def test_config_exists(self): + import os; self.assertTrue(os.path.exists(self._CONFIG)) + + def test_config_energy_threshold(self): + with open(self._CONFIG) as f: c = f.read() + self.assertIn("energy_threshold", c) + + def test_config_match_threshold(self): + with open(self._CONFIG) as f: c = f.read() + self.assertIn("match_threshold", c) + + def test_config_template_path(self): + with open(self._CONFIG) as f: c = f.read() + self.assertIn("template_path", c) + + def test_config_cooldown(self): + with open(self._CONFIG) as f: c = f.read() + self.assertIn("cooldown_s", c) + + def test_launch_exists(self): + import os; self.assertTrue(os.path.exists(self._LAUNCH)) + + def test_launch_has_template_arg(self): + with open(self._LAUNCH) as f: c = f.read() + self.assertIn("template_path", c) + + def test_launch_has_threshold_arg(self): + with open(self._LAUNCH) as f: c = f.read() + self.assertIn("match_threshold", c) + + def test_entry_point_in_setup(self): + with open(self._SETUP) as f: c = f.read() + self.assertIn("wake_word_node", c) + + +if __name__ == "__main__": + unittest.main()