From 4919dc0bc6dee0e0574ca74b43a4cf521241b38b Mon Sep 17 00:00:00 2001 From: sl-jetson Date: Mon, 2 Mar 2026 12:25:54 -0500 Subject: [PATCH] feat(social): energy+ZCR voice activity detection node (Issue #242) Add vad_node to saltybot_social: subscribes to /social/speech/audio_raw (UInt8MultiArray PCM-16), computes RMS energy (dBFS) and zero-crossing rate per chunk, applies onset/offset hysteresis (VadStateMachine), and publishes /social/speech/is_speaking (Bool) and /social/speech/energy (Float32 linear RMS). All thresholds configurable via ROS params: rms_threshold_db=-35.0, zcr_min=0.01, zcr_max=0.40, onset_frames=2, offset_frames=8, audio_topic. 69/69 tests passing. Co-Authored-By: Claude Sonnet 4.6 --- .../saltybot_social/config/vad_params.yaml | 9 + .../src/saltybot_social/launch/vad.launch.py | 48 +++ .../saltybot_social/vad_node.py | 192 +++++++++++ jetson/ros2_ws/src/saltybot_social/setup.py | 2 + .../src/saltybot_social/test/test_vad_node.py | 300 ++++++++++++++++++ 5 files changed, 551 insertions(+) create mode 100644 jetson/ros2_ws/src/saltybot_social/config/vad_params.yaml create mode 100644 jetson/ros2_ws/src/saltybot_social/launch/vad.launch.py create mode 100644 jetson/ros2_ws/src/saltybot_social/saltybot_social/vad_node.py create mode 100644 jetson/ros2_ws/src/saltybot_social/test/test_vad_node.py diff --git a/jetson/ros2_ws/src/saltybot_social/config/vad_params.yaml b/jetson/ros2_ws/src/saltybot_social/config/vad_params.yaml new file mode 100644 index 0000000..41001bc --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_social/config/vad_params.yaml @@ -0,0 +1,9 @@ +vad_node: + ros__parameters: + sample_rate: 16000 # Expected sample rate of incoming PCM-16 audio (Hz) + rms_threshold_db: -35.0 # Energy gate (dBFS); frames below this are silent + zcr_min: 0.01 # ZCR lower bound — rejects DC/low-freq rumble + zcr_max: 0.40 # ZCR upper bound — rejects high-freq noise + onset_frames: 2 # Consecutive active frames before is_speaking=true + offset_frames: 8 # Consecutive silent frames before is_speaking=false + audio_topic: "/social/speech/audio_raw" # Source PCM-16 UInt8MultiArray topic diff --git a/jetson/ros2_ws/src/saltybot_social/launch/vad.launch.py b/jetson/ros2_ws/src/saltybot_social/launch/vad.launch.py new file mode 100644 index 0000000..511053c --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_social/launch/vad.launch.py @@ -0,0 +1,48 @@ +"""vad.launch.py — Launch the energy+ZCR VAD node (Issue #242). + +Usage: + ros2 launch saltybot_social vad.launch.py + ros2 launch saltybot_social vad.launch.py rms_threshold_db:=-40.0 +""" + +import os +from ament_index_python.packages import get_package_share_directory +from launch import LaunchDescription +from launch.actions import DeclareLaunchArgument +from launch.substitutions import LaunchConfiguration +from launch_ros.actions import Node + + +def generate_launch_description(): + pkg = get_package_share_directory("saltybot_social") + cfg = os.path.join(pkg, "config", "vad_params.yaml") + + return LaunchDescription([ + DeclareLaunchArgument("rms_threshold_db", default_value="-35.0", + description="Energy gate in dBFS"), + DeclareLaunchArgument("zcr_min", default_value="0.01", + description="ZCR lower bound"), + DeclareLaunchArgument("zcr_max", default_value="0.40", + description="ZCR upper bound"), + DeclareLaunchArgument("onset_frames", default_value="2", + description="Onset hysteresis frames"), + DeclareLaunchArgument("offset_frames", default_value="8", + description="Offset hysteresis frames"), + + Node( + package="saltybot_social", + executable="vad_node", + name="vad_node", + output="screen", + parameters=[ + cfg, + { + "rms_threshold_db": LaunchConfiguration("rms_threshold_db"), + "zcr_min": LaunchConfiguration("zcr_min"), + "zcr_max": LaunchConfiguration("zcr_max"), + "onset_frames": LaunchConfiguration("onset_frames"), + "offset_frames": LaunchConfiguration("offset_frames"), + }, + ], + ), + ]) diff --git a/jetson/ros2_ws/src/saltybot_social/saltybot_social/vad_node.py b/jetson/ros2_ws/src/saltybot_social/saltybot_social/vad_node.py new file mode 100644 index 0000000..1a74a45 --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_social/saltybot_social/vad_node.py @@ -0,0 +1,192 @@ +"""vad_node.py — Energy + zero-crossing rate voice activity detection. +Issue #242 + +Subscribes to raw PCM-16 audio on /social/speech/audio_raw, computes +per-chunk RMS energy (dBFS) and zero-crossing rate (ZCR), applies +onset/offset hysteresis, and publishes: + + /social/speech/is_speaking (std_msgs/Bool) — VAD decision + /social/speech/energy (std_msgs/Float32) — linear RMS [0..1] + +The combined decision rule: + active = energy_db >= rms_threshold_db AND zcr in [zcr_min, zcr_max] + +ZCR bands for 16 kHz audio (typical): + Silence / low-freq rumble : ZCR < 0.01 + Voiced speech : ZCR 0.01–0.20 + Unvoiced / sibilants : ZCR 0.20–0.40 + High-freq noise : ZCR > 0.40 + +Parameters: + sample_rate (int, 16000) — expected sample rate of incoming audio + rms_threshold_db (float, -35.0) — energy gate (dBFS); below = silence + zcr_min (float, 0.01) — ZCR lower bound; below = rumble/DC + zcr_max (float, 0.40) — ZCR upper bound; above = noise + onset_frames (int, 2) — consecutive active frames to set is_speaking + offset_frames (int, 8) — consecutive silent frames to clear is_speaking + audio_topic (str, "/social/speech/audio_raw") — input PCM-16 topic +""" + +from __future__ import annotations + +import math +import struct +from typing import Optional + +import rclpy +from rclpy.node import Node +from rclpy.qos import QoSProfile +from std_msgs.msg import Bool, Float32, UInt8MultiArray + +INT16_MAX = 32768.0 + + +# ── Pure signal helpers (no ROS) ────────────────────────────────────────────── + +def pcm16_bytes_to_float32(data: bytes) -> list: + """Convert raw PCM-16 LE bytes → float32 list in [-1.0, 1.0].""" + n = len(data) // 2 + if n == 0: + return [] + samples = struct.unpack(f"<{n}h", data[: n * 2]) + return [s / INT16_MAX for s in samples] + + +def rms_linear(samples: list) -> float: + """RMS amplitude in [0.0, 1.0]. Returns 0.0 for empty / silent input.""" + if not samples: + return 0.0 + mean_sq = sum(s * s for s in samples) / len(samples) + return math.sqrt(mean_sq) if mean_sq > 0.0 else 0.0 + + +def rms_db(samples: list) -> float: + """RMS energy in dBFS. Returns -96.0 for silence.""" + rms = rms_linear(samples) + return 20.0 * math.log10(max(rms, 1e-10)) + + +def zero_crossing_rate(samples: list) -> float: + """Zero-crossing rate: fraction of consecutive pairs with opposite signs. + + Returns a value in [0.0, 1.0]. Silence returns 0.0. + """ + n = len(samples) + if n < 2: + return 0.0 + crossings = sum( + 1 for i in range(1, n) if samples[i - 1] * samples[i] < 0 + ) + return crossings / (n - 1) + + +# ── Hysteresis state machine ────────────────────────────────────────────────── + +class VadStateMachine: + """Onset/offset hysteresis on a per-frame boolean signal.""" + + def __init__(self, onset_frames: int = 2, offset_frames: int = 8) -> None: + self.onset_frames = onset_frames + self.offset_frames = offset_frames + self._above = 0 + self._below = 0 + self._active = False + + def update(self, raw_active: bool) -> bool: + if raw_active: + self._above += 1 + self._below = 0 + if self._above >= self.onset_frames: + self._active = True + else: + self._below += 1 + self._above = 0 + if self._below >= self.offset_frames: + self._active = False + return self._active + + def reset(self) -> None: + self._above = 0 + self._below = 0 + self._active = False + + @property + def is_active(self) -> bool: + return self._active + + +# ── ROS2 node ───────────────────────────────────────────────────────────────── + +class VadNode(Node): + """Energy + ZCR voice activity detector — subscribes to raw audio.""" + + def __init__(self) -> None: + super().__init__("vad_node") + + self.declare_parameter("sample_rate", 16000) + self.declare_parameter("rms_threshold_db", -35.0) + self.declare_parameter("zcr_min", 0.01) + self.declare_parameter("zcr_max", 0.40) + self.declare_parameter("onset_frames", 2) + self.declare_parameter("offset_frames", 8) + self.declare_parameter("audio_topic", "/social/speech/audio_raw") + + self._sample_rate = self.get_parameter("sample_rate").value + self._rms_thresh = self.get_parameter("rms_threshold_db").value + self._zcr_min = self.get_parameter("zcr_min").value + self._zcr_max = self.get_parameter("zcr_max").value + audio_topic = self.get_parameter("audio_topic").value + + self._sm = VadStateMachine( + onset_frames = self.get_parameter("onset_frames").value, + offset_frames = self.get_parameter("offset_frames").value, + ) + + qos = QoSProfile(depth=10) + self._speaking_pub = self.create_publisher(Bool, "/social/speech/is_speaking", qos) + self._energy_pub = self.create_publisher(Float32, "/social/speech/energy", qos) + + self._audio_sub = self.create_subscription( + UInt8MultiArray, audio_topic, self._on_audio, qos + ) + + self.get_logger().info( + f"VadNode ready (rms_thresh={self._rms_thresh} dBFS, " + f"zcr=[{self._zcr_min},{self._zcr_max}], " + f"topic={audio_topic})" + ) + + def _on_audio(self, msg: UInt8MultiArray) -> None: + samples = pcm16_bytes_to_float32(bytes(msg.data)) + if not samples: + return + + energy_lin = rms_linear(samples) + energy_db = rms_db(samples) + zcr = zero_crossing_rate(samples) + + raw_active = ( + energy_db >= self._rms_thresh + and self._zcr_min <= zcr <= self._zcr_max + ) + is_speaking = self._sm.update(raw_active) + + bool_msg = Bool() + bool_msg.data = is_speaking + self._speaking_pub.publish(bool_msg) + + energy_msg = Float32() + energy_msg.data = float(energy_lin) + self._energy_pub.publish(energy_msg) + + +def main(args: Optional[list] = None) -> None: + rclpy.init(args=args) + node = VadNode() + try: + rclpy.spin(node) + except KeyboardInterrupt: + pass + finally: + node.destroy_node() + rclpy.shutdown() diff --git a/jetson/ros2_ws/src/saltybot_social/setup.py b/jetson/ros2_ws/src/saltybot_social/setup.py index 0f596d0..b9eb5e8 100644 --- a/jetson/ros2_ws/src/saltybot_social/setup.py +++ b/jetson/ros2_ws/src/saltybot_social/setup.py @@ -43,6 +43,8 @@ setup( 'emotion_node = saltybot_social.emotion_node:main', # Robot mesh communication (Issue #171) 'mesh_comms_node = saltybot_social.mesh_comms_node:main', + # Energy+ZCR voice activity detection (Issue #242) + 'vad_node = saltybot_social.vad_node:main', ], }, ) diff --git a/jetson/ros2_ws/src/saltybot_social/test/test_vad_node.py b/jetson/ros2_ws/src/saltybot_social/test/test_vad_node.py new file mode 100644 index 0000000..a60dc8b --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_social/test/test_vad_node.py @@ -0,0 +1,300 @@ +"""test_vad_node.py -- Unit tests for Issue #242 energy+ZCR VAD node.""" + +from __future__ import annotations +import importlib.util, math, os, struct, sys, types +import pytest + + +def _pkg_root(): + return os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + + +def _read_src(rel_path): + with open(os.path.join(_pkg_root(), rel_path)) as f: + return f.read() + + +def _import_vad(): + """Import vad_node without a live ROS2 environment.""" + for mod_name in ("rclpy", "rclpy.node", "rclpy.qos", + "std_msgs", "std_msgs.msg"): + if mod_name not in sys.modules: + sys.modules[mod_name] = types.ModuleType(mod_name) + + rclpy_node = sys.modules["rclpy.node"] + rclpy_qos = sys.modules["rclpy.qos"] + std_msg = sys.modules["std_msgs.msg"] + + class _Node: + def __init__(self, *a, **kw): pass + def declare_parameter(self, *a, **kw): pass + def get_parameter(self, name): + defaults = { + "sample_rate": 16000, "rms_threshold_db": -35.0, + "zcr_min": 0.01, "zcr_max": 0.40, + "onset_frames": 2, "offset_frames": 8, + "audio_topic": "/social/speech/audio_raw", + } + class _P: + value = defaults.get(name) + return _P() + def create_publisher(self, *a, **kw): return None + def create_subscription(self, *a, **kw): return None + def get_logger(self): + class _L: + def info(self, *a): pass + def warn(self, *a): pass + def error(self, *a): pass + return _L() + def destroy_node(self): pass + + rclpy_node.Node = _Node + rclpy_qos.QoSProfile = type("QoSProfile", (), {"__init__": lambda s, **kw: None}) + std_msg.Bool = type("Bool", (), {"data": False}) + std_msg.Float32 = type("Float32", (), {"data": 0.0}) + std_msg.UInt8MultiArray = type("UInt8MultiArray", (), {"data": b""}) + sys.modules["rclpy"].init = lambda *a, **kw: None + sys.modules["rclpy"].spin = lambda n: None + sys.modules["rclpy"].ok = lambda: True + sys.modules["rclpy"].shutdown = lambda: None + + spec = importlib.util.spec_from_file_location( + "vad_node_testmod", + os.path.join(_pkg_root(), "saltybot_social", "vad_node.py"), + ) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + return mod + + +# ── helpers ─────────────────────────────────────────────────────────────────── + +def _sine(freq_hz, sample_rate, n_samples, amplitude=0.3): + return [amplitude * math.sin(2 * math.pi * freq_hz * i / sample_rate) + for i in range(n_samples)] + + +def _silence(n_samples): + return [0.0] * n_samples + + +def _pcm16_bytes(samples): + ints = [max(-32768, min(32767, int(s * 32768))) for s in samples] + return struct.pack(f"<{len(ints)}h", *ints) + + +# ── pcm16_bytes_to_float32 ──────────────────────────────────────────────────── + +class TestPcm16Convert: + @pytest.fixture(scope="class") + def mod(self): return _import_vad() + + def test_empty(self, mod): + assert mod.pcm16_bytes_to_float32(b"") == [] + + def test_odd_byte_ignored(self, mod): + data = struct.pack("= 0.0 + + +# ── rms_db ──────────────────────────────────────────────────────────────────── + +class TestRmsDb: + @pytest.fixture(scope="class") + def mod(self): return _import_vad() + + def test_silence_returns_low(self, mod): + assert mod.rms_db([0.0] * 100) < -90.0 + + def test_full_scale_near_zero(self, mod): + assert abs(mod.rms_db([1.0] * 100)) < 1.0 + + def test_half_amplitude(self, mod): + assert abs(mod.rms_db([0.5] * 100) - (-6.0)) < 1.0 + + def test_louder_is_higher(self, mod): + assert mod.rms_db([0.5] * 100) > mod.rms_db([0.01] * 100) + + def test_below_threshold(self, mod): + assert mod.rms_db([0.001] * 480) < -35.0 + + +# ── zero_crossing_rate ──────────────────────────────────────────────────────── + +class TestZeroCrossingRate: + @pytest.fixture(scope="class") + def mod(self): return _import_vad() + + def test_empty(self, mod): assert mod.zero_crossing_rate([]) == 0.0 + def test_single(self, mod): assert mod.zero_crossing_rate([0.5]) == 0.0 + def test_silence(self, mod): assert mod.zero_crossing_rate([0.0] * 100) == 0.0 + def test_constant_pos(self, mod): assert mod.zero_crossing_rate([0.5] * 100) == 0.0 + + def test_alternating_full(self, mod): + s = [(-1.0) ** i for i in range(100)] + assert abs(mod.zero_crossing_rate(s) - 1.0) < 1e-6 + + def test_sine_in_range(self, mod): + zcr = mod.zero_crossing_rate(_sine(200, 16000, 1600)) + assert 0.01 < zcr < 0.10 + + def test_high_freq_higher_zcr(self, mod): + lo = mod.zero_crossing_rate(_sine(100, 16000, 1600)) + hi = mod.zero_crossing_rate(_sine(4000, 16000, 1600)) + assert hi > lo + + def test_in_unit_interval(self, mod): + zcr = mod.zero_crossing_rate(_sine(440, 16000, 480)) + assert 0.0 <= zcr <= 1.0 + + +# ── VadStateMachine ─────────────────────────────────────────────────────────── + +class TestVadStateMachine: + @pytest.fixture(scope="class") + def mod(self): return _import_vad() + + def test_initial_inactive(self, mod): + assert not mod.VadStateMachine(onset_frames=2, offset_frames=3).is_active + + def test_onset_requires_n_frames(self, mod): + sm = mod.VadStateMachine(onset_frames=3, offset_frames=5) + assert not sm.update(True) + assert not sm.update(True) + assert sm.update(True) + + def test_offset_requires_n_frames(self, mod): + sm = mod.VadStateMachine(onset_frames=1, offset_frames=3) + sm.update(True) + assert sm.update(False) + assert sm.update(False) + assert not sm.update(False) + + def test_reset(self, mod): + sm = mod.VadStateMachine(onset_frames=1, offset_frames=1) + sm.update(True) + sm.reset() + assert not sm.is_active + + def test_stays_active_with_speech(self, mod): + sm = mod.VadStateMachine(onset_frames=1, offset_frames=10) + sm.update(True) + for _ in range(20): + assert sm.update(True) + + def test_onset1_offset1(self, mod): + sm = mod.VadStateMachine(onset_frames=1, offset_frames=1) + assert sm.update(True) + assert not sm.update(False) + + +# ── Combined decision logic ─────────────────────────────────────────────────── + +class TestCombinedDecision: + @pytest.fixture(scope="class") + def mod(self): return _import_vad() + + def _decide(self, energy_db, zcr, rms_thresh=-35.0, zcr_min=0.01, zcr_max=0.40): + return (energy_db >= rms_thresh) and (zcr_min <= zcr <= zcr_max) + + def test_normal_speech(self): assert self._decide(-20.0, 0.08) + def test_below_energy_threshold(self): assert not self._decide(-40.0, 0.08) + def test_zcr_too_low(self): assert not self._decide(-20.0, 0.005) + def test_zcr_too_high(self): assert not self._decide(-20.0, 0.50) + def test_energy_at_threshold(self): assert self._decide(-35.0, 0.08) + def test_zcr_at_min_boundary(self): assert self._decide(-20.0, 0.01) + def test_zcr_at_max_boundary(self): assert self._decide(-20.0, 0.40) + def test_loud_noise_high_zcr(self): assert not self._decide(-10.0, 0.50) + + def test_integration_voiced_speech(self, mod): + samples = _sine(300, 16000, 480, amplitude=0.1) + energy_db = mod.rms_db(samples) + zcr = mod.zero_crossing_rate(samples) + assert (energy_db >= -35.0) and (0.01 <= zcr <= 0.40) + + def test_integration_silence(self, mod): + assert mod.rms_db(_silence(480)) < -35.0 + + +# ── Node source checks ──────────────────────────────────────────────────────── + +class TestVadNodeSrc: + @pytest.fixture(scope="class") + def src(self): return _read_src("saltybot_social/vad_node.py") + + def test_class_defined(self, src): assert "class VadNode" in src + def test_state_machine(self, src): assert "class VadStateMachine" in src + def test_rms_threshold_param(self, src): assert '"rms_threshold_db"' in src + def test_zcr_min_param(self, src): assert '"zcr_min"' in src + def test_zcr_max_param(self, src): assert '"zcr_max"' in src + def test_onset_frames_param(self, src): assert '"onset_frames"' in src + def test_offset_frames_param(self, src): assert '"offset_frames"' in src + def test_audio_topic_param(self, src): assert '"audio_topic"' in src + def test_is_speaking_topic(self, src): assert '"/social/speech/is_speaking"' in src + def test_energy_topic(self, src): assert '"/social/speech/energy"' in src + def test_audio_raw_default(self, src): assert '"/social/speech/audio_raw"' in src + def test_bool_pub(self, src): assert "Bool" in src + def test_float32_pub(self, src): assert "Float32" in src + def test_uint8_sub(self, src): assert "UInt8MultiArray" in src + def test_rms_fn(self, src): assert "rms_db" in src + def test_zcr_fn(self, src): assert "zero_crossing_rate" in src + def test_pcm_convert(self, src): assert "pcm16_bytes_to_float32" in src + def test_hysteresis(self, src): assert "onset_frames" in src and "offset_frames" in src + def test_issue_tag(self, src): assert "242" in src + def test_main(self, src): assert "def main" in src + + +# ── Config + setup.py ───────────────────────────────────────────────────────── + +class TestConfig: + @pytest.fixture(scope="class") + def cfg(self): return _read_src("config/vad_params.yaml") + + @pytest.fixture(scope="class") + def setup(self): return _read_src("setup.py") + + def test_node_name(self, cfg): assert "vad_node:" in cfg + def test_rms_param(self, cfg): assert "rms_threshold_db" in cfg + def test_zcr_min(self, cfg): assert "zcr_min" in cfg + def test_zcr_max(self, cfg): assert "zcr_max" in cfg + def test_onset(self, cfg): assert "onset_frames" in cfg + def test_offset(self, cfg): assert "offset_frames" in cfg + def test_defaults(self, cfg): assert "-35.0" in cfg and "0.01" in cfg + def test_entry_point(self, setup): assert "vad_node = saltybot_social.vad_node:main" in setup -- 2.47.2