feat(social): audio wake-word detector 'hey salty' (Issue #320) #317

Merged
sl-jetson merged 1 commits from sl-jetson/wake-word-detect into main 2026-03-03 00:41:23 -05:00
5 changed files with 1118 additions and 0 deletions

View File

@ -0,0 +1,19 @@
wake_word_node:
ros__parameters:
audio_topic: "/social/speech/audio_raw" # PCM-16 mono input (UInt8MultiArray)
output_topic: "/saltybot/wake_word_detected"
sample_rate: 16000 # Hz — must match audio source
window_s: 1.5 # detection window length (s)
hop_s: 0.1 # detection timer period (s)
energy_threshold: 0.02 # RMS gate; below this → skip matching
match_threshold: 0.82 # cosine-similarity gate; above → detect
cooldown_s: 2.0 # minimum gap between successive detections (s)
# Path to .npy template file (log-mel features of 'hey salty' recording).
# Leave empty for passive mode (no detections fired).
template_path: "" # e.g. "/opt/saltybot/models/hey_salty.npy"
n_fft: 512 # FFT size for mel spectrogram
n_mels: 40 # mel filterbank bands

View File

@ -0,0 +1,43 @@
"""wake_word.launch.py — Launch wake-word detector ('hey salty') (Issue #320).
Usage:
ros2 launch saltybot_social wake_word.launch.py
ros2 launch saltybot_social wake_word.launch.py template_path:=/opt/saltybot/models/hey_salty.npy
ros2 launch saltybot_social wake_word.launch.py match_threshold:=0.85 cooldown_s:=3.0
"""
import os
from ament_index_python.packages import get_package_share_directory
from launch import LaunchDescription
from launch.actions import DeclareLaunchArgument
from launch.substitutions import LaunchConfiguration
from launch_ros.actions import Node
def generate_launch_description():
pkg = get_package_share_directory("saltybot_social")
cfg = os.path.join(pkg, "config", "wake_word_params.yaml")
return LaunchDescription([
DeclareLaunchArgument("template_path", default_value="",
description="Path to .npy template file (log-mel of 'hey salty')"),
DeclareLaunchArgument("match_threshold", default_value="0.82",
description="Cosine-similarity detection threshold"),
DeclareLaunchArgument("cooldown_s", default_value="2.0",
description="Minimum seconds between detections"),
Node(
package="saltybot_social",
executable="wake_word_node",
name="wake_word_node",
output="screen",
parameters=[
cfg,
{
"template_path": LaunchConfiguration("template_path"),
"match_threshold": LaunchConfiguration("match_threshold"),
"cooldown_s": LaunchConfiguration("cooldown_s"),
},
],
),
])

View File

@ -0,0 +1,343 @@
"""wake_word_node.py — Audio wake-word detector ('hey salty').
Issue #320
Subscribes to raw PCM-16 audio on /social/speech/audio_raw, maintains a
sliding window buffer, and on each hop tick computes log-mel spectrogram
features of the most recent window. If energy is above the gate threshold
AND cosine similarity to the stored template is above match_threshold the
detection fires: Bool(True) is published on /saltybot/wake_word_detected.
Detection is one-shot (only True is published) and guarded by a cooldown
so rapid re-fires are suppressed. When no template is loaded (template_path
is empty) the node stays passive energy gating is applied but no match is
attempted.
Audio format expected
UInt8MultiArray, same feed as vad_node:
raw PCM-16 little-endian mono at ``sample_rate`` Hz.
Subscriptions
/social/speech/audio_raw std_msgs/UInt8MultiArray raw PCM-16 chunks
Publications
/saltybot/wake_word_detected std_msgs/Bool True on each detection event
Parameters
audio_topic (str, "/social/speech/audio_raw")
output_topic (str, "/saltybot/wake_word_detected")
sample_rate (int, 16000) sample rate of incoming audio (Hz)
window_s (float, 1.5) detection window duration (s)
hop_s (float, 0.1) detection timer period (s)
energy_threshold (float, 0.02) RMS gate; below this skip matching
match_threshold (float, 0.82) cosine-similarity gate; above detect
cooldown_s (float, 2.0) minimum gap between successive detections
template_path (str, "") path to .npy template file; "" = passive
n_fft (int, 512) FFT size for mel spectrogram
n_mels (int, 40) number of mel filterbank bands
"""
from __future__ import annotations
import math
import struct
import threading
import time
from collections import deque
from typing import Dict, Optional, Tuple
import rclpy
from rclpy.node import Node
from rclpy.qos import QoSProfile
from std_msgs.msg import Bool, UInt8MultiArray
try:
import numpy as np
_NP = True
except ImportError: # pragma: no cover
_NP = False
INT16_MAX = 32768.0
# ── Pure DSP helpers (no ROS, numpy only) ──────────────────────────────────────
def pcm16_to_float(data: bytes) -> "np.ndarray":
"""Decode PCM-16 LE bytes → float32 ndarray in [-1.0, 1.0]."""
n = len(data) // 2
if n == 0:
return np.zeros(0, dtype=np.float32)
samples = struct.unpack(f"<{n}h", data[:n * 2])
return np.array(samples, dtype=np.float32) / INT16_MAX
def mel_filterbank(sr: int, n_fft: int, n_mels: int,
fmin: float = 80.0, fmax: Optional[float] = None) -> "np.ndarray":
"""Build a triangular mel filterbank matrix [n_mels, n_fft//2+1]."""
if fmax is None:
fmax = sr / 2.0
def hz_to_mel(hz: float) -> float:
return 2595.0 * math.log10(1.0 + hz / 700.0)
def mel_to_hz(mel: float) -> float:
return 700.0 * (10.0 ** (mel / 2595.0) - 1.0)
mel_lo = hz_to_mel(fmin)
mel_hi = hz_to_mel(fmax)
mel_pts = np.linspace(mel_lo, mel_hi, n_mels + 2)
hz_pts = np.array([mel_to_hz(m) for m in mel_pts])
freqs = np.fft.rfftfreq(n_fft, d=1.0 / sr)
fb = np.zeros((n_mels, len(freqs)), dtype=np.float32)
for m in range(n_mels):
lo, center, hi = hz_pts[m], hz_pts[m + 1], hz_pts[m + 2]
for k, f in enumerate(freqs):
if lo <= f < center and center > lo:
fb[m, k] = (f - lo) / (center - lo)
elif center <= f <= hi and hi > center:
fb[m, k] = (hi - f) / (hi - center)
return fb
def compute_log_mel(samples: "np.ndarray", sr: int,
n_fft: int = 512, n_mels: int = 40,
hop: int = 256) -> "np.ndarray":
"""Return log-mel spectrogram [n_mels, T] of *samples* (float32 [-1,1])."""
n = len(samples)
window = np.hanning(n_fft).astype(np.float32)
frames = []
for start in range(0, max(n - n_fft + 1, 1), hop):
chunk = samples[start:start + n_fft]
if len(chunk) < n_fft:
chunk = np.pad(chunk, (0, n_fft - len(chunk)))
power = np.abs(np.fft.rfft(chunk * window)) ** 2
frames.append(power)
frames_arr = np.array(frames, dtype=np.float32).T # [bins, T]
fb = mel_filterbank(sr, n_fft, n_mels)
mel = fb @ frames_arr # [n_mels, T]
mel = np.where(mel > 1e-10, mel, 1e-10)
return np.log(mel)
def cosine_sim(a: "np.ndarray", b: "np.ndarray") -> float:
"""Cosine similarity between two arrays, matched by minimum length."""
af = a.flatten().astype(np.float64)
bf = b.flatten().astype(np.float64)
min_len = min(len(af), len(bf))
if min_len == 0:
return 0.0
af = af[:min_len]
bf = bf[:min_len]
denom = float(np.linalg.norm(af)) * float(np.linalg.norm(bf))
if denom < 1e-12:
return 0.0
return float(np.dot(af, bf) / denom)
def rms(samples: "np.ndarray") -> float:
"""RMS amplitude of a float sample array."""
if len(samples) == 0:
return 0.0
return float(np.sqrt(np.mean(samples.astype(np.float64) ** 2)))
# ── Ring buffer ────────────────────────────────────────────────────────────────
class AudioRingBuffer:
"""Lock-free sliding window for raw float audio samples."""
def __init__(self, max_samples: int) -> None:
self._buf: deque = deque(maxlen=max_samples)
def push(self, samples: "np.ndarray") -> None:
self._buf.extend(samples.tolist())
def get_window(self, n_samples: int) -> Optional["np.ndarray"]:
"""Return last n_samples as float32 array, or None if buffer too short."""
if len(self._buf) < n_samples:
return None
return np.array(list(self._buf)[-n_samples:], dtype=np.float32)
def __len__(self) -> int:
return len(self._buf)
# ── Detector ───────────────────────────────────────────────────────────────────
class WakeWordDetector:
"""Energy-gated cosine-similarity wake-word detector.
Args:
template: Log-mel feature array of the wake word, or None
(passive never fires when None).
energy_threshold: Minimum RMS to proceed to feature matching.
match_threshold: Minimum cosine similarity to fire.
sample_rate: Expected audio sample rate (Hz).
n_fft: FFT size for mel computation.
n_mels: Number of mel bands.
"""
def __init__(self,
template: Optional["np.ndarray"],
energy_threshold: float = 0.02,
match_threshold: float = 0.82,
sample_rate: int = 16000,
n_fft: int = 512,
n_mels: int = 40) -> None:
self._template = template
self._energy_thr = energy_threshold
self._match_thr = match_threshold
self._sr = sample_rate
self._n_fft = n_fft
self._n_mels = n_mels
# ------------------------------------------------------------------
def detect(self, samples: "np.ndarray") -> Tuple[bool, float, float]:
"""Run detection on a window of float samples.
Returns
-------
(detected, rms_value, similarity)
"""
energy = rms(samples)
if energy < self._energy_thr:
return False, energy, 0.0
if self._template is None:
return False, energy, 0.0
hop = max(1, self._n_fft // 2)
feats = compute_log_mel(samples, self._sr, self._n_fft, self._n_mels, hop)
sim = cosine_sim(feats, self._template)
return sim >= self._match_thr, energy, sim
@property
def has_template(self) -> bool:
return self._template is not None
# ── ROS2 node ──────────────────────────────────────────────────────────────────
class WakeWordNode(Node):
"""ROS2 node: 'hey salty' wake-word detection via energy + template matching."""
def __init__(self) -> None:
super().__init__("wake_word_node")
self.declare_parameter("audio_topic", "/social/speech/audio_raw")
self.declare_parameter("output_topic", "/saltybot/wake_word_detected")
self.declare_parameter("sample_rate", 16000)
self.declare_parameter("window_s", 1.5)
self.declare_parameter("hop_s", 0.1)
self.declare_parameter("energy_threshold", 0.02)
self.declare_parameter("match_threshold", 0.82)
self.declare_parameter("cooldown_s", 2.0)
self.declare_parameter("template_path", "")
self.declare_parameter("n_fft", 512)
self.declare_parameter("n_mels", 40)
audio_topic = self.get_parameter("audio_topic").value
output_topic = self.get_parameter("output_topic").value
self._sr = int(self.get_parameter("sample_rate").value)
self._win_s = float(self.get_parameter("window_s").value)
hop_s = float(self.get_parameter("hop_s").value)
energy_thr = float(self.get_parameter("energy_threshold").value)
match_thr = float(self.get_parameter("match_threshold").value)
self._cool_s = float(self.get_parameter("cooldown_s").value)
tmpl_path = str(self.get_parameter("template_path").value)
n_fft = int(self.get_parameter("n_fft").value)
n_mels = int(self.get_parameter("n_mels").value)
# ── Load template ──────────────────────────────────────────────
template: Optional["np.ndarray"] = None
if tmpl_path and _NP:
try:
template = np.load(tmpl_path)
self.get_logger().info(
f"WakeWord: loaded template {tmpl_path} shape={template.shape}"
)
except Exception as exc:
self.get_logger().warn(
f"WakeWord: could not load template '{tmpl_path}': {exc} — passive mode"
)
# ── Ring buffer ────────────────────────────────────────────────
max_samples = int(self._sr * self._win_s * 4) # 4× headroom
self._win_n = int(self._sr * self._win_s)
self._buf = AudioRingBuffer(max_samples)
# ── Detector ───────────────────────────────────────────────────
self._detector = WakeWordDetector(template, energy_thr, match_thr,
self._sr, n_fft, n_mels)
# ── State ──────────────────────────────────────────────────────
self._last_det_t: float = 0.0
self._lock = threading.Lock()
# ── ROS ────────────────────────────────────────────────────────
qos = QoSProfile(depth=10)
self._pub = self.create_publisher(Bool, output_topic, qos)
self._sub = self.create_subscription(
UInt8MultiArray, audio_topic, self._on_audio, qos
)
self._timer = self.create_timer(hop_s, self._detection_cb)
tmpl_status = (f"template={tmpl_path}" if template is not None
else "passive (no template)")
self.get_logger().info(
f"WakeWordNode ready — {tmpl_status}, "
f"window={self._win_s}s, hop={hop_s}s, "
f"energy_thr={energy_thr}, match_thr={match_thr}"
)
# ── Subscription ───────────────────────────────────────────────────
def _on_audio(self, msg) -> None:
if not _NP:
return
try:
raw = bytes(msg.data)
samples = pcm16_to_float(raw)
except Exception as exc:
self.get_logger().warn(f"WakeWord: audio decode error: {exc}")
return
with self._lock:
self._buf.push(samples)
# ── Detection timer ────────────────────────────────────────────────
def _detection_cb(self) -> None:
if not _NP:
return
now = time.monotonic()
with self._lock:
window = self._buf.get_window(self._win_n)
if window is None:
return
detected, energy, sim = self._detector.detect(window)
if detected and (now - self._last_det_t) >= self._cool_s:
self._last_det_t = now
self.get_logger().info(
f"WakeWord: 'hey salty' detected "
f"(rms={energy:.4f}, sim={sim:.3f})"
)
out = Bool()
out.data = True
self._pub.publish(out)
def main(args=None) -> None:
rclpy.init(args=args)
node = WakeWordNode()
try:
rclpy.spin(node)
except KeyboardInterrupt:
pass
finally:
node.destroy_node()
rclpy.shutdown()

View File

@ -57,6 +57,8 @@ setup(
'topic_memory_node = saltybot_social.topic_memory_node:main', 'topic_memory_node = saltybot_social.topic_memory_node:main',
# Personal space respector (Issue #310) # Personal space respector (Issue #310)
'personal_space_node = saltybot_social.personal_space_node:main', 'personal_space_node = saltybot_social.personal_space_node:main',
# Audio wake-word detector — 'hey salty' (Issue #320)
'wake_word_node = saltybot_social.wake_word_node:main',
], ],
}, },
) )

View File

@ -0,0 +1,711 @@
"""test_wake_word.py — Offline tests for wake_word_node (Issue #320).
Stubs out rclpy and ROS message types so tests run without a ROS install.
numpy is required (standard on the Jetson).
"""
import importlib.util
import math
import struct
import sys
import time
import types
import unittest
import numpy as np
# ── ROS2 stubs ────────────────────────────────────────────────────────────────
def _make_ros_stubs():
for mod_name in ("rclpy", "rclpy.node", "rclpy.qos",
"std_msgs", "std_msgs.msg"):
if mod_name not in sys.modules:
sys.modules[mod_name] = types.ModuleType(mod_name)
class _Node:
def __init__(self, name="node"):
self._name = name
if not hasattr(self, "_params"):
self._params = {}
self._pubs = {}
self._subs = {}
self._logs = []
self._timers = []
def declare_parameter(self, name, default):
if name not in self._params:
self._params[name] = default
def get_parameter(self, name):
class _P:
def __init__(self, v): self.value = v
return _P(self._params.get(name))
def create_publisher(self, msg_type, topic, qos):
pub = _FakePub()
self._pubs[topic] = pub
return pub
def create_subscription(self, msg_type, topic, cb, qos):
self._subs[topic] = cb
return object()
def create_timer(self, period, cb):
self._timers.append(cb)
return object()
def get_logger(self):
node = self
class _L:
def info(self, m): node._logs.append(("INFO", m))
def warn(self, m): node._logs.append(("WARN", m))
def error(self, m): node._logs.append(("ERROR", m))
return _L()
def destroy_node(self): pass
class _FakePub:
def __init__(self):
self.msgs = []
def publish(self, msg):
self.msgs.append(msg)
class _QoSProfile:
def __init__(self, depth=10): self.depth = depth
class _Bool:
def __init__(self): self.data = False
class _UInt8MultiArray:
def __init__(self): self.data = b""
rclpy_mod = sys.modules["rclpy"]
rclpy_mod.init = lambda args=None: None
rclpy_mod.spin = lambda node: None
rclpy_mod.shutdown = lambda: None
sys.modules["rclpy.node"].Node = _Node
sys.modules["rclpy.qos"].QoSProfile = _QoSProfile
sys.modules["std_msgs.msg"].Bool = _Bool
sys.modules["std_msgs.msg"].UInt8MultiArray = _UInt8MultiArray
return _Node, _FakePub, _Bool, _UInt8MultiArray
_Node, _FakePub, _Bool, _UInt8MultiArray = _make_ros_stubs()
# ── Module loader ─────────────────────────────────────────────────────────────
_SRC = (
"/Users/seb/AI/saltylab-firmware/jetson/ros2_ws/src/"
"saltybot_social/saltybot_social/wake_word_node.py"
)
def _load_mod():
spec = importlib.util.spec_from_file_location("wake_word_testmod", _SRC)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
return mod
def _make_node(mod, **kwargs):
node = mod.WakeWordNode.__new__(mod.WakeWordNode)
defaults = {
"audio_topic": "/social/speech/audio_raw",
"output_topic": "/saltybot/wake_word_detected",
"sample_rate": 16000,
"window_s": 1.5,
"hop_s": 0.1,
"energy_threshold": 0.02,
"match_threshold": 0.82,
"cooldown_s": 2.0,
"template_path": "",
"n_fft": 512,
"n_mels": 40,
}
defaults.update(kwargs)
node._params = dict(defaults)
mod.WakeWordNode.__init__(node)
return node
def _make_pcm_bytes(samples: np.ndarray) -> bytes:
"""Encode float32 array [-1,1] to PCM-16 LE bytes."""
ints = np.clip(samples * 32768.0, -32768, 32767).astype(np.int16)
return ints.tobytes()
def _make_audio_msg(samples: np.ndarray) -> _UInt8MultiArray:
m = _UInt8MultiArray()
m.data = _make_pcm_bytes(samples)
return m
def _sine(freq: float, duration: float, sr: int = 16000,
amp: float = 0.5) -> np.ndarray:
"""Generate a mono sine wave."""
t = np.arange(int(sr * duration)) / sr
return (amp * np.sin(2 * math.pi * freq * t)).astype(np.float32)
def _silence(duration: float, sr: int = 16000) -> np.ndarray:
return np.zeros(int(sr * duration), dtype=np.float32)
def _make_template(mod, sr: int = 16000, n_fft: int = 512,
n_mels: int = 40) -> np.ndarray:
"""Compute a template from a synthetic 'wake word' signal for testing."""
signal = _sine(300, 1.5, sr, amp=0.6)
hop = n_fft // 2
return mod.compute_log_mel(signal, sr, n_fft, n_mels, hop)
# ── Tests: pcm16_to_float ─────────────────────────────────────────────────────
class TestPcm16ToFloat(unittest.TestCase):
@classmethod
def setUpClass(cls): cls.mod = _load_mod()
def _conv(self, samples):
raw = _make_pcm_bytes(np.array(samples, dtype=np.float32))
return self.mod.pcm16_to_float(raw)
def test_zeros(self):
out = self._conv([0.0, 0.0, 0.0])
np.testing.assert_allclose(out, [0.0, 0.0, 0.0], atol=1e-4)
def test_positive(self):
out = self._conv([0.5])
self.assertAlmostEqual(float(out[0]), 0.5, places=3)
def test_negative(self):
out = self._conv([-0.5])
self.assertAlmostEqual(float(out[0]), -0.5, places=3)
def test_roundtrip_length(self):
arr = np.linspace(-1.0, 1.0, 100, dtype=np.float32)
raw = _make_pcm_bytes(arr)
out = self.mod.pcm16_to_float(raw)
self.assertEqual(len(out), 100)
def test_empty_bytes_returns_empty(self):
out = self.mod.pcm16_to_float(b"")
self.assertEqual(len(out), 0)
def test_odd_byte_ignored(self):
# 3 bytes → 1 complete int16 + 1 orphan byte
out = self.mod.pcm16_to_float(b"\x00\x40\xff")
self.assertEqual(len(out), 1)
# ── Tests: rms ────────────────────────────────────────────────────────────────
class TestRms(unittest.TestCase):
@classmethod
def setUpClass(cls): cls.mod = _load_mod()
def test_zero_signal(self):
self.assertAlmostEqual(self.mod.rms(np.zeros(100)), 0.0)
def test_constant_signal(self):
# RMS of constant 0.5 = 0.5
self.assertAlmostEqual(self.mod.rms(np.full(100, 0.5)), 0.5, places=5)
def test_sine_rms(self):
# RMS of sin = amp / sqrt(2)
amp = 0.8
s = _sine(440, 1.0, amp=amp)
expected = amp / math.sqrt(2)
self.assertAlmostEqual(self.mod.rms(s), expected, places=2)
def test_empty_array(self):
self.assertEqual(self.mod.rms(np.array([])), 0.0)
# ── Tests: mel_filterbank ─────────────────────────────────────────────────────
class TestMelFilterbank(unittest.TestCase):
@classmethod
def setUpClass(cls): cls.mod = _load_mod()
def test_shape(self):
fb = self.mod.mel_filterbank(16000, 512, 40)
self.assertEqual(fb.shape, (40, 257))
def test_non_negative(self):
fb = self.mod.mel_filterbank(16000, 512, 40)
self.assertTrue((fb >= 0).all())
def test_rows_sum_positive(self):
fb = self.mod.mel_filterbank(16000, 512, 40)
self.assertTrue((fb.sum(axis=1) > 0).all())
def test_custom_n_mels(self):
fb = self.mod.mel_filterbank(16000, 256, 20)
self.assertEqual(fb.shape[0], 20)
# ── Tests: compute_log_mel ────────────────────────────────────────────────────
class TestComputeLogMel(unittest.TestCase):
@classmethod
def setUpClass(cls): cls.mod = _load_mod()
def test_output_shape_rows(self):
s = _sine(440, 1.5)
out = self.mod.compute_log_mel(s, 16000, n_fft=512, n_mels=40, hop=256)
self.assertEqual(out.shape[0], 40)
def test_output_has_time_axis(self):
s = _sine(440, 1.5)
out = self.mod.compute_log_mel(s, 16000, n_fft=512, n_mels=40, hop=256)
self.assertGreater(out.shape[1], 0)
def test_output_finite(self):
s = _sine(440, 1.5)
out = self.mod.compute_log_mel(s, 16000, n_fft=512, n_mels=40, hop=256)
self.assertTrue(np.isfinite(out).all())
def test_silence_gives_low_values(self):
s = _silence(1.5)
out = self.mod.compute_log_mel(s, 16000, n_fft=512, n_mels=40, hop=256)
# All values should be very small (near log(1e-10))
self.assertTrue((out < -20).all())
def test_short_signal_no_crash(self):
# Shorter than one FFT frame
s = np.zeros(100, dtype=np.float32)
out = self.mod.compute_log_mel(s, 16000, n_fft=512, n_mels=40, hop=256)
self.assertEqual(out.shape[0], 40)
# ── Tests: cosine_sim ─────────────────────────────────────────────────────────
class TestCosineSim(unittest.TestCase):
@classmethod
def setUpClass(cls): cls.mod = _load_mod()
def test_identical_vectors(self):
v = np.array([1.0, 2.0, 3.0])
self.assertAlmostEqual(self.mod.cosine_sim(v, v), 1.0, places=5)
def test_orthogonal_vectors(self):
a = np.array([1.0, 0.0])
b = np.array([0.0, 1.0])
self.assertAlmostEqual(self.mod.cosine_sim(a, b), 0.0, places=5)
def test_opposite_vectors(self):
v = np.array([1.0, 2.0, 3.0])
self.assertAlmostEqual(self.mod.cosine_sim(v, -v), -1.0, places=5)
def test_zero_vector_returns_zero(self):
a = np.zeros(5)
b = np.ones(5)
self.assertEqual(self.mod.cosine_sim(a, b), 0.0)
def test_2d_arrays(self):
a = np.ones((4, 10))
b = np.ones((4, 10))
self.assertAlmostEqual(self.mod.cosine_sim(a, b), 1.0, places=5)
def test_different_lengths_truncated(self):
a = np.array([1.0, 2.0, 3.0, 4.0])
b = np.array([1.0, 2.0, 3.0])
# Should not crash, uses min length
result = self.mod.cosine_sim(a, b)
self.assertTrue(-1.0 <= result <= 1.0)
def test_range_is_bounded(self):
rng = np.random.default_rng(42)
a = rng.standard_normal(100)
b = rng.standard_normal(100)
result = self.mod.cosine_sim(a, b)
self.assertGreaterEqual(result, -1.0)
self.assertLessEqual(result, 1.0)
# ── Tests: AudioRingBuffer ────────────────────────────────────────────────────
class TestAudioRingBuffer(unittest.TestCase):
@classmethod
def setUpClass(cls): cls.mod = _load_mod()
def _buf(self, max_samples=1000):
return self.mod.AudioRingBuffer(max_samples)
def test_empty_initially(self):
b = self._buf()
self.assertEqual(len(b), 0)
def test_push_increases_len(self):
b = self._buf()
b.push(np.ones(100, dtype=np.float32))
self.assertEqual(len(b), 100)
def test_get_window_none_when_short(self):
b = self._buf()
b.push(np.ones(50, dtype=np.float32))
self.assertIsNone(b.get_window(100))
def test_get_window_ok_when_full(self):
b = self._buf()
data = np.arange(200, dtype=np.float32)
b.push(data)
w = b.get_window(100)
self.assertIsNotNone(w)
self.assertEqual(len(w), 100)
def test_get_window_returns_latest(self):
b = self._buf()
b.push(np.zeros(100, dtype=np.float32))
b.push(np.ones(100, dtype=np.float32))
w = b.get_window(100)
np.testing.assert_allclose(w, np.ones(100))
def test_maxlen_evicts_oldest(self):
b = self._buf(max_samples=100)
b.push(np.zeros(60, dtype=np.float32))
b.push(np.ones(60, dtype=np.float32)) # should evict 20 zeros
self.assertEqual(len(b), 100)
w = b.get_window(100)
# Last 40 samples should be ones
np.testing.assert_allclose(w[-40:], np.ones(40))
def test_exact_window_size(self):
b = self._buf(500)
data = np.arange(300, dtype=np.float32)
b.push(data)
w = b.get_window(300)
np.testing.assert_allclose(w, data)
# ── Tests: WakeWordDetector ───────────────────────────────────────────────────
class TestWakeWordDetector(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.mod = _load_mod()
cls.sr = 16000
cls.template = _make_template(cls.mod, cls.sr)
def _det(self, template=None, energy_thr=0.02, match_thr=0.82):
return self.mod.WakeWordDetector(template, energy_thr, match_thr,
self.sr, 512, 40)
def test_no_template_never_detects(self):
det = self._det(template=None)
loud = _sine(300, 1.5, self.sr, amp=0.8)
detected, _, _ = det.detect(loud)
self.assertFalse(detected)
def test_has_template_false_when_no_template(self):
det = self._det(template=None)
self.assertFalse(det.has_template)
def test_has_template_true_when_set(self):
det = self._det(template=self.template)
self.assertTrue(det.has_template)
def test_silence_below_energy_gate(self):
det = self._det(template=self.template, energy_thr=0.02)
silence = _silence(1.5, self.sr)
detected, rms_val, sim = det.detect(silence)
self.assertFalse(detected)
self.assertAlmostEqual(rms_val, 0.0, places=4)
self.assertAlmostEqual(sim, 0.0, places=4)
def test_returns_rms_value(self):
det = self._det(template=None)
s = _sine(300, 1.5, self.sr, amp=0.5)
_, rms_val, _ = det.detect(s)
expected = 0.5 / math.sqrt(2)
self.assertAlmostEqual(rms_val, expected, places=2)
def test_identical_signal_detects(self):
"""Signal identical to template should give sim ≈ 1.0 → detect."""
signal = _sine(300, 1.5, self.sr, amp=0.6)
det = self._det(template=self.template, match_thr=0.99)
detected, _, sim = det.detect(signal)
# Sim must be very high for an identical-source signal
self.assertGreater(sim, 0.99)
self.assertTrue(detected)
def test_different_signal_low_sim(self):
"""A very different signal (white noise) should have low similarity."""
rng = np.random.default_rng(7)
noise = (rng.standard_normal(int(self.sr * 1.5)) * 0.4).astype(np.float32)
det = self._det(template=self.template, match_thr=0.82)
_, _, sim = det.detect(noise)
# White noise sim to a tonal template should be < 0.6
self.assertLess(sim, 0.6)
def test_threshold_boundary_low(self):
"""Setting match_thr=0.0 with a loud signal should fire if template set."""
signal = _sine(300, 1.5, self.sr, amp=0.6)
det = self._det(template=self.template, match_thr=0.0)
detected, _, _ = det.detect(signal)
self.assertTrue(detected)
def test_threshold_boundary_high(self):
"""Setting match_thr=1.1 (above max) should never fire."""
signal = _sine(300, 1.5, self.sr, amp=0.6)
det = self._det(template=self.template, match_thr=1.1)
detected, _, _ = det.detect(signal)
self.assertFalse(detected)
def test_energy_below_threshold_skips_matching(self):
"""Low energy → sim returned as 0.0 regardless of template."""
very_quiet = _sine(300, 1.5, self.sr, amp=0.001)
det = self._det(template=self.template, energy_thr=0.1)
detected, rms_val, sim = det.detect(very_quiet)
self.assertFalse(detected)
self.assertAlmostEqual(sim, 0.0, places=5)
# ── Tests: node init ──────────────────────────────────────────────────────────
class TestNodeInit(unittest.TestCase):
@classmethod
def setUpClass(cls): cls.mod = _load_mod()
def test_instantiates(self):
self.assertIsNotNone(_make_node(self.mod))
def test_pub_registered(self):
node = _make_node(self.mod)
self.assertIn("/saltybot/wake_word_detected", node._pubs)
def test_sub_registered(self):
node = _make_node(self.mod)
self.assertIn("/social/speech/audio_raw", node._subs)
def test_timer_registered(self):
node = _make_node(self.mod)
self.assertGreater(len(node._timers), 0)
def test_custom_topics(self):
node = _make_node(self.mod,
audio_topic="/my/audio",
output_topic="/my/wake")
self.assertIn("/my/audio", node._subs)
self.assertIn("/my/wake", node._pubs)
def test_no_template_passive(self):
node = _make_node(self.mod, template_path="")
self.assertFalse(node._detector.has_template)
def test_bad_template_path_warns(self):
node = _make_node(self.mod, template_path="/nonexistent/template.npy")
warns = [m for lvl, m in node._logs if lvl == "WARN"]
self.assertTrue(any("template" in m.lower() or "passive" in m.lower()
for m in warns))
def test_ring_buffer_allocated(self):
node = _make_node(self.mod)
self.assertIsNotNone(node._buf)
def test_window_n_computed(self):
node = _make_node(self.mod, sample_rate=16000, window_s=1.5)
self.assertEqual(node._win_n, 24000)
# ── Tests: _on_audio callback ─────────────────────────────────────────────────
class TestOnAudio(unittest.TestCase):
@classmethod
def setUpClass(cls): cls.mod = _load_mod()
def setUp(self):
self.node = _make_node(self.mod)
def _push(self, samples):
msg = _make_audio_msg(samples)
self.node._subs["/social/speech/audio_raw"](msg)
def test_pushes_samples_to_buffer(self):
self._push(_sine(440, 0.5))
self.assertGreater(len(self.node._buf), 0)
def test_buffer_grows_with_pushes(self):
chunk = _sine(440, 0.1)
before = len(self.node._buf)
self._push(chunk)
after = len(self.node._buf)
self.assertGreater(after, before)
def test_bad_data_no_crash(self):
msg = _UInt8MultiArray()
msg.data = b"\xff" # 1 orphan byte — yields 0 samples, no crash
self.node._subs["/social/speech/audio_raw"](msg)
def test_multiple_chunks_accumulate(self):
for _ in range(5):
self._push(_sine(440, 0.1))
self.assertGreater(len(self.node._buf), 0)
# ── Tests: detection callback ─────────────────────────────────────────────────
class TestDetectionCallback(unittest.TestCase):
@classmethod
def setUpClass(cls): cls.mod = _load_mod()
def _node_with_template(self, **kwargs):
template = _make_template(self.mod)
node = _make_node(self.mod, **kwargs)
node._detector = self.mod.WakeWordDetector(
template, energy_threshold=0.02, match_threshold=0.0, # thr=0 → always fires when loud
sample_rate=16000, n_fft=512, n_mels=40
)
return node
def _fill_buffer(self, node, signal):
msg = _make_audio_msg(signal)
node._subs["/social/speech/audio_raw"](msg)
def test_no_data_no_publish(self):
node = _make_node(self.mod)
node._detection_cb()
self.assertEqual(len(node._pubs["/saltybot/wake_word_detected"].msgs), 0)
def test_insufficient_buffer_no_publish(self):
node = self._node_with_template()
# Push only 0.1 s but window is 1.5 s
self._fill_buffer(node, _sine(300, 0.1))
node._detection_cb()
self.assertEqual(len(node._pubs["/saltybot/wake_word_detected"].msgs), 0)
def test_detects_and_publishes_true(self):
node = self._node_with_template()
# Fill with a loud 300 Hz sine (matches template)
self._fill_buffer(node, _sine(300, 1.5, amp=0.6))
node._detection_cb()
pub = node._pubs["/saltybot/wake_word_detected"]
self.assertEqual(len(pub.msgs), 1)
self.assertTrue(pub.msgs[0].data)
def test_cooldown_suppresses_second_detection(self):
node = self._node_with_template(cooldown_s=60.0)
self._fill_buffer(node, _sine(300, 1.5, amp=0.6))
node._detection_cb()
# Second call immediately → cooldown active
self._fill_buffer(node, _sine(300, 1.5, amp=0.6))
node._detection_cb()
pub = node._pubs["/saltybot/wake_word_detected"]
self.assertEqual(len(pub.msgs), 1)
def test_cooldown_expired_allows_second(self):
node = self._node_with_template(cooldown_s=0.0)
self._fill_buffer(node, _sine(300, 1.5, amp=0.6))
node._detection_cb()
self._fill_buffer(node, _sine(300, 1.5, amp=0.6))
node._detection_cb()
pub = node._pubs["/saltybot/wake_word_detected"]
self.assertEqual(len(pub.msgs), 2)
def test_no_template_never_publishes(self):
node = _make_node(self.mod, template_path="")
self._fill_buffer(node, _sine(300, 1.5, amp=0.8))
node._detection_cb()
pub = node._pubs["/saltybot/wake_word_detected"]
self.assertEqual(len(pub.msgs), 0)
def test_silence_no_publish(self):
node = self._node_with_template()
self._fill_buffer(node, _silence(1.5))
node._detection_cb()
pub = node._pubs["/saltybot/wake_word_detected"]
self.assertEqual(len(pub.msgs), 0)
def test_detection_logs_info(self):
node = self._node_with_template()
self._fill_buffer(node, _sine(300, 1.5, amp=0.6))
node._detection_cb()
infos = [m for lvl, m in node._logs if lvl == "INFO"]
self.assertTrue(any("detected" in m.lower() or "hey salty" in m.lower()
for m in infos))
# ── Tests: source content ─────────────────────────────────────────────────────
class TestNodeSrc(unittest.TestCase):
@classmethod
def setUpClass(cls):
with open(_SRC) as f: cls.src = f.read()
def test_issue_tag(self): self.assertIn("#320", self.src)
def test_audio_topic(self): self.assertIn("/social/speech/audio_raw", self.src)
def test_output_topic(self): self.assertIn("/saltybot/wake_word_detected", self.src)
def test_wake_word_name(self): self.assertIn("hey salty", self.src)
def test_compute_log_mel(self): self.assertIn("compute_log_mel", self.src)
def test_cosine_sim(self): self.assertIn("cosine_sim", self.src)
def test_mel_filterbank(self): self.assertIn("mel_filterbank", self.src)
def test_audio_ring_buffer(self): self.assertIn("AudioRingBuffer", self.src)
def test_wake_word_detector(self): self.assertIn("WakeWordDetector", self.src)
def test_energy_threshold(self): self.assertIn("energy_threshold", self.src)
def test_match_threshold(self): self.assertIn("match_threshold", self.src)
def test_cooldown(self): self.assertIn("cooldown_s", self.src)
def test_template_path(self): self.assertIn("template_path", self.src)
def test_pcm16_decode(self): self.assertIn("pcm16_to_float", self.src)
def test_threading_lock(self): self.assertIn("threading.Lock", self.src)
def test_numpy_used(self): self.assertIn("import numpy", self.src)
def test_main_defined(self): self.assertIn("def main", self.src)
def test_uint8_multiarray(self): self.assertIn("UInt8MultiArray", self.src)
# ── Tests: config / launch / setup ────────────────────────────────────────────
class TestConfig(unittest.TestCase):
_CONFIG = (
"/Users/seb/AI/saltylab-firmware/jetson/ros2_ws/src/"
"saltybot_social/config/wake_word_params.yaml"
)
_LAUNCH = (
"/Users/seb/AI/saltylab-firmware/jetson/ros2_ws/src/"
"saltybot_social/launch/wake_word.launch.py"
)
_SETUP = (
"/Users/seb/AI/saltylab-firmware/jetson/ros2_ws/src/"
"saltybot_social/setup.py"
)
def test_config_exists(self):
import os; self.assertTrue(os.path.exists(self._CONFIG))
def test_config_energy_threshold(self):
with open(self._CONFIG) as f: c = f.read()
self.assertIn("energy_threshold", c)
def test_config_match_threshold(self):
with open(self._CONFIG) as f: c = f.read()
self.assertIn("match_threshold", c)
def test_config_template_path(self):
with open(self._CONFIG) as f: c = f.read()
self.assertIn("template_path", c)
def test_config_cooldown(self):
with open(self._CONFIG) as f: c = f.read()
self.assertIn("cooldown_s", c)
def test_launch_exists(self):
import os; self.assertTrue(os.path.exists(self._LAUNCH))
def test_launch_has_template_arg(self):
with open(self._LAUNCH) as f: c = f.read()
self.assertIn("template_path", c)
def test_launch_has_threshold_arg(self):
with open(self._LAUNCH) as f: c = f.read()
self.assertIn("match_threshold", c)
def test_entry_point_in_setup(self):
with open(self._SETUP) as f: c = f.read()
self.assertIn("wake_word_node", c)
if __name__ == "__main__":
unittest.main()