Compare commits

...

5 Commits

Author SHA1 Message Date
067a871103 feat(perception): wheel encoder differential drive odometry (Issue #184)
Some checks failed
social-bot integration tests / Lint (flake8 + pep257) (pull_request) Failing after 7s
social-bot integration tests / Core integration tests (mock sensors, no GPU) (pull_request) Has been skipped
social-bot integration tests / Latency profiling (GPU, Orin) (pull_request) Has been cancelled
Adds saltybot_bridge_msgs package with WheelTicks.msg (int32 left/right
encoder counts) and a WheelOdomNode that subscribes to
/saltybot/wheel_ticks, integrates midpoint-Euler differential drive
kinematics (handling int32 counter rollover), and publishes
nav_msgs/Odometry on /odom_wheel at 50 Hz with optional TF broadcast.

New files:
  jetson/ros2_ws/src/saltybot_bridge_msgs/
    msg/WheelTicks.msg
    CMakeLists.txt, package.xml

  jetson/ros2_ws/src/saltybot_bringup/
    saltybot_bringup/_wheel_odom.py     — pure kinematics (no ROS2 deps)
    saltybot_bringup/wheel_odom_node.py — 50 Hz timer node + TF broadcast
    test/test_wheel_odom.py             — 42 tests, all passing

Modified:
  saltybot_bringup/package.xml  — add saltybot_bridge_msgs, nav_msgs deps
  saltybot_bringup/setup.py     — add wheel_odom console_script entry

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-03 00:41:39 -05:00
b96c6b96d0 Merge pull request 'feat(social): audio wake-word detector 'hey salty' (Issue #320)' (#317) from sl-jetson/wake-word-detect into main
Some checks failed
social-bot integration tests / Lint (flake8 + pep257) (push) Failing after 10s
social-bot integration tests / Core integration tests (mock sensors, no GPU) (push) Has been skipped
social-bot integration tests / Lint (flake8 + pep257) (pull_request) Failing after 10s
social-bot integration tests / Core integration tests (mock sensors, no GPU) (pull_request) Has been skipped
social-bot integration tests / Latency profiling (GPU, Orin) (push) Has been cancelled
social-bot integration tests / Latency profiling (GPU, Orin) (pull_request) Has been cancelled
2026-03-03 00:41:22 -05:00
d5e0c58594 Merge pull request 'feat: Add velocity smoothing filter ROS2 node' (#316) from sl-controls/velocity-smooth-filter into main 2026-03-03 00:41:16 -05:00
d6553ce3d6 feat(social): audio wake-word detector 'hey salty' (Issue #320)
Some checks failed
social-bot integration tests / Lint (flake8 + pep257) (push) Failing after 2s
social-bot integration tests / Core integration tests (mock sensors, no GPU) (push) Has been skipped
social-bot integration tests / Lint (flake8 + pep257) (pull_request) Failing after 10s
social-bot integration tests / Core integration tests (mock sensors, no GPU) (pull_request) Has been skipped
social-bot integration tests / Latency profiling (GPU, Orin) (push) Has been cancelled
social-bot integration tests / Latency profiling (GPU, Orin) (pull_request) Has been cancelled
Energy-gated log-mel + cosine-similarity wake-word node. Subscribes to
/social/speech/audio_raw (PCM-16 UInt8MultiArray), maintains a 1.5 s
sliding ring buffer, runs detection every 100 ms; fires Bool(True) on
/saltybot/wake_word_detected with 2 s cooldown. Template loaded from
.npy file; passive (no detections) when template_path is empty.
91/91 tests pass.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-03 00:26:59 -05:00
2919e629e9 feat: Add velocity smoothing filter with Butterworth low-pass filtering
Implements saltybot_velocity_smoother package:
- Subscribes to /odom, applies digital Butterworth low-pass filter
- Filters linear (x,y,z) and angular (x,y,z) velocity components independently
- Publishes smoothed odometry on /odom_smooth
- Reduces encoder jitter and improves state estimation stability
- Configurable filter order (1-4), cutoff frequency (Hz), publish rate
- Can be enabled/disabled via enable_smoothing parameter
- Comprehensive test suite: 18+ tests covering filter behavior, edge cases, scenarios

Co-Authored-By: Claude Haiku 4.5 <noreply@anthropic.com>
2026-03-03 00:23:53 -05:00
23 changed files with 2646 additions and 0 deletions

View File

@ -0,0 +1,15 @@
cmake_minimum_required(VERSION 3.8)
project(saltybot_bridge_msgs)
find_package(ament_cmake REQUIRED)
find_package(rosidl_default_generators REQUIRED)
find_package(std_msgs REQUIRED)
find_package(builtin_interfaces REQUIRED)
rosidl_generate_interfaces(${PROJECT_NAME}
"msg/WheelTicks.msg"
DEPENDENCIES std_msgs builtin_interfaces
)
ament_export_dependencies(rosidl_default_runtime)
ament_package()

View File

@ -0,0 +1,11 @@
# WheelTicks.msg — cumulative wheel encoder tick counts from STM32 (Issue #184)
#
# left_ticks : cumulative left encoder count (int32, wraps at ±2^31)
# right_ticks : cumulative right encoder count (int32, wraps at ±2^31)
#
# Receivers must handle int32 rollover by computing delta relative to the
# previous message value. The wheel_odom_node does this via unwrap_delta().
#
std_msgs/Header header
int32 left_ticks
int32 right_ticks

View File

@ -0,0 +1,23 @@
<?xml version="1.0"?>
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
<package format="3">
<name>saltybot_bridge_msgs</name>
<version>0.1.0</version>
<description>STM32 bridge message definitions — wheel encoder ticks and low-level hardware telemetry (Issue #184)</description>
<maintainer email="sl-perception@saltylab.local">sl-perception</maintainer>
<license>MIT</license>
<buildtool_depend>ament_cmake</buildtool_depend>
<buildtool_depend>rosidl_default_generators</buildtool_depend>
<depend>std_msgs</depend>
<depend>builtin_interfaces</depend>
<exec_depend>rosidl_default_runtime</exec_depend>
<member_of_group>rosidl_interface_packages</member_of_group>
<export>
<build_type>ament_cmake</build_type>
</export>
</package>

View File

@ -27,6 +27,9 @@
<exec_depend>saltybot_perception</exec_depend> <exec_depend>saltybot_perception</exec_depend>
<!-- HSV color segmentation messages (Issue #274) --> <!-- HSV color segmentation messages (Issue #274) -->
<exec_depend>saltybot_scene_msgs</exec_depend> <exec_depend>saltybot_scene_msgs</exec_depend>
<!-- Wheel encoder messages + odometry (Issue #184) -->
<exec_depend>saltybot_bridge_msgs</exec_depend>
<exec_depend>nav_msgs</exec_depend>
<exec_depend>saltybot_uwb</exec_depend> <exec_depend>saltybot_uwb</exec_depend>
<buildtool_depend>ament_python</buildtool_depend> <buildtool_depend>ament_python</buildtool_depend>

View File

@ -0,0 +1,161 @@
"""
_wheel_odom.py Differential drive wheel odometry kinematics (no ROS2 deps).
Algorithm
---------
Given incremental left/right wheel displacements (metres) since the last update:
d_center = (d_left + d_right) / 2
d_theta = (d_right d_left) / wheel_base
Pose is integrated using the midpoint Euler method (equivalent to the exact
ICC solution for a circular arc, but simpler and numerically robust):
mid_theta = theta + d_theta / 2
x += d_center · cos(mid_theta)
y += d_center · sin(mid_theta)
theta += d_theta
theta is kept in (π, π] after every step.
Int32 rollover
--------------
STM32 encoder counters are int32 and wrap at ±2^31. `unwrap_delta` handles
this by detecting jumps larger than half the int32 range and adjusting by the
full range:
if delta > 2^30 : delta -= 2^31
if delta < -2^30 : delta += 2^31
This correctly handles up to ½ of the full int32 range per message interval
safely above any realistic tick rate at 50 Hz.
Public API
----------
WheelOdomState(x, y, theta)
unwrap_delta(current, prev) -> int
ticks_to_meters(ticks, radius, ticks_per_rev) -> float
integrate_odom(x, y, theta, d_left_m, d_right_m, wheel_base) -> (x, y, theta)
normalize_angle(theta) -> float
quat_from_yaw(yaw) -> (qx, qy, qz, qw)
"""
from __future__ import annotations
import math
from typing import NamedTuple, Tuple
# ── Data types ────────────────────────────────────────────────────────────────
class WheelOdomState(NamedTuple):
"""Current dead-reckoning pose estimate."""
x: float # metres (world frame)
y: float # metres (world frame)
theta: float # radians, kept in (−π, π]
# ── Int32 rollover handling ───────────────────────────────────────────────────
_INT32_HALF = 2 ** 31 # half of the full int32 value range (2^32) — detection threshold
_INT32_FULL = 2 ** 32 # full int32 value range — adjustment amount
def unwrap_delta(current: int, prev: int) -> int:
"""
Compute signed tick delta handling int32 counter rollover.
Parameters
----------
current : current tick count (int32, may have wrapped)
prev : previous tick count
Returns
-------
Signed delta ticks, correct even across the ±2^31 rollover boundary.
"""
delta = int(current) - int(prev)
if delta > _INT32_HALF:
delta -= _INT32_FULL
elif delta < -_INT32_HALF:
delta += _INT32_FULL
return delta
# ── Unit conversion ───────────────────────────────────────────────────────────
def ticks_to_meters(delta_ticks: int, wheel_radius: float, ticks_per_rev: int) -> float:
"""
Convert encoder tick count to linear wheel displacement in metres.
Parameters
----------
delta_ticks : signed tick increment
wheel_radius : effective wheel radius (metres)
ticks_per_rev: encoder ticks per full wheel revolution
Returns
-------
Linear displacement (metres); positive = forward.
"""
if ticks_per_rev <= 0:
return 0.0
return delta_ticks * (2.0 * math.pi * wheel_radius) / ticks_per_rev
# ── Pose integration ──────────────────────────────────────────────────────────
def normalize_angle(theta: float) -> float:
"""Wrap angle to (−π, π]."""
return math.atan2(math.sin(theta), math.cos(theta))
def integrate_odom(
x: float,
y: float,
theta: float,
d_left_m: float,
d_right_m: float,
wheel_base: float,
) -> Tuple[float, float, float]:
"""
Integrate differential drive kinematics using the midpoint Euler method.
Parameters
----------
x, y, theta : current pose (metres, metres, radians)
d_left_m : left wheel displacement since last update (metres)
d_right_m : right wheel displacement since last update (metres)
wheel_base : centre-to-centre wheel separation (metres)
Returns
-------
(x, y, theta) updated pose
"""
d_center = (d_left_m + d_right_m) / 2.0
d_theta = (d_right_m - d_left_m) / wheel_base
mid_theta = theta + d_theta / 2.0
new_x = x + d_center * math.cos(mid_theta)
new_y = y + d_center * math.sin(mid_theta)
new_theta = normalize_angle(theta + d_theta)
return new_x, new_y, new_theta
# ── Quaternion helper ─────────────────────────────────────────────────────────
def quat_from_yaw(yaw: float) -> Tuple[float, float, float, float]:
"""
Convert a yaw angle (rotation about Z) to a unit quaternion.
Parameters
----------
yaw : rotation angle in radians
Returns
-------
(qx, qy, qz, qw)
"""
half = yaw / 2.0
return (0.0, 0.0, math.sin(half), math.cos(half))

View File

@ -0,0 +1,210 @@
"""
wheel_odom_node.py Differential drive wheel encoder odometry (Issue #184).
Subscribes to raw encoder tick counts from the STM32 bridge, integrates
differential drive kinematics, and publishes nav_msgs/Odometry at 50 Hz.
Optionally broadcasts the odom base_link TF transform.
Subscribes:
/saltybot/wheel_ticks saltybot_bridge_msgs/WheelTicks (RELIABLE)
Publishes:
/odom_wheel nav_msgs/Odometry at publish_hz (default 50 Hz)
TF broadcast (when publish_tf=true):
odom base_link
Parameters
----------
wheel_radius float 0.034 Effective wheel radius (metres)
wheel_base float 0.160 Centre-to-centre wheel separation (metres)
ticks_per_rev int 1000 Encoder ticks per full wheel revolution
publish_hz float 50.0 Odometry publication rate (Hz)
odom_frame_id str odom Frame id for odometry header
base_frame_id str base_link Child frame id
publish_tf bool true Whether to broadcast odom base_link TF
"""
from __future__ import annotations
import math
import rclpy
from rclpy.node import Node
from rclpy.qos import QoSProfile, ReliabilityPolicy, HistoryPolicy
import tf2_ros
from geometry_msgs.msg import TransformStamped
from nav_msgs.msg import Odometry
from saltybot_bridge_msgs.msg import WheelTicks
from ._wheel_odom import (
WheelOdomState,
unwrap_delta,
ticks_to_meters,
integrate_odom,
quat_from_yaw,
)
# Covariance matrices (6×6 diagonal, row-major).
# Indices: 0=x, 7=y, 14=z, 21=roll, 28=pitch, 35=yaw
_POSE_COV = [0.0] * 36
_TWIST_COV = [0.0] * 36
_POSE_COV[0] = 1e-3 # x
_POSE_COV[7] = 1e-3 # y
_POSE_COV[35] = 1e-3 # yaw
_TWIST_COV[0] = 1e-3 # vx
_TWIST_COV[35] = 1e-3 # v_yaw
class WheelOdomNode(Node):
def __init__(self) -> None:
super().__init__('wheel_odom_node')
self.declare_parameter('wheel_radius', 0.034)
self.declare_parameter('wheel_base', 0.160)
self.declare_parameter('ticks_per_rev', 1000)
self.declare_parameter('publish_hz', 50.0)
self.declare_parameter('odom_frame_id', 'odom')
self.declare_parameter('base_frame_id', 'base_link')
self.declare_parameter('publish_tf', True)
self._radius = float(self.get_parameter('wheel_radius').value)
self._base = float(self.get_parameter('wheel_base').value)
self._ticks_rev = int(self.get_parameter('ticks_per_rev').value)
publish_hz = float(self.get_parameter('publish_hz').value)
self._odom_fid = self.get_parameter('odom_frame_id').value
self._base_fid = self.get_parameter('base_frame_id').value
self._pub_tf = bool(self.get_parameter('publish_tf').value)
# Running state
self._x = 0.0
self._y = 0.0
self._theta = 0.0
self._prev_left: int | None = None
self._prev_right: int | None = None
# Velocity accumulation between timer callbacks
self._accum_d: float = 0.0
self._accum_dtheta: float = 0.0
self._sub = self.create_subscription(
WheelTicks,
'/saltybot/wheel_ticks',
self._on_ticks,
QoSProfile(
reliability=ReliabilityPolicy.RELIABLE,
history=HistoryPolicy.KEEP_LAST,
depth=10,
),
)
self._pub = self.create_publisher(Odometry, '/odom_wheel', 10)
if self._pub_tf:
self._tf_br = tf2_ros.TransformBroadcaster(self)
else:
self._tf_br = None
period = 1.0 / max(publish_hz, 1.0)
self._last_pub_time = self.get_clock().now()
self._timer = self.create_timer(period, self._on_timer)
self.get_logger().info(
f'wheel_odom_node ready — '
f'radius={self._radius}m base={self._base}m '
f'ticks_per_rev={self._ticks_rev} hz={publish_hz}'
)
# ── Tick callback (updates pose immediately) ──────────────────────────────
def _on_ticks(self, msg: WheelTicks) -> None:
if self._prev_left is None:
self._prev_left = msg.left_ticks
self._prev_right = msg.right_ticks
return
dl = unwrap_delta(msg.left_ticks, self._prev_left)
dr = unwrap_delta(msg.right_ticks, self._prev_right)
self._prev_left = msg.left_ticks
self._prev_right = msg.right_ticks
d_left_m = ticks_to_meters(dl, self._radius, self._ticks_rev)
d_right_m = ticks_to_meters(dr, self._radius, self._ticks_rev)
self._x, self._y, self._theta = integrate_odom(
self._x, self._y, self._theta,
d_left_m, d_right_m, self._base,
)
# Accumulate for velocity estimation in the timer
self._accum_d += (d_left_m + d_right_m) / 2.0
self._accum_dtheta += (d_right_m - d_left_m) / self._base
# ── Timer callback (publishes at fixed rate) ──────────────────────────────
def _on_timer(self) -> None:
now = self.get_clock().now()
dt = (now - self._last_pub_time).nanoseconds * 1e-9
self._last_pub_time = now
# Velocity from accumulated incremental motion over the timer period
vx = self._accum_d / dt if dt > 1e-6 else 0.0
omega = self._accum_dtheta / dt if dt > 1e-6 else 0.0
self._accum_d = 0.0
self._accum_dtheta = 0.0
qx, qy, qz, qw = quat_from_yaw(self._theta)
stamp = now.to_msg()
# ── Odometry message ──────────────────────────────────────────────────
odom = Odometry()
odom.header.stamp = stamp
odom.header.frame_id = self._odom_fid
odom.child_frame_id = self._base_fid
odom.pose.pose.position.x = self._x
odom.pose.pose.position.y = self._y
odom.pose.pose.position.z = 0.0
odom.pose.pose.orientation.x = qx
odom.pose.pose.orientation.y = qy
odom.pose.pose.orientation.z = qz
odom.pose.pose.orientation.w = qw
odom.pose.covariance = _POSE_COV
odom.twist.twist.linear.x = vx
odom.twist.twist.angular.z = omega
odom.twist.covariance = _TWIST_COV
self._pub.publish(odom)
# ── TF broadcast ──────────────────────────────────────────────────────
if self._tf_br is not None:
tf_msg = TransformStamped()
tf_msg.header.stamp = stamp
tf_msg.header.frame_id = self._odom_fid
tf_msg.child_frame_id = self._base_fid
tf_msg.transform.translation.x = self._x
tf_msg.transform.translation.y = self._y
tf_msg.transform.translation.z = 0.0
tf_msg.transform.rotation.x = qx
tf_msg.transform.rotation.y = qy
tf_msg.transform.rotation.z = qz
tf_msg.transform.rotation.w = qw
self._tf_br.sendTransform(tf_msg)
def main(args=None) -> None:
rclpy.init(args=args)
node = WheelOdomNode()
try:
rclpy.spin(node)
finally:
node.destroy_node()
rclpy.shutdown()
if __name__ == '__main__':
main()

View File

@ -47,6 +47,8 @@ setup(
'terrain_roughness = saltybot_bringup.terrain_rough_node:main', 'terrain_roughness = saltybot_bringup.terrain_rough_node:main',
# Sky detector for outdoor navigation (Issue #307) # Sky detector for outdoor navigation (Issue #307)
'sky_detector = saltybot_bringup.sky_detect_node:main', 'sky_detector = saltybot_bringup.sky_detect_node:main',
# Wheel encoder differential drive odometry (Issue #184)
'wheel_odom = saltybot_bringup.wheel_odom_node:main',
], ],
}, },
) )

View File

@ -0,0 +1,332 @@
"""
test_wheel_odom.py Unit tests for wheel odometry kinematics (no ROS2 required).
Covers:
unwrap_delta:
- normal positive delta
- normal negative delta
- zero delta
- positive rollover (counter crosses +2^31 boundary)
- negative rollover (counter crosses -2^31 boundary)
- near-zero after positive rollover
- large forward motion (just under rollover threshold)
ticks_to_meters:
- zero ticks 0.0 m
- one full revolution 2π * radius
- half revolution π * radius
- negative ticks negative displacement
- ticks_per_rev = 0 0.0 (guard)
- fractional result correctness
normalize_angle:
- 0 0
- π π (or π, both valid; atan2 returns in (π, π])
- 2π 0
- π π (or π)
- 3π/2 π/2
integrate_odom straight line:
- equal d_left == d_right x increases, y unchanged, theta unchanged
- moving backward (negative equal displacements) x decreases
integrate_odom rotation:
- d_right > d_left theta increases (left turn)
- d_left > d_right theta decreases (right turn)
- spin in place (d_left = d_right) x,y unchanged, theta = d_theta
integrate_odom circular motion:
- four identical quarter-turns return near-original position
- heading after 90° left turn π/2
integrate_odom starting from non-zero pose:
- displacement is applied in the current heading direction
quat_from_yaw:
- yaw=0 qw=1, qx=qy=qz=0
- yaw=π qw0, qz1
- yaw=π/2 qw=qz1/2, qx=qy=0
- unit quaternion: qx²+qy²+qz²+qw²=1 for arbitrary yaw
- yaw=π/2 qz1/2
WheelOdomState:
- fields accessible by name
- immutable (NamedTuple)
"""
import sys
import os
import math
import pytest
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
from saltybot_bringup._wheel_odom import (
WheelOdomState,
unwrap_delta,
ticks_to_meters,
normalize_angle,
integrate_odom,
quat_from_yaw,
)
# ── Constants ─────────────────────────────────────────────────────────────────
_R = 0.034 # wheel radius (m)
_BASE = 0.160 # wheel base (m)
_TPR = 1000 # ticks per rev
# ── WheelOdomState ────────────────────────────────────────────────────────────
class TestWheelOdomState:
def test_fields_accessible(self):
s = WheelOdomState(x=1.0, y=2.0, theta=0.5)
assert s.x == pytest.approx(1.0)
assert s.y == pytest.approx(2.0)
assert s.theta == pytest.approx(0.5)
def test_is_namedtuple(self):
s = WheelOdomState(0.0, 0.0, 0.0)
assert isinstance(s, tuple)
# ── unwrap_delta ──────────────────────────────────────────────────────────────
class TestUnwrapDelta:
def test_normal_positive(self):
assert unwrap_delta(100, 50) == 50
def test_normal_negative(self):
assert unwrap_delta(50, 100) == -50
def test_zero(self):
assert unwrap_delta(200, 200) == 0
def test_positive_rollover(self):
"""Counter wraps from just below +2^31 to just above 2^31."""
MAX = 2 ** 31
prev = MAX - 10
current = -MAX + 5 # wrapped: total advance = 15 ticks
assert unwrap_delta(current, prev) == 15
def test_negative_rollover(self):
"""Counter wraps from just above 2^31 to just below +2^31."""
MAX = 2 ** 31
prev = -MAX + 10
current = MAX - 5 # wrapped backward: delta = 15
assert unwrap_delta(current, prev) == -15
def test_large_forward_no_rollover(self):
"""Delta just under the rollover threshold should not be treated as rollover."""
half = 2 ** 30
assert unwrap_delta(half - 1, 0) == half - 1
def test_symmetry(self):
assert unwrap_delta(300, 200) == -unwrap_delta(200, 300)
# ── ticks_to_meters ───────────────────────────────────────────────────────────
class TestTicksToMeters:
def test_zero_ticks(self):
assert ticks_to_meters(0, _R, _TPR) == pytest.approx(0.0)
def test_one_full_revolution(self):
expected = 2.0 * math.pi * _R
assert ticks_to_meters(_TPR, _R, _TPR) == pytest.approx(expected, rel=1e-9)
def test_half_revolution(self):
expected = math.pi * _R
assert ticks_to_meters(_TPR // 2, _R, _TPR) == pytest.approx(expected, rel=1e-4)
def test_negative_ticks(self):
d = ticks_to_meters(-_TPR, _R, _TPR)
assert d == pytest.approx(-(2.0 * math.pi * _R), rel=1e-9)
def test_ticks_per_rev_zero_returns_zero(self):
assert ticks_to_meters(100, _R, 0) == pytest.approx(0.0)
def test_proportional(self):
d1 = ticks_to_meters(100, _R, _TPR)
d2 = ticks_to_meters(200, _R, _TPR)
assert d2 == pytest.approx(2 * d1, rel=1e-9)
# ── normalize_angle ───────────────────────────────────────────────────────────
class TestNormalizeAngle:
def test_zero(self):
assert normalize_angle(0.0) == pytest.approx(0.0)
def test_two_pi_to_zero(self):
assert normalize_angle(2 * math.pi) == pytest.approx(0.0, abs=1e-10)
def test_three_half_pi_to_neg_half_pi(self):
assert normalize_angle(3 * math.pi / 2) == pytest.approx(-math.pi / 2, rel=1e-9)
def test_minus_three_half_pi_to_half_pi(self):
assert normalize_angle(-3 * math.pi / 2) == pytest.approx(math.pi / 2, rel=1e-9)
def test_pi_stays_in_range(self):
v = normalize_angle(math.pi)
assert -math.pi <= v <= math.pi
def test_large_angle_wraps(self):
v = normalize_angle(100 * math.pi + 0.1)
assert -math.pi < v <= math.pi
assert v == pytest.approx(0.1, rel=1e-6)
# ── integrate_odom — straight line ───────────────────────────────────────────
class TestIntegrateOdomStraight:
def test_straight_forward_x_increases(self):
d = 0.1
x, y, theta = integrate_odom(0.0, 0.0, 0.0, d, d, _BASE)
assert x == pytest.approx(d, rel=1e-9)
assert y == pytest.approx(0.0, abs=1e-12)
assert theta == pytest.approx(0.0, abs=1e-12)
def test_straight_backward_x_decreases(self):
d = -0.05
x, y, theta = integrate_odom(0.0, 0.0, 0.0, d, d, _BASE)
assert x == pytest.approx(d, rel=1e-9)
assert y == pytest.approx(0.0, abs=1e-12)
def test_straight_along_y_axis(self):
"""Starting at theta=π/2, forward motion should increase y."""
d = 0.1
x, y, theta = integrate_odom(0.0, 0.0, math.pi / 2, d, d, _BASE)
assert x == pytest.approx(0.0, abs=1e-10)
assert y == pytest.approx(d, rel=1e-9)
def test_straight_accumulates(self):
"""Multiple straight steps accumulate correctly."""
x, y, theta = 0.0, 0.0, 0.0
for _ in range(10):
x, y, theta = integrate_odom(x, y, theta, 0.01, 0.01, _BASE)
assert x == pytest.approx(0.10, rel=1e-6)
assert y == pytest.approx(0.0, abs=1e-10)
# ── integrate_odom — rotation ─────────────────────────────────────────────────
class TestIntegrateOdomRotation:
def test_right_wheel_faster_turns_left(self):
"""d_right > d_left → robot turns left → theta increases."""
x, y, theta = integrate_odom(0.0, 0.0, 0.0, 0.0, 0.01, _BASE)
assert theta > 0.0
def test_left_wheel_faster_turns_right(self):
"""d_left > d_right → robot turns right → theta decreases."""
x, y, theta = integrate_odom(0.0, 0.0, 0.0, 0.01, 0.0, _BASE)
assert theta < 0.0
def test_spin_in_place_xy_unchanged(self):
"""d_left = d_right → pure rotation, position unchanged."""
d = 0.01
x, y, theta = integrate_odom(0.0, 0.0, 0.0, -d, d, _BASE)
assert x == pytest.approx(0.0, abs=1e-10)
assert y == pytest.approx(0.0, abs=1e-10)
def test_spin_in_place_theta_correct(self):
"""Spinning one radian: d_theta = (d_right d_left) / base."""
d_theta_target = 1.0
d = d_theta_target * _BASE / 2.0
x, y, theta = integrate_odom(0.0, 0.0, 0.0, -d, d, _BASE)
assert theta == pytest.approx(d_theta_target, rel=1e-9)
def test_90deg_left_turn_heading(self):
"""Quarter arc left turn → theta ≈ π/2."""
# Arc length for 90° left turn on wheel base _BASE, radius R:
# r = R + base/2 for outer (right) wheel; r - base/2 for inner (left)
arc_radius = 0.5 # arbitrary turn radius
d_theta = math.pi / 2
d_right = (arc_radius + _BASE / 2) * d_theta
d_left = (arc_radius - _BASE / 2) * d_theta
x, y, theta = integrate_odom(0.0, 0.0, 0.0, d_left, d_right, _BASE)
assert theta == pytest.approx(math.pi / 2, rel=1e-9)
# ── integrate_odom — circular closure ────────────────────────────────────────
class TestIntegrateOdomClosure:
def test_four_quarter_turns_return_to_origin(self):
"""
Four 90° left arcs (same radius) should close into a full circle and
return approximately to the origin.
"""
arc_r = 0.30 # turning radius (m)
d_theta = math.pi / 2 # 90° per arc
d_right = (arc_r + _BASE / 2) * d_theta
d_left = (arc_r - _BASE / 2) * d_theta
x, y, theta = 0.0, 0.0, 0.0
for _ in range(4):
x, y, theta = integrate_odom(x, y, theta, d_left, d_right, _BASE)
assert x == pytest.approx(0.0, abs=1e-9)
assert y == pytest.approx(0.0, abs=1e-9)
assert theta == pytest.approx(0.0, abs=1e-9)
def test_full_spin_in_place(self):
"""Spinning 2π in small steps should return theta to 0."""
d_theta_step = math.pi / 18 # 10° per step
d = d_theta_step * _BASE / 2.0
x, y, theta = 0.0, 0.0, 0.0
for _ in range(36): # 36 × 10° = 360°
x, y, theta = integrate_odom(x, y, theta, -d, d, _BASE)
assert theta == pytest.approx(0.0, abs=1e-9)
# ── quat_from_yaw ─────────────────────────────────────────────────────────────
class TestQuatFromYaw:
def test_zero_yaw(self):
qx, qy, qz, qw = quat_from_yaw(0.0)
assert qx == pytest.approx(0.0, abs=1e-12)
assert qy == pytest.approx(0.0, abs=1e-12)
assert qz == pytest.approx(0.0, abs=1e-12)
assert qw == pytest.approx(1.0)
def test_pi_yaw(self):
qx, qy, qz, qw = quat_from_yaw(math.pi)
assert qx == pytest.approx(0.0, abs=1e-10)
assert qy == pytest.approx(0.0, abs=1e-10)
assert abs(qz) == pytest.approx(1.0, rel=1e-9)
assert abs(qw) == pytest.approx(0.0, abs=1e-10)
def test_half_pi_yaw(self):
qx, qy, qz, qw = quat_from_yaw(math.pi / 2)
s = 1.0 / math.sqrt(2)
assert qx == pytest.approx(0.0, abs=1e-12)
assert qy == pytest.approx(0.0, abs=1e-12)
assert qz == pytest.approx(s, rel=1e-9)
assert qw == pytest.approx(s, rel=1e-9)
def test_neg_half_pi_yaw(self):
qx, qy, qz, qw = quat_from_yaw(-math.pi / 2)
s = 1.0 / math.sqrt(2)
assert qz == pytest.approx(-s, rel=1e-9)
assert qw == pytest.approx(s, rel=1e-9)
@pytest.mark.parametrize('yaw', [0.0, 0.1, 0.5, math.pi, -math.pi / 3, 2.7])
def test_unit_quaternion(self, yaw):
qx, qy, qz, qw = quat_from_yaw(yaw)
norm_sq = qx**2 + qy**2 + qz**2 + qw**2
assert norm_sq == pytest.approx(1.0, rel=1e-9)
if __name__ == '__main__':
pytest.main([__file__, '-v'])

View File

@ -0,0 +1,19 @@
wake_word_node:
ros__parameters:
audio_topic: "/social/speech/audio_raw" # PCM-16 mono input (UInt8MultiArray)
output_topic: "/saltybot/wake_word_detected"
sample_rate: 16000 # Hz — must match audio source
window_s: 1.5 # detection window length (s)
hop_s: 0.1 # detection timer period (s)
energy_threshold: 0.02 # RMS gate; below this → skip matching
match_threshold: 0.82 # cosine-similarity gate; above → detect
cooldown_s: 2.0 # minimum gap between successive detections (s)
# Path to .npy template file (log-mel features of 'hey salty' recording).
# Leave empty for passive mode (no detections fired).
template_path: "" # e.g. "/opt/saltybot/models/hey_salty.npy"
n_fft: 512 # FFT size for mel spectrogram
n_mels: 40 # mel filterbank bands

View File

@ -0,0 +1,43 @@
"""wake_word.launch.py — Launch wake-word detector ('hey salty') (Issue #320).
Usage:
ros2 launch saltybot_social wake_word.launch.py
ros2 launch saltybot_social wake_word.launch.py template_path:=/opt/saltybot/models/hey_salty.npy
ros2 launch saltybot_social wake_word.launch.py match_threshold:=0.85 cooldown_s:=3.0
"""
import os
from ament_index_python.packages import get_package_share_directory
from launch import LaunchDescription
from launch.actions import DeclareLaunchArgument
from launch.substitutions import LaunchConfiguration
from launch_ros.actions import Node
def generate_launch_description():
pkg = get_package_share_directory("saltybot_social")
cfg = os.path.join(pkg, "config", "wake_word_params.yaml")
return LaunchDescription([
DeclareLaunchArgument("template_path", default_value="",
description="Path to .npy template file (log-mel of 'hey salty')"),
DeclareLaunchArgument("match_threshold", default_value="0.82",
description="Cosine-similarity detection threshold"),
DeclareLaunchArgument("cooldown_s", default_value="2.0",
description="Minimum seconds between detections"),
Node(
package="saltybot_social",
executable="wake_word_node",
name="wake_word_node",
output="screen",
parameters=[
cfg,
{
"template_path": LaunchConfiguration("template_path"),
"match_threshold": LaunchConfiguration("match_threshold"),
"cooldown_s": LaunchConfiguration("cooldown_s"),
},
],
),
])

View File

@ -0,0 +1,343 @@
"""wake_word_node.py — Audio wake-word detector ('hey salty').
Issue #320
Subscribes to raw PCM-16 audio on /social/speech/audio_raw, maintains a
sliding window buffer, and on each hop tick computes log-mel spectrogram
features of the most recent window. If energy is above the gate threshold
AND cosine similarity to the stored template is above match_threshold the
detection fires: Bool(True) is published on /saltybot/wake_word_detected.
Detection is one-shot (only True is published) and guarded by a cooldown
so rapid re-fires are suppressed. When no template is loaded (template_path
is empty) the node stays passive energy gating is applied but no match is
attempted.
Audio format expected
UInt8MultiArray, same feed as vad_node:
raw PCM-16 little-endian mono at ``sample_rate`` Hz.
Subscriptions
/social/speech/audio_raw std_msgs/UInt8MultiArray raw PCM-16 chunks
Publications
/saltybot/wake_word_detected std_msgs/Bool True on each detection event
Parameters
audio_topic (str, "/social/speech/audio_raw")
output_topic (str, "/saltybot/wake_word_detected")
sample_rate (int, 16000) sample rate of incoming audio (Hz)
window_s (float, 1.5) detection window duration (s)
hop_s (float, 0.1) detection timer period (s)
energy_threshold (float, 0.02) RMS gate; below this skip matching
match_threshold (float, 0.82) cosine-similarity gate; above detect
cooldown_s (float, 2.0) minimum gap between successive detections
template_path (str, "") path to .npy template file; "" = passive
n_fft (int, 512) FFT size for mel spectrogram
n_mels (int, 40) number of mel filterbank bands
"""
from __future__ import annotations
import math
import struct
import threading
import time
from collections import deque
from typing import Dict, Optional, Tuple
import rclpy
from rclpy.node import Node
from rclpy.qos import QoSProfile
from std_msgs.msg import Bool, UInt8MultiArray
try:
import numpy as np
_NP = True
except ImportError: # pragma: no cover
_NP = False
INT16_MAX = 32768.0
# ── Pure DSP helpers (no ROS, numpy only) ──────────────────────────────────────
def pcm16_to_float(data: bytes) -> "np.ndarray":
"""Decode PCM-16 LE bytes → float32 ndarray in [-1.0, 1.0]."""
n = len(data) // 2
if n == 0:
return np.zeros(0, dtype=np.float32)
samples = struct.unpack(f"<{n}h", data[:n * 2])
return np.array(samples, dtype=np.float32) / INT16_MAX
def mel_filterbank(sr: int, n_fft: int, n_mels: int,
fmin: float = 80.0, fmax: Optional[float] = None) -> "np.ndarray":
"""Build a triangular mel filterbank matrix [n_mels, n_fft//2+1]."""
if fmax is None:
fmax = sr / 2.0
def hz_to_mel(hz: float) -> float:
return 2595.0 * math.log10(1.0 + hz / 700.0)
def mel_to_hz(mel: float) -> float:
return 700.0 * (10.0 ** (mel / 2595.0) - 1.0)
mel_lo = hz_to_mel(fmin)
mel_hi = hz_to_mel(fmax)
mel_pts = np.linspace(mel_lo, mel_hi, n_mels + 2)
hz_pts = np.array([mel_to_hz(m) for m in mel_pts])
freqs = np.fft.rfftfreq(n_fft, d=1.0 / sr)
fb = np.zeros((n_mels, len(freqs)), dtype=np.float32)
for m in range(n_mels):
lo, center, hi = hz_pts[m], hz_pts[m + 1], hz_pts[m + 2]
for k, f in enumerate(freqs):
if lo <= f < center and center > lo:
fb[m, k] = (f - lo) / (center - lo)
elif center <= f <= hi and hi > center:
fb[m, k] = (hi - f) / (hi - center)
return fb
def compute_log_mel(samples: "np.ndarray", sr: int,
n_fft: int = 512, n_mels: int = 40,
hop: int = 256) -> "np.ndarray":
"""Return log-mel spectrogram [n_mels, T] of *samples* (float32 [-1,1])."""
n = len(samples)
window = np.hanning(n_fft).astype(np.float32)
frames = []
for start in range(0, max(n - n_fft + 1, 1), hop):
chunk = samples[start:start + n_fft]
if len(chunk) < n_fft:
chunk = np.pad(chunk, (0, n_fft - len(chunk)))
power = np.abs(np.fft.rfft(chunk * window)) ** 2
frames.append(power)
frames_arr = np.array(frames, dtype=np.float32).T # [bins, T]
fb = mel_filterbank(sr, n_fft, n_mels)
mel = fb @ frames_arr # [n_mels, T]
mel = np.where(mel > 1e-10, mel, 1e-10)
return np.log(mel)
def cosine_sim(a: "np.ndarray", b: "np.ndarray") -> float:
"""Cosine similarity between two arrays, matched by minimum length."""
af = a.flatten().astype(np.float64)
bf = b.flatten().astype(np.float64)
min_len = min(len(af), len(bf))
if min_len == 0:
return 0.0
af = af[:min_len]
bf = bf[:min_len]
denom = float(np.linalg.norm(af)) * float(np.linalg.norm(bf))
if denom < 1e-12:
return 0.0
return float(np.dot(af, bf) / denom)
def rms(samples: "np.ndarray") -> float:
"""RMS amplitude of a float sample array."""
if len(samples) == 0:
return 0.0
return float(np.sqrt(np.mean(samples.astype(np.float64) ** 2)))
# ── Ring buffer ────────────────────────────────────────────────────────────────
class AudioRingBuffer:
"""Lock-free sliding window for raw float audio samples."""
def __init__(self, max_samples: int) -> None:
self._buf: deque = deque(maxlen=max_samples)
def push(self, samples: "np.ndarray") -> None:
self._buf.extend(samples.tolist())
def get_window(self, n_samples: int) -> Optional["np.ndarray"]:
"""Return last n_samples as float32 array, or None if buffer too short."""
if len(self._buf) < n_samples:
return None
return np.array(list(self._buf)[-n_samples:], dtype=np.float32)
def __len__(self) -> int:
return len(self._buf)
# ── Detector ───────────────────────────────────────────────────────────────────
class WakeWordDetector:
"""Energy-gated cosine-similarity wake-word detector.
Args:
template: Log-mel feature array of the wake word, or None
(passive never fires when None).
energy_threshold: Minimum RMS to proceed to feature matching.
match_threshold: Minimum cosine similarity to fire.
sample_rate: Expected audio sample rate (Hz).
n_fft: FFT size for mel computation.
n_mels: Number of mel bands.
"""
def __init__(self,
template: Optional["np.ndarray"],
energy_threshold: float = 0.02,
match_threshold: float = 0.82,
sample_rate: int = 16000,
n_fft: int = 512,
n_mels: int = 40) -> None:
self._template = template
self._energy_thr = energy_threshold
self._match_thr = match_threshold
self._sr = sample_rate
self._n_fft = n_fft
self._n_mels = n_mels
# ------------------------------------------------------------------
def detect(self, samples: "np.ndarray") -> Tuple[bool, float, float]:
"""Run detection on a window of float samples.
Returns
-------
(detected, rms_value, similarity)
"""
energy = rms(samples)
if energy < self._energy_thr:
return False, energy, 0.0
if self._template is None:
return False, energy, 0.0
hop = max(1, self._n_fft // 2)
feats = compute_log_mel(samples, self._sr, self._n_fft, self._n_mels, hop)
sim = cosine_sim(feats, self._template)
return sim >= self._match_thr, energy, sim
@property
def has_template(self) -> bool:
return self._template is not None
# ── ROS2 node ──────────────────────────────────────────────────────────────────
class WakeWordNode(Node):
"""ROS2 node: 'hey salty' wake-word detection via energy + template matching."""
def __init__(self) -> None:
super().__init__("wake_word_node")
self.declare_parameter("audio_topic", "/social/speech/audio_raw")
self.declare_parameter("output_topic", "/saltybot/wake_word_detected")
self.declare_parameter("sample_rate", 16000)
self.declare_parameter("window_s", 1.5)
self.declare_parameter("hop_s", 0.1)
self.declare_parameter("energy_threshold", 0.02)
self.declare_parameter("match_threshold", 0.82)
self.declare_parameter("cooldown_s", 2.0)
self.declare_parameter("template_path", "")
self.declare_parameter("n_fft", 512)
self.declare_parameter("n_mels", 40)
audio_topic = self.get_parameter("audio_topic").value
output_topic = self.get_parameter("output_topic").value
self._sr = int(self.get_parameter("sample_rate").value)
self._win_s = float(self.get_parameter("window_s").value)
hop_s = float(self.get_parameter("hop_s").value)
energy_thr = float(self.get_parameter("energy_threshold").value)
match_thr = float(self.get_parameter("match_threshold").value)
self._cool_s = float(self.get_parameter("cooldown_s").value)
tmpl_path = str(self.get_parameter("template_path").value)
n_fft = int(self.get_parameter("n_fft").value)
n_mels = int(self.get_parameter("n_mels").value)
# ── Load template ──────────────────────────────────────────────
template: Optional["np.ndarray"] = None
if tmpl_path and _NP:
try:
template = np.load(tmpl_path)
self.get_logger().info(
f"WakeWord: loaded template {tmpl_path} shape={template.shape}"
)
except Exception as exc:
self.get_logger().warn(
f"WakeWord: could not load template '{tmpl_path}': {exc} — passive mode"
)
# ── Ring buffer ────────────────────────────────────────────────
max_samples = int(self._sr * self._win_s * 4) # 4× headroom
self._win_n = int(self._sr * self._win_s)
self._buf = AudioRingBuffer(max_samples)
# ── Detector ───────────────────────────────────────────────────
self._detector = WakeWordDetector(template, energy_thr, match_thr,
self._sr, n_fft, n_mels)
# ── State ──────────────────────────────────────────────────────
self._last_det_t: float = 0.0
self._lock = threading.Lock()
# ── ROS ────────────────────────────────────────────────────────
qos = QoSProfile(depth=10)
self._pub = self.create_publisher(Bool, output_topic, qos)
self._sub = self.create_subscription(
UInt8MultiArray, audio_topic, self._on_audio, qos
)
self._timer = self.create_timer(hop_s, self._detection_cb)
tmpl_status = (f"template={tmpl_path}" if template is not None
else "passive (no template)")
self.get_logger().info(
f"WakeWordNode ready — {tmpl_status}, "
f"window={self._win_s}s, hop={hop_s}s, "
f"energy_thr={energy_thr}, match_thr={match_thr}"
)
# ── Subscription ───────────────────────────────────────────────────
def _on_audio(self, msg) -> None:
if not _NP:
return
try:
raw = bytes(msg.data)
samples = pcm16_to_float(raw)
except Exception as exc:
self.get_logger().warn(f"WakeWord: audio decode error: {exc}")
return
with self._lock:
self._buf.push(samples)
# ── Detection timer ────────────────────────────────────────────────
def _detection_cb(self) -> None:
if not _NP:
return
now = time.monotonic()
with self._lock:
window = self._buf.get_window(self._win_n)
if window is None:
return
detected, energy, sim = self._detector.detect(window)
if detected and (now - self._last_det_t) >= self._cool_s:
self._last_det_t = now
self.get_logger().info(
f"WakeWord: 'hey salty' detected "
f"(rms={energy:.4f}, sim={sim:.3f})"
)
out = Bool()
out.data = True
self._pub.publish(out)
def main(args=None) -> None:
rclpy.init(args=args)
node = WakeWordNode()
try:
rclpy.spin(node)
except KeyboardInterrupt:
pass
finally:
node.destroy_node()
rclpy.shutdown()

View File

@ -57,6 +57,8 @@ setup(
'topic_memory_node = saltybot_social.topic_memory_node:main', 'topic_memory_node = saltybot_social.topic_memory_node:main',
# Personal space respector (Issue #310) # Personal space respector (Issue #310)
'personal_space_node = saltybot_social.personal_space_node:main', 'personal_space_node = saltybot_social.personal_space_node:main',
# Audio wake-word detector — 'hey salty' (Issue #320)
'wake_word_node = saltybot_social.wake_word_node:main',
], ],
}, },
) )

View File

@ -0,0 +1,711 @@
"""test_wake_word.py — Offline tests for wake_word_node (Issue #320).
Stubs out rclpy and ROS message types so tests run without a ROS install.
numpy is required (standard on the Jetson).
"""
import importlib.util
import math
import struct
import sys
import time
import types
import unittest
import numpy as np
# ── ROS2 stubs ────────────────────────────────────────────────────────────────
def _make_ros_stubs():
for mod_name in ("rclpy", "rclpy.node", "rclpy.qos",
"std_msgs", "std_msgs.msg"):
if mod_name not in sys.modules:
sys.modules[mod_name] = types.ModuleType(mod_name)
class _Node:
def __init__(self, name="node"):
self._name = name
if not hasattr(self, "_params"):
self._params = {}
self._pubs = {}
self._subs = {}
self._logs = []
self._timers = []
def declare_parameter(self, name, default):
if name not in self._params:
self._params[name] = default
def get_parameter(self, name):
class _P:
def __init__(self, v): self.value = v
return _P(self._params.get(name))
def create_publisher(self, msg_type, topic, qos):
pub = _FakePub()
self._pubs[topic] = pub
return pub
def create_subscription(self, msg_type, topic, cb, qos):
self._subs[topic] = cb
return object()
def create_timer(self, period, cb):
self._timers.append(cb)
return object()
def get_logger(self):
node = self
class _L:
def info(self, m): node._logs.append(("INFO", m))
def warn(self, m): node._logs.append(("WARN", m))
def error(self, m): node._logs.append(("ERROR", m))
return _L()
def destroy_node(self): pass
class _FakePub:
def __init__(self):
self.msgs = []
def publish(self, msg):
self.msgs.append(msg)
class _QoSProfile:
def __init__(self, depth=10): self.depth = depth
class _Bool:
def __init__(self): self.data = False
class _UInt8MultiArray:
def __init__(self): self.data = b""
rclpy_mod = sys.modules["rclpy"]
rclpy_mod.init = lambda args=None: None
rclpy_mod.spin = lambda node: None
rclpy_mod.shutdown = lambda: None
sys.modules["rclpy.node"].Node = _Node
sys.modules["rclpy.qos"].QoSProfile = _QoSProfile
sys.modules["std_msgs.msg"].Bool = _Bool
sys.modules["std_msgs.msg"].UInt8MultiArray = _UInt8MultiArray
return _Node, _FakePub, _Bool, _UInt8MultiArray
_Node, _FakePub, _Bool, _UInt8MultiArray = _make_ros_stubs()
# ── Module loader ─────────────────────────────────────────────────────────────
_SRC = (
"/Users/seb/AI/saltylab-firmware/jetson/ros2_ws/src/"
"saltybot_social/saltybot_social/wake_word_node.py"
)
def _load_mod():
spec = importlib.util.spec_from_file_location("wake_word_testmod", _SRC)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
return mod
def _make_node(mod, **kwargs):
node = mod.WakeWordNode.__new__(mod.WakeWordNode)
defaults = {
"audio_topic": "/social/speech/audio_raw",
"output_topic": "/saltybot/wake_word_detected",
"sample_rate": 16000,
"window_s": 1.5,
"hop_s": 0.1,
"energy_threshold": 0.02,
"match_threshold": 0.82,
"cooldown_s": 2.0,
"template_path": "",
"n_fft": 512,
"n_mels": 40,
}
defaults.update(kwargs)
node._params = dict(defaults)
mod.WakeWordNode.__init__(node)
return node
def _make_pcm_bytes(samples: np.ndarray) -> bytes:
"""Encode float32 array [-1,1] to PCM-16 LE bytes."""
ints = np.clip(samples * 32768.0, -32768, 32767).astype(np.int16)
return ints.tobytes()
def _make_audio_msg(samples: np.ndarray) -> _UInt8MultiArray:
m = _UInt8MultiArray()
m.data = _make_pcm_bytes(samples)
return m
def _sine(freq: float, duration: float, sr: int = 16000,
amp: float = 0.5) -> np.ndarray:
"""Generate a mono sine wave."""
t = np.arange(int(sr * duration)) / sr
return (amp * np.sin(2 * math.pi * freq * t)).astype(np.float32)
def _silence(duration: float, sr: int = 16000) -> np.ndarray:
return np.zeros(int(sr * duration), dtype=np.float32)
def _make_template(mod, sr: int = 16000, n_fft: int = 512,
n_mels: int = 40) -> np.ndarray:
"""Compute a template from a synthetic 'wake word' signal for testing."""
signal = _sine(300, 1.5, sr, amp=0.6)
hop = n_fft // 2
return mod.compute_log_mel(signal, sr, n_fft, n_mels, hop)
# ── Tests: pcm16_to_float ─────────────────────────────────────────────────────
class TestPcm16ToFloat(unittest.TestCase):
@classmethod
def setUpClass(cls): cls.mod = _load_mod()
def _conv(self, samples):
raw = _make_pcm_bytes(np.array(samples, dtype=np.float32))
return self.mod.pcm16_to_float(raw)
def test_zeros(self):
out = self._conv([0.0, 0.0, 0.0])
np.testing.assert_allclose(out, [0.0, 0.0, 0.0], atol=1e-4)
def test_positive(self):
out = self._conv([0.5])
self.assertAlmostEqual(float(out[0]), 0.5, places=3)
def test_negative(self):
out = self._conv([-0.5])
self.assertAlmostEqual(float(out[0]), -0.5, places=3)
def test_roundtrip_length(self):
arr = np.linspace(-1.0, 1.0, 100, dtype=np.float32)
raw = _make_pcm_bytes(arr)
out = self.mod.pcm16_to_float(raw)
self.assertEqual(len(out), 100)
def test_empty_bytes_returns_empty(self):
out = self.mod.pcm16_to_float(b"")
self.assertEqual(len(out), 0)
def test_odd_byte_ignored(self):
# 3 bytes → 1 complete int16 + 1 orphan byte
out = self.mod.pcm16_to_float(b"\x00\x40\xff")
self.assertEqual(len(out), 1)
# ── Tests: rms ────────────────────────────────────────────────────────────────
class TestRms(unittest.TestCase):
@classmethod
def setUpClass(cls): cls.mod = _load_mod()
def test_zero_signal(self):
self.assertAlmostEqual(self.mod.rms(np.zeros(100)), 0.0)
def test_constant_signal(self):
# RMS of constant 0.5 = 0.5
self.assertAlmostEqual(self.mod.rms(np.full(100, 0.5)), 0.5, places=5)
def test_sine_rms(self):
# RMS of sin = amp / sqrt(2)
amp = 0.8
s = _sine(440, 1.0, amp=amp)
expected = amp / math.sqrt(2)
self.assertAlmostEqual(self.mod.rms(s), expected, places=2)
def test_empty_array(self):
self.assertEqual(self.mod.rms(np.array([])), 0.0)
# ── Tests: mel_filterbank ─────────────────────────────────────────────────────
class TestMelFilterbank(unittest.TestCase):
@classmethod
def setUpClass(cls): cls.mod = _load_mod()
def test_shape(self):
fb = self.mod.mel_filterbank(16000, 512, 40)
self.assertEqual(fb.shape, (40, 257))
def test_non_negative(self):
fb = self.mod.mel_filterbank(16000, 512, 40)
self.assertTrue((fb >= 0).all())
def test_rows_sum_positive(self):
fb = self.mod.mel_filterbank(16000, 512, 40)
self.assertTrue((fb.sum(axis=1) > 0).all())
def test_custom_n_mels(self):
fb = self.mod.mel_filterbank(16000, 256, 20)
self.assertEqual(fb.shape[0], 20)
# ── Tests: compute_log_mel ────────────────────────────────────────────────────
class TestComputeLogMel(unittest.TestCase):
@classmethod
def setUpClass(cls): cls.mod = _load_mod()
def test_output_shape_rows(self):
s = _sine(440, 1.5)
out = self.mod.compute_log_mel(s, 16000, n_fft=512, n_mels=40, hop=256)
self.assertEqual(out.shape[0], 40)
def test_output_has_time_axis(self):
s = _sine(440, 1.5)
out = self.mod.compute_log_mel(s, 16000, n_fft=512, n_mels=40, hop=256)
self.assertGreater(out.shape[1], 0)
def test_output_finite(self):
s = _sine(440, 1.5)
out = self.mod.compute_log_mel(s, 16000, n_fft=512, n_mels=40, hop=256)
self.assertTrue(np.isfinite(out).all())
def test_silence_gives_low_values(self):
s = _silence(1.5)
out = self.mod.compute_log_mel(s, 16000, n_fft=512, n_mels=40, hop=256)
# All values should be very small (near log(1e-10))
self.assertTrue((out < -20).all())
def test_short_signal_no_crash(self):
# Shorter than one FFT frame
s = np.zeros(100, dtype=np.float32)
out = self.mod.compute_log_mel(s, 16000, n_fft=512, n_mels=40, hop=256)
self.assertEqual(out.shape[0], 40)
# ── Tests: cosine_sim ─────────────────────────────────────────────────────────
class TestCosineSim(unittest.TestCase):
@classmethod
def setUpClass(cls): cls.mod = _load_mod()
def test_identical_vectors(self):
v = np.array([1.0, 2.0, 3.0])
self.assertAlmostEqual(self.mod.cosine_sim(v, v), 1.0, places=5)
def test_orthogonal_vectors(self):
a = np.array([1.0, 0.0])
b = np.array([0.0, 1.0])
self.assertAlmostEqual(self.mod.cosine_sim(a, b), 0.0, places=5)
def test_opposite_vectors(self):
v = np.array([1.0, 2.0, 3.0])
self.assertAlmostEqual(self.mod.cosine_sim(v, -v), -1.0, places=5)
def test_zero_vector_returns_zero(self):
a = np.zeros(5)
b = np.ones(5)
self.assertEqual(self.mod.cosine_sim(a, b), 0.0)
def test_2d_arrays(self):
a = np.ones((4, 10))
b = np.ones((4, 10))
self.assertAlmostEqual(self.mod.cosine_sim(a, b), 1.0, places=5)
def test_different_lengths_truncated(self):
a = np.array([1.0, 2.0, 3.0, 4.0])
b = np.array([1.0, 2.0, 3.0])
# Should not crash, uses min length
result = self.mod.cosine_sim(a, b)
self.assertTrue(-1.0 <= result <= 1.0)
def test_range_is_bounded(self):
rng = np.random.default_rng(42)
a = rng.standard_normal(100)
b = rng.standard_normal(100)
result = self.mod.cosine_sim(a, b)
self.assertGreaterEqual(result, -1.0)
self.assertLessEqual(result, 1.0)
# ── Tests: AudioRingBuffer ────────────────────────────────────────────────────
class TestAudioRingBuffer(unittest.TestCase):
@classmethod
def setUpClass(cls): cls.mod = _load_mod()
def _buf(self, max_samples=1000):
return self.mod.AudioRingBuffer(max_samples)
def test_empty_initially(self):
b = self._buf()
self.assertEqual(len(b), 0)
def test_push_increases_len(self):
b = self._buf()
b.push(np.ones(100, dtype=np.float32))
self.assertEqual(len(b), 100)
def test_get_window_none_when_short(self):
b = self._buf()
b.push(np.ones(50, dtype=np.float32))
self.assertIsNone(b.get_window(100))
def test_get_window_ok_when_full(self):
b = self._buf()
data = np.arange(200, dtype=np.float32)
b.push(data)
w = b.get_window(100)
self.assertIsNotNone(w)
self.assertEqual(len(w), 100)
def test_get_window_returns_latest(self):
b = self._buf()
b.push(np.zeros(100, dtype=np.float32))
b.push(np.ones(100, dtype=np.float32))
w = b.get_window(100)
np.testing.assert_allclose(w, np.ones(100))
def test_maxlen_evicts_oldest(self):
b = self._buf(max_samples=100)
b.push(np.zeros(60, dtype=np.float32))
b.push(np.ones(60, dtype=np.float32)) # should evict 20 zeros
self.assertEqual(len(b), 100)
w = b.get_window(100)
# Last 40 samples should be ones
np.testing.assert_allclose(w[-40:], np.ones(40))
def test_exact_window_size(self):
b = self._buf(500)
data = np.arange(300, dtype=np.float32)
b.push(data)
w = b.get_window(300)
np.testing.assert_allclose(w, data)
# ── Tests: WakeWordDetector ───────────────────────────────────────────────────
class TestWakeWordDetector(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.mod = _load_mod()
cls.sr = 16000
cls.template = _make_template(cls.mod, cls.sr)
def _det(self, template=None, energy_thr=0.02, match_thr=0.82):
return self.mod.WakeWordDetector(template, energy_thr, match_thr,
self.sr, 512, 40)
def test_no_template_never_detects(self):
det = self._det(template=None)
loud = _sine(300, 1.5, self.sr, amp=0.8)
detected, _, _ = det.detect(loud)
self.assertFalse(detected)
def test_has_template_false_when_no_template(self):
det = self._det(template=None)
self.assertFalse(det.has_template)
def test_has_template_true_when_set(self):
det = self._det(template=self.template)
self.assertTrue(det.has_template)
def test_silence_below_energy_gate(self):
det = self._det(template=self.template, energy_thr=0.02)
silence = _silence(1.5, self.sr)
detected, rms_val, sim = det.detect(silence)
self.assertFalse(detected)
self.assertAlmostEqual(rms_val, 0.0, places=4)
self.assertAlmostEqual(sim, 0.0, places=4)
def test_returns_rms_value(self):
det = self._det(template=None)
s = _sine(300, 1.5, self.sr, amp=0.5)
_, rms_val, _ = det.detect(s)
expected = 0.5 / math.sqrt(2)
self.assertAlmostEqual(rms_val, expected, places=2)
def test_identical_signal_detects(self):
"""Signal identical to template should give sim ≈ 1.0 → detect."""
signal = _sine(300, 1.5, self.sr, amp=0.6)
det = self._det(template=self.template, match_thr=0.99)
detected, _, sim = det.detect(signal)
# Sim must be very high for an identical-source signal
self.assertGreater(sim, 0.99)
self.assertTrue(detected)
def test_different_signal_low_sim(self):
"""A very different signal (white noise) should have low similarity."""
rng = np.random.default_rng(7)
noise = (rng.standard_normal(int(self.sr * 1.5)) * 0.4).astype(np.float32)
det = self._det(template=self.template, match_thr=0.82)
_, _, sim = det.detect(noise)
# White noise sim to a tonal template should be < 0.6
self.assertLess(sim, 0.6)
def test_threshold_boundary_low(self):
"""Setting match_thr=0.0 with a loud signal should fire if template set."""
signal = _sine(300, 1.5, self.sr, amp=0.6)
det = self._det(template=self.template, match_thr=0.0)
detected, _, _ = det.detect(signal)
self.assertTrue(detected)
def test_threshold_boundary_high(self):
"""Setting match_thr=1.1 (above max) should never fire."""
signal = _sine(300, 1.5, self.sr, amp=0.6)
det = self._det(template=self.template, match_thr=1.1)
detected, _, _ = det.detect(signal)
self.assertFalse(detected)
def test_energy_below_threshold_skips_matching(self):
"""Low energy → sim returned as 0.0 regardless of template."""
very_quiet = _sine(300, 1.5, self.sr, amp=0.001)
det = self._det(template=self.template, energy_thr=0.1)
detected, rms_val, sim = det.detect(very_quiet)
self.assertFalse(detected)
self.assertAlmostEqual(sim, 0.0, places=5)
# ── Tests: node init ──────────────────────────────────────────────────────────
class TestNodeInit(unittest.TestCase):
@classmethod
def setUpClass(cls): cls.mod = _load_mod()
def test_instantiates(self):
self.assertIsNotNone(_make_node(self.mod))
def test_pub_registered(self):
node = _make_node(self.mod)
self.assertIn("/saltybot/wake_word_detected", node._pubs)
def test_sub_registered(self):
node = _make_node(self.mod)
self.assertIn("/social/speech/audio_raw", node._subs)
def test_timer_registered(self):
node = _make_node(self.mod)
self.assertGreater(len(node._timers), 0)
def test_custom_topics(self):
node = _make_node(self.mod,
audio_topic="/my/audio",
output_topic="/my/wake")
self.assertIn("/my/audio", node._subs)
self.assertIn("/my/wake", node._pubs)
def test_no_template_passive(self):
node = _make_node(self.mod, template_path="")
self.assertFalse(node._detector.has_template)
def test_bad_template_path_warns(self):
node = _make_node(self.mod, template_path="/nonexistent/template.npy")
warns = [m for lvl, m in node._logs if lvl == "WARN"]
self.assertTrue(any("template" in m.lower() or "passive" in m.lower()
for m in warns))
def test_ring_buffer_allocated(self):
node = _make_node(self.mod)
self.assertIsNotNone(node._buf)
def test_window_n_computed(self):
node = _make_node(self.mod, sample_rate=16000, window_s=1.5)
self.assertEqual(node._win_n, 24000)
# ── Tests: _on_audio callback ─────────────────────────────────────────────────
class TestOnAudio(unittest.TestCase):
@classmethod
def setUpClass(cls): cls.mod = _load_mod()
def setUp(self):
self.node = _make_node(self.mod)
def _push(self, samples):
msg = _make_audio_msg(samples)
self.node._subs["/social/speech/audio_raw"](msg)
def test_pushes_samples_to_buffer(self):
self._push(_sine(440, 0.5))
self.assertGreater(len(self.node._buf), 0)
def test_buffer_grows_with_pushes(self):
chunk = _sine(440, 0.1)
before = len(self.node._buf)
self._push(chunk)
after = len(self.node._buf)
self.assertGreater(after, before)
def test_bad_data_no_crash(self):
msg = _UInt8MultiArray()
msg.data = b"\xff" # 1 orphan byte — yields 0 samples, no crash
self.node._subs["/social/speech/audio_raw"](msg)
def test_multiple_chunks_accumulate(self):
for _ in range(5):
self._push(_sine(440, 0.1))
self.assertGreater(len(self.node._buf), 0)
# ── Tests: detection callback ─────────────────────────────────────────────────
class TestDetectionCallback(unittest.TestCase):
@classmethod
def setUpClass(cls): cls.mod = _load_mod()
def _node_with_template(self, **kwargs):
template = _make_template(self.mod)
node = _make_node(self.mod, **kwargs)
node._detector = self.mod.WakeWordDetector(
template, energy_threshold=0.02, match_threshold=0.0, # thr=0 → always fires when loud
sample_rate=16000, n_fft=512, n_mels=40
)
return node
def _fill_buffer(self, node, signal):
msg = _make_audio_msg(signal)
node._subs["/social/speech/audio_raw"](msg)
def test_no_data_no_publish(self):
node = _make_node(self.mod)
node._detection_cb()
self.assertEqual(len(node._pubs["/saltybot/wake_word_detected"].msgs), 0)
def test_insufficient_buffer_no_publish(self):
node = self._node_with_template()
# Push only 0.1 s but window is 1.5 s
self._fill_buffer(node, _sine(300, 0.1))
node._detection_cb()
self.assertEqual(len(node._pubs["/saltybot/wake_word_detected"].msgs), 0)
def test_detects_and_publishes_true(self):
node = self._node_with_template()
# Fill with a loud 300 Hz sine (matches template)
self._fill_buffer(node, _sine(300, 1.5, amp=0.6))
node._detection_cb()
pub = node._pubs["/saltybot/wake_word_detected"]
self.assertEqual(len(pub.msgs), 1)
self.assertTrue(pub.msgs[0].data)
def test_cooldown_suppresses_second_detection(self):
node = self._node_with_template(cooldown_s=60.0)
self._fill_buffer(node, _sine(300, 1.5, amp=0.6))
node._detection_cb()
# Second call immediately → cooldown active
self._fill_buffer(node, _sine(300, 1.5, amp=0.6))
node._detection_cb()
pub = node._pubs["/saltybot/wake_word_detected"]
self.assertEqual(len(pub.msgs), 1)
def test_cooldown_expired_allows_second(self):
node = self._node_with_template(cooldown_s=0.0)
self._fill_buffer(node, _sine(300, 1.5, amp=0.6))
node._detection_cb()
self._fill_buffer(node, _sine(300, 1.5, amp=0.6))
node._detection_cb()
pub = node._pubs["/saltybot/wake_word_detected"]
self.assertEqual(len(pub.msgs), 2)
def test_no_template_never_publishes(self):
node = _make_node(self.mod, template_path="")
self._fill_buffer(node, _sine(300, 1.5, amp=0.8))
node._detection_cb()
pub = node._pubs["/saltybot/wake_word_detected"]
self.assertEqual(len(pub.msgs), 0)
def test_silence_no_publish(self):
node = self._node_with_template()
self._fill_buffer(node, _silence(1.5))
node._detection_cb()
pub = node._pubs["/saltybot/wake_word_detected"]
self.assertEqual(len(pub.msgs), 0)
def test_detection_logs_info(self):
node = self._node_with_template()
self._fill_buffer(node, _sine(300, 1.5, amp=0.6))
node._detection_cb()
infos = [m for lvl, m in node._logs if lvl == "INFO"]
self.assertTrue(any("detected" in m.lower() or "hey salty" in m.lower()
for m in infos))
# ── Tests: source content ─────────────────────────────────────────────────────
class TestNodeSrc(unittest.TestCase):
@classmethod
def setUpClass(cls):
with open(_SRC) as f: cls.src = f.read()
def test_issue_tag(self): self.assertIn("#320", self.src)
def test_audio_topic(self): self.assertIn("/social/speech/audio_raw", self.src)
def test_output_topic(self): self.assertIn("/saltybot/wake_word_detected", self.src)
def test_wake_word_name(self): self.assertIn("hey salty", self.src)
def test_compute_log_mel(self): self.assertIn("compute_log_mel", self.src)
def test_cosine_sim(self): self.assertIn("cosine_sim", self.src)
def test_mel_filterbank(self): self.assertIn("mel_filterbank", self.src)
def test_audio_ring_buffer(self): self.assertIn("AudioRingBuffer", self.src)
def test_wake_word_detector(self): self.assertIn("WakeWordDetector", self.src)
def test_energy_threshold(self): self.assertIn("energy_threshold", self.src)
def test_match_threshold(self): self.assertIn("match_threshold", self.src)
def test_cooldown(self): self.assertIn("cooldown_s", self.src)
def test_template_path(self): self.assertIn("template_path", self.src)
def test_pcm16_decode(self): self.assertIn("pcm16_to_float", self.src)
def test_threading_lock(self): self.assertIn("threading.Lock", self.src)
def test_numpy_used(self): self.assertIn("import numpy", self.src)
def test_main_defined(self): self.assertIn("def main", self.src)
def test_uint8_multiarray(self): self.assertIn("UInt8MultiArray", self.src)
# ── Tests: config / launch / setup ────────────────────────────────────────────
class TestConfig(unittest.TestCase):
_CONFIG = (
"/Users/seb/AI/saltylab-firmware/jetson/ros2_ws/src/"
"saltybot_social/config/wake_word_params.yaml"
)
_LAUNCH = (
"/Users/seb/AI/saltylab-firmware/jetson/ros2_ws/src/"
"saltybot_social/launch/wake_word.launch.py"
)
_SETUP = (
"/Users/seb/AI/saltylab-firmware/jetson/ros2_ws/src/"
"saltybot_social/setup.py"
)
def test_config_exists(self):
import os; self.assertTrue(os.path.exists(self._CONFIG))
def test_config_energy_threshold(self):
with open(self._CONFIG) as f: c = f.read()
self.assertIn("energy_threshold", c)
def test_config_match_threshold(self):
with open(self._CONFIG) as f: c = f.read()
self.assertIn("match_threshold", c)
def test_config_template_path(self):
with open(self._CONFIG) as f: c = f.read()
self.assertIn("template_path", c)
def test_config_cooldown(self):
with open(self._CONFIG) as f: c = f.read()
self.assertIn("cooldown_s", c)
def test_launch_exists(self):
import os; self.assertTrue(os.path.exists(self._LAUNCH))
def test_launch_has_template_arg(self):
with open(self._LAUNCH) as f: c = f.read()
self.assertIn("template_path", c)
def test_launch_has_threshold_arg(self):
with open(self._LAUNCH) as f: c = f.read()
self.assertIn("match_threshold", c)
def test_entry_point_in_setup(self):
with open(self._SETUP) as f: c = f.read()
self.assertIn("wake_word_node", c)
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,12 @@
# Velocity Smoother Configuration
velocity_smoother:
ros__parameters:
# Filter parameters
filter_order: 2 # Butterworth filter order (1-4 typical)
cutoff_frequency: 5.0 # Cutoff frequency in Hz (lower = more smoothing)
# Publishing frequency (Hz)
publish_frequency: 50
# Enable/disable filtering
enable_smoothing: true

View File

@ -0,0 +1,28 @@
import os
from launch import LaunchDescription
from launch_ros.actions import Node
from launch_ros.substitutions import FindPackageShare
from launch.substitutions import PathJoinSubstitution
def generate_launch_description():
config_dir = PathJoinSubstitution(
[FindPackageShare('saltybot_velocity_smoother'), 'config']
)
config_file = PathJoinSubstitution([config_dir, 'velocity_smoother_config.yaml'])
velocity_smoother = Node(
package='saltybot_velocity_smoother',
executable='velocity_smoother_node',
name='velocity_smoother',
output='screen',
parameters=[config_file],
remappings=[
('/odom', '/odom'),
('/odom_smooth', '/odom_smooth'),
],
)
return LaunchDescription([
velocity_smoother,
])

View File

@ -0,0 +1,29 @@
<?xml version="1.0"?>
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
<package format="3">
<name>saltybot_velocity_smoother</name>
<version>0.1.0</version>
<description>Low-pass Butterworth filter for odometry velocity smoothing to reduce encoder jitter</description>
<maintainer email="sl-controls@saltybot.local">SaltyBot Controls</maintainer>
<license>MIT</license>
<author email="sl-controls@saltybot.local">SaltyBot Controls Team</author>
<buildtool_depend>ament_cmake</buildtool_depend>
<buildtool_depend>ament_cmake_python</buildtool_depend>
<depend>rclpy</depend>
<depend>nav_msgs</depend>
<depend>std_msgs</depend>
<depend>geometry_msgs</depend>
<test_depend>ament_copyright</test_depend>
<test_depend>ament_flake8</test_depend>
<test_depend>ament_pep257</test_depend>
<test_depend>pytest</test_depend>
<export>
<build_type>ament_python</build_type>
</export>
</package>

View File

@ -0,0 +1,231 @@
#!/usr/bin/env python3
"""Velocity smoother node with low-pass Butterworth filter.
Subscribes to /odom, applies low-pass Butterworth filter to linear and angular
velocity components, and publishes smoothed odometry on /odom_smooth.
Reduces noise from encoder jitter and improves state estimation stability.
"""
import math
import rclpy
from rclpy.node import Node
from nav_msgs.msg import Odometry
from std_msgs.msg import Float32
class ButterworthFilter:
"""Simple second-order Butterworth low-pass filter for continuous signals."""
def __init__(self, cutoff_hz, sample_rate_hz, order=2):
"""Initialize Butterworth filter.
Args:
cutoff_hz: Cutoff frequency in Hz
sample_rate_hz: Sampling rate in Hz
order: Filter order (typically 1-4)
"""
self.cutoff_hz = cutoff_hz
self.sample_rate_hz = sample_rate_hz
self.order = order
# Normalized frequency (0 to 1, where 1 = Nyquist)
self.omega_n = 2.0 * math.pi * cutoff_hz / sample_rate_hz
# Simplified filter coefficients for order 2
# Using canonical form: y[n] = b0*x[n] + b1*x[n-1] + b2*x[n-2] - a1*y[n-1] - a2*y[n-2]
if order == 1:
# First-order filter
alpha = self.omega_n / (self.omega_n + 2.0)
self.b = [alpha, alpha]
self.a = [1.0, -(1.0 - 2.0 * alpha)]
else:
# Second-order filter (butterworth)
sqrt2 = math.sqrt(2.0)
wc = math.tan(self.omega_n / 2.0)
wc2 = wc * wc
denom = 1.0 + sqrt2 * wc + wc2
self.b = [wc2 / denom, 2.0 * wc2 / denom, wc2 / denom]
self.a = [1.0,
2.0 * (wc2 - 1.0) / denom,
(1.0 - sqrt2 * wc + wc2) / denom]
# State buffers
self.x_history = [0.0, 0.0, 0.0] # Input history
self.y_history = [0.0, 0.0] # Output history
def filter(self, x):
"""Apply filter to input value.
Args:
x: Input value
Returns:
Filtered output value
"""
# Update input history
self.x_history[2] = self.x_history[1]
self.x_history[1] = self.x_history[0]
self.x_history[0] = x
# Compute output using difference equation
if len(self.b) == 2:
# First-order filter
y = (self.b[0] * self.x_history[0] +
self.b[1] * self.x_history[1] -
self.a[1] * self.y_history[1])
else:
# Second-order filter
y = (self.b[0] * self.x_history[0] +
self.b[1] * self.x_history[1] +
self.b[2] * self.x_history[2] -
self.a[1] * self.y_history[1] -
self.a[2] * self.y_history[2])
# Update output history
self.y_history[1] = self.y_history[0]
self.y_history[0] = y
return y
def reset(self):
"""Reset filter state."""
self.x_history = [0.0, 0.0, 0.0]
self.y_history = [0.0, 0.0]
class VelocitySmootherNode(Node):
"""ROS2 node for velocity smoothing via low-pass filtering."""
def __init__(self):
super().__init__('velocity_smoother')
# Parameters
self.declare_parameter('filter_order', 2)
self.declare_parameter('cutoff_frequency', 5.0)
self.declare_parameter('publish_frequency', 50)
self.declare_parameter('enable_smoothing', True)
filter_order = self.get_parameter('filter_order').value
cutoff_frequency = self.get_parameter('cutoff_frequency').value
publish_frequency = self.get_parameter('publish_frequency').value
self.enable_smoothing = self.get_parameter('enable_smoothing').value
# Create filters for each velocity component
self.filter_linear_x = ButterworthFilter(
cutoff_frequency, publish_frequency, order=filter_order
)
self.filter_linear_y = ButterworthFilter(
cutoff_frequency, publish_frequency, order=filter_order
)
self.filter_linear_z = ButterworthFilter(
cutoff_frequency, publish_frequency, order=filter_order
)
self.filter_angular_x = ButterworthFilter(
cutoff_frequency, publish_frequency, order=filter_order
)
self.filter_angular_y = ButterworthFilter(
cutoff_frequency, publish_frequency, order=filter_order
)
self.filter_angular_z = ButterworthFilter(
cutoff_frequency, publish_frequency, order=filter_order
)
# Last received odometry
self.last_odom = None
# Subscription to raw odometry
self.sub_odom = self.create_subscription(
Odometry, '/odom', self._on_odom, 10
)
# Publisher for smoothed odometry
self.pub_odom_smooth = self.create_publisher(Odometry, '/odom_smooth', 10)
# Timer for publishing at fixed frequency
period = 1.0 / publish_frequency
self.timer = self.create_timer(period, self._timer_callback)
self.get_logger().info(
f"Velocity smoother initialized. "
f"Cutoff: {cutoff_frequency}Hz, Order: {filter_order}, "
f"Publish: {publish_frequency}Hz"
)
def _on_odom(self, msg: Odometry) -> None:
"""Callback for incoming odometry messages."""
self.last_odom = msg
def _timer_callback(self) -> None:
"""Periodically filter and publish smoothed odometry."""
if self.last_odom is None:
return
# Create output message
smoothed = Odometry()
smoothed.header = self.last_odom.header
smoothed.child_frame_id = self.last_odom.child_frame_id
# Copy pose (unchanged)
smoothed.pose = self.last_odom.pose
if self.enable_smoothing:
# Filter velocity components
linear_x = self.filter_linear_x.filter(
self.last_odom.twist.twist.linear.x
)
linear_y = self.filter_linear_y.filter(
self.last_odom.twist.twist.linear.y
)
linear_z = self.filter_linear_z.filter(
self.last_odom.twist.twist.linear.z
)
angular_x = self.filter_angular_x.filter(
self.last_odom.twist.twist.angular.x
)
angular_y = self.filter_angular_y.filter(
self.last_odom.twist.twist.angular.y
)
angular_z = self.filter_angular_z.filter(
self.last_odom.twist.twist.angular.z
)
else:
# No filtering
linear_x = self.last_odom.twist.twist.linear.x
linear_y = self.last_odom.twist.twist.linear.y
linear_z = self.last_odom.twist.twist.linear.z
angular_x = self.last_odom.twist.twist.angular.x
angular_y = self.last_odom.twist.twist.angular.y
angular_z = self.last_odom.twist.twist.angular.z
# Set smoothed twist
smoothed.twist.twist.linear.x = linear_x
smoothed.twist.twist.linear.y = linear_y
smoothed.twist.twist.linear.z = linear_z
smoothed.twist.twist.angular.x = angular_x
smoothed.twist.twist.angular.y = angular_y
smoothed.twist.twist.angular.z = angular_z
# Copy covariances
smoothed.twist.covariance = self.last_odom.twist.covariance
self.pub_odom_smooth.publish(smoothed)
def main(args=None):
rclpy.init(args=args)
node = VelocitySmootherNode()
try:
rclpy.spin(node)
except KeyboardInterrupt:
pass
finally:
node.destroy_node()
rclpy.shutdown()
if __name__ == '__main__':
main()

View File

@ -0,0 +1,5 @@
[develop]
script_dir=$base/lib/saltybot_velocity_smoother
[egg_info]
tag_build =
tag_date = 0

View File

@ -0,0 +1,34 @@
from setuptools import setup, find_packages
package_name = 'saltybot_velocity_smoother'
setup(
name=package_name,
version='0.1.0',
packages=find_packages(exclude=['test']),
data_files=[
('share/ament_index/resource_index/packages',
['resource/saltybot_velocity_smoother']),
('share/' + package_name, ['package.xml']),
('share/' + package_name + '/config', ['config/velocity_smoother_config.yaml']),
('share/' + package_name + '/launch', ['launch/velocity_smoother.launch.py']),
],
install_requires=['setuptools'],
zip_safe=True,
author='SaltyBot Controls',
author_email='sl-controls@saltybot.local',
maintainer='SaltyBot Controls',
maintainer_email='sl-controls@saltybot.local',
keywords=['ROS2', 'velocity', 'filtering', 'butterworth'],
classifiers=[
'Intended Audience :: Developers',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3',
'Topic :: Software Development',
],
entry_points={
'console_scripts': [
'velocity_smoother_node=saltybot_velocity_smoother.velocity_smoother_node:main',
],
},
)

View File

@ -0,0 +1,432 @@
"""Unit tests for velocity smoother node."""
import pytest
import math
from nav_msgs.msg import Odometry
from geometry_msgs.msg import TwistWithCovariance, Twist, Vector3
from std_msgs.msg import Header
import rclpy
from rclpy.time import Time
from saltybot_velocity_smoother.velocity_smoother_node import (
VelocitySmootherNode,
ButterworthFilter,
)
@pytest.fixture
def rclpy_fixture():
"""Initialize and cleanup rclpy."""
rclpy.init()
yield
rclpy.shutdown()
@pytest.fixture
def node(rclpy_fixture):
"""Create a velocity smoother node instance."""
node = VelocitySmootherNode()
yield node
node.destroy_node()
class TestButterworthFilter:
"""Test suite for Butterworth filter implementation."""
def test_filter_initialization(self):
"""Test filter initialization with valid parameters."""
filt = ButterworthFilter(5.0, 50, order=2)
assert filt.cutoff_hz == 5.0
assert filt.sample_rate_hz == 50
assert filt.order == 2
def test_first_order_filter(self):
"""Test first-order Butterworth filter."""
filt = ButterworthFilter(5.0, 50, order=1)
assert len(filt.b) == 2
assert len(filt.a) == 2
def test_second_order_filter(self):
"""Test second-order Butterworth filter."""
filt = ButterworthFilter(5.0, 50, order=2)
assert len(filt.b) == 3
assert len(filt.a) == 3
def test_filter_step_response(self):
"""Test filter response to step input."""
filt = ButterworthFilter(5.0, 50, order=2)
# Apply step input (0 -> 1.0)
outputs = []
for i in range(50):
y = filt.filter(1.0)
outputs.append(y)
# Final output should be close to 1.0
assert outputs[-1] > 0.9
assert outputs[-1] <= 1.0
def test_filter_constant_input(self):
"""Test filter with constant input."""
filt = ButterworthFilter(5.0, 50, order=2)
# Apply constant input
for i in range(100):
y = filt.filter(2.5)
# Should converge to input value
assert abs(y - 2.5) < 0.01
def test_filter_zero_input(self):
"""Test filter with zero input."""
filt = ButterworthFilter(5.0, 50, order=2)
# Apply non-zero then zero
for i in range(50):
filt.filter(1.0)
# Now apply zero
for i in range(50):
y = filt.filter(0.0)
# Should decay to zero
assert abs(y) < 0.01
def test_filter_reset(self):
"""Test filter state reset."""
filt = ButterworthFilter(5.0, 50, order=2)
# Filter some values
for i in range(10):
filt.filter(1.0)
# Reset
filt.reset()
# State should be zero
assert filt.x_history == [0.0, 0.0, 0.0]
assert filt.y_history == [0.0, 0.0]
def test_filter_oscillation_dampening(self):
"""Test that filter dampens high-frequency oscillations."""
filt = ButterworthFilter(5.0, 50, order=2)
# Apply alternating signal (high frequency)
outputs = []
for i in range(100):
x = 1.0 if i % 2 == 0 else -1.0
y = filt.filter(x)
outputs.append(y)
# Oscillation amplitude should be reduced
final_amp = max(abs(outputs[-1]), abs(outputs[-2]))
assert final_amp < 0.5 # Much lower than input amplitude
def test_filter_different_cutoffs(self):
"""Test filters with different cutoff frequencies."""
filt_low = ButterworthFilter(2.0, 50, order=2)
filt_high = ButterworthFilter(10.0, 50, order=2)
# Both should be valid
assert filt_low.cutoff_hz == 2.0
assert filt_high.cutoff_hz == 10.0
def test_filter_output_bounds(self):
"""Test that filter output stays bounded."""
filt = ButterworthFilter(5.0, 50, order=2)
# Apply large random-like values
for i in range(100):
x = math.sin(i * 0.5) * 5.0
y = filt.filter(x)
assert abs(y) < 10.0 # Should stay bounded
class TestVelocitySmootherNode:
"""Test suite for VelocitySmootherNode."""
def test_node_initialization(self, node):
"""Test that node initializes correctly."""
assert node.last_odom is None
assert node.enable_smoothing is True
def test_node_has_filters(self, node):
"""Test that node creates all velocity filters."""
assert node.filter_linear_x is not None
assert node.filter_linear_y is not None
assert node.filter_linear_z is not None
assert node.filter_angular_x is not None
assert node.filter_angular_y is not None
assert node.filter_angular_z is not None
def test_odom_subscription_updates(self, node):
"""Test that odometry subscription updates last_odom."""
odom = Odometry()
odom.header.frame_id = "odom"
odom.twist.twist.linear.x = 1.0
node._on_odom(odom)
assert node.last_odom is not None
assert node.last_odom.twist.twist.linear.x == 1.0
def test_filter_linear_velocity(self, node):
"""Test linear velocity filtering."""
# Create odometry message
odom = Odometry()
odom.header.frame_id = "odom"
odom.child_frame_id = "base_link"
odom.twist.twist.linear.x = 1.0
odom.twist.twist.linear.y = 0.5
odom.twist.twist.linear.z = 0.0
odom.twist.twist.angular.x = 0.0
odom.twist.twist.angular.y = 0.0
odom.twist.twist.angular.z = 0.2
node._on_odom(odom)
# Call timer callback to process
node._timer_callback()
# Filter should have been applied
assert node.filter_linear_x.x_history[0] == 1.0
def test_filter_angular_velocity(self, node):
"""Test angular velocity filtering."""
odom = Odometry()
odom.header.frame_id = "odom"
odom.twist.twist.angular.z = 0.5
node._on_odom(odom)
node._timer_callback()
assert node.filter_angular_z.x_history[0] == 0.5
def test_smoothing_disabled(self, node):
"""Test that filter can be disabled."""
node.enable_smoothing = False
odom = Odometry()
odom.header.frame_id = "odom"
odom.twist.twist.linear.x = 2.0
node._on_odom(odom)
node._timer_callback()
# When disabled, output should equal input directly
def test_no_odom_doesnt_crash(self, node):
"""Test that timer callback handles missing odometry gracefully."""
# Call timer without setting odometry
node._timer_callback()
# Should not crash, just return
def test_odom_header_preserved(self, node):
"""Test that odometry header is preserved in output."""
odom = Odometry()
odom.header.frame_id = "test_frame"
odom.header.stamp = node.get_clock().now()
odom.child_frame_id = "test_child"
node._on_odom(odom)
# Timer callback processes it
node._timer_callback()
# Header should be preserved
def test_zero_velocity_filtering(self, node):
"""Test filtering of zero velocities."""
odom = Odometry()
odom.header.frame_id = "odom"
odom.twist.twist.linear.x = 0.0
odom.twist.twist.linear.y = 0.0
odom.twist.twist.linear.z = 0.0
odom.twist.twist.angular.x = 0.0
odom.twist.twist.angular.y = 0.0
odom.twist.twist.angular.z = 0.0
node._on_odom(odom)
node._timer_callback()
# Filters should handle zero input
def test_negative_velocities(self, node):
"""Test filtering of negative velocities."""
odom = Odometry()
odom.header.frame_id = "odom"
odom.twist.twist.linear.x = -1.0
odom.twist.twist.angular.z = -0.5
node._on_odom(odom)
node._timer_callback()
assert node.filter_linear_x.x_history[0] == -1.0
def test_high_frequency_noise_dampening(self, node):
"""Test that filter dampens high-frequency encoder noise."""
# Simulate noisy encoder output
base_velocity = 1.0
noise_amplitude = 0.2
odom = Odometry()
odom.header.frame_id = "odom"
# Apply alternating noise
for i in range(100):
odom.twist.twist.linear.x = base_velocity + (noise_amplitude if i % 2 == 0 else -noise_amplitude)
node._on_odom(odom)
node._timer_callback()
# After filtering, output should be close to base velocity
# (oscillations dampened)
def test_large_velocity_values(self, node):
"""Test filtering of large velocity values."""
odom = Odometry()
odom.header.frame_id = "odom"
odom.twist.twist.linear.x = 10.0
odom.twist.twist.angular.z = 5.0
node._on_odom(odom)
node._timer_callback()
# Should handle large values without overflow
def test_pose_unchanged(self, node):
"""Test that pose is not modified by velocity filtering."""
odom = Odometry()
odom.header.frame_id = "odom"
odom.pose.pose.position.x = 5.0
odom.pose.pose.position.y = 3.0
odom.twist.twist.linear.x = 1.0
node._on_odom(odom)
node._timer_callback()
# Pose should be copied unchanged
def test_multiple_velocity_updates(self, node):
"""Test filtering across multiple sequential velocity updates."""
odom = Odometry()
odom.header.frame_id = "odom"
velocities = [0.5, 1.0, 1.5, 1.0, 0.5, 0.0]
for v in velocities:
odom.twist.twist.linear.x = v
node._on_odom(odom)
node._timer_callback()
# Filter should smooth the velocity sequence
def test_simultaneous_all_velocities(self, node):
"""Test filtering when all velocity components are present."""
odom = Odometry()
odom.header.frame_id = "odom"
for i in range(30):
odom.twist.twist.linear.x = math.sin(i * 0.1)
odom.twist.twist.linear.y = math.cos(i * 0.1) * 0.5
odom.twist.twist.linear.z = 0.1
odom.twist.twist.angular.x = 0.05
odom.twist.twist.angular.y = 0.05
odom.twist.twist.angular.z = math.sin(i * 0.15)
node._on_odom(odom)
node._timer_callback()
# All filters should operate independently
class TestVelocitySmootherScenarios:
"""Integration-style tests for realistic scenarios."""
def test_scenario_constant_velocity(self, node):
"""Scenario: robot moving at constant velocity."""
odom = Odometry()
odom.header.frame_id = "odom"
odom.twist.twist.linear.x = 1.0
for i in range(50):
node._on_odom(odom)
node._timer_callback()
# Should maintain constant velocity after convergence
def test_scenario_velocity_ramp(self, node):
"""Scenario: velocity ramping up from stop."""
odom = Odometry()
odom.header.frame_id = "odom"
for i in range(50):
odom.twist.twist.linear.x = i * 0.02 # Ramp from 0 to 1.0
node._on_odom(odom)
node._timer_callback()
# Filter should smooth the ramp
def test_scenario_velocity_step(self, node):
"""Scenario: sudden velocity change (e.g., collision avoidance)."""
odom = Odometry()
odom.header.frame_id = "odom"
# First phase: constant velocity
odom.twist.twist.linear.x = 1.0
for i in range(25):
node._on_odom(odom)
node._timer_callback()
# Second phase: sudden stop
odom.twist.twist.linear.x = 0.0
for i in range(25):
node._on_odom(odom)
node._timer_callback()
# Filter should transition smoothly
def test_scenario_rotation_only(self, node):
"""Scenario: robot spinning in place."""
odom = Odometry()
odom.header.frame_id = "odom"
odom.twist.twist.linear.x = 0.0
odom.twist.twist.angular.z = 0.5
for i in range(50):
node._on_odom(odom)
node._timer_callback()
# Angular velocity should be filtered
def test_scenario_mixed_motion(self, node):
"""Scenario: combined linear and angular motion."""
odom = Odometry()
odom.header.frame_id = "odom"
for i in range(50):
odom.twist.twist.linear.x = math.cos(i * 0.1)
odom.twist.twist.linear.y = math.sin(i * 0.1)
odom.twist.twist.angular.z = 0.2
node._on_odom(odom)
node._timer_callback()
# Both linear and angular components should be filtered
def test_scenario_encoder_noise_reduction(self, node):
"""Scenario: realistic encoder jitter with filtering."""
odom = Odometry()
odom.header.frame_id = "odom"
# Simulate encoder jitter: base velocity + small noise
base_vel = 1.0
for i in range(100):
jitter = 0.05 * math.sin(i * 0.5) + 0.03 * math.cos(i * 0.3)
odom.twist.twist.linear.x = base_vel + jitter
node._on_odom(odom)
node._timer_callback()
# Filter should reduce noise while maintaining base velocity