feat(social): audio wake-word detector 'hey salty' (Issue #320) #317
@ -0,0 +1,19 @@
|
||||
wake_word_node:
|
||||
ros__parameters:
|
||||
audio_topic: "/social/speech/audio_raw" # PCM-16 mono input (UInt8MultiArray)
|
||||
output_topic: "/saltybot/wake_word_detected"
|
||||
|
||||
sample_rate: 16000 # Hz — must match audio source
|
||||
window_s: 1.5 # detection window length (s)
|
||||
hop_s: 0.1 # detection timer period (s)
|
||||
|
||||
energy_threshold: 0.02 # RMS gate; below this → skip matching
|
||||
match_threshold: 0.82 # cosine-similarity gate; above → detect
|
||||
cooldown_s: 2.0 # minimum gap between successive detections (s)
|
||||
|
||||
# Path to .npy template file (log-mel features of 'hey salty' recording).
|
||||
# Leave empty for passive mode (no detections fired).
|
||||
template_path: "" # e.g. "/opt/saltybot/models/hey_salty.npy"
|
||||
|
||||
n_fft: 512 # FFT size for mel spectrogram
|
||||
n_mels: 40 # mel filterbank bands
|
||||
@ -0,0 +1,43 @@
|
||||
"""wake_word.launch.py — Launch wake-word detector ('hey salty') (Issue #320).
|
||||
|
||||
Usage:
|
||||
ros2 launch saltybot_social wake_word.launch.py
|
||||
ros2 launch saltybot_social wake_word.launch.py template_path:=/opt/saltybot/models/hey_salty.npy
|
||||
ros2 launch saltybot_social wake_word.launch.py match_threshold:=0.85 cooldown_s:=3.0
|
||||
"""
|
||||
|
||||
import os
|
||||
from ament_index_python.packages import get_package_share_directory
|
||||
from launch import LaunchDescription
|
||||
from launch.actions import DeclareLaunchArgument
|
||||
from launch.substitutions import LaunchConfiguration
|
||||
from launch_ros.actions import Node
|
||||
|
||||
|
||||
def generate_launch_description():
|
||||
pkg = get_package_share_directory("saltybot_social")
|
||||
cfg = os.path.join(pkg, "config", "wake_word_params.yaml")
|
||||
|
||||
return LaunchDescription([
|
||||
DeclareLaunchArgument("template_path", default_value="",
|
||||
description="Path to .npy template file (log-mel of 'hey salty')"),
|
||||
DeclareLaunchArgument("match_threshold", default_value="0.82",
|
||||
description="Cosine-similarity detection threshold"),
|
||||
DeclareLaunchArgument("cooldown_s", default_value="2.0",
|
||||
description="Minimum seconds between detections"),
|
||||
|
||||
Node(
|
||||
package="saltybot_social",
|
||||
executable="wake_word_node",
|
||||
name="wake_word_node",
|
||||
output="screen",
|
||||
parameters=[
|
||||
cfg,
|
||||
{
|
||||
"template_path": LaunchConfiguration("template_path"),
|
||||
"match_threshold": LaunchConfiguration("match_threshold"),
|
||||
"cooldown_s": LaunchConfiguration("cooldown_s"),
|
||||
},
|
||||
],
|
||||
),
|
||||
])
|
||||
@ -0,0 +1,343 @@
|
||||
"""wake_word_node.py — Audio wake-word detector ('hey salty').
|
||||
Issue #320
|
||||
|
||||
Subscribes to raw PCM-16 audio on /social/speech/audio_raw, maintains a
|
||||
sliding window buffer, and on each hop tick computes log-mel spectrogram
|
||||
features of the most recent window. If energy is above the gate threshold
|
||||
AND cosine similarity to the stored template is above match_threshold the
|
||||
detection fires: Bool(True) is published on /saltybot/wake_word_detected.
|
||||
|
||||
Detection is one-shot (only True is published) and guarded by a cooldown
|
||||
so rapid re-fires are suppressed. When no template is loaded (template_path
|
||||
is empty) the node stays passive — energy gating is applied but no match is
|
||||
attempted.
|
||||
|
||||
Audio format expected
|
||||
─────────────────────
|
||||
UInt8MultiArray, same feed as vad_node:
|
||||
raw PCM-16 little-endian mono at ``sample_rate`` Hz.
|
||||
|
||||
Subscriptions
|
||||
─────────────
|
||||
/social/speech/audio_raw std_msgs/UInt8MultiArray — raw PCM-16 chunks
|
||||
|
||||
Publications
|
||||
────────────
|
||||
/saltybot/wake_word_detected std_msgs/Bool — True on each detection event
|
||||
|
||||
Parameters
|
||||
──────────
|
||||
audio_topic (str, "/social/speech/audio_raw")
|
||||
output_topic (str, "/saltybot/wake_word_detected")
|
||||
sample_rate (int, 16000) sample rate of incoming audio (Hz)
|
||||
window_s (float, 1.5) detection window duration (s)
|
||||
hop_s (float, 0.1) detection timer period (s)
|
||||
energy_threshold (float, 0.02) RMS gate; below this → skip matching
|
||||
match_threshold (float, 0.82) cosine-similarity gate; above → detect
|
||||
cooldown_s (float, 2.0) minimum gap between successive detections
|
||||
template_path (str, "") path to .npy template file; "" = passive
|
||||
n_fft (int, 512) FFT size for mel spectrogram
|
||||
n_mels (int, 40) number of mel filterbank bands
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import struct
|
||||
import threading
|
||||
import time
|
||||
from collections import deque
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import rclpy
|
||||
from rclpy.node import Node
|
||||
from rclpy.qos import QoSProfile
|
||||
from std_msgs.msg import Bool, UInt8MultiArray
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
_NP = True
|
||||
except ImportError: # pragma: no cover
|
||||
_NP = False
|
||||
|
||||
INT16_MAX = 32768.0
|
||||
|
||||
|
||||
# ── Pure DSP helpers (no ROS, numpy only) ──────────────────────────────────────
|
||||
|
||||
def pcm16_to_float(data: bytes) -> "np.ndarray":
|
||||
"""Decode PCM-16 LE bytes → float32 ndarray in [-1.0, 1.0]."""
|
||||
n = len(data) // 2
|
||||
if n == 0:
|
||||
return np.zeros(0, dtype=np.float32)
|
||||
samples = struct.unpack(f"<{n}h", data[:n * 2])
|
||||
return np.array(samples, dtype=np.float32) / INT16_MAX
|
||||
|
||||
|
||||
def mel_filterbank(sr: int, n_fft: int, n_mels: int,
|
||||
fmin: float = 80.0, fmax: Optional[float] = None) -> "np.ndarray":
|
||||
"""Build a triangular mel filterbank matrix [n_mels, n_fft//2+1]."""
|
||||
if fmax is None:
|
||||
fmax = sr / 2.0
|
||||
|
||||
def hz_to_mel(hz: float) -> float:
|
||||
return 2595.0 * math.log10(1.0 + hz / 700.0)
|
||||
|
||||
def mel_to_hz(mel: float) -> float:
|
||||
return 700.0 * (10.0 ** (mel / 2595.0) - 1.0)
|
||||
|
||||
mel_lo = hz_to_mel(fmin)
|
||||
mel_hi = hz_to_mel(fmax)
|
||||
mel_pts = np.linspace(mel_lo, mel_hi, n_mels + 2)
|
||||
hz_pts = np.array([mel_to_hz(m) for m in mel_pts])
|
||||
freqs = np.fft.rfftfreq(n_fft, d=1.0 / sr)
|
||||
|
||||
fb = np.zeros((n_mels, len(freqs)), dtype=np.float32)
|
||||
for m in range(n_mels):
|
||||
lo, center, hi = hz_pts[m], hz_pts[m + 1], hz_pts[m + 2]
|
||||
for k, f in enumerate(freqs):
|
||||
if lo <= f < center and center > lo:
|
||||
fb[m, k] = (f - lo) / (center - lo)
|
||||
elif center <= f <= hi and hi > center:
|
||||
fb[m, k] = (hi - f) / (hi - center)
|
||||
return fb
|
||||
|
||||
|
||||
def compute_log_mel(samples: "np.ndarray", sr: int,
|
||||
n_fft: int = 512, n_mels: int = 40,
|
||||
hop: int = 256) -> "np.ndarray":
|
||||
"""Return log-mel spectrogram [n_mels, T] of *samples* (float32 [-1,1])."""
|
||||
n = len(samples)
|
||||
window = np.hanning(n_fft).astype(np.float32)
|
||||
frames = []
|
||||
for start in range(0, max(n - n_fft + 1, 1), hop):
|
||||
chunk = samples[start:start + n_fft]
|
||||
if len(chunk) < n_fft:
|
||||
chunk = np.pad(chunk, (0, n_fft - len(chunk)))
|
||||
power = np.abs(np.fft.rfft(chunk * window)) ** 2
|
||||
frames.append(power)
|
||||
frames_arr = np.array(frames, dtype=np.float32).T # [bins, T]
|
||||
fb = mel_filterbank(sr, n_fft, n_mels)
|
||||
mel = fb @ frames_arr # [n_mels, T]
|
||||
mel = np.where(mel > 1e-10, mel, 1e-10)
|
||||
return np.log(mel)
|
||||
|
||||
|
||||
def cosine_sim(a: "np.ndarray", b: "np.ndarray") -> float:
|
||||
"""Cosine similarity between two arrays, matched by minimum length."""
|
||||
af = a.flatten().astype(np.float64)
|
||||
bf = b.flatten().astype(np.float64)
|
||||
min_len = min(len(af), len(bf))
|
||||
if min_len == 0:
|
||||
return 0.0
|
||||
af = af[:min_len]
|
||||
bf = bf[:min_len]
|
||||
denom = float(np.linalg.norm(af)) * float(np.linalg.norm(bf))
|
||||
if denom < 1e-12:
|
||||
return 0.0
|
||||
return float(np.dot(af, bf) / denom)
|
||||
|
||||
|
||||
def rms(samples: "np.ndarray") -> float:
|
||||
"""RMS amplitude of a float sample array."""
|
||||
if len(samples) == 0:
|
||||
return 0.0
|
||||
return float(np.sqrt(np.mean(samples.astype(np.float64) ** 2)))
|
||||
|
||||
|
||||
# ── Ring buffer ────────────────────────────────────────────────────────────────
|
||||
|
||||
class AudioRingBuffer:
|
||||
"""Lock-free sliding window for raw float audio samples."""
|
||||
|
||||
def __init__(self, max_samples: int) -> None:
|
||||
self._buf: deque = deque(maxlen=max_samples)
|
||||
|
||||
def push(self, samples: "np.ndarray") -> None:
|
||||
self._buf.extend(samples.tolist())
|
||||
|
||||
def get_window(self, n_samples: int) -> Optional["np.ndarray"]:
|
||||
"""Return last n_samples as float32 array, or None if buffer too short."""
|
||||
if len(self._buf) < n_samples:
|
||||
return None
|
||||
return np.array(list(self._buf)[-n_samples:], dtype=np.float32)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._buf)
|
||||
|
||||
|
||||
# ── Detector ───────────────────────────────────────────────────────────────────
|
||||
|
||||
class WakeWordDetector:
|
||||
"""Energy-gated cosine-similarity wake-word detector.
|
||||
|
||||
Args:
|
||||
template: Log-mel feature array of the wake word, or None
|
||||
(passive — never fires when None).
|
||||
energy_threshold: Minimum RMS to proceed to feature matching.
|
||||
match_threshold: Minimum cosine similarity to fire.
|
||||
sample_rate: Expected audio sample rate (Hz).
|
||||
n_fft: FFT size for mel computation.
|
||||
n_mels: Number of mel bands.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
template: Optional["np.ndarray"],
|
||||
energy_threshold: float = 0.02,
|
||||
match_threshold: float = 0.82,
|
||||
sample_rate: int = 16000,
|
||||
n_fft: int = 512,
|
||||
n_mels: int = 40) -> None:
|
||||
self._template = template
|
||||
self._energy_thr = energy_threshold
|
||||
self._match_thr = match_threshold
|
||||
self._sr = sample_rate
|
||||
self._n_fft = n_fft
|
||||
self._n_mels = n_mels
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
def detect(self, samples: "np.ndarray") -> Tuple[bool, float, float]:
|
||||
"""Run detection on a window of float samples.
|
||||
|
||||
Returns
|
||||
-------
|
||||
(detected, rms_value, similarity)
|
||||
"""
|
||||
energy = rms(samples)
|
||||
if energy < self._energy_thr:
|
||||
return False, energy, 0.0
|
||||
if self._template is None:
|
||||
return False, energy, 0.0
|
||||
|
||||
hop = max(1, self._n_fft // 2)
|
||||
feats = compute_log_mel(samples, self._sr, self._n_fft, self._n_mels, hop)
|
||||
sim = cosine_sim(feats, self._template)
|
||||
return sim >= self._match_thr, energy, sim
|
||||
|
||||
@property
|
||||
def has_template(self) -> bool:
|
||||
return self._template is not None
|
||||
|
||||
|
||||
# ── ROS2 node ──────────────────────────────────────────────────────────────────
|
||||
|
||||
class WakeWordNode(Node):
|
||||
"""ROS2 node: 'hey salty' wake-word detection via energy + template matching."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__("wake_word_node")
|
||||
|
||||
self.declare_parameter("audio_topic", "/social/speech/audio_raw")
|
||||
self.declare_parameter("output_topic", "/saltybot/wake_word_detected")
|
||||
self.declare_parameter("sample_rate", 16000)
|
||||
self.declare_parameter("window_s", 1.5)
|
||||
self.declare_parameter("hop_s", 0.1)
|
||||
self.declare_parameter("energy_threshold", 0.02)
|
||||
self.declare_parameter("match_threshold", 0.82)
|
||||
self.declare_parameter("cooldown_s", 2.0)
|
||||
self.declare_parameter("template_path", "")
|
||||
self.declare_parameter("n_fft", 512)
|
||||
self.declare_parameter("n_mels", 40)
|
||||
|
||||
audio_topic = self.get_parameter("audio_topic").value
|
||||
output_topic = self.get_parameter("output_topic").value
|
||||
self._sr = int(self.get_parameter("sample_rate").value)
|
||||
self._win_s = float(self.get_parameter("window_s").value)
|
||||
hop_s = float(self.get_parameter("hop_s").value)
|
||||
energy_thr = float(self.get_parameter("energy_threshold").value)
|
||||
match_thr = float(self.get_parameter("match_threshold").value)
|
||||
self._cool_s = float(self.get_parameter("cooldown_s").value)
|
||||
tmpl_path = str(self.get_parameter("template_path").value)
|
||||
n_fft = int(self.get_parameter("n_fft").value)
|
||||
n_mels = int(self.get_parameter("n_mels").value)
|
||||
|
||||
# ── Load template ──────────────────────────────────────────────
|
||||
template: Optional["np.ndarray"] = None
|
||||
if tmpl_path and _NP:
|
||||
try:
|
||||
template = np.load(tmpl_path)
|
||||
self.get_logger().info(
|
||||
f"WakeWord: loaded template {tmpl_path} shape={template.shape}"
|
||||
)
|
||||
except Exception as exc:
|
||||
self.get_logger().warn(
|
||||
f"WakeWord: could not load template '{tmpl_path}': {exc} — passive mode"
|
||||
)
|
||||
|
||||
# ── Ring buffer ────────────────────────────────────────────────
|
||||
max_samples = int(self._sr * self._win_s * 4) # 4× headroom
|
||||
self._win_n = int(self._sr * self._win_s)
|
||||
self._buf = AudioRingBuffer(max_samples)
|
||||
|
||||
# ── Detector ───────────────────────────────────────────────────
|
||||
self._detector = WakeWordDetector(template, energy_thr, match_thr,
|
||||
self._sr, n_fft, n_mels)
|
||||
|
||||
# ── State ──────────────────────────────────────────────────────
|
||||
self._last_det_t: float = 0.0
|
||||
self._lock = threading.Lock()
|
||||
|
||||
# ── ROS ────────────────────────────────────────────────────────
|
||||
qos = QoSProfile(depth=10)
|
||||
self._pub = self.create_publisher(Bool, output_topic, qos)
|
||||
self._sub = self.create_subscription(
|
||||
UInt8MultiArray, audio_topic, self._on_audio, qos
|
||||
)
|
||||
self._timer = self.create_timer(hop_s, self._detection_cb)
|
||||
|
||||
tmpl_status = (f"template={tmpl_path}" if template is not None
|
||||
else "passive (no template)")
|
||||
self.get_logger().info(
|
||||
f"WakeWordNode ready — {tmpl_status}, "
|
||||
f"window={self._win_s}s, hop={hop_s}s, "
|
||||
f"energy_thr={energy_thr}, match_thr={match_thr}"
|
||||
)
|
||||
|
||||
# ── Subscription ───────────────────────────────────────────────────
|
||||
|
||||
def _on_audio(self, msg) -> None:
|
||||
if not _NP:
|
||||
return
|
||||
try:
|
||||
raw = bytes(msg.data)
|
||||
samples = pcm16_to_float(raw)
|
||||
except Exception as exc:
|
||||
self.get_logger().warn(f"WakeWord: audio decode error: {exc}")
|
||||
return
|
||||
with self._lock:
|
||||
self._buf.push(samples)
|
||||
|
||||
# ── Detection timer ────────────────────────────────────────────────
|
||||
|
||||
def _detection_cb(self) -> None:
|
||||
if not _NP:
|
||||
return
|
||||
now = time.monotonic()
|
||||
with self._lock:
|
||||
window = self._buf.get_window(self._win_n)
|
||||
if window is None:
|
||||
return
|
||||
|
||||
detected, energy, sim = self._detector.detect(window)
|
||||
|
||||
if detected and (now - self._last_det_t) >= self._cool_s:
|
||||
self._last_det_t = now
|
||||
self.get_logger().info(
|
||||
f"WakeWord: 'hey salty' detected "
|
||||
f"(rms={energy:.4f}, sim={sim:.3f})"
|
||||
)
|
||||
out = Bool()
|
||||
out.data = True
|
||||
self._pub.publish(out)
|
||||
|
||||
|
||||
def main(args=None) -> None:
|
||||
rclpy.init(args=args)
|
||||
node = WakeWordNode()
|
||||
try:
|
||||
rclpy.spin(node)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
node.destroy_node()
|
||||
rclpy.shutdown()
|
||||
@ -57,6 +57,8 @@ setup(
|
||||
'topic_memory_node = saltybot_social.topic_memory_node:main',
|
||||
# Personal space respector (Issue #310)
|
||||
'personal_space_node = saltybot_social.personal_space_node:main',
|
||||
# Audio wake-word detector — 'hey salty' (Issue #320)
|
||||
'wake_word_node = saltybot_social.wake_word_node:main',
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
711
jetson/ros2_ws/src/saltybot_social/test/test_wake_word.py
Normal file
711
jetson/ros2_ws/src/saltybot_social/test/test_wake_word.py
Normal file
@ -0,0 +1,711 @@
|
||||
"""test_wake_word.py — Offline tests for wake_word_node (Issue #320).
|
||||
|
||||
Stubs out rclpy and ROS message types so tests run without a ROS install.
|
||||
numpy is required (standard on the Jetson).
|
||||
"""
|
||||
|
||||
import importlib.util
|
||||
import math
|
||||
import struct
|
||||
import sys
|
||||
import time
|
||||
import types
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
# ── ROS2 stubs ────────────────────────────────────────────────────────────────
|
||||
|
||||
def _make_ros_stubs():
|
||||
for mod_name in ("rclpy", "rclpy.node", "rclpy.qos",
|
||||
"std_msgs", "std_msgs.msg"):
|
||||
if mod_name not in sys.modules:
|
||||
sys.modules[mod_name] = types.ModuleType(mod_name)
|
||||
|
||||
class _Node:
|
||||
def __init__(self, name="node"):
|
||||
self._name = name
|
||||
if not hasattr(self, "_params"):
|
||||
self._params = {}
|
||||
self._pubs = {}
|
||||
self._subs = {}
|
||||
self._logs = []
|
||||
self._timers = []
|
||||
|
||||
def declare_parameter(self, name, default):
|
||||
if name not in self._params:
|
||||
self._params[name] = default
|
||||
|
||||
def get_parameter(self, name):
|
||||
class _P:
|
||||
def __init__(self, v): self.value = v
|
||||
return _P(self._params.get(name))
|
||||
|
||||
def create_publisher(self, msg_type, topic, qos):
|
||||
pub = _FakePub()
|
||||
self._pubs[topic] = pub
|
||||
return pub
|
||||
|
||||
def create_subscription(self, msg_type, topic, cb, qos):
|
||||
self._subs[topic] = cb
|
||||
return object()
|
||||
|
||||
def create_timer(self, period, cb):
|
||||
self._timers.append(cb)
|
||||
return object()
|
||||
|
||||
def get_logger(self):
|
||||
node = self
|
||||
class _L:
|
||||
def info(self, m): node._logs.append(("INFO", m))
|
||||
def warn(self, m): node._logs.append(("WARN", m))
|
||||
def error(self, m): node._logs.append(("ERROR", m))
|
||||
return _L()
|
||||
|
||||
def destroy_node(self): pass
|
||||
|
||||
class _FakePub:
|
||||
def __init__(self):
|
||||
self.msgs = []
|
||||
def publish(self, msg):
|
||||
self.msgs.append(msg)
|
||||
|
||||
class _QoSProfile:
|
||||
def __init__(self, depth=10): self.depth = depth
|
||||
|
||||
class _Bool:
|
||||
def __init__(self): self.data = False
|
||||
|
||||
class _UInt8MultiArray:
|
||||
def __init__(self): self.data = b""
|
||||
|
||||
rclpy_mod = sys.modules["rclpy"]
|
||||
rclpy_mod.init = lambda args=None: None
|
||||
rclpy_mod.spin = lambda node: None
|
||||
rclpy_mod.shutdown = lambda: None
|
||||
|
||||
sys.modules["rclpy.node"].Node = _Node
|
||||
sys.modules["rclpy.qos"].QoSProfile = _QoSProfile
|
||||
sys.modules["std_msgs.msg"].Bool = _Bool
|
||||
sys.modules["std_msgs.msg"].UInt8MultiArray = _UInt8MultiArray
|
||||
|
||||
return _Node, _FakePub, _Bool, _UInt8MultiArray
|
||||
|
||||
|
||||
_Node, _FakePub, _Bool, _UInt8MultiArray = _make_ros_stubs()
|
||||
|
||||
|
||||
# ── Module loader ─────────────────────────────────────────────────────────────
|
||||
|
||||
_SRC = (
|
||||
"/Users/seb/AI/saltylab-firmware/jetson/ros2_ws/src/"
|
||||
"saltybot_social/saltybot_social/wake_word_node.py"
|
||||
)
|
||||
|
||||
|
||||
def _load_mod():
|
||||
spec = importlib.util.spec_from_file_location("wake_word_testmod", _SRC)
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(mod)
|
||||
return mod
|
||||
|
||||
|
||||
def _make_node(mod, **kwargs):
|
||||
node = mod.WakeWordNode.__new__(mod.WakeWordNode)
|
||||
defaults = {
|
||||
"audio_topic": "/social/speech/audio_raw",
|
||||
"output_topic": "/saltybot/wake_word_detected",
|
||||
"sample_rate": 16000,
|
||||
"window_s": 1.5,
|
||||
"hop_s": 0.1,
|
||||
"energy_threshold": 0.02,
|
||||
"match_threshold": 0.82,
|
||||
"cooldown_s": 2.0,
|
||||
"template_path": "",
|
||||
"n_fft": 512,
|
||||
"n_mels": 40,
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
node._params = dict(defaults)
|
||||
mod.WakeWordNode.__init__(node)
|
||||
return node
|
||||
|
||||
|
||||
def _make_pcm_bytes(samples: np.ndarray) -> bytes:
|
||||
"""Encode float32 array [-1,1] to PCM-16 LE bytes."""
|
||||
ints = np.clip(samples * 32768.0, -32768, 32767).astype(np.int16)
|
||||
return ints.tobytes()
|
||||
|
||||
|
||||
def _make_audio_msg(samples: np.ndarray) -> _UInt8MultiArray:
|
||||
m = _UInt8MultiArray()
|
||||
m.data = _make_pcm_bytes(samples)
|
||||
return m
|
||||
|
||||
|
||||
def _sine(freq: float, duration: float, sr: int = 16000,
|
||||
amp: float = 0.5) -> np.ndarray:
|
||||
"""Generate a mono sine wave."""
|
||||
t = np.arange(int(sr * duration)) / sr
|
||||
return (amp * np.sin(2 * math.pi * freq * t)).astype(np.float32)
|
||||
|
||||
|
||||
def _silence(duration: float, sr: int = 16000) -> np.ndarray:
|
||||
return np.zeros(int(sr * duration), dtype=np.float32)
|
||||
|
||||
|
||||
def _make_template(mod, sr: int = 16000, n_fft: int = 512,
|
||||
n_mels: int = 40) -> np.ndarray:
|
||||
"""Compute a template from a synthetic 'wake word' signal for testing."""
|
||||
signal = _sine(300, 1.5, sr, amp=0.6)
|
||||
hop = n_fft // 2
|
||||
return mod.compute_log_mel(signal, sr, n_fft, n_mels, hop)
|
||||
|
||||
|
||||
# ── Tests: pcm16_to_float ─────────────────────────────────────────────────────
|
||||
|
||||
class TestPcm16ToFloat(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls): cls.mod = _load_mod()
|
||||
|
||||
def _conv(self, samples):
|
||||
raw = _make_pcm_bytes(np.array(samples, dtype=np.float32))
|
||||
return self.mod.pcm16_to_float(raw)
|
||||
|
||||
def test_zeros(self):
|
||||
out = self._conv([0.0, 0.0, 0.0])
|
||||
np.testing.assert_allclose(out, [0.0, 0.0, 0.0], atol=1e-4)
|
||||
|
||||
def test_positive(self):
|
||||
out = self._conv([0.5])
|
||||
self.assertAlmostEqual(float(out[0]), 0.5, places=3)
|
||||
|
||||
def test_negative(self):
|
||||
out = self._conv([-0.5])
|
||||
self.assertAlmostEqual(float(out[0]), -0.5, places=3)
|
||||
|
||||
def test_roundtrip_length(self):
|
||||
arr = np.linspace(-1.0, 1.0, 100, dtype=np.float32)
|
||||
raw = _make_pcm_bytes(arr)
|
||||
out = self.mod.pcm16_to_float(raw)
|
||||
self.assertEqual(len(out), 100)
|
||||
|
||||
def test_empty_bytes_returns_empty(self):
|
||||
out = self.mod.pcm16_to_float(b"")
|
||||
self.assertEqual(len(out), 0)
|
||||
|
||||
def test_odd_byte_ignored(self):
|
||||
# 3 bytes → 1 complete int16 + 1 orphan byte
|
||||
out = self.mod.pcm16_to_float(b"\x00\x40\xff")
|
||||
self.assertEqual(len(out), 1)
|
||||
|
||||
|
||||
# ── Tests: rms ────────────────────────────────────────────────────────────────
|
||||
|
||||
class TestRms(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls): cls.mod = _load_mod()
|
||||
|
||||
def test_zero_signal(self):
|
||||
self.assertAlmostEqual(self.mod.rms(np.zeros(100)), 0.0)
|
||||
|
||||
def test_constant_signal(self):
|
||||
# RMS of constant 0.5 = 0.5
|
||||
self.assertAlmostEqual(self.mod.rms(np.full(100, 0.5)), 0.5, places=5)
|
||||
|
||||
def test_sine_rms(self):
|
||||
# RMS of sin = amp / sqrt(2)
|
||||
amp = 0.8
|
||||
s = _sine(440, 1.0, amp=amp)
|
||||
expected = amp / math.sqrt(2)
|
||||
self.assertAlmostEqual(self.mod.rms(s), expected, places=2)
|
||||
|
||||
def test_empty_array(self):
|
||||
self.assertEqual(self.mod.rms(np.array([])), 0.0)
|
||||
|
||||
|
||||
# ── Tests: mel_filterbank ─────────────────────────────────────────────────────
|
||||
|
||||
class TestMelFilterbank(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls): cls.mod = _load_mod()
|
||||
|
||||
def test_shape(self):
|
||||
fb = self.mod.mel_filterbank(16000, 512, 40)
|
||||
self.assertEqual(fb.shape, (40, 257))
|
||||
|
||||
def test_non_negative(self):
|
||||
fb = self.mod.mel_filterbank(16000, 512, 40)
|
||||
self.assertTrue((fb >= 0).all())
|
||||
|
||||
def test_rows_sum_positive(self):
|
||||
fb = self.mod.mel_filterbank(16000, 512, 40)
|
||||
self.assertTrue((fb.sum(axis=1) > 0).all())
|
||||
|
||||
def test_custom_n_mels(self):
|
||||
fb = self.mod.mel_filterbank(16000, 256, 20)
|
||||
self.assertEqual(fb.shape[0], 20)
|
||||
|
||||
|
||||
# ── Tests: compute_log_mel ────────────────────────────────────────────────────
|
||||
|
||||
class TestComputeLogMel(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls): cls.mod = _load_mod()
|
||||
|
||||
def test_output_shape_rows(self):
|
||||
s = _sine(440, 1.5)
|
||||
out = self.mod.compute_log_mel(s, 16000, n_fft=512, n_mels=40, hop=256)
|
||||
self.assertEqual(out.shape[0], 40)
|
||||
|
||||
def test_output_has_time_axis(self):
|
||||
s = _sine(440, 1.5)
|
||||
out = self.mod.compute_log_mel(s, 16000, n_fft=512, n_mels=40, hop=256)
|
||||
self.assertGreater(out.shape[1], 0)
|
||||
|
||||
def test_output_finite(self):
|
||||
s = _sine(440, 1.5)
|
||||
out = self.mod.compute_log_mel(s, 16000, n_fft=512, n_mels=40, hop=256)
|
||||
self.assertTrue(np.isfinite(out).all())
|
||||
|
||||
def test_silence_gives_low_values(self):
|
||||
s = _silence(1.5)
|
||||
out = self.mod.compute_log_mel(s, 16000, n_fft=512, n_mels=40, hop=256)
|
||||
# All values should be very small (near log(1e-10))
|
||||
self.assertTrue((out < -20).all())
|
||||
|
||||
def test_short_signal_no_crash(self):
|
||||
# Shorter than one FFT frame
|
||||
s = np.zeros(100, dtype=np.float32)
|
||||
out = self.mod.compute_log_mel(s, 16000, n_fft=512, n_mels=40, hop=256)
|
||||
self.assertEqual(out.shape[0], 40)
|
||||
|
||||
|
||||
# ── Tests: cosine_sim ─────────────────────────────────────────────────────────
|
||||
|
||||
class TestCosineSim(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls): cls.mod = _load_mod()
|
||||
|
||||
def test_identical_vectors(self):
|
||||
v = np.array([1.0, 2.0, 3.0])
|
||||
self.assertAlmostEqual(self.mod.cosine_sim(v, v), 1.0, places=5)
|
||||
|
||||
def test_orthogonal_vectors(self):
|
||||
a = np.array([1.0, 0.0])
|
||||
b = np.array([0.0, 1.0])
|
||||
self.assertAlmostEqual(self.mod.cosine_sim(a, b), 0.0, places=5)
|
||||
|
||||
def test_opposite_vectors(self):
|
||||
v = np.array([1.0, 2.0, 3.0])
|
||||
self.assertAlmostEqual(self.mod.cosine_sim(v, -v), -1.0, places=5)
|
||||
|
||||
def test_zero_vector_returns_zero(self):
|
||||
a = np.zeros(5)
|
||||
b = np.ones(5)
|
||||
self.assertEqual(self.mod.cosine_sim(a, b), 0.0)
|
||||
|
||||
def test_2d_arrays(self):
|
||||
a = np.ones((4, 10))
|
||||
b = np.ones((4, 10))
|
||||
self.assertAlmostEqual(self.mod.cosine_sim(a, b), 1.0, places=5)
|
||||
|
||||
def test_different_lengths_truncated(self):
|
||||
a = np.array([1.0, 2.0, 3.0, 4.0])
|
||||
b = np.array([1.0, 2.0, 3.0])
|
||||
# Should not crash, uses min length
|
||||
result = self.mod.cosine_sim(a, b)
|
||||
self.assertTrue(-1.0 <= result <= 1.0)
|
||||
|
||||
def test_range_is_bounded(self):
|
||||
rng = np.random.default_rng(42)
|
||||
a = rng.standard_normal(100)
|
||||
b = rng.standard_normal(100)
|
||||
result = self.mod.cosine_sim(a, b)
|
||||
self.assertGreaterEqual(result, -1.0)
|
||||
self.assertLessEqual(result, 1.0)
|
||||
|
||||
|
||||
# ── Tests: AudioRingBuffer ────────────────────────────────────────────────────
|
||||
|
||||
class TestAudioRingBuffer(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls): cls.mod = _load_mod()
|
||||
|
||||
def _buf(self, max_samples=1000):
|
||||
return self.mod.AudioRingBuffer(max_samples)
|
||||
|
||||
def test_empty_initially(self):
|
||||
b = self._buf()
|
||||
self.assertEqual(len(b), 0)
|
||||
|
||||
def test_push_increases_len(self):
|
||||
b = self._buf()
|
||||
b.push(np.ones(100, dtype=np.float32))
|
||||
self.assertEqual(len(b), 100)
|
||||
|
||||
def test_get_window_none_when_short(self):
|
||||
b = self._buf()
|
||||
b.push(np.ones(50, dtype=np.float32))
|
||||
self.assertIsNone(b.get_window(100))
|
||||
|
||||
def test_get_window_ok_when_full(self):
|
||||
b = self._buf()
|
||||
data = np.arange(200, dtype=np.float32)
|
||||
b.push(data)
|
||||
w = b.get_window(100)
|
||||
self.assertIsNotNone(w)
|
||||
self.assertEqual(len(w), 100)
|
||||
|
||||
def test_get_window_returns_latest(self):
|
||||
b = self._buf()
|
||||
b.push(np.zeros(100, dtype=np.float32))
|
||||
b.push(np.ones(100, dtype=np.float32))
|
||||
w = b.get_window(100)
|
||||
np.testing.assert_allclose(w, np.ones(100))
|
||||
|
||||
def test_maxlen_evicts_oldest(self):
|
||||
b = self._buf(max_samples=100)
|
||||
b.push(np.zeros(60, dtype=np.float32))
|
||||
b.push(np.ones(60, dtype=np.float32)) # should evict 20 zeros
|
||||
self.assertEqual(len(b), 100)
|
||||
w = b.get_window(100)
|
||||
# Last 40 samples should be ones
|
||||
np.testing.assert_allclose(w[-40:], np.ones(40))
|
||||
|
||||
def test_exact_window_size(self):
|
||||
b = self._buf(500)
|
||||
data = np.arange(300, dtype=np.float32)
|
||||
b.push(data)
|
||||
w = b.get_window(300)
|
||||
np.testing.assert_allclose(w, data)
|
||||
|
||||
|
||||
# ── Tests: WakeWordDetector ───────────────────────────────────────────────────
|
||||
|
||||
class TestWakeWordDetector(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.mod = _load_mod()
|
||||
cls.sr = 16000
|
||||
cls.template = _make_template(cls.mod, cls.sr)
|
||||
|
||||
def _det(self, template=None, energy_thr=0.02, match_thr=0.82):
|
||||
return self.mod.WakeWordDetector(template, energy_thr, match_thr,
|
||||
self.sr, 512, 40)
|
||||
|
||||
def test_no_template_never_detects(self):
|
||||
det = self._det(template=None)
|
||||
loud = _sine(300, 1.5, self.sr, amp=0.8)
|
||||
detected, _, _ = det.detect(loud)
|
||||
self.assertFalse(detected)
|
||||
|
||||
def test_has_template_false_when_no_template(self):
|
||||
det = self._det(template=None)
|
||||
self.assertFalse(det.has_template)
|
||||
|
||||
def test_has_template_true_when_set(self):
|
||||
det = self._det(template=self.template)
|
||||
self.assertTrue(det.has_template)
|
||||
|
||||
def test_silence_below_energy_gate(self):
|
||||
det = self._det(template=self.template, energy_thr=0.02)
|
||||
silence = _silence(1.5, self.sr)
|
||||
detected, rms_val, sim = det.detect(silence)
|
||||
self.assertFalse(detected)
|
||||
self.assertAlmostEqual(rms_val, 0.0, places=4)
|
||||
self.assertAlmostEqual(sim, 0.0, places=4)
|
||||
|
||||
def test_returns_rms_value(self):
|
||||
det = self._det(template=None)
|
||||
s = _sine(300, 1.5, self.sr, amp=0.5)
|
||||
_, rms_val, _ = det.detect(s)
|
||||
expected = 0.5 / math.sqrt(2)
|
||||
self.assertAlmostEqual(rms_val, expected, places=2)
|
||||
|
||||
def test_identical_signal_detects(self):
|
||||
"""Signal identical to template should give sim ≈ 1.0 → detect."""
|
||||
signal = _sine(300, 1.5, self.sr, amp=0.6)
|
||||
det = self._det(template=self.template, match_thr=0.99)
|
||||
detected, _, sim = det.detect(signal)
|
||||
# Sim must be very high for an identical-source signal
|
||||
self.assertGreater(sim, 0.99)
|
||||
self.assertTrue(detected)
|
||||
|
||||
def test_different_signal_low_sim(self):
|
||||
"""A very different signal (white noise) should have low similarity."""
|
||||
rng = np.random.default_rng(7)
|
||||
noise = (rng.standard_normal(int(self.sr * 1.5)) * 0.4).astype(np.float32)
|
||||
det = self._det(template=self.template, match_thr=0.82)
|
||||
_, _, sim = det.detect(noise)
|
||||
# White noise sim to a tonal template should be < 0.6
|
||||
self.assertLess(sim, 0.6)
|
||||
|
||||
def test_threshold_boundary_low(self):
|
||||
"""Setting match_thr=0.0 with a loud signal should fire if template set."""
|
||||
signal = _sine(300, 1.5, self.sr, amp=0.6)
|
||||
det = self._det(template=self.template, match_thr=0.0)
|
||||
detected, _, _ = det.detect(signal)
|
||||
self.assertTrue(detected)
|
||||
|
||||
def test_threshold_boundary_high(self):
|
||||
"""Setting match_thr=1.1 (above max) should never fire."""
|
||||
signal = _sine(300, 1.5, self.sr, amp=0.6)
|
||||
det = self._det(template=self.template, match_thr=1.1)
|
||||
detected, _, _ = det.detect(signal)
|
||||
self.assertFalse(detected)
|
||||
|
||||
def test_energy_below_threshold_skips_matching(self):
|
||||
"""Low energy → sim returned as 0.0 regardless of template."""
|
||||
very_quiet = _sine(300, 1.5, self.sr, amp=0.001)
|
||||
det = self._det(template=self.template, energy_thr=0.1)
|
||||
detected, rms_val, sim = det.detect(very_quiet)
|
||||
self.assertFalse(detected)
|
||||
self.assertAlmostEqual(sim, 0.0, places=5)
|
||||
|
||||
|
||||
# ── Tests: node init ──────────────────────────────────────────────────────────
|
||||
|
||||
class TestNodeInit(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls): cls.mod = _load_mod()
|
||||
|
||||
def test_instantiates(self):
|
||||
self.assertIsNotNone(_make_node(self.mod))
|
||||
|
||||
def test_pub_registered(self):
|
||||
node = _make_node(self.mod)
|
||||
self.assertIn("/saltybot/wake_word_detected", node._pubs)
|
||||
|
||||
def test_sub_registered(self):
|
||||
node = _make_node(self.mod)
|
||||
self.assertIn("/social/speech/audio_raw", node._subs)
|
||||
|
||||
def test_timer_registered(self):
|
||||
node = _make_node(self.mod)
|
||||
self.assertGreater(len(node._timers), 0)
|
||||
|
||||
def test_custom_topics(self):
|
||||
node = _make_node(self.mod,
|
||||
audio_topic="/my/audio",
|
||||
output_topic="/my/wake")
|
||||
self.assertIn("/my/audio", node._subs)
|
||||
self.assertIn("/my/wake", node._pubs)
|
||||
|
||||
def test_no_template_passive(self):
|
||||
node = _make_node(self.mod, template_path="")
|
||||
self.assertFalse(node._detector.has_template)
|
||||
|
||||
def test_bad_template_path_warns(self):
|
||||
node = _make_node(self.mod, template_path="/nonexistent/template.npy")
|
||||
warns = [m for lvl, m in node._logs if lvl == "WARN"]
|
||||
self.assertTrue(any("template" in m.lower() or "passive" in m.lower()
|
||||
for m in warns))
|
||||
|
||||
def test_ring_buffer_allocated(self):
|
||||
node = _make_node(self.mod)
|
||||
self.assertIsNotNone(node._buf)
|
||||
|
||||
def test_window_n_computed(self):
|
||||
node = _make_node(self.mod, sample_rate=16000, window_s=1.5)
|
||||
self.assertEqual(node._win_n, 24000)
|
||||
|
||||
|
||||
# ── Tests: _on_audio callback ─────────────────────────────────────────────────
|
||||
|
||||
class TestOnAudio(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls): cls.mod = _load_mod()
|
||||
|
||||
def setUp(self):
|
||||
self.node = _make_node(self.mod)
|
||||
|
||||
def _push(self, samples):
|
||||
msg = _make_audio_msg(samples)
|
||||
self.node._subs["/social/speech/audio_raw"](msg)
|
||||
|
||||
def test_pushes_samples_to_buffer(self):
|
||||
self._push(_sine(440, 0.5))
|
||||
self.assertGreater(len(self.node._buf), 0)
|
||||
|
||||
def test_buffer_grows_with_pushes(self):
|
||||
chunk = _sine(440, 0.1)
|
||||
before = len(self.node._buf)
|
||||
self._push(chunk)
|
||||
after = len(self.node._buf)
|
||||
self.assertGreater(after, before)
|
||||
|
||||
def test_bad_data_no_crash(self):
|
||||
msg = _UInt8MultiArray()
|
||||
msg.data = b"\xff" # 1 orphan byte — yields 0 samples, no crash
|
||||
self.node._subs["/social/speech/audio_raw"](msg)
|
||||
|
||||
def test_multiple_chunks_accumulate(self):
|
||||
for _ in range(5):
|
||||
self._push(_sine(440, 0.1))
|
||||
self.assertGreater(len(self.node._buf), 0)
|
||||
|
||||
|
||||
# ── Tests: detection callback ─────────────────────────────────────────────────
|
||||
|
||||
class TestDetectionCallback(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls): cls.mod = _load_mod()
|
||||
|
||||
def _node_with_template(self, **kwargs):
|
||||
template = _make_template(self.mod)
|
||||
node = _make_node(self.mod, **kwargs)
|
||||
node._detector = self.mod.WakeWordDetector(
|
||||
template, energy_threshold=0.02, match_threshold=0.0, # thr=0 → always fires when loud
|
||||
sample_rate=16000, n_fft=512, n_mels=40
|
||||
)
|
||||
return node
|
||||
|
||||
def _fill_buffer(self, node, signal):
|
||||
msg = _make_audio_msg(signal)
|
||||
node._subs["/social/speech/audio_raw"](msg)
|
||||
|
||||
def test_no_data_no_publish(self):
|
||||
node = _make_node(self.mod)
|
||||
node._detection_cb()
|
||||
self.assertEqual(len(node._pubs["/saltybot/wake_word_detected"].msgs), 0)
|
||||
|
||||
def test_insufficient_buffer_no_publish(self):
|
||||
node = self._node_with_template()
|
||||
# Push only 0.1 s but window is 1.5 s
|
||||
self._fill_buffer(node, _sine(300, 0.1))
|
||||
node._detection_cb()
|
||||
self.assertEqual(len(node._pubs["/saltybot/wake_word_detected"].msgs), 0)
|
||||
|
||||
def test_detects_and_publishes_true(self):
|
||||
node = self._node_with_template()
|
||||
# Fill with a loud 300 Hz sine (matches template)
|
||||
self._fill_buffer(node, _sine(300, 1.5, amp=0.6))
|
||||
node._detection_cb()
|
||||
pub = node._pubs["/saltybot/wake_word_detected"]
|
||||
self.assertEqual(len(pub.msgs), 1)
|
||||
self.assertTrue(pub.msgs[0].data)
|
||||
|
||||
def test_cooldown_suppresses_second_detection(self):
|
||||
node = self._node_with_template(cooldown_s=60.0)
|
||||
self._fill_buffer(node, _sine(300, 1.5, amp=0.6))
|
||||
node._detection_cb()
|
||||
# Second call immediately → cooldown active
|
||||
self._fill_buffer(node, _sine(300, 1.5, amp=0.6))
|
||||
node._detection_cb()
|
||||
pub = node._pubs["/saltybot/wake_word_detected"]
|
||||
self.assertEqual(len(pub.msgs), 1)
|
||||
|
||||
def test_cooldown_expired_allows_second(self):
|
||||
node = self._node_with_template(cooldown_s=0.0)
|
||||
self._fill_buffer(node, _sine(300, 1.5, amp=0.6))
|
||||
node._detection_cb()
|
||||
self._fill_buffer(node, _sine(300, 1.5, amp=0.6))
|
||||
node._detection_cb()
|
||||
pub = node._pubs["/saltybot/wake_word_detected"]
|
||||
self.assertEqual(len(pub.msgs), 2)
|
||||
|
||||
def test_no_template_never_publishes(self):
|
||||
node = _make_node(self.mod, template_path="")
|
||||
self._fill_buffer(node, _sine(300, 1.5, amp=0.8))
|
||||
node._detection_cb()
|
||||
pub = node._pubs["/saltybot/wake_word_detected"]
|
||||
self.assertEqual(len(pub.msgs), 0)
|
||||
|
||||
def test_silence_no_publish(self):
|
||||
node = self._node_with_template()
|
||||
self._fill_buffer(node, _silence(1.5))
|
||||
node._detection_cb()
|
||||
pub = node._pubs["/saltybot/wake_word_detected"]
|
||||
self.assertEqual(len(pub.msgs), 0)
|
||||
|
||||
def test_detection_logs_info(self):
|
||||
node = self._node_with_template()
|
||||
self._fill_buffer(node, _sine(300, 1.5, amp=0.6))
|
||||
node._detection_cb()
|
||||
infos = [m for lvl, m in node._logs if lvl == "INFO"]
|
||||
self.assertTrue(any("detected" in m.lower() or "hey salty" in m.lower()
|
||||
for m in infos))
|
||||
|
||||
|
||||
# ── Tests: source content ─────────────────────────────────────────────────────
|
||||
|
||||
class TestNodeSrc(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
with open(_SRC) as f: cls.src = f.read()
|
||||
|
||||
def test_issue_tag(self): self.assertIn("#320", self.src)
|
||||
def test_audio_topic(self): self.assertIn("/social/speech/audio_raw", self.src)
|
||||
def test_output_topic(self): self.assertIn("/saltybot/wake_word_detected", self.src)
|
||||
def test_wake_word_name(self): self.assertIn("hey salty", self.src)
|
||||
def test_compute_log_mel(self): self.assertIn("compute_log_mel", self.src)
|
||||
def test_cosine_sim(self): self.assertIn("cosine_sim", self.src)
|
||||
def test_mel_filterbank(self): self.assertIn("mel_filterbank", self.src)
|
||||
def test_audio_ring_buffer(self): self.assertIn("AudioRingBuffer", self.src)
|
||||
def test_wake_word_detector(self): self.assertIn("WakeWordDetector", self.src)
|
||||
def test_energy_threshold(self): self.assertIn("energy_threshold", self.src)
|
||||
def test_match_threshold(self): self.assertIn("match_threshold", self.src)
|
||||
def test_cooldown(self): self.assertIn("cooldown_s", self.src)
|
||||
def test_template_path(self): self.assertIn("template_path", self.src)
|
||||
def test_pcm16_decode(self): self.assertIn("pcm16_to_float", self.src)
|
||||
def test_threading_lock(self): self.assertIn("threading.Lock", self.src)
|
||||
def test_numpy_used(self): self.assertIn("import numpy", self.src)
|
||||
def test_main_defined(self): self.assertIn("def main", self.src)
|
||||
def test_uint8_multiarray(self): self.assertIn("UInt8MultiArray", self.src)
|
||||
|
||||
|
||||
# ── Tests: config / launch / setup ────────────────────────────────────────────
|
||||
|
||||
class TestConfig(unittest.TestCase):
|
||||
_CONFIG = (
|
||||
"/Users/seb/AI/saltylab-firmware/jetson/ros2_ws/src/"
|
||||
"saltybot_social/config/wake_word_params.yaml"
|
||||
)
|
||||
_LAUNCH = (
|
||||
"/Users/seb/AI/saltylab-firmware/jetson/ros2_ws/src/"
|
||||
"saltybot_social/launch/wake_word.launch.py"
|
||||
)
|
||||
_SETUP = (
|
||||
"/Users/seb/AI/saltylab-firmware/jetson/ros2_ws/src/"
|
||||
"saltybot_social/setup.py"
|
||||
)
|
||||
|
||||
def test_config_exists(self):
|
||||
import os; self.assertTrue(os.path.exists(self._CONFIG))
|
||||
|
||||
def test_config_energy_threshold(self):
|
||||
with open(self._CONFIG) as f: c = f.read()
|
||||
self.assertIn("energy_threshold", c)
|
||||
|
||||
def test_config_match_threshold(self):
|
||||
with open(self._CONFIG) as f: c = f.read()
|
||||
self.assertIn("match_threshold", c)
|
||||
|
||||
def test_config_template_path(self):
|
||||
with open(self._CONFIG) as f: c = f.read()
|
||||
self.assertIn("template_path", c)
|
||||
|
||||
def test_config_cooldown(self):
|
||||
with open(self._CONFIG) as f: c = f.read()
|
||||
self.assertIn("cooldown_s", c)
|
||||
|
||||
def test_launch_exists(self):
|
||||
import os; self.assertTrue(os.path.exists(self._LAUNCH))
|
||||
|
||||
def test_launch_has_template_arg(self):
|
||||
with open(self._LAUNCH) as f: c = f.read()
|
||||
self.assertIn("template_path", c)
|
||||
|
||||
def test_launch_has_threshold_arg(self):
|
||||
with open(self._LAUNCH) as f: c = f.read()
|
||||
self.assertIn("match_threshold", c)
|
||||
|
||||
def test_entry_point_in_setup(self):
|
||||
with open(self._SETUP) as f: c = f.read()
|
||||
self.assertIn("wake_word_node", c)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Loading…
x
Reference in New Issue
Block a user