Compare commits
6 Commits
e76f0ab95f
...
3530f16fa8
| Author | SHA1 | Date | |
|---|---|---|---|
| 3530f16fa8 | |||
| d6ee5a16b7 | |||
| 94a6f0787e | |||
| 50636de5a9 | |||
| 9d12805843 | |||
| 3cd9faeed9 |
@ -0,0 +1,150 @@
|
||||
"""
|
||||
_vo_drift.py — Visual odometry drift detector helpers (no ROS2 deps).
|
||||
|
||||
Algorithm
|
||||
---------
|
||||
Two independent odometry streams (visual and wheel) are compared over a
|
||||
sliding time window. Drift is measured as the absolute difference in
|
||||
cumulative path length travelled by each source over that window:
|
||||
|
||||
drift_m = |path_length(vo_window) − path_length(wheel_window)|
|
||||
|
||||
Using cumulative path length (sum of inter-sample Euclidean steps) rather
|
||||
than straight-line displacement makes the measure robust to circular motion
|
||||
where start and end positions are the same.
|
||||
|
||||
Drift is flagged when drift_m ≥ drift_threshold_m.
|
||||
|
||||
Public API
|
||||
----------
|
||||
OdomSample — namedtuple(t, x, y)
|
||||
OdomBuffer — deque of OdomSamples with time-window trimming
|
||||
compute_drift() — compare two OdomBuffers and return DriftResult
|
||||
DriftResult — namedtuple(drift_m, vo_path_m, wheel_path_m,
|
||||
is_drifting, window_s, n_vo, n_wheel)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from collections import deque
|
||||
from typing import NamedTuple, Sequence
|
||||
|
||||
|
||||
class OdomSample(NamedTuple):
|
||||
t: float # monotonic timestamp (seconds)
|
||||
x: float # position x (metres)
|
||||
y: float # position y (metres)
|
||||
|
||||
|
||||
class DriftResult(NamedTuple):
|
||||
drift_m: float # |vo_path − wheel_path| (metres)
|
||||
vo_path_m: float # cumulative path of VO source over window (metres)
|
||||
wheel_path_m: float # cumulative path of wheel source over window (metres)
|
||||
is_drifting: bool # True when drift_m >= threshold
|
||||
window_s: float # actual time span of data used (seconds)
|
||||
n_vo: int # number of VO samples in window
|
||||
n_wheel: int # number of wheel samples in window
|
||||
|
||||
|
||||
class OdomBuffer:
|
||||
"""
|
||||
Rolling buffer of OdomSamples trimmed to the last `max_age_s` seconds.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
max_age_s : float — samples older than this are discarded (seconds)
|
||||
"""
|
||||
|
||||
def __init__(self, max_age_s: float = 10.0) -> None:
|
||||
self._max_age = max_age_s
|
||||
self._buf: deque[OdomSample] = deque()
|
||||
|
||||
# ── Public ────────────────────────────────────────────────────────────────
|
||||
|
||||
def push(self, sample: OdomSample) -> None:
|
||||
"""Append a sample and evict anything older than max_age_s."""
|
||||
self._buf.append(sample)
|
||||
self._trim(sample.t)
|
||||
|
||||
def window(self, window_s: float, now: float) -> list[OdomSample]:
|
||||
"""Return samples within the last window_s seconds of `now`."""
|
||||
cutoff = now - window_s
|
||||
return [s for s in self._buf if s.t >= cutoff]
|
||||
|
||||
def clear(self) -> None:
|
||||
self._buf.clear()
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._buf)
|
||||
|
||||
# ── Internal ──────────────────────────────────────────────────────────────
|
||||
|
||||
def _trim(self, now: float) -> None:
|
||||
cutoff = now - self._max_age
|
||||
while self._buf and self._buf[0].t < cutoff:
|
||||
self._buf.popleft()
|
||||
|
||||
|
||||
# ── Core computation ──────────────────────────────────────────────────────────
|
||||
|
||||
def compute_drift(
|
||||
vo_buf: OdomBuffer,
|
||||
wheel_buf: OdomBuffer,
|
||||
window_s: float,
|
||||
drift_threshold_m: float,
|
||||
now: float,
|
||||
) -> DriftResult:
|
||||
"""
|
||||
Compare VO and wheel odometry path lengths over the last `window_s`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
vo_buf : OdomBuffer of visual odometry samples
|
||||
wheel_buf : OdomBuffer of wheel odometry samples
|
||||
window_s : comparison window width (seconds)
|
||||
drift_threshold_m : drift_m threshold for is_drifting flag
|
||||
now : current time (same scale as OdomSample.t)
|
||||
|
||||
Returns
|
||||
-------
|
||||
DriftResult — zero drift if either buffer has fewer than 2 samples.
|
||||
"""
|
||||
vo_samples = vo_buf.window(window_s, now)
|
||||
wheel_samples = wheel_buf.window(window_s, now)
|
||||
|
||||
if len(vo_samples) < 2 or len(wheel_samples) < 2:
|
||||
return DriftResult(
|
||||
drift_m=0.0, vo_path_m=0.0, wheel_path_m=0.0,
|
||||
is_drifting=False,
|
||||
window_s=0.0, n_vo=len(vo_samples), n_wheel=len(wheel_samples),
|
||||
)
|
||||
|
||||
vo_path = _path_length(vo_samples)
|
||||
wheel_path = _path_length(wheel_samples)
|
||||
drift_m = abs(vo_path - wheel_path)
|
||||
|
||||
# Actual data span = latest timestamp − earliest across both buffers
|
||||
t_min = min(vo_samples[0].t, wheel_samples[0].t)
|
||||
t_max = max(vo_samples[-1].t, wheel_samples[-1].t)
|
||||
actual_window = t_max - t_min
|
||||
|
||||
return DriftResult(
|
||||
drift_m=drift_m,
|
||||
vo_path_m=vo_path,
|
||||
wheel_path_m=wheel_path,
|
||||
is_drifting=drift_m >= drift_threshold_m,
|
||||
window_s=actual_window,
|
||||
n_vo=len(vo_samples),
|
||||
n_wheel=len(wheel_samples),
|
||||
)
|
||||
|
||||
|
||||
def _path_length(samples: Sequence[OdomSample]) -> float:
|
||||
"""Sum of Euclidean inter-sample distances."""
|
||||
total = 0.0
|
||||
for i in range(1, len(samples)):
|
||||
dx = samples[i].x - samples[i - 1].x
|
||||
dy = samples[i].y - samples[i - 1].y
|
||||
total += math.sqrt(dx * dx + dy * dy)
|
||||
return total
|
||||
@ -0,0 +1,150 @@
|
||||
"""
|
||||
vo_drift_node.py — Visual odometry drift detector (Issue #260).
|
||||
|
||||
Compares the cumulative path lengths of visual odometry and wheel odometry
|
||||
over a sliding window. When the absolute difference exceeds the configured
|
||||
threshold the node flags drift, allowing the system to warn operators,
|
||||
inflate VO covariance, or fall back to wheel-only localisation.
|
||||
|
||||
Subscribes (BEST_EFFORT):
|
||||
/camera/odom nav_msgs/Odometry visual odometry
|
||||
/odom nav_msgs/Odometry wheel odometry
|
||||
|
||||
→ For this robot remap to /saltybot/visual_odom + /saltybot/rover_odom.
|
||||
|
||||
Publishes:
|
||||
/saltybot/vo_drift_detected std_msgs/Bool True while drifting
|
||||
/saltybot/vo_drift_magnitude std_msgs/Float32 drift magnitude (metres)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
vo_topic str /camera/odom Visual odometry source topic
|
||||
wheel_topic str /odom Wheel odometry source topic
|
||||
drift_threshold_m float 0.5 Drift flag threshold (metres)
|
||||
window_s float 10.0 Comparison window (seconds)
|
||||
publish_hz float 2.0 Output publication rate (Hz)
|
||||
max_buffer_age_s float 30.0 Max age of stored samples (s)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
|
||||
import rclpy
|
||||
from rclpy.node import Node
|
||||
from rclpy.qos import QoSProfile, ReliabilityPolicy, HistoryPolicy
|
||||
|
||||
from nav_msgs.msg import Odometry
|
||||
from std_msgs.msg import Bool, Float32
|
||||
|
||||
from ._vo_drift import OdomBuffer, OdomSample, compute_drift
|
||||
|
||||
|
||||
_SENSOR_QOS = QoSProfile(
|
||||
reliability=ReliabilityPolicy.BEST_EFFORT,
|
||||
history=HistoryPolicy.KEEP_LAST,
|
||||
depth=4,
|
||||
)
|
||||
|
||||
|
||||
class VoDriftNode(Node):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__('vo_drift_node')
|
||||
|
||||
self.declare_parameter('vo_topic', '/camera/odom')
|
||||
self.declare_parameter('wheel_topic', '/odom')
|
||||
self.declare_parameter('drift_threshold_m', 0.5)
|
||||
self.declare_parameter('window_s', 10.0)
|
||||
self.declare_parameter('publish_hz', 2.0)
|
||||
self.declare_parameter('max_buffer_age_s', 30.0)
|
||||
|
||||
vo_topic = self.get_parameter('vo_topic').value
|
||||
wheel_topic = self.get_parameter('wheel_topic').value
|
||||
self._thresh = self.get_parameter('drift_threshold_m').value
|
||||
self._window_s = self.get_parameter('window_s').value
|
||||
publish_hz = self.get_parameter('publish_hz').value
|
||||
max_age = self.get_parameter('max_buffer_age_s').value
|
||||
|
||||
self._vo_buf = OdomBuffer(max_age_s=max_age)
|
||||
self._wheel_buf = OdomBuffer(max_age_s=max_age)
|
||||
|
||||
self.create_subscription(
|
||||
Odometry, vo_topic, self._on_vo, _SENSOR_QOS)
|
||||
self.create_subscription(
|
||||
Odometry, wheel_topic, self._on_wheel, _SENSOR_QOS)
|
||||
|
||||
self._pub_detected = self.create_publisher(
|
||||
Bool, '/saltybot/vo_drift_detected', 10)
|
||||
self._pub_magnitude = self.create_publisher(
|
||||
Float32, '/saltybot/vo_drift_magnitude', 10)
|
||||
|
||||
self.create_timer(1.0 / publish_hz, self._tick)
|
||||
|
||||
self.get_logger().info(
|
||||
f'vo_drift_node ready — '
|
||||
f'vo={vo_topic} wheel={wheel_topic} '
|
||||
f'threshold={self._thresh}m window={self._window_s}s'
|
||||
)
|
||||
|
||||
# ── Callbacks ─────────────────────────────────────────────────────────────
|
||||
|
||||
def _on_vo(self, msg: Odometry) -> None:
|
||||
s = _odom_to_sample(msg)
|
||||
self._vo_buf.push(s)
|
||||
|
||||
def _on_wheel(self, msg: Odometry) -> None:
|
||||
s = _odom_to_sample(msg)
|
||||
self._wheel_buf.push(s)
|
||||
|
||||
# ── Publish tick ──────────────────────────────────────────────────────────
|
||||
|
||||
def _tick(self) -> None:
|
||||
now = time.monotonic()
|
||||
result = compute_drift(
|
||||
self._vo_buf, self._wheel_buf,
|
||||
window_s=self._window_s,
|
||||
drift_threshold_m=self._thresh,
|
||||
now=now,
|
||||
)
|
||||
|
||||
if result.is_drifting:
|
||||
self.get_logger().warn(
|
||||
f'VO drift detected: {result.drift_m:.3f}m '
|
||||
f'(vo={result.vo_path_m:.3f}m wheel={result.wheel_path_m:.3f}m '
|
||||
f'over {result.window_s:.1f}s)',
|
||||
throttle_duration_sec=5.0,
|
||||
)
|
||||
|
||||
det_msg = Bool()
|
||||
det_msg.data = result.is_drifting
|
||||
self._pub_detected.publish(det_msg)
|
||||
|
||||
mag_msg = Float32()
|
||||
mag_msg.data = float(result.drift_m)
|
||||
self._pub_magnitude.publish(mag_msg)
|
||||
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────────────
|
||||
|
||||
def _odom_to_sample(msg: Odometry) -> OdomSample:
|
||||
"""Convert nav_msgs/Odometry to OdomSample using monotonic clock."""
|
||||
return OdomSample(
|
||||
t=time.monotonic(),
|
||||
x=msg.pose.pose.position.x,
|
||||
y=msg.pose.pose.position.y,
|
||||
)
|
||||
|
||||
|
||||
def main(args=None) -> None:
|
||||
rclpy.init(args=args)
|
||||
node = VoDriftNode()
|
||||
try:
|
||||
rclpy.spin(node)
|
||||
finally:
|
||||
node.destroy_node()
|
||||
rclpy.shutdown()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@ -35,6 +35,8 @@ setup(
|
||||
'lidar_clustering = saltybot_bringup.lidar_clustering_node:main',
|
||||
# Floor surface type classifier (Issue #249)
|
||||
'floor_classifier = saltybot_bringup.floor_classifier_node:main',
|
||||
# Visual odometry drift detector (Issue #260)
|
||||
'vo_drift_detector = saltybot_bringup.vo_drift_node:main',
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
297
jetson/ros2_ws/src/saltybot_bringup/test/test_vo_drift.py
Normal file
297
jetson/ros2_ws/src/saltybot_bringup/test/test_vo_drift.py
Normal file
@ -0,0 +1,297 @@
|
||||
"""
|
||||
test_vo_drift.py — Unit tests for VO drift detector helpers (no ROS2 required).
|
||||
|
||||
Covers:
|
||||
OdomBuffer:
|
||||
- push/len
|
||||
- window returns only samples within cutoff
|
||||
- old samples are evicted beyond max_age_s
|
||||
- clear empties the buffer
|
||||
- window on empty buffer returns empty list
|
||||
|
||||
_path_length (via compute_drift with crafted samples):
|
||||
- stationary source → path = 0
|
||||
- straight-line motion → path = total distance
|
||||
- L-shaped path → path = sum of two legs
|
||||
|
||||
compute_drift:
|
||||
- both empty → DriftResult with zeros, is_drifting=False
|
||||
- one buffer < 2 samples → zero drift
|
||||
- both move same distance → drift ≈ 0, not drifting
|
||||
- VO moves 1m, wheel moves 0.5m → drift = 0.5m
|
||||
- drift == threshold → is_drifting=True (>=)
|
||||
- drift < threshold → is_drifting=False
|
||||
- drift > threshold → is_drifting=True
|
||||
- path lengths in result match expectation
|
||||
- n_vo / n_wheel counts correct
|
||||
- samples outside window ignored
|
||||
- window_s in result reflects actual data span
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import math
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
|
||||
|
||||
from saltybot_bringup._vo_drift import (
|
||||
OdomSample,
|
||||
OdomBuffer,
|
||||
DriftResult,
|
||||
compute_drift,
|
||||
_path_length,
|
||||
)
|
||||
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────────────
|
||||
|
||||
def _s(t, x, y) -> OdomSample:
|
||||
return OdomSample(t=t, x=x, y=y)
|
||||
|
||||
|
||||
def _straight_buf(n=5, speed=0.1, t_start=0.0, dt=1.0,
|
||||
max_age_s=30.0) -> OdomBuffer:
|
||||
"""n samples moving along +x at `speed` m/s."""
|
||||
buf = OdomBuffer(max_age_s=max_age_s)
|
||||
for i in range(n):
|
||||
buf.push(_s(t_start + i * dt, x=i * speed * dt, y=0.0))
|
||||
return buf
|
||||
|
||||
|
||||
def _stationary_buf(n=5, t_start=0.0, dt=1.0,
|
||||
max_age_s=30.0) -> OdomBuffer:
|
||||
buf = OdomBuffer(max_age_s=max_age_s)
|
||||
for i in range(n):
|
||||
buf.push(_s(t_start + i * dt, x=0.0, y=0.0))
|
||||
return buf
|
||||
|
||||
|
||||
# ── OdomBuffer ────────────────────────────────────────────────────────────────
|
||||
|
||||
class TestOdomBuffer:
|
||||
|
||||
def test_push_increases_len(self):
|
||||
buf = OdomBuffer()
|
||||
assert len(buf) == 0
|
||||
buf.push(_s(0.0, 0.0, 0.0))
|
||||
assert len(buf) == 1
|
||||
|
||||
def test_window_returns_all_within_cutoff(self):
|
||||
buf = OdomBuffer(max_age_s=30.0)
|
||||
for t in [0.0, 5.0, 10.0]:
|
||||
buf.push(_s(t, 0.0, 0.0))
|
||||
samples = buf.window(window_s=10.0, now=10.0)
|
||||
assert len(samples) == 3
|
||||
|
||||
def test_window_excludes_old_samples(self):
|
||||
buf = OdomBuffer(max_age_s=30.0)
|
||||
for t in [0.0, 5.0, 15.0]:
|
||||
buf.push(_s(t, 0.0, 0.0))
|
||||
# window=5s from now=15 → only t=15 qualifies (t>=10)
|
||||
samples = buf.window(window_s=5.0, now=15.0)
|
||||
assert len(samples) == 1
|
||||
assert samples[0].t == 15.0
|
||||
|
||||
def test_evicts_samples_beyond_max_age(self):
|
||||
buf = OdomBuffer(max_age_s=5.0)
|
||||
buf.push(_s(0.0, 0.0, 0.0))
|
||||
buf.push(_s(10.0, 1.0, 0.0)) # now=10 → t=0 is 10s old > 5s max
|
||||
assert len(buf) == 1
|
||||
|
||||
def test_clear_empties_buffer(self):
|
||||
buf = _straight_buf(n=5)
|
||||
buf.clear()
|
||||
assert len(buf) == 0
|
||||
|
||||
def test_window_on_empty_buffer(self):
|
||||
buf = OdomBuffer()
|
||||
assert buf.window(window_s=10.0, now=100.0) == []
|
||||
|
||||
def test_window_boundary_inclusive(self):
|
||||
"""Sample exactly at window cutoff (t == now - window_s) is included."""
|
||||
buf = OdomBuffer(max_age_s=30.0)
|
||||
buf.push(_s(0.0, 0.0, 0.0))
|
||||
# window=10, now=10 → cutoff=0.0, sample at t=0.0 should be included
|
||||
samples = buf.window(window_s=10.0, now=10.0)
|
||||
assert len(samples) == 1
|
||||
|
||||
|
||||
# ── _path_length ──────────────────────────────────────────────────────────────
|
||||
|
||||
class TestPathLength:
|
||||
|
||||
def test_stationary_path_zero(self):
|
||||
samples = [_s(i, 0.0, 0.0) for i in range(5)]
|
||||
assert _path_length(samples) == pytest.approx(0.0)
|
||||
|
||||
def test_unit_step_path(self):
|
||||
samples = [_s(0, 0.0, 0.0), _s(1, 1.0, 0.0)]
|
||||
assert _path_length(samples) == pytest.approx(1.0)
|
||||
|
||||
def test_two_unit_steps(self):
|
||||
samples = [_s(0, 0.0, 0.0), _s(1, 1.0, 0.0), _s(2, 2.0, 0.0)]
|
||||
assert _path_length(samples) == pytest.approx(2.0)
|
||||
|
||||
def test_diagonal_step(self):
|
||||
# (0,0) → (1,1): distance = sqrt(2)
|
||||
samples = [_s(0, 0.0, 0.0), _s(1, 1.0, 1.0)]
|
||||
assert _path_length(samples) == pytest.approx(math.sqrt(2))
|
||||
|
||||
def test_l_shaped_path(self):
|
||||
# Right 3m then up 4m → total path = 7m (not hypotenuse)
|
||||
samples = [_s(0, 0.0, 0.0), _s(1, 3.0, 0.0), _s(2, 3.0, 4.0)]
|
||||
assert _path_length(samples) == pytest.approx(7.0)
|
||||
|
||||
def test_single_sample_returns_zero(self):
|
||||
assert _path_length([_s(0, 5.0, 5.0)]) == pytest.approx(0.0)
|
||||
|
||||
def test_empty_returns_zero(self):
|
||||
assert _path_length([]) == pytest.approx(0.0)
|
||||
|
||||
|
||||
# ── compute_drift ─────────────────────────────────────────────────────────────
|
||||
|
||||
class TestComputeDrift:
|
||||
|
||||
def test_both_empty_returns_zero_drift(self):
|
||||
result = compute_drift(
|
||||
OdomBuffer(), OdomBuffer(),
|
||||
window_s=10.0, drift_threshold_m=0.5, now=10.0)
|
||||
assert result.drift_m == pytest.approx(0.0)
|
||||
assert not result.is_drifting
|
||||
|
||||
def test_one_buffer_empty_returns_zero(self):
|
||||
vo = _straight_buf(n=5, speed=0.1)
|
||||
result = compute_drift(
|
||||
vo, OdomBuffer(),
|
||||
window_s=10.0, drift_threshold_m=0.5, now=5.0)
|
||||
assert result.drift_m == pytest.approx(0.0)
|
||||
assert not result.is_drifting
|
||||
|
||||
def test_one_buffer_single_sample_returns_zero(self):
|
||||
vo = _straight_buf(n=5, speed=0.1)
|
||||
wheel = OdomBuffer()
|
||||
wheel.push(_s(0.0, 0.0, 0.0)) # only 1 sample
|
||||
result = compute_drift(
|
||||
vo, wheel,
|
||||
window_s=10.0, drift_threshold_m=0.5, now=5.0)
|
||||
assert result.drift_m == pytest.approx(0.0)
|
||||
assert not result.is_drifting
|
||||
|
||||
def test_both_move_same_distance_zero_drift(self):
|
||||
# Both move 0.1 m/s for 4 steps → 0.4 m each
|
||||
vo = _straight_buf(n=5, speed=0.1, dt=1.0)
|
||||
wheel = _straight_buf(n=5, speed=0.1, dt=1.0)
|
||||
result = compute_drift(
|
||||
vo, wheel,
|
||||
window_s=10.0, drift_threshold_m=0.5, now=5.0)
|
||||
assert result.drift_m == pytest.approx(0.0, abs=1e-9)
|
||||
assert not result.is_drifting
|
||||
|
||||
def test_both_stationary_zero_drift(self):
|
||||
vo = _stationary_buf(n=5)
|
||||
wheel = _stationary_buf(n=5)
|
||||
result = compute_drift(
|
||||
vo, wheel,
|
||||
window_s=10.0, drift_threshold_m=0.5, now=5.0)
|
||||
assert result.drift_m == pytest.approx(0.0)
|
||||
assert not result.is_drifting
|
||||
|
||||
def test_drift_equals_path_length_difference(self):
|
||||
# VO moves 1.0 m total, wheel moves 0.5 m total
|
||||
vo = _straight_buf(n=11, speed=0.1, dt=1.0) # 10 steps × 0.1 = 1.0m
|
||||
wheel = _straight_buf(n=11, speed=0.05, dt=1.0) # 10 steps × 0.05 = 0.5m
|
||||
result = compute_drift(
|
||||
vo, wheel,
|
||||
window_s=15.0, drift_threshold_m=0.5, now=11.0)
|
||||
assert result.vo_path_m == pytest.approx(1.0, abs=1e-9)
|
||||
assert result.wheel_path_m == pytest.approx(0.5, abs=1e-9)
|
||||
assert result.drift_m == pytest.approx(0.5, abs=1e-9)
|
||||
|
||||
def test_drift_at_threshold_is_drifting(self):
|
||||
# drift == 0.5 → is_drifting = True (>= threshold)
|
||||
vo = _straight_buf(n=11, speed=0.1, dt=1.0)
|
||||
wheel = _straight_buf(n=11, speed=0.05, dt=1.0)
|
||||
result = compute_drift(
|
||||
vo, wheel,
|
||||
window_s=15.0, drift_threshold_m=0.5, now=11.0)
|
||||
assert result.is_drifting
|
||||
|
||||
def test_drift_below_threshold_not_drifting(self):
|
||||
vo = _straight_buf(n=11, speed=0.1, dt=1.0)
|
||||
wheel = _straight_buf(n=11, speed=0.08, dt=1.0)
|
||||
result = compute_drift(
|
||||
vo, wheel,
|
||||
window_s=15.0, drift_threshold_m=0.5, now=11.0)
|
||||
# drift = |1.0 - 0.8| = 0.2
|
||||
assert result.drift_m == pytest.approx(0.2, abs=1e-9)
|
||||
assert not result.is_drifting
|
||||
|
||||
def test_drift_above_threshold_is_drifting(self):
|
||||
vo = _straight_buf(n=11, speed=0.1, dt=1.0)
|
||||
wheel = _stationary_buf(n=11, dt=1.0)
|
||||
result = compute_drift(
|
||||
vo, wheel,
|
||||
window_s=15.0, drift_threshold_m=0.5, now=11.0)
|
||||
# drift = |1.0 - 0.0| = 1.0 > 0.5
|
||||
assert result.drift_m > 0.5
|
||||
assert result.is_drifting
|
||||
|
||||
def test_n_vo_n_wheel_counts(self):
|
||||
vo = _straight_buf(n=8, dt=1.0)
|
||||
wheel = _straight_buf(n=5, dt=1.0)
|
||||
result = compute_drift(
|
||||
vo, wheel,
|
||||
window_s=15.0, drift_threshold_m=0.5, now=8.0)
|
||||
assert result.n_vo == 8
|
||||
assert result.n_wheel == 5
|
||||
|
||||
def test_samples_outside_window_ignored(self):
|
||||
# Push old samples far in the past; should not contribute to window
|
||||
vo = OdomBuffer(max_age_s=60.0)
|
||||
wheel = OdomBuffer(max_age_s=60.0)
|
||||
# Old samples outside window (t=0..4, window is last 3s from now=10)
|
||||
for t in range(5):
|
||||
vo.push(_s(float(t), x=float(t), y=0.0))
|
||||
wheel.push(_s(float(t), x=float(t), y=0.0))
|
||||
# Recent samples inside window (t=7..10)
|
||||
for t in range(7, 11):
|
||||
vo.push(_s(float(t), x=float(t) * 0.1, y=0.0))
|
||||
wheel.push(_s(float(t), x=float(t) * 0.1, y=0.0))
|
||||
|
||||
result = compute_drift(
|
||||
vo, wheel,
|
||||
window_s=3.0, drift_threshold_m=0.5, now=10.0)
|
||||
# Both sources move identically inside window → zero drift
|
||||
assert result.drift_m == pytest.approx(0.0, abs=1e-9)
|
||||
# Only the 4 recent samples (t=7,8,9,10) in window
|
||||
assert result.n_vo == 4
|
||||
assert result.n_wheel == 4
|
||||
|
||||
def test_result_is_namedtuple(self):
|
||||
result = compute_drift(
|
||||
_straight_buf(), _straight_buf(),
|
||||
window_s=10.0, drift_threshold_m=0.5, now=5.0)
|
||||
assert hasattr(result, 'drift_m')
|
||||
assert hasattr(result, 'vo_path_m')
|
||||
assert hasattr(result, 'wheel_path_m')
|
||||
assert hasattr(result, 'is_drifting')
|
||||
assert hasattr(result, 'window_s')
|
||||
assert hasattr(result, 'n_vo')
|
||||
assert hasattr(result, 'n_wheel')
|
||||
|
||||
def test_wheel_faster_than_vo_still_drifts(self):
|
||||
"""Drift is absolute difference — direction doesn't matter."""
|
||||
vo = _stationary_buf(n=11, dt=1.0)
|
||||
wheel = _straight_buf(n=11, speed=0.1, dt=1.0)
|
||||
result = compute_drift(
|
||||
vo, wheel,
|
||||
window_s=15.0, drift_threshold_m=0.5, now=11.0)
|
||||
assert result.drift_m == pytest.approx(1.0, abs=1e-9)
|
||||
assert result.is_drifting
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__, '-v'])
|
||||
@ -0,0 +1,21 @@
|
||||
ambient_sound_node:
|
||||
ros__parameters:
|
||||
sample_rate: 16000 # Expected PCM sample rate (Hz)
|
||||
window_s: 1.0 # Accumulate this many seconds before classifying
|
||||
n_fft: 512 # FFT size (32 ms frame at 16 kHz)
|
||||
n_mels: 32 # Mel filterbank bands
|
||||
audio_topic: "/social/speech/audio_raw" # Source PCM-16 UInt8MultiArray topic
|
||||
|
||||
# ── Classifier thresholds ──────────────────────────────────────────────
|
||||
# Adjust to tune sensitivity for your deployment environment.
|
||||
silence_db: -40.0 # Below this energy (dBFS) → silence
|
||||
alarm_db_min: -25.0 # Min energy for alarm detection
|
||||
alarm_zcr_min: 0.12 # Min ZCR for alarm (intermittent high pitch)
|
||||
alarm_high_ratio_min: 0.35 # Min high-band energy fraction for alarm
|
||||
speech_zcr_min: 0.02 # Min ZCR for speech (voiced onset)
|
||||
speech_zcr_max: 0.25 # Max ZCR for speech
|
||||
speech_flatness_max: 0.35 # Max spectral flatness for speech (tonal)
|
||||
music_zcr_max: 0.08 # Max ZCR for music (harmonic / tonal)
|
||||
music_flatness_max: 0.25 # Max spectral flatness for music
|
||||
crowd_zcr_min: 0.10 # Min ZCR for crowd noise
|
||||
crowd_flatness_min: 0.35 # Min spectral flatness for crowd
|
||||
@ -0,0 +1,42 @@
|
||||
"""ambient_sound.launch.py -- Launch the ambient sound classifier (Issue #252).
|
||||
|
||||
Usage:
|
||||
ros2 launch saltybot_social ambient_sound.launch.py
|
||||
ros2 launch saltybot_social ambient_sound.launch.py silence_db:=-45.0
|
||||
"""
|
||||
|
||||
import os
|
||||
from ament_index_python.packages import get_package_share_directory
|
||||
from launch import LaunchDescription
|
||||
from launch.actions import DeclareLaunchArgument
|
||||
from launch.substitutions import LaunchConfiguration
|
||||
from launch_ros.actions import Node
|
||||
|
||||
|
||||
def generate_launch_description():
|
||||
pkg = get_package_share_directory("saltybot_social")
|
||||
cfg = os.path.join(pkg, "config", "ambient_sound_params.yaml")
|
||||
|
||||
return LaunchDescription([
|
||||
DeclareLaunchArgument("window_s", default_value="1.0",
|
||||
description="Accumulation window (s)"),
|
||||
DeclareLaunchArgument("n_mels", default_value="32",
|
||||
description="Mel filterbank bands"),
|
||||
DeclareLaunchArgument("silence_db", default_value="-40.0",
|
||||
description="Silence energy threshold (dBFS)"),
|
||||
|
||||
Node(
|
||||
package="saltybot_social",
|
||||
executable="ambient_sound_node",
|
||||
name="ambient_sound_node",
|
||||
output="screen",
|
||||
parameters=[
|
||||
cfg,
|
||||
{
|
||||
"window_s": LaunchConfiguration("window_s"),
|
||||
"n_mels": LaunchConfiguration("n_mels"),
|
||||
"silence_db": LaunchConfiguration("silence_db"),
|
||||
},
|
||||
],
|
||||
),
|
||||
])
|
||||
@ -0,0 +1,363 @@
|
||||
"""ambient_sound_node.py -- Ambient sound classifier via mel-spectrogram features.
|
||||
Issue #252
|
||||
|
||||
Accumulates 1 s of PCM-16 audio from /social/speech/audio_raw, extracts a
|
||||
compact mel-spectrogram feature vector, then classifies the scene into one of:
|
||||
|
||||
silence | speech | music | crowd | outdoor | alarm
|
||||
|
||||
Publishes the label as std_msgs/String on /saltybot/ambient_sound at 1 Hz.
|
||||
|
||||
Signal processing is pure Python + numpy (no torch / onnx dependency).
|
||||
|
||||
Feature vector (per 1-s window):
|
||||
energy_db -- overall RMS in dBFS
|
||||
zcr -- mean zero-crossing rate across frames
|
||||
mel_centroid -- centre-of-mass of the mel band energies [0..1]
|
||||
mel_flatness -- geometric/arithmetic mean of mel energies [0..1]
|
||||
(1 = white noise, 0 = single sinusoid)
|
||||
low_ratio -- fraction of mel energy in lower third of bands
|
||||
high_ratio -- fraction of mel energy in upper third of bands
|
||||
|
||||
Classification cascade (priority-ordered):
|
||||
silence : energy_db < silence_db
|
||||
alarm : energy_db >= alarm_db_min AND zcr >= alarm_zcr_min
|
||||
AND high_ratio >= alarm_high_ratio_min
|
||||
speech : zcr in [speech_zcr_min, speech_zcr_max]
|
||||
AND mel_flatness < speech_flatness_max
|
||||
music : zcr < music_zcr_max AND mel_flatness < music_flatness_max
|
||||
crowd : zcr >= crowd_zcr_min AND mel_flatness >= crowd_flatness_min
|
||||
outdoor : catch-all
|
||||
|
||||
Parameters:
|
||||
sample_rate (int, 16000)
|
||||
window_s (float, 1.0) -- accumulation window before classify
|
||||
n_fft (int, 512) -- FFT size
|
||||
n_mels (int, 32) -- mel filterbank bands
|
||||
audio_topic (str, "/social/speech/audio_raw")
|
||||
silence_db (float, -40.0)
|
||||
alarm_db_min (float, -25.0)
|
||||
alarm_zcr_min (float, 0.12)
|
||||
alarm_high_ratio_min (float, 0.35)
|
||||
speech_zcr_min (float, 0.02)
|
||||
speech_zcr_max (float, 0.25)
|
||||
speech_flatness_max (float, 0.35)
|
||||
music_zcr_max (float, 0.08)
|
||||
music_flatness_max (float, 0.25)
|
||||
crowd_zcr_min (float, 0.10)
|
||||
crowd_flatness_min (float, 0.35)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import struct
|
||||
import threading
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import rclpy
|
||||
from rclpy.node import Node
|
||||
from rclpy.qos import QoSProfile
|
||||
from std_msgs.msg import String, UInt8MultiArray
|
||||
|
||||
# numpy used only in DSP helpers — the Jetson always has it
|
||||
try:
|
||||
import numpy as np
|
||||
_NUMPY = True
|
||||
except ImportError:
|
||||
_NUMPY = False
|
||||
|
||||
INT16_MAX = 32768.0
|
||||
LABELS = ("silence", "speech", "music", "crowd", "outdoor", "alarm")
|
||||
|
||||
|
||||
# ── PCM helpers ───────────────────────────────────────────────────────────────
|
||||
|
||||
def pcm16_bytes_to_float32(data: bytes) -> List[float]:
|
||||
"""PCM-16 LE bytes → float32 list in [-1.0, 1.0]."""
|
||||
n = len(data) // 2
|
||||
if n == 0:
|
||||
return []
|
||||
return [s / INT16_MAX for s in struct.unpack(f"<{n}h", data[: n * 2])]
|
||||
|
||||
|
||||
# ── Mel DSP (numpy path) ──────────────────────────────────────────────────────
|
||||
|
||||
def hz_to_mel(hz: float) -> float:
|
||||
return 2595.0 * math.log10(1.0 + hz / 700.0)
|
||||
|
||||
|
||||
def mel_to_hz(mel: float) -> float:
|
||||
return 700.0 * (10.0 ** (mel / 2595.0) - 1.0)
|
||||
|
||||
|
||||
def build_mel_filterbank(sr: int, n_fft: int, n_mels: int,
|
||||
fmin: float = 0.0, fmax: Optional[float] = None):
|
||||
"""Return (n_mels, n_fft//2+1) numpy filterbank matrix."""
|
||||
import numpy as np
|
||||
if fmax is None:
|
||||
fmax = sr / 2.0
|
||||
n_freqs = n_fft // 2 + 1
|
||||
mel_min = hz_to_mel(fmin)
|
||||
mel_max = hz_to_mel(fmax)
|
||||
mel_pts = np.linspace(mel_min, mel_max, n_mels + 2)
|
||||
hz_pts = np.array([mel_to_hz(m) for m in mel_pts])
|
||||
bin_pts = np.floor((n_fft + 1) * hz_pts / sr).astype(int)
|
||||
fb = np.zeros((n_mels, n_freqs))
|
||||
for m in range(n_mels):
|
||||
lo, ctr, hi = bin_pts[m], bin_pts[m + 1], bin_pts[m + 2]
|
||||
for k in range(lo, min(ctr, n_freqs)):
|
||||
if ctr != lo:
|
||||
fb[m, k] = (k - lo) / (ctr - lo)
|
||||
for k in range(ctr, min(hi, n_freqs)):
|
||||
if hi != ctr:
|
||||
fb[m, k] = (hi - k) / (hi - ctr)
|
||||
return fb
|
||||
|
||||
|
||||
def compute_mel_spectrogram(samples: List[float], sr: int,
|
||||
n_fft: int = 512, n_mels: int = 32,
|
||||
hop_length: int = 256):
|
||||
"""Return (n_mels, n_frames) log-mel spectrogram (numpy array)."""
|
||||
import numpy as np
|
||||
x = np.array(samples, dtype=np.float32)
|
||||
fb = build_mel_filterbank(sr, n_fft, n_mels)
|
||||
window = np.hanning(n_fft)
|
||||
frames = []
|
||||
for start in range(0, len(x) - n_fft + 1, hop_length):
|
||||
frame = x[start : start + n_fft] * window
|
||||
spec = np.abs(np.fft.rfft(frame)) ** 2
|
||||
mel = fb @ spec
|
||||
frames.append(mel)
|
||||
if not frames:
|
||||
return np.zeros((n_mels, 1), dtype=np.float32)
|
||||
return np.column_stack(frames).astype(np.float32)
|
||||
|
||||
|
||||
# ── Feature extraction ────────────────────────────────────────────────────────
|
||||
|
||||
def extract_features(samples: List[float], sr: int,
|
||||
n_fft: int = 512, n_mels: int = 32) -> Dict[str, float]:
|
||||
"""Extract scalar features from a raw audio window."""
|
||||
import numpy as np
|
||||
|
||||
n = len(samples)
|
||||
if n == 0:
|
||||
return {k: 0.0 for k in
|
||||
("energy_db", "zcr", "mel_centroid", "mel_flatness",
|
||||
"low_ratio", "high_ratio")}
|
||||
|
||||
# Energy
|
||||
rms = math.sqrt(sum(s * s for s in samples) / n) if n else 0.0
|
||||
energy_db = 20.0 * math.log10(max(rms, 1e-10))
|
||||
|
||||
# ZCR across 30 ms frames
|
||||
chunk = max(1, int(sr * 0.030))
|
||||
zcr_vals = []
|
||||
for i in range(0, n - chunk + 1, chunk):
|
||||
seg = samples[i : i + chunk]
|
||||
crossings = sum(1 for j in range(1, len(seg))
|
||||
if seg[j - 1] * seg[j] < 0)
|
||||
zcr_vals.append(crossings / max(len(seg) - 1, 1))
|
||||
zcr = sum(zcr_vals) / len(zcr_vals) if zcr_vals else 0.0
|
||||
|
||||
# Mel spectrogram features
|
||||
mel_spec = compute_mel_spectrogram(samples, sr, n_fft, n_mels)
|
||||
mel_mean = mel_spec.mean(axis=1) # (n_mels,) mean energy per band
|
||||
|
||||
total = float(mel_mean.sum()) if mel_mean.sum() > 0 else 1e-10
|
||||
indices = np.arange(n_mels, dtype=np.float32)
|
||||
mel_centroid = float((indices * mel_mean).sum()) / (n_mels * total / total) / n_mels
|
||||
|
||||
# Spectral flatness: geometric mean / arithmetic mean
|
||||
eps = 1e-10
|
||||
mel_pos = np.clip(mel_mean, eps, None)
|
||||
geo_mean = float(np.exp(np.log(mel_pos).mean()))
|
||||
arith_mean = float(mel_pos.mean())
|
||||
mel_flatness = min(geo_mean / max(arith_mean, eps), 1.0)
|
||||
|
||||
# Band ratios
|
||||
third = max(1, n_mels // 3)
|
||||
low_energy = float(mel_mean[:third].sum())
|
||||
high_energy = float(mel_mean[-third:].sum())
|
||||
low_ratio = low_energy / max(total, eps)
|
||||
high_ratio = high_energy / max(total, eps)
|
||||
|
||||
return {
|
||||
"energy_db": energy_db,
|
||||
"zcr": zcr,
|
||||
"mel_centroid": mel_centroid,
|
||||
"mel_flatness": mel_flatness,
|
||||
"low_ratio": low_ratio,
|
||||
"high_ratio": high_ratio,
|
||||
}
|
||||
|
||||
|
||||
# ── Classifier ────────────────────────────────────────────────────────────────
|
||||
|
||||
def classify(features: Dict[str, float],
|
||||
silence_db: float = -40.0,
|
||||
alarm_db_min: float = -25.0,
|
||||
alarm_zcr_min: float = 0.12,
|
||||
alarm_high_ratio_min: float = 0.35,
|
||||
speech_zcr_min: float = 0.02,
|
||||
speech_zcr_max: float = 0.25,
|
||||
speech_flatness_max: float = 0.35,
|
||||
music_zcr_max: float = 0.08,
|
||||
music_flatness_max: float = 0.25,
|
||||
crowd_zcr_min: float = 0.10,
|
||||
crowd_flatness_min: float = 0.35) -> str:
|
||||
"""Priority-ordered rule cascade. Returns a label from LABELS."""
|
||||
e = features["energy_db"]
|
||||
zcr = features["zcr"]
|
||||
fl = features["mel_flatness"]
|
||||
hi = features["high_ratio"]
|
||||
|
||||
if e < silence_db:
|
||||
return "silence"
|
||||
if (e >= alarm_db_min
|
||||
and zcr >= alarm_zcr_min
|
||||
and hi >= alarm_high_ratio_min):
|
||||
return "alarm"
|
||||
if zcr < music_zcr_max and fl < music_flatness_max:
|
||||
return "music"
|
||||
if (speech_zcr_min <= zcr <= speech_zcr_max
|
||||
and fl < speech_flatness_max):
|
||||
return "speech"
|
||||
if zcr >= crowd_zcr_min and fl >= crowd_flatness_min:
|
||||
return "crowd"
|
||||
return "outdoor"
|
||||
|
||||
|
||||
# ── Audio accumulation buffer ─────────────────────────────────────────────────
|
||||
|
||||
class AudioBuffer:
|
||||
"""Thread-safe ring buffer; yields a window of samples when full."""
|
||||
|
||||
def __init__(self, window_samples: int) -> None:
|
||||
self._target = window_samples
|
||||
self._buf: List[float] = []
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def push(self, samples: List[float]) -> Optional[List[float]]:
|
||||
"""Append samples. Returns a complete window (and resets) when full."""
|
||||
with self._lock:
|
||||
self._buf.extend(samples)
|
||||
if len(self._buf) >= self._target:
|
||||
window = self._buf[: self._target]
|
||||
self._buf = self._buf[self._target :]
|
||||
return window
|
||||
return None
|
||||
|
||||
def clear(self) -> None:
|
||||
with self._lock:
|
||||
self._buf.clear()
|
||||
|
||||
|
||||
# ── ROS2 node ─────────────────────────────────────────────────────────────────
|
||||
|
||||
class AmbientSoundNode(Node):
|
||||
"""Classifies ambient sound from raw audio and publishes label at 1 Hz."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__("ambient_sound_node")
|
||||
|
||||
self.declare_parameter("sample_rate", 16000)
|
||||
self.declare_parameter("window_s", 1.0)
|
||||
self.declare_parameter("n_fft", 512)
|
||||
self.declare_parameter("n_mels", 32)
|
||||
self.declare_parameter("audio_topic", "/social/speech/audio_raw")
|
||||
# Classifier thresholds
|
||||
self.declare_parameter("silence_db", -40.0)
|
||||
self.declare_parameter("alarm_db_min", -25.0)
|
||||
self.declare_parameter("alarm_zcr_min", 0.12)
|
||||
self.declare_parameter("alarm_high_ratio_min", 0.35)
|
||||
self.declare_parameter("speech_zcr_min", 0.02)
|
||||
self.declare_parameter("speech_zcr_max", 0.25)
|
||||
self.declare_parameter("speech_flatness_max", 0.35)
|
||||
self.declare_parameter("music_zcr_max", 0.08)
|
||||
self.declare_parameter("music_flatness_max", 0.25)
|
||||
self.declare_parameter("crowd_zcr_min", 0.10)
|
||||
self.declare_parameter("crowd_flatness_min", 0.35)
|
||||
|
||||
self._sr = self.get_parameter("sample_rate").value
|
||||
self._n_fft = self.get_parameter("n_fft").value
|
||||
self._n_mels = self.get_parameter("n_mels").value
|
||||
window_s = self.get_parameter("window_s").value
|
||||
audio_topic = self.get_parameter("audio_topic").value
|
||||
|
||||
self._thresholds = {
|
||||
k: self.get_parameter(k).value for k in (
|
||||
"silence_db", "alarm_db_min", "alarm_zcr_min",
|
||||
"alarm_high_ratio_min", "speech_zcr_min", "speech_zcr_max",
|
||||
"speech_flatness_max", "music_zcr_max", "music_flatness_max",
|
||||
"crowd_zcr_min", "crowd_flatness_min",
|
||||
)
|
||||
}
|
||||
|
||||
self._buffer = AudioBuffer(int(self._sr * window_s))
|
||||
self._last_label = "silence"
|
||||
|
||||
qos = QoSProfile(depth=10)
|
||||
self._pub = self.create_publisher(String, "/saltybot/ambient_sound", qos)
|
||||
self._audio_sub = self.create_subscription(
|
||||
UInt8MultiArray, audio_topic, self._on_audio, qos
|
||||
)
|
||||
|
||||
if not _NUMPY:
|
||||
self.get_logger().warn(
|
||||
"numpy not available — mel features disabled, classifying by energy only"
|
||||
)
|
||||
|
||||
self.get_logger().info(
|
||||
f"AmbientSoundNode ready "
|
||||
f"(sr={self._sr}, window={window_s}s, n_mels={self._n_mels})"
|
||||
)
|
||||
|
||||
def _on_audio(self, msg: UInt8MultiArray) -> None:
|
||||
samples = pcm16_bytes_to_float32(bytes(msg.data))
|
||||
if not samples:
|
||||
return
|
||||
window = self._buffer.push(samples)
|
||||
if window is not None:
|
||||
self._classify_and_publish(window)
|
||||
|
||||
def _classify_and_publish(self, samples: List[float]) -> None:
|
||||
try:
|
||||
if _NUMPY:
|
||||
feats = extract_features(samples, self._sr, self._n_fft, self._n_mels)
|
||||
else:
|
||||
# Numpy-free fallback: energy-only
|
||||
rms = math.sqrt(sum(s * s for s in samples) / len(samples))
|
||||
e_db = 20.0 * math.log10(max(rms, 1e-10))
|
||||
feats = {
|
||||
"energy_db": e_db, "zcr": 0.05,
|
||||
"mel_centroid": 0.5, "mel_flatness": 0.2,
|
||||
"low_ratio": 0.4, "high_ratio": 0.2,
|
||||
}
|
||||
label = classify(feats, **self._thresholds)
|
||||
except Exception as exc:
|
||||
self.get_logger().error(f"Classification error: {exc}")
|
||||
label = self._last_label
|
||||
|
||||
if label != self._last_label:
|
||||
self.get_logger().info(
|
||||
f"Ambient sound: {self._last_label} -> {label}"
|
||||
)
|
||||
self._last_label = label
|
||||
|
||||
msg = String()
|
||||
msg.data = label
|
||||
self._pub.publish(msg)
|
||||
|
||||
|
||||
def main(args: Optional[list] = None) -> None:
|
||||
rclpy.init(args=args)
|
||||
node = AmbientSoundNode()
|
||||
try:
|
||||
rclpy.spin(node)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
node.destroy_node()
|
||||
rclpy.shutdown()
|
||||
@ -45,6 +45,8 @@ setup(
|
||||
'mesh_comms_node = saltybot_social.mesh_comms_node:main',
|
||||
# Energy+ZCR voice activity detection (Issue #242)
|
||||
'vad_node = saltybot_social.vad_node:main',
|
||||
# Ambient sound classifier — mel-spectrogram (Issue #252)
|
||||
'ambient_sound_node = saltybot_social.ambient_sound_node:main',
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
407
jetson/ros2_ws/src/saltybot_social/test/test_ambient_sound.py
Normal file
407
jetson/ros2_ws/src/saltybot_social/test/test_ambient_sound.py
Normal file
@ -0,0 +1,407 @@
|
||||
"""test_ambient_sound.py -- Unit tests for Issue #252 ambient sound classifier."""
|
||||
|
||||
from __future__ import annotations
|
||||
import importlib.util, math, os, struct, sys, types
|
||||
import pytest
|
||||
|
||||
# numpy is available on dev machine
|
||||
import numpy as np
|
||||
|
||||
|
||||
def _pkg_root():
|
||||
return os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
|
||||
def _read_src(rel_path):
|
||||
with open(os.path.join(_pkg_root(), rel_path)) as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
def _import_mod():
|
||||
"""Import ambient_sound_node without a live ROS2 environment."""
|
||||
for mod_name in ("rclpy", "rclpy.node", "rclpy.qos",
|
||||
"std_msgs", "std_msgs.msg"):
|
||||
if mod_name not in sys.modules:
|
||||
sys.modules[mod_name] = types.ModuleType(mod_name)
|
||||
|
||||
rclpy_node = sys.modules["rclpy.node"]
|
||||
rclpy_qos = sys.modules["rclpy.qos"]
|
||||
std_msg = sys.modules["std_msgs.msg"]
|
||||
|
||||
DEFAULTS = {
|
||||
"sample_rate": 16000, "window_s": 1.0, "n_fft": 512, "n_mels": 32,
|
||||
"audio_topic": "/social/speech/audio_raw",
|
||||
"silence_db": -40.0, "alarm_db_min": -25.0, "alarm_zcr_min": 0.12,
|
||||
"alarm_high_ratio_min": 0.35, "speech_zcr_min": 0.02,
|
||||
"speech_zcr_max": 0.25, "speech_flatness_max": 0.35,
|
||||
"music_zcr_max": 0.08, "music_flatness_max": 0.25,
|
||||
"crowd_zcr_min": 0.10, "crowd_flatness_min": 0.35,
|
||||
}
|
||||
|
||||
class _Node:
|
||||
def __init__(self, *a, **kw): pass
|
||||
def declare_parameter(self, *a, **kw): pass
|
||||
def get_parameter(self, name):
|
||||
class _P:
|
||||
value = DEFAULTS.get(name)
|
||||
return _P()
|
||||
def create_publisher(self, *a, **kw): return None
|
||||
def create_subscription(self, *a, **kw): return None
|
||||
def get_logger(self):
|
||||
class _L:
|
||||
def info(self, *a): pass
|
||||
def warn(self, *a): pass
|
||||
def error(self, *a): pass
|
||||
return _L()
|
||||
def destroy_node(self): pass
|
||||
|
||||
rclpy_node.Node = _Node
|
||||
rclpy_qos.QoSProfile = type("QoSProfile", (), {"__init__": lambda s, **kw: None})
|
||||
std_msg.String = type("String", (), {"data": ""})
|
||||
std_msg.UInt8MultiArray = type("UInt8MultiArray", (), {"data": b""})
|
||||
sys.modules["rclpy"].init = lambda *a, **kw: None
|
||||
sys.modules["rclpy"].spin = lambda n: None
|
||||
sys.modules["rclpy"].ok = lambda: True
|
||||
sys.modules["rclpy"].shutdown = lambda: None
|
||||
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"ambient_sound_node_testmod",
|
||||
os.path.join(_pkg_root(), "saltybot_social", "ambient_sound_node.py"),
|
||||
)
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(mod)
|
||||
return mod
|
||||
|
||||
|
||||
# ── Audio helpers ─────────────────────────────────────────────────────────────
|
||||
|
||||
SR = 16000
|
||||
|
||||
def _sine(freq, n=SR, amp=0.2):
|
||||
return [amp * math.sin(2 * math.pi * freq * i / SR) for i in range(n)]
|
||||
|
||||
def _white_noise(n=SR, amp=0.1):
|
||||
import random
|
||||
rng = random.Random(42)
|
||||
return [rng.uniform(-amp, amp) for _ in range(n)]
|
||||
|
||||
def _silence(n=SR):
|
||||
return [0.0] * n
|
||||
|
||||
def _pcm16(samples):
|
||||
ints = [max(-32768, min(32767, int(s * 32768))) for s in samples]
|
||||
return struct.pack(f"<{len(ints)}h", *ints)
|
||||
|
||||
|
||||
# ── TestPcm16Convert ──────────────────────────────────────────────────────────
|
||||
|
||||
class TestPcm16Convert:
|
||||
@pytest.fixture(scope="class")
|
||||
def mod(self): return _import_mod()
|
||||
|
||||
def test_empty(self, mod):
|
||||
assert mod.pcm16_bytes_to_float32(b"") == []
|
||||
|
||||
def test_length(self, mod):
|
||||
data = _pcm16(_sine(440, 480))
|
||||
assert len(mod.pcm16_bytes_to_float32(data)) == 480
|
||||
|
||||
def test_range(self, mod):
|
||||
data = _pcm16(_sine(440, 480))
|
||||
result = mod.pcm16_bytes_to_float32(data)
|
||||
assert all(-1.0 <= s <= 1.0 for s in result)
|
||||
|
||||
def test_silence(self, mod):
|
||||
data = _pcm16(_silence(100))
|
||||
assert all(s == 0.0 for s in mod.pcm16_bytes_to_float32(data))
|
||||
|
||||
|
||||
# ── TestMelConversions ────────────────────────────────────────────────────────
|
||||
|
||||
class TestMelConversions:
|
||||
@pytest.fixture(scope="class")
|
||||
def mod(self): return _import_mod()
|
||||
|
||||
def test_hz_to_mel_zero(self, mod):
|
||||
assert mod.hz_to_mel(0.0) == 0.0
|
||||
|
||||
def test_hz_to_mel_1000(self, mod):
|
||||
# 1000 Hz → ~999.99 mel (approximately)
|
||||
assert abs(mod.hz_to_mel(1000.0) - 999.99) < 1.0
|
||||
|
||||
def test_roundtrip(self, mod):
|
||||
for hz in (100.0, 500.0, 1000.0, 4000.0, 8000.0):
|
||||
assert abs(mod.mel_to_hz(mod.hz_to_mel(hz)) - hz) < 0.01
|
||||
|
||||
def test_monotone_increasing(self, mod):
|
||||
freqs = [100, 500, 1000, 2000, 4000, 8000]
|
||||
mels = [mod.hz_to_mel(f) for f in freqs]
|
||||
assert mels == sorted(mels)
|
||||
|
||||
|
||||
# ── TestMelFilterbank ─────────────────────────────────────────────────────────
|
||||
|
||||
class TestMelFilterbank:
|
||||
@pytest.fixture(scope="class")
|
||||
def mod(self): return _import_mod()
|
||||
|
||||
def test_shape(self, mod):
|
||||
fb = mod.build_mel_filterbank(SR, 512, 32)
|
||||
assert fb.shape == (32, 257) # (n_mels, n_fft//2+1)
|
||||
|
||||
def test_nonnegative(self, mod):
|
||||
fb = mod.build_mel_filterbank(SR, 512, 32)
|
||||
assert (fb >= 0).all()
|
||||
|
||||
def test_each_filter_sums_positive(self, mod):
|
||||
fb = mod.build_mel_filterbank(SR, 512, 32)
|
||||
assert all(fb[m].sum() > 0 for m in range(32))
|
||||
|
||||
def test_custom_n_mels(self, mod):
|
||||
fb = mod.build_mel_filterbank(SR, 512, 16)
|
||||
assert fb.shape[0] == 16
|
||||
|
||||
def test_max_value_leq_one(self, mod):
|
||||
fb = mod.build_mel_filterbank(SR, 512, 32)
|
||||
assert fb.max() <= 1.0 + 1e-6
|
||||
|
||||
|
||||
# ── TestMelSpectrogram ────────────────────────────────────────────────────────
|
||||
|
||||
class TestMelSpectrogram:
|
||||
@pytest.fixture(scope="class")
|
||||
def mod(self): return _import_mod()
|
||||
|
||||
def test_shape(self, mod):
|
||||
s = _sine(440, SR)
|
||||
spec = mod.compute_mel_spectrogram(s, SR, n_fft=512, n_mels=32, hop_length=256)
|
||||
assert spec.shape[0] == 32
|
||||
assert spec.shape[1] > 0
|
||||
|
||||
def test_silence_near_zero(self, mod):
|
||||
spec = mod.compute_mel_spectrogram(_silence(SR), SR, n_fft=512, n_mels=32)
|
||||
assert spec.mean() < 1e-6
|
||||
|
||||
def test_louder_has_higher_energy(self, mod):
|
||||
quiet = mod.compute_mel_spectrogram(_sine(440, SR, amp=0.01), SR).mean()
|
||||
loud = mod.compute_mel_spectrogram(_sine(440, SR, amp=0.5), SR).mean()
|
||||
assert loud > quiet
|
||||
|
||||
def test_returns_array(self, mod):
|
||||
spec = mod.compute_mel_spectrogram(_sine(440, SR), SR)
|
||||
assert isinstance(spec, np.ndarray)
|
||||
|
||||
|
||||
# ── TestExtractFeatures ───────────────────────────────────────────────────────
|
||||
|
||||
class TestExtractFeatures:
|
||||
@pytest.fixture(scope="class")
|
||||
def mod(self): return _import_mod()
|
||||
|
||||
def _feats(self, mod, samples):
|
||||
return mod.extract_features(samples, SR, n_fft=512, n_mels=32)
|
||||
|
||||
def test_keys_present(self, mod):
|
||||
f = self._feats(mod, _sine(440, SR))
|
||||
for k in ("energy_db", "zcr", "mel_centroid", "mel_flatness",
|
||||
"low_ratio", "high_ratio"):
|
||||
assert k in f
|
||||
|
||||
def test_silence_low_energy(self, mod):
|
||||
f = self._feats(mod, _silence(SR))
|
||||
assert f["energy_db"] < -40.0
|
||||
|
||||
def test_silence_zero_zcr(self, mod):
|
||||
f = self._feats(mod, _silence(SR))
|
||||
assert f["zcr"] == 0.0
|
||||
|
||||
def test_sine_moderate_energy(self, mod):
|
||||
f = self._feats(mod, _sine(440, SR, amp=0.1))
|
||||
assert -40.0 < f["energy_db"] < 0.0
|
||||
|
||||
def test_ratios_sum_leq_one(self, mod):
|
||||
f = self._feats(mod, _sine(440, SR))
|
||||
assert f["low_ratio"] + f["high_ratio"] <= 1.0 + 1e-6
|
||||
|
||||
def test_ratios_nonnegative(self, mod):
|
||||
f = self._feats(mod, _sine(440, SR))
|
||||
assert f["low_ratio"] >= 0.0 and f["high_ratio"] >= 0.0
|
||||
|
||||
def test_flatness_in_unit_interval(self, mod):
|
||||
f = self._feats(mod, _sine(440, SR))
|
||||
assert 0.0 <= f["mel_flatness"] <= 1.0
|
||||
|
||||
def test_white_noise_high_flatness(self, mod):
|
||||
f_noise = self._feats(mod, _white_noise(SR, amp=0.3))
|
||||
f_sine = self._feats(mod, _sine(440, SR, amp=0.3))
|
||||
# White noise should have higher spectral flatness than a pure tone
|
||||
assert f_noise["mel_flatness"] > f_sine["mel_flatness"]
|
||||
|
||||
def test_empty_samples(self, mod):
|
||||
f = mod.extract_features([], SR)
|
||||
assert f["energy_db"] == 0.0
|
||||
|
||||
def test_louder_higher_energy_db(self, mod):
|
||||
quiet = self._feats(mod, _sine(440, SR, amp=0.01))["energy_db"]
|
||||
loud = self._feats(mod, _sine(440, SR, amp=0.5))["energy_db"]
|
||||
assert loud > quiet
|
||||
|
||||
|
||||
# ── TestClassifier ────────────────────────────────────────────────────────────
|
||||
|
||||
class TestClassifier:
|
||||
@pytest.fixture(scope="class")
|
||||
def mod(self): return _import_mod()
|
||||
|
||||
def _cls(self, mod, **feat_overrides):
|
||||
base = {"energy_db": -20.0, "zcr": 0.05,
|
||||
"mel_centroid": 0.4, "mel_flatness": 0.2,
|
||||
"low_ratio": 0.4, "high_ratio": 0.2}
|
||||
base.update(feat_overrides)
|
||||
return mod.classify(base)
|
||||
|
||||
def test_silence(self, mod):
|
||||
assert self._cls(mod, energy_db=-45.0) == "silence"
|
||||
|
||||
def test_silence_at_threshold(self, mod):
|
||||
assert self._cls(mod, energy_db=-40.0) != "silence"
|
||||
|
||||
def test_alarm(self, mod):
|
||||
assert self._cls(mod, energy_db=-20.0, zcr=0.15, high_ratio=0.40) == "alarm"
|
||||
|
||||
def test_alarm_requires_high_ratio(self, mod):
|
||||
result = self._cls(mod, energy_db=-20.0, zcr=0.15, high_ratio=0.10)
|
||||
assert result != "alarm"
|
||||
|
||||
def test_speech(self, mod):
|
||||
assert self._cls(mod, energy_db=-25.0, zcr=0.08,
|
||||
mel_flatness=0.20) == "speech"
|
||||
|
||||
def test_speech_zcr_too_low(self, mod):
|
||||
result = self._cls(mod, energy_db=-25.0, zcr=0.005, mel_flatness=0.2)
|
||||
assert result != "speech"
|
||||
|
||||
def test_speech_zcr_too_high(self, mod):
|
||||
result = self._cls(mod, energy_db=-25.0, zcr=0.30, mel_flatness=0.2)
|
||||
assert result != "speech"
|
||||
|
||||
def test_music(self, mod):
|
||||
assert self._cls(mod, energy_db=-25.0, zcr=0.04,
|
||||
mel_flatness=0.10) == "music"
|
||||
|
||||
def test_crowd(self, mod):
|
||||
assert self._cls(mod, energy_db=-25.0, zcr=0.15,
|
||||
mel_flatness=0.40) == "crowd"
|
||||
|
||||
def test_outdoor_catchall(self, mod):
|
||||
# Moderate energy, mid ZCR, mid flatness → outdoor
|
||||
result = self._cls(mod, energy_db=-35.0, zcr=0.06, mel_flatness=0.30)
|
||||
assert result in mod.LABELS
|
||||
|
||||
def test_returns_valid_label(self, mod):
|
||||
import random
|
||||
rng = random.Random(0)
|
||||
for _ in range(20):
|
||||
f = {
|
||||
"energy_db": rng.uniform(-60, 0),
|
||||
"zcr": rng.uniform(0, 0.5),
|
||||
"mel_centroid": rng.uniform(0, 1),
|
||||
"mel_flatness": rng.uniform(0, 1),
|
||||
"low_ratio": rng.uniform(0, 0.6),
|
||||
"high_ratio": rng.uniform(0, 0.4),
|
||||
}
|
||||
assert mod.classify(f) in mod.LABELS
|
||||
|
||||
|
||||
# ── TestAudioBuffer ───────────────────────────────────────────────────────────
|
||||
|
||||
class TestAudioBuffer:
|
||||
@pytest.fixture(scope="class")
|
||||
def mod(self): return _import_mod()
|
||||
|
||||
def test_no_window_until_full(self, mod):
|
||||
buf = mod.AudioBuffer(window_samples=100)
|
||||
assert buf.push([0.0] * 50) is None
|
||||
|
||||
def test_exact_fill_returns_window(self, mod):
|
||||
buf = mod.AudioBuffer(window_samples=100)
|
||||
w = buf.push([0.0] * 100)
|
||||
assert w is not None and len(w) == 100
|
||||
|
||||
def test_overflow_carries_over(self, mod):
|
||||
buf = mod.AudioBuffer(window_samples=100)
|
||||
buf.push([0.0] * 100) # fills first window
|
||||
w2 = buf.push([1.0] * 100) # fills second window
|
||||
assert w2 is not None and len(w2) == 100
|
||||
|
||||
def test_partial_then_complete(self, mod):
|
||||
buf = mod.AudioBuffer(window_samples=100)
|
||||
buf.push([0.0] * 60)
|
||||
w = buf.push([0.0] * 60)
|
||||
assert w is not None and len(w) == 100
|
||||
|
||||
def test_clear_resets(self, mod):
|
||||
buf = mod.AudioBuffer(window_samples=100)
|
||||
buf.push([0.0] * 90)
|
||||
buf.clear()
|
||||
assert buf.push([0.0] * 90) is None
|
||||
|
||||
def test_window_contents_correct(self, mod):
|
||||
buf = mod.AudioBuffer(window_samples=4)
|
||||
w = buf.push([1.0, 2.0, 3.0, 4.0])
|
||||
assert w == [1.0, 2.0, 3.0, 4.0]
|
||||
|
||||
|
||||
# ── TestNodeSrc ───────────────────────────────────────────────────────────────
|
||||
|
||||
class TestNodeSrc:
|
||||
@pytest.fixture(scope="class")
|
||||
def src(self): return _read_src("saltybot_social/ambient_sound_node.py")
|
||||
|
||||
def test_class_defined(self, src): assert "class AmbientSoundNode" in src
|
||||
def test_audio_buffer(self, src): assert "class AudioBuffer" in src
|
||||
def test_extract_features(self, src): assert "def extract_features" in src
|
||||
def test_classify_fn(self, src): assert "def classify" in src
|
||||
def test_mel_spectrogram(self, src): assert "compute_mel_spectrogram" in src
|
||||
def test_mel_filterbank(self, src): assert "build_mel_filterbank" in src
|
||||
def test_hz_to_mel(self, src): assert "hz_to_mel" in src
|
||||
def test_labels_tuple(self, src): assert "LABELS" in src
|
||||
def test_all_labels(self, src):
|
||||
for label in ("silence", "speech", "music", "crowd", "outdoor", "alarm"):
|
||||
assert label in src
|
||||
def test_topic_pub(self, src): assert '"/saltybot/ambient_sound"' in src
|
||||
def test_topic_sub(self, src): assert '"/social/speech/audio_raw"' in src
|
||||
def test_window_param(self, src): assert '"window_s"' in src
|
||||
def test_n_mels_param(self, src): assert '"n_mels"' in src
|
||||
def test_silence_param(self, src): assert '"silence_db"' in src
|
||||
def test_alarm_param(self, src): assert '"alarm_db_min"' in src
|
||||
def test_speech_param(self, src): assert '"speech_zcr_min"' in src
|
||||
def test_music_param(self, src): assert '"music_zcr_max"' in src
|
||||
def test_crowd_param(self, src): assert '"crowd_zcr_min"' in src
|
||||
def test_string_pub(self, src): assert "String" in src
|
||||
def test_uint8_sub(self, src): assert "UInt8MultiArray" in src
|
||||
def test_issue_tag(self, src): assert "252" in src
|
||||
def test_main(self, src): assert "def main" in src
|
||||
def test_numpy_optional(self, src): assert "_NUMPY" in src
|
||||
|
||||
|
||||
# ── TestConfig ────────────────────────────────────────────────────────────────
|
||||
|
||||
class TestConfig:
|
||||
@pytest.fixture(scope="class")
|
||||
def cfg(self): return _read_src("config/ambient_sound_params.yaml")
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def setup(self): return _read_src("setup.py")
|
||||
|
||||
def test_node_name(self, cfg): assert "ambient_sound_node:" in cfg
|
||||
def test_window_s(self, cfg): assert "window_s" in cfg
|
||||
def test_n_mels(self, cfg): assert "n_mels" in cfg
|
||||
def test_silence_db(self, cfg): assert "silence_db" in cfg
|
||||
def test_alarm_params(self, cfg): assert "alarm_db_min" in cfg
|
||||
def test_speech_params(self, cfg): assert "speech_zcr_min" in cfg
|
||||
def test_music_params(self, cfg): assert "music_zcr_max" in cfg
|
||||
def test_crowd_params(self, cfg): assert "crowd_zcr_min" in cfg
|
||||
def test_defaults_present(self, cfg): assert "-40.0" in cfg and "0.12" in cfg
|
||||
def test_entry_point(self, setup):
|
||||
assert "ambient_sound_node = saltybot_social.ambient_sound_node:main" in setup
|
||||
@ -0,0 +1,5 @@
|
||||
wheel_slip_detector:
|
||||
ros__parameters:
|
||||
frequency: 10
|
||||
slip_threshold: 0.1
|
||||
slip_timeout: 0.5
|
||||
@ -0,0 +1,14 @@
|
||||
from launch import LaunchDescription
|
||||
from launch_ros.actions import Node
|
||||
from launch.substitutions import LaunchConfiguration
|
||||
from launch.actions import DeclareLaunchArgument
|
||||
import os
|
||||
from ament_index_python.packages import get_package_share_directory
|
||||
|
||||
def generate_launch_description():
|
||||
pkg_dir = get_package_share_directory("saltybot_wheel_slip_detector")
|
||||
config_file = os.path.join(pkg_dir, "config", "wheel_slip_config.yaml")
|
||||
return LaunchDescription([
|
||||
DeclareLaunchArgument("config_file", default_value=config_file, description="Path to configuration YAML file"),
|
||||
Node(package="saltybot_wheel_slip_detector", executable="wheel_slip_detector_node", name="wheel_slip_detector", output="screen", parameters=[LaunchConfiguration("config_file")]),
|
||||
])
|
||||
18
jetson/ros2_ws/src/saltybot_wheel_slip_detector/package.xml
Normal file
18
jetson/ros2_ws/src/saltybot_wheel_slip_detector/package.xml
Normal file
@ -0,0 +1,18 @@
|
||||
<?xml version="1.0"?>
|
||||
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
|
||||
<package format="3">
|
||||
<name>saltybot_wheel_slip_detector</name>
|
||||
<version>0.1.0</version>
|
||||
<description>Wheel slip detection by comparing commanded vs actual velocity.</description>
|
||||
<maintainer email="seb@vayrette.com">Seb</maintainer>
|
||||
<license>Apache-2.0</license>
|
||||
<buildtool_depend>ament_python</buildtool_depend>
|
||||
<depend>rclpy</depend>
|
||||
<depend>geometry_msgs</depend>
|
||||
<depend>std_msgs</depend>
|
||||
<depend>nav_msgs</depend>
|
||||
<test_depend>pytest</test_depend>
|
||||
<export>
|
||||
<build_type>ament_python</build_type>
|
||||
</export>
|
||||
</package>
|
||||
@ -0,0 +1,77 @@
|
||||
#!/usr/bin/env python3
|
||||
from typing import Optional
|
||||
import math
|
||||
import rclpy
|
||||
from rclpy.node import Node
|
||||
from rclpy.timer import Timer
|
||||
from geometry_msgs.msg import Twist
|
||||
from nav_msgs.msg import Odometry
|
||||
from std_msgs.msg import Bool
|
||||
|
||||
class WheelSlipDetectorNode(Node):
|
||||
def __init__(self):
|
||||
super().__init__("wheel_slip_detector")
|
||||
self.declare_parameter("frequency", 10)
|
||||
frequency = self.get_parameter("frequency").value
|
||||
self.declare_parameter("slip_threshold", 0.1)
|
||||
self.declare_parameter("slip_timeout", 0.5)
|
||||
self.slip_threshold = self.get_parameter("slip_threshold").value
|
||||
self.slip_timeout = self.get_parameter("slip_timeout").value
|
||||
self.period = 1.0 / frequency
|
||||
self.cmd_vel: Optional[Twist] = None
|
||||
self.actual_vel: Optional[Twist] = None
|
||||
self.slip_duration = 0.0
|
||||
self.slip_detected = False
|
||||
self.create_subscription(Twist, "/cmd_vel", self._on_cmd_vel, 10)
|
||||
self.create_subscription(Odometry, "/odom", self._on_odom, 10)
|
||||
self.pub_slip = self.create_publisher(Bool, "/saltybot/wheel_slip_detected", 10)
|
||||
self.timer: Timer = self.create_timer(self.period, self._timer_callback)
|
||||
self.get_logger().info(f"Wheel slip detector initialized at {frequency}Hz. Threshold: {self.slip_threshold} m/s, Timeout: {self.slip_timeout}s")
|
||||
|
||||
def _on_cmd_vel(self, msg: Twist) -> None:
|
||||
self.cmd_vel = msg
|
||||
|
||||
def _on_odom(self, msg: Odometry) -> None:
|
||||
self.actual_vel = msg.twist.twist
|
||||
|
||||
def _timer_callback(self) -> None:
|
||||
if self.cmd_vel is None or self.actual_vel is None:
|
||||
slip_detected = False
|
||||
else:
|
||||
slip_detected = self._check_slip()
|
||||
if slip_detected:
|
||||
self.slip_duration += self.period
|
||||
else:
|
||||
self.slip_duration = 0.0
|
||||
is_slip = self.slip_duration > self.slip_timeout
|
||||
if is_slip != self.slip_detected:
|
||||
self.slip_detected = is_slip
|
||||
if self.slip_detected:
|
||||
self.get_logger().warn(f"WHEEL SLIP DETECTED: {self.slip_duration:.2f}s")
|
||||
else:
|
||||
self.get_logger().info("Wheel slip cleared")
|
||||
slip_msg = Bool()
|
||||
slip_msg.data = is_slip
|
||||
self.pub_slip.publish(slip_msg)
|
||||
|
||||
def _check_slip(self) -> bool:
|
||||
cmd_speed = math.sqrt(self.cmd_vel.linear.x**2 + self.cmd_vel.linear.y**2)
|
||||
actual_speed = math.sqrt(self.actual_vel.linear.x**2 + self.actual_vel.linear.y**2)
|
||||
vel_diff = abs(cmd_speed - actual_speed)
|
||||
if cmd_speed < 0.05 and actual_speed < 0.05:
|
||||
return False
|
||||
return vel_diff > self.slip_threshold
|
||||
|
||||
def main(args=None):
|
||||
rclpy.init(args=args)
|
||||
node = WheelSlipDetectorNode()
|
||||
try:
|
||||
rclpy.spin(node)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
node.destroy_node()
|
||||
rclpy.shutdown()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -0,0 +1,4 @@
|
||||
[develop]
|
||||
script-dir=$base/lib/saltybot_wheel_slip_detector
|
||||
[install]
|
||||
install-scripts=$base/lib/saltybot_wheel_slip_detector
|
||||
21
jetson/ros2_ws/src/saltybot_wheel_slip_detector/setup.py
Normal file
21
jetson/ros2_ws/src/saltybot_wheel_slip_detector/setup.py
Normal file
@ -0,0 +1,21 @@
|
||||
from setuptools import find_packages, setup
|
||||
package_name = "saltybot_wheel_slip_detector"
|
||||
setup(
|
||||
name=package_name,
|
||||
version="0.1.0",
|
||||
packages=find_packages(exclude=["test"]),
|
||||
data_files=[
|
||||
("share/ament_index/resource_index/packages", ["resource/" + package_name]),
|
||||
("share/" + package_name, ["package.xml"]),
|
||||
("share/" + package_name + "/launch", ["launch/wheel_slip_detector.launch.py"]),
|
||||
("share/" + package_name + "/config", ["config/wheel_slip_config.yaml"]),
|
||||
],
|
||||
install_requires=["setuptools"],
|
||||
zip_safe=True,
|
||||
maintainer="Seb",
|
||||
maintainer_email="seb@vayrette.com",
|
||||
description="Wheel slip detection from velocity command/actual mismatch",
|
||||
license="Apache-2.0",
|
||||
tests_require=["pytest"],
|
||||
entry_points={"console_scripts": ["wheel_slip_detector_node = saltybot_wheel_slip_detector.wheel_slip_detector_node:main"]},
|
||||
)
|
||||
@ -58,6 +58,9 @@ import JoystickTeleop from './components/JoystickTeleop.jsx';
|
||||
// Network diagnostics (issue #222)
|
||||
import { NetworkPanel } from './components/NetworkPanel.jsx';
|
||||
|
||||
// Waypoint editor (issue #261)
|
||||
import { WaypointEditor } from './components/WaypointEditor.jsx';
|
||||
|
||||
const TAB_GROUPS = [
|
||||
{
|
||||
label: 'SOCIAL',
|
||||
@ -85,6 +88,13 @@ const TAB_GROUPS = [
|
||||
{ id: 'cameras', label: 'Cameras', },
|
||||
],
|
||||
},
|
||||
{
|
||||
label: 'NAVIGATION',
|
||||
color: 'text-teal-600',
|
||||
tabs: [
|
||||
{ id: 'waypoints', label: 'Waypoints' },
|
||||
],
|
||||
},
|
||||
{
|
||||
label: 'FLEET',
|
||||
color: 'text-green-600',
|
||||
@ -244,8 +254,10 @@ export default function App() {
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
{activeTab === 'health' && <SystemHealth subscribe={subscribe} />}
|
||||
{activeTab === 'cameras' && <CameraViewer subscribe={subscribe} />}
|
||||
{activeTab === 'health' && <SystemHealth subscribe={subscribe} />}
|
||||
{activeTab === 'cameras' && <CameraViewer subscribe={subscribe} />}
|
||||
|
||||
{activeTab === 'waypoints' && <WaypointEditor subscribe={subscribe} publish={publishFn} callService={callService} />}
|
||||
|
||||
{activeTab === 'fleet' && <FleetPanel />}
|
||||
{activeTab === 'missions' && <MissionPlanner />}
|
||||
|
||||
449
ui/social-bot/src/components/WaypointEditor.jsx
Normal file
449
ui/social-bot/src/components/WaypointEditor.jsx
Normal file
@ -0,0 +1,449 @@
|
||||
/**
|
||||
* WaypointEditor.jsx — Interactive waypoint navigation editor with click-to-place and drag-to-reorder
|
||||
*
|
||||
* Features:
|
||||
* - Click on map canvas to place waypoints
|
||||
* - Drag waypoints to reorder navigation sequence
|
||||
* - Right-click to delete waypoints
|
||||
* - Real-time waypoint list with labels and coordinates
|
||||
* - Send Nav2 goal to /navigate_to_pose action
|
||||
* - Execute waypoint sequence with automatic progression
|
||||
* - Clear all waypoints button
|
||||
* - Visual feedback for active waypoint (executing)
|
||||
* - Imports map display from MapViewer for coordinate system
|
||||
*/
|
||||
|
||||
import { useEffect, useRef, useState } from 'react';
|
||||
|
||||
function WaypointEditor({ subscribe, publish, callService }) {
|
||||
// Waypoint storage
|
||||
const [waypoints, setWaypoints] = useState([]);
|
||||
const [selectedWaypoint, setSelectedWaypoint] = useState(null);
|
||||
const [isDragging, setIsDragging] = useState(false);
|
||||
const [dragIndex, setDragIndex] = useState(null);
|
||||
const [activeWaypoint, setActiveWaypoint] = useState(null);
|
||||
const [executing, setExecuting] = useState(false);
|
||||
|
||||
// Map context
|
||||
const [mapData, setMapData] = useState(null);
|
||||
const [robotPose, setRobotPose] = useState({ x: 0, y: 0, theta: 0 });
|
||||
|
||||
// Canvas reference
|
||||
const canvasRef = useRef(null);
|
||||
const containerRef = useRef(null);
|
||||
|
||||
// Refs for ROS integration
|
||||
const mapDataRef = useRef(null);
|
||||
const robotPoseRef = useRef({ x: 0, y: 0, theta: 0 });
|
||||
const waypointsRef = useRef([]);
|
||||
|
||||
// Subscribe to map data (for coordinate reference)
|
||||
useEffect(() => {
|
||||
const unsubMap = subscribe(
|
||||
'/map',
|
||||
'nav_msgs/OccupancyGrid',
|
||||
(msg) => {
|
||||
try {
|
||||
const mapInfo = {
|
||||
width: msg.info.width,
|
||||
height: msg.info.height,
|
||||
resolution: msg.info.resolution,
|
||||
origin: msg.info.origin,
|
||||
};
|
||||
setMapData(mapInfo);
|
||||
mapDataRef.current = mapInfo;
|
||||
} catch (e) {
|
||||
console.error('Error parsing map data:', e);
|
||||
}
|
||||
}
|
||||
);
|
||||
return unsubMap;
|
||||
}, [subscribe]);
|
||||
|
||||
// Subscribe to robot odometry (for current position reference)
|
||||
useEffect(() => {
|
||||
const unsubOdom = subscribe(
|
||||
'/odom',
|
||||
'nav_msgs/Odometry',
|
||||
(msg) => {
|
||||
try {
|
||||
const pos = msg.pose.pose.position;
|
||||
const ori = msg.pose.pose.orientation;
|
||||
|
||||
const siny_cosp = 2 * (ori.w * ori.z + ori.x * ori.y);
|
||||
const cosy_cosp = 1 - 2 * (ori.y * ori.y + ori.z * ori.z);
|
||||
const theta = Math.atan2(siny_cosp, cosy_cosp);
|
||||
|
||||
const newPose = { x: pos.x, y: pos.y, theta };
|
||||
setRobotPose(newPose);
|
||||
robotPoseRef.current = newPose;
|
||||
} catch (e) {
|
||||
console.error('Error parsing odometry data:', e);
|
||||
}
|
||||
}
|
||||
);
|
||||
return unsubOdom;
|
||||
}, [subscribe]);
|
||||
|
||||
// Canvas event handlers
|
||||
const handleCanvasClick = (e) => {
|
||||
if (!mapDataRef.current || !canvasRef.current) return;
|
||||
|
||||
const canvas = canvasRef.current;
|
||||
const rect = canvas.getBoundingClientRect();
|
||||
const clickX = e.clientX - rect.left;
|
||||
const clickY = e.clientY - rect.top;
|
||||
|
||||
// Convert canvas coordinates to world coordinates
|
||||
// This assumes the map is centered on the robot
|
||||
const map = mapDataRef.current;
|
||||
const robot = robotPoseRef.current;
|
||||
const zoom = 1; // Would need to track zoom if map has zoom controls
|
||||
|
||||
// Inverse of map rendering calculation
|
||||
const centerX = canvas.width / 2;
|
||||
const centerY = canvas.height / 2;
|
||||
|
||||
const worldX = robot.x + (clickX - centerX) / zoom;
|
||||
const worldY = robot.y - (clickY - centerY) / zoom;
|
||||
|
||||
// Create new waypoint
|
||||
const newWaypoint = {
|
||||
id: Date.now(),
|
||||
x: parseFloat(worldX.toFixed(2)),
|
||||
y: parseFloat(worldY.toFixed(2)),
|
||||
label: `WP-${waypoints.length + 1}`,
|
||||
};
|
||||
|
||||
setWaypoints((prev) => [...prev, newWaypoint]);
|
||||
waypointsRef.current = [...waypointsRef.current, newWaypoint];
|
||||
};
|
||||
|
||||
const handleCanvasContextMenu = (e) => {
|
||||
e.preventDefault();
|
||||
// Right-click handled by waypoint list
|
||||
};
|
||||
|
||||
// Waypoint list handlers
|
||||
const handleDeleteWaypoint = (id) => {
|
||||
setWaypoints((prev) => prev.filter((wp) => wp.id !== id));
|
||||
waypointsRef.current = waypointsRef.current.filter((wp) => wp.id !== id);
|
||||
if (selectedWaypoint === id) setSelectedWaypoint(null);
|
||||
};
|
||||
|
||||
const handleWaypointSelect = (id) => {
|
||||
setSelectedWaypoint(selectedWaypoint === id ? null : id);
|
||||
};
|
||||
|
||||
const handleWaypointDragStart = (e, index) => {
|
||||
setIsDragging(true);
|
||||
setDragIndex(index);
|
||||
};
|
||||
|
||||
const handleWaypointDragOver = (e, targetIndex) => {
|
||||
if (!isDragging || dragIndex === null || dragIndex === targetIndex) return;
|
||||
|
||||
const newWaypoints = [...waypoints];
|
||||
const draggedWaypoint = newWaypoints[dragIndex];
|
||||
newWaypoints.splice(dragIndex, 1);
|
||||
newWaypoints.splice(targetIndex, 0, draggedWaypoint);
|
||||
|
||||
setWaypoints(newWaypoints);
|
||||
waypointsRef.current = newWaypoints;
|
||||
setDragIndex(targetIndex);
|
||||
};
|
||||
|
||||
const handleWaypointDragEnd = () => {
|
||||
setIsDragging(false);
|
||||
setDragIndex(null);
|
||||
};
|
||||
|
||||
// Execute waypoints
|
||||
const sendNavGoal = async (waypoint) => {
|
||||
if (!callService) return;
|
||||
|
||||
try {
|
||||
// Create quaternion from heading (default to 0 if no heading)
|
||||
const heading = waypoint.theta || 0;
|
||||
const halfHeading = heading / 2;
|
||||
const qx = 0;
|
||||
const qy = 0;
|
||||
const qz = Math.sin(halfHeading);
|
||||
const qw = Math.cos(halfHeading);
|
||||
|
||||
const goal = {
|
||||
pose: {
|
||||
position: {
|
||||
x: waypoint.x,
|
||||
y: waypoint.y,
|
||||
z: 0,
|
||||
},
|
||||
orientation: {
|
||||
x: qx,
|
||||
y: qy,
|
||||
z: qz,
|
||||
w: qw,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
// Send to Nav2 navigate_to_pose action
|
||||
await callService(
|
||||
'/navigate_to_pose',
|
||||
'nav2_msgs/NavigateToPose',
|
||||
{ pose: goal.pose }
|
||||
);
|
||||
|
||||
setActiveWaypoint(waypoint.id);
|
||||
return true;
|
||||
} catch (e) {
|
||||
console.error('Error sending nav goal:', e);
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
const executeWaypoints = async () => {
|
||||
if (waypoints.length === 0) return;
|
||||
|
||||
setExecuting(true);
|
||||
for (const waypoint of waypoints) {
|
||||
const success = await sendNavGoal(waypoint);
|
||||
if (!success) {
|
||||
console.error('Failed to send goal for waypoint:', waypoint);
|
||||
break;
|
||||
}
|
||||
// Wait a bit before sending next goal
|
||||
await new Promise((resolve) => setTimeout(resolve, 500));
|
||||
}
|
||||
setExecuting(false);
|
||||
setActiveWaypoint(null);
|
||||
};
|
||||
|
||||
const clearWaypoints = () => {
|
||||
setWaypoints([]);
|
||||
waypointsRef.current = [];
|
||||
setSelectedWaypoint(null);
|
||||
setActiveWaypoint(null);
|
||||
};
|
||||
|
||||
const sendSingleGoal = async () => {
|
||||
if (selectedWaypoint === null) return;
|
||||
|
||||
const wp = waypoints.find((w) => w.id === selectedWaypoint);
|
||||
if (wp) {
|
||||
await sendNavGoal(wp);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="flex h-full gap-3">
|
||||
{/* Map area with click handlers */}
|
||||
<div className="flex-1 flex flex-col space-y-3">
|
||||
<div className="flex-1 bg-gray-900 rounded-lg border border-cyan-950 overflow-hidden relative cursor-crosshair">
|
||||
<div
|
||||
ref={containerRef}
|
||||
className="w-full h-full"
|
||||
onClick={handleCanvasClick}
|
||||
onContextMenu={handleCanvasContextMenu}
|
||||
>
|
||||
{/* Virtual map display - waypoints overlaid */}
|
||||
<svg
|
||||
className="absolute inset-0 w-full h-full pointer-events-none"
|
||||
id="waypoint-overlay"
|
||||
>
|
||||
{/* Waypoint markers */}
|
||||
{waypoints.map((wp, idx) => {
|
||||
if (!mapDataRef.current) return null;
|
||||
|
||||
const robot = robotPoseRef.current;
|
||||
const zoom = 1;
|
||||
const centerX = containerRef.current?.clientWidth / 2 || 400;
|
||||
const centerY = containerRef.current?.clientHeight / 2 || 300;
|
||||
|
||||
const canvasX = centerX + (wp.x - robot.x) * zoom;
|
||||
const canvasY = centerY - (wp.y - robot.y) * zoom;
|
||||
|
||||
const isActive = wp.id === activeWaypoint;
|
||||
const isSelected = wp.id === selectedWaypoint;
|
||||
|
||||
return (
|
||||
<g key={wp.id}>
|
||||
{/* Waypoint circle */}
|
||||
<circle
|
||||
cx={canvasX}
|
||||
cy={canvasY}
|
||||
r="10"
|
||||
fill={isActive ? '#ef4444' : isSelected ? '#fbbf24' : '#06b6d4'}
|
||||
opacity="0.8"
|
||||
/>
|
||||
{/* Waypoint number */}
|
||||
<text
|
||||
x={canvasX}
|
||||
y={canvasY}
|
||||
textAnchor="middle"
|
||||
dominantBaseline="middle"
|
||||
fill="white"
|
||||
fontSize="10"
|
||||
fontWeight="bold"
|
||||
pointerEvents="none"
|
||||
>
|
||||
{idx + 1}
|
||||
</text>
|
||||
{/* Line to next waypoint */}
|
||||
{idx < waypoints.length - 1 && (
|
||||
<line
|
||||
x1={canvasX}
|
||||
y1={canvasY}
|
||||
x2={
|
||||
centerX +
|
||||
(waypoints[idx + 1].x - robot.x) * zoom
|
||||
}
|
||||
y2={
|
||||
centerY -
|
||||
(waypoints[idx + 1].y - robot.y) * zoom
|
||||
}
|
||||
stroke="#10b981"
|
||||
strokeWidth="2"
|
||||
opacity="0.6"
|
||||
/>
|
||||
)}
|
||||
</g>
|
||||
);
|
||||
})}
|
||||
|
||||
{/* Robot position marker */}
|
||||
<circle
|
||||
cx={containerRef.current?.clientWidth / 2 || 400}
|
||||
cy={containerRef.current?.clientHeight / 2 || 300}
|
||||
r="8"
|
||||
fill="#8b5cf6"
|
||||
opacity="1"
|
||||
/>
|
||||
</svg>
|
||||
|
||||
<div className="absolute inset-0 flex items-center justify-center pointer-events-none text-gray-600 text-sm">
|
||||
{waypoints.length === 0 && (
|
||||
<div className="text-center">
|
||||
<div>Click to place waypoints</div>
|
||||
<div className="text-xs text-gray-700">Right-click to delete</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Info panel */}
|
||||
<div className="bg-gray-950 rounded-lg border border-cyan-950 p-3 text-xs text-gray-600 space-y-1">
|
||||
<div className="flex justify-between">
|
||||
<span>Waypoints:</span>
|
||||
<span className="text-cyan-400">{waypoints.length}</span>
|
||||
</div>
|
||||
<div className="flex justify-between">
|
||||
<span>Robot Position:</span>
|
||||
<span className="text-cyan-400">
|
||||
({robotPose.x.toFixed(2)}, {robotPose.y.toFixed(2)})
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Waypoint list sidebar */}
|
||||
<div className="w-64 flex flex-col bg-gray-950 rounded-lg border border-cyan-950 space-y-3 p-3">
|
||||
<div className="flex items-center justify-between">
|
||||
<div className="text-cyan-700 text-xs font-bold tracking-widest">WAYPOINTS</div>
|
||||
<div className="text-gray-600 text-xs">{waypoints.length}</div>
|
||||
</div>
|
||||
|
||||
{/* Waypoint list */}
|
||||
<div className="flex-1 overflow-y-auto space-y-1">
|
||||
{waypoints.length === 0 ? (
|
||||
<div className="text-center text-gray-700 text-xs py-4">
|
||||
Click map to add waypoints
|
||||
</div>
|
||||
) : (
|
||||
waypoints.map((wp, idx) => (
|
||||
<div
|
||||
key={wp.id}
|
||||
draggable
|
||||
onDragStart={(e) => handleWaypointDragStart(e, idx)}
|
||||
onDragOver={(e) => {
|
||||
e.preventDefault();
|
||||
handleWaypointDragOver(e, idx);
|
||||
}}
|
||||
onDragEnd={handleWaypointDragEnd}
|
||||
onClick={() => handleWaypointSelect(wp.id)}
|
||||
onContextMenu={(e) => {
|
||||
e.preventDefault();
|
||||
handleDeleteWaypoint(wp.id);
|
||||
}}
|
||||
className={`p-2 rounded border text-xs cursor-move transition-colors ${
|
||||
wp.id === activeWaypoint
|
||||
? 'bg-red-950 border-red-700 text-red-300'
|
||||
: wp.id === selectedWaypoint
|
||||
? 'bg-amber-950 border-amber-700 text-amber-300'
|
||||
: 'bg-gray-900 border-gray-700 text-gray-400 hover:border-gray-600'
|
||||
}`}
|
||||
>
|
||||
<div className="flex justify-between items-start gap-2">
|
||||
<div className="font-bold">#{idx + 1}</div>
|
||||
<div className="text-right flex-1">
|
||||
<div className="text-gray-500">{wp.label}</div>
|
||||
<div className="text-gray-600">
|
||||
{wp.x.toFixed(2)}, {wp.y.toFixed(2)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
))
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Control buttons */}
|
||||
<div className="space-y-2 border-t border-gray-800 pt-3">
|
||||
<button
|
||||
onClick={sendSingleGoal}
|
||||
disabled={selectedWaypoint === null || executing}
|
||||
className="w-full px-2 py-1.5 text-xs font-bold tracking-widest rounded border border-cyan-800 bg-cyan-950 text-cyan-400 hover:bg-cyan-900 disabled:opacity-50 disabled:cursor-not-allowed transition-colors"
|
||||
>
|
||||
SEND GOAL
|
||||
</button>
|
||||
|
||||
<button
|
||||
onClick={executeWaypoints}
|
||||
disabled={waypoints.length === 0 || executing}
|
||||
className="w-full px-2 py-1.5 text-xs font-bold tracking-widest rounded border border-green-800 bg-green-950 text-green-400 hover:bg-green-900 disabled:opacity-50 disabled:cursor-not-allowed transition-colors"
|
||||
>
|
||||
{executing ? 'EXECUTING...' : 'EXECUTE ALL'}
|
||||
</button>
|
||||
|
||||
<button
|
||||
onClick={clearWaypoints}
|
||||
disabled={waypoints.length === 0}
|
||||
className="w-full px-2 py-1.5 text-xs font-bold tracking-widest rounded border border-red-800 bg-red-950 text-red-400 hover:bg-red-900 disabled:opacity-50 disabled:cursor-not-allowed transition-colors"
|
||||
>
|
||||
CLEAR ALL
|
||||
</button>
|
||||
</div>
|
||||
|
||||
{/* Instructions */}
|
||||
<div className="text-xs text-gray-600 space-y-1 border-t border-gray-800 pt-3">
|
||||
<div className="font-bold text-gray-500">CONTROLS:</div>
|
||||
<div>• Click: Place waypoint</div>
|
||||
<div>• Right-click: Delete waypoint</div>
|
||||
<div>• Drag: Reorder waypoints</div>
|
||||
<div>• Click list: Select waypoint</div>
|
||||
</div>
|
||||
|
||||
{/* Topic info */}
|
||||
<div className="text-xs text-gray-600 border-t border-gray-800 pt-3">
|
||||
<div className="flex justify-between">
|
||||
<span>Service:</span>
|
||||
<span className="text-gray-500">/navigate_to_pose</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export { WaypointEditor };
|
||||
Loading…
x
Reference in New Issue
Block a user