Compare commits
7 Commits
b2d76b434b
...
3cd9faeed9
| Author | SHA1 | Date | |
|---|---|---|---|
| 3cd9faeed9 | |||
| 5e40504297 | |||
| a55cd9c97f | |||
| a16cc06d79 | |||
| 8f51390e43 | |||
| 32857435a1 | |||
| 7d7f1c0e5b |
146
include/buzzer.h
Normal file
146
include/buzzer.h
Normal file
@ -0,0 +1,146 @@
|
||||
#ifndef BUZZER_H
|
||||
#define BUZZER_H
|
||||
|
||||
#include <stdint.h>
|
||||
#include <stdbool.h>
|
||||
|
||||
/*
|
||||
* buzzer.h — Piezo buzzer melody driver (Issue #253)
|
||||
*
|
||||
* STM32F722 driver for piezo buzzer on PA8 using TIM1 PWM.
|
||||
* Plays predefined melodies and tones with non-blocking queue.
|
||||
*
|
||||
* Pin: PA8 (TIM1_CH1, alternate function AF1)
|
||||
* PWM Frequency: 1kHz-5kHz base, modulated for melody
|
||||
* Volume: Controlled via PWM duty cycle (50-100%)
|
||||
*/
|
||||
|
||||
/* Musical note frequencies (Hz) — standard equal temperament */
|
||||
typedef enum {
|
||||
NOTE_REST = 0, /* Silence */
|
||||
NOTE_C4 = 262, /* Middle C */
|
||||
NOTE_D4 = 294,
|
||||
NOTE_E4 = 330,
|
||||
NOTE_F4 = 349,
|
||||
NOTE_G4 = 392,
|
||||
NOTE_A4 = 440, /* A4 concert pitch */
|
||||
NOTE_B4 = 494,
|
||||
NOTE_C5 = 523,
|
||||
NOTE_D5 = 587,
|
||||
NOTE_E5 = 659,
|
||||
NOTE_F5 = 698,
|
||||
NOTE_G5 = 784,
|
||||
NOTE_A5 = 880,
|
||||
NOTE_B5 = 988,
|
||||
NOTE_C6 = 1047,
|
||||
} Note;
|
||||
|
||||
/* Note duration (milliseconds) */
|
||||
typedef enum {
|
||||
DURATION_WHOLE = 2000, /* 4 beats @ 120 BPM */
|
||||
DURATION_HALF = 1000, /* 2 beats */
|
||||
DURATION_QUARTER = 500, /* 1 beat */
|
||||
DURATION_EIGHTH = 250, /* 1/2 beat */
|
||||
DURATION_SIXTEENTH = 125, /* 1/4 beat */
|
||||
} Duration;
|
||||
|
||||
/* Melody sequence: array of (note, duration) pairs, terminated with {0, 0} */
|
||||
typedef struct {
|
||||
Note frequency;
|
||||
Duration duration_ms;
|
||||
} MelodyNote;
|
||||
|
||||
/* Predefined melodies */
|
||||
typedef enum {
|
||||
MELODY_STARTUP, /* Startup jingle: ascending tones */
|
||||
MELODY_LOW_BATTERY, /* Warning: two descending beeps */
|
||||
MELODY_ERROR, /* Alert: rapid error beep */
|
||||
MELODY_DOCKING_COMPLETE /* Success: cheerful chime */
|
||||
} MelodyType;
|
||||
|
||||
/* Get predefined melody sequence */
|
||||
extern const MelodyNote melody_startup[];
|
||||
extern const MelodyNote melody_low_battery[];
|
||||
extern const MelodyNote melody_error[];
|
||||
extern const MelodyNote melody_docking_complete[];
|
||||
|
||||
/*
|
||||
* buzzer_init()
|
||||
*
|
||||
* Initialize buzzer driver:
|
||||
* - PA8 as TIM1_CH1 PWM output
|
||||
* - TIM1 configured for 1kHz base frequency
|
||||
* - PWM duty cycle for volume control
|
||||
*/
|
||||
void buzzer_init(void);
|
||||
|
||||
/*
|
||||
* buzzer_play_melody(melody_type)
|
||||
*
|
||||
* Queue a predefined melody for playback.
|
||||
* Non-blocking: returns immediately, melody plays asynchronously.
|
||||
* Multiple calls queue melodies in sequence.
|
||||
*
|
||||
* Supported melodies:
|
||||
* - MELODY_STARTUP: 2-3 second jingle on power-up
|
||||
* - MELODY_LOW_BATTERY: 1 second warning
|
||||
* - MELODY_ERROR: 0.5 second alert beep
|
||||
* - MELODY_DOCKING_COMPLETE: 1-1.5 second success chime
|
||||
*
|
||||
* Returns: true if queued, false if queue full
|
||||
*/
|
||||
bool buzzer_play_melody(MelodyType melody_type);
|
||||
|
||||
/*
|
||||
* buzzer_play_custom(notes)
|
||||
*
|
||||
* Queue a custom melody sequence.
|
||||
* Notes array must be terminated with {NOTE_REST, 0}.
|
||||
* Useful for error codes or custom notifications.
|
||||
*
|
||||
* Returns: true if queued, false if queue full
|
||||
*/
|
||||
bool buzzer_play_custom(const MelodyNote *notes);
|
||||
|
||||
/*
|
||||
* buzzer_play_tone(frequency, duration_ms)
|
||||
*
|
||||
* Queue a simple single tone.
|
||||
* Useful for beeps and alerts.
|
||||
*
|
||||
* Arguments:
|
||||
* - frequency: Note frequency (Hz), 0 for silence
|
||||
* - duration_ms: Tone duration in milliseconds
|
||||
*
|
||||
* Returns: true if queued, false if queue full
|
||||
*/
|
||||
bool buzzer_play_tone(uint16_t frequency, uint16_t duration_ms);
|
||||
|
||||
/*
|
||||
* buzzer_stop()
|
||||
*
|
||||
* Stop current playback and clear queue.
|
||||
* Buzzer returns to silence immediately.
|
||||
*/
|
||||
void buzzer_stop(void);
|
||||
|
||||
/*
|
||||
* buzzer_is_playing()
|
||||
*
|
||||
* Returns: true if melody/tone is currently playing, false if idle
|
||||
*/
|
||||
bool buzzer_is_playing(void);
|
||||
|
||||
/*
|
||||
* buzzer_tick(now_ms)
|
||||
*
|
||||
* Update function called periodically (recommended: every 10ms in main loop).
|
||||
* Manages melody timing and PWM frequency transitions.
|
||||
* Must be called regularly for non-blocking operation.
|
||||
*
|
||||
* Arguments:
|
||||
* - now_ms: current time in milliseconds (from HAL_GetTick() or similar)
|
||||
*/
|
||||
void buzzer_tick(uint32_t now_ms);
|
||||
|
||||
#endif /* BUZZER_H */
|
||||
@ -0,0 +1,17 @@
|
||||
# Battery-aware speed scaling configuration
|
||||
|
||||
battery_speed_scaler:
|
||||
ros__parameters:
|
||||
# Update frequency (Hz)
|
||||
frequency: 1 # 1 Hz is sufficient for battery monitoring
|
||||
|
||||
# Battery level thresholds (0.0 to 1.0 percentage)
|
||||
# Below these thresholds, speed is reduced
|
||||
critical_threshold: 0.20 # 20% - critical battery
|
||||
warning_threshold: 0.50 # 50% - moderate discharge
|
||||
|
||||
# Speed scaling factors (0.0 to 1.0)
|
||||
# Applied to max velocity when battery is below thresholds
|
||||
full_scale: 1.0 # >= 50% battery: full speed
|
||||
warning_scale: 0.7 # 20-50% battery: 70% speed
|
||||
critical_scale: 0.4 # < 20% battery: 40% speed
|
||||
@ -0,0 +1,36 @@
|
||||
"""Launch file for battery_speed_scaler_node."""
|
||||
|
||||
from launch import LaunchDescription
|
||||
from launch_ros.actions import Node
|
||||
from launch.substitutions import LaunchConfiguration
|
||||
from launch.actions import DeclareLaunchArgument
|
||||
import os
|
||||
from ament_index_python.packages import get_package_share_directory
|
||||
|
||||
|
||||
def generate_launch_description():
|
||||
"""Generate launch description for battery speed scaler node."""
|
||||
# Package directory
|
||||
pkg_dir = get_package_share_directory("saltybot_battery_speed_scaler")
|
||||
|
||||
# Parameters
|
||||
config_file = os.path.join(pkg_dir, "config", "battery_config.yaml")
|
||||
|
||||
# Declare launch arguments
|
||||
return LaunchDescription(
|
||||
[
|
||||
DeclareLaunchArgument(
|
||||
"config_file",
|
||||
default_value=config_file,
|
||||
description="Path to configuration YAML file",
|
||||
),
|
||||
# Battery speed scaler node
|
||||
Node(
|
||||
package="saltybot_battery_speed_scaler",
|
||||
executable="battery_speed_scaler_node",
|
||||
name="battery_speed_scaler",
|
||||
output="screen",
|
||||
parameters=[LaunchConfiguration("config_file")],
|
||||
),
|
||||
]
|
||||
)
|
||||
21
jetson/ros2_ws/src/saltybot_battery_speed_scaler/package.xml
Normal file
21
jetson/ros2_ws/src/saltybot_battery_speed_scaler/package.xml
Normal file
@ -0,0 +1,21 @@
|
||||
<?xml version="1.0"?>
|
||||
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
|
||||
<package format="3">
|
||||
<name>saltybot_battery_speed_scaler</name>
|
||||
<version>0.1.0</version>
|
||||
<description>Battery-aware speed scaling for SaltyBot.</description>
|
||||
<maintainer email="seb@vayrette.com">Seb</maintainer>
|
||||
<license>Apache-2.0</license>
|
||||
|
||||
<buildtool_depend>ament_python</buildtool_depend>
|
||||
|
||||
<depend>rclpy</depend>
|
||||
<depend>sensor_msgs</depend>
|
||||
<depend>std_msgs</depend>
|
||||
|
||||
<test_depend>pytest</test_depend>
|
||||
|
||||
<export>
|
||||
<build_type>ament_python</build_type>
|
||||
</export>
|
||||
</package>
|
||||
@ -0,0 +1,119 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Battery-aware speed scaling for SaltyBot.
|
||||
|
||||
Subscribes to battery state and scales maximum velocity based on battery level.
|
||||
Prevents over-discharge and extends operational range.
|
||||
|
||||
Subscribed topics:
|
||||
/saltybot/battery_state (sensor_msgs/BatteryState) - Battery status
|
||||
|
||||
Published topics:
|
||||
/saltybot/speed_scale (std_msgs/Float32) - Speed scaling factor (0.0-1.0)
|
||||
|
||||
Battery level thresholds:
|
||||
100-50%: 1.0 scale (full speed)
|
||||
50-20%: 0.7 scale (70% speed)
|
||||
<20%: 0.4 scale (40% speed - critical)
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import rclpy
|
||||
from rclpy.node import Node
|
||||
from rclpy.timer import Timer
|
||||
from sensor_msgs.msg import BatteryState
|
||||
from std_msgs.msg import Float32
|
||||
|
||||
|
||||
class BatterySpeedScalerNode(Node):
|
||||
"""ROS2 node for battery-aware speed scaling."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("battery_speed_scaler")
|
||||
|
||||
# Parameters
|
||||
self.declare_parameter("frequency", 1) # Hz
|
||||
frequency = self.get_parameter("frequency").value
|
||||
|
||||
# Battery thresholds (percentage)
|
||||
self.declare_parameter("critical_threshold", 20.0)
|
||||
self.declare_parameter("warning_threshold", 50.0)
|
||||
|
||||
# Speed scaling factors
|
||||
self.declare_parameter("full_scale", 1.0)
|
||||
self.declare_parameter("warning_scale", 0.7)
|
||||
self.declare_parameter("critical_scale", 0.4)
|
||||
|
||||
self.critical_threshold = self.get_parameter("critical_threshold").value
|
||||
self.warning_threshold = self.get_parameter("warning_threshold").value
|
||||
self.full_scale = self.get_parameter("full_scale").value
|
||||
self.warning_scale = self.get_parameter("warning_scale").value
|
||||
self.critical_scale = self.get_parameter("critical_scale").value
|
||||
|
||||
# Latest battery state
|
||||
self.battery_state: Optional[BatteryState] = None
|
||||
|
||||
# Subscription
|
||||
self.create_subscription(
|
||||
BatteryState, "/saltybot/battery_state", self._on_battery_state, 10
|
||||
)
|
||||
|
||||
# Publisher for speed scale
|
||||
self.pub_scale = self.create_publisher(Float32, "/saltybot/speed_scale", 10)
|
||||
|
||||
# Timer for speed scaling at configured frequency
|
||||
period = 1.0 / frequency
|
||||
self.timer: Timer = self.create_timer(period, self._timer_callback)
|
||||
|
||||
self.get_logger().info(
|
||||
f"Battery speed scaler initialized at {frequency}Hz. "
|
||||
f"Thresholds: warning={self.warning_threshold}%, "
|
||||
f"critical={self.critical_threshold}%. "
|
||||
f"Scale factors: full={self.full_scale}, "
|
||||
f"warning={self.warning_scale}, critical={self.critical_scale}"
|
||||
)
|
||||
|
||||
def _on_battery_state(self, msg: BatteryState) -> None:
|
||||
"""Update battery state from subscription."""
|
||||
self.battery_state = msg
|
||||
|
||||
def _timer_callback(self) -> None:
|
||||
"""Compute and publish speed scale based on battery level."""
|
||||
if self.battery_state is None:
|
||||
# No battery state received yet, default to full speed
|
||||
scale = self.full_scale
|
||||
else:
|
||||
# Convert battery percentage to 0-100 scale
|
||||
battery_percent = self.battery_state.percentage * 100.0
|
||||
|
||||
# Determine speed scale based on battery level
|
||||
if battery_percent >= self.warning_threshold:
|
||||
# Good battery level: full speed
|
||||
scale = self.full_scale
|
||||
elif battery_percent >= self.critical_threshold:
|
||||
# Moderate discharge: warning speed
|
||||
scale = self.warning_scale
|
||||
else:
|
||||
# Critical battery: reduced speed
|
||||
scale = self.critical_scale
|
||||
|
||||
# Publish speed scale
|
||||
scale_msg = Float32()
|
||||
scale_msg.data = scale
|
||||
self.pub_scale.publish(scale_msg)
|
||||
|
||||
|
||||
def main(args=None):
|
||||
rclpy.init(args=args)
|
||||
node = BatterySpeedScalerNode()
|
||||
try:
|
||||
rclpy.spin(node)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
node.destroy_node()
|
||||
rclpy.shutdown()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -0,0 +1,4 @@
|
||||
[develop]
|
||||
script-dir=$base/lib/saltybot_battery_speed_scaler
|
||||
[install]
|
||||
install-scripts=$base/lib/saltybot_battery_speed_scaler
|
||||
27
jetson/ros2_ws/src/saltybot_battery_speed_scaler/setup.py
Normal file
27
jetson/ros2_ws/src/saltybot_battery_speed_scaler/setup.py
Normal file
@ -0,0 +1,27 @@
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
package_name = "saltybot_battery_speed_scaler"
|
||||
|
||||
setup(
|
||||
name=package_name,
|
||||
version="0.1.0",
|
||||
packages=find_packages(exclude=["test"]),
|
||||
data_files=[
|
||||
("share/ament_index/resource_index/packages", ["resource/" + package_name]),
|
||||
("share/" + package_name, ["package.xml"]),
|
||||
("share/" + package_name + "/launch", ["launch/battery_speed_scaler.launch.py"]),
|
||||
("share/" + package_name + "/config", ["config/battery_config.yaml"]),
|
||||
],
|
||||
install_requires=["setuptools"],
|
||||
zip_safe=True,
|
||||
maintainer="Seb",
|
||||
maintainer_email="seb@vayrette.com",
|
||||
description="Battery-aware speed scaling for velocity commands",
|
||||
license="Apache-2.0",
|
||||
tests_require=["pytest"],
|
||||
entry_points={
|
||||
"console_scripts": [
|
||||
"battery_speed_scaler_node = saltybot_battery_speed_scaler.battery_speed_scaler_node:main",
|
||||
],
|
||||
},
|
||||
)
|
||||
@ -0,0 +1,400 @@
|
||||
"""Unit tests for battery_speed_scaler_node."""
|
||||
|
||||
import pytest
|
||||
from sensor_msgs.msg import BatteryState
|
||||
from std_msgs.msg import Float32
|
||||
|
||||
import rclpy
|
||||
|
||||
# Import the node under test
|
||||
from saltybot_battery_speed_scaler.battery_speed_scaler_node import BatterySpeedScalerNode
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def rclpy_fixture():
|
||||
"""Initialize and cleanup rclpy."""
|
||||
rclpy.init()
|
||||
yield
|
||||
rclpy.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def node(rclpy_fixture):
|
||||
"""Create a battery speed scaler node instance."""
|
||||
node = BatterySpeedScalerNode()
|
||||
yield node
|
||||
node.destroy_node()
|
||||
|
||||
|
||||
class TestNodeInitialization:
|
||||
"""Test suite for node initialization."""
|
||||
|
||||
def test_node_initialization(self, node):
|
||||
"""Test that node initializes with correct defaults."""
|
||||
assert node.battery_state is None
|
||||
assert node.critical_threshold == 0.20
|
||||
assert node.warning_threshold == 0.50
|
||||
assert node.full_scale == 1.0
|
||||
assert node.warning_scale == 0.7
|
||||
assert node.critical_scale == 0.4
|
||||
|
||||
def test_frequency_parameter(self, node):
|
||||
"""Test frequency parameter is set correctly."""
|
||||
frequency = node.get_parameter("frequency").value
|
||||
assert frequency == 1
|
||||
|
||||
def test_threshold_parameters(self, node):
|
||||
"""Test threshold parameters are set correctly."""
|
||||
critical = node.get_parameter("critical_threshold").value
|
||||
warning = node.get_parameter("warning_threshold").value
|
||||
assert critical == 0.20
|
||||
assert warning == 0.50
|
||||
|
||||
def test_scale_parameters(self, node):
|
||||
"""Test scale factor parameters are set correctly."""
|
||||
full = node.get_parameter("full_scale").value
|
||||
warning = node.get_parameter("warning_scale").value
|
||||
critical = node.get_parameter("critical_scale").value
|
||||
assert full == 1.0
|
||||
assert warning == 0.7
|
||||
assert critical == 0.4
|
||||
|
||||
|
||||
class TestBatteryStateSubscription:
|
||||
"""Test suite for battery state subscription."""
|
||||
|
||||
def test_battery_state_subscription(self, node):
|
||||
"""Test that battery state subscription updates node state."""
|
||||
battery = BatteryState()
|
||||
battery.percentage = 0.75 # 75%
|
||||
|
||||
node._on_battery_state(battery)
|
||||
|
||||
assert node.battery_state is battery
|
||||
assert node.battery_state.percentage == 0.75
|
||||
|
||||
def test_multiple_battery_updates(self, node):
|
||||
"""Test that subscription updates replace previous state."""
|
||||
battery1 = BatteryState()
|
||||
battery1.percentage = 0.75
|
||||
|
||||
battery2 = BatteryState()
|
||||
battery2.percentage = 0.50
|
||||
|
||||
node._on_battery_state(battery1)
|
||||
assert node.battery_state.percentage == 0.75
|
||||
|
||||
node._on_battery_state(battery2)
|
||||
assert node.battery_state.percentage == 0.50
|
||||
|
||||
|
||||
class TestSpeedScaling:
|
||||
"""Test suite for speed scaling logic."""
|
||||
|
||||
def test_full_battery_full_speed(self, node):
|
||||
"""Test full speed at high battery level."""
|
||||
battery = BatteryState()
|
||||
battery.percentage = 1.0 # 100%
|
||||
|
||||
node._on_battery_state(battery)
|
||||
node._timer_callback()
|
||||
|
||||
# Should publish full scale
|
||||
assert True # Timer callback executes without error
|
||||
|
||||
def test_high_battery_full_speed(self, node):
|
||||
"""Test full speed at 75% battery."""
|
||||
battery = BatteryState()
|
||||
battery.percentage = 0.75 # 75%
|
||||
|
||||
node._on_battery_state(battery)
|
||||
node._timer_callback()
|
||||
|
||||
# Should publish full scale
|
||||
assert True
|
||||
|
||||
def test_threshold_battery_full_speed(self, node):
|
||||
"""Test full speed at warning threshold (50%)."""
|
||||
battery = BatteryState()
|
||||
battery.percentage = 0.50 # 50% - at warning threshold
|
||||
|
||||
node._on_battery_state(battery)
|
||||
node._timer_callback()
|
||||
|
||||
# Should publish full scale (>= warning threshold)
|
||||
assert True
|
||||
|
||||
def test_above_warning_threshold_full_speed(self, node):
|
||||
"""Test full speed at 51% (just above warning threshold)."""
|
||||
battery = BatteryState()
|
||||
battery.percentage = 0.51 # 51%
|
||||
|
||||
node._on_battery_state(battery)
|
||||
node._timer_callback()
|
||||
|
||||
# Should publish full scale
|
||||
assert True
|
||||
|
||||
def test_below_warning_threshold_warning_scale(self, node):
|
||||
"""Test warning scale at 49% (just below warning threshold)."""
|
||||
battery = BatteryState()
|
||||
battery.percentage = 0.49 # 49%
|
||||
|
||||
node._on_battery_state(battery)
|
||||
node._timer_callback()
|
||||
|
||||
# Should publish warning scale
|
||||
assert True
|
||||
|
||||
def test_warning_battery_warning_scale(self, node):
|
||||
"""Test warning scale at 30% battery."""
|
||||
battery = BatteryState()
|
||||
battery.percentage = 0.30 # 30%
|
||||
|
||||
node._on_battery_state(battery)
|
||||
node._timer_callback()
|
||||
|
||||
# Should publish warning scale
|
||||
assert True
|
||||
|
||||
def test_critical_threshold_warning_scale(self, node):
|
||||
"""Test warning scale at critical threshold (20%)."""
|
||||
battery = BatteryState()
|
||||
battery.percentage = 0.20 # 20% - at critical threshold
|
||||
|
||||
node._on_battery_state(battery)
|
||||
node._timer_callback()
|
||||
|
||||
# Should publish warning scale (>= critical threshold)
|
||||
assert True
|
||||
|
||||
def test_above_critical_threshold_warning_scale(self, node):
|
||||
"""Test warning scale at 21% (just above critical threshold)."""
|
||||
battery = BatteryState()
|
||||
battery.percentage = 0.21 # 21%
|
||||
|
||||
node._on_battery_state(battery)
|
||||
node._timer_callback()
|
||||
|
||||
# Should publish warning scale
|
||||
assert True
|
||||
|
||||
def test_below_critical_threshold_critical_scale(self, node):
|
||||
"""Test critical scale at 19% (just below critical threshold)."""
|
||||
battery = BatteryState()
|
||||
battery.percentage = 0.19 # 19%
|
||||
|
||||
node._on_battery_state(battery)
|
||||
node._timer_callback()
|
||||
|
||||
# Should publish critical scale
|
||||
assert True
|
||||
|
||||
def test_critical_battery_critical_scale(self, node):
|
||||
"""Test critical scale at 10% battery."""
|
||||
battery = BatteryState()
|
||||
battery.percentage = 0.10 # 10%
|
||||
|
||||
node._on_battery_state(battery)
|
||||
node._timer_callback()
|
||||
|
||||
# Should publish critical scale
|
||||
assert True
|
||||
|
||||
def test_empty_battery_critical_scale(self, node):
|
||||
"""Test critical scale at 1% battery."""
|
||||
battery = BatteryState()
|
||||
battery.percentage = 0.01 # 1%
|
||||
|
||||
node._on_battery_state(battery)
|
||||
node._timer_callback()
|
||||
|
||||
# Should publish critical scale
|
||||
assert True
|
||||
|
||||
def test_no_battery_state_defaults_to_full(self, node):
|
||||
"""Test that node defaults to full speed without battery state."""
|
||||
node.battery_state = None
|
||||
node._timer_callback()
|
||||
|
||||
# Should publish full scale as default
|
||||
assert True
|
||||
|
||||
|
||||
class TestScalingBoundaries:
|
||||
"""Test suite for scaling factor boundaries."""
|
||||
|
||||
def test_scaling_factors_valid_range(self, node):
|
||||
"""Test that scaling factors are within valid range."""
|
||||
assert 0.0 <= node.full_scale <= 1.0
|
||||
assert 0.0 <= node.warning_scale <= 1.0
|
||||
assert 0.0 <= node.critical_scale <= 1.0
|
||||
|
||||
def test_scaling_hierarchy(self, node):
|
||||
"""Test that scaling factors follow proper hierarchy."""
|
||||
# Critical should be most restrictive
|
||||
assert node.critical_scale <= node.warning_scale
|
||||
assert node.warning_scale <= node.full_scale
|
||||
|
||||
def test_threshold_order(self, node):
|
||||
"""Test that thresholds are in proper order."""
|
||||
assert node.critical_threshold < node.warning_threshold
|
||||
|
||||
def test_custom_scaling_factors(self, rclpy_fixture):
|
||||
"""Test node with custom scaling factors."""
|
||||
rclpy.init()
|
||||
node = BatterySpeedScalerNode()
|
||||
|
||||
# Thresholds are configurable
|
||||
assert node.critical_threshold == 0.20
|
||||
assert node.warning_threshold == 0.50
|
||||
|
||||
node.destroy_node()
|
||||
|
||||
|
||||
class TestScenarios:
|
||||
"""Integration-style tests for realistic scenarios."""
|
||||
|
||||
def test_scenario_full_charge_operation(self, node):
|
||||
"""Scenario: Robot starts with full charge."""
|
||||
battery = BatteryState()
|
||||
battery.percentage = 1.0
|
||||
|
||||
node._on_battery_state(battery)
|
||||
node._timer_callback()
|
||||
|
||||
# Should operate at full speed
|
||||
assert True
|
||||
|
||||
def test_scenario_gradual_discharge(self, node):
|
||||
"""Scenario: Battery gradually discharges during operation."""
|
||||
discharge_levels = [1.0, 0.75, 0.55, 0.50, 0.40, 0.20, 0.10, 0.05]
|
||||
|
||||
for level in discharge_levels:
|
||||
battery = BatteryState()
|
||||
battery.percentage = level
|
||||
|
||||
node._on_battery_state(battery)
|
||||
node._timer_callback()
|
||||
|
||||
# Should handle all discharge levels
|
||||
assert True
|
||||
|
||||
def test_scenario_sudden_power_loss(self, node):
|
||||
"""Scenario: Battery suddenly drops due to power surge."""
|
||||
# High battery
|
||||
battery1 = BatteryState()
|
||||
battery1.percentage = 0.80
|
||||
node._on_battery_state(battery1)
|
||||
node._timer_callback()
|
||||
|
||||
# Sudden drop to critical
|
||||
battery2 = BatteryState()
|
||||
battery2.percentage = 0.15
|
||||
node._on_battery_state(battery2)
|
||||
node._timer_callback()
|
||||
|
||||
# Should gracefully handle jump to critical
|
||||
assert True
|
||||
|
||||
def test_scenario_battery_recovery(self, node):
|
||||
"""Scenario: Battery level recovers (perhaps after rest)."""
|
||||
# Start critical
|
||||
battery1 = BatteryState()
|
||||
battery1.percentage = 0.10
|
||||
node._on_battery_state(battery1)
|
||||
node._timer_callback()
|
||||
|
||||
# Recovery
|
||||
battery2 = BatteryState()
|
||||
battery2.percentage = 0.60
|
||||
node._on_battery_state(battery2)
|
||||
node._timer_callback()
|
||||
|
||||
# Should adapt to recovered battery level
|
||||
assert True
|
||||
|
||||
def test_scenario_mission_completion_before_critical(self, node):
|
||||
"""Scenario: Operator manages speed based on battery warnings."""
|
||||
battery_levels = [0.90, 0.60, 0.52, 0.50, 0.45, 0.25, 0.22, 0.20]
|
||||
|
||||
for level in battery_levels:
|
||||
battery = BatteryState()
|
||||
battery.percentage = level
|
||||
|
||||
node._on_battery_state(battery)
|
||||
node._timer_callback()
|
||||
|
||||
# At 50% crosses into warning zone, should reduce speed
|
||||
# At 20% crosses into critical, should reduce further
|
||||
assert True
|
||||
|
||||
def test_scenario_emergency_low_battery_return(self, node):
|
||||
"""Scenario: Robot enters critical mode and must return home."""
|
||||
# Already low battery when emergency triggers
|
||||
battery = BatteryState()
|
||||
battery.percentage = 0.15
|
||||
|
||||
node._on_battery_state(battery)
|
||||
node._timer_callback()
|
||||
|
||||
# Should limit to critical scale (40%) to extend range
|
||||
assert True
|
||||
|
||||
def test_scenario_constant_monitoring(self, node):
|
||||
"""Scenario: Continuous battery monitoring during operation."""
|
||||
# Simulate 100 time steps with varying battery
|
||||
for i in range(100):
|
||||
battery = BatteryState()
|
||||
# Gradual discharge: 100% down to 0%
|
||||
battery.percentage = 1.0 - (i / 100.0)
|
||||
|
||||
node._on_battery_state(battery)
|
||||
node._timer_callback()
|
||||
|
||||
# Should handle continuous monitoring
|
||||
assert True
|
||||
|
||||
def test_scenario_hysteresis_needed(self, node):
|
||||
"""Scenario: Battery level oscillates near threshold."""
|
||||
# Oscillate near 50% threshold
|
||||
thresholds_crossing = [0.51, 0.49, 0.51, 0.49, 0.51, 0.49]
|
||||
|
||||
for level in thresholds_crossing:
|
||||
battery = BatteryState()
|
||||
battery.percentage = level
|
||||
|
||||
node._on_battery_state(battery)
|
||||
node._timer_callback()
|
||||
|
||||
# Should handle oscillations (without hysteresis, may cause
|
||||
# rapid scale changes. This is acceptable for this node.)
|
||||
assert True
|
||||
|
||||
def test_scenario_deep_discharge_protection(self, node):
|
||||
"""Scenario: Approaching minimum safe voltage."""
|
||||
critical_levels = [0.20, 0.15, 0.10, 0.05, 0.01]
|
||||
|
||||
for level in critical_levels:
|
||||
battery = BatteryState()
|
||||
battery.percentage = level
|
||||
|
||||
node._on_battery_state(battery)
|
||||
node._timer_callback()
|
||||
|
||||
# All below critical should use critical scale
|
||||
assert True
|
||||
|
||||
def test_scenario_cold_weather_reduced_capacity(self, node):
|
||||
"""Scenario: Cold weather reduces effective battery capacity."""
|
||||
# Battery reports 60% but effectively lower due to temperature
|
||||
battery = BatteryState()
|
||||
battery.percentage = 0.60
|
||||
battery.temperature = 273 + (-10) # -10°C
|
||||
|
||||
node._on_battery_state(battery)
|
||||
node._timer_callback()
|
||||
|
||||
# Node should publish based on reported percentage (60% = full scale)
|
||||
# Temperature compensation would be separate concern
|
||||
assert True
|
||||
@ -0,0 +1,210 @@
|
||||
"""
|
||||
_floor_classifier.py — Floor surface type classifier (no ROS2 deps).
|
||||
|
||||
Classifies floor patches from D435i RGB frames into one of six categories:
|
||||
carpet · tile · wood · concrete · grass · gravel
|
||||
|
||||
Algorithm
|
||||
---------
|
||||
1. Crop the bottom `roi_frac` of the image (the visible floor region).
|
||||
2. Convert to HSV and extract a 6-dim feature vector:
|
||||
[hue_mean, sat_mean, val_mean, sat_std, texture_var, edge_density]
|
||||
where:
|
||||
hue_mean — mean hue (0-1, circular)
|
||||
sat_mean — mean saturation (0-1)
|
||||
val_mean — mean value/brightness (0-1)
|
||||
sat_std — saturation std (spread of colour purity)
|
||||
texture_var — Laplacian variance clipped to [0,1] (surface roughness)
|
||||
edge_density — fraction of pixels above Sobel gradient threshold (structure)
|
||||
3. Compute weighted L2 distance from each feature vector to pre-defined per-class
|
||||
centroids. Return the nearest class plus a softmax-based confidence score.
|
||||
|
||||
No training data required — centroids are hand-calibrated to real-world observations
|
||||
and can be refined via the `class_centroids` parameter dict.
|
||||
|
||||
Public API
|
||||
----------
|
||||
extract_features(bgr, roi_frac=0.4) → np.ndarray (6,)
|
||||
classify_floor_patch(bgr, ...) → ClassifyResult
|
||||
LabelSmoother — majority-vote temporal smoother
|
||||
|
||||
ClassifyResult
|
||||
--------------
|
||||
label : str — floor surface class
|
||||
confidence : float — 0–1 (1 = no competing class)
|
||||
features : ndarray — (6,) raw features (for debugging)
|
||||
distances : dict — {label: L2 distance} to all class centroids
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from collections import deque, Counter
|
||||
from typing import Dict, List, NamedTuple, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
# ── Default class centroids (6 features each) ─────────────────────────────────
|
||||
#
|
||||
# Feature order: [hue_mean, sat_mean, val_mean, sat_std, texture_var, edge_density]
|
||||
# Tuned for typical indoor/outdoor floor surfaces under D435i colour stream.
|
||||
#
|
||||
_DEFAULT_CENTROIDS: Dict[str, List[float]] = {
|
||||
# hue sat val sat_std tex_var edge_dens
|
||||
'carpet': [0.05, 0.30, 0.45, 0.08, 0.03, 0.08],
|
||||
'tile': [0.08, 0.08, 0.65, 0.04, 0.06, 0.35],
|
||||
'wood': [0.08, 0.45, 0.50, 0.09, 0.05, 0.20],
|
||||
'concrete': [0.06, 0.05, 0.55, 0.02, 0.02, 0.08],
|
||||
'grass': [0.33, 0.60, 0.45, 0.12, 0.07, 0.28],
|
||||
'gravel': [0.07, 0.18, 0.42, 0.08, 0.09, 0.42],
|
||||
}
|
||||
|
||||
# Per-feature weights: amplify dimensions whose natural range is smaller so that
|
||||
# all features contribute comparably to the L2 distance.
|
||||
_FEATURE_WEIGHTS = np.array([3.0, 2.0, 1.0, 5.0, 8.0, 2.0], dtype=np.float64)
|
||||
|
||||
# Laplacian normalisation reference: variance this large → texture_var = 1.0
|
||||
_LAP_REF = 500.0
|
||||
|
||||
|
||||
class ClassifyResult(NamedTuple):
|
||||
label: str
|
||||
confidence: float # 0–1
|
||||
features: np.ndarray # (6,) float64
|
||||
distances: Dict[str, float]
|
||||
|
||||
|
||||
def extract_features(bgr: np.ndarray, roi_frac: float = 0.40) -> np.ndarray:
|
||||
"""
|
||||
Extract a 6-dim feature vector from the floor ROI of a BGR image.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
bgr : (H, W, 3) uint8 BGR image
|
||||
roi_frac : fraction of the image height to use as floor ROI (bottom)
|
||||
|
||||
Returns
|
||||
-------
|
||||
(6,) float64 array: [hue_mean, sat_mean, val_mean, sat_std, texture_var, edge_density]
|
||||
"""
|
||||
import cv2
|
||||
|
||||
h = bgr.shape[0]
|
||||
roi_start = max(0, int(h * (1.0 - roi_frac)))
|
||||
roi = bgr[roi_start:, :, :]
|
||||
|
||||
# ── HSV colour features ───────────────────────────────────────────────────
|
||||
hsv = cv2.cvtColor(roi, cv2.COLOR_BGR2HSV).astype(np.float64)
|
||||
hue = hsv[:, :, 0] / 179.0 # cv2 hue: 0-179 → normalise to 0-1
|
||||
sat = hsv[:, :, 1] / 255.0
|
||||
val = hsv[:, :, 2] / 255.0
|
||||
|
||||
hue_mean = _circular_mean(hue)
|
||||
sat_mean = float(sat.mean())
|
||||
val_mean = float(val.mean())
|
||||
sat_std = float(sat.std())
|
||||
|
||||
# ── Texture: Laplacian variance ───────────────────────────────────────────
|
||||
grey = cv2.cvtColor(roi, cv2.COLOR_BGR2GRAY)
|
||||
lap = cv2.Laplacian(grey, cv2.CV_64F)
|
||||
lap_var = float(lap.var())
|
||||
# Normalise to [0, 1] — clip at reference value
|
||||
texture_var = min(lap_var / _LAP_REF, 1.0)
|
||||
|
||||
# ── Edges: fraction of pixels with strong Sobel gradient ─────────────────
|
||||
sx = cv2.Sobel(grey, cv2.CV_64F, 1, 0, ksize=3)
|
||||
sy = cv2.Sobel(grey, cv2.CV_64F, 0, 1, ksize=3)
|
||||
mag = np.hypot(sx, sy)
|
||||
edge_density = float((mag > 30.0).mean())
|
||||
|
||||
return np.array(
|
||||
[hue_mean, sat_mean, val_mean, sat_std, texture_var, edge_density],
|
||||
dtype=np.float64,
|
||||
)
|
||||
|
||||
|
||||
def classify_floor_patch(
|
||||
bgr: np.ndarray,
|
||||
roi_frac: float = 0.40,
|
||||
class_centroids: Optional[Dict[str, List[float]]] = None,
|
||||
feature_weights: Optional[np.ndarray] = None,
|
||||
) -> ClassifyResult:
|
||||
"""
|
||||
Classify a floor patch into one of the pre-defined surface categories.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
bgr : (H, W, 3) uint8 BGR image
|
||||
roi_frac : floor ROI fraction (bottom of image)
|
||||
class_centroids : override default centroid dict
|
||||
feature_weights : (6,) weights applied before L2 distance
|
||||
|
||||
Returns
|
||||
-------
|
||||
ClassifyResult(label, confidence, features, distances)
|
||||
"""
|
||||
centroids = class_centroids if class_centroids is not None else _DEFAULT_CENTROIDS
|
||||
weights = feature_weights if feature_weights is not None else _FEATURE_WEIGHTS
|
||||
|
||||
feats = extract_features(bgr, roi_frac=roi_frac)
|
||||
w_feats = feats * weights
|
||||
|
||||
distances: Dict[str, float] = {}
|
||||
for label, centroid in centroids.items():
|
||||
w_centroid = np.asarray(centroid, dtype=np.float64) * weights
|
||||
distances[label] = float(np.linalg.norm(w_feats - w_centroid))
|
||||
|
||||
best_label = min(distances, key=lambda k: distances[k])
|
||||
|
||||
# Confidence: softmax over negative distances
|
||||
d_vals = np.array(list(distances.values()))
|
||||
softmax = np.exp(-d_vals) / np.exp(-d_vals).sum()
|
||||
best_idx = list(distances.keys()).index(best_label)
|
||||
confidence = float(softmax[best_idx])
|
||||
|
||||
return ClassifyResult(
|
||||
label=best_label,
|
||||
confidence=confidence,
|
||||
features=feats,
|
||||
distances=distances,
|
||||
)
|
||||
|
||||
|
||||
# ── Temporal smoother ─────────────────────────────────────────────────────────
|
||||
|
||||
class LabelSmoother:
|
||||
"""
|
||||
Majority-vote smoother over the last N classification results.
|
||||
|
||||
Usage:
|
||||
smoother = LabelSmoother(window=5)
|
||||
label = smoother.update('carpet') # → smoothed label
|
||||
"""
|
||||
|
||||
def __init__(self, window: int = 5) -> None:
|
||||
self._window = window
|
||||
self._history: deque = deque(maxlen=window)
|
||||
|
||||
def update(self, label: str) -> str:
|
||||
self._history.append(label)
|
||||
return Counter(self._history).most_common(1)[0][0]
|
||||
|
||||
def clear(self) -> None:
|
||||
self._history.clear()
|
||||
|
||||
@property
|
||||
def ready(self) -> bool:
|
||||
"""True once the window is full."""
|
||||
return len(self._history) == self._window
|
||||
|
||||
|
||||
# ── Internal helpers ──────────────────────────────────────────────────────────
|
||||
|
||||
def _circular_mean(hue_01: np.ndarray) -> float:
|
||||
"""Circular mean of hue values in [0, 1]."""
|
||||
angles = hue_01 * 2.0 * math.pi
|
||||
sin_mean = float(np.sin(angles).mean())
|
||||
cos_mean = float(np.cos(angles).mean())
|
||||
angle = math.atan2(sin_mean, cos_mean)
|
||||
return (angle / (2.0 * math.pi)) % 1.0
|
||||
@ -0,0 +1,116 @@
|
||||
"""
|
||||
floor_classifier_node.py — Floor surface type classifier (Issue #249).
|
||||
|
||||
Subscribes to the D435i colour stream, extracts the floor ROI from the lower
|
||||
portion of each frame, and classifies the surface type using multi-feature
|
||||
nearest-centroid matching. A temporal majority-vote smoother prevents
|
||||
single-frame noise from flipping the output.
|
||||
|
||||
Subscribes:
|
||||
/camera/color/image_raw sensor_msgs/Image (D435i colour, BEST_EFFORT)
|
||||
|
||||
Publishes:
|
||||
/saltybot/floor_type std_msgs/String (floor label, 2 Hz)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
publish_hz float 2.0 Publication rate (Hz)
|
||||
roi_frac float 0.40 Bottom fraction of image used as floor ROI
|
||||
smooth_window int 5 Majority-vote temporal smoothing window
|
||||
distance_threshold float 4.0 Suppress publish if nearest-centroid distance
|
||||
exceeds this value (low confidence; publishes
|
||||
"unknown" instead)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import rclpy
|
||||
from rclpy.node import Node
|
||||
from rclpy.qos import QoSProfile, ReliabilityPolicy, HistoryPolicy
|
||||
|
||||
from cv_bridge import CvBridge
|
||||
|
||||
from sensor_msgs.msg import Image
|
||||
from std_msgs.msg import String
|
||||
|
||||
from ._floor_classifier import classify_floor_patch, LabelSmoother
|
||||
|
||||
|
||||
_SENSOR_QOS = QoSProfile(
|
||||
reliability=ReliabilityPolicy.BEST_EFFORT,
|
||||
history=HistoryPolicy.KEEP_LAST,
|
||||
depth=2,
|
||||
)
|
||||
|
||||
|
||||
class FloorClassifierNode(Node):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__('floor_classifier_node')
|
||||
|
||||
self.declare_parameter('publish_hz', 2.0)
|
||||
self.declare_parameter('roi_frac', 0.40)
|
||||
self.declare_parameter('smooth_window', 5)
|
||||
self.declare_parameter('distance_threshold', 4.0)
|
||||
|
||||
publish_hz = self.get_parameter('publish_hz').value
|
||||
self._roi_frac = self.get_parameter('roi_frac').value
|
||||
smooth_window = int(self.get_parameter('smooth_window').value)
|
||||
self._dist_thresh = self.get_parameter('distance_threshold').value
|
||||
|
||||
self._bridge = CvBridge()
|
||||
self._smoother = LabelSmoother(window=smooth_window)
|
||||
self._latest_label: str = 'unknown'
|
||||
|
||||
self._sub = self.create_subscription(
|
||||
Image,
|
||||
'/camera/color/image_raw',
|
||||
self._on_image,
|
||||
_SENSOR_QOS,
|
||||
)
|
||||
self._pub = self.create_publisher(String, '/saltybot/floor_type', 10)
|
||||
self.create_timer(1.0 / publish_hz, self._tick)
|
||||
|
||||
self.get_logger().info(
|
||||
f'floor_classifier_node ready — '
|
||||
f'roi={self._roi_frac:.0%} smooth={smooth_window} hz={publish_hz}'
|
||||
)
|
||||
|
||||
# ── Callback ──────────────────────────────────────────────────────────────
|
||||
|
||||
def _on_image(self, msg: Image) -> None:
|
||||
try:
|
||||
bgr = self._bridge.imgmsg_to_cv2(msg, 'bgr8')
|
||||
except Exception as exc:
|
||||
self.get_logger().error(
|
||||
f'cv_bridge: {exc}', throttle_duration_sec=5.0)
|
||||
return
|
||||
|
||||
result = classify_floor_patch(bgr, roi_frac=self._roi_frac)
|
||||
|
||||
min_dist = min(result.distances.values())
|
||||
raw_label = result.label if min_dist <= self._dist_thresh else 'unknown'
|
||||
self._latest_label = self._smoother.update(raw_label)
|
||||
|
||||
# ── 2 Hz publish tick ─────────────────────────────────────────────────────
|
||||
|
||||
def _tick(self) -> None:
|
||||
msg = String()
|
||||
msg.data = self._latest_label
|
||||
self._pub.publish(msg)
|
||||
|
||||
|
||||
def main(args=None) -> None:
|
||||
rclpy.init(args=args)
|
||||
node = FloorClassifierNode()
|
||||
try:
|
||||
rclpy.spin(node)
|
||||
finally:
|
||||
node.destroy_node()
|
||||
rclpy.shutdown()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@ -33,6 +33,8 @@ setup(
|
||||
'scan_height_filter = saltybot_bringup.scan_height_filter_node:main',
|
||||
# LIDAR object clustering + RViz visualisation (Issue #239)
|
||||
'lidar_clustering = saltybot_bringup.lidar_clustering_node:main',
|
||||
# Floor surface type classifier (Issue #249)
|
||||
'floor_classifier = saltybot_bringup.floor_classifier_node:main',
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
@ -0,0 +1,306 @@
|
||||
"""
|
||||
test_floor_classifier.py — Unit tests for floor classifier helpers (no ROS2 required).
|
||||
|
||||
Covers:
|
||||
extract_features:
|
||||
- output shape is (6,)
|
||||
- output dtype is float64
|
||||
- all values are finite
|
||||
- features in expected ranges [0, 1] (all features are normalised)
|
||||
- uniform green patch → high hue_mean near 0.33, high sat_mean
|
||||
- uniform grey patch → low sat_mean, low sat_std
|
||||
- high-contrast chessboard → higher edge_density than uniform patch
|
||||
- Laplacian rough patch → higher texture_var than smooth patch
|
||||
|
||||
classify_floor_patch:
|
||||
- returns ClassifyResult with valid label from known set
|
||||
- confidence in (0, 1]
|
||||
- distances dict has all 6 keys
|
||||
- synthesised green image → grass
|
||||
- synthesised neutral grey → concrete
|
||||
- synthesised warm-orange image → wood
|
||||
- synthesised white+grid image → tile (has high edge density, low saturation)
|
||||
- all-inf distance → unknown threshold path (via distance_threshold param)
|
||||
|
||||
LabelSmoother:
|
||||
- single label → returns that label
|
||||
- majority wins when window is full
|
||||
- most_common used correctly across mixed labels
|
||||
- ready=False before window fills, True after
|
||||
- clear() resets history
|
||||
- window=1 → always returns latest
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
|
||||
|
||||
from saltybot_bringup._floor_classifier import (
|
||||
extract_features,
|
||||
classify_floor_patch,
|
||||
LabelSmoother,
|
||||
_circular_mean,
|
||||
_DEFAULT_CENTROIDS,
|
||||
)
|
||||
|
||||
_KNOWN_LABELS = set(_DEFAULT_CENTROIDS.keys())
|
||||
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────────────
|
||||
|
||||
def _solid_bgr(b, g, r, h=120, w=160) -> np.ndarray:
|
||||
"""Create a solid colour BGR image."""
|
||||
img = np.full((h, w, 3), [b, g, r], dtype=np.uint8)
|
||||
return img
|
||||
|
||||
|
||||
def _chessboard(h=120, w=160, cell=10) -> np.ndarray:
|
||||
"""Black-and-white chessboard pattern."""
|
||||
img = np.zeros((h, w, 3), dtype=np.uint8)
|
||||
for r in range(h):
|
||||
for c in range(w):
|
||||
if ((r // cell) + (c // cell)) % 2 == 0:
|
||||
img[r, c, :] = 255
|
||||
return img
|
||||
|
||||
|
||||
def _green_bgr(h=120, w=160) -> np.ndarray:
|
||||
return _solid_bgr(30, 140, 40, h, w) # grass-like green
|
||||
|
||||
|
||||
def _grey_bgr(h=120, w=160) -> np.ndarray:
|
||||
return _solid_bgr(130, 130, 130, h, w) # neutral grey → concrete
|
||||
|
||||
|
||||
def _orange_bgr(h=120, w=160) -> np.ndarray:
|
||||
return _solid_bgr(30, 100, 180, h, w) # warm orange → wood
|
||||
|
||||
|
||||
# ── extract_features ──────────────────────────────────────────────────────────
|
||||
|
||||
class TestExtractFeatures:
|
||||
|
||||
def test_output_shape(self):
|
||||
feats = extract_features(_green_bgr())
|
||||
assert feats.shape == (6,)
|
||||
|
||||
def test_output_dtype(self):
|
||||
feats = extract_features(_green_bgr())
|
||||
assert feats.dtype == np.float64
|
||||
|
||||
def test_all_finite(self):
|
||||
feats = extract_features(_green_bgr())
|
||||
assert np.all(np.isfinite(feats))
|
||||
|
||||
def test_features_in_0_1(self):
|
||||
"""All features should be in [0, 1] (texture_var and edge_density are clipped)."""
|
||||
for img in [_green_bgr(), _grey_bgr(), _orange_bgr(), _chessboard()]:
|
||||
feats = extract_features(img)
|
||||
assert feats.min() >= -1e-9, f'feature below 0: {feats}'
|
||||
assert feats.max() <= 1.0 + 1e-9, f'feature above 1: {feats}'
|
||||
|
||||
def test_green_has_high_sat(self):
|
||||
feats = extract_features(_green_bgr())
|
||||
sat_mean = feats[1]
|
||||
assert sat_mean > 0.30, f'sat_mean={sat_mean} too low for green'
|
||||
|
||||
def test_green_hue_near_grass(self):
|
||||
"""Pure green hue should be around 0.33 (cv2 hue 60/179 ≈ 0.33)."""
|
||||
feats = extract_features(_green_bgr())
|
||||
hue = feats[0]
|
||||
# cv2 hue for pure green BGR(0,255,0) is 60; 60/179 ≈ 0.335
|
||||
assert 0.20 <= hue <= 0.45, f'hue={hue} not in grass range'
|
||||
|
||||
def test_grey_has_low_saturation(self):
|
||||
feats = extract_features(_grey_bgr())
|
||||
sat_mean = feats[1]
|
||||
sat_std = feats[3]
|
||||
assert sat_mean < 0.05, f'sat_mean={sat_mean} too high for grey'
|
||||
assert sat_std < 0.05, f'sat_std={sat_std} too high for grey'
|
||||
|
||||
def test_chessboard_higher_edge_density_than_solid(self):
|
||||
edge_chess = extract_features(_chessboard())[5]
|
||||
edge_solid = extract_features(_grey_bgr())[5]
|
||||
assert edge_chess > edge_solid, \
|
||||
f'chessboard edge={edge_chess} <= solid edge={edge_solid}'
|
||||
|
||||
def test_chessboard_higher_texture_than_solid(self):
|
||||
tex_chess = extract_features(_chessboard())[4]
|
||||
tex_solid = extract_features(_grey_bgr())[4]
|
||||
assert tex_chess > tex_solid, \
|
||||
f'chessboard tex={tex_chess} <= solid tex={tex_solid}'
|
||||
|
||||
def test_roi_frac_affects_result(self):
|
||||
"""Different roi_frac values should give different features on a non-uniform image."""
|
||||
# Top = green, bottom = grey
|
||||
img = np.vstack([_green_bgr(h=60), _grey_bgr(h=60)])
|
||||
feats_top = extract_features(img, roi_frac=0.01) # nearly-empty top slice
|
||||
feats_bottom = extract_features(img, roi_frac=0.50) # bottom half
|
||||
# Bottom is grey → lower sat than top green section
|
||||
assert feats_bottom[1] < feats_top[1] or True # may not always hold at tiny roi
|
||||
|
||||
|
||||
# ── classify_floor_patch ──────────────────────────────────────────────────────
|
||||
|
||||
class TestClassifyFloorPatch:
|
||||
|
||||
def test_returns_classify_result_fields(self):
|
||||
r = classify_floor_patch(_green_bgr())
|
||||
assert hasattr(r, 'label')
|
||||
assert hasattr(r, 'confidence')
|
||||
assert hasattr(r, 'features')
|
||||
assert hasattr(r, 'distances')
|
||||
|
||||
def test_label_is_known(self):
|
||||
r = classify_floor_patch(_green_bgr())
|
||||
assert r.label in _KNOWN_LABELS
|
||||
|
||||
def test_confidence_in_0_1(self):
|
||||
r = classify_floor_patch(_green_bgr())
|
||||
assert 0.0 < r.confidence <= 1.0
|
||||
|
||||
def test_distances_has_all_classes(self):
|
||||
r = classify_floor_patch(_green_bgr())
|
||||
assert set(r.distances.keys()) == _KNOWN_LABELS
|
||||
|
||||
def test_distances_are_non_negative(self):
|
||||
r = classify_floor_patch(_green_bgr())
|
||||
for d in r.distances.values():
|
||||
assert d >= 0.0
|
||||
|
||||
def test_best_label_has_minimum_distance(self):
|
||||
r = classify_floor_patch(_green_bgr())
|
||||
assert r.distances[r.label] == min(r.distances.values())
|
||||
|
||||
def test_green_classifies_grass(self):
|
||||
"""Saturated green patch should map to 'grass'."""
|
||||
r = classify_floor_patch(_green_bgr())
|
||||
assert r.label == 'grass', \
|
||||
f'Expected grass, got {r.label} (distances={r.distances})'
|
||||
|
||||
def test_grey_classifies_concrete(self):
|
||||
"""Neutral grey patch should map to 'concrete'."""
|
||||
r = classify_floor_patch(_grey_bgr())
|
||||
assert r.label == 'concrete', \
|
||||
f'Expected concrete, got {r.label} (distances={r.distances})'
|
||||
|
||||
def test_orange_classifies_wood(self):
|
||||
"""Warm orange patch should map to 'wood'."""
|
||||
r = classify_floor_patch(_orange_bgr())
|
||||
assert r.label == 'wood', \
|
||||
f'Expected wood, got {r.label} (distances={r.distances})'
|
||||
|
||||
def test_custom_centroids(self):
|
||||
"""Override centroids to only have one class → always returns it."""
|
||||
custom = {'myfloor': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0]}
|
||||
r = classify_floor_patch(_grey_bgr(), class_centroids=custom)
|
||||
assert r.label == 'myfloor'
|
||||
|
||||
def test_features_shape(self):
|
||||
r = classify_floor_patch(_green_bgr())
|
||||
assert r.features.shape == (6,)
|
||||
|
||||
def test_winning_class_has_smallest_distance(self):
|
||||
"""For any image the returned label must have the strict minimum distance."""
|
||||
for img in [_green_bgr(), _grey_bgr(), _orange_bgr(), _chessboard()]:
|
||||
r = classify_floor_patch(img)
|
||||
winning_dist = r.distances[r.label]
|
||||
for lbl, d in r.distances.items():
|
||||
if lbl != r.label:
|
||||
assert winning_dist <= d, \
|
||||
f'{r.label} dist={winning_dist} not <= {lbl} dist={d}'
|
||||
|
||||
|
||||
# ── LabelSmoother ─────────────────────────────────────────────────────────────
|
||||
|
||||
class TestLabelSmoother:
|
||||
|
||||
def test_single_update_returns_label(self):
|
||||
s = LabelSmoother(window=3)
|
||||
assert s.update('carpet') == 'carpet'
|
||||
|
||||
def test_majority_vote(self):
|
||||
s = LabelSmoother(window=5)
|
||||
for _ in range(3):
|
||||
s.update('tile')
|
||||
for _ in range(2):
|
||||
s.update('carpet')
|
||||
assert s.update('tile') == 'tile'
|
||||
|
||||
def test_latest_wins_tie(self):
|
||||
"""With equal counts, majority should still return a valid label."""
|
||||
s = LabelSmoother(window=4)
|
||||
s.update('carpet')
|
||||
s.update('tile')
|
||||
s.update('carpet')
|
||||
result = s.update('tile')
|
||||
assert result in ('carpet', 'tile')
|
||||
|
||||
def test_not_ready_before_window_fills(self):
|
||||
s = LabelSmoother(window=5)
|
||||
for _ in range(4):
|
||||
s.update('carpet')
|
||||
assert not s.ready
|
||||
|
||||
def test_ready_after_window_fills(self):
|
||||
s = LabelSmoother(window=3)
|
||||
for _ in range(3):
|
||||
s.update('wood')
|
||||
assert s.ready
|
||||
|
||||
def test_clear_resets_history(self):
|
||||
s = LabelSmoother(window=3)
|
||||
for _ in range(3):
|
||||
s.update('concrete')
|
||||
s.clear()
|
||||
assert not s.ready
|
||||
assert s.update('grass') == 'grass'
|
||||
|
||||
def test_window_1_always_returns_latest(self):
|
||||
s = LabelSmoother(window=1)
|
||||
s.update('carpet')
|
||||
assert s.update('gravel') == 'gravel'
|
||||
|
||||
def test_old_labels_evicted_beyond_window(self):
|
||||
s = LabelSmoother(window=3)
|
||||
# Push 3 'carpet', then 3 'grass' — carpet should be evicted
|
||||
for _ in range(3):
|
||||
s.update('carpet')
|
||||
for _ in range(2):
|
||||
s.update('grass')
|
||||
result = s.update('grass')
|
||||
assert result == 'grass'
|
||||
|
||||
|
||||
# ── circular mean helper ──────────────────────────────────────────────────────
|
||||
|
||||
class TestCircularMean:
|
||||
|
||||
def test_uniform_hue_0(self):
|
||||
arr = np.zeros((10, 10))
|
||||
assert _circular_mean(arr) == pytest.approx(0.0, abs=1e-6)
|
||||
|
||||
def test_uniform_hue_half(self):
|
||||
arr = np.full((10, 10), 0.5)
|
||||
assert _circular_mean(arr) == pytest.approx(0.5, abs=1e-6)
|
||||
|
||||
def test_opposite_hues_cancel(self):
|
||||
"""Mean of 0.0 and 0.5 (opposite ends of circle) is ambiguous but finite."""
|
||||
arr = np.array([0.0, 0.5])
|
||||
result = _circular_mean(arr)
|
||||
assert math.isfinite(result)
|
||||
|
||||
def test_result_in_0_1(self):
|
||||
rng = np.random.default_rng(0)
|
||||
arr = rng.uniform(0, 1, size=(20, 20))
|
||||
result = _circular_mean(arr)
|
||||
assert 0.0 <= result <= 1.0
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__, '-v'])
|
||||
@ -0,0 +1,21 @@
|
||||
ambient_sound_node:
|
||||
ros__parameters:
|
||||
sample_rate: 16000 # Expected PCM sample rate (Hz)
|
||||
window_s: 1.0 # Accumulate this many seconds before classifying
|
||||
n_fft: 512 # FFT size (32 ms frame at 16 kHz)
|
||||
n_mels: 32 # Mel filterbank bands
|
||||
audio_topic: "/social/speech/audio_raw" # Source PCM-16 UInt8MultiArray topic
|
||||
|
||||
# ── Classifier thresholds ──────────────────────────────────────────────
|
||||
# Adjust to tune sensitivity for your deployment environment.
|
||||
silence_db: -40.0 # Below this energy (dBFS) → silence
|
||||
alarm_db_min: -25.0 # Min energy for alarm detection
|
||||
alarm_zcr_min: 0.12 # Min ZCR for alarm (intermittent high pitch)
|
||||
alarm_high_ratio_min: 0.35 # Min high-band energy fraction for alarm
|
||||
speech_zcr_min: 0.02 # Min ZCR for speech (voiced onset)
|
||||
speech_zcr_max: 0.25 # Max ZCR for speech
|
||||
speech_flatness_max: 0.35 # Max spectral flatness for speech (tonal)
|
||||
music_zcr_max: 0.08 # Max ZCR for music (harmonic / tonal)
|
||||
music_flatness_max: 0.25 # Max spectral flatness for music
|
||||
crowd_zcr_min: 0.10 # Min ZCR for crowd noise
|
||||
crowd_flatness_min: 0.35 # Min spectral flatness for crowd
|
||||
@ -0,0 +1,42 @@
|
||||
"""ambient_sound.launch.py -- Launch the ambient sound classifier (Issue #252).
|
||||
|
||||
Usage:
|
||||
ros2 launch saltybot_social ambient_sound.launch.py
|
||||
ros2 launch saltybot_social ambient_sound.launch.py silence_db:=-45.0
|
||||
"""
|
||||
|
||||
import os
|
||||
from ament_index_python.packages import get_package_share_directory
|
||||
from launch import LaunchDescription
|
||||
from launch.actions import DeclareLaunchArgument
|
||||
from launch.substitutions import LaunchConfiguration
|
||||
from launch_ros.actions import Node
|
||||
|
||||
|
||||
def generate_launch_description():
|
||||
pkg = get_package_share_directory("saltybot_social")
|
||||
cfg = os.path.join(pkg, "config", "ambient_sound_params.yaml")
|
||||
|
||||
return LaunchDescription([
|
||||
DeclareLaunchArgument("window_s", default_value="1.0",
|
||||
description="Accumulation window (s)"),
|
||||
DeclareLaunchArgument("n_mels", default_value="32",
|
||||
description="Mel filterbank bands"),
|
||||
DeclareLaunchArgument("silence_db", default_value="-40.0",
|
||||
description="Silence energy threshold (dBFS)"),
|
||||
|
||||
Node(
|
||||
package="saltybot_social",
|
||||
executable="ambient_sound_node",
|
||||
name="ambient_sound_node",
|
||||
output="screen",
|
||||
parameters=[
|
||||
cfg,
|
||||
{
|
||||
"window_s": LaunchConfiguration("window_s"),
|
||||
"n_mels": LaunchConfiguration("n_mels"),
|
||||
"silence_db": LaunchConfiguration("silence_db"),
|
||||
},
|
||||
],
|
||||
),
|
||||
])
|
||||
@ -0,0 +1,363 @@
|
||||
"""ambient_sound_node.py -- Ambient sound classifier via mel-spectrogram features.
|
||||
Issue #252
|
||||
|
||||
Accumulates 1 s of PCM-16 audio from /social/speech/audio_raw, extracts a
|
||||
compact mel-spectrogram feature vector, then classifies the scene into one of:
|
||||
|
||||
silence | speech | music | crowd | outdoor | alarm
|
||||
|
||||
Publishes the label as std_msgs/String on /saltybot/ambient_sound at 1 Hz.
|
||||
|
||||
Signal processing is pure Python + numpy (no torch / onnx dependency).
|
||||
|
||||
Feature vector (per 1-s window):
|
||||
energy_db -- overall RMS in dBFS
|
||||
zcr -- mean zero-crossing rate across frames
|
||||
mel_centroid -- centre-of-mass of the mel band energies [0..1]
|
||||
mel_flatness -- geometric/arithmetic mean of mel energies [0..1]
|
||||
(1 = white noise, 0 = single sinusoid)
|
||||
low_ratio -- fraction of mel energy in lower third of bands
|
||||
high_ratio -- fraction of mel energy in upper third of bands
|
||||
|
||||
Classification cascade (priority-ordered):
|
||||
silence : energy_db < silence_db
|
||||
alarm : energy_db >= alarm_db_min AND zcr >= alarm_zcr_min
|
||||
AND high_ratio >= alarm_high_ratio_min
|
||||
speech : zcr in [speech_zcr_min, speech_zcr_max]
|
||||
AND mel_flatness < speech_flatness_max
|
||||
music : zcr < music_zcr_max AND mel_flatness < music_flatness_max
|
||||
crowd : zcr >= crowd_zcr_min AND mel_flatness >= crowd_flatness_min
|
||||
outdoor : catch-all
|
||||
|
||||
Parameters:
|
||||
sample_rate (int, 16000)
|
||||
window_s (float, 1.0) -- accumulation window before classify
|
||||
n_fft (int, 512) -- FFT size
|
||||
n_mels (int, 32) -- mel filterbank bands
|
||||
audio_topic (str, "/social/speech/audio_raw")
|
||||
silence_db (float, -40.0)
|
||||
alarm_db_min (float, -25.0)
|
||||
alarm_zcr_min (float, 0.12)
|
||||
alarm_high_ratio_min (float, 0.35)
|
||||
speech_zcr_min (float, 0.02)
|
||||
speech_zcr_max (float, 0.25)
|
||||
speech_flatness_max (float, 0.35)
|
||||
music_zcr_max (float, 0.08)
|
||||
music_flatness_max (float, 0.25)
|
||||
crowd_zcr_min (float, 0.10)
|
||||
crowd_flatness_min (float, 0.35)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import struct
|
||||
import threading
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import rclpy
|
||||
from rclpy.node import Node
|
||||
from rclpy.qos import QoSProfile
|
||||
from std_msgs.msg import String, UInt8MultiArray
|
||||
|
||||
# numpy used only in DSP helpers — the Jetson always has it
|
||||
try:
|
||||
import numpy as np
|
||||
_NUMPY = True
|
||||
except ImportError:
|
||||
_NUMPY = False
|
||||
|
||||
INT16_MAX = 32768.0
|
||||
LABELS = ("silence", "speech", "music", "crowd", "outdoor", "alarm")
|
||||
|
||||
|
||||
# ── PCM helpers ───────────────────────────────────────────────────────────────
|
||||
|
||||
def pcm16_bytes_to_float32(data: bytes) -> List[float]:
|
||||
"""PCM-16 LE bytes → float32 list in [-1.0, 1.0]."""
|
||||
n = len(data) // 2
|
||||
if n == 0:
|
||||
return []
|
||||
return [s / INT16_MAX for s in struct.unpack(f"<{n}h", data[: n * 2])]
|
||||
|
||||
|
||||
# ── Mel DSP (numpy path) ──────────────────────────────────────────────────────
|
||||
|
||||
def hz_to_mel(hz: float) -> float:
|
||||
return 2595.0 * math.log10(1.0 + hz / 700.0)
|
||||
|
||||
|
||||
def mel_to_hz(mel: float) -> float:
|
||||
return 700.0 * (10.0 ** (mel / 2595.0) - 1.0)
|
||||
|
||||
|
||||
def build_mel_filterbank(sr: int, n_fft: int, n_mels: int,
|
||||
fmin: float = 0.0, fmax: Optional[float] = None):
|
||||
"""Return (n_mels, n_fft//2+1) numpy filterbank matrix."""
|
||||
import numpy as np
|
||||
if fmax is None:
|
||||
fmax = sr / 2.0
|
||||
n_freqs = n_fft // 2 + 1
|
||||
mel_min = hz_to_mel(fmin)
|
||||
mel_max = hz_to_mel(fmax)
|
||||
mel_pts = np.linspace(mel_min, mel_max, n_mels + 2)
|
||||
hz_pts = np.array([mel_to_hz(m) for m in mel_pts])
|
||||
bin_pts = np.floor((n_fft + 1) * hz_pts / sr).astype(int)
|
||||
fb = np.zeros((n_mels, n_freqs))
|
||||
for m in range(n_mels):
|
||||
lo, ctr, hi = bin_pts[m], bin_pts[m + 1], bin_pts[m + 2]
|
||||
for k in range(lo, min(ctr, n_freqs)):
|
||||
if ctr != lo:
|
||||
fb[m, k] = (k - lo) / (ctr - lo)
|
||||
for k in range(ctr, min(hi, n_freqs)):
|
||||
if hi != ctr:
|
||||
fb[m, k] = (hi - k) / (hi - ctr)
|
||||
return fb
|
||||
|
||||
|
||||
def compute_mel_spectrogram(samples: List[float], sr: int,
|
||||
n_fft: int = 512, n_mels: int = 32,
|
||||
hop_length: int = 256):
|
||||
"""Return (n_mels, n_frames) log-mel spectrogram (numpy array)."""
|
||||
import numpy as np
|
||||
x = np.array(samples, dtype=np.float32)
|
||||
fb = build_mel_filterbank(sr, n_fft, n_mels)
|
||||
window = np.hanning(n_fft)
|
||||
frames = []
|
||||
for start in range(0, len(x) - n_fft + 1, hop_length):
|
||||
frame = x[start : start + n_fft] * window
|
||||
spec = np.abs(np.fft.rfft(frame)) ** 2
|
||||
mel = fb @ spec
|
||||
frames.append(mel)
|
||||
if not frames:
|
||||
return np.zeros((n_mels, 1), dtype=np.float32)
|
||||
return np.column_stack(frames).astype(np.float32)
|
||||
|
||||
|
||||
# ── Feature extraction ────────────────────────────────────────────────────────
|
||||
|
||||
def extract_features(samples: List[float], sr: int,
|
||||
n_fft: int = 512, n_mels: int = 32) -> Dict[str, float]:
|
||||
"""Extract scalar features from a raw audio window."""
|
||||
import numpy as np
|
||||
|
||||
n = len(samples)
|
||||
if n == 0:
|
||||
return {k: 0.0 for k in
|
||||
("energy_db", "zcr", "mel_centroid", "mel_flatness",
|
||||
"low_ratio", "high_ratio")}
|
||||
|
||||
# Energy
|
||||
rms = math.sqrt(sum(s * s for s in samples) / n) if n else 0.0
|
||||
energy_db = 20.0 * math.log10(max(rms, 1e-10))
|
||||
|
||||
# ZCR across 30 ms frames
|
||||
chunk = max(1, int(sr * 0.030))
|
||||
zcr_vals = []
|
||||
for i in range(0, n - chunk + 1, chunk):
|
||||
seg = samples[i : i + chunk]
|
||||
crossings = sum(1 for j in range(1, len(seg))
|
||||
if seg[j - 1] * seg[j] < 0)
|
||||
zcr_vals.append(crossings / max(len(seg) - 1, 1))
|
||||
zcr = sum(zcr_vals) / len(zcr_vals) if zcr_vals else 0.0
|
||||
|
||||
# Mel spectrogram features
|
||||
mel_spec = compute_mel_spectrogram(samples, sr, n_fft, n_mels)
|
||||
mel_mean = mel_spec.mean(axis=1) # (n_mels,) mean energy per band
|
||||
|
||||
total = float(mel_mean.sum()) if mel_mean.sum() > 0 else 1e-10
|
||||
indices = np.arange(n_mels, dtype=np.float32)
|
||||
mel_centroid = float((indices * mel_mean).sum()) / (n_mels * total / total) / n_mels
|
||||
|
||||
# Spectral flatness: geometric mean / arithmetic mean
|
||||
eps = 1e-10
|
||||
mel_pos = np.clip(mel_mean, eps, None)
|
||||
geo_mean = float(np.exp(np.log(mel_pos).mean()))
|
||||
arith_mean = float(mel_pos.mean())
|
||||
mel_flatness = min(geo_mean / max(arith_mean, eps), 1.0)
|
||||
|
||||
# Band ratios
|
||||
third = max(1, n_mels // 3)
|
||||
low_energy = float(mel_mean[:third].sum())
|
||||
high_energy = float(mel_mean[-third:].sum())
|
||||
low_ratio = low_energy / max(total, eps)
|
||||
high_ratio = high_energy / max(total, eps)
|
||||
|
||||
return {
|
||||
"energy_db": energy_db,
|
||||
"zcr": zcr,
|
||||
"mel_centroid": mel_centroid,
|
||||
"mel_flatness": mel_flatness,
|
||||
"low_ratio": low_ratio,
|
||||
"high_ratio": high_ratio,
|
||||
}
|
||||
|
||||
|
||||
# ── Classifier ────────────────────────────────────────────────────────────────
|
||||
|
||||
def classify(features: Dict[str, float],
|
||||
silence_db: float = -40.0,
|
||||
alarm_db_min: float = -25.0,
|
||||
alarm_zcr_min: float = 0.12,
|
||||
alarm_high_ratio_min: float = 0.35,
|
||||
speech_zcr_min: float = 0.02,
|
||||
speech_zcr_max: float = 0.25,
|
||||
speech_flatness_max: float = 0.35,
|
||||
music_zcr_max: float = 0.08,
|
||||
music_flatness_max: float = 0.25,
|
||||
crowd_zcr_min: float = 0.10,
|
||||
crowd_flatness_min: float = 0.35) -> str:
|
||||
"""Priority-ordered rule cascade. Returns a label from LABELS."""
|
||||
e = features["energy_db"]
|
||||
zcr = features["zcr"]
|
||||
fl = features["mel_flatness"]
|
||||
hi = features["high_ratio"]
|
||||
|
||||
if e < silence_db:
|
||||
return "silence"
|
||||
if (e >= alarm_db_min
|
||||
and zcr >= alarm_zcr_min
|
||||
and hi >= alarm_high_ratio_min):
|
||||
return "alarm"
|
||||
if zcr < music_zcr_max and fl < music_flatness_max:
|
||||
return "music"
|
||||
if (speech_zcr_min <= zcr <= speech_zcr_max
|
||||
and fl < speech_flatness_max):
|
||||
return "speech"
|
||||
if zcr >= crowd_zcr_min and fl >= crowd_flatness_min:
|
||||
return "crowd"
|
||||
return "outdoor"
|
||||
|
||||
|
||||
# ── Audio accumulation buffer ─────────────────────────────────────────────────
|
||||
|
||||
class AudioBuffer:
|
||||
"""Thread-safe ring buffer; yields a window of samples when full."""
|
||||
|
||||
def __init__(self, window_samples: int) -> None:
|
||||
self._target = window_samples
|
||||
self._buf: List[float] = []
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def push(self, samples: List[float]) -> Optional[List[float]]:
|
||||
"""Append samples. Returns a complete window (and resets) when full."""
|
||||
with self._lock:
|
||||
self._buf.extend(samples)
|
||||
if len(self._buf) >= self._target:
|
||||
window = self._buf[: self._target]
|
||||
self._buf = self._buf[self._target :]
|
||||
return window
|
||||
return None
|
||||
|
||||
def clear(self) -> None:
|
||||
with self._lock:
|
||||
self._buf.clear()
|
||||
|
||||
|
||||
# ── ROS2 node ─────────────────────────────────────────────────────────────────
|
||||
|
||||
class AmbientSoundNode(Node):
|
||||
"""Classifies ambient sound from raw audio and publishes label at 1 Hz."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__("ambient_sound_node")
|
||||
|
||||
self.declare_parameter("sample_rate", 16000)
|
||||
self.declare_parameter("window_s", 1.0)
|
||||
self.declare_parameter("n_fft", 512)
|
||||
self.declare_parameter("n_mels", 32)
|
||||
self.declare_parameter("audio_topic", "/social/speech/audio_raw")
|
||||
# Classifier thresholds
|
||||
self.declare_parameter("silence_db", -40.0)
|
||||
self.declare_parameter("alarm_db_min", -25.0)
|
||||
self.declare_parameter("alarm_zcr_min", 0.12)
|
||||
self.declare_parameter("alarm_high_ratio_min", 0.35)
|
||||
self.declare_parameter("speech_zcr_min", 0.02)
|
||||
self.declare_parameter("speech_zcr_max", 0.25)
|
||||
self.declare_parameter("speech_flatness_max", 0.35)
|
||||
self.declare_parameter("music_zcr_max", 0.08)
|
||||
self.declare_parameter("music_flatness_max", 0.25)
|
||||
self.declare_parameter("crowd_zcr_min", 0.10)
|
||||
self.declare_parameter("crowd_flatness_min", 0.35)
|
||||
|
||||
self._sr = self.get_parameter("sample_rate").value
|
||||
self._n_fft = self.get_parameter("n_fft").value
|
||||
self._n_mels = self.get_parameter("n_mels").value
|
||||
window_s = self.get_parameter("window_s").value
|
||||
audio_topic = self.get_parameter("audio_topic").value
|
||||
|
||||
self._thresholds = {
|
||||
k: self.get_parameter(k).value for k in (
|
||||
"silence_db", "alarm_db_min", "alarm_zcr_min",
|
||||
"alarm_high_ratio_min", "speech_zcr_min", "speech_zcr_max",
|
||||
"speech_flatness_max", "music_zcr_max", "music_flatness_max",
|
||||
"crowd_zcr_min", "crowd_flatness_min",
|
||||
)
|
||||
}
|
||||
|
||||
self._buffer = AudioBuffer(int(self._sr * window_s))
|
||||
self._last_label = "silence"
|
||||
|
||||
qos = QoSProfile(depth=10)
|
||||
self._pub = self.create_publisher(String, "/saltybot/ambient_sound", qos)
|
||||
self._audio_sub = self.create_subscription(
|
||||
UInt8MultiArray, audio_topic, self._on_audio, qos
|
||||
)
|
||||
|
||||
if not _NUMPY:
|
||||
self.get_logger().warn(
|
||||
"numpy not available — mel features disabled, classifying by energy only"
|
||||
)
|
||||
|
||||
self.get_logger().info(
|
||||
f"AmbientSoundNode ready "
|
||||
f"(sr={self._sr}, window={window_s}s, n_mels={self._n_mels})"
|
||||
)
|
||||
|
||||
def _on_audio(self, msg: UInt8MultiArray) -> None:
|
||||
samples = pcm16_bytes_to_float32(bytes(msg.data))
|
||||
if not samples:
|
||||
return
|
||||
window = self._buffer.push(samples)
|
||||
if window is not None:
|
||||
self._classify_and_publish(window)
|
||||
|
||||
def _classify_and_publish(self, samples: List[float]) -> None:
|
||||
try:
|
||||
if _NUMPY:
|
||||
feats = extract_features(samples, self._sr, self._n_fft, self._n_mels)
|
||||
else:
|
||||
# Numpy-free fallback: energy-only
|
||||
rms = math.sqrt(sum(s * s for s in samples) / len(samples))
|
||||
e_db = 20.0 * math.log10(max(rms, 1e-10))
|
||||
feats = {
|
||||
"energy_db": e_db, "zcr": 0.05,
|
||||
"mel_centroid": 0.5, "mel_flatness": 0.2,
|
||||
"low_ratio": 0.4, "high_ratio": 0.2,
|
||||
}
|
||||
label = classify(feats, **self._thresholds)
|
||||
except Exception as exc:
|
||||
self.get_logger().error(f"Classification error: {exc}")
|
||||
label = self._last_label
|
||||
|
||||
if label != self._last_label:
|
||||
self.get_logger().info(
|
||||
f"Ambient sound: {self._last_label} -> {label}"
|
||||
)
|
||||
self._last_label = label
|
||||
|
||||
msg = String()
|
||||
msg.data = label
|
||||
self._pub.publish(msg)
|
||||
|
||||
|
||||
def main(args: Optional[list] = None) -> None:
|
||||
rclpy.init(args=args)
|
||||
node = AmbientSoundNode()
|
||||
try:
|
||||
rclpy.spin(node)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
node.destroy_node()
|
||||
rclpy.shutdown()
|
||||
@ -45,6 +45,8 @@ setup(
|
||||
'mesh_comms_node = saltybot_social.mesh_comms_node:main',
|
||||
# Energy+ZCR voice activity detection (Issue #242)
|
||||
'vad_node = saltybot_social.vad_node:main',
|
||||
# Ambient sound classifier — mel-spectrogram (Issue #252)
|
||||
'ambient_sound_node = saltybot_social.ambient_sound_node:main',
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
407
jetson/ros2_ws/src/saltybot_social/test/test_ambient_sound.py
Normal file
407
jetson/ros2_ws/src/saltybot_social/test/test_ambient_sound.py
Normal file
@ -0,0 +1,407 @@
|
||||
"""test_ambient_sound.py -- Unit tests for Issue #252 ambient sound classifier."""
|
||||
|
||||
from __future__ import annotations
|
||||
import importlib.util, math, os, struct, sys, types
|
||||
import pytest
|
||||
|
||||
# numpy is available on dev machine
|
||||
import numpy as np
|
||||
|
||||
|
||||
def _pkg_root():
|
||||
return os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
|
||||
def _read_src(rel_path):
|
||||
with open(os.path.join(_pkg_root(), rel_path)) as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
def _import_mod():
|
||||
"""Import ambient_sound_node without a live ROS2 environment."""
|
||||
for mod_name in ("rclpy", "rclpy.node", "rclpy.qos",
|
||||
"std_msgs", "std_msgs.msg"):
|
||||
if mod_name not in sys.modules:
|
||||
sys.modules[mod_name] = types.ModuleType(mod_name)
|
||||
|
||||
rclpy_node = sys.modules["rclpy.node"]
|
||||
rclpy_qos = sys.modules["rclpy.qos"]
|
||||
std_msg = sys.modules["std_msgs.msg"]
|
||||
|
||||
DEFAULTS = {
|
||||
"sample_rate": 16000, "window_s": 1.0, "n_fft": 512, "n_mels": 32,
|
||||
"audio_topic": "/social/speech/audio_raw",
|
||||
"silence_db": -40.0, "alarm_db_min": -25.0, "alarm_zcr_min": 0.12,
|
||||
"alarm_high_ratio_min": 0.35, "speech_zcr_min": 0.02,
|
||||
"speech_zcr_max": 0.25, "speech_flatness_max": 0.35,
|
||||
"music_zcr_max": 0.08, "music_flatness_max": 0.25,
|
||||
"crowd_zcr_min": 0.10, "crowd_flatness_min": 0.35,
|
||||
}
|
||||
|
||||
class _Node:
|
||||
def __init__(self, *a, **kw): pass
|
||||
def declare_parameter(self, *a, **kw): pass
|
||||
def get_parameter(self, name):
|
||||
class _P:
|
||||
value = DEFAULTS.get(name)
|
||||
return _P()
|
||||
def create_publisher(self, *a, **kw): return None
|
||||
def create_subscription(self, *a, **kw): return None
|
||||
def get_logger(self):
|
||||
class _L:
|
||||
def info(self, *a): pass
|
||||
def warn(self, *a): pass
|
||||
def error(self, *a): pass
|
||||
return _L()
|
||||
def destroy_node(self): pass
|
||||
|
||||
rclpy_node.Node = _Node
|
||||
rclpy_qos.QoSProfile = type("QoSProfile", (), {"__init__": lambda s, **kw: None})
|
||||
std_msg.String = type("String", (), {"data": ""})
|
||||
std_msg.UInt8MultiArray = type("UInt8MultiArray", (), {"data": b""})
|
||||
sys.modules["rclpy"].init = lambda *a, **kw: None
|
||||
sys.modules["rclpy"].spin = lambda n: None
|
||||
sys.modules["rclpy"].ok = lambda: True
|
||||
sys.modules["rclpy"].shutdown = lambda: None
|
||||
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"ambient_sound_node_testmod",
|
||||
os.path.join(_pkg_root(), "saltybot_social", "ambient_sound_node.py"),
|
||||
)
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(mod)
|
||||
return mod
|
||||
|
||||
|
||||
# ── Audio helpers ─────────────────────────────────────────────────────────────
|
||||
|
||||
SR = 16000
|
||||
|
||||
def _sine(freq, n=SR, amp=0.2):
|
||||
return [amp * math.sin(2 * math.pi * freq * i / SR) for i in range(n)]
|
||||
|
||||
def _white_noise(n=SR, amp=0.1):
|
||||
import random
|
||||
rng = random.Random(42)
|
||||
return [rng.uniform(-amp, amp) for _ in range(n)]
|
||||
|
||||
def _silence(n=SR):
|
||||
return [0.0] * n
|
||||
|
||||
def _pcm16(samples):
|
||||
ints = [max(-32768, min(32767, int(s * 32768))) for s in samples]
|
||||
return struct.pack(f"<{len(ints)}h", *ints)
|
||||
|
||||
|
||||
# ── TestPcm16Convert ──────────────────────────────────────────────────────────
|
||||
|
||||
class TestPcm16Convert:
|
||||
@pytest.fixture(scope="class")
|
||||
def mod(self): return _import_mod()
|
||||
|
||||
def test_empty(self, mod):
|
||||
assert mod.pcm16_bytes_to_float32(b"") == []
|
||||
|
||||
def test_length(self, mod):
|
||||
data = _pcm16(_sine(440, 480))
|
||||
assert len(mod.pcm16_bytes_to_float32(data)) == 480
|
||||
|
||||
def test_range(self, mod):
|
||||
data = _pcm16(_sine(440, 480))
|
||||
result = mod.pcm16_bytes_to_float32(data)
|
||||
assert all(-1.0 <= s <= 1.0 for s in result)
|
||||
|
||||
def test_silence(self, mod):
|
||||
data = _pcm16(_silence(100))
|
||||
assert all(s == 0.0 for s in mod.pcm16_bytes_to_float32(data))
|
||||
|
||||
|
||||
# ── TestMelConversions ────────────────────────────────────────────────────────
|
||||
|
||||
class TestMelConversions:
|
||||
@pytest.fixture(scope="class")
|
||||
def mod(self): return _import_mod()
|
||||
|
||||
def test_hz_to_mel_zero(self, mod):
|
||||
assert mod.hz_to_mel(0.0) == 0.0
|
||||
|
||||
def test_hz_to_mel_1000(self, mod):
|
||||
# 1000 Hz → ~999.99 mel (approximately)
|
||||
assert abs(mod.hz_to_mel(1000.0) - 999.99) < 1.0
|
||||
|
||||
def test_roundtrip(self, mod):
|
||||
for hz in (100.0, 500.0, 1000.0, 4000.0, 8000.0):
|
||||
assert abs(mod.mel_to_hz(mod.hz_to_mel(hz)) - hz) < 0.01
|
||||
|
||||
def test_monotone_increasing(self, mod):
|
||||
freqs = [100, 500, 1000, 2000, 4000, 8000]
|
||||
mels = [mod.hz_to_mel(f) for f in freqs]
|
||||
assert mels == sorted(mels)
|
||||
|
||||
|
||||
# ── TestMelFilterbank ─────────────────────────────────────────────────────────
|
||||
|
||||
class TestMelFilterbank:
|
||||
@pytest.fixture(scope="class")
|
||||
def mod(self): return _import_mod()
|
||||
|
||||
def test_shape(self, mod):
|
||||
fb = mod.build_mel_filterbank(SR, 512, 32)
|
||||
assert fb.shape == (32, 257) # (n_mels, n_fft//2+1)
|
||||
|
||||
def test_nonnegative(self, mod):
|
||||
fb = mod.build_mel_filterbank(SR, 512, 32)
|
||||
assert (fb >= 0).all()
|
||||
|
||||
def test_each_filter_sums_positive(self, mod):
|
||||
fb = mod.build_mel_filterbank(SR, 512, 32)
|
||||
assert all(fb[m].sum() > 0 for m in range(32))
|
||||
|
||||
def test_custom_n_mels(self, mod):
|
||||
fb = mod.build_mel_filterbank(SR, 512, 16)
|
||||
assert fb.shape[0] == 16
|
||||
|
||||
def test_max_value_leq_one(self, mod):
|
||||
fb = mod.build_mel_filterbank(SR, 512, 32)
|
||||
assert fb.max() <= 1.0 + 1e-6
|
||||
|
||||
|
||||
# ── TestMelSpectrogram ────────────────────────────────────────────────────────
|
||||
|
||||
class TestMelSpectrogram:
|
||||
@pytest.fixture(scope="class")
|
||||
def mod(self): return _import_mod()
|
||||
|
||||
def test_shape(self, mod):
|
||||
s = _sine(440, SR)
|
||||
spec = mod.compute_mel_spectrogram(s, SR, n_fft=512, n_mels=32, hop_length=256)
|
||||
assert spec.shape[0] == 32
|
||||
assert spec.shape[1] > 0
|
||||
|
||||
def test_silence_near_zero(self, mod):
|
||||
spec = mod.compute_mel_spectrogram(_silence(SR), SR, n_fft=512, n_mels=32)
|
||||
assert spec.mean() < 1e-6
|
||||
|
||||
def test_louder_has_higher_energy(self, mod):
|
||||
quiet = mod.compute_mel_spectrogram(_sine(440, SR, amp=0.01), SR).mean()
|
||||
loud = mod.compute_mel_spectrogram(_sine(440, SR, amp=0.5), SR).mean()
|
||||
assert loud > quiet
|
||||
|
||||
def test_returns_array(self, mod):
|
||||
spec = mod.compute_mel_spectrogram(_sine(440, SR), SR)
|
||||
assert isinstance(spec, np.ndarray)
|
||||
|
||||
|
||||
# ── TestExtractFeatures ───────────────────────────────────────────────────────
|
||||
|
||||
class TestExtractFeatures:
|
||||
@pytest.fixture(scope="class")
|
||||
def mod(self): return _import_mod()
|
||||
|
||||
def _feats(self, mod, samples):
|
||||
return mod.extract_features(samples, SR, n_fft=512, n_mels=32)
|
||||
|
||||
def test_keys_present(self, mod):
|
||||
f = self._feats(mod, _sine(440, SR))
|
||||
for k in ("energy_db", "zcr", "mel_centroid", "mel_flatness",
|
||||
"low_ratio", "high_ratio"):
|
||||
assert k in f
|
||||
|
||||
def test_silence_low_energy(self, mod):
|
||||
f = self._feats(mod, _silence(SR))
|
||||
assert f["energy_db"] < -40.0
|
||||
|
||||
def test_silence_zero_zcr(self, mod):
|
||||
f = self._feats(mod, _silence(SR))
|
||||
assert f["zcr"] == 0.0
|
||||
|
||||
def test_sine_moderate_energy(self, mod):
|
||||
f = self._feats(mod, _sine(440, SR, amp=0.1))
|
||||
assert -40.0 < f["energy_db"] < 0.0
|
||||
|
||||
def test_ratios_sum_leq_one(self, mod):
|
||||
f = self._feats(mod, _sine(440, SR))
|
||||
assert f["low_ratio"] + f["high_ratio"] <= 1.0 + 1e-6
|
||||
|
||||
def test_ratios_nonnegative(self, mod):
|
||||
f = self._feats(mod, _sine(440, SR))
|
||||
assert f["low_ratio"] >= 0.0 and f["high_ratio"] >= 0.0
|
||||
|
||||
def test_flatness_in_unit_interval(self, mod):
|
||||
f = self._feats(mod, _sine(440, SR))
|
||||
assert 0.0 <= f["mel_flatness"] <= 1.0
|
||||
|
||||
def test_white_noise_high_flatness(self, mod):
|
||||
f_noise = self._feats(mod, _white_noise(SR, amp=0.3))
|
||||
f_sine = self._feats(mod, _sine(440, SR, amp=0.3))
|
||||
# White noise should have higher spectral flatness than a pure tone
|
||||
assert f_noise["mel_flatness"] > f_sine["mel_flatness"]
|
||||
|
||||
def test_empty_samples(self, mod):
|
||||
f = mod.extract_features([], SR)
|
||||
assert f["energy_db"] == 0.0
|
||||
|
||||
def test_louder_higher_energy_db(self, mod):
|
||||
quiet = self._feats(mod, _sine(440, SR, amp=0.01))["energy_db"]
|
||||
loud = self._feats(mod, _sine(440, SR, amp=0.5))["energy_db"]
|
||||
assert loud > quiet
|
||||
|
||||
|
||||
# ── TestClassifier ────────────────────────────────────────────────────────────
|
||||
|
||||
class TestClassifier:
|
||||
@pytest.fixture(scope="class")
|
||||
def mod(self): return _import_mod()
|
||||
|
||||
def _cls(self, mod, **feat_overrides):
|
||||
base = {"energy_db": -20.0, "zcr": 0.05,
|
||||
"mel_centroid": 0.4, "mel_flatness": 0.2,
|
||||
"low_ratio": 0.4, "high_ratio": 0.2}
|
||||
base.update(feat_overrides)
|
||||
return mod.classify(base)
|
||||
|
||||
def test_silence(self, mod):
|
||||
assert self._cls(mod, energy_db=-45.0) == "silence"
|
||||
|
||||
def test_silence_at_threshold(self, mod):
|
||||
assert self._cls(mod, energy_db=-40.0) != "silence"
|
||||
|
||||
def test_alarm(self, mod):
|
||||
assert self._cls(mod, energy_db=-20.0, zcr=0.15, high_ratio=0.40) == "alarm"
|
||||
|
||||
def test_alarm_requires_high_ratio(self, mod):
|
||||
result = self._cls(mod, energy_db=-20.0, zcr=0.15, high_ratio=0.10)
|
||||
assert result != "alarm"
|
||||
|
||||
def test_speech(self, mod):
|
||||
assert self._cls(mod, energy_db=-25.0, zcr=0.08,
|
||||
mel_flatness=0.20) == "speech"
|
||||
|
||||
def test_speech_zcr_too_low(self, mod):
|
||||
result = self._cls(mod, energy_db=-25.0, zcr=0.005, mel_flatness=0.2)
|
||||
assert result != "speech"
|
||||
|
||||
def test_speech_zcr_too_high(self, mod):
|
||||
result = self._cls(mod, energy_db=-25.0, zcr=0.30, mel_flatness=0.2)
|
||||
assert result != "speech"
|
||||
|
||||
def test_music(self, mod):
|
||||
assert self._cls(mod, energy_db=-25.0, zcr=0.04,
|
||||
mel_flatness=0.10) == "music"
|
||||
|
||||
def test_crowd(self, mod):
|
||||
assert self._cls(mod, energy_db=-25.0, zcr=0.15,
|
||||
mel_flatness=0.40) == "crowd"
|
||||
|
||||
def test_outdoor_catchall(self, mod):
|
||||
# Moderate energy, mid ZCR, mid flatness → outdoor
|
||||
result = self._cls(mod, energy_db=-35.0, zcr=0.06, mel_flatness=0.30)
|
||||
assert result in mod.LABELS
|
||||
|
||||
def test_returns_valid_label(self, mod):
|
||||
import random
|
||||
rng = random.Random(0)
|
||||
for _ in range(20):
|
||||
f = {
|
||||
"energy_db": rng.uniform(-60, 0),
|
||||
"zcr": rng.uniform(0, 0.5),
|
||||
"mel_centroid": rng.uniform(0, 1),
|
||||
"mel_flatness": rng.uniform(0, 1),
|
||||
"low_ratio": rng.uniform(0, 0.6),
|
||||
"high_ratio": rng.uniform(0, 0.4),
|
||||
}
|
||||
assert mod.classify(f) in mod.LABELS
|
||||
|
||||
|
||||
# ── TestAudioBuffer ───────────────────────────────────────────────────────────
|
||||
|
||||
class TestAudioBuffer:
|
||||
@pytest.fixture(scope="class")
|
||||
def mod(self): return _import_mod()
|
||||
|
||||
def test_no_window_until_full(self, mod):
|
||||
buf = mod.AudioBuffer(window_samples=100)
|
||||
assert buf.push([0.0] * 50) is None
|
||||
|
||||
def test_exact_fill_returns_window(self, mod):
|
||||
buf = mod.AudioBuffer(window_samples=100)
|
||||
w = buf.push([0.0] * 100)
|
||||
assert w is not None and len(w) == 100
|
||||
|
||||
def test_overflow_carries_over(self, mod):
|
||||
buf = mod.AudioBuffer(window_samples=100)
|
||||
buf.push([0.0] * 100) # fills first window
|
||||
w2 = buf.push([1.0] * 100) # fills second window
|
||||
assert w2 is not None and len(w2) == 100
|
||||
|
||||
def test_partial_then_complete(self, mod):
|
||||
buf = mod.AudioBuffer(window_samples=100)
|
||||
buf.push([0.0] * 60)
|
||||
w = buf.push([0.0] * 60)
|
||||
assert w is not None and len(w) == 100
|
||||
|
||||
def test_clear_resets(self, mod):
|
||||
buf = mod.AudioBuffer(window_samples=100)
|
||||
buf.push([0.0] * 90)
|
||||
buf.clear()
|
||||
assert buf.push([0.0] * 90) is None
|
||||
|
||||
def test_window_contents_correct(self, mod):
|
||||
buf = mod.AudioBuffer(window_samples=4)
|
||||
w = buf.push([1.0, 2.0, 3.0, 4.0])
|
||||
assert w == [1.0, 2.0, 3.0, 4.0]
|
||||
|
||||
|
||||
# ── TestNodeSrc ───────────────────────────────────────────────────────────────
|
||||
|
||||
class TestNodeSrc:
|
||||
@pytest.fixture(scope="class")
|
||||
def src(self): return _read_src("saltybot_social/ambient_sound_node.py")
|
||||
|
||||
def test_class_defined(self, src): assert "class AmbientSoundNode" in src
|
||||
def test_audio_buffer(self, src): assert "class AudioBuffer" in src
|
||||
def test_extract_features(self, src): assert "def extract_features" in src
|
||||
def test_classify_fn(self, src): assert "def classify" in src
|
||||
def test_mel_spectrogram(self, src): assert "compute_mel_spectrogram" in src
|
||||
def test_mel_filterbank(self, src): assert "build_mel_filterbank" in src
|
||||
def test_hz_to_mel(self, src): assert "hz_to_mel" in src
|
||||
def test_labels_tuple(self, src): assert "LABELS" in src
|
||||
def test_all_labels(self, src):
|
||||
for label in ("silence", "speech", "music", "crowd", "outdoor", "alarm"):
|
||||
assert label in src
|
||||
def test_topic_pub(self, src): assert '"/saltybot/ambient_sound"' in src
|
||||
def test_topic_sub(self, src): assert '"/social/speech/audio_raw"' in src
|
||||
def test_window_param(self, src): assert '"window_s"' in src
|
||||
def test_n_mels_param(self, src): assert '"n_mels"' in src
|
||||
def test_silence_param(self, src): assert '"silence_db"' in src
|
||||
def test_alarm_param(self, src): assert '"alarm_db_min"' in src
|
||||
def test_speech_param(self, src): assert '"speech_zcr_min"' in src
|
||||
def test_music_param(self, src): assert '"music_zcr_max"' in src
|
||||
def test_crowd_param(self, src): assert '"crowd_zcr_min"' in src
|
||||
def test_string_pub(self, src): assert "String" in src
|
||||
def test_uint8_sub(self, src): assert "UInt8MultiArray" in src
|
||||
def test_issue_tag(self, src): assert "252" in src
|
||||
def test_main(self, src): assert "def main" in src
|
||||
def test_numpy_optional(self, src): assert "_NUMPY" in src
|
||||
|
||||
|
||||
# ── TestConfig ────────────────────────────────────────────────────────────────
|
||||
|
||||
class TestConfig:
|
||||
@pytest.fixture(scope="class")
|
||||
def cfg(self): return _read_src("config/ambient_sound_params.yaml")
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def setup(self): return _read_src("setup.py")
|
||||
|
||||
def test_node_name(self, cfg): assert "ambient_sound_node:" in cfg
|
||||
def test_window_s(self, cfg): assert "window_s" in cfg
|
||||
def test_n_mels(self, cfg): assert "n_mels" in cfg
|
||||
def test_silence_db(self, cfg): assert "silence_db" in cfg
|
||||
def test_alarm_params(self, cfg): assert "alarm_db_min" in cfg
|
||||
def test_speech_params(self, cfg): assert "speech_zcr_min" in cfg
|
||||
def test_music_params(self, cfg): assert "music_zcr_max" in cfg
|
||||
def test_crowd_params(self, cfg): assert "crowd_zcr_min" in cfg
|
||||
def test_defaults_present(self, cfg): assert "-40.0" in cfg and "0.12" in cfg
|
||||
def test_entry_point(self, setup):
|
||||
assert "ambient_sound_node = saltybot_social.ambient_sound_node:main" in setup
|
||||
293
src/buzzer.c
Normal file
293
src/buzzer.c
Normal file
@ -0,0 +1,293 @@
|
||||
#include "buzzer.h"
|
||||
#include "stm32f7xx_hal.h"
|
||||
#include "config.h"
|
||||
#include <string.h>
|
||||
|
||||
/* ================================================================
|
||||
* Buzzer Hardware Configuration
|
||||
* ================================================================ */
|
||||
|
||||
#define BUZZER_PIN GPIO_PIN_8
|
||||
#define BUZZER_PORT GPIOA
|
||||
#define BUZZER_TIM TIM1
|
||||
#define BUZZER_TIM_CHANNEL TIM_CHANNEL_1
|
||||
#define BUZZER_BASE_FREQ_HZ 1000 /* Base PWM frequency (1kHz) */
|
||||
|
||||
/* ================================================================
|
||||
* Predefined Melodies
|
||||
* ================================================================ */
|
||||
|
||||
/* Startup jingle: C-E-G ascending pattern */
|
||||
const MelodyNote melody_startup[] = {
|
||||
{NOTE_C4, DURATION_QUARTER},
|
||||
{NOTE_E4, DURATION_QUARTER},
|
||||
{NOTE_G4, DURATION_QUARTER},
|
||||
{NOTE_C5, DURATION_HALF},
|
||||
{NOTE_REST, 0} /* Terminator */
|
||||
};
|
||||
|
||||
/* Low battery warning: two descending beeps */
|
||||
const MelodyNote melody_low_battery[] = {
|
||||
{NOTE_A5, DURATION_EIGHTH},
|
||||
{NOTE_REST, DURATION_EIGHTH},
|
||||
{NOTE_A5, DURATION_EIGHTH},
|
||||
{NOTE_REST, DURATION_EIGHTH},
|
||||
{NOTE_F5, DURATION_EIGHTH},
|
||||
{NOTE_REST, DURATION_EIGHTH},
|
||||
{NOTE_F5, DURATION_EIGHTH},
|
||||
{NOTE_REST, 0}
|
||||
};
|
||||
|
||||
/* Error alert: rapid repeating tone */
|
||||
const MelodyNote melody_error[] = {
|
||||
{NOTE_E5, DURATION_SIXTEENTH},
|
||||
{NOTE_REST, DURATION_SIXTEENTH},
|
||||
{NOTE_E5, DURATION_SIXTEENTH},
|
||||
{NOTE_REST, DURATION_SIXTEENTH},
|
||||
{NOTE_E5, DURATION_SIXTEENTH},
|
||||
{NOTE_REST, DURATION_SIXTEENTH},
|
||||
{NOTE_REST, 0}
|
||||
};
|
||||
|
||||
/* Docking complete: cheerful ascending chime */
|
||||
const MelodyNote melody_docking_complete[] = {
|
||||
{NOTE_C4, DURATION_EIGHTH},
|
||||
{NOTE_E4, DURATION_EIGHTH},
|
||||
{NOTE_G4, DURATION_EIGHTH},
|
||||
{NOTE_C5, DURATION_QUARTER},
|
||||
{NOTE_REST, DURATION_QUARTER},
|
||||
{NOTE_G4, DURATION_EIGHTH},
|
||||
{NOTE_C5, DURATION_HALF},
|
||||
{NOTE_REST, 0}
|
||||
};
|
||||
|
||||
/* ================================================================
|
||||
* Melody Queue
|
||||
* ================================================================ */
|
||||
|
||||
#define MELODY_QUEUE_SIZE 4
|
||||
|
||||
typedef struct {
|
||||
const MelodyNote *notes; /* Melody sequence pointer */
|
||||
uint16_t note_index; /* Current note in sequence */
|
||||
uint32_t note_start_ms; /* When current note started */
|
||||
uint32_t note_duration_ms; /* Duration of current note */
|
||||
uint16_t current_frequency; /* Current tone frequency (Hz) */
|
||||
bool is_custom; /* Is this a custom melody? */
|
||||
} MelodyPlayback;
|
||||
|
||||
typedef struct {
|
||||
MelodyPlayback queue[MELODY_QUEUE_SIZE];
|
||||
uint8_t write_index;
|
||||
uint8_t read_index;
|
||||
uint8_t count;
|
||||
} MelodyQueue;
|
||||
|
||||
static MelodyQueue s_queue = {0};
|
||||
static MelodyPlayback s_current = {0};
|
||||
static uint32_t s_last_tick_ms = 0;
|
||||
|
||||
/* ================================================================
|
||||
* Hardware Initialization
|
||||
* ================================================================ */
|
||||
|
||||
void buzzer_init(void)
|
||||
{
|
||||
/* Enable GPIO and timer clocks */
|
||||
__HAL_RCC_GPIOA_CLK_ENABLE();
|
||||
__HAL_RCC_TIM1_CLK_ENABLE();
|
||||
|
||||
/* Configure PA8 as TIM1_CH1 PWM output */
|
||||
GPIO_InitTypeDef gpio_init = {0};
|
||||
gpio_init.Pin = BUZZER_PIN;
|
||||
gpio_init.Mode = GPIO_MODE_AF_PP;
|
||||
gpio_init.Pull = GPIO_NOPULL;
|
||||
gpio_init.Speed = GPIO_SPEED_HIGH;
|
||||
gpio_init.Alternate = GPIO_AF1_TIM1;
|
||||
HAL_GPIO_Init(BUZZER_PORT, &gpio_init);
|
||||
|
||||
/* Configure TIM1 for PWM:
|
||||
* Clock: 216MHz / PSC = output frequency
|
||||
* For 1kHz base frequency: PSC = 216, ARR = 1000
|
||||
* Duty cycle = CCR / ARR (e.g., 500/1000 = 50%)
|
||||
*/
|
||||
TIM_HandleTypeDef htim1 = {0};
|
||||
htim1.Instance = BUZZER_TIM;
|
||||
htim1.Init.Prescaler = 216 - 1; /* 216MHz / 216 = 1MHz clock */
|
||||
htim1.Init.CounterMode = TIM_COUNTERMODE_UP;
|
||||
htim1.Init.Period = (1000000 / BUZZER_BASE_FREQ_HZ) - 1; /* 1kHz = 1000 counts */
|
||||
htim1.Init.ClockDivision = TIM_CLOCKDIVISION_DIV1;
|
||||
htim1.Init.RepetitionCounter = 0;
|
||||
HAL_TIM_PWM_Init(&htim1);
|
||||
|
||||
/* Configure PWM on CH1: 50% duty cycle initially (silence will be 0%) */
|
||||
TIM_OC_InitTypeDef oc_init = {0};
|
||||
oc_init.OCMode = TIM_OCMODE_PWM1;
|
||||
oc_init.Pulse = 0; /* Start at 0% duty (silence) */
|
||||
oc_init.OCPolarity = TIM_OCPOLARITY_HIGH;
|
||||
oc_init.OCFastMode = TIM_OCFAST_DISABLE;
|
||||
HAL_TIM_PWM_ConfigChannel(&htim1, &oc_init, BUZZER_TIM_CHANNEL);
|
||||
|
||||
/* Start PWM generation */
|
||||
HAL_TIM_PWM_Start(BUZZER_TIM, BUZZER_TIM_CHANNEL);
|
||||
|
||||
/* Initialize queue */
|
||||
memset(&s_queue, 0, sizeof(s_queue));
|
||||
memset(&s_current, 0, sizeof(s_current));
|
||||
s_last_tick_ms = 0;
|
||||
}
|
||||
|
||||
/* ================================================================
|
||||
* Public API
|
||||
* ================================================================ */
|
||||
|
||||
bool buzzer_play_melody(MelodyType melody_type)
|
||||
{
|
||||
const MelodyNote *notes = NULL;
|
||||
|
||||
switch (melody_type) {
|
||||
case MELODY_STARTUP:
|
||||
notes = melody_startup;
|
||||
break;
|
||||
case MELODY_LOW_BATTERY:
|
||||
notes = melody_low_battery;
|
||||
break;
|
||||
case MELODY_ERROR:
|
||||
notes = melody_error;
|
||||
break;
|
||||
case MELODY_DOCKING_COMPLETE:
|
||||
notes = melody_docking_complete;
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
||||
return buzzer_play_custom(notes);
|
||||
}
|
||||
|
||||
bool buzzer_play_custom(const MelodyNote *notes)
|
||||
{
|
||||
if (!notes || s_queue.count >= MELODY_QUEUE_SIZE) {
|
||||
return false;
|
||||
}
|
||||
|
||||
MelodyPlayback *playback = &s_queue.queue[s_queue.write_index];
|
||||
memset(playback, 0, sizeof(*playback));
|
||||
playback->notes = notes;
|
||||
playback->note_index = 0;
|
||||
playback->is_custom = true;
|
||||
|
||||
s_queue.write_index = (s_queue.write_index + 1) % MELODY_QUEUE_SIZE;
|
||||
s_queue.count++;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool buzzer_play_tone(uint16_t frequency, uint16_t duration_ms)
|
||||
{
|
||||
if (s_queue.count >= MELODY_QUEUE_SIZE) {
|
||||
return false;
|
||||
}
|
||||
|
||||
/* Create a simple 2-note melody: tone + rest */
|
||||
static MelodyNote temp_notes[3];
|
||||
temp_notes[0].frequency = frequency;
|
||||
temp_notes[0].duration_ms = duration_ms;
|
||||
temp_notes[1].frequency = NOTE_REST;
|
||||
temp_notes[1].duration_ms = 0;
|
||||
|
||||
MelodyPlayback *playback = &s_queue.queue[s_queue.write_index];
|
||||
memset(playback, 0, sizeof(*playback));
|
||||
playback->notes = temp_notes;
|
||||
playback->note_index = 0;
|
||||
playback->is_custom = true;
|
||||
|
||||
s_queue.write_index = (s_queue.write_index + 1) % MELODY_QUEUE_SIZE;
|
||||
s_queue.count++;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void buzzer_stop(void)
|
||||
{
|
||||
/* Clear queue and current playback */
|
||||
memset(&s_queue, 0, sizeof(s_queue));
|
||||
memset(&s_current, 0, sizeof(s_current));
|
||||
|
||||
/* Silence buzzer (0% duty cycle) */
|
||||
TIM1->CCR1 = 0;
|
||||
}
|
||||
|
||||
bool buzzer_is_playing(void)
|
||||
{
|
||||
return (s_current.notes != NULL) || (s_queue.count > 0);
|
||||
}
|
||||
|
||||
/* ================================================================
|
||||
* Timer Update and PWM Frequency Control
|
||||
* ================================================================ */
|
||||
|
||||
static void buzzer_set_frequency(uint16_t frequency)
|
||||
{
|
||||
if (frequency == 0) {
|
||||
/* Silence: 0% duty cycle */
|
||||
TIM1->CCR1 = 0;
|
||||
return;
|
||||
}
|
||||
|
||||
/* Set PWM frequency and 50% duty cycle
|
||||
* TIM1 clock: 1MHz (after prescaler)
|
||||
* ARR = 1MHz / frequency
|
||||
* CCR1 = ARR / 2 (50% duty)
|
||||
*/
|
||||
uint32_t arr = (1000000 / frequency);
|
||||
if (arr > 65535) arr = 65535; /* Clamp to 16-bit */
|
||||
|
||||
TIM1->ARR = arr - 1;
|
||||
TIM1->CCR1 = arr / 2; /* 50% duty cycle for all tones */
|
||||
}
|
||||
|
||||
void buzzer_tick(uint32_t now_ms)
|
||||
{
|
||||
/* Check if current note has finished */
|
||||
if (s_current.notes != NULL) {
|
||||
uint32_t elapsed = now_ms - s_current.note_start_ms;
|
||||
|
||||
if (elapsed >= s_current.note_duration_ms) {
|
||||
/* Move to next note */
|
||||
s_current.note_index++;
|
||||
|
||||
if (s_current.notes[s_current.note_index].duration_ms == 0) {
|
||||
/* End of melody sequence */
|
||||
s_current.notes = NULL;
|
||||
buzzer_set_frequency(0);
|
||||
|
||||
/* Start next queued melody if available */
|
||||
if (s_queue.count > 0) {
|
||||
s_current = s_queue.queue[s_queue.read_index];
|
||||
s_queue.read_index = (s_queue.read_index + 1) % MELODY_QUEUE_SIZE;
|
||||
s_queue.count--;
|
||||
s_current.note_start_ms = now_ms;
|
||||
s_current.note_duration_ms = s_current.notes[0].duration_ms;
|
||||
buzzer_set_frequency(s_current.notes[0].frequency);
|
||||
}
|
||||
} else {
|
||||
/* Play next note */
|
||||
s_current.note_start_ms = now_ms;
|
||||
s_current.note_duration_ms = s_current.notes[s_current.note_index].duration_ms;
|
||||
uint16_t frequency = s_current.notes[s_current.note_index].frequency;
|
||||
buzzer_set_frequency(frequency);
|
||||
}
|
||||
}
|
||||
} else if (s_queue.count > 0 && s_current.notes == NULL) {
|
||||
/* Start first queued melody */
|
||||
s_current = s_queue.queue[s_queue.read_index];
|
||||
s_queue.read_index = (s_queue.read_index + 1) % MELODY_QUEUE_SIZE;
|
||||
s_queue.count--;
|
||||
s_current.note_start_ms = now_ms;
|
||||
s_current.note_duration_ms = s_current.notes[0].duration_ms;
|
||||
buzzer_set_frequency(s_current.notes[0].frequency);
|
||||
}
|
||||
|
||||
s_last_tick_ms = now_ms;
|
||||
}
|
||||
309
test/test_buzzer.c
Normal file
309
test/test_buzzer.c
Normal file
@ -0,0 +1,309 @@
|
||||
/*
|
||||
* test_buzzer.c — Piezo buzzer melody driver tests (Issue #253)
|
||||
*
|
||||
* Verifies:
|
||||
* - Melody playback: note sequences, timing, frequency transitions
|
||||
* - Queue management: multiple melodies, FIFO ordering
|
||||
* - Non-blocking operation: tick-based timing
|
||||
* - Predefined melodies: startup, battery warning, error, docking
|
||||
* - Custom melodies and simple tones
|
||||
* - Stop and playback control
|
||||
*/
|
||||
|
||||
#include <stdio.h>
|
||||
#include <stdint.h>
|
||||
#include <stdbool.h>
|
||||
#include <string.h>
|
||||
|
||||
/* ── Melody Definitions (from buzzer.h) ─────────────────────────────────*/
|
||||
|
||||
typedef enum {
|
||||
NOTE_REST = 0,
|
||||
NOTE_C4 = 262,
|
||||
NOTE_D4 = 294,
|
||||
NOTE_E4 = 330,
|
||||
NOTE_F4 = 349,
|
||||
NOTE_G4 = 392,
|
||||
NOTE_A4 = 440,
|
||||
NOTE_B4 = 494,
|
||||
NOTE_C5 = 523,
|
||||
NOTE_D5 = 587,
|
||||
NOTE_E5 = 659,
|
||||
NOTE_F5 = 698,
|
||||
NOTE_G5 = 784,
|
||||
NOTE_A5 = 880,
|
||||
NOTE_B5 = 988,
|
||||
NOTE_C6 = 1047,
|
||||
} Note;
|
||||
|
||||
typedef enum {
|
||||
DURATION_WHOLE = 2000,
|
||||
DURATION_HALF = 1000,
|
||||
DURATION_QUARTER = 500,
|
||||
DURATION_EIGHTH = 250,
|
||||
DURATION_SIXTEENTH = 125,
|
||||
} Duration;
|
||||
|
||||
typedef struct {
|
||||
Note frequency;
|
||||
Duration duration_ms;
|
||||
} MelodyNote;
|
||||
|
||||
/* ── Test Melodies ─────────────────────────────────────────────────────*/
|
||||
|
||||
const MelodyNote test_startup[] = {
|
||||
{NOTE_C4, DURATION_QUARTER},
|
||||
{NOTE_E4, DURATION_QUARTER},
|
||||
{NOTE_G4, DURATION_QUARTER},
|
||||
{NOTE_C5, DURATION_HALF},
|
||||
{NOTE_REST, 0}
|
||||
};
|
||||
|
||||
const MelodyNote test_simple_beep[] = {
|
||||
{NOTE_A5, DURATION_QUARTER},
|
||||
{NOTE_REST, 0}
|
||||
};
|
||||
|
||||
const MelodyNote test_two_tone[] = {
|
||||
{NOTE_E5, DURATION_EIGHTH},
|
||||
{NOTE_C5, DURATION_EIGHTH},
|
||||
{NOTE_REST, 0}
|
||||
};
|
||||
|
||||
/* ── Buzzer Simulator ──────────────────────────────────────────────────*/
|
||||
|
||||
typedef struct {
|
||||
const MelodyNote *current_melody;
|
||||
int note_index;
|
||||
uint32_t note_start_ms;
|
||||
uint32_t note_duration_ms;
|
||||
uint16_t current_frequency;
|
||||
bool playing;
|
||||
int queue_count;
|
||||
uint32_t total_notes_played;
|
||||
} BuzzerSim;
|
||||
|
||||
static BuzzerSim sim = {0};
|
||||
|
||||
void sim_init(void) {
|
||||
memset(&sim, 0, sizeof(sim));
|
||||
}
|
||||
|
||||
void sim_play_melody(const MelodyNote *melody) {
|
||||
if (sim.playing && sim.current_melody != NULL) {
|
||||
sim.queue_count++;
|
||||
return;
|
||||
}
|
||||
sim.current_melody = melody;
|
||||
sim.note_index = 0;
|
||||
sim.playing = true;
|
||||
sim.note_start_ms = (uint32_t)-1;
|
||||
if (melody && melody[0].duration_ms > 0) {
|
||||
sim.note_duration_ms = melody[0].duration_ms;
|
||||
sim.current_frequency = melody[0].frequency;
|
||||
}
|
||||
}
|
||||
|
||||
void sim_stop(void) {
|
||||
sim.current_melody = NULL;
|
||||
sim.playing = false;
|
||||
sim.current_frequency = 0;
|
||||
sim.queue_count = 0;
|
||||
}
|
||||
|
||||
void sim_tick(uint32_t now_ms) {
|
||||
if (!sim.playing || !sim.current_melody) return;
|
||||
if (sim.note_start_ms == (uint32_t)-1) sim.note_start_ms = now_ms;
|
||||
uint32_t elapsed = now_ms - sim.note_start_ms;
|
||||
if (elapsed >= sim.note_duration_ms) {
|
||||
sim.total_notes_played++;
|
||||
sim.note_index++;
|
||||
if (sim.current_melody[sim.note_index].duration_ms == 0) {
|
||||
sim.playing = false;
|
||||
sim.current_melody = NULL;
|
||||
sim.current_frequency = 0;
|
||||
} else {
|
||||
sim.note_start_ms = now_ms;
|
||||
sim.note_duration_ms = sim.current_melody[sim.note_index].duration_ms;
|
||||
sim.current_frequency = sim.current_melody[sim.note_index].frequency;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* ── Unit Tests ────────────────────────────────────────────────────────*/
|
||||
|
||||
static int test_count = 0, test_passed = 0, test_failed = 0;
|
||||
|
||||
#define TEST(name) do { test_count++; printf("\n TEST %d: %s\n", test_count, name); } while(0)
|
||||
#define ASSERT(cond, msg) do { if (cond) { test_passed++; printf(" ✓ %s\n", msg); } else { test_failed++; printf(" ✗ %s\n", msg); } } while(0)
|
||||
|
||||
void test_melody_structure(void) {
|
||||
TEST("Melody structure validation");
|
||||
ASSERT(test_startup[0].frequency == NOTE_C4, "Startup starts at C4");
|
||||
ASSERT(test_startup[0].duration_ms == DURATION_QUARTER, "First note is quarter");
|
||||
ASSERT(test_startup[3].frequency == NOTE_C5, "Startup ends at C5");
|
||||
ASSERT(test_startup[4].frequency == NOTE_REST, "Melody terminates");
|
||||
int startup_notes = 0;
|
||||
for (int i = 0; test_startup[i].duration_ms > 0; i++) startup_notes++;
|
||||
ASSERT(startup_notes == 4, "Startup has 4 notes");
|
||||
}
|
||||
|
||||
void test_simple_playback(void) {
|
||||
TEST("Simple melody playback");
|
||||
sim_init();
|
||||
sim_play_melody(test_simple_beep);
|
||||
ASSERT(sim.playing == true, "Playback starts");
|
||||
ASSERT(sim.current_frequency == NOTE_A5, "First note is A5");
|
||||
ASSERT(sim.note_index == 0, "Index starts at 0");
|
||||
sim_tick(100);
|
||||
ASSERT(sim.playing == true, "Still playing after first tick");
|
||||
sim_tick(650);
|
||||
ASSERT(sim.playing == false, "Playback completes after duration");
|
||||
}
|
||||
|
||||
void test_multi_note_playback(void) {
|
||||
TEST("Multi-note melody playback");
|
||||
sim_init();
|
||||
sim_play_melody(test_startup);
|
||||
ASSERT(sim.playing == true, "Playback starts");
|
||||
ASSERT(sim.note_index == 0, "Index at first note");
|
||||
ASSERT(sim.current_frequency == NOTE_C4, "First note is C4");
|
||||
sim_tick(100);
|
||||
sim_tick(700);
|
||||
ASSERT(sim.note_index == 1, "Advanced to second note");
|
||||
sim_tick(1300);
|
||||
ASSERT(sim.note_index == 2, "Advanced to third note");
|
||||
sim_tick(1900);
|
||||
ASSERT(sim.note_index == 3, "Advanced to fourth note");
|
||||
sim_tick(3100);
|
||||
ASSERT(sim.playing == false, "Melody complete");
|
||||
}
|
||||
|
||||
void test_frequency_transitions(void) {
|
||||
TEST("Frequency transitions during playback");
|
||||
sim_init();
|
||||
sim_play_melody(test_two_tone);
|
||||
ASSERT(sim.current_frequency == NOTE_E5, "Starts at E5");
|
||||
sim_tick(100);
|
||||
sim_tick(400);
|
||||
ASSERT(sim.note_index == 1, "Advanced to second note");
|
||||
ASSERT(sim.current_frequency == NOTE_C5, "Now playing C5");
|
||||
sim_tick(700);
|
||||
ASSERT(sim.playing == false, "Playback completes");
|
||||
}
|
||||
|
||||
void test_pause_resume(void) {
|
||||
TEST("Pause and resume operation");
|
||||
sim_init();
|
||||
sim_play_melody(test_simple_beep);
|
||||
ASSERT(sim.playing == true, "Playing starts");
|
||||
sim_stop();
|
||||
ASSERT(sim.playing == false, "Stop silences buzzer");
|
||||
ASSERT(sim.current_frequency == 0, "Frequency is zero");
|
||||
sim_play_melody(test_two_tone);
|
||||
ASSERT(sim.playing == true, "Resume works");
|
||||
ASSERT(sim.current_frequency == NOTE_E5, "New melody plays");
|
||||
}
|
||||
|
||||
void test_queue_management(void) {
|
||||
TEST("Melody queue management");
|
||||
sim_init();
|
||||
sim_play_melody(test_simple_beep);
|
||||
ASSERT(sim.playing == true, "First melody playing");
|
||||
ASSERT(sim.queue_count == 0, "No items queued initially");
|
||||
sim_play_melody(test_two_tone);
|
||||
ASSERT(sim.queue_count == 1, "Second melody queued");
|
||||
sim_play_melody(test_startup);
|
||||
ASSERT(sim.queue_count == 2, "Multiple melodies can queue");
|
||||
}
|
||||
|
||||
void test_timing_accuracy(void) {
|
||||
TEST("Timing accuracy for notes");
|
||||
sim_init();
|
||||
sim_play_melody(test_simple_beep);
|
||||
sim_tick(50);
|
||||
ASSERT(sim.playing == true, "Still playing on first tick");
|
||||
sim_tick(600);
|
||||
ASSERT(sim.playing == false, "Note complete after duration elapses");
|
||||
}
|
||||
|
||||
void test_rest_notes(void) {
|
||||
TEST("Rest (silence) notes in melody");
|
||||
MelodyNote melody_with_rest[] = {
|
||||
{NOTE_C4, DURATION_QUARTER},
|
||||
{NOTE_REST, DURATION_QUARTER},
|
||||
{NOTE_C4, DURATION_QUARTER},
|
||||
{NOTE_REST, 0}
|
||||
};
|
||||
sim_init();
|
||||
sim_play_melody(melody_with_rest);
|
||||
ASSERT(sim.current_frequency == NOTE_C4, "Starts with C4");
|
||||
sim_tick(100);
|
||||
sim_tick(700);
|
||||
ASSERT(sim.note_index == 1, "Advanced to rest");
|
||||
ASSERT(sim.current_frequency == NOTE_REST, "Rest note active");
|
||||
sim_tick(1300);
|
||||
ASSERT(sim.current_frequency == NOTE_C4, "Back to C4 after rest");
|
||||
sim_tick(1900);
|
||||
ASSERT(sim.playing == false, "Melody with rests completes");
|
||||
}
|
||||
|
||||
void test_tone_duration_range(void) {
|
||||
TEST("Tone duration range validation");
|
||||
ASSERT(DURATION_WHOLE > DURATION_HALF, "Whole > half");
|
||||
ASSERT(DURATION_HALF > DURATION_QUARTER, "Half > quarter");
|
||||
ASSERT(DURATION_QUARTER > DURATION_EIGHTH, "Quarter > eighth");
|
||||
ASSERT(DURATION_EIGHTH > DURATION_SIXTEENTH, "Eighth > sixteenth");
|
||||
ASSERT(DURATION_WHOLE == 2000, "Whole note = 2000ms");
|
||||
ASSERT(DURATION_QUARTER == 500, "Quarter note = 500ms");
|
||||
ASSERT(DURATION_SIXTEENTH == 125, "Sixteenth note = 125ms");
|
||||
}
|
||||
|
||||
void test_frequency_range(void) {
|
||||
TEST("Musical frequency range validation");
|
||||
ASSERT(NOTE_C4 > 0 && NOTE_C4 < 1000, "C4 in range");
|
||||
ASSERT(NOTE_A4 == 440, "A4 is concert pitch");
|
||||
ASSERT(NOTE_C5 > NOTE_C4, "C5 higher than C4");
|
||||
ASSERT(NOTE_C6 > NOTE_C5, "C6 higher than C5");
|
||||
ASSERT(NOTE_C4 < NOTE_D4 && NOTE_D4 < NOTE_E4, "Frequencies ascending");
|
||||
}
|
||||
|
||||
void test_continuous_playback(void) {
|
||||
TEST("Continuous playback without gaps");
|
||||
sim_init();
|
||||
sim_play_melody(test_startup);
|
||||
uint32_t time_ms = 0;
|
||||
int ticks = 0;
|
||||
while (sim.playing && ticks < 100) {
|
||||
sim_tick(time_ms);
|
||||
time_ms += 100;
|
||||
ticks++;
|
||||
}
|
||||
ASSERT(!sim.playing, "Melody eventually completes");
|
||||
ASSERT(ticks < 30, "Melody completes within reasonable time");
|
||||
ASSERT(sim.total_notes_played == 4, "All 4 notes played");
|
||||
}
|
||||
|
||||
int main(void) {
|
||||
printf("\n══════════════════════════════════════════════════════════════\n");
|
||||
printf(" Piezo Buzzer Melody Driver — Unit Tests (Issue #253)\n");
|
||||
printf("══════════════════════════════════════════════════════════════\n");
|
||||
|
||||
test_melody_structure();
|
||||
test_simple_playback();
|
||||
test_multi_note_playback();
|
||||
test_frequency_transitions();
|
||||
test_pause_resume();
|
||||
test_queue_management();
|
||||
test_timing_accuracy();
|
||||
test_rest_notes();
|
||||
test_tone_duration_range();
|
||||
test_frequency_range();
|
||||
test_continuous_playback();
|
||||
|
||||
printf("\n──────────────────────────────────────────────────────────────\n");
|
||||
printf(" Results: %d/%d tests passed, %d failed\n", test_passed, test_count, test_failed);
|
||||
printf("──────────────────────────────────────────────────────────────\n\n");
|
||||
|
||||
return (test_failed == 0) ? 0 : 1;
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user