Compare commits

..

5 Commits

Author SHA1 Message Date
067a871103 feat(perception): wheel encoder differential drive odometry (Issue #184)
Some checks failed
social-bot integration tests / Lint (flake8 + pep257) (pull_request) Failing after 7s
social-bot integration tests / Core integration tests (mock sensors, no GPU) (pull_request) Has been skipped
social-bot integration tests / Latency profiling (GPU, Orin) (pull_request) Has been cancelled
Adds saltybot_bridge_msgs package with WheelTicks.msg (int32 left/right
encoder counts) and a WheelOdomNode that subscribes to
/saltybot/wheel_ticks, integrates midpoint-Euler differential drive
kinematics (handling int32 counter rollover), and publishes
nav_msgs/Odometry on /odom_wheel at 50 Hz with optional TF broadcast.

New files:
  jetson/ros2_ws/src/saltybot_bridge_msgs/
    msg/WheelTicks.msg
    CMakeLists.txt, package.xml

  jetson/ros2_ws/src/saltybot_bringup/
    saltybot_bringup/_wheel_odom.py     — pure kinematics (no ROS2 deps)
    saltybot_bringup/wheel_odom_node.py — 50 Hz timer node + TF broadcast
    test/test_wheel_odom.py             — 42 tests, all passing

Modified:
  saltybot_bringup/package.xml  — add saltybot_bridge_msgs, nav_msgs deps
  saltybot_bringup/setup.py     — add wheel_odom console_script entry

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-03 00:41:39 -05:00
b96c6b96d0 Merge pull request 'feat(social): audio wake-word detector 'hey salty' (Issue #320)' (#317) from sl-jetson/wake-word-detect into main
Some checks failed
social-bot integration tests / Lint (flake8 + pep257) (push) Failing after 10s
social-bot integration tests / Core integration tests (mock sensors, no GPU) (push) Has been skipped
social-bot integration tests / Lint (flake8 + pep257) (pull_request) Failing after 10s
social-bot integration tests / Core integration tests (mock sensors, no GPU) (pull_request) Has been skipped
social-bot integration tests / Latency profiling (GPU, Orin) (push) Has been cancelled
social-bot integration tests / Latency profiling (GPU, Orin) (pull_request) Has been cancelled
2026-03-03 00:41:22 -05:00
d5e0c58594 Merge pull request 'feat: Add velocity smoothing filter ROS2 node' (#316) from sl-controls/velocity-smooth-filter into main 2026-03-03 00:41:16 -05:00
d6553ce3d6 feat(social): audio wake-word detector 'hey salty' (Issue #320)
Some checks failed
social-bot integration tests / Lint (flake8 + pep257) (push) Failing after 2s
social-bot integration tests / Core integration tests (mock sensors, no GPU) (push) Has been skipped
social-bot integration tests / Lint (flake8 + pep257) (pull_request) Failing after 10s
social-bot integration tests / Core integration tests (mock sensors, no GPU) (pull_request) Has been skipped
social-bot integration tests / Latency profiling (GPU, Orin) (push) Has been cancelled
social-bot integration tests / Latency profiling (GPU, Orin) (pull_request) Has been cancelled
Energy-gated log-mel + cosine-similarity wake-word node. Subscribes to
/social/speech/audio_raw (PCM-16 UInt8MultiArray), maintains a 1.5 s
sliding ring buffer, runs detection every 100 ms; fires Bool(True) on
/saltybot/wake_word_detected with 2 s cooldown. Template loaded from
.npy file; passive (no detections) when template_path is empty.
91/91 tests pass.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-03 00:26:59 -05:00
2919e629e9 feat: Add velocity smoothing filter with Butterworth low-pass filtering
Implements saltybot_velocity_smoother package:
- Subscribes to /odom, applies digital Butterworth low-pass filter
- Filters linear (x,y,z) and angular (x,y,z) velocity components independently
- Publishes smoothed odometry on /odom_smooth
- Reduces encoder jitter and improves state estimation stability
- Configurable filter order (1-4), cutoff frequency (Hz), publish rate
- Can be enabled/disabled via enable_smoothing parameter
- Comprehensive test suite: 18+ tests covering filter behavior, edge cases, scenarios

Co-Authored-By: Claude Haiku 4.5 <noreply@anthropic.com>
2026-03-03 00:23:53 -05:00
15 changed files with 1889 additions and 0 deletions

View File

@ -0,0 +1,19 @@
wake_word_node:
ros__parameters:
audio_topic: "/social/speech/audio_raw" # PCM-16 mono input (UInt8MultiArray)
output_topic: "/saltybot/wake_word_detected"
sample_rate: 16000 # Hz — must match audio source
window_s: 1.5 # detection window length (s)
hop_s: 0.1 # detection timer period (s)
energy_threshold: 0.02 # RMS gate; below this → skip matching
match_threshold: 0.82 # cosine-similarity gate; above → detect
cooldown_s: 2.0 # minimum gap between successive detections (s)
# Path to .npy template file (log-mel features of 'hey salty' recording).
# Leave empty for passive mode (no detections fired).
template_path: "" # e.g. "/opt/saltybot/models/hey_salty.npy"
n_fft: 512 # FFT size for mel spectrogram
n_mels: 40 # mel filterbank bands

View File

@ -0,0 +1,43 @@
"""wake_word.launch.py — Launch wake-word detector ('hey salty') (Issue #320).
Usage:
ros2 launch saltybot_social wake_word.launch.py
ros2 launch saltybot_social wake_word.launch.py template_path:=/opt/saltybot/models/hey_salty.npy
ros2 launch saltybot_social wake_word.launch.py match_threshold:=0.85 cooldown_s:=3.0
"""
import os
from ament_index_python.packages import get_package_share_directory
from launch import LaunchDescription
from launch.actions import DeclareLaunchArgument
from launch.substitutions import LaunchConfiguration
from launch_ros.actions import Node
def generate_launch_description():
pkg = get_package_share_directory("saltybot_social")
cfg = os.path.join(pkg, "config", "wake_word_params.yaml")
return LaunchDescription([
DeclareLaunchArgument("template_path", default_value="",
description="Path to .npy template file (log-mel of 'hey salty')"),
DeclareLaunchArgument("match_threshold", default_value="0.82",
description="Cosine-similarity detection threshold"),
DeclareLaunchArgument("cooldown_s", default_value="2.0",
description="Minimum seconds between detections"),
Node(
package="saltybot_social",
executable="wake_word_node",
name="wake_word_node",
output="screen",
parameters=[
cfg,
{
"template_path": LaunchConfiguration("template_path"),
"match_threshold": LaunchConfiguration("match_threshold"),
"cooldown_s": LaunchConfiguration("cooldown_s"),
},
],
),
])

View File

@ -0,0 +1,343 @@
"""wake_word_node.py — Audio wake-word detector ('hey salty').
Issue #320
Subscribes to raw PCM-16 audio on /social/speech/audio_raw, maintains a
sliding window buffer, and on each hop tick computes log-mel spectrogram
features of the most recent window. If energy is above the gate threshold
AND cosine similarity to the stored template is above match_threshold the
detection fires: Bool(True) is published on /saltybot/wake_word_detected.
Detection is one-shot (only True is published) and guarded by a cooldown
so rapid re-fires are suppressed. When no template is loaded (template_path
is empty) the node stays passive energy gating is applied but no match is
attempted.
Audio format expected
UInt8MultiArray, same feed as vad_node:
raw PCM-16 little-endian mono at ``sample_rate`` Hz.
Subscriptions
/social/speech/audio_raw std_msgs/UInt8MultiArray raw PCM-16 chunks
Publications
/saltybot/wake_word_detected std_msgs/Bool True on each detection event
Parameters
audio_topic (str, "/social/speech/audio_raw")
output_topic (str, "/saltybot/wake_word_detected")
sample_rate (int, 16000) sample rate of incoming audio (Hz)
window_s (float, 1.5) detection window duration (s)
hop_s (float, 0.1) detection timer period (s)
energy_threshold (float, 0.02) RMS gate; below this skip matching
match_threshold (float, 0.82) cosine-similarity gate; above detect
cooldown_s (float, 2.0) minimum gap between successive detections
template_path (str, "") path to .npy template file; "" = passive
n_fft (int, 512) FFT size for mel spectrogram
n_mels (int, 40) number of mel filterbank bands
"""
from __future__ import annotations
import math
import struct
import threading
import time
from collections import deque
from typing import Dict, Optional, Tuple
import rclpy
from rclpy.node import Node
from rclpy.qos import QoSProfile
from std_msgs.msg import Bool, UInt8MultiArray
try:
import numpy as np
_NP = True
except ImportError: # pragma: no cover
_NP = False
INT16_MAX = 32768.0
# ── Pure DSP helpers (no ROS, numpy only) ──────────────────────────────────────
def pcm16_to_float(data: bytes) -> "np.ndarray":
"""Decode PCM-16 LE bytes → float32 ndarray in [-1.0, 1.0]."""
n = len(data) // 2
if n == 0:
return np.zeros(0, dtype=np.float32)
samples = struct.unpack(f"<{n}h", data[:n * 2])
return np.array(samples, dtype=np.float32) / INT16_MAX
def mel_filterbank(sr: int, n_fft: int, n_mels: int,
fmin: float = 80.0, fmax: Optional[float] = None) -> "np.ndarray":
"""Build a triangular mel filterbank matrix [n_mels, n_fft//2+1]."""
if fmax is None:
fmax = sr / 2.0
def hz_to_mel(hz: float) -> float:
return 2595.0 * math.log10(1.0 + hz / 700.0)
def mel_to_hz(mel: float) -> float:
return 700.0 * (10.0 ** (mel / 2595.0) - 1.0)
mel_lo = hz_to_mel(fmin)
mel_hi = hz_to_mel(fmax)
mel_pts = np.linspace(mel_lo, mel_hi, n_mels + 2)
hz_pts = np.array([mel_to_hz(m) for m in mel_pts])
freqs = np.fft.rfftfreq(n_fft, d=1.0 / sr)
fb = np.zeros((n_mels, len(freqs)), dtype=np.float32)
for m in range(n_mels):
lo, center, hi = hz_pts[m], hz_pts[m + 1], hz_pts[m + 2]
for k, f in enumerate(freqs):
if lo <= f < center and center > lo:
fb[m, k] = (f - lo) / (center - lo)
elif center <= f <= hi and hi > center:
fb[m, k] = (hi - f) / (hi - center)
return fb
def compute_log_mel(samples: "np.ndarray", sr: int,
n_fft: int = 512, n_mels: int = 40,
hop: int = 256) -> "np.ndarray":
"""Return log-mel spectrogram [n_mels, T] of *samples* (float32 [-1,1])."""
n = len(samples)
window = np.hanning(n_fft).astype(np.float32)
frames = []
for start in range(0, max(n - n_fft + 1, 1), hop):
chunk = samples[start:start + n_fft]
if len(chunk) < n_fft:
chunk = np.pad(chunk, (0, n_fft - len(chunk)))
power = np.abs(np.fft.rfft(chunk * window)) ** 2
frames.append(power)
frames_arr = np.array(frames, dtype=np.float32).T # [bins, T]
fb = mel_filterbank(sr, n_fft, n_mels)
mel = fb @ frames_arr # [n_mels, T]
mel = np.where(mel > 1e-10, mel, 1e-10)
return np.log(mel)
def cosine_sim(a: "np.ndarray", b: "np.ndarray") -> float:
"""Cosine similarity between two arrays, matched by minimum length."""
af = a.flatten().astype(np.float64)
bf = b.flatten().astype(np.float64)
min_len = min(len(af), len(bf))
if min_len == 0:
return 0.0
af = af[:min_len]
bf = bf[:min_len]
denom = float(np.linalg.norm(af)) * float(np.linalg.norm(bf))
if denom < 1e-12:
return 0.0
return float(np.dot(af, bf) / denom)
def rms(samples: "np.ndarray") -> float:
"""RMS amplitude of a float sample array."""
if len(samples) == 0:
return 0.0
return float(np.sqrt(np.mean(samples.astype(np.float64) ** 2)))
# ── Ring buffer ────────────────────────────────────────────────────────────────
class AudioRingBuffer:
"""Lock-free sliding window for raw float audio samples."""
def __init__(self, max_samples: int) -> None:
self._buf: deque = deque(maxlen=max_samples)
def push(self, samples: "np.ndarray") -> None:
self._buf.extend(samples.tolist())
def get_window(self, n_samples: int) -> Optional["np.ndarray"]:
"""Return last n_samples as float32 array, or None if buffer too short."""
if len(self._buf) < n_samples:
return None
return np.array(list(self._buf)[-n_samples:], dtype=np.float32)
def __len__(self) -> int:
return len(self._buf)
# ── Detector ───────────────────────────────────────────────────────────────────
class WakeWordDetector:
"""Energy-gated cosine-similarity wake-word detector.
Args:
template: Log-mel feature array of the wake word, or None
(passive never fires when None).
energy_threshold: Minimum RMS to proceed to feature matching.
match_threshold: Minimum cosine similarity to fire.
sample_rate: Expected audio sample rate (Hz).
n_fft: FFT size for mel computation.
n_mels: Number of mel bands.
"""
def __init__(self,
template: Optional["np.ndarray"],
energy_threshold: float = 0.02,
match_threshold: float = 0.82,
sample_rate: int = 16000,
n_fft: int = 512,
n_mels: int = 40) -> None:
self._template = template
self._energy_thr = energy_threshold
self._match_thr = match_threshold
self._sr = sample_rate
self._n_fft = n_fft
self._n_mels = n_mels
# ------------------------------------------------------------------
def detect(self, samples: "np.ndarray") -> Tuple[bool, float, float]:
"""Run detection on a window of float samples.
Returns
-------
(detected, rms_value, similarity)
"""
energy = rms(samples)
if energy < self._energy_thr:
return False, energy, 0.0
if self._template is None:
return False, energy, 0.0
hop = max(1, self._n_fft // 2)
feats = compute_log_mel(samples, self._sr, self._n_fft, self._n_mels, hop)
sim = cosine_sim(feats, self._template)
return sim >= self._match_thr, energy, sim
@property
def has_template(self) -> bool:
return self._template is not None
# ── ROS2 node ──────────────────────────────────────────────────────────────────
class WakeWordNode(Node):
"""ROS2 node: 'hey salty' wake-word detection via energy + template matching."""
def __init__(self) -> None:
super().__init__("wake_word_node")
self.declare_parameter("audio_topic", "/social/speech/audio_raw")
self.declare_parameter("output_topic", "/saltybot/wake_word_detected")
self.declare_parameter("sample_rate", 16000)
self.declare_parameter("window_s", 1.5)
self.declare_parameter("hop_s", 0.1)
self.declare_parameter("energy_threshold", 0.02)
self.declare_parameter("match_threshold", 0.82)
self.declare_parameter("cooldown_s", 2.0)
self.declare_parameter("template_path", "")
self.declare_parameter("n_fft", 512)
self.declare_parameter("n_mels", 40)
audio_topic = self.get_parameter("audio_topic").value
output_topic = self.get_parameter("output_topic").value
self._sr = int(self.get_parameter("sample_rate").value)
self._win_s = float(self.get_parameter("window_s").value)
hop_s = float(self.get_parameter("hop_s").value)
energy_thr = float(self.get_parameter("energy_threshold").value)
match_thr = float(self.get_parameter("match_threshold").value)
self._cool_s = float(self.get_parameter("cooldown_s").value)
tmpl_path = str(self.get_parameter("template_path").value)
n_fft = int(self.get_parameter("n_fft").value)
n_mels = int(self.get_parameter("n_mels").value)
# ── Load template ──────────────────────────────────────────────
template: Optional["np.ndarray"] = None
if tmpl_path and _NP:
try:
template = np.load(tmpl_path)
self.get_logger().info(
f"WakeWord: loaded template {tmpl_path} shape={template.shape}"
)
except Exception as exc:
self.get_logger().warn(
f"WakeWord: could not load template '{tmpl_path}': {exc} — passive mode"
)
# ── Ring buffer ────────────────────────────────────────────────
max_samples = int(self._sr * self._win_s * 4) # 4× headroom
self._win_n = int(self._sr * self._win_s)
self._buf = AudioRingBuffer(max_samples)
# ── Detector ───────────────────────────────────────────────────
self._detector = WakeWordDetector(template, energy_thr, match_thr,
self._sr, n_fft, n_mels)
# ── State ──────────────────────────────────────────────────────
self._last_det_t: float = 0.0
self._lock = threading.Lock()
# ── ROS ────────────────────────────────────────────────────────
qos = QoSProfile(depth=10)
self._pub = self.create_publisher(Bool, output_topic, qos)
self._sub = self.create_subscription(
UInt8MultiArray, audio_topic, self._on_audio, qos
)
self._timer = self.create_timer(hop_s, self._detection_cb)
tmpl_status = (f"template={tmpl_path}" if template is not None
else "passive (no template)")
self.get_logger().info(
f"WakeWordNode ready — {tmpl_status}, "
f"window={self._win_s}s, hop={hop_s}s, "
f"energy_thr={energy_thr}, match_thr={match_thr}"
)
# ── Subscription ───────────────────────────────────────────────────
def _on_audio(self, msg) -> None:
if not _NP:
return
try:
raw = bytes(msg.data)
samples = pcm16_to_float(raw)
except Exception as exc:
self.get_logger().warn(f"WakeWord: audio decode error: {exc}")
return
with self._lock:
self._buf.push(samples)
# ── Detection timer ────────────────────────────────────────────────
def _detection_cb(self) -> None:
if not _NP:
return
now = time.monotonic()
with self._lock:
window = self._buf.get_window(self._win_n)
if window is None:
return
detected, energy, sim = self._detector.detect(window)
if detected and (now - self._last_det_t) >= self._cool_s:
self._last_det_t = now
self.get_logger().info(
f"WakeWord: 'hey salty' detected "
f"(rms={energy:.4f}, sim={sim:.3f})"
)
out = Bool()
out.data = True
self._pub.publish(out)
def main(args=None) -> None:
rclpy.init(args=args)
node = WakeWordNode()
try:
rclpy.spin(node)
except KeyboardInterrupt:
pass
finally:
node.destroy_node()
rclpy.shutdown()

View File

@ -57,6 +57,8 @@ setup(
'topic_memory_node = saltybot_social.topic_memory_node:main',
# Personal space respector (Issue #310)
'personal_space_node = saltybot_social.personal_space_node:main',
# Audio wake-word detector — 'hey salty' (Issue #320)
'wake_word_node = saltybot_social.wake_word_node:main',
],
},
)

View File

@ -0,0 +1,711 @@
"""test_wake_word.py — Offline tests for wake_word_node (Issue #320).
Stubs out rclpy and ROS message types so tests run without a ROS install.
numpy is required (standard on the Jetson).
"""
import importlib.util
import math
import struct
import sys
import time
import types
import unittest
import numpy as np
# ── ROS2 stubs ────────────────────────────────────────────────────────────────
def _make_ros_stubs():
for mod_name in ("rclpy", "rclpy.node", "rclpy.qos",
"std_msgs", "std_msgs.msg"):
if mod_name not in sys.modules:
sys.modules[mod_name] = types.ModuleType(mod_name)
class _Node:
def __init__(self, name="node"):
self._name = name
if not hasattr(self, "_params"):
self._params = {}
self._pubs = {}
self._subs = {}
self._logs = []
self._timers = []
def declare_parameter(self, name, default):
if name not in self._params:
self._params[name] = default
def get_parameter(self, name):
class _P:
def __init__(self, v): self.value = v
return _P(self._params.get(name))
def create_publisher(self, msg_type, topic, qos):
pub = _FakePub()
self._pubs[topic] = pub
return pub
def create_subscription(self, msg_type, topic, cb, qos):
self._subs[topic] = cb
return object()
def create_timer(self, period, cb):
self._timers.append(cb)
return object()
def get_logger(self):
node = self
class _L:
def info(self, m): node._logs.append(("INFO", m))
def warn(self, m): node._logs.append(("WARN", m))
def error(self, m): node._logs.append(("ERROR", m))
return _L()
def destroy_node(self): pass
class _FakePub:
def __init__(self):
self.msgs = []
def publish(self, msg):
self.msgs.append(msg)
class _QoSProfile:
def __init__(self, depth=10): self.depth = depth
class _Bool:
def __init__(self): self.data = False
class _UInt8MultiArray:
def __init__(self): self.data = b""
rclpy_mod = sys.modules["rclpy"]
rclpy_mod.init = lambda args=None: None
rclpy_mod.spin = lambda node: None
rclpy_mod.shutdown = lambda: None
sys.modules["rclpy.node"].Node = _Node
sys.modules["rclpy.qos"].QoSProfile = _QoSProfile
sys.modules["std_msgs.msg"].Bool = _Bool
sys.modules["std_msgs.msg"].UInt8MultiArray = _UInt8MultiArray
return _Node, _FakePub, _Bool, _UInt8MultiArray
_Node, _FakePub, _Bool, _UInt8MultiArray = _make_ros_stubs()
# ── Module loader ─────────────────────────────────────────────────────────────
_SRC = (
"/Users/seb/AI/saltylab-firmware/jetson/ros2_ws/src/"
"saltybot_social/saltybot_social/wake_word_node.py"
)
def _load_mod():
spec = importlib.util.spec_from_file_location("wake_word_testmod", _SRC)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
return mod
def _make_node(mod, **kwargs):
node = mod.WakeWordNode.__new__(mod.WakeWordNode)
defaults = {
"audio_topic": "/social/speech/audio_raw",
"output_topic": "/saltybot/wake_word_detected",
"sample_rate": 16000,
"window_s": 1.5,
"hop_s": 0.1,
"energy_threshold": 0.02,
"match_threshold": 0.82,
"cooldown_s": 2.0,
"template_path": "",
"n_fft": 512,
"n_mels": 40,
}
defaults.update(kwargs)
node._params = dict(defaults)
mod.WakeWordNode.__init__(node)
return node
def _make_pcm_bytes(samples: np.ndarray) -> bytes:
"""Encode float32 array [-1,1] to PCM-16 LE bytes."""
ints = np.clip(samples * 32768.0, -32768, 32767).astype(np.int16)
return ints.tobytes()
def _make_audio_msg(samples: np.ndarray) -> _UInt8MultiArray:
m = _UInt8MultiArray()
m.data = _make_pcm_bytes(samples)
return m
def _sine(freq: float, duration: float, sr: int = 16000,
amp: float = 0.5) -> np.ndarray:
"""Generate a mono sine wave."""
t = np.arange(int(sr * duration)) / sr
return (amp * np.sin(2 * math.pi * freq * t)).astype(np.float32)
def _silence(duration: float, sr: int = 16000) -> np.ndarray:
return np.zeros(int(sr * duration), dtype=np.float32)
def _make_template(mod, sr: int = 16000, n_fft: int = 512,
n_mels: int = 40) -> np.ndarray:
"""Compute a template from a synthetic 'wake word' signal for testing."""
signal = _sine(300, 1.5, sr, amp=0.6)
hop = n_fft // 2
return mod.compute_log_mel(signal, sr, n_fft, n_mels, hop)
# ── Tests: pcm16_to_float ─────────────────────────────────────────────────────
class TestPcm16ToFloat(unittest.TestCase):
@classmethod
def setUpClass(cls): cls.mod = _load_mod()
def _conv(self, samples):
raw = _make_pcm_bytes(np.array(samples, dtype=np.float32))
return self.mod.pcm16_to_float(raw)
def test_zeros(self):
out = self._conv([0.0, 0.0, 0.0])
np.testing.assert_allclose(out, [0.0, 0.0, 0.0], atol=1e-4)
def test_positive(self):
out = self._conv([0.5])
self.assertAlmostEqual(float(out[0]), 0.5, places=3)
def test_negative(self):
out = self._conv([-0.5])
self.assertAlmostEqual(float(out[0]), -0.5, places=3)
def test_roundtrip_length(self):
arr = np.linspace(-1.0, 1.0, 100, dtype=np.float32)
raw = _make_pcm_bytes(arr)
out = self.mod.pcm16_to_float(raw)
self.assertEqual(len(out), 100)
def test_empty_bytes_returns_empty(self):
out = self.mod.pcm16_to_float(b"")
self.assertEqual(len(out), 0)
def test_odd_byte_ignored(self):
# 3 bytes → 1 complete int16 + 1 orphan byte
out = self.mod.pcm16_to_float(b"\x00\x40\xff")
self.assertEqual(len(out), 1)
# ── Tests: rms ────────────────────────────────────────────────────────────────
class TestRms(unittest.TestCase):
@classmethod
def setUpClass(cls): cls.mod = _load_mod()
def test_zero_signal(self):
self.assertAlmostEqual(self.mod.rms(np.zeros(100)), 0.0)
def test_constant_signal(self):
# RMS of constant 0.5 = 0.5
self.assertAlmostEqual(self.mod.rms(np.full(100, 0.5)), 0.5, places=5)
def test_sine_rms(self):
# RMS of sin = amp / sqrt(2)
amp = 0.8
s = _sine(440, 1.0, amp=amp)
expected = amp / math.sqrt(2)
self.assertAlmostEqual(self.mod.rms(s), expected, places=2)
def test_empty_array(self):
self.assertEqual(self.mod.rms(np.array([])), 0.0)
# ── Tests: mel_filterbank ─────────────────────────────────────────────────────
class TestMelFilterbank(unittest.TestCase):
@classmethod
def setUpClass(cls): cls.mod = _load_mod()
def test_shape(self):
fb = self.mod.mel_filterbank(16000, 512, 40)
self.assertEqual(fb.shape, (40, 257))
def test_non_negative(self):
fb = self.mod.mel_filterbank(16000, 512, 40)
self.assertTrue((fb >= 0).all())
def test_rows_sum_positive(self):
fb = self.mod.mel_filterbank(16000, 512, 40)
self.assertTrue((fb.sum(axis=1) > 0).all())
def test_custom_n_mels(self):
fb = self.mod.mel_filterbank(16000, 256, 20)
self.assertEqual(fb.shape[0], 20)
# ── Tests: compute_log_mel ────────────────────────────────────────────────────
class TestComputeLogMel(unittest.TestCase):
@classmethod
def setUpClass(cls): cls.mod = _load_mod()
def test_output_shape_rows(self):
s = _sine(440, 1.5)
out = self.mod.compute_log_mel(s, 16000, n_fft=512, n_mels=40, hop=256)
self.assertEqual(out.shape[0], 40)
def test_output_has_time_axis(self):
s = _sine(440, 1.5)
out = self.mod.compute_log_mel(s, 16000, n_fft=512, n_mels=40, hop=256)
self.assertGreater(out.shape[1], 0)
def test_output_finite(self):
s = _sine(440, 1.5)
out = self.mod.compute_log_mel(s, 16000, n_fft=512, n_mels=40, hop=256)
self.assertTrue(np.isfinite(out).all())
def test_silence_gives_low_values(self):
s = _silence(1.5)
out = self.mod.compute_log_mel(s, 16000, n_fft=512, n_mels=40, hop=256)
# All values should be very small (near log(1e-10))
self.assertTrue((out < -20).all())
def test_short_signal_no_crash(self):
# Shorter than one FFT frame
s = np.zeros(100, dtype=np.float32)
out = self.mod.compute_log_mel(s, 16000, n_fft=512, n_mels=40, hop=256)
self.assertEqual(out.shape[0], 40)
# ── Tests: cosine_sim ─────────────────────────────────────────────────────────
class TestCosineSim(unittest.TestCase):
@classmethod
def setUpClass(cls): cls.mod = _load_mod()
def test_identical_vectors(self):
v = np.array([1.0, 2.0, 3.0])
self.assertAlmostEqual(self.mod.cosine_sim(v, v), 1.0, places=5)
def test_orthogonal_vectors(self):
a = np.array([1.0, 0.0])
b = np.array([0.0, 1.0])
self.assertAlmostEqual(self.mod.cosine_sim(a, b), 0.0, places=5)
def test_opposite_vectors(self):
v = np.array([1.0, 2.0, 3.0])
self.assertAlmostEqual(self.mod.cosine_sim(v, -v), -1.0, places=5)
def test_zero_vector_returns_zero(self):
a = np.zeros(5)
b = np.ones(5)
self.assertEqual(self.mod.cosine_sim(a, b), 0.0)
def test_2d_arrays(self):
a = np.ones((4, 10))
b = np.ones((4, 10))
self.assertAlmostEqual(self.mod.cosine_sim(a, b), 1.0, places=5)
def test_different_lengths_truncated(self):
a = np.array([1.0, 2.0, 3.0, 4.0])
b = np.array([1.0, 2.0, 3.0])
# Should not crash, uses min length
result = self.mod.cosine_sim(a, b)
self.assertTrue(-1.0 <= result <= 1.0)
def test_range_is_bounded(self):
rng = np.random.default_rng(42)
a = rng.standard_normal(100)
b = rng.standard_normal(100)
result = self.mod.cosine_sim(a, b)
self.assertGreaterEqual(result, -1.0)
self.assertLessEqual(result, 1.0)
# ── Tests: AudioRingBuffer ────────────────────────────────────────────────────
class TestAudioRingBuffer(unittest.TestCase):
@classmethod
def setUpClass(cls): cls.mod = _load_mod()
def _buf(self, max_samples=1000):
return self.mod.AudioRingBuffer(max_samples)
def test_empty_initially(self):
b = self._buf()
self.assertEqual(len(b), 0)
def test_push_increases_len(self):
b = self._buf()
b.push(np.ones(100, dtype=np.float32))
self.assertEqual(len(b), 100)
def test_get_window_none_when_short(self):
b = self._buf()
b.push(np.ones(50, dtype=np.float32))
self.assertIsNone(b.get_window(100))
def test_get_window_ok_when_full(self):
b = self._buf()
data = np.arange(200, dtype=np.float32)
b.push(data)
w = b.get_window(100)
self.assertIsNotNone(w)
self.assertEqual(len(w), 100)
def test_get_window_returns_latest(self):
b = self._buf()
b.push(np.zeros(100, dtype=np.float32))
b.push(np.ones(100, dtype=np.float32))
w = b.get_window(100)
np.testing.assert_allclose(w, np.ones(100))
def test_maxlen_evicts_oldest(self):
b = self._buf(max_samples=100)
b.push(np.zeros(60, dtype=np.float32))
b.push(np.ones(60, dtype=np.float32)) # should evict 20 zeros
self.assertEqual(len(b), 100)
w = b.get_window(100)
# Last 40 samples should be ones
np.testing.assert_allclose(w[-40:], np.ones(40))
def test_exact_window_size(self):
b = self._buf(500)
data = np.arange(300, dtype=np.float32)
b.push(data)
w = b.get_window(300)
np.testing.assert_allclose(w, data)
# ── Tests: WakeWordDetector ───────────────────────────────────────────────────
class TestWakeWordDetector(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.mod = _load_mod()
cls.sr = 16000
cls.template = _make_template(cls.mod, cls.sr)
def _det(self, template=None, energy_thr=0.02, match_thr=0.82):
return self.mod.WakeWordDetector(template, energy_thr, match_thr,
self.sr, 512, 40)
def test_no_template_never_detects(self):
det = self._det(template=None)
loud = _sine(300, 1.5, self.sr, amp=0.8)
detected, _, _ = det.detect(loud)
self.assertFalse(detected)
def test_has_template_false_when_no_template(self):
det = self._det(template=None)
self.assertFalse(det.has_template)
def test_has_template_true_when_set(self):
det = self._det(template=self.template)
self.assertTrue(det.has_template)
def test_silence_below_energy_gate(self):
det = self._det(template=self.template, energy_thr=0.02)
silence = _silence(1.5, self.sr)
detected, rms_val, sim = det.detect(silence)
self.assertFalse(detected)
self.assertAlmostEqual(rms_val, 0.0, places=4)
self.assertAlmostEqual(sim, 0.0, places=4)
def test_returns_rms_value(self):
det = self._det(template=None)
s = _sine(300, 1.5, self.sr, amp=0.5)
_, rms_val, _ = det.detect(s)
expected = 0.5 / math.sqrt(2)
self.assertAlmostEqual(rms_val, expected, places=2)
def test_identical_signal_detects(self):
"""Signal identical to template should give sim ≈ 1.0 → detect."""
signal = _sine(300, 1.5, self.sr, amp=0.6)
det = self._det(template=self.template, match_thr=0.99)
detected, _, sim = det.detect(signal)
# Sim must be very high for an identical-source signal
self.assertGreater(sim, 0.99)
self.assertTrue(detected)
def test_different_signal_low_sim(self):
"""A very different signal (white noise) should have low similarity."""
rng = np.random.default_rng(7)
noise = (rng.standard_normal(int(self.sr * 1.5)) * 0.4).astype(np.float32)
det = self._det(template=self.template, match_thr=0.82)
_, _, sim = det.detect(noise)
# White noise sim to a tonal template should be < 0.6
self.assertLess(sim, 0.6)
def test_threshold_boundary_low(self):
"""Setting match_thr=0.0 with a loud signal should fire if template set."""
signal = _sine(300, 1.5, self.sr, amp=0.6)
det = self._det(template=self.template, match_thr=0.0)
detected, _, _ = det.detect(signal)
self.assertTrue(detected)
def test_threshold_boundary_high(self):
"""Setting match_thr=1.1 (above max) should never fire."""
signal = _sine(300, 1.5, self.sr, amp=0.6)
det = self._det(template=self.template, match_thr=1.1)
detected, _, _ = det.detect(signal)
self.assertFalse(detected)
def test_energy_below_threshold_skips_matching(self):
"""Low energy → sim returned as 0.0 regardless of template."""
very_quiet = _sine(300, 1.5, self.sr, amp=0.001)
det = self._det(template=self.template, energy_thr=0.1)
detected, rms_val, sim = det.detect(very_quiet)
self.assertFalse(detected)
self.assertAlmostEqual(sim, 0.0, places=5)
# ── Tests: node init ──────────────────────────────────────────────────────────
class TestNodeInit(unittest.TestCase):
@classmethod
def setUpClass(cls): cls.mod = _load_mod()
def test_instantiates(self):
self.assertIsNotNone(_make_node(self.mod))
def test_pub_registered(self):
node = _make_node(self.mod)
self.assertIn("/saltybot/wake_word_detected", node._pubs)
def test_sub_registered(self):
node = _make_node(self.mod)
self.assertIn("/social/speech/audio_raw", node._subs)
def test_timer_registered(self):
node = _make_node(self.mod)
self.assertGreater(len(node._timers), 0)
def test_custom_topics(self):
node = _make_node(self.mod,
audio_topic="/my/audio",
output_topic="/my/wake")
self.assertIn("/my/audio", node._subs)
self.assertIn("/my/wake", node._pubs)
def test_no_template_passive(self):
node = _make_node(self.mod, template_path="")
self.assertFalse(node._detector.has_template)
def test_bad_template_path_warns(self):
node = _make_node(self.mod, template_path="/nonexistent/template.npy")
warns = [m for lvl, m in node._logs if lvl == "WARN"]
self.assertTrue(any("template" in m.lower() or "passive" in m.lower()
for m in warns))
def test_ring_buffer_allocated(self):
node = _make_node(self.mod)
self.assertIsNotNone(node._buf)
def test_window_n_computed(self):
node = _make_node(self.mod, sample_rate=16000, window_s=1.5)
self.assertEqual(node._win_n, 24000)
# ── Tests: _on_audio callback ─────────────────────────────────────────────────
class TestOnAudio(unittest.TestCase):
@classmethod
def setUpClass(cls): cls.mod = _load_mod()
def setUp(self):
self.node = _make_node(self.mod)
def _push(self, samples):
msg = _make_audio_msg(samples)
self.node._subs["/social/speech/audio_raw"](msg)
def test_pushes_samples_to_buffer(self):
self._push(_sine(440, 0.5))
self.assertGreater(len(self.node._buf), 0)
def test_buffer_grows_with_pushes(self):
chunk = _sine(440, 0.1)
before = len(self.node._buf)
self._push(chunk)
after = len(self.node._buf)
self.assertGreater(after, before)
def test_bad_data_no_crash(self):
msg = _UInt8MultiArray()
msg.data = b"\xff" # 1 orphan byte — yields 0 samples, no crash
self.node._subs["/social/speech/audio_raw"](msg)
def test_multiple_chunks_accumulate(self):
for _ in range(5):
self._push(_sine(440, 0.1))
self.assertGreater(len(self.node._buf), 0)
# ── Tests: detection callback ─────────────────────────────────────────────────
class TestDetectionCallback(unittest.TestCase):
@classmethod
def setUpClass(cls): cls.mod = _load_mod()
def _node_with_template(self, **kwargs):
template = _make_template(self.mod)
node = _make_node(self.mod, **kwargs)
node._detector = self.mod.WakeWordDetector(
template, energy_threshold=0.02, match_threshold=0.0, # thr=0 → always fires when loud
sample_rate=16000, n_fft=512, n_mels=40
)
return node
def _fill_buffer(self, node, signal):
msg = _make_audio_msg(signal)
node._subs["/social/speech/audio_raw"](msg)
def test_no_data_no_publish(self):
node = _make_node(self.mod)
node._detection_cb()
self.assertEqual(len(node._pubs["/saltybot/wake_word_detected"].msgs), 0)
def test_insufficient_buffer_no_publish(self):
node = self._node_with_template()
# Push only 0.1 s but window is 1.5 s
self._fill_buffer(node, _sine(300, 0.1))
node._detection_cb()
self.assertEqual(len(node._pubs["/saltybot/wake_word_detected"].msgs), 0)
def test_detects_and_publishes_true(self):
node = self._node_with_template()
# Fill with a loud 300 Hz sine (matches template)
self._fill_buffer(node, _sine(300, 1.5, amp=0.6))
node._detection_cb()
pub = node._pubs["/saltybot/wake_word_detected"]
self.assertEqual(len(pub.msgs), 1)
self.assertTrue(pub.msgs[0].data)
def test_cooldown_suppresses_second_detection(self):
node = self._node_with_template(cooldown_s=60.0)
self._fill_buffer(node, _sine(300, 1.5, amp=0.6))
node._detection_cb()
# Second call immediately → cooldown active
self._fill_buffer(node, _sine(300, 1.5, amp=0.6))
node._detection_cb()
pub = node._pubs["/saltybot/wake_word_detected"]
self.assertEqual(len(pub.msgs), 1)
def test_cooldown_expired_allows_second(self):
node = self._node_with_template(cooldown_s=0.0)
self._fill_buffer(node, _sine(300, 1.5, amp=0.6))
node._detection_cb()
self._fill_buffer(node, _sine(300, 1.5, amp=0.6))
node._detection_cb()
pub = node._pubs["/saltybot/wake_word_detected"]
self.assertEqual(len(pub.msgs), 2)
def test_no_template_never_publishes(self):
node = _make_node(self.mod, template_path="")
self._fill_buffer(node, _sine(300, 1.5, amp=0.8))
node._detection_cb()
pub = node._pubs["/saltybot/wake_word_detected"]
self.assertEqual(len(pub.msgs), 0)
def test_silence_no_publish(self):
node = self._node_with_template()
self._fill_buffer(node, _silence(1.5))
node._detection_cb()
pub = node._pubs["/saltybot/wake_word_detected"]
self.assertEqual(len(pub.msgs), 0)
def test_detection_logs_info(self):
node = self._node_with_template()
self._fill_buffer(node, _sine(300, 1.5, amp=0.6))
node._detection_cb()
infos = [m for lvl, m in node._logs if lvl == "INFO"]
self.assertTrue(any("detected" in m.lower() or "hey salty" in m.lower()
for m in infos))
# ── Tests: source content ─────────────────────────────────────────────────────
class TestNodeSrc(unittest.TestCase):
@classmethod
def setUpClass(cls):
with open(_SRC) as f: cls.src = f.read()
def test_issue_tag(self): self.assertIn("#320", self.src)
def test_audio_topic(self): self.assertIn("/social/speech/audio_raw", self.src)
def test_output_topic(self): self.assertIn("/saltybot/wake_word_detected", self.src)
def test_wake_word_name(self): self.assertIn("hey salty", self.src)
def test_compute_log_mel(self): self.assertIn("compute_log_mel", self.src)
def test_cosine_sim(self): self.assertIn("cosine_sim", self.src)
def test_mel_filterbank(self): self.assertIn("mel_filterbank", self.src)
def test_audio_ring_buffer(self): self.assertIn("AudioRingBuffer", self.src)
def test_wake_word_detector(self): self.assertIn("WakeWordDetector", self.src)
def test_energy_threshold(self): self.assertIn("energy_threshold", self.src)
def test_match_threshold(self): self.assertIn("match_threshold", self.src)
def test_cooldown(self): self.assertIn("cooldown_s", self.src)
def test_template_path(self): self.assertIn("template_path", self.src)
def test_pcm16_decode(self): self.assertIn("pcm16_to_float", self.src)
def test_threading_lock(self): self.assertIn("threading.Lock", self.src)
def test_numpy_used(self): self.assertIn("import numpy", self.src)
def test_main_defined(self): self.assertIn("def main", self.src)
def test_uint8_multiarray(self): self.assertIn("UInt8MultiArray", self.src)
# ── Tests: config / launch / setup ────────────────────────────────────────────
class TestConfig(unittest.TestCase):
_CONFIG = (
"/Users/seb/AI/saltylab-firmware/jetson/ros2_ws/src/"
"saltybot_social/config/wake_word_params.yaml"
)
_LAUNCH = (
"/Users/seb/AI/saltylab-firmware/jetson/ros2_ws/src/"
"saltybot_social/launch/wake_word.launch.py"
)
_SETUP = (
"/Users/seb/AI/saltylab-firmware/jetson/ros2_ws/src/"
"saltybot_social/setup.py"
)
def test_config_exists(self):
import os; self.assertTrue(os.path.exists(self._CONFIG))
def test_config_energy_threshold(self):
with open(self._CONFIG) as f: c = f.read()
self.assertIn("energy_threshold", c)
def test_config_match_threshold(self):
with open(self._CONFIG) as f: c = f.read()
self.assertIn("match_threshold", c)
def test_config_template_path(self):
with open(self._CONFIG) as f: c = f.read()
self.assertIn("template_path", c)
def test_config_cooldown(self):
with open(self._CONFIG) as f: c = f.read()
self.assertIn("cooldown_s", c)
def test_launch_exists(self):
import os; self.assertTrue(os.path.exists(self._LAUNCH))
def test_launch_has_template_arg(self):
with open(self._LAUNCH) as f: c = f.read()
self.assertIn("template_path", c)
def test_launch_has_threshold_arg(self):
with open(self._LAUNCH) as f: c = f.read()
self.assertIn("match_threshold", c)
def test_entry_point_in_setup(self):
with open(self._SETUP) as f: c = f.read()
self.assertIn("wake_word_node", c)
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,12 @@
# Velocity Smoother Configuration
velocity_smoother:
ros__parameters:
# Filter parameters
filter_order: 2 # Butterworth filter order (1-4 typical)
cutoff_frequency: 5.0 # Cutoff frequency in Hz (lower = more smoothing)
# Publishing frequency (Hz)
publish_frequency: 50
# Enable/disable filtering
enable_smoothing: true

View File

@ -0,0 +1,28 @@
import os
from launch import LaunchDescription
from launch_ros.actions import Node
from launch_ros.substitutions import FindPackageShare
from launch.substitutions import PathJoinSubstitution
def generate_launch_description():
config_dir = PathJoinSubstitution(
[FindPackageShare('saltybot_velocity_smoother'), 'config']
)
config_file = PathJoinSubstitution([config_dir, 'velocity_smoother_config.yaml'])
velocity_smoother = Node(
package='saltybot_velocity_smoother',
executable='velocity_smoother_node',
name='velocity_smoother',
output='screen',
parameters=[config_file],
remappings=[
('/odom', '/odom'),
('/odom_smooth', '/odom_smooth'),
],
)
return LaunchDescription([
velocity_smoother,
])

View File

@ -0,0 +1,29 @@
<?xml version="1.0"?>
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
<package format="3">
<name>saltybot_velocity_smoother</name>
<version>0.1.0</version>
<description>Low-pass Butterworth filter for odometry velocity smoothing to reduce encoder jitter</description>
<maintainer email="sl-controls@saltybot.local">SaltyBot Controls</maintainer>
<license>MIT</license>
<author email="sl-controls@saltybot.local">SaltyBot Controls Team</author>
<buildtool_depend>ament_cmake</buildtool_depend>
<buildtool_depend>ament_cmake_python</buildtool_depend>
<depend>rclpy</depend>
<depend>nav_msgs</depend>
<depend>std_msgs</depend>
<depend>geometry_msgs</depend>
<test_depend>ament_copyright</test_depend>
<test_depend>ament_flake8</test_depend>
<test_depend>ament_pep257</test_depend>
<test_depend>pytest</test_depend>
<export>
<build_type>ament_python</build_type>
</export>
</package>

View File

@ -0,0 +1,231 @@
#!/usr/bin/env python3
"""Velocity smoother node with low-pass Butterworth filter.
Subscribes to /odom, applies low-pass Butterworth filter to linear and angular
velocity components, and publishes smoothed odometry on /odom_smooth.
Reduces noise from encoder jitter and improves state estimation stability.
"""
import math
import rclpy
from rclpy.node import Node
from nav_msgs.msg import Odometry
from std_msgs.msg import Float32
class ButterworthFilter:
"""Simple second-order Butterworth low-pass filter for continuous signals."""
def __init__(self, cutoff_hz, sample_rate_hz, order=2):
"""Initialize Butterworth filter.
Args:
cutoff_hz: Cutoff frequency in Hz
sample_rate_hz: Sampling rate in Hz
order: Filter order (typically 1-4)
"""
self.cutoff_hz = cutoff_hz
self.sample_rate_hz = sample_rate_hz
self.order = order
# Normalized frequency (0 to 1, where 1 = Nyquist)
self.omega_n = 2.0 * math.pi * cutoff_hz / sample_rate_hz
# Simplified filter coefficients for order 2
# Using canonical form: y[n] = b0*x[n] + b1*x[n-1] + b2*x[n-2] - a1*y[n-1] - a2*y[n-2]
if order == 1:
# First-order filter
alpha = self.omega_n / (self.omega_n + 2.0)
self.b = [alpha, alpha]
self.a = [1.0, -(1.0 - 2.0 * alpha)]
else:
# Second-order filter (butterworth)
sqrt2 = math.sqrt(2.0)
wc = math.tan(self.omega_n / 2.0)
wc2 = wc * wc
denom = 1.0 + sqrt2 * wc + wc2
self.b = [wc2 / denom, 2.0 * wc2 / denom, wc2 / denom]
self.a = [1.0,
2.0 * (wc2 - 1.0) / denom,
(1.0 - sqrt2 * wc + wc2) / denom]
# State buffers
self.x_history = [0.0, 0.0, 0.0] # Input history
self.y_history = [0.0, 0.0] # Output history
def filter(self, x):
"""Apply filter to input value.
Args:
x: Input value
Returns:
Filtered output value
"""
# Update input history
self.x_history[2] = self.x_history[1]
self.x_history[1] = self.x_history[0]
self.x_history[0] = x
# Compute output using difference equation
if len(self.b) == 2:
# First-order filter
y = (self.b[0] * self.x_history[0] +
self.b[1] * self.x_history[1] -
self.a[1] * self.y_history[1])
else:
# Second-order filter
y = (self.b[0] * self.x_history[0] +
self.b[1] * self.x_history[1] +
self.b[2] * self.x_history[2] -
self.a[1] * self.y_history[1] -
self.a[2] * self.y_history[2])
# Update output history
self.y_history[1] = self.y_history[0]
self.y_history[0] = y
return y
def reset(self):
"""Reset filter state."""
self.x_history = [0.0, 0.0, 0.0]
self.y_history = [0.0, 0.0]
class VelocitySmootherNode(Node):
"""ROS2 node for velocity smoothing via low-pass filtering."""
def __init__(self):
super().__init__('velocity_smoother')
# Parameters
self.declare_parameter('filter_order', 2)
self.declare_parameter('cutoff_frequency', 5.0)
self.declare_parameter('publish_frequency', 50)
self.declare_parameter('enable_smoothing', True)
filter_order = self.get_parameter('filter_order').value
cutoff_frequency = self.get_parameter('cutoff_frequency').value
publish_frequency = self.get_parameter('publish_frequency').value
self.enable_smoothing = self.get_parameter('enable_smoothing').value
# Create filters for each velocity component
self.filter_linear_x = ButterworthFilter(
cutoff_frequency, publish_frequency, order=filter_order
)
self.filter_linear_y = ButterworthFilter(
cutoff_frequency, publish_frequency, order=filter_order
)
self.filter_linear_z = ButterworthFilter(
cutoff_frequency, publish_frequency, order=filter_order
)
self.filter_angular_x = ButterworthFilter(
cutoff_frequency, publish_frequency, order=filter_order
)
self.filter_angular_y = ButterworthFilter(
cutoff_frequency, publish_frequency, order=filter_order
)
self.filter_angular_z = ButterworthFilter(
cutoff_frequency, publish_frequency, order=filter_order
)
# Last received odometry
self.last_odom = None
# Subscription to raw odometry
self.sub_odom = self.create_subscription(
Odometry, '/odom', self._on_odom, 10
)
# Publisher for smoothed odometry
self.pub_odom_smooth = self.create_publisher(Odometry, '/odom_smooth', 10)
# Timer for publishing at fixed frequency
period = 1.0 / publish_frequency
self.timer = self.create_timer(period, self._timer_callback)
self.get_logger().info(
f"Velocity smoother initialized. "
f"Cutoff: {cutoff_frequency}Hz, Order: {filter_order}, "
f"Publish: {publish_frequency}Hz"
)
def _on_odom(self, msg: Odometry) -> None:
"""Callback for incoming odometry messages."""
self.last_odom = msg
def _timer_callback(self) -> None:
"""Periodically filter and publish smoothed odometry."""
if self.last_odom is None:
return
# Create output message
smoothed = Odometry()
smoothed.header = self.last_odom.header
smoothed.child_frame_id = self.last_odom.child_frame_id
# Copy pose (unchanged)
smoothed.pose = self.last_odom.pose
if self.enable_smoothing:
# Filter velocity components
linear_x = self.filter_linear_x.filter(
self.last_odom.twist.twist.linear.x
)
linear_y = self.filter_linear_y.filter(
self.last_odom.twist.twist.linear.y
)
linear_z = self.filter_linear_z.filter(
self.last_odom.twist.twist.linear.z
)
angular_x = self.filter_angular_x.filter(
self.last_odom.twist.twist.angular.x
)
angular_y = self.filter_angular_y.filter(
self.last_odom.twist.twist.angular.y
)
angular_z = self.filter_angular_z.filter(
self.last_odom.twist.twist.angular.z
)
else:
# No filtering
linear_x = self.last_odom.twist.twist.linear.x
linear_y = self.last_odom.twist.twist.linear.y
linear_z = self.last_odom.twist.twist.linear.z
angular_x = self.last_odom.twist.twist.angular.x
angular_y = self.last_odom.twist.twist.angular.y
angular_z = self.last_odom.twist.twist.angular.z
# Set smoothed twist
smoothed.twist.twist.linear.x = linear_x
smoothed.twist.twist.linear.y = linear_y
smoothed.twist.twist.linear.z = linear_z
smoothed.twist.twist.angular.x = angular_x
smoothed.twist.twist.angular.y = angular_y
smoothed.twist.twist.angular.z = angular_z
# Copy covariances
smoothed.twist.covariance = self.last_odom.twist.covariance
self.pub_odom_smooth.publish(smoothed)
def main(args=None):
rclpy.init(args=args)
node = VelocitySmootherNode()
try:
rclpy.spin(node)
except KeyboardInterrupt:
pass
finally:
node.destroy_node()
rclpy.shutdown()
if __name__ == '__main__':
main()

View File

@ -0,0 +1,5 @@
[develop]
script_dir=$base/lib/saltybot_velocity_smoother
[egg_info]
tag_build =
tag_date = 0

View File

@ -0,0 +1,34 @@
from setuptools import setup, find_packages
package_name = 'saltybot_velocity_smoother'
setup(
name=package_name,
version='0.1.0',
packages=find_packages(exclude=['test']),
data_files=[
('share/ament_index/resource_index/packages',
['resource/saltybot_velocity_smoother']),
('share/' + package_name, ['package.xml']),
('share/' + package_name + '/config', ['config/velocity_smoother_config.yaml']),
('share/' + package_name + '/launch', ['launch/velocity_smoother.launch.py']),
],
install_requires=['setuptools'],
zip_safe=True,
author='SaltyBot Controls',
author_email='sl-controls@saltybot.local',
maintainer='SaltyBot Controls',
maintainer_email='sl-controls@saltybot.local',
keywords=['ROS2', 'velocity', 'filtering', 'butterworth'],
classifiers=[
'Intended Audience :: Developers',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3',
'Topic :: Software Development',
],
entry_points={
'console_scripts': [
'velocity_smoother_node=saltybot_velocity_smoother.velocity_smoother_node:main',
],
},
)

View File

@ -0,0 +1,432 @@
"""Unit tests for velocity smoother node."""
import pytest
import math
from nav_msgs.msg import Odometry
from geometry_msgs.msg import TwistWithCovariance, Twist, Vector3
from std_msgs.msg import Header
import rclpy
from rclpy.time import Time
from saltybot_velocity_smoother.velocity_smoother_node import (
VelocitySmootherNode,
ButterworthFilter,
)
@pytest.fixture
def rclpy_fixture():
"""Initialize and cleanup rclpy."""
rclpy.init()
yield
rclpy.shutdown()
@pytest.fixture
def node(rclpy_fixture):
"""Create a velocity smoother node instance."""
node = VelocitySmootherNode()
yield node
node.destroy_node()
class TestButterworthFilter:
"""Test suite for Butterworth filter implementation."""
def test_filter_initialization(self):
"""Test filter initialization with valid parameters."""
filt = ButterworthFilter(5.0, 50, order=2)
assert filt.cutoff_hz == 5.0
assert filt.sample_rate_hz == 50
assert filt.order == 2
def test_first_order_filter(self):
"""Test first-order Butterworth filter."""
filt = ButterworthFilter(5.0, 50, order=1)
assert len(filt.b) == 2
assert len(filt.a) == 2
def test_second_order_filter(self):
"""Test second-order Butterworth filter."""
filt = ButterworthFilter(5.0, 50, order=2)
assert len(filt.b) == 3
assert len(filt.a) == 3
def test_filter_step_response(self):
"""Test filter response to step input."""
filt = ButterworthFilter(5.0, 50, order=2)
# Apply step input (0 -> 1.0)
outputs = []
for i in range(50):
y = filt.filter(1.0)
outputs.append(y)
# Final output should be close to 1.0
assert outputs[-1] > 0.9
assert outputs[-1] <= 1.0
def test_filter_constant_input(self):
"""Test filter with constant input."""
filt = ButterworthFilter(5.0, 50, order=2)
# Apply constant input
for i in range(100):
y = filt.filter(2.5)
# Should converge to input value
assert abs(y - 2.5) < 0.01
def test_filter_zero_input(self):
"""Test filter with zero input."""
filt = ButterworthFilter(5.0, 50, order=2)
# Apply non-zero then zero
for i in range(50):
filt.filter(1.0)
# Now apply zero
for i in range(50):
y = filt.filter(0.0)
# Should decay to zero
assert abs(y) < 0.01
def test_filter_reset(self):
"""Test filter state reset."""
filt = ButterworthFilter(5.0, 50, order=2)
# Filter some values
for i in range(10):
filt.filter(1.0)
# Reset
filt.reset()
# State should be zero
assert filt.x_history == [0.0, 0.0, 0.0]
assert filt.y_history == [0.0, 0.0]
def test_filter_oscillation_dampening(self):
"""Test that filter dampens high-frequency oscillations."""
filt = ButterworthFilter(5.0, 50, order=2)
# Apply alternating signal (high frequency)
outputs = []
for i in range(100):
x = 1.0 if i % 2 == 0 else -1.0
y = filt.filter(x)
outputs.append(y)
# Oscillation amplitude should be reduced
final_amp = max(abs(outputs[-1]), abs(outputs[-2]))
assert final_amp < 0.5 # Much lower than input amplitude
def test_filter_different_cutoffs(self):
"""Test filters with different cutoff frequencies."""
filt_low = ButterworthFilter(2.0, 50, order=2)
filt_high = ButterworthFilter(10.0, 50, order=2)
# Both should be valid
assert filt_low.cutoff_hz == 2.0
assert filt_high.cutoff_hz == 10.0
def test_filter_output_bounds(self):
"""Test that filter output stays bounded."""
filt = ButterworthFilter(5.0, 50, order=2)
# Apply large random-like values
for i in range(100):
x = math.sin(i * 0.5) * 5.0
y = filt.filter(x)
assert abs(y) < 10.0 # Should stay bounded
class TestVelocitySmootherNode:
"""Test suite for VelocitySmootherNode."""
def test_node_initialization(self, node):
"""Test that node initializes correctly."""
assert node.last_odom is None
assert node.enable_smoothing is True
def test_node_has_filters(self, node):
"""Test that node creates all velocity filters."""
assert node.filter_linear_x is not None
assert node.filter_linear_y is not None
assert node.filter_linear_z is not None
assert node.filter_angular_x is not None
assert node.filter_angular_y is not None
assert node.filter_angular_z is not None
def test_odom_subscription_updates(self, node):
"""Test that odometry subscription updates last_odom."""
odom = Odometry()
odom.header.frame_id = "odom"
odom.twist.twist.linear.x = 1.0
node._on_odom(odom)
assert node.last_odom is not None
assert node.last_odom.twist.twist.linear.x == 1.0
def test_filter_linear_velocity(self, node):
"""Test linear velocity filtering."""
# Create odometry message
odom = Odometry()
odom.header.frame_id = "odom"
odom.child_frame_id = "base_link"
odom.twist.twist.linear.x = 1.0
odom.twist.twist.linear.y = 0.5
odom.twist.twist.linear.z = 0.0
odom.twist.twist.angular.x = 0.0
odom.twist.twist.angular.y = 0.0
odom.twist.twist.angular.z = 0.2
node._on_odom(odom)
# Call timer callback to process
node._timer_callback()
# Filter should have been applied
assert node.filter_linear_x.x_history[0] == 1.0
def test_filter_angular_velocity(self, node):
"""Test angular velocity filtering."""
odom = Odometry()
odom.header.frame_id = "odom"
odom.twist.twist.angular.z = 0.5
node._on_odom(odom)
node._timer_callback()
assert node.filter_angular_z.x_history[0] == 0.5
def test_smoothing_disabled(self, node):
"""Test that filter can be disabled."""
node.enable_smoothing = False
odom = Odometry()
odom.header.frame_id = "odom"
odom.twist.twist.linear.x = 2.0
node._on_odom(odom)
node._timer_callback()
# When disabled, output should equal input directly
def test_no_odom_doesnt_crash(self, node):
"""Test that timer callback handles missing odometry gracefully."""
# Call timer without setting odometry
node._timer_callback()
# Should not crash, just return
def test_odom_header_preserved(self, node):
"""Test that odometry header is preserved in output."""
odom = Odometry()
odom.header.frame_id = "test_frame"
odom.header.stamp = node.get_clock().now()
odom.child_frame_id = "test_child"
node._on_odom(odom)
# Timer callback processes it
node._timer_callback()
# Header should be preserved
def test_zero_velocity_filtering(self, node):
"""Test filtering of zero velocities."""
odom = Odometry()
odom.header.frame_id = "odom"
odom.twist.twist.linear.x = 0.0
odom.twist.twist.linear.y = 0.0
odom.twist.twist.linear.z = 0.0
odom.twist.twist.angular.x = 0.0
odom.twist.twist.angular.y = 0.0
odom.twist.twist.angular.z = 0.0
node._on_odom(odom)
node._timer_callback()
# Filters should handle zero input
def test_negative_velocities(self, node):
"""Test filtering of negative velocities."""
odom = Odometry()
odom.header.frame_id = "odom"
odom.twist.twist.linear.x = -1.0
odom.twist.twist.angular.z = -0.5
node._on_odom(odom)
node._timer_callback()
assert node.filter_linear_x.x_history[0] == -1.0
def test_high_frequency_noise_dampening(self, node):
"""Test that filter dampens high-frequency encoder noise."""
# Simulate noisy encoder output
base_velocity = 1.0
noise_amplitude = 0.2
odom = Odometry()
odom.header.frame_id = "odom"
# Apply alternating noise
for i in range(100):
odom.twist.twist.linear.x = base_velocity + (noise_amplitude if i % 2 == 0 else -noise_amplitude)
node._on_odom(odom)
node._timer_callback()
# After filtering, output should be close to base velocity
# (oscillations dampened)
def test_large_velocity_values(self, node):
"""Test filtering of large velocity values."""
odom = Odometry()
odom.header.frame_id = "odom"
odom.twist.twist.linear.x = 10.0
odom.twist.twist.angular.z = 5.0
node._on_odom(odom)
node._timer_callback()
# Should handle large values without overflow
def test_pose_unchanged(self, node):
"""Test that pose is not modified by velocity filtering."""
odom = Odometry()
odom.header.frame_id = "odom"
odom.pose.pose.position.x = 5.0
odom.pose.pose.position.y = 3.0
odom.twist.twist.linear.x = 1.0
node._on_odom(odom)
node._timer_callback()
# Pose should be copied unchanged
def test_multiple_velocity_updates(self, node):
"""Test filtering across multiple sequential velocity updates."""
odom = Odometry()
odom.header.frame_id = "odom"
velocities = [0.5, 1.0, 1.5, 1.0, 0.5, 0.0]
for v in velocities:
odom.twist.twist.linear.x = v
node._on_odom(odom)
node._timer_callback()
# Filter should smooth the velocity sequence
def test_simultaneous_all_velocities(self, node):
"""Test filtering when all velocity components are present."""
odom = Odometry()
odom.header.frame_id = "odom"
for i in range(30):
odom.twist.twist.linear.x = math.sin(i * 0.1)
odom.twist.twist.linear.y = math.cos(i * 0.1) * 0.5
odom.twist.twist.linear.z = 0.1
odom.twist.twist.angular.x = 0.05
odom.twist.twist.angular.y = 0.05
odom.twist.twist.angular.z = math.sin(i * 0.15)
node._on_odom(odom)
node._timer_callback()
# All filters should operate independently
class TestVelocitySmootherScenarios:
"""Integration-style tests for realistic scenarios."""
def test_scenario_constant_velocity(self, node):
"""Scenario: robot moving at constant velocity."""
odom = Odometry()
odom.header.frame_id = "odom"
odom.twist.twist.linear.x = 1.0
for i in range(50):
node._on_odom(odom)
node._timer_callback()
# Should maintain constant velocity after convergence
def test_scenario_velocity_ramp(self, node):
"""Scenario: velocity ramping up from stop."""
odom = Odometry()
odom.header.frame_id = "odom"
for i in range(50):
odom.twist.twist.linear.x = i * 0.02 # Ramp from 0 to 1.0
node._on_odom(odom)
node._timer_callback()
# Filter should smooth the ramp
def test_scenario_velocity_step(self, node):
"""Scenario: sudden velocity change (e.g., collision avoidance)."""
odom = Odometry()
odom.header.frame_id = "odom"
# First phase: constant velocity
odom.twist.twist.linear.x = 1.0
for i in range(25):
node._on_odom(odom)
node._timer_callback()
# Second phase: sudden stop
odom.twist.twist.linear.x = 0.0
for i in range(25):
node._on_odom(odom)
node._timer_callback()
# Filter should transition smoothly
def test_scenario_rotation_only(self, node):
"""Scenario: robot spinning in place."""
odom = Odometry()
odom.header.frame_id = "odom"
odom.twist.twist.linear.x = 0.0
odom.twist.twist.angular.z = 0.5
for i in range(50):
node._on_odom(odom)
node._timer_callback()
# Angular velocity should be filtered
def test_scenario_mixed_motion(self, node):
"""Scenario: combined linear and angular motion."""
odom = Odometry()
odom.header.frame_id = "odom"
for i in range(50):
odom.twist.twist.linear.x = math.cos(i * 0.1)
odom.twist.twist.linear.y = math.sin(i * 0.1)
odom.twist.twist.angular.z = 0.2
node._on_odom(odom)
node._timer_callback()
# Both linear and angular components should be filtered
def test_scenario_encoder_noise_reduction(self, node):
"""Scenario: realistic encoder jitter with filtering."""
odom = Odometry()
odom.header.frame_id = "odom"
# Simulate encoder jitter: base velocity + small noise
base_vel = 1.0
for i in range(100):
jitter = 0.05 * math.sin(i * 0.5) + 0.03 * math.cos(i * 0.3)
odom.twist.twist.linear.x = base_vel + jitter
node._on_odom(odom)
node._timer_callback()
# Filter should reduce noise while maintaining base velocity