Compare commits
5 Commits
b4181183e7
...
067a871103
| Author | SHA1 | Date | |
|---|---|---|---|
| 067a871103 | |||
| b96c6b96d0 | |||
| d5e0c58594 | |||
| d6553ce3d6 | |||
| 2919e629e9 |
@ -0,0 +1,19 @@
|
||||
wake_word_node:
|
||||
ros__parameters:
|
||||
audio_topic: "/social/speech/audio_raw" # PCM-16 mono input (UInt8MultiArray)
|
||||
output_topic: "/saltybot/wake_word_detected"
|
||||
|
||||
sample_rate: 16000 # Hz — must match audio source
|
||||
window_s: 1.5 # detection window length (s)
|
||||
hop_s: 0.1 # detection timer period (s)
|
||||
|
||||
energy_threshold: 0.02 # RMS gate; below this → skip matching
|
||||
match_threshold: 0.82 # cosine-similarity gate; above → detect
|
||||
cooldown_s: 2.0 # minimum gap between successive detections (s)
|
||||
|
||||
# Path to .npy template file (log-mel features of 'hey salty' recording).
|
||||
# Leave empty for passive mode (no detections fired).
|
||||
template_path: "" # e.g. "/opt/saltybot/models/hey_salty.npy"
|
||||
|
||||
n_fft: 512 # FFT size for mel spectrogram
|
||||
n_mels: 40 # mel filterbank bands
|
||||
@ -0,0 +1,43 @@
|
||||
"""wake_word.launch.py — Launch wake-word detector ('hey salty') (Issue #320).
|
||||
|
||||
Usage:
|
||||
ros2 launch saltybot_social wake_word.launch.py
|
||||
ros2 launch saltybot_social wake_word.launch.py template_path:=/opt/saltybot/models/hey_salty.npy
|
||||
ros2 launch saltybot_social wake_word.launch.py match_threshold:=0.85 cooldown_s:=3.0
|
||||
"""
|
||||
|
||||
import os
|
||||
from ament_index_python.packages import get_package_share_directory
|
||||
from launch import LaunchDescription
|
||||
from launch.actions import DeclareLaunchArgument
|
||||
from launch.substitutions import LaunchConfiguration
|
||||
from launch_ros.actions import Node
|
||||
|
||||
|
||||
def generate_launch_description():
|
||||
pkg = get_package_share_directory("saltybot_social")
|
||||
cfg = os.path.join(pkg, "config", "wake_word_params.yaml")
|
||||
|
||||
return LaunchDescription([
|
||||
DeclareLaunchArgument("template_path", default_value="",
|
||||
description="Path to .npy template file (log-mel of 'hey salty')"),
|
||||
DeclareLaunchArgument("match_threshold", default_value="0.82",
|
||||
description="Cosine-similarity detection threshold"),
|
||||
DeclareLaunchArgument("cooldown_s", default_value="2.0",
|
||||
description="Minimum seconds between detections"),
|
||||
|
||||
Node(
|
||||
package="saltybot_social",
|
||||
executable="wake_word_node",
|
||||
name="wake_word_node",
|
||||
output="screen",
|
||||
parameters=[
|
||||
cfg,
|
||||
{
|
||||
"template_path": LaunchConfiguration("template_path"),
|
||||
"match_threshold": LaunchConfiguration("match_threshold"),
|
||||
"cooldown_s": LaunchConfiguration("cooldown_s"),
|
||||
},
|
||||
],
|
||||
),
|
||||
])
|
||||
@ -0,0 +1,343 @@
|
||||
"""wake_word_node.py — Audio wake-word detector ('hey salty').
|
||||
Issue #320
|
||||
|
||||
Subscribes to raw PCM-16 audio on /social/speech/audio_raw, maintains a
|
||||
sliding window buffer, and on each hop tick computes log-mel spectrogram
|
||||
features of the most recent window. If energy is above the gate threshold
|
||||
AND cosine similarity to the stored template is above match_threshold the
|
||||
detection fires: Bool(True) is published on /saltybot/wake_word_detected.
|
||||
|
||||
Detection is one-shot (only True is published) and guarded by a cooldown
|
||||
so rapid re-fires are suppressed. When no template is loaded (template_path
|
||||
is empty) the node stays passive — energy gating is applied but no match is
|
||||
attempted.
|
||||
|
||||
Audio format expected
|
||||
─────────────────────
|
||||
UInt8MultiArray, same feed as vad_node:
|
||||
raw PCM-16 little-endian mono at ``sample_rate`` Hz.
|
||||
|
||||
Subscriptions
|
||||
─────────────
|
||||
/social/speech/audio_raw std_msgs/UInt8MultiArray — raw PCM-16 chunks
|
||||
|
||||
Publications
|
||||
────────────
|
||||
/saltybot/wake_word_detected std_msgs/Bool — True on each detection event
|
||||
|
||||
Parameters
|
||||
──────────
|
||||
audio_topic (str, "/social/speech/audio_raw")
|
||||
output_topic (str, "/saltybot/wake_word_detected")
|
||||
sample_rate (int, 16000) sample rate of incoming audio (Hz)
|
||||
window_s (float, 1.5) detection window duration (s)
|
||||
hop_s (float, 0.1) detection timer period (s)
|
||||
energy_threshold (float, 0.02) RMS gate; below this → skip matching
|
||||
match_threshold (float, 0.82) cosine-similarity gate; above → detect
|
||||
cooldown_s (float, 2.0) minimum gap between successive detections
|
||||
template_path (str, "") path to .npy template file; "" = passive
|
||||
n_fft (int, 512) FFT size for mel spectrogram
|
||||
n_mels (int, 40) number of mel filterbank bands
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import struct
|
||||
import threading
|
||||
import time
|
||||
from collections import deque
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import rclpy
|
||||
from rclpy.node import Node
|
||||
from rclpy.qos import QoSProfile
|
||||
from std_msgs.msg import Bool, UInt8MultiArray
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
_NP = True
|
||||
except ImportError: # pragma: no cover
|
||||
_NP = False
|
||||
|
||||
INT16_MAX = 32768.0
|
||||
|
||||
|
||||
# ── Pure DSP helpers (no ROS, numpy only) ──────────────────────────────────────
|
||||
|
||||
def pcm16_to_float(data: bytes) -> "np.ndarray":
|
||||
"""Decode PCM-16 LE bytes → float32 ndarray in [-1.0, 1.0]."""
|
||||
n = len(data) // 2
|
||||
if n == 0:
|
||||
return np.zeros(0, dtype=np.float32)
|
||||
samples = struct.unpack(f"<{n}h", data[:n * 2])
|
||||
return np.array(samples, dtype=np.float32) / INT16_MAX
|
||||
|
||||
|
||||
def mel_filterbank(sr: int, n_fft: int, n_mels: int,
|
||||
fmin: float = 80.0, fmax: Optional[float] = None) -> "np.ndarray":
|
||||
"""Build a triangular mel filterbank matrix [n_mels, n_fft//2+1]."""
|
||||
if fmax is None:
|
||||
fmax = sr / 2.0
|
||||
|
||||
def hz_to_mel(hz: float) -> float:
|
||||
return 2595.0 * math.log10(1.0 + hz / 700.0)
|
||||
|
||||
def mel_to_hz(mel: float) -> float:
|
||||
return 700.0 * (10.0 ** (mel / 2595.0) - 1.0)
|
||||
|
||||
mel_lo = hz_to_mel(fmin)
|
||||
mel_hi = hz_to_mel(fmax)
|
||||
mel_pts = np.linspace(mel_lo, mel_hi, n_mels + 2)
|
||||
hz_pts = np.array([mel_to_hz(m) for m in mel_pts])
|
||||
freqs = np.fft.rfftfreq(n_fft, d=1.0 / sr)
|
||||
|
||||
fb = np.zeros((n_mels, len(freqs)), dtype=np.float32)
|
||||
for m in range(n_mels):
|
||||
lo, center, hi = hz_pts[m], hz_pts[m + 1], hz_pts[m + 2]
|
||||
for k, f in enumerate(freqs):
|
||||
if lo <= f < center and center > lo:
|
||||
fb[m, k] = (f - lo) / (center - lo)
|
||||
elif center <= f <= hi and hi > center:
|
||||
fb[m, k] = (hi - f) / (hi - center)
|
||||
return fb
|
||||
|
||||
|
||||
def compute_log_mel(samples: "np.ndarray", sr: int,
|
||||
n_fft: int = 512, n_mels: int = 40,
|
||||
hop: int = 256) -> "np.ndarray":
|
||||
"""Return log-mel spectrogram [n_mels, T] of *samples* (float32 [-1,1])."""
|
||||
n = len(samples)
|
||||
window = np.hanning(n_fft).astype(np.float32)
|
||||
frames = []
|
||||
for start in range(0, max(n - n_fft + 1, 1), hop):
|
||||
chunk = samples[start:start + n_fft]
|
||||
if len(chunk) < n_fft:
|
||||
chunk = np.pad(chunk, (0, n_fft - len(chunk)))
|
||||
power = np.abs(np.fft.rfft(chunk * window)) ** 2
|
||||
frames.append(power)
|
||||
frames_arr = np.array(frames, dtype=np.float32).T # [bins, T]
|
||||
fb = mel_filterbank(sr, n_fft, n_mels)
|
||||
mel = fb @ frames_arr # [n_mels, T]
|
||||
mel = np.where(mel > 1e-10, mel, 1e-10)
|
||||
return np.log(mel)
|
||||
|
||||
|
||||
def cosine_sim(a: "np.ndarray", b: "np.ndarray") -> float:
|
||||
"""Cosine similarity between two arrays, matched by minimum length."""
|
||||
af = a.flatten().astype(np.float64)
|
||||
bf = b.flatten().astype(np.float64)
|
||||
min_len = min(len(af), len(bf))
|
||||
if min_len == 0:
|
||||
return 0.0
|
||||
af = af[:min_len]
|
||||
bf = bf[:min_len]
|
||||
denom = float(np.linalg.norm(af)) * float(np.linalg.norm(bf))
|
||||
if denom < 1e-12:
|
||||
return 0.0
|
||||
return float(np.dot(af, bf) / denom)
|
||||
|
||||
|
||||
def rms(samples: "np.ndarray") -> float:
|
||||
"""RMS amplitude of a float sample array."""
|
||||
if len(samples) == 0:
|
||||
return 0.0
|
||||
return float(np.sqrt(np.mean(samples.astype(np.float64) ** 2)))
|
||||
|
||||
|
||||
# ── Ring buffer ────────────────────────────────────────────────────────────────
|
||||
|
||||
class AudioRingBuffer:
|
||||
"""Lock-free sliding window for raw float audio samples."""
|
||||
|
||||
def __init__(self, max_samples: int) -> None:
|
||||
self._buf: deque = deque(maxlen=max_samples)
|
||||
|
||||
def push(self, samples: "np.ndarray") -> None:
|
||||
self._buf.extend(samples.tolist())
|
||||
|
||||
def get_window(self, n_samples: int) -> Optional["np.ndarray"]:
|
||||
"""Return last n_samples as float32 array, or None if buffer too short."""
|
||||
if len(self._buf) < n_samples:
|
||||
return None
|
||||
return np.array(list(self._buf)[-n_samples:], dtype=np.float32)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._buf)
|
||||
|
||||
|
||||
# ── Detector ───────────────────────────────────────────────────────────────────
|
||||
|
||||
class WakeWordDetector:
|
||||
"""Energy-gated cosine-similarity wake-word detector.
|
||||
|
||||
Args:
|
||||
template: Log-mel feature array of the wake word, or None
|
||||
(passive — never fires when None).
|
||||
energy_threshold: Minimum RMS to proceed to feature matching.
|
||||
match_threshold: Minimum cosine similarity to fire.
|
||||
sample_rate: Expected audio sample rate (Hz).
|
||||
n_fft: FFT size for mel computation.
|
||||
n_mels: Number of mel bands.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
template: Optional["np.ndarray"],
|
||||
energy_threshold: float = 0.02,
|
||||
match_threshold: float = 0.82,
|
||||
sample_rate: int = 16000,
|
||||
n_fft: int = 512,
|
||||
n_mels: int = 40) -> None:
|
||||
self._template = template
|
||||
self._energy_thr = energy_threshold
|
||||
self._match_thr = match_threshold
|
||||
self._sr = sample_rate
|
||||
self._n_fft = n_fft
|
||||
self._n_mels = n_mels
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
def detect(self, samples: "np.ndarray") -> Tuple[bool, float, float]:
|
||||
"""Run detection on a window of float samples.
|
||||
|
||||
Returns
|
||||
-------
|
||||
(detected, rms_value, similarity)
|
||||
"""
|
||||
energy = rms(samples)
|
||||
if energy < self._energy_thr:
|
||||
return False, energy, 0.0
|
||||
if self._template is None:
|
||||
return False, energy, 0.0
|
||||
|
||||
hop = max(1, self._n_fft // 2)
|
||||
feats = compute_log_mel(samples, self._sr, self._n_fft, self._n_mels, hop)
|
||||
sim = cosine_sim(feats, self._template)
|
||||
return sim >= self._match_thr, energy, sim
|
||||
|
||||
@property
|
||||
def has_template(self) -> bool:
|
||||
return self._template is not None
|
||||
|
||||
|
||||
# ── ROS2 node ──────────────────────────────────────────────────────────────────
|
||||
|
||||
class WakeWordNode(Node):
|
||||
"""ROS2 node: 'hey salty' wake-word detection via energy + template matching."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__("wake_word_node")
|
||||
|
||||
self.declare_parameter("audio_topic", "/social/speech/audio_raw")
|
||||
self.declare_parameter("output_topic", "/saltybot/wake_word_detected")
|
||||
self.declare_parameter("sample_rate", 16000)
|
||||
self.declare_parameter("window_s", 1.5)
|
||||
self.declare_parameter("hop_s", 0.1)
|
||||
self.declare_parameter("energy_threshold", 0.02)
|
||||
self.declare_parameter("match_threshold", 0.82)
|
||||
self.declare_parameter("cooldown_s", 2.0)
|
||||
self.declare_parameter("template_path", "")
|
||||
self.declare_parameter("n_fft", 512)
|
||||
self.declare_parameter("n_mels", 40)
|
||||
|
||||
audio_topic = self.get_parameter("audio_topic").value
|
||||
output_topic = self.get_parameter("output_topic").value
|
||||
self._sr = int(self.get_parameter("sample_rate").value)
|
||||
self._win_s = float(self.get_parameter("window_s").value)
|
||||
hop_s = float(self.get_parameter("hop_s").value)
|
||||
energy_thr = float(self.get_parameter("energy_threshold").value)
|
||||
match_thr = float(self.get_parameter("match_threshold").value)
|
||||
self._cool_s = float(self.get_parameter("cooldown_s").value)
|
||||
tmpl_path = str(self.get_parameter("template_path").value)
|
||||
n_fft = int(self.get_parameter("n_fft").value)
|
||||
n_mels = int(self.get_parameter("n_mels").value)
|
||||
|
||||
# ── Load template ──────────────────────────────────────────────
|
||||
template: Optional["np.ndarray"] = None
|
||||
if tmpl_path and _NP:
|
||||
try:
|
||||
template = np.load(tmpl_path)
|
||||
self.get_logger().info(
|
||||
f"WakeWord: loaded template {tmpl_path} shape={template.shape}"
|
||||
)
|
||||
except Exception as exc:
|
||||
self.get_logger().warn(
|
||||
f"WakeWord: could not load template '{tmpl_path}': {exc} — passive mode"
|
||||
)
|
||||
|
||||
# ── Ring buffer ────────────────────────────────────────────────
|
||||
max_samples = int(self._sr * self._win_s * 4) # 4× headroom
|
||||
self._win_n = int(self._sr * self._win_s)
|
||||
self._buf = AudioRingBuffer(max_samples)
|
||||
|
||||
# ── Detector ───────────────────────────────────────────────────
|
||||
self._detector = WakeWordDetector(template, energy_thr, match_thr,
|
||||
self._sr, n_fft, n_mels)
|
||||
|
||||
# ── State ──────────────────────────────────────────────────────
|
||||
self._last_det_t: float = 0.0
|
||||
self._lock = threading.Lock()
|
||||
|
||||
# ── ROS ────────────────────────────────────────────────────────
|
||||
qos = QoSProfile(depth=10)
|
||||
self._pub = self.create_publisher(Bool, output_topic, qos)
|
||||
self._sub = self.create_subscription(
|
||||
UInt8MultiArray, audio_topic, self._on_audio, qos
|
||||
)
|
||||
self._timer = self.create_timer(hop_s, self._detection_cb)
|
||||
|
||||
tmpl_status = (f"template={tmpl_path}" if template is not None
|
||||
else "passive (no template)")
|
||||
self.get_logger().info(
|
||||
f"WakeWordNode ready — {tmpl_status}, "
|
||||
f"window={self._win_s}s, hop={hop_s}s, "
|
||||
f"energy_thr={energy_thr}, match_thr={match_thr}"
|
||||
)
|
||||
|
||||
# ── Subscription ───────────────────────────────────────────────────
|
||||
|
||||
def _on_audio(self, msg) -> None:
|
||||
if not _NP:
|
||||
return
|
||||
try:
|
||||
raw = bytes(msg.data)
|
||||
samples = pcm16_to_float(raw)
|
||||
except Exception as exc:
|
||||
self.get_logger().warn(f"WakeWord: audio decode error: {exc}")
|
||||
return
|
||||
with self._lock:
|
||||
self._buf.push(samples)
|
||||
|
||||
# ── Detection timer ────────────────────────────────────────────────
|
||||
|
||||
def _detection_cb(self) -> None:
|
||||
if not _NP:
|
||||
return
|
||||
now = time.monotonic()
|
||||
with self._lock:
|
||||
window = self._buf.get_window(self._win_n)
|
||||
if window is None:
|
||||
return
|
||||
|
||||
detected, energy, sim = self._detector.detect(window)
|
||||
|
||||
if detected and (now - self._last_det_t) >= self._cool_s:
|
||||
self._last_det_t = now
|
||||
self.get_logger().info(
|
||||
f"WakeWord: 'hey salty' detected "
|
||||
f"(rms={energy:.4f}, sim={sim:.3f})"
|
||||
)
|
||||
out = Bool()
|
||||
out.data = True
|
||||
self._pub.publish(out)
|
||||
|
||||
|
||||
def main(args=None) -> None:
|
||||
rclpy.init(args=args)
|
||||
node = WakeWordNode()
|
||||
try:
|
||||
rclpy.spin(node)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
node.destroy_node()
|
||||
rclpy.shutdown()
|
||||
@ -57,6 +57,8 @@ setup(
|
||||
'topic_memory_node = saltybot_social.topic_memory_node:main',
|
||||
# Personal space respector (Issue #310)
|
||||
'personal_space_node = saltybot_social.personal_space_node:main',
|
||||
# Audio wake-word detector — 'hey salty' (Issue #320)
|
||||
'wake_word_node = saltybot_social.wake_word_node:main',
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
711
jetson/ros2_ws/src/saltybot_social/test/test_wake_word.py
Normal file
711
jetson/ros2_ws/src/saltybot_social/test/test_wake_word.py
Normal file
@ -0,0 +1,711 @@
|
||||
"""test_wake_word.py — Offline tests for wake_word_node (Issue #320).
|
||||
|
||||
Stubs out rclpy and ROS message types so tests run without a ROS install.
|
||||
numpy is required (standard on the Jetson).
|
||||
"""
|
||||
|
||||
import importlib.util
|
||||
import math
|
||||
import struct
|
||||
import sys
|
||||
import time
|
||||
import types
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
# ── ROS2 stubs ────────────────────────────────────────────────────────────────
|
||||
|
||||
def _make_ros_stubs():
|
||||
for mod_name in ("rclpy", "rclpy.node", "rclpy.qos",
|
||||
"std_msgs", "std_msgs.msg"):
|
||||
if mod_name not in sys.modules:
|
||||
sys.modules[mod_name] = types.ModuleType(mod_name)
|
||||
|
||||
class _Node:
|
||||
def __init__(self, name="node"):
|
||||
self._name = name
|
||||
if not hasattr(self, "_params"):
|
||||
self._params = {}
|
||||
self._pubs = {}
|
||||
self._subs = {}
|
||||
self._logs = []
|
||||
self._timers = []
|
||||
|
||||
def declare_parameter(self, name, default):
|
||||
if name not in self._params:
|
||||
self._params[name] = default
|
||||
|
||||
def get_parameter(self, name):
|
||||
class _P:
|
||||
def __init__(self, v): self.value = v
|
||||
return _P(self._params.get(name))
|
||||
|
||||
def create_publisher(self, msg_type, topic, qos):
|
||||
pub = _FakePub()
|
||||
self._pubs[topic] = pub
|
||||
return pub
|
||||
|
||||
def create_subscription(self, msg_type, topic, cb, qos):
|
||||
self._subs[topic] = cb
|
||||
return object()
|
||||
|
||||
def create_timer(self, period, cb):
|
||||
self._timers.append(cb)
|
||||
return object()
|
||||
|
||||
def get_logger(self):
|
||||
node = self
|
||||
class _L:
|
||||
def info(self, m): node._logs.append(("INFO", m))
|
||||
def warn(self, m): node._logs.append(("WARN", m))
|
||||
def error(self, m): node._logs.append(("ERROR", m))
|
||||
return _L()
|
||||
|
||||
def destroy_node(self): pass
|
||||
|
||||
class _FakePub:
|
||||
def __init__(self):
|
||||
self.msgs = []
|
||||
def publish(self, msg):
|
||||
self.msgs.append(msg)
|
||||
|
||||
class _QoSProfile:
|
||||
def __init__(self, depth=10): self.depth = depth
|
||||
|
||||
class _Bool:
|
||||
def __init__(self): self.data = False
|
||||
|
||||
class _UInt8MultiArray:
|
||||
def __init__(self): self.data = b""
|
||||
|
||||
rclpy_mod = sys.modules["rclpy"]
|
||||
rclpy_mod.init = lambda args=None: None
|
||||
rclpy_mod.spin = lambda node: None
|
||||
rclpy_mod.shutdown = lambda: None
|
||||
|
||||
sys.modules["rclpy.node"].Node = _Node
|
||||
sys.modules["rclpy.qos"].QoSProfile = _QoSProfile
|
||||
sys.modules["std_msgs.msg"].Bool = _Bool
|
||||
sys.modules["std_msgs.msg"].UInt8MultiArray = _UInt8MultiArray
|
||||
|
||||
return _Node, _FakePub, _Bool, _UInt8MultiArray
|
||||
|
||||
|
||||
_Node, _FakePub, _Bool, _UInt8MultiArray = _make_ros_stubs()
|
||||
|
||||
|
||||
# ── Module loader ─────────────────────────────────────────────────────────────
|
||||
|
||||
_SRC = (
|
||||
"/Users/seb/AI/saltylab-firmware/jetson/ros2_ws/src/"
|
||||
"saltybot_social/saltybot_social/wake_word_node.py"
|
||||
)
|
||||
|
||||
|
||||
def _load_mod():
|
||||
spec = importlib.util.spec_from_file_location("wake_word_testmod", _SRC)
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(mod)
|
||||
return mod
|
||||
|
||||
|
||||
def _make_node(mod, **kwargs):
|
||||
node = mod.WakeWordNode.__new__(mod.WakeWordNode)
|
||||
defaults = {
|
||||
"audio_topic": "/social/speech/audio_raw",
|
||||
"output_topic": "/saltybot/wake_word_detected",
|
||||
"sample_rate": 16000,
|
||||
"window_s": 1.5,
|
||||
"hop_s": 0.1,
|
||||
"energy_threshold": 0.02,
|
||||
"match_threshold": 0.82,
|
||||
"cooldown_s": 2.0,
|
||||
"template_path": "",
|
||||
"n_fft": 512,
|
||||
"n_mels": 40,
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
node._params = dict(defaults)
|
||||
mod.WakeWordNode.__init__(node)
|
||||
return node
|
||||
|
||||
|
||||
def _make_pcm_bytes(samples: np.ndarray) -> bytes:
|
||||
"""Encode float32 array [-1,1] to PCM-16 LE bytes."""
|
||||
ints = np.clip(samples * 32768.0, -32768, 32767).astype(np.int16)
|
||||
return ints.tobytes()
|
||||
|
||||
|
||||
def _make_audio_msg(samples: np.ndarray) -> _UInt8MultiArray:
|
||||
m = _UInt8MultiArray()
|
||||
m.data = _make_pcm_bytes(samples)
|
||||
return m
|
||||
|
||||
|
||||
def _sine(freq: float, duration: float, sr: int = 16000,
|
||||
amp: float = 0.5) -> np.ndarray:
|
||||
"""Generate a mono sine wave."""
|
||||
t = np.arange(int(sr * duration)) / sr
|
||||
return (amp * np.sin(2 * math.pi * freq * t)).astype(np.float32)
|
||||
|
||||
|
||||
def _silence(duration: float, sr: int = 16000) -> np.ndarray:
|
||||
return np.zeros(int(sr * duration), dtype=np.float32)
|
||||
|
||||
|
||||
def _make_template(mod, sr: int = 16000, n_fft: int = 512,
|
||||
n_mels: int = 40) -> np.ndarray:
|
||||
"""Compute a template from a synthetic 'wake word' signal for testing."""
|
||||
signal = _sine(300, 1.5, sr, amp=0.6)
|
||||
hop = n_fft // 2
|
||||
return mod.compute_log_mel(signal, sr, n_fft, n_mels, hop)
|
||||
|
||||
|
||||
# ── Tests: pcm16_to_float ─────────────────────────────────────────────────────
|
||||
|
||||
class TestPcm16ToFloat(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls): cls.mod = _load_mod()
|
||||
|
||||
def _conv(self, samples):
|
||||
raw = _make_pcm_bytes(np.array(samples, dtype=np.float32))
|
||||
return self.mod.pcm16_to_float(raw)
|
||||
|
||||
def test_zeros(self):
|
||||
out = self._conv([0.0, 0.0, 0.0])
|
||||
np.testing.assert_allclose(out, [0.0, 0.0, 0.0], atol=1e-4)
|
||||
|
||||
def test_positive(self):
|
||||
out = self._conv([0.5])
|
||||
self.assertAlmostEqual(float(out[0]), 0.5, places=3)
|
||||
|
||||
def test_negative(self):
|
||||
out = self._conv([-0.5])
|
||||
self.assertAlmostEqual(float(out[0]), -0.5, places=3)
|
||||
|
||||
def test_roundtrip_length(self):
|
||||
arr = np.linspace(-1.0, 1.0, 100, dtype=np.float32)
|
||||
raw = _make_pcm_bytes(arr)
|
||||
out = self.mod.pcm16_to_float(raw)
|
||||
self.assertEqual(len(out), 100)
|
||||
|
||||
def test_empty_bytes_returns_empty(self):
|
||||
out = self.mod.pcm16_to_float(b"")
|
||||
self.assertEqual(len(out), 0)
|
||||
|
||||
def test_odd_byte_ignored(self):
|
||||
# 3 bytes → 1 complete int16 + 1 orphan byte
|
||||
out = self.mod.pcm16_to_float(b"\x00\x40\xff")
|
||||
self.assertEqual(len(out), 1)
|
||||
|
||||
|
||||
# ── Tests: rms ────────────────────────────────────────────────────────────────
|
||||
|
||||
class TestRms(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls): cls.mod = _load_mod()
|
||||
|
||||
def test_zero_signal(self):
|
||||
self.assertAlmostEqual(self.mod.rms(np.zeros(100)), 0.0)
|
||||
|
||||
def test_constant_signal(self):
|
||||
# RMS of constant 0.5 = 0.5
|
||||
self.assertAlmostEqual(self.mod.rms(np.full(100, 0.5)), 0.5, places=5)
|
||||
|
||||
def test_sine_rms(self):
|
||||
# RMS of sin = amp / sqrt(2)
|
||||
amp = 0.8
|
||||
s = _sine(440, 1.0, amp=amp)
|
||||
expected = amp / math.sqrt(2)
|
||||
self.assertAlmostEqual(self.mod.rms(s), expected, places=2)
|
||||
|
||||
def test_empty_array(self):
|
||||
self.assertEqual(self.mod.rms(np.array([])), 0.0)
|
||||
|
||||
|
||||
# ── Tests: mel_filterbank ─────────────────────────────────────────────────────
|
||||
|
||||
class TestMelFilterbank(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls): cls.mod = _load_mod()
|
||||
|
||||
def test_shape(self):
|
||||
fb = self.mod.mel_filterbank(16000, 512, 40)
|
||||
self.assertEqual(fb.shape, (40, 257))
|
||||
|
||||
def test_non_negative(self):
|
||||
fb = self.mod.mel_filterbank(16000, 512, 40)
|
||||
self.assertTrue((fb >= 0).all())
|
||||
|
||||
def test_rows_sum_positive(self):
|
||||
fb = self.mod.mel_filterbank(16000, 512, 40)
|
||||
self.assertTrue((fb.sum(axis=1) > 0).all())
|
||||
|
||||
def test_custom_n_mels(self):
|
||||
fb = self.mod.mel_filterbank(16000, 256, 20)
|
||||
self.assertEqual(fb.shape[0], 20)
|
||||
|
||||
|
||||
# ── Tests: compute_log_mel ────────────────────────────────────────────────────
|
||||
|
||||
class TestComputeLogMel(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls): cls.mod = _load_mod()
|
||||
|
||||
def test_output_shape_rows(self):
|
||||
s = _sine(440, 1.5)
|
||||
out = self.mod.compute_log_mel(s, 16000, n_fft=512, n_mels=40, hop=256)
|
||||
self.assertEqual(out.shape[0], 40)
|
||||
|
||||
def test_output_has_time_axis(self):
|
||||
s = _sine(440, 1.5)
|
||||
out = self.mod.compute_log_mel(s, 16000, n_fft=512, n_mels=40, hop=256)
|
||||
self.assertGreater(out.shape[1], 0)
|
||||
|
||||
def test_output_finite(self):
|
||||
s = _sine(440, 1.5)
|
||||
out = self.mod.compute_log_mel(s, 16000, n_fft=512, n_mels=40, hop=256)
|
||||
self.assertTrue(np.isfinite(out).all())
|
||||
|
||||
def test_silence_gives_low_values(self):
|
||||
s = _silence(1.5)
|
||||
out = self.mod.compute_log_mel(s, 16000, n_fft=512, n_mels=40, hop=256)
|
||||
# All values should be very small (near log(1e-10))
|
||||
self.assertTrue((out < -20).all())
|
||||
|
||||
def test_short_signal_no_crash(self):
|
||||
# Shorter than one FFT frame
|
||||
s = np.zeros(100, dtype=np.float32)
|
||||
out = self.mod.compute_log_mel(s, 16000, n_fft=512, n_mels=40, hop=256)
|
||||
self.assertEqual(out.shape[0], 40)
|
||||
|
||||
|
||||
# ── Tests: cosine_sim ─────────────────────────────────────────────────────────
|
||||
|
||||
class TestCosineSim(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls): cls.mod = _load_mod()
|
||||
|
||||
def test_identical_vectors(self):
|
||||
v = np.array([1.0, 2.0, 3.0])
|
||||
self.assertAlmostEqual(self.mod.cosine_sim(v, v), 1.0, places=5)
|
||||
|
||||
def test_orthogonal_vectors(self):
|
||||
a = np.array([1.0, 0.0])
|
||||
b = np.array([0.0, 1.0])
|
||||
self.assertAlmostEqual(self.mod.cosine_sim(a, b), 0.0, places=5)
|
||||
|
||||
def test_opposite_vectors(self):
|
||||
v = np.array([1.0, 2.0, 3.0])
|
||||
self.assertAlmostEqual(self.mod.cosine_sim(v, -v), -1.0, places=5)
|
||||
|
||||
def test_zero_vector_returns_zero(self):
|
||||
a = np.zeros(5)
|
||||
b = np.ones(5)
|
||||
self.assertEqual(self.mod.cosine_sim(a, b), 0.0)
|
||||
|
||||
def test_2d_arrays(self):
|
||||
a = np.ones((4, 10))
|
||||
b = np.ones((4, 10))
|
||||
self.assertAlmostEqual(self.mod.cosine_sim(a, b), 1.0, places=5)
|
||||
|
||||
def test_different_lengths_truncated(self):
|
||||
a = np.array([1.0, 2.0, 3.0, 4.0])
|
||||
b = np.array([1.0, 2.0, 3.0])
|
||||
# Should not crash, uses min length
|
||||
result = self.mod.cosine_sim(a, b)
|
||||
self.assertTrue(-1.0 <= result <= 1.0)
|
||||
|
||||
def test_range_is_bounded(self):
|
||||
rng = np.random.default_rng(42)
|
||||
a = rng.standard_normal(100)
|
||||
b = rng.standard_normal(100)
|
||||
result = self.mod.cosine_sim(a, b)
|
||||
self.assertGreaterEqual(result, -1.0)
|
||||
self.assertLessEqual(result, 1.0)
|
||||
|
||||
|
||||
# ── Tests: AudioRingBuffer ────────────────────────────────────────────────────
|
||||
|
||||
class TestAudioRingBuffer(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls): cls.mod = _load_mod()
|
||||
|
||||
def _buf(self, max_samples=1000):
|
||||
return self.mod.AudioRingBuffer(max_samples)
|
||||
|
||||
def test_empty_initially(self):
|
||||
b = self._buf()
|
||||
self.assertEqual(len(b), 0)
|
||||
|
||||
def test_push_increases_len(self):
|
||||
b = self._buf()
|
||||
b.push(np.ones(100, dtype=np.float32))
|
||||
self.assertEqual(len(b), 100)
|
||||
|
||||
def test_get_window_none_when_short(self):
|
||||
b = self._buf()
|
||||
b.push(np.ones(50, dtype=np.float32))
|
||||
self.assertIsNone(b.get_window(100))
|
||||
|
||||
def test_get_window_ok_when_full(self):
|
||||
b = self._buf()
|
||||
data = np.arange(200, dtype=np.float32)
|
||||
b.push(data)
|
||||
w = b.get_window(100)
|
||||
self.assertIsNotNone(w)
|
||||
self.assertEqual(len(w), 100)
|
||||
|
||||
def test_get_window_returns_latest(self):
|
||||
b = self._buf()
|
||||
b.push(np.zeros(100, dtype=np.float32))
|
||||
b.push(np.ones(100, dtype=np.float32))
|
||||
w = b.get_window(100)
|
||||
np.testing.assert_allclose(w, np.ones(100))
|
||||
|
||||
def test_maxlen_evicts_oldest(self):
|
||||
b = self._buf(max_samples=100)
|
||||
b.push(np.zeros(60, dtype=np.float32))
|
||||
b.push(np.ones(60, dtype=np.float32)) # should evict 20 zeros
|
||||
self.assertEqual(len(b), 100)
|
||||
w = b.get_window(100)
|
||||
# Last 40 samples should be ones
|
||||
np.testing.assert_allclose(w[-40:], np.ones(40))
|
||||
|
||||
def test_exact_window_size(self):
|
||||
b = self._buf(500)
|
||||
data = np.arange(300, dtype=np.float32)
|
||||
b.push(data)
|
||||
w = b.get_window(300)
|
||||
np.testing.assert_allclose(w, data)
|
||||
|
||||
|
||||
# ── Tests: WakeWordDetector ───────────────────────────────────────────────────
|
||||
|
||||
class TestWakeWordDetector(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.mod = _load_mod()
|
||||
cls.sr = 16000
|
||||
cls.template = _make_template(cls.mod, cls.sr)
|
||||
|
||||
def _det(self, template=None, energy_thr=0.02, match_thr=0.82):
|
||||
return self.mod.WakeWordDetector(template, energy_thr, match_thr,
|
||||
self.sr, 512, 40)
|
||||
|
||||
def test_no_template_never_detects(self):
|
||||
det = self._det(template=None)
|
||||
loud = _sine(300, 1.5, self.sr, amp=0.8)
|
||||
detected, _, _ = det.detect(loud)
|
||||
self.assertFalse(detected)
|
||||
|
||||
def test_has_template_false_when_no_template(self):
|
||||
det = self._det(template=None)
|
||||
self.assertFalse(det.has_template)
|
||||
|
||||
def test_has_template_true_when_set(self):
|
||||
det = self._det(template=self.template)
|
||||
self.assertTrue(det.has_template)
|
||||
|
||||
def test_silence_below_energy_gate(self):
|
||||
det = self._det(template=self.template, energy_thr=0.02)
|
||||
silence = _silence(1.5, self.sr)
|
||||
detected, rms_val, sim = det.detect(silence)
|
||||
self.assertFalse(detected)
|
||||
self.assertAlmostEqual(rms_val, 0.0, places=4)
|
||||
self.assertAlmostEqual(sim, 0.0, places=4)
|
||||
|
||||
def test_returns_rms_value(self):
|
||||
det = self._det(template=None)
|
||||
s = _sine(300, 1.5, self.sr, amp=0.5)
|
||||
_, rms_val, _ = det.detect(s)
|
||||
expected = 0.5 / math.sqrt(2)
|
||||
self.assertAlmostEqual(rms_val, expected, places=2)
|
||||
|
||||
def test_identical_signal_detects(self):
|
||||
"""Signal identical to template should give sim ≈ 1.0 → detect."""
|
||||
signal = _sine(300, 1.5, self.sr, amp=0.6)
|
||||
det = self._det(template=self.template, match_thr=0.99)
|
||||
detected, _, sim = det.detect(signal)
|
||||
# Sim must be very high for an identical-source signal
|
||||
self.assertGreater(sim, 0.99)
|
||||
self.assertTrue(detected)
|
||||
|
||||
def test_different_signal_low_sim(self):
|
||||
"""A very different signal (white noise) should have low similarity."""
|
||||
rng = np.random.default_rng(7)
|
||||
noise = (rng.standard_normal(int(self.sr * 1.5)) * 0.4).astype(np.float32)
|
||||
det = self._det(template=self.template, match_thr=0.82)
|
||||
_, _, sim = det.detect(noise)
|
||||
# White noise sim to a tonal template should be < 0.6
|
||||
self.assertLess(sim, 0.6)
|
||||
|
||||
def test_threshold_boundary_low(self):
|
||||
"""Setting match_thr=0.0 with a loud signal should fire if template set."""
|
||||
signal = _sine(300, 1.5, self.sr, amp=0.6)
|
||||
det = self._det(template=self.template, match_thr=0.0)
|
||||
detected, _, _ = det.detect(signal)
|
||||
self.assertTrue(detected)
|
||||
|
||||
def test_threshold_boundary_high(self):
|
||||
"""Setting match_thr=1.1 (above max) should never fire."""
|
||||
signal = _sine(300, 1.5, self.sr, amp=0.6)
|
||||
det = self._det(template=self.template, match_thr=1.1)
|
||||
detected, _, _ = det.detect(signal)
|
||||
self.assertFalse(detected)
|
||||
|
||||
def test_energy_below_threshold_skips_matching(self):
|
||||
"""Low energy → sim returned as 0.0 regardless of template."""
|
||||
very_quiet = _sine(300, 1.5, self.sr, amp=0.001)
|
||||
det = self._det(template=self.template, energy_thr=0.1)
|
||||
detected, rms_val, sim = det.detect(very_quiet)
|
||||
self.assertFalse(detected)
|
||||
self.assertAlmostEqual(sim, 0.0, places=5)
|
||||
|
||||
|
||||
# ── Tests: node init ──────────────────────────────────────────────────────────
|
||||
|
||||
class TestNodeInit(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls): cls.mod = _load_mod()
|
||||
|
||||
def test_instantiates(self):
|
||||
self.assertIsNotNone(_make_node(self.mod))
|
||||
|
||||
def test_pub_registered(self):
|
||||
node = _make_node(self.mod)
|
||||
self.assertIn("/saltybot/wake_word_detected", node._pubs)
|
||||
|
||||
def test_sub_registered(self):
|
||||
node = _make_node(self.mod)
|
||||
self.assertIn("/social/speech/audio_raw", node._subs)
|
||||
|
||||
def test_timer_registered(self):
|
||||
node = _make_node(self.mod)
|
||||
self.assertGreater(len(node._timers), 0)
|
||||
|
||||
def test_custom_topics(self):
|
||||
node = _make_node(self.mod,
|
||||
audio_topic="/my/audio",
|
||||
output_topic="/my/wake")
|
||||
self.assertIn("/my/audio", node._subs)
|
||||
self.assertIn("/my/wake", node._pubs)
|
||||
|
||||
def test_no_template_passive(self):
|
||||
node = _make_node(self.mod, template_path="")
|
||||
self.assertFalse(node._detector.has_template)
|
||||
|
||||
def test_bad_template_path_warns(self):
|
||||
node = _make_node(self.mod, template_path="/nonexistent/template.npy")
|
||||
warns = [m for lvl, m in node._logs if lvl == "WARN"]
|
||||
self.assertTrue(any("template" in m.lower() or "passive" in m.lower()
|
||||
for m in warns))
|
||||
|
||||
def test_ring_buffer_allocated(self):
|
||||
node = _make_node(self.mod)
|
||||
self.assertIsNotNone(node._buf)
|
||||
|
||||
def test_window_n_computed(self):
|
||||
node = _make_node(self.mod, sample_rate=16000, window_s=1.5)
|
||||
self.assertEqual(node._win_n, 24000)
|
||||
|
||||
|
||||
# ── Tests: _on_audio callback ─────────────────────────────────────────────────
|
||||
|
||||
class TestOnAudio(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls): cls.mod = _load_mod()
|
||||
|
||||
def setUp(self):
|
||||
self.node = _make_node(self.mod)
|
||||
|
||||
def _push(self, samples):
|
||||
msg = _make_audio_msg(samples)
|
||||
self.node._subs["/social/speech/audio_raw"](msg)
|
||||
|
||||
def test_pushes_samples_to_buffer(self):
|
||||
self._push(_sine(440, 0.5))
|
||||
self.assertGreater(len(self.node._buf), 0)
|
||||
|
||||
def test_buffer_grows_with_pushes(self):
|
||||
chunk = _sine(440, 0.1)
|
||||
before = len(self.node._buf)
|
||||
self._push(chunk)
|
||||
after = len(self.node._buf)
|
||||
self.assertGreater(after, before)
|
||||
|
||||
def test_bad_data_no_crash(self):
|
||||
msg = _UInt8MultiArray()
|
||||
msg.data = b"\xff" # 1 orphan byte — yields 0 samples, no crash
|
||||
self.node._subs["/social/speech/audio_raw"](msg)
|
||||
|
||||
def test_multiple_chunks_accumulate(self):
|
||||
for _ in range(5):
|
||||
self._push(_sine(440, 0.1))
|
||||
self.assertGreater(len(self.node._buf), 0)
|
||||
|
||||
|
||||
# ── Tests: detection callback ─────────────────────────────────────────────────
|
||||
|
||||
class TestDetectionCallback(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls): cls.mod = _load_mod()
|
||||
|
||||
def _node_with_template(self, **kwargs):
|
||||
template = _make_template(self.mod)
|
||||
node = _make_node(self.mod, **kwargs)
|
||||
node._detector = self.mod.WakeWordDetector(
|
||||
template, energy_threshold=0.02, match_threshold=0.0, # thr=0 → always fires when loud
|
||||
sample_rate=16000, n_fft=512, n_mels=40
|
||||
)
|
||||
return node
|
||||
|
||||
def _fill_buffer(self, node, signal):
|
||||
msg = _make_audio_msg(signal)
|
||||
node._subs["/social/speech/audio_raw"](msg)
|
||||
|
||||
def test_no_data_no_publish(self):
|
||||
node = _make_node(self.mod)
|
||||
node._detection_cb()
|
||||
self.assertEqual(len(node._pubs["/saltybot/wake_word_detected"].msgs), 0)
|
||||
|
||||
def test_insufficient_buffer_no_publish(self):
|
||||
node = self._node_with_template()
|
||||
# Push only 0.1 s but window is 1.5 s
|
||||
self._fill_buffer(node, _sine(300, 0.1))
|
||||
node._detection_cb()
|
||||
self.assertEqual(len(node._pubs["/saltybot/wake_word_detected"].msgs), 0)
|
||||
|
||||
def test_detects_and_publishes_true(self):
|
||||
node = self._node_with_template()
|
||||
# Fill with a loud 300 Hz sine (matches template)
|
||||
self._fill_buffer(node, _sine(300, 1.5, amp=0.6))
|
||||
node._detection_cb()
|
||||
pub = node._pubs["/saltybot/wake_word_detected"]
|
||||
self.assertEqual(len(pub.msgs), 1)
|
||||
self.assertTrue(pub.msgs[0].data)
|
||||
|
||||
def test_cooldown_suppresses_second_detection(self):
|
||||
node = self._node_with_template(cooldown_s=60.0)
|
||||
self._fill_buffer(node, _sine(300, 1.5, amp=0.6))
|
||||
node._detection_cb()
|
||||
# Second call immediately → cooldown active
|
||||
self._fill_buffer(node, _sine(300, 1.5, amp=0.6))
|
||||
node._detection_cb()
|
||||
pub = node._pubs["/saltybot/wake_word_detected"]
|
||||
self.assertEqual(len(pub.msgs), 1)
|
||||
|
||||
def test_cooldown_expired_allows_second(self):
|
||||
node = self._node_with_template(cooldown_s=0.0)
|
||||
self._fill_buffer(node, _sine(300, 1.5, amp=0.6))
|
||||
node._detection_cb()
|
||||
self._fill_buffer(node, _sine(300, 1.5, amp=0.6))
|
||||
node._detection_cb()
|
||||
pub = node._pubs["/saltybot/wake_word_detected"]
|
||||
self.assertEqual(len(pub.msgs), 2)
|
||||
|
||||
def test_no_template_never_publishes(self):
|
||||
node = _make_node(self.mod, template_path="")
|
||||
self._fill_buffer(node, _sine(300, 1.5, amp=0.8))
|
||||
node._detection_cb()
|
||||
pub = node._pubs["/saltybot/wake_word_detected"]
|
||||
self.assertEqual(len(pub.msgs), 0)
|
||||
|
||||
def test_silence_no_publish(self):
|
||||
node = self._node_with_template()
|
||||
self._fill_buffer(node, _silence(1.5))
|
||||
node._detection_cb()
|
||||
pub = node._pubs["/saltybot/wake_word_detected"]
|
||||
self.assertEqual(len(pub.msgs), 0)
|
||||
|
||||
def test_detection_logs_info(self):
|
||||
node = self._node_with_template()
|
||||
self._fill_buffer(node, _sine(300, 1.5, amp=0.6))
|
||||
node._detection_cb()
|
||||
infos = [m for lvl, m in node._logs if lvl == "INFO"]
|
||||
self.assertTrue(any("detected" in m.lower() or "hey salty" in m.lower()
|
||||
for m in infos))
|
||||
|
||||
|
||||
# ── Tests: source content ─────────────────────────────────────────────────────
|
||||
|
||||
class TestNodeSrc(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
with open(_SRC) as f: cls.src = f.read()
|
||||
|
||||
def test_issue_tag(self): self.assertIn("#320", self.src)
|
||||
def test_audio_topic(self): self.assertIn("/social/speech/audio_raw", self.src)
|
||||
def test_output_topic(self): self.assertIn("/saltybot/wake_word_detected", self.src)
|
||||
def test_wake_word_name(self): self.assertIn("hey salty", self.src)
|
||||
def test_compute_log_mel(self): self.assertIn("compute_log_mel", self.src)
|
||||
def test_cosine_sim(self): self.assertIn("cosine_sim", self.src)
|
||||
def test_mel_filterbank(self): self.assertIn("mel_filterbank", self.src)
|
||||
def test_audio_ring_buffer(self): self.assertIn("AudioRingBuffer", self.src)
|
||||
def test_wake_word_detector(self): self.assertIn("WakeWordDetector", self.src)
|
||||
def test_energy_threshold(self): self.assertIn("energy_threshold", self.src)
|
||||
def test_match_threshold(self): self.assertIn("match_threshold", self.src)
|
||||
def test_cooldown(self): self.assertIn("cooldown_s", self.src)
|
||||
def test_template_path(self): self.assertIn("template_path", self.src)
|
||||
def test_pcm16_decode(self): self.assertIn("pcm16_to_float", self.src)
|
||||
def test_threading_lock(self): self.assertIn("threading.Lock", self.src)
|
||||
def test_numpy_used(self): self.assertIn("import numpy", self.src)
|
||||
def test_main_defined(self): self.assertIn("def main", self.src)
|
||||
def test_uint8_multiarray(self): self.assertIn("UInt8MultiArray", self.src)
|
||||
|
||||
|
||||
# ── Tests: config / launch / setup ────────────────────────────────────────────
|
||||
|
||||
class TestConfig(unittest.TestCase):
|
||||
_CONFIG = (
|
||||
"/Users/seb/AI/saltylab-firmware/jetson/ros2_ws/src/"
|
||||
"saltybot_social/config/wake_word_params.yaml"
|
||||
)
|
||||
_LAUNCH = (
|
||||
"/Users/seb/AI/saltylab-firmware/jetson/ros2_ws/src/"
|
||||
"saltybot_social/launch/wake_word.launch.py"
|
||||
)
|
||||
_SETUP = (
|
||||
"/Users/seb/AI/saltylab-firmware/jetson/ros2_ws/src/"
|
||||
"saltybot_social/setup.py"
|
||||
)
|
||||
|
||||
def test_config_exists(self):
|
||||
import os; self.assertTrue(os.path.exists(self._CONFIG))
|
||||
|
||||
def test_config_energy_threshold(self):
|
||||
with open(self._CONFIG) as f: c = f.read()
|
||||
self.assertIn("energy_threshold", c)
|
||||
|
||||
def test_config_match_threshold(self):
|
||||
with open(self._CONFIG) as f: c = f.read()
|
||||
self.assertIn("match_threshold", c)
|
||||
|
||||
def test_config_template_path(self):
|
||||
with open(self._CONFIG) as f: c = f.read()
|
||||
self.assertIn("template_path", c)
|
||||
|
||||
def test_config_cooldown(self):
|
||||
with open(self._CONFIG) as f: c = f.read()
|
||||
self.assertIn("cooldown_s", c)
|
||||
|
||||
def test_launch_exists(self):
|
||||
import os; self.assertTrue(os.path.exists(self._LAUNCH))
|
||||
|
||||
def test_launch_has_template_arg(self):
|
||||
with open(self._LAUNCH) as f: c = f.read()
|
||||
self.assertIn("template_path", c)
|
||||
|
||||
def test_launch_has_threshold_arg(self):
|
||||
with open(self._LAUNCH) as f: c = f.read()
|
||||
self.assertIn("match_threshold", c)
|
||||
|
||||
def test_entry_point_in_setup(self):
|
||||
with open(self._SETUP) as f: c = f.read()
|
||||
self.assertIn("wake_word_node", c)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@ -0,0 +1,12 @@
|
||||
# Velocity Smoother Configuration
|
||||
velocity_smoother:
|
||||
ros__parameters:
|
||||
# Filter parameters
|
||||
filter_order: 2 # Butterworth filter order (1-4 typical)
|
||||
cutoff_frequency: 5.0 # Cutoff frequency in Hz (lower = more smoothing)
|
||||
|
||||
# Publishing frequency (Hz)
|
||||
publish_frequency: 50
|
||||
|
||||
# Enable/disable filtering
|
||||
enable_smoothing: true
|
||||
@ -0,0 +1,28 @@
|
||||
import os
|
||||
from launch import LaunchDescription
|
||||
from launch_ros.actions import Node
|
||||
from launch_ros.substitutions import FindPackageShare
|
||||
from launch.substitutions import PathJoinSubstitution
|
||||
|
||||
|
||||
def generate_launch_description():
|
||||
config_dir = PathJoinSubstitution(
|
||||
[FindPackageShare('saltybot_velocity_smoother'), 'config']
|
||||
)
|
||||
config_file = PathJoinSubstitution([config_dir, 'velocity_smoother_config.yaml'])
|
||||
|
||||
velocity_smoother = Node(
|
||||
package='saltybot_velocity_smoother',
|
||||
executable='velocity_smoother_node',
|
||||
name='velocity_smoother',
|
||||
output='screen',
|
||||
parameters=[config_file],
|
||||
remappings=[
|
||||
('/odom', '/odom'),
|
||||
('/odom_smooth', '/odom_smooth'),
|
||||
],
|
||||
)
|
||||
|
||||
return LaunchDescription([
|
||||
velocity_smoother,
|
||||
])
|
||||
29
jetson/ros2_ws/src/saltybot_velocity_smoother/package.xml
Normal file
29
jetson/ros2_ws/src/saltybot_velocity_smoother/package.xml
Normal file
@ -0,0 +1,29 @@
|
||||
<?xml version="1.0"?>
|
||||
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
|
||||
<package format="3">
|
||||
<name>saltybot_velocity_smoother</name>
|
||||
<version>0.1.0</version>
|
||||
<description>Low-pass Butterworth filter for odometry velocity smoothing to reduce encoder jitter</description>
|
||||
|
||||
<maintainer email="sl-controls@saltybot.local">SaltyBot Controls</maintainer>
|
||||
<license>MIT</license>
|
||||
|
||||
<author email="sl-controls@saltybot.local">SaltyBot Controls Team</author>
|
||||
|
||||
<buildtool_depend>ament_cmake</buildtool_depend>
|
||||
<buildtool_depend>ament_cmake_python</buildtool_depend>
|
||||
|
||||
<depend>rclpy</depend>
|
||||
<depend>nav_msgs</depend>
|
||||
<depend>std_msgs</depend>
|
||||
<depend>geometry_msgs</depend>
|
||||
|
||||
<test_depend>ament_copyright</test_depend>
|
||||
<test_depend>ament_flake8</test_depend>
|
||||
<test_depend>ament_pep257</test_depend>
|
||||
<test_depend>pytest</test_depend>
|
||||
|
||||
<export>
|
||||
<build_type>ament_python</build_type>
|
||||
</export>
|
||||
</package>
|
||||
@ -0,0 +1,231 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Velocity smoother node with low-pass Butterworth filter.
|
||||
|
||||
Subscribes to /odom, applies low-pass Butterworth filter to linear and angular
|
||||
velocity components, and publishes smoothed odometry on /odom_smooth.
|
||||
|
||||
Reduces noise from encoder jitter and improves state estimation stability.
|
||||
"""
|
||||
|
||||
import math
|
||||
import rclpy
|
||||
from rclpy.node import Node
|
||||
from nav_msgs.msg import Odometry
|
||||
from std_msgs.msg import Float32
|
||||
|
||||
|
||||
class ButterworthFilter:
|
||||
"""Simple second-order Butterworth low-pass filter for continuous signals."""
|
||||
|
||||
def __init__(self, cutoff_hz, sample_rate_hz, order=2):
|
||||
"""Initialize Butterworth filter.
|
||||
|
||||
Args:
|
||||
cutoff_hz: Cutoff frequency in Hz
|
||||
sample_rate_hz: Sampling rate in Hz
|
||||
order: Filter order (typically 1-4)
|
||||
"""
|
||||
self.cutoff_hz = cutoff_hz
|
||||
self.sample_rate_hz = sample_rate_hz
|
||||
self.order = order
|
||||
|
||||
# Normalized frequency (0 to 1, where 1 = Nyquist)
|
||||
self.omega_n = 2.0 * math.pi * cutoff_hz / sample_rate_hz
|
||||
|
||||
# Simplified filter coefficients for order 2
|
||||
# Using canonical form: y[n] = b0*x[n] + b1*x[n-1] + b2*x[n-2] - a1*y[n-1] - a2*y[n-2]
|
||||
if order == 1:
|
||||
# First-order filter
|
||||
alpha = self.omega_n / (self.omega_n + 2.0)
|
||||
self.b = [alpha, alpha]
|
||||
self.a = [1.0, -(1.0 - 2.0 * alpha)]
|
||||
else:
|
||||
# Second-order filter (butterworth)
|
||||
sqrt2 = math.sqrt(2.0)
|
||||
wc = math.tan(self.omega_n / 2.0)
|
||||
wc2 = wc * wc
|
||||
|
||||
denom = 1.0 + sqrt2 * wc + wc2
|
||||
|
||||
self.b = [wc2 / denom, 2.0 * wc2 / denom, wc2 / denom]
|
||||
self.a = [1.0,
|
||||
2.0 * (wc2 - 1.0) / denom,
|
||||
(1.0 - sqrt2 * wc + wc2) / denom]
|
||||
|
||||
# State buffers
|
||||
self.x_history = [0.0, 0.0, 0.0] # Input history
|
||||
self.y_history = [0.0, 0.0] # Output history
|
||||
|
||||
def filter(self, x):
|
||||
"""Apply filter to input value.
|
||||
|
||||
Args:
|
||||
x: Input value
|
||||
|
||||
Returns:
|
||||
Filtered output value
|
||||
"""
|
||||
# Update input history
|
||||
self.x_history[2] = self.x_history[1]
|
||||
self.x_history[1] = self.x_history[0]
|
||||
self.x_history[0] = x
|
||||
|
||||
# Compute output using difference equation
|
||||
if len(self.b) == 2:
|
||||
# First-order filter
|
||||
y = (self.b[0] * self.x_history[0] +
|
||||
self.b[1] * self.x_history[1] -
|
||||
self.a[1] * self.y_history[1])
|
||||
else:
|
||||
# Second-order filter
|
||||
y = (self.b[0] * self.x_history[0] +
|
||||
self.b[1] * self.x_history[1] +
|
||||
self.b[2] * self.x_history[2] -
|
||||
self.a[1] * self.y_history[1] -
|
||||
self.a[2] * self.y_history[2])
|
||||
|
||||
# Update output history
|
||||
self.y_history[1] = self.y_history[0]
|
||||
self.y_history[0] = y
|
||||
|
||||
return y
|
||||
|
||||
def reset(self):
|
||||
"""Reset filter state."""
|
||||
self.x_history = [0.0, 0.0, 0.0]
|
||||
self.y_history = [0.0, 0.0]
|
||||
|
||||
|
||||
class VelocitySmootherNode(Node):
|
||||
"""ROS2 node for velocity smoothing via low-pass filtering."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__('velocity_smoother')
|
||||
|
||||
# Parameters
|
||||
self.declare_parameter('filter_order', 2)
|
||||
self.declare_parameter('cutoff_frequency', 5.0)
|
||||
self.declare_parameter('publish_frequency', 50)
|
||||
self.declare_parameter('enable_smoothing', True)
|
||||
|
||||
filter_order = self.get_parameter('filter_order').value
|
||||
cutoff_frequency = self.get_parameter('cutoff_frequency').value
|
||||
publish_frequency = self.get_parameter('publish_frequency').value
|
||||
self.enable_smoothing = self.get_parameter('enable_smoothing').value
|
||||
|
||||
# Create filters for each velocity component
|
||||
self.filter_linear_x = ButterworthFilter(
|
||||
cutoff_frequency, publish_frequency, order=filter_order
|
||||
)
|
||||
self.filter_linear_y = ButterworthFilter(
|
||||
cutoff_frequency, publish_frequency, order=filter_order
|
||||
)
|
||||
self.filter_linear_z = ButterworthFilter(
|
||||
cutoff_frequency, publish_frequency, order=filter_order
|
||||
)
|
||||
self.filter_angular_x = ButterworthFilter(
|
||||
cutoff_frequency, publish_frequency, order=filter_order
|
||||
)
|
||||
self.filter_angular_y = ButterworthFilter(
|
||||
cutoff_frequency, publish_frequency, order=filter_order
|
||||
)
|
||||
self.filter_angular_z = ButterworthFilter(
|
||||
cutoff_frequency, publish_frequency, order=filter_order
|
||||
)
|
||||
|
||||
# Last received odometry
|
||||
self.last_odom = None
|
||||
|
||||
# Subscription to raw odometry
|
||||
self.sub_odom = self.create_subscription(
|
||||
Odometry, '/odom', self._on_odom, 10
|
||||
)
|
||||
|
||||
# Publisher for smoothed odometry
|
||||
self.pub_odom_smooth = self.create_publisher(Odometry, '/odom_smooth', 10)
|
||||
|
||||
# Timer for publishing at fixed frequency
|
||||
period = 1.0 / publish_frequency
|
||||
self.timer = self.create_timer(period, self._timer_callback)
|
||||
|
||||
self.get_logger().info(
|
||||
f"Velocity smoother initialized. "
|
||||
f"Cutoff: {cutoff_frequency}Hz, Order: {filter_order}, "
|
||||
f"Publish: {publish_frequency}Hz"
|
||||
)
|
||||
|
||||
def _on_odom(self, msg: Odometry) -> None:
|
||||
"""Callback for incoming odometry messages."""
|
||||
self.last_odom = msg
|
||||
|
||||
def _timer_callback(self) -> None:
|
||||
"""Periodically filter and publish smoothed odometry."""
|
||||
if self.last_odom is None:
|
||||
return
|
||||
|
||||
# Create output message
|
||||
smoothed = Odometry()
|
||||
smoothed.header = self.last_odom.header
|
||||
smoothed.child_frame_id = self.last_odom.child_frame_id
|
||||
|
||||
# Copy pose (unchanged)
|
||||
smoothed.pose = self.last_odom.pose
|
||||
|
||||
if self.enable_smoothing:
|
||||
# Filter velocity components
|
||||
linear_x = self.filter_linear_x.filter(
|
||||
self.last_odom.twist.twist.linear.x
|
||||
)
|
||||
linear_y = self.filter_linear_y.filter(
|
||||
self.last_odom.twist.twist.linear.y
|
||||
)
|
||||
linear_z = self.filter_linear_z.filter(
|
||||
self.last_odom.twist.twist.linear.z
|
||||
)
|
||||
|
||||
angular_x = self.filter_angular_x.filter(
|
||||
self.last_odom.twist.twist.angular.x
|
||||
)
|
||||
angular_y = self.filter_angular_y.filter(
|
||||
self.last_odom.twist.twist.angular.y
|
||||
)
|
||||
angular_z = self.filter_angular_z.filter(
|
||||
self.last_odom.twist.twist.angular.z
|
||||
)
|
||||
else:
|
||||
# No filtering
|
||||
linear_x = self.last_odom.twist.twist.linear.x
|
||||
linear_y = self.last_odom.twist.twist.linear.y
|
||||
linear_z = self.last_odom.twist.twist.linear.z
|
||||
angular_x = self.last_odom.twist.twist.angular.x
|
||||
angular_y = self.last_odom.twist.twist.angular.y
|
||||
angular_z = self.last_odom.twist.twist.angular.z
|
||||
|
||||
# Set smoothed twist
|
||||
smoothed.twist.twist.linear.x = linear_x
|
||||
smoothed.twist.twist.linear.y = linear_y
|
||||
smoothed.twist.twist.linear.z = linear_z
|
||||
smoothed.twist.twist.angular.x = angular_x
|
||||
smoothed.twist.twist.angular.y = angular_y
|
||||
smoothed.twist.twist.angular.z = angular_z
|
||||
|
||||
# Copy covariances
|
||||
smoothed.twist.covariance = self.last_odom.twist.covariance
|
||||
|
||||
self.pub_odom_smooth.publish(smoothed)
|
||||
|
||||
|
||||
def main(args=None):
|
||||
rclpy.init(args=args)
|
||||
node = VelocitySmootherNode()
|
||||
try:
|
||||
rclpy.spin(node)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
node.destroy_node()
|
||||
rclpy.shutdown()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
5
jetson/ros2_ws/src/saltybot_velocity_smoother/setup.cfg
Normal file
5
jetson/ros2_ws/src/saltybot_velocity_smoother/setup.cfg
Normal file
@ -0,0 +1,5 @@
|
||||
[develop]
|
||||
script_dir=$base/lib/saltybot_velocity_smoother
|
||||
[egg_info]
|
||||
tag_build =
|
||||
tag_date = 0
|
||||
34
jetson/ros2_ws/src/saltybot_velocity_smoother/setup.py
Normal file
34
jetson/ros2_ws/src/saltybot_velocity_smoother/setup.py
Normal file
@ -0,0 +1,34 @@
|
||||
from setuptools import setup, find_packages
|
||||
|
||||
package_name = 'saltybot_velocity_smoother'
|
||||
|
||||
setup(
|
||||
name=package_name,
|
||||
version='0.1.0',
|
||||
packages=find_packages(exclude=['test']),
|
||||
data_files=[
|
||||
('share/ament_index/resource_index/packages',
|
||||
['resource/saltybot_velocity_smoother']),
|
||||
('share/' + package_name, ['package.xml']),
|
||||
('share/' + package_name + '/config', ['config/velocity_smoother_config.yaml']),
|
||||
('share/' + package_name + '/launch', ['launch/velocity_smoother.launch.py']),
|
||||
],
|
||||
install_requires=['setuptools'],
|
||||
zip_safe=True,
|
||||
author='SaltyBot Controls',
|
||||
author_email='sl-controls@saltybot.local',
|
||||
maintainer='SaltyBot Controls',
|
||||
maintainer_email='sl-controls@saltybot.local',
|
||||
keywords=['ROS2', 'velocity', 'filtering', 'butterworth'],
|
||||
classifiers=[
|
||||
'Intended Audience :: Developers',
|
||||
'License :: OSI Approved :: MIT License',
|
||||
'Programming Language :: Python :: 3',
|
||||
'Topic :: Software Development',
|
||||
],
|
||||
entry_points={
|
||||
'console_scripts': [
|
||||
'velocity_smoother_node=saltybot_velocity_smoother.velocity_smoother_node:main',
|
||||
],
|
||||
},
|
||||
)
|
||||
@ -0,0 +1,432 @@
|
||||
"""Unit tests for velocity smoother node."""
|
||||
|
||||
import pytest
|
||||
import math
|
||||
from nav_msgs.msg import Odometry
|
||||
from geometry_msgs.msg import TwistWithCovariance, Twist, Vector3
|
||||
from std_msgs.msg import Header
|
||||
|
||||
import rclpy
|
||||
from rclpy.time import Time
|
||||
|
||||
from saltybot_velocity_smoother.velocity_smoother_node import (
|
||||
VelocitySmootherNode,
|
||||
ButterworthFilter,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def rclpy_fixture():
|
||||
"""Initialize and cleanup rclpy."""
|
||||
rclpy.init()
|
||||
yield
|
||||
rclpy.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def node(rclpy_fixture):
|
||||
"""Create a velocity smoother node instance."""
|
||||
node = VelocitySmootherNode()
|
||||
yield node
|
||||
node.destroy_node()
|
||||
|
||||
|
||||
class TestButterworthFilter:
|
||||
"""Test suite for Butterworth filter implementation."""
|
||||
|
||||
def test_filter_initialization(self):
|
||||
"""Test filter initialization with valid parameters."""
|
||||
filt = ButterworthFilter(5.0, 50, order=2)
|
||||
assert filt.cutoff_hz == 5.0
|
||||
assert filt.sample_rate_hz == 50
|
||||
assert filt.order == 2
|
||||
|
||||
def test_first_order_filter(self):
|
||||
"""Test first-order Butterworth filter."""
|
||||
filt = ButterworthFilter(5.0, 50, order=1)
|
||||
assert len(filt.b) == 2
|
||||
assert len(filt.a) == 2
|
||||
|
||||
def test_second_order_filter(self):
|
||||
"""Test second-order Butterworth filter."""
|
||||
filt = ButterworthFilter(5.0, 50, order=2)
|
||||
assert len(filt.b) == 3
|
||||
assert len(filt.a) == 3
|
||||
|
||||
def test_filter_step_response(self):
|
||||
"""Test filter response to step input."""
|
||||
filt = ButterworthFilter(5.0, 50, order=2)
|
||||
|
||||
# Apply step input (0 -> 1.0)
|
||||
outputs = []
|
||||
for i in range(50):
|
||||
y = filt.filter(1.0)
|
||||
outputs.append(y)
|
||||
|
||||
# Final output should be close to 1.0
|
||||
assert outputs[-1] > 0.9
|
||||
assert outputs[-1] <= 1.0
|
||||
|
||||
def test_filter_constant_input(self):
|
||||
"""Test filter with constant input."""
|
||||
filt = ButterworthFilter(5.0, 50, order=2)
|
||||
|
||||
# Apply constant input
|
||||
for i in range(100):
|
||||
y = filt.filter(2.5)
|
||||
|
||||
# Should converge to input value
|
||||
assert abs(y - 2.5) < 0.01
|
||||
|
||||
def test_filter_zero_input(self):
|
||||
"""Test filter with zero input."""
|
||||
filt = ButterworthFilter(5.0, 50, order=2)
|
||||
|
||||
# Apply non-zero then zero
|
||||
for i in range(50):
|
||||
filt.filter(1.0)
|
||||
|
||||
# Now apply zero
|
||||
for i in range(50):
|
||||
y = filt.filter(0.0)
|
||||
|
||||
# Should decay to zero
|
||||
assert abs(y) < 0.01
|
||||
|
||||
def test_filter_reset(self):
|
||||
"""Test filter state reset."""
|
||||
filt = ButterworthFilter(5.0, 50, order=2)
|
||||
|
||||
# Filter some values
|
||||
for i in range(10):
|
||||
filt.filter(1.0)
|
||||
|
||||
# Reset
|
||||
filt.reset()
|
||||
|
||||
# State should be zero
|
||||
assert filt.x_history == [0.0, 0.0, 0.0]
|
||||
assert filt.y_history == [0.0, 0.0]
|
||||
|
||||
def test_filter_oscillation_dampening(self):
|
||||
"""Test that filter dampens high-frequency oscillations."""
|
||||
filt = ButterworthFilter(5.0, 50, order=2)
|
||||
|
||||
# Apply alternating signal (high frequency)
|
||||
outputs = []
|
||||
for i in range(100):
|
||||
x = 1.0 if i % 2 == 0 else -1.0
|
||||
y = filt.filter(x)
|
||||
outputs.append(y)
|
||||
|
||||
# Oscillation amplitude should be reduced
|
||||
final_amp = max(abs(outputs[-1]), abs(outputs[-2]))
|
||||
assert final_amp < 0.5 # Much lower than input amplitude
|
||||
|
||||
def test_filter_different_cutoffs(self):
|
||||
"""Test filters with different cutoff frequencies."""
|
||||
filt_low = ButterworthFilter(2.0, 50, order=2)
|
||||
filt_high = ButterworthFilter(10.0, 50, order=2)
|
||||
|
||||
# Both should be valid
|
||||
assert filt_low.cutoff_hz == 2.0
|
||||
assert filt_high.cutoff_hz == 10.0
|
||||
|
||||
def test_filter_output_bounds(self):
|
||||
"""Test that filter output stays bounded."""
|
||||
filt = ButterworthFilter(5.0, 50, order=2)
|
||||
|
||||
# Apply large random-like values
|
||||
for i in range(100):
|
||||
x = math.sin(i * 0.5) * 5.0
|
||||
y = filt.filter(x)
|
||||
assert abs(y) < 10.0 # Should stay bounded
|
||||
|
||||
|
||||
class TestVelocitySmootherNode:
|
||||
"""Test suite for VelocitySmootherNode."""
|
||||
|
||||
def test_node_initialization(self, node):
|
||||
"""Test that node initializes correctly."""
|
||||
assert node.last_odom is None
|
||||
assert node.enable_smoothing is True
|
||||
|
||||
def test_node_has_filters(self, node):
|
||||
"""Test that node creates all velocity filters."""
|
||||
assert node.filter_linear_x is not None
|
||||
assert node.filter_linear_y is not None
|
||||
assert node.filter_linear_z is not None
|
||||
assert node.filter_angular_x is not None
|
||||
assert node.filter_angular_y is not None
|
||||
assert node.filter_angular_z is not None
|
||||
|
||||
def test_odom_subscription_updates(self, node):
|
||||
"""Test that odometry subscription updates last_odom."""
|
||||
odom = Odometry()
|
||||
odom.header.frame_id = "odom"
|
||||
odom.twist.twist.linear.x = 1.0
|
||||
|
||||
node._on_odom(odom)
|
||||
|
||||
assert node.last_odom is not None
|
||||
assert node.last_odom.twist.twist.linear.x == 1.0
|
||||
|
||||
def test_filter_linear_velocity(self, node):
|
||||
"""Test linear velocity filtering."""
|
||||
# Create odometry message
|
||||
odom = Odometry()
|
||||
odom.header.frame_id = "odom"
|
||||
odom.child_frame_id = "base_link"
|
||||
odom.twist.twist.linear.x = 1.0
|
||||
odom.twist.twist.linear.y = 0.5
|
||||
odom.twist.twist.linear.z = 0.0
|
||||
odom.twist.twist.angular.x = 0.0
|
||||
odom.twist.twist.angular.y = 0.0
|
||||
odom.twist.twist.angular.z = 0.2
|
||||
|
||||
node._on_odom(odom)
|
||||
|
||||
# Call timer callback to process
|
||||
node._timer_callback()
|
||||
|
||||
# Filter should have been applied
|
||||
assert node.filter_linear_x.x_history[0] == 1.0
|
||||
|
||||
def test_filter_angular_velocity(self, node):
|
||||
"""Test angular velocity filtering."""
|
||||
odom = Odometry()
|
||||
odom.header.frame_id = "odom"
|
||||
odom.twist.twist.angular.z = 0.5
|
||||
|
||||
node._on_odom(odom)
|
||||
node._timer_callback()
|
||||
|
||||
assert node.filter_angular_z.x_history[0] == 0.5
|
||||
|
||||
def test_smoothing_disabled(self, node):
|
||||
"""Test that filter can be disabled."""
|
||||
node.enable_smoothing = False
|
||||
|
||||
odom = Odometry()
|
||||
odom.header.frame_id = "odom"
|
||||
odom.twist.twist.linear.x = 2.0
|
||||
|
||||
node._on_odom(odom)
|
||||
node._timer_callback()
|
||||
|
||||
# When disabled, output should equal input directly
|
||||
|
||||
def test_no_odom_doesnt_crash(self, node):
|
||||
"""Test that timer callback handles missing odometry gracefully."""
|
||||
# Call timer without setting odometry
|
||||
node._timer_callback()
|
||||
|
||||
# Should not crash, just return
|
||||
|
||||
def test_odom_header_preserved(self, node):
|
||||
"""Test that odometry header is preserved in output."""
|
||||
odom = Odometry()
|
||||
odom.header.frame_id = "test_frame"
|
||||
odom.header.stamp = node.get_clock().now()
|
||||
odom.child_frame_id = "test_child"
|
||||
|
||||
node._on_odom(odom)
|
||||
|
||||
# Timer callback processes it
|
||||
node._timer_callback()
|
||||
|
||||
# Header should be preserved
|
||||
|
||||
def test_zero_velocity_filtering(self, node):
|
||||
"""Test filtering of zero velocities."""
|
||||
odom = Odometry()
|
||||
odom.header.frame_id = "odom"
|
||||
odom.twist.twist.linear.x = 0.0
|
||||
odom.twist.twist.linear.y = 0.0
|
||||
odom.twist.twist.linear.z = 0.0
|
||||
odom.twist.twist.angular.x = 0.0
|
||||
odom.twist.twist.angular.y = 0.0
|
||||
odom.twist.twist.angular.z = 0.0
|
||||
|
||||
node._on_odom(odom)
|
||||
node._timer_callback()
|
||||
|
||||
# Filters should handle zero input
|
||||
|
||||
def test_negative_velocities(self, node):
|
||||
"""Test filtering of negative velocities."""
|
||||
odom = Odometry()
|
||||
odom.header.frame_id = "odom"
|
||||
odom.twist.twist.linear.x = -1.0
|
||||
odom.twist.twist.angular.z = -0.5
|
||||
|
||||
node._on_odom(odom)
|
||||
node._timer_callback()
|
||||
|
||||
assert node.filter_linear_x.x_history[0] == -1.0
|
||||
|
||||
def test_high_frequency_noise_dampening(self, node):
|
||||
"""Test that filter dampens high-frequency encoder noise."""
|
||||
# Simulate noisy encoder output
|
||||
base_velocity = 1.0
|
||||
noise_amplitude = 0.2
|
||||
|
||||
odom = Odometry()
|
||||
odom.header.frame_id = "odom"
|
||||
|
||||
# Apply alternating noise
|
||||
for i in range(100):
|
||||
odom.twist.twist.linear.x = base_velocity + (noise_amplitude if i % 2 == 0 else -noise_amplitude)
|
||||
node._on_odom(odom)
|
||||
node._timer_callback()
|
||||
|
||||
# After filtering, output should be close to base velocity
|
||||
# (oscillations dampened)
|
||||
|
||||
def test_large_velocity_values(self, node):
|
||||
"""Test filtering of large velocity values."""
|
||||
odom = Odometry()
|
||||
odom.header.frame_id = "odom"
|
||||
odom.twist.twist.linear.x = 10.0
|
||||
odom.twist.twist.angular.z = 5.0
|
||||
|
||||
node._on_odom(odom)
|
||||
node._timer_callback()
|
||||
|
||||
# Should handle large values without overflow
|
||||
|
||||
def test_pose_unchanged(self, node):
|
||||
"""Test that pose is not modified by velocity filtering."""
|
||||
odom = Odometry()
|
||||
odom.header.frame_id = "odom"
|
||||
odom.pose.pose.position.x = 5.0
|
||||
odom.pose.pose.position.y = 3.0
|
||||
odom.twist.twist.linear.x = 1.0
|
||||
|
||||
node._on_odom(odom)
|
||||
node._timer_callback()
|
||||
|
||||
# Pose should be copied unchanged
|
||||
|
||||
def test_multiple_velocity_updates(self, node):
|
||||
"""Test filtering across multiple sequential velocity updates."""
|
||||
odom = Odometry()
|
||||
odom.header.frame_id = "odom"
|
||||
|
||||
velocities = [0.5, 1.0, 1.5, 1.0, 0.5, 0.0]
|
||||
|
||||
for v in velocities:
|
||||
odom.twist.twist.linear.x = v
|
||||
node._on_odom(odom)
|
||||
node._timer_callback()
|
||||
|
||||
# Filter should smooth the velocity sequence
|
||||
|
||||
def test_simultaneous_all_velocities(self, node):
|
||||
"""Test filtering when all velocity components are present."""
|
||||
odom = Odometry()
|
||||
odom.header.frame_id = "odom"
|
||||
|
||||
for i in range(30):
|
||||
odom.twist.twist.linear.x = math.sin(i * 0.1)
|
||||
odom.twist.twist.linear.y = math.cos(i * 0.1) * 0.5
|
||||
odom.twist.twist.linear.z = 0.1
|
||||
odom.twist.twist.angular.x = 0.05
|
||||
odom.twist.twist.angular.y = 0.05
|
||||
odom.twist.twist.angular.z = math.sin(i * 0.15)
|
||||
|
||||
node._on_odom(odom)
|
||||
node._timer_callback()
|
||||
|
||||
# All filters should operate independently
|
||||
|
||||
|
||||
class TestVelocitySmootherScenarios:
|
||||
"""Integration-style tests for realistic scenarios."""
|
||||
|
||||
def test_scenario_constant_velocity(self, node):
|
||||
"""Scenario: robot moving at constant velocity."""
|
||||
odom = Odometry()
|
||||
odom.header.frame_id = "odom"
|
||||
odom.twist.twist.linear.x = 1.0
|
||||
|
||||
for i in range(50):
|
||||
node._on_odom(odom)
|
||||
node._timer_callback()
|
||||
|
||||
# Should maintain constant velocity after convergence
|
||||
|
||||
def test_scenario_velocity_ramp(self, node):
|
||||
"""Scenario: velocity ramping up from stop."""
|
||||
odom = Odometry()
|
||||
odom.header.frame_id = "odom"
|
||||
|
||||
for i in range(50):
|
||||
odom.twist.twist.linear.x = i * 0.02 # Ramp from 0 to 1.0
|
||||
node._on_odom(odom)
|
||||
node._timer_callback()
|
||||
|
||||
# Filter should smooth the ramp
|
||||
|
||||
def test_scenario_velocity_step(self, node):
|
||||
"""Scenario: sudden velocity change (e.g., collision avoidance)."""
|
||||
odom = Odometry()
|
||||
odom.header.frame_id = "odom"
|
||||
|
||||
# First phase: constant velocity
|
||||
odom.twist.twist.linear.x = 1.0
|
||||
for i in range(25):
|
||||
node._on_odom(odom)
|
||||
node._timer_callback()
|
||||
|
||||
# Second phase: sudden stop
|
||||
odom.twist.twist.linear.x = 0.0
|
||||
for i in range(25):
|
||||
node._on_odom(odom)
|
||||
node._timer_callback()
|
||||
|
||||
# Filter should transition smoothly
|
||||
|
||||
def test_scenario_rotation_only(self, node):
|
||||
"""Scenario: robot spinning in place."""
|
||||
odom = Odometry()
|
||||
odom.header.frame_id = "odom"
|
||||
odom.twist.twist.linear.x = 0.0
|
||||
odom.twist.twist.angular.z = 0.5
|
||||
|
||||
for i in range(50):
|
||||
node._on_odom(odom)
|
||||
node._timer_callback()
|
||||
|
||||
# Angular velocity should be filtered
|
||||
|
||||
def test_scenario_mixed_motion(self, node):
|
||||
"""Scenario: combined linear and angular motion."""
|
||||
odom = Odometry()
|
||||
odom.header.frame_id = "odom"
|
||||
|
||||
for i in range(50):
|
||||
odom.twist.twist.linear.x = math.cos(i * 0.1)
|
||||
odom.twist.twist.linear.y = math.sin(i * 0.1)
|
||||
odom.twist.twist.angular.z = 0.2
|
||||
|
||||
node._on_odom(odom)
|
||||
node._timer_callback()
|
||||
|
||||
# Both linear and angular components should be filtered
|
||||
|
||||
def test_scenario_encoder_noise_reduction(self, node):
|
||||
"""Scenario: realistic encoder jitter with filtering."""
|
||||
odom = Odometry()
|
||||
odom.header.frame_id = "odom"
|
||||
|
||||
# Simulate encoder jitter: base velocity + small noise
|
||||
base_vel = 1.0
|
||||
for i in range(100):
|
||||
jitter = 0.05 * math.sin(i * 0.5) + 0.03 * math.cos(i * 0.3)
|
||||
odom.twist.twist.linear.x = base_vel + jitter
|
||||
|
||||
node._on_odom(odom)
|
||||
node._timer_callback()
|
||||
|
||||
# Filter should reduce noise while maintaining base velocity
|
||||
Loading…
x
Reference in New Issue
Block a user