Merge pull request 'feat(social): speech volume auto-adjuster (Issue #289)' (#293) from sl-jetson/issue-289-volume-adjust into main
Some checks failed
Some checks failed
This commit is contained in:
commit
b087af4b94
@ -0,0 +1,25 @@
|
||||
volume_adjust_node:
|
||||
ros__parameters:
|
||||
energy_topic: "/social/speech/energy" # Linear RMS [0..1] from vad_node
|
||||
volume_topic: "/saltybot/tts_volume" # TTS volume output [0..1]
|
||||
|
||||
# Volume range
|
||||
min_volume: 0.5 # Quietest TTS volume (silent environment)
|
||||
max_volume: 1.0 # Loudest TTS volume (very noisy environment)
|
||||
|
||||
# Noise mapping window
|
||||
noise_floor: 0.001 # Energy ≤ this → maps to min_volume (near-silence)
|
||||
noise_ceil: 0.10 # Energy ≥ this → maps to max_volume (loud crowd)
|
||||
|
||||
# Response curve
|
||||
# gamma < 1 → concave (aggressive at low noise, gentle at top)
|
||||
# gamma = 1 → linear
|
||||
# gamma > 1 → convex (gentle at low noise, aggressive at top)
|
||||
curve_gamma: 0.5
|
||||
|
||||
# Smoothing (EMA)
|
||||
smoothing_alpha: 0.1 # 0 = frozen, 1 = instant; 0.1 ≈ ~10-sample lag
|
||||
|
||||
# Timing
|
||||
publish_rate: 5.0 # Hz — how often volume command is published
|
||||
stale_timeout_s: 5.0 # If no energy received for this long, hold last value
|
||||
@ -0,0 +1,45 @@
|
||||
"""volume_adjust.launch.py — Launch speech volume auto-adjuster (Issue #289).
|
||||
|
||||
Usage:
|
||||
ros2 launch saltybot_social volume_adjust.launch.py
|
||||
ros2 launch saltybot_social volume_adjust.launch.py min_volume:=0.3 max_volume:=0.9
|
||||
"""
|
||||
|
||||
import os
|
||||
from ament_index_python.packages import get_package_share_directory
|
||||
from launch import LaunchDescription
|
||||
from launch.actions import DeclareLaunchArgument
|
||||
from launch.substitutions import LaunchConfiguration
|
||||
from launch_ros.actions import Node
|
||||
|
||||
|
||||
def generate_launch_description():
|
||||
pkg = get_package_share_directory("saltybot_social")
|
||||
cfg = os.path.join(pkg, "config", "volume_adjust_params.yaml")
|
||||
|
||||
return LaunchDescription([
|
||||
DeclareLaunchArgument("min_volume", default_value="0.5",
|
||||
description="Quietest TTS volume [0..1]"),
|
||||
DeclareLaunchArgument("max_volume", default_value="1.0",
|
||||
description="Loudest TTS volume [0..1]"),
|
||||
DeclareLaunchArgument("curve_gamma", default_value="0.5",
|
||||
description="Power-curve exponent (0.5=concave, 1=linear, 2=convex)"),
|
||||
DeclareLaunchArgument("smoothing_alpha", default_value="0.1",
|
||||
description="EMA smoothing factor (0=frozen, 1=instant)"),
|
||||
|
||||
Node(
|
||||
package="saltybot_social",
|
||||
executable="volume_adjust_node",
|
||||
name="volume_adjust_node",
|
||||
output="screen",
|
||||
parameters=[
|
||||
cfg,
|
||||
{
|
||||
"min_volume": LaunchConfiguration("min_volume"),
|
||||
"max_volume": LaunchConfiguration("max_volume"),
|
||||
"curve_gamma": LaunchConfiguration("curve_gamma"),
|
||||
"smoothing_alpha": LaunchConfiguration("smoothing_alpha"),
|
||||
},
|
||||
],
|
||||
),
|
||||
])
|
||||
@ -0,0 +1,186 @@
|
||||
"""volume_adjust_node.py — Speech volume auto-adjuster.
|
||||
Issue #289
|
||||
|
||||
Subscribes to /social/speech/energy (linear RMS [0..1] published by
|
||||
vad_node) and dynamically adjusts TTS output volume via
|
||||
/saltybot/tts_volume (std_msgs/Float32) so the robot speaks louder in
|
||||
noisy environments and quieter when ambient noise is low.
|
||||
|
||||
Mapping pipeline (runs at publish_rate Hz)
|
||||
──────────────────────────────────────────
|
||||
1. Normalise raw energy into [0, 1] using [noise_floor, noise_ceil]:
|
||||
t = clamp((energy - noise_floor) / (noise_ceil - noise_floor), 0, 1)
|
||||
2. Apply power curve:
|
||||
t_curved = t ^ curve_gamma
|
||||
gamma < 1 → concave (strong response at low noise, gentle at top)
|
||||
gamma = 1 → linear
|
||||
gamma > 1 → convex (gentle response at low noise, strong at top)
|
||||
3. Map to volume:
|
||||
vol_target = min_volume + (max_volume - min_volume) * t_curved
|
||||
4. Exponential moving-average smoothing:
|
||||
vol_smoothed = alpha * vol_target + (1-alpha) * vol_smoothed
|
||||
5. If no energy sample received for stale_timeout_s → hold last value.
|
||||
|
||||
Parameters
|
||||
──────────
|
||||
energy_topic (str, "/social/speech/energy")
|
||||
volume_topic (str, "/saltybot/tts_volume")
|
||||
min_volume (float, 0.5) lowest TTS volume [0..1]
|
||||
max_volume (float, 1.0) highest TTS volume [0..1]
|
||||
noise_floor (float, 0.001) energy ≤ this maps to min_volume
|
||||
noise_ceil (float, 0.10) energy ≥ this maps to max_volume
|
||||
curve_gamma (float, 0.5) power-curve exponent
|
||||
smoothing_alpha(float, 0.1) EMA weight (0=frozen, 1=instant)
|
||||
publish_rate (float, 5.0) Hz
|
||||
stale_timeout_s(float, 5.0) seconds before energy considered stale
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
import threading
|
||||
|
||||
import rclpy
|
||||
from rclpy.node import Node
|
||||
from rclpy.qos import QoSProfile
|
||||
from std_msgs.msg import Float32
|
||||
|
||||
|
||||
# ── Pure helpers ───────────────────────────────────────────────────────────────
|
||||
|
||||
def clamp(v: float, lo: float, hi: float) -> float:
|
||||
return max(lo, min(hi, v))
|
||||
|
||||
|
||||
def normalize_energy(energy: float, noise_floor: float, noise_ceil: float) -> float:
|
||||
"""Map raw linear energy into [0, 1] within [noise_floor, noise_ceil]."""
|
||||
span = noise_ceil - noise_floor
|
||||
if span <= 0.0:
|
||||
return 1.0 if energy >= noise_ceil else 0.0
|
||||
return clamp((energy - noise_floor) / span, 0.0, 1.0)
|
||||
|
||||
|
||||
def apply_curve(t: float, gamma: float) -> float:
|
||||
"""Apply power-curve shaping. t ∈ [0,1], gamma > 0."""
|
||||
if gamma <= 0.0:
|
||||
return 1.0
|
||||
return t ** gamma
|
||||
|
||||
|
||||
def compute_target_volume(energy: float,
|
||||
noise_floor: float,
|
||||
noise_ceil: float,
|
||||
gamma: float,
|
||||
min_volume: float,
|
||||
max_volume: float) -> float:
|
||||
"""Return target volume [min_volume, max_volume] for a given energy."""
|
||||
t = normalize_energy(energy, noise_floor, noise_ceil)
|
||||
t_curved = apply_curve(t, gamma)
|
||||
return min_volume + (max_volume - min_volume) * t_curved
|
||||
|
||||
|
||||
def ema_update(current: float, target: float, alpha: float) -> float:
|
||||
"""Single-step exponential moving average."""
|
||||
alpha = clamp(alpha, 0.0, 1.0)
|
||||
return alpha * target + (1.0 - alpha) * current
|
||||
|
||||
|
||||
# ── ROS2 node ──────────────────────────────────────────────────────────────────
|
||||
|
||||
class VolumeAdjustNode(Node):
|
||||
"""Dynamically adjusts TTS volume from ambient energy measurements."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__("volume_adjust_node")
|
||||
|
||||
self.declare_parameter("energy_topic", "/social/speech/energy")
|
||||
self.declare_parameter("volume_topic", "/saltybot/tts_volume")
|
||||
self.declare_parameter("min_volume", 0.5)
|
||||
self.declare_parameter("max_volume", 1.0)
|
||||
self.declare_parameter("noise_floor", 0.001)
|
||||
self.declare_parameter("noise_ceil", 0.10)
|
||||
self.declare_parameter("curve_gamma", 0.5)
|
||||
self.declare_parameter("smoothing_alpha", 0.1)
|
||||
self.declare_parameter("publish_rate", 5.0)
|
||||
self.declare_parameter("stale_timeout_s", 5.0)
|
||||
|
||||
energy_topic = self.get_parameter("energy_topic").value
|
||||
volume_topic = self.get_parameter("volume_topic").value
|
||||
self._min_vol = self.get_parameter("min_volume").value
|
||||
self._max_vol = self.get_parameter("max_volume").value
|
||||
self._floor = self.get_parameter("noise_floor").value
|
||||
self._ceil = self.get_parameter("noise_ceil").value
|
||||
self._gamma = self.get_parameter("curve_gamma").value
|
||||
self._alpha = self.get_parameter("smoothing_alpha").value
|
||||
self._stale_t = self.get_parameter("stale_timeout_s").value
|
||||
publish_rate = self.get_parameter("publish_rate").value
|
||||
|
||||
# Start at min_volume; smoothed toward target as data arrives
|
||||
self._current_vol: float = self._min_vol
|
||||
self._latest_energy: float = 0.0
|
||||
self._last_energy_t: float = 0.0
|
||||
self._lock = threading.Lock()
|
||||
|
||||
qos = QoSProfile(depth=10)
|
||||
self._vol_pub = self.create_publisher(Float32, volume_topic, qos)
|
||||
self._energy_sub = self.create_subscription(
|
||||
Float32, energy_topic, self._on_energy, qos
|
||||
)
|
||||
self._timer = self.create_timer(1.0 / publish_rate, self._publish_cb)
|
||||
|
||||
self.get_logger().info(
|
||||
f"VolumeAdjustNode ready "
|
||||
f"(vol=[{self._min_vol},{self._max_vol}], "
|
||||
f"floor={self._floor}, ceil={self._ceil}, "
|
||||
f"gamma={self._gamma}, alpha={self._alpha})"
|
||||
)
|
||||
|
||||
# ── Subscription ───────────────────────────────────────────────────────
|
||||
|
||||
def _on_energy(self, msg: Float32) -> None:
|
||||
with self._lock:
|
||||
self._latest_energy = float(msg.data)
|
||||
self._last_energy_t = time.monotonic()
|
||||
|
||||
# ── Timer / publish ────────────────────────────────────────────────────
|
||||
|
||||
def _publish_cb(self) -> None:
|
||||
now = time.monotonic()
|
||||
with self._lock:
|
||||
energy = self._latest_energy
|
||||
last_t = self._last_energy_t
|
||||
fresh = last_t > 0.0 and (now - last_t) < self._stale_t
|
||||
|
||||
if fresh:
|
||||
target = compute_target_volume(
|
||||
energy, self._floor, self._ceil,
|
||||
self._gamma, self._min_vol, self._max_vol,
|
||||
)
|
||||
with self._lock:
|
||||
self._current_vol = ema_update(
|
||||
self._current_vol, target, self._alpha
|
||||
)
|
||||
|
||||
msg = Float32()
|
||||
with self._lock:
|
||||
msg.data = float(self._current_vol)
|
||||
self._vol_pub.publish(msg)
|
||||
|
||||
# ── Property (for testing) ─────────────────────────────────────────────
|
||||
|
||||
@property
|
||||
def current_volume(self) -> float:
|
||||
with self._lock:
|
||||
return self._current_vol
|
||||
|
||||
|
||||
def main(args=None) -> None:
|
||||
rclpy.init(args=args)
|
||||
node = VolumeAdjustNode()
|
||||
try:
|
||||
rclpy.spin(node)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
node.destroy_node()
|
||||
rclpy.shutdown()
|
||||
@ -51,6 +51,8 @@ setup(
|
||||
'greeting_trigger_node = saltybot_social.greeting_trigger_node:main',
|
||||
# Face-tracking head servo controller (Issue #279)
|
||||
'face_track_servo_node = saltybot_social.face_track_servo_node:main',
|
||||
# Speech volume auto-adjuster (Issue #289)
|
||||
'volume_adjust_node = saltybot_social.volume_adjust_node:main',
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
511
jetson/ros2_ws/src/saltybot_social/test/test_volume_adjust.py
Normal file
511
jetson/ros2_ws/src/saltybot_social/test/test_volume_adjust.py
Normal file
@ -0,0 +1,511 @@
|
||||
"""test_volume_adjust.py — Offline tests for volume_adjust_node (Issue #289).
|
||||
|
||||
Stubs out rclpy so tests run without a ROS install.
|
||||
"""
|
||||
|
||||
import importlib.util
|
||||
import math
|
||||
import sys
|
||||
import time
|
||||
import types
|
||||
import unittest
|
||||
|
||||
|
||||
# ── ROS2 stubs ────────────────────────────────────────────────────────────────
|
||||
|
||||
def _make_ros_stubs():
|
||||
for mod_name in ("rclpy", "rclpy.node", "rclpy.qos",
|
||||
"std_msgs", "std_msgs.msg"):
|
||||
if mod_name not in sys.modules:
|
||||
sys.modules[mod_name] = types.ModuleType(mod_name)
|
||||
|
||||
class _Node:
|
||||
def __init__(self, name="node"):
|
||||
self._name = name
|
||||
if not hasattr(self, "_params"):
|
||||
self._params = {}
|
||||
self._pubs = {}
|
||||
self._subs = {}
|
||||
self._timers = []
|
||||
self._logs = []
|
||||
|
||||
def declare_parameter(self, name, default):
|
||||
if name not in self._params:
|
||||
self._params[name] = default
|
||||
|
||||
def get_parameter(self, name):
|
||||
class _P:
|
||||
def __init__(self, v): self.value = v
|
||||
return _P(self._params.get(name))
|
||||
|
||||
def create_publisher(self, msg_type, topic, qos):
|
||||
pub = _FakePub()
|
||||
self._pubs[topic] = pub
|
||||
return pub
|
||||
|
||||
def create_subscription(self, msg_type, topic, cb, qos):
|
||||
self._subs[topic] = cb
|
||||
return object()
|
||||
|
||||
def create_timer(self, period, cb):
|
||||
self._timers.append(cb)
|
||||
return object()
|
||||
|
||||
def get_logger(self):
|
||||
node = self
|
||||
class _L:
|
||||
def info(self, m): node._logs.append(("INFO", m))
|
||||
def warn(self, m): node._logs.append(("WARN", m))
|
||||
def error(self, m): node._logs.append(("ERROR", m))
|
||||
return _L()
|
||||
|
||||
def destroy_node(self): pass
|
||||
|
||||
class _FakePub:
|
||||
def __init__(self):
|
||||
self.msgs = []
|
||||
def publish(self, msg):
|
||||
self.msgs.append(msg)
|
||||
|
||||
class _QoSProfile:
|
||||
def __init__(self, depth=10): self.depth = depth
|
||||
|
||||
class _Float32:
|
||||
def __init__(self): self.data = 0.0
|
||||
|
||||
rclpy_mod = sys.modules["rclpy"]
|
||||
rclpy_mod.init = lambda args=None: None
|
||||
rclpy_mod.spin = lambda node: None
|
||||
rclpy_mod.shutdown = lambda: None
|
||||
|
||||
sys.modules["rclpy.node"].Node = _Node
|
||||
sys.modules["rclpy.qos"].QoSProfile = _QoSProfile
|
||||
sys.modules["std_msgs.msg"].Float32 = _Float32
|
||||
|
||||
return _Node, _FakePub, _Float32
|
||||
|
||||
|
||||
_Node, _FakePub, _Float32 = _make_ros_stubs()
|
||||
|
||||
# ── Module loader ─────────────────────────────────────────────────────────────
|
||||
|
||||
_SRC = (
|
||||
"/Users/seb/AI/saltylab-firmware/jetson/ros2_ws/src/"
|
||||
"saltybot_social/saltybot_social/volume_adjust_node.py"
|
||||
)
|
||||
|
||||
|
||||
def _load_mod():
|
||||
spec = importlib.util.spec_from_file_location("volume_adjust_testmod", _SRC)
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(mod)
|
||||
return mod
|
||||
|
||||
|
||||
def _make_node(mod, **kwargs):
|
||||
node = mod.VolumeAdjustNode.__new__(mod.VolumeAdjustNode)
|
||||
defaults = {
|
||||
"energy_topic": "/social/speech/energy",
|
||||
"volume_topic": "/saltybot/tts_volume",
|
||||
"min_volume": 0.5,
|
||||
"max_volume": 1.0,
|
||||
"noise_floor": 0.001,
|
||||
"noise_ceil": 0.10,
|
||||
"curve_gamma": 0.5,
|
||||
"smoothing_alpha": 0.1,
|
||||
"publish_rate": 5.0,
|
||||
"stale_timeout_s": 5.0,
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
node._params = dict(defaults)
|
||||
mod.VolumeAdjustNode.__init__(node)
|
||||
return node
|
||||
|
||||
|
||||
def _energy_msg(val):
|
||||
m = _Float32(); m.data = val; return m
|
||||
|
||||
|
||||
# ── Tests: pure helpers ───────────────────────────────────────────────────────
|
||||
|
||||
class TestClamp(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls): cls.mod = _load_mod()
|
||||
|
||||
def test_within(self): self.assertAlmostEqual(self.mod.clamp(0.5, 0.0, 1.0), 0.5)
|
||||
def test_low(self): self.assertAlmostEqual(self.mod.clamp(-1.0, 0.0, 1.0), 0.0)
|
||||
def test_high(self): self.assertAlmostEqual(self.mod.clamp(2.0, 0.0, 1.0), 1.0)
|
||||
def test_equal_lo(self): self.assertAlmostEqual(self.mod.clamp(0.0, 0.0, 1.0), 0.0)
|
||||
def test_equal_hi(self): self.assertAlmostEqual(self.mod.clamp(1.0, 0.0, 1.0), 1.0)
|
||||
|
||||
|
||||
class TestNormalizeEnergy(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls): cls.mod = _load_mod()
|
||||
|
||||
def _n(self, e, floor=0.001, ceil=0.1):
|
||||
return self.mod.normalize_energy(e, floor, ceil)
|
||||
|
||||
def test_at_floor(self):
|
||||
self.assertAlmostEqual(self._n(0.001), 0.0)
|
||||
|
||||
def test_at_ceil(self):
|
||||
self.assertAlmostEqual(self._n(0.1), 1.0)
|
||||
|
||||
def test_midpoint(self):
|
||||
mid = (0.001 + 0.1) / 2
|
||||
self.assertAlmostEqual(self._n(mid), 0.5, places=4)
|
||||
|
||||
def test_below_floor(self):
|
||||
self.assertAlmostEqual(self._n(0.0), 0.0)
|
||||
|
||||
def test_above_ceil(self):
|
||||
self.assertAlmostEqual(self._n(1.0), 1.0)
|
||||
|
||||
def test_zero_span_below(self):
|
||||
self.assertAlmostEqual(self.mod.normalize_energy(0.0, 0.5, 0.5), 0.0)
|
||||
|
||||
def test_zero_span_above(self):
|
||||
self.assertAlmostEqual(self.mod.normalize_energy(1.0, 0.5, 0.5), 1.0)
|
||||
|
||||
|
||||
class TestApplyCurve(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls): cls.mod = _load_mod()
|
||||
|
||||
def test_linear_gamma1(self):
|
||||
self.assertAlmostEqual(self.mod.apply_curve(0.5, 1.0), 0.5)
|
||||
|
||||
def test_gamma_half(self):
|
||||
# 0.25 ^ 0.5 = 0.5
|
||||
self.assertAlmostEqual(self.mod.apply_curve(0.25, 0.5), 0.5)
|
||||
|
||||
def test_zero_input(self):
|
||||
self.assertAlmostEqual(self.mod.apply_curve(0.0, 0.5), 0.0)
|
||||
|
||||
def test_one_input(self):
|
||||
self.assertAlmostEqual(self.mod.apply_curve(1.0, 0.5), 1.0)
|
||||
|
||||
def test_concave_midpoint_above_linear(self):
|
||||
# gamma < 1 → curve above linear at midpoint
|
||||
self.assertGreater(self.mod.apply_curve(0.5, 0.5), 0.5)
|
||||
|
||||
def test_convex_midpoint_below_linear(self):
|
||||
# gamma > 1 → curve below linear at midpoint
|
||||
self.assertLess(self.mod.apply_curve(0.5, 2.0), 0.5)
|
||||
|
||||
def test_invalid_gamma_zero(self):
|
||||
# gamma <= 0 → returns 1.0 (max)
|
||||
self.assertAlmostEqual(self.mod.apply_curve(0.5, 0.0), 1.0)
|
||||
|
||||
|
||||
class TestComputeTargetVolume(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls): cls.mod = _load_mod()
|
||||
|
||||
def _vol(self, energy, floor=0.001, ceil=0.1, gamma=1.0,
|
||||
min_v=0.5, max_v=1.0):
|
||||
return self.mod.compute_target_volume(energy, floor, ceil, gamma, min_v, max_v)
|
||||
|
||||
def test_silence_gives_min(self):
|
||||
self.assertAlmostEqual(self._vol(0.0), 0.5)
|
||||
|
||||
def test_max_noise_gives_max(self):
|
||||
self.assertAlmostEqual(self._vol(1.0), 1.0)
|
||||
|
||||
def test_mid_linear(self):
|
||||
mid = (0.001 + 0.1) / 2
|
||||
vol = self._vol(mid)
|
||||
self.assertGreater(vol, 0.5)
|
||||
self.assertLess(vol, 1.0)
|
||||
|
||||
def test_monotonically_increasing(self):
|
||||
energies = [0.0, 0.01, 0.03, 0.05, 0.08, 0.1, 0.5]
|
||||
vols = [self._vol(e) for e in energies]
|
||||
for i in range(len(vols) - 1):
|
||||
self.assertLessEqual(vols[i], vols[i + 1])
|
||||
|
||||
def test_min_max_bounds(self):
|
||||
for e in [0.0, 0.001, 0.05, 0.1, 1.0]:
|
||||
v = self._vol(e)
|
||||
self.assertGreaterEqual(v, 0.5)
|
||||
self.assertLessEqual(v, 1.0)
|
||||
|
||||
def test_custom_range(self):
|
||||
v_min = self._vol(0.0, min_v=0.2, max_v=0.8)
|
||||
v_max = self._vol(1.0, min_v=0.2, max_v=0.8)
|
||||
self.assertAlmostEqual(v_min, 0.2)
|
||||
self.assertAlmostEqual(v_max, 0.8)
|
||||
|
||||
def test_gamma_affects_curve(self):
|
||||
e = (0.001 + 0.1) / 2
|
||||
v_concave = self._vol(e, gamma=0.5)
|
||||
v_linear = self._vol(e, gamma=1.0)
|
||||
v_convex = self._vol(e, gamma=2.0)
|
||||
self.assertGreater(v_concave, v_linear)
|
||||
self.assertGreater(v_linear, v_convex)
|
||||
|
||||
|
||||
class TestEmaUpdate(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls): cls.mod = _load_mod()
|
||||
|
||||
def test_alpha_zero_frozen(self):
|
||||
result = self.mod.ema_update(0.5, 1.0, 0.0)
|
||||
self.assertAlmostEqual(result, 0.5)
|
||||
|
||||
def test_alpha_one_instant(self):
|
||||
result = self.mod.ema_update(0.5, 1.0, 1.0)
|
||||
self.assertAlmostEqual(result, 1.0)
|
||||
|
||||
def test_alpha_half(self):
|
||||
result = self.mod.ema_update(0.0, 1.0, 0.5)
|
||||
self.assertAlmostEqual(result, 0.5)
|
||||
|
||||
def test_converges_to_target(self):
|
||||
val = 0.0
|
||||
for _ in range(200):
|
||||
val = self.mod.ema_update(val, 1.0, 0.1)
|
||||
self.assertAlmostEqual(val, 1.0, places=2)
|
||||
|
||||
def test_alpha_clamped_above_one(self):
|
||||
result = self.mod.ema_update(0.5, 1.0, 2.0)
|
||||
self.assertAlmostEqual(result, 1.0)
|
||||
|
||||
def test_alpha_clamped_below_zero(self):
|
||||
result = self.mod.ema_update(0.5, 1.0, -1.0)
|
||||
self.assertAlmostEqual(result, 0.5)
|
||||
|
||||
|
||||
# ── Tests: node initialisation ────────────────────────────────────────────────
|
||||
|
||||
class TestNodeInit(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls): cls.mod = _load_mod()
|
||||
|
||||
def test_instantiates(self):
|
||||
self.assertIsNotNone(_make_node(self.mod))
|
||||
|
||||
def test_pub_registered(self):
|
||||
node = _make_node(self.mod)
|
||||
self.assertIn("/saltybot/tts_volume", node._pubs)
|
||||
|
||||
def test_sub_registered(self):
|
||||
node = _make_node(self.mod)
|
||||
self.assertIn("/social/speech/energy", node._subs)
|
||||
|
||||
def test_timer_registered(self):
|
||||
node = _make_node(self.mod)
|
||||
self.assertGreater(len(node._timers), 0)
|
||||
|
||||
def test_initial_volume_is_min(self):
|
||||
node = _make_node(self.mod, min_volume=0.4)
|
||||
self.assertAlmostEqual(node.current_volume, 0.4)
|
||||
|
||||
def test_custom_topics(self):
|
||||
node = _make_node(self.mod,
|
||||
energy_topic="/my/energy",
|
||||
volume_topic="/my/volume")
|
||||
self.assertIn("/my/energy", node._subs)
|
||||
self.assertIn("/my/volume", node._pubs)
|
||||
|
||||
|
||||
# ── Tests: energy subscription ────────────────────────────────────────────────
|
||||
|
||||
class TestEnergyCallback(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls): cls.mod = _load_mod()
|
||||
|
||||
def setUp(self):
|
||||
self.node = _make_node(self.mod)
|
||||
|
||||
def test_stores_energy(self):
|
||||
self.node._on_energy(_energy_msg(0.05))
|
||||
self.assertAlmostEqual(self.node._latest_energy, 0.05)
|
||||
|
||||
def test_updates_timestamp(self):
|
||||
before = time.monotonic()
|
||||
self.node._on_energy(_energy_msg(0.03))
|
||||
self.assertGreaterEqual(self.node._last_energy_t, before)
|
||||
|
||||
def test_zero_energy(self):
|
||||
self.node._on_energy(_energy_msg(0.0))
|
||||
self.assertAlmostEqual(self.node._latest_energy, 0.0)
|
||||
|
||||
def test_high_energy(self):
|
||||
self.node._on_energy(_energy_msg(0.9))
|
||||
self.assertAlmostEqual(self.node._latest_energy, 0.9)
|
||||
|
||||
|
||||
# ── Tests: publish callback ───────────────────────────────────────────────────
|
||||
|
||||
class TestPublishCallback(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls): cls.mod = _load_mod()
|
||||
|
||||
def setUp(self):
|
||||
self.node = _make_node(self.mod,
|
||||
smoothing_alpha=1.0, # instant for test clarity
|
||||
min_volume=0.5, max_volume=1.0,
|
||||
noise_floor=0.001, noise_ceil=0.1,
|
||||
curve_gamma=1.0,
|
||||
stale_timeout_s=5.0)
|
||||
self.pub = self.node._pubs["/saltybot/tts_volume"]
|
||||
|
||||
def _inject(self, energy):
|
||||
self.node._on_energy(_energy_msg(energy))
|
||||
self.node._last_energy_t = time.monotonic()
|
||||
|
||||
def test_publishes_on_tick(self):
|
||||
self.node._publish_cb()
|
||||
self.assertEqual(len(self.pub.msgs), 1)
|
||||
|
||||
def test_no_data_holds_min(self):
|
||||
self.node._publish_cb()
|
||||
self.assertAlmostEqual(self.pub.msgs[-1].data, 0.5, places=3)
|
||||
|
||||
def test_high_energy_raises_volume(self):
|
||||
self._inject(0.1) # at noise_ceil → max_volume
|
||||
self.node._publish_cb()
|
||||
self.assertAlmostEqual(self.pub.msgs[-1].data, 1.0, places=3)
|
||||
|
||||
def test_silence_gives_min_volume(self):
|
||||
self._inject(0.0) # at/below floor → min_volume
|
||||
self.node._publish_cb()
|
||||
self.assertAlmostEqual(self.pub.msgs[-1].data, 0.5, places=3)
|
||||
|
||||
def test_volume_between_min_max(self):
|
||||
self._inject(0.05)
|
||||
self.node._publish_cb()
|
||||
v = self.pub.msgs[-1].data
|
||||
self.assertGreaterEqual(v, 0.5)
|
||||
self.assertLessEqual(v, 1.0)
|
||||
|
||||
def test_stale_data_holds_last(self):
|
||||
self._inject(0.1)
|
||||
self.node._publish_cb()
|
||||
v_before = self.pub.msgs[-1].data
|
||||
# Expire the timestamp
|
||||
self.node._last_energy_t = time.monotonic() - 100.0
|
||||
self.node._publish_cb()
|
||||
self.assertAlmostEqual(self.pub.msgs[-1].data, v_before, places=5)
|
||||
|
||||
def test_smoothing_gradual(self):
|
||||
node = _make_node(self.mod, smoothing_alpha=0.1,
|
||||
min_volume=0.5, max_volume=1.0,
|
||||
noise_floor=0.0, noise_ceil=1.0,
|
||||
curve_gamma=1.0)
|
||||
pub = node._pubs["/saltybot/tts_volume"]
|
||||
node._on_energy(_energy_msg(1.0))
|
||||
node._last_energy_t = time.monotonic()
|
||||
node._publish_cb()
|
||||
# alpha=0.1: should move from 0.5 toward 1.0 by ~0.05
|
||||
v = pub.msgs[-1].data
|
||||
self.assertGreater(v, 0.5)
|
||||
self.assertLess(v, 1.0)
|
||||
|
||||
def test_multiple_ticks_converge(self):
|
||||
node = _make_node(self.mod, smoothing_alpha=0.2,
|
||||
min_volume=0.5, max_volume=1.0,
|
||||
noise_floor=0.0, noise_ceil=1.0,
|
||||
curve_gamma=1.0)
|
||||
pub = node._pubs["/saltybot/tts_volume"]
|
||||
for _ in range(50):
|
||||
node._on_energy(_energy_msg(1.0))
|
||||
node._last_energy_t = time.monotonic()
|
||||
node._publish_cb()
|
||||
self.assertAlmostEqual(pub.msgs[-1].data, 1.0, places=2)
|
||||
|
||||
def test_noise_raises_then_silence_lowers(self):
|
||||
node = _make_node(self.mod, smoothing_alpha=1.0,
|
||||
min_volume=0.5, max_volume=1.0,
|
||||
noise_floor=0.0, noise_ceil=1.0,
|
||||
curve_gamma=1.0)
|
||||
pub = node._pubs["/saltybot/tts_volume"]
|
||||
# High noise
|
||||
node._on_energy(_energy_msg(1.0))
|
||||
node._last_energy_t = time.monotonic()
|
||||
node._publish_cb()
|
||||
v_high = pub.msgs[-1].data
|
||||
# Silence
|
||||
node._on_energy(_energy_msg(0.0))
|
||||
node._last_energy_t = time.monotonic()
|
||||
node._publish_cb()
|
||||
v_low = pub.msgs[-1].data
|
||||
self.assertGreater(v_high, v_low)
|
||||
|
||||
|
||||
# ── Tests: source and config ──────────────────────────────────────────────────
|
||||
|
||||
class TestNodeSrc(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
with open(_SRC) as f: cls.src = f.read()
|
||||
|
||||
def test_issue_tag(self): self.assertIn("#289", self.src)
|
||||
def test_energy_topic(self): self.assertIn("/social/speech/energy", self.src)
|
||||
def test_volume_topic(self): self.assertIn("/saltybot/tts_volume", self.src)
|
||||
def test_min_volume_param(self): self.assertIn("min_volume", self.src)
|
||||
def test_max_volume_param(self): self.assertIn("max_volume", self.src)
|
||||
def test_noise_floor_param(self):self.assertIn("noise_floor", self.src)
|
||||
def test_noise_ceil_param(self): self.assertIn("noise_ceil", self.src)
|
||||
def test_gamma_param(self): self.assertIn("curve_gamma", self.src)
|
||||
def test_alpha_param(self): self.assertIn("smoothing_alpha", self.src)
|
||||
def test_ema_update(self): self.assertIn("ema_update", self.src)
|
||||
def test_compute_target(self): self.assertIn("compute_target_volume", self.src)
|
||||
def test_threading_lock(self): self.assertIn("threading.Lock", self.src)
|
||||
def test_main_defined(self): self.assertIn("def main", self.src)
|
||||
def test_stale_timeout(self): self.assertIn("stale_timeout", self.src)
|
||||
|
||||
|
||||
class TestConfig(unittest.TestCase):
|
||||
_CONFIG = (
|
||||
"/Users/seb/AI/saltylab-firmware/jetson/ros2_ws/src/"
|
||||
"saltybot_social/config/volume_adjust_params.yaml"
|
||||
)
|
||||
_LAUNCH = (
|
||||
"/Users/seb/AI/saltylab-firmware/jetson/ros2_ws/src/"
|
||||
"saltybot_social/launch/volume_adjust.launch.py"
|
||||
)
|
||||
_SETUP = (
|
||||
"/Users/seb/AI/saltylab-firmware/jetson/ros2_ws/src/"
|
||||
"saltybot_social/setup.py"
|
||||
)
|
||||
|
||||
def test_config_exists(self):
|
||||
import os; self.assertTrue(os.path.exists(self._CONFIG))
|
||||
|
||||
def test_config_min_volume(self):
|
||||
with open(self._CONFIG) as f: c = f.read()
|
||||
self.assertIn("min_volume", c)
|
||||
|
||||
def test_config_max_volume(self):
|
||||
with open(self._CONFIG) as f: c = f.read()
|
||||
self.assertIn("max_volume", c)
|
||||
|
||||
def test_config_gamma(self):
|
||||
with open(self._CONFIG) as f: c = f.read()
|
||||
self.assertIn("curve_gamma", c)
|
||||
|
||||
def test_config_smoothing(self):
|
||||
with open(self._CONFIG) as f: c = f.read()
|
||||
self.assertIn("smoothing_alpha", c)
|
||||
|
||||
def test_launch_exists(self):
|
||||
import os; self.assertTrue(os.path.exists(self._LAUNCH))
|
||||
|
||||
def test_launch_min_volume_arg(self):
|
||||
with open(self._LAUNCH) as f: c = f.read()
|
||||
self.assertIn("min_volume", c)
|
||||
|
||||
def test_launch_gamma_arg(self):
|
||||
with open(self._LAUNCH) as f: c = f.read()
|
||||
self.assertIn("curve_gamma", c)
|
||||
|
||||
def test_entry_point(self):
|
||||
with open(self._SETUP) as f: c = f.read()
|
||||
self.assertIn("volume_adjust_node", c)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Loading…
x
Reference in New Issue
Block a user