Compare commits

..

6 Commits

Author SHA1 Message Date
3530f16fa8 feat(controls): Wheel slip detector (Issue #262)
Some checks failed
social-bot integration tests / Lint (flake8 + pep257) (pull_request) Failing after 2s
social-bot integration tests / Core integration tests (mock sensors, no GPU) (pull_request) Has been skipped
social-bot integration tests / Latency profiling (GPU, Orin) (pull_request) Has been cancelled
Detect wheel slip by comparing commanded velocity vs actual encoder velocity.
Publishes Bool flag on /saltybot/wheel_slip_detected when slip detected >0.5s.

Features:
- Subscribe to /cmd_vel (commanded) and /odom (actual velocity)
- Compare velocity magnitudes with 0.1 m/s threshold
- Persistence: slip must persist >0.5s to trigger (debounces transients)
- Publish Bool on /saltybot/wheel_slip_detected with detection status
- 10Hz monitoring frequency, configurable parameters

Algorithm:
- Compute linear speed from x,y components
- Calculate velocity difference
- If exceeds threshold: increment slip duration
- If duration > timeout: declare slip detected

Benefits:
- Detects environmental slip (ice, mud, wet surfaces)
- Triggers speed reduction to maintain traction
- Prevents wheel spinning/rut digging
- Safety response for loss of grip

Topics:
- Subscribed: /cmd_vel (Twist), /odom (Odometry)
- Published: /saltybot/wheel_slip_detected (Bool)

Config: frequency=10Hz, slip_threshold=0.1 m/s, slip_timeout=0.5s

Co-Authored-By: Claude Haiku 4.5 <noreply@anthropic.com>
2026-03-02 14:13:56 -05:00
d6ee5a16b7 feat(webui): waypoint editor with click-to-navigate (Issue #261) 2026-03-02 14:13:56 -05:00
94a6f0787e Merge pull request 'feat(bringup): visual odometry drift detector (Issue #260)' (#265) from sl-perception/issue-260-vo-drift into main 2026-03-02 14:13:34 -05:00
50636de5a9 Merge pull request 'feat(social): ambient sound classifier via mel-spectrogram (Issue #252)' (#258) from sl-jetson/issue-252-ambient-sound into main
Some checks failed
social-bot integration tests / Lint (flake8 + pep257) (push) Failing after 11s
social-bot integration tests / Core integration tests (mock sensors, no GPU) (push) Has been skipped
social-bot integration tests / Latency profiling (GPU, Orin) (push) Has been cancelled
2026-03-02 14:13:33 -05:00
9d12805843 feat(bringup): visual odometry drift detector (Issue #260)
Adds sliding-window drift detector that compares cumulative path lengths
of visual odom and wheel odom over a configurable window (default 10 s).
Drift = |vo_path − wheel_path|; flagged when ≥ 0.5 m (configurable).
OdomBuffer handles per-source rolling storage with automatic age eviction.
Publishes Bool on /saltybot/vo_drift_detected and Float32 on
/saltybot/vo_drift_magnitude at 2 Hz.  27/27 pure-Python tests pass.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-02 13:26:07 -05:00
3cd9faeed9 feat(social): ambient sound classifier via mel-spectrogram — Issue #252
Some checks failed
social-bot integration tests / Lint (flake8 + pep257) (pull_request) Failing after 2s
social-bot integration tests / Core integration tests (mock sensors, no GPU) (pull_request) Has been skipped
social-bot integration tests / Latency profiling (GPU, Orin) (pull_request) Has been cancelled
Adds ambient_sound_node to saltybot_social:
- Accumulates 1 s of PCM-16 audio from /social/speech/audio_raw
- Extracts mel-spectrogram feature vector (energy_db, zcr, mel_centroid,
  mel_flatness, low_ratio, high_ratio) using pure numpy (no torch/onnx)
- Priority-cascade classifier: silence → music → speech → crowd → outdoor → alarm
- Publishes label as std_msgs/String on /saltybot/ambient_sound on each buffer fill
- All 11 thresholds exposed as ROS parameters (yaml + launch file)
- numpy-free energy-only fallback for edge environments
- 77/77 tests passing

Closes #252
2026-03-02 13:22:38 -05:00
9 changed files with 1434 additions and 0 deletions

View File

@ -0,0 +1,150 @@
"""
_vo_drift.py Visual odometry drift detector helpers (no ROS2 deps).
Algorithm
---------
Two independent odometry streams (visual and wheel) are compared over a
sliding time window. Drift is measured as the absolute difference in
cumulative path length travelled by each source over that window:
drift_m = |path_length(vo_window) path_length(wheel_window)|
Using cumulative path length (sum of inter-sample Euclidean steps) rather
than straight-line displacement makes the measure robust to circular motion
where start and end positions are the same.
Drift is flagged when drift_m drift_threshold_m.
Public API
----------
OdomSample namedtuple(t, x, y)
OdomBuffer deque of OdomSamples with time-window trimming
compute_drift() compare two OdomBuffers and return DriftResult
DriftResult namedtuple(drift_m, vo_path_m, wheel_path_m,
is_drifting, window_s, n_vo, n_wheel)
"""
from __future__ import annotations
import math
from collections import deque
from typing import NamedTuple, Sequence
class OdomSample(NamedTuple):
t: float # monotonic timestamp (seconds)
x: float # position x (metres)
y: float # position y (metres)
class DriftResult(NamedTuple):
drift_m: float # |vo_path wheel_path| (metres)
vo_path_m: float # cumulative path of VO source over window (metres)
wheel_path_m: float # cumulative path of wheel source over window (metres)
is_drifting: bool # True when drift_m >= threshold
window_s: float # actual time span of data used (seconds)
n_vo: int # number of VO samples in window
n_wheel: int # number of wheel samples in window
class OdomBuffer:
"""
Rolling buffer of OdomSamples trimmed to the last `max_age_s` seconds.
Parameters
----------
max_age_s : float samples older than this are discarded (seconds)
"""
def __init__(self, max_age_s: float = 10.0) -> None:
self._max_age = max_age_s
self._buf: deque[OdomSample] = deque()
# ── Public ────────────────────────────────────────────────────────────────
def push(self, sample: OdomSample) -> None:
"""Append a sample and evict anything older than max_age_s."""
self._buf.append(sample)
self._trim(sample.t)
def window(self, window_s: float, now: float) -> list[OdomSample]:
"""Return samples within the last window_s seconds of `now`."""
cutoff = now - window_s
return [s for s in self._buf if s.t >= cutoff]
def clear(self) -> None:
self._buf.clear()
def __len__(self) -> int:
return len(self._buf)
# ── Internal ──────────────────────────────────────────────────────────────
def _trim(self, now: float) -> None:
cutoff = now - self._max_age
while self._buf and self._buf[0].t < cutoff:
self._buf.popleft()
# ── Core computation ──────────────────────────────────────────────────────────
def compute_drift(
vo_buf: OdomBuffer,
wheel_buf: OdomBuffer,
window_s: float,
drift_threshold_m: float,
now: float,
) -> DriftResult:
"""
Compare VO and wheel odometry path lengths over the last `window_s`.
Parameters
----------
vo_buf : OdomBuffer of visual odometry samples
wheel_buf : OdomBuffer of wheel odometry samples
window_s : comparison window width (seconds)
drift_threshold_m : drift_m threshold for is_drifting flag
now : current time (same scale as OdomSample.t)
Returns
-------
DriftResult zero drift if either buffer has fewer than 2 samples.
"""
vo_samples = vo_buf.window(window_s, now)
wheel_samples = wheel_buf.window(window_s, now)
if len(vo_samples) < 2 or len(wheel_samples) < 2:
return DriftResult(
drift_m=0.0, vo_path_m=0.0, wheel_path_m=0.0,
is_drifting=False,
window_s=0.0, n_vo=len(vo_samples), n_wheel=len(wheel_samples),
)
vo_path = _path_length(vo_samples)
wheel_path = _path_length(wheel_samples)
drift_m = abs(vo_path - wheel_path)
# Actual data span = latest timestamp earliest across both buffers
t_min = min(vo_samples[0].t, wheel_samples[0].t)
t_max = max(vo_samples[-1].t, wheel_samples[-1].t)
actual_window = t_max - t_min
return DriftResult(
drift_m=drift_m,
vo_path_m=vo_path,
wheel_path_m=wheel_path,
is_drifting=drift_m >= drift_threshold_m,
window_s=actual_window,
n_vo=len(vo_samples),
n_wheel=len(wheel_samples),
)
def _path_length(samples: Sequence[OdomSample]) -> float:
"""Sum of Euclidean inter-sample distances."""
total = 0.0
for i in range(1, len(samples)):
dx = samples[i].x - samples[i - 1].x
dy = samples[i].y - samples[i - 1].y
total += math.sqrt(dx * dx + dy * dy)
return total

View File

@ -0,0 +1,150 @@
"""
vo_drift_node.py Visual odometry drift detector (Issue #260).
Compares the cumulative path lengths of visual odometry and wheel odometry
over a sliding window. When the absolute difference exceeds the configured
threshold the node flags drift, allowing the system to warn operators,
inflate VO covariance, or fall back to wheel-only localisation.
Subscribes (BEST_EFFORT):
/camera/odom nav_msgs/Odometry visual odometry
/odom nav_msgs/Odometry wheel odometry
For this robot remap to /saltybot/visual_odom + /saltybot/rover_odom.
Publishes:
/saltybot/vo_drift_detected std_msgs/Bool True while drifting
/saltybot/vo_drift_magnitude std_msgs/Float32 drift magnitude (metres)
Parameters
----------
vo_topic str /camera/odom Visual odometry source topic
wheel_topic str /odom Wheel odometry source topic
drift_threshold_m float 0.5 Drift flag threshold (metres)
window_s float 10.0 Comparison window (seconds)
publish_hz float 2.0 Output publication rate (Hz)
max_buffer_age_s float 30.0 Max age of stored samples (s)
"""
from __future__ import annotations
import time
import rclpy
from rclpy.node import Node
from rclpy.qos import QoSProfile, ReliabilityPolicy, HistoryPolicy
from nav_msgs.msg import Odometry
from std_msgs.msg import Bool, Float32
from ._vo_drift import OdomBuffer, OdomSample, compute_drift
_SENSOR_QOS = QoSProfile(
reliability=ReliabilityPolicy.BEST_EFFORT,
history=HistoryPolicy.KEEP_LAST,
depth=4,
)
class VoDriftNode(Node):
def __init__(self) -> None:
super().__init__('vo_drift_node')
self.declare_parameter('vo_topic', '/camera/odom')
self.declare_parameter('wheel_topic', '/odom')
self.declare_parameter('drift_threshold_m', 0.5)
self.declare_parameter('window_s', 10.0)
self.declare_parameter('publish_hz', 2.0)
self.declare_parameter('max_buffer_age_s', 30.0)
vo_topic = self.get_parameter('vo_topic').value
wheel_topic = self.get_parameter('wheel_topic').value
self._thresh = self.get_parameter('drift_threshold_m').value
self._window_s = self.get_parameter('window_s').value
publish_hz = self.get_parameter('publish_hz').value
max_age = self.get_parameter('max_buffer_age_s').value
self._vo_buf = OdomBuffer(max_age_s=max_age)
self._wheel_buf = OdomBuffer(max_age_s=max_age)
self.create_subscription(
Odometry, vo_topic, self._on_vo, _SENSOR_QOS)
self.create_subscription(
Odometry, wheel_topic, self._on_wheel, _SENSOR_QOS)
self._pub_detected = self.create_publisher(
Bool, '/saltybot/vo_drift_detected', 10)
self._pub_magnitude = self.create_publisher(
Float32, '/saltybot/vo_drift_magnitude', 10)
self.create_timer(1.0 / publish_hz, self._tick)
self.get_logger().info(
f'vo_drift_node ready — '
f'vo={vo_topic} wheel={wheel_topic} '
f'threshold={self._thresh}m window={self._window_s}s'
)
# ── Callbacks ─────────────────────────────────────────────────────────────
def _on_vo(self, msg: Odometry) -> None:
s = _odom_to_sample(msg)
self._vo_buf.push(s)
def _on_wheel(self, msg: Odometry) -> None:
s = _odom_to_sample(msg)
self._wheel_buf.push(s)
# ── Publish tick ──────────────────────────────────────────────────────────
def _tick(self) -> None:
now = time.monotonic()
result = compute_drift(
self._vo_buf, self._wheel_buf,
window_s=self._window_s,
drift_threshold_m=self._thresh,
now=now,
)
if result.is_drifting:
self.get_logger().warn(
f'VO drift detected: {result.drift_m:.3f}m '
f'(vo={result.vo_path_m:.3f}m wheel={result.wheel_path_m:.3f}m '
f'over {result.window_s:.1f}s)',
throttle_duration_sec=5.0,
)
det_msg = Bool()
det_msg.data = result.is_drifting
self._pub_detected.publish(det_msg)
mag_msg = Float32()
mag_msg.data = float(result.drift_m)
self._pub_magnitude.publish(mag_msg)
# ── Helpers ───────────────────────────────────────────────────────────────────
def _odom_to_sample(msg: Odometry) -> OdomSample:
"""Convert nav_msgs/Odometry to OdomSample using monotonic clock."""
return OdomSample(
t=time.monotonic(),
x=msg.pose.pose.position.x,
y=msg.pose.pose.position.y,
)
def main(args=None) -> None:
rclpy.init(args=args)
node = VoDriftNode()
try:
rclpy.spin(node)
finally:
node.destroy_node()
rclpy.shutdown()
if __name__ == '__main__':
main()

View File

@ -35,6 +35,8 @@ setup(
'lidar_clustering = saltybot_bringup.lidar_clustering_node:main', 'lidar_clustering = saltybot_bringup.lidar_clustering_node:main',
# Floor surface type classifier (Issue #249) # Floor surface type classifier (Issue #249)
'floor_classifier = saltybot_bringup.floor_classifier_node:main', 'floor_classifier = saltybot_bringup.floor_classifier_node:main',
# Visual odometry drift detector (Issue #260)
'vo_drift_detector = saltybot_bringup.vo_drift_node:main',
], ],
}, },
) )

View File

@ -0,0 +1,297 @@
"""
test_vo_drift.py Unit tests for VO drift detector helpers (no ROS2 required).
Covers:
OdomBuffer:
- push/len
- window returns only samples within cutoff
- old samples are evicted beyond max_age_s
- clear empties the buffer
- window on empty buffer returns empty list
_path_length (via compute_drift with crafted samples):
- stationary source path = 0
- straight-line motion path = total distance
- L-shaped path path = sum of two legs
compute_drift:
- both empty DriftResult with zeros, is_drifting=False
- one buffer < 2 samples zero drift
- both move same distance drift 0, not drifting
- VO moves 1m, wheel moves 0.5m drift = 0.5m
- drift == threshold is_drifting=True (>=)
- drift < threshold is_drifting=False
- drift > threshold is_drifting=True
- path lengths in result match expectation
- n_vo / n_wheel counts correct
- samples outside window ignored
- window_s in result reflects actual data span
"""
import sys
import os
import math
import pytest
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
from saltybot_bringup._vo_drift import (
OdomSample,
OdomBuffer,
DriftResult,
compute_drift,
_path_length,
)
# ── Helpers ───────────────────────────────────────────────────────────────────
def _s(t, x, y) -> OdomSample:
return OdomSample(t=t, x=x, y=y)
def _straight_buf(n=5, speed=0.1, t_start=0.0, dt=1.0,
max_age_s=30.0) -> OdomBuffer:
"""n samples moving along +x at `speed` m/s."""
buf = OdomBuffer(max_age_s=max_age_s)
for i in range(n):
buf.push(_s(t_start + i * dt, x=i * speed * dt, y=0.0))
return buf
def _stationary_buf(n=5, t_start=0.0, dt=1.0,
max_age_s=30.0) -> OdomBuffer:
buf = OdomBuffer(max_age_s=max_age_s)
for i in range(n):
buf.push(_s(t_start + i * dt, x=0.0, y=0.0))
return buf
# ── OdomBuffer ────────────────────────────────────────────────────────────────
class TestOdomBuffer:
def test_push_increases_len(self):
buf = OdomBuffer()
assert len(buf) == 0
buf.push(_s(0.0, 0.0, 0.0))
assert len(buf) == 1
def test_window_returns_all_within_cutoff(self):
buf = OdomBuffer(max_age_s=30.0)
for t in [0.0, 5.0, 10.0]:
buf.push(_s(t, 0.0, 0.0))
samples = buf.window(window_s=10.0, now=10.0)
assert len(samples) == 3
def test_window_excludes_old_samples(self):
buf = OdomBuffer(max_age_s=30.0)
for t in [0.0, 5.0, 15.0]:
buf.push(_s(t, 0.0, 0.0))
# window=5s from now=15 → only t=15 qualifies (t>=10)
samples = buf.window(window_s=5.0, now=15.0)
assert len(samples) == 1
assert samples[0].t == 15.0
def test_evicts_samples_beyond_max_age(self):
buf = OdomBuffer(max_age_s=5.0)
buf.push(_s(0.0, 0.0, 0.0))
buf.push(_s(10.0, 1.0, 0.0)) # now=10 → t=0 is 10s old > 5s max
assert len(buf) == 1
def test_clear_empties_buffer(self):
buf = _straight_buf(n=5)
buf.clear()
assert len(buf) == 0
def test_window_on_empty_buffer(self):
buf = OdomBuffer()
assert buf.window(window_s=10.0, now=100.0) == []
def test_window_boundary_inclusive(self):
"""Sample exactly at window cutoff (t == now - window_s) is included."""
buf = OdomBuffer(max_age_s=30.0)
buf.push(_s(0.0, 0.0, 0.0))
# window=10, now=10 → cutoff=0.0, sample at t=0.0 should be included
samples = buf.window(window_s=10.0, now=10.0)
assert len(samples) == 1
# ── _path_length ──────────────────────────────────────────────────────────────
class TestPathLength:
def test_stationary_path_zero(self):
samples = [_s(i, 0.0, 0.0) for i in range(5)]
assert _path_length(samples) == pytest.approx(0.0)
def test_unit_step_path(self):
samples = [_s(0, 0.0, 0.0), _s(1, 1.0, 0.0)]
assert _path_length(samples) == pytest.approx(1.0)
def test_two_unit_steps(self):
samples = [_s(0, 0.0, 0.0), _s(1, 1.0, 0.0), _s(2, 2.0, 0.0)]
assert _path_length(samples) == pytest.approx(2.0)
def test_diagonal_step(self):
# (0,0) → (1,1): distance = sqrt(2)
samples = [_s(0, 0.0, 0.0), _s(1, 1.0, 1.0)]
assert _path_length(samples) == pytest.approx(math.sqrt(2))
def test_l_shaped_path(self):
# Right 3m then up 4m → total path = 7m (not hypotenuse)
samples = [_s(0, 0.0, 0.0), _s(1, 3.0, 0.0), _s(2, 3.0, 4.0)]
assert _path_length(samples) == pytest.approx(7.0)
def test_single_sample_returns_zero(self):
assert _path_length([_s(0, 5.0, 5.0)]) == pytest.approx(0.0)
def test_empty_returns_zero(self):
assert _path_length([]) == pytest.approx(0.0)
# ── compute_drift ─────────────────────────────────────────────────────────────
class TestComputeDrift:
def test_both_empty_returns_zero_drift(self):
result = compute_drift(
OdomBuffer(), OdomBuffer(),
window_s=10.0, drift_threshold_m=0.5, now=10.0)
assert result.drift_m == pytest.approx(0.0)
assert not result.is_drifting
def test_one_buffer_empty_returns_zero(self):
vo = _straight_buf(n=5, speed=0.1)
result = compute_drift(
vo, OdomBuffer(),
window_s=10.0, drift_threshold_m=0.5, now=5.0)
assert result.drift_m == pytest.approx(0.0)
assert not result.is_drifting
def test_one_buffer_single_sample_returns_zero(self):
vo = _straight_buf(n=5, speed=0.1)
wheel = OdomBuffer()
wheel.push(_s(0.0, 0.0, 0.0)) # only 1 sample
result = compute_drift(
vo, wheel,
window_s=10.0, drift_threshold_m=0.5, now=5.0)
assert result.drift_m == pytest.approx(0.0)
assert not result.is_drifting
def test_both_move_same_distance_zero_drift(self):
# Both move 0.1 m/s for 4 steps → 0.4 m each
vo = _straight_buf(n=5, speed=0.1, dt=1.0)
wheel = _straight_buf(n=5, speed=0.1, dt=1.0)
result = compute_drift(
vo, wheel,
window_s=10.0, drift_threshold_m=0.5, now=5.0)
assert result.drift_m == pytest.approx(0.0, abs=1e-9)
assert not result.is_drifting
def test_both_stationary_zero_drift(self):
vo = _stationary_buf(n=5)
wheel = _stationary_buf(n=5)
result = compute_drift(
vo, wheel,
window_s=10.0, drift_threshold_m=0.5, now=5.0)
assert result.drift_m == pytest.approx(0.0)
assert not result.is_drifting
def test_drift_equals_path_length_difference(self):
# VO moves 1.0 m total, wheel moves 0.5 m total
vo = _straight_buf(n=11, speed=0.1, dt=1.0) # 10 steps × 0.1 = 1.0m
wheel = _straight_buf(n=11, speed=0.05, dt=1.0) # 10 steps × 0.05 = 0.5m
result = compute_drift(
vo, wheel,
window_s=15.0, drift_threshold_m=0.5, now=11.0)
assert result.vo_path_m == pytest.approx(1.0, abs=1e-9)
assert result.wheel_path_m == pytest.approx(0.5, abs=1e-9)
assert result.drift_m == pytest.approx(0.5, abs=1e-9)
def test_drift_at_threshold_is_drifting(self):
# drift == 0.5 → is_drifting = True (>= threshold)
vo = _straight_buf(n=11, speed=0.1, dt=1.0)
wheel = _straight_buf(n=11, speed=0.05, dt=1.0)
result = compute_drift(
vo, wheel,
window_s=15.0, drift_threshold_m=0.5, now=11.0)
assert result.is_drifting
def test_drift_below_threshold_not_drifting(self):
vo = _straight_buf(n=11, speed=0.1, dt=1.0)
wheel = _straight_buf(n=11, speed=0.08, dt=1.0)
result = compute_drift(
vo, wheel,
window_s=15.0, drift_threshold_m=0.5, now=11.0)
# drift = |1.0 - 0.8| = 0.2
assert result.drift_m == pytest.approx(0.2, abs=1e-9)
assert not result.is_drifting
def test_drift_above_threshold_is_drifting(self):
vo = _straight_buf(n=11, speed=0.1, dt=1.0)
wheel = _stationary_buf(n=11, dt=1.0)
result = compute_drift(
vo, wheel,
window_s=15.0, drift_threshold_m=0.5, now=11.0)
# drift = |1.0 - 0.0| = 1.0 > 0.5
assert result.drift_m > 0.5
assert result.is_drifting
def test_n_vo_n_wheel_counts(self):
vo = _straight_buf(n=8, dt=1.0)
wheel = _straight_buf(n=5, dt=1.0)
result = compute_drift(
vo, wheel,
window_s=15.0, drift_threshold_m=0.5, now=8.0)
assert result.n_vo == 8
assert result.n_wheel == 5
def test_samples_outside_window_ignored(self):
# Push old samples far in the past; should not contribute to window
vo = OdomBuffer(max_age_s=60.0)
wheel = OdomBuffer(max_age_s=60.0)
# Old samples outside window (t=0..4, window is last 3s from now=10)
for t in range(5):
vo.push(_s(float(t), x=float(t), y=0.0))
wheel.push(_s(float(t), x=float(t), y=0.0))
# Recent samples inside window (t=7..10)
for t in range(7, 11):
vo.push(_s(float(t), x=float(t) * 0.1, y=0.0))
wheel.push(_s(float(t), x=float(t) * 0.1, y=0.0))
result = compute_drift(
vo, wheel,
window_s=3.0, drift_threshold_m=0.5, now=10.0)
# Both sources move identically inside window → zero drift
assert result.drift_m == pytest.approx(0.0, abs=1e-9)
# Only the 4 recent samples (t=7,8,9,10) in window
assert result.n_vo == 4
assert result.n_wheel == 4
def test_result_is_namedtuple(self):
result = compute_drift(
_straight_buf(), _straight_buf(),
window_s=10.0, drift_threshold_m=0.5, now=5.0)
assert hasattr(result, 'drift_m')
assert hasattr(result, 'vo_path_m')
assert hasattr(result, 'wheel_path_m')
assert hasattr(result, 'is_drifting')
assert hasattr(result, 'window_s')
assert hasattr(result, 'n_vo')
assert hasattr(result, 'n_wheel')
def test_wheel_faster_than_vo_still_drifts(self):
"""Drift is absolute difference — direction doesn't matter."""
vo = _stationary_buf(n=11, dt=1.0)
wheel = _straight_buf(n=11, speed=0.1, dt=1.0)
result = compute_drift(
vo, wheel,
window_s=15.0, drift_threshold_m=0.5, now=11.0)
assert result.drift_m == pytest.approx(1.0, abs=1e-9)
assert result.is_drifting
if __name__ == '__main__':
pytest.main([__file__, '-v'])

View File

@ -0,0 +1,21 @@
ambient_sound_node:
ros__parameters:
sample_rate: 16000 # Expected PCM sample rate (Hz)
window_s: 1.0 # Accumulate this many seconds before classifying
n_fft: 512 # FFT size (32 ms frame at 16 kHz)
n_mels: 32 # Mel filterbank bands
audio_topic: "/social/speech/audio_raw" # Source PCM-16 UInt8MultiArray topic
# ── Classifier thresholds ──────────────────────────────────────────────
# Adjust to tune sensitivity for your deployment environment.
silence_db: -40.0 # Below this energy (dBFS) → silence
alarm_db_min: -25.0 # Min energy for alarm detection
alarm_zcr_min: 0.12 # Min ZCR for alarm (intermittent high pitch)
alarm_high_ratio_min: 0.35 # Min high-band energy fraction for alarm
speech_zcr_min: 0.02 # Min ZCR for speech (voiced onset)
speech_zcr_max: 0.25 # Max ZCR for speech
speech_flatness_max: 0.35 # Max spectral flatness for speech (tonal)
music_zcr_max: 0.08 # Max ZCR for music (harmonic / tonal)
music_flatness_max: 0.25 # Max spectral flatness for music
crowd_zcr_min: 0.10 # Min ZCR for crowd noise
crowd_flatness_min: 0.35 # Min spectral flatness for crowd

View File

@ -0,0 +1,42 @@
"""ambient_sound.launch.py -- Launch the ambient sound classifier (Issue #252).
Usage:
ros2 launch saltybot_social ambient_sound.launch.py
ros2 launch saltybot_social ambient_sound.launch.py silence_db:=-45.0
"""
import os
from ament_index_python.packages import get_package_share_directory
from launch import LaunchDescription
from launch.actions import DeclareLaunchArgument
from launch.substitutions import LaunchConfiguration
from launch_ros.actions import Node
def generate_launch_description():
pkg = get_package_share_directory("saltybot_social")
cfg = os.path.join(pkg, "config", "ambient_sound_params.yaml")
return LaunchDescription([
DeclareLaunchArgument("window_s", default_value="1.0",
description="Accumulation window (s)"),
DeclareLaunchArgument("n_mels", default_value="32",
description="Mel filterbank bands"),
DeclareLaunchArgument("silence_db", default_value="-40.0",
description="Silence energy threshold (dBFS)"),
Node(
package="saltybot_social",
executable="ambient_sound_node",
name="ambient_sound_node",
output="screen",
parameters=[
cfg,
{
"window_s": LaunchConfiguration("window_s"),
"n_mels": LaunchConfiguration("n_mels"),
"silence_db": LaunchConfiguration("silence_db"),
},
],
),
])

View File

@ -0,0 +1,363 @@
"""ambient_sound_node.py -- Ambient sound classifier via mel-spectrogram features.
Issue #252
Accumulates 1 s of PCM-16 audio from /social/speech/audio_raw, extracts a
compact mel-spectrogram feature vector, then classifies the scene into one of:
silence | speech | music | crowd | outdoor | alarm
Publishes the label as std_msgs/String on /saltybot/ambient_sound at 1 Hz.
Signal processing is pure Python + numpy (no torch / onnx dependency).
Feature vector (per 1-s window):
energy_db -- overall RMS in dBFS
zcr -- mean zero-crossing rate across frames
mel_centroid -- centre-of-mass of the mel band energies [0..1]
mel_flatness -- geometric/arithmetic mean of mel energies [0..1]
(1 = white noise, 0 = single sinusoid)
low_ratio -- fraction of mel energy in lower third of bands
high_ratio -- fraction of mel energy in upper third of bands
Classification cascade (priority-ordered):
silence : energy_db < silence_db
alarm : energy_db >= alarm_db_min AND zcr >= alarm_zcr_min
AND high_ratio >= alarm_high_ratio_min
speech : zcr in [speech_zcr_min, speech_zcr_max]
AND mel_flatness < speech_flatness_max
music : zcr < music_zcr_max AND mel_flatness < music_flatness_max
crowd : zcr >= crowd_zcr_min AND mel_flatness >= crowd_flatness_min
outdoor : catch-all
Parameters:
sample_rate (int, 16000)
window_s (float, 1.0) -- accumulation window before classify
n_fft (int, 512) -- FFT size
n_mels (int, 32) -- mel filterbank bands
audio_topic (str, "/social/speech/audio_raw")
silence_db (float, -40.0)
alarm_db_min (float, -25.0)
alarm_zcr_min (float, 0.12)
alarm_high_ratio_min (float, 0.35)
speech_zcr_min (float, 0.02)
speech_zcr_max (float, 0.25)
speech_flatness_max (float, 0.35)
music_zcr_max (float, 0.08)
music_flatness_max (float, 0.25)
crowd_zcr_min (float, 0.10)
crowd_flatness_min (float, 0.35)
"""
from __future__ import annotations
import math
import struct
import threading
from typing import Dict, List, Optional
import rclpy
from rclpy.node import Node
from rclpy.qos import QoSProfile
from std_msgs.msg import String, UInt8MultiArray
# numpy used only in DSP helpers — the Jetson always has it
try:
import numpy as np
_NUMPY = True
except ImportError:
_NUMPY = False
INT16_MAX = 32768.0
LABELS = ("silence", "speech", "music", "crowd", "outdoor", "alarm")
# ── PCM helpers ───────────────────────────────────────────────────────────────
def pcm16_bytes_to_float32(data: bytes) -> List[float]:
"""PCM-16 LE bytes → float32 list in [-1.0, 1.0]."""
n = len(data) // 2
if n == 0:
return []
return [s / INT16_MAX for s in struct.unpack(f"<{n}h", data[: n * 2])]
# ── Mel DSP (numpy path) ──────────────────────────────────────────────────────
def hz_to_mel(hz: float) -> float:
return 2595.0 * math.log10(1.0 + hz / 700.0)
def mel_to_hz(mel: float) -> float:
return 700.0 * (10.0 ** (mel / 2595.0) - 1.0)
def build_mel_filterbank(sr: int, n_fft: int, n_mels: int,
fmin: float = 0.0, fmax: Optional[float] = None):
"""Return (n_mels, n_fft//2+1) numpy filterbank matrix."""
import numpy as np
if fmax is None:
fmax = sr / 2.0
n_freqs = n_fft // 2 + 1
mel_min = hz_to_mel(fmin)
mel_max = hz_to_mel(fmax)
mel_pts = np.linspace(mel_min, mel_max, n_mels + 2)
hz_pts = np.array([mel_to_hz(m) for m in mel_pts])
bin_pts = np.floor((n_fft + 1) * hz_pts / sr).astype(int)
fb = np.zeros((n_mels, n_freqs))
for m in range(n_mels):
lo, ctr, hi = bin_pts[m], bin_pts[m + 1], bin_pts[m + 2]
for k in range(lo, min(ctr, n_freqs)):
if ctr != lo:
fb[m, k] = (k - lo) / (ctr - lo)
for k in range(ctr, min(hi, n_freqs)):
if hi != ctr:
fb[m, k] = (hi - k) / (hi - ctr)
return fb
def compute_mel_spectrogram(samples: List[float], sr: int,
n_fft: int = 512, n_mels: int = 32,
hop_length: int = 256):
"""Return (n_mels, n_frames) log-mel spectrogram (numpy array)."""
import numpy as np
x = np.array(samples, dtype=np.float32)
fb = build_mel_filterbank(sr, n_fft, n_mels)
window = np.hanning(n_fft)
frames = []
for start in range(0, len(x) - n_fft + 1, hop_length):
frame = x[start : start + n_fft] * window
spec = np.abs(np.fft.rfft(frame)) ** 2
mel = fb @ spec
frames.append(mel)
if not frames:
return np.zeros((n_mels, 1), dtype=np.float32)
return np.column_stack(frames).astype(np.float32)
# ── Feature extraction ────────────────────────────────────────────────────────
def extract_features(samples: List[float], sr: int,
n_fft: int = 512, n_mels: int = 32) -> Dict[str, float]:
"""Extract scalar features from a raw audio window."""
import numpy as np
n = len(samples)
if n == 0:
return {k: 0.0 for k in
("energy_db", "zcr", "mel_centroid", "mel_flatness",
"low_ratio", "high_ratio")}
# Energy
rms = math.sqrt(sum(s * s for s in samples) / n) if n else 0.0
energy_db = 20.0 * math.log10(max(rms, 1e-10))
# ZCR across 30 ms frames
chunk = max(1, int(sr * 0.030))
zcr_vals = []
for i in range(0, n - chunk + 1, chunk):
seg = samples[i : i + chunk]
crossings = sum(1 for j in range(1, len(seg))
if seg[j - 1] * seg[j] < 0)
zcr_vals.append(crossings / max(len(seg) - 1, 1))
zcr = sum(zcr_vals) / len(zcr_vals) if zcr_vals else 0.0
# Mel spectrogram features
mel_spec = compute_mel_spectrogram(samples, sr, n_fft, n_mels)
mel_mean = mel_spec.mean(axis=1) # (n_mels,) mean energy per band
total = float(mel_mean.sum()) if mel_mean.sum() > 0 else 1e-10
indices = np.arange(n_mels, dtype=np.float32)
mel_centroid = float((indices * mel_mean).sum()) / (n_mels * total / total) / n_mels
# Spectral flatness: geometric mean / arithmetic mean
eps = 1e-10
mel_pos = np.clip(mel_mean, eps, None)
geo_mean = float(np.exp(np.log(mel_pos).mean()))
arith_mean = float(mel_pos.mean())
mel_flatness = min(geo_mean / max(arith_mean, eps), 1.0)
# Band ratios
third = max(1, n_mels // 3)
low_energy = float(mel_mean[:third].sum())
high_energy = float(mel_mean[-third:].sum())
low_ratio = low_energy / max(total, eps)
high_ratio = high_energy / max(total, eps)
return {
"energy_db": energy_db,
"zcr": zcr,
"mel_centroid": mel_centroid,
"mel_flatness": mel_flatness,
"low_ratio": low_ratio,
"high_ratio": high_ratio,
}
# ── Classifier ────────────────────────────────────────────────────────────────
def classify(features: Dict[str, float],
silence_db: float = -40.0,
alarm_db_min: float = -25.0,
alarm_zcr_min: float = 0.12,
alarm_high_ratio_min: float = 0.35,
speech_zcr_min: float = 0.02,
speech_zcr_max: float = 0.25,
speech_flatness_max: float = 0.35,
music_zcr_max: float = 0.08,
music_flatness_max: float = 0.25,
crowd_zcr_min: float = 0.10,
crowd_flatness_min: float = 0.35) -> str:
"""Priority-ordered rule cascade. Returns a label from LABELS."""
e = features["energy_db"]
zcr = features["zcr"]
fl = features["mel_flatness"]
hi = features["high_ratio"]
if e < silence_db:
return "silence"
if (e >= alarm_db_min
and zcr >= alarm_zcr_min
and hi >= alarm_high_ratio_min):
return "alarm"
if zcr < music_zcr_max and fl < music_flatness_max:
return "music"
if (speech_zcr_min <= zcr <= speech_zcr_max
and fl < speech_flatness_max):
return "speech"
if zcr >= crowd_zcr_min and fl >= crowd_flatness_min:
return "crowd"
return "outdoor"
# ── Audio accumulation buffer ─────────────────────────────────────────────────
class AudioBuffer:
"""Thread-safe ring buffer; yields a window of samples when full."""
def __init__(self, window_samples: int) -> None:
self._target = window_samples
self._buf: List[float] = []
self._lock = threading.Lock()
def push(self, samples: List[float]) -> Optional[List[float]]:
"""Append samples. Returns a complete window (and resets) when full."""
with self._lock:
self._buf.extend(samples)
if len(self._buf) >= self._target:
window = self._buf[: self._target]
self._buf = self._buf[self._target :]
return window
return None
def clear(self) -> None:
with self._lock:
self._buf.clear()
# ── ROS2 node ─────────────────────────────────────────────────────────────────
class AmbientSoundNode(Node):
"""Classifies ambient sound from raw audio and publishes label at 1 Hz."""
def __init__(self) -> None:
super().__init__("ambient_sound_node")
self.declare_parameter("sample_rate", 16000)
self.declare_parameter("window_s", 1.0)
self.declare_parameter("n_fft", 512)
self.declare_parameter("n_mels", 32)
self.declare_parameter("audio_topic", "/social/speech/audio_raw")
# Classifier thresholds
self.declare_parameter("silence_db", -40.0)
self.declare_parameter("alarm_db_min", -25.0)
self.declare_parameter("alarm_zcr_min", 0.12)
self.declare_parameter("alarm_high_ratio_min", 0.35)
self.declare_parameter("speech_zcr_min", 0.02)
self.declare_parameter("speech_zcr_max", 0.25)
self.declare_parameter("speech_flatness_max", 0.35)
self.declare_parameter("music_zcr_max", 0.08)
self.declare_parameter("music_flatness_max", 0.25)
self.declare_parameter("crowd_zcr_min", 0.10)
self.declare_parameter("crowd_flatness_min", 0.35)
self._sr = self.get_parameter("sample_rate").value
self._n_fft = self.get_parameter("n_fft").value
self._n_mels = self.get_parameter("n_mels").value
window_s = self.get_parameter("window_s").value
audio_topic = self.get_parameter("audio_topic").value
self._thresholds = {
k: self.get_parameter(k).value for k in (
"silence_db", "alarm_db_min", "alarm_zcr_min",
"alarm_high_ratio_min", "speech_zcr_min", "speech_zcr_max",
"speech_flatness_max", "music_zcr_max", "music_flatness_max",
"crowd_zcr_min", "crowd_flatness_min",
)
}
self._buffer = AudioBuffer(int(self._sr * window_s))
self._last_label = "silence"
qos = QoSProfile(depth=10)
self._pub = self.create_publisher(String, "/saltybot/ambient_sound", qos)
self._audio_sub = self.create_subscription(
UInt8MultiArray, audio_topic, self._on_audio, qos
)
if not _NUMPY:
self.get_logger().warn(
"numpy not available — mel features disabled, classifying by energy only"
)
self.get_logger().info(
f"AmbientSoundNode ready "
f"(sr={self._sr}, window={window_s}s, n_mels={self._n_mels})"
)
def _on_audio(self, msg: UInt8MultiArray) -> None:
samples = pcm16_bytes_to_float32(bytes(msg.data))
if not samples:
return
window = self._buffer.push(samples)
if window is not None:
self._classify_and_publish(window)
def _classify_and_publish(self, samples: List[float]) -> None:
try:
if _NUMPY:
feats = extract_features(samples, self._sr, self._n_fft, self._n_mels)
else:
# Numpy-free fallback: energy-only
rms = math.sqrt(sum(s * s for s in samples) / len(samples))
e_db = 20.0 * math.log10(max(rms, 1e-10))
feats = {
"energy_db": e_db, "zcr": 0.05,
"mel_centroid": 0.5, "mel_flatness": 0.2,
"low_ratio": 0.4, "high_ratio": 0.2,
}
label = classify(feats, **self._thresholds)
except Exception as exc:
self.get_logger().error(f"Classification error: {exc}")
label = self._last_label
if label != self._last_label:
self.get_logger().info(
f"Ambient sound: {self._last_label} -> {label}"
)
self._last_label = label
msg = String()
msg.data = label
self._pub.publish(msg)
def main(args: Optional[list] = None) -> None:
rclpy.init(args=args)
node = AmbientSoundNode()
try:
rclpy.spin(node)
except KeyboardInterrupt:
pass
finally:
node.destroy_node()
rclpy.shutdown()

View File

@ -45,6 +45,8 @@ setup(
'mesh_comms_node = saltybot_social.mesh_comms_node:main', 'mesh_comms_node = saltybot_social.mesh_comms_node:main',
# Energy+ZCR voice activity detection (Issue #242) # Energy+ZCR voice activity detection (Issue #242)
'vad_node = saltybot_social.vad_node:main', 'vad_node = saltybot_social.vad_node:main',
# Ambient sound classifier — mel-spectrogram (Issue #252)
'ambient_sound_node = saltybot_social.ambient_sound_node:main',
], ],
}, },
) )

View File

@ -0,0 +1,407 @@
"""test_ambient_sound.py -- Unit tests for Issue #252 ambient sound classifier."""
from __future__ import annotations
import importlib.util, math, os, struct, sys, types
import pytest
# numpy is available on dev machine
import numpy as np
def _pkg_root():
return os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
def _read_src(rel_path):
with open(os.path.join(_pkg_root(), rel_path)) as f:
return f.read()
def _import_mod():
"""Import ambient_sound_node without a live ROS2 environment."""
for mod_name in ("rclpy", "rclpy.node", "rclpy.qos",
"std_msgs", "std_msgs.msg"):
if mod_name not in sys.modules:
sys.modules[mod_name] = types.ModuleType(mod_name)
rclpy_node = sys.modules["rclpy.node"]
rclpy_qos = sys.modules["rclpy.qos"]
std_msg = sys.modules["std_msgs.msg"]
DEFAULTS = {
"sample_rate": 16000, "window_s": 1.0, "n_fft": 512, "n_mels": 32,
"audio_topic": "/social/speech/audio_raw",
"silence_db": -40.0, "alarm_db_min": -25.0, "alarm_zcr_min": 0.12,
"alarm_high_ratio_min": 0.35, "speech_zcr_min": 0.02,
"speech_zcr_max": 0.25, "speech_flatness_max": 0.35,
"music_zcr_max": 0.08, "music_flatness_max": 0.25,
"crowd_zcr_min": 0.10, "crowd_flatness_min": 0.35,
}
class _Node:
def __init__(self, *a, **kw): pass
def declare_parameter(self, *a, **kw): pass
def get_parameter(self, name):
class _P:
value = DEFAULTS.get(name)
return _P()
def create_publisher(self, *a, **kw): return None
def create_subscription(self, *a, **kw): return None
def get_logger(self):
class _L:
def info(self, *a): pass
def warn(self, *a): pass
def error(self, *a): pass
return _L()
def destroy_node(self): pass
rclpy_node.Node = _Node
rclpy_qos.QoSProfile = type("QoSProfile", (), {"__init__": lambda s, **kw: None})
std_msg.String = type("String", (), {"data": ""})
std_msg.UInt8MultiArray = type("UInt8MultiArray", (), {"data": b""})
sys.modules["rclpy"].init = lambda *a, **kw: None
sys.modules["rclpy"].spin = lambda n: None
sys.modules["rclpy"].ok = lambda: True
sys.modules["rclpy"].shutdown = lambda: None
spec = importlib.util.spec_from_file_location(
"ambient_sound_node_testmod",
os.path.join(_pkg_root(), "saltybot_social", "ambient_sound_node.py"),
)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
return mod
# ── Audio helpers ─────────────────────────────────────────────────────────────
SR = 16000
def _sine(freq, n=SR, amp=0.2):
return [amp * math.sin(2 * math.pi * freq * i / SR) for i in range(n)]
def _white_noise(n=SR, amp=0.1):
import random
rng = random.Random(42)
return [rng.uniform(-amp, amp) for _ in range(n)]
def _silence(n=SR):
return [0.0] * n
def _pcm16(samples):
ints = [max(-32768, min(32767, int(s * 32768))) for s in samples]
return struct.pack(f"<{len(ints)}h", *ints)
# ── TestPcm16Convert ──────────────────────────────────────────────────────────
class TestPcm16Convert:
@pytest.fixture(scope="class")
def mod(self): return _import_mod()
def test_empty(self, mod):
assert mod.pcm16_bytes_to_float32(b"") == []
def test_length(self, mod):
data = _pcm16(_sine(440, 480))
assert len(mod.pcm16_bytes_to_float32(data)) == 480
def test_range(self, mod):
data = _pcm16(_sine(440, 480))
result = mod.pcm16_bytes_to_float32(data)
assert all(-1.0 <= s <= 1.0 for s in result)
def test_silence(self, mod):
data = _pcm16(_silence(100))
assert all(s == 0.0 for s in mod.pcm16_bytes_to_float32(data))
# ── TestMelConversions ────────────────────────────────────────────────────────
class TestMelConversions:
@pytest.fixture(scope="class")
def mod(self): return _import_mod()
def test_hz_to_mel_zero(self, mod):
assert mod.hz_to_mel(0.0) == 0.0
def test_hz_to_mel_1000(self, mod):
# 1000 Hz → ~999.99 mel (approximately)
assert abs(mod.hz_to_mel(1000.0) - 999.99) < 1.0
def test_roundtrip(self, mod):
for hz in (100.0, 500.0, 1000.0, 4000.0, 8000.0):
assert abs(mod.mel_to_hz(mod.hz_to_mel(hz)) - hz) < 0.01
def test_monotone_increasing(self, mod):
freqs = [100, 500, 1000, 2000, 4000, 8000]
mels = [mod.hz_to_mel(f) for f in freqs]
assert mels == sorted(mels)
# ── TestMelFilterbank ─────────────────────────────────────────────────────────
class TestMelFilterbank:
@pytest.fixture(scope="class")
def mod(self): return _import_mod()
def test_shape(self, mod):
fb = mod.build_mel_filterbank(SR, 512, 32)
assert fb.shape == (32, 257) # (n_mels, n_fft//2+1)
def test_nonnegative(self, mod):
fb = mod.build_mel_filterbank(SR, 512, 32)
assert (fb >= 0).all()
def test_each_filter_sums_positive(self, mod):
fb = mod.build_mel_filterbank(SR, 512, 32)
assert all(fb[m].sum() > 0 for m in range(32))
def test_custom_n_mels(self, mod):
fb = mod.build_mel_filterbank(SR, 512, 16)
assert fb.shape[0] == 16
def test_max_value_leq_one(self, mod):
fb = mod.build_mel_filterbank(SR, 512, 32)
assert fb.max() <= 1.0 + 1e-6
# ── TestMelSpectrogram ────────────────────────────────────────────────────────
class TestMelSpectrogram:
@pytest.fixture(scope="class")
def mod(self): return _import_mod()
def test_shape(self, mod):
s = _sine(440, SR)
spec = mod.compute_mel_spectrogram(s, SR, n_fft=512, n_mels=32, hop_length=256)
assert spec.shape[0] == 32
assert spec.shape[1] > 0
def test_silence_near_zero(self, mod):
spec = mod.compute_mel_spectrogram(_silence(SR), SR, n_fft=512, n_mels=32)
assert spec.mean() < 1e-6
def test_louder_has_higher_energy(self, mod):
quiet = mod.compute_mel_spectrogram(_sine(440, SR, amp=0.01), SR).mean()
loud = mod.compute_mel_spectrogram(_sine(440, SR, amp=0.5), SR).mean()
assert loud > quiet
def test_returns_array(self, mod):
spec = mod.compute_mel_spectrogram(_sine(440, SR), SR)
assert isinstance(spec, np.ndarray)
# ── TestExtractFeatures ───────────────────────────────────────────────────────
class TestExtractFeatures:
@pytest.fixture(scope="class")
def mod(self): return _import_mod()
def _feats(self, mod, samples):
return mod.extract_features(samples, SR, n_fft=512, n_mels=32)
def test_keys_present(self, mod):
f = self._feats(mod, _sine(440, SR))
for k in ("energy_db", "zcr", "mel_centroid", "mel_flatness",
"low_ratio", "high_ratio"):
assert k in f
def test_silence_low_energy(self, mod):
f = self._feats(mod, _silence(SR))
assert f["energy_db"] < -40.0
def test_silence_zero_zcr(self, mod):
f = self._feats(mod, _silence(SR))
assert f["zcr"] == 0.0
def test_sine_moderate_energy(self, mod):
f = self._feats(mod, _sine(440, SR, amp=0.1))
assert -40.0 < f["energy_db"] < 0.0
def test_ratios_sum_leq_one(self, mod):
f = self._feats(mod, _sine(440, SR))
assert f["low_ratio"] + f["high_ratio"] <= 1.0 + 1e-6
def test_ratios_nonnegative(self, mod):
f = self._feats(mod, _sine(440, SR))
assert f["low_ratio"] >= 0.0 and f["high_ratio"] >= 0.0
def test_flatness_in_unit_interval(self, mod):
f = self._feats(mod, _sine(440, SR))
assert 0.0 <= f["mel_flatness"] <= 1.0
def test_white_noise_high_flatness(self, mod):
f_noise = self._feats(mod, _white_noise(SR, amp=0.3))
f_sine = self._feats(mod, _sine(440, SR, amp=0.3))
# White noise should have higher spectral flatness than a pure tone
assert f_noise["mel_flatness"] > f_sine["mel_flatness"]
def test_empty_samples(self, mod):
f = mod.extract_features([], SR)
assert f["energy_db"] == 0.0
def test_louder_higher_energy_db(self, mod):
quiet = self._feats(mod, _sine(440, SR, amp=0.01))["energy_db"]
loud = self._feats(mod, _sine(440, SR, amp=0.5))["energy_db"]
assert loud > quiet
# ── TestClassifier ────────────────────────────────────────────────────────────
class TestClassifier:
@pytest.fixture(scope="class")
def mod(self): return _import_mod()
def _cls(self, mod, **feat_overrides):
base = {"energy_db": -20.0, "zcr": 0.05,
"mel_centroid": 0.4, "mel_flatness": 0.2,
"low_ratio": 0.4, "high_ratio": 0.2}
base.update(feat_overrides)
return mod.classify(base)
def test_silence(self, mod):
assert self._cls(mod, energy_db=-45.0) == "silence"
def test_silence_at_threshold(self, mod):
assert self._cls(mod, energy_db=-40.0) != "silence"
def test_alarm(self, mod):
assert self._cls(mod, energy_db=-20.0, zcr=0.15, high_ratio=0.40) == "alarm"
def test_alarm_requires_high_ratio(self, mod):
result = self._cls(mod, energy_db=-20.0, zcr=0.15, high_ratio=0.10)
assert result != "alarm"
def test_speech(self, mod):
assert self._cls(mod, energy_db=-25.0, zcr=0.08,
mel_flatness=0.20) == "speech"
def test_speech_zcr_too_low(self, mod):
result = self._cls(mod, energy_db=-25.0, zcr=0.005, mel_flatness=0.2)
assert result != "speech"
def test_speech_zcr_too_high(self, mod):
result = self._cls(mod, energy_db=-25.0, zcr=0.30, mel_flatness=0.2)
assert result != "speech"
def test_music(self, mod):
assert self._cls(mod, energy_db=-25.0, zcr=0.04,
mel_flatness=0.10) == "music"
def test_crowd(self, mod):
assert self._cls(mod, energy_db=-25.0, zcr=0.15,
mel_flatness=0.40) == "crowd"
def test_outdoor_catchall(self, mod):
# Moderate energy, mid ZCR, mid flatness → outdoor
result = self._cls(mod, energy_db=-35.0, zcr=0.06, mel_flatness=0.30)
assert result in mod.LABELS
def test_returns_valid_label(self, mod):
import random
rng = random.Random(0)
for _ in range(20):
f = {
"energy_db": rng.uniform(-60, 0),
"zcr": rng.uniform(0, 0.5),
"mel_centroid": rng.uniform(0, 1),
"mel_flatness": rng.uniform(0, 1),
"low_ratio": rng.uniform(0, 0.6),
"high_ratio": rng.uniform(0, 0.4),
}
assert mod.classify(f) in mod.LABELS
# ── TestAudioBuffer ───────────────────────────────────────────────────────────
class TestAudioBuffer:
@pytest.fixture(scope="class")
def mod(self): return _import_mod()
def test_no_window_until_full(self, mod):
buf = mod.AudioBuffer(window_samples=100)
assert buf.push([0.0] * 50) is None
def test_exact_fill_returns_window(self, mod):
buf = mod.AudioBuffer(window_samples=100)
w = buf.push([0.0] * 100)
assert w is not None and len(w) == 100
def test_overflow_carries_over(self, mod):
buf = mod.AudioBuffer(window_samples=100)
buf.push([0.0] * 100) # fills first window
w2 = buf.push([1.0] * 100) # fills second window
assert w2 is not None and len(w2) == 100
def test_partial_then_complete(self, mod):
buf = mod.AudioBuffer(window_samples=100)
buf.push([0.0] * 60)
w = buf.push([0.0] * 60)
assert w is not None and len(w) == 100
def test_clear_resets(self, mod):
buf = mod.AudioBuffer(window_samples=100)
buf.push([0.0] * 90)
buf.clear()
assert buf.push([0.0] * 90) is None
def test_window_contents_correct(self, mod):
buf = mod.AudioBuffer(window_samples=4)
w = buf.push([1.0, 2.0, 3.0, 4.0])
assert w == [1.0, 2.0, 3.0, 4.0]
# ── TestNodeSrc ───────────────────────────────────────────────────────────────
class TestNodeSrc:
@pytest.fixture(scope="class")
def src(self): return _read_src("saltybot_social/ambient_sound_node.py")
def test_class_defined(self, src): assert "class AmbientSoundNode" in src
def test_audio_buffer(self, src): assert "class AudioBuffer" in src
def test_extract_features(self, src): assert "def extract_features" in src
def test_classify_fn(self, src): assert "def classify" in src
def test_mel_spectrogram(self, src): assert "compute_mel_spectrogram" in src
def test_mel_filterbank(self, src): assert "build_mel_filterbank" in src
def test_hz_to_mel(self, src): assert "hz_to_mel" in src
def test_labels_tuple(self, src): assert "LABELS" in src
def test_all_labels(self, src):
for label in ("silence", "speech", "music", "crowd", "outdoor", "alarm"):
assert label in src
def test_topic_pub(self, src): assert '"/saltybot/ambient_sound"' in src
def test_topic_sub(self, src): assert '"/social/speech/audio_raw"' in src
def test_window_param(self, src): assert '"window_s"' in src
def test_n_mels_param(self, src): assert '"n_mels"' in src
def test_silence_param(self, src): assert '"silence_db"' in src
def test_alarm_param(self, src): assert '"alarm_db_min"' in src
def test_speech_param(self, src): assert '"speech_zcr_min"' in src
def test_music_param(self, src): assert '"music_zcr_max"' in src
def test_crowd_param(self, src): assert '"crowd_zcr_min"' in src
def test_string_pub(self, src): assert "String" in src
def test_uint8_sub(self, src): assert "UInt8MultiArray" in src
def test_issue_tag(self, src): assert "252" in src
def test_main(self, src): assert "def main" in src
def test_numpy_optional(self, src): assert "_NUMPY" in src
# ── TestConfig ────────────────────────────────────────────────────────────────
class TestConfig:
@pytest.fixture(scope="class")
def cfg(self): return _read_src("config/ambient_sound_params.yaml")
@pytest.fixture(scope="class")
def setup(self): return _read_src("setup.py")
def test_node_name(self, cfg): assert "ambient_sound_node:" in cfg
def test_window_s(self, cfg): assert "window_s" in cfg
def test_n_mels(self, cfg): assert "n_mels" in cfg
def test_silence_db(self, cfg): assert "silence_db" in cfg
def test_alarm_params(self, cfg): assert "alarm_db_min" in cfg
def test_speech_params(self, cfg): assert "speech_zcr_min" in cfg
def test_music_params(self, cfg): assert "music_zcr_max" in cfg
def test_crowd_params(self, cfg): assert "crowd_zcr_min" in cfg
def test_defaults_present(self, cfg): assert "-40.0" in cfg and "0.12" in cfg
def test_entry_point(self, setup):
assert "ambient_sound_node = saltybot_social.ambient_sound_node:main" in setup