feat(social): energy+ZCR voice activity detection node (Issue #242) #247

Merged
sl-jetson merged 1 commits from sl-jetson/issue-242-vad into main 2026-03-02 12:46:26 -05:00
5 changed files with 551 additions and 0 deletions

View File

@ -0,0 +1,9 @@
vad_node:
ros__parameters:
sample_rate: 16000 # Expected sample rate of incoming PCM-16 audio (Hz)
rms_threshold_db: -35.0 # Energy gate (dBFS); frames below this are silent
zcr_min: 0.01 # ZCR lower bound — rejects DC/low-freq rumble
zcr_max: 0.40 # ZCR upper bound — rejects high-freq noise
onset_frames: 2 # Consecutive active frames before is_speaking=true
offset_frames: 8 # Consecutive silent frames before is_speaking=false
audio_topic: "/social/speech/audio_raw" # Source PCM-16 UInt8MultiArray topic

View File

@ -0,0 +1,48 @@
"""vad.launch.py — Launch the energy+ZCR VAD node (Issue #242).
Usage:
ros2 launch saltybot_social vad.launch.py
ros2 launch saltybot_social vad.launch.py rms_threshold_db:=-40.0
"""
import os
from ament_index_python.packages import get_package_share_directory
from launch import LaunchDescription
from launch.actions import DeclareLaunchArgument
from launch.substitutions import LaunchConfiguration
from launch_ros.actions import Node
def generate_launch_description():
pkg = get_package_share_directory("saltybot_social")
cfg = os.path.join(pkg, "config", "vad_params.yaml")
return LaunchDescription([
DeclareLaunchArgument("rms_threshold_db", default_value="-35.0",
description="Energy gate in dBFS"),
DeclareLaunchArgument("zcr_min", default_value="0.01",
description="ZCR lower bound"),
DeclareLaunchArgument("zcr_max", default_value="0.40",
description="ZCR upper bound"),
DeclareLaunchArgument("onset_frames", default_value="2",
description="Onset hysteresis frames"),
DeclareLaunchArgument("offset_frames", default_value="8",
description="Offset hysteresis frames"),
Node(
package="saltybot_social",
executable="vad_node",
name="vad_node",
output="screen",
parameters=[
cfg,
{
"rms_threshold_db": LaunchConfiguration("rms_threshold_db"),
"zcr_min": LaunchConfiguration("zcr_min"),
"zcr_max": LaunchConfiguration("zcr_max"),
"onset_frames": LaunchConfiguration("onset_frames"),
"offset_frames": LaunchConfiguration("offset_frames"),
},
],
),
])

View File

@ -0,0 +1,192 @@
"""vad_node.py — Energy + zero-crossing rate voice activity detection.
Issue #242
Subscribes to raw PCM-16 audio on /social/speech/audio_raw, computes
per-chunk RMS energy (dBFS) and zero-crossing rate (ZCR), applies
onset/offset hysteresis, and publishes:
/social/speech/is_speaking (std_msgs/Bool) VAD decision
/social/speech/energy (std_msgs/Float32) linear RMS [0..1]
The combined decision rule:
active = energy_db >= rms_threshold_db AND zcr in [zcr_min, zcr_max]
ZCR bands for 16 kHz audio (typical):
Silence / low-freq rumble : ZCR < 0.01
Voiced speech : ZCR 0.010.20
Unvoiced / sibilants : ZCR 0.200.40
High-freq noise : ZCR > 0.40
Parameters:
sample_rate (int, 16000) expected sample rate of incoming audio
rms_threshold_db (float, -35.0) energy gate (dBFS); below = silence
zcr_min (float, 0.01) ZCR lower bound; below = rumble/DC
zcr_max (float, 0.40) ZCR upper bound; above = noise
onset_frames (int, 2) consecutive active frames to set is_speaking
offset_frames (int, 8) consecutive silent frames to clear is_speaking
audio_topic (str, "/social/speech/audio_raw") input PCM-16 topic
"""
from __future__ import annotations
import math
import struct
from typing import Optional
import rclpy
from rclpy.node import Node
from rclpy.qos import QoSProfile
from std_msgs.msg import Bool, Float32, UInt8MultiArray
INT16_MAX = 32768.0
# ── Pure signal helpers (no ROS) ──────────────────────────────────────────────
def pcm16_bytes_to_float32(data: bytes) -> list:
"""Convert raw PCM-16 LE bytes → float32 list in [-1.0, 1.0]."""
n = len(data) // 2
if n == 0:
return []
samples = struct.unpack(f"<{n}h", data[: n * 2])
return [s / INT16_MAX for s in samples]
def rms_linear(samples: list) -> float:
"""RMS amplitude in [0.0, 1.0]. Returns 0.0 for empty / silent input."""
if not samples:
return 0.0
mean_sq = sum(s * s for s in samples) / len(samples)
return math.sqrt(mean_sq) if mean_sq > 0.0 else 0.0
def rms_db(samples: list) -> float:
"""RMS energy in dBFS. Returns -96.0 for silence."""
rms = rms_linear(samples)
return 20.0 * math.log10(max(rms, 1e-10))
def zero_crossing_rate(samples: list) -> float:
"""Zero-crossing rate: fraction of consecutive pairs with opposite signs.
Returns a value in [0.0, 1.0]. Silence returns 0.0.
"""
n = len(samples)
if n < 2:
return 0.0
crossings = sum(
1 for i in range(1, n) if samples[i - 1] * samples[i] < 0
)
return crossings / (n - 1)
# ── Hysteresis state machine ──────────────────────────────────────────────────
class VadStateMachine:
"""Onset/offset hysteresis on a per-frame boolean signal."""
def __init__(self, onset_frames: int = 2, offset_frames: int = 8) -> None:
self.onset_frames = onset_frames
self.offset_frames = offset_frames
self._above = 0
self._below = 0
self._active = False
def update(self, raw_active: bool) -> bool:
if raw_active:
self._above += 1
self._below = 0
if self._above >= self.onset_frames:
self._active = True
else:
self._below += 1
self._above = 0
if self._below >= self.offset_frames:
self._active = False
return self._active
def reset(self) -> None:
self._above = 0
self._below = 0
self._active = False
@property
def is_active(self) -> bool:
return self._active
# ── ROS2 node ─────────────────────────────────────────────────────────────────
class VadNode(Node):
"""Energy + ZCR voice activity detector — subscribes to raw audio."""
def __init__(self) -> None:
super().__init__("vad_node")
self.declare_parameter("sample_rate", 16000)
self.declare_parameter("rms_threshold_db", -35.0)
self.declare_parameter("zcr_min", 0.01)
self.declare_parameter("zcr_max", 0.40)
self.declare_parameter("onset_frames", 2)
self.declare_parameter("offset_frames", 8)
self.declare_parameter("audio_topic", "/social/speech/audio_raw")
self._sample_rate = self.get_parameter("sample_rate").value
self._rms_thresh = self.get_parameter("rms_threshold_db").value
self._zcr_min = self.get_parameter("zcr_min").value
self._zcr_max = self.get_parameter("zcr_max").value
audio_topic = self.get_parameter("audio_topic").value
self._sm = VadStateMachine(
onset_frames = self.get_parameter("onset_frames").value,
offset_frames = self.get_parameter("offset_frames").value,
)
qos = QoSProfile(depth=10)
self._speaking_pub = self.create_publisher(Bool, "/social/speech/is_speaking", qos)
self._energy_pub = self.create_publisher(Float32, "/social/speech/energy", qos)
self._audio_sub = self.create_subscription(
UInt8MultiArray, audio_topic, self._on_audio, qos
)
self.get_logger().info(
f"VadNode ready (rms_thresh={self._rms_thresh} dBFS, "
f"zcr=[{self._zcr_min},{self._zcr_max}], "
f"topic={audio_topic})"
)
def _on_audio(self, msg: UInt8MultiArray) -> None:
samples = pcm16_bytes_to_float32(bytes(msg.data))
if not samples:
return
energy_lin = rms_linear(samples)
energy_db = rms_db(samples)
zcr = zero_crossing_rate(samples)
raw_active = (
energy_db >= self._rms_thresh
and self._zcr_min <= zcr <= self._zcr_max
)
is_speaking = self._sm.update(raw_active)
bool_msg = Bool()
bool_msg.data = is_speaking
self._speaking_pub.publish(bool_msg)
energy_msg = Float32()
energy_msg.data = float(energy_lin)
self._energy_pub.publish(energy_msg)
def main(args: Optional[list] = None) -> None:
rclpy.init(args=args)
node = VadNode()
try:
rclpy.spin(node)
except KeyboardInterrupt:
pass
finally:
node.destroy_node()
rclpy.shutdown()

View File

@ -43,6 +43,8 @@ setup(
'emotion_node = saltybot_social.emotion_node:main',
# Robot mesh communication (Issue #171)
'mesh_comms_node = saltybot_social.mesh_comms_node:main',
# Energy+ZCR voice activity detection (Issue #242)
'vad_node = saltybot_social.vad_node:main',
],
},
)

View File

@ -0,0 +1,300 @@
"""test_vad_node.py -- Unit tests for Issue #242 energy+ZCR VAD node."""
from __future__ import annotations
import importlib.util, math, os, struct, sys, types
import pytest
def _pkg_root():
return os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
def _read_src(rel_path):
with open(os.path.join(_pkg_root(), rel_path)) as f:
return f.read()
def _import_vad():
"""Import vad_node without a live ROS2 environment."""
for mod_name in ("rclpy", "rclpy.node", "rclpy.qos",
"std_msgs", "std_msgs.msg"):
if mod_name not in sys.modules:
sys.modules[mod_name] = types.ModuleType(mod_name)
rclpy_node = sys.modules["rclpy.node"]
rclpy_qos = sys.modules["rclpy.qos"]
std_msg = sys.modules["std_msgs.msg"]
class _Node:
def __init__(self, *a, **kw): pass
def declare_parameter(self, *a, **kw): pass
def get_parameter(self, name):
defaults = {
"sample_rate": 16000, "rms_threshold_db": -35.0,
"zcr_min": 0.01, "zcr_max": 0.40,
"onset_frames": 2, "offset_frames": 8,
"audio_topic": "/social/speech/audio_raw",
}
class _P:
value = defaults.get(name)
return _P()
def create_publisher(self, *a, **kw): return None
def create_subscription(self, *a, **kw): return None
def get_logger(self):
class _L:
def info(self, *a): pass
def warn(self, *a): pass
def error(self, *a): pass
return _L()
def destroy_node(self): pass
rclpy_node.Node = _Node
rclpy_qos.QoSProfile = type("QoSProfile", (), {"__init__": lambda s, **kw: None})
std_msg.Bool = type("Bool", (), {"data": False})
std_msg.Float32 = type("Float32", (), {"data": 0.0})
std_msg.UInt8MultiArray = type("UInt8MultiArray", (), {"data": b""})
sys.modules["rclpy"].init = lambda *a, **kw: None
sys.modules["rclpy"].spin = lambda n: None
sys.modules["rclpy"].ok = lambda: True
sys.modules["rclpy"].shutdown = lambda: None
spec = importlib.util.spec_from_file_location(
"vad_node_testmod",
os.path.join(_pkg_root(), "saltybot_social", "vad_node.py"),
)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
return mod
# ── helpers ───────────────────────────────────────────────────────────────────
def _sine(freq_hz, sample_rate, n_samples, amplitude=0.3):
return [amplitude * math.sin(2 * math.pi * freq_hz * i / sample_rate)
for i in range(n_samples)]
def _silence(n_samples):
return [0.0] * n_samples
def _pcm16_bytes(samples):
ints = [max(-32768, min(32767, int(s * 32768))) for s in samples]
return struct.pack(f"<{len(ints)}h", *ints)
# ── pcm16_bytes_to_float32 ────────────────────────────────────────────────────
class TestPcm16Convert:
@pytest.fixture(scope="class")
def mod(self): return _import_vad()
def test_empty(self, mod):
assert mod.pcm16_bytes_to_float32(b"") == []
def test_odd_byte_ignored(self, mod):
data = struct.pack("<h", 16384) + b"\x00"
assert len(mod.pcm16_bytes_to_float32(data)) == 1
def test_positive_full_scale(self, mod):
data = struct.pack("<h", 32767)
result = mod.pcm16_bytes_to_float32(data)
assert abs(result[0] - 32767 / 32768.0) < 1e-4
def test_negative(self, mod):
data = struct.pack("<h", -16384)
assert mod.pcm16_bytes_to_float32(data)[0] < 0
def test_silence(self, mod):
data = struct.pack("<4h", 0, 0, 0, 0)
assert all(s == 0.0 for s in mod.pcm16_bytes_to_float32(data))
def test_roundtrip_length(self, mod):
n = 480
data = _pcm16_bytes(_sine(440, 16000, n, 0.5))
assert len(mod.pcm16_bytes_to_float32(data)) == n
# ── rms_linear ────────────────────────────────────────────────────────────────
class TestRmsLinear:
@pytest.fixture(scope="class")
def mod(self): return _import_vad()
def test_empty(self, mod): assert mod.rms_linear([]) == 0.0
def test_silence(self, mod): assert mod.rms_linear([0.0] * 100) == 0.0
def test_full_scale(self, mod): assert abs(mod.rms_linear([1.0] * 100) - 1.0) < 1e-6
def test_half_amplitude(self, mod): assert abs(mod.rms_linear([0.5] * 100) - 0.5) < 1e-6
def test_sine_rms(self, mod):
s = _sine(440, 16000, 1600, amplitude=1.0)
assert abs(mod.rms_linear(s) - 1.0 / math.sqrt(2)) < 0.01
def test_nonnegative(self, mod):
assert mod.rms_linear(_sine(200, 16000, 480, 0.3)) >= 0.0
# ── rms_db ────────────────────────────────────────────────────────────────────
class TestRmsDb:
@pytest.fixture(scope="class")
def mod(self): return _import_vad()
def test_silence_returns_low(self, mod):
assert mod.rms_db([0.0] * 100) < -90.0
def test_full_scale_near_zero(self, mod):
assert abs(mod.rms_db([1.0] * 100)) < 1.0
def test_half_amplitude(self, mod):
assert abs(mod.rms_db([0.5] * 100) - (-6.0)) < 1.0
def test_louder_is_higher(self, mod):
assert mod.rms_db([0.5] * 100) > mod.rms_db([0.01] * 100)
def test_below_threshold(self, mod):
assert mod.rms_db([0.001] * 480) < -35.0
# ── zero_crossing_rate ────────────────────────────────────────────────────────
class TestZeroCrossingRate:
@pytest.fixture(scope="class")
def mod(self): return _import_vad()
def test_empty(self, mod): assert mod.zero_crossing_rate([]) == 0.0
def test_single(self, mod): assert mod.zero_crossing_rate([0.5]) == 0.0
def test_silence(self, mod): assert mod.zero_crossing_rate([0.0] * 100) == 0.0
def test_constant_pos(self, mod): assert mod.zero_crossing_rate([0.5] * 100) == 0.0
def test_alternating_full(self, mod):
s = [(-1.0) ** i for i in range(100)]
assert abs(mod.zero_crossing_rate(s) - 1.0) < 1e-6
def test_sine_in_range(self, mod):
zcr = mod.zero_crossing_rate(_sine(200, 16000, 1600))
assert 0.01 < zcr < 0.10
def test_high_freq_higher_zcr(self, mod):
lo = mod.zero_crossing_rate(_sine(100, 16000, 1600))
hi = mod.zero_crossing_rate(_sine(4000, 16000, 1600))
assert hi > lo
def test_in_unit_interval(self, mod):
zcr = mod.zero_crossing_rate(_sine(440, 16000, 480))
assert 0.0 <= zcr <= 1.0
# ── VadStateMachine ───────────────────────────────────────────────────────────
class TestVadStateMachine:
@pytest.fixture(scope="class")
def mod(self): return _import_vad()
def test_initial_inactive(self, mod):
assert not mod.VadStateMachine(onset_frames=2, offset_frames=3).is_active
def test_onset_requires_n_frames(self, mod):
sm = mod.VadStateMachine(onset_frames=3, offset_frames=5)
assert not sm.update(True)
assert not sm.update(True)
assert sm.update(True)
def test_offset_requires_n_frames(self, mod):
sm = mod.VadStateMachine(onset_frames=1, offset_frames=3)
sm.update(True)
assert sm.update(False)
assert sm.update(False)
assert not sm.update(False)
def test_reset(self, mod):
sm = mod.VadStateMachine(onset_frames=1, offset_frames=1)
sm.update(True)
sm.reset()
assert not sm.is_active
def test_stays_active_with_speech(self, mod):
sm = mod.VadStateMachine(onset_frames=1, offset_frames=10)
sm.update(True)
for _ in range(20):
assert sm.update(True)
def test_onset1_offset1(self, mod):
sm = mod.VadStateMachine(onset_frames=1, offset_frames=1)
assert sm.update(True)
assert not sm.update(False)
# ── Combined decision logic ───────────────────────────────────────────────────
class TestCombinedDecision:
@pytest.fixture(scope="class")
def mod(self): return _import_vad()
def _decide(self, energy_db, zcr, rms_thresh=-35.0, zcr_min=0.01, zcr_max=0.40):
return (energy_db >= rms_thresh) and (zcr_min <= zcr <= zcr_max)
def test_normal_speech(self): assert self._decide(-20.0, 0.08)
def test_below_energy_threshold(self): assert not self._decide(-40.0, 0.08)
def test_zcr_too_low(self): assert not self._decide(-20.0, 0.005)
def test_zcr_too_high(self): assert not self._decide(-20.0, 0.50)
def test_energy_at_threshold(self): assert self._decide(-35.0, 0.08)
def test_zcr_at_min_boundary(self): assert self._decide(-20.0, 0.01)
def test_zcr_at_max_boundary(self): assert self._decide(-20.0, 0.40)
def test_loud_noise_high_zcr(self): assert not self._decide(-10.0, 0.50)
def test_integration_voiced_speech(self, mod):
samples = _sine(300, 16000, 480, amplitude=0.1)
energy_db = mod.rms_db(samples)
zcr = mod.zero_crossing_rate(samples)
assert (energy_db >= -35.0) and (0.01 <= zcr <= 0.40)
def test_integration_silence(self, mod):
assert mod.rms_db(_silence(480)) < -35.0
# ── Node source checks ────────────────────────────────────────────────────────
class TestVadNodeSrc:
@pytest.fixture(scope="class")
def src(self): return _read_src("saltybot_social/vad_node.py")
def test_class_defined(self, src): assert "class VadNode" in src
def test_state_machine(self, src): assert "class VadStateMachine" in src
def test_rms_threshold_param(self, src): assert '"rms_threshold_db"' in src
def test_zcr_min_param(self, src): assert '"zcr_min"' in src
def test_zcr_max_param(self, src): assert '"zcr_max"' in src
def test_onset_frames_param(self, src): assert '"onset_frames"' in src
def test_offset_frames_param(self, src): assert '"offset_frames"' in src
def test_audio_topic_param(self, src): assert '"audio_topic"' in src
def test_is_speaking_topic(self, src): assert '"/social/speech/is_speaking"' in src
def test_energy_topic(self, src): assert '"/social/speech/energy"' in src
def test_audio_raw_default(self, src): assert '"/social/speech/audio_raw"' in src
def test_bool_pub(self, src): assert "Bool" in src
def test_float32_pub(self, src): assert "Float32" in src
def test_uint8_sub(self, src): assert "UInt8MultiArray" in src
def test_rms_fn(self, src): assert "rms_db" in src
def test_zcr_fn(self, src): assert "zero_crossing_rate" in src
def test_pcm_convert(self, src): assert "pcm16_bytes_to_float32" in src
def test_hysteresis(self, src): assert "onset_frames" in src and "offset_frames" in src
def test_issue_tag(self, src): assert "242" in src
def test_main(self, src): assert "def main" in src
# ── Config + setup.py ─────────────────────────────────────────────────────────
class TestConfig:
@pytest.fixture(scope="class")
def cfg(self): return _read_src("config/vad_params.yaml")
@pytest.fixture(scope="class")
def setup(self): return _read_src("setup.py")
def test_node_name(self, cfg): assert "vad_node:" in cfg
def test_rms_param(self, cfg): assert "rms_threshold_db" in cfg
def test_zcr_min(self, cfg): assert "zcr_min" in cfg
def test_zcr_max(self, cfg): assert "zcr_max" in cfg
def test_onset(self, cfg): assert "onset_frames" in cfg
def test_offset(self, cfg): assert "offset_frames" in cfg
def test_defaults(self, cfg): assert "-35.0" in cfg and "0.01" in cfg
def test_entry_point(self, setup): assert "vad_node = saltybot_social.vad_node:main" in setup