Merge pull request 'feat(social): energy+ZCR voice activity detection node (Issue #242)' (#247) from sl-jetson/issue-242-vad into main
Some checks failed
Some checks failed
This commit is contained in:
commit
82e836ec3f
@ -0,0 +1,9 @@
|
|||||||
|
vad_node:
|
||||||
|
ros__parameters:
|
||||||
|
sample_rate: 16000 # Expected sample rate of incoming PCM-16 audio (Hz)
|
||||||
|
rms_threshold_db: -35.0 # Energy gate (dBFS); frames below this are silent
|
||||||
|
zcr_min: 0.01 # ZCR lower bound — rejects DC/low-freq rumble
|
||||||
|
zcr_max: 0.40 # ZCR upper bound — rejects high-freq noise
|
||||||
|
onset_frames: 2 # Consecutive active frames before is_speaking=true
|
||||||
|
offset_frames: 8 # Consecutive silent frames before is_speaking=false
|
||||||
|
audio_topic: "/social/speech/audio_raw" # Source PCM-16 UInt8MultiArray topic
|
||||||
48
jetson/ros2_ws/src/saltybot_social/launch/vad.launch.py
Normal file
48
jetson/ros2_ws/src/saltybot_social/launch/vad.launch.py
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
"""vad.launch.py — Launch the energy+ZCR VAD node (Issue #242).
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
ros2 launch saltybot_social vad.launch.py
|
||||||
|
ros2 launch saltybot_social vad.launch.py rms_threshold_db:=-40.0
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from ament_index_python.packages import get_package_share_directory
|
||||||
|
from launch import LaunchDescription
|
||||||
|
from launch.actions import DeclareLaunchArgument
|
||||||
|
from launch.substitutions import LaunchConfiguration
|
||||||
|
from launch_ros.actions import Node
|
||||||
|
|
||||||
|
|
||||||
|
def generate_launch_description():
|
||||||
|
pkg = get_package_share_directory("saltybot_social")
|
||||||
|
cfg = os.path.join(pkg, "config", "vad_params.yaml")
|
||||||
|
|
||||||
|
return LaunchDescription([
|
||||||
|
DeclareLaunchArgument("rms_threshold_db", default_value="-35.0",
|
||||||
|
description="Energy gate in dBFS"),
|
||||||
|
DeclareLaunchArgument("zcr_min", default_value="0.01",
|
||||||
|
description="ZCR lower bound"),
|
||||||
|
DeclareLaunchArgument("zcr_max", default_value="0.40",
|
||||||
|
description="ZCR upper bound"),
|
||||||
|
DeclareLaunchArgument("onset_frames", default_value="2",
|
||||||
|
description="Onset hysteresis frames"),
|
||||||
|
DeclareLaunchArgument("offset_frames", default_value="8",
|
||||||
|
description="Offset hysteresis frames"),
|
||||||
|
|
||||||
|
Node(
|
||||||
|
package="saltybot_social",
|
||||||
|
executable="vad_node",
|
||||||
|
name="vad_node",
|
||||||
|
output="screen",
|
||||||
|
parameters=[
|
||||||
|
cfg,
|
||||||
|
{
|
||||||
|
"rms_threshold_db": LaunchConfiguration("rms_threshold_db"),
|
||||||
|
"zcr_min": LaunchConfiguration("zcr_min"),
|
||||||
|
"zcr_max": LaunchConfiguration("zcr_max"),
|
||||||
|
"onset_frames": LaunchConfiguration("onset_frames"),
|
||||||
|
"offset_frames": LaunchConfiguration("offset_frames"),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
),
|
||||||
|
])
|
||||||
192
jetson/ros2_ws/src/saltybot_social/saltybot_social/vad_node.py
Normal file
192
jetson/ros2_ws/src/saltybot_social/saltybot_social/vad_node.py
Normal file
@ -0,0 +1,192 @@
|
|||||||
|
"""vad_node.py — Energy + zero-crossing rate voice activity detection.
|
||||||
|
Issue #242
|
||||||
|
|
||||||
|
Subscribes to raw PCM-16 audio on /social/speech/audio_raw, computes
|
||||||
|
per-chunk RMS energy (dBFS) and zero-crossing rate (ZCR), applies
|
||||||
|
onset/offset hysteresis, and publishes:
|
||||||
|
|
||||||
|
/social/speech/is_speaking (std_msgs/Bool) — VAD decision
|
||||||
|
/social/speech/energy (std_msgs/Float32) — linear RMS [0..1]
|
||||||
|
|
||||||
|
The combined decision rule:
|
||||||
|
active = energy_db >= rms_threshold_db AND zcr in [zcr_min, zcr_max]
|
||||||
|
|
||||||
|
ZCR bands for 16 kHz audio (typical):
|
||||||
|
Silence / low-freq rumble : ZCR < 0.01
|
||||||
|
Voiced speech : ZCR 0.01–0.20
|
||||||
|
Unvoiced / sibilants : ZCR 0.20–0.40
|
||||||
|
High-freq noise : ZCR > 0.40
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
sample_rate (int, 16000) — expected sample rate of incoming audio
|
||||||
|
rms_threshold_db (float, -35.0) — energy gate (dBFS); below = silence
|
||||||
|
zcr_min (float, 0.01) — ZCR lower bound; below = rumble/DC
|
||||||
|
zcr_max (float, 0.40) — ZCR upper bound; above = noise
|
||||||
|
onset_frames (int, 2) — consecutive active frames to set is_speaking
|
||||||
|
offset_frames (int, 8) — consecutive silent frames to clear is_speaking
|
||||||
|
audio_topic (str, "/social/speech/audio_raw") — input PCM-16 topic
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import math
|
||||||
|
import struct
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import rclpy
|
||||||
|
from rclpy.node import Node
|
||||||
|
from rclpy.qos import QoSProfile
|
||||||
|
from std_msgs.msg import Bool, Float32, UInt8MultiArray
|
||||||
|
|
||||||
|
INT16_MAX = 32768.0
|
||||||
|
|
||||||
|
|
||||||
|
# ── Pure signal helpers (no ROS) ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
def pcm16_bytes_to_float32(data: bytes) -> list:
|
||||||
|
"""Convert raw PCM-16 LE bytes → float32 list in [-1.0, 1.0]."""
|
||||||
|
n = len(data) // 2
|
||||||
|
if n == 0:
|
||||||
|
return []
|
||||||
|
samples = struct.unpack(f"<{n}h", data[: n * 2])
|
||||||
|
return [s / INT16_MAX for s in samples]
|
||||||
|
|
||||||
|
|
||||||
|
def rms_linear(samples: list) -> float:
|
||||||
|
"""RMS amplitude in [0.0, 1.0]. Returns 0.0 for empty / silent input."""
|
||||||
|
if not samples:
|
||||||
|
return 0.0
|
||||||
|
mean_sq = sum(s * s for s in samples) / len(samples)
|
||||||
|
return math.sqrt(mean_sq) if mean_sq > 0.0 else 0.0
|
||||||
|
|
||||||
|
|
||||||
|
def rms_db(samples: list) -> float:
|
||||||
|
"""RMS energy in dBFS. Returns -96.0 for silence."""
|
||||||
|
rms = rms_linear(samples)
|
||||||
|
return 20.0 * math.log10(max(rms, 1e-10))
|
||||||
|
|
||||||
|
|
||||||
|
def zero_crossing_rate(samples: list) -> float:
|
||||||
|
"""Zero-crossing rate: fraction of consecutive pairs with opposite signs.
|
||||||
|
|
||||||
|
Returns a value in [0.0, 1.0]. Silence returns 0.0.
|
||||||
|
"""
|
||||||
|
n = len(samples)
|
||||||
|
if n < 2:
|
||||||
|
return 0.0
|
||||||
|
crossings = sum(
|
||||||
|
1 for i in range(1, n) if samples[i - 1] * samples[i] < 0
|
||||||
|
)
|
||||||
|
return crossings / (n - 1)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Hysteresis state machine ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class VadStateMachine:
|
||||||
|
"""Onset/offset hysteresis on a per-frame boolean signal."""
|
||||||
|
|
||||||
|
def __init__(self, onset_frames: int = 2, offset_frames: int = 8) -> None:
|
||||||
|
self.onset_frames = onset_frames
|
||||||
|
self.offset_frames = offset_frames
|
||||||
|
self._above = 0
|
||||||
|
self._below = 0
|
||||||
|
self._active = False
|
||||||
|
|
||||||
|
def update(self, raw_active: bool) -> bool:
|
||||||
|
if raw_active:
|
||||||
|
self._above += 1
|
||||||
|
self._below = 0
|
||||||
|
if self._above >= self.onset_frames:
|
||||||
|
self._active = True
|
||||||
|
else:
|
||||||
|
self._below += 1
|
||||||
|
self._above = 0
|
||||||
|
if self._below >= self.offset_frames:
|
||||||
|
self._active = False
|
||||||
|
return self._active
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
self._above = 0
|
||||||
|
self._below = 0
|
||||||
|
self._active = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_active(self) -> bool:
|
||||||
|
return self._active
|
||||||
|
|
||||||
|
|
||||||
|
# ── ROS2 node ─────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class VadNode(Node):
|
||||||
|
"""Energy + ZCR voice activity detector — subscribes to raw audio."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__("vad_node")
|
||||||
|
|
||||||
|
self.declare_parameter("sample_rate", 16000)
|
||||||
|
self.declare_parameter("rms_threshold_db", -35.0)
|
||||||
|
self.declare_parameter("zcr_min", 0.01)
|
||||||
|
self.declare_parameter("zcr_max", 0.40)
|
||||||
|
self.declare_parameter("onset_frames", 2)
|
||||||
|
self.declare_parameter("offset_frames", 8)
|
||||||
|
self.declare_parameter("audio_topic", "/social/speech/audio_raw")
|
||||||
|
|
||||||
|
self._sample_rate = self.get_parameter("sample_rate").value
|
||||||
|
self._rms_thresh = self.get_parameter("rms_threshold_db").value
|
||||||
|
self._zcr_min = self.get_parameter("zcr_min").value
|
||||||
|
self._zcr_max = self.get_parameter("zcr_max").value
|
||||||
|
audio_topic = self.get_parameter("audio_topic").value
|
||||||
|
|
||||||
|
self._sm = VadStateMachine(
|
||||||
|
onset_frames = self.get_parameter("onset_frames").value,
|
||||||
|
offset_frames = self.get_parameter("offset_frames").value,
|
||||||
|
)
|
||||||
|
|
||||||
|
qos = QoSProfile(depth=10)
|
||||||
|
self._speaking_pub = self.create_publisher(Bool, "/social/speech/is_speaking", qos)
|
||||||
|
self._energy_pub = self.create_publisher(Float32, "/social/speech/energy", qos)
|
||||||
|
|
||||||
|
self._audio_sub = self.create_subscription(
|
||||||
|
UInt8MultiArray, audio_topic, self._on_audio, qos
|
||||||
|
)
|
||||||
|
|
||||||
|
self.get_logger().info(
|
||||||
|
f"VadNode ready (rms_thresh={self._rms_thresh} dBFS, "
|
||||||
|
f"zcr=[{self._zcr_min},{self._zcr_max}], "
|
||||||
|
f"topic={audio_topic})"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _on_audio(self, msg: UInt8MultiArray) -> None:
|
||||||
|
samples = pcm16_bytes_to_float32(bytes(msg.data))
|
||||||
|
if not samples:
|
||||||
|
return
|
||||||
|
|
||||||
|
energy_lin = rms_linear(samples)
|
||||||
|
energy_db = rms_db(samples)
|
||||||
|
zcr = zero_crossing_rate(samples)
|
||||||
|
|
||||||
|
raw_active = (
|
||||||
|
energy_db >= self._rms_thresh
|
||||||
|
and self._zcr_min <= zcr <= self._zcr_max
|
||||||
|
)
|
||||||
|
is_speaking = self._sm.update(raw_active)
|
||||||
|
|
||||||
|
bool_msg = Bool()
|
||||||
|
bool_msg.data = is_speaking
|
||||||
|
self._speaking_pub.publish(bool_msg)
|
||||||
|
|
||||||
|
energy_msg = Float32()
|
||||||
|
energy_msg.data = float(energy_lin)
|
||||||
|
self._energy_pub.publish(energy_msg)
|
||||||
|
|
||||||
|
|
||||||
|
def main(args: Optional[list] = None) -> None:
|
||||||
|
rclpy.init(args=args)
|
||||||
|
node = VadNode()
|
||||||
|
try:
|
||||||
|
rclpy.spin(node)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
node.destroy_node()
|
||||||
|
rclpy.shutdown()
|
||||||
@ -43,6 +43,8 @@ setup(
|
|||||||
'emotion_node = saltybot_social.emotion_node:main',
|
'emotion_node = saltybot_social.emotion_node:main',
|
||||||
# Robot mesh communication (Issue #171)
|
# Robot mesh communication (Issue #171)
|
||||||
'mesh_comms_node = saltybot_social.mesh_comms_node:main',
|
'mesh_comms_node = saltybot_social.mesh_comms_node:main',
|
||||||
|
# Energy+ZCR voice activity detection (Issue #242)
|
||||||
|
'vad_node = saltybot_social.vad_node:main',
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|||||||
300
jetson/ros2_ws/src/saltybot_social/test/test_vad_node.py
Normal file
300
jetson/ros2_ws/src/saltybot_social/test/test_vad_node.py
Normal file
@ -0,0 +1,300 @@
|
|||||||
|
"""test_vad_node.py -- Unit tests for Issue #242 energy+ZCR VAD node."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
import importlib.util, math, os, struct, sys, types
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
def _pkg_root():
|
||||||
|
return os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|
||||||
|
|
||||||
|
def _read_src(rel_path):
|
||||||
|
with open(os.path.join(_pkg_root(), rel_path)) as f:
|
||||||
|
return f.read()
|
||||||
|
|
||||||
|
|
||||||
|
def _import_vad():
|
||||||
|
"""Import vad_node without a live ROS2 environment."""
|
||||||
|
for mod_name in ("rclpy", "rclpy.node", "rclpy.qos",
|
||||||
|
"std_msgs", "std_msgs.msg"):
|
||||||
|
if mod_name not in sys.modules:
|
||||||
|
sys.modules[mod_name] = types.ModuleType(mod_name)
|
||||||
|
|
||||||
|
rclpy_node = sys.modules["rclpy.node"]
|
||||||
|
rclpy_qos = sys.modules["rclpy.qos"]
|
||||||
|
std_msg = sys.modules["std_msgs.msg"]
|
||||||
|
|
||||||
|
class _Node:
|
||||||
|
def __init__(self, *a, **kw): pass
|
||||||
|
def declare_parameter(self, *a, **kw): pass
|
||||||
|
def get_parameter(self, name):
|
||||||
|
defaults = {
|
||||||
|
"sample_rate": 16000, "rms_threshold_db": -35.0,
|
||||||
|
"zcr_min": 0.01, "zcr_max": 0.40,
|
||||||
|
"onset_frames": 2, "offset_frames": 8,
|
||||||
|
"audio_topic": "/social/speech/audio_raw",
|
||||||
|
}
|
||||||
|
class _P:
|
||||||
|
value = defaults.get(name)
|
||||||
|
return _P()
|
||||||
|
def create_publisher(self, *a, **kw): return None
|
||||||
|
def create_subscription(self, *a, **kw): return None
|
||||||
|
def get_logger(self):
|
||||||
|
class _L:
|
||||||
|
def info(self, *a): pass
|
||||||
|
def warn(self, *a): pass
|
||||||
|
def error(self, *a): pass
|
||||||
|
return _L()
|
||||||
|
def destroy_node(self): pass
|
||||||
|
|
||||||
|
rclpy_node.Node = _Node
|
||||||
|
rclpy_qos.QoSProfile = type("QoSProfile", (), {"__init__": lambda s, **kw: None})
|
||||||
|
std_msg.Bool = type("Bool", (), {"data": False})
|
||||||
|
std_msg.Float32 = type("Float32", (), {"data": 0.0})
|
||||||
|
std_msg.UInt8MultiArray = type("UInt8MultiArray", (), {"data": b""})
|
||||||
|
sys.modules["rclpy"].init = lambda *a, **kw: None
|
||||||
|
sys.modules["rclpy"].spin = lambda n: None
|
||||||
|
sys.modules["rclpy"].ok = lambda: True
|
||||||
|
sys.modules["rclpy"].shutdown = lambda: None
|
||||||
|
|
||||||
|
spec = importlib.util.spec_from_file_location(
|
||||||
|
"vad_node_testmod",
|
||||||
|
os.path.join(_pkg_root(), "saltybot_social", "vad_node.py"),
|
||||||
|
)
|
||||||
|
mod = importlib.util.module_from_spec(spec)
|
||||||
|
spec.loader.exec_module(mod)
|
||||||
|
return mod
|
||||||
|
|
||||||
|
|
||||||
|
# ── helpers ───────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _sine(freq_hz, sample_rate, n_samples, amplitude=0.3):
|
||||||
|
return [amplitude * math.sin(2 * math.pi * freq_hz * i / sample_rate)
|
||||||
|
for i in range(n_samples)]
|
||||||
|
|
||||||
|
|
||||||
|
def _silence(n_samples):
|
||||||
|
return [0.0] * n_samples
|
||||||
|
|
||||||
|
|
||||||
|
def _pcm16_bytes(samples):
|
||||||
|
ints = [max(-32768, min(32767, int(s * 32768))) for s in samples]
|
||||||
|
return struct.pack(f"<{len(ints)}h", *ints)
|
||||||
|
|
||||||
|
|
||||||
|
# ── pcm16_bytes_to_float32 ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestPcm16Convert:
|
||||||
|
@pytest.fixture(scope="class")
|
||||||
|
def mod(self): return _import_vad()
|
||||||
|
|
||||||
|
def test_empty(self, mod):
|
||||||
|
assert mod.pcm16_bytes_to_float32(b"") == []
|
||||||
|
|
||||||
|
def test_odd_byte_ignored(self, mod):
|
||||||
|
data = struct.pack("<h", 16384) + b"\x00"
|
||||||
|
assert len(mod.pcm16_bytes_to_float32(data)) == 1
|
||||||
|
|
||||||
|
def test_positive_full_scale(self, mod):
|
||||||
|
data = struct.pack("<h", 32767)
|
||||||
|
result = mod.pcm16_bytes_to_float32(data)
|
||||||
|
assert abs(result[0] - 32767 / 32768.0) < 1e-4
|
||||||
|
|
||||||
|
def test_negative(self, mod):
|
||||||
|
data = struct.pack("<h", -16384)
|
||||||
|
assert mod.pcm16_bytes_to_float32(data)[0] < 0
|
||||||
|
|
||||||
|
def test_silence(self, mod):
|
||||||
|
data = struct.pack("<4h", 0, 0, 0, 0)
|
||||||
|
assert all(s == 0.0 for s in mod.pcm16_bytes_to_float32(data))
|
||||||
|
|
||||||
|
def test_roundtrip_length(self, mod):
|
||||||
|
n = 480
|
||||||
|
data = _pcm16_bytes(_sine(440, 16000, n, 0.5))
|
||||||
|
assert len(mod.pcm16_bytes_to_float32(data)) == n
|
||||||
|
|
||||||
|
|
||||||
|
# ── rms_linear ────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestRmsLinear:
|
||||||
|
@pytest.fixture(scope="class")
|
||||||
|
def mod(self): return _import_vad()
|
||||||
|
|
||||||
|
def test_empty(self, mod): assert mod.rms_linear([]) == 0.0
|
||||||
|
def test_silence(self, mod): assert mod.rms_linear([0.0] * 100) == 0.0
|
||||||
|
def test_full_scale(self, mod): assert abs(mod.rms_linear([1.0] * 100) - 1.0) < 1e-6
|
||||||
|
def test_half_amplitude(self, mod): assert abs(mod.rms_linear([0.5] * 100) - 0.5) < 1e-6
|
||||||
|
|
||||||
|
def test_sine_rms(self, mod):
|
||||||
|
s = _sine(440, 16000, 1600, amplitude=1.0)
|
||||||
|
assert abs(mod.rms_linear(s) - 1.0 / math.sqrt(2)) < 0.01
|
||||||
|
|
||||||
|
def test_nonnegative(self, mod):
|
||||||
|
assert mod.rms_linear(_sine(200, 16000, 480, 0.3)) >= 0.0
|
||||||
|
|
||||||
|
|
||||||
|
# ── rms_db ────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestRmsDb:
|
||||||
|
@pytest.fixture(scope="class")
|
||||||
|
def mod(self): return _import_vad()
|
||||||
|
|
||||||
|
def test_silence_returns_low(self, mod):
|
||||||
|
assert mod.rms_db([0.0] * 100) < -90.0
|
||||||
|
|
||||||
|
def test_full_scale_near_zero(self, mod):
|
||||||
|
assert abs(mod.rms_db([1.0] * 100)) < 1.0
|
||||||
|
|
||||||
|
def test_half_amplitude(self, mod):
|
||||||
|
assert abs(mod.rms_db([0.5] * 100) - (-6.0)) < 1.0
|
||||||
|
|
||||||
|
def test_louder_is_higher(self, mod):
|
||||||
|
assert mod.rms_db([0.5] * 100) > mod.rms_db([0.01] * 100)
|
||||||
|
|
||||||
|
def test_below_threshold(self, mod):
|
||||||
|
assert mod.rms_db([0.001] * 480) < -35.0
|
||||||
|
|
||||||
|
|
||||||
|
# ── zero_crossing_rate ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestZeroCrossingRate:
|
||||||
|
@pytest.fixture(scope="class")
|
||||||
|
def mod(self): return _import_vad()
|
||||||
|
|
||||||
|
def test_empty(self, mod): assert mod.zero_crossing_rate([]) == 0.0
|
||||||
|
def test_single(self, mod): assert mod.zero_crossing_rate([0.5]) == 0.0
|
||||||
|
def test_silence(self, mod): assert mod.zero_crossing_rate([0.0] * 100) == 0.0
|
||||||
|
def test_constant_pos(self, mod): assert mod.zero_crossing_rate([0.5] * 100) == 0.0
|
||||||
|
|
||||||
|
def test_alternating_full(self, mod):
|
||||||
|
s = [(-1.0) ** i for i in range(100)]
|
||||||
|
assert abs(mod.zero_crossing_rate(s) - 1.0) < 1e-6
|
||||||
|
|
||||||
|
def test_sine_in_range(self, mod):
|
||||||
|
zcr = mod.zero_crossing_rate(_sine(200, 16000, 1600))
|
||||||
|
assert 0.01 < zcr < 0.10
|
||||||
|
|
||||||
|
def test_high_freq_higher_zcr(self, mod):
|
||||||
|
lo = mod.zero_crossing_rate(_sine(100, 16000, 1600))
|
||||||
|
hi = mod.zero_crossing_rate(_sine(4000, 16000, 1600))
|
||||||
|
assert hi > lo
|
||||||
|
|
||||||
|
def test_in_unit_interval(self, mod):
|
||||||
|
zcr = mod.zero_crossing_rate(_sine(440, 16000, 480))
|
||||||
|
assert 0.0 <= zcr <= 1.0
|
||||||
|
|
||||||
|
|
||||||
|
# ── VadStateMachine ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestVadStateMachine:
|
||||||
|
@pytest.fixture(scope="class")
|
||||||
|
def mod(self): return _import_vad()
|
||||||
|
|
||||||
|
def test_initial_inactive(self, mod):
|
||||||
|
assert not mod.VadStateMachine(onset_frames=2, offset_frames=3).is_active
|
||||||
|
|
||||||
|
def test_onset_requires_n_frames(self, mod):
|
||||||
|
sm = mod.VadStateMachine(onset_frames=3, offset_frames=5)
|
||||||
|
assert not sm.update(True)
|
||||||
|
assert not sm.update(True)
|
||||||
|
assert sm.update(True)
|
||||||
|
|
||||||
|
def test_offset_requires_n_frames(self, mod):
|
||||||
|
sm = mod.VadStateMachine(onset_frames=1, offset_frames=3)
|
||||||
|
sm.update(True)
|
||||||
|
assert sm.update(False)
|
||||||
|
assert sm.update(False)
|
||||||
|
assert not sm.update(False)
|
||||||
|
|
||||||
|
def test_reset(self, mod):
|
||||||
|
sm = mod.VadStateMachine(onset_frames=1, offset_frames=1)
|
||||||
|
sm.update(True)
|
||||||
|
sm.reset()
|
||||||
|
assert not sm.is_active
|
||||||
|
|
||||||
|
def test_stays_active_with_speech(self, mod):
|
||||||
|
sm = mod.VadStateMachine(onset_frames=1, offset_frames=10)
|
||||||
|
sm.update(True)
|
||||||
|
for _ in range(20):
|
||||||
|
assert sm.update(True)
|
||||||
|
|
||||||
|
def test_onset1_offset1(self, mod):
|
||||||
|
sm = mod.VadStateMachine(onset_frames=1, offset_frames=1)
|
||||||
|
assert sm.update(True)
|
||||||
|
assert not sm.update(False)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Combined decision logic ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestCombinedDecision:
|
||||||
|
@pytest.fixture(scope="class")
|
||||||
|
def mod(self): return _import_vad()
|
||||||
|
|
||||||
|
def _decide(self, energy_db, zcr, rms_thresh=-35.0, zcr_min=0.01, zcr_max=0.40):
|
||||||
|
return (energy_db >= rms_thresh) and (zcr_min <= zcr <= zcr_max)
|
||||||
|
|
||||||
|
def test_normal_speech(self): assert self._decide(-20.0, 0.08)
|
||||||
|
def test_below_energy_threshold(self): assert not self._decide(-40.0, 0.08)
|
||||||
|
def test_zcr_too_low(self): assert not self._decide(-20.0, 0.005)
|
||||||
|
def test_zcr_too_high(self): assert not self._decide(-20.0, 0.50)
|
||||||
|
def test_energy_at_threshold(self): assert self._decide(-35.0, 0.08)
|
||||||
|
def test_zcr_at_min_boundary(self): assert self._decide(-20.0, 0.01)
|
||||||
|
def test_zcr_at_max_boundary(self): assert self._decide(-20.0, 0.40)
|
||||||
|
def test_loud_noise_high_zcr(self): assert not self._decide(-10.0, 0.50)
|
||||||
|
|
||||||
|
def test_integration_voiced_speech(self, mod):
|
||||||
|
samples = _sine(300, 16000, 480, amplitude=0.1)
|
||||||
|
energy_db = mod.rms_db(samples)
|
||||||
|
zcr = mod.zero_crossing_rate(samples)
|
||||||
|
assert (energy_db >= -35.0) and (0.01 <= zcr <= 0.40)
|
||||||
|
|
||||||
|
def test_integration_silence(self, mod):
|
||||||
|
assert mod.rms_db(_silence(480)) < -35.0
|
||||||
|
|
||||||
|
|
||||||
|
# ── Node source checks ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestVadNodeSrc:
|
||||||
|
@pytest.fixture(scope="class")
|
||||||
|
def src(self): return _read_src("saltybot_social/vad_node.py")
|
||||||
|
|
||||||
|
def test_class_defined(self, src): assert "class VadNode" in src
|
||||||
|
def test_state_machine(self, src): assert "class VadStateMachine" in src
|
||||||
|
def test_rms_threshold_param(self, src): assert '"rms_threshold_db"' in src
|
||||||
|
def test_zcr_min_param(self, src): assert '"zcr_min"' in src
|
||||||
|
def test_zcr_max_param(self, src): assert '"zcr_max"' in src
|
||||||
|
def test_onset_frames_param(self, src): assert '"onset_frames"' in src
|
||||||
|
def test_offset_frames_param(self, src): assert '"offset_frames"' in src
|
||||||
|
def test_audio_topic_param(self, src): assert '"audio_topic"' in src
|
||||||
|
def test_is_speaking_topic(self, src): assert '"/social/speech/is_speaking"' in src
|
||||||
|
def test_energy_topic(self, src): assert '"/social/speech/energy"' in src
|
||||||
|
def test_audio_raw_default(self, src): assert '"/social/speech/audio_raw"' in src
|
||||||
|
def test_bool_pub(self, src): assert "Bool" in src
|
||||||
|
def test_float32_pub(self, src): assert "Float32" in src
|
||||||
|
def test_uint8_sub(self, src): assert "UInt8MultiArray" in src
|
||||||
|
def test_rms_fn(self, src): assert "rms_db" in src
|
||||||
|
def test_zcr_fn(self, src): assert "zero_crossing_rate" in src
|
||||||
|
def test_pcm_convert(self, src): assert "pcm16_bytes_to_float32" in src
|
||||||
|
def test_hysteresis(self, src): assert "onset_frames" in src and "offset_frames" in src
|
||||||
|
def test_issue_tag(self, src): assert "242" in src
|
||||||
|
def test_main(self, src): assert "def main" in src
|
||||||
|
|
||||||
|
|
||||||
|
# ── Config + setup.py ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestConfig:
|
||||||
|
@pytest.fixture(scope="class")
|
||||||
|
def cfg(self): return _read_src("config/vad_params.yaml")
|
||||||
|
|
||||||
|
@pytest.fixture(scope="class")
|
||||||
|
def setup(self): return _read_src("setup.py")
|
||||||
|
|
||||||
|
def test_node_name(self, cfg): assert "vad_node:" in cfg
|
||||||
|
def test_rms_param(self, cfg): assert "rms_threshold_db" in cfg
|
||||||
|
def test_zcr_min(self, cfg): assert "zcr_min" in cfg
|
||||||
|
def test_zcr_max(self, cfg): assert "zcr_max" in cfg
|
||||||
|
def test_onset(self, cfg): assert "onset_frames" in cfg
|
||||||
|
def test_offset(self, cfg): assert "offset_frames" in cfg
|
||||||
|
def test_defaults(self, cfg): assert "-35.0" in cfg and "0.01" in cfg
|
||||||
|
def test_entry_point(self, setup): assert "vad_node = saltybot_social.vad_node:main" in setup
|
||||||
Loading…
x
Reference in New Issue
Block a user